From 06285dbdc10fd6efd97c253cab6766e41451a64f Mon Sep 17 00:00:00 2001 From: Ozzie Gooen Date: Sat, 2 Apr 2022 16:25:41 -0400 Subject: [PATCH] Additional testing for GenericDist-Reducer interface, plus getting log, exp to work with 1 param --- .../ReducerInterface_Distribution_test.res | 88 +++++++++++++++++-- .../ReducerInterface_GenericDistribution.res | 48 ++++++---- 2 files changed, 113 insertions(+), 23 deletions(-) diff --git a/packages/squiggle-lang/__tests__/ReducerInterface/ReducerInterface_Distribution_test.res b/packages/squiggle-lang/__tests__/ReducerInterface/ReducerInterface_Distribution_test.res index 55ae387c..b514232a 100644 --- a/packages/squiggle-lang/__tests__/ReducerInterface/ReducerInterface_Distribution_test.res +++ b/packages/squiggle-lang/__tests__/ReducerInterface/ReducerInterface_Distribution_test.res @@ -1,10 +1,18 @@ open Jest -open Reducer_TestHelpers -let testEval = (str, result) => test(str, () => expectEvalToBe(str, result)) +let testSkip: (bool, string, unit => assertion) => unit = (skip: bool) => + if skip { + Skip.test + } else { + test + } +let testEval = (~skip=false, str, result) => + testSkip(skip)(str, () => Reducer_TestHelpers.expectEvalToBe(str, result)) +let testParse = (~skip=false, str, result) => + testSkip(skip)(str, () => Reducer_TestHelpers.expectParseToBe(str, result)) describe("eval", () => { - Only.describe("expressions", () => { + describe("expressions", () => { testEval("normal(5,2)", "Ok(Normal(5,2))") testEval("5 to 2", "Error(TODO: Low value must be less than high value.)") testEval("to(2,5)", "Ok(Lognormal(1.1512925464970227,0.278507821238345))") @@ -17,12 +25,82 @@ describe("eval", () => { testEval("toSampleSet(normal(5,2), 100)", "Ok(Sample Set Distribution)") testEval("add(normal(5,2), normal(10,2))", "Ok(Normal(15,2.8284271247461903))") testEval("add(normal(5,2), lognormal(10,2))", "Ok(Sample Set Distribution)") - testEval("pointwiseAdd(normal(5,2), lognormal(10,2))", "Ok(Point Set Distribution)") - testEval("pointwiseAdd(normal(5,2), 3)", "Ok(Point Set Distribution)") testEval("add(normal(5,2), 3)", "Ok(Point Set Distribution)") testEval("add(3, normal(5,2))", "Ok(Point Set Distribution)") testEval("3+normal(5,2)", "Ok(Point Set Distribution)") + testEval("normal(5,2)+3", "Ok(Point Set Distribution)") testEval("add(3, 3)", "Ok(6)") testEval("truncateLeft(normal(5,2), 3)", "Ok(Point Set Distribution)") + testEval("truncateRight(normal(5,2), 3)", "Ok(Point Set Distribution)") + testEval("truncate(normal(5,2), 3, 8)", "Ok(Point Set Distribution)") + }) + + describe("exp", () => { + testEval("exp(normal(5,2))", "Ok(Point Set Distribution)") + }) + + describe("pow", () => { + testEval("pow(3, uniform(5,8))", "Ok(Point Set Distribution)") + testEval("pow(uniform(5,8), 3)", "Ok(Point Set Distribution)") + testEval("pow(uniform(5,8), uniform(9, 10))", "Ok(Sample Set Distribution)") + }) + + describe("log", () => { + testEval("log(2, uniform(5,8))", "Ok(Point Set Distribution)") + testEval("log(normal(5,2), 3)", "Ok(Point Set Distribution)") + testEval("log(normal(5,2), normal(10,1))", "Ok(Sample Set Distribution)") + testEval("log(uniform(5,8))", "Ok(Point Set Distribution)") + testEval("log10(uniform(5,8))", "Ok(Point Set Distribution)") + }) + + describe("dotLog", () => { + testEval("dotLog(normal(5,2), 3)", "Ok(Point Set Distribution)") + testEval("dotLog(normal(5,2), 3)", "Ok(Point Set Distribution)") + testEval("dotLog(normal(5,2), normal(10,1))", "Ok(Point Set Distribution)") + }) + + describe("dotAdd", () => { + testEval("dotAdd(normal(5,2), lognormal(10,2))", "Ok(Point Set Distribution)") + testEval("dotAdd(normal(5,2), 3)", "Ok(Point Set Distribution)") + }) + + describe("equality", () => { + testEval(~skip=true, "normal(5,2) == normal(5,2)", "Ok(true)") + }) + + describe("mixture", () => { + testEval( + ~skip=true, + "mx(normal(5,2), normal(10,1), normal(15, 1))", + "Ok(Point Set Distribution)", + ) + testEval( + ~skip=true, + "mixture(normal(5,2), normal(10,1), [.2,, .4])", + "Ok(Point Set Distribution)", + ) + }) +}) + +describe("MathJs parse", () => { + describe("literals operators paranthesis", () => { + testParse("mean(normal(5,2) + normal(5,1))", "Ok((:mean (:add (:normal 5 2) (:normal 5 1))))") + testParse("normal(5,2) .* normal(5,1)", "Ok((:dotMultiply (:normal 5 2) (:normal 5 1)))") + testParse("normal(5,2) ./ normal(5,1)", "Ok((:dotDivide (:normal 5 2) (:normal 5 1)))") + testParse("normal(5,2) .^ normal(5,1)", "Ok((:dotPow (:normal 5 2) (:normal 5 1)))") + testParse("normal(5,2) ^ normal(5,1)", "Ok((:pow (:normal 5 2) (:normal 5 1)))") + testParse("3 ^ normal(5,1)", "Ok((:pow 3 (:normal 5 1)))") + testParse("normal(5,2) ^ 3", "Ok((:pow (:normal 5 2) 3))") + testParse("5 == normal(5,2)", "Ok((:equal 5 (:normal 5 2)))") + describe("adding two normals", () => { + testParse( + ~skip=true, + "normal(5,2) .+ normal(5,1)", + "Ok((:dotAdd (:normal 5 2) (:normal 5 1)))", + ) + }) + describe("exponential of one distribution", () => { + testParse(~skip=true, "exp(normal(5,2)", "Ok((:pow (:normal 5 2) 3))") + }) }) }) diff --git a/packages/squiggle-lang/src/rescript/ReducerInterface/ReducerInterface_GenericDistribution.res b/packages/squiggle-lang/src/rescript/ReducerInterface/ReducerInterface_GenericDistribution.res index 12d0ddb4..11376001 100644 --- a/packages/squiggle-lang/src/rescript/ReducerInterface/ReducerInterface_GenericDistribution.res +++ b/packages/squiggle-lang/src/rescript/ReducerInterface/ReducerInterface_GenericDistribution.res @@ -12,17 +12,17 @@ module Helpers = { let arithmeticMap = r => switch r { | "add" => #Add - | "pointwiseAdd" => #Add + | "dotAdd" => #Add | "subtract" => #Subtract - | "pointwiseSubtract" => #Subtract + | "dotSubtract" => #Subtract | "divide" => #Divide - | "logarithm" => #Logarithm - | "pointwiseDivide" => #Divide - | "exponentiate" => #Exponentiate - | "pointwiseExponentiate" => #Exponentiate + | "log" => #Logarithm + | "dotDivide" => #Divide + | "pow" => #Exponentiate + | "dotPow" => #Exponentiate | "multiply" => #Multiply - | "pointwiseMultiply" => #Multiply - | "pointwiseLogarithm" => #Logarithm + | "dotMultiply" => #Multiply + | "dotLog" => #Logarithm | _ => #Multiply } @@ -93,6 +93,10 @@ module SymbolicConstructors = { } } +module Math = { + let e = 2.718281828459 +} + let dispatchToGenericOutput = (call: ExpressionValue.functionCall): option< GenericDist_GenericOperation.outputType, > => { @@ -115,6 +119,9 @@ let dispatchToGenericOutput = (call: ExpressionValue.functionCall): option< ->SymbolicConstructors.symbolicResultToOutput | ("sample", [EvDistribution(dist)]) => Helpers.toFloatFn(#Sample, dist) | ("mean", [EvDistribution(dist)]) => Helpers.toFloatFn(#Mean, dist) + | ("exp", [EvDistribution(a)]) => + // https://mathjs.org/docs/reference/functions/exp.html + Helpers.twoDiststoDistFn(Algebraic, "pow", GenericDist.fromFloat(Math.e), a)->Some | ("normalize", [EvDistribution(dist)]) => Helpers.toDistFn(Normalize, dist) | ("toPointSet", [EvDistribution(dist)]) => Helpers.toDistFn(ToPointSet, dist) | ("cdf", [EvDistribution(dist), EvNumber(float)]) => Helpers.toFloatFn(#Cdf(float), dist) @@ -128,25 +135,30 @@ let dispatchToGenericOutput = (call: ExpressionValue.functionCall): option< Helpers.toDistFn(Truncate(None, Some(float)), dist) | ("truncate", [EvDistribution(dist), EvNumber(float1), EvNumber(float2)]) => Helpers.toDistFn(Truncate(Some(float1), Some(float2)), dist) - | ( - ("add" | "multiply" | "subtract" | "divide" | "exponentiate" | "log") as arithmetic, - [a, b] as args, - ) => + | ("log", [EvDistribution(a)]) => + Helpers.twoDiststoDistFn(Algebraic, "log", a, GenericDist.fromFloat(Math.e))->Some + | ("log10", [EvDistribution(a)]) => + Helpers.twoDiststoDistFn(Algebraic, "log", a, GenericDist.fromFloat(10.0))->Some + | (("add" | "multiply" | "subtract" | "divide" | "pow" | "log") as arithmetic, [a, b] as args) => Helpers.catchAndConvertTwoArgsToDists(args)->E.O2.fmap(((fst, snd)) => Helpers.twoDiststoDistFn(Algebraic, arithmetic, fst, snd) ) | ( - ("pointwiseAdd" - | "pointwiseMultiply" - | "pointwiseSubtract" - | "pointwiseDivide" - | "pointwiseExponentiate" - | "pointwiseLogarithm") as arithmetic, + ("dotAdd" + | "dotMultiply" + | "dotSubtract" + | "dotDivide" + | "dotPow" + | "dotLog") as arithmetic, [a, b] as args, ) => Helpers.catchAndConvertTwoArgsToDists(args)->E.O2.fmap(((fst, snd)) => Helpers.twoDiststoDistFn(Pointwise, arithmetic, fst, snd) ) + | ("dotLog", [EvDistribution(a)]) => + Helpers.twoDiststoDistFn(Pointwise, "dotLog", a, GenericDist.fromFloat(Math.e))->Some + | ("dotExp", [EvDistribution(a)]) => + Helpers.twoDiststoDistFn(Pointwise, "dotPow", GenericDist.fromFloat(Math.e), a)->Some | _ => None } }