Simple functionality without weights

This commit is contained in:
Ozzie Gooen 2020-11-06 19:31:38 -08:00
parent cf36594b4a
commit 8d55bba2ca
7 changed files with 307 additions and 188 deletions

View File

@ -233,7 +233,7 @@ let make = () => {
~onSubmit=({state}) => {None}, ~onSubmit=({state}) => {None},
~initialState={ ~initialState={
//guesstimatorString: "mm(normal(-10, 2), uniform(18, 25), lognormal({mean: 10, stdev: 8}), triangular(31,40,50))", //guesstimatorString: "mm(normal(-10, 2), uniform(18, 25), lognormal({mean: 10, stdev: 8}), triangular(31,40,50))",
guesstimatorString: "mm(3,4)", guesstimatorString: "mm(normal(5,2))",
domainType: "Complete", domainType: "Complete",
xPoint: "50.0", xPoint: "50.0",
xPoint2: "60.0", xPoint2: "60.0",

View File

@ -206,6 +206,7 @@ module Truncate = {
module Normalize = { module Normalize = {
let rec operationToLeaf = (evaluationParams, t: node): result(node, string) => { let rec operationToLeaf = (evaluationParams, t: node): result(node, string) => {
Js.log2("normalize", t);
switch (t) { switch (t) {
| `RenderedDist(s) => Ok(`RenderedDist(Shape.T.normalize(s))) | `RenderedDist(s) => Ok(`RenderedDist(Shape.T.normalize(s)))
| `SymbolicDist(_) => Ok(t) | `SymbolicDist(_) => Ok(t)
@ -225,6 +226,7 @@ let callableFunction = (evaluationParams, name, args) => {
module Render = { module Render = {
let rec operationToLeaf = let rec operationToLeaf =
(evaluationParams: evaluationParams, t: node): result(t, string) => { (evaluationParams: evaluationParams, t: node): result(t, string) => {
Js.log2("rendering", t);
switch (t) { switch (t) {
| `Function(_) => Error("Cannot render a function") | `Function(_) => Error("Cannot render a function")
| `SymbolicDist(d) => | `SymbolicDist(d) =>
@ -254,17 +256,18 @@ let rec toLeaf =
node: t, node: t,
) )
: result(t, string) => { : result(t, string) => {
Js.log2("node",node); Js.log2("node", node);
switch (node) { switch (node) {
// Leaf nodes just stay leaf nodes // Leaf nodes just stay leaf nodes
| `SymbolicDist(_) | `SymbolicDist(_)
| `Function(_) | `Function(_)
| `RenderedDist(_) => Ok(node) | `RenderedDist(_) => Ok(node)
| `Array(args) => | `Array(args) =>
Js.log2("Array!", args);
args args
|> E.A.fmap(toLeaf(evaluationParams)) |> E.A.fmap(toLeaf(evaluationParams))
|> E.A.R.firstErrorOrOpen |> E.A.R.firstErrorOrOpen
|> E.R.fmap(r => `Array(r)) |> E.R.fmap(r => `Array(r));
// Operations nevaluationParamsd to be turned into leaves // Operations nevaluationParamsd to be turned into leaves
| `AlgebraicCombination(algebraicOp, t1, t2) => | `AlgebraicCombination(algebraicOp, t1, t2) =>
AlgebraicCombination.operationToLeaf( AlgebraicCombination.operationToLeaf(
@ -285,34 +288,25 @@ let rec toLeaf =
| `Normalize(t) => Normalize.operationToLeaf(evaluationParams, t) | `Normalize(t) => Normalize.operationToLeaf(evaluationParams, t)
| `Render(t) => Render.operationToLeaf(evaluationParams, t) | `Render(t) => Render.operationToLeaf(evaluationParams, t)
| `Hash(t) => | `Hash(t) =>
Js.log("In hash");
t t
|> E.A.fmap(((name: string, node: node)) => |> E.A.fmap(((name: string, node: node)) =>
toLeaf(evaluationParams, node) |> E.R.fmap(r => (name, r)) toLeaf(evaluationParams, node) |> E.R.fmap(r => (name, r))
) )
|> E.A.R.firstErrorOrOpen |> E.A.R.firstErrorOrOpen
|> E.R.fmap(r => `Hash(r)) |> E.R.fmap(r => `Hash(r));
| `Symbol(r) => | `Symbol(r) =>
Js.log("Symbol");
ExpressionTypes.ExpressionTree.Environment.get( ExpressionTypes.ExpressionTree.Environment.get(
evaluationParams.environment, evaluationParams.environment,
r, r,
) )
|> E.O.toResult("Undeclared variable " ++ r) |> E.O.toResult("Undeclared variable " ++ r)
|> E.R.bind(_, toLeaf(evaluationParams)) |> E.R.bind(_, toLeaf(evaluationParams));
| `FunctionCall(name, args) => | `FunctionCall(name, args) =>
Js.log3("In function call", name, args);
callableFunction(evaluationParams, name, args) callableFunction(evaluationParams, name, args)
|> E.R.bind(_, toLeaf(evaluationParams)) |> E.R.bind(_, toLeaf(evaluationParams));
| `MultiModal(r) => | `MultiModal(r) => Error("Multimodal?")
let components =
r
|> E.A.fmap(((dist, weight)) =>
`FunctionCall("scaleMultiply", [|dist, `SymbolicDist(`Float(weight))|]));
let pointwiseSum =
components
|> Js.Array.sliceFrom(1)
|> E.A.fold_left(
(acc, x) => {`PointwiseCombination((`Add, acc, x))},
E.A.unsafe_get(components, 0),
);
Ok(`Normalize(pointwiseSum)) |> E.R.bind(_, toLeaf(evaluationParams));
}; };
}; };

View File

@ -16,11 +16,12 @@ type distToFloatOperation = [
]; ];
module ExpressionTree = { module ExpressionTree = {
type node = [ type hash = array((string, node))
and node = [
| `SymbolicDist(SymbolicTypes.symbolicDist) | `SymbolicDist(SymbolicTypes.symbolicDist)
| `RenderedDist(DistTypes.shape) | `RenderedDist(DistTypes.shape)
| `Symbol(string) | `Symbol(string)
| `Hash(array((string, node))) | `Hash(hash)
| `Array(array(node)) | `Array(array(node))
| `Function(array(string), node) | `Function(array(string), node)
| `AlgebraicCombination(algebraicOperation, node, node) | `AlgebraicCombination(algebraicOperation, node, node)
@ -31,16 +32,31 @@ module ExpressionTree = {
| `FunctionCall(string, array(node)) | `FunctionCall(string, array(node))
| `MultiModal(array((node, float))) | `MultiModal(array((node, float)))
]; ];
// Have nil as option
let getFloat = (node:node) => node |> fun
| `RenderedDist(Discrete({xyShape: {xs: [|x|], ys: [|1.0|]}})) => Some(x)
| `SymbolicDist(`Float(x)) => Some(x)
| _ => None
let toFloatIfNeeded = (node:node) => switch(node |> getFloat){ module Hash = {
type t('a) = array((string, 'a));
let getByName = (t:t('a), name) =>
E.A.getBy(t, ((n, _)) => n == name) |> E.O.fmap(((_, r)) => r);
let getByNames = (hash: t('a), names:array(string)) =>
names |> E.A.fmap(name => (name, getByName(hash, name)))
};
// Have nil as option
let getFloat = (node: node) =>
node
|> (
fun
| `RenderedDist(Discrete({xyShape: {xs: [|x|], ys: [|1.0|]}})) =>
Some(x)
| `SymbolicDist(`Float(x)) => Some(x)
| _ => None
);
let toFloatIfNeeded = (node: node) =>
switch (node |> getFloat) {
| Some(float) => `SymbolicDist(`Float(float)) | Some(float) => `SymbolicDist(`Float(float))
| None => node | None => node
} };
type samplingInputs = { type samplingInputs = {
sampleCount: int, sampleCount: int,
@ -97,7 +113,6 @@ module ExpressionTree = {
|> evaluationParams.evaluateNode(evaluationParams) |> evaluationParams.evaluateNode(evaluationParams)
|> E.R.bind(_, fn(evaluationParams)); |> E.R.bind(_, fn(evaluationParams));
module Render = { module Render = {
type t = node; type t = node;

View File

@ -15,7 +15,7 @@ let to_: (float, float) => result(node, string) =
}; };
let makeSymbolicFromTwoFloats = (name, fn) => let makeSymbolicFromTwoFloats = (name, fn) =>
Function.make( Function.T.make(
~name, ~name,
~outputType=`SamplingDistribution, ~outputType=`SamplingDistribution,
~inputTypes=[|`Float, `Float|], ~inputTypes=[|`Float, `Float|],
@ -23,10 +23,11 @@ let makeSymbolicFromTwoFloats = (name, fn) =>
fun fun
| [|`Float(a), `Float(b)|] => Ok(`SymbolicDist(fn(a, b))) | [|`Float(a), `Float(b)|] => Ok(`SymbolicDist(fn(a, b)))
| e => wrongInputsError(e), | e => wrongInputsError(e),
(),
); );
let makeSymbolicFromOneFloat = (name, fn) => let makeSymbolicFromOneFloat = (name, fn) =>
Function.make( Function.T.make(
~name, ~name,
~outputType=`SamplingDistribution, ~outputType=`SamplingDistribution,
~inputTypes=[|`Float|], ~inputTypes=[|`Float|],
@ -34,10 +35,11 @@ let makeSymbolicFromOneFloat = (name, fn) =>
fun fun
| [|`Float(a)|] => Ok(`SymbolicDist(fn(a))) | [|`Float(a)|] => Ok(`SymbolicDist(fn(a)))
| e => wrongInputsError(e), | e => wrongInputsError(e),
(),
); );
let makeDistFloat = (name, fn) => let makeDistFloat = (name, fn) =>
Function.make( Function.T.make(
~name, ~name,
~outputType=`SamplingDistribution, ~outputType=`SamplingDistribution,
~inputTypes=[|`SamplingDistribution, `Float|], ~inputTypes=[|`SamplingDistribution, `Float|],
@ -45,21 +47,23 @@ let makeDistFloat = (name, fn) =>
fun fun
| [|`SamplingDist(a), `Float(b)|] => fn(a, b) | [|`SamplingDist(a), `Float(b)|] => fn(a, b)
| e => wrongInputsError(e), | e => wrongInputsError(e),
(),
); );
let makeRenderedDistFloat = (name, fn) => let makeRenderedDistFloat = (name, fn) =>
Function.make( Function.T.make(
~name, ~name,
~outputType=`RenderedDistribution, ~outputType=`RenderedDistribution,
~inputTypes=[|`RenderedDistribution, `Float|], ~inputTypes=[|`RenderedDistribution, `Float|],
~run= ~run=
fun fun
| [|`RenderedDist(a), `Float(b)|] => fn(a, b) | [|`RenderedDist(a), `Float(b)|] => fn(a, b)
| e => wrongInputsError(e) | e => wrongInputsError(e),
(),
); );
let makeDist = (name, fn) => let makeDist = (name, fn) =>
Function.make( Function.T.make(
~name, ~name,
~outputType=`SamplingDistribution, ~outputType=`SamplingDistribution,
~inputTypes=[|`SamplingDistribution|], ~inputTypes=[|`SamplingDistribution|],
@ -67,6 +71,7 @@ let makeDist = (name, fn) =>
fun fun
| [|`SamplingDist(a)|] => fn(a) | [|`SamplingDist(a)|] => fn(a)
| e => wrongInputsError(e), | e => wrongInputsError(e),
(),
); );
let floatFromDist = let floatFromDist =
@ -112,7 +117,7 @@ let functions = [|
SymbolicDist.Lognormal.fromMeanAndStdev, SymbolicDist.Lognormal.fromMeanAndStdev,
), ),
makeSymbolicFromOneFloat("exponential", SymbolicDist.Exponential.make), makeSymbolicFromOneFloat("exponential", SymbolicDist.Exponential.make),
Function.make( Function.T.make(
~name="to", ~name="to",
~outputType=`SamplingDistribution, ~outputType=`SamplingDistribution,
~inputTypes=[|`Float, `Float|], ~inputTypes=[|`Float, `Float|],
@ -120,8 +125,9 @@ let functions = [|
fun fun
| [|`Float(a), `Float(b)|] => to_(a, b) | [|`Float(a), `Float(b)|] => to_(a, b)
| e => wrongInputsError(e), | e => wrongInputsError(e),
(),
), ),
Function.make( Function.T.make(
~name="triangular", ~name="triangular",
~outputType=`SamplingDistribution, ~outputType=`SamplingDistribution,
~inputTypes=[|`Float, `Float, `Float|], ~inputTypes=[|`Float, `Float, `Float|],
@ -131,13 +137,14 @@ let functions = [|
SymbolicDist.Triangular.make(a, b, c) SymbolicDist.Triangular.make(a, b, c)
|> E.R.fmap(r => `SymbolicDist(r)) |> E.R.fmap(r => `SymbolicDist(r))
| e => wrongInputsError(e), | e => wrongInputsError(e),
(),
), ),
makeDistFloat("pdf", (dist, float) => floatFromDist(`Pdf(float), dist)), makeDistFloat("pdf", (dist, float) => floatFromDist(`Pdf(float), dist)),
makeDistFloat("inv", (dist, float) => floatFromDist(`Inv(float), dist)), makeDistFloat("inv", (dist, float) => floatFromDist(`Inv(float), dist)),
makeDistFloat("cdf", (dist, float) => floatFromDist(`Cdf(float), dist)), makeDistFloat("cdf", (dist, float) => floatFromDist(`Cdf(float), dist)),
makeDist("mean", dist => floatFromDist(`Mean, dist)), makeDist("mean", dist => floatFromDist(`Mean, dist)),
makeDist("sample", dist => floatFromDist(`Sample, dist)), makeDist("sample", dist => floatFromDist(`Sample, dist)),
Function.make( Function.T.make(
~name="render", ~name="render",
~outputType=`RenderedDistribution, ~outputType=`RenderedDistribution,
~inputTypes=[|`RenderedDistribution|], ~inputTypes=[|`RenderedDistribution|],
@ -145,8 +152,9 @@ let functions = [|
fun fun
| [|`RenderedDist(c)|] => Ok(`RenderedDist(c)) | [|`RenderedDist(c)|] => Ok(`RenderedDist(c))
| e => wrongInputsError(e), | e => wrongInputsError(e),
(),
), ),
Function.make( Function.T.make(
~name="normalize", ~name="normalize",
~outputType=`SamplingDistribution, ~outputType=`SamplingDistribution,
~inputTypes=[|`SamplingDistribution|], ~inputTypes=[|`SamplingDistribution|],
@ -156,6 +164,7 @@ let functions = [|
| [|`SamplingDist(`RenderedDist(c))|] => | [|`SamplingDist(`RenderedDist(c))|] =>
Ok(`RenderedDist(Shape.T.normalize(c))) Ok(`RenderedDist(Shape.T.normalize(c)))
| e => wrongInputsError(e), | e => wrongInputsError(e),
(),
), ),
makeRenderedDistFloat("scaleExp", (dist, float) => makeRenderedDistFloat("scaleExp", (dist, float) =>
verticalScaling(`Exponentiate, dist, float) verticalScaling(`Exponentiate, dist, float)
@ -166,4 +175,42 @@ let functions = [|
makeRenderedDistFloat("scaleLog", (dist, float) => makeRenderedDistFloat("scaleLog", (dist, float) =>
verticalScaling(`Log, dist, float) verticalScaling(`Log, dist, float)
), ),
Function.T.make(
~name="multimodal",
~outputType=`SamplingDistribution,
~inputTypes=[|
`Named([|
("dists", `Array(`SamplingDistribution)),
("weights", `Array(`Float)),
|]),
|],
~run=
fun
| [|`Named(r)|] => {
let foo =
(r: TypeSystem.typedValue)
: result(ExpressionTypes.ExpressionTree.node, string) =>
switch (r) {
| `SamplingDist(`SymbolicDist(c)) => Ok(`SymbolicDist(c))
| `SamplingDist(`RenderedDist(c)) => Ok(`RenderedDist(c))
| _ => Error("")
};
switch (ExpressionTypes.ExpressionTree.Hash.getByName(r, "dists")) {
| Some(`Array(r)) =>
r
|> E.A.fmap(foo)
|> E.A.R.firstErrorOrOpen
|> E.R.fmap(distributions => {
distributions
|> E.A.fold_left(
(acc, x) => {`PointwiseCombination((`Add, acc, x))},
E.A.unsafe_get(distributions, 0),
)
})
| _ => Error("")
};
}
| _ => Error(""),
(),
),
|]; |];

View File

@ -8,9 +8,13 @@ let fnn =
name, name,
args: array(node), args: array(node),
) => { ) => {
let trySomeFns = // let foundFn =
TypeSystem.getAndRun(Fns.functions, name, evaluationParams, args); // TypeSystem.Function.Ts.findByName(Fns.functions, name) |> E.O.toResult("Function " ++ name ++ " not found");
switch (trySomeFns) { // let ran = foundFn |> E.R.bind(_,TypeSystem.Function.T.run(evaluationParams,args))
let foundFn =
TypeSystem.Function.Ts.findByNameAndRun(Fns.functions, name, evaluationParams, args);
switch (foundFn) {
| Some(r) => r | Some(r) => r
| None => | None =>
switch ( switch (
@ -21,32 +25,8 @@ let fnn =
), ),
) { ) {
| (_, Some(`Function(argNames, tt))) => | (_, Some(`Function(argNames, tt))) =>
Js.log("Fundction found: " ++ name);
PTypes.Function.run(evaluationParams, args, (argNames, tt)) PTypes.Function.run(evaluationParams, args, (argNames, tt))
| ("multimodal", _) =>
switch (args |> E.A.to_list) {
| [`Array(weights), ...dists] =>
let withWeights =
dists
|> E.L.toArray
|> E.A.fmapi((index, t) => {
let w =
weights
|> E.A.get(_, index)
|> E.O.bind(_, getFloat)
|> E.O.default(1.0);
(t, w);
});
Ok(`MultiModal(withWeights));
| dists when E.L.length(dists) > 0 =>
Ok(
`MultiModal(
dists
|> E.L.toArray
|> E.A.fmap(r => (r, 1.0)),
),
)
| _ => Error("Needs at least one distribution")
}
| _ => Error("Function " ++ name ++ " not found") | _ => Error("Function " ++ name ++ " not found")
} }
}; };

View File

@ -198,13 +198,36 @@ module MathAdtToDistDst = {
}); });
}; };
let functionParser = (nodeParser, name, args) => { let functionParser =
(
nodeParser:
MathJsonToMathJsAdt.arg =>
Belt.Result.t(
ProbExample.ExpressionTypes.ExpressionTree.node,
string,
),
name: string,
args: array(MathJsonToMathJsAdt.arg),
)
: result(ExpressionTypes.ExpressionTree.node, string) => {
let parseArray = ags => let parseArray = ags =>
ags |> E.A.fmap(nodeParser) |> E.A.R.firstErrorOrOpen; ags |> E.A.fmap(nodeParser) |> E.A.R.firstErrorOrOpen;
let parseArgs = () => parseArray(args); let parseArgs = () => parseArray(args);
switch (name) { switch (name) {
| "lognormal" => lognormal(args, parseArgs, nodeParser) | "lognormal" => lognormal(args, parseArgs, nodeParser)
| "multimodal" | "multimodal"
| "add"
| "subtract"
| "multiply"
| "unaryMinus"
| "dotMultiply"
| "dotPow"
| "rightLogShift"
| "divide"
| "pow"
| "leftTruncate"
| "rightTruncate"
| "truncate" => operationParser(name, parseArgs())
| "mm" => | "mm" =>
let weights = let weights =
args args
@ -223,24 +246,25 @@ module MathAdtToDistDst = {
switch (weights, dists) { switch (weights, dists) {
| (Some(Error(r)), _) => Error(r) | (Some(Error(r)), _) => Error(r)
| (_, Error(r)) => Error(r) | (_, Error(r)) => Error(r)
| (None, Ok(dists)) => Ok(`FunctionCall(("multimodal", dists))) | (None, Ok(dists)) =>
| (Some(Ok(r)), Ok(dists)) => let hash: ExpressionTypes.ExpressionTree.node =
Ok( `FunctionCall(("multimodal", [|`Hash(
`FunctionCall(("multimodal", E.A.append([|`Array(r)|], dists))), [|
) ("dists", `Array(dists)),
("weights", `Array([||]))
|]
)|]));
Ok(hash);
| (Some(Ok(weights)), Ok(dists)) =>
let hash: ExpressionTypes.ExpressionTree.node =
`FunctionCall(("multimodal", [|`Hash(
[|
("dists", `Array(dists)),
("weights", `Array(weights))
|]
)|]));
Ok(hash);
}; };
| "add"
| "subtract"
| "multiply"
| "unaryMinus"
| "dotMultiply"
| "dotPow"
| "rightLogShift"
| "divide"
| "pow"
| "leftTruncate"
| "rightTruncate"
| "truncate" => operationParser(name, parseArgs())
| name => | name =>
parseArgs() parseArgs()
|> E.R.fmap((args: array(ExpressionTypes.ExpressionTree.node)) => |> E.R.fmap((args: array(ExpressionTypes.ExpressionTree.node)) =>

View File

@ -6,127 +6,186 @@ type samplingDist = [
| `RenderedDist(DistTypes.shape) | `RenderedDist(DistTypes.shape)
]; ];
type t = [ type _type = [
| `Float | `Float
| `SamplingDistribution | `SamplingDistribution
| `RenderedDistribution | `RenderedDistribution
| `Array(t) | `Array(_type)
| `Named(array((string, t))) | `Named(array((string, _type)))
]; ];
type tx = [
type typedValue = [
| `Float(float) | `Float(float)
| `RenderedDist(DistTypes.shape) | `RenderedDist(DistTypes.shape)
| `SamplingDist(samplingDist) | `SamplingDist(samplingDist)
| `Array(array(tx)) | `Array(array(typedValue))
| `Named(array((string, tx))) | `Named(array((string, typedValue)))
]; ];
type fn = { type _function = {
name: string, name: string,
inputTypes: array(t), inputTypes: array(_type),
outputType: t, outputType: _type,
run: array(tx) => result(node, string), run: array(typedValue) => result(node, string),
shouldCoerceTypes: bool,
};
type functions = array(_function);
type inputNodes = array(node);
module TypedValue = {
let rec fromNode = (node: node): result(typedValue, string) =>
switch (ExpressionTypes.ExpressionTree.toFloatIfNeeded(node)) {
| `SymbolicDist(`Float(r)) => Ok(`Float(r))
| `SymbolicDist(s) => Ok(`SamplingDist(`SymbolicDist(s)))
| `RenderedDist(s) => Ok(`RenderedDist(s))
| `Array(r) =>
r
|> E.A.fmap(fromNode)
|> E.A.R.firstErrorOrOpen
|> E.R.fmap(r => `Array(r))
| `Hash(hash) =>
hash
|> E.A.fmap(((name, t)) => fromNode(t) |> E.R.fmap(r => (name, r)))
|> E.A.R.firstErrorOrOpen
|> E.R.fmap(r => `Named(r))
| _ => Error("Wrong type")
};
// todo: Arrays and hashes
let rec fromNodeWithTypeCoercion = (evaluationParams, _type: _type, node) => {
Js.log3("With Coersion!", _type, node);
switch (_type, node) {
| (`Float, _) =>
switch (getFloat(node)) {
| Some(a) => Ok(`Float(a))
| _ => Error("Type Error: Expected float.")
}
| (`SamplingDistribution, _) =>
PTypes.SamplingDistribution.renderIfIsNotSamplingDistribution(
evaluationParams,
node,
)
|> E.R.bind(_, fromNode)
| (`RenderedDistribution, _) =>
ExpressionTypes.ExpressionTree.Render.render(evaluationParams, node)
|> E.R.bind(_, fromNode)
| (`Array(_type), `Array(b)) =>
b
|> E.A.fmap(fromNodeWithTypeCoercion(evaluationParams, _type))
|> E.A.R.firstErrorOrOpen
|> E.R.fmap(r => `Array(r))
| (`Named(named), `Hash(r)) =>
Js.log3("Named", named, r);
let foo =
named
|> E.A.fmap(((name, intendedType)) =>
(
name,
intendedType,
ExpressionTypes.ExpressionTree.Hash.getByName(r, name),
)
);
Js.log("Named: part 2");
let bar =
foo
|> E.A.fmap(((name, intendedType, optionNode)) =>
switch (optionNode) {
| Some(node) =>
fromNodeWithTypeCoercion(evaluationParams, intendedType, node)
|> E.R.fmap(node => (name, node))
| None => Error("Hash parameter not present in hash.")
}
)
|> E.A.R.firstErrorOrOpen
|> E.R.fmap(r => `Named(r));
Js.log3("Named!", foo, bar);
bar;
| _ => Error("fromNodeWithTypeCoercion error, sorry.")
};
};
}; };
module Function = { module Function = {
let make = (~name, ~inputTypes, ~outputType, ~run): fn => { type t = _function;
name, type ts = functions;
inputTypes,
outputType,
run,
};
};
type fns = array(fn); module T = {
type inputTypes = array(node); let make =
(~name, ~inputTypes, ~outputType, ~run, ~shouldCoerceTypes=true, _): t => {
name,
inputTypes,
outputType,
run,
shouldCoerceTypes,
};
let rec fromNodeDirect = (node: node): result(tx, string) => let _inputLengthCheck = (inputNodes: inputNodes, t: t) => {
switch (ExpressionTypes.ExpressionTree.toFloatIfNeeded(node)) { let expectedLength = E.A.length(t.inputTypes);
| `SymbolicDist(`Float(r)) => Ok(`Float(r)) let actualLength = E.A.length(inputNodes);
| `SymbolicDist(s) => Ok(`SamplingDist(`SymbolicDist(s))) expectedLength == actualLength
| `RenderedDist(s) => Ok(`RenderedDist(s)) ? Ok(inputNodes)
| `Array(r) => : Error(
r "Wrong number of inputs. Expected"
|> E.A.fmap(fromNodeDirect) ++ (expectedLength |> E.I.toString)
|> E.A.R.firstErrorOrOpen ++ ". Got:"
|> E.R.fmap(r => `Array(r)) ++ (actualLength |> E.I.toString),
| `Hash(hash) => );
hash };
|> E.A.fmap(((name, t)) =>
fromNodeDirect(t) |> E.R.fmap(r => (name, r))
)
|> E.A.R.firstErrorOrOpen
|> E.R.fmap(r => `Named(r))
| _ => Error("Wrong type")
};
let compareInput = (evaluationParams, t: t, node) => let _coerceInputNodes =
switch (t) { (evaluationParams, inputTypes, shouldCoerce, inputNodes) =>
| `Float => Belt.Array.zip(inputTypes, inputNodes)
switch (getFloat(node)) {
| Some(a) => Ok(`Float(a))
| _ => Error("Type Error: Expected float.")
}
| `SamplingDistribution =>
PTypes.SamplingDistribution.renderIfIsNotSamplingDistribution(
evaluationParams,
node,
)
|> E.R.bind(_, fromNodeDirect)
| `RenderedDistribution =>
ExpressionTypes.ExpressionTree.Render.render(evaluationParams, node)
|> E.R.bind(_, fromNodeDirect)
| _ => {
Js.log4("Type error: Expected ", t, ", got ", node);
Error("Bad input, sorry.")}
};
let sanatizeInputs =
(
evaluationParams: ExpressionTypes.ExpressionTree.evaluationParams,
inputTypes: inputTypes,
t: fn,
) => {
E.A.length(t.inputTypes) == E.A.length(inputTypes)
? Belt.Array.zip(t.inputTypes, inputTypes)
|> E.A.fmap(((def, input)) => |> E.A.fmap(((def, input)) =>
compareInput(evaluationParams, def, input) shouldCoerce
? TypedValue.fromNodeWithTypeCoercion(
evaluationParams,
def,
input,
)
: TypedValue.fromNode(input)
) )
|> (r => {Js.log2("Inputs", r); r}) |> E.A.R.firstErrorOrOpen;
|> E.A.R.firstErrorOrOpen
: Error(
"Wrong number of inputs. Expected"
++ (E.A.length(t.inputTypes) |> E.I.toString)
++ ". Got:"
++ (E.A.length(inputTypes) |> E.I.toString),
);
};
let run = let inputsToTypedValues =
( (
evaluationParams: ExpressionTypes.ExpressionTree.evaluationParams, evaluationParams: ExpressionTypes.ExpressionTree.evaluationParams,
inputTypes: inputTypes, inputNodes: inputNodes,
t: fn, t: t,
) => { ) => {
let _sanitizedInputs = sanatizeInputs(evaluationParams, inputTypes, t); _inputLengthCheck(inputNodes, t)
_sanitizedInputs |> E.R.bind(_,t.run) ->E.R.bind(
|> ( _coerceInputNodes(
fun evaluationParams,
| Ok(i) => Ok(i) t.inputTypes,
| Error(r) => { t.shouldCoerceTypes,
Js.log4( ),
"Error",
inputTypes,
t,
_sanitizedInputs
); );
Error("Function " ++ t.name ++ " error: " ++ r); };
}
); let run =
(
evaluationParams: ExpressionTypes.ExpressionTree.evaluationParams,
inputNodes: inputNodes,
t: t,
) => {
Js.log("Running!");
inputsToTypedValues(evaluationParams, inputNodes, t)->E.R.bind(t.run)
|> (
fun
| Ok(i) => Ok(i)
| Error(r) => {
Error("Function " ++ t.name ++ " error: " ++ r);
}
);
};
};
module Ts = {
let findByName = (ts: ts, n: string) =>
ts |> Belt.Array.getBy(_, ({name}) => name == n);
let findByNameAndRun = (ts: ts, n: string, evaluationParams, inputTypes) =>
findByName(ts, n) |> E.O.fmap(T.run(evaluationParams, inputTypes));
};
}; };
let getFn = (fns: fns, n: string) =>
fns |> Belt.Array.getBy(_, ({name}) => name == n);
let getAndRun = (fns: fns, n: string, evaluationParams, inputTypes) =>
getFn(fns, n) |> E.O.fmap(run(evaluationParams, inputTypes));