good evening, not working yet, but out of time for the night

Value: [1e-6 to 1e-4]
This commit is contained in:
Quinn Dougherty 2022-05-09 19:17:27 -04:00
parent b2d80eef86
commit ccd55ef8f1
4 changed files with 26 additions and 10 deletions

View File

@ -19,7 +19,7 @@ describe("kl divergence on continuous distributions", () => {
let analyticalKl = Js.Math.log((highPrediction -. lowPrediction) /. (highAnswer -. lowAnswer))
let kl = E.R.liftJoin2(klDivergence, prediction, answer)
switch kl {
| Ok(kl') => kl'->expect->toBeCloseTo(analyticalKl)
| Ok(kl') => kl'->expect->toBeSoCloseTo(analyticalKl, ~digits=7)
| Error(err) => {
Js.Console.log(DistributionTypes.Error.toString(err))
raise(KlFailed)
@ -51,7 +51,7 @@ describe("kl divergence on continuous distributions", () => {
let kl = E.R.liftJoin2(klDivergence, prediction, answer)
switch kl {
| Ok(kl') => kl'->expect->toBeCloseTo(analyticalKl)
| Ok(kl') => kl'->expect->toBeSoCloseTo(analyticalKl, ~digits=3)
| Error(err) => {
Js.Console.log(DistributionTypes.Error.toString(err))
raise(KlFailed)
@ -78,9 +78,9 @@ describe("kl divergence on discrete distributions", () => {
| (Dist(prediction'), Dist(answer')) => klDivergence(prediction', answer')
| _ => raise(MixtureFailed)
}
let analyticalKl = Js.Math.log(2.0 /. 3.0)
let analyticalKl = Js.Math.log(3.0 /. 2.0)
switch kl {
| Ok(kl') => kl'->expect->toBeCloseTo(analyticalKl)
| Ok(kl') => kl'->expect->toBeSoCloseTo(analyticalKl, ~digits=7)
| Error(err) =>
Js.Console.log(DistributionTypes.Error.toString(err))
raise(KlFailed)

View File

@ -51,6 +51,7 @@ let mkExponential = rate => DistributionTypes.Symbolic(#Exponential({rate: rate}
let mkUniform = (low, high) => DistributionTypes.Symbolic(#Uniform({low: low, high: high}))
let mkCauchy = (local, scale) => DistributionTypes.Symbolic(#Cauchy({local: local, scale: scale}))
let mkLognormal = (mu, sigma) => DistributionTypes.Symbolic(#Lognormal({mu: mu, sigma: sigma}))
let mkDirac = x => DistributionTypes.Symbolic(#Float(x))
let normalMake = SymbolicDist.Normal.make
let betaMake = SymbolicDist.Beta.make

View File

@ -229,11 +229,25 @@ module T = Dist({
}
let klDivergence = (prediction: t, answer: t) => {
combinePointwise(
~combiner=XYShape.PointwiseCombination.combineAlongSupportOfSecondArgument0,
~fn=PointSetDist_Scoring.KLDivergence.integrand,
prediction,
answer,
) |> E.R2.bind(integralEndYResult)
let massOrZero = (t: t, x: float): float => {
let i = E.A.findIndex(x' => x' == x, t.xyShape.xs)
switch i {
| None => 0.0
| Some(i') => t.xyShape.ys[i']
}
}
let predictionNewYs = E.A.fmap(massOrZero(answer), prediction.xyShape.xs)
let integrand = XYShape.PointwiseCombination.combine(
PointSetDist_Scoring.KLDivergence.integrand,
XYShape.XtoY.continuousInterpolator(#Stepwise, #UseZero),
{XYShape.xs: answer.xyShape.xs, XYShape.ys: predictionNewYs},
answer.xyShape,
)
let xyShapeToDiscrete: XYShape.xyShape => t = xyShape => {
xyShape: xyShape,
integralSumCache: None,
integralCache: None,
}
integrand->E.R2.fmap(x => x->xyShapeToDiscrete->integralEndY)
}
})

View File

@ -199,6 +199,7 @@ module T = Dist({
let klDivergence = (t1: t, t2: t) =>
switch (t1, t2) {
| (Continuous(t1), Continuous(t2)) => Continuous.T.klDivergence(t1, t2)
| (Discrete(t1), Discrete(t2)) => Discrete.T.klDivergence(t1, t2)
| _ => Error(NotYetImplemented)
}
})