295 lines
9.5 KiB
ReasonML
295 lines
9.5 KiB
ReasonML
|
/* This module represents a tree node. */
|
||
|
open ExpressionTypes;
|
||
|
open ExpressionTypes.ExpressionTree;
|
||
|
|
||
|
type t = node;
|
||
|
type tResult = node => result(node, string);
|
||
|
|
||
|
/* Given two random variables A and B, this returns the distribution
|
||
|
of a new variable that is the result of the operation on A and B.
|
||
|
For instance, normal(0, 1) + normal(1, 1) -> normal(1, 2).
|
||
|
In general, this is implemented via convolution. */
|
||
|
module AlgebraicCombination = {
|
||
|
let toTreeNode = (op, t1, t2) =>
|
||
|
`Operation(`AlgebraicCombination((op, t1, t2)));
|
||
|
let tryAnalyticalSolution =
|
||
|
fun
|
||
|
| `Operation(
|
||
|
`AlgebraicCombination(
|
||
|
operation,
|
||
|
`Leaf(`SymbolicDist(d1)),
|
||
|
`Leaf(`SymbolicDist(d2)),
|
||
|
),
|
||
|
) as t =>
|
||
|
switch (SymbolicDist.T.attemptAnalyticalOperation(d1, d2, operation)) {
|
||
|
| `AnalyticalSolution(symbolicDist) =>
|
||
|
Ok(`Leaf(`SymbolicDist(symbolicDist)))
|
||
|
| `Error(er) => Error(er)
|
||
|
| `NoSolution => Ok(t)
|
||
|
}
|
||
|
| t => Ok(t);
|
||
|
|
||
|
// todo: I don't like the name evaluateNumerically that much, if this renders and does it algebraically. It's tricky.
|
||
|
let evaluateNumerically = (algebraicOp, operationToLeaf, t1, t2) => {
|
||
|
// force rendering into shapes
|
||
|
let renderShape = r => operationToLeaf(`Render(r));
|
||
|
switch (renderShape(t1), renderShape(t2)) {
|
||
|
| (Ok(`Leaf(`RenderedDist(s1))), Ok(`Leaf(`RenderedDist(s2)))) =>
|
||
|
Ok(
|
||
|
`Leaf(
|
||
|
`RenderedDist(
|
||
|
Distributions.Shape.combineAlgebraically(algebraicOp, s1, s2),
|
||
|
),
|
||
|
),
|
||
|
)
|
||
|
| (Error(e1), _) => Error(e1)
|
||
|
| (_, Error(e2)) => Error(e2)
|
||
|
| _ => Error("Could not render shapes.")
|
||
|
};
|
||
|
};
|
||
|
|
||
|
let toLeaf =
|
||
|
(
|
||
|
operationToLeaf,
|
||
|
algebraicOp: ExpressionTypes.algebraicOperation,
|
||
|
t1: t,
|
||
|
t2: t,
|
||
|
)
|
||
|
: result(node, string) =>
|
||
|
toTreeNode(algebraicOp, t1, t2)
|
||
|
|> tryAnalyticalSolution
|
||
|
|> E.R.bind(
|
||
|
_,
|
||
|
fun
|
||
|
| `Leaf(d) => Ok(`Leaf(d)) // the analytical simplifaction worked, nice!
|
||
|
| `Operation(_) =>
|
||
|
// if not, run the convolution
|
||
|
evaluateNumerically(algebraicOp, operationToLeaf, t1, t2),
|
||
|
);
|
||
|
};
|
||
|
|
||
|
module VerticalScaling = {
|
||
|
let toLeaf = (operationToLeaf, scaleOp, t, scaleBy) => {
|
||
|
// scaleBy has to be a single float, otherwise we'll return an error.
|
||
|
let fn = Operation.Scale.toFn(scaleOp);
|
||
|
let knownIntegralSumFn = Operation.Scale.toKnownIntegralSumFn(scaleOp);
|
||
|
let renderedShape = operationToLeaf(`Render(t));
|
||
|
|
||
|
switch (renderedShape, scaleBy) {
|
||
|
| (Ok(`Leaf(`RenderedDist(rs))), `Leaf(`SymbolicDist(`Float(sm)))) =>
|
||
|
Ok(
|
||
|
`Leaf(
|
||
|
`RenderedDist(
|
||
|
Distributions.Shape.T.mapY(
|
||
|
~knownIntegralSumFn=knownIntegralSumFn(sm),
|
||
|
fn(sm),
|
||
|
rs,
|
||
|
),
|
||
|
),
|
||
|
),
|
||
|
)
|
||
|
| (Error(e1), _) => Error(e1)
|
||
|
| (_, _) => Error("Can only scale by float values.")
|
||
|
};
|
||
|
};
|
||
|
};
|
||
|
|
||
|
module PointwiseCombination = {
|
||
|
let pointwiseAdd = (operationToLeaf, t1, t2) => {
|
||
|
let renderedShape1 = operationToLeaf(`Render(t1));
|
||
|
let renderedShape2 = operationToLeaf(`Render(t2));
|
||
|
|
||
|
switch (renderedShape1, renderedShape2) {
|
||
|
| (Ok(`Leaf(`RenderedDist(rs1))), Ok(`Leaf(`RenderedDist(rs2)))) =>
|
||
|
Ok(
|
||
|
`Leaf(
|
||
|
`RenderedDist(
|
||
|
Distributions.Shape.combinePointwise(
|
||
|
~knownIntegralSumsFn=(a, b) => Some(a +. b),
|
||
|
(+.),
|
||
|
rs1,
|
||
|
rs2,
|
||
|
),
|
||
|
),
|
||
|
),
|
||
|
)
|
||
|
| (Error(e1), _) => Error(e1)
|
||
|
| (_, Error(e2)) => Error(e2)
|
||
|
| _ => Error("Could not perform pointwise addition.")
|
||
|
};
|
||
|
};
|
||
|
|
||
|
let pointwiseMultiply = (operationToLeaf, t1, t2) => {
|
||
|
// TODO: construct a function that we can easily sample from, to construct
|
||
|
// a RenderedDist. Use the xMin and xMax of the rendered shapes to tell the sampling function where to look.
|
||
|
Error(
|
||
|
"Pointwise multiplication not yet supported.",
|
||
|
);
|
||
|
};
|
||
|
|
||
|
let toLeaf = (operationToLeaf, pointwiseOp, t1, t2) => {
|
||
|
switch (pointwiseOp) {
|
||
|
| `Add => pointwiseAdd(operationToLeaf, t1, t2)
|
||
|
| `Multiply => pointwiseMultiply(operationToLeaf, t1, t2)
|
||
|
};
|
||
|
};
|
||
|
};
|
||
|
|
||
|
module Truncate = {
|
||
|
module Simplify = {
|
||
|
let tryTruncatingNothing: tResult =
|
||
|
fun
|
||
|
| `Operation(`Truncate(None, None, `Leaf(d))) => Ok(`Leaf(d))
|
||
|
| t => Ok(t);
|
||
|
|
||
|
let tryTruncatingUniform: tResult =
|
||
|
fun
|
||
|
| `Operation(`Truncate(lc, rc, `Leaf(`SymbolicDist(`Uniform(u))))) => {
|
||
|
// just create a new Uniform distribution
|
||
|
let newLow = max(E.O.default(neg_infinity, lc), u.low);
|
||
|
let newHigh = min(E.O.default(infinity, rc), u.high);
|
||
|
Ok(`Leaf(`SymbolicDist(`Uniform({low: newLow, high: newHigh}))));
|
||
|
}
|
||
|
| t => Ok(t);
|
||
|
|
||
|
let attempt = (leftCutoff, rightCutoff, t): result(node, string) => {
|
||
|
let originalTreeNode =
|
||
|
`Operation(`Truncate((leftCutoff, rightCutoff, t)));
|
||
|
|
||
|
originalTreeNode
|
||
|
|> tryTruncatingNothing
|
||
|
|> E.R.bind(_, tryTruncatingUniform);
|
||
|
};
|
||
|
};
|
||
|
|
||
|
let evaluateNumerically = (leftCutoff, rightCutoff, operationToLeaf, t) => {
|
||
|
// TODO: use named args in renderToShape; if we're lucky we can at least get the tail
|
||
|
// of a distribution we otherwise wouldn't get at all
|
||
|
let renderedShape = operationToLeaf(`Render(t));
|
||
|
|
||
|
switch (renderedShape) {
|
||
|
| Ok(`Leaf(`RenderedDist(rs))) =>
|
||
|
let truncatedShape =
|
||
|
rs |> Distributions.Shape.T.truncate(leftCutoff, rightCutoff);
|
||
|
Ok(`Leaf(`RenderedDist(rs)));
|
||
|
| Error(e1) => Error(e1)
|
||
|
| _ => Error("Could not truncate distribution.")
|
||
|
};
|
||
|
};
|
||
|
|
||
|
let toLeaf =
|
||
|
(
|
||
|
operationToLeaf,
|
||
|
leftCutoff: option(float),
|
||
|
rightCutoff: option(float),
|
||
|
t: node,
|
||
|
)
|
||
|
: result(node, string) => {
|
||
|
t
|
||
|
|> Simplify.attempt(leftCutoff, rightCutoff)
|
||
|
|> E.R.bind(
|
||
|
_,
|
||
|
fun
|
||
|
| `Leaf(d) => Ok(`Leaf(d)) // the analytical simplifaction worked, nice!
|
||
|
| `Operation(_) =>
|
||
|
evaluateNumerically(leftCutoff, rightCutoff, operationToLeaf, t),
|
||
|
); // if not, run the convolution
|
||
|
};
|
||
|
};
|
||
|
|
||
|
module Normalize = {
|
||
|
let rec toLeaf = (operationToLeaf, t: node): result(node, string) => {
|
||
|
switch (t) {
|
||
|
| `Leaf(`RenderedDist(s)) =>
|
||
|
Ok(`Leaf(`RenderedDist(Distributions.Shape.T.normalize(s))))
|
||
|
| `Leaf(`SymbolicDist(_)) => Ok(t)
|
||
|
| `Operation(op) =>
|
||
|
operationToLeaf(op) |> E.R.bind(_, toLeaf(operationToLeaf))
|
||
|
};
|
||
|
};
|
||
|
};
|
||
|
|
||
|
module FloatFromDist = {
|
||
|
let symbolicToLeaf = (distToFloatOp: distToFloatOperation, s) => {
|
||
|
SymbolicDist.T.operate(distToFloatOp, s)
|
||
|
|> E.R.bind(_, v => Ok(`Leaf(`SymbolicDist(`Float(v)))));
|
||
|
};
|
||
|
let renderedToLeaf =
|
||
|
(distToFloatOp: distToFloatOperation, rs: DistTypes.shape)
|
||
|
: result(node, string) => {
|
||
|
Distributions.Shape.operate(distToFloatOp, rs)
|
||
|
|> (v => Ok(`Leaf(`SymbolicDist(`Float(v)))));
|
||
|
};
|
||
|
let rec toLeaf =
|
||
|
(operationToLeaf, distToFloatOp: distToFloatOperation, t: node)
|
||
|
: result(node, string) => {
|
||
|
switch (t) {
|
||
|
| `Leaf(`SymbolicDist(s)) => symbolicToLeaf(distToFloatOp, s) // we want to evaluate the distToFloatOp on the symbolic dist
|
||
|
| `Leaf(`RenderedDist(rs)) => renderedToLeaf(distToFloatOp, rs)
|
||
|
| `Operation(op) =>
|
||
|
E.R.bind(operationToLeaf(op), toLeaf(operationToLeaf, distToFloatOp))
|
||
|
};
|
||
|
};
|
||
|
};
|
||
|
|
||
|
module Render = {
|
||
|
let rec toLeaf =
|
||
|
(
|
||
|
operationToLeaf: operation => result(t, string),
|
||
|
sampleCount: int,
|
||
|
t: node,
|
||
|
)
|
||
|
: result(t, string) => {
|
||
|
switch (t) {
|
||
|
| `Leaf(`SymbolicDist(d)) =>
|
||
|
Ok(`Leaf(`RenderedDist(SymbolicDist.T.toShape(sampleCount, d))))
|
||
|
| `Leaf(`RenderedDist(_)) as t => Ok(t) // already a rendered shape, we're done here
|
||
|
| `Operation(op) =>
|
||
|
E.R.bind(operationToLeaf(op), toLeaf(operationToLeaf, sampleCount))
|
||
|
};
|
||
|
};
|
||
|
};
|
||
|
|
||
|
let rec operationToLeaf =
|
||
|
(sampleCount: int, op: operation): result(t, string) => {
|
||
|
// the functions that convert the Operation nodes to Leaf nodes need to
|
||
|
// have a way to call this function on their children, if their children are themselves Operation nodes.
|
||
|
switch (op) {
|
||
|
| `AlgebraicCombination(algebraicOp, t1, t2) =>
|
||
|
AlgebraicCombination.toLeaf(
|
||
|
operationToLeaf(sampleCount),
|
||
|
algebraicOp,
|
||
|
t1,
|
||
|
t2 // we want to give it the option to render or simply leave it as is
|
||
|
)
|
||
|
| `PointwiseCombination(pointwiseOp, t1, t2) =>
|
||
|
PointwiseCombination.toLeaf(
|
||
|
operationToLeaf(sampleCount),
|
||
|
pointwiseOp,
|
||
|
t1,
|
||
|
t2,
|
||
|
)
|
||
|
| `VerticalScaling(scaleOp, t, scaleBy) =>
|
||
|
VerticalScaling.toLeaf(operationToLeaf(sampleCount), scaleOp, t, scaleBy)
|
||
|
| `Truncate(leftCutoff, rightCutoff, t) =>
|
||
|
Truncate.toLeaf(operationToLeaf(sampleCount), leftCutoff, rightCutoff, t)
|
||
|
| `FloatFromDist(distToFloatOp, t) =>
|
||
|
FloatFromDist.toLeaf(operationToLeaf(sampleCount), distToFloatOp, t)
|
||
|
| `Normalize(t) => Normalize.toLeaf(operationToLeaf(sampleCount), t)
|
||
|
| `Render(t) => Render.toLeaf(operationToLeaf(sampleCount), sampleCount, t)
|
||
|
};
|
||
|
};
|
||
|
|
||
|
/* This function recursively goes through the nodes of the parse tree,
|
||
|
replacing each Operation node and its subtree with a Data node.
|
||
|
Whenever possible, the replacement produces a new Symbolic Data node,
|
||
|
but most often it will produce a RenderedDist.
|
||
|
This function is used mainly to turn a parse tree into a single RenderedDist
|
||
|
that can then be displayed to the user. */
|
||
|
let toLeaf = (node: t, sampleCount: int): result(t, string) => {
|
||
|
switch (node) {
|
||
|
| `Leaf(d) => Ok(`Leaf(d))
|
||
|
| `Operation(op) => operationToLeaf(sampleCount, op)
|
||
|
};
|
||
|
};
|