From 5e39c386f7e73efca69df35fb36ddbc91618ed22 Mon Sep 17 00:00:00 2001 From: NunoSempere Date: Sun, 23 Jul 2023 09:29:00 +0200 Subject: [PATCH] fix dumb beta sampling bug --- squiggle.c | 88 ++++++++++++++++++++++++++++-------------------------- 1 file changed, 46 insertions(+), 42 deletions(-) diff --git a/squiggle.c b/squiggle.c index 2a68d9b..58dbeb0 100644 --- a/squiggle.c +++ b/squiggle.c @@ -74,46 +74,48 @@ float sample_to(float low, float high, uint32_t* seed) return sample_lognormal(logmean, logsigma, seed); } -float sample_gamma(float alpha, uint32_t* seed){ +float sample_gamma(float alpha, uint32_t* seed) +{ - // A Simple Method for Generating Gamma Variables, Marsaglia and Wan Tsang, 2001 - // https://dl.acm.org/doi/pdf/10.1145/358407.358414 - // see also the references/ folder - if(alpha >=1){ - float d, c, x, v, u; - d = alpha - 1.0/3.0; - c = 1.0/sqrt(9.0 * d); - while(1){ - - do { - x = sample_unit_normal(seed); - v = 1.0 + c * x; - } while(v <= 0.0); + // A Simple Method for Generating Gamma Variables, Marsaglia and Wan Tsang, 2001 + // https://dl.acm.org/doi/pdf/10.1145/358407.358414 + // see also the references/ folder + if (alpha >= 1) { + float d, c, x, v, u; + d = alpha - 1.0 / 3.0; + c = 1.0 / sqrt(9.0 * d); + while (1) { - v = pow(v, 3); - u = sample_unit_uniform(seed); - if( u < 1.0 - 0.0331 * pow(x, 4)){ // Condition 1 - // the 0.0331 doesn't inspire much confidence - // however, this isn't the whole story - // by knowing that Condition 1 implies condition 2 - // we realize that this is just a way of making the algorithm faster - // i.e., of not using the logarithms - return d*v; - } - if(log(u) < 0.5*pow(x,2) + d*(1.0 - v + log(v))){ // Condition 2 - return d*v; - } - } - }else{ - return sample_gamma(1 + alpha, seed) * pow(sample_unit_uniform(seed), 1/alpha); - // see note in p. 371 of https://dl.acm.org/doi/pdf/10.1145/358407.358414 - } + do { + x = sample_unit_normal(seed); + v = 1.0 + c * x; + } while (v <= 0.0); + + v = pow(v, 3); + u = sample_unit_uniform(seed); + if (u < 1.0 - 0.0331 * pow(x, 4)) { // Condition 1 + // the 0.0331 doesn't inspire much confidence + // however, this isn't the whole story + // by knowing that Condition 1 implies condition 2 + // we realize that this is just a way of making the algorithm faster + // i.e., of not using the logarithms + return d * v; + } + if (log(u) < 0.5 * pow(x, 2) + d * (1.0 - v + log(v))) { // Condition 2 + return d * v; + } + } + } else { + return sample_gamma(1 + alpha, seed) * pow(sample_unit_uniform(seed), 1 / alpha); + // see note in p. 371 of https://dl.acm.org/doi/pdf/10.1145/358407.358414 + } } -float sample_beta(float a, float b, uint32_t* seed){ - float gamma_a = sample_gamma(a, seed); - float gamma_b = sample_gamma(b, seed); - return a / (a + b); +float sample_beta(float a, float b, uint32_t* seed) +{ + float gamma_a = sample_gamma(a, seed); + float gamma_b = sample_gamma(b, seed); + return gamma_a / (gamma_a + gamma_b); } // Array helpers @@ -134,18 +136,20 @@ void array_cumsum(float* array_to_sum, float* array_cumsummed, int length) } } -float array_mean(float* array, int length){ - float sum = array_sum(array, length); - return sum / length; +float array_mean(float* array, int length) +{ + float sum = array_sum(array, length); + return sum / length; } -float array_std(float* array, int length){ - float mean = array_mean(array, length); +float array_std(float* array, int length) +{ + float mean = array_mean(array, length); float std = 0.0; for (int i = 0; i < length; i++) { std += pow(array[i] - mean, 2.0); } - std=sqrt(std/length); + std = sqrt(std / length); return std; }