combineAlongSupportOfSecondArgument
implemented, tests still failing
Value: [1e-4 to 4e-2]
This commit is contained in:
parent
b49865d3aa
commit
dcf56d7bc6
|
@ -4,11 +4,11 @@ open TestHelpers
|
|||
|
||||
describe("kl divergence", () => {
|
||||
let klDivergence = DistributionOperation.Constructors.klDivergence(~env)
|
||||
test("", () => {
|
||||
exception KlFailed
|
||||
let lowAnswer = 4.3526e0
|
||||
test("of two uniforms is equal to the analytic expression", () => {
|
||||
let lowAnswer = 2.3526e0
|
||||
let highAnswer = 8.5382e0
|
||||
let lowPrediction = 4.3526e0
|
||||
let lowPrediction = 2.3526e0
|
||||
let highPrediction = 1.2345e1
|
||||
let answer =
|
||||
uniformMakeR(lowAnswer, highAnswer)->E.R2.errMap(s => DistributionTypes.ArgumentError(s))
|
||||
|
@ -17,11 +17,30 @@ describe("kl divergence", () => {
|
|||
s,
|
||||
))
|
||||
// integral along the support of the answer of answer.pdf(x) times log of prediction.pdf(x) divided by answer.pdf(x) dx
|
||||
let analyticalKl = Js.Math.log((highPrediction -. lowPrediction) /. (highAnswer -. lowAnswer))
|
||||
let kl = E.R.liftJoin2(klDivergence, prediction, answer)
|
||||
switch kl {
|
||||
| Ok(kl') => kl'->expect->toBeCloseTo(analyticalKl)
|
||||
| Error(err) => {
|
||||
Js.Console.log(DistributionTypes.Error.toString(err))
|
||||
raise(KlFailed)
|
||||
}
|
||||
}
|
||||
})
|
||||
test("of two normals is equal to the formula", () => {
|
||||
// This test case comes via Nuño https://github.com/quantified-uncertainty/squiggle/issues/433
|
||||
let mean1 = 4.0
|
||||
let mean2 = 1.0
|
||||
let stdev1 = 1.0
|
||||
let stdev2 = 1.0
|
||||
|
||||
let prediction =
|
||||
normalMakeR(mean1, stdev1)->E.R2.errMap(s => DistributionTypes.ArgumentError(s))
|
||||
let answer = normalMakeR(mean2, stdev2)->E.R2.errMap(s => DistributionTypes.ArgumentError(s))
|
||||
let analyticalKl =
|
||||
-1.0 /.
|
||||
(highAnswer -. lowAnswer) *.
|
||||
Js.Math.log((highAnswer -. lowAnswer) /. (highPrediction -. lowPrediction)) *.
|
||||
(highAnswer -. lowAnswer)
|
||||
Js.Math.log(stdev2 /. stdev1) +.
|
||||
stdev1 ** 2.0 /. 2.0 /. stdev2 ** 2.0 +.
|
||||
(mean1 -. mean2) ** 2.0 /. 2.0 /. stdev2 ** 2.0 -. 0.5
|
||||
let kl = E.R.liftJoin2(klDivergence, prediction, answer)
|
||||
switch kl {
|
||||
| Ok(kl') => kl'->expect->toBeCloseTo(analyticalKl)
|
||||
|
|
|
@ -38,19 +38,6 @@ describe("XYShapes", () => {
|
|||
)
|
||||
})
|
||||
|
||||
describe("logScorePoint", () => {
|
||||
makeTest("When identical", XYShape.logScorePoint(30, pointSetDist1, pointSetDist1), Some(0.0))
|
||||
makeTest(
|
||||
"When similar",
|
||||
XYShape.logScorePoint(30, pointSetDist1, pointSetDist2),
|
||||
Some(1.658971191043856),
|
||||
)
|
||||
makeTest(
|
||||
"When very different",
|
||||
XYShape.logScorePoint(30, pointSetDist1, pointSetDist3),
|
||||
Some(210.3721280423322),
|
||||
)
|
||||
})
|
||||
describe("integrateWithTriangles", () =>
|
||||
makeTest(
|
||||
"integrates correctly",
|
||||
|
|
|
@ -86,6 +86,7 @@ let stepwiseToLinear = (t: t): t =>
|
|||
|
||||
// Note: This results in a distribution with as many points as the sum of those in t1 and t2.
|
||||
let combinePointwise = (
|
||||
~combiner=XYShape.PointwiseCombination.combine,
|
||||
~integralSumCachesFn=(_, _) => None,
|
||||
~distributionType: PointSetTypes.distributionType=#PDF,
|
||||
fn: (float, float) => result<float, Operation.Error.t>,
|
||||
|
@ -119,7 +120,7 @@ let combinePointwise = (
|
|||
|
||||
let interpolator = XYShape.XtoY.continuousInterpolator(t1.interpolation, extrapolation)
|
||||
|
||||
XYShape.PointwiseCombination.combine(fn, interpolator, t1.xyShape, t2.xyShape)->E.R2.fmap(x =>
|
||||
combiner(fn, interpolator, t1.xyShape, t2.xyShape)->E.R2.fmap(x =>
|
||||
make(~integralSumCache=combinedIntegralSum, x)
|
||||
)
|
||||
}
|
||||
|
@ -271,7 +272,12 @@ module T = Dist({
|
|||
XYShape.Analysis.getVarianceDangerously(t, mean, Analysis.getMeanOfSquares)
|
||||
|
||||
let klDivergence = (prediction: t, answer: t) => {
|
||||
combinePointwise(PointSetDist_Scoring.KLDivergence.integrand, prediction, answer)
|
||||
combinePointwise(
|
||||
~combiner=XYShape.PointwiseCombination.combineAlongSupportOfSecondArgument,
|
||||
PointSetDist_Scoring.KLDivergence.integrand,
|
||||
prediction,
|
||||
answer,
|
||||
)
|
||||
|> E.R.fmap(shapeMap(XYShape.T.filterYValues(Js.Float.isFinite)))
|
||||
|> E.R.fmap(integralEndY)
|
||||
}
|
||||
|
|
|
@ -33,6 +33,7 @@ let shapeFn = (fn, t: t) => t |> getShape |> fn
|
|||
let lastY = (t: t) => t |> getShape |> XYShape.T.lastY
|
||||
|
||||
let combinePointwise = (
|
||||
~combiner=XYShape.PointwiseCombination.combine,
|
||||
~integralSumCachesFn=(_, _) => None,
|
||||
~fn=(a, b) => Ok(a +. b),
|
||||
t1: PointSetTypes.discreteShape,
|
||||
|
@ -48,12 +49,10 @@ let combinePointwise = (
|
|||
// It could be done for pointwise additions, but is that ever needed?
|
||||
|
||||
make(
|
||||
XYShape.PointwiseCombination.combine(
|
||||
fn,
|
||||
XYShape.XtoY.discreteInterpolator,
|
||||
t1.xyShape,
|
||||
t2.xyShape,
|
||||
)->E.R.toExn("Addition operation should never fail", _),
|
||||
combiner(fn, XYShape.XtoY.discreteInterpolator, t1.xyShape, t2.xyShape)->E.R.toExn(
|
||||
"Addition operation should never fail",
|
||||
_,
|
||||
),
|
||||
)->Ok
|
||||
}
|
||||
|
||||
|
@ -231,6 +230,7 @@ module T = Dist({
|
|||
|
||||
let klDivergence = (prediction: t, answer: t) => {
|
||||
combinePointwise(
|
||||
~combiner=XYShape.PointwiseCombination.combineAlongSupportOfSecondArgument,
|
||||
~fn=PointSetDist_Scoring.KLDivergence.integrand,
|
||||
prediction,
|
||||
answer,
|
||||
|
|
|
@ -48,7 +48,7 @@ let combinePointwise = (
|
|||
|> E.A.fmap(toDiscrete)
|
||||
|> E.A.O.concatSomes
|
||||
|> Discrete.reduce(~integralSumCachesFn, fn)
|
||||
|> E.R.toExn("foo")
|
||||
|> E.R.toExn("Theoretically unreachable state")
|
||||
|
||||
let reducedContinuous =
|
||||
[t1, t2]
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
module KLDivergence = {
|
||||
let logFn = Js.Math.log
|
||||
let logFn = Js.Math.log // base e
|
||||
let integrand = (predictionElement: float, answerElement: float): result<
|
||||
float,
|
||||
Operation.Error.t,
|
||||
|
@ -7,9 +7,9 @@ module KLDivergence = {
|
|||
if predictionElement == 0.0 {
|
||||
Error(Operation.NegativeInfinityError)
|
||||
} else if answerElement == 0.0 {
|
||||
Ok(answerElement)
|
||||
Ok(0.0)
|
||||
} else {
|
||||
let quot = predictionElement /. answerElement
|
||||
quot < 0.0 ? Error(Operation.ComplexNumberError) : Ok(-.answerElement *. logFn(quot))
|
||||
quot < 0.0 ? Error(Operation.ComplexNumberError) : Ok(answerElement *. logFn(quot))
|
||||
}
|
||||
}
|
||||
|
|
|
@ -97,7 +97,20 @@ module T = {
|
|||
let equallyDividedXs = (t: t, newLength) => E.A.Floats.range(minX(t), maxX(t), newLength)
|
||||
let toJs = (t: t) => {"xs": t.xs, "ys": t.ys}
|
||||
let filterYValues = (fn, t: t): t => t |> zip |> E.A.filter(((_, y)) => fn(y)) |> fromZippedArray
|
||||
|
||||
let filterOkYs = (xs: array<float>, ys: array<result<float, 'b>>): t => {
|
||||
let n = E.A.length(xs) // Assume length(xs) == length(ys)
|
||||
let newXs = []
|
||||
let newYs = []
|
||||
for i in 0 to n - 1 {
|
||||
switch ys[i] {
|
||||
| Ok(y) =>
|
||||
let _ = Js.Array.push(xs[i], newXs)
|
||||
let _ = Js.Array.push(y, newYs)
|
||||
| Error(_) => ()
|
||||
}
|
||||
}
|
||||
{xs: newXs, ys: newYs}
|
||||
}
|
||||
module Validator = {
|
||||
let fnName = "XYShape validate"
|
||||
let notSortedError = (p: string): error => NotSorted(p)
|
||||
|
@ -377,6 +390,64 @@ module PointwiseCombination = {
|
|||
}
|
||||
`)
|
||||
|
||||
// This function is used for kl divergence
|
||||
let combineAlongSupportOfSecondArgument: (
|
||||
(float, float) => result<float, Operation.Error.t>,
|
||||
interpolator,
|
||||
T.t,
|
||||
T.t,
|
||||
) => result<T.t, Operation.Error.t> = (fn, interpolator, t1, t2) => {
|
||||
let newYs = []
|
||||
let newXs = []
|
||||
let (l1, l2) = (E.A.length(t1.xs), E.A.length(t2.xs))
|
||||
let (i, j) = (ref(0), ref(0))
|
||||
let minX = t2.xs[0]
|
||||
let maxX = t2.xs[l2 - 1]
|
||||
while j.contents < l2 - 1 && i.contents < l1 - 1 {
|
||||
let (x, y1, y2) = {
|
||||
let x1 = t1.xs[i.contents + 1]
|
||||
let x2 = t2.xs[j.contents + 1]
|
||||
/* if t1 has to catch up to t2 */ if (
|
||||
i.contents < l1 - 1 && j.contents < l2 && x1 < x2 && minX <= x1 && x2 <= maxX
|
||||
) {
|
||||
i := i.contents + 1
|
||||
let x = x1
|
||||
let y1 = t1.ys[i.contents]
|
||||
let y2 = interpolator(t2, j.contents, x)
|
||||
(x, y1, y2)
|
||||
} else if (
|
||||
/* if t2 has to catch up to t1 */
|
||||
i.contents < l1 && j.contents < l2 - 1 && x1 > x2 && x2 >= minX && maxX >= x1
|
||||
) {
|
||||
j := j.contents + 1
|
||||
let x = x2
|
||||
let y1 = interpolator(t1, i.contents, x)
|
||||
let y2 = t2.ys[j.contents]
|
||||
(x, y1, y2)
|
||||
} else if (
|
||||
/* move both ahead if they are equal */
|
||||
i.contents < l1 - 1 && j.contents < l2 - 1 && x1 == x2 && x1 >= minX && maxX >= x2
|
||||
) {
|
||||
i := i.contents + 1
|
||||
j := j.contents + 1
|
||||
let x = x1
|
||||
let y1 = t1.ys[i.contents]
|
||||
let y2 = t2.ys[j.contents]
|
||||
(x, y1, y2)
|
||||
} else {
|
||||
i := i.contents + 1
|
||||
(0.0, 0.0, 0.0) // for the function I have in mind, this will error out
|
||||
// exception PointwiseCombinationError
|
||||
// raise(PointwiseCombinationError)
|
||||
}
|
||||
}
|
||||
// Js.Console.log(newYs)
|
||||
let _ = Js.Array.push(fn(y1, y2), newYs)
|
||||
let _ = Js.Array.push(x, newXs)
|
||||
}
|
||||
T.filterOkYs(newXs, newYs)->Ok
|
||||
}
|
||||
|
||||
let addCombine = (interpolator: interpolator, t1: T.t, t2: T.t): T.t =>
|
||||
combine((a, b) => Ok(a +. b), interpolator, t1, t2)->E.R.toExn(
|
||||
"Add operation should never fail",
|
||||
|
@ -490,25 +561,6 @@ module Range = {
|
|||
}
|
||||
}
|
||||
|
||||
let pointLogScore = (prediction, answer) =>
|
||||
switch answer {
|
||||
| 0. => 0.0
|
||||
| answer => answer *. Js.Math.log2(Js.Math.abs_float(prediction /. answer))
|
||||
}
|
||||
|
||||
let logScorePoint = (sampleCount, t1, t2) =>
|
||||
PointwiseCombination.combineEvenXs(
|
||||
~fn=pointLogScore,
|
||||
~xToYSelection=XtoY.linear,
|
||||
sampleCount,
|
||||
t1,
|
||||
t2,
|
||||
)
|
||||
|> Range.integrateWithTriangles
|
||||
|> E.O.fmap(T.accumulateYs(\"+."))
|
||||
|> E.O.fmap(Pairs.last)
|
||||
|> E.O.fmap(Pairs.y)
|
||||
|
||||
module Analysis = {
|
||||
let getVarianceDangerously = (t: 't, mean: 't => float, getMeanOfSquares: 't => float): float => {
|
||||
let meanSquared = mean(t) ** 2.0
|
||||
|
|
Loading…
Reference in New Issue
Block a user