diff --git a/src/distPlus/expressionTree/ExpressionTree.re b/src/distPlus/expressionTree/ExpressionTree.re index 6cbfe898..86e65fbb 100644 --- a/src/distPlus/expressionTree/ExpressionTree.re +++ b/src/distPlus/expressionTree/ExpressionTree.re @@ -8,8 +8,6 @@ let rec toString: node => string = Operation.Algebraic.format(op, toString(t1), toString(t2)) | `PointwiseCombination(op, t1, t2) => Operation.Pointwise.format(op, toString(t1), toString(t2)) - | `VerticalScaling(scaleOp, t, scaleBy) => - Operation.Scale.format(scaleOp, toString(t), toString(scaleBy)) | `Normalize(t) => "normalize(k" ++ toString(t) ++ ")" | `Truncate(lc, rc, t) => Operation.T.truncateToString(lc, rc, toString(t)) diff --git a/src/distPlus/expressionTree/ExpressionTreeEvaluator.re b/src/distPlus/expressionTree/ExpressionTreeEvaluator.re index 14500828..c2f8b90c 100644 --- a/src/distPlus/expressionTree/ExpressionTreeEvaluator.re +++ b/src/distPlus/expressionTree/ExpressionTreeEvaluator.re @@ -91,36 +91,6 @@ module AlgebraicCombination = { ); }; -module VerticalScaling = { - let operationToLeaf = - (evaluationParams: evaluationParams, scaleOp, t, scaleBy) => { - // scaleBy has to be a single float, otherwise we'll return an error. - let fn = (secondary, main) => - Operation.Scale.toFn(scaleOp, main, secondary); - let integralSumCacheFn = Operation.Scale.toIntegralSumCacheFn(scaleOp); - let integralCacheFn = Operation.Scale.toIntegralCacheFn(scaleOp); - let renderedShape = Render.render(evaluationParams, t); - - let s = - switch (renderedShape, scaleBy) { - | (Ok(`RenderedDist(rs)), `SymbolicDist(`Float(scaleBy))) => - Ok( - `RenderedDist( - Shape.T.mapY( - ~integralSumCacheFn=integralSumCacheFn(scaleBy), - ~integralCacheFn=integralCacheFn(scaleBy), - ~fn=fn(scaleBy), - rs, - ), - ), - ) - | (Error(e1), _) => Error(e1) - | (_, _) => Error("Can only scale by float values.") - }; - s; - }; -}; - module PointwiseCombination = { let pointwiseAdd = (evaluationParams: evaluationParams, t1: t, t2: t) => { switch ( @@ -309,8 +279,6 @@ let rec toLeaf = t1, t2, ) - | `VerticalScaling(scaleOp, t, scaleBy) => - VerticalScaling.operationToLeaf(evaluationParams, scaleOp, t, scaleBy) | `Truncate(leftCutoff, rightCutoff, t) => Truncate.operationToLeaf(evaluationParams, leftCutoff, rightCutoff, t) | `Normalize(t) => Normalize.operationToLeaf(evaluationParams, t) @@ -336,12 +304,7 @@ let rec toLeaf = let components = r |> E.A.fmap(((dist, weight)) => - `VerticalScaling(( - `Multiply, - dist, - `SymbolicDist(`Float(weight)), - )) - ); + `FunctionCall("scaleExp", [|dist, `SymbolicDist(`Float(weight))|])); let pointwiseSum = components |> Js.Array.sliceFrom(1) diff --git a/src/distPlus/expressionTree/ExpressionTypes.re b/src/distPlus/expressionTree/ExpressionTypes.re index c9e1a63a..d40a92d3 100644 --- a/src/distPlus/expressionTree/ExpressionTypes.re +++ b/src/distPlus/expressionTree/ExpressionTypes.re @@ -25,7 +25,6 @@ module ExpressionTree = { | `Function(array(string), node) | `AlgebraicCombination(algebraicOperation, node, node) | `PointwiseCombination(pointwiseOperation, node, node) - | `VerticalScaling(scaleOperation, node, node) | `Normalize(node) | `Render(node) | `Truncate(option(float), option(float), node) diff --git a/src/distPlus/expressionTree/Fns.re b/src/distPlus/expressionTree/Fns.re index dc74caf2..4947a1d4 100644 --- a/src/distPlus/expressionTree/Fns.re +++ b/src/distPlus/expressionTree/Fns.re @@ -1,6 +1,9 @@ open TypeSystem; -let wrongInputsError = (r) => {Js.log2("Wrong inputs", r); Error("Wrong inputs")}; +let wrongInputsError = r => { + Js.log2("Wrong inputs", r); + Error("Wrong inputs"); +}; let to_: (float, float) => result(node, string) = (low, high) => @@ -20,7 +23,7 @@ let makeSymbolicFromTwoFloats = (name, fn) => ~run= fun | [|`Float(a), `Float(b)|] => Ok(`SymbolicDist(fn(a, b))) - | e => wrongInputsError(e) + | e => wrongInputsError(e), ); let makeSymbolicFromOneFloat = (name, fn) => @@ -31,21 +34,32 @@ let makeSymbolicFromOneFloat = (name, fn) => ~run= fun | [|`Float(a)|] => Ok(`SymbolicDist(fn(a))) - | e => wrongInputsError(e) + | e => wrongInputsError(e), ); -let makeDistFloat = (name, fn) => +let makeDistFloat = (name, fn) => Function.make( ~name, ~output=`SamplingDistribution, ~inputs=[|`SamplingDistribution, `Float|], ~run= fun - | [|`SamplingDist(a), `Float(b)|] => (fn(a,b)) - | e => wrongInputsError(e) + | [|`SamplingDist(a), `Float(b)|] => fn(a, b) + | e => wrongInputsError(e), ); -let makeDist = (name, fn) => +let makeRenderedDistFloat = (name, fn) => + Function.make( + ~name, + ~output=`RenderedDistribution, + ~inputs=[|`RenderedDistribution, `Float|], + ~run= + fun + | [|`RenderedDist(a), `Float(b)|] => fn(a, b) + | e => wrongInputsError(e), + ); + +let makeDist = (name, fn) => Function.make( ~name, ~output=`SamplingDistribution, @@ -53,7 +67,7 @@ let makeDist = (name, fn) => ~run= fun | [|`SamplingDist(a)|] => fn(a) - | e => wrongInputsError(e) + | e => wrongInputsError(e), ); let floatFromDist = @@ -71,6 +85,22 @@ let floatFromDist = }; }; +let verticalScaling = (scaleOp, rs, scaleBy) => { + // scaleBy has to be a single float, otherwise we'll return an error. + let fn = (secondary, main) => + Operation.Scale.toFn(scaleOp, main, secondary); + let integralSumCacheFn = Operation.Scale.toIntegralSumCacheFn(scaleOp); + let integralCacheFn = Operation.Scale.toIntegralCacheFn(scaleOp); + Ok(`RenderedDist( + Shape.T.mapY( + ~integralSumCacheFn=integralSumCacheFn(scaleBy), + ~integralCacheFn=integralCacheFn(scaleBy), + ~fn=fn(scaleBy), + rs, + ), + )); +}; + let functions = [| makeSymbolicFromTwoFloats("normal", SymbolicDist.Normal.make), makeSymbolicFromTwoFloats("uniform", SymbolicDist.Uniform.make), @@ -87,8 +117,8 @@ let functions = [| ~inputs=[|`Float, `Float|], ~run= fun - | [|`Float(a), `Float(b)|] => to_(a,b) - | e => wrongInputsError(e) + | [|`Float(a), `Float(b)|] => to_(a, b) + | e => wrongInputsError(e), ), Function.make( ~name="triangular", @@ -99,11 +129,39 @@ let functions = [| | [|`Float(a), `Float(b), `Float(c)|] => SymbolicDist.Triangular.make(a, b, c) |> E.R.fmap(r => `SymbolicDist(r)) - | e => wrongInputsError(e) + | e => wrongInputsError(e), ), makeDistFloat("pdf", (dist, float) => floatFromDist(`Pdf(float), dist)), makeDistFloat("inv", (dist, float) => floatFromDist(`Inv(float), dist)), makeDistFloat("cdf", (dist, float) => floatFromDist(`Cdf(float), dist)), - makeDist("mean", (dist) => floatFromDist(`Mean, dist)), - makeDist("sample", (dist) => floatFromDist(`Sample, dist)) + makeDist("mean", dist => floatFromDist(`Mean, dist)), + makeDist("sample", dist => floatFromDist(`Sample, dist)), + Function.make( + ~name="render", + ~output=`RenderedDistribution, + ~inputs=[|`RenderedDistribution|], + ~run= + fun + | [|`RenderedDist(c)|] => Ok(`RenderedDist(c)) + | e => wrongInputsError(e), + ), + Function.make( + ~name="normalize", + ~output=`SamplingDistribution, + ~inputs=[|`SamplingDistribution|], + ~run= + fun + | [|`SamplingDist(`SymbolicDist(c))|] => Ok(`SymbolicDist(c)) + | [|`SamplingDist(`RenderedDist(c))|] => Ok(`RenderedDist(Shape.T.normalize(c))) + | e => wrongInputsError(e), + ), + makeRenderedDistFloat("scaleExp", (dist, float) => + verticalScaling(`Exponentiate, dist, float) + ), + makeRenderedDistFloat("scaleMultiply", (dist, float) => + verticalScaling(`Multiply, dist, float) + ), + makeRenderedDistFloat("scaleLog", (dist, float) => + verticalScaling(`Log, dist, float) + ), |]; diff --git a/src/distPlus/expressionTree/MathJsParser.re b/src/distPlus/expressionTree/MathJsParser.re index 2182dbb1..aa757892 100644 --- a/src/distPlus/expressionTree/MathJsParser.re +++ b/src/distPlus/expressionTree/MathJsParser.re @@ -193,14 +193,6 @@ module MathAdtToDistDst = { Error( "truncate needs three arguments: the expression and both cutoffs", ) - | ("scaleMultiply", [|d, `SymbolicDist(`Float(v))|]) => - Ok(`VerticalScaling((`Multiply, d, `SymbolicDist(`Float(v))))) - | ("scaleExp", [|d, `SymbolicDist(`Float(v))|]) => - Ok( - `VerticalScaling((`Exponentiate, d, `SymbolicDist(`Float(v)))), - ) - | ("scaleLog", [|d, `SymbolicDist(`Float(v))|]) => - Ok(`VerticalScaling((`Log, d, `SymbolicDist(`Float(v))))) | _ => Error("This type not currently supported") } }); @@ -245,9 +237,6 @@ module MathAdtToDistDst = { | "pow" | "leftTruncate" | "rightTruncate" - | "scaleMultiply" - | "scaleExp" - | "scaleLog" | "truncate" => operationParser(name, parseArgs()) | name => parseArgs()