Removed negative infinity error handling
Value: [1e-5 to 1e-3]
This commit is contained in:
parent
b07c6e6e9c
commit
b2d80eef86
|
@ -2,7 +2,7 @@ open Jest
|
||||||
open Expect
|
open Expect
|
||||||
open TestHelpers
|
open TestHelpers
|
||||||
|
|
||||||
describe("kl divergence", () => {
|
describe("kl divergence on continuous distributions", () => {
|
||||||
let klDivergence = DistributionOperation.Constructors.klDivergence(~env)
|
let klDivergence = DistributionOperation.Constructors.klDivergence(~env)
|
||||||
exception KlFailed
|
exception KlFailed
|
||||||
|
|
||||||
|
@ -60,6 +60,48 @@ describe("kl divergence", () => {
|
||||||
})
|
})
|
||||||
})
|
})
|
||||||
|
|
||||||
|
describe("kl divergence on discrete distributions", () => {
|
||||||
|
let klDivergence = DistributionOperation.Constructors.klDivergence(~env)
|
||||||
|
let mixture = a => DistributionTypes.DistributionOperation.Mixture(a)
|
||||||
|
exception KlFailed
|
||||||
|
exception MixtureFailed
|
||||||
|
let float1 = 1.0
|
||||||
|
let float2 = 2.0
|
||||||
|
let float3 = 3.0
|
||||||
|
let point1 = mkDirac(float1)
|
||||||
|
let point2 = mkDirac(float2)
|
||||||
|
let point3 = mkDirac(float3)
|
||||||
|
test("finite kl divergence", () => {
|
||||||
|
let answer = [(point1, 1e0), (point2, 1e0)]->mixture->run
|
||||||
|
let prediction = [(point1, 1e0), (point2, 1e0), (point3, 1e0)]->mixture->run
|
||||||
|
let kl = switch (prediction, answer) {
|
||||||
|
| (Dist(prediction'), Dist(answer')) => klDivergence(prediction', answer')
|
||||||
|
| _ => raise(MixtureFailed)
|
||||||
|
}
|
||||||
|
let analyticalKl = Js.Math.log(2.0 /. 3.0)
|
||||||
|
switch kl {
|
||||||
|
| Ok(kl') => kl'->expect->toBeCloseTo(analyticalKl)
|
||||||
|
| Error(err) =>
|
||||||
|
Js.Console.log(DistributionTypes.Error.toString(err))
|
||||||
|
raise(KlFailed)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
test("infinite kl divergence", () => {
|
||||||
|
let prediction = [(point1, 1e0), (point2, 1e0)]->mixture->run
|
||||||
|
let answer = [(point1, 1e0), (point2, 1e0), (point3, 1e0)]->mixture->run
|
||||||
|
let kl = switch (prediction, answer) {
|
||||||
|
| (Dist(prediction'), Dist(answer')) => klDivergence(prediction', answer')
|
||||||
|
| _ => raise(MixtureFailed)
|
||||||
|
}
|
||||||
|
switch kl {
|
||||||
|
| Ok(kl') => kl'->expect->toEqual(neg_infinity)
|
||||||
|
| Error(err) =>
|
||||||
|
Js.Console.log(DistributionTypes.Error.toString(err))
|
||||||
|
raise(KlFailed)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
})
|
||||||
|
|
||||||
describe("combine along support test", () => {
|
describe("combine along support test", () => {
|
||||||
// This tests the version of the function that we're NOT using. Haven't deleted the test in case we use the code later.
|
// This tests the version of the function that we're NOT using. Haven't deleted the test in case we use the code later.
|
||||||
test("combine along support test", _ => {
|
test("combine along support test", _ => {
|
||||||
|
@ -97,6 +139,7 @@ describe("combine along support test", () => {
|
||||||
2.0 *. MagicNumbers.Epsilon.ten,
|
2.0 *. MagicNumbers.Epsilon.ten,
|
||||||
1.0 -. MagicNumbers.Epsilon.ten,
|
1.0 -. MagicNumbers.Epsilon.ten,
|
||||||
1.0,
|
1.0,
|
||||||
|
1.0 +. MagicNumbers.Epsilon.ten,
|
||||||
],
|
],
|
||||||
ys: [
|
ys: [
|
||||||
-0.34657359027997264,
|
-0.34657359027997264,
|
||||||
|
@ -104,6 +147,7 @@ describe("combine along support test", () => {
|
||||||
-0.34657359027997264,
|
-0.34657359027997264,
|
||||||
-0.34657359027997264,
|
-0.34657359027997264,
|
||||||
-0.34657359027997264,
|
-0.34657359027997264,
|
||||||
|
infinity,
|
||||||
],
|
],
|
||||||
}),
|
}),
|
||||||
),
|
),
|
||||||
|
|
|
@ -4,10 +4,9 @@ module KLDivergence = {
|
||||||
float,
|
float,
|
||||||
Operation.Error.t,
|
Operation.Error.t,
|
||||||
> =>
|
> =>
|
||||||
|
// We decided that negative infinity, not an error at answerElement = 0.0, is a desirable value.
|
||||||
if answerElement == 0.0 {
|
if answerElement == 0.0 {
|
||||||
Ok(0.0)
|
Ok(0.0)
|
||||||
} else if predictionElement == 0.0 {
|
|
||||||
Error(Operation.NegativeInfinityError)
|
|
||||||
} else {
|
} else {
|
||||||
let quot = predictionElement /. answerElement
|
let quot = predictionElement /. answerElement
|
||||||
quot < 0.0 ? Error(Operation.ComplexNumberError) : Ok(-.answerElement *. logFn(quot))
|
quot < 0.0 ? Error(Operation.ComplexNumberError) : Ok(-.answerElement *. logFn(quot))
|
||||||
|
|
Loading…
Reference in New Issue
Block a user