diff --git a/go/squiggle.go b/go/squiggle.go index 8b46c786..46903415 100644 --- a/go/squiggle.go +++ b/go/squiggle.go @@ -58,23 +58,24 @@ func sample_mixture(fs []func64, weights []float64) float64 { } var total float64 = 0 + var cumsummed_normalized_weights = append([]float64(nil), weights...) for i, weight := range weights { total += weight / sum_weights - weights[i] = total + cumsummed_normalized_weights[i] = total } var result float64 var flag int = 0 var p float64 = r.Float64() - for i, weight := range weights { - if p < weight { + for i, cnw := range cumsummed_normalized_weights { + if p < cnw { result = fs[i]() flag = 1 break } } - fmt.Println(weights) + fmt.Println(cumsummed_normalized_weights) if flag == 0 { result = fs[len(fs)-1]() @@ -98,7 +99,7 @@ func main() { ws := [4](float64){1 - p_c, p_c / 2, p_c / 4, p_c / 4} fmt.Println("weights #1", ws) - var n_samples int = 20 + var n_samples int = 1_000_000 var avg float64 = 0 for i := 0; i < n_samples; i++ { x := sample_mixture(fs[0:], ws[0:])