diff --git a/packages/squiggle-lang/src/rescript/Distributions/DistributionTypes.res b/packages/squiggle-lang/src/rescript/Distributions/DistributionTypes.res index a9f7dfbe..26766b6e 100644 --- a/packages/squiggle-lang/src/rescript/Distributions/DistributionTypes.res +++ b/packages/squiggle-lang/src/rescript/Distributions/DistributionTypes.res @@ -6,13 +6,63 @@ type genericDist = type asAlgebraicCombinationStrategy = AsDefault | AsSymbolic | AsMonteCarlo | AsConvolution +type expressionType = + | ArrayType + | ArrayStringType + | BoolType + | CallType + | DistributionType + | LambdaType + | NumberType + | RecordType + | StringType + | SymbolType + +type argumentError = + | WrongTypeError(expressionType, expressionType) + | IncorrectNumberOfArgumentsError(int, int) + | MustBeFinite + | MustBePositive + | OtherArgumentError(string) + +module ArgumentError = { + + type t = argumentError + + let expressionTypeToString = (eType: expressionType): string => + switch eType { + | ArrayType => "array" + | ArrayStringType => "arraystring" + | BoolType => "boolean" + | CallType => "call" + | DistributionType => "distribution" + | LambdaType => "lambda" + | NumberType => "number" + | RecordType => "record" + | StringType => "string" + | SymbolType => "symbol" + } + + let toString = (err: t) : string => + switch err { + | WrongTypeError(expected, actual) => + `Argument has wrong type. Expected ${expressionTypeToString(expected)} but got ${expressionTypeToString(actual)}` + | IncorrectNumberOfArgumentsError(expected, actual) => `Expected ${Belt.Int.toString(expected)} arguments but got ${Belt.Int.toString(actual)}` + | MustBeFinite => "Argument must be finite" + | MustBePositive => "Argument must be positive" + | OtherArgumentError(msg) => msg + } +} + + + @genType type error = | NotYetImplemented | Unreachable | DistributionVerticalShiftIsInvalid | SampleSetError(SampleSetDist.sampleSetError) - | ArgumentError(string) + | ArgumentError(ArgumentError.t) | OperationError(Operation.Error.t) | PointSetConversionError(SampleSetDist.pointsetConversionError) | SparklineError(PointSetTypes.sparklineError) // This type of error is for when we find a sparkline of a discrete distribution. This should probably at some point be actually implemented @@ -33,7 +83,7 @@ module Error = { | NotYetImplemented => "Function Not Yet Implemented" | Unreachable => "Unreachable" | DistributionVerticalShiftIsInvalid => "Distribution Vertical Shift is Invalid" - | ArgumentError(s) => `Argument Error ${s}` + | ArgumentError(s) => ArgumentError.toString(s) | LogarithmOfDistributionError(s) => `Logarithm of input error: ${s}` | SampleSetError(TooFewSamples) => "Too Few Samples" | SampleSetError(NonNumericInput(err)) => `Found a non-number in input: ${err}` @@ -51,6 +101,7 @@ module Error = { let sampleErrorToDistErr = (err: SampleSetDist.sampleSetError): error => SampleSetError(err) } + @genType module DistributionOperation = { @genType diff --git a/packages/squiggle-lang/src/rescript/Distributions/SymbolicDist/SymbolicDist.res b/packages/squiggle-lang/src/rescript/Distributions/SymbolicDist/SymbolicDist.res index c925e39e..468773b8 100644 --- a/packages/squiggle-lang/src/rescript/Distributions/SymbolicDist/SymbolicDist.res +++ b/packages/squiggle-lang/src/rescript/Distributions/SymbolicDist/SymbolicDist.res @@ -5,10 +5,8 @@ let normal95confidencePoint = 1.6448536269514722 module Normal = { type t = normal - let make = (mean: float, stdev: float): result => - stdev > 0.0 - ? Ok(#Normal({mean: mean, stdev: stdev})) - : Error("Standard deviation of normal distribution must be larger than 0") + let make = (mean: SafeFloat.finite, stdev: SafeFloat.positive): symbolicDist => + #Normal({mean: SafeFloat.Finite.toFloat(mean), stdev: SafeFloat.Positive.toFloat(stdev)}) let pdf = (x, t: t) => Jstat.Normal.pdf(x, t.mean, t.stdev) let cdf = (x, t: t) => Jstat.Normal.cdf(x, t.mean, t.stdev) @@ -68,14 +66,10 @@ module Normal = { module Exponential = { type t = exponential - let make = (rate: float): result => - rate > 0.0 - ? Ok( + let make = (rate: SafeFloat.positive): symbolicDist => #Exponential({ - rate: rate, - }), - ) - : Error("Exponential distributions rate must be larger than 0.") + rate: SafeFloat.Positive.toFloat(rate), + }) let pdf = (x, t: t) => Jstat.Exponential.pdf(x, t.rate) let cdf = (x, t: t) => Jstat.Exponential.cdf(x, t.rate) let inv = (p, t: t) => Jstat.Exponential.inv(p, t.rate) @@ -86,10 +80,8 @@ module Exponential = { module Cauchy = { type t = cauchy - let make = (local, scale): result => - scale > 0.0 - ? Ok(#Cauchy({local: local, scale: scale})) - : Error("Cauchy distribution scale parameter must larger than 0.") + let make = (local: SafeFloat.finite, scale: SafeFloat.positive): symbolicDist => + #Cauchy({local: SafeFloat.Finite.toFloat(local), scale: SafeFloat.Positive.toFloat(scale)}) let pdf = (x, t: t) => Jstat.Cauchy.pdf(x, t.local, t.scale) let cdf = (x, t: t) => Jstat.Cauchy.cdf(x, t.local, t.scale) let inv = (p, t: t) => Jstat.Cauchy.inv(p, t.local, t.scale) @@ -100,10 +92,12 @@ module Cauchy = { module Triangular = { type t = triangular - let make = (low, medium, high): result => - low < medium && medium < high - ? Ok(#Triangular({low: low, medium: medium, high: high})) - : Error("Triangular values must be increasing order.") + let make = (low: SafeFloat.finite, medium: SafeFloat.finite, high: SafeFloat.finite): result =>{ + let (l, m, h) = (SafeFloat.Finite.toFloat(low), SafeFloat.Finite.toFloat(medium), SafeFloat.Finite.toFloat(high)) + l < m && m < h + ? Ok(#Triangular({low: l, medium: m, high: h})) + : Error(OtherArgumentError("Triangular values must be increasing order.")) + } let pdf = (x, t: t) => Jstat.Triangular.pdf(x, t.low, t.high, t.medium) // not obvious in jstat docs that high comes before medium? let cdf = (x, t: t) => Jstat.Triangular.cdf(x, t.low, t.high, t.medium) let inv = (p, t: t) => Jstat.Triangular.inv(p, t.low, t.high, t.medium) @@ -114,10 +108,8 @@ module Triangular = { module Beta = { type t = beta - let make = (alpha, beta) => - alpha > 0.0 && beta > 0.0 - ? Ok(#Beta({alpha: alpha, beta: beta})) - : Error("Beta distribution parameters must be positive") + let make = (alpha: SafeFloat.positive, beta: SafeFloat.positive) => + #Beta({alpha: SafeFloat.Positive.toFloat(alpha), beta: SafeFloat.Positive.toFloat(beta)}) let pdf = (x, t: t) => Jstat.Beta.pdf(x, t.alpha, t.beta) let cdf = (x, t: t) => Jstat.Beta.cdf(x, t.alpha, t.beta) let inv = (p, t: t) => Jstat.Beta.inv(p, t.alpha, t.beta) @@ -128,10 +120,8 @@ module Beta = { module Lognormal = { type t = lognormal - let make = (mu, sigma) => - sigma > 0.0 - ? Ok(#Lognormal({mu: mu, sigma: sigma})) - : Error("Lognormal standard deviation must be larger than 0") + let make = (mu: SafeFloat.finite, sigma: SafeFloat.positive) => + #Lognormal({mu: SafeFloat.Finite.toFloat(mu), sigma: SafeFloat.Positive.toFloat(sigma)}) let pdf = (x, t: t) => Jstat.Lognormal.pdf(x, t.mu, t.sigma) let cdf = (x, t: t) => Jstat.Lognormal.cdf(x, t.mu, t.sigma) let inv = (p, t: t) => Jstat.Lognormal.inv(p, t.mu, t.sigma) @@ -199,8 +189,16 @@ module Lognormal = { module Uniform = { type t = uniform - let make = (low, high) => - high > low ? Ok(#Uniform({low: low, high: high})) : Error("High must be larger than low") + let make = (low: SafeFloat.finite, high: SafeFloat.finite) => { + let l = SafeFloat.Finite.toFloat(low) + let h = SafeFloat.Finite.toFloat(high) + if h > l { + Ok(#Uniform({low: l, high: h})) + } + else { + Error(DistributionTypes.OtherArgumentError("High must be larger than low")) + } + } let pdf = (x, t: t) => Jstat.Uniform.pdf(x, t.low, t.high) let cdf = (x, t: t) => Jstat.Uniform.cdf(x, t.low, t.high) @@ -218,16 +216,8 @@ module Uniform = { module Gamma = { type t = gamma - let make = (shape: float, scale: float) => { - if shape > 0. { - if scale > 0. { - Ok(#Gamma({shape: shape, scale: scale})) - } else { - Error("scale must be larger than 0") - } - } else { - Error("shape must be larger than 0") - } + let make = (shape: SafeFloat.positive, scale: SafeFloat.positive) => { + #Gamma({shape: SafeFloat.Positive.toFloat(shape), scale: SafeFloat.Positive.toFloat(scale)}) } let pdf = (x: float, t: t) => Jstat.Gamma.pdf(x, t.shape, t.scale) let cdf = (x: float, t: t) => Jstat.Gamma.cdf(x, t.shape, t.scale) @@ -255,12 +245,16 @@ module Float = { } module From90thPercentile = { - let make = (low, high) => - switch (low, high) { - | (low, high) if low <= 0.0 && low < high => Ok(Normal.from90PercentCI(low, high)) - | (low, high) if low < high => Ok(Lognormal.from90PercentCI(low, high)) - | (_, _) => Error("Low value must be less than high value.") + let make = (low: SafeFloat.positive, high: SafeFloat.positive) : result => { + let l = SafeFloat.Positive.toFloat(low) + let h = SafeFloat.Positive.toFloat(high) + if l < h { + Ok(Lognormal.from90PercentCI(l, h)) } + else { + Error(OtherArgumentError("Low value must be less than high value.")) + } + } } module T = { diff --git a/packages/squiggle-lang/src/rescript/ReducerInterface/ReducerInterface_FunctionParser.res b/packages/squiggle-lang/src/rescript/ReducerInterface/ReducerInterface_FunctionParser.res new file mode 100644 index 00000000..18263792 --- /dev/null +++ b/packages/squiggle-lang/src/rescript/ReducerInterface/ReducerInterface_FunctionParser.res @@ -0,0 +1,103 @@ +type expressionValue = ReducerInterface_ExpressionValue.expressionValue +type error = DistributionTypes.error +type argumentError = DistributionTypes.ArgumentError.t + +let expressionValueToType = (value: expressionValue): DistributionTypes.expressionType => + switch value { + | EvArray(_) => ArrayType + | EvArrayString(_) => ArrayStringType + | EvBool(_) => BoolType + | EvCall(_) => CallType + | EvDistribution(_) => DistributionType + | EvLambda(_) => LambdaType + | EvNumber(_) => NumberType + | EvRecord(_) => RecordType + | EvString(_) => StringType + | EvSymbol(_) => SymbolType + } + +module Primitive = { + let distribution = (argument: expressionValue): result => + switch argument { + | EvDistribution(dist) => Ok(dist) + | _ => Error(WrongTypeError(DistributionType, expressionValueToType(argument))) + } + + let finite = (argument: expressionValue): result => + switch argument { + | EvNumber(num) => + switch SafeFloat.Finite.make(num) { + | Some(safeNum) => Ok(safeNum) + | None => Error(MustBeFinite) + } + | _ => Error(WrongTypeError(NumberType, expressionValueToType(argument))) + } + + let positive = (argument: expressionValue): result => + switch argument { + | EvNumber(num) => + switch SafeFloat.Positive.make(num) { + | Some(safeNum) => Ok(safeNum) + | None => Error(MustBePositive) + } + | _ => Error(WrongTypeError(NumberType, expressionValueToType(argument))) + } +} + +module Functions = { + let function1 = ( + f: 'a => 'b, + parseArg1: expressionValue => result<'a, argumentError>, + args: array, + ): result<'b, argumentError> => + switch args { + | [arg1] => E.R.fmap(f, parseArg1(arg1)) + | _ => Error(IncorrectNumberOfArgumentsError(1, E.A.length(args))) + } + + + let function2Bind = ( + f: ('a, 'b) => result<'c, argumentError>, + parseArg1: expressionValue => result<'a, argumentError>, + parseArg2: expressionValue => result<'b, argumentError>, + args: array, + ): result<'c, argumentError> => + switch args { + | [arg1, arg2] => E.R.merge(parseArg1(arg1), parseArg2(arg2)) -> E.R.bind(((a, b)) => f(a, b)) + | _ => Error(IncorrectNumberOfArgumentsError(2, E.A.length(args))) + } + + let function2 = ( + f: ('a, 'b) => 'c, + parseArg1: expressionValue => result<'a, argumentError>, + parseArg2: expressionValue => result<'b, argumentError>, + args: array, + ): result<'c, argumentError> => + function2Bind((a, b) => Ok(f(a, b)), parseArg1, parseArg2, args) + + let function3Bind = ( + f: ('a, 'b, 'c) => result<'d, argumentError>, + parseArg1: expressionValue => result<'a, argumentError>, + parseArg2: expressionValue => result<'b, argumentError>, + parseArg3: expressionValue => result<'c, argumentError>, + args: array, + ): result<'d, argumentError> => + switch args { + | [arg1, arg2, arg3] => E.R.merge3(parseArg1(arg1), parseArg2(arg2), parseArg3(arg3)) -> E.R.bind(((a, b, c)) => f(a, b, c)) + | _ => Error(IncorrectNumberOfArgumentsError(2, E.A.length(args))) + } +} + +type function = Function(string, array => result) + +let allFunctions: array = + [ Function("exponential", Functions.function1( SymbolicDist.Exponential.make, Primitive.positive)) + , Function("normal", Functions.function2( SymbolicDist.Normal.make, Primitive.finite, Primitive.positive,)) + , Function("uniform", Functions.function2Bind( SymbolicDist.Uniform.make, Primitive.finite, Primitive.finite,)) + , Function("beta", Functions.function2( SymbolicDist.Beta.make, Primitive.positive, Primitive.positive,)) + , Function("lognormal", Functions.function2( SymbolicDist.Lognormal.make, Primitive.finite, Primitive.positive,)) + , Function("cauchy", Functions.function2( SymbolicDist.Cauchy.make, Primitive.finite, Primitive.positive,)) + , Function("gamma", Functions.function2( SymbolicDist.Gamma.make, Primitive.positive, Primitive.positive,)) + , Function("to", Functions.function2Bind( SymbolicDist.From90thPercentile.make, Primitive.positive, Primitive.positive,)) + , Function("triangular", Functions.function3Bind( SymbolicDist.Triangular.make, Primitive.finite, Primitive.finite, Primitive.finite,)) +] diff --git a/packages/squiggle-lang/src/rescript/ReducerInterface/ReducerInterface_GenericDistribution.res b/packages/squiggle-lang/src/rescript/ReducerInterface/ReducerInterface_GenericDistribution.res index 73614aee..44c38687 100644 --- a/packages/squiggle-lang/src/rescript/ReducerInterface/ReducerInterface_GenericDistribution.res +++ b/packages/squiggle-lang/src/rescript/ReducerInterface/ReducerInterface_GenericDistribution.res @@ -1,5 +1,6 @@ module ExpressionValue = ReducerInterface_ExpressionValue type expressionValue = ReducerInterface_ExpressionValue.expressionValue +type argumentError = DistributionTypes.argumentError module Helpers = { let arithmeticMap = r => @@ -80,25 +81,25 @@ module Helpers = { )->DistributionOperation.run(~env) } - let parseNumber = (args: expressionValue): Belt.Result.t => + let parseNumber = (args: expressionValue): Belt.Result.t => switch args { | EvNumber(x) => Ok(x) - | _ => Error("Not a number") + | _ => Error(OtherArgumentError("Not a number")) } - let parseNumberArray = (ags: array): Belt.Result.t, string> => + let parseNumberArray = (ags: array): Belt.Result.t, argumentError> => E.A.fmap(parseNumber, ags) |> E.A.R.firstErrorOrOpen - let parseDist = (args: expressionValue): Belt.Result.t => + let parseDist = (args: expressionValue): Belt.Result.t => switch args { | EvDistribution(x) => Ok(x) | EvNumber(x) => Ok(GenericDist.fromFloat(x)) - | _ => Error("Not a distribution") + | _ => Error(OtherArgumentError("Not a distribution")) } let parseDistributionArray = (ags: array): Belt.Result.t< array, - string, + argumentError, > => E.A.fmap(parseDist, ags) |> E.A.R.firstErrorOrOpen let mixtureWithGivenWeights = ( @@ -109,7 +110,7 @@ module Helpers = { E.A.length(distributions) == E.A.length(weights) ? Mixture(Belt.Array.zip(distributions, weights))->DistributionOperation.run(~env) : GenDistError( - ArgumentError("Error, mixture call has different number of distributions and weights"), + ArgumentError(OtherArgumentError("Error, mixture call has different number of distributions and weights")) ) let mixtureWithDefaultWeights = ( @@ -125,7 +126,7 @@ module Helpers = { args: array, ~env: DistributionOperation.env, ): DistributionOperation.outputType => { - let error = (err: string): DistributionOperation.outputType => + let error = (err: DistributionTypes.argumentError): DistributionOperation.outputType => err->DistributionTypes.ArgumentError->GenDistError switch args { | [EvArray(distributions)] => @@ -138,7 +139,7 @@ module Helpers = { | (Ok(distrs), Ok(wghts)) => mixtureWithGivenWeights(distrs, wghts, ~env) | (Error(err), Ok(_)) => error(err) | (Ok(_), Error(err)) => error(err) - | (Error(err1), Error(err2)) => error(`${err1}|${err2}`) + | (Error(err1), Error(err2)) => error(err1) } | _ => switch E.A.last(args) { @@ -158,36 +159,19 @@ module Helpers = { | Ok(distributions) => mixtureWithDefaultWeights(distributions, ~env) | Error(err) => error(err) } - | _ => error("Last argument of mx must be array or distribution") + | _ => error(OtherArgumentError("Last argument of mx must be array or distribution")) } } } } module SymbolicConstructors = { - let oneFloat = name => - switch name { - | "exponential" => Ok(SymbolicDist.Exponential.make) - | _ => Error("Unreachable state") - } - - let twoFloat = name => - switch name { - | "normal" => Ok(SymbolicDist.Normal.make) - | "uniform" => Ok(SymbolicDist.Uniform.make) - | "beta" => Ok(SymbolicDist.Beta.make) - | "lognormal" => Ok(SymbolicDist.Lognormal.make) - | "cauchy" => Ok(SymbolicDist.Cauchy.make) - | "gamma" => Ok(SymbolicDist.Gamma.make) - | "to" => Ok(SymbolicDist.From90thPercentile.make) - | _ => Error("Unreachable state") - } - - let threeFloat = name => - switch name { - | "triangular" => Ok(SymbolicDist.Triangular.make) - | _ => Error("Unreachable state") - } + let checkSymbolicConstructors = (call: ExpressionValue.functionCall) : option> => { + let (fnName, args) = call + let function = E.A.find((ReducerInterface_FunctionParser.Function(name, argsParser)) => name == fnName, ReducerInterface_FunctionParser.allFunctions) + E.O.fmap((ReducerInterface_FunctionParser.Function(_, argsParser)) => argsParser(args), function) + } + let symbolicResultToOutput = ( symbolicResult: result, @@ -204,23 +188,8 @@ let dispatchToGenericOutput = ( ): option => { let (fnName, args) = call switch (fnName, args) { - | ("exponential" as fnName, [EvNumber(f)]) => - SymbolicConstructors.oneFloat(fnName) - ->E.R.bind(r => r(f)) - ->SymbolicConstructors.symbolicResultToOutput | ("delta", [EvNumber(f)]) => SymbolicDist.Float.makeSafe(f)->SymbolicConstructors.symbolicResultToOutput - | ( - ("normal" | "uniform" | "beta" | "lognormal" | "cauchy" | "gamma" | "to") as fnName, - [EvNumber(f1), EvNumber(f2)], - ) => - SymbolicConstructors.twoFloat(fnName) - ->E.R.bind(r => r(f1, f2)) - ->SymbolicConstructors.symbolicResultToOutput - | ("triangular" as fnName, [EvNumber(f1), EvNumber(f2), EvNumber(f3)]) => - SymbolicConstructors.threeFloat(fnName) - ->E.R.bind(r => r(f1, f2, f3)) - ->SymbolicConstructors.symbolicResultToOutput | ("sample", [EvDistribution(dist)]) => Helpers.toFloatFn(#Sample, dist, ~env) | ("mean", [EvDistribution(dist)]) => Helpers.toFloatFn(#Mean, dist, ~env) | ("integralSum", [EvDistribution(dist)]) => Helpers.toFloatFn(#IntegralSum, dist, ~env) @@ -323,7 +292,7 @@ let dispatchToGenericOutput = ( a, ~env, )->Some - | _ => None + | _ => checkSymbolicConstructors(call) } } diff --git a/packages/squiggle-lang/src/rescript/Utility/E.res b/packages/squiggle-lang/src/rescript/Utility/E.res index 15678e1a..3975db00 100644 --- a/packages/squiggle-lang/src/rescript/Utility/E.res +++ b/packages/squiggle-lang/src/rescript/Utility/E.res @@ -253,6 +253,13 @@ module R = { | (_, Error(e)) => Error(e) | (Ok(a), Ok(b)) => Ok((a, b)) } + let merge3 = (a, b, c) => + switch (a, b, c) { + | (Error(e), _, _) => Error(e) + | (_, Error(e), _) => Error(e) + | (_, _, Error(e)) => Error(e) + | (Ok(a), Ok(b), Ok(c)) => Ok((a, b, c)) + } let toOption = (e: Belt.Result.t<'a, 'b>) => switch e { | Ok(r) => Some(r) @@ -531,10 +538,13 @@ module A = { let keepMap = Belt.Array.keepMap let slice = Belt.Array.slice let init = Array.init + let filter = (fn, xs) => Belt.Array.keep(xs, fn) let reduce = Belt.Array.reduce let reducei = Belt.Array.reduceWithIndex let isEmpty = r => length(r) < 1 let stableSortBy = Belt.SortArray.stableSortBy + let find = (f: 'a => bool, xs: array<'a>) => first(filter(f, xs)) + let toRanges = (a: array<'a>) => switch a |> Belt.Array.length { | 0 diff --git a/packages/squiggle-lang/src/rescript/Utility/SafeFloat.res b/packages/squiggle-lang/src/rescript/Utility/SafeFloat.res new file mode 100644 index 00000000..04b954f2 --- /dev/null +++ b/packages/squiggle-lang/src/rescript/Utility/SafeFloat.res @@ -0,0 +1,34 @@ +type finite = Finite(float) +module Finite = { + type t = finite + let valid = Js.Float.isFinite + let make = (x: float) : option => + if valid(x) { + Some(Finite(x)) + } + else { + None + } + let toFloat = (x: t) => + switch x { + | Finite(inner) => inner + } +} + +type positive = Positive(float) +module Positive = { + type t = positive + let valid = (x: float) => Finite.valid(x) && x > 0. + let make = (x: float) : option => + if valid(x) { + Some(Positive(x)) + } + else { + None + } + + let toFloat = (x: t) => + switch x { + | Positive(inner) => inner + } +}