First working prototype of algebraic combinations

This commit is contained in:
Sebastian Kosch 2020-06-29 22:29:15 -07:00
parent f1e2458bca
commit 502481e345
6 changed files with 331 additions and 287 deletions

View File

@ -172,7 +172,7 @@ let make = () => {
~onSubmit=({state}) => {None}, ~onSubmit=({state}) => {None},
~initialState={ ~initialState={
//guesstimatorString: "mm(normal(-10, 2), uniform(18, 25), lognormal({mean: 10, stdev: 8}), triangular(31,40,50))", //guesstimatorString: "mm(normal(-10, 2), uniform(18, 25), lognormal({mean: 10, stdev: 8}), triangular(31,40,50))",
guesstimatorString: "uniform(0, 1) * normal(1, 2)", guesstimatorString: "uniform(0, 1) * normal(1, 2) - 1",
domainType: "Complete", domainType: "Complete",
xPoint: "50.0", xPoint: "50.0",
xPoint2: "60.0", xPoint2: "60.0",

View File

@ -0,0 +1,161 @@
type algebraicOperation = [
| `Add
| `Multiply
| `Subtract
| `Divide
];
type pointMassesWithMoments = {
n: int,
masses: array(float),
means: array(float),
variances: array(float)
};
let operationToFn: (algebraicOperation, float, float) => float =
fun
| `Add => (+.)
| `Subtract => (-.)
| `Multiply => ( *. )
| `Divide => (/.);
/* This function takes a continuous distribution and efficiently approximates it as
point masses that have variances associated with them.
We estimate the means and variances from overlapping triangular distributions which we imagine are making up the
XYShape.
We can then use the algebra of random variables to "convolve" the point masses and their variances,
and finally reconstruct a new distribution from them, e.g. using a Fast Gauss Transform or Raykar et al. (2007). */
let toDiscretePointMassesFromTriangulars = (~inverse=false, s: XYShape.T.t): pointMassesWithMoments => {
// TODO: what if there is only one point in the distribution?
let n = s |> XYShape.T.length;
// first, double up the leftmost and rightmost points:
let {xs, ys}: XYShape.T.t = s;
let _ = Js.Array.unshift(xs[0], xs);
let _ = Js.Array.unshift(ys[0], ys);
let _ = Js.Array.push(xs[n - 1], xs);
let _ = Js.Array.push(ys[n - 1], ys);
let n = E.A.length(xs);
// squares and neighbourly products of the xs
let xsSq: array(float) = Belt.Array.makeUninitializedUnsafe(n);
let xsProdN1: array(float) = Belt.Array.makeUninitializedUnsafe(n - 1);
let xsProdN2: array(float) = Belt.Array.makeUninitializedUnsafe(n - 2);
for (i in 0 to n - 1) {
let _ = Belt.Array.set(xsSq, i, xs[i] *. xs[i]); ();
};
for (i in 0 to n - 2) {
let _ = Belt.Array.set(xsProdN1, i, xs[i] *. xs[i + 1]); ();
};
for (i in 0 to n - 3) {
let _ = Belt.Array.set(xsProdN2, i, xs[i] *. xs[i + 2]); ();
};
// means and variances
let masses: array(float) = Belt.Array.makeUninitializedUnsafe(n - 2); // doesn't include the fake first and last points
let means: array(float) = Belt.Array.makeUninitializedUnsafe(n - 2);
let variances: array(float) = Belt.Array.makeUninitializedUnsafe(n - 2);
if (inverse) {
for (i in 1 to n - 2) {
let _ = Belt.Array.set(masses, i - 1, (xs[i + 1] -. xs[i - 1]) *. ys[i] /. 2.);
// this only works when the whole triange is either on the left or on the right of zero
let a = xs[i - 1];
let c = xs[i];
let b = xs[i + 1];
// These are the moments of the reciprocal of a triangular distribution, as symbolically integrated by Mathematica.
// They're probably pretty close to invMean ~ 1/mean = 3/(a+b+c) and invVar. But I haven't worked out
// the worst case error, so for now let's use these monster equations
let inverseMean = 2. *. ((a *. log(a/.c) /. (a-.c)) +. ((b *. log(c/.b))/.(b-.c))) /. (a -. b);
let inverseVar = 2. *. ((log(c/.a) /. (a-.c)) +. ((b *. log(b/.c))/.(b-.c))) /. (a -. b) -. inverseMean ** 2.;
let _ = Belt.Array.set(means, i - 1, inverseMean);
let _ = Belt.Array.set(variances, i - 1, inverseVar);
();
};
{n: n - 2, masses, means, variances};
} else {
for (i in 1 to n - 2) {
let _ = Belt.Array.set(masses, i - 1, (xs[i + 1] -. xs[i - 1]) *. ys[i] /. 2.);
let _ = Belt.Array.set(means, i - 1, (xs[i - 1] +. xs[i] +. xs[i + 1]) /. 3.);
let _ = Belt.Array.set(variances, i - 1,
(xsSq[i-1] +. xsSq[i] +. xsSq[i+1] -. xsProdN1[i-1] -. xsProdN1[i] -. xsProdN2[i-1]) /. 18.);
();
};
{n: n - 2, masses, means, variances};
};
};
let combineShapesContinuousContinuous = (op: algebraicOperation, s1: DistTypes.xyShape, s2: DistTypes.xyShape): DistTypes.xyShape => {
let t1n = s1 |> XYShape.T.length;
let t2n = s2 |> XYShape.T.length;
// if we add the two distributions, we should probably use normal filters.
// if we multiply the two distributions, we should probably use lognormal filters.
let t1m = toDiscretePointMassesFromTriangulars(s1);
let t2m = toDiscretePointMassesFromTriangulars(s2);
let combineMeansFn = switch (op) {
| `Add => (m1, m2) => m1 +. m2
| `Subtract => (m1, m2) => m1 -. m2
| `Multiply => (m1, m2) => m1 *. m2
| `Divide => (m1, mInv2) => m1 *. mInv2
}; // note: here, mInv2 = mean(1 / t2) ~= 1 / mean(t2)
// converts the variances and means of the two inputs into the variance of the output
let combineVariancesFn = switch (op) {
| `Add => (v1, v2, m1, m2) => v1 +. v2
| `Subtract => (v1, v2, m1, m2) => v1 +. v2
| `Multiply => (v1, v2, m1, m2) => (v1 *. v2) +. (v1 *. m1**2.) +. (v2 *. m1**2.)
| `Divide => (v1, vInv2, m1, mInv2) => (v1 *. vInv2) +. (v1 *. mInv2**2.) +. (vInv2 *. m1**2.)
};
let outputMinX: ref(float) = ref(infinity);
let outputMaxX: ref(float) = ref(neg_infinity);
let masses: array(float) = Belt.Array.makeUninitializedUnsafe(t1m.n * t2m.n);
let means: array(float) = Belt.Array.makeUninitializedUnsafe(t1m.n * t2m.n);
let variances: array(float) = Belt.Array.makeUninitializedUnsafe(t1m.n * t2m.n);
// then convolve the two sets of pointMassesWithMoments
for (i in 0 to t1m.n - 1) {
for (j in 0 to t2m.n - 1) {
let k = i * t2m.n + j;
let _ = Belt.Array.set(masses, k, t1m.masses[i] *. t2m.masses[j]);
let mean = combineMeansFn(t1m.means[i], t2m.means[j]);
let variance = combineVariancesFn(t1m.variances[i], t2m.variances[j], t1m.means[i], t2m.means[j]);
let _ = Belt.Array.set(means, k, mean);
let _ = Belt.Array.set(variances, k, variance);
// update bounds
let minX = mean -. variance *. 1.644854;
let maxX = mean +. variance *. 1.644854;
if (minX < outputMinX^) {
outputMinX := minX;
}
if (maxX > outputMaxX^) {
outputMaxX := maxX;
}
};
};
// we now want to create a set of target points. For now, let's just evenly distribute 200 points between
// between the outputMinX and outputMaxX
let outputXs: array(float) = E.A.Floats.range(outputMinX^, outputMaxX^, 200);
let outputYs: array(float) = Belt.Array.make(200, 0.0);
// now, for each of the outputYs, accumulate from a Gaussian kernel over each input point.
for (i in 0 to E.A.length(outputXs) - 1) {
let x = outputXs[i];
for (j in 0 to E.A.length(masses) - 1) {
let dx = outputXs[i] -. means[j];
let contribution = masses[j] *. exp(-.(dx**2.) /. (2. *. variances[j]));
let _ = Belt.Array.set(outputYs, i, outputYs[i] +. contribution);
();
};
();
};
{xs: outputXs, ys: outputYs};
};

View File

@ -95,7 +95,7 @@ module Continuous = {
interpolation: `Linear, interpolation: `Linear,
knownIntegralSum: Some(0.0), knownIntegralSum: Some(0.0),
}; };
let combine = let combinePointwise =
( (
~knownIntegralSumsFn, ~knownIntegralSumsFn,
fn, fn,
@ -114,7 +114,7 @@ module Continuous = {
make( make(
`Linear, `Linear,
XYShape.Combine.combine( XYShape.PointwiseCombination.combine(
~xsSelection=ALL_XS, ~xsSelection=ALL_XS,
~xToYSelection=XYShape.XtoY.linear, ~xToYSelection=XYShape.XtoY.linear,
~fn, ~fn,
@ -147,42 +147,9 @@ module Continuous = {
continuousShapes, continuousShapes,
) => ) =>
continuousShapes continuousShapes
|> E.A.fold_left(combine(~knownIntegralSumsFn, fn), empty); |> E.A.fold_left(combinePointwise(~knownIntegralSumsFn, fn), empty);
// Contracts every point in the continuous xyShape into a single dirac-Delta-like point, let mapY = (~knownIntegralSumFn=(_ => None), fn, t: t) => {
// using the centerpoints between adjacent xs and the area under each trapezoid.
// This is essentially like integrateWithTriangles, without the accumulation.
let toDiscretePointMasses = (t: t): DistTypes.discreteShape => {
let tl = t |> getShape |> XYShape.T.length;
let pointMassesX: array(float) = Belt.Array.make(tl - 1, 0.0);
let pointMassesY: array(float) = Belt.Array.make(tl - 1, 0.0);
let {xs, ys}: XYShape.T.t = t |> getShape;
for (x in 0 to E.A.length(xs) - 2) {
let _ =
Belt.Array.set(
pointMassesY,
x,
(xs[x + 1] -. xs[x]) *. ((ys[x] +. ys[x + 1]) /. 2.),
); // = dx * (1/2) * (avgY)
let _ =
Belt.Array.set(
pointMassesX,
x,
(xs[x] +. xs[x + 1]) /. 2.,
); // midpoints
();
};
{
xyShape: {
xs: pointMassesX,
ys: pointMassesY,
},
knownIntegralSum: t.knownIntegralSum,
};
};
let mapY = (~knownIntegralSumFn=previousKnownIntegralSum => None, fn, t: t) => {
let u = E.O.bind(_, knownIntegralSumFn); let u = E.O.bind(_, knownIntegralSumFn);
let yMapFn = shapeMap(XYShape.T.mapY(fn)); let yMapFn = shapeMap(XYShape.T.mapY(fn));
@ -247,7 +214,9 @@ module Continuous = {
}; };
// TODO: This should work with stepwise plots. // TODO: This should work with stepwise plots.
let integral = (~cache, t) => let integral = (~cache, t) => {
if ((t |> getShape |> XYShape.T.length) > 0) {
switch (cache) { switch (cache) {
| Some(cache) => cache | Some(cache) => cache
| None => | None =>
@ -257,6 +226,11 @@ module Continuous = {
|> E.O.toExt("This should not have happened") |> E.O.toExt("This should not have happened")
|> make(`Linear, _, None) |> make(`Linear, _, None)
}; };
} else {
make(`Linear, {xs: [|neg_infinity|], ys: [|0.0|]}, None);
}
};
let downsample = (~cache=None, length, t): t => let downsample = (~cache=None, length, t): t =>
t t
|> shapeMap( |> shapeMap(
@ -287,6 +261,7 @@ module Continuous = {
let indefiniteIntegralStepwise = (p, h1) => h1 *. p ** 2.0 /. 2.0; let indefiniteIntegralStepwise = (p, h1) => h1 *. p ** 2.0 /. 2.0;
let indefiniteIntegralLinear = (p, a, b) => let indefiniteIntegralLinear = (p, a, b) =>
a *. p ** 2.0 /. 2.0 +. b *. p ** 3.0 /. 3.0; a *. p ** 2.0 /. 2.0 +. b *. p ** 3.0 /. 3.0;
XYShape.Analysis.integrateContinuousShape( XYShape.Analysis.integrateContinuousShape(
~indefiniteIntegralStepwise, ~indefiniteIntegralStepwise,
~indefiniteIntegralLinear, ~indefiniteIntegralLinear,
@ -302,24 +277,16 @@ module Continuous = {
}); });
/* Performs a discrete convolution between two continuous distributions A and B. /* This simply creates multiple copies of the continuous distribution, scaled and shifted according to
* It is an extremely good idea to downsample the distributions beforehand, each discrete data point, and then adds them all together. */
* because the number of samples in the convolution can be up to length(A) * length(B). let combineAlgebraicallyWithDiscrete = (~downsample=false, op: AlgebraicCombinations.algebraicOperation, t1: t, t2: DistTypes.discreteShape) => {
*
* Conventional convolution uses fn = (+.), but we also allow other operations to combine the xs.
*
* In practice, the convolution works by multiplying the ys for each possible combo of points of
* the two shapes. This creates a new shape for each point of A. These new shapes are then combined
* linearly. This may not always be the most efficient way, but it is probably the most robust for now.
*
* In the future, it may be possible to use a non-uniform fast Fourier transform instead (although only for addition).
*/
let convolveWithDiscrete = (~downsample=false, fn, t1: t, t2: DistTypes.discreteShape) => {
let t1s = t1 |> getShape; let t1s = t1 |> getShape;
let t2s = t2.xyShape; // would like to use Discrete.getShape here, but current file structure doesn't allow for that let t2s = t2.xyShape; // would like to use Discrete.getShape here, but current file structure doesn't allow for that
let t1n = t1s |> XYShape.T.length; let t1n = t1s |> XYShape.T.length;
let t2n = t2s |> XYShape.T.length; let t2n = t2s |> XYShape.T.length;
let fn = AlgebraicCombinations.operationToFn(op);
let outXYShapes: array(array((float, float))) = let outXYShapes: array(array((float, float))) =
Belt.Array.makeUninitializedUnsafe(t2n); Belt.Array.makeUninitializedUnsafe(t2n);
@ -351,125 +318,19 @@ module Continuous = {
|> updateKnownIntegralSum(combinedIntegralSum); |> updateKnownIntegralSum(combinedIntegralSum);
}; };
/* This function takes a continuous distribution and efficiently approximates it as let combineAlgebraically = (~downsample=false, op: AlgebraicCombinations.algebraicOperation, t1: t, t2: t) => {
point masses that have variances associated with them. let s1 = t1 |> getShape;
We estimate the means and variances from overlapping triangular distributions which we imagine are making up the let s2 = t2 |> getShape;
XYShape. let t1n = s1 |> XYShape.T.length;
We can then use the algebra of random variables to "convolve" the point masses and their variances, let t2n = s2 |> XYShape.T.length;
and finally reconstruct a new distribution from them, e.g. using a Fast Gauss Transform or Raykar et al. (2007). */ if (t1n == 0 || t2n == 0) {
type pointMassesWithMoments = { empty;
n: int,
masses: array(float),
means: array(float),
variances: array(float)
};
let toDiscretePointMassesFromTriangulars = (~inverse=False, t: t): pointMassesWithMoments => {
// TODO: what if there is only one point in the distribution?
let s = t |> getShape;
let n = s |> XYShape.T.length;
// first, double up the leftmost and rightmost points:
let {xs, ys}: XYShape.T.t = s;
let _ = Js.Array.unshift(xs[0], xs);
let _ = Js.Array.unshift(ys[0], ys);
let _ = Js.Array.push(xs[n - 1], xs);
let _ = Js.Array.push(ys[n - 1], ys);
let n = E.A.length(xs);
// squares and neighbourly products of the xs
let xsSq: array(float) = Belt.Array.makeUninitializedUnsafe(n);
let xsProdN1: array(float) = Belt.Array.makeUninitializedUnsafe(n - 1);
let xsProdN2: array(float) = Belt.Array.makeUninitializedUnsafe(n - 2);
for (i in 0 to n - 1) {
let _ = Belt.Array.set(xsSq, i, xs[i] *. xs[i]); ();
};
for (i in 0 to n - 2) {
let _ = Belt.Array.set(xsProdN1, i, xs[i] *. xs[i + 1]); ();
};
for (i in 0 to n - 3) {
let _ = Belt.Array.set(xsProdN2, i, xs[i] *. xs[i + 2]); ();
};
// means and variances
let masses: array(float) = Belt.Array.makeUninitializedUnsafe(n);
let means: array(float) = Belt.Array.makeUninitializedUnsafe(n);
let variances: array(float) = Belt.Array.makeUninitializedUnsafe(n);
if (inverse) {
for (i in 1 to n - 2) {
let _ = Belt.Array.set(masses, i - 1, (xs[i + 1] -. xs[i - 1]) *. ys[i] /. 2.);
// this only works when the whole triange is either on the left or on the right of zero
let a = xs[i - 1];
let c = xs[i];
let b = xs[i + 1];
// These are the moments of the reciprocal of a triangular distribution, as symbolically integrated by Mathematica.
// They're probably pretty close to invMean ~ 1/mean = 3/(a+b+c) and invVar. But I haven't worked out
// the worst case error, so for now let's use these monster equations
let inverseMean = 2. *. ((a *. log(a/.c) /. (a-.c)) +. ((b *. log(c/.b))/.(b-.c))) /. (a -. b);
let inverseVar = 2. *. ((log(c/.a) /. (a-.c)) +. ((b *. log(b/.c))/.(b-.c))) /. (a -. b) - inverseMean ** 2.;
let _ = Belt.Array.set(means, i - 1, inverseMean);
let _ = Belt.Array.set(variances, i - 1, inverseVar);
();
};
{n, masses, means, variances};
} else { } else {
for (i in 1 to n - 2) { let combinedShape = AlgebraicCombinations.combineShapesContinuousContinuous(op, s1, s2);
let _ = Belt.Array.set(masses, i - 1, (xs[i + 1] -. xs[i - 1]) *. ys[i] /. 2.); let combinedIntegralSum = Common.combineIntegralSums((a, b) => Some(a *. b), t1.knownIntegralSum, t2.knownIntegralSum);
let _ = Belt.Array.set(means, i - 1, (xs[i - 1] +. xs[i] +. xs[i + 1]) /. 3.); // return a new Continuous distribution
make(`Linear, combinedShape, combinedIntegralSum);
let _ = Belt.Array.set(variances, i - 1,
(xsSq[i-1] +. xsSq[i] +. xsSq[i+1] -. xsProdN1[i-1] -. xsProdN1[i] -. xsProdN2[i-1]) /. 18.);
();
}; };
{n, masses, means, variances};
};
};
let convolve = (~downsample=false, fn, t1: t, t2: t) => {
let downsampleIfTooLarge = (t: t) => {
let sqtl = sqrt(float_of_int(t |> getShape |> XYShape.T.length));
sqtl > 10. && downsample && false ? T.downsample(int_of_float(sqtl), t) : t;
};
let t1d = downsampleIfTooLarge(t1);
let t2d = downsampleIfTooLarge(t2);
// if we add the two distributions, we should probably use normal filters.
// if we multiply the two distributions, we should probably use lognormal filters.
let t1m = toDiscretePointMassesFromTriangulars(t1);
let t2m = toDiscretePointMassesFromTriangulars(t2);
let convolveMeansFn = (TreeNode.standardOp) => fun
| `Add => (m1, m2) => m1 +. m2
| `Subtract => (m1, m2) => m1 -. m2
| `Multiply => (m1, m2) => m1 *. m2
| `Divide => (m1, mInv2) => m1 *. mInv2; // note: here, mInv2 = mean(1 / t2)
// converts the variances and means of the two inputs into the variance of the output
let convolveVariancesFn = (TreeNode.standardOp) => fun
| `Add => (v1, v2, m1, m2) => v1 +. v2
| `Subtract => (v1, v2, m1, m2) => v1 +. v2
| `Multiply => (v1, v2, m1, m2) => (v1 *. v2) +. (v1 *. m1**2.) +. (v2 *. m1**2.)
| `Divide => (v1, vInv2, m1, mInv2) => (v1 *. vInv2) +. (v1 *. mInv2**2.) +. (vInv2 *. m1**2.);
let masses: array(float) = Belt.Array.makeUninitializedUnsafe(t1m.n * t2m.n);
let means: array(float) = Belt.Array.makeUninitializedUnsafe(t1m.n * t2m.n);
let variances: array(float) = Belt.Array.makeUninitializedUnsafe(t1m.n * t2m.n);
// then convolve the two sets of pointMassesWithMoments
for (i in 0 to t1m.n - 1) {
for (j in 0 to t2m.n - 1) {
let k = i * t2m.n + j;
let _ = Belt.Array.set(masses, k, t1m.masses[i] *. t2m.masses[j]);
let _ = Belt.Array.set(means, k, convolveMeansFn(t1m.means[i], t2m.means[j]));
let _ = Belt.Array.set(variances, k, convolveMeansFn(t1m.variances[i], t2m.variances[j], t1m.means[i], t2m.means[j]));
};
};
// now, run a Fast Gauss transform to estimate the new distribution:
}; };
}; };
@ -490,7 +351,7 @@ module Discrete = {
let lastY = (t: t) => t |> getShape |> XYShape.T.lastY; let lastY = (t: t) => t |> getShape |> XYShape.T.lastY;
let combine = let combinePointwise =
( (
~knownIntegralSumsFn, ~knownIntegralSumsFn,
fn, fn,
@ -506,7 +367,7 @@ module Discrete = {
); );
make( make(
XYShape.Combine.combine( XYShape.PointwiseCombination.combine(
~xsSelection=ALL_XS, ~xsSelection=ALL_XS,
~xToYSelection=XYShape.XtoY.stepwiseIfAtX, ~xToYSelection=XYShape.XtoY.stepwiseIfAtX,
~fn=((a, b) => fn(E.O.default(0.0, a), E.O.default(0.0, b))), // stepwiseIfAtX returns option(float), so this fn needs to handle None ~fn=((a, b) => fn(E.O.default(0.0, a), E.O.default(0.0, b))), // stepwiseIfAtX returns option(float), so this fn needs to handle None
@ -519,14 +380,16 @@ module Discrete = {
let reduce = (~knownIntegralSumsFn=(_, _) => None, fn, discreteShapes): DistTypes.discreteShape => let reduce = (~knownIntegralSumsFn=(_, _) => None, fn, discreteShapes): DistTypes.discreteShape =>
discreteShapes discreteShapes
|> E.A.fold_left(combine(~knownIntegralSumsFn, fn), empty); |> E.A.fold_left(combinePointwise(~knownIntegralSumsFn, fn), empty);
let updateKnownIntegralSum = (knownIntegralSum, t: t): t => { let updateKnownIntegralSum = (knownIntegralSum, t: t): t => {
...t, ...t,
knownIntegralSum, knownIntegralSum,
}; };
let convolve = (fn, t1: t, t2: t) => { /* This multiples all of the data points together and creates a new discrete distribution from the results.
Data points at the same xs get added together. It may be a good idea to downsample t1 and t2 before and/or the result after. */
let combineAlgebraically = (op: AlgebraicCombinations.algebraicOperation, t1: t, t2: t) => {
let t1s = t1 |> getShape; let t1s = t1 |> getShape;
let t2s = t2 |> getShape; let t2s = t2 |> getShape;
let t1n = t1s |> XYShape.T.length; let t1n = t1s |> XYShape.T.length;
@ -539,6 +402,7 @@ module Discrete = {
t2.knownIntegralSum, t2.knownIntegralSum,
); );
let fn = AlgebraicCombinations.operationToFn(op);
let xToYMap = E.FloatFloatMap.empty(); let xToYMap = E.FloatFloatMap.empty();
for (i in 0 to t1n - 1) { for (i in 0 to t1n - 1) {
@ -553,9 +417,9 @@ module Discrete = {
let rxys = xToYMap |> E.FloatFloatMap.toArray |> XYShape.Zipped.sortByX; let rxys = xToYMap |> E.FloatFloatMap.toArray |> XYShape.Zipped.sortByX;
let convolvedShape = XYShape.T.fromZippedArray(rxys); let combinedShape = XYShape.T.fromZippedArray(rxys);
make(convolvedShape, combinedIntegralSum); make(combinedShape, combinedIntegralSum);
}; };
let mapY = (~knownIntegralSumFn=previousKnownIntegralSum => None, fn, t: t) => { let mapY = (~knownIntegralSumFn=previousKnownIntegralSum => None, fn, t: t) => {
@ -577,7 +441,8 @@ module Discrete = {
Dist({ Dist({
type t = DistTypes.discreteShape; type t = DistTypes.discreteShape;
type integral = DistTypes.continuousShape; type integral = DistTypes.continuousShape;
let integral = (~cache, t) => let integral = (~cache, t) => {
if ((t |> getShape |> XYShape.T.length) > 0) {
switch (cache) { switch (cache) {
| Some(c) => c | Some(c) => c
| None => | None =>
@ -587,6 +452,10 @@ module Discrete = {
None, None,
) )
}; };
} else {
Continuous.make(`Stepwise, {xs: [|neg_infinity|], ys: [|0.0|]}, None);
}};
let integralEndY = (~cache, t: t) => let integralEndY = (~cache, t: t) =>
t.knownIntegralSum t.knownIntegralSum
|> E.O.default(t |> integral(~cache) |> Continuous.lastY); |> E.O.default(t |> integral(~cache) |> Continuous.lastY);
@ -612,7 +481,7 @@ module Discrete = {
// The best we can do is to clip off the smallest values. // The best we can do is to clip off the smallest values.
let currentLength = t |> getShape |> XYShape.T.length; let currentLength = t |> getShape |> XYShape.T.length;
if (i < currentLength) { if (i < currentLength && i >= 1 && currentLength > 1) {
let clippedShape = let clippedShape =
t t
|> getShape |> getShape
@ -696,7 +565,7 @@ module Mixed = {
let toContinuous = ({continuous}: t) => Some(continuous); let toContinuous = ({continuous}: t) => Some(continuous);
let toDiscrete = ({discrete}: t) => Some(discrete); let toDiscrete = ({discrete}: t) => Some(discrete);
let combine = (~knownIntegralSumsFn, fn, t1: t, t2: t) => { let combinePointwise = (~knownIntegralSumsFn, fn, t1: t, t2: t) => {
let reducedDiscrete = let reducedDiscrete =
[|t1, t2|] [|t1, t2|]
|> E.A.fmap(toDiscrete) |> E.A.fmap(toDiscrete)
@ -829,7 +698,7 @@ module Mixed = {
Continuous.make( Continuous.make(
`Linear, `Linear,
XYShape.Combine.combineLinear( XYShape.PointwiseCombination.combineLinear(
~fn=(+.), ~fn=(+.),
Continuous.getShape(continuousIntegral), Continuous.getShape(continuousIntegral),
Continuous.getShape(discreteIntegral), Continuous.getShape(discreteIntegral),
@ -940,7 +809,7 @@ module Mixed = {
}; };
}); });
let convolve = (~downsample=false, fn: (float, float) => float, t1: t, t2: t): t => { let combineAlgebraically = (~downsample=false, op: AlgebraicCombinations.algebraicOperation, t1: t, t2: t): t => {
// Discrete convolution can cause a huge increase in the number of samples, // Discrete convolution can cause a huge increase in the number of samples,
// so we'll first downsample. // so we'll first downsample.
@ -958,17 +827,17 @@ module Mixed = {
// continuous (*) continuous => continuous, but also // continuous (*) continuous => continuous, but also
// discrete (*) continuous => continuous (and vice versa). We have to take care of all combos and then combine them: // discrete (*) continuous => continuous (and vice versa). We have to take care of all combos and then combine them:
let ccConvResult = let ccConvResult =
Continuous.convolve(~downsample=false, fn, t1d.continuous, t2d.continuous); Continuous.combineAlgebraically(~downsample=false, op, t1d.continuous, t2d.continuous);
let dcConvResult = let dcConvResult =
Continuous.convolveWithDiscrete(~downsample=false, fn, t2d.continuous, t1d.discrete); Continuous.combineAlgebraicallyWithDiscrete(~downsample=false, op, t2d.continuous, t1d.discrete);
let cdConvResult = let cdConvResult =
Continuous.convolveWithDiscrete(~downsample=false, fn, t1d.continuous, t2d.discrete); Continuous.combineAlgebraicallyWithDiscrete(~downsample=false, op, t1d.continuous, t2d.discrete);
let continuousConvResult = let continuousConvResult =
Continuous.reduce((+.), [|ccConvResult, dcConvResult, cdConvResult|]); Continuous.reduce((+.), [|ccConvResult, dcConvResult, cdConvResult|]);
// ... finally, discrete (*) discrete => discrete, obviously: // ... finally, discrete (*) discrete => discrete, obviously:
let discreteConvResult = let discreteConvResult =
Discrete.convolve(fn, t1d.discrete, t2d.discrete); Discrete.combineAlgebraically(op, t1d.discrete, t2d.discrete);
{discrete: discreteConvResult, continuous: continuousConvResult}; {discrete: discreteConvResult, continuous: continuousConvResult};
}; };
@ -997,25 +866,38 @@ module Shape = {
c => Mixed.make(~discrete=Discrete.empty, ~continuous=c), c => Mixed.make(~discrete=Discrete.empty, ~continuous=c),
)); ));
let convolve = (fn, t1: t, t2: t): t => { let combineAlgebraically = (op: AlgebraicCombinations.algebraicOperation, t1: t, t2: t): t => {
switch ((t1, t2)) { switch ((t1, t2)) {
| (Continuous(m1), Continuous(m2)) => DistTypes.Continuous(Continuous.convolve(~downsample=true, fn, m1, m2)) | (Continuous(m1), Continuous(m2)) => DistTypes.Continuous(Continuous.combineAlgebraically(~downsample=true, op, m1, m2))
| (Discrete(m1), Discrete(m2)) => DistTypes.Discrete(Discrete.convolve(fn, m1, m2)) | (Discrete(m1), Discrete(m2)) => DistTypes.Discrete(Discrete.combineAlgebraically(op, m1, m2))
| (m1, m2) => { | (m1, m2) => {
DistTypes.Mixed(Mixed.convolve(~downsample=true, fn, toMixed(m1), toMixed(m2))) DistTypes.Mixed(Mixed.combineAlgebraically(~downsample=true, op, toMixed(m1), toMixed(m2)))
} }
}; };
}; };
let combine = (~knownIntegralSumsFn=(_, _) => None, fn, t1: t, t2: t) => let combinePointwise = (~knownIntegralSumsFn=(_, _) => None, fn, t1: t, t2: t) =>
switch ((t1, t2)) { switch ((t1, t2)) {
| (Continuous(m1), Continuous(m2)) => DistTypes.Continuous(Continuous.combine(~knownIntegralSumsFn, fn, m1, m2)) | (Continuous(m1), Continuous(m2)) => DistTypes.Continuous(Continuous.combinePointwise(~knownIntegralSumsFn, fn, m1, m2))
| (Discrete(m1), Discrete(m2)) => DistTypes.Discrete(Discrete.combine(~knownIntegralSumsFn, fn, m1, m2)) | (Discrete(m1), Discrete(m2)) => DistTypes.Discrete(Discrete.combinePointwise(~knownIntegralSumsFn, fn, m1, m2))
| (m1, m2) => { | (m1, m2) => {
DistTypes.Mixed(Mixed.combine(~knownIntegralSumsFn, fn, toMixed(m1), toMixed(m2))) DistTypes.Mixed(Mixed.combinePointwise(~knownIntegralSumsFn, fn, toMixed(m1), toMixed(m2)))
} }
}; };
// TODO: implement these functions
let pdf = (f: float, t: t): float => {
0.0;
};
let inv = (f: float, t: t): float => {
0.0;
};
let sample = (t: t): float => {
0.0;
};
module T = module T =
Dist({ Dist({
type t = DistTypes.shape; type t = DistTypes.shape;
@ -1271,7 +1153,9 @@ module DistPlus = {
let integralYtoX = (~cache as _, f, t: t) => { let integralYtoX = (~cache as _, f, t: t) => {
Shape.T.Integral.yToX(~cache=Some(t.integralCache), f, toShape(t)); Shape.T.Integral.yToX(~cache=Some(t.integralCache), f, toShape(t));
}; };
let mean = (t: t) => Shape.T.mean(t.shape); let mean = (t: t) => {
Shape.T.mean(t.shape);
};
let variance = (t: t) => Shape.T.variance(t.shape); let variance = (t: t) => Shape.T.variance(t.shape);
}); });
}; };

View File

@ -170,7 +170,7 @@ module Zipped = {
let filterByX = (testFn: (float => bool), t: zipped) => t |> E.A.filter(((x, _)) => testFn(x)); let filterByX = (testFn: (float => bool), t: zipped) => t |> E.A.filter(((x, _)) => testFn(x));
}; };
module Combine = { module PointwiseCombination = {
type xsSelection = type xsSelection =
| ALL_XS | ALL_XS
| XS_EVENLY_DIVIDED(int); | XS_EVENLY_DIVIDED(int);
@ -278,7 +278,7 @@ module Range = {
items items
|> Belt.Array.map(_, rangePointAssumingSteps) |> Belt.Array.map(_, rangePointAssumingSteps)
|> T.fromZippedArray |> T.fromZippedArray
|> Combine.intersperse(t |> T.mapX(e => e +. diff)), |> PointwiseCombination.intersperse(t |> T.mapX(e => e +. diff)),
) )
| _ => Some(t) | _ => Some(t)
}; };
@ -300,7 +300,7 @@ let pointLogScore = (prediction, answer) =>
}; };
let logScorePoint = (sampleCount, t1, t2) => let logScorePoint = (sampleCount, t1, t2) =>
Combine.combine( PointwiseCombination.combine(
~xsSelection=XS_EVENLY_DIVIDED(sampleCount), ~xsSelection=XS_EVENLY_DIVIDED(sampleCount),
~xToYSelection=XtoY.linear, ~xToYSelection=XtoY.linear,
~fn=pointLogScore, ~fn=pointLogScore,
@ -328,6 +328,7 @@ module Analysis = {
0.0, 0.0,
(acc, _x, i) => { (acc, _x, i) => {
let areaUnderIntegral = let areaUnderIntegral =
// TODO Take this switch statement out of the loop body
switch (t.interpolation, i) { switch (t.interpolation, i) {
| (_, 0) => 0.0 | (_, 0) => 0.0
| (`Stepwise, _) => | (`Stepwise, _) =>
@ -336,6 +337,9 @@ module Analysis = {
| (`Linear, _) => | (`Linear, _) =>
let x1 = xs[i - 1]; let x1 = xs[i - 1];
let x2 = xs[i]; let x2 = xs[i];
if (x1 == x2) {
0.0
} else {
let h1 = ys[i - 1]; let h1 = ys[i - 1];
let h2 = ys[i]; let h2 = ys[i];
let b = (h1 -. h2) /. (x1 -. x2); let b = (h1 -. h2) /. (x1 -. x2);
@ -343,6 +347,7 @@ module Analysis = {
indefiniteIntegralLinear(x2, a, b) indefiniteIntegralLinear(x2, a, b)
-. indefiniteIntegralLinear(x1, a, b); -. indefiniteIntegralLinear(x1, a, b);
}; };
};
acc +. areaUnderIntegral; acc +. areaUnderIntegral;
}, },
); );

View File

@ -175,13 +175,13 @@ module MathAdtToDistDst = {
|> E.A.fmapi((index, t) => { |> E.A.fmapi((index, t) => {
let w = weights |> E.A.get(_, index) |> E.O.default(1.0); let w = weights |> E.A.get(_, index) |> E.O.default(1.0);
`Operation(`ScaleOperation(`Multiply, t, `DistData(`Symbolic(`Float(w))))) `Operation(`VerticalScaling(`Multiply, t, `DistData(`Symbolic(`Float(w)))))
}); });
let pointwiseSum = components let pointwiseSum = components
|> Js.Array.sliceFrom(1) |> Js.Array.sliceFrom(1)
|> E.A.fold_left((acc, x) => { |> E.A.fold_left((acc, x) => {
`Operation(`PointwiseOperation(`Add, acc, x)) `Operation(`PointwiseCombination(`Add, acc, x))
}, E.A.unsafe_get(components, 0)) }, E.A.unsafe_get(components, 0))
Ok(`Operation(`Normalize(pointwiseSum))) Ok(`Operation(`Normalize(pointwiseSum)))
@ -251,25 +251,31 @@ module MathAdtToDistDst = {
multiModal(dists, weights); multiModal(dists, weights);
} }
// TODO: wire up these FloatFromDist operations
| Fn({name: "mean", args}) => Error("mean(...) not yet implemented.")
| Fn({name: "inv", args}) => Error("inv(...) not yet implemented.")
| Fn({name: "sample", args}) => Error("sample(...) not yet implemented.")
| Fn({name: "pdf", args}) => Error("pdf(...) not yet implemented.")
| Fn({name: "add", args}) => { | Fn({name: "add", args}) => {
args args
|> E.A.fmap(functionParser) |> E.A.fmap(functionParser)
|> (fun |> (fun
| [|Ok(l), Ok(r)|] => Ok(`Operation(`StandardOperation(`Add, l, r))) | [|Ok(l), Ok(r)|] => Ok(`Operation(`AlgebraicCombination(`Add, l, r)))
| _ => Error("Addition needs two operands")) | _ => Error("Addition needs two operands"))
} }
| Fn({name: "subtract", args}) => { | Fn({name: "subtract", args}) => {
args args
|> E.A.fmap(functionParser) |> E.A.fmap(functionParser)
|> (fun |> (fun
| [|Ok(l), Ok(r)|] => Ok(`Operation(`StandardOperation(`Subtract, l, r))) | [|Ok(l), Ok(r)|] => Ok(`Operation(`AlgebraicCombination(`Subtract, l, r)))
| _ => Error("Subtraction needs two operands")) | _ => Error("Subtraction needs two operands"))
} }
| Fn({name: "multiply", args}) => { | Fn({name: "multiply", args}) => {
args args
|> E.A.fmap(functionParser) |> E.A.fmap(functionParser)
|> (fun |> (fun
| [|Ok(l), Ok(r)|] => Ok(`Operation(`StandardOperation(`Multiply, l, r))) | [|Ok(l), Ok(r)|] => Ok(`Operation(`AlgebraicCombination(`Multiply, l, r)))
| _ => Error("Multiplication needs two operands")) | _ => Error("Multiplication needs two operands"))
} }
| Fn({name: "divide", args}) => { | Fn({name: "divide", args}) => {
@ -277,16 +283,18 @@ module MathAdtToDistDst = {
|> E.A.fmap(functionParser) |> E.A.fmap(functionParser)
|> (fun |> (fun
| [|Ok(l), Ok(`DistData(`Symbolic(`Float(0.0))))|] => Error("Division by zero") | [|Ok(l), Ok(`DistData(`Symbolic(`Float(0.0))))|] => Error("Division by zero")
| [|Ok(l), Ok(r)|] => Ok(`Operation(`StandardOperation(`Divide, l, r))) | [|Ok(l), Ok(r)|] => Ok(`Operation(`AlgebraicCombination(`Divide, l, r)))
| _ => Error("Division needs two operands")) | _ => Error("Division needs two operands"))
} }
// TODO: Figure out how to implement meaningful exponentiation
| Fn({name: "pow", args}) => { | Fn({name: "pow", args}) => {
args args
|> E.A.fmap(functionParser) |> E.A.fmap(functionParser)
|> (fun |> (fun
| [|Ok(l), Ok(r)|] => Ok(`Operation(`StandardOperation(`Exponentiate, l, r))) //| [|Ok(l), Ok(r)|] => Ok(`Operation(`AlgebraicCombination(`Exponentiate, l, r)))
| _ => Error("Division needs two operands") //| _ => Error("Exponentiations needs two operands"))
| _ => Error("Exponentiations needs two operands")) | _ => Error("Exponentiation is not yet supported.")
)
} }
| Fn({name: "leftTruncate", args}) => { | Fn({name: "leftTruncate", args}) => {
args args

View File

@ -5,13 +5,6 @@ type distData = [
| `RenderedShape(DistTypes.shape) | `RenderedShape(DistTypes.shape)
]; ];
type standardOperation = [
| `Add
| `Multiply
| `Subtract
| `Divide
| `Exponentiate
];
type pointwiseOperation = [ | `Add | `Multiply]; type pointwiseOperation = [ | `Add | `Multiply];
type scaleOperation = [ | `Multiply | `Exponentiate | `Log]; type scaleOperation = [ | `Multiply | `Exponentiate | `Log];
type distToFloatOperation = [ | `Pdf(float) | `Inv(float) | `Mean | `Sample]; type distToFloatOperation = [ | `Pdf(float) | `Inv(float) | `Mean | `Sample];
@ -23,14 +16,14 @@ type treeNode = [
] ]
and operation = [ and operation = [
| // binary operations | // binary operations
`StandardOperation( `AlgebraicCombination(
standardOperation, AlgebraicCombinations.algebraicOperation,
treeNode, treeNode,
treeNode, treeNode,
) )
// unary operations // unary operations
| `PointwiseOperation(pointwiseOperation, treeNode, treeNode) // always evaluates to `DistData(`RenderedShape(...)) | `PointwiseCombination(pointwiseOperation, treeNode, treeNode) // always evaluates to `DistData(`RenderedShape(...))
| `ScaleOperation(scaleOperation, treeNode, treeNode) // always evaluates to `DistData(`RenderedShape(...)) | `VerticalScaling(scaleOperation, treeNode, treeNode) // always evaluates to `DistData(`RenderedShape(...))
| `Render(treeNode) // always evaluates to `DistData(`RenderedShape(...)) | `Render(treeNode) // always evaluates to `DistData(`RenderedShape(...))
| `Truncate // always evaluates to `DistData(`RenderedShape(...)) | `Truncate // always evaluates to `DistData(`RenderedShape(...))
( (
@ -56,15 +49,14 @@ module TreeNode = {
type simplifier = treeNode => result(treeNode, string); type simplifier = treeNode => result(treeNode, string);
let rec toString = (t: t): string => { let rec toString = (t: t): string => {
let stringFromStandardOperation = let stringFromAlgebraicCombination =
fun fun
| `Add => " + " | `Add => " + "
| `Subtract => " - " | `Subtract => " - "
| `Multiply => " * " | `Multiply => " * "
| `Divide => " / " | `Divide => " / "
| `Exponentiate => "^";
let stringFromPointwiseOperation = let stringFromPointwiseCombination =
fun fun
| `Add => " .+ " | `Add => " .+ "
| `Multiply => " .* "; | `Multiply => " .* ";
@ -81,11 +73,11 @@ module TreeNode = {
| `DistData(`Symbolic(d)) => | `DistData(`Symbolic(d)) =>
SymbolicDist.GenericDistFunctions.toString(d) SymbolicDist.GenericDistFunctions.toString(d)
| `DistData(`RenderedShape(s)) => "[shape]" | `DistData(`RenderedShape(s)) => "[shape]"
| `Operation(`StandardOperation(op, t1, t2)) => | `Operation(`AlgebraicCombination(op, t1, t2)) =>
toString(t1) ++ stringFromStandardOperation(op) ++ toString(t2) toString(t1) ++ stringFromAlgebraicCombination(op) ++ toString(t2)
| `Operation(`PointwiseOperation(op, t1, t2)) => | `Operation(`PointwiseCombination(op, t1, t2)) =>
toString(t1) ++ stringFromPointwiseOperation(op) ++ toString(t2) toString(t1) ++ stringFromPointwiseCombination(op) ++ toString(t2)
| `Operation(`ScaleOperation(_scaleOp, t, scaleBy)) => | `Operation(`VerticalScaling(_scaleOp, t, scaleBy)) =>
toString(t) ++ " @ " ++ toString(scaleBy) toString(t) ++ " @ " ++ toString(scaleBy)
| `Operation(`Normalize(t)) => "normalize(" ++ toString(t) ++ ")" | `Operation(`Normalize(t)) => "normalize(" ++ toString(t) ++ ")"
| `Operation(`FloatFromDist(floatFromDistOp, t)) => stringFromFloatFromDistOperation(floatFromDistOp) ++ toString(t) ++ ")" | `Operation(`FloatFromDist(floatFromDistOp, t)) => stringFromFloatFromDistOperation(floatFromDistOp) ++ toString(t) ++ ")"
@ -108,20 +100,12 @@ module TreeNode = {
of a new variable that is the result of the operation on A and B. of a new variable that is the result of the operation on A and B.
For instance, normal(0, 1) + normal(1, 1) -> normal(1, 2). For instance, normal(0, 1) + normal(1, 1) -> normal(1, 2).
In general, this is implemented via convolution. */ In general, this is implemented via convolution. */
module StandardOperation = { module AlgebraicCombination = {
let funcFromOp: (standardOperation, float, float) => float = let simplify = (algebraicOp, t1: t, t2: t): result(treeNode, string) => {
fun
| `Add => (+.)
| `Subtract => (-.)
| `Multiply => ( *. )
| `Divide => (/.)
| `Exponentiate => ( ** );
module Simplify = {
let tryCombiningFloats: simplifier = let tryCombiningFloats: simplifier =
fun fun
| `Operation( | `Operation(
`StandardOperation( `AlgebraicCombination(
`Divide, `Divide,
`DistData(`Symbolic(`Float(v1))), `DistData(`Symbolic(`Float(v1))),
`DistData(`Symbolic(`Float(0.))), `DistData(`Symbolic(`Float(0.))),
@ -129,13 +113,13 @@ module TreeNode = {
) => ) =>
Error("Cannot divide $v1 by zero.") Error("Cannot divide $v1 by zero.")
| `Operation( | `Operation(
`StandardOperation( `AlgebraicCombination(
standardOp, algebraicOp,
`DistData(`Symbolic(`Float(v1))), `DistData(`Symbolic(`Float(v1))),
`DistData(`Symbolic(`Float(v2))), `DistData(`Symbolic(`Float(v2))),
), ),
) => { ) => {
let func = funcFromOp(standardOp); let func = AlgebraicCombinations.operationToFn(algebraicOp);
Ok(`DistData(`Symbolic(`Float(func(v1, v2))))); Ok(`DistData(`Symbolic(`Float(func(v1, v2)))));
} }
| t => Ok(t); | t => Ok(t);
@ -143,7 +127,7 @@ module TreeNode = {
let tryCombiningNormals: simplifier = let tryCombiningNormals: simplifier =
fun fun
| `Operation( | `Operation(
`StandardOperation( `AlgebraicCombination(
`Add, `Add,
`DistData(`Symbolic(`Normal(n1))), `DistData(`Symbolic(`Normal(n1))),
`DistData(`Symbolic(`Normal(n2))), `DistData(`Symbolic(`Normal(n2))),
@ -151,7 +135,7 @@ module TreeNode = {
) => ) =>
Ok(`DistData(`Symbolic(SymbolicDist.Normal.add(n1, n2)))) Ok(`DistData(`Symbolic(SymbolicDist.Normal.add(n1, n2))))
| `Operation( | `Operation(
`StandardOperation( `AlgebraicCombination(
`Subtract, `Subtract,
`DistData(`Symbolic(`Normal(n1))), `DistData(`Symbolic(`Normal(n1))),
`DistData(`Symbolic(`Normal(n2))), `DistData(`Symbolic(`Normal(n2))),
@ -163,7 +147,7 @@ module TreeNode = {
let tryCombiningLognormals: simplifier = let tryCombiningLognormals: simplifier =
fun fun
| `Operation( | `Operation(
`StandardOperation( `AlgebraicCombination(
`Multiply, `Multiply,
`DistData(`Symbolic(`Lognormal(l1))), `DistData(`Symbolic(`Lognormal(l1))),
`DistData(`Symbolic(`Lognormal(l2))), `DistData(`Symbolic(`Lognormal(l2))),
@ -171,7 +155,7 @@ module TreeNode = {
) => ) =>
Ok(`DistData(`Symbolic(SymbolicDist.Lognormal.multiply(l1, l2)))) Ok(`DistData(`Symbolic(SymbolicDist.Lognormal.multiply(l1, l2))))
| `Operation( | `Operation(
`StandardOperation( `AlgebraicCombination(
`Divide, `Divide,
`DistData(`Symbolic(`Lognormal(l1))), `DistData(`Symbolic(`Lognormal(l1))),
`DistData(`Symbolic(`Lognormal(l2))), `DistData(`Symbolic(`Lognormal(l2))),
@ -180,20 +164,16 @@ module TreeNode = {
Ok(`DistData(`Symbolic(SymbolicDist.Lognormal.divide(l1, l2)))) Ok(`DistData(`Symbolic(SymbolicDist.Lognormal.divide(l1, l2))))
| t => Ok(t); | t => Ok(t);
let attempt = (standardOp, t1: t, t2: t): result(treeNode, string) => {
let originalTreeNode = let originalTreeNode =
`Operation(`StandardOperation((standardOp, t1, t2))); `Operation(`AlgebraicCombination((algebraicOp, t1, t2)));
originalTreeNode originalTreeNode
|> tryCombiningFloats |> tryCombiningFloats
|> E.R.bind(_, tryCombiningNormals) |> E.R.bind(_, tryCombiningNormals)
|> E.R.bind(_, tryCombiningLognormals); |> E.R.bind(_, tryCombiningLognormals);
}; };
};
let evaluateNumerically = (standardOp, operationToDistData, t1, t2) => {
let func = funcFromOp(standardOp);
let evaluateNumerically = (algebraicOp, operationToDistData, t1, t2) => {
// force rendering into shapes // force rendering into shapes
let renderedShape1 = operationToDistData(`Render(t1)); let renderedShape1 = operationToDistData(`Render(t1));
let renderedShape2 = operationToDistData(`Render(t2)); let renderedShape2 = operationToDistData(`Render(t2));
@ -205,7 +185,7 @@ module TreeNode = {
) => ) =>
Ok( Ok(
`DistData( `DistData(
`RenderedShape(Distributions.Shape.convolve(func, s1, s2)), `RenderedShape(Distributions.Shape.combineAlgebraically(algebraicOp, s1, s2)),
), ),
) )
| (Error(e1), _) => Error(e1) | (Error(e1), _) => Error(e1)
@ -215,21 +195,21 @@ module TreeNode = {
}; };
let evaluateToDistData = let evaluateToDistData =
(standardOp: standardOperation, operationToDistData, t1: t, t2: t) (algebraicOp: AlgebraicCombinations.algebraicOperation, operationToDistData, t1: t, t2: t)
: result(treeNode, string) => : result(treeNode, string) =>
standardOp algebraicOp
|> Simplify.attempt(_, t1, t2) |> simplify(_, t1, t2)
|> E.R.bind( |> E.R.bind(
_, _,
fun fun
| `DistData(d) => Ok(`DistData(d)) // the analytical simplifaction worked, nice! | `DistData(d) => Ok(`DistData(d)) // the analytical simplifaction worked, nice!
| `Operation(_) => | `Operation(_) =>
// if not, run the convolution // if not, run the convolution
evaluateNumerically(standardOp, operationToDistData, t1, t2), evaluateNumerically(algebraicOp, operationToDistData, t1, t2),
); );
}; };
module ScaleOperation = { module VerticalScaling = {
let fnFromOp = let fnFromOp =
fun fun
| `Multiply => ( *. ) | `Multiply => ( *. )
@ -271,7 +251,7 @@ module TreeNode = {
}; };
}; };
module PointwiseOperation = { module PointwiseCombination = {
let pointwiseAdd = (operationToDistData, t1, t2) => { let pointwiseAdd = (operationToDistData, t1, t2) => {
let renderedShape1 = operationToDistData(`Render(t1)); let renderedShape1 = operationToDistData(`Render(t1));
let renderedShape2 = operationToDistData(`Render(t2)); let renderedShape2 = operationToDistData(`Render(t2));
@ -279,7 +259,8 @@ module TreeNode = {
switch ((renderedShape1, renderedShape2)) { switch ((renderedShape1, renderedShape2)) {
| (Error(e1), _) => Error(e1) | (Error(e1), _) => Error(e1)
| (_, Error(e2)) => Error(e2) | (_, Error(e2)) => Error(e2)
| (Ok(`DistData(`RenderedShape(rs1))), Ok(`DistData(`RenderedShape(rs2)))) => Ok(`DistData(`RenderedShape(Distributions.Shape.combine(~knownIntegralSumsFn=(a, b) => Some(a +. b), (+.), rs1, rs2)))) | (Ok(`DistData(`RenderedShape(rs1))), Ok(`DistData(`RenderedShape(rs2)))) =>
Ok(`DistData(`RenderedShape(Distributions.Shape.combinePointwise(~knownIntegralSumsFn=(a, b) => Some(a +. b), (+.), rs1, rs2))))
| _ => Error("Could not perform pointwise addition.") | _ => Error("Could not perform pointwise addition.")
}; };
}; };
@ -397,10 +378,15 @@ module TreeNode = {
}; };
E.R.bind(value, v => Ok(`DistData(`Symbolic(`Float(v))))); E.R.bind(value, v => Ok(`DistData(`Symbolic(`Float(v)))));
}; };
let evaluateFromRenderedShape = let evaluateFromRenderedShape = (distToFloatOp: distToFloatOperation, rs: DistTypes.shape) : result(treeNode, string) => {
(distToFloatOp: distToFloatOperation, rs: DistTypes.shape) let value =
: result(treeNode, string) => { switch (distToFloatOp) {
Ok(`DistData(`Symbolic(`Float(Distributions.Shape.T.mean(rs))))); | `Pdf(f) => Ok(Distributions.Shape.pdf(f, rs))
| `Inv(f) => Ok(Distributions.Shape.inv(f, rs)) // TODO: this is tricky for discrete distributions, because they have a stepwise CDF
| `Sample => Ok(Distributions.Shape.sample(rs))
| `Mean => Ok(Distributions.Shape.T.mean(rs))
};
E.R.bind(value, v => Ok(`DistData(`Symbolic(`Float(v)))));
}; };
let rec evaluateToDistData = let rec evaluateToDistData =
( (
@ -480,22 +466,22 @@ module TreeNode = {
// the functions that convert the Operation nodes to DistData nodes need to // the functions that convert the Operation nodes to DistData nodes need to
// have a way to call this function on their children, if their children are themselves Operation nodes. // have a way to call this function on their children, if their children are themselves Operation nodes.
switch (op) { switch (op) {
| `StandardOperation(standardOp, t1, t2) => | `AlgebraicCombination(algebraicOp, t1, t2) =>
StandardOperation.evaluateToDistData( AlgebraicCombination.evaluateToDistData(
standardOp, algebraicOp,
operationToDistData(sampleCount), operationToDistData(sampleCount),
t1, t1,
t2 // we want to give it the option to render or simply leave it as is t2 // we want to give it the option to render or simply leave it as is
) )
| `PointwiseOperation(pointwiseOp, t1, t2) => | `PointwiseCombination(pointwiseOp, t1, t2) =>
PointwiseOperation.evaluateToDistData( PointwiseCombination.evaluateToDistData(
pointwiseOp, pointwiseOp,
operationToDistData(sampleCount), operationToDistData(sampleCount),
t1, t1,
t2, t2,
) )
| `ScaleOperation(scaleOp, t, scaleBy) => | `VerticalScaling(scaleOp, t, scaleBy) =>
ScaleOperation.evaluateToDistData( VerticalScaling.evaluateToDistData(
scaleOp, scaleOp,
operationToDistData(sampleCount), operationToDistData(sampleCount),
t, t,