good evening, not working yet, but out of time for the night
Value: [1e-6 to 1e-4]
This commit is contained in:
parent
b2d80eef86
commit
ccd55ef8f1
|
@ -19,7 +19,7 @@ describe("kl divergence on continuous distributions", () => {
|
||||||
let analyticalKl = Js.Math.log((highPrediction -. lowPrediction) /. (highAnswer -. lowAnswer))
|
let analyticalKl = Js.Math.log((highPrediction -. lowPrediction) /. (highAnswer -. lowAnswer))
|
||||||
let kl = E.R.liftJoin2(klDivergence, prediction, answer)
|
let kl = E.R.liftJoin2(klDivergence, prediction, answer)
|
||||||
switch kl {
|
switch kl {
|
||||||
| Ok(kl') => kl'->expect->toBeCloseTo(analyticalKl)
|
| Ok(kl') => kl'->expect->toBeSoCloseTo(analyticalKl, ~digits=7)
|
||||||
| Error(err) => {
|
| Error(err) => {
|
||||||
Js.Console.log(DistributionTypes.Error.toString(err))
|
Js.Console.log(DistributionTypes.Error.toString(err))
|
||||||
raise(KlFailed)
|
raise(KlFailed)
|
||||||
|
@ -51,7 +51,7 @@ describe("kl divergence on continuous distributions", () => {
|
||||||
let kl = E.R.liftJoin2(klDivergence, prediction, answer)
|
let kl = E.R.liftJoin2(klDivergence, prediction, answer)
|
||||||
|
|
||||||
switch kl {
|
switch kl {
|
||||||
| Ok(kl') => kl'->expect->toBeCloseTo(analyticalKl)
|
| Ok(kl') => kl'->expect->toBeSoCloseTo(analyticalKl, ~digits=3)
|
||||||
| Error(err) => {
|
| Error(err) => {
|
||||||
Js.Console.log(DistributionTypes.Error.toString(err))
|
Js.Console.log(DistributionTypes.Error.toString(err))
|
||||||
raise(KlFailed)
|
raise(KlFailed)
|
||||||
|
@ -78,9 +78,9 @@ describe("kl divergence on discrete distributions", () => {
|
||||||
| (Dist(prediction'), Dist(answer')) => klDivergence(prediction', answer')
|
| (Dist(prediction'), Dist(answer')) => klDivergence(prediction', answer')
|
||||||
| _ => raise(MixtureFailed)
|
| _ => raise(MixtureFailed)
|
||||||
}
|
}
|
||||||
let analyticalKl = Js.Math.log(2.0 /. 3.0)
|
let analyticalKl = Js.Math.log(3.0 /. 2.0)
|
||||||
switch kl {
|
switch kl {
|
||||||
| Ok(kl') => kl'->expect->toBeCloseTo(analyticalKl)
|
| Ok(kl') => kl'->expect->toBeSoCloseTo(analyticalKl, ~digits=7)
|
||||||
| Error(err) =>
|
| Error(err) =>
|
||||||
Js.Console.log(DistributionTypes.Error.toString(err))
|
Js.Console.log(DistributionTypes.Error.toString(err))
|
||||||
raise(KlFailed)
|
raise(KlFailed)
|
||||||
|
|
|
@ -51,6 +51,7 @@ let mkExponential = rate => DistributionTypes.Symbolic(#Exponential({rate: rate}
|
||||||
let mkUniform = (low, high) => DistributionTypes.Symbolic(#Uniform({low: low, high: high}))
|
let mkUniform = (low, high) => DistributionTypes.Symbolic(#Uniform({low: low, high: high}))
|
||||||
let mkCauchy = (local, scale) => DistributionTypes.Symbolic(#Cauchy({local: local, scale: scale}))
|
let mkCauchy = (local, scale) => DistributionTypes.Symbolic(#Cauchy({local: local, scale: scale}))
|
||||||
let mkLognormal = (mu, sigma) => DistributionTypes.Symbolic(#Lognormal({mu: mu, sigma: sigma}))
|
let mkLognormal = (mu, sigma) => DistributionTypes.Symbolic(#Lognormal({mu: mu, sigma: sigma}))
|
||||||
|
let mkDirac = x => DistributionTypes.Symbolic(#Float(x))
|
||||||
|
|
||||||
let normalMake = SymbolicDist.Normal.make
|
let normalMake = SymbolicDist.Normal.make
|
||||||
let betaMake = SymbolicDist.Beta.make
|
let betaMake = SymbolicDist.Beta.make
|
||||||
|
|
|
@ -229,11 +229,25 @@ module T = Dist({
|
||||||
}
|
}
|
||||||
|
|
||||||
let klDivergence = (prediction: t, answer: t) => {
|
let klDivergence = (prediction: t, answer: t) => {
|
||||||
combinePointwise(
|
let massOrZero = (t: t, x: float): float => {
|
||||||
~combiner=XYShape.PointwiseCombination.combineAlongSupportOfSecondArgument0,
|
let i = E.A.findIndex(x' => x' == x, t.xyShape.xs)
|
||||||
~fn=PointSetDist_Scoring.KLDivergence.integrand,
|
switch i {
|
||||||
prediction,
|
| None => 0.0
|
||||||
answer,
|
| Some(i') => t.xyShape.ys[i']
|
||||||
) |> E.R2.bind(integralEndYResult)
|
}
|
||||||
|
}
|
||||||
|
let predictionNewYs = E.A.fmap(massOrZero(answer), prediction.xyShape.xs)
|
||||||
|
let integrand = XYShape.PointwiseCombination.combine(
|
||||||
|
PointSetDist_Scoring.KLDivergence.integrand,
|
||||||
|
XYShape.XtoY.continuousInterpolator(#Stepwise, #UseZero),
|
||||||
|
{XYShape.xs: answer.xyShape.xs, XYShape.ys: predictionNewYs},
|
||||||
|
answer.xyShape,
|
||||||
|
)
|
||||||
|
let xyShapeToDiscrete: XYShape.xyShape => t = xyShape => {
|
||||||
|
xyShape: xyShape,
|
||||||
|
integralSumCache: None,
|
||||||
|
integralCache: None,
|
||||||
|
}
|
||||||
|
integrand->E.R2.fmap(x => x->xyShapeToDiscrete->integralEndY)
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
|
|
|
@ -199,6 +199,7 @@ module T = Dist({
|
||||||
let klDivergence = (t1: t, t2: t) =>
|
let klDivergence = (t1: t, t2: t) =>
|
||||||
switch (t1, t2) {
|
switch (t1, t2) {
|
||||||
| (Continuous(t1), Continuous(t2)) => Continuous.T.klDivergence(t1, t2)
|
| (Continuous(t1), Continuous(t2)) => Continuous.T.klDivergence(t1, t2)
|
||||||
|
| (Discrete(t1), Discrete(t2)) => Discrete.T.klDivergence(t1, t2)
|
||||||
| _ => Error(NotYetImplemented)
|
| _ => Error(NotYetImplemented)
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
|
|
Loading…
Reference in New Issue
Block a user