combineAlongSupportOfSecondArgument implemented, tests still failing

Value: [1e-4 to 4e-2]
This commit is contained in:
Quinn Dougherty 2022-05-05 20:02:12 -04:00
parent b49865d3aa
commit dcf56d7bc6
7 changed files with 117 additions and 53 deletions

View File

@ -4,11 +4,11 @@ open TestHelpers
describe("kl divergence", () => { describe("kl divergence", () => {
let klDivergence = DistributionOperation.Constructors.klDivergence(~env) let klDivergence = DistributionOperation.Constructors.klDivergence(~env)
test("", () => {
exception KlFailed exception KlFailed
let lowAnswer = 4.3526e0 test("of two uniforms is equal to the analytic expression", () => {
let lowAnswer = 2.3526e0
let highAnswer = 8.5382e0 let highAnswer = 8.5382e0
let lowPrediction = 4.3526e0 let lowPrediction = 2.3526e0
let highPrediction = 1.2345e1 let highPrediction = 1.2345e1
let answer = let answer =
uniformMakeR(lowAnswer, highAnswer)->E.R2.errMap(s => DistributionTypes.ArgumentError(s)) uniformMakeR(lowAnswer, highAnswer)->E.R2.errMap(s => DistributionTypes.ArgumentError(s))
@ -17,11 +17,30 @@ describe("kl divergence", () => {
s, s,
)) ))
// integral along the support of the answer of answer.pdf(x) times log of prediction.pdf(x) divided by answer.pdf(x) dx // integral along the support of the answer of answer.pdf(x) times log of prediction.pdf(x) divided by answer.pdf(x) dx
let analyticalKl = Js.Math.log((highPrediction -. lowPrediction) /. (highAnswer -. lowAnswer))
let kl = E.R.liftJoin2(klDivergence, prediction, answer)
switch kl {
| Ok(kl') => kl'->expect->toBeCloseTo(analyticalKl)
| Error(err) => {
Js.Console.log(DistributionTypes.Error.toString(err))
raise(KlFailed)
}
}
})
test("of two normals is equal to the formula", () => {
// This test case comes via Nuño https://github.com/quantified-uncertainty/squiggle/issues/433
let mean1 = 4.0
let mean2 = 1.0
let stdev1 = 1.0
let stdev2 = 1.0
let prediction =
normalMakeR(mean1, stdev1)->E.R2.errMap(s => DistributionTypes.ArgumentError(s))
let answer = normalMakeR(mean2, stdev2)->E.R2.errMap(s => DistributionTypes.ArgumentError(s))
let analyticalKl = let analyticalKl =
-1.0 /. Js.Math.log(stdev2 /. stdev1) +.
(highAnswer -. lowAnswer) *. stdev1 ** 2.0 /. 2.0 /. stdev2 ** 2.0 +.
Js.Math.log((highAnswer -. lowAnswer) /. (highPrediction -. lowPrediction)) *. (mean1 -. mean2) ** 2.0 /. 2.0 /. stdev2 ** 2.0 -. 0.5
(highAnswer -. lowAnswer)
let kl = E.R.liftJoin2(klDivergence, prediction, answer) let kl = E.R.liftJoin2(klDivergence, prediction, answer)
switch kl { switch kl {
| Ok(kl') => kl'->expect->toBeCloseTo(analyticalKl) | Ok(kl') => kl'->expect->toBeCloseTo(analyticalKl)

View File

@ -38,19 +38,6 @@ describe("XYShapes", () => {
) )
}) })
describe("logScorePoint", () => {
makeTest("When identical", XYShape.logScorePoint(30, pointSetDist1, pointSetDist1), Some(0.0))
makeTest(
"When similar",
XYShape.logScorePoint(30, pointSetDist1, pointSetDist2),
Some(1.658971191043856),
)
makeTest(
"When very different",
XYShape.logScorePoint(30, pointSetDist1, pointSetDist3),
Some(210.3721280423322),
)
})
describe("integrateWithTriangles", () => describe("integrateWithTriangles", () =>
makeTest( makeTest(
"integrates correctly", "integrates correctly",

View File

@ -86,6 +86,7 @@ let stepwiseToLinear = (t: t): t =>
// Note: This results in a distribution with as many points as the sum of those in t1 and t2. // Note: This results in a distribution with as many points as the sum of those in t1 and t2.
let combinePointwise = ( let combinePointwise = (
~combiner=XYShape.PointwiseCombination.combine,
~integralSumCachesFn=(_, _) => None, ~integralSumCachesFn=(_, _) => None,
~distributionType: PointSetTypes.distributionType=#PDF, ~distributionType: PointSetTypes.distributionType=#PDF,
fn: (float, float) => result<float, Operation.Error.t>, fn: (float, float) => result<float, Operation.Error.t>,
@ -119,7 +120,7 @@ let combinePointwise = (
let interpolator = XYShape.XtoY.continuousInterpolator(t1.interpolation, extrapolation) let interpolator = XYShape.XtoY.continuousInterpolator(t1.interpolation, extrapolation)
XYShape.PointwiseCombination.combine(fn, interpolator, t1.xyShape, t2.xyShape)->E.R2.fmap(x => combiner(fn, interpolator, t1.xyShape, t2.xyShape)->E.R2.fmap(x =>
make(~integralSumCache=combinedIntegralSum, x) make(~integralSumCache=combinedIntegralSum, x)
) )
} }
@ -271,7 +272,12 @@ module T = Dist({
XYShape.Analysis.getVarianceDangerously(t, mean, Analysis.getMeanOfSquares) XYShape.Analysis.getVarianceDangerously(t, mean, Analysis.getMeanOfSquares)
let klDivergence = (prediction: t, answer: t) => { let klDivergence = (prediction: t, answer: t) => {
combinePointwise(PointSetDist_Scoring.KLDivergence.integrand, prediction, answer) combinePointwise(
~combiner=XYShape.PointwiseCombination.combineAlongSupportOfSecondArgument,
PointSetDist_Scoring.KLDivergence.integrand,
prediction,
answer,
)
|> E.R.fmap(shapeMap(XYShape.T.filterYValues(Js.Float.isFinite))) |> E.R.fmap(shapeMap(XYShape.T.filterYValues(Js.Float.isFinite)))
|> E.R.fmap(integralEndY) |> E.R.fmap(integralEndY)
} }

View File

@ -33,6 +33,7 @@ let shapeFn = (fn, t: t) => t |> getShape |> fn
let lastY = (t: t) => t |> getShape |> XYShape.T.lastY let lastY = (t: t) => t |> getShape |> XYShape.T.lastY
let combinePointwise = ( let combinePointwise = (
~combiner=XYShape.PointwiseCombination.combine,
~integralSumCachesFn=(_, _) => None, ~integralSumCachesFn=(_, _) => None,
~fn=(a, b) => Ok(a +. b), ~fn=(a, b) => Ok(a +. b),
t1: PointSetTypes.discreteShape, t1: PointSetTypes.discreteShape,
@ -48,12 +49,10 @@ let combinePointwise = (
// It could be done for pointwise additions, but is that ever needed? // It could be done for pointwise additions, but is that ever needed?
make( make(
XYShape.PointwiseCombination.combine( combiner(fn, XYShape.XtoY.discreteInterpolator, t1.xyShape, t2.xyShape)->E.R.toExn(
fn, "Addition operation should never fail",
XYShape.XtoY.discreteInterpolator, _,
t1.xyShape, ),
t2.xyShape,
)->E.R.toExn("Addition operation should never fail", _),
)->Ok )->Ok
} }
@ -231,6 +230,7 @@ module T = Dist({
let klDivergence = (prediction: t, answer: t) => { let klDivergence = (prediction: t, answer: t) => {
combinePointwise( combinePointwise(
~combiner=XYShape.PointwiseCombination.combineAlongSupportOfSecondArgument,
~fn=PointSetDist_Scoring.KLDivergence.integrand, ~fn=PointSetDist_Scoring.KLDivergence.integrand,
prediction, prediction,
answer, answer,

View File

@ -48,7 +48,7 @@ let combinePointwise = (
|> E.A.fmap(toDiscrete) |> E.A.fmap(toDiscrete)
|> E.A.O.concatSomes |> E.A.O.concatSomes
|> Discrete.reduce(~integralSumCachesFn, fn) |> Discrete.reduce(~integralSumCachesFn, fn)
|> E.R.toExn("foo") |> E.R.toExn("Theoretically unreachable state")
let reducedContinuous = let reducedContinuous =
[t1, t2] [t1, t2]

View File

@ -1,5 +1,5 @@
module KLDivergence = { module KLDivergence = {
let logFn = Js.Math.log let logFn = Js.Math.log // base e
let integrand = (predictionElement: float, answerElement: float): result< let integrand = (predictionElement: float, answerElement: float): result<
float, float,
Operation.Error.t, Operation.Error.t,
@ -7,9 +7,9 @@ module KLDivergence = {
if predictionElement == 0.0 { if predictionElement == 0.0 {
Error(Operation.NegativeInfinityError) Error(Operation.NegativeInfinityError)
} else if answerElement == 0.0 { } else if answerElement == 0.0 {
Ok(answerElement) Ok(0.0)
} else { } else {
let quot = predictionElement /. answerElement let quot = predictionElement /. answerElement
quot < 0.0 ? Error(Operation.ComplexNumberError) : Ok(-.answerElement *. logFn(quot)) quot < 0.0 ? Error(Operation.ComplexNumberError) : Ok(answerElement *. logFn(quot))
} }
} }

View File

@ -97,7 +97,20 @@ module T = {
let equallyDividedXs = (t: t, newLength) => E.A.Floats.range(minX(t), maxX(t), newLength) let equallyDividedXs = (t: t, newLength) => E.A.Floats.range(minX(t), maxX(t), newLength)
let toJs = (t: t) => {"xs": t.xs, "ys": t.ys} let toJs = (t: t) => {"xs": t.xs, "ys": t.ys}
let filterYValues = (fn, t: t): t => t |> zip |> E.A.filter(((_, y)) => fn(y)) |> fromZippedArray let filterYValues = (fn, t: t): t => t |> zip |> E.A.filter(((_, y)) => fn(y)) |> fromZippedArray
let filterOkYs = (xs: array<float>, ys: array<result<float, 'b>>): t => {
let n = E.A.length(xs) // Assume length(xs) == length(ys)
let newXs = []
let newYs = []
for i in 0 to n - 1 {
switch ys[i] {
| Ok(y) =>
let _ = Js.Array.push(xs[i], newXs)
let _ = Js.Array.push(y, newYs)
| Error(_) => ()
}
}
{xs: newXs, ys: newYs}
}
module Validator = { module Validator = {
let fnName = "XYShape validate" let fnName = "XYShape validate"
let notSortedError = (p: string): error => NotSorted(p) let notSortedError = (p: string): error => NotSorted(p)
@ -377,6 +390,64 @@ module PointwiseCombination = {
} }
`) `)
// This function is used for kl divergence
let combineAlongSupportOfSecondArgument: (
(float, float) => result<float, Operation.Error.t>,
interpolator,
T.t,
T.t,
) => result<T.t, Operation.Error.t> = (fn, interpolator, t1, t2) => {
let newYs = []
let newXs = []
let (l1, l2) = (E.A.length(t1.xs), E.A.length(t2.xs))
let (i, j) = (ref(0), ref(0))
let minX = t2.xs[0]
let maxX = t2.xs[l2 - 1]
while j.contents < l2 - 1 && i.contents < l1 - 1 {
let (x, y1, y2) = {
let x1 = t1.xs[i.contents + 1]
let x2 = t2.xs[j.contents + 1]
/* if t1 has to catch up to t2 */ if (
i.contents < l1 - 1 && j.contents < l2 && x1 < x2 && minX <= x1 && x2 <= maxX
) {
i := i.contents + 1
let x = x1
let y1 = t1.ys[i.contents]
let y2 = interpolator(t2, j.contents, x)
(x, y1, y2)
} else if (
/* if t2 has to catch up to t1 */
i.contents < l1 && j.contents < l2 - 1 && x1 > x2 && x2 >= minX && maxX >= x1
) {
j := j.contents + 1
let x = x2
let y1 = interpolator(t1, i.contents, x)
let y2 = t2.ys[j.contents]
(x, y1, y2)
} else if (
/* move both ahead if they are equal */
i.contents < l1 - 1 && j.contents < l2 - 1 && x1 == x2 && x1 >= minX && maxX >= x2
) {
i := i.contents + 1
j := j.contents + 1
let x = x1
let y1 = t1.ys[i.contents]
let y2 = t2.ys[j.contents]
(x, y1, y2)
} else {
i := i.contents + 1
(0.0, 0.0, 0.0) // for the function I have in mind, this will error out
// exception PointwiseCombinationError
// raise(PointwiseCombinationError)
}
}
// Js.Console.log(newYs)
let _ = Js.Array.push(fn(y1, y2), newYs)
let _ = Js.Array.push(x, newXs)
}
T.filterOkYs(newXs, newYs)->Ok
}
let addCombine = (interpolator: interpolator, t1: T.t, t2: T.t): T.t => let addCombine = (interpolator: interpolator, t1: T.t, t2: T.t): T.t =>
combine((a, b) => Ok(a +. b), interpolator, t1, t2)->E.R.toExn( combine((a, b) => Ok(a +. b), interpolator, t1, t2)->E.R.toExn(
"Add operation should never fail", "Add operation should never fail",
@ -490,25 +561,6 @@ module Range = {
} }
} }
let pointLogScore = (prediction, answer) =>
switch answer {
| 0. => 0.0
| answer => answer *. Js.Math.log2(Js.Math.abs_float(prediction /. answer))
}
let logScorePoint = (sampleCount, t1, t2) =>
PointwiseCombination.combineEvenXs(
~fn=pointLogScore,
~xToYSelection=XtoY.linear,
sampleCount,
t1,
t2,
)
|> Range.integrateWithTriangles
|> E.O.fmap(T.accumulateYs(\"+."))
|> E.O.fmap(Pairs.last)
|> E.O.fmap(Pairs.y)
module Analysis = { module Analysis = {
let getVarianceDangerously = (t: 't, mean: 't => float, getMeanOfSquares: 't => float): float => { let getVarianceDangerously = (t: 't, mean: 't => float, getMeanOfSquares: 't => float): float => {
let meanSquared = mean(t) ** 2.0 let meanSquared = mean(t) ** 2.0