From dcf56d7bc6ea803a6ddf384300e2ec01be74a566 Mon Sep 17 00:00:00 2001 From: Quinn Dougherty Date: Thu, 5 May 2022 20:02:12 -0400 Subject: [PATCH] `combineAlongSupportOfSecondArgument` implemented, tests still failing Value: [1e-4 to 4e-2] --- .../Distributions/KlDivergence_test.res | 35 +++++-- .../squiggle-lang/__tests__/XYShape_test.res | 13 --- .../Distributions/PointSetDist/Continuous.res | 10 +- .../Distributions/PointSetDist/Discrete.res | 12 +-- .../Distributions/PointSetDist/Mixed.res | 2 +- .../PointSetDist/PointSetDist_Scoring.res | 6 +- .../src/rescript/Utility/XYShape.res | 92 +++++++++++++++---- 7 files changed, 117 insertions(+), 53 deletions(-) diff --git a/packages/squiggle-lang/__tests__/Distributions/KlDivergence_test.res b/packages/squiggle-lang/__tests__/Distributions/KlDivergence_test.res index b1162124..659f4b69 100644 --- a/packages/squiggle-lang/__tests__/Distributions/KlDivergence_test.res +++ b/packages/squiggle-lang/__tests__/Distributions/KlDivergence_test.res @@ -4,11 +4,11 @@ open TestHelpers describe("kl divergence", () => { let klDivergence = DistributionOperation.Constructors.klDivergence(~env) - test("", () => { - exception KlFailed - let lowAnswer = 4.3526e0 + exception KlFailed + test("of two uniforms is equal to the analytic expression", () => { + let lowAnswer = 2.3526e0 let highAnswer = 8.5382e0 - let lowPrediction = 4.3526e0 + let lowPrediction = 2.3526e0 let highPrediction = 1.2345e1 let answer = uniformMakeR(lowAnswer, highAnswer)->E.R2.errMap(s => DistributionTypes.ArgumentError(s)) @@ -17,11 +17,30 @@ describe("kl divergence", () => { s, )) // integral along the support of the answer of answer.pdf(x) times log of prediction.pdf(x) divided by answer.pdf(x) dx + let analyticalKl = Js.Math.log((highPrediction -. lowPrediction) /. (highAnswer -. lowAnswer)) + let kl = E.R.liftJoin2(klDivergence, prediction, answer) + switch kl { + | Ok(kl') => kl'->expect->toBeCloseTo(analyticalKl) + | Error(err) => { + Js.Console.log(DistributionTypes.Error.toString(err)) + raise(KlFailed) + } + } + }) + test("of two normals is equal to the formula", () => { + // This test case comes via Nuño https://github.com/quantified-uncertainty/squiggle/issues/433 + let mean1 = 4.0 + let mean2 = 1.0 + let stdev1 = 1.0 + let stdev2 = 1.0 + + let prediction = + normalMakeR(mean1, stdev1)->E.R2.errMap(s => DistributionTypes.ArgumentError(s)) + let answer = normalMakeR(mean2, stdev2)->E.R2.errMap(s => DistributionTypes.ArgumentError(s)) let analyticalKl = - -1.0 /. - (highAnswer -. lowAnswer) *. - Js.Math.log((highAnswer -. lowAnswer) /. (highPrediction -. lowPrediction)) *. - (highAnswer -. lowAnswer) + Js.Math.log(stdev2 /. stdev1) +. + stdev1 ** 2.0 /. 2.0 /. stdev2 ** 2.0 +. + (mean1 -. mean2) ** 2.0 /. 2.0 /. stdev2 ** 2.0 -. 0.5 let kl = E.R.liftJoin2(klDivergence, prediction, answer) switch kl { | Ok(kl') => kl'->expect->toBeCloseTo(analyticalKl) diff --git a/packages/squiggle-lang/__tests__/XYShape_test.res b/packages/squiggle-lang/__tests__/XYShape_test.res index 38535020..60d716ca 100644 --- a/packages/squiggle-lang/__tests__/XYShape_test.res +++ b/packages/squiggle-lang/__tests__/XYShape_test.res @@ -38,19 +38,6 @@ describe("XYShapes", () => { ) }) - describe("logScorePoint", () => { - makeTest("When identical", XYShape.logScorePoint(30, pointSetDist1, pointSetDist1), Some(0.0)) - makeTest( - "When similar", - XYShape.logScorePoint(30, pointSetDist1, pointSetDist2), - Some(1.658971191043856), - ) - makeTest( - "When very different", - XYShape.logScorePoint(30, pointSetDist1, pointSetDist3), - Some(210.3721280423322), - ) - }) describe("integrateWithTriangles", () => makeTest( "integrates correctly", diff --git a/packages/squiggle-lang/src/rescript/Distributions/PointSetDist/Continuous.res b/packages/squiggle-lang/src/rescript/Distributions/PointSetDist/Continuous.res index f0614a27..b5eb0330 100644 --- a/packages/squiggle-lang/src/rescript/Distributions/PointSetDist/Continuous.res +++ b/packages/squiggle-lang/src/rescript/Distributions/PointSetDist/Continuous.res @@ -86,6 +86,7 @@ let stepwiseToLinear = (t: t): t => // Note: This results in a distribution with as many points as the sum of those in t1 and t2. let combinePointwise = ( + ~combiner=XYShape.PointwiseCombination.combine, ~integralSumCachesFn=(_, _) => None, ~distributionType: PointSetTypes.distributionType=#PDF, fn: (float, float) => result, @@ -119,7 +120,7 @@ let combinePointwise = ( let interpolator = XYShape.XtoY.continuousInterpolator(t1.interpolation, extrapolation) - XYShape.PointwiseCombination.combine(fn, interpolator, t1.xyShape, t2.xyShape)->E.R2.fmap(x => + combiner(fn, interpolator, t1.xyShape, t2.xyShape)->E.R2.fmap(x => make(~integralSumCache=combinedIntegralSum, x) ) } @@ -271,7 +272,12 @@ module T = Dist({ XYShape.Analysis.getVarianceDangerously(t, mean, Analysis.getMeanOfSquares) let klDivergence = (prediction: t, answer: t) => { - combinePointwise(PointSetDist_Scoring.KLDivergence.integrand, prediction, answer) + combinePointwise( + ~combiner=XYShape.PointwiseCombination.combineAlongSupportOfSecondArgument, + PointSetDist_Scoring.KLDivergence.integrand, + prediction, + answer, + ) |> E.R.fmap(shapeMap(XYShape.T.filterYValues(Js.Float.isFinite))) |> E.R.fmap(integralEndY) } diff --git a/packages/squiggle-lang/src/rescript/Distributions/PointSetDist/Discrete.res b/packages/squiggle-lang/src/rescript/Distributions/PointSetDist/Discrete.res index 53a8f45c..3aa92230 100644 --- a/packages/squiggle-lang/src/rescript/Distributions/PointSetDist/Discrete.res +++ b/packages/squiggle-lang/src/rescript/Distributions/PointSetDist/Discrete.res @@ -33,6 +33,7 @@ let shapeFn = (fn, t: t) => t |> getShape |> fn let lastY = (t: t) => t |> getShape |> XYShape.T.lastY let combinePointwise = ( + ~combiner=XYShape.PointwiseCombination.combine, ~integralSumCachesFn=(_, _) => None, ~fn=(a, b) => Ok(a +. b), t1: PointSetTypes.discreteShape, @@ -48,12 +49,10 @@ let combinePointwise = ( // It could be done for pointwise additions, but is that ever needed? make( - XYShape.PointwiseCombination.combine( - fn, - XYShape.XtoY.discreteInterpolator, - t1.xyShape, - t2.xyShape, - )->E.R.toExn("Addition operation should never fail", _), + combiner(fn, XYShape.XtoY.discreteInterpolator, t1.xyShape, t2.xyShape)->E.R.toExn( + "Addition operation should never fail", + _, + ), )->Ok } @@ -231,6 +230,7 @@ module T = Dist({ let klDivergence = (prediction: t, answer: t) => { combinePointwise( + ~combiner=XYShape.PointwiseCombination.combineAlongSupportOfSecondArgument, ~fn=PointSetDist_Scoring.KLDivergence.integrand, prediction, answer, diff --git a/packages/squiggle-lang/src/rescript/Distributions/PointSetDist/Mixed.res b/packages/squiggle-lang/src/rescript/Distributions/PointSetDist/Mixed.res index 743ad231..50cd8939 100644 --- a/packages/squiggle-lang/src/rescript/Distributions/PointSetDist/Mixed.res +++ b/packages/squiggle-lang/src/rescript/Distributions/PointSetDist/Mixed.res @@ -48,7 +48,7 @@ let combinePointwise = ( |> E.A.fmap(toDiscrete) |> E.A.O.concatSomes |> Discrete.reduce(~integralSumCachesFn, fn) - |> E.R.toExn("foo") + |> E.R.toExn("Theoretically unreachable state") let reducedContinuous = [t1, t2] diff --git a/packages/squiggle-lang/src/rescript/Distributions/PointSetDist/PointSetDist_Scoring.res b/packages/squiggle-lang/src/rescript/Distributions/PointSetDist/PointSetDist_Scoring.res index 4b50c725..7a1f6f61 100644 --- a/packages/squiggle-lang/src/rescript/Distributions/PointSetDist/PointSetDist_Scoring.res +++ b/packages/squiggle-lang/src/rescript/Distributions/PointSetDist/PointSetDist_Scoring.res @@ -1,5 +1,5 @@ module KLDivergence = { - let logFn = Js.Math.log + let logFn = Js.Math.log // base e let integrand = (predictionElement: float, answerElement: float): result< float, Operation.Error.t, @@ -7,9 +7,9 @@ module KLDivergence = { if predictionElement == 0.0 { Error(Operation.NegativeInfinityError) } else if answerElement == 0.0 { - Ok(answerElement) + Ok(0.0) } else { let quot = predictionElement /. answerElement - quot < 0.0 ? Error(Operation.ComplexNumberError) : Ok(-.answerElement *. logFn(quot)) + quot < 0.0 ? Error(Operation.ComplexNumberError) : Ok(answerElement *. logFn(quot)) } } diff --git a/packages/squiggle-lang/src/rescript/Utility/XYShape.res b/packages/squiggle-lang/src/rescript/Utility/XYShape.res index c0263927..8c3e7d92 100644 --- a/packages/squiggle-lang/src/rescript/Utility/XYShape.res +++ b/packages/squiggle-lang/src/rescript/Utility/XYShape.res @@ -97,7 +97,20 @@ module T = { let equallyDividedXs = (t: t, newLength) => E.A.Floats.range(minX(t), maxX(t), newLength) let toJs = (t: t) => {"xs": t.xs, "ys": t.ys} let filterYValues = (fn, t: t): t => t |> zip |> E.A.filter(((_, y)) => fn(y)) |> fromZippedArray - + let filterOkYs = (xs: array, ys: array>): t => { + let n = E.A.length(xs) // Assume length(xs) == length(ys) + let newXs = [] + let newYs = [] + for i in 0 to n - 1 { + switch ys[i] { + | Ok(y) => + let _ = Js.Array.push(xs[i], newXs) + let _ = Js.Array.push(y, newYs) + | Error(_) => () + } + } + {xs: newXs, ys: newYs} + } module Validator = { let fnName = "XYShape validate" let notSortedError = (p: string): error => NotSorted(p) @@ -377,6 +390,64 @@ module PointwiseCombination = { } `) + // This function is used for kl divergence + let combineAlongSupportOfSecondArgument: ( + (float, float) => result, + interpolator, + T.t, + T.t, + ) => result = (fn, interpolator, t1, t2) => { + let newYs = [] + let newXs = [] + let (l1, l2) = (E.A.length(t1.xs), E.A.length(t2.xs)) + let (i, j) = (ref(0), ref(0)) + let minX = t2.xs[0] + let maxX = t2.xs[l2 - 1] + while j.contents < l2 - 1 && i.contents < l1 - 1 { + let (x, y1, y2) = { + let x1 = t1.xs[i.contents + 1] + let x2 = t2.xs[j.contents + 1] + /* if t1 has to catch up to t2 */ if ( + i.contents < l1 - 1 && j.contents < l2 && x1 < x2 && minX <= x1 && x2 <= maxX + ) { + i := i.contents + 1 + let x = x1 + let y1 = t1.ys[i.contents] + let y2 = interpolator(t2, j.contents, x) + (x, y1, y2) + } else if ( + /* if t2 has to catch up to t1 */ + i.contents < l1 && j.contents < l2 - 1 && x1 > x2 && x2 >= minX && maxX >= x1 + ) { + j := j.contents + 1 + let x = x2 + let y1 = interpolator(t1, i.contents, x) + let y2 = t2.ys[j.contents] + (x, y1, y2) + } else if ( + /* move both ahead if they are equal */ + i.contents < l1 - 1 && j.contents < l2 - 1 && x1 == x2 && x1 >= minX && maxX >= x2 + ) { + i := i.contents + 1 + j := j.contents + 1 + let x = x1 + let y1 = t1.ys[i.contents] + let y2 = t2.ys[j.contents] + (x, y1, y2) + } else { + i := i.contents + 1 + (0.0, 0.0, 0.0) // for the function I have in mind, this will error out + // exception PointwiseCombinationError + // raise(PointwiseCombinationError) + } + } + // Js.Console.log(newYs) + let _ = Js.Array.push(fn(y1, y2), newYs) + let _ = Js.Array.push(x, newXs) + } + T.filterOkYs(newXs, newYs)->Ok + } + let addCombine = (interpolator: interpolator, t1: T.t, t2: T.t): T.t => combine((a, b) => Ok(a +. b), interpolator, t1, t2)->E.R.toExn( "Add operation should never fail", @@ -490,25 +561,6 @@ module Range = { } } -let pointLogScore = (prediction, answer) => - switch answer { - | 0. => 0.0 - | answer => answer *. Js.Math.log2(Js.Math.abs_float(prediction /. answer)) - } - -let logScorePoint = (sampleCount, t1, t2) => - PointwiseCombination.combineEvenXs( - ~fn=pointLogScore, - ~xToYSelection=XtoY.linear, - sampleCount, - t1, - t2, - ) - |> Range.integrateWithTriangles - |> E.O.fmap(T.accumulateYs(\"+.")) - |> E.O.fmap(Pairs.last) - |> E.O.fmap(Pairs.y) - module Analysis = { let getVarianceDangerously = (t: 't, mean: 't => float, getMeanOfSquares: 't => float): float => { let meanSquared = mean(t) ** 2.0