Implemented correct math underlying logScoreWithPointResolution

Value: [1e-2 to 7e-1]

Realized that I need to switch argument order, put `prior` last maybe.
This commit is contained in:
Quinn Dougherty 2022-05-13 15:43:59 -04:00
parent 78def2d3d2
commit b4a1137019
10 changed files with 71 additions and 30 deletions

View File

@ -149,7 +149,7 @@ let rec run = (~env, functionCallInfo: functionCallInfo): outputType => {
->E.R2.fmap(r => Float(r)) ->E.R2.fmap(r => Float(r))
->OutputLocal.fromResult ->OutputLocal.fromResult
| ToScore(LogScore(prediction, answer)) => | ToScore(LogScore(prediction, answer)) =>
GenericDist.Score.logScore(dist, prediction, answer, ~toPointSetFn) GenericDist.Score.logScoreWithPointResolution(Some(dist), prediction, answer, ~toPointSetFn)
->E.R2.fmap(r => Float(r)) ->E.R2.fmap(r => Float(r))
->OutputLocal.fromResult ->OutputLocal.fromResult
| ToBool(IsNormalized) => dist->GenericDist.isNormalized->Bool | ToBool(IsNormalized) => dist->GenericDist.isNormalized->Bool
@ -267,7 +267,7 @@ module Constructors = {
let isNormalized = (~env, dist) => C.isNormalized(dist)->run(~env)->toBoolR let isNormalized = (~env, dist) => C.isNormalized(dist)->run(~env)->toBoolR
let klDivergence = (~env, dist1, dist2) => C.klDivergence(dist1, dist2)->run(~env)->toFloatR let klDivergence = (~env, dist1, dist2) => C.klDivergence(dist1, dist2)->run(~env)->toFloatR
let logScore = (~env, prior, prediction, answer) => let logScore = (~env, prior, prediction, answer) =>
C.logScore(prior, prediction, answer)->run(~env)->toFloatR C.logScoreWithPointResolution(prior, prediction, answer)->run(~env)->toFloatR
let toPointSet = (~env, dist) => C.toPointSet(dist)->run(~env)->toDistR let toPointSet = (~env, dist) => C.toPointSet(dist)->run(~env)->toDistR
let toSampleSet = (~env, dist, n) => C.toSampleSet(dist, n)->run(~env)->toDistR let toSampleSet = (~env, dist, n) => C.toSampleSet(dist, n)->run(~env)->toDistR
let fromSamples = (~env, xs) => C.fromSamples(xs)->run(~env)->toDistR let fromSamples = (~env, xs) => C.fromSamples(xs)->run(~env)->toDistR

View File

@ -162,7 +162,7 @@ module Constructors = {
let truncate = (dist, left, right): t => FromDist(ToDist(Truncate(left, right)), dist) let truncate = (dist, left, right): t => FromDist(ToDist(Truncate(left, right)), dist)
let inspect = (dist): t => FromDist(ToDist(Inspect), dist) let inspect = (dist): t => FromDist(ToDist(Inspect), dist)
let klDivergence = (dist1, dist2): t => FromDist(ToScore(KLDivergence(dist2)), dist1) let klDivergence = (dist1, dist2): t => FromDist(ToScore(KLDivergence(dist2)), dist1)
let logScore = (prior, prediction, answer): t => FromDist( let logScoreWithPointResolution = (prior, prediction, answer): t => FromDist(
ToScore(LogScore(prediction, answer)), ToScore(LogScore(prediction, answer)),
prior, prior,
) )

View File

@ -67,11 +67,32 @@ module Score = {
) )
} }
let logScore = (prior, prediction, answer, ~toPointSetFn: toPointSetFn): result<float, error> => { let logScoreWithPointResolution = (
let pointSets = E.R.merge(toPointSetFn(prior), toPointSetFn(prediction)) prior,
pointSets |> E.R2.bind(((a, b)) => prediction,
PointSetDist.T.logScore(a, b, answer)->E.R2.errMap(x => DistributionTypes.OperationError(x)) answer,
~toPointSetFn: toPointSetFn,
): result<float, error> => {
switch prior {
| Some(prior') =>
E.R.merge(toPointSetFn(prior'), toPointSetFn(prediction))->E.R.bind(((a, b)) =>
PointSetDist.T.logScoreWithPointResolution(
a->Some,
b,
answer,
)->E.R2.errMap(x => DistributionTypes.OperationError(x))
) )
| None =>
prediction
->toPointSetFn
->E.R.bind(x =>
PointSetDist.T.logScoreWithPointResolution(
None,
x,
answer,
)->E.R2.errMap(x => DistributionTypes.OperationError(x))
)
}
} }
} }

View File

@ -25,7 +25,12 @@ let toFloatOperation: (
module Score: { module Score: {
let klDivergence: (t, t, ~toPointSetFn: toPointSetFn) => result<float, error> let klDivergence: (t, t, ~toPointSetFn: toPointSetFn) => result<float, error>
let logScore: (t, t, float, ~toPointSetFn: toPointSetFn) => result<float, error> let logScoreWithPointResolution: (
option<t>,
t,
float,
~toPointSetFn: toPointSetFn,
) => result<float, error>
} }
@genType @genType

View File

@ -279,13 +279,10 @@ module T = Dist({
) )
newShape->E.R2.fmap(x => x->make->integralEndY) newShape->E.R2.fmap(x => x->make->integralEndY)
} }
let logScore = (prior: t, prediction: t, answer: float) => { let logScoreWithPointResolution = (prior: option<t>, prediction: t, answer: float) => {
let newShape = XYShape.PointwiseCombination.combineAlongSupportOfSecondArgument( let priorPdf = prior->E.O2.fmap((shape, x) => XYShape.XtoY.linear(x, shape.xyShape))
PointSetDist_Scoring.LogScore.integrand(~answer), let predictionPdf = x => XYShape.XtoY.linear(x, prediction.xyShape)
prior.xyShape, PointSetDist_Scoring.LogScoreWithPointResolution.score(~priorPdf, ~predictionPdf, ~answer)
prediction.xyShape,
)
newShape->E.R2.fmap(x => x->make->integralEndY)
} }
}) })

View File

@ -229,7 +229,7 @@ module T = Dist({
answer, answer,
)->E.R2.fmap(integralEndY) )->E.R2.fmap(integralEndY)
} }
let logScore = (prior: t, prediction: t, answer: float) => { let logScoreWithPointResolution = (prior: option<t>, prediction: t, answer: float) => {
Error(Operation.NotYetImplemented) Error(Operation.NotYetImplemented)
} }
}) })

View File

@ -34,7 +34,7 @@ module type dist = {
let mean: t => float let mean: t => float
let variance: t => float let variance: t => float
let klDivergence: (t, t) => result<float, Operation.Error.t> let klDivergence: (t, t) => result<float, Operation.Error.t>
let logScore: (t, t, float) => result<float, Operation.Error.t> let logScoreWithPointResolution: (option<t>, t, float) => result<float, Operation.Error.t>
} }
module Dist = (T: dist) => { module Dist = (T: dist) => {
@ -58,7 +58,7 @@ module Dist = (T: dist) => {
let variance = T.variance let variance = T.variance
let integralEndY = T.integralEndY let integralEndY = T.integralEndY
let klDivergence = T.klDivergence let klDivergence = T.klDivergence
let logScore = T.logScore let logScoreWithPointResolution = T.logScoreWithPointResolution
let updateIntegralCache = T.updateIntegralCache let updateIntegralCache = T.updateIntegralCache

View File

@ -306,7 +306,7 @@ module T = Dist({
let klContinuousPart = Continuous.T.klDivergence(prediction.continuous, answer.continuous) let klContinuousPart = Continuous.T.klDivergence(prediction.continuous, answer.continuous)
E.R.merge(klDiscretePart, klContinuousPart)->E.R2.fmap(t => fst(t) +. snd(t)) E.R.merge(klDiscretePart, klContinuousPart)->E.R2.fmap(t => fst(t) +. snd(t))
} }
let logScore = (prior: t, prediction: t, answer: float) => { let logScoreWithPointResolution = (prior: option<t>, prediction: t, answer: float) => {
Error(Operation.NotYetImplemented) Error(Operation.NotYetImplemented)
} }
}) })

View File

@ -203,9 +203,11 @@ module T = Dist({
| (m1, m2) => Mixed.T.klDivergence(m1->toMixed, m2->toMixed) | (m1, m2) => Mixed.T.klDivergence(m1->toMixed, m2->toMixed)
} }
let logScore = (prior: t, prediction: t, answer: float) => { let logScoreWithPointResolution = (prior: option<t>, prediction: t, answer: float) => {
switch (prior, prediction) { switch (prior, prediction) {
| (Continuous(t1), Continuous(t2)) => Continuous.T.logScore(t1, t2, answer) | (Some(Continuous(t1)), Continuous(t2)) =>
Continuous.T.logScoreWithPointResolution(t1->Some, t2, answer)
| (None, Continuous(t2)) => Continuous.T.logScoreWithPointResolution(None, t2, answer)
| _ => Error(Operation.NotYetImplemented) | _ => Error(Operation.NotYetImplemented)
} }
} }

View File

@ -18,16 +18,32 @@ module KLDivergence = {
/* /*
*/ */
module LogScore = { module LogScoreWithPointResolution = {
let logFn = Js.Math.log let logFn = Js.Math.log
let integrand = (priorElement: float, predictionElement: float, ~answer: float) => { let score = (
if answer == 0.0 { ~priorPdf: option<float => float>,
Ok(0.0) ~predictionPdf: float => float,
} else if predictionElement == 0.0 { ~answer: float,
Ok(infinity) ): result<float, Operation.Error.t> => {
let numer = answer->predictionPdf
if numer < 0.0 {
Operation.ComplexNumberError->Error
} else if numer == 0.0 {
infinity->Ok
} else { } else {
let quot = predictionElement /. priorElement -.(
quot < 0.0 ? Error(Operation.ComplexNumberError) : Ok(-.answer *. logFn(quot /. answer)) switch priorPdf {
| None => numer->logFn
| Some(f) => {
let priorDensityOfAnswer = f(answer)
if priorDensityOfAnswer == 0.0 {
neg_infinity
} else {
(numer /. priorDensityOfAnswer)->logFn
}
}
}
)->Ok
} }
} }
} }