add some guards for edge cases

scalar = 0 * logdist
scalar < 1 * logdist
logdist with parameters < 0
logdist with same parameters (should be scalar)
logdist with low > high
This commit is contained in:
NunoSempere 2024-11-20 11:54:31 +00:00
parent ae1e1bbe97
commit 84bdfa004f

View File

@ -254,6 +254,16 @@ func multiplyBetaDists(beta1 Beta, beta2 Beta) Beta {
return Beta{a: beta1.a + beta2.a, b: beta1.b + beta2.b}
}
func multiplyLogDistAndScalar(l Lognormal, s Scalar) (Dist, error) {
if s == 0.0 {
return Scalar(0.0), nil
} else if s < 0.0 {
return operateDistsAsSamples(s, l, "+")
} else {
return multiplyLogDists(l, Lognormal{low: float64(s), high: float64(s)}), nil
}
}
func multiplyDists(old_dist Dist, new_dist Dist) (Dist, error) {
switch o := old_dist.(type) {
@ -263,17 +273,20 @@ func multiplyDists(old_dist Dist, new_dist Dist) (Dist, error) {
case Lognormal:
return multiplyLogDists(o, n), nil
case Scalar:
return multiplyLogDists(o, Lognormal{low: float64(n), high: float64(n)}), nil
return multiplyLogDistAndScalar(o, n)
}
}
case Scalar:
{
if o == 1 {
switch o {
case 1.0:
return new_dist, nil
case 0.0:
return Scalar(0.0), nil
}
switch n := new_dist.(type) {
case Lognormal:
return multiplyLogDists(Lognormal{low: float64(o), high: float64(o)}, n), nil
return multiplyLogDistAndScalar(n, o)
case Scalar:
return Scalar(float64(o) * float64(n)), nil
}
@ -304,14 +317,14 @@ func divideDists(old_dist Dist, new_dist Dist) (Dist, error) {
fmt.Println("Error: Can't divide by 0.0")
return nil, errors.New("Error: division by zero scalar")
}
return multiplyLogDists(o, Lognormal{low: 1.0 / float64(n), high: 1.0 / float64(n)}), nil
return multiplyLogDistAndScalar(o, Scalar(1.0/n))
}
}
case Scalar:
{
switch n := new_dist.(type) {
case Lognormal:
return multiplyLogDists(Lognormal{low: float64(o), high: float64(o)}, Lognormal{low: 1.0 / n.high, high: 1.0 / n.low}), nil
return multiplyLogDistAndScalar(Lognormal{low: 1.0 / n.high, high: 1.0 / n.low}, o)
case Scalar:
if n == 0.0 {
fmt.Println("Error: Can't divide by 0.0")
@ -374,8 +387,15 @@ func parseWordsIntoOpAndDist(words []string, vars map[string]Dist) (string, Dist
case 2:
new_low, err1 := pretty.ParseFloat(words[0])
new_high, err2 := pretty.ParseFloat(words[1])
if err1 != nil || err2 != nil {
switch {
case err1 != nil || err2 != nil:
return parseWordsErr("Trying to operate by a distribution, but distribution is not specified as two floats")
case new_low <= 0.0 || new_high <= 0.0:
return parseWordsErr("Trying to parse two floats as a lognormal, but the two floats must be greater than 0")
case new_low == new_high:
return parseWordsErr("Trying to parse two floats as a lognormal, but the two floats must be different. Try a single scalar instead?")
case new_low > new_high:
return parseWordsErr("Trying to parse two floats as a lognormal, but the first number is larger than the second number")
}
dist = Lognormal{low: new_low, high: new_high}
case 3: