klDivergence is now LogarithmWithThreshold

This commit is contained in:
Quinn Dougherty 2022-05-04 13:53:32 -04:00
parent cfa83e552d
commit 898547f3a3
4 changed files with 26 additions and 8 deletions

View File

@ -280,7 +280,11 @@ module T = Dist({
if referenceIsZero { if referenceIsZero {
Ok(0.0) Ok(0.0)
} else { } else {
combinePointwise(PointSetDist_Scoring.KLDivergence.logScore, base, reference) combinePointwise(
PointSetDist_Scoring.KLDivergence.logScore(~eps=MagicNumbers.Epsilon.seven),
base,
reference,
)
|> E.R.fmap(shapeMap(XYShape.T.filterYValues(Js.Float.isFinite))) |> E.R.fmap(shapeMap(XYShape.T.filterYValues(Js.Float.isFinite)))
|> E.R.fmap(integralEndY) |> E.R.fmap(integralEndY)
} }
@ -289,7 +293,8 @@ module T = Dist({
let isNormalized = (t: t): bool => { let isNormalized = (t: t): bool => {
let areaUnderIntegral = t |> updateIntegralCache(Some(T.integral(t))) |> T.integralEndY let areaUnderIntegral = t |> updateIntegralCache(Some(T.integral(t))) |> T.integralEndY
areaUnderIntegral < 1. +. 1e-7 && areaUnderIntegral > 1. -. 1e-7 areaUnderIntegral < 1. +. MagicNumbers.Epsilon.seven &&
areaUnderIntegral > 1. -. MagicNumbers.Epsilon.seven
} }
let downsampleEquallyOverX = (length, t): t => let downsampleEquallyOverX = (length, t): t =>

View File

@ -240,7 +240,7 @@ module T = Dist({
Ok(0.0) Ok(0.0)
} else { } else {
combinePointwise( combinePointwise(
~fn=PointSetDist_Scoring.KLDivergence.logScore, ~fn=PointSetDist_Scoring.KLDivergence.logScore(~eps=MagicNumbers.Epsilon.ten),
base, base,
reference, reference,
) |> E.R2.bind(integralEndYResult) ) |> E.R2.bind(integralEndYResult)

View File

@ -311,9 +311,11 @@ module T = Dist({
if referenceIsZero { if referenceIsZero {
Ok(0.0) Ok(0.0)
} else { } else {
combinePointwise(PointSetDist_Scoring.KLDivergence.logScore, base, reference) |> E.R.fmap( combinePointwise(
integralEndY, PointSetDist_Scoring.KLDivergence.logScore(~eps=MagicNumbers.Epsilon.ten),
) base,
reference,
) |> E.R.fmap(integralEndY)
} }
} }
}) })

View File

@ -1,7 +1,8 @@
module KLDivergence = { module KLDivergence = {
let logFn = Js.Math.log let logFn = Js.Math.log
let subtraction = (a, b) => Ok(a -. b) let subtraction = (a, b) => Ok(a -. b)
let logScore = (a: float, b: float): result<float, Operation.Error.t> => let multiply = (a: float, b: float): result<float, Operation.Error.t> => Ok(a *. b)
let logScoreDirect = (a: float, b: float): result<float, Operation.Error.t> =>
if a == 0.0 { if a == 0.0 {
Error(Operation.NegativeInfinityError) Error(Operation.NegativeInfinityError)
} else if b == 0.0 { } else if b == 0.0 {
@ -10,5 +11,15 @@ module KLDivergence = {
let quot = a /. b let quot = a /. b
quot < 0.0 ? Error(Operation.ComplexNumberError) : Ok(b *. logFn(quot)) quot < 0.0 ? Error(Operation.ComplexNumberError) : Ok(b *. logFn(quot))
} }
let multiply = (a: float, b: float): result<float, Operation.Error.t> => Ok(a *. b) let logScoreWithThreshold = (~eps: float, a: float, b: float): result<float, Operation.Error.t> =>
if abs_float(a) < eps {
Ok(0.0)
} else {
logScoreDirect(a, b)
}
let logScore = (~eps: option<float>=?, a: float, b: float): result<float, Operation.Error.t> =>
switch eps {
| None => logScoreDirect(a, b)
| Some(eps') => logScoreWithThreshold(~eps=eps', a, b)
}
} }