Initialized logScore and logScoreAgainstImproperPrior

Value: [1e-5 to 6e-3]
This commit is contained in:
Quinn Dougherty 2022-05-12 13:11:51 -04:00
parent 937458cd05
commit 978e149913
6 changed files with 59 additions and 11 deletions

View File

@ -277,13 +277,19 @@ module T = Dist({
prediction.xyShape, prediction.xyShape,
answer.xyShape, answer.xyShape,
) )
let xyShapeToContinuous: XYShape.xyShape => t = xyShape => { newShape->E.R2.fmap(x => x->make->integralEndY)
xyShape: xyShape,
interpolation: #Linear,
integralSumCache: None,
integralCache: None,
} }
newShape->E.R2.fmap(x => x->xyShapeToContinuous->integralEndY) let logScore = (prior: t, prediction: t, answer: float) => {
let newShape = XYShape.PointwiseCombination.combineAlongSupportOfSecondArgument(
PointSetDist_Scoring.LogScore.integrand(~answer),
prior.xyShape,
prediction.xyShape,
)
newShape->E.R2.fmap(x => x->make->integralEndY)
}
let logScoreAgainstImproperPrior = (prediction: t, answer: float) => {
let prior = make({xs: prediction.xyShape.xs, ys: E.A.fmap(_ => 1.0, prediction.xyShape.xs)})
logScore(prior, prediction, answer)
} }
}) })

View File

@ -229,4 +229,10 @@ module T = Dist({
answer, answer,
)->E.R2.fmap(integralEndY) )->E.R2.fmap(integralEndY)
} }
let logScore = (prior: t, prediction: t, answer: float) => {
Error(Operation.NotYetImplemented)
}
let logScoreAgainstImproperPrior = (prediction: t, answer: float) => {
Error(Operation.NotYetImplemented)
}
}) })

View File

@ -34,6 +34,8 @@ module type dist = {
let mean: t => float let mean: t => float
let variance: t => float let variance: t => float
let klDivergence: (t, t) => result<float, Operation.Error.t> let klDivergence: (t, t) => result<float, Operation.Error.t>
let logScore: (t, t, float) => result<float, Operation.Error.t>
let logScoreAgainstImproperPrior: (t, float) => result<float, Operation.Error.t>
} }
module Dist = (T: dist) => { module Dist = (T: dist) => {
@ -57,6 +59,8 @@ module Dist = (T: dist) => {
let variance = T.variance let variance = T.variance
let integralEndY = T.integralEndY let integralEndY = T.integralEndY
let klDivergence = T.klDivergence let klDivergence = T.klDivergence
let logScore = T.logScore
let logScoreAgainstImproperPrior = T.logScoreAgainstImproperPrior
let updateIntegralCache = T.updateIntegralCache let updateIntegralCache = T.updateIntegralCache

View File

@ -306,6 +306,12 @@ module T = Dist({
let klContinuousPart = Continuous.T.klDivergence(prediction.continuous, answer.continuous) let klContinuousPart = Continuous.T.klDivergence(prediction.continuous, answer.continuous)
E.R.merge(klDiscretePart, klContinuousPart)->E.R2.fmap(t => fst(t) +. snd(t)) E.R.merge(klDiscretePart, klContinuousPart)->E.R2.fmap(t => fst(t) +. snd(t))
} }
let logScore = (prior: t, prediction: t, answer: float) => {
Error(Operation.NotYetImplemented)
}
let logScoreAgainstImproperPrior = (prediction: t, answer: float) => {
Error(Operation.NotYetImplemented)
}
}) })
let combineAlgebraically = (op: Operation.convolutionOperation, t1: t, t2: t): t => { let combineAlgebraically = (op: Operation.convolutionOperation, t1: t, t2: t): t => {

View File

@ -196,12 +196,21 @@ module T = Dist({
| Continuous(m) => Continuous.T.variance(m) | Continuous(m) => Continuous.T.variance(m)
} }
let klDivergence = (t1: t, t2: t) => let klDivergence = (prediction: t, answer: t) =>
switch (t1, t2) { switch (prediction, answer) {
| (Continuous(t1), Continuous(t2)) => Continuous.T.klDivergence(t1, t2) | (Continuous(t1), Continuous(t2)) => Continuous.T.klDivergence(t1, t2)
| (Discrete(t1), Discrete(t2)) => Discrete.T.klDivergence(t1, t2) | (Discrete(t1), Discrete(t2)) => Discrete.T.klDivergence(t1, t2)
| (Mixed(t1), Mixed(t2)) => Mixed.T.klDivergence(t1, t2) | (m1, m2) => Mixed.T.klDivergence(m1->toMixed, m2->toMixed)
| _ => Error(NotYetImplemented) }
let logScore = (prior: t, prediction: t, answer: float) => {
switch (prior, prediction) {
| (Continuous(t1), Continuous(t2)) => Continuous.T.logScore(t1, t2, answer)
| _ => Error(Operation.NotYetImplemented)
}
}
let logScoreAgainstImproperPrior = (prediction: t, answer: float) => {
Error(Operation.NotYetImplemented)
} }
}) })

View File

@ -14,3 +14,20 @@ module KLDivergence = {
quot < 0.0 ? Error(Operation.ComplexNumberError) : Ok(-.answerElement *. logFn(quot)) quot < 0.0 ? Error(Operation.ComplexNumberError) : Ok(-.answerElement *. logFn(quot))
} }
} }
/*
*/
module LogScore = {
let logFn = Js.Math.log
let integrand = (priorElement: float, predictionElement: float, ~answer: float) => {
if answer == 0.0 {
Ok(0.0)
} else if predictionElement == 0.0 {
Ok(infinity)
} else {
let quot = predictionElement /. priorElement
quot < 0.0 ? Error(Operation.ComplexNumberError) : Ok(-.answer *. logFn(quot /. answer))
}
}
}