From 5ece2994ba612a9969834e964c9a023a8b8d1690 Mon Sep 17 00:00:00 2001 From: Ozzie Gooen Date: Thu, 31 Mar 2022 19:58:08 -0400 Subject: [PATCH] Full attempt at getting genericDist into Reducer External Lib --- .../GenericDist_GenericOperation.res | 12 ++ .../GenericDist_GenericOperation.resi | 2 + .../ReducerInterface_ExpressionValue.res | 4 +- .../ReducerInterface_ExternalLibrary.res | 147 ++++++++++++++---- 4 files changed, 135 insertions(+), 30 deletions(-) diff --git a/packages/squiggle-lang/src/rescript/GenericDist/GenericDist_GenericOperation.res b/packages/squiggle-lang/src/rescript/GenericDist/GenericDist_GenericOperation.res index 67db34e1..351389ae 100644 --- a/packages/squiggle-lang/src/rescript/GenericDist/GenericDist_GenericOperation.res +++ b/packages/squiggle-lang/src/rescript/GenericDist/GenericDist_GenericOperation.res @@ -48,12 +48,24 @@ module OutputLocal = { | _ => None } + let toFloatR = (t: t): result => + switch t { + | Float(r) => Ok(r) + | e => Error(toErrorOrUnreachable(e)) + } + let toString = (t: t) => switch t { | String(d) => Some(d) | _ => None } + let toStringR = (t: t): result => + switch t { + | String(r) => Ok(r) + | e => Error(toErrorOrUnreachable(e)) + } + //This is used to catch errors in other switch statements. let fromResult = (r: result): outputType => switch r { diff --git a/packages/squiggle-lang/src/rescript/GenericDist/GenericDist_GenericOperation.resi b/packages/squiggle-lang/src/rescript/GenericDist/GenericDist_GenericOperation.resi index abbd713e..3c3e132a 100644 --- a/packages/squiggle-lang/src/rescript/GenericDist/GenericDist_GenericOperation.resi +++ b/packages/squiggle-lang/src/rescript/GenericDist/GenericDist_GenericOperation.resi @@ -26,7 +26,9 @@ module Output: { let toDist: t => option let toDistR: t => result let toFloat: t => option + let toFloatR: t => result let toString: t => option + let toStringR: t => result let toError: t => option let fmap: (~env: env, t, GenericDist_Types.Operation.singleParamaterFunction) => t } diff --git a/packages/squiggle-lang/src/rescript/ReducerInterface/ReducerInterface_ExpressionValue.res b/packages/squiggle-lang/src/rescript/ReducerInterface/ReducerInterface_ExpressionValue.res index 2ad3a402..633e6a3c 100644 --- a/packages/squiggle-lang/src/rescript/ReducerInterface/ReducerInterface_ExpressionValue.res +++ b/packages/squiggle-lang/src/rescript/ReducerInterface/ReducerInterface_ExpressionValue.res @@ -36,7 +36,7 @@ let rec toString = aValue => ->Js.String.concatMany("") `{${pairs}}` } - // | Dist() => + | EvDist(dist) => `${GenericDist.toString(dist)}` } let toStringWithType = aValue => @@ -47,7 +47,7 @@ let toStringWithType = aValue => | EvSymbol(_) => `Symbol::${toString(aValue)}` | EvArray(_) => `Array::${toString(aValue)}` | EvRecord(_) => `Record::${toString(aValue)}` - // | Dist(_) => + | EvDist(_) => `Distribution::${toString(aValue)}` } let argsToString = (args: array): string => { diff --git a/packages/squiggle-lang/src/rescript/ReducerInterface/ReducerInterface_ExternalLibrary.res b/packages/squiggle-lang/src/rescript/ReducerInterface/ReducerInterface_ExternalLibrary.res index 608c9792..d4586978 100644 --- a/packages/squiggle-lang/src/rescript/ReducerInterface/ReducerInterface_ExternalLibrary.res +++ b/packages/squiggle-lang/src/rescript/ReducerInterface/ReducerInterface_ExternalLibrary.res @@ -10,38 +10,130 @@ module Sample = { let customAdd = (a: float, b: float): float => {a +. b} } +module Dist = { + let env: GenericDist_GenericOperation.env = { + sampleCount: 1000, + xyPointLength: 1000, + } + + let {toDistR, toFloatR} = module(GenericDist_GenericOperation.Output) + let runGenericOperation = GenericDist_GenericOperation.run(~env) + + let genericDistReturnToEvReturn = x => + switch x { + | Ok(thing) => Ok(ReducerInterface_ExpressionValue.EvDist(thing)) + | Error(err) => Error(Reducer_ErrorValue.RETodo("")) // TODO: + } + + let numberReturnToEvReturn = x => + switch x { + | Ok(n) => Ok(ReducerInterface_ExpressionValue.EvNumber(n)) + | Error(err) => Error(Reducer_ErrorValue.RETodo("")) // TODO: + } + + let arithmeticMap = r => + switch r { + | "add" => #Add + | "dotAdd" => #Add + | "subtract" => #Subtract + | "dotSubtract" => #Subtract + | "divide" => #Divide + | "logarithm" => #Divide + | "dotDivide" => #Divide + | "exponentiate" => #Exponentiate + | "dotExponentiate" => #Exponentiate + | "multiply" => #Multiply + | "dotMultiply" => #Multiply + | "dotLogarithm" => #Divide + | _ => #Multiply + } + + let catchAndConvertTwoArgsToDists = (args: array): option<( + GenericDist_Types.genericDist, + GenericDist_Types.genericDist, + )> => { + switch args { + | [EvDist(a), EvDist(b)] => Some((a, b)) + | [EvNumber(a), EvDist(b)] => Some((GenericDist.fromFloat(a), b)) + | [EvDist(a), EvNumber(b)] => Some((a, GenericDist.fromFloat(b))) + | _ => None + } + } + + let toFloatFn = (fnCall: GenericDist_Types.Operation.toFloat, dist) => { + FromDist(GenericDist_Types.Operation.ToFloat(fnCall), dist) + ->runGenericOperation + ->toFloatR + ->numberReturnToEvReturn + ->Some + } + + let toDistFn = (fnCall: GenericDist_Types.Operation.toDist, dist) => { + FromDist(GenericDist_Types.Operation.ToDist(fnCall), dist) + ->runGenericOperation + ->toDistR + ->genericDistReturnToEvReturn + ->Some + } + + let twoDiststoDistFn = (direction, arithmetic, dist1, dist2) => { + FromDist( + GenericDist_Types.Operation.ToDistCombination( + direction, + arithmeticMap(arithmetic), + #Dist(dist2), + ), + dist1, + ) + ->runGenericOperation + ->toDistR + ->genericDistReturnToEvReturn + } + + let dispatch = (call: ExpressionValue.functionCall): option> => { + let (fnName, args) = call + switch (fnName, args) { + | ("cdf", [EvDist(dist), EvNumber(float)]) => toFloatFn(#Cdf(float), dist) + | ("pdf", [EvDist(dist), EvNumber(float)]) => toFloatFn(#Pdf(float), dist) + | ("inv", [EvDist(dist), EvNumber(float)]) => toFloatFn(#Inv(float), dist) + | ("mean", [EvDist(dist)]) => toFloatFn(#Mean, dist) + | ("normalize", [EvDist(dist)]) => toDistFn(Normalize, dist) + | ("toPointSet", [EvDist(dist)]) => toDistFn(ToPointSet, dist) + | ("toSampleSet", [EvDist(dist), EvNumber(float)]) => + toDistFn(ToSampleSet(Belt.Int.fromFloat(float)), dist) + | ("truncateLeft", [EvDist(dist), EvNumber(float)]) => + toDistFn(Truncate(Some(float), None), dist) + | ("truncateRight", [EvDist(dist), EvNumber(float)]) => + toDistFn(Truncate(None, Some(float)), dist) + | ("truncate", [EvDist(dist), EvNumber(float1), EvNumber(float2)]) => + toDistFn(Truncate(Some(float1), Some(float2)), dist) + | ("sample", [EvDist(dist)]) => toFloatFn(#Sample, dist) + | ( + ("add" | "multiply" | "subtract" | "divide" | "exponentiate") as arithmetic, + [a, b] as args, + ) => + catchAndConvertTwoArgsToDists(args) |> E.O.fmap(((fst, snd)) => + twoDiststoDistFn(Algebraic, arithmetic, fst, snd) + ) + | ( + ("dotAdd" | "dotSubtract" | "dotDivide" | "dotExponentiate" | "dotMultiply") as arithmetic, + [a, b] as args, + ) => + catchAndConvertTwoArgsToDists(args) |> E.O.fmap(((fst, snd)) => + twoDiststoDistFn(Pointwise, arithmetic, fst, snd) + ) + | _ => None + } + } +} + /* Map external calls of Reducer */ -let env: GenericDist_GenericOperation.env = { - sampleCount: 100, - xyPointLength: 100, -} let dispatch = (call: ExpressionValue.functionCall, chain): result => - switch call { - | ("add", [EvNumber(a), EvNumber(b)]) => Sample.customAdd(a, b)->EvNumber->Ok - | ("add", [EvDist(a), EvDist(b)]) => { - let x = GenericDist_GenericOperation.Output.toDistR( - GenericDist_GenericOperation.run(~env, FromDist(ToDistCombination(Algebraic, #Add, #Dist(b)), a)) - ) - switch x { - | Ok(thing) => Ok(EvDist(thing)) - | Error(err) => Error(Reducer_ErrorValue.RETodo("")) // TODO: - } - } - | ("add", [EvNumber(a), EvDist(b)]) => { - let x = GenericDist_GenericOperation.Output.toDistR( - GenericDist_GenericOperation.run(~env, FromDist(ToDistCombination(Algebraic, #Add, #Dist(b)), a)) - ) - switch x { - | Ok(thing) => Ok(EvDist(thing)) - | Error(err) => Error(Reducer_ErrorValue.RETodo("")) // TODO: - } - } - | call => chain(call) - - /* + Dist.dispatch(call) |> E.O.default(chain(call)) +/* If your dispatch is too big you can divide it into smaller dispatches and pass the call so that it gets called finally. The final chain(call) invokes the builtin default functions of the interpreter. @@ -57,4 +149,3 @@ Remember from the users point of view, there are no different modules: // "doSth( constructorType2 )" doSth gets dispatched to the correct module because of the type signature. You get function and operator abstraction for free. You don't need to combine different implementations into one type. That would be duplicating the repsonsibility of the dispatcher. */ - }