diff --git a/packages/squiggle-lang/src/rescript/ReducerInterface/ReducerInterface_GenericDistribution.res b/packages/squiggle-lang/src/rescript/ReducerInterface/ReducerInterface_GenericDistribution.res index 8dae2586..a50fbede 100644 --- a/packages/squiggle-lang/src/rescript/ReducerInterface/ReducerInterface_GenericDistribution.res +++ b/packages/squiggle-lang/src/rescript/ReducerInterface/ReducerInterface_GenericDistribution.res @@ -6,8 +6,6 @@ let defaultEnv: DistributionOperation.env = { xyPointLength: MagicNumbers.Environment.defaultXYPointLength, } -let runGenericOperation = DistributionOperation.run(~env=defaultEnv) - module Helpers = { let arithmeticMap = r => switch r { @@ -39,37 +37,44 @@ module Helpers = { let toFloatFn = ( fnCall: DistributionTypes.DistributionOperation.toFloat, dist: DistributionTypes.genericDist, + ~env: DistributionOperation.env, ) => { FromDist(DistributionTypes.DistributionOperation.ToFloat(fnCall), dist) - ->runGenericOperation + ->DistributionOperation.run(~env) ->Some } let toStringFn = ( fnCall: DistributionTypes.DistributionOperation.toString, dist: DistributionTypes.genericDist, + ~env: DistributionOperation.env, ) => { FromDist(DistributionTypes.DistributionOperation.ToString(fnCall), dist) - ->runGenericOperation + ->DistributionOperation.run(~env) ->Some } let toBoolFn = ( fnCall: DistributionTypes.DistributionOperation.toBool, dist: DistributionTypes.genericDist, + ~env: DistributionOperation.env, ) => { FromDist(DistributionTypes.DistributionOperation.ToBool(fnCall), dist) - ->runGenericOperation + ->DistributionOperation.run(~env) ->Some } - let toDistFn = (fnCall: DistributionTypes.DistributionOperation.toDist, dist) => { + let toDistFn = ( + fnCall: DistributionTypes.DistributionOperation.toDist, + dist, + ~env: DistributionOperation.env, + ) => { FromDist(DistributionTypes.DistributionOperation.ToDist(fnCall), dist) - ->runGenericOperation + ->DistributionOperation.run(~env) ->Some } - let twoDiststoDistFn = (direction, arithmetic, dist1, dist2) => { + let twoDiststoDistFn = (direction, arithmetic, dist1, dist2, ~env: DistributionOperation.env) => { FromDist( DistributionTypes.DistributionOperation.ToDistCombination( direction, @@ -77,7 +82,7 @@ module Helpers = { #Dist(dist2), ), dist1, - )->runGenericOperation + )->DistributionOperation.run(~env) } let parseNumber = (args: expressionValue): Belt.Result.t => @@ -104,33 +109,38 @@ module Helpers = { let mixtureWithGivenWeights = ( distributions: array, weights: array, + ~env: DistributionOperation.env, ): DistributionOperation.outputType => E.A.length(distributions) == E.A.length(weights) - ? Mixture(Belt.Array.zip(distributions, weights))->runGenericOperation + ? Mixture(Belt.Array.zip(distributions, weights))->DistributionOperation.run(~env) : GenDistError( ArgumentError("Error, mixture call has different number of distributions and weights"), ) let mixtureWithDefaultWeights = ( distributions: array, + ~env: DistributionOperation.env, ): DistributionOperation.outputType => { let length = E.A.length(distributions) let weights = Belt.Array.make(length, 1.0 /. Belt.Int.toFloat(length)) - mixtureWithGivenWeights(distributions, weights) + mixtureWithGivenWeights(distributions, weights, ~env) } - let mixture = (args: array): DistributionOperation.outputType => { + let mixture = ( + args: array, + ~env: DistributionOperation.env, + ): DistributionOperation.outputType => { let error = (err: string): DistributionOperation.outputType => err->DistributionTypes.ArgumentError->GenDistError switch args { | [EvArray(distributions)] => switch parseDistributionArray(distributions) { - | Ok(distrs) => mixtureWithDefaultWeights(distrs) + | Ok(distrs) => mixtureWithDefaultWeights(distrs, ~env) | Error(err) => error(err) } | [EvArray(distributions), EvArray(weights)] => switch (parseDistributionArray(distributions), parseNumberArray(weights)) { - | (Ok(distrs), Ok(wghts)) => mixtureWithGivenWeights(distrs, wghts) + | (Ok(distrs), Ok(wghts)) => mixtureWithGivenWeights(distrs, wghts, ~env) | (Error(err), Ok(_)) => error(err) | (Ok(_), Error(err)) => error(err) | (Error(err1), Error(err2)) => error(`${err1}|${err2}`) @@ -143,14 +153,14 @@ module Helpers = { Belt.Array.slice(args, ~offset=0, ~len=E.A.length(args) - 1), ) switch E.R.merge(distributions, weights) { - | Ok(d, w) => mixtureWithGivenWeights(d, w) + | Ok(d, w) => mixtureWithGivenWeights(d, w, ~env) | Error(err) => error(err) } } | Some(EvNumber(_)) | Some(EvDistribution(_)) => switch parseDistributionArray(args) { - | Ok(distributions) => mixtureWithDefaultWeights(distributions) + | Ok(distributions) => mixtureWithDefaultWeights(distributions, ~env) | Error(err) => error(err) } | _ => error("Last argument of mx must be array or distribution") @@ -193,9 +203,10 @@ module SymbolicConstructors = { } } -let dispatchToGenericOutput = (call: ExpressionValue.functionCall, _environment): option< - DistributionOperation.outputType, -> => { +let dispatchToGenericOutput = ( + call: ExpressionValue.functionCall, + env: DistributionOperation.env, +): option => { let (fnName, args) = call switch (fnName, args) { | ("exponential" as fnName, [EvNumber(f)]) => @@ -215,13 +226,13 @@ let dispatchToGenericOutput = (call: ExpressionValue.functionCall, _environment) SymbolicConstructors.threeFloat(fnName) ->E.R.bind(r => r(f1, f2, f3)) ->SymbolicConstructors.symbolicResultToOutput - | ("sample", [EvDistribution(dist)]) => Helpers.toFloatFn(#Sample, dist) - | ("mean", [EvDistribution(dist)]) => Helpers.toFloatFn(#Mean, dist) - | ("integralSum", [EvDistribution(dist)]) => Helpers.toFloatFn(#IntegralSum, dist) - | ("toString", [EvDistribution(dist)]) => Helpers.toStringFn(ToString, dist) - | ("toSparkline", [EvDistribution(dist)]) => Helpers.toStringFn(ToSparkline(20), dist) + | ("sample", [EvDistribution(dist)]) => Helpers.toFloatFn(#Sample, dist, ~env) + | ("mean", [EvDistribution(dist)]) => Helpers.toFloatFn(#Mean, dist, ~env) + | ("integralSum", [EvDistribution(dist)]) => Helpers.toFloatFn(#IntegralSum, dist, ~env) + | ("toString", [EvDistribution(dist)]) => Helpers.toStringFn(ToString, dist, ~env) + | ("toSparkline", [EvDistribution(dist)]) => Helpers.toStringFn(ToSparkline(20), dist, ~env) | ("toSparkline", [EvDistribution(dist), EvNumber(n)]) => - Helpers.toStringFn(ToSparkline(Belt.Float.toInt(n)), dist) + Helpers.toStringFn(ToSparkline(Belt.Float.toInt(n)), dist, ~env) | ("exp", [EvDistribution(a)]) => // https://mathjs.org/docs/reference/functions/exp.html Helpers.twoDiststoDistFn( @@ -229,60 +240,74 @@ let dispatchToGenericOutput = (call: ExpressionValue.functionCall, _environment) "pow", GenericDist.fromFloat(MagicNumbers.Math.e), a, + ~env, )->Some - | ("normalize", [EvDistribution(dist)]) => Helpers.toDistFn(Normalize, dist) + | ("normalize", [EvDistribution(dist)]) => Helpers.toDistFn(Normalize, dist, ~env) | ("klDivergence", [EvDistribution(a), EvDistribution(b)]) => - Some(runGenericOperation(FromDist(ToScore(KLDivergence(b)), a))) - | ("isNormalized", [EvDistribution(dist)]) => Helpers.toBoolFn(IsNormalized, dist) - | ("toPointSet", [EvDistribution(dist)]) => Helpers.toDistFn(ToPointSet, dist) + Some(DistributionOperation.run(FromDist(ToScore(KLDivergence(b)), a), ~env)) + | ("isNormalized", [EvDistribution(dist)]) => Helpers.toBoolFn(IsNormalized, dist, ~env) + | ("toPointSet", [EvDistribution(dist)]) => Helpers.toDistFn(ToPointSet, dist, ~env) | ("scaleLog", [EvDistribution(dist)]) => - Helpers.toDistFn(Scale(#Logarithm, MagicNumbers.Math.e), dist) - | ("scaleLog10", [EvDistribution(dist)]) => Helpers.toDistFn(Scale(#Logarithm, 10.0), dist) + Helpers.toDistFn(Scale(#Logarithm, MagicNumbers.Math.e), dist, ~env) + | ("scaleLog10", [EvDistribution(dist)]) => Helpers.toDistFn(Scale(#Logarithm, 10.0), dist, ~env) | ("scaleLog", [EvDistribution(dist), EvNumber(float)]) => - Helpers.toDistFn(Scale(#Logarithm, float), dist) + Helpers.toDistFn(Scale(#Logarithm, float), dist, ~env) | ("scaleLogWithThreshold", [EvDistribution(dist), EvNumber(base), EvNumber(eps)]) => - Helpers.toDistFn(Scale(#LogarithmWithThreshold(eps), base), dist) + Helpers.toDistFn(Scale(#LogarithmWithThreshold(eps), base), dist, ~env) | ("scalePow", [EvDistribution(dist), EvNumber(float)]) => - Helpers.toDistFn(Scale(#Power, float), dist) + Helpers.toDistFn(Scale(#Power, float), dist, ~env) | ("scaleExp", [EvDistribution(dist)]) => - Helpers.toDistFn(Scale(#Power, MagicNumbers.Math.e), dist) - | ("cdf", [EvDistribution(dist), EvNumber(float)]) => Helpers.toFloatFn(#Cdf(float), dist) - | ("pdf", [EvDistribution(dist), EvNumber(float)]) => Helpers.toFloatFn(#Pdf(float), dist) - | ("inv", [EvDistribution(dist), EvNumber(float)]) => Helpers.toFloatFn(#Inv(float), dist) + Helpers.toDistFn(Scale(#Power, MagicNumbers.Math.e), dist, ~env) + | ("cdf", [EvDistribution(dist), EvNumber(float)]) => Helpers.toFloatFn(#Cdf(float), dist, ~env) + | ("pdf", [EvDistribution(dist), EvNumber(float)]) => Helpers.toFloatFn(#Pdf(float), dist, ~env) + | ("inv", [EvDistribution(dist), EvNumber(float)]) => Helpers.toFloatFn(#Inv(float), dist, ~env) | ("toSampleSet", [EvDistribution(dist), EvNumber(float)]) => - Helpers.toDistFn(ToSampleSet(Belt.Int.fromFloat(float)), dist) + Helpers.toDistFn(ToSampleSet(Belt.Int.fromFloat(float)), dist, ~env) | ("toSampleSet", [EvDistribution(dist)]) => - Helpers.toDistFn(ToSampleSet(MagicNumbers.Environment.defaultSampleCount), dist) + Helpers.toDistFn(ToSampleSet(env.sampleCount), dist, ~env) | ("fromSamples", [EvArray(inputArray)]) => { let _wrapInputErrors = x => SampleSetDist.NonNumericInput(x) let parsedArray = Helpers.parseNumberArray(inputArray)->E.R2.errMap(_wrapInputErrors) switch parsedArray { - | Ok(array) => runGenericOperation(FromSamples(array)) + | Ok(array) => DistributionOperation.run(FromSamples(array), ~env) | Error(e) => GenDistError(SampleSetError(e)) }->Some } - | ("inspect", [EvDistribution(dist)]) => Helpers.toDistFn(Inspect, dist) + | ("inspect", [EvDistribution(dist)]) => Helpers.toDistFn(Inspect, dist, ~env) | ("truncateLeft", [EvDistribution(dist), EvNumber(float)]) => - Helpers.toDistFn(Truncate(Some(float), None), dist) + Helpers.toDistFn(Truncate(Some(float), None), dist, ~env) | ("truncateRight", [EvDistribution(dist), EvNumber(float)]) => - Helpers.toDistFn(Truncate(None, Some(float)), dist) + Helpers.toDistFn(Truncate(None, Some(float)), dist, ~env) | ("truncate", [EvDistribution(dist), EvNumber(float1), EvNumber(float2)]) => - Helpers.toDistFn(Truncate(Some(float1), Some(float2)), dist) - | ("mx" | "mixture", args) => Helpers.mixture(args)->Some + Helpers.toDistFn(Truncate(Some(float1), Some(float2)), dist, ~env) + | ("mx" | "mixture", args) => Helpers.mixture(args, ~env)->Some | ("log", [EvDistribution(a)]) => Helpers.twoDiststoDistFn( Algebraic(AsDefault), "log", a, GenericDist.fromFloat(MagicNumbers.Math.e), + ~env, )->Some | ("log10", [EvDistribution(a)]) => - Helpers.twoDiststoDistFn(Algebraic(AsDefault), "log", a, GenericDist.fromFloat(10.0))->Some + Helpers.twoDiststoDistFn( + Algebraic(AsDefault), + "log", + a, + GenericDist.fromFloat(10.0), + ~env, + )->Some | ("unaryMinus", [EvDistribution(a)]) => - Helpers.twoDiststoDistFn(Algebraic(AsDefault), "multiply", a, GenericDist.fromFloat(-1.0))->Some + Helpers.twoDiststoDistFn( + Algebraic(AsDefault), + "multiply", + a, + GenericDist.fromFloat(-1.0), + ~env, + )->Some | (("add" | "multiply" | "subtract" | "divide" | "pow" | "log") as arithmetic, [_, _] as args) => Helpers.catchAndConvertTwoArgsToDists(args)->E.O2.fmap(((fst, snd)) => - Helpers.twoDiststoDistFn(Algebraic(AsDefault), arithmetic, fst, snd) + Helpers.twoDiststoDistFn(Algebraic(AsDefault), arithmetic, fst, snd, ~env) ) | ( ("dotAdd" @@ -293,7 +318,7 @@ let dispatchToGenericOutput = (call: ExpressionValue.functionCall, _environment) [_, _] as args, ) => Helpers.catchAndConvertTwoArgsToDists(args)->E.O2.fmap(((fst, snd)) => - Helpers.twoDiststoDistFn(Pointwise, arithmetic, fst, snd) + Helpers.twoDiststoDistFn(Pointwise, arithmetic, fst, snd, ~env) ) | ("dotExp", [EvDistribution(a)]) => Helpers.twoDiststoDistFn( @@ -301,6 +326,7 @@ let dispatchToGenericOutput = (call: ExpressionValue.functionCall, _environment) "dotPow", GenericDist.fromFloat(MagicNumbers.Math.e), a, + ~env, )->Some | _ => None }