All three tests pass
- `uniform` `toPointSet` method has been changed for numerical stability. Value: [1e-1 to 1.75e0]
This commit is contained in:
parent
d9a40c973a
commit
722bfc6366
|
@ -70,13 +70,14 @@ describe("kl divergence", () => {
|
||||||
let prediction =
|
let prediction =
|
||||||
normalMakeR(mean1, stdev1)->E.R2.errMap(s => DistributionTypes.ArgumentError(s))
|
normalMakeR(mean1, stdev1)->E.R2.errMap(s => DistributionTypes.ArgumentError(s))
|
||||||
let answer = normalMakeR(mean2, stdev2)->E.R2.errMap(s => DistributionTypes.ArgumentError(s))
|
let answer = normalMakeR(mean2, stdev2)->E.R2.errMap(s => DistributionTypes.ArgumentError(s))
|
||||||
|
// https://stats.stackexchange.com/questions/7440/kl-divergence-between-two-univariate-gaussians
|
||||||
let analyticalKl =
|
let analyticalKl =
|
||||||
Js.Math.log(stdev1 /. stdev2) +.
|
Js.Math.log(stdev1 /. stdev2) +.
|
||||||
(stdev2 ** 2.0 +. (mean2 -. mean1) ** 2.0) /. (2.0 *. stdev1 ** 2.0) -. 0.5
|
(stdev2 ** 2.0 +. (mean2 -. mean1) ** 2.0) /. (2.0 *. stdev1 ** 2.0) -. 0.5
|
||||||
let kl = E.R.liftJoin2(klDivergence, prediction, answer)
|
let kl = E.R.liftJoin2(klDivergence, prediction, answer)
|
||||||
|
|
||||||
Js.Console.log2("Analytical: ", analyticalKl)
|
// Js.Console.log2("Analytical: ", analyticalKl)
|
||||||
Js.Console.log2("Computed: ", kl)
|
// Js.Console.log2("Computed: ", kl)
|
||||||
|
|
||||||
switch kl {
|
switch kl {
|
||||||
| Ok(kl') => kl'->expect->toBeCloseTo(analyticalKl)
|
| Ok(kl') => kl'->expect->toBeCloseTo(analyticalKl)
|
||||||
|
@ -91,7 +92,7 @@ describe("kl divergence", () => {
|
||||||
describe("combine along support test", () => {
|
describe("combine along support test", () => {
|
||||||
Skip.test("combine along support test", _ => {
|
Skip.test("combine along support test", _ => {
|
||||||
// doesn't matter
|
// doesn't matter
|
||||||
let combineAlongSupportOfSecondArgument = XYShape.PointwiseCombination.combineAlongSupportOfSecondArgument
|
let combineAlongSupportOfSecondArgument = XYShape.PointwiseCombination.combineAlongSupportOfSecondArgument0
|
||||||
let lowAnswer = 0.0
|
let lowAnswer = 0.0
|
||||||
let highAnswer = 1.0
|
let highAnswer = 1.0
|
||||||
let lowPrediction = 0.0
|
let lowPrediction = 0.0
|
||||||
|
|
|
@ -271,38 +271,30 @@ module T = Dist({
|
||||||
let variance = (t: t): float =>
|
let variance = (t: t): float =>
|
||||||
XYShape.Analysis.getVarianceDangerously(t, mean, Analysis.getMeanOfSquares)
|
XYShape.Analysis.getVarianceDangerously(t, mean, Analysis.getMeanOfSquares)
|
||||||
|
|
||||||
let klDivergence0 = (prediction: t, answer: t) => {
|
// let klDivergence0 = (prediction: t, answer: t) => {
|
||||||
combinePointwise(
|
// combinePointwise(
|
||||||
~combiner=XYShape.PointwiseCombination.combineAlongSupportOfSecondArgument,
|
// ~combiner=XYShape.PointwiseCombination.combineAlongSupportOfSecondArgument,
|
||||||
PointSetDist_Scoring.KLDivergence.integrand,
|
// PointSetDist_Scoring.KLDivergence.integrand,
|
||||||
prediction,
|
// prediction,
|
||||||
answer,
|
// answer,
|
||||||
)
|
// )
|
||||||
|> E.R.fmap(shapeMap(XYShape.T.filterYValues(Js.Float.isFinite)))
|
// |> E.R.fmap(shapeMap(XYShape.T.filterYValues(Js.Float.isFinite)))
|
||||||
|> E.R.fmap(integralEndY)
|
// |> E.R.fmap(integralEndY)
|
||||||
}
|
// }
|
||||||
|
|
||||||
let klDivergence = (prediction: t, answer: t) => {
|
let klDivergence = (prediction: t, answer: t) => {
|
||||||
let newShape = XYShape.PointwiseCombination.combineAlongSupportOfSecondArgument2(
|
let newShape = XYShape.PointwiseCombination.combineAlongSupportOfSecondArgument(
|
||||||
PointSetDist_Scoring.KLDivergence.integrand,
|
PointSetDist_Scoring.KLDivergence.integrand,
|
||||||
prediction.xyShape,
|
prediction.xyShape,
|
||||||
answer.xyShape,
|
answer.xyShape,
|
||||||
)
|
)
|
||||||
let generateContinuousDistFromXYShape: XYShape.xyShape => t = xyShape => {
|
let xyShapeToContinuous: XYShape.xyShape => t = xyShape => {
|
||||||
xyShape: xyShape,
|
xyShape: xyShape,
|
||||||
interpolation: #Linear,
|
interpolation: #Linear,
|
||||||
integralSumCache: None,
|
integralSumCache: None,
|
||||||
integralCache: None,
|
integralCache: None,
|
||||||
}
|
}
|
||||||
let _ = Js.Console.log2("prediction", prediction)
|
newShape->E.R2.fmap(x => x->xyShapeToContinuous->integralEndY)
|
||||||
let _ = Js.Console.log2("answer", answer)
|
|
||||||
let _ = Js.Console.log2("newShape", newShape)
|
|
||||||
switch newShape {
|
|
||||||
| Ok(tshape) => Ok(integralEndY(generateContinuousDistFromXYShape(tshape)))
|
|
||||||
| Error(errormessage) => Error(errormessage)
|
|
||||||
}
|
|
||||||
//|> E.R.fmap(shapeMap(XYShape.T.filterYValues(Js.Float.isFinite)))
|
|
||||||
//|> E.R.fmap(integralEndY)
|
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
|
|
||||||
|
|
|
@ -230,7 +230,7 @@ module T = Dist({
|
||||||
|
|
||||||
let klDivergence = (prediction: t, answer: t) => {
|
let klDivergence = (prediction: t, answer: t) => {
|
||||||
combinePointwise(
|
combinePointwise(
|
||||||
~combiner=XYShape.PointwiseCombination.combineAlongSupportOfSecondArgument,
|
~combiner=XYShape.PointwiseCombination.combineAlongSupportOfSecondArgument0,
|
||||||
~fn=PointSetDist_Scoring.KLDivergence.integrand,
|
~fn=PointSetDist_Scoring.KLDivergence.integrand,
|
||||||
prediction,
|
prediction,
|
||||||
answer,
|
answer,
|
||||||
|
|
|
@ -396,8 +396,9 @@ module T = {
|
||||||
| (#ByWeight, #Uniform(n)) =>
|
| (#ByWeight, #Uniform(n)) =>
|
||||||
// In `ByWeight mode, uniform distributions get special treatment because we need two x's
|
// In `ByWeight mode, uniform distributions get special treatment because we need two x's
|
||||||
// on either side for proper rendering (just left and right of the discontinuities).
|
// on either side for proper rendering (just left and right of the discontinuities).
|
||||||
let dx = 0.00001 *. (n.high -. n.low)
|
let distance = n.high -. n.low
|
||||||
[n.low -. dx, n.low +. dx, n.high -. dx, n.high +. dx]
|
let dx = MagicNumbers.Epsilon.ten *. distance
|
||||||
|
[n.low -. dx, n.low, n.low +. dx, n.high -. dx, n.high, n.high +. dx]
|
||||||
| (#ByWeight, _) =>
|
| (#ByWeight, _) =>
|
||||||
let ys = E.A.Floats.range(minCdfValue, maxCdfValue, n)
|
let ys = E.A.Floats.range(minCdfValue, maxCdfValue, n)
|
||||||
ys |> E.A.fmap(y => inv(y, dist))
|
ys |> E.A.fmap(y => inv(y, dist))
|
||||||
|
|
|
@ -391,7 +391,7 @@ module PointwiseCombination = {
|
||||||
`)
|
`)
|
||||||
|
|
||||||
// This function is used for kl divergence
|
// This function is used for kl divergence
|
||||||
let combineAlongSupportOfSecondArgument: (
|
let combineAlongSupportOfSecondArgument0: (
|
||||||
(float, float) => result<float, Operation.Error.t>,
|
(float, float) => result<float, Operation.Error.t>,
|
||||||
interpolator,
|
interpolator,
|
||||||
T.t,
|
T.t,
|
||||||
|
@ -455,14 +455,14 @@ module PointwiseCombination = {
|
||||||
T.filterOkYs(newXs, newYs)->Ok
|
T.filterOkYs(newXs, newYs)->Ok
|
||||||
}
|
}
|
||||||
|
|
||||||
let getApproximatePdfOfContinuousDistributionAtPoint: (xyShape, float) => option<float> = (
|
let approximatePdfOfContinuousDistributionAtPoint: (xyShape, float) => option<float> = (
|
||||||
dist: xyShape,
|
distShape,
|
||||||
point: float,
|
point,
|
||||||
) => {
|
) => {
|
||||||
let closestFromBelowIndex = E.A.reducei(dist.xs, None, (accumulator, item, index) =>
|
let closestFromBelowIndex = E.A.reducei(distShape.xs, None, (accumulator, item, index) =>
|
||||||
item < point ? Some(index) : accumulator
|
item < point ? Some(index) : accumulator
|
||||||
) // This could be made more efficient by taking advantage of the fact that these are ordered
|
) // This could be made more efficient by taking advantage of the fact that these are ordered
|
||||||
let closestFromAboveIndexOption = Belt.Array.getIndexBy(dist.xs, item => item > point)
|
let closestFromAboveIndexOption = Belt.Array.getIndexBy(distShape.xs, item => item > point)
|
||||||
|
|
||||||
let weightedMean = (
|
let weightedMean = (
|
||||||
point: float,
|
point: float,
|
||||||
|
@ -480,30 +480,28 @@ module PointwiseCombination = {
|
||||||
|
|
||||||
let result = switch (closestFromBelowIndex, closestFromAboveIndexOption) {
|
let result = switch (closestFromBelowIndex, closestFromAboveIndexOption) {
|
||||||
| (None, None) => None // all are smaller, and all are larger
|
| (None, None) => None // all are smaller, and all are larger
|
||||||
| (None, Some(i)) => Some(0.0) // none are smaller, all are larger
|
| (None, Some(_)) => Some(0.0) // none are smaller, all are larger
|
||||||
| (Some(i), None) => Some(0.0) // all are smaller, none are larger
|
| (Some(_), None) => Some(0.0) // all are smaller, none are larger
|
||||||
| (Some(i), Some(j)) =>
|
| (Some(i), Some(j)) =>
|
||||||
Some(weightedMean(point, dist.xs[i], dist.xs[j], dist.ys[i], dist.ys[j])) // there is a lowerBound and an upperBound.
|
Some(weightedMean(point, distShape.xs[i], distShape.xs[j], distShape.ys[i], distShape.ys[j])) // there is a lowerBound and an upperBound.
|
||||||
}
|
}
|
||||||
|
|
||||||
result
|
result
|
||||||
}
|
}
|
||||||
|
|
||||||
let combineAlongSupportOfSecondArgument2: (
|
let combineAlongSupportOfSecondArgument: (
|
||||||
(float, float) => result<float, Operation.Error.t>,
|
(float, float) => result<float, Operation.Error.t>,
|
||||||
T.t,
|
T.t,
|
||||||
T.t,
|
T.t,
|
||||||
) => result<T.t, Operation.Error.t> = (fn, prediction, answer) => {
|
) => result<T.t, Operation.Error.t> = (fn, prediction, answer) => {
|
||||||
let combineWithFn = (x: float, i: int) => {
|
let combineWithFn = (answerX: float, i: int) => {
|
||||||
let answerX = x
|
|
||||||
let answerY = answer.ys[i]
|
let answerY = answer.ys[i]
|
||||||
let predictionY = getApproximatePdfOfContinuousDistributionAtPoint(prediction, answerX)
|
let predictionY = approximatePdfOfContinuousDistributionAtPoint(prediction, answerX)
|
||||||
let wrappedResult = E.O.fmap(x => fn(x, answerY), predictionY)
|
let wrappedResult = E.O.fmap(x => fn(x, answerY), predictionY)
|
||||||
let result = switch wrappedResult {
|
switch wrappedResult {
|
||||||
| Some(x) => x
|
| Some(x) => x
|
||||||
| None => Error(Operation.LogicallyInconsistentPathwayError)
|
| None => Error(Operation.LogicallyInconsistentPathwayError)
|
||||||
}
|
}
|
||||||
result
|
|
||||||
}
|
}
|
||||||
let newYsWithError = Js.Array.mapi((x, i) => combineWithFn(x, i), answer.xs)
|
let newYsWithError = Js.Array.mapi((x, i) => combineWithFn(x, i), answer.xs)
|
||||||
let newYsOrError = E.A.R.firstErrorOrOpen(newYsWithError)
|
let newYsOrError = E.A.R.firstErrorOrOpen(newYsWithError)
|
||||||
|
@ -512,7 +510,6 @@ module PointwiseCombination = {
|
||||||
| Error(b) => Error(b)
|
| Error(b) => Error(b)
|
||||||
}
|
}
|
||||||
|
|
||||||
// T.filterOkYs(newXs, newYs)->Ok
|
|
||||||
result
|
result
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue
Block a user