diff --git a/packages/squiggle-lang/src/rescript/Distributions/GenericDist/GenericDist.res b/packages/squiggle-lang/src/rescript/Distributions/GenericDist/GenericDist.res index 810eb5cb..71cec481 100644 --- a/packages/squiggle-lang/src/rescript/Distributions/GenericDist/GenericDist.res +++ b/packages/squiggle-lang/src/rescript/Distributions/GenericDist/GenericDist.res @@ -191,10 +191,23 @@ module AlgebraicCombination = { | _ => 1000 } - let chooseConvolutionOrMonteCarlo = (t2: t, t1: t) => - expectedConvolutionCost(t1) * expectedConvolutionCost(t2) > 10000 - ? #CalculateWithMonteCarlo - : #CalculateWithConvolution + type calculationMethod = MonteCarlo | Convolution(Operation.convolutionOperation) + + let chooseConvolutionOrMonteCarlo = ( + op: Operation.algebraicOperation, + t2: t, + t1: t, + ): calculationMethod => + switch op { + | #Divide + | #Power + | #Logarithm => + MonteCarlo + | (#Add | #Subtract | #Multiply) as convOp => + expectedConvolutionCost(t1) * expectedConvolutionCost(t2) > 10000 + ? MonteCarlo + : Convolution(convOp) + } let run = ( t1: t, @@ -207,17 +220,10 @@ module AlgebraicCombination = { | Some(Ok(symbolicDist)) => Ok(Symbolic(symbolicDist)) | Some(Error(e)) => Error(Other(e)) | None => - switch arithmeticOperation { - | #Divide - | #Power - | #Logarithm => - runMonteCarlo(toSampleSetFn, arithmeticOperation, t1, t2) - | (#Add | #Subtract | #Multiply) as op => - switch chooseConvolutionOrMonteCarlo(t1, t2) { - | #CalculateWithMonteCarlo => runMonteCarlo(toSampleSetFn, arithmeticOperation, t1, t2) - | #CalculateWithConvolution => - runConvolution(toPointSetFn, op, t1, t2)->E.R2.fmap(r => DistributionTypes.PointSet(r)) - } + switch chooseConvolutionOrMonteCarlo(arithmeticOperation, t1, t2) { + | MonteCarlo => runMonteCarlo(toSampleSetFn, arithmeticOperation, t1, t2) + | Convolution(convOp) => + runConvolution(toPointSetFn, convOp, t1, t2)->E.R2.fmap(r => DistributionTypes.PointSet(r)) } } }