From 9090c442846e65c4647c0079da44e56c80bff9e7 Mon Sep 17 00:00:00 2001 From: Sam Nolan Date: Sat, 23 Apr 2022 15:58:42 -0400 Subject: [PATCH] Fix subtraction not commuting under pointsets --- .../PointSetDist/AlgebraicShapeCombination.res | 4 +++- .../Distributions/PointSetDist/Continuous.res | 2 ++ .../rescript/Distributions/PointSetDist/Mixed.res | 14 ++++++++++++-- .../Distributions/PointSetDist/PointSetDist.res | 7 ++++--- 4 files changed, 21 insertions(+), 6 deletions(-) diff --git a/packages/squiggle-lang/src/rescript/Distributions/PointSetDist/AlgebraicShapeCombination.res b/packages/squiggle-lang/src/rescript/Distributions/PointSetDist/AlgebraicShapeCombination.res index fd7871cd..d68a9623 100644 --- a/packages/squiggle-lang/src/rescript/Distributions/PointSetDist/AlgebraicShapeCombination.res +++ b/packages/squiggle-lang/src/rescript/Distributions/PointSetDist/AlgebraicShapeCombination.res @@ -191,12 +191,14 @@ let combineShapesContinuousDiscrete = ( op: Operation.convolutionOperation, continuousShape: PointSetTypes.xyShape, discreteShape: PointSetTypes.xyShape, + flip: bool, ): PointSetTypes.xyShape => { let t1n = continuousShape |> XYShape.T.length let t2n = discreteShape |> XYShape.T.length // each x pair is added/subtracted - let fn = Operation.Convolution.toFn(op) + let opFunc = Operation.Convolution.toFn(op) + let fn = flip ? (a, b) => opFunc(b, a) : opFunc let outXYShapes: array> = Belt.Array.makeUninitializedUnsafe(t2n) diff --git a/packages/squiggle-lang/src/rescript/Distributions/PointSetDist/Continuous.res b/packages/squiggle-lang/src/rescript/Distributions/PointSetDist/Continuous.res index a8542bae..de5d65ed 100644 --- a/packages/squiggle-lang/src/rescript/Distributions/PointSetDist/Continuous.res +++ b/packages/squiggle-lang/src/rescript/Distributions/PointSetDist/Continuous.res @@ -278,6 +278,7 @@ let combineAlgebraicallyWithDiscrete = ( op: Operation.convolutionOperation, t1: t, t2: PointSetTypes.discreteShape, + flip: bool, ) => { let t1s = t1 |> getShape let t2s = t2.xyShape // TODO would like to use Discrete.getShape here, but current file structure doesn't allow for that @@ -294,6 +295,7 @@ let combineAlgebraicallyWithDiscrete = ( op, continuousAsLinear |> getShape, t2s, + flip, ) let combinedIntegralSum = switch op { diff --git a/packages/squiggle-lang/src/rescript/Distributions/PointSetDist/Mixed.res b/packages/squiggle-lang/src/rescript/Distributions/PointSetDist/Mixed.res index 9961b51d..da604262 100644 --- a/packages/squiggle-lang/src/rescript/Distributions/PointSetDist/Mixed.res +++ b/packages/squiggle-lang/src/rescript/Distributions/PointSetDist/Mixed.res @@ -277,8 +277,18 @@ let combineAlgebraically = (op: Operation.convolutionOperation, t1: t, t2: t): t // continuous (*) continuous => continuous, but also // discrete (*) continuous => continuous (and vice versa). We have to take care of all combos and then combine them: let ccConvResult = Continuous.combineAlgebraically(op, t1.continuous, t2.continuous) - let dcConvResult = Continuous.combineAlgebraicallyWithDiscrete(op, t2.continuous, t1.discrete) - let cdConvResult = Continuous.combineAlgebraicallyWithDiscrete(op, t1.continuous, t2.discrete) + let dcConvResult = Continuous.combineAlgebraicallyWithDiscrete( + op, + t2.continuous, + t1.discrete, + true, + ) + let cdConvResult = Continuous.combineAlgebraicallyWithDiscrete( + op, + t1.continuous, + t2.discrete, + false, + ) let continuousConvResult = Continuous.sum([ccConvResult, dcConvResult, cdConvResult]) // ... finally, discrete (*) discrete => discrete, obviously: diff --git a/packages/squiggle-lang/src/rescript/Distributions/PointSetDist/PointSetDist.res b/packages/squiggle-lang/src/rescript/Distributions/PointSetDist/PointSetDist.res index 6dacc40f..641532d6 100644 --- a/packages/squiggle-lang/src/rescript/Distributions/PointSetDist/PointSetDist.res +++ b/packages/squiggle-lang/src/rescript/Distributions/PointSetDist/PointSetDist.res @@ -46,9 +46,10 @@ let combineAlgebraically = (op: Operation.convolutionOperation, t1: t, t2: t): t switch (t1, t2) { | (Continuous(m1), Continuous(m2)) => Continuous.combineAlgebraically(op, m1, m2) |> Continuous.T.toPointSetDist - | (Continuous(m1), Discrete(m2)) - | (Discrete(m2), Continuous(m1)) => - Continuous.combineAlgebraicallyWithDiscrete(op, m1, m2) |> Continuous.T.toPointSetDist + | (Discrete(m1), Continuous(m2)) => + Continuous.combineAlgebraicallyWithDiscrete(op, m2, m1, true) |> Continuous.T.toPointSetDist + | (Continuous(m1), Discrete(m2)) => + Continuous.combineAlgebraicallyWithDiscrete(op, m1, m2, false) |> Continuous.T.toPointSetDist | (Discrete(m1), Discrete(m2)) => Discrete.combineAlgebraically(op, m1, m2) |> Discrete.T.toPointSetDist | (m1, m2) => Mixed.combineAlgebraically(op, toMixed(m1), toMixed(m2)) |> Mixed.T.toPointSetDist