From b064001a512b024630df0df4a6405756dfacf3a3 Mon Sep 17 00:00:00 2001 From: Ozzie Gooen Date: Sat, 8 Aug 2020 22:33:36 +0100 Subject: [PATCH] Ensure render in DistPlusRenderer --- .../expressionTree/ExpressionTreeEvaluator.re | 4 +- src/distPlus/expressionTree/Functions.re | 5 ++ src/distPlus/expressionTree/MathJsParser.re | 15 ++++ src/distPlus/renderers/DistPlusRenderer.re | 72 +++++++++++-------- 4 files changed, 66 insertions(+), 30 deletions(-) diff --git a/src/distPlus/expressionTree/ExpressionTreeEvaluator.re b/src/distPlus/expressionTree/ExpressionTreeEvaluator.re index 836bdcc3..99689911 100644 --- a/src/distPlus/expressionTree/ExpressionTreeEvaluator.re +++ b/src/distPlus/expressionTree/ExpressionTreeEvaluator.re @@ -160,7 +160,7 @@ module PointwiseCombination = { Render.render(evaluationParams, t2), ) { | (Ok(`RenderedDist(rs1)), Ok(`RenderedDist(rs2))) => - Ok(`RenderedDist(Shape.combinePointwise(( *. ), rs1, rs2))) + Ok(`RenderedDist(Shape.combinePointwise(fn, rs1, rs2))) | (Error(e1), _) => Error(e1) | (_, Error(e2)) => Error(e2) | _ => Error("Pointwise combination: rendering failed.") @@ -177,7 +177,7 @@ module PointwiseCombination = { switch (pointwiseOp) { | `Add => pointwiseAdd(evaluationParams, t1, t2) | `Multiply => pointwiseCombine(( *. ),evaluationParams, t1, t2) - | `Exponentiate => pointwiseCombine(( *. ),evaluationParams, t1, t2) + | `Exponentiate => pointwiseCombine(( ** ),evaluationParams, t1, t2) }; }; }; diff --git a/src/distPlus/expressionTree/Functions.re b/src/distPlus/expressionTree/Functions.re index 1f7a6dd2..80fa81ca 100644 --- a/src/distPlus/expressionTree/Functions.re +++ b/src/distPlus/expressionTree/Functions.re @@ -49,6 +49,11 @@ let to_: array(node) => result(node, string) = Error("Low value must be less than high value.") | _ => Error("Requires 2 variables"); +// Possible setup: +// let normal = {"inputs": [`float, `float], "outputs": [`float]}; +// let render = {"inputs": [`dist], "outputs": [`renderedDist]}; +// let render = {"inputs": [`distRenderedDist], "outputs": [`renderedDist]}; + let fnn = ( evaluationParams: ExpressionTypes.ExpressionTree.evaluationParams, diff --git a/src/distPlus/expressionTree/MathJsParser.re b/src/distPlus/expressionTree/MathJsParser.re index 52fb90b2..1b234dad 100644 --- a/src/distPlus/expressionTree/MathJsParser.re +++ b/src/distPlus/expressionTree/MathJsParser.re @@ -201,6 +201,9 @@ module MathAdtToDistDst = { | ("dotMultiply", [|l, r|]) => toOkPointwise((`Multiply, l, r)) | ("dotMultiply", _) => Error("Dotwise multiplication needs two operands") + | ("dotPow", [|l, r|]) => toOkPointwise((`Exponentiate, l, r)) + | ("dotPow", _) => + Error("Dotwise exponentiation needs two operands") | ("rightLogShift", [|l, r|]) => toOkPointwise((`Add, l, r)) | ("rightLogShift", _) => Error("Dotwise addition needs two operands") @@ -227,6 +230,12 @@ module MathAdtToDistDst = { Error( "truncate needs three arguments: the expression and both cutoffs", ) + | ("scaleMultiply", [|d, `SymbolicDist(`Float(v))|]) => + Ok(`VerticalScaling(`Multiply, d, `SymbolicDist(`Float(v)))) + | ("scaleExp", [|d, `SymbolicDist(`Float(v))|]) => + Ok(`VerticalScaling(`Exponentiate, d, `SymbolicDist(`Float(v)))) + | ("scaleLog", [|d, `SymbolicDist(`Float(v))|]) => + Ok(`VerticalScaling(`Log, d, `SymbolicDist(`Float(v)))) | ("pdf", [|d, `SymbolicDist(`Float(v))|]) => toOkFloatFromDist((`Pdf(v), d)) | ("cdf", [|d, `SymbolicDist(`Float(v))|]) => @@ -275,11 +284,15 @@ module MathAdtToDistDst = { | "subtract" | "multiply" | "dotMultiply" + | "dotPow" | "rightLogShift" | "divide" | "pow" | "leftTruncate" | "rightTruncate" + | "scaleMultiply" + | "scaleExp" + | "scaleLog" | "truncate" | "mean" | "inv" @@ -355,6 +368,7 @@ let fromString2 = str => { Inside of this function, MathAdtToDistDst is called whenever a distribution function is encountered. */ let mathJsToJson = str |> pointwiseToRightLogShift |> Mathjs.parseMath; + let mathJsParse = E.R.bind(mathJsToJson, r => { switch (MathJsonToMathJsAdt.run(r)) { @@ -364,6 +378,7 @@ let fromString2 = str => { }); let value = E.R.bind(mathJsParse, MathAdtToDistDst.run); + Js.log2(mathJsParse, value); value; }; diff --git a/src/distPlus/renderers/DistPlusRenderer.re b/src/distPlus/renderers/DistPlusRenderer.re index e8c97342..9981db52 100644 --- a/src/distPlus/renderers/DistPlusRenderer.re +++ b/src/distPlus/renderers/DistPlusRenderer.re @@ -91,18 +91,16 @@ module Internals = { }; let makeOutputs = (graph, shape): outputs => {graph, shape}; + let makeInputs = (inputs): ExpressionTypes.ExpressionTree.samplingInputs => { + sampleCount: inputs.samplingInputs.sampleCount |> E.O.default(10000), + outputXYPoints: + inputs.samplingInputs.outputXYPoints |> E.O.default(10000), + kernelWidth: inputs.samplingInputs.kernelWidth, + shapeLength: inputs.samplingInputs.shapeLength |> E.O.default(10000), + }; + let runNode = (inputs, node) => { - ExpressionTree.toLeaf( - { - sampleCount: inputs.samplingInputs.sampleCount |> E.O.default(10000), - outputXYPoints: - inputs.samplingInputs.outputXYPoints |> E.O.default(10000), - kernelWidth: inputs.samplingInputs.kernelWidth, - shapeLength: inputs.samplingInputs.shapeLength |> E.O.default(10000), - }, - inputs.environment, - node, - ); + ExpressionTree.toLeaf(makeInputs(inputs), inputs.environment, node); }; let runProgram = (inputs: inputs, p: ExpressionTypes.Program.program) => { @@ -154,15 +152,40 @@ let run = (inputs: Inputs.inputs) => { |> E.R.fmap(Internals.outputToDistPlus(inputs)); }; -let exportDistPlus = inputs => - fun - | `RenderedDist(n) => Ok(`DistPlus(Internals.outputToDistPlus(inputs, n))) - | `Function(n) => Ok(`Function(n)) - | n => - Error( - "Didn't output a rendered distribution. Format:" - ++ ExpressionTree.toString(n), - ); +let renderIfNeeded = + (inputs, node: ExpressionTypes.ExpressionTree.node) + : result(ExpressionTypes.ExpressionTree.node, string) => + node + |> ( + fun + | `SymbolicDist(n) => { + `Render(`SymbolicDist(n)) + |> Internals.runNode(Internals.distPlusRenderInputsToInputs(inputs)) + |> ( + fun + | Ok(`RenderedDist(n)) => Ok(`RenderedDist(n)) + | Error(r) => Error(r) + | _ => Error("Didn't render, but intended to") + ); + } + | n => Ok(n) + ); + +let exportDistPlus = (inputs, node: ExpressionTypes.ExpressionTree.node) => + node + |> renderIfNeeded(inputs) + |> E.R.bind( + _, + fun + | `RenderedDist(n) => + Ok(`DistPlus(Internals.outputToDistPlus(inputs, n))) + | `Function(n) => Ok(`Function(n)) + | n => + Error( + "Didn't output a rendered distribution. Format:" + ++ ExpressionTree.toString(n), + ), + ); let run2 = (inputs: Inputs.inputs) => { inputs @@ -177,17 +200,10 @@ let runFunction = fn: (array(string), ExpressionTypes.ExpressionTree.node), fnInputs, ) => { - let (_, fns) = fn; let inputs = ins |> Internals.distPlusRenderInputsToInputs; let output = ExpressionTree.runFunction( - { - sampleCount: inputs.samplingInputs.sampleCount |> E.O.default(10000), - outputXYPoints: - inputs.samplingInputs.outputXYPoints |> E.O.default(10000), - kernelWidth: inputs.samplingInputs.kernelWidth, - shapeLength: inputs.samplingInputs.shapeLength |> E.O.default(10000), - }, + Internals.makeInputs(inputs), inputs.environment, fnInputs, fn,