Added extra multiplicative factor in logScore integrand
Value: [8e-2 to 7e-1] migrated intregrand from `log(predicted / answer)` to `answer * log(predicted / answer)`
This commit is contained in:
parent
d595285078
commit
db3acbf96c
|
@ -0,0 +1,3 @@
|
||||||
|
open Jest
|
||||||
|
open Expect
|
||||||
|
open TestHelpers
|
19
packages/squiggle-lang/__tests__/TS/Score_test.ts
Normal file
19
packages/squiggle-lang/__tests__/TS/Score_test.ts
Normal file
|
@ -0,0 +1,19 @@
|
||||||
|
import { testRun } from "./TestHelpers";
|
||||||
|
|
||||||
|
describe("KL divergence", () => {
|
||||||
|
test("by integral solver agrees with analytical", () => {
|
||||||
|
let squiggleStringKL = `prediction=normal(4, 1)
|
||||||
|
answer=normal(1,1)
|
||||||
|
logSubtraction=dotSubtract(scaleLog(answer),scaleLog(prediction))
|
||||||
|
klintegrand=dotMultiply(logSubtraction, answer)
|
||||||
|
klintegral = integralSum(klintegrand)
|
||||||
|
analyticalKl = log(1 / 1) + 1 ^ 2 / (2 * 1 ^ 2) + ((4 - 1) * (1 - 4) / (2 * 1 * 1)) - 1 / 2
|
||||||
|
klintegral - analyticalKl`;
|
||||||
|
let squiggleResultKL = testRun(squiggleStringKL);
|
||||||
|
expect(squiggleResultKL.value).toBeCloseTo(0);
|
||||||
|
});
|
||||||
|
});
|
||||||
|
|
||||||
|
let squiggleStringLS = `prediction=normal(4,1)
|
||||||
|
answer=normal(1,1)
|
||||||
|
logScore(prediction, answer)`;
|
|
@ -271,7 +271,10 @@ module T = Dist({
|
||||||
XYShape.Analysis.getVarianceDangerously(t, mean, Analysis.getMeanOfSquares)
|
XYShape.Analysis.getVarianceDangerously(t, mean, Analysis.getMeanOfSquares)
|
||||||
|
|
||||||
let logScore = (base: t, reference: t) => {
|
let logScore = (base: t, reference: t) => {
|
||||||
combinePointwise(PointSetDist_Scoring.LogScoring.logScore, base, reference)
|
E.R2.bind(
|
||||||
|
combinePointwise(PointSetDist_Scoring.LogScoring.multiply, reference),
|
||||||
|
combinePointwise(PointSetDist_Scoring.LogScoring.logScore, base, reference),
|
||||||
|
)
|
||||||
|> E.R.fmap(shapeMap(XYShape.T.filterYValues(Js.Float.isFinite)))
|
|> E.R.fmap(shapeMap(XYShape.T.filterYValues(Js.Float.isFinite)))
|
||||||
|> E.R.fmap(integralEndY)
|
|> E.R.fmap(integralEndY)
|
||||||
}
|
}
|
||||||
|
|
|
@ -4,4 +4,5 @@ module LogScoring = {
|
||||||
let logScore = (a: float, b: float): result<float, Operation.Error.t> => Ok(
|
let logScore = (a: float, b: float): result<float, Operation.Error.t> => Ok(
|
||||||
Js.Math.log2(Js.Math.abs_float(a /. b)),
|
Js.Math.log2(Js.Math.abs_float(a /. b)),
|
||||||
)
|
)
|
||||||
|
let multiply = (a: float, b: float): result<float, Operation.Error.t> => Ok(a *. b)
|
||||||
}
|
}
|
||||||
|
|
Loading…
Reference in New Issue
Block a user