Moving function creation into ExpressionTreeEvaluator as functions

This commit is contained in:
Ozzie Gooen 2020-07-30 11:44:01 +01:00
parent 102a147b97
commit d3766f7a7f
6 changed files with 104 additions and 82 deletions

View File

@ -216,6 +216,11 @@ let sample = (t: t): float => {
bar; bar;
}; };
let isFloat = (t:t) => switch(t){
| Discrete({xyShape: {xs: [|_|], ys: [|1.0|]}}) => true
| _ => false
}
let sampleNRendered = (n, dist) => { let sampleNRendered = (n, dist) => {
let integralCache = T.Integral.get(dist); let integralCache = T.Integral.get(dist);
let distWithUpdatedIntegralCache = T.updateIntegralCache(Some(integralCache), dist); let distWithUpdatedIntegralCache = T.updateIntegralCache(Some(integralCache), dist);

View File

@ -259,10 +259,22 @@ module FloatFromDist = {
}; };
}; };
let callableFunction = (evaluationParams, name, args) => {
let b =
args
|> E.A.fmap(a =>
Render.render(evaluationParams, a)
|> E.R.bind(_, Render.toFloat)
)
|> E.A.R.firstErrorOrOpen;
b |> E.R.bind(_, Functions.fnn("normal"));
};
module Render = { module Render = {
let rec operationToLeaf = let rec operationToLeaf =
(evaluationParams: evaluationParams, t: node): result(t, string) => { (evaluationParams: evaluationParams, t: node): result(t, string) => {
switch (t) { switch (t) {
| `Function(_) => Error("Cannot render a function")
| `SymbolicDist(d) => | `SymbolicDist(d) =>
Ok( Ok(
`RenderedDist( `RenderedDist(
@ -275,6 +287,13 @@ module Render = {
}; };
}; };
let run = (node, fnNode) => {
switch (fnNode) {
| `Function(r) => Ok(r(node))
| _ => Error("Not a function")
};
};
/* This function recursively goes through the nodes of the parse tree, /* This function recursively goes through the nodes of the parse tree,
replacing each Operation node and its subtree with a Data node. replacing each Operation node and its subtree with a Data node.
Whenever possible, the replacement produces a new Symbolic Data node, Whenever possible, the replacement produces a new Symbolic Data node,
@ -314,5 +333,8 @@ let toLeaf =
FloatFromDist.operationToLeaf(evaluationParams, distToFloatOp, t) FloatFromDist.operationToLeaf(evaluationParams, distToFloatOp, t)
| `Normalize(t) => Normalize.operationToLeaf(evaluationParams, t) | `Normalize(t) => Normalize.operationToLeaf(evaluationParams, t)
| `Render(t) => Render.operationToLeaf(evaluationParams, t) | `Render(t) => Render.operationToLeaf(evaluationParams, t)
| `Function(t) => Ok(`Function(t))
| `CallableFunction(name, args) =>
callableFunction(evaluationParams, name, args)
}; };
}; };

View File

@ -20,6 +20,8 @@ module ExpressionTree = {
| `Truncate(option(float), option(float), node) | `Truncate(option(float), option(float), node)
| `Normalize(node) | `Normalize(node)
| `FloatFromDist(distToFloatOperation, node) | `FloatFromDist(distToFloatOperation, node)
| `Function(node => result(node, string))
| `CallableFunction(string, array(node))
]; ];
type samplingInputs = { type samplingInputs = {
@ -71,8 +73,14 @@ module ExpressionTree = {
| `RenderedDist(r) => Some(r) | `RenderedDist(r) => Some(r)
| _ => None | _ => None
}; };
};
let _toFloat = (t:DistTypes.shape) => switch(t){
| Discrete({xyShape: {xs: [|x|], ys: [|1.0|]}}) => Some(`SymbolicDist(`Float(x)))
| _ => None
}
let toFloat = (item:node):result(node, string) => item |> getShape |> E.O.bind(_,_toFloat) |> E.O.toResult("Not valid shape")
};
}; };
type simplificationResult = [ type simplificationResult = [

View File

@ -0,0 +1,27 @@
type node = ExpressionTypes.ExpressionTree.node;
let toOkSym = r => Ok(`SymbolicDist(r));
let twoFloats = (fn, n1: node, n2: node): result(node, string) =>
switch (n1, n2) {
| (`SymbolicDist(`Float(a)), `SymbolicDist(`Float(b))) => fn(a, b)
| _ => Error("Variables have wrong type")
};
let twoFloatsToOkSym = fn => twoFloats((f1, f2) => fn(f1, f2) |> toOkSym);
let apply2 = (fn, args): result(node, string) =>
switch (args) {
| [|a, b|] => fn(a, b)
| _ => Error("Needs 2 args")
};
let fnn = (name, args: array(node)) => {
switch (name) {
| "normal" => apply2(twoFloatsToOkSym(SymbolicDist.Normal.make), args)
| "uniform" => apply2(twoFloatsToOkSym(SymbolicDist.Uniform.make), args)
| "beta" => apply2(twoFloatsToOkSym(SymbolicDist.Beta.make), args)
| "cauchy" => apply2(twoFloatsToOkSym(SymbolicDist.Cauchy.make), args)
| _ => Error("Function not found")
};
};

View File

@ -55,7 +55,7 @@ module MathAdtToDistDst = {
let handleSymbol = (inputVars: inputVars, sym) => { let handleSymbol = (inputVars: inputVars, sym) => {
switch (Belt.Map.String.get(inputVars, sym)) { switch (Belt.Map.String.get(inputVars, sym)) {
| Some(s) => Ok(s) | Some(s) => Ok(s)
| None => Error("Couldn't find.") | None => Error("Couldn't find symbol " ++ sym)
}; };
}; };
@ -93,13 +93,6 @@ module MathAdtToDistDst = {
); );
}; };
let normal:
array(arg) => result(ExpressionTypes.ExpressionTree.node, string) =
fun
| [|Value(mean), Value(stdev)|] =>
Ok(`SymbolicDist(`Normal({mean, stdev})))
| _ => Error("Wrong number of variables in normal distribution");
let lognormal: let lognormal:
array(arg) => result(ExpressionTypes.ExpressionTree.node, string) = array(arg) => result(ExpressionTypes.ExpressionTree.node, string) =
fun fun
@ -135,32 +128,12 @@ module MathAdtToDistDst = {
Error("Low value must be less than high value.") Error("Low value must be less than high value.")
| _ => Error("Wrong number of variables in lognormal distribution"); | _ => Error("Wrong number of variables in lognormal distribution");
let uniform:
array(arg) => result(ExpressionTypes.ExpressionTree.node, string) =
fun
| [|Value(low), Value(high)|] =>
Ok(`SymbolicDist(`Uniform({low, high})))
| _ => Error("Wrong number of variables in lognormal distribution");
let beta: array(arg) => result(ExpressionTypes.ExpressionTree.node, string) =
fun
| [|Value(alpha), Value(beta)|] =>
Ok(`SymbolicDist(`Beta({alpha, beta})))
| _ => Error("Wrong number of variables in lognormal distribution");
let exponential: let exponential:
array(arg) => result(ExpressionTypes.ExpressionTree.node, string) = array(arg) => result(ExpressionTypes.ExpressionTree.node, string) =
fun fun
| [|Value(rate)|] => Ok(`SymbolicDist(`Exponential({rate: rate}))) | [|Value(rate)|] => Ok(`SymbolicDist(`Exponential({rate: rate})))
| _ => Error("Wrong number of variables in Exponential distribution"); | _ => Error("Wrong number of variables in Exponential distribution");
let cauchy:
array(arg) => result(ExpressionTypes.ExpressionTree.node, string) =
fun
| [|Value(local), Value(scale)|] =>
Ok(`SymbolicDist(`Cauchy({local, scale})))
| _ => Error("Wrong number of variables in cauchy distribution");
let triangular: let triangular:
array(arg) => result(ExpressionTypes.ExpressionTree.node, string) = array(arg) => result(ExpressionTypes.ExpressionTree.node, string) =
fun fun
@ -214,43 +187,16 @@ module MathAdtToDistDst = {
}; };
}; };
// let arrayParser =
// (args: array(arg))
// : result(ExpressionTypes.ExpressionTree.node, string) => {
// let samples =
// args
// |> E.A.fmap(
// fun
// | Value(n) => Some(n)
// | _ => None,
// )
// |> E.A.O.concatSomes;
// let outputs = Samples.T.fromSamples(samples);
// let pdf =
// outputs.shape |> E.O.bind(_, Shape.T.toContinuous);
// let shape =
// pdf
// |> E.O.fmap(pdf => {
// let _pdf = Continuous.T.normalize(pdf);
// let cdf = Continuous.T.integral(~cache=None, _pdf);
// SymbolicDist.ContinuousShape.make(_pdf, cdf);
// });
// switch (shape) {
// | Some(s) => Ok(`SymbolicDist(`ContinuousShape(s)))
// | None => Error("Rendering did not work")
// };
// };
let operationParser = let operationParser =
( (
name: string, name: string,
args: array(result(ExpressionTypes.ExpressionTree.node, string)), args: result(array(ExpressionTypes.ExpressionTree.node), string),
) => { ) => {
let toOkAlgebraic = r => Ok(`AlgebraicCombination(r)); let toOkAlgebraic = r => Ok(`AlgebraicCombination(r));
let toOkPointwise = r => Ok(`PointwiseCombination(r)); let toOkPointwise = r => Ok(`PointwiseCombination(r));
let toOkTruncate = r => Ok(`Truncate(r)); let toOkTruncate = r => Ok(`Truncate(r));
let toOkFloatFromDist = r => Ok(`FloatFromDist(r)); let toOkFloatFromDist = r => Ok(`FloatFromDist(r));
E.A.R.firstErrorOrOpen(args) args
|> E.R.bind(_, args => { |> E.R.bind(_, args => {
switch (name, args) { switch (name, args) {
| ("add", [|l, r|]) => toOkAlgebraic((`Add, l, r)) | ("add", [|l, r|]) => toOkAlgebraic((`Add, l, r))
@ -303,17 +249,17 @@ module MathAdtToDistDst = {
}; };
let functionParser = (nodeParser, name, args) => { let functionParser = (nodeParser, name, args) => {
let parseArgs = () => args |> E.A.fmap(nodeParser); let parseArgs = () =>
Js.log2("Parseargs", parseArgs); args |> E.A.fmap(nodeParser) |> E.A.R.firstErrorOrOpen;
switch (name) { switch (name) {
| "normal" => normal(args) | "normal" | "uniform" | "beta" | "caucy" =>
| "lognormal" => lognormal(args) parseArgs()
| "uniform" => uniform(args) |> E.R.fmap(
| "beta" => beta(args) (
| "to" => to_(args) args: array(ExpressionTypes.ExpressionTree.node),
| "exponential" => exponential(args) ) =>
| "cauchy" => cauchy(args) `CallableFunction((name, args))
| "triangular" => triangular(args) )
| "mm" => | "mm" =>
let weights = let weights =
args args
@ -358,14 +304,18 @@ module MathAdtToDistDst = {
}; };
}; };
let rec nodeParser = inputVars => let rec nodeParser:
fun (inputVars, MathJsonToMathJsAdt.arg) =>
| Value(f) => Ok(`SymbolicDist(`Float(f))) result(ExpressionTypes.ExpressionTree.node, string) =
| Symbol(s) => handleSymbol(inputVars, s) inputVars =>
| Fn({name, args}) => functionParser(nodeParser(inputVars), name, args) fun
| _ => { | Value(f) => Ok(`SymbolicDist(`Float(f)))
Error("This type not currently supported"); | Symbol(s) => handleSymbol(inputVars, s)
}; | Fn({name, args}) =>
functionParser(nodeParser(inputVars), name, args)
| _ => {
Error("This type not currently supported");
};
let topLevel = inputVars => let topLevel = inputVars =>
fun fun
@ -405,7 +355,6 @@ let fromString2 = (inputVars: inputVars, str) => {
} }
}); });
Js.log(mathJsParse);
let value = E.R.bind(mathJsParse, MathAdtToDistDst.run(inputVars)); let value = E.R.bind(mathJsParse, MathAdtToDistDst.run(inputVars));
value; value;
}; };

View File

@ -2,6 +2,8 @@ open SymbolicTypes;
module Exponential = { module Exponential = {
type t = exponential; type t = exponential;
let make = (rate): symbolicDist =>
`Exponential({rate});
let pdf = (x, t: t) => Jstat.exponential##pdf(x, t.rate); let pdf = (x, t: t) => Jstat.exponential##pdf(x, t.rate);
let cdf = (x, t: t) => Jstat.exponential##cdf(x, t.rate); let cdf = (x, t: t) => Jstat.exponential##cdf(x, t.rate);
let inv = (p, t: t) => Jstat.exponential##inv(p, t.rate); let inv = (p, t: t) => Jstat.exponential##inv(p, t.rate);
@ -12,6 +14,8 @@ module Exponential = {
module Cauchy = { module Cauchy = {
type t = cauchy; type t = cauchy;
let make = (local, scale): symbolicDist =>
`Cauchy({local,scale});
let pdf = (x, t: t) => Jstat.cauchy##pdf(x, t.local, t.scale); let pdf = (x, t: t) => Jstat.cauchy##pdf(x, t.local, t.scale);
let cdf = (x, t: t) => Jstat.cauchy##cdf(x, t.local, t.scale); let cdf = (x, t: t) => Jstat.cauchy##cdf(x, t.local, t.scale);
let inv = (p, t: t) => Jstat.cauchy##inv(p, t.local, t.scale); let inv = (p, t: t) => Jstat.cauchy##inv(p, t.local, t.scale);
@ -22,6 +26,8 @@ module Cauchy = {
module Triangular = { module Triangular = {
type t = triangular; type t = triangular;
let make = (low, medium, high): symbolicDist =>
`Triangular({low, medium, high});
let pdf = (x, t: t) => Jstat.triangular##pdf(x, t.low, t.high, t.medium); let pdf = (x, t: t) => Jstat.triangular##pdf(x, t.low, t.high, t.medium);
let cdf = (x, t: t) => Jstat.triangular##cdf(x, t.low, t.high, t.medium); let cdf = (x, t: t) => Jstat.triangular##cdf(x, t.low, t.high, t.medium);
let inv = (p, t: t) => Jstat.triangular##inv(p, t.low, t.high, t.medium); let inv = (p, t: t) => Jstat.triangular##inv(p, t.low, t.high, t.medium);
@ -32,7 +38,7 @@ module Triangular = {
module Normal = { module Normal = {
type t = normal; type t = normal;
let make = (mean, stdev):t => {mean, stdev}; let make = (mean, stdev): symbolicDist => `Normal({mean, stdev});
let pdf = (x, t: t) => Jstat.normal##pdf(x, t.mean, t.stdev); let pdf = (x, t: t) => Jstat.normal##pdf(x, t.mean, t.stdev);
let cdf = (x, t: t) => Jstat.normal##cdf(x, t.mean, t.stdev); let cdf = (x, t: t) => Jstat.normal##cdf(x, t.mean, t.stdev);
@ -76,6 +82,7 @@ module Normal = {
module Beta = { module Beta = {
type t = beta; type t = beta;
let make = (alpha, beta) => `Beta({alpha, beta})
let pdf = (x, t: t) => Jstat.beta##pdf(x, t.alpha, t.beta); let pdf = (x, t: t) => Jstat.beta##pdf(x, t.alpha, t.beta);
let cdf = (x, t: t) => Jstat.beta##cdf(x, t.alpha, t.beta); let cdf = (x, t: t) => Jstat.beta##cdf(x, t.alpha, t.beta);
let inv = (p, t: t) => Jstat.beta##inv(p, t.alpha, t.beta); let inv = (p, t: t) => Jstat.beta##inv(p, t.alpha, t.beta);
@ -86,6 +93,7 @@ module Beta = {
module Lognormal = { module Lognormal = {
type t = lognormal; type t = lognormal;
let make = (mu, sigma) => `Lognormal({mu, sigma})
let pdf = (x, t: t) => Jstat.lognormal##pdf(x, t.mu, t.sigma); let pdf = (x, t: t) => Jstat.lognormal##pdf(x, t.mu, t.sigma);
let cdf = (x, t: t) => Jstat.lognormal##cdf(x, t.mu, t.sigma); let cdf = (x, t: t) => Jstat.lognormal##cdf(x, t.mu, t.sigma);
let inv = (p, t: t) => Jstat.lognormal##inv(p, t.mu, t.sigma); let inv = (p, t: t) => Jstat.lognormal##inv(p, t.mu, t.sigma);
@ -132,6 +140,7 @@ module Lognormal = {
module Uniform = { module Uniform = {
type t = uniform; type t = uniform;
let make = (low, high) => `Uniform({low, high})
let pdf = (x, t: t) => Jstat.uniform##pdf(x, t.low, t.high); let pdf = (x, t: t) => Jstat.uniform##pdf(x, t.low, t.high);
let cdf = (x, t: t) => Jstat.uniform##cdf(x, t.low, t.high); let cdf = (x, t: t) => Jstat.uniform##cdf(x, t.low, t.high);
let inv = (p, t: t) => Jstat.uniform##inv(p, t.low, t.high); let inv = (p, t: t) => Jstat.uniform##inv(p, t.low, t.high);
@ -147,6 +156,7 @@ module Uniform = {
module Float = { module Float = {
type t = float; type t = float;
let make = t => `Float(t)
let pdf = (x, t: t) => x == t ? 1.0 : 0.0; let pdf = (x, t: t) => x == t ? 1.0 : 0.0;
let cdf = (x, t: t) => x >= t ? 1.0 : 0.0; let cdf = (x, t: t) => x >= t ? 1.0 : 0.0;
let inv = (p, t: t) => p < t ? 0.0 : 1.0; let inv = (p, t: t) => p < t ? 0.0 : 1.0;
@ -318,13 +328,14 @@ module T = {
switch (d) { switch (d) {
| `Float(v) => | `Float(v) =>
Discrete( Discrete(
Discrete.make(~integralSumCache=Some(1.0), {xs: [|v|], ys: [|1.0|]}), Discrete.make(
~integralSumCache=Some(1.0),
{xs: [|v|], ys: [|1.0|]},
),
) )
| _ => | _ =>
let xs = interpolateXs(~xSelection=`ByWeight, d, sampleCount); let xs = interpolateXs(~xSelection=`ByWeight, d, sampleCount);
let ys = xs |> E.A.fmap(x => pdf(x, d)); let ys = xs |> E.A.fmap(x => pdf(x, d));
Continuous( Continuous(Continuous.make(~integralSumCache=Some(1.0), {xs, ys}));
Continuous.make(~integralSumCache=Some(1.0), {xs, ys}),
);
}; };
}; };