Starting simple type system

This commit is contained in:
Ozzie Gooen 2020-08-14 21:48:41 +01:00
parent 6d36284283
commit c002100461
5 changed files with 281 additions and 182 deletions

View File

@ -1,11 +1,5 @@
open ExpressionTypes.ExpressionTree;
let envs = (samplingInputs, environment) => {
{samplingInputs, environment, evaluateNode: ExpressionTreeEvaluator.toLeaf};
};
let toLeaf = (samplingInputs, environment, node: node) =>
ExpressionTreeEvaluator.toLeaf(envs(samplingInputs, environment), node);
let rec toString: node => string =
fun
| `SymbolicDist(d) => SymbolicDist.T.toString(d)
@ -33,9 +27,16 @@ let rec toString: node => string =
++ (args |> Js.String.concatMany(_, ","))
++ toString(internal)
++ ")]"
| `Array(args) => "Array"
| `MultiModal(args) => "Multimodal"
| `Array(_) => "Array"
| `MultiModal(_) => "Multimodal"
| `Hash(_) => "Hash"
let envs = (samplingInputs, environment) => {
{samplingInputs, environment, evaluateNode: ExpressionTreeEvaluator.toLeaf};
};
let toLeaf = (samplingInputs, environment, node: node) =>
ExpressionTreeEvaluator.toLeaf(envs(samplingInputs, environment), node);
let toShape = (samplingInputs, environment, node: node) => {
switch (toLeaf(samplingInputs, environment, node)) {
| Ok(`RenderedDist(shape)) => Ok(shape)

View File

@ -21,6 +21,7 @@ module ExpressionTree = {
| `RenderedDist(DistTypes.shape)
| `Symbol(string)
| `Hash(array((string, node)))
| `Array(array(node))
| `Function(array(string), node)
| `AlgebraicCombination(algebraicOperation, node, node)
| `PointwiseCombination(pointwiseOperation, node, node)
@ -30,7 +31,6 @@ module ExpressionTree = {
| `Truncate(option(float), option(float), node)
| `FloatFromDist(distToFloatOperation, node)
| `FunctionCall(string, array(node))
| `Array(array(node))
| `MultiModal(array((node, float)))
];
// Have nil as option

View File

@ -0,0 +1,109 @@
open TypeSystem;
let wrongInputsError = (r) => {Js.log2("Wrong inputs", r); Error("Wrong inputs")};
let to_: (float, float) => result(node, string) =
(low, high) =>
switch (low, high) {
| (low, high) when low <= 0.0 && low < high =>
Ok(`SymbolicDist(SymbolicDist.Normal.from90PercentCI(low, high)))
| (low, high) when low < high =>
Ok(`SymbolicDist(SymbolicDist.Lognormal.from90PercentCI(low, high)))
| (_, _) => Error("Low value must be less than high value.")
};
let makeSymbolicFromTwoFloats = (name, fn) =>
Function.make(
~name,
~output=`SamplingDistribution,
~inputs=[|`Float, `Float|],
~run=
fun
| [|`Float(a), `Float(b)|] => Ok(`SymbolicDist(fn(a, b)))
| e => wrongInputsError(e)
);
let makeSymbolicFromOneFloat = (name, fn) =>
Function.make(
~name,
~output=`SamplingDistribution,
~inputs=[|`Float|],
~run=
fun
| [|`Float(a)|] => Ok(`SymbolicDist(fn(a)))
| e => wrongInputsError(e)
);
let makeDistFloat = (name, fn) =>
Function.make(
~name,
~output=`SamplingDistribution,
~inputs=[|`SamplingDistribution, `Float|],
~run=
fun
| [|`SamplingDist(a), `Float(b)|] => (fn(a,b))
| e => wrongInputsError(e)
);
let makeDist = (name, fn) =>
Function.make(
~name,
~output=`SamplingDistribution,
~inputs=[|`SamplingDistribution|],
~run=
fun
| [|`SamplingDist(a)|] => fn(a)
| e => wrongInputsError(e)
);
let floatFromDist =
(
distToFloatOp: ExpressionTypes.distToFloatOperation,
t: TypeSystem.samplingDist,
)
: result(node, string) => {
switch (t) {
| `SymbolicDist(s) =>
SymbolicDist.T.operate(distToFloatOp, s)
|> E.R.bind(_, v => Ok(`SymbolicDist(`Float(v))))
| `RenderedDist(rs) =>
Shape.operate(distToFloatOp, rs) |> (v => Ok(`SymbolicDist(`Float(v))))
};
};
let functions = [|
makeSymbolicFromTwoFloats("normal", SymbolicDist.Normal.make),
makeSymbolicFromTwoFloats("uniform", SymbolicDist.Uniform.make),
makeSymbolicFromTwoFloats("beta", SymbolicDist.Beta.make),
makeSymbolicFromTwoFloats("lognormal", SymbolicDist.Lognormal.make),
makeSymbolicFromTwoFloats(
"lognormalFromMeanAndStdDev",
SymbolicDist.Lognormal.fromMeanAndStdev,
),
makeSymbolicFromOneFloat("exponential", SymbolicDist.Exponential.make),
Function.make(
~name="to",
~output=`SamplingDistribution,
~inputs=[|`Float, `Float|],
~run=
fun
| [|`Float(a), `Float(b)|] => to_(a,b)
| e => wrongInputsError(e)
),
Function.make(
~name="triangular",
~output=`SamplingDistribution,
~inputs=[|`Float, `Float, `Float|],
~run=
fun
| [|`Float(a), `Float(b), `Float(c)|] =>
SymbolicDist.Triangular.make(a, b, c)
|> E.R.fmap(r => `SymbolicDist(r))
| e => wrongInputsError(e)
),
makeDistFloat("pdf", (dist, float) => floatFromDist(`Pdf(float), dist)),
makeDistFloat("inv", (dist, float) => floatFromDist(`Inv(float), dist)),
makeDistFloat("cdf", (dist, float) => floatFromDist(`Cdf(float), dist)),
makeDist("mean", (dist) => floatFromDist(`Mean, dist)),
makeDist("sample", (dist) => floatFromDist(`Sample, dist))
|];

View File

@ -2,185 +2,47 @@ type node = ExpressionTypes.ExpressionTree.node;
let toOkSym = r => Ok(`SymbolicDist(r));
let getFloat = ExpressionTypes.ExpressionTree.getFloat;
let twoFloats = (fn, n1: node, n2: node): result(node, string) =>
switch (getFloat(n1), getFloat(n2)) {
| (Some(a), Some(b)) => fn(a, b)
| _ => Error("Function needed two floats, missing them.")
};
let threeFloats = (fn, n1: node, n2: node, n3: node): result(node, string) =>
switch (getFloat(n1), getFloat(n2), getFloat(n3)) {
| (Some(a), Some(b), Some(c)) => fn(a, b, c)
| _ => Error("Variables have wrong type")
};
let twoFloatsToOkSym = fn => twoFloats((f1, f2) => fn(f1, f2) |> toOkSym);
let threeFloats = fn => threeFloats((f1, f2, f3) => fn(f1, f2, f3));
let apply2 = (fn, args): result(node, string) =>
switch (args) {
| [|a, b|] => fn(a, b)
| _ => Error("Needs 2 args")
};
let apply3 = (fn, args: array(node)): result(node, string) =>
switch (args) {
| [|a, b, c|] => fn(a, b, c)
| _ => Error("Needs 3 args")
};
let to_: (float, float) => result(node, string) =
(low, high) =>
switch (low, high) {
| (low, high) when low <= 0.0 && low < high =>
Ok(`SymbolicDist(SymbolicDist.Normal.from90PercentCI(low, high)))
| (low, high) when low < high =>
Ok(`SymbolicDist(SymbolicDist.Lognormal.from90PercentCI(low, high)))
| (low, high) => Error("Low value must be less than high value.")
};
// Possible setup:
// let normal = {"inputs": [`float, `float], "outputs": [`float]};
// let render = {"inputs": [`dist], "outputs": [`renderedDist]};
// let render = {"inputs": [`distRenderedDist], "outputs": [`renderedDist]};
// type types = [`Float| `Dist];
// type def = {types};
let fnn =
(
evaluationParams: ExpressionTypes.ExpressionTree.evaluationParams,
name,
args: array(node),
) =>
switch (
name,
ExpressionTypes.ExpressionTree.Environment.get(
evaluationParams.environment,
) => {
let trySomeFns =
TypeSystem.getAndRun(Fns.functions, name, evaluationParams, args);
switch (trySomeFns) {
| Some(r) => r
| None =>
switch (
name,
),
) {
| (_, Some(`Function(argNames, tt))) =>
PTypes.Function.run(evaluationParams, args, (argNames, tt))
| ("normal", _) =>
apply2(twoFloatsToOkSym(SymbolicDist.Normal.make), args)
| ("uniform", _) =>
apply2(twoFloatsToOkSym(SymbolicDist.Uniform.make), args)
| ("beta", _) => apply2(twoFloatsToOkSym(SymbolicDist.Beta.make), args)
| ("cauchy", _) =>
apply2(twoFloatsToOkSym(SymbolicDist.Cauchy.make), args)
| ("lognormal", _) =>
apply2(twoFloatsToOkSym(SymbolicDist.Lognormal.make), args)
| ("lognormalFromMeanAndStdDev", _) =>
apply2(twoFloatsToOkSym(SymbolicDist.Lognormal.fromMeanAndStdev), args)
| ("exponential", _) =>
switch (args) {
| [|`SymbolicDist(`Float(a))|] =>
Ok(`SymbolicDist(SymbolicDist.Exponential.make(a)))
| _ => Error("Needs 3 valid arguments")
}
| ("triangular", _) =>
switch (args |> E.A.fmap(getFloat)) {
| [|Some(a), Some(b), Some(c)|] =>
SymbolicDist.Triangular.make(a, b, c)
|> E.R.fmap(r => `SymbolicDist(r))
| _ => Error("Needs 3 valid arguments")
}
| ("to", _) => apply2(twoFloats(to_), args)
| ("pdf", _) =>
switch (args) {
| [|fst, snd|] =>
switch (
PTypes.SamplingDistribution.renderIfIsNotSamplingDistribution(
evaluationParams,
fst,
),
getFloat(snd),
) {
| (Ok(fst), Some(flt)) => Ok(`FloatFromDist((`Pdf(flt), fst)))
| _ => Error("Incorrect arguments")
ExpressionTypes.ExpressionTree.Environment.get(
evaluationParams.environment,
name,
),
) {
| (_, Some(`Function(argNames, tt))) =>
PTypes.Function.run(evaluationParams, args, (argNames, tt))
| ("mm", _)
| ("multimodal", _) =>
switch (args |> E.A.to_list) {
| [`Array(weights), ...dists] =>
let withWeights =
dists
|> E.L.toArray
|> E.A.fmapi((index, t) => {
let w =
weights
|> E.A.get(_, index)
|> E.O.bind(_, getFloat)
|> E.O.default(1.0);
(t, w);
});
Ok(`MultiModal(withWeights));
| dists when E.L.length(dists) > 0 =>
Ok(`MultiModal(dists |> E.L.toArray |> E.A.fmap(r => (r, 1.0))))
| _ => Error("Needs at least one distribution")
}
| _ => Error("Needs two args")
| _ => Error("Function " ++ name ++ " not found")
}
| ("inv", _) =>
switch (args) {
| [|fst, snd|] =>
switch (
PTypes.SamplingDistribution.renderIfIsNotSamplingDistribution(
evaluationParams,
fst,
),
getFloat(snd),
) {
| (Ok(fst), Some(flt)) => Ok(`FloatFromDist((`Inv(flt), fst)))
| _ => Error("Incorrect arguments")
}
| _ => Error("Needs two args")
}
| ("cdf", _) =>
switch (args) {
| [|fst, snd|] =>
switch (
PTypes.SamplingDistribution.renderIfIsNotSamplingDistribution(
evaluationParams,
fst,
),
getFloat(snd),
) {
| (Ok(fst), Some(flt)) => Ok(`FloatFromDist((`Cdf(flt), fst)))
| _ => Error("Incorrect arguments")
}
| _ => Error("Needs two args")
}
| ("mean", _) =>
switch (args) {
| [|fst|] =>
switch (
PTypes.SamplingDistribution.renderIfIsNotSamplingDistribution(
evaluationParams,
fst,
)
) {
| Ok(fst) => Ok(`FloatFromDist((`Mean, fst)))
| _ => Error("Incorrect arguments")
}
| _ => Error("Needs two args")
}
| ("sample", _) =>
switch (args) {
| [|fst|] =>
switch (
PTypes.SamplingDistribution.renderIfIsNotSamplingDistribution(
evaluationParams,
fst,
)
) {
| Ok(fst) => Ok(`FloatFromDist((`Sample, fst)))
| _ => Error("Incorrect arguments")
}
| _ => Error("Needs two args")
}
| ("mm", _)
| ("multimodal", _) =>
switch (args |> E.A.to_list) {
| [`Array(weights), ...dists] =>
let withWeights =
dists
|> E.L.toArray
|> E.A.fmapi((index, t) => {
let w =
weights
|> E.A.get(_, index)
|> E.O.bind(_, getFloat)
|> E.O.default(1.0);
(t, w);
});
Ok(`MultiModal(withWeights));
| dists when E.L.length(dists) > 0 =>
Ok(`MultiModal(dists |> E.L.toArray |> E.A.fmap(r => (r, 1.0))))
| _ => Error("Needs at least one distribution")
}
| _ => Error("Function " ++ name ++ " not found")
};
};

View File

@ -0,0 +1,127 @@
type node = ExpressionTypes.ExpressionTree.node;
let getFloat = ExpressionTypes.ExpressionTree.getFloat;
type samplingDist = [
| `SymbolicDist(SymbolicTypes.symbolicDist)
| `RenderedDist(DistTypes.shape)
];
type t = [
| `Float
| `SamplingDistribution
| `RenderedDistribution
| `Array(t)
| `Named(array((string, t)))
];
type tx = [
| `Float(float)
| `RenderedDist(DistTypes.shape)
| `SamplingDist(samplingDist)
| `Array(array(tx))
| `Named(array((string, tx)))
];
type fn = {
name: string,
inputs: array(t),
output: t,
run: array(tx) => result(node, string),
};
module Function = {
let make = (~name, ~inputs, ~output, ~run): fn => {
name,
inputs,
output,
run,
};
};
type fns = array(fn);
type inputs = array(node);
let rec fromNodeDirect = (node: node): result(tx, string) =>
switch (ExpressionTypes.ExpressionTree.toFloatIfNeeded(node)) {
| `SymbolicDist(`Float(r)) => Ok(`Float(r))
| `SymbolicDist(s) => Ok(`SamplingDist(`SymbolicDist(s)))
| `RenderedDist(s) => Ok(`RenderedDist(s))
| `Array(r) =>
r
|> E.A.fmap(fromNodeDirect)
|> E.A.R.firstErrorOrOpen
|> E.R.fmap(r => `Array(r))
| `Hash(hash) =>
hash
|> E.A.fmap(((name, t)) =>
fromNodeDirect(t) |> E.R.fmap(r => (name, r))
)
|> E.A.R.firstErrorOrOpen
|> E.R.fmap(r => `Named(r))
| _ => Error("Wrong type")
};
let compareInput = (evaluationParams, t: t, node) =>
switch (t) {
| `Float =>
switch (getFloat(node)) {
| Some(a) => Ok(`Float(a))
| _ =>
Error(
"Type Error: Expected float."
)
}
| `SamplingDistribution =>
PTypes.SamplingDistribution.renderIfIsNotSamplingDistribution(
evaluationParams,
node,
)
|> E.R.bind(_, fromNodeDirect)
| `RenderedDistribution =>
ExpressionTypes.ExpressionTree.Render.render(evaluationParams, node)
|> E.R.bind(_, fromNodeDirect)
| _ => Error("Bad input, sorry.")
};
let sanatizeInputs =
(
evaluationParams: ExpressionTypes.ExpressionTree.evaluationParams,
inputs: inputs,
t: fn,
) => {
E.A.length(t.inputs) == E.A.length(inputs)
? Belt.Array.zip(t.inputs, inputs)
|> E.A.fmap(((def, input)) =>
compareInput(evaluationParams, def, input)
)
|> E.A.R.firstErrorOrOpen
: Error(
"Wrong number of inputs. Expected"
++ (E.A.length(t.inputs) |> E.I.toString)
++ ". Got:"
++ (E.A.length(inputs) |> E.I.toString),
);
};
let run =
(
evaluationParams: ExpressionTypes.ExpressionTree.evaluationParams,
inputs: inputs,
t: fn,
) =>{
(
switch (sanatizeInputs(evaluationParams, inputs, t)) {
| Ok(inputs) => t.run(inputs)
| Error(r) => Error(r)
}
)
|> (
fun
| Ok(i) => Ok(i)
| Error(r) => {Js.log4("Error", inputs, t, sanatizeInputs(evaluationParams, inputs, t), ); Error("Function " ++ t.name ++ " error: " ++ r)}
);
}
let getFn = (fns: fns, n: string) =>
fns |> Belt.Array.getBy(_, ({name}) => name == n);
let getAndRun = (fns: fns, n: string, evaluationParams, inputs) =>
getFn(fns, n) |> E.O.fmap(run(evaluationParams, inputs));