slghtly more robust solution to mixed kldivergences (and removed a

warning)

Value: [1e-5 to 1e-2]
This commit is contained in:
Quinn Dougherty 2022-05-25 13:51:01 -04:00
parent a266b8ed09
commit 3aaad14f11
2 changed files with 63 additions and 29 deletions

View File

@ -201,7 +201,12 @@ module T = Dist({
})
let logScore = (args: PointSetDist_Scoring.scoreArgs): result<float, Operation.Error.t> =>
PointSetDist_Scoring.logScore(args, ~combineFn=combinePointwise, ~integrateFn=T.Integral.sum)
PointSetDist_Scoring.logScore(
args,
~combineFn=combinePointwise,
~integrateFn=T.Integral.sum,
~toMixedFn=toMixed,
)
let pdf = (f: float, t: t) => {
let mixedPoint: PointSetTypes.mixedPoint = T.xToY(f, t)

View File

@ -29,33 +29,59 @@ module WithDistAnswer = {
minusScaledLogOfQuot(~esti=estimateElement, ~answ=answerElement)
}
let rec sum = (~estimate: t, ~answer: t, ~combineFn, ~integrateFn) =>
switch (estimate, answer) {
| ((Continuous(_) | Discrete(_)) as esti, (Continuous(_) | Discrete(_)) as answ) =>
combineFn(integrand, esti, answ)->E.R2.fmap(integrateFn)
| (Mixed(esti), Mixed(answ)) =>
E.R.merge(
sum(
~estimate=Discrete(esti.discrete),
~answer=Discrete(answ.discrete),
~combineFn,
~integrateFn,
),
sum(
~estimate=Continuous(esti.continuous),
~answer=Continuous(answ.continuous),
~combineFn,
~integrateFn,
),
)->E.R2.fmap(((discretePart, continuousPart)) => discretePart +. continuousPart)
}
let sumWithPrior = (~estimate: t, ~answer: t, ~prior: t, ~combineFn, ~integrateFn): result<
let rec sum = (~estimate: t, ~answer: t, ~combineFn, ~integrateFn, ~toMixedFn): result<
float,
Operation.Error.t,
> => {
let kl1 = sum(~estimate, ~answer, ~combineFn, ~integrateFn)
let kl2 = sum(~estimate=prior, ~answer, ~combineFn, ~integrateFn)
> =>
switch (estimate, answer) {
| (Continuous(_), Continuous(_)) =>
combineFn(integrand, estimate, answer)->E.R2.fmap(integrateFn)
| (Discrete(_), Discrete(_)) => combineFn(integrand, estimate, answer)->E.R2.fmap(integrateFn)
| (_, _) =>
let esti = estimate->toMixedFn
let answ = answer->toMixedFn
switch (
Mixed.T.toContinuous(esti),
Mixed.T.toDiscrete(esti),
Mixed.T.toContinuous(answ),
Mixed.T.toDiscrete(answ),
) {
| (
Some(estiContinuousPart),
Some(estiDiscretePart),
Some(answContinuousPart),
Some(answDiscretePart),
) =>
E.R.merge(
sum(
~estimate=Discrete(estiDiscretePart),
~answer=Discrete(answDiscretePart),
~combineFn,
~integrateFn,
~toMixedFn,
),
sum(
~estimate=Continuous(estiContinuousPart),
~answer=Continuous(answContinuousPart),
~combineFn,
~integrateFn,
~toMixedFn,
),
)->E.R2.fmap(((discretePart, continuousPart)) => discretePart +. continuousPart)
| (_, _, _, _) => `unreachable state`->Operation.Other->Error
}
}
let sumWithPrior = (
~estimate: t,
~answer: t,
~prior: t,
~combineFn,
~integrateFn,
~toMixedFn,
): result<float, Operation.Error.t> => {
let kl1 = sum(~estimate, ~answer, ~combineFn, ~integrateFn, ~toMixedFn)
let kl2 = sum(~estimate=prior, ~answer, ~combineFn, ~integrateFn, ~toMixedFn)
E.R.merge(kl1, kl2)->E.R2.fmap(((kl1', kl2')) => kl1' -. kl2')
}
}
@ -138,12 +164,15 @@ module TwoScalars = {
}
}
let logScore = (args: scoreArgs, ~combineFn, ~integrateFn): result<float, Operation.Error.t> =>
let logScore = (args: scoreArgs, ~combineFn, ~integrateFn, ~toMixedFn): result<
float,
Operation.Error.t,
> =>
switch args {
| DistEstimateDistAnswer({estimate, answer, prior: None}) =>
WithDistAnswer.sum(~estimate, ~answer, ~integrateFn, ~combineFn)
WithDistAnswer.sum(~estimate, ~answer, ~integrateFn, ~combineFn, ~toMixedFn)
| DistEstimateDistAnswer({estimate, answer, prior: Some(prior)}) =>
WithDistAnswer.sumWithPrior(~estimate, ~answer, ~prior, ~integrateFn, ~combineFn)
WithDistAnswer.sumWithPrior(~estimate, ~answer, ~prior, ~integrateFn, ~combineFn, ~toMixedFn)
| DistEstimateScalarAnswer({estimate, answer, prior: None}) =>
WithScalarAnswer.score(~estimate, ~answer)
| DistEstimateScalarAnswer({estimate, answer, prior: Some(prior)}) =>