Fixed logScoreAgainstImproperPrior by finding how it was None

Value: [1e-4 to 8e-2]
This commit is contained in:
Quinn Dougherty 2022-05-12 15:26:51 -04:00
parent 51310819a1
commit 65751e590a

View File

@ -160,16 +160,16 @@ module Helpers = {
let constructNonNormalizedPointSet = (
~supportOf: DistributionTypes.genericDist,
fn: float => float,
): option<DistributionTypes.genericDist> => {
switch supportOf {
| PointSet(Continuous(dist)) =>
{xs: dist.xyShape.xs, ys: E.A.fmap(fn, dist.xyShape.xs)}
->Continuous.make
->Continuous
->PointSet
->Some
| _ => None
): DistributionTypes.genericDist => {
let cdf = x => toFloatFn(#Cdf(x), supportOf)
let leftEndpoint = cdf(MagicNumbers.Epsilon.ten)
let rightEndpoint = cdf(1.0 -. MagicNumbers.Epsilon.ten)
let xs = switch (leftEndpoint, rightEndpoint) {
| (Some(Float(a)), Some(Float(b))) =>
E.A.Floats.range(a, b, MagicNumbers.Environment.defaultXYPointLength)
| _ => []
}
{xs: xs, ys: E.A.fmap(fn, xs)}->Continuous.make->Continuous->DistributionTypes.PointSet
}
}
@ -253,16 +253,18 @@ let dispatchToGenericOutput = (call: ExpressionValue.functionCall, _environment)
"logScore",
[EvDistribution(prior), EvDistribution(prediction), EvDistribution(Symbolic(#Float(answer)))],
) =>
Some(runGenericOperation(FromDist(ToScore(LogScore(prediction, answer)), prior)))
runGenericOperation(FromDist(ToScore(LogScore(prediction, answer)), prior))->Some
| ("logScoreAgainstImproperPrior", [EvDistribution(prediction), EvNumber(answer)])
| (
"logScoreAgainstImproperPrior",
[EvDistribution(prediction), EvDistribution(Symbolic(#Float(answer)))],
) =>
E.O.fmap(
d => runGenericOperation(FromDist(ToScore(LogScore(prediction, answer)), d)),
Helpers.constructNonNormalizedPointSet(~supportOf=prediction, _ => 1.0),
)
runGenericOperation(
FromDist(
ToScore(LogScore(prediction, answer)),
Helpers.constructNonNormalizedPointSet(~supportOf=prediction, _ => 1.0),
),
)->Some
| ("isNormalized", [EvDistribution(dist)]) => Helpers.toBoolFn(IsNormalized, dist)
| ("toPointSet", [EvDistribution(dist)]) => Helpers.toDistFn(ToPointSet, dist)
| ("scaleLog", [EvDistribution(dist)]) =>