Moving operations functionality into new SymbolicTypes.re file

This commit is contained in:
Ozzie Gooen 2020-07-01 22:01:58 +01:00
parent acdd3dfe7a
commit baaff19750
5 changed files with 330 additions and 227 deletions

View File

@ -1,5 +1,3 @@
type algebraicOperation = [ | `Add | `Multiply | `Subtract | `Divide];
type pointMassesWithMoments = {
n: int,
masses: array(float),
@ -7,23 +5,6 @@ type pointMassesWithMoments = {
variances: array(float),
};
module Operation = {
type t = algebraicOperation;
let toFn: (t, float, float) => float =
fun
| `Add => (+.)
| `Subtract => (-.)
| `Multiply => ( *. )
| `Divide => (/.);
let toString =
fun
| `Add => " + "
| `Subtract => " - "
| `Multiply => " * "
| `Divide => " / ";
};
/* This function takes a continuous distribution and efficiently approximates it as
point masses that have variances associated with them.
We estimate the means and variances from overlapping triangular distributions which we imagine are making up the
@ -129,7 +110,7 @@ let toDiscretePointMassesFromTriangulars =
};
let combineShapesContinuousContinuous =
(op: algebraicOperation, s1: DistTypes.xyShape, s2: DistTypes.xyShape)
(op: SymbolicTypes.algebraicOperation, s1: DistTypes.xyShape, s2: DistTypes.xyShape)
: DistTypes.xyShape => {
let t1n = s1 |> XYShape.T.length;
let t2n = s2 |> XYShape.T.length;
@ -216,4 +197,4 @@ let combineShapesContinuousContinuous =
};
{xs: outputXs, ys: outputYs};
};
};

View File

@ -149,7 +149,7 @@ module Continuous = {
continuousShapes
|> E.A.fold_left(combinePointwise(~knownIntegralSumsFn, fn), empty);
let mapY = (~knownIntegralSumFn=(_ => None), fn, t: t) => {
let mapY = (~knownIntegralSumFn=_ => None, fn, t: t) => {
let u = E.O.bind(_, knownIntegralSumFn);
let yMapFn = shapeMap(XYShape.T.mapY(fn));
@ -164,7 +164,6 @@ module Continuous = {
);
};
module T =
Dist({
type t = DistTypes.continuousShape;
@ -194,9 +193,9 @@ module Continuous = {
|> getShape
|> XYShape.T.zip
|> XYShape.Zipped.filterByX(x =>
x >= E.O.default(neg_infinity, leftCutoff)
|| x <= E.O.default(infinity, rightCutoff)
);
x >= E.O.default(neg_infinity, leftCutoff)
|| x <= E.O.default(infinity, rightCutoff)
);
let eps = (t |> getShape |> XYShape.T.xTotalRange) *. 0.0001;
@ -206,7 +205,11 @@ module Continuous = {
rightCutoff |> E.O.dimap(rc => [|(rc +. eps, 0.)|], _ => [||]);
let truncatedZippedPairsWithNewPoints =
E.A.concatMany([|leftNewPoint, truncatedZippedPairs, rightNewPoint|]);
E.A.concatMany([|
leftNewPoint,
truncatedZippedPairs,
rightNewPoint,
|]);
let truncatedShape =
XYShape.T.fromZippedArray(truncatedZippedPairsWithNewPoints);
@ -214,22 +217,20 @@ module Continuous = {
};
// TODO: This should work with stepwise plots.
let integral = (~cache, t) => {
if ((t |> getShape |> XYShape.T.length) > 0) {
switch (cache) {
| Some(cache) => cache
| None =>
t
|> getShape
|> XYShape.Range.integrateWithTriangles
|> E.O.toExt("This should not have happened")
|> make(`Linear, _, None)
};
let integral = (~cache, t) =>
if (t |> getShape |> XYShape.T.length > 0) {
switch (cache) {
| Some(cache) => cache
| None =>
t
|> getShape
|> XYShape.Range.integrateWithTriangles
|> E.O.toExt("This should not have happened")
|> make(`Linear, _, None)
};
} else {
make(`Linear, {xs: [|neg_infinity|], ys: [|0.0|]}, None);
}
};
};
let downsample = (~cache=None, length, t): t =>
t
@ -276,23 +277,31 @@ module Continuous = {
);
});
/* This simply creates multiple copies of the continuous distribution, scaled and shifted according to
each discrete data point, and then adds them all together. */
let combineAlgebraicallyWithDiscrete = (~downsample=false, op: AlgebraicCombinations.algebraicOperation, t1: t, t2: DistTypes.discreteShape) => {
let combineAlgebraicallyWithDiscrete =
(
~downsample=false,
op: SymbolicTypes.algebraicOperation,
t1: t,
t2: DistTypes.discreteShape,
) => {
let t1s = t1 |> getShape;
let t2s = t2.xyShape; // would like to use Discrete.getShape here, but current file structure doesn't allow for that
let t1n = t1s |> XYShape.T.length;
let t2n = t2s |> XYShape.T.length;
let fn = AlgebraicCombinations.Operation.toFn(op);
let fn = SymbolicTypes.Algebraic.toFn(op);
let outXYShapes: array(array((float, float))) =
Belt.Array.makeUninitializedUnsafe(t2n);
for (j in 0 to t2n - 1) { // for each one of the discrete points
for (j in 0 to t2n - 1) {
// for each one of the discrete points
// create a new distribution, as long as the original continuous one
let dxyShape: array((float, float)) = Belt.Array.makeUninitializedUnsafe(t1n);
let dxyShape: array((float, float)) =
Belt.Array.makeUninitializedUnsafe(t1n);
for (i in 0 to t1n - 1) {
let _ =
Belt.Array.set(
@ -307,7 +316,12 @@ module Continuous = {
();
};
let combinedIntegralSum = Common.combineIntegralSums((a, b) => Some(a *. b), t1.knownIntegralSum, t2.knownIntegralSum);
let combinedIntegralSum =
Common.combineIntegralSums(
(a, b) => Some(a *. b),
t1.knownIntegralSum,
t2.knownIntegralSum,
);
outXYShapes
|> E.A.fmap(s => {
@ -318,7 +332,13 @@ module Continuous = {
|> updateKnownIntegralSum(combinedIntegralSum);
};
let combineAlgebraically = (~downsample=false, op: AlgebraicCombinations.algebraicOperation, t1: t, t2: t) => {
let combineAlgebraically =
(
~downsample=false,
op: SymbolicTypes.algebraicOperation,
t1: t,
t2: t,
) => {
let s1 = t1 |> getShape;
let s2 = t2 |> getShape;
let t1n = s1 |> XYShape.T.length;
@ -326,8 +346,14 @@ module Continuous = {
if (t1n == 0 || t2n == 0) {
empty;
} else {
let combinedShape = AlgebraicCombinations.combineShapesContinuousContinuous(op, s1, s2);
let combinedIntegralSum = Common.combineIntegralSums((a, b) => Some(a *. b), t1.knownIntegralSum, t2.knownIntegralSum);
let combinedShape =
AlgebraicCombinations.combineShapesContinuousContinuous(op, s1, s2);
let combinedIntegralSum =
Common.combineIntegralSums(
(a, b) => Some(a *. b),
t1.knownIntegralSum,
t2.knownIntegralSum,
);
// return a new Continuous distribution
make(`Linear, combinedShape, combinedIntegralSum);
};
@ -370,7 +396,7 @@ module Discrete = {
XYShape.PointwiseCombination.combine(
~xsSelection=ALL_XS,
~xToYSelection=XYShape.XtoY.stepwiseIfAtX,
~fn=((a, b) => fn(E.O.default(0.0, a), E.O.default(0.0, b))), // stepwiseIfAtX returns option(float), so this fn needs to handle None
~fn=(a, b) => fn(E.O.default(0.0, a), E.O.default(0.0, b)), // stepwiseIfAtX returns option(float), so this fn needs to handle None
t1.xyShape,
t2.xyShape,
),
@ -378,7 +404,9 @@ module Discrete = {
);
};
let reduce = (~knownIntegralSumsFn=(_, _) => None, fn, discreteShapes): DistTypes.discreteShape =>
let reduce =
(~knownIntegralSumsFn=(_, _) => None, fn, discreteShapes)
: DistTypes.discreteShape =>
discreteShapes
|> E.A.fold_left(combinePointwise(~knownIntegralSumsFn, fn), empty);
@ -389,7 +417,8 @@ module Discrete = {
/* This multiples all of the data points together and creates a new discrete distribution from the results.
Data points at the same xs get added together. It may be a good idea to downsample t1 and t2 before and/or the result after. */
let combineAlgebraically = (op: AlgebraicCombinations.algebraicOperation, t1: t, t2: t) => {
let combineAlgebraically =
(op: SymbolicTypes.algebraicOperation, t1: t, t2: t) => {
let t1s = t1 |> getShape;
let t2s = t2 |> getShape;
let t1n = t1s |> XYShape.T.length;
@ -402,7 +431,7 @@ module Discrete = {
t2.knownIntegralSum,
);
let fn = AlgebraicCombinations.Operation.toFn(op);
let fn = SymbolicTypes.Algebraic.toFn(op);
let xToYMap = E.FloatFloatMap.empty();
for (i in 0 to t1n - 1) {
@ -441,8 +470,8 @@ module Discrete = {
Dist({
type t = DistTypes.discreteShape;
type integral = DistTypes.continuousShape;
let integral = (~cache, t) => {
if ((t |> getShape |> XYShape.T.length) > 0) {
let integral = (~cache, t) =>
if (t |> getShape |> XYShape.T.length > 0) {
switch (cache) {
| Some(c) => c
| None =>
@ -453,9 +482,13 @@ module Discrete = {
)
};
} else {
Continuous.make(`Stepwise, {xs: [|neg_infinity|], ys: [|0.0|]}, None);
}};
Continuous.make(
`Stepwise,
{xs: [|neg_infinity|], ys: [|0.0|]},
None,
);
};
let integralEndY = (~cache, t: t) =>
t.knownIntegralSum
|> E.O.default(t |> integral(~cache) |> Continuous.lastY);
@ -495,7 +528,7 @@ module Discrete = {
make(clippedShape, None); // if someone needs the sum, they'll have to recompute it
} else {
t;
}
};
};
let truncate =
@ -505,9 +538,9 @@ module Discrete = {
|> getShape
|> XYShape.T.zip
|> XYShape.Zipped.filterByX(x =>
x >= E.O.default(neg_infinity, leftCutoff)
|| x <= E.O.default(infinity, rightCutoff)
)
x >= E.O.default(neg_infinity, leftCutoff)
|| x <= E.O.default(infinity, rightCutoff)
)
|> XYShape.T.fromZippedArray;
make(truncatedShape, None);
@ -601,8 +634,10 @@ module Mixed = {
rightCutoff: option(float),
{discrete, continuous}: t,
) => {
let truncatedContinuous = Continuous.T.truncate(leftCutoff, rightCutoff, continuous);
let truncatedDiscrete = Discrete.T.truncate(leftCutoff, rightCutoff, discrete);
let truncatedContinuous =
Continuous.T.truncate(leftCutoff, rightCutoff, continuous);
let truncatedDiscrete =
Discrete.T.truncate(leftCutoff, rightCutoff, discrete);
make(~discrete=truncatedDiscrete, ~continuous=truncatedContinuous);
};
@ -809,7 +844,14 @@ module Mixed = {
};
});
let combineAlgebraically = (~downsample=false, op: AlgebraicCombinations.algebraicOperation, t1: t, t2: t): t => {
let combineAlgebraically =
(
~downsample=false,
op: SymbolicTypes.algebraicOperation,
t1: t,
t2: t,
)
: t => {
// Discrete convolution can cause a huge increase in the number of samples,
// so we'll first downsample.
@ -827,11 +869,26 @@ module Mixed = {
// continuous (*) continuous => continuous, but also
// discrete (*) continuous => continuous (and vice versa). We have to take care of all combos and then combine them:
let ccConvResult =
Continuous.combineAlgebraically(~downsample=false, op, t1d.continuous, t2d.continuous);
Continuous.combineAlgebraically(
~downsample=false,
op,
t1d.continuous,
t2d.continuous,
);
let dcConvResult =
Continuous.combineAlgebraicallyWithDiscrete(~downsample=false, op, t2d.continuous, t1d.discrete);
Continuous.combineAlgebraicallyWithDiscrete(
~downsample=false,
op,
t2d.continuous,
t1d.discrete,
);
let cdConvResult =
Continuous.combineAlgebraicallyWithDiscrete(~downsample=false, op, t1d.continuous, t2d.discrete);
Continuous.combineAlgebraicallyWithDiscrete(
~downsample=false,
op,
t1d.continuous,
t2d.discrete,
);
let continuousConvResult =
Continuous.reduce((+.), [|ccConvResult, dcConvResult, cdConvResult|]);
@ -866,23 +923,47 @@ module Shape = {
c => Mixed.make(~discrete=Discrete.empty, ~continuous=c),
));
let combineAlgebraically = (op: AlgebraicCombinations.algebraicOperation, t1: t, t2: t): t => {
switch ((t1, t2)) {
| (Continuous(m1), Continuous(m2)) => DistTypes.Continuous(Continuous.combineAlgebraically(~downsample=true, op, m1, m2))
| (Discrete(m1), Discrete(m2)) => DistTypes.Discrete(Discrete.combineAlgebraically(op, m1, m2))
| (m1, m2) => {
DistTypes.Mixed(Mixed.combineAlgebraically(~downsample=true, op, toMixed(m1), toMixed(m2)))
}
let combineAlgebraically =
(op: SymbolicTypes.algebraicOperation, t1: t, t2: t): t => {
switch (t1, t2) {
| (Continuous(m1), Continuous(m2)) =>
DistTypes.Continuous(
Continuous.combineAlgebraically(~downsample=true, op, m1, m2),
)
| (Discrete(m1), Discrete(m2)) =>
DistTypes.Discrete(Discrete.combineAlgebraically(op, m1, m2))
| (m1, m2) =>
DistTypes.Mixed(
Mixed.combineAlgebraically(
~downsample=true,
op,
toMixed(m1),
toMixed(m2),
),
)
};
};
let combinePointwise = (~knownIntegralSumsFn=(_, _) => None, fn, t1: t, t2: t) =>
switch ((t1, t2)) {
| (Continuous(m1), Continuous(m2)) => DistTypes.Continuous(Continuous.combinePointwise(~knownIntegralSumsFn, fn, m1, m2))
| (Discrete(m1), Discrete(m2)) => DistTypes.Discrete(Discrete.combinePointwise(~knownIntegralSumsFn, fn, m1, m2))
| (m1, m2) => {
DistTypes.Mixed(Mixed.combinePointwise(~knownIntegralSumsFn, fn, toMixed(m1), toMixed(m2)))
}
let combinePointwise =
(~knownIntegralSumsFn=(_, _) => None, fn, t1: t, t2: t) =>
switch (t1, t2) {
| (Continuous(m1), Continuous(m2)) =>
DistTypes.Continuous(
Continuous.combinePointwise(~knownIntegralSumsFn, fn, m1, m2),
)
| (Discrete(m1), Discrete(m2)) =>
DistTypes.Discrete(
Discrete.combinePointwise(~knownIntegralSumsFn, fn, m1, m2),
)
| (m1, m2) =>
DistTypes.Mixed(
Mixed.combinePointwise(
~knownIntegralSumsFn,
fn,
toMixed(m1),
toMixed(m2),
),
)
};
// TODO: implement these functions
@ -915,7 +996,6 @@ module Shape = {
let toContinuous = t => None;
let toDiscrete = t => None;
let downsample = (~cache=None, i, t) =>
fmap(
(
@ -938,7 +1018,11 @@ module Shape = {
let toDiscreteProbabilityMassFraction = t => 0.0;
let normalize =
fmap((Mixed.T.normalize, Discrete.T.normalize, Continuous.T.normalize));
fmap((
Mixed.T.normalize,
Discrete.T.normalize,
Continuous.T.normalize,
));
let toContinuous =
mapToAll((
Mixed.T.toContinuous,
@ -1089,7 +1173,8 @@ module DistPlus = {
};
let truncate = (leftCutoff, rightCutoff, t: t): t => {
let truncatedShape = t |> toShape |> Shape.T.truncate(leftCutoff, rightCutoff);
let truncatedShape =
t |> toShape |> Shape.T.truncate(leftCutoff, rightCutoff);
t |> updateShape(truncatedShape);
};
@ -1153,9 +1238,9 @@ module DistPlus = {
let integralYtoX = (~cache as _, f, t: t) => {
Shape.T.Integral.yToX(~cache=Some(t.integralCache), f, toShape(t));
};
let mean = (t: t) => {
Shape.T.mean(t.shape);
};
let mean = (t: t) => {
Shape.T.mean(t.shape);
};
let variance = (t: t) => Shape.T.variance(t.shape);
});
};

View File

@ -123,6 +123,12 @@ module Normal = {
let stdev = 1. /. (1. /. n1.stdev ** 2. +. 1. /. n2.stdev ** 2.);
`Normal({mean, stdev});
};
let operate = (operation: SymbolicTypes.Algebraic.t, n1: t, n2: t) => switch(operation){
| `Add => Some(add(n1, n2))
| `Subtract => Some(subtract(n1, n2))
| _ => None
}
};
module Beta = {
@ -171,6 +177,11 @@ module Lognormal = {
let sigma = l1.sigma +. l2.sigma;
`Lognormal({mu, sigma});
};
let operate = (operation: SymbolicTypes.Algebraic.t, n1: t, n2: t) => switch(operation){
| `Multiply => Some(multiply(n1, n2))
| `Divide => Some(divide(n1, n2))
| _ => None
}
};
module Uniform = {

View File

@ -0,0 +1,67 @@
type pointwiseOperation = [ | `Add | `Multiply];
type scaleOperation = [ | `Multiply | `Exponentiate | `Log];
type distToFloatOperation = [ | `Pdf(float) | `Inv(float) | `Mean | `Sample];
type algebraicOperation = [ | `Add | `Multiply | `Subtract | `Divide];
module Algebraic = {
type t = algebraicOperation;
let toFn: (t, float, float) => float =
fun
| `Add => (+.)
| `Subtract => (-.)
| `Multiply => ( *. )
| `Divide => (/.);
let applyFn = (t, f1, f2) => {
switch (t, f1, f2) {
| (`Divide, _, 0.) => Error("Cannot divide $v1 by zero.")
| _ => Ok(toFn(t, f1, f2))
};
};
let toString =
fun
| `Add => "+"
| `Subtract => "-"
| `Multiply => "*"
| `Divide => "/";
let format = (a, b, c) => b ++ " " ++ toString(a) ++ " " ++ c;
};
module Pointwise = {
type t = pointwiseOperation;
let toString =
fun
| `Add => "+"
| `Multiply => "*";
let format = (a, b, c) => b ++ " " ++ toString(a) ++ " " ++ c;
};
module DistToFloat = {
type t = distToFloatOperation;
let stringFromFloatFromDistOperation =
fun
| `Pdf(f) => {j|pdf(x=$f, |j}
| `Inv(f) => {j|inv(x=$f, |j}
| `Sample => "sample("
| `Mean => "mean(";
let format = (a, b) => stringFromFloatFromDistOperation(a) ++ b ++ ")";
};
module Scale = {
type t = scaleOperation;
let toFn =
fun
| `Multiply => ( *. )
| `Exponentiate => ( ** )
| `Log => ((a, b) => log(a) /. log(b));
let toKnownIntegralSumFn =
fun
| `Multiply => ((a, b) => Some(a *. b))
| `Exponentiate => ((_, _) => None)
| `Log => ((_, _) => None);
}

View File

@ -1,86 +1,41 @@
/* This module represents a tree node. */
open SymbolicTypes;
// todo: Symbolic already has an arbitrary continuousShape option. It seems messy to have both.
type distData = [
| `Symbolic(SymbolicDist.dist)
| `RenderedShape(DistTypes.shape)
];
type pointwiseOperation = [ | `Add | `Multiply];
type scaleOperation = [ | `Multiply | `Exponentiate | `Log];
type distToFloatOperation = [ | `Pdf(float) | `Inv(float) | `Mean | `Sample];
/* TreeNodes are either Data (i.e. symbolic or rendered distributions) or Operations. */
type treeNode = [
| `DistData(distData) // a leaf node that describes a distribution
| `Operation(operation) // an operation on two child nodes
]
/* TreeNodes are either Data (i.e. symbolic or rendered distributions) or Operations. Operations always refer to two child nodes.*/
type treeNode = [ | `DistData(distData) | `Operation(operation)]
and operation = [
| // binary operations
`AlgebraicCombination(
AlgebraicCombinations.algebraicOperation,
treeNode,
treeNode,
)
// unary operations
| `PointwiseCombination(pointwiseOperation, treeNode, treeNode) // always evaluates to `DistData(`RenderedShape(...))
| `VerticalScaling(scaleOperation, treeNode, treeNode) // always evaluates to `DistData(`RenderedShape(...))
| `Render(treeNode) // always evaluates to `DistData(`RenderedShape(...))
| `Truncate // always evaluates to `DistData(`RenderedShape(...))
(
option(float),
option(float),
treeNode,
) // leftCutoff and rightCutoff
| `Normalize // always evaluates to `DistData(`RenderedShape(...))
// leftCutoff and rightCutoff
(
treeNode,
)
| `FloatFromDist // always evaluates to `DistData(`RenderedShape(...))
// leftCutoff and rightCutoff
(
distToFloatOperation,
treeNode,
)
| `AlgebraicCombination(algebraicOperation, treeNode, treeNode)
| `PointwiseCombination(pointwiseOperation, treeNode, treeNode)
| `VerticalScaling(scaleOperation, treeNode, treeNode)
| `Render(treeNode)
| `Truncate(option(float), option(float), treeNode)
| `Normalize(treeNode)
| `FloatFromDist(distToFloatOperation, treeNode)
];
module TreeNode = {
type t = treeNode;
type tResult = treeNode => result(treeNode, string);
let rec toString = (t: t): string => {
let stringFromAlgebraicCombination =
fun
| `Add => " + "
| `Subtract => " - "
| `Multiply => " * "
| `Divide => " / "
let stringFromPointwiseCombination =
fun
| `Add => " .+ "
| `Multiply => " .* ";
let stringFromFloatFromDistOperation =
fun
| `Pdf(f) => {j|pdf(x=$f, |j}
| `Inv(f) => {j|inv(x=$f, |j}
| `Sample => "sample("
| `Mean => "mean(";
switch (t) {
let rec toString =
fun
| `DistData(`Symbolic(d)) =>
SymbolicDist.GenericDistFunctions.toString(d)
| `DistData(`RenderedShape(_)) => "[shape]"
| `Operation(`AlgebraicCombination(op, t1, t2)) =>
toString(t1) ++ stringFromAlgebraicCombination(op) ++ toString(t2)
SymbolicTypes.Algebraic.format(op, toString(t1), toString(t2))
| `Operation(`PointwiseCombination(op, t1, t2)) =>
toString(t1) ++ stringFromPointwiseCombination(op) ++ toString(t2)
SymbolicTypes.Pointwise.format(op, toString(t1), toString(t2))
| `Operation(`VerticalScaling(_scaleOp, t, scaleBy)) =>
toString(t) ++ " @ " ++ toString(scaleBy)
| `Operation(`Normalize(t)) => "normalize(" ++ toString(t) ++ ")"
| `Operation(`FloatFromDist(floatFromDistOp, t)) => stringFromFloatFromDistOperation(floatFromDistOp) ++ toString(t) ++ ")"
| `Operation(`FloatFromDist(floatFromDistOp, t)) =>
SymbolicTypes.DistToFloat.format(floatFromDistOp, toString(t))
| `Operation(`Truncate(lc, rc, t)) =>
"truncate("
++ toString(t)
@ -89,9 +44,7 @@ module TreeNode = {
++ ", "
++ E.O.dimap(Js.Float.toString, () => "inf", rc)
++ ")"
| `Operation(`Render(t)) => toString(t)
};
};
| `Operation(`Render(t)) => toString(t);
/* The following modules encapsulate everything we can do with
* different kinds of operations. */
@ -104,88 +57,72 @@ module TreeNode = {
let simplify = (algebraicOp, t1: t, t2: t): result(treeNode, string) => {
let tryCombiningFloats: tResult =
fun
| `Operation(
`AlgebraicCombination(
`Divide,
`DistData(`Symbolic(`Float(_))),
`DistData(`Symbolic(`Float(0.))),
),
) =>
Error("Cannot divide $v1 by zero.")
| `Operation(
`AlgebraicCombination(
algebraicOp,
`DistData(`Symbolic(`Float(v1))),
`DistData(`Symbolic(`Float(v2))),
),
) => {
let func = AlgebraicCombinations.Operation.toFn(algebraicOp);
Ok(`DistData(`Symbolic(`Float(func(v1, v2)))));
}
) =>
SymbolicTypes.Algebraic.applyFn(algebraicOp, v1, v2)
|> E.R.fmap(r => `DistData(`Symbolic(`Float(r))))
| t => Ok(t);
let optionToSymbolicResult = (t, o) =>
o
|> E.O.dimap(r => `DistData(`Symbolic(r)), () => t)
|> (r => Ok(r));
let tryCombiningNormals: tResult =
fun
| `Operation(
`AlgebraicCombination(
`Add,
operation,
`DistData(`Symbolic(`Normal(n1))),
`DistData(`Symbolic(`Normal(n2))),
),
) =>
Ok(`DistData(`Symbolic(SymbolicDist.Normal.add(n1, n2))))
| `Operation(
`AlgebraicCombination(
`Subtract,
`DistData(`Symbolic(`Normal(n1))),
`DistData(`Symbolic(`Normal(n2))),
),
) =>
Ok(`DistData(`Symbolic(SymbolicDist.Normal.subtract(n1, n2))))
) as t =>
SymbolicDist.Normal.operate(operation, n1, n2)
|> optionToSymbolicResult(t)
| t => Ok(t);
let tryCombiningLognormals: tResult =
fun
| `Operation(
`AlgebraicCombination(
`Multiply,
`DistData(`Symbolic(`Lognormal(l1))),
`DistData(`Symbolic(`Lognormal(l2))),
operation,
`DistData(`Symbolic(`Lognormal(n1))),
`DistData(`Symbolic(`Lognormal(n2))),
),
) =>
Ok(`DistData(`Symbolic(SymbolicDist.Lognormal.multiply(l1, l2))))
| `Operation(
`AlgebraicCombination(
`Divide,
`DistData(`Symbolic(`Lognormal(l1))),
`DistData(`Symbolic(`Lognormal(l2))),
),
) =>
Ok(`DistData(`Symbolic(SymbolicDist.Lognormal.divide(l1, l2))))
) as t =>
SymbolicDist.Lognormal.operate(operation, n1, n2)
|> optionToSymbolicResult(t)
| t => Ok(t);
let originalTreeNode =
`Operation(`AlgebraicCombination((algebraicOp, t1, t2)));
`Operation(`AlgebraicCombination((algebraicOp, t1, t2)));
// Feedback: I like this pattern, kudos
originalTreeNode
|> tryCombiningFloats
|> E.R.bind(_, tryCombiningNormals)
|> E.R.bind(_, tryCombiningLognormals);
};
// todo: I don't like the name evaluateNumerically that much, if this renders and does it algebraically. It's tricky.
let evaluateNumerically = (algebraicOp, operationToDistData, t1, t2) => {
// force rendering into shapes
let renderedShape1 = operationToDistData(`Render(t1));
let renderedShape2 = operationToDistData(`Render(t2));
switch (renderedShape1, renderedShape2) {
let renderShape = r => operationToDistData(`Render(r));
switch (renderShape(t1), renderShape(t2)) {
| (
Ok(`DistData(`RenderedShape(s1))),
Ok(`DistData(`RenderedShape(s2))),
) =>
Ok(
`DistData(
`RenderedShape(Distributions.Shape.combineAlgebraically(algebraicOp, s1, s2)),
`RenderedShape(
Distributions.Shape.combineAlgebraically(algebraicOp, s1, s2),
),
),
)
| (Error(e1), _) => Error(e1)
@ -195,7 +132,12 @@ module TreeNode = {
};
let evaluateToDistData =
(algebraicOp: AlgebraicCombinations.algebraicOperation, operationToDistData, t1: t, t2: t)
(
algebraicOp: SymbolicTypes.algebraicOperation,
operationToDistData,
t1: t,
t2: t,
)
: result(treeNode, string) =>
algebraicOp
|> simplify(_, t1, t2)
@ -210,27 +152,13 @@ module TreeNode = {
};
module VerticalScaling = {
let fnFromOp =
fun
| `Multiply => ( *. )
| `Exponentiate => ( ** )
| `Log => ((a, b) => log(a) /. log(b));
let knownIntegralSumFnFromOp =
fun
| `Multiply => ((a, b) => Some(a *. b))
| `Exponentiate => ((_, _) => None)
| `Log => ((_, _) => None);
let evaluateToDistData = (scaleOp, operationToDistData, t, scaleBy) => {
// scaleBy has to be a single float, otherwise we'll return an error.
let fn = fnFromOp(scaleOp);
let knownIntegralSumFn = knownIntegralSumFnFromOp(scaleOp);
let fn = SymbolicTypes.Scale.toFn(scaleOp);
let knownIntegralSumFn = SymbolicTypes.Scale.toKnownIntegralSumFn(scaleOp);
let renderedShape = operationToDistData(`Render(t));
switch (renderedShape, scaleBy) {
| (Error(e1), _) => Error(e1)
| (
Ok(`DistData(`RenderedShape(rs))),
`DistData(`Symbolic(`Float(sm))),
@ -246,6 +174,7 @@ module TreeNode = {
),
),
)
| (Error(e1), _) => Error(e1)
| (_, _) => Error("Can only scale by float values.")
};
};
@ -253,14 +182,28 @@ module TreeNode = {
module PointwiseCombination = {
let pointwiseAdd = (operationToDistData, t1, t2) => {
let renderedShape1 = operationToDistData(`Render(t1));
let renderedShape2 = operationToDistData(`Render(t2));
let renderedShape1 = operationToDistData(`Render(t1));
let renderedShape2 = operationToDistData(`Render(t2));
switch ((renderedShape1, renderedShape2)) {
switch (renderedShape1, renderedShape2) {
| (
Ok(`DistData(`RenderedShape(rs1))),
Ok(`DistData(`RenderedShape(rs2))),
) =>
Ok(
`DistData(
`RenderedShape(
Distributions.Shape.combinePointwise(
~knownIntegralSumsFn=(a, b) => Some(a +. b),
(+.),
rs1,
rs2,
),
),
),
)
| (Error(e1), _) => Error(e1)
| (_, Error(e2)) => Error(e2)
| (Ok(`DistData(`RenderedShape(rs1))), Ok(`DistData(`RenderedShape(rs2)))) =>
Ok(`DistData(`RenderedShape(Distributions.Shape.combinePointwise(~knownIntegralSumsFn=(a, b) => Some(a +. b), (+.), rs1, rs2))))
| _ => Error("Could not perform pointwise addition.")
};
};
@ -268,14 +211,16 @@ module TreeNode = {
let pointwiseMultiply = (operationToDistData, t1, t2) => {
// TODO: construct a function that we can easily sample from, to construct
// a RenderedShape. Use the xMin and xMax of the rendered shapes to tell the sampling function where to look.
Error("Pointwise multiplication not yet supported.");
Error(
"Pointwise multiplication not yet supported.",
);
};
let evaluateToDistData = (pointwiseOp, operationToDistData, t1, t2) => {
switch (pointwiseOp) {
| `Add => pointwiseAdd(operationToDistData, t1, t2)
| `Multiply => pointwiseMultiply(operationToDistData, t1, t2)
}
};
};
};
@ -378,7 +323,9 @@ module TreeNode = {
};
E.R.bind(value, v => Ok(`DistData(`Symbolic(`Float(v)))));
};
let evaluateFromRenderedShape = (distToFloatOp: distToFloatOperation, rs: DistTypes.shape) : result(treeNode, string) => {
let evaluateFromRenderedShape =
(distToFloatOp: distToFloatOperation, rs: DistTypes.shape)
: result(treeNode, string) => {
let value =
switch (distToFloatOp) {
| `Pdf(f) => Ok(Distributions.Shape.pdf(f, rs))
@ -410,8 +357,12 @@ module TreeNode = {
module Render = {
let rec evaluateToRenderedShape =
(operationToDistData: operation => result(t, string), sampleCount: int, t: treeNode)
: result(t, string) => {
(
operationToDistData: operation => result(t, string),
sampleCount: int,
t: treeNode,
)
: result(t, string) => {
switch (t) {
| `DistData(`RenderedShape(s)) => Ok(`DistData(`RenderedShape(s))) // already a rendered shape, we're done here
| `DistData(`Symbolic(d)) =>
@ -495,10 +446,19 @@ module TreeNode = {
t,
)
| `FloatFromDist(distToFloatOp, t) =>
FloatFromDist.evaluateToDistData(distToFloatOp, operationToDistData(sampleCount), t)
| `Normalize(t) => Normalize.evaluateToDistData(operationToDistData(sampleCount), t)
FloatFromDist.evaluateToDistData(
distToFloatOp,
operationToDistData(sampleCount),
t,
)
| `Normalize(t) =>
Normalize.evaluateToDistData(operationToDistData(sampleCount), t)
| `Render(t) =>
Render.evaluateToRenderedShape(operationToDistData(sampleCount), sampleCount, t)
Render.evaluateToRenderedShape(
operationToDistData(sampleCount),
sampleCount,
t,
)
};
};
@ -531,5 +491,4 @@ let toShape = (sampleCount: int, treeNode: treeNode) => {
};
};
let toString = (treeNode: treeNode) =>
TreeNode.toString(treeNode);
let toString = (treeNode: treeNode) => TreeNode.toString(treeNode);