diff --git a/packages/squiggle-lang/src/rescript/Distributions/SymbolicDist/SymbolicDist.res b/packages/squiggle-lang/src/rescript/Distributions/SymbolicDist/SymbolicDist.res index c44a972c..92a8277d 100644 --- a/packages/squiggle-lang/src/rescript/Distributions/SymbolicDist/SymbolicDist.res +++ b/packages/squiggle-lang/src/rescript/Distributions/SymbolicDist/SymbolicDist.res @@ -10,7 +10,9 @@ type rec error = | NotFinite(string, string, float) | DivideByZero(string) | NormalError(normalError) + | SafeMath(SafeMath.error) | MultipleErrors(array) + | NinetiethPercentileShouldBeOrdered module Error = { let mapErrorArrayToError = (errors: array): option => { @@ -51,8 +53,12 @@ module Normal = { let from90PercentCI = (low, high) => { let construct = () => { let mean = E.A.Floats.mean([low, high]) - let stdev = (high -. low) /. (2. *. normal95confidencePoint) - make(~mean, ~stdev) + let stdev = + SafeMath.F.divide( + ~num=high -. low, + ~denominator=2. *. normal95confidencePoint, + )->E.R2.errMap(r => SafeMath(r)) + stdev->E.R.bind(stdev => make(~mean, ~stdev)) } [ Error.checkIsFinite(low, "Normal", "low"), @@ -75,7 +81,7 @@ module Normal = { dangerouslyMake(~mean, ~stdev) } - // TODO: is this useful here at all? would need the integral as well ... + // Note: This isn't being used right now let pointwiseProduct = (n1: t, n2: t) => { let mean = (n1.mean *. n2.stdev ** 2. +. n2.mean *. n1.stdev ** 2.) /. (n1.stdev ** 2. +. n2.stdev ** 2.) @@ -104,11 +110,15 @@ module Normal = { | #Subtract => Some(make(~mean=n1.mean -. n2, ~stdev=n1.stdev)) | #Multiply => Some(make(~mean=n1.mean *. n2, ~stdev=n1.stdev *. Js.Math.abs_float(n2))) | #Divide => - [Error.divideByZero(n2, "Normal operateFloatSecond")] - ->ifNoErrorsThanDo(() => { - make(~mean=n1.mean /. n2, ~stdev=n1.stdev /. Js.Math.abs_float(n2)) - }) - ->Some + { + let mean = SafeMath.F.divide(~num=n1.mean, ~denominator=n2)->E.R2.errMap(r => SafeMath(r)) + let stdev = + SafeMath.F.divide( + ~num=n1.stdev, + ~denominator=Js.Math.abs_float(n2), + )->E.R2.errMap(r => SafeMath(r)) + E.R.merge(mean, stdev)->E.R.bind(((mean, stdev)) => make(~mean, ~stdev)) + }->Some | _ => None } } @@ -279,7 +289,7 @@ module From90thPercentile = { switch (low, high) { | (low, high) if low <= 0.0 && low < high => Normal.from90PercentCI(low, high) | (low, high) if low < high => Ok(Lognormal.from90PercentCI(low, high)) - | (_, _) => Error("Low value must be less than high value.") + | (_, _) => Error(NinetiethPercentileShouldBeOrdered) } } diff --git a/packages/squiggle-lang/src/rescript/Utility/SafeMath.res b/packages/squiggle-lang/src/rescript/Utility/SafeMath.res new file mode 100644 index 00000000..e5517e7e --- /dev/null +++ b/packages/squiggle-lang/src/rescript/Utility/SafeMath.res @@ -0,0 +1,38 @@ +@genType +type rec error = + | DivideByZero + | IsNaN + | IsInfinite + | NotFinite(string, float) + +module Error = { + let checkIsFinite = (value, fnName) => + E.Float.isFinite(value) ? None : Some(NotFinite(fnName, value)) + + let divideByZero = value => 0.0 == value ? None : Some(DivideByZero) +} + +module Float = { + type t = float + let make = (v: t) => + switch Error.checkIsFinite(v, "toSafe") { + | None => Ok(v) + | Some(e) => Error(e) + } +} + +module F = { + type t = float + let divide = (~num: t, ~denominator: t) => { + if denominator == 0.0 { + Error(DivideByZero) + } else { + let result = num /. denominator + if E.Float.isFinite(result) { + Ok(result) + } else { + Error(NotFinite("divide", result)) + } + } + } +}