Minor fix for multimodals
This commit is contained in:
parent
4c33561b7c
commit
7566c59fef
|
@ -304,7 +304,7 @@ let rec toLeaf =
|
||||||
let components =
|
let components =
|
||||||
r
|
r
|
||||||
|> E.A.fmap(((dist, weight)) =>
|
|> E.A.fmap(((dist, weight)) =>
|
||||||
`FunctionCall("scaleExp", [|dist, `SymbolicDist(`Float(weight))|]));
|
`FunctionCall("scaleMultiply", [|dist, `SymbolicDist(`Float(weight))|]));
|
||||||
let pointwiseSum =
|
let pointwiseSum =
|
||||||
components
|
components
|
||||||
|> Js.Array.sliceFrom(1)
|
|> Js.Array.sliceFrom(1)
|
||||||
|
|
|
@ -91,14 +91,16 @@ let verticalScaling = (scaleOp, rs, scaleBy) => {
|
||||||
Operation.Scale.toFn(scaleOp, main, secondary);
|
Operation.Scale.toFn(scaleOp, main, secondary);
|
||||||
let integralSumCacheFn = Operation.Scale.toIntegralSumCacheFn(scaleOp);
|
let integralSumCacheFn = Operation.Scale.toIntegralSumCacheFn(scaleOp);
|
||||||
let integralCacheFn = Operation.Scale.toIntegralCacheFn(scaleOp);
|
let integralCacheFn = Operation.Scale.toIntegralCacheFn(scaleOp);
|
||||||
Ok(`RenderedDist(
|
Ok(
|
||||||
Shape.T.mapY(
|
`RenderedDist(
|
||||||
~integralSumCacheFn=integralSumCacheFn(scaleBy),
|
Shape.T.mapY(
|
||||||
~integralCacheFn=integralCacheFn(scaleBy),
|
~integralSumCacheFn=integralSumCacheFn(scaleBy),
|
||||||
~fn=fn(scaleBy),
|
~integralCacheFn=integralCacheFn(scaleBy),
|
||||||
rs,
|
~fn=fn(scaleBy),
|
||||||
|
rs,
|
||||||
|
),
|
||||||
),
|
),
|
||||||
));
|
);
|
||||||
};
|
};
|
||||||
|
|
||||||
let functions = [|
|
let functions = [|
|
||||||
|
@ -152,7 +154,8 @@ let functions = [|
|
||||||
~run=
|
~run=
|
||||||
fun
|
fun
|
||||||
| [|`SamplingDist(`SymbolicDist(c))|] => Ok(`SymbolicDist(c))
|
| [|`SamplingDist(`SymbolicDist(c))|] => Ok(`SymbolicDist(c))
|
||||||
| [|`SamplingDist(`RenderedDist(c))|] => Ok(`RenderedDist(Shape.T.normalize(c)))
|
| [|`SamplingDist(`RenderedDist(c))|] =>
|
||||||
|
Ok(`RenderedDist(Shape.T.normalize(c)))
|
||||||
| e => wrongInputsError(e),
|
| e => wrongInputsError(e),
|
||||||
),
|
),
|
||||||
makeRenderedDistFloat("scaleExp", (dist, float) =>
|
makeRenderedDistFloat("scaleExp", (dist, float) =>
|
||||||
|
|
|
@ -22,7 +22,6 @@ let fnn =
|
||||||
) {
|
) {
|
||||||
| (_, Some(`Function(argNames, tt))) =>
|
| (_, Some(`Function(argNames, tt))) =>
|
||||||
PTypes.Function.run(evaluationParams, args, (argNames, tt))
|
PTypes.Function.run(evaluationParams, args, (argNames, tt))
|
||||||
| ("mm", _)
|
|
||||||
| ("multimodal", _) =>
|
| ("multimodal", _) =>
|
||||||
switch (args |> E.A.to_list) {
|
switch (args |> E.A.to_list) {
|
||||||
| [`Array(weights), ...dists] =>
|
| [`Array(weights), ...dists] =>
|
||||||
|
|
|
@ -204,7 +204,8 @@ module MathAdtToDistDst = {
|
||||||
let parseArgs = () => parseArray(args);
|
let parseArgs = () => parseArray(args);
|
||||||
switch (name) {
|
switch (name) {
|
||||||
| "lognormal" => lognormal(args, parseArgs, nodeParser)
|
| "lognormal" => lognormal(args, parseArgs, nodeParser)
|
||||||
| "mm" =>{
|
| "multimodal"
|
||||||
|
| "mm" =>
|
||||||
let weights =
|
let weights =
|
||||||
args
|
args
|
||||||
|> E.A.last
|
|> E.A.last
|
||||||
|
@ -212,20 +213,22 @@ module MathAdtToDistDst = {
|
||||||
_,
|
_,
|
||||||
fun
|
fun
|
||||||
| Array(values) => Some(parseArray(values))
|
| Array(values) => Some(parseArray(values))
|
||||||
| _ => None
|
| _ => None,
|
||||||
);
|
);
|
||||||
let possibleDists =
|
let possibleDists =
|
||||||
E.O.isSome(weights)
|
E.O.isSome(weights)
|
||||||
? Belt.Array.slice(args, ~offset=0, ~len=E.A.length(args) - 1)
|
? Belt.Array.slice(args, ~offset=0, ~len=E.A.length(args) - 1)
|
||||||
: args;
|
: args;
|
||||||
let dists = parseArray(possibleDists);
|
let dists = parseArray(possibleDists);
|
||||||
switch(weights, dists){
|
switch (weights, dists) {
|
||||||
| (Some(Error(r)), _) => Error(r)
|
| (Some(Error(r)), _) => Error(r)
|
||||||
| (_, Error(r)) => Error(r)
|
| (_, Error(r)) => Error(r)
|
||||||
| (None, Ok(dists)) => Ok(`FunctionCall("multimodal", dists))
|
| (None, Ok(dists)) => Ok(`FunctionCall(("multimodal", dists)))
|
||||||
| (Some(Ok(r)), Ok(dists)) => Ok(`FunctionCall("multimodal", E.A.append([|`Array(r)|], dists)))
|
| (Some(Ok(r)), Ok(dists)) =>
|
||||||
}
|
Ok(
|
||||||
}
|
`FunctionCall(("multimodal", E.A.append([|`Array(r)|], dists))),
|
||||||
|
)
|
||||||
|
};
|
||||||
| "add"
|
| "add"
|
||||||
| "subtract"
|
| "subtract"
|
||||||
| "multiply"
|
| "multiply"
|
||||||
|
|
Loading…
Reference in New Issue
Block a user