All three tests pass
- `uniform` `toPointSet` method has been changed for numerical stability. Value: [1e-1 to 1.75e0]
This commit is contained in:
parent
d9a40c973a
commit
722bfc6366
|
@ -70,13 +70,14 @@ describe("kl divergence", () => {
|
|||
let prediction =
|
||||
normalMakeR(mean1, stdev1)->E.R2.errMap(s => DistributionTypes.ArgumentError(s))
|
||||
let answer = normalMakeR(mean2, stdev2)->E.R2.errMap(s => DistributionTypes.ArgumentError(s))
|
||||
// https://stats.stackexchange.com/questions/7440/kl-divergence-between-two-univariate-gaussians
|
||||
let analyticalKl =
|
||||
Js.Math.log(stdev1 /. stdev2) +.
|
||||
(stdev2 ** 2.0 +. (mean2 -. mean1) ** 2.0) /. (2.0 *. stdev1 ** 2.0) -. 0.5
|
||||
let kl = E.R.liftJoin2(klDivergence, prediction, answer)
|
||||
|
||||
Js.Console.log2("Analytical: ", analyticalKl)
|
||||
Js.Console.log2("Computed: ", kl)
|
||||
// Js.Console.log2("Analytical: ", analyticalKl)
|
||||
// Js.Console.log2("Computed: ", kl)
|
||||
|
||||
switch kl {
|
||||
| Ok(kl') => kl'->expect->toBeCloseTo(analyticalKl)
|
||||
|
@ -91,7 +92,7 @@ describe("kl divergence", () => {
|
|||
describe("combine along support test", () => {
|
||||
Skip.test("combine along support test", _ => {
|
||||
// doesn't matter
|
||||
let combineAlongSupportOfSecondArgument = XYShape.PointwiseCombination.combineAlongSupportOfSecondArgument
|
||||
let combineAlongSupportOfSecondArgument = XYShape.PointwiseCombination.combineAlongSupportOfSecondArgument0
|
||||
let lowAnswer = 0.0
|
||||
let highAnswer = 1.0
|
||||
let lowPrediction = 0.0
|
||||
|
|
|
@ -271,38 +271,30 @@ module T = Dist({
|
|||
let variance = (t: t): float =>
|
||||
XYShape.Analysis.getVarianceDangerously(t, mean, Analysis.getMeanOfSquares)
|
||||
|
||||
let klDivergence0 = (prediction: t, answer: t) => {
|
||||
combinePointwise(
|
||||
~combiner=XYShape.PointwiseCombination.combineAlongSupportOfSecondArgument,
|
||||
PointSetDist_Scoring.KLDivergence.integrand,
|
||||
prediction,
|
||||
answer,
|
||||
)
|
||||
|> E.R.fmap(shapeMap(XYShape.T.filterYValues(Js.Float.isFinite)))
|
||||
|> E.R.fmap(integralEndY)
|
||||
}
|
||||
// let klDivergence0 = (prediction: t, answer: t) => {
|
||||
// combinePointwise(
|
||||
// ~combiner=XYShape.PointwiseCombination.combineAlongSupportOfSecondArgument,
|
||||
// PointSetDist_Scoring.KLDivergence.integrand,
|
||||
// prediction,
|
||||
// answer,
|
||||
// )
|
||||
// |> E.R.fmap(shapeMap(XYShape.T.filterYValues(Js.Float.isFinite)))
|
||||
// |> E.R.fmap(integralEndY)
|
||||
// }
|
||||
|
||||
let klDivergence = (prediction: t, answer: t) => {
|
||||
let newShape = XYShape.PointwiseCombination.combineAlongSupportOfSecondArgument2(
|
||||
let newShape = XYShape.PointwiseCombination.combineAlongSupportOfSecondArgument(
|
||||
PointSetDist_Scoring.KLDivergence.integrand,
|
||||
prediction.xyShape,
|
||||
answer.xyShape,
|
||||
)
|
||||
let generateContinuousDistFromXYShape: XYShape.xyShape => t = xyShape => {
|
||||
let xyShapeToContinuous: XYShape.xyShape => t = xyShape => {
|
||||
xyShape: xyShape,
|
||||
interpolation: #Linear,
|
||||
integralSumCache: None,
|
||||
integralCache: None,
|
||||
}
|
||||
let _ = Js.Console.log2("prediction", prediction)
|
||||
let _ = Js.Console.log2("answer", answer)
|
||||
let _ = Js.Console.log2("newShape", newShape)
|
||||
switch newShape {
|
||||
| Ok(tshape) => Ok(integralEndY(generateContinuousDistFromXYShape(tshape)))
|
||||
| Error(errormessage) => Error(errormessage)
|
||||
}
|
||||
//|> E.R.fmap(shapeMap(XYShape.T.filterYValues(Js.Float.isFinite)))
|
||||
//|> E.R.fmap(integralEndY)
|
||||
newShape->E.R2.fmap(x => x->xyShapeToContinuous->integralEndY)
|
||||
}
|
||||
})
|
||||
|
||||
|
|
|
@ -230,7 +230,7 @@ module T = Dist({
|
|||
|
||||
let klDivergence = (prediction: t, answer: t) => {
|
||||
combinePointwise(
|
||||
~combiner=XYShape.PointwiseCombination.combineAlongSupportOfSecondArgument,
|
||||
~combiner=XYShape.PointwiseCombination.combineAlongSupportOfSecondArgument0,
|
||||
~fn=PointSetDist_Scoring.KLDivergence.integrand,
|
||||
prediction,
|
||||
answer,
|
||||
|
|
|
@ -396,8 +396,9 @@ module T = {
|
|||
| (#ByWeight, #Uniform(n)) =>
|
||||
// In `ByWeight mode, uniform distributions get special treatment because we need two x's
|
||||
// on either side for proper rendering (just left and right of the discontinuities).
|
||||
let dx = 0.00001 *. (n.high -. n.low)
|
||||
[n.low -. dx, n.low +. dx, n.high -. dx, n.high +. dx]
|
||||
let distance = n.high -. n.low
|
||||
let dx = MagicNumbers.Epsilon.ten *. distance
|
||||
[n.low -. dx, n.low, n.low +. dx, n.high -. dx, n.high, n.high +. dx]
|
||||
| (#ByWeight, _) =>
|
||||
let ys = E.A.Floats.range(minCdfValue, maxCdfValue, n)
|
||||
ys |> E.A.fmap(y => inv(y, dist))
|
||||
|
|
|
@ -391,7 +391,7 @@ module PointwiseCombination = {
|
|||
`)
|
||||
|
||||
// This function is used for kl divergence
|
||||
let combineAlongSupportOfSecondArgument: (
|
||||
let combineAlongSupportOfSecondArgument0: (
|
||||
(float, float) => result<float, Operation.Error.t>,
|
||||
interpolator,
|
||||
T.t,
|
||||
|
@ -455,14 +455,14 @@ module PointwiseCombination = {
|
|||
T.filterOkYs(newXs, newYs)->Ok
|
||||
}
|
||||
|
||||
let getApproximatePdfOfContinuousDistributionAtPoint: (xyShape, float) => option<float> = (
|
||||
dist: xyShape,
|
||||
point: float,
|
||||
let approximatePdfOfContinuousDistributionAtPoint: (xyShape, float) => option<float> = (
|
||||
distShape,
|
||||
point,
|
||||
) => {
|
||||
let closestFromBelowIndex = E.A.reducei(dist.xs, None, (accumulator, item, index) =>
|
||||
let closestFromBelowIndex = E.A.reducei(distShape.xs, None, (accumulator, item, index) =>
|
||||
item < point ? Some(index) : accumulator
|
||||
) // This could be made more efficient by taking advantage of the fact that these are ordered
|
||||
let closestFromAboveIndexOption = Belt.Array.getIndexBy(dist.xs, item => item > point)
|
||||
let closestFromAboveIndexOption = Belt.Array.getIndexBy(distShape.xs, item => item > point)
|
||||
|
||||
let weightedMean = (
|
||||
point: float,
|
||||
|
@ -480,30 +480,28 @@ module PointwiseCombination = {
|
|||
|
||||
let result = switch (closestFromBelowIndex, closestFromAboveIndexOption) {
|
||||
| (None, None) => None // all are smaller, and all are larger
|
||||
| (None, Some(i)) => Some(0.0) // none are smaller, all are larger
|
||||
| (Some(i), None) => Some(0.0) // all are smaller, none are larger
|
||||
| (None, Some(_)) => Some(0.0) // none are smaller, all are larger
|
||||
| (Some(_), None) => Some(0.0) // all are smaller, none are larger
|
||||
| (Some(i), Some(j)) =>
|
||||
Some(weightedMean(point, dist.xs[i], dist.xs[j], dist.ys[i], dist.ys[j])) // there is a lowerBound and an upperBound.
|
||||
Some(weightedMean(point, distShape.xs[i], distShape.xs[j], distShape.ys[i], distShape.ys[j])) // there is a lowerBound and an upperBound.
|
||||
}
|
||||
|
||||
result
|
||||
}
|
||||
|
||||
let combineAlongSupportOfSecondArgument2: (
|
||||
let combineAlongSupportOfSecondArgument: (
|
||||
(float, float) => result<float, Operation.Error.t>,
|
||||
T.t,
|
||||
T.t,
|
||||
) => result<T.t, Operation.Error.t> = (fn, prediction, answer) => {
|
||||
let combineWithFn = (x: float, i: int) => {
|
||||
let answerX = x
|
||||
let combineWithFn = (answerX: float, i: int) => {
|
||||
let answerY = answer.ys[i]
|
||||
let predictionY = getApproximatePdfOfContinuousDistributionAtPoint(prediction, answerX)
|
||||
let predictionY = approximatePdfOfContinuousDistributionAtPoint(prediction, answerX)
|
||||
let wrappedResult = E.O.fmap(x => fn(x, answerY), predictionY)
|
||||
let result = switch wrappedResult {
|
||||
switch wrappedResult {
|
||||
| Some(x) => x
|
||||
| None => Error(Operation.LogicallyInconsistentPathwayError)
|
||||
}
|
||||
result
|
||||
}
|
||||
let newYsWithError = Js.Array.mapi((x, i) => combineWithFn(x, i), answer.xs)
|
||||
let newYsOrError = E.A.R.firstErrorOrOpen(newYsWithError)
|
||||
|
@ -512,7 +510,6 @@ module PointwiseCombination = {
|
|||
| Error(b) => Error(b)
|
||||
}
|
||||
|
||||
// T.filterOkYs(newXs, newYs)->Ok
|
||||
result
|
||||
}
|
||||
|
||||
|
|
Loading…
Reference in New Issue
Block a user