add some guards for edge cases

scalar = 0 * logdist
scalar < 1 * logdist
logdist with parameters < 0
logdist with same parameters (should be scalar)
logdist with low > high
This commit is contained in:
NunoSempere 2024-11-20 11:54:31 +00:00
parent ae1e1bbe97
commit 84bdfa004f

View File

@ -254,6 +254,16 @@ func multiplyBetaDists(beta1 Beta, beta2 Beta) Beta {
return Beta{a: beta1.a + beta2.a, b: beta1.b + beta2.b} return Beta{a: beta1.a + beta2.a, b: beta1.b + beta2.b}
} }
func multiplyLogDistAndScalar(l Lognormal, s Scalar) (Dist, error) {
if s == 0.0 {
return Scalar(0.0), nil
} else if s < 0.0 {
return operateDistsAsSamples(s, l, "+")
} else {
return multiplyLogDists(l, Lognormal{low: float64(s), high: float64(s)}), nil
}
}
func multiplyDists(old_dist Dist, new_dist Dist) (Dist, error) { func multiplyDists(old_dist Dist, new_dist Dist) (Dist, error) {
switch o := old_dist.(type) { switch o := old_dist.(type) {
@ -263,17 +273,20 @@ func multiplyDists(old_dist Dist, new_dist Dist) (Dist, error) {
case Lognormal: case Lognormal:
return multiplyLogDists(o, n), nil return multiplyLogDists(o, n), nil
case Scalar: case Scalar:
return multiplyLogDists(o, Lognormal{low: float64(n), high: float64(n)}), nil return multiplyLogDistAndScalar(o, n)
} }
} }
case Scalar: case Scalar:
{ {
if o == 1 { switch o {
case 1.0:
return new_dist, nil return new_dist, nil
case 0.0:
return Scalar(0.0), nil
} }
switch n := new_dist.(type) { switch n := new_dist.(type) {
case Lognormal: case Lognormal:
return multiplyLogDists(Lognormal{low: float64(o), high: float64(o)}, n), nil return multiplyLogDistAndScalar(n, o)
case Scalar: case Scalar:
return Scalar(float64(o) * float64(n)), nil return Scalar(float64(o) * float64(n)), nil
} }
@ -304,14 +317,14 @@ func divideDists(old_dist Dist, new_dist Dist) (Dist, error) {
fmt.Println("Error: Can't divide by 0.0") fmt.Println("Error: Can't divide by 0.0")
return nil, errors.New("Error: division by zero scalar") return nil, errors.New("Error: division by zero scalar")
} }
return multiplyLogDists(o, Lognormal{low: 1.0 / float64(n), high: 1.0 / float64(n)}), nil return multiplyLogDistAndScalar(o, Scalar(1.0/n))
} }
} }
case Scalar: case Scalar:
{ {
switch n := new_dist.(type) { switch n := new_dist.(type) {
case Lognormal: case Lognormal:
return multiplyLogDists(Lognormal{low: float64(o), high: float64(o)}, Lognormal{low: 1.0 / n.high, high: 1.0 / n.low}), nil return multiplyLogDistAndScalar(Lognormal{low: 1.0 / n.high, high: 1.0 / n.low}, o)
case Scalar: case Scalar:
if n == 0.0 { if n == 0.0 {
fmt.Println("Error: Can't divide by 0.0") fmt.Println("Error: Can't divide by 0.0")
@ -374,8 +387,15 @@ func parseWordsIntoOpAndDist(words []string, vars map[string]Dist) (string, Dist
case 2: case 2:
new_low, err1 := pretty.ParseFloat(words[0]) new_low, err1 := pretty.ParseFloat(words[0])
new_high, err2 := pretty.ParseFloat(words[1]) new_high, err2 := pretty.ParseFloat(words[1])
if err1 != nil || err2 != nil { switch {
case err1 != nil || err2 != nil:
return parseWordsErr("Trying to operate by a distribution, but distribution is not specified as two floats") return parseWordsErr("Trying to operate by a distribution, but distribution is not specified as two floats")
case new_low <= 0.0 || new_high <= 0.0:
return parseWordsErr("Trying to parse two floats as a lognormal, but the two floats must be greater than 0")
case new_low == new_high:
return parseWordsErr("Trying to parse two floats as a lognormal, but the two floats must be different. Try a single scalar instead?")
case new_low > new_high:
return parseWordsErr("Trying to parse two floats as a lognormal, but the first number is larger than the second number")
} }
dist = Lognormal{low: new_low, high: new_high} dist = Lognormal{low: new_low, high: new_high}
case 3: case 3: