good evening, not working yet, but out of time for the night
Value: [1e-6 to 1e-4]
This commit is contained in:
parent
b2d80eef86
commit
ccd55ef8f1
|
@ -19,7 +19,7 @@ describe("kl divergence on continuous distributions", () => {
|
|||
let analyticalKl = Js.Math.log((highPrediction -. lowPrediction) /. (highAnswer -. lowAnswer))
|
||||
let kl = E.R.liftJoin2(klDivergence, prediction, answer)
|
||||
switch kl {
|
||||
| Ok(kl') => kl'->expect->toBeCloseTo(analyticalKl)
|
||||
| Ok(kl') => kl'->expect->toBeSoCloseTo(analyticalKl, ~digits=7)
|
||||
| Error(err) => {
|
||||
Js.Console.log(DistributionTypes.Error.toString(err))
|
||||
raise(KlFailed)
|
||||
|
@ -51,7 +51,7 @@ describe("kl divergence on continuous distributions", () => {
|
|||
let kl = E.R.liftJoin2(klDivergence, prediction, answer)
|
||||
|
||||
switch kl {
|
||||
| Ok(kl') => kl'->expect->toBeCloseTo(analyticalKl)
|
||||
| Ok(kl') => kl'->expect->toBeSoCloseTo(analyticalKl, ~digits=3)
|
||||
| Error(err) => {
|
||||
Js.Console.log(DistributionTypes.Error.toString(err))
|
||||
raise(KlFailed)
|
||||
|
@ -78,9 +78,9 @@ describe("kl divergence on discrete distributions", () => {
|
|||
| (Dist(prediction'), Dist(answer')) => klDivergence(prediction', answer')
|
||||
| _ => raise(MixtureFailed)
|
||||
}
|
||||
let analyticalKl = Js.Math.log(2.0 /. 3.0)
|
||||
let analyticalKl = Js.Math.log(3.0 /. 2.0)
|
||||
switch kl {
|
||||
| Ok(kl') => kl'->expect->toBeCloseTo(analyticalKl)
|
||||
| Ok(kl') => kl'->expect->toBeSoCloseTo(analyticalKl, ~digits=7)
|
||||
| Error(err) =>
|
||||
Js.Console.log(DistributionTypes.Error.toString(err))
|
||||
raise(KlFailed)
|
||||
|
|
|
@ -51,6 +51,7 @@ let mkExponential = rate => DistributionTypes.Symbolic(#Exponential({rate: rate}
|
|||
let mkUniform = (low, high) => DistributionTypes.Symbolic(#Uniform({low: low, high: high}))
|
||||
let mkCauchy = (local, scale) => DistributionTypes.Symbolic(#Cauchy({local: local, scale: scale}))
|
||||
let mkLognormal = (mu, sigma) => DistributionTypes.Symbolic(#Lognormal({mu: mu, sigma: sigma}))
|
||||
let mkDirac = x => DistributionTypes.Symbolic(#Float(x))
|
||||
|
||||
let normalMake = SymbolicDist.Normal.make
|
||||
let betaMake = SymbolicDist.Beta.make
|
||||
|
|
|
@ -229,11 +229,25 @@ module T = Dist({
|
|||
}
|
||||
|
||||
let klDivergence = (prediction: t, answer: t) => {
|
||||
combinePointwise(
|
||||
~combiner=XYShape.PointwiseCombination.combineAlongSupportOfSecondArgument0,
|
||||
~fn=PointSetDist_Scoring.KLDivergence.integrand,
|
||||
prediction,
|
||||
answer,
|
||||
) |> E.R2.bind(integralEndYResult)
|
||||
let massOrZero = (t: t, x: float): float => {
|
||||
let i = E.A.findIndex(x' => x' == x, t.xyShape.xs)
|
||||
switch i {
|
||||
| None => 0.0
|
||||
| Some(i') => t.xyShape.ys[i']
|
||||
}
|
||||
}
|
||||
let predictionNewYs = E.A.fmap(massOrZero(answer), prediction.xyShape.xs)
|
||||
let integrand = XYShape.PointwiseCombination.combine(
|
||||
PointSetDist_Scoring.KLDivergence.integrand,
|
||||
XYShape.XtoY.continuousInterpolator(#Stepwise, #UseZero),
|
||||
{XYShape.xs: answer.xyShape.xs, XYShape.ys: predictionNewYs},
|
||||
answer.xyShape,
|
||||
)
|
||||
let xyShapeToDiscrete: XYShape.xyShape => t = xyShape => {
|
||||
xyShape: xyShape,
|
||||
integralSumCache: None,
|
||||
integralCache: None,
|
||||
}
|
||||
integrand->E.R2.fmap(x => x->xyShapeToDiscrete->integralEndY)
|
||||
}
|
||||
})
|
||||
|
|
|
@ -199,6 +199,7 @@ module T = Dist({
|
|||
let klDivergence = (t1: t, t2: t) =>
|
||||
switch (t1, t2) {
|
||||
| (Continuous(t1), Continuous(t2)) => Continuous.T.klDivergence(t1, t2)
|
||||
| (Discrete(t1), Discrete(t2)) => Discrete.T.klDivergence(t1, t2)
|
||||
| _ => Error(NotYetImplemented)
|
||||
}
|
||||
})
|
||||
|
|
Loading…
Reference in New Issue
Block a user