Minor refactors
This commit is contained in:
parent
6a9179d4b8
commit
239abbdcf8
|
@ -128,26 +128,27 @@ module Score = {
|
||||||
~answ: scoreDistOrScalar,
|
~answ: scoreDistOrScalar,
|
||||||
~prior: option<scoreDistOrScalar>,
|
~prior: option<scoreDistOrScalar>,
|
||||||
): result<PointSetDist_Scoring.scoreArgs, error> => {
|
): result<PointSetDist_Scoring.scoreArgs, error> => {
|
||||||
let toPointSetFn = toPointSet(
|
let toPointSetFn = t =>
|
||||||
|
toPointSet(
|
||||||
|
t,
|
||||||
~xyPointLength=MagicNumbers.Environment.defaultXYPointLength,
|
~xyPointLength=MagicNumbers.Environment.defaultXYPointLength,
|
||||||
~sampleCount=MagicNumbers.Environment.defaultSampleCount,
|
~sampleCount=MagicNumbers.Environment.defaultSampleCount,
|
||||||
~xSelection=#ByWeight,
|
~xSelection=#ByWeight,
|
||||||
|
(),
|
||||||
)
|
)
|
||||||
let prior': option<result<pointSet_ScoreDistOrScalar, error>> = switch prior {
|
let prior': option<result<pointSet_ScoreDistOrScalar, error>> = switch prior {
|
||||||
| None => None
|
| None => None
|
||||||
| Some(Score_Dist(d)) => toPointSetFn(d, ())->E.R.bind(x => x->D->Ok)->Some
|
| Some(Score_Dist(d)) => toPointSetFn(d)->E.R.bind(x => x->D->Ok)->Some
|
||||||
| Some(Score_Scalar(s)) => s->S->Ok->Some
|
| Some(Score_Scalar(s)) => s->S->Ok->Some
|
||||||
}
|
}
|
||||||
let twoDists = (esti': t, answ': t): result<
|
let twoDists = (esti': t, answ': t): result<
|
||||||
(PointSetTypes.pointSetDist, PointSetTypes.pointSetDist),
|
(PointSetTypes.pointSetDist, PointSetTypes.pointSetDist),
|
||||||
error,
|
error,
|
||||||
> => E.R.merge(toPointSetFn(esti', ()), toPointSetFn(answ', ()))
|
> => E.R.merge(toPointSetFn(esti'), toPointSetFn(answ'))
|
||||||
switch (esti, answ, prior') {
|
switch (esti, answ, prior') {
|
||||||
| (Score_Dist(esti'), Score_Dist(answ'), None) =>
|
| (Score_Dist(esti'), Score_Dist(answ'), None) =>
|
||||||
twoDists(esti', answ')->E.R.bind(((esti'', answ'')) =>
|
twoDists(esti', answ')->E.R2.fmap(((esti'', answ'')) =>
|
||||||
{estimate: esti'', answer: answ'', prior: None}
|
{estimate: esti'', answer: answ'', prior: None}->PointSetDist_Scoring.DistEstimateDistAnswer
|
||||||
->PointSetDist_Scoring.DistEstimateDistAnswer
|
|
||||||
->Ok
|
|
||||||
)
|
)
|
||||||
| (Score_Dist(esti'), Score_Dist(answ'), Some(Ok(D(prior'')))) =>
|
| (Score_Dist(esti'), Score_Dist(answ'), Some(Ok(D(prior'')))) =>
|
||||||
twoDists(esti', answ')->E.R.bind(((esti'', answ'')) =>
|
twoDists(esti', answ')->E.R.bind(((esti'', answ'')) =>
|
||||||
|
@ -157,25 +158,25 @@ module Score = {
|
||||||
)
|
)
|
||||||
| (Score_Dist(_), _, Some(Ok(S(_)))) => DistributionTypes.Unreachable->Error
|
| (Score_Dist(_), _, Some(Ok(S(_)))) => DistributionTypes.Unreachable->Error
|
||||||
| (Score_Dist(esti'), Score_Scalar(answ'), None) =>
|
| (Score_Dist(esti'), Score_Scalar(answ'), None) =>
|
||||||
toPointSetFn(esti', ())->E.R.bind(esti'' =>
|
toPointSetFn(esti')->E.R.bind(esti'' =>
|
||||||
{estimate: esti'', answer: answ', prior: None}
|
{estimate: esti'', answer: answ', prior: None}
|
||||||
->PointSetDist_Scoring.DistEstimateScalarAnswer
|
->PointSetDist_Scoring.DistEstimateScalarAnswer
|
||||||
->Ok
|
->Ok
|
||||||
)
|
)
|
||||||
| (Score_Dist(esti'), Score_Scalar(answ'), Some(Ok(D(prior'')))) =>
|
| (Score_Dist(esti'), Score_Scalar(answ'), Some(Ok(D(prior'')))) =>
|
||||||
toPointSetFn(esti', ())->E.R.bind(esti'' =>
|
toPointSetFn(esti')->E.R.bind(esti'' =>
|
||||||
{estimate: esti'', answer: answ', prior: Some(prior'')}
|
{estimate: esti'', answer: answ', prior: Some(prior'')}
|
||||||
->PointSetDist_Scoring.DistEstimateScalarAnswer
|
->PointSetDist_Scoring.DistEstimateScalarAnswer
|
||||||
->Ok
|
->Ok
|
||||||
)
|
)
|
||||||
| (Score_Scalar(esti'), Score_Dist(answ'), None) =>
|
| (Score_Scalar(esti'), Score_Dist(answ'), None) =>
|
||||||
toPointSetFn(answ', ())->E.R.bind(answ'' =>
|
toPointSetFn(answ')->E.R.bind(answ'' =>
|
||||||
{estimate: esti', answer: answ'', prior: None}
|
{estimate: esti', answer: answ'', prior: None}
|
||||||
->PointSetDist_Scoring.ScalarEstimateDistAnswer
|
->PointSetDist_Scoring.ScalarEstimateDistAnswer
|
||||||
->Ok
|
->Ok
|
||||||
)
|
)
|
||||||
| (Score_Scalar(esti'), Score_Dist(answ'), Some(Ok(S(prior'')))) =>
|
| (Score_Scalar(esti'), Score_Dist(answ'), Some(Ok(S(prior'')))) =>
|
||||||
toPointSetFn(answ', ())->E.R.bind(answ'' =>
|
toPointSetFn(answ')->E.R.bind(answ'' =>
|
||||||
{estimate: esti', answer: answ'', prior: Some(prior'')}
|
{estimate: esti', answer: answ'', prior: Some(prior'')}
|
||||||
->PointSetDist_Scoring.ScalarEstimateDistAnswer
|
->PointSetDist_Scoring.ScalarEstimateDistAnswer
|
||||||
->Ok
|
->Ok
|
||||||
|
|
|
@ -1,11 +1,11 @@
|
||||||
type t = PointSetTypes.pointSetDist
|
type pointSetDist = PointSetTypes.pointSetDist
|
||||||
|
|
||||||
type scalar = float
|
type scalar = float
|
||||||
type abstractScoreArgs<'a, 'b> = {estimate: 'a, answer: 'b, prior: option<'a>}
|
type abstractScoreArgs<'a, 'b> = {estimate: 'a, answer: 'b, prior: option<'a>}
|
||||||
type scoreArgs =
|
type scoreArgs =
|
||||||
| DistEstimateDistAnswer(abstractScoreArgs<t, t>)
|
| DistEstimateDistAnswer(abstractScoreArgs<pointSetDist, pointSetDist>)
|
||||||
| DistEstimateScalarAnswer(abstractScoreArgs<t, scalar>)
|
| DistEstimateScalarAnswer(abstractScoreArgs<pointSetDist, scalar>)
|
||||||
| ScalarEstimateDistAnswer(abstractScoreArgs<scalar, t>)
|
| ScalarEstimateDistAnswer(abstractScoreArgs<scalar, pointSetDist>)
|
||||||
| ScalarEstimateScalarAnswer(abstractScoreArgs<scalar, scalar>)
|
| ScalarEstimateScalarAnswer(abstractScoreArgs<scalar, scalar>)
|
||||||
|
|
||||||
let logFn = Js.Math.log // base e
|
let logFn = Js.Math.log // base e
|
||||||
|
@ -29,15 +29,17 @@ module WithDistAnswer = {
|
||||||
minusScaledLogOfQuot(~esti=estimateElement, ~answ=answerElement)
|
minusScaledLogOfQuot(~esti=estimateElement, ~answ=answerElement)
|
||||||
}
|
}
|
||||||
|
|
||||||
let rec sum = (~estimate: t, ~answer: t, ~combineFn, ~integrateFn, ~toMixedFn): result<
|
let sum = (
|
||||||
float,
|
~estimate: pointSetDist,
|
||||||
Operation.Error.t,
|
~answer: pointSetDist,
|
||||||
> =>
|
~combineFn,
|
||||||
switch (estimate, answer) {
|
~integrateFn,
|
||||||
| (Continuous(_), Continuous(_))
|
~toMixedFn,
|
||||||
| (Discrete(_), Discrete(_)) =>
|
): result<float, Operation.Error.t> => {
|
||||||
|
let combineAndIntegrate = (estimate, answer) =>
|
||||||
combineFn(integrand, estimate, answer)->E.R2.fmap(integrateFn)
|
combineFn(integrand, estimate, answer)->E.R2.fmap(integrateFn)
|
||||||
| (_, _) =>
|
|
||||||
|
let getMixedSums = (estimate: pointSetDist, answer: pointSetDist) => {
|
||||||
let esti = estimate->toMixedFn
|
let esti = estimate->toMixedFn
|
||||||
let answ = answer->toMixedFn
|
let answ = answer->toMixedFn
|
||||||
switch (
|
switch (
|
||||||
|
@ -53,29 +55,31 @@ module WithDistAnswer = {
|
||||||
Some(answDiscretePart),
|
Some(answDiscretePart),
|
||||||
) =>
|
) =>
|
||||||
E.R.merge(
|
E.R.merge(
|
||||||
sum(
|
combineAndIntegrate(
|
||||||
~estimate=Discrete(estiDiscretePart),
|
PointSetTypes.Discrete(estiDiscretePart),
|
||||||
~answer=Discrete(answDiscretePart),
|
PointSetTypes.Discrete(answDiscretePart),
|
||||||
~combineFn,
|
|
||||||
~integrateFn,
|
|
||||||
~toMixedFn,
|
|
||||||
),
|
),
|
||||||
sum(
|
combineAndIntegrate(Continuous(estiContinuousPart), Continuous(answContinuousPart)),
|
||||||
~estimate=Continuous(estiContinuousPart),
|
)
|
||||||
~answer=Continuous(answContinuousPart),
|
|
||||||
~combineFn,
|
|
||||||
~integrateFn,
|
|
||||||
~toMixedFn,
|
|
||||||
),
|
|
||||||
)->E.R2.fmap(((discretePart, continuousPart)) => discretePart +. continuousPart)
|
|
||||||
| (_, _, _, _) => `unreachable state`->Operation.Other->Error
|
| (_, _, _, _) => `unreachable state`->Operation.Other->Error
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
switch (estimate, answer) {
|
||||||
|
| (Continuous(_), Continuous(_))
|
||||||
|
| (Discrete(_), Discrete(_)) =>
|
||||||
|
combineAndIntegrate(estimate, answer)
|
||||||
|
| (_, _) =>
|
||||||
|
getMixedSums(estimate, answer)->E.R2.fmap(((discretePart, continuousPart)) =>
|
||||||
|
discretePart +. continuousPart
|
||||||
|
)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
let sumWithPrior = (
|
let sumWithPrior = (
|
||||||
~estimate: t,
|
~estimate: pointSetDist,
|
||||||
~answer: t,
|
~answer: pointSetDist,
|
||||||
~prior: t,
|
~prior: pointSetDist,
|
||||||
~combineFn,
|
~combineFn,
|
||||||
~integrateFn,
|
~integrateFn,
|
||||||
~toMixedFn,
|
~toMixedFn,
|
||||||
|
@ -87,7 +91,12 @@ module WithDistAnswer = {
|
||||||
}
|
}
|
||||||
|
|
||||||
module WithScalarAnswer = {
|
module WithScalarAnswer = {
|
||||||
let score' = (~estimatePdf: float => float, ~answer: float): result<float, Operation.Error.t> => {
|
let sum = (mp: PointSetTypes.MixedPoint.t): float => mp.continuous +. mp.discrete
|
||||||
|
let score = (~estimate: pointSetDist, ~answer: scalar): result<float, Operation.Error.t> => {
|
||||||
|
let _score = (~estimatePdf: float => float, ~answer: float): result<
|
||||||
|
float,
|
||||||
|
Operation.Error.t,
|
||||||
|
> => {
|
||||||
let density = answer->estimatePdf
|
let density = answer->estimatePdf
|
||||||
if density < 0.0 {
|
if density < 0.0 {
|
||||||
Operation.PdfInvalidError->Error
|
Operation.PdfInvalidError->Error
|
||||||
|
@ -97,7 +106,21 @@ module WithScalarAnswer = {
|
||||||
density->logFn->(x => -.x)->Ok
|
density->logFn->(x => -.x)->Ok
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
let scoreWithPrior' = (
|
|
||||||
|
let estimatePdf = x =>
|
||||||
|
switch estimate {
|
||||||
|
| Continuous(esti) => Continuous.T.xToY(x, esti)->sum
|
||||||
|
| Discrete(esti) => Discrete.T.xToY(x, esti)->sum
|
||||||
|
| Mixed(esti) => Mixed.T.xToY(x, esti)->sum
|
||||||
|
}
|
||||||
|
_score(~estimatePdf, ~answer)
|
||||||
|
}
|
||||||
|
|
||||||
|
let scoreWithPrior = (~estimate: pointSetDist, ~answer: scalar, ~prior: pointSetDist): result<
|
||||||
|
float,
|
||||||
|
Operation.Error.t,
|
||||||
|
> => {
|
||||||
|
let _scoreWithPrior = (
|
||||||
~estimatePdf: float => float,
|
~estimatePdf: float => float,
|
||||||
~answer: scalar,
|
~answer: scalar,
|
||||||
~priorPdf: float => float,
|
~priorPdf: float => float,
|
||||||
|
@ -113,21 +136,6 @@ module WithScalarAnswer = {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
let sum = (mp: PointSetTypes.MixedPoint.t): float => mp.continuous +. mp.discrete
|
|
||||||
let score = (~estimate: t, ~answer: scalar): result<float, Operation.Error.t> => {
|
|
||||||
let estimatePdf = x =>
|
|
||||||
switch estimate {
|
|
||||||
| Continuous(esti) => Continuous.T.xToY(x, esti)->sum
|
|
||||||
| Discrete(esti) => Discrete.T.xToY(x, esti)->sum
|
|
||||||
| Mixed(esti) => Mixed.T.xToY(x, esti)->sum
|
|
||||||
}
|
|
||||||
|
|
||||||
score'(~estimatePdf, ~answer)
|
|
||||||
}
|
|
||||||
let scoreWithPrior = (~estimate: t, ~answer: scalar, ~prior: t): result<
|
|
||||||
float,
|
|
||||||
Operation.Error.t,
|
|
||||||
> => {
|
|
||||||
let estimatePdf = x =>
|
let estimatePdf = x =>
|
||||||
switch estimate {
|
switch estimate {
|
||||||
| Continuous(esti) => Continuous.T.xToY(x, esti)->sum
|
| Continuous(esti) => Continuous.T.xToY(x, esti)->sum
|
||||||
|
@ -140,7 +148,7 @@ module WithScalarAnswer = {
|
||||||
| Discrete(prio) => Discrete.T.xToY(x, prio)->sum
|
| Discrete(prio) => Discrete.T.xToY(x, prio)->sum
|
||||||
| Mixed(prio) => Mixed.T.xToY(x, prio)->sum
|
| Mixed(prio) => Mixed.T.xToY(x, prio)->sum
|
||||||
}
|
}
|
||||||
scoreWithPrior'(~estimatePdf, ~answer, ~priorPdf)
|
_scoreWithPrior(~estimatePdf, ~answer, ~priorPdf)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue
Block a user