tweak: Add tests for combineAlongSupportOfSecondArgument
This commit is contained in:
parent
cc3db79a2a
commit
5dd272fb0c
|
@ -6,10 +6,10 @@ describe("kl divergence", () => {
|
||||||
let klDivergence = DistributionOperation.Constructors.klDivergence(~env)
|
let klDivergence = DistributionOperation.Constructors.klDivergence(~env)
|
||||||
exception KlFailed
|
exception KlFailed
|
||||||
test("of two uniforms is equal to the analytic expression", () => {
|
test("of two uniforms is equal to the analytic expression", () => {
|
||||||
let lowAnswer = 2.3526e0
|
let lowAnswer = 0.0
|
||||||
let highAnswer = 8.5382e0
|
let highAnswer = 1.0
|
||||||
let lowPrediction = 2.3526e0
|
let lowPrediction = 0.0
|
||||||
let highPrediction = 1.2345e1
|
let highPrediction = 2.0
|
||||||
let answer =
|
let answer =
|
||||||
uniformMakeR(lowAnswer, highAnswer)->E.R2.errMap(s => DistributionTypes.ArgumentError(s))
|
uniformMakeR(lowAnswer, highAnswer)->E.R2.errMap(s => DistributionTypes.ArgumentError(s))
|
||||||
let prediction =
|
let prediction =
|
||||||
|
@ -19,6 +19,8 @@ describe("kl divergence", () => {
|
||||||
// integral along the support of the answer of answer.pdf(x) times log of prediction.pdf(x) divided by answer.pdf(x) dx
|
// integral along the support of the answer of answer.pdf(x) times log of prediction.pdf(x) divided by answer.pdf(x) dx
|
||||||
let analyticalKl = Js.Math.log((highPrediction -. lowPrediction) /. (highAnswer -. lowAnswer))
|
let analyticalKl = Js.Math.log((highPrediction -. lowPrediction) /. (highAnswer -. lowAnswer))
|
||||||
let kl = E.R.liftJoin2(klDivergence, prediction, answer)
|
let kl = E.R.liftJoin2(klDivergence, prediction, answer)
|
||||||
|
Js.Console.log2("Analytical: ", analyticalKl)
|
||||||
|
Js.Console.log2("Computed: ", kl)
|
||||||
switch kl {
|
switch kl {
|
||||||
| Ok(kl') => kl'->expect->toBeCloseTo(analyticalKl)
|
| Ok(kl') => kl'->expect->toBeCloseTo(analyticalKl)
|
||||||
| Error(err) => {
|
| Error(err) => {
|
||||||
|
@ -32,7 +34,7 @@ describe("kl divergence", () => {
|
||||||
let mean1 = 4.0
|
let mean1 = 4.0
|
||||||
let mean2 = 1.0
|
let mean2 = 1.0
|
||||||
let stdev1 = 1.0
|
let stdev1 = 1.0
|
||||||
let stdev2 = 1.0
|
let stdev2 = 4.0
|
||||||
|
|
||||||
let prediction =
|
let prediction =
|
||||||
normalMakeR(mean1, stdev1)->E.R2.errMap(s => DistributionTypes.ArgumentError(s))
|
normalMakeR(mean1, stdev1)->E.R2.errMap(s => DistributionTypes.ArgumentError(s))
|
||||||
|
@ -42,6 +44,10 @@ describe("kl divergence", () => {
|
||||||
stdev1 ** 2.0 /. 2.0 /. stdev2 ** 2.0 +.
|
stdev1 ** 2.0 /. 2.0 /. stdev2 ** 2.0 +.
|
||||||
(mean1 -. mean2) ** 2.0 /. 2.0 /. stdev2 ** 2.0 -. 0.5
|
(mean1 -. mean2) ** 2.0 /. 2.0 /. stdev2 ** 2.0 -. 0.5
|
||||||
let kl = E.R.liftJoin2(klDivergence, prediction, answer)
|
let kl = E.R.liftJoin2(klDivergence, prediction, answer)
|
||||||
|
|
||||||
|
Js.Console.log2("Analytical: ", analyticalKl)
|
||||||
|
Js.Console.log2("Computed: ", kl)
|
||||||
|
|
||||||
switch kl {
|
switch kl {
|
||||||
| Ok(kl') => kl'->expect->toBeCloseTo(analyticalKl)
|
| Ok(kl') => kl'->expect->toBeCloseTo(analyticalKl)
|
||||||
| Error(err) => {
|
| Error(err) => {
|
||||||
|
@ -51,3 +57,33 @@ describe("kl divergence", () => {
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
})
|
})
|
||||||
|
|
||||||
|
describe("combine along support test", () => {
|
||||||
|
let combineAlongSupportOfSecondArgument = XYShape.PointwiseCombination.combineAlongSupportOfSecondArgument
|
||||||
|
let lowAnswer = 0.0
|
||||||
|
let highAnswer = 1.0
|
||||||
|
let lowPrediction = -1.0
|
||||||
|
let highPrediction = 2.0
|
||||||
|
|
||||||
|
let answer =
|
||||||
|
uniformMakeR(lowAnswer, highAnswer)->E.R2.errMap(s => DistributionTypes.ArgumentError(s))
|
||||||
|
let prediction =
|
||||||
|
uniformMakeR(lowPrediction, highPrediction)->E.R2.errMap(s => DistributionTypes.ArgumentError(
|
||||||
|
s,
|
||||||
|
))
|
||||||
|
let answerWrapped = E.R.fmap(a => run(FromDist(ToDist(ToPointSet), a)), answer)
|
||||||
|
let predictionWrapped = E.R.fmap(a => run(FromDist(ToDist(ToPointSet), a)), prediction)
|
||||||
|
|
||||||
|
let interpolator = XYShape.XtoY.continuousInterpolator(#Stepwise, #UseZero)
|
||||||
|
let integrand = PointSetDist_Scoring.KLDivergence.integrand
|
||||||
|
|
||||||
|
let result = switch (answerWrapped, predictionWrapped) {
|
||||||
|
| (Ok(Dist(PointSet(Continuous(a)))), Ok(Dist(PointSet(Continuous(b))))) =>
|
||||||
|
Some(combineAlongSupportOfSecondArgument(integrand, interpolator, a.xyShape, b.xyShape))
|
||||||
|
| _ => None
|
||||||
|
}
|
||||||
|
test("combine along support test", _ => {
|
||||||
|
Js.Console.log2("combineAlongSupportOfSecondArgument", result)
|
||||||
|
false->expect->toBe(true)
|
||||||
|
})
|
||||||
|
})
|
||||||
|
|
|
@ -10,6 +10,6 @@ module KLDivergence = {
|
||||||
Error(Operation.NegativeInfinityError)
|
Error(Operation.NegativeInfinityError)
|
||||||
} else {
|
} else {
|
||||||
let quot = predictionElement /. answerElement
|
let quot = predictionElement /. answerElement
|
||||||
quot < 0.0 ? Error(Operation.ComplexNumberError) : Ok(answerElement *. logFn(quot))
|
quot < 0.0 ? Error(Operation.ComplexNumberError) : Ok(-.answerElement *. logFn(quot))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -391,7 +391,7 @@ module PointwiseCombination = {
|
||||||
`)
|
`)
|
||||||
|
|
||||||
// This function is used for kl divergence
|
// This function is used for kl divergence
|
||||||
let combineAlongSupportOfSecondArgument: (
|
let combineAlongSupportOfSecondArgument0: (
|
||||||
(float, float) => result<float, Operation.Error.t>,
|
(float, float) => result<float, Operation.Error.t>,
|
||||||
interpolator,
|
interpolator,
|
||||||
T.t,
|
T.t,
|
||||||
|
@ -489,12 +489,12 @@ module PointwiseCombination = {
|
||||||
result
|
result
|
||||||
}
|
}
|
||||||
|
|
||||||
let combineAlongSupportOfSecondArgument2: (
|
let combineAlongSupportOfSecondArgument: (
|
||||||
(float, float) => result<float, Operation.Error.t>,
|
(float, float) => result<float, Operation.Error.t>,
|
||||||
|
interpolator,
|
||||||
T.t,
|
T.t,
|
||||||
T.t,
|
T.t,
|
||||||
) => result<T.t, Operation.Error.t> = (fn, prediction, answer) => {
|
) => result<T.t, Operation.Error.t> = (fn, interpolator, prediction, answer) => {
|
||||||
let newXs = answer.xs
|
|
||||||
let combineWithFn = (x: float, i: int) => {
|
let combineWithFn = (x: float, i: int) => {
|
||||||
let answerX = x
|
let answerX = x
|
||||||
let answerY = answer.ys[i]
|
let answerY = answer.ys[i]
|
||||||
|
@ -506,9 +506,15 @@ module PointwiseCombination = {
|
||||||
}
|
}
|
||||||
result
|
result
|
||||||
}
|
}
|
||||||
let newYs = Js.Array.mapi((x, i) => combineWithFn(x, i), answer.xs)
|
let newYsWithError = Js.Array.mapi((x, i) => combineWithFn(x, i), answer.xs)
|
||||||
|
let newYsOrError = E.A.R.firstErrorOrOpen(newYsWithError)
|
||||||
|
let result = switch newYsOrError {
|
||||||
|
| Ok(a) => Ok({xs: answer.xs, ys: a})
|
||||||
|
| Error(b) => Error(b)
|
||||||
|
}
|
||||||
|
|
||||||
T.filterOkYs(newXs, newYs)->Ok
|
// T.filterOkYs(newXs, newYs)->Ok
|
||||||
|
result
|
||||||
}
|
}
|
||||||
|
|
||||||
let addCombine = (interpolator: interpolator, t1: T.t, t2: T.t): T.t =>
|
let addCombine = (interpolator: interpolator, t1: T.t, t2: T.t): T.t =>
|
||||||
|
|
Loading…
Reference in New Issue
Block a user