2020-07-02 17:12:03 +00:00
|
|
|
open ExpressionTypes;
|
|
|
|
open ExpressionTypes.ExpressionTree;
|
|
|
|
|
|
|
|
type t = node;
|
|
|
|
type tResult = node => result(node, string);
|
|
|
|
|
|
|
|
/* Given two random variables A and B, this returns the distribution
|
|
|
|
of a new variable that is the result of the operation on A and B.
|
|
|
|
For instance, normal(0, 1) + normal(1, 1) -> normal(1, 2).
|
|
|
|
In general, this is implemented via convolution. */
|
|
|
|
module AlgebraicCombination = {
|
2020-07-03 21:55:27 +00:00
|
|
|
let tryAnalyticalSimplification = (operation, t1: t, t2: t) =>
|
|
|
|
switch (operation, t1, t2) {
|
2020-07-06 18:50:22 +00:00
|
|
|
| (operation, `SymbolicDist(d1), `SymbolicDist(d2)) =>
|
2020-07-03 21:55:27 +00:00
|
|
|
switch (SymbolicDist.T.tryAnalyticalSimplification(d1, d2, operation)) {
|
|
|
|
| `AnalyticalSolution(symbolicDist) => Ok(`SymbolicDist(symbolicDist))
|
2020-07-02 17:12:03 +00:00
|
|
|
| `Error(er) => Error(er)
|
2020-07-06 18:50:22 +00:00
|
|
|
| `NoSolution => Ok(`AlgebraicCombination((operation, t1, t2)))
|
2020-07-02 17:12:03 +00:00
|
|
|
}
|
2020-07-06 18:50:22 +00:00
|
|
|
| _ => Ok(`AlgebraicCombination((operation, t1, t2)))
|
|
|
|
};
|
2020-07-02 17:12:03 +00:00
|
|
|
|
2020-07-11 15:41:35 +00:00
|
|
|
let tryCombination = (n, algebraicOp, t1: node, t2: node) => {
|
|
|
|
let sampleN =
|
|
|
|
mapRenderable(Shape.sampleNRendered(n), SymbolicDist.T.sampleN(n));
|
|
|
|
switch (sampleN(t1), sampleN(t2)) {
|
|
|
|
| (Some(a), Some(b)) =>
|
|
|
|
Some(
|
|
|
|
Belt.Array.zip(a, b)
|
|
|
|
|> E.A.fmap(((a, b)) => Operation.Algebraic.toFn(algebraicOp, a, b)),
|
2020-07-02 17:12:03 +00:00
|
|
|
)
|
2020-07-11 15:41:35 +00:00
|
|
|
| _ => None
|
2020-07-02 17:12:03 +00:00
|
|
|
};
|
|
|
|
};
|
|
|
|
|
2020-07-11 15:41:35 +00:00
|
|
|
let renderIfNotRendered = (params, t) =>
|
|
|
|
!renderable(t)
|
|
|
|
? switch (render(params, t)) {
|
|
|
|
| Ok(r) => Ok(r)
|
|
|
|
| Error(e) => Error(e)
|
|
|
|
}
|
|
|
|
: Ok(t);
|
|
|
|
|
|
|
|
let combineAsShapes =
|
|
|
|
(evaluationParams: evaluationParams, algebraicOp, t1: node, t2: node) => {
|
|
|
|
let i1 = renderIfNotRendered(evaluationParams, t1);
|
|
|
|
let i2 = renderIfNotRendered(evaluationParams, t2);
|
|
|
|
E.R.merge(i1, i2)
|
|
|
|
|> E.R.bind(
|
|
|
|
_,
|
|
|
|
((a, b)) => {
|
2020-07-12 22:54:52 +00:00
|
|
|
let samples =
|
|
|
|
tryCombination(
|
|
|
|
evaluationParams.samplingInputs.sampleCount,
|
|
|
|
algebraicOp,
|
|
|
|
a,
|
|
|
|
b,
|
|
|
|
);
|
2020-07-11 15:41:35 +00:00
|
|
|
let shape =
|
|
|
|
samples
|
|
|
|
|> E.O.fmap(
|
|
|
|
Samples.T.fromSamples(
|
|
|
|
~samplingInputs={
|
2020-07-12 22:54:52 +00:00
|
|
|
sampleCount:
|
|
|
|
Some(evaluationParams.samplingInputs.sampleCount),
|
|
|
|
outputXYPoints:
|
|
|
|
Some(evaluationParams.samplingInputs.outputXYPoints),
|
|
|
|
kernelWidth: evaluationParams.samplingInputs.kernelWidth,
|
2020-07-11 15:41:35 +00:00
|
|
|
},
|
|
|
|
),
|
|
|
|
)
|
|
|
|
|> E.O.bind(_, (r: RenderTypes.ShapeRenderer.Sampling.outputs) =>
|
|
|
|
r.shape
|
|
|
|
)
|
|
|
|
|> E.O.toResult("No response");
|
|
|
|
shape |> E.R.fmap(r => `Normalize(`RenderedDist(r)));
|
|
|
|
},
|
|
|
|
);
|
|
|
|
};
|
|
|
|
|
2020-07-03 21:55:27 +00:00
|
|
|
let operationToLeaf =
|
2020-07-02 17:12:03 +00:00
|
|
|
(
|
2020-07-08 10:39:03 +00:00
|
|
|
evaluationParams: evaluationParams,
|
2020-07-02 17:12:03 +00:00
|
|
|
algebraicOp: ExpressionTypes.algebraicOperation,
|
|
|
|
t1: t,
|
|
|
|
t2: t,
|
|
|
|
)
|
|
|
|
: result(node, string) =>
|
2020-07-03 21:55:27 +00:00
|
|
|
algebraicOp
|
|
|
|
|> tryAnalyticalSimplification(_, t1, t2)
|
2020-07-02 17:12:03 +00:00
|
|
|
|> E.R.bind(
|
|
|
|
_,
|
|
|
|
fun
|
2020-07-06 18:50:22 +00:00
|
|
|
| `SymbolicDist(d) as t => Ok(t)
|
2020-07-08 10:39:03 +00:00
|
|
|
| _ => combineAsShapes(evaluationParams, algebraicOp, t1, t2),
|
2020-07-02 17:12:03 +00:00
|
|
|
);
|
|
|
|
};
|
|
|
|
|
|
|
|
module VerticalScaling = {
|
2020-07-08 10:39:03 +00:00
|
|
|
let operationToLeaf =
|
|
|
|
(evaluationParams: evaluationParams, scaleOp, t, scaleBy) => {
|
2020-07-02 17:12:03 +00:00
|
|
|
// scaleBy has to be a single float, otherwise we'll return an error.
|
|
|
|
let fn = Operation.Scale.toFn(scaleOp);
|
2020-07-17 01:14:42 +00:00
|
|
|
let integralSumCacheFn = Operation.Scale.toIntegralSumCacheFn(scaleOp);
|
|
|
|
let integralCacheFn = Operation.Scale.toIntegralCacheFn(scaleOp);
|
2020-07-08 10:39:03 +00:00
|
|
|
let renderedShape = render(evaluationParams, t);
|
2020-07-02 17:12:03 +00:00
|
|
|
|
|
|
|
switch (renderedShape, scaleBy) {
|
2020-07-03 21:55:27 +00:00
|
|
|
| (Ok(`RenderedDist(rs)), `SymbolicDist(`Float(sm))) =>
|
2020-07-02 17:12:03 +00:00
|
|
|
Ok(
|
2020-07-06 18:50:22 +00:00
|
|
|
`RenderedDist(
|
2020-07-08 16:00:13 +00:00
|
|
|
Shape.T.mapY(
|
2020-07-17 01:14:42 +00:00
|
|
|
~integralSumCacheFn=integralSumCacheFn(sm),
|
|
|
|
~integralCacheFn=integralCacheFn(sm),
|
2020-07-06 18:50:22 +00:00
|
|
|
fn(sm),
|
|
|
|
rs,
|
|
|
|
),
|
2020-07-02 17:12:03 +00:00
|
|
|
),
|
|
|
|
)
|
|
|
|
| (Error(e1), _) => Error(e1)
|
|
|
|
| (_, _) => Error("Can only scale by float values.")
|
|
|
|
};
|
|
|
|
};
|
|
|
|
};
|
|
|
|
|
|
|
|
module PointwiseCombination = {
|
2020-07-17 01:14:42 +00:00
|
|
|
let pointwiseAdd = (evaluationParams: evaluationParams, t1: t, t2: t) => {
|
2020-07-08 10:39:03 +00:00
|
|
|
switch (render(evaluationParams, t1), render(evaluationParams, t2)) {
|
2020-07-03 21:55:27 +00:00
|
|
|
| (Ok(`RenderedDist(rs1)), Ok(`RenderedDist(rs2))) =>
|
2020-07-02 17:12:03 +00:00
|
|
|
Ok(
|
2020-07-03 21:55:27 +00:00
|
|
|
`RenderedDist(
|
2020-07-08 16:00:13 +00:00
|
|
|
Shape.combinePointwise(
|
2020-07-17 01:14:42 +00:00
|
|
|
~integralSumCachesFn=(a, b) => Some(a +. b),
|
|
|
|
~integralCachesFn=(a, b) => Some(Continuous.combinePointwise(~extrapolation=`UseOutermostPoints, (+.), a, b)),
|
2020-07-03 21:55:27 +00:00
|
|
|
(+.),
|
|
|
|
rs1,
|
|
|
|
rs2,
|
2020-07-02 17:12:03 +00:00
|
|
|
),
|
|
|
|
),
|
|
|
|
)
|
|
|
|
| (Error(e1), _) => Error(e1)
|
|
|
|
| (_, Error(e2)) => Error(e2)
|
2020-07-03 21:55:27 +00:00
|
|
|
| _ => Error("Pointwise combination: rendering failed.")
|
2020-07-02 17:12:03 +00:00
|
|
|
};
|
|
|
|
};
|
|
|
|
|
2020-07-17 01:14:42 +00:00
|
|
|
let pointwiseMultiply = (evaluationParams: evaluationParams, t1: t, t2: t) => {
|
2020-07-02 17:12:03 +00:00
|
|
|
// TODO: construct a function that we can easily sample from, to construct
|
|
|
|
// a RenderedDist. Use the xMin and xMax of the rendered shapes to tell the sampling function where to look.
|
|
|
|
Error(
|
|
|
|
"Pointwise multiplication not yet supported.",
|
|
|
|
);
|
|
|
|
};
|
|
|
|
|
2020-07-08 10:39:03 +00:00
|
|
|
let operationToLeaf =
|
2020-07-17 01:14:42 +00:00
|
|
|
(evaluationParams: evaluationParams, pointwiseOp: pointwiseOperation, t1: t, t2: t) => {
|
2020-07-02 17:12:03 +00:00
|
|
|
switch (pointwiseOp) {
|
2020-07-08 10:39:03 +00:00
|
|
|
| `Add => pointwiseAdd(evaluationParams, t1, t2)
|
|
|
|
| `Multiply => pointwiseMultiply(evaluationParams, t1, t2)
|
2020-07-02 17:12:03 +00:00
|
|
|
};
|
|
|
|
};
|
|
|
|
};
|
|
|
|
|
|
|
|
module Truncate = {
|
2020-07-07 04:08:56 +00:00
|
|
|
let trySimplification = (leftCutoff, rightCutoff, t): simplificationResult => {
|
2020-07-03 21:55:27 +00:00
|
|
|
switch (leftCutoff, rightCutoff, t) {
|
2020-07-07 04:08:56 +00:00
|
|
|
| (None, None, t) => `Solution(t)
|
2020-07-08 09:37:39 +00:00
|
|
|
| (Some(lc), Some(rc), t) when lc > rc =>
|
|
|
|
`Error("Left truncation bound must be smaller than right bound.")
|
2020-07-06 18:50:22 +00:00
|
|
|
| (lc, rc, `SymbolicDist(`Uniform(u))) =>
|
2020-07-11 15:41:35 +00:00
|
|
|
`Solution(
|
|
|
|
`SymbolicDist(`Uniform(SymbolicDist.Uniform.truncate(lc, rc, u))),
|
|
|
|
)
|
2020-07-07 04:08:56 +00:00
|
|
|
| _ => `NoSolution
|
2020-07-02 17:12:03 +00:00
|
|
|
};
|
|
|
|
};
|
|
|
|
|
2020-07-08 10:39:03 +00:00
|
|
|
let truncateAsShape =
|
|
|
|
(evaluationParams: evaluationParams, leftCutoff, rightCutoff, t) => {
|
2020-07-17 03:50:12 +00:00
|
|
|
// TODO: use named args for xMin/xMax in renderToShape; if we're lucky we can at least get the tail
|
2020-07-02 17:12:03 +00:00
|
|
|
// of a distribution we otherwise wouldn't get at all
|
2020-07-08 10:39:03 +00:00
|
|
|
switch (render(evaluationParams, t)) {
|
2020-07-06 18:50:22 +00:00
|
|
|
| Ok(`RenderedDist(rs)) =>
|
2020-07-11 15:41:35 +00:00
|
|
|
Ok(`RenderedDist(Shape.T.truncate(leftCutoff, rightCutoff, rs)))
|
2020-07-08 10:39:03 +00:00
|
|
|
| Error(e) => Error(e)
|
2020-07-02 17:12:03 +00:00
|
|
|
| _ => Error("Could not truncate distribution.")
|
|
|
|
};
|
|
|
|
};
|
|
|
|
|
2020-07-03 21:55:27 +00:00
|
|
|
let operationToLeaf =
|
2020-07-06 18:50:22 +00:00
|
|
|
(
|
2020-07-08 10:39:03 +00:00
|
|
|
evaluationParams,
|
2020-07-06 18:50:22 +00:00
|
|
|
leftCutoff: option(float),
|
|
|
|
rightCutoff: option(float),
|
|
|
|
t: node,
|
|
|
|
)
|
|
|
|
: result(node, string) => {
|
2020-07-02 17:12:03 +00:00
|
|
|
t
|
2020-07-03 21:55:27 +00:00
|
|
|
|> trySimplification(leftCutoff, rightCutoff)
|
2020-07-08 09:37:39 +00:00
|
|
|
|> (
|
|
|
|
fun
|
|
|
|
| `Solution(t) => Ok(t)
|
|
|
|
| `Error(e) => Error(e)
|
|
|
|
| `NoSolution =>
|
2020-07-08 10:39:03 +00:00
|
|
|
truncateAsShape(evaluationParams, leftCutoff, rightCutoff, t)
|
2020-07-08 09:37:39 +00:00
|
|
|
);
|
2020-07-02 17:12:03 +00:00
|
|
|
};
|
|
|
|
};
|
|
|
|
|
|
|
|
module Normalize = {
|
2020-07-08 10:39:03 +00:00
|
|
|
let rec operationToLeaf = (evaluationParams, t: node): result(node, string) => {
|
2020-07-02 17:12:03 +00:00
|
|
|
switch (t) {
|
2020-07-11 15:41:35 +00:00
|
|
|
| `RenderedDist(s) => Ok(`RenderedDist(Shape.T.normalize(s)))
|
2020-07-03 21:55:27 +00:00
|
|
|
| `SymbolicDist(_) => Ok(t)
|
2020-07-08 12:52:47 +00:00
|
|
|
| _ => evaluateAndRetry(evaluationParams, operationToLeaf, t)
|
2020-07-02 17:12:03 +00:00
|
|
|
};
|
|
|
|
};
|
|
|
|
};
|
|
|
|
|
|
|
|
module FloatFromDist = {
|
2020-07-03 21:55:27 +00:00
|
|
|
let rec operationToLeaf =
|
2020-07-08 10:39:03 +00:00
|
|
|
(evaluationParams, distToFloatOp: distToFloatOperation, t: node)
|
2020-07-02 17:12:03 +00:00
|
|
|
: result(node, string) => {
|
|
|
|
switch (t) {
|
2020-07-08 10:39:03 +00:00
|
|
|
| `SymbolicDist(s) =>
|
|
|
|
SymbolicDist.T.operate(distToFloatOp, s)
|
|
|
|
|> E.R.bind(_, v => Ok(`SymbolicDist(`Float(v))))
|
|
|
|
| `RenderedDist(rs) =>
|
2020-07-08 16:00:13 +00:00
|
|
|
Shape.operate(distToFloatOp, rs)
|
2020-07-08 10:39:03 +00:00
|
|
|
|> (v => Ok(`SymbolicDist(`Float(v))))
|
2020-07-06 18:50:22 +00:00
|
|
|
| _ =>
|
|
|
|
t
|
2020-07-08 12:52:47 +00:00
|
|
|
|> evaluateAndRetry(evaluationParams, r =>
|
|
|
|
operationToLeaf(r, distToFloatOp)
|
|
|
|
)
|
2020-07-02 17:12:03 +00:00
|
|
|
};
|
|
|
|
};
|
|
|
|
};
|
|
|
|
|
|
|
|
module Render = {
|
2020-07-03 21:55:27 +00:00
|
|
|
let rec operationToLeaf =
|
2020-07-08 10:39:03 +00:00
|
|
|
(evaluationParams: evaluationParams, t: node): result(t, string) => {
|
2020-07-02 17:12:03 +00:00
|
|
|
switch (t) {
|
2020-07-03 21:55:27 +00:00
|
|
|
| `SymbolicDist(d) =>
|
2020-07-08 10:39:03 +00:00
|
|
|
Ok(
|
|
|
|
`RenderedDist(
|
2020-07-12 22:54:52 +00:00
|
|
|
SymbolicDist.T.toShape(evaluationParams.intendedShapeLength, d),
|
2020-07-08 10:39:03 +00:00
|
|
|
),
|
|
|
|
)
|
2020-07-03 21:55:27 +00:00
|
|
|
| `RenderedDist(_) as t => Ok(t) // already a rendered shape, we're done here
|
2020-07-08 12:52:47 +00:00
|
|
|
| _ => evaluateAndRetry(evaluationParams, operationToLeaf, t)
|
2020-07-02 17:12:03 +00:00
|
|
|
};
|
|
|
|
};
|
|
|
|
};
|
|
|
|
|
2020-07-03 21:55:27 +00:00
|
|
|
/* This function recursively goes through the nodes of the parse tree,
|
|
|
|
replacing each Operation node and its subtree with a Data node.
|
|
|
|
Whenever possible, the replacement produces a new Symbolic Data node,
|
|
|
|
but most often it will produce a RenderedDist.
|
|
|
|
This function is used mainly to turn a parse tree into a single RenderedDist
|
|
|
|
that can then be displayed to the user. */
|
2020-07-08 10:39:03 +00:00
|
|
|
let toLeaf =
|
2020-07-08 12:52:47 +00:00
|
|
|
(
|
|
|
|
evaluationParams: ExpressionTypes.ExpressionTree.evaluationParams,
|
|
|
|
node: t,
|
|
|
|
)
|
|
|
|
: result(t, string) => {
|
2020-07-03 21:55:27 +00:00
|
|
|
switch (node) {
|
|
|
|
// Leaf nodes just stay leaf nodes
|
|
|
|
| `SymbolicDist(_)
|
|
|
|
| `RenderedDist(_) => Ok(node)
|
2020-07-08 10:39:03 +00:00
|
|
|
// Operations nevaluationParamsd to be turned into leaves
|
2020-07-02 17:12:03 +00:00
|
|
|
| `AlgebraicCombination(algebraicOp, t1, t2) =>
|
2020-07-03 21:55:27 +00:00
|
|
|
AlgebraicCombination.operationToLeaf(
|
2020-07-08 10:39:03 +00:00
|
|
|
evaluationParams,
|
2020-07-02 17:12:03 +00:00
|
|
|
algebraicOp,
|
|
|
|
t1,
|
2020-07-06 18:50:22 +00:00
|
|
|
t2,
|
2020-07-02 17:12:03 +00:00
|
|
|
)
|
|
|
|
| `PointwiseCombination(pointwiseOp, t1, t2) =>
|
2020-07-03 21:55:27 +00:00
|
|
|
PointwiseCombination.operationToLeaf(
|
2020-07-08 10:39:03 +00:00
|
|
|
evaluationParams,
|
2020-07-02 17:12:03 +00:00
|
|
|
pointwiseOp,
|
|
|
|
t1,
|
|
|
|
t2,
|
|
|
|
)
|
|
|
|
| `VerticalScaling(scaleOp, t, scaleBy) =>
|
2020-07-08 10:39:03 +00:00
|
|
|
VerticalScaling.operationToLeaf(evaluationParams, scaleOp, t, scaleBy)
|
2020-07-02 17:12:03 +00:00
|
|
|
| `Truncate(leftCutoff, rightCutoff, t) =>
|
2020-07-08 10:39:03 +00:00
|
|
|
Truncate.operationToLeaf(evaluationParams, leftCutoff, rightCutoff, t)
|
2020-07-02 17:12:03 +00:00
|
|
|
| `FloatFromDist(distToFloatOp, t) =>
|
2020-07-08 10:39:03 +00:00
|
|
|
FloatFromDist.operationToLeaf(evaluationParams, distToFloatOp, t)
|
|
|
|
| `Normalize(t) => Normalize.operationToLeaf(evaluationParams, t)
|
|
|
|
| `Render(t) => Render.operationToLeaf(evaluationParams, t)
|
2020-07-02 17:12:03 +00:00
|
|
|
};
|
|
|
|
};
|