klDivergence
with prior
Value: [1e-4 to 5e-23]
This commit is contained in:
parent
d00b82807c
commit
81b2c74ac8
|
@ -180,6 +180,20 @@ module Helpers = {
|
|||
->PointSetTypes.Continuous
|
||||
->DistributionTypes.PointSet
|
||||
}
|
||||
|
||||
let klDivergenceWithPrior = (
|
||||
prediction: DistributionTypes.genericDist,
|
||||
answer: DistributionTypes.genericDist,
|
||||
prior: DistributionTypes.genericDist,
|
||||
env: DistributionOperation.env,
|
||||
) => {
|
||||
let term1 = DistributionOperation.Constructors.klDivergence(~env, prediction, answer)
|
||||
let term2 = DistributionOperation.Constructors.klDivergence(~env, prior, answer)
|
||||
switch E.R.merge(term1, term2)->E.R2.fmap(((a, b)) => a -. b) {
|
||||
| Ok(x) => x->DistributionOperation.Float->Some
|
||||
| Error(_) => None
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
module SymbolicConstructors = {
|
||||
|
@ -268,8 +282,10 @@ let dispatchToGenericOutput = (
|
|||
~env,
|
||||
)->Some
|
||||
| ("normalize", [EvDistribution(dist)]) => Helpers.toDistFn(Normalize, dist, ~env)
|
||||
| ("klDivergence", [EvDistribution(a), EvDistribution(b)]) =>
|
||||
Some(DistributionOperation.run(FromDist(ToScore(KLDivergence(b)), a), ~env))
|
||||
| ("klDivergence", [EvDistribution(prediction), EvDistribution(answer)]) =>
|
||||
Some(DistributionOperation.run(FromDist(ToScore(KLDivergence(answer)), prediction), ~env))
|
||||
| ("klDivergence", [EvDistribution(prediction), EvDistribution(answer), EvDistribution(prior)]) =>
|
||||
Helpers.klDivergenceWithPrior(prediction, answer, prior, env)
|
||||
| (
|
||||
"logScoreWithPointAnswer",
|
||||
[EvDistribution(prediction), EvNumber(answer), EvDistribution(prior)],
|
||||
|
|
Loading…
Reference in New Issue
Block a user