progress on klDivergence (still working)

Value: [1e-5 to 1e-2]
This commit is contained in:
Quinn Dougherty 2022-05-05 15:37:28 -04:00
parent 32a881d06a
commit b49865d3aa
12 changed files with 68 additions and 93 deletions

View File

@ -0,0 +1,34 @@
open Jest
open Expect
open TestHelpers
describe("kl divergence", () => {
let klDivergence = DistributionOperation.Constructors.klDivergence(~env)
test("", () => {
exception KlFailed
let lowAnswer = 4.3526e0
let highAnswer = 8.5382e0
let lowPrediction = 4.3526e0
let highPrediction = 1.2345e1
let answer =
uniformMakeR(lowAnswer, highAnswer)->E.R2.errMap(s => DistributionTypes.ArgumentError(s))
let prediction =
uniformMakeR(lowPrediction, highPrediction)->E.R2.errMap(s => DistributionTypes.ArgumentError(
s,
))
// integral along the support of the answer of answer.pdf(x) times log of prediction.pdf(x) divided by answer.pdf(x) dx
let analyticalKl =
-1.0 /.
(highAnswer -. lowAnswer) *.
Js.Math.log((highAnswer -. lowAnswer) /. (highPrediction -. lowPrediction)) *.
(highAnswer -. lowAnswer)
let kl = E.R.liftJoin2(klDivergence, prediction, answer)
switch kl {
| Ok(kl') => kl'->expect->toBeCloseTo(analyticalKl)
| Error(err) => {
Js.Console.log(DistributionTypes.Error.toString(err))
raise(KlFailed)
}
}
})
})

View File

@ -10,8 +10,8 @@ type env = {
} }
let defaultEnv = { let defaultEnv = {
sampleCount: 10000, sampleCount: MagicNumbers.Environment.defaultSampleCount,
xyPointLength: 10000, xyPointLength: MagicNumbers.Environment.defaultXYPointLength,
} }
type outputType = type outputType =
@ -128,7 +128,7 @@ let rec run = (~env, functionCallInfo: functionCallInfo): outputType => {
let fromDistFn = ( let fromDistFn = (
subFnName: DistributionTypes.DistributionOperation.fromDist, subFnName: DistributionTypes.DistributionOperation.fromDist,
dist: genericDist, dist: genericDist,
) => { ): outputType => {
let response = switch subFnName { let response = switch subFnName {
| ToFloat(distToFloatOperation) => | ToFloat(distToFloatOperation) =>
GenericDist.toFloatOperation(dist, ~toPointSetFn, ~distToFloatOperation) GenericDist.toFloatOperation(dist, ~toPointSetFn, ~distToFloatOperation)
@ -261,7 +261,7 @@ module Constructors = {
let pdf = (~env, dist, f) => C.pdf(dist, f)->run(~env)->toFloatR let pdf = (~env, dist, f) => C.pdf(dist, f)->run(~env)->toFloatR
let normalize = (~env, dist) => C.normalize(dist)->run(~env)->toDistR let normalize = (~env, dist) => C.normalize(dist)->run(~env)->toDistR
let isNormalized = (~env, dist) => C.isNormalized(dist)->run(~env)->toBoolR let isNormalized = (~env, dist) => C.isNormalized(dist)->run(~env)->toBoolR
let logScore = (~env, dist1, dist2) => C.logScore(dist1, dist2)->run(~env)->toFloatR let klDivergence = (~env, dist1, dist2) => C.klDivergence(dist1, dist2)->run(~env)->toFloatR
let toPointSet = (~env, dist) => C.toPointSet(dist)->run(~env)->toDistR let toPointSet = (~env, dist) => C.toPointSet(dist)->run(~env)->toDistR
let toSampleSet = (~env, dist, n) => C.toSampleSet(dist, n)->run(~env)->toDistR let toSampleSet = (~env, dist, n) => C.toSampleSet(dist, n)->run(~env)->toDistR
let fromSamples = (~env, xs) => C.fromSamples(xs)->run(~env)->toDistR let fromSamples = (~env, xs) => C.fromSamples(xs)->run(~env)->toDistR

View File

@ -60,7 +60,7 @@ module Constructors: {
@genType @genType
let isNormalized: (~env: env, genericDist) => result<bool, error> let isNormalized: (~env: env, genericDist) => result<bool, error>
@genType @genType
let logScore: (~env: env, genericDist, genericDist) => result<float, error> let klDivergence: (~env: env, genericDist, genericDist) => result<float, error>
@genType @genType
let toPointSet: (~env: env, genericDist) => result<genericDist, error> let toPointSet: (~env: env, genericDist) => result<genericDist, error>
@genType @genType

View File

@ -160,7 +160,7 @@ module Constructors = {
let fromSamples = (xs): t => FromSamples(xs) let fromSamples = (xs): t => FromSamples(xs)
let truncate = (dist, left, right): t => FromDist(ToDist(Truncate(left, right)), dist) let truncate = (dist, left, right): t => FromDist(ToDist(Truncate(left, right)), dist)
let inspect = (dist): t => FromDist(ToDist(Inspect), dist) let inspect = (dist): t => FromDist(ToDist(Inspect), dist)
let logScore = (dist1, dist2): t => FromDist(ToScore(KLDivergence(dist2)), dist1) let klDivergence = (dist1, dist2): t => FromDist(ToScore(KLDivergence(dist2)), dist1)
let scalePower = (dist, n): t => FromDist(ToDist(Scale(#Power, n)), dist) let scalePower = (dist, n): t => FromDist(ToDist(Scale(#Power, n)), dist)
let scaleLogarithm = (dist, n): t => FromDist(ToDist(Scale(#Logarithm, n)), dist) let scaleLogarithm = (dist, n): t => FromDist(ToDist(Scale(#Logarithm, n)), dist)
let scaleLogarithmWithThreshold = (dist, n, eps): t => FromDist( let scaleLogarithmWithThreshold = (dist, n, eps): t => FromDist(

View File

@ -270,24 +270,10 @@ module T = Dist({
let variance = (t: t): float => let variance = (t: t): float =>
XYShape.Analysis.getVarianceDangerously(t, mean, Analysis.getMeanOfSquares) XYShape.Analysis.getVarianceDangerously(t, mean, Analysis.getMeanOfSquares)
let klDivergence = (base: t, reference: t) => { let klDivergence = (prediction: t, answer: t) => {
let referenceIsZero = switch Distributions.Common.isZeroEverywhere( combinePointwise(PointSetDist_Scoring.KLDivergence.integrand, prediction, answer)
PointSetTypes.Continuous(reference), |> E.R.fmap(shapeMap(XYShape.T.filterYValues(Js.Float.isFinite)))
) { |> E.R.fmap(integralEndY)
| Continuous(b) => b
| _ => false
}
if referenceIsZero {
Ok(0.0)
} else {
combinePointwise(
PointSetDist_Scoring.KLDivergence.logScore(~eps=MagicNumbers.Epsilon.ten),
base,
reference,
)
|> E.R.fmap(shapeMap(XYShape.T.filterYValues(Js.Float.isFinite)))
|> E.R.fmap(integralEndY)
}
} }
}) })

View File

@ -229,21 +229,11 @@ module T = Dist({
XYShape.Analysis.getVarianceDangerously(t, mean, getMeanOfSquares) XYShape.Analysis.getVarianceDangerously(t, mean, getMeanOfSquares)
} }
let klDivergence = (base: t, reference: t) => { let klDivergence = (prediction: t, answer: t) => {
let referenceIsZero = switch Distributions.Common.isZeroEverywhere( combinePointwise(
PointSetTypes.Discrete(reference), ~fn=PointSetDist_Scoring.KLDivergence.integrand,
) { prediction,
| Discrete(b) => b answer,
| _ => false ) |> E.R2.bind(integralEndYResult)
}
if referenceIsZero {
Ok(0.0)
} else {
combinePointwise(
~fn=PointSetDist_Scoring.KLDivergence.logScore(~eps=MagicNumbers.Epsilon.ten),
base,
reference,
) |> E.R2.bind(integralEndYResult)
}
} }
}) })

View File

@ -96,18 +96,4 @@ module Common = {
None None
| (Some(s1), Some(s2)) => combineFn(s1, s2) | (Some(s1), Some(s2)) => combineFn(s1, s2)
} }
let isZeroEverywhere = (d: PointSetTypes.pointSetDist) => {
let isZero = (x: float): bool => x == 0.0
PointSetTypes.ShapeMonad.fmap(
d,
(
mixed =>
E.A.all(isZero, mixed.continuous.xyShape.ys) &&
E.A.all(isZero, mixed.discrete.xyShape.ys),
disc => E.A.all(isZero, disc.xyShape.ys),
cont => E.A.all(isZero, cont.xyShape.ys),
),
)
}
} }

View File

@ -301,22 +301,10 @@ module T = Dist({
} }
} }
let klDivergence = (base: t, reference: t) => { let klDivergence = (prediction: t, answer: t) => {
let referenceIsZero = switch Distributions.Common.isZeroEverywhere( combinePointwise(PointSetDist_Scoring.KLDivergence.integrand, prediction, answer) |> E.R.fmap(
PointSetTypes.Mixed(reference), integralEndY,
) { )
| Mixed(b) => b
| _ => false
}
if referenceIsZero {
Ok(0.0)
} else {
combinePointwise(
PointSetDist_Scoring.KLDivergence.logScore(~eps=MagicNumbers.Epsilon.ten),
base,
reference,
) |> E.R.fmap(integralEndY)
}
} }
}) })

View File

@ -202,9 +202,9 @@ module T = Dist({
| (Discrete(t1), Discrete(t2)) => Discrete.T.klDivergence(t1, t2) | (Discrete(t1), Discrete(t2)) => Discrete.T.klDivergence(t1, t2)
| (Mixed(t1), Mixed(t2)) => Mixed.T.klDivergence(t1, t2) | (Mixed(t1), Mixed(t2)) => Mixed.T.klDivergence(t1, t2)
| _ => { | _ => {
let t1 = toMixed(t1) let t1' = toMixed(t1)
let t2 = toMixed(t2) let t2' = toMixed(t2)
Mixed.T.klDivergence(t1, t2) Mixed.T.klDivergence(t1', t2')
} }
} }
}) })

View File

@ -1,25 +1,15 @@
module KLDivergence = { module KLDivergence = {
let logFn = Js.Math.log let logFn = Js.Math.log
let subtraction = (a, b) => Ok(a -. b) let integrand = (predictionElement: float, answerElement: float): result<
let multiply = (a: float, b: float): result<float, Operation.Error.t> => Ok(a *. b) float,
let logScoreDirect = (a: float, b: float): result<float, Operation.Error.t> => Operation.Error.t,
if a == 0.0 { > =>
if predictionElement == 0.0 {
Error(Operation.NegativeInfinityError) Error(Operation.NegativeInfinityError)
} else if b == 0.0 { } else if answerElement == 0.0 {
Ok(b) Ok(answerElement)
} else { } else {
let quot = a /. b let quot = predictionElement /. answerElement
quot < 0.0 ? Error(Operation.ComplexNumberError) : Ok(b *. logFn(quot)) quot < 0.0 ? Error(Operation.ComplexNumberError) : Ok(-.answerElement *. logFn(quot))
}
let logScoreWithThreshold = (~eps: float, a: float, b: float): result<float, Operation.Error.t> =>
if abs_float(a) < eps {
Ok(0.0)
} else {
logScoreDirect(a, b)
}
let logScore = (~eps: option<float>=?, a: float, b: float): result<float, Operation.Error.t> =>
switch eps {
| None => logScoreDirect(a, b)
| Some(eps') => logScoreWithThreshold(~eps=eps', a, b)
} }
} }

View File

@ -6,6 +6,7 @@ module Math = {
module Epsilon = { module Epsilon = {
let ten = 1e-10 let ten = 1e-10
let seven = 1e-7 let seven = 1e-7
let five = 1e-5
} }
module Environment = { module Environment = {

View File

@ -468,7 +468,7 @@ module Range = {
// TODO: I think this isn't needed by any functions anymore. // TODO: I think this isn't needed by any functions anymore.
let stepsToContinuous = t => { let stepsToContinuous = t => {
// TODO: It would be nicer if this the diff didn't change the first element, and also maybe if there were a more elegant way of doing this. // TODO: It would be nicer if this the diff didn't change the first element, and also maybe if there were a more elegant way of doing this.
let diff = T.xTotalRange(t) |> (r => r *. 0.00001) let diff = T.xTotalRange(t) |> (r => r *. MagicNumbers.Epsilon.five)
let items = switch E.A.toRanges(Belt.Array.zip(t.xs, t.ys)) { let items = switch E.A.toRanges(Belt.Array.zip(t.xs, t.ys)) {
| Ok(items) => | Ok(items) =>
Some( Some(