Simple code version of multimodal, cdf, pdf, sample
This commit is contained in:
parent
f3841a961c
commit
31e4f97820
|
@ -32,7 +32,9 @@ let rec toString: node => string =
|
|||
"[Function: ("
|
||||
++ (args |> Js.String.concatMany(_, ","))
|
||||
++ toString(internal)
|
||||
++ ")]";
|
||||
++ ")]"
|
||||
| `Array(args) => "Array"
|
||||
| `MultiModal(args) => "Multimodal"
|
||||
|
||||
let toShape = (samplingInputs, environment, node: node) => {
|
||||
switch (toLeaf(samplingInputs, environment, node)) {
|
||||
|
|
|
@ -309,6 +309,11 @@ let rec toLeaf =
|
|||
| `SymbolicDist(_)
|
||||
| `Function(_)
|
||||
| `RenderedDist(_) => Ok(node)
|
||||
| `Array(args) =>
|
||||
args
|
||||
|> E.A.fmap(toLeaf(evaluationParams))
|
||||
|> E.A.R.firstErrorOrOpen
|
||||
|> E.R.fmap(r => `Array(r))
|
||||
// Operations nevaluationParamsd to be turned into leaves
|
||||
| `AlgebraicCombination(algebraicOp, t1, t2) =>
|
||||
AlgebraicCombination.operationToLeaf(
|
||||
|
@ -341,5 +346,23 @@ let rec toLeaf =
|
|||
|> E.R.bind(_, toLeaf(evaluationParams))
|
||||
| `FunctionCall(name, args) =>
|
||||
callableFunction(evaluationParams, name, args)
|
||||
| `MultiModal(r) =>
|
||||
let components =
|
||||
r
|
||||
|> E.A.fmap(((dist, weight)) =>
|
||||
`VerticalScaling((
|
||||
`Multiply,
|
||||
dist,
|
||||
`SymbolicDist(`Float(weight)),
|
||||
))
|
||||
);
|
||||
let pointwiseSum =
|
||||
components
|
||||
|> Js.Array.sliceFrom(1)
|
||||
|> E.A.fold_left(
|
||||
(acc, x) => {`PointwiseCombination((`Add, acc, x))},
|
||||
E.A.unsafe_get(components, 0),
|
||||
);
|
||||
Ok(`Render(`Normalize(pointwiseSum)));
|
||||
};
|
||||
};
|
||||
|
|
|
@ -29,6 +29,8 @@ module ExpressionTree = {
|
|||
| `Truncate(option(float), option(float), node)
|
||||
| `FloatFromDist(distToFloatOperation, node)
|
||||
| `FunctionCall(string, array(node))
|
||||
| `Array(array(node))
|
||||
| `MultiModal(array((node, float)))
|
||||
];
|
||||
// Have nil as option
|
||||
let getFloat = (node:node) => node |> fun
|
||||
|
|
|
@ -86,50 +86,98 @@ let fnn =
|
|||
| _ => Error("Needs 3 valid arguments")
|
||||
}
|
||||
| ("to", _) => apply2(twoFloats(to_), args)
|
||||
| ("pdf", _) => switch(args){
|
||||
| [|fst,snd|] => {
|
||||
switch(PTypes.SamplingDistribution.renderIfIsNotSamplingDistribution(evaluationParams,fst),getFloat(snd)){
|
||||
| (Ok(fst), Some(flt)) => Ok(`FloatFromDist(`Pdf(flt), fst))
|
||||
| _ => Error("Incorrect arguments")
|
||||
| ("pdf", _) =>
|
||||
switch (args) {
|
||||
| [|fst, snd|] =>
|
||||
switch (
|
||||
PTypes.SamplingDistribution.renderIfIsNotSamplingDistribution(
|
||||
evaluationParams,
|
||||
fst,
|
||||
),
|
||||
getFloat(snd),
|
||||
) {
|
||||
| (Ok(fst), Some(flt)) => Ok(`FloatFromDist((`Pdf(flt), fst)))
|
||||
| _ => Error("Incorrect arguments")
|
||||
}
|
||||
}
|
||||
| _ => Error("Needs two args")
|
||||
}
|
||||
| ("inv", _) => switch(args){
|
||||
| [|fst,snd|] => {
|
||||
switch(PTypes.SamplingDistribution.renderIfIsNotSamplingDistribution(evaluationParams,fst),getFloat(snd)){
|
||||
| (Ok(fst), Some(flt)) => Ok(`FloatFromDist(`Inv(flt), fst))
|
||||
| _ => Error("Incorrect arguments")
|
||||
}
|
||||
| ("inv", _) =>
|
||||
switch (args) {
|
||||
| [|fst, snd|] =>
|
||||
switch (
|
||||
PTypes.SamplingDistribution.renderIfIsNotSamplingDistribution(
|
||||
evaluationParams,
|
||||
fst,
|
||||
),
|
||||
getFloat(snd),
|
||||
) {
|
||||
| (Ok(fst), Some(flt)) => Ok(`FloatFromDist((`Inv(flt), fst)))
|
||||
| _ => Error("Incorrect arguments")
|
||||
}
|
||||
}
|
||||
| _ => Error("Needs two args")
|
||||
}
|
||||
| ("cdf", _) => switch(args){
|
||||
| [|fst,snd|] => {
|
||||
switch(PTypes.SamplingDistribution.renderIfIsNotSamplingDistribution(evaluationParams,fst),getFloat(snd)){
|
||||
| (Ok(fst), Some(flt)) => Ok(`FloatFromDist(`Cdf(flt), fst))
|
||||
| _ => Error("Incorrect arguments")
|
||||
}
|
||||
| ("cdf", _) =>
|
||||
switch (args) {
|
||||
| [|fst, snd|] =>
|
||||
switch (
|
||||
PTypes.SamplingDistribution.renderIfIsNotSamplingDistribution(
|
||||
evaluationParams,
|
||||
fst,
|
||||
),
|
||||
getFloat(snd),
|
||||
) {
|
||||
| (Ok(fst), Some(flt)) => Ok(`FloatFromDist((`Cdf(flt), fst)))
|
||||
| _ => Error("Incorrect arguments")
|
||||
}
|
||||
}
|
||||
| _ => Error("Needs two args")
|
||||
}
|
||||
| ("mean", _) => switch(args){
|
||||
| [|fst|] => {
|
||||
switch(PTypes.SamplingDistribution.renderIfIsNotSamplingDistribution(evaluationParams,fst)){
|
||||
| (Ok(fst)) => Ok(`FloatFromDist(`Mean,fst))
|
||||
| _ => Error("Incorrect arguments")
|
||||
}
|
||||
| ("mean", _) =>
|
||||
switch (args) {
|
||||
| [|fst|] =>
|
||||
switch (
|
||||
PTypes.SamplingDistribution.renderIfIsNotSamplingDistribution(
|
||||
evaluationParams,
|
||||
fst,
|
||||
)
|
||||
) {
|
||||
| Ok(fst) => Ok(`FloatFromDist((`Mean, fst)))
|
||||
| _ => Error("Incorrect arguments")
|
||||
}
|
||||
}
|
||||
| _ => Error("Needs two args")
|
||||
}
|
||||
| ("sample", _) => switch(args){
|
||||
| [|fst|] => {
|
||||
switch(PTypes.SamplingDistribution.renderIfIsNotSamplingDistribution(evaluationParams,fst)){
|
||||
| (Ok(fst)) => Ok(`FloatFromDist(`Sample,fst))
|
||||
| _ => Error("Incorrect arguments")
|
||||
}
|
||||
| ("sample", _) =>
|
||||
switch (args) {
|
||||
| [|fst|] =>
|
||||
switch (
|
||||
PTypes.SamplingDistribution.renderIfIsNotSamplingDistribution(
|
||||
evaluationParams,
|
||||
fst,
|
||||
)
|
||||
) {
|
||||
| Ok(fst) => Ok(`FloatFromDist((`Sample, fst)))
|
||||
| _ => Error("Incorrect arguments")
|
||||
}
|
||||
}
|
||||
| _ => Error("Needs two args")
|
||||
}
|
||||
}
|
||||
| ("mm", _)
|
||||
| ("multimodal", _) =>
|
||||
switch (args |> E.A.to_list) {
|
||||
| [`Array(weights), ...dists] =>
|
||||
let withWeights =
|
||||
dists
|
||||
|> E.L.toArray
|
||||
|> E.A.fmapi((index, t) => {
|
||||
let w =
|
||||
weights
|
||||
|> E.A.get(_, index)
|
||||
|> E.O.bind(_, getFloat)
|
||||
|> E.O.default(1.0);
|
||||
(t, w);
|
||||
});
|
||||
Ok(`MultiModal(withWeights));
|
||||
| dists when E.L.length(dists) > 0 =>
|
||||
Ok(`MultiModal(dists |> E.L.toArray |> E.A.fmap(r => (r, 1.0))))
|
||||
| _ => Error("Needs at least one distribution")
|
||||
}
|
||||
| _ => Error("Function " ++ name ++ " not found")
|
||||
};
|
||||
|
|
|
@ -138,40 +138,6 @@ module MathAdtToDistDst = {
|
|||
)
|
||||
};
|
||||
|
||||
let multiModal =
|
||||
(
|
||||
args: array(result(ExpressionTypes.ExpressionTree.node, string)),
|
||||
weights: option(array(float)),
|
||||
) => {
|
||||
let weights = weights |> E.O.default([||]);
|
||||
let firstWithError = args |> Belt.Array.getBy(_, Belt.Result.isError);
|
||||
let withoutErrors = args |> E.A.fmap(E.R.toOption) |> E.A.O.concatSomes;
|
||||
|
||||
switch (firstWithError) {
|
||||
| Some(Error(e)) => Error(e)
|
||||
| None when withoutErrors |> E.A.length == 0 =>
|
||||
Error("Multimodals need at least one input")
|
||||
| _ =>
|
||||
let components =
|
||||
withoutErrors
|
||||
|> E.A.fmapi((index, t) => {
|
||||
let w = weights |> E.A.get(_, index) |> E.O.default(1.0);
|
||||
|
||||
`VerticalScaling((`Multiply, t, `SymbolicDist(`Float(w))));
|
||||
});
|
||||
|
||||
let pointwiseSum =
|
||||
components
|
||||
|> Js.Array.sliceFrom(1)
|
||||
|> E.A.fold_left(
|
||||
(acc, x) => {`PointwiseCombination((`Add, acc, x))},
|
||||
E.A.unsafe_get(components, 0),
|
||||
);
|
||||
|
||||
Ok(`Normalize(pointwiseSum));
|
||||
};
|
||||
};
|
||||
|
||||
// Error("Dotwise exponentiation needs two operands")
|
||||
let operationParser =
|
||||
(
|
||||
|
@ -182,7 +148,6 @@ module MathAdtToDistDst = {
|
|||
let toOkAlgebraic = r => Ok(`AlgebraicCombination(r));
|
||||
let toOkPointwise = r => Ok(`PointwiseCombination(r));
|
||||
let toOkTruncate = r => Ok(`Truncate(r));
|
||||
let toOkFloatFromDist = r => Ok(`FloatFromDist(r));
|
||||
args
|
||||
|> E.R.bind(_, args => {
|
||||
switch (name, args) {
|
||||
|
@ -247,31 +212,28 @@ module MathAdtToDistDst = {
|
|||
let parseArgs = () => parseArray(args);
|
||||
switch (name) {
|
||||
| "lognormal" => lognormal(args, parseArgs, nodeParser)
|
||||
| "mm" =>
|
||||
| "mm" =>{
|
||||
let weights =
|
||||
args
|
||||
|> E.A.last
|
||||
|> E.O.bind(
|
||||
_,
|
||||
fun
|
||||
| Array(values) => Some(values)
|
||||
| _ => None,
|
||||
)
|
||||
|> E.O.fmap(o =>
|
||||
o
|
||||
|> E.A.fmap(
|
||||
fun
|
||||
| Value(r) => Some(r)
|
||||
| _ => None,
|
||||
)
|
||||
|> E.A.O.concatSomes
|
||||
| Array(values) => Some(parseArray(values))
|
||||
| _ => None
|
||||
);
|
||||
let possibleDists =
|
||||
E.O.isSome(weights)
|
||||
? Belt.Array.slice(args, ~offset=0, ~len=E.A.length(args) - 1)
|
||||
: args;
|
||||
let dists = possibleDists |> E.A.fmap(nodeParser);
|
||||
multiModal(dists, weights);
|
||||
let dists = parseArray(possibleDists);
|
||||
switch(weights, dists){
|
||||
| (Some(Error(r)), _) => Error(r)
|
||||
| (_, Error(r)) => Error(r)
|
||||
| (None, Ok(dists)) => Ok(`FunctionCall("multimodal", dists))
|
||||
| (Some(Ok(r)), Ok(dists)) => Ok(`FunctionCall("multimodal", E.A.append([|`Array(r)|], dists)))
|
||||
}
|
||||
}
|
||||
| "add"
|
||||
| "subtract"
|
||||
| "multiply"
|
||||
|
|
|
@ -141,8 +141,9 @@ let renderIfNeeded =
|
|||
node
|
||||
|> (
|
||||
fun
|
||||
| `SymbolicDist(n) => {
|
||||
`Render(`SymbolicDist(n))
|
||||
| `MultiModal(_) as n
|
||||
| `SymbolicDist(_) as n => {
|
||||
`Render(n)
|
||||
|> Internals.runNode(Internals.distPlusRenderInputsToInputs(inputs))
|
||||
|> (
|
||||
fun
|
||||
|
|
Loading…
Reference in New Issue
Block a user