Formatted ExpressionTreeEvaluator.re

This commit is contained in:
Ozzie Gooen 2020-07-06 19:50:22 +01:00
parent 56a9bda82a
commit 4cf7a69d3e

View File

@ -4,9 +4,7 @@ open ExpressionTypes.ExpressionTree;
type t = node; type t = node;
type tResult = node => result(node, string); type tResult = node => result(node, string);
type renderParams = { type renderParams = {sampleCount: int};
sampleCount: int,
};
/* Given two random variables A and B, this returns the distribution /* Given two random variables A and B, this returns the distribution
of a new variable that is the result of the operation on A and B. of a new variable that is the result of the operation on A and B.
@ -15,26 +13,23 @@ type renderParams = {
module AlgebraicCombination = { module AlgebraicCombination = {
let tryAnalyticalSimplification = (operation, t1: t, t2: t) => let tryAnalyticalSimplification = (operation, t1: t, t2: t) =>
switch (operation, t1, t2) { switch (operation, t1, t2) {
| (operation, | (operation, `SymbolicDist(d1), `SymbolicDist(d2)) =>
`SymbolicDist(d1),
`SymbolicDist(d2),
) =>
switch (SymbolicDist.T.tryAnalyticalSimplification(d1, d2, operation)) { switch (SymbolicDist.T.tryAnalyticalSimplification(d1, d2, operation)) {
| `AnalyticalSolution(symbolicDist) => Ok(`SymbolicDist(symbolicDist)) | `AnalyticalSolution(symbolicDist) => Ok(`SymbolicDist(symbolicDist))
| `Error(er) => Error(er) | `Error(er) => Error(er)
| `NoSolution => Ok(`AlgebraicCombination(operation, t1, t2)) | `NoSolution => Ok(`AlgebraicCombination((operation, t1, t2)))
} }
| _ => Ok(`AlgebraicCombination(operation, t1, t2)) | _ => Ok(`AlgebraicCombination((operation, t1, t2)))
}; };
let combineAsShapes = (toLeaf, renderParams, algebraicOp, t1, t2) => { let combineAsShapes = (toLeaf, renderParams, algebraicOp, t1, t2) => {
let renderShape = r => toLeaf(renderParams, `Render(r)); let renderShape = r => toLeaf(renderParams, `Render(r));
switch (renderShape(t1), renderShape(t2)) { switch (renderShape(t1), renderShape(t2)) {
| (Ok(`RenderedDist(s1)), Ok(`RenderedDist(s2))) => | (Ok(`RenderedDist(s1)), Ok(`RenderedDist(s2))) =>
Ok( Ok(
`RenderedDist( `RenderedDist(
Distributions.Shape.combineAlgebraically(algebraicOp, s1, s2), Distributions.Shape.combineAlgebraically(algebraicOp, s1, s2),
), ),
) )
| (Error(e1), _) => Error(e1) | (Error(e1), _) => Error(e1)
| (_, Error(e2)) => Error(e2) | (_, Error(e2)) => Error(e2)
@ -51,14 +46,13 @@ module AlgebraicCombination = {
t2: t, t2: t,
) )
: result(node, string) => : result(node, string) =>
algebraicOp algebraicOp
|> tryAnalyticalSimplification(_, t1, t2) |> tryAnalyticalSimplification(_, t1, t2)
|> E.R.bind( |> E.R.bind(
_, _,
fun fun
| `SymbolicDist(d) as t => Ok(t) | `SymbolicDist(d) as t => Ok(t)
| _ => combineAsShapes(toLeaf, renderParams, algebraicOp, t1, t2) | _ => combineAsShapes(toLeaf, renderParams, algebraicOp, t1, t2),
); );
}; };
@ -72,12 +66,12 @@ module VerticalScaling = {
switch (renderedShape, scaleBy) { switch (renderedShape, scaleBy) {
| (Ok(`RenderedDist(rs)), `SymbolicDist(`Float(sm))) => | (Ok(`RenderedDist(rs)), `SymbolicDist(`Float(sm))) =>
Ok( Ok(
`RenderedDist( `RenderedDist(
Distributions.Shape.T.mapY( Distributions.Shape.T.mapY(
~knownIntegralSumFn=knownIntegralSumFn(sm), ~knownIntegralSumFn=knownIntegralSumFn(sm),
fn(sm), fn(sm),
rs, rs,
), ),
), ),
) )
| (Error(e1), _) => Error(e1) | (Error(e1), _) => Error(e1)
@ -127,13 +121,12 @@ module Truncate = {
let trySimplification = (leftCutoff, rightCutoff, t) => { let trySimplification = (leftCutoff, rightCutoff, t) => {
switch (leftCutoff, rightCutoff, t) { switch (leftCutoff, rightCutoff, t) {
| (None, None, t) => Ok(t) | (None, None, t) => Ok(t)
| (lc, rc, `SymbolicDist(`Uniform(u))) => { | (lc, rc, `SymbolicDist(`Uniform(u))) =>
// just create a new Uniform distribution // just create a new Uniform distribution
let nu: SymbolicTypes.uniform = u; let nu: SymbolicTypes.uniform = u;
let newLow = max(E.O.default(neg_infinity, lc), nu.low); let newLow = max(E.O.default(neg_infinity, lc), nu.low);
let newHigh = min(E.O.default(infinity, rc), nu.high); let newHigh = min(E.O.default(infinity, rc), nu.high);
Ok(`SymbolicDist(`Uniform({low: newLow, high: newHigh}))); Ok(`SymbolicDist(`Uniform({low: newLow, high: newHigh})));
}
| (_, _, t) => Ok(t) | (_, _, t) => Ok(t)
}; };
}; };
@ -144,43 +137,47 @@ module Truncate = {
let renderedShape = toLeaf(renderParams, `Render(t)); let renderedShape = toLeaf(renderParams, `Render(t));
switch (renderedShape) { switch (renderedShape) {
| Ok(`RenderedDist(rs)) => { | Ok(`RenderedDist(rs)) =>
let truncatedShape = let truncatedShape =
rs |> Distributions.Shape.T.truncate(leftCutoff, rightCutoff); rs |> Distributions.Shape.T.truncate(leftCutoff, rightCutoff);
Ok(`RenderedDist(truncatedShape)); Ok(`RenderedDist(truncatedShape));
}
| Error(e1) => Error(e1) | Error(e1) => Error(e1)
| _ => Error("Could not truncate distribution.") | _ => Error("Could not truncate distribution.")
}; };
}; };
let operationToLeaf = let operationToLeaf =
( (
toLeaf, toLeaf,
renderParams, renderParams,
leftCutoff: option(float), leftCutoff: option(float),
rightCutoff: option(float), rightCutoff: option(float),
t: node, t: node,
) )
: result(node, string) => { : result(node, string) => {
t t
|> trySimplification(leftCutoff, rightCutoff) |> trySimplification(leftCutoff, rightCutoff)
|> E.R.bind( |> E.R.bind(
_, _,
fun fun
| `SymbolicDist(d) as t => Ok(t) | `SymbolicDist(d) as t => Ok(t)
| _ => truncateAsShape(toLeaf, renderParams, leftCutoff, rightCutoff, t), | _ =>
truncateAsShape(toLeaf, renderParams, leftCutoff, rightCutoff, t),
); );
}; };
}; };
module Normalize = { module Normalize = {
let rec operationToLeaf = (toLeaf, renderParams, t: node): result(node, string) => { let rec operationToLeaf =
(toLeaf, renderParams, t: node): result(node, string) => {
switch (t) { switch (t) {
| `RenderedDist(s) => | `RenderedDist(s) =>
Ok(`RenderedDist(Distributions.Shape.T.normalize(s))) Ok(`RenderedDist(Distributions.Shape.T.normalize(s)))
| `SymbolicDist(_) => Ok(t) | `SymbolicDist(_) => Ok(t)
| _ => t |> toLeaf(renderParams) |> E.R.bind(_, operationToLeaf(toLeaf, renderParams)) | _ =>
t
|> toLeaf(renderParams)
|> E.R.bind(_, operationToLeaf(toLeaf, renderParams))
}; };
}; };
}; };
@ -202,24 +199,25 @@ module FloatFromDist = {
switch (t) { switch (t) {
| `SymbolicDist(s) => symbolicToLeaf(distToFloatOp, s) | `SymbolicDist(s) => symbolicToLeaf(distToFloatOp, s)
| `RenderedDist(rs) => renderedToLeaf(distToFloatOp, rs) | `RenderedDist(rs) => renderedToLeaf(distToFloatOp, rs)
| _ => t |> toLeaf(renderParams) |> E.R.bind(_, operationToLeaf(toLeaf, renderParams, distToFloatOp)) | _ =>
t
|> toLeaf(renderParams)
|> E.R.bind(_, operationToLeaf(toLeaf, renderParams, distToFloatOp))
}; };
}; };
}; };
module Render = { module Render = {
let rec operationToLeaf = let rec operationToLeaf =
( (toLeaf, renderParams, t: node): result(t, string) => {
toLeaf,
renderParams,
t: node,
)
: result(t, string) => {
switch (t) { switch (t) {
| `SymbolicDist(d) => | `SymbolicDist(d) =>
Ok(`RenderedDist(SymbolicDist.T.toShape(renderParams.sampleCount, d))) Ok(`RenderedDist(SymbolicDist.T.toShape(renderParams.sampleCount, d)))
| `RenderedDist(_) as t => Ok(t) // already a rendered shape, we're done here | `RenderedDist(_) as t => Ok(t) // already a rendered shape, we're done here
| _ => t |> toLeaf(renderParams) |> E.R.bind(_, operationToLeaf(toLeaf, renderParams)) | _ =>
t
|> toLeaf(renderParams)
|> E.R.bind(_, operationToLeaf(toLeaf, renderParams))
}; };
}; };
}; };
@ -242,7 +240,7 @@ let rec toLeaf = (renderParams, node: t): result(t, string) => {
renderParams, renderParams,
algebraicOp, algebraicOp,
t1, t1,
t2 t2,
) )
| `PointwiseCombination(pointwiseOp, t1, t2) => | `PointwiseCombination(pointwiseOp, t1, t2) =>
PointwiseCombination.operationToLeaf( PointwiseCombination.operationToLeaf(
@ -253,9 +251,7 @@ let rec toLeaf = (renderParams, node: t): result(t, string) => {
t2, t2,
) )
| `VerticalScaling(scaleOp, t, scaleBy) => | `VerticalScaling(scaleOp, t, scaleBy) =>
VerticalScaling.operationToLeaf( VerticalScaling.operationToLeaf(toLeaf, renderParams, scaleOp, t, scaleBy)
toLeaf, renderParams, scaleOp, t, scaleBy
)
| `Truncate(leftCutoff, rightCutoff, t) => | `Truncate(leftCutoff, rightCutoff, t) =>
Truncate.operationToLeaf(toLeaf, renderParams, leftCutoff, rightCutoff, t) Truncate.operationToLeaf(toLeaf, renderParams, leftCutoff, rightCutoff, t)
| `FloatFromDist(distToFloatOp, t) => | `FloatFromDist(distToFloatOp, t) =>