Clean up for CR

Value: [1e-7 to 43-4]
This commit is contained in:
Quinn Dougherty 2022-05-12 09:51:20 -04:00
parent b90a0e7a1a
commit f5e3701a79
3 changed files with 12 additions and 15 deletions

View File

@ -119,18 +119,18 @@ describe("klDivergence: discrete -> discrete -> float", () => {
describe("klDivergence: mixed -> mixed -> float", () => {
let klDivergence = DistributionOperation.Constructors.klDivergence(~env)
let mixture = a => DistributionTypes.DistributionOperation.Mixture(a)
let a' = [(point1, 1e0), (uniformDist, 1e0)]->mixture->run
let b' = [(point1, 1e0), (floatDist, 1e0), (normalDist10, 1e0)]->mixture->run
let c' = [(point1, 1e0), (point2, 1e0), (point3, 1e0), (uniformDist, 1e0)]->mixture->run
let a' = [(point1, 1.0), (uniformDist, 1.0)]->mixture->run
let b' = [(point1, 1.0), (floatDist, 1.0), (normalDist10, 1.0)]->mixture->run
let c' = [(point1, 1.0), (point2, 1.0), (point3, 1.0), (uniformDist, 1.0)]->mixture->run
let d' =
[(point1, 1e0), (point2, 1e0), (point3, 1e0), (floatDist, 1e0), (uniformDist2, 1e0)]
[(point1, 1.0), (point2, 1.0), (point3, 1.0), (floatDist, 1.0), (uniformDist2, 1.0)]
->mixture
->run
let (a, b, c, d) = switch (a', b', c', d') {
| (Dist(a''), Dist(b''), Dist(c''), Dist(d'')) => (a'', b'', c'', d'')
| _ => raise(MixtureFailed)
}
test("finite klDivergence return is correct", () => {
test("finite klDivergence produces correct answer", () => {
let prediction = b
let answer = a
let kl = klDivergence(prediction, answer)
@ -156,7 +156,7 @@ describe("klDivergence: mixed -> mixed -> float", () => {
raise(KlFailed)
}
})
test("finite klDivergence return is correct", () => {
test("finite klDivergence produces correct answer", () => {
let prediction = d
let answer = c
let kl = klDivergence(prediction, answer)

View File

@ -272,14 +272,10 @@ module T = Dist({
XYShape.Analysis.getVarianceDangerously(t, mean, Analysis.getMeanOfSquares)
let klDivergence = (prediction: t, answer: t) => {
let enrich = true
let enrichedAnswer = enrich
? XYShape.PointwiseCombination.enrichXyShape(answer.xyShape)
: answer.xyShape //
let newShape = XYShape.PointwiseCombination.combineAlongSupportOfSecondArgument(
PointSetDist_Scoring.KLDivergence.integrand,
prediction.xyShape,
enrichedAnswer,
answer.xyShape,
)
let xyShapeToContinuous: XYShape.xyShape => t = xyShape => {
xyShape: xyShape,

View File

@ -453,14 +453,15 @@ module PointwiseCombination = {
T.filterOkYs(newXs, newYs)->Ok
}
// Nuño wrote this function to try to increase precision, but it didn't work.
// *Dead code*: Nuño wrote this function to try to increase precision, but it didn't work.
// If another traveler comes through with a similar idea, we hope this implementation will help them.
let enrichXyShape = (t: T.t): T.t => {
let enrichmentFactor = 10
let defaultEnrichmentFactor = 10
let length = E.A.length(t.xs)
let points =
length < MagicNumbers.Environment.defaultXYPointLength
? enrichmentFactor * MagicNumbers.Environment.defaultXYPointLength / length
: enrichmentFactor
? defaultEnrichmentFactor * MagicNumbers.Environment.defaultXYPointLength / length
: defaultEnrichmentFactor
let getInBetween = (x1: float, x2: float): array<float> => {
if abs_float(x1 -. x2) < 2.0 *. MagicNumbers.Epsilon.seven {