From 84bdfa004f946fe0aa2246a2b54ab18e274c5a5b Mon Sep 17 00:00:00 2001 From: NunoSempere Date: Wed, 20 Nov 2024 11:54:31 +0000 Subject: [PATCH] add some guards for edge cases scalar = 0 * logdist scalar < 1 * logdist logdist with parameters < 0 logdist with same parameters (should be scalar) logdist with low > high --- fermi.go | 32 ++++++++++++++++++++++++++------ 1 file changed, 26 insertions(+), 6 deletions(-) diff --git a/fermi.go b/fermi.go index 8a4e4fc..581b95e 100644 --- a/fermi.go +++ b/fermi.go @@ -254,6 +254,16 @@ func multiplyBetaDists(beta1 Beta, beta2 Beta) Beta { return Beta{a: beta1.a + beta2.a, b: beta1.b + beta2.b} } +func multiplyLogDistAndScalar(l Lognormal, s Scalar) (Dist, error) { + if s == 0.0 { + return Scalar(0.0), nil + } else if s < 0.0 { + return operateDistsAsSamples(s, l, "+") + } else { + return multiplyLogDists(l, Lognormal{low: float64(s), high: float64(s)}), nil + } +} + func multiplyDists(old_dist Dist, new_dist Dist) (Dist, error) { switch o := old_dist.(type) { @@ -263,17 +273,20 @@ func multiplyDists(old_dist Dist, new_dist Dist) (Dist, error) { case Lognormal: return multiplyLogDists(o, n), nil case Scalar: - return multiplyLogDists(o, Lognormal{low: float64(n), high: float64(n)}), nil + return multiplyLogDistAndScalar(o, n) } } case Scalar: { - if o == 1 { + switch o { + case 1.0: return new_dist, nil + case 0.0: + return Scalar(0.0), nil } switch n := new_dist.(type) { case Lognormal: - return multiplyLogDists(Lognormal{low: float64(o), high: float64(o)}, n), nil + return multiplyLogDistAndScalar(n, o) case Scalar: return Scalar(float64(o) * float64(n)), nil } @@ -304,14 +317,14 @@ func divideDists(old_dist Dist, new_dist Dist) (Dist, error) { fmt.Println("Error: Can't divide by 0.0") return nil, errors.New("Error: division by zero scalar") } - return multiplyLogDists(o, Lognormal{low: 1.0 / float64(n), high: 1.0 / float64(n)}), nil + return multiplyLogDistAndScalar(o, Scalar(1.0/n)) } } case Scalar: { switch n := new_dist.(type) { case Lognormal: - return multiplyLogDists(Lognormal{low: float64(o), high: float64(o)}, Lognormal{low: 1.0 / n.high, high: 1.0 / n.low}), nil + return multiplyLogDistAndScalar(Lognormal{low: 1.0 / n.high, high: 1.0 / n.low}, o) case Scalar: if n == 0.0 { fmt.Println("Error: Can't divide by 0.0") @@ -374,8 +387,15 @@ func parseWordsIntoOpAndDist(words []string, vars map[string]Dist) (string, Dist case 2: new_low, err1 := pretty.ParseFloat(words[0]) new_high, err2 := pretty.ParseFloat(words[1]) - if err1 != nil || err2 != nil { + switch { + case err1 != nil || err2 != nil: return parseWordsErr("Trying to operate by a distribution, but distribution is not specified as two floats") + case new_low <= 0.0 || new_high <= 0.0: + return parseWordsErr("Trying to parse two floats as a lognormal, but the two floats must be greater than 0") + case new_low == new_high: + return parseWordsErr("Trying to parse two floats as a lognormal, but the two floats must be different. Try a single scalar instead?") + case new_low > new_high: + return parseWordsErr("Trying to parse two floats as a lognormal, but the first number is larger than the second number") } dist = Lognormal{low: new_low, high: new_high} case 3: