Corrected log score
Value: [5e-5 to 2e-2]
This commit is contained in:
parent
cec4bbd334
commit
d5c9705811
|
@ -271,6 +271,15 @@ module T = Dist({
|
||||||
XYShape.Analysis.getVarianceDangerously(t, mean, Analysis.getMeanOfSquares)
|
XYShape.Analysis.getVarianceDangerously(t, mean, Analysis.getMeanOfSquares)
|
||||||
|
|
||||||
let logScore = (base: t, reference: t) => {
|
let logScore = (base: t, reference: t) => {
|
||||||
|
let referenceIsZero = switch Distributions.Common.isZeroEverywhere(
|
||||||
|
PointSetTypes.Continuous(reference),
|
||||||
|
) {
|
||||||
|
| Continuous(b) => b
|
||||||
|
| _ => false
|
||||||
|
}
|
||||||
|
if referenceIsZero {
|
||||||
|
Ok(0.0)
|
||||||
|
} else {
|
||||||
E.R2.bind(
|
E.R2.bind(
|
||||||
combinePointwise(PointSetDist_Scoring.LogScoring.multiply, reference),
|
combinePointwise(PointSetDist_Scoring.LogScoring.multiply, reference),
|
||||||
combinePointwise(PointSetDist_Scoring.LogScoring.logScore, base, reference),
|
combinePointwise(PointSetDist_Scoring.LogScoring.logScore, base, reference),
|
||||||
|
@ -278,6 +287,7 @@ module T = Dist({
|
||||||
|> E.R.fmap(shapeMap(XYShape.T.filterYValues(Js.Float.isFinite)))
|
|> E.R.fmap(shapeMap(XYShape.T.filterYValues(Js.Float.isFinite)))
|
||||||
|> E.R.fmap(integralEndY)
|
|> E.R.fmap(integralEndY)
|
||||||
}
|
}
|
||||||
|
}
|
||||||
})
|
})
|
||||||
|
|
||||||
let isNormalized = (t: t): bool => {
|
let isNormalized = (t: t): bool => {
|
||||||
|
|
|
@ -230,8 +230,19 @@ module T = Dist({
|
||||||
}
|
}
|
||||||
|
|
||||||
let logScore = (base: t, reference: t) => {
|
let logScore = (base: t, reference: t) => {
|
||||||
combinePointwise(~fn=PointSetDist_Scoring.LogScoring.logScore, base, reference) |> E.R2.bind(
|
let referenceIsZero = switch Distributions.Common.isZeroEverywhere(
|
||||||
integralEndYResult,
|
PointSetTypes.Discrete(reference),
|
||||||
)
|
) {
|
||||||
|
| Discrete(b) => b
|
||||||
|
| _ => false
|
||||||
|
}
|
||||||
|
if referenceIsZero {
|
||||||
|
Ok(0.0)
|
||||||
|
} else {
|
||||||
|
E.R2.bind(
|
||||||
|
combinePointwise(~fn=PointSetDist_Scoring.LogScoring.multiply, reference),
|
||||||
|
combinePointwise(~fn=PointSetDist_Scoring.LogScoring.logScore, base, reference),
|
||||||
|
) |> E.R2.bind(integralEndYResult)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
|
|
|
@ -96,4 +96,18 @@ module Common = {
|
||||||
None
|
None
|
||||||
| (Some(s1), Some(s2)) => combineFn(s1, s2)
|
| (Some(s1), Some(s2)) => combineFn(s1, s2)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
let isZeroEverywhere = (d: PointSetTypes.pointSetDist) => {
|
||||||
|
let isZero = (x: float): bool => x == 0.0
|
||||||
|
PointSetTypes.ShapeMonad.fmap(
|
||||||
|
d,
|
||||||
|
(
|
||||||
|
mixed =>
|
||||||
|
E.A.all(isZero, mixed.continuous.xyShape.ys) &&
|
||||||
|
E.A.all(isZero, mixed.discrete.xyShape.ys),
|
||||||
|
disc => E.A.all(isZero, disc.xyShape.ys),
|
||||||
|
cont => E.A.all(isZero, cont.xyShape.ys),
|
||||||
|
),
|
||||||
|
)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -302,9 +302,20 @@ module T = Dist({
|
||||||
}
|
}
|
||||||
|
|
||||||
let logScore = (base: t, reference: t) => {
|
let logScore = (base: t, reference: t) => {
|
||||||
combinePointwise(PointSetDist_Scoring.LogScoring.logScore, base, reference) |> E.R.fmap(
|
let referenceIsZero = switch Distributions.Common.isZeroEverywhere(
|
||||||
integralEndY,
|
PointSetTypes.Mixed(reference),
|
||||||
)
|
) {
|
||||||
|
| Mixed(b) => b
|
||||||
|
| _ => false
|
||||||
|
}
|
||||||
|
if referenceIsZero {
|
||||||
|
Ok(0.0)
|
||||||
|
} else {
|
||||||
|
E.R2.bind(
|
||||||
|
combinePointwise(PointSetDist_Scoring.LogScoring.multiply, reference),
|
||||||
|
combinePointwise(PointSetDist_Scoring.LogScoring.logScore, base, reference),
|
||||||
|
) |> E.R.fmap(integralEndY)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
|
|
||||||
|
|
|
@ -606,6 +606,9 @@ module A = {
|
||||||
let filter = Js.Array.filter
|
let filter = Js.Array.filter
|
||||||
let joinWith = Js.Array.joinWith
|
let joinWith = Js.Array.joinWith
|
||||||
|
|
||||||
|
let all = (p: 'a => bool, xs: array<'a>): bool => length(filter(p, xs)) == length(xs)
|
||||||
|
let any = (p: 'a => bool, xs: array<'a>): bool => length(filter(p, xs)) > 0
|
||||||
|
|
||||||
module O = {
|
module O = {
|
||||||
let concatSomes = (optionals: array<option<'a>>): array<'a> =>
|
let concatSomes = (optionals: array<option<'a>>): array<'a> =>
|
||||||
optionals
|
optionals
|
||||||
|
|
Loading…
Reference in New Issue
Block a user