Merge pull request #512 from quantified-uncertainty/kldivergence-mixed
`klDivergence` on mixed distributions
This commit is contained in:
commit
937458cd05
|
@ -11,6 +11,7 @@ let triangularDist: DistributionTypes.genericDist = Symbolic(
|
|||
)
|
||||
let exponentialDist: DistributionTypes.genericDist = Symbolic(#Exponential({rate: 2.0}))
|
||||
let uniformDist: DistributionTypes.genericDist = Symbolic(#Uniform({low: 9.0, high: 10.0}))
|
||||
let uniformDist2: DistributionTypes.genericDist = Symbolic(#Uniform({low: 8.0, high: 11.0}))
|
||||
let floatDist: DistributionTypes.genericDist = Symbolic(#Float(1e1))
|
||||
|
||||
exception KlFailed
|
||||
|
|
|
@ -3,9 +3,15 @@ open Expect
|
|||
open TestHelpers
|
||||
open GenericDist_Fixtures
|
||||
|
||||
// integral from low to high of 1 / (high - low) log(normal(mean, stdev)(x) / (1 / (high - low))) dx
|
||||
let klNormalUniform = (mean, stdev, low, high): float =>
|
||||
-.Js.Math.log((high -. low) /. Js.Math.sqrt(2.0 *. MagicNumbers.Math.pi *. stdev ** 2.0)) +.
|
||||
1.0 /.
|
||||
stdev ** 2.0 *.
|
||||
(mean ** 2.0 -. (high +. low) *. mean +. (low ** 2.0 +. high *. low +. high ** 2.0) /. 3.0)
|
||||
|
||||
describe("klDivergence: continuous -> continuous -> float", () => {
|
||||
let klDivergence = DistributionOperation.Constructors.klDivergence(~env)
|
||||
exception KlFailed
|
||||
|
||||
let testUniform = (lowAnswer, highAnswer, lowPrediction, highPrediction) => {
|
||||
test("of two uniforms is equal to the analytic expression", () => {
|
||||
|
@ -59,6 +65,20 @@ describe("klDivergence: continuous -> continuous -> float", () => {
|
|||
}
|
||||
}
|
||||
})
|
||||
|
||||
test("of a normal and a uniform is equal to the formula", () => {
|
||||
let prediction = normalDist10
|
||||
let answer = uniformDist
|
||||
let kl = klDivergence(prediction, answer)
|
||||
let analyticalKl = klNormalUniform(10.0, 2.0, 9.0, 10.0)
|
||||
switch kl {
|
||||
| Ok(kl') => kl'->expect->toBeSoCloseTo(analyticalKl, ~digits=1)
|
||||
| Error(err) => {
|
||||
Js.Console.log(DistributionTypes.Error.toString(err))
|
||||
raise(KlFailed)
|
||||
}
|
||||
}
|
||||
})
|
||||
})
|
||||
|
||||
describe("klDivergence: discrete -> discrete -> float", () => {
|
||||
|
@ -96,6 +116,64 @@ describe("klDivergence: discrete -> discrete -> float", () => {
|
|||
})
|
||||
})
|
||||
|
||||
describe("klDivergence: mixed -> mixed -> float", () => {
|
||||
let klDivergence = DistributionOperation.Constructors.klDivergence(~env)
|
||||
let mixture' = a => DistributionTypes.DistributionOperation.Mixture(a)
|
||||
let mixture = a => {
|
||||
let dist' = a->mixture'->run
|
||||
switch dist' {
|
||||
| Dist(dist) => dist
|
||||
| _ => raise(MixtureFailed)
|
||||
}
|
||||
}
|
||||
let a = [(point1, 1.0), (uniformDist, 1.0)]->mixture
|
||||
let b = [(point1, 1.0), (floatDist, 1.0), (normalDist10, 1.0)]->mixture
|
||||
let c = [(point1, 1.0), (point2, 1.0), (point3, 1.0), (uniformDist, 1.0)]->mixture
|
||||
let d =
|
||||
[(point1, 1.0), (point2, 1.0), (point3, 1.0), (floatDist, 1.0), (uniformDist2, 1.0)]->mixture
|
||||
|
||||
test("finite klDivergence produces correct answer", () => {
|
||||
let prediction = b
|
||||
let answer = a
|
||||
let kl = klDivergence(prediction, answer)
|
||||
// high = 10; low = 9; mean = 10; stdev = 2
|
||||
let analyticalKlContinuousPart = klNormalUniform(10.0, 2.0, 9.0, 10.0) /. 2.0
|
||||
let analyticalKlDiscretePart = 1.0 /. 2.0 *. Js.Math.log(2.0 /. 1.0)
|
||||
switch kl {
|
||||
| Ok(kl') =>
|
||||
kl'->expect->toBeSoCloseTo(analyticalKlContinuousPart +. analyticalKlDiscretePart, ~digits=1)
|
||||
| Error(err) =>
|
||||
Js.Console.log(DistributionTypes.Error.toString(err))
|
||||
raise(KlFailed)
|
||||
}
|
||||
})
|
||||
test("returns infinity when infinite", () => {
|
||||
let prediction = a
|
||||
let answer = b
|
||||
let kl = klDivergence(prediction, answer)
|
||||
switch kl {
|
||||
| Ok(kl') => kl'->expect->toEqual(infinity)
|
||||
| Error(err) =>
|
||||
Js.Console.log(DistributionTypes.Error.toString(err))
|
||||
raise(KlFailed)
|
||||
}
|
||||
})
|
||||
test("finite klDivergence produces correct answer", () => {
|
||||
let prediction = d
|
||||
let answer = c
|
||||
let kl = klDivergence(prediction, answer)
|
||||
let analyticalKlContinuousPart = Js.Math.log((11.0 -. 8.0) /. (10.0 -. 9.0)) /. 4.0 // 4 = length of c' array
|
||||
let analyticalKlDiscretePart = 3.0 /. 4.0 *. Js.Math.log(4.0 /. 3.0)
|
||||
switch kl {
|
||||
| Ok(kl') =>
|
||||
kl'->expect->toBeSoCloseTo(analyticalKlContinuousPart +. analyticalKlDiscretePart, ~digits=1)
|
||||
| Error(err) =>
|
||||
Js.Console.log(DistributionTypes.Error.toString(err))
|
||||
raise(KlFailed)
|
||||
}
|
||||
})
|
||||
})
|
||||
|
||||
describe("combineAlongSupportOfSecondArgument0", () => {
|
||||
// This tests the version of the function that we're NOT using. Haven't deleted the test in case we use the code later.
|
||||
test("test on two uniforms", _ => {
|
||||
|
|
|
@ -302,10 +302,9 @@ module T = Dist({
|
|||
}
|
||||
|
||||
let klDivergence = (prediction: t, answer: t) => {
|
||||
Error(Operation.NotYetImplemented)
|
||||
// combinePointwise(PointSetDist_Scoring.KLDivergence.integrand, prediction, answer) |> E.R.fmap(
|
||||
// integralEndY,
|
||||
// )
|
||||
let klDiscretePart = Discrete.T.klDivergence(prediction.discrete, answer.discrete)
|
||||
let klContinuousPart = Continuous.T.klDivergence(prediction.continuous, answer.continuous)
|
||||
E.R.merge(klDiscretePart, klContinuousPart)->E.R2.fmap(t => fst(t) +. snd(t))
|
||||
}
|
||||
})
|
||||
|
||||
|
|
|
@ -200,6 +200,7 @@ module T = Dist({
|
|||
switch (t1, t2) {
|
||||
| (Continuous(t1), Continuous(t2)) => Continuous.T.klDivergence(t1, t2)
|
||||
| (Discrete(t1), Discrete(t2)) => Discrete.T.klDivergence(t1, t2)
|
||||
| (Mixed(t1), Mixed(t2)) => Mixed.T.klDivergence(t1, t2)
|
||||
| _ => Error(NotYetImplemented)
|
||||
}
|
||||
})
|
||||
|
|
|
@ -453,6 +453,44 @@ module PointwiseCombination = {
|
|||
T.filterOkYs(newXs, newYs)->Ok
|
||||
}
|
||||
|
||||
/* *Dead code*: Nuño wrote this function to try to increase precision, but it didn't work.
|
||||
If another traveler comes through with a similar idea, we hope this implementation will help them.
|
||||
By "enrich" we mean to increase granularity.
|
||||
*/
|
||||
let enrichXyShape = (t: T.t): T.t => {
|
||||
let defaultEnrichmentFactor = 10
|
||||
let length = E.A.length(t.xs)
|
||||
let points =
|
||||
length < MagicNumbers.Environment.defaultXYPointLength
|
||||
? defaultEnrichmentFactor * MagicNumbers.Environment.defaultXYPointLength / length
|
||||
: defaultEnrichmentFactor
|
||||
|
||||
let getInBetween = (x1: float, x2: float): array<float> => {
|
||||
if abs_float(x1 -. x2) < 2.0 *. MagicNumbers.Epsilon.seven {
|
||||
[x1]
|
||||
} else {
|
||||
let newPointsArray = Belt.Array.makeBy(points - 1, i => i)
|
||||
// don't repeat the x2 point, it will be gotten in the next iteration.
|
||||
let result = Js.Array.mapi((pos, i) =>
|
||||
if i == 0 {
|
||||
x1
|
||||
} else {
|
||||
let points' = Belt.Float.fromInt(points)
|
||||
let pos' = Belt.Float.fromInt(pos)
|
||||
x1 *. (points' -. pos') /. points' +. x2 *. pos' /. points'
|
||||
}
|
||||
, newPointsArray)
|
||||
result
|
||||
}
|
||||
}
|
||||
let newXsUnflattened = Js.Array.mapi(
|
||||
(x, i) => i < length - 2 ? getInBetween(x, t.xs[i + 1]) : [x],
|
||||
t.xs,
|
||||
)
|
||||
let newXs = Belt.Array.concatMany(newXsUnflattened)
|
||||
let newYs = E.A.fmap(x => XtoY.linear(x, t), newXs)
|
||||
{xs: newXs, ys: newYs}
|
||||
}
|
||||
// This function is used for klDivergence
|
||||
let combineAlongSupportOfSecondArgument: (
|
||||
(float, float) => result<float, Operation.Error.t>,
|
||||
|
|
Loading…
Reference in New Issue
Block a user