From 6045fe5e6281f704e362bd70af413c7518a36cde Mon Sep 17 00:00:00 2001 From: Ozzie Gooen Date: Wed, 27 Apr 2022 12:48:46 -0400 Subject: [PATCH] Organized AlgebraicCombination functionality into submodules --- .../Distributions/GenericDist/GenericDist.res | 274 +++++++++--------- .../src/rescript/Utility/Operation.res | 12 + 2 files changed, 157 insertions(+), 129 deletions(-) diff --git a/packages/squiggle-lang/src/rescript/Distributions/GenericDist/GenericDist.res b/packages/squiggle-lang/src/rescript/Distributions/GenericDist/GenericDist.res index b27c3c13..15d6333b 100644 --- a/packages/squiggle-lang/src/rescript/Distributions/GenericDist/GenericDist.res +++ b/packages/squiggle-lang/src/rescript/Distributions/GenericDist/GenericDist.res @@ -147,129 +147,136 @@ let truncate = Truncate.run TODO: It would be useful to be able to pass in a paramater to get this to run either with convolution or monte carlo. */ module AlgebraicCombination = { - let runConvolution = ( - toPointSet: toPointSetFn, - arithmeticOperation: Operation.convolutionOperation, - t1: t, - t2: t, - ) => - E.R.merge(toPointSet(t1), toPointSet(t2))->E.R2.fmap(((a, b)) => - PointSetDist.combineAlgebraically(arithmeticOperation, a, b) - ) - - let runMonteCarlo = ( - toSampleSet: toSampleSetFn, - arithmeticOperation: Operation.algebraicOperation, - t1: t, - t2: t, - ): result => { - let fn = Operation.Algebraic.toFn(arithmeticOperation) - E.R.merge(toSampleSet(t1), toSampleSet(t2)) - ->E.R.bind(((t1, t2)) => { - SampleSetDist.map2(~fn, ~t1, ~t2)->E.R2.errMap(x => DistributionTypes.OperationError(x)) - }) - ->E.R2.fmap(r => DistributionTypes.SampleSet(r)) - } - - /* + module InputValidator = { + /* It would be good to also do a check to make sure that probability mass for the second operand, at value 1.0, is 0 (or approximately 0). However, we'd ideally want to check that both the probability mass and the probability density are greater than zero. Right now we don't yet have a way of getting probability mass, so I'll leave this for later. */ - let getLogarithmInputError = (t1: t, t2: t, ~toPointSetFn: toPointSetFn): option => { - let firstOperandIsGreaterThanZero = - toFloatOperation(t1, ~toPointSetFn, ~distToFloatOperation=#Cdf(1e-10)) |> E.R.fmap(r => - r > 0. - ) - let secondOperandIsGreaterThanZero = - toFloatOperation(t2, ~toPointSetFn, ~distToFloatOperation=#Cdf(1e-10)) |> E.R.fmap(r => - r > 0. - ) - let items = E.A.R.firstErrorOrOpen([ - firstOperandIsGreaterThanZero, - secondOperandIsGreaterThanZero, - ]) - switch items { - | Error(r) => Some(r) - | Ok([true, _]) => - Some(LogarithmOfDistributionError("First input must completely greater than 0")) - | Ok([false, true]) => - Some(LogarithmOfDistributionError("Second input must completely greater than 0")) - | Ok([false, false]) => None - | Ok(_) => Some(Unreachable) + let getLogarithmInputError = (t1: t, t2: t, ~toPointSetFn: toPointSetFn): option => { + let firstOperandIsGreaterThanZero = + toFloatOperation(t1, ~toPointSetFn, ~distToFloatOperation=#Cdf(1e-10)) |> E.R.fmap(r => + r > 0. + ) + let secondOperandIsGreaterThanZero = + toFloatOperation(t2, ~toPointSetFn, ~distToFloatOperation=#Cdf(1e-10)) |> E.R.fmap(r => + r > 0. + ) + let items = E.A.R.firstErrorOrOpen([ + firstOperandIsGreaterThanZero, + secondOperandIsGreaterThanZero, + ]) + switch items { + | Error(r) => Some(r) + | Ok([true, _]) => + Some(LogarithmOfDistributionError("First input must completely greater than 0")) + | Ok([false, true]) => + Some(LogarithmOfDistributionError("Second input must completely greater than 0")) + | Ok([false, false]) => None + | Ok(_) => Some(Unreachable) + } + } + + let run = (t1: t, t2: t, ~toPointSetFn: toPointSetFn, ~arithmeticOperation): option => { + if arithmeticOperation == #Logarithm { + getLogarithmInputError(t1, t2, ~toPointSetFn) + } else { + None + } } } - let getInvalidOperationError = ( - t1: t, - t2: t, - ~toPointSetFn: toPointSetFn, + module StrategyCallOnValidatedInputs = { + let convolution = ( + toPointSet: toPointSetFn, + arithmeticOperation: Operation.convolutionOperation, + t1: t, + t2: t, + ): result => + E.R.merge(toPointSet(t1), toPointSet(t2)) + ->E.R2.fmap(((a, b)) => PointSetDist.combineAlgebraically(arithmeticOperation, a, b)) + ->E.R2.fmap(r => DistributionTypes.PointSet(r)) + + let monteCarlo = ( + toSampleSet: toSampleSetFn, + arithmeticOperation: Operation.algebraicOperation, + t1: t, + t2: t, + ): result => { + let fn = Operation.Algebraic.toFn(arithmeticOperation) + E.R.merge(toSampleSet(t1), toSampleSet(t2)) + ->E.R.bind(((t1, t2)) => { + SampleSetDist.map2(~fn, ~t1, ~t2)->E.R2.errMap(x => DistributionTypes.OperationError(x)) + }) + ->E.R2.fmap(r => DistributionTypes.SampleSet(r)) + } + + let symbolic = ( + arithmeticOperation: Operation.algebraicOperation, + t1: t, + t2: t, + ): SymbolicDistTypes.analyticalSimplificationResult => { + switch (t1, t2) { + | (DistributionTypes.Symbolic(d1), DistributionTypes.Symbolic(d2)) => + SymbolicDist.T.tryAnalyticalSimplification(d1, d2, arithmeticOperation) + | _ => #NoSolution + } + } + } + + module StrategyChooser = { + type specificStrategy = [#AsSymbolic | #AsMonteCarlo | #AsConvolution] + + //I'm (Ozzie) really just guessing here, very little idea what's best + let expectedConvolutionCost: t => int = x => + switch x { + | Symbolic(#Float(_)) => 1 + | Symbolic(_) => 1000 + | PointSet(Discrete(m)) => m.xyShape->XYShape.T.length + | PointSet(Mixed(_)) => 1000 + | PointSet(Continuous(_)) => 1000 + | _ => 1000 + } + + let run = (~t1: t, ~t2: t, ~arithmeticOperation): specificStrategy => { + switch StrategyCallOnValidatedInputs.symbolic(arithmeticOperation, t1, t2) { + | #AnalyticalSolution(_) + | #Error(_) => + #AsSymbolic + | #NoSolution => + if Operation.Convolution.canDoAlgebraicOperation(arithmeticOperation) { + expectedConvolutionCost(t1) * expectedConvolutionCost(t2) > 10000 + ? #AsMonteCarlo + : #AsConvolution + } else { + #AsMonteCarlo + } + } + } + } + + let runStrategyOnValidatedInputs = ( + ~t1: t, + ~t2: t, ~arithmeticOperation, - ): option => { - if arithmeticOperation == #Logarithm { - getLogarithmInputError(t1, t2, ~toPointSetFn) - } else { - None - } - } - - //I'm (Ozzie) really just guessing here, very little idea what's best - let expectedConvolutionCost: t => int = x => - switch x { - | Symbolic(#Float(_)) => 1 - | Symbolic(_) => 1000 - | PointSet(Discrete(m)) => m.xyShape->XYShape.T.length - | PointSet(Mixed(_)) => 1000 - | PointSet(Continuous(_)) => 1000 - | _ => 1000 - } - - type calculationStrategy = MonteCarloStrat | ConvolutionStrat(Operation.convolutionOperation) - - let chooseConvolutionOrMonteCarloDefault = ( - op: Operation.algebraicOperation, - t2: t, - t1: t, - ): calculationStrategy => - switch op { - | #Divide - | #Power - | #Logarithm => - MonteCarloStrat - | (#Add | #Subtract | #Multiply) as convOp => - expectedConvolutionCost(t1) * expectedConvolutionCost(t2) > 10000 - ? MonteCarloStrat - : ConvolutionStrat(convOp) - } - - let tryAnalyticalSimplification = ( - arithmeticOperation: Operation.algebraicOperation, - t1: t, - t2: t, - ): SymbolicDistTypes.analyticalSimplificationResult => { - switch (t1, t2) { - | (DistributionTypes.Symbolic(d1), DistributionTypes.Symbolic(d2)) => - SymbolicDist.T.tryAnalyticalSimplification(d1, d2, arithmeticOperation) - | _ => #NoSolution - } - } - - let runDefault = ( - t1: t, + ~strategy: StrategyChooser.specificStrategy, ~toPointSetFn: toPointSetFn, ~toSampleSetFn: toSampleSetFn, - ~arithmeticOperation, - ~t2: t, ): result => { - switch tryAnalyticalSimplification(arithmeticOperation, t1, t2) { - | #AnalyticalSolution(symbolicDist) => Ok(Symbolic(symbolicDist)) - | #Error(e) => Error(OperationError(e)) - | #NoSolution => - switch chooseConvolutionOrMonteCarloDefault(arithmeticOperation, t1, t2) { - | MonteCarloStrat => runMonteCarlo(toSampleSetFn, arithmeticOperation, t1, t2) - | ConvolutionStrat(convOp) => - runConvolution(toPointSetFn, convOp, t1, t2)->E.R2.fmap(r => DistributionTypes.PointSet(r)) + switch strategy { + | #AsMonteCarlo => + StrategyCallOnValidatedInputs.monteCarlo(toSampleSetFn, arithmeticOperation, t1, t2) + | #AsSymbolic => + switch StrategyCallOnValidatedInputs.symbolic(arithmeticOperation, t1, t2) { + | #AnalyticalSolution(symbolicDist) => Ok(Symbolic(symbolicDist)) + | #Error(e) => Error(OperationError(e)) + | #NoSolution => Error(Unreachable) + } + | #AsConvolution => + switch Operation.Convolution.fromAlgebraicOperation(arithmeticOperation) { + | Some(convOp) => StrategyCallOnValidatedInputs.convolution(toPointSetFn, convOp, t1, t2) + | None => Error(Unreachable) } } } @@ -282,29 +289,38 @@ module AlgebraicCombination = { ~arithmeticOperation: Operation.algebraicOperation, ~t2: t, ): result => { - let invalidOperationError = getInvalidOperationError( - t1, - t2, - ~toPointSetFn, - ~arithmeticOperation, - ) - switch (invalidOperationError, strategy, arithmeticOperation) { - | (Some(e), _, _) => Error(e) - | (None, AsDefault, _) => - runDefault(t1, ~toPointSetFn, ~toSampleSetFn, ~arithmeticOperation, ~t2) - | (None, AsMonteCarlo, _) => runMonteCarlo(toSampleSetFn, arithmeticOperation, t1, t2) - | (None, AsSymbolic, _) => - switch tryAnalyticalSimplification(arithmeticOperation, t1, t2) { + let invalidOperationError = InputValidator.run(t1, t2, ~arithmeticOperation, ~toPointSetFn) + switch (invalidOperationError, strategy) { + | (Some(e), _) => Error(e) + | (None, AsDefault) => { + let chooseStrategy = StrategyChooser.run(~arithmeticOperation, ~t1, ~t2) + runStrategyOnValidatedInputs( + ~t1, + ~t2, + ~strategy=chooseStrategy, + ~arithmeticOperation, + ~toPointSetFn, + ~toSampleSetFn, + ) + } + | (None, AsMonteCarlo) => + StrategyCallOnValidatedInputs.monteCarlo(toSampleSetFn, arithmeticOperation, t1, t2) + | (None, AsSymbolic) => + switch StrategyCallOnValidatedInputs.symbolic(arithmeticOperation, t1, t2) { | #AnalyticalSolution(symbolicDist) => Ok(Symbolic(symbolicDist)) | #NoSolution => Error(RequestedStrategyInvalidError(`No analytic solution for inputs`)) | #Error(err) => Error(OperationError(err)) } - | (None, AsConvolution, (#Divide | #Power | #Logarithm) as convOp) => { - let errString = `Can't convolve on ${Operation.Algebraic.toString(convOp)}` - Error(RequestedStrategyInvalidError(errString)) + | (None, AsConvolution) => + switch Operation.Convolution.fromAlgebraicOperation(arithmeticOperation) { + | None => { + let errString = `Convolution not supported for ${Operation.Algebraic.toString( + arithmeticOperation, + )}` + Error(RequestedStrategyInvalidError(errString)) + } + | Some(convOp) => StrategyCallOnValidatedInputs.convolution(toPointSetFn, convOp, t1, t2) } - | (None, AsConvolution, (#Add | #Subtract | #Multiply) as convOp) => - runConvolution(toPointSetFn, convOp, t1, t2)->E.R2.fmap(r => DistributionTypes.PointSet(r)) } } } diff --git a/packages/squiggle-lang/src/rescript/Utility/Operation.res b/packages/squiggle-lang/src/rescript/Utility/Operation.res index ac83ceea..4a1ef91a 100644 --- a/packages/squiggle-lang/src/rescript/Utility/Operation.res +++ b/packages/squiggle-lang/src/rescript/Utility/Operation.res @@ -29,6 +29,18 @@ type distToFloatOperation = [ module Convolution = { type t = convolutionOperation + //Only a selection of operations are supported by convolution. + let fromAlgebraicOperation = (op: algebraicOperation): option => + switch op { + | #Add => Some(#Add) + | #Subtract => Some(#Subtract) + | #Multiply => Some(#Multiply) + | #Divide | #Power | #Logarithm => None + } + + let canDoAlgebraicOperation = (op: algebraicOperation): bool => + fromAlgebraicOperation(op)->E.O.isSome + let toFn: (t, float, float) => float = x => switch x { | #Add => \"+."