diff --git a/packages/squiggle-lang/__tests__/Distributions/AlgebraicShapeCombination_test.res b/packages/squiggle-lang/__tests__/Distributions/AlgebraicShapeCombination_test.res index 7b84ee99..702b67a4 100644 --- a/packages/squiggle-lang/__tests__/Distributions/AlgebraicShapeCombination_test.res +++ b/packages/squiggle-lang/__tests__/Distributions/AlgebraicShapeCombination_test.res @@ -9,6 +9,7 @@ describe("Combining Continuous and Discrete Distributions", () => { #Multiply, {xs: [0., 1.], ys: [1., 1.]}, {xs: [-1.], ys: [1.]}, + ~discretePosition=Second, ), ), // Multiply distribution by -1 true, diff --git a/packages/squiggle-lang/__tests__/ReducerInterface/ReducerInterface_Distribution_test.res b/packages/squiggle-lang/__tests__/ReducerInterface/ReducerInterface_Distribution_test.res index 605797b9..a4569a88 100644 --- a/packages/squiggle-lang/__tests__/ReducerInterface/ReducerInterface_Distribution_test.res +++ b/packages/squiggle-lang/__tests__/ReducerInterface/ReducerInterface_Distribution_test.res @@ -54,6 +54,7 @@ describe("eval on distribution functions", () => { describe("subtract", () => { testEval("10 - normal(5, 1)", "Ok(Normal(5,1))") testEval("normal(5, 1) - 10", "Ok(Normal(-5,1))") + testEval("mean(1 - toPointSet(normal(5, 2)))", "Ok(-4.002309896304692)") }) describe("multiply", () => { testEval("normal(10, 2) * 2", "Ok(Normal(20,4))") diff --git a/packages/squiggle-lang/src/rescript/Distributions/PointSetDist/AlgebraicShapeCombination.res b/packages/squiggle-lang/src/rescript/Distributions/PointSetDist/AlgebraicShapeCombination.res index b544fb33..63600e43 100644 --- a/packages/squiggle-lang/src/rescript/Distributions/PointSetDist/AlgebraicShapeCombination.res +++ b/packages/squiggle-lang/src/rescript/Distributions/PointSetDist/AlgebraicShapeCombination.res @@ -187,16 +187,20 @@ let toDiscretePointMassesFromDiscrete = (s: PointSetTypes.xyShape): pointMassesW {n: n, masses: masses, means: means, variances: variances} } +type argumentPosition = First | Second + let combineShapesContinuousDiscrete = ( op: Operation.convolutionOperation, continuousShape: PointSetTypes.xyShape, discreteShape: PointSetTypes.xyShape, + ~discretePosition: argumentPosition, ): PointSetTypes.xyShape => { let t1n = continuousShape |> XYShape.T.length let t2n = discreteShape |> XYShape.T.length // each x pair is added/subtracted - let fn = Operation.Convolution.toFn(op) + let opFunc = Operation.Convolution.toFn(op) + let fn = discretePosition == First ? (a, b) => opFunc(b, a) : opFunc let outXYShapes: array> = Belt.Array.makeUninitializedUnsafe(t2n) @@ -207,9 +211,13 @@ let combineShapesContinuousDiscrete = ( // creates a new continuous shape for each one of the discrete points, and collects them in outXYShapes. let dxyShape: array<(float, float)> = Belt.Array.makeUninitializedUnsafe(t1n) for i in 0 to t1n - 1 { + // When this operation is flipped (like 1 - normal(5, 2)) then the + // x axis coordinates would all come out the wrong order. So we need + // to fill them out in the opposite direction + let index = discretePosition == First ? t1n - 1 - i : i Belt.Array.set( dxyShape, - i, + index, ( fn(continuousShape.xs[i], discreteShape.xs[j]), continuousShape.ys[i] *. discreteShape.ys[j], diff --git a/packages/squiggle-lang/src/rescript/Distributions/PointSetDist/Continuous.res b/packages/squiggle-lang/src/rescript/Distributions/PointSetDist/Continuous.res index a8542bae..5e44f900 100644 --- a/packages/squiggle-lang/src/rescript/Distributions/PointSetDist/Continuous.res +++ b/packages/squiggle-lang/src/rescript/Distributions/PointSetDist/Continuous.res @@ -278,6 +278,7 @@ let combineAlgebraicallyWithDiscrete = ( op: Operation.convolutionOperation, t1: t, t2: PointSetTypes.discreteShape, + ~discretePosition: AlgebraicShapeCombination.argumentPosition, ) => { let t1s = t1 |> getShape let t2s = t2.xyShape // TODO would like to use Discrete.getShape here, but current file structure doesn't allow for that @@ -294,6 +295,7 @@ let combineAlgebraicallyWithDiscrete = ( op, continuousAsLinear |> getShape, t2s, + ~discretePosition, ) let combinedIntegralSum = switch op { diff --git a/packages/squiggle-lang/src/rescript/Distributions/PointSetDist/Mixed.res b/packages/squiggle-lang/src/rescript/Distributions/PointSetDist/Mixed.res index 9961b51d..4ce2bdd6 100644 --- a/packages/squiggle-lang/src/rescript/Distributions/PointSetDist/Mixed.res +++ b/packages/squiggle-lang/src/rescript/Distributions/PointSetDist/Mixed.res @@ -277,8 +277,18 @@ let combineAlgebraically = (op: Operation.convolutionOperation, t1: t, t2: t): t // continuous (*) continuous => continuous, but also // discrete (*) continuous => continuous (and vice versa). We have to take care of all combos and then combine them: let ccConvResult = Continuous.combineAlgebraically(op, t1.continuous, t2.continuous) - let dcConvResult = Continuous.combineAlgebraicallyWithDiscrete(op, t2.continuous, t1.discrete) - let cdConvResult = Continuous.combineAlgebraicallyWithDiscrete(op, t1.continuous, t2.discrete) + let dcConvResult = Continuous.combineAlgebraicallyWithDiscrete( + op, + t2.continuous, + t1.discrete, + ~discretePosition=First, + ) + let cdConvResult = Continuous.combineAlgebraicallyWithDiscrete( + op, + t1.continuous, + t2.discrete, + ~discretePosition=Second, + ) let continuousConvResult = Continuous.sum([ccConvResult, dcConvResult, cdConvResult]) // ... finally, discrete (*) discrete => discrete, obviously: diff --git a/packages/squiggle-lang/src/rescript/Distributions/PointSetDist/PointSetDist.res b/packages/squiggle-lang/src/rescript/Distributions/PointSetDist/PointSetDist.res index 6dacc40f..12aa5477 100644 --- a/packages/squiggle-lang/src/rescript/Distributions/PointSetDist/PointSetDist.res +++ b/packages/squiggle-lang/src/rescript/Distributions/PointSetDist/PointSetDist.res @@ -46,9 +46,20 @@ let combineAlgebraically = (op: Operation.convolutionOperation, t1: t, t2: t): t switch (t1, t2) { | (Continuous(m1), Continuous(m2)) => Continuous.combineAlgebraically(op, m1, m2) |> Continuous.T.toPointSetDist - | (Continuous(m1), Discrete(m2)) - | (Discrete(m2), Continuous(m1)) => - Continuous.combineAlgebraicallyWithDiscrete(op, m1, m2) |> Continuous.T.toPointSetDist + | (Discrete(m1), Continuous(m2)) => + Continuous.combineAlgebraicallyWithDiscrete( + op, + m2, + m1, + ~discretePosition=First, + ) |> Continuous.T.toPointSetDist + | (Continuous(m1), Discrete(m2)) => + Continuous.combineAlgebraicallyWithDiscrete( + op, + m1, + m2, + ~discretePosition=Second, + ) |> Continuous.T.toPointSetDist | (Discrete(m1), Discrete(m2)) => Discrete.combineAlgebraically(op, m1, m2) |> Discrete.T.toPointSetDist | (m1, m2) => Mixed.combineAlgebraically(op, toMixed(m1), toMixed(m2)) |> Mixed.T.toPointSetDist