debug mixture implementation

This commit is contained in:
NunoSempere 2024-12-24 15:39:45 +01:00
parent 2f663b1262
commit 884adba214
5 changed files with 22 additions and 6 deletions

View File

@ -357,7 +357,7 @@ func operateDists(old_dist Dist, new_dist Dist, op string) (Dist, error) {
func parseMixture(words []string, vars map[string]Dist) (Dist, error) {
// mx, mix, var weight var weight var weight ...
// Check syntax
if len(words)%2 != 1 || words[0] != "mx" {
if len(words)%2 != 0 {
return nil, printAndReturnErr("Not a mixture. \nMixture syntax: \nmx x 2.5 y 8 z 10\ni.e.: mx var weight var2 weight2 ... var_n weight_n")
}
@ -365,7 +365,7 @@ func parseMixture(words []string, vars map[string]Dist) (Dist, error) {
var fs [][]float64
var weights []float64
for i, word := range words[1:] {
for i, word := range words {
if i%2 == 0 {
dist, exists := vars[word]
if !exists {
@ -400,7 +400,7 @@ func parseWordsIntoOpAndDist(words []string, vars map[string]Dist) (string, Dist
var dist Dist
switch words[0] {
case "*", "/", "+", "-":
case "*", "/", "+", "-", "mx":
op = words[0]
words = words[1:]
default:
@ -448,13 +448,14 @@ func parseWordsIntoOpAndDist(words []string, vars map[string]Dist) (string, Dist
return parseWordsErr("Input not understood or not implemented yet")
}
default:
switch words[0] {
switch op {
case "mx":
tmp, err := parseMixture(words, vars)
if err != nil {
return parseWordsErr("Error parsing a mixture: " + err.Error())
}
dist = tmp
op = "*"
default:
return parseWordsErr("Input not understood or not implemented yet")
}

2
go.mod
View File

@ -1,3 +1,5 @@
module git.nunosempere.com/NunoSempere/fermi
go 1.22.1
require github.com/pkg/errors v0.9.1 // indirect

2
go.sum Normal file
View File

@ -0,0 +1,2 @@
github.com/pkg/errors v0.9.1 h1:FEBLx1zS214owpjy7qsBeixbURkuhQAwrK5UwLGTwt4=
github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0=

9
mixture.fermi Normal file
View File

@ -0,0 +1,9 @@
1 100
=. x
2K 200K
=. y
mx x 40% y 60%
stats
exit

View File

@ -1,6 +1,7 @@
package sample
import (
"fmt"
"math"
"sync"
@ -180,12 +181,13 @@ func Sample_mixture_serially(fs [][]float64, weights []float64, n_samples int) (
return nil, errors.New("Cummulative sum of weights in mixture must be > 0.0")
}
var flag int = 0
var p float64 = global_state.Float64()
fmt.Printf("Weights: %v\n", cumsummed_normalized_weights)
xs := make([]float64, n_samples)
// var global_state = rand.New(rand.NewPCG(uint64(1), uint64(2)))
for i := 0; i < n_samples; i++ {
var flag int = 0
var p float64 = global_state.Float64()
for j, cnw := range cumsummed_normalized_weights {
if p < cnw {
xs[i] = fs[j][i]