Remove Leaf and Operation wrapper types
This commit is contained in:
parent
a649a6bca2
commit
ca9f725ae7
|
@ -383,9 +383,9 @@ describe("Shape", () => {
|
|||
let numSamples = 10000;
|
||||
open Distributions.Shape;
|
||||
let normal: SymbolicTypes.symbolicDist = `Normal({mean, stdev});
|
||||
let normalShape = ExpressionTree.toShape(numSamples, `Leaf(`SymbolicDist(normal)));
|
||||
let normalShape = ExpressionTree.toShape(numSamples, `SymbolicDist(normal));
|
||||
let lognormal = SymbolicDist.Lognormal.fromMeanAndStdev(mean, stdev);
|
||||
let lognormalShape = ExpressionTree.toShape(numSamples, `Leaf(`SymbolicDist(lognormal)));
|
||||
let lognormalShape = ExpressionTree.toShape(numSamples, `SymbolicDist(lognormal));
|
||||
|
||||
makeTestCloseEquality(
|
||||
"Mean of a normal",
|
||||
|
|
|
@ -389,7 +389,7 @@ module Draw = {
|
|||
let numSamples = 3000;
|
||||
|
||||
let normal: SymbolicTypes.symbolicDist = `Normal({mean, stdev});
|
||||
let normalShape = ExpressionTree.toShape(numSamples, `Leaf(`SymbolicDist(normal)));
|
||||
let normalShape = ExpressionTree.toShape(numSamples, `SymbolicDist(normal));
|
||||
let xyShape: Types.xyShape =
|
||||
switch (normalShape) {
|
||||
| Mixed(_) => {xs: [||], ys: [||]}
|
||||
|
|
|
@ -2,10 +2,11 @@ open ExpressionTypes.ExpressionTree;
|
|||
|
||||
let toShape = (sampleCount: int, node: node) => {
|
||||
let renderResult =
|
||||
ExpressionTreeEvaluator.toLeaf(`Operation(`Render(node)), sampleCount);
|
||||
`Render(`Normalize(node))
|
||||
|> ExpressionTreeEvaluator.toLeaf({sampleCount: sampleCount});
|
||||
|
||||
switch (renderResult) {
|
||||
| Ok(`Leaf(`RenderedDist(rs))) =>
|
||||
| Ok(`RenderedDist(rs)) =>
|
||||
let continuous = Distributions.Shape.T.toContinuous(rs);
|
||||
let discrete = Distributions.Shape.T.toDiscrete(rs);
|
||||
let shape = MixedShapeBuilder.buildSimple(~continuous, ~discrete);
|
||||
|
@ -17,6 +18,6 @@ let toShape = (sampleCount: int, node: node) => {
|
|||
|
||||
let rec toString =
|
||||
fun
|
||||
| `Leaf(`SymbolicDist(d)) => SymbolicDist.T.toString(d)
|
||||
| `Leaf(`RenderedDist(_)) => "[shape]"
|
||||
| `Operation(op) => Operation.T.toString(toString, op);
|
||||
| `SymbolicDist(d) => SymbolicDist.T.toString(d)
|
||||
| `RenderedDist(_) => "[shape]"
|
||||
| op => Operation.T.toString(toString, op);
|
||||
|
|
|
@ -1,84 +1,77 @@
|
|||
/* This module represents a tree node. */
|
||||
open ExpressionTypes;
|
||||
open ExpressionTypes.ExpressionTree;
|
||||
|
||||
type t = node;
|
||||
type tResult = node => result(node, string);
|
||||
|
||||
type renderParams = {
|
||||
sampleCount: int,
|
||||
};
|
||||
|
||||
/* Given two random variables A and B, this returns the distribution
|
||||
of a new variable that is the result of the operation on A and B.
|
||||
For instance, normal(0, 1) + normal(1, 1) -> normal(1, 2).
|
||||
In general, this is implemented via convolution. */
|
||||
module AlgebraicCombination = {
|
||||
let toTreeNode = (op, t1, t2) =>
|
||||
`Operation(`AlgebraicCombination((op, t1, t2)));
|
||||
let tryAnalyticalSolution =
|
||||
fun
|
||||
| `Operation(
|
||||
`AlgebraicCombination(
|
||||
operation,
|
||||
`Leaf(`SymbolicDist(d1)),
|
||||
`Leaf(`SymbolicDist(d2)),
|
||||
),
|
||||
) as t =>
|
||||
switch (SymbolicDist.T.attemptAnalyticalOperation(d1, d2, operation)) {
|
||||
| `AnalyticalSolution(symbolicDist) =>
|
||||
Ok(`Leaf(`SymbolicDist(symbolicDist)))
|
||||
let tryAnalyticalSimplification = (operation, t1: t, t2: t) =>
|
||||
switch (operation, t1, t2) {
|
||||
| (operation,
|
||||
`SymbolicDist(d1),
|
||||
`SymbolicDist(d2),
|
||||
) =>
|
||||
switch (SymbolicDist.T.tryAnalyticalSimplification(d1, d2, operation)) {
|
||||
| `AnalyticalSolution(symbolicDist) => Ok(`SymbolicDist(symbolicDist))
|
||||
| `Error(er) => Error(er)
|
||||
| `NoSolution => Ok(t)
|
||||
| `NoSolution => Ok(`AlgebraicCombination(operation, t1, t2))
|
||||
}
|
||||
| t => Ok(t);
|
||||
| _ => Ok(`AlgebraicCombination(operation, t1, t2))
|
||||
};
|
||||
|
||||
// todo: I don't like the name evaluateNumerically that much, if this renders and does it algebraically. It's tricky.
|
||||
let evaluateNumerically = (algebraicOp, operationToLeaf, t1, t2) => {
|
||||
// force rendering into shapes
|
||||
let renderShape = r => operationToLeaf(`Render(r));
|
||||
let combineAsShapes = (toLeaf, renderParams, algebraicOp, t1, t2) => {
|
||||
let renderShape = r => toLeaf(renderParams, `Render(r));
|
||||
switch (renderShape(t1), renderShape(t2)) {
|
||||
| (Ok(`Leaf(`RenderedDist(s1))), Ok(`Leaf(`RenderedDist(s2)))) =>
|
||||
| (Ok(`RenderedDist(s1)), Ok(`RenderedDist(s2))) =>
|
||||
Ok(
|
||||
`Leaf(
|
||||
`RenderedDist(
|
||||
Distributions.Shape.combineAlgebraically(algebraicOp, s1, s2),
|
||||
),
|
||||
),
|
||||
)
|
||||
| (Error(e1), _) => Error(e1)
|
||||
| (_, Error(e2)) => Error(e2)
|
||||
| _ => Error("Could not render shapes.")
|
||||
| _ => Error("Algebraic combination: rendering failed.")
|
||||
};
|
||||
};
|
||||
|
||||
let toLeaf =
|
||||
let operationToLeaf =
|
||||
(
|
||||
operationToLeaf,
|
||||
toLeaf,
|
||||
renderParams: renderParams,
|
||||
algebraicOp: ExpressionTypes.algebraicOperation,
|
||||
t1: t,
|
||||
t2: t,
|
||||
)
|
||||
: result(node, string) =>
|
||||
toTreeNode(algebraicOp, t1, t2)
|
||||
|> tryAnalyticalSolution
|
||||
|
||||
algebraicOp
|
||||
|> tryAnalyticalSimplification(_, t1, t2)
|
||||
|> E.R.bind(
|
||||
_,
|
||||
fun
|
||||
| `Leaf(d) => Ok(`Leaf(d)) // the analytical simplifaction worked, nice!
|
||||
| `Operation(_) =>
|
||||
// if not, run the convolution
|
||||
evaluateNumerically(algebraicOp, operationToLeaf, t1, t2),
|
||||
| `SymbolicDist(d) as t => Ok(t)
|
||||
| _ => combineAsShapes(toLeaf, renderParams, algebraicOp, t1, t2)
|
||||
);
|
||||
};
|
||||
|
||||
module VerticalScaling = {
|
||||
let toLeaf = (operationToLeaf, scaleOp, t, scaleBy) => {
|
||||
let operationToLeaf = (toLeaf, renderParams, scaleOp, t, scaleBy) => {
|
||||
// scaleBy has to be a single float, otherwise we'll return an error.
|
||||
let fn = Operation.Scale.toFn(scaleOp);
|
||||
let knownIntegralSumFn = Operation.Scale.toKnownIntegralSumFn(scaleOp);
|
||||
let renderedShape = operationToLeaf(`Render(t));
|
||||
let renderedShape = toLeaf(renderParams, `Render(t));
|
||||
|
||||
switch (renderedShape, scaleBy) {
|
||||
| (Ok(`Leaf(`RenderedDist(rs))), `Leaf(`SymbolicDist(`Float(sm)))) =>
|
||||
| (Ok(`RenderedDist(rs)), `SymbolicDist(`Float(sm))) =>
|
||||
Ok(
|
||||
`Leaf(
|
||||
`RenderedDist(
|
||||
Distributions.Shape.T.mapY(
|
||||
~knownIntegralSumFn=knownIntegralSumFn(sm),
|
||||
|
@ -86,7 +79,6 @@ module VerticalScaling = {
|
|||
rs,
|
||||
),
|
||||
),
|
||||
),
|
||||
)
|
||||
| (Error(e1), _) => Error(e1)
|
||||
| (_, _) => Error("Can only scale by float values.")
|
||||
|
@ -95,14 +87,11 @@ module VerticalScaling = {
|
|||
};
|
||||
|
||||
module PointwiseCombination = {
|
||||
let pointwiseAdd = (operationToLeaf, t1, t2) => {
|
||||
let renderedShape1 = operationToLeaf(`Render(t1));
|
||||
let renderedShape2 = operationToLeaf(`Render(t2));
|
||||
|
||||
switch (renderedShape1, renderedShape2) {
|
||||
| (Ok(`Leaf(`RenderedDist(rs1))), Ok(`Leaf(`RenderedDist(rs2)))) =>
|
||||
let pointwiseAdd = (toLeaf, renderParams, t1, t2) => {
|
||||
let renderShape = r => toLeaf(renderParams, `Render(r));
|
||||
switch (renderShape(t1), renderShape(t2)) {
|
||||
| (Ok(`RenderedDist(rs1)), Ok(`RenderedDist(rs2))) =>
|
||||
Ok(
|
||||
`Leaf(
|
||||
`RenderedDist(
|
||||
Distributions.Shape.combinePointwise(
|
||||
~knownIntegralSumsFn=(a, b) => Some(a +. b),
|
||||
|
@ -111,15 +100,14 @@ module PointwiseCombination = {
|
|||
rs2,
|
||||
),
|
||||
),
|
||||
),
|
||||
)
|
||||
| (Error(e1), _) => Error(e1)
|
||||
| (_, Error(e2)) => Error(e2)
|
||||
| _ => Error("Could not perform pointwise addition.")
|
||||
| _ => Error("Pointwise combination: rendering failed.")
|
||||
};
|
||||
};
|
||||
|
||||
let pointwiseMultiply = (operationToLeaf, t1, t2) => {
|
||||
let pointwiseMultiply = (toLeaf, renderParams, t1, t2) => {
|
||||
// TODO: construct a function that we can easily sample from, to construct
|
||||
// a RenderedDist. Use the xMin and xMax of the rendered shapes to tell the sampling function where to look.
|
||||
Error(
|
||||
|
@ -127,84 +115,72 @@ module PointwiseCombination = {
|
|||
);
|
||||
};
|
||||
|
||||
let toLeaf = (operationToLeaf, pointwiseOp, t1, t2) => {
|
||||
let operationToLeaf = (toLeaf, renderParams, pointwiseOp, t1, t2) => {
|
||||
switch (pointwiseOp) {
|
||||
| `Add => pointwiseAdd(operationToLeaf, t1, t2)
|
||||
| `Multiply => pointwiseMultiply(operationToLeaf, t1, t2)
|
||||
| `Add => pointwiseAdd(toLeaf, renderParams, t1, t2)
|
||||
| `Multiply => pointwiseMultiply(toLeaf, renderParams, t1, t2)
|
||||
};
|
||||
};
|
||||
};
|
||||
|
||||
module Truncate = {
|
||||
module Simplify = {
|
||||
let tryTruncatingNothing: tResult =
|
||||
fun
|
||||
| `Operation(`Truncate(None, None, `Leaf(d))) => Ok(`Leaf(d))
|
||||
| t => Ok(t);
|
||||
|
||||
let tryTruncatingUniform: tResult =
|
||||
fun
|
||||
| `Operation(`Truncate(lc, rc, `Leaf(`SymbolicDist(`Uniform(u))))) => {
|
||||
let trySimplification = (leftCutoff, rightCutoff, t) => {
|
||||
switch (leftCutoff, rightCutoff, t) {
|
||||
| (None, None, t) => Ok(t)
|
||||
| (lc, rc, `SymbolicDist(`Uniform(u))) => {
|
||||
// just create a new Uniform distribution
|
||||
let newLow = max(E.O.default(neg_infinity, lc), u.low);
|
||||
let newHigh = min(E.O.default(infinity, rc), u.high);
|
||||
Ok(`Leaf(`SymbolicDist(`Uniform({low: newLow, high: newHigh}))));
|
||||
let nu: SymbolicTypes.uniform = u;
|
||||
let newLow = max(E.O.default(neg_infinity, lc), nu.low);
|
||||
let newHigh = min(E.O.default(infinity, rc), nu.high);
|
||||
Ok(`SymbolicDist(`Uniform({low: newLow, high: newHigh})));
|
||||
}
|
||||
| t => Ok(t);
|
||||
|
||||
let attempt = (leftCutoff, rightCutoff, t): result(node, string) => {
|
||||
let originalTreeNode =
|
||||
`Operation(`Truncate((leftCutoff, rightCutoff, t)));
|
||||
|
||||
originalTreeNode
|
||||
|> tryTruncatingNothing
|
||||
|> E.R.bind(_, tryTruncatingUniform);
|
||||
| (_, _, t) => Ok(t)
|
||||
};
|
||||
};
|
||||
|
||||
let evaluateNumerically = (leftCutoff, rightCutoff, operationToLeaf, t) => {
|
||||
let truncateAsShape = (toLeaf, renderParams, leftCutoff, rightCutoff, t) => {
|
||||
// TODO: use named args in renderToShape; if we're lucky we can at least get the tail
|
||||
// of a distribution we otherwise wouldn't get at all
|
||||
let renderedShape = operationToLeaf(`Render(t));
|
||||
let renderedShape = toLeaf(renderParams, `Render(t));
|
||||
|
||||
switch (renderedShape) {
|
||||
| Ok(`Leaf(`RenderedDist(rs))) =>
|
||||
| Ok(`RenderedDist(rs)) => {
|
||||
let truncatedShape =
|
||||
rs |> Distributions.Shape.T.truncate(leftCutoff, rightCutoff);
|
||||
Ok(`Leaf(`RenderedDist(rs)));
|
||||
Ok(`RenderedDist(rs));
|
||||
}
|
||||
| Error(e1) => Error(e1)
|
||||
| _ => Error("Could not truncate distribution.")
|
||||
};
|
||||
};
|
||||
|
||||
let toLeaf =
|
||||
let operationToLeaf =
|
||||
(
|
||||
operationToLeaf,
|
||||
toLeaf,
|
||||
renderParams,
|
||||
leftCutoff: option(float),
|
||||
rightCutoff: option(float),
|
||||
t: node,
|
||||
)
|
||||
: result(node, string) => {
|
||||
t
|
||||
|> Simplify.attempt(leftCutoff, rightCutoff)
|
||||
|> trySimplification(leftCutoff, rightCutoff)
|
||||
|> E.R.bind(
|
||||
_,
|
||||
fun
|
||||
| `Leaf(d) => Ok(`Leaf(d)) // the analytical simplifaction worked, nice!
|
||||
| `Operation(_) =>
|
||||
evaluateNumerically(leftCutoff, rightCutoff, operationToLeaf, t),
|
||||
); // if not, run the convolution
|
||||
| `SymbolicDist(d) as t => Ok(t)
|
||||
| _ => truncateAsShape(toLeaf, renderParams, leftCutoff, rightCutoff, t),
|
||||
);
|
||||
};
|
||||
};
|
||||
|
||||
module Normalize = {
|
||||
let rec toLeaf = (operationToLeaf, t: node): result(node, string) => {
|
||||
let rec operationToLeaf = (toLeaf, renderParams, t: node): result(node, string) => {
|
||||
switch (t) {
|
||||
| `Leaf(`RenderedDist(s)) =>
|
||||
Ok(`Leaf(`RenderedDist(Distributions.Shape.T.normalize(s))))
|
||||
| `Leaf(`SymbolicDist(_)) => Ok(t)
|
||||
| `Operation(op) =>
|
||||
operationToLeaf(op) |> E.R.bind(_, toLeaf(operationToLeaf))
|
||||
| `RenderedDist(s) =>
|
||||
Ok(`RenderedDist(Distributions.Shape.T.normalize(s)))
|
||||
| `SymbolicDist(_) => Ok(t)
|
||||
| _ => t |> toLeaf(renderParams) |> E.R.bind(_, operationToLeaf(toLeaf, renderParams))
|
||||
};
|
||||
};
|
||||
};
|
||||
|
@ -212,83 +188,79 @@ module Normalize = {
|
|||
module FloatFromDist = {
|
||||
let symbolicToLeaf = (distToFloatOp: distToFloatOperation, s) => {
|
||||
SymbolicDist.T.operate(distToFloatOp, s)
|
||||
|> E.R.bind(_, v => Ok(`Leaf(`SymbolicDist(`Float(v)))));
|
||||
|> E.R.bind(_, v => Ok(`SymbolicDist(`Float(v))));
|
||||
};
|
||||
let renderedToLeaf =
|
||||
(distToFloatOp: distToFloatOperation, rs: DistTypes.shape)
|
||||
: result(node, string) => {
|
||||
Distributions.Shape.operate(distToFloatOp, rs)
|
||||
|> (v => Ok(`Leaf(`SymbolicDist(`Float(v)))));
|
||||
|> (v => Ok(`SymbolicDist(`Float(v))));
|
||||
};
|
||||
let rec toLeaf =
|
||||
(operationToLeaf, distToFloatOp: distToFloatOperation, t: node)
|
||||
let rec operationToLeaf =
|
||||
(toLeaf, renderParams, distToFloatOp: distToFloatOperation, t: node)
|
||||
: result(node, string) => {
|
||||
switch (t) {
|
||||
| `Leaf(`SymbolicDist(s)) => symbolicToLeaf(distToFloatOp, s) // we want to evaluate the distToFloatOp on the symbolic dist
|
||||
| `Leaf(`RenderedDist(rs)) => renderedToLeaf(distToFloatOp, rs)
|
||||
| `Operation(op) =>
|
||||
E.R.bind(operationToLeaf(op), toLeaf(operationToLeaf, distToFloatOp))
|
||||
| `SymbolicDist(s) => symbolicToLeaf(distToFloatOp, s)
|
||||
| `RenderedDist(rs) => renderedToLeaf(distToFloatOp, rs)
|
||||
| _ => t |> toLeaf(renderParams) |> E.R.bind(_, operationToLeaf(toLeaf, renderParams, distToFloatOp))
|
||||
};
|
||||
};
|
||||
};
|
||||
|
||||
module Render = {
|
||||
let rec toLeaf =
|
||||
let rec operationToLeaf =
|
||||
(
|
||||
operationToLeaf: operation => result(t, string),
|
||||
sampleCount: int,
|
||||
toLeaf,
|
||||
renderParams,
|
||||
t: node,
|
||||
)
|
||||
: result(t, string) => {
|
||||
switch (t) {
|
||||
| `Leaf(`SymbolicDist(d)) =>
|
||||
Ok(`Leaf(`RenderedDist(SymbolicDist.T.toShape(sampleCount, d))))
|
||||
| `Leaf(`RenderedDist(_)) as t => Ok(t) // already a rendered shape, we're done here
|
||||
| `Operation(op) =>
|
||||
E.R.bind(operationToLeaf(op), toLeaf(operationToLeaf, sampleCount))
|
||||
| `SymbolicDist(d) =>
|
||||
Ok(`RenderedDist(SymbolicDist.T.toShape(renderParams.sampleCount, d)))
|
||||
| `RenderedDist(_) as t => Ok(t) // already a rendered shape, we're done here
|
||||
| _ => t |> toLeaf(renderParams) |> E.R.bind(_, operationToLeaf(toLeaf, renderParams))
|
||||
};
|
||||
};
|
||||
};
|
||||
|
||||
let rec operationToLeaf =
|
||||
(sampleCount: int, op: operation): result(t, string) => {
|
||||
// the functions that convert the Operation nodes to Leaf nodes need to
|
||||
// have a way to call this function on their children, if their children are themselves Operation nodes.
|
||||
switch (op) {
|
||||
| `AlgebraicCombination(algebraicOp, t1, t2) =>
|
||||
AlgebraicCombination.toLeaf(
|
||||
operationToLeaf(sampleCount),
|
||||
algebraicOp,
|
||||
t1,
|
||||
t2 // we want to give it the option to render or simply leave it as is
|
||||
)
|
||||
| `PointwiseCombination(pointwiseOp, t1, t2) =>
|
||||
PointwiseCombination.toLeaf(
|
||||
operationToLeaf(sampleCount),
|
||||
pointwiseOp,
|
||||
t1,
|
||||
t2,
|
||||
)
|
||||
| `VerticalScaling(scaleOp, t, scaleBy) =>
|
||||
VerticalScaling.toLeaf(operationToLeaf(sampleCount), scaleOp, t, scaleBy)
|
||||
| `Truncate(leftCutoff, rightCutoff, t) =>
|
||||
Truncate.toLeaf(operationToLeaf(sampleCount), leftCutoff, rightCutoff, t)
|
||||
| `FloatFromDist(distToFloatOp, t) =>
|
||||
FloatFromDist.toLeaf(operationToLeaf(sampleCount), distToFloatOp, t)
|
||||
| `Normalize(t) => Normalize.toLeaf(operationToLeaf(sampleCount), t)
|
||||
| `Render(t) => Render.toLeaf(operationToLeaf(sampleCount), sampleCount, t)
|
||||
};
|
||||
};
|
||||
|
||||
/* This function recursively goes through the nodes of the parse tree,
|
||||
replacing each Operation node and its subtree with a Data node.
|
||||
Whenever possible, the replacement produces a new Symbolic Data node,
|
||||
but most often it will produce a RenderedDist.
|
||||
This function is used mainly to turn a parse tree into a single RenderedDist
|
||||
that can then be displayed to the user. */
|
||||
let toLeaf = (node: t, sampleCount: int): result(t, string) => {
|
||||
let rec toLeaf = (renderParams, node: t): result(t, string) => {
|
||||
switch (node) {
|
||||
| `Leaf(d) => Ok(`Leaf(d))
|
||||
| `Operation(op) => operationToLeaf(sampleCount, op)
|
||||
// Leaf nodes just stay leaf nodes
|
||||
| `SymbolicDist(_)
|
||||
| `RenderedDist(_) => Ok(node)
|
||||
// Operations need to be turned into leaves
|
||||
| `AlgebraicCombination(algebraicOp, t1, t2) =>
|
||||
AlgebraicCombination.operationToLeaf(
|
||||
toLeaf,
|
||||
renderParams,
|
||||
algebraicOp,
|
||||
t1,
|
||||
t2
|
||||
)
|
||||
| `PointwiseCombination(pointwiseOp, t1, t2) =>
|
||||
PointwiseCombination.operationToLeaf(
|
||||
toLeaf,
|
||||
renderParams,
|
||||
pointwiseOp,
|
||||
t1,
|
||||
t2,
|
||||
)
|
||||
| `VerticalScaling(scaleOp, t, scaleBy) =>
|
||||
VerticalScaling.operationToLeaf(
|
||||
toLeaf, renderParams, scaleOp, t, scaleBy
|
||||
)
|
||||
| `Truncate(leftCutoff, rightCutoff, t) =>
|
||||
Truncate.operationToLeaf(toLeaf, renderParams, leftCutoff, rightCutoff, t)
|
||||
| `FloatFromDist(distToFloatOp, t) =>
|
||||
FloatFromDist.operationToLeaf(toLeaf, renderParams, distToFloatOp, t)
|
||||
| `Normalize(t) => Normalize.operationToLeaf(toLeaf, renderParams, t)
|
||||
| `Render(t) => Render.operationToLeaf(toLeaf, renderParams, t)
|
||||
};
|
||||
};
|
||||
|
|
|
@ -3,22 +3,18 @@ type pointwiseOperation = [ | `Add | `Multiply];
|
|||
type scaleOperation = [ | `Multiply | `Exponentiate | `Log];
|
||||
type distToFloatOperation = [ | `Pdf(float) | `Inv(float) | `Mean | `Sample];
|
||||
|
||||
type abstractOperation('a) = [
|
||||
| `AlgebraicCombination(algebraicOperation, 'a, 'a)
|
||||
| `PointwiseCombination(pointwiseOperation, 'a, 'a)
|
||||
| `VerticalScaling(scaleOperation, 'a, 'a)
|
||||
| `Render('a)
|
||||
| `Truncate(option(float), option(float), 'a)
|
||||
| `Normalize('a)
|
||||
| `FloatFromDist(distToFloatOperation, 'a)
|
||||
];
|
||||
|
||||
module ExpressionTree = {
|
||||
type leaf = [
|
||||
type node = [
|
||||
// leaf nodes:
|
||||
| `SymbolicDist(SymbolicTypes.symbolicDist)
|
||||
| `RenderedDist(DistTypes.shape)
|
||||
// operations:
|
||||
| `AlgebraicCombination(algebraicOperation, node, node)
|
||||
| `PointwiseCombination(pointwiseOperation, node, node)
|
||||
| `VerticalScaling(scaleOperation, node, node)
|
||||
| `Render(node)
|
||||
| `Truncate(option(float), option(float), node)
|
||||
| `Normalize(node)
|
||||
| `FloatFromDist(distToFloatOperation, node)
|
||||
];
|
||||
|
||||
type node = [ | `Leaf(leaf) | `Operation(operation)]
|
||||
and operation = abstractOperation(node);
|
||||
};
|
||||
|
|
|
@ -86,29 +86,29 @@ module MathAdtToDistDst = {
|
|||
);
|
||||
};
|
||||
|
||||
let normal: array(arg) => result(ExpressionTypes.ExpressionTree.node, string) =
|
||||
let normal:
|
||||
array(arg) => result(ExpressionTypes.ExpressionTree.node, string) =
|
||||
fun
|
||||
| [|Value(mean), Value(stdev)|] =>
|
||||
Ok(`Leaf(`SymbolicDist(`Normal({mean, stdev}))))
|
||||
Ok(`SymbolicDist(`Normal({mean, stdev})))
|
||||
| _ => Error("Wrong number of variables in normal distribution");
|
||||
|
||||
let lognormal: array(arg) => result(ExpressionTypes.ExpressionTree.node, string) =
|
||||
let lognormal:
|
||||
array(arg) => result(ExpressionTypes.ExpressionTree.node, string) =
|
||||
fun
|
||||
| [|Value(mu), Value(sigma)|] =>
|
||||
Ok(`Leaf(`SymbolicDist(`Lognormal({mu, sigma}))))
|
||||
Ok(`SymbolicDist(`Lognormal({mu, sigma})))
|
||||
| [|Object(o)|] => {
|
||||
let g = Js.Dict.get(o);
|
||||
switch (g("mean"), g("stdev"), g("mu"), g("sigma")) {
|
||||
| (Some(Value(mean)), Some(Value(stdev)), _, _) =>
|
||||
Ok(
|
||||
`Leaf(
|
||||
`SymbolicDist(
|
||||
SymbolicDist.Lognormal.fromMeanAndStdev(mean, stdev),
|
||||
),
|
||||
),
|
||||
)
|
||||
| (_, _, Some(Value(mu)), Some(Value(sigma))) =>
|
||||
Ok(`Leaf(`SymbolicDist(`Lognormal({mu, sigma}))))
|
||||
Ok(`SymbolicDist(`Lognormal({mu, sigma})))
|
||||
| _ => Error("Lognormal distribution would need mean and stdev")
|
||||
};
|
||||
}
|
||||
|
@ -117,51 +117,48 @@ module MathAdtToDistDst = {
|
|||
let to_: array(arg) => result(ExpressionTypes.ExpressionTree.node, string) =
|
||||
fun
|
||||
| [|Value(low), Value(high)|] when low <= 0.0 && low < high => {
|
||||
Ok(
|
||||
`Leaf(
|
||||
`SymbolicDist(SymbolicDist.Normal.from90PercentCI(low, high)),
|
||||
),
|
||||
);
|
||||
Ok(`SymbolicDist(SymbolicDist.Normal.from90PercentCI(low, high)));
|
||||
}
|
||||
| [|Value(low), Value(high)|] when low < high => {
|
||||
Ok(
|
||||
`Leaf(
|
||||
`SymbolicDist(SymbolicDist.Lognormal.from90PercentCI(low, high)),
|
||||
),
|
||||
);
|
||||
}
|
||||
| [|Value(_), Value(_)|] =>
|
||||
Error("Low value must be less than high value.")
|
||||
| _ => Error("Wrong number of variables in lognormal distribution");
|
||||
|
||||
let uniform: array(arg) => result(ExpressionTypes.ExpressionTree.node, string) =
|
||||
let uniform:
|
||||
array(arg) => result(ExpressionTypes.ExpressionTree.node, string) =
|
||||
fun
|
||||
| [|Value(low), Value(high)|] =>
|
||||
Ok(`Leaf(`SymbolicDist(`Uniform({low, high}))))
|
||||
Ok(`SymbolicDist(`Uniform({low, high})))
|
||||
| _ => Error("Wrong number of variables in lognormal distribution");
|
||||
|
||||
let beta: array(arg) => result(ExpressionTypes.ExpressionTree.node, string) =
|
||||
fun
|
||||
| [|Value(alpha), Value(beta)|] =>
|
||||
Ok(`Leaf(`SymbolicDist(`Beta({alpha, beta}))))
|
||||
Ok(`SymbolicDist(`Beta({alpha, beta})))
|
||||
| _ => Error("Wrong number of variables in lognormal distribution");
|
||||
|
||||
let exponential: array(arg) => result(ExpressionTypes.ExpressionTree.node, string) =
|
||||
let exponential:
|
||||
array(arg) => result(ExpressionTypes.ExpressionTree.node, string) =
|
||||
fun
|
||||
| [|Value(rate)|] =>
|
||||
Ok(`Leaf(`SymbolicDist(`Exponential({rate: rate}))))
|
||||
| [|Value(rate)|] => Ok(`SymbolicDist(`Exponential({rate: rate})))
|
||||
| _ => Error("Wrong number of variables in Exponential distribution");
|
||||
|
||||
let cauchy: array(arg) => result(ExpressionTypes.ExpressionTree.node, string) =
|
||||
let cauchy:
|
||||
array(arg) => result(ExpressionTypes.ExpressionTree.node, string) =
|
||||
fun
|
||||
| [|Value(local), Value(scale)|] =>
|
||||
Ok(`Leaf(`SymbolicDist(`Cauchy({local, scale}))))
|
||||
Ok(`SymbolicDist(`Cauchy({local, scale})))
|
||||
| _ => Error("Wrong number of variables in cauchy distribution");
|
||||
|
||||
let triangular: array(arg) => result(ExpressionTypes.ExpressionTree.node, string) =
|
||||
let triangular:
|
||||
array(arg) => result(ExpressionTypes.ExpressionTree.node, string) =
|
||||
fun
|
||||
| [|Value(low), Value(medium), Value(high)|] =>
|
||||
Ok(`Leaf(`SymbolicDist(`Triangular({low, medium, high}))))
|
||||
Ok(`SymbolicDist(`Triangular({low, medium, high})))
|
||||
| _ => Error("Wrong number of variables in triangle distribution");
|
||||
|
||||
let multiModal =
|
||||
|
@ -192,30 +189,24 @@ module MathAdtToDistDst = {
|
|||
|> E.A.fmapi((index, t) => {
|
||||
let w = weights |> E.A.get(_, index) |> E.O.default(1.0);
|
||||
|
||||
`Operation(
|
||||
`VerticalScaling((
|
||||
`Multiply,
|
||||
t,
|
||||
`Leaf(`SymbolicDist(`Float(w))),
|
||||
)),
|
||||
);
|
||||
`VerticalScaling((`Multiply, t, `SymbolicDist(`Float(w))));
|
||||
});
|
||||
|
||||
let pointwiseSum =
|
||||
components
|
||||
|> Js.Array.sliceFrom(1)
|
||||
|> E.A.fold_left(
|
||||
(acc, x) => {
|
||||
`Operation(`PointwiseCombination((`Add, acc, x)))
|
||||
},
|
||||
(acc, x) => {`PointwiseCombination((`Add, acc, x))},
|
||||
E.A.unsafe_get(components, 0),
|
||||
);
|
||||
|
||||
Ok(`Operation(`Normalize(pointwiseSum)));
|
||||
Ok(`Normalize(pointwiseSum));
|
||||
};
|
||||
};
|
||||
|
||||
let arrayParser = (args: array(arg)): result(ExpressionTypes.ExpressionTree.node, string) => {
|
||||
let arrayParser =
|
||||
(args: array(arg))
|
||||
: result(ExpressionTypes.ExpressionTree.node, string) => {
|
||||
let samples =
|
||||
args
|
||||
|> E.A.fmap(
|
||||
|
@ -235,15 +226,18 @@ module MathAdtToDistDst = {
|
|||
SymbolicDist.ContinuousShape.make(_pdf, cdf);
|
||||
});
|
||||
switch (shape) {
|
||||
| Some(s) => Ok(`Leaf(`SymbolicDist(`ContinuousShape(s))))
|
||||
| Some(s) => Ok(`SymbolicDist(`ContinuousShape(s)))
|
||||
| None => Error("Rendering did not work")
|
||||
};
|
||||
};
|
||||
|
||||
let operationParser =
|
||||
(name: string, args: array(result(ExpressionTypes.ExpressionTree.node, string))) => {
|
||||
let toOkAlgebraic = r => Ok(`Operation(`AlgebraicCombination(r)));
|
||||
let toOkTrunctate = r => Ok(`Operation(`Truncate(r)));
|
||||
(
|
||||
name: string,
|
||||
args: array(result(ExpressionTypes.ExpressionTree.node, string)),
|
||||
) => {
|
||||
let toOkAlgebraic = r => Ok(`AlgebraicCombination(r));
|
||||
let toOkTrunctate = r => Ok(`Truncate(r));
|
||||
switch (name, args) {
|
||||
| ("add", [|Ok(l), Ok(r)|]) => toOkAlgebraic((`Add, l, r))
|
||||
| ("add", _) => Error("Addition needs two operands")
|
||||
|
@ -254,11 +248,11 @@ module MathAdtToDistDst = {
|
|||
| ("divide", [|Ok(l), Ok(r)|]) => toOkAlgebraic((`Divide, l, r))
|
||||
| ("divide", _) => Error("Division needs two operands")
|
||||
| ("pow", _) => Error("Exponentiation is not yet supported.")
|
||||
| ("leftTruncate", [|Ok(d), Ok(`Leaf(`SymbolicDist(`Float(lc))))|]) =>
|
||||
| ("leftTruncate", [|Ok(d), Ok(`SymbolicDist(`Float(lc)))|]) =>
|
||||
toOkTrunctate((Some(lc), None, d))
|
||||
| ("leftTruncate", _) =>
|
||||
Error("leftTruncate needs two arguments: the expression and the cutoff")
|
||||
| ("rightTruncate", [|Ok(d), Ok(`Leaf(`SymbolicDist(`Float(rc))))|]) =>
|
||||
| ("rightTruncate", [|Ok(d), Ok(`SymbolicDist(`Float(rc)))|]) =>
|
||||
toOkTrunctate((None, Some(rc), d))
|
||||
| ("rightTruncate", _) =>
|
||||
Error(
|
||||
|
@ -268,8 +262,8 @@ module MathAdtToDistDst = {
|
|||
"truncate",
|
||||
[|
|
||||
Ok(d),
|
||||
Ok(`Leaf(`SymbolicDist(`Float(lc)))),
|
||||
Ok(`Leaf(`SymbolicDist(`Float(rc)))),
|
||||
Ok(`SymbolicDist(`Float(lc))),
|
||||
Ok(`SymbolicDist(`Float(rc))),
|
||||
|],
|
||||
) =>
|
||||
toOkTrunctate((Some(lc), Some(rc), d))
|
||||
|
@ -333,7 +327,7 @@ module MathAdtToDistDst = {
|
|||
|
||||
let rec nodeParser =
|
||||
fun
|
||||
| Value(f) => Ok(`Leaf(`SymbolicDist(`Float(f))))
|
||||
| Value(f) => Ok(`SymbolicDist(`Float(f)))
|
||||
| Fn({name, args}) => functionParser(nodeParser, name, args)
|
||||
| _ => {
|
||||
Error("This type not currently supported");
|
||||
|
|
|
@ -89,5 +89,6 @@ module T = {
|
|||
| `FloatFromDist(floatFromDistOp, t) =>
|
||||
DistToFloat.format(floatFromDistOp, nodeToString(t))
|
||||
| `Truncate(lc, rc, t) => truncateToString(lc, rc, nodeToString(t))
|
||||
| `Render(t) => nodeToString(t);
|
||||
| `Render(t) => nodeToString(t)
|
||||
| _ => ""; // SymbolicDist and RenderedDist are handled in ExpressionTree.toString.
|
||||
};
|
||||
|
|
|
@ -269,23 +269,23 @@ module T = {
|
|||
};
|
||||
};
|
||||
|
||||
/* This returns an optional that wraps a result. If the optional is None,
|
||||
there is no valid analytic solution. If it Some, it
|
||||
/* Calling e.g. "Normal.operate" returns an optional that wraps a result.
|
||||
If the optional is None, there is no valid analytic solution. If it Some, it
|
||||
can still return an error if there is a serious problem,
|
||||
like in the casea of a divide by 0.
|
||||
like in the case of a divide by 0.
|
||||
*/
|
||||
type analyticalSolutionAttempt = [
|
||||
type analyticalSimplificationResult = [
|
||||
| `AnalyticalSolution(SymbolicTypes.symbolicDist)
|
||||
| `Error(string)
|
||||
| `NoSolution
|
||||
];
|
||||
let attemptAnalyticalOperation =
|
||||
let tryAnalyticalSimplification =
|
||||
(
|
||||
d1: symbolicDist,
|
||||
d2: symbolicDist,
|
||||
op: ExpressionTypes.algebraicOperation,
|
||||
)
|
||||
: analyticalSolutionAttempt =>
|
||||
: analyticalSimplificationResult =>
|
||||
switch (d1, d2) {
|
||||
| (`Float(v1), `Float(v2)) =>
|
||||
switch (Operation.Algebraic.applyFn(op, v1, v2)) {
|
||||
|
|
Loading…
Reference in New Issue
Block a user