Further cleanup of MultiModal functionality
This commit is contained in:
parent
b71d037180
commit
a6051d8371
|
@ -24,7 +24,6 @@ let rec toString: node => string =
|
||||||
++ toString(internal)
|
++ toString(internal)
|
||||||
++ ")]"
|
++ ")]"
|
||||||
| `Array(_) => "Array"
|
| `Array(_) => "Array"
|
||||||
| `MultiModal(_) => "Multimodal"
|
|
||||||
| `Hash(_) => "Hash"
|
| `Hash(_) => "Hash"
|
||||||
|
|
||||||
let envs = (samplingInputs, environment) => {
|
let envs = (samplingInputs, environment) => {
|
||||||
|
|
|
@ -307,6 +307,5 @@ let rec toLeaf =
|
||||||
Js.log3("In function call", name, args);
|
Js.log3("In function call", name, args);
|
||||||
callableFunction(evaluationParams, name, args)
|
callableFunction(evaluationParams, name, args)
|
||||||
|> E.R.bind(_, toLeaf(evaluationParams));
|
|> E.R.bind(_, toLeaf(evaluationParams));
|
||||||
| `MultiModal(r) => Error("Multimodal?")
|
|
||||||
};
|
};
|
||||||
};
|
};
|
||||||
|
|
|
@ -30,16 +30,18 @@ module ExpressionTree = {
|
||||||
| `Render(node)
|
| `Render(node)
|
||||||
| `Truncate(option(float), option(float), node)
|
| `Truncate(option(float), option(float), node)
|
||||||
| `FunctionCall(string, array(node))
|
| `FunctionCall(string, array(node))
|
||||||
| `MultiModal(array((node, float)))
|
|
||||||
];
|
];
|
||||||
|
|
||||||
module Hash = {
|
module Hash = {
|
||||||
type t('a) = array((string, 'a));
|
type t('a) = array((string, 'a));
|
||||||
let getByName = (t:t('a), name) =>
|
let getByName = (t: t('a), name) =>
|
||||||
E.A.getBy(t, ((n, _)) => n == name) |> E.O.fmap(((_, r)) => r);
|
E.A.getBy(t, ((n, _)) => n == name) |> E.O.fmap(((_, r)) => r);
|
||||||
|
|
||||||
let getByNames = (hash: t('a), names:array(string)) =>
|
let getByNameResult = (t: t('a), name) =>
|
||||||
names |> E.A.fmap(name => (name, getByName(hash, name)))
|
getByName(t, name) |> E.O.toResult(name ++ " expected and not found");
|
||||||
|
|
||||||
|
let getByNames = (hash: t('a), names: array(string)) =>
|
||||||
|
names |> E.A.fmap(name => (name, getByName(hash, name)));
|
||||||
};
|
};
|
||||||
// Have nil as option
|
// Have nil as option
|
||||||
let getFloat = (node: node) =>
|
let getFloat = (node: node) =>
|
||||||
|
|
|
@ -107,6 +107,80 @@ let verticalScaling = (scaleOp, rs, scaleBy) => {
|
||||||
);
|
);
|
||||||
};
|
};
|
||||||
|
|
||||||
|
module Multimodal = {
|
||||||
|
let getByNameResult = ExpressionTypes.ExpressionTree.Hash.getByNameResult;
|
||||||
|
|
||||||
|
let _paramsToDistsAndWeights = (r: array(typedValue)) =>
|
||||||
|
switch (r) {
|
||||||
|
| [|`Named(r)|] =>
|
||||||
|
let dists =
|
||||||
|
getByNameResult(r, "dists")
|
||||||
|
->E.R.bind(TypeSystem.TypedValue.toArray)
|
||||||
|
->E.R.bind(r =>
|
||||||
|
r
|
||||||
|
|> E.A.fmap(TypeSystem.TypedValue.toDist)
|
||||||
|
|> E.A.R.firstErrorOrOpen
|
||||||
|
);
|
||||||
|
let weights =
|
||||||
|
getByNameResult(r, "weights")
|
||||||
|
->E.R.bind(TypeSystem.TypedValue.toArray)
|
||||||
|
->E.R.bind(r =>
|
||||||
|
r
|
||||||
|
|> E.A.fmap(TypeSystem.TypedValue.toFloat)
|
||||||
|
|> E.A.R.firstErrorOrOpen
|
||||||
|
);
|
||||||
|
|
||||||
|
E.R.merge(dists, weights)
|
||||||
|
|> E.R.fmap(((a, b)) =>
|
||||||
|
E.A.zipMaxLength(a, b)
|
||||||
|
|> E.A.fmap(((a, b)) =>
|
||||||
|
(a |> E.O.toExn(""), b |> E.O.default(1.0))
|
||||||
|
)
|
||||||
|
);
|
||||||
|
| _ => Error("Needs items")
|
||||||
|
};
|
||||||
|
let _runner: array(typedValue) => result(node, string) =
|
||||||
|
r => {
|
||||||
|
let paramsToDistsAndWeights =
|
||||||
|
_paramsToDistsAndWeights(r)
|
||||||
|
|> E.R.fmap(
|
||||||
|
E.A.fmap(((dist, weight)) =>
|
||||||
|
`FunctionCall((
|
||||||
|
"scaleMultiply",
|
||||||
|
[|dist, `SymbolicDist(`Float(weight))|],
|
||||||
|
))
|
||||||
|
),
|
||||||
|
);
|
||||||
|
let pointwiseSum: result(node, string) =
|
||||||
|
paramsToDistsAndWeights->E.R.bind(
|
||||||
|
E.R.errorIfCondition(E.A.isEmpty, "Needs one input"),
|
||||||
|
)
|
||||||
|
|> E.R.fmap(r =>
|
||||||
|
r
|
||||||
|
|> Js.Array.sliceFrom(1)
|
||||||
|
|> E.A.fold_left(
|
||||||
|
(acc, x) => {`PointwiseCombination((`Add, acc, x))},
|
||||||
|
E.A.unsafe_get(r, 0),
|
||||||
|
)
|
||||||
|
);
|
||||||
|
pointwiseSum;
|
||||||
|
};
|
||||||
|
|
||||||
|
let _function =
|
||||||
|
Function.T.make(
|
||||||
|
~name="multimodal",
|
||||||
|
~outputType=`SamplingDistribution,
|
||||||
|
~inputTypes=[|
|
||||||
|
`Named([|
|
||||||
|
("dists", `Array(`SamplingDistribution)),
|
||||||
|
("weights", `Array(`Float)),
|
||||||
|
|]),
|
||||||
|
|],
|
||||||
|
~run=_runner,
|
||||||
|
(),
|
||||||
|
);
|
||||||
|
};
|
||||||
|
|
||||||
let functions = [|
|
let functions = [|
|
||||||
makeSymbolicFromTwoFloats("normal", SymbolicDist.Normal.make),
|
makeSymbolicFromTwoFloats("normal", SymbolicDist.Normal.make),
|
||||||
makeSymbolicFromTwoFloats("uniform", SymbolicDist.Uniform.make),
|
makeSymbolicFromTwoFloats("uniform", SymbolicDist.Uniform.make),
|
||||||
|
@ -175,98 +249,5 @@ let functions = [|
|
||||||
makeRenderedDistFloat("scaleLog", (dist, float) =>
|
makeRenderedDistFloat("scaleLog", (dist, float) =>
|
||||||
verticalScaling(`Log, dist, float)
|
verticalScaling(`Log, dist, float)
|
||||||
),
|
),
|
||||||
Function.T.make(
|
Multimodal._function
|
||||||
~name="multimodal",
|
|
||||||
~outputType=`SamplingDistribution,
|
|
||||||
~inputTypes=[|
|
|
||||||
`Named([|
|
|
||||||
("dists", `Array(`SamplingDistribution)),
|
|
||||||
("weights", `Array(`Float)),
|
|
||||||
|]),
|
|
||||||
|],
|
|
||||||
~run=
|
|
||||||
fun
|
|
||||||
| [|`Named(r)|] => {
|
|
||||||
let foo =
|
|
||||||
(r: TypeSystem.typedValue)
|
|
||||||
: result(ExpressionTypes.ExpressionTree.node, string) =>
|
|
||||||
switch (r) {
|
|
||||||
| `SamplingDist(`SymbolicDist(c)) => Ok(`SymbolicDist(c))
|
|
||||||
| `SamplingDist(`RenderedDist(c)) => Ok(`RenderedDist(c))
|
|
||||||
| `Float(x) =>
|
|
||||||
Ok(`RenderedDist(SymbolicDist.T.toShape(1000, `Float(x))))
|
|
||||||
| _ => Error("")
|
|
||||||
};
|
|
||||||
let weight = (r: TypeSystem.typedValue) =>
|
|
||||||
switch (r) {
|
|
||||||
| `Float(x) => Ok(x)
|
|
||||||
| _ => Error("Wrong Type")
|
|
||||||
};
|
|
||||||
let dists =
|
|
||||||
switch (ExpressionTypes.ExpressionTree.Hash.getByName(r, "dists")) {
|
|
||||||
| Some(`Array(r)) => r |> E.A.fmap(foo) |> E.A.R.firstErrorOrOpen
|
|
||||||
| _ => Error("")
|
|
||||||
};
|
|
||||||
let weights =
|
|
||||||
(
|
|
||||||
switch (
|
|
||||||
ExpressionTypes.ExpressionTree.Hash.getByName(r, "weights")
|
|
||||||
) {
|
|
||||||
| Some(`Array(r)) =>
|
|
||||||
r |> E.A.fmap(weight) |> E.A.R.firstErrorOrOpen
|
|
||||||
| _ => Error("")
|
|
||||||
}
|
|
||||||
)
|
|
||||||
|> (
|
|
||||||
fun
|
|
||||||
| Ok(r) => r
|
|
||||||
| _ => [||]
|
|
||||||
);
|
|
||||||
let withWeights =
|
|
||||||
dists
|
|
||||||
|> E.R.fmap(d => {
|
|
||||||
let iis =
|
|
||||||
d |> E.A.length |> Belt.Array.makeUninitializedUnsafe;
|
|
||||||
for (i in 0 to (d |> E.A.length) - 1) {
|
|
||||||
Belt.Array.set(
|
|
||||||
iis,
|
|
||||||
i,
|
|
||||||
(
|
|
||||||
E.A.unsafe_get(d, i),
|
|
||||||
E.A.get(weights, i) |> E.O.default(1.0),
|
|
||||||
),
|
|
||||||
)
|
|
||||||
|> ignore;
|
|
||||||
};
|
|
||||||
iis;
|
|
||||||
});
|
|
||||||
let components: result(array(node), string) =
|
|
||||||
withWeights
|
|
||||||
|> E.R.fmap(
|
|
||||||
E.A.fmap(((dist, weight)) =>
|
|
||||||
`FunctionCall((
|
|
||||||
"scaleMultiply",
|
|
||||||
[|dist, `SymbolicDist(`Float(weight))|],
|
|
||||||
))
|
|
||||||
),
|
|
||||||
);
|
|
||||||
let pointwiseSum =
|
|
||||||
components
|
|
||||||
|> E.R.bind(_, r => {
|
|
||||||
E.A.length(r) > 0
|
|
||||||
? Ok(r) : Error("Invalid argument length")
|
|
||||||
})
|
|
||||||
|> E.R.fmap(r =>
|
|
||||||
r
|
|
||||||
|> Js.Array.sliceFrom(1)
|
|
||||||
|> E.A.fold_left(
|
|
||||||
(acc, x) => {`PointwiseCombination((`Add, acc, x))},
|
|
||||||
E.A.unsafe_get(r, 0),
|
|
||||||
)
|
|
||||||
);
|
|
||||||
pointwiseSum;
|
|
||||||
}
|
|
||||||
| _ => Error(""),
|
|
||||||
(),
|
|
||||||
),
|
|
||||||
|];
|
|];
|
||||||
|
|
|
@ -54,7 +54,6 @@ module TypedValue = {
|
||||||
|
|
||||||
// todo: Arrays and hashes
|
// todo: Arrays and hashes
|
||||||
let rec fromNodeWithTypeCoercion = (evaluationParams, _type: _type, node) => {
|
let rec fromNodeWithTypeCoercion = (evaluationParams, _type: _type, node) => {
|
||||||
Js.log3("With Coersion!", _type, node);
|
|
||||||
switch (_type, node) {
|
switch (_type, node) {
|
||||||
| (`Float, _) =>
|
| (`Float, _) =>
|
||||||
switch (getFloat(node)) {
|
switch (getFloat(node)) {
|
||||||
|
@ -76,7 +75,6 @@ module TypedValue = {
|
||||||
|> E.A.R.firstErrorOrOpen
|
|> E.A.R.firstErrorOrOpen
|
||||||
|> E.R.fmap(r => `Array(r))
|
|> E.R.fmap(r => `Array(r))
|
||||||
| (`Named(named), `Hash(r)) =>
|
| (`Named(named), `Hash(r)) =>
|
||||||
Js.log3("Named", named, r);
|
|
||||||
let foo =
|
let foo =
|
||||||
named
|
named
|
||||||
|> E.A.fmap(((name, intendedType)) =>
|
|> E.A.fmap(((name, intendedType)) =>
|
||||||
|
@ -86,7 +84,6 @@ module TypedValue = {
|
||||||
ExpressionTypes.ExpressionTree.Hash.getByName(r, name),
|
ExpressionTypes.ExpressionTree.Hash.getByName(r, name),
|
||||||
)
|
)
|
||||||
);
|
);
|
||||||
Js.log("Named: part 2");
|
|
||||||
let bar =
|
let bar =
|
||||||
foo
|
foo
|
||||||
|> E.A.fmap(((name, intendedType, optionNode)) =>
|
|> E.A.fmap(((name, intendedType, optionNode)) =>
|
||||||
|
@ -99,11 +96,33 @@ module TypedValue = {
|
||||||
)
|
)
|
||||||
|> E.A.R.firstErrorOrOpen
|
|> E.A.R.firstErrorOrOpen
|
||||||
|> E.R.fmap(r => `Named(r));
|
|> E.R.fmap(r => `Named(r));
|
||||||
Js.log3("Named!", foo, bar);
|
|
||||||
bar;
|
bar;
|
||||||
| _ => Error("fromNodeWithTypeCoercion error, sorry.")
|
| _ => Error("fromNodeWithTypeCoercion error, sorry.")
|
||||||
};
|
};
|
||||||
};
|
};
|
||||||
|
|
||||||
|
let toFloat =
|
||||||
|
fun
|
||||||
|
| `Float(x) => Ok(x)
|
||||||
|
| _ => Error("Not a float");
|
||||||
|
|
||||||
|
let toArray =
|
||||||
|
fun
|
||||||
|
| `Array(x) => Ok(x)
|
||||||
|
| _ => Error("Not an array");
|
||||||
|
|
||||||
|
let toNamed =
|
||||||
|
fun
|
||||||
|
| `Named(x) => Ok(x)
|
||||||
|
| _ => Error("Not a named item");
|
||||||
|
|
||||||
|
let toDist =
|
||||||
|
fun
|
||||||
|
| `SamplingDist(`SymbolicDist(c)) => Ok(`SymbolicDist(c))
|
||||||
|
| `SamplingDist(`RenderedDist(c)) => Ok(`RenderedDist(c))
|
||||||
|
| `Float(x) =>
|
||||||
|
Ok(`RenderedDist(SymbolicDist.T.toShape(1000, `Float(x))))
|
||||||
|
| _ => Error("");
|
||||||
};
|
};
|
||||||
|
|
||||||
module Function = {
|
module Function = {
|
||||||
|
|
|
@ -144,7 +144,6 @@ let renderIfNeeded =
|
||||||
node
|
node
|
||||||
|> (
|
|> (
|
||||||
fun
|
fun
|
||||||
| `MultiModal(_) as n
|
|
||||||
| `Normalize(_) as n
|
| `Normalize(_) as n
|
||||||
| `SymbolicDist(_) as n => {
|
| `SymbolicDist(_) as n => {
|
||||||
`Render(n)
|
`Render(n)
|
||||||
|
|
|
@ -26,6 +26,9 @@ module FloatFloatMap = {
|
||||||
let fmap = (fn, t: t) => Belt.MutableMap.map(t, fn);
|
let fmap = (fn, t: t) => Belt.MutableMap.map(t, fn);
|
||||||
};
|
};
|
||||||
|
|
||||||
|
module Int = {
|
||||||
|
let max = (i1: int, i2: int) => i1 > i2 ? i1 : i2;
|
||||||
|
};
|
||||||
/* Utils */
|
/* Utils */
|
||||||
module U = {
|
module U = {
|
||||||
let isEqual = (a, b) => a == b;
|
let isEqual = (a, b) => a == b;
|
||||||
|
@ -146,6 +149,11 @@ module R = {
|
||||||
let fmap = Rationale.Result.fmap;
|
let fmap = Rationale.Result.fmap;
|
||||||
let bind = Rationale.Result.bind;
|
let bind = Rationale.Result.bind;
|
||||||
let toExn = Belt.Result.getExn;
|
let toExn = Belt.Result.getExn;
|
||||||
|
let default = (default, res: Belt.Result.t('a, 'b)) =>
|
||||||
|
switch (res) {
|
||||||
|
| Ok(r) => r
|
||||||
|
| Error(_) => default
|
||||||
|
};
|
||||||
let merge = (a, b) =>
|
let merge = (a, b) =>
|
||||||
switch (a, b) {
|
switch (a, b) {
|
||||||
| (Error(e), _) => Error(e)
|
| (Error(e), _) => Error(e)
|
||||||
|
@ -157,6 +165,9 @@ module R = {
|
||||||
| Ok(r) => Some(r)
|
| Ok(r) => Some(r)
|
||||||
| Error(_) => None
|
| Error(_) => None
|
||||||
};
|
};
|
||||||
|
|
||||||
|
let errorIfCondition = (errorCondition, errorMessage, r) =>
|
||||||
|
errorCondition(r) ? Error(errorMessage) : Ok(r);
|
||||||
};
|
};
|
||||||
|
|
||||||
let safe_fn_of_string = (fn, s: string): option('a) =>
|
let safe_fn_of_string = (fn, s: string): option('a) =>
|
||||||
|
@ -263,6 +274,7 @@ module A = {
|
||||||
let init = Array.init;
|
let init = Array.init;
|
||||||
let reduce = Belt.Array.reduce;
|
let reduce = Belt.Array.reduce;
|
||||||
let reducei = Belt.Array.reduceWithIndex;
|
let reducei = Belt.Array.reduceWithIndex;
|
||||||
|
let isEmpty = r => length(r) < 1;
|
||||||
let min = a =>
|
let min = a =>
|
||||||
get(a, 0)
|
get(a, 0)
|
||||||
|> O.fmap(first => Belt.Array.reduce(a, first, (i, j) => i < j ? i : j));
|
|> O.fmap(first => Belt.Array.reduce(a, first, (i, j) => i < j ? i : j));
|
||||||
|
@ -285,6 +297,16 @@ module A = {
|
||||||
|> Rationale.Result.return
|
|> Rationale.Result.return
|
||||||
};
|
};
|
||||||
|
|
||||||
|
// This zips while taking the longest elements of each array.
|
||||||
|
let zipMaxLength = (array1, array2) => {
|
||||||
|
let maxLength = Int.max(length(array1), length(array2));
|
||||||
|
let result = maxLength |> Belt.Array.makeUninitializedUnsafe;
|
||||||
|
for (i in 0 to maxLength - 1) {
|
||||||
|
Belt.Array.set(result, i, (get(array1, i), get(array2, i))) |> ignore;
|
||||||
|
};
|
||||||
|
result;
|
||||||
|
};
|
||||||
|
|
||||||
let asList = (f: list('a) => list('a), r: array('a)) =>
|
let asList = (f: list('a) => list('a), r: array('a)) =>
|
||||||
r |> to_list |> f |> of_list;
|
r |> to_list |> f |> of_list;
|
||||||
/* TODO: Is there a better way of doing this? */
|
/* TODO: Is there a better way of doing this? */
|
||||||
|
|
Loading…
Reference in New Issue
Block a user