Merge pull request #512 from quantified-uncertainty/kldivergence-mixed

`klDivergence` on mixed distributions
This commit is contained in:
Quinn 2022-05-12 11:26:01 -04:00 committed by GitHub
commit 937458cd05
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 122 additions and 5 deletions

View File

@ -11,6 +11,7 @@ let triangularDist: DistributionTypes.genericDist = Symbolic(
)
let exponentialDist: DistributionTypes.genericDist = Symbolic(#Exponential({rate: 2.0}))
let uniformDist: DistributionTypes.genericDist = Symbolic(#Uniform({low: 9.0, high: 10.0}))
let uniformDist2: DistributionTypes.genericDist = Symbolic(#Uniform({low: 8.0, high: 11.0}))
let floatDist: DistributionTypes.genericDist = Symbolic(#Float(1e1))
exception KlFailed

View File

@ -3,9 +3,15 @@ open Expect
open TestHelpers
open GenericDist_Fixtures
// integral from low to high of 1 / (high - low) log(normal(mean, stdev)(x) / (1 / (high - low))) dx
let klNormalUniform = (mean, stdev, low, high): float =>
-.Js.Math.log((high -. low) /. Js.Math.sqrt(2.0 *. MagicNumbers.Math.pi *. stdev ** 2.0)) +.
1.0 /.
stdev ** 2.0 *.
(mean ** 2.0 -. (high +. low) *. mean +. (low ** 2.0 +. high *. low +. high ** 2.0) /. 3.0)
describe("klDivergence: continuous -> continuous -> float", () => {
let klDivergence = DistributionOperation.Constructors.klDivergence(~env)
exception KlFailed
let testUniform = (lowAnswer, highAnswer, lowPrediction, highPrediction) => {
test("of two uniforms is equal to the analytic expression", () => {
@ -59,6 +65,20 @@ describe("klDivergence: continuous -> continuous -> float", () => {
}
}
})
test("of a normal and a uniform is equal to the formula", () => {
let prediction = normalDist10
let answer = uniformDist
let kl = klDivergence(prediction, answer)
let analyticalKl = klNormalUniform(10.0, 2.0, 9.0, 10.0)
switch kl {
| Ok(kl') => kl'->expect->toBeSoCloseTo(analyticalKl, ~digits=1)
| Error(err) => {
Js.Console.log(DistributionTypes.Error.toString(err))
raise(KlFailed)
}
}
})
})
describe("klDivergence: discrete -> discrete -> float", () => {
@ -96,6 +116,64 @@ describe("klDivergence: discrete -> discrete -> float", () => {
})
})
describe("klDivergence: mixed -> mixed -> float", () => {
let klDivergence = DistributionOperation.Constructors.klDivergence(~env)
let mixture' = a => DistributionTypes.DistributionOperation.Mixture(a)
let mixture = a => {
let dist' = a->mixture'->run
switch dist' {
| Dist(dist) => dist
| _ => raise(MixtureFailed)
}
}
let a = [(point1, 1.0), (uniformDist, 1.0)]->mixture
let b = [(point1, 1.0), (floatDist, 1.0), (normalDist10, 1.0)]->mixture
let c = [(point1, 1.0), (point2, 1.0), (point3, 1.0), (uniformDist, 1.0)]->mixture
let d =
[(point1, 1.0), (point2, 1.0), (point3, 1.0), (floatDist, 1.0), (uniformDist2, 1.0)]->mixture
test("finite klDivergence produces correct answer", () => {
let prediction = b
let answer = a
let kl = klDivergence(prediction, answer)
// high = 10; low = 9; mean = 10; stdev = 2
let analyticalKlContinuousPart = klNormalUniform(10.0, 2.0, 9.0, 10.0) /. 2.0
let analyticalKlDiscretePart = 1.0 /. 2.0 *. Js.Math.log(2.0 /. 1.0)
switch kl {
| Ok(kl') =>
kl'->expect->toBeSoCloseTo(analyticalKlContinuousPart +. analyticalKlDiscretePart, ~digits=1)
| Error(err) =>
Js.Console.log(DistributionTypes.Error.toString(err))
raise(KlFailed)
}
})
test("returns infinity when infinite", () => {
let prediction = a
let answer = b
let kl = klDivergence(prediction, answer)
switch kl {
| Ok(kl') => kl'->expect->toEqual(infinity)
| Error(err) =>
Js.Console.log(DistributionTypes.Error.toString(err))
raise(KlFailed)
}
})
test("finite klDivergence produces correct answer", () => {
let prediction = d
let answer = c
let kl = klDivergence(prediction, answer)
let analyticalKlContinuousPart = Js.Math.log((11.0 -. 8.0) /. (10.0 -. 9.0)) /. 4.0 // 4 = length of c' array
let analyticalKlDiscretePart = 3.0 /. 4.0 *. Js.Math.log(4.0 /. 3.0)
switch kl {
| Ok(kl') =>
kl'->expect->toBeSoCloseTo(analyticalKlContinuousPart +. analyticalKlDiscretePart, ~digits=1)
| Error(err) =>
Js.Console.log(DistributionTypes.Error.toString(err))
raise(KlFailed)
}
})
})
describe("combineAlongSupportOfSecondArgument0", () => {
// This tests the version of the function that we're NOT using. Haven't deleted the test in case we use the code later.
test("test on two uniforms", _ => {

View File

@ -302,10 +302,9 @@ module T = Dist({
}
let klDivergence = (prediction: t, answer: t) => {
Error(Operation.NotYetImplemented)
// combinePointwise(PointSetDist_Scoring.KLDivergence.integrand, prediction, answer) |> E.R.fmap(
// integralEndY,
// )
let klDiscretePart = Discrete.T.klDivergence(prediction.discrete, answer.discrete)
let klContinuousPart = Continuous.T.klDivergence(prediction.continuous, answer.continuous)
E.R.merge(klDiscretePart, klContinuousPart)->E.R2.fmap(t => fst(t) +. snd(t))
}
})

View File

@ -200,6 +200,7 @@ module T = Dist({
switch (t1, t2) {
| (Continuous(t1), Continuous(t2)) => Continuous.T.klDivergence(t1, t2)
| (Discrete(t1), Discrete(t2)) => Discrete.T.klDivergence(t1, t2)
| (Mixed(t1), Mixed(t2)) => Mixed.T.klDivergence(t1, t2)
| _ => Error(NotYetImplemented)
}
})

View File

@ -453,6 +453,44 @@ module PointwiseCombination = {
T.filterOkYs(newXs, newYs)->Ok
}
/* *Dead code*: Nuño wrote this function to try to increase precision, but it didn't work.
If another traveler comes through with a similar idea, we hope this implementation will help them.
By "enrich" we mean to increase granularity.
*/
let enrichXyShape = (t: T.t): T.t => {
let defaultEnrichmentFactor = 10
let length = E.A.length(t.xs)
let points =
length < MagicNumbers.Environment.defaultXYPointLength
? defaultEnrichmentFactor * MagicNumbers.Environment.defaultXYPointLength / length
: defaultEnrichmentFactor
let getInBetween = (x1: float, x2: float): array<float> => {
if abs_float(x1 -. x2) < 2.0 *. MagicNumbers.Epsilon.seven {
[x1]
} else {
let newPointsArray = Belt.Array.makeBy(points - 1, i => i)
// don't repeat the x2 point, it will be gotten in the next iteration.
let result = Js.Array.mapi((pos, i) =>
if i == 0 {
x1
} else {
let points' = Belt.Float.fromInt(points)
let pos' = Belt.Float.fromInt(pos)
x1 *. (points' -. pos') /. points' +. x2 *. pos' /. points'
}
, newPointsArray)
result
}
}
let newXsUnflattened = Js.Array.mapi(
(x, i) => i < length - 2 ? getInBetween(x, t.xs[i + 1]) : [x],
t.xs,
)
let newXs = Belt.Array.concatMany(newXsUnflattened)
let newYs = E.A.fmap(x => XtoY.linear(x, t), newXs)
{xs: newXs, ys: newYs}
}
// This function is used for klDivergence
let combineAlongSupportOfSecondArgument: (
(float, float) => result<float, Operation.Error.t>,