diff --git a/packages/squiggle-lang/__tests__/Distributions/KlDivergence_test.res b/packages/squiggle-lang/__tests__/Distributions/KlDivergence_test.res new file mode 100644 index 00000000..b1162124 --- /dev/null +++ b/packages/squiggle-lang/__tests__/Distributions/KlDivergence_test.res @@ -0,0 +1,34 @@ +open Jest +open Expect +open TestHelpers + +describe("kl divergence", () => { + let klDivergence = DistributionOperation.Constructors.klDivergence(~env) + test("", () => { + exception KlFailed + let lowAnswer = 4.3526e0 + let highAnswer = 8.5382e0 + let lowPrediction = 4.3526e0 + let highPrediction = 1.2345e1 + let answer = + uniformMakeR(lowAnswer, highAnswer)->E.R2.errMap(s => DistributionTypes.ArgumentError(s)) + let prediction = + uniformMakeR(lowPrediction, highPrediction)->E.R2.errMap(s => DistributionTypes.ArgumentError( + s, + )) + // integral along the support of the answer of answer.pdf(x) times log of prediction.pdf(x) divided by answer.pdf(x) dx + let analyticalKl = + -1.0 /. + (highAnswer -. lowAnswer) *. + Js.Math.log((highAnswer -. lowAnswer) /. (highPrediction -. lowPrediction)) *. + (highAnswer -. lowAnswer) + let kl = E.R.liftJoin2(klDivergence, prediction, answer) + switch kl { + | Ok(kl') => kl'->expect->toBeCloseTo(analyticalKl) + | Error(err) => { + Js.Console.log(DistributionTypes.Error.toString(err)) + raise(KlFailed) + } + } + }) +}) diff --git a/packages/squiggle-lang/src/rescript/Distributions/DistributionOperation/DistributionOperation.res b/packages/squiggle-lang/src/rescript/Distributions/DistributionOperation/DistributionOperation.res index 6872136e..80a53eb8 100644 --- a/packages/squiggle-lang/src/rescript/Distributions/DistributionOperation/DistributionOperation.res +++ b/packages/squiggle-lang/src/rescript/Distributions/DistributionOperation/DistributionOperation.res @@ -10,8 +10,8 @@ type env = { } let defaultEnv = { - sampleCount: 10000, - xyPointLength: 10000, + sampleCount: MagicNumbers.Environment.defaultSampleCount, + xyPointLength: MagicNumbers.Environment.defaultXYPointLength, } type outputType = @@ -128,7 +128,7 @@ let rec run = (~env, functionCallInfo: functionCallInfo): outputType => { let fromDistFn = ( subFnName: DistributionTypes.DistributionOperation.fromDist, dist: genericDist, - ) => { + ): outputType => { let response = switch subFnName { | ToFloat(distToFloatOperation) => GenericDist.toFloatOperation(dist, ~toPointSetFn, ~distToFloatOperation) @@ -261,7 +261,7 @@ module Constructors = { let pdf = (~env, dist, f) => C.pdf(dist, f)->run(~env)->toFloatR let normalize = (~env, dist) => C.normalize(dist)->run(~env)->toDistR let isNormalized = (~env, dist) => C.isNormalized(dist)->run(~env)->toBoolR - let logScore = (~env, dist1, dist2) => C.logScore(dist1, dist2)->run(~env)->toFloatR + let klDivergence = (~env, dist1, dist2) => C.klDivergence(dist1, dist2)->run(~env)->toFloatR let toPointSet = (~env, dist) => C.toPointSet(dist)->run(~env)->toDistR let toSampleSet = (~env, dist, n) => C.toSampleSet(dist, n)->run(~env)->toDistR let fromSamples = (~env, xs) => C.fromSamples(xs)->run(~env)->toDistR diff --git a/packages/squiggle-lang/src/rescript/Distributions/DistributionOperation/DistributionOperation.resi b/packages/squiggle-lang/src/rescript/Distributions/DistributionOperation/DistributionOperation.resi index 7476b619..200be7d7 100644 --- a/packages/squiggle-lang/src/rescript/Distributions/DistributionOperation/DistributionOperation.resi +++ b/packages/squiggle-lang/src/rescript/Distributions/DistributionOperation/DistributionOperation.resi @@ -60,7 +60,7 @@ module Constructors: { @genType let isNormalized: (~env: env, genericDist) => result @genType - let logScore: (~env: env, genericDist, genericDist) => result + let klDivergence: (~env: env, genericDist, genericDist) => result @genType let toPointSet: (~env: env, genericDist) => result @genType diff --git a/packages/squiggle-lang/src/rescript/Distributions/DistributionTypes.res b/packages/squiggle-lang/src/rescript/Distributions/DistributionTypes.res index 03210270..a9f7dfbe 100644 --- a/packages/squiggle-lang/src/rescript/Distributions/DistributionTypes.res +++ b/packages/squiggle-lang/src/rescript/Distributions/DistributionTypes.res @@ -160,7 +160,7 @@ module Constructors = { let fromSamples = (xs): t => FromSamples(xs) let truncate = (dist, left, right): t => FromDist(ToDist(Truncate(left, right)), dist) let inspect = (dist): t => FromDist(ToDist(Inspect), dist) - let logScore = (dist1, dist2): t => FromDist(ToScore(KLDivergence(dist2)), dist1) + let klDivergence = (dist1, dist2): t => FromDist(ToScore(KLDivergence(dist2)), dist1) let scalePower = (dist, n): t => FromDist(ToDist(Scale(#Power, n)), dist) let scaleLogarithm = (dist, n): t => FromDist(ToDist(Scale(#Logarithm, n)), dist) let scaleLogarithmWithThreshold = (dist, n, eps): t => FromDist( diff --git a/packages/squiggle-lang/src/rescript/Distributions/PointSetDist/Continuous.res b/packages/squiggle-lang/src/rescript/Distributions/PointSetDist/Continuous.res index b9b6b985..f0614a27 100644 --- a/packages/squiggle-lang/src/rescript/Distributions/PointSetDist/Continuous.res +++ b/packages/squiggle-lang/src/rescript/Distributions/PointSetDist/Continuous.res @@ -270,24 +270,10 @@ module T = Dist({ let variance = (t: t): float => XYShape.Analysis.getVarianceDangerously(t, mean, Analysis.getMeanOfSquares) - let klDivergence = (base: t, reference: t) => { - let referenceIsZero = switch Distributions.Common.isZeroEverywhere( - PointSetTypes.Continuous(reference), - ) { - | Continuous(b) => b - | _ => false - } - if referenceIsZero { - Ok(0.0) - } else { - combinePointwise( - PointSetDist_Scoring.KLDivergence.logScore(~eps=MagicNumbers.Epsilon.ten), - base, - reference, - ) - |> E.R.fmap(shapeMap(XYShape.T.filterYValues(Js.Float.isFinite))) - |> E.R.fmap(integralEndY) - } + let klDivergence = (prediction: t, answer: t) => { + combinePointwise(PointSetDist_Scoring.KLDivergence.integrand, prediction, answer) + |> E.R.fmap(shapeMap(XYShape.T.filterYValues(Js.Float.isFinite))) + |> E.R.fmap(integralEndY) } }) diff --git a/packages/squiggle-lang/src/rescript/Distributions/PointSetDist/Discrete.res b/packages/squiggle-lang/src/rescript/Distributions/PointSetDist/Discrete.res index 372eb3d1..53a8f45c 100644 --- a/packages/squiggle-lang/src/rescript/Distributions/PointSetDist/Discrete.res +++ b/packages/squiggle-lang/src/rescript/Distributions/PointSetDist/Discrete.res @@ -229,21 +229,11 @@ module T = Dist({ XYShape.Analysis.getVarianceDangerously(t, mean, getMeanOfSquares) } - let klDivergence = (base: t, reference: t) => { - let referenceIsZero = switch Distributions.Common.isZeroEverywhere( - PointSetTypes.Discrete(reference), - ) { - | Discrete(b) => b - | _ => false - } - if referenceIsZero { - Ok(0.0) - } else { - combinePointwise( - ~fn=PointSetDist_Scoring.KLDivergence.logScore(~eps=MagicNumbers.Epsilon.ten), - base, - reference, - ) |> E.R2.bind(integralEndYResult) - } + let klDivergence = (prediction: t, answer: t) => { + combinePointwise( + ~fn=PointSetDist_Scoring.KLDivergence.integrand, + prediction, + answer, + ) |> E.R2.bind(integralEndYResult) } }) diff --git a/packages/squiggle-lang/src/rescript/Distributions/PointSetDist/Distributions.res b/packages/squiggle-lang/src/rescript/Distributions/PointSetDist/Distributions.res index 5c9ee1aa..85ffe4b1 100644 --- a/packages/squiggle-lang/src/rescript/Distributions/PointSetDist/Distributions.res +++ b/packages/squiggle-lang/src/rescript/Distributions/PointSetDist/Distributions.res @@ -96,18 +96,4 @@ module Common = { None | (Some(s1), Some(s2)) => combineFn(s1, s2) } - - let isZeroEverywhere = (d: PointSetTypes.pointSetDist) => { - let isZero = (x: float): bool => x == 0.0 - PointSetTypes.ShapeMonad.fmap( - d, - ( - mixed => - E.A.all(isZero, mixed.continuous.xyShape.ys) && - E.A.all(isZero, mixed.discrete.xyShape.ys), - disc => E.A.all(isZero, disc.xyShape.ys), - cont => E.A.all(isZero, cont.xyShape.ys), - ), - ) - } } diff --git a/packages/squiggle-lang/src/rescript/Distributions/PointSetDist/Mixed.res b/packages/squiggle-lang/src/rescript/Distributions/PointSetDist/Mixed.res index 1521a13c..743ad231 100644 --- a/packages/squiggle-lang/src/rescript/Distributions/PointSetDist/Mixed.res +++ b/packages/squiggle-lang/src/rescript/Distributions/PointSetDist/Mixed.res @@ -301,22 +301,10 @@ module T = Dist({ } } - let klDivergence = (base: t, reference: t) => { - let referenceIsZero = switch Distributions.Common.isZeroEverywhere( - PointSetTypes.Mixed(reference), - ) { - | Mixed(b) => b - | _ => false - } - if referenceIsZero { - Ok(0.0) - } else { - combinePointwise( - PointSetDist_Scoring.KLDivergence.logScore(~eps=MagicNumbers.Epsilon.ten), - base, - reference, - ) |> E.R.fmap(integralEndY) - } + let klDivergence = (prediction: t, answer: t) => { + combinePointwise(PointSetDist_Scoring.KLDivergence.integrand, prediction, answer) |> E.R.fmap( + integralEndY, + ) } }) diff --git a/packages/squiggle-lang/src/rescript/Distributions/PointSetDist/PointSetDist.res b/packages/squiggle-lang/src/rescript/Distributions/PointSetDist/PointSetDist.res index 7b905316..c5cf466a 100644 --- a/packages/squiggle-lang/src/rescript/Distributions/PointSetDist/PointSetDist.res +++ b/packages/squiggle-lang/src/rescript/Distributions/PointSetDist/PointSetDist.res @@ -202,9 +202,9 @@ module T = Dist({ | (Discrete(t1), Discrete(t2)) => Discrete.T.klDivergence(t1, t2) | (Mixed(t1), Mixed(t2)) => Mixed.T.klDivergence(t1, t2) | _ => { - let t1 = toMixed(t1) - let t2 = toMixed(t2) - Mixed.T.klDivergence(t1, t2) + let t1' = toMixed(t1) + let t2' = toMixed(t2) + Mixed.T.klDivergence(t1', t2') } } }) diff --git a/packages/squiggle-lang/src/rescript/Distributions/PointSetDist/PointSetDist_Scoring.res b/packages/squiggle-lang/src/rescript/Distributions/PointSetDist/PointSetDist_Scoring.res index 34113edb..4b50c725 100644 --- a/packages/squiggle-lang/src/rescript/Distributions/PointSetDist/PointSetDist_Scoring.res +++ b/packages/squiggle-lang/src/rescript/Distributions/PointSetDist/PointSetDist_Scoring.res @@ -1,25 +1,15 @@ module KLDivergence = { let logFn = Js.Math.log - let subtraction = (a, b) => Ok(a -. b) - let multiply = (a: float, b: float): result => Ok(a *. b) - let logScoreDirect = (a: float, b: float): result => - if a == 0.0 { + let integrand = (predictionElement: float, answerElement: float): result< + float, + Operation.Error.t, + > => + if predictionElement == 0.0 { Error(Operation.NegativeInfinityError) - } else if b == 0.0 { - Ok(b) + } else if answerElement == 0.0 { + Ok(answerElement) } else { - let quot = a /. b - quot < 0.0 ? Error(Operation.ComplexNumberError) : Ok(b *. logFn(quot)) - } - let logScoreWithThreshold = (~eps: float, a: float, b: float): result => - if abs_float(a) < eps { - Ok(0.0) - } else { - logScoreDirect(a, b) - } - let logScore = (~eps: option=?, a: float, b: float): result => - switch eps { - | None => logScoreDirect(a, b) - | Some(eps') => logScoreWithThreshold(~eps=eps', a, b) + let quot = predictionElement /. answerElement + quot < 0.0 ? Error(Operation.ComplexNumberError) : Ok(-.answerElement *. logFn(quot)) } } diff --git a/packages/squiggle-lang/src/rescript/MagicNumbers.res b/packages/squiggle-lang/src/rescript/MagicNumbers.res index 0f059c03..13beafe4 100644 --- a/packages/squiggle-lang/src/rescript/MagicNumbers.res +++ b/packages/squiggle-lang/src/rescript/MagicNumbers.res @@ -6,6 +6,7 @@ module Math = { module Epsilon = { let ten = 1e-10 let seven = 1e-7 + let five = 1e-5 } module Environment = { diff --git a/packages/squiggle-lang/src/rescript/Utility/XYShape.res b/packages/squiggle-lang/src/rescript/Utility/XYShape.res index a90c68e5..c0263927 100644 --- a/packages/squiggle-lang/src/rescript/Utility/XYShape.res +++ b/packages/squiggle-lang/src/rescript/Utility/XYShape.res @@ -468,7 +468,7 @@ module Range = { // TODO: I think this isn't needed by any functions anymore. let stepsToContinuous = t => { // TODO: It would be nicer if this the diff didn't change the first element, and also maybe if there were a more elegant way of doing this. - let diff = T.xTotalRange(t) |> (r => r *. 0.00001) + let diff = T.xTotalRange(t) |> (r => r *. MagicNumbers.Epsilon.five) let items = switch E.A.toRanges(Belt.Array.zip(t.xs, t.ys)) { | Ok(items) => Some(