Some Cleanup

Btw, Nuño gets a lot of credit for the last commit!

Value: [1e-2 to 8e-2]
This commit is contained in:
Quinn Dougherty 2022-05-06 14:21:53 -04:00
parent 722bfc6366
commit bcf620337a
2 changed files with 32 additions and 43 deletions

View File

@ -5,40 +5,9 @@ open TestHelpers
describe("kl divergence", () => {
let klDivergence = DistributionOperation.Constructors.klDivergence(~env)
exception KlFailed
test("of two uniforms is equal to the analytic expression", () => {
let lowAnswer = 0.0
let highAnswer = 1.0
let lowPrediction = -1.0
let highPrediction = 2.0
let answer =
uniformMakeR(lowAnswer, highAnswer)->E.R2.errMap(s => DistributionTypes.ArgumentError(s))
let prediction =
uniformMakeR(lowPrediction, highPrediction)->E.R2.errMap(s => DistributionTypes.ArgumentError(
s,
))
// integral along the support of the answer of answer.pdf(x) times log of prediction.pdf(x) divided by answer.pdf(x) dx
let analyticalKl = Js.Math.log((highPrediction -. lowPrediction) /. (highAnswer -. lowAnswer))
let kl = E.R.liftJoin2(klDivergence, prediction, answer)
Js.Console.log2("Analytical: ", analyticalKl)
Js.Console.log2("Computed: ", kl)
switch kl {
| Ok(kl') => kl'->expect->toBeCloseTo(analyticalKl)
| Error(err) => {
Js.Console.log(DistributionTypes.Error.toString(err))
raise(KlFailed)
}
}
})
test(
"of two uniforms is equal to the analytic expression, part 2 (annoying numerical errors)",
() => {
Js.Console.log(
"This will fait because of extremely annoying numerical errors. Will not fail if the two uniforms are a bit different. Very annoying",
)
let lowAnswer = 0.0
let highAnswer = 1.0
let lowPrediction = 0.0
let highPrediction = 2.0
let testUniform = (lowAnswer, highAnswer, lowPrediction, highPrediction) => {
test("of two uniforms is equal to the analytic expression", () => {
let answer =
uniformMakeR(lowAnswer, highAnswer)->E.R2.errMap(s => DistributionTypes.ArgumentError(s))
let prediction =
@ -49,8 +18,6 @@ describe("kl divergence", () => {
// integral along the support of the answer of answer.pdf(x) times log of prediction.pdf(x) divided by answer.pdf(x) dx
let analyticalKl = Js.Math.log((highPrediction -. lowPrediction) /. (highAnswer -. lowAnswer))
let kl = E.R.liftJoin2(klDivergence, prediction, answer)
Js.Console.log2("Analytical: ", analyticalKl)
Js.Console.log2("Computed: ", kl)
switch kl {
| Ok(kl') => kl'->expect->toBeCloseTo(analyticalKl)
| Error(err) => {
@ -58,8 +25,12 @@ describe("kl divergence", () => {
raise(KlFailed)
}
}
},
)
})
}
testUniform(0.0, 1.0, -1.0, 2.0)
testUniform(0.0, 1.0, 0.0, 2.0)
// testUniform(-1.0, 1.0, 0.0, 2.0)
test("of two normals is equal to the formula", () => {
// This test case comes via Nuño https://github.com/quantified-uncertainty/squiggle/issues/433
let mean1 = 4.0
@ -90,8 +61,7 @@ describe("kl divergence", () => {
})
describe("combine along support test", () => {
Skip.test("combine along support test", _ => {
// doesn't matter
test("combine along support test", _ => {
let combineAlongSupportOfSecondArgument = XYShape.PointwiseCombination.combineAlongSupportOfSecondArgument0
let lowAnswer = 0.0
let highAnswer = 1.0
@ -115,7 +85,27 @@ describe("combine along support test", () => {
Some(combineAlongSupportOfSecondArgument(integrand, interpolator, a.xyShape, b.xyShape))
| _ => None
}
Js.Console.log2("combineAlongSupportOfSecondArgument", result)
false->expect->toBe(true)
result
->expect
->toEqual(
Some(
Ok({
xs: [
0.0,
MagicNumbers.Epsilon.ten,
2.0 *. MagicNumbers.Epsilon.ten,
1.0 -. MagicNumbers.Epsilon.ten,
1.0,
],
ys: [
-0.34657359027997264,
-0.34657359027997264,
-0.34657359027997264,
-0.34657359027997264,
-0.34657359027997264,
],
}),
),
)
})
})

View File

@ -15,7 +15,6 @@
"test": "jest",
"test:ts": "jest __tests__/TS/",
"test:rescript": "jest --modulePathIgnorePatterns=__tests__/TS/*",
"test:kldivergence": "jest __tests__/Distributions/KlDivergence_test.*",
"test:watch": "jest --watchAll",
"coverage:rescript": "rm -f *.coverage; yarn clean; BISECT_ENABLE=yes yarn build; yarn test:rescript; bisect-ppx-report html",
"coverage:ts": "yarn clean; yarn build; nyc --reporter=lcov yarn test:ts",