progress on klDivergence (still working)
Value: [1e-5 to 1e-2]
This commit is contained in:
parent
32a881d06a
commit
b49865d3aa
|
@ -0,0 +1,34 @@
|
|||
open Jest
|
||||
open Expect
|
||||
open TestHelpers
|
||||
|
||||
describe("kl divergence", () => {
|
||||
let klDivergence = DistributionOperation.Constructors.klDivergence(~env)
|
||||
test("", () => {
|
||||
exception KlFailed
|
||||
let lowAnswer = 4.3526e0
|
||||
let highAnswer = 8.5382e0
|
||||
let lowPrediction = 4.3526e0
|
||||
let highPrediction = 1.2345e1
|
||||
let answer =
|
||||
uniformMakeR(lowAnswer, highAnswer)->E.R2.errMap(s => DistributionTypes.ArgumentError(s))
|
||||
let prediction =
|
||||
uniformMakeR(lowPrediction, highPrediction)->E.R2.errMap(s => DistributionTypes.ArgumentError(
|
||||
s,
|
||||
))
|
||||
// integral along the support of the answer of answer.pdf(x) times log of prediction.pdf(x) divided by answer.pdf(x) dx
|
||||
let analyticalKl =
|
||||
-1.0 /.
|
||||
(highAnswer -. lowAnswer) *.
|
||||
Js.Math.log((highAnswer -. lowAnswer) /. (highPrediction -. lowPrediction)) *.
|
||||
(highAnswer -. lowAnswer)
|
||||
let kl = E.R.liftJoin2(klDivergence, prediction, answer)
|
||||
switch kl {
|
||||
| Ok(kl') => kl'->expect->toBeCloseTo(analyticalKl)
|
||||
| Error(err) => {
|
||||
Js.Console.log(DistributionTypes.Error.toString(err))
|
||||
raise(KlFailed)
|
||||
}
|
||||
}
|
||||
})
|
||||
})
|
|
@ -10,8 +10,8 @@ type env = {
|
|||
}
|
||||
|
||||
let defaultEnv = {
|
||||
sampleCount: 10000,
|
||||
xyPointLength: 10000,
|
||||
sampleCount: MagicNumbers.Environment.defaultSampleCount,
|
||||
xyPointLength: MagicNumbers.Environment.defaultXYPointLength,
|
||||
}
|
||||
|
||||
type outputType =
|
||||
|
@ -128,7 +128,7 @@ let rec run = (~env, functionCallInfo: functionCallInfo): outputType => {
|
|||
let fromDistFn = (
|
||||
subFnName: DistributionTypes.DistributionOperation.fromDist,
|
||||
dist: genericDist,
|
||||
) => {
|
||||
): outputType => {
|
||||
let response = switch subFnName {
|
||||
| ToFloat(distToFloatOperation) =>
|
||||
GenericDist.toFloatOperation(dist, ~toPointSetFn, ~distToFloatOperation)
|
||||
|
@ -261,7 +261,7 @@ module Constructors = {
|
|||
let pdf = (~env, dist, f) => C.pdf(dist, f)->run(~env)->toFloatR
|
||||
let normalize = (~env, dist) => C.normalize(dist)->run(~env)->toDistR
|
||||
let isNormalized = (~env, dist) => C.isNormalized(dist)->run(~env)->toBoolR
|
||||
let logScore = (~env, dist1, dist2) => C.logScore(dist1, dist2)->run(~env)->toFloatR
|
||||
let klDivergence = (~env, dist1, dist2) => C.klDivergence(dist1, dist2)->run(~env)->toFloatR
|
||||
let toPointSet = (~env, dist) => C.toPointSet(dist)->run(~env)->toDistR
|
||||
let toSampleSet = (~env, dist, n) => C.toSampleSet(dist, n)->run(~env)->toDistR
|
||||
let fromSamples = (~env, xs) => C.fromSamples(xs)->run(~env)->toDistR
|
||||
|
|
|
@ -60,7 +60,7 @@ module Constructors: {
|
|||
@genType
|
||||
let isNormalized: (~env: env, genericDist) => result<bool, error>
|
||||
@genType
|
||||
let logScore: (~env: env, genericDist, genericDist) => result<float, error>
|
||||
let klDivergence: (~env: env, genericDist, genericDist) => result<float, error>
|
||||
@genType
|
||||
let toPointSet: (~env: env, genericDist) => result<genericDist, error>
|
||||
@genType
|
||||
|
|
|
@ -160,7 +160,7 @@ module Constructors = {
|
|||
let fromSamples = (xs): t => FromSamples(xs)
|
||||
let truncate = (dist, left, right): t => FromDist(ToDist(Truncate(left, right)), dist)
|
||||
let inspect = (dist): t => FromDist(ToDist(Inspect), dist)
|
||||
let logScore = (dist1, dist2): t => FromDist(ToScore(KLDivergence(dist2)), dist1)
|
||||
let klDivergence = (dist1, dist2): t => FromDist(ToScore(KLDivergence(dist2)), dist1)
|
||||
let scalePower = (dist, n): t => FromDist(ToDist(Scale(#Power, n)), dist)
|
||||
let scaleLogarithm = (dist, n): t => FromDist(ToDist(Scale(#Logarithm, n)), dist)
|
||||
let scaleLogarithmWithThreshold = (dist, n, eps): t => FromDist(
|
||||
|
|
|
@ -270,25 +270,11 @@ module T = Dist({
|
|||
let variance = (t: t): float =>
|
||||
XYShape.Analysis.getVarianceDangerously(t, mean, Analysis.getMeanOfSquares)
|
||||
|
||||
let klDivergence = (base: t, reference: t) => {
|
||||
let referenceIsZero = switch Distributions.Common.isZeroEverywhere(
|
||||
PointSetTypes.Continuous(reference),
|
||||
) {
|
||||
| Continuous(b) => b
|
||||
| _ => false
|
||||
}
|
||||
if referenceIsZero {
|
||||
Ok(0.0)
|
||||
} else {
|
||||
combinePointwise(
|
||||
PointSetDist_Scoring.KLDivergence.logScore(~eps=MagicNumbers.Epsilon.ten),
|
||||
base,
|
||||
reference,
|
||||
)
|
||||
let klDivergence = (prediction: t, answer: t) => {
|
||||
combinePointwise(PointSetDist_Scoring.KLDivergence.integrand, prediction, answer)
|
||||
|> E.R.fmap(shapeMap(XYShape.T.filterYValues(Js.Float.isFinite)))
|
||||
|> E.R.fmap(integralEndY)
|
||||
}
|
||||
}
|
||||
})
|
||||
|
||||
let isNormalized = (t: t): bool => {
|
||||
|
|
|
@ -229,21 +229,11 @@ module T = Dist({
|
|||
XYShape.Analysis.getVarianceDangerously(t, mean, getMeanOfSquares)
|
||||
}
|
||||
|
||||
let klDivergence = (base: t, reference: t) => {
|
||||
let referenceIsZero = switch Distributions.Common.isZeroEverywhere(
|
||||
PointSetTypes.Discrete(reference),
|
||||
) {
|
||||
| Discrete(b) => b
|
||||
| _ => false
|
||||
}
|
||||
if referenceIsZero {
|
||||
Ok(0.0)
|
||||
} else {
|
||||
let klDivergence = (prediction: t, answer: t) => {
|
||||
combinePointwise(
|
||||
~fn=PointSetDist_Scoring.KLDivergence.logScore(~eps=MagicNumbers.Epsilon.ten),
|
||||
base,
|
||||
reference,
|
||||
~fn=PointSetDist_Scoring.KLDivergence.integrand,
|
||||
prediction,
|
||||
answer,
|
||||
) |> E.R2.bind(integralEndYResult)
|
||||
}
|
||||
}
|
||||
})
|
||||
|
|
|
@ -96,18 +96,4 @@ module Common = {
|
|||
None
|
||||
| (Some(s1), Some(s2)) => combineFn(s1, s2)
|
||||
}
|
||||
|
||||
let isZeroEverywhere = (d: PointSetTypes.pointSetDist) => {
|
||||
let isZero = (x: float): bool => x == 0.0
|
||||
PointSetTypes.ShapeMonad.fmap(
|
||||
d,
|
||||
(
|
||||
mixed =>
|
||||
E.A.all(isZero, mixed.continuous.xyShape.ys) &&
|
||||
E.A.all(isZero, mixed.discrete.xyShape.ys),
|
||||
disc => E.A.all(isZero, disc.xyShape.ys),
|
||||
cont => E.A.all(isZero, cont.xyShape.ys),
|
||||
),
|
||||
)
|
||||
}
|
||||
}
|
||||
|
|
|
@ -301,22 +301,10 @@ module T = Dist({
|
|||
}
|
||||
}
|
||||
|
||||
let klDivergence = (base: t, reference: t) => {
|
||||
let referenceIsZero = switch Distributions.Common.isZeroEverywhere(
|
||||
PointSetTypes.Mixed(reference),
|
||||
) {
|
||||
| Mixed(b) => b
|
||||
| _ => false
|
||||
}
|
||||
if referenceIsZero {
|
||||
Ok(0.0)
|
||||
} else {
|
||||
combinePointwise(
|
||||
PointSetDist_Scoring.KLDivergence.logScore(~eps=MagicNumbers.Epsilon.ten),
|
||||
base,
|
||||
reference,
|
||||
) |> E.R.fmap(integralEndY)
|
||||
}
|
||||
let klDivergence = (prediction: t, answer: t) => {
|
||||
combinePointwise(PointSetDist_Scoring.KLDivergence.integrand, prediction, answer) |> E.R.fmap(
|
||||
integralEndY,
|
||||
)
|
||||
}
|
||||
})
|
||||
|
||||
|
|
|
@ -202,9 +202,9 @@ module T = Dist({
|
|||
| (Discrete(t1), Discrete(t2)) => Discrete.T.klDivergence(t1, t2)
|
||||
| (Mixed(t1), Mixed(t2)) => Mixed.T.klDivergence(t1, t2)
|
||||
| _ => {
|
||||
let t1 = toMixed(t1)
|
||||
let t2 = toMixed(t2)
|
||||
Mixed.T.klDivergence(t1, t2)
|
||||
let t1' = toMixed(t1)
|
||||
let t2' = toMixed(t2)
|
||||
Mixed.T.klDivergence(t1', t2')
|
||||
}
|
||||
}
|
||||
})
|
||||
|
|
|
@ -1,25 +1,15 @@
|
|||
module KLDivergence = {
|
||||
let logFn = Js.Math.log
|
||||
let subtraction = (a, b) => Ok(a -. b)
|
||||
let multiply = (a: float, b: float): result<float, Operation.Error.t> => Ok(a *. b)
|
||||
let logScoreDirect = (a: float, b: float): result<float, Operation.Error.t> =>
|
||||
if a == 0.0 {
|
||||
let integrand = (predictionElement: float, answerElement: float): result<
|
||||
float,
|
||||
Operation.Error.t,
|
||||
> =>
|
||||
if predictionElement == 0.0 {
|
||||
Error(Operation.NegativeInfinityError)
|
||||
} else if b == 0.0 {
|
||||
Ok(b)
|
||||
} else if answerElement == 0.0 {
|
||||
Ok(answerElement)
|
||||
} else {
|
||||
let quot = a /. b
|
||||
quot < 0.0 ? Error(Operation.ComplexNumberError) : Ok(b *. logFn(quot))
|
||||
}
|
||||
let logScoreWithThreshold = (~eps: float, a: float, b: float): result<float, Operation.Error.t> =>
|
||||
if abs_float(a) < eps {
|
||||
Ok(0.0)
|
||||
} else {
|
||||
logScoreDirect(a, b)
|
||||
}
|
||||
let logScore = (~eps: option<float>=?, a: float, b: float): result<float, Operation.Error.t> =>
|
||||
switch eps {
|
||||
| None => logScoreDirect(a, b)
|
||||
| Some(eps') => logScoreWithThreshold(~eps=eps', a, b)
|
||||
let quot = predictionElement /. answerElement
|
||||
quot < 0.0 ? Error(Operation.ComplexNumberError) : Ok(-.answerElement *. logFn(quot))
|
||||
}
|
||||
}
|
||||
|
|
|
@ -6,6 +6,7 @@ module Math = {
|
|||
module Epsilon = {
|
||||
let ten = 1e-10
|
||||
let seven = 1e-7
|
||||
let five = 1e-5
|
||||
}
|
||||
|
||||
module Environment = {
|
||||
|
|
|
@ -468,7 +468,7 @@ module Range = {
|
|||
// TODO: I think this isn't needed by any functions anymore.
|
||||
let stepsToContinuous = t => {
|
||||
// TODO: It would be nicer if this the diff didn't change the first element, and also maybe if there were a more elegant way of doing this.
|
||||
let diff = T.xTotalRange(t) |> (r => r *. 0.00001)
|
||||
let diff = T.xTotalRange(t) |> (r => r *. MagicNumbers.Epsilon.five)
|
||||
let items = switch E.A.toRanges(Belt.Array.zip(t.xs, t.ys)) {
|
||||
| Ok(items) =>
|
||||
Some(
|
||||
|
|
Loading…
Reference in New Issue
Block a user