Cleaned up Scoring file: no dispatch yet

Value: [1e-4 to 6e-2]
This commit is contained in:
Quinn Dougherty 2022-05-23 10:46:25 -04:00
parent 7b865c95f5
commit bafcb4f7b8

View File

@ -1,46 +1,118 @@
module KLDivergence = {
let logFn = Js.Math.log // base e
let integrand = (predictionElement: float, answerElement: float): result<
type t = PointSetDist.pointSetDist
type continuousShape = PointSetTypes.continuousShape
type discreteShape = PointSetTypes.discreteShape
type mixedShape = PointSetTypes.mixedShape
type scalar = float
type abstractScoreArgs<'a, 'b> = {estimate: 'a, answer: 'b, prior: option<'a>}
type scoreArgs =
| DistEstimateDistAnswer(abstractScoreArgs<t, t>)
| DistEstimateScalarAnswer(abstractScoreArgs<t, scalar>)
| ScalarEstimateDistAnswer(abstractScoreArgs<scalar, t>)
| ScalarEstimateScalarAnswer(abstractScoreArgs<scalar, scalar>)
let logFn = Js.Math.log // base e
let minusScaledLogOfQuot = (~esti, ~answ): result<float, Operation.Error.t> => {
let quot = esti /. answ
quot < 0.0 ? Error(Operation.ComplexNumberError) : Ok(-.answ *. logFn(quot))
}
module WithDistAnswer = {
// The Kullback-Leibler divergence
let integrand = (estimateElement: float, answerElement: float): result<
float,
Operation.Error.t,
> =>
// We decided that negative infinity, not an error at answerElement = 0.0, is a desirable value.
if answerElement == 0.0 {
Ok(0.0)
} else if predictionElement == 0.0 {
} else if estimateElement == 0.0 {
Ok(infinity)
} else {
let quot = predictionElement /. answerElement
quot < 0.0 ? Error(Operation.ComplexNumberError) : Ok(-.answerElement *. logFn(quot))
minusScaledLogOfQuot(~esti=estimateElement, ~answ=answerElement)
}
let sum = (~estimate: t, ~answer: t, ~integrateFn) =>
PointSetDist.combinePointwise(integrand, estimate, answer)->E.R2.fmap(integrateFn)
let sumWithPrior = (~estimate: t, ~answer: t, ~prior: t, ~integrateFn): result<
float,
Operation.Error.t,
> => {
let kl1 = sum(~estimate, ~answer, ~integrateFn)
let kl2 = sum(~estimate=prior, ~answer, ~integrateFn)
E.R.merge(kl1, kl2)->E.R2.fmap(((k1', k2')) => kl1' -. kl2')
}
}
module WithScalarAnswer = {
let score' = (~estimatePdf: float => float, ~answer: float): result<float, Operation.Error.t> => {
let density = answer->estimatePdf
if density < 0.0 {
Operation.PdfInvalidError->Error
} else if density == 0.0 {
infinity->Ok
} else {
density->logFn->(x => -.x)->Ok
}
}
let scoreWithPrior' = (
~estimatePdf: float => float,
~answer: float,
~priorPdf: float => float,
): result<float, Operation.Error.t> => {
let numerator = answer->estimatePdf
let priorDensityOfAnswer = answer->priorPdf
if numerator < 0.0 || priorDensityOfAnswer < 0.0 {
Operation.PdfInvalidError->Error
} else if numerator == 0.0 || priorDensityOfAnswer == 0.0 {
infinity->Ok
} else {
minusScaledLogOfQuot(~esti=numerator, ~answ=priorDensityOfAnswer)
}
}
let score = (~estimate: t, ~answer: t): result<float, Operation.Error.t> => {
let estimatePdf = x => XYShape.XtoY.linear(x, estimate.xyShape)
score'(~estimatePdf, ~answer)
}
let scoreWithPrior = (~estimate: t, ~answer: t, ~prior: t): result<float, Operation.Error.t> => {
let estimatePdf = x => XYShape.XtoY.linear(x, estimate.xyShape)
let priorPdf = x => XYShape.XtoY.linear(x, prior.xyShape)
scoreWithPrior'(~estimatePdf, ~answer, ~priorPdf)
}
}
module TwoScalars = {
let score = (~estimate: float, ~answer: float) =>
if answer == 0.0 {
0.0->Ok
} else if estimate == 0.0 {
infinity->Ok
} else {
minusScaledLogOfQuot(~esti=estimate, ~answ=answer)
}
let scoreWithPrior = (~estimate: float, ~answer: float, ~prior: float) =>
if answer == 0.0 {
0.0->Ok
} else if estimate == 0.0 || prior == 0.0 {
infinity->Ok
} else {
minusScaledLogOfQuot(~esti=estimate /. prior, ~answ=answer)
}
}
module LogScoreWithPointResolution = {
let logFn = Js.Math.log
let score = (
~priorPdf: option<float => float>,
~predictionPdf: float => float,
~answer: float,
): result<float, Operation.Error.t> => {
let numerator = answer->predictionPdf
if numerator < 0.0 {
Operation.PdfInvalidError->Error
} else if numerator == 0.0 {
infinity->Ok
} else {
-.(
switch priorPdf {
| None => numerator->logFn
| Some(f) => {
let priorDensityOfAnswer = f(answer)
if priorDensityOfAnswer == 0.0 {
neg_infinity
} else {
(numerator /. priorDensityOfAnswer)->logFn
}
}
}
)->Ok
}
let logScore = (args: scoreArgs, ~integrateFn): result<float, Operation.Error.t> =>
switch args {
| DistEstimateDistAnswer({estimate, answer, prior: None}) =>
WithDistAnswer.sum(~estimate, ~answer, ~integrateFn)
| DistEstimateDistAnswer({estimate, answer, prior: Some(prior)}) =>
WithDistAnswer.sumWithPrior(~estimate, ~answer, ~prior, ~integrateFn)
| DistEstimateScalarAnswer({estimate, answer, prior: None}) =>
WithScalarAnswer.score(~estimate, ~answer)
| DistEstimateScalarAnswer({estimate, answer, prior: Some(prior)}) =>
WithScalarAnswer.scoreWithPrior(~estimate, ~answer, ~prior)
| ScalarEstimateDistAnswer(_) => Operation.NotYetImplemented->Error
| ScalarEstimateScalarAnswer({estimate, answer, prior: None}) =>
TwoScalars.score(~estimate, ~answer)
| ScalarEstimateScalarAnswer({estimate, answer, prior: Some(prior)}) =>
TwoScalars.scoreWithPrior(~estimate, ~answer, ~prior)
}
}