Clean up for CR
Value: [1e-7 to 43-4]
This commit is contained in:
parent
b90a0e7a1a
commit
f5e3701a79
|
@ -119,18 +119,18 @@ describe("klDivergence: discrete -> discrete -> float", () => {
|
||||||
describe("klDivergence: mixed -> mixed -> float", () => {
|
describe("klDivergence: mixed -> mixed -> float", () => {
|
||||||
let klDivergence = DistributionOperation.Constructors.klDivergence(~env)
|
let klDivergence = DistributionOperation.Constructors.klDivergence(~env)
|
||||||
let mixture = a => DistributionTypes.DistributionOperation.Mixture(a)
|
let mixture = a => DistributionTypes.DistributionOperation.Mixture(a)
|
||||||
let a' = [(point1, 1e0), (uniformDist, 1e0)]->mixture->run
|
let a' = [(point1, 1.0), (uniformDist, 1.0)]->mixture->run
|
||||||
let b' = [(point1, 1e0), (floatDist, 1e0), (normalDist10, 1e0)]->mixture->run
|
let b' = [(point1, 1.0), (floatDist, 1.0), (normalDist10, 1.0)]->mixture->run
|
||||||
let c' = [(point1, 1e0), (point2, 1e0), (point3, 1e0), (uniformDist, 1e0)]->mixture->run
|
let c' = [(point1, 1.0), (point2, 1.0), (point3, 1.0), (uniformDist, 1.0)]->mixture->run
|
||||||
let d' =
|
let d' =
|
||||||
[(point1, 1e0), (point2, 1e0), (point3, 1e0), (floatDist, 1e0), (uniformDist2, 1e0)]
|
[(point1, 1.0), (point2, 1.0), (point3, 1.0), (floatDist, 1.0), (uniformDist2, 1.0)]
|
||||||
->mixture
|
->mixture
|
||||||
->run
|
->run
|
||||||
let (a, b, c, d) = switch (a', b', c', d') {
|
let (a, b, c, d) = switch (a', b', c', d') {
|
||||||
| (Dist(a''), Dist(b''), Dist(c''), Dist(d'')) => (a'', b'', c'', d'')
|
| (Dist(a''), Dist(b''), Dist(c''), Dist(d'')) => (a'', b'', c'', d'')
|
||||||
| _ => raise(MixtureFailed)
|
| _ => raise(MixtureFailed)
|
||||||
}
|
}
|
||||||
test("finite klDivergence return is correct", () => {
|
test("finite klDivergence produces correct answer", () => {
|
||||||
let prediction = b
|
let prediction = b
|
||||||
let answer = a
|
let answer = a
|
||||||
let kl = klDivergence(prediction, answer)
|
let kl = klDivergence(prediction, answer)
|
||||||
|
@ -156,7 +156,7 @@ describe("klDivergence: mixed -> mixed -> float", () => {
|
||||||
raise(KlFailed)
|
raise(KlFailed)
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
test("finite klDivergence return is correct", () => {
|
test("finite klDivergence produces correct answer", () => {
|
||||||
let prediction = d
|
let prediction = d
|
||||||
let answer = c
|
let answer = c
|
||||||
let kl = klDivergence(prediction, answer)
|
let kl = klDivergence(prediction, answer)
|
||||||
|
|
|
@ -272,14 +272,10 @@ module T = Dist({
|
||||||
XYShape.Analysis.getVarianceDangerously(t, mean, Analysis.getMeanOfSquares)
|
XYShape.Analysis.getVarianceDangerously(t, mean, Analysis.getMeanOfSquares)
|
||||||
|
|
||||||
let klDivergence = (prediction: t, answer: t) => {
|
let klDivergence = (prediction: t, answer: t) => {
|
||||||
let enrich = true
|
|
||||||
let enrichedAnswer = enrich
|
|
||||||
? XYShape.PointwiseCombination.enrichXyShape(answer.xyShape)
|
|
||||||
: answer.xyShape //
|
|
||||||
let newShape = XYShape.PointwiseCombination.combineAlongSupportOfSecondArgument(
|
let newShape = XYShape.PointwiseCombination.combineAlongSupportOfSecondArgument(
|
||||||
PointSetDist_Scoring.KLDivergence.integrand,
|
PointSetDist_Scoring.KLDivergence.integrand,
|
||||||
prediction.xyShape,
|
prediction.xyShape,
|
||||||
enrichedAnswer,
|
answer.xyShape,
|
||||||
)
|
)
|
||||||
let xyShapeToContinuous: XYShape.xyShape => t = xyShape => {
|
let xyShapeToContinuous: XYShape.xyShape => t = xyShape => {
|
||||||
xyShape: xyShape,
|
xyShape: xyShape,
|
||||||
|
|
|
@ -453,14 +453,15 @@ module PointwiseCombination = {
|
||||||
T.filterOkYs(newXs, newYs)->Ok
|
T.filterOkYs(newXs, newYs)->Ok
|
||||||
}
|
}
|
||||||
|
|
||||||
// Nuño wrote this function to try to increase precision, but it didn't work.
|
// *Dead code*: Nuño wrote this function to try to increase precision, but it didn't work.
|
||||||
|
// If another traveler comes through with a similar idea, we hope this implementation will help them.
|
||||||
let enrichXyShape = (t: T.t): T.t => {
|
let enrichXyShape = (t: T.t): T.t => {
|
||||||
let enrichmentFactor = 10
|
let defaultEnrichmentFactor = 10
|
||||||
let length = E.A.length(t.xs)
|
let length = E.A.length(t.xs)
|
||||||
let points =
|
let points =
|
||||||
length < MagicNumbers.Environment.defaultXYPointLength
|
length < MagicNumbers.Environment.defaultXYPointLength
|
||||||
? enrichmentFactor * MagicNumbers.Environment.defaultXYPointLength / length
|
? defaultEnrichmentFactor * MagicNumbers.Environment.defaultXYPointLength / length
|
||||||
: enrichmentFactor
|
: defaultEnrichmentFactor
|
||||||
|
|
||||||
let getInBetween = (x1: float, x2: float): array<float> => {
|
let getInBetween = (x1: float, x2: float): array<float> => {
|
||||||
if abs_float(x1 -. x2) < 2.0 *. MagicNumbers.Epsilon.seven {
|
if abs_float(x1 -. x2) < 2.0 *. MagicNumbers.Epsilon.seven {
|
||||||
|
|
Loading…
Reference in New Issue
Block a user