Clean up for CR

Value: [1e-7 to 43-4]
This commit is contained in:
Quinn Dougherty 2022-05-12 09:51:20 -04:00
parent b90a0e7a1a
commit f5e3701a79
3 changed files with 12 additions and 15 deletions

View File

@ -119,18 +119,18 @@ describe("klDivergence: discrete -> discrete -> float", () => {
describe("klDivergence: mixed -> mixed -> float", () => { describe("klDivergence: mixed -> mixed -> float", () => {
let klDivergence = DistributionOperation.Constructors.klDivergence(~env) let klDivergence = DistributionOperation.Constructors.klDivergence(~env)
let mixture = a => DistributionTypes.DistributionOperation.Mixture(a) let mixture = a => DistributionTypes.DistributionOperation.Mixture(a)
let a' = [(point1, 1e0), (uniformDist, 1e0)]->mixture->run let a' = [(point1, 1.0), (uniformDist, 1.0)]->mixture->run
let b' = [(point1, 1e0), (floatDist, 1e0), (normalDist10, 1e0)]->mixture->run let b' = [(point1, 1.0), (floatDist, 1.0), (normalDist10, 1.0)]->mixture->run
let c' = [(point1, 1e0), (point2, 1e0), (point3, 1e0), (uniformDist, 1e0)]->mixture->run let c' = [(point1, 1.0), (point2, 1.0), (point3, 1.0), (uniformDist, 1.0)]->mixture->run
let d' = let d' =
[(point1, 1e0), (point2, 1e0), (point3, 1e0), (floatDist, 1e0), (uniformDist2, 1e0)] [(point1, 1.0), (point2, 1.0), (point3, 1.0), (floatDist, 1.0), (uniformDist2, 1.0)]
->mixture ->mixture
->run ->run
let (a, b, c, d) = switch (a', b', c', d') { let (a, b, c, d) = switch (a', b', c', d') {
| (Dist(a''), Dist(b''), Dist(c''), Dist(d'')) => (a'', b'', c'', d'') | (Dist(a''), Dist(b''), Dist(c''), Dist(d'')) => (a'', b'', c'', d'')
| _ => raise(MixtureFailed) | _ => raise(MixtureFailed)
} }
test("finite klDivergence return is correct", () => { test("finite klDivergence produces correct answer", () => {
let prediction = b let prediction = b
let answer = a let answer = a
let kl = klDivergence(prediction, answer) let kl = klDivergence(prediction, answer)
@ -156,7 +156,7 @@ describe("klDivergence: mixed -> mixed -> float", () => {
raise(KlFailed) raise(KlFailed)
} }
}) })
test("finite klDivergence return is correct", () => { test("finite klDivergence produces correct answer", () => {
let prediction = d let prediction = d
let answer = c let answer = c
let kl = klDivergence(prediction, answer) let kl = klDivergence(prediction, answer)

View File

@ -272,14 +272,10 @@ module T = Dist({
XYShape.Analysis.getVarianceDangerously(t, mean, Analysis.getMeanOfSquares) XYShape.Analysis.getVarianceDangerously(t, mean, Analysis.getMeanOfSquares)
let klDivergence = (prediction: t, answer: t) => { let klDivergence = (prediction: t, answer: t) => {
let enrich = true
let enrichedAnswer = enrich
? XYShape.PointwiseCombination.enrichXyShape(answer.xyShape)
: answer.xyShape //
let newShape = XYShape.PointwiseCombination.combineAlongSupportOfSecondArgument( let newShape = XYShape.PointwiseCombination.combineAlongSupportOfSecondArgument(
PointSetDist_Scoring.KLDivergence.integrand, PointSetDist_Scoring.KLDivergence.integrand,
prediction.xyShape, prediction.xyShape,
enrichedAnswer, answer.xyShape,
) )
let xyShapeToContinuous: XYShape.xyShape => t = xyShape => { let xyShapeToContinuous: XYShape.xyShape => t = xyShape => {
xyShape: xyShape, xyShape: xyShape,

View File

@ -453,14 +453,15 @@ module PointwiseCombination = {
T.filterOkYs(newXs, newYs)->Ok T.filterOkYs(newXs, newYs)->Ok
} }
// Nuño wrote this function to try to increase precision, but it didn't work. // *Dead code*: Nuño wrote this function to try to increase precision, but it didn't work.
// If another traveler comes through with a similar idea, we hope this implementation will help them.
let enrichXyShape = (t: T.t): T.t => { let enrichXyShape = (t: T.t): T.t => {
let enrichmentFactor = 10 let defaultEnrichmentFactor = 10
let length = E.A.length(t.xs) let length = E.A.length(t.xs)
let points = let points =
length < MagicNumbers.Environment.defaultXYPointLength length < MagicNumbers.Environment.defaultXYPointLength
? enrichmentFactor * MagicNumbers.Environment.defaultXYPointLength / length ? defaultEnrichmentFactor * MagicNumbers.Environment.defaultXYPointLength / length
: enrichmentFactor : defaultEnrichmentFactor
let getInBetween = (x1: float, x2: float): array<float> => { let getInBetween = (x1: float, x2: float): array<float> => {
if abs_float(x1 -. x2) < 2.0 *. MagicNumbers.Epsilon.seven { if abs_float(x1 -. x2) < 2.0 *. MagicNumbers.Epsilon.seven {