From 43704743af8f3bd8f51b6425a329e1f889b40824 Mon Sep 17 00:00:00 2001 From: Ozzie Gooen Date: Thu, 30 Jul 2020 13:54:35 +0100 Subject: [PATCH] New functions use new syntax --- .../expressionTree/ExpressionTreeEvaluator.re | 3 +- src/distPlus/expressionTree/Functions.re | 55 +++++++++++ src/distPlus/expressionTree/MathJsParser.re | 96 ++++++------------- src/distPlus/symbolic/SymbolicDist.re | 25 +++-- 4 files changed, 103 insertions(+), 76 deletions(-) diff --git a/src/distPlus/expressionTree/ExpressionTreeEvaluator.re b/src/distPlus/expressionTree/ExpressionTreeEvaluator.re index 7c19d6fc..91c1ec54 100644 --- a/src/distPlus/expressionTree/ExpressionTreeEvaluator.re +++ b/src/distPlus/expressionTree/ExpressionTreeEvaluator.re @@ -259,6 +259,7 @@ module FloatFromDist = { }; }; +// TODO: This forces things to be floats let callableFunction = (evaluationParams, name, args) => { let b = args @@ -267,7 +268,7 @@ let callableFunction = (evaluationParams, name, args) => { |> E.R.bind(_, Render.toFloat) ) |> E.A.R.firstErrorOrOpen; - b |> E.R.bind(_, Functions.fnn("normal")); + b |> E.R.bind(_, Functions.fnn(name)); }; module Render = { diff --git a/src/distPlus/expressionTree/Functions.re b/src/distPlus/expressionTree/Functions.re index 5dea5db7..f23d0738 100644 --- a/src/distPlus/expressionTree/Functions.re +++ b/src/distPlus/expressionTree/Functions.re @@ -8,20 +8,75 @@ let twoFloats = (fn, n1: node, n2: node): result(node, string) => | _ => Error("Variables have wrong type") }; +let threeFloats = (fn, n1: node, n2: node, n3: node): result(node, string) => + switch (n1, n2, n3) { + | ( + `SymbolicDist(`Float(a)), + `SymbolicDist(`Float(b)), + `SymbolicDist(`Float(c)), + ) => + fn(a, b, c) + | _ => Error("Variables have wrong type") + }; + let twoFloatsToOkSym = fn => twoFloats((f1, f2) => fn(f1, f2) |> toOkSym); +let threeFloats = fn => threeFloats((f1, f2, f3) => fn(f1, f2, f3)); + let apply2 = (fn, args): result(node, string) => switch (args) { | [|a, b|] => fn(a, b) | _ => Error("Needs 2 args") }; +let apply3 = (fn, args: array(node)): result(node, string) => + switch (args) { + | [|a, b, c|] => fn(a, b, c) + | _ => Error("Needs 3 args") + }; + +let to_: array(node) => result(node, string) = + fun + | [|`SymbolicDist(`Float(low)), `SymbolicDist(`Float(high))|] + when low <= 0.0 && low < high => { + Ok(`SymbolicDist(SymbolicDist.Normal.from90PercentCI(low, high))); + } + | [|`SymbolicDist(`Float(low)), `SymbolicDist(`Float(high))|] + when low < high => { + Ok(`SymbolicDist(SymbolicDist.Lognormal.from90PercentCI(low, high))); + } + | [|`SymbolicDist(`Float(_)), `SymbolicDist(_)|] => + Error("Low value must be less than high value.") + | _ => Error("Requires 2 variables"); + let fnn = (name, args: array(node)) => { switch (name) { | "normal" => apply2(twoFloatsToOkSym(SymbolicDist.Normal.make), args) | "uniform" => apply2(twoFloatsToOkSym(SymbolicDist.Uniform.make), args) | "beta" => apply2(twoFloatsToOkSym(SymbolicDist.Beta.make), args) | "cauchy" => apply2(twoFloatsToOkSym(SymbolicDist.Cauchy.make), args) + | "lognormal" => apply2(twoFloatsToOkSym(SymbolicDist.Lognormal.make), args) + | "lognormalFromMeanAndStdDev" => apply2(twoFloatsToOkSym(SymbolicDist.Lognormal.fromMeanAndStdev), args) + | "exponential" => + switch (args) { + | [| + `SymbolicDist(`Float(a)), + |] => + Ok(`SymbolicDist(SymbolicDist.Exponential.make(a))); + | _ => Error("Needs 3 valid arguments") + } + | "triangular" => + switch (args) { + | [| + `SymbolicDist(`Float(a)), + `SymbolicDist(`Float(b)), + `SymbolicDist(`Float(c)), + |] => + SymbolicDist.Triangular.make(a, b, c) + |> E.R.fmap(r => `SymbolicDist(r)) + | _ => Error("Needs 3 valid arguments") + } + | "to" => to_(args) | _ => Error("Function not found") }; }; diff --git a/src/distPlus/expressionTree/MathJsParser.re b/src/distPlus/expressionTree/MathJsParser.re index 41263c31..e0aa100e 100644 --- a/src/distPlus/expressionTree/MathJsParser.re +++ b/src/distPlus/expressionTree/MathJsParser.re @@ -93,56 +93,26 @@ module MathAdtToDistDst = { ); }; - let lognormal: - array(arg) => result(ExpressionTypes.ExpressionTree.node, string) = - fun - | [|Value(mu), Value(sigma)|] => - Ok(`SymbolicDist(`Lognormal({mu, sigma}))) - | [|Object(o)|] => { - let g = Js.Dict.get(o); - switch (g("mean"), g("stdev"), g("mu"), g("sigma")) { - | (Some(Value(mean)), Some(Value(stdev)), _, _) => - Ok( - `SymbolicDist( - SymbolicDist.Lognormal.fromMeanAndStdev(mean, stdev), - ), - ) - | (_, _, Some(Value(mu)), Some(Value(sigma))) => - Ok(`SymbolicDist(`Lognormal({mu, sigma}))) - | _ => Error("Lognormal distribution would need mean and stdev") - }; - } - | _ => Error("Wrong number of variables in lognormal distribution"); - - let to_: array(arg) => result(ExpressionTypes.ExpressionTree.node, string) = - fun - | [|Value(low), Value(high)|] when low <= 0.0 && low < high => { - Ok(`SymbolicDist(SymbolicDist.Normal.from90PercentCI(low, high))); - } - | [|Value(low), Value(high)|] when low < high => { + let lognormal = (args, parseArgs, nodeParser) => + switch (args) { + | [|Object(o)|] => + let g = s => + Js.Dict.get(o, s) |> E.O.toResult("") |> E.R.bind(_, nodeParser); + switch (g("mean"), g("stdev"), g("mu"), g("sigma")) { + | (Ok(mean), Ok(stdev), _, _) => Ok( - `SymbolicDist(SymbolicDist.Lognormal.from90PercentCI(low, high)), - ); - } - | [|Value(_), Value(_)|] => - Error("Low value must be less than high value.") - | _ => Error("Wrong number of variables in lognormal distribution"); - - let exponential: - array(arg) => result(ExpressionTypes.ExpressionTree.node, string) = - fun - | [|Value(rate)|] => Ok(`SymbolicDist(`Exponential({rate: rate}))) - | _ => Error("Wrong number of variables in Exponential distribution"); - - let triangular: - array(arg) => result(ExpressionTypes.ExpressionTree.node, string) = - fun - | [|Value(low), Value(medium), Value(high)|] - when low < medium && medium < high => - Ok(`SymbolicDist(`Triangular({low, medium, high}))) - | [|Value(_), Value(_), Value(_)|] => - Error("Triangular values must be increasing order") - | _ => Error("Wrong number of variables in triangle distribution"); + `CallableFunction(("lognormalFromMeanAndStdDev", [|mean, stdev|])), + ) + | (_, _, Ok(mu), Ok(sigma)) => + Ok(`CallableFunction(("lognormal", [|mu, sigma|]))) + | _ => Error("Lognormal distribution needs either mean and stdev or mu and sigma") + }; + | _ => + parseArgs() + |> E.R.fmap((args: array(ExpressionTypes.ExpressionTree.node)) => + `CallableFunction(("lognormal", args)) + ) + }; let multiModal = ( @@ -150,15 +120,6 @@ module MathAdtToDistDst = { weights: option(array(float)), ) => { let weights = weights |> E.O.default([||]); - - /*let dists: = - args - |> E.A.fmap( - fun - | Ok(a) => a - | Error(e) => Error(e) - );*/ - let firstWithError = args |> Belt.Array.getBy(_, Belt.Result.isError); let withoutErrors = args |> E.A.fmap(E.R.toOption) |> E.A.O.concatSomes; @@ -249,17 +210,22 @@ module MathAdtToDistDst = { }; let functionParser = (nodeParser, name, args) => { - let parseArgs = () => - args |> E.A.fmap(nodeParser) |> E.A.R.firstErrorOrOpen; + let parseArray = ags => + ags |> E.A.fmap(nodeParser) |> E.A.R.firstErrorOrOpen; + let parseArgs = () => parseArray(args); switch (name) { - | "normal" | "uniform" | "beta" | "caucy" => + | "normal" + | "uniform" + | "beta" + | "triangular" + | "to" + | "exponential" + | "cauchy" => parseArgs() - |> E.R.fmap( - ( - args: array(ExpressionTypes.ExpressionTree.node), - ) => + |> E.R.fmap((args: array(ExpressionTypes.ExpressionTree.node)) => `CallableFunction((name, args)) ) + | "lognormal" => lognormal(args, parseArgs, nodeParser) | "mm" => let weights = args diff --git a/src/distPlus/symbolic/SymbolicDist.re b/src/distPlus/symbolic/SymbolicDist.re index 6ad02704..33942187 100644 --- a/src/distPlus/symbolic/SymbolicDist.re +++ b/src/distPlus/symbolic/SymbolicDist.re @@ -2,8 +2,12 @@ open SymbolicTypes; module Exponential = { type t = exponential; - let make = (rate): symbolicDist => - `Exponential({rate}); + let make = (rate:float): symbolicDist => + `Exponential( + { + rate:rate + }, + ); let pdf = (x, t: t) => Jstat.exponential##pdf(x, t.rate); let cdf = (x, t: t) => Jstat.exponential##cdf(x, t.rate); let inv = (p, t: t) => Jstat.exponential##inv(p, t.rate); @@ -14,8 +18,7 @@ module Exponential = { module Cauchy = { type t = cauchy; - let make = (local, scale): symbolicDist => - `Cauchy({local,scale}); + let make = (local, scale): symbolicDist => `Cauchy({local, scale}); let pdf = (x, t: t) => Jstat.cauchy##pdf(x, t.local, t.scale); let cdf = (x, t: t) => Jstat.cauchy##cdf(x, t.local, t.scale); let inv = (p, t: t) => Jstat.cauchy##inv(p, t.local, t.scale); @@ -26,8 +29,10 @@ module Cauchy = { module Triangular = { type t = triangular; - let make = (low, medium, high): symbolicDist => - `Triangular({low, medium, high}); + let make = (low, medium, high): result(symbolicDist, string) => + low < medium && medium < high + ? Ok(`Triangular({low, medium, high})) + : Error("Triangular values must be increasing order"); let pdf = (x, t: t) => Jstat.triangular##pdf(x, t.low, t.high, t.medium); let cdf = (x, t: t) => Jstat.triangular##cdf(x, t.low, t.high, t.medium); let inv = (p, t: t) => Jstat.triangular##inv(p, t.low, t.high, t.medium); @@ -82,7 +87,7 @@ module Normal = { module Beta = { type t = beta; - let make = (alpha, beta) => `Beta({alpha, beta}) + let make = (alpha, beta) => `Beta({alpha, beta}); let pdf = (x, t: t) => Jstat.beta##pdf(x, t.alpha, t.beta); let cdf = (x, t: t) => Jstat.beta##cdf(x, t.alpha, t.beta); let inv = (p, t: t) => Jstat.beta##inv(p, t.alpha, t.beta); @@ -93,7 +98,7 @@ module Beta = { module Lognormal = { type t = lognormal; - let make = (mu, sigma) => `Lognormal({mu, sigma}) + let make = (mu, sigma) => `Lognormal({mu, sigma}); let pdf = (x, t: t) => Jstat.lognormal##pdf(x, t.mu, t.sigma); let cdf = (x, t: t) => Jstat.lognormal##cdf(x, t.mu, t.sigma); let inv = (p, t: t) => Jstat.lognormal##inv(p, t.mu, t.sigma); @@ -140,7 +145,7 @@ module Lognormal = { module Uniform = { type t = uniform; - let make = (low, high) => `Uniform({low, high}) + let make = (low, high) => `Uniform({low, high}); let pdf = (x, t: t) => Jstat.uniform##pdf(x, t.low, t.high); let cdf = (x, t: t) => Jstat.uniform##cdf(x, t.low, t.high); let inv = (p, t: t) => Jstat.uniform##inv(p, t.low, t.high); @@ -156,7 +161,7 @@ module Uniform = { module Float = { type t = float; - let make = t => `Float(t) + let make = t => `Float(t); let pdf = (x, t: t) => x == t ? 1.0 : 0.0; let cdf = (x, t: t) => x >= t ? 1.0 : 0.0; let inv = (p, t: t) => p < t ? 0.0 : 1.0;