From 8deb79682014ab83fa352a069bf1f8067e0ae710 Mon Sep 17 00:00:00 2001 From: Sam Nolan Date: Mon, 18 Apr 2022 16:42:11 +1000 Subject: [PATCH] Fix non-commutative pointwise combinations --- README.md | 1 + .../AlgebraicShapeCombination.res | 85 ++++++++++--------- .../Distributions/PointSetDist/Continuous.res | 36 +++++++- .../Distributions/PointSetDist/Mixed.res | 12 ++- .../PointSetDist/PointSetDist.res | 7 +- 5 files changed, 96 insertions(+), 45 deletions(-) diff --git a/README.md b/README.md index ee4d00c8..2f1de7f7 100644 --- a/README.md +++ b/README.md @@ -35,6 +35,7 @@ The playground depends on the components library which then depends on the langu # Develop For any project in the repo, begin by running `yarn` in the top level + ```sh yarn ``` diff --git a/packages/squiggle-lang/src/rescript/Distributions/PointSetDist/AlgebraicShapeCombination.res b/packages/squiggle-lang/src/rescript/Distributions/PointSetDist/AlgebraicShapeCombination.res index f8740de9..1e412d5a 100644 --- a/packages/squiggle-lang/src/rescript/Distributions/PointSetDist/AlgebraicShapeCombination.res +++ b/packages/squiggle-lang/src/rescript/Distributions/PointSetDist/AlgebraicShapeCombination.res @@ -203,55 +203,62 @@ let combineShapesContinuousDiscrete = ( continuousShape: PointSetTypes.xyShape, discreteShape: PointSetTypes.xyShape, ): PointSetTypes.xyShape => { - let t1n = continuousShape |> XYShape.T.length - let t2n = discreteShape |> XYShape.T.length - // each x pair is added/subtracted let fn = Operation.Algebraic.toFn(op) - let outXYShapes: array> = Belt.Array.makeUninitializedUnsafe(t2n) + let discretePoints = Belt.Array.zip(discreteShape.xs, discreteShape.ys) + let continuousPoints = Belt.Array.zip(continuousShape.xs, continuousShape.ys) - switch op { + let outXYShapes = switch op { | #Add | #Subtract => - for j in 0 to t2n - 1 { - // creates a new continuous shape for each one of the discrete points, and collects them in outXYShapes. - let dxyShape: array<(float, float)> = Belt.Array.makeUninitializedUnsafe(t1n) - for i in 0 to t1n - 1 { - Belt.Array.set( - dxyShape, - i, - ( - fn(continuousShape.xs[i], discreteShape.xs[j]), - continuousShape.ys[i] *. discreteShape.ys[j], - ), - ) |> ignore - () - } - Belt.Array.set(outXYShapes, j, dxyShape) |> ignore - () - } + discretePoints->E.A2.fmap(((dx, dy)) => + continuousPoints->E.A2.fmap(((cx, cy)) => (fn(cx, dx), cy *. dy)) + ) | #Multiply | #Power | #Logarithm | #Divide => - for j in 0 to t2n - 1 { - // creates a new continuous shape for each one of the discrete points, and collects them in outXYShapes. - let dxyShape: array<(float, float)> = Belt.Array.makeUninitializedUnsafe(t1n) - for i in 0 to t1n - 1 { - Belt.Array.set( - dxyShape, - i, - ( - fn(continuousShape.xs[i], discreteShape.xs[j]), - continuousShape.ys[i] *. discreteShape.ys[j] /. discreteShape.xs[j], - ), - ) |> ignore - () - } - Belt.Array.set(outXYShapes, j, dxyShape) |> ignore - () - } + discretePoints->E.A2.fmap(((dx, dy)) => + continuousPoints->E.A2.fmap(((cx, cy)) => (fn(cx, dx), cy *. dy /. dx)) + ) + } + + outXYShapes + |> E.A.fmap(XYShape.T.fromZippedArray) + |> E.A.fold_left( + XYShape.PointwiseCombination.combine( + \"+.", + XYShape.XtoY.continuousInterpolator(#Linear, #UseZero), + ), + XYShape.T.empty, + ) +} + +let combineShapesDiscreteContinuous = ( + op: Operation.algebraicOperation, + discreteShape: PointSetTypes.xyShape, + continuousShape: PointSetTypes.xyShape, +): PointSetTypes.xyShape => { + // each x pair is added/subtracted + let fn = Operation.Algebraic.toFn(op) + + let discretePoints = Belt.Array.zip(discreteShape.xs, discreteShape.ys) + let continuousPoints = Belt.Array.zip(continuousShape.xs, continuousShape.ys) + + let outXYShapes = switch op { + | #Add + | #Subtract => + discretePoints->E.A2.fmap(((dx, dy)) => + continuousPoints->E.A2.fmap(((cx, cy)) => (fn(dx, cx), dy *. cy)) + ) + | #Multiply + | #Power + | #Logarithm + | #Divide => + discretePoints->E.A2.fmap(((dx, dy)) => + continuousPoints->E.A2.fmap(((cx, cy)) => (fn(dx, cx), dy *. cy /. dx)) + ) } outXYShapes diff --git a/packages/squiggle-lang/src/rescript/Distributions/PointSetDist/Continuous.res b/packages/squiggle-lang/src/rescript/Distributions/PointSetDist/Continuous.res index 4c48df70..f7d1e179 100644 --- a/packages/squiggle-lang/src/rescript/Distributions/PointSetDist/Continuous.res +++ b/packages/squiggle-lang/src/rescript/Distributions/PointSetDist/Continuous.res @@ -246,7 +246,7 @@ let downsampleEquallyOverX = (length, t): t => /* This simply creates multiple copies of the continuous distribution, scaled and shifted according to each discrete data point, and then adds them all together. */ -let combineAlgebraicallyWithDiscrete = ( +let combineAlgebraicallyWithDiscreteSecond = ( op: Operation.algebraicOperation, t1: t, t2: PointSetTypes.discreteShape, @@ -280,6 +280,40 @@ let combineAlgebraicallyWithDiscrete = ( } } +let combineAlgebraicallyWithDiscreteFirst = ( + op: Operation.algebraicOperation, + t1: PointSetTypes.discreteShape, + t2: t, +) => { + let t1s = t1.xyShape + let t2s = t2->getShape + + if XYShape.T.isEmpty(t1s) || XYShape.T.isEmpty(t2s) { + empty + } else { + let continuousAsLinear = switch t2.interpolation { + | #Linear => t2 + | #Stepwise => stepwiseToLinear(t2) + } + + let combinedShape = AlgebraicShapeCombination.combineShapesDiscreteContinuous( + op, + t1s, + continuousAsLinear |> getShape, + ) + + let combinedIntegralSum = switch op { + | #Multiply + | #Divide => + Common.combineIntegralSums((a, b) => Some(a *. b), t1.integralSumCache, t2.integralSumCache) + | _ => None + } + + // TODO: It could make sense to automatically transform the integrals here (shift or scale) + make(~interpolation=t2.interpolation, ~integralSumCache=combinedIntegralSum, combinedShape) + } +} + let combineAlgebraically = (op: Operation.algebraicOperation, t1: t, t2: t) => { let s1 = t1 |> getShape let s2 = t2 |> getShape diff --git a/packages/squiggle-lang/src/rescript/Distributions/PointSetDist/Mixed.res b/packages/squiggle-lang/src/rescript/Distributions/PointSetDist/Mixed.res index 9e27057a..f621f6a5 100644 --- a/packages/squiggle-lang/src/rescript/Distributions/PointSetDist/Mixed.res +++ b/packages/squiggle-lang/src/rescript/Distributions/PointSetDist/Mixed.res @@ -242,8 +242,16 @@ let combineAlgebraically = (op: Operation.algebraicOperation, t1: t, t2: t): t = // continuous (*) continuous => continuous, but also // discrete (*) continuous => continuous (and vice versa). We have to take care of all combos and then combine them: let ccConvResult = Continuous.combineAlgebraically(op, t1.continuous, t2.continuous) - let dcConvResult = Continuous.combineAlgebraicallyWithDiscrete(op, t2.continuous, t1.discrete) - let cdConvResult = Continuous.combineAlgebraicallyWithDiscrete(op, t1.continuous, t2.discrete) + let dcConvResult = Continuous.combineAlgebraicallyWithDiscreteFirst( + op, + t1.discrete, + t2.continuous, + ) + let cdConvResult = Continuous.combineAlgebraicallyWithDiscreteSecond( + op, + t1.continuous, + t2.discrete, + ) let continuousConvResult = Continuous.reduce(\"+.", [ccConvResult, dcConvResult, cdConvResult]) // ... finally, discrete (*) discrete => discrete, obviously: diff --git a/packages/squiggle-lang/src/rescript/Distributions/PointSetDist/PointSetDist.res b/packages/squiggle-lang/src/rescript/Distributions/PointSetDist/PointSetDist.res index d0668a57..f552a028 100644 --- a/packages/squiggle-lang/src/rescript/Distributions/PointSetDist/PointSetDist.res +++ b/packages/squiggle-lang/src/rescript/Distributions/PointSetDist/PointSetDist.res @@ -39,9 +39,10 @@ let combineAlgebraically = (op: Operation.algebraicOperation, t1: t, t2: t): t = switch (t1, t2) { | (Continuous(m1), Continuous(m2)) => Continuous.combineAlgebraically(op, m1, m2) |> Continuous.T.toPointSetDist - | (Continuous(m1), Discrete(m2)) - | (Discrete(m2), Continuous(m1)) => - Continuous.combineAlgebraicallyWithDiscrete(op, m1, m2) |> Continuous.T.toPointSetDist + | (Discrete(m1), Continuous(m2)) => + Continuous.combineAlgebraicallyWithDiscreteFirst(op, m1, m2) |> Continuous.T.toPointSetDist + | (Continuous(m1), Discrete(m2)) => + Continuous.combineAlgebraicallyWithDiscreteSecond(op, m1, m2) |> Continuous.T.toPointSetDist | (Discrete(m1), Discrete(m2)) => Discrete.combineAlgebraically(op, m1, m2) |> Discrete.T.toPointSetDist | (m1, m2) => Mixed.combineAlgebraically(op, toMixed(m1), toMixed(m2)) |> Mixed.T.toPointSetDist