Merge pull request #402 from quantified-uncertainty/algebraic-combination-refactor
Algebraic combination refactor
This commit is contained in:
		
						commit
						20685ea8cb
					
				|  | @ -92,11 +92,11 @@ describe("eval on distribution functions", () => { | ||||||
|     testEval("log(2, uniform(5,8))", "Ok(Sample Set Distribution)") |     testEval("log(2, uniform(5,8))", "Ok(Sample Set Distribution)") | ||||||
|     testEval( |     testEval( | ||||||
|       "log(normal(5,2), 3)", |       "log(normal(5,2), 3)", | ||||||
|       "Error(Distribution Math Error: Logarithm of input error: First input must completely greater than 0)", |       "Error(Distribution Math Error: Logarithm of input error: First input must be completely greater than 0)", | ||||||
|     ) |     ) | ||||||
|     testEval( |     testEval( | ||||||
|       "log(normal(5,2), normal(10,1))", |       "log(normal(5,2), normal(10,1))", | ||||||
|       "Error(Distribution Math Error: Logarithm of input error: First input must completely greater than 0)", |       "Error(Distribution Math Error: Logarithm of input error: First input must be completely greater than 0)", | ||||||
|     ) |     ) | ||||||
|     testEval("log(uniform(5,8))", "Ok(Sample Set Distribution)") |     testEval("log(uniform(5,8))", "Ok(Sample Set Distribution)") | ||||||
|     testEval("log10(uniform(5,8))", "Ok(Sample Set Distribution)") |     testEval("log10(uniform(5,8))", "Ok(Sample Set Distribution)") | ||||||
|  |  | ||||||
|  | @ -150,144 +150,143 @@ let truncate = Truncate.run | ||||||
|    of a new variable that is the result of the operation on A and B. |    of a new variable that is the result of the operation on A and B. | ||||||
|    For instance, normal(0, 1) + normal(1, 1) -> normal(1, 2). |    For instance, normal(0, 1) + normal(1, 1) -> normal(1, 2). | ||||||
|    In general, this is implemented via convolution. |    In general, this is implemented via convolution. | ||||||
| 
 |  | ||||||
|   TODO: It would be useful to be able to pass in a paramater to get this to run either with convolution or monte carlo. |  | ||||||
| */ | */ | ||||||
| module AlgebraicCombination = { | module AlgebraicCombination = { | ||||||
|   let runConvolution = ( |   module InputValidator = { | ||||||
|     toPointSet: toPointSetFn, |     /* | ||||||
|     arithmeticOperation: Operation.convolutionOperation, |  | ||||||
|     t1: t, |  | ||||||
|     t2: t, |  | ||||||
|   ) => |  | ||||||
|     E.R.merge(toPointSet(t1), toPointSet(t2))->E.R2.fmap(((a, b)) => |  | ||||||
|       PointSetDist.combineAlgebraically(arithmeticOperation, a, b) |  | ||||||
|     ) |  | ||||||
| 
 |  | ||||||
|   let runMonteCarlo = ( |  | ||||||
|     toSampleSet: toSampleSetFn, |  | ||||||
|     arithmeticOperation: Operation.algebraicOperation, |  | ||||||
|     t1: t, |  | ||||||
|     t2: t, |  | ||||||
|   ): result<t, error> => { |  | ||||||
|     let fn = Operation.Algebraic.toFn(arithmeticOperation) |  | ||||||
|     E.R.merge(toSampleSet(t1), toSampleSet(t2)) |  | ||||||
|     ->E.R.bind(((t1, t2)) => { |  | ||||||
|       SampleSetDist.map2(~fn, ~t1, ~t2)->E.R2.errMap(x => DistributionTypes.OperationError(x)) |  | ||||||
|     }) |  | ||||||
|     ->E.R2.fmap(r => DistributionTypes.SampleSet(r)) |  | ||||||
|   } |  | ||||||
| 
 |  | ||||||
|   /* |  | ||||||
|      It would be good to also do a check to make sure that probability mass for the second |      It would be good to also do a check to make sure that probability mass for the second | ||||||
|      operand, at value 1.0, is 0 (or approximately 0). However, we'd ideally want to check  |      operand, at value 1.0, is 0 (or approximately 0). However, we'd ideally want to check  | ||||||
|      that both the probability mass and the probability density are greater than zero. |      that both the probability mass and the probability density are greater than zero. | ||||||
|      Right now we don't yet have a way of getting probability mass, so I'll leave this for later. |      Right now we don't yet have a way of getting probability mass, so I'll leave this for later. | ||||||
|  */ |  */ | ||||||
|   let getLogarithmInputError = (t1: t, t2: t, ~toPointSetFn: toPointSetFn): option<error> => { |     let getLogarithmInputError = (t1: t, t2: t, ~toPointSetFn: toPointSetFn): option<error> => { | ||||||
|     let firstOperandIsGreaterThanZero = |       let firstOperandIsGreaterThanZero = | ||||||
|       toFloatOperation( |         toFloatOperation( | ||||||
|         t1, |           t1, | ||||||
|         ~toPointSetFn, |           ~toPointSetFn, | ||||||
|         ~distToFloatOperation=#Cdf(MagicNumbers.Epsilon.ten), |           ~distToFloatOperation=#Cdf(MagicNumbers.Epsilon.ten), | ||||||
|       ) |> E.R.fmap(r => r > 0.) |         ) |> E.R.fmap(r => r > 0.) | ||||||
|     let secondOperandIsGreaterThanZero = |       let secondOperandIsGreaterThanZero = | ||||||
|       toFloatOperation( |         toFloatOperation( | ||||||
|         t2, |           t2, | ||||||
|         ~toPointSetFn, |           ~toPointSetFn, | ||||||
|         ~distToFloatOperation=#Cdf(MagicNumbers.Epsilon.ten), |           ~distToFloatOperation=#Cdf(MagicNumbers.Epsilon.ten), | ||||||
|       ) |> E.R.fmap(r => r > 0.) |         ) |> E.R.fmap(r => r > 0.) | ||||||
|     let items = E.A.R.firstErrorOrOpen([ |       let items = E.A.R.firstErrorOrOpen([ | ||||||
|       firstOperandIsGreaterThanZero, |         firstOperandIsGreaterThanZero, | ||||||
|       secondOperandIsGreaterThanZero, |         secondOperandIsGreaterThanZero, | ||||||
|     ]) |       ]) | ||||||
|     switch items { |       switch items { | ||||||
|     | Error(r) => Some(r) |       | Error(r) => Some(r) | ||||||
|     | Ok([true, _]) => |       | Ok([true, _]) => | ||||||
|       Some(LogarithmOfDistributionError("First input must completely greater than 0")) |         Some(LogarithmOfDistributionError("First input must be completely greater than 0")) | ||||||
|     | Ok([false, true]) => |       | Ok([false, true]) => | ||||||
|       Some(LogarithmOfDistributionError("Second input must completely greater than 0")) |         Some(LogarithmOfDistributionError("Second input must be completely greater than 0")) | ||||||
|     | Ok([false, false]) => None |       | Ok([false, false]) => None | ||||||
|     | Ok(_) => Some(Unreachable) |       | Ok(_) => Some(Unreachable) | ||||||
|  |       } | ||||||
|  |     } | ||||||
|  | 
 | ||||||
|  |     let run = (t1: t, t2: t, ~toPointSetFn: toPointSetFn, ~arithmeticOperation): option<error> => { | ||||||
|  |       if arithmeticOperation == #Logarithm { | ||||||
|  |         getLogarithmInputError(t1, t2, ~toPointSetFn) | ||||||
|  |       } else { | ||||||
|  |         None | ||||||
|  |       } | ||||||
|     } |     } | ||||||
|   } |   } | ||||||
| 
 | 
 | ||||||
|   let getInvalidOperationError = ( |   module StrategyCallOnValidatedInputs = { | ||||||
|     t1: t, |     let convolution = ( | ||||||
|     t2: t, |       toPointSet: toPointSetFn, | ||||||
|     ~toPointSetFn: toPointSetFn, |       arithmeticOperation: Operation.convolutionOperation, | ||||||
|  |       t1: t, | ||||||
|  |       t2: t, | ||||||
|  |     ): result<t, error> => | ||||||
|  |       E.R.merge(toPointSet(t1), toPointSet(t2)) | ||||||
|  |       ->E.R2.fmap(((a, b)) => PointSetDist.combineAlgebraically(arithmeticOperation, a, b)) | ||||||
|  |       ->E.R2.fmap(r => DistributionTypes.PointSet(r)) | ||||||
|  | 
 | ||||||
|  |     let monteCarlo = ( | ||||||
|  |       toSampleSet: toSampleSetFn, | ||||||
|  |       arithmeticOperation: Operation.algebraicOperation, | ||||||
|  |       t1: t, | ||||||
|  |       t2: t, | ||||||
|  |     ): result<t, error> => { | ||||||
|  |       let fn = Operation.Algebraic.toFn(arithmeticOperation) | ||||||
|  |       E.R.merge(toSampleSet(t1), toSampleSet(t2)) | ||||||
|  |       ->E.R.bind(((t1, t2)) => { | ||||||
|  |         SampleSetDist.map2(~fn, ~t1, ~t2)->E.R2.errMap(x => DistributionTypes.OperationError(x)) | ||||||
|  |       }) | ||||||
|  |       ->E.R2.fmap(r => DistributionTypes.SampleSet(r)) | ||||||
|  |     } | ||||||
|  | 
 | ||||||
|  |     let symbolic = ( | ||||||
|  |       arithmeticOperation: Operation.algebraicOperation, | ||||||
|  |       t1: t, | ||||||
|  |       t2: t, | ||||||
|  |     ): SymbolicDistTypes.analyticalSimplificationResult => { | ||||||
|  |       switch (t1, t2) { | ||||||
|  |       | (DistributionTypes.Symbolic(d1), DistributionTypes.Symbolic(d2)) => | ||||||
|  |         SymbolicDist.T.tryAnalyticalSimplification(d1, d2, arithmeticOperation) | ||||||
|  |       | _ => #NoSolution | ||||||
|  |       } | ||||||
|  |     } | ||||||
|  |   } | ||||||
|  | 
 | ||||||
|  |   module StrategyChooser = { | ||||||
|  |     type specificStrategy = [#AsSymbolic | #AsMonteCarlo | #AsConvolution] | ||||||
|  | 
 | ||||||
|  |     //I'm (Ozzie) really just guessing here, very little idea what's best | ||||||
|  |     let expectedConvolutionCost: t => int = x => | ||||||
|  |       switch x { | ||||||
|  |       | Symbolic(#Float(_)) => MagicNumbers.OpCost.floatCost | ||||||
|  |       | Symbolic(_) => MagicNumbers.OpCost.symbolicCost | ||||||
|  |       | PointSet(Discrete(m)) => m.xyShape->XYShape.T.length | ||||||
|  |       | PointSet(Mixed(_)) => MagicNumbers.OpCost.mixedCost | ||||||
|  |       | PointSet(Continuous(_)) => MagicNumbers.OpCost.continuousCost | ||||||
|  |       | _ => MagicNumbers.OpCost.wildcardCost | ||||||
|  |       } | ||||||
|  | 
 | ||||||
|  |     let run = (~t1: t, ~t2: t, ~arithmeticOperation): specificStrategy => { | ||||||
|  |       switch StrategyCallOnValidatedInputs.symbolic(arithmeticOperation, t1, t2) { | ||||||
|  |       | #AnalyticalSolution(_) | ||||||
|  |       | #Error(_) => | ||||||
|  |         #AsSymbolic | ||||||
|  |       | #NoSolution => | ||||||
|  |         if Operation.Convolution.canDoAlgebraicOperation(arithmeticOperation) { | ||||||
|  |           expectedConvolutionCost(t1) * expectedConvolutionCost(t2) > | ||||||
|  |             MagicNumbers.OpCost.monteCarloCost | ||||||
|  |             ? #AsMonteCarlo | ||||||
|  |             : #AsConvolution | ||||||
|  |         } else { | ||||||
|  |           #AsMonteCarlo | ||||||
|  |         } | ||||||
|  |       } | ||||||
|  |     } | ||||||
|  |   } | ||||||
|  | 
 | ||||||
|  |   let runStrategyOnValidatedInputs = ( | ||||||
|  |     ~t1: t, | ||||||
|  |     ~t2: t, | ||||||
|     ~arithmeticOperation, |     ~arithmeticOperation, | ||||||
|   ): option<error> => { |     ~strategy: StrategyChooser.specificStrategy, | ||||||
|     if arithmeticOperation == #Logarithm { |  | ||||||
|       getLogarithmInputError(t1, t2, ~toPointSetFn) |  | ||||||
|     } else { |  | ||||||
|       None |  | ||||||
|     } |  | ||||||
|   } |  | ||||||
| 
 |  | ||||||
|   //I'm (Ozzie) really just guessing here, very little idea what's best |  | ||||||
|   let expectedConvolutionCost: t => int = x => |  | ||||||
|     switch x { |  | ||||||
|     | Symbolic(#Float(_)) => MagicNumbers.OpCost.floatCost |  | ||||||
|     | Symbolic(_) => MagicNumbers.OpCost.symbolicCost |  | ||||||
|     | PointSet(Discrete(m)) => m.xyShape->XYShape.T.length |  | ||||||
|     | PointSet(Mixed(_)) => MagicNumbers.OpCost.mixedCost |  | ||||||
|     | PointSet(Continuous(_)) => MagicNumbers.OpCost.continuousCost |  | ||||||
|     | _ => MagicNumbers.OpCost.wildcardCost |  | ||||||
|     } |  | ||||||
| 
 |  | ||||||
|   type calculationStrategy = MonteCarloStrat | ConvolutionStrat(Operation.convolutionOperation) |  | ||||||
| 
 |  | ||||||
|   let chooseConvolutionOrMonteCarloDefault = ( |  | ||||||
|     op: Operation.algebraicOperation, |  | ||||||
|     t2: t, |  | ||||||
|     t1: t, |  | ||||||
|   ): calculationStrategy => |  | ||||||
|     switch op { |  | ||||||
|     | #Divide |  | ||||||
|     | #Power |  | ||||||
|     | #Logarithm => |  | ||||||
|       MonteCarloStrat |  | ||||||
|     | (#Add | #Subtract | #Multiply) as convOp => |  | ||||||
|       expectedConvolutionCost(t1) * expectedConvolutionCost(t2) > MagicNumbers.OpCost.monteCarloCost |  | ||||||
|         ? MonteCarloStrat |  | ||||||
|         : ConvolutionStrat(convOp) |  | ||||||
|     } |  | ||||||
| 
 |  | ||||||
|   let tryAnalyticalSimplification = ( |  | ||||||
|     arithmeticOperation: Operation.algebraicOperation, |  | ||||||
|     t1: t, |  | ||||||
|     t2: t, |  | ||||||
|   ): option<SymbolicDistTypes.analyticalSimplificationResult> => { |  | ||||||
|     switch (t1, t2) { |  | ||||||
|     | (DistributionTypes.Symbolic(d1), DistributionTypes.Symbolic(d2)) => |  | ||||||
|       Some(SymbolicDist.T.tryAnalyticalSimplification(d1, d2, arithmeticOperation)) |  | ||||||
|     | _ => None |  | ||||||
|     } |  | ||||||
|   } |  | ||||||
| 
 |  | ||||||
|   let runDefault = ( |  | ||||||
|     t1: t, |  | ||||||
|     ~toPointSetFn: toPointSetFn, |     ~toPointSetFn: toPointSetFn, | ||||||
|     ~toSampleSetFn: toSampleSetFn, |     ~toSampleSetFn: toSampleSetFn, | ||||||
|     ~arithmeticOperation, |  | ||||||
|     ~t2: t, |  | ||||||
|   ): result<t, error> => { |   ): result<t, error> => { | ||||||
|     switch tryAnalyticalSimplification(arithmeticOperation, t1, t2) { |     switch strategy { | ||||||
|     | Some(#AnalyticalSolution(symbolicDist)) => Ok(Symbolic(symbolicDist)) |     | #AsMonteCarlo => | ||||||
|     | Some(#Error(e)) => Error(OperationError(e)) |       StrategyCallOnValidatedInputs.monteCarlo(toSampleSetFn, arithmeticOperation, t1, t2) | ||||||
|     | Some(#NoSolution) |     | #AsSymbolic => | ||||||
|     | None => |       switch StrategyCallOnValidatedInputs.symbolic(arithmeticOperation, t1, t2) { | ||||||
|       switch getInvalidOperationError(t1, t2, ~toPointSetFn, ~arithmeticOperation) { |       | #AnalyticalSolution(symbolicDist) => Ok(Symbolic(symbolicDist)) | ||||||
|       | Some(e) => Error(e) |       | #Error(e) => Error(OperationError(e)) | ||||||
|       | None => |       | #NoSolution => Error(Unreachable) | ||||||
|         switch chooseConvolutionOrMonteCarloDefault(arithmeticOperation, t1, t2) { |       } | ||||||
|         | MonteCarloStrat => runMonteCarlo(toSampleSetFn, arithmeticOperation, t1, t2) |     | #AsConvolution => | ||||||
|         | ConvolutionStrat(convOp) => |       switch Operation.Convolution.fromAlgebraicOperation(arithmeticOperation) { | ||||||
|           runConvolution(toPointSetFn, convOp, t1, t2)->E.R2.fmap(r => DistributionTypes.PointSet( |       | Some(convOp) => StrategyCallOnValidatedInputs.convolution(toPointSetFn, convOp, t1, t2) | ||||||
|             r, |       | None => Error(Unreachable) | ||||||
|           )) |  | ||||||
|         } |  | ||||||
|       } |       } | ||||||
|     } |     } | ||||||
|   } |   } | ||||||
|  | @ -300,27 +299,38 @@ module AlgebraicCombination = { | ||||||
|     ~arithmeticOperation: Operation.algebraicOperation, |     ~arithmeticOperation: Operation.algebraicOperation, | ||||||
|     ~t2: t, |     ~t2: t, | ||||||
|   ): result<t, error> => { |   ): result<t, error> => { | ||||||
|     switch strategy { |     let invalidOperationError = InputValidator.run(t1, t2, ~arithmeticOperation, ~toPointSetFn) | ||||||
|     | AsDefault => runDefault(t1, ~toPointSetFn, ~toSampleSetFn, ~arithmeticOperation, ~t2) |     switch (invalidOperationError, strategy) { | ||||||
|     | AsSymbolic => |     | (Some(e), _) => Error(e) | ||||||
|       switch tryAnalyticalSimplification(arithmeticOperation, t1, t2) { |     | (None, AsDefault) => { | ||||||
|       | Some(#AnalyticalSolution(symbolicDist)) => Ok(Symbolic(symbolicDist)) |         let chooseStrategy = StrategyChooser.run(~arithmeticOperation, ~t1, ~t2) | ||||||
|       | Some(#NoSolution) => Error(RequestedStrategyInvalidError(`No analytical solution`)) |         runStrategyOnValidatedInputs( | ||||||
|       | None => Error(RequestedStrategyInvalidError("Inputs were not even symbolic")) |           ~t1, | ||||||
|       | Some(#Error(err)) => Error(OperationError(err)) |           ~t2, | ||||||
|  |           ~strategy=chooseStrategy, | ||||||
|  |           ~arithmeticOperation, | ||||||
|  |           ~toPointSetFn, | ||||||
|  |           ~toSampleSetFn, | ||||||
|  |         ) | ||||||
|       } |       } | ||||||
|     | AsConvolution => { |     | (None, AsMonteCarlo) => | ||||||
|         let errString = opString => `Can't convolve on ${opString}` |       StrategyCallOnValidatedInputs.monteCarlo(toSampleSetFn, arithmeticOperation, t1, t2) | ||||||
|         switch arithmeticOperation { |     | (None, AsSymbolic) => | ||||||
|         | (#Add | #Subtract | #Multiply) as convOp => |       switch StrategyCallOnValidatedInputs.symbolic(arithmeticOperation, t1, t2) { | ||||||
|           runConvolution(toPointSetFn, convOp, t1, t2)->E.R2.fmap(r => DistributionTypes.PointSet( |       | #AnalyticalSolution(symbolicDist) => Ok(Symbolic(symbolicDist)) | ||||||
|             r, |       | #NoSolution => Error(RequestedStrategyInvalidError(`No analytic solution for inputs`)) | ||||||
|           )) |       | #Error(err) => Error(OperationError(err)) | ||||||
|         | (#Divide | #Power | #Logarithm) as op => |       } | ||||||
|           op->Operation.Algebraic.toString->errString->RequestedStrategyInvalidError->Error |     | (None, AsConvolution) => | ||||||
|  |       switch Operation.Convolution.fromAlgebraicOperation(arithmeticOperation) { | ||||||
|  |       | None => { | ||||||
|  |           let errString = `Convolution not supported for ${Operation.Algebraic.toString( | ||||||
|  |               arithmeticOperation, | ||||||
|  |             )}` | ||||||
|  |           Error(RequestedStrategyInvalidError(errString)) | ||||||
|         } |         } | ||||||
|  |       | Some(convOp) => StrategyCallOnValidatedInputs.convolution(toPointSetFn, convOp, t1, t2) | ||||||
|       } |       } | ||||||
|     | AsMonteCarlo => runMonteCarlo(toSampleSetFn, arithmeticOperation, t1, t2) |  | ||||||
|     } |     } | ||||||
|   } |   } | ||||||
| } | } | ||||||
|  |  | ||||||
|  | @ -29,6 +29,18 @@ type distToFloatOperation = [ | ||||||
| 
 | 
 | ||||||
| module Convolution = { | module Convolution = { | ||||||
|   type t = convolutionOperation |   type t = convolutionOperation | ||||||
|  |   //Only a selection of operations are supported by convolution. | ||||||
|  |   let fromAlgebraicOperation = (op: algebraicOperation): option<convolutionOperation> => | ||||||
|  |     switch op { | ||||||
|  |     | #Add => Some(#Add) | ||||||
|  |     | #Subtract => Some(#Subtract) | ||||||
|  |     | #Multiply => Some(#Multiply) | ||||||
|  |     | #Divide | #Power | #Logarithm => None | ||||||
|  |     } | ||||||
|  | 
 | ||||||
|  |   let canDoAlgebraicOperation = (op: algebraicOperation): bool => | ||||||
|  |     fromAlgebraicOperation(op)->E.O.isSome | ||||||
|  | 
 | ||||||
|   let toFn: (t, float, float) => float = x => |   let toFn: (t, float, float) => float = x => | ||||||
|     switch x { |     switch x { | ||||||
|     | #Add => \"+." |     | #Add => \"+." | ||||||
|  |  | ||||||
		Loading…
	
		Reference in New Issue
	
	Block a user