From 4b07226b4546ccd8b9ccdf715495ffdca126a82a Mon Sep 17 00:00:00 2001 From: Ozzie Gooen Date: Thu, 19 May 2022 17:04:31 -0400 Subject: [PATCH] Continuing cleanup of FunctionRegistry --- .../src/rescript/FunctionRegistry.res | 151 +++++++++--------- 1 file changed, 73 insertions(+), 78 deletions(-) diff --git a/packages/squiggle-lang/src/rescript/FunctionRegistry.res b/packages/squiggle-lang/src/rescript/FunctionRegistry.res index a1138129..31db2269 100644 --- a/packages/squiggle-lang/src/rescript/FunctionRegistry.res +++ b/packages/squiggle-lang/src/rescript/FunctionRegistry.res @@ -241,38 +241,28 @@ module Registry = { let impossibleError = "Wrong inputs / Logically impossible" module Prepare = { - let twoNumberInputs = (inputs: array) => { + let recordWithTwoArgsToValues = (inputs: array): result, string> => + switch inputs { + | [Record([(_, n1), (_, n2)])] => Ok([n1, n2]) + | _ => Error(impossibleError) + } + + let twoNumberInputs = (inputs: array): result<(float, float), string> => { switch inputs { | [Number(n1), Number(n2)] => Ok(n1, n2) | _ => Error(impossibleError) } } - let twoDistOrNumber = (values: array) => { + let twoDistOrNumber = (values: array): result<(distOrNumber, distOrNumber), string> => { switch values { | [DistOrNumber(a1), DistOrNumber(a2)] => Ok(a1, a2) | _ => Error(impossibleError) } } - let twoNumberInputsRecord = (v1: string, v2: string, inputs: array) => - switch inputs { - | [Record([(name1, n1), (name2, n2)])] if name1 == v1 && name2 == v2 => - twoNumberInputs([n1, n2]) - | _ => Error(impossibleError) - } - - let twoNumberInputsRecord2 = (inputs: array) => - switch inputs { - | [Record([(_, n1), (_, n2)])] => twoNumberInputs([n1, n2]) - | _ => Error(impossibleError) - } - - let twoNumberInputsRecord3 = (inputs: array) => - switch inputs { - | [Record([(_, n1), (_, n2)])] => Ok([n1, n2]) - | _ => Error(impossibleError) - } + let twoDistOrNumberFromRecord = (values: array) => + values->recordWithTwoArgsToValues->E.R.bind(twoDistOrNumber) } module Wrappers = { @@ -281,60 +271,69 @@ module Wrappers = { let symbolicEvDistribution = r => r->Symbolic->evDistribution } -let twoDistsOrNumbers = ( - ~fn: (float, float) => result, - ~values: (distOrNumber, distOrNumber), -) => { - let toSampleSet = r => GenericDist.toSampleSetDist(r, 1000) - let sampleSetToExpressionValue = ( - b: Belt.Result.t, - ) => - switch b { - | Ok(r) => Ok(ReducerInterface_ExpressionValue.EvDistribution(SampleSet(r))) - | Error(d) => Error(DistributionTypes.Error.toString(d)) - } +module Process = { + let twoDistsOrNumbersToDist = ( + ~fn: (float, float) => result, + ~values: (distOrNumber, distOrNumber), + ) => { + let toSampleSet = r => GenericDist.toSampleSetDist(r, 1000) + let sampleSetToExpressionValue = ( + b: Belt.Result.t, + ) => + switch b { + | Ok(r) => Ok(ReducerInterface_ExpressionValue.EvDistribution(SampleSet(r))) + | Error(d) => Error(DistributionTypes.Error.toString(d)) + } - let mapFnResult = r => - switch r { - | Ok(r) => Ok(GenericDist.sample(r)) - | Error(r) => Error(Operation.Other(r)) - } + let mapFnResult = r => + switch r { + | Ok(r) => Ok(GenericDist.sample(r)) + | Error(r) => Error(Operation.Other(r)) + } - let singleVarSample = (a, fn) => { - let sampleSetResult = - toSampleSet(a) |> E.R2.bind(dist => - SampleSetDist.samplesMap( - ~fn=f => fn(f)->mapFnResult, - dist, - )->E.R2.errMap(r => DistributionTypes.SampleSetError(r)) - ) - sampleSetResult->sampleSetToExpressionValue - } - - switch values { - | (Number(a1), Number(a2)) => fn(a1, a2)->E.R2.fmap(Wrappers.evDistribution) - | (Dist(a1), Number(a2)) => singleVarSample(a1, r => fn(r, a2)) - | (Number(a1), Dist(a2)) => singleVarSample(a2, r => fn(a1, r)) - | (Dist(a1), Dist(a2)) => { - let altFn = (a, b) => fn(a, b)->mapFnResult + let singleVarSample = (a, fn) => { let sampleSetResult = - E.R.merge(toSampleSet(a1), toSampleSet(a2)) - ->E.R2.errMap(DistributionTypes.Error.toString) - ->E.R.bind(((t1, t2)) => { - SampleSetDist.map2(~fn=altFn, ~t1, ~t2)->E.R2.errMap(Operation.Error.toString) - }) - ->E.R2.errMap(r => DistributionTypes.OtherError(r)) + toSampleSet(a) |> E.R2.bind(dist => + SampleSetDist.samplesMap( + ~fn=f => fn(f)->mapFnResult, + dist, + )->E.R2.errMap(r => DistributionTypes.SampleSetError(r)) + ) sampleSetResult->sampleSetToExpressionValue } + + switch values { + | (Number(a1), Number(a2)) => fn(a1, a2)->E.R2.fmap(Wrappers.evDistribution) + | (Dist(a1), Number(a2)) => singleVarSample(a1, r => fn(r, a2)) + | (Number(a1), Dist(a2)) => singleVarSample(a2, r => fn(a1, r)) + | (Dist(a1), Dist(a2)) => { + let altFn = (a, b) => fn(a, b)->mapFnResult + let sampleSetResult = + E.R.merge(toSampleSet(a1), toSampleSet(a2)) + ->E.R2.errMap(DistributionTypes.Error.toString) + ->E.R.bind(((t1, t2)) => { + SampleSetDist.map2(~fn=altFn, ~t1, ~t2)->E.R2.errMap(Operation.Error.toString) + }) + ->E.R2.errMap(r => DistributionTypes.OtherError(r)) + sampleSetResult->sampleSetToExpressionValue + } + } + } + + let twoDistsOrNumbersToDistUsingSymbolicDist = ( + ~fn: (float, float) => result, + ~values, + ) => { + twoDistsOrNumbersToDist(~fn=(a, b) => fn(a, b)->E.R2.fmap(Wrappers.symbolic), ~values) } } module NormalFn = { let fnName = "normal" - let twoFloatsToSymoblic = (a1: float, a2: float) => - SymbolicDist.Normal.make(a1, a2)->E.R2.fmap(Wrappers.symbolic) - let twoFloatsToSymbolic90P = (a1: float, a2: float) => - SymbolicDist.Normal.from90PercentCI(a1, a2)->Wrappers.symbolic->Ok + let twoFloatsToSymoblic = (a1, a2) => + SymbolicDist.Normal.make(a1, a2) + let twoFloatsToSymbolic90P = (a1, a2) => + SymbolicDist.Normal.from90PercentCI(a1, a2)->Ok let toFn = Function.make( ~name="Normal", @@ -345,7 +344,7 @@ module NormalFn = { ~run=inputs => { inputs ->Prepare.twoDistOrNumber - ->E.R.bind(twoDistsOrNumbers(~fn=twoFloatsToSymoblic, ~values=_)) + ->E.R.bind(Process.twoDistsOrNumbersToDistUsingSymbolicDist(~fn=twoFloatsToSymoblic, ~values=_)) }, ), Function.makeDefinition( @@ -353,18 +352,16 @@ module NormalFn = { ~inputs=[I_Record([("mean", I_DistOrNumber), ("stdev", I_DistOrNumber)])], ~run=inputs => inputs - ->Prepare.twoNumberInputsRecord3 - ->E.R.bind(Prepare.twoDistOrNumber) - ->E.R.bind(twoDistsOrNumbers(~fn=twoFloatsToSymoblic, ~values=_)), + ->Prepare.twoDistOrNumberFromRecord + ->E.R.bind(Process.twoDistsOrNumbersToDistUsingSymbolicDist(~fn=twoFloatsToSymoblic, ~values=_)), ), Function.makeDefinition( ~name=fnName, ~inputs=[I_Record([("p5", I_DistOrNumber), ("p95", I_DistOrNumber)])], ~run=inputs => inputs - ->Prepare.twoNumberInputsRecord3 - ->E.R.bind(Prepare.twoDistOrNumber) - ->E.R.bind(twoDistsOrNumbers(~fn=twoFloatsToSymbolic90P, ~values=_)), + ->Prepare.twoDistOrNumberFromRecord + ->E.R.bind(Process.twoDistsOrNumbersToDistUsingSymbolicDist(~fn=twoFloatsToSymbolic90P, ~values=_)), ), ], ) @@ -385,25 +382,23 @@ module LognormalFn = { Function.makeDefinition(~name=fnName, ~inputs=[I_DistOrNumber, I_DistOrNumber], ~run=inputs => inputs ->Prepare.twoDistOrNumber - ->E.R.bind(twoDistsOrNumbers(~fn=twoFloatsToSymoblic, ~values=_)) + ->E.R.bind(Process.twoDistsOrNumbersToDist(~fn=twoFloatsToSymoblic, ~values=_)) ), Function.makeDefinition( ~name=fnName, ~inputs=[I_Record([("p5", I_DistOrNumber), ("p95", I_DistOrNumber)])], ~run=inputs => inputs - ->Prepare.twoNumberInputsRecord3 - ->E.R.bind(Prepare.twoDistOrNumber) - ->E.R.bind(twoDistsOrNumbers(~fn=twoFloatsToSymbolic90P, ~values=_)), + ->Prepare.twoDistOrNumberFromRecord + ->E.R.bind(Process.twoDistsOrNumbersToDist(~fn=twoFloatsToSymbolic90P, ~values=_)), ), Function.makeDefinition( ~name=fnName, ~inputs=[I_Record([("mean", I_DistOrNumber), ("stdev", I_DistOrNumber)])], ~run=inputs => inputs - ->Prepare.twoNumberInputsRecord3 - ->E.R.bind(Prepare.twoDistOrNumber) - ->E.R.bind(twoDistsOrNumbers(~fn=twoFloatsToMeanStdev, ~values=_)), + ->Prepare.twoDistOrNumberFromRecord + ->E.R.bind(Process.twoDistsOrNumbersToDist(~fn=twoFloatsToMeanStdev, ~values=_)), ), ], )