From d9a40c973ae836ffc0347a5c57f644692a3ef0dc Mon Sep 17 00:00:00 2001 From: NunoSempere Date: Fri, 6 May 2022 12:26:51 -0400 Subject: [PATCH] feat: Get KL divergence working except in case of numerical errors () - Quinn was of great help here. - I also left some dead code, which still has to be cleaned up - There are still very annoying numerical errors, so I left one test failing. These are due to how the interpolation is done - Quinn to pick up from here Value: [0.6 to 2] --- .../Distributions/KlDivergence_test.res | 85 +++++++++++++------ packages/squiggle-lang/package.json | 1 + .../Distributions/PointSetDist/Continuous.res | 25 +++++- .../src/rescript/Utility/XYShape.res | 7 +- 4 files changed, 86 insertions(+), 32 deletions(-) diff --git a/packages/squiggle-lang/__tests__/Distributions/KlDivergence_test.res b/packages/squiggle-lang/__tests__/Distributions/KlDivergence_test.res index b749eb11..be884727 100644 --- a/packages/squiggle-lang/__tests__/Distributions/KlDivergence_test.res +++ b/packages/squiggle-lang/__tests__/Distributions/KlDivergence_test.res @@ -8,7 +8,7 @@ describe("kl divergence", () => { test("of two uniforms is equal to the analytic expression", () => { let lowAnswer = 0.0 let highAnswer = 1.0 - let lowPrediction = 0.0 + let lowPrediction = -1.0 let highPrediction = 2.0 let answer = uniformMakeR(lowAnswer, highAnswer)->E.R2.errMap(s => DistributionTypes.ArgumentError(s)) @@ -29,20 +29,50 @@ describe("kl divergence", () => { } } }) + test( + "of two uniforms is equal to the analytic expression, part 2 (annoying numerical errors)", + () => { + Js.Console.log( + "This will fait because of extremely annoying numerical errors. Will not fail if the two uniforms are a bit different. Very annoying", + ) + let lowAnswer = 0.0 + let highAnswer = 1.0 + let lowPrediction = 0.0 + let highPrediction = 2.0 + let answer = + uniformMakeR(lowAnswer, highAnswer)->E.R2.errMap(s => DistributionTypes.ArgumentError(s)) + let prediction = + uniformMakeR( + lowPrediction, + highPrediction, + )->E.R2.errMap(s => DistributionTypes.ArgumentError(s)) + // integral along the support of the answer of answer.pdf(x) times log of prediction.pdf(x) divided by answer.pdf(x) dx + let analyticalKl = Js.Math.log((highPrediction -. lowPrediction) /. (highAnswer -. lowAnswer)) + let kl = E.R.liftJoin2(klDivergence, prediction, answer) + Js.Console.log2("Analytical: ", analyticalKl) + Js.Console.log2("Computed: ", kl) + switch kl { + | Ok(kl') => kl'->expect->toBeCloseTo(analyticalKl) + | Error(err) => { + Js.Console.log(DistributionTypes.Error.toString(err)) + raise(KlFailed) + } + } + }, + ) test("of two normals is equal to the formula", () => { // This test case comes via Nuño https://github.com/quantified-uncertainty/squiggle/issues/433 let mean1 = 4.0 let mean2 = 1.0 - let stdev1 = 1.0 - let stdev2 = 4.0 + let stdev1 = 4.0 + let stdev2 = 1.0 let prediction = normalMakeR(mean1, stdev1)->E.R2.errMap(s => DistributionTypes.ArgumentError(s)) let answer = normalMakeR(mean2, stdev2)->E.R2.errMap(s => DistributionTypes.ArgumentError(s)) let analyticalKl = - Js.Math.log(stdev2 /. stdev1) +. - stdev1 ** 2.0 /. 2.0 /. stdev2 ** 2.0 +. - (mean1 -. mean2) ** 2.0 /. 2.0 /. stdev2 ** 2.0 -. 0.5 + Js.Math.log(stdev1 /. stdev2) +. + (stdev2 ** 2.0 +. (mean2 -. mean1) ** 2.0) /. (2.0 *. stdev1 ** 2.0) -. 0.5 let kl = E.R.liftJoin2(klDivergence, prediction, answer) Js.Console.log2("Analytical: ", analyticalKl) @@ -59,30 +89,31 @@ describe("kl divergence", () => { }) describe("combine along support test", () => { - let combineAlongSupportOfSecondArgument = XYShape.PointwiseCombination.combineAlongSupportOfSecondArgument - let lowAnswer = 0.0 - let highAnswer = 1.0 - let lowPrediction = -1.0 - let highPrediction = 2.0 + Skip.test("combine along support test", _ => { + // doesn't matter + let combineAlongSupportOfSecondArgument = XYShape.PointwiseCombination.combineAlongSupportOfSecondArgument + let lowAnswer = 0.0 + let highAnswer = 1.0 + let lowPrediction = 0.0 + let highPrediction = 2.0 - let answer = - uniformMakeR(lowAnswer, highAnswer)->E.R2.errMap(s => DistributionTypes.ArgumentError(s)) - let prediction = - uniformMakeR(lowPrediction, highPrediction)->E.R2.errMap(s => DistributionTypes.ArgumentError( - s, - )) - let answerWrapped = E.R.fmap(a => run(FromDist(ToDist(ToPointSet), a)), answer) - let predictionWrapped = E.R.fmap(a => run(FromDist(ToDist(ToPointSet), a)), prediction) + let answer = + uniformMakeR(lowAnswer, highAnswer)->E.R2.errMap(s => DistributionTypes.ArgumentError(s)) + let prediction = + uniformMakeR(lowPrediction, highPrediction)->E.R2.errMap(s => DistributionTypes.ArgumentError( + s, + )) + let answerWrapped = E.R.fmap(a => run(FromDist(ToDist(ToPointSet), a)), answer) + let predictionWrapped = E.R.fmap(a => run(FromDist(ToDist(ToPointSet), a)), prediction) - let interpolator = XYShape.XtoY.continuousInterpolator(#Stepwise, #UseZero) - let integrand = PointSetDist_Scoring.KLDivergence.integrand + let interpolator = XYShape.XtoY.continuousInterpolator(#Stepwise, #UseZero) + let integrand = PointSetDist_Scoring.KLDivergence.integrand - let result = switch (answerWrapped, predictionWrapped) { - | (Ok(Dist(PointSet(Continuous(a)))), Ok(Dist(PointSet(Continuous(b))))) => - Some(combineAlongSupportOfSecondArgument(integrand, interpolator, a.xyShape, b.xyShape)) - | _ => None - } - test("combine along support test", _ => { + let result = switch (answerWrapped, predictionWrapped) { + | (Ok(Dist(PointSet(Continuous(a)))), Ok(Dist(PointSet(Continuous(b))))) => + Some(combineAlongSupportOfSecondArgument(integrand, interpolator, a.xyShape, b.xyShape)) + | _ => None + } Js.Console.log2("combineAlongSupportOfSecondArgument", result) false->expect->toBe(true) }) diff --git a/packages/squiggle-lang/package.json b/packages/squiggle-lang/package.json index 97b710b6..8f3c04a8 100644 --- a/packages/squiggle-lang/package.json +++ b/packages/squiggle-lang/package.json @@ -15,6 +15,7 @@ "test": "jest", "test:ts": "jest __tests__/TS/", "test:rescript": "jest --modulePathIgnorePatterns=__tests__/TS/*", + "test:kldivergence": "jest __tests__/Distributions/KlDivergence_test.*", "test:watch": "jest --watchAll", "coverage:rescript": "rm -f *.coverage; yarn clean; BISECT_ENABLE=yes yarn build; yarn test:rescript; bisect-ppx-report html", "coverage:ts": "yarn clean; yarn build; nyc --reporter=lcov yarn test:ts", diff --git a/packages/squiggle-lang/src/rescript/Distributions/PointSetDist/Continuous.res b/packages/squiggle-lang/src/rescript/Distributions/PointSetDist/Continuous.res index b5eb0330..09b8d6b1 100644 --- a/packages/squiggle-lang/src/rescript/Distributions/PointSetDist/Continuous.res +++ b/packages/squiggle-lang/src/rescript/Distributions/PointSetDist/Continuous.res @@ -271,7 +271,7 @@ module T = Dist({ let variance = (t: t): float => XYShape.Analysis.getVarianceDangerously(t, mean, Analysis.getMeanOfSquares) - let klDivergence = (prediction: t, answer: t) => { + let klDivergence0 = (prediction: t, answer: t) => { combinePointwise( ~combiner=XYShape.PointwiseCombination.combineAlongSupportOfSecondArgument, PointSetDist_Scoring.KLDivergence.integrand, @@ -281,6 +281,29 @@ module T = Dist({ |> E.R.fmap(shapeMap(XYShape.T.filterYValues(Js.Float.isFinite))) |> E.R.fmap(integralEndY) } + + let klDivergence = (prediction: t, answer: t) => { + let newShape = XYShape.PointwiseCombination.combineAlongSupportOfSecondArgument2( + PointSetDist_Scoring.KLDivergence.integrand, + prediction.xyShape, + answer.xyShape, + ) + let generateContinuousDistFromXYShape: XYShape.xyShape => t = xyShape => { + xyShape: xyShape, + interpolation: #Linear, + integralSumCache: None, + integralCache: None, + } + let _ = Js.Console.log2("prediction", prediction) + let _ = Js.Console.log2("answer", answer) + let _ = Js.Console.log2("newShape", newShape) + switch newShape { + | Ok(tshape) => Ok(integralEndY(generateContinuousDistFromXYShape(tshape))) + | Error(errormessage) => Error(errormessage) + } + //|> E.R.fmap(shapeMap(XYShape.T.filterYValues(Js.Float.isFinite))) + //|> E.R.fmap(integralEndY) + } }) let isNormalized = (t: t): bool => { diff --git a/packages/squiggle-lang/src/rescript/Utility/XYShape.res b/packages/squiggle-lang/src/rescript/Utility/XYShape.res index 4af1082e..84853c40 100644 --- a/packages/squiggle-lang/src/rescript/Utility/XYShape.res +++ b/packages/squiggle-lang/src/rescript/Utility/XYShape.res @@ -391,7 +391,7 @@ module PointwiseCombination = { `) // This function is used for kl divergence - let combineAlongSupportOfSecondArgument0: ( + let combineAlongSupportOfSecondArgument: ( (float, float) => result, interpolator, T.t, @@ -489,12 +489,11 @@ module PointwiseCombination = { result } - let combineAlongSupportOfSecondArgument: ( + let combineAlongSupportOfSecondArgument2: ( (float, float) => result, - interpolator, T.t, T.t, - ) => result = (fn, interpolator, prediction, answer) => { + ) => result = (fn, prediction, answer) => { let combineWithFn = (x: float, i: int) => { let answerX = x let answerY = answer.ys[i]