Cleaned up Scoring
file: no dispatch yet
Value: [1e-4 to 6e-2]
This commit is contained in:
parent
7b865c95f5
commit
bafcb4f7b8
|
@ -1,46 +1,118 @@
|
|||
module KLDivergence = {
|
||||
let logFn = Js.Math.log // base e
|
||||
let integrand = (predictionElement: float, answerElement: float): result<
|
||||
type t = PointSetDist.pointSetDist
|
||||
type continuousShape = PointSetTypes.continuousShape
|
||||
type discreteShape = PointSetTypes.discreteShape
|
||||
type mixedShape = PointSetTypes.mixedShape
|
||||
type scalar = float
|
||||
type abstractScoreArgs<'a, 'b> = {estimate: 'a, answer: 'b, prior: option<'a>}
|
||||
type scoreArgs =
|
||||
| DistEstimateDistAnswer(abstractScoreArgs<t, t>)
|
||||
| DistEstimateScalarAnswer(abstractScoreArgs<t, scalar>)
|
||||
| ScalarEstimateDistAnswer(abstractScoreArgs<scalar, t>)
|
||||
| ScalarEstimateScalarAnswer(abstractScoreArgs<scalar, scalar>)
|
||||
let logFn = Js.Math.log // base e
|
||||
let minusScaledLogOfQuot = (~esti, ~answ): result<float, Operation.Error.t> => {
|
||||
let quot = esti /. answ
|
||||
quot < 0.0 ? Error(Operation.ComplexNumberError) : Ok(-.answ *. logFn(quot))
|
||||
}
|
||||
|
||||
module WithDistAnswer = {
|
||||
// The Kullback-Leibler divergence
|
||||
let integrand = (estimateElement: float, answerElement: float): result<
|
||||
float,
|
||||
Operation.Error.t,
|
||||
> =>
|
||||
// We decided that negative infinity, not an error at answerElement = 0.0, is a desirable value.
|
||||
if answerElement == 0.0 {
|
||||
Ok(0.0)
|
||||
} else if predictionElement == 0.0 {
|
||||
} else if estimateElement == 0.0 {
|
||||
Ok(infinity)
|
||||
} else {
|
||||
let quot = predictionElement /. answerElement
|
||||
quot < 0.0 ? Error(Operation.ComplexNumberError) : Ok(-.answerElement *. logFn(quot))
|
||||
minusScaledLogOfQuot(~esti=estimateElement, ~answ=answerElement)
|
||||
}
|
||||
|
||||
let sum = (~estimate: t, ~answer: t, ~integrateFn) =>
|
||||
PointSetDist.combinePointwise(integrand, estimate, answer)->E.R2.fmap(integrateFn)
|
||||
|
||||
let sumWithPrior = (~estimate: t, ~answer: t, ~prior: t, ~integrateFn): result<
|
||||
float,
|
||||
Operation.Error.t,
|
||||
> => {
|
||||
let kl1 = sum(~estimate, ~answer, ~integrateFn)
|
||||
let kl2 = sum(~estimate=prior, ~answer, ~integrateFn)
|
||||
E.R.merge(kl1, kl2)->E.R2.fmap(((k1', k2')) => kl1' -. kl2')
|
||||
}
|
||||
}
|
||||
|
||||
module WithScalarAnswer = {
|
||||
let score' = (~estimatePdf: float => float, ~answer: float): result<float, Operation.Error.t> => {
|
||||
let density = answer->estimatePdf
|
||||
if density < 0.0 {
|
||||
Operation.PdfInvalidError->Error
|
||||
} else if density == 0.0 {
|
||||
infinity->Ok
|
||||
} else {
|
||||
density->logFn->(x => -.x)->Ok
|
||||
}
|
||||
}
|
||||
let scoreWithPrior' = (
|
||||
~estimatePdf: float => float,
|
||||
~answer: float,
|
||||
~priorPdf: float => float,
|
||||
): result<float, Operation.Error.t> => {
|
||||
let numerator = answer->estimatePdf
|
||||
let priorDensityOfAnswer = answer->priorPdf
|
||||
if numerator < 0.0 || priorDensityOfAnswer < 0.0 {
|
||||
Operation.PdfInvalidError->Error
|
||||
} else if numerator == 0.0 || priorDensityOfAnswer == 0.0 {
|
||||
infinity->Ok
|
||||
} else {
|
||||
minusScaledLogOfQuot(~esti=numerator, ~answ=priorDensityOfAnswer)
|
||||
}
|
||||
}
|
||||
let score = (~estimate: t, ~answer: t): result<float, Operation.Error.t> => {
|
||||
let estimatePdf = x => XYShape.XtoY.linear(x, estimate.xyShape)
|
||||
score'(~estimatePdf, ~answer)
|
||||
}
|
||||
let scoreWithPrior = (~estimate: t, ~answer: t, ~prior: t): result<float, Operation.Error.t> => {
|
||||
let estimatePdf = x => XYShape.XtoY.linear(x, estimate.xyShape)
|
||||
let priorPdf = x => XYShape.XtoY.linear(x, prior.xyShape)
|
||||
scoreWithPrior'(~estimatePdf, ~answer, ~priorPdf)
|
||||
}
|
||||
}
|
||||
|
||||
module TwoScalars = {
|
||||
let score = (~estimate: float, ~answer: float) =>
|
||||
if answer == 0.0 {
|
||||
0.0->Ok
|
||||
} else if estimate == 0.0 {
|
||||
infinity->Ok
|
||||
} else {
|
||||
minusScaledLogOfQuot(~esti=estimate, ~answ=answer)
|
||||
}
|
||||
|
||||
let scoreWithPrior = (~estimate: float, ~answer: float, ~prior: float) =>
|
||||
if answer == 0.0 {
|
||||
0.0->Ok
|
||||
} else if estimate == 0.0 || prior == 0.0 {
|
||||
infinity->Ok
|
||||
} else {
|
||||
minusScaledLogOfQuot(~esti=estimate /. prior, ~answ=answer)
|
||||
}
|
||||
}
|
||||
|
||||
module LogScoreWithPointResolution = {
|
||||
let logFn = Js.Math.log
|
||||
let score = (
|
||||
~priorPdf: option<float => float>,
|
||||
~predictionPdf: float => float,
|
||||
~answer: float,
|
||||
): result<float, Operation.Error.t> => {
|
||||
let numerator = answer->predictionPdf
|
||||
if numerator < 0.0 {
|
||||
Operation.PdfInvalidError->Error
|
||||
} else if numerator == 0.0 {
|
||||
infinity->Ok
|
||||
} else {
|
||||
-.(
|
||||
switch priorPdf {
|
||||
| None => numerator->logFn
|
||||
| Some(f) => {
|
||||
let priorDensityOfAnswer = f(answer)
|
||||
if priorDensityOfAnswer == 0.0 {
|
||||
neg_infinity
|
||||
} else {
|
||||
(numerator /. priorDensityOfAnswer)->logFn
|
||||
}
|
||||
}
|
||||
}
|
||||
)->Ok
|
||||
}
|
||||
let logScore = (args: scoreArgs, ~integrateFn): result<float, Operation.Error.t> =>
|
||||
switch args {
|
||||
| DistEstimateDistAnswer({estimate, answer, prior: None}) =>
|
||||
WithDistAnswer.sum(~estimate, ~answer, ~integrateFn)
|
||||
| DistEstimateDistAnswer({estimate, answer, prior: Some(prior)}) =>
|
||||
WithDistAnswer.sumWithPrior(~estimate, ~answer, ~prior, ~integrateFn)
|
||||
| DistEstimateScalarAnswer({estimate, answer, prior: None}) =>
|
||||
WithScalarAnswer.score(~estimate, ~answer)
|
||||
| DistEstimateScalarAnswer({estimate, answer, prior: Some(prior)}) =>
|
||||
WithScalarAnswer.scoreWithPrior(~estimate, ~answer, ~prior)
|
||||
| ScalarEstimateDistAnswer(_) => Operation.NotYetImplemented->Error
|
||||
| ScalarEstimateScalarAnswer({estimate, answer, prior: None}) =>
|
||||
TwoScalars.score(~estimate, ~answer)
|
||||
| ScalarEstimateScalarAnswer({estimate, answer, prior: Some(prior)}) =>
|
||||
TwoScalars.scoreWithPrior(~estimate, ~answer, ~prior)
|
||||
}
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue
Block a user