Organized AlgebraicCombination functionality into submodules

This commit is contained in:
Ozzie Gooen 2022-04-27 12:48:46 -04:00
parent d104494f02
commit 6045fe5e62
2 changed files with 157 additions and 129 deletions

View File

@ -147,129 +147,136 @@ let truncate = Truncate.run
TODO: It would be useful to be able to pass in a paramater to get this to run either with convolution or monte carlo. TODO: It would be useful to be able to pass in a paramater to get this to run either with convolution or monte carlo.
*/ */
module AlgebraicCombination = { module AlgebraicCombination = {
let runConvolution = ( module InputValidator = {
toPointSet: toPointSetFn, /*
arithmeticOperation: Operation.convolutionOperation,
t1: t,
t2: t,
) =>
E.R.merge(toPointSet(t1), toPointSet(t2))->E.R2.fmap(((a, b)) =>
PointSetDist.combineAlgebraically(arithmeticOperation, a, b)
)
let runMonteCarlo = (
toSampleSet: toSampleSetFn,
arithmeticOperation: Operation.algebraicOperation,
t1: t,
t2: t,
): result<t, error> => {
let fn = Operation.Algebraic.toFn(arithmeticOperation)
E.R.merge(toSampleSet(t1), toSampleSet(t2))
->E.R.bind(((t1, t2)) => {
SampleSetDist.map2(~fn, ~t1, ~t2)->E.R2.errMap(x => DistributionTypes.OperationError(x))
})
->E.R2.fmap(r => DistributionTypes.SampleSet(r))
}
/*
It would be good to also do a check to make sure that probability mass for the second It would be good to also do a check to make sure that probability mass for the second
operand, at value 1.0, is 0 (or approximately 0). However, we'd ideally want to check operand, at value 1.0, is 0 (or approximately 0). However, we'd ideally want to check
that both the probability mass and the probability density are greater than zero. that both the probability mass and the probability density are greater than zero.
Right now we don't yet have a way of getting probability mass, so I'll leave this for later. Right now we don't yet have a way of getting probability mass, so I'll leave this for later.
*/ */
let getLogarithmInputError = (t1: t, t2: t, ~toPointSetFn: toPointSetFn): option<error> => { let getLogarithmInputError = (t1: t, t2: t, ~toPointSetFn: toPointSetFn): option<error> => {
let firstOperandIsGreaterThanZero = let firstOperandIsGreaterThanZero =
toFloatOperation(t1, ~toPointSetFn, ~distToFloatOperation=#Cdf(1e-10)) |> E.R.fmap(r => toFloatOperation(t1, ~toPointSetFn, ~distToFloatOperation=#Cdf(1e-10)) |> E.R.fmap(r =>
r > 0. r > 0.
) )
let secondOperandIsGreaterThanZero = let secondOperandIsGreaterThanZero =
toFloatOperation(t2, ~toPointSetFn, ~distToFloatOperation=#Cdf(1e-10)) |> E.R.fmap(r => toFloatOperation(t2, ~toPointSetFn, ~distToFloatOperation=#Cdf(1e-10)) |> E.R.fmap(r =>
r > 0. r > 0.
) )
let items = E.A.R.firstErrorOrOpen([ let items = E.A.R.firstErrorOrOpen([
firstOperandIsGreaterThanZero, firstOperandIsGreaterThanZero,
secondOperandIsGreaterThanZero, secondOperandIsGreaterThanZero,
]) ])
switch items { switch items {
| Error(r) => Some(r) | Error(r) => Some(r)
| Ok([true, _]) => | Ok([true, _]) =>
Some(LogarithmOfDistributionError("First input must completely greater than 0")) Some(LogarithmOfDistributionError("First input must completely greater than 0"))
| Ok([false, true]) => | Ok([false, true]) =>
Some(LogarithmOfDistributionError("Second input must completely greater than 0")) Some(LogarithmOfDistributionError("Second input must completely greater than 0"))
| Ok([false, false]) => None | Ok([false, false]) => None
| Ok(_) => Some(Unreachable) | Ok(_) => Some(Unreachable)
}
}
let run = (t1: t, t2: t, ~toPointSetFn: toPointSetFn, ~arithmeticOperation): option<error> => {
if arithmeticOperation == #Logarithm {
getLogarithmInputError(t1, t2, ~toPointSetFn)
} else {
None
}
} }
} }
let getInvalidOperationError = ( module StrategyCallOnValidatedInputs = {
t1: t, let convolution = (
t2: t, toPointSet: toPointSetFn,
~toPointSetFn: toPointSetFn, arithmeticOperation: Operation.convolutionOperation,
t1: t,
t2: t,
): result<t, error> =>
E.R.merge(toPointSet(t1), toPointSet(t2))
->E.R2.fmap(((a, b)) => PointSetDist.combineAlgebraically(arithmeticOperation, a, b))
->E.R2.fmap(r => DistributionTypes.PointSet(r))
let monteCarlo = (
toSampleSet: toSampleSetFn,
arithmeticOperation: Operation.algebraicOperation,
t1: t,
t2: t,
): result<t, error> => {
let fn = Operation.Algebraic.toFn(arithmeticOperation)
E.R.merge(toSampleSet(t1), toSampleSet(t2))
->E.R.bind(((t1, t2)) => {
SampleSetDist.map2(~fn, ~t1, ~t2)->E.R2.errMap(x => DistributionTypes.OperationError(x))
})
->E.R2.fmap(r => DistributionTypes.SampleSet(r))
}
let symbolic = (
arithmeticOperation: Operation.algebraicOperation,
t1: t,
t2: t,
): SymbolicDistTypes.analyticalSimplificationResult => {
switch (t1, t2) {
| (DistributionTypes.Symbolic(d1), DistributionTypes.Symbolic(d2)) =>
SymbolicDist.T.tryAnalyticalSimplification(d1, d2, arithmeticOperation)
| _ => #NoSolution
}
}
}
module StrategyChooser = {
type specificStrategy = [#AsSymbolic | #AsMonteCarlo | #AsConvolution]
//I'm (Ozzie) really just guessing here, very little idea what's best
let expectedConvolutionCost: t => int = x =>
switch x {
| Symbolic(#Float(_)) => 1
| Symbolic(_) => 1000
| PointSet(Discrete(m)) => m.xyShape->XYShape.T.length
| PointSet(Mixed(_)) => 1000
| PointSet(Continuous(_)) => 1000
| _ => 1000
}
let run = (~t1: t, ~t2: t, ~arithmeticOperation): specificStrategy => {
switch StrategyCallOnValidatedInputs.symbolic(arithmeticOperation, t1, t2) {
| #AnalyticalSolution(_)
| #Error(_) =>
#AsSymbolic
| #NoSolution =>
if Operation.Convolution.canDoAlgebraicOperation(arithmeticOperation) {
expectedConvolutionCost(t1) * expectedConvolutionCost(t2) > 10000
? #AsMonteCarlo
: #AsConvolution
} else {
#AsMonteCarlo
}
}
}
}
let runStrategyOnValidatedInputs = (
~t1: t,
~t2: t,
~arithmeticOperation, ~arithmeticOperation,
): option<error> => { ~strategy: StrategyChooser.specificStrategy,
if arithmeticOperation == #Logarithm {
getLogarithmInputError(t1, t2, ~toPointSetFn)
} else {
None
}
}
//I'm (Ozzie) really just guessing here, very little idea what's best
let expectedConvolutionCost: t => int = x =>
switch x {
| Symbolic(#Float(_)) => 1
| Symbolic(_) => 1000
| PointSet(Discrete(m)) => m.xyShape->XYShape.T.length
| PointSet(Mixed(_)) => 1000
| PointSet(Continuous(_)) => 1000
| _ => 1000
}
type calculationStrategy = MonteCarloStrat | ConvolutionStrat(Operation.convolutionOperation)
let chooseConvolutionOrMonteCarloDefault = (
op: Operation.algebraicOperation,
t2: t,
t1: t,
): calculationStrategy =>
switch op {
| #Divide
| #Power
| #Logarithm =>
MonteCarloStrat
| (#Add | #Subtract | #Multiply) as convOp =>
expectedConvolutionCost(t1) * expectedConvolutionCost(t2) > 10000
? MonteCarloStrat
: ConvolutionStrat(convOp)
}
let tryAnalyticalSimplification = (
arithmeticOperation: Operation.algebraicOperation,
t1: t,
t2: t,
): SymbolicDistTypes.analyticalSimplificationResult => {
switch (t1, t2) {
| (DistributionTypes.Symbolic(d1), DistributionTypes.Symbolic(d2)) =>
SymbolicDist.T.tryAnalyticalSimplification(d1, d2, arithmeticOperation)
| _ => #NoSolution
}
}
let runDefault = (
t1: t,
~toPointSetFn: toPointSetFn, ~toPointSetFn: toPointSetFn,
~toSampleSetFn: toSampleSetFn, ~toSampleSetFn: toSampleSetFn,
~arithmeticOperation,
~t2: t,
): result<t, error> => { ): result<t, error> => {
switch tryAnalyticalSimplification(arithmeticOperation, t1, t2) { switch strategy {
| #AnalyticalSolution(symbolicDist) => Ok(Symbolic(symbolicDist)) | #AsMonteCarlo =>
| #Error(e) => Error(OperationError(e)) StrategyCallOnValidatedInputs.monteCarlo(toSampleSetFn, arithmeticOperation, t1, t2)
| #NoSolution => | #AsSymbolic =>
switch chooseConvolutionOrMonteCarloDefault(arithmeticOperation, t1, t2) { switch StrategyCallOnValidatedInputs.symbolic(arithmeticOperation, t1, t2) {
| MonteCarloStrat => runMonteCarlo(toSampleSetFn, arithmeticOperation, t1, t2) | #AnalyticalSolution(symbolicDist) => Ok(Symbolic(symbolicDist))
| ConvolutionStrat(convOp) => | #Error(e) => Error(OperationError(e))
runConvolution(toPointSetFn, convOp, t1, t2)->E.R2.fmap(r => DistributionTypes.PointSet(r)) | #NoSolution => Error(Unreachable)
}
| #AsConvolution =>
switch Operation.Convolution.fromAlgebraicOperation(arithmeticOperation) {
| Some(convOp) => StrategyCallOnValidatedInputs.convolution(toPointSetFn, convOp, t1, t2)
| None => Error(Unreachable)
} }
} }
} }
@ -282,29 +289,38 @@ module AlgebraicCombination = {
~arithmeticOperation: Operation.algebraicOperation, ~arithmeticOperation: Operation.algebraicOperation,
~t2: t, ~t2: t,
): result<t, error> => { ): result<t, error> => {
let invalidOperationError = getInvalidOperationError( let invalidOperationError = InputValidator.run(t1, t2, ~arithmeticOperation, ~toPointSetFn)
t1, switch (invalidOperationError, strategy) {
t2, | (Some(e), _) => Error(e)
~toPointSetFn, | (None, AsDefault) => {
~arithmeticOperation, let chooseStrategy = StrategyChooser.run(~arithmeticOperation, ~t1, ~t2)
) runStrategyOnValidatedInputs(
switch (invalidOperationError, strategy, arithmeticOperation) { ~t1,
| (Some(e), _, _) => Error(e) ~t2,
| (None, AsDefault, _) => ~strategy=chooseStrategy,
runDefault(t1, ~toPointSetFn, ~toSampleSetFn, ~arithmeticOperation, ~t2) ~arithmeticOperation,
| (None, AsMonteCarlo, _) => runMonteCarlo(toSampleSetFn, arithmeticOperation, t1, t2) ~toPointSetFn,
| (None, AsSymbolic, _) => ~toSampleSetFn,
switch tryAnalyticalSimplification(arithmeticOperation, t1, t2) { )
}
| (None, AsMonteCarlo) =>
StrategyCallOnValidatedInputs.monteCarlo(toSampleSetFn, arithmeticOperation, t1, t2)
| (None, AsSymbolic) =>
switch StrategyCallOnValidatedInputs.symbolic(arithmeticOperation, t1, t2) {
| #AnalyticalSolution(symbolicDist) => Ok(Symbolic(symbolicDist)) | #AnalyticalSolution(symbolicDist) => Ok(Symbolic(symbolicDist))
| #NoSolution => Error(RequestedStrategyInvalidError(`No analytic solution for inputs`)) | #NoSolution => Error(RequestedStrategyInvalidError(`No analytic solution for inputs`))
| #Error(err) => Error(OperationError(err)) | #Error(err) => Error(OperationError(err))
} }
| (None, AsConvolution, (#Divide | #Power | #Logarithm) as convOp) => { | (None, AsConvolution) =>
let errString = `Can't convolve on ${Operation.Algebraic.toString(convOp)}` switch Operation.Convolution.fromAlgebraicOperation(arithmeticOperation) {
Error(RequestedStrategyInvalidError(errString)) | None => {
let errString = `Convolution not supported for ${Operation.Algebraic.toString(
arithmeticOperation,
)}`
Error(RequestedStrategyInvalidError(errString))
}
| Some(convOp) => StrategyCallOnValidatedInputs.convolution(toPointSetFn, convOp, t1, t2)
} }
| (None, AsConvolution, (#Add | #Subtract | #Multiply) as convOp) =>
runConvolution(toPointSetFn, convOp, t1, t2)->E.R2.fmap(r => DistributionTypes.PointSet(r))
} }
} }
} }

View File

@ -29,6 +29,18 @@ type distToFloatOperation = [
module Convolution = { module Convolution = {
type t = convolutionOperation type t = convolutionOperation
//Only a selection of operations are supported by convolution.
let fromAlgebraicOperation = (op: algebraicOperation): option<convolutionOperation> =>
switch op {
| #Add => Some(#Add)
| #Subtract => Some(#Subtract)
| #Multiply => Some(#Multiply)
| #Divide | #Power | #Logarithm => None
}
let canDoAlgebraicOperation = (op: algebraicOperation): bool =>
fromAlgebraicOperation(op)->E.O.isSome
let toFn: (t, float, float) => float = x => let toFn: (t, float, float) => float = x =>
switch x { switch x {
| #Add => \"+." | #Add => \"+."