KLDivergence on discretes is passing

Value: [1e-3 to 2e-1]
This commit is contained in:
Quinn Dougherty 2022-05-10 11:27:59 -04:00
parent ccd55ef8f1
commit 15f1ebb429
5 changed files with 19 additions and 32 deletions

View File

@ -68,9 +68,9 @@ describe("kl divergence on discrete distributions", () => {
let float1 = 1.0 let float1 = 1.0
let float2 = 2.0 let float2 = 2.0
let float3 = 3.0 let float3 = 3.0
let point1 = mkDirac(float1) let point1 = mkDelta(float1)
let point2 = mkDirac(float2) let point2 = mkDelta(float2)
let point3 = mkDirac(float3) let point3 = mkDelta(float3)
test("finite kl divergence", () => { test("finite kl divergence", () => {
let answer = [(point1, 1e0), (point2, 1e0)]->mixture->run let answer = [(point1, 1e0), (point2, 1e0)]->mixture->run
let prediction = [(point1, 1e0), (point2, 1e0), (point3, 1e0)]->mixture->run let prediction = [(point1, 1e0), (point2, 1e0), (point3, 1e0)]->mixture->run
@ -94,7 +94,7 @@ describe("kl divergence on discrete distributions", () => {
| _ => raise(MixtureFailed) | _ => raise(MixtureFailed)
} }
switch kl { switch kl {
| Ok(kl') => kl'->expect->toEqual(neg_infinity) | Ok(kl') => kl'->expect->toEqual(infinity)
| Error(err) => | Error(err) =>
Js.Console.log(DistributionTypes.Error.toString(err)) Js.Console.log(DistributionTypes.Error.toString(err))
raise(KlFailed) raise(KlFailed)
@ -102,9 +102,9 @@ describe("kl divergence on discrete distributions", () => {
}) })
}) })
describe("combine along support test", () => { describe("combineAlongSupportOfSecondArgument", () => {
// This tests the version of the function that we're NOT using. Haven't deleted the test in case we use the code later. // This tests the version of the function that we're NOT using. Haven't deleted the test in case we use the code later.
test("combine along support test", _ => { test("test on two uniforms", _ => {
let combineAlongSupportOfSecondArgument = XYShape.PointwiseCombination.combineAlongSupportOfSecondArgument0 let combineAlongSupportOfSecondArgument = XYShape.PointwiseCombination.combineAlongSupportOfSecondArgument0
let lowAnswer = 0.0 let lowAnswer = 0.0
let highAnswer = 1.0 let highAnswer = 1.0

View File

@ -51,7 +51,7 @@ let mkExponential = rate => DistributionTypes.Symbolic(#Exponential({rate: rate}
let mkUniform = (low, high) => DistributionTypes.Symbolic(#Uniform({low: low, high: high})) let mkUniform = (low, high) => DistributionTypes.Symbolic(#Uniform({low: low, high: high}))
let mkCauchy = (local, scale) => DistributionTypes.Symbolic(#Cauchy({local: local, scale: scale})) let mkCauchy = (local, scale) => DistributionTypes.Symbolic(#Cauchy({local: local, scale: scale}))
let mkLognormal = (mu, sigma) => DistributionTypes.Symbolic(#Lognormal({mu: mu, sigma: sigma})) let mkLognormal = (mu, sigma) => DistributionTypes.Symbolic(#Lognormal({mu: mu, sigma: sigma}))
let mkDirac = x => DistributionTypes.Symbolic(#Float(x)) let mkDelta = x => DistributionTypes.Symbolic(#Float(x))
let normalMake = SymbolicDist.Normal.make let normalMake = SymbolicDist.Normal.make
let betaMake = SymbolicDist.Beta.make let betaMake = SymbolicDist.Beta.make

View File

@ -50,7 +50,7 @@ let combinePointwise = (
make( make(
combiner(fn, XYShape.XtoY.discreteInterpolator, t1.xyShape, t2.xyShape)->E.R.toExn( combiner(fn, XYShape.XtoY.discreteInterpolator, t1.xyShape, t2.xyShape)->E.R.toExn(
"Addition operation should never fail", "Logically unreachable?",
_, _,
), ),
)->Ok )->Ok
@ -163,7 +163,6 @@ module T = Dist({
} }
let integralEndY = (t: t) => t.integralSumCache |> E.O.default(t |> integral |> Continuous.lastY) let integralEndY = (t: t) => t.integralSumCache |> E.O.default(t |> integral |> Continuous.lastY)
let integralEndYResult = (t: t) => t->integralEndY->Ok
let minX = shapeFn(XYShape.T.minX) let minX = shapeFn(XYShape.T.minX)
let maxX = shapeFn(XYShape.T.maxX) let maxX = shapeFn(XYShape.T.maxX)
let toDiscreteProbabilityMassFraction = _ => 1.0 let toDiscreteProbabilityMassFraction = _ => 1.0
@ -229,25 +228,10 @@ module T = Dist({
} }
let klDivergence = (prediction: t, answer: t) => { let klDivergence = (prediction: t, answer: t) => {
let massOrZero = (t: t, x: float): float => { combinePointwise(
let i = E.A.findIndex(x' => x' == x, t.xyShape.xs) ~fn=PointSetDist_Scoring.KLDivergence.integrand,
switch i { prediction,
| None => 0.0 answer,
| Some(i') => t.xyShape.ys[i'] )->E.R2.fmap(integralEndY)
}
}
let predictionNewYs = E.A.fmap(massOrZero(answer), prediction.xyShape.xs)
let integrand = XYShape.PointwiseCombination.combine(
PointSetDist_Scoring.KLDivergence.integrand,
XYShape.XtoY.continuousInterpolator(#Stepwise, #UseZero),
{XYShape.xs: answer.xyShape.xs, XYShape.ys: predictionNewYs},
answer.xyShape,
)
let xyShapeToDiscrete: XYShape.xyShape => t = xyShape => {
xyShape: xyShape,
integralSumCache: None,
integralCache: None,
}
integrand->E.R2.fmap(x => x->xyShapeToDiscrete->integralEndY)
} }
}) })

View File

@ -302,9 +302,10 @@ module T = Dist({
} }
let klDivergence = (prediction: t, answer: t) => { let klDivergence = (prediction: t, answer: t) => {
combinePointwise(PointSetDist_Scoring.KLDivergence.integrand, prediction, answer) |> E.R.fmap( Error(Operation.NotYetImplemented)
integralEndY, // combinePointwise(PointSetDist_Scoring.KLDivergence.integrand, prediction, answer) |> E.R.fmap(
) // integralEndY,
// )
} }
}) })

View File

@ -7,6 +7,8 @@ module KLDivergence = {
// We decided that negative infinity, not an error at answerElement = 0.0, is a desirable value. // We decided that negative infinity, not an error at answerElement = 0.0, is a desirable value.
if answerElement == 0.0 { if answerElement == 0.0 {
Ok(0.0) Ok(0.0)
} else if predictionElement == 0.0 {
Ok(infinity)
} else { } else {
let quot = predictionElement /. answerElement let quot = predictionElement /. answerElement
quot < 0.0 ? Error(Operation.ComplexNumberError) : Ok(-.answerElement *. logFn(quot)) quot < 0.0 ? Error(Operation.ComplexNumberError) : Ok(-.answerElement *. logFn(quot))