From 898547f3a3f7e05af9cc6618a1cfc867f8f85989 Mon Sep 17 00:00:00 2001 From: Quinn Dougherty Date: Wed, 4 May 2022 13:53:32 -0400 Subject: [PATCH] `klDivergence` is now `LogarithmWithThreshold` --- .../Distributions/PointSetDist/Continuous.res | 9 +++++++-- .../Distributions/PointSetDist/Discrete.res | 2 +- .../rescript/Distributions/PointSetDist/Mixed.res | 8 +++++--- .../PointSetDist/PointSetDist_Scoring.res | 15 +++++++++++++-- 4 files changed, 26 insertions(+), 8 deletions(-) diff --git a/packages/squiggle-lang/src/rescript/Distributions/PointSetDist/Continuous.res b/packages/squiggle-lang/src/rescript/Distributions/PointSetDist/Continuous.res index 93f6c1c9..c713900f 100644 --- a/packages/squiggle-lang/src/rescript/Distributions/PointSetDist/Continuous.res +++ b/packages/squiggle-lang/src/rescript/Distributions/PointSetDist/Continuous.res @@ -280,7 +280,11 @@ module T = Dist({ if referenceIsZero { Ok(0.0) } else { - combinePointwise(PointSetDist_Scoring.KLDivergence.logScore, base, reference) + combinePointwise( + PointSetDist_Scoring.KLDivergence.logScore(~eps=MagicNumbers.Epsilon.seven), + base, + reference, + ) |> E.R.fmap(shapeMap(XYShape.T.filterYValues(Js.Float.isFinite))) |> E.R.fmap(integralEndY) } @@ -289,7 +293,8 @@ module T = Dist({ let isNormalized = (t: t): bool => { let areaUnderIntegral = t |> updateIntegralCache(Some(T.integral(t))) |> T.integralEndY - areaUnderIntegral < 1. +. 1e-7 && areaUnderIntegral > 1. -. 1e-7 + areaUnderIntegral < 1. +. MagicNumbers.Epsilon.seven && + areaUnderIntegral > 1. -. MagicNumbers.Epsilon.seven } let downsampleEquallyOverX = (length, t): t => diff --git a/packages/squiggle-lang/src/rescript/Distributions/PointSetDist/Discrete.res b/packages/squiggle-lang/src/rescript/Distributions/PointSetDist/Discrete.res index 8ec9410f..372eb3d1 100644 --- a/packages/squiggle-lang/src/rescript/Distributions/PointSetDist/Discrete.res +++ b/packages/squiggle-lang/src/rescript/Distributions/PointSetDist/Discrete.res @@ -240,7 +240,7 @@ module T = Dist({ Ok(0.0) } else { combinePointwise( - ~fn=PointSetDist_Scoring.KLDivergence.logScore, + ~fn=PointSetDist_Scoring.KLDivergence.logScore(~eps=MagicNumbers.Epsilon.ten), base, reference, ) |> E.R2.bind(integralEndYResult) diff --git a/packages/squiggle-lang/src/rescript/Distributions/PointSetDist/Mixed.res b/packages/squiggle-lang/src/rescript/Distributions/PointSetDist/Mixed.res index 66f353f6..1521a13c 100644 --- a/packages/squiggle-lang/src/rescript/Distributions/PointSetDist/Mixed.res +++ b/packages/squiggle-lang/src/rescript/Distributions/PointSetDist/Mixed.res @@ -311,9 +311,11 @@ module T = Dist({ if referenceIsZero { Ok(0.0) } else { - combinePointwise(PointSetDist_Scoring.KLDivergence.logScore, base, reference) |> E.R.fmap( - integralEndY, - ) + combinePointwise( + PointSetDist_Scoring.KLDivergence.logScore(~eps=MagicNumbers.Epsilon.ten), + base, + reference, + ) |> E.R.fmap(integralEndY) } } }) diff --git a/packages/squiggle-lang/src/rescript/Distributions/PointSetDist/PointSetDist_Scoring.res b/packages/squiggle-lang/src/rescript/Distributions/PointSetDist/PointSetDist_Scoring.res index 4a607281..d43aeba6 100644 --- a/packages/squiggle-lang/src/rescript/Distributions/PointSetDist/PointSetDist_Scoring.res +++ b/packages/squiggle-lang/src/rescript/Distributions/PointSetDist/PointSetDist_Scoring.res @@ -1,7 +1,8 @@ module KLDivergence = { let logFn = Js.Math.log let subtraction = (a, b) => Ok(a -. b) - let logScore = (a: float, b: float): result => + let multiply = (a: float, b: float): result => Ok(a *. b) + let logScoreDirect = (a: float, b: float): result => if a == 0.0 { Error(Operation.NegativeInfinityError) } else if b == 0.0 { @@ -10,5 +11,15 @@ module KLDivergence = { let quot = a /. b quot < 0.0 ? Error(Operation.ComplexNumberError) : Ok(b *. logFn(quot)) } - let multiply = (a: float, b: float): result => Ok(a *. b) + let logScoreWithThreshold = (~eps: float, a: float, b: float): result => + if abs_float(a) < eps { + Ok(0.0) + } else { + logScoreDirect(a, b) + } + let logScore = (~eps: option=?, a: float, b: float): result => + switch eps { + | None => logScoreDirect(a, b) + | Some(eps') => logScoreWithThreshold(~eps=eps', a, b) + } }