Moving operations functionality into new SymbolicTypes.re file
This commit is contained in:
parent
acdd3dfe7a
commit
baaff19750
|
@ -1,5 +1,3 @@
|
|||
type algebraicOperation = [ | `Add | `Multiply | `Subtract | `Divide];
|
||||
|
||||
type pointMassesWithMoments = {
|
||||
n: int,
|
||||
masses: array(float),
|
||||
|
@ -7,23 +5,6 @@ type pointMassesWithMoments = {
|
|||
variances: array(float),
|
||||
};
|
||||
|
||||
module Operation = {
|
||||
type t = algebraicOperation;
|
||||
let toFn: (t, float, float) => float =
|
||||
fun
|
||||
| `Add => (+.)
|
||||
| `Subtract => (-.)
|
||||
| `Multiply => ( *. )
|
||||
| `Divide => (/.);
|
||||
|
||||
let toString =
|
||||
fun
|
||||
| `Add => " + "
|
||||
| `Subtract => " - "
|
||||
| `Multiply => " * "
|
||||
| `Divide => " / ";
|
||||
};
|
||||
|
||||
/* This function takes a continuous distribution and efficiently approximates it as
|
||||
point masses that have variances associated with them.
|
||||
We estimate the means and variances from overlapping triangular distributions which we imagine are making up the
|
||||
|
@ -129,7 +110,7 @@ let toDiscretePointMassesFromTriangulars =
|
|||
};
|
||||
|
||||
let combineShapesContinuousContinuous =
|
||||
(op: algebraicOperation, s1: DistTypes.xyShape, s2: DistTypes.xyShape)
|
||||
(op: SymbolicTypes.algebraicOperation, s1: DistTypes.xyShape, s2: DistTypes.xyShape)
|
||||
: DistTypes.xyShape => {
|
||||
let t1n = s1 |> XYShape.T.length;
|
||||
let t2n = s2 |> XYShape.T.length;
|
||||
|
@ -216,4 +197,4 @@ let combineShapesContinuousContinuous =
|
|||
};
|
||||
|
||||
{xs: outputXs, ys: outputYs};
|
||||
};
|
||||
};
|
|
@ -149,7 +149,7 @@ module Continuous = {
|
|||
continuousShapes
|
||||
|> E.A.fold_left(combinePointwise(~knownIntegralSumsFn, fn), empty);
|
||||
|
||||
let mapY = (~knownIntegralSumFn=(_ => None), fn, t: t) => {
|
||||
let mapY = (~knownIntegralSumFn=_ => None, fn, t: t) => {
|
||||
let u = E.O.bind(_, knownIntegralSumFn);
|
||||
let yMapFn = shapeMap(XYShape.T.mapY(fn));
|
||||
|
||||
|
@ -164,7 +164,6 @@ module Continuous = {
|
|||
);
|
||||
};
|
||||
|
||||
|
||||
module T =
|
||||
Dist({
|
||||
type t = DistTypes.continuousShape;
|
||||
|
@ -194,9 +193,9 @@ module Continuous = {
|
|||
|> getShape
|
||||
|> XYShape.T.zip
|
||||
|> XYShape.Zipped.filterByX(x =>
|
||||
x >= E.O.default(neg_infinity, leftCutoff)
|
||||
|| x <= E.O.default(infinity, rightCutoff)
|
||||
);
|
||||
x >= E.O.default(neg_infinity, leftCutoff)
|
||||
|| x <= E.O.default(infinity, rightCutoff)
|
||||
);
|
||||
|
||||
let eps = (t |> getShape |> XYShape.T.xTotalRange) *. 0.0001;
|
||||
|
||||
|
@ -206,7 +205,11 @@ module Continuous = {
|
|||
rightCutoff |> E.O.dimap(rc => [|(rc +. eps, 0.)|], _ => [||]);
|
||||
|
||||
let truncatedZippedPairsWithNewPoints =
|
||||
E.A.concatMany([|leftNewPoint, truncatedZippedPairs, rightNewPoint|]);
|
||||
E.A.concatMany([|
|
||||
leftNewPoint,
|
||||
truncatedZippedPairs,
|
||||
rightNewPoint,
|
||||
|]);
|
||||
let truncatedShape =
|
||||
XYShape.T.fromZippedArray(truncatedZippedPairsWithNewPoints);
|
||||
|
||||
|
@ -214,22 +217,20 @@ module Continuous = {
|
|||
};
|
||||
|
||||
// TODO: This should work with stepwise plots.
|
||||
let integral = (~cache, t) => {
|
||||
|
||||
if ((t |> getShape |> XYShape.T.length) > 0) {
|
||||
switch (cache) {
|
||||
| Some(cache) => cache
|
||||
| None =>
|
||||
t
|
||||
|> getShape
|
||||
|> XYShape.Range.integrateWithTriangles
|
||||
|> E.O.toExt("This should not have happened")
|
||||
|> make(`Linear, _, None)
|
||||
};
|
||||
let integral = (~cache, t) =>
|
||||
if (t |> getShape |> XYShape.T.length > 0) {
|
||||
switch (cache) {
|
||||
| Some(cache) => cache
|
||||
| None =>
|
||||
t
|
||||
|> getShape
|
||||
|> XYShape.Range.integrateWithTriangles
|
||||
|> E.O.toExt("This should not have happened")
|
||||
|> make(`Linear, _, None)
|
||||
};
|
||||
} else {
|
||||
make(`Linear, {xs: [|neg_infinity|], ys: [|0.0|]}, None);
|
||||
}
|
||||
};
|
||||
};
|
||||
|
||||
let downsample = (~cache=None, length, t): t =>
|
||||
t
|
||||
|
@ -276,23 +277,31 @@ module Continuous = {
|
|||
);
|
||||
});
|
||||
|
||||
|
||||
/* This simply creates multiple copies of the continuous distribution, scaled and shifted according to
|
||||
each discrete data point, and then adds them all together. */
|
||||
let combineAlgebraicallyWithDiscrete = (~downsample=false, op: AlgebraicCombinations.algebraicOperation, t1: t, t2: DistTypes.discreteShape) => {
|
||||
let combineAlgebraicallyWithDiscrete =
|
||||
(
|
||||
~downsample=false,
|
||||
op: SymbolicTypes.algebraicOperation,
|
||||
t1: t,
|
||||
t2: DistTypes.discreteShape,
|
||||
) => {
|
||||
let t1s = t1 |> getShape;
|
||||
let t2s = t2.xyShape; // would like to use Discrete.getShape here, but current file structure doesn't allow for that
|
||||
let t1n = t1s |> XYShape.T.length;
|
||||
let t2n = t2s |> XYShape.T.length;
|
||||
|
||||
let fn = AlgebraicCombinations.Operation.toFn(op);
|
||||
let fn = SymbolicTypes.Algebraic.toFn(op);
|
||||
|
||||
let outXYShapes: array(array((float, float))) =
|
||||
Belt.Array.makeUninitializedUnsafe(t2n);
|
||||
|
||||
for (j in 0 to t2n - 1) { // for each one of the discrete points
|
||||
for (j in 0 to t2n - 1) {
|
||||
// for each one of the discrete points
|
||||
// create a new distribution, as long as the original continuous one
|
||||
let dxyShape: array((float, float)) = Belt.Array.makeUninitializedUnsafe(t1n);
|
||||
|
||||
let dxyShape: array((float, float)) =
|
||||
Belt.Array.makeUninitializedUnsafe(t1n);
|
||||
for (i in 0 to t1n - 1) {
|
||||
let _ =
|
||||
Belt.Array.set(
|
||||
|
@ -307,7 +316,12 @@ module Continuous = {
|
|||
();
|
||||
};
|
||||
|
||||
let combinedIntegralSum = Common.combineIntegralSums((a, b) => Some(a *. b), t1.knownIntegralSum, t2.knownIntegralSum);
|
||||
let combinedIntegralSum =
|
||||
Common.combineIntegralSums(
|
||||
(a, b) => Some(a *. b),
|
||||
t1.knownIntegralSum,
|
||||
t2.knownIntegralSum,
|
||||
);
|
||||
|
||||
outXYShapes
|
||||
|> E.A.fmap(s => {
|
||||
|
@ -318,7 +332,13 @@ module Continuous = {
|
|||
|> updateKnownIntegralSum(combinedIntegralSum);
|
||||
};
|
||||
|
||||
let combineAlgebraically = (~downsample=false, op: AlgebraicCombinations.algebraicOperation, t1: t, t2: t) => {
|
||||
let combineAlgebraically =
|
||||
(
|
||||
~downsample=false,
|
||||
op: SymbolicTypes.algebraicOperation,
|
||||
t1: t,
|
||||
t2: t,
|
||||
) => {
|
||||
let s1 = t1 |> getShape;
|
||||
let s2 = t2 |> getShape;
|
||||
let t1n = s1 |> XYShape.T.length;
|
||||
|
@ -326,8 +346,14 @@ module Continuous = {
|
|||
if (t1n == 0 || t2n == 0) {
|
||||
empty;
|
||||
} else {
|
||||
let combinedShape = AlgebraicCombinations.combineShapesContinuousContinuous(op, s1, s2);
|
||||
let combinedIntegralSum = Common.combineIntegralSums((a, b) => Some(a *. b), t1.knownIntegralSum, t2.knownIntegralSum);
|
||||
let combinedShape =
|
||||
AlgebraicCombinations.combineShapesContinuousContinuous(op, s1, s2);
|
||||
let combinedIntegralSum =
|
||||
Common.combineIntegralSums(
|
||||
(a, b) => Some(a *. b),
|
||||
t1.knownIntegralSum,
|
||||
t2.knownIntegralSum,
|
||||
);
|
||||
// return a new Continuous distribution
|
||||
make(`Linear, combinedShape, combinedIntegralSum);
|
||||
};
|
||||
|
@ -370,7 +396,7 @@ module Discrete = {
|
|||
XYShape.PointwiseCombination.combine(
|
||||
~xsSelection=ALL_XS,
|
||||
~xToYSelection=XYShape.XtoY.stepwiseIfAtX,
|
||||
~fn=((a, b) => fn(E.O.default(0.0, a), E.O.default(0.0, b))), // stepwiseIfAtX returns option(float), so this fn needs to handle None
|
||||
~fn=(a, b) => fn(E.O.default(0.0, a), E.O.default(0.0, b)), // stepwiseIfAtX returns option(float), so this fn needs to handle None
|
||||
t1.xyShape,
|
||||
t2.xyShape,
|
||||
),
|
||||
|
@ -378,7 +404,9 @@ module Discrete = {
|
|||
);
|
||||
};
|
||||
|
||||
let reduce = (~knownIntegralSumsFn=(_, _) => None, fn, discreteShapes): DistTypes.discreteShape =>
|
||||
let reduce =
|
||||
(~knownIntegralSumsFn=(_, _) => None, fn, discreteShapes)
|
||||
: DistTypes.discreteShape =>
|
||||
discreteShapes
|
||||
|> E.A.fold_left(combinePointwise(~knownIntegralSumsFn, fn), empty);
|
||||
|
||||
|
@ -389,7 +417,8 @@ module Discrete = {
|
|||
|
||||
/* This multiples all of the data points together and creates a new discrete distribution from the results.
|
||||
Data points at the same xs get added together. It may be a good idea to downsample t1 and t2 before and/or the result after. */
|
||||
let combineAlgebraically = (op: AlgebraicCombinations.algebraicOperation, t1: t, t2: t) => {
|
||||
let combineAlgebraically =
|
||||
(op: SymbolicTypes.algebraicOperation, t1: t, t2: t) => {
|
||||
let t1s = t1 |> getShape;
|
||||
let t2s = t2 |> getShape;
|
||||
let t1n = t1s |> XYShape.T.length;
|
||||
|
@ -402,7 +431,7 @@ module Discrete = {
|
|||
t2.knownIntegralSum,
|
||||
);
|
||||
|
||||
let fn = AlgebraicCombinations.Operation.toFn(op);
|
||||
let fn = SymbolicTypes.Algebraic.toFn(op);
|
||||
let xToYMap = E.FloatFloatMap.empty();
|
||||
|
||||
for (i in 0 to t1n - 1) {
|
||||
|
@ -441,8 +470,8 @@ module Discrete = {
|
|||
Dist({
|
||||
type t = DistTypes.discreteShape;
|
||||
type integral = DistTypes.continuousShape;
|
||||
let integral = (~cache, t) => {
|
||||
if ((t |> getShape |> XYShape.T.length) > 0) {
|
||||
let integral = (~cache, t) =>
|
||||
if (t |> getShape |> XYShape.T.length > 0) {
|
||||
switch (cache) {
|
||||
| Some(c) => c
|
||||
| None =>
|
||||
|
@ -453,9 +482,13 @@ module Discrete = {
|
|||
)
|
||||
};
|
||||
} else {
|
||||
Continuous.make(`Stepwise, {xs: [|neg_infinity|], ys: [|0.0|]}, None);
|
||||
}};
|
||||
|
||||
Continuous.make(
|
||||
`Stepwise,
|
||||
{xs: [|neg_infinity|], ys: [|0.0|]},
|
||||
None,
|
||||
);
|
||||
};
|
||||
|
||||
let integralEndY = (~cache, t: t) =>
|
||||
t.knownIntegralSum
|
||||
|> E.O.default(t |> integral(~cache) |> Continuous.lastY);
|
||||
|
@ -495,7 +528,7 @@ module Discrete = {
|
|||
make(clippedShape, None); // if someone needs the sum, they'll have to recompute it
|
||||
} else {
|
||||
t;
|
||||
}
|
||||
};
|
||||
};
|
||||
|
||||
let truncate =
|
||||
|
@ -505,9 +538,9 @@ module Discrete = {
|
|||
|> getShape
|
||||
|> XYShape.T.zip
|
||||
|> XYShape.Zipped.filterByX(x =>
|
||||
x >= E.O.default(neg_infinity, leftCutoff)
|
||||
|| x <= E.O.default(infinity, rightCutoff)
|
||||
)
|
||||
x >= E.O.default(neg_infinity, leftCutoff)
|
||||
|| x <= E.O.default(infinity, rightCutoff)
|
||||
)
|
||||
|> XYShape.T.fromZippedArray;
|
||||
|
||||
make(truncatedShape, None);
|
||||
|
@ -601,8 +634,10 @@ module Mixed = {
|
|||
rightCutoff: option(float),
|
||||
{discrete, continuous}: t,
|
||||
) => {
|
||||
let truncatedContinuous = Continuous.T.truncate(leftCutoff, rightCutoff, continuous);
|
||||
let truncatedDiscrete = Discrete.T.truncate(leftCutoff, rightCutoff, discrete);
|
||||
let truncatedContinuous =
|
||||
Continuous.T.truncate(leftCutoff, rightCutoff, continuous);
|
||||
let truncatedDiscrete =
|
||||
Discrete.T.truncate(leftCutoff, rightCutoff, discrete);
|
||||
|
||||
make(~discrete=truncatedDiscrete, ~continuous=truncatedContinuous);
|
||||
};
|
||||
|
@ -809,7 +844,14 @@ module Mixed = {
|
|||
};
|
||||
});
|
||||
|
||||
let combineAlgebraically = (~downsample=false, op: AlgebraicCombinations.algebraicOperation, t1: t, t2: t): t => {
|
||||
let combineAlgebraically =
|
||||
(
|
||||
~downsample=false,
|
||||
op: SymbolicTypes.algebraicOperation,
|
||||
t1: t,
|
||||
t2: t,
|
||||
)
|
||||
: t => {
|
||||
// Discrete convolution can cause a huge increase in the number of samples,
|
||||
// so we'll first downsample.
|
||||
|
||||
|
@ -827,11 +869,26 @@ module Mixed = {
|
|||
// continuous (*) continuous => continuous, but also
|
||||
// discrete (*) continuous => continuous (and vice versa). We have to take care of all combos and then combine them:
|
||||
let ccConvResult =
|
||||
Continuous.combineAlgebraically(~downsample=false, op, t1d.continuous, t2d.continuous);
|
||||
Continuous.combineAlgebraically(
|
||||
~downsample=false,
|
||||
op,
|
||||
t1d.continuous,
|
||||
t2d.continuous,
|
||||
);
|
||||
let dcConvResult =
|
||||
Continuous.combineAlgebraicallyWithDiscrete(~downsample=false, op, t2d.continuous, t1d.discrete);
|
||||
Continuous.combineAlgebraicallyWithDiscrete(
|
||||
~downsample=false,
|
||||
op,
|
||||
t2d.continuous,
|
||||
t1d.discrete,
|
||||
);
|
||||
let cdConvResult =
|
||||
Continuous.combineAlgebraicallyWithDiscrete(~downsample=false, op, t1d.continuous, t2d.discrete);
|
||||
Continuous.combineAlgebraicallyWithDiscrete(
|
||||
~downsample=false,
|
||||
op,
|
||||
t1d.continuous,
|
||||
t2d.discrete,
|
||||
);
|
||||
let continuousConvResult =
|
||||
Continuous.reduce((+.), [|ccConvResult, dcConvResult, cdConvResult|]);
|
||||
|
||||
|
@ -866,23 +923,47 @@ module Shape = {
|
|||
c => Mixed.make(~discrete=Discrete.empty, ~continuous=c),
|
||||
));
|
||||
|
||||
let combineAlgebraically = (op: AlgebraicCombinations.algebraicOperation, t1: t, t2: t): t => {
|
||||
switch ((t1, t2)) {
|
||||
| (Continuous(m1), Continuous(m2)) => DistTypes.Continuous(Continuous.combineAlgebraically(~downsample=true, op, m1, m2))
|
||||
| (Discrete(m1), Discrete(m2)) => DistTypes.Discrete(Discrete.combineAlgebraically(op, m1, m2))
|
||||
| (m1, m2) => {
|
||||
DistTypes.Mixed(Mixed.combineAlgebraically(~downsample=true, op, toMixed(m1), toMixed(m2)))
|
||||
}
|
||||
let combineAlgebraically =
|
||||
(op: SymbolicTypes.algebraicOperation, t1: t, t2: t): t => {
|
||||
switch (t1, t2) {
|
||||
| (Continuous(m1), Continuous(m2)) =>
|
||||
DistTypes.Continuous(
|
||||
Continuous.combineAlgebraically(~downsample=true, op, m1, m2),
|
||||
)
|
||||
| (Discrete(m1), Discrete(m2)) =>
|
||||
DistTypes.Discrete(Discrete.combineAlgebraically(op, m1, m2))
|
||||
| (m1, m2) =>
|
||||
DistTypes.Mixed(
|
||||
Mixed.combineAlgebraically(
|
||||
~downsample=true,
|
||||
op,
|
||||
toMixed(m1),
|
||||
toMixed(m2),
|
||||
),
|
||||
)
|
||||
};
|
||||
};
|
||||
|
||||
let combinePointwise = (~knownIntegralSumsFn=(_, _) => None, fn, t1: t, t2: t) =>
|
||||
switch ((t1, t2)) {
|
||||
| (Continuous(m1), Continuous(m2)) => DistTypes.Continuous(Continuous.combinePointwise(~knownIntegralSumsFn, fn, m1, m2))
|
||||
| (Discrete(m1), Discrete(m2)) => DistTypes.Discrete(Discrete.combinePointwise(~knownIntegralSumsFn, fn, m1, m2))
|
||||
| (m1, m2) => {
|
||||
DistTypes.Mixed(Mixed.combinePointwise(~knownIntegralSumsFn, fn, toMixed(m1), toMixed(m2)))
|
||||
}
|
||||
let combinePointwise =
|
||||
(~knownIntegralSumsFn=(_, _) => None, fn, t1: t, t2: t) =>
|
||||
switch (t1, t2) {
|
||||
| (Continuous(m1), Continuous(m2)) =>
|
||||
DistTypes.Continuous(
|
||||
Continuous.combinePointwise(~knownIntegralSumsFn, fn, m1, m2),
|
||||
)
|
||||
| (Discrete(m1), Discrete(m2)) =>
|
||||
DistTypes.Discrete(
|
||||
Discrete.combinePointwise(~knownIntegralSumsFn, fn, m1, m2),
|
||||
)
|
||||
| (m1, m2) =>
|
||||
DistTypes.Mixed(
|
||||
Mixed.combinePointwise(
|
||||
~knownIntegralSumsFn,
|
||||
fn,
|
||||
toMixed(m1),
|
||||
toMixed(m2),
|
||||
),
|
||||
)
|
||||
};
|
||||
|
||||
// TODO: implement these functions
|
||||
|
@ -915,7 +996,6 @@ module Shape = {
|
|||
let toContinuous = t => None;
|
||||
let toDiscrete = t => None;
|
||||
|
||||
|
||||
let downsample = (~cache=None, i, t) =>
|
||||
fmap(
|
||||
(
|
||||
|
@ -938,7 +1018,11 @@ module Shape = {
|
|||
|
||||
let toDiscreteProbabilityMassFraction = t => 0.0;
|
||||
let normalize =
|
||||
fmap((Mixed.T.normalize, Discrete.T.normalize, Continuous.T.normalize));
|
||||
fmap((
|
||||
Mixed.T.normalize,
|
||||
Discrete.T.normalize,
|
||||
Continuous.T.normalize,
|
||||
));
|
||||
let toContinuous =
|
||||
mapToAll((
|
||||
Mixed.T.toContinuous,
|
||||
|
@ -1089,7 +1173,8 @@ module DistPlus = {
|
|||
};
|
||||
|
||||
let truncate = (leftCutoff, rightCutoff, t: t): t => {
|
||||
let truncatedShape = t |> toShape |> Shape.T.truncate(leftCutoff, rightCutoff);
|
||||
let truncatedShape =
|
||||
t |> toShape |> Shape.T.truncate(leftCutoff, rightCutoff);
|
||||
|
||||
t |> updateShape(truncatedShape);
|
||||
};
|
||||
|
@ -1153,9 +1238,9 @@ module DistPlus = {
|
|||
let integralYtoX = (~cache as _, f, t: t) => {
|
||||
Shape.T.Integral.yToX(~cache=Some(t.integralCache), f, toShape(t));
|
||||
};
|
||||
let mean = (t: t) => {
|
||||
Shape.T.mean(t.shape);
|
||||
};
|
||||
let mean = (t: t) => {
|
||||
Shape.T.mean(t.shape);
|
||||
};
|
||||
let variance = (t: t) => Shape.T.variance(t.shape);
|
||||
});
|
||||
};
|
||||
|
|
|
@ -123,6 +123,12 @@ module Normal = {
|
|||
let stdev = 1. /. (1. /. n1.stdev ** 2. +. 1. /. n2.stdev ** 2.);
|
||||
`Normal({mean, stdev});
|
||||
};
|
||||
|
||||
let operate = (operation: SymbolicTypes.Algebraic.t, n1: t, n2: t) => switch(operation){
|
||||
| `Add => Some(add(n1, n2))
|
||||
| `Subtract => Some(subtract(n1, n2))
|
||||
| _ => None
|
||||
}
|
||||
};
|
||||
|
||||
module Beta = {
|
||||
|
@ -171,6 +177,11 @@ module Lognormal = {
|
|||
let sigma = l1.sigma +. l2.sigma;
|
||||
`Lognormal({mu, sigma});
|
||||
};
|
||||
let operate = (operation: SymbolicTypes.Algebraic.t, n1: t, n2: t) => switch(operation){
|
||||
| `Multiply => Some(multiply(n1, n2))
|
||||
| `Divide => Some(divide(n1, n2))
|
||||
| _ => None
|
||||
}
|
||||
};
|
||||
|
||||
module Uniform = {
|
||||
|
|
67
src/distPlus/symbolic/SymbolicTypes.re
Normal file
67
src/distPlus/symbolic/SymbolicTypes.re
Normal file
|
@ -0,0 +1,67 @@
|
|||
type pointwiseOperation = [ | `Add | `Multiply];
|
||||
type scaleOperation = [ | `Multiply | `Exponentiate | `Log];
|
||||
type distToFloatOperation = [ | `Pdf(float) | `Inv(float) | `Mean | `Sample];
|
||||
type algebraicOperation = [ | `Add | `Multiply | `Subtract | `Divide];
|
||||
|
||||
module Algebraic = {
|
||||
type t = algebraicOperation;
|
||||
let toFn: (t, float, float) => float =
|
||||
fun
|
||||
| `Add => (+.)
|
||||
| `Subtract => (-.)
|
||||
| `Multiply => ( *. )
|
||||
| `Divide => (/.);
|
||||
|
||||
let applyFn = (t, f1, f2) => {
|
||||
switch (t, f1, f2) {
|
||||
| (`Divide, _, 0.) => Error("Cannot divide $v1 by zero.")
|
||||
| _ => Ok(toFn(t, f1, f2))
|
||||
};
|
||||
};
|
||||
|
||||
let toString =
|
||||
fun
|
||||
| `Add => "+"
|
||||
| `Subtract => "-"
|
||||
| `Multiply => "*"
|
||||
| `Divide => "/";
|
||||
|
||||
let format = (a, b, c) => b ++ " " ++ toString(a) ++ " " ++ c;
|
||||
};
|
||||
|
||||
module Pointwise = {
|
||||
type t = pointwiseOperation;
|
||||
let toString =
|
||||
fun
|
||||
| `Add => "+"
|
||||
| `Multiply => "*";
|
||||
|
||||
let format = (a, b, c) => b ++ " " ++ toString(a) ++ " " ++ c;
|
||||
};
|
||||
|
||||
module DistToFloat = {
|
||||
type t = distToFloatOperation;
|
||||
|
||||
let stringFromFloatFromDistOperation =
|
||||
fun
|
||||
| `Pdf(f) => {j|pdf(x=$f, |j}
|
||||
| `Inv(f) => {j|inv(x=$f, |j}
|
||||
| `Sample => "sample("
|
||||
| `Mean => "mean(";
|
||||
let format = (a, b) => stringFromFloatFromDistOperation(a) ++ b ++ ")";
|
||||
};
|
||||
|
||||
module Scale = {
|
||||
type t = scaleOperation;
|
||||
let toFn =
|
||||
fun
|
||||
| `Multiply => ( *. )
|
||||
| `Exponentiate => ( ** )
|
||||
| `Log => ((a, b) => log(a) /. log(b));
|
||||
|
||||
let toKnownIntegralSumFn =
|
||||
fun
|
||||
| `Multiply => ((a, b) => Some(a *. b))
|
||||
| `Exponentiate => ((_, _) => None)
|
||||
| `Log => ((_, _) => None);
|
||||
}
|
|
@ -1,86 +1,41 @@
|
|||
/* This module represents a tree node. */
|
||||
open SymbolicTypes;
|
||||
|
||||
// todo: Symbolic already has an arbitrary continuousShape option. It seems messy to have both.
|
||||
type distData = [
|
||||
| `Symbolic(SymbolicDist.dist)
|
||||
| `RenderedShape(DistTypes.shape)
|
||||
];
|
||||
|
||||
type pointwiseOperation = [ | `Add | `Multiply];
|
||||
type scaleOperation = [ | `Multiply | `Exponentiate | `Log];
|
||||
type distToFloatOperation = [ | `Pdf(float) | `Inv(float) | `Mean | `Sample];
|
||||
|
||||
/* TreeNodes are either Data (i.e. symbolic or rendered distributions) or Operations. */
|
||||
type treeNode = [
|
||||
| `DistData(distData) // a leaf node that describes a distribution
|
||||
| `Operation(operation) // an operation on two child nodes
|
||||
]
|
||||
/* TreeNodes are either Data (i.e. symbolic or rendered distributions) or Operations. Operations always refer to two child nodes.*/
|
||||
type treeNode = [ | `DistData(distData) | `Operation(operation)]
|
||||
and operation = [
|
||||
| // binary operations
|
||||
`AlgebraicCombination(
|
||||
AlgebraicCombinations.algebraicOperation,
|
||||
treeNode,
|
||||
treeNode,
|
||||
)
|
||||
// unary operations
|
||||
| `PointwiseCombination(pointwiseOperation, treeNode, treeNode) // always evaluates to `DistData(`RenderedShape(...))
|
||||
| `VerticalScaling(scaleOperation, treeNode, treeNode) // always evaluates to `DistData(`RenderedShape(...))
|
||||
| `Render(treeNode) // always evaluates to `DistData(`RenderedShape(...))
|
||||
| `Truncate // always evaluates to `DistData(`RenderedShape(...))
|
||||
(
|
||||
option(float),
|
||||
option(float),
|
||||
treeNode,
|
||||
) // leftCutoff and rightCutoff
|
||||
| `Normalize // always evaluates to `DistData(`RenderedShape(...))
|
||||
// leftCutoff and rightCutoff
|
||||
(
|
||||
treeNode,
|
||||
)
|
||||
| `FloatFromDist // always evaluates to `DistData(`RenderedShape(...))
|
||||
// leftCutoff and rightCutoff
|
||||
(
|
||||
distToFloatOperation,
|
||||
treeNode,
|
||||
)
|
||||
| `AlgebraicCombination(algebraicOperation, treeNode, treeNode)
|
||||
| `PointwiseCombination(pointwiseOperation, treeNode, treeNode)
|
||||
| `VerticalScaling(scaleOperation, treeNode, treeNode)
|
||||
| `Render(treeNode)
|
||||
| `Truncate(option(float), option(float), treeNode)
|
||||
| `Normalize(treeNode)
|
||||
| `FloatFromDist(distToFloatOperation, treeNode)
|
||||
];
|
||||
|
||||
module TreeNode = {
|
||||
type t = treeNode;
|
||||
type tResult = treeNode => result(treeNode, string);
|
||||
|
||||
let rec toString = (t: t): string => {
|
||||
let stringFromAlgebraicCombination =
|
||||
fun
|
||||
| `Add => " + "
|
||||
| `Subtract => " - "
|
||||
| `Multiply => " * "
|
||||
| `Divide => " / "
|
||||
|
||||
let stringFromPointwiseCombination =
|
||||
fun
|
||||
| `Add => " .+ "
|
||||
| `Multiply => " .* ";
|
||||
|
||||
let stringFromFloatFromDistOperation =
|
||||
fun
|
||||
| `Pdf(f) => {j|pdf(x=$f, |j}
|
||||
| `Inv(f) => {j|inv(x=$f, |j}
|
||||
| `Sample => "sample("
|
||||
| `Mean => "mean(";
|
||||
|
||||
switch (t) {
|
||||
let rec toString =
|
||||
fun
|
||||
| `DistData(`Symbolic(d)) =>
|
||||
SymbolicDist.GenericDistFunctions.toString(d)
|
||||
| `DistData(`RenderedShape(_)) => "[shape]"
|
||||
| `Operation(`AlgebraicCombination(op, t1, t2)) =>
|
||||
toString(t1) ++ stringFromAlgebraicCombination(op) ++ toString(t2)
|
||||
SymbolicTypes.Algebraic.format(op, toString(t1), toString(t2))
|
||||
| `Operation(`PointwiseCombination(op, t1, t2)) =>
|
||||
toString(t1) ++ stringFromPointwiseCombination(op) ++ toString(t2)
|
||||
SymbolicTypes.Pointwise.format(op, toString(t1), toString(t2))
|
||||
| `Operation(`VerticalScaling(_scaleOp, t, scaleBy)) =>
|
||||
toString(t) ++ " @ " ++ toString(scaleBy)
|
||||
| `Operation(`Normalize(t)) => "normalize(" ++ toString(t) ++ ")"
|
||||
| `Operation(`FloatFromDist(floatFromDistOp, t)) => stringFromFloatFromDistOperation(floatFromDistOp) ++ toString(t) ++ ")"
|
||||
| `Operation(`FloatFromDist(floatFromDistOp, t)) =>
|
||||
SymbolicTypes.DistToFloat.format(floatFromDistOp, toString(t))
|
||||
| `Operation(`Truncate(lc, rc, t)) =>
|
||||
"truncate("
|
||||
++ toString(t)
|
||||
|
@ -89,9 +44,7 @@ module TreeNode = {
|
|||
++ ", "
|
||||
++ E.O.dimap(Js.Float.toString, () => "inf", rc)
|
||||
++ ")"
|
||||
| `Operation(`Render(t)) => toString(t)
|
||||
};
|
||||
};
|
||||
| `Operation(`Render(t)) => toString(t);
|
||||
|
||||
/* The following modules encapsulate everything we can do with
|
||||
* different kinds of operations. */
|
||||
|
@ -104,88 +57,72 @@ module TreeNode = {
|
|||
let simplify = (algebraicOp, t1: t, t2: t): result(treeNode, string) => {
|
||||
let tryCombiningFloats: tResult =
|
||||
fun
|
||||
| `Operation(
|
||||
`AlgebraicCombination(
|
||||
`Divide,
|
||||
`DistData(`Symbolic(`Float(_))),
|
||||
`DistData(`Symbolic(`Float(0.))),
|
||||
),
|
||||
) =>
|
||||
Error("Cannot divide $v1 by zero.")
|
||||
| `Operation(
|
||||
`AlgebraicCombination(
|
||||
algebraicOp,
|
||||
`DistData(`Symbolic(`Float(v1))),
|
||||
`DistData(`Symbolic(`Float(v2))),
|
||||
),
|
||||
) => {
|
||||
let func = AlgebraicCombinations.Operation.toFn(algebraicOp);
|
||||
Ok(`DistData(`Symbolic(`Float(func(v1, v2)))));
|
||||
}
|
||||
) =>
|
||||
SymbolicTypes.Algebraic.applyFn(algebraicOp, v1, v2)
|
||||
|> E.R.fmap(r => `DistData(`Symbolic(`Float(r))))
|
||||
| t => Ok(t);
|
||||
|
||||
let optionToSymbolicResult = (t, o) =>
|
||||
o
|
||||
|> E.O.dimap(r => `DistData(`Symbolic(r)), () => t)
|
||||
|> (r => Ok(r));
|
||||
|
||||
let tryCombiningNormals: tResult =
|
||||
fun
|
||||
| `Operation(
|
||||
`AlgebraicCombination(
|
||||
`Add,
|
||||
operation,
|
||||
`DistData(`Symbolic(`Normal(n1))),
|
||||
`DistData(`Symbolic(`Normal(n2))),
|
||||
),
|
||||
) =>
|
||||
Ok(`DistData(`Symbolic(SymbolicDist.Normal.add(n1, n2))))
|
||||
| `Operation(
|
||||
`AlgebraicCombination(
|
||||
`Subtract,
|
||||
`DistData(`Symbolic(`Normal(n1))),
|
||||
`DistData(`Symbolic(`Normal(n2))),
|
||||
),
|
||||
) =>
|
||||
Ok(`DistData(`Symbolic(SymbolicDist.Normal.subtract(n1, n2))))
|
||||
) as t =>
|
||||
SymbolicDist.Normal.operate(operation, n1, n2)
|
||||
|> optionToSymbolicResult(t)
|
||||
| t => Ok(t);
|
||||
|
||||
let tryCombiningLognormals: tResult =
|
||||
fun
|
||||
| `Operation(
|
||||
`AlgebraicCombination(
|
||||
`Multiply,
|
||||
`DistData(`Symbolic(`Lognormal(l1))),
|
||||
`DistData(`Symbolic(`Lognormal(l2))),
|
||||
operation,
|
||||
`DistData(`Symbolic(`Lognormal(n1))),
|
||||
`DistData(`Symbolic(`Lognormal(n2))),
|
||||
),
|
||||
) =>
|
||||
Ok(`DistData(`Symbolic(SymbolicDist.Lognormal.multiply(l1, l2))))
|
||||
| `Operation(
|
||||
`AlgebraicCombination(
|
||||
`Divide,
|
||||
`DistData(`Symbolic(`Lognormal(l1))),
|
||||
`DistData(`Symbolic(`Lognormal(l2))),
|
||||
),
|
||||
) =>
|
||||
Ok(`DistData(`Symbolic(SymbolicDist.Lognormal.divide(l1, l2))))
|
||||
) as t =>
|
||||
SymbolicDist.Lognormal.operate(operation, n1, n2)
|
||||
|> optionToSymbolicResult(t)
|
||||
| t => Ok(t);
|
||||
|
||||
let originalTreeNode =
|
||||
`Operation(`AlgebraicCombination((algebraicOp, t1, t2)));
|
||||
`Operation(`AlgebraicCombination((algebraicOp, t1, t2)));
|
||||
|
||||
// Feedback: I like this pattern, kudos
|
||||
originalTreeNode
|
||||
|> tryCombiningFloats
|
||||
|> E.R.bind(_, tryCombiningNormals)
|
||||
|> E.R.bind(_, tryCombiningLognormals);
|
||||
};
|
||||
|
||||
// todo: I don't like the name evaluateNumerically that much, if this renders and does it algebraically. It's tricky.
|
||||
let evaluateNumerically = (algebraicOp, operationToDistData, t1, t2) => {
|
||||
// force rendering into shapes
|
||||
let renderedShape1 = operationToDistData(`Render(t1));
|
||||
let renderedShape2 = operationToDistData(`Render(t2));
|
||||
|
||||
switch (renderedShape1, renderedShape2) {
|
||||
let renderShape = r => operationToDistData(`Render(r));
|
||||
switch (renderShape(t1), renderShape(t2)) {
|
||||
| (
|
||||
Ok(`DistData(`RenderedShape(s1))),
|
||||
Ok(`DistData(`RenderedShape(s2))),
|
||||
) =>
|
||||
Ok(
|
||||
`DistData(
|
||||
`RenderedShape(Distributions.Shape.combineAlgebraically(algebraicOp, s1, s2)),
|
||||
`RenderedShape(
|
||||
Distributions.Shape.combineAlgebraically(algebraicOp, s1, s2),
|
||||
),
|
||||
),
|
||||
)
|
||||
| (Error(e1), _) => Error(e1)
|
||||
|
@ -195,7 +132,12 @@ module TreeNode = {
|
|||
};
|
||||
|
||||
let evaluateToDistData =
|
||||
(algebraicOp: AlgebraicCombinations.algebraicOperation, operationToDistData, t1: t, t2: t)
|
||||
(
|
||||
algebraicOp: SymbolicTypes.algebraicOperation,
|
||||
operationToDistData,
|
||||
t1: t,
|
||||
t2: t,
|
||||
)
|
||||
: result(treeNode, string) =>
|
||||
algebraicOp
|
||||
|> simplify(_, t1, t2)
|
||||
|
@ -210,27 +152,13 @@ module TreeNode = {
|
|||
};
|
||||
|
||||
module VerticalScaling = {
|
||||
let fnFromOp =
|
||||
fun
|
||||
| `Multiply => ( *. )
|
||||
| `Exponentiate => ( ** )
|
||||
| `Log => ((a, b) => log(a) /. log(b));
|
||||
|
||||
let knownIntegralSumFnFromOp =
|
||||
fun
|
||||
| `Multiply => ((a, b) => Some(a *. b))
|
||||
| `Exponentiate => ((_, _) => None)
|
||||
| `Log => ((_, _) => None);
|
||||
|
||||
let evaluateToDistData = (scaleOp, operationToDistData, t, scaleBy) => {
|
||||
// scaleBy has to be a single float, otherwise we'll return an error.
|
||||
let fn = fnFromOp(scaleOp);
|
||||
let knownIntegralSumFn = knownIntegralSumFnFromOp(scaleOp);
|
||||
|
||||
let fn = SymbolicTypes.Scale.toFn(scaleOp);
|
||||
let knownIntegralSumFn = SymbolicTypes.Scale.toKnownIntegralSumFn(scaleOp);
|
||||
let renderedShape = operationToDistData(`Render(t));
|
||||
|
||||
switch (renderedShape, scaleBy) {
|
||||
| (Error(e1), _) => Error(e1)
|
||||
| (
|
||||
Ok(`DistData(`RenderedShape(rs))),
|
||||
`DistData(`Symbolic(`Float(sm))),
|
||||
|
@ -246,6 +174,7 @@ module TreeNode = {
|
|||
),
|
||||
),
|
||||
)
|
||||
| (Error(e1), _) => Error(e1)
|
||||
| (_, _) => Error("Can only scale by float values.")
|
||||
};
|
||||
};
|
||||
|
@ -253,14 +182,28 @@ module TreeNode = {
|
|||
|
||||
module PointwiseCombination = {
|
||||
let pointwiseAdd = (operationToDistData, t1, t2) => {
|
||||
let renderedShape1 = operationToDistData(`Render(t1));
|
||||
let renderedShape2 = operationToDistData(`Render(t2));
|
||||
let renderedShape1 = operationToDistData(`Render(t1));
|
||||
let renderedShape2 = operationToDistData(`Render(t2));
|
||||
|
||||
switch ((renderedShape1, renderedShape2)) {
|
||||
switch (renderedShape1, renderedShape2) {
|
||||
| (
|
||||
Ok(`DistData(`RenderedShape(rs1))),
|
||||
Ok(`DistData(`RenderedShape(rs2))),
|
||||
) =>
|
||||
Ok(
|
||||
`DistData(
|
||||
`RenderedShape(
|
||||
Distributions.Shape.combinePointwise(
|
||||
~knownIntegralSumsFn=(a, b) => Some(a +. b),
|
||||
(+.),
|
||||
rs1,
|
||||
rs2,
|
||||
),
|
||||
),
|
||||
),
|
||||
)
|
||||
| (Error(e1), _) => Error(e1)
|
||||
| (_, Error(e2)) => Error(e2)
|
||||
| (Ok(`DistData(`RenderedShape(rs1))), Ok(`DistData(`RenderedShape(rs2)))) =>
|
||||
Ok(`DistData(`RenderedShape(Distributions.Shape.combinePointwise(~knownIntegralSumsFn=(a, b) => Some(a +. b), (+.), rs1, rs2))))
|
||||
| _ => Error("Could not perform pointwise addition.")
|
||||
};
|
||||
};
|
||||
|
@ -268,14 +211,16 @@ module TreeNode = {
|
|||
let pointwiseMultiply = (operationToDistData, t1, t2) => {
|
||||
// TODO: construct a function that we can easily sample from, to construct
|
||||
// a RenderedShape. Use the xMin and xMax of the rendered shapes to tell the sampling function where to look.
|
||||
Error("Pointwise multiplication not yet supported.");
|
||||
Error(
|
||||
"Pointwise multiplication not yet supported.",
|
||||
);
|
||||
};
|
||||
|
||||
let evaluateToDistData = (pointwiseOp, operationToDistData, t1, t2) => {
|
||||
switch (pointwiseOp) {
|
||||
| `Add => pointwiseAdd(operationToDistData, t1, t2)
|
||||
| `Multiply => pointwiseMultiply(operationToDistData, t1, t2)
|
||||
}
|
||||
};
|
||||
};
|
||||
};
|
||||
|
||||
|
@ -378,7 +323,9 @@ module TreeNode = {
|
|||
};
|
||||
E.R.bind(value, v => Ok(`DistData(`Symbolic(`Float(v)))));
|
||||
};
|
||||
let evaluateFromRenderedShape = (distToFloatOp: distToFloatOperation, rs: DistTypes.shape) : result(treeNode, string) => {
|
||||
let evaluateFromRenderedShape =
|
||||
(distToFloatOp: distToFloatOperation, rs: DistTypes.shape)
|
||||
: result(treeNode, string) => {
|
||||
let value =
|
||||
switch (distToFloatOp) {
|
||||
| `Pdf(f) => Ok(Distributions.Shape.pdf(f, rs))
|
||||
|
@ -410,8 +357,12 @@ module TreeNode = {
|
|||
|
||||
module Render = {
|
||||
let rec evaluateToRenderedShape =
|
||||
(operationToDistData: operation => result(t, string), sampleCount: int, t: treeNode)
|
||||
: result(t, string) => {
|
||||
(
|
||||
operationToDistData: operation => result(t, string),
|
||||
sampleCount: int,
|
||||
t: treeNode,
|
||||
)
|
||||
: result(t, string) => {
|
||||
switch (t) {
|
||||
| `DistData(`RenderedShape(s)) => Ok(`DistData(`RenderedShape(s))) // already a rendered shape, we're done here
|
||||
| `DistData(`Symbolic(d)) =>
|
||||
|
@ -495,10 +446,19 @@ module TreeNode = {
|
|||
t,
|
||||
)
|
||||
| `FloatFromDist(distToFloatOp, t) =>
|
||||
FloatFromDist.evaluateToDistData(distToFloatOp, operationToDistData(sampleCount), t)
|
||||
| `Normalize(t) => Normalize.evaluateToDistData(operationToDistData(sampleCount), t)
|
||||
FloatFromDist.evaluateToDistData(
|
||||
distToFloatOp,
|
||||
operationToDistData(sampleCount),
|
||||
t,
|
||||
)
|
||||
| `Normalize(t) =>
|
||||
Normalize.evaluateToDistData(operationToDistData(sampleCount), t)
|
||||
| `Render(t) =>
|
||||
Render.evaluateToRenderedShape(operationToDistData(sampleCount), sampleCount, t)
|
||||
Render.evaluateToRenderedShape(
|
||||
operationToDistData(sampleCount),
|
||||
sampleCount,
|
||||
t,
|
||||
)
|
||||
};
|
||||
};
|
||||
|
||||
|
@ -531,5 +491,4 @@ let toShape = (sampleCount: int, treeNode: treeNode) => {
|
|||
};
|
||||
};
|
||||
|
||||
let toString = (treeNode: treeNode) =>
|
||||
TreeNode.toString(treeNode);
|
||||
let toString = (treeNode: treeNode) => TreeNode.toString(treeNode);
|
||||
|
|
Loading…
Reference in New Issue
Block a user