savepoint

This commit is contained in:
NunoSempere 2024-06-09 23:35:36 +02:00
parent 318b8da414
commit b0adde937f

45
f.go
View File

@ -17,8 +17,9 @@ const GENERAL_ERR_MSG = "Valid inputs: 2 || * 2 || / 2 || 2 20 || * 2 20 || / 2
// Distribution interface
// https://go.dev/tour/methods/9
type Distribution interface {
type Dist interface {
Samples() []float64
Type() string
}
// Lognormal implementing Distribution
@ -31,6 +32,9 @@ func (ln Lognormal) Samples() []float64 {
sampler := func(r sample.Src) float64 { return sample.Sample_to(ln.low, ln.high, r) }
return sample.Sample_parallel(sampler, 1_000_000)
}
func (ln Lognormal) Type() string {
return "Lognormal"
}
// Beta implementing Distribution
type Beta struct {
@ -41,7 +45,9 @@ type Beta struct {
func (beta Beta) Samples() []float64 {
sampler := func(r sample.Src) float64 { return sample.Sample_beta(beta.a, beta.b, r) }
return sample.Sample_parallel(sampler, 1_000_000)
}
func (beta Beta) Type() string {
return "Beta"
}
// FilledSamples implementing Distribution
@ -52,6 +58,9 @@ type FilledSamples struct {
func (fs FilledSamples) Samples() []float64 {
return fs.xs
}
func (beta FilledSamples) Type() string {
return "FilledSamples"
}
// Actually, I should look up how do do a) enums in go, b) union types
/*type Lognormal struct {
@ -60,17 +69,20 @@ func (fs FilledSamples) Samples() []float64 {
}
*/
/*
type Dist struct {
Type string
Lognormal Lognormal
Samples []float64
}
*/
// Parse line into Distribution
func parseLineErr(err_msg string) (string, Dist, error) {
fmt.Println(GENERAL_ERR_MSG)
fmt.Println(err_msg)
return "", Dist{}, errors.New(err_msg)
var errorDist Dist
return "", errorDist, errors.New(err_msg)
}
func parseLine(line string, vars map[string]Dist) (string, Dist, error) {
@ -103,7 +115,7 @@ func parseLine(line string, vars map[string]Dist) (string, Dist, error) {
case var_word_exists:
dist = var_word
case err1 == nil:
dist = Dist{Type: "Lognormal", Lognormal: Lognormal{low: single_float, high: single_float}, Samples: nil}
dist = Lognormal{low: single_float, high: single_float}
case err1 != nil && !var_word_exists:
return parseLineErr("Trying to operate on a scalar, but scalar is neither a float nor an assigned variable")
}
@ -113,7 +125,7 @@ func parseLine(line string, vars map[string]Dist) (string, Dist, error) {
if err1 != nil || err2 != nil {
return parseLineErr("Trying to operate by a distribution, but distribution is not specified as two floats")
}
dist = Dist{Type: "Lognormal", Lognormal: Lognormal{low: new_low, high: new_high}, Samples: nil}
dist = Lognormal{low: new_low, high: new_high}
default:
return parseLineErr("Other input methods not implemented yet")
}
@ -147,14 +159,19 @@ func multiplyBetaDists(beta1 Beta, beta2 Beta) Beta {
func joinDists(old_dist Dist, new_dist Dist, op string) (Dist, error) {
switch {
case old_dist.Type == "Lognormal" && new_dist.Type == "Lognormal" && op == "*":
return Dist{Type: "Lognormal", Lognormal: multiplyLogDists(old_dist.Lognormal, new_dist.Lognormal), Samples: nil}, nil
case old_dist.Type == "Lognormal" && new_dist.Type == "Lognormal" && op == "/":
tmp_dist := Lognormal{low: 1.0 / new_dist.Lognormal.high, high: 1.0 / new_dist.Lognormal.low}
return Dist{Type: "Lognormal", Lognormal: multiplyLogDists(old_dist.Lognormal, tmp_dist), Samples: nil}, nil
default:
fmt.Printf("For now, can't do anything besides multiplying lognormals\n")
}
/*
switch {
case old_dist.Type == "Lognormal" && new_dist.Type == "Lognormal" && op == "*":
return Dist{Type: "Lognormal", Lognormal: multiplyLogDists(old_dist.Lognormal, new_dist.Lognormal), Samples: nil}, nil
case old_dist.Type == "Lognormal" && new_dist.Type == "Lognormal" && op == "/":
tmp_dist := Lognormal{low: 1.0 / new_dist.Lognormal.high, high: 1.0 / new_dist.Lognormal.low}
return Dist{Type: "Lognormal", Lognormal: multiplyLogDists(old_dist.Lognormal, tmp_dist), Samples: nil}, nil
default:
fmt.Printf("For now, can't do anything besides multiplying lognormals\n")
}
*/
return old_dist, errors.New("Can't combine distributions in this way")
}
@ -196,8 +213,8 @@ func prettyPrintLognormal(low float64, high float64) {
}
func prettyPrintDist(dist Dist) {
if dist.Type == "Lognormal" {
prettyPrintLognormal(dist.Lognormal.low, dist.Lognormal.high)
if dist.Type() == "Lognormal" {
prettyPrintLognormal(dist.low, dist.high)
} else {
fmt.Printf("%v", dist)
}