diff --git a/packages/squiggle-lang/__tests__/Distributions/KlDivergence_test.res b/packages/squiggle-lang/__tests__/Distributions/KlDivergence_test.res index 96e95899..31f03ae2 100644 --- a/packages/squiggle-lang/__tests__/Distributions/KlDivergence_test.res +++ b/packages/squiggle-lang/__tests__/Distributions/KlDivergence_test.res @@ -5,7 +5,6 @@ open GenericDist_Fixtures describe("klDivergence: continuous -> continuous -> float", () => { let klDivergence = DistributionOperation.Constructors.klDivergence(~env) - exception KlFailed let testUniform = (lowAnswer, highAnswer, lowPrediction, highPrediction) => { test("of two uniforms is equal to the analytic expression", () => { @@ -59,6 +58,24 @@ describe("klDivergence: continuous -> continuous -> float", () => { } } }) + + test("of a normal and a uniform is equal to the formula", () => { + let prediction = normalDist10 + let answer = uniformDist + let kl = klDivergence(prediction, answer) + let analyticalKl = + -.Js.Math.log((10.0 -. 9.0) /. Js.Math.sqrt(2.0 *. MagicNumbers.Math.pi *. 2.0 ** 2.0)) +. + 1.0 /. + 2.0 ** 2.0 *. + (10.0 ** 2.0 -. (10.0 +. 9.0) *. 10.0 +. (9.0 ** 2.0 +. 10.0 *. 9.0 +. 10.0 ** 2.0) /. 3.0) + switch kl { + | Ok(kl') => kl'->expect->toBeSoCloseTo(analyticalKl, ~digits=1) + | Error(err) => { + Js.Console.log(DistributionTypes.Error.toString(err)) + raise(KlFailed) + } + } + }) }) describe("klDivergence: discrete -> discrete -> float", () => { @@ -96,6 +113,47 @@ describe("klDivergence: discrete -> discrete -> float", () => { }) }) +describe("klDivergence: mixed -> mixed -> float", () => { + let klDivergence = DistributionOperation.Constructors.klDivergence(~env) + let mixture = a => DistributionTypes.DistributionOperation.Mixture(a) + let a' = [(floatDist, 1e0), (uniformDist, 1e0)]->mixture->run + let b' = [(point3, 1e0), (floatDist, 1e0), (normalDist10, 1e0)]->mixture->run + let (a, b) = switch (a', b') { + | (Dist(a''), Dist(b'')) => (a'', b'') + | _ => raise(MixtureFailed) + } + test("finite klDivergence returns is correct", () => { + let prediction = b + let answer = a + let kl = klDivergence(prediction, answer) + // high = 10; low = 9; mean = 10; stdev = 2 + let analyticalKlContinuousPart = + Js.Math.log(Js.Math.sqrt(2.0 *. MagicNumbers.Math.pi *. 2.0 ** 2.0) /. (10.0 -. 9.0)) +. + 1.0 /. + 2.0 ** 2.0 *. + (10.0 ** 2.0 -. (10.0 +. 9.0) *. 10.0 +. (9.0 ** 2.0 +. 10.0 *. 9.0 +. 10.0 ** 2.0) /. 3.0) + let analyticalKlDiscretePart = -2.0 /. 3.0 *. Js.Math.log(3.0 /. 2.0) + switch kl { + | Ok(kl') => + kl'->expect->toBeSoCloseTo(analyticalKlContinuousPart +. analyticalKlDiscretePart, ~digits=0) + | Error(err) => + Js.Console.log(DistributionTypes.Error.toString(err)) + raise(KlFailed) + } + }) + test("returns infinity when infinite", () => { + let prediction = a + let answer = b + let kl = klDivergence(prediction, answer) + switch kl { + | Ok(kl') => kl'->expect->toEqual(infinity) + | Error(err) => + Js.Console.log(DistributionTypes.Error.toString(err)) + raise(KlFailed) + } + }) +}) + describe("combineAlongSupportOfSecondArgument0", () => { // This tests the version of the function that we're NOT using. Haven't deleted the test in case we use the code later. test("test on two uniforms", _ => { diff --git a/packages/squiggle-lang/src/rescript/Distributions/PointSetDist/Mixed.res b/packages/squiggle-lang/src/rescript/Distributions/PointSetDist/Mixed.res index 05465728..7bbe2065 100644 --- a/packages/squiggle-lang/src/rescript/Distributions/PointSetDist/Mixed.res +++ b/packages/squiggle-lang/src/rescript/Distributions/PointSetDist/Mixed.res @@ -302,10 +302,9 @@ module T = Dist({ } let klDivergence = (prediction: t, answer: t) => { - Error(Operation.NotYetImplemented) - // combinePointwise(PointSetDist_Scoring.KLDivergence.integrand, prediction, answer) |> E.R.fmap( - // integralEndY, - // ) + let klDiscretePart = Discrete.T.klDivergence(prediction.discrete, answer.discrete) + let klContinuousPart = Continuous.T.klDivergence(prediction.continuous, answer.continuous) + E.R.merge(klDiscretePart, klContinuousPart)->E.R2.fmap(t => fst(t) +. snd(t)) } }) diff --git a/packages/squiggle-lang/src/rescript/Distributions/PointSetDist/PointSetDist.res b/packages/squiggle-lang/src/rescript/Distributions/PointSetDist/PointSetDist.res index 1879ebdd..db47d1e1 100644 --- a/packages/squiggle-lang/src/rescript/Distributions/PointSetDist/PointSetDist.res +++ b/packages/squiggle-lang/src/rescript/Distributions/PointSetDist/PointSetDist.res @@ -200,6 +200,7 @@ module T = Dist({ switch (t1, t2) { | (Continuous(t1), Continuous(t2)) => Continuous.T.klDivergence(t1, t2) | (Discrete(t1), Discrete(t2)) => Discrete.T.klDivergence(t1, t2) + | (Mixed(t1), Mixed(t2)) => Mixed.T.klDivergence(t1, t2) | _ => Error(NotYetImplemented) } })