Implemented correct math underlying logScoreWithPointResolution
Value: [1e-2 to 7e-1] Realized that I need to switch argument order, put `prior` last maybe.
This commit is contained in:
parent
78def2d3d2
commit
b4a1137019
|
@ -149,7 +149,7 @@ let rec run = (~env, functionCallInfo: functionCallInfo): outputType => {
|
||||||
->E.R2.fmap(r => Float(r))
|
->E.R2.fmap(r => Float(r))
|
||||||
->OutputLocal.fromResult
|
->OutputLocal.fromResult
|
||||||
| ToScore(LogScore(prediction, answer)) =>
|
| ToScore(LogScore(prediction, answer)) =>
|
||||||
GenericDist.Score.logScore(dist, prediction, answer, ~toPointSetFn)
|
GenericDist.Score.logScoreWithPointResolution(Some(dist), prediction, answer, ~toPointSetFn)
|
||||||
->E.R2.fmap(r => Float(r))
|
->E.R2.fmap(r => Float(r))
|
||||||
->OutputLocal.fromResult
|
->OutputLocal.fromResult
|
||||||
| ToBool(IsNormalized) => dist->GenericDist.isNormalized->Bool
|
| ToBool(IsNormalized) => dist->GenericDist.isNormalized->Bool
|
||||||
|
@ -267,7 +267,7 @@ module Constructors = {
|
||||||
let isNormalized = (~env, dist) => C.isNormalized(dist)->run(~env)->toBoolR
|
let isNormalized = (~env, dist) => C.isNormalized(dist)->run(~env)->toBoolR
|
||||||
let klDivergence = (~env, dist1, dist2) => C.klDivergence(dist1, dist2)->run(~env)->toFloatR
|
let klDivergence = (~env, dist1, dist2) => C.klDivergence(dist1, dist2)->run(~env)->toFloatR
|
||||||
let logScore = (~env, prior, prediction, answer) =>
|
let logScore = (~env, prior, prediction, answer) =>
|
||||||
C.logScore(prior, prediction, answer)->run(~env)->toFloatR
|
C.logScoreWithPointResolution(prior, prediction, answer)->run(~env)->toFloatR
|
||||||
let toPointSet = (~env, dist) => C.toPointSet(dist)->run(~env)->toDistR
|
let toPointSet = (~env, dist) => C.toPointSet(dist)->run(~env)->toDistR
|
||||||
let toSampleSet = (~env, dist, n) => C.toSampleSet(dist, n)->run(~env)->toDistR
|
let toSampleSet = (~env, dist, n) => C.toSampleSet(dist, n)->run(~env)->toDistR
|
||||||
let fromSamples = (~env, xs) => C.fromSamples(xs)->run(~env)->toDistR
|
let fromSamples = (~env, xs) => C.fromSamples(xs)->run(~env)->toDistR
|
||||||
|
|
|
@ -162,7 +162,7 @@ module Constructors = {
|
||||||
let truncate = (dist, left, right): t => FromDist(ToDist(Truncate(left, right)), dist)
|
let truncate = (dist, left, right): t => FromDist(ToDist(Truncate(left, right)), dist)
|
||||||
let inspect = (dist): t => FromDist(ToDist(Inspect), dist)
|
let inspect = (dist): t => FromDist(ToDist(Inspect), dist)
|
||||||
let klDivergence = (dist1, dist2): t => FromDist(ToScore(KLDivergence(dist2)), dist1)
|
let klDivergence = (dist1, dist2): t => FromDist(ToScore(KLDivergence(dist2)), dist1)
|
||||||
let logScore = (prior, prediction, answer): t => FromDist(
|
let logScoreWithPointResolution = (prior, prediction, answer): t => FromDist(
|
||||||
ToScore(LogScore(prediction, answer)),
|
ToScore(LogScore(prediction, answer)),
|
||||||
prior,
|
prior,
|
||||||
)
|
)
|
||||||
|
|
|
@ -67,11 +67,32 @@ module Score = {
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
let logScore = (prior, prediction, answer, ~toPointSetFn: toPointSetFn): result<float, error> => {
|
let logScoreWithPointResolution = (
|
||||||
let pointSets = E.R.merge(toPointSetFn(prior), toPointSetFn(prediction))
|
prior,
|
||||||
pointSets |> E.R2.bind(((a, b)) =>
|
prediction,
|
||||||
PointSetDist.T.logScore(a, b, answer)->E.R2.errMap(x => DistributionTypes.OperationError(x))
|
answer,
|
||||||
)
|
~toPointSetFn: toPointSetFn,
|
||||||
|
): result<float, error> => {
|
||||||
|
switch prior {
|
||||||
|
| Some(prior') =>
|
||||||
|
E.R.merge(toPointSetFn(prior'), toPointSetFn(prediction))->E.R.bind(((a, b)) =>
|
||||||
|
PointSetDist.T.logScoreWithPointResolution(
|
||||||
|
a->Some,
|
||||||
|
b,
|
||||||
|
answer,
|
||||||
|
)->E.R2.errMap(x => DistributionTypes.OperationError(x))
|
||||||
|
)
|
||||||
|
| None =>
|
||||||
|
prediction
|
||||||
|
->toPointSetFn
|
||||||
|
->E.R.bind(x =>
|
||||||
|
PointSetDist.T.logScoreWithPointResolution(
|
||||||
|
None,
|
||||||
|
x,
|
||||||
|
answer,
|
||||||
|
)->E.R2.errMap(x => DistributionTypes.OperationError(x))
|
||||||
|
)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -25,7 +25,12 @@ let toFloatOperation: (
|
||||||
|
|
||||||
module Score: {
|
module Score: {
|
||||||
let klDivergence: (t, t, ~toPointSetFn: toPointSetFn) => result<float, error>
|
let klDivergence: (t, t, ~toPointSetFn: toPointSetFn) => result<float, error>
|
||||||
let logScore: (t, t, float, ~toPointSetFn: toPointSetFn) => result<float, error>
|
let logScoreWithPointResolution: (
|
||||||
|
option<t>,
|
||||||
|
t,
|
||||||
|
float,
|
||||||
|
~toPointSetFn: toPointSetFn,
|
||||||
|
) => result<float, error>
|
||||||
}
|
}
|
||||||
|
|
||||||
@genType
|
@genType
|
||||||
|
|
|
@ -279,13 +279,10 @@ module T = Dist({
|
||||||
)
|
)
|
||||||
newShape->E.R2.fmap(x => x->make->integralEndY)
|
newShape->E.R2.fmap(x => x->make->integralEndY)
|
||||||
}
|
}
|
||||||
let logScore = (prior: t, prediction: t, answer: float) => {
|
let logScoreWithPointResolution = (prior: option<t>, prediction: t, answer: float) => {
|
||||||
let newShape = XYShape.PointwiseCombination.combineAlongSupportOfSecondArgument(
|
let priorPdf = prior->E.O2.fmap((shape, x) => XYShape.XtoY.linear(x, shape.xyShape))
|
||||||
PointSetDist_Scoring.LogScore.integrand(~answer),
|
let predictionPdf = x => XYShape.XtoY.linear(x, prediction.xyShape)
|
||||||
prior.xyShape,
|
PointSetDist_Scoring.LogScoreWithPointResolution.score(~priorPdf, ~predictionPdf, ~answer)
|
||||||
prediction.xyShape,
|
|
||||||
)
|
|
||||||
newShape->E.R2.fmap(x => x->make->integralEndY)
|
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
|
|
||||||
|
|
|
@ -229,7 +229,7 @@ module T = Dist({
|
||||||
answer,
|
answer,
|
||||||
)->E.R2.fmap(integralEndY)
|
)->E.R2.fmap(integralEndY)
|
||||||
}
|
}
|
||||||
let logScore = (prior: t, prediction: t, answer: float) => {
|
let logScoreWithPointResolution = (prior: option<t>, prediction: t, answer: float) => {
|
||||||
Error(Operation.NotYetImplemented)
|
Error(Operation.NotYetImplemented)
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
|
|
|
@ -34,7 +34,7 @@ module type dist = {
|
||||||
let mean: t => float
|
let mean: t => float
|
||||||
let variance: t => float
|
let variance: t => float
|
||||||
let klDivergence: (t, t) => result<float, Operation.Error.t>
|
let klDivergence: (t, t) => result<float, Operation.Error.t>
|
||||||
let logScore: (t, t, float) => result<float, Operation.Error.t>
|
let logScoreWithPointResolution: (option<t>, t, float) => result<float, Operation.Error.t>
|
||||||
}
|
}
|
||||||
|
|
||||||
module Dist = (T: dist) => {
|
module Dist = (T: dist) => {
|
||||||
|
@ -58,7 +58,7 @@ module Dist = (T: dist) => {
|
||||||
let variance = T.variance
|
let variance = T.variance
|
||||||
let integralEndY = T.integralEndY
|
let integralEndY = T.integralEndY
|
||||||
let klDivergence = T.klDivergence
|
let klDivergence = T.klDivergence
|
||||||
let logScore = T.logScore
|
let logScoreWithPointResolution = T.logScoreWithPointResolution
|
||||||
|
|
||||||
let updateIntegralCache = T.updateIntegralCache
|
let updateIntegralCache = T.updateIntegralCache
|
||||||
|
|
||||||
|
|
|
@ -306,7 +306,7 @@ module T = Dist({
|
||||||
let klContinuousPart = Continuous.T.klDivergence(prediction.continuous, answer.continuous)
|
let klContinuousPart = Continuous.T.klDivergence(prediction.continuous, answer.continuous)
|
||||||
E.R.merge(klDiscretePart, klContinuousPart)->E.R2.fmap(t => fst(t) +. snd(t))
|
E.R.merge(klDiscretePart, klContinuousPart)->E.R2.fmap(t => fst(t) +. snd(t))
|
||||||
}
|
}
|
||||||
let logScore = (prior: t, prediction: t, answer: float) => {
|
let logScoreWithPointResolution = (prior: option<t>, prediction: t, answer: float) => {
|
||||||
Error(Operation.NotYetImplemented)
|
Error(Operation.NotYetImplemented)
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
|
|
|
@ -203,9 +203,11 @@ module T = Dist({
|
||||||
| (m1, m2) => Mixed.T.klDivergence(m1->toMixed, m2->toMixed)
|
| (m1, m2) => Mixed.T.klDivergence(m1->toMixed, m2->toMixed)
|
||||||
}
|
}
|
||||||
|
|
||||||
let logScore = (prior: t, prediction: t, answer: float) => {
|
let logScoreWithPointResolution = (prior: option<t>, prediction: t, answer: float) => {
|
||||||
switch (prior, prediction) {
|
switch (prior, prediction) {
|
||||||
| (Continuous(t1), Continuous(t2)) => Continuous.T.logScore(t1, t2, answer)
|
| (Some(Continuous(t1)), Continuous(t2)) =>
|
||||||
|
Continuous.T.logScoreWithPointResolution(t1->Some, t2, answer)
|
||||||
|
| (None, Continuous(t2)) => Continuous.T.logScoreWithPointResolution(None, t2, answer)
|
||||||
| _ => Error(Operation.NotYetImplemented)
|
| _ => Error(Operation.NotYetImplemented)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -18,16 +18,32 @@ module KLDivergence = {
|
||||||
/*
|
/*
|
||||||
|
|
||||||
*/
|
*/
|
||||||
module LogScore = {
|
module LogScoreWithPointResolution = {
|
||||||
let logFn = Js.Math.log
|
let logFn = Js.Math.log
|
||||||
let integrand = (priorElement: float, predictionElement: float, ~answer: float) => {
|
let score = (
|
||||||
if answer == 0.0 {
|
~priorPdf: option<float => float>,
|
||||||
Ok(0.0)
|
~predictionPdf: float => float,
|
||||||
} else if predictionElement == 0.0 {
|
~answer: float,
|
||||||
Ok(infinity)
|
): result<float, Operation.Error.t> => {
|
||||||
|
let numer = answer->predictionPdf
|
||||||
|
if numer < 0.0 {
|
||||||
|
Operation.ComplexNumberError->Error
|
||||||
|
} else if numer == 0.0 {
|
||||||
|
infinity->Ok
|
||||||
} else {
|
} else {
|
||||||
let quot = predictionElement /. priorElement
|
-.(
|
||||||
quot < 0.0 ? Error(Operation.ComplexNumberError) : Ok(-.answer *. logFn(quot /. answer))
|
switch priorPdf {
|
||||||
|
| None => numer->logFn
|
||||||
|
| Some(f) => {
|
||||||
|
let priorDensityOfAnswer = f(answer)
|
||||||
|
if priorDensityOfAnswer == 0.0 {
|
||||||
|
neg_infinity
|
||||||
|
} else {
|
||||||
|
(numer /. priorDensityOfAnswer)->logFn
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
)->Ok
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
Loading…
Reference in New Issue
Block a user