From ee6551a6941e7e4b358ffe21ce7f594c4400e2f7 Mon Sep 17 00:00:00 2001 From: Sam Nolan Date: Sat, 23 Apr 2022 16:57:06 -0400 Subject: [PATCH] Prevent negative standard deviation in symbolic multiplication Introduced in #242 --- .../ReducerInterface/ReducerInterface_Distribution_test.res | 1 + .../rescript/Distributions/SymbolicDist/SymbolicDist.res | 6 +++--- 2 files changed, 4 insertions(+), 3 deletions(-) diff --git a/packages/squiggle-lang/__tests__/ReducerInterface/ReducerInterface_Distribution_test.res b/packages/squiggle-lang/__tests__/ReducerInterface/ReducerInterface_Distribution_test.res index d4199f89..605797b9 100644 --- a/packages/squiggle-lang/__tests__/ReducerInterface/ReducerInterface_Distribution_test.res +++ b/packages/squiggle-lang/__tests__/ReducerInterface/ReducerInterface_Distribution_test.res @@ -20,6 +20,7 @@ describe("eval on distribution functions", () => { }) describe("unaryMinus", () => { testEval("mean(-normal(5,2))", "Ok(-5)") + testEval("-normal(5,2)", "Ok(Normal(-5,2))") }) describe("to", () => { testEval("5 to 2", "Error(Math Error: Low value must be less than high value.)") diff --git a/packages/squiggle-lang/src/rescript/Distributions/SymbolicDist/SymbolicDist.res b/packages/squiggle-lang/src/rescript/Distributions/SymbolicDist/SymbolicDist.res index 7ce95721..92249eae 100644 --- a/packages/squiggle-lang/src/rescript/Distributions/SymbolicDist/SymbolicDist.res +++ b/packages/squiggle-lang/src/rescript/Distributions/SymbolicDist/SymbolicDist.res @@ -52,7 +52,7 @@ module Normal = { switch operation { | #Add => Some(#Normal({mean: n1 +. n2.mean, stdev: n2.stdev})) | #Subtract => Some(#Normal({mean: n1 -. n2.mean, stdev: n2.stdev})) - | #Multiply => Some(#Normal({mean: n1 *. n2.mean, stdev: n1 *. n2.stdev})) + | #Multiply => Some(#Normal({mean: n1 *. n2.mean, stdev: Js.Math.abs_float(n1) *. n2.stdev})) | _ => None } @@ -60,8 +60,8 @@ module Normal = { switch operation { | #Add => Some(#Normal({mean: n1.mean +. n2, stdev: n1.stdev})) | #Subtract => Some(#Normal({mean: n1.mean -. n2, stdev: n1.stdev})) - | #Multiply => Some(#Normal({mean: n1.mean *. n2, stdev: n1.stdev *. n2})) - | #Divide => Some(#Normal({mean: n1.mean /. n2, stdev: n1.stdev /. n2})) + | #Multiply => Some(#Normal({mean: n1.mean *. n2, stdev: n1.stdev *. Js.Math.abs_float(n2)})) + | #Divide => Some(#Normal({mean: n1.mean /. n2, stdev: n1.stdev /. Js.Math.abs_float(n2)})) | _ => None } }