diff --git a/packages/squiggle-lang/__tests__/Distributions/KlDivergence_test.res b/packages/squiggle-lang/__tests__/Distributions/KlDivergence_test.res index 31f03ae2..e3efe446 100644 --- a/packages/squiggle-lang/__tests__/Distributions/KlDivergence_test.res +++ b/packages/squiggle-lang/__tests__/Distributions/KlDivergence_test.res @@ -69,7 +69,7 @@ describe("klDivergence: continuous -> continuous -> float", () => { 2.0 ** 2.0 *. (10.0 ** 2.0 -. (10.0 +. 9.0) *. 10.0 +. (9.0 ** 2.0 +. 10.0 *. 9.0 +. 10.0 ** 2.0) /. 3.0) switch kl { - | Ok(kl') => kl'->expect->toBeSoCloseTo(analyticalKl, ~digits=1) + | Ok(kl') => kl'->expect->toBeSoCloseTo(analyticalKl, ~digits=3) | Error(err) => { Js.Console.log(DistributionTypes.Error.toString(err)) raise(KlFailed) diff --git a/packages/squiggle-lang/src/rescript/Distributions/PointSetDist/Continuous.res b/packages/squiggle-lang/src/rescript/Distributions/PointSetDist/Continuous.res index 3aca0c66..c6bac10b 100644 --- a/packages/squiggle-lang/src/rescript/Distributions/PointSetDist/Continuous.res +++ b/packages/squiggle-lang/src/rescript/Distributions/PointSetDist/Continuous.res @@ -272,10 +272,11 @@ module T = Dist({ XYShape.Analysis.getVarianceDangerously(t, mean, Analysis.getMeanOfSquares) let klDivergence = (prediction: t, answer: t) => { + let enrichedAnswer = XYShape.PointwiseCombination.enrichXyShape(answer.xyShape) // answer.xyShape // let newShape = XYShape.PointwiseCombination.combineAlongSupportOfSecondArgument( PointSetDist_Scoring.KLDivergence.integrand, prediction.xyShape, - answer.xyShape, + enrichedAnswer, ) let xyShapeToContinuous: XYShape.xyShape => t = xyShape => { xyShape: xyShape, diff --git a/packages/squiggle-lang/src/rescript/MagicNumbers.res b/packages/squiggle-lang/src/rescript/MagicNumbers.res index 13beafe4..cd8cfd2f 100644 --- a/packages/squiggle-lang/src/rescript/MagicNumbers.res +++ b/packages/squiggle-lang/src/rescript/MagicNumbers.res @@ -12,6 +12,7 @@ module Epsilon = { module Environment = { let defaultXYPointLength = 1000 let defaultSampleCount = 10000 + let enrichmentFactor = 10 } module OpCost = { diff --git a/packages/squiggle-lang/src/rescript/Utility/XYShape.res b/packages/squiggle-lang/src/rescript/Utility/XYShape.res index 60d0bbde..16ad64ab 100644 --- a/packages/squiggle-lang/src/rescript/Utility/XYShape.res +++ b/packages/squiggle-lang/src/rescript/Utility/XYShape.res @@ -453,6 +453,44 @@ module PointwiseCombination = { T.filterOkYs(newXs, newYs)->Ok } + let enrichXyShape = (t: T.t): T.t => { + let length = E.A.length(t.xs) + Js.Console.log(length) + let points = switch length < MagicNumbers.Environment.defaultXYPointLength { + | true => + Belt.Int.fromFloat( + Belt.Float.fromInt( + MagicNumbers.Environment.enrichmentFactor * MagicNumbers.Environment.defaultXYPointLength, + ) /. + Belt.Float.fromInt(length), + ) + | false => MagicNumbers.Environment.enrichmentFactor + } + + let getInBetween = (x1: float, x2: float): array => { + let newPointsArray = Belt.Array.makeBy(points - 1, i => i) + // don't repeat the x2 point, it will be gotten in the next iteration. + let result = Js.Array.mapi((pos, i) => + switch i { + | 0 => x1 + | _ => + x1 *. + (Belt.Float.fromInt(points) -. Belt.Float.fromInt(pos)) /. + Belt.Float.fromInt(points) +. x2 *. Belt.Float.fromInt(pos) /. Belt.Float.fromInt(points) + } + , newPointsArray) + result + } + let newXsUnflattened = Js.Array.mapi((x, i) => + switch i < length - 1 { + | true => getInBetween(x, t.xs[i + 1]) + | false => [x] + } + , t.xs) + let newXs = Belt.Array.concatMany(newXsUnflattened) + let newYs = E.A.fmap(x => XtoY.linear(x, t), newXs) //XtoY.linear(newXs) + {xs: newXs, ys: newYs} + } // This function is used for klDivergence let combineAlongSupportOfSecondArgument: ( (float, float) => result,