response to CR
Value: [1e-5 to 5e-3]
This commit is contained in:
parent
aff2c622f7
commit
06352357a2
|
@ -27,8 +27,11 @@ describe("kl divergence", () => {
|
|||
}
|
||||
})
|
||||
}
|
||||
// The pair on the right (the answer) can be wider than the pair on the left (the prediction), but not the other way around.
|
||||
testUniform(0.0, 1.0, -1.0, 2.0)
|
||||
testUniform(0.0, 1.0, 0.0, 2.0)
|
||||
testUniform(0.0, 1.0, 0.0, 2.0) // equal left endpoints
|
||||
testUniform(0.0, 1.0, -1.0, 1.0) // equal rightendpoints
|
||||
testUniform(0.0, 1e1, 0.0, 1e1) // equal (klDivergence = 0)
|
||||
// testUniform(-1.0, 1.0, 0.0, 2.0)
|
||||
|
||||
test("of two normals is equal to the formula", () => {
|
||||
|
@ -58,6 +61,7 @@ describe("kl divergence", () => {
|
|||
})
|
||||
|
||||
describe("combine along support test", () => {
|
||||
// This tests the version of the function that we're NOT using. Haven't deleted the test in case we use the code later.
|
||||
test("combine along support test", _ => {
|
||||
let combineAlongSupportOfSecondArgument = XYShape.PointwiseCombination.combineAlongSupportOfSecondArgument0
|
||||
let lowAnswer = 0.0
|
||||
|
|
|
@ -199,13 +199,7 @@ module T = Dist({
|
|||
let klDivergence = (t1: t, t2: t) =>
|
||||
switch (t1, t2) {
|
||||
| (Continuous(t1), Continuous(t2)) => Continuous.T.klDivergence(t1, t2)
|
||||
| (Discrete(t1), Discrete(t2)) => Discrete.T.klDivergence(t1, t2)
|
||||
| (Mixed(t1), Mixed(t2)) => Mixed.T.klDivergence(t1, t2)
|
||||
| _ => {
|
||||
let t1' = toMixed(t1)
|
||||
let t2' = toMixed(t2)
|
||||
Mixed.T.klDivergence(t1', t2')
|
||||
}
|
||||
| _ => Error(NotYetImplemented)
|
||||
}
|
||||
})
|
||||
|
||||
|
|
|
@ -56,6 +56,7 @@ type operationError =
|
|||
| InfinityError
|
||||
| NegativeInfinityError
|
||||
| LogicallyInconsistentPathwayError
|
||||
| NotYetImplemented // should be removed when `klDivergence` for mixed and discrete is implemented.
|
||||
|
||||
@genType
|
||||
module Error = {
|
||||
|
@ -69,6 +70,7 @@ module Error = {
|
|||
| InfinityError => "Operation returned positive infinity"
|
||||
| NegativeInfinityError => "Operation returned negative infinity"
|
||||
| LogicallyInconsistentPathwayError => "This pathway should have been logically unreachable"
|
||||
| NotYetImplemented => "This pathway is not yet implemented"
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -440,9 +440,6 @@ module PointwiseCombination = {
|
|||
} else {
|
||||
i := i.contents + 1
|
||||
None
|
||||
// (0.0, 0.0, 0.0) // for the function I have in mind, this will error out
|
||||
// exception PointwiseCombinationError
|
||||
// raise(PointwiseCombinationError)
|
||||
}
|
||||
}
|
||||
switch someTuple {
|
||||
|
@ -456,40 +453,6 @@ module PointwiseCombination = {
|
|||
T.filterOkYs(newXs, newYs)->Ok
|
||||
}
|
||||
|
||||
let approximatePdfOfContinuousDistributionAtPoint: (xyShape, float) => option<float> = (
|
||||
distShape,
|
||||
point,
|
||||
) => {
|
||||
let closestFromBelowIndex = E.A.reducei(distShape.xs, None, (accumulator, item, index) =>
|
||||
item < point ? Some(index) : accumulator
|
||||
) // This could be made more efficient by taking advantage of the fact that these are ordered
|
||||
let closestFromAboveIndexOption = Belt.Array.getIndexBy(distShape.xs, item => item > point)
|
||||
|
||||
let weightedMean = (
|
||||
point: float,
|
||||
closestFromBelow: float,
|
||||
closestFromAbove: float,
|
||||
valueclosestFromBelow,
|
||||
valueclosestFromAbove,
|
||||
): float => {
|
||||
let distance = closestFromAbove -. closestFromBelow
|
||||
let w1 = (point -. closestFromBelow) /. distance
|
||||
let w2 = (closestFromAbove -. point) /. distance
|
||||
let result = w1 *. valueclosestFromAbove +. w2 *. valueclosestFromBelow
|
||||
result
|
||||
}
|
||||
|
||||
let result = switch (closestFromBelowIndex, closestFromAboveIndexOption) {
|
||||
| (None, None) => None // all are smaller, and all are larger
|
||||
| (None, Some(_)) => Some(0.0) // none are smaller, all are larger
|
||||
| (Some(_), None) => Some(0.0) // all are smaller, none are larger
|
||||
| (Some(i), Some(j)) =>
|
||||
Some(weightedMean(point, distShape.xs[i], distShape.xs[j], distShape.ys[i], distShape.ys[j])) // there is a lowerBound and an upperBound.
|
||||
}
|
||||
|
||||
result
|
||||
}
|
||||
|
||||
// This function is used for klDivergence
|
||||
let combineAlongSupportOfSecondArgument: (
|
||||
(float, float) => result<float, Operation.Error.t>,
|
||||
|
@ -498,12 +461,8 @@ module PointwiseCombination = {
|
|||
) => result<T.t, Operation.Error.t> = (fn, prediction, answer) => {
|
||||
let combineWithFn = (answerX: float, i: int) => {
|
||||
let answerY = answer.ys[i]
|
||||
let predictionY = approximatePdfOfContinuousDistributionAtPoint(prediction, answerX)
|
||||
let wrappedResult = E.O.fmap(x => fn(x, answerY), predictionY)
|
||||
switch wrappedResult {
|
||||
| Some(x) => x
|
||||
| None => Error(Operation.LogicallyInconsistentPathwayError)
|
||||
}
|
||||
let predictionY = XtoY.linear(answerX, prediction)
|
||||
fn(predictionY, answerY)
|
||||
}
|
||||
let newYsWithError = Js.Array.mapi((x, i) => combineWithFn(x, i), answer.xs)
|
||||
let newYsOrError = E.A.R.firstErrorOrOpen(newYsWithError)
|
||||
|
|
Loading…
Reference in New Issue
Block a user