klDivergence with prior

Value: [1e-4 to 5e-23]
This commit is contained in:
Quinn Dougherty 2022-05-16 13:18:01 -04:00
parent d00b82807c
commit 81b2c74ac8

View File

@ -180,6 +180,20 @@ module Helpers = {
->PointSetTypes.Continuous ->PointSetTypes.Continuous
->DistributionTypes.PointSet ->DistributionTypes.PointSet
} }
let klDivergenceWithPrior = (
prediction: DistributionTypes.genericDist,
answer: DistributionTypes.genericDist,
prior: DistributionTypes.genericDist,
env: DistributionOperation.env,
) => {
let term1 = DistributionOperation.Constructors.klDivergence(~env, prediction, answer)
let term2 = DistributionOperation.Constructors.klDivergence(~env, prior, answer)
switch E.R.merge(term1, term2)->E.R2.fmap(((a, b)) => a -. b) {
| Ok(x) => x->DistributionOperation.Float->Some
| Error(_) => None
}
}
} }
module SymbolicConstructors = { module SymbolicConstructors = {
@ -268,8 +282,10 @@ let dispatchToGenericOutput = (
~env, ~env,
)->Some )->Some
| ("normalize", [EvDistribution(dist)]) => Helpers.toDistFn(Normalize, dist, ~env) | ("normalize", [EvDistribution(dist)]) => Helpers.toDistFn(Normalize, dist, ~env)
| ("klDivergence", [EvDistribution(a), EvDistribution(b)]) => | ("klDivergence", [EvDistribution(prediction), EvDistribution(answer)]) =>
Some(DistributionOperation.run(FromDist(ToScore(KLDivergence(b)), a), ~env)) Some(DistributionOperation.run(FromDist(ToScore(KLDivergence(answer)), prediction), ~env))
| ("klDivergence", [EvDistribution(prediction), EvDistribution(answer), EvDistribution(prior)]) =>
Helpers.klDivergenceWithPrior(prediction, answer, prior, env)
| ( | (
"logScoreWithPointAnswer", "logScoreWithPointAnswer",
[EvDistribution(prediction), EvNumber(answer), EvDistribution(prior)], [EvDistribution(prediction), EvNumber(answer), EvDistribution(prior)],