diff --git a/f.go b/f.go index edffb0c..600ab17 100644 --- a/f.go +++ b/f.go @@ -19,7 +19,20 @@ const GENERAL_ERR_MSG = "Valid inputs: 2 || * 2 || / 2 || 2 20 || * 2 20 || / 2 type Dist interface { Samples() []float64 - Type() string +} + +// Point implementing distribution + +type Scalar struct { + p float64 +} + +func (p Scalar) Samples() []float64 { + xs := make([]float64, 1_000_000) + for i := 0; i < 1_000_000; i++ { + xs[i] = p.p + } + return xs } // Lognormal implementing Distribution @@ -62,21 +75,6 @@ func (beta FilledSamples) Type() string { return "FilledSamples" } -// Actually, I should look up how do do a) enums in go, b) union types -/*type Lognormal struct { - low float64 - high float64 -} -*/ - -/* -type Dist struct { - Type string - Lognormal Lognormal - Samples []float64 -} -*/ - // Parse line into Distribution func parseLineErr(err_msg string) (string, Dist, error) { fmt.Println(GENERAL_ERR_MSG) @@ -157,10 +155,55 @@ func multiplyBetaDists(beta1 Beta, beta2 Beta) Beta { return Beta{a: beta1.a + beta2.a, b: beta1.b + beta2.b} } -func joinDists(old_dist Dist, new_dist Dist, op string) (Dist, error) { - switch { +func multiplyAsSamples(dist1 Dist, dist2 Dist) Dist { + xs := dist1.Samples() + ys := dist2.Samples() + zs := make([]float64, 1_000_000) + for i, x := range xs { + zs[i] = x * ys[i] } + return FilledSamples{xs: xs} +} + +func multiplyDists(old_dist Dist, new_dist Dist) (Dist, error) { + + switch o := old_dist.(type) { + case Lognormal: + { + switch n := new_dist.(type) { + case Lognormal: + return multiplyLogDists(o, n), nil + case Scalar: + return multiplyLogDists(o, Lognormal{low: n.p, high: n.p}), nil + default: + return multiplyAsSamples(o, n), nil + } + } + case Scalar: + { + switch n := new_dist.(type) { + case Lognormal: + return multiplyLogDists(Lognormal{low: o.p, high: o.p}, n), nil + case Scalar: + return Scalar{p: o.p * n.p}, nil + default: + return multiplyAsSamples(o, n), nil + } + } + default: + return nil, errors.New("Can't multiply dists") + } +} + +func joinDists(old_dist Dist, new_dist Dist, op string) (Dist, error) { + + switch op { + case "*": + return multiplyDists(old_dist, new_dist) + default: + return old_dist, errors.New("Can't combine distributions in this way") + } /* switch { case old_dist.Type == "Lognormal" && new_dist.Type == "Lognormal" && op == "*": @@ -172,7 +215,7 @@ func joinDists(old_dist Dist, new_dist Dist, op string) (Dist, error) { fmt.Printf("For now, can't do anything besides multiplying lognormals\n") } */ - return old_dist, errors.New("Can't combine distributions in this way") + // return old_dist, errors.New("Can't combine distributions in this way") } /* Pretty print distributions */