Cleaned up Scoring
file: no dispatch yet
Value: [1e-4 to 6e-2]
This commit is contained in:
parent
7b865c95f5
commit
bafcb4f7b8
|
@ -1,46 +1,118 @@
|
||||||
module KLDivergence = {
|
type t = PointSetDist.pointSetDist
|
||||||
let logFn = Js.Math.log // base e
|
type continuousShape = PointSetTypes.continuousShape
|
||||||
let integrand = (predictionElement: float, answerElement: float): result<
|
type discreteShape = PointSetTypes.discreteShape
|
||||||
|
type mixedShape = PointSetTypes.mixedShape
|
||||||
|
type scalar = float
|
||||||
|
type abstractScoreArgs<'a, 'b> = {estimate: 'a, answer: 'b, prior: option<'a>}
|
||||||
|
type scoreArgs =
|
||||||
|
| DistEstimateDistAnswer(abstractScoreArgs<t, t>)
|
||||||
|
| DistEstimateScalarAnswer(abstractScoreArgs<t, scalar>)
|
||||||
|
| ScalarEstimateDistAnswer(abstractScoreArgs<scalar, t>)
|
||||||
|
| ScalarEstimateScalarAnswer(abstractScoreArgs<scalar, scalar>)
|
||||||
|
let logFn = Js.Math.log // base e
|
||||||
|
let minusScaledLogOfQuot = (~esti, ~answ): result<float, Operation.Error.t> => {
|
||||||
|
let quot = esti /. answ
|
||||||
|
quot < 0.0 ? Error(Operation.ComplexNumberError) : Ok(-.answ *. logFn(quot))
|
||||||
|
}
|
||||||
|
|
||||||
|
module WithDistAnswer = {
|
||||||
|
// The Kullback-Leibler divergence
|
||||||
|
let integrand = (estimateElement: float, answerElement: float): result<
|
||||||
float,
|
float,
|
||||||
Operation.Error.t,
|
Operation.Error.t,
|
||||||
> =>
|
> =>
|
||||||
// We decided that negative infinity, not an error at answerElement = 0.0, is a desirable value.
|
// We decided that negative infinity, not an error at answerElement = 0.0, is a desirable value.
|
||||||
if answerElement == 0.0 {
|
if answerElement == 0.0 {
|
||||||
Ok(0.0)
|
Ok(0.0)
|
||||||
} else if predictionElement == 0.0 {
|
} else if estimateElement == 0.0 {
|
||||||
Ok(infinity)
|
Ok(infinity)
|
||||||
} else {
|
} else {
|
||||||
let quot = predictionElement /. answerElement
|
minusScaledLogOfQuot(~esti=estimateElement, ~answ=answerElement)
|
||||||
quot < 0.0 ? Error(Operation.ComplexNumberError) : Ok(-.answerElement *. logFn(quot))
|
}
|
||||||
|
|
||||||
|
let sum = (~estimate: t, ~answer: t, ~integrateFn) =>
|
||||||
|
PointSetDist.combinePointwise(integrand, estimate, answer)->E.R2.fmap(integrateFn)
|
||||||
|
|
||||||
|
let sumWithPrior = (~estimate: t, ~answer: t, ~prior: t, ~integrateFn): result<
|
||||||
|
float,
|
||||||
|
Operation.Error.t,
|
||||||
|
> => {
|
||||||
|
let kl1 = sum(~estimate, ~answer, ~integrateFn)
|
||||||
|
let kl2 = sum(~estimate=prior, ~answer, ~integrateFn)
|
||||||
|
E.R.merge(kl1, kl2)->E.R2.fmap(((k1', k2')) => kl1' -. kl2')
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
module WithScalarAnswer = {
|
||||||
|
let score' = (~estimatePdf: float => float, ~answer: float): result<float, Operation.Error.t> => {
|
||||||
|
let density = answer->estimatePdf
|
||||||
|
if density < 0.0 {
|
||||||
|
Operation.PdfInvalidError->Error
|
||||||
|
} else if density == 0.0 {
|
||||||
|
infinity->Ok
|
||||||
|
} else {
|
||||||
|
density->logFn->(x => -.x)->Ok
|
||||||
|
}
|
||||||
|
}
|
||||||
|
let scoreWithPrior' = (
|
||||||
|
~estimatePdf: float => float,
|
||||||
|
~answer: float,
|
||||||
|
~priorPdf: float => float,
|
||||||
|
): result<float, Operation.Error.t> => {
|
||||||
|
let numerator = answer->estimatePdf
|
||||||
|
let priorDensityOfAnswer = answer->priorPdf
|
||||||
|
if numerator < 0.0 || priorDensityOfAnswer < 0.0 {
|
||||||
|
Operation.PdfInvalidError->Error
|
||||||
|
} else if numerator == 0.0 || priorDensityOfAnswer == 0.0 {
|
||||||
|
infinity->Ok
|
||||||
|
} else {
|
||||||
|
minusScaledLogOfQuot(~esti=numerator, ~answ=priorDensityOfAnswer)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
let score = (~estimate: t, ~answer: t): result<float, Operation.Error.t> => {
|
||||||
|
let estimatePdf = x => XYShape.XtoY.linear(x, estimate.xyShape)
|
||||||
|
score'(~estimatePdf, ~answer)
|
||||||
|
}
|
||||||
|
let scoreWithPrior = (~estimate: t, ~answer: t, ~prior: t): result<float, Operation.Error.t> => {
|
||||||
|
let estimatePdf = x => XYShape.XtoY.linear(x, estimate.xyShape)
|
||||||
|
let priorPdf = x => XYShape.XtoY.linear(x, prior.xyShape)
|
||||||
|
scoreWithPrior'(~estimatePdf, ~answer, ~priorPdf)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
module TwoScalars = {
|
||||||
|
let score = (~estimate: float, ~answer: float) =>
|
||||||
|
if answer == 0.0 {
|
||||||
|
0.0->Ok
|
||||||
|
} else if estimate == 0.0 {
|
||||||
|
infinity->Ok
|
||||||
|
} else {
|
||||||
|
minusScaledLogOfQuot(~esti=estimate, ~answ=answer)
|
||||||
|
}
|
||||||
|
|
||||||
|
let scoreWithPrior = (~estimate: float, ~answer: float, ~prior: float) =>
|
||||||
|
if answer == 0.0 {
|
||||||
|
0.0->Ok
|
||||||
|
} else if estimate == 0.0 || prior == 0.0 {
|
||||||
|
infinity->Ok
|
||||||
|
} else {
|
||||||
|
minusScaledLogOfQuot(~esti=estimate /. prior, ~answ=answer)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
module LogScoreWithPointResolution = {
|
let logScore = (args: scoreArgs, ~integrateFn): result<float, Operation.Error.t> =>
|
||||||
let logFn = Js.Math.log
|
switch args {
|
||||||
let score = (
|
| DistEstimateDistAnswer({estimate, answer, prior: None}) =>
|
||||||
~priorPdf: option<float => float>,
|
WithDistAnswer.sum(~estimate, ~answer, ~integrateFn)
|
||||||
~predictionPdf: float => float,
|
| DistEstimateDistAnswer({estimate, answer, prior: Some(prior)}) =>
|
||||||
~answer: float,
|
WithDistAnswer.sumWithPrior(~estimate, ~answer, ~prior, ~integrateFn)
|
||||||
): result<float, Operation.Error.t> => {
|
| DistEstimateScalarAnswer({estimate, answer, prior: None}) =>
|
||||||
let numerator = answer->predictionPdf
|
WithScalarAnswer.score(~estimate, ~answer)
|
||||||
if numerator < 0.0 {
|
| DistEstimateScalarAnswer({estimate, answer, prior: Some(prior)}) =>
|
||||||
Operation.PdfInvalidError->Error
|
WithScalarAnswer.scoreWithPrior(~estimate, ~answer, ~prior)
|
||||||
} else if numerator == 0.0 {
|
| ScalarEstimateDistAnswer(_) => Operation.NotYetImplemented->Error
|
||||||
infinity->Ok
|
| ScalarEstimateScalarAnswer({estimate, answer, prior: None}) =>
|
||||||
} else {
|
TwoScalars.score(~estimate, ~answer)
|
||||||
-.(
|
| ScalarEstimateScalarAnswer({estimate, answer, prior: Some(prior)}) =>
|
||||||
switch priorPdf {
|
TwoScalars.scoreWithPrior(~estimate, ~answer, ~prior)
|
||||||
| None => numerator->logFn
|
|
||||||
| Some(f) => {
|
|
||||||
let priorDensityOfAnswer = f(answer)
|
|
||||||
if priorDensityOfAnswer == 0.0 {
|
|
||||||
neg_infinity
|
|
||||||
} else {
|
|
||||||
(numerator /. priorDensityOfAnswer)->logFn
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
)->Ok
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
}
|
|
||||||
|
|
Loading…
Reference in New Issue
Block a user