Fix subtraction not commuting under pointsets

This commit is contained in:
Sam Nolan 2022-04-23 15:58:42 -04:00
parent dfd2f83c9d
commit 9090c44284
4 changed files with 21 additions and 6 deletions

View File

@ -191,12 +191,14 @@ let combineShapesContinuousDiscrete = (
op: Operation.convolutionOperation,
continuousShape: PointSetTypes.xyShape,
discreteShape: PointSetTypes.xyShape,
flip: bool,
): PointSetTypes.xyShape => {
let t1n = continuousShape |> XYShape.T.length
let t2n = discreteShape |> XYShape.T.length
// each x pair is added/subtracted
let fn = Operation.Convolution.toFn(op)
let opFunc = Operation.Convolution.toFn(op)
let fn = flip ? (a, b) => opFunc(b, a) : opFunc
let outXYShapes: array<array<(float, float)>> = Belt.Array.makeUninitializedUnsafe(t2n)

View File

@ -278,6 +278,7 @@ let combineAlgebraicallyWithDiscrete = (
op: Operation.convolutionOperation,
t1: t,
t2: PointSetTypes.discreteShape,
flip: bool,
) => {
let t1s = t1 |> getShape
let t2s = t2.xyShape // TODO would like to use Discrete.getShape here, but current file structure doesn't allow for that
@ -294,6 +295,7 @@ let combineAlgebraicallyWithDiscrete = (
op,
continuousAsLinear |> getShape,
t2s,
flip,
)
let combinedIntegralSum = switch op {

View File

@ -277,8 +277,18 @@ let combineAlgebraically = (op: Operation.convolutionOperation, t1: t, t2: t): t
// continuous (*) continuous => continuous, but also
// discrete (*) continuous => continuous (and vice versa). We have to take care of all combos and then combine them:
let ccConvResult = Continuous.combineAlgebraically(op, t1.continuous, t2.continuous)
let dcConvResult = Continuous.combineAlgebraicallyWithDiscrete(op, t2.continuous, t1.discrete)
let cdConvResult = Continuous.combineAlgebraicallyWithDiscrete(op, t1.continuous, t2.discrete)
let dcConvResult = Continuous.combineAlgebraicallyWithDiscrete(
op,
t2.continuous,
t1.discrete,
true,
)
let cdConvResult = Continuous.combineAlgebraicallyWithDiscrete(
op,
t1.continuous,
t2.discrete,
false,
)
let continuousConvResult = Continuous.sum([ccConvResult, dcConvResult, cdConvResult])
// ... finally, discrete (*) discrete => discrete, obviously:

View File

@ -46,9 +46,10 @@ let combineAlgebraically = (op: Operation.convolutionOperation, t1: t, t2: t): t
switch (t1, t2) {
| (Continuous(m1), Continuous(m2)) =>
Continuous.combineAlgebraically(op, m1, m2) |> Continuous.T.toPointSetDist
| (Continuous(m1), Discrete(m2))
| (Discrete(m2), Continuous(m1)) =>
Continuous.combineAlgebraicallyWithDiscrete(op, m1, m2) |> Continuous.T.toPointSetDist
| (Discrete(m1), Continuous(m2)) =>
Continuous.combineAlgebraicallyWithDiscrete(op, m2, m1, true) |> Continuous.T.toPointSetDist
| (Continuous(m1), Discrete(m2)) =>
Continuous.combineAlgebraicallyWithDiscrete(op, m1, m2, false) |> Continuous.T.toPointSetDist
| (Discrete(m1), Discrete(m2)) =>
Discrete.combineAlgebraically(op, m1, m2) |> Discrete.T.toPointSetDist
| (m1, m2) => Mixed.combineAlgebraically(op, toMixed(m1), toMixed(m2)) |> Mixed.T.toPointSetDist