From d119db88a17eefd7d449b2d0240f53ea8bcb423c Mon Sep 17 00:00:00 2001 From: Sebastian Kosch Date: Sun, 12 Jul 2020 15:53:53 -0700 Subject: [PATCH] Allow algebraic multiplication --- .../distribution/AlgebraicShapeCombination.re | 162 +++++------------- src/distPlus/distribution/Continuous.re | 50 ++---- src/distPlus/distribution/Shape.re | 5 + 3 files changed, 62 insertions(+), 155 deletions(-) diff --git a/src/distPlus/distribution/AlgebraicShapeCombination.re b/src/distPlus/distribution/AlgebraicShapeCombination.re index a4239200..9ece3541 100644 --- a/src/distPlus/distribution/AlgebraicShapeCombination.re +++ b/src/distPlus/distribution/AlgebraicShapeCombination.re @@ -245,9 +245,10 @@ let toDiscretePointMassesFromDiscrete = (s: DistTypes.xyShape): pointMassesWithM {n, masses, means, variances}; }; -let combineShapesContinuousDiscreteAdd = +let combineShapesContinuousDiscrete = (op: ExpressionTypes.algebraicOperation, s1: DistTypes.xyShape, s2: DistTypes.xyShape) - : DistTypes.xyShape => { + : DistTypes.xyShape => { + let t1n = s1 |> XYShape.T.length; let t2n = s2 |> XYShape.T.length; @@ -257,126 +258,51 @@ let combineShapesContinuousDiscreteAdd = let outXYShapes: array(array((float, float))) = Belt.Array.makeUninitializedUnsafe(t2n); - for (j in 0 to t2n - 1) { - // for each one of the discrete points - // create a new distribution, as long as the original continuous one + switch (op) { + | `Add + | `Subtract => { + for (j in 0 to t2n - 1) { + // for each one of the discrete points + // create a new distribution, as long as the original continuous one + let dxyShape: array((float, float)) = + Belt.Array.makeUninitializedUnsafe(t1n); - let dxyShape: array((float, float)) = - Belt.Array.makeUninitializedUnsafe(t1n); - for (i in 0 to t1n - 1) { - let _ = - Belt.Array.set( - dxyShape, - i, - (fn(s1.xs[i], s2.xs[j]), s1.ys[i] *. s2.ys[j]), - ); + for (i in 0 to t1n - 1) { + let _ = + Belt.Array.set( + dxyShape, + i, + (fn(s1.xs[i], s2.xs[j]), s1.ys[i] *. s2.ys[j]), + ); + (); + }; + let _ = Belt.Array.set(outXYShapes, j, dxyShape); (); - }; - - let _ = Belt.Array.set(outXYShapes, j, dxyShape); - (); + } + } + | `Multiply + | `Divide => { + for (j in 0 to t2n - 1) { + // for each one of the discrete points + // create a new distribution, as long as the original continuous one + let dxyShape: array((float, float)) = + Belt.Array.makeUninitializedUnsafe(t1n); + for (i in 0 to t1n - 1) { + let _ = + Belt.Array.set( + dxyShape, + i, + (fn(s1.xs[i], s2.xs[j]), s1.ys[i] *. s2.ys[j] /. s2.xs[j]), + ); + (); + }; + let _ = Belt.Array.set(outXYShapes, j, dxyShape); + (); + } + } }; outXYShapes + |> E.A.fmap(XYShape.T.fromZippedArray) |> E.A.fold_left(XYShape.PointwiseCombination.combineLinear((+.)), XYShape.T.empty); }; - -let combineShapesContinuousDiscreteMultiply = - (op: ExpressionTypes.algebraicOperation, s1: DistTypes.xyShape, s2: DistTypes.xyShape) - : DistTypes.xyShape => { - let t1n = s1 |> XYShape.T.length; - let t2n = s2 |> XYShape.T.length; - - let t1m = toDiscretePointMassesFromTriangulars(s1); - let t2m = toDiscretePointMassesFromDiscrete(s2); - - let combineMeansFn = - switch (op) { - | `Add => ((m1, m2) => m1 +. m2) - | `Subtract => ((m1, m2) => m1 -. m2) - | `Multiply => ((m1, m2) => m1 *. m2) - | `Divide => ((m1, m2) => m1 /. m2) - }; - - let combineVariancesFn = - switch (op) { - | `Add - | `Subtract => ((v1, v2, _, _) => v1 +. v2) - | `Multiply - | `Divide => ( - (v1, v2, m1, m2) => v1 *. m2 ** 2. - ) - }; - - let outputMinX: ref(float) = ref(infinity); - let outputMaxX: ref(float) = ref(neg_infinity); - let masses: array(float) = - Belt.Array.makeUninitializedUnsafe(t1m.n * t2m.n); - let means: array(float) = - Belt.Array.makeUninitializedUnsafe(t1m.n * t2m.n); - let variances: array(float) = - Belt.Array.makeUninitializedUnsafe(t1m.n * t2m.n); - // then convolve the two sets of pointMassesWithMoments - for (i in 0 to t1m.n - 1) { - for (j in 0 to t2m.n - 1) { - let k = i * t2m.n + j; - let _ = Belt.Array.set(masses, k, t1m.masses[i] *. t2m.masses[j]); - - let mean = combineMeansFn(t1m.means[i], t2m.means[j]); - let variance = - combineVariancesFn( - t1m.variances[i], - t2m.variances[j], - t1m.means[i], - t2m.means[j], - ); - let _ = Belt.Array.set(means, k, mean); - let _ = Belt.Array.set(variances, k, variance); - - // update bounds - let minX = mean -. 2. *. sqrt(variance) *. 1.644854; - let maxX = mean +. 2. *. sqrt(variance) *. 1.644854; - if (minX < outputMinX^) { - outputMinX := minX; - }; - if (maxX > outputMaxX^) { - outputMaxX := maxX; - }; - }; - }; - - - // we now want to create a set of target points. For now, let's just evenly distribute 200 points between - // between the outputMinX and outputMaxX - let nOut = 300; - let outputXs: array(float) = E.A.Floats.range(outputMinX^, outputMaxX^, nOut); - let outputYs: array(float) = Belt.Array.make(nOut, 0.0); - // now, for each of the outputYs, accumulate from a Gaussian kernel over each input point. - for (j in 0 to E.A.length(masses) - 1) { // go through all of the result points - let _ = if (variances[j] > 0. && masses[j] > 0.) { - for (i in 0 to E.A.length(outputXs) - 1) { // go through all of the target points - let dx = outputXs[i] -. means[j]; - let contribution = masses[j] *. exp(-. (dx ** 2.) /. (2. *. variances[j])) /. (sqrt(2. *. 3.14159276 *. variances[j])); - let _ = Belt.Array.set(outputYs, i, outputYs[i] +. contribution); - (); - }; - (); - }; - (); - }; - - {xs: outputXs, ys: outputYs}; -}; - -let combineShapesContinuousDiscrete = - (op: ExpressionTypes.algebraicOperation, s1: DistTypes.xyShape, s2: DistTypes.xyShape) - : DistTypes.xyShape => { - - switch (op) { - | `Add - | `Subtract => combineShapesContinuousDiscreteAdd(op, s1, s2); - | `Multiply - | `Divide => combineShapesContinuousDiscreteMultiply(op, s1, s2); - }; - -}; diff --git a/src/distPlus/distribution/Continuous.re b/src/distPlus/distribution/Continuous.re index 6e7b68c7..108c96ba 100644 --- a/src/distPlus/distribution/Continuous.re +++ b/src/distPlus/distribution/Continuous.re @@ -208,48 +208,24 @@ let combineAlgebraicallyWithDiscrete = ) => { let t1s = t1 |> getShape; let t2s = t2.xyShape; // would like to use Discrete.getShape here, but current file structure doesn't allow for that + let t1n = t1s |> XYShape.T.length; let t2n = t2s |> XYShape.T.length; - let fn = Operation.Algebraic.toFn(op); + if (t1n > 0 && t2n > 0) { + let combinedShape = AlgebraicShapeCombination.combineShapesContinuousDiscrete(op, t1s, t2s); - let outXYShapes: array(array((float, float))) = - Belt.Array.makeUninitializedUnsafe(t2n); + let combinedIntegralSum = + Common.combineIntegralSums( + (a, b) => Some(a *. b), + t1.knownIntegralSum, + t2.knownIntegralSum, + ); - for (j in 0 to t2n - 1) { - // for each one of the discrete points - // create a new distribution, as long as the original continuous one - - let dxyShape: array((float, float)) = - Belt.Array.makeUninitializedUnsafe(t1n); - for (i in 0 to t1n - 1) { - let _ = - Belt.Array.set( - dxyShape, - i, - (fn(t1s.xs[i], t2s.xs[j]), t1s.ys[i] *. t2s.ys[j]), - ); - (); - }; - - let _ = Belt.Array.set(outXYShapes, j, dxyShape); - (); - }; - - let combinedIntegralSum = - Common.combineIntegralSums( - (a, b) => Some(a *. b), - t1.knownIntegralSum, - t2.knownIntegralSum, - ); - - outXYShapes - |> E.A.fmap(s => { - let xyShape = XYShape.T.fromZippedArray(s); - make(`Linear, xyShape, None); - }) - |> reduce((+.)) - |> updateKnownIntegralSum(combinedIntegralSum); + make(`Linear, combinedShape, combinedIntegralSum); + } else { + empty; + } }; let combineAlgebraically = diff --git a/src/distPlus/distribution/Shape.re b/src/distPlus/distribution/Shape.re index 219e6534..e5039b11 100644 --- a/src/distPlus/distribution/Shape.re +++ b/src/distPlus/distribution/Shape.re @@ -29,6 +29,11 @@ let combineAlgebraically = DistTypes.Continuous( Continuous.combineAlgebraically(~downsample=true, op, m1, m2), ) + | (Continuous(m1), Discrete(m2)) + | (Discrete(m2), Continuous(m1)) => + DistTypes.Continuous( + Continuous.combineAlgebraicallyWithDiscrete(~downsample=false, op, m1, m2), + ) | (Discrete(m1), Discrete(m2)) => DistTypes.Discrete(Discrete.combineAlgebraically(op, m1, m2)) | (m1, m2) =>