Some cleanup

Value: [1e-4 to 1e-2]
This commit is contained in:
Quinn Dougherty 2022-05-10 11:56:13 -04:00
parent 15f1ebb429
commit f7690c33e0
2 changed files with 19 additions and 23 deletions

View File

@ -2,7 +2,7 @@ open Jest
open Expect open Expect
open TestHelpers open TestHelpers
describe("kl divergence on continuous distributions", () => { describe("klDivergence: continuous -> continuous -> float", () => {
let klDivergence = DistributionOperation.Constructors.klDivergence(~env) let klDivergence = DistributionOperation.Constructors.klDivergence(~env)
exception KlFailed exception KlFailed
@ -60,7 +60,7 @@ describe("kl divergence on continuous distributions", () => {
}) })
}) })
describe("kl divergence on discrete distributions", () => { describe("klDivergence: discrete -> discrete -> float", () => {
let klDivergence = DistributionOperation.Constructors.klDivergence(~env) let klDivergence = DistributionOperation.Constructors.klDivergence(~env)
let mixture = a => DistributionTypes.DistributionOperation.Mixture(a) let mixture = a => DistributionTypes.DistributionOperation.Mixture(a)
exception KlFailed exception KlFailed
@ -71,13 +71,17 @@ describe("kl divergence on discrete distributions", () => {
let point1 = mkDelta(float1) let point1 = mkDelta(float1)
let point2 = mkDelta(float2) let point2 = mkDelta(float2)
let point3 = mkDelta(float3) let point3 = mkDelta(float3)
test("finite kl divergence", () => { let a' = [(point1, 1e0), (point2, 1e0)]->mixture->run
let answer = [(point1, 1e0), (point2, 1e0)]->mixture->run let b' = [(point1, 1e0), (point2, 1e0), (point3, 1e0)]->mixture->run
let prediction = [(point1, 1e0), (point2, 1e0), (point3, 1e0)]->mixture->run let (a, b) = switch (a', b') {
let kl = switch (prediction, answer) { | (Dist(a''), Dist(b'')) => (a'', b'')
| (Dist(prediction'), Dist(answer')) => klDivergence(prediction', answer')
| _ => raise(MixtureFailed) | _ => raise(MixtureFailed)
} }
test("is finite", () => {
let prediction = b
let answer = a
let kl = klDivergence(prediction, answer)
// Sigma_{i \in 1..2} 0.5 * log(0.5 / 0.33333)
let analyticalKl = Js.Math.log(3.0 /. 2.0) let analyticalKl = Js.Math.log(3.0 /. 2.0)
switch kl { switch kl {
| Ok(kl') => kl'->expect->toBeSoCloseTo(analyticalKl, ~digits=7) | Ok(kl') => kl'->expect->toBeSoCloseTo(analyticalKl, ~digits=7)
@ -86,13 +90,10 @@ describe("kl divergence on discrete distributions", () => {
raise(KlFailed) raise(KlFailed)
} }
}) })
test("infinite kl divergence", () => { test("is infinite", () => {
let prediction = [(point1, 1e0), (point2, 1e0)]->mixture->run let prediction = a
let answer = [(point1, 1e0), (point2, 1e0), (point3, 1e0)]->mixture->run let answer = b
let kl = switch (prediction, answer) { let kl = klDivergence(prediction, answer)
| (Dist(prediction'), Dist(answer')) => klDivergence(prediction', answer')
| _ => raise(MixtureFailed)
}
switch kl { switch kl {
| Ok(kl') => kl'->expect->toEqual(infinity) | Ok(kl') => kl'->expect->toEqual(infinity)
| Error(err) => | Error(err) =>
@ -102,7 +103,7 @@ describe("kl divergence on discrete distributions", () => {
}) })
}) })
describe("combineAlongSupportOfSecondArgument", () => { describe("combineAlongSupportOfSecondArgument0", () => {
// This tests the version of the function that we're NOT using. Haven't deleted the test in case we use the code later. // This tests the version of the function that we're NOT using. Haven't deleted the test in case we use the code later.
test("test on two uniforms", _ => { test("test on two uniforms", _ => {
let combineAlongSupportOfSecondArgument = XYShape.PointwiseCombination.combineAlongSupportOfSecondArgument0 let combineAlongSupportOfSecondArgument = XYShape.PointwiseCombination.combineAlongSupportOfSecondArgument0

View File

@ -48,12 +48,7 @@ let combinePointwise = (
// TODO: does it ever make sense to pointwise combine the integrals here? // TODO: does it ever make sense to pointwise combine the integrals here?
// It could be done for pointwise additions, but is that ever needed? // It could be done for pointwise additions, but is that ever needed?
make( combiner(fn, XYShape.XtoY.discreteInterpolator, t1.xyShape, t2.xyShape)->E.R2.fmap(make)
combiner(fn, XYShape.XtoY.discreteInterpolator, t1.xyShape, t2.xyShape)->E.R.toExn(
"Logically unreachable?",
_,
),
)->Ok
} }
let reduce = ( let reduce = (