switch on type using new distribution interface

This commit is contained in:
NunoSempere 2024-06-10 00:24:06 +02:00
parent 8b1792f861
commit 0e9ef33b8e

81
f.go
View File

@ -19,7 +19,20 @@ const GENERAL_ERR_MSG = "Valid inputs: 2 || * 2 || / 2 || 2 20 || * 2 20 || / 2
type Dist interface { type Dist interface {
Samples() []float64 Samples() []float64
Type() string }
// Point implementing distribution
type Scalar struct {
p float64
}
func (p Scalar) Samples() []float64 {
xs := make([]float64, 1_000_000)
for i := 0; i < 1_000_000; i++ {
xs[i] = p.p
}
return xs
} }
// Lognormal implementing Distribution // Lognormal implementing Distribution
@ -62,21 +75,6 @@ func (beta FilledSamples) Type() string {
return "FilledSamples" return "FilledSamples"
} }
// Actually, I should look up how do do a) enums in go, b) union types
/*type Lognormal struct {
low float64
high float64
}
*/
/*
type Dist struct {
Type string
Lognormal Lognormal
Samples []float64
}
*/
// Parse line into Distribution // Parse line into Distribution
func parseLineErr(err_msg string) (string, Dist, error) { func parseLineErr(err_msg string) (string, Dist, error) {
fmt.Println(GENERAL_ERR_MSG) fmt.Println(GENERAL_ERR_MSG)
@ -157,10 +155,55 @@ func multiplyBetaDists(beta1 Beta, beta2 Beta) Beta {
return Beta{a: beta1.a + beta2.a, b: beta1.b + beta2.b} return Beta{a: beta1.a + beta2.a, b: beta1.b + beta2.b}
} }
func joinDists(old_dist Dist, new_dist Dist, op string) (Dist, error) { func multiplyAsSamples(dist1 Dist, dist2 Dist) Dist {
switch { xs := dist1.Samples()
ys := dist2.Samples()
zs := make([]float64, 1_000_000)
for i, x := range xs {
zs[i] = x * ys[i]
} }
return FilledSamples{xs: xs}
}
func multiplyDists(old_dist Dist, new_dist Dist) (Dist, error) {
switch o := old_dist.(type) {
case Lognormal:
{
switch n := new_dist.(type) {
case Lognormal:
return multiplyLogDists(o, n), nil
case Scalar:
return multiplyLogDists(o, Lognormal{low: n.p, high: n.p}), nil
default:
return multiplyAsSamples(o, n), nil
}
}
case Scalar:
{
switch n := new_dist.(type) {
case Lognormal:
return multiplyLogDists(Lognormal{low: o.p, high: o.p}, n), nil
case Scalar:
return Scalar{p: o.p * n.p}, nil
default:
return multiplyAsSamples(o, n), nil
}
}
default:
return nil, errors.New("Can't multiply dists")
}
}
func joinDists(old_dist Dist, new_dist Dist, op string) (Dist, error) {
switch op {
case "*":
return multiplyDists(old_dist, new_dist)
default:
return old_dist, errors.New("Can't combine distributions in this way")
}
/* /*
switch { switch {
case old_dist.Type == "Lognormal" && new_dist.Type == "Lognormal" && op == "*": case old_dist.Type == "Lognormal" && new_dist.Type == "Lognormal" && op == "*":
@ -172,7 +215,7 @@ func joinDists(old_dist Dist, new_dist Dist, op string) (Dist, error) {
fmt.Printf("For now, can't do anything besides multiplying lognormals\n") fmt.Printf("For now, can't do anything besides multiplying lognormals\n")
} }
*/ */
return old_dist, errors.New("Can't combine distributions in this way") // return old_dist, errors.New("Can't combine distributions in this way")
} }
/* Pretty print distributions */ /* Pretty print distributions */