Simple functionality without weights
This commit is contained in:
parent
cf36594b4a
commit
8d55bba2ca
|
@ -233,7 +233,7 @@ let make = () => {
|
||||||
~onSubmit=({state}) => {None},
|
~onSubmit=({state}) => {None},
|
||||||
~initialState={
|
~initialState={
|
||||||
//guesstimatorString: "mm(normal(-10, 2), uniform(18, 25), lognormal({mean: 10, stdev: 8}), triangular(31,40,50))",
|
//guesstimatorString: "mm(normal(-10, 2), uniform(18, 25), lognormal({mean: 10, stdev: 8}), triangular(31,40,50))",
|
||||||
guesstimatorString: "mm(3,4)",
|
guesstimatorString: "mm(normal(5,2))",
|
||||||
domainType: "Complete",
|
domainType: "Complete",
|
||||||
xPoint: "50.0",
|
xPoint: "50.0",
|
||||||
xPoint2: "60.0",
|
xPoint2: "60.0",
|
||||||
|
|
|
@ -206,6 +206,7 @@ module Truncate = {
|
||||||
|
|
||||||
module Normalize = {
|
module Normalize = {
|
||||||
let rec operationToLeaf = (evaluationParams, t: node): result(node, string) => {
|
let rec operationToLeaf = (evaluationParams, t: node): result(node, string) => {
|
||||||
|
Js.log2("normalize", t);
|
||||||
switch (t) {
|
switch (t) {
|
||||||
| `RenderedDist(s) => Ok(`RenderedDist(Shape.T.normalize(s)))
|
| `RenderedDist(s) => Ok(`RenderedDist(Shape.T.normalize(s)))
|
||||||
| `SymbolicDist(_) => Ok(t)
|
| `SymbolicDist(_) => Ok(t)
|
||||||
|
@ -225,6 +226,7 @@ let callableFunction = (evaluationParams, name, args) => {
|
||||||
module Render = {
|
module Render = {
|
||||||
let rec operationToLeaf =
|
let rec operationToLeaf =
|
||||||
(evaluationParams: evaluationParams, t: node): result(t, string) => {
|
(evaluationParams: evaluationParams, t: node): result(t, string) => {
|
||||||
|
Js.log2("rendering", t);
|
||||||
switch (t) {
|
switch (t) {
|
||||||
| `Function(_) => Error("Cannot render a function")
|
| `Function(_) => Error("Cannot render a function")
|
||||||
| `SymbolicDist(d) =>
|
| `SymbolicDist(d) =>
|
||||||
|
@ -254,17 +256,18 @@ let rec toLeaf =
|
||||||
node: t,
|
node: t,
|
||||||
)
|
)
|
||||||
: result(t, string) => {
|
: result(t, string) => {
|
||||||
Js.log2("node",node);
|
Js.log2("node", node);
|
||||||
switch (node) {
|
switch (node) {
|
||||||
// Leaf nodes just stay leaf nodes
|
// Leaf nodes just stay leaf nodes
|
||||||
| `SymbolicDist(_)
|
| `SymbolicDist(_)
|
||||||
| `Function(_)
|
| `Function(_)
|
||||||
| `RenderedDist(_) => Ok(node)
|
| `RenderedDist(_) => Ok(node)
|
||||||
| `Array(args) =>
|
| `Array(args) =>
|
||||||
|
Js.log2("Array!", args);
|
||||||
args
|
args
|
||||||
|> E.A.fmap(toLeaf(evaluationParams))
|
|> E.A.fmap(toLeaf(evaluationParams))
|
||||||
|> E.A.R.firstErrorOrOpen
|
|> E.A.R.firstErrorOrOpen
|
||||||
|> E.R.fmap(r => `Array(r))
|
|> E.R.fmap(r => `Array(r));
|
||||||
// Operations nevaluationParamsd to be turned into leaves
|
// Operations nevaluationParamsd to be turned into leaves
|
||||||
| `AlgebraicCombination(algebraicOp, t1, t2) =>
|
| `AlgebraicCombination(algebraicOp, t1, t2) =>
|
||||||
AlgebraicCombination.operationToLeaf(
|
AlgebraicCombination.operationToLeaf(
|
||||||
|
@ -285,34 +288,25 @@ let rec toLeaf =
|
||||||
| `Normalize(t) => Normalize.operationToLeaf(evaluationParams, t)
|
| `Normalize(t) => Normalize.operationToLeaf(evaluationParams, t)
|
||||||
| `Render(t) => Render.operationToLeaf(evaluationParams, t)
|
| `Render(t) => Render.operationToLeaf(evaluationParams, t)
|
||||||
| `Hash(t) =>
|
| `Hash(t) =>
|
||||||
|
Js.log("In hash");
|
||||||
t
|
t
|
||||||
|> E.A.fmap(((name: string, node: node)) =>
|
|> E.A.fmap(((name: string, node: node)) =>
|
||||||
toLeaf(evaluationParams, node) |> E.R.fmap(r => (name, r))
|
toLeaf(evaluationParams, node) |> E.R.fmap(r => (name, r))
|
||||||
)
|
)
|
||||||
|> E.A.R.firstErrorOrOpen
|
|> E.A.R.firstErrorOrOpen
|
||||||
|> E.R.fmap(r => `Hash(r))
|
|> E.R.fmap(r => `Hash(r));
|
||||||
| `Symbol(r) =>
|
| `Symbol(r) =>
|
||||||
|
Js.log("Symbol");
|
||||||
ExpressionTypes.ExpressionTree.Environment.get(
|
ExpressionTypes.ExpressionTree.Environment.get(
|
||||||
evaluationParams.environment,
|
evaluationParams.environment,
|
||||||
r,
|
r,
|
||||||
)
|
)
|
||||||
|> E.O.toResult("Undeclared variable " ++ r)
|
|> E.O.toResult("Undeclared variable " ++ r)
|
||||||
|> E.R.bind(_, toLeaf(evaluationParams))
|
|> E.R.bind(_, toLeaf(evaluationParams));
|
||||||
| `FunctionCall(name, args) =>
|
| `FunctionCall(name, args) =>
|
||||||
|
Js.log3("In function call", name, args);
|
||||||
callableFunction(evaluationParams, name, args)
|
callableFunction(evaluationParams, name, args)
|
||||||
|> E.R.bind(_, toLeaf(evaluationParams))
|
|> E.R.bind(_, toLeaf(evaluationParams));
|
||||||
| `MultiModal(r) =>
|
| `MultiModal(r) => Error("Multimodal?")
|
||||||
let components =
|
|
||||||
r
|
|
||||||
|> E.A.fmap(((dist, weight)) =>
|
|
||||||
`FunctionCall("scaleMultiply", [|dist, `SymbolicDist(`Float(weight))|]));
|
|
||||||
let pointwiseSum =
|
|
||||||
components
|
|
||||||
|> Js.Array.sliceFrom(1)
|
|
||||||
|> E.A.fold_left(
|
|
||||||
(acc, x) => {`PointwiseCombination((`Add, acc, x))},
|
|
||||||
E.A.unsafe_get(components, 0),
|
|
||||||
);
|
|
||||||
Ok(`Normalize(pointwiseSum)) |> E.R.bind(_, toLeaf(evaluationParams));
|
|
||||||
};
|
};
|
||||||
};
|
};
|
||||||
|
|
|
@ -16,11 +16,12 @@ type distToFloatOperation = [
|
||||||
];
|
];
|
||||||
|
|
||||||
module ExpressionTree = {
|
module ExpressionTree = {
|
||||||
type node = [
|
type hash = array((string, node))
|
||||||
|
and node = [
|
||||||
| `SymbolicDist(SymbolicTypes.symbolicDist)
|
| `SymbolicDist(SymbolicTypes.symbolicDist)
|
||||||
| `RenderedDist(DistTypes.shape)
|
| `RenderedDist(DistTypes.shape)
|
||||||
| `Symbol(string)
|
| `Symbol(string)
|
||||||
| `Hash(array((string, node)))
|
| `Hash(hash)
|
||||||
| `Array(array(node))
|
| `Array(array(node))
|
||||||
| `Function(array(string), node)
|
| `Function(array(string), node)
|
||||||
| `AlgebraicCombination(algebraicOperation, node, node)
|
| `AlgebraicCombination(algebraicOperation, node, node)
|
||||||
|
@ -31,16 +32,31 @@ module ExpressionTree = {
|
||||||
| `FunctionCall(string, array(node))
|
| `FunctionCall(string, array(node))
|
||||||
| `MultiModal(array((node, float)))
|
| `MultiModal(array((node, float)))
|
||||||
];
|
];
|
||||||
|
|
||||||
|
module Hash = {
|
||||||
|
type t('a) = array((string, 'a));
|
||||||
|
let getByName = (t:t('a), name) =>
|
||||||
|
E.A.getBy(t, ((n, _)) => n == name) |> E.O.fmap(((_, r)) => r);
|
||||||
|
|
||||||
|
let getByNames = (hash: t('a), names:array(string)) =>
|
||||||
|
names |> E.A.fmap(name => (name, getByName(hash, name)))
|
||||||
|
};
|
||||||
// Have nil as option
|
// Have nil as option
|
||||||
let getFloat = (node:node) => node |> fun
|
let getFloat = (node: node) =>
|
||||||
| `RenderedDist(Discrete({xyShape: {xs: [|x|], ys: [|1.0|]}})) => Some(x)
|
node
|
||||||
|
|> (
|
||||||
|
fun
|
||||||
|
| `RenderedDist(Discrete({xyShape: {xs: [|x|], ys: [|1.0|]}})) =>
|
||||||
|
Some(x)
|
||||||
| `SymbolicDist(`Float(x)) => Some(x)
|
| `SymbolicDist(`Float(x)) => Some(x)
|
||||||
| _ => None
|
| _ => None
|
||||||
|
);
|
||||||
|
|
||||||
let toFloatIfNeeded = (node:node) => switch(node |> getFloat){
|
let toFloatIfNeeded = (node: node) =>
|
||||||
|
switch (node |> getFloat) {
|
||||||
| Some(float) => `SymbolicDist(`Float(float))
|
| Some(float) => `SymbolicDist(`Float(float))
|
||||||
| None => node
|
| None => node
|
||||||
}
|
};
|
||||||
|
|
||||||
type samplingInputs = {
|
type samplingInputs = {
|
||||||
sampleCount: int,
|
sampleCount: int,
|
||||||
|
@ -97,7 +113,6 @@ module ExpressionTree = {
|
||||||
|> evaluationParams.evaluateNode(evaluationParams)
|
|> evaluationParams.evaluateNode(evaluationParams)
|
||||||
|> E.R.bind(_, fn(evaluationParams));
|
|> E.R.bind(_, fn(evaluationParams));
|
||||||
|
|
||||||
|
|
||||||
module Render = {
|
module Render = {
|
||||||
type t = node;
|
type t = node;
|
||||||
|
|
||||||
|
|
|
@ -15,7 +15,7 @@ let to_: (float, float) => result(node, string) =
|
||||||
};
|
};
|
||||||
|
|
||||||
let makeSymbolicFromTwoFloats = (name, fn) =>
|
let makeSymbolicFromTwoFloats = (name, fn) =>
|
||||||
Function.make(
|
Function.T.make(
|
||||||
~name,
|
~name,
|
||||||
~outputType=`SamplingDistribution,
|
~outputType=`SamplingDistribution,
|
||||||
~inputTypes=[|`Float, `Float|],
|
~inputTypes=[|`Float, `Float|],
|
||||||
|
@ -23,10 +23,11 @@ let makeSymbolicFromTwoFloats = (name, fn) =>
|
||||||
fun
|
fun
|
||||||
| [|`Float(a), `Float(b)|] => Ok(`SymbolicDist(fn(a, b)))
|
| [|`Float(a), `Float(b)|] => Ok(`SymbolicDist(fn(a, b)))
|
||||||
| e => wrongInputsError(e),
|
| e => wrongInputsError(e),
|
||||||
|
(),
|
||||||
);
|
);
|
||||||
|
|
||||||
let makeSymbolicFromOneFloat = (name, fn) =>
|
let makeSymbolicFromOneFloat = (name, fn) =>
|
||||||
Function.make(
|
Function.T.make(
|
||||||
~name,
|
~name,
|
||||||
~outputType=`SamplingDistribution,
|
~outputType=`SamplingDistribution,
|
||||||
~inputTypes=[|`Float|],
|
~inputTypes=[|`Float|],
|
||||||
|
@ -34,10 +35,11 @@ let makeSymbolicFromOneFloat = (name, fn) =>
|
||||||
fun
|
fun
|
||||||
| [|`Float(a)|] => Ok(`SymbolicDist(fn(a)))
|
| [|`Float(a)|] => Ok(`SymbolicDist(fn(a)))
|
||||||
| e => wrongInputsError(e),
|
| e => wrongInputsError(e),
|
||||||
|
(),
|
||||||
);
|
);
|
||||||
|
|
||||||
let makeDistFloat = (name, fn) =>
|
let makeDistFloat = (name, fn) =>
|
||||||
Function.make(
|
Function.T.make(
|
||||||
~name,
|
~name,
|
||||||
~outputType=`SamplingDistribution,
|
~outputType=`SamplingDistribution,
|
||||||
~inputTypes=[|`SamplingDistribution, `Float|],
|
~inputTypes=[|`SamplingDistribution, `Float|],
|
||||||
|
@ -45,21 +47,23 @@ let makeDistFloat = (name, fn) =>
|
||||||
fun
|
fun
|
||||||
| [|`SamplingDist(a), `Float(b)|] => fn(a, b)
|
| [|`SamplingDist(a), `Float(b)|] => fn(a, b)
|
||||||
| e => wrongInputsError(e),
|
| e => wrongInputsError(e),
|
||||||
|
(),
|
||||||
);
|
);
|
||||||
|
|
||||||
let makeRenderedDistFloat = (name, fn) =>
|
let makeRenderedDistFloat = (name, fn) =>
|
||||||
Function.make(
|
Function.T.make(
|
||||||
~name,
|
~name,
|
||||||
~outputType=`RenderedDistribution,
|
~outputType=`RenderedDistribution,
|
||||||
~inputTypes=[|`RenderedDistribution, `Float|],
|
~inputTypes=[|`RenderedDistribution, `Float|],
|
||||||
~run=
|
~run=
|
||||||
fun
|
fun
|
||||||
| [|`RenderedDist(a), `Float(b)|] => fn(a, b)
|
| [|`RenderedDist(a), `Float(b)|] => fn(a, b)
|
||||||
| e => wrongInputsError(e)
|
| e => wrongInputsError(e),
|
||||||
|
(),
|
||||||
);
|
);
|
||||||
|
|
||||||
let makeDist = (name, fn) =>
|
let makeDist = (name, fn) =>
|
||||||
Function.make(
|
Function.T.make(
|
||||||
~name,
|
~name,
|
||||||
~outputType=`SamplingDistribution,
|
~outputType=`SamplingDistribution,
|
||||||
~inputTypes=[|`SamplingDistribution|],
|
~inputTypes=[|`SamplingDistribution|],
|
||||||
|
@ -67,6 +71,7 @@ let makeDist = (name, fn) =>
|
||||||
fun
|
fun
|
||||||
| [|`SamplingDist(a)|] => fn(a)
|
| [|`SamplingDist(a)|] => fn(a)
|
||||||
| e => wrongInputsError(e),
|
| e => wrongInputsError(e),
|
||||||
|
(),
|
||||||
);
|
);
|
||||||
|
|
||||||
let floatFromDist =
|
let floatFromDist =
|
||||||
|
@ -112,7 +117,7 @@ let functions = [|
|
||||||
SymbolicDist.Lognormal.fromMeanAndStdev,
|
SymbolicDist.Lognormal.fromMeanAndStdev,
|
||||||
),
|
),
|
||||||
makeSymbolicFromOneFloat("exponential", SymbolicDist.Exponential.make),
|
makeSymbolicFromOneFloat("exponential", SymbolicDist.Exponential.make),
|
||||||
Function.make(
|
Function.T.make(
|
||||||
~name="to",
|
~name="to",
|
||||||
~outputType=`SamplingDistribution,
|
~outputType=`SamplingDistribution,
|
||||||
~inputTypes=[|`Float, `Float|],
|
~inputTypes=[|`Float, `Float|],
|
||||||
|
@ -120,8 +125,9 @@ let functions = [|
|
||||||
fun
|
fun
|
||||||
| [|`Float(a), `Float(b)|] => to_(a, b)
|
| [|`Float(a), `Float(b)|] => to_(a, b)
|
||||||
| e => wrongInputsError(e),
|
| e => wrongInputsError(e),
|
||||||
|
(),
|
||||||
),
|
),
|
||||||
Function.make(
|
Function.T.make(
|
||||||
~name="triangular",
|
~name="triangular",
|
||||||
~outputType=`SamplingDistribution,
|
~outputType=`SamplingDistribution,
|
||||||
~inputTypes=[|`Float, `Float, `Float|],
|
~inputTypes=[|`Float, `Float, `Float|],
|
||||||
|
@ -131,13 +137,14 @@ let functions = [|
|
||||||
SymbolicDist.Triangular.make(a, b, c)
|
SymbolicDist.Triangular.make(a, b, c)
|
||||||
|> E.R.fmap(r => `SymbolicDist(r))
|
|> E.R.fmap(r => `SymbolicDist(r))
|
||||||
| e => wrongInputsError(e),
|
| e => wrongInputsError(e),
|
||||||
|
(),
|
||||||
),
|
),
|
||||||
makeDistFloat("pdf", (dist, float) => floatFromDist(`Pdf(float), dist)),
|
makeDistFloat("pdf", (dist, float) => floatFromDist(`Pdf(float), dist)),
|
||||||
makeDistFloat("inv", (dist, float) => floatFromDist(`Inv(float), dist)),
|
makeDistFloat("inv", (dist, float) => floatFromDist(`Inv(float), dist)),
|
||||||
makeDistFloat("cdf", (dist, float) => floatFromDist(`Cdf(float), dist)),
|
makeDistFloat("cdf", (dist, float) => floatFromDist(`Cdf(float), dist)),
|
||||||
makeDist("mean", dist => floatFromDist(`Mean, dist)),
|
makeDist("mean", dist => floatFromDist(`Mean, dist)),
|
||||||
makeDist("sample", dist => floatFromDist(`Sample, dist)),
|
makeDist("sample", dist => floatFromDist(`Sample, dist)),
|
||||||
Function.make(
|
Function.T.make(
|
||||||
~name="render",
|
~name="render",
|
||||||
~outputType=`RenderedDistribution,
|
~outputType=`RenderedDistribution,
|
||||||
~inputTypes=[|`RenderedDistribution|],
|
~inputTypes=[|`RenderedDistribution|],
|
||||||
|
@ -145,8 +152,9 @@ let functions = [|
|
||||||
fun
|
fun
|
||||||
| [|`RenderedDist(c)|] => Ok(`RenderedDist(c))
|
| [|`RenderedDist(c)|] => Ok(`RenderedDist(c))
|
||||||
| e => wrongInputsError(e),
|
| e => wrongInputsError(e),
|
||||||
|
(),
|
||||||
),
|
),
|
||||||
Function.make(
|
Function.T.make(
|
||||||
~name="normalize",
|
~name="normalize",
|
||||||
~outputType=`SamplingDistribution,
|
~outputType=`SamplingDistribution,
|
||||||
~inputTypes=[|`SamplingDistribution|],
|
~inputTypes=[|`SamplingDistribution|],
|
||||||
|
@ -156,6 +164,7 @@ let functions = [|
|
||||||
| [|`SamplingDist(`RenderedDist(c))|] =>
|
| [|`SamplingDist(`RenderedDist(c))|] =>
|
||||||
Ok(`RenderedDist(Shape.T.normalize(c)))
|
Ok(`RenderedDist(Shape.T.normalize(c)))
|
||||||
| e => wrongInputsError(e),
|
| e => wrongInputsError(e),
|
||||||
|
(),
|
||||||
),
|
),
|
||||||
makeRenderedDistFloat("scaleExp", (dist, float) =>
|
makeRenderedDistFloat("scaleExp", (dist, float) =>
|
||||||
verticalScaling(`Exponentiate, dist, float)
|
verticalScaling(`Exponentiate, dist, float)
|
||||||
|
@ -166,4 +175,42 @@ let functions = [|
|
||||||
makeRenderedDistFloat("scaleLog", (dist, float) =>
|
makeRenderedDistFloat("scaleLog", (dist, float) =>
|
||||||
verticalScaling(`Log, dist, float)
|
verticalScaling(`Log, dist, float)
|
||||||
),
|
),
|
||||||
|
Function.T.make(
|
||||||
|
~name="multimodal",
|
||||||
|
~outputType=`SamplingDistribution,
|
||||||
|
~inputTypes=[|
|
||||||
|
`Named([|
|
||||||
|
("dists", `Array(`SamplingDistribution)),
|
||||||
|
("weights", `Array(`Float)),
|
||||||
|
|]),
|
||||||
|
|],
|
||||||
|
~run=
|
||||||
|
fun
|
||||||
|
| [|`Named(r)|] => {
|
||||||
|
let foo =
|
||||||
|
(r: TypeSystem.typedValue)
|
||||||
|
: result(ExpressionTypes.ExpressionTree.node, string) =>
|
||||||
|
switch (r) {
|
||||||
|
| `SamplingDist(`SymbolicDist(c)) => Ok(`SymbolicDist(c))
|
||||||
|
| `SamplingDist(`RenderedDist(c)) => Ok(`RenderedDist(c))
|
||||||
|
| _ => Error("")
|
||||||
|
};
|
||||||
|
switch (ExpressionTypes.ExpressionTree.Hash.getByName(r, "dists")) {
|
||||||
|
| Some(`Array(r)) =>
|
||||||
|
r
|
||||||
|
|> E.A.fmap(foo)
|
||||||
|
|> E.A.R.firstErrorOrOpen
|
||||||
|
|> E.R.fmap(distributions => {
|
||||||
|
distributions
|
||||||
|
|> E.A.fold_left(
|
||||||
|
(acc, x) => {`PointwiseCombination((`Add, acc, x))},
|
||||||
|
E.A.unsafe_get(distributions, 0),
|
||||||
|
)
|
||||||
|
})
|
||||||
|
| _ => Error("")
|
||||||
|
};
|
||||||
|
}
|
||||||
|
| _ => Error(""),
|
||||||
|
(),
|
||||||
|
),
|
||||||
|];
|
|];
|
||||||
|
|
|
@ -8,9 +8,13 @@ let fnn =
|
||||||
name,
|
name,
|
||||||
args: array(node),
|
args: array(node),
|
||||||
) => {
|
) => {
|
||||||
let trySomeFns =
|
// let foundFn =
|
||||||
TypeSystem.getAndRun(Fns.functions, name, evaluationParams, args);
|
// TypeSystem.Function.Ts.findByName(Fns.functions, name) |> E.O.toResult("Function " ++ name ++ " not found");
|
||||||
switch (trySomeFns) {
|
// let ran = foundFn |> E.R.bind(_,TypeSystem.Function.T.run(evaluationParams,args))
|
||||||
|
|
||||||
|
let foundFn =
|
||||||
|
TypeSystem.Function.Ts.findByNameAndRun(Fns.functions, name, evaluationParams, args);
|
||||||
|
switch (foundFn) {
|
||||||
| Some(r) => r
|
| Some(r) => r
|
||||||
| None =>
|
| None =>
|
||||||
switch (
|
switch (
|
||||||
|
@ -21,32 +25,8 @@ let fnn =
|
||||||
),
|
),
|
||||||
) {
|
) {
|
||||||
| (_, Some(`Function(argNames, tt))) =>
|
| (_, Some(`Function(argNames, tt))) =>
|
||||||
|
Js.log("Fundction found: " ++ name);
|
||||||
PTypes.Function.run(evaluationParams, args, (argNames, tt))
|
PTypes.Function.run(evaluationParams, args, (argNames, tt))
|
||||||
| ("multimodal", _) =>
|
|
||||||
switch (args |> E.A.to_list) {
|
|
||||||
| [`Array(weights), ...dists] =>
|
|
||||||
let withWeights =
|
|
||||||
dists
|
|
||||||
|> E.L.toArray
|
|
||||||
|> E.A.fmapi((index, t) => {
|
|
||||||
let w =
|
|
||||||
weights
|
|
||||||
|> E.A.get(_, index)
|
|
||||||
|> E.O.bind(_, getFloat)
|
|
||||||
|> E.O.default(1.0);
|
|
||||||
(t, w);
|
|
||||||
});
|
|
||||||
Ok(`MultiModal(withWeights));
|
|
||||||
| dists when E.L.length(dists) > 0 =>
|
|
||||||
Ok(
|
|
||||||
`MultiModal(
|
|
||||||
dists
|
|
||||||
|> E.L.toArray
|
|
||||||
|> E.A.fmap(r => (r, 1.0)),
|
|
||||||
),
|
|
||||||
)
|
|
||||||
| _ => Error("Needs at least one distribution")
|
|
||||||
}
|
|
||||||
| _ => Error("Function " ++ name ++ " not found")
|
| _ => Error("Function " ++ name ++ " not found")
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
|
@ -198,13 +198,36 @@ module MathAdtToDistDst = {
|
||||||
});
|
});
|
||||||
};
|
};
|
||||||
|
|
||||||
let functionParser = (nodeParser, name, args) => {
|
let functionParser =
|
||||||
|
(
|
||||||
|
nodeParser:
|
||||||
|
MathJsonToMathJsAdt.arg =>
|
||||||
|
Belt.Result.t(
|
||||||
|
ProbExample.ExpressionTypes.ExpressionTree.node,
|
||||||
|
string,
|
||||||
|
),
|
||||||
|
name: string,
|
||||||
|
args: array(MathJsonToMathJsAdt.arg),
|
||||||
|
)
|
||||||
|
: result(ExpressionTypes.ExpressionTree.node, string) => {
|
||||||
let parseArray = ags =>
|
let parseArray = ags =>
|
||||||
ags |> E.A.fmap(nodeParser) |> E.A.R.firstErrorOrOpen;
|
ags |> E.A.fmap(nodeParser) |> E.A.R.firstErrorOrOpen;
|
||||||
let parseArgs = () => parseArray(args);
|
let parseArgs = () => parseArray(args);
|
||||||
switch (name) {
|
switch (name) {
|
||||||
| "lognormal" => lognormal(args, parseArgs, nodeParser)
|
| "lognormal" => lognormal(args, parseArgs, nodeParser)
|
||||||
| "multimodal"
|
| "multimodal"
|
||||||
|
| "add"
|
||||||
|
| "subtract"
|
||||||
|
| "multiply"
|
||||||
|
| "unaryMinus"
|
||||||
|
| "dotMultiply"
|
||||||
|
| "dotPow"
|
||||||
|
| "rightLogShift"
|
||||||
|
| "divide"
|
||||||
|
| "pow"
|
||||||
|
| "leftTruncate"
|
||||||
|
| "rightTruncate"
|
||||||
|
| "truncate" => operationParser(name, parseArgs())
|
||||||
| "mm" =>
|
| "mm" =>
|
||||||
let weights =
|
let weights =
|
||||||
args
|
args
|
||||||
|
@ -223,24 +246,25 @@ module MathAdtToDistDst = {
|
||||||
switch (weights, dists) {
|
switch (weights, dists) {
|
||||||
| (Some(Error(r)), _) => Error(r)
|
| (Some(Error(r)), _) => Error(r)
|
||||||
| (_, Error(r)) => Error(r)
|
| (_, Error(r)) => Error(r)
|
||||||
| (None, Ok(dists)) => Ok(`FunctionCall(("multimodal", dists)))
|
| (None, Ok(dists)) =>
|
||||||
| (Some(Ok(r)), Ok(dists)) =>
|
let hash: ExpressionTypes.ExpressionTree.node =
|
||||||
Ok(
|
`FunctionCall(("multimodal", [|`Hash(
|
||||||
`FunctionCall(("multimodal", E.A.append([|`Array(r)|], dists))),
|
[|
|
||||||
)
|
("dists", `Array(dists)),
|
||||||
|
("weights", `Array([||]))
|
||||||
|
|]
|
||||||
|
)|]));
|
||||||
|
Ok(hash);
|
||||||
|
| (Some(Ok(weights)), Ok(dists)) =>
|
||||||
|
let hash: ExpressionTypes.ExpressionTree.node =
|
||||||
|
`FunctionCall(("multimodal", [|`Hash(
|
||||||
|
[|
|
||||||
|
("dists", `Array(dists)),
|
||||||
|
("weights", `Array(weights))
|
||||||
|
|]
|
||||||
|
)|]));
|
||||||
|
Ok(hash);
|
||||||
};
|
};
|
||||||
| "add"
|
|
||||||
| "subtract"
|
|
||||||
| "multiply"
|
|
||||||
| "unaryMinus"
|
|
||||||
| "dotMultiply"
|
|
||||||
| "dotPow"
|
|
||||||
| "rightLogShift"
|
|
||||||
| "divide"
|
|
||||||
| "pow"
|
|
||||||
| "leftTruncate"
|
|
||||||
| "rightTruncate"
|
|
||||||
| "truncate" => operationParser(name, parseArgs())
|
|
||||||
| name =>
|
| name =>
|
||||||
parseArgs()
|
parseArgs()
|
||||||
|> E.R.fmap((args: array(ExpressionTypes.ExpressionTree.node)) =>
|
|> E.R.fmap((args: array(ExpressionTypes.ExpressionTree.node)) =>
|
||||||
|
|
|
@ -6,127 +6,186 @@ type samplingDist = [
|
||||||
| `RenderedDist(DistTypes.shape)
|
| `RenderedDist(DistTypes.shape)
|
||||||
];
|
];
|
||||||
|
|
||||||
type t = [
|
type _type = [
|
||||||
| `Float
|
| `Float
|
||||||
| `SamplingDistribution
|
| `SamplingDistribution
|
||||||
| `RenderedDistribution
|
| `RenderedDistribution
|
||||||
| `Array(t)
|
| `Array(_type)
|
||||||
| `Named(array((string, t)))
|
| `Named(array((string, _type)))
|
||||||
];
|
];
|
||||||
type tx = [
|
|
||||||
|
type typedValue = [
|
||||||
| `Float(float)
|
| `Float(float)
|
||||||
| `RenderedDist(DistTypes.shape)
|
| `RenderedDist(DistTypes.shape)
|
||||||
| `SamplingDist(samplingDist)
|
| `SamplingDist(samplingDist)
|
||||||
| `Array(array(tx))
|
| `Array(array(typedValue))
|
||||||
| `Named(array((string, tx)))
|
| `Named(array((string, typedValue)))
|
||||||
];
|
];
|
||||||
|
|
||||||
type fn = {
|
type _function = {
|
||||||
name: string,
|
name: string,
|
||||||
inputTypes: array(t),
|
inputTypes: array(_type),
|
||||||
outputType: t,
|
outputType: _type,
|
||||||
run: array(tx) => result(node, string),
|
run: array(typedValue) => result(node, string),
|
||||||
|
shouldCoerceTypes: bool,
|
||||||
};
|
};
|
||||||
|
|
||||||
module Function = {
|
type functions = array(_function);
|
||||||
let make = (~name, ~inputTypes, ~outputType, ~run): fn => {
|
type inputNodes = array(node);
|
||||||
name,
|
|
||||||
inputTypes,
|
|
||||||
outputType,
|
|
||||||
run,
|
|
||||||
};
|
|
||||||
};
|
|
||||||
|
|
||||||
type fns = array(fn);
|
module TypedValue = {
|
||||||
type inputTypes = array(node);
|
let rec fromNode = (node: node): result(typedValue, string) =>
|
||||||
|
|
||||||
let rec fromNodeDirect = (node: node): result(tx, string) =>
|
|
||||||
switch (ExpressionTypes.ExpressionTree.toFloatIfNeeded(node)) {
|
switch (ExpressionTypes.ExpressionTree.toFloatIfNeeded(node)) {
|
||||||
| `SymbolicDist(`Float(r)) => Ok(`Float(r))
|
| `SymbolicDist(`Float(r)) => Ok(`Float(r))
|
||||||
| `SymbolicDist(s) => Ok(`SamplingDist(`SymbolicDist(s)))
|
| `SymbolicDist(s) => Ok(`SamplingDist(`SymbolicDist(s)))
|
||||||
| `RenderedDist(s) => Ok(`RenderedDist(s))
|
| `RenderedDist(s) => Ok(`RenderedDist(s))
|
||||||
| `Array(r) =>
|
| `Array(r) =>
|
||||||
r
|
r
|
||||||
|> E.A.fmap(fromNodeDirect)
|
|> E.A.fmap(fromNode)
|
||||||
|> E.A.R.firstErrorOrOpen
|
|> E.A.R.firstErrorOrOpen
|
||||||
|> E.R.fmap(r => `Array(r))
|
|> E.R.fmap(r => `Array(r))
|
||||||
| `Hash(hash) =>
|
| `Hash(hash) =>
|
||||||
hash
|
hash
|
||||||
|> E.A.fmap(((name, t)) =>
|
|> E.A.fmap(((name, t)) => fromNode(t) |> E.R.fmap(r => (name, r)))
|
||||||
fromNodeDirect(t) |> E.R.fmap(r => (name, r))
|
|
||||||
)
|
|
||||||
|> E.A.R.firstErrorOrOpen
|
|> E.A.R.firstErrorOrOpen
|
||||||
|> E.R.fmap(r => `Named(r))
|
|> E.R.fmap(r => `Named(r))
|
||||||
| _ => Error("Wrong type")
|
| _ => Error("Wrong type")
|
||||||
};
|
};
|
||||||
|
|
||||||
let compareInput = (evaluationParams, t: t, node) =>
|
// todo: Arrays and hashes
|
||||||
switch (t) {
|
let rec fromNodeWithTypeCoercion = (evaluationParams, _type: _type, node) => {
|
||||||
| `Float =>
|
Js.log3("With Coersion!", _type, node);
|
||||||
|
switch (_type, node) {
|
||||||
|
| (`Float, _) =>
|
||||||
switch (getFloat(node)) {
|
switch (getFloat(node)) {
|
||||||
| Some(a) => Ok(`Float(a))
|
| Some(a) => Ok(`Float(a))
|
||||||
| _ => Error("Type Error: Expected float.")
|
| _ => Error("Type Error: Expected float.")
|
||||||
}
|
}
|
||||||
| `SamplingDistribution =>
|
| (`SamplingDistribution, _) =>
|
||||||
PTypes.SamplingDistribution.renderIfIsNotSamplingDistribution(
|
PTypes.SamplingDistribution.renderIfIsNotSamplingDistribution(
|
||||||
evaluationParams,
|
evaluationParams,
|
||||||
node,
|
node,
|
||||||
)
|
)
|
||||||
|> E.R.bind(_, fromNodeDirect)
|
|> E.R.bind(_, fromNode)
|
||||||
| `RenderedDistribution =>
|
| (`RenderedDistribution, _) =>
|
||||||
ExpressionTypes.ExpressionTree.Render.render(evaluationParams, node)
|
ExpressionTypes.ExpressionTree.Render.render(evaluationParams, node)
|
||||||
|> E.R.bind(_, fromNodeDirect)
|
|> E.R.bind(_, fromNode)
|
||||||
| _ => {
|
| (`Array(_type), `Array(b)) =>
|
||||||
Js.log4("Type error: Expected ", t, ", got ", node);
|
b
|
||||||
Error("Bad input, sorry.")}
|
|> E.A.fmap(fromNodeWithTypeCoercion(evaluationParams, _type))
|
||||||
};
|
|
||||||
|
|
||||||
let sanatizeInputs =
|
|
||||||
(
|
|
||||||
evaluationParams: ExpressionTypes.ExpressionTree.evaluationParams,
|
|
||||||
inputTypes: inputTypes,
|
|
||||||
t: fn,
|
|
||||||
) => {
|
|
||||||
E.A.length(t.inputTypes) == E.A.length(inputTypes)
|
|
||||||
? Belt.Array.zip(t.inputTypes, inputTypes)
|
|
||||||
|> E.A.fmap(((def, input)) =>
|
|
||||||
compareInput(evaluationParams, def, input)
|
|
||||||
)
|
|
||||||
|> (r => {Js.log2("Inputs", r); r})
|
|
||||||
|> E.A.R.firstErrorOrOpen
|
|> E.A.R.firstErrorOrOpen
|
||||||
: Error(
|
|> E.R.fmap(r => `Array(r))
|
||||||
"Wrong number of inputs. Expected"
|
| (`Named(named), `Hash(r)) =>
|
||||||
++ (E.A.length(t.inputTypes) |> E.I.toString)
|
Js.log3("Named", named, r);
|
||||||
++ ". Got:"
|
let foo =
|
||||||
++ (E.A.length(inputTypes) |> E.I.toString),
|
named
|
||||||
|
|> E.A.fmap(((name, intendedType)) =>
|
||||||
|
(
|
||||||
|
name,
|
||||||
|
intendedType,
|
||||||
|
ExpressionTypes.ExpressionTree.Hash.getByName(r, name),
|
||||||
|
)
|
||||||
);
|
);
|
||||||
|
Js.log("Named: part 2");
|
||||||
|
let bar =
|
||||||
|
foo
|
||||||
|
|> E.A.fmap(((name, intendedType, optionNode)) =>
|
||||||
|
switch (optionNode) {
|
||||||
|
| Some(node) =>
|
||||||
|
fromNodeWithTypeCoercion(evaluationParams, intendedType, node)
|
||||||
|
|> E.R.fmap(node => (name, node))
|
||||||
|
| None => Error("Hash parameter not present in hash.")
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|> E.A.R.firstErrorOrOpen
|
||||||
|
|> E.R.fmap(r => `Named(r));
|
||||||
|
Js.log3("Named!", foo, bar);
|
||||||
|
bar;
|
||||||
|
| _ => Error("fromNodeWithTypeCoercion error, sorry.")
|
||||||
|
};
|
||||||
|
};
|
||||||
};
|
};
|
||||||
|
|
||||||
let run =
|
module Function = {
|
||||||
|
type t = _function;
|
||||||
|
type ts = functions;
|
||||||
|
|
||||||
|
module T = {
|
||||||
|
let make =
|
||||||
|
(~name, ~inputTypes, ~outputType, ~run, ~shouldCoerceTypes=true, _): t => {
|
||||||
|
name,
|
||||||
|
inputTypes,
|
||||||
|
outputType,
|
||||||
|
run,
|
||||||
|
shouldCoerceTypes,
|
||||||
|
};
|
||||||
|
|
||||||
|
let _inputLengthCheck = (inputNodes: inputNodes, t: t) => {
|
||||||
|
let expectedLength = E.A.length(t.inputTypes);
|
||||||
|
let actualLength = E.A.length(inputNodes);
|
||||||
|
expectedLength == actualLength
|
||||||
|
? Ok(inputNodes)
|
||||||
|
: Error(
|
||||||
|
"Wrong number of inputs. Expected"
|
||||||
|
++ (expectedLength |> E.I.toString)
|
||||||
|
++ ". Got:"
|
||||||
|
++ (actualLength |> E.I.toString),
|
||||||
|
);
|
||||||
|
};
|
||||||
|
|
||||||
|
let _coerceInputNodes =
|
||||||
|
(evaluationParams, inputTypes, shouldCoerce, inputNodes) =>
|
||||||
|
Belt.Array.zip(inputTypes, inputNodes)
|
||||||
|
|> E.A.fmap(((def, input)) =>
|
||||||
|
shouldCoerce
|
||||||
|
? TypedValue.fromNodeWithTypeCoercion(
|
||||||
|
evaluationParams,
|
||||||
|
def,
|
||||||
|
input,
|
||||||
|
)
|
||||||
|
: TypedValue.fromNode(input)
|
||||||
|
)
|
||||||
|
|> E.A.R.firstErrorOrOpen;
|
||||||
|
|
||||||
|
let inputsToTypedValues =
|
||||||
(
|
(
|
||||||
evaluationParams: ExpressionTypes.ExpressionTree.evaluationParams,
|
evaluationParams: ExpressionTypes.ExpressionTree.evaluationParams,
|
||||||
inputTypes: inputTypes,
|
inputNodes: inputNodes,
|
||||||
t: fn,
|
t: t,
|
||||||
) => {
|
) => {
|
||||||
let _sanitizedInputs = sanatizeInputs(evaluationParams, inputTypes, t);
|
_inputLengthCheck(inputNodes, t)
|
||||||
_sanitizedInputs |> E.R.bind(_,t.run)
|
->E.R.bind(
|
||||||
|
_coerceInputNodes(
|
||||||
|
evaluationParams,
|
||||||
|
t.inputTypes,
|
||||||
|
t.shouldCoerceTypes,
|
||||||
|
),
|
||||||
|
);
|
||||||
|
};
|
||||||
|
|
||||||
|
let run =
|
||||||
|
(
|
||||||
|
evaluationParams: ExpressionTypes.ExpressionTree.evaluationParams,
|
||||||
|
inputNodes: inputNodes,
|
||||||
|
t: t,
|
||||||
|
) => {
|
||||||
|
Js.log("Running!");
|
||||||
|
inputsToTypedValues(evaluationParams, inputNodes, t)->E.R.bind(t.run)
|
||||||
|> (
|
|> (
|
||||||
fun
|
fun
|
||||||
| Ok(i) => Ok(i)
|
| Ok(i) => Ok(i)
|
||||||
| Error(r) => {
|
| Error(r) => {
|
||||||
Js.log4(
|
|
||||||
"Error",
|
|
||||||
inputTypes,
|
|
||||||
t,
|
|
||||||
_sanitizedInputs
|
|
||||||
);
|
|
||||||
Error("Function " ++ t.name ++ " error: " ++ r);
|
Error("Function " ++ t.name ++ " error: " ++ r);
|
||||||
}
|
}
|
||||||
);
|
);
|
||||||
|
};
|
||||||
|
};
|
||||||
|
|
||||||
|
module Ts = {
|
||||||
|
let findByName = (ts: ts, n: string) =>
|
||||||
|
ts |> Belt.Array.getBy(_, ({name}) => name == n);
|
||||||
|
|
||||||
|
let findByNameAndRun = (ts: ts, n: string, evaluationParams, inputTypes) =>
|
||||||
|
findByName(ts, n) |> E.O.fmap(T.run(evaluationParams, inputTypes));
|
||||||
|
};
|
||||||
};
|
};
|
||||||
|
|
||||||
let getFn = (fns: fns, n: string) =>
|
|
||||||
fns |> Belt.Array.getBy(_, ({name}) => name == n);
|
|
||||||
|
|
||||||
let getAndRun = (fns: fns, n: string, evaluationParams, inputTypes) =>
|
|
||||||
getFn(fns, n) |> E.O.fmap(run(evaluationParams, inputTypes));
|
|
||||||
|
|
Loading…
Reference in New Issue
Block a user