From e266e6557c4aed514ae6946fc13d283c5b28f9d9 Mon Sep 17 00:00:00 2001 From: Ozzie Gooen Date: Fri, 31 Jul 2020 14:02:29 +0100 Subject: [PATCH] Added simple exponentiation --- .../distribution/AlgebraicShapeCombination.re | 68 +++++++++++-------- .../expressionTree/ExpressionTreeEvaluator.re | 7 +- .../expressionTree/ExpressionTypes.re | 4 +- src/distPlus/expressionTree/MathJsParser.re | 6 +- src/distPlus/expressionTree/Operation.re | 2 + src/distPlus/renderers/DistPlusRenderer.re | 2 +- 6 files changed, 54 insertions(+), 35 deletions(-) diff --git a/src/distPlus/distribution/AlgebraicShapeCombination.re b/src/distPlus/distribution/AlgebraicShapeCombination.re index fd6f7453..631fb3a5 100644 --- a/src/distPlus/distribution/AlgebraicShapeCombination.re +++ b/src/distPlus/distribution/AlgebraicShapeCombination.re @@ -27,15 +27,15 @@ let toDiscretePointMassesFromTriangulars = let xsProdN1: array(float) = Belt.Array.makeUninitializedUnsafe(n - 1); let xsProdN2: array(float) = Belt.Array.makeUninitializedUnsafe(n - 2); for (i in 0 to n - 1) { - Belt.Array.set(xsSq, i, xs[i] *. xs[i]) |> ignore; + Belt.Array.set(xsSq, i, xs[i] *. xs[i]) |> ignore; (); }; for (i in 0 to n - 2) { - Belt.Array.set(xsProdN1, i, xs[i] *. xs[i + 1]) |> ignore; + Belt.Array.set(xsProdN1, i, xs[i] *. xs[i + 1]) |> ignore; (); }; for (i in 0 to n - 3) { - Belt.Array.set(xsProdN2, i, xs[i] *. xs[i + 2]) |> ignore; + Belt.Array.set(xsProdN2, i, xs[i] *. xs[i + 2]) |> ignore; (); }; // means and variances @@ -45,11 +45,8 @@ let toDiscretePointMassesFromTriangulars = if (inverse) { for (i in 1 to n - 2) { - Belt.Array.set( - masses, - i - 1, - (xs[i + 1] -. xs[i - 1]) *. ys[i] /. 2., - ) |> ignore; + Belt.Array.set(masses, i - 1, (xs[i + 1] -. xs[i - 1]) *. ys[i] /. 2.) + |> ignore; // this only works when the whole triange is either on the left or on the right of zero let a = xs[i - 1]; @@ -70,9 +67,9 @@ let toDiscretePointMassesFromTriangulars = -. inverseMean ** 2.; - Belt.Array.set(means, i - 1, inverseMean) |> ignore; + Belt.Array.set(means, i - 1, inverseMean) |> ignore; - Belt.Array.set(variances, i - 1, inverseVar) |> ignore; + Belt.Array.set(variances, i - 1, inverseVar) |> ignore; (); }; @@ -80,14 +77,12 @@ let toDiscretePointMassesFromTriangulars = } else { for (i in 1 to n - 2) { // area of triangle = width * height / 2 - Belt.Array.set( - masses, - i - 1, - (xs[i + 1] -. xs[i - 1]) *. ys[i] /. 2., - ) |> ignore; + Belt.Array.set(masses, i - 1, (xs[i + 1] -. xs[i - 1]) *. ys[i] /. 2.) + |> ignore; // means of triangle = (a + b + c) / 3 - Belt.Array.set(means, i - 1, (xs[i - 1] +. xs[i] +. xs[i + 1]) /. 3.) |> ignore; + Belt.Array.set(means, i - 1, (xs[i - 1] +. xs[i] +. xs[i + 1]) /. 3.) + |> ignore; // variance of triangle = (a^2 + b^2 + c^2 - ab - ac - bc) / 18 Belt.Array.set( @@ -102,7 +97,8 @@ let toDiscretePointMassesFromTriangulars = -. xsProdN2[i - 1] ) /. 18., - ) |> ignore; + ) + |> ignore; (); }; {n: n - 2, masses, means, variances}; @@ -134,8 +130,10 @@ let combineShapesContinuousContinuous = | `Subtract => ((m1, m2) => m1 -. m2) | `Multiply => ((m1, m2) => m1 *. m2) | `Divide => ((m1, mInv2) => m1 *. mInv2) + | `Exponentiate => ((m1, mInv2) => m1 ** mInv2) }; // note: here, mInv2 = mean(1 / t2) ~= 1 / mean(t2) + // TODO: I don't know what the variances are for exponentatiation // converts the variances and means of the two inputs into the variance of the output let combineVariancesFn = switch (op) { @@ -144,6 +142,8 @@ let combineShapesContinuousContinuous = | `Multiply => ( (v1, v2, m1, m2) => v1 *. v2 +. v1 *. m2 ** 2. +. v2 *. m1 ** 2. ) + | `Exponentiate => + ((v1, v2, m1, m2) => v1 *. v2 +. v1 *. m2 ** 2. +. v2 *. m1 ** 2.); | `Divide => ( (v1, vInv2, m1, mInv2) => v1 *. vInv2 +. v1 *. mInv2 ** 2. +. vInv2 *. m1 ** 2. @@ -225,9 +225,12 @@ let toDiscretePointMassesFromDiscrete = }; let combineShapesContinuousDiscrete = - (op: ExpressionTypes.algebraicOperation, continuousShape: DistTypes.xyShape, discreteShape: DistTypes.xyShape) + ( + op: ExpressionTypes.algebraicOperation, + continuousShape: DistTypes.xyShape, + discreteShape: DistTypes.xyShape, + ) : DistTypes.xyShape => { - let t1n = continuousShape |> XYShape.T.length; let t2n = discreteShape |> XYShape.T.length; @@ -248,15 +251,19 @@ let combineShapesContinuousDiscrete = Belt.Array.set( dxyShape, i, - (fn(continuousShape.xs[i], discreteShape.xs[j]), - continuousShape.ys[i] *. discreteShape.ys[j]), - ) |> ignore; + ( + fn(continuousShape.xs[i], discreteShape.xs[j]), + continuousShape.ys[i] *. discreteShape.ys[j], + ), + ) + |> ignore; (); }; Belt.Array.set(outXYShapes, j, dxyShape) |> ignore; (); } | `Multiply + | `Exponentiate | `Divide => for (j in 0 to t2n - 1) { // creates a new continuous shape for each one of the discrete points, and collects them in outXYShapes. @@ -266,8 +273,12 @@ let combineShapesContinuousDiscrete = Belt.Array.set( dxyShape, i, - (fn(continuousShape.xs[i], discreteShape.xs[j]), continuousShape.ys[i] *. discreteShape.ys[j] /. discreteShape.xs[j]), - ) |> ignore; + ( + fn(continuousShape.xs[i], discreteShape.xs[j]), + {continuousShape.ys[i] *. discreteShape.ys[j] /. discreteShape.xs[j]} + ), + ) + |> ignore; (); }; Belt.Array.set(outXYShapes, j, dxyShape) |> ignore; @@ -278,7 +289,10 @@ let combineShapesContinuousDiscrete = outXYShapes |> E.A.fmap(XYShape.T.fromZippedArray) |> E.A.fold_left( - XYShape.PointwiseCombination.combine((+.), - XYShape.XtoY.continuousInterpolator(`Linear, `UseZero)), - XYShape.T.empty); + XYShape.PointwiseCombination.combine( + (+.), + XYShape.XtoY.continuousInterpolator(`Linear, `UseZero), + ), + XYShape.T.empty, + ); }; diff --git a/src/distPlus/expressionTree/ExpressionTreeEvaluator.re b/src/distPlus/expressionTree/ExpressionTreeEvaluator.re index 9649cd08..27570b59 100644 --- a/src/distPlus/expressionTree/ExpressionTreeEvaluator.re +++ b/src/distPlus/expressionTree/ExpressionTreeEvaluator.re @@ -151,7 +151,7 @@ module PointwiseCombination = { }; }; - let pointwiseMultiply = (evaluationParams: evaluationParams, t1: t, t2: t) => { + let pointwiseCombine = (fn, evaluationParams: evaluationParams, t1: t, t2: t) => { // TODO: construct a function that we can easily sample from, to construct // a RenderedDist. Use the xMin and xMax of the rendered shapes to tell the sampling function where to look. // TODO: This should work for symbolic distributions too! @@ -176,7 +176,8 @@ module PointwiseCombination = { ) => { switch (pointwiseOp) { | `Add => pointwiseAdd(evaluationParams, t1, t2) - | `Multiply => pointwiseMultiply(evaluationParams, t1, t2) + | `Multiply => pointwiseCombine(( *. ),evaluationParams, t1, t2) + | `Exponentiate => pointwiseCombine(( *. ),evaluationParams, t1, t2) }; }; }; @@ -279,7 +280,7 @@ module Render = { | `SymbolicDist(d) => Ok( `RenderedDist( - SymbolicDist.T.toShape(1234, d), + SymbolicDist.T.toShape(evaluationParams.samplingInputs.shapeLength, d), ), ) | `RenderedDist(_) as t => Ok(t) // already a rendered shape, we're done here diff --git a/src/distPlus/expressionTree/ExpressionTypes.re b/src/distPlus/expressionTree/ExpressionTypes.re index 8becbbd0..6453667a 100644 --- a/src/distPlus/expressionTree/ExpressionTypes.re +++ b/src/distPlus/expressionTree/ExpressionTypes.re @@ -1,5 +1,5 @@ -type algebraicOperation = [ | `Add | `Multiply | `Subtract | `Divide]; -type pointwiseOperation = [ | `Add | `Multiply]; +type algebraicOperation = [ | `Add | `Multiply | `Subtract | `Divide | `Exponentiate]; +type pointwiseOperation = [ | `Add | `Multiply | `Exponentiate]; type scaleOperation = [ | `Multiply | `Exponentiate | `Log]; type distToFloatOperation = [ | `Pdf(float) diff --git a/src/distPlus/expressionTree/MathJsParser.re b/src/distPlus/expressionTree/MathJsParser.re index b15d4a77..76e1c221 100644 --- a/src/distPlus/expressionTree/MathJsParser.re +++ b/src/distPlus/expressionTree/MathJsParser.re @@ -177,11 +177,12 @@ module MathAdtToDistDst = { }; }; + // Error("Dotwise exponentiation needs two operands") let operationParser = ( name: string, args: result(array(ExpressionTypes.ExpressionTree.node), string), - ) => { + ):result(ExpressionTypes.ExpressionTree.node,string) => { let toOkAlgebraic = r => Ok(`AlgebraicCombination(r)); let toOkPointwise = r => Ok(`PointwiseCombination(r)); let toOkTruncate = r => Ok(`Truncate(r)); @@ -195,6 +196,8 @@ module MathAdtToDistDst = { | ("subtract", _) => Error("Subtraction needs two operands") | ("multiply", [|l, r|]) => toOkAlgebraic((`Multiply, l, r)) | ("multiply", _) => Error("Multiplication needs two operands") + | ("pow", [|l,r|]) => toOkAlgebraic((`Exponentiate, l, r)) + | ("pow", _) => Error("Exponentiation needs two operands") | ("dotMultiply", [|l, r|]) => toOkPointwise((`Multiply, l, r)) | ("dotMultiply", _) => Error("Dotwise multiplication needs two operands") @@ -203,7 +206,6 @@ module MathAdtToDistDst = { Error("Dotwise addition needs two operands") | ("divide", [|l, r|]) => toOkAlgebraic((`Divide, l, r)) | ("divide", _) => Error("Division needs two operands") - | ("pow", _) => Error("Exponentiation is not yet supported.") | ("leftTruncate", [|d, `SymbolicDist(`Float(lc))|]) => toOkTruncate((Some(lc), None, d)) | ("leftTruncate", _) => diff --git a/src/distPlus/expressionTree/Operation.re b/src/distPlus/expressionTree/Operation.re index 26826f3f..e415b5de 100644 --- a/src/distPlus/expressionTree/Operation.re +++ b/src/distPlus/expressionTree/Operation.re @@ -7,6 +7,7 @@ module Algebraic = { | `Add => (+.) | `Subtract => (-.) | `Multiply => ( *. ) + | `Exponentiate => ( ** ) | `Divide => (/.); let applyFn = (t, f1, f2) => { @@ -21,6 +22,7 @@ module Algebraic = { | `Add => "+" | `Subtract => "-" | `Multiply => "*" + | `Exponentiate => ( "**" ) | `Divide => "/"; let format = (a, b, c) => b ++ " " ++ toString(a) ++ " " ++ c; diff --git a/src/distPlus/renderers/DistPlusRenderer.re b/src/distPlus/renderers/DistPlusRenderer.re index 9af6f4e6..7c39e17c 100644 --- a/src/distPlus/renderers/DistPlusRenderer.re +++ b/src/distPlus/renderers/DistPlusRenderer.re @@ -125,7 +125,7 @@ module Internals = { let inputsToShape = (inputs: inputs) => { MathJsParser.fromString(inputs.guesstimatorString) |> E.R.bind(_, g => runProgram(inputs, g)) - |> E.R.bind(_, r => E.A.last(r) |> E.O.toResult("No rendered lines")); + |> E.R.bind(_, r => E.A.last(r) |> E.O.toResult("No rendered lines") |> E.R.fmap(Shape.T.normalize)); }; let outputToDistPlus = (inputs: Inputs.inputs, shape: DistTypes.shape) => {