Moving operations functionality into new SymbolicTypes.re file

This commit is contained in:
Ozzie Gooen 2020-07-01 22:01:58 +01:00
parent acdd3dfe7a
commit baaff19750
5 changed files with 330 additions and 227 deletions

View File

@ -1,5 +1,3 @@
type algebraicOperation = [ | `Add | `Multiply | `Subtract | `Divide];
type pointMassesWithMoments = { type pointMassesWithMoments = {
n: int, n: int,
masses: array(float), masses: array(float),
@ -7,23 +5,6 @@ type pointMassesWithMoments = {
variances: array(float), variances: array(float),
}; };
module Operation = {
type t = algebraicOperation;
let toFn: (t, float, float) => float =
fun
| `Add => (+.)
| `Subtract => (-.)
| `Multiply => ( *. )
| `Divide => (/.);
let toString =
fun
| `Add => " + "
| `Subtract => " - "
| `Multiply => " * "
| `Divide => " / ";
};
/* This function takes a continuous distribution and efficiently approximates it as /* This function takes a continuous distribution and efficiently approximates it as
point masses that have variances associated with them. point masses that have variances associated with them.
We estimate the means and variances from overlapping triangular distributions which we imagine are making up the We estimate the means and variances from overlapping triangular distributions which we imagine are making up the
@ -129,7 +110,7 @@ let toDiscretePointMassesFromTriangulars =
}; };
let combineShapesContinuousContinuous = let combineShapesContinuousContinuous =
(op: algebraicOperation, s1: DistTypes.xyShape, s2: DistTypes.xyShape) (op: SymbolicTypes.algebraicOperation, s1: DistTypes.xyShape, s2: DistTypes.xyShape)
: DistTypes.xyShape => { : DistTypes.xyShape => {
let t1n = s1 |> XYShape.T.length; let t1n = s1 |> XYShape.T.length;
let t2n = s2 |> XYShape.T.length; let t2n = s2 |> XYShape.T.length;
@ -216,4 +197,4 @@ let combineShapesContinuousContinuous =
}; };
{xs: outputXs, ys: outputYs}; {xs: outputXs, ys: outputYs};
}; };

View File

@ -149,7 +149,7 @@ module Continuous = {
continuousShapes continuousShapes
|> E.A.fold_left(combinePointwise(~knownIntegralSumsFn, fn), empty); |> E.A.fold_left(combinePointwise(~knownIntegralSumsFn, fn), empty);
let mapY = (~knownIntegralSumFn=(_ => None), fn, t: t) => { let mapY = (~knownIntegralSumFn=_ => None, fn, t: t) => {
let u = E.O.bind(_, knownIntegralSumFn); let u = E.O.bind(_, knownIntegralSumFn);
let yMapFn = shapeMap(XYShape.T.mapY(fn)); let yMapFn = shapeMap(XYShape.T.mapY(fn));
@ -164,7 +164,6 @@ module Continuous = {
); );
}; };
module T = module T =
Dist({ Dist({
type t = DistTypes.continuousShape; type t = DistTypes.continuousShape;
@ -194,9 +193,9 @@ module Continuous = {
|> getShape |> getShape
|> XYShape.T.zip |> XYShape.T.zip
|> XYShape.Zipped.filterByX(x => |> XYShape.Zipped.filterByX(x =>
x >= E.O.default(neg_infinity, leftCutoff) x >= E.O.default(neg_infinity, leftCutoff)
|| x <= E.O.default(infinity, rightCutoff) || x <= E.O.default(infinity, rightCutoff)
); );
let eps = (t |> getShape |> XYShape.T.xTotalRange) *. 0.0001; let eps = (t |> getShape |> XYShape.T.xTotalRange) *. 0.0001;
@ -206,7 +205,11 @@ module Continuous = {
rightCutoff |> E.O.dimap(rc => [|(rc +. eps, 0.)|], _ => [||]); rightCutoff |> E.O.dimap(rc => [|(rc +. eps, 0.)|], _ => [||]);
let truncatedZippedPairsWithNewPoints = let truncatedZippedPairsWithNewPoints =
E.A.concatMany([|leftNewPoint, truncatedZippedPairs, rightNewPoint|]); E.A.concatMany([|
leftNewPoint,
truncatedZippedPairs,
rightNewPoint,
|]);
let truncatedShape = let truncatedShape =
XYShape.T.fromZippedArray(truncatedZippedPairsWithNewPoints); XYShape.T.fromZippedArray(truncatedZippedPairsWithNewPoints);
@ -214,22 +217,20 @@ module Continuous = {
}; };
// TODO: This should work with stepwise plots. // TODO: This should work with stepwise plots.
let integral = (~cache, t) => { let integral = (~cache, t) =>
if (t |> getShape |> XYShape.T.length > 0) {
if ((t |> getShape |> XYShape.T.length) > 0) { switch (cache) {
switch (cache) { | Some(cache) => cache
| Some(cache) => cache | None =>
| None => t
t |> getShape
|> getShape |> XYShape.Range.integrateWithTriangles
|> XYShape.Range.integrateWithTriangles |> E.O.toExt("This should not have happened")
|> E.O.toExt("This should not have happened") |> make(`Linear, _, None)
|> make(`Linear, _, None) };
};
} else { } else {
make(`Linear, {xs: [|neg_infinity|], ys: [|0.0|]}, None); make(`Linear, {xs: [|neg_infinity|], ys: [|0.0|]}, None);
} };
};
let downsample = (~cache=None, length, t): t => let downsample = (~cache=None, length, t): t =>
t t
@ -276,23 +277,31 @@ module Continuous = {
); );
}); });
/* This simply creates multiple copies of the continuous distribution, scaled and shifted according to /* This simply creates multiple copies of the continuous distribution, scaled and shifted according to
each discrete data point, and then adds them all together. */ each discrete data point, and then adds them all together. */
let combineAlgebraicallyWithDiscrete = (~downsample=false, op: AlgebraicCombinations.algebraicOperation, t1: t, t2: DistTypes.discreteShape) => { let combineAlgebraicallyWithDiscrete =
(
~downsample=false,
op: SymbolicTypes.algebraicOperation,
t1: t,
t2: DistTypes.discreteShape,
) => {
let t1s = t1 |> getShape; let t1s = t1 |> getShape;
let t2s = t2.xyShape; // would like to use Discrete.getShape here, but current file structure doesn't allow for that let t2s = t2.xyShape; // would like to use Discrete.getShape here, but current file structure doesn't allow for that
let t1n = t1s |> XYShape.T.length; let t1n = t1s |> XYShape.T.length;
let t2n = t2s |> XYShape.T.length; let t2n = t2s |> XYShape.T.length;
let fn = AlgebraicCombinations.Operation.toFn(op); let fn = SymbolicTypes.Algebraic.toFn(op);
let outXYShapes: array(array((float, float))) = let outXYShapes: array(array((float, float))) =
Belt.Array.makeUninitializedUnsafe(t2n); Belt.Array.makeUninitializedUnsafe(t2n);
for (j in 0 to t2n - 1) { // for each one of the discrete points for (j in 0 to t2n - 1) {
// for each one of the discrete points
// create a new distribution, as long as the original continuous one // create a new distribution, as long as the original continuous one
let dxyShape: array((float, float)) = Belt.Array.makeUninitializedUnsafe(t1n);
let dxyShape: array((float, float)) =
Belt.Array.makeUninitializedUnsafe(t1n);
for (i in 0 to t1n - 1) { for (i in 0 to t1n - 1) {
let _ = let _ =
Belt.Array.set( Belt.Array.set(
@ -307,7 +316,12 @@ module Continuous = {
(); ();
}; };
let combinedIntegralSum = Common.combineIntegralSums((a, b) => Some(a *. b), t1.knownIntegralSum, t2.knownIntegralSum); let combinedIntegralSum =
Common.combineIntegralSums(
(a, b) => Some(a *. b),
t1.knownIntegralSum,
t2.knownIntegralSum,
);
outXYShapes outXYShapes
|> E.A.fmap(s => { |> E.A.fmap(s => {
@ -318,7 +332,13 @@ module Continuous = {
|> updateKnownIntegralSum(combinedIntegralSum); |> updateKnownIntegralSum(combinedIntegralSum);
}; };
let combineAlgebraically = (~downsample=false, op: AlgebraicCombinations.algebraicOperation, t1: t, t2: t) => { let combineAlgebraically =
(
~downsample=false,
op: SymbolicTypes.algebraicOperation,
t1: t,
t2: t,
) => {
let s1 = t1 |> getShape; let s1 = t1 |> getShape;
let s2 = t2 |> getShape; let s2 = t2 |> getShape;
let t1n = s1 |> XYShape.T.length; let t1n = s1 |> XYShape.T.length;
@ -326,8 +346,14 @@ module Continuous = {
if (t1n == 0 || t2n == 0) { if (t1n == 0 || t2n == 0) {
empty; empty;
} else { } else {
let combinedShape = AlgebraicCombinations.combineShapesContinuousContinuous(op, s1, s2); let combinedShape =
let combinedIntegralSum = Common.combineIntegralSums((a, b) => Some(a *. b), t1.knownIntegralSum, t2.knownIntegralSum); AlgebraicCombinations.combineShapesContinuousContinuous(op, s1, s2);
let combinedIntegralSum =
Common.combineIntegralSums(
(a, b) => Some(a *. b),
t1.knownIntegralSum,
t2.knownIntegralSum,
);
// return a new Continuous distribution // return a new Continuous distribution
make(`Linear, combinedShape, combinedIntegralSum); make(`Linear, combinedShape, combinedIntegralSum);
}; };
@ -370,7 +396,7 @@ module Discrete = {
XYShape.PointwiseCombination.combine( XYShape.PointwiseCombination.combine(
~xsSelection=ALL_XS, ~xsSelection=ALL_XS,
~xToYSelection=XYShape.XtoY.stepwiseIfAtX, ~xToYSelection=XYShape.XtoY.stepwiseIfAtX,
~fn=((a, b) => fn(E.O.default(0.0, a), E.O.default(0.0, b))), // stepwiseIfAtX returns option(float), so this fn needs to handle None ~fn=(a, b) => fn(E.O.default(0.0, a), E.O.default(0.0, b)), // stepwiseIfAtX returns option(float), so this fn needs to handle None
t1.xyShape, t1.xyShape,
t2.xyShape, t2.xyShape,
), ),
@ -378,7 +404,9 @@ module Discrete = {
); );
}; };
let reduce = (~knownIntegralSumsFn=(_, _) => None, fn, discreteShapes): DistTypes.discreteShape => let reduce =
(~knownIntegralSumsFn=(_, _) => None, fn, discreteShapes)
: DistTypes.discreteShape =>
discreteShapes discreteShapes
|> E.A.fold_left(combinePointwise(~knownIntegralSumsFn, fn), empty); |> E.A.fold_left(combinePointwise(~knownIntegralSumsFn, fn), empty);
@ -389,7 +417,8 @@ module Discrete = {
/* This multiples all of the data points together and creates a new discrete distribution from the results. /* This multiples all of the data points together and creates a new discrete distribution from the results.
Data points at the same xs get added together. It may be a good idea to downsample t1 and t2 before and/or the result after. */ Data points at the same xs get added together. It may be a good idea to downsample t1 and t2 before and/or the result after. */
let combineAlgebraically = (op: AlgebraicCombinations.algebraicOperation, t1: t, t2: t) => { let combineAlgebraically =
(op: SymbolicTypes.algebraicOperation, t1: t, t2: t) => {
let t1s = t1 |> getShape; let t1s = t1 |> getShape;
let t2s = t2 |> getShape; let t2s = t2 |> getShape;
let t1n = t1s |> XYShape.T.length; let t1n = t1s |> XYShape.T.length;
@ -402,7 +431,7 @@ module Discrete = {
t2.knownIntegralSum, t2.knownIntegralSum,
); );
let fn = AlgebraicCombinations.Operation.toFn(op); let fn = SymbolicTypes.Algebraic.toFn(op);
let xToYMap = E.FloatFloatMap.empty(); let xToYMap = E.FloatFloatMap.empty();
for (i in 0 to t1n - 1) { for (i in 0 to t1n - 1) {
@ -441,8 +470,8 @@ module Discrete = {
Dist({ Dist({
type t = DistTypes.discreteShape; type t = DistTypes.discreteShape;
type integral = DistTypes.continuousShape; type integral = DistTypes.continuousShape;
let integral = (~cache, t) => { let integral = (~cache, t) =>
if ((t |> getShape |> XYShape.T.length) > 0) { if (t |> getShape |> XYShape.T.length > 0) {
switch (cache) { switch (cache) {
| Some(c) => c | Some(c) => c
| None => | None =>
@ -453,9 +482,13 @@ module Discrete = {
) )
}; };
} else { } else {
Continuous.make(`Stepwise, {xs: [|neg_infinity|], ys: [|0.0|]}, None); Continuous.make(
}}; `Stepwise,
{xs: [|neg_infinity|], ys: [|0.0|]},
None,
);
};
let integralEndY = (~cache, t: t) => let integralEndY = (~cache, t: t) =>
t.knownIntegralSum t.knownIntegralSum
|> E.O.default(t |> integral(~cache) |> Continuous.lastY); |> E.O.default(t |> integral(~cache) |> Continuous.lastY);
@ -495,7 +528,7 @@ module Discrete = {
make(clippedShape, None); // if someone needs the sum, they'll have to recompute it make(clippedShape, None); // if someone needs the sum, they'll have to recompute it
} else { } else {
t; t;
} };
}; };
let truncate = let truncate =
@ -505,9 +538,9 @@ module Discrete = {
|> getShape |> getShape
|> XYShape.T.zip |> XYShape.T.zip
|> XYShape.Zipped.filterByX(x => |> XYShape.Zipped.filterByX(x =>
x >= E.O.default(neg_infinity, leftCutoff) x >= E.O.default(neg_infinity, leftCutoff)
|| x <= E.O.default(infinity, rightCutoff) || x <= E.O.default(infinity, rightCutoff)
) )
|> XYShape.T.fromZippedArray; |> XYShape.T.fromZippedArray;
make(truncatedShape, None); make(truncatedShape, None);
@ -601,8 +634,10 @@ module Mixed = {
rightCutoff: option(float), rightCutoff: option(float),
{discrete, continuous}: t, {discrete, continuous}: t,
) => { ) => {
let truncatedContinuous = Continuous.T.truncate(leftCutoff, rightCutoff, continuous); let truncatedContinuous =
let truncatedDiscrete = Discrete.T.truncate(leftCutoff, rightCutoff, discrete); Continuous.T.truncate(leftCutoff, rightCutoff, continuous);
let truncatedDiscrete =
Discrete.T.truncate(leftCutoff, rightCutoff, discrete);
make(~discrete=truncatedDiscrete, ~continuous=truncatedContinuous); make(~discrete=truncatedDiscrete, ~continuous=truncatedContinuous);
}; };
@ -809,7 +844,14 @@ module Mixed = {
}; };
}); });
let combineAlgebraically = (~downsample=false, op: AlgebraicCombinations.algebraicOperation, t1: t, t2: t): t => { let combineAlgebraically =
(
~downsample=false,
op: SymbolicTypes.algebraicOperation,
t1: t,
t2: t,
)
: t => {
// Discrete convolution can cause a huge increase in the number of samples, // Discrete convolution can cause a huge increase in the number of samples,
// so we'll first downsample. // so we'll first downsample.
@ -827,11 +869,26 @@ module Mixed = {
// continuous (*) continuous => continuous, but also // continuous (*) continuous => continuous, but also
// discrete (*) continuous => continuous (and vice versa). We have to take care of all combos and then combine them: // discrete (*) continuous => continuous (and vice versa). We have to take care of all combos and then combine them:
let ccConvResult = let ccConvResult =
Continuous.combineAlgebraically(~downsample=false, op, t1d.continuous, t2d.continuous); Continuous.combineAlgebraically(
~downsample=false,
op,
t1d.continuous,
t2d.continuous,
);
let dcConvResult = let dcConvResult =
Continuous.combineAlgebraicallyWithDiscrete(~downsample=false, op, t2d.continuous, t1d.discrete); Continuous.combineAlgebraicallyWithDiscrete(
~downsample=false,
op,
t2d.continuous,
t1d.discrete,
);
let cdConvResult = let cdConvResult =
Continuous.combineAlgebraicallyWithDiscrete(~downsample=false, op, t1d.continuous, t2d.discrete); Continuous.combineAlgebraicallyWithDiscrete(
~downsample=false,
op,
t1d.continuous,
t2d.discrete,
);
let continuousConvResult = let continuousConvResult =
Continuous.reduce((+.), [|ccConvResult, dcConvResult, cdConvResult|]); Continuous.reduce((+.), [|ccConvResult, dcConvResult, cdConvResult|]);
@ -866,23 +923,47 @@ module Shape = {
c => Mixed.make(~discrete=Discrete.empty, ~continuous=c), c => Mixed.make(~discrete=Discrete.empty, ~continuous=c),
)); ));
let combineAlgebraically = (op: AlgebraicCombinations.algebraicOperation, t1: t, t2: t): t => { let combineAlgebraically =
switch ((t1, t2)) { (op: SymbolicTypes.algebraicOperation, t1: t, t2: t): t => {
| (Continuous(m1), Continuous(m2)) => DistTypes.Continuous(Continuous.combineAlgebraically(~downsample=true, op, m1, m2)) switch (t1, t2) {
| (Discrete(m1), Discrete(m2)) => DistTypes.Discrete(Discrete.combineAlgebraically(op, m1, m2)) | (Continuous(m1), Continuous(m2)) =>
| (m1, m2) => { DistTypes.Continuous(
DistTypes.Mixed(Mixed.combineAlgebraically(~downsample=true, op, toMixed(m1), toMixed(m2))) Continuous.combineAlgebraically(~downsample=true, op, m1, m2),
} )
| (Discrete(m1), Discrete(m2)) =>
DistTypes.Discrete(Discrete.combineAlgebraically(op, m1, m2))
| (m1, m2) =>
DistTypes.Mixed(
Mixed.combineAlgebraically(
~downsample=true,
op,
toMixed(m1),
toMixed(m2),
),
)
}; };
}; };
let combinePointwise = (~knownIntegralSumsFn=(_, _) => None, fn, t1: t, t2: t) => let combinePointwise =
switch ((t1, t2)) { (~knownIntegralSumsFn=(_, _) => None, fn, t1: t, t2: t) =>
| (Continuous(m1), Continuous(m2)) => DistTypes.Continuous(Continuous.combinePointwise(~knownIntegralSumsFn, fn, m1, m2)) switch (t1, t2) {
| (Discrete(m1), Discrete(m2)) => DistTypes.Discrete(Discrete.combinePointwise(~knownIntegralSumsFn, fn, m1, m2)) | (Continuous(m1), Continuous(m2)) =>
| (m1, m2) => { DistTypes.Continuous(
DistTypes.Mixed(Mixed.combinePointwise(~knownIntegralSumsFn, fn, toMixed(m1), toMixed(m2))) Continuous.combinePointwise(~knownIntegralSumsFn, fn, m1, m2),
} )
| (Discrete(m1), Discrete(m2)) =>
DistTypes.Discrete(
Discrete.combinePointwise(~knownIntegralSumsFn, fn, m1, m2),
)
| (m1, m2) =>
DistTypes.Mixed(
Mixed.combinePointwise(
~knownIntegralSumsFn,
fn,
toMixed(m1),
toMixed(m2),
),
)
}; };
// TODO: implement these functions // TODO: implement these functions
@ -915,7 +996,6 @@ module Shape = {
let toContinuous = t => None; let toContinuous = t => None;
let toDiscrete = t => None; let toDiscrete = t => None;
let downsample = (~cache=None, i, t) => let downsample = (~cache=None, i, t) =>
fmap( fmap(
( (
@ -938,7 +1018,11 @@ module Shape = {
let toDiscreteProbabilityMassFraction = t => 0.0; let toDiscreteProbabilityMassFraction = t => 0.0;
let normalize = let normalize =
fmap((Mixed.T.normalize, Discrete.T.normalize, Continuous.T.normalize)); fmap((
Mixed.T.normalize,
Discrete.T.normalize,
Continuous.T.normalize,
));
let toContinuous = let toContinuous =
mapToAll(( mapToAll((
Mixed.T.toContinuous, Mixed.T.toContinuous,
@ -1089,7 +1173,8 @@ module DistPlus = {
}; };
let truncate = (leftCutoff, rightCutoff, t: t): t => { let truncate = (leftCutoff, rightCutoff, t: t): t => {
let truncatedShape = t |> toShape |> Shape.T.truncate(leftCutoff, rightCutoff); let truncatedShape =
t |> toShape |> Shape.T.truncate(leftCutoff, rightCutoff);
t |> updateShape(truncatedShape); t |> updateShape(truncatedShape);
}; };
@ -1153,9 +1238,9 @@ module DistPlus = {
let integralYtoX = (~cache as _, f, t: t) => { let integralYtoX = (~cache as _, f, t: t) => {
Shape.T.Integral.yToX(~cache=Some(t.integralCache), f, toShape(t)); Shape.T.Integral.yToX(~cache=Some(t.integralCache), f, toShape(t));
}; };
let mean = (t: t) => { let mean = (t: t) => {
Shape.T.mean(t.shape); Shape.T.mean(t.shape);
}; };
let variance = (t: t) => Shape.T.variance(t.shape); let variance = (t: t) => Shape.T.variance(t.shape);
}); });
}; };

View File

@ -123,6 +123,12 @@ module Normal = {
let stdev = 1. /. (1. /. n1.stdev ** 2. +. 1. /. n2.stdev ** 2.); let stdev = 1. /. (1. /. n1.stdev ** 2. +. 1. /. n2.stdev ** 2.);
`Normal({mean, stdev}); `Normal({mean, stdev});
}; };
let operate = (operation: SymbolicTypes.Algebraic.t, n1: t, n2: t) => switch(operation){
| `Add => Some(add(n1, n2))
| `Subtract => Some(subtract(n1, n2))
| _ => None
}
}; };
module Beta = { module Beta = {
@ -171,6 +177,11 @@ module Lognormal = {
let sigma = l1.sigma +. l2.sigma; let sigma = l1.sigma +. l2.sigma;
`Lognormal({mu, sigma}); `Lognormal({mu, sigma});
}; };
let operate = (operation: SymbolicTypes.Algebraic.t, n1: t, n2: t) => switch(operation){
| `Multiply => Some(multiply(n1, n2))
| `Divide => Some(divide(n1, n2))
| _ => None
}
}; };
module Uniform = { module Uniform = {

View File

@ -0,0 +1,67 @@
type pointwiseOperation = [ | `Add | `Multiply];
type scaleOperation = [ | `Multiply | `Exponentiate | `Log];
type distToFloatOperation = [ | `Pdf(float) | `Inv(float) | `Mean | `Sample];
type algebraicOperation = [ | `Add | `Multiply | `Subtract | `Divide];
module Algebraic = {
type t = algebraicOperation;
let toFn: (t, float, float) => float =
fun
| `Add => (+.)
| `Subtract => (-.)
| `Multiply => ( *. )
| `Divide => (/.);
let applyFn = (t, f1, f2) => {
switch (t, f1, f2) {
| (`Divide, _, 0.) => Error("Cannot divide $v1 by zero.")
| _ => Ok(toFn(t, f1, f2))
};
};
let toString =
fun
| `Add => "+"
| `Subtract => "-"
| `Multiply => "*"
| `Divide => "/";
let format = (a, b, c) => b ++ " " ++ toString(a) ++ " " ++ c;
};
module Pointwise = {
type t = pointwiseOperation;
let toString =
fun
| `Add => "+"
| `Multiply => "*";
let format = (a, b, c) => b ++ " " ++ toString(a) ++ " " ++ c;
};
module DistToFloat = {
type t = distToFloatOperation;
let stringFromFloatFromDistOperation =
fun
| `Pdf(f) => {j|pdf(x=$f, |j}
| `Inv(f) => {j|inv(x=$f, |j}
| `Sample => "sample("
| `Mean => "mean(";
let format = (a, b) => stringFromFloatFromDistOperation(a) ++ b ++ ")";
};
module Scale = {
type t = scaleOperation;
let toFn =
fun
| `Multiply => ( *. )
| `Exponentiate => ( ** )
| `Log => ((a, b) => log(a) /. log(b));
let toKnownIntegralSumFn =
fun
| `Multiply => ((a, b) => Some(a *. b))
| `Exponentiate => ((_, _) => None)
| `Log => ((_, _) => None);
}

View File

@ -1,86 +1,41 @@
/* This module represents a tree node. */ /* This module represents a tree node. */
open SymbolicTypes;
// todo: Symbolic already has an arbitrary continuousShape option. It seems messy to have both. // todo: Symbolic already has an arbitrary continuousShape option. It seems messy to have both.
type distData = [ type distData = [
| `Symbolic(SymbolicDist.dist) | `Symbolic(SymbolicDist.dist)
| `RenderedShape(DistTypes.shape) | `RenderedShape(DistTypes.shape)
]; ];
/* TreeNodes are either Data (i.e. symbolic or rendered distributions) or Operations. Operations always refer to two child nodes.*/
type pointwiseOperation = [ | `Add | `Multiply]; type treeNode = [ | `DistData(distData) | `Operation(operation)]
type scaleOperation = [ | `Multiply | `Exponentiate | `Log];
type distToFloatOperation = [ | `Pdf(float) | `Inv(float) | `Mean | `Sample];
/* TreeNodes are either Data (i.e. symbolic or rendered distributions) or Operations. */
type treeNode = [
| `DistData(distData) // a leaf node that describes a distribution
| `Operation(operation) // an operation on two child nodes
]
and operation = [ and operation = [
| // binary operations | `AlgebraicCombination(algebraicOperation, treeNode, treeNode)
`AlgebraicCombination( | `PointwiseCombination(pointwiseOperation, treeNode, treeNode)
AlgebraicCombinations.algebraicOperation, | `VerticalScaling(scaleOperation, treeNode, treeNode)
treeNode, | `Render(treeNode)
treeNode, | `Truncate(option(float), option(float), treeNode)
) | `Normalize(treeNode)
// unary operations | `FloatFromDist(distToFloatOperation, treeNode)
| `PointwiseCombination(pointwiseOperation, treeNode, treeNode) // always evaluates to `DistData(`RenderedShape(...))
| `VerticalScaling(scaleOperation, treeNode, treeNode) // always evaluates to `DistData(`RenderedShape(...))
| `Render(treeNode) // always evaluates to `DistData(`RenderedShape(...))
| `Truncate // always evaluates to `DistData(`RenderedShape(...))
(
option(float),
option(float),
treeNode,
) // leftCutoff and rightCutoff
| `Normalize // always evaluates to `DistData(`RenderedShape(...))
// leftCutoff and rightCutoff
(
treeNode,
)
| `FloatFromDist // always evaluates to `DistData(`RenderedShape(...))
// leftCutoff and rightCutoff
(
distToFloatOperation,
treeNode,
)
]; ];
module TreeNode = { module TreeNode = {
type t = treeNode; type t = treeNode;
type tResult = treeNode => result(treeNode, string); type tResult = treeNode => result(treeNode, string);
let rec toString = (t: t): string => { let rec toString =
let stringFromAlgebraicCombination = fun
fun
| `Add => " + "
| `Subtract => " - "
| `Multiply => " * "
| `Divide => " / "
let stringFromPointwiseCombination =
fun
| `Add => " .+ "
| `Multiply => " .* ";
let stringFromFloatFromDistOperation =
fun
| `Pdf(f) => {j|pdf(x=$f, |j}
| `Inv(f) => {j|inv(x=$f, |j}
| `Sample => "sample("
| `Mean => "mean(";
switch (t) {
| `DistData(`Symbolic(d)) => | `DistData(`Symbolic(d)) =>
SymbolicDist.GenericDistFunctions.toString(d) SymbolicDist.GenericDistFunctions.toString(d)
| `DistData(`RenderedShape(_)) => "[shape]" | `DistData(`RenderedShape(_)) => "[shape]"
| `Operation(`AlgebraicCombination(op, t1, t2)) => | `Operation(`AlgebraicCombination(op, t1, t2)) =>
toString(t1) ++ stringFromAlgebraicCombination(op) ++ toString(t2) SymbolicTypes.Algebraic.format(op, toString(t1), toString(t2))
| `Operation(`PointwiseCombination(op, t1, t2)) => | `Operation(`PointwiseCombination(op, t1, t2)) =>
toString(t1) ++ stringFromPointwiseCombination(op) ++ toString(t2) SymbolicTypes.Pointwise.format(op, toString(t1), toString(t2))
| `Operation(`VerticalScaling(_scaleOp, t, scaleBy)) => | `Operation(`VerticalScaling(_scaleOp, t, scaleBy)) =>
toString(t) ++ " @ " ++ toString(scaleBy) toString(t) ++ " @ " ++ toString(scaleBy)
| `Operation(`Normalize(t)) => "normalize(" ++ toString(t) ++ ")" | `Operation(`Normalize(t)) => "normalize(" ++ toString(t) ++ ")"
| `Operation(`FloatFromDist(floatFromDistOp, t)) => stringFromFloatFromDistOperation(floatFromDistOp) ++ toString(t) ++ ")" | `Operation(`FloatFromDist(floatFromDistOp, t)) =>
SymbolicTypes.DistToFloat.format(floatFromDistOp, toString(t))
| `Operation(`Truncate(lc, rc, t)) => | `Operation(`Truncate(lc, rc, t)) =>
"truncate(" "truncate("
++ toString(t) ++ toString(t)
@ -89,9 +44,7 @@ module TreeNode = {
++ ", " ++ ", "
++ E.O.dimap(Js.Float.toString, () => "inf", rc) ++ E.O.dimap(Js.Float.toString, () => "inf", rc)
++ ")" ++ ")"
| `Operation(`Render(t)) => toString(t) | `Operation(`Render(t)) => toString(t);
};
};
/* The following modules encapsulate everything we can do with /* The following modules encapsulate everything we can do with
* different kinds of operations. */ * different kinds of operations. */
@ -104,88 +57,72 @@ module TreeNode = {
let simplify = (algebraicOp, t1: t, t2: t): result(treeNode, string) => { let simplify = (algebraicOp, t1: t, t2: t): result(treeNode, string) => {
let tryCombiningFloats: tResult = let tryCombiningFloats: tResult =
fun fun
| `Operation(
`AlgebraicCombination(
`Divide,
`DistData(`Symbolic(`Float(_))),
`DistData(`Symbolic(`Float(0.))),
),
) =>
Error("Cannot divide $v1 by zero.")
| `Operation( | `Operation(
`AlgebraicCombination( `AlgebraicCombination(
algebraicOp, algebraicOp,
`DistData(`Symbolic(`Float(v1))), `DistData(`Symbolic(`Float(v1))),
`DistData(`Symbolic(`Float(v2))), `DistData(`Symbolic(`Float(v2))),
), ),
) => { ) =>
let func = AlgebraicCombinations.Operation.toFn(algebraicOp); SymbolicTypes.Algebraic.applyFn(algebraicOp, v1, v2)
Ok(`DistData(`Symbolic(`Float(func(v1, v2))))); |> E.R.fmap(r => `DistData(`Symbolic(`Float(r))))
}
| t => Ok(t); | t => Ok(t);
let optionToSymbolicResult = (t, o) =>
o
|> E.O.dimap(r => `DistData(`Symbolic(r)), () => t)
|> (r => Ok(r));
let tryCombiningNormals: tResult = let tryCombiningNormals: tResult =
fun fun
| `Operation( | `Operation(
`AlgebraicCombination( `AlgebraicCombination(
`Add, operation,
`DistData(`Symbolic(`Normal(n1))), `DistData(`Symbolic(`Normal(n1))),
`DistData(`Symbolic(`Normal(n2))), `DistData(`Symbolic(`Normal(n2))),
), ),
) => ) as t =>
Ok(`DistData(`Symbolic(SymbolicDist.Normal.add(n1, n2)))) SymbolicDist.Normal.operate(operation, n1, n2)
| `Operation( |> optionToSymbolicResult(t)
`AlgebraicCombination(
`Subtract,
`DistData(`Symbolic(`Normal(n1))),
`DistData(`Symbolic(`Normal(n2))),
),
) =>
Ok(`DistData(`Symbolic(SymbolicDist.Normal.subtract(n1, n2))))
| t => Ok(t); | t => Ok(t);
let tryCombiningLognormals: tResult = let tryCombiningLognormals: tResult =
fun fun
| `Operation( | `Operation(
`AlgebraicCombination( `AlgebraicCombination(
`Multiply, operation,
`DistData(`Symbolic(`Lognormal(l1))), `DistData(`Symbolic(`Lognormal(n1))),
`DistData(`Symbolic(`Lognormal(l2))), `DistData(`Symbolic(`Lognormal(n2))),
), ),
) => ) as t =>
Ok(`DistData(`Symbolic(SymbolicDist.Lognormal.multiply(l1, l2)))) SymbolicDist.Lognormal.operate(operation, n1, n2)
| `Operation( |> optionToSymbolicResult(t)
`AlgebraicCombination(
`Divide,
`DistData(`Symbolic(`Lognormal(l1))),
`DistData(`Symbolic(`Lognormal(l2))),
),
) =>
Ok(`DistData(`Symbolic(SymbolicDist.Lognormal.divide(l1, l2))))
| t => Ok(t); | t => Ok(t);
let originalTreeNode = let originalTreeNode =
`Operation(`AlgebraicCombination((algebraicOp, t1, t2))); `Operation(`AlgebraicCombination((algebraicOp, t1, t2)));
// Feedback: I like this pattern, kudos
originalTreeNode originalTreeNode
|> tryCombiningFloats |> tryCombiningFloats
|> E.R.bind(_, tryCombiningNormals) |> E.R.bind(_, tryCombiningNormals)
|> E.R.bind(_, tryCombiningLognormals); |> E.R.bind(_, tryCombiningLognormals);
}; };
// todo: I don't like the name evaluateNumerically that much, if this renders and does it algebraically. It's tricky.
let evaluateNumerically = (algebraicOp, operationToDistData, t1, t2) => { let evaluateNumerically = (algebraicOp, operationToDistData, t1, t2) => {
// force rendering into shapes // force rendering into shapes
let renderedShape1 = operationToDistData(`Render(t1)); let renderShape = r => operationToDistData(`Render(r));
let renderedShape2 = operationToDistData(`Render(t2)); switch (renderShape(t1), renderShape(t2)) {
switch (renderedShape1, renderedShape2) {
| ( | (
Ok(`DistData(`RenderedShape(s1))), Ok(`DistData(`RenderedShape(s1))),
Ok(`DistData(`RenderedShape(s2))), Ok(`DistData(`RenderedShape(s2))),
) => ) =>
Ok( Ok(
`DistData( `DistData(
`RenderedShape(Distributions.Shape.combineAlgebraically(algebraicOp, s1, s2)), `RenderedShape(
Distributions.Shape.combineAlgebraically(algebraicOp, s1, s2),
),
), ),
) )
| (Error(e1), _) => Error(e1) | (Error(e1), _) => Error(e1)
@ -195,7 +132,12 @@ module TreeNode = {
}; };
let evaluateToDistData = let evaluateToDistData =
(algebraicOp: AlgebraicCombinations.algebraicOperation, operationToDistData, t1: t, t2: t) (
algebraicOp: SymbolicTypes.algebraicOperation,
operationToDistData,
t1: t,
t2: t,
)
: result(treeNode, string) => : result(treeNode, string) =>
algebraicOp algebraicOp
|> simplify(_, t1, t2) |> simplify(_, t1, t2)
@ -210,27 +152,13 @@ module TreeNode = {
}; };
module VerticalScaling = { module VerticalScaling = {
let fnFromOp =
fun
| `Multiply => ( *. )
| `Exponentiate => ( ** )
| `Log => ((a, b) => log(a) /. log(b));
let knownIntegralSumFnFromOp =
fun
| `Multiply => ((a, b) => Some(a *. b))
| `Exponentiate => ((_, _) => None)
| `Log => ((_, _) => None);
let evaluateToDistData = (scaleOp, operationToDistData, t, scaleBy) => { let evaluateToDistData = (scaleOp, operationToDistData, t, scaleBy) => {
// scaleBy has to be a single float, otherwise we'll return an error. // scaleBy has to be a single float, otherwise we'll return an error.
let fn = fnFromOp(scaleOp); let fn = SymbolicTypes.Scale.toFn(scaleOp);
let knownIntegralSumFn = knownIntegralSumFnFromOp(scaleOp); let knownIntegralSumFn = SymbolicTypes.Scale.toKnownIntegralSumFn(scaleOp);
let renderedShape = operationToDistData(`Render(t)); let renderedShape = operationToDistData(`Render(t));
switch (renderedShape, scaleBy) { switch (renderedShape, scaleBy) {
| (Error(e1), _) => Error(e1)
| ( | (
Ok(`DistData(`RenderedShape(rs))), Ok(`DistData(`RenderedShape(rs))),
`DistData(`Symbolic(`Float(sm))), `DistData(`Symbolic(`Float(sm))),
@ -246,6 +174,7 @@ module TreeNode = {
), ),
), ),
) )
| (Error(e1), _) => Error(e1)
| (_, _) => Error("Can only scale by float values.") | (_, _) => Error("Can only scale by float values.")
}; };
}; };
@ -253,14 +182,28 @@ module TreeNode = {
module PointwiseCombination = { module PointwiseCombination = {
let pointwiseAdd = (operationToDistData, t1, t2) => { let pointwiseAdd = (operationToDistData, t1, t2) => {
let renderedShape1 = operationToDistData(`Render(t1)); let renderedShape1 = operationToDistData(`Render(t1));
let renderedShape2 = operationToDistData(`Render(t2)); let renderedShape2 = operationToDistData(`Render(t2));
switch ((renderedShape1, renderedShape2)) { switch (renderedShape1, renderedShape2) {
| (
Ok(`DistData(`RenderedShape(rs1))),
Ok(`DistData(`RenderedShape(rs2))),
) =>
Ok(
`DistData(
`RenderedShape(
Distributions.Shape.combinePointwise(
~knownIntegralSumsFn=(a, b) => Some(a +. b),
(+.),
rs1,
rs2,
),
),
),
)
| (Error(e1), _) => Error(e1) | (Error(e1), _) => Error(e1)
| (_, Error(e2)) => Error(e2) | (_, Error(e2)) => Error(e2)
| (Ok(`DistData(`RenderedShape(rs1))), Ok(`DistData(`RenderedShape(rs2)))) =>
Ok(`DistData(`RenderedShape(Distributions.Shape.combinePointwise(~knownIntegralSumsFn=(a, b) => Some(a +. b), (+.), rs1, rs2))))
| _ => Error("Could not perform pointwise addition.") | _ => Error("Could not perform pointwise addition.")
}; };
}; };
@ -268,14 +211,16 @@ module TreeNode = {
let pointwiseMultiply = (operationToDistData, t1, t2) => { let pointwiseMultiply = (operationToDistData, t1, t2) => {
// TODO: construct a function that we can easily sample from, to construct // TODO: construct a function that we can easily sample from, to construct
// a RenderedShape. Use the xMin and xMax of the rendered shapes to tell the sampling function where to look. // a RenderedShape. Use the xMin and xMax of the rendered shapes to tell the sampling function where to look.
Error("Pointwise multiplication not yet supported."); Error(
"Pointwise multiplication not yet supported.",
);
}; };
let evaluateToDistData = (pointwiseOp, operationToDistData, t1, t2) => { let evaluateToDistData = (pointwiseOp, operationToDistData, t1, t2) => {
switch (pointwiseOp) { switch (pointwiseOp) {
| `Add => pointwiseAdd(operationToDistData, t1, t2) | `Add => pointwiseAdd(operationToDistData, t1, t2)
| `Multiply => pointwiseMultiply(operationToDistData, t1, t2) | `Multiply => pointwiseMultiply(operationToDistData, t1, t2)
} };
}; };
}; };
@ -378,7 +323,9 @@ module TreeNode = {
}; };
E.R.bind(value, v => Ok(`DistData(`Symbolic(`Float(v))))); E.R.bind(value, v => Ok(`DistData(`Symbolic(`Float(v)))));
}; };
let evaluateFromRenderedShape = (distToFloatOp: distToFloatOperation, rs: DistTypes.shape) : result(treeNode, string) => { let evaluateFromRenderedShape =
(distToFloatOp: distToFloatOperation, rs: DistTypes.shape)
: result(treeNode, string) => {
let value = let value =
switch (distToFloatOp) { switch (distToFloatOp) {
| `Pdf(f) => Ok(Distributions.Shape.pdf(f, rs)) | `Pdf(f) => Ok(Distributions.Shape.pdf(f, rs))
@ -410,8 +357,12 @@ module TreeNode = {
module Render = { module Render = {
let rec evaluateToRenderedShape = let rec evaluateToRenderedShape =
(operationToDistData: operation => result(t, string), sampleCount: int, t: treeNode) (
: result(t, string) => { operationToDistData: operation => result(t, string),
sampleCount: int,
t: treeNode,
)
: result(t, string) => {
switch (t) { switch (t) {
| `DistData(`RenderedShape(s)) => Ok(`DistData(`RenderedShape(s))) // already a rendered shape, we're done here | `DistData(`RenderedShape(s)) => Ok(`DistData(`RenderedShape(s))) // already a rendered shape, we're done here
| `DistData(`Symbolic(d)) => | `DistData(`Symbolic(d)) =>
@ -495,10 +446,19 @@ module TreeNode = {
t, t,
) )
| `FloatFromDist(distToFloatOp, t) => | `FloatFromDist(distToFloatOp, t) =>
FloatFromDist.evaluateToDistData(distToFloatOp, operationToDistData(sampleCount), t) FloatFromDist.evaluateToDistData(
| `Normalize(t) => Normalize.evaluateToDistData(operationToDistData(sampleCount), t) distToFloatOp,
operationToDistData(sampleCount),
t,
)
| `Normalize(t) =>
Normalize.evaluateToDistData(operationToDistData(sampleCount), t)
| `Render(t) => | `Render(t) =>
Render.evaluateToRenderedShape(operationToDistData(sampleCount), sampleCount, t) Render.evaluateToRenderedShape(
operationToDistData(sampleCount),
sampleCount,
t,
)
}; };
}; };
@ -531,5 +491,4 @@ let toShape = (sampleCount: int, treeNode: treeNode) => {
}; };
}; };
let toString = (treeNode: treeNode) => let toString = (treeNode: treeNode) => TreeNode.toString(treeNode);
TreeNode.toString(treeNode);