Gained precision
Value: [1e-3 to 2e-2]
This commit is contained in:
parent
26afc96495
commit
e1e5e3305d
|
@ -11,6 +11,7 @@ let triangularDist: DistributionTypes.genericDist = Symbolic(
|
||||||
)
|
)
|
||||||
let exponentialDist: DistributionTypes.genericDist = Symbolic(#Exponential({rate: 2.0}))
|
let exponentialDist: DistributionTypes.genericDist = Symbolic(#Exponential({rate: 2.0}))
|
||||||
let uniformDist: DistributionTypes.genericDist = Symbolic(#Uniform({low: 9.0, high: 10.0}))
|
let uniformDist: DistributionTypes.genericDist = Symbolic(#Uniform({low: 9.0, high: 10.0}))
|
||||||
|
let uniformDist2: DistributionTypes.genericDist = Symbolic(#Uniform({low: 8.0, high: 11.0}))
|
||||||
let floatDist: DistributionTypes.genericDist = Symbolic(#Float(1e1))
|
let floatDist: DistributionTypes.genericDist = Symbolic(#Float(1e1))
|
||||||
|
|
||||||
exception KlFailed
|
exception KlFailed
|
||||||
|
|
|
@ -121,8 +121,13 @@ describe("klDivergence: mixed -> mixed -> float", () => {
|
||||||
let mixture = a => DistributionTypes.DistributionOperation.Mixture(a)
|
let mixture = a => DistributionTypes.DistributionOperation.Mixture(a)
|
||||||
let a' = [(point1, 1e0), (uniformDist, 1e0)]->mixture->run
|
let a' = [(point1, 1e0), (uniformDist, 1e0)]->mixture->run
|
||||||
let b' = [(point1, 1e0), (floatDist, 1e0), (normalDist10, 1e0)]->mixture->run
|
let b' = [(point1, 1e0), (floatDist, 1e0), (normalDist10, 1e0)]->mixture->run
|
||||||
let (a, b) = switch (a', b') {
|
let c' = [(point1, 1e0), (point2, 1e0), (point3, 1e0), (uniformDist, 1e0)]->mixture->run
|
||||||
| (Dist(a''), Dist(b'')) => (a'', b'')
|
let d' =
|
||||||
|
[(point1, 1e0), (point2, 1e0), (point3, 1e0), (floatDist, 1e0), (uniformDist2, 1e0)]
|
||||||
|
->mixture
|
||||||
|
->run
|
||||||
|
let (a, b, c, d) = switch (a', b', c', d') {
|
||||||
|
| (Dist(a''), Dist(b''), Dist(c''), Dist(d'')) => (a'', b'', c'', d'')
|
||||||
| _ => raise(MixtureFailed)
|
| _ => raise(MixtureFailed)
|
||||||
}
|
}
|
||||||
test("finite klDivergence return is correct", () => {
|
test("finite klDivergence return is correct", () => {
|
||||||
|
@ -130,11 +135,11 @@ describe("klDivergence: mixed -> mixed -> float", () => {
|
||||||
let answer = a
|
let answer = a
|
||||||
let kl = klDivergence(prediction, answer)
|
let kl = klDivergence(prediction, answer)
|
||||||
// high = 10; low = 9; mean = 10; stdev = 2
|
// high = 10; low = 9; mean = 10; stdev = 2
|
||||||
let analyticalKlContinuousPart = klNormalUniform(10.0, 2.0, 9.0, 10.0)
|
let analyticalKlContinuousPart = klNormalUniform(10.0, 2.0, 9.0, 10.0) /. 2.0
|
||||||
let analyticalKlDiscretePart = Js.Math.log(2.0 /. 3.0) /. 2.0
|
let analyticalKlDiscretePart = 1.0 /. 2.0 *. Js.Math.log(2.0 /. 1.0)
|
||||||
switch kl {
|
switch kl {
|
||||||
| Ok(kl') =>
|
| Ok(kl') =>
|
||||||
kl'->expect->toBeSoCloseTo(analyticalKlContinuousPart +. analyticalKlDiscretePart, ~digits=0)
|
kl'->expect->toBeSoCloseTo(analyticalKlContinuousPart +. analyticalKlDiscretePart, ~digits=1)
|
||||||
| Error(err) =>
|
| Error(err) =>
|
||||||
Js.Console.log(DistributionTypes.Error.toString(err))
|
Js.Console.log(DistributionTypes.Error.toString(err))
|
||||||
raise(KlFailed)
|
raise(KlFailed)
|
||||||
|
@ -151,6 +156,20 @@ describe("klDivergence: mixed -> mixed -> float", () => {
|
||||||
raise(KlFailed)
|
raise(KlFailed)
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
|
test("finite klDivergence return is correct", () => {
|
||||||
|
let prediction = d
|
||||||
|
let answer = c
|
||||||
|
let kl = klDivergence(prediction, answer)
|
||||||
|
let analyticalKlContinuousPart = Js.Math.log((11.0 -. 8.0) /. (10.0 -. 9.0)) /. 4.0 // 4 = length of c' array
|
||||||
|
let analyticalKlDiscretePart = 3.0 /. 4.0 *. Js.Math.log(4.0 /. 3.0)
|
||||||
|
switch kl {
|
||||||
|
| Ok(kl') =>
|
||||||
|
kl'->expect->toBeSoCloseTo(analyticalKlContinuousPart +. analyticalKlDiscretePart, ~digits=1)
|
||||||
|
| Error(err) =>
|
||||||
|
Js.Console.log(DistributionTypes.Error.toString(err))
|
||||||
|
raise(KlFailed)
|
||||||
|
}
|
||||||
|
})
|
||||||
})
|
})
|
||||||
|
|
||||||
describe("combineAlongSupportOfSecondArgument0", () => {
|
describe("combineAlongSupportOfSecondArgument0", () => {
|
||||||
|
|
Loading…
Reference in New Issue
Block a user