klDivergence
is now LogarithmWithThreshold
This commit is contained in:
parent
cfa83e552d
commit
898547f3a3
|
@ -280,7 +280,11 @@ module T = Dist({
|
||||||
if referenceIsZero {
|
if referenceIsZero {
|
||||||
Ok(0.0)
|
Ok(0.0)
|
||||||
} else {
|
} else {
|
||||||
combinePointwise(PointSetDist_Scoring.KLDivergence.logScore, base, reference)
|
combinePointwise(
|
||||||
|
PointSetDist_Scoring.KLDivergence.logScore(~eps=MagicNumbers.Epsilon.seven),
|
||||||
|
base,
|
||||||
|
reference,
|
||||||
|
)
|
||||||
|> E.R.fmap(shapeMap(XYShape.T.filterYValues(Js.Float.isFinite)))
|
|> E.R.fmap(shapeMap(XYShape.T.filterYValues(Js.Float.isFinite)))
|
||||||
|> E.R.fmap(integralEndY)
|
|> E.R.fmap(integralEndY)
|
||||||
}
|
}
|
||||||
|
@ -289,7 +293,8 @@ module T = Dist({
|
||||||
|
|
||||||
let isNormalized = (t: t): bool => {
|
let isNormalized = (t: t): bool => {
|
||||||
let areaUnderIntegral = t |> updateIntegralCache(Some(T.integral(t))) |> T.integralEndY
|
let areaUnderIntegral = t |> updateIntegralCache(Some(T.integral(t))) |> T.integralEndY
|
||||||
areaUnderIntegral < 1. +. 1e-7 && areaUnderIntegral > 1. -. 1e-7
|
areaUnderIntegral < 1. +. MagicNumbers.Epsilon.seven &&
|
||||||
|
areaUnderIntegral > 1. -. MagicNumbers.Epsilon.seven
|
||||||
}
|
}
|
||||||
|
|
||||||
let downsampleEquallyOverX = (length, t): t =>
|
let downsampleEquallyOverX = (length, t): t =>
|
||||||
|
|
|
@ -240,7 +240,7 @@ module T = Dist({
|
||||||
Ok(0.0)
|
Ok(0.0)
|
||||||
} else {
|
} else {
|
||||||
combinePointwise(
|
combinePointwise(
|
||||||
~fn=PointSetDist_Scoring.KLDivergence.logScore,
|
~fn=PointSetDist_Scoring.KLDivergence.logScore(~eps=MagicNumbers.Epsilon.ten),
|
||||||
base,
|
base,
|
||||||
reference,
|
reference,
|
||||||
) |> E.R2.bind(integralEndYResult)
|
) |> E.R2.bind(integralEndYResult)
|
||||||
|
|
|
@ -311,9 +311,11 @@ module T = Dist({
|
||||||
if referenceIsZero {
|
if referenceIsZero {
|
||||||
Ok(0.0)
|
Ok(0.0)
|
||||||
} else {
|
} else {
|
||||||
combinePointwise(PointSetDist_Scoring.KLDivergence.logScore, base, reference) |> E.R.fmap(
|
combinePointwise(
|
||||||
integralEndY,
|
PointSetDist_Scoring.KLDivergence.logScore(~eps=MagicNumbers.Epsilon.ten),
|
||||||
)
|
base,
|
||||||
|
reference,
|
||||||
|
) |> E.R.fmap(integralEndY)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
|
|
|
@ -1,7 +1,8 @@
|
||||||
module KLDivergence = {
|
module KLDivergence = {
|
||||||
let logFn = Js.Math.log
|
let logFn = Js.Math.log
|
||||||
let subtraction = (a, b) => Ok(a -. b)
|
let subtraction = (a, b) => Ok(a -. b)
|
||||||
let logScore = (a: float, b: float): result<float, Operation.Error.t> =>
|
let multiply = (a: float, b: float): result<float, Operation.Error.t> => Ok(a *. b)
|
||||||
|
let logScoreDirect = (a: float, b: float): result<float, Operation.Error.t> =>
|
||||||
if a == 0.0 {
|
if a == 0.0 {
|
||||||
Error(Operation.NegativeInfinityError)
|
Error(Operation.NegativeInfinityError)
|
||||||
} else if b == 0.0 {
|
} else if b == 0.0 {
|
||||||
|
@ -10,5 +11,15 @@ module KLDivergence = {
|
||||||
let quot = a /. b
|
let quot = a /. b
|
||||||
quot < 0.0 ? Error(Operation.ComplexNumberError) : Ok(b *. logFn(quot))
|
quot < 0.0 ? Error(Operation.ComplexNumberError) : Ok(b *. logFn(quot))
|
||||||
}
|
}
|
||||||
let multiply = (a: float, b: float): result<float, Operation.Error.t> => Ok(a *. b)
|
let logScoreWithThreshold = (~eps: float, a: float, b: float): result<float, Operation.Error.t> =>
|
||||||
|
if abs_float(a) < eps {
|
||||||
|
Ok(0.0)
|
||||||
|
} else {
|
||||||
|
logScoreDirect(a, b)
|
||||||
|
}
|
||||||
|
let logScore = (~eps: option<float>=?, a: float, b: float): result<float, Operation.Error.t> =>
|
||||||
|
switch eps {
|
||||||
|
| None => logScoreDirect(a, b)
|
||||||
|
| Some(eps') => logScoreWithThreshold(~eps=eps', a, b)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
Loading…
Reference in New Issue
Block a user