response to CR

Value: [1e-5 to 5e-3]
This commit is contained in:
Quinn Dougherty 2022-05-09 11:14:33 -04:00
parent aff2c622f7
commit 06352357a2
4 changed files with 10 additions and 51 deletions

View File

@ -27,8 +27,11 @@ describe("kl divergence", () => {
} }
}) })
} }
// The pair on the right (the answer) can be wider than the pair on the left (the prediction), but not the other way around.
testUniform(0.0, 1.0, -1.0, 2.0) testUniform(0.0, 1.0, -1.0, 2.0)
testUniform(0.0, 1.0, 0.0, 2.0) testUniform(0.0, 1.0, 0.0, 2.0) // equal left endpoints
testUniform(0.0, 1.0, -1.0, 1.0) // equal rightendpoints
testUniform(0.0, 1e1, 0.0, 1e1) // equal (klDivergence = 0)
// testUniform(-1.0, 1.0, 0.0, 2.0) // testUniform(-1.0, 1.0, 0.0, 2.0)
test("of two normals is equal to the formula", () => { test("of two normals is equal to the formula", () => {
@ -58,6 +61,7 @@ describe("kl divergence", () => {
}) })
describe("combine along support test", () => { describe("combine along support test", () => {
// This tests the version of the function that we're NOT using. Haven't deleted the test in case we use the code later.
test("combine along support test", _ => { test("combine along support test", _ => {
let combineAlongSupportOfSecondArgument = XYShape.PointwiseCombination.combineAlongSupportOfSecondArgument0 let combineAlongSupportOfSecondArgument = XYShape.PointwiseCombination.combineAlongSupportOfSecondArgument0
let lowAnswer = 0.0 let lowAnswer = 0.0

View File

@ -199,13 +199,7 @@ module T = Dist({
let klDivergence = (t1: t, t2: t) => let klDivergence = (t1: t, t2: t) =>
switch (t1, t2) { switch (t1, t2) {
| (Continuous(t1), Continuous(t2)) => Continuous.T.klDivergence(t1, t2) | (Continuous(t1), Continuous(t2)) => Continuous.T.klDivergence(t1, t2)
| (Discrete(t1), Discrete(t2)) => Discrete.T.klDivergence(t1, t2) | _ => Error(NotYetImplemented)
| (Mixed(t1), Mixed(t2)) => Mixed.T.klDivergence(t1, t2)
| _ => {
let t1' = toMixed(t1)
let t2' = toMixed(t2)
Mixed.T.klDivergence(t1', t2')
}
} }
}) })

View File

@ -56,6 +56,7 @@ type operationError =
| InfinityError | InfinityError
| NegativeInfinityError | NegativeInfinityError
| LogicallyInconsistentPathwayError | LogicallyInconsistentPathwayError
| NotYetImplemented // should be removed when `klDivergence` for mixed and discrete is implemented.
@genType @genType
module Error = { module Error = {
@ -69,6 +70,7 @@ module Error = {
| InfinityError => "Operation returned positive infinity" | InfinityError => "Operation returned positive infinity"
| NegativeInfinityError => "Operation returned negative infinity" | NegativeInfinityError => "Operation returned negative infinity"
| LogicallyInconsistentPathwayError => "This pathway should have been logically unreachable" | LogicallyInconsistentPathwayError => "This pathway should have been logically unreachable"
| NotYetImplemented => "This pathway is not yet implemented"
} }
} }

View File

@ -440,9 +440,6 @@ module PointwiseCombination = {
} else { } else {
i := i.contents + 1 i := i.contents + 1
None None
// (0.0, 0.0, 0.0) // for the function I have in mind, this will error out
// exception PointwiseCombinationError
// raise(PointwiseCombinationError)
} }
} }
switch someTuple { switch someTuple {
@ -456,40 +453,6 @@ module PointwiseCombination = {
T.filterOkYs(newXs, newYs)->Ok T.filterOkYs(newXs, newYs)->Ok
} }
let approximatePdfOfContinuousDistributionAtPoint: (xyShape, float) => option<float> = (
distShape,
point,
) => {
let closestFromBelowIndex = E.A.reducei(distShape.xs, None, (accumulator, item, index) =>
item < point ? Some(index) : accumulator
) // This could be made more efficient by taking advantage of the fact that these are ordered
let closestFromAboveIndexOption = Belt.Array.getIndexBy(distShape.xs, item => item > point)
let weightedMean = (
point: float,
closestFromBelow: float,
closestFromAbove: float,
valueclosestFromBelow,
valueclosestFromAbove,
): float => {
let distance = closestFromAbove -. closestFromBelow
let w1 = (point -. closestFromBelow) /. distance
let w2 = (closestFromAbove -. point) /. distance
let result = w1 *. valueclosestFromAbove +. w2 *. valueclosestFromBelow
result
}
let result = switch (closestFromBelowIndex, closestFromAboveIndexOption) {
| (None, None) => None // all are smaller, and all are larger
| (None, Some(_)) => Some(0.0) // none are smaller, all are larger
| (Some(_), None) => Some(0.0) // all are smaller, none are larger
| (Some(i), Some(j)) =>
Some(weightedMean(point, distShape.xs[i], distShape.xs[j], distShape.ys[i], distShape.ys[j])) // there is a lowerBound and an upperBound.
}
result
}
// This function is used for klDivergence // This function is used for klDivergence
let combineAlongSupportOfSecondArgument: ( let combineAlongSupportOfSecondArgument: (
(float, float) => result<float, Operation.Error.t>, (float, float) => result<float, Operation.Error.t>,
@ -498,12 +461,8 @@ module PointwiseCombination = {
) => result<T.t, Operation.Error.t> = (fn, prediction, answer) => { ) => result<T.t, Operation.Error.t> = (fn, prediction, answer) => {
let combineWithFn = (answerX: float, i: int) => { let combineWithFn = (answerX: float, i: int) => {
let answerY = answer.ys[i] let answerY = answer.ys[i]
let predictionY = approximatePdfOfContinuousDistributionAtPoint(prediction, answerX) let predictionY = XtoY.linear(answerX, prediction)
let wrappedResult = E.O.fmap(x => fn(x, answerY), predictionY) fn(predictionY, answerY)
switch wrappedResult {
| Some(x) => x
| None => Error(Operation.LogicallyInconsistentPathwayError)
}
} }
let newYsWithError = Js.Array.mapi((x, i) => combineWithFn(x, i), answer.xs) let newYsWithError = Js.Array.mapi((x, i) => combineWithFn(x, i), answer.xs)
let newYsOrError = E.A.R.firstErrorOrOpen(newYsWithError) let newYsOrError = E.A.R.firstErrorOrOpen(newYsWithError)