slghtly more robust solution to mixed kldivergences (and removed a
warning) Value: [1e-5 to 1e-2]
This commit is contained in:
parent
a266b8ed09
commit
3aaad14f11
|
@ -201,7 +201,12 @@ module T = Dist({
|
||||||
})
|
})
|
||||||
|
|
||||||
let logScore = (args: PointSetDist_Scoring.scoreArgs): result<float, Operation.Error.t> =>
|
let logScore = (args: PointSetDist_Scoring.scoreArgs): result<float, Operation.Error.t> =>
|
||||||
PointSetDist_Scoring.logScore(args, ~combineFn=combinePointwise, ~integrateFn=T.Integral.sum)
|
PointSetDist_Scoring.logScore(
|
||||||
|
args,
|
||||||
|
~combineFn=combinePointwise,
|
||||||
|
~integrateFn=T.Integral.sum,
|
||||||
|
~toMixedFn=toMixed,
|
||||||
|
)
|
||||||
|
|
||||||
let pdf = (f: float, t: t) => {
|
let pdf = (f: float, t: t) => {
|
||||||
let mixedPoint: PointSetTypes.mixedPoint = T.xToY(f, t)
|
let mixedPoint: PointSetTypes.mixedPoint = T.xToY(f, t)
|
||||||
|
|
|
@ -29,33 +29,59 @@ module WithDistAnswer = {
|
||||||
minusScaledLogOfQuot(~esti=estimateElement, ~answ=answerElement)
|
minusScaledLogOfQuot(~esti=estimateElement, ~answ=answerElement)
|
||||||
}
|
}
|
||||||
|
|
||||||
let rec sum = (~estimate: t, ~answer: t, ~combineFn, ~integrateFn) =>
|
let rec sum = (~estimate: t, ~answer: t, ~combineFn, ~integrateFn, ~toMixedFn): result<
|
||||||
switch (estimate, answer) {
|
|
||||||
| ((Continuous(_) | Discrete(_)) as esti, (Continuous(_) | Discrete(_)) as answ) =>
|
|
||||||
combineFn(integrand, esti, answ)->E.R2.fmap(integrateFn)
|
|
||||||
| (Mixed(esti), Mixed(answ)) =>
|
|
||||||
E.R.merge(
|
|
||||||
sum(
|
|
||||||
~estimate=Discrete(esti.discrete),
|
|
||||||
~answer=Discrete(answ.discrete),
|
|
||||||
~combineFn,
|
|
||||||
~integrateFn,
|
|
||||||
),
|
|
||||||
sum(
|
|
||||||
~estimate=Continuous(esti.continuous),
|
|
||||||
~answer=Continuous(answ.continuous),
|
|
||||||
~combineFn,
|
|
||||||
~integrateFn,
|
|
||||||
),
|
|
||||||
)->E.R2.fmap(((discretePart, continuousPart)) => discretePart +. continuousPart)
|
|
||||||
}
|
|
||||||
|
|
||||||
let sumWithPrior = (~estimate: t, ~answer: t, ~prior: t, ~combineFn, ~integrateFn): result<
|
|
||||||
float,
|
float,
|
||||||
Operation.Error.t,
|
Operation.Error.t,
|
||||||
> => {
|
> =>
|
||||||
let kl1 = sum(~estimate, ~answer, ~combineFn, ~integrateFn)
|
switch (estimate, answer) {
|
||||||
let kl2 = sum(~estimate=prior, ~answer, ~combineFn, ~integrateFn)
|
| (Continuous(_), Continuous(_)) =>
|
||||||
|
combineFn(integrand, estimate, answer)->E.R2.fmap(integrateFn)
|
||||||
|
| (Discrete(_), Discrete(_)) => combineFn(integrand, estimate, answer)->E.R2.fmap(integrateFn)
|
||||||
|
| (_, _) =>
|
||||||
|
let esti = estimate->toMixedFn
|
||||||
|
let answ = answer->toMixedFn
|
||||||
|
switch (
|
||||||
|
Mixed.T.toContinuous(esti),
|
||||||
|
Mixed.T.toDiscrete(esti),
|
||||||
|
Mixed.T.toContinuous(answ),
|
||||||
|
Mixed.T.toDiscrete(answ),
|
||||||
|
) {
|
||||||
|
| (
|
||||||
|
Some(estiContinuousPart),
|
||||||
|
Some(estiDiscretePart),
|
||||||
|
Some(answContinuousPart),
|
||||||
|
Some(answDiscretePart),
|
||||||
|
) =>
|
||||||
|
E.R.merge(
|
||||||
|
sum(
|
||||||
|
~estimate=Discrete(estiDiscretePart),
|
||||||
|
~answer=Discrete(answDiscretePart),
|
||||||
|
~combineFn,
|
||||||
|
~integrateFn,
|
||||||
|
~toMixedFn,
|
||||||
|
),
|
||||||
|
sum(
|
||||||
|
~estimate=Continuous(estiContinuousPart),
|
||||||
|
~answer=Continuous(answContinuousPart),
|
||||||
|
~combineFn,
|
||||||
|
~integrateFn,
|
||||||
|
~toMixedFn,
|
||||||
|
),
|
||||||
|
)->E.R2.fmap(((discretePart, continuousPart)) => discretePart +. continuousPart)
|
||||||
|
| (_, _, _, _) => `unreachable state`->Operation.Other->Error
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
let sumWithPrior = (
|
||||||
|
~estimate: t,
|
||||||
|
~answer: t,
|
||||||
|
~prior: t,
|
||||||
|
~combineFn,
|
||||||
|
~integrateFn,
|
||||||
|
~toMixedFn,
|
||||||
|
): result<float, Operation.Error.t> => {
|
||||||
|
let kl1 = sum(~estimate, ~answer, ~combineFn, ~integrateFn, ~toMixedFn)
|
||||||
|
let kl2 = sum(~estimate=prior, ~answer, ~combineFn, ~integrateFn, ~toMixedFn)
|
||||||
E.R.merge(kl1, kl2)->E.R2.fmap(((kl1', kl2')) => kl1' -. kl2')
|
E.R.merge(kl1, kl2)->E.R2.fmap(((kl1', kl2')) => kl1' -. kl2')
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -138,12 +164,15 @@ module TwoScalars = {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
let logScore = (args: scoreArgs, ~combineFn, ~integrateFn): result<float, Operation.Error.t> =>
|
let logScore = (args: scoreArgs, ~combineFn, ~integrateFn, ~toMixedFn): result<
|
||||||
|
float,
|
||||||
|
Operation.Error.t,
|
||||||
|
> =>
|
||||||
switch args {
|
switch args {
|
||||||
| DistEstimateDistAnswer({estimate, answer, prior: None}) =>
|
| DistEstimateDistAnswer({estimate, answer, prior: None}) =>
|
||||||
WithDistAnswer.sum(~estimate, ~answer, ~integrateFn, ~combineFn)
|
WithDistAnswer.sum(~estimate, ~answer, ~integrateFn, ~combineFn, ~toMixedFn)
|
||||||
| DistEstimateDistAnswer({estimate, answer, prior: Some(prior)}) =>
|
| DistEstimateDistAnswer({estimate, answer, prior: Some(prior)}) =>
|
||||||
WithDistAnswer.sumWithPrior(~estimate, ~answer, ~prior, ~integrateFn, ~combineFn)
|
WithDistAnswer.sumWithPrior(~estimate, ~answer, ~prior, ~integrateFn, ~combineFn, ~toMixedFn)
|
||||||
| DistEstimateScalarAnswer({estimate, answer, prior: None}) =>
|
| DistEstimateScalarAnswer({estimate, answer, prior: None}) =>
|
||||||
WithScalarAnswer.score(~estimate, ~answer)
|
WithScalarAnswer.score(~estimate, ~answer)
|
||||||
| DistEstimateScalarAnswer({estimate, answer, prior: Some(prior)}) =>
|
| DistEstimateScalarAnswer({estimate, answer, prior: Some(prior)}) =>
|
||||||
|
|
Loading…
Reference in New Issue
Block a user