Added extra multiplicative factor in logScore integrand
Value: [8e-2 to 7e-1] migrated intregrand from `log(predicted / answer)` to `answer * log(predicted / answer)`
This commit is contained in:
parent
d595285078
commit
db3acbf96c
|
@ -0,0 +1,3 @@
|
|||
open Jest
|
||||
open Expect
|
||||
open TestHelpers
|
19
packages/squiggle-lang/__tests__/TS/Score_test.ts
Normal file
19
packages/squiggle-lang/__tests__/TS/Score_test.ts
Normal file
|
@ -0,0 +1,19 @@
|
|||
import { testRun } from "./TestHelpers";
|
||||
|
||||
describe("KL divergence", () => {
|
||||
test("by integral solver agrees with analytical", () => {
|
||||
let squiggleStringKL = `prediction=normal(4, 1)
|
||||
answer=normal(1,1)
|
||||
logSubtraction=dotSubtract(scaleLog(answer),scaleLog(prediction))
|
||||
klintegrand=dotMultiply(logSubtraction, answer)
|
||||
klintegral = integralSum(klintegrand)
|
||||
analyticalKl = log(1 / 1) + 1 ^ 2 / (2 * 1 ^ 2) + ((4 - 1) * (1 - 4) / (2 * 1 * 1)) - 1 / 2
|
||||
klintegral - analyticalKl`;
|
||||
let squiggleResultKL = testRun(squiggleStringKL);
|
||||
expect(squiggleResultKL.value).toBeCloseTo(0);
|
||||
});
|
||||
});
|
||||
|
||||
let squiggleStringLS = `prediction=normal(4,1)
|
||||
answer=normal(1,1)
|
||||
logScore(prediction, answer)`;
|
|
@ -271,7 +271,10 @@ module T = Dist({
|
|||
XYShape.Analysis.getVarianceDangerously(t, mean, Analysis.getMeanOfSquares)
|
||||
|
||||
let logScore = (base: t, reference: t) => {
|
||||
combinePointwise(PointSetDist_Scoring.LogScoring.logScore, base, reference)
|
||||
E.R2.bind(
|
||||
combinePointwise(PointSetDist_Scoring.LogScoring.multiply, reference),
|
||||
combinePointwise(PointSetDist_Scoring.LogScoring.logScore, base, reference),
|
||||
)
|
||||
|> E.R.fmap(shapeMap(XYShape.T.filterYValues(Js.Float.isFinite)))
|
||||
|> E.R.fmap(integralEndY)
|
||||
}
|
||||
|
|
|
@ -4,4 +4,5 @@ module LogScoring = {
|
|||
let logScore = (a: float, b: float): result<float, Operation.Error.t> => Ok(
|
||||
Js.Math.log2(Js.Math.abs_float(a /. b)),
|
||||
)
|
||||
let multiply = (a: float, b: float): result<float, Operation.Error.t> => Ok(a *. b)
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue
Block a user