WIP: continuous/continuous convolution
This commit is contained in:
parent
f5ce4354ab
commit
f1e2458bca
|
@ -351,6 +351,83 @@ module Continuous = {
|
||||||
|> updateKnownIntegralSum(combinedIntegralSum);
|
|> updateKnownIntegralSum(combinedIntegralSum);
|
||||||
};
|
};
|
||||||
|
|
||||||
|
/* This function takes a continuous distribution and efficiently approximates it as
|
||||||
|
point masses that have variances associated with them.
|
||||||
|
We estimate the means and variances from overlapping triangular distributions which we imagine are making up the
|
||||||
|
XYShape.
|
||||||
|
We can then use the algebra of random variables to "convolve" the point masses and their variances,
|
||||||
|
and finally reconstruct a new distribution from them, e.g. using a Fast Gauss Transform or Raykar et al. (2007). */
|
||||||
|
type pointMassesWithMoments = {
|
||||||
|
n: int,
|
||||||
|
masses: array(float),
|
||||||
|
means: array(float),
|
||||||
|
variances: array(float)
|
||||||
|
};
|
||||||
|
let toDiscretePointMassesFromTriangulars = (~inverse=False, t: t): pointMassesWithMoments => {
|
||||||
|
// TODO: what if there is only one point in the distribution?
|
||||||
|
let s = t |> getShape;
|
||||||
|
let n = s |> XYShape.T.length;
|
||||||
|
// first, double up the leftmost and rightmost points:
|
||||||
|
let {xs, ys}: XYShape.T.t = s;
|
||||||
|
let _ = Js.Array.unshift(xs[0], xs);
|
||||||
|
let _ = Js.Array.unshift(ys[0], ys);
|
||||||
|
let _ = Js.Array.push(xs[n - 1], xs);
|
||||||
|
let _ = Js.Array.push(ys[n - 1], ys);
|
||||||
|
let n = E.A.length(xs);
|
||||||
|
// squares and neighbourly products of the xs
|
||||||
|
let xsSq: array(float) = Belt.Array.makeUninitializedUnsafe(n);
|
||||||
|
let xsProdN1: array(float) = Belt.Array.makeUninitializedUnsafe(n - 1);
|
||||||
|
let xsProdN2: array(float) = Belt.Array.makeUninitializedUnsafe(n - 2);
|
||||||
|
for (i in 0 to n - 1) {
|
||||||
|
let _ = Belt.Array.set(xsSq, i, xs[i] *. xs[i]); ();
|
||||||
|
};
|
||||||
|
for (i in 0 to n - 2) {
|
||||||
|
let _ = Belt.Array.set(xsProdN1, i, xs[i] *. xs[i + 1]); ();
|
||||||
|
};
|
||||||
|
for (i in 0 to n - 3) {
|
||||||
|
let _ = Belt.Array.set(xsProdN2, i, xs[i] *. xs[i + 2]); ();
|
||||||
|
};
|
||||||
|
// means and variances
|
||||||
|
let masses: array(float) = Belt.Array.makeUninitializedUnsafe(n);
|
||||||
|
let means: array(float) = Belt.Array.makeUninitializedUnsafe(n);
|
||||||
|
let variances: array(float) = Belt.Array.makeUninitializedUnsafe(n);
|
||||||
|
|
||||||
|
if (inverse) {
|
||||||
|
for (i in 1 to n - 2) {
|
||||||
|
let _ = Belt.Array.set(masses, i - 1, (xs[i + 1] -. xs[i - 1]) *. ys[i] /. 2.);
|
||||||
|
|
||||||
|
// this only works when the whole triange is either on the left or on the right of zero
|
||||||
|
let a = xs[i - 1];
|
||||||
|
let c = xs[i];
|
||||||
|
let b = xs[i + 1];
|
||||||
|
|
||||||
|
// These are the moments of the reciprocal of a triangular distribution, as symbolically integrated by Mathematica.
|
||||||
|
// They're probably pretty close to invMean ~ 1/mean = 3/(a+b+c) and invVar. But I haven't worked out
|
||||||
|
// the worst case error, so for now let's use these monster equations
|
||||||
|
let inverseMean = 2. *. ((a *. log(a/.c) /. (a-.c)) +. ((b *. log(c/.b))/.(b-.c))) /. (a -. b);
|
||||||
|
let inverseVar = 2. *. ((log(c/.a) /. (a-.c)) +. ((b *. log(b/.c))/.(b-.c))) /. (a -. b) - inverseMean ** 2.;
|
||||||
|
|
||||||
|
let _ = Belt.Array.set(means, i - 1, inverseMean);
|
||||||
|
|
||||||
|
let _ = Belt.Array.set(variances, i - 1, inverseVar);
|
||||||
|
();
|
||||||
|
};
|
||||||
|
|
||||||
|
{n, masses, means, variances};
|
||||||
|
} else {
|
||||||
|
for (i in 1 to n - 2) {
|
||||||
|
let _ = Belt.Array.set(masses, i - 1, (xs[i + 1] -. xs[i - 1]) *. ys[i] /. 2.);
|
||||||
|
let _ = Belt.Array.set(means, i - 1, (xs[i - 1] +. xs[i] +. xs[i + 1]) /. 3.);
|
||||||
|
|
||||||
|
let _ = Belt.Array.set(variances, i - 1,
|
||||||
|
(xsSq[i-1] +. xsSq[i] +. xsSq[i+1] -. xsProdN1[i-1] -. xsProdN1[i] -. xsProdN2[i-1]) /. 18.);
|
||||||
|
();
|
||||||
|
};
|
||||||
|
{n, masses, means, variances};
|
||||||
|
};
|
||||||
|
|
||||||
|
};
|
||||||
|
|
||||||
let convolve = (~downsample=false, fn, t1: t, t2: t) => {
|
let convolve = (~downsample=false, fn, t1: t, t2: t) => {
|
||||||
let downsampleIfTooLarge = (t: t) => {
|
let downsampleIfTooLarge = (t: t) => {
|
||||||
let sqtl = sqrt(float_of_int(t |> getShape |> XYShape.T.length));
|
let sqtl = sqrt(float_of_int(t |> getShape |> XYShape.T.length));
|
||||||
|
@ -360,14 +437,39 @@ module Continuous = {
|
||||||
let t1d = downsampleIfTooLarge(t1);
|
let t1d = downsampleIfTooLarge(t1);
|
||||||
let t2d = downsampleIfTooLarge(t2);
|
let t2d = downsampleIfTooLarge(t2);
|
||||||
|
|
||||||
let t1m = toDiscretePointMasses(t1);
|
// if we add the two distributions, we should probably use normal filters.
|
||||||
let t2m = toDiscretePointMasses(t2);
|
// if we multiply the two distributions, we should probably use lognormal filters.
|
||||||
|
let t1m = toDiscretePointMassesFromTriangulars(t1);
|
||||||
|
let t2m = toDiscretePointMassesFromTriangulars(t2);
|
||||||
|
|
||||||
// then convolve the two as discrete distributions
|
let convolveMeansFn = (TreeNode.standardOp) => fun
|
||||||
let c = Discrete.convolve(fn, t1m, t2m);
|
| `Add => (m1, m2) => m1 +. m2
|
||||||
|
| `Subtract => (m1, m2) => m1 -. m2
|
||||||
|
| `Multiply => (m1, m2) => m1 *. m2
|
||||||
|
| `Divide => (m1, mInv2) => m1 *. mInv2; // note: here, mInv2 = mean(1 / t2)
|
||||||
|
|
||||||
|
// converts the variances and means of the two inputs into the variance of the output
|
||||||
|
let convolveVariancesFn = (TreeNode.standardOp) => fun
|
||||||
|
| `Add => (v1, v2, m1, m2) => v1 +. v2
|
||||||
|
| `Subtract => (v1, v2, m1, m2) => v1 +. v2
|
||||||
|
| `Multiply => (v1, v2, m1, m2) => (v1 *. v2) +. (v1 *. m1**2.) +. (v2 *. m1**2.)
|
||||||
|
| `Divide => (v1, vInv2, m1, mInv2) => (v1 *. vInv2) +. (v1 *. mInv2**2.) +. (vInv2 *. m1**2.);
|
||||||
|
|
||||||
|
let masses: array(float) = Belt.Array.makeUninitializedUnsafe(t1m.n * t2m.n);
|
||||||
|
let means: array(float) = Belt.Array.makeUninitializedUnsafe(t1m.n * t2m.n);
|
||||||
|
let variances: array(float) = Belt.Array.makeUninitializedUnsafe(t1m.n * t2m.n);
|
||||||
|
// then convolve the two sets of pointMassesWithMoments
|
||||||
|
for (i in 0 to t1m.n - 1) {
|
||||||
|
for (j in 0 to t2m.n - 1) {
|
||||||
|
let k = i * t2m.n + j;
|
||||||
|
let _ = Belt.Array.set(masses, k, t1m.masses[i] *. t2m.masses[j]);
|
||||||
|
let _ = Belt.Array.set(means, k, convolveMeansFn(t1m.means[i], t2m.means[j]));
|
||||||
|
let _ = Belt.Array.set(variances, k, convolveMeansFn(t1m.variances[i], t2m.variances[j], t1m.means[i], t2m.means[j]));
|
||||||
|
};
|
||||||
|
};
|
||||||
|
|
||||||
|
// now, run a Fast Gauss transform to estimate the new distribution:
|
||||||
|
|
||||||
// then convert back to an approximate pdf
|
|
||||||
// TODO: find an efficient way to do this (kernel densities? trapezoids?)
|
|
||||||
};
|
};
|
||||||
};
|
};
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue
Block a user