Simple code version of multimodal, cdf, pdf, sample
This commit is contained in:
parent
f3841a961c
commit
31e4f97820
|
@ -32,7 +32,9 @@ let rec toString: node => string =
|
||||||
"[Function: ("
|
"[Function: ("
|
||||||
++ (args |> Js.String.concatMany(_, ","))
|
++ (args |> Js.String.concatMany(_, ","))
|
||||||
++ toString(internal)
|
++ toString(internal)
|
||||||
++ ")]";
|
++ ")]"
|
||||||
|
| `Array(args) => "Array"
|
||||||
|
| `MultiModal(args) => "Multimodal"
|
||||||
|
|
||||||
let toShape = (samplingInputs, environment, node: node) => {
|
let toShape = (samplingInputs, environment, node: node) => {
|
||||||
switch (toLeaf(samplingInputs, environment, node)) {
|
switch (toLeaf(samplingInputs, environment, node)) {
|
||||||
|
|
|
@ -309,6 +309,11 @@ let rec toLeaf =
|
||||||
| `SymbolicDist(_)
|
| `SymbolicDist(_)
|
||||||
| `Function(_)
|
| `Function(_)
|
||||||
| `RenderedDist(_) => Ok(node)
|
| `RenderedDist(_) => Ok(node)
|
||||||
|
| `Array(args) =>
|
||||||
|
args
|
||||||
|
|> E.A.fmap(toLeaf(evaluationParams))
|
||||||
|
|> E.A.R.firstErrorOrOpen
|
||||||
|
|> E.R.fmap(r => `Array(r))
|
||||||
// Operations nevaluationParamsd to be turned into leaves
|
// Operations nevaluationParamsd to be turned into leaves
|
||||||
| `AlgebraicCombination(algebraicOp, t1, t2) =>
|
| `AlgebraicCombination(algebraicOp, t1, t2) =>
|
||||||
AlgebraicCombination.operationToLeaf(
|
AlgebraicCombination.operationToLeaf(
|
||||||
|
@ -341,5 +346,23 @@ let rec toLeaf =
|
||||||
|> E.R.bind(_, toLeaf(evaluationParams))
|
|> E.R.bind(_, toLeaf(evaluationParams))
|
||||||
| `FunctionCall(name, args) =>
|
| `FunctionCall(name, args) =>
|
||||||
callableFunction(evaluationParams, name, args)
|
callableFunction(evaluationParams, name, args)
|
||||||
|
| `MultiModal(r) =>
|
||||||
|
let components =
|
||||||
|
r
|
||||||
|
|> E.A.fmap(((dist, weight)) =>
|
||||||
|
`VerticalScaling((
|
||||||
|
`Multiply,
|
||||||
|
dist,
|
||||||
|
`SymbolicDist(`Float(weight)),
|
||||||
|
))
|
||||||
|
);
|
||||||
|
let pointwiseSum =
|
||||||
|
components
|
||||||
|
|> Js.Array.sliceFrom(1)
|
||||||
|
|> E.A.fold_left(
|
||||||
|
(acc, x) => {`PointwiseCombination((`Add, acc, x))},
|
||||||
|
E.A.unsafe_get(components, 0),
|
||||||
|
);
|
||||||
|
Ok(`Render(`Normalize(pointwiseSum)));
|
||||||
};
|
};
|
||||||
};
|
};
|
||||||
|
|
|
@ -29,6 +29,8 @@ module ExpressionTree = {
|
||||||
| `Truncate(option(float), option(float), node)
|
| `Truncate(option(float), option(float), node)
|
||||||
| `FloatFromDist(distToFloatOperation, node)
|
| `FloatFromDist(distToFloatOperation, node)
|
||||||
| `FunctionCall(string, array(node))
|
| `FunctionCall(string, array(node))
|
||||||
|
| `Array(array(node))
|
||||||
|
| `MultiModal(array((node, float)))
|
||||||
];
|
];
|
||||||
// Have nil as option
|
// Have nil as option
|
||||||
let getFloat = (node:node) => node |> fun
|
let getFloat = (node:node) => node |> fun
|
||||||
|
|
|
@ -86,50 +86,98 @@ let fnn =
|
||||||
| _ => Error("Needs 3 valid arguments")
|
| _ => Error("Needs 3 valid arguments")
|
||||||
}
|
}
|
||||||
| ("to", _) => apply2(twoFloats(to_), args)
|
| ("to", _) => apply2(twoFloats(to_), args)
|
||||||
| ("pdf", _) => switch(args){
|
| ("pdf", _) =>
|
||||||
| [|fst,snd|] => {
|
switch (args) {
|
||||||
switch(PTypes.SamplingDistribution.renderIfIsNotSamplingDistribution(evaluationParams,fst),getFloat(snd)){
|
| [|fst, snd|] =>
|
||||||
| (Ok(fst), Some(flt)) => Ok(`FloatFromDist(`Pdf(flt), fst))
|
switch (
|
||||||
|
PTypes.SamplingDistribution.renderIfIsNotSamplingDistribution(
|
||||||
|
evaluationParams,
|
||||||
|
fst,
|
||||||
|
),
|
||||||
|
getFloat(snd),
|
||||||
|
) {
|
||||||
|
| (Ok(fst), Some(flt)) => Ok(`FloatFromDist((`Pdf(flt), fst)))
|
||||||
| _ => Error("Incorrect arguments")
|
| _ => Error("Incorrect arguments")
|
||||||
}
|
}
|
||||||
}
|
|
||||||
| _ => Error("Needs two args")
|
| _ => Error("Needs two args")
|
||||||
}
|
}
|
||||||
| ("inv", _) => switch(args){
|
| ("inv", _) =>
|
||||||
| [|fst,snd|] => {
|
switch (args) {
|
||||||
switch(PTypes.SamplingDistribution.renderIfIsNotSamplingDistribution(evaluationParams,fst),getFloat(snd)){
|
| [|fst, snd|] =>
|
||||||
| (Ok(fst), Some(flt)) => Ok(`FloatFromDist(`Inv(flt), fst))
|
switch (
|
||||||
|
PTypes.SamplingDistribution.renderIfIsNotSamplingDistribution(
|
||||||
|
evaluationParams,
|
||||||
|
fst,
|
||||||
|
),
|
||||||
|
getFloat(snd),
|
||||||
|
) {
|
||||||
|
| (Ok(fst), Some(flt)) => Ok(`FloatFromDist((`Inv(flt), fst)))
|
||||||
| _ => Error("Incorrect arguments")
|
| _ => Error("Incorrect arguments")
|
||||||
}
|
}
|
||||||
}
|
|
||||||
| _ => Error("Needs two args")
|
| _ => Error("Needs two args")
|
||||||
}
|
}
|
||||||
| ("cdf", _) => switch(args){
|
| ("cdf", _) =>
|
||||||
| [|fst,snd|] => {
|
switch (args) {
|
||||||
switch(PTypes.SamplingDistribution.renderIfIsNotSamplingDistribution(evaluationParams,fst),getFloat(snd)){
|
| [|fst, snd|] =>
|
||||||
| (Ok(fst), Some(flt)) => Ok(`FloatFromDist(`Cdf(flt), fst))
|
switch (
|
||||||
|
PTypes.SamplingDistribution.renderIfIsNotSamplingDistribution(
|
||||||
|
evaluationParams,
|
||||||
|
fst,
|
||||||
|
),
|
||||||
|
getFloat(snd),
|
||||||
|
) {
|
||||||
|
| (Ok(fst), Some(flt)) => Ok(`FloatFromDist((`Cdf(flt), fst)))
|
||||||
| _ => Error("Incorrect arguments")
|
| _ => Error("Incorrect arguments")
|
||||||
}
|
}
|
||||||
}
|
|
||||||
| _ => Error("Needs two args")
|
| _ => Error("Needs two args")
|
||||||
}
|
}
|
||||||
| ("mean", _) => switch(args){
|
| ("mean", _) =>
|
||||||
| [|fst|] => {
|
switch (args) {
|
||||||
switch(PTypes.SamplingDistribution.renderIfIsNotSamplingDistribution(evaluationParams,fst)){
|
| [|fst|] =>
|
||||||
| (Ok(fst)) => Ok(`FloatFromDist(`Mean,fst))
|
switch (
|
||||||
|
PTypes.SamplingDistribution.renderIfIsNotSamplingDistribution(
|
||||||
|
evaluationParams,
|
||||||
|
fst,
|
||||||
|
)
|
||||||
|
) {
|
||||||
|
| Ok(fst) => Ok(`FloatFromDist((`Mean, fst)))
|
||||||
| _ => Error("Incorrect arguments")
|
| _ => Error("Incorrect arguments")
|
||||||
}
|
}
|
||||||
}
|
|
||||||
| _ => Error("Needs two args")
|
| _ => Error("Needs two args")
|
||||||
}
|
}
|
||||||
| ("sample", _) => switch(args){
|
| ("sample", _) =>
|
||||||
| [|fst|] => {
|
switch (args) {
|
||||||
switch(PTypes.SamplingDistribution.renderIfIsNotSamplingDistribution(evaluationParams,fst)){
|
| [|fst|] =>
|
||||||
| (Ok(fst)) => Ok(`FloatFromDist(`Sample,fst))
|
switch (
|
||||||
|
PTypes.SamplingDistribution.renderIfIsNotSamplingDistribution(
|
||||||
|
evaluationParams,
|
||||||
|
fst,
|
||||||
|
)
|
||||||
|
) {
|
||||||
|
| Ok(fst) => Ok(`FloatFromDist((`Sample, fst)))
|
||||||
| _ => Error("Incorrect arguments")
|
| _ => Error("Incorrect arguments")
|
||||||
}
|
}
|
||||||
}
|
|
||||||
| _ => Error("Needs two args")
|
| _ => Error("Needs two args")
|
||||||
}
|
}
|
||||||
|
| ("mm", _)
|
||||||
|
| ("multimodal", _) =>
|
||||||
|
switch (args |> E.A.to_list) {
|
||||||
|
| [`Array(weights), ...dists] =>
|
||||||
|
let withWeights =
|
||||||
|
dists
|
||||||
|
|> E.L.toArray
|
||||||
|
|> E.A.fmapi((index, t) => {
|
||||||
|
let w =
|
||||||
|
weights
|
||||||
|
|> E.A.get(_, index)
|
||||||
|
|> E.O.bind(_, getFloat)
|
||||||
|
|> E.O.default(1.0);
|
||||||
|
(t, w);
|
||||||
|
});
|
||||||
|
Ok(`MultiModal(withWeights));
|
||||||
|
| dists when E.L.length(dists) > 0 =>
|
||||||
|
Ok(`MultiModal(dists |> E.L.toArray |> E.A.fmap(r => (r, 1.0))))
|
||||||
|
| _ => Error("Needs at least one distribution")
|
||||||
|
}
|
||||||
| _ => Error("Function " ++ name ++ " not found")
|
| _ => Error("Function " ++ name ++ " not found")
|
||||||
};
|
};
|
||||||
|
|
|
@ -138,40 +138,6 @@ module MathAdtToDistDst = {
|
||||||
)
|
)
|
||||||
};
|
};
|
||||||
|
|
||||||
let multiModal =
|
|
||||||
(
|
|
||||||
args: array(result(ExpressionTypes.ExpressionTree.node, string)),
|
|
||||||
weights: option(array(float)),
|
|
||||||
) => {
|
|
||||||
let weights = weights |> E.O.default([||]);
|
|
||||||
let firstWithError = args |> Belt.Array.getBy(_, Belt.Result.isError);
|
|
||||||
let withoutErrors = args |> E.A.fmap(E.R.toOption) |> E.A.O.concatSomes;
|
|
||||||
|
|
||||||
switch (firstWithError) {
|
|
||||||
| Some(Error(e)) => Error(e)
|
|
||||||
| None when withoutErrors |> E.A.length == 0 =>
|
|
||||||
Error("Multimodals need at least one input")
|
|
||||||
| _ =>
|
|
||||||
let components =
|
|
||||||
withoutErrors
|
|
||||||
|> E.A.fmapi((index, t) => {
|
|
||||||
let w = weights |> E.A.get(_, index) |> E.O.default(1.0);
|
|
||||||
|
|
||||||
`VerticalScaling((`Multiply, t, `SymbolicDist(`Float(w))));
|
|
||||||
});
|
|
||||||
|
|
||||||
let pointwiseSum =
|
|
||||||
components
|
|
||||||
|> Js.Array.sliceFrom(1)
|
|
||||||
|> E.A.fold_left(
|
|
||||||
(acc, x) => {`PointwiseCombination((`Add, acc, x))},
|
|
||||||
E.A.unsafe_get(components, 0),
|
|
||||||
);
|
|
||||||
|
|
||||||
Ok(`Normalize(pointwiseSum));
|
|
||||||
};
|
|
||||||
};
|
|
||||||
|
|
||||||
// Error("Dotwise exponentiation needs two operands")
|
// Error("Dotwise exponentiation needs two operands")
|
||||||
let operationParser =
|
let operationParser =
|
||||||
(
|
(
|
||||||
|
@ -182,7 +148,6 @@ module MathAdtToDistDst = {
|
||||||
let toOkAlgebraic = r => Ok(`AlgebraicCombination(r));
|
let toOkAlgebraic = r => Ok(`AlgebraicCombination(r));
|
||||||
let toOkPointwise = r => Ok(`PointwiseCombination(r));
|
let toOkPointwise = r => Ok(`PointwiseCombination(r));
|
||||||
let toOkTruncate = r => Ok(`Truncate(r));
|
let toOkTruncate = r => Ok(`Truncate(r));
|
||||||
let toOkFloatFromDist = r => Ok(`FloatFromDist(r));
|
|
||||||
args
|
args
|
||||||
|> E.R.bind(_, args => {
|
|> E.R.bind(_, args => {
|
||||||
switch (name, args) {
|
switch (name, args) {
|
||||||
|
@ -247,31 +212,28 @@ module MathAdtToDistDst = {
|
||||||
let parseArgs = () => parseArray(args);
|
let parseArgs = () => parseArray(args);
|
||||||
switch (name) {
|
switch (name) {
|
||||||
| "lognormal" => lognormal(args, parseArgs, nodeParser)
|
| "lognormal" => lognormal(args, parseArgs, nodeParser)
|
||||||
| "mm" =>
|
| "mm" =>{
|
||||||
let weights =
|
let weights =
|
||||||
args
|
args
|
||||||
|> E.A.last
|
|> E.A.last
|
||||||
|> E.O.bind(
|
|> E.O.bind(
|
||||||
_,
|
_,
|
||||||
fun
|
fun
|
||||||
| Array(values) => Some(values)
|
| Array(values) => Some(parseArray(values))
|
||||||
| _ => None,
|
| _ => None
|
||||||
)
|
|
||||||
|> E.O.fmap(o =>
|
|
||||||
o
|
|
||||||
|> E.A.fmap(
|
|
||||||
fun
|
|
||||||
| Value(r) => Some(r)
|
|
||||||
| _ => None,
|
|
||||||
)
|
|
||||||
|> E.A.O.concatSomes
|
|
||||||
);
|
);
|
||||||
let possibleDists =
|
let possibleDists =
|
||||||
E.O.isSome(weights)
|
E.O.isSome(weights)
|
||||||
? Belt.Array.slice(args, ~offset=0, ~len=E.A.length(args) - 1)
|
? Belt.Array.slice(args, ~offset=0, ~len=E.A.length(args) - 1)
|
||||||
: args;
|
: args;
|
||||||
let dists = possibleDists |> E.A.fmap(nodeParser);
|
let dists = parseArray(possibleDists);
|
||||||
multiModal(dists, weights);
|
switch(weights, dists){
|
||||||
|
| (Some(Error(r)), _) => Error(r)
|
||||||
|
| (_, Error(r)) => Error(r)
|
||||||
|
| (None, Ok(dists)) => Ok(`FunctionCall("multimodal", dists))
|
||||||
|
| (Some(Ok(r)), Ok(dists)) => Ok(`FunctionCall("multimodal", E.A.append([|`Array(r)|], dists)))
|
||||||
|
}
|
||||||
|
}
|
||||||
| "add"
|
| "add"
|
||||||
| "subtract"
|
| "subtract"
|
||||||
| "multiply"
|
| "multiply"
|
||||||
|
|
|
@ -141,8 +141,9 @@ let renderIfNeeded =
|
||||||
node
|
node
|
||||||
|> (
|
|> (
|
||||||
fun
|
fun
|
||||||
| `SymbolicDist(n) => {
|
| `MultiModal(_) as n
|
||||||
`Render(`SymbolicDist(n))
|
| `SymbolicDist(_) as n => {
|
||||||
|
`Render(n)
|
||||||
|> Internals.runNode(Internals.distPlusRenderInputsToInputs(inputs))
|
|> Internals.runNode(Internals.distPlusRenderInputsToInputs(inputs))
|
||||||
|> (
|
|> (
|
||||||
fun
|
fun
|
||||||
|
|
Loading…
Reference in New Issue
Block a user