diff --git a/packages/squiggle-lang/__tests__/Distributions/KlDivergence_test.res b/packages/squiggle-lang/__tests__/Distributions/KlDivergence_test.res index 5355f14d..6fb73366 100644 --- a/packages/squiggle-lang/__tests__/Distributions/KlDivergence_test.res +++ b/packages/squiggle-lang/__tests__/Distributions/KlDivergence_test.res @@ -119,18 +119,18 @@ describe("klDivergence: discrete -> discrete -> float", () => { describe("klDivergence: mixed -> mixed -> float", () => { let klDivergence = DistributionOperation.Constructors.klDivergence(~env) let mixture = a => DistributionTypes.DistributionOperation.Mixture(a) - let a' = [(point1, 1e0), (uniformDist, 1e0)]->mixture->run - let b' = [(point1, 1e0), (floatDist, 1e0), (normalDist10, 1e0)]->mixture->run - let c' = [(point1, 1e0), (point2, 1e0), (point3, 1e0), (uniformDist, 1e0)]->mixture->run + let a' = [(point1, 1.0), (uniformDist, 1.0)]->mixture->run + let b' = [(point1, 1.0), (floatDist, 1.0), (normalDist10, 1.0)]->mixture->run + let c' = [(point1, 1.0), (point2, 1.0), (point3, 1.0), (uniformDist, 1.0)]->mixture->run let d' = - [(point1, 1e0), (point2, 1e0), (point3, 1e0), (floatDist, 1e0), (uniformDist2, 1e0)] + [(point1, 1.0), (point2, 1.0), (point3, 1.0), (floatDist, 1.0), (uniformDist2, 1.0)] ->mixture ->run let (a, b, c, d) = switch (a', b', c', d') { | (Dist(a''), Dist(b''), Dist(c''), Dist(d'')) => (a'', b'', c'', d'') | _ => raise(MixtureFailed) } - test("finite klDivergence return is correct", () => { + test("finite klDivergence produces correct answer", () => { let prediction = b let answer = a let kl = klDivergence(prediction, answer) @@ -156,7 +156,7 @@ describe("klDivergence: mixed -> mixed -> float", () => { raise(KlFailed) } }) - test("finite klDivergence return is correct", () => { + test("finite klDivergence produces correct answer", () => { let prediction = d let answer = c let kl = klDivergence(prediction, answer) diff --git a/packages/squiggle-lang/src/rescript/Distributions/PointSetDist/Continuous.res b/packages/squiggle-lang/src/rescript/Distributions/PointSetDist/Continuous.res index e8049688..3aca0c66 100644 --- a/packages/squiggle-lang/src/rescript/Distributions/PointSetDist/Continuous.res +++ b/packages/squiggle-lang/src/rescript/Distributions/PointSetDist/Continuous.res @@ -272,14 +272,10 @@ module T = Dist({ XYShape.Analysis.getVarianceDangerously(t, mean, Analysis.getMeanOfSquares) let klDivergence = (prediction: t, answer: t) => { - let enrich = true - let enrichedAnswer = enrich - ? XYShape.PointwiseCombination.enrichXyShape(answer.xyShape) - : answer.xyShape // let newShape = XYShape.PointwiseCombination.combineAlongSupportOfSecondArgument( PointSetDist_Scoring.KLDivergence.integrand, prediction.xyShape, - enrichedAnswer, + answer.xyShape, ) let xyShapeToContinuous: XYShape.xyShape => t = xyShape => { xyShape: xyShape, diff --git a/packages/squiggle-lang/src/rescript/Utility/XYShape.res b/packages/squiggle-lang/src/rescript/Utility/XYShape.res index e3bd9016..5c3055f5 100644 --- a/packages/squiggle-lang/src/rescript/Utility/XYShape.res +++ b/packages/squiggle-lang/src/rescript/Utility/XYShape.res @@ -453,14 +453,15 @@ module PointwiseCombination = { T.filterOkYs(newXs, newYs)->Ok } - // Nuño wrote this function to try to increase precision, but it didn't work. + // *Dead code*: Nuño wrote this function to try to increase precision, but it didn't work. + // If another traveler comes through with a similar idea, we hope this implementation will help them. let enrichXyShape = (t: T.t): T.t => { - let enrichmentFactor = 10 + let defaultEnrichmentFactor = 10 let length = E.A.length(t.xs) let points = length < MagicNumbers.Environment.defaultXYPointLength - ? enrichmentFactor * MagicNumbers.Environment.defaultXYPointLength / length - : enrichmentFactor + ? defaultEnrichmentFactor * MagicNumbers.Environment.defaultXYPointLength / length + : defaultEnrichmentFactor let getInBetween = (x1: float, x2: float): array => { if abs_float(x1 -. x2) < 2.0 *. MagicNumbers.Epsilon.seven {