Initialized logScore
and logScoreAgainstImproperPrior
Value: [1e-5 to 6e-3]
This commit is contained in:
parent
937458cd05
commit
978e149913
|
@ -277,13 +277,19 @@ module T = Dist({
|
|||
prediction.xyShape,
|
||||
answer.xyShape,
|
||||
)
|
||||
let xyShapeToContinuous: XYShape.xyShape => t = xyShape => {
|
||||
xyShape: xyShape,
|
||||
interpolation: #Linear,
|
||||
integralSumCache: None,
|
||||
integralCache: None,
|
||||
newShape->E.R2.fmap(x => x->make->integralEndY)
|
||||
}
|
||||
newShape->E.R2.fmap(x => x->xyShapeToContinuous->integralEndY)
|
||||
let logScore = (prior: t, prediction: t, answer: float) => {
|
||||
let newShape = XYShape.PointwiseCombination.combineAlongSupportOfSecondArgument(
|
||||
PointSetDist_Scoring.LogScore.integrand(~answer),
|
||||
prior.xyShape,
|
||||
prediction.xyShape,
|
||||
)
|
||||
newShape->E.R2.fmap(x => x->make->integralEndY)
|
||||
}
|
||||
let logScoreAgainstImproperPrior = (prediction: t, answer: float) => {
|
||||
let prior = make({xs: prediction.xyShape.xs, ys: E.A.fmap(_ => 1.0, prediction.xyShape.xs)})
|
||||
logScore(prior, prediction, answer)
|
||||
}
|
||||
})
|
||||
|
||||
|
|
|
@ -229,4 +229,10 @@ module T = Dist({
|
|||
answer,
|
||||
)->E.R2.fmap(integralEndY)
|
||||
}
|
||||
let logScore = (prior: t, prediction: t, answer: float) => {
|
||||
Error(Operation.NotYetImplemented)
|
||||
}
|
||||
let logScoreAgainstImproperPrior = (prediction: t, answer: float) => {
|
||||
Error(Operation.NotYetImplemented)
|
||||
}
|
||||
})
|
||||
|
|
|
@ -34,6 +34,8 @@ module type dist = {
|
|||
let mean: t => float
|
||||
let variance: t => float
|
||||
let klDivergence: (t, t) => result<float, Operation.Error.t>
|
||||
let logScore: (t, t, float) => result<float, Operation.Error.t>
|
||||
let logScoreAgainstImproperPrior: (t, float) => result<float, Operation.Error.t>
|
||||
}
|
||||
|
||||
module Dist = (T: dist) => {
|
||||
|
@ -57,6 +59,8 @@ module Dist = (T: dist) => {
|
|||
let variance = T.variance
|
||||
let integralEndY = T.integralEndY
|
||||
let klDivergence = T.klDivergence
|
||||
let logScore = T.logScore
|
||||
let logScoreAgainstImproperPrior = T.logScoreAgainstImproperPrior
|
||||
|
||||
let updateIntegralCache = T.updateIntegralCache
|
||||
|
||||
|
|
|
@ -306,6 +306,12 @@ module T = Dist({
|
|||
let klContinuousPart = Continuous.T.klDivergence(prediction.continuous, answer.continuous)
|
||||
E.R.merge(klDiscretePart, klContinuousPart)->E.R2.fmap(t => fst(t) +. snd(t))
|
||||
}
|
||||
let logScore = (prior: t, prediction: t, answer: float) => {
|
||||
Error(Operation.NotYetImplemented)
|
||||
}
|
||||
let logScoreAgainstImproperPrior = (prediction: t, answer: float) => {
|
||||
Error(Operation.NotYetImplemented)
|
||||
}
|
||||
})
|
||||
|
||||
let combineAlgebraically = (op: Operation.convolutionOperation, t1: t, t2: t): t => {
|
||||
|
|
|
@ -196,12 +196,21 @@ module T = Dist({
|
|||
| Continuous(m) => Continuous.T.variance(m)
|
||||
}
|
||||
|
||||
let klDivergence = (t1: t, t2: t) =>
|
||||
switch (t1, t2) {
|
||||
let klDivergence = (prediction: t, answer: t) =>
|
||||
switch (prediction, answer) {
|
||||
| (Continuous(t1), Continuous(t2)) => Continuous.T.klDivergence(t1, t2)
|
||||
| (Discrete(t1), Discrete(t2)) => Discrete.T.klDivergence(t1, t2)
|
||||
| (Mixed(t1), Mixed(t2)) => Mixed.T.klDivergence(t1, t2)
|
||||
| _ => Error(NotYetImplemented)
|
||||
| (m1, m2) => Mixed.T.klDivergence(m1->toMixed, m2->toMixed)
|
||||
}
|
||||
|
||||
let logScore = (prior: t, prediction: t, answer: float) => {
|
||||
switch (prior, prediction) {
|
||||
| (Continuous(t1), Continuous(t2)) => Continuous.T.logScore(t1, t2, answer)
|
||||
| _ => Error(Operation.NotYetImplemented)
|
||||
}
|
||||
}
|
||||
let logScoreAgainstImproperPrior = (prediction: t, answer: float) => {
|
||||
Error(Operation.NotYetImplemented)
|
||||
}
|
||||
})
|
||||
|
||||
|
|
|
@ -14,3 +14,20 @@ module KLDivergence = {
|
|||
quot < 0.0 ? Error(Operation.ComplexNumberError) : Ok(-.answerElement *. logFn(quot))
|
||||
}
|
||||
}
|
||||
|
||||
/*
|
||||
|
||||
*/
|
||||
module LogScore = {
|
||||
let logFn = Js.Math.log
|
||||
let integrand = (priorElement: float, predictionElement: float, ~answer: float) => {
|
||||
if answer == 0.0 {
|
||||
Ok(0.0)
|
||||
} else if predictionElement == 0.0 {
|
||||
Ok(infinity)
|
||||
} else {
|
||||
let quot = predictionElement /. priorElement
|
||||
quot < 0.0 ? Error(Operation.ComplexNumberError) : Ok(-.answer *. logFn(quot /. answer))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue
Block a user