Simple code version of multimodal, cdf, pdf, sample

This commit is contained in:
Ozzie Gooen 2020-08-10 23:35:21 +01:00
parent f3841a961c
commit 31e4f97820
6 changed files with 125 additions and 87 deletions

View File

@ -32,7 +32,9 @@ let rec toString: node => string =
"[Function: ("
++ (args |> Js.String.concatMany(_, ","))
++ toString(internal)
++ ")]";
++ ")]"
| `Array(args) => "Array"
| `MultiModal(args) => "Multimodal"
let toShape = (samplingInputs, environment, node: node) => {
switch (toLeaf(samplingInputs, environment, node)) {

View File

@ -309,6 +309,11 @@ let rec toLeaf =
| `SymbolicDist(_)
| `Function(_)
| `RenderedDist(_) => Ok(node)
| `Array(args) =>
args
|> E.A.fmap(toLeaf(evaluationParams))
|> E.A.R.firstErrorOrOpen
|> E.R.fmap(r => `Array(r))
// Operations nevaluationParamsd to be turned into leaves
| `AlgebraicCombination(algebraicOp, t1, t2) =>
AlgebraicCombination.operationToLeaf(
@ -341,5 +346,23 @@ let rec toLeaf =
|> E.R.bind(_, toLeaf(evaluationParams))
| `FunctionCall(name, args) =>
callableFunction(evaluationParams, name, args)
| `MultiModal(r) =>
let components =
r
|> E.A.fmap(((dist, weight)) =>
`VerticalScaling((
`Multiply,
dist,
`SymbolicDist(`Float(weight)),
))
);
let pointwiseSum =
components
|> Js.Array.sliceFrom(1)
|> E.A.fold_left(
(acc, x) => {`PointwiseCombination((`Add, acc, x))},
E.A.unsafe_get(components, 0),
);
Ok(`Render(`Normalize(pointwiseSum)));
};
};

View File

@ -29,6 +29,8 @@ module ExpressionTree = {
| `Truncate(option(float), option(float), node)
| `FloatFromDist(distToFloatOperation, node)
| `FunctionCall(string, array(node))
| `Array(array(node))
| `MultiModal(array((node, float)))
];
// Have nil as option
let getFloat = (node:node) => node |> fun

View File

@ -86,50 +86,98 @@ let fnn =
| _ => Error("Needs 3 valid arguments")
}
| ("to", _) => apply2(twoFloats(to_), args)
| ("pdf", _) => switch(args){
| [|fst,snd|] => {
switch(PTypes.SamplingDistribution.renderIfIsNotSamplingDistribution(evaluationParams,fst),getFloat(snd)){
| (Ok(fst), Some(flt)) => Ok(`FloatFromDist(`Pdf(flt), fst))
| _ => Error("Incorrect arguments")
| ("pdf", _) =>
switch (args) {
| [|fst, snd|] =>
switch (
PTypes.SamplingDistribution.renderIfIsNotSamplingDistribution(
evaluationParams,
fst,
),
getFloat(snd),
) {
| (Ok(fst), Some(flt)) => Ok(`FloatFromDist((`Pdf(flt), fst)))
| _ => Error("Incorrect arguments")
}
}
| _ => Error("Needs two args")
}
| ("inv", _) => switch(args){
| [|fst,snd|] => {
switch(PTypes.SamplingDistribution.renderIfIsNotSamplingDistribution(evaluationParams,fst),getFloat(snd)){
| (Ok(fst), Some(flt)) => Ok(`FloatFromDist(`Inv(flt), fst))
| _ => Error("Incorrect arguments")
}
| ("inv", _) =>
switch (args) {
| [|fst, snd|] =>
switch (
PTypes.SamplingDistribution.renderIfIsNotSamplingDistribution(
evaluationParams,
fst,
),
getFloat(snd),
) {
| (Ok(fst), Some(flt)) => Ok(`FloatFromDist((`Inv(flt), fst)))
| _ => Error("Incorrect arguments")
}
}
| _ => Error("Needs two args")
}
| ("cdf", _) => switch(args){
| [|fst,snd|] => {
switch(PTypes.SamplingDistribution.renderIfIsNotSamplingDistribution(evaluationParams,fst),getFloat(snd)){
| (Ok(fst), Some(flt)) => Ok(`FloatFromDist(`Cdf(flt), fst))
| _ => Error("Incorrect arguments")
}
| ("cdf", _) =>
switch (args) {
| [|fst, snd|] =>
switch (
PTypes.SamplingDistribution.renderIfIsNotSamplingDistribution(
evaluationParams,
fst,
),
getFloat(snd),
) {
| (Ok(fst), Some(flt)) => Ok(`FloatFromDist((`Cdf(flt), fst)))
| _ => Error("Incorrect arguments")
}
}
| _ => Error("Needs two args")
}
| ("mean", _) => switch(args){
| [|fst|] => {
switch(PTypes.SamplingDistribution.renderIfIsNotSamplingDistribution(evaluationParams,fst)){
| (Ok(fst)) => Ok(`FloatFromDist(`Mean,fst))
| _ => Error("Incorrect arguments")
}
| ("mean", _) =>
switch (args) {
| [|fst|] =>
switch (
PTypes.SamplingDistribution.renderIfIsNotSamplingDistribution(
evaluationParams,
fst,
)
) {
| Ok(fst) => Ok(`FloatFromDist((`Mean, fst)))
| _ => Error("Incorrect arguments")
}
}
| _ => Error("Needs two args")
}
| ("sample", _) => switch(args){
| [|fst|] => {
switch(PTypes.SamplingDistribution.renderIfIsNotSamplingDistribution(evaluationParams,fst)){
| (Ok(fst)) => Ok(`FloatFromDist(`Sample,fst))
| _ => Error("Incorrect arguments")
}
| ("sample", _) =>
switch (args) {
| [|fst|] =>
switch (
PTypes.SamplingDistribution.renderIfIsNotSamplingDistribution(
evaluationParams,
fst,
)
) {
| Ok(fst) => Ok(`FloatFromDist((`Sample, fst)))
| _ => Error("Incorrect arguments")
}
}
| _ => Error("Needs two args")
}
}
| ("mm", _)
| ("multimodal", _) =>
switch (args |> E.A.to_list) {
| [`Array(weights), ...dists] =>
let withWeights =
dists
|> E.L.toArray
|> E.A.fmapi((index, t) => {
let w =
weights
|> E.A.get(_, index)
|> E.O.bind(_, getFloat)
|> E.O.default(1.0);
(t, w);
});
Ok(`MultiModal(withWeights));
| dists when E.L.length(dists) > 0 =>
Ok(`MultiModal(dists |> E.L.toArray |> E.A.fmap(r => (r, 1.0))))
| _ => Error("Needs at least one distribution")
}
| _ => Error("Function " ++ name ++ " not found")
};

View File

@ -138,40 +138,6 @@ module MathAdtToDistDst = {
)
};
let multiModal =
(
args: array(result(ExpressionTypes.ExpressionTree.node, string)),
weights: option(array(float)),
) => {
let weights = weights |> E.O.default([||]);
let firstWithError = args |> Belt.Array.getBy(_, Belt.Result.isError);
let withoutErrors = args |> E.A.fmap(E.R.toOption) |> E.A.O.concatSomes;
switch (firstWithError) {
| Some(Error(e)) => Error(e)
| None when withoutErrors |> E.A.length == 0 =>
Error("Multimodals need at least one input")
| _ =>
let components =
withoutErrors
|> E.A.fmapi((index, t) => {
let w = weights |> E.A.get(_, index) |> E.O.default(1.0);
`VerticalScaling((`Multiply, t, `SymbolicDist(`Float(w))));
});
let pointwiseSum =
components
|> Js.Array.sliceFrom(1)
|> E.A.fold_left(
(acc, x) => {`PointwiseCombination((`Add, acc, x))},
E.A.unsafe_get(components, 0),
);
Ok(`Normalize(pointwiseSum));
};
};
// Error("Dotwise exponentiation needs two operands")
let operationParser =
(
@ -182,7 +148,6 @@ module MathAdtToDistDst = {
let toOkAlgebraic = r => Ok(`AlgebraicCombination(r));
let toOkPointwise = r => Ok(`PointwiseCombination(r));
let toOkTruncate = r => Ok(`Truncate(r));
let toOkFloatFromDist = r => Ok(`FloatFromDist(r));
args
|> E.R.bind(_, args => {
switch (name, args) {
@ -247,31 +212,28 @@ module MathAdtToDistDst = {
let parseArgs = () => parseArray(args);
switch (name) {
| "lognormal" => lognormal(args, parseArgs, nodeParser)
| "mm" =>
| "mm" =>{
let weights =
args
|> E.A.last
|> E.O.bind(
_,
fun
| Array(values) => Some(values)
| _ => None,
)
|> E.O.fmap(o =>
o
|> E.A.fmap(
fun
| Value(r) => Some(r)
| _ => None,
)
|> E.A.O.concatSomes
| Array(values) => Some(parseArray(values))
| _ => None
);
let possibleDists =
E.O.isSome(weights)
? Belt.Array.slice(args, ~offset=0, ~len=E.A.length(args) - 1)
: args;
let dists = possibleDists |> E.A.fmap(nodeParser);
multiModal(dists, weights);
let dists = parseArray(possibleDists);
switch(weights, dists){
| (Some(Error(r)), _) => Error(r)
| (_, Error(r)) => Error(r)
| (None, Ok(dists)) => Ok(`FunctionCall("multimodal", dists))
| (Some(Ok(r)), Ok(dists)) => Ok(`FunctionCall("multimodal", E.A.append([|`Array(r)|], dists)))
}
}
| "add"
| "subtract"
| "multiply"

View File

@ -141,8 +141,9 @@ let renderIfNeeded =
node
|> (
fun
| `SymbolicDist(n) => {
`Render(`SymbolicDist(n))
| `MultiModal(_) as n
| `SymbolicDist(_) as n => {
`Render(n)
|> Internals.runNode(Internals.distPlusRenderInputsToInputs(inputs))
|> (
fun