Minor fix for multimodals

This commit is contained in:
Ozzie Gooen 2020-11-05 17:04:50 -08:00
parent 4c33561b7c
commit 7566c59fef
4 changed files with 24 additions and 19 deletions

View File

@ -304,7 +304,7 @@ let rec toLeaf =
let components = let components =
r r
|> E.A.fmap(((dist, weight)) => |> E.A.fmap(((dist, weight)) =>
`FunctionCall("scaleExp", [|dist, `SymbolicDist(`Float(weight))|])); `FunctionCall("scaleMultiply", [|dist, `SymbolicDist(`Float(weight))|]));
let pointwiseSum = let pointwiseSum =
components components
|> Js.Array.sliceFrom(1) |> Js.Array.sliceFrom(1)

View File

@ -91,14 +91,16 @@ let verticalScaling = (scaleOp, rs, scaleBy) => {
Operation.Scale.toFn(scaleOp, main, secondary); Operation.Scale.toFn(scaleOp, main, secondary);
let integralSumCacheFn = Operation.Scale.toIntegralSumCacheFn(scaleOp); let integralSumCacheFn = Operation.Scale.toIntegralSumCacheFn(scaleOp);
let integralCacheFn = Operation.Scale.toIntegralCacheFn(scaleOp); let integralCacheFn = Operation.Scale.toIntegralCacheFn(scaleOp);
Ok(`RenderedDist( Ok(
Shape.T.mapY( `RenderedDist(
~integralSumCacheFn=integralSumCacheFn(scaleBy), Shape.T.mapY(
~integralCacheFn=integralCacheFn(scaleBy), ~integralSumCacheFn=integralSumCacheFn(scaleBy),
~fn=fn(scaleBy), ~integralCacheFn=integralCacheFn(scaleBy),
rs, ~fn=fn(scaleBy),
rs,
),
), ),
)); );
}; };
let functions = [| let functions = [|
@ -152,7 +154,8 @@ let functions = [|
~run= ~run=
fun fun
| [|`SamplingDist(`SymbolicDist(c))|] => Ok(`SymbolicDist(c)) | [|`SamplingDist(`SymbolicDist(c))|] => Ok(`SymbolicDist(c))
| [|`SamplingDist(`RenderedDist(c))|] => Ok(`RenderedDist(Shape.T.normalize(c))) | [|`SamplingDist(`RenderedDist(c))|] =>
Ok(`RenderedDist(Shape.T.normalize(c)))
| e => wrongInputsError(e), | e => wrongInputsError(e),
), ),
makeRenderedDistFloat("scaleExp", (dist, float) => makeRenderedDistFloat("scaleExp", (dist, float) =>

View File

@ -22,7 +22,6 @@ let fnn =
) { ) {
| (_, Some(`Function(argNames, tt))) => | (_, Some(`Function(argNames, tt))) =>
PTypes.Function.run(evaluationParams, args, (argNames, tt)) PTypes.Function.run(evaluationParams, args, (argNames, tt))
| ("mm", _)
| ("multimodal", _) => | ("multimodal", _) =>
switch (args |> E.A.to_list) { switch (args |> E.A.to_list) {
| [`Array(weights), ...dists] => | [`Array(weights), ...dists] =>

View File

@ -204,7 +204,8 @@ module MathAdtToDistDst = {
let parseArgs = () => parseArray(args); let parseArgs = () => parseArray(args);
switch (name) { switch (name) {
| "lognormal" => lognormal(args, parseArgs, nodeParser) | "lognormal" => lognormal(args, parseArgs, nodeParser)
| "mm" =>{ | "multimodal"
| "mm" =>
let weights = let weights =
args args
|> E.A.last |> E.A.last
@ -212,20 +213,22 @@ module MathAdtToDistDst = {
_, _,
fun fun
| Array(values) => Some(parseArray(values)) | Array(values) => Some(parseArray(values))
| _ => None | _ => None,
); );
let possibleDists = let possibleDists =
E.O.isSome(weights) E.O.isSome(weights)
? Belt.Array.slice(args, ~offset=0, ~len=E.A.length(args) - 1) ? Belt.Array.slice(args, ~offset=0, ~len=E.A.length(args) - 1)
: args; : args;
let dists = parseArray(possibleDists); let dists = parseArray(possibleDists);
switch(weights, dists){ switch (weights, dists) {
| (Some(Error(r)), _) => Error(r) | (Some(Error(r)), _) => Error(r)
| (_, Error(r)) => Error(r) | (_, Error(r)) => Error(r)
| (None, Ok(dists)) => Ok(`FunctionCall("multimodal", dists)) | (None, Ok(dists)) => Ok(`FunctionCall(("multimodal", dists)))
| (Some(Ok(r)), Ok(dists)) => Ok(`FunctionCall("multimodal", E.A.append([|`Array(r)|], dists))) | (Some(Ok(r)), Ok(dists)) =>
} Ok(
} `FunctionCall(("multimodal", E.A.append([|`Array(r)|], dists))),
)
};
| "add" | "add"
| "subtract" | "subtract"
| "multiply" | "multiply"