Moving operations functionality into new SymbolicTypes.re file
This commit is contained in:
parent
acdd3dfe7a
commit
baaff19750
|
@ -1,5 +1,3 @@
|
||||||
type algebraicOperation = [ | `Add | `Multiply | `Subtract | `Divide];
|
|
||||||
|
|
||||||
type pointMassesWithMoments = {
|
type pointMassesWithMoments = {
|
||||||
n: int,
|
n: int,
|
||||||
masses: array(float),
|
masses: array(float),
|
||||||
|
@ -7,23 +5,6 @@ type pointMassesWithMoments = {
|
||||||
variances: array(float),
|
variances: array(float),
|
||||||
};
|
};
|
||||||
|
|
||||||
module Operation = {
|
|
||||||
type t = algebraicOperation;
|
|
||||||
let toFn: (t, float, float) => float =
|
|
||||||
fun
|
|
||||||
| `Add => (+.)
|
|
||||||
| `Subtract => (-.)
|
|
||||||
| `Multiply => ( *. )
|
|
||||||
| `Divide => (/.);
|
|
||||||
|
|
||||||
let toString =
|
|
||||||
fun
|
|
||||||
| `Add => " + "
|
|
||||||
| `Subtract => " - "
|
|
||||||
| `Multiply => " * "
|
|
||||||
| `Divide => " / ";
|
|
||||||
};
|
|
||||||
|
|
||||||
/* This function takes a continuous distribution and efficiently approximates it as
|
/* This function takes a continuous distribution and efficiently approximates it as
|
||||||
point masses that have variances associated with them.
|
point masses that have variances associated with them.
|
||||||
We estimate the means and variances from overlapping triangular distributions which we imagine are making up the
|
We estimate the means and variances from overlapping triangular distributions which we imagine are making up the
|
||||||
|
@ -129,7 +110,7 @@ let toDiscretePointMassesFromTriangulars =
|
||||||
};
|
};
|
||||||
|
|
||||||
let combineShapesContinuousContinuous =
|
let combineShapesContinuousContinuous =
|
||||||
(op: algebraicOperation, s1: DistTypes.xyShape, s2: DistTypes.xyShape)
|
(op: SymbolicTypes.algebraicOperation, s1: DistTypes.xyShape, s2: DistTypes.xyShape)
|
||||||
: DistTypes.xyShape => {
|
: DistTypes.xyShape => {
|
||||||
let t1n = s1 |> XYShape.T.length;
|
let t1n = s1 |> XYShape.T.length;
|
||||||
let t2n = s2 |> XYShape.T.length;
|
let t2n = s2 |> XYShape.T.length;
|
||||||
|
@ -216,4 +197,4 @@ let combineShapesContinuousContinuous =
|
||||||
};
|
};
|
||||||
|
|
||||||
{xs: outputXs, ys: outputYs};
|
{xs: outputXs, ys: outputYs};
|
||||||
};
|
};
|
|
@ -149,7 +149,7 @@ module Continuous = {
|
||||||
continuousShapes
|
continuousShapes
|
||||||
|> E.A.fold_left(combinePointwise(~knownIntegralSumsFn, fn), empty);
|
|> E.A.fold_left(combinePointwise(~knownIntegralSumsFn, fn), empty);
|
||||||
|
|
||||||
let mapY = (~knownIntegralSumFn=(_ => None), fn, t: t) => {
|
let mapY = (~knownIntegralSumFn=_ => None, fn, t: t) => {
|
||||||
let u = E.O.bind(_, knownIntegralSumFn);
|
let u = E.O.bind(_, knownIntegralSumFn);
|
||||||
let yMapFn = shapeMap(XYShape.T.mapY(fn));
|
let yMapFn = shapeMap(XYShape.T.mapY(fn));
|
||||||
|
|
||||||
|
@ -164,7 +164,6 @@ module Continuous = {
|
||||||
);
|
);
|
||||||
};
|
};
|
||||||
|
|
||||||
|
|
||||||
module T =
|
module T =
|
||||||
Dist({
|
Dist({
|
||||||
type t = DistTypes.continuousShape;
|
type t = DistTypes.continuousShape;
|
||||||
|
@ -194,9 +193,9 @@ module Continuous = {
|
||||||
|> getShape
|
|> getShape
|
||||||
|> XYShape.T.zip
|
|> XYShape.T.zip
|
||||||
|> XYShape.Zipped.filterByX(x =>
|
|> XYShape.Zipped.filterByX(x =>
|
||||||
x >= E.O.default(neg_infinity, leftCutoff)
|
x >= E.O.default(neg_infinity, leftCutoff)
|
||||||
|| x <= E.O.default(infinity, rightCutoff)
|
|| x <= E.O.default(infinity, rightCutoff)
|
||||||
);
|
);
|
||||||
|
|
||||||
let eps = (t |> getShape |> XYShape.T.xTotalRange) *. 0.0001;
|
let eps = (t |> getShape |> XYShape.T.xTotalRange) *. 0.0001;
|
||||||
|
|
||||||
|
@ -206,7 +205,11 @@ module Continuous = {
|
||||||
rightCutoff |> E.O.dimap(rc => [|(rc +. eps, 0.)|], _ => [||]);
|
rightCutoff |> E.O.dimap(rc => [|(rc +. eps, 0.)|], _ => [||]);
|
||||||
|
|
||||||
let truncatedZippedPairsWithNewPoints =
|
let truncatedZippedPairsWithNewPoints =
|
||||||
E.A.concatMany([|leftNewPoint, truncatedZippedPairs, rightNewPoint|]);
|
E.A.concatMany([|
|
||||||
|
leftNewPoint,
|
||||||
|
truncatedZippedPairs,
|
||||||
|
rightNewPoint,
|
||||||
|
|]);
|
||||||
let truncatedShape =
|
let truncatedShape =
|
||||||
XYShape.T.fromZippedArray(truncatedZippedPairsWithNewPoints);
|
XYShape.T.fromZippedArray(truncatedZippedPairsWithNewPoints);
|
||||||
|
|
||||||
|
@ -214,22 +217,20 @@ module Continuous = {
|
||||||
};
|
};
|
||||||
|
|
||||||
// TODO: This should work with stepwise plots.
|
// TODO: This should work with stepwise plots.
|
||||||
let integral = (~cache, t) => {
|
let integral = (~cache, t) =>
|
||||||
|
if (t |> getShape |> XYShape.T.length > 0) {
|
||||||
if ((t |> getShape |> XYShape.T.length) > 0) {
|
switch (cache) {
|
||||||
switch (cache) {
|
| Some(cache) => cache
|
||||||
| Some(cache) => cache
|
| None =>
|
||||||
| None =>
|
t
|
||||||
t
|
|> getShape
|
||||||
|> getShape
|
|> XYShape.Range.integrateWithTriangles
|
||||||
|> XYShape.Range.integrateWithTriangles
|
|> E.O.toExt("This should not have happened")
|
||||||
|> E.O.toExt("This should not have happened")
|
|> make(`Linear, _, None)
|
||||||
|> make(`Linear, _, None)
|
};
|
||||||
};
|
|
||||||
} else {
|
} else {
|
||||||
make(`Linear, {xs: [|neg_infinity|], ys: [|0.0|]}, None);
|
make(`Linear, {xs: [|neg_infinity|], ys: [|0.0|]}, None);
|
||||||
}
|
};
|
||||||
};
|
|
||||||
|
|
||||||
let downsample = (~cache=None, length, t): t =>
|
let downsample = (~cache=None, length, t): t =>
|
||||||
t
|
t
|
||||||
|
@ -276,23 +277,31 @@ module Continuous = {
|
||||||
);
|
);
|
||||||
});
|
});
|
||||||
|
|
||||||
|
|
||||||
/* This simply creates multiple copies of the continuous distribution, scaled and shifted according to
|
/* This simply creates multiple copies of the continuous distribution, scaled and shifted according to
|
||||||
each discrete data point, and then adds them all together. */
|
each discrete data point, and then adds them all together. */
|
||||||
let combineAlgebraicallyWithDiscrete = (~downsample=false, op: AlgebraicCombinations.algebraicOperation, t1: t, t2: DistTypes.discreteShape) => {
|
let combineAlgebraicallyWithDiscrete =
|
||||||
|
(
|
||||||
|
~downsample=false,
|
||||||
|
op: SymbolicTypes.algebraicOperation,
|
||||||
|
t1: t,
|
||||||
|
t2: DistTypes.discreteShape,
|
||||||
|
) => {
|
||||||
let t1s = t1 |> getShape;
|
let t1s = t1 |> getShape;
|
||||||
let t2s = t2.xyShape; // would like to use Discrete.getShape here, but current file structure doesn't allow for that
|
let t2s = t2.xyShape; // would like to use Discrete.getShape here, but current file structure doesn't allow for that
|
||||||
let t1n = t1s |> XYShape.T.length;
|
let t1n = t1s |> XYShape.T.length;
|
||||||
let t2n = t2s |> XYShape.T.length;
|
let t2n = t2s |> XYShape.T.length;
|
||||||
|
|
||||||
let fn = AlgebraicCombinations.Operation.toFn(op);
|
let fn = SymbolicTypes.Algebraic.toFn(op);
|
||||||
|
|
||||||
let outXYShapes: array(array((float, float))) =
|
let outXYShapes: array(array((float, float))) =
|
||||||
Belt.Array.makeUninitializedUnsafe(t2n);
|
Belt.Array.makeUninitializedUnsafe(t2n);
|
||||||
|
|
||||||
for (j in 0 to t2n - 1) { // for each one of the discrete points
|
for (j in 0 to t2n - 1) {
|
||||||
|
// for each one of the discrete points
|
||||||
// create a new distribution, as long as the original continuous one
|
// create a new distribution, as long as the original continuous one
|
||||||
let dxyShape: array((float, float)) = Belt.Array.makeUninitializedUnsafe(t1n);
|
|
||||||
|
let dxyShape: array((float, float)) =
|
||||||
|
Belt.Array.makeUninitializedUnsafe(t1n);
|
||||||
for (i in 0 to t1n - 1) {
|
for (i in 0 to t1n - 1) {
|
||||||
let _ =
|
let _ =
|
||||||
Belt.Array.set(
|
Belt.Array.set(
|
||||||
|
@ -307,7 +316,12 @@ module Continuous = {
|
||||||
();
|
();
|
||||||
};
|
};
|
||||||
|
|
||||||
let combinedIntegralSum = Common.combineIntegralSums((a, b) => Some(a *. b), t1.knownIntegralSum, t2.knownIntegralSum);
|
let combinedIntegralSum =
|
||||||
|
Common.combineIntegralSums(
|
||||||
|
(a, b) => Some(a *. b),
|
||||||
|
t1.knownIntegralSum,
|
||||||
|
t2.knownIntegralSum,
|
||||||
|
);
|
||||||
|
|
||||||
outXYShapes
|
outXYShapes
|
||||||
|> E.A.fmap(s => {
|
|> E.A.fmap(s => {
|
||||||
|
@ -318,7 +332,13 @@ module Continuous = {
|
||||||
|> updateKnownIntegralSum(combinedIntegralSum);
|
|> updateKnownIntegralSum(combinedIntegralSum);
|
||||||
};
|
};
|
||||||
|
|
||||||
let combineAlgebraically = (~downsample=false, op: AlgebraicCombinations.algebraicOperation, t1: t, t2: t) => {
|
let combineAlgebraically =
|
||||||
|
(
|
||||||
|
~downsample=false,
|
||||||
|
op: SymbolicTypes.algebraicOperation,
|
||||||
|
t1: t,
|
||||||
|
t2: t,
|
||||||
|
) => {
|
||||||
let s1 = t1 |> getShape;
|
let s1 = t1 |> getShape;
|
||||||
let s2 = t2 |> getShape;
|
let s2 = t2 |> getShape;
|
||||||
let t1n = s1 |> XYShape.T.length;
|
let t1n = s1 |> XYShape.T.length;
|
||||||
|
@ -326,8 +346,14 @@ module Continuous = {
|
||||||
if (t1n == 0 || t2n == 0) {
|
if (t1n == 0 || t2n == 0) {
|
||||||
empty;
|
empty;
|
||||||
} else {
|
} else {
|
||||||
let combinedShape = AlgebraicCombinations.combineShapesContinuousContinuous(op, s1, s2);
|
let combinedShape =
|
||||||
let combinedIntegralSum = Common.combineIntegralSums((a, b) => Some(a *. b), t1.knownIntegralSum, t2.knownIntegralSum);
|
AlgebraicCombinations.combineShapesContinuousContinuous(op, s1, s2);
|
||||||
|
let combinedIntegralSum =
|
||||||
|
Common.combineIntegralSums(
|
||||||
|
(a, b) => Some(a *. b),
|
||||||
|
t1.knownIntegralSum,
|
||||||
|
t2.knownIntegralSum,
|
||||||
|
);
|
||||||
// return a new Continuous distribution
|
// return a new Continuous distribution
|
||||||
make(`Linear, combinedShape, combinedIntegralSum);
|
make(`Linear, combinedShape, combinedIntegralSum);
|
||||||
};
|
};
|
||||||
|
@ -370,7 +396,7 @@ module Discrete = {
|
||||||
XYShape.PointwiseCombination.combine(
|
XYShape.PointwiseCombination.combine(
|
||||||
~xsSelection=ALL_XS,
|
~xsSelection=ALL_XS,
|
||||||
~xToYSelection=XYShape.XtoY.stepwiseIfAtX,
|
~xToYSelection=XYShape.XtoY.stepwiseIfAtX,
|
||||||
~fn=((a, b) => fn(E.O.default(0.0, a), E.O.default(0.0, b))), // stepwiseIfAtX returns option(float), so this fn needs to handle None
|
~fn=(a, b) => fn(E.O.default(0.0, a), E.O.default(0.0, b)), // stepwiseIfAtX returns option(float), so this fn needs to handle None
|
||||||
t1.xyShape,
|
t1.xyShape,
|
||||||
t2.xyShape,
|
t2.xyShape,
|
||||||
),
|
),
|
||||||
|
@ -378,7 +404,9 @@ module Discrete = {
|
||||||
);
|
);
|
||||||
};
|
};
|
||||||
|
|
||||||
let reduce = (~knownIntegralSumsFn=(_, _) => None, fn, discreteShapes): DistTypes.discreteShape =>
|
let reduce =
|
||||||
|
(~knownIntegralSumsFn=(_, _) => None, fn, discreteShapes)
|
||||||
|
: DistTypes.discreteShape =>
|
||||||
discreteShapes
|
discreteShapes
|
||||||
|> E.A.fold_left(combinePointwise(~knownIntegralSumsFn, fn), empty);
|
|> E.A.fold_left(combinePointwise(~knownIntegralSumsFn, fn), empty);
|
||||||
|
|
||||||
|
@ -389,7 +417,8 @@ module Discrete = {
|
||||||
|
|
||||||
/* This multiples all of the data points together and creates a new discrete distribution from the results.
|
/* This multiples all of the data points together and creates a new discrete distribution from the results.
|
||||||
Data points at the same xs get added together. It may be a good idea to downsample t1 and t2 before and/or the result after. */
|
Data points at the same xs get added together. It may be a good idea to downsample t1 and t2 before and/or the result after. */
|
||||||
let combineAlgebraically = (op: AlgebraicCombinations.algebraicOperation, t1: t, t2: t) => {
|
let combineAlgebraically =
|
||||||
|
(op: SymbolicTypes.algebraicOperation, t1: t, t2: t) => {
|
||||||
let t1s = t1 |> getShape;
|
let t1s = t1 |> getShape;
|
||||||
let t2s = t2 |> getShape;
|
let t2s = t2 |> getShape;
|
||||||
let t1n = t1s |> XYShape.T.length;
|
let t1n = t1s |> XYShape.T.length;
|
||||||
|
@ -402,7 +431,7 @@ module Discrete = {
|
||||||
t2.knownIntegralSum,
|
t2.knownIntegralSum,
|
||||||
);
|
);
|
||||||
|
|
||||||
let fn = AlgebraicCombinations.Operation.toFn(op);
|
let fn = SymbolicTypes.Algebraic.toFn(op);
|
||||||
let xToYMap = E.FloatFloatMap.empty();
|
let xToYMap = E.FloatFloatMap.empty();
|
||||||
|
|
||||||
for (i in 0 to t1n - 1) {
|
for (i in 0 to t1n - 1) {
|
||||||
|
@ -441,8 +470,8 @@ module Discrete = {
|
||||||
Dist({
|
Dist({
|
||||||
type t = DistTypes.discreteShape;
|
type t = DistTypes.discreteShape;
|
||||||
type integral = DistTypes.continuousShape;
|
type integral = DistTypes.continuousShape;
|
||||||
let integral = (~cache, t) => {
|
let integral = (~cache, t) =>
|
||||||
if ((t |> getShape |> XYShape.T.length) > 0) {
|
if (t |> getShape |> XYShape.T.length > 0) {
|
||||||
switch (cache) {
|
switch (cache) {
|
||||||
| Some(c) => c
|
| Some(c) => c
|
||||||
| None =>
|
| None =>
|
||||||
|
@ -453,9 +482,13 @@ module Discrete = {
|
||||||
)
|
)
|
||||||
};
|
};
|
||||||
} else {
|
} else {
|
||||||
Continuous.make(`Stepwise, {xs: [|neg_infinity|], ys: [|0.0|]}, None);
|
Continuous.make(
|
||||||
}};
|
`Stepwise,
|
||||||
|
{xs: [|neg_infinity|], ys: [|0.0|]},
|
||||||
|
None,
|
||||||
|
);
|
||||||
|
};
|
||||||
|
|
||||||
let integralEndY = (~cache, t: t) =>
|
let integralEndY = (~cache, t: t) =>
|
||||||
t.knownIntegralSum
|
t.knownIntegralSum
|
||||||
|> E.O.default(t |> integral(~cache) |> Continuous.lastY);
|
|> E.O.default(t |> integral(~cache) |> Continuous.lastY);
|
||||||
|
@ -495,7 +528,7 @@ module Discrete = {
|
||||||
make(clippedShape, None); // if someone needs the sum, they'll have to recompute it
|
make(clippedShape, None); // if someone needs the sum, they'll have to recompute it
|
||||||
} else {
|
} else {
|
||||||
t;
|
t;
|
||||||
}
|
};
|
||||||
};
|
};
|
||||||
|
|
||||||
let truncate =
|
let truncate =
|
||||||
|
@ -505,9 +538,9 @@ module Discrete = {
|
||||||
|> getShape
|
|> getShape
|
||||||
|> XYShape.T.zip
|
|> XYShape.T.zip
|
||||||
|> XYShape.Zipped.filterByX(x =>
|
|> XYShape.Zipped.filterByX(x =>
|
||||||
x >= E.O.default(neg_infinity, leftCutoff)
|
x >= E.O.default(neg_infinity, leftCutoff)
|
||||||
|| x <= E.O.default(infinity, rightCutoff)
|
|| x <= E.O.default(infinity, rightCutoff)
|
||||||
)
|
)
|
||||||
|> XYShape.T.fromZippedArray;
|
|> XYShape.T.fromZippedArray;
|
||||||
|
|
||||||
make(truncatedShape, None);
|
make(truncatedShape, None);
|
||||||
|
@ -601,8 +634,10 @@ module Mixed = {
|
||||||
rightCutoff: option(float),
|
rightCutoff: option(float),
|
||||||
{discrete, continuous}: t,
|
{discrete, continuous}: t,
|
||||||
) => {
|
) => {
|
||||||
let truncatedContinuous = Continuous.T.truncate(leftCutoff, rightCutoff, continuous);
|
let truncatedContinuous =
|
||||||
let truncatedDiscrete = Discrete.T.truncate(leftCutoff, rightCutoff, discrete);
|
Continuous.T.truncate(leftCutoff, rightCutoff, continuous);
|
||||||
|
let truncatedDiscrete =
|
||||||
|
Discrete.T.truncate(leftCutoff, rightCutoff, discrete);
|
||||||
|
|
||||||
make(~discrete=truncatedDiscrete, ~continuous=truncatedContinuous);
|
make(~discrete=truncatedDiscrete, ~continuous=truncatedContinuous);
|
||||||
};
|
};
|
||||||
|
@ -809,7 +844,14 @@ module Mixed = {
|
||||||
};
|
};
|
||||||
});
|
});
|
||||||
|
|
||||||
let combineAlgebraically = (~downsample=false, op: AlgebraicCombinations.algebraicOperation, t1: t, t2: t): t => {
|
let combineAlgebraically =
|
||||||
|
(
|
||||||
|
~downsample=false,
|
||||||
|
op: SymbolicTypes.algebraicOperation,
|
||||||
|
t1: t,
|
||||||
|
t2: t,
|
||||||
|
)
|
||||||
|
: t => {
|
||||||
// Discrete convolution can cause a huge increase in the number of samples,
|
// Discrete convolution can cause a huge increase in the number of samples,
|
||||||
// so we'll first downsample.
|
// so we'll first downsample.
|
||||||
|
|
||||||
|
@ -827,11 +869,26 @@ module Mixed = {
|
||||||
// continuous (*) continuous => continuous, but also
|
// continuous (*) continuous => continuous, but also
|
||||||
// discrete (*) continuous => continuous (and vice versa). We have to take care of all combos and then combine them:
|
// discrete (*) continuous => continuous (and vice versa). We have to take care of all combos and then combine them:
|
||||||
let ccConvResult =
|
let ccConvResult =
|
||||||
Continuous.combineAlgebraically(~downsample=false, op, t1d.continuous, t2d.continuous);
|
Continuous.combineAlgebraically(
|
||||||
|
~downsample=false,
|
||||||
|
op,
|
||||||
|
t1d.continuous,
|
||||||
|
t2d.continuous,
|
||||||
|
);
|
||||||
let dcConvResult =
|
let dcConvResult =
|
||||||
Continuous.combineAlgebraicallyWithDiscrete(~downsample=false, op, t2d.continuous, t1d.discrete);
|
Continuous.combineAlgebraicallyWithDiscrete(
|
||||||
|
~downsample=false,
|
||||||
|
op,
|
||||||
|
t2d.continuous,
|
||||||
|
t1d.discrete,
|
||||||
|
);
|
||||||
let cdConvResult =
|
let cdConvResult =
|
||||||
Continuous.combineAlgebraicallyWithDiscrete(~downsample=false, op, t1d.continuous, t2d.discrete);
|
Continuous.combineAlgebraicallyWithDiscrete(
|
||||||
|
~downsample=false,
|
||||||
|
op,
|
||||||
|
t1d.continuous,
|
||||||
|
t2d.discrete,
|
||||||
|
);
|
||||||
let continuousConvResult =
|
let continuousConvResult =
|
||||||
Continuous.reduce((+.), [|ccConvResult, dcConvResult, cdConvResult|]);
|
Continuous.reduce((+.), [|ccConvResult, dcConvResult, cdConvResult|]);
|
||||||
|
|
||||||
|
@ -866,23 +923,47 @@ module Shape = {
|
||||||
c => Mixed.make(~discrete=Discrete.empty, ~continuous=c),
|
c => Mixed.make(~discrete=Discrete.empty, ~continuous=c),
|
||||||
));
|
));
|
||||||
|
|
||||||
let combineAlgebraically = (op: AlgebraicCombinations.algebraicOperation, t1: t, t2: t): t => {
|
let combineAlgebraically =
|
||||||
switch ((t1, t2)) {
|
(op: SymbolicTypes.algebraicOperation, t1: t, t2: t): t => {
|
||||||
| (Continuous(m1), Continuous(m2)) => DistTypes.Continuous(Continuous.combineAlgebraically(~downsample=true, op, m1, m2))
|
switch (t1, t2) {
|
||||||
| (Discrete(m1), Discrete(m2)) => DistTypes.Discrete(Discrete.combineAlgebraically(op, m1, m2))
|
| (Continuous(m1), Continuous(m2)) =>
|
||||||
| (m1, m2) => {
|
DistTypes.Continuous(
|
||||||
DistTypes.Mixed(Mixed.combineAlgebraically(~downsample=true, op, toMixed(m1), toMixed(m2)))
|
Continuous.combineAlgebraically(~downsample=true, op, m1, m2),
|
||||||
}
|
)
|
||||||
|
| (Discrete(m1), Discrete(m2)) =>
|
||||||
|
DistTypes.Discrete(Discrete.combineAlgebraically(op, m1, m2))
|
||||||
|
| (m1, m2) =>
|
||||||
|
DistTypes.Mixed(
|
||||||
|
Mixed.combineAlgebraically(
|
||||||
|
~downsample=true,
|
||||||
|
op,
|
||||||
|
toMixed(m1),
|
||||||
|
toMixed(m2),
|
||||||
|
),
|
||||||
|
)
|
||||||
};
|
};
|
||||||
};
|
};
|
||||||
|
|
||||||
let combinePointwise = (~knownIntegralSumsFn=(_, _) => None, fn, t1: t, t2: t) =>
|
let combinePointwise =
|
||||||
switch ((t1, t2)) {
|
(~knownIntegralSumsFn=(_, _) => None, fn, t1: t, t2: t) =>
|
||||||
| (Continuous(m1), Continuous(m2)) => DistTypes.Continuous(Continuous.combinePointwise(~knownIntegralSumsFn, fn, m1, m2))
|
switch (t1, t2) {
|
||||||
| (Discrete(m1), Discrete(m2)) => DistTypes.Discrete(Discrete.combinePointwise(~knownIntegralSumsFn, fn, m1, m2))
|
| (Continuous(m1), Continuous(m2)) =>
|
||||||
| (m1, m2) => {
|
DistTypes.Continuous(
|
||||||
DistTypes.Mixed(Mixed.combinePointwise(~knownIntegralSumsFn, fn, toMixed(m1), toMixed(m2)))
|
Continuous.combinePointwise(~knownIntegralSumsFn, fn, m1, m2),
|
||||||
}
|
)
|
||||||
|
| (Discrete(m1), Discrete(m2)) =>
|
||||||
|
DistTypes.Discrete(
|
||||||
|
Discrete.combinePointwise(~knownIntegralSumsFn, fn, m1, m2),
|
||||||
|
)
|
||||||
|
| (m1, m2) =>
|
||||||
|
DistTypes.Mixed(
|
||||||
|
Mixed.combinePointwise(
|
||||||
|
~knownIntegralSumsFn,
|
||||||
|
fn,
|
||||||
|
toMixed(m1),
|
||||||
|
toMixed(m2),
|
||||||
|
),
|
||||||
|
)
|
||||||
};
|
};
|
||||||
|
|
||||||
// TODO: implement these functions
|
// TODO: implement these functions
|
||||||
|
@ -915,7 +996,6 @@ module Shape = {
|
||||||
let toContinuous = t => None;
|
let toContinuous = t => None;
|
||||||
let toDiscrete = t => None;
|
let toDiscrete = t => None;
|
||||||
|
|
||||||
|
|
||||||
let downsample = (~cache=None, i, t) =>
|
let downsample = (~cache=None, i, t) =>
|
||||||
fmap(
|
fmap(
|
||||||
(
|
(
|
||||||
|
@ -938,7 +1018,11 @@ module Shape = {
|
||||||
|
|
||||||
let toDiscreteProbabilityMassFraction = t => 0.0;
|
let toDiscreteProbabilityMassFraction = t => 0.0;
|
||||||
let normalize =
|
let normalize =
|
||||||
fmap((Mixed.T.normalize, Discrete.T.normalize, Continuous.T.normalize));
|
fmap((
|
||||||
|
Mixed.T.normalize,
|
||||||
|
Discrete.T.normalize,
|
||||||
|
Continuous.T.normalize,
|
||||||
|
));
|
||||||
let toContinuous =
|
let toContinuous =
|
||||||
mapToAll((
|
mapToAll((
|
||||||
Mixed.T.toContinuous,
|
Mixed.T.toContinuous,
|
||||||
|
@ -1089,7 +1173,8 @@ module DistPlus = {
|
||||||
};
|
};
|
||||||
|
|
||||||
let truncate = (leftCutoff, rightCutoff, t: t): t => {
|
let truncate = (leftCutoff, rightCutoff, t: t): t => {
|
||||||
let truncatedShape = t |> toShape |> Shape.T.truncate(leftCutoff, rightCutoff);
|
let truncatedShape =
|
||||||
|
t |> toShape |> Shape.T.truncate(leftCutoff, rightCutoff);
|
||||||
|
|
||||||
t |> updateShape(truncatedShape);
|
t |> updateShape(truncatedShape);
|
||||||
};
|
};
|
||||||
|
@ -1153,9 +1238,9 @@ module DistPlus = {
|
||||||
let integralYtoX = (~cache as _, f, t: t) => {
|
let integralYtoX = (~cache as _, f, t: t) => {
|
||||||
Shape.T.Integral.yToX(~cache=Some(t.integralCache), f, toShape(t));
|
Shape.T.Integral.yToX(~cache=Some(t.integralCache), f, toShape(t));
|
||||||
};
|
};
|
||||||
let mean = (t: t) => {
|
let mean = (t: t) => {
|
||||||
Shape.T.mean(t.shape);
|
Shape.T.mean(t.shape);
|
||||||
};
|
};
|
||||||
let variance = (t: t) => Shape.T.variance(t.shape);
|
let variance = (t: t) => Shape.T.variance(t.shape);
|
||||||
});
|
});
|
||||||
};
|
};
|
||||||
|
|
|
@ -123,6 +123,12 @@ module Normal = {
|
||||||
let stdev = 1. /. (1. /. n1.stdev ** 2. +. 1. /. n2.stdev ** 2.);
|
let stdev = 1. /. (1. /. n1.stdev ** 2. +. 1. /. n2.stdev ** 2.);
|
||||||
`Normal({mean, stdev});
|
`Normal({mean, stdev});
|
||||||
};
|
};
|
||||||
|
|
||||||
|
let operate = (operation: SymbolicTypes.Algebraic.t, n1: t, n2: t) => switch(operation){
|
||||||
|
| `Add => Some(add(n1, n2))
|
||||||
|
| `Subtract => Some(subtract(n1, n2))
|
||||||
|
| _ => None
|
||||||
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
module Beta = {
|
module Beta = {
|
||||||
|
@ -171,6 +177,11 @@ module Lognormal = {
|
||||||
let sigma = l1.sigma +. l2.sigma;
|
let sigma = l1.sigma +. l2.sigma;
|
||||||
`Lognormal({mu, sigma});
|
`Lognormal({mu, sigma});
|
||||||
};
|
};
|
||||||
|
let operate = (operation: SymbolicTypes.Algebraic.t, n1: t, n2: t) => switch(operation){
|
||||||
|
| `Multiply => Some(multiply(n1, n2))
|
||||||
|
| `Divide => Some(divide(n1, n2))
|
||||||
|
| _ => None
|
||||||
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
module Uniform = {
|
module Uniform = {
|
||||||
|
|
67
src/distPlus/symbolic/SymbolicTypes.re
Normal file
67
src/distPlus/symbolic/SymbolicTypes.re
Normal file
|
@ -0,0 +1,67 @@
|
||||||
|
type pointwiseOperation = [ | `Add | `Multiply];
|
||||||
|
type scaleOperation = [ | `Multiply | `Exponentiate | `Log];
|
||||||
|
type distToFloatOperation = [ | `Pdf(float) | `Inv(float) | `Mean | `Sample];
|
||||||
|
type algebraicOperation = [ | `Add | `Multiply | `Subtract | `Divide];
|
||||||
|
|
||||||
|
module Algebraic = {
|
||||||
|
type t = algebraicOperation;
|
||||||
|
let toFn: (t, float, float) => float =
|
||||||
|
fun
|
||||||
|
| `Add => (+.)
|
||||||
|
| `Subtract => (-.)
|
||||||
|
| `Multiply => ( *. )
|
||||||
|
| `Divide => (/.);
|
||||||
|
|
||||||
|
let applyFn = (t, f1, f2) => {
|
||||||
|
switch (t, f1, f2) {
|
||||||
|
| (`Divide, _, 0.) => Error("Cannot divide $v1 by zero.")
|
||||||
|
| _ => Ok(toFn(t, f1, f2))
|
||||||
|
};
|
||||||
|
};
|
||||||
|
|
||||||
|
let toString =
|
||||||
|
fun
|
||||||
|
| `Add => "+"
|
||||||
|
| `Subtract => "-"
|
||||||
|
| `Multiply => "*"
|
||||||
|
| `Divide => "/";
|
||||||
|
|
||||||
|
let format = (a, b, c) => b ++ " " ++ toString(a) ++ " " ++ c;
|
||||||
|
};
|
||||||
|
|
||||||
|
module Pointwise = {
|
||||||
|
type t = pointwiseOperation;
|
||||||
|
let toString =
|
||||||
|
fun
|
||||||
|
| `Add => "+"
|
||||||
|
| `Multiply => "*";
|
||||||
|
|
||||||
|
let format = (a, b, c) => b ++ " " ++ toString(a) ++ " " ++ c;
|
||||||
|
};
|
||||||
|
|
||||||
|
module DistToFloat = {
|
||||||
|
type t = distToFloatOperation;
|
||||||
|
|
||||||
|
let stringFromFloatFromDistOperation =
|
||||||
|
fun
|
||||||
|
| `Pdf(f) => {j|pdf(x=$f, |j}
|
||||||
|
| `Inv(f) => {j|inv(x=$f, |j}
|
||||||
|
| `Sample => "sample("
|
||||||
|
| `Mean => "mean(";
|
||||||
|
let format = (a, b) => stringFromFloatFromDistOperation(a) ++ b ++ ")";
|
||||||
|
};
|
||||||
|
|
||||||
|
module Scale = {
|
||||||
|
type t = scaleOperation;
|
||||||
|
let toFn =
|
||||||
|
fun
|
||||||
|
| `Multiply => ( *. )
|
||||||
|
| `Exponentiate => ( ** )
|
||||||
|
| `Log => ((a, b) => log(a) /. log(b));
|
||||||
|
|
||||||
|
let toKnownIntegralSumFn =
|
||||||
|
fun
|
||||||
|
| `Multiply => ((a, b) => Some(a *. b))
|
||||||
|
| `Exponentiate => ((_, _) => None)
|
||||||
|
| `Log => ((_, _) => None);
|
||||||
|
}
|
|
@ -1,86 +1,41 @@
|
||||||
/* This module represents a tree node. */
|
/* This module represents a tree node. */
|
||||||
|
open SymbolicTypes;
|
||||||
|
|
||||||
// todo: Symbolic already has an arbitrary continuousShape option. It seems messy to have both.
|
// todo: Symbolic already has an arbitrary continuousShape option. It seems messy to have both.
|
||||||
type distData = [
|
type distData = [
|
||||||
| `Symbolic(SymbolicDist.dist)
|
| `Symbolic(SymbolicDist.dist)
|
||||||
| `RenderedShape(DistTypes.shape)
|
| `RenderedShape(DistTypes.shape)
|
||||||
];
|
];
|
||||||
|
/* TreeNodes are either Data (i.e. symbolic or rendered distributions) or Operations. Operations always refer to two child nodes.*/
|
||||||
type pointwiseOperation = [ | `Add | `Multiply];
|
type treeNode = [ | `DistData(distData) | `Operation(operation)]
|
||||||
type scaleOperation = [ | `Multiply | `Exponentiate | `Log];
|
|
||||||
type distToFloatOperation = [ | `Pdf(float) | `Inv(float) | `Mean | `Sample];
|
|
||||||
|
|
||||||
/* TreeNodes are either Data (i.e. symbolic or rendered distributions) or Operations. */
|
|
||||||
type treeNode = [
|
|
||||||
| `DistData(distData) // a leaf node that describes a distribution
|
|
||||||
| `Operation(operation) // an operation on two child nodes
|
|
||||||
]
|
|
||||||
and operation = [
|
and operation = [
|
||||||
| // binary operations
|
| `AlgebraicCombination(algebraicOperation, treeNode, treeNode)
|
||||||
`AlgebraicCombination(
|
| `PointwiseCombination(pointwiseOperation, treeNode, treeNode)
|
||||||
AlgebraicCombinations.algebraicOperation,
|
| `VerticalScaling(scaleOperation, treeNode, treeNode)
|
||||||
treeNode,
|
| `Render(treeNode)
|
||||||
treeNode,
|
| `Truncate(option(float), option(float), treeNode)
|
||||||
)
|
| `Normalize(treeNode)
|
||||||
// unary operations
|
| `FloatFromDist(distToFloatOperation, treeNode)
|
||||||
| `PointwiseCombination(pointwiseOperation, treeNode, treeNode) // always evaluates to `DistData(`RenderedShape(...))
|
|
||||||
| `VerticalScaling(scaleOperation, treeNode, treeNode) // always evaluates to `DistData(`RenderedShape(...))
|
|
||||||
| `Render(treeNode) // always evaluates to `DistData(`RenderedShape(...))
|
|
||||||
| `Truncate // always evaluates to `DistData(`RenderedShape(...))
|
|
||||||
(
|
|
||||||
option(float),
|
|
||||||
option(float),
|
|
||||||
treeNode,
|
|
||||||
) // leftCutoff and rightCutoff
|
|
||||||
| `Normalize // always evaluates to `DistData(`RenderedShape(...))
|
|
||||||
// leftCutoff and rightCutoff
|
|
||||||
(
|
|
||||||
treeNode,
|
|
||||||
)
|
|
||||||
| `FloatFromDist // always evaluates to `DistData(`RenderedShape(...))
|
|
||||||
// leftCutoff and rightCutoff
|
|
||||||
(
|
|
||||||
distToFloatOperation,
|
|
||||||
treeNode,
|
|
||||||
)
|
|
||||||
];
|
];
|
||||||
|
|
||||||
module TreeNode = {
|
module TreeNode = {
|
||||||
type t = treeNode;
|
type t = treeNode;
|
||||||
type tResult = treeNode => result(treeNode, string);
|
type tResult = treeNode => result(treeNode, string);
|
||||||
|
|
||||||
let rec toString = (t: t): string => {
|
let rec toString =
|
||||||
let stringFromAlgebraicCombination =
|
fun
|
||||||
fun
|
|
||||||
| `Add => " + "
|
|
||||||
| `Subtract => " - "
|
|
||||||
| `Multiply => " * "
|
|
||||||
| `Divide => " / "
|
|
||||||
|
|
||||||
let stringFromPointwiseCombination =
|
|
||||||
fun
|
|
||||||
| `Add => " .+ "
|
|
||||||
| `Multiply => " .* ";
|
|
||||||
|
|
||||||
let stringFromFloatFromDistOperation =
|
|
||||||
fun
|
|
||||||
| `Pdf(f) => {j|pdf(x=$f, |j}
|
|
||||||
| `Inv(f) => {j|inv(x=$f, |j}
|
|
||||||
| `Sample => "sample("
|
|
||||||
| `Mean => "mean(";
|
|
||||||
|
|
||||||
switch (t) {
|
|
||||||
| `DistData(`Symbolic(d)) =>
|
| `DistData(`Symbolic(d)) =>
|
||||||
SymbolicDist.GenericDistFunctions.toString(d)
|
SymbolicDist.GenericDistFunctions.toString(d)
|
||||||
| `DistData(`RenderedShape(_)) => "[shape]"
|
| `DistData(`RenderedShape(_)) => "[shape]"
|
||||||
| `Operation(`AlgebraicCombination(op, t1, t2)) =>
|
| `Operation(`AlgebraicCombination(op, t1, t2)) =>
|
||||||
toString(t1) ++ stringFromAlgebraicCombination(op) ++ toString(t2)
|
SymbolicTypes.Algebraic.format(op, toString(t1), toString(t2))
|
||||||
| `Operation(`PointwiseCombination(op, t1, t2)) =>
|
| `Operation(`PointwiseCombination(op, t1, t2)) =>
|
||||||
toString(t1) ++ stringFromPointwiseCombination(op) ++ toString(t2)
|
SymbolicTypes.Pointwise.format(op, toString(t1), toString(t2))
|
||||||
| `Operation(`VerticalScaling(_scaleOp, t, scaleBy)) =>
|
| `Operation(`VerticalScaling(_scaleOp, t, scaleBy)) =>
|
||||||
toString(t) ++ " @ " ++ toString(scaleBy)
|
toString(t) ++ " @ " ++ toString(scaleBy)
|
||||||
| `Operation(`Normalize(t)) => "normalize(" ++ toString(t) ++ ")"
|
| `Operation(`Normalize(t)) => "normalize(" ++ toString(t) ++ ")"
|
||||||
| `Operation(`FloatFromDist(floatFromDistOp, t)) => stringFromFloatFromDistOperation(floatFromDistOp) ++ toString(t) ++ ")"
|
| `Operation(`FloatFromDist(floatFromDistOp, t)) =>
|
||||||
|
SymbolicTypes.DistToFloat.format(floatFromDistOp, toString(t))
|
||||||
| `Operation(`Truncate(lc, rc, t)) =>
|
| `Operation(`Truncate(lc, rc, t)) =>
|
||||||
"truncate("
|
"truncate("
|
||||||
++ toString(t)
|
++ toString(t)
|
||||||
|
@ -89,9 +44,7 @@ module TreeNode = {
|
||||||
++ ", "
|
++ ", "
|
||||||
++ E.O.dimap(Js.Float.toString, () => "inf", rc)
|
++ E.O.dimap(Js.Float.toString, () => "inf", rc)
|
||||||
++ ")"
|
++ ")"
|
||||||
| `Operation(`Render(t)) => toString(t)
|
| `Operation(`Render(t)) => toString(t);
|
||||||
};
|
|
||||||
};
|
|
||||||
|
|
||||||
/* The following modules encapsulate everything we can do with
|
/* The following modules encapsulate everything we can do with
|
||||||
* different kinds of operations. */
|
* different kinds of operations. */
|
||||||
|
@ -104,88 +57,72 @@ module TreeNode = {
|
||||||
let simplify = (algebraicOp, t1: t, t2: t): result(treeNode, string) => {
|
let simplify = (algebraicOp, t1: t, t2: t): result(treeNode, string) => {
|
||||||
let tryCombiningFloats: tResult =
|
let tryCombiningFloats: tResult =
|
||||||
fun
|
fun
|
||||||
| `Operation(
|
|
||||||
`AlgebraicCombination(
|
|
||||||
`Divide,
|
|
||||||
`DistData(`Symbolic(`Float(_))),
|
|
||||||
`DistData(`Symbolic(`Float(0.))),
|
|
||||||
),
|
|
||||||
) =>
|
|
||||||
Error("Cannot divide $v1 by zero.")
|
|
||||||
| `Operation(
|
| `Operation(
|
||||||
`AlgebraicCombination(
|
`AlgebraicCombination(
|
||||||
algebraicOp,
|
algebraicOp,
|
||||||
`DistData(`Symbolic(`Float(v1))),
|
`DistData(`Symbolic(`Float(v1))),
|
||||||
`DistData(`Symbolic(`Float(v2))),
|
`DistData(`Symbolic(`Float(v2))),
|
||||||
),
|
),
|
||||||
) => {
|
) =>
|
||||||
let func = AlgebraicCombinations.Operation.toFn(algebraicOp);
|
SymbolicTypes.Algebraic.applyFn(algebraicOp, v1, v2)
|
||||||
Ok(`DistData(`Symbolic(`Float(func(v1, v2)))));
|
|> E.R.fmap(r => `DistData(`Symbolic(`Float(r))))
|
||||||
}
|
|
||||||
| t => Ok(t);
|
| t => Ok(t);
|
||||||
|
|
||||||
|
let optionToSymbolicResult = (t, o) =>
|
||||||
|
o
|
||||||
|
|> E.O.dimap(r => `DistData(`Symbolic(r)), () => t)
|
||||||
|
|> (r => Ok(r));
|
||||||
|
|
||||||
let tryCombiningNormals: tResult =
|
let tryCombiningNormals: tResult =
|
||||||
fun
|
fun
|
||||||
| `Operation(
|
| `Operation(
|
||||||
`AlgebraicCombination(
|
`AlgebraicCombination(
|
||||||
`Add,
|
operation,
|
||||||
`DistData(`Symbolic(`Normal(n1))),
|
`DistData(`Symbolic(`Normal(n1))),
|
||||||
`DistData(`Symbolic(`Normal(n2))),
|
`DistData(`Symbolic(`Normal(n2))),
|
||||||
),
|
),
|
||||||
) =>
|
) as t =>
|
||||||
Ok(`DistData(`Symbolic(SymbolicDist.Normal.add(n1, n2))))
|
SymbolicDist.Normal.operate(operation, n1, n2)
|
||||||
| `Operation(
|
|> optionToSymbolicResult(t)
|
||||||
`AlgebraicCombination(
|
|
||||||
`Subtract,
|
|
||||||
`DistData(`Symbolic(`Normal(n1))),
|
|
||||||
`DistData(`Symbolic(`Normal(n2))),
|
|
||||||
),
|
|
||||||
) =>
|
|
||||||
Ok(`DistData(`Symbolic(SymbolicDist.Normal.subtract(n1, n2))))
|
|
||||||
| t => Ok(t);
|
| t => Ok(t);
|
||||||
|
|
||||||
let tryCombiningLognormals: tResult =
|
let tryCombiningLognormals: tResult =
|
||||||
fun
|
fun
|
||||||
| `Operation(
|
| `Operation(
|
||||||
`AlgebraicCombination(
|
`AlgebraicCombination(
|
||||||
`Multiply,
|
operation,
|
||||||
`DistData(`Symbolic(`Lognormal(l1))),
|
`DistData(`Symbolic(`Lognormal(n1))),
|
||||||
`DistData(`Symbolic(`Lognormal(l2))),
|
`DistData(`Symbolic(`Lognormal(n2))),
|
||||||
),
|
),
|
||||||
) =>
|
) as t =>
|
||||||
Ok(`DistData(`Symbolic(SymbolicDist.Lognormal.multiply(l1, l2))))
|
SymbolicDist.Lognormal.operate(operation, n1, n2)
|
||||||
| `Operation(
|
|> optionToSymbolicResult(t)
|
||||||
`AlgebraicCombination(
|
|
||||||
`Divide,
|
|
||||||
`DistData(`Symbolic(`Lognormal(l1))),
|
|
||||||
`DistData(`Symbolic(`Lognormal(l2))),
|
|
||||||
),
|
|
||||||
) =>
|
|
||||||
Ok(`DistData(`Symbolic(SymbolicDist.Lognormal.divide(l1, l2))))
|
|
||||||
| t => Ok(t);
|
| t => Ok(t);
|
||||||
|
|
||||||
let originalTreeNode =
|
let originalTreeNode =
|
||||||
`Operation(`AlgebraicCombination((algebraicOp, t1, t2)));
|
`Operation(`AlgebraicCombination((algebraicOp, t1, t2)));
|
||||||
|
|
||||||
|
// Feedback: I like this pattern, kudos
|
||||||
originalTreeNode
|
originalTreeNode
|
||||||
|> tryCombiningFloats
|
|> tryCombiningFloats
|
||||||
|> E.R.bind(_, tryCombiningNormals)
|
|> E.R.bind(_, tryCombiningNormals)
|
||||||
|> E.R.bind(_, tryCombiningLognormals);
|
|> E.R.bind(_, tryCombiningLognormals);
|
||||||
};
|
};
|
||||||
|
|
||||||
|
// todo: I don't like the name evaluateNumerically that much, if this renders and does it algebraically. It's tricky.
|
||||||
let evaluateNumerically = (algebraicOp, operationToDistData, t1, t2) => {
|
let evaluateNumerically = (algebraicOp, operationToDistData, t1, t2) => {
|
||||||
// force rendering into shapes
|
// force rendering into shapes
|
||||||
let renderedShape1 = operationToDistData(`Render(t1));
|
let renderShape = r => operationToDistData(`Render(r));
|
||||||
let renderedShape2 = operationToDistData(`Render(t2));
|
switch (renderShape(t1), renderShape(t2)) {
|
||||||
|
|
||||||
switch (renderedShape1, renderedShape2) {
|
|
||||||
| (
|
| (
|
||||||
Ok(`DistData(`RenderedShape(s1))),
|
Ok(`DistData(`RenderedShape(s1))),
|
||||||
Ok(`DistData(`RenderedShape(s2))),
|
Ok(`DistData(`RenderedShape(s2))),
|
||||||
) =>
|
) =>
|
||||||
Ok(
|
Ok(
|
||||||
`DistData(
|
`DistData(
|
||||||
`RenderedShape(Distributions.Shape.combineAlgebraically(algebraicOp, s1, s2)),
|
`RenderedShape(
|
||||||
|
Distributions.Shape.combineAlgebraically(algebraicOp, s1, s2),
|
||||||
|
),
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
| (Error(e1), _) => Error(e1)
|
| (Error(e1), _) => Error(e1)
|
||||||
|
@ -195,7 +132,12 @@ module TreeNode = {
|
||||||
};
|
};
|
||||||
|
|
||||||
let evaluateToDistData =
|
let evaluateToDistData =
|
||||||
(algebraicOp: AlgebraicCombinations.algebraicOperation, operationToDistData, t1: t, t2: t)
|
(
|
||||||
|
algebraicOp: SymbolicTypes.algebraicOperation,
|
||||||
|
operationToDistData,
|
||||||
|
t1: t,
|
||||||
|
t2: t,
|
||||||
|
)
|
||||||
: result(treeNode, string) =>
|
: result(treeNode, string) =>
|
||||||
algebraicOp
|
algebraicOp
|
||||||
|> simplify(_, t1, t2)
|
|> simplify(_, t1, t2)
|
||||||
|
@ -210,27 +152,13 @@ module TreeNode = {
|
||||||
};
|
};
|
||||||
|
|
||||||
module VerticalScaling = {
|
module VerticalScaling = {
|
||||||
let fnFromOp =
|
|
||||||
fun
|
|
||||||
| `Multiply => ( *. )
|
|
||||||
| `Exponentiate => ( ** )
|
|
||||||
| `Log => ((a, b) => log(a) /. log(b));
|
|
||||||
|
|
||||||
let knownIntegralSumFnFromOp =
|
|
||||||
fun
|
|
||||||
| `Multiply => ((a, b) => Some(a *. b))
|
|
||||||
| `Exponentiate => ((_, _) => None)
|
|
||||||
| `Log => ((_, _) => None);
|
|
||||||
|
|
||||||
let evaluateToDistData = (scaleOp, operationToDistData, t, scaleBy) => {
|
let evaluateToDistData = (scaleOp, operationToDistData, t, scaleBy) => {
|
||||||
// scaleBy has to be a single float, otherwise we'll return an error.
|
// scaleBy has to be a single float, otherwise we'll return an error.
|
||||||
let fn = fnFromOp(scaleOp);
|
let fn = SymbolicTypes.Scale.toFn(scaleOp);
|
||||||
let knownIntegralSumFn = knownIntegralSumFnFromOp(scaleOp);
|
let knownIntegralSumFn = SymbolicTypes.Scale.toKnownIntegralSumFn(scaleOp);
|
||||||
|
|
||||||
let renderedShape = operationToDistData(`Render(t));
|
let renderedShape = operationToDistData(`Render(t));
|
||||||
|
|
||||||
switch (renderedShape, scaleBy) {
|
switch (renderedShape, scaleBy) {
|
||||||
| (Error(e1), _) => Error(e1)
|
|
||||||
| (
|
| (
|
||||||
Ok(`DistData(`RenderedShape(rs))),
|
Ok(`DistData(`RenderedShape(rs))),
|
||||||
`DistData(`Symbolic(`Float(sm))),
|
`DistData(`Symbolic(`Float(sm))),
|
||||||
|
@ -246,6 +174,7 @@ module TreeNode = {
|
||||||
),
|
),
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
|
| (Error(e1), _) => Error(e1)
|
||||||
| (_, _) => Error("Can only scale by float values.")
|
| (_, _) => Error("Can only scale by float values.")
|
||||||
};
|
};
|
||||||
};
|
};
|
||||||
|
@ -253,14 +182,28 @@ module TreeNode = {
|
||||||
|
|
||||||
module PointwiseCombination = {
|
module PointwiseCombination = {
|
||||||
let pointwiseAdd = (operationToDistData, t1, t2) => {
|
let pointwiseAdd = (operationToDistData, t1, t2) => {
|
||||||
let renderedShape1 = operationToDistData(`Render(t1));
|
let renderedShape1 = operationToDistData(`Render(t1));
|
||||||
let renderedShape2 = operationToDistData(`Render(t2));
|
let renderedShape2 = operationToDistData(`Render(t2));
|
||||||
|
|
||||||
switch ((renderedShape1, renderedShape2)) {
|
switch (renderedShape1, renderedShape2) {
|
||||||
|
| (
|
||||||
|
Ok(`DistData(`RenderedShape(rs1))),
|
||||||
|
Ok(`DistData(`RenderedShape(rs2))),
|
||||||
|
) =>
|
||||||
|
Ok(
|
||||||
|
`DistData(
|
||||||
|
`RenderedShape(
|
||||||
|
Distributions.Shape.combinePointwise(
|
||||||
|
~knownIntegralSumsFn=(a, b) => Some(a +. b),
|
||||||
|
(+.),
|
||||||
|
rs1,
|
||||||
|
rs2,
|
||||||
|
),
|
||||||
|
),
|
||||||
|
),
|
||||||
|
)
|
||||||
| (Error(e1), _) => Error(e1)
|
| (Error(e1), _) => Error(e1)
|
||||||
| (_, Error(e2)) => Error(e2)
|
| (_, Error(e2)) => Error(e2)
|
||||||
| (Ok(`DistData(`RenderedShape(rs1))), Ok(`DistData(`RenderedShape(rs2)))) =>
|
|
||||||
Ok(`DistData(`RenderedShape(Distributions.Shape.combinePointwise(~knownIntegralSumsFn=(a, b) => Some(a +. b), (+.), rs1, rs2))))
|
|
||||||
| _ => Error("Could not perform pointwise addition.")
|
| _ => Error("Could not perform pointwise addition.")
|
||||||
};
|
};
|
||||||
};
|
};
|
||||||
|
@ -268,14 +211,16 @@ module TreeNode = {
|
||||||
let pointwiseMultiply = (operationToDistData, t1, t2) => {
|
let pointwiseMultiply = (operationToDistData, t1, t2) => {
|
||||||
// TODO: construct a function that we can easily sample from, to construct
|
// TODO: construct a function that we can easily sample from, to construct
|
||||||
// a RenderedShape. Use the xMin and xMax of the rendered shapes to tell the sampling function where to look.
|
// a RenderedShape. Use the xMin and xMax of the rendered shapes to tell the sampling function where to look.
|
||||||
Error("Pointwise multiplication not yet supported.");
|
Error(
|
||||||
|
"Pointwise multiplication not yet supported.",
|
||||||
|
);
|
||||||
};
|
};
|
||||||
|
|
||||||
let evaluateToDistData = (pointwiseOp, operationToDistData, t1, t2) => {
|
let evaluateToDistData = (pointwiseOp, operationToDistData, t1, t2) => {
|
||||||
switch (pointwiseOp) {
|
switch (pointwiseOp) {
|
||||||
| `Add => pointwiseAdd(operationToDistData, t1, t2)
|
| `Add => pointwiseAdd(operationToDistData, t1, t2)
|
||||||
| `Multiply => pointwiseMultiply(operationToDistData, t1, t2)
|
| `Multiply => pointwiseMultiply(operationToDistData, t1, t2)
|
||||||
}
|
};
|
||||||
};
|
};
|
||||||
};
|
};
|
||||||
|
|
||||||
|
@ -378,7 +323,9 @@ module TreeNode = {
|
||||||
};
|
};
|
||||||
E.R.bind(value, v => Ok(`DistData(`Symbolic(`Float(v)))));
|
E.R.bind(value, v => Ok(`DistData(`Symbolic(`Float(v)))));
|
||||||
};
|
};
|
||||||
let evaluateFromRenderedShape = (distToFloatOp: distToFloatOperation, rs: DistTypes.shape) : result(treeNode, string) => {
|
let evaluateFromRenderedShape =
|
||||||
|
(distToFloatOp: distToFloatOperation, rs: DistTypes.shape)
|
||||||
|
: result(treeNode, string) => {
|
||||||
let value =
|
let value =
|
||||||
switch (distToFloatOp) {
|
switch (distToFloatOp) {
|
||||||
| `Pdf(f) => Ok(Distributions.Shape.pdf(f, rs))
|
| `Pdf(f) => Ok(Distributions.Shape.pdf(f, rs))
|
||||||
|
@ -410,8 +357,12 @@ module TreeNode = {
|
||||||
|
|
||||||
module Render = {
|
module Render = {
|
||||||
let rec evaluateToRenderedShape =
|
let rec evaluateToRenderedShape =
|
||||||
(operationToDistData: operation => result(t, string), sampleCount: int, t: treeNode)
|
(
|
||||||
: result(t, string) => {
|
operationToDistData: operation => result(t, string),
|
||||||
|
sampleCount: int,
|
||||||
|
t: treeNode,
|
||||||
|
)
|
||||||
|
: result(t, string) => {
|
||||||
switch (t) {
|
switch (t) {
|
||||||
| `DistData(`RenderedShape(s)) => Ok(`DistData(`RenderedShape(s))) // already a rendered shape, we're done here
|
| `DistData(`RenderedShape(s)) => Ok(`DistData(`RenderedShape(s))) // already a rendered shape, we're done here
|
||||||
| `DistData(`Symbolic(d)) =>
|
| `DistData(`Symbolic(d)) =>
|
||||||
|
@ -495,10 +446,19 @@ module TreeNode = {
|
||||||
t,
|
t,
|
||||||
)
|
)
|
||||||
| `FloatFromDist(distToFloatOp, t) =>
|
| `FloatFromDist(distToFloatOp, t) =>
|
||||||
FloatFromDist.evaluateToDistData(distToFloatOp, operationToDistData(sampleCount), t)
|
FloatFromDist.evaluateToDistData(
|
||||||
| `Normalize(t) => Normalize.evaluateToDistData(operationToDistData(sampleCount), t)
|
distToFloatOp,
|
||||||
|
operationToDistData(sampleCount),
|
||||||
|
t,
|
||||||
|
)
|
||||||
|
| `Normalize(t) =>
|
||||||
|
Normalize.evaluateToDistData(operationToDistData(sampleCount), t)
|
||||||
| `Render(t) =>
|
| `Render(t) =>
|
||||||
Render.evaluateToRenderedShape(operationToDistData(sampleCount), sampleCount, t)
|
Render.evaluateToRenderedShape(
|
||||||
|
operationToDistData(sampleCount),
|
||||||
|
sampleCount,
|
||||||
|
t,
|
||||||
|
)
|
||||||
};
|
};
|
||||||
};
|
};
|
||||||
|
|
||||||
|
@ -531,5 +491,4 @@ let toShape = (sampleCount: int, treeNode: treeNode) => {
|
||||||
};
|
};
|
||||||
};
|
};
|
||||||
|
|
||||||
let toString = (treeNode: treeNode) =>
|
let toString = (treeNode: treeNode) => TreeNode.toString(treeNode);
|
||||||
TreeNode.toString(treeNode);
|
|
||||||
|
|
Loading…
Reference in New Issue
Block a user