Removed negative infinity error handling

Value: [1e-5 to 1e-3]
This commit is contained in:
Quinn Dougherty 2022-05-09 18:28:35 -04:00
parent b07c6e6e9c
commit b2d80eef86
2 changed files with 46 additions and 3 deletions

View File

@ -2,7 +2,7 @@ open Jest
open Expect open Expect
open TestHelpers open TestHelpers
describe("kl divergence", () => { describe("kl divergence on continuous distributions", () => {
let klDivergence = DistributionOperation.Constructors.klDivergence(~env) let klDivergence = DistributionOperation.Constructors.klDivergence(~env)
exception KlFailed exception KlFailed
@ -60,6 +60,48 @@ describe("kl divergence", () => {
}) })
}) })
describe("kl divergence on discrete distributions", () => {
let klDivergence = DistributionOperation.Constructors.klDivergence(~env)
let mixture = a => DistributionTypes.DistributionOperation.Mixture(a)
exception KlFailed
exception MixtureFailed
let float1 = 1.0
let float2 = 2.0
let float3 = 3.0
let point1 = mkDirac(float1)
let point2 = mkDirac(float2)
let point3 = mkDirac(float3)
test("finite kl divergence", () => {
let answer = [(point1, 1e0), (point2, 1e0)]->mixture->run
let prediction = [(point1, 1e0), (point2, 1e0), (point3, 1e0)]->mixture->run
let kl = switch (prediction, answer) {
| (Dist(prediction'), Dist(answer')) => klDivergence(prediction', answer')
| _ => raise(MixtureFailed)
}
let analyticalKl = Js.Math.log(2.0 /. 3.0)
switch kl {
| Ok(kl') => kl'->expect->toBeCloseTo(analyticalKl)
| Error(err) =>
Js.Console.log(DistributionTypes.Error.toString(err))
raise(KlFailed)
}
})
test("infinite kl divergence", () => {
let prediction = [(point1, 1e0), (point2, 1e0)]->mixture->run
let answer = [(point1, 1e0), (point2, 1e0), (point3, 1e0)]->mixture->run
let kl = switch (prediction, answer) {
| (Dist(prediction'), Dist(answer')) => klDivergence(prediction', answer')
| _ => raise(MixtureFailed)
}
switch kl {
| Ok(kl') => kl'->expect->toEqual(neg_infinity)
| Error(err) =>
Js.Console.log(DistributionTypes.Error.toString(err))
raise(KlFailed)
}
})
})
describe("combine along support test", () => { describe("combine along support test", () => {
// This tests the version of the function that we're NOT using. Haven't deleted the test in case we use the code later. // This tests the version of the function that we're NOT using. Haven't deleted the test in case we use the code later.
test("combine along support test", _ => { test("combine along support test", _ => {
@ -97,6 +139,7 @@ describe("combine along support test", () => {
2.0 *. MagicNumbers.Epsilon.ten, 2.0 *. MagicNumbers.Epsilon.ten,
1.0 -. MagicNumbers.Epsilon.ten, 1.0 -. MagicNumbers.Epsilon.ten,
1.0, 1.0,
1.0 +. MagicNumbers.Epsilon.ten,
], ],
ys: [ ys: [
-0.34657359027997264, -0.34657359027997264,
@ -104,6 +147,7 @@ describe("combine along support test", () => {
-0.34657359027997264, -0.34657359027997264,
-0.34657359027997264, -0.34657359027997264,
-0.34657359027997264, -0.34657359027997264,
infinity,
], ],
}), }),
), ),

View File

@ -4,10 +4,9 @@ module KLDivergence = {
float, float,
Operation.Error.t, Operation.Error.t,
> => > =>
// We decided that negative infinity, not an error at answerElement = 0.0, is a desirable value.
if answerElement == 0.0 { if answerElement == 0.0 {
Ok(0.0) Ok(0.0)
} else if predictionElement == 0.0 {
Error(Operation.NegativeInfinityError)
} else { } else {
let quot = predictionElement /. answerElement let quot = predictionElement /. answerElement
quot < 0.0 ? Error(Operation.ComplexNumberError) : Ok(-.answerElement *. logFn(quot)) quot < 0.0 ? Error(Operation.ComplexNumberError) : Ok(-.answerElement *. logFn(quot))