From 31e4f978207d4ffe5b6ac0678a236804ed32da4a Mon Sep 17 00:00:00 2001 From: Ozzie Gooen Date: Mon, 10 Aug 2020 23:35:21 +0100 Subject: [PATCH] Simple code version of multimodal, cdf, pdf, sample --- src/distPlus/expressionTree/ExpressionTree.re | 4 +- .../expressionTree/ExpressionTreeEvaluator.re | 23 ++++ .../expressionTree/ExpressionTypes.re | 2 + src/distPlus/expressionTree/Functions.re | 118 ++++++++++++------ src/distPlus/expressionTree/MathJsParser.re | 60 ++------- src/distPlus/renderers/DistPlusRenderer.re | 5 +- 6 files changed, 125 insertions(+), 87 deletions(-) diff --git a/src/distPlus/expressionTree/ExpressionTree.re b/src/distPlus/expressionTree/ExpressionTree.re index 3e7193ab..02b33262 100644 --- a/src/distPlus/expressionTree/ExpressionTree.re +++ b/src/distPlus/expressionTree/ExpressionTree.re @@ -32,7 +32,9 @@ let rec toString: node => string = "[Function: (" ++ (args |> Js.String.concatMany(_, ",")) ++ toString(internal) - ++ ")]"; + ++ ")]" + | `Array(args) => "Array" + | `MultiModal(args) => "Multimodal" let toShape = (samplingInputs, environment, node: node) => { switch (toLeaf(samplingInputs, environment, node)) { diff --git a/src/distPlus/expressionTree/ExpressionTreeEvaluator.re b/src/distPlus/expressionTree/ExpressionTreeEvaluator.re index 996d9863..82182e95 100644 --- a/src/distPlus/expressionTree/ExpressionTreeEvaluator.re +++ b/src/distPlus/expressionTree/ExpressionTreeEvaluator.re @@ -309,6 +309,11 @@ let rec toLeaf = | `SymbolicDist(_) | `Function(_) | `RenderedDist(_) => Ok(node) + | `Array(args) => + args + |> E.A.fmap(toLeaf(evaluationParams)) + |> E.A.R.firstErrorOrOpen + |> E.R.fmap(r => `Array(r)) // Operations nevaluationParamsd to be turned into leaves | `AlgebraicCombination(algebraicOp, t1, t2) => AlgebraicCombination.operationToLeaf( @@ -341,5 +346,23 @@ let rec toLeaf = |> E.R.bind(_, toLeaf(evaluationParams)) | `FunctionCall(name, args) => callableFunction(evaluationParams, name, args) + | `MultiModal(r) => + let components = + r + |> E.A.fmap(((dist, weight)) => + `VerticalScaling(( + `Multiply, + dist, + `SymbolicDist(`Float(weight)), + )) + ); + let pointwiseSum = + components + |> Js.Array.sliceFrom(1) + |> E.A.fold_left( + (acc, x) => {`PointwiseCombination((`Add, acc, x))}, + E.A.unsafe_get(components, 0), + ); + Ok(`Render(`Normalize(pointwiseSum))); }; }; diff --git a/src/distPlus/expressionTree/ExpressionTypes.re b/src/distPlus/expressionTree/ExpressionTypes.re index ec359a4c..a8404eb7 100644 --- a/src/distPlus/expressionTree/ExpressionTypes.re +++ b/src/distPlus/expressionTree/ExpressionTypes.re @@ -29,6 +29,8 @@ module ExpressionTree = { | `Truncate(option(float), option(float), node) | `FloatFromDist(distToFloatOperation, node) | `FunctionCall(string, array(node)) + | `Array(array(node)) + | `MultiModal(array((node, float))) ]; // Have nil as option let getFloat = (node:node) => node |> fun diff --git a/src/distPlus/expressionTree/Functions.re b/src/distPlus/expressionTree/Functions.re index cddc698b..1cdcd73e 100644 --- a/src/distPlus/expressionTree/Functions.re +++ b/src/distPlus/expressionTree/Functions.re @@ -86,50 +86,98 @@ let fnn = | _ => Error("Needs 3 valid arguments") } | ("to", _) => apply2(twoFloats(to_), args) - | ("pdf", _) => switch(args){ - | [|fst,snd|] => { - switch(PTypes.SamplingDistribution.renderIfIsNotSamplingDistribution(evaluationParams,fst),getFloat(snd)){ - | (Ok(fst), Some(flt)) => Ok(`FloatFromDist(`Pdf(flt), fst)) - | _ => Error("Incorrect arguments") + | ("pdf", _) => + switch (args) { + | [|fst, snd|] => + switch ( + PTypes.SamplingDistribution.renderIfIsNotSamplingDistribution( + evaluationParams, + fst, + ), + getFloat(snd), + ) { + | (Ok(fst), Some(flt)) => Ok(`FloatFromDist((`Pdf(flt), fst))) + | _ => Error("Incorrect arguments") } - } | _ => Error("Needs two args") - } - | ("inv", _) => switch(args){ - | [|fst,snd|] => { - switch(PTypes.SamplingDistribution.renderIfIsNotSamplingDistribution(evaluationParams,fst),getFloat(snd)){ - | (Ok(fst), Some(flt)) => Ok(`FloatFromDist(`Inv(flt), fst)) - | _ => Error("Incorrect arguments") + } + | ("inv", _) => + switch (args) { + | [|fst, snd|] => + switch ( + PTypes.SamplingDistribution.renderIfIsNotSamplingDistribution( + evaluationParams, + fst, + ), + getFloat(snd), + ) { + | (Ok(fst), Some(flt)) => Ok(`FloatFromDist((`Inv(flt), fst))) + | _ => Error("Incorrect arguments") } - } | _ => Error("Needs two args") - } - | ("cdf", _) => switch(args){ - | [|fst,snd|] => { - switch(PTypes.SamplingDistribution.renderIfIsNotSamplingDistribution(evaluationParams,fst),getFloat(snd)){ - | (Ok(fst), Some(flt)) => Ok(`FloatFromDist(`Cdf(flt), fst)) - | _ => Error("Incorrect arguments") + } + | ("cdf", _) => + switch (args) { + | [|fst, snd|] => + switch ( + PTypes.SamplingDistribution.renderIfIsNotSamplingDistribution( + evaluationParams, + fst, + ), + getFloat(snd), + ) { + | (Ok(fst), Some(flt)) => Ok(`FloatFromDist((`Cdf(flt), fst))) + | _ => Error("Incorrect arguments") } - } | _ => Error("Needs two args") - } - | ("mean", _) => switch(args){ - | [|fst|] => { - switch(PTypes.SamplingDistribution.renderIfIsNotSamplingDistribution(evaluationParams,fst)){ - | (Ok(fst)) => Ok(`FloatFromDist(`Mean,fst)) - | _ => Error("Incorrect arguments") + } + | ("mean", _) => + switch (args) { + | [|fst|] => + switch ( + PTypes.SamplingDistribution.renderIfIsNotSamplingDistribution( + evaluationParams, + fst, + ) + ) { + | Ok(fst) => Ok(`FloatFromDist((`Mean, fst))) + | _ => Error("Incorrect arguments") } - } | _ => Error("Needs two args") - } - | ("sample", _) => switch(args){ - | [|fst|] => { - switch(PTypes.SamplingDistribution.renderIfIsNotSamplingDistribution(evaluationParams,fst)){ - | (Ok(fst)) => Ok(`FloatFromDist(`Sample,fst)) - | _ => Error("Incorrect arguments") + } + | ("sample", _) => + switch (args) { + | [|fst|] => + switch ( + PTypes.SamplingDistribution.renderIfIsNotSamplingDistribution( + evaluationParams, + fst, + ) + ) { + | Ok(fst) => Ok(`FloatFromDist((`Sample, fst))) + | _ => Error("Incorrect arguments") } - } | _ => Error("Needs two args") - } + } + | ("mm", _) + | ("multimodal", _) => + switch (args |> E.A.to_list) { + | [`Array(weights), ...dists] => + let withWeights = + dists + |> E.L.toArray + |> E.A.fmapi((index, t) => { + let w = + weights + |> E.A.get(_, index) + |> E.O.bind(_, getFloat) + |> E.O.default(1.0); + (t, w); + }); + Ok(`MultiModal(withWeights)); + | dists when E.L.length(dists) > 0 => + Ok(`MultiModal(dists |> E.L.toArray |> E.A.fmap(r => (r, 1.0)))) + | _ => Error("Needs at least one distribution") + } | _ => Error("Function " ++ name ++ " not found") }; diff --git a/src/distPlus/expressionTree/MathJsParser.re b/src/distPlus/expressionTree/MathJsParser.re index 2ae5fef2..2182dbb1 100644 --- a/src/distPlus/expressionTree/MathJsParser.re +++ b/src/distPlus/expressionTree/MathJsParser.re @@ -138,40 +138,6 @@ module MathAdtToDistDst = { ) }; - let multiModal = - ( - args: array(result(ExpressionTypes.ExpressionTree.node, string)), - weights: option(array(float)), - ) => { - let weights = weights |> E.O.default([||]); - let firstWithError = args |> Belt.Array.getBy(_, Belt.Result.isError); - let withoutErrors = args |> E.A.fmap(E.R.toOption) |> E.A.O.concatSomes; - - switch (firstWithError) { - | Some(Error(e)) => Error(e) - | None when withoutErrors |> E.A.length == 0 => - Error("Multimodals need at least one input") - | _ => - let components = - withoutErrors - |> E.A.fmapi((index, t) => { - let w = weights |> E.A.get(_, index) |> E.O.default(1.0); - - `VerticalScaling((`Multiply, t, `SymbolicDist(`Float(w)))); - }); - - let pointwiseSum = - components - |> Js.Array.sliceFrom(1) - |> E.A.fold_left( - (acc, x) => {`PointwiseCombination((`Add, acc, x))}, - E.A.unsafe_get(components, 0), - ); - - Ok(`Normalize(pointwiseSum)); - }; - }; - // Error("Dotwise exponentiation needs two operands") let operationParser = ( @@ -182,7 +148,6 @@ module MathAdtToDistDst = { let toOkAlgebraic = r => Ok(`AlgebraicCombination(r)); let toOkPointwise = r => Ok(`PointwiseCombination(r)); let toOkTruncate = r => Ok(`Truncate(r)); - let toOkFloatFromDist = r => Ok(`FloatFromDist(r)); args |> E.R.bind(_, args => { switch (name, args) { @@ -247,31 +212,28 @@ module MathAdtToDistDst = { let parseArgs = () => parseArray(args); switch (name) { | "lognormal" => lognormal(args, parseArgs, nodeParser) - | "mm" => + | "mm" =>{ let weights = args |> E.A.last |> E.O.bind( _, fun - | Array(values) => Some(values) - | _ => None, - ) - |> E.O.fmap(o => - o - |> E.A.fmap( - fun - | Value(r) => Some(r) - | _ => None, - ) - |> E.A.O.concatSomes + | Array(values) => Some(parseArray(values)) + | _ => None ); let possibleDists = E.O.isSome(weights) ? Belt.Array.slice(args, ~offset=0, ~len=E.A.length(args) - 1) : args; - let dists = possibleDists |> E.A.fmap(nodeParser); - multiModal(dists, weights); + let dists = parseArray(possibleDists); + switch(weights, dists){ + | (Some(Error(r)), _) => Error(r) + | (_, Error(r)) => Error(r) + | (None, Ok(dists)) => Ok(`FunctionCall("multimodal", dists)) + | (Some(Ok(r)), Ok(dists)) => Ok(`FunctionCall("multimodal", E.A.append([|`Array(r)|], dists))) + } + } | "add" | "subtract" | "multiply" diff --git a/src/distPlus/renderers/DistPlusRenderer.re b/src/distPlus/renderers/DistPlusRenderer.re index 4218dc43..5dae9610 100644 --- a/src/distPlus/renderers/DistPlusRenderer.re +++ b/src/distPlus/renderers/DistPlusRenderer.re @@ -141,8 +141,9 @@ let renderIfNeeded = node |> ( fun - | `SymbolicDist(n) => { - `Render(`SymbolicDist(n)) + | `MultiModal(_) as n + | `SymbolicDist(_) as n => { + `Render(n) |> Internals.runNode(Internals.distPlusRenderInputsToInputs(inputs)) |> ( fun