diff --git a/packages/squiggle-lang/__tests__/Distributions/KlDivergence_test.res b/packages/squiggle-lang/__tests__/Distributions/KlDivergence_test.res index bb21ba7f..2ba5100e 100644 --- a/packages/squiggle-lang/__tests__/Distributions/KlDivergence_test.res +++ b/packages/squiggle-lang/__tests__/Distributions/KlDivergence_test.res @@ -19,7 +19,7 @@ describe("kl divergence on continuous distributions", () => { let analyticalKl = Js.Math.log((highPrediction -. lowPrediction) /. (highAnswer -. lowAnswer)) let kl = E.R.liftJoin2(klDivergence, prediction, answer) switch kl { - | Ok(kl') => kl'->expect->toBeCloseTo(analyticalKl) + | Ok(kl') => kl'->expect->toBeSoCloseTo(analyticalKl, ~digits=7) | Error(err) => { Js.Console.log(DistributionTypes.Error.toString(err)) raise(KlFailed) @@ -51,7 +51,7 @@ describe("kl divergence on continuous distributions", () => { let kl = E.R.liftJoin2(klDivergence, prediction, answer) switch kl { - | Ok(kl') => kl'->expect->toBeCloseTo(analyticalKl) + | Ok(kl') => kl'->expect->toBeSoCloseTo(analyticalKl, ~digits=3) | Error(err) => { Js.Console.log(DistributionTypes.Error.toString(err)) raise(KlFailed) @@ -78,9 +78,9 @@ describe("kl divergence on discrete distributions", () => { | (Dist(prediction'), Dist(answer')) => klDivergence(prediction', answer') | _ => raise(MixtureFailed) } - let analyticalKl = Js.Math.log(2.0 /. 3.0) + let analyticalKl = Js.Math.log(3.0 /. 2.0) switch kl { - | Ok(kl') => kl'->expect->toBeCloseTo(analyticalKl) + | Ok(kl') => kl'->expect->toBeSoCloseTo(analyticalKl, ~digits=7) | Error(err) => Js.Console.log(DistributionTypes.Error.toString(err)) raise(KlFailed) diff --git a/packages/squiggle-lang/__tests__/TestHelpers.res b/packages/squiggle-lang/__tests__/TestHelpers.res index 54c4c814..21f3c51c 100644 --- a/packages/squiggle-lang/__tests__/TestHelpers.res +++ b/packages/squiggle-lang/__tests__/TestHelpers.res @@ -51,6 +51,7 @@ let mkExponential = rate => DistributionTypes.Symbolic(#Exponential({rate: rate} let mkUniform = (low, high) => DistributionTypes.Symbolic(#Uniform({low: low, high: high})) let mkCauchy = (local, scale) => DistributionTypes.Symbolic(#Cauchy({local: local, scale: scale})) let mkLognormal = (mu, sigma) => DistributionTypes.Symbolic(#Lognormal({mu: mu, sigma: sigma})) +let mkDirac = x => DistributionTypes.Symbolic(#Float(x)) let normalMake = SymbolicDist.Normal.make let betaMake = SymbolicDist.Beta.make diff --git a/packages/squiggle-lang/src/rescript/Distributions/PointSetDist/Discrete.res b/packages/squiggle-lang/src/rescript/Distributions/PointSetDist/Discrete.res index 31694232..658b7dc9 100644 --- a/packages/squiggle-lang/src/rescript/Distributions/PointSetDist/Discrete.res +++ b/packages/squiggle-lang/src/rescript/Distributions/PointSetDist/Discrete.res @@ -229,11 +229,25 @@ module T = Dist({ } let klDivergence = (prediction: t, answer: t) => { - combinePointwise( - ~combiner=XYShape.PointwiseCombination.combineAlongSupportOfSecondArgument0, - ~fn=PointSetDist_Scoring.KLDivergence.integrand, - prediction, - answer, - ) |> E.R2.bind(integralEndYResult) + let massOrZero = (t: t, x: float): float => { + let i = E.A.findIndex(x' => x' == x, t.xyShape.xs) + switch i { + | None => 0.0 + | Some(i') => t.xyShape.ys[i'] + } + } + let predictionNewYs = E.A.fmap(massOrZero(answer), prediction.xyShape.xs) + let integrand = XYShape.PointwiseCombination.combine( + PointSetDist_Scoring.KLDivergence.integrand, + XYShape.XtoY.continuousInterpolator(#Stepwise, #UseZero), + {XYShape.xs: answer.xyShape.xs, XYShape.ys: predictionNewYs}, + answer.xyShape, + ) + let xyShapeToDiscrete: XYShape.xyShape => t = xyShape => { + xyShape: xyShape, + integralSumCache: None, + integralCache: None, + } + integrand->E.R2.fmap(x => x->xyShapeToDiscrete->integralEndY) } }) diff --git a/packages/squiggle-lang/src/rescript/Distributions/PointSetDist/PointSetDist.res b/packages/squiggle-lang/src/rescript/Distributions/PointSetDist/PointSetDist.res index 0a8d2987..1879ebdd 100644 --- a/packages/squiggle-lang/src/rescript/Distributions/PointSetDist/PointSetDist.res +++ b/packages/squiggle-lang/src/rescript/Distributions/PointSetDist/PointSetDist.res @@ -199,6 +199,7 @@ module T = Dist({ let klDivergence = (t1: t, t2: t) => switch (t1, t2) { | (Continuous(t1), Continuous(t2)) => Continuous.T.klDivergence(t1, t2) + | (Discrete(t1), Discrete(t2)) => Discrete.T.klDivergence(t1, t2) | _ => Error(NotYetImplemented) } })