diff --git a/packages/squiggle-lang/__tests__/ReducerInterface/ReducerInterface_Distribution_test.res b/packages/squiggle-lang/__tests__/ReducerInterface/ReducerInterface_Distribution_test.res index 023bdcdf..6045d6e0 100644 --- a/packages/squiggle-lang/__tests__/ReducerInterface/ReducerInterface_Distribution_test.res +++ b/packages/squiggle-lang/__tests__/ReducerInterface/ReducerInterface_Distribution_test.res @@ -50,6 +50,26 @@ describe("eval on distribution functions", () => { testEval("3+normal(5,2)", "Ok(Normal(8,2))") testEval("normal(5,2)+3", "Ok(Normal(8,2))") }) + describe("subtract", () => { + testEval("10 - normal(5, 1)", "Ok(Normal(5,1))") + testEval("normal(5, 1) - 10", "Ok(Normal(-5,1))") + }) + describe("multiply", () => { + testEval("normal(10, 2) * 2", "Ok(Normal(20,4))") + testEval("2 * normal(10, 2)", "Ok(Normal(20,4))") + testEval("lognormal(5,2) * lognormal(10,2)", "Ok(Lognormal(15,4))") + testEval("lognormal(10, 2) * lognormal(5, 2)", "Ok(Lognormal(15,4))") + testEval("2 * lognormal(5, 2)", "Ok(Lognormal(5.693147180559945,2))") + testEval("lognormal(5, 2) * 2", "Ok(Lognormal(5.693147180559945,2))") + }) + describe("division", () => { + testEval("lognormal(5,2) / lognormal(10,2)", "Ok(Lognormal(-5,4))") + testEval("lognormal(10,2) / lognormal(5,2)", "Ok(Lognormal(5,4))") + testEval("lognormal(5, 2) / 2", "Ok(Lognormal(4.306852819440055,2))") + testEval("2 / lognormal(5, 2)", "Ok(Lognormal(-4.306852819440055,2))") + testEval("2 / normal(10, 2)", "Ok(Point Set Distribution)") + testEval("normal(10, 2) / 2", "Ok(Normal(5,1))") + }) describe("truncate", () => { testEval("truncateLeft(normal(5,2), 3)", "Ok(Point Set Distribution)") testEval("truncateRight(normal(5,2), 3)", "Ok(Point Set Distribution)") @@ -93,11 +113,6 @@ describe("eval on distribution functions", () => { testEval("mx(normal(5,2), normal(10,1), normal(15, 1))", "Ok(Point Set Distribution)") testEval("mixture(normal(5,2), normal(10,1), [0.2, 0.4])", "Ok(Point Set Distribution)") }) - - describe("subtract", () => { - testEval("10 - normal(5, 1)", "Ok(Normal(5,1))") - testEval("normal(5, 1) - 10", "Ok(Normal(-5,1))") - }) }) describe("parse on distribution functions", () => { diff --git a/packages/squiggle-lang/src/rescript/Distributions/SymbolicDist/SymbolicDist.res b/packages/squiggle-lang/src/rescript/Distributions/SymbolicDist/SymbolicDist.res index 9f33d0d4..cb47cc17 100644 --- a/packages/squiggle-lang/src/rescript/Distributions/SymbolicDist/SymbolicDist.res +++ b/packages/squiggle-lang/src/rescript/Distributions/SymbolicDist/SymbolicDist.res @@ -57,7 +57,7 @@ module Normal = { switch operation { | #Add => Some(#Normal({mean: n1.mean +. n2, stdev: n1.stdev})) | #Subtract => Some(#Normal({mean: n1.mean -. n2, stdev: n1.stdev})) - | #Multiply => Some(#Normal({mean: n1.mean *. n2, stdev: n1.stdev})) + | #Multiply => Some(#Normal({mean: n1.mean *. n2, stdev: n1.stdev *. n2})) | #Divide => Some(#Normal({mean: n1.mean /. n2, stdev: n1.stdev /. n2})) | _ => None } @@ -167,6 +167,22 @@ module Lognormal = { | #Divide => Some(divide(n1, n2)) | _ => None } + + let operateFloatFirst = (operation: Operation.Algebraic.t, n1: float, n2: t) => + switch operation { + | #Multiply => + n1 > 0.0 ? Some(#Lognormal({mu: Js.Math.log(n1) +. n2.mu, sigma: n2.sigma})) : None + | #Divide => n1 > 0.0 ? Some(#Lognormal({mu: Js.Math.log(n1) -. n2.mu, sigma: n2.sigma})) : None + | _ => None + } + + let operateFloatSecond = (operation: Operation.Algebraic.t, n1: t, n2: float) => + switch operation { + | #Multiply => + n2 > 0.0 ? Some(#Lognormal({mu: n1.mu +. Js.Math.log(n2), sigma: n1.sigma})) : None + | #Divide => n2 > 0.0 ? Some(#Lognormal({mu: n1.mu -. Js.Math.log(n2), sigma: n1.sigma})) : None + | _ => None + } } module Uniform = { @@ -358,18 +374,28 @@ module T = { } | (#Normal(v1), #Normal(v2)) => Normal.operate(op, v1, v2) |> E.O.dimap(r => #AnalyticalSolution(r), () => #NoSolution) - | (#Normal(v1), #Float(v2)) => - Normal.operateFloatSecond(op, v1, v2) |> E.O.dimap( - r => #AnalyticalSolution(r), - () => #NoSolution, - ) | (#Float(v1), #Normal(v2)) => Normal.operateFloatFirst(op, v1, v2) |> E.O.dimap( r => #AnalyticalSolution(r), () => #NoSolution, ) + | (#Normal(v1), #Float(v2)) => + Normal.operateFloatSecond(op, v1, v2) |> E.O.dimap( + r => #AnalyticalSolution(r), + () => #NoSolution, + ) | (#Lognormal(v1), #Lognormal(v2)) => Lognormal.operate(op, v1, v2) |> E.O.dimap(r => #AnalyticalSolution(r), () => #NoSolution) + | (#Float(v1), #Lognormal(v2)) => + Lognormal.operateFloatFirst(op, v1, v2) |> E.O.dimap( + r => #AnalyticalSolution(r), + () => #NoSolution, + ) + | (#Lognormal(v1), #Float(v2)) => + Lognormal.operateFloatSecond(op, v1, v2) |> E.O.dimap( + r => #AnalyticalSolution(r), + () => #NoSolution, + ) | _ => #NoSolution }