From b0f48286d5e230a6b6ed9041f89ce9aa7c33a5ae Mon Sep 17 00:00:00 2001 From: NunoSempere Date: Mon, 10 Jun 2024 01:12:02 +0200 Subject: [PATCH] add code to multiply beta distributions --- f.go | 36 ++++++++++++++++++++++++++++-------- sample/sample.go | 9 +++++++++ 2 files changed, 37 insertions(+), 8 deletions(-) diff --git a/f.go b/f.go index c58369c..a545cdd 100644 --- a/f.go +++ b/f.go @@ -42,7 +42,9 @@ type Lognormal struct { func (ln Lognormal) Samples() []float64 { sampler := func(r sample.Src) float64 { return sample.Sample_to(ln.low, ln.high, r) } - return sample.Sample_parallel(sampler, N_SAMPLES) + // return sample.Sample_parallel(sampler, N_SAMPLES) + // Can't do parallel because then I'd have to await throughout the code + return sample.Sample_serially(sampler, N_SAMPLES) } type Beta struct { @@ -52,7 +54,8 @@ type Beta struct { func (beta Beta) Samples() []float64 { sampler := func(r sample.Src) float64 { return sample.Sample_beta(beta.a, beta.b, r) } - return sample.Sample_parallel(sampler, N_SAMPLES) + // return sample.Sample_parallel(sampler, N_SAMPLES) + return sample.Sample_serially(sampler, N_SAMPLES) } type FilledSamples struct { @@ -156,12 +159,12 @@ func multiplyBetaDists(beta1 Beta, beta2 Beta) Beta { func multiplyAsSamples(dist1 Dist, dist2 Dist) Dist { // dist2 = Beta{a: 1, b: 2} - fmt.Printf("dist1: %v\n", dist1) - fmt.Printf("dist2: %v\n", dist2) + // fmt.Printf("dist1: %v\n", dist1) + // fmt.Printf("dist2: %v\n", dist2) xs := dist1.Samples() ys := dist2.Samples() - fmt.Printf("xs: %v\n", xs) - fmt.Printf("ys: %v\n", ys) + // fmt.Printf("xs: %v\n", xs) + // fmt.Printf("ys: %v\n", ys) zs := make([]float64, N_SAMPLES) for i := 0; i < N_SAMPLES; i++ { @@ -188,6 +191,9 @@ func multiplyDists(old_dist Dist, new_dist Dist) (Dist, error) { } case Scalar: { + if o.p == 1 { + return new_dist, nil + } switch n := new_dist.(type) { case Lognormal: return multiplyLogDists(Lognormal{low: o.p, high: o.p}, n), nil @@ -197,8 +203,16 @@ func multiplyDists(old_dist Dist, new_dist Dist) (Dist, error) { return multiplyAsSamples(o, n), nil } } + case Beta: + switch n := new_dist.(type) { + case Beta: + return multiplyBetaDists(o, n), nil + default: + return multiplyAsSamples(o, n), nil + } default: - return nil, errors.New("Can't multiply dists") + return multiplyAsSamples(old_dist, new_dist), nil + // return nil, errors.New("Can't multiply dists") } } @@ -227,7 +241,6 @@ func joinDists(old_dist Dist, new_dist Dist, op string) (Dist, error) { /* Pretty print distributions */ func prettyPrint90CI(low float64, high float64) { // fmt.Printf("=> %.1f %.1f\n", low, high) - fmt.Printf("=> ") switch { case math.Abs(low) >= 1_000_000_000_000: fmt.Printf("%.1fT", low/1_000_000_000_000) @@ -264,6 +277,7 @@ func prettyPrint90CI(low float64, high float64) { func prettyPrintDist(dist Dist) { switch v := dist.(type) { case Lognormal: + fmt.Printf("=> ") prettyPrint90CI(v.low, v.high) case FilledSamples: tmp_xs := make([]float64, N_SAMPLES) @@ -276,6 +290,9 @@ func prettyPrintDist(dist Dist) { high_int := N_SAMPLES * 19 / 20 high := tmp_xs[high_int] prettyPrint90CI(low, high) + case Beta: + fmt.Printf("=> beta ") + prettyPrint90CI(v.a, v.b) default: fmt.Printf("%v", v) } @@ -342,6 +359,9 @@ EventForLoop: joint_dist, err := joinDists(old_dist, new_dist, op) if err != nil { + fmt.Printf("%v\n", err) + fmt.Printf("Dist on stack: ") + prettyPrintDist(old_dist) continue EventForLoop } old_dist = joint_dist diff --git a/sample/sample.go b/sample/sample.go index 42ac5a8..8e8844c 100644 --- a/sample/sample.go +++ b/sample/sample.go @@ -137,6 +137,15 @@ func Sample_mixture(fs []func64, weights []float64, r Src) float64 { } +func Sample_serially(f func64, n_samples int) []float64 { + var r = rand.New(rand.NewPCG(uint64(1), uint64(2))) + xs := make([]float64, n_samples) + for i := 0; i < n_samples; i++ { + xs[i] = f(r) + } + return xs +} + func Sample_parallel(f func64, n_samples int) []float64 { var num_threads = 16 var xs = make([]float64, n_samples)