fix dumb beta sampling bug
This commit is contained in:
parent
4dad518d3f
commit
5e39c386f7
88
squiggle.c
88
squiggle.c
|
@ -74,46 +74,48 @@ float sample_to(float low, float high, uint32_t* seed)
|
||||||
return sample_lognormal(logmean, logsigma, seed);
|
return sample_lognormal(logmean, logsigma, seed);
|
||||||
}
|
}
|
||||||
|
|
||||||
float sample_gamma(float alpha, uint32_t* seed){
|
float sample_gamma(float alpha, uint32_t* seed)
|
||||||
|
{
|
||||||
|
|
||||||
// A Simple Method for Generating Gamma Variables, Marsaglia and Wan Tsang, 2001
|
// A Simple Method for Generating Gamma Variables, Marsaglia and Wan Tsang, 2001
|
||||||
// https://dl.acm.org/doi/pdf/10.1145/358407.358414
|
// https://dl.acm.org/doi/pdf/10.1145/358407.358414
|
||||||
// see also the references/ folder
|
// see also the references/ folder
|
||||||
if(alpha >=1){
|
if (alpha >= 1) {
|
||||||
float d, c, x, v, u;
|
float d, c, x, v, u;
|
||||||
d = alpha - 1.0/3.0;
|
d = alpha - 1.0 / 3.0;
|
||||||
c = 1.0/sqrt(9.0 * d);
|
c = 1.0 / sqrt(9.0 * d);
|
||||||
while(1){
|
while (1) {
|
||||||
|
|
||||||
do {
|
|
||||||
x = sample_unit_normal(seed);
|
|
||||||
v = 1.0 + c * x;
|
|
||||||
} while(v <= 0.0);
|
|
||||||
|
|
||||||
v = pow(v, 3);
|
do {
|
||||||
u = sample_unit_uniform(seed);
|
x = sample_unit_normal(seed);
|
||||||
if( u < 1.0 - 0.0331 * pow(x, 4)){ // Condition 1
|
v = 1.0 + c * x;
|
||||||
// the 0.0331 doesn't inspire much confidence
|
} while (v <= 0.0);
|
||||||
// however, this isn't the whole story
|
|
||||||
// by knowing that Condition 1 implies condition 2
|
v = pow(v, 3);
|
||||||
// we realize that this is just a way of making the algorithm faster
|
u = sample_unit_uniform(seed);
|
||||||
// i.e., of not using the logarithms
|
if (u < 1.0 - 0.0331 * pow(x, 4)) { // Condition 1
|
||||||
return d*v;
|
// the 0.0331 doesn't inspire much confidence
|
||||||
}
|
// however, this isn't the whole story
|
||||||
if(log(u) < 0.5*pow(x,2) + d*(1.0 - v + log(v))){ // Condition 2
|
// by knowing that Condition 1 implies condition 2
|
||||||
return d*v;
|
// we realize that this is just a way of making the algorithm faster
|
||||||
}
|
// i.e., of not using the logarithms
|
||||||
}
|
return d * v;
|
||||||
}else{
|
}
|
||||||
return sample_gamma(1 + alpha, seed) * pow(sample_unit_uniform(seed), 1/alpha);
|
if (log(u) < 0.5 * pow(x, 2) + d * (1.0 - v + log(v))) { // Condition 2
|
||||||
// see note in p. 371 of https://dl.acm.org/doi/pdf/10.1145/358407.358414
|
return d * v;
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
return sample_gamma(1 + alpha, seed) * pow(sample_unit_uniform(seed), 1 / alpha);
|
||||||
|
// see note in p. 371 of https://dl.acm.org/doi/pdf/10.1145/358407.358414
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
float sample_beta(float a, float b, uint32_t* seed){
|
float sample_beta(float a, float b, uint32_t* seed)
|
||||||
float gamma_a = sample_gamma(a, seed);
|
{
|
||||||
float gamma_b = sample_gamma(b, seed);
|
float gamma_a = sample_gamma(a, seed);
|
||||||
return a / (a + b);
|
float gamma_b = sample_gamma(b, seed);
|
||||||
|
return gamma_a / (gamma_a + gamma_b);
|
||||||
}
|
}
|
||||||
|
|
||||||
// Array helpers
|
// Array helpers
|
||||||
|
@ -134,18 +136,20 @@ void array_cumsum(float* array_to_sum, float* array_cumsummed, int length)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
float array_mean(float* array, int length){
|
float array_mean(float* array, int length)
|
||||||
float sum = array_sum(array, length);
|
{
|
||||||
return sum / length;
|
float sum = array_sum(array, length);
|
||||||
|
return sum / length;
|
||||||
}
|
}
|
||||||
|
|
||||||
float array_std(float* array, int length){
|
float array_std(float* array, int length)
|
||||||
float mean = array_mean(array, length);
|
{
|
||||||
|
float mean = array_mean(array, length);
|
||||||
float std = 0.0;
|
float std = 0.0;
|
||||||
for (int i = 0; i < length; i++) {
|
for (int i = 0; i < length; i++) {
|
||||||
std += pow(array[i] - mean, 2.0);
|
std += pow(array[i] - mean, 2.0);
|
||||||
}
|
}
|
||||||
std=sqrt(std/length);
|
std = sqrt(std / length);
|
||||||
return std;
|
return std;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue
Block a user