diff --git a/packages/squiggle-lang/src/rescript/Distributions/DistributionTypes.res b/packages/squiggle-lang/src/rescript/Distributions/DistributionTypes.res index c16ea1b2..e27a138d 100644 --- a/packages/squiggle-lang/src/rescript/Distributions/DistributionTypes.res +++ b/packages/squiggle-lang/src/rescript/Distributions/DistributionTypes.res @@ -16,7 +16,7 @@ type error = | OperationError(Operation.Error.t) | PointSetConversionError(SampleSetDist.pointsetConversionError) | SparklineError(PointSetTypes.sparklineError) // This type of error is for when we find a sparkline of a discrete distribution. This should probably at some point be actually implemented - | RequestedStrategyInvalidError + | RequestedStrategyInvalidError(string) | LogarithmOfDistributionError(string) | OtherError(string) @@ -38,7 +38,7 @@ module Error = { | OperationError(err) => Operation.Error.toString(err) | PointSetConversionError(err) => SampleSetDist.pointsetConversionErrorToString(err) | SparklineError(err) => PointSetTypes.sparklineErrorToString(err) - | RequestedStrategyInvalidError => `Requested strategy invalid` + | RequestedStrategyInvalidError(err) => `Requested strategy invalid: ${err}` | OtherError(s) => s } diff --git a/packages/squiggle-lang/src/rescript/Distributions/GenericDist/GenericDist.res b/packages/squiggle-lang/src/rescript/Distributions/GenericDist/GenericDist.res index 25169b2f..a83bc8c2 100644 --- a/packages/squiggle-lang/src/rescript/Distributions/GenericDist/GenericDist.res +++ b/packages/squiggle-lang/src/rescript/Distributions/GenericDist/GenericDist.res @@ -225,55 +225,22 @@ module AlgebraicCombination = { | _ => 1000 } - type calculationMethod = MonteCarlo | Convolution(Operation.convolutionOperation) + type calculationStrategy = MonteCarloStrat | ConvolutionStrat(Operation.convolutionOperation) let chooseConvolutionOrMonteCarloDefault = ( op: Operation.algebraicOperation, t2: t, t1: t, - ): calculationMethod => + ): calculationStrategy => switch op { | #Divide | #Power | #Logarithm => - MonteCarlo + MonteCarloStrat | (#Add | #Subtract | #Multiply) as convOp => expectedConvolutionCost(t1) * expectedConvolutionCost(t2) > 10000 - ? MonteCarlo - : Convolution(convOp) - } - - let chooseConvolutionOrMonteCarlo = ( - ~strat: DistributionTypes.asAlgebraicCombinationStrategy, - op: Operation.algebraicOperation, - t2: t, - t1: t, - ): result => { - switch strat { - | AsDefault => Ok(chooseConvolutionOrMonteCarloDefault(op, t2, t1)) - | AsConvolution => - switch op { - | #Divide | #Power | #Logarithm => Error(RequestedStrategyInvalidError) - | (#Add | #Subtract | #Multiply) as convOp => Ok(Convolution(convOp)) - } - | AsMonteCarlo => Ok(MonteCarlo) - | AsSymbolic => Error(RequestedStrategyInvalidError) - } - } - - let tryAnalyticalSimplificationDefault = ( - arithmeticOperation: Operation.algebraicOperation, - t1: t, - t2: t, - ): option> => - switch (t1, t2) { - | (Symbolic(d1), Symbolic(d2)) => - switch SymbolicDist.T.tryAnalyticalSimplification(d1, d2, arithmeticOperation) { - | #AnalyticalSolution(symbolicDist) => Some(Ok(symbolicDist)) - | #Error(er) => Some(Error(er)) - | #NoSolution => None - } - | _ => None + ? MonteCarloStrat + : ConvolutionStrat(convOp) } let tryAnalyticalSimplification = ( @@ -295,16 +262,17 @@ module AlgebraicCombination = { ~arithmeticOperation, ~t2: t, ): result => { - switch tryAnalyticalSimplificationDefault(arithmeticOperation, t1, t2) { - | Some(Ok(symbolicDist)) => Ok(Symbolic(symbolicDist)) - | Some(Error(e)) => Error(OperationError(e)) + switch tryAnalyticalSimplification(arithmeticOperation, t1, t2) { + | Some(#AnalyticalSolution(symbolicDist)) => Ok(Symbolic(symbolicDist)) + | Some(#Error(e)) => Error(OperationError(e)) + | Some(#NoSolution) | None => switch getInvalidOperationError(t1, t2, ~toPointSetFn, ~arithmeticOperation) { | Some(e) => Error(e) | None => switch chooseConvolutionOrMonteCarloDefault(arithmeticOperation, t1, t2) { - | MonteCarlo => runMonteCarlo(toSampleSetFn, arithmeticOperation, t1, t2) - | Convolution(convOp) => + | MonteCarloStrat => runMonteCarlo(toSampleSetFn, arithmeticOperation, t1, t2) + | ConvolutionStrat(convOp) => runConvolution(toPointSetFn, convOp, t1, t2)->E.R2.fmap(r => DistributionTypes.PointSet( r, )) @@ -318,7 +286,7 @@ module AlgebraicCombination = { t1: t, ~toPointSetFn: toPointSetFn, ~toSampleSetFn: toSampleSetFn, - ~arithmeticOperation, + ~arithmeticOperation: Operation.algebraicOperation, ~t2: t, ): result => { switch strategy { @@ -326,20 +294,23 @@ module AlgebraicCombination = { | AsSymbolic => switch tryAnalyticalSimplification(arithmeticOperation, t1, t2) { | Some(#AnalyticalSolution(symbolicDist)) => Ok(Symbolic(symbolicDist)) - | Some(#NoSolution) - | None => - Error(RequestedStrategyInvalidError) + | Some(#NoSolution) => Error(RequestedStrategyInvalidError(`No analytical solution`)) + | None => Error(RequestedStrategyInvalidError("Inputs were not even symbolic")) | Some(#Error(err)) => Error(OperationError(err)) } - | AsConvolution - | AsMonteCarlo => - switch chooseConvolutionOrMonteCarlo(~strat=strategy, arithmeticOperation, t1, t2) { - | Ok(MonteCarlo) => runMonteCarlo(toSampleSetFn, arithmeticOperation, t1, t2) - | Ok(Convolution(convOp)) => - runConvolution(toPointSetFn, convOp, t1, t2)->E.R2.fmap(r => DistributionTypes.PointSet(r)) - | Error(RequestedStrategyInvalidError) => Error(RequestedStrategyInvalidError) - | Error(err) => Error(err) + | AsConvolution => { + let errString = opString => `Can't convolve on ${opString}` + switch arithmeticOperation { + | (#Add | #Subtract | #Multiply) as convOp => + runConvolution(toPointSetFn, convOp, t1, t2)->E.R2.fmap(r => DistributionTypes.PointSet( + r, + )) + | #Divide => "divide"->errString->RequestedStrategyInvalidError->Error + | #Power => "power"->errString->RequestedStrategyInvalidError->Error + | #Logarithm => "logarithm"->errString->RequestedStrategyInvalidError->Error + } } + | AsMonteCarlo => runMonteCarlo(toSampleSetFn, arithmeticOperation, t1, t2) } } }