From 95fe117ef00d5a1fdbe413649ee55d21a0b14e7c Mon Sep 17 00:00:00 2001 From: Quinn Dougherty Date: Wed, 11 May 2022 14:14:10 -0400 Subject: [PATCH] Factored continuous part of normal and uniform kldivergence into it's own function Value: [1e-4 to 1e-3] --- .../Distributions/KlDivergence_test.res | 22 +++++++++---------- 1 file changed, 10 insertions(+), 12 deletions(-) diff --git a/packages/squiggle-lang/__tests__/Distributions/KlDivergence_test.res b/packages/squiggle-lang/__tests__/Distributions/KlDivergence_test.res index e3efe446..03652738 100644 --- a/packages/squiggle-lang/__tests__/Distributions/KlDivergence_test.res +++ b/packages/squiggle-lang/__tests__/Distributions/KlDivergence_test.res @@ -3,6 +3,12 @@ open Expect open TestHelpers open GenericDist_Fixtures +let klNormalUniform = (mean, stdev, low, high): float => + -.Js.Math.log((high -. low) /. Js.Math.sqrt(2.0 *. MagicNumbers.Math.pi *. stdev ** 2.0)) +. + 1.0 /. + stdev ** 2.0 *. + (mean ** 2.0 -. (high +. low) *. mean +. (low ** 2.0 +. high *. low +. high ** 2.0) /. 3.0) + describe("klDivergence: continuous -> continuous -> float", () => { let klDivergence = DistributionOperation.Constructors.klDivergence(~env) @@ -63,11 +69,7 @@ describe("klDivergence: continuous -> continuous -> float", () => { let prediction = normalDist10 let answer = uniformDist let kl = klDivergence(prediction, answer) - let analyticalKl = - -.Js.Math.log((10.0 -. 9.0) /. Js.Math.sqrt(2.0 *. MagicNumbers.Math.pi *. 2.0 ** 2.0)) +. - 1.0 /. - 2.0 ** 2.0 *. - (10.0 ** 2.0 -. (10.0 +. 9.0) *. 10.0 +. (9.0 ** 2.0 +. 10.0 *. 9.0 +. 10.0 ** 2.0) /. 3.0) + let analyticalKl = klNormalUniform(10.0, 2.0, 9.0, 10.0) switch kl { | Ok(kl') => kl'->expect->toBeSoCloseTo(analyticalKl, ~digits=3) | Error(err) => { @@ -122,17 +124,13 @@ describe("klDivergence: mixed -> mixed -> float", () => { | (Dist(a''), Dist(b'')) => (a'', b'') | _ => raise(MixtureFailed) } - test("finite klDivergence returns is correct", () => { + test("finite klDivergence return is correct", () => { let prediction = b let answer = a let kl = klDivergence(prediction, answer) // high = 10; low = 9; mean = 10; stdev = 2 - let analyticalKlContinuousPart = - Js.Math.log(Js.Math.sqrt(2.0 *. MagicNumbers.Math.pi *. 2.0 ** 2.0) /. (10.0 -. 9.0)) +. - 1.0 /. - 2.0 ** 2.0 *. - (10.0 ** 2.0 -. (10.0 +. 9.0) *. 10.0 +. (9.0 ** 2.0 +. 10.0 *. 9.0 +. 10.0 ** 2.0) /. 3.0) - let analyticalKlDiscretePart = -2.0 /. 3.0 *. Js.Math.log(3.0 /. 2.0) + let analyticalKlContinuousPart = klNormalUniform(10.0, 2.0, 9.0, 10.0) + let analyticalKlDiscretePart = 2.0 /. 3.0 *. Js.Math.log(2.0 /. 3.0) switch kl { | Ok(kl') => kl'->expect->toBeSoCloseTo(analyticalKlContinuousPart +. analyticalKlDiscretePart, ~digits=0)