From f7690c33e0226754a91a83ff71539c8a5079bbb7 Mon Sep 17 00:00:00 2001 From: Quinn Dougherty Date: Tue, 10 May 2022 11:56:13 -0400 Subject: [PATCH] Some cleanup Value: [1e-4 to 1e-2] --- .../Distributions/KlDivergence_test.res | 35 ++++++++++--------- .../Distributions/PointSetDist/Discrete.res | 7 +--- 2 files changed, 19 insertions(+), 23 deletions(-) diff --git a/packages/squiggle-lang/__tests__/Distributions/KlDivergence_test.res b/packages/squiggle-lang/__tests__/Distributions/KlDivergence_test.res index 3ed1788b..8eb9974c 100644 --- a/packages/squiggle-lang/__tests__/Distributions/KlDivergence_test.res +++ b/packages/squiggle-lang/__tests__/Distributions/KlDivergence_test.res @@ -2,7 +2,7 @@ open Jest open Expect open TestHelpers -describe("kl divergence on continuous distributions", () => { +describe("klDivergence: continuous -> continuous -> float", () => { let klDivergence = DistributionOperation.Constructors.klDivergence(~env) exception KlFailed @@ -60,7 +60,7 @@ describe("kl divergence on continuous distributions", () => { }) }) -describe("kl divergence on discrete distributions", () => { +describe("klDivergence: discrete -> discrete -> float", () => { let klDivergence = DistributionOperation.Constructors.klDivergence(~env) let mixture = a => DistributionTypes.DistributionOperation.Mixture(a) exception KlFailed @@ -71,13 +71,17 @@ describe("kl divergence on discrete distributions", () => { let point1 = mkDelta(float1) let point2 = mkDelta(float2) let point3 = mkDelta(float3) - test("finite kl divergence", () => { - let answer = [(point1, 1e0), (point2, 1e0)]->mixture->run - let prediction = [(point1, 1e0), (point2, 1e0), (point3, 1e0)]->mixture->run - let kl = switch (prediction, answer) { - | (Dist(prediction'), Dist(answer')) => klDivergence(prediction', answer') - | _ => raise(MixtureFailed) - } + let a' = [(point1, 1e0), (point2, 1e0)]->mixture->run + let b' = [(point1, 1e0), (point2, 1e0), (point3, 1e0)]->mixture->run + let (a, b) = switch (a', b') { + | (Dist(a''), Dist(b'')) => (a'', b'') + | _ => raise(MixtureFailed) + } + test("is finite", () => { + let prediction = b + let answer = a + let kl = klDivergence(prediction, answer) + // Sigma_{i \in 1..2} 0.5 * log(0.5 / 0.33333) let analyticalKl = Js.Math.log(3.0 /. 2.0) switch kl { | Ok(kl') => kl'->expect->toBeSoCloseTo(analyticalKl, ~digits=7) @@ -86,13 +90,10 @@ describe("kl divergence on discrete distributions", () => { raise(KlFailed) } }) - test("infinite kl divergence", () => { - let prediction = [(point1, 1e0), (point2, 1e0)]->mixture->run - let answer = [(point1, 1e0), (point2, 1e0), (point3, 1e0)]->mixture->run - let kl = switch (prediction, answer) { - | (Dist(prediction'), Dist(answer')) => klDivergence(prediction', answer') - | _ => raise(MixtureFailed) - } + test("is infinite", () => { + let prediction = a + let answer = b + let kl = klDivergence(prediction, answer) switch kl { | Ok(kl') => kl'->expect->toEqual(infinity) | Error(err) => @@ -102,7 +103,7 @@ describe("kl divergence on discrete distributions", () => { }) }) -describe("combineAlongSupportOfSecondArgument", () => { +describe("combineAlongSupportOfSecondArgument0", () => { // This tests the version of the function that we're NOT using. Haven't deleted the test in case we use the code later. test("test on two uniforms", _ => { let combineAlongSupportOfSecondArgument = XYShape.PointwiseCombination.combineAlongSupportOfSecondArgument0 diff --git a/packages/squiggle-lang/src/rescript/Distributions/PointSetDist/Discrete.res b/packages/squiggle-lang/src/rescript/Distributions/PointSetDist/Discrete.res index 04265d7c..abb6b793 100644 --- a/packages/squiggle-lang/src/rescript/Distributions/PointSetDist/Discrete.res +++ b/packages/squiggle-lang/src/rescript/Distributions/PointSetDist/Discrete.res @@ -48,12 +48,7 @@ let combinePointwise = ( // TODO: does it ever make sense to pointwise combine the integrals here? // It could be done for pointwise additions, but is that ever needed? - make( - combiner(fn, XYShape.XtoY.discreteInterpolator, t1.xyShape, t2.xyShape)->E.R.toExn( - "Logically unreachable?", - _, - ), - )->Ok + combiner(fn, XYShape.XtoY.discreteInterpolator, t1.xyShape, t2.xyShape)->E.R2.fmap(make) } let reduce = (