2022-05-05 19:37:28 +00:00
|
|
|
open Jest
|
|
|
|
open Expect
|
|
|
|
open TestHelpers
|
2022-05-10 18:03:42 +00:00
|
|
|
open GenericDist_Fixtures
|
2022-05-05 19:37:28 +00:00
|
|
|
|
2022-05-10 15:56:13 +00:00
|
|
|
describe("klDivergence: continuous -> continuous -> float", () => {
|
2022-05-05 19:37:28 +00:00
|
|
|
let klDivergence = DistributionOperation.Constructors.klDivergence(~env)
|
2022-05-06 00:02:12 +00:00
|
|
|
exception KlFailed
|
2022-05-06 18:21:53 +00:00
|
|
|
|
|
|
|
let testUniform = (lowAnswer, highAnswer, lowPrediction, highPrediction) => {
|
|
|
|
test("of two uniforms is equal to the analytic expression", () => {
|
2022-05-06 16:26:51 +00:00
|
|
|
let answer =
|
|
|
|
uniformMakeR(lowAnswer, highAnswer)->E.R2.errMap(s => DistributionTypes.ArgumentError(s))
|
|
|
|
let prediction =
|
|
|
|
uniformMakeR(
|
|
|
|
lowPrediction,
|
|
|
|
highPrediction,
|
|
|
|
)->E.R2.errMap(s => DistributionTypes.ArgumentError(s))
|
|
|
|
// integral along the support of the answer of answer.pdf(x) times log of prediction.pdf(x) divided by answer.pdf(x) dx
|
|
|
|
let analyticalKl = Js.Math.log((highPrediction -. lowPrediction) /. (highAnswer -. lowAnswer))
|
|
|
|
let kl = E.R.liftJoin2(klDivergence, prediction, answer)
|
|
|
|
switch kl {
|
2022-05-09 23:17:27 +00:00
|
|
|
| Ok(kl') => kl'->expect->toBeSoCloseTo(analyticalKl, ~digits=7)
|
2022-05-06 16:26:51 +00:00
|
|
|
| Error(err) => {
|
|
|
|
Js.Console.log(DistributionTypes.Error.toString(err))
|
|
|
|
raise(KlFailed)
|
|
|
|
}
|
|
|
|
}
|
2022-05-06 18:21:53 +00:00
|
|
|
})
|
|
|
|
}
|
2022-05-09 15:14:33 +00:00
|
|
|
// The pair on the right (the answer) can be wider than the pair on the left (the prediction), but not the other way around.
|
2022-05-06 18:21:53 +00:00
|
|
|
testUniform(0.0, 1.0, -1.0, 2.0)
|
2022-05-09 15:14:33 +00:00
|
|
|
testUniform(0.0, 1.0, 0.0, 2.0) // equal left endpoints
|
|
|
|
testUniform(0.0, 1.0, -1.0, 1.0) // equal rightendpoints
|
|
|
|
testUniform(0.0, 1e1, 0.0, 1e1) // equal (klDivergence = 0)
|
2022-05-06 18:21:53 +00:00
|
|
|
// testUniform(-1.0, 1.0, 0.0, 2.0)
|
|
|
|
|
2022-05-06 00:02:12 +00:00
|
|
|
test("of two normals is equal to the formula", () => {
|
|
|
|
// This test case comes via Nuño https://github.com/quantified-uncertainty/squiggle/issues/433
|
|
|
|
let mean1 = 4.0
|
|
|
|
let mean2 = 1.0
|
2022-05-06 16:26:51 +00:00
|
|
|
let stdev1 = 4.0
|
|
|
|
let stdev2 = 1.0
|
2022-05-06 00:02:12 +00:00
|
|
|
|
|
|
|
let prediction =
|
|
|
|
normalMakeR(mean1, stdev1)->E.R2.errMap(s => DistributionTypes.ArgumentError(s))
|
|
|
|
let answer = normalMakeR(mean2, stdev2)->E.R2.errMap(s => DistributionTypes.ArgumentError(s))
|
2022-05-06 17:58:15 +00:00
|
|
|
// https://stats.stackexchange.com/questions/7440/kl-divergence-between-two-univariate-gaussians
|
2022-05-05 19:37:28 +00:00
|
|
|
let analyticalKl =
|
2022-05-06 16:26:51 +00:00
|
|
|
Js.Math.log(stdev1 /. stdev2) +.
|
|
|
|
(stdev2 ** 2.0 +. (mean2 -. mean1) ** 2.0) /. (2.0 *. stdev1 ** 2.0) -. 0.5
|
2022-05-05 19:37:28 +00:00
|
|
|
let kl = E.R.liftJoin2(klDivergence, prediction, answer)
|
2022-05-06 15:45:11 +00:00
|
|
|
|
2022-05-05 19:37:28 +00:00
|
|
|
switch kl {
|
2022-05-09 23:17:27 +00:00
|
|
|
| Ok(kl') => kl'->expect->toBeSoCloseTo(analyticalKl, ~digits=3)
|
2022-05-05 19:37:28 +00:00
|
|
|
| Error(err) => {
|
|
|
|
Js.Console.log(DistributionTypes.Error.toString(err))
|
|
|
|
raise(KlFailed)
|
|
|
|
}
|
|
|
|
}
|
|
|
|
})
|
|
|
|
})
|
2022-05-06 15:45:11 +00:00
|
|
|
|
2022-05-10 15:56:13 +00:00
|
|
|
describe("klDivergence: discrete -> discrete -> float", () => {
|
2022-05-09 22:28:35 +00:00
|
|
|
let klDivergence = DistributionOperation.Constructors.klDivergence(~env)
|
|
|
|
let mixture = a => DistributionTypes.DistributionOperation.Mixture(a)
|
2022-05-10 15:56:13 +00:00
|
|
|
let a' = [(point1, 1e0), (point2, 1e0)]->mixture->run
|
|
|
|
let b' = [(point1, 1e0), (point2, 1e0), (point3, 1e0)]->mixture->run
|
|
|
|
let (a, b) = switch (a', b') {
|
|
|
|
| (Dist(a''), Dist(b'')) => (a'', b'')
|
|
|
|
| _ => raise(MixtureFailed)
|
|
|
|
}
|
2022-05-10 18:03:42 +00:00
|
|
|
test("agrees with analytical answer when finite", () => {
|
2022-05-10 15:56:13 +00:00
|
|
|
let prediction = b
|
|
|
|
let answer = a
|
|
|
|
let kl = klDivergence(prediction, answer)
|
|
|
|
// Sigma_{i \in 1..2} 0.5 * log(0.5 / 0.33333)
|
2022-05-09 23:17:27 +00:00
|
|
|
let analyticalKl = Js.Math.log(3.0 /. 2.0)
|
2022-05-09 22:28:35 +00:00
|
|
|
switch kl {
|
2022-05-09 23:17:27 +00:00
|
|
|
| Ok(kl') => kl'->expect->toBeSoCloseTo(analyticalKl, ~digits=7)
|
2022-05-09 22:28:35 +00:00
|
|
|
| Error(err) =>
|
|
|
|
Js.Console.log(DistributionTypes.Error.toString(err))
|
|
|
|
raise(KlFailed)
|
|
|
|
}
|
|
|
|
})
|
2022-05-10 18:03:42 +00:00
|
|
|
test("returns infinity when infinite", () => {
|
2022-05-10 15:56:13 +00:00
|
|
|
let prediction = a
|
|
|
|
let answer = b
|
|
|
|
let kl = klDivergence(prediction, answer)
|
2022-05-09 22:28:35 +00:00
|
|
|
switch kl {
|
2022-05-10 15:27:59 +00:00
|
|
|
| Ok(kl') => kl'->expect->toEqual(infinity)
|
2022-05-09 22:28:35 +00:00
|
|
|
| Error(err) =>
|
|
|
|
Js.Console.log(DistributionTypes.Error.toString(err))
|
|
|
|
raise(KlFailed)
|
|
|
|
}
|
|
|
|
})
|
|
|
|
})
|
|
|
|
|
2022-05-10 15:56:13 +00:00
|
|
|
describe("combineAlongSupportOfSecondArgument0", () => {
|
2022-05-09 15:14:33 +00:00
|
|
|
// This tests the version of the function that we're NOT using. Haven't deleted the test in case we use the code later.
|
2022-05-10 15:27:59 +00:00
|
|
|
test("test on two uniforms", _ => {
|
2022-05-06 17:58:15 +00:00
|
|
|
let combineAlongSupportOfSecondArgument = XYShape.PointwiseCombination.combineAlongSupportOfSecondArgument0
|
2022-05-06 16:26:51 +00:00
|
|
|
let lowAnswer = 0.0
|
|
|
|
let highAnswer = 1.0
|
|
|
|
let lowPrediction = 0.0
|
|
|
|
let highPrediction = 2.0
|
2022-05-06 15:45:11 +00:00
|
|
|
|
2022-05-06 16:26:51 +00:00
|
|
|
let answer =
|
|
|
|
uniformMakeR(lowAnswer, highAnswer)->E.R2.errMap(s => DistributionTypes.ArgumentError(s))
|
|
|
|
let prediction =
|
|
|
|
uniformMakeR(lowPrediction, highPrediction)->E.R2.errMap(s => DistributionTypes.ArgumentError(
|
|
|
|
s,
|
|
|
|
))
|
|
|
|
let answerWrapped = E.R.fmap(a => run(FromDist(ToDist(ToPointSet), a)), answer)
|
|
|
|
let predictionWrapped = E.R.fmap(a => run(FromDist(ToDist(ToPointSet), a)), prediction)
|
2022-05-06 15:45:11 +00:00
|
|
|
|
2022-05-06 16:26:51 +00:00
|
|
|
let interpolator = XYShape.XtoY.continuousInterpolator(#Stepwise, #UseZero)
|
|
|
|
let integrand = PointSetDist_Scoring.KLDivergence.integrand
|
2022-05-06 15:45:11 +00:00
|
|
|
|
2022-05-06 16:26:51 +00:00
|
|
|
let result = switch (answerWrapped, predictionWrapped) {
|
|
|
|
| (Ok(Dist(PointSet(Continuous(a)))), Ok(Dist(PointSet(Continuous(b))))) =>
|
|
|
|
Some(combineAlongSupportOfSecondArgument(integrand, interpolator, a.xyShape, b.xyShape))
|
|
|
|
| _ => None
|
|
|
|
}
|
2022-05-06 18:21:53 +00:00
|
|
|
result
|
|
|
|
->expect
|
|
|
|
->toEqual(
|
|
|
|
Some(
|
|
|
|
Ok({
|
|
|
|
xs: [
|
|
|
|
0.0,
|
|
|
|
MagicNumbers.Epsilon.ten,
|
|
|
|
2.0 *. MagicNumbers.Epsilon.ten,
|
|
|
|
1.0 -. MagicNumbers.Epsilon.ten,
|
|
|
|
1.0,
|
2022-05-09 22:28:35 +00:00
|
|
|
1.0 +. MagicNumbers.Epsilon.ten,
|
2022-05-06 18:21:53 +00:00
|
|
|
],
|
|
|
|
ys: [
|
|
|
|
-0.34657359027997264,
|
|
|
|
-0.34657359027997264,
|
|
|
|
-0.34657359027997264,
|
|
|
|
-0.34657359027997264,
|
|
|
|
-0.34657359027997264,
|
2022-05-09 22:28:35 +00:00
|
|
|
infinity,
|
2022-05-06 18:21:53 +00:00
|
|
|
],
|
|
|
|
}),
|
|
|
|
),
|
|
|
|
)
|
2022-05-06 15:45:11 +00:00
|
|
|
})
|
|
|
|
})
|