Compare commits

...

3 Commits

Author SHA1 Message Date
Sam Nolan
9637d17099 Merge branch 'develop' into log-and-power-bug-fixes 2022-04-20 18:41:34 -04:00
Ozzie Gooen
41cacd2aae Minor refactors 2022-04-20 13:05:26 -04:00
Ozzie Gooen
81e478ba49 Quick bug fixes for pow, log, and toPointSet 2022-04-17 18:33:43 -04:00
5 changed files with 88 additions and 39 deletions

View File

@ -161,10 +161,12 @@ module AlgebraicCombination = {
arithmeticOperation: GenericDist_Types.Operation.arithmeticOperation,
t1: t,
t2: t,
) =>
) => {
let normalize = PointSetDist.T.normalize
E.R.merge(toPointSet(t1), toPointSet(t2))->E.R2.fmap(((a, b)) =>
PointSetDist.combineAlgebraically(arithmeticOperation, a, b)
PointSetDist.combineAlgebraically(arithmeticOperation, normalize(a), normalize(b))->normalize
)
}
let runMonteCarlo = (
toSampleSet: toSampleSetFn,
@ -196,6 +198,48 @@ module AlgebraicCombination = {
? #CalculateWithMonteCarlo
: #CalculateWithConvolution
let getLogarithmInputError = (t1: t, t2: t, ~toPointSetFn: toPointSetFn): option<error> => {
let firstOperandIsGreaterThanZero =
toFloatOperation(t1, ~toPointSetFn, ~distToFloatOperation=#Cdf(1e-10)) |> E.R.fmap(r =>
r > 0.
)
let secondOperandIsGreaterThanZero =
toFloatOperation(t2, ~toPointSetFn, ~distToFloatOperation=#Cdf(1e-10)) |> E.R.fmap(r =>
r > 0.
)
let secondOperandHasMassAt1 =
toFloatOperation(t2, ~toPointSetFn, ~distToFloatOperation=#Pdf(1.0)) |> E.R.fmap(r =>
r >= 1e-10
)
let items = E.A.R.firstErrorOrOpen([
firstOperandIsGreaterThanZero,
secondOperandIsGreaterThanZero,
secondOperandHasMassAt1,
])
switch items {
| Error(r) => Some(r)
| Ok([true, _, _]) => Some(Other("First input of logarithm must be fully greater than 0"))
| Ok([false, true, _]) => Some(Other("Second input of logarithm must be fully greater than 0"))
| Ok([false, false, true]) =>
Some(Other("Second input of logarithm cannot have probability mass at 1.0"))
| Ok([false, false, false]) => None
| Ok(_) => Some(Unreachable)
}
}
let getInvalidOperationError = (
t1: t,
t2: t,
~toPointSetFn: toPointSetFn,
~arithmeticOperation,
): option<error> => {
if arithmeticOperation == #Logarithm {
getLogarithmInputError(t1, t2, ~toPointSetFn)
} else {
None
}
}
let run = (
t1: t,
~toPointSetFn: toPointSetFn,
@ -207,15 +251,19 @@ module AlgebraicCombination = {
| Some(Ok(symbolicDist)) => Ok(Symbolic(symbolicDist))
| Some(Error(e)) => Error(Other(e))
| None =>
switch chooseConvolutionOrMonteCarlo(t1, t2) {
| #CalculateWithMonteCarlo => runMonteCarlo(toSampleSetFn, arithmeticOperation, t1, t2)
| #CalculateWithConvolution =>
runConvolution(
toPointSetFn,
arithmeticOperation,
t1,
t2,
)->E.R2.fmap(r => DistributionTypes.PointSet(r))
switch getInvalidOperationError(t1, t2, ~toPointSetFn, ~arithmeticOperation) {
| Some(e) => Error(e)
| None =>
switch chooseConvolutionOrMonteCarlo(t1, t2) {
| #CalculateWithMonteCarlo => runMonteCarlo(toSampleSetFn, arithmeticOperation, t1, t2)
| #CalculateWithConvolution =>
runConvolution(
toPointSetFn,
arithmeticOperation,
t1,
t2,
)->E.R2.fmap(r => DistributionTypes.PointSet(PointSetDist.T.normalize(r)))
}
}
}
}

View File

@ -28,24 +28,9 @@ module Operation = {
| Algebraic
| Pointwise
type arithmeticOperation = [
| #Add
| #Multiply
| #Subtract
| #Divide
| #Power
| #Logarithm
]
let arithmeticToFn = (arithmetic: arithmeticOperation) =>
switch arithmetic {
| #Add => \"+."
| #Multiply => \"*."
| #Subtract => \"-."
| #Power => \"**"
| #Divide => \"/."
| #Logarithm => (a, b) => log(a) /. log(b)
}
type arithmeticOperation = Operation.algebraicOperation
let arithmeticToFn = Operation.Algebraic.toFn
let arithmeticToString = Operation.Algebraic.toString
type toFloat = [
| #Cdf(float)
@ -105,8 +90,8 @@ module Operation = {
| ToString(ToString) => `toString`
| ToString(ToSparkline(n)) => `toSparkline(${E.I.toString(n)})`
| ToBool(IsNormalized) => `isNormalized`
| ToDistCombination(Algebraic, _, _) => `algebraic`
| ToDistCombination(Pointwise, _, _) => `pointwise`
| ToDistCombination(Algebraic, operation, _) => `algebraic-${arithmeticToString(operation)}`
| ToDistCombination(Pointwise, operation, _) => `pointwise-${arithmeticToString(operation)}`
}
let toString = (d: genericFunctionCallInfo): string =>

View File

@ -93,7 +93,20 @@ module T = Dist({
t,
)
let normalize = fmap((Mixed.T.normalize, Discrete.T.normalize, Continuous.T.normalize))
let integralEndY = mapToAll((
Mixed.T.Integral.sum,
Discrete.T.Integral.sum,
Continuous.T.Integral.sum,
))
let isNormalized = t => integralEndY(t) == 1.0
let normalize = (t: t): t =>
if isNormalized(t) {
t
} else {
t |> fmap((Mixed.T.normalize, Discrete.T.normalize, Continuous.T.normalize))
}
let updateIntegralCache = (integralCache, t: t): t =>
fmap(
@ -124,11 +137,6 @@ module T = Dist({
Discrete.T.Integral.get,
Continuous.T.Integral.get,
))
let integralEndY = mapToAll((
Mixed.T.Integral.sum,
Discrete.T.Integral.sum,
Continuous.T.Integral.sum,
))
let integralXtoY = f =>
mapToAll((Mixed.T.Integral.xToY(f), Discrete.T.Integral.xToY(f), Continuous.T.Integral.xToY(f)))
let integralYtoX = f =>

View File

@ -64,5 +64,10 @@ let sampleN = (t: t, n) => {
//TODO: Figure out what to do if distributions are different lengths. ``zip`` is kind of inelegant for this.
let map2 = (~fn: (float, float) => float, ~t1: t, ~t2: t) => {
let samples = Belt.Array.zip(get(t1), get(t2))->E.A2.fmap(((a, b)) => fn(a, b))
make(samples)
let has_invalid_results = Belt.Array.some(samples, a => Js.Float.isNaN(a))
if has_invalid_results {
Error("Distribution combination produced invalid results")
} else {
make(samples)
}
}

View File

@ -133,9 +133,12 @@ let toPointSetDist = (
~discrete=Some(discrete),
)
//The latter doesn't always produce a normalized result, so we need to normalize it.
let normalized = pointSetDist->E.O2.fmap(PointSetDist.T.normalize)
let samplesParse: Internals.Types.outputs = {
continuousParseParams: pdf |> E.O.fmap(snd),
pointSetDist: pointSetDist,
pointSetDist: normalized,
}
samplesParse