Simple functionality without weights
This commit is contained in:
parent
cf36594b4a
commit
8d55bba2ca
|
@ -233,7 +233,7 @@ let make = () => {
|
|||
~onSubmit=({state}) => {None},
|
||||
~initialState={
|
||||
//guesstimatorString: "mm(normal(-10, 2), uniform(18, 25), lognormal({mean: 10, stdev: 8}), triangular(31,40,50))",
|
||||
guesstimatorString: "mm(3,4)",
|
||||
guesstimatorString: "mm(normal(5,2))",
|
||||
domainType: "Complete",
|
||||
xPoint: "50.0",
|
||||
xPoint2: "60.0",
|
||||
|
|
|
@ -206,6 +206,7 @@ module Truncate = {
|
|||
|
||||
module Normalize = {
|
||||
let rec operationToLeaf = (evaluationParams, t: node): result(node, string) => {
|
||||
Js.log2("normalize", t);
|
||||
switch (t) {
|
||||
| `RenderedDist(s) => Ok(`RenderedDist(Shape.T.normalize(s)))
|
||||
| `SymbolicDist(_) => Ok(t)
|
||||
|
@ -225,6 +226,7 @@ let callableFunction = (evaluationParams, name, args) => {
|
|||
module Render = {
|
||||
let rec operationToLeaf =
|
||||
(evaluationParams: evaluationParams, t: node): result(t, string) => {
|
||||
Js.log2("rendering", t);
|
||||
switch (t) {
|
||||
| `Function(_) => Error("Cannot render a function")
|
||||
| `SymbolicDist(d) =>
|
||||
|
@ -254,17 +256,18 @@ let rec toLeaf =
|
|||
node: t,
|
||||
)
|
||||
: result(t, string) => {
|
||||
Js.log2("node",node);
|
||||
Js.log2("node", node);
|
||||
switch (node) {
|
||||
// Leaf nodes just stay leaf nodes
|
||||
| `SymbolicDist(_)
|
||||
| `Function(_)
|
||||
| `RenderedDist(_) => Ok(node)
|
||||
| `Array(args) =>
|
||||
Js.log2("Array!", args);
|
||||
args
|
||||
|> E.A.fmap(toLeaf(evaluationParams))
|
||||
|> E.A.R.firstErrorOrOpen
|
||||
|> E.R.fmap(r => `Array(r))
|
||||
|> E.R.fmap(r => `Array(r));
|
||||
// Operations nevaluationParamsd to be turned into leaves
|
||||
| `AlgebraicCombination(algebraicOp, t1, t2) =>
|
||||
AlgebraicCombination.operationToLeaf(
|
||||
|
@ -285,34 +288,25 @@ let rec toLeaf =
|
|||
| `Normalize(t) => Normalize.operationToLeaf(evaluationParams, t)
|
||||
| `Render(t) => Render.operationToLeaf(evaluationParams, t)
|
||||
| `Hash(t) =>
|
||||
Js.log("In hash");
|
||||
t
|
||||
|> E.A.fmap(((name: string, node: node)) =>
|
||||
toLeaf(evaluationParams, node) |> E.R.fmap(r => (name, r))
|
||||
)
|
||||
|> E.A.R.firstErrorOrOpen
|
||||
|> E.R.fmap(r => `Hash(r))
|
||||
|> E.R.fmap(r => `Hash(r));
|
||||
| `Symbol(r) =>
|
||||
Js.log("Symbol");
|
||||
ExpressionTypes.ExpressionTree.Environment.get(
|
||||
evaluationParams.environment,
|
||||
r,
|
||||
)
|
||||
|> E.O.toResult("Undeclared variable " ++ r)
|
||||
|> E.R.bind(_, toLeaf(evaluationParams))
|
||||
|> E.R.bind(_, toLeaf(evaluationParams));
|
||||
| `FunctionCall(name, args) =>
|
||||
Js.log3("In function call", name, args);
|
||||
callableFunction(evaluationParams, name, args)
|
||||
|> E.R.bind(_, toLeaf(evaluationParams))
|
||||
| `MultiModal(r) =>
|
||||
let components =
|
||||
r
|
||||
|> E.A.fmap(((dist, weight)) =>
|
||||
`FunctionCall("scaleMultiply", [|dist, `SymbolicDist(`Float(weight))|]));
|
||||
let pointwiseSum =
|
||||
components
|
||||
|> Js.Array.sliceFrom(1)
|
||||
|> E.A.fold_left(
|
||||
(acc, x) => {`PointwiseCombination((`Add, acc, x))},
|
||||
E.A.unsafe_get(components, 0),
|
||||
);
|
||||
Ok(`Normalize(pointwiseSum)) |> E.R.bind(_, toLeaf(evaluationParams));
|
||||
|> E.R.bind(_, toLeaf(evaluationParams));
|
||||
| `MultiModal(r) => Error("Multimodal?")
|
||||
};
|
||||
};
|
||||
|
|
|
@ -16,11 +16,12 @@ type distToFloatOperation = [
|
|||
];
|
||||
|
||||
module ExpressionTree = {
|
||||
type node = [
|
||||
type hash = array((string, node))
|
||||
and node = [
|
||||
| `SymbolicDist(SymbolicTypes.symbolicDist)
|
||||
| `RenderedDist(DistTypes.shape)
|
||||
| `Symbol(string)
|
||||
| `Hash(array((string, node)))
|
||||
| `Hash(hash)
|
||||
| `Array(array(node))
|
||||
| `Function(array(string), node)
|
||||
| `AlgebraicCombination(algebraicOperation, node, node)
|
||||
|
@ -31,16 +32,31 @@ module ExpressionTree = {
|
|||
| `FunctionCall(string, array(node))
|
||||
| `MultiModal(array((node, float)))
|
||||
];
|
||||
// Have nil as option
|
||||
let getFloat = (node:node) => node |> fun
|
||||
| `RenderedDist(Discrete({xyShape: {xs: [|x|], ys: [|1.0|]}})) => Some(x)
|
||||
| `SymbolicDist(`Float(x)) => Some(x)
|
||||
| _ => None
|
||||
|
||||
let toFloatIfNeeded = (node:node) => switch(node |> getFloat){
|
||||
module Hash = {
|
||||
type t('a) = array((string, 'a));
|
||||
let getByName = (t:t('a), name) =>
|
||||
E.A.getBy(t, ((n, _)) => n == name) |> E.O.fmap(((_, r)) => r);
|
||||
|
||||
let getByNames = (hash: t('a), names:array(string)) =>
|
||||
names |> E.A.fmap(name => (name, getByName(hash, name)))
|
||||
};
|
||||
// Have nil as option
|
||||
let getFloat = (node: node) =>
|
||||
node
|
||||
|> (
|
||||
fun
|
||||
| `RenderedDist(Discrete({xyShape: {xs: [|x|], ys: [|1.0|]}})) =>
|
||||
Some(x)
|
||||
| `SymbolicDist(`Float(x)) => Some(x)
|
||||
| _ => None
|
||||
);
|
||||
|
||||
let toFloatIfNeeded = (node: node) =>
|
||||
switch (node |> getFloat) {
|
||||
| Some(float) => `SymbolicDist(`Float(float))
|
||||
| None => node
|
||||
}
|
||||
};
|
||||
|
||||
type samplingInputs = {
|
||||
sampleCount: int,
|
||||
|
@ -97,7 +113,6 @@ module ExpressionTree = {
|
|||
|> evaluationParams.evaluateNode(evaluationParams)
|
||||
|> E.R.bind(_, fn(evaluationParams));
|
||||
|
||||
|
||||
module Render = {
|
||||
type t = node;
|
||||
|
||||
|
|
|
@ -15,7 +15,7 @@ let to_: (float, float) => result(node, string) =
|
|||
};
|
||||
|
||||
let makeSymbolicFromTwoFloats = (name, fn) =>
|
||||
Function.make(
|
||||
Function.T.make(
|
||||
~name,
|
||||
~outputType=`SamplingDistribution,
|
||||
~inputTypes=[|`Float, `Float|],
|
||||
|
@ -23,10 +23,11 @@ let makeSymbolicFromTwoFloats = (name, fn) =>
|
|||
fun
|
||||
| [|`Float(a), `Float(b)|] => Ok(`SymbolicDist(fn(a, b)))
|
||||
| e => wrongInputsError(e),
|
||||
(),
|
||||
);
|
||||
|
||||
let makeSymbolicFromOneFloat = (name, fn) =>
|
||||
Function.make(
|
||||
Function.T.make(
|
||||
~name,
|
||||
~outputType=`SamplingDistribution,
|
||||
~inputTypes=[|`Float|],
|
||||
|
@ -34,10 +35,11 @@ let makeSymbolicFromOneFloat = (name, fn) =>
|
|||
fun
|
||||
| [|`Float(a)|] => Ok(`SymbolicDist(fn(a)))
|
||||
| e => wrongInputsError(e),
|
||||
(),
|
||||
);
|
||||
|
||||
let makeDistFloat = (name, fn) =>
|
||||
Function.make(
|
||||
Function.T.make(
|
||||
~name,
|
||||
~outputType=`SamplingDistribution,
|
||||
~inputTypes=[|`SamplingDistribution, `Float|],
|
||||
|
@ -45,21 +47,23 @@ let makeDistFloat = (name, fn) =>
|
|||
fun
|
||||
| [|`SamplingDist(a), `Float(b)|] => fn(a, b)
|
||||
| e => wrongInputsError(e),
|
||||
(),
|
||||
);
|
||||
|
||||
let makeRenderedDistFloat = (name, fn) =>
|
||||
Function.make(
|
||||
Function.T.make(
|
||||
~name,
|
||||
~outputType=`RenderedDistribution,
|
||||
~inputTypes=[|`RenderedDistribution, `Float|],
|
||||
~run=
|
||||
fun
|
||||
| [|`RenderedDist(a), `Float(b)|] => fn(a, b)
|
||||
| e => wrongInputsError(e)
|
||||
| e => wrongInputsError(e),
|
||||
(),
|
||||
);
|
||||
|
||||
let makeDist = (name, fn) =>
|
||||
Function.make(
|
||||
Function.T.make(
|
||||
~name,
|
||||
~outputType=`SamplingDistribution,
|
||||
~inputTypes=[|`SamplingDistribution|],
|
||||
|
@ -67,6 +71,7 @@ let makeDist = (name, fn) =>
|
|||
fun
|
||||
| [|`SamplingDist(a)|] => fn(a)
|
||||
| e => wrongInputsError(e),
|
||||
(),
|
||||
);
|
||||
|
||||
let floatFromDist =
|
||||
|
@ -112,7 +117,7 @@ let functions = [|
|
|||
SymbolicDist.Lognormal.fromMeanAndStdev,
|
||||
),
|
||||
makeSymbolicFromOneFloat("exponential", SymbolicDist.Exponential.make),
|
||||
Function.make(
|
||||
Function.T.make(
|
||||
~name="to",
|
||||
~outputType=`SamplingDistribution,
|
||||
~inputTypes=[|`Float, `Float|],
|
||||
|
@ -120,8 +125,9 @@ let functions = [|
|
|||
fun
|
||||
| [|`Float(a), `Float(b)|] => to_(a, b)
|
||||
| e => wrongInputsError(e),
|
||||
(),
|
||||
),
|
||||
Function.make(
|
||||
Function.T.make(
|
||||
~name="triangular",
|
||||
~outputType=`SamplingDistribution,
|
||||
~inputTypes=[|`Float, `Float, `Float|],
|
||||
|
@ -131,13 +137,14 @@ let functions = [|
|
|||
SymbolicDist.Triangular.make(a, b, c)
|
||||
|> E.R.fmap(r => `SymbolicDist(r))
|
||||
| e => wrongInputsError(e),
|
||||
(),
|
||||
),
|
||||
makeDistFloat("pdf", (dist, float) => floatFromDist(`Pdf(float), dist)),
|
||||
makeDistFloat("inv", (dist, float) => floatFromDist(`Inv(float), dist)),
|
||||
makeDistFloat("cdf", (dist, float) => floatFromDist(`Cdf(float), dist)),
|
||||
makeDist("mean", dist => floatFromDist(`Mean, dist)),
|
||||
makeDist("sample", dist => floatFromDist(`Sample, dist)),
|
||||
Function.make(
|
||||
Function.T.make(
|
||||
~name="render",
|
||||
~outputType=`RenderedDistribution,
|
||||
~inputTypes=[|`RenderedDistribution|],
|
||||
|
@ -145,8 +152,9 @@ let functions = [|
|
|||
fun
|
||||
| [|`RenderedDist(c)|] => Ok(`RenderedDist(c))
|
||||
| e => wrongInputsError(e),
|
||||
(),
|
||||
),
|
||||
Function.make(
|
||||
Function.T.make(
|
||||
~name="normalize",
|
||||
~outputType=`SamplingDistribution,
|
||||
~inputTypes=[|`SamplingDistribution|],
|
||||
|
@ -156,6 +164,7 @@ let functions = [|
|
|||
| [|`SamplingDist(`RenderedDist(c))|] =>
|
||||
Ok(`RenderedDist(Shape.T.normalize(c)))
|
||||
| e => wrongInputsError(e),
|
||||
(),
|
||||
),
|
||||
makeRenderedDistFloat("scaleExp", (dist, float) =>
|
||||
verticalScaling(`Exponentiate, dist, float)
|
||||
|
@ -166,4 +175,42 @@ let functions = [|
|
|||
makeRenderedDistFloat("scaleLog", (dist, float) =>
|
||||
verticalScaling(`Log, dist, float)
|
||||
),
|
||||
Function.T.make(
|
||||
~name="multimodal",
|
||||
~outputType=`SamplingDistribution,
|
||||
~inputTypes=[|
|
||||
`Named([|
|
||||
("dists", `Array(`SamplingDistribution)),
|
||||
("weights", `Array(`Float)),
|
||||
|]),
|
||||
|],
|
||||
~run=
|
||||
fun
|
||||
| [|`Named(r)|] => {
|
||||
let foo =
|
||||
(r: TypeSystem.typedValue)
|
||||
: result(ExpressionTypes.ExpressionTree.node, string) =>
|
||||
switch (r) {
|
||||
| `SamplingDist(`SymbolicDist(c)) => Ok(`SymbolicDist(c))
|
||||
| `SamplingDist(`RenderedDist(c)) => Ok(`RenderedDist(c))
|
||||
| _ => Error("")
|
||||
};
|
||||
switch (ExpressionTypes.ExpressionTree.Hash.getByName(r, "dists")) {
|
||||
| Some(`Array(r)) =>
|
||||
r
|
||||
|> E.A.fmap(foo)
|
||||
|> E.A.R.firstErrorOrOpen
|
||||
|> E.R.fmap(distributions => {
|
||||
distributions
|
||||
|> E.A.fold_left(
|
||||
(acc, x) => {`PointwiseCombination((`Add, acc, x))},
|
||||
E.A.unsafe_get(distributions, 0),
|
||||
)
|
||||
})
|
||||
| _ => Error("")
|
||||
};
|
||||
}
|
||||
| _ => Error(""),
|
||||
(),
|
||||
),
|
||||
|];
|
||||
|
|
|
@ -8,9 +8,13 @@ let fnn =
|
|||
name,
|
||||
args: array(node),
|
||||
) => {
|
||||
let trySomeFns =
|
||||
TypeSystem.getAndRun(Fns.functions, name, evaluationParams, args);
|
||||
switch (trySomeFns) {
|
||||
// let foundFn =
|
||||
// TypeSystem.Function.Ts.findByName(Fns.functions, name) |> E.O.toResult("Function " ++ name ++ " not found");
|
||||
// let ran = foundFn |> E.R.bind(_,TypeSystem.Function.T.run(evaluationParams,args))
|
||||
|
||||
let foundFn =
|
||||
TypeSystem.Function.Ts.findByNameAndRun(Fns.functions, name, evaluationParams, args);
|
||||
switch (foundFn) {
|
||||
| Some(r) => r
|
||||
| None =>
|
||||
switch (
|
||||
|
@ -21,32 +25,8 @@ let fnn =
|
|||
),
|
||||
) {
|
||||
| (_, Some(`Function(argNames, tt))) =>
|
||||
Js.log("Fundction found: " ++ name);
|
||||
PTypes.Function.run(evaluationParams, args, (argNames, tt))
|
||||
| ("multimodal", _) =>
|
||||
switch (args |> E.A.to_list) {
|
||||
| [`Array(weights), ...dists] =>
|
||||
let withWeights =
|
||||
dists
|
||||
|> E.L.toArray
|
||||
|> E.A.fmapi((index, t) => {
|
||||
let w =
|
||||
weights
|
||||
|> E.A.get(_, index)
|
||||
|> E.O.bind(_, getFloat)
|
||||
|> E.O.default(1.0);
|
||||
(t, w);
|
||||
});
|
||||
Ok(`MultiModal(withWeights));
|
||||
| dists when E.L.length(dists) > 0 =>
|
||||
Ok(
|
||||
`MultiModal(
|
||||
dists
|
||||
|> E.L.toArray
|
||||
|> E.A.fmap(r => (r, 1.0)),
|
||||
),
|
||||
)
|
||||
| _ => Error("Needs at least one distribution")
|
||||
}
|
||||
| _ => Error("Function " ++ name ++ " not found")
|
||||
}
|
||||
};
|
||||
|
|
|
@ -198,13 +198,36 @@ module MathAdtToDistDst = {
|
|||
});
|
||||
};
|
||||
|
||||
let functionParser = (nodeParser, name, args) => {
|
||||
let functionParser =
|
||||
(
|
||||
nodeParser:
|
||||
MathJsonToMathJsAdt.arg =>
|
||||
Belt.Result.t(
|
||||
ProbExample.ExpressionTypes.ExpressionTree.node,
|
||||
string,
|
||||
),
|
||||
name: string,
|
||||
args: array(MathJsonToMathJsAdt.arg),
|
||||
)
|
||||
: result(ExpressionTypes.ExpressionTree.node, string) => {
|
||||
let parseArray = ags =>
|
||||
ags |> E.A.fmap(nodeParser) |> E.A.R.firstErrorOrOpen;
|
||||
let parseArgs = () => parseArray(args);
|
||||
switch (name) {
|
||||
| "lognormal" => lognormal(args, parseArgs, nodeParser)
|
||||
| "multimodal"
|
||||
| "add"
|
||||
| "subtract"
|
||||
| "multiply"
|
||||
| "unaryMinus"
|
||||
| "dotMultiply"
|
||||
| "dotPow"
|
||||
| "rightLogShift"
|
||||
| "divide"
|
||||
| "pow"
|
||||
| "leftTruncate"
|
||||
| "rightTruncate"
|
||||
| "truncate" => operationParser(name, parseArgs())
|
||||
| "mm" =>
|
||||
let weights =
|
||||
args
|
||||
|
@ -223,24 +246,25 @@ module MathAdtToDistDst = {
|
|||
switch (weights, dists) {
|
||||
| (Some(Error(r)), _) => Error(r)
|
||||
| (_, Error(r)) => Error(r)
|
||||
| (None, Ok(dists)) => Ok(`FunctionCall(("multimodal", dists)))
|
||||
| (Some(Ok(r)), Ok(dists)) =>
|
||||
Ok(
|
||||
`FunctionCall(("multimodal", E.A.append([|`Array(r)|], dists))),
|
||||
)
|
||||
| (None, Ok(dists)) =>
|
||||
let hash: ExpressionTypes.ExpressionTree.node =
|
||||
`FunctionCall(("multimodal", [|`Hash(
|
||||
[|
|
||||
("dists", `Array(dists)),
|
||||
("weights", `Array([||]))
|
||||
|]
|
||||
)|]));
|
||||
Ok(hash);
|
||||
| (Some(Ok(weights)), Ok(dists)) =>
|
||||
let hash: ExpressionTypes.ExpressionTree.node =
|
||||
`FunctionCall(("multimodal", [|`Hash(
|
||||
[|
|
||||
("dists", `Array(dists)),
|
||||
("weights", `Array(weights))
|
||||
|]
|
||||
)|]));
|
||||
Ok(hash);
|
||||
};
|
||||
| "add"
|
||||
| "subtract"
|
||||
| "multiply"
|
||||
| "unaryMinus"
|
||||
| "dotMultiply"
|
||||
| "dotPow"
|
||||
| "rightLogShift"
|
||||
| "divide"
|
||||
| "pow"
|
||||
| "leftTruncate"
|
||||
| "rightTruncate"
|
||||
| "truncate" => operationParser(name, parseArgs())
|
||||
| name =>
|
||||
parseArgs()
|
||||
|> E.R.fmap((args: array(ExpressionTypes.ExpressionTree.node)) =>
|
||||
|
|
|
@ -6,127 +6,186 @@ type samplingDist = [
|
|||
| `RenderedDist(DistTypes.shape)
|
||||
];
|
||||
|
||||
type t = [
|
||||
type _type = [
|
||||
| `Float
|
||||
| `SamplingDistribution
|
||||
| `RenderedDistribution
|
||||
| `Array(t)
|
||||
| `Named(array((string, t)))
|
||||
| `Array(_type)
|
||||
| `Named(array((string, _type)))
|
||||
];
|
||||
type tx = [
|
||||
|
||||
type typedValue = [
|
||||
| `Float(float)
|
||||
| `RenderedDist(DistTypes.shape)
|
||||
| `SamplingDist(samplingDist)
|
||||
| `Array(array(tx))
|
||||
| `Named(array((string, tx)))
|
||||
| `Array(array(typedValue))
|
||||
| `Named(array((string, typedValue)))
|
||||
];
|
||||
|
||||
type fn = {
|
||||
type _function = {
|
||||
name: string,
|
||||
inputTypes: array(t),
|
||||
outputType: t,
|
||||
run: array(tx) => result(node, string),
|
||||
inputTypes: array(_type),
|
||||
outputType: _type,
|
||||
run: array(typedValue) => result(node, string),
|
||||
shouldCoerceTypes: bool,
|
||||
};
|
||||
|
||||
type functions = array(_function);
|
||||
type inputNodes = array(node);
|
||||
|
||||
module TypedValue = {
|
||||
let rec fromNode = (node: node): result(typedValue, string) =>
|
||||
switch (ExpressionTypes.ExpressionTree.toFloatIfNeeded(node)) {
|
||||
| `SymbolicDist(`Float(r)) => Ok(`Float(r))
|
||||
| `SymbolicDist(s) => Ok(`SamplingDist(`SymbolicDist(s)))
|
||||
| `RenderedDist(s) => Ok(`RenderedDist(s))
|
||||
| `Array(r) =>
|
||||
r
|
||||
|> E.A.fmap(fromNode)
|
||||
|> E.A.R.firstErrorOrOpen
|
||||
|> E.R.fmap(r => `Array(r))
|
||||
| `Hash(hash) =>
|
||||
hash
|
||||
|> E.A.fmap(((name, t)) => fromNode(t) |> E.R.fmap(r => (name, r)))
|
||||
|> E.A.R.firstErrorOrOpen
|
||||
|> E.R.fmap(r => `Named(r))
|
||||
| _ => Error("Wrong type")
|
||||
};
|
||||
|
||||
// todo: Arrays and hashes
|
||||
let rec fromNodeWithTypeCoercion = (evaluationParams, _type: _type, node) => {
|
||||
Js.log3("With Coersion!", _type, node);
|
||||
switch (_type, node) {
|
||||
| (`Float, _) =>
|
||||
switch (getFloat(node)) {
|
||||
| Some(a) => Ok(`Float(a))
|
||||
| _ => Error("Type Error: Expected float.")
|
||||
}
|
||||
| (`SamplingDistribution, _) =>
|
||||
PTypes.SamplingDistribution.renderIfIsNotSamplingDistribution(
|
||||
evaluationParams,
|
||||
node,
|
||||
)
|
||||
|> E.R.bind(_, fromNode)
|
||||
| (`RenderedDistribution, _) =>
|
||||
ExpressionTypes.ExpressionTree.Render.render(evaluationParams, node)
|
||||
|> E.R.bind(_, fromNode)
|
||||
| (`Array(_type), `Array(b)) =>
|
||||
b
|
||||
|> E.A.fmap(fromNodeWithTypeCoercion(evaluationParams, _type))
|
||||
|> E.A.R.firstErrorOrOpen
|
||||
|> E.R.fmap(r => `Array(r))
|
||||
| (`Named(named), `Hash(r)) =>
|
||||
Js.log3("Named", named, r);
|
||||
let foo =
|
||||
named
|
||||
|> E.A.fmap(((name, intendedType)) =>
|
||||
(
|
||||
name,
|
||||
intendedType,
|
||||
ExpressionTypes.ExpressionTree.Hash.getByName(r, name),
|
||||
)
|
||||
);
|
||||
Js.log("Named: part 2");
|
||||
let bar =
|
||||
foo
|
||||
|> E.A.fmap(((name, intendedType, optionNode)) =>
|
||||
switch (optionNode) {
|
||||
| Some(node) =>
|
||||
fromNodeWithTypeCoercion(evaluationParams, intendedType, node)
|
||||
|> E.R.fmap(node => (name, node))
|
||||
| None => Error("Hash parameter not present in hash.")
|
||||
}
|
||||
)
|
||||
|> E.A.R.firstErrorOrOpen
|
||||
|> E.R.fmap(r => `Named(r));
|
||||
Js.log3("Named!", foo, bar);
|
||||
bar;
|
||||
| _ => Error("fromNodeWithTypeCoercion error, sorry.")
|
||||
};
|
||||
};
|
||||
};
|
||||
|
||||
module Function = {
|
||||
let make = (~name, ~inputTypes, ~outputType, ~run): fn => {
|
||||
name,
|
||||
inputTypes,
|
||||
outputType,
|
||||
run,
|
||||
};
|
||||
};
|
||||
type t = _function;
|
||||
type ts = functions;
|
||||
|
||||
type fns = array(fn);
|
||||
type inputTypes = array(node);
|
||||
module T = {
|
||||
let make =
|
||||
(~name, ~inputTypes, ~outputType, ~run, ~shouldCoerceTypes=true, _): t => {
|
||||
name,
|
||||
inputTypes,
|
||||
outputType,
|
||||
run,
|
||||
shouldCoerceTypes,
|
||||
};
|
||||
|
||||
let rec fromNodeDirect = (node: node): result(tx, string) =>
|
||||
switch (ExpressionTypes.ExpressionTree.toFloatIfNeeded(node)) {
|
||||
| `SymbolicDist(`Float(r)) => Ok(`Float(r))
|
||||
| `SymbolicDist(s) => Ok(`SamplingDist(`SymbolicDist(s)))
|
||||
| `RenderedDist(s) => Ok(`RenderedDist(s))
|
||||
| `Array(r) =>
|
||||
r
|
||||
|> E.A.fmap(fromNodeDirect)
|
||||
|> E.A.R.firstErrorOrOpen
|
||||
|> E.R.fmap(r => `Array(r))
|
||||
| `Hash(hash) =>
|
||||
hash
|
||||
|> E.A.fmap(((name, t)) =>
|
||||
fromNodeDirect(t) |> E.R.fmap(r => (name, r))
|
||||
)
|
||||
|> E.A.R.firstErrorOrOpen
|
||||
|> E.R.fmap(r => `Named(r))
|
||||
| _ => Error("Wrong type")
|
||||
};
|
||||
let _inputLengthCheck = (inputNodes: inputNodes, t: t) => {
|
||||
let expectedLength = E.A.length(t.inputTypes);
|
||||
let actualLength = E.A.length(inputNodes);
|
||||
expectedLength == actualLength
|
||||
? Ok(inputNodes)
|
||||
: Error(
|
||||
"Wrong number of inputs. Expected"
|
||||
++ (expectedLength |> E.I.toString)
|
||||
++ ". Got:"
|
||||
++ (actualLength |> E.I.toString),
|
||||
);
|
||||
};
|
||||
|
||||
let compareInput = (evaluationParams, t: t, node) =>
|
||||
switch (t) {
|
||||
| `Float =>
|
||||
switch (getFloat(node)) {
|
||||
| Some(a) => Ok(`Float(a))
|
||||
| _ => Error("Type Error: Expected float.")
|
||||
}
|
||||
| `SamplingDistribution =>
|
||||
PTypes.SamplingDistribution.renderIfIsNotSamplingDistribution(
|
||||
evaluationParams,
|
||||
node,
|
||||
)
|
||||
|> E.R.bind(_, fromNodeDirect)
|
||||
| `RenderedDistribution =>
|
||||
ExpressionTypes.ExpressionTree.Render.render(evaluationParams, node)
|
||||
|> E.R.bind(_, fromNodeDirect)
|
||||
| _ => {
|
||||
Js.log4("Type error: Expected ", t, ", got ", node);
|
||||
Error("Bad input, sorry.")}
|
||||
};
|
||||
|
||||
let sanatizeInputs =
|
||||
(
|
||||
evaluationParams: ExpressionTypes.ExpressionTree.evaluationParams,
|
||||
inputTypes: inputTypes,
|
||||
t: fn,
|
||||
) => {
|
||||
E.A.length(t.inputTypes) == E.A.length(inputTypes)
|
||||
? Belt.Array.zip(t.inputTypes, inputTypes)
|
||||
let _coerceInputNodes =
|
||||
(evaluationParams, inputTypes, shouldCoerce, inputNodes) =>
|
||||
Belt.Array.zip(inputTypes, inputNodes)
|
||||
|> E.A.fmap(((def, input)) =>
|
||||
compareInput(evaluationParams, def, input)
|
||||
shouldCoerce
|
||||
? TypedValue.fromNodeWithTypeCoercion(
|
||||
evaluationParams,
|
||||
def,
|
||||
input,
|
||||
)
|
||||
: TypedValue.fromNode(input)
|
||||
)
|
||||
|> (r => {Js.log2("Inputs", r); r})
|
||||
|> E.A.R.firstErrorOrOpen
|
||||
: Error(
|
||||
"Wrong number of inputs. Expected"
|
||||
++ (E.A.length(t.inputTypes) |> E.I.toString)
|
||||
++ ". Got:"
|
||||
++ (E.A.length(inputTypes) |> E.I.toString),
|
||||
);
|
||||
};
|
||||
|> E.A.R.firstErrorOrOpen;
|
||||
|
||||
let run =
|
||||
(
|
||||
evaluationParams: ExpressionTypes.ExpressionTree.evaluationParams,
|
||||
inputTypes: inputTypes,
|
||||
t: fn,
|
||||
) => {
|
||||
let _sanitizedInputs = sanatizeInputs(evaluationParams, inputTypes, t);
|
||||
_sanitizedInputs |> E.R.bind(_,t.run)
|
||||
|> (
|
||||
fun
|
||||
| Ok(i) => Ok(i)
|
||||
| Error(r) => {
|
||||
Js.log4(
|
||||
"Error",
|
||||
inputTypes,
|
||||
t,
|
||||
_sanitizedInputs
|
||||
let inputsToTypedValues =
|
||||
(
|
||||
evaluationParams: ExpressionTypes.ExpressionTree.evaluationParams,
|
||||
inputNodes: inputNodes,
|
||||
t: t,
|
||||
) => {
|
||||
_inputLengthCheck(inputNodes, t)
|
||||
->E.R.bind(
|
||||
_coerceInputNodes(
|
||||
evaluationParams,
|
||||
t.inputTypes,
|
||||
t.shouldCoerceTypes,
|
||||
),
|
||||
);
|
||||
Error("Function " ++ t.name ++ " error: " ++ r);
|
||||
}
|
||||
);
|
||||
};
|
||||
|
||||
let run =
|
||||
(
|
||||
evaluationParams: ExpressionTypes.ExpressionTree.evaluationParams,
|
||||
inputNodes: inputNodes,
|
||||
t: t,
|
||||
) => {
|
||||
Js.log("Running!");
|
||||
inputsToTypedValues(evaluationParams, inputNodes, t)->E.R.bind(t.run)
|
||||
|> (
|
||||
fun
|
||||
| Ok(i) => Ok(i)
|
||||
| Error(r) => {
|
||||
Error("Function " ++ t.name ++ " error: " ++ r);
|
||||
}
|
||||
);
|
||||
};
|
||||
};
|
||||
|
||||
module Ts = {
|
||||
let findByName = (ts: ts, n: string) =>
|
||||
ts |> Belt.Array.getBy(_, ({name}) => name == n);
|
||||
|
||||
let findByNameAndRun = (ts: ts, n: string, evaluationParams, inputTypes) =>
|
||||
findByName(ts, n) |> E.O.fmap(T.run(evaluationParams, inputTypes));
|
||||
};
|
||||
};
|
||||
|
||||
let getFn = (fns: fns, n: string) =>
|
||||
fns |> Belt.Array.getBy(_, ({name}) => name == n);
|
||||
|
||||
let getAndRun = (fns: fns, n: string, evaluationParams, inputTypes) =>
|
||||
getFn(fns, n) |> E.O.fmap(run(evaluationParams, inputTypes));
|
||||
|
|
Loading…
Reference in New Issue
Block a user