klDivergence
with prior
Value: [1e-4 to 5e-23]
This commit is contained in:
parent
d00b82807c
commit
81b2c74ac8
|
@ -180,6 +180,20 @@ module Helpers = {
|
||||||
->PointSetTypes.Continuous
|
->PointSetTypes.Continuous
|
||||||
->DistributionTypes.PointSet
|
->DistributionTypes.PointSet
|
||||||
}
|
}
|
||||||
|
|
||||||
|
let klDivergenceWithPrior = (
|
||||||
|
prediction: DistributionTypes.genericDist,
|
||||||
|
answer: DistributionTypes.genericDist,
|
||||||
|
prior: DistributionTypes.genericDist,
|
||||||
|
env: DistributionOperation.env,
|
||||||
|
) => {
|
||||||
|
let term1 = DistributionOperation.Constructors.klDivergence(~env, prediction, answer)
|
||||||
|
let term2 = DistributionOperation.Constructors.klDivergence(~env, prior, answer)
|
||||||
|
switch E.R.merge(term1, term2)->E.R2.fmap(((a, b)) => a -. b) {
|
||||||
|
| Ok(x) => x->DistributionOperation.Float->Some
|
||||||
|
| Error(_) => None
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
module SymbolicConstructors = {
|
module SymbolicConstructors = {
|
||||||
|
@ -268,8 +282,10 @@ let dispatchToGenericOutput = (
|
||||||
~env,
|
~env,
|
||||||
)->Some
|
)->Some
|
||||||
| ("normalize", [EvDistribution(dist)]) => Helpers.toDistFn(Normalize, dist, ~env)
|
| ("normalize", [EvDistribution(dist)]) => Helpers.toDistFn(Normalize, dist, ~env)
|
||||||
| ("klDivergence", [EvDistribution(a), EvDistribution(b)]) =>
|
| ("klDivergence", [EvDistribution(prediction), EvDistribution(answer)]) =>
|
||||||
Some(DistributionOperation.run(FromDist(ToScore(KLDivergence(b)), a), ~env))
|
Some(DistributionOperation.run(FromDist(ToScore(KLDivergence(answer)), prediction), ~env))
|
||||||
|
| ("klDivergence", [EvDistribution(prediction), EvDistribution(answer), EvDistribution(prior)]) =>
|
||||||
|
Helpers.klDivergenceWithPrior(prediction, answer, prior, env)
|
||||||
| (
|
| (
|
||||||
"logScoreWithPointAnswer",
|
"logScoreWithPointAnswer",
|
||||||
[EvDistribution(prediction), EvNumber(answer), EvDistribution(prior)],
|
[EvDistribution(prediction), EvNumber(answer), EvDistribution(prior)],
|
||||||
|
|
Loading…
Reference in New Issue
Block a user