Simple functionality without weights

This commit is contained in:
Ozzie Gooen 2020-11-06 19:31:38 -08:00
parent cf36594b4a
commit 8d55bba2ca
7 changed files with 307 additions and 188 deletions

View File

@ -233,7 +233,7 @@ let make = () => {
~onSubmit=({state}) => {None},
~initialState={
//guesstimatorString: "mm(normal(-10, 2), uniform(18, 25), lognormal({mean: 10, stdev: 8}), triangular(31,40,50))",
guesstimatorString: "mm(3,4)",
guesstimatorString: "mm(normal(5,2))",
domainType: "Complete",
xPoint: "50.0",
xPoint2: "60.0",

View File

@ -206,6 +206,7 @@ module Truncate = {
module Normalize = {
let rec operationToLeaf = (evaluationParams, t: node): result(node, string) => {
Js.log2("normalize", t);
switch (t) {
| `RenderedDist(s) => Ok(`RenderedDist(Shape.T.normalize(s)))
| `SymbolicDist(_) => Ok(t)
@ -225,6 +226,7 @@ let callableFunction = (evaluationParams, name, args) => {
module Render = {
let rec operationToLeaf =
(evaluationParams: evaluationParams, t: node): result(t, string) => {
Js.log2("rendering", t);
switch (t) {
| `Function(_) => Error("Cannot render a function")
| `SymbolicDist(d) =>
@ -254,17 +256,18 @@ let rec toLeaf =
node: t,
)
: result(t, string) => {
Js.log2("node",node);
Js.log2("node", node);
switch (node) {
// Leaf nodes just stay leaf nodes
| `SymbolicDist(_)
| `Function(_)
| `RenderedDist(_) => Ok(node)
| `Array(args) =>
Js.log2("Array!", args);
args
|> E.A.fmap(toLeaf(evaluationParams))
|> E.A.R.firstErrorOrOpen
|> E.R.fmap(r => `Array(r))
|> E.R.fmap(r => `Array(r));
// Operations nevaluationParamsd to be turned into leaves
| `AlgebraicCombination(algebraicOp, t1, t2) =>
AlgebraicCombination.operationToLeaf(
@ -285,34 +288,25 @@ let rec toLeaf =
| `Normalize(t) => Normalize.operationToLeaf(evaluationParams, t)
| `Render(t) => Render.operationToLeaf(evaluationParams, t)
| `Hash(t) =>
Js.log("In hash");
t
|> E.A.fmap(((name: string, node: node)) =>
toLeaf(evaluationParams, node) |> E.R.fmap(r => (name, r))
)
|> E.A.R.firstErrorOrOpen
|> E.R.fmap(r => `Hash(r))
|> E.R.fmap(r => `Hash(r));
| `Symbol(r) =>
Js.log("Symbol");
ExpressionTypes.ExpressionTree.Environment.get(
evaluationParams.environment,
r,
)
|> E.O.toResult("Undeclared variable " ++ r)
|> E.R.bind(_, toLeaf(evaluationParams))
|> E.R.bind(_, toLeaf(evaluationParams));
| `FunctionCall(name, args) =>
Js.log3("In function call", name, args);
callableFunction(evaluationParams, name, args)
|> E.R.bind(_, toLeaf(evaluationParams))
| `MultiModal(r) =>
let components =
r
|> E.A.fmap(((dist, weight)) =>
`FunctionCall("scaleMultiply", [|dist, `SymbolicDist(`Float(weight))|]));
let pointwiseSum =
components
|> Js.Array.sliceFrom(1)
|> E.A.fold_left(
(acc, x) => {`PointwiseCombination((`Add, acc, x))},
E.A.unsafe_get(components, 0),
);
Ok(`Normalize(pointwiseSum)) |> E.R.bind(_, toLeaf(evaluationParams));
|> E.R.bind(_, toLeaf(evaluationParams));
| `MultiModal(r) => Error("Multimodal?")
};
};

View File

@ -16,11 +16,12 @@ type distToFloatOperation = [
];
module ExpressionTree = {
type node = [
type hash = array((string, node))
and node = [
| `SymbolicDist(SymbolicTypes.symbolicDist)
| `RenderedDist(DistTypes.shape)
| `Symbol(string)
| `Hash(array((string, node)))
| `Hash(hash)
| `Array(array(node))
| `Function(array(string), node)
| `AlgebraicCombination(algebraicOperation, node, node)
@ -31,16 +32,31 @@ module ExpressionTree = {
| `FunctionCall(string, array(node))
| `MultiModal(array((node, float)))
];
// Have nil as option
let getFloat = (node:node) => node |> fun
| `RenderedDist(Discrete({xyShape: {xs: [|x|], ys: [|1.0|]}})) => Some(x)
| `SymbolicDist(`Float(x)) => Some(x)
| _ => None
let toFloatIfNeeded = (node:node) => switch(node |> getFloat){
module Hash = {
type t('a) = array((string, 'a));
let getByName = (t:t('a), name) =>
E.A.getBy(t, ((n, _)) => n == name) |> E.O.fmap(((_, r)) => r);
let getByNames = (hash: t('a), names:array(string)) =>
names |> E.A.fmap(name => (name, getByName(hash, name)))
};
// Have nil as option
let getFloat = (node: node) =>
node
|> (
fun
| `RenderedDist(Discrete({xyShape: {xs: [|x|], ys: [|1.0|]}})) =>
Some(x)
| `SymbolicDist(`Float(x)) => Some(x)
| _ => None
);
let toFloatIfNeeded = (node: node) =>
switch (node |> getFloat) {
| Some(float) => `SymbolicDist(`Float(float))
| None => node
}
};
type samplingInputs = {
sampleCount: int,
@ -97,7 +113,6 @@ module ExpressionTree = {
|> evaluationParams.evaluateNode(evaluationParams)
|> E.R.bind(_, fn(evaluationParams));
module Render = {
type t = node;

View File

@ -15,7 +15,7 @@ let to_: (float, float) => result(node, string) =
};
let makeSymbolicFromTwoFloats = (name, fn) =>
Function.make(
Function.T.make(
~name,
~outputType=`SamplingDistribution,
~inputTypes=[|`Float, `Float|],
@ -23,10 +23,11 @@ let makeSymbolicFromTwoFloats = (name, fn) =>
fun
| [|`Float(a), `Float(b)|] => Ok(`SymbolicDist(fn(a, b)))
| e => wrongInputsError(e),
(),
);
let makeSymbolicFromOneFloat = (name, fn) =>
Function.make(
Function.T.make(
~name,
~outputType=`SamplingDistribution,
~inputTypes=[|`Float|],
@ -34,10 +35,11 @@ let makeSymbolicFromOneFloat = (name, fn) =>
fun
| [|`Float(a)|] => Ok(`SymbolicDist(fn(a)))
| e => wrongInputsError(e),
(),
);
let makeDistFloat = (name, fn) =>
Function.make(
Function.T.make(
~name,
~outputType=`SamplingDistribution,
~inputTypes=[|`SamplingDistribution, `Float|],
@ -45,21 +47,23 @@ let makeDistFloat = (name, fn) =>
fun
| [|`SamplingDist(a), `Float(b)|] => fn(a, b)
| e => wrongInputsError(e),
(),
);
let makeRenderedDistFloat = (name, fn) =>
Function.make(
Function.T.make(
~name,
~outputType=`RenderedDistribution,
~inputTypes=[|`RenderedDistribution, `Float|],
~run=
fun
| [|`RenderedDist(a), `Float(b)|] => fn(a, b)
| e => wrongInputsError(e)
| e => wrongInputsError(e),
(),
);
let makeDist = (name, fn) =>
Function.make(
Function.T.make(
~name,
~outputType=`SamplingDistribution,
~inputTypes=[|`SamplingDistribution|],
@ -67,6 +71,7 @@ let makeDist = (name, fn) =>
fun
| [|`SamplingDist(a)|] => fn(a)
| e => wrongInputsError(e),
(),
);
let floatFromDist =
@ -112,7 +117,7 @@ let functions = [|
SymbolicDist.Lognormal.fromMeanAndStdev,
),
makeSymbolicFromOneFloat("exponential", SymbolicDist.Exponential.make),
Function.make(
Function.T.make(
~name="to",
~outputType=`SamplingDistribution,
~inputTypes=[|`Float, `Float|],
@ -120,8 +125,9 @@ let functions = [|
fun
| [|`Float(a), `Float(b)|] => to_(a, b)
| e => wrongInputsError(e),
(),
),
Function.make(
Function.T.make(
~name="triangular",
~outputType=`SamplingDistribution,
~inputTypes=[|`Float, `Float, `Float|],
@ -131,13 +137,14 @@ let functions = [|
SymbolicDist.Triangular.make(a, b, c)
|> E.R.fmap(r => `SymbolicDist(r))
| e => wrongInputsError(e),
(),
),
makeDistFloat("pdf", (dist, float) => floatFromDist(`Pdf(float), dist)),
makeDistFloat("inv", (dist, float) => floatFromDist(`Inv(float), dist)),
makeDistFloat("cdf", (dist, float) => floatFromDist(`Cdf(float), dist)),
makeDist("mean", dist => floatFromDist(`Mean, dist)),
makeDist("sample", dist => floatFromDist(`Sample, dist)),
Function.make(
Function.T.make(
~name="render",
~outputType=`RenderedDistribution,
~inputTypes=[|`RenderedDistribution|],
@ -145,8 +152,9 @@ let functions = [|
fun
| [|`RenderedDist(c)|] => Ok(`RenderedDist(c))
| e => wrongInputsError(e),
(),
),
Function.make(
Function.T.make(
~name="normalize",
~outputType=`SamplingDistribution,
~inputTypes=[|`SamplingDistribution|],
@ -156,6 +164,7 @@ let functions = [|
| [|`SamplingDist(`RenderedDist(c))|] =>
Ok(`RenderedDist(Shape.T.normalize(c)))
| e => wrongInputsError(e),
(),
),
makeRenderedDistFloat("scaleExp", (dist, float) =>
verticalScaling(`Exponentiate, dist, float)
@ -166,4 +175,42 @@ let functions = [|
makeRenderedDistFloat("scaleLog", (dist, float) =>
verticalScaling(`Log, dist, float)
),
Function.T.make(
~name="multimodal",
~outputType=`SamplingDistribution,
~inputTypes=[|
`Named([|
("dists", `Array(`SamplingDistribution)),
("weights", `Array(`Float)),
|]),
|],
~run=
fun
| [|`Named(r)|] => {
let foo =
(r: TypeSystem.typedValue)
: result(ExpressionTypes.ExpressionTree.node, string) =>
switch (r) {
| `SamplingDist(`SymbolicDist(c)) => Ok(`SymbolicDist(c))
| `SamplingDist(`RenderedDist(c)) => Ok(`RenderedDist(c))
| _ => Error("")
};
switch (ExpressionTypes.ExpressionTree.Hash.getByName(r, "dists")) {
| Some(`Array(r)) =>
r
|> E.A.fmap(foo)
|> E.A.R.firstErrorOrOpen
|> E.R.fmap(distributions => {
distributions
|> E.A.fold_left(
(acc, x) => {`PointwiseCombination((`Add, acc, x))},
E.A.unsafe_get(distributions, 0),
)
})
| _ => Error("")
};
}
| _ => Error(""),
(),
),
|];

View File

@ -8,9 +8,13 @@ let fnn =
name,
args: array(node),
) => {
let trySomeFns =
TypeSystem.getAndRun(Fns.functions, name, evaluationParams, args);
switch (trySomeFns) {
// let foundFn =
// TypeSystem.Function.Ts.findByName(Fns.functions, name) |> E.O.toResult("Function " ++ name ++ " not found");
// let ran = foundFn |> E.R.bind(_,TypeSystem.Function.T.run(evaluationParams,args))
let foundFn =
TypeSystem.Function.Ts.findByNameAndRun(Fns.functions, name, evaluationParams, args);
switch (foundFn) {
| Some(r) => r
| None =>
switch (
@ -21,32 +25,8 @@ let fnn =
),
) {
| (_, Some(`Function(argNames, tt))) =>
Js.log("Fundction found: " ++ name);
PTypes.Function.run(evaluationParams, args, (argNames, tt))
| ("multimodal", _) =>
switch (args |> E.A.to_list) {
| [`Array(weights), ...dists] =>
let withWeights =
dists
|> E.L.toArray
|> E.A.fmapi((index, t) => {
let w =
weights
|> E.A.get(_, index)
|> E.O.bind(_, getFloat)
|> E.O.default(1.0);
(t, w);
});
Ok(`MultiModal(withWeights));
| dists when E.L.length(dists) > 0 =>
Ok(
`MultiModal(
dists
|> E.L.toArray
|> E.A.fmap(r => (r, 1.0)),
),
)
| _ => Error("Needs at least one distribution")
}
| _ => Error("Function " ++ name ++ " not found")
}
};

View File

@ -198,13 +198,36 @@ module MathAdtToDistDst = {
});
};
let functionParser = (nodeParser, name, args) => {
let functionParser =
(
nodeParser:
MathJsonToMathJsAdt.arg =>
Belt.Result.t(
ProbExample.ExpressionTypes.ExpressionTree.node,
string,
),
name: string,
args: array(MathJsonToMathJsAdt.arg),
)
: result(ExpressionTypes.ExpressionTree.node, string) => {
let parseArray = ags =>
ags |> E.A.fmap(nodeParser) |> E.A.R.firstErrorOrOpen;
let parseArgs = () => parseArray(args);
switch (name) {
| "lognormal" => lognormal(args, parseArgs, nodeParser)
| "multimodal"
| "add"
| "subtract"
| "multiply"
| "unaryMinus"
| "dotMultiply"
| "dotPow"
| "rightLogShift"
| "divide"
| "pow"
| "leftTruncate"
| "rightTruncate"
| "truncate" => operationParser(name, parseArgs())
| "mm" =>
let weights =
args
@ -223,24 +246,25 @@ module MathAdtToDistDst = {
switch (weights, dists) {
| (Some(Error(r)), _) => Error(r)
| (_, Error(r)) => Error(r)
| (None, Ok(dists)) => Ok(`FunctionCall(("multimodal", dists)))
| (Some(Ok(r)), Ok(dists)) =>
Ok(
`FunctionCall(("multimodal", E.A.append([|`Array(r)|], dists))),
)
| (None, Ok(dists)) =>
let hash: ExpressionTypes.ExpressionTree.node =
`FunctionCall(("multimodal", [|`Hash(
[|
("dists", `Array(dists)),
("weights", `Array([||]))
|]
)|]));
Ok(hash);
| (Some(Ok(weights)), Ok(dists)) =>
let hash: ExpressionTypes.ExpressionTree.node =
`FunctionCall(("multimodal", [|`Hash(
[|
("dists", `Array(dists)),
("weights", `Array(weights))
|]
)|]));
Ok(hash);
};
| "add"
| "subtract"
| "multiply"
| "unaryMinus"
| "dotMultiply"
| "dotPow"
| "rightLogShift"
| "divide"
| "pow"
| "leftTruncate"
| "rightTruncate"
| "truncate" => operationParser(name, parseArgs())
| name =>
parseArgs()
|> E.R.fmap((args: array(ExpressionTypes.ExpressionTree.node)) =>

View File

@ -6,127 +6,186 @@ type samplingDist = [
| `RenderedDist(DistTypes.shape)
];
type t = [
type _type = [
| `Float
| `SamplingDistribution
| `RenderedDistribution
| `Array(t)
| `Named(array((string, t)))
| `Array(_type)
| `Named(array((string, _type)))
];
type tx = [
type typedValue = [
| `Float(float)
| `RenderedDist(DistTypes.shape)
| `SamplingDist(samplingDist)
| `Array(array(tx))
| `Named(array((string, tx)))
| `Array(array(typedValue))
| `Named(array((string, typedValue)))
];
type fn = {
type _function = {
name: string,
inputTypes: array(t),
outputType: t,
run: array(tx) => result(node, string),
inputTypes: array(_type),
outputType: _type,
run: array(typedValue) => result(node, string),
shouldCoerceTypes: bool,
};
type functions = array(_function);
type inputNodes = array(node);
module TypedValue = {
let rec fromNode = (node: node): result(typedValue, string) =>
switch (ExpressionTypes.ExpressionTree.toFloatIfNeeded(node)) {
| `SymbolicDist(`Float(r)) => Ok(`Float(r))
| `SymbolicDist(s) => Ok(`SamplingDist(`SymbolicDist(s)))
| `RenderedDist(s) => Ok(`RenderedDist(s))
| `Array(r) =>
r
|> E.A.fmap(fromNode)
|> E.A.R.firstErrorOrOpen
|> E.R.fmap(r => `Array(r))
| `Hash(hash) =>
hash
|> E.A.fmap(((name, t)) => fromNode(t) |> E.R.fmap(r => (name, r)))
|> E.A.R.firstErrorOrOpen
|> E.R.fmap(r => `Named(r))
| _ => Error("Wrong type")
};
// todo: Arrays and hashes
let rec fromNodeWithTypeCoercion = (evaluationParams, _type: _type, node) => {
Js.log3("With Coersion!", _type, node);
switch (_type, node) {
| (`Float, _) =>
switch (getFloat(node)) {
| Some(a) => Ok(`Float(a))
| _ => Error("Type Error: Expected float.")
}
| (`SamplingDistribution, _) =>
PTypes.SamplingDistribution.renderIfIsNotSamplingDistribution(
evaluationParams,
node,
)
|> E.R.bind(_, fromNode)
| (`RenderedDistribution, _) =>
ExpressionTypes.ExpressionTree.Render.render(evaluationParams, node)
|> E.R.bind(_, fromNode)
| (`Array(_type), `Array(b)) =>
b
|> E.A.fmap(fromNodeWithTypeCoercion(evaluationParams, _type))
|> E.A.R.firstErrorOrOpen
|> E.R.fmap(r => `Array(r))
| (`Named(named), `Hash(r)) =>
Js.log3("Named", named, r);
let foo =
named
|> E.A.fmap(((name, intendedType)) =>
(
name,
intendedType,
ExpressionTypes.ExpressionTree.Hash.getByName(r, name),
)
);
Js.log("Named: part 2");
let bar =
foo
|> E.A.fmap(((name, intendedType, optionNode)) =>
switch (optionNode) {
| Some(node) =>
fromNodeWithTypeCoercion(evaluationParams, intendedType, node)
|> E.R.fmap(node => (name, node))
| None => Error("Hash parameter not present in hash.")
}
)
|> E.A.R.firstErrorOrOpen
|> E.R.fmap(r => `Named(r));
Js.log3("Named!", foo, bar);
bar;
| _ => Error("fromNodeWithTypeCoercion error, sorry.")
};
};
};
module Function = {
let make = (~name, ~inputTypes, ~outputType, ~run): fn => {
name,
inputTypes,
outputType,
run,
};
};
type t = _function;
type ts = functions;
type fns = array(fn);
type inputTypes = array(node);
module T = {
let make =
(~name, ~inputTypes, ~outputType, ~run, ~shouldCoerceTypes=true, _): t => {
name,
inputTypes,
outputType,
run,
shouldCoerceTypes,
};
let rec fromNodeDirect = (node: node): result(tx, string) =>
switch (ExpressionTypes.ExpressionTree.toFloatIfNeeded(node)) {
| `SymbolicDist(`Float(r)) => Ok(`Float(r))
| `SymbolicDist(s) => Ok(`SamplingDist(`SymbolicDist(s)))
| `RenderedDist(s) => Ok(`RenderedDist(s))
| `Array(r) =>
r
|> E.A.fmap(fromNodeDirect)
|> E.A.R.firstErrorOrOpen
|> E.R.fmap(r => `Array(r))
| `Hash(hash) =>
hash
|> E.A.fmap(((name, t)) =>
fromNodeDirect(t) |> E.R.fmap(r => (name, r))
)
|> E.A.R.firstErrorOrOpen
|> E.R.fmap(r => `Named(r))
| _ => Error("Wrong type")
};
let _inputLengthCheck = (inputNodes: inputNodes, t: t) => {
let expectedLength = E.A.length(t.inputTypes);
let actualLength = E.A.length(inputNodes);
expectedLength == actualLength
? Ok(inputNodes)
: Error(
"Wrong number of inputs. Expected"
++ (expectedLength |> E.I.toString)
++ ". Got:"
++ (actualLength |> E.I.toString),
);
};
let compareInput = (evaluationParams, t: t, node) =>
switch (t) {
| `Float =>
switch (getFloat(node)) {
| Some(a) => Ok(`Float(a))
| _ => Error("Type Error: Expected float.")
}
| `SamplingDistribution =>
PTypes.SamplingDistribution.renderIfIsNotSamplingDistribution(
evaluationParams,
node,
)
|> E.R.bind(_, fromNodeDirect)
| `RenderedDistribution =>
ExpressionTypes.ExpressionTree.Render.render(evaluationParams, node)
|> E.R.bind(_, fromNodeDirect)
| _ => {
Js.log4("Type error: Expected ", t, ", got ", node);
Error("Bad input, sorry.")}
};
let sanatizeInputs =
(
evaluationParams: ExpressionTypes.ExpressionTree.evaluationParams,
inputTypes: inputTypes,
t: fn,
) => {
E.A.length(t.inputTypes) == E.A.length(inputTypes)
? Belt.Array.zip(t.inputTypes, inputTypes)
let _coerceInputNodes =
(evaluationParams, inputTypes, shouldCoerce, inputNodes) =>
Belt.Array.zip(inputTypes, inputNodes)
|> E.A.fmap(((def, input)) =>
compareInput(evaluationParams, def, input)
shouldCoerce
? TypedValue.fromNodeWithTypeCoercion(
evaluationParams,
def,
input,
)
: TypedValue.fromNode(input)
)
|> (r => {Js.log2("Inputs", r); r})
|> E.A.R.firstErrorOrOpen
: Error(
"Wrong number of inputs. Expected"
++ (E.A.length(t.inputTypes) |> E.I.toString)
++ ". Got:"
++ (E.A.length(inputTypes) |> E.I.toString),
);
};
|> E.A.R.firstErrorOrOpen;
let run =
(
evaluationParams: ExpressionTypes.ExpressionTree.evaluationParams,
inputTypes: inputTypes,
t: fn,
) => {
let _sanitizedInputs = sanatizeInputs(evaluationParams, inputTypes, t);
_sanitizedInputs |> E.R.bind(_,t.run)
|> (
fun
| Ok(i) => Ok(i)
| Error(r) => {
Js.log4(
"Error",
inputTypes,
t,
_sanitizedInputs
let inputsToTypedValues =
(
evaluationParams: ExpressionTypes.ExpressionTree.evaluationParams,
inputNodes: inputNodes,
t: t,
) => {
_inputLengthCheck(inputNodes, t)
->E.R.bind(
_coerceInputNodes(
evaluationParams,
t.inputTypes,
t.shouldCoerceTypes,
),
);
Error("Function " ++ t.name ++ " error: " ++ r);
}
);
};
let run =
(
evaluationParams: ExpressionTypes.ExpressionTree.evaluationParams,
inputNodes: inputNodes,
t: t,
) => {
Js.log("Running!");
inputsToTypedValues(evaluationParams, inputNodes, t)->E.R.bind(t.run)
|> (
fun
| Ok(i) => Ok(i)
| Error(r) => {
Error("Function " ++ t.name ++ " error: " ++ r);
}
);
};
};
module Ts = {
let findByName = (ts: ts, n: string) =>
ts |> Belt.Array.getBy(_, ({name}) => name == n);
let findByNameAndRun = (ts: ts, n: string, evaluationParams, inputTypes) =>
findByName(ts, n) |> E.O.fmap(T.run(evaluationParams, inputTypes));
};
};
let getFn = (fns: fns, n: string) =>
fns |> Belt.Array.getBy(_, ({name}) => name == n);
let getAndRun = (fns: fns, n: string, evaluationParams, inputTypes) =>
getFn(fns, n) |> E.O.fmap(run(evaluationParams, inputTypes));