From 18d742b63c1cccdf1d9cef54be14a166ba15d2f8 Mon Sep 17 00:00:00 2001 From: Ozzie Gooen Date: Fri, 1 Apr 2022 13:21:24 -0400 Subject: [PATCH] Added symbolic functions and tests for reducer interface distribution code --- .../ReducerInterface_Distribution_test.res | 26 ++++++++++ .../ReducerInterface_GenericDistribution.res | 52 +++++++++++++++++-- 2 files changed, 74 insertions(+), 4 deletions(-) create mode 100644 packages/squiggle-lang/__tests__/ReducerInterface/ReducerInterface_Distribution_test.res diff --git a/packages/squiggle-lang/__tests__/ReducerInterface/ReducerInterface_Distribution_test.res b/packages/squiggle-lang/__tests__/ReducerInterface/ReducerInterface_Distribution_test.res new file mode 100644 index 00000000..8fd6f9da --- /dev/null +++ b/packages/squiggle-lang/__tests__/ReducerInterface/ReducerInterface_Distribution_test.res @@ -0,0 +1,26 @@ +open Jest +open Reducer_TestHelpers + +let makeTest = (str, result) => test(str, () => expectEvalToBe(str, result)) + +describe("eval", () => { + Only.describe("expressions", () => { + makeTest("normal(5,2)", "Ok(Normal(5,2))") + makeTest("lognormal(5,2)", "Ok(Lognormal(5,2))") + makeTest("mean(normal(5,2))", "Ok(5)") + makeTest("mean(lognormal(1,2))", "Ok(20.085536923187668)") + makeTest("normalize(normal(5,2))", "Ok(Normal(5,2))") + makeTest("toPointSet(normal(5,2))", "Ok(Point Set Distribution)") + makeTest("toSampleSet(normal(5,2), 100)", "Ok(Sample Set Distribution)") + makeTest("add(normal(5,2), normal(10,2))", "Ok(Normal(15,2.8284271247461903))") + makeTest("add(normal(5,2), lognormal(10,2))", "Ok(Sample Set Distribution)") + makeTest("dotAdd(normal(5,2), lognormal(10,2))", "Ok(Point Set Distribution)") + makeTest("dotAdd(normal(5,2), 3)", "Ok(Point Set Distribution)") + makeTest("add(normal(5,2), 3)", "Ok(Point Set Distribution)") + makeTest("add(3, normal(5,2))", "Ok(Point Set Distribution)") + makeTest("3+normal(5,2)", "Ok(Point Set Distribution)") + makeTest("add(3, 3)", "Ok(6)") + makeTest("truncateLeft(normal(5,2), 3)", "Ok(Point Set Distribution)") + makeTest("mean(add(3, normal(5,2)))", "Ok(8.004619792609384)") + }) +}) diff --git a/packages/squiggle-lang/src/rescript/ReducerInterface/ReducerInterface_GenericDistribution.res b/packages/squiggle-lang/src/rescript/ReducerInterface/ReducerInterface_GenericDistribution.res index 18f152d4..0bc7e23f 100644 --- a/packages/squiggle-lang/src/rescript/ReducerInterface/ReducerInterface_GenericDistribution.res +++ b/packages/squiggle-lang/src/rescript/ReducerInterface/ReducerInterface_GenericDistribution.res @@ -74,11 +74,54 @@ let genericOutputToReducerValue = (o: GenericDist_GenericOperation.outputType): | GenDistError(Other(s)) => Error(RETodo(s)) } +module SymbolicConstructor = { + let oneFloat = name => + switch name { + | "exponential" => Ok(SymbolicDist.Exponential.make) + | _ => Error("impossible path") + } + + let twoFloat = name => + switch name { + | "normal" => Ok(SymbolicDist.Normal.make) + | "uniform" => Ok(SymbolicDist.Uniform.make) + | "beta" => Ok(SymbolicDist.Beta.make) + | "lognormal" => Ok(SymbolicDist.Lognormal.make) + | _ => Error("impossible path") + } + + let threeFloat = name => + switch name { + | "triangular" => Ok(SymbolicDist.Triangular.make) + | _ => Error("impossible path") + } + + let symbolicResultToOutput = ( + symbolicResult: result, + ): option => + switch symbolicResult { + | Ok(r) => Some(Dist(Symbolic(r))) + | Error(r) => Some(GenDistError(Other(r))) + } +} + let dispatchToGenericOutput = (call: ExpressionValue.functionCall): option< GenericDist_GenericOperation.outputType, > => { let (fnName, args) = call switch (fnName, args) { + | ("exponential" as fnName, [EvNumber(f1)]) => + SymbolicConstructor.oneFloat(fnName) + ->E.R.bind(r => r(f1)) + ->SymbolicConstructor.symbolicResultToOutput + | (("normal" | "uniform" | "beta" | "lognormal") as fnName, [EvNumber(f1), EvNumber(f2)]) => + SymbolicConstructor.twoFloat(fnName) + ->E.R.bind(r => r(f1, f2)) + ->SymbolicConstructor.symbolicResultToOutput + | ("triangular" as fnName, [EvNumber(f1), EvNumber(f2), EvNumber(f3)]) => + SymbolicConstructor.threeFloat(fnName) + ->E.R.bind(r => r(f1, f2, f3)) + ->SymbolicConstructor.symbolicResultToOutput | ("sample", [EvDistribution(dist)]) => toFloatFn(#Sample, dist) | ("mean", [EvDistribution(dist)]) => toFloatFn(#Mean, dist) | ("normalize", [EvDistribution(dist)]) => toDistFn(Normalize, dist) @@ -88,20 +131,21 @@ let dispatchToGenericOutput = (call: ExpressionValue.functionCall): option< | ("inv", [EvDistribution(dist), EvNumber(float)]) => toFloatFn(#Inv(float), dist) | ("toSampleSet", [EvDistribution(dist), EvNumber(float)]) => toDistFn(ToSampleSet(Belt.Int.fromFloat(float)), dist) - | ("truncateLeft", [EvDistribution(dist), EvNumber(float)]) => toDistFn(Truncate(Some(float), None), dist) + | ("truncateLeft", [EvDistribution(dist), EvNumber(float)]) => + toDistFn(Truncate(Some(float), None), dist) | ("truncateRight", [EvDistribution(dist), EvNumber(float)]) => toDistFn(Truncate(None, Some(float)), dist) | ("truncate", [EvDistribution(dist), EvNumber(float1), EvNumber(float2)]) => toDistFn(Truncate(Some(float1), Some(float2)), dist) | (("add" | "multiply" | "subtract" | "divide" | "exponentiate") as arithmetic, [a, b] as args) => - catchAndConvertTwoArgsToDists(args) -> E.O2.fmap(((fst, snd)) => + catchAndConvertTwoArgsToDists(args)->E.O2.fmap(((fst, snd)) => twoDiststoDistFn(Algebraic, arithmetic, fst, snd) ) | ( ("dotAdd" | "dotMultiply" | "dotSubtract" | "dotDivide" | "dotExponentiate") as arithmetic, [a, b] as args, ) => - catchAndConvertTwoArgsToDists(args) -> E.O2.fmap(((fst, snd)) => + catchAndConvertTwoArgsToDists(args)->E.O2.fmap(((fst, snd)) => twoDiststoDistFn(Pointwise, arithmetic, fst, snd) ) | _ => None @@ -109,5 +153,5 @@ let dispatchToGenericOutput = (call: ExpressionValue.functionCall): option< } let dispatch = call => { - dispatchToGenericOutput(call) -> E.O2.fmap(genericOutputToReducerValue) + dispatchToGenericOutput(call)->E.O2.fmap(genericOutputToReducerValue) }