Gained precision

Value: [1e-3 to 2e-2]
This commit is contained in:
Quinn Dougherty 2022-05-11 15:46:57 -04:00
parent 26afc96495
commit e1e5e3305d
2 changed files with 25 additions and 5 deletions

View File

@ -11,6 +11,7 @@ let triangularDist: DistributionTypes.genericDist = Symbolic(
) )
let exponentialDist: DistributionTypes.genericDist = Symbolic(#Exponential({rate: 2.0})) let exponentialDist: DistributionTypes.genericDist = Symbolic(#Exponential({rate: 2.0}))
let uniformDist: DistributionTypes.genericDist = Symbolic(#Uniform({low: 9.0, high: 10.0})) let uniformDist: DistributionTypes.genericDist = Symbolic(#Uniform({low: 9.0, high: 10.0}))
let uniformDist2: DistributionTypes.genericDist = Symbolic(#Uniform({low: 8.0, high: 11.0}))
let floatDist: DistributionTypes.genericDist = Symbolic(#Float(1e1)) let floatDist: DistributionTypes.genericDist = Symbolic(#Float(1e1))
exception KlFailed exception KlFailed

View File

@ -121,8 +121,13 @@ describe("klDivergence: mixed -> mixed -> float", () => {
let mixture = a => DistributionTypes.DistributionOperation.Mixture(a) let mixture = a => DistributionTypes.DistributionOperation.Mixture(a)
let a' = [(point1, 1e0), (uniformDist, 1e0)]->mixture->run let a' = [(point1, 1e0), (uniformDist, 1e0)]->mixture->run
let b' = [(point1, 1e0), (floatDist, 1e0), (normalDist10, 1e0)]->mixture->run let b' = [(point1, 1e0), (floatDist, 1e0), (normalDist10, 1e0)]->mixture->run
let (a, b) = switch (a', b') { let c' = [(point1, 1e0), (point2, 1e0), (point3, 1e0), (uniformDist, 1e0)]->mixture->run
| (Dist(a''), Dist(b'')) => (a'', b'') let d' =
[(point1, 1e0), (point2, 1e0), (point3, 1e0), (floatDist, 1e0), (uniformDist2, 1e0)]
->mixture
->run
let (a, b, c, d) = switch (a', b', c', d') {
| (Dist(a''), Dist(b''), Dist(c''), Dist(d'')) => (a'', b'', c'', d'')
| _ => raise(MixtureFailed) | _ => raise(MixtureFailed)
} }
test("finite klDivergence return is correct", () => { test("finite klDivergence return is correct", () => {
@ -130,11 +135,11 @@ describe("klDivergence: mixed -> mixed -> float", () => {
let answer = a let answer = a
let kl = klDivergence(prediction, answer) let kl = klDivergence(prediction, answer)
// high = 10; low = 9; mean = 10; stdev = 2 // high = 10; low = 9; mean = 10; stdev = 2
let analyticalKlContinuousPart = klNormalUniform(10.0, 2.0, 9.0, 10.0) let analyticalKlContinuousPart = klNormalUniform(10.0, 2.0, 9.0, 10.0) /. 2.0
let analyticalKlDiscretePart = Js.Math.log(2.0 /. 3.0) /. 2.0 let analyticalKlDiscretePart = 1.0 /. 2.0 *. Js.Math.log(2.0 /. 1.0)
switch kl { switch kl {
| Ok(kl') => | Ok(kl') =>
kl'->expect->toBeSoCloseTo(analyticalKlContinuousPart +. analyticalKlDiscretePart, ~digits=0) kl'->expect->toBeSoCloseTo(analyticalKlContinuousPart +. analyticalKlDiscretePart, ~digits=1)
| Error(err) => | Error(err) =>
Js.Console.log(DistributionTypes.Error.toString(err)) Js.Console.log(DistributionTypes.Error.toString(err))
raise(KlFailed) raise(KlFailed)
@ -151,6 +156,20 @@ describe("klDivergence: mixed -> mixed -> float", () => {
raise(KlFailed) raise(KlFailed)
} }
}) })
test("finite klDivergence return is correct", () => {
let prediction = d
let answer = c
let kl = klDivergence(prediction, answer)
let analyticalKlContinuousPart = Js.Math.log((11.0 -. 8.0) /. (10.0 -. 9.0)) /. 4.0 // 4 = length of c' array
let analyticalKlDiscretePart = 3.0 /. 4.0 *. Js.Math.log(4.0 /. 3.0)
switch kl {
| Ok(kl') =>
kl'->expect->toBeSoCloseTo(analyticalKlContinuousPart +. analyticalKlDiscretePart, ~digits=1)
| Error(err) =>
Js.Console.log(DistributionTypes.Error.toString(err))
raise(KlFailed)
}
})
}) })
describe("combineAlongSupportOfSecondArgument0", () => { describe("combineAlongSupportOfSecondArgument0", () => {