progress on klDivergence (still working)
Value: [1e-5 to 1e-2]
This commit is contained in:
parent
32a881d06a
commit
b49865d3aa
|
@ -0,0 +1,34 @@
|
||||||
|
open Jest
|
||||||
|
open Expect
|
||||||
|
open TestHelpers
|
||||||
|
|
||||||
|
describe("kl divergence", () => {
|
||||||
|
let klDivergence = DistributionOperation.Constructors.klDivergence(~env)
|
||||||
|
test("", () => {
|
||||||
|
exception KlFailed
|
||||||
|
let lowAnswer = 4.3526e0
|
||||||
|
let highAnswer = 8.5382e0
|
||||||
|
let lowPrediction = 4.3526e0
|
||||||
|
let highPrediction = 1.2345e1
|
||||||
|
let answer =
|
||||||
|
uniformMakeR(lowAnswer, highAnswer)->E.R2.errMap(s => DistributionTypes.ArgumentError(s))
|
||||||
|
let prediction =
|
||||||
|
uniformMakeR(lowPrediction, highPrediction)->E.R2.errMap(s => DistributionTypes.ArgumentError(
|
||||||
|
s,
|
||||||
|
))
|
||||||
|
// integral along the support of the answer of answer.pdf(x) times log of prediction.pdf(x) divided by answer.pdf(x) dx
|
||||||
|
let analyticalKl =
|
||||||
|
-1.0 /.
|
||||||
|
(highAnswer -. lowAnswer) *.
|
||||||
|
Js.Math.log((highAnswer -. lowAnswer) /. (highPrediction -. lowPrediction)) *.
|
||||||
|
(highAnswer -. lowAnswer)
|
||||||
|
let kl = E.R.liftJoin2(klDivergence, prediction, answer)
|
||||||
|
switch kl {
|
||||||
|
| Ok(kl') => kl'->expect->toBeCloseTo(analyticalKl)
|
||||||
|
| Error(err) => {
|
||||||
|
Js.Console.log(DistributionTypes.Error.toString(err))
|
||||||
|
raise(KlFailed)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
})
|
||||||
|
})
|
|
@ -10,8 +10,8 @@ type env = {
|
||||||
}
|
}
|
||||||
|
|
||||||
let defaultEnv = {
|
let defaultEnv = {
|
||||||
sampleCount: 10000,
|
sampleCount: MagicNumbers.Environment.defaultSampleCount,
|
||||||
xyPointLength: 10000,
|
xyPointLength: MagicNumbers.Environment.defaultXYPointLength,
|
||||||
}
|
}
|
||||||
|
|
||||||
type outputType =
|
type outputType =
|
||||||
|
@ -128,7 +128,7 @@ let rec run = (~env, functionCallInfo: functionCallInfo): outputType => {
|
||||||
let fromDistFn = (
|
let fromDistFn = (
|
||||||
subFnName: DistributionTypes.DistributionOperation.fromDist,
|
subFnName: DistributionTypes.DistributionOperation.fromDist,
|
||||||
dist: genericDist,
|
dist: genericDist,
|
||||||
) => {
|
): outputType => {
|
||||||
let response = switch subFnName {
|
let response = switch subFnName {
|
||||||
| ToFloat(distToFloatOperation) =>
|
| ToFloat(distToFloatOperation) =>
|
||||||
GenericDist.toFloatOperation(dist, ~toPointSetFn, ~distToFloatOperation)
|
GenericDist.toFloatOperation(dist, ~toPointSetFn, ~distToFloatOperation)
|
||||||
|
@ -261,7 +261,7 @@ module Constructors = {
|
||||||
let pdf = (~env, dist, f) => C.pdf(dist, f)->run(~env)->toFloatR
|
let pdf = (~env, dist, f) => C.pdf(dist, f)->run(~env)->toFloatR
|
||||||
let normalize = (~env, dist) => C.normalize(dist)->run(~env)->toDistR
|
let normalize = (~env, dist) => C.normalize(dist)->run(~env)->toDistR
|
||||||
let isNormalized = (~env, dist) => C.isNormalized(dist)->run(~env)->toBoolR
|
let isNormalized = (~env, dist) => C.isNormalized(dist)->run(~env)->toBoolR
|
||||||
let logScore = (~env, dist1, dist2) => C.logScore(dist1, dist2)->run(~env)->toFloatR
|
let klDivergence = (~env, dist1, dist2) => C.klDivergence(dist1, dist2)->run(~env)->toFloatR
|
||||||
let toPointSet = (~env, dist) => C.toPointSet(dist)->run(~env)->toDistR
|
let toPointSet = (~env, dist) => C.toPointSet(dist)->run(~env)->toDistR
|
||||||
let toSampleSet = (~env, dist, n) => C.toSampleSet(dist, n)->run(~env)->toDistR
|
let toSampleSet = (~env, dist, n) => C.toSampleSet(dist, n)->run(~env)->toDistR
|
||||||
let fromSamples = (~env, xs) => C.fromSamples(xs)->run(~env)->toDistR
|
let fromSamples = (~env, xs) => C.fromSamples(xs)->run(~env)->toDistR
|
||||||
|
|
|
@ -60,7 +60,7 @@ module Constructors: {
|
||||||
@genType
|
@genType
|
||||||
let isNormalized: (~env: env, genericDist) => result<bool, error>
|
let isNormalized: (~env: env, genericDist) => result<bool, error>
|
||||||
@genType
|
@genType
|
||||||
let logScore: (~env: env, genericDist, genericDist) => result<float, error>
|
let klDivergence: (~env: env, genericDist, genericDist) => result<float, error>
|
||||||
@genType
|
@genType
|
||||||
let toPointSet: (~env: env, genericDist) => result<genericDist, error>
|
let toPointSet: (~env: env, genericDist) => result<genericDist, error>
|
||||||
@genType
|
@genType
|
||||||
|
|
|
@ -160,7 +160,7 @@ module Constructors = {
|
||||||
let fromSamples = (xs): t => FromSamples(xs)
|
let fromSamples = (xs): t => FromSamples(xs)
|
||||||
let truncate = (dist, left, right): t => FromDist(ToDist(Truncate(left, right)), dist)
|
let truncate = (dist, left, right): t => FromDist(ToDist(Truncate(left, right)), dist)
|
||||||
let inspect = (dist): t => FromDist(ToDist(Inspect), dist)
|
let inspect = (dist): t => FromDist(ToDist(Inspect), dist)
|
||||||
let logScore = (dist1, dist2): t => FromDist(ToScore(KLDivergence(dist2)), dist1)
|
let klDivergence = (dist1, dist2): t => FromDist(ToScore(KLDivergence(dist2)), dist1)
|
||||||
let scalePower = (dist, n): t => FromDist(ToDist(Scale(#Power, n)), dist)
|
let scalePower = (dist, n): t => FromDist(ToDist(Scale(#Power, n)), dist)
|
||||||
let scaleLogarithm = (dist, n): t => FromDist(ToDist(Scale(#Logarithm, n)), dist)
|
let scaleLogarithm = (dist, n): t => FromDist(ToDist(Scale(#Logarithm, n)), dist)
|
||||||
let scaleLogarithmWithThreshold = (dist, n, eps): t => FromDist(
|
let scaleLogarithmWithThreshold = (dist, n, eps): t => FromDist(
|
||||||
|
|
|
@ -270,24 +270,10 @@ module T = Dist({
|
||||||
let variance = (t: t): float =>
|
let variance = (t: t): float =>
|
||||||
XYShape.Analysis.getVarianceDangerously(t, mean, Analysis.getMeanOfSquares)
|
XYShape.Analysis.getVarianceDangerously(t, mean, Analysis.getMeanOfSquares)
|
||||||
|
|
||||||
let klDivergence = (base: t, reference: t) => {
|
let klDivergence = (prediction: t, answer: t) => {
|
||||||
let referenceIsZero = switch Distributions.Common.isZeroEverywhere(
|
combinePointwise(PointSetDist_Scoring.KLDivergence.integrand, prediction, answer)
|
||||||
PointSetTypes.Continuous(reference),
|
|> E.R.fmap(shapeMap(XYShape.T.filterYValues(Js.Float.isFinite)))
|
||||||
) {
|
|> E.R.fmap(integralEndY)
|
||||||
| Continuous(b) => b
|
|
||||||
| _ => false
|
|
||||||
}
|
|
||||||
if referenceIsZero {
|
|
||||||
Ok(0.0)
|
|
||||||
} else {
|
|
||||||
combinePointwise(
|
|
||||||
PointSetDist_Scoring.KLDivergence.logScore(~eps=MagicNumbers.Epsilon.ten),
|
|
||||||
base,
|
|
||||||
reference,
|
|
||||||
)
|
|
||||||
|> E.R.fmap(shapeMap(XYShape.T.filterYValues(Js.Float.isFinite)))
|
|
||||||
|> E.R.fmap(integralEndY)
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
|
|
||||||
|
|
|
@ -229,21 +229,11 @@ module T = Dist({
|
||||||
XYShape.Analysis.getVarianceDangerously(t, mean, getMeanOfSquares)
|
XYShape.Analysis.getVarianceDangerously(t, mean, getMeanOfSquares)
|
||||||
}
|
}
|
||||||
|
|
||||||
let klDivergence = (base: t, reference: t) => {
|
let klDivergence = (prediction: t, answer: t) => {
|
||||||
let referenceIsZero = switch Distributions.Common.isZeroEverywhere(
|
combinePointwise(
|
||||||
PointSetTypes.Discrete(reference),
|
~fn=PointSetDist_Scoring.KLDivergence.integrand,
|
||||||
) {
|
prediction,
|
||||||
| Discrete(b) => b
|
answer,
|
||||||
| _ => false
|
) |> E.R2.bind(integralEndYResult)
|
||||||
}
|
|
||||||
if referenceIsZero {
|
|
||||||
Ok(0.0)
|
|
||||||
} else {
|
|
||||||
combinePointwise(
|
|
||||||
~fn=PointSetDist_Scoring.KLDivergence.logScore(~eps=MagicNumbers.Epsilon.ten),
|
|
||||||
base,
|
|
||||||
reference,
|
|
||||||
) |> E.R2.bind(integralEndYResult)
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
|
|
|
@ -96,18 +96,4 @@ module Common = {
|
||||||
None
|
None
|
||||||
| (Some(s1), Some(s2)) => combineFn(s1, s2)
|
| (Some(s1), Some(s2)) => combineFn(s1, s2)
|
||||||
}
|
}
|
||||||
|
|
||||||
let isZeroEverywhere = (d: PointSetTypes.pointSetDist) => {
|
|
||||||
let isZero = (x: float): bool => x == 0.0
|
|
||||||
PointSetTypes.ShapeMonad.fmap(
|
|
||||||
d,
|
|
||||||
(
|
|
||||||
mixed =>
|
|
||||||
E.A.all(isZero, mixed.continuous.xyShape.ys) &&
|
|
||||||
E.A.all(isZero, mixed.discrete.xyShape.ys),
|
|
||||||
disc => E.A.all(isZero, disc.xyShape.ys),
|
|
||||||
cont => E.A.all(isZero, cont.xyShape.ys),
|
|
||||||
),
|
|
||||||
)
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
|
@ -301,22 +301,10 @@ module T = Dist({
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
let klDivergence = (base: t, reference: t) => {
|
let klDivergence = (prediction: t, answer: t) => {
|
||||||
let referenceIsZero = switch Distributions.Common.isZeroEverywhere(
|
combinePointwise(PointSetDist_Scoring.KLDivergence.integrand, prediction, answer) |> E.R.fmap(
|
||||||
PointSetTypes.Mixed(reference),
|
integralEndY,
|
||||||
) {
|
)
|
||||||
| Mixed(b) => b
|
|
||||||
| _ => false
|
|
||||||
}
|
|
||||||
if referenceIsZero {
|
|
||||||
Ok(0.0)
|
|
||||||
} else {
|
|
||||||
combinePointwise(
|
|
||||||
PointSetDist_Scoring.KLDivergence.logScore(~eps=MagicNumbers.Epsilon.ten),
|
|
||||||
base,
|
|
||||||
reference,
|
|
||||||
) |> E.R.fmap(integralEndY)
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
|
|
||||||
|
|
|
@ -202,9 +202,9 @@ module T = Dist({
|
||||||
| (Discrete(t1), Discrete(t2)) => Discrete.T.klDivergence(t1, t2)
|
| (Discrete(t1), Discrete(t2)) => Discrete.T.klDivergence(t1, t2)
|
||||||
| (Mixed(t1), Mixed(t2)) => Mixed.T.klDivergence(t1, t2)
|
| (Mixed(t1), Mixed(t2)) => Mixed.T.klDivergence(t1, t2)
|
||||||
| _ => {
|
| _ => {
|
||||||
let t1 = toMixed(t1)
|
let t1' = toMixed(t1)
|
||||||
let t2 = toMixed(t2)
|
let t2' = toMixed(t2)
|
||||||
Mixed.T.klDivergence(t1, t2)
|
Mixed.T.klDivergence(t1', t2')
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
|
|
|
@ -1,25 +1,15 @@
|
||||||
module KLDivergence = {
|
module KLDivergence = {
|
||||||
let logFn = Js.Math.log
|
let logFn = Js.Math.log
|
||||||
let subtraction = (a, b) => Ok(a -. b)
|
let integrand = (predictionElement: float, answerElement: float): result<
|
||||||
let multiply = (a: float, b: float): result<float, Operation.Error.t> => Ok(a *. b)
|
float,
|
||||||
let logScoreDirect = (a: float, b: float): result<float, Operation.Error.t> =>
|
Operation.Error.t,
|
||||||
if a == 0.0 {
|
> =>
|
||||||
|
if predictionElement == 0.0 {
|
||||||
Error(Operation.NegativeInfinityError)
|
Error(Operation.NegativeInfinityError)
|
||||||
} else if b == 0.0 {
|
} else if answerElement == 0.0 {
|
||||||
Ok(b)
|
Ok(answerElement)
|
||||||
} else {
|
} else {
|
||||||
let quot = a /. b
|
let quot = predictionElement /. answerElement
|
||||||
quot < 0.0 ? Error(Operation.ComplexNumberError) : Ok(b *. logFn(quot))
|
quot < 0.0 ? Error(Operation.ComplexNumberError) : Ok(-.answerElement *. logFn(quot))
|
||||||
}
|
|
||||||
let logScoreWithThreshold = (~eps: float, a: float, b: float): result<float, Operation.Error.t> =>
|
|
||||||
if abs_float(a) < eps {
|
|
||||||
Ok(0.0)
|
|
||||||
} else {
|
|
||||||
logScoreDirect(a, b)
|
|
||||||
}
|
|
||||||
let logScore = (~eps: option<float>=?, a: float, b: float): result<float, Operation.Error.t> =>
|
|
||||||
switch eps {
|
|
||||||
| None => logScoreDirect(a, b)
|
|
||||||
| Some(eps') => logScoreWithThreshold(~eps=eps', a, b)
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -6,6 +6,7 @@ module Math = {
|
||||||
module Epsilon = {
|
module Epsilon = {
|
||||||
let ten = 1e-10
|
let ten = 1e-10
|
||||||
let seven = 1e-7
|
let seven = 1e-7
|
||||||
|
let five = 1e-5
|
||||||
}
|
}
|
||||||
|
|
||||||
module Environment = {
|
module Environment = {
|
||||||
|
|
|
@ -468,7 +468,7 @@ module Range = {
|
||||||
// TODO: I think this isn't needed by any functions anymore.
|
// TODO: I think this isn't needed by any functions anymore.
|
||||||
let stepsToContinuous = t => {
|
let stepsToContinuous = t => {
|
||||||
// TODO: It would be nicer if this the diff didn't change the first element, and also maybe if there were a more elegant way of doing this.
|
// TODO: It would be nicer if this the diff didn't change the first element, and also maybe if there were a more elegant way of doing this.
|
||||||
let diff = T.xTotalRange(t) |> (r => r *. 0.00001)
|
let diff = T.xTotalRange(t) |> (r => r *. MagicNumbers.Epsilon.five)
|
||||||
let items = switch E.A.toRanges(Belt.Array.zip(t.xs, t.ys)) {
|
let items = switch E.A.toRanges(Belt.Array.zip(t.xs, t.ys)) {
|
||||||
| Ok(items) =>
|
| Ok(items) =>
|
||||||
Some(
|
Some(
|
||||||
|
|
Loading…
Reference in New Issue
Block a user