diff --git a/packages/squiggle-lang/src/rescript/Distributions/DistributionTypes.res b/packages/squiggle-lang/src/rescript/Distributions/DistributionTypes.res index e94f8421..c16ea1b2 100644 --- a/packages/squiggle-lang/src/rescript/Distributions/DistributionTypes.res +++ b/packages/squiggle-lang/src/rescript/Distributions/DistributionTypes.res @@ -4,7 +4,7 @@ type genericDist = | SampleSet(SampleSetDist.t) | Symbolic(SymbolicDistTypes.symbolicDist) -type asAlgebraicCombinationStrategy = AsDefault | AsSymbolic | AsMontecarlo | AsConvolution +type asAlgebraicCombinationStrategy = AsDefault | AsSymbolic | AsMonteCarlo | AsConvolution @genType type error = @@ -38,7 +38,7 @@ module Error = { | OperationError(err) => Operation.Error.toString(err) | PointSetConversionError(err) => SampleSetDist.pointsetConversionErrorToString(err) | SparklineError(err) => PointSetTypes.sparklineErrorToString(err) - | RequestedStrategyInvalidError => `Requested mode invalid` + | RequestedStrategyInvalidError => `Requested strategy invalid` | OtherError(s) => s } diff --git a/packages/squiggle-lang/src/rescript/Distributions/GenericDist/GenericDist.res b/packages/squiggle-lang/src/rescript/Distributions/GenericDist/GenericDist.res index c5308ad3..25169b2f 100644 --- a/packages/squiggle-lang/src/rescript/Distributions/GenericDist/GenericDist.res +++ b/packages/squiggle-lang/src/rescript/Distributions/GenericDist/GenericDist.res @@ -147,21 +147,6 @@ let truncate = Truncate.run TODO: It would be useful to be able to pass in a paramater to get this to run either with convolution or monte carlo. */ module AlgebraicCombination = { - let tryAnalyticalSimplification = ( - arithmeticOperation: Operation.algebraicOperation, - t1: t, - t2: t, - ): option> => - switch (arithmeticOperation, t1, t2) { - | (arithmeticOperation, Symbolic(d1), Symbolic(d2)) => - switch SymbolicDist.T.tryAnalyticalSimplification(d1, d2, arithmeticOperation) { - | #AnalyticalSolution(symbolicDist) => Some(Ok(symbolicDist)) - | #Error(er) => Some(Error(er)) - | #NoSolution => None - } - | _ => None - } - let runConvolution = ( toPointSet: toPointSetFn, arithmeticOperation: Operation.convolutionOperation, @@ -271,11 +256,63 @@ module AlgebraicCombination = { | #Divide | #Power | #Logarithm => Error(RequestedStrategyInvalidError) | (#Add | #Subtract | #Multiply) as convOp => Ok(Convolution(convOp)) } - | AsMontecarlo => Ok(MonteCarlo) + | AsMonteCarlo => Ok(MonteCarlo) | AsSymbolic => Error(RequestedStrategyInvalidError) } } + let tryAnalyticalSimplificationDefault = ( + arithmeticOperation: Operation.algebraicOperation, + t1: t, + t2: t, + ): option> => + switch (t1, t2) { + | (Symbolic(d1), Symbolic(d2)) => + switch SymbolicDist.T.tryAnalyticalSimplification(d1, d2, arithmeticOperation) { + | #AnalyticalSolution(symbolicDist) => Some(Ok(symbolicDist)) + | #Error(er) => Some(Error(er)) + | #NoSolution => None + } + | _ => None + } + + let tryAnalyticalSimplification = ( + arithmeticOperation: Operation.algebraicOperation, + t1: t, + t2: t, + ): option => { + switch (t1, t2) { + | (DistributionTypes.Symbolic(d1), DistributionTypes.Symbolic(d2)) => + Some(SymbolicDist.T.tryAnalyticalSimplification(d1, d2, arithmeticOperation)) + | _ => None + } + } + + let runDefault = ( + t1: t, + ~toPointSetFn: toPointSetFn, + ~toSampleSetFn: toSampleSetFn, + ~arithmeticOperation, + ~t2: t, + ): result => { + switch tryAnalyticalSimplificationDefault(arithmeticOperation, t1, t2) { + | Some(Ok(symbolicDist)) => Ok(Symbolic(symbolicDist)) + | Some(Error(e)) => Error(OperationError(e)) + | None => + switch getInvalidOperationError(t1, t2, ~toPointSetFn, ~arithmeticOperation) { + | Some(e) => Error(e) + | None => + switch chooseConvolutionOrMonteCarloDefault(arithmeticOperation, t1, t2) { + | MonteCarlo => runMonteCarlo(toSampleSetFn, arithmeticOperation, t1, t2) + | Convolution(convOp) => + runConvolution(toPointSetFn, convOp, t1, t2)->E.R2.fmap(r => DistributionTypes.PointSet( + r, + )) + } + } + } + } + let run = ( ~strategy: DistributionTypes.asAlgebraicCombinationStrategy, t1: t, @@ -284,22 +321,24 @@ module AlgebraicCombination = { ~arithmeticOperation, ~t2: t, ): result => { - switch tryAnalyticalSimplification(arithmeticOperation, t1, t2) { - | Some(Ok(symbolicDist)) => Ok(Symbolic(symbolicDist)) - | Some(Error(e)) => Error(OperationError(e)) - | None => - switch getInvalidOperationError(t1, t2, ~toPointSetFn, ~arithmeticOperation) { - | Some(e) => Error(e) + switch strategy { + | AsDefault => runDefault(t1, ~toPointSetFn, ~toSampleSetFn, ~arithmeticOperation, ~t2) + | AsSymbolic => + switch tryAnalyticalSimplification(arithmeticOperation, t1, t2) { + | Some(#AnalyticalSolution(symbolicDist)) => Ok(Symbolic(symbolicDist)) + | Some(#NoSolution) | None => - switch chooseConvolutionOrMonteCarlo(~strat=strategy, arithmeticOperation, t1, t2) { - | Ok(MonteCarlo) => runMonteCarlo(toSampleSetFn, arithmeticOperation, t1, t2) - | Ok(Convolution(convOp)) => - runConvolution(toPointSetFn, convOp, t1, t2)->E.R2.fmap(r => DistributionTypes.PointSet( - r, - )) - | Error(RequestedStrategyInvalidError) => Error(RequestedStrategyInvalidError) - | Error(err) => Error(err) - } + Error(RequestedStrategyInvalidError) + | Some(#Error(err)) => Error(OperationError(err)) + } + | AsConvolution + | AsMonteCarlo => + switch chooseConvolutionOrMonteCarlo(~strat=strategy, arithmeticOperation, t1, t2) { + | Ok(MonteCarlo) => runMonteCarlo(toSampleSetFn, arithmeticOperation, t1, t2) + | Ok(Convolution(convOp)) => + runConvolution(toPointSetFn, convOp, t1, t2)->E.R2.fmap(r => DistributionTypes.PointSet(r)) + | Error(RequestedStrategyInvalidError) => Error(RequestedStrategyInvalidError) + | Error(err) => Error(err) } } }