Corrected log score

Value: [5e-5 to 2e-2]
This commit is contained in:
Quinn Dougherty 2022-05-03 14:00:34 -04:00
parent cec4bbd334
commit d5c9705811
5 changed files with 61 additions and 12 deletions

View File

@ -271,6 +271,15 @@ module T = Dist({
XYShape.Analysis.getVarianceDangerously(t, mean, Analysis.getMeanOfSquares) XYShape.Analysis.getVarianceDangerously(t, mean, Analysis.getMeanOfSquares)
let logScore = (base: t, reference: t) => { let logScore = (base: t, reference: t) => {
let referenceIsZero = switch Distributions.Common.isZeroEverywhere(
PointSetTypes.Continuous(reference),
) {
| Continuous(b) => b
| _ => false
}
if referenceIsZero {
Ok(0.0)
} else {
E.R2.bind( E.R2.bind(
combinePointwise(PointSetDist_Scoring.LogScoring.multiply, reference), combinePointwise(PointSetDist_Scoring.LogScoring.multiply, reference),
combinePointwise(PointSetDist_Scoring.LogScoring.logScore, base, reference), combinePointwise(PointSetDist_Scoring.LogScoring.logScore, base, reference),
@ -278,6 +287,7 @@ module T = Dist({
|> E.R.fmap(shapeMap(XYShape.T.filterYValues(Js.Float.isFinite))) |> E.R.fmap(shapeMap(XYShape.T.filterYValues(Js.Float.isFinite)))
|> E.R.fmap(integralEndY) |> E.R.fmap(integralEndY)
} }
}
}) })
let isNormalized = (t: t): bool => { let isNormalized = (t: t): bool => {

View File

@ -230,8 +230,19 @@ module T = Dist({
} }
let logScore = (base: t, reference: t) => { let logScore = (base: t, reference: t) => {
combinePointwise(~fn=PointSetDist_Scoring.LogScoring.logScore, base, reference) |> E.R2.bind( let referenceIsZero = switch Distributions.Common.isZeroEverywhere(
integralEndYResult, PointSetTypes.Discrete(reference),
) ) {
| Discrete(b) => b
| _ => false
}
if referenceIsZero {
Ok(0.0)
} else {
E.R2.bind(
combinePointwise(~fn=PointSetDist_Scoring.LogScoring.multiply, reference),
combinePointwise(~fn=PointSetDist_Scoring.LogScoring.logScore, base, reference),
) |> E.R2.bind(integralEndYResult)
}
} }
}) })

View File

@ -96,4 +96,18 @@ module Common = {
None None
| (Some(s1), Some(s2)) => combineFn(s1, s2) | (Some(s1), Some(s2)) => combineFn(s1, s2)
} }
let isZeroEverywhere = (d: PointSetTypes.pointSetDist) => {
let isZero = (x: float): bool => x == 0.0
PointSetTypes.ShapeMonad.fmap(
d,
(
mixed =>
E.A.all(isZero, mixed.continuous.xyShape.ys) &&
E.A.all(isZero, mixed.discrete.xyShape.ys),
disc => E.A.all(isZero, disc.xyShape.ys),
cont => E.A.all(isZero, cont.xyShape.ys),
),
)
}
} }

View File

@ -302,9 +302,20 @@ module T = Dist({
} }
let logScore = (base: t, reference: t) => { let logScore = (base: t, reference: t) => {
combinePointwise(PointSetDist_Scoring.LogScoring.logScore, base, reference) |> E.R.fmap( let referenceIsZero = switch Distributions.Common.isZeroEverywhere(
integralEndY, PointSetTypes.Mixed(reference),
) ) {
| Mixed(b) => b
| _ => false
}
if referenceIsZero {
Ok(0.0)
} else {
E.R2.bind(
combinePointwise(PointSetDist_Scoring.LogScoring.multiply, reference),
combinePointwise(PointSetDist_Scoring.LogScoring.logScore, base, reference),
) |> E.R.fmap(integralEndY)
}
} }
}) })

View File

@ -606,6 +606,9 @@ module A = {
let filter = Js.Array.filter let filter = Js.Array.filter
let joinWith = Js.Array.joinWith let joinWith = Js.Array.joinWith
let all = (p: 'a => bool, xs: array<'a>): bool => length(filter(p, xs)) == length(xs)
let any = (p: 'a => bool, xs: array<'a>): bool => length(filter(p, xs)) > 0
module O = { module O = {
let concatSomes = (optionals: array<option<'a>>): array<'a> => let concatSomes = (optionals: array<option<'a>>): array<'a> =>
optionals optionals