Quick bug fixes for pow, log, and toPointSet

This commit is contained in:
Ozzie Gooen 2022-04-17 18:33:43 -04:00
parent cf9c12f786
commit 81e478ba49
4 changed files with 87 additions and 20 deletions

View File

@ -83,7 +83,8 @@ let toPointSet = (
pointSetDistLength: xyPointLength, pointSetDistLength: xyPointLength,
kernelWidth: None, kernelWidth: None,
}, },
)->GenericDist_Types.Error.resultStringToResultError )
->GenericDist_Types.Error.resultStringToResultError
} }
} }
@ -161,10 +162,12 @@ module AlgebraicCombination = {
arithmeticOperation: GenericDist_Types.Operation.arithmeticOperation, arithmeticOperation: GenericDist_Types.Operation.arithmeticOperation,
t1: t, t1: t,
t2: t, t2: t,
) => ) => {
let normalize = PointSetDist.T.normalize
E.R.merge(toPointSet(t1), toPointSet(t2))->E.R2.fmap(((a, b)) => E.R.merge(toPointSet(t1), toPointSet(t2))->E.R2.fmap(((a, b)) =>
PointSetDist.combineAlgebraically(arithmeticOperation, a, b) PointSetDist.combineAlgebraically(arithmeticOperation, normalize(a), normalize(b))->normalize
) )
}
let runMonteCarlo = ( let runMonteCarlo = (
toSampleSet: toSampleSetFn, toSampleSet: toSampleSetFn,
@ -196,6 +199,50 @@ module AlgebraicCombination = {
? #CalculateWithMonteCarlo ? #CalculateWithMonteCarlo
: #CalculateWithConvolution : #CalculateWithConvolution
let getLogarithmInputError = (t1: t, t2: t, ~toPointSetFn: toPointSetFn): option<error> => {
let firstOperandIsGreaterThanZero =
toFloatOperation(t1, ~toPointSetFn, ~distToFloatOperation=#Cdf(1e-10)) |> E.R.fmap(r =>
r > 0.
)
let secondOperandIsGreaterThanZero =
toFloatOperation(t2, ~toPointSetFn, ~distToFloatOperation=#Cdf(1e-10)) |> E.R.fmap(r =>
r > 0.
)
let secondOperandHasMassAt1 =
toFloatOperation(t2, ~toPointSetFn, ~distToFloatOperation=#Pdf(1.0)) |> E.R.fmap(r =>
r >= 1e-10
)
let items = E.A.R.firstErrorOrOpen([
firstOperandIsGreaterThanZero,
secondOperandIsGreaterThanZero,
secondOperandHasMassAt1,
])
Js.log2("PMASS", toFloatOperation(t2, ~toPointSetFn, ~distToFloatOperation=#Pdf(1.0)))
Js.log4("HIHI", items, t1, t2)
switch items {
| Error(r) => Some(r)
| Ok([true, _, _]) => Some(Other("First input of logarithm must be fully greater than 0"))
| Ok([false, true, _]) => Some(Other("Second input of logarithm must be fully greater than 0"))
| Ok([false, false, true]) =>
Some(Other("Second input of logarithm cannot have probability mass at 1.0"))
| Ok([false, false, false]) => None
| Ok(_) => Some(Unreachable)
}
}
let getInvalidOperationError = (
t1: t,
t2: t,
~toPointSetFn: toPointSetFn,
~arithmeticOperation,
): option<error> => {
if arithmeticOperation == #Logarithm {
getLogarithmInputError(t1, t2, ~toPointSetFn)
} else {
None
}
}
let run = ( let run = (
t1: t, t1: t,
~toPointSetFn: toPointSetFn, ~toPointSetFn: toPointSetFn,
@ -206,6 +253,9 @@ module AlgebraicCombination = {
switch tryAnalyticalSimplification(arithmeticOperation, t1, t2) { switch tryAnalyticalSimplification(arithmeticOperation, t1, t2) {
| Some(Ok(symbolicDist)) => Ok(Symbolic(symbolicDist)) | Some(Ok(symbolicDist)) => Ok(Symbolic(symbolicDist))
| Some(Error(e)) => Error(Other(e)) | Some(Error(e)) => Error(Other(e))
| None =>
switch getInvalidOperationError(t1, t2, ~toPointSetFn, ~arithmeticOperation) {
| Some(e) => Error(e)
| None => | None =>
switch chooseConvolutionOrMonteCarlo(t1, t2) { switch chooseConvolutionOrMonteCarlo(t1, t2) {
| #CalculateWithMonteCarlo => runMonteCarlo(toSampleSetFn, arithmeticOperation, t1, t2) | #CalculateWithMonteCarlo => runMonteCarlo(toSampleSetFn, arithmeticOperation, t1, t2)
@ -215,7 +265,8 @@ module AlgebraicCombination = {
arithmeticOperation, arithmeticOperation,
t1, t1,
t2, t2,
)->E.R2.fmap(r => DistributionTypes.PointSet(r)) )->E.R2.fmap(r => DistributionTypes.PointSet(PointSetDist.T.normalize(r)))
}
} }
} }
} }

View File

@ -93,7 +93,20 @@ module T = Dist({
t, t,
) )
let normalize = fmap((Mixed.T.normalize, Discrete.T.normalize, Continuous.T.normalize)) let integralEndY = mapToAll((
Mixed.T.Integral.sum,
Discrete.T.Integral.sum,
Continuous.T.Integral.sum,
))
let isNormalized = t => integralEndY(t) == 1.0
let normalize = (t: t): t =>
if isNormalized(t) {
t
} else {
t |> fmap((Mixed.T.normalize, Discrete.T.normalize, Continuous.T.normalize))
}
let updateIntegralCache = (integralCache, t: t): t => let updateIntegralCache = (integralCache, t: t): t =>
fmap( fmap(
@ -124,11 +137,6 @@ module T = Dist({
Discrete.T.Integral.get, Discrete.T.Integral.get,
Continuous.T.Integral.get, Continuous.T.Integral.get,
)) ))
let integralEndY = mapToAll((
Mixed.T.Integral.sum,
Discrete.T.Integral.sum,
Continuous.T.Integral.sum,
))
let integralXtoY = f => let integralXtoY = f =>
mapToAll((Mixed.T.Integral.xToY(f), Discrete.T.Integral.xToY(f), Continuous.T.Integral.xToY(f))) mapToAll((Mixed.T.Integral.xToY(f), Discrete.T.Integral.xToY(f), Continuous.T.Integral.xToY(f)))
let integralYtoX = f => let integralYtoX = f =>

View File

@ -64,5 +64,10 @@ let sampleN = (t: t, n) => {
//TODO: Figure out what to do if distributions are different lengths. ``zip`` is kind of inelegant for this. //TODO: Figure out what to do if distributions are different lengths. ``zip`` is kind of inelegant for this.
let map2 = (~fn: (float, float) => float, ~t1: t, ~t2: t) => { let map2 = (~fn: (float, float) => float, ~t1: t, ~t2: t) => {
let samples = Belt.Array.zip(get(t1), get(t2))->E.A2.fmap(((a, b)) => fn(a, b)) let samples = Belt.Array.zip(get(t1), get(t2))->E.A2.fmap(((a, b)) => fn(a, b))
let has_invalid_results = Belt.Array.some(samples, a => Js.Float.isNaN(a))
if has_invalid_results {
Error("Distribution combination produced invalid results")
} else {
make(samples) make(samples)
}
} }

View File

@ -133,9 +133,12 @@ let toPointSetDist = (
~discrete=Some(discrete), ~discrete=Some(discrete),
) )
//The latter doesn't always produce a normalized result, so we need to normalize it.
let normalized = pointSetDist->E.O2.fmap(PointSetDist.T.normalize)
let samplesParse: Internals.Types.outputs = { let samplesParse: Internals.Types.outputs = {
continuousParseParams: pdf |> E.O.fmap(snd), continuousParseParams: pdf |> E.O.fmap(snd),
pointSetDist: pointSetDist, pointSetDist: normalized,
} }
samplesParse samplesParse