Algebraic Strategy should use MC when inputs include sample set dists

This commit is contained in:
Ozzie Gooen 2022-04-28 09:08:53 -04:00
parent 20685ea8cb
commit 934ce78399

View File

@ -6,6 +6,24 @@ type toSampleSetFn = t => result<SampleSetDist.t, error>
type scaleMultiplyFn = (t, float) => result<t, error> type scaleMultiplyFn = (t, float) => result<t, error>
type pointwiseAddFn = (t, t) => result<t, error> type pointwiseAddFn = (t, t) => result<t, error>
let isPointSet = (t: t) =>
switch t {
| PointSet(_) => true
| _ => false
}
let isSampleSetSet = (t: t) =>
switch t {
| SampleSet(_) => true
| _ => false
}
let isSymbolic = (t: t) =>
switch t {
| Symbolic(_) => true
| _ => false
}
let sampleN = (t: t, n) => let sampleN = (t: t, n) =>
switch t { switch t {
| PointSet(r) => PointSetDist.sampleNRendered(n, r) | PointSet(r) => PointSetDist.sampleNRendered(n, r)
@ -248,20 +266,24 @@ module AlgebraicCombination = {
| _ => MagicNumbers.OpCost.wildcardCost | _ => MagicNumbers.OpCost.wildcardCost
} }
let hasSampleSetDist = (t1: t, t2: t): bool => isSampleSetSet(t1) || isSampleSetSet(t2)
let convolutionIsFasterThanMonteCarlo = (t1: t, t2: t): bool =>
expectedConvolutionCost(t1) * expectedConvolutionCost(t2) < MagicNumbers.OpCost.monteCarloCost
let preferConvolutionToMonteCarlo = (t1, t2, arithmeticOperation) => {
!hasSampleSetDist(t1, t2) &&
Operation.Convolution.canDoAlgebraicOperation(arithmeticOperation) &&
convolutionIsFasterThanMonteCarlo(t1, t2)
}
let run = (~t1: t, ~t2: t, ~arithmeticOperation): specificStrategy => { let run = (~t1: t, ~t2: t, ~arithmeticOperation): specificStrategy => {
switch StrategyCallOnValidatedInputs.symbolic(arithmeticOperation, t1, t2) { switch StrategyCallOnValidatedInputs.symbolic(arithmeticOperation, t1, t2) {
| #AnalyticalSolution(_) | #AnalyticalSolution(_)
| #Error(_) => | #Error(_) =>
#AsSymbolic #AsSymbolic
| #NoSolution => | #NoSolution =>
if Operation.Convolution.canDoAlgebraicOperation(arithmeticOperation) { preferConvolutionToMonteCarlo(t1, t2, arithmeticOperation) ? #AsConvolution : #AsMonteCarlo
expectedConvolutionCost(t1) * expectedConvolutionCost(t2) >
MagicNumbers.OpCost.monteCarloCost
? #AsMonteCarlo
: #AsConvolution
} else {
#AsMonteCarlo
}
} }
} }
} }