Merge pull request #512 from quantified-uncertainty/kldivergence-mixed
`klDivergence` on mixed distributions
This commit is contained in:
commit
937458cd05
|
@ -11,6 +11,7 @@ let triangularDist: DistributionTypes.genericDist = Symbolic(
|
||||||
)
|
)
|
||||||
let exponentialDist: DistributionTypes.genericDist = Symbolic(#Exponential({rate: 2.0}))
|
let exponentialDist: DistributionTypes.genericDist = Symbolic(#Exponential({rate: 2.0}))
|
||||||
let uniformDist: DistributionTypes.genericDist = Symbolic(#Uniform({low: 9.0, high: 10.0}))
|
let uniformDist: DistributionTypes.genericDist = Symbolic(#Uniform({low: 9.0, high: 10.0}))
|
||||||
|
let uniformDist2: DistributionTypes.genericDist = Symbolic(#Uniform({low: 8.0, high: 11.0}))
|
||||||
let floatDist: DistributionTypes.genericDist = Symbolic(#Float(1e1))
|
let floatDist: DistributionTypes.genericDist = Symbolic(#Float(1e1))
|
||||||
|
|
||||||
exception KlFailed
|
exception KlFailed
|
||||||
|
|
|
@ -3,9 +3,15 @@ open Expect
|
||||||
open TestHelpers
|
open TestHelpers
|
||||||
open GenericDist_Fixtures
|
open GenericDist_Fixtures
|
||||||
|
|
||||||
|
// integral from low to high of 1 / (high - low) log(normal(mean, stdev)(x) / (1 / (high - low))) dx
|
||||||
|
let klNormalUniform = (mean, stdev, low, high): float =>
|
||||||
|
-.Js.Math.log((high -. low) /. Js.Math.sqrt(2.0 *. MagicNumbers.Math.pi *. stdev ** 2.0)) +.
|
||||||
|
1.0 /.
|
||||||
|
stdev ** 2.0 *.
|
||||||
|
(mean ** 2.0 -. (high +. low) *. mean +. (low ** 2.0 +. high *. low +. high ** 2.0) /. 3.0)
|
||||||
|
|
||||||
describe("klDivergence: continuous -> continuous -> float", () => {
|
describe("klDivergence: continuous -> continuous -> float", () => {
|
||||||
let klDivergence = DistributionOperation.Constructors.klDivergence(~env)
|
let klDivergence = DistributionOperation.Constructors.klDivergence(~env)
|
||||||
exception KlFailed
|
|
||||||
|
|
||||||
let testUniform = (lowAnswer, highAnswer, lowPrediction, highPrediction) => {
|
let testUniform = (lowAnswer, highAnswer, lowPrediction, highPrediction) => {
|
||||||
test("of two uniforms is equal to the analytic expression", () => {
|
test("of two uniforms is equal to the analytic expression", () => {
|
||||||
|
@ -59,6 +65,20 @@ describe("klDivergence: continuous -> continuous -> float", () => {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
|
|
||||||
|
test("of a normal and a uniform is equal to the formula", () => {
|
||||||
|
let prediction = normalDist10
|
||||||
|
let answer = uniformDist
|
||||||
|
let kl = klDivergence(prediction, answer)
|
||||||
|
let analyticalKl = klNormalUniform(10.0, 2.0, 9.0, 10.0)
|
||||||
|
switch kl {
|
||||||
|
| Ok(kl') => kl'->expect->toBeSoCloseTo(analyticalKl, ~digits=1)
|
||||||
|
| Error(err) => {
|
||||||
|
Js.Console.log(DistributionTypes.Error.toString(err))
|
||||||
|
raise(KlFailed)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
})
|
||||||
})
|
})
|
||||||
|
|
||||||
describe("klDivergence: discrete -> discrete -> float", () => {
|
describe("klDivergence: discrete -> discrete -> float", () => {
|
||||||
|
@ -96,6 +116,64 @@ describe("klDivergence: discrete -> discrete -> float", () => {
|
||||||
})
|
})
|
||||||
})
|
})
|
||||||
|
|
||||||
|
describe("klDivergence: mixed -> mixed -> float", () => {
|
||||||
|
let klDivergence = DistributionOperation.Constructors.klDivergence(~env)
|
||||||
|
let mixture' = a => DistributionTypes.DistributionOperation.Mixture(a)
|
||||||
|
let mixture = a => {
|
||||||
|
let dist' = a->mixture'->run
|
||||||
|
switch dist' {
|
||||||
|
| Dist(dist) => dist
|
||||||
|
| _ => raise(MixtureFailed)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
let a = [(point1, 1.0), (uniformDist, 1.0)]->mixture
|
||||||
|
let b = [(point1, 1.0), (floatDist, 1.0), (normalDist10, 1.0)]->mixture
|
||||||
|
let c = [(point1, 1.0), (point2, 1.0), (point3, 1.0), (uniformDist, 1.0)]->mixture
|
||||||
|
let d =
|
||||||
|
[(point1, 1.0), (point2, 1.0), (point3, 1.0), (floatDist, 1.0), (uniformDist2, 1.0)]->mixture
|
||||||
|
|
||||||
|
test("finite klDivergence produces correct answer", () => {
|
||||||
|
let prediction = b
|
||||||
|
let answer = a
|
||||||
|
let kl = klDivergence(prediction, answer)
|
||||||
|
// high = 10; low = 9; mean = 10; stdev = 2
|
||||||
|
let analyticalKlContinuousPart = klNormalUniform(10.0, 2.0, 9.0, 10.0) /. 2.0
|
||||||
|
let analyticalKlDiscretePart = 1.0 /. 2.0 *. Js.Math.log(2.0 /. 1.0)
|
||||||
|
switch kl {
|
||||||
|
| Ok(kl') =>
|
||||||
|
kl'->expect->toBeSoCloseTo(analyticalKlContinuousPart +. analyticalKlDiscretePart, ~digits=1)
|
||||||
|
| Error(err) =>
|
||||||
|
Js.Console.log(DistributionTypes.Error.toString(err))
|
||||||
|
raise(KlFailed)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
test("returns infinity when infinite", () => {
|
||||||
|
let prediction = a
|
||||||
|
let answer = b
|
||||||
|
let kl = klDivergence(prediction, answer)
|
||||||
|
switch kl {
|
||||||
|
| Ok(kl') => kl'->expect->toEqual(infinity)
|
||||||
|
| Error(err) =>
|
||||||
|
Js.Console.log(DistributionTypes.Error.toString(err))
|
||||||
|
raise(KlFailed)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
test("finite klDivergence produces correct answer", () => {
|
||||||
|
let prediction = d
|
||||||
|
let answer = c
|
||||||
|
let kl = klDivergence(prediction, answer)
|
||||||
|
let analyticalKlContinuousPart = Js.Math.log((11.0 -. 8.0) /. (10.0 -. 9.0)) /. 4.0 // 4 = length of c' array
|
||||||
|
let analyticalKlDiscretePart = 3.0 /. 4.0 *. Js.Math.log(4.0 /. 3.0)
|
||||||
|
switch kl {
|
||||||
|
| Ok(kl') =>
|
||||||
|
kl'->expect->toBeSoCloseTo(analyticalKlContinuousPart +. analyticalKlDiscretePart, ~digits=1)
|
||||||
|
| Error(err) =>
|
||||||
|
Js.Console.log(DistributionTypes.Error.toString(err))
|
||||||
|
raise(KlFailed)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
})
|
||||||
|
|
||||||
describe("combineAlongSupportOfSecondArgument0", () => {
|
describe("combineAlongSupportOfSecondArgument0", () => {
|
||||||
// This tests the version of the function that we're NOT using. Haven't deleted the test in case we use the code later.
|
// This tests the version of the function that we're NOT using. Haven't deleted the test in case we use the code later.
|
||||||
test("test on two uniforms", _ => {
|
test("test on two uniforms", _ => {
|
||||||
|
|
|
@ -302,10 +302,9 @@ module T = Dist({
|
||||||
}
|
}
|
||||||
|
|
||||||
let klDivergence = (prediction: t, answer: t) => {
|
let klDivergence = (prediction: t, answer: t) => {
|
||||||
Error(Operation.NotYetImplemented)
|
let klDiscretePart = Discrete.T.klDivergence(prediction.discrete, answer.discrete)
|
||||||
// combinePointwise(PointSetDist_Scoring.KLDivergence.integrand, prediction, answer) |> E.R.fmap(
|
let klContinuousPart = Continuous.T.klDivergence(prediction.continuous, answer.continuous)
|
||||||
// integralEndY,
|
E.R.merge(klDiscretePart, klContinuousPart)->E.R2.fmap(t => fst(t) +. snd(t))
|
||||||
// )
|
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
|
|
||||||
|
|
|
@ -200,6 +200,7 @@ module T = Dist({
|
||||||
switch (t1, t2) {
|
switch (t1, t2) {
|
||||||
| (Continuous(t1), Continuous(t2)) => Continuous.T.klDivergence(t1, t2)
|
| (Continuous(t1), Continuous(t2)) => Continuous.T.klDivergence(t1, t2)
|
||||||
| (Discrete(t1), Discrete(t2)) => Discrete.T.klDivergence(t1, t2)
|
| (Discrete(t1), Discrete(t2)) => Discrete.T.klDivergence(t1, t2)
|
||||||
|
| (Mixed(t1), Mixed(t2)) => Mixed.T.klDivergence(t1, t2)
|
||||||
| _ => Error(NotYetImplemented)
|
| _ => Error(NotYetImplemented)
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
|
|
|
@ -453,6 +453,44 @@ module PointwiseCombination = {
|
||||||
T.filterOkYs(newXs, newYs)->Ok
|
T.filterOkYs(newXs, newYs)->Ok
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/* *Dead code*: Nuño wrote this function to try to increase precision, but it didn't work.
|
||||||
|
If another traveler comes through with a similar idea, we hope this implementation will help them.
|
||||||
|
By "enrich" we mean to increase granularity.
|
||||||
|
*/
|
||||||
|
let enrichXyShape = (t: T.t): T.t => {
|
||||||
|
let defaultEnrichmentFactor = 10
|
||||||
|
let length = E.A.length(t.xs)
|
||||||
|
let points =
|
||||||
|
length < MagicNumbers.Environment.defaultXYPointLength
|
||||||
|
? defaultEnrichmentFactor * MagicNumbers.Environment.defaultXYPointLength / length
|
||||||
|
: defaultEnrichmentFactor
|
||||||
|
|
||||||
|
let getInBetween = (x1: float, x2: float): array<float> => {
|
||||||
|
if abs_float(x1 -. x2) < 2.0 *. MagicNumbers.Epsilon.seven {
|
||||||
|
[x1]
|
||||||
|
} else {
|
||||||
|
let newPointsArray = Belt.Array.makeBy(points - 1, i => i)
|
||||||
|
// don't repeat the x2 point, it will be gotten in the next iteration.
|
||||||
|
let result = Js.Array.mapi((pos, i) =>
|
||||||
|
if i == 0 {
|
||||||
|
x1
|
||||||
|
} else {
|
||||||
|
let points' = Belt.Float.fromInt(points)
|
||||||
|
let pos' = Belt.Float.fromInt(pos)
|
||||||
|
x1 *. (points' -. pos') /. points' +. x2 *. pos' /. points'
|
||||||
|
}
|
||||||
|
, newPointsArray)
|
||||||
|
result
|
||||||
|
}
|
||||||
|
}
|
||||||
|
let newXsUnflattened = Js.Array.mapi(
|
||||||
|
(x, i) => i < length - 2 ? getInBetween(x, t.xs[i + 1]) : [x],
|
||||||
|
t.xs,
|
||||||
|
)
|
||||||
|
let newXs = Belt.Array.concatMany(newXsUnflattened)
|
||||||
|
let newYs = E.A.fmap(x => XtoY.linear(x, t), newXs)
|
||||||
|
{xs: newXs, ys: newYs}
|
||||||
|
}
|
||||||
// This function is used for klDivergence
|
// This function is used for klDivergence
|
||||||
let combineAlongSupportOfSecondArgument: (
|
let combineAlongSupportOfSecondArgument: (
|
||||||
(float, float) => result<float, Operation.Error.t>,
|
(float, float) => result<float, Operation.Error.t>,
|
||||||
|
|
Loading…
Reference in New Issue
Block a user