feat: Get KL divergence working except in case of numerical errors ()

- Quinn was of great help here.
- I also left some dead code, which still has to be cleaned up
- There are still very annoying numerical errors, so I left one test
failing. These are due to how the interpolation is done
- Quinn to pick up from here

Value: [0.6 to 2]
This commit is contained in:
NunoSempere 2022-05-06 12:26:51 -04:00
parent 5dd272fb0c
commit d9a40c973a
4 changed files with 86 additions and 32 deletions

View File

@ -8,7 +8,7 @@ describe("kl divergence", () => {
test("of two uniforms is equal to the analytic expression", () => { test("of two uniforms is equal to the analytic expression", () => {
let lowAnswer = 0.0 let lowAnswer = 0.0
let highAnswer = 1.0 let highAnswer = 1.0
let lowPrediction = 0.0 let lowPrediction = -1.0
let highPrediction = 2.0 let highPrediction = 2.0
let answer = let answer =
uniformMakeR(lowAnswer, highAnswer)->E.R2.errMap(s => DistributionTypes.ArgumentError(s)) uniformMakeR(lowAnswer, highAnswer)->E.R2.errMap(s => DistributionTypes.ArgumentError(s))
@ -29,20 +29,50 @@ describe("kl divergence", () => {
} }
} }
}) })
test(
"of two uniforms is equal to the analytic expression, part 2 (annoying numerical errors)",
() => {
Js.Console.log(
"This will fait because of extremely annoying numerical errors. Will not fail if the two uniforms are a bit different. Very annoying",
)
let lowAnswer = 0.0
let highAnswer = 1.0
let lowPrediction = 0.0
let highPrediction = 2.0
let answer =
uniformMakeR(lowAnswer, highAnswer)->E.R2.errMap(s => DistributionTypes.ArgumentError(s))
let prediction =
uniformMakeR(
lowPrediction,
highPrediction,
)->E.R2.errMap(s => DistributionTypes.ArgumentError(s))
// integral along the support of the answer of answer.pdf(x) times log of prediction.pdf(x) divided by answer.pdf(x) dx
let analyticalKl = Js.Math.log((highPrediction -. lowPrediction) /. (highAnswer -. lowAnswer))
let kl = E.R.liftJoin2(klDivergence, prediction, answer)
Js.Console.log2("Analytical: ", analyticalKl)
Js.Console.log2("Computed: ", kl)
switch kl {
| Ok(kl') => kl'->expect->toBeCloseTo(analyticalKl)
| Error(err) => {
Js.Console.log(DistributionTypes.Error.toString(err))
raise(KlFailed)
}
}
},
)
test("of two normals is equal to the formula", () => { test("of two normals is equal to the formula", () => {
// This test case comes via Nuño https://github.com/quantified-uncertainty/squiggle/issues/433 // This test case comes via Nuño https://github.com/quantified-uncertainty/squiggle/issues/433
let mean1 = 4.0 let mean1 = 4.0
let mean2 = 1.0 let mean2 = 1.0
let stdev1 = 1.0 let stdev1 = 4.0
let stdev2 = 4.0 let stdev2 = 1.0
let prediction = let prediction =
normalMakeR(mean1, stdev1)->E.R2.errMap(s => DistributionTypes.ArgumentError(s)) normalMakeR(mean1, stdev1)->E.R2.errMap(s => DistributionTypes.ArgumentError(s))
let answer = normalMakeR(mean2, stdev2)->E.R2.errMap(s => DistributionTypes.ArgumentError(s)) let answer = normalMakeR(mean2, stdev2)->E.R2.errMap(s => DistributionTypes.ArgumentError(s))
let analyticalKl = let analyticalKl =
Js.Math.log(stdev2 /. stdev1) +. Js.Math.log(stdev1 /. stdev2) +.
stdev1 ** 2.0 /. 2.0 /. stdev2 ** 2.0 +. (stdev2 ** 2.0 +. (mean2 -. mean1) ** 2.0) /. (2.0 *. stdev1 ** 2.0) -. 0.5
(mean1 -. mean2) ** 2.0 /. 2.0 /. stdev2 ** 2.0 -. 0.5
let kl = E.R.liftJoin2(klDivergence, prediction, answer) let kl = E.R.liftJoin2(klDivergence, prediction, answer)
Js.Console.log2("Analytical: ", analyticalKl) Js.Console.log2("Analytical: ", analyticalKl)
@ -59,30 +89,31 @@ describe("kl divergence", () => {
}) })
describe("combine along support test", () => { describe("combine along support test", () => {
let combineAlongSupportOfSecondArgument = XYShape.PointwiseCombination.combineAlongSupportOfSecondArgument Skip.test("combine along support test", _ => {
let lowAnswer = 0.0 // doesn't matter
let highAnswer = 1.0 let combineAlongSupportOfSecondArgument = XYShape.PointwiseCombination.combineAlongSupportOfSecondArgument
let lowPrediction = -1.0 let lowAnswer = 0.0
let highPrediction = 2.0 let highAnswer = 1.0
let lowPrediction = 0.0
let highPrediction = 2.0
let answer = let answer =
uniformMakeR(lowAnswer, highAnswer)->E.R2.errMap(s => DistributionTypes.ArgumentError(s)) uniformMakeR(lowAnswer, highAnswer)->E.R2.errMap(s => DistributionTypes.ArgumentError(s))
let prediction = let prediction =
uniformMakeR(lowPrediction, highPrediction)->E.R2.errMap(s => DistributionTypes.ArgumentError( uniformMakeR(lowPrediction, highPrediction)->E.R2.errMap(s => DistributionTypes.ArgumentError(
s, s,
)) ))
let answerWrapped = E.R.fmap(a => run(FromDist(ToDist(ToPointSet), a)), answer) let answerWrapped = E.R.fmap(a => run(FromDist(ToDist(ToPointSet), a)), answer)
let predictionWrapped = E.R.fmap(a => run(FromDist(ToDist(ToPointSet), a)), prediction) let predictionWrapped = E.R.fmap(a => run(FromDist(ToDist(ToPointSet), a)), prediction)
let interpolator = XYShape.XtoY.continuousInterpolator(#Stepwise, #UseZero) let interpolator = XYShape.XtoY.continuousInterpolator(#Stepwise, #UseZero)
let integrand = PointSetDist_Scoring.KLDivergence.integrand let integrand = PointSetDist_Scoring.KLDivergence.integrand
let result = switch (answerWrapped, predictionWrapped) { let result = switch (answerWrapped, predictionWrapped) {
| (Ok(Dist(PointSet(Continuous(a)))), Ok(Dist(PointSet(Continuous(b))))) => | (Ok(Dist(PointSet(Continuous(a)))), Ok(Dist(PointSet(Continuous(b))))) =>
Some(combineAlongSupportOfSecondArgument(integrand, interpolator, a.xyShape, b.xyShape)) Some(combineAlongSupportOfSecondArgument(integrand, interpolator, a.xyShape, b.xyShape))
| _ => None | _ => None
} }
test("combine along support test", _ => {
Js.Console.log2("combineAlongSupportOfSecondArgument", result) Js.Console.log2("combineAlongSupportOfSecondArgument", result)
false->expect->toBe(true) false->expect->toBe(true)
}) })

View File

@ -15,6 +15,7 @@
"test": "jest", "test": "jest",
"test:ts": "jest __tests__/TS/", "test:ts": "jest __tests__/TS/",
"test:rescript": "jest --modulePathIgnorePatterns=__tests__/TS/*", "test:rescript": "jest --modulePathIgnorePatterns=__tests__/TS/*",
"test:kldivergence": "jest __tests__/Distributions/KlDivergence_test.*",
"test:watch": "jest --watchAll", "test:watch": "jest --watchAll",
"coverage:rescript": "rm -f *.coverage; yarn clean; BISECT_ENABLE=yes yarn build; yarn test:rescript; bisect-ppx-report html", "coverage:rescript": "rm -f *.coverage; yarn clean; BISECT_ENABLE=yes yarn build; yarn test:rescript; bisect-ppx-report html",
"coverage:ts": "yarn clean; yarn build; nyc --reporter=lcov yarn test:ts", "coverage:ts": "yarn clean; yarn build; nyc --reporter=lcov yarn test:ts",

View File

@ -271,7 +271,7 @@ module T = Dist({
let variance = (t: t): float => let variance = (t: t): float =>
XYShape.Analysis.getVarianceDangerously(t, mean, Analysis.getMeanOfSquares) XYShape.Analysis.getVarianceDangerously(t, mean, Analysis.getMeanOfSquares)
let klDivergence = (prediction: t, answer: t) => { let klDivergence0 = (prediction: t, answer: t) => {
combinePointwise( combinePointwise(
~combiner=XYShape.PointwiseCombination.combineAlongSupportOfSecondArgument, ~combiner=XYShape.PointwiseCombination.combineAlongSupportOfSecondArgument,
PointSetDist_Scoring.KLDivergence.integrand, PointSetDist_Scoring.KLDivergence.integrand,
@ -281,6 +281,29 @@ module T = Dist({
|> E.R.fmap(shapeMap(XYShape.T.filterYValues(Js.Float.isFinite))) |> E.R.fmap(shapeMap(XYShape.T.filterYValues(Js.Float.isFinite)))
|> E.R.fmap(integralEndY) |> E.R.fmap(integralEndY)
} }
let klDivergence = (prediction: t, answer: t) => {
let newShape = XYShape.PointwiseCombination.combineAlongSupportOfSecondArgument2(
PointSetDist_Scoring.KLDivergence.integrand,
prediction.xyShape,
answer.xyShape,
)
let generateContinuousDistFromXYShape: XYShape.xyShape => t = xyShape => {
xyShape: xyShape,
interpolation: #Linear,
integralSumCache: None,
integralCache: None,
}
let _ = Js.Console.log2("prediction", prediction)
let _ = Js.Console.log2("answer", answer)
let _ = Js.Console.log2("newShape", newShape)
switch newShape {
| Ok(tshape) => Ok(integralEndY(generateContinuousDistFromXYShape(tshape)))
| Error(errormessage) => Error(errormessage)
}
//|> E.R.fmap(shapeMap(XYShape.T.filterYValues(Js.Float.isFinite)))
//|> E.R.fmap(integralEndY)
}
}) })
let isNormalized = (t: t): bool => { let isNormalized = (t: t): bool => {

View File

@ -391,7 +391,7 @@ module PointwiseCombination = {
`) `)
// This function is used for kl divergence // This function is used for kl divergence
let combineAlongSupportOfSecondArgument0: ( let combineAlongSupportOfSecondArgument: (
(float, float) => result<float, Operation.Error.t>, (float, float) => result<float, Operation.Error.t>,
interpolator, interpolator,
T.t, T.t,
@ -489,12 +489,11 @@ module PointwiseCombination = {
result result
} }
let combineAlongSupportOfSecondArgument: ( let combineAlongSupportOfSecondArgument2: (
(float, float) => result<float, Operation.Error.t>, (float, float) => result<float, Operation.Error.t>,
interpolator,
T.t, T.t,
T.t, T.t,
) => result<T.t, Operation.Error.t> = (fn, interpolator, prediction, answer) => { ) => result<T.t, Operation.Error.t> = (fn, prediction, answer) => {
let combineWithFn = (x: float, i: int) => { let combineWithFn = (x: float, i: int) => {
let answerX = x let answerX = x
let answerY = answer.ys[i] let answerY = answer.ys[i]