Further cleanup of MultiModal functionality
This commit is contained in:
parent
b71d037180
commit
a6051d8371
|
@ -24,7 +24,6 @@ let rec toString: node => string =
|
|||
++ toString(internal)
|
||||
++ ")]"
|
||||
| `Array(_) => "Array"
|
||||
| `MultiModal(_) => "Multimodal"
|
||||
| `Hash(_) => "Hash"
|
||||
|
||||
let envs = (samplingInputs, environment) => {
|
||||
|
|
|
@ -307,6 +307,5 @@ let rec toLeaf =
|
|||
Js.log3("In function call", name, args);
|
||||
callableFunction(evaluationParams, name, args)
|
||||
|> E.R.bind(_, toLeaf(evaluationParams));
|
||||
| `MultiModal(r) => Error("Multimodal?")
|
||||
};
|
||||
};
|
||||
|
|
|
@ -30,16 +30,18 @@ module ExpressionTree = {
|
|||
| `Render(node)
|
||||
| `Truncate(option(float), option(float), node)
|
||||
| `FunctionCall(string, array(node))
|
||||
| `MultiModal(array((node, float)))
|
||||
];
|
||||
|
||||
module Hash = {
|
||||
type t('a) = array((string, 'a));
|
||||
let getByName = (t:t('a), name) =>
|
||||
let getByName = (t: t('a), name) =>
|
||||
E.A.getBy(t, ((n, _)) => n == name) |> E.O.fmap(((_, r)) => r);
|
||||
|
||||
let getByNames = (hash: t('a), names:array(string)) =>
|
||||
names |> E.A.fmap(name => (name, getByName(hash, name)))
|
||||
let getByNameResult = (t: t('a), name) =>
|
||||
getByName(t, name) |> E.O.toResult(name ++ " expected and not found");
|
||||
|
||||
let getByNames = (hash: t('a), names: array(string)) =>
|
||||
names |> E.A.fmap(name => (name, getByName(hash, name)));
|
||||
};
|
||||
// Have nil as option
|
||||
let getFloat = (node: node) =>
|
||||
|
|
|
@ -107,6 +107,80 @@ let verticalScaling = (scaleOp, rs, scaleBy) => {
|
|||
);
|
||||
};
|
||||
|
||||
module Multimodal = {
|
||||
let getByNameResult = ExpressionTypes.ExpressionTree.Hash.getByNameResult;
|
||||
|
||||
let _paramsToDistsAndWeights = (r: array(typedValue)) =>
|
||||
switch (r) {
|
||||
| [|`Named(r)|] =>
|
||||
let dists =
|
||||
getByNameResult(r, "dists")
|
||||
->E.R.bind(TypeSystem.TypedValue.toArray)
|
||||
->E.R.bind(r =>
|
||||
r
|
||||
|> E.A.fmap(TypeSystem.TypedValue.toDist)
|
||||
|> E.A.R.firstErrorOrOpen
|
||||
);
|
||||
let weights =
|
||||
getByNameResult(r, "weights")
|
||||
->E.R.bind(TypeSystem.TypedValue.toArray)
|
||||
->E.R.bind(r =>
|
||||
r
|
||||
|> E.A.fmap(TypeSystem.TypedValue.toFloat)
|
||||
|> E.A.R.firstErrorOrOpen
|
||||
);
|
||||
|
||||
E.R.merge(dists, weights)
|
||||
|> E.R.fmap(((a, b)) =>
|
||||
E.A.zipMaxLength(a, b)
|
||||
|> E.A.fmap(((a, b)) =>
|
||||
(a |> E.O.toExn(""), b |> E.O.default(1.0))
|
||||
)
|
||||
);
|
||||
| _ => Error("Needs items")
|
||||
};
|
||||
let _runner: array(typedValue) => result(node, string) =
|
||||
r => {
|
||||
let paramsToDistsAndWeights =
|
||||
_paramsToDistsAndWeights(r)
|
||||
|> E.R.fmap(
|
||||
E.A.fmap(((dist, weight)) =>
|
||||
`FunctionCall((
|
||||
"scaleMultiply",
|
||||
[|dist, `SymbolicDist(`Float(weight))|],
|
||||
))
|
||||
),
|
||||
);
|
||||
let pointwiseSum: result(node, string) =
|
||||
paramsToDistsAndWeights->E.R.bind(
|
||||
E.R.errorIfCondition(E.A.isEmpty, "Needs one input"),
|
||||
)
|
||||
|> E.R.fmap(r =>
|
||||
r
|
||||
|> Js.Array.sliceFrom(1)
|
||||
|> E.A.fold_left(
|
||||
(acc, x) => {`PointwiseCombination((`Add, acc, x))},
|
||||
E.A.unsafe_get(r, 0),
|
||||
)
|
||||
);
|
||||
pointwiseSum;
|
||||
};
|
||||
|
||||
let _function =
|
||||
Function.T.make(
|
||||
~name="multimodal",
|
||||
~outputType=`SamplingDistribution,
|
||||
~inputTypes=[|
|
||||
`Named([|
|
||||
("dists", `Array(`SamplingDistribution)),
|
||||
("weights", `Array(`Float)),
|
||||
|]),
|
||||
|],
|
||||
~run=_runner,
|
||||
(),
|
||||
);
|
||||
};
|
||||
|
||||
let functions = [|
|
||||
makeSymbolicFromTwoFloats("normal", SymbolicDist.Normal.make),
|
||||
makeSymbolicFromTwoFloats("uniform", SymbolicDist.Uniform.make),
|
||||
|
@ -175,98 +249,5 @@ let functions = [|
|
|||
makeRenderedDistFloat("scaleLog", (dist, float) =>
|
||||
verticalScaling(`Log, dist, float)
|
||||
),
|
||||
Function.T.make(
|
||||
~name="multimodal",
|
||||
~outputType=`SamplingDistribution,
|
||||
~inputTypes=[|
|
||||
`Named([|
|
||||
("dists", `Array(`SamplingDistribution)),
|
||||
("weights", `Array(`Float)),
|
||||
|]),
|
||||
|],
|
||||
~run=
|
||||
fun
|
||||
| [|`Named(r)|] => {
|
||||
let foo =
|
||||
(r: TypeSystem.typedValue)
|
||||
: result(ExpressionTypes.ExpressionTree.node, string) =>
|
||||
switch (r) {
|
||||
| `SamplingDist(`SymbolicDist(c)) => Ok(`SymbolicDist(c))
|
||||
| `SamplingDist(`RenderedDist(c)) => Ok(`RenderedDist(c))
|
||||
| `Float(x) =>
|
||||
Ok(`RenderedDist(SymbolicDist.T.toShape(1000, `Float(x))))
|
||||
| _ => Error("")
|
||||
};
|
||||
let weight = (r: TypeSystem.typedValue) =>
|
||||
switch (r) {
|
||||
| `Float(x) => Ok(x)
|
||||
| _ => Error("Wrong Type")
|
||||
};
|
||||
let dists =
|
||||
switch (ExpressionTypes.ExpressionTree.Hash.getByName(r, "dists")) {
|
||||
| Some(`Array(r)) => r |> E.A.fmap(foo) |> E.A.R.firstErrorOrOpen
|
||||
| _ => Error("")
|
||||
};
|
||||
let weights =
|
||||
(
|
||||
switch (
|
||||
ExpressionTypes.ExpressionTree.Hash.getByName(r, "weights")
|
||||
) {
|
||||
| Some(`Array(r)) =>
|
||||
r |> E.A.fmap(weight) |> E.A.R.firstErrorOrOpen
|
||||
| _ => Error("")
|
||||
}
|
||||
)
|
||||
|> (
|
||||
fun
|
||||
| Ok(r) => r
|
||||
| _ => [||]
|
||||
);
|
||||
let withWeights =
|
||||
dists
|
||||
|> E.R.fmap(d => {
|
||||
let iis =
|
||||
d |> E.A.length |> Belt.Array.makeUninitializedUnsafe;
|
||||
for (i in 0 to (d |> E.A.length) - 1) {
|
||||
Belt.Array.set(
|
||||
iis,
|
||||
i,
|
||||
(
|
||||
E.A.unsafe_get(d, i),
|
||||
E.A.get(weights, i) |> E.O.default(1.0),
|
||||
),
|
||||
)
|
||||
|> ignore;
|
||||
};
|
||||
iis;
|
||||
});
|
||||
let components: result(array(node), string) =
|
||||
withWeights
|
||||
|> E.R.fmap(
|
||||
E.A.fmap(((dist, weight)) =>
|
||||
`FunctionCall((
|
||||
"scaleMultiply",
|
||||
[|dist, `SymbolicDist(`Float(weight))|],
|
||||
))
|
||||
),
|
||||
);
|
||||
let pointwiseSum =
|
||||
components
|
||||
|> E.R.bind(_, r => {
|
||||
E.A.length(r) > 0
|
||||
? Ok(r) : Error("Invalid argument length")
|
||||
})
|
||||
|> E.R.fmap(r =>
|
||||
r
|
||||
|> Js.Array.sliceFrom(1)
|
||||
|> E.A.fold_left(
|
||||
(acc, x) => {`PointwiseCombination((`Add, acc, x))},
|
||||
E.A.unsafe_get(r, 0),
|
||||
)
|
||||
);
|
||||
pointwiseSum;
|
||||
}
|
||||
| _ => Error(""),
|
||||
(),
|
||||
),
|
||||
Multimodal._function
|
||||
|];
|
||||
|
|
|
@ -54,7 +54,6 @@ module TypedValue = {
|
|||
|
||||
// todo: Arrays and hashes
|
||||
let rec fromNodeWithTypeCoercion = (evaluationParams, _type: _type, node) => {
|
||||
Js.log3("With Coersion!", _type, node);
|
||||
switch (_type, node) {
|
||||
| (`Float, _) =>
|
||||
switch (getFloat(node)) {
|
||||
|
@ -76,7 +75,6 @@ module TypedValue = {
|
|||
|> E.A.R.firstErrorOrOpen
|
||||
|> E.R.fmap(r => `Array(r))
|
||||
| (`Named(named), `Hash(r)) =>
|
||||
Js.log3("Named", named, r);
|
||||
let foo =
|
||||
named
|
||||
|> E.A.fmap(((name, intendedType)) =>
|
||||
|
@ -86,7 +84,6 @@ module TypedValue = {
|
|||
ExpressionTypes.ExpressionTree.Hash.getByName(r, name),
|
||||
)
|
||||
);
|
||||
Js.log("Named: part 2");
|
||||
let bar =
|
||||
foo
|
||||
|> E.A.fmap(((name, intendedType, optionNode)) =>
|
||||
|
@ -99,11 +96,33 @@ module TypedValue = {
|
|||
)
|
||||
|> E.A.R.firstErrorOrOpen
|
||||
|> E.R.fmap(r => `Named(r));
|
||||
Js.log3("Named!", foo, bar);
|
||||
bar;
|
||||
| _ => Error("fromNodeWithTypeCoercion error, sorry.")
|
||||
};
|
||||
};
|
||||
|
||||
let toFloat =
|
||||
fun
|
||||
| `Float(x) => Ok(x)
|
||||
| _ => Error("Not a float");
|
||||
|
||||
let toArray =
|
||||
fun
|
||||
| `Array(x) => Ok(x)
|
||||
| _ => Error("Not an array");
|
||||
|
||||
let toNamed =
|
||||
fun
|
||||
| `Named(x) => Ok(x)
|
||||
| _ => Error("Not a named item");
|
||||
|
||||
let toDist =
|
||||
fun
|
||||
| `SamplingDist(`SymbolicDist(c)) => Ok(`SymbolicDist(c))
|
||||
| `SamplingDist(`RenderedDist(c)) => Ok(`RenderedDist(c))
|
||||
| `Float(x) =>
|
||||
Ok(`RenderedDist(SymbolicDist.T.toShape(1000, `Float(x))))
|
||||
| _ => Error("");
|
||||
};
|
||||
|
||||
module Function = {
|
||||
|
|
|
@ -144,7 +144,6 @@ let renderIfNeeded =
|
|||
node
|
||||
|> (
|
||||
fun
|
||||
| `MultiModal(_) as n
|
||||
| `Normalize(_) as n
|
||||
| `SymbolicDist(_) as n => {
|
||||
`Render(n)
|
||||
|
|
|
@ -26,6 +26,9 @@ module FloatFloatMap = {
|
|||
let fmap = (fn, t: t) => Belt.MutableMap.map(t, fn);
|
||||
};
|
||||
|
||||
module Int = {
|
||||
let max = (i1: int, i2: int) => i1 > i2 ? i1 : i2;
|
||||
};
|
||||
/* Utils */
|
||||
module U = {
|
||||
let isEqual = (a, b) => a == b;
|
||||
|
@ -146,6 +149,11 @@ module R = {
|
|||
let fmap = Rationale.Result.fmap;
|
||||
let bind = Rationale.Result.bind;
|
||||
let toExn = Belt.Result.getExn;
|
||||
let default = (default, res: Belt.Result.t('a, 'b)) =>
|
||||
switch (res) {
|
||||
| Ok(r) => r
|
||||
| Error(_) => default
|
||||
};
|
||||
let merge = (a, b) =>
|
||||
switch (a, b) {
|
||||
| (Error(e), _) => Error(e)
|
||||
|
@ -157,6 +165,9 @@ module R = {
|
|||
| Ok(r) => Some(r)
|
||||
| Error(_) => None
|
||||
};
|
||||
|
||||
let errorIfCondition = (errorCondition, errorMessage, r) =>
|
||||
errorCondition(r) ? Error(errorMessage) : Ok(r);
|
||||
};
|
||||
|
||||
let safe_fn_of_string = (fn, s: string): option('a) =>
|
||||
|
@ -263,6 +274,7 @@ module A = {
|
|||
let init = Array.init;
|
||||
let reduce = Belt.Array.reduce;
|
||||
let reducei = Belt.Array.reduceWithIndex;
|
||||
let isEmpty = r => length(r) < 1;
|
||||
let min = a =>
|
||||
get(a, 0)
|
||||
|> O.fmap(first => Belt.Array.reduce(a, first, (i, j) => i < j ? i : j));
|
||||
|
@ -285,6 +297,16 @@ module A = {
|
|||
|> Rationale.Result.return
|
||||
};
|
||||
|
||||
// This zips while taking the longest elements of each array.
|
||||
let zipMaxLength = (array1, array2) => {
|
||||
let maxLength = Int.max(length(array1), length(array2));
|
||||
let result = maxLength |> Belt.Array.makeUninitializedUnsafe;
|
||||
for (i in 0 to maxLength - 1) {
|
||||
Belt.Array.set(result, i, (get(array1, i), get(array2, i))) |> ignore;
|
||||
};
|
||||
result;
|
||||
};
|
||||
|
||||
let asList = (f: list('a) => list('a), r: array('a)) =>
|
||||
r |> to_list |> f |> of_list;
|
||||
/* TODO: Is there a better way of doing this? */
|
||||
|
|
Loading…
Reference in New Issue
Block a user