Merge pull request #402 from quantified-uncertainty/algebraic-combination-refactor
Algebraic combination refactor
This commit is contained in:
commit
20685ea8cb
|
@ -92,11 +92,11 @@ describe("eval on distribution functions", () => {
|
|||
testEval("log(2, uniform(5,8))", "Ok(Sample Set Distribution)")
|
||||
testEval(
|
||||
"log(normal(5,2), 3)",
|
||||
"Error(Distribution Math Error: Logarithm of input error: First input must completely greater than 0)",
|
||||
"Error(Distribution Math Error: Logarithm of input error: First input must be completely greater than 0)",
|
||||
)
|
||||
testEval(
|
||||
"log(normal(5,2), normal(10,1))",
|
||||
"Error(Distribution Math Error: Logarithm of input error: First input must completely greater than 0)",
|
||||
"Error(Distribution Math Error: Logarithm of input error: First input must be completely greater than 0)",
|
||||
)
|
||||
testEval("log(uniform(5,8))", "Ok(Sample Set Distribution)")
|
||||
testEval("log10(uniform(5,8))", "Ok(Sample Set Distribution)")
|
||||
|
|
|
@ -150,34 +150,9 @@ let truncate = Truncate.run
|
|||
of a new variable that is the result of the operation on A and B.
|
||||
For instance, normal(0, 1) + normal(1, 1) -> normal(1, 2).
|
||||
In general, this is implemented via convolution.
|
||||
|
||||
TODO: It would be useful to be able to pass in a paramater to get this to run either with convolution or monte carlo.
|
||||
*/
|
||||
module AlgebraicCombination = {
|
||||
let runConvolution = (
|
||||
toPointSet: toPointSetFn,
|
||||
arithmeticOperation: Operation.convolutionOperation,
|
||||
t1: t,
|
||||
t2: t,
|
||||
) =>
|
||||
E.R.merge(toPointSet(t1), toPointSet(t2))->E.R2.fmap(((a, b)) =>
|
||||
PointSetDist.combineAlgebraically(arithmeticOperation, a, b)
|
||||
)
|
||||
|
||||
let runMonteCarlo = (
|
||||
toSampleSet: toSampleSetFn,
|
||||
arithmeticOperation: Operation.algebraicOperation,
|
||||
t1: t,
|
||||
t2: t,
|
||||
): result<t, error> => {
|
||||
let fn = Operation.Algebraic.toFn(arithmeticOperation)
|
||||
E.R.merge(toSampleSet(t1), toSampleSet(t2))
|
||||
->E.R.bind(((t1, t2)) => {
|
||||
SampleSetDist.map2(~fn, ~t1, ~t2)->E.R2.errMap(x => DistributionTypes.OperationError(x))
|
||||
})
|
||||
->E.R2.fmap(r => DistributionTypes.SampleSet(r))
|
||||
}
|
||||
|
||||
module InputValidator = {
|
||||
/*
|
||||
It would be good to also do a check to make sure that probability mass for the second
|
||||
operand, at value 1.0, is 0 (or approximately 0). However, we'd ideally want to check
|
||||
|
@ -204,26 +179,63 @@ module AlgebraicCombination = {
|
|||
switch items {
|
||||
| Error(r) => Some(r)
|
||||
| Ok([true, _]) =>
|
||||
Some(LogarithmOfDistributionError("First input must completely greater than 0"))
|
||||
Some(LogarithmOfDistributionError("First input must be completely greater than 0"))
|
||||
| Ok([false, true]) =>
|
||||
Some(LogarithmOfDistributionError("Second input must completely greater than 0"))
|
||||
Some(LogarithmOfDistributionError("Second input must be completely greater than 0"))
|
||||
| Ok([false, false]) => None
|
||||
| Ok(_) => Some(Unreachable)
|
||||
}
|
||||
}
|
||||
|
||||
let getInvalidOperationError = (
|
||||
t1: t,
|
||||
t2: t,
|
||||
~toPointSetFn: toPointSetFn,
|
||||
~arithmeticOperation,
|
||||
): option<error> => {
|
||||
let run = (t1: t, t2: t, ~toPointSetFn: toPointSetFn, ~arithmeticOperation): option<error> => {
|
||||
if arithmeticOperation == #Logarithm {
|
||||
getLogarithmInputError(t1, t2, ~toPointSetFn)
|
||||
} else {
|
||||
None
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
module StrategyCallOnValidatedInputs = {
|
||||
let convolution = (
|
||||
toPointSet: toPointSetFn,
|
||||
arithmeticOperation: Operation.convolutionOperation,
|
||||
t1: t,
|
||||
t2: t,
|
||||
): result<t, error> =>
|
||||
E.R.merge(toPointSet(t1), toPointSet(t2))
|
||||
->E.R2.fmap(((a, b)) => PointSetDist.combineAlgebraically(arithmeticOperation, a, b))
|
||||
->E.R2.fmap(r => DistributionTypes.PointSet(r))
|
||||
|
||||
let monteCarlo = (
|
||||
toSampleSet: toSampleSetFn,
|
||||
arithmeticOperation: Operation.algebraicOperation,
|
||||
t1: t,
|
||||
t2: t,
|
||||
): result<t, error> => {
|
||||
let fn = Operation.Algebraic.toFn(arithmeticOperation)
|
||||
E.R.merge(toSampleSet(t1), toSampleSet(t2))
|
||||
->E.R.bind(((t1, t2)) => {
|
||||
SampleSetDist.map2(~fn, ~t1, ~t2)->E.R2.errMap(x => DistributionTypes.OperationError(x))
|
||||
})
|
||||
->E.R2.fmap(r => DistributionTypes.SampleSet(r))
|
||||
}
|
||||
|
||||
let symbolic = (
|
||||
arithmeticOperation: Operation.algebraicOperation,
|
||||
t1: t,
|
||||
t2: t,
|
||||
): SymbolicDistTypes.analyticalSimplificationResult => {
|
||||
switch (t1, t2) {
|
||||
| (DistributionTypes.Symbolic(d1), DistributionTypes.Symbolic(d2)) =>
|
||||
SymbolicDist.T.tryAnalyticalSimplification(d1, d2, arithmeticOperation)
|
||||
| _ => #NoSolution
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
module StrategyChooser = {
|
||||
type specificStrategy = [#AsSymbolic | #AsMonteCarlo | #AsConvolution]
|
||||
|
||||
//I'm (Ozzie) really just guessing here, very little idea what's best
|
||||
let expectedConvolutionCost: t => int = x =>
|
||||
|
@ -236,58 +248,45 @@ module AlgebraicCombination = {
|
|||
| _ => MagicNumbers.OpCost.wildcardCost
|
||||
}
|
||||
|
||||
type calculationStrategy = MonteCarloStrat | ConvolutionStrat(Operation.convolutionOperation)
|
||||
|
||||
let chooseConvolutionOrMonteCarloDefault = (
|
||||
op: Operation.algebraicOperation,
|
||||
t2: t,
|
||||
t1: t,
|
||||
): calculationStrategy =>
|
||||
switch op {
|
||||
| #Divide
|
||||
| #Power
|
||||
| #Logarithm =>
|
||||
MonteCarloStrat
|
||||
| (#Add | #Subtract | #Multiply) as convOp =>
|
||||
expectedConvolutionCost(t1) * expectedConvolutionCost(t2) > MagicNumbers.OpCost.monteCarloCost
|
||||
? MonteCarloStrat
|
||||
: ConvolutionStrat(convOp)
|
||||
let run = (~t1: t, ~t2: t, ~arithmeticOperation): specificStrategy => {
|
||||
switch StrategyCallOnValidatedInputs.symbolic(arithmeticOperation, t1, t2) {
|
||||
| #AnalyticalSolution(_)
|
||||
| #Error(_) =>
|
||||
#AsSymbolic
|
||||
| #NoSolution =>
|
||||
if Operation.Convolution.canDoAlgebraicOperation(arithmeticOperation) {
|
||||
expectedConvolutionCost(t1) * expectedConvolutionCost(t2) >
|
||||
MagicNumbers.OpCost.monteCarloCost
|
||||
? #AsMonteCarlo
|
||||
: #AsConvolution
|
||||
} else {
|
||||
#AsMonteCarlo
|
||||
}
|
||||
}
|
||||
|
||||
let tryAnalyticalSimplification = (
|
||||
arithmeticOperation: Operation.algebraicOperation,
|
||||
t1: t,
|
||||
t2: t,
|
||||
): option<SymbolicDistTypes.analyticalSimplificationResult> => {
|
||||
switch (t1, t2) {
|
||||
| (DistributionTypes.Symbolic(d1), DistributionTypes.Symbolic(d2)) =>
|
||||
Some(SymbolicDist.T.tryAnalyticalSimplification(d1, d2, arithmeticOperation))
|
||||
| _ => None
|
||||
}
|
||||
}
|
||||
|
||||
let runDefault = (
|
||||
t1: t,
|
||||
let runStrategyOnValidatedInputs = (
|
||||
~t1: t,
|
||||
~t2: t,
|
||||
~arithmeticOperation,
|
||||
~strategy: StrategyChooser.specificStrategy,
|
||||
~toPointSetFn: toPointSetFn,
|
||||
~toSampleSetFn: toSampleSetFn,
|
||||
~arithmeticOperation,
|
||||
~t2: t,
|
||||
): result<t, error> => {
|
||||
switch tryAnalyticalSimplification(arithmeticOperation, t1, t2) {
|
||||
| Some(#AnalyticalSolution(symbolicDist)) => Ok(Symbolic(symbolicDist))
|
||||
| Some(#Error(e)) => Error(OperationError(e))
|
||||
| Some(#NoSolution)
|
||||
| None =>
|
||||
switch getInvalidOperationError(t1, t2, ~toPointSetFn, ~arithmeticOperation) {
|
||||
| Some(e) => Error(e)
|
||||
| None =>
|
||||
switch chooseConvolutionOrMonteCarloDefault(arithmeticOperation, t1, t2) {
|
||||
| MonteCarloStrat => runMonteCarlo(toSampleSetFn, arithmeticOperation, t1, t2)
|
||||
| ConvolutionStrat(convOp) =>
|
||||
runConvolution(toPointSetFn, convOp, t1, t2)->E.R2.fmap(r => DistributionTypes.PointSet(
|
||||
r,
|
||||
))
|
||||
switch strategy {
|
||||
| #AsMonteCarlo =>
|
||||
StrategyCallOnValidatedInputs.monteCarlo(toSampleSetFn, arithmeticOperation, t1, t2)
|
||||
| #AsSymbolic =>
|
||||
switch StrategyCallOnValidatedInputs.symbolic(arithmeticOperation, t1, t2) {
|
||||
| #AnalyticalSolution(symbolicDist) => Ok(Symbolic(symbolicDist))
|
||||
| #Error(e) => Error(OperationError(e))
|
||||
| #NoSolution => Error(Unreachable)
|
||||
}
|
||||
| #AsConvolution =>
|
||||
switch Operation.Convolution.fromAlgebraicOperation(arithmeticOperation) {
|
||||
| Some(convOp) => StrategyCallOnValidatedInputs.convolution(toPointSetFn, convOp, t1, t2)
|
||||
| None => Error(Unreachable)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -300,27 +299,38 @@ module AlgebraicCombination = {
|
|||
~arithmeticOperation: Operation.algebraicOperation,
|
||||
~t2: t,
|
||||
): result<t, error> => {
|
||||
switch strategy {
|
||||
| AsDefault => runDefault(t1, ~toPointSetFn, ~toSampleSetFn, ~arithmeticOperation, ~t2)
|
||||
| AsSymbolic =>
|
||||
switch tryAnalyticalSimplification(arithmeticOperation, t1, t2) {
|
||||
| Some(#AnalyticalSolution(symbolicDist)) => Ok(Symbolic(symbolicDist))
|
||||
| Some(#NoSolution) => Error(RequestedStrategyInvalidError(`No analytical solution`))
|
||||
| None => Error(RequestedStrategyInvalidError("Inputs were not even symbolic"))
|
||||
| Some(#Error(err)) => Error(OperationError(err))
|
||||
let invalidOperationError = InputValidator.run(t1, t2, ~arithmeticOperation, ~toPointSetFn)
|
||||
switch (invalidOperationError, strategy) {
|
||||
| (Some(e), _) => Error(e)
|
||||
| (None, AsDefault) => {
|
||||
let chooseStrategy = StrategyChooser.run(~arithmeticOperation, ~t1, ~t2)
|
||||
runStrategyOnValidatedInputs(
|
||||
~t1,
|
||||
~t2,
|
||||
~strategy=chooseStrategy,
|
||||
~arithmeticOperation,
|
||||
~toPointSetFn,
|
||||
~toSampleSetFn,
|
||||
)
|
||||
}
|
||||
| AsConvolution => {
|
||||
let errString = opString => `Can't convolve on ${opString}`
|
||||
switch arithmeticOperation {
|
||||
| (#Add | #Subtract | #Multiply) as convOp =>
|
||||
runConvolution(toPointSetFn, convOp, t1, t2)->E.R2.fmap(r => DistributionTypes.PointSet(
|
||||
r,
|
||||
))
|
||||
| (#Divide | #Power | #Logarithm) as op =>
|
||||
op->Operation.Algebraic.toString->errString->RequestedStrategyInvalidError->Error
|
||||
| (None, AsMonteCarlo) =>
|
||||
StrategyCallOnValidatedInputs.monteCarlo(toSampleSetFn, arithmeticOperation, t1, t2)
|
||||
| (None, AsSymbolic) =>
|
||||
switch StrategyCallOnValidatedInputs.symbolic(arithmeticOperation, t1, t2) {
|
||||
| #AnalyticalSolution(symbolicDist) => Ok(Symbolic(symbolicDist))
|
||||
| #NoSolution => Error(RequestedStrategyInvalidError(`No analytic solution for inputs`))
|
||||
| #Error(err) => Error(OperationError(err))
|
||||
}
|
||||
| (None, AsConvolution) =>
|
||||
switch Operation.Convolution.fromAlgebraicOperation(arithmeticOperation) {
|
||||
| None => {
|
||||
let errString = `Convolution not supported for ${Operation.Algebraic.toString(
|
||||
arithmeticOperation,
|
||||
)}`
|
||||
Error(RequestedStrategyInvalidError(errString))
|
||||
}
|
||||
| Some(convOp) => StrategyCallOnValidatedInputs.convolution(toPointSetFn, convOp, t1, t2)
|
||||
}
|
||||
| AsMonteCarlo => runMonteCarlo(toSampleSetFn, arithmeticOperation, t1, t2)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -29,6 +29,18 @@ type distToFloatOperation = [
|
|||
|
||||
module Convolution = {
|
||||
type t = convolutionOperation
|
||||
//Only a selection of operations are supported by convolution.
|
||||
let fromAlgebraicOperation = (op: algebraicOperation): option<convolutionOperation> =>
|
||||
switch op {
|
||||
| #Add => Some(#Add)
|
||||
| #Subtract => Some(#Subtract)
|
||||
| #Multiply => Some(#Multiply)
|
||||
| #Divide | #Power | #Logarithm => None
|
||||
}
|
||||
|
||||
let canDoAlgebraicOperation = (op: algebraicOperation): bool =>
|
||||
fromAlgebraicOperation(op)->E.O.isSome
|
||||
|
||||
let toFn: (t, float, float) => float = x =>
|
||||
switch x {
|
||||
| #Add => \"+."
|
||||
|
|
Loading…
Reference in New Issue
Block a user