All three tests pass

- `uniform` `toPointSet` method has been changed for numerical
stability.

Value: [1e-1 to 1.75e0]
This commit is contained in:
Quinn Dougherty 2022-05-06 13:58:15 -04:00
parent d9a40c973a
commit 722bfc6366
5 changed files with 34 additions and 43 deletions

View File

@ -70,13 +70,14 @@ describe("kl divergence", () => {
let prediction = let prediction =
normalMakeR(mean1, stdev1)->E.R2.errMap(s => DistributionTypes.ArgumentError(s)) normalMakeR(mean1, stdev1)->E.R2.errMap(s => DistributionTypes.ArgumentError(s))
let answer = normalMakeR(mean2, stdev2)->E.R2.errMap(s => DistributionTypes.ArgumentError(s)) let answer = normalMakeR(mean2, stdev2)->E.R2.errMap(s => DistributionTypes.ArgumentError(s))
// https://stats.stackexchange.com/questions/7440/kl-divergence-between-two-univariate-gaussians
let analyticalKl = let analyticalKl =
Js.Math.log(stdev1 /. stdev2) +. Js.Math.log(stdev1 /. stdev2) +.
(stdev2 ** 2.0 +. (mean2 -. mean1) ** 2.0) /. (2.0 *. stdev1 ** 2.0) -. 0.5 (stdev2 ** 2.0 +. (mean2 -. mean1) ** 2.0) /. (2.0 *. stdev1 ** 2.0) -. 0.5
let kl = E.R.liftJoin2(klDivergence, prediction, answer) let kl = E.R.liftJoin2(klDivergence, prediction, answer)
Js.Console.log2("Analytical: ", analyticalKl) // Js.Console.log2("Analytical: ", analyticalKl)
Js.Console.log2("Computed: ", kl) // Js.Console.log2("Computed: ", kl)
switch kl { switch kl {
| Ok(kl') => kl'->expect->toBeCloseTo(analyticalKl) | Ok(kl') => kl'->expect->toBeCloseTo(analyticalKl)
@ -91,7 +92,7 @@ describe("kl divergence", () => {
describe("combine along support test", () => { describe("combine along support test", () => {
Skip.test("combine along support test", _ => { Skip.test("combine along support test", _ => {
// doesn't matter // doesn't matter
let combineAlongSupportOfSecondArgument = XYShape.PointwiseCombination.combineAlongSupportOfSecondArgument let combineAlongSupportOfSecondArgument = XYShape.PointwiseCombination.combineAlongSupportOfSecondArgument0
let lowAnswer = 0.0 let lowAnswer = 0.0
let highAnswer = 1.0 let highAnswer = 1.0
let lowPrediction = 0.0 let lowPrediction = 0.0

View File

@ -271,38 +271,30 @@ module T = Dist({
let variance = (t: t): float => let variance = (t: t): float =>
XYShape.Analysis.getVarianceDangerously(t, mean, Analysis.getMeanOfSquares) XYShape.Analysis.getVarianceDangerously(t, mean, Analysis.getMeanOfSquares)
let klDivergence0 = (prediction: t, answer: t) => { // let klDivergence0 = (prediction: t, answer: t) => {
combinePointwise( // combinePointwise(
~combiner=XYShape.PointwiseCombination.combineAlongSupportOfSecondArgument, // ~combiner=XYShape.PointwiseCombination.combineAlongSupportOfSecondArgument,
PointSetDist_Scoring.KLDivergence.integrand, // PointSetDist_Scoring.KLDivergence.integrand,
prediction, // prediction,
answer, // answer,
) // )
|> E.R.fmap(shapeMap(XYShape.T.filterYValues(Js.Float.isFinite))) // |> E.R.fmap(shapeMap(XYShape.T.filterYValues(Js.Float.isFinite)))
|> E.R.fmap(integralEndY) // |> E.R.fmap(integralEndY)
} // }
let klDivergence = (prediction: t, answer: t) => { let klDivergence = (prediction: t, answer: t) => {
let newShape = XYShape.PointwiseCombination.combineAlongSupportOfSecondArgument2( let newShape = XYShape.PointwiseCombination.combineAlongSupportOfSecondArgument(
PointSetDist_Scoring.KLDivergence.integrand, PointSetDist_Scoring.KLDivergence.integrand,
prediction.xyShape, prediction.xyShape,
answer.xyShape, answer.xyShape,
) )
let generateContinuousDistFromXYShape: XYShape.xyShape => t = xyShape => { let xyShapeToContinuous: XYShape.xyShape => t = xyShape => {
xyShape: xyShape, xyShape: xyShape,
interpolation: #Linear, interpolation: #Linear,
integralSumCache: None, integralSumCache: None,
integralCache: None, integralCache: None,
} }
let _ = Js.Console.log2("prediction", prediction) newShape->E.R2.fmap(x => x->xyShapeToContinuous->integralEndY)
let _ = Js.Console.log2("answer", answer)
let _ = Js.Console.log2("newShape", newShape)
switch newShape {
| Ok(tshape) => Ok(integralEndY(generateContinuousDistFromXYShape(tshape)))
| Error(errormessage) => Error(errormessage)
}
//|> E.R.fmap(shapeMap(XYShape.T.filterYValues(Js.Float.isFinite)))
//|> E.R.fmap(integralEndY)
} }
}) })

View File

@ -230,7 +230,7 @@ module T = Dist({
let klDivergence = (prediction: t, answer: t) => { let klDivergence = (prediction: t, answer: t) => {
combinePointwise( combinePointwise(
~combiner=XYShape.PointwiseCombination.combineAlongSupportOfSecondArgument, ~combiner=XYShape.PointwiseCombination.combineAlongSupportOfSecondArgument0,
~fn=PointSetDist_Scoring.KLDivergence.integrand, ~fn=PointSetDist_Scoring.KLDivergence.integrand,
prediction, prediction,
answer, answer,

View File

@ -396,8 +396,9 @@ module T = {
| (#ByWeight, #Uniform(n)) => | (#ByWeight, #Uniform(n)) =>
// In `ByWeight mode, uniform distributions get special treatment because we need two x's // In `ByWeight mode, uniform distributions get special treatment because we need two x's
// on either side for proper rendering (just left and right of the discontinuities). // on either side for proper rendering (just left and right of the discontinuities).
let dx = 0.00001 *. (n.high -. n.low) let distance = n.high -. n.low
[n.low -. dx, n.low +. dx, n.high -. dx, n.high +. dx] let dx = MagicNumbers.Epsilon.ten *. distance
[n.low -. dx, n.low, n.low +. dx, n.high -. dx, n.high, n.high +. dx]
| (#ByWeight, _) => | (#ByWeight, _) =>
let ys = E.A.Floats.range(minCdfValue, maxCdfValue, n) let ys = E.A.Floats.range(minCdfValue, maxCdfValue, n)
ys |> E.A.fmap(y => inv(y, dist)) ys |> E.A.fmap(y => inv(y, dist))

View File

@ -391,7 +391,7 @@ module PointwiseCombination = {
`) `)
// This function is used for kl divergence // This function is used for kl divergence
let combineAlongSupportOfSecondArgument: ( let combineAlongSupportOfSecondArgument0: (
(float, float) => result<float, Operation.Error.t>, (float, float) => result<float, Operation.Error.t>,
interpolator, interpolator,
T.t, T.t,
@ -455,14 +455,14 @@ module PointwiseCombination = {
T.filterOkYs(newXs, newYs)->Ok T.filterOkYs(newXs, newYs)->Ok
} }
let getApproximatePdfOfContinuousDistributionAtPoint: (xyShape, float) => option<float> = ( let approximatePdfOfContinuousDistributionAtPoint: (xyShape, float) => option<float> = (
dist: xyShape, distShape,
point: float, point,
) => { ) => {
let closestFromBelowIndex = E.A.reducei(dist.xs, None, (accumulator, item, index) => let closestFromBelowIndex = E.A.reducei(distShape.xs, None, (accumulator, item, index) =>
item < point ? Some(index) : accumulator item < point ? Some(index) : accumulator
) // This could be made more efficient by taking advantage of the fact that these are ordered ) // This could be made more efficient by taking advantage of the fact that these are ordered
let closestFromAboveIndexOption = Belt.Array.getIndexBy(dist.xs, item => item > point) let closestFromAboveIndexOption = Belt.Array.getIndexBy(distShape.xs, item => item > point)
let weightedMean = ( let weightedMean = (
point: float, point: float,
@ -480,30 +480,28 @@ module PointwiseCombination = {
let result = switch (closestFromBelowIndex, closestFromAboveIndexOption) { let result = switch (closestFromBelowIndex, closestFromAboveIndexOption) {
| (None, None) => None // all are smaller, and all are larger | (None, None) => None // all are smaller, and all are larger
| (None, Some(i)) => Some(0.0) // none are smaller, all are larger | (None, Some(_)) => Some(0.0) // none are smaller, all are larger
| (Some(i), None) => Some(0.0) // all are smaller, none are larger | (Some(_), None) => Some(0.0) // all are smaller, none are larger
| (Some(i), Some(j)) => | (Some(i), Some(j)) =>
Some(weightedMean(point, dist.xs[i], dist.xs[j], dist.ys[i], dist.ys[j])) // there is a lowerBound and an upperBound. Some(weightedMean(point, distShape.xs[i], distShape.xs[j], distShape.ys[i], distShape.ys[j])) // there is a lowerBound and an upperBound.
} }
result result
} }
let combineAlongSupportOfSecondArgument2: ( let combineAlongSupportOfSecondArgument: (
(float, float) => result<float, Operation.Error.t>, (float, float) => result<float, Operation.Error.t>,
T.t, T.t,
T.t, T.t,
) => result<T.t, Operation.Error.t> = (fn, prediction, answer) => { ) => result<T.t, Operation.Error.t> = (fn, prediction, answer) => {
let combineWithFn = (x: float, i: int) => { let combineWithFn = (answerX: float, i: int) => {
let answerX = x
let answerY = answer.ys[i] let answerY = answer.ys[i]
let predictionY = getApproximatePdfOfContinuousDistributionAtPoint(prediction, answerX) let predictionY = approximatePdfOfContinuousDistributionAtPoint(prediction, answerX)
let wrappedResult = E.O.fmap(x => fn(x, answerY), predictionY) let wrappedResult = E.O.fmap(x => fn(x, answerY), predictionY)
let result = switch wrappedResult { switch wrappedResult {
| Some(x) => x | Some(x) => x
| None => Error(Operation.LogicallyInconsistentPathwayError) | None => Error(Operation.LogicallyInconsistentPathwayError)
} }
result
} }
let newYsWithError = Js.Array.mapi((x, i) => combineWithFn(x, i), answer.xs) let newYsWithError = Js.Array.mapi((x, i) => combineWithFn(x, i), answer.xs)
let newYsOrError = E.A.R.firstErrorOrOpen(newYsWithError) let newYsOrError = E.A.R.firstErrorOrOpen(newYsWithError)
@ -512,7 +510,6 @@ module PointwiseCombination = {
| Error(b) => Error(b) | Error(b) => Error(b)
} }
// T.filterOkYs(newXs, newYs)->Ok
result result
} }