Quick bug fixes for pow, log, and toPointSet
This commit is contained in:
parent
cf9c12f786
commit
81e478ba49
|
@ -83,7 +83,8 @@ let toPointSet = (
|
|||
pointSetDistLength: xyPointLength,
|
||||
kernelWidth: None,
|
||||
},
|
||||
)->GenericDist_Types.Error.resultStringToResultError
|
||||
)
|
||||
->GenericDist_Types.Error.resultStringToResultError
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -161,10 +162,12 @@ module AlgebraicCombination = {
|
|||
arithmeticOperation: GenericDist_Types.Operation.arithmeticOperation,
|
||||
t1: t,
|
||||
t2: t,
|
||||
) =>
|
||||
) => {
|
||||
let normalize = PointSetDist.T.normalize
|
||||
E.R.merge(toPointSet(t1), toPointSet(t2))->E.R2.fmap(((a, b)) =>
|
||||
PointSetDist.combineAlgebraically(arithmeticOperation, a, b)
|
||||
PointSetDist.combineAlgebraically(arithmeticOperation, normalize(a), normalize(b))->normalize
|
||||
)
|
||||
}
|
||||
|
||||
let runMonteCarlo = (
|
||||
toSampleSet: toSampleSetFn,
|
||||
|
@ -196,6 +199,50 @@ module AlgebraicCombination = {
|
|||
? #CalculateWithMonteCarlo
|
||||
: #CalculateWithConvolution
|
||||
|
||||
let getLogarithmInputError = (t1: t, t2: t, ~toPointSetFn: toPointSetFn): option<error> => {
|
||||
let firstOperandIsGreaterThanZero =
|
||||
toFloatOperation(t1, ~toPointSetFn, ~distToFloatOperation=#Cdf(1e-10)) |> E.R.fmap(r =>
|
||||
r > 0.
|
||||
)
|
||||
let secondOperandIsGreaterThanZero =
|
||||
toFloatOperation(t2, ~toPointSetFn, ~distToFloatOperation=#Cdf(1e-10)) |> E.R.fmap(r =>
|
||||
r > 0.
|
||||
)
|
||||
let secondOperandHasMassAt1 =
|
||||
toFloatOperation(t2, ~toPointSetFn, ~distToFloatOperation=#Pdf(1.0)) |> E.R.fmap(r =>
|
||||
r >= 1e-10
|
||||
)
|
||||
let items = E.A.R.firstErrorOrOpen([
|
||||
firstOperandIsGreaterThanZero,
|
||||
secondOperandIsGreaterThanZero,
|
||||
secondOperandHasMassAt1,
|
||||
])
|
||||
Js.log2("PMASS", toFloatOperation(t2, ~toPointSetFn, ~distToFloatOperation=#Pdf(1.0)))
|
||||
Js.log4("HIHI", items, t1, t2)
|
||||
switch items {
|
||||
| Error(r) => Some(r)
|
||||
| Ok([true, _, _]) => Some(Other("First input of logarithm must be fully greater than 0"))
|
||||
| Ok([false, true, _]) => Some(Other("Second input of logarithm must be fully greater than 0"))
|
||||
| Ok([false, false, true]) =>
|
||||
Some(Other("Second input of logarithm cannot have probability mass at 1.0"))
|
||||
| Ok([false, false, false]) => None
|
||||
| Ok(_) => Some(Unreachable)
|
||||
}
|
||||
}
|
||||
|
||||
let getInvalidOperationError = (
|
||||
t1: t,
|
||||
t2: t,
|
||||
~toPointSetFn: toPointSetFn,
|
||||
~arithmeticOperation,
|
||||
): option<error> => {
|
||||
if arithmeticOperation == #Logarithm {
|
||||
getLogarithmInputError(t1, t2, ~toPointSetFn)
|
||||
} else {
|
||||
None
|
||||
}
|
||||
}
|
||||
|
||||
let run = (
|
||||
t1: t,
|
||||
~toPointSetFn: toPointSetFn,
|
||||
|
@ -206,6 +253,9 @@ module AlgebraicCombination = {
|
|||
switch tryAnalyticalSimplification(arithmeticOperation, t1, t2) {
|
||||
| Some(Ok(symbolicDist)) => Ok(Symbolic(symbolicDist))
|
||||
| Some(Error(e)) => Error(Other(e))
|
||||
| None =>
|
||||
switch getInvalidOperationError(t1, t2, ~toPointSetFn, ~arithmeticOperation) {
|
||||
| Some(e) => Error(e)
|
||||
| None =>
|
||||
switch chooseConvolutionOrMonteCarlo(t1, t2) {
|
||||
| #CalculateWithMonteCarlo => runMonteCarlo(toSampleSetFn, arithmeticOperation, t1, t2)
|
||||
|
@ -215,7 +265,8 @@ module AlgebraicCombination = {
|
|||
arithmeticOperation,
|
||||
t1,
|
||||
t2,
|
||||
)->E.R2.fmap(r => DistributionTypes.PointSet(r))
|
||||
)->E.R2.fmap(r => DistributionTypes.PointSet(PointSetDist.T.normalize(r)))
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -93,7 +93,20 @@ module T = Dist({
|
|||
t,
|
||||
)
|
||||
|
||||
let normalize = fmap((Mixed.T.normalize, Discrete.T.normalize, Continuous.T.normalize))
|
||||
let integralEndY = mapToAll((
|
||||
Mixed.T.Integral.sum,
|
||||
Discrete.T.Integral.sum,
|
||||
Continuous.T.Integral.sum,
|
||||
))
|
||||
|
||||
let isNormalized = t => integralEndY(t) == 1.0
|
||||
|
||||
let normalize = (t: t): t =>
|
||||
if isNormalized(t) {
|
||||
t
|
||||
} else {
|
||||
t |> fmap((Mixed.T.normalize, Discrete.T.normalize, Continuous.T.normalize))
|
||||
}
|
||||
|
||||
let updateIntegralCache = (integralCache, t: t): t =>
|
||||
fmap(
|
||||
|
@ -124,11 +137,6 @@ module T = Dist({
|
|||
Discrete.T.Integral.get,
|
||||
Continuous.T.Integral.get,
|
||||
))
|
||||
let integralEndY = mapToAll((
|
||||
Mixed.T.Integral.sum,
|
||||
Discrete.T.Integral.sum,
|
||||
Continuous.T.Integral.sum,
|
||||
))
|
||||
let integralXtoY = f =>
|
||||
mapToAll((Mixed.T.Integral.xToY(f), Discrete.T.Integral.xToY(f), Continuous.T.Integral.xToY(f)))
|
||||
let integralYtoX = f =>
|
||||
|
|
|
@ -64,5 +64,10 @@ let sampleN = (t: t, n) => {
|
|||
//TODO: Figure out what to do if distributions are different lengths. ``zip`` is kind of inelegant for this.
|
||||
let map2 = (~fn: (float, float) => float, ~t1: t, ~t2: t) => {
|
||||
let samples = Belt.Array.zip(get(t1), get(t2))->E.A2.fmap(((a, b)) => fn(a, b))
|
||||
let has_invalid_results = Belt.Array.some(samples, a => Js.Float.isNaN(a))
|
||||
if has_invalid_results {
|
||||
Error("Distribution combination produced invalid results")
|
||||
} else {
|
||||
make(samples)
|
||||
}
|
||||
}
|
||||
|
|
|
@ -133,9 +133,12 @@ let toPointSetDist = (
|
|||
~discrete=Some(discrete),
|
||||
)
|
||||
|
||||
//The latter doesn't always produce a normalized result, so we need to normalize it.
|
||||
let normalized = pointSetDist->E.O2.fmap(PointSetDist.T.normalize)
|
||||
|
||||
let samplesParse: Internals.Types.outputs = {
|
||||
continuousParseParams: pdf |> E.O.fmap(snd),
|
||||
pointSetDist: pointSetDist,
|
||||
pointSetDist: normalized,
|
||||
}
|
||||
|
||||
samplesParse
|
||||
|
|
Loading…
Reference in New Issue
Block a user