forked from personal/squiggle.c
		
	fix dumb beta sampling bug
This commit is contained in:
		
							parent
							
								
									4dad518d3f
								
							
						
					
					
						commit
						5e39c386f7
					
				
							
								
								
									
										38
									
								
								squiggle.c
									
									
									
									
									
								
							
							
						
						
									
										38
									
								
								squiggle.c
									
									
									
									
									
								
							|  | @ -74,46 +74,48 @@ float sample_to(float low, float high, uint32_t* seed) | ||||||
|     return sample_lognormal(logmean, logsigma, seed); |     return sample_lognormal(logmean, logsigma, seed); | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| float sample_gamma(float alpha, uint32_t* seed){ | float sample_gamma(float alpha, uint32_t* seed) | ||||||
|  | { | ||||||
| 
 | 
 | ||||||
|     // A Simple Method for Generating Gamma Variables, Marsaglia and Wan Tsang, 2001
 |     // A Simple Method for Generating Gamma Variables, Marsaglia and Wan Tsang, 2001
 | ||||||
|     // https://dl.acm.org/doi/pdf/10.1145/358407.358414
 |     // https://dl.acm.org/doi/pdf/10.1145/358407.358414
 | ||||||
|     // see also the references/ folder
 |     // see also the references/ folder
 | ||||||
| 	if(alpha >=1){ |     if (alpha >= 1) { | ||||||
|         float d, c, x, v, u; |         float d, c, x, v, u; | ||||||
| 		d = alpha - 1.0/3.0; |         d = alpha - 1.0 / 3.0; | ||||||
| 		c = 1.0/sqrt(9.0 * d); |         c = 1.0 / sqrt(9.0 * d); | ||||||
| 		while(1){ |         while (1) { | ||||||
| 
 | 
 | ||||||
|             do { |             do { | ||||||
|                 x = sample_unit_normal(seed); |                 x = sample_unit_normal(seed); | ||||||
|                 v = 1.0 + c * x; |                 v = 1.0 + c * x; | ||||||
| 			} while(v <= 0.0); |             } while (v <= 0.0); | ||||||
| 
 | 
 | ||||||
|             v = pow(v, 3); |             v = pow(v, 3); | ||||||
|             u = sample_unit_uniform(seed); |             u = sample_unit_uniform(seed); | ||||||
| 			if( u < 1.0 - 0.0331 * pow(x, 4)){ // Condition 1
 |             if (u < 1.0 - 0.0331 * pow(x, 4)) { // Condition 1
 | ||||||
|                 // the 0.0331 doesn't inspire much confidence
 |                 // the 0.0331 doesn't inspire much confidence
 | ||||||
|                 // however, this isn't the whole story
 |                 // however, this isn't the whole story
 | ||||||
|                 // by knowing that Condition 1 implies condition 2
 |                 // by knowing that Condition 1 implies condition 2
 | ||||||
|                 // we realize that this is just a way of making the algorithm faster
 |                 // we realize that this is just a way of making the algorithm faster
 | ||||||
|                 // i.e., of not using the logarithms
 |                 // i.e., of not using the logarithms
 | ||||||
| 					return d*v; |                 return d * v; | ||||||
|             } |             } | ||||||
| 			if(log(u) < 0.5*pow(x,2) + d*(1.0 - v + log(v))){ // Condition 2
 |             if (log(u) < 0.5 * pow(x, 2) + d * (1.0 - v + log(v))) { // Condition 2
 | ||||||
| 				return d*v; |                 return d * v; | ||||||
|             } |             } | ||||||
|         } |         } | ||||||
| 	}else{ |     } else { | ||||||
| 		return sample_gamma(1 + alpha, seed) * pow(sample_unit_uniform(seed), 1/alpha); |         return sample_gamma(1 + alpha, seed) * pow(sample_unit_uniform(seed), 1 / alpha); | ||||||
|         // see note in p. 371 of https://dl.acm.org/doi/pdf/10.1145/358407.358414
 |         // see note in p. 371 of https://dl.acm.org/doi/pdf/10.1145/358407.358414
 | ||||||
|     } |     } | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| float sample_beta(float a, float b, uint32_t* seed){ | float sample_beta(float a, float b, uint32_t* seed) | ||||||
|  | { | ||||||
|     float gamma_a = sample_gamma(a, seed); |     float gamma_a = sample_gamma(a, seed); | ||||||
|     float gamma_b = sample_gamma(b, seed); |     float gamma_b = sample_gamma(b, seed); | ||||||
| 	return a / (a + b); |     return gamma_a / (gamma_a + gamma_b); | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| // Array helpers
 | // Array helpers
 | ||||||
|  | @ -134,18 +136,20 @@ void array_cumsum(float* array_to_sum, float* array_cumsummed, int length) | ||||||
|     } |     } | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| float array_mean(float* array, int length){ | float array_mean(float* array, int length) | ||||||
|  | { | ||||||
|     float sum = array_sum(array, length); |     float sum = array_sum(array, length); | ||||||
|     return sum / length; |     return sum / length; | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| float array_std(float* array, int length){ | float array_std(float* array, int length) | ||||||
|  | { | ||||||
|     float mean = array_mean(array, length); |     float mean = array_mean(array, length); | ||||||
|     float std = 0.0; |     float std = 0.0; | ||||||
|     for (int i = 0; i < length; i++) { |     for (int i = 0; i < length; i++) { | ||||||
|         std += pow(array[i] - mean, 2.0); |         std += pow(array[i] - mean, 2.0); | ||||||
|     } |     } | ||||||
| 		std=sqrt(std/length); |     std = sqrt(std / length); | ||||||
|     return std; |     return std; | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
|  |  | ||||||
		Loading…
	
		Reference in New Issue
	
	Block a user