Response to CR

Value: [1e-6 to 1e-4]
This commit is contained in:
Quinn Dougherty 2022-05-10 14:03:42 -04:00
parent f7690c33e0
commit 29c1956e88
2 changed files with 13 additions and 10 deletions

View File

@ -12,3 +12,13 @@ let triangularDist: DistributionTypes.genericDist = Symbolic(
let exponentialDist: DistributionTypes.genericDist = Symbolic(#Exponential({rate: 2.0})) let exponentialDist: DistributionTypes.genericDist = Symbolic(#Exponential({rate: 2.0}))
let uniformDist: DistributionTypes.genericDist = Symbolic(#Uniform({low: 9.0, high: 10.0})) let uniformDist: DistributionTypes.genericDist = Symbolic(#Uniform({low: 9.0, high: 10.0}))
let floatDist: DistributionTypes.genericDist = Symbolic(#Float(1e1)) let floatDist: DistributionTypes.genericDist = Symbolic(#Float(1e1))
exception KlFailed
exception MixtureFailed
let float1 = 1.0
let float2 = 2.0
let float3 = 3.0
let {mkDelta} = module(TestHelpers)
let point1 = mkDelta(float1)
let point2 = mkDelta(float2)
let point3 = mkDelta(float3)

View File

@ -1,6 +1,7 @@
open Jest open Jest
open Expect open Expect
open TestHelpers open TestHelpers
open GenericDist_Fixtures
describe("klDivergence: continuous -> continuous -> float", () => { describe("klDivergence: continuous -> continuous -> float", () => {
let klDivergence = DistributionOperation.Constructors.klDivergence(~env) let klDivergence = DistributionOperation.Constructors.klDivergence(~env)
@ -63,21 +64,13 @@ describe("klDivergence: continuous -> continuous -> float", () => {
describe("klDivergence: discrete -> discrete -> float", () => { describe("klDivergence: discrete -> discrete -> float", () => {
let klDivergence = DistributionOperation.Constructors.klDivergence(~env) let klDivergence = DistributionOperation.Constructors.klDivergence(~env)
let mixture = a => DistributionTypes.DistributionOperation.Mixture(a) let mixture = a => DistributionTypes.DistributionOperation.Mixture(a)
exception KlFailed
exception MixtureFailed
let float1 = 1.0
let float2 = 2.0
let float3 = 3.0
let point1 = mkDelta(float1)
let point2 = mkDelta(float2)
let point3 = mkDelta(float3)
let a' = [(point1, 1e0), (point2, 1e0)]->mixture->run let a' = [(point1, 1e0), (point2, 1e0)]->mixture->run
let b' = [(point1, 1e0), (point2, 1e0), (point3, 1e0)]->mixture->run let b' = [(point1, 1e0), (point2, 1e0), (point3, 1e0)]->mixture->run
let (a, b) = switch (a', b') { let (a, b) = switch (a', b') {
| (Dist(a''), Dist(b'')) => (a'', b'') | (Dist(a''), Dist(b'')) => (a'', b'')
| _ => raise(MixtureFailed) | _ => raise(MixtureFailed)
} }
test("is finite", () => { test("agrees with analytical answer when finite", () => {
let prediction = b let prediction = b
let answer = a let answer = a
let kl = klDivergence(prediction, answer) let kl = klDivergence(prediction, answer)
@ -90,7 +83,7 @@ describe("klDivergence: discrete -> discrete -> float", () => {
raise(KlFailed) raise(KlFailed)
} }
}) })
test("is infinite", () => { test("returns infinity when infinite", () => {
let prediction = a let prediction = a
let answer = b let answer = b
let kl = klDivergence(prediction, answer) let kl = klDivergence(prediction, answer)