klDivergence
on mixed distributions works for one test case
Value: [1e-4 to 5e-2]
This commit is contained in:
parent
4c68b1d9b8
commit
0b8da034c6
|
@ -5,7 +5,6 @@ open GenericDist_Fixtures
|
||||||
|
|
||||||
describe("klDivergence: continuous -> continuous -> float", () => {
|
describe("klDivergence: continuous -> continuous -> float", () => {
|
||||||
let klDivergence = DistributionOperation.Constructors.klDivergence(~env)
|
let klDivergence = DistributionOperation.Constructors.klDivergence(~env)
|
||||||
exception KlFailed
|
|
||||||
|
|
||||||
let testUniform = (lowAnswer, highAnswer, lowPrediction, highPrediction) => {
|
let testUniform = (lowAnswer, highAnswer, lowPrediction, highPrediction) => {
|
||||||
test("of two uniforms is equal to the analytic expression", () => {
|
test("of two uniforms is equal to the analytic expression", () => {
|
||||||
|
@ -59,6 +58,24 @@ describe("klDivergence: continuous -> continuous -> float", () => {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
|
|
||||||
|
test("of a normal and a uniform is equal to the formula", () => {
|
||||||
|
let prediction = normalDist10
|
||||||
|
let answer = uniformDist
|
||||||
|
let kl = klDivergence(prediction, answer)
|
||||||
|
let analyticalKl =
|
||||||
|
-.Js.Math.log((10.0 -. 9.0) /. Js.Math.sqrt(2.0 *. MagicNumbers.Math.pi *. 2.0 ** 2.0)) +.
|
||||||
|
1.0 /.
|
||||||
|
2.0 ** 2.0 *.
|
||||||
|
(10.0 ** 2.0 -. (10.0 +. 9.0) *. 10.0 +. (9.0 ** 2.0 +. 10.0 *. 9.0 +. 10.0 ** 2.0) /. 3.0)
|
||||||
|
switch kl {
|
||||||
|
| Ok(kl') => kl'->expect->toBeSoCloseTo(analyticalKl, ~digits=1)
|
||||||
|
| Error(err) => {
|
||||||
|
Js.Console.log(DistributionTypes.Error.toString(err))
|
||||||
|
raise(KlFailed)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
})
|
||||||
})
|
})
|
||||||
|
|
||||||
describe("klDivergence: discrete -> discrete -> float", () => {
|
describe("klDivergence: discrete -> discrete -> float", () => {
|
||||||
|
@ -96,6 +113,47 @@ describe("klDivergence: discrete -> discrete -> float", () => {
|
||||||
})
|
})
|
||||||
})
|
})
|
||||||
|
|
||||||
|
describe("klDivergence: mixed -> mixed -> float", () => {
|
||||||
|
let klDivergence = DistributionOperation.Constructors.klDivergence(~env)
|
||||||
|
let mixture = a => DistributionTypes.DistributionOperation.Mixture(a)
|
||||||
|
let a' = [(floatDist, 1e0), (uniformDist, 1e0)]->mixture->run
|
||||||
|
let b' = [(point3, 1e0), (floatDist, 1e0), (normalDist10, 1e0)]->mixture->run
|
||||||
|
let (a, b) = switch (a', b') {
|
||||||
|
| (Dist(a''), Dist(b'')) => (a'', b'')
|
||||||
|
| _ => raise(MixtureFailed)
|
||||||
|
}
|
||||||
|
test("finite klDivergence returns is correct", () => {
|
||||||
|
let prediction = b
|
||||||
|
let answer = a
|
||||||
|
let kl = klDivergence(prediction, answer)
|
||||||
|
// high = 10; low = 9; mean = 10; stdev = 2
|
||||||
|
let analyticalKlContinuousPart =
|
||||||
|
Js.Math.log(Js.Math.sqrt(2.0 *. MagicNumbers.Math.pi *. 2.0 ** 2.0) /. (10.0 -. 9.0)) +.
|
||||||
|
1.0 /.
|
||||||
|
2.0 ** 2.0 *.
|
||||||
|
(10.0 ** 2.0 -. (10.0 +. 9.0) *. 10.0 +. (9.0 ** 2.0 +. 10.0 *. 9.0 +. 10.0 ** 2.0) /. 3.0)
|
||||||
|
let analyticalKlDiscretePart = -2.0 /. 3.0 *. Js.Math.log(3.0 /. 2.0)
|
||||||
|
switch kl {
|
||||||
|
| Ok(kl') =>
|
||||||
|
kl'->expect->toBeSoCloseTo(analyticalKlContinuousPart +. analyticalKlDiscretePart, ~digits=0)
|
||||||
|
| Error(err) =>
|
||||||
|
Js.Console.log(DistributionTypes.Error.toString(err))
|
||||||
|
raise(KlFailed)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
test("returns infinity when infinite", () => {
|
||||||
|
let prediction = a
|
||||||
|
let answer = b
|
||||||
|
let kl = klDivergence(prediction, answer)
|
||||||
|
switch kl {
|
||||||
|
| Ok(kl') => kl'->expect->toEqual(infinity)
|
||||||
|
| Error(err) =>
|
||||||
|
Js.Console.log(DistributionTypes.Error.toString(err))
|
||||||
|
raise(KlFailed)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
})
|
||||||
|
|
||||||
describe("combineAlongSupportOfSecondArgument0", () => {
|
describe("combineAlongSupportOfSecondArgument0", () => {
|
||||||
// This tests the version of the function that we're NOT using. Haven't deleted the test in case we use the code later.
|
// This tests the version of the function that we're NOT using. Haven't deleted the test in case we use the code later.
|
||||||
test("test on two uniforms", _ => {
|
test("test on two uniforms", _ => {
|
||||||
|
|
|
@ -302,10 +302,9 @@ module T = Dist({
|
||||||
}
|
}
|
||||||
|
|
||||||
let klDivergence = (prediction: t, answer: t) => {
|
let klDivergence = (prediction: t, answer: t) => {
|
||||||
Error(Operation.NotYetImplemented)
|
let klDiscretePart = Discrete.T.klDivergence(prediction.discrete, answer.discrete)
|
||||||
// combinePointwise(PointSetDist_Scoring.KLDivergence.integrand, prediction, answer) |> E.R.fmap(
|
let klContinuousPart = Continuous.T.klDivergence(prediction.continuous, answer.continuous)
|
||||||
// integralEndY,
|
E.R.merge(klDiscretePart, klContinuousPart)->E.R2.fmap(t => fst(t) +. snd(t))
|
||||||
// )
|
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
|
|
||||||
|
|
|
@ -200,6 +200,7 @@ module T = Dist({
|
||||||
switch (t1, t2) {
|
switch (t1, t2) {
|
||||||
| (Continuous(t1), Continuous(t2)) => Continuous.T.klDivergence(t1, t2)
|
| (Continuous(t1), Continuous(t2)) => Continuous.T.klDivergence(t1, t2)
|
||||||
| (Discrete(t1), Discrete(t2)) => Discrete.T.klDivergence(t1, t2)
|
| (Discrete(t1), Discrete(t2)) => Discrete.T.klDivergence(t1, t2)
|
||||||
|
| (Mixed(t1), Mixed(t2)) => Mixed.T.klDivergence(t1, t2)
|
||||||
| _ => Error(NotYetImplemented)
|
| _ => Error(NotYetImplemented)
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
|
|
Loading…
Reference in New Issue
Block a user