good evening, not working yet, but out of time for the night

Value: [1e-6 to 1e-4]
This commit is contained in:
Quinn Dougherty 2022-05-09 19:17:27 -04:00
parent b2d80eef86
commit ccd55ef8f1
4 changed files with 26 additions and 10 deletions

View File

@ -19,7 +19,7 @@ describe("kl divergence on continuous distributions", () => {
let analyticalKl = Js.Math.log((highPrediction -. lowPrediction) /. (highAnswer -. lowAnswer)) let analyticalKl = Js.Math.log((highPrediction -. lowPrediction) /. (highAnswer -. lowAnswer))
let kl = E.R.liftJoin2(klDivergence, prediction, answer) let kl = E.R.liftJoin2(klDivergence, prediction, answer)
switch kl { switch kl {
| Ok(kl') => kl'->expect->toBeCloseTo(analyticalKl) | Ok(kl') => kl'->expect->toBeSoCloseTo(analyticalKl, ~digits=7)
| Error(err) => { | Error(err) => {
Js.Console.log(DistributionTypes.Error.toString(err)) Js.Console.log(DistributionTypes.Error.toString(err))
raise(KlFailed) raise(KlFailed)
@ -51,7 +51,7 @@ describe("kl divergence on continuous distributions", () => {
let kl = E.R.liftJoin2(klDivergence, prediction, answer) let kl = E.R.liftJoin2(klDivergence, prediction, answer)
switch kl { switch kl {
| Ok(kl') => kl'->expect->toBeCloseTo(analyticalKl) | Ok(kl') => kl'->expect->toBeSoCloseTo(analyticalKl, ~digits=3)
| Error(err) => { | Error(err) => {
Js.Console.log(DistributionTypes.Error.toString(err)) Js.Console.log(DistributionTypes.Error.toString(err))
raise(KlFailed) raise(KlFailed)
@ -78,9 +78,9 @@ describe("kl divergence on discrete distributions", () => {
| (Dist(prediction'), Dist(answer')) => klDivergence(prediction', answer') | (Dist(prediction'), Dist(answer')) => klDivergence(prediction', answer')
| _ => raise(MixtureFailed) | _ => raise(MixtureFailed)
} }
let analyticalKl = Js.Math.log(2.0 /. 3.0) let analyticalKl = Js.Math.log(3.0 /. 2.0)
switch kl { switch kl {
| Ok(kl') => kl'->expect->toBeCloseTo(analyticalKl) | Ok(kl') => kl'->expect->toBeSoCloseTo(analyticalKl, ~digits=7)
| Error(err) => | Error(err) =>
Js.Console.log(DistributionTypes.Error.toString(err)) Js.Console.log(DistributionTypes.Error.toString(err))
raise(KlFailed) raise(KlFailed)

View File

@ -51,6 +51,7 @@ let mkExponential = rate => DistributionTypes.Symbolic(#Exponential({rate: rate}
let mkUniform = (low, high) => DistributionTypes.Symbolic(#Uniform({low: low, high: high})) let mkUniform = (low, high) => DistributionTypes.Symbolic(#Uniform({low: low, high: high}))
let mkCauchy = (local, scale) => DistributionTypes.Symbolic(#Cauchy({local: local, scale: scale})) let mkCauchy = (local, scale) => DistributionTypes.Symbolic(#Cauchy({local: local, scale: scale}))
let mkLognormal = (mu, sigma) => DistributionTypes.Symbolic(#Lognormal({mu: mu, sigma: sigma})) let mkLognormal = (mu, sigma) => DistributionTypes.Symbolic(#Lognormal({mu: mu, sigma: sigma}))
let mkDirac = x => DistributionTypes.Symbolic(#Float(x))
let normalMake = SymbolicDist.Normal.make let normalMake = SymbolicDist.Normal.make
let betaMake = SymbolicDist.Beta.make let betaMake = SymbolicDist.Beta.make

View File

@ -229,11 +229,25 @@ module T = Dist({
} }
let klDivergence = (prediction: t, answer: t) => { let klDivergence = (prediction: t, answer: t) => {
combinePointwise( let massOrZero = (t: t, x: float): float => {
~combiner=XYShape.PointwiseCombination.combineAlongSupportOfSecondArgument0, let i = E.A.findIndex(x' => x' == x, t.xyShape.xs)
~fn=PointSetDist_Scoring.KLDivergence.integrand, switch i {
prediction, | None => 0.0
answer, | Some(i') => t.xyShape.ys[i']
) |> E.R2.bind(integralEndYResult) }
}
let predictionNewYs = E.A.fmap(massOrZero(answer), prediction.xyShape.xs)
let integrand = XYShape.PointwiseCombination.combine(
PointSetDist_Scoring.KLDivergence.integrand,
XYShape.XtoY.continuousInterpolator(#Stepwise, #UseZero),
{XYShape.xs: answer.xyShape.xs, XYShape.ys: predictionNewYs},
answer.xyShape,
)
let xyShapeToDiscrete: XYShape.xyShape => t = xyShape => {
xyShape: xyShape,
integralSumCache: None,
integralCache: None,
}
integrand->E.R2.fmap(x => x->xyShapeToDiscrete->integralEndY)
} }
}) })

View File

@ -199,6 +199,7 @@ module T = Dist({
let klDivergence = (t1: t, t2: t) => let klDivergence = (t1: t, t2: t) =>
switch (t1, t2) { switch (t1, t2) {
| (Continuous(t1), Continuous(t2)) => Continuous.T.klDivergence(t1, t2) | (Continuous(t1), Continuous(t2)) => Continuous.T.klDivergence(t1, t2)
| (Discrete(t1), Discrete(t2)) => Discrete.T.klDivergence(t1, t2)
| _ => Error(NotYetImplemented) | _ => Error(NotYetImplemented)
} }
}) })