Minor refactors

This commit is contained in:
Ozzie Gooen 2022-05-25 18:10:05 -04:00
parent 6a9179d4b8
commit 239abbdcf8
2 changed files with 84 additions and 75 deletions

View File

@ -128,26 +128,27 @@ module Score = {
~answ: scoreDistOrScalar, ~answ: scoreDistOrScalar,
~prior: option<scoreDistOrScalar>, ~prior: option<scoreDistOrScalar>,
): result<PointSetDist_Scoring.scoreArgs, error> => { ): result<PointSetDist_Scoring.scoreArgs, error> => {
let toPointSetFn = toPointSet( let toPointSetFn = t =>
toPointSet(
t,
~xyPointLength=MagicNumbers.Environment.defaultXYPointLength, ~xyPointLength=MagicNumbers.Environment.defaultXYPointLength,
~sampleCount=MagicNumbers.Environment.defaultSampleCount, ~sampleCount=MagicNumbers.Environment.defaultSampleCount,
~xSelection=#ByWeight, ~xSelection=#ByWeight,
(),
) )
let prior': option<result<pointSet_ScoreDistOrScalar, error>> = switch prior { let prior': option<result<pointSet_ScoreDistOrScalar, error>> = switch prior {
| None => None | None => None
| Some(Score_Dist(d)) => toPointSetFn(d, ())->E.R.bind(x => x->D->Ok)->Some | Some(Score_Dist(d)) => toPointSetFn(d)->E.R.bind(x => x->D->Ok)->Some
| Some(Score_Scalar(s)) => s->S->Ok->Some | Some(Score_Scalar(s)) => s->S->Ok->Some
} }
let twoDists = (esti': t, answ': t): result< let twoDists = (esti': t, answ': t): result<
(PointSetTypes.pointSetDist, PointSetTypes.pointSetDist), (PointSetTypes.pointSetDist, PointSetTypes.pointSetDist),
error, error,
> => E.R.merge(toPointSetFn(esti', ()), toPointSetFn(answ', ())) > => E.R.merge(toPointSetFn(esti'), toPointSetFn(answ'))
switch (esti, answ, prior') { switch (esti, answ, prior') {
| (Score_Dist(esti'), Score_Dist(answ'), None) => | (Score_Dist(esti'), Score_Dist(answ'), None) =>
twoDists(esti', answ')->E.R.bind(((esti'', answ'')) => twoDists(esti', answ')->E.R2.fmap(((esti'', answ'')) =>
{estimate: esti'', answer: answ'', prior: None} {estimate: esti'', answer: answ'', prior: None}->PointSetDist_Scoring.DistEstimateDistAnswer
->PointSetDist_Scoring.DistEstimateDistAnswer
->Ok
) )
| (Score_Dist(esti'), Score_Dist(answ'), Some(Ok(D(prior'')))) => | (Score_Dist(esti'), Score_Dist(answ'), Some(Ok(D(prior'')))) =>
twoDists(esti', answ')->E.R.bind(((esti'', answ'')) => twoDists(esti', answ')->E.R.bind(((esti'', answ'')) =>
@ -157,25 +158,25 @@ module Score = {
) )
| (Score_Dist(_), _, Some(Ok(S(_)))) => DistributionTypes.Unreachable->Error | (Score_Dist(_), _, Some(Ok(S(_)))) => DistributionTypes.Unreachable->Error
| (Score_Dist(esti'), Score_Scalar(answ'), None) => | (Score_Dist(esti'), Score_Scalar(answ'), None) =>
toPointSetFn(esti', ())->E.R.bind(esti'' => toPointSetFn(esti')->E.R.bind(esti'' =>
{estimate: esti'', answer: answ', prior: None} {estimate: esti'', answer: answ', prior: None}
->PointSetDist_Scoring.DistEstimateScalarAnswer ->PointSetDist_Scoring.DistEstimateScalarAnswer
->Ok ->Ok
) )
| (Score_Dist(esti'), Score_Scalar(answ'), Some(Ok(D(prior'')))) => | (Score_Dist(esti'), Score_Scalar(answ'), Some(Ok(D(prior'')))) =>
toPointSetFn(esti', ())->E.R.bind(esti'' => toPointSetFn(esti')->E.R.bind(esti'' =>
{estimate: esti'', answer: answ', prior: Some(prior'')} {estimate: esti'', answer: answ', prior: Some(prior'')}
->PointSetDist_Scoring.DistEstimateScalarAnswer ->PointSetDist_Scoring.DistEstimateScalarAnswer
->Ok ->Ok
) )
| (Score_Scalar(esti'), Score_Dist(answ'), None) => | (Score_Scalar(esti'), Score_Dist(answ'), None) =>
toPointSetFn(answ', ())->E.R.bind(answ'' => toPointSetFn(answ')->E.R.bind(answ'' =>
{estimate: esti', answer: answ'', prior: None} {estimate: esti', answer: answ'', prior: None}
->PointSetDist_Scoring.ScalarEstimateDistAnswer ->PointSetDist_Scoring.ScalarEstimateDistAnswer
->Ok ->Ok
) )
| (Score_Scalar(esti'), Score_Dist(answ'), Some(Ok(S(prior'')))) => | (Score_Scalar(esti'), Score_Dist(answ'), Some(Ok(S(prior'')))) =>
toPointSetFn(answ', ())->E.R.bind(answ'' => toPointSetFn(answ')->E.R.bind(answ'' =>
{estimate: esti', answer: answ'', prior: Some(prior'')} {estimate: esti', answer: answ'', prior: Some(prior'')}
->PointSetDist_Scoring.ScalarEstimateDistAnswer ->PointSetDist_Scoring.ScalarEstimateDistAnswer
->Ok ->Ok

View File

@ -1,11 +1,11 @@
type t = PointSetTypes.pointSetDist type pointSetDist = PointSetTypes.pointSetDist
type scalar = float type scalar = float
type abstractScoreArgs<'a, 'b> = {estimate: 'a, answer: 'b, prior: option<'a>} type abstractScoreArgs<'a, 'b> = {estimate: 'a, answer: 'b, prior: option<'a>}
type scoreArgs = type scoreArgs =
| DistEstimateDistAnswer(abstractScoreArgs<t, t>) | DistEstimateDistAnswer(abstractScoreArgs<pointSetDist, pointSetDist>)
| DistEstimateScalarAnswer(abstractScoreArgs<t, scalar>) | DistEstimateScalarAnswer(abstractScoreArgs<pointSetDist, scalar>)
| ScalarEstimateDistAnswer(abstractScoreArgs<scalar, t>) | ScalarEstimateDistAnswer(abstractScoreArgs<scalar, pointSetDist>)
| ScalarEstimateScalarAnswer(abstractScoreArgs<scalar, scalar>) | ScalarEstimateScalarAnswer(abstractScoreArgs<scalar, scalar>)
let logFn = Js.Math.log // base e let logFn = Js.Math.log // base e
@ -29,15 +29,17 @@ module WithDistAnswer = {
minusScaledLogOfQuot(~esti=estimateElement, ~answ=answerElement) minusScaledLogOfQuot(~esti=estimateElement, ~answ=answerElement)
} }
let rec sum = (~estimate: t, ~answer: t, ~combineFn, ~integrateFn, ~toMixedFn): result< let sum = (
float, ~estimate: pointSetDist,
Operation.Error.t, ~answer: pointSetDist,
> => ~combineFn,
switch (estimate, answer) { ~integrateFn,
| (Continuous(_), Continuous(_)) ~toMixedFn,
| (Discrete(_), Discrete(_)) => ): result<float, Operation.Error.t> => {
let combineAndIntegrate = (estimate, answer) =>
combineFn(integrand, estimate, answer)->E.R2.fmap(integrateFn) combineFn(integrand, estimate, answer)->E.R2.fmap(integrateFn)
| (_, _) =>
let getMixedSums = (estimate: pointSetDist, answer: pointSetDist) => {
let esti = estimate->toMixedFn let esti = estimate->toMixedFn
let answ = answer->toMixedFn let answ = answer->toMixedFn
switch ( switch (
@ -53,29 +55,31 @@ module WithDistAnswer = {
Some(answDiscretePart), Some(answDiscretePart),
) => ) =>
E.R.merge( E.R.merge(
sum( combineAndIntegrate(
~estimate=Discrete(estiDiscretePart), PointSetTypes.Discrete(estiDiscretePart),
~answer=Discrete(answDiscretePart), PointSetTypes.Discrete(answDiscretePart),
~combineFn,
~integrateFn,
~toMixedFn,
), ),
sum( combineAndIntegrate(Continuous(estiContinuousPart), Continuous(answContinuousPart)),
~estimate=Continuous(estiContinuousPart), )
~answer=Continuous(answContinuousPart),
~combineFn,
~integrateFn,
~toMixedFn,
),
)->E.R2.fmap(((discretePart, continuousPart)) => discretePart +. continuousPart)
| (_, _, _, _) => `unreachable state`->Operation.Other->Error | (_, _, _, _) => `unreachable state`->Operation.Other->Error
} }
} }
switch (estimate, answer) {
| (Continuous(_), Continuous(_))
| (Discrete(_), Discrete(_)) =>
combineAndIntegrate(estimate, answer)
| (_, _) =>
getMixedSums(estimate, answer)->E.R2.fmap(((discretePart, continuousPart)) =>
discretePart +. continuousPart
)
}
}
let sumWithPrior = ( let sumWithPrior = (
~estimate: t, ~estimate: pointSetDist,
~answer: t, ~answer: pointSetDist,
~prior: t, ~prior: pointSetDist,
~combineFn, ~combineFn,
~integrateFn, ~integrateFn,
~toMixedFn, ~toMixedFn,
@ -87,7 +91,12 @@ module WithDistAnswer = {
} }
module WithScalarAnswer = { module WithScalarAnswer = {
let score' = (~estimatePdf: float => float, ~answer: float): result<float, Operation.Error.t> => { let sum = (mp: PointSetTypes.MixedPoint.t): float => mp.continuous +. mp.discrete
let score = (~estimate: pointSetDist, ~answer: scalar): result<float, Operation.Error.t> => {
let _score = (~estimatePdf: float => float, ~answer: float): result<
float,
Operation.Error.t,
> => {
let density = answer->estimatePdf let density = answer->estimatePdf
if density < 0.0 { if density < 0.0 {
Operation.PdfInvalidError->Error Operation.PdfInvalidError->Error
@ -97,7 +106,21 @@ module WithScalarAnswer = {
density->logFn->(x => -.x)->Ok density->logFn->(x => -.x)->Ok
} }
} }
let scoreWithPrior' = (
let estimatePdf = x =>
switch estimate {
| Continuous(esti) => Continuous.T.xToY(x, esti)->sum
| Discrete(esti) => Discrete.T.xToY(x, esti)->sum
| Mixed(esti) => Mixed.T.xToY(x, esti)->sum
}
_score(~estimatePdf, ~answer)
}
let scoreWithPrior = (~estimate: pointSetDist, ~answer: scalar, ~prior: pointSetDist): result<
float,
Operation.Error.t,
> => {
let _scoreWithPrior = (
~estimatePdf: float => float, ~estimatePdf: float => float,
~answer: scalar, ~answer: scalar,
~priorPdf: float => float, ~priorPdf: float => float,
@ -113,21 +136,6 @@ module WithScalarAnswer = {
} }
} }
let sum = (mp: PointSetTypes.MixedPoint.t): float => mp.continuous +. mp.discrete
let score = (~estimate: t, ~answer: scalar): result<float, Operation.Error.t> => {
let estimatePdf = x =>
switch estimate {
| Continuous(esti) => Continuous.T.xToY(x, esti)->sum
| Discrete(esti) => Discrete.T.xToY(x, esti)->sum
| Mixed(esti) => Mixed.T.xToY(x, esti)->sum
}
score'(~estimatePdf, ~answer)
}
let scoreWithPrior = (~estimate: t, ~answer: scalar, ~prior: t): result<
float,
Operation.Error.t,
> => {
let estimatePdf = x => let estimatePdf = x =>
switch estimate { switch estimate {
| Continuous(esti) => Continuous.T.xToY(x, esti)->sum | Continuous(esti) => Continuous.T.xToY(x, esti)->sum
@ -140,7 +148,7 @@ module WithScalarAnswer = {
| Discrete(prio) => Discrete.T.xToY(x, prio)->sum | Discrete(prio) => Discrete.T.xToY(x, prio)->sum
| Mixed(prio) => Mixed.T.xToY(x, prio)->sum | Mixed(prio) => Mixed.T.xToY(x, prio)->sum
} }
scoreWithPrior'(~estimatePdf, ~answer, ~priorPdf) _scoreWithPrior(~estimatePdf, ~answer, ~priorPdf)
} }
} }