diff --git a/packages/squiggle-lang/__tests__/ReducerInterface/ReducerInterface_Distribution_test.res b/packages/squiggle-lang/__tests__/ReducerInterface/ReducerInterface_Distribution_test.res index 0a131d93..6045d6e0 100644 --- a/packages/squiggle-lang/__tests__/ReducerInterface/ReducerInterface_Distribution_test.res +++ b/packages/squiggle-lang/__tests__/ReducerInterface/ReducerInterface_Distribution_test.res @@ -19,7 +19,7 @@ describe("eval on distribution functions", () => { testEval("lognormal(5,2)", "Ok(Lognormal(5,2))") }) describe("unaryMinus", () => { - testEval("mean(-normal(5,2))", "Ok(-5.002887370380851)") + testEval("mean(-normal(5,2))", "Ok(-5)") }) describe("to", () => { testEval("5 to 2", "Error(TODO: Low value must be less than high value.)") @@ -45,10 +45,30 @@ describe("eval on distribution functions", () => { describe("add", () => { testEval("add(normal(5,2), normal(10,2))", "Ok(Normal(15,2.8284271247461903))") testEval("add(normal(5,2), lognormal(10,2))", "Ok(Sample Set Distribution)") - testEval("add(normal(5,2), 3)", "Ok(Point Set Distribution)") - testEval("add(3, normal(5,2))", "Ok(Point Set Distribution)") - testEval("3+normal(5,2)", "Ok(Point Set Distribution)") - testEval("normal(5,2)+3", "Ok(Point Set Distribution)") + testEval("add(normal(5,2), 3)", "Ok(Normal(8,2))") + testEval("add(3, normal(5,2))", "Ok(Normal(8,2))") + testEval("3+normal(5,2)", "Ok(Normal(8,2))") + testEval("normal(5,2)+3", "Ok(Normal(8,2))") + }) + describe("subtract", () => { + testEval("10 - normal(5, 1)", "Ok(Normal(5,1))") + testEval("normal(5, 1) - 10", "Ok(Normal(-5,1))") + }) + describe("multiply", () => { + testEval("normal(10, 2) * 2", "Ok(Normal(20,4))") + testEval("2 * normal(10, 2)", "Ok(Normal(20,4))") + testEval("lognormal(5,2) * lognormal(10,2)", "Ok(Lognormal(15,4))") + testEval("lognormal(10, 2) * lognormal(5, 2)", "Ok(Lognormal(15,4))") + testEval("2 * lognormal(5, 2)", "Ok(Lognormal(5.693147180559945,2))") + testEval("lognormal(5, 2) * 2", "Ok(Lognormal(5.693147180559945,2))") + }) + describe("division", () => { + testEval("lognormal(5,2) / lognormal(10,2)", "Ok(Lognormal(-5,4))") + testEval("lognormal(10,2) / lognormal(5,2)", "Ok(Lognormal(5,4))") + testEval("lognormal(5, 2) / 2", "Ok(Lognormal(4.306852819440055,2))") + testEval("2 / lognormal(5, 2)", "Ok(Lognormal(-4.306852819440055,2))") + testEval("2 / normal(10, 2)", "Ok(Point Set Distribution)") + testEval("normal(10, 2) / 2", "Ok(Normal(5,1))") }) describe("truncate", () => { testEval("truncateLeft(normal(5,2), 3)", "Ok(Point Set Distribution)") @@ -101,6 +121,10 @@ describe("parse on distribution functions", () => { testParse("3 ^ normal(5,1)", "Ok((:pow 3 (:normal 5 1)))") testParse("normal(5,2) ^ 3", "Ok((:pow (:normal 5 2) 3))") }) + describe("subtraction", () => { + testParse("10 - normal(5,1)", "Ok((:subtract 10 (:normal 5 1)))") + testParse("normal(5,1) - 10", "Ok((:subtract (:normal 5 1) 10))") + }) describe("pointwise arithmetic expressions", () => { testParse(~skip=true, "normal(5,2) .+ normal(5,1)", "Ok((:dotAdd (:normal 5 2) (:normal 5 1)))") testParse( diff --git a/packages/squiggle-lang/src/rescript/Distributions/SymbolicDist/SymbolicDist.res b/packages/squiggle-lang/src/rescript/Distributions/SymbolicDist/SymbolicDist.res index 69abc726..1652c799 100644 --- a/packages/squiggle-lang/src/rescript/Distributions/SymbolicDist/SymbolicDist.res +++ b/packages/squiggle-lang/src/rescript/Distributions/SymbolicDist/SymbolicDist.res @@ -44,6 +44,23 @@ module Normal = { | #Subtract => Some(subtract(n1, n2)) | _ => None } + + let operateFloatFirst = (operation: Operation.Algebraic.t, n1: float, n2: t) => + switch operation { + | #Add => Some(#Normal({mean: n1 +. n2.mean, stdev: n2.stdev})) + | #Subtract => Some(#Normal({mean: n1 -. n2.mean, stdev: n2.stdev})) + | #Multiply => Some(#Normal({mean: n1 *. n2.mean, stdev: n1 *. n2.stdev})) + | _ => None + } + + let operateFloatSecond = (operation: Operation.Algebraic.t, n1: t, n2: float) => + switch operation { + | #Add => Some(#Normal({mean: n1.mean +. n2, stdev: n1.stdev})) + | #Subtract => Some(#Normal({mean: n1.mean -. n2, stdev: n1.stdev})) + | #Multiply => Some(#Normal({mean: n1.mean *. n2, stdev: n1.stdev *. n2})) + | #Divide => Some(#Normal({mean: n1.mean /. n2, stdev: n1.stdev /. n2})) + | _ => None + } } module Exponential = { @@ -152,6 +169,22 @@ module Lognormal = { | #Divide => Some(divide(n1, n2)) | _ => None } + + let operateFloatFirst = (operation: Operation.Algebraic.t, n1: float, n2: t) => + switch operation { + | #Multiply => + n1 > 0.0 ? Some(#Lognormal({mu: Js.Math.log(n1) +. n2.mu, sigma: n2.sigma})) : None + | #Divide => n1 > 0.0 ? Some(#Lognormal({mu: Js.Math.log(n1) -. n2.mu, sigma: n2.sigma})) : None + | _ => None + } + + let operateFloatSecond = (operation: Operation.Algebraic.t, n1: t, n2: float) => + switch operation { + | #Multiply => + n2 > 0.0 ? Some(#Lognormal({mu: n1.mu +. Js.Math.log(n2), sigma: n1.sigma})) : None + | #Divide => n2 > 0.0 ? Some(#Lognormal({mu: n1.mu -. Js.Math.log(n2), sigma: n1.sigma})) : None + | _ => None + } } module Uniform = { @@ -343,8 +376,28 @@ module T = { } | (#Normal(v1), #Normal(v2)) => Normal.operate(op, v1, v2) |> E.O.dimap(r => #AnalyticalSolution(r), () => #NoSolution) + | (#Float(v1), #Normal(v2)) => + Normal.operateFloatFirst(op, v1, v2) |> E.O.dimap( + r => #AnalyticalSolution(r), + () => #NoSolution, + ) + | (#Normal(v1), #Float(v2)) => + Normal.operateFloatSecond(op, v1, v2) |> E.O.dimap( + r => #AnalyticalSolution(r), + () => #NoSolution, + ) | (#Lognormal(v1), #Lognormal(v2)) => Lognormal.operate(op, v1, v2) |> E.O.dimap(r => #AnalyticalSolution(r), () => #NoSolution) + | (#Float(v1), #Lognormal(v2)) => + Lognormal.operateFloatFirst(op, v1, v2) |> E.O.dimap( + r => #AnalyticalSolution(r), + () => #NoSolution, + ) + | (#Lognormal(v1), #Float(v2)) => + Lognormal.operateFloatSecond(op, v1, v2) |> E.O.dimap( + r => #AnalyticalSolution(r), + () => #NoSolution, + ) | _ => #NoSolution }