Gained precision
Value: [1e-3 to 2e-2]
This commit is contained in:
parent
26afc96495
commit
e1e5e3305d
|
@ -11,6 +11,7 @@ let triangularDist: DistributionTypes.genericDist = Symbolic(
|
|||
)
|
||||
let exponentialDist: DistributionTypes.genericDist = Symbolic(#Exponential({rate: 2.0}))
|
||||
let uniformDist: DistributionTypes.genericDist = Symbolic(#Uniform({low: 9.0, high: 10.0}))
|
||||
let uniformDist2: DistributionTypes.genericDist = Symbolic(#Uniform({low: 8.0, high: 11.0}))
|
||||
let floatDist: DistributionTypes.genericDist = Symbolic(#Float(1e1))
|
||||
|
||||
exception KlFailed
|
||||
|
|
|
@ -121,8 +121,13 @@ describe("klDivergence: mixed -> mixed -> float", () => {
|
|||
let mixture = a => DistributionTypes.DistributionOperation.Mixture(a)
|
||||
let a' = [(point1, 1e0), (uniformDist, 1e0)]->mixture->run
|
||||
let b' = [(point1, 1e0), (floatDist, 1e0), (normalDist10, 1e0)]->mixture->run
|
||||
let (a, b) = switch (a', b') {
|
||||
| (Dist(a''), Dist(b'')) => (a'', b'')
|
||||
let c' = [(point1, 1e0), (point2, 1e0), (point3, 1e0), (uniformDist, 1e0)]->mixture->run
|
||||
let d' =
|
||||
[(point1, 1e0), (point2, 1e0), (point3, 1e0), (floatDist, 1e0), (uniformDist2, 1e0)]
|
||||
->mixture
|
||||
->run
|
||||
let (a, b, c, d) = switch (a', b', c', d') {
|
||||
| (Dist(a''), Dist(b''), Dist(c''), Dist(d'')) => (a'', b'', c'', d'')
|
||||
| _ => raise(MixtureFailed)
|
||||
}
|
||||
test("finite klDivergence return is correct", () => {
|
||||
|
@ -130,11 +135,11 @@ describe("klDivergence: mixed -> mixed -> float", () => {
|
|||
let answer = a
|
||||
let kl = klDivergence(prediction, answer)
|
||||
// high = 10; low = 9; mean = 10; stdev = 2
|
||||
let analyticalKlContinuousPart = klNormalUniform(10.0, 2.0, 9.0, 10.0)
|
||||
let analyticalKlDiscretePart = Js.Math.log(2.0 /. 3.0) /. 2.0
|
||||
let analyticalKlContinuousPart = klNormalUniform(10.0, 2.0, 9.0, 10.0) /. 2.0
|
||||
let analyticalKlDiscretePart = 1.0 /. 2.0 *. Js.Math.log(2.0 /. 1.0)
|
||||
switch kl {
|
||||
| Ok(kl') =>
|
||||
kl'->expect->toBeSoCloseTo(analyticalKlContinuousPart +. analyticalKlDiscretePart, ~digits=0)
|
||||
kl'->expect->toBeSoCloseTo(analyticalKlContinuousPart +. analyticalKlDiscretePart, ~digits=1)
|
||||
| Error(err) =>
|
||||
Js.Console.log(DistributionTypes.Error.toString(err))
|
||||
raise(KlFailed)
|
||||
|
@ -151,6 +156,20 @@ describe("klDivergence: mixed -> mixed -> float", () => {
|
|||
raise(KlFailed)
|
||||
}
|
||||
})
|
||||
test("finite klDivergence return is correct", () => {
|
||||
let prediction = d
|
||||
let answer = c
|
||||
let kl = klDivergence(prediction, answer)
|
||||
let analyticalKlContinuousPart = Js.Math.log((11.0 -. 8.0) /. (10.0 -. 9.0)) /. 4.0 // 4 = length of c' array
|
||||
let analyticalKlDiscretePart = 3.0 /. 4.0 *. Js.Math.log(4.0 /. 3.0)
|
||||
switch kl {
|
||||
| Ok(kl') =>
|
||||
kl'->expect->toBeSoCloseTo(analyticalKlContinuousPart +. analyticalKlDiscretePart, ~digits=1)
|
||||
| Error(err) =>
|
||||
Js.Console.log(DistributionTypes.Error.toString(err))
|
||||
raise(KlFailed)
|
||||
}
|
||||
})
|
||||
})
|
||||
|
||||
describe("combineAlongSupportOfSecondArgument0", () => {
|
||||
|
|
Loading…
Reference in New Issue
Block a user