Starting simple type system
This commit is contained in:
parent
6d36284283
commit
c002100461
|
@ -1,11 +1,5 @@
|
|||
open ExpressionTypes.ExpressionTree;
|
||||
|
||||
let envs = (samplingInputs, environment) => {
|
||||
{samplingInputs, environment, evaluateNode: ExpressionTreeEvaluator.toLeaf};
|
||||
};
|
||||
let toLeaf = (samplingInputs, environment, node: node) =>
|
||||
ExpressionTreeEvaluator.toLeaf(envs(samplingInputs, environment), node);
|
||||
|
||||
let rec toString: node => string =
|
||||
fun
|
||||
| `SymbolicDist(d) => SymbolicDist.T.toString(d)
|
||||
|
@ -33,9 +27,16 @@ let rec toString: node => string =
|
|||
++ (args |> Js.String.concatMany(_, ","))
|
||||
++ toString(internal)
|
||||
++ ")]"
|
||||
| `Array(args) => "Array"
|
||||
| `MultiModal(args) => "Multimodal"
|
||||
| `Array(_) => "Array"
|
||||
| `MultiModal(_) => "Multimodal"
|
||||
| `Hash(_) => "Hash"
|
||||
|
||||
let envs = (samplingInputs, environment) => {
|
||||
{samplingInputs, environment, evaluateNode: ExpressionTreeEvaluator.toLeaf};
|
||||
};
|
||||
|
||||
let toLeaf = (samplingInputs, environment, node: node) =>
|
||||
ExpressionTreeEvaluator.toLeaf(envs(samplingInputs, environment), node);
|
||||
let toShape = (samplingInputs, environment, node: node) => {
|
||||
switch (toLeaf(samplingInputs, environment, node)) {
|
||||
| Ok(`RenderedDist(shape)) => Ok(shape)
|
||||
|
|
|
@ -21,6 +21,7 @@ module ExpressionTree = {
|
|||
| `RenderedDist(DistTypes.shape)
|
||||
| `Symbol(string)
|
||||
| `Hash(array((string, node)))
|
||||
| `Array(array(node))
|
||||
| `Function(array(string), node)
|
||||
| `AlgebraicCombination(algebraicOperation, node, node)
|
||||
| `PointwiseCombination(pointwiseOperation, node, node)
|
||||
|
@ -30,7 +31,6 @@ module ExpressionTree = {
|
|||
| `Truncate(option(float), option(float), node)
|
||||
| `FloatFromDist(distToFloatOperation, node)
|
||||
| `FunctionCall(string, array(node))
|
||||
| `Array(array(node))
|
||||
| `MultiModal(array((node, float)))
|
||||
];
|
||||
// Have nil as option
|
||||
|
|
109
src/distPlus/expressionTree/Fns.re
Normal file
109
src/distPlus/expressionTree/Fns.re
Normal file
|
@ -0,0 +1,109 @@
|
|||
open TypeSystem;
|
||||
|
||||
let wrongInputsError = (r) => {Js.log2("Wrong inputs", r); Error("Wrong inputs")};
|
||||
|
||||
let to_: (float, float) => result(node, string) =
|
||||
(low, high) =>
|
||||
switch (low, high) {
|
||||
| (low, high) when low <= 0.0 && low < high =>
|
||||
Ok(`SymbolicDist(SymbolicDist.Normal.from90PercentCI(low, high)))
|
||||
| (low, high) when low < high =>
|
||||
Ok(`SymbolicDist(SymbolicDist.Lognormal.from90PercentCI(low, high)))
|
||||
| (_, _) => Error("Low value must be less than high value.")
|
||||
};
|
||||
|
||||
let makeSymbolicFromTwoFloats = (name, fn) =>
|
||||
Function.make(
|
||||
~name,
|
||||
~output=`SamplingDistribution,
|
||||
~inputs=[|`Float, `Float|],
|
||||
~run=
|
||||
fun
|
||||
| [|`Float(a), `Float(b)|] => Ok(`SymbolicDist(fn(a, b)))
|
||||
| e => wrongInputsError(e)
|
||||
);
|
||||
|
||||
let makeSymbolicFromOneFloat = (name, fn) =>
|
||||
Function.make(
|
||||
~name,
|
||||
~output=`SamplingDistribution,
|
||||
~inputs=[|`Float|],
|
||||
~run=
|
||||
fun
|
||||
| [|`Float(a)|] => Ok(`SymbolicDist(fn(a)))
|
||||
| e => wrongInputsError(e)
|
||||
);
|
||||
|
||||
let makeDistFloat = (name, fn) =>
|
||||
Function.make(
|
||||
~name,
|
||||
~output=`SamplingDistribution,
|
||||
~inputs=[|`SamplingDistribution, `Float|],
|
||||
~run=
|
||||
fun
|
||||
| [|`SamplingDist(a), `Float(b)|] => (fn(a,b))
|
||||
| e => wrongInputsError(e)
|
||||
);
|
||||
|
||||
let makeDist = (name, fn) =>
|
||||
Function.make(
|
||||
~name,
|
||||
~output=`SamplingDistribution,
|
||||
~inputs=[|`SamplingDistribution|],
|
||||
~run=
|
||||
fun
|
||||
| [|`SamplingDist(a)|] => fn(a)
|
||||
| e => wrongInputsError(e)
|
||||
);
|
||||
|
||||
let floatFromDist =
|
||||
(
|
||||
distToFloatOp: ExpressionTypes.distToFloatOperation,
|
||||
t: TypeSystem.samplingDist,
|
||||
)
|
||||
: result(node, string) => {
|
||||
switch (t) {
|
||||
| `SymbolicDist(s) =>
|
||||
SymbolicDist.T.operate(distToFloatOp, s)
|
||||
|> E.R.bind(_, v => Ok(`SymbolicDist(`Float(v))))
|
||||
| `RenderedDist(rs) =>
|
||||
Shape.operate(distToFloatOp, rs) |> (v => Ok(`SymbolicDist(`Float(v))))
|
||||
};
|
||||
};
|
||||
|
||||
let functions = [|
|
||||
makeSymbolicFromTwoFloats("normal", SymbolicDist.Normal.make),
|
||||
makeSymbolicFromTwoFloats("uniform", SymbolicDist.Uniform.make),
|
||||
makeSymbolicFromTwoFloats("beta", SymbolicDist.Beta.make),
|
||||
makeSymbolicFromTwoFloats("lognormal", SymbolicDist.Lognormal.make),
|
||||
makeSymbolicFromTwoFloats(
|
||||
"lognormalFromMeanAndStdDev",
|
||||
SymbolicDist.Lognormal.fromMeanAndStdev,
|
||||
),
|
||||
makeSymbolicFromOneFloat("exponential", SymbolicDist.Exponential.make),
|
||||
Function.make(
|
||||
~name="to",
|
||||
~output=`SamplingDistribution,
|
||||
~inputs=[|`Float, `Float|],
|
||||
~run=
|
||||
fun
|
||||
| [|`Float(a), `Float(b)|] => to_(a,b)
|
||||
| e => wrongInputsError(e)
|
||||
),
|
||||
Function.make(
|
||||
~name="triangular",
|
||||
~output=`SamplingDistribution,
|
||||
~inputs=[|`Float, `Float, `Float|],
|
||||
~run=
|
||||
fun
|
||||
| [|`Float(a), `Float(b), `Float(c)|] =>
|
||||
SymbolicDist.Triangular.make(a, b, c)
|
||||
|> E.R.fmap(r => `SymbolicDist(r))
|
||||
| e => wrongInputsError(e)
|
||||
),
|
||||
makeDistFloat("pdf", (dist, float) => floatFromDist(`Pdf(float), dist)),
|
||||
makeDistFloat("inv", (dist, float) => floatFromDist(`Inv(float), dist)),
|
||||
makeDistFloat("cdf", (dist, float) => floatFromDist(`Cdf(float), dist)),
|
||||
makeDist("mean", (dist) => floatFromDist(`Mean, dist)),
|
||||
makeDist("sample", (dist) => floatFromDist(`Sample, dist))
|
||||
|];
|
|
@ -2,185 +2,47 @@ type node = ExpressionTypes.ExpressionTree.node;
|
|||
|
||||
let toOkSym = r => Ok(`SymbolicDist(r));
|
||||
let getFloat = ExpressionTypes.ExpressionTree.getFloat;
|
||||
|
||||
let twoFloats = (fn, n1: node, n2: node): result(node, string) =>
|
||||
switch (getFloat(n1), getFloat(n2)) {
|
||||
| (Some(a), Some(b)) => fn(a, b)
|
||||
| _ => Error("Function needed two floats, missing them.")
|
||||
};
|
||||
|
||||
let threeFloats = (fn, n1: node, n2: node, n3: node): result(node, string) =>
|
||||
switch (getFloat(n1), getFloat(n2), getFloat(n3)) {
|
||||
| (Some(a), Some(b), Some(c)) => fn(a, b, c)
|
||||
| _ => Error("Variables have wrong type")
|
||||
};
|
||||
|
||||
let twoFloatsToOkSym = fn => twoFloats((f1, f2) => fn(f1, f2) |> toOkSym);
|
||||
|
||||
let threeFloats = fn => threeFloats((f1, f2, f3) => fn(f1, f2, f3));
|
||||
|
||||
let apply2 = (fn, args): result(node, string) =>
|
||||
switch (args) {
|
||||
| [|a, b|] => fn(a, b)
|
||||
| _ => Error("Needs 2 args")
|
||||
};
|
||||
|
||||
let apply3 = (fn, args: array(node)): result(node, string) =>
|
||||
switch (args) {
|
||||
| [|a, b, c|] => fn(a, b, c)
|
||||
| _ => Error("Needs 3 args")
|
||||
};
|
||||
|
||||
let to_: (float, float) => result(node, string) =
|
||||
(low, high) =>
|
||||
switch (low, high) {
|
||||
| (low, high) when low <= 0.0 && low < high =>
|
||||
Ok(`SymbolicDist(SymbolicDist.Normal.from90PercentCI(low, high)))
|
||||
| (low, high) when low < high =>
|
||||
Ok(`SymbolicDist(SymbolicDist.Lognormal.from90PercentCI(low, high)))
|
||||
| (low, high) => Error("Low value must be less than high value.")
|
||||
};
|
||||
|
||||
// Possible setup:
|
||||
// let normal = {"inputs": [`float, `float], "outputs": [`float]};
|
||||
// let render = {"inputs": [`dist], "outputs": [`renderedDist]};
|
||||
// let render = {"inputs": [`distRenderedDist], "outputs": [`renderedDist]};
|
||||
|
||||
// type types = [`Float| `Dist];
|
||||
// type def = {types};
|
||||
|
||||
let fnn =
|
||||
(
|
||||
evaluationParams: ExpressionTypes.ExpressionTree.evaluationParams,
|
||||
name,
|
||||
args: array(node),
|
||||
) =>
|
||||
switch (
|
||||
name,
|
||||
ExpressionTypes.ExpressionTree.Environment.get(
|
||||
evaluationParams.environment,
|
||||
) => {
|
||||
let trySomeFns =
|
||||
TypeSystem.getAndRun(Fns.functions, name, evaluationParams, args);
|
||||
switch (trySomeFns) {
|
||||
| Some(r) => r
|
||||
| None =>
|
||||
switch (
|
||||
name,
|
||||
),
|
||||
) {
|
||||
| (_, Some(`Function(argNames, tt))) =>
|
||||
PTypes.Function.run(evaluationParams, args, (argNames, tt))
|
||||
| ("normal", _) =>
|
||||
apply2(twoFloatsToOkSym(SymbolicDist.Normal.make), args)
|
||||
| ("uniform", _) =>
|
||||
apply2(twoFloatsToOkSym(SymbolicDist.Uniform.make), args)
|
||||
| ("beta", _) => apply2(twoFloatsToOkSym(SymbolicDist.Beta.make), args)
|
||||
| ("cauchy", _) =>
|
||||
apply2(twoFloatsToOkSym(SymbolicDist.Cauchy.make), args)
|
||||
| ("lognormal", _) =>
|
||||
apply2(twoFloatsToOkSym(SymbolicDist.Lognormal.make), args)
|
||||
| ("lognormalFromMeanAndStdDev", _) =>
|
||||
apply2(twoFloatsToOkSym(SymbolicDist.Lognormal.fromMeanAndStdev), args)
|
||||
| ("exponential", _) =>
|
||||
switch (args) {
|
||||
| [|`SymbolicDist(`Float(a))|] =>
|
||||
Ok(`SymbolicDist(SymbolicDist.Exponential.make(a)))
|
||||
| _ => Error("Needs 3 valid arguments")
|
||||
}
|
||||
| ("triangular", _) =>
|
||||
switch (args |> E.A.fmap(getFloat)) {
|
||||
| [|Some(a), Some(b), Some(c)|] =>
|
||||
SymbolicDist.Triangular.make(a, b, c)
|
||||
|> E.R.fmap(r => `SymbolicDist(r))
|
||||
| _ => Error("Needs 3 valid arguments")
|
||||
}
|
||||
| ("to", _) => apply2(twoFloats(to_), args)
|
||||
| ("pdf", _) =>
|
||||
switch (args) {
|
||||
| [|fst, snd|] =>
|
||||
switch (
|
||||
PTypes.SamplingDistribution.renderIfIsNotSamplingDistribution(
|
||||
evaluationParams,
|
||||
fst,
|
||||
),
|
||||
getFloat(snd),
|
||||
) {
|
||||
| (Ok(fst), Some(flt)) => Ok(`FloatFromDist((`Pdf(flt), fst)))
|
||||
| _ => Error("Incorrect arguments")
|
||||
ExpressionTypes.ExpressionTree.Environment.get(
|
||||
evaluationParams.environment,
|
||||
name,
|
||||
),
|
||||
) {
|
||||
| (_, Some(`Function(argNames, tt))) =>
|
||||
PTypes.Function.run(evaluationParams, args, (argNames, tt))
|
||||
| ("mm", _)
|
||||
| ("multimodal", _) =>
|
||||
switch (args |> E.A.to_list) {
|
||||
| [`Array(weights), ...dists] =>
|
||||
let withWeights =
|
||||
dists
|
||||
|> E.L.toArray
|
||||
|> E.A.fmapi((index, t) => {
|
||||
let w =
|
||||
weights
|
||||
|> E.A.get(_, index)
|
||||
|> E.O.bind(_, getFloat)
|
||||
|> E.O.default(1.0);
|
||||
(t, w);
|
||||
});
|
||||
Ok(`MultiModal(withWeights));
|
||||
| dists when E.L.length(dists) > 0 =>
|
||||
Ok(`MultiModal(dists |> E.L.toArray |> E.A.fmap(r => (r, 1.0))))
|
||||
| _ => Error("Needs at least one distribution")
|
||||
}
|
||||
| _ => Error("Needs two args")
|
||||
| _ => Error("Function " ++ name ++ " not found")
|
||||
}
|
||||
| ("inv", _) =>
|
||||
switch (args) {
|
||||
| [|fst, snd|] =>
|
||||
switch (
|
||||
PTypes.SamplingDistribution.renderIfIsNotSamplingDistribution(
|
||||
evaluationParams,
|
||||
fst,
|
||||
),
|
||||
getFloat(snd),
|
||||
) {
|
||||
| (Ok(fst), Some(flt)) => Ok(`FloatFromDist((`Inv(flt), fst)))
|
||||
| _ => Error("Incorrect arguments")
|
||||
}
|
||||
| _ => Error("Needs two args")
|
||||
}
|
||||
| ("cdf", _) =>
|
||||
switch (args) {
|
||||
| [|fst, snd|] =>
|
||||
switch (
|
||||
PTypes.SamplingDistribution.renderIfIsNotSamplingDistribution(
|
||||
evaluationParams,
|
||||
fst,
|
||||
),
|
||||
getFloat(snd),
|
||||
) {
|
||||
| (Ok(fst), Some(flt)) => Ok(`FloatFromDist((`Cdf(flt), fst)))
|
||||
| _ => Error("Incorrect arguments")
|
||||
}
|
||||
| _ => Error("Needs two args")
|
||||
}
|
||||
| ("mean", _) =>
|
||||
switch (args) {
|
||||
| [|fst|] =>
|
||||
switch (
|
||||
PTypes.SamplingDistribution.renderIfIsNotSamplingDistribution(
|
||||
evaluationParams,
|
||||
fst,
|
||||
)
|
||||
) {
|
||||
| Ok(fst) => Ok(`FloatFromDist((`Mean, fst)))
|
||||
| _ => Error("Incorrect arguments")
|
||||
}
|
||||
| _ => Error("Needs two args")
|
||||
}
|
||||
| ("sample", _) =>
|
||||
switch (args) {
|
||||
| [|fst|] =>
|
||||
switch (
|
||||
PTypes.SamplingDistribution.renderIfIsNotSamplingDistribution(
|
||||
evaluationParams,
|
||||
fst,
|
||||
)
|
||||
) {
|
||||
| Ok(fst) => Ok(`FloatFromDist((`Sample, fst)))
|
||||
| _ => Error("Incorrect arguments")
|
||||
}
|
||||
| _ => Error("Needs two args")
|
||||
}
|
||||
| ("mm", _)
|
||||
| ("multimodal", _) =>
|
||||
switch (args |> E.A.to_list) {
|
||||
| [`Array(weights), ...dists] =>
|
||||
let withWeights =
|
||||
dists
|
||||
|> E.L.toArray
|
||||
|> E.A.fmapi((index, t) => {
|
||||
let w =
|
||||
weights
|
||||
|> E.A.get(_, index)
|
||||
|> E.O.bind(_, getFloat)
|
||||
|> E.O.default(1.0);
|
||||
(t, w);
|
||||
});
|
||||
Ok(`MultiModal(withWeights));
|
||||
| dists when E.L.length(dists) > 0 =>
|
||||
Ok(`MultiModal(dists |> E.L.toArray |> E.A.fmap(r => (r, 1.0))))
|
||||
| _ => Error("Needs at least one distribution")
|
||||
}
|
||||
| _ => Error("Function " ++ name ++ " not found")
|
||||
};
|
||||
};
|
||||
|
|
127
src/distPlus/expressionTree/TypeSystem.re
Normal file
127
src/distPlus/expressionTree/TypeSystem.re
Normal file
|
@ -0,0 +1,127 @@
|
|||
type node = ExpressionTypes.ExpressionTree.node;
|
||||
let getFloat = ExpressionTypes.ExpressionTree.getFloat;
|
||||
|
||||
type samplingDist = [
|
||||
| `SymbolicDist(SymbolicTypes.symbolicDist)
|
||||
| `RenderedDist(DistTypes.shape)
|
||||
];
|
||||
|
||||
type t = [
|
||||
| `Float
|
||||
| `SamplingDistribution
|
||||
| `RenderedDistribution
|
||||
| `Array(t)
|
||||
| `Named(array((string, t)))
|
||||
];
|
||||
type tx = [
|
||||
| `Float(float)
|
||||
| `RenderedDist(DistTypes.shape)
|
||||
| `SamplingDist(samplingDist)
|
||||
| `Array(array(tx))
|
||||
| `Named(array((string, tx)))
|
||||
];
|
||||
type fn = {
|
||||
name: string,
|
||||
inputs: array(t),
|
||||
output: t,
|
||||
run: array(tx) => result(node, string),
|
||||
};
|
||||
|
||||
module Function = {
|
||||
let make = (~name, ~inputs, ~output, ~run): fn => {
|
||||
name,
|
||||
inputs,
|
||||
output,
|
||||
run,
|
||||
};
|
||||
};
|
||||
|
||||
type fns = array(fn);
|
||||
type inputs = array(node);
|
||||
|
||||
let rec fromNodeDirect = (node: node): result(tx, string) =>
|
||||
switch (ExpressionTypes.ExpressionTree.toFloatIfNeeded(node)) {
|
||||
| `SymbolicDist(`Float(r)) => Ok(`Float(r))
|
||||
| `SymbolicDist(s) => Ok(`SamplingDist(`SymbolicDist(s)))
|
||||
| `RenderedDist(s) => Ok(`RenderedDist(s))
|
||||
| `Array(r) =>
|
||||
r
|
||||
|> E.A.fmap(fromNodeDirect)
|
||||
|> E.A.R.firstErrorOrOpen
|
||||
|> E.R.fmap(r => `Array(r))
|
||||
| `Hash(hash) =>
|
||||
hash
|
||||
|> E.A.fmap(((name, t)) =>
|
||||
fromNodeDirect(t) |> E.R.fmap(r => (name, r))
|
||||
)
|
||||
|> E.A.R.firstErrorOrOpen
|
||||
|> E.R.fmap(r => `Named(r))
|
||||
| _ => Error("Wrong type")
|
||||
};
|
||||
|
||||
let compareInput = (evaluationParams, t: t, node) =>
|
||||
switch (t) {
|
||||
| `Float =>
|
||||
switch (getFloat(node)) {
|
||||
| Some(a) => Ok(`Float(a))
|
||||
| _ =>
|
||||
Error(
|
||||
"Type Error: Expected float."
|
||||
)
|
||||
}
|
||||
| `SamplingDistribution =>
|
||||
PTypes.SamplingDistribution.renderIfIsNotSamplingDistribution(
|
||||
evaluationParams,
|
||||
node,
|
||||
)
|
||||
|> E.R.bind(_, fromNodeDirect)
|
||||
| `RenderedDistribution =>
|
||||
ExpressionTypes.ExpressionTree.Render.render(evaluationParams, node)
|
||||
|> E.R.bind(_, fromNodeDirect)
|
||||
| _ => Error("Bad input, sorry.")
|
||||
};
|
||||
|
||||
let sanatizeInputs =
|
||||
(
|
||||
evaluationParams: ExpressionTypes.ExpressionTree.evaluationParams,
|
||||
inputs: inputs,
|
||||
t: fn,
|
||||
) => {
|
||||
E.A.length(t.inputs) == E.A.length(inputs)
|
||||
? Belt.Array.zip(t.inputs, inputs)
|
||||
|> E.A.fmap(((def, input)) =>
|
||||
compareInput(evaluationParams, def, input)
|
||||
)
|
||||
|> E.A.R.firstErrorOrOpen
|
||||
: Error(
|
||||
"Wrong number of inputs. Expected"
|
||||
++ (E.A.length(t.inputs) |> E.I.toString)
|
||||
++ ". Got:"
|
||||
++ (E.A.length(inputs) |> E.I.toString),
|
||||
);
|
||||
};
|
||||
|
||||
let run =
|
||||
(
|
||||
evaluationParams: ExpressionTypes.ExpressionTree.evaluationParams,
|
||||
inputs: inputs,
|
||||
t: fn,
|
||||
) =>{
|
||||
(
|
||||
switch (sanatizeInputs(evaluationParams, inputs, t)) {
|
||||
| Ok(inputs) => t.run(inputs)
|
||||
| Error(r) => Error(r)
|
||||
}
|
||||
)
|
||||
|> (
|
||||
fun
|
||||
| Ok(i) => Ok(i)
|
||||
| Error(r) => {Js.log4("Error", inputs, t, sanatizeInputs(evaluationParams, inputs, t), ); Error("Function " ++ t.name ++ " error: " ++ r)}
|
||||
);
|
||||
}
|
||||
|
||||
let getFn = (fns: fns, n: string) =>
|
||||
fns |> Belt.Array.getBy(_, ({name}) => name == n);
|
||||
|
||||
let getAndRun = (fns: fns, n: string, evaluationParams, inputs) =>
|
||||
getFn(fns, n) |> E.O.fmap(run(evaluationParams, inputs));
|
Loading…
Reference in New Issue
Block a user