Formatted ExpressionTreeEvaluator.re

This commit is contained in:
Ozzie Gooen 2020-07-06 19:50:22 +01:00
parent 56a9bda82a
commit 4cf7a69d3e

View File

@ -4,9 +4,7 @@ open ExpressionTypes.ExpressionTree;
type t = node;
type tResult = node => result(node, string);
type renderParams = {
sampleCount: int,
};
type renderParams = {sampleCount: int};
/* Given two random variables A and B, this returns the distribution
of a new variable that is the result of the operation on A and B.
@ -15,26 +13,23 @@ type renderParams = {
module AlgebraicCombination = {
let tryAnalyticalSimplification = (operation, t1: t, t2: t) =>
switch (operation, t1, t2) {
| (operation,
`SymbolicDist(d1),
`SymbolicDist(d2),
) =>
| (operation, `SymbolicDist(d1), `SymbolicDist(d2)) =>
switch (SymbolicDist.T.tryAnalyticalSimplification(d1, d2, operation)) {
| `AnalyticalSolution(symbolicDist) => Ok(`SymbolicDist(symbolicDist))
| `Error(er) => Error(er)
| `NoSolution => Ok(`AlgebraicCombination(operation, t1, t2))
| `NoSolution => Ok(`AlgebraicCombination((operation, t1, t2)))
}
| _ => Ok(`AlgebraicCombination(operation, t1, t2))
};
| _ => Ok(`AlgebraicCombination((operation, t1, t2)))
};
let combineAsShapes = (toLeaf, renderParams, algebraicOp, t1, t2) => {
let renderShape = r => toLeaf(renderParams, `Render(r));
switch (renderShape(t1), renderShape(t2)) {
| (Ok(`RenderedDist(s1)), Ok(`RenderedDist(s2))) =>
Ok(
`RenderedDist(
Distributions.Shape.combineAlgebraically(algebraicOp, s1, s2),
),
`RenderedDist(
Distributions.Shape.combineAlgebraically(algebraicOp, s1, s2),
),
)
| (Error(e1), _) => Error(e1)
| (_, Error(e2)) => Error(e2)
@ -51,14 +46,13 @@ module AlgebraicCombination = {
t2: t,
)
: result(node, string) =>
algebraicOp
|> tryAnalyticalSimplification(_, t1, t2)
|> E.R.bind(
_,
fun
| `SymbolicDist(d) as t => Ok(t)
| _ => combineAsShapes(toLeaf, renderParams, algebraicOp, t1, t2)
| `SymbolicDist(d) as t => Ok(t)
| _ => combineAsShapes(toLeaf, renderParams, algebraicOp, t1, t2),
);
};
@ -72,12 +66,12 @@ module VerticalScaling = {
switch (renderedShape, scaleBy) {
| (Ok(`RenderedDist(rs)), `SymbolicDist(`Float(sm))) =>
Ok(
`RenderedDist(
Distributions.Shape.T.mapY(
~knownIntegralSumFn=knownIntegralSumFn(sm),
fn(sm),
rs,
),
`RenderedDist(
Distributions.Shape.T.mapY(
~knownIntegralSumFn=knownIntegralSumFn(sm),
fn(sm),
rs,
),
),
)
| (Error(e1), _) => Error(e1)
@ -127,13 +121,12 @@ module Truncate = {
let trySimplification = (leftCutoff, rightCutoff, t) => {
switch (leftCutoff, rightCutoff, t) {
| (None, None, t) => Ok(t)
| (lc, rc, `SymbolicDist(`Uniform(u))) => {
// just create a new Uniform distribution
let nu: SymbolicTypes.uniform = u;
let newLow = max(E.O.default(neg_infinity, lc), nu.low);
let newHigh = min(E.O.default(infinity, rc), nu.high);
Ok(`SymbolicDist(`Uniform({low: newLow, high: newHigh})));
}
| (lc, rc, `SymbolicDist(`Uniform(u))) =>
// just create a new Uniform distribution
let nu: SymbolicTypes.uniform = u;
let newLow = max(E.O.default(neg_infinity, lc), nu.low);
let newHigh = min(E.O.default(infinity, rc), nu.high);
Ok(`SymbolicDist(`Uniform({low: newLow, high: newHigh})));
| (_, _, t) => Ok(t)
};
};
@ -144,43 +137,47 @@ module Truncate = {
let renderedShape = toLeaf(renderParams, `Render(t));
switch (renderedShape) {
| Ok(`RenderedDist(rs)) => {
| Ok(`RenderedDist(rs)) =>
let truncatedShape =
rs |> Distributions.Shape.T.truncate(leftCutoff, rightCutoff);
Ok(`RenderedDist(truncatedShape));
}
| Error(e1) => Error(e1)
| _ => Error("Could not truncate distribution.")
};
};
let operationToLeaf =
(
toLeaf,
renderParams,
leftCutoff: option(float),
rightCutoff: option(float),
t: node,
)
: result(node, string) => {
(
toLeaf,
renderParams,
leftCutoff: option(float),
rightCutoff: option(float),
t: node,
)
: result(node, string) => {
t
|> trySimplification(leftCutoff, rightCutoff)
|> E.R.bind(
_,
fun
| `SymbolicDist(d) as t => Ok(t)
| _ => truncateAsShape(toLeaf, renderParams, leftCutoff, rightCutoff, t),
| _ =>
truncateAsShape(toLeaf, renderParams, leftCutoff, rightCutoff, t),
);
};
};
module Normalize = {
let rec operationToLeaf = (toLeaf, renderParams, t: node): result(node, string) => {
let rec operationToLeaf =
(toLeaf, renderParams, t: node): result(node, string) => {
switch (t) {
| `RenderedDist(s) =>
Ok(`RenderedDist(Distributions.Shape.T.normalize(s)))
| `SymbolicDist(_) => Ok(t)
| _ => t |> toLeaf(renderParams) |> E.R.bind(_, operationToLeaf(toLeaf, renderParams))
| _ =>
t
|> toLeaf(renderParams)
|> E.R.bind(_, operationToLeaf(toLeaf, renderParams))
};
};
};
@ -202,24 +199,25 @@ module FloatFromDist = {
switch (t) {
| `SymbolicDist(s) => symbolicToLeaf(distToFloatOp, s)
| `RenderedDist(rs) => renderedToLeaf(distToFloatOp, rs)
| _ => t |> toLeaf(renderParams) |> E.R.bind(_, operationToLeaf(toLeaf, renderParams, distToFloatOp))
| _ =>
t
|> toLeaf(renderParams)
|> E.R.bind(_, operationToLeaf(toLeaf, renderParams, distToFloatOp))
};
};
};
module Render = {
let rec operationToLeaf =
(
toLeaf,
renderParams,
t: node,
)
: result(t, string) => {
(toLeaf, renderParams, t: node): result(t, string) => {
switch (t) {
| `SymbolicDist(d) =>
Ok(`RenderedDist(SymbolicDist.T.toShape(renderParams.sampleCount, d)))
| `RenderedDist(_) as t => Ok(t) // already a rendered shape, we're done here
| _ => t |> toLeaf(renderParams) |> E.R.bind(_, operationToLeaf(toLeaf, renderParams))
| _ =>
t
|> toLeaf(renderParams)
|> E.R.bind(_, operationToLeaf(toLeaf, renderParams))
};
};
};
@ -242,7 +240,7 @@ let rec toLeaf = (renderParams, node: t): result(t, string) => {
renderParams,
algebraicOp,
t1,
t2
t2,
)
| `PointwiseCombination(pointwiseOp, t1, t2) =>
PointwiseCombination.operationToLeaf(
@ -253,9 +251,7 @@ let rec toLeaf = (renderParams, node: t): result(t, string) => {
t2,
)
| `VerticalScaling(scaleOp, t, scaleBy) =>
VerticalScaling.operationToLeaf(
toLeaf, renderParams, scaleOp, t, scaleBy
)
VerticalScaling.operationToLeaf(toLeaf, renderParams, scaleOp, t, scaleBy)
| `Truncate(leftCutoff, rightCutoff, t) =>
Truncate.operationToLeaf(toLeaf, renderParams, leftCutoff, rightCutoff, t)
| `FloatFromDist(distToFloatOp, t) =>