Moving function creation into ExpressionTreeEvaluator as functions
This commit is contained in:
parent
102a147b97
commit
d3766f7a7f
|
@ -216,6 +216,11 @@ let sample = (t: t): float => {
|
|||
bar;
|
||||
};
|
||||
|
||||
let isFloat = (t:t) => switch(t){
|
||||
| Discrete({xyShape: {xs: [|_|], ys: [|1.0|]}}) => true
|
||||
| _ => false
|
||||
}
|
||||
|
||||
let sampleNRendered = (n, dist) => {
|
||||
let integralCache = T.Integral.get(dist);
|
||||
let distWithUpdatedIntegralCache = T.updateIntegralCache(Some(integralCache), dist);
|
||||
|
|
|
@ -259,10 +259,22 @@ module FloatFromDist = {
|
|||
};
|
||||
};
|
||||
|
||||
let callableFunction = (evaluationParams, name, args) => {
|
||||
let b =
|
||||
args
|
||||
|> E.A.fmap(a =>
|
||||
Render.render(evaluationParams, a)
|
||||
|> E.R.bind(_, Render.toFloat)
|
||||
)
|
||||
|> E.A.R.firstErrorOrOpen;
|
||||
b |> E.R.bind(_, Functions.fnn("normal"));
|
||||
};
|
||||
|
||||
module Render = {
|
||||
let rec operationToLeaf =
|
||||
(evaluationParams: evaluationParams, t: node): result(t, string) => {
|
||||
switch (t) {
|
||||
| `Function(_) => Error("Cannot render a function")
|
||||
| `SymbolicDist(d) =>
|
||||
Ok(
|
||||
`RenderedDist(
|
||||
|
@ -275,6 +287,13 @@ module Render = {
|
|||
};
|
||||
};
|
||||
|
||||
let run = (node, fnNode) => {
|
||||
switch (fnNode) {
|
||||
| `Function(r) => Ok(r(node))
|
||||
| _ => Error("Not a function")
|
||||
};
|
||||
};
|
||||
|
||||
/* This function recursively goes through the nodes of the parse tree,
|
||||
replacing each Operation node and its subtree with a Data node.
|
||||
Whenever possible, the replacement produces a new Symbolic Data node,
|
||||
|
@ -314,5 +333,8 @@ let toLeaf =
|
|||
FloatFromDist.operationToLeaf(evaluationParams, distToFloatOp, t)
|
||||
| `Normalize(t) => Normalize.operationToLeaf(evaluationParams, t)
|
||||
| `Render(t) => Render.operationToLeaf(evaluationParams, t)
|
||||
| `Function(t) => Ok(`Function(t))
|
||||
| `CallableFunction(name, args) =>
|
||||
callableFunction(evaluationParams, name, args)
|
||||
};
|
||||
};
|
||||
|
|
|
@ -20,6 +20,8 @@ module ExpressionTree = {
|
|||
| `Truncate(option(float), option(float), node)
|
||||
| `Normalize(node)
|
||||
| `FloatFromDist(distToFloatOperation, node)
|
||||
| `Function(node => result(node, string))
|
||||
| `CallableFunction(string, array(node))
|
||||
];
|
||||
|
||||
type samplingInputs = {
|
||||
|
@ -71,8 +73,14 @@ module ExpressionTree = {
|
|||
| `RenderedDist(r) => Some(r)
|
||||
| _ => None
|
||||
};
|
||||
};
|
||||
|
||||
let _toFloat = (t:DistTypes.shape) => switch(t){
|
||||
| Discrete({xyShape: {xs: [|x|], ys: [|1.0|]}}) => Some(`SymbolicDist(`Float(x)))
|
||||
| _ => None
|
||||
}
|
||||
|
||||
let toFloat = (item:node):result(node, string) => item |> getShape |> E.O.bind(_,_toFloat) |> E.O.toResult("Not valid shape")
|
||||
};
|
||||
};
|
||||
|
||||
type simplificationResult = [
|
||||
|
|
27
src/distPlus/expressionTree/Functions.re
Normal file
27
src/distPlus/expressionTree/Functions.re
Normal file
|
@ -0,0 +1,27 @@
|
|||
type node = ExpressionTypes.ExpressionTree.node;
|
||||
|
||||
let toOkSym = r => Ok(`SymbolicDist(r));
|
||||
|
||||
let twoFloats = (fn, n1: node, n2: node): result(node, string) =>
|
||||
switch (n1, n2) {
|
||||
| (`SymbolicDist(`Float(a)), `SymbolicDist(`Float(b))) => fn(a, b)
|
||||
| _ => Error("Variables have wrong type")
|
||||
};
|
||||
|
||||
let twoFloatsToOkSym = fn => twoFloats((f1, f2) => fn(f1, f2) |> toOkSym);
|
||||
|
||||
let apply2 = (fn, args): result(node, string) =>
|
||||
switch (args) {
|
||||
| [|a, b|] => fn(a, b)
|
||||
| _ => Error("Needs 2 args")
|
||||
};
|
||||
|
||||
let fnn = (name, args: array(node)) => {
|
||||
switch (name) {
|
||||
| "normal" => apply2(twoFloatsToOkSym(SymbolicDist.Normal.make), args)
|
||||
| "uniform" => apply2(twoFloatsToOkSym(SymbolicDist.Uniform.make), args)
|
||||
| "beta" => apply2(twoFloatsToOkSym(SymbolicDist.Beta.make), args)
|
||||
| "cauchy" => apply2(twoFloatsToOkSym(SymbolicDist.Cauchy.make), args)
|
||||
| _ => Error("Function not found")
|
||||
};
|
||||
};
|
|
@ -55,7 +55,7 @@ module MathAdtToDistDst = {
|
|||
let handleSymbol = (inputVars: inputVars, sym) => {
|
||||
switch (Belt.Map.String.get(inputVars, sym)) {
|
||||
| Some(s) => Ok(s)
|
||||
| None => Error("Couldn't find.")
|
||||
| None => Error("Couldn't find symbol " ++ sym)
|
||||
};
|
||||
};
|
||||
|
||||
|
@ -93,13 +93,6 @@ module MathAdtToDistDst = {
|
|||
);
|
||||
};
|
||||
|
||||
let normal:
|
||||
array(arg) => result(ExpressionTypes.ExpressionTree.node, string) =
|
||||
fun
|
||||
| [|Value(mean), Value(stdev)|] =>
|
||||
Ok(`SymbolicDist(`Normal({mean, stdev})))
|
||||
| _ => Error("Wrong number of variables in normal distribution");
|
||||
|
||||
let lognormal:
|
||||
array(arg) => result(ExpressionTypes.ExpressionTree.node, string) =
|
||||
fun
|
||||
|
@ -135,32 +128,12 @@ module MathAdtToDistDst = {
|
|||
Error("Low value must be less than high value.")
|
||||
| _ => Error("Wrong number of variables in lognormal distribution");
|
||||
|
||||
let uniform:
|
||||
array(arg) => result(ExpressionTypes.ExpressionTree.node, string) =
|
||||
fun
|
||||
| [|Value(low), Value(high)|] =>
|
||||
Ok(`SymbolicDist(`Uniform({low, high})))
|
||||
| _ => Error("Wrong number of variables in lognormal distribution");
|
||||
|
||||
let beta: array(arg) => result(ExpressionTypes.ExpressionTree.node, string) =
|
||||
fun
|
||||
| [|Value(alpha), Value(beta)|] =>
|
||||
Ok(`SymbolicDist(`Beta({alpha, beta})))
|
||||
| _ => Error("Wrong number of variables in lognormal distribution");
|
||||
|
||||
let exponential:
|
||||
array(arg) => result(ExpressionTypes.ExpressionTree.node, string) =
|
||||
fun
|
||||
| [|Value(rate)|] => Ok(`SymbolicDist(`Exponential({rate: rate})))
|
||||
| _ => Error("Wrong number of variables in Exponential distribution");
|
||||
|
||||
let cauchy:
|
||||
array(arg) => result(ExpressionTypes.ExpressionTree.node, string) =
|
||||
fun
|
||||
| [|Value(local), Value(scale)|] =>
|
||||
Ok(`SymbolicDist(`Cauchy({local, scale})))
|
||||
| _ => Error("Wrong number of variables in cauchy distribution");
|
||||
|
||||
let triangular:
|
||||
array(arg) => result(ExpressionTypes.ExpressionTree.node, string) =
|
||||
fun
|
||||
|
@ -214,43 +187,16 @@ module MathAdtToDistDst = {
|
|||
};
|
||||
};
|
||||
|
||||
// let arrayParser =
|
||||
// (args: array(arg))
|
||||
// : result(ExpressionTypes.ExpressionTree.node, string) => {
|
||||
// let samples =
|
||||
// args
|
||||
// |> E.A.fmap(
|
||||
// fun
|
||||
// | Value(n) => Some(n)
|
||||
// | _ => None,
|
||||
// )
|
||||
// |> E.A.O.concatSomes;
|
||||
// let outputs = Samples.T.fromSamples(samples);
|
||||
// let pdf =
|
||||
// outputs.shape |> E.O.bind(_, Shape.T.toContinuous);
|
||||
// let shape =
|
||||
// pdf
|
||||
// |> E.O.fmap(pdf => {
|
||||
// let _pdf = Continuous.T.normalize(pdf);
|
||||
// let cdf = Continuous.T.integral(~cache=None, _pdf);
|
||||
// SymbolicDist.ContinuousShape.make(_pdf, cdf);
|
||||
// });
|
||||
// switch (shape) {
|
||||
// | Some(s) => Ok(`SymbolicDist(`ContinuousShape(s)))
|
||||
// | None => Error("Rendering did not work")
|
||||
// };
|
||||
// };
|
||||
|
||||
let operationParser =
|
||||
(
|
||||
name: string,
|
||||
args: array(result(ExpressionTypes.ExpressionTree.node, string)),
|
||||
args: result(array(ExpressionTypes.ExpressionTree.node), string),
|
||||
) => {
|
||||
let toOkAlgebraic = r => Ok(`AlgebraicCombination(r));
|
||||
let toOkPointwise = r => Ok(`PointwiseCombination(r));
|
||||
let toOkTruncate = r => Ok(`Truncate(r));
|
||||
let toOkFloatFromDist = r => Ok(`FloatFromDist(r));
|
||||
E.A.R.firstErrorOrOpen(args)
|
||||
args
|
||||
|> E.R.bind(_, args => {
|
||||
switch (name, args) {
|
||||
| ("add", [|l, r|]) => toOkAlgebraic((`Add, l, r))
|
||||
|
@ -303,17 +249,17 @@ module MathAdtToDistDst = {
|
|||
};
|
||||
|
||||
let functionParser = (nodeParser, name, args) => {
|
||||
let parseArgs = () => args |> E.A.fmap(nodeParser);
|
||||
Js.log2("Parseargs", parseArgs);
|
||||
let parseArgs = () =>
|
||||
args |> E.A.fmap(nodeParser) |> E.A.R.firstErrorOrOpen;
|
||||
switch (name) {
|
||||
| "normal" => normal(args)
|
||||
| "lognormal" => lognormal(args)
|
||||
| "uniform" => uniform(args)
|
||||
| "beta" => beta(args)
|
||||
| "to" => to_(args)
|
||||
| "exponential" => exponential(args)
|
||||
| "cauchy" => cauchy(args)
|
||||
| "triangular" => triangular(args)
|
||||
| "normal" | "uniform" | "beta" | "caucy" =>
|
||||
parseArgs()
|
||||
|> E.R.fmap(
|
||||
(
|
||||
args: array(ExpressionTypes.ExpressionTree.node),
|
||||
) =>
|
||||
`CallableFunction((name, args))
|
||||
)
|
||||
| "mm" =>
|
||||
let weights =
|
||||
args
|
||||
|
@ -358,14 +304,18 @@ module MathAdtToDistDst = {
|
|||
};
|
||||
};
|
||||
|
||||
let rec nodeParser = inputVars =>
|
||||
fun
|
||||
| Value(f) => Ok(`SymbolicDist(`Float(f)))
|
||||
| Symbol(s) => handleSymbol(inputVars, s)
|
||||
| Fn({name, args}) => functionParser(nodeParser(inputVars), name, args)
|
||||
| _ => {
|
||||
Error("This type not currently supported");
|
||||
};
|
||||
let rec nodeParser:
|
||||
(inputVars, MathJsonToMathJsAdt.arg) =>
|
||||
result(ExpressionTypes.ExpressionTree.node, string) =
|
||||
inputVars =>
|
||||
fun
|
||||
| Value(f) => Ok(`SymbolicDist(`Float(f)))
|
||||
| Symbol(s) => handleSymbol(inputVars, s)
|
||||
| Fn({name, args}) =>
|
||||
functionParser(nodeParser(inputVars), name, args)
|
||||
| _ => {
|
||||
Error("This type not currently supported");
|
||||
};
|
||||
|
||||
let topLevel = inputVars =>
|
||||
fun
|
||||
|
@ -405,7 +355,6 @@ let fromString2 = (inputVars: inputVars, str) => {
|
|||
}
|
||||
});
|
||||
|
||||
Js.log(mathJsParse);
|
||||
let value = E.R.bind(mathJsParse, MathAdtToDistDst.run(inputVars));
|
||||
value;
|
||||
};
|
||||
|
|
|
@ -2,6 +2,8 @@ open SymbolicTypes;
|
|||
|
||||
module Exponential = {
|
||||
type t = exponential;
|
||||
let make = (rate): symbolicDist =>
|
||||
`Exponential({rate});
|
||||
let pdf = (x, t: t) => Jstat.exponential##pdf(x, t.rate);
|
||||
let cdf = (x, t: t) => Jstat.exponential##cdf(x, t.rate);
|
||||
let inv = (p, t: t) => Jstat.exponential##inv(p, t.rate);
|
||||
|
@ -12,6 +14,8 @@ module Exponential = {
|
|||
|
||||
module Cauchy = {
|
||||
type t = cauchy;
|
||||
let make = (local, scale): symbolicDist =>
|
||||
`Cauchy({local,scale});
|
||||
let pdf = (x, t: t) => Jstat.cauchy##pdf(x, t.local, t.scale);
|
||||
let cdf = (x, t: t) => Jstat.cauchy##cdf(x, t.local, t.scale);
|
||||
let inv = (p, t: t) => Jstat.cauchy##inv(p, t.local, t.scale);
|
||||
|
@ -22,6 +26,8 @@ module Cauchy = {
|
|||
|
||||
module Triangular = {
|
||||
type t = triangular;
|
||||
let make = (low, medium, high): symbolicDist =>
|
||||
`Triangular({low, medium, high});
|
||||
let pdf = (x, t: t) => Jstat.triangular##pdf(x, t.low, t.high, t.medium);
|
||||
let cdf = (x, t: t) => Jstat.triangular##cdf(x, t.low, t.high, t.medium);
|
||||
let inv = (p, t: t) => Jstat.triangular##inv(p, t.low, t.high, t.medium);
|
||||
|
@ -32,7 +38,7 @@ module Triangular = {
|
|||
|
||||
module Normal = {
|
||||
type t = normal;
|
||||
let make = (mean, stdev):t => {mean, stdev};
|
||||
let make = (mean, stdev): symbolicDist => `Normal({mean, stdev});
|
||||
let pdf = (x, t: t) => Jstat.normal##pdf(x, t.mean, t.stdev);
|
||||
let cdf = (x, t: t) => Jstat.normal##cdf(x, t.mean, t.stdev);
|
||||
|
||||
|
@ -76,6 +82,7 @@ module Normal = {
|
|||
|
||||
module Beta = {
|
||||
type t = beta;
|
||||
let make = (alpha, beta) => `Beta({alpha, beta})
|
||||
let pdf = (x, t: t) => Jstat.beta##pdf(x, t.alpha, t.beta);
|
||||
let cdf = (x, t: t) => Jstat.beta##cdf(x, t.alpha, t.beta);
|
||||
let inv = (p, t: t) => Jstat.beta##inv(p, t.alpha, t.beta);
|
||||
|
@ -86,6 +93,7 @@ module Beta = {
|
|||
|
||||
module Lognormal = {
|
||||
type t = lognormal;
|
||||
let make = (mu, sigma) => `Lognormal({mu, sigma})
|
||||
let pdf = (x, t: t) => Jstat.lognormal##pdf(x, t.mu, t.sigma);
|
||||
let cdf = (x, t: t) => Jstat.lognormal##cdf(x, t.mu, t.sigma);
|
||||
let inv = (p, t: t) => Jstat.lognormal##inv(p, t.mu, t.sigma);
|
||||
|
@ -132,6 +140,7 @@ module Lognormal = {
|
|||
|
||||
module Uniform = {
|
||||
type t = uniform;
|
||||
let make = (low, high) => `Uniform({low, high})
|
||||
let pdf = (x, t: t) => Jstat.uniform##pdf(x, t.low, t.high);
|
||||
let cdf = (x, t: t) => Jstat.uniform##cdf(x, t.low, t.high);
|
||||
let inv = (p, t: t) => Jstat.uniform##inv(p, t.low, t.high);
|
||||
|
@ -147,6 +156,7 @@ module Uniform = {
|
|||
|
||||
module Float = {
|
||||
type t = float;
|
||||
let make = t => `Float(t)
|
||||
let pdf = (x, t: t) => x == t ? 1.0 : 0.0;
|
||||
let cdf = (x, t: t) => x >= t ? 1.0 : 0.0;
|
||||
let inv = (p, t: t) => p < t ? 0.0 : 1.0;
|
||||
|
@ -318,13 +328,14 @@ module T = {
|
|||
switch (d) {
|
||||
| `Float(v) =>
|
||||
Discrete(
|
||||
Discrete.make(~integralSumCache=Some(1.0), {xs: [|v|], ys: [|1.0|]}),
|
||||
Discrete.make(
|
||||
~integralSumCache=Some(1.0),
|
||||
{xs: [|v|], ys: [|1.0|]},
|
||||
),
|
||||
)
|
||||
| _ =>
|
||||
let xs = interpolateXs(~xSelection=`ByWeight, d, sampleCount);
|
||||
let ys = xs |> E.A.fmap(x => pdf(x, d));
|
||||
Continuous(
|
||||
Continuous.make(~integralSumCache=Some(1.0), {xs, ys}),
|
||||
);
|
||||
Continuous(Continuous.make(~integralSumCache=Some(1.0), {xs, ys}));
|
||||
};
|
||||
};
|
||||
|
|
Loading…
Reference in New Issue
Block a user