Prior can't be a point

This commit is contained in:
Ozzie Gooen 2022-07-12 09:45:41 -07:00
parent e9968288fd
commit 652394f535
7 changed files with 23 additions and 38 deletions

View File

@ -146,7 +146,7 @@ let rec run = (~env, functionCallInfo: functionCallInfo): outputType => {
} }
| #ToDist(Normalize) => dist->GenericDist.normalize->Dist | #ToDist(Normalize) => dist->GenericDist.normalize->Dist
| #ToScore(LogScore(answer, prior)) => | #ToScore(LogScore(answer, prior)) =>
GenericDist.Score.logScore(~estimate=Score_Dist(dist), ~answer, ~prior) GenericDist.Score.logScore(~estimate=dist, ~answer, ~prior)
->E.R2.fmap(s => Float(s)) ->E.R2.fmap(s => Float(s))
->OutputLocal.fromResult ->OutputLocal.fromResult
| #ToBool(IsNormalized) => dist->GenericDist.isNormalized->Bool | #ToBool(IsNormalized) => dist->GenericDist.isNormalized->Bool

View File

@ -72,7 +72,7 @@ module Constructors: {
~env: env, ~env: env,
genericDist, genericDist,
genericDist, genericDist,
DistributionTypes.DistributionOperation.genericDistOrScalar, genericDist,
) => result<float, error> ) => result<float, error>
@genType @genType
let distEstimateScalarAnswer: (~env: env, genericDist, float) => result<float, error> let distEstimateScalarAnswer: (~env: env, genericDist, float) => result<float, error>
@ -81,7 +81,7 @@ module Constructors: {
~env: env, ~env: env,
genericDist, genericDist,
float, float,
DistributionTypes.DistributionOperation.genericDistOrScalar, genericDist,
) => result<float, error> ) => result<float, error>
} }
@genType @genType

View File

@ -100,7 +100,7 @@ module DistributionOperation = {
type genericDistOrScalar = Score_Dist(genericDist) | Score_Scalar(float) type genericDistOrScalar = Score_Dist(genericDist) | Score_Scalar(float)
type toScore = LogScore(genericDistOrScalar, option<genericDistOrScalar>) type toScore = LogScore(genericDistOrScalar, option<genericDist>)
type fromFloat = [ type fromFloat = [
| #ToFloat(toFloat) | #ToFloat(toFloat)

View File

@ -133,13 +133,11 @@ let toPointSet = (
module Score = { module Score = {
type genericDistOrScalar = DistributionTypes.DistributionOperation.genericDistOrScalar type genericDistOrScalar = DistributionTypes.DistributionOperation.genericDistOrScalar
type pointSet_ScoreDistOrScalar = PSDist(PointSetTypes.pointSetDist) | PSScalar(float)
let argsMake = ( let argsMake = (~esti: t, ~answ: genericDistOrScalar, ~prior: option<t>): result<
~esti: genericDistOrScalar, PointSetDist_Scoring.scoreArgs,
~answ: genericDistOrScalar, error,
~prior: option<genericDistOrScalar>, > => {
): result<PointSetDist_Scoring.scoreArgs, error> => {
let toPointSetFn = t => let toPointSetFn = t =>
toPointSet( toPointSet(
t, t,
@ -148,21 +146,20 @@ module Score = {
~xSelection=#ByWeight, ~xSelection=#ByWeight,
(), (),
) )
let prior': option<result<pointSet_ScoreDistOrScalar, error>> = switch prior { let prior': option<result<PointSetTypes.pointSetDist, error>> = switch prior {
| None => None | None => None
| Some(Score_Dist(d)) => toPointSetFn(d)->E.R.bind(x => x->PSDist->Ok)->Some | Some(d) => toPointSetFn(d)->Some
| Some(Score_Scalar(s)) => s->PSScalar->Ok->Some
} }
let twoDists = (~toPointSetFn, esti': t, answ': t): result< let twoDists = (~toPointSetFn, esti': t, answ': t): result<
(PointSetTypes.pointSetDist, PointSetTypes.pointSetDist), (PointSetTypes.pointSetDist, PointSetTypes.pointSetDist),
error, error,
> => E.R.merge(toPointSetFn(esti'), toPointSetFn(answ')) > => E.R.merge(toPointSetFn(esti'), toPointSetFn(answ'))
switch (esti, answ, prior') { switch (esti, answ, prior') {
| (Score_Dist(esti'), Score_Dist(answ'), None) => | (esti', Score_Dist(answ'), None) =>
twoDists(~toPointSetFn, esti', answ')->E.R2.fmap(((esti'', answ'')) => twoDists(~toPointSetFn, esti', answ')->E.R2.fmap(((esti'', answ'')) =>
{estimate: esti'', answer: answ'', prior: None}->PointSetDist_Scoring.DistAnswer {estimate: esti'', answer: answ'', prior: None}->PointSetDist_Scoring.DistAnswer
) )
| (Score_Dist(esti'), Score_Dist(answ'), Some(Ok(PSDist(prior'')))) => | (esti', Score_Dist(answ'), Some(Ok(prior''))) =>
twoDists(~toPointSetFn, esti', answ')->E.R2.fmap(((esti'', answ'')) => twoDists(~toPointSetFn, esti', answ')->E.R2.fmap(((esti'', answ'')) =>
{ {
estimate: esti'', estimate: esti'',
@ -170,8 +167,7 @@ module Score = {
prior: Some(prior''), prior: Some(prior''),
}->PointSetDist_Scoring.DistAnswer }->PointSetDist_Scoring.DistAnswer
) )
| (Score_Dist(_), _, Some(Ok(PSScalar(_)))) => DistributionTypes.Unreachable->Error | (esti', Score_Scalar(answ'), None) =>
| (Score_Dist(esti'), Score_Scalar(answ'), None) =>
toPointSetFn(esti')->E.R2.fmap(esti'' => toPointSetFn(esti')->E.R2.fmap(esti'' =>
{ {
estimate: esti'', estimate: esti'',
@ -179,7 +175,7 @@ module Score = {
prior: None, prior: None,
}->PointSetDist_Scoring.ScalarAnswer }->PointSetDist_Scoring.ScalarAnswer
) )
| (Score_Dist(esti'), Score_Scalar(answ'), Some(Ok(PSDist(prior'')))) => | (esti', Score_Scalar(answ'), Some(Ok(prior''))) =>
toPointSetFn(esti')->E.R2.fmap(esti'' => toPointSetFn(esti')->E.R2.fmap(esti'' =>
{ {
estimate: esti'', estimate: esti'',
@ -187,19 +183,14 @@ module Score = {
prior: Some(prior''), prior: Some(prior''),
}->PointSetDist_Scoring.ScalarAnswer }->PointSetDist_Scoring.ScalarAnswer
) )
| (Score_Scalar(_), Score_Dist(_), None) => NotYetImplemented->Error
| (Score_Scalar(_), Score_Dist(_), Some(Ok(PSScalar(_)))) => NotYetImplemented->Error
| (Score_Scalar(_), _, Some(Ok(PSDist(_)))) => DistributionTypes.Unreachable->Error
| (Score_Scalar(_), Score_Scalar(_), _) => NotYetImplemented->Error
| (_, _, Some(Error(err))) => err->Error | (_, _, Some(Error(err))) => err->Error
} }
} }
let logScore = ( let logScore = (~estimate: t, ~answer: genericDistOrScalar, ~prior: option<t>): result<
~estimate: genericDistOrScalar, float,
~answer: genericDistOrScalar, error,
~prior: option<genericDistOrScalar>, > =>
): result<float, error> =>
argsMake(~esti=estimate, ~answ=answer, ~prior)->E.R.bind(x => argsMake(~esti=estimate, ~answ=answer, ~prior)->E.R.bind(x =>
x->PointSetDist.logScore->E.R2.errMap(y => DistributionTypes.OperationError(y)) x->PointSetDist.logScore->E.R2.errMap(y => DistributionTypes.OperationError(y))
) )

View File

@ -26,9 +26,9 @@ let toFloatOperation: (
module Score: { module Score: {
let logScore: ( let logScore: (
~estimate: DistributionTypes.DistributionOperation.genericDistOrScalar, ~estimate: t,
~answer: DistributionTypes.DistributionOperation.genericDistOrScalar, ~answer: DistributionTypes.DistributionOperation.genericDistOrScalar,
~prior: option<DistributionTypes.DistributionOperation.genericDistOrScalar>, ~prior: option<t>,
) => result<float, error> ) => result<float, error>
} }

View File

@ -19,7 +19,7 @@ module WithDistAnswer = {
float, float,
Operation.Error.t, Operation.Error.t,
> => > =>
// We decided that negative infinity, not an error at answerElement = 0.0, is a desirable value. // We decided that 0.0, not an error at answerElement = 0.0, is a desirable value.
if answerElement == 0.0 { if answerElement == 0.0 {
Ok(0.0) Ok(0.0)
} else if estimateElement == 0.0 { } else if estimateElement == 0.0 {

View File

@ -231,10 +231,7 @@ let dispatchToGenericOutput = (call: IEV.functionCall, env: DistributionOperatio
DistributionOperation.run( DistributionOperation.run(
FromDist( FromDist(
#ToScore( #ToScore(
LogScore( LogScore(DistributionTypes.DistributionOperation.Score_Dist(answer), Some(prior)),
DistributionTypes.DistributionOperation.Score_Dist(answer),
Some(DistributionTypes.DistributionOperation.Score_Dist(prior)),
),
), ),
prediction, prediction,
), ),
@ -256,10 +253,7 @@ let dispatchToGenericOutput = (call: IEV.functionCall, env: DistributionOperatio
DistributionOperation.run( DistributionOperation.run(
FromDist( FromDist(
#ToScore( #ToScore(
LogScore( LogScore(DistributionTypes.DistributionOperation.Score_Scalar(answer), prior->Some),
DistributionTypes.DistributionOperation.Score_Scalar(answer),
DistributionTypes.DistributionOperation.Score_Dist(prior)->Some,
),
), ),
prediction, prediction,
), ),