Minor refactors

This commit is contained in:
Ozzie Gooen 2022-05-25 18:10:05 -04:00
parent 6a9179d4b8
commit 239abbdcf8
2 changed files with 84 additions and 75 deletions

View File

@ -128,26 +128,27 @@ module Score = {
~answ: scoreDistOrScalar,
~prior: option<scoreDistOrScalar>,
): result<PointSetDist_Scoring.scoreArgs, error> => {
let toPointSetFn = toPointSet(
~xyPointLength=MagicNumbers.Environment.defaultXYPointLength,
~sampleCount=MagicNumbers.Environment.defaultSampleCount,
~xSelection=#ByWeight,
)
let toPointSetFn = t =>
toPointSet(
t,
~xyPointLength=MagicNumbers.Environment.defaultXYPointLength,
~sampleCount=MagicNumbers.Environment.defaultSampleCount,
~xSelection=#ByWeight,
(),
)
let prior': option<result<pointSet_ScoreDistOrScalar, error>> = switch prior {
| None => None
| Some(Score_Dist(d)) => toPointSetFn(d, ())->E.R.bind(x => x->D->Ok)->Some
| Some(Score_Dist(d)) => toPointSetFn(d)->E.R.bind(x => x->D->Ok)->Some
| Some(Score_Scalar(s)) => s->S->Ok->Some
}
let twoDists = (esti': t, answ': t): result<
(PointSetTypes.pointSetDist, PointSetTypes.pointSetDist),
error,
> => E.R.merge(toPointSetFn(esti', ()), toPointSetFn(answ', ()))
> => E.R.merge(toPointSetFn(esti'), toPointSetFn(answ'))
switch (esti, answ, prior') {
| (Score_Dist(esti'), Score_Dist(answ'), None) =>
twoDists(esti', answ')->E.R.bind(((esti'', answ'')) =>
{estimate: esti'', answer: answ'', prior: None}
->PointSetDist_Scoring.DistEstimateDistAnswer
->Ok
twoDists(esti', answ')->E.R2.fmap(((esti'', answ'')) =>
{estimate: esti'', answer: answ'', prior: None}->PointSetDist_Scoring.DistEstimateDistAnswer
)
| (Score_Dist(esti'), Score_Dist(answ'), Some(Ok(D(prior'')))) =>
twoDists(esti', answ')->E.R.bind(((esti'', answ'')) =>
@ -157,25 +158,25 @@ module Score = {
)
| (Score_Dist(_), _, Some(Ok(S(_)))) => DistributionTypes.Unreachable->Error
| (Score_Dist(esti'), Score_Scalar(answ'), None) =>
toPointSetFn(esti', ())->E.R.bind(esti'' =>
toPointSetFn(esti')->E.R.bind(esti'' =>
{estimate: esti'', answer: answ', prior: None}
->PointSetDist_Scoring.DistEstimateScalarAnswer
->Ok
)
| (Score_Dist(esti'), Score_Scalar(answ'), Some(Ok(D(prior'')))) =>
toPointSetFn(esti', ())->E.R.bind(esti'' =>
toPointSetFn(esti')->E.R.bind(esti'' =>
{estimate: esti'', answer: answ', prior: Some(prior'')}
->PointSetDist_Scoring.DistEstimateScalarAnswer
->Ok
)
| (Score_Scalar(esti'), Score_Dist(answ'), None) =>
toPointSetFn(answ', ())->E.R.bind(answ'' =>
toPointSetFn(answ')->E.R.bind(answ'' =>
{estimate: esti', answer: answ'', prior: None}
->PointSetDist_Scoring.ScalarEstimateDistAnswer
->Ok
)
| (Score_Scalar(esti'), Score_Dist(answ'), Some(Ok(S(prior'')))) =>
toPointSetFn(answ', ())->E.R.bind(answ'' =>
toPointSetFn(answ')->E.R.bind(answ'' =>
{estimate: esti', answer: answ'', prior: Some(prior'')}
->PointSetDist_Scoring.ScalarEstimateDistAnswer
->Ok

View File

@ -1,11 +1,11 @@
type t = PointSetTypes.pointSetDist
type pointSetDist = PointSetTypes.pointSetDist
type scalar = float
type abstractScoreArgs<'a, 'b> = {estimate: 'a, answer: 'b, prior: option<'a>}
type scoreArgs =
| DistEstimateDistAnswer(abstractScoreArgs<t, t>)
| DistEstimateScalarAnswer(abstractScoreArgs<t, scalar>)
| ScalarEstimateDistAnswer(abstractScoreArgs<scalar, t>)
| DistEstimateDistAnswer(abstractScoreArgs<pointSetDist, pointSetDist>)
| DistEstimateScalarAnswer(abstractScoreArgs<pointSetDist, scalar>)
| ScalarEstimateDistAnswer(abstractScoreArgs<scalar, pointSetDist>)
| ScalarEstimateScalarAnswer(abstractScoreArgs<scalar, scalar>)
let logFn = Js.Math.log // base e
@ -29,15 +29,17 @@ module WithDistAnswer = {
minusScaledLogOfQuot(~esti=estimateElement, ~answ=answerElement)
}
let rec sum = (~estimate: t, ~answer: t, ~combineFn, ~integrateFn, ~toMixedFn): result<
float,
Operation.Error.t,
> =>
switch (estimate, answer) {
| (Continuous(_), Continuous(_))
| (Discrete(_), Discrete(_)) =>
let sum = (
~estimate: pointSetDist,
~answer: pointSetDist,
~combineFn,
~integrateFn,
~toMixedFn,
): result<float, Operation.Error.t> => {
let combineAndIntegrate = (estimate, answer) =>
combineFn(integrand, estimate, answer)->E.R2.fmap(integrateFn)
| (_, _) =>
let getMixedSums = (estimate: pointSetDist, answer: pointSetDist) => {
let esti = estimate->toMixedFn
let answ = answer->toMixedFn
switch (
@ -53,29 +55,31 @@ module WithDistAnswer = {
Some(answDiscretePart),
) =>
E.R.merge(
sum(
~estimate=Discrete(estiDiscretePart),
~answer=Discrete(answDiscretePart),
~combineFn,
~integrateFn,
~toMixedFn,
combineAndIntegrate(
PointSetTypes.Discrete(estiDiscretePart),
PointSetTypes.Discrete(answDiscretePart),
),
sum(
~estimate=Continuous(estiContinuousPart),
~answer=Continuous(answContinuousPart),
~combineFn,
~integrateFn,
~toMixedFn,
),
)->E.R2.fmap(((discretePart, continuousPart)) => discretePart +. continuousPart)
combineAndIntegrate(Continuous(estiContinuousPart), Continuous(answContinuousPart)),
)
| (_, _, _, _) => `unreachable state`->Operation.Other->Error
}
}
switch (estimate, answer) {
| (Continuous(_), Continuous(_))
| (Discrete(_), Discrete(_)) =>
combineAndIntegrate(estimate, answer)
| (_, _) =>
getMixedSums(estimate, answer)->E.R2.fmap(((discretePart, continuousPart)) =>
discretePart +. continuousPart
)
}
}
let sumWithPrior = (
~estimate: t,
~answer: t,
~prior: t,
~estimate: pointSetDist,
~answer: pointSetDist,
~prior: pointSetDist,
~combineFn,
~integrateFn,
~toMixedFn,
@ -87,47 +91,51 @@ module WithDistAnswer = {
}
module WithScalarAnswer = {
let score' = (~estimatePdf: float => float, ~answer: float): result<float, Operation.Error.t> => {
let density = answer->estimatePdf
if density < 0.0 {
Operation.PdfInvalidError->Error
} else if density == 0.0 {
infinity->Ok
} else {
density->logFn->(x => -.x)->Ok
}
}
let scoreWithPrior' = (
~estimatePdf: float => float,
~answer: scalar,
~priorPdf: float => float,
): result<float, Operation.Error.t> => {
let numerator = answer->estimatePdf
let priorDensityOfAnswer = answer->priorPdf
if numerator < 0.0 || priorDensityOfAnswer < 0.0 {
Operation.PdfInvalidError->Error
} else if numerator == 0.0 || priorDensityOfAnswer == 0.0 {
infinity->Ok
} else {
minusScaledLogOfQuot(~esti=numerator, ~answ=priorDensityOfAnswer)
}
}
let sum = (mp: PointSetTypes.MixedPoint.t): float => mp.continuous +. mp.discrete
let score = (~estimate: t, ~answer: scalar): result<float, Operation.Error.t> => {
let score = (~estimate: pointSetDist, ~answer: scalar): result<float, Operation.Error.t> => {
let _score = (~estimatePdf: float => float, ~answer: float): result<
float,
Operation.Error.t,
> => {
let density = answer->estimatePdf
if density < 0.0 {
Operation.PdfInvalidError->Error
} else if density == 0.0 {
infinity->Ok
} else {
density->logFn->(x => -.x)->Ok
}
}
let estimatePdf = x =>
switch estimate {
| Continuous(esti) => Continuous.T.xToY(x, esti)->sum
| Discrete(esti) => Discrete.T.xToY(x, esti)->sum
| Mixed(esti) => Mixed.T.xToY(x, esti)->sum
}
score'(~estimatePdf, ~answer)
_score(~estimatePdf, ~answer)
}
let scoreWithPrior = (~estimate: t, ~answer: scalar, ~prior: t): result<
let scoreWithPrior = (~estimate: pointSetDist, ~answer: scalar, ~prior: pointSetDist): result<
float,
Operation.Error.t,
> => {
let _scoreWithPrior = (
~estimatePdf: float => float,
~answer: scalar,
~priorPdf: float => float,
): result<float, Operation.Error.t> => {
let numerator = answer->estimatePdf
let priorDensityOfAnswer = answer->priorPdf
if numerator < 0.0 || priorDensityOfAnswer < 0.0 {
Operation.PdfInvalidError->Error
} else if numerator == 0.0 || priorDensityOfAnswer == 0.0 {
infinity->Ok
} else {
minusScaledLogOfQuot(~esti=numerator, ~answ=priorDensityOfAnswer)
}
}
let estimatePdf = x =>
switch estimate {
| Continuous(esti) => Continuous.T.xToY(x, esti)->sum
@ -140,7 +148,7 @@ module WithScalarAnswer = {
| Discrete(prio) => Discrete.T.xToY(x, prio)->sum
| Mixed(prio) => Mixed.T.xToY(x, prio)->sum
}
scoreWithPrior'(~estimatePdf, ~answer, ~priorPdf)
_scoreWithPrior(~estimatePdf, ~answer, ~priorPdf)
}
}