From 934ce783994fe5d68136c2cfdd74e491729ddf7f Mon Sep 17 00:00:00 2001 From: Ozzie Gooen Date: Thu, 28 Apr 2022 09:08:53 -0400 Subject: [PATCH] Algebraic Strategy should use MC when inputs include sample set dists --- .../Distributions/GenericDist/GenericDist.res | 38 +++++++++++++++---- 1 file changed, 30 insertions(+), 8 deletions(-) diff --git a/packages/squiggle-lang/src/rescript/Distributions/GenericDist/GenericDist.res b/packages/squiggle-lang/src/rescript/Distributions/GenericDist/GenericDist.res index d5dbe1cf..c19bdf7f 100644 --- a/packages/squiggle-lang/src/rescript/Distributions/GenericDist/GenericDist.res +++ b/packages/squiggle-lang/src/rescript/Distributions/GenericDist/GenericDist.res @@ -6,6 +6,24 @@ type toSampleSetFn = t => result type scaleMultiplyFn = (t, float) => result type pointwiseAddFn = (t, t) => result +let isPointSet = (t: t) => + switch t { + | PointSet(_) => true + | _ => false + } + +let isSampleSetSet = (t: t) => + switch t { + | SampleSet(_) => true + | _ => false + } + +let isSymbolic = (t: t) => + switch t { + | Symbolic(_) => true + | _ => false + } + let sampleN = (t: t, n) => switch t { | PointSet(r) => PointSetDist.sampleNRendered(n, r) @@ -248,20 +266,24 @@ module AlgebraicCombination = { | _ => MagicNumbers.OpCost.wildcardCost } + let hasSampleSetDist = (t1: t, t2: t): bool => isSampleSetSet(t1) || isSampleSetSet(t2) + + let convolutionIsFasterThanMonteCarlo = (t1: t, t2: t): bool => + expectedConvolutionCost(t1) * expectedConvolutionCost(t2) < MagicNumbers.OpCost.monteCarloCost + + let preferConvolutionToMonteCarlo = (t1, t2, arithmeticOperation) => { + !hasSampleSetDist(t1, t2) && + Operation.Convolution.canDoAlgebraicOperation(arithmeticOperation) && + convolutionIsFasterThanMonteCarlo(t1, t2) + } + let run = (~t1: t, ~t2: t, ~arithmeticOperation): specificStrategy => { switch StrategyCallOnValidatedInputs.symbolic(arithmeticOperation, t1, t2) { | #AnalyticalSolution(_) | #Error(_) => #AsSymbolic | #NoSolution => - if Operation.Convolution.canDoAlgebraicOperation(arithmeticOperation) { - expectedConvolutionCost(t1) * expectedConvolutionCost(t2) > - MagicNumbers.OpCost.monteCarloCost - ? #AsMonteCarlo - : #AsConvolution - } else { - #AsMonteCarlo - } + preferConvolutionToMonteCarlo(t1, t2, arithmeticOperation) ? #AsConvolution : #AsMonteCarlo } } }