From c4167681d7c4057cf82254ff81abc5310708f18d Mon Sep 17 00:00:00 2001 From: NunoSempere Date: Sun, 25 Feb 2024 19:18:16 -0300 Subject: [PATCH] add concurrency --- probppl.go | 41 ++++++++++++++++++++++++++++------------- 1 file changed, 28 insertions(+), 13 deletions(-) diff --git a/probppl.go b/probppl.go index 1cec82c..801874b 100644 --- a/probppl.go +++ b/probppl.go @@ -5,6 +5,7 @@ import ( "git.nunosempere.com/NunoSempere/probppl/choose" "math" rand "math/rand/v2" + "sync" ) type src = *rand.Rand @@ -119,7 +120,7 @@ func draw148PplFromDistributionAndCheck(d IntProbs, r src, show bool) int64 { func getUnnormalizedBayesianUpdateForDistribution(d IntProbs, r src) int64 { var sum int64 = 0 - n := 10_000 + n := 1000 for i := 0; i < n; i++ { /* if i%1000 == 0 { fmt.Println(i) @@ -133,22 +134,36 @@ func getUnnormalizedBayesianUpdateForDistribution(d IntProbs, r src) int64 { func main() { - var r = rand.New(rand.NewPCG(uint64(1), uint64(2))) - - n_dists := 10 + n_dists := 1000 var dists = make([]IntProbsWeights, n_dists) - for i := 0; i < n_dists; i++ { - people_known_distribution := generatePeopleKnownDistribution(r) - result := getUnnormalizedBayesianUpdateForDistribution(people_known_distribution, r) - if i%10 == 0 { - fmt.Printf("%d/%d\n", i, n_dists) - } - if result > 0 { - dists[i] = IntProbsWeights{IntProbs: people_known_distribution, w: result} - } + // Prepare for concurrency + num_threads := 8 + var wg sync.WaitGroup + wg.Add(num_threads) + + for i := range num_threads { + go func() { + defer wg.Done() + var r = rand.New(rand.NewPCG(uint64(i), uint64(i+1))) + for j := i * (n_dists / num_threads); j < (i+1)*(n_dists/num_threads); j++ { + people_known_distribution := generatePeopleKnownDistribution(r) + result := getUnnormalizedBayesianUpdateForDistribution(people_known_distribution, r) + /* + if i%10 == 0 { + fmt.Printf("%d/%d\n", i, n_dists) + } + */ + if result > 0 { + dists[j] = IntProbsWeights{IntProbs: people_known_distribution, w: result} + } + } + + }() + } + wg.Wait() // Now calculate the posterior sum_weights := int64(0) for _, dist := range dists {