squiggle/src/distPlus/expressionTree/ExpressionTreeEvaluator.re

260 lines
8.4 KiB
ReasonML
Raw Normal View History

2020-07-02 17:12:03 +00:00
open ExpressionTypes;
open ExpressionTypes.ExpressionTree;
type t = node;
type tResult = node => result(node, string);
2020-07-06 18:50:22 +00:00
type renderParams = {sampleCount: int};
2020-07-02 17:12:03 +00:00
/* Given two random variables A and B, this returns the distribution
of a new variable that is the result of the operation on A and B.
For instance, normal(0, 1) + normal(1, 1) -> normal(1, 2).
In general, this is implemented via convolution. */
module AlgebraicCombination = {
let tryAnalyticalSimplification = (operation, t1: t, t2: t) =>
switch (operation, t1, t2) {
2020-07-06 18:50:22 +00:00
| (operation, `SymbolicDist(d1), `SymbolicDist(d2)) =>
switch (SymbolicDist.T.tryAnalyticalSimplification(d1, d2, operation)) {
| `AnalyticalSolution(symbolicDist) => Ok(`SymbolicDist(symbolicDist))
2020-07-02 17:12:03 +00:00
| `Error(er) => Error(er)
2020-07-06 18:50:22 +00:00
| `NoSolution => Ok(`AlgebraicCombination((operation, t1, t2)))
2020-07-02 17:12:03 +00:00
}
2020-07-06 18:50:22 +00:00
| _ => Ok(`AlgebraicCombination((operation, t1, t2)))
};
2020-07-02 17:12:03 +00:00
let combineAsShapes = (toLeaf, renderParams, algebraicOp, t1, t2) => {
let renderShape = r => toLeaf(renderParams, `Render(r));
2020-07-02 17:12:03 +00:00
switch (renderShape(t1), renderShape(t2)) {
| (Ok(`RenderedDist(s1)), Ok(`RenderedDist(s2))) =>
2020-07-02 17:12:03 +00:00
Ok(
2020-07-06 18:50:22 +00:00
`RenderedDist(
Distributions.Shape.combineAlgebraically(algebraicOp, s1, s2),
),
2020-07-02 17:12:03 +00:00
)
| (Error(e1), _) => Error(e1)
| (_, Error(e2)) => Error(e2)
| _ => Error("Algebraic combination: rendering failed.")
2020-07-02 17:12:03 +00:00
};
};
let operationToLeaf =
2020-07-02 17:12:03 +00:00
(
toLeaf,
renderParams: renderParams,
2020-07-02 17:12:03 +00:00
algebraicOp: ExpressionTypes.algebraicOperation,
t1: t,
t2: t,
)
: result(node, string) =>
algebraicOp
|> tryAnalyticalSimplification(_, t1, t2)
2020-07-02 17:12:03 +00:00
|> E.R.bind(
_,
fun
2020-07-06 18:50:22 +00:00
| `SymbolicDist(d) as t => Ok(t)
| _ => combineAsShapes(toLeaf, renderParams, algebraicOp, t1, t2),
2020-07-02 17:12:03 +00:00
);
};
module VerticalScaling = {
let operationToLeaf = (toLeaf, renderParams, scaleOp, t, scaleBy) => {
2020-07-02 17:12:03 +00:00
// scaleBy has to be a single float, otherwise we'll return an error.
let fn = Operation.Scale.toFn(scaleOp);
let knownIntegralSumFn = Operation.Scale.toKnownIntegralSumFn(scaleOp);
let renderedShape = toLeaf(renderParams, `Render(t));
2020-07-02 17:12:03 +00:00
switch (renderedShape, scaleBy) {
| (Ok(`RenderedDist(rs)), `SymbolicDist(`Float(sm))) =>
2020-07-02 17:12:03 +00:00
Ok(
2020-07-06 18:50:22 +00:00
`RenderedDist(
Distributions.Shape.T.mapY(
~knownIntegralSumFn=knownIntegralSumFn(sm),
fn(sm),
rs,
),
2020-07-02 17:12:03 +00:00
),
)
| (Error(e1), _) => Error(e1)
| (_, _) => Error("Can only scale by float values.")
};
};
};
module PointwiseCombination = {
let pointwiseAdd = (toLeaf, renderParams, t1, t2) => {
let renderShape = r => toLeaf(renderParams, `Render(r));
switch (renderShape(t1), renderShape(t2)) {
| (Ok(`RenderedDist(rs1)), Ok(`RenderedDist(rs2))) =>
2020-07-02 17:12:03 +00:00
Ok(
`RenderedDist(
Distributions.Shape.combinePointwise(
~knownIntegralSumsFn=(a, b) => Some(a +. b),
(+.),
rs1,
rs2,
2020-07-02 17:12:03 +00:00
),
),
)
| (Error(e1), _) => Error(e1)
| (_, Error(e2)) => Error(e2)
| _ => Error("Pointwise combination: rendering failed.")
2020-07-02 17:12:03 +00:00
};
};
let pointwiseMultiply = (toLeaf, renderParams, t1, t2) => {
2020-07-02 17:12:03 +00:00
// TODO: construct a function that we can easily sample from, to construct
// a RenderedDist. Use the xMin and xMax of the rendered shapes to tell the sampling function where to look.
Error(
"Pointwise multiplication not yet supported.",
);
};
let operationToLeaf = (toLeaf, renderParams, pointwiseOp, t1, t2) => {
2020-07-02 17:12:03 +00:00
switch (pointwiseOp) {
| `Add => pointwiseAdd(toLeaf, renderParams, t1, t2)
| `Multiply => pointwiseMultiply(toLeaf, renderParams, t1, t2)
2020-07-02 17:12:03 +00:00
};
};
};
module Truncate = {
2020-07-07 04:08:56 +00:00
let trySimplification = (leftCutoff, rightCutoff, t): simplificationResult => {
switch (leftCutoff, rightCutoff, t) {
2020-07-07 04:08:56 +00:00
| (None, None, t) => `Solution(t)
| (Some(lc), Some(rc), t) when lc > rc => `Error("Left truncation bound must be smaller than right bound.")
2020-07-06 18:50:22 +00:00
| (lc, rc, `SymbolicDist(`Uniform(u))) =>
// just create a new Uniform distribution
let nu: SymbolicTypes.uniform = u;
let newLow = max(E.O.default(neg_infinity, lc), nu.low);
let newHigh = min(E.O.default(infinity, rc), nu.high);
2020-07-07 04:08:56 +00:00
`Solution(`SymbolicDist(`Uniform({low: newLow, high: newHigh})));
| _ => `NoSolution
2020-07-02 17:12:03 +00:00
};
};
let truncateAsShape = (toLeaf, renderParams, leftCutoff, rightCutoff, t) => {
2020-07-02 17:12:03 +00:00
// TODO: use named args in renderToShape; if we're lucky we can at least get the tail
// of a distribution we otherwise wouldn't get at all
let renderedShape = toLeaf(renderParams, `Render(t));
2020-07-02 17:12:03 +00:00
switch (renderedShape) {
2020-07-06 18:50:22 +00:00
| Ok(`RenderedDist(rs)) =>
2020-07-02 17:12:03 +00:00
let truncatedShape =
rs |> Distributions.Shape.T.truncate(leftCutoff, rightCutoff);
Ok(`RenderedDist(truncatedShape));
2020-07-02 17:12:03 +00:00
| Error(e1) => Error(e1)
| _ => Error("Could not truncate distribution.")
};
};
let operationToLeaf =
2020-07-06 18:50:22 +00:00
(
toLeaf,
renderParams,
leftCutoff: option(float),
rightCutoff: option(float),
t: node,
)
: result(node, string) => {
2020-07-02 17:12:03 +00:00
t
|> trySimplification(leftCutoff, rightCutoff)
2020-07-07 04:08:56 +00:00
|> fun
| `Solution(t) => Ok(t)
| `Error(e) => Error(e)
| `NoSolution => truncateAsShape(toLeaf, renderParams, leftCutoff, rightCutoff, t);
2020-07-02 17:12:03 +00:00
};
};
module Normalize = {
2020-07-06 18:50:22 +00:00
let rec operationToLeaf =
(toLeaf, renderParams, t: node): result(node, string) => {
2020-07-02 17:12:03 +00:00
switch (t) {
| `RenderedDist(s) =>
Ok(`RenderedDist(Distributions.Shape.T.normalize(s)));
| `SymbolicDist(_) => Ok(t)
2020-07-06 18:50:22 +00:00
| _ =>
t
|> toLeaf(renderParams)
|> E.R.bind(_, operationToLeaf(toLeaf, renderParams))
2020-07-02 17:12:03 +00:00
};
};
};
module FloatFromDist = {
let symbolicToLeaf = (distToFloatOp: distToFloatOperation, s) => {
SymbolicDist.T.operate(distToFloatOp, s)
|> E.R.bind(_, v => Ok(`SymbolicDist(`Float(v))));
2020-07-02 17:12:03 +00:00
};
let renderedToLeaf =
(distToFloatOp: distToFloatOperation, rs: DistTypes.shape)
: result(node, string) => {
Distributions.Shape.operate(distToFloatOp, rs)
|> (v => Ok(`SymbolicDist(`Float(v))));
2020-07-02 17:12:03 +00:00
};
let rec operationToLeaf =
(toLeaf, renderParams, distToFloatOp: distToFloatOperation, t: node)
2020-07-02 17:12:03 +00:00
: result(node, string) => {
switch (t) {
| `SymbolicDist(s) => symbolicToLeaf(distToFloatOp, s)
| `RenderedDist(rs) => renderedToLeaf(distToFloatOp, rs)
2020-07-06 18:50:22 +00:00
| _ =>
t
|> toLeaf(renderParams)
|> E.R.bind(_, operationToLeaf(toLeaf, renderParams, distToFloatOp))
2020-07-02 17:12:03 +00:00
};
};
};
module Render = {
let rec operationToLeaf =
2020-07-06 18:50:22 +00:00
(toLeaf, renderParams, t: node): result(t, string) => {
2020-07-02 17:12:03 +00:00
switch (t) {
| `SymbolicDist(d) =>
Ok(`RenderedDist(SymbolicDist.T.toShape(renderParams.sampleCount, d)))
| `RenderedDist(_) as t => Ok(t) // already a rendered shape, we're done here
2020-07-06 18:50:22 +00:00
| _ =>
t
|> toLeaf(renderParams)
|> E.R.bind(_, operationToLeaf(toLeaf, renderParams))
2020-07-02 17:12:03 +00:00
};
};
};
/* This function recursively goes through the nodes of the parse tree,
replacing each Operation node and its subtree with a Data node.
Whenever possible, the replacement produces a new Symbolic Data node,
but most often it will produce a RenderedDist.
This function is used mainly to turn a parse tree into a single RenderedDist
that can then be displayed to the user. */
let rec toLeaf = (renderParams, node: t): result(t, string) => {
switch (node) {
// Leaf nodes just stay leaf nodes
| `SymbolicDist(_)
| `RenderedDist(_) => Ok(node)
// Operations need to be turned into leaves
2020-07-02 17:12:03 +00:00
| `AlgebraicCombination(algebraicOp, t1, t2) =>
AlgebraicCombination.operationToLeaf(
toLeaf,
renderParams,
2020-07-02 17:12:03 +00:00
algebraicOp,
t1,
2020-07-06 18:50:22 +00:00
t2,
2020-07-02 17:12:03 +00:00
)
| `PointwiseCombination(pointwiseOp, t1, t2) =>
PointwiseCombination.operationToLeaf(
toLeaf,
renderParams,
2020-07-02 17:12:03 +00:00
pointwiseOp,
t1,
t2,
)
| `VerticalScaling(scaleOp, t, scaleBy) =>
2020-07-06 18:50:22 +00:00
VerticalScaling.operationToLeaf(toLeaf, renderParams, scaleOp, t, scaleBy)
2020-07-02 17:12:03 +00:00
| `Truncate(leftCutoff, rightCutoff, t) =>
Truncate.operationToLeaf(toLeaf, renderParams, leftCutoff, rightCutoff, t)
2020-07-02 17:12:03 +00:00
| `FloatFromDist(distToFloatOp, t) =>
FloatFromDist.operationToLeaf(toLeaf, renderParams, distToFloatOp, t)
| `Normalize(t) => Normalize.operationToLeaf(toLeaf, renderParams, t)
| `Render(t) => Render.operationToLeaf(toLeaf, renderParams, t)
2020-07-02 17:12:03 +00:00
};
};