Factored continuous part of normal and uniform kldivergence into it's
own function Value: [1e-4 to 1e-3]
This commit is contained in:
parent
c13f49a7bc
commit
95fe117ef0
|
@ -3,6 +3,12 @@ open Expect
|
||||||
open TestHelpers
|
open TestHelpers
|
||||||
open GenericDist_Fixtures
|
open GenericDist_Fixtures
|
||||||
|
|
||||||
|
let klNormalUniform = (mean, stdev, low, high): float =>
|
||||||
|
-.Js.Math.log((high -. low) /. Js.Math.sqrt(2.0 *. MagicNumbers.Math.pi *. stdev ** 2.0)) +.
|
||||||
|
1.0 /.
|
||||||
|
stdev ** 2.0 *.
|
||||||
|
(mean ** 2.0 -. (high +. low) *. mean +. (low ** 2.0 +. high *. low +. high ** 2.0) /. 3.0)
|
||||||
|
|
||||||
describe("klDivergence: continuous -> continuous -> float", () => {
|
describe("klDivergence: continuous -> continuous -> float", () => {
|
||||||
let klDivergence = DistributionOperation.Constructors.klDivergence(~env)
|
let klDivergence = DistributionOperation.Constructors.klDivergence(~env)
|
||||||
|
|
||||||
|
@ -63,11 +69,7 @@ describe("klDivergence: continuous -> continuous -> float", () => {
|
||||||
let prediction = normalDist10
|
let prediction = normalDist10
|
||||||
let answer = uniformDist
|
let answer = uniformDist
|
||||||
let kl = klDivergence(prediction, answer)
|
let kl = klDivergence(prediction, answer)
|
||||||
let analyticalKl =
|
let analyticalKl = klNormalUniform(10.0, 2.0, 9.0, 10.0)
|
||||||
-.Js.Math.log((10.0 -. 9.0) /. Js.Math.sqrt(2.0 *. MagicNumbers.Math.pi *. 2.0 ** 2.0)) +.
|
|
||||||
1.0 /.
|
|
||||||
2.0 ** 2.0 *.
|
|
||||||
(10.0 ** 2.0 -. (10.0 +. 9.0) *. 10.0 +. (9.0 ** 2.0 +. 10.0 *. 9.0 +. 10.0 ** 2.0) /. 3.0)
|
|
||||||
switch kl {
|
switch kl {
|
||||||
| Ok(kl') => kl'->expect->toBeSoCloseTo(analyticalKl, ~digits=3)
|
| Ok(kl') => kl'->expect->toBeSoCloseTo(analyticalKl, ~digits=3)
|
||||||
| Error(err) => {
|
| Error(err) => {
|
||||||
|
@ -122,17 +124,13 @@ describe("klDivergence: mixed -> mixed -> float", () => {
|
||||||
| (Dist(a''), Dist(b'')) => (a'', b'')
|
| (Dist(a''), Dist(b'')) => (a'', b'')
|
||||||
| _ => raise(MixtureFailed)
|
| _ => raise(MixtureFailed)
|
||||||
}
|
}
|
||||||
test("finite klDivergence returns is correct", () => {
|
test("finite klDivergence return is correct", () => {
|
||||||
let prediction = b
|
let prediction = b
|
||||||
let answer = a
|
let answer = a
|
||||||
let kl = klDivergence(prediction, answer)
|
let kl = klDivergence(prediction, answer)
|
||||||
// high = 10; low = 9; mean = 10; stdev = 2
|
// high = 10; low = 9; mean = 10; stdev = 2
|
||||||
let analyticalKlContinuousPart =
|
let analyticalKlContinuousPart = klNormalUniform(10.0, 2.0, 9.0, 10.0)
|
||||||
Js.Math.log(Js.Math.sqrt(2.0 *. MagicNumbers.Math.pi *. 2.0 ** 2.0) /. (10.0 -. 9.0)) +.
|
let analyticalKlDiscretePart = 2.0 /. 3.0 *. Js.Math.log(2.0 /. 3.0)
|
||||||
1.0 /.
|
|
||||||
2.0 ** 2.0 *.
|
|
||||||
(10.0 ** 2.0 -. (10.0 +. 9.0) *. 10.0 +. (9.0 ** 2.0 +. 10.0 *. 9.0 +. 10.0 ** 2.0) /. 3.0)
|
|
||||||
let analyticalKlDiscretePart = -2.0 /. 3.0 *. Js.Math.log(3.0 /. 2.0)
|
|
||||||
switch kl {
|
switch kl {
|
||||||
| Ok(kl') =>
|
| Ok(kl') =>
|
||||||
kl'->expect->toBeSoCloseTo(analyticalKlContinuousPart +. analyticalKlDiscretePart, ~digits=0)
|
kl'->expect->toBeSoCloseTo(analyticalKlContinuousPart +. analyticalKlDiscretePart, ~digits=0)
|
||||||
|
|
Loading…
Reference in New Issue
Block a user