diff --git a/packages/squiggle-lang/src/rescript/Distributions/SymbolicDist/SymbolicDist.res b/packages/squiggle-lang/src/rescript/Distributions/SymbolicDist/SymbolicDist.res index 61babfce..c44a972c 100644 --- a/packages/squiggle-lang/src/rescript/Distributions/SymbolicDist/SymbolicDist.res +++ b/packages/squiggle-lang/src/rescript/Distributions/SymbolicDist/SymbolicDist.res @@ -8,6 +8,7 @@ type normalError = NormalStandardDeviationGreaterThanZero(float) @genType type rec error = | NotFinite(string, string, float) + | DivideByZero(string) | NormalError(normalError) | MultipleErrors(array) @@ -19,38 +20,44 @@ module Error = { | _ => Some(MultipleErrors(errors)) } } + + let checkIsFinite = (value, fnName, propertyName) => + E.Float.isFinite(value) ? None : Some(NotFinite(fnName, propertyName, value)) + + let divideByZero = (value, fnName) => 0.0 == value ? None : Some(DivideByZero(fnName)) } +let ifNoErrorsThanDo = (errors, fn) => + errors->E.A.O.concatSomes->Error.mapErrorArrayToError->E.O.errorToResult(fn) + module Normal = { type t = normal - let make = (~mean: float, ~stdev: float): result => { - let firstElementFinite = Js.Float.isFinite(mean) - ? Some(NotFinite("Normal", "mean", mean)) - : None - let secondElementFinite = Js.Float.isFinite(stdev) - ? Some(NotFinite("Normal", "mean", mean)) - : None - let stdevError = - stdev <= 0.0 ? Some(NormalError(NormalStandardDeviationGreaterThanZero(stdev))) : None - - let error = - [firstElementFinite, secondElementFinite, stdevError] - ->E.A.O.concatSomes - ->Error.mapErrorArrayToError - - switch error { - | Some(r) => Error(r) - | None => Ok(#Normal({mean: mean, stdev: stdev})) - } - } let dangerouslyMake = (~mean: float, ~stdev: float) => #Normal({mean: mean, stdev: stdev}) + let inputValidation = (~mean, ~stdev) => + [ + Error.checkIsFinite(mean, "Normal", "mean"), + Error.checkIsFinite(stdev, "Normal", "stdev"), + stdev <= 0.0 ? Some(NormalError(NormalStandardDeviationGreaterThanZero(stdev))) : None, + ] + ->E.A.O.concatSomes + ->Error.mapErrorArrayToError + + let make = (~mean: float, ~stdev: float): result => + inputValidation(~mean, ~stdev)->E.O.errorToResult(() => Ok(dangerouslyMake(~mean, ~stdev))) + let pdf = (x, t: t) => Jstat.Normal.pdf(x, t.mean, t.stdev) let cdf = (x, t: t) => Jstat.Normal.cdf(x, t.mean, t.stdev) let from90PercentCI = (low, high) => { - let mean = E.A.Floats.mean([low, high]) - let stdev = (high -. low) /. (2. *. normal95confidencePoint) - dangerouslyMake(~mean, ~stdev) + let construct = () => { + let mean = E.A.Floats.mean([low, high]) + let stdev = (high -. low) /. (2. *. normal95confidencePoint) + make(~mean, ~stdev) + } + [ + Error.checkIsFinite(low, "Normal", "low"), + Error.checkIsFinite(high, "Normal", "high"), + ]->ifNoErrorsThanDo(construct) } let inv = (p, t: t) => Jstat.Normal.inv(p, t.mean, t.stdev) let sample = (t: t) => Jstat.Normal.sample(t.mean, t.stdev) @@ -96,7 +103,12 @@ module Normal = { | #Add => Some(make(~mean=n1.mean +. n2, ~stdev=n1.stdev)) | #Subtract => Some(make(~mean=n1.mean -. n2, ~stdev=n1.stdev)) | #Multiply => Some(make(~mean=n1.mean *. n2, ~stdev=n1.stdev *. Js.Math.abs_float(n2))) - | #Divide => Some(make(~mean=n1.mean /. n2, ~stdev=n1.stdev /. Js.Math.abs_float(n2))) + | #Divide => + [Error.divideByZero(n2, "Normal operateFloatSecond")] + ->ifNoErrorsThanDo(() => { + make(~mean=n1.mean /. n2, ~stdev=n1.stdev /. Js.Math.abs_float(n2)) + }) + ->Some | _ => None } } @@ -265,7 +277,7 @@ module Float = { module From90thPercentile = { let make = (low, high) => switch (low, high) { - | (low, high) if low <= 0.0 && low < high => Ok(Normal.from90PercentCI(low, high)) + | (low, high) if low <= 0.0 && low < high => Normal.from90PercentCI(low, high) | (low, high) if low < high => Ok(Lognormal.from90PercentCI(low, high)) | (_, _) => Error("Low value must be less than high value.") } diff --git a/packages/squiggle-lang/src/rescript/Utility/E.res b/packages/squiggle-lang/src/rescript/Utility/E.res index e0bcaf5c..217ac1b6 100644 --- a/packages/squiggle-lang/src/rescript/Utility/E.res +++ b/packages/squiggle-lang/src/rescript/Utility/E.res @@ -136,6 +136,13 @@ module O = { | None => Error(error) } + let errorToResult = (error: option<'a>, fn) => { + switch error { + | None => fn() + | Some(e) => Error(e) + } + } + let compare = (compare, f1: option, f2: option) => switch (f1, f2) { | (Some(f1), Some(f2)) => Some(compare(f1, f2) ? f1 : f2) @@ -198,6 +205,13 @@ module Float = { let with3DigitsPrecision = Js.Float.toPrecisionWithPrecision(_, ~digits=3) let toFixed = Js.Float.toFixed let toString = Js.Float.toString + let isFinite = Js.Float.isFinite + let safeDivision = (num: float, denom: float) => + if denom == 0.0 { + Error("Division by zero") + } else { + Ok(num /. denom) + } } module I = {