feat: Get KL divergence working except in case of numerical errors ()
- Quinn was of great help here. - I also left some dead code, which still has to be cleaned up - There are still very annoying numerical errors, so I left one test failing. These are due to how the interpolation is done - Quinn to pick up from here Value: [0.6 to 2]
This commit is contained in:
parent
5dd272fb0c
commit
d9a40c973a
|
@ -8,7 +8,7 @@ describe("kl divergence", () => {
|
||||||
test("of two uniforms is equal to the analytic expression", () => {
|
test("of two uniforms is equal to the analytic expression", () => {
|
||||||
let lowAnswer = 0.0
|
let lowAnswer = 0.0
|
||||||
let highAnswer = 1.0
|
let highAnswer = 1.0
|
||||||
let lowPrediction = 0.0
|
let lowPrediction = -1.0
|
||||||
let highPrediction = 2.0
|
let highPrediction = 2.0
|
||||||
let answer =
|
let answer =
|
||||||
uniformMakeR(lowAnswer, highAnswer)->E.R2.errMap(s => DistributionTypes.ArgumentError(s))
|
uniformMakeR(lowAnswer, highAnswer)->E.R2.errMap(s => DistributionTypes.ArgumentError(s))
|
||||||
|
@ -29,20 +29,50 @@ describe("kl divergence", () => {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
|
test(
|
||||||
|
"of two uniforms is equal to the analytic expression, part 2 (annoying numerical errors)",
|
||||||
|
() => {
|
||||||
|
Js.Console.log(
|
||||||
|
"This will fait because of extremely annoying numerical errors. Will not fail if the two uniforms are a bit different. Very annoying",
|
||||||
|
)
|
||||||
|
let lowAnswer = 0.0
|
||||||
|
let highAnswer = 1.0
|
||||||
|
let lowPrediction = 0.0
|
||||||
|
let highPrediction = 2.0
|
||||||
|
let answer =
|
||||||
|
uniformMakeR(lowAnswer, highAnswer)->E.R2.errMap(s => DistributionTypes.ArgumentError(s))
|
||||||
|
let prediction =
|
||||||
|
uniformMakeR(
|
||||||
|
lowPrediction,
|
||||||
|
highPrediction,
|
||||||
|
)->E.R2.errMap(s => DistributionTypes.ArgumentError(s))
|
||||||
|
// integral along the support of the answer of answer.pdf(x) times log of prediction.pdf(x) divided by answer.pdf(x) dx
|
||||||
|
let analyticalKl = Js.Math.log((highPrediction -. lowPrediction) /. (highAnswer -. lowAnswer))
|
||||||
|
let kl = E.R.liftJoin2(klDivergence, prediction, answer)
|
||||||
|
Js.Console.log2("Analytical: ", analyticalKl)
|
||||||
|
Js.Console.log2("Computed: ", kl)
|
||||||
|
switch kl {
|
||||||
|
| Ok(kl') => kl'->expect->toBeCloseTo(analyticalKl)
|
||||||
|
| Error(err) => {
|
||||||
|
Js.Console.log(DistributionTypes.Error.toString(err))
|
||||||
|
raise(KlFailed)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
},
|
||||||
|
)
|
||||||
test("of two normals is equal to the formula", () => {
|
test("of two normals is equal to the formula", () => {
|
||||||
// This test case comes via Nuño https://github.com/quantified-uncertainty/squiggle/issues/433
|
// This test case comes via Nuño https://github.com/quantified-uncertainty/squiggle/issues/433
|
||||||
let mean1 = 4.0
|
let mean1 = 4.0
|
||||||
let mean2 = 1.0
|
let mean2 = 1.0
|
||||||
let stdev1 = 1.0
|
let stdev1 = 4.0
|
||||||
let stdev2 = 4.0
|
let stdev2 = 1.0
|
||||||
|
|
||||||
let prediction =
|
let prediction =
|
||||||
normalMakeR(mean1, stdev1)->E.R2.errMap(s => DistributionTypes.ArgumentError(s))
|
normalMakeR(mean1, stdev1)->E.R2.errMap(s => DistributionTypes.ArgumentError(s))
|
||||||
let answer = normalMakeR(mean2, stdev2)->E.R2.errMap(s => DistributionTypes.ArgumentError(s))
|
let answer = normalMakeR(mean2, stdev2)->E.R2.errMap(s => DistributionTypes.ArgumentError(s))
|
||||||
let analyticalKl =
|
let analyticalKl =
|
||||||
Js.Math.log(stdev2 /. stdev1) +.
|
Js.Math.log(stdev1 /. stdev2) +.
|
||||||
stdev1 ** 2.0 /. 2.0 /. stdev2 ** 2.0 +.
|
(stdev2 ** 2.0 +. (mean2 -. mean1) ** 2.0) /. (2.0 *. stdev1 ** 2.0) -. 0.5
|
||||||
(mean1 -. mean2) ** 2.0 /. 2.0 /. stdev2 ** 2.0 -. 0.5
|
|
||||||
let kl = E.R.liftJoin2(klDivergence, prediction, answer)
|
let kl = E.R.liftJoin2(klDivergence, prediction, answer)
|
||||||
|
|
||||||
Js.Console.log2("Analytical: ", analyticalKl)
|
Js.Console.log2("Analytical: ", analyticalKl)
|
||||||
|
@ -59,30 +89,31 @@ describe("kl divergence", () => {
|
||||||
})
|
})
|
||||||
|
|
||||||
describe("combine along support test", () => {
|
describe("combine along support test", () => {
|
||||||
let combineAlongSupportOfSecondArgument = XYShape.PointwiseCombination.combineAlongSupportOfSecondArgument
|
Skip.test("combine along support test", _ => {
|
||||||
let lowAnswer = 0.0
|
// doesn't matter
|
||||||
let highAnswer = 1.0
|
let combineAlongSupportOfSecondArgument = XYShape.PointwiseCombination.combineAlongSupportOfSecondArgument
|
||||||
let lowPrediction = -1.0
|
let lowAnswer = 0.0
|
||||||
let highPrediction = 2.0
|
let highAnswer = 1.0
|
||||||
|
let lowPrediction = 0.0
|
||||||
|
let highPrediction = 2.0
|
||||||
|
|
||||||
let answer =
|
let answer =
|
||||||
uniformMakeR(lowAnswer, highAnswer)->E.R2.errMap(s => DistributionTypes.ArgumentError(s))
|
uniformMakeR(lowAnswer, highAnswer)->E.R2.errMap(s => DistributionTypes.ArgumentError(s))
|
||||||
let prediction =
|
let prediction =
|
||||||
uniformMakeR(lowPrediction, highPrediction)->E.R2.errMap(s => DistributionTypes.ArgumentError(
|
uniformMakeR(lowPrediction, highPrediction)->E.R2.errMap(s => DistributionTypes.ArgumentError(
|
||||||
s,
|
s,
|
||||||
))
|
))
|
||||||
let answerWrapped = E.R.fmap(a => run(FromDist(ToDist(ToPointSet), a)), answer)
|
let answerWrapped = E.R.fmap(a => run(FromDist(ToDist(ToPointSet), a)), answer)
|
||||||
let predictionWrapped = E.R.fmap(a => run(FromDist(ToDist(ToPointSet), a)), prediction)
|
let predictionWrapped = E.R.fmap(a => run(FromDist(ToDist(ToPointSet), a)), prediction)
|
||||||
|
|
||||||
let interpolator = XYShape.XtoY.continuousInterpolator(#Stepwise, #UseZero)
|
let interpolator = XYShape.XtoY.continuousInterpolator(#Stepwise, #UseZero)
|
||||||
let integrand = PointSetDist_Scoring.KLDivergence.integrand
|
let integrand = PointSetDist_Scoring.KLDivergence.integrand
|
||||||
|
|
||||||
let result = switch (answerWrapped, predictionWrapped) {
|
let result = switch (answerWrapped, predictionWrapped) {
|
||||||
| (Ok(Dist(PointSet(Continuous(a)))), Ok(Dist(PointSet(Continuous(b))))) =>
|
| (Ok(Dist(PointSet(Continuous(a)))), Ok(Dist(PointSet(Continuous(b))))) =>
|
||||||
Some(combineAlongSupportOfSecondArgument(integrand, interpolator, a.xyShape, b.xyShape))
|
Some(combineAlongSupportOfSecondArgument(integrand, interpolator, a.xyShape, b.xyShape))
|
||||||
| _ => None
|
| _ => None
|
||||||
}
|
}
|
||||||
test("combine along support test", _ => {
|
|
||||||
Js.Console.log2("combineAlongSupportOfSecondArgument", result)
|
Js.Console.log2("combineAlongSupportOfSecondArgument", result)
|
||||||
false->expect->toBe(true)
|
false->expect->toBe(true)
|
||||||
})
|
})
|
||||||
|
|
|
@ -15,6 +15,7 @@
|
||||||
"test": "jest",
|
"test": "jest",
|
||||||
"test:ts": "jest __tests__/TS/",
|
"test:ts": "jest __tests__/TS/",
|
||||||
"test:rescript": "jest --modulePathIgnorePatterns=__tests__/TS/*",
|
"test:rescript": "jest --modulePathIgnorePatterns=__tests__/TS/*",
|
||||||
|
"test:kldivergence": "jest __tests__/Distributions/KlDivergence_test.*",
|
||||||
"test:watch": "jest --watchAll",
|
"test:watch": "jest --watchAll",
|
||||||
"coverage:rescript": "rm -f *.coverage; yarn clean; BISECT_ENABLE=yes yarn build; yarn test:rescript; bisect-ppx-report html",
|
"coverage:rescript": "rm -f *.coverage; yarn clean; BISECT_ENABLE=yes yarn build; yarn test:rescript; bisect-ppx-report html",
|
||||||
"coverage:ts": "yarn clean; yarn build; nyc --reporter=lcov yarn test:ts",
|
"coverage:ts": "yarn clean; yarn build; nyc --reporter=lcov yarn test:ts",
|
||||||
|
|
|
@ -271,7 +271,7 @@ module T = Dist({
|
||||||
let variance = (t: t): float =>
|
let variance = (t: t): float =>
|
||||||
XYShape.Analysis.getVarianceDangerously(t, mean, Analysis.getMeanOfSquares)
|
XYShape.Analysis.getVarianceDangerously(t, mean, Analysis.getMeanOfSquares)
|
||||||
|
|
||||||
let klDivergence = (prediction: t, answer: t) => {
|
let klDivergence0 = (prediction: t, answer: t) => {
|
||||||
combinePointwise(
|
combinePointwise(
|
||||||
~combiner=XYShape.PointwiseCombination.combineAlongSupportOfSecondArgument,
|
~combiner=XYShape.PointwiseCombination.combineAlongSupportOfSecondArgument,
|
||||||
PointSetDist_Scoring.KLDivergence.integrand,
|
PointSetDist_Scoring.KLDivergence.integrand,
|
||||||
|
@ -281,6 +281,29 @@ module T = Dist({
|
||||||
|> E.R.fmap(shapeMap(XYShape.T.filterYValues(Js.Float.isFinite)))
|
|> E.R.fmap(shapeMap(XYShape.T.filterYValues(Js.Float.isFinite)))
|
||||||
|> E.R.fmap(integralEndY)
|
|> E.R.fmap(integralEndY)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
let klDivergence = (prediction: t, answer: t) => {
|
||||||
|
let newShape = XYShape.PointwiseCombination.combineAlongSupportOfSecondArgument2(
|
||||||
|
PointSetDist_Scoring.KLDivergence.integrand,
|
||||||
|
prediction.xyShape,
|
||||||
|
answer.xyShape,
|
||||||
|
)
|
||||||
|
let generateContinuousDistFromXYShape: XYShape.xyShape => t = xyShape => {
|
||||||
|
xyShape: xyShape,
|
||||||
|
interpolation: #Linear,
|
||||||
|
integralSumCache: None,
|
||||||
|
integralCache: None,
|
||||||
|
}
|
||||||
|
let _ = Js.Console.log2("prediction", prediction)
|
||||||
|
let _ = Js.Console.log2("answer", answer)
|
||||||
|
let _ = Js.Console.log2("newShape", newShape)
|
||||||
|
switch newShape {
|
||||||
|
| Ok(tshape) => Ok(integralEndY(generateContinuousDistFromXYShape(tshape)))
|
||||||
|
| Error(errormessage) => Error(errormessage)
|
||||||
|
}
|
||||||
|
//|> E.R.fmap(shapeMap(XYShape.T.filterYValues(Js.Float.isFinite)))
|
||||||
|
//|> E.R.fmap(integralEndY)
|
||||||
|
}
|
||||||
})
|
})
|
||||||
|
|
||||||
let isNormalized = (t: t): bool => {
|
let isNormalized = (t: t): bool => {
|
||||||
|
|
|
@ -391,7 +391,7 @@ module PointwiseCombination = {
|
||||||
`)
|
`)
|
||||||
|
|
||||||
// This function is used for kl divergence
|
// This function is used for kl divergence
|
||||||
let combineAlongSupportOfSecondArgument0: (
|
let combineAlongSupportOfSecondArgument: (
|
||||||
(float, float) => result<float, Operation.Error.t>,
|
(float, float) => result<float, Operation.Error.t>,
|
||||||
interpolator,
|
interpolator,
|
||||||
T.t,
|
T.t,
|
||||||
|
@ -489,12 +489,11 @@ module PointwiseCombination = {
|
||||||
result
|
result
|
||||||
}
|
}
|
||||||
|
|
||||||
let combineAlongSupportOfSecondArgument: (
|
let combineAlongSupportOfSecondArgument2: (
|
||||||
(float, float) => result<float, Operation.Error.t>,
|
(float, float) => result<float, Operation.Error.t>,
|
||||||
interpolator,
|
|
||||||
T.t,
|
T.t,
|
||||||
T.t,
|
T.t,
|
||||||
) => result<T.t, Operation.Error.t> = (fn, interpolator, prediction, answer) => {
|
) => result<T.t, Operation.Error.t> = (fn, prediction, answer) => {
|
||||||
let combineWithFn = (x: float, i: int) => {
|
let combineWithFn = (x: float, i: int) => {
|
||||||
let answerX = x
|
let answerX = x
|
||||||
let answerY = answer.ys[i]
|
let answerY = answer.ys[i]
|
||||||
|
|
Loading…
Reference in New Issue
Block a user