Added extra multiplicative factor in logScore integrand

Value: [8e-2 to 7e-1]

migrated intregrand from `log(predicted / answer)` to `answer *
log(predicted / answer)`
This commit is contained in:
Quinn Dougherty 2022-05-02 13:40:34 -04:00
parent d595285078
commit db3acbf96c
4 changed files with 27 additions and 1 deletions

View File

@ -0,0 +1,3 @@
open Jest
open Expect
open TestHelpers

View File

@ -0,0 +1,19 @@
import { testRun } from "./TestHelpers";
describe("KL divergence", () => {
test("by integral solver agrees with analytical", () => {
let squiggleStringKL = `prediction=normal(4, 1)
answer=normal(1,1)
logSubtraction=dotSubtract(scaleLog(answer),scaleLog(prediction))
klintegrand=dotMultiply(logSubtraction, answer)
klintegral = integralSum(klintegrand)
analyticalKl = log(1 / 1) + 1 ^ 2 / (2 * 1 ^ 2) + ((4 - 1) * (1 - 4) / (2 * 1 * 1)) - 1 / 2
klintegral - analyticalKl`;
let squiggleResultKL = testRun(squiggleStringKL);
expect(squiggleResultKL.value).toBeCloseTo(0);
});
});
let squiggleStringLS = `prediction=normal(4,1)
answer=normal(1,1)
logScore(prediction, answer)`;

View File

@ -271,7 +271,10 @@ module T = Dist({
XYShape.Analysis.getVarianceDangerously(t, mean, Analysis.getMeanOfSquares) XYShape.Analysis.getVarianceDangerously(t, mean, Analysis.getMeanOfSquares)
let logScore = (base: t, reference: t) => { let logScore = (base: t, reference: t) => {
combinePointwise(PointSetDist_Scoring.LogScoring.logScore, base, reference) E.R2.bind(
combinePointwise(PointSetDist_Scoring.LogScoring.multiply, reference),
combinePointwise(PointSetDist_Scoring.LogScoring.logScore, base, reference),
)
|> E.R.fmap(shapeMap(XYShape.T.filterYValues(Js.Float.isFinite))) |> E.R.fmap(shapeMap(XYShape.T.filterYValues(Js.Float.isFinite)))
|> E.R.fmap(integralEndY) |> E.R.fmap(integralEndY)
} }

View File

@ -4,4 +4,5 @@ module LogScoring = {
let logScore = (a: float, b: float): result<float, Operation.Error.t> => Ok( let logScore = (a: float, b: float): result<float, Operation.Error.t> => Ok(
Js.Math.log2(Js.Math.abs_float(a /. b)), Js.Math.log2(Js.Math.abs_float(a /. b)),
) )
let multiply = (a: float, b: float): result<float, Operation.Error.t> => Ok(a *. b)
} }