fix: PointSetDist_Scoring.WithScalarAnswer.scoreWithPrior
Done in pair coding with Quinn. Value::[0.3 to 0.9]
This commit is contained in:
parent
7e859c7823
commit
4b1c226173
|
@ -0,0 +1,81 @@
|
||||||
|
// Bring up a discrete distribution
|
||||||
|
open Jest
|
||||||
|
open Expect
|
||||||
|
open TestHelpers
|
||||||
|
open GenericDist_Fixtures
|
||||||
|
|
||||||
|
// WithDistAnswer -> in the KL divergence test file.
|
||||||
|
|
||||||
|
// WithScalarAnswer
|
||||||
|
describe("WithScalarAnswer: discrete -> discrete -> float", () => {
|
||||||
|
let mixture = a => DistributionTypes.DistributionOperation.Mixture(a)
|
||||||
|
let pointA = mkDelta(3.0)
|
||||||
|
let pointB = mkDelta(2.0)
|
||||||
|
let pointC = mkDelta(1.0)
|
||||||
|
let pointD = mkDelta(0.0)
|
||||||
|
|
||||||
|
test("score: agrees with analytical answer when finite", () => {
|
||||||
|
let prediction' = [(pointA, 0.25), (pointB, 0.25), (pointC, 0.25), (pointD, 0.25)]->mixture->run
|
||||||
|
let prediction = switch prediction' {
|
||||||
|
| Dist(PointSet(a'')) => a''
|
||||||
|
| _ => raise(MixtureFailed)
|
||||||
|
}
|
||||||
|
|
||||||
|
let answer = 2.0 // So this is: assigning 100% probability to 2.0
|
||||||
|
let result = PointSetDist_Scoring.WithScalarAnswer.score(~estimate=prediction, ~answer)
|
||||||
|
switch result {
|
||||||
|
| Ok(x) => x->expect->toEqual(-.Js.Math.log(0.25 /. 1.0))
|
||||||
|
| _ => raise(MixtureFailed)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
test("score: agrees with analytical answer when finite", () => {
|
||||||
|
let prediction' = [(pointA, 0.75), (pointB, 0.25)]->mixture->run
|
||||||
|
let prediction = switch prediction' {
|
||||||
|
| Dist(PointSet(a'')) => a''
|
||||||
|
| _ => raise(MixtureFailed)
|
||||||
|
}
|
||||||
|
let answer = 3.0 // So this is: assigning 100% probability to 2.0
|
||||||
|
let result = PointSetDist_Scoring.WithScalarAnswer.score(~estimate=prediction, ~answer)
|
||||||
|
switch result {
|
||||||
|
| Ok(x) => x->expect->toEqual(-.Js.Math.log(0.75 /. 1.0))
|
||||||
|
| _ => raise(MixtureFailed)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
test("scoreWithPrior: ", () => {
|
||||||
|
let prior' = [(pointA, 0.5), (pointB, 0.5)]->mixture->run
|
||||||
|
let prediction' = [(pointA, 0.75), (pointB, 0.25)]->mixture->run
|
||||||
|
|
||||||
|
let prediction = switch prediction' {
|
||||||
|
| Dist(PointSet(a'')) => a''
|
||||||
|
| _ => raise(MixtureFailed)
|
||||||
|
}
|
||||||
|
|
||||||
|
let prior = switch prior' {
|
||||||
|
| Dist(PointSet(a'')) => a''
|
||||||
|
| _ => raise(MixtureFailed)
|
||||||
|
}
|
||||||
|
|
||||||
|
let answer = 3.0 // So this is: assigning 100% probability to 2.0
|
||||||
|
let result = PointSetDist_Scoring.WithScalarAnswer.scoreWithPrior(
|
||||||
|
~estimate=prediction,
|
||||||
|
~answer,
|
||||||
|
~prior,
|
||||||
|
)
|
||||||
|
switch result {
|
||||||
|
| Ok(x) => x->expect->toEqual(-.Js.Math.log(0.75 /. 1.0) -. -.Js.Math.log(0.5 /. 1.0))
|
||||||
|
| _ => raise(MixtureFailed)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
})
|
||||||
|
|
||||||
|
// WithDistAnswer
|
||||||
|
/*
|
||||||
|
describe("WithScalarAnswer: discrete -> discrete -> float", () => {
|
||||||
|
})
|
||||||
|
|
||||||
|
// TwoScalars
|
||||||
|
describe("WithScalarAnswer: discrete -> discrete -> float", () => {
|
||||||
|
})
|
||||||
|
*/
|
|
@ -115,11 +115,20 @@ module WithScalarAnswer = {
|
||||||
}
|
}
|
||||||
_score(~estimatePdf, ~answer)
|
_score(~estimatePdf, ~answer)
|
||||||
}
|
}
|
||||||
|
/*
|
||||||
|
let score1 = (~estimate: pointSetDist, ~answer: scalar): result<score, Operation.Error.t> => {
|
||||||
|
let probabilityAssignedToAnswer = Ok(1.0)
|
||||||
|
}
|
||||||
|
*/
|
||||||
|
|
||||||
let scoreWithPrior = (~estimate: pointSetDist, ~answer: scalar, ~prior: pointSetDist): result<
|
let scoreWithPrior = (~estimate: pointSetDist, ~answer: scalar, ~prior: pointSetDist): result<
|
||||||
score,
|
score,
|
||||||
Operation.Error.t,
|
Operation.Error.t,
|
||||||
> => {
|
> => {
|
||||||
|
E.R.merge(score(~estimate, ~answer), score(~estimate=prior, ~answer))->E.R2.fmap(((s1, s2)) =>
|
||||||
|
s1 -. s2
|
||||||
|
)
|
||||||
|
/*
|
||||||
let _scoreWithPrior = (
|
let _scoreWithPrior = (
|
||||||
~estimatePdf: float => float,
|
~estimatePdf: float => float,
|
||||||
~answer: scalar,
|
~answer: scalar,
|
||||||
|
@ -132,7 +141,6 @@ module WithScalarAnswer = {
|
||||||
} else if numerator == 0.0 || priorDensityOfAnswer == 0.0 {
|
} else if numerator == 0.0 || priorDensityOfAnswer == 0.0 {
|
||||||
infinity->Ok
|
infinity->Ok
|
||||||
} else {
|
} else {
|
||||||
minusScaledLogOfQuotient(~esti=numerator, ~answ=priorDensityOfAnswer)
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -149,6 +157,7 @@ module WithScalarAnswer = {
|
||||||
| Mixed(prio) => Mixed.T.xToY(x, prio)->sum
|
| Mixed(prio) => Mixed.T.xToY(x, prio)->sum
|
||||||
}
|
}
|
||||||
_scoreWithPrior(~estimatePdf, ~answer, ~priorPdf)
|
_scoreWithPrior(~estimatePdf, ~answer, ~priorPdf)
|
||||||
|
*/
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue
Block a user