Corrected log score
Value: [5e-5 to 2e-2]
This commit is contained in:
parent
cec4bbd334
commit
d5c9705811
|
@ -271,6 +271,15 @@ module T = Dist({
|
|||
XYShape.Analysis.getVarianceDangerously(t, mean, Analysis.getMeanOfSquares)
|
||||
|
||||
let logScore = (base: t, reference: t) => {
|
||||
let referenceIsZero = switch Distributions.Common.isZeroEverywhere(
|
||||
PointSetTypes.Continuous(reference),
|
||||
) {
|
||||
| Continuous(b) => b
|
||||
| _ => false
|
||||
}
|
||||
if referenceIsZero {
|
||||
Ok(0.0)
|
||||
} else {
|
||||
E.R2.bind(
|
||||
combinePointwise(PointSetDist_Scoring.LogScoring.multiply, reference),
|
||||
combinePointwise(PointSetDist_Scoring.LogScoring.logScore, base, reference),
|
||||
|
@ -278,6 +287,7 @@ module T = Dist({
|
|||
|> E.R.fmap(shapeMap(XYShape.T.filterYValues(Js.Float.isFinite)))
|
||||
|> E.R.fmap(integralEndY)
|
||||
}
|
||||
}
|
||||
})
|
||||
|
||||
let isNormalized = (t: t): bool => {
|
||||
|
|
|
@ -230,8 +230,19 @@ module T = Dist({
|
|||
}
|
||||
|
||||
let logScore = (base: t, reference: t) => {
|
||||
combinePointwise(~fn=PointSetDist_Scoring.LogScoring.logScore, base, reference) |> E.R2.bind(
|
||||
integralEndYResult,
|
||||
)
|
||||
let referenceIsZero = switch Distributions.Common.isZeroEverywhere(
|
||||
PointSetTypes.Discrete(reference),
|
||||
) {
|
||||
| Discrete(b) => b
|
||||
| _ => false
|
||||
}
|
||||
if referenceIsZero {
|
||||
Ok(0.0)
|
||||
} else {
|
||||
E.R2.bind(
|
||||
combinePointwise(~fn=PointSetDist_Scoring.LogScoring.multiply, reference),
|
||||
combinePointwise(~fn=PointSetDist_Scoring.LogScoring.logScore, base, reference),
|
||||
) |> E.R2.bind(integralEndYResult)
|
||||
}
|
||||
}
|
||||
})
|
||||
|
|
|
@ -96,4 +96,18 @@ module Common = {
|
|||
None
|
||||
| (Some(s1), Some(s2)) => combineFn(s1, s2)
|
||||
}
|
||||
|
||||
let isZeroEverywhere = (d: PointSetTypes.pointSetDist) => {
|
||||
let isZero = (x: float): bool => x == 0.0
|
||||
PointSetTypes.ShapeMonad.fmap(
|
||||
d,
|
||||
(
|
||||
mixed =>
|
||||
E.A.all(isZero, mixed.continuous.xyShape.ys) &&
|
||||
E.A.all(isZero, mixed.discrete.xyShape.ys),
|
||||
disc => E.A.all(isZero, disc.xyShape.ys),
|
||||
cont => E.A.all(isZero, cont.xyShape.ys),
|
||||
),
|
||||
)
|
||||
}
|
||||
}
|
||||
|
|
|
@ -302,9 +302,20 @@ module T = Dist({
|
|||
}
|
||||
|
||||
let logScore = (base: t, reference: t) => {
|
||||
combinePointwise(PointSetDist_Scoring.LogScoring.logScore, base, reference) |> E.R.fmap(
|
||||
integralEndY,
|
||||
)
|
||||
let referenceIsZero = switch Distributions.Common.isZeroEverywhere(
|
||||
PointSetTypes.Mixed(reference),
|
||||
) {
|
||||
| Mixed(b) => b
|
||||
| _ => false
|
||||
}
|
||||
if referenceIsZero {
|
||||
Ok(0.0)
|
||||
} else {
|
||||
E.R2.bind(
|
||||
combinePointwise(PointSetDist_Scoring.LogScoring.multiply, reference),
|
||||
combinePointwise(PointSetDist_Scoring.LogScoring.logScore, base, reference),
|
||||
) |> E.R.fmap(integralEndY)
|
||||
}
|
||||
}
|
||||
})
|
||||
|
||||
|
|
|
@ -606,6 +606,9 @@ module A = {
|
|||
let filter = Js.Array.filter
|
||||
let joinWith = Js.Array.joinWith
|
||||
|
||||
let all = (p: 'a => bool, xs: array<'a>): bool => length(filter(p, xs)) == length(xs)
|
||||
let any = (p: 'a => bool, xs: array<'a>): bool => length(filter(p, xs)) > 0
|
||||
|
||||
module O = {
|
||||
let concatSomes = (optionals: array<option<'a>>): array<'a> =>
|
||||
optionals
|
||||
|
|
Loading…
Reference in New Issue
Block a user