combineAlongSupportOfSecondArgument
implemented, tests still failing
Value: [1e-4 to 4e-2]
This commit is contained in:
parent
b49865d3aa
commit
dcf56d7bc6
|
@ -4,11 +4,11 @@ open TestHelpers
|
||||||
|
|
||||||
describe("kl divergence", () => {
|
describe("kl divergence", () => {
|
||||||
let klDivergence = DistributionOperation.Constructors.klDivergence(~env)
|
let klDivergence = DistributionOperation.Constructors.klDivergence(~env)
|
||||||
test("", () => {
|
exception KlFailed
|
||||||
exception KlFailed
|
test("of two uniforms is equal to the analytic expression", () => {
|
||||||
let lowAnswer = 4.3526e0
|
let lowAnswer = 2.3526e0
|
||||||
let highAnswer = 8.5382e0
|
let highAnswer = 8.5382e0
|
||||||
let lowPrediction = 4.3526e0
|
let lowPrediction = 2.3526e0
|
||||||
let highPrediction = 1.2345e1
|
let highPrediction = 1.2345e1
|
||||||
let answer =
|
let answer =
|
||||||
uniformMakeR(lowAnswer, highAnswer)->E.R2.errMap(s => DistributionTypes.ArgumentError(s))
|
uniformMakeR(lowAnswer, highAnswer)->E.R2.errMap(s => DistributionTypes.ArgumentError(s))
|
||||||
|
@ -17,11 +17,30 @@ describe("kl divergence", () => {
|
||||||
s,
|
s,
|
||||||
))
|
))
|
||||||
// integral along the support of the answer of answer.pdf(x) times log of prediction.pdf(x) divided by answer.pdf(x) dx
|
// integral along the support of the answer of answer.pdf(x) times log of prediction.pdf(x) divided by answer.pdf(x) dx
|
||||||
|
let analyticalKl = Js.Math.log((highPrediction -. lowPrediction) /. (highAnswer -. lowAnswer))
|
||||||
|
let kl = E.R.liftJoin2(klDivergence, prediction, answer)
|
||||||
|
switch kl {
|
||||||
|
| Ok(kl') => kl'->expect->toBeCloseTo(analyticalKl)
|
||||||
|
| Error(err) => {
|
||||||
|
Js.Console.log(DistributionTypes.Error.toString(err))
|
||||||
|
raise(KlFailed)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
})
|
||||||
|
test("of two normals is equal to the formula", () => {
|
||||||
|
// This test case comes via Nuño https://github.com/quantified-uncertainty/squiggle/issues/433
|
||||||
|
let mean1 = 4.0
|
||||||
|
let mean2 = 1.0
|
||||||
|
let stdev1 = 1.0
|
||||||
|
let stdev2 = 1.0
|
||||||
|
|
||||||
|
let prediction =
|
||||||
|
normalMakeR(mean1, stdev1)->E.R2.errMap(s => DistributionTypes.ArgumentError(s))
|
||||||
|
let answer = normalMakeR(mean2, stdev2)->E.R2.errMap(s => DistributionTypes.ArgumentError(s))
|
||||||
let analyticalKl =
|
let analyticalKl =
|
||||||
-1.0 /.
|
Js.Math.log(stdev2 /. stdev1) +.
|
||||||
(highAnswer -. lowAnswer) *.
|
stdev1 ** 2.0 /. 2.0 /. stdev2 ** 2.0 +.
|
||||||
Js.Math.log((highAnswer -. lowAnswer) /. (highPrediction -. lowPrediction)) *.
|
(mean1 -. mean2) ** 2.0 /. 2.0 /. stdev2 ** 2.0 -. 0.5
|
||||||
(highAnswer -. lowAnswer)
|
|
||||||
let kl = E.R.liftJoin2(klDivergence, prediction, answer)
|
let kl = E.R.liftJoin2(klDivergence, prediction, answer)
|
||||||
switch kl {
|
switch kl {
|
||||||
| Ok(kl') => kl'->expect->toBeCloseTo(analyticalKl)
|
| Ok(kl') => kl'->expect->toBeCloseTo(analyticalKl)
|
||||||
|
|
|
@ -38,19 +38,6 @@ describe("XYShapes", () => {
|
||||||
)
|
)
|
||||||
})
|
})
|
||||||
|
|
||||||
describe("logScorePoint", () => {
|
|
||||||
makeTest("When identical", XYShape.logScorePoint(30, pointSetDist1, pointSetDist1), Some(0.0))
|
|
||||||
makeTest(
|
|
||||||
"When similar",
|
|
||||||
XYShape.logScorePoint(30, pointSetDist1, pointSetDist2),
|
|
||||||
Some(1.658971191043856),
|
|
||||||
)
|
|
||||||
makeTest(
|
|
||||||
"When very different",
|
|
||||||
XYShape.logScorePoint(30, pointSetDist1, pointSetDist3),
|
|
||||||
Some(210.3721280423322),
|
|
||||||
)
|
|
||||||
})
|
|
||||||
describe("integrateWithTriangles", () =>
|
describe("integrateWithTriangles", () =>
|
||||||
makeTest(
|
makeTest(
|
||||||
"integrates correctly",
|
"integrates correctly",
|
||||||
|
|
|
@ -86,6 +86,7 @@ let stepwiseToLinear = (t: t): t =>
|
||||||
|
|
||||||
// Note: This results in a distribution with as many points as the sum of those in t1 and t2.
|
// Note: This results in a distribution with as many points as the sum of those in t1 and t2.
|
||||||
let combinePointwise = (
|
let combinePointwise = (
|
||||||
|
~combiner=XYShape.PointwiseCombination.combine,
|
||||||
~integralSumCachesFn=(_, _) => None,
|
~integralSumCachesFn=(_, _) => None,
|
||||||
~distributionType: PointSetTypes.distributionType=#PDF,
|
~distributionType: PointSetTypes.distributionType=#PDF,
|
||||||
fn: (float, float) => result<float, Operation.Error.t>,
|
fn: (float, float) => result<float, Operation.Error.t>,
|
||||||
|
@ -119,7 +120,7 @@ let combinePointwise = (
|
||||||
|
|
||||||
let interpolator = XYShape.XtoY.continuousInterpolator(t1.interpolation, extrapolation)
|
let interpolator = XYShape.XtoY.continuousInterpolator(t1.interpolation, extrapolation)
|
||||||
|
|
||||||
XYShape.PointwiseCombination.combine(fn, interpolator, t1.xyShape, t2.xyShape)->E.R2.fmap(x =>
|
combiner(fn, interpolator, t1.xyShape, t2.xyShape)->E.R2.fmap(x =>
|
||||||
make(~integralSumCache=combinedIntegralSum, x)
|
make(~integralSumCache=combinedIntegralSum, x)
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
|
@ -271,7 +272,12 @@ module T = Dist({
|
||||||
XYShape.Analysis.getVarianceDangerously(t, mean, Analysis.getMeanOfSquares)
|
XYShape.Analysis.getVarianceDangerously(t, mean, Analysis.getMeanOfSquares)
|
||||||
|
|
||||||
let klDivergence = (prediction: t, answer: t) => {
|
let klDivergence = (prediction: t, answer: t) => {
|
||||||
combinePointwise(PointSetDist_Scoring.KLDivergence.integrand, prediction, answer)
|
combinePointwise(
|
||||||
|
~combiner=XYShape.PointwiseCombination.combineAlongSupportOfSecondArgument,
|
||||||
|
PointSetDist_Scoring.KLDivergence.integrand,
|
||||||
|
prediction,
|
||||||
|
answer,
|
||||||
|
)
|
||||||
|> E.R.fmap(shapeMap(XYShape.T.filterYValues(Js.Float.isFinite)))
|
|> E.R.fmap(shapeMap(XYShape.T.filterYValues(Js.Float.isFinite)))
|
||||||
|> E.R.fmap(integralEndY)
|
|> E.R.fmap(integralEndY)
|
||||||
}
|
}
|
||||||
|
|
|
@ -33,6 +33,7 @@ let shapeFn = (fn, t: t) => t |> getShape |> fn
|
||||||
let lastY = (t: t) => t |> getShape |> XYShape.T.lastY
|
let lastY = (t: t) => t |> getShape |> XYShape.T.lastY
|
||||||
|
|
||||||
let combinePointwise = (
|
let combinePointwise = (
|
||||||
|
~combiner=XYShape.PointwiseCombination.combine,
|
||||||
~integralSumCachesFn=(_, _) => None,
|
~integralSumCachesFn=(_, _) => None,
|
||||||
~fn=(a, b) => Ok(a +. b),
|
~fn=(a, b) => Ok(a +. b),
|
||||||
t1: PointSetTypes.discreteShape,
|
t1: PointSetTypes.discreteShape,
|
||||||
|
@ -48,12 +49,10 @@ let combinePointwise = (
|
||||||
// It could be done for pointwise additions, but is that ever needed?
|
// It could be done for pointwise additions, but is that ever needed?
|
||||||
|
|
||||||
make(
|
make(
|
||||||
XYShape.PointwiseCombination.combine(
|
combiner(fn, XYShape.XtoY.discreteInterpolator, t1.xyShape, t2.xyShape)->E.R.toExn(
|
||||||
fn,
|
"Addition operation should never fail",
|
||||||
XYShape.XtoY.discreteInterpolator,
|
_,
|
||||||
t1.xyShape,
|
),
|
||||||
t2.xyShape,
|
|
||||||
)->E.R.toExn("Addition operation should never fail", _),
|
|
||||||
)->Ok
|
)->Ok
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -231,6 +230,7 @@ module T = Dist({
|
||||||
|
|
||||||
let klDivergence = (prediction: t, answer: t) => {
|
let klDivergence = (prediction: t, answer: t) => {
|
||||||
combinePointwise(
|
combinePointwise(
|
||||||
|
~combiner=XYShape.PointwiseCombination.combineAlongSupportOfSecondArgument,
|
||||||
~fn=PointSetDist_Scoring.KLDivergence.integrand,
|
~fn=PointSetDist_Scoring.KLDivergence.integrand,
|
||||||
prediction,
|
prediction,
|
||||||
answer,
|
answer,
|
||||||
|
|
|
@ -48,7 +48,7 @@ let combinePointwise = (
|
||||||
|> E.A.fmap(toDiscrete)
|
|> E.A.fmap(toDiscrete)
|
||||||
|> E.A.O.concatSomes
|
|> E.A.O.concatSomes
|
||||||
|> Discrete.reduce(~integralSumCachesFn, fn)
|
|> Discrete.reduce(~integralSumCachesFn, fn)
|
||||||
|> E.R.toExn("foo")
|
|> E.R.toExn("Theoretically unreachable state")
|
||||||
|
|
||||||
let reducedContinuous =
|
let reducedContinuous =
|
||||||
[t1, t2]
|
[t1, t2]
|
||||||
|
|
|
@ -1,5 +1,5 @@
|
||||||
module KLDivergence = {
|
module KLDivergence = {
|
||||||
let logFn = Js.Math.log
|
let logFn = Js.Math.log // base e
|
||||||
let integrand = (predictionElement: float, answerElement: float): result<
|
let integrand = (predictionElement: float, answerElement: float): result<
|
||||||
float,
|
float,
|
||||||
Operation.Error.t,
|
Operation.Error.t,
|
||||||
|
@ -7,9 +7,9 @@ module KLDivergence = {
|
||||||
if predictionElement == 0.0 {
|
if predictionElement == 0.0 {
|
||||||
Error(Operation.NegativeInfinityError)
|
Error(Operation.NegativeInfinityError)
|
||||||
} else if answerElement == 0.0 {
|
} else if answerElement == 0.0 {
|
||||||
Ok(answerElement)
|
Ok(0.0)
|
||||||
} else {
|
} else {
|
||||||
let quot = predictionElement /. answerElement
|
let quot = predictionElement /. answerElement
|
||||||
quot < 0.0 ? Error(Operation.ComplexNumberError) : Ok(-.answerElement *. logFn(quot))
|
quot < 0.0 ? Error(Operation.ComplexNumberError) : Ok(answerElement *. logFn(quot))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -97,7 +97,20 @@ module T = {
|
||||||
let equallyDividedXs = (t: t, newLength) => E.A.Floats.range(minX(t), maxX(t), newLength)
|
let equallyDividedXs = (t: t, newLength) => E.A.Floats.range(minX(t), maxX(t), newLength)
|
||||||
let toJs = (t: t) => {"xs": t.xs, "ys": t.ys}
|
let toJs = (t: t) => {"xs": t.xs, "ys": t.ys}
|
||||||
let filterYValues = (fn, t: t): t => t |> zip |> E.A.filter(((_, y)) => fn(y)) |> fromZippedArray
|
let filterYValues = (fn, t: t): t => t |> zip |> E.A.filter(((_, y)) => fn(y)) |> fromZippedArray
|
||||||
|
let filterOkYs = (xs: array<float>, ys: array<result<float, 'b>>): t => {
|
||||||
|
let n = E.A.length(xs) // Assume length(xs) == length(ys)
|
||||||
|
let newXs = []
|
||||||
|
let newYs = []
|
||||||
|
for i in 0 to n - 1 {
|
||||||
|
switch ys[i] {
|
||||||
|
| Ok(y) =>
|
||||||
|
let _ = Js.Array.push(xs[i], newXs)
|
||||||
|
let _ = Js.Array.push(y, newYs)
|
||||||
|
| Error(_) => ()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
{xs: newXs, ys: newYs}
|
||||||
|
}
|
||||||
module Validator = {
|
module Validator = {
|
||||||
let fnName = "XYShape validate"
|
let fnName = "XYShape validate"
|
||||||
let notSortedError = (p: string): error => NotSorted(p)
|
let notSortedError = (p: string): error => NotSorted(p)
|
||||||
|
@ -377,6 +390,64 @@ module PointwiseCombination = {
|
||||||
}
|
}
|
||||||
`)
|
`)
|
||||||
|
|
||||||
|
// This function is used for kl divergence
|
||||||
|
let combineAlongSupportOfSecondArgument: (
|
||||||
|
(float, float) => result<float, Operation.Error.t>,
|
||||||
|
interpolator,
|
||||||
|
T.t,
|
||||||
|
T.t,
|
||||||
|
) => result<T.t, Operation.Error.t> = (fn, interpolator, t1, t2) => {
|
||||||
|
let newYs = []
|
||||||
|
let newXs = []
|
||||||
|
let (l1, l2) = (E.A.length(t1.xs), E.A.length(t2.xs))
|
||||||
|
let (i, j) = (ref(0), ref(0))
|
||||||
|
let minX = t2.xs[0]
|
||||||
|
let maxX = t2.xs[l2 - 1]
|
||||||
|
while j.contents < l2 - 1 && i.contents < l1 - 1 {
|
||||||
|
let (x, y1, y2) = {
|
||||||
|
let x1 = t1.xs[i.contents + 1]
|
||||||
|
let x2 = t2.xs[j.contents + 1]
|
||||||
|
/* if t1 has to catch up to t2 */ if (
|
||||||
|
i.contents < l1 - 1 && j.contents < l2 && x1 < x2 && minX <= x1 && x2 <= maxX
|
||||||
|
) {
|
||||||
|
i := i.contents + 1
|
||||||
|
let x = x1
|
||||||
|
let y1 = t1.ys[i.contents]
|
||||||
|
let y2 = interpolator(t2, j.contents, x)
|
||||||
|
(x, y1, y2)
|
||||||
|
} else if (
|
||||||
|
/* if t2 has to catch up to t1 */
|
||||||
|
i.contents < l1 && j.contents < l2 - 1 && x1 > x2 && x2 >= minX && maxX >= x1
|
||||||
|
) {
|
||||||
|
j := j.contents + 1
|
||||||
|
let x = x2
|
||||||
|
let y1 = interpolator(t1, i.contents, x)
|
||||||
|
let y2 = t2.ys[j.contents]
|
||||||
|
(x, y1, y2)
|
||||||
|
} else if (
|
||||||
|
/* move both ahead if they are equal */
|
||||||
|
i.contents < l1 - 1 && j.contents < l2 - 1 && x1 == x2 && x1 >= minX && maxX >= x2
|
||||||
|
) {
|
||||||
|
i := i.contents + 1
|
||||||
|
j := j.contents + 1
|
||||||
|
let x = x1
|
||||||
|
let y1 = t1.ys[i.contents]
|
||||||
|
let y2 = t2.ys[j.contents]
|
||||||
|
(x, y1, y2)
|
||||||
|
} else {
|
||||||
|
i := i.contents + 1
|
||||||
|
(0.0, 0.0, 0.0) // for the function I have in mind, this will error out
|
||||||
|
// exception PointwiseCombinationError
|
||||||
|
// raise(PointwiseCombinationError)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
// Js.Console.log(newYs)
|
||||||
|
let _ = Js.Array.push(fn(y1, y2), newYs)
|
||||||
|
let _ = Js.Array.push(x, newXs)
|
||||||
|
}
|
||||||
|
T.filterOkYs(newXs, newYs)->Ok
|
||||||
|
}
|
||||||
|
|
||||||
let addCombine = (interpolator: interpolator, t1: T.t, t2: T.t): T.t =>
|
let addCombine = (interpolator: interpolator, t1: T.t, t2: T.t): T.t =>
|
||||||
combine((a, b) => Ok(a +. b), interpolator, t1, t2)->E.R.toExn(
|
combine((a, b) => Ok(a +. b), interpolator, t1, t2)->E.R.toExn(
|
||||||
"Add operation should never fail",
|
"Add operation should never fail",
|
||||||
|
@ -490,25 +561,6 @@ module Range = {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
let pointLogScore = (prediction, answer) =>
|
|
||||||
switch answer {
|
|
||||||
| 0. => 0.0
|
|
||||||
| answer => answer *. Js.Math.log2(Js.Math.abs_float(prediction /. answer))
|
|
||||||
}
|
|
||||||
|
|
||||||
let logScorePoint = (sampleCount, t1, t2) =>
|
|
||||||
PointwiseCombination.combineEvenXs(
|
|
||||||
~fn=pointLogScore,
|
|
||||||
~xToYSelection=XtoY.linear,
|
|
||||||
sampleCount,
|
|
||||||
t1,
|
|
||||||
t2,
|
|
||||||
)
|
|
||||||
|> Range.integrateWithTriangles
|
|
||||||
|> E.O.fmap(T.accumulateYs(\"+."))
|
|
||||||
|> E.O.fmap(Pairs.last)
|
|
||||||
|> E.O.fmap(Pairs.y)
|
|
||||||
|
|
||||||
module Analysis = {
|
module Analysis = {
|
||||||
let getVarianceDangerously = (t: 't, mean: 't => float, getMeanOfSquares: 't => float): float => {
|
let getVarianceDangerously = (t: 't, mean: 't => float, getMeanOfSquares: 't => float): float => {
|
||||||
let meanSquared = mean(t) ** 2.0
|
let meanSquared = mean(t) ** 2.0
|
||||||
|
|
Loading…
Reference in New Issue
Block a user