From d3766f7a7f3b593dfbbfd7363b89412fe13a7900 Mon Sep 17 00:00:00 2001 From: Ozzie Gooen Date: Thu, 30 Jul 2020 11:44:01 +0100 Subject: [PATCH] Moving function creation into ExpressionTreeEvaluator as functions --- src/distPlus/distribution/Shape.re | 5 + .../expressionTree/ExpressionTreeEvaluator.re | 22 ++++ .../expressionTree/ExpressionTypes.re | 10 +- src/distPlus/expressionTree/Functions.re | 27 +++++ src/distPlus/expressionTree/MathJsParser.re | 101 +++++------------- src/distPlus/symbolic/SymbolicDist.re | 21 +++- 6 files changed, 104 insertions(+), 82 deletions(-) create mode 100644 src/distPlus/expressionTree/Functions.re diff --git a/src/distPlus/distribution/Shape.re b/src/distPlus/distribution/Shape.re index 01685c3c..fa0be2b8 100644 --- a/src/distPlus/distribution/Shape.re +++ b/src/distPlus/distribution/Shape.re @@ -216,6 +216,11 @@ let sample = (t: t): float => { bar; }; +let isFloat = (t:t) => switch(t){ +| Discrete({xyShape: {xs: [|_|], ys: [|1.0|]}}) => true +| _ => false +} + let sampleNRendered = (n, dist) => { let integralCache = T.Integral.get(dist); let distWithUpdatedIntegralCache = T.updateIntegralCache(Some(integralCache), dist); diff --git a/src/distPlus/expressionTree/ExpressionTreeEvaluator.re b/src/distPlus/expressionTree/ExpressionTreeEvaluator.re index 2ad84adc..7c19d6fc 100644 --- a/src/distPlus/expressionTree/ExpressionTreeEvaluator.re +++ b/src/distPlus/expressionTree/ExpressionTreeEvaluator.re @@ -259,10 +259,22 @@ module FloatFromDist = { }; }; +let callableFunction = (evaluationParams, name, args) => { + let b = + args + |> E.A.fmap(a => + Render.render(evaluationParams, a) + |> E.R.bind(_, Render.toFloat) + ) + |> E.A.R.firstErrorOrOpen; + b |> E.R.bind(_, Functions.fnn("normal")); +}; + module Render = { let rec operationToLeaf = (evaluationParams: evaluationParams, t: node): result(t, string) => { switch (t) { + | `Function(_) => Error("Cannot render a function") | `SymbolicDist(d) => Ok( `RenderedDist( @@ -275,6 +287,13 @@ module Render = { }; }; +let run = (node, fnNode) => { + switch (fnNode) { + | `Function(r) => Ok(r(node)) + | _ => Error("Not a function") + }; +}; + /* This function recursively goes through the nodes of the parse tree, replacing each Operation node and its subtree with a Data node. Whenever possible, the replacement produces a new Symbolic Data node, @@ -314,5 +333,8 @@ let toLeaf = FloatFromDist.operationToLeaf(evaluationParams, distToFloatOp, t) | `Normalize(t) => Normalize.operationToLeaf(evaluationParams, t) | `Render(t) => Render.operationToLeaf(evaluationParams, t) + | `Function(t) => Ok(`Function(t)) + | `CallableFunction(name, args) => + callableFunction(evaluationParams, name, args) }; }; diff --git a/src/distPlus/expressionTree/ExpressionTypes.re b/src/distPlus/expressionTree/ExpressionTypes.re index a76aee74..e7dbabcf 100644 --- a/src/distPlus/expressionTree/ExpressionTypes.re +++ b/src/distPlus/expressionTree/ExpressionTypes.re @@ -20,6 +20,8 @@ module ExpressionTree = { | `Truncate(option(float), option(float), node) | `Normalize(node) | `FloatFromDist(distToFloatOperation, node) + | `Function(node => result(node, string)) + | `CallableFunction(string, array(node)) ]; type samplingInputs = { @@ -71,8 +73,14 @@ module ExpressionTree = { | `RenderedDist(r) => Some(r) | _ => None }; - }; + let _toFloat = (t:DistTypes.shape) => switch(t){ + | Discrete({xyShape: {xs: [|x|], ys: [|1.0|]}}) => Some(`SymbolicDist(`Float(x))) + | _ => None + } + + let toFloat = (item:node):result(node, string) => item |> getShape |> E.O.bind(_,_toFloat) |> E.O.toResult("Not valid shape") + }; }; type simplificationResult = [ diff --git a/src/distPlus/expressionTree/Functions.re b/src/distPlus/expressionTree/Functions.re new file mode 100644 index 00000000..5dea5db7 --- /dev/null +++ b/src/distPlus/expressionTree/Functions.re @@ -0,0 +1,27 @@ +type node = ExpressionTypes.ExpressionTree.node; + +let toOkSym = r => Ok(`SymbolicDist(r)); + +let twoFloats = (fn, n1: node, n2: node): result(node, string) => + switch (n1, n2) { + | (`SymbolicDist(`Float(a)), `SymbolicDist(`Float(b))) => fn(a, b) + | _ => Error("Variables have wrong type") + }; + +let twoFloatsToOkSym = fn => twoFloats((f1, f2) => fn(f1, f2) |> toOkSym); + +let apply2 = (fn, args): result(node, string) => + switch (args) { + | [|a, b|] => fn(a, b) + | _ => Error("Needs 2 args") + }; + +let fnn = (name, args: array(node)) => { + switch (name) { + | "normal" => apply2(twoFloatsToOkSym(SymbolicDist.Normal.make), args) + | "uniform" => apply2(twoFloatsToOkSym(SymbolicDist.Uniform.make), args) + | "beta" => apply2(twoFloatsToOkSym(SymbolicDist.Beta.make), args) + | "cauchy" => apply2(twoFloatsToOkSym(SymbolicDist.Cauchy.make), args) + | _ => Error("Function not found") + }; +}; diff --git a/src/distPlus/expressionTree/MathJsParser.re b/src/distPlus/expressionTree/MathJsParser.re index f168479d..41263c31 100644 --- a/src/distPlus/expressionTree/MathJsParser.re +++ b/src/distPlus/expressionTree/MathJsParser.re @@ -55,7 +55,7 @@ module MathAdtToDistDst = { let handleSymbol = (inputVars: inputVars, sym) => { switch (Belt.Map.String.get(inputVars, sym)) { | Some(s) => Ok(s) - | None => Error("Couldn't find.") + | None => Error("Couldn't find symbol " ++ sym) }; }; @@ -93,13 +93,6 @@ module MathAdtToDistDst = { ); }; - let normal: - array(arg) => result(ExpressionTypes.ExpressionTree.node, string) = - fun - | [|Value(mean), Value(stdev)|] => - Ok(`SymbolicDist(`Normal({mean, stdev}))) - | _ => Error("Wrong number of variables in normal distribution"); - let lognormal: array(arg) => result(ExpressionTypes.ExpressionTree.node, string) = fun @@ -135,32 +128,12 @@ module MathAdtToDistDst = { Error("Low value must be less than high value.") | _ => Error("Wrong number of variables in lognormal distribution"); - let uniform: - array(arg) => result(ExpressionTypes.ExpressionTree.node, string) = - fun - | [|Value(low), Value(high)|] => - Ok(`SymbolicDist(`Uniform({low, high}))) - | _ => Error("Wrong number of variables in lognormal distribution"); - - let beta: array(arg) => result(ExpressionTypes.ExpressionTree.node, string) = - fun - | [|Value(alpha), Value(beta)|] => - Ok(`SymbolicDist(`Beta({alpha, beta}))) - | _ => Error("Wrong number of variables in lognormal distribution"); - let exponential: array(arg) => result(ExpressionTypes.ExpressionTree.node, string) = fun | [|Value(rate)|] => Ok(`SymbolicDist(`Exponential({rate: rate}))) | _ => Error("Wrong number of variables in Exponential distribution"); - let cauchy: - array(arg) => result(ExpressionTypes.ExpressionTree.node, string) = - fun - | [|Value(local), Value(scale)|] => - Ok(`SymbolicDist(`Cauchy({local, scale}))) - | _ => Error("Wrong number of variables in cauchy distribution"); - let triangular: array(arg) => result(ExpressionTypes.ExpressionTree.node, string) = fun @@ -214,43 +187,16 @@ module MathAdtToDistDst = { }; }; - // let arrayParser = - // (args: array(arg)) - // : result(ExpressionTypes.ExpressionTree.node, string) => { - // let samples = - // args - // |> E.A.fmap( - // fun - // | Value(n) => Some(n) - // | _ => None, - // ) - // |> E.A.O.concatSomes; - // let outputs = Samples.T.fromSamples(samples); - // let pdf = - // outputs.shape |> E.O.bind(_, Shape.T.toContinuous); - // let shape = - // pdf - // |> E.O.fmap(pdf => { - // let _pdf = Continuous.T.normalize(pdf); - // let cdf = Continuous.T.integral(~cache=None, _pdf); - // SymbolicDist.ContinuousShape.make(_pdf, cdf); - // }); - // switch (shape) { - // | Some(s) => Ok(`SymbolicDist(`ContinuousShape(s))) - // | None => Error("Rendering did not work") - // }; - // }; - let operationParser = ( name: string, - args: array(result(ExpressionTypes.ExpressionTree.node, string)), + args: result(array(ExpressionTypes.ExpressionTree.node), string), ) => { let toOkAlgebraic = r => Ok(`AlgebraicCombination(r)); let toOkPointwise = r => Ok(`PointwiseCombination(r)); let toOkTruncate = r => Ok(`Truncate(r)); let toOkFloatFromDist = r => Ok(`FloatFromDist(r)); - E.A.R.firstErrorOrOpen(args) + args |> E.R.bind(_, args => { switch (name, args) { | ("add", [|l, r|]) => toOkAlgebraic((`Add, l, r)) @@ -303,17 +249,17 @@ module MathAdtToDistDst = { }; let functionParser = (nodeParser, name, args) => { - let parseArgs = () => args |> E.A.fmap(nodeParser); - Js.log2("Parseargs", parseArgs); + let parseArgs = () => + args |> E.A.fmap(nodeParser) |> E.A.R.firstErrorOrOpen; switch (name) { - | "normal" => normal(args) - | "lognormal" => lognormal(args) - | "uniform" => uniform(args) - | "beta" => beta(args) - | "to" => to_(args) - | "exponential" => exponential(args) - | "cauchy" => cauchy(args) - | "triangular" => triangular(args) + | "normal" | "uniform" | "beta" | "caucy" => + parseArgs() + |> E.R.fmap( + ( + args: array(ExpressionTypes.ExpressionTree.node), + ) => + `CallableFunction((name, args)) + ) | "mm" => let weights = args @@ -358,14 +304,18 @@ module MathAdtToDistDst = { }; }; - let rec nodeParser = inputVars => - fun - | Value(f) => Ok(`SymbolicDist(`Float(f))) - | Symbol(s) => handleSymbol(inputVars, s) - | Fn({name, args}) => functionParser(nodeParser(inputVars), name, args) - | _ => { - Error("This type not currently supported"); - }; + let rec nodeParser: + (inputVars, MathJsonToMathJsAdt.arg) => + result(ExpressionTypes.ExpressionTree.node, string) = + inputVars => + fun + | Value(f) => Ok(`SymbolicDist(`Float(f))) + | Symbol(s) => handleSymbol(inputVars, s) + | Fn({name, args}) => + functionParser(nodeParser(inputVars), name, args) + | _ => { + Error("This type not currently supported"); + }; let topLevel = inputVars => fun @@ -405,7 +355,6 @@ let fromString2 = (inputVars: inputVars, str) => { } }); - Js.log(mathJsParse); let value = E.R.bind(mathJsParse, MathAdtToDistDst.run(inputVars)); value; }; diff --git a/src/distPlus/symbolic/SymbolicDist.re b/src/distPlus/symbolic/SymbolicDist.re index 43c5377b..6ad02704 100644 --- a/src/distPlus/symbolic/SymbolicDist.re +++ b/src/distPlus/symbolic/SymbolicDist.re @@ -2,6 +2,8 @@ open SymbolicTypes; module Exponential = { type t = exponential; + let make = (rate): symbolicDist => + `Exponential({rate}); let pdf = (x, t: t) => Jstat.exponential##pdf(x, t.rate); let cdf = (x, t: t) => Jstat.exponential##cdf(x, t.rate); let inv = (p, t: t) => Jstat.exponential##inv(p, t.rate); @@ -12,6 +14,8 @@ module Exponential = { module Cauchy = { type t = cauchy; + let make = (local, scale): symbolicDist => + `Cauchy({local,scale}); let pdf = (x, t: t) => Jstat.cauchy##pdf(x, t.local, t.scale); let cdf = (x, t: t) => Jstat.cauchy##cdf(x, t.local, t.scale); let inv = (p, t: t) => Jstat.cauchy##inv(p, t.local, t.scale); @@ -22,6 +26,8 @@ module Cauchy = { module Triangular = { type t = triangular; + let make = (low, medium, high): symbolicDist => + `Triangular({low, medium, high}); let pdf = (x, t: t) => Jstat.triangular##pdf(x, t.low, t.high, t.medium); let cdf = (x, t: t) => Jstat.triangular##cdf(x, t.low, t.high, t.medium); let inv = (p, t: t) => Jstat.triangular##inv(p, t.low, t.high, t.medium); @@ -32,7 +38,7 @@ module Triangular = { module Normal = { type t = normal; - let make = (mean, stdev):t => {mean, stdev}; + let make = (mean, stdev): symbolicDist => `Normal({mean, stdev}); let pdf = (x, t: t) => Jstat.normal##pdf(x, t.mean, t.stdev); let cdf = (x, t: t) => Jstat.normal##cdf(x, t.mean, t.stdev); @@ -76,6 +82,7 @@ module Normal = { module Beta = { type t = beta; + let make = (alpha, beta) => `Beta({alpha, beta}) let pdf = (x, t: t) => Jstat.beta##pdf(x, t.alpha, t.beta); let cdf = (x, t: t) => Jstat.beta##cdf(x, t.alpha, t.beta); let inv = (p, t: t) => Jstat.beta##inv(p, t.alpha, t.beta); @@ -86,6 +93,7 @@ module Beta = { module Lognormal = { type t = lognormal; + let make = (mu, sigma) => `Lognormal({mu, sigma}) let pdf = (x, t: t) => Jstat.lognormal##pdf(x, t.mu, t.sigma); let cdf = (x, t: t) => Jstat.lognormal##cdf(x, t.mu, t.sigma); let inv = (p, t: t) => Jstat.lognormal##inv(p, t.mu, t.sigma); @@ -132,6 +140,7 @@ module Lognormal = { module Uniform = { type t = uniform; + let make = (low, high) => `Uniform({low, high}) let pdf = (x, t: t) => Jstat.uniform##pdf(x, t.low, t.high); let cdf = (x, t: t) => Jstat.uniform##cdf(x, t.low, t.high); let inv = (p, t: t) => Jstat.uniform##inv(p, t.low, t.high); @@ -147,6 +156,7 @@ module Uniform = { module Float = { type t = float; + let make = t => `Float(t) let pdf = (x, t: t) => x == t ? 1.0 : 0.0; let cdf = (x, t: t) => x >= t ? 1.0 : 0.0; let inv = (p, t: t) => p < t ? 0.0 : 1.0; @@ -318,13 +328,14 @@ module T = { switch (d) { | `Float(v) => Discrete( - Discrete.make(~integralSumCache=Some(1.0), {xs: [|v|], ys: [|1.0|]}), + Discrete.make( + ~integralSumCache=Some(1.0), + {xs: [|v|], ys: [|1.0|]}, + ), ) | _ => let xs = interpolateXs(~xSelection=`ByWeight, d, sampleCount); let ys = xs |> E.A.fmap(x => pdf(x, d)); - Continuous( - Continuous.make(~integralSumCache=Some(1.0), {xs, ys}), - ); + Continuous(Continuous.make(~integralSumCache=Some(1.0), {xs, ys})); }; };