Corrected log score

Value: [5e-5 to 2e-2]
This commit is contained in:
Quinn Dougherty 2022-05-03 14:00:34 -04:00
parent cec4bbd334
commit d5c9705811
5 changed files with 61 additions and 12 deletions

View File

@ -271,12 +271,22 @@ module T = Dist({
XYShape.Analysis.getVarianceDangerously(t, mean, Analysis.getMeanOfSquares)
let logScore = (base: t, reference: t) => {
E.R2.bind(
combinePointwise(PointSetDist_Scoring.LogScoring.multiply, reference),
combinePointwise(PointSetDist_Scoring.LogScoring.logScore, base, reference),
)
|> E.R.fmap(shapeMap(XYShape.T.filterYValues(Js.Float.isFinite)))
|> E.R.fmap(integralEndY)
let referenceIsZero = switch Distributions.Common.isZeroEverywhere(
PointSetTypes.Continuous(reference),
) {
| Continuous(b) => b
| _ => false
}
if referenceIsZero {
Ok(0.0)
} else {
E.R2.bind(
combinePointwise(PointSetDist_Scoring.LogScoring.multiply, reference),
combinePointwise(PointSetDist_Scoring.LogScoring.logScore, base, reference),
)
|> E.R.fmap(shapeMap(XYShape.T.filterYValues(Js.Float.isFinite)))
|> E.R.fmap(integralEndY)
}
}
})

View File

@ -230,8 +230,19 @@ module T = Dist({
}
let logScore = (base: t, reference: t) => {
combinePointwise(~fn=PointSetDist_Scoring.LogScoring.logScore, base, reference) |> E.R2.bind(
integralEndYResult,
)
let referenceIsZero = switch Distributions.Common.isZeroEverywhere(
PointSetTypes.Discrete(reference),
) {
| Discrete(b) => b
| _ => false
}
if referenceIsZero {
Ok(0.0)
} else {
E.R2.bind(
combinePointwise(~fn=PointSetDist_Scoring.LogScoring.multiply, reference),
combinePointwise(~fn=PointSetDist_Scoring.LogScoring.logScore, base, reference),
) |> E.R2.bind(integralEndYResult)
}
}
})

View File

@ -96,4 +96,18 @@ module Common = {
None
| (Some(s1), Some(s2)) => combineFn(s1, s2)
}
let isZeroEverywhere = (d: PointSetTypes.pointSetDist) => {
let isZero = (x: float): bool => x == 0.0
PointSetTypes.ShapeMonad.fmap(
d,
(
mixed =>
E.A.all(isZero, mixed.continuous.xyShape.ys) &&
E.A.all(isZero, mixed.discrete.xyShape.ys),
disc => E.A.all(isZero, disc.xyShape.ys),
cont => E.A.all(isZero, cont.xyShape.ys),
),
)
}
}

View File

@ -302,9 +302,20 @@ module T = Dist({
}
let logScore = (base: t, reference: t) => {
combinePointwise(PointSetDist_Scoring.LogScoring.logScore, base, reference) |> E.R.fmap(
integralEndY,
)
let referenceIsZero = switch Distributions.Common.isZeroEverywhere(
PointSetTypes.Mixed(reference),
) {
| Mixed(b) => b
| _ => false
}
if referenceIsZero {
Ok(0.0)
} else {
E.R2.bind(
combinePointwise(PointSetDist_Scoring.LogScoring.multiply, reference),
combinePointwise(PointSetDist_Scoring.LogScoring.logScore, base, reference),
) |> E.R.fmap(integralEndY)
}
}
})

View File

@ -606,6 +606,9 @@ module A = {
let filter = Js.Array.filter
let joinWith = Js.Array.joinWith
let all = (p: 'a => bool, xs: array<'a>): bool => length(filter(p, xs)) == length(xs)
let any = (p: 'a => bool, xs: array<'a>): bool => length(filter(p, xs)) > 0
module O = {
let concatSomes = (optionals: array<option<'a>>): array<'a> =>
optionals