Further cleanup of MultiModal functionality

This commit is contained in:
Ozzie Gooen 2020-11-07 23:35:05 -08:00
parent b71d037180
commit a6051d8371
7 changed files with 126 additions and 105 deletions

View File

@ -24,7 +24,6 @@ let rec toString: node => string =
++ toString(internal) ++ toString(internal)
++ ")]" ++ ")]"
| `Array(_) => "Array" | `Array(_) => "Array"
| `MultiModal(_) => "Multimodal"
| `Hash(_) => "Hash" | `Hash(_) => "Hash"
let envs = (samplingInputs, environment) => { let envs = (samplingInputs, environment) => {

View File

@ -307,6 +307,5 @@ let rec toLeaf =
Js.log3("In function call", name, args); Js.log3("In function call", name, args);
callableFunction(evaluationParams, name, args) callableFunction(evaluationParams, name, args)
|> E.R.bind(_, toLeaf(evaluationParams)); |> E.R.bind(_, toLeaf(evaluationParams));
| `MultiModal(r) => Error("Multimodal?")
}; };
}; };

View File

@ -30,16 +30,18 @@ module ExpressionTree = {
| `Render(node) | `Render(node)
| `Truncate(option(float), option(float), node) | `Truncate(option(float), option(float), node)
| `FunctionCall(string, array(node)) | `FunctionCall(string, array(node))
| `MultiModal(array((node, float)))
]; ];
module Hash = { module Hash = {
type t('a) = array((string, 'a)); type t('a) = array((string, 'a));
let getByName = (t:t('a), name) => let getByName = (t: t('a), name) =>
E.A.getBy(t, ((n, _)) => n == name) |> E.O.fmap(((_, r)) => r); E.A.getBy(t, ((n, _)) => n == name) |> E.O.fmap(((_, r)) => r);
let getByNames = (hash: t('a), names:array(string)) => let getByNameResult = (t: t('a), name) =>
names |> E.A.fmap(name => (name, getByName(hash, name))) getByName(t, name) |> E.O.toResult(name ++ " expected and not found");
let getByNames = (hash: t('a), names: array(string)) =>
names |> E.A.fmap(name => (name, getByName(hash, name)));
}; };
// Have nil as option // Have nil as option
let getFloat = (node: node) => let getFloat = (node: node) =>

View File

@ -107,6 +107,80 @@ let verticalScaling = (scaleOp, rs, scaleBy) => {
); );
}; };
module Multimodal = {
let getByNameResult = ExpressionTypes.ExpressionTree.Hash.getByNameResult;
let _paramsToDistsAndWeights = (r: array(typedValue)) =>
switch (r) {
| [|`Named(r)|] =>
let dists =
getByNameResult(r, "dists")
->E.R.bind(TypeSystem.TypedValue.toArray)
->E.R.bind(r =>
r
|> E.A.fmap(TypeSystem.TypedValue.toDist)
|> E.A.R.firstErrorOrOpen
);
let weights =
getByNameResult(r, "weights")
->E.R.bind(TypeSystem.TypedValue.toArray)
->E.R.bind(r =>
r
|> E.A.fmap(TypeSystem.TypedValue.toFloat)
|> E.A.R.firstErrorOrOpen
);
E.R.merge(dists, weights)
|> E.R.fmap(((a, b)) =>
E.A.zipMaxLength(a, b)
|> E.A.fmap(((a, b)) =>
(a |> E.O.toExn(""), b |> E.O.default(1.0))
)
);
| _ => Error("Needs items")
};
let _runner: array(typedValue) => result(node, string) =
r => {
let paramsToDistsAndWeights =
_paramsToDistsAndWeights(r)
|> E.R.fmap(
E.A.fmap(((dist, weight)) =>
`FunctionCall((
"scaleMultiply",
[|dist, `SymbolicDist(`Float(weight))|],
))
),
);
let pointwiseSum: result(node, string) =
paramsToDistsAndWeights->E.R.bind(
E.R.errorIfCondition(E.A.isEmpty, "Needs one input"),
)
|> E.R.fmap(r =>
r
|> Js.Array.sliceFrom(1)
|> E.A.fold_left(
(acc, x) => {`PointwiseCombination((`Add, acc, x))},
E.A.unsafe_get(r, 0),
)
);
pointwiseSum;
};
let _function =
Function.T.make(
~name="multimodal",
~outputType=`SamplingDistribution,
~inputTypes=[|
`Named([|
("dists", `Array(`SamplingDistribution)),
("weights", `Array(`Float)),
|]),
|],
~run=_runner,
(),
);
};
let functions = [| let functions = [|
makeSymbolicFromTwoFloats("normal", SymbolicDist.Normal.make), makeSymbolicFromTwoFloats("normal", SymbolicDist.Normal.make),
makeSymbolicFromTwoFloats("uniform", SymbolicDist.Uniform.make), makeSymbolicFromTwoFloats("uniform", SymbolicDist.Uniform.make),
@ -175,98 +249,5 @@ let functions = [|
makeRenderedDistFloat("scaleLog", (dist, float) => makeRenderedDistFloat("scaleLog", (dist, float) =>
verticalScaling(`Log, dist, float) verticalScaling(`Log, dist, float)
), ),
Function.T.make( Multimodal._function
~name="multimodal",
~outputType=`SamplingDistribution,
~inputTypes=[|
`Named([|
("dists", `Array(`SamplingDistribution)),
("weights", `Array(`Float)),
|]),
|],
~run=
fun
| [|`Named(r)|] => {
let foo =
(r: TypeSystem.typedValue)
: result(ExpressionTypes.ExpressionTree.node, string) =>
switch (r) {
| `SamplingDist(`SymbolicDist(c)) => Ok(`SymbolicDist(c))
| `SamplingDist(`RenderedDist(c)) => Ok(`RenderedDist(c))
| `Float(x) =>
Ok(`RenderedDist(SymbolicDist.T.toShape(1000, `Float(x))))
| _ => Error("")
};
let weight = (r: TypeSystem.typedValue) =>
switch (r) {
| `Float(x) => Ok(x)
| _ => Error("Wrong Type")
};
let dists =
switch (ExpressionTypes.ExpressionTree.Hash.getByName(r, "dists")) {
| Some(`Array(r)) => r |> E.A.fmap(foo) |> E.A.R.firstErrorOrOpen
| _ => Error("")
};
let weights =
(
switch (
ExpressionTypes.ExpressionTree.Hash.getByName(r, "weights")
) {
| Some(`Array(r)) =>
r |> E.A.fmap(weight) |> E.A.R.firstErrorOrOpen
| _ => Error("")
}
)
|> (
fun
| Ok(r) => r
| _ => [||]
);
let withWeights =
dists
|> E.R.fmap(d => {
let iis =
d |> E.A.length |> Belt.Array.makeUninitializedUnsafe;
for (i in 0 to (d |> E.A.length) - 1) {
Belt.Array.set(
iis,
i,
(
E.A.unsafe_get(d, i),
E.A.get(weights, i) |> E.O.default(1.0),
),
)
|> ignore;
};
iis;
});
let components: result(array(node), string) =
withWeights
|> E.R.fmap(
E.A.fmap(((dist, weight)) =>
`FunctionCall((
"scaleMultiply",
[|dist, `SymbolicDist(`Float(weight))|],
))
),
);
let pointwiseSum =
components
|> E.R.bind(_, r => {
E.A.length(r) > 0
? Ok(r) : Error("Invalid argument length")
})
|> E.R.fmap(r =>
r
|> Js.Array.sliceFrom(1)
|> E.A.fold_left(
(acc, x) => {`PointwiseCombination((`Add, acc, x))},
E.A.unsafe_get(r, 0),
)
);
pointwiseSum;
}
| _ => Error(""),
(),
),
|]; |];

View File

@ -54,7 +54,6 @@ module TypedValue = {
// todo: Arrays and hashes // todo: Arrays and hashes
let rec fromNodeWithTypeCoercion = (evaluationParams, _type: _type, node) => { let rec fromNodeWithTypeCoercion = (evaluationParams, _type: _type, node) => {
Js.log3("With Coersion!", _type, node);
switch (_type, node) { switch (_type, node) {
| (`Float, _) => | (`Float, _) =>
switch (getFloat(node)) { switch (getFloat(node)) {
@ -76,7 +75,6 @@ module TypedValue = {
|> E.A.R.firstErrorOrOpen |> E.A.R.firstErrorOrOpen
|> E.R.fmap(r => `Array(r)) |> E.R.fmap(r => `Array(r))
| (`Named(named), `Hash(r)) => | (`Named(named), `Hash(r)) =>
Js.log3("Named", named, r);
let foo = let foo =
named named
|> E.A.fmap(((name, intendedType)) => |> E.A.fmap(((name, intendedType)) =>
@ -86,7 +84,6 @@ module TypedValue = {
ExpressionTypes.ExpressionTree.Hash.getByName(r, name), ExpressionTypes.ExpressionTree.Hash.getByName(r, name),
) )
); );
Js.log("Named: part 2");
let bar = let bar =
foo foo
|> E.A.fmap(((name, intendedType, optionNode)) => |> E.A.fmap(((name, intendedType, optionNode)) =>
@ -99,11 +96,33 @@ module TypedValue = {
) )
|> E.A.R.firstErrorOrOpen |> E.A.R.firstErrorOrOpen
|> E.R.fmap(r => `Named(r)); |> E.R.fmap(r => `Named(r));
Js.log3("Named!", foo, bar);
bar; bar;
| _ => Error("fromNodeWithTypeCoercion error, sorry.") | _ => Error("fromNodeWithTypeCoercion error, sorry.")
}; };
}; };
let toFloat =
fun
| `Float(x) => Ok(x)
| _ => Error("Not a float");
let toArray =
fun
| `Array(x) => Ok(x)
| _ => Error("Not an array");
let toNamed =
fun
| `Named(x) => Ok(x)
| _ => Error("Not a named item");
let toDist =
fun
| `SamplingDist(`SymbolicDist(c)) => Ok(`SymbolicDist(c))
| `SamplingDist(`RenderedDist(c)) => Ok(`RenderedDist(c))
| `Float(x) =>
Ok(`RenderedDist(SymbolicDist.T.toShape(1000, `Float(x))))
| _ => Error("");
}; };
module Function = { module Function = {

View File

@ -144,7 +144,6 @@ let renderIfNeeded =
node node
|> ( |> (
fun fun
| `MultiModal(_) as n
| `Normalize(_) as n | `Normalize(_) as n
| `SymbolicDist(_) as n => { | `SymbolicDist(_) as n => {
`Render(n) `Render(n)

View File

@ -26,6 +26,9 @@ module FloatFloatMap = {
let fmap = (fn, t: t) => Belt.MutableMap.map(t, fn); let fmap = (fn, t: t) => Belt.MutableMap.map(t, fn);
}; };
module Int = {
let max = (i1: int, i2: int) => i1 > i2 ? i1 : i2;
};
/* Utils */ /* Utils */
module U = { module U = {
let isEqual = (a, b) => a == b; let isEqual = (a, b) => a == b;
@ -146,6 +149,11 @@ module R = {
let fmap = Rationale.Result.fmap; let fmap = Rationale.Result.fmap;
let bind = Rationale.Result.bind; let bind = Rationale.Result.bind;
let toExn = Belt.Result.getExn; let toExn = Belt.Result.getExn;
let default = (default, res: Belt.Result.t('a, 'b)) =>
switch (res) {
| Ok(r) => r
| Error(_) => default
};
let merge = (a, b) => let merge = (a, b) =>
switch (a, b) { switch (a, b) {
| (Error(e), _) => Error(e) | (Error(e), _) => Error(e)
@ -157,6 +165,9 @@ module R = {
| Ok(r) => Some(r) | Ok(r) => Some(r)
| Error(_) => None | Error(_) => None
}; };
let errorIfCondition = (errorCondition, errorMessage, r) =>
errorCondition(r) ? Error(errorMessage) : Ok(r);
}; };
let safe_fn_of_string = (fn, s: string): option('a) => let safe_fn_of_string = (fn, s: string): option('a) =>
@ -263,6 +274,7 @@ module A = {
let init = Array.init; let init = Array.init;
let reduce = Belt.Array.reduce; let reduce = Belt.Array.reduce;
let reducei = Belt.Array.reduceWithIndex; let reducei = Belt.Array.reduceWithIndex;
let isEmpty = r => length(r) < 1;
let min = a => let min = a =>
get(a, 0) get(a, 0)
|> O.fmap(first => Belt.Array.reduce(a, first, (i, j) => i < j ? i : j)); |> O.fmap(first => Belt.Array.reduce(a, first, (i, j) => i < j ? i : j));
@ -285,6 +297,16 @@ module A = {
|> Rationale.Result.return |> Rationale.Result.return
}; };
// This zips while taking the longest elements of each array.
let zipMaxLength = (array1, array2) => {
let maxLength = Int.max(length(array1), length(array2));
let result = maxLength |> Belt.Array.makeUninitializedUnsafe;
for (i in 0 to maxLength - 1) {
Belt.Array.set(result, i, (get(array1, i), get(array2, i))) |> ignore;
};
result;
};
let asList = (f: list('a) => list('a), r: array('a)) => let asList = (f: list('a) => list('a), r: array('a)) =>
r |> to_list |> f |> of_list; r |> to_list |> f |> of_list;
/* TODO: Is there a better way of doing this? */ /* TODO: Is there a better way of doing this? */