diff --git a/packages/squiggle-lang/src/rescript/Distributions/PointSetDist/AlgebraicShapeCombination.res b/packages/squiggle-lang/src/rescript/Distributions/PointSetDist/AlgebraicShapeCombination.res index 1e412d5a..080f324a 100644 --- a/packages/squiggle-lang/src/rescript/Distributions/PointSetDist/AlgebraicShapeCombination.res +++ b/packages/squiggle-lang/src/rescript/Distributions/PointSetDist/AlgebraicShapeCombination.res @@ -202,9 +202,11 @@ let combineShapesContinuousDiscrete = ( op: Operation.algebraicOperation, continuousShape: PointSetTypes.xyShape, discreteShape: PointSetTypes.xyShape, + discreteFirst: bool, ): PointSetTypes.xyShape => { // each x pair is added/subtracted - let fn = Operation.Algebraic.toFn(op) + let opFunc = Operation.Algebraic.toFn(op) + let fn = discreteFirst ? (a, b) => opFunc(b, a) : opFunc let discretePoints = Belt.Array.zip(discreteShape.xs, discreteShape.ys) let continuousPoints = Belt.Array.zip(continuousShape.xs, continuousShape.ys) @@ -234,40 +236,3 @@ let combineShapesContinuousDiscrete = ( XYShape.T.empty, ) } - -let combineShapesDiscreteContinuous = ( - op: Operation.algebraicOperation, - discreteShape: PointSetTypes.xyShape, - continuousShape: PointSetTypes.xyShape, -): PointSetTypes.xyShape => { - // each x pair is added/subtracted - let fn = Operation.Algebraic.toFn(op) - - let discretePoints = Belt.Array.zip(discreteShape.xs, discreteShape.ys) - let continuousPoints = Belt.Array.zip(continuousShape.xs, continuousShape.ys) - - let outXYShapes = switch op { - | #Add - | #Subtract => - discretePoints->E.A2.fmap(((dx, dy)) => - continuousPoints->E.A2.fmap(((cx, cy)) => (fn(dx, cx), dy *. cy)) - ) - | #Multiply - | #Power - | #Logarithm - | #Divide => - discretePoints->E.A2.fmap(((dx, dy)) => - continuousPoints->E.A2.fmap(((cx, cy)) => (fn(dx, cx), dy *. cy /. dx)) - ) - } - - outXYShapes - |> E.A.fmap(XYShape.T.fromZippedArray) - |> E.A.fold_left( - XYShape.PointwiseCombination.combine( - \"+.", - XYShape.XtoY.continuousInterpolator(#Linear, #UseZero), - ), - XYShape.T.empty, - ) -} diff --git a/packages/squiggle-lang/src/rescript/Distributions/PointSetDist/Continuous.res b/packages/squiggle-lang/src/rescript/Distributions/PointSetDist/Continuous.res index f7d1e179..d000cb39 100644 --- a/packages/squiggle-lang/src/rescript/Distributions/PointSetDist/Continuous.res +++ b/packages/squiggle-lang/src/rescript/Distributions/PointSetDist/Continuous.res @@ -246,10 +246,11 @@ let downsampleEquallyOverX = (length, t): t => /* This simply creates multiple copies of the continuous distribution, scaled and shifted according to each discrete data point, and then adds them all together. */ -let combineAlgebraicallyWithDiscreteSecond = ( +let combineAlgebraicallyWithDiscrete = ( op: Operation.algebraicOperation, t1: t, t2: PointSetTypes.discreteShape, + discreteFirst: bool, ) => { let t1s = t1 |> getShape let t2s = t2.xyShape // TODO would like to use Discrete.getShape here, but current file structure doesn't allow for that @@ -266,6 +267,7 @@ let combineAlgebraicallyWithDiscreteSecond = ( op, continuousAsLinear |> getShape, t2s, + discreteFirst, ) let combinedIntegralSum = switch op { @@ -280,40 +282,6 @@ let combineAlgebraicallyWithDiscreteSecond = ( } } -let combineAlgebraicallyWithDiscreteFirst = ( - op: Operation.algebraicOperation, - t1: PointSetTypes.discreteShape, - t2: t, -) => { - let t1s = t1.xyShape - let t2s = t2->getShape - - if XYShape.T.isEmpty(t1s) || XYShape.T.isEmpty(t2s) { - empty - } else { - let continuousAsLinear = switch t2.interpolation { - | #Linear => t2 - | #Stepwise => stepwiseToLinear(t2) - } - - let combinedShape = AlgebraicShapeCombination.combineShapesDiscreteContinuous( - op, - t1s, - continuousAsLinear |> getShape, - ) - - let combinedIntegralSum = switch op { - | #Multiply - | #Divide => - Common.combineIntegralSums((a, b) => Some(a *. b), t1.integralSumCache, t2.integralSumCache) - | _ => None - } - - // TODO: It could make sense to automatically transform the integrals here (shift or scale) - make(~interpolation=t2.interpolation, ~integralSumCache=combinedIntegralSum, combinedShape) - } -} - let combineAlgebraically = (op: Operation.algebraicOperation, t1: t, t2: t) => { let s1 = t1 |> getShape let s2 = t2 |> getShape diff --git a/packages/squiggle-lang/src/rescript/Distributions/PointSetDist/Mixed.res b/packages/squiggle-lang/src/rescript/Distributions/PointSetDist/Mixed.res index f621f6a5..c5ba67e5 100644 --- a/packages/squiggle-lang/src/rescript/Distributions/PointSetDist/Mixed.res +++ b/packages/squiggle-lang/src/rescript/Distributions/PointSetDist/Mixed.res @@ -242,15 +242,17 @@ let combineAlgebraically = (op: Operation.algebraicOperation, t1: t, t2: t): t = // continuous (*) continuous => continuous, but also // discrete (*) continuous => continuous (and vice versa). We have to take care of all combos and then combine them: let ccConvResult = Continuous.combineAlgebraically(op, t1.continuous, t2.continuous) - let dcConvResult = Continuous.combineAlgebraicallyWithDiscreteFirst( + let dcConvResult = Continuous.combineAlgebraicallyWithDiscrete( op, - t1.discrete, t2.continuous, + t1.discrete, + true, ) - let cdConvResult = Continuous.combineAlgebraicallyWithDiscreteSecond( + let cdConvResult = Continuous.combineAlgebraicallyWithDiscrete( op, t1.continuous, t2.discrete, + false, ) let continuousConvResult = Continuous.reduce(\"+.", [ccConvResult, dcConvResult, cdConvResult]) diff --git a/packages/squiggle-lang/src/rescript/Distributions/PointSetDist/PointSetDist.res b/packages/squiggle-lang/src/rescript/Distributions/PointSetDist/PointSetDist.res index f552a028..b9be8ed8 100644 --- a/packages/squiggle-lang/src/rescript/Distributions/PointSetDist/PointSetDist.res +++ b/packages/squiggle-lang/src/rescript/Distributions/PointSetDist/PointSetDist.res @@ -40,9 +40,9 @@ let combineAlgebraically = (op: Operation.algebraicOperation, t1: t, t2: t): t = | (Continuous(m1), Continuous(m2)) => Continuous.combineAlgebraically(op, m1, m2) |> Continuous.T.toPointSetDist | (Discrete(m1), Continuous(m2)) => - Continuous.combineAlgebraicallyWithDiscreteFirst(op, m1, m2) |> Continuous.T.toPointSetDist + Continuous.combineAlgebraicallyWithDiscrete(op, m2, m1, true) |> Continuous.T.toPointSetDist | (Continuous(m1), Discrete(m2)) => - Continuous.combineAlgebraicallyWithDiscreteSecond(op, m1, m2) |> Continuous.T.toPointSetDist + Continuous.combineAlgebraicallyWithDiscrete(op, m1, m2, false) |> Continuous.T.toPointSetDist | (Discrete(m1), Discrete(m2)) => Discrete.combineAlgebraically(op, m1, m2) |> Discrete.T.toPointSetDist | (m1, m2) => Mixed.combineAlgebraically(op, toMixed(m1), toMixed(m2)) |> Mixed.T.toPointSetDist