From 348b1c9ac673de7e1617588b736454a5ecdf95f0 Mon Sep 17 00:00:00 2001 From: Sam Nolan Date: Wed, 13 Apr 2022 14:36:30 +1000 Subject: [PATCH] Add normal distribution analytical simplifications --- .../ReducerInterface_Distribution_test.res | 19 +++++++++---- .../SymbolicDist/SymbolicDist.res | 27 +++++++++++++++++++ .../ReducerInterface_GenericDistribution.res | 3 ++- 3 files changed, 43 insertions(+), 6 deletions(-) diff --git a/packages/squiggle-lang/__tests__/ReducerInterface/ReducerInterface_Distribution_test.res b/packages/squiggle-lang/__tests__/ReducerInterface/ReducerInterface_Distribution_test.res index 0a131d93..023bdcdf 100644 --- a/packages/squiggle-lang/__tests__/ReducerInterface/ReducerInterface_Distribution_test.res +++ b/packages/squiggle-lang/__tests__/ReducerInterface/ReducerInterface_Distribution_test.res @@ -19,7 +19,7 @@ describe("eval on distribution functions", () => { testEval("lognormal(5,2)", "Ok(Lognormal(5,2))") }) describe("unaryMinus", () => { - testEval("mean(-normal(5,2))", "Ok(-5.002887370380851)") + testEval("mean(-normal(5,2))", "Ok(-5)") }) describe("to", () => { testEval("5 to 2", "Error(TODO: Low value must be less than high value.)") @@ -45,10 +45,10 @@ describe("eval on distribution functions", () => { describe("add", () => { testEval("add(normal(5,2), normal(10,2))", "Ok(Normal(15,2.8284271247461903))") testEval("add(normal(5,2), lognormal(10,2))", "Ok(Sample Set Distribution)") - testEval("add(normal(5,2), 3)", "Ok(Point Set Distribution)") - testEval("add(3, normal(5,2))", "Ok(Point Set Distribution)") - testEval("3+normal(5,2)", "Ok(Point Set Distribution)") - testEval("normal(5,2)+3", "Ok(Point Set Distribution)") + testEval("add(normal(5,2), 3)", "Ok(Normal(8,2))") + testEval("add(3, normal(5,2))", "Ok(Normal(8,2))") + testEval("3+normal(5,2)", "Ok(Normal(8,2))") + testEval("normal(5,2)+3", "Ok(Normal(8,2))") }) describe("truncate", () => { testEval("truncateLeft(normal(5,2), 3)", "Ok(Point Set Distribution)") @@ -93,6 +93,11 @@ describe("eval on distribution functions", () => { testEval("mx(normal(5,2), normal(10,1), normal(15, 1))", "Ok(Point Set Distribution)") testEval("mixture(normal(5,2), normal(10,1), [0.2, 0.4])", "Ok(Point Set Distribution)") }) + + describe("subtract", () => { + testEval("10 - normal(5, 1)", "Ok(Normal(5,1))") + testEval("normal(5, 1) - 10", "Ok(Normal(-5,1))") + }) }) describe("parse on distribution functions", () => { @@ -101,6 +106,10 @@ describe("parse on distribution functions", () => { testParse("3 ^ normal(5,1)", "Ok((:pow 3 (:normal 5 1)))") testParse("normal(5,2) ^ 3", "Ok((:pow (:normal 5 2) 3))") }) + describe("subtraction", () => { + testParse("10 - normal(5,1)", "Ok((:subtract 10 (:normal 5 1)))") + testParse("normal(5,1) - 10", "Ok((:subtract (:normal 5 1) 10))") + }) describe("pointwise arithmetic expressions", () => { testParse(~skip=true, "normal(5,2) .+ normal(5,1)", "Ok((:dotAdd (:normal 5 2) (:normal 5 1)))") testParse( diff --git a/packages/squiggle-lang/src/rescript/Distributions/SymbolicDist/SymbolicDist.res b/packages/squiggle-lang/src/rescript/Distributions/SymbolicDist/SymbolicDist.res index e1d01390..9f33d0d4 100644 --- a/packages/squiggle-lang/src/rescript/Distributions/SymbolicDist/SymbolicDist.res +++ b/packages/squiggle-lang/src/rescript/Distributions/SymbolicDist/SymbolicDist.res @@ -44,6 +44,23 @@ module Normal = { | #Subtract => Some(subtract(n1, n2)) | _ => None } + + let operateFloatFirst = (operation: Operation.Algebraic.t, n1: float, n2: t) => + switch operation { + | #Add => Some(#Normal({mean: n1 +. n2.mean, stdev: n2.stdev})) + | #Subtract => Some(#Normal({mean: n1 -. n2.mean, stdev: n2.stdev})) + | #Multiply => Some(#Normal({mean: n1 *. n2.mean, stdev: n1 *. n2.stdev})) + | _ => None + } + + let operateFloatSecond = (operation: Operation.Algebraic.t, n1: t, n2: float) => + switch operation { + | #Add => Some(#Normal({mean: n1.mean +. n2, stdev: n1.stdev})) + | #Subtract => Some(#Normal({mean: n1.mean -. n2, stdev: n1.stdev})) + | #Multiply => Some(#Normal({mean: n1.mean *. n2, stdev: n1.stdev})) + | #Divide => Some(#Normal({mean: n1.mean /. n2, stdev: n1.stdev /. n2})) + | _ => None + } } module Exponential = { @@ -341,6 +358,16 @@ module T = { } | (#Normal(v1), #Normal(v2)) => Normal.operate(op, v1, v2) |> E.O.dimap(r => #AnalyticalSolution(r), () => #NoSolution) + | (#Normal(v1), #Float(v2)) => + Normal.operateFloatSecond(op, v1, v2) |> E.O.dimap( + r => #AnalyticalSolution(r), + () => #NoSolution, + ) + | (#Float(v1), #Normal(v2)) => + Normal.operateFloatFirst(op, v1, v2) |> E.O.dimap( + r => #AnalyticalSolution(r), + () => #NoSolution, + ) | (#Lognormal(v1), #Lognormal(v2)) => Lognormal.operate(op, v1, v2) |> E.O.dimap(r => #AnalyticalSolution(r), () => #NoSolution) | _ => #NoSolution diff --git a/packages/squiggle-lang/src/rescript/ReducerInterface/ReducerInterface_GenericDistribution.res b/packages/squiggle-lang/src/rescript/ReducerInterface/ReducerInterface_GenericDistribution.res index 918c4b70..b0ce84e3 100644 --- a/packages/squiggle-lang/src/rescript/ReducerInterface/ReducerInterface_GenericDistribution.res +++ b/packages/squiggle-lang/src/rescript/ReducerInterface/ReducerInterface_GenericDistribution.res @@ -117,7 +117,8 @@ module Helpers = { | Error(err) => GenDistError(ArgumentError(err)) } } - | Some(EvDistribution(b)) => switch parseDistributionArray(args) { + | Some(EvDistribution(b)) => + switch parseDistributionArray(args) { | Ok(distributions) => mixtureWithDefaultWeights(distributions) | Error(err) => GenDistError(ArgumentError(err)) }