first pass mixture

This commit is contained in:
NunoSempere 2024-12-24 15:31:21 +01:00
parent ad70db5f14
commit 2f663b1262
3 changed files with 128 additions and 30 deletions

BIN
fermi Executable file

Binary file not shown.

View File

@ -50,13 +50,13 @@ func (p Scalar) Samples() []float64 {
} }
func (ln Lognormal) Samples() []float64 { func (ln Lognormal) Samples() []float64 {
sampler := func(r sample.Src) float64 { return sample.Sample_to(ln.low, ln.high, r) } sampler := func(r sample.State) float64 { return sample.Sample_to(ln.low, ln.high, r) }
// Can't do parallel because then I'd have to await throughout the code // Can't do parallel because then I'd have to await throughout the code
return sample.Sample_serially(sampler, N_SAMPLES) return sample.Sample_serially(sampler, N_SAMPLES)
} }
func (beta Beta) Samples() []float64 { func (beta Beta) Samples() []float64 {
sampler := func(r sample.Src) float64 { return sample.Sample_beta(beta.a, beta.b, r) } sampler := func(r sample.State) float64 { return sample.Sample_beta(beta.a, beta.b, r) }
return sample.Sample_serially(sampler, N_SAMPLES) return sample.Sample_serially(sampler, N_SAMPLES)
} }
@ -156,7 +156,6 @@ func prettyPrintDist(dist Dist) {
func printAndReturnErr(err_msg string) error { func printAndReturnErr(err_msg string) error {
fmt.Println(err_msg) fmt.Println(err_msg)
// fmt.Println(HELP_MSG)
fmt.Println("Type \"help\" (without quotes) to see a pseudogrammar and examples") fmt.Println("Type \"help\" (without quotes) to see a pseudogrammar and examples")
return errors.New(err_msg) return errors.New(err_msg)
} }
@ -354,6 +353,43 @@ func operateDists(old_dist Dist, new_dist Dist, op string) (Dist, error) {
} }
} }
/* Mixtures */
func parseMixture(words []string, vars map[string]Dist) (Dist, error) {
// mx, mix, var weight var weight var weight ...
// Check syntax
if len(words)%2 != 1 || words[0] != "mx" {
return nil, printAndReturnErr("Not a mixture. \nMixture syntax: \nmx x 2.5 y 8 z 10\ni.e.: mx var weight var2 weight2 ... var_n weight_n")
}
var dists []Dist
var fs [][]float64
var weights []float64
for i, word := range words[1:] {
if i%2 == 0 {
dist, exists := vars[word]
if !exists {
return nil, printAndReturnErr("Expected mixture variable but didn't get a variable. \nMixture syntax: \nmx x 2.5 y 8 z 10\ni.e.: mx var weight var2 weight2 ... var_n weight_n")
}
samples := dist.Samples()
dists = append(dists, dist)
fs = append(fs, samples)
} else {
weight, err := pretty.ParseFloat(word)
if err != nil {
return nil, printAndReturnErr("Expected mixture weight but didn't get a float. \nMixture syntax: \nmx x 2.5 y 8 z 10\ni.e.: mx var weight var2 weight2 ... var_n weight_n")
}
weights = append(weights, weight)
}
}
// Sample from mixture
xs, err := sample.Sample_mixture_serially(fs, weights, N_SAMPLES)
if err != nil {
return nil, printAndReturnErr(err.Error())
}
return FilledSamples{xs: xs}, nil
}
/* Parser and repl */ /* Parser and repl */
func parseWordsErr(err_msg string) (string, Dist, error) { func parseWordsErr(err_msg string) (string, Dist, error) {
return "", nil, printAndReturnErr(err_msg) return "", nil, printAndReturnErr(err_msg)
@ -368,7 +404,7 @@ func parseWordsIntoOpAndDist(words []string, vars map[string]Dist) (string, Dist
op = words[0] op = words[0]
words = words[1:] words = words[1:]
default: default:
op = "*" // later, change the below to op = "*"
} }
switch len(words) { switch len(words) {
@ -400,18 +436,28 @@ func parseWordsIntoOpAndDist(words []string, vars map[string]Dist) (string, Dist
} }
dist = Lognormal{low: new_low, high: new_high} dist = Lognormal{low: new_low, high: new_high}
case 3: case 3:
if words[0] == "beta" || words[0] == "b" { switch {
case words[0] == "beta" || words[0] == "b":
a, err1 := pretty.ParseFloat(words[1]) a, err1 := pretty.ParseFloat(words[1])
b, err2 := pretty.ParseFloat(words[2]) b, err2 := pretty.ParseFloat(words[2])
if err1 != nil || err2 != nil { if err1 != nil || err2 != nil {
return parseWordsErr("Trying to specify a beta distribution? Try beta 1 2") return parseWordsErr("Trying to specify a beta distribution? Try beta 1 2")
} }
dist = Beta{a: a, b: b} dist = Beta{a: a, b: b}
} else { default:
return parseWordsErr("Input not understood or not implemented yet") return parseWordsErr("Input not understood or not implemented yet")
} }
default: default:
return parseWordsErr("Input not understood or not implemented yet") switch words[0] {
case "mx":
tmp, err := parseMixture(words, vars)
if err != nil {
return parseWordsErr("Error parsing a mixture: " + err.Error())
}
dist = tmp
default:
return parseWordsErr("Input not understood or not implemented yet")
}
} }
return op, dist, nil return op, dist, nil
} }

View File

@ -1,44 +1,49 @@
package sample package sample
import "math" import (
import "sync" "math"
import rand "math/rand/v2" "sync"
rand "math/rand/v2"
"github.com/pkg/errors"
)
// https://pkg.go.dev/math/rand/v2 // https://pkg.go.dev/math/rand/v2
type Src = *rand.Rand type State = *rand.Rand
type func64 = func(Src) float64 type func64 = func(State) float64
var global_r = rand.New(rand.NewPCG(uint64(1), uint64(2))) var global_state = rand.New(rand.NewPCG(uint64(1), uint64(2)))
func Sample_unit_uniform(r Src) float64 { func Sample_unit_uniform(r State) float64 {
return r.Float64() return r.Float64()
} }
func Sample_unit_normal(r Src) float64 { func Sample_unit_normal(r State) float64 {
return r.NormFloat64() return r.NormFloat64()
} }
func Sample_uniform(start float64, end float64, r Src) float64 { func Sample_uniform(start float64, end float64, r State) float64 {
return Sample_unit_uniform(r)*(end-start) + start return Sample_unit_uniform(r)*(end-start) + start
} }
func Sample_normal(mean float64, sigma float64, r Src) float64 { func Sample_normal(mean float64, sigma float64, r State) float64 {
return mean + Sample_unit_normal(r)*sigma return mean + Sample_unit_normal(r)*sigma
} }
func Sample_lognormal(logmean float64, logstd float64, r Src) float64 { func Sample_lognormal(logmean float64, logstd float64, r State) float64 {
return (math.Exp(Sample_normal(logmean, logstd, r))) return (math.Exp(Sample_normal(logmean, logstd, r)))
} }
func Sample_normal_from_90_ci(low float64, high float64, r Src) float64 { func Sample_normal_from_90_ci(low float64, high float64, r State) float64 {
var normal90 float64 = 1.6448536269514727 var normal90 float64 = 1.6448536269514727
var mean float64 = (high + low) / 2.0 var mean float64 = (high + low) / 2.0
var std float64 = (high - low) / (2.0 * normal90) var std float64 = (high - low) / (2.0 * normal90)
return Sample_normal(mean, std, r) return Sample_normal(mean, std, r)
} }
func Sample_to(low float64, high float64, r Src) float64 { func Sample_to(low float64, high float64, r State) float64 {
// Given a (positive) 90% confidence interval, // Given a (positive) 90% confidence interval,
// returns a sample from a lognorma with a matching 90% c.i. // returns a sample from a lognorma with a matching 90% c.i.
// Key idea: If we want a lognormal with 90% confidence interval [a, b] // Key idea: If we want a lognormal with 90% confidence interval [a, b]
@ -49,7 +54,7 @@ func Sample_to(low float64, high float64, r Src) float64 {
return math.Exp(Sample_normal_from_90_ci(loglow, loghigh, r)) return math.Exp(Sample_normal_from_90_ci(loglow, loghigh, r))
} }
func Sample_gamma(alpha float64, r Src) float64 { func Sample_gamma(alpha float64, r State) float64 {
// a simple method for generating gamma variables, marsaglia and wan tsang, 2001 // a simple method for generating gamma variables, marsaglia and wan tsang, 2001
// https://dl.acm.org/doi/pdf/10.1145/358407.358414 // https://dl.acm.org/doi/pdf/10.1145/358407.358414
@ -99,13 +104,13 @@ func Sample_gamma(alpha float64, r Src) float64 {
} }
} }
func Sample_beta(a float64, b float64, r Src) float64 { func Sample_beta(a float64, b float64, r State) float64 {
gamma_a := Sample_gamma(a, r) gamma_a := Sample_gamma(a, r)
gamma_b := Sample_gamma(b, r) gamma_b := Sample_gamma(b, r)
return gamma_a / (gamma_a + gamma_b) return gamma_a / (gamma_a + gamma_b)
} }
func Sample_mixture(fs []func64, weights []float64, r Src) float64 { func Sample_mixture_once(fs []func64, weights []float64, r State) float64 {
// fmt.Println("weights initially: ", weights) // fmt.Println("weights initially: ", weights)
var sum_weights float64 = 0 var sum_weights float64 = 0
@ -141,13 +146,60 @@ func Sample_mixture(fs []func64, weights []float64, r Src) float64 {
func Sample_serially(f func64, n_samples int) []float64 { func Sample_serially(f func64, n_samples int) []float64 {
xs := make([]float64, n_samples) xs := make([]float64, n_samples)
// var global_r = rand.New(rand.NewPCG(uint64(1), uint64(2))) // var global_state = rand.New(rand.NewPCG(uint64(1), uint64(2)))
for i := 0; i < n_samples; i++ { for i := 0; i < n_samples; i++ {
xs[i] = f(global_r) xs[i] = f(global_state)
} }
return xs return xs
} }
func Sample_mixture_serially(fs [][]float64, weights []float64, n_samples int) ([]float64, error) {
// Checks
if len(weights) != len(fs) {
return nil, errors.New("Mixture must have dists and weights alternated")
}
for _, f := range fs {
if len(f) < n_samples {
return nil, errors.New("Mixture components don't have enough samples")
}
}
// fmt.Println("weights initially: ", weights)
var sum_weights float64 = 0
for _, weight := range weights {
sum_weights += weight
}
var total float64 = 0
var cumsummed_normalized_weights = append([]float64(nil), weights...)
for i, weight := range weights {
total += weight / sum_weights
cumsummed_normalized_weights[i] = total
}
if total == 0.0 {
return nil, errors.New("Cummulative sum of weights in mixture must be > 0.0")
}
var flag int = 0
var p float64 = global_state.Float64()
xs := make([]float64, n_samples)
// var global_state = rand.New(rand.NewPCG(uint64(1), uint64(2)))
for i := 0; i < n_samples; i++ {
for j, cnw := range cumsummed_normalized_weights {
if p < cnw {
xs[i] = fs[j][i]
flag = 1
break
}
}
if flag == 0 {
xs[i] = fs[len(fs)-1][i]
}
}
return xs, nil
}
func Sample_parallel(f func64, n_samples int) []float64 { func Sample_parallel(f func64, n_samples int) []float64 {
var num_threads = 16 var num_threads = 16
var xs = make([]float64, n_samples) var xs = make([]float64, n_samples)
@ -178,13 +230,13 @@ func main() {
var p_c float64 = p_a * p_b var p_c float64 = p_a * p_b
ws := [4](float64){1 - p_c, p_c / 2, p_c / 4, p_c / 4} ws := [4](float64){1 - p_c, p_c / 2, p_c / 4, p_c / 4}
Sample_0 := func(r Src) float64 { return 0 } Sample_0 := func(r State) float64 { return 0 }
Sample_1 := func(r Src) float64 { return 1 } Sample_1 := func(r State) float64 { return 1 }
Sample_few := func(r Src) float64 { return Sample_to(1, 3, r) } Sample_few := func(r State) float64 { return Sample_to(1, 3, r) }
Sample_many := func(r Src) float64 { return Sample_to(2, 10, r) } Sample_many := func(r State) float64 { return Sample_to(2, 10, r) }
fs := [4](func64){Sample_0, Sample_1, Sample_few, Sample_many} fs := [4](func64){Sample_0, Sample_1, Sample_few, Sample_many}
model := func(r Src) float64 { return Sample_mixture(fs[0:], ws[0:], r) } model := func(r State) float64 { return Sample_mixture(fs[0:], ws[0:], r) }
n_samples := 1_000_000 n_samples := 1_000_000
xs := Sample_parallel(model, n_samples) xs := Sample_parallel(model, n_samples)
var avg float64 = 0 var avg float64 = 0