Fleshed out AlgebraicCombination

This commit is contained in:
Ozzie Gooen 2022-03-26 16:56:56 -04:00
parent d490af38f0
commit c5afb2d867
3 changed files with 96 additions and 26 deletions

View File

@ -115,6 +115,7 @@ let combineShapesContinuousContinuous = (
| #Multiply => (m1, m2) => m1 *. m2 | #Multiply => (m1, m2) => m1 *. m2
| #Divide => (m1, mInv2) => m1 *. mInv2 | #Divide => (m1, mInv2) => m1 *. mInv2
| #Exponentiate => (m1, mInv2) => m1 ** mInv2 | #Exponentiate => (m1, mInv2) => m1 ** mInv2
| #Log => (m1, m2) => log(m1) /. log(m2)
} // note: here, mInv2 = mean(1 / t2) ~= 1 / mean(t2) } // note: here, mInv2 = mean(1 / t2) ~= 1 / mean(t2)
// TODO: I don't know what the variances are for exponentatiation // TODO: I don't know what the variances are for exponentatiation
@ -232,6 +233,7 @@ let combineShapesContinuousDiscrete = (
} }
| #Multiply | #Multiply
| #Exponentiate | #Exponentiate
| #Log
| #Divide => | #Divide =>
for j in 0 to t2n - 1 { for j in 0 to t2n - 1 {
// creates a new continuous shape for each one of the discrete points, and collects them in outXYShapes. // creates a new continuous shape for each one of the discrete points, and collects them in outXYShapes.

View File

@ -41,7 +41,8 @@ let combineAlgebraically = (op: Operation.algebraicOperation, t1: t, t2: t): t =
| (Continuous(m1), Discrete(m2)) | (Continuous(m1), Discrete(m2))
| (Discrete(m2), Continuous(m1)) => | (Discrete(m2), Continuous(m1)) =>
Continuous.combineAlgebraicallyWithDiscrete(op, m1, m2) |> Continuous.T.toPointSetDist Continuous.combineAlgebraicallyWithDiscrete(op, m1, m2) |> Continuous.T.toPointSetDist
| (Discrete(m1), Discrete(m2)) => Discrete.combineAlgebraically(op, m1, m2) |> Discrete.T.toPointSetDist | (Discrete(m1), Discrete(m2)) =>
Discrete.combineAlgebraically(op, m1, m2) |> Discrete.T.toPointSetDist
| (m1, m2) => Mixed.combineAlgebraically(op, toMixed(m1), toMixed(m2)) |> Mixed.T.toPointSetDist | (m1, m2) => Mixed.combineAlgebraically(op, toMixed(m1), toMixed(m2)) |> Mixed.T.toPointSetDist
} }

View File

@ -55,12 +55,6 @@ module OperationType = {
| #Sample(int) | #Sample(int)
] ]
type scale = [
| #Multiply
| #Exponentiate
| #Log
]
type t = [ type t = [
| #toFloat(toFloat) | #toFloat(toFloat)
| #toDist(toDist) | #toDist(toDist)
@ -73,6 +67,7 @@ type operation = OperationType.t
module T = { module T = {
type t = genericDist type t = genericDist
type toPointSetFn = genericDist => result<PointSetTypes.pointSetDist, error> type toPointSetFn = genericDist => result<PointSetTypes.pointSetDist, error>
type toSampleSetFn = genericDist => result<array<float>, error>
let sampleN = (n, t: t) => { let sampleN = (n, t: t) => {
switch t { switch t {
| #PointSet(r) => Ok(PointSetDist.sampleNRendered(n, r)) | #PointSet(r) => Ok(PointSetDist.sampleNRendered(n, r))
@ -81,7 +76,7 @@ module T = {
} }
} }
let toFloat = (toPointSet: toPointSetFn, fnName, t: genericDist) => { let toFloat = (toPointSet: toPointSetFn, fnName, t: genericDist): result<float, error> => {
switch t { switch t {
| #Symbolic(r) if Belt.Result.isOk(SymbolicDist.T.operate(fnName, r)) => | #Symbolic(r) if Belt.Result.isOk(SymbolicDist.T.operate(fnName, r)) =>
switch SymbolicDist.T.operate(fnName, r) { switch SymbolicDist.T.operate(fnName, r) {
@ -104,7 +99,7 @@ module T = {
kernelWidth: None, kernelWidth: None,
} }
let toPointSet = (xyPointLength, t: t) => { let toPointSet = (xyPointLength, t: t): result<PointSetTypes.pointSetDist, error> => {
switch t { switch t {
| #PointSet(pointSet) => Ok(pointSet) | #PointSet(pointSet) => Ok(pointSet)
| #Symbolic(r) => Ok(SymbolicDist.T.toPointSetDist(xyPointLength, r)) | #Symbolic(r) => Ok(SymbolicDist.T.toPointSetDist(xyPointLength, r))
@ -122,16 +117,82 @@ module T = {
} }
} }
let algebraicCombination = (operation, sampleCount, dist1: t, dist2: t) => { module AlgebraicCombination = {
let dist1 = sampleN(sampleCount, dist1) let tryAnalyticalSimplification = (operation: OperationType.combination, t1: t, t2: t): option<
let dist2 = sampleN(sampleCount, dist2) result<SymbolicDistTypes.symbolicDist, string>,
let samples = E.R.merge(dist1, dist2) |> E.R.fmap(((d1, d2)) => { > =>
Belt.Array.zip(d1, d2) |> E.A.fmap(((a, b)) => Operation.Algebraic.toFn(operation, a, b)) switch (operation, t1, t2) {
}) | (operation, #Symbolic(d1), #Symbolic(d2)) =>
samples |> E.R.fmap(r => #SampleSet(r)) switch SymbolicDist.T.tryAnalyticalSimplification(d1, d2, operation) {
| #AnalyticalSolution(symbolicDist) => Some(Ok(symbolicDist))
| #Error(er) => Some(Error(er))
| #NoSolution => None
}
| _ => None
}
let runConvolution = (
toPointSet: toPointSetFn,
operation: OperationType.combination,
t1: t,
t2: t,
) =>
E.R.merge(toPointSet(t1), toPointSet(t2)) |> E.R.fmap(((a, b)) =>
PointSetDist.combineAlgebraically(operation, a, b)
)
let runMonteCarlo = (
toSampleSet: toSampleSetFn,
operation: OperationType.combination,
t1: t,
t2: t,
) => {
E.R.merge(toSampleSet(t1), toSampleSet(t2)) |> E.R.fmap(((a, b)) => {
Belt.Array.zip(a, b) |> E.A.fmap(((a, b)) => Operation.Algebraic.toFn(operation, a, b))
})
}
//I'm (Ozzie) really just guessing here, very little idea what's best
let expectedConvolutionCost: t => int = x =>
switch x {
| #Symbolic(#Float(_)) => 1
| #Symbolic(_) => 1000
| #PointSet(Discrete(m)) => m.xyShape |> XYShape.T.length
| #PointSet(Mixed(_)) => 1000
| #PointSet(Continuous(_)) => 1000
| _ => 1000
}
let chooseConvolutionOrMonteCarlo = (t1: t, t2: t) =>
expectedConvolutionCost(t1) * expectedConvolutionCost(t2) > 10000
? #CalculateWithMonteCarlo
: #CalculateWithConvolution
let run = (
toPointSet: toPointSetFn,
toSampleSet: toSampleSetFn,
algebraicOp,
t1: t,
t2: t,
): result<t, error> => {
switch tryAnalyticalSimplification(algebraicOp, t1, t2) {
| Some(Ok(symbolicDist)) => Ok(#Symbolic(symbolicDist))
| Some(Error(e)) => Error(Other(e))
| None =>
switch chooseConvolutionOrMonteCarlo(t1, t2) {
| #CalculateWithMonteCarlo =>
runMonteCarlo(toSampleSet, algebraicOp, t1, t2) |> E.R.fmap(r => #SampleSet(r))
| #CalculateWithConvolution =>
runConvolution(toPointSet, algebraicOp, t1, t2) |> E.R.fmap(r => #PointSet(r))
}
}
}
} }
let pointwiseCombination = (toPointSet: toPointSetFn, operation, t1: t, t2: t) => { let pointwiseCombination = (toPointSet: toPointSetFn, operation, t1: t, t2: t): result<
t,
error,
> => {
E.R.merge(toPointSet(t1), toPointSet(t2)) E.R.merge(toPointSet(t1), toPointSet(t2))
|> E.R.fmap(((t1, t2)) => |> E.R.fmap(((t1, t2)) =>
PointSetDist.combinePointwise(OperationType.combinationToFn(operation), t1, t2) PointSetDist.combinePointwise(OperationType.combinationToFn(operation), t1, t2)
@ -144,11 +205,12 @@ module T = {
operation: OperationType.combination, operation: OperationType.combination,
t: t, t: t,
f: float, f: float,
) => { ): result<t, error> => {
switch operation { switch operation {
| #Add | #Subtract => Error(DistributionVerticalShiftIsInvalid) | #Add | #Subtract => Error(DistributionVerticalShiftIsInvalid)
| (#Multiply | #Divide | #Exponentiate | #Log) as operation => | (#Multiply | #Divide | #Exponentiate | #Log) as operation =>
toPointSet(t) |> E.R.fmap(t => { toPointSet(t) |> E.R.fmap(t => {
//TODO: Move to PointSet codebase
let fn = (secondary, main) => Operation.Scale.toFn(operation, main, secondary) let fn = (secondary, main) => Operation.Scale.toFn(operation, main, secondary)
let integralSumCacheFn = Operation.Scale.toIntegralSumCacheFn(operation) let integralSumCacheFn = Operation.Scale.toIntegralSumCacheFn(operation)
let integralCacheFn = Operation.Scale.toIntegralCacheFn(operation) let integralCacheFn = Operation.Scale.toIntegralCacheFn(operation)
@ -159,7 +221,7 @@ module T = {
t, t,
) )
}) })
} } |> E.R.fmap(r => #PointSet(r))
} }
} }
@ -188,10 +250,10 @@ module OmniRunner = {
| Error(e) => #Error(e) | Error(e) => #Error(e)
} }
let rec applyFnInternal = (wrapped: wrapped, fnName: operation): outputType => { let rec run = (wrapped: wrapped, fnName: operation): outputType => {
let (value, {sampleCount, xyPointLength} as extra) = wrapped let (value, {sampleCount, xyPointLength} as extra) = wrapped
let reCall = (~value=value, ~extra=extra, ~fnName=fnName, ()) => { let reCall = (~value=value, ~extra=extra, ~fnName=fnName, ()) => {
applyFnInternal((value, extra), fnName) run((value, extra), fnName)
} }
let toPointSet = r => { let toPointSet = r => {
switch reCall(~value=r, ~fnName=#toDist(#toPointSet), ()) { switch reCall(~value=r, ~fnName=#toDist(#toPointSet), ()) {
@ -200,8 +262,14 @@ module OmniRunner = {
| _ => Error(Other("Impossible error")) | _ => Error(Other("Impossible error"))
} }
} }
let toPointSetAndReCall = v => toPointSet(v) |> E.R.fmap(r => reCall(~value=#PointSet(r), ())) let toSampleSet = r => {
let newVal: outputType = switch (fnName, value) { switch reCall(~value=r, ~fnName=#toDist(#toSampleSet(sampleCount)), ()) {
| #Dist(#SampleSet(p)) => Ok(p)
| #Error(r) => Error(r)
| _ => Error(Other("Impossible error"))
}
}
switch (fnName, value) {
// | (#toFloat(n), v) => toFloat(toPointSet, v, n) // | (#toFloat(n), v) => toFloat(toPointSet, v, n)
| (#toFloat(fnName), _) => | (#toFloat(fnName), _) =>
T.toFloat(toPointSet, fnName, value) |> E.R.fmap(r => #Float(r)) |> fromResult T.toFloat(toPointSet, fnName, value) |> E.R.fmap(r => #Float(r)) |> fromResult
@ -214,17 +282,16 @@ module OmniRunner = {
value |> T.sampleN(n) |> E.R.fmap(r => #Dist(#SampleSet(r))) |> fromResult value |> T.sampleN(n) |> E.R.fmap(r => #Dist(#SampleSet(r))) |> fromResult
| (#toDistCombination(#Algebraic, _, #Float(_)), _) => #Error(NotYetImplemented) | (#toDistCombination(#Algebraic, _, #Float(_)), _) => #Error(NotYetImplemented)
| (#toDistCombination(#Algebraic, operation, #Dist(p2)), p1) => | (#toDistCombination(#Algebraic, operation, #Dist(p2)), p1) =>
T.algebraicCombination(operation, sampleCount, p1, p2) T.AlgebraicCombination.run(toPointSet, toSampleSet, operation, p1, p2)
|> E.R.fmap(r => #Dist(r)) |> E.R.fmap(r => #Dist(r))
|> fromResult |> fromResult
| (#toDistCombination(#Pointwise, operation, #Dist(p2)), p1) => | (#toDistCombination(#Pointwise, operation, #Dist(p2)), p1) =>
T.pointwiseCombination(toPointSet, operation, p1, p2) |> E.R.fmap(r => #Dist(r)) |> fromResult T.pointwiseCombination(toPointSet, operation, p1, p2) |> E.R.fmap(r => #Dist(r)) |> fromResult
| (#toDistCombination(#Pointwise, operation, #Float(f)), _) => | (#toDistCombination(#Pointwise, operation, #Float(f)), _) =>
T.pointwiseCombinationFloat(toPointSet, operation, value, f) T.pointwiseCombinationFloat(toPointSet, operation, value, f)
|> E.R.fmap(r => #Dist(#PointSet(r))) |> E.R.fmap(r => #Dist(r))
|> fromResult |> fromResult
} }
newVal
} }
} }