Initialized logScore
and logScoreAgainstImproperPrior
Value: [1e-5 to 6e-3]
This commit is contained in:
parent
937458cd05
commit
978e149913
|
@ -277,13 +277,19 @@ module T = Dist({
|
||||||
prediction.xyShape,
|
prediction.xyShape,
|
||||||
answer.xyShape,
|
answer.xyShape,
|
||||||
)
|
)
|
||||||
let xyShapeToContinuous: XYShape.xyShape => t = xyShape => {
|
newShape->E.R2.fmap(x => x->make->integralEndY)
|
||||||
xyShape: xyShape,
|
}
|
||||||
interpolation: #Linear,
|
let logScore = (prior: t, prediction: t, answer: float) => {
|
||||||
integralSumCache: None,
|
let newShape = XYShape.PointwiseCombination.combineAlongSupportOfSecondArgument(
|
||||||
integralCache: None,
|
PointSetDist_Scoring.LogScore.integrand(~answer),
|
||||||
}
|
prior.xyShape,
|
||||||
newShape->E.R2.fmap(x => x->xyShapeToContinuous->integralEndY)
|
prediction.xyShape,
|
||||||
|
)
|
||||||
|
newShape->E.R2.fmap(x => x->make->integralEndY)
|
||||||
|
}
|
||||||
|
let logScoreAgainstImproperPrior = (prediction: t, answer: float) => {
|
||||||
|
let prior = make({xs: prediction.xyShape.xs, ys: E.A.fmap(_ => 1.0, prediction.xyShape.xs)})
|
||||||
|
logScore(prior, prediction, answer)
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
|
|
||||||
|
|
|
@ -229,4 +229,10 @@ module T = Dist({
|
||||||
answer,
|
answer,
|
||||||
)->E.R2.fmap(integralEndY)
|
)->E.R2.fmap(integralEndY)
|
||||||
}
|
}
|
||||||
|
let logScore = (prior: t, prediction: t, answer: float) => {
|
||||||
|
Error(Operation.NotYetImplemented)
|
||||||
|
}
|
||||||
|
let logScoreAgainstImproperPrior = (prediction: t, answer: float) => {
|
||||||
|
Error(Operation.NotYetImplemented)
|
||||||
|
}
|
||||||
})
|
})
|
||||||
|
|
|
@ -34,6 +34,8 @@ module type dist = {
|
||||||
let mean: t => float
|
let mean: t => float
|
||||||
let variance: t => float
|
let variance: t => float
|
||||||
let klDivergence: (t, t) => result<float, Operation.Error.t>
|
let klDivergence: (t, t) => result<float, Operation.Error.t>
|
||||||
|
let logScore: (t, t, float) => result<float, Operation.Error.t>
|
||||||
|
let logScoreAgainstImproperPrior: (t, float) => result<float, Operation.Error.t>
|
||||||
}
|
}
|
||||||
|
|
||||||
module Dist = (T: dist) => {
|
module Dist = (T: dist) => {
|
||||||
|
@ -57,6 +59,8 @@ module Dist = (T: dist) => {
|
||||||
let variance = T.variance
|
let variance = T.variance
|
||||||
let integralEndY = T.integralEndY
|
let integralEndY = T.integralEndY
|
||||||
let klDivergence = T.klDivergence
|
let klDivergence = T.klDivergence
|
||||||
|
let logScore = T.logScore
|
||||||
|
let logScoreAgainstImproperPrior = T.logScoreAgainstImproperPrior
|
||||||
|
|
||||||
let updateIntegralCache = T.updateIntegralCache
|
let updateIntegralCache = T.updateIntegralCache
|
||||||
|
|
||||||
|
|
|
@ -306,6 +306,12 @@ module T = Dist({
|
||||||
let klContinuousPart = Continuous.T.klDivergence(prediction.continuous, answer.continuous)
|
let klContinuousPart = Continuous.T.klDivergence(prediction.continuous, answer.continuous)
|
||||||
E.R.merge(klDiscretePart, klContinuousPart)->E.R2.fmap(t => fst(t) +. snd(t))
|
E.R.merge(klDiscretePart, klContinuousPart)->E.R2.fmap(t => fst(t) +. snd(t))
|
||||||
}
|
}
|
||||||
|
let logScore = (prior: t, prediction: t, answer: float) => {
|
||||||
|
Error(Operation.NotYetImplemented)
|
||||||
|
}
|
||||||
|
let logScoreAgainstImproperPrior = (prediction: t, answer: float) => {
|
||||||
|
Error(Operation.NotYetImplemented)
|
||||||
|
}
|
||||||
})
|
})
|
||||||
|
|
||||||
let combineAlgebraically = (op: Operation.convolutionOperation, t1: t, t2: t): t => {
|
let combineAlgebraically = (op: Operation.convolutionOperation, t1: t, t2: t): t => {
|
||||||
|
|
|
@ -196,13 +196,22 @@ module T = Dist({
|
||||||
| Continuous(m) => Continuous.T.variance(m)
|
| Continuous(m) => Continuous.T.variance(m)
|
||||||
}
|
}
|
||||||
|
|
||||||
let klDivergence = (t1: t, t2: t) =>
|
let klDivergence = (prediction: t, answer: t) =>
|
||||||
switch (t1, t2) {
|
switch (prediction, answer) {
|
||||||
| (Continuous(t1), Continuous(t2)) => Continuous.T.klDivergence(t1, t2)
|
| (Continuous(t1), Continuous(t2)) => Continuous.T.klDivergence(t1, t2)
|
||||||
| (Discrete(t1), Discrete(t2)) => Discrete.T.klDivergence(t1, t2)
|
| (Discrete(t1), Discrete(t2)) => Discrete.T.klDivergence(t1, t2)
|
||||||
| (Mixed(t1), Mixed(t2)) => Mixed.T.klDivergence(t1, t2)
|
| (m1, m2) => Mixed.T.klDivergence(m1->toMixed, m2->toMixed)
|
||||||
| _ => Error(NotYetImplemented)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
let logScore = (prior: t, prediction: t, answer: float) => {
|
||||||
|
switch (prior, prediction) {
|
||||||
|
| (Continuous(t1), Continuous(t2)) => Continuous.T.logScore(t1, t2, answer)
|
||||||
|
| _ => Error(Operation.NotYetImplemented)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
let logScoreAgainstImproperPrior = (prediction: t, answer: float) => {
|
||||||
|
Error(Operation.NotYetImplemented)
|
||||||
|
}
|
||||||
})
|
})
|
||||||
|
|
||||||
let pdf = (f: float, t: t) => {
|
let pdf = (f: float, t: t) => {
|
||||||
|
|
|
@ -14,3 +14,20 @@ module KLDivergence = {
|
||||||
quot < 0.0 ? Error(Operation.ComplexNumberError) : Ok(-.answerElement *. logFn(quot))
|
quot < 0.0 ? Error(Operation.ComplexNumberError) : Ok(-.answerElement *. logFn(quot))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/*
|
||||||
|
|
||||||
|
*/
|
||||||
|
module LogScore = {
|
||||||
|
let logFn = Js.Math.log
|
||||||
|
let integrand = (priorElement: float, predictionElement: float, ~answer: float) => {
|
||||||
|
if answer == 0.0 {
|
||||||
|
Ok(0.0)
|
||||||
|
} else if predictionElement == 0.0 {
|
||||||
|
Ok(infinity)
|
||||||
|
} else {
|
||||||
|
let quot = predictionElement /. priorElement
|
||||||
|
quot < 0.0 ? Error(Operation.ComplexNumberError) : Ok(-.answer *. logFn(quot /. answer))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
Loading…
Reference in New Issue
Block a user