From 348b1c9ac673de7e1617588b736454a5ecdf95f0 Mon Sep 17 00:00:00 2001 From: Sam Nolan Date: Wed, 13 Apr 2022 14:36:30 +1000 Subject: [PATCH 1/2] Add normal distribution analytical simplifications --- .../ReducerInterface_Distribution_test.res | 19 +++++++++---- .../SymbolicDist/SymbolicDist.res | 27 +++++++++++++++++++ .../ReducerInterface_GenericDistribution.res | 3 ++- 3 files changed, 43 insertions(+), 6 deletions(-) diff --git a/packages/squiggle-lang/__tests__/ReducerInterface/ReducerInterface_Distribution_test.res b/packages/squiggle-lang/__tests__/ReducerInterface/ReducerInterface_Distribution_test.res index 0a131d93..023bdcdf 100644 --- a/packages/squiggle-lang/__tests__/ReducerInterface/ReducerInterface_Distribution_test.res +++ b/packages/squiggle-lang/__tests__/ReducerInterface/ReducerInterface_Distribution_test.res @@ -19,7 +19,7 @@ describe("eval on distribution functions", () => { testEval("lognormal(5,2)", "Ok(Lognormal(5,2))") }) describe("unaryMinus", () => { - testEval("mean(-normal(5,2))", "Ok(-5.002887370380851)") + testEval("mean(-normal(5,2))", "Ok(-5)") }) describe("to", () => { testEval("5 to 2", "Error(TODO: Low value must be less than high value.)") @@ -45,10 +45,10 @@ describe("eval on distribution functions", () => { describe("add", () => { testEval("add(normal(5,2), normal(10,2))", "Ok(Normal(15,2.8284271247461903))") testEval("add(normal(5,2), lognormal(10,2))", "Ok(Sample Set Distribution)") - testEval("add(normal(5,2), 3)", "Ok(Point Set Distribution)") - testEval("add(3, normal(5,2))", "Ok(Point Set Distribution)") - testEval("3+normal(5,2)", "Ok(Point Set Distribution)") - testEval("normal(5,2)+3", "Ok(Point Set Distribution)") + testEval("add(normal(5,2), 3)", "Ok(Normal(8,2))") + testEval("add(3, normal(5,2))", "Ok(Normal(8,2))") + testEval("3+normal(5,2)", "Ok(Normal(8,2))") + testEval("normal(5,2)+3", "Ok(Normal(8,2))") }) describe("truncate", () => { testEval("truncateLeft(normal(5,2), 3)", "Ok(Point Set Distribution)") @@ -93,6 +93,11 @@ describe("eval on distribution functions", () => { testEval("mx(normal(5,2), normal(10,1), normal(15, 1))", "Ok(Point Set Distribution)") testEval("mixture(normal(5,2), normal(10,1), [0.2, 0.4])", "Ok(Point Set Distribution)") }) + + describe("subtract", () => { + testEval("10 - normal(5, 1)", "Ok(Normal(5,1))") + testEval("normal(5, 1) - 10", "Ok(Normal(-5,1))") + }) }) describe("parse on distribution functions", () => { @@ -101,6 +106,10 @@ describe("parse on distribution functions", () => { testParse("3 ^ normal(5,1)", "Ok((:pow 3 (:normal 5 1)))") testParse("normal(5,2) ^ 3", "Ok((:pow (:normal 5 2) 3))") }) + describe("subtraction", () => { + testParse("10 - normal(5,1)", "Ok((:subtract 10 (:normal 5 1)))") + testParse("normal(5,1) - 10", "Ok((:subtract (:normal 5 1) 10))") + }) describe("pointwise arithmetic expressions", () => { testParse(~skip=true, "normal(5,2) .+ normal(5,1)", "Ok((:dotAdd (:normal 5 2) (:normal 5 1)))") testParse( diff --git a/packages/squiggle-lang/src/rescript/Distributions/SymbolicDist/SymbolicDist.res b/packages/squiggle-lang/src/rescript/Distributions/SymbolicDist/SymbolicDist.res index e1d01390..9f33d0d4 100644 --- a/packages/squiggle-lang/src/rescript/Distributions/SymbolicDist/SymbolicDist.res +++ b/packages/squiggle-lang/src/rescript/Distributions/SymbolicDist/SymbolicDist.res @@ -44,6 +44,23 @@ module Normal = { | #Subtract => Some(subtract(n1, n2)) | _ => None } + + let operateFloatFirst = (operation: Operation.Algebraic.t, n1: float, n2: t) => + switch operation { + | #Add => Some(#Normal({mean: n1 +. n2.mean, stdev: n2.stdev})) + | #Subtract => Some(#Normal({mean: n1 -. n2.mean, stdev: n2.stdev})) + | #Multiply => Some(#Normal({mean: n1 *. n2.mean, stdev: n1 *. n2.stdev})) + | _ => None + } + + let operateFloatSecond = (operation: Operation.Algebraic.t, n1: t, n2: float) => + switch operation { + | #Add => Some(#Normal({mean: n1.mean +. n2, stdev: n1.stdev})) + | #Subtract => Some(#Normal({mean: n1.mean -. n2, stdev: n1.stdev})) + | #Multiply => Some(#Normal({mean: n1.mean *. n2, stdev: n1.stdev})) + | #Divide => Some(#Normal({mean: n1.mean /. n2, stdev: n1.stdev /. n2})) + | _ => None + } } module Exponential = { @@ -341,6 +358,16 @@ module T = { } | (#Normal(v1), #Normal(v2)) => Normal.operate(op, v1, v2) |> E.O.dimap(r => #AnalyticalSolution(r), () => #NoSolution) + | (#Normal(v1), #Float(v2)) => + Normal.operateFloatSecond(op, v1, v2) |> E.O.dimap( + r => #AnalyticalSolution(r), + () => #NoSolution, + ) + | (#Float(v1), #Normal(v2)) => + Normal.operateFloatFirst(op, v1, v2) |> E.O.dimap( + r => #AnalyticalSolution(r), + () => #NoSolution, + ) | (#Lognormal(v1), #Lognormal(v2)) => Lognormal.operate(op, v1, v2) |> E.O.dimap(r => #AnalyticalSolution(r), () => #NoSolution) | _ => #NoSolution diff --git a/packages/squiggle-lang/src/rescript/ReducerInterface/ReducerInterface_GenericDistribution.res b/packages/squiggle-lang/src/rescript/ReducerInterface/ReducerInterface_GenericDistribution.res index 918c4b70..b0ce84e3 100644 --- a/packages/squiggle-lang/src/rescript/ReducerInterface/ReducerInterface_GenericDistribution.res +++ b/packages/squiggle-lang/src/rescript/ReducerInterface/ReducerInterface_GenericDistribution.res @@ -117,7 +117,8 @@ module Helpers = { | Error(err) => GenDistError(ArgumentError(err)) } } - | Some(EvDistribution(b)) => switch parseDistributionArray(args) { + | Some(EvDistribution(b)) => + switch parseDistributionArray(args) { | Ok(distributions) => mixtureWithDefaultWeights(distributions) | Error(err) => GenDistError(ArgumentError(err)) } From 948a8dd6512cc1e7e589c77b882012d3d08e6f84 Mon Sep 17 00:00:00 2001 From: Sam Nolan Date: Wed, 13 Apr 2022 15:11:14 +1000 Subject: [PATCH 2/2] Add analytic solutions for normal and lognormal --- .../ReducerInterface_Distribution_test.res | 25 +++++++++--- .../SymbolicDist/SymbolicDist.res | 38 ++++++++++++++++--- 2 files changed, 52 insertions(+), 11 deletions(-) diff --git a/packages/squiggle-lang/__tests__/ReducerInterface/ReducerInterface_Distribution_test.res b/packages/squiggle-lang/__tests__/ReducerInterface/ReducerInterface_Distribution_test.res index 023bdcdf..6045d6e0 100644 --- a/packages/squiggle-lang/__tests__/ReducerInterface/ReducerInterface_Distribution_test.res +++ b/packages/squiggle-lang/__tests__/ReducerInterface/ReducerInterface_Distribution_test.res @@ -50,6 +50,26 @@ describe("eval on distribution functions", () => { testEval("3+normal(5,2)", "Ok(Normal(8,2))") testEval("normal(5,2)+3", "Ok(Normal(8,2))") }) + describe("subtract", () => { + testEval("10 - normal(5, 1)", "Ok(Normal(5,1))") + testEval("normal(5, 1) - 10", "Ok(Normal(-5,1))") + }) + describe("multiply", () => { + testEval("normal(10, 2) * 2", "Ok(Normal(20,4))") + testEval("2 * normal(10, 2)", "Ok(Normal(20,4))") + testEval("lognormal(5,2) * lognormal(10,2)", "Ok(Lognormal(15,4))") + testEval("lognormal(10, 2) * lognormal(5, 2)", "Ok(Lognormal(15,4))") + testEval("2 * lognormal(5, 2)", "Ok(Lognormal(5.693147180559945,2))") + testEval("lognormal(5, 2) * 2", "Ok(Lognormal(5.693147180559945,2))") + }) + describe("division", () => { + testEval("lognormal(5,2) / lognormal(10,2)", "Ok(Lognormal(-5,4))") + testEval("lognormal(10,2) / lognormal(5,2)", "Ok(Lognormal(5,4))") + testEval("lognormal(5, 2) / 2", "Ok(Lognormal(4.306852819440055,2))") + testEval("2 / lognormal(5, 2)", "Ok(Lognormal(-4.306852819440055,2))") + testEval("2 / normal(10, 2)", "Ok(Point Set Distribution)") + testEval("normal(10, 2) / 2", "Ok(Normal(5,1))") + }) describe("truncate", () => { testEval("truncateLeft(normal(5,2), 3)", "Ok(Point Set Distribution)") testEval("truncateRight(normal(5,2), 3)", "Ok(Point Set Distribution)") @@ -93,11 +113,6 @@ describe("eval on distribution functions", () => { testEval("mx(normal(5,2), normal(10,1), normal(15, 1))", "Ok(Point Set Distribution)") testEval("mixture(normal(5,2), normal(10,1), [0.2, 0.4])", "Ok(Point Set Distribution)") }) - - describe("subtract", () => { - testEval("10 - normal(5, 1)", "Ok(Normal(5,1))") - testEval("normal(5, 1) - 10", "Ok(Normal(-5,1))") - }) }) describe("parse on distribution functions", () => { diff --git a/packages/squiggle-lang/src/rescript/Distributions/SymbolicDist/SymbolicDist.res b/packages/squiggle-lang/src/rescript/Distributions/SymbolicDist/SymbolicDist.res index 9f33d0d4..cb47cc17 100644 --- a/packages/squiggle-lang/src/rescript/Distributions/SymbolicDist/SymbolicDist.res +++ b/packages/squiggle-lang/src/rescript/Distributions/SymbolicDist/SymbolicDist.res @@ -57,7 +57,7 @@ module Normal = { switch operation { | #Add => Some(#Normal({mean: n1.mean +. n2, stdev: n1.stdev})) | #Subtract => Some(#Normal({mean: n1.mean -. n2, stdev: n1.stdev})) - | #Multiply => Some(#Normal({mean: n1.mean *. n2, stdev: n1.stdev})) + | #Multiply => Some(#Normal({mean: n1.mean *. n2, stdev: n1.stdev *. n2})) | #Divide => Some(#Normal({mean: n1.mean /. n2, stdev: n1.stdev /. n2})) | _ => None } @@ -167,6 +167,22 @@ module Lognormal = { | #Divide => Some(divide(n1, n2)) | _ => None } + + let operateFloatFirst = (operation: Operation.Algebraic.t, n1: float, n2: t) => + switch operation { + | #Multiply => + n1 > 0.0 ? Some(#Lognormal({mu: Js.Math.log(n1) +. n2.mu, sigma: n2.sigma})) : None + | #Divide => n1 > 0.0 ? Some(#Lognormal({mu: Js.Math.log(n1) -. n2.mu, sigma: n2.sigma})) : None + | _ => None + } + + let operateFloatSecond = (operation: Operation.Algebraic.t, n1: t, n2: float) => + switch operation { + | #Multiply => + n2 > 0.0 ? Some(#Lognormal({mu: n1.mu +. Js.Math.log(n2), sigma: n1.sigma})) : None + | #Divide => n2 > 0.0 ? Some(#Lognormal({mu: n1.mu -. Js.Math.log(n2), sigma: n1.sigma})) : None + | _ => None + } } module Uniform = { @@ -358,18 +374,28 @@ module T = { } | (#Normal(v1), #Normal(v2)) => Normal.operate(op, v1, v2) |> E.O.dimap(r => #AnalyticalSolution(r), () => #NoSolution) - | (#Normal(v1), #Float(v2)) => - Normal.operateFloatSecond(op, v1, v2) |> E.O.dimap( - r => #AnalyticalSolution(r), - () => #NoSolution, - ) | (#Float(v1), #Normal(v2)) => Normal.operateFloatFirst(op, v1, v2) |> E.O.dimap( r => #AnalyticalSolution(r), () => #NoSolution, ) + | (#Normal(v1), #Float(v2)) => + Normal.operateFloatSecond(op, v1, v2) |> E.O.dimap( + r => #AnalyticalSolution(r), + () => #NoSolution, + ) | (#Lognormal(v1), #Lognormal(v2)) => Lognormal.operate(op, v1, v2) |> E.O.dimap(r => #AnalyticalSolution(r), () => #NoSolution) + | (#Float(v1), #Lognormal(v2)) => + Lognormal.operateFloatFirst(op, v1, v2) |> E.O.dimap( + r => #AnalyticalSolution(r), + () => #NoSolution, + ) + | (#Lognormal(v1), #Float(v2)) => + Lognormal.operateFloatSecond(op, v1, v2) |> E.O.dimap( + r => #AnalyticalSolution(r), + () => #NoSolution, + ) | _ => #NoSolution }