Some Cleanup

Btw, Nuño gets a lot of credit for the last commit!

Value: [1e-2 to 8e-2]
This commit is contained in:
Quinn Dougherty 2022-05-06 14:21:53 -04:00
parent 722bfc6366
commit bcf620337a
2 changed files with 32 additions and 43 deletions

View File

@ -5,40 +5,9 @@ open TestHelpers
describe("kl divergence", () => { describe("kl divergence", () => {
let klDivergence = DistributionOperation.Constructors.klDivergence(~env) let klDivergence = DistributionOperation.Constructors.klDivergence(~env)
exception KlFailed exception KlFailed
test("of two uniforms is equal to the analytic expression", () => {
let lowAnswer = 0.0 let testUniform = (lowAnswer, highAnswer, lowPrediction, highPrediction) => {
let highAnswer = 1.0 test("of two uniforms is equal to the analytic expression", () => {
let lowPrediction = -1.0
let highPrediction = 2.0
let answer =
uniformMakeR(lowAnswer, highAnswer)->E.R2.errMap(s => DistributionTypes.ArgumentError(s))
let prediction =
uniformMakeR(lowPrediction, highPrediction)->E.R2.errMap(s => DistributionTypes.ArgumentError(
s,
))
// integral along the support of the answer of answer.pdf(x) times log of prediction.pdf(x) divided by answer.pdf(x) dx
let analyticalKl = Js.Math.log((highPrediction -. lowPrediction) /. (highAnswer -. lowAnswer))
let kl = E.R.liftJoin2(klDivergence, prediction, answer)
Js.Console.log2("Analytical: ", analyticalKl)
Js.Console.log2("Computed: ", kl)
switch kl {
| Ok(kl') => kl'->expect->toBeCloseTo(analyticalKl)
| Error(err) => {
Js.Console.log(DistributionTypes.Error.toString(err))
raise(KlFailed)
}
}
})
test(
"of two uniforms is equal to the analytic expression, part 2 (annoying numerical errors)",
() => {
Js.Console.log(
"This will fait because of extremely annoying numerical errors. Will not fail if the two uniforms are a bit different. Very annoying",
)
let lowAnswer = 0.0
let highAnswer = 1.0
let lowPrediction = 0.0
let highPrediction = 2.0
let answer = let answer =
uniformMakeR(lowAnswer, highAnswer)->E.R2.errMap(s => DistributionTypes.ArgumentError(s)) uniformMakeR(lowAnswer, highAnswer)->E.R2.errMap(s => DistributionTypes.ArgumentError(s))
let prediction = let prediction =
@ -49,8 +18,6 @@ describe("kl divergence", () => {
// integral along the support of the answer of answer.pdf(x) times log of prediction.pdf(x) divided by answer.pdf(x) dx // integral along the support of the answer of answer.pdf(x) times log of prediction.pdf(x) divided by answer.pdf(x) dx
let analyticalKl = Js.Math.log((highPrediction -. lowPrediction) /. (highAnswer -. lowAnswer)) let analyticalKl = Js.Math.log((highPrediction -. lowPrediction) /. (highAnswer -. lowAnswer))
let kl = E.R.liftJoin2(klDivergence, prediction, answer) let kl = E.R.liftJoin2(klDivergence, prediction, answer)
Js.Console.log2("Analytical: ", analyticalKl)
Js.Console.log2("Computed: ", kl)
switch kl { switch kl {
| Ok(kl') => kl'->expect->toBeCloseTo(analyticalKl) | Ok(kl') => kl'->expect->toBeCloseTo(analyticalKl)
| Error(err) => { | Error(err) => {
@ -58,8 +25,12 @@ describe("kl divergence", () => {
raise(KlFailed) raise(KlFailed)
} }
} }
}, })
) }
testUniform(0.0, 1.0, -1.0, 2.0)
testUniform(0.0, 1.0, 0.0, 2.0)
// testUniform(-1.0, 1.0, 0.0, 2.0)
test("of two normals is equal to the formula", () => { test("of two normals is equal to the formula", () => {
// This test case comes via Nuño https://github.com/quantified-uncertainty/squiggle/issues/433 // This test case comes via Nuño https://github.com/quantified-uncertainty/squiggle/issues/433
let mean1 = 4.0 let mean1 = 4.0
@ -90,8 +61,7 @@ describe("kl divergence", () => {
}) })
describe("combine along support test", () => { describe("combine along support test", () => {
Skip.test("combine along support test", _ => { test("combine along support test", _ => {
// doesn't matter
let combineAlongSupportOfSecondArgument = XYShape.PointwiseCombination.combineAlongSupportOfSecondArgument0 let combineAlongSupportOfSecondArgument = XYShape.PointwiseCombination.combineAlongSupportOfSecondArgument0
let lowAnswer = 0.0 let lowAnswer = 0.0
let highAnswer = 1.0 let highAnswer = 1.0
@ -115,7 +85,27 @@ describe("combine along support test", () => {
Some(combineAlongSupportOfSecondArgument(integrand, interpolator, a.xyShape, b.xyShape)) Some(combineAlongSupportOfSecondArgument(integrand, interpolator, a.xyShape, b.xyShape))
| _ => None | _ => None
} }
Js.Console.log2("combineAlongSupportOfSecondArgument", result) result
false->expect->toBe(true) ->expect
->toEqual(
Some(
Ok({
xs: [
0.0,
MagicNumbers.Epsilon.ten,
2.0 *. MagicNumbers.Epsilon.ten,
1.0 -. MagicNumbers.Epsilon.ten,
1.0,
],
ys: [
-0.34657359027997264,
-0.34657359027997264,
-0.34657359027997264,
-0.34657359027997264,
-0.34657359027997264,
],
}),
),
)
}) })
}) })

View File

@ -15,7 +15,6 @@
"test": "jest", "test": "jest",
"test:ts": "jest __tests__/TS/", "test:ts": "jest __tests__/TS/",
"test:rescript": "jest --modulePathIgnorePatterns=__tests__/TS/*", "test:rescript": "jest --modulePathIgnorePatterns=__tests__/TS/*",
"test:kldivergence": "jest __tests__/Distributions/KlDivergence_test.*",
"test:watch": "jest --watchAll", "test:watch": "jest --watchAll",
"coverage:rescript": "rm -f *.coverage; yarn clean; BISECT_ENABLE=yes yarn build; yarn test:rescript; bisect-ppx-report html", "coverage:rescript": "rm -f *.coverage; yarn clean; BISECT_ENABLE=yes yarn build; yarn test:rescript; bisect-ppx-report html",
"coverage:ts": "yarn clean; yarn build; nyc --reporter=lcov yarn test:ts", "coverage:ts": "yarn clean; yarn build; nyc --reporter=lcov yarn test:ts",