From a6051d8371f53c9ac308cc26e4a277605a897b23 Mon Sep 17 00:00:00 2001 From: Ozzie Gooen Date: Sat, 7 Nov 2020 23:35:05 -0800 Subject: [PATCH] Further cleanup of MultiModal functionality --- src/distPlus/expressionTree/ExpressionTree.re | 1 - .../expressionTree/ExpressionTreeEvaluator.re | 1 - .../expressionTree/ExpressionTypes.re | 10 +- src/distPlus/expressionTree/Fns.re | 169 ++++++++---------- src/distPlus/expressionTree/TypeSystem.re | 27 ++- src/distPlus/renderers/DistPlusRenderer.re | 1 - src/distPlus/utility/E.re | 22 +++ 7 files changed, 126 insertions(+), 105 deletions(-) diff --git a/src/distPlus/expressionTree/ExpressionTree.re b/src/distPlus/expressionTree/ExpressionTree.re index 86e65fbb..2611a432 100644 --- a/src/distPlus/expressionTree/ExpressionTree.re +++ b/src/distPlus/expressionTree/ExpressionTree.re @@ -24,7 +24,6 @@ let rec toString: node => string = ++ toString(internal) ++ ")]" | `Array(_) => "Array" - | `MultiModal(_) => "Multimodal" | `Hash(_) => "Hash" let envs = (samplingInputs, environment) => { diff --git a/src/distPlus/expressionTree/ExpressionTreeEvaluator.re b/src/distPlus/expressionTree/ExpressionTreeEvaluator.re index 9f3702e9..82eb99c2 100644 --- a/src/distPlus/expressionTree/ExpressionTreeEvaluator.re +++ b/src/distPlus/expressionTree/ExpressionTreeEvaluator.re @@ -307,6 +307,5 @@ let rec toLeaf = Js.log3("In function call", name, args); callableFunction(evaluationParams, name, args) |> E.R.bind(_, toLeaf(evaluationParams)); - | `MultiModal(r) => Error("Multimodal?") }; }; diff --git a/src/distPlus/expressionTree/ExpressionTypes.re b/src/distPlus/expressionTree/ExpressionTypes.re index 33d1d0d1..07428c85 100644 --- a/src/distPlus/expressionTree/ExpressionTypes.re +++ b/src/distPlus/expressionTree/ExpressionTypes.re @@ -30,16 +30,18 @@ module ExpressionTree = { | `Render(node) | `Truncate(option(float), option(float), node) | `FunctionCall(string, array(node)) - | `MultiModal(array((node, float))) ]; module Hash = { type t('a) = array((string, 'a)); - let getByName = (t:t('a), name) => + let getByName = (t: t('a), name) => E.A.getBy(t, ((n, _)) => n == name) |> E.O.fmap(((_, r)) => r); - let getByNames = (hash: t('a), names:array(string)) => - names |> E.A.fmap(name => (name, getByName(hash, name))) + let getByNameResult = (t: t('a), name) => + getByName(t, name) |> E.O.toResult(name ++ " expected and not found"); + + let getByNames = (hash: t('a), names: array(string)) => + names |> E.A.fmap(name => (name, getByName(hash, name))); }; // Have nil as option let getFloat = (node: node) => diff --git a/src/distPlus/expressionTree/Fns.re b/src/distPlus/expressionTree/Fns.re index 4fdbc519..9d5983c5 100644 --- a/src/distPlus/expressionTree/Fns.re +++ b/src/distPlus/expressionTree/Fns.re @@ -107,6 +107,80 @@ let verticalScaling = (scaleOp, rs, scaleBy) => { ); }; +module Multimodal = { + let getByNameResult = ExpressionTypes.ExpressionTree.Hash.getByNameResult; + + let _paramsToDistsAndWeights = (r: array(typedValue)) => + switch (r) { + | [|`Named(r)|] => + let dists = + getByNameResult(r, "dists") + ->E.R.bind(TypeSystem.TypedValue.toArray) + ->E.R.bind(r => + r + |> E.A.fmap(TypeSystem.TypedValue.toDist) + |> E.A.R.firstErrorOrOpen + ); + let weights = + getByNameResult(r, "weights") + ->E.R.bind(TypeSystem.TypedValue.toArray) + ->E.R.bind(r => + r + |> E.A.fmap(TypeSystem.TypedValue.toFloat) + |> E.A.R.firstErrorOrOpen + ); + + E.R.merge(dists, weights) + |> E.R.fmap(((a, b)) => + E.A.zipMaxLength(a, b) + |> E.A.fmap(((a, b)) => + (a |> E.O.toExn(""), b |> E.O.default(1.0)) + ) + ); + | _ => Error("Needs items") + }; + let _runner: array(typedValue) => result(node, string) = + r => { + let paramsToDistsAndWeights = + _paramsToDistsAndWeights(r) + |> E.R.fmap( + E.A.fmap(((dist, weight)) => + `FunctionCall(( + "scaleMultiply", + [|dist, `SymbolicDist(`Float(weight))|], + )) + ), + ); + let pointwiseSum: result(node, string) = + paramsToDistsAndWeights->E.R.bind( + E.R.errorIfCondition(E.A.isEmpty, "Needs one input"), + ) + |> E.R.fmap(r => + r + |> Js.Array.sliceFrom(1) + |> E.A.fold_left( + (acc, x) => {`PointwiseCombination((`Add, acc, x))}, + E.A.unsafe_get(r, 0), + ) + ); + pointwiseSum; + }; + + let _function = + Function.T.make( + ~name="multimodal", + ~outputType=`SamplingDistribution, + ~inputTypes=[| + `Named([| + ("dists", `Array(`SamplingDistribution)), + ("weights", `Array(`Float)), + |]), + |], + ~run=_runner, + (), + ); +}; + let functions = [| makeSymbolicFromTwoFloats("normal", SymbolicDist.Normal.make), makeSymbolicFromTwoFloats("uniform", SymbolicDist.Uniform.make), @@ -175,98 +249,5 @@ let functions = [| makeRenderedDistFloat("scaleLog", (dist, float) => verticalScaling(`Log, dist, float) ), - Function.T.make( - ~name="multimodal", - ~outputType=`SamplingDistribution, - ~inputTypes=[| - `Named([| - ("dists", `Array(`SamplingDistribution)), - ("weights", `Array(`Float)), - |]), - |], - ~run= - fun - | [|`Named(r)|] => { - let foo = - (r: TypeSystem.typedValue) - : result(ExpressionTypes.ExpressionTree.node, string) => - switch (r) { - | `SamplingDist(`SymbolicDist(c)) => Ok(`SymbolicDist(c)) - | `SamplingDist(`RenderedDist(c)) => Ok(`RenderedDist(c)) - | `Float(x) => - Ok(`RenderedDist(SymbolicDist.T.toShape(1000, `Float(x)))) - | _ => Error("") - }; - let weight = (r: TypeSystem.typedValue) => - switch (r) { - | `Float(x) => Ok(x) - | _ => Error("Wrong Type") - }; - let dists = - switch (ExpressionTypes.ExpressionTree.Hash.getByName(r, "dists")) { - | Some(`Array(r)) => r |> E.A.fmap(foo) |> E.A.R.firstErrorOrOpen - | _ => Error("") - }; - let weights = - ( - switch ( - ExpressionTypes.ExpressionTree.Hash.getByName(r, "weights") - ) { - | Some(`Array(r)) => - r |> E.A.fmap(weight) |> E.A.R.firstErrorOrOpen - | _ => Error("") - } - ) - |> ( - fun - | Ok(r) => r - | _ => [||] - ); - let withWeights = - dists - |> E.R.fmap(d => { - let iis = - d |> E.A.length |> Belt.Array.makeUninitializedUnsafe; - for (i in 0 to (d |> E.A.length) - 1) { - Belt.Array.set( - iis, - i, - ( - E.A.unsafe_get(d, i), - E.A.get(weights, i) |> E.O.default(1.0), - ), - ) - |> ignore; - }; - iis; - }); - let components: result(array(node), string) = - withWeights - |> E.R.fmap( - E.A.fmap(((dist, weight)) => - `FunctionCall(( - "scaleMultiply", - [|dist, `SymbolicDist(`Float(weight))|], - )) - ), - ); - let pointwiseSum = - components - |> E.R.bind(_, r => { - E.A.length(r) > 0 - ? Ok(r) : Error("Invalid argument length") - }) - |> E.R.fmap(r => - r - |> Js.Array.sliceFrom(1) - |> E.A.fold_left( - (acc, x) => {`PointwiseCombination((`Add, acc, x))}, - E.A.unsafe_get(r, 0), - ) - ); - pointwiseSum; - } - | _ => Error(""), - (), - ), + Multimodal._function |]; diff --git a/src/distPlus/expressionTree/TypeSystem.re b/src/distPlus/expressionTree/TypeSystem.re index cda9ad27..4e39e753 100644 --- a/src/distPlus/expressionTree/TypeSystem.re +++ b/src/distPlus/expressionTree/TypeSystem.re @@ -54,7 +54,6 @@ module TypedValue = { // todo: Arrays and hashes let rec fromNodeWithTypeCoercion = (evaluationParams, _type: _type, node) => { - Js.log3("With Coersion!", _type, node); switch (_type, node) { | (`Float, _) => switch (getFloat(node)) { @@ -76,7 +75,6 @@ module TypedValue = { |> E.A.R.firstErrorOrOpen |> E.R.fmap(r => `Array(r)) | (`Named(named), `Hash(r)) => - Js.log3("Named", named, r); let foo = named |> E.A.fmap(((name, intendedType)) => @@ -86,7 +84,6 @@ module TypedValue = { ExpressionTypes.ExpressionTree.Hash.getByName(r, name), ) ); - Js.log("Named: part 2"); let bar = foo |> E.A.fmap(((name, intendedType, optionNode)) => @@ -99,11 +96,33 @@ module TypedValue = { ) |> E.A.R.firstErrorOrOpen |> E.R.fmap(r => `Named(r)); - Js.log3("Named!", foo, bar); bar; | _ => Error("fromNodeWithTypeCoercion error, sorry.") }; }; + + let toFloat = + fun + | `Float(x) => Ok(x) + | _ => Error("Not a float"); + + let toArray = + fun + | `Array(x) => Ok(x) + | _ => Error("Not an array"); + + let toNamed = + fun + | `Named(x) => Ok(x) + | _ => Error("Not a named item"); + + let toDist = + fun + | `SamplingDist(`SymbolicDist(c)) => Ok(`SymbolicDist(c)) + | `SamplingDist(`RenderedDist(c)) => Ok(`RenderedDist(c)) + | `Float(x) => + Ok(`RenderedDist(SymbolicDist.T.toShape(1000, `Float(x)))) + | _ => Error(""); }; module Function = { diff --git a/src/distPlus/renderers/DistPlusRenderer.re b/src/distPlus/renderers/DistPlusRenderer.re index 70f720e0..10c37ea7 100644 --- a/src/distPlus/renderers/DistPlusRenderer.re +++ b/src/distPlus/renderers/DistPlusRenderer.re @@ -144,7 +144,6 @@ let renderIfNeeded = node |> ( fun - | `MultiModal(_) as n | `Normalize(_) as n | `SymbolicDist(_) as n => { `Render(n) diff --git a/src/distPlus/utility/E.re b/src/distPlus/utility/E.re index 175c2847..505b1d0d 100644 --- a/src/distPlus/utility/E.re +++ b/src/distPlus/utility/E.re @@ -26,6 +26,9 @@ module FloatFloatMap = { let fmap = (fn, t: t) => Belt.MutableMap.map(t, fn); }; +module Int = { + let max = (i1: int, i2: int) => i1 > i2 ? i1 : i2; +}; /* Utils */ module U = { let isEqual = (a, b) => a == b; @@ -146,6 +149,11 @@ module R = { let fmap = Rationale.Result.fmap; let bind = Rationale.Result.bind; let toExn = Belt.Result.getExn; + let default = (default, res: Belt.Result.t('a, 'b)) => + switch (res) { + | Ok(r) => r + | Error(_) => default + }; let merge = (a, b) => switch (a, b) { | (Error(e), _) => Error(e) @@ -157,6 +165,9 @@ module R = { | Ok(r) => Some(r) | Error(_) => None }; + + let errorIfCondition = (errorCondition, errorMessage, r) => + errorCondition(r) ? Error(errorMessage) : Ok(r); }; let safe_fn_of_string = (fn, s: string): option('a) => @@ -263,6 +274,7 @@ module A = { let init = Array.init; let reduce = Belt.Array.reduce; let reducei = Belt.Array.reduceWithIndex; + let isEmpty = r => length(r) < 1; let min = a => get(a, 0) |> O.fmap(first => Belt.Array.reduce(a, first, (i, j) => i < j ? i : j)); @@ -285,6 +297,16 @@ module A = { |> Rationale.Result.return }; + // This zips while taking the longest elements of each array. + let zipMaxLength = (array1, array2) => { + let maxLength = Int.max(length(array1), length(array2)); + let result = maxLength |> Belt.Array.makeUninitializedUnsafe; + for (i in 0 to maxLength - 1) { + Belt.Array.set(result, i, (get(array1, i), get(array2, i))) |> ignore; + }; + result; + }; + let asList = (f: list('a) => list('a), r: array('a)) => r |> to_list |> f |> of_list; /* TODO: Is there a better way of doing this? */