Remove Leaf and Operation wrapper types

This commit is contained in:
Sebastian Kosch 2020-07-03 14:55:27 -07:00
parent a649a6bca2
commit ca9f725ae7
8 changed files with 193 additions and 229 deletions

View File

@ -383,9 +383,9 @@ describe("Shape", () => {
let numSamples = 10000; let numSamples = 10000;
open Distributions.Shape; open Distributions.Shape;
let normal: SymbolicTypes.symbolicDist = `Normal({mean, stdev}); let normal: SymbolicTypes.symbolicDist = `Normal({mean, stdev});
let normalShape = ExpressionTree.toShape(numSamples, `Leaf(`SymbolicDist(normal))); let normalShape = ExpressionTree.toShape(numSamples, `SymbolicDist(normal));
let lognormal = SymbolicDist.Lognormal.fromMeanAndStdev(mean, stdev); let lognormal = SymbolicDist.Lognormal.fromMeanAndStdev(mean, stdev);
let lognormalShape = ExpressionTree.toShape(numSamples, `Leaf(`SymbolicDist(lognormal))); let lognormalShape = ExpressionTree.toShape(numSamples, `SymbolicDist(lognormal));
makeTestCloseEquality( makeTestCloseEquality(
"Mean of a normal", "Mean of a normal",

View File

@ -389,7 +389,7 @@ module Draw = {
let numSamples = 3000; let numSamples = 3000;
let normal: SymbolicTypes.symbolicDist = `Normal({mean, stdev}); let normal: SymbolicTypes.symbolicDist = `Normal({mean, stdev});
let normalShape = ExpressionTree.toShape(numSamples, `Leaf(`SymbolicDist(normal))); let normalShape = ExpressionTree.toShape(numSamples, `SymbolicDist(normal));
let xyShape: Types.xyShape = let xyShape: Types.xyShape =
switch (normalShape) { switch (normalShape) {
| Mixed(_) => {xs: [||], ys: [||]} | Mixed(_) => {xs: [||], ys: [||]}

View File

@ -2,10 +2,11 @@ open ExpressionTypes.ExpressionTree;
let toShape = (sampleCount: int, node: node) => { let toShape = (sampleCount: int, node: node) => {
let renderResult = let renderResult =
ExpressionTreeEvaluator.toLeaf(`Operation(`Render(node)), sampleCount); `Render(`Normalize(node))
|> ExpressionTreeEvaluator.toLeaf({sampleCount: sampleCount});
switch (renderResult) { switch (renderResult) {
| Ok(`Leaf(`RenderedDist(rs))) => | Ok(`RenderedDist(rs)) =>
let continuous = Distributions.Shape.T.toContinuous(rs); let continuous = Distributions.Shape.T.toContinuous(rs);
let discrete = Distributions.Shape.T.toDiscrete(rs); let discrete = Distributions.Shape.T.toDiscrete(rs);
let shape = MixedShapeBuilder.buildSimple(~continuous, ~discrete); let shape = MixedShapeBuilder.buildSimple(~continuous, ~discrete);
@ -17,6 +18,6 @@ let toShape = (sampleCount: int, node: node) => {
let rec toString = let rec toString =
fun fun
| `Leaf(`SymbolicDist(d)) => SymbolicDist.T.toString(d) | `SymbolicDist(d) => SymbolicDist.T.toString(d)
| `Leaf(`RenderedDist(_)) => "[shape]" | `RenderedDist(_) => "[shape]"
| `Operation(op) => Operation.T.toString(toString, op); | op => Operation.T.toString(toString, op);

View File

@ -1,84 +1,77 @@
/* This module represents a tree node. */
open ExpressionTypes; open ExpressionTypes;
open ExpressionTypes.ExpressionTree; open ExpressionTypes.ExpressionTree;
type t = node; type t = node;
type tResult = node => result(node, string); type tResult = node => result(node, string);
type renderParams = {
sampleCount: int,
};
/* Given two random variables A and B, this returns the distribution /* Given two random variables A and B, this returns the distribution
of a new variable that is the result of the operation on A and B. of a new variable that is the result of the operation on A and B.
For instance, normal(0, 1) + normal(1, 1) -> normal(1, 2). For instance, normal(0, 1) + normal(1, 1) -> normal(1, 2).
In general, this is implemented via convolution. */ In general, this is implemented via convolution. */
module AlgebraicCombination = { module AlgebraicCombination = {
let toTreeNode = (op, t1, t2) => let tryAnalyticalSimplification = (operation, t1: t, t2: t) =>
`Operation(`AlgebraicCombination((op, t1, t2))); switch (operation, t1, t2) {
let tryAnalyticalSolution = | (operation,
fun `SymbolicDist(d1),
| `Operation( `SymbolicDist(d2),
`AlgebraicCombination( ) =>
operation, switch (SymbolicDist.T.tryAnalyticalSimplification(d1, d2, operation)) {
`Leaf(`SymbolicDist(d1)), | `AnalyticalSolution(symbolicDist) => Ok(`SymbolicDist(symbolicDist))
`Leaf(`SymbolicDist(d2)),
),
) as t =>
switch (SymbolicDist.T.attemptAnalyticalOperation(d1, d2, operation)) {
| `AnalyticalSolution(symbolicDist) =>
Ok(`Leaf(`SymbolicDist(symbolicDist)))
| `Error(er) => Error(er) | `Error(er) => Error(er)
| `NoSolution => Ok(t) | `NoSolution => Ok(`AlgebraicCombination(operation, t1, t2))
} }
| t => Ok(t); | _ => Ok(`AlgebraicCombination(operation, t1, t2))
};
// todo: I don't like the name evaluateNumerically that much, if this renders and does it algebraically. It's tricky. let combineAsShapes = (toLeaf, renderParams, algebraicOp, t1, t2) => {
let evaluateNumerically = (algebraicOp, operationToLeaf, t1, t2) => { let renderShape = r => toLeaf(renderParams, `Render(r));
// force rendering into shapes
let renderShape = r => operationToLeaf(`Render(r));
switch (renderShape(t1), renderShape(t2)) { switch (renderShape(t1), renderShape(t2)) {
| (Ok(`Leaf(`RenderedDist(s1))), Ok(`Leaf(`RenderedDist(s2)))) => | (Ok(`RenderedDist(s1)), Ok(`RenderedDist(s2))) =>
Ok( Ok(
`Leaf(
`RenderedDist( `RenderedDist(
Distributions.Shape.combineAlgebraically(algebraicOp, s1, s2), Distributions.Shape.combineAlgebraically(algebraicOp, s1, s2),
), ),
),
) )
| (Error(e1), _) => Error(e1) | (Error(e1), _) => Error(e1)
| (_, Error(e2)) => Error(e2) | (_, Error(e2)) => Error(e2)
| _ => Error("Could not render shapes.") | _ => Error("Algebraic combination: rendering failed.")
}; };
}; };
let toLeaf = let operationToLeaf =
( (
operationToLeaf, toLeaf,
renderParams: renderParams,
algebraicOp: ExpressionTypes.algebraicOperation, algebraicOp: ExpressionTypes.algebraicOperation,
t1: t, t1: t,
t2: t, t2: t,
) )
: result(node, string) => : result(node, string) =>
toTreeNode(algebraicOp, t1, t2)
|> tryAnalyticalSolution algebraicOp
|> tryAnalyticalSimplification(_, t1, t2)
|> E.R.bind( |> E.R.bind(
_, _,
fun fun
| `Leaf(d) => Ok(`Leaf(d)) // the analytical simplifaction worked, nice! | `SymbolicDist(d) as t => Ok(t)
| `Operation(_) => | _ => combineAsShapes(toLeaf, renderParams, algebraicOp, t1, t2)
// if not, run the convolution
evaluateNumerically(algebraicOp, operationToLeaf, t1, t2),
); );
}; };
module VerticalScaling = { module VerticalScaling = {
let toLeaf = (operationToLeaf, scaleOp, t, scaleBy) => { let operationToLeaf = (toLeaf, renderParams, scaleOp, t, scaleBy) => {
// scaleBy has to be a single float, otherwise we'll return an error. // scaleBy has to be a single float, otherwise we'll return an error.
let fn = Operation.Scale.toFn(scaleOp); let fn = Operation.Scale.toFn(scaleOp);
let knownIntegralSumFn = Operation.Scale.toKnownIntegralSumFn(scaleOp); let knownIntegralSumFn = Operation.Scale.toKnownIntegralSumFn(scaleOp);
let renderedShape = operationToLeaf(`Render(t)); let renderedShape = toLeaf(renderParams, `Render(t));
switch (renderedShape, scaleBy) { switch (renderedShape, scaleBy) {
| (Ok(`Leaf(`RenderedDist(rs))), `Leaf(`SymbolicDist(`Float(sm)))) => | (Ok(`RenderedDist(rs)), `SymbolicDist(`Float(sm))) =>
Ok( Ok(
`Leaf(
`RenderedDist( `RenderedDist(
Distributions.Shape.T.mapY( Distributions.Shape.T.mapY(
~knownIntegralSumFn=knownIntegralSumFn(sm), ~knownIntegralSumFn=knownIntegralSumFn(sm),
@ -86,7 +79,6 @@ module VerticalScaling = {
rs, rs,
), ),
), ),
),
) )
| (Error(e1), _) => Error(e1) | (Error(e1), _) => Error(e1)
| (_, _) => Error("Can only scale by float values.") | (_, _) => Error("Can only scale by float values.")
@ -95,14 +87,11 @@ module VerticalScaling = {
}; };
module PointwiseCombination = { module PointwiseCombination = {
let pointwiseAdd = (operationToLeaf, t1, t2) => { let pointwiseAdd = (toLeaf, renderParams, t1, t2) => {
let renderedShape1 = operationToLeaf(`Render(t1)); let renderShape = r => toLeaf(renderParams, `Render(r));
let renderedShape2 = operationToLeaf(`Render(t2)); switch (renderShape(t1), renderShape(t2)) {
| (Ok(`RenderedDist(rs1)), Ok(`RenderedDist(rs2))) =>
switch (renderedShape1, renderedShape2) {
| (Ok(`Leaf(`RenderedDist(rs1))), Ok(`Leaf(`RenderedDist(rs2)))) =>
Ok( Ok(
`Leaf(
`RenderedDist( `RenderedDist(
Distributions.Shape.combinePointwise( Distributions.Shape.combinePointwise(
~knownIntegralSumsFn=(a, b) => Some(a +. b), ~knownIntegralSumsFn=(a, b) => Some(a +. b),
@ -111,15 +100,14 @@ module PointwiseCombination = {
rs2, rs2,
), ),
), ),
),
) )
| (Error(e1), _) => Error(e1) | (Error(e1), _) => Error(e1)
| (_, Error(e2)) => Error(e2) | (_, Error(e2)) => Error(e2)
| _ => Error("Could not perform pointwise addition.") | _ => Error("Pointwise combination: rendering failed.")
}; };
}; };
let pointwiseMultiply = (operationToLeaf, t1, t2) => { let pointwiseMultiply = (toLeaf, renderParams, t1, t2) => {
// TODO: construct a function that we can easily sample from, to construct // TODO: construct a function that we can easily sample from, to construct
// a RenderedDist. Use the xMin and xMax of the rendered shapes to tell the sampling function where to look. // a RenderedDist. Use the xMin and xMax of the rendered shapes to tell the sampling function where to look.
Error( Error(
@ -127,84 +115,72 @@ module PointwiseCombination = {
); );
}; };
let toLeaf = (operationToLeaf, pointwiseOp, t1, t2) => { let operationToLeaf = (toLeaf, renderParams, pointwiseOp, t1, t2) => {
switch (pointwiseOp) { switch (pointwiseOp) {
| `Add => pointwiseAdd(operationToLeaf, t1, t2) | `Add => pointwiseAdd(toLeaf, renderParams, t1, t2)
| `Multiply => pointwiseMultiply(operationToLeaf, t1, t2) | `Multiply => pointwiseMultiply(toLeaf, renderParams, t1, t2)
}; };
}; };
}; };
module Truncate = { module Truncate = {
module Simplify = { let trySimplification = (leftCutoff, rightCutoff, t) => {
let tryTruncatingNothing: tResult = switch (leftCutoff, rightCutoff, t) {
fun | (None, None, t) => Ok(t)
| `Operation(`Truncate(None, None, `Leaf(d))) => Ok(`Leaf(d)) | (lc, rc, `SymbolicDist(`Uniform(u))) => {
| t => Ok(t);
let tryTruncatingUniform: tResult =
fun
| `Operation(`Truncate(lc, rc, `Leaf(`SymbolicDist(`Uniform(u))))) => {
// just create a new Uniform distribution // just create a new Uniform distribution
let newLow = max(E.O.default(neg_infinity, lc), u.low); let nu: SymbolicTypes.uniform = u;
let newHigh = min(E.O.default(infinity, rc), u.high); let newLow = max(E.O.default(neg_infinity, lc), nu.low);
Ok(`Leaf(`SymbolicDist(`Uniform({low: newLow, high: newHigh})))); let newHigh = min(E.O.default(infinity, rc), nu.high);
Ok(`SymbolicDist(`Uniform({low: newLow, high: newHigh})));
} }
| t => Ok(t); | (_, _, t) => Ok(t)
let attempt = (leftCutoff, rightCutoff, t): result(node, string) => {
let originalTreeNode =
`Operation(`Truncate((leftCutoff, rightCutoff, t)));
originalTreeNode
|> tryTruncatingNothing
|> E.R.bind(_, tryTruncatingUniform);
}; };
}; };
let evaluateNumerically = (leftCutoff, rightCutoff, operationToLeaf, t) => { let truncateAsShape = (toLeaf, renderParams, leftCutoff, rightCutoff, t) => {
// TODO: use named args in renderToShape; if we're lucky we can at least get the tail // TODO: use named args in renderToShape; if we're lucky we can at least get the tail
// of a distribution we otherwise wouldn't get at all // of a distribution we otherwise wouldn't get at all
let renderedShape = operationToLeaf(`Render(t)); let renderedShape = toLeaf(renderParams, `Render(t));
switch (renderedShape) { switch (renderedShape) {
| Ok(`Leaf(`RenderedDist(rs))) => | Ok(`RenderedDist(rs)) => {
let truncatedShape = let truncatedShape =
rs |> Distributions.Shape.T.truncate(leftCutoff, rightCutoff); rs |> Distributions.Shape.T.truncate(leftCutoff, rightCutoff);
Ok(`Leaf(`RenderedDist(rs))); Ok(`RenderedDist(rs));
}
| Error(e1) => Error(e1) | Error(e1) => Error(e1)
| _ => Error("Could not truncate distribution.") | _ => Error("Could not truncate distribution.")
}; };
}; };
let toLeaf = let operationToLeaf =
( (
operationToLeaf, toLeaf,
renderParams,
leftCutoff: option(float), leftCutoff: option(float),
rightCutoff: option(float), rightCutoff: option(float),
t: node, t: node,
) )
: result(node, string) => { : result(node, string) => {
t t
|> Simplify.attempt(leftCutoff, rightCutoff) |> trySimplification(leftCutoff, rightCutoff)
|> E.R.bind( |> E.R.bind(
_, _,
fun fun
| `Leaf(d) => Ok(`Leaf(d)) // the analytical simplifaction worked, nice! | `SymbolicDist(d) as t => Ok(t)
| `Operation(_) => | _ => truncateAsShape(toLeaf, renderParams, leftCutoff, rightCutoff, t),
evaluateNumerically(leftCutoff, rightCutoff, operationToLeaf, t), );
); // if not, run the convolution
}; };
}; };
module Normalize = { module Normalize = {
let rec toLeaf = (operationToLeaf, t: node): result(node, string) => { let rec operationToLeaf = (toLeaf, renderParams, t: node): result(node, string) => {
switch (t) { switch (t) {
| `Leaf(`RenderedDist(s)) => | `RenderedDist(s) =>
Ok(`Leaf(`RenderedDist(Distributions.Shape.T.normalize(s)))) Ok(`RenderedDist(Distributions.Shape.T.normalize(s)))
| `Leaf(`SymbolicDist(_)) => Ok(t) | `SymbolicDist(_) => Ok(t)
| `Operation(op) => | _ => t |> toLeaf(renderParams) |> E.R.bind(_, operationToLeaf(toLeaf, renderParams))
operationToLeaf(op) |> E.R.bind(_, toLeaf(operationToLeaf))
}; };
}; };
}; };
@ -212,83 +188,79 @@ module Normalize = {
module FloatFromDist = { module FloatFromDist = {
let symbolicToLeaf = (distToFloatOp: distToFloatOperation, s) => { let symbolicToLeaf = (distToFloatOp: distToFloatOperation, s) => {
SymbolicDist.T.operate(distToFloatOp, s) SymbolicDist.T.operate(distToFloatOp, s)
|> E.R.bind(_, v => Ok(`Leaf(`SymbolicDist(`Float(v))))); |> E.R.bind(_, v => Ok(`SymbolicDist(`Float(v))));
}; };
let renderedToLeaf = let renderedToLeaf =
(distToFloatOp: distToFloatOperation, rs: DistTypes.shape) (distToFloatOp: distToFloatOperation, rs: DistTypes.shape)
: result(node, string) => { : result(node, string) => {
Distributions.Shape.operate(distToFloatOp, rs) Distributions.Shape.operate(distToFloatOp, rs)
|> (v => Ok(`Leaf(`SymbolicDist(`Float(v))))); |> (v => Ok(`SymbolicDist(`Float(v))));
}; };
let rec toLeaf = let rec operationToLeaf =
(operationToLeaf, distToFloatOp: distToFloatOperation, t: node) (toLeaf, renderParams, distToFloatOp: distToFloatOperation, t: node)
: result(node, string) => { : result(node, string) => {
switch (t) { switch (t) {
| `Leaf(`SymbolicDist(s)) => symbolicToLeaf(distToFloatOp, s) // we want to evaluate the distToFloatOp on the symbolic dist | `SymbolicDist(s) => symbolicToLeaf(distToFloatOp, s)
| `Leaf(`RenderedDist(rs)) => renderedToLeaf(distToFloatOp, rs) | `RenderedDist(rs) => renderedToLeaf(distToFloatOp, rs)
| `Operation(op) => | _ => t |> toLeaf(renderParams) |> E.R.bind(_, operationToLeaf(toLeaf, renderParams, distToFloatOp))
E.R.bind(operationToLeaf(op), toLeaf(operationToLeaf, distToFloatOp))
}; };
}; };
}; };
module Render = { module Render = {
let rec toLeaf = let rec operationToLeaf =
( (
operationToLeaf: operation => result(t, string), toLeaf,
sampleCount: int, renderParams,
t: node, t: node,
) )
: result(t, string) => { : result(t, string) => {
switch (t) { switch (t) {
| `Leaf(`SymbolicDist(d)) => | `SymbolicDist(d) =>
Ok(`Leaf(`RenderedDist(SymbolicDist.T.toShape(sampleCount, d)))) Ok(`RenderedDist(SymbolicDist.T.toShape(renderParams.sampleCount, d)))
| `Leaf(`RenderedDist(_)) as t => Ok(t) // already a rendered shape, we're done here | `RenderedDist(_) as t => Ok(t) // already a rendered shape, we're done here
| `Operation(op) => | _ => t |> toLeaf(renderParams) |> E.R.bind(_, operationToLeaf(toLeaf, renderParams))
E.R.bind(operationToLeaf(op), toLeaf(operationToLeaf, sampleCount))
}; };
}; };
}; };
let rec operationToLeaf =
(sampleCount: int, op: operation): result(t, string) => {
// the functions that convert the Operation nodes to Leaf nodes need to
// have a way to call this function on their children, if their children are themselves Operation nodes.
switch (op) {
| `AlgebraicCombination(algebraicOp, t1, t2) =>
AlgebraicCombination.toLeaf(
operationToLeaf(sampleCount),
algebraicOp,
t1,
t2 // we want to give it the option to render or simply leave it as is
)
| `PointwiseCombination(pointwiseOp, t1, t2) =>
PointwiseCombination.toLeaf(
operationToLeaf(sampleCount),
pointwiseOp,
t1,
t2,
)
| `VerticalScaling(scaleOp, t, scaleBy) =>
VerticalScaling.toLeaf(operationToLeaf(sampleCount), scaleOp, t, scaleBy)
| `Truncate(leftCutoff, rightCutoff, t) =>
Truncate.toLeaf(operationToLeaf(sampleCount), leftCutoff, rightCutoff, t)
| `FloatFromDist(distToFloatOp, t) =>
FloatFromDist.toLeaf(operationToLeaf(sampleCount), distToFloatOp, t)
| `Normalize(t) => Normalize.toLeaf(operationToLeaf(sampleCount), t)
| `Render(t) => Render.toLeaf(operationToLeaf(sampleCount), sampleCount, t)
};
};
/* This function recursively goes through the nodes of the parse tree, /* This function recursively goes through the nodes of the parse tree,
replacing each Operation node and its subtree with a Data node. replacing each Operation node and its subtree with a Data node.
Whenever possible, the replacement produces a new Symbolic Data node, Whenever possible, the replacement produces a new Symbolic Data node,
but most often it will produce a RenderedDist. but most often it will produce a RenderedDist.
This function is used mainly to turn a parse tree into a single RenderedDist This function is used mainly to turn a parse tree into a single RenderedDist
that can then be displayed to the user. */ that can then be displayed to the user. */
let toLeaf = (node: t, sampleCount: int): result(t, string) => { let rec toLeaf = (renderParams, node: t): result(t, string) => {
switch (node) { switch (node) {
| `Leaf(d) => Ok(`Leaf(d)) // Leaf nodes just stay leaf nodes
| `Operation(op) => operationToLeaf(sampleCount, op) | `SymbolicDist(_)
| `RenderedDist(_) => Ok(node)
// Operations need to be turned into leaves
| `AlgebraicCombination(algebraicOp, t1, t2) =>
AlgebraicCombination.operationToLeaf(
toLeaf,
renderParams,
algebraicOp,
t1,
t2
)
| `PointwiseCombination(pointwiseOp, t1, t2) =>
PointwiseCombination.operationToLeaf(
toLeaf,
renderParams,
pointwiseOp,
t1,
t2,
)
| `VerticalScaling(scaleOp, t, scaleBy) =>
VerticalScaling.operationToLeaf(
toLeaf, renderParams, scaleOp, t, scaleBy
)
| `Truncate(leftCutoff, rightCutoff, t) =>
Truncate.operationToLeaf(toLeaf, renderParams, leftCutoff, rightCutoff, t)
| `FloatFromDist(distToFloatOp, t) =>
FloatFromDist.operationToLeaf(toLeaf, renderParams, distToFloatOp, t)
| `Normalize(t) => Normalize.operationToLeaf(toLeaf, renderParams, t)
| `Render(t) => Render.operationToLeaf(toLeaf, renderParams, t)
}; };
}; };

View File

@ -3,22 +3,18 @@ type pointwiseOperation = [ | `Add | `Multiply];
type scaleOperation = [ | `Multiply | `Exponentiate | `Log]; type scaleOperation = [ | `Multiply | `Exponentiate | `Log];
type distToFloatOperation = [ | `Pdf(float) | `Inv(float) | `Mean | `Sample]; type distToFloatOperation = [ | `Pdf(float) | `Inv(float) | `Mean | `Sample];
type abstractOperation('a) = [
| `AlgebraicCombination(algebraicOperation, 'a, 'a)
| `PointwiseCombination(pointwiseOperation, 'a, 'a)
| `VerticalScaling(scaleOperation, 'a, 'a)
| `Render('a)
| `Truncate(option(float), option(float), 'a)
| `Normalize('a)
| `FloatFromDist(distToFloatOperation, 'a)
];
module ExpressionTree = { module ExpressionTree = {
type leaf = [ type node = [
// leaf nodes:
| `SymbolicDist(SymbolicTypes.symbolicDist) | `SymbolicDist(SymbolicTypes.symbolicDist)
| `RenderedDist(DistTypes.shape) | `RenderedDist(DistTypes.shape)
// operations:
| `AlgebraicCombination(algebraicOperation, node, node)
| `PointwiseCombination(pointwiseOperation, node, node)
| `VerticalScaling(scaleOperation, node, node)
| `Render(node)
| `Truncate(option(float), option(float), node)
| `Normalize(node)
| `FloatFromDist(distToFloatOperation, node)
]; ];
type node = [ | `Leaf(leaf) | `Operation(operation)]
and operation = abstractOperation(node);
}; };

View File

@ -86,29 +86,29 @@ module MathAdtToDistDst = {
); );
}; };
let normal: array(arg) => result(ExpressionTypes.ExpressionTree.node, string) = let normal:
array(arg) => result(ExpressionTypes.ExpressionTree.node, string) =
fun fun
| [|Value(mean), Value(stdev)|] => | [|Value(mean), Value(stdev)|] =>
Ok(`Leaf(`SymbolicDist(`Normal({mean, stdev})))) Ok(`SymbolicDist(`Normal({mean, stdev})))
| _ => Error("Wrong number of variables in normal distribution"); | _ => Error("Wrong number of variables in normal distribution");
let lognormal: array(arg) => result(ExpressionTypes.ExpressionTree.node, string) = let lognormal:
array(arg) => result(ExpressionTypes.ExpressionTree.node, string) =
fun fun
| [|Value(mu), Value(sigma)|] => | [|Value(mu), Value(sigma)|] =>
Ok(`Leaf(`SymbolicDist(`Lognormal({mu, sigma})))) Ok(`SymbolicDist(`Lognormal({mu, sigma})))
| [|Object(o)|] => { | [|Object(o)|] => {
let g = Js.Dict.get(o); let g = Js.Dict.get(o);
switch (g("mean"), g("stdev"), g("mu"), g("sigma")) { switch (g("mean"), g("stdev"), g("mu"), g("sigma")) {
| (Some(Value(mean)), Some(Value(stdev)), _, _) => | (Some(Value(mean)), Some(Value(stdev)), _, _) =>
Ok( Ok(
`Leaf(
`SymbolicDist( `SymbolicDist(
SymbolicDist.Lognormal.fromMeanAndStdev(mean, stdev), SymbolicDist.Lognormal.fromMeanAndStdev(mean, stdev),
), ),
),
) )
| (_, _, Some(Value(mu)), Some(Value(sigma))) => | (_, _, Some(Value(mu)), Some(Value(sigma))) =>
Ok(`Leaf(`SymbolicDist(`Lognormal({mu, sigma})))) Ok(`SymbolicDist(`Lognormal({mu, sigma})))
| _ => Error("Lognormal distribution would need mean and stdev") | _ => Error("Lognormal distribution would need mean and stdev")
}; };
} }
@ -117,51 +117,48 @@ module MathAdtToDistDst = {
let to_: array(arg) => result(ExpressionTypes.ExpressionTree.node, string) = let to_: array(arg) => result(ExpressionTypes.ExpressionTree.node, string) =
fun fun
| [|Value(low), Value(high)|] when low <= 0.0 && low < high => { | [|Value(low), Value(high)|] when low <= 0.0 && low < high => {
Ok( Ok(`SymbolicDist(SymbolicDist.Normal.from90PercentCI(low, high)));
`Leaf(
`SymbolicDist(SymbolicDist.Normal.from90PercentCI(low, high)),
),
);
} }
| [|Value(low), Value(high)|] when low < high => { | [|Value(low), Value(high)|] when low < high => {
Ok( Ok(
`Leaf(
`SymbolicDist(SymbolicDist.Lognormal.from90PercentCI(low, high)), `SymbolicDist(SymbolicDist.Lognormal.from90PercentCI(low, high)),
),
); );
} }
| [|Value(_), Value(_)|] => | [|Value(_), Value(_)|] =>
Error("Low value must be less than high value.") Error("Low value must be less than high value.")
| _ => Error("Wrong number of variables in lognormal distribution"); | _ => Error("Wrong number of variables in lognormal distribution");
let uniform: array(arg) => result(ExpressionTypes.ExpressionTree.node, string) = let uniform:
array(arg) => result(ExpressionTypes.ExpressionTree.node, string) =
fun fun
| [|Value(low), Value(high)|] => | [|Value(low), Value(high)|] =>
Ok(`Leaf(`SymbolicDist(`Uniform({low, high})))) Ok(`SymbolicDist(`Uniform({low, high})))
| _ => Error("Wrong number of variables in lognormal distribution"); | _ => Error("Wrong number of variables in lognormal distribution");
let beta: array(arg) => result(ExpressionTypes.ExpressionTree.node, string) = let beta: array(arg) => result(ExpressionTypes.ExpressionTree.node, string) =
fun fun
| [|Value(alpha), Value(beta)|] => | [|Value(alpha), Value(beta)|] =>
Ok(`Leaf(`SymbolicDist(`Beta({alpha, beta})))) Ok(`SymbolicDist(`Beta({alpha, beta})))
| _ => Error("Wrong number of variables in lognormal distribution"); | _ => Error("Wrong number of variables in lognormal distribution");
let exponential: array(arg) => result(ExpressionTypes.ExpressionTree.node, string) = let exponential:
array(arg) => result(ExpressionTypes.ExpressionTree.node, string) =
fun fun
| [|Value(rate)|] => | [|Value(rate)|] => Ok(`SymbolicDist(`Exponential({rate: rate})))
Ok(`Leaf(`SymbolicDist(`Exponential({rate: rate}))))
| _ => Error("Wrong number of variables in Exponential distribution"); | _ => Error("Wrong number of variables in Exponential distribution");
let cauchy: array(arg) => result(ExpressionTypes.ExpressionTree.node, string) = let cauchy:
array(arg) => result(ExpressionTypes.ExpressionTree.node, string) =
fun fun
| [|Value(local), Value(scale)|] => | [|Value(local), Value(scale)|] =>
Ok(`Leaf(`SymbolicDist(`Cauchy({local, scale})))) Ok(`SymbolicDist(`Cauchy({local, scale})))
| _ => Error("Wrong number of variables in cauchy distribution"); | _ => Error("Wrong number of variables in cauchy distribution");
let triangular: array(arg) => result(ExpressionTypes.ExpressionTree.node, string) = let triangular:
array(arg) => result(ExpressionTypes.ExpressionTree.node, string) =
fun fun
| [|Value(low), Value(medium), Value(high)|] => | [|Value(low), Value(medium), Value(high)|] =>
Ok(`Leaf(`SymbolicDist(`Triangular({low, medium, high})))) Ok(`SymbolicDist(`Triangular({low, medium, high})))
| _ => Error("Wrong number of variables in triangle distribution"); | _ => Error("Wrong number of variables in triangle distribution");
let multiModal = let multiModal =
@ -192,30 +189,24 @@ module MathAdtToDistDst = {
|> E.A.fmapi((index, t) => { |> E.A.fmapi((index, t) => {
let w = weights |> E.A.get(_, index) |> E.O.default(1.0); let w = weights |> E.A.get(_, index) |> E.O.default(1.0);
`Operation( `VerticalScaling((`Multiply, t, `SymbolicDist(`Float(w))));
`VerticalScaling((
`Multiply,
t,
`Leaf(`SymbolicDist(`Float(w))),
)),
);
}); });
let pointwiseSum = let pointwiseSum =
components components
|> Js.Array.sliceFrom(1) |> Js.Array.sliceFrom(1)
|> E.A.fold_left( |> E.A.fold_left(
(acc, x) => { (acc, x) => {`PointwiseCombination((`Add, acc, x))},
`Operation(`PointwiseCombination((`Add, acc, x)))
},
E.A.unsafe_get(components, 0), E.A.unsafe_get(components, 0),
); );
Ok(`Operation(`Normalize(pointwiseSum))); Ok(`Normalize(pointwiseSum));
}; };
}; };
let arrayParser = (args: array(arg)): result(ExpressionTypes.ExpressionTree.node, string) => { let arrayParser =
(args: array(arg))
: result(ExpressionTypes.ExpressionTree.node, string) => {
let samples = let samples =
args args
|> E.A.fmap( |> E.A.fmap(
@ -235,15 +226,18 @@ module MathAdtToDistDst = {
SymbolicDist.ContinuousShape.make(_pdf, cdf); SymbolicDist.ContinuousShape.make(_pdf, cdf);
}); });
switch (shape) { switch (shape) {
| Some(s) => Ok(`Leaf(`SymbolicDist(`ContinuousShape(s)))) | Some(s) => Ok(`SymbolicDist(`ContinuousShape(s)))
| None => Error("Rendering did not work") | None => Error("Rendering did not work")
}; };
}; };
let operationParser = let operationParser =
(name: string, args: array(result(ExpressionTypes.ExpressionTree.node, string))) => { (
let toOkAlgebraic = r => Ok(`Operation(`AlgebraicCombination(r))); name: string,
let toOkTrunctate = r => Ok(`Operation(`Truncate(r))); args: array(result(ExpressionTypes.ExpressionTree.node, string)),
) => {
let toOkAlgebraic = r => Ok(`AlgebraicCombination(r));
let toOkTrunctate = r => Ok(`Truncate(r));
switch (name, args) { switch (name, args) {
| ("add", [|Ok(l), Ok(r)|]) => toOkAlgebraic((`Add, l, r)) | ("add", [|Ok(l), Ok(r)|]) => toOkAlgebraic((`Add, l, r))
| ("add", _) => Error("Addition needs two operands") | ("add", _) => Error("Addition needs two operands")
@ -254,11 +248,11 @@ module MathAdtToDistDst = {
| ("divide", [|Ok(l), Ok(r)|]) => toOkAlgebraic((`Divide, l, r)) | ("divide", [|Ok(l), Ok(r)|]) => toOkAlgebraic((`Divide, l, r))
| ("divide", _) => Error("Division needs two operands") | ("divide", _) => Error("Division needs two operands")
| ("pow", _) => Error("Exponentiation is not yet supported.") | ("pow", _) => Error("Exponentiation is not yet supported.")
| ("leftTruncate", [|Ok(d), Ok(`Leaf(`SymbolicDist(`Float(lc))))|]) => | ("leftTruncate", [|Ok(d), Ok(`SymbolicDist(`Float(lc)))|]) =>
toOkTrunctate((Some(lc), None, d)) toOkTrunctate((Some(lc), None, d))
| ("leftTruncate", _) => | ("leftTruncate", _) =>
Error("leftTruncate needs two arguments: the expression and the cutoff") Error("leftTruncate needs two arguments: the expression and the cutoff")
| ("rightTruncate", [|Ok(d), Ok(`Leaf(`SymbolicDist(`Float(rc))))|]) => | ("rightTruncate", [|Ok(d), Ok(`SymbolicDist(`Float(rc)))|]) =>
toOkTrunctate((None, Some(rc), d)) toOkTrunctate((None, Some(rc), d))
| ("rightTruncate", _) => | ("rightTruncate", _) =>
Error( Error(
@ -268,8 +262,8 @@ module MathAdtToDistDst = {
"truncate", "truncate",
[| [|
Ok(d), Ok(d),
Ok(`Leaf(`SymbolicDist(`Float(lc)))), Ok(`SymbolicDist(`Float(lc))),
Ok(`Leaf(`SymbolicDist(`Float(rc)))), Ok(`SymbolicDist(`Float(rc))),
|], |],
) => ) =>
toOkTrunctate((Some(lc), Some(rc), d)) toOkTrunctate((Some(lc), Some(rc), d))
@ -333,7 +327,7 @@ module MathAdtToDistDst = {
let rec nodeParser = let rec nodeParser =
fun fun
| Value(f) => Ok(`Leaf(`SymbolicDist(`Float(f)))) | Value(f) => Ok(`SymbolicDist(`Float(f)))
| Fn({name, args}) => functionParser(nodeParser, name, args) | Fn({name, args}) => functionParser(nodeParser, name, args)
| _ => { | _ => {
Error("This type not currently supported"); Error("This type not currently supported");

View File

@ -89,5 +89,6 @@ module T = {
| `FloatFromDist(floatFromDistOp, t) => | `FloatFromDist(floatFromDistOp, t) =>
DistToFloat.format(floatFromDistOp, nodeToString(t)) DistToFloat.format(floatFromDistOp, nodeToString(t))
| `Truncate(lc, rc, t) => truncateToString(lc, rc, nodeToString(t)) | `Truncate(lc, rc, t) => truncateToString(lc, rc, nodeToString(t))
| `Render(t) => nodeToString(t); | `Render(t) => nodeToString(t)
| _ => ""; // SymbolicDist and RenderedDist are handled in ExpressionTree.toString.
}; };

View File

@ -269,23 +269,23 @@ module T = {
}; };
}; };
/* This returns an optional that wraps a result. If the optional is None, /* Calling e.g. "Normal.operate" returns an optional that wraps a result.
there is no valid analytic solution. If it Some, it If the optional is None, there is no valid analytic solution. If it Some, it
can still return an error if there is a serious problem, can still return an error if there is a serious problem,
like in the casea of a divide by 0. like in the case of a divide by 0.
*/ */
type analyticalSolutionAttempt = [ type analyticalSimplificationResult = [
| `AnalyticalSolution(SymbolicTypes.symbolicDist) | `AnalyticalSolution(SymbolicTypes.symbolicDist)
| `Error(string) | `Error(string)
| `NoSolution | `NoSolution
]; ];
let attemptAnalyticalOperation = let tryAnalyticalSimplification =
( (
d1: symbolicDist, d1: symbolicDist,
d2: symbolicDist, d2: symbolicDist,
op: ExpressionTypes.algebraicOperation, op: ExpressionTypes.algebraicOperation,
) )
: analyticalSolutionAttempt => : analyticalSimplificationResult =>
switch (d1, d2) { switch (d1, d2) {
| (`Float(v1), `Float(v2)) => | (`Float(v1), `Float(v2)) =>
switch (Operation.Algebraic.applyFn(op, v1, v2)) { switch (Operation.Algebraic.applyFn(op, v1, v2)) {