go: finish debugging weights code.

This commit is contained in:
NunoSempere 2024-02-16 14:15:48 +01:00
parent d3cb97684a
commit 5029f67429

View File

@ -58,23 +58,24 @@ func sample_mixture(fs []func64, weights []float64) float64 {
} }
var total float64 = 0 var total float64 = 0
var cumsummed_normalized_weights = append([]float64(nil), weights...)
for i, weight := range weights { for i, weight := range weights {
total += weight / sum_weights total += weight / sum_weights
weights[i] = total cumsummed_normalized_weights[i] = total
} }
var result float64 var result float64
var flag int = 0 var flag int = 0
var p float64 = r.Float64() var p float64 = r.Float64()
for i, weight := range weights { for i, cnw := range cumsummed_normalized_weights {
if p < weight { if p < cnw {
result = fs[i]() result = fs[i]()
flag = 1 flag = 1
break break
} }
} }
fmt.Println(weights) fmt.Println(cumsummed_normalized_weights)
if flag == 0 { if flag == 0 {
result = fs[len(fs)-1]() result = fs[len(fs)-1]()
@ -98,7 +99,7 @@ func main() {
ws := [4](float64){1 - p_c, p_c / 2, p_c / 4, p_c / 4} ws := [4](float64){1 - p_c, p_c / 2, p_c / 4, p_c / 4}
fmt.Println("weights #1", ws) fmt.Println("weights #1", ws)
var n_samples int = 20 var n_samples int = 1_000_000
var avg float64 = 0 var avg float64 = 0
for i := 0; i < n_samples; i++ { for i := 0; i < n_samples; i++ {
x := sample_mixture(fs[0:], ws[0:]) x := sample_mixture(fs[0:], ws[0:])