diff --git a/__tests__/Distributions__Test.re b/__tests__/Distributions__Test.re index 341ef8a4..0b2e30e6 100644 --- a/__tests__/Distributions__Test.re +++ b/__tests__/Distributions__Test.re @@ -383,9 +383,9 @@ describe("Shape", () => { let numSamples = 10000; open Distributions.Shape; let normal: SymbolicTypes.symbolicDist = `Normal({mean, stdev}); - let normalShape = ExpressionTree.toShape(numSamples, `Leaf(`SymbolicDist(normal))); + let normalShape = ExpressionTree.toShape(numSamples, `SymbolicDist(normal)); let lognormal = SymbolicDist.Lognormal.fromMeanAndStdev(mean, stdev); - let lognormalShape = ExpressionTree.toShape(numSamples, `Leaf(`SymbolicDist(lognormal))); + let lognormalShape = ExpressionTree.toShape(numSamples, `SymbolicDist(lognormal)); makeTestCloseEquality( "Mean of a normal", diff --git a/src/components/Drawer.re b/src/components/Drawer.re index 8a0f2cfa..f9ae5ddb 100644 --- a/src/components/Drawer.re +++ b/src/components/Drawer.re @@ -389,7 +389,7 @@ module Draw = { let numSamples = 3000; let normal: SymbolicTypes.symbolicDist = `Normal({mean, stdev}); - let normalShape = ExpressionTree.toShape(numSamples, `Leaf(`SymbolicDist(normal))); + let normalShape = ExpressionTree.toShape(numSamples, `SymbolicDist(normal)); let xyShape: Types.xyShape = switch (normalShape) { | Mixed(_) => {xs: [||], ys: [||]} diff --git a/src/distPlus/expressionTree/ExpressionTree.re b/src/distPlus/expressionTree/ExpressionTree.re index 2ceb783b..bd162bbf 100644 --- a/src/distPlus/expressionTree/ExpressionTree.re +++ b/src/distPlus/expressionTree/ExpressionTree.re @@ -2,10 +2,11 @@ open ExpressionTypes.ExpressionTree; let toShape = (sampleCount: int, node: node) => { let renderResult = - ExpressionTreeEvaluator.toLeaf(`Operation(`Render(node)), sampleCount); + `Render(`Normalize(node)) + |> ExpressionTreeEvaluator.toLeaf({sampleCount: sampleCount}); switch (renderResult) { - | Ok(`Leaf(`RenderedDist(rs))) => + | Ok(`RenderedDist(rs)) => let continuous = Distributions.Shape.T.toContinuous(rs); let discrete = Distributions.Shape.T.toDiscrete(rs); let shape = MixedShapeBuilder.buildSimple(~continuous, ~discrete); @@ -17,6 +18,6 @@ let toShape = (sampleCount: int, node: node) => { let rec toString = fun - | `Leaf(`SymbolicDist(d)) => SymbolicDist.T.toString(d) - | `Leaf(`RenderedDist(_)) => "[shape]" - | `Operation(op) => Operation.T.toString(toString, op); + | `SymbolicDist(d) => SymbolicDist.T.toString(d) + | `RenderedDist(_) => "[shape]" + | op => Operation.T.toString(toString, op); diff --git a/src/distPlus/expressionTree/ExpressionTreeEvaluator.re b/src/distPlus/expressionTree/ExpressionTreeEvaluator.re index 348f91ef..6c5210f8 100644 --- a/src/distPlus/expressionTree/ExpressionTreeEvaluator.re +++ b/src/distPlus/expressionTree/ExpressionTreeEvaluator.re @@ -1,91 +1,83 @@ -/* This module represents a tree node. */ open ExpressionTypes; open ExpressionTypes.ExpressionTree; type t = node; type tResult = node => result(node, string); +type renderParams = { + sampleCount: int, +}; + /* Given two random variables A and B, this returns the distribution of a new variable that is the result of the operation on A and B. For instance, normal(0, 1) + normal(1, 1) -> normal(1, 2). In general, this is implemented via convolution. */ module AlgebraicCombination = { - let toTreeNode = (op, t1, t2) => - `Operation(`AlgebraicCombination((op, t1, t2))); - let tryAnalyticalSolution = - fun - | `Operation( - `AlgebraicCombination( - operation, - `Leaf(`SymbolicDist(d1)), - `Leaf(`SymbolicDist(d2)), - ), - ) as t => - switch (SymbolicDist.T.attemptAnalyticalOperation(d1, d2, operation)) { - | `AnalyticalSolution(symbolicDist) => - Ok(`Leaf(`SymbolicDist(symbolicDist))) + let tryAnalyticalSimplification = (operation, t1: t, t2: t) => + switch (operation, t1, t2) { + | (operation, + `SymbolicDist(d1), + `SymbolicDist(d2), + ) => + switch (SymbolicDist.T.tryAnalyticalSimplification(d1, d2, operation)) { + | `AnalyticalSolution(symbolicDist) => Ok(`SymbolicDist(symbolicDist)) | `Error(er) => Error(er) - | `NoSolution => Ok(t) + | `NoSolution => Ok(`AlgebraicCombination(operation, t1, t2)) } - | t => Ok(t); + | _ => Ok(`AlgebraicCombination(operation, t1, t2)) + }; - // todo: I don't like the name evaluateNumerically that much, if this renders and does it algebraically. It's tricky. - let evaluateNumerically = (algebraicOp, operationToLeaf, t1, t2) => { - // force rendering into shapes - let renderShape = r => operationToLeaf(`Render(r)); + let combineAsShapes = (toLeaf, renderParams, algebraicOp, t1, t2) => { + let renderShape = r => toLeaf(renderParams, `Render(r)); switch (renderShape(t1), renderShape(t2)) { - | (Ok(`Leaf(`RenderedDist(s1))), Ok(`Leaf(`RenderedDist(s2)))) => + | (Ok(`RenderedDist(s1)), Ok(`RenderedDist(s2))) => Ok( - `Leaf( `RenderedDist( Distributions.Shape.combineAlgebraically(algebraicOp, s1, s2), ), - ), ) | (Error(e1), _) => Error(e1) | (_, Error(e2)) => Error(e2) - | _ => Error("Could not render shapes.") + | _ => Error("Algebraic combination: rendering failed.") }; }; - let toLeaf = + let operationToLeaf = ( - operationToLeaf, + toLeaf, + renderParams: renderParams, algebraicOp: ExpressionTypes.algebraicOperation, t1: t, t2: t, ) : result(node, string) => - toTreeNode(algebraicOp, t1, t2) - |> tryAnalyticalSolution + + algebraicOp + |> tryAnalyticalSimplification(_, t1, t2) |> E.R.bind( _, fun - | `Leaf(d) => Ok(`Leaf(d)) // the analytical simplifaction worked, nice! - | `Operation(_) => - // if not, run the convolution - evaluateNumerically(algebraicOp, operationToLeaf, t1, t2), + | `SymbolicDist(d) as t => Ok(t) + | _ => combineAsShapes(toLeaf, renderParams, algebraicOp, t1, t2) ); }; module VerticalScaling = { - let toLeaf = (operationToLeaf, scaleOp, t, scaleBy) => { + let operationToLeaf = (toLeaf, renderParams, scaleOp, t, scaleBy) => { // scaleBy has to be a single float, otherwise we'll return an error. let fn = Operation.Scale.toFn(scaleOp); let knownIntegralSumFn = Operation.Scale.toKnownIntegralSumFn(scaleOp); - let renderedShape = operationToLeaf(`Render(t)); + let renderedShape = toLeaf(renderParams, `Render(t)); switch (renderedShape, scaleBy) { - | (Ok(`Leaf(`RenderedDist(rs))), `Leaf(`SymbolicDist(`Float(sm)))) => + | (Ok(`RenderedDist(rs)), `SymbolicDist(`Float(sm))) => Ok( - `Leaf( `RenderedDist( Distributions.Shape.T.mapY( ~knownIntegralSumFn=knownIntegralSumFn(sm), fn(sm), rs, ), - ), ), ) | (Error(e1), _) => Error(e1) @@ -95,31 +87,27 @@ module VerticalScaling = { }; module PointwiseCombination = { - let pointwiseAdd = (operationToLeaf, t1, t2) => { - let renderedShape1 = operationToLeaf(`Render(t1)); - let renderedShape2 = operationToLeaf(`Render(t2)); - - switch (renderedShape1, renderedShape2) { - | (Ok(`Leaf(`RenderedDist(rs1))), Ok(`Leaf(`RenderedDist(rs2)))) => + let pointwiseAdd = (toLeaf, renderParams, t1, t2) => { + let renderShape = r => toLeaf(renderParams, `Render(r)); + switch (renderShape(t1), renderShape(t2)) { + | (Ok(`RenderedDist(rs1)), Ok(`RenderedDist(rs2))) => Ok( - `Leaf( - `RenderedDist( - Distributions.Shape.combinePointwise( - ~knownIntegralSumsFn=(a, b) => Some(a +. b), - (+.), - rs1, - rs2, - ), + `RenderedDist( + Distributions.Shape.combinePointwise( + ~knownIntegralSumsFn=(a, b) => Some(a +. b), + (+.), + rs1, + rs2, ), ), ) | (Error(e1), _) => Error(e1) | (_, Error(e2)) => Error(e2) - | _ => Error("Could not perform pointwise addition.") + | _ => Error("Pointwise combination: rendering failed.") }; }; - let pointwiseMultiply = (operationToLeaf, t1, t2) => { + let pointwiseMultiply = (toLeaf, renderParams, t1, t2) => { // TODO: construct a function that we can easily sample from, to construct // a RenderedDist. Use the xMin and xMax of the rendered shapes to tell the sampling function where to look. Error( @@ -127,84 +115,72 @@ module PointwiseCombination = { ); }; - let toLeaf = (operationToLeaf, pointwiseOp, t1, t2) => { + let operationToLeaf = (toLeaf, renderParams, pointwiseOp, t1, t2) => { switch (pointwiseOp) { - | `Add => pointwiseAdd(operationToLeaf, t1, t2) - | `Multiply => pointwiseMultiply(operationToLeaf, t1, t2) + | `Add => pointwiseAdd(toLeaf, renderParams, t1, t2) + | `Multiply => pointwiseMultiply(toLeaf, renderParams, t1, t2) }; }; }; module Truncate = { - module Simplify = { - let tryTruncatingNothing: tResult = - fun - | `Operation(`Truncate(None, None, `Leaf(d))) => Ok(`Leaf(d)) - | t => Ok(t); - - let tryTruncatingUniform: tResult = - fun - | `Operation(`Truncate(lc, rc, `Leaf(`SymbolicDist(`Uniform(u))))) => { - // just create a new Uniform distribution - let newLow = max(E.O.default(neg_infinity, lc), u.low); - let newHigh = min(E.O.default(infinity, rc), u.high); - Ok(`Leaf(`SymbolicDist(`Uniform({low: newLow, high: newHigh})))); - } - | t => Ok(t); - - let attempt = (leftCutoff, rightCutoff, t): result(node, string) => { - let originalTreeNode = - `Operation(`Truncate((leftCutoff, rightCutoff, t))); - - originalTreeNode - |> tryTruncatingNothing - |> E.R.bind(_, tryTruncatingUniform); + let trySimplification = (leftCutoff, rightCutoff, t) => { + switch (leftCutoff, rightCutoff, t) { + | (None, None, t) => Ok(t) + | (lc, rc, `SymbolicDist(`Uniform(u))) => { + // just create a new Uniform distribution + let nu: SymbolicTypes.uniform = u; + let newLow = max(E.O.default(neg_infinity, lc), nu.low); + let newHigh = min(E.O.default(infinity, rc), nu.high); + Ok(`SymbolicDist(`Uniform({low: newLow, high: newHigh}))); + } + | (_, _, t) => Ok(t) }; }; - let evaluateNumerically = (leftCutoff, rightCutoff, operationToLeaf, t) => { + let truncateAsShape = (toLeaf, renderParams, leftCutoff, rightCutoff, t) => { // TODO: use named args in renderToShape; if we're lucky we can at least get the tail // of a distribution we otherwise wouldn't get at all - let renderedShape = operationToLeaf(`Render(t)); + let renderedShape = toLeaf(renderParams, `Render(t)); switch (renderedShape) { - | Ok(`Leaf(`RenderedDist(rs))) => + | Ok(`RenderedDist(rs)) => { let truncatedShape = rs |> Distributions.Shape.T.truncate(leftCutoff, rightCutoff); - Ok(`Leaf(`RenderedDist(rs))); + Ok(`RenderedDist(rs)); + } | Error(e1) => Error(e1) | _ => Error("Could not truncate distribution.") }; }; - let toLeaf = - ( - operationToLeaf, - leftCutoff: option(float), - rightCutoff: option(float), - t: node, - ) - : result(node, string) => { + let operationToLeaf = + ( + toLeaf, + renderParams, + leftCutoff: option(float), + rightCutoff: option(float), + t: node, + ) + : result(node, string) => { t - |> Simplify.attempt(leftCutoff, rightCutoff) + |> trySimplification(leftCutoff, rightCutoff) |> E.R.bind( _, fun - | `Leaf(d) => Ok(`Leaf(d)) // the analytical simplifaction worked, nice! - | `Operation(_) => - evaluateNumerically(leftCutoff, rightCutoff, operationToLeaf, t), - ); // if not, run the convolution + | `SymbolicDist(d) as t => Ok(t) + | _ => truncateAsShape(toLeaf, renderParams, leftCutoff, rightCutoff, t), + ); }; }; module Normalize = { - let rec toLeaf = (operationToLeaf, t: node): result(node, string) => { + let rec operationToLeaf = (toLeaf, renderParams, t: node): result(node, string) => { switch (t) { - | `Leaf(`RenderedDist(s)) => - Ok(`Leaf(`RenderedDist(Distributions.Shape.T.normalize(s)))) - | `Leaf(`SymbolicDist(_)) => Ok(t) - | `Operation(op) => - operationToLeaf(op) |> E.R.bind(_, toLeaf(operationToLeaf)) + | `RenderedDist(s) => + Ok(`RenderedDist(Distributions.Shape.T.normalize(s))) + | `SymbolicDist(_) => Ok(t) + | _ => t |> toLeaf(renderParams) |> E.R.bind(_, operationToLeaf(toLeaf, renderParams)) }; }; }; @@ -212,83 +188,79 @@ module Normalize = { module FloatFromDist = { let symbolicToLeaf = (distToFloatOp: distToFloatOperation, s) => { SymbolicDist.T.operate(distToFloatOp, s) - |> E.R.bind(_, v => Ok(`Leaf(`SymbolicDist(`Float(v))))); + |> E.R.bind(_, v => Ok(`SymbolicDist(`Float(v)))); }; let renderedToLeaf = (distToFloatOp: distToFloatOperation, rs: DistTypes.shape) : result(node, string) => { Distributions.Shape.operate(distToFloatOp, rs) - |> (v => Ok(`Leaf(`SymbolicDist(`Float(v))))); + |> (v => Ok(`SymbolicDist(`Float(v)))); }; - let rec toLeaf = - (operationToLeaf, distToFloatOp: distToFloatOperation, t: node) + let rec operationToLeaf = + (toLeaf, renderParams, distToFloatOp: distToFloatOperation, t: node) : result(node, string) => { switch (t) { - | `Leaf(`SymbolicDist(s)) => symbolicToLeaf(distToFloatOp, s) // we want to evaluate the distToFloatOp on the symbolic dist - | `Leaf(`RenderedDist(rs)) => renderedToLeaf(distToFloatOp, rs) - | `Operation(op) => - E.R.bind(operationToLeaf(op), toLeaf(operationToLeaf, distToFloatOp)) + | `SymbolicDist(s) => symbolicToLeaf(distToFloatOp, s) + | `RenderedDist(rs) => renderedToLeaf(distToFloatOp, rs) + | _ => t |> toLeaf(renderParams) |> E.R.bind(_, operationToLeaf(toLeaf, renderParams, distToFloatOp)) }; }; }; module Render = { - let rec toLeaf = + let rec operationToLeaf = ( - operationToLeaf: operation => result(t, string), - sampleCount: int, + toLeaf, + renderParams, t: node, ) : result(t, string) => { switch (t) { - | `Leaf(`SymbolicDist(d)) => - Ok(`Leaf(`RenderedDist(SymbolicDist.T.toShape(sampleCount, d)))) - | `Leaf(`RenderedDist(_)) as t => Ok(t) // already a rendered shape, we're done here - | `Operation(op) => - E.R.bind(operationToLeaf(op), toLeaf(operationToLeaf, sampleCount)) + | `SymbolicDist(d) => + Ok(`RenderedDist(SymbolicDist.T.toShape(renderParams.sampleCount, d))) + | `RenderedDist(_) as t => Ok(t) // already a rendered shape, we're done here + | _ => t |> toLeaf(renderParams) |> E.R.bind(_, operationToLeaf(toLeaf, renderParams)) }; }; }; -let rec operationToLeaf = - (sampleCount: int, op: operation): result(t, string) => { - // the functions that convert the Operation nodes to Leaf nodes need to - // have a way to call this function on their children, if their children are themselves Operation nodes. - switch (op) { - | `AlgebraicCombination(algebraicOp, t1, t2) => - AlgebraicCombination.toLeaf( - operationToLeaf(sampleCount), - algebraicOp, - t1, - t2 // we want to give it the option to render or simply leave it as is - ) - | `PointwiseCombination(pointwiseOp, t1, t2) => - PointwiseCombination.toLeaf( - operationToLeaf(sampleCount), - pointwiseOp, - t1, - t2, - ) - | `VerticalScaling(scaleOp, t, scaleBy) => - VerticalScaling.toLeaf(operationToLeaf(sampleCount), scaleOp, t, scaleBy) - | `Truncate(leftCutoff, rightCutoff, t) => - Truncate.toLeaf(operationToLeaf(sampleCount), leftCutoff, rightCutoff, t) - | `FloatFromDist(distToFloatOp, t) => - FloatFromDist.toLeaf(operationToLeaf(sampleCount), distToFloatOp, t) - | `Normalize(t) => Normalize.toLeaf(operationToLeaf(sampleCount), t) - | `Render(t) => Render.toLeaf(operationToLeaf(sampleCount), sampleCount, t) - }; -}; - /* This function recursively goes through the nodes of the parse tree, replacing each Operation node and its subtree with a Data node. Whenever possible, the replacement produces a new Symbolic Data node, but most often it will produce a RenderedDist. This function is used mainly to turn a parse tree into a single RenderedDist that can then be displayed to the user. */ -let toLeaf = (node: t, sampleCount: int): result(t, string) => { +let rec toLeaf = (renderParams, node: t): result(t, string) => { switch (node) { - | `Leaf(d) => Ok(`Leaf(d)) - | `Operation(op) => operationToLeaf(sampleCount, op) + // Leaf nodes just stay leaf nodes + | `SymbolicDist(_) + | `RenderedDist(_) => Ok(node) + // Operations need to be turned into leaves + | `AlgebraicCombination(algebraicOp, t1, t2) => + AlgebraicCombination.operationToLeaf( + toLeaf, + renderParams, + algebraicOp, + t1, + t2 + ) + | `PointwiseCombination(pointwiseOp, t1, t2) => + PointwiseCombination.operationToLeaf( + toLeaf, + renderParams, + pointwiseOp, + t1, + t2, + ) + | `VerticalScaling(scaleOp, t, scaleBy) => + VerticalScaling.operationToLeaf( + toLeaf, renderParams, scaleOp, t, scaleBy + ) + | `Truncate(leftCutoff, rightCutoff, t) => + Truncate.operationToLeaf(toLeaf, renderParams, leftCutoff, rightCutoff, t) + | `FloatFromDist(distToFloatOp, t) => + FloatFromDist.operationToLeaf(toLeaf, renderParams, distToFloatOp, t) + | `Normalize(t) => Normalize.operationToLeaf(toLeaf, renderParams, t) + | `Render(t) => Render.operationToLeaf(toLeaf, renderParams, t) }; }; diff --git a/src/distPlus/expressionTree/ExpressionTypes.re b/src/distPlus/expressionTree/ExpressionTypes.re index 730a228b..06be9967 100644 --- a/src/distPlus/expressionTree/ExpressionTypes.re +++ b/src/distPlus/expressionTree/ExpressionTypes.re @@ -3,22 +3,18 @@ type pointwiseOperation = [ | `Add | `Multiply]; type scaleOperation = [ | `Multiply | `Exponentiate | `Log]; type distToFloatOperation = [ | `Pdf(float) | `Inv(float) | `Mean | `Sample]; -type abstractOperation('a) = [ - | `AlgebraicCombination(algebraicOperation, 'a, 'a) - | `PointwiseCombination(pointwiseOperation, 'a, 'a) - | `VerticalScaling(scaleOperation, 'a, 'a) - | `Render('a) - | `Truncate(option(float), option(float), 'a) - | `Normalize('a) - | `FloatFromDist(distToFloatOperation, 'a) -]; - module ExpressionTree = { - type leaf = [ + type node = [ + // leaf nodes: | `SymbolicDist(SymbolicTypes.symbolicDist) | `RenderedDist(DistTypes.shape) + // operations: + | `AlgebraicCombination(algebraicOperation, node, node) + | `PointwiseCombination(pointwiseOperation, node, node) + | `VerticalScaling(scaleOperation, node, node) + | `Render(node) + | `Truncate(option(float), option(float), node) + | `Normalize(node) + | `FloatFromDist(distToFloatOperation, node) ]; - - type node = [ | `Leaf(leaf) | `Operation(operation)] - and operation = abstractOperation(node); }; diff --git a/src/distPlus/expressionTree/MathJsParser.re b/src/distPlus/expressionTree/MathJsParser.re index 92227736..42ebb3ec 100644 --- a/src/distPlus/expressionTree/MathJsParser.re +++ b/src/distPlus/expressionTree/MathJsParser.re @@ -86,29 +86,29 @@ module MathAdtToDistDst = { ); }; - let normal: array(arg) => result(ExpressionTypes.ExpressionTree.node, string) = + let normal: + array(arg) => result(ExpressionTypes.ExpressionTree.node, string) = fun | [|Value(mean), Value(stdev)|] => - Ok(`Leaf(`SymbolicDist(`Normal({mean, stdev})))) + Ok(`SymbolicDist(`Normal({mean, stdev}))) | _ => Error("Wrong number of variables in normal distribution"); - let lognormal: array(arg) => result(ExpressionTypes.ExpressionTree.node, string) = + let lognormal: + array(arg) => result(ExpressionTypes.ExpressionTree.node, string) = fun | [|Value(mu), Value(sigma)|] => - Ok(`Leaf(`SymbolicDist(`Lognormal({mu, sigma})))) + Ok(`SymbolicDist(`Lognormal({mu, sigma}))) | [|Object(o)|] => { let g = Js.Dict.get(o); switch (g("mean"), g("stdev"), g("mu"), g("sigma")) { | (Some(Value(mean)), Some(Value(stdev)), _, _) => Ok( - `Leaf( - `SymbolicDist( - SymbolicDist.Lognormal.fromMeanAndStdev(mean, stdev), - ), + `SymbolicDist( + SymbolicDist.Lognormal.fromMeanAndStdev(mean, stdev), ), ) | (_, _, Some(Value(mu)), Some(Value(sigma))) => - Ok(`Leaf(`SymbolicDist(`Lognormal({mu, sigma})))) + Ok(`SymbolicDist(`Lognormal({mu, sigma}))) | _ => Error("Lognormal distribution would need mean and stdev") }; } @@ -117,51 +117,48 @@ module MathAdtToDistDst = { let to_: array(arg) => result(ExpressionTypes.ExpressionTree.node, string) = fun | [|Value(low), Value(high)|] when low <= 0.0 && low < high => { - Ok( - `Leaf( - `SymbolicDist(SymbolicDist.Normal.from90PercentCI(low, high)), - ), - ); + Ok(`SymbolicDist(SymbolicDist.Normal.from90PercentCI(low, high))); } | [|Value(low), Value(high)|] when low < high => { Ok( - `Leaf( - `SymbolicDist(SymbolicDist.Lognormal.from90PercentCI(low, high)), - ), + `SymbolicDist(SymbolicDist.Lognormal.from90PercentCI(low, high)), ); } | [|Value(_), Value(_)|] => Error("Low value must be less than high value.") | _ => Error("Wrong number of variables in lognormal distribution"); - let uniform: array(arg) => result(ExpressionTypes.ExpressionTree.node, string) = + let uniform: + array(arg) => result(ExpressionTypes.ExpressionTree.node, string) = fun | [|Value(low), Value(high)|] => - Ok(`Leaf(`SymbolicDist(`Uniform({low, high})))) + Ok(`SymbolicDist(`Uniform({low, high}))) | _ => Error("Wrong number of variables in lognormal distribution"); let beta: array(arg) => result(ExpressionTypes.ExpressionTree.node, string) = fun | [|Value(alpha), Value(beta)|] => - Ok(`Leaf(`SymbolicDist(`Beta({alpha, beta})))) + Ok(`SymbolicDist(`Beta({alpha, beta}))) | _ => Error("Wrong number of variables in lognormal distribution"); - let exponential: array(arg) => result(ExpressionTypes.ExpressionTree.node, string) = + let exponential: + array(arg) => result(ExpressionTypes.ExpressionTree.node, string) = fun - | [|Value(rate)|] => - Ok(`Leaf(`SymbolicDist(`Exponential({rate: rate})))) + | [|Value(rate)|] => Ok(`SymbolicDist(`Exponential({rate: rate}))) | _ => Error("Wrong number of variables in Exponential distribution"); - let cauchy: array(arg) => result(ExpressionTypes.ExpressionTree.node, string) = + let cauchy: + array(arg) => result(ExpressionTypes.ExpressionTree.node, string) = fun | [|Value(local), Value(scale)|] => - Ok(`Leaf(`SymbolicDist(`Cauchy({local, scale})))) + Ok(`SymbolicDist(`Cauchy({local, scale}))) | _ => Error("Wrong number of variables in cauchy distribution"); - let triangular: array(arg) => result(ExpressionTypes.ExpressionTree.node, string) = + let triangular: + array(arg) => result(ExpressionTypes.ExpressionTree.node, string) = fun | [|Value(low), Value(medium), Value(high)|] => - Ok(`Leaf(`SymbolicDist(`Triangular({low, medium, high})))) + Ok(`SymbolicDist(`Triangular({low, medium, high}))) | _ => Error("Wrong number of variables in triangle distribution"); let multiModal = @@ -192,30 +189,24 @@ module MathAdtToDistDst = { |> E.A.fmapi((index, t) => { let w = weights |> E.A.get(_, index) |> E.O.default(1.0); - `Operation( - `VerticalScaling(( - `Multiply, - t, - `Leaf(`SymbolicDist(`Float(w))), - )), - ); + `VerticalScaling((`Multiply, t, `SymbolicDist(`Float(w)))); }); let pointwiseSum = components |> Js.Array.sliceFrom(1) |> E.A.fold_left( - (acc, x) => { - `Operation(`PointwiseCombination((`Add, acc, x))) - }, + (acc, x) => {`PointwiseCombination((`Add, acc, x))}, E.A.unsafe_get(components, 0), ); - Ok(`Operation(`Normalize(pointwiseSum))); + Ok(`Normalize(pointwiseSum)); }; }; - let arrayParser = (args: array(arg)): result(ExpressionTypes.ExpressionTree.node, string) => { + let arrayParser = + (args: array(arg)) + : result(ExpressionTypes.ExpressionTree.node, string) => { let samples = args |> E.A.fmap( @@ -235,15 +226,18 @@ module MathAdtToDistDst = { SymbolicDist.ContinuousShape.make(_pdf, cdf); }); switch (shape) { - | Some(s) => Ok(`Leaf(`SymbolicDist(`ContinuousShape(s)))) + | Some(s) => Ok(`SymbolicDist(`ContinuousShape(s))) | None => Error("Rendering did not work") }; }; let operationParser = - (name: string, args: array(result(ExpressionTypes.ExpressionTree.node, string))) => { - let toOkAlgebraic = r => Ok(`Operation(`AlgebraicCombination(r))); - let toOkTrunctate = r => Ok(`Operation(`Truncate(r))); + ( + name: string, + args: array(result(ExpressionTypes.ExpressionTree.node, string)), + ) => { + let toOkAlgebraic = r => Ok(`AlgebraicCombination(r)); + let toOkTrunctate = r => Ok(`Truncate(r)); switch (name, args) { | ("add", [|Ok(l), Ok(r)|]) => toOkAlgebraic((`Add, l, r)) | ("add", _) => Error("Addition needs two operands") @@ -254,11 +248,11 @@ module MathAdtToDistDst = { | ("divide", [|Ok(l), Ok(r)|]) => toOkAlgebraic((`Divide, l, r)) | ("divide", _) => Error("Division needs two operands") | ("pow", _) => Error("Exponentiation is not yet supported.") - | ("leftTruncate", [|Ok(d), Ok(`Leaf(`SymbolicDist(`Float(lc))))|]) => + | ("leftTruncate", [|Ok(d), Ok(`SymbolicDist(`Float(lc)))|]) => toOkTrunctate((Some(lc), None, d)) | ("leftTruncate", _) => Error("leftTruncate needs two arguments: the expression and the cutoff") - | ("rightTruncate", [|Ok(d), Ok(`Leaf(`SymbolicDist(`Float(rc))))|]) => + | ("rightTruncate", [|Ok(d), Ok(`SymbolicDist(`Float(rc)))|]) => toOkTrunctate((None, Some(rc), d)) | ("rightTruncate", _) => Error( @@ -268,8 +262,8 @@ module MathAdtToDistDst = { "truncate", [| Ok(d), - Ok(`Leaf(`SymbolicDist(`Float(lc)))), - Ok(`Leaf(`SymbolicDist(`Float(rc)))), + Ok(`SymbolicDist(`Float(lc))), + Ok(`SymbolicDist(`Float(rc))), |], ) => toOkTrunctate((Some(lc), Some(rc), d)) @@ -333,7 +327,7 @@ module MathAdtToDistDst = { let rec nodeParser = fun - | Value(f) => Ok(`Leaf(`SymbolicDist(`Float(f)))) + | Value(f) => Ok(`SymbolicDist(`Float(f))) | Fn({name, args}) => functionParser(nodeParser, name, args) | _ => { Error("This type not currently supported"); diff --git a/src/distPlus/expressionTree/Operation.re b/src/distPlus/expressionTree/Operation.re index 29cee28b..33c05461 100644 --- a/src/distPlus/expressionTree/Operation.re +++ b/src/distPlus/expressionTree/Operation.re @@ -89,5 +89,6 @@ module T = { | `FloatFromDist(floatFromDistOp, t) => DistToFloat.format(floatFromDistOp, nodeToString(t)) | `Truncate(lc, rc, t) => truncateToString(lc, rc, nodeToString(t)) - | `Render(t) => nodeToString(t); + | `Render(t) => nodeToString(t) + | _ => ""; // SymbolicDist and RenderedDist are handled in ExpressionTree.toString. }; diff --git a/src/distPlus/symbolic/SymbolicDist.re b/src/distPlus/symbolic/SymbolicDist.re index 94e513d6..96ecf0c1 100644 --- a/src/distPlus/symbolic/SymbolicDist.re +++ b/src/distPlus/symbolic/SymbolicDist.re @@ -269,23 +269,23 @@ module T = { }; }; - /* This returns an optional that wraps a result. If the optional is None, - there is no valid analytic solution. If it Some, it + /* Calling e.g. "Normal.operate" returns an optional that wraps a result. + If the optional is None, there is no valid analytic solution. If it Some, it can still return an error if there is a serious problem, - like in the casea of a divide by 0. + like in the case of a divide by 0. */ - type analyticalSolutionAttempt = [ + type analyticalSimplificationResult = [ | `AnalyticalSolution(SymbolicTypes.symbolicDist) | `Error(string) | `NoSolution ]; - let attemptAnalyticalOperation = + let tryAnalyticalSimplification = ( d1: symbolicDist, d2: symbolicDist, op: ExpressionTypes.algebraicOperation, ) - : analyticalSolutionAttempt => + : analyticalSimplificationResult => switch (d1, d2) { | (`Float(v1), `Float(v2)) => switch (Operation.Algebraic.applyFn(op, v1, v2)) {