Removed negative infinity error handling
Value: [1e-5 to 1e-3]
This commit is contained in:
parent
b07c6e6e9c
commit
b2d80eef86
|
@ -2,7 +2,7 @@ open Jest
|
|||
open Expect
|
||||
open TestHelpers
|
||||
|
||||
describe("kl divergence", () => {
|
||||
describe("kl divergence on continuous distributions", () => {
|
||||
let klDivergence = DistributionOperation.Constructors.klDivergence(~env)
|
||||
exception KlFailed
|
||||
|
||||
|
@ -60,6 +60,48 @@ describe("kl divergence", () => {
|
|||
})
|
||||
})
|
||||
|
||||
describe("kl divergence on discrete distributions", () => {
|
||||
let klDivergence = DistributionOperation.Constructors.klDivergence(~env)
|
||||
let mixture = a => DistributionTypes.DistributionOperation.Mixture(a)
|
||||
exception KlFailed
|
||||
exception MixtureFailed
|
||||
let float1 = 1.0
|
||||
let float2 = 2.0
|
||||
let float3 = 3.0
|
||||
let point1 = mkDirac(float1)
|
||||
let point2 = mkDirac(float2)
|
||||
let point3 = mkDirac(float3)
|
||||
test("finite kl divergence", () => {
|
||||
let answer = [(point1, 1e0), (point2, 1e0)]->mixture->run
|
||||
let prediction = [(point1, 1e0), (point2, 1e0), (point3, 1e0)]->mixture->run
|
||||
let kl = switch (prediction, answer) {
|
||||
| (Dist(prediction'), Dist(answer')) => klDivergence(prediction', answer')
|
||||
| _ => raise(MixtureFailed)
|
||||
}
|
||||
let analyticalKl = Js.Math.log(2.0 /. 3.0)
|
||||
switch kl {
|
||||
| Ok(kl') => kl'->expect->toBeCloseTo(analyticalKl)
|
||||
| Error(err) =>
|
||||
Js.Console.log(DistributionTypes.Error.toString(err))
|
||||
raise(KlFailed)
|
||||
}
|
||||
})
|
||||
test("infinite kl divergence", () => {
|
||||
let prediction = [(point1, 1e0), (point2, 1e0)]->mixture->run
|
||||
let answer = [(point1, 1e0), (point2, 1e0), (point3, 1e0)]->mixture->run
|
||||
let kl = switch (prediction, answer) {
|
||||
| (Dist(prediction'), Dist(answer')) => klDivergence(prediction', answer')
|
||||
| _ => raise(MixtureFailed)
|
||||
}
|
||||
switch kl {
|
||||
| Ok(kl') => kl'->expect->toEqual(neg_infinity)
|
||||
| Error(err) =>
|
||||
Js.Console.log(DistributionTypes.Error.toString(err))
|
||||
raise(KlFailed)
|
||||
}
|
||||
})
|
||||
})
|
||||
|
||||
describe("combine along support test", () => {
|
||||
// This tests the version of the function that we're NOT using. Haven't deleted the test in case we use the code later.
|
||||
test("combine along support test", _ => {
|
||||
|
@ -97,6 +139,7 @@ describe("combine along support test", () => {
|
|||
2.0 *. MagicNumbers.Epsilon.ten,
|
||||
1.0 -. MagicNumbers.Epsilon.ten,
|
||||
1.0,
|
||||
1.0 +. MagicNumbers.Epsilon.ten,
|
||||
],
|
||||
ys: [
|
||||
-0.34657359027997264,
|
||||
|
@ -104,6 +147,7 @@ describe("combine along support test", () => {
|
|||
-0.34657359027997264,
|
||||
-0.34657359027997264,
|
||||
-0.34657359027997264,
|
||||
infinity,
|
||||
],
|
||||
}),
|
||||
),
|
||||
|
|
|
@ -4,10 +4,9 @@ module KLDivergence = {
|
|||
float,
|
||||
Operation.Error.t,
|
||||
> =>
|
||||
// We decided that negative infinity, not an error at answerElement = 0.0, is a desirable value.
|
||||
if answerElement == 0.0 {
|
||||
Ok(0.0)
|
||||
} else if predictionElement == 0.0 {
|
||||
Error(Operation.NegativeInfinityError)
|
||||
} else {
|
||||
let quot = predictionElement /. answerElement
|
||||
quot < 0.0 ? Error(Operation.ComplexNumberError) : Ok(-.answerElement *. logFn(quot))
|
||||
|
|
Loading…
Reference in New Issue
Block a user