Add normal distribution analytical simplifications

This commit is contained in:
Sam Nolan 2022-04-13 14:36:30 +10:00
parent c6e78a1fd4
commit 348b1c9ac6
3 changed files with 43 additions and 6 deletions

View File

@ -19,7 +19,7 @@ describe("eval on distribution functions", () => {
testEval("lognormal(5,2)", "Ok(Lognormal(5,2))") testEval("lognormal(5,2)", "Ok(Lognormal(5,2))")
}) })
describe("unaryMinus", () => { describe("unaryMinus", () => {
testEval("mean(-normal(5,2))", "Ok(-5.002887370380851)") testEval("mean(-normal(5,2))", "Ok(-5)")
}) })
describe("to", () => { describe("to", () => {
testEval("5 to 2", "Error(TODO: Low value must be less than high value.)") testEval("5 to 2", "Error(TODO: Low value must be less than high value.)")
@ -45,10 +45,10 @@ describe("eval on distribution functions", () => {
describe("add", () => { describe("add", () => {
testEval("add(normal(5,2), normal(10,2))", "Ok(Normal(15,2.8284271247461903))") testEval("add(normal(5,2), normal(10,2))", "Ok(Normal(15,2.8284271247461903))")
testEval("add(normal(5,2), lognormal(10,2))", "Ok(Sample Set Distribution)") testEval("add(normal(5,2), lognormal(10,2))", "Ok(Sample Set Distribution)")
testEval("add(normal(5,2), 3)", "Ok(Point Set Distribution)") testEval("add(normal(5,2), 3)", "Ok(Normal(8,2))")
testEval("add(3, normal(5,2))", "Ok(Point Set Distribution)") testEval("add(3, normal(5,2))", "Ok(Normal(8,2))")
testEval("3+normal(5,2)", "Ok(Point Set Distribution)") testEval("3+normal(5,2)", "Ok(Normal(8,2))")
testEval("normal(5,2)+3", "Ok(Point Set Distribution)") testEval("normal(5,2)+3", "Ok(Normal(8,2))")
}) })
describe("truncate", () => { describe("truncate", () => {
testEval("truncateLeft(normal(5,2), 3)", "Ok(Point Set Distribution)") testEval("truncateLeft(normal(5,2), 3)", "Ok(Point Set Distribution)")
@ -93,6 +93,11 @@ describe("eval on distribution functions", () => {
testEval("mx(normal(5,2), normal(10,1), normal(15, 1))", "Ok(Point Set Distribution)") testEval("mx(normal(5,2), normal(10,1), normal(15, 1))", "Ok(Point Set Distribution)")
testEval("mixture(normal(5,2), normal(10,1), [0.2, 0.4])", "Ok(Point Set Distribution)") testEval("mixture(normal(5,2), normal(10,1), [0.2, 0.4])", "Ok(Point Set Distribution)")
}) })
describe("subtract", () => {
testEval("10 - normal(5, 1)", "Ok(Normal(5,1))")
testEval("normal(5, 1) - 10", "Ok(Normal(-5,1))")
})
}) })
describe("parse on distribution functions", () => { describe("parse on distribution functions", () => {
@ -101,6 +106,10 @@ describe("parse on distribution functions", () => {
testParse("3 ^ normal(5,1)", "Ok((:pow 3 (:normal 5 1)))") testParse("3 ^ normal(5,1)", "Ok((:pow 3 (:normal 5 1)))")
testParse("normal(5,2) ^ 3", "Ok((:pow (:normal 5 2) 3))") testParse("normal(5,2) ^ 3", "Ok((:pow (:normal 5 2) 3))")
}) })
describe("subtraction", () => {
testParse("10 - normal(5,1)", "Ok((:subtract 10 (:normal 5 1)))")
testParse("normal(5,1) - 10", "Ok((:subtract (:normal 5 1) 10))")
})
describe("pointwise arithmetic expressions", () => { describe("pointwise arithmetic expressions", () => {
testParse(~skip=true, "normal(5,2) .+ normal(5,1)", "Ok((:dotAdd (:normal 5 2) (:normal 5 1)))") testParse(~skip=true, "normal(5,2) .+ normal(5,1)", "Ok((:dotAdd (:normal 5 2) (:normal 5 1)))")
testParse( testParse(

View File

@ -44,6 +44,23 @@ module Normal = {
| #Subtract => Some(subtract(n1, n2)) | #Subtract => Some(subtract(n1, n2))
| _ => None | _ => None
} }
let operateFloatFirst = (operation: Operation.Algebraic.t, n1: float, n2: t) =>
switch operation {
| #Add => Some(#Normal({mean: n1 +. n2.mean, stdev: n2.stdev}))
| #Subtract => Some(#Normal({mean: n1 -. n2.mean, stdev: n2.stdev}))
| #Multiply => Some(#Normal({mean: n1 *. n2.mean, stdev: n1 *. n2.stdev}))
| _ => None
}
let operateFloatSecond = (operation: Operation.Algebraic.t, n1: t, n2: float) =>
switch operation {
| #Add => Some(#Normal({mean: n1.mean +. n2, stdev: n1.stdev}))
| #Subtract => Some(#Normal({mean: n1.mean -. n2, stdev: n1.stdev}))
| #Multiply => Some(#Normal({mean: n1.mean *. n2, stdev: n1.stdev}))
| #Divide => Some(#Normal({mean: n1.mean /. n2, stdev: n1.stdev /. n2}))
| _ => None
}
} }
module Exponential = { module Exponential = {
@ -341,6 +358,16 @@ module T = {
} }
| (#Normal(v1), #Normal(v2)) => | (#Normal(v1), #Normal(v2)) =>
Normal.operate(op, v1, v2) |> E.O.dimap(r => #AnalyticalSolution(r), () => #NoSolution) Normal.operate(op, v1, v2) |> E.O.dimap(r => #AnalyticalSolution(r), () => #NoSolution)
| (#Normal(v1), #Float(v2)) =>
Normal.operateFloatSecond(op, v1, v2) |> E.O.dimap(
r => #AnalyticalSolution(r),
() => #NoSolution,
)
| (#Float(v1), #Normal(v2)) =>
Normal.operateFloatFirst(op, v1, v2) |> E.O.dimap(
r => #AnalyticalSolution(r),
() => #NoSolution,
)
| (#Lognormal(v1), #Lognormal(v2)) => | (#Lognormal(v1), #Lognormal(v2)) =>
Lognormal.operate(op, v1, v2) |> E.O.dimap(r => #AnalyticalSolution(r), () => #NoSolution) Lognormal.operate(op, v1, v2) |> E.O.dimap(r => #AnalyticalSolution(r), () => #NoSolution)
| _ => #NoSolution | _ => #NoSolution

View File

@ -117,7 +117,8 @@ module Helpers = {
| Error(err) => GenDistError(ArgumentError(err)) | Error(err) => GenDistError(ArgumentError(err))
} }
} }
| Some(EvDistribution(b)) => switch parseDistributionArray(args) { | Some(EvDistribution(b)) =>
switch parseDistributionArray(args) {
| Ok(distributions) => mixtureWithDefaultWeights(distributions) | Ok(distributions) => mixtureWithDefaultWeights(distributions)
| Error(err) => GenDistError(ArgumentError(err)) | Error(err) => GenDistError(ArgumentError(err))
} }