Merge pull request #501 from quantified-uncertainty/kldivergence-discrete
`klDivergence` on discrete distributions
This commit is contained in:
commit
396bf5bf00
|
@ -12,3 +12,13 @@ let triangularDist: DistributionTypes.genericDist = Symbolic(
|
||||||
let exponentialDist: DistributionTypes.genericDist = Symbolic(#Exponential({rate: 2.0}))
|
let exponentialDist: DistributionTypes.genericDist = Symbolic(#Exponential({rate: 2.0}))
|
||||||
let uniformDist: DistributionTypes.genericDist = Symbolic(#Uniform({low: 9.0, high: 10.0}))
|
let uniformDist: DistributionTypes.genericDist = Symbolic(#Uniform({low: 9.0, high: 10.0}))
|
||||||
let floatDist: DistributionTypes.genericDist = Symbolic(#Float(1e1))
|
let floatDist: DistributionTypes.genericDist = Symbolic(#Float(1e1))
|
||||||
|
|
||||||
|
exception KlFailed
|
||||||
|
exception MixtureFailed
|
||||||
|
let float1 = 1.0
|
||||||
|
let float2 = 2.0
|
||||||
|
let float3 = 3.0
|
||||||
|
let {mkDelta} = module(TestHelpers)
|
||||||
|
let point1 = mkDelta(float1)
|
||||||
|
let point2 = mkDelta(float2)
|
||||||
|
let point3 = mkDelta(float3)
|
||||||
|
|
|
@ -1,8 +1,9 @@
|
||||||
open Jest
|
open Jest
|
||||||
open Expect
|
open Expect
|
||||||
open TestHelpers
|
open TestHelpers
|
||||||
|
open GenericDist_Fixtures
|
||||||
|
|
||||||
describe("kl divergence", () => {
|
describe("klDivergence: continuous -> continuous -> float", () => {
|
||||||
let klDivergence = DistributionOperation.Constructors.klDivergence(~env)
|
let klDivergence = DistributionOperation.Constructors.klDivergence(~env)
|
||||||
exception KlFailed
|
exception KlFailed
|
||||||
|
|
||||||
|
@ -19,7 +20,7 @@ describe("kl divergence", () => {
|
||||||
let analyticalKl = Js.Math.log((highPrediction -. lowPrediction) /. (highAnswer -. lowAnswer))
|
let analyticalKl = Js.Math.log((highPrediction -. lowPrediction) /. (highAnswer -. lowAnswer))
|
||||||
let kl = E.R.liftJoin2(klDivergence, prediction, answer)
|
let kl = E.R.liftJoin2(klDivergence, prediction, answer)
|
||||||
switch kl {
|
switch kl {
|
||||||
| Ok(kl') => kl'->expect->toBeCloseTo(analyticalKl)
|
| Ok(kl') => kl'->expect->toBeSoCloseTo(analyticalKl, ~digits=7)
|
||||||
| Error(err) => {
|
| Error(err) => {
|
||||||
Js.Console.log(DistributionTypes.Error.toString(err))
|
Js.Console.log(DistributionTypes.Error.toString(err))
|
||||||
raise(KlFailed)
|
raise(KlFailed)
|
||||||
|
@ -51,7 +52,7 @@ describe("kl divergence", () => {
|
||||||
let kl = E.R.liftJoin2(klDivergence, prediction, answer)
|
let kl = E.R.liftJoin2(klDivergence, prediction, answer)
|
||||||
|
|
||||||
switch kl {
|
switch kl {
|
||||||
| Ok(kl') => kl'->expect->toBeCloseTo(analyticalKl)
|
| Ok(kl') => kl'->expect->toBeSoCloseTo(analyticalKl, ~digits=3)
|
||||||
| Error(err) => {
|
| Error(err) => {
|
||||||
Js.Console.log(DistributionTypes.Error.toString(err))
|
Js.Console.log(DistributionTypes.Error.toString(err))
|
||||||
raise(KlFailed)
|
raise(KlFailed)
|
||||||
|
@ -60,9 +61,44 @@ describe("kl divergence", () => {
|
||||||
})
|
})
|
||||||
})
|
})
|
||||||
|
|
||||||
describe("combine along support test", () => {
|
describe("klDivergence: discrete -> discrete -> float", () => {
|
||||||
|
let klDivergence = DistributionOperation.Constructors.klDivergence(~env)
|
||||||
|
let mixture = a => DistributionTypes.DistributionOperation.Mixture(a)
|
||||||
|
let a' = [(point1, 1e0), (point2, 1e0)]->mixture->run
|
||||||
|
let b' = [(point1, 1e0), (point2, 1e0), (point3, 1e0)]->mixture->run
|
||||||
|
let (a, b) = switch (a', b') {
|
||||||
|
| (Dist(a''), Dist(b'')) => (a'', b'')
|
||||||
|
| _ => raise(MixtureFailed)
|
||||||
|
}
|
||||||
|
test("agrees with analytical answer when finite", () => {
|
||||||
|
let prediction = b
|
||||||
|
let answer = a
|
||||||
|
let kl = klDivergence(prediction, answer)
|
||||||
|
// Sigma_{i \in 1..2} 0.5 * log(0.5 / 0.33333)
|
||||||
|
let analyticalKl = Js.Math.log(3.0 /. 2.0)
|
||||||
|
switch kl {
|
||||||
|
| Ok(kl') => kl'->expect->toBeSoCloseTo(analyticalKl, ~digits=7)
|
||||||
|
| Error(err) =>
|
||||||
|
Js.Console.log(DistributionTypes.Error.toString(err))
|
||||||
|
raise(KlFailed)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
test("returns infinity when infinite", () => {
|
||||||
|
let prediction = a
|
||||||
|
let answer = b
|
||||||
|
let kl = klDivergence(prediction, answer)
|
||||||
|
switch kl {
|
||||||
|
| Ok(kl') => kl'->expect->toEqual(infinity)
|
||||||
|
| Error(err) =>
|
||||||
|
Js.Console.log(DistributionTypes.Error.toString(err))
|
||||||
|
raise(KlFailed)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
})
|
||||||
|
|
||||||
|
describe("combineAlongSupportOfSecondArgument0", () => {
|
||||||
// This tests the version of the function that we're NOT using. Haven't deleted the test in case we use the code later.
|
// This tests the version of the function that we're NOT using. Haven't deleted the test in case we use the code later.
|
||||||
test("combine along support test", _ => {
|
test("test on two uniforms", _ => {
|
||||||
let combineAlongSupportOfSecondArgument = XYShape.PointwiseCombination.combineAlongSupportOfSecondArgument0
|
let combineAlongSupportOfSecondArgument = XYShape.PointwiseCombination.combineAlongSupportOfSecondArgument0
|
||||||
let lowAnswer = 0.0
|
let lowAnswer = 0.0
|
||||||
let highAnswer = 1.0
|
let highAnswer = 1.0
|
||||||
|
@ -97,6 +133,7 @@ describe("combine along support test", () => {
|
||||||
2.0 *. MagicNumbers.Epsilon.ten,
|
2.0 *. MagicNumbers.Epsilon.ten,
|
||||||
1.0 -. MagicNumbers.Epsilon.ten,
|
1.0 -. MagicNumbers.Epsilon.ten,
|
||||||
1.0,
|
1.0,
|
||||||
|
1.0 +. MagicNumbers.Epsilon.ten,
|
||||||
],
|
],
|
||||||
ys: [
|
ys: [
|
||||||
-0.34657359027997264,
|
-0.34657359027997264,
|
||||||
|
@ -104,6 +141,7 @@ describe("combine along support test", () => {
|
||||||
-0.34657359027997264,
|
-0.34657359027997264,
|
||||||
-0.34657359027997264,
|
-0.34657359027997264,
|
||||||
-0.34657359027997264,
|
-0.34657359027997264,
|
||||||
|
infinity,
|
||||||
],
|
],
|
||||||
}),
|
}),
|
||||||
),
|
),
|
||||||
|
|
|
@ -51,6 +51,7 @@ let mkExponential = rate => DistributionTypes.Symbolic(#Exponential({rate: rate}
|
||||||
let mkUniform = (low, high) => DistributionTypes.Symbolic(#Uniform({low: low, high: high}))
|
let mkUniform = (low, high) => DistributionTypes.Symbolic(#Uniform({low: low, high: high}))
|
||||||
let mkCauchy = (local, scale) => DistributionTypes.Symbolic(#Cauchy({local: local, scale: scale}))
|
let mkCauchy = (local, scale) => DistributionTypes.Symbolic(#Cauchy({local: local, scale: scale}))
|
||||||
let mkLognormal = (mu, sigma) => DistributionTypes.Symbolic(#Lognormal({mu: mu, sigma: sigma}))
|
let mkLognormal = (mu, sigma) => DistributionTypes.Symbolic(#Lognormal({mu: mu, sigma: sigma}))
|
||||||
|
let mkDelta = x => DistributionTypes.Symbolic(#Float(x))
|
||||||
|
|
||||||
let normalMake = SymbolicDist.Normal.make
|
let normalMake = SymbolicDist.Normal.make
|
||||||
let betaMake = SymbolicDist.Beta.make
|
let betaMake = SymbolicDist.Beta.make
|
||||||
|
|
|
@ -48,12 +48,7 @@ let combinePointwise = (
|
||||||
// TODO: does it ever make sense to pointwise combine the integrals here?
|
// TODO: does it ever make sense to pointwise combine the integrals here?
|
||||||
// It could be done for pointwise additions, but is that ever needed?
|
// It could be done for pointwise additions, but is that ever needed?
|
||||||
|
|
||||||
make(
|
combiner(fn, XYShape.XtoY.discreteInterpolator, t1.xyShape, t2.xyShape)->E.R2.fmap(make)
|
||||||
combiner(fn, XYShape.XtoY.discreteInterpolator, t1.xyShape, t2.xyShape)->E.R.toExn(
|
|
||||||
"Addition operation should never fail",
|
|
||||||
_,
|
|
||||||
),
|
|
||||||
)->Ok
|
|
||||||
}
|
}
|
||||||
|
|
||||||
let reduce = (
|
let reduce = (
|
||||||
|
@ -163,7 +158,6 @@ module T = Dist({
|
||||||
}
|
}
|
||||||
|
|
||||||
let integralEndY = (t: t) => t.integralSumCache |> E.O.default(t |> integral |> Continuous.lastY)
|
let integralEndY = (t: t) => t.integralSumCache |> E.O.default(t |> integral |> Continuous.lastY)
|
||||||
let integralEndYResult = (t: t) => t->integralEndY->Ok
|
|
||||||
let minX = shapeFn(XYShape.T.minX)
|
let minX = shapeFn(XYShape.T.minX)
|
||||||
let maxX = shapeFn(XYShape.T.maxX)
|
let maxX = shapeFn(XYShape.T.maxX)
|
||||||
let toDiscreteProbabilityMassFraction = _ => 1.0
|
let toDiscreteProbabilityMassFraction = _ => 1.0
|
||||||
|
@ -230,10 +224,9 @@ module T = Dist({
|
||||||
|
|
||||||
let klDivergence = (prediction: t, answer: t) => {
|
let klDivergence = (prediction: t, answer: t) => {
|
||||||
combinePointwise(
|
combinePointwise(
|
||||||
~combiner=XYShape.PointwiseCombination.combineAlongSupportOfSecondArgument0,
|
|
||||||
~fn=PointSetDist_Scoring.KLDivergence.integrand,
|
~fn=PointSetDist_Scoring.KLDivergence.integrand,
|
||||||
prediction,
|
prediction,
|
||||||
answer,
|
answer,
|
||||||
) |> E.R2.bind(integralEndYResult)
|
)->E.R2.fmap(integralEndY)
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
|
|
|
@ -302,9 +302,10 @@ module T = Dist({
|
||||||
}
|
}
|
||||||
|
|
||||||
let klDivergence = (prediction: t, answer: t) => {
|
let klDivergence = (prediction: t, answer: t) => {
|
||||||
combinePointwise(PointSetDist_Scoring.KLDivergence.integrand, prediction, answer) |> E.R.fmap(
|
Error(Operation.NotYetImplemented)
|
||||||
integralEndY,
|
// combinePointwise(PointSetDist_Scoring.KLDivergence.integrand, prediction, answer) |> E.R.fmap(
|
||||||
)
|
// integralEndY,
|
||||||
|
// )
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
|
|
||||||
|
|
|
@ -199,6 +199,7 @@ module T = Dist({
|
||||||
let klDivergence = (t1: t, t2: t) =>
|
let klDivergence = (t1: t, t2: t) =>
|
||||||
switch (t1, t2) {
|
switch (t1, t2) {
|
||||||
| (Continuous(t1), Continuous(t2)) => Continuous.T.klDivergence(t1, t2)
|
| (Continuous(t1), Continuous(t2)) => Continuous.T.klDivergence(t1, t2)
|
||||||
|
| (Discrete(t1), Discrete(t2)) => Discrete.T.klDivergence(t1, t2)
|
||||||
| _ => Error(NotYetImplemented)
|
| _ => Error(NotYetImplemented)
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
|
|
|
@ -4,10 +4,11 @@ module KLDivergence = {
|
||||||
float,
|
float,
|
||||||
Operation.Error.t,
|
Operation.Error.t,
|
||||||
> =>
|
> =>
|
||||||
|
// We decided that negative infinity, not an error at answerElement = 0.0, is a desirable value.
|
||||||
if answerElement == 0.0 {
|
if answerElement == 0.0 {
|
||||||
Ok(0.0)
|
Ok(0.0)
|
||||||
} else if predictionElement == 0.0 {
|
} else if predictionElement == 0.0 {
|
||||||
Error(Operation.NegativeInfinityError)
|
Ok(infinity)
|
||||||
} else {
|
} else {
|
||||||
let quot = predictionElement /. answerElement
|
let quot = predictionElement /. answerElement
|
||||||
quot < 0.0 ? Error(Operation.ComplexNumberError) : Ok(-.answerElement *. logFn(quot))
|
quot < 0.0 ? Error(Operation.ComplexNumberError) : Ok(-.answerElement *. logFn(quot))
|
||||||
|
|
Loading…
Reference in New Issue
Block a user