Clean up for CR
Value: [1e-7 to 43-4]
This commit is contained in:
parent
b90a0e7a1a
commit
f5e3701a79
|
@ -119,18 +119,18 @@ describe("klDivergence: discrete -> discrete -> float", () => {
|
|||
describe("klDivergence: mixed -> mixed -> float", () => {
|
||||
let klDivergence = DistributionOperation.Constructors.klDivergence(~env)
|
||||
let mixture = a => DistributionTypes.DistributionOperation.Mixture(a)
|
||||
let a' = [(point1, 1e0), (uniformDist, 1e0)]->mixture->run
|
||||
let b' = [(point1, 1e0), (floatDist, 1e0), (normalDist10, 1e0)]->mixture->run
|
||||
let c' = [(point1, 1e0), (point2, 1e0), (point3, 1e0), (uniformDist, 1e0)]->mixture->run
|
||||
let a' = [(point1, 1.0), (uniformDist, 1.0)]->mixture->run
|
||||
let b' = [(point1, 1.0), (floatDist, 1.0), (normalDist10, 1.0)]->mixture->run
|
||||
let c' = [(point1, 1.0), (point2, 1.0), (point3, 1.0), (uniformDist, 1.0)]->mixture->run
|
||||
let d' =
|
||||
[(point1, 1e0), (point2, 1e0), (point3, 1e0), (floatDist, 1e0), (uniformDist2, 1e0)]
|
||||
[(point1, 1.0), (point2, 1.0), (point3, 1.0), (floatDist, 1.0), (uniformDist2, 1.0)]
|
||||
->mixture
|
||||
->run
|
||||
let (a, b, c, d) = switch (a', b', c', d') {
|
||||
| (Dist(a''), Dist(b''), Dist(c''), Dist(d'')) => (a'', b'', c'', d'')
|
||||
| _ => raise(MixtureFailed)
|
||||
}
|
||||
test("finite klDivergence return is correct", () => {
|
||||
test("finite klDivergence produces correct answer", () => {
|
||||
let prediction = b
|
||||
let answer = a
|
||||
let kl = klDivergence(prediction, answer)
|
||||
|
@ -156,7 +156,7 @@ describe("klDivergence: mixed -> mixed -> float", () => {
|
|||
raise(KlFailed)
|
||||
}
|
||||
})
|
||||
test("finite klDivergence return is correct", () => {
|
||||
test("finite klDivergence produces correct answer", () => {
|
||||
let prediction = d
|
||||
let answer = c
|
||||
let kl = klDivergence(prediction, answer)
|
||||
|
|
|
@ -272,14 +272,10 @@ module T = Dist({
|
|||
XYShape.Analysis.getVarianceDangerously(t, mean, Analysis.getMeanOfSquares)
|
||||
|
||||
let klDivergence = (prediction: t, answer: t) => {
|
||||
let enrich = true
|
||||
let enrichedAnswer = enrich
|
||||
? XYShape.PointwiseCombination.enrichXyShape(answer.xyShape)
|
||||
: answer.xyShape //
|
||||
let newShape = XYShape.PointwiseCombination.combineAlongSupportOfSecondArgument(
|
||||
PointSetDist_Scoring.KLDivergence.integrand,
|
||||
prediction.xyShape,
|
||||
enrichedAnswer,
|
||||
answer.xyShape,
|
||||
)
|
||||
let xyShapeToContinuous: XYShape.xyShape => t = xyShape => {
|
||||
xyShape: xyShape,
|
||||
|
|
|
@ -453,14 +453,15 @@ module PointwiseCombination = {
|
|||
T.filterOkYs(newXs, newYs)->Ok
|
||||
}
|
||||
|
||||
// Nuño wrote this function to try to increase precision, but it didn't work.
|
||||
// *Dead code*: Nuño wrote this function to try to increase precision, but it didn't work.
|
||||
// If another traveler comes through with a similar idea, we hope this implementation will help them.
|
||||
let enrichXyShape = (t: T.t): T.t => {
|
||||
let enrichmentFactor = 10
|
||||
let defaultEnrichmentFactor = 10
|
||||
let length = E.A.length(t.xs)
|
||||
let points =
|
||||
length < MagicNumbers.Environment.defaultXYPointLength
|
||||
? enrichmentFactor * MagicNumbers.Environment.defaultXYPointLength / length
|
||||
: enrichmentFactor
|
||||
? defaultEnrichmentFactor * MagicNumbers.Environment.defaultXYPointLength / length
|
||||
: defaultEnrichmentFactor
|
||||
|
||||
let getInBetween = (x1: float, x2: float): array<float> => {
|
||||
if abs_float(x1 -. x2) < 2.0 *. MagicNumbers.Epsilon.seven {
|
||||
|
|
Loading…
Reference in New Issue
Block a user