fix: "Enrich" (add more x points) when integrating
in order to get more numerical precision. Note: not complete yet. Value: [1e-3 to 3e-1]
This commit is contained in:
parent
0b8da034c6
commit
4df4597ed3
|
@ -69,7 +69,7 @@ describe("klDivergence: continuous -> continuous -> float", () => {
|
||||||
2.0 ** 2.0 *.
|
2.0 ** 2.0 *.
|
||||||
(10.0 ** 2.0 -. (10.0 +. 9.0) *. 10.0 +. (9.0 ** 2.0 +. 10.0 *. 9.0 +. 10.0 ** 2.0) /. 3.0)
|
(10.0 ** 2.0 -. (10.0 +. 9.0) *. 10.0 +. (9.0 ** 2.0 +. 10.0 *. 9.0 +. 10.0 ** 2.0) /. 3.0)
|
||||||
switch kl {
|
switch kl {
|
||||||
| Ok(kl') => kl'->expect->toBeSoCloseTo(analyticalKl, ~digits=1)
|
| Ok(kl') => kl'->expect->toBeSoCloseTo(analyticalKl, ~digits=3)
|
||||||
| Error(err) => {
|
| Error(err) => {
|
||||||
Js.Console.log(DistributionTypes.Error.toString(err))
|
Js.Console.log(DistributionTypes.Error.toString(err))
|
||||||
raise(KlFailed)
|
raise(KlFailed)
|
||||||
|
|
|
@ -272,10 +272,11 @@ module T = Dist({
|
||||||
XYShape.Analysis.getVarianceDangerously(t, mean, Analysis.getMeanOfSquares)
|
XYShape.Analysis.getVarianceDangerously(t, mean, Analysis.getMeanOfSquares)
|
||||||
|
|
||||||
let klDivergence = (prediction: t, answer: t) => {
|
let klDivergence = (prediction: t, answer: t) => {
|
||||||
|
let enrichedAnswer = XYShape.PointwiseCombination.enrichXyShape(answer.xyShape) // answer.xyShape //
|
||||||
let newShape = XYShape.PointwiseCombination.combineAlongSupportOfSecondArgument(
|
let newShape = XYShape.PointwiseCombination.combineAlongSupportOfSecondArgument(
|
||||||
PointSetDist_Scoring.KLDivergence.integrand,
|
PointSetDist_Scoring.KLDivergence.integrand,
|
||||||
prediction.xyShape,
|
prediction.xyShape,
|
||||||
answer.xyShape,
|
enrichedAnswer,
|
||||||
)
|
)
|
||||||
let xyShapeToContinuous: XYShape.xyShape => t = xyShape => {
|
let xyShapeToContinuous: XYShape.xyShape => t = xyShape => {
|
||||||
xyShape: xyShape,
|
xyShape: xyShape,
|
||||||
|
|
|
@ -12,6 +12,7 @@ module Epsilon = {
|
||||||
module Environment = {
|
module Environment = {
|
||||||
let defaultXYPointLength = 1000
|
let defaultXYPointLength = 1000
|
||||||
let defaultSampleCount = 10000
|
let defaultSampleCount = 10000
|
||||||
|
let enrichmentFactor = 10
|
||||||
}
|
}
|
||||||
|
|
||||||
module OpCost = {
|
module OpCost = {
|
||||||
|
|
|
@ -453,6 +453,44 @@ module PointwiseCombination = {
|
||||||
T.filterOkYs(newXs, newYs)->Ok
|
T.filterOkYs(newXs, newYs)->Ok
|
||||||
}
|
}
|
||||||
|
|
||||||
|
let enrichXyShape = (t: T.t): T.t => {
|
||||||
|
let length = E.A.length(t.xs)
|
||||||
|
Js.Console.log(length)
|
||||||
|
let points = switch length < MagicNumbers.Environment.defaultXYPointLength {
|
||||||
|
| true =>
|
||||||
|
Belt.Int.fromFloat(
|
||||||
|
Belt.Float.fromInt(
|
||||||
|
MagicNumbers.Environment.enrichmentFactor * MagicNumbers.Environment.defaultXYPointLength,
|
||||||
|
) /.
|
||||||
|
Belt.Float.fromInt(length),
|
||||||
|
)
|
||||||
|
| false => MagicNumbers.Environment.enrichmentFactor
|
||||||
|
}
|
||||||
|
|
||||||
|
let getInBetween = (x1: float, x2: float): array<float> => {
|
||||||
|
let newPointsArray = Belt.Array.makeBy(points - 1, i => i)
|
||||||
|
// don't repeat the x2 point, it will be gotten in the next iteration.
|
||||||
|
let result = Js.Array.mapi((pos, i) =>
|
||||||
|
switch i {
|
||||||
|
| 0 => x1
|
||||||
|
| _ =>
|
||||||
|
x1 *.
|
||||||
|
(Belt.Float.fromInt(points) -. Belt.Float.fromInt(pos)) /.
|
||||||
|
Belt.Float.fromInt(points) +. x2 *. Belt.Float.fromInt(pos) /. Belt.Float.fromInt(points)
|
||||||
|
}
|
||||||
|
, newPointsArray)
|
||||||
|
result
|
||||||
|
}
|
||||||
|
let newXsUnflattened = Js.Array.mapi((x, i) =>
|
||||||
|
switch i < length - 1 {
|
||||||
|
| true => getInBetween(x, t.xs[i + 1])
|
||||||
|
| false => [x]
|
||||||
|
}
|
||||||
|
, t.xs)
|
||||||
|
let newXs = Belt.Array.concatMany(newXsUnflattened)
|
||||||
|
let newYs = E.A.fmap(x => XtoY.linear(x, t), newXs) //XtoY.linear(newXs)
|
||||||
|
{xs: newXs, ys: newYs}
|
||||||
|
}
|
||||||
// This function is used for klDivergence
|
// This function is used for klDivergence
|
||||||
let combineAlongSupportOfSecondArgument: (
|
let combineAlongSupportOfSecondArgument: (
|
||||||
(float, float) => result<float, Operation.Error.t>,
|
(float, float) => result<float, Operation.Error.t>,
|
||||||
|
|
Loading…
Reference in New Issue
Block a user