rework sampling for mixtures a bit
This commit is contained in:
parent
e473223bbd
commit
3ca32655d5
27
fermi.go
27
fermi.go
|
@ -21,7 +21,7 @@ type Stack struct {
|
||||||
|
|
||||||
type Dist interface {
|
type Dist interface {
|
||||||
Samples() []float64
|
Samples() []float64
|
||||||
Sampler(sample.State) float64
|
Sampler(int, sample.State) float64
|
||||||
}
|
}
|
||||||
|
|
||||||
type Scalar float64
|
type Scalar float64
|
||||||
|
@ -49,7 +49,7 @@ func (p Scalar) Samples() []float64 {
|
||||||
}
|
}
|
||||||
return xs
|
return xs
|
||||||
}
|
}
|
||||||
func (p Scalar) Sampler(r sample.State) float64 {
|
func (p Scalar) Sampler(i int, r sample.State) float64 {
|
||||||
return float64(p)
|
return float64(p)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -58,7 +58,7 @@ func (ln Lognormal) Samples() []float64 {
|
||||||
// Can't do parallel because then I'd have to await throughout the code
|
// Can't do parallel because then I'd have to await throughout the code
|
||||||
return sample.Sample_serially(sampler, N_SAMPLES)
|
return sample.Sample_serially(sampler, N_SAMPLES)
|
||||||
}
|
}
|
||||||
func (ln Lognormal) Sampler(r sample.State) float64 {
|
func (ln Lognormal) Sampler(i int, r sample.State) float64 {
|
||||||
return sample.Sample_to(ln.low, ln.high, r)
|
return sample.Sample_to(ln.low, ln.high, r)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -66,18 +66,20 @@ func (beta Beta) Samples() []float64 {
|
||||||
sampler := func(r sample.State) float64 { return sample.Sample_beta(beta.a, beta.b, r) }
|
sampler := func(r sample.State) float64 { return sample.Sample_beta(beta.a, beta.b, r) }
|
||||||
return sample.Sample_serially(sampler, N_SAMPLES)
|
return sample.Sample_serially(sampler, N_SAMPLES)
|
||||||
}
|
}
|
||||||
func (beta Beta) Sampler(r sample.State) float64 {
|
func (beta Beta) Sampler(i int, r sample.State) float64 {
|
||||||
return sample.Sample_beta(beta.a, beta.b, r)
|
return sample.Sample_beta(beta.a, beta.b, r)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (fs FilledSamples) Samples() []float64 {
|
func (fs FilledSamples) Samples() []float64 {
|
||||||
return fs.xs
|
return fs.xs
|
||||||
}
|
}
|
||||||
func (fs FilledSamples) Sampler(r sample.State) float64 {
|
func (fs FilledSamples) Sampler(i int, r sample.State) float64 {
|
||||||
// This is a bit subtle, because sampling from a FilledSamples item iteratively converges
|
// This is a bit subtle, because sampling from FilledSamples randomly iteratively converges
|
||||||
// to something different than the initial distribution
|
// to something different than the initial distribution
|
||||||
n := len(fs.xs)
|
// So instead we have an i parameter.
|
||||||
i := sample.Sample_int(n, r)
|
// Not sure how I feel about it
|
||||||
|
// n := len(fs.xs)
|
||||||
|
// i := sample.Sample_int(n, r)
|
||||||
return fs.xs[i]
|
return fs.xs[i]
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -379,7 +381,8 @@ func parseMixture(words []string, vars map[string]Dist) (Dist, error) {
|
||||||
}
|
}
|
||||||
|
|
||||||
var dists []Dist
|
var dists []Dist
|
||||||
var fs [][]float64
|
var fs []func(int, sample.State) float64
|
||||||
|
var ss [][]float64
|
||||||
var weights []float64
|
var weights []float64
|
||||||
|
|
||||||
for i, word := range words {
|
for i, word := range words {
|
||||||
|
@ -389,9 +392,11 @@ func parseMixture(words []string, vars map[string]Dist) (Dist, error) {
|
||||||
return nil, printAndReturnErr("Expected mixture variable but didn't get a variable. \nMixture syntax: \nmx x 2.5 y 8 z 10\ni.e.: mx var weight var2 weight2 ... var_n weight_n")
|
return nil, printAndReturnErr("Expected mixture variable but didn't get a variable. \nMixture syntax: \nmx x 2.5 y 8 z 10\ni.e.: mx var weight var2 weight2 ... var_n weight_n")
|
||||||
}
|
}
|
||||||
samples := dist.Samples()
|
samples := dist.Samples()
|
||||||
|
f := dist.Sampler
|
||||||
// Inefficient to draw N_SAMPLES for each of the distributions, but conceptually simpler.
|
// Inefficient to draw N_SAMPLES for each of the distributions, but conceptually simpler.
|
||||||
dists = append(dists, dist)
|
dists = append(dists, dist)
|
||||||
fs = append(fs, samples)
|
fs = append(fs, f)
|
||||||
|
ss = append(ss, samples)
|
||||||
} else {
|
} else {
|
||||||
weight, err := pretty.ParseFloat(word)
|
weight, err := pretty.ParseFloat(word)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
@ -401,7 +406,7 @@ func parseMixture(words []string, vars map[string]Dist) (Dist, error) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
// Sample from mixture
|
// Sample from mixture
|
||||||
xs, err := sample.Sample_mixture_serially(fs, weights, N_SAMPLES)
|
xs, err := sample.Sample_mixture_serially_from_samplers(fs, weights, N_SAMPLES)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, printAndReturnErr(err.Error())
|
return nil, printAndReturnErr(err.Error())
|
||||||
}
|
}
|
||||||
|
|
|
@ -13,6 +13,7 @@ import (
|
||||||
|
|
||||||
type State = *rand.Rand
|
type State = *rand.Rand
|
||||||
type func64 = func(State) float64
|
type func64 = func(State) float64
|
||||||
|
type func64i = func(int, State) float64
|
||||||
|
|
||||||
var global_state = rand.New(rand.NewPCG(uint64(1), uint64(2)))
|
var global_state = rand.New(rand.NewPCG(uint64(1), uint64(2)))
|
||||||
|
|
||||||
|
@ -157,7 +158,7 @@ func Sample_serially(f func64, n_samples int) []float64 {
|
||||||
return xs
|
return xs
|
||||||
}
|
}
|
||||||
|
|
||||||
func Sample_mixture_serially(fs [][]float64, weights []float64, n_samples int) ([]float64, error) {
|
func Sample_mixture_serially_from_samples(fs [][]float64, weights []float64, n_samples int) ([]float64, error) {
|
||||||
|
|
||||||
// Checks
|
// Checks
|
||||||
if len(weights) != len(fs) {
|
if len(weights) != len(fs) {
|
||||||
|
@ -205,6 +206,47 @@ func Sample_mixture_serially(fs [][]float64, weights []float64, n_samples int) (
|
||||||
return xs, nil
|
return xs, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func Sample_mixture_serially_from_samplers(fs []func64i, weights []float64, n_samples int) ([]float64, error) {
|
||||||
|
|
||||||
|
// Checks
|
||||||
|
if len(weights) != len(fs) {
|
||||||
|
return nil, errors.New("Mixture must have dists and weights alternated")
|
||||||
|
}
|
||||||
|
// fmt.Println("weights initially: ", weights)
|
||||||
|
var sum_weights float64 = 0
|
||||||
|
for _, weight := range weights {
|
||||||
|
sum_weights += weight
|
||||||
|
}
|
||||||
|
|
||||||
|
var total float64 = 0
|
||||||
|
var cumsummed_normalized_weights = append([]float64(nil), weights...)
|
||||||
|
for i, weight := range weights {
|
||||||
|
total += weight / sum_weights
|
||||||
|
cumsummed_normalized_weights[i] = total
|
||||||
|
}
|
||||||
|
if total == 0.0 {
|
||||||
|
return nil, errors.New("Cummulative sum of weights in mixture must be > 0.0")
|
||||||
|
}
|
||||||
|
|
||||||
|
// fmt.Printf("Weights: %v\n", cumsummed_normalized_weights)
|
||||||
|
xs := make([]float64, n_samples)
|
||||||
|
for i := 0; i < n_samples; i++ {
|
||||||
|
var flag int = 0
|
||||||
|
var p float64 = global_state.Float64()
|
||||||
|
for j, cnw := range cumsummed_normalized_weights {
|
||||||
|
if p < cnw {
|
||||||
|
xs[i] = fs[j](i, global_state)
|
||||||
|
flag = 1
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if flag == 0 {
|
||||||
|
xs[i] = fs[len(fs)-1](i, global_state)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return xs, nil
|
||||||
|
}
|
||||||
|
|
||||||
func Sample_parallel(f func64, n_samples int) []float64 {
|
func Sample_parallel(f func64, n_samples int) []float64 {
|
||||||
var num_threads = 16
|
var num_threads = 16
|
||||||
var xs = make([]float64, n_samples)
|
var xs = make([]float64, n_samples)
|
||||||
|
|
Loading…
Reference in New Issue
Block a user