Some Cleanup
Btw, Nuño gets a lot of credit for the last commit! Value: [1e-2 to 8e-2]
This commit is contained in:
parent
722bfc6366
commit
bcf620337a
|
@ -5,40 +5,9 @@ open TestHelpers
|
|||
describe("kl divergence", () => {
|
||||
let klDivergence = DistributionOperation.Constructors.klDivergence(~env)
|
||||
exception KlFailed
|
||||
|
||||
let testUniform = (lowAnswer, highAnswer, lowPrediction, highPrediction) => {
|
||||
test("of two uniforms is equal to the analytic expression", () => {
|
||||
let lowAnswer = 0.0
|
||||
let highAnswer = 1.0
|
||||
let lowPrediction = -1.0
|
||||
let highPrediction = 2.0
|
||||
let answer =
|
||||
uniformMakeR(lowAnswer, highAnswer)->E.R2.errMap(s => DistributionTypes.ArgumentError(s))
|
||||
let prediction =
|
||||
uniformMakeR(lowPrediction, highPrediction)->E.R2.errMap(s => DistributionTypes.ArgumentError(
|
||||
s,
|
||||
))
|
||||
// integral along the support of the answer of answer.pdf(x) times log of prediction.pdf(x) divided by answer.pdf(x) dx
|
||||
let analyticalKl = Js.Math.log((highPrediction -. lowPrediction) /. (highAnswer -. lowAnswer))
|
||||
let kl = E.R.liftJoin2(klDivergence, prediction, answer)
|
||||
Js.Console.log2("Analytical: ", analyticalKl)
|
||||
Js.Console.log2("Computed: ", kl)
|
||||
switch kl {
|
||||
| Ok(kl') => kl'->expect->toBeCloseTo(analyticalKl)
|
||||
| Error(err) => {
|
||||
Js.Console.log(DistributionTypes.Error.toString(err))
|
||||
raise(KlFailed)
|
||||
}
|
||||
}
|
||||
})
|
||||
test(
|
||||
"of two uniforms is equal to the analytic expression, part 2 (annoying numerical errors)",
|
||||
() => {
|
||||
Js.Console.log(
|
||||
"This will fait because of extremely annoying numerical errors. Will not fail if the two uniforms are a bit different. Very annoying",
|
||||
)
|
||||
let lowAnswer = 0.0
|
||||
let highAnswer = 1.0
|
||||
let lowPrediction = 0.0
|
||||
let highPrediction = 2.0
|
||||
let answer =
|
||||
uniformMakeR(lowAnswer, highAnswer)->E.R2.errMap(s => DistributionTypes.ArgumentError(s))
|
||||
let prediction =
|
||||
|
@ -49,8 +18,6 @@ describe("kl divergence", () => {
|
|||
// integral along the support of the answer of answer.pdf(x) times log of prediction.pdf(x) divided by answer.pdf(x) dx
|
||||
let analyticalKl = Js.Math.log((highPrediction -. lowPrediction) /. (highAnswer -. lowAnswer))
|
||||
let kl = E.R.liftJoin2(klDivergence, prediction, answer)
|
||||
Js.Console.log2("Analytical: ", analyticalKl)
|
||||
Js.Console.log2("Computed: ", kl)
|
||||
switch kl {
|
||||
| Ok(kl') => kl'->expect->toBeCloseTo(analyticalKl)
|
||||
| Error(err) => {
|
||||
|
@ -58,8 +25,12 @@ describe("kl divergence", () => {
|
|||
raise(KlFailed)
|
||||
}
|
||||
}
|
||||
},
|
||||
)
|
||||
})
|
||||
}
|
||||
testUniform(0.0, 1.0, -1.0, 2.0)
|
||||
testUniform(0.0, 1.0, 0.0, 2.0)
|
||||
// testUniform(-1.0, 1.0, 0.0, 2.0)
|
||||
|
||||
test("of two normals is equal to the formula", () => {
|
||||
// This test case comes via Nuño https://github.com/quantified-uncertainty/squiggle/issues/433
|
||||
let mean1 = 4.0
|
||||
|
@ -90,8 +61,7 @@ describe("kl divergence", () => {
|
|||
})
|
||||
|
||||
describe("combine along support test", () => {
|
||||
Skip.test("combine along support test", _ => {
|
||||
// doesn't matter
|
||||
test("combine along support test", _ => {
|
||||
let combineAlongSupportOfSecondArgument = XYShape.PointwiseCombination.combineAlongSupportOfSecondArgument0
|
||||
let lowAnswer = 0.0
|
||||
let highAnswer = 1.0
|
||||
|
@ -115,7 +85,27 @@ describe("combine along support test", () => {
|
|||
Some(combineAlongSupportOfSecondArgument(integrand, interpolator, a.xyShape, b.xyShape))
|
||||
| _ => None
|
||||
}
|
||||
Js.Console.log2("combineAlongSupportOfSecondArgument", result)
|
||||
false->expect->toBe(true)
|
||||
result
|
||||
->expect
|
||||
->toEqual(
|
||||
Some(
|
||||
Ok({
|
||||
xs: [
|
||||
0.0,
|
||||
MagicNumbers.Epsilon.ten,
|
||||
2.0 *. MagicNumbers.Epsilon.ten,
|
||||
1.0 -. MagicNumbers.Epsilon.ten,
|
||||
1.0,
|
||||
],
|
||||
ys: [
|
||||
-0.34657359027997264,
|
||||
-0.34657359027997264,
|
||||
-0.34657359027997264,
|
||||
-0.34657359027997264,
|
||||
-0.34657359027997264,
|
||||
],
|
||||
}),
|
||||
),
|
||||
)
|
||||
})
|
||||
})
|
||||
|
|
|
@ -15,7 +15,6 @@
|
|||
"test": "jest",
|
||||
"test:ts": "jest __tests__/TS/",
|
||||
"test:rescript": "jest --modulePathIgnorePatterns=__tests__/TS/*",
|
||||
"test:kldivergence": "jest __tests__/Distributions/KlDivergence_test.*",
|
||||
"test:watch": "jest --watchAll",
|
||||
"coverage:rescript": "rm -f *.coverage; yarn clean; BISECT_ENABLE=yes yarn build; yarn test:rescript; bisect-ppx-report html",
|
||||
"coverage:ts": "yarn clean; yarn build; nyc --reporter=lcov yarn test:ts",
|
||||
|
|
Loading…
Reference in New Issue
Block a user