diff --git a/packages/squiggle-lang/src/rescript/FunctionRegistry.res b/packages/squiggle-lang/src/rescript/FunctionRegistry.res index df04cef9..a1138129 100644 --- a/packages/squiggle-lang/src/rescript/FunctionRegistry.res +++ b/packages/squiggle-lang/src/rescript/FunctionRegistry.res @@ -130,11 +130,12 @@ module Function = { type definitionId = int type match = Match.t, definitionId> - let make = (name, definitions): function => { + let make = (~name, ~definitions): function => { name: name, definitions: definitions, } - let makeDefinition = (name, inputs, run): fnDefinition => { + + let makeDefinition = (~name, ~inputs, ~run): fnDefinition => { name: name, inputs: inputs, run: run, @@ -239,39 +240,50 @@ module Registry = { let impossibleError = "Wrong inputs / Logically impossible" -let twoNumberInputs = (inputs: array) => { - switch inputs { - | [Number(n1), Number(n2)] => Ok(n1, n2) - | _ => Error(impossibleError) +module Prepare = { + let twoNumberInputs = (inputs: array) => { + switch inputs { + | [Number(n1), Number(n2)] => Ok(n1, n2) + | _ => Error(impossibleError) + } } + + let twoDistOrNumber = (values: array) => { + switch values { + | [DistOrNumber(a1), DistOrNumber(a2)] => Ok(a1, a2) + | _ => Error(impossibleError) + } + } + + let twoNumberInputsRecord = (v1: string, v2: string, inputs: array) => + switch inputs { + | [Record([(name1, n1), (name2, n2)])] if name1 == v1 && name2 == v2 => + twoNumberInputs([n1, n2]) + | _ => Error(impossibleError) + } + + let twoNumberInputsRecord2 = (inputs: array) => + switch inputs { + | [Record([(_, n1), (_, n2)])] => twoNumberInputs([n1, n2]) + | _ => Error(impossibleError) + } + + let twoNumberInputsRecord3 = (inputs: array) => + switch inputs { + | [Record([(_, n1), (_, n2)])] => Ok([n1, n2]) + | _ => Error(impossibleError) + } } -let twoNumberInputsRecord = (v1, v2, inputs: array) => - switch inputs { - | [Record([(name1, n1), (name2, n2)])] if name1 == v1 && name2 == v2 => twoNumberInputs([n1, n2]) - | _ => Error(impossibleError) - } - -let contain = r => ReducerInterface_ExpressionValue.EvDistribution(Symbolic(r)) - -let meanStdev = (mean, stdev) => SymbolicDist.Normal.make(mean, stdev)->E.R2.fmap(contain) - -let p5and95 = (p5, p95) => contain(SymbolicDist.Normal.from90PercentCI(p5, p95)) - -let convertTwoInputs = (inputs: array): result => - twoNumberInputs(inputs)->E.R.bind(((mean, stdev)) => meanStdev(mean, stdev)) - -let twoDistOrStdev = (a1: value, a2: value) => { - switch (a1, a2) { - | (DistOrNumber(a1), DistOrNumber(a2)) => Ok(a1, a2) - | _ => Error(impossibleError) - } +module Wrappers = { + let symbolic = r => DistributionTypes.Symbolic(r) + let evDistribution = r => ReducerInterface_ExpressionValue.EvDistribution(r) + let symbolicEvDistribution = r => r->Symbolic->evDistribution } -let distTwo = ( +let twoDistsOrNumbers = ( ~fn: (float, float) => result, - a1: value, - a2: value, + ~values: (distOrNumber, distOrNumber), ) => { let toSampleSet = r => GenericDist.toSampleSetDist(r, 1000) let sampleSetToExpressionValue = ( @@ -299,12 +311,11 @@ let distTwo = ( sampleSetResult->sampleSetToExpressionValue } - switch (a1, a2) { - | (DistOrNumber(Number(a1)), DistOrNumber(Number(a2))) => - fn(a1, a2)->E.R2.fmap(r => ReducerInterface_ExpressionValue.EvDistribution(r)) - | (DistOrNumber(Dist(a1)), DistOrNumber(Number(a2))) => singleVarSample(a1, r => fn(r, a2)) - | (DistOrNumber(Number(a1)), DistOrNumber(Dist(a2))) => singleVarSample(a2, r => fn(a1, r)) - | (DistOrNumber(Dist(a1)), DistOrNumber(Dist(a2))) => { + switch values { + | (Number(a1), Number(a2)) => fn(a1, a2)->E.R2.fmap(Wrappers.evDistribution) + | (Dist(a1), Number(a2)) => singleVarSample(a1, r => fn(r, a2)) + | (Number(a1), Dist(a2)) => singleVarSample(a2, r => fn(a1, r)) + | (Dist(a1), Dist(a2)) => { let altFn = (a, b) => fn(a, b)->mapFnResult let sampleSetResult = E.R.merge(toSampleSet(a1), toSampleSet(a2)) @@ -315,57 +326,87 @@ let distTwo = ( ->E.R2.errMap(r => DistributionTypes.OtherError(r)) sampleSetResult->sampleSetToExpressionValue } - | _ => Error(impossibleError) } } -let normal = Function.make( - "Normal", - [ - Function.makeDefinition("normal", [I_DistOrNumber, I_DistOrNumber], inputs => { - let combine = (a1: float, a2: float) => - SymbolicDist.Normal.make(a1, a2)->E.R2.fmap(r => DistributionTypes.Symbolic(r)) - distTwo(~fn=combine, inputs[0], inputs[1]) - }), - Function.makeDefinition( - "normal", - [I_Record([("mean", I_Numeric), ("stdev", I_Numeric)])], - inputs => - twoNumberInputsRecord("mean", "stdev", inputs)->E.R.bind(((mean, stdev)) => - meanStdev(mean, stdev) - ), - ), - Function.makeDefinition("normal", [I_Record([("p5", I_Numeric), ("p95", I_Numeric)])], inputs => - twoNumberInputsRecord("p5", "p95", inputs)->E.R.bind(((v1, v2)) => Ok(p5and95(v1, v2))) - ), - ], -) +module NormalFn = { + let fnName = "normal" + let twoFloatsToSymoblic = (a1: float, a2: float) => + SymbolicDist.Normal.make(a1, a2)->E.R2.fmap(Wrappers.symbolic) + let twoFloatsToSymbolic90P = (a1: float, a2: float) => + SymbolicDist.Normal.from90PercentCI(a1, a2)->Wrappers.symbolic->Ok -let logNormal = Function.make( - "Lognormal", - [ - Function.makeDefinition("lognormal", [I_Numeric, I_Numeric], inputs => - twoNumberInputs(inputs)->E.R.bind(((mu, sigma)) => - SymbolicDist.Lognormal.make(mu, sigma)->E.R2.fmap(contain) - ) - ), - Function.makeDefinition( - "lognormal", - [I_Record([("p5", I_Numeric), ("p95", I_Numeric)])], - inputs => - twoNumberInputsRecord("p5", "p95", inputs)->E.R.bind(((p5, p95)) => Ok( - contain(SymbolicDist.Lognormal.from90PercentCI(p5, p95)), - )), - ), - Function.makeDefinition( - "lognormal", - [I_Record([("mean", I_Numeric), ("stdev", I_Numeric)])], - inputs => - twoNumberInputsRecord("mean", "stdev", inputs)->E.R.bind(((mean, stdev)) => - SymbolicDist.Lognormal.fromMeanAndStdev(mean, stdev)->E.R2.fmap(contain) - ), - ), - ], -) + let toFn = Function.make( + ~name="Normal", + ~definitions=[ + Function.makeDefinition( + ~name=fnName, + ~inputs=[I_DistOrNumber, I_DistOrNumber], + ~run=inputs => { + inputs + ->Prepare.twoDistOrNumber + ->E.R.bind(twoDistsOrNumbers(~fn=twoFloatsToSymoblic, ~values=_)) + }, + ), + Function.makeDefinition( + ~name=fnName, + ~inputs=[I_Record([("mean", I_DistOrNumber), ("stdev", I_DistOrNumber)])], + ~run=inputs => + inputs + ->Prepare.twoNumberInputsRecord3 + ->E.R.bind(Prepare.twoDistOrNumber) + ->E.R.bind(twoDistsOrNumbers(~fn=twoFloatsToSymoblic, ~values=_)), + ), + Function.makeDefinition( + ~name=fnName, + ~inputs=[I_Record([("p5", I_DistOrNumber), ("p95", I_DistOrNumber)])], + ~run=inputs => + inputs + ->Prepare.twoNumberInputsRecord3 + ->E.R.bind(Prepare.twoDistOrNumber) + ->E.R.bind(twoDistsOrNumbers(~fn=twoFloatsToSymbolic90P, ~values=_)), + ), + ], + ) +} -let allFunctions = [normal, logNormal] +module LognormalFn = { + let fnName = "lognormal" + let twoFloatsToSymoblic = (a1, a2) => + SymbolicDist.Lognormal.make(a1, a2)->E.R2.fmap(Wrappers.symbolic) + let twoFloatsToSymbolic90P = (a1, a2) => + SymbolicDist.Lognormal.from90PercentCI(a1, a2)->Wrappers.symbolic->Ok + let twoFloatsToMeanStdev = (a1, a2) => + SymbolicDist.Lognormal.fromMeanAndStdev(a1, a2)->E.R2.fmap(Wrappers.symbolic) + + let toFn = Function.make( + ~name="Lognormal", + ~definitions=[ + Function.makeDefinition(~name=fnName, ~inputs=[I_DistOrNumber, I_DistOrNumber], ~run=inputs => + inputs + ->Prepare.twoDistOrNumber + ->E.R.bind(twoDistsOrNumbers(~fn=twoFloatsToSymoblic, ~values=_)) + ), + Function.makeDefinition( + ~name=fnName, + ~inputs=[I_Record([("p5", I_DistOrNumber), ("p95", I_DistOrNumber)])], + ~run=inputs => + inputs + ->Prepare.twoNumberInputsRecord3 + ->E.R.bind(Prepare.twoDistOrNumber) + ->E.R.bind(twoDistsOrNumbers(~fn=twoFloatsToSymbolic90P, ~values=_)), + ), + Function.makeDefinition( + ~name=fnName, + ~inputs=[I_Record([("mean", I_DistOrNumber), ("stdev", I_DistOrNumber)])], + ~run=inputs => + inputs + ->Prepare.twoNumberInputsRecord3 + ->E.R.bind(Prepare.twoDistOrNumber) + ->E.R.bind(twoDistsOrNumbers(~fn=twoFloatsToMeanStdev, ~values=_)), + ), + ], + ) +} + +let allFunctions = [NormalFn.toFn, LognormalFn.toFn]