diff --git a/packages/squiggle-lang/src/rescript/FunctionRegistry.res b/packages/squiggle-lang/src/rescript/FunctionRegistry.res index 31db2269..10f48f43 100644 --- a/packages/squiggle-lang/src/rescript/FunctionRegistry.res +++ b/packages/squiggle-lang/src/rescript/FunctionRegistry.res @@ -92,11 +92,11 @@ module FnDefinition = { type definitionMatch = MatchSimple.t let getArgValues = (f: fnDefinition, args: array): option> => { - let inputTypes = f.inputs + let mainInputTypes = f.inputs if E.A.length(f.inputs) !== E.A.length(args) { None } else { - E.A.zip(inputTypes, args) + E.A.zip(mainInputTypes, args) ->E.A2.fmap(((input, arg)) => matchInput(input, arg)) ->E.A.O.openIfAllSome } @@ -273,7 +273,7 @@ module Wrappers = { module Process = { let twoDistsOrNumbersToDist = ( - ~fn: (float, float) => result, + ~fn: ((float, float)) => result, ~values: (distOrNumber, distOrNumber), ) => { let toSampleSet = r => GenericDist.toSampleSetDist(r, 1000) @@ -303,11 +303,11 @@ module Process = { } switch values { - | (Number(a1), Number(a2)) => fn(a1, a2)->E.R2.fmap(Wrappers.evDistribution) - | (Dist(a1), Number(a2)) => singleVarSample(a1, r => fn(r, a2)) - | (Number(a1), Dist(a2)) => singleVarSample(a2, r => fn(a1, r)) + | (Number(a1), Number(a2)) => fn((a1, a2))->E.R2.fmap(Wrappers.evDistribution) + | (Dist(a1), Number(a2)) => singleVarSample(a1, r => fn((r, a2))) + | (Number(a1), Dist(a2)) => singleVarSample(a2, r => fn((a1, r))) | (Dist(a1), Dist(a2)) => { - let altFn = (a, b) => fn(a, b)->mapFnResult + let altFn = (a, b) => fn((a, b))->mapFnResult let sampleSetResult = E.R.merge(toSampleSet(a1), toSampleSet(a2)) ->E.R2.errMap(DistributionTypes.Error.toString) @@ -321,47 +321,41 @@ module Process = { } let twoDistsOrNumbersToDistUsingSymbolicDist = ( - ~fn: (float, float) => result, + ~fn: ((float, float)) => result, ~values, ) => { - twoDistsOrNumbersToDist(~fn=(a, b) => fn(a, b)->E.R2.fmap(Wrappers.symbolic), ~values) + twoDistsOrNumbersToDist(~fn=r => r->fn->E.R2.fmap(Wrappers.symbolic), ~values) } } +let twoArgs = (fn, (a1, a2)) => fn(a1, a2) + +let process = (~fn, r) => + r->E.R.bind(Process.twoDistsOrNumbersToDistUsingSymbolicDist(~fn, ~values=_)) + module NormalFn = { let fnName = "normal" - let twoFloatsToSymoblic = (a1, a2) => - SymbolicDist.Normal.make(a1, a2) - let twoFloatsToSymbolic90P = (a1, a2) => - SymbolicDist.Normal.from90PercentCI(a1, a2)->Ok + let mainInputType = I_DistOrNumber let toFn = Function.make( ~name="Normal", ~definitions=[ + Function.makeDefinition(~name=fnName, ~inputs=[mainInputType, mainInputType], ~run=inputs => { + inputs->Prepare.twoDistOrNumber->process(~fn=twoArgs(SymbolicDist.Normal.make)) + }), Function.makeDefinition( ~name=fnName, - ~inputs=[I_DistOrNumber, I_DistOrNumber], - ~run=inputs => { - inputs - ->Prepare.twoDistOrNumber - ->E.R.bind(Process.twoDistsOrNumbersToDistUsingSymbolicDist(~fn=twoFloatsToSymoblic, ~values=_)) - }, + ~inputs=[I_Record([("mean", mainInputType), ("stdev", mainInputType)])], + ~run=inputs => + inputs->Prepare.twoDistOrNumberFromRecord->process(~fn=twoArgs(SymbolicDist.Normal.make)), ), Function.makeDefinition( ~name=fnName, - ~inputs=[I_Record([("mean", I_DistOrNumber), ("stdev", I_DistOrNumber)])], + ~inputs=[I_Record([("p5", mainInputType), ("p95", mainInputType)])], ~run=inputs => inputs ->Prepare.twoDistOrNumberFromRecord - ->E.R.bind(Process.twoDistsOrNumbersToDistUsingSymbolicDist(~fn=twoFloatsToSymoblic, ~values=_)), - ), - Function.makeDefinition( - ~name=fnName, - ~inputs=[I_Record([("p5", I_DistOrNumber), ("p95", I_DistOrNumber)])], - ~run=inputs => - inputs - ->Prepare.twoDistOrNumberFromRecord - ->E.R.bind(Process.twoDistsOrNumbersToDistUsingSymbolicDist(~fn=twoFloatsToSymbolic90P, ~values=_)), + ->process(~fn=r => twoArgs(SymbolicDist.Normal.from90PercentCI, r)->Ok), ), ], ) @@ -369,36 +363,29 @@ module NormalFn = { module LognormalFn = { let fnName = "lognormal" - let twoFloatsToSymoblic = (a1, a2) => - SymbolicDist.Lognormal.make(a1, a2)->E.R2.fmap(Wrappers.symbolic) - let twoFloatsToSymbolic90P = (a1, a2) => - SymbolicDist.Lognormal.from90PercentCI(a1, a2)->Wrappers.symbolic->Ok - let twoFloatsToMeanStdev = (a1, a2) => - SymbolicDist.Lognormal.fromMeanAndStdev(a1, a2)->E.R2.fmap(Wrappers.symbolic) + let mainInputType = I_DistOrNumber let toFn = Function.make( ~name="Lognormal", ~definitions=[ - Function.makeDefinition(~name=fnName, ~inputs=[I_DistOrNumber, I_DistOrNumber], ~run=inputs => - inputs - ->Prepare.twoDistOrNumber - ->E.R.bind(Process.twoDistsOrNumbersToDist(~fn=twoFloatsToSymoblic, ~values=_)) + Function.makeDefinition(~name=fnName, ~inputs=[mainInputType, mainInputType], ~run=inputs => + inputs->Prepare.twoDistOrNumber->process(~fn=twoArgs(SymbolicDist.Lognormal.make)) ), Function.makeDefinition( ~name=fnName, - ~inputs=[I_Record([("p5", I_DistOrNumber), ("p95", I_DistOrNumber)])], + ~inputs=[I_Record([("p5", mainInputType), ("p95", mainInputType)])], ~run=inputs => inputs ->Prepare.twoDistOrNumberFromRecord - ->E.R.bind(Process.twoDistsOrNumbersToDist(~fn=twoFloatsToSymbolic90P, ~values=_)), + ->process(~fn=r => twoArgs(SymbolicDist.Lognormal.from90PercentCI, r)->Ok), ), Function.makeDefinition( ~name=fnName, - ~inputs=[I_Record([("mean", I_DistOrNumber), ("stdev", I_DistOrNumber)])], + ~inputs=[I_Record([("mean", mainInputType), ("stdev", mainInputType)])], ~run=inputs => inputs ->Prepare.twoDistOrNumberFromRecord - ->E.R.bind(Process.twoDistsOrNumbersToDist(~fn=twoFloatsToMeanStdev, ~values=_)), + ->process(~fn=twoArgs(SymbolicDist.Lognormal.fromMeanAndStdev)), ), ], )