diff --git a/packages/squiggle-lang/__tests__/Distributions/Score_test.res b/packages/squiggle-lang/__tests__/Distributions/Score_test.res new file mode 100644 index 00000000..eb096950 --- /dev/null +++ b/packages/squiggle-lang/__tests__/Distributions/Score_test.res @@ -0,0 +1,3 @@ +open Jest +open Expect +open TestHelpers diff --git a/packages/squiggle-lang/__tests__/TS/Score_test.ts b/packages/squiggle-lang/__tests__/TS/Score_test.ts new file mode 100644 index 00000000..c83133a7 --- /dev/null +++ b/packages/squiggle-lang/__tests__/TS/Score_test.ts @@ -0,0 +1,19 @@ +import { testRun } from "./TestHelpers"; + +describe("KL divergence", () => { + test("by integral solver agrees with analytical", () => { + let squiggleStringKL = `prediction=normal(4, 1) + answer=normal(1,1) + logSubtraction=dotSubtract(scaleLog(answer),scaleLog(prediction)) + klintegrand=dotMultiply(logSubtraction, answer) + klintegral = integralSum(klintegrand) + analyticalKl = log(1 / 1) + 1 ^ 2 / (2 * 1 ^ 2) + ((4 - 1) * (1 - 4) / (2 * 1 * 1)) - 1 / 2 + klintegral - analyticalKl`; + let squiggleResultKL = testRun(squiggleStringKL); + expect(squiggleResultKL.value).toBeCloseTo(0); + }); +}); + +let squiggleStringLS = `prediction=normal(4,1) + answer=normal(1,1) + logScore(prediction, answer)`; diff --git a/packages/squiggle-lang/src/rescript/Distributions/PointSetDist/Continuous.res b/packages/squiggle-lang/src/rescript/Distributions/PointSetDist/Continuous.res index 6bbc95ae..1467239e 100644 --- a/packages/squiggle-lang/src/rescript/Distributions/PointSetDist/Continuous.res +++ b/packages/squiggle-lang/src/rescript/Distributions/PointSetDist/Continuous.res @@ -271,7 +271,10 @@ module T = Dist({ XYShape.Analysis.getVarianceDangerously(t, mean, Analysis.getMeanOfSquares) let logScore = (base: t, reference: t) => { - combinePointwise(PointSetDist_Scoring.LogScoring.logScore, base, reference) + E.R2.bind( + combinePointwise(PointSetDist_Scoring.LogScoring.multiply, reference), + combinePointwise(PointSetDist_Scoring.LogScoring.logScore, base, reference), + ) |> E.R.fmap(shapeMap(XYShape.T.filterYValues(Js.Float.isFinite))) |> E.R.fmap(integralEndY) } diff --git a/packages/squiggle-lang/src/rescript/Distributions/PointSetDist/PointSetDist_Scoring.res b/packages/squiggle-lang/src/rescript/Distributions/PointSetDist/PointSetDist_Scoring.res index 40ead2ce..157c3e23 100644 --- a/packages/squiggle-lang/src/rescript/Distributions/PointSetDist/PointSetDist_Scoring.res +++ b/packages/squiggle-lang/src/rescript/Distributions/PointSetDist/PointSetDist_Scoring.res @@ -4,4 +4,5 @@ module LogScoring = { let logScore = (a: float, b: float): result => Ok( Js.Math.log2(Js.Math.abs_float(a /. b)), ) + let multiply = (a: float, b: float): result => Ok(a *. b) }