add concurrency

This commit is contained in:
NunoSempere 2024-02-25 19:18:16 -03:00
parent e74c1127a5
commit c4167681d7

View File

@ -5,6 +5,7 @@ import (
"git.nunosempere.com/NunoSempere/probppl/choose" "git.nunosempere.com/NunoSempere/probppl/choose"
"math" "math"
rand "math/rand/v2" rand "math/rand/v2"
"sync"
) )
type src = *rand.Rand type src = *rand.Rand
@ -119,7 +120,7 @@ func draw148PplFromDistributionAndCheck(d IntProbs, r src, show bool) int64 {
func getUnnormalizedBayesianUpdateForDistribution(d IntProbs, r src) int64 { func getUnnormalizedBayesianUpdateForDistribution(d IntProbs, r src) int64 {
var sum int64 = 0 var sum int64 = 0
n := 10_000 n := 1000
for i := 0; i < n; i++ { for i := 0; i < n; i++ {
/* if i%1000 == 0 { /* if i%1000 == 0 {
fmt.Println(i) fmt.Println(i)
@ -133,22 +134,36 @@ func getUnnormalizedBayesianUpdateForDistribution(d IntProbs, r src) int64 {
func main() { func main() {
var r = rand.New(rand.NewPCG(uint64(1), uint64(2))) n_dists := 1000
n_dists := 10
var dists = make([]IntProbsWeights, n_dists) var dists = make([]IntProbsWeights, n_dists)
for i := 0; i < n_dists; i++ { // Prepare for concurrency
num_threads := 8
var wg sync.WaitGroup
wg.Add(num_threads)
for i := range num_threads {
go func() {
defer wg.Done()
var r = rand.New(rand.NewPCG(uint64(i), uint64(i+1)))
for j := i * (n_dists / num_threads); j < (i+1)*(n_dists/num_threads); j++ {
people_known_distribution := generatePeopleKnownDistribution(r) people_known_distribution := generatePeopleKnownDistribution(r)
result := getUnnormalizedBayesianUpdateForDistribution(people_known_distribution, r) result := getUnnormalizedBayesianUpdateForDistribution(people_known_distribution, r)
/*
if i%10 == 0 { if i%10 == 0 {
fmt.Printf("%d/%d\n", i, n_dists) fmt.Printf("%d/%d\n", i, n_dists)
} }
*/
if result > 0 { if result > 0 {
dists[i] = IntProbsWeights{IntProbs: people_known_distribution, w: result} dists[j] = IntProbsWeights{IntProbs: people_known_distribution, w: result}
} }
} }
}()
}
wg.Wait()
// Now calculate the posterior // Now calculate the posterior
sum_weights := int64(0) sum_weights := int64(0)
for _, dist := range dists { for _, dist := range dists {