diff --git a/packages/squiggle-lang/src/rescript/Distributions/GenericDist.res b/packages/squiggle-lang/src/rescript/Distributions/GenericDist.res index e63490ea..02163092 100644 --- a/packages/squiggle-lang/src/rescript/Distributions/GenericDist.res +++ b/packages/squiggle-lang/src/rescript/Distributions/GenericDist.res @@ -183,18 +183,8 @@ module Score = { ->PointSetDist_Scoring.DistEstimateScalarAnswer ->Ok ) - | (Score_Scalar(esti'), Score_Dist(answ'), None) => - toPointSetFn(answ')->E.R.bind(answ'' => - {estimate: esti', answer: answ'', prior: None} - ->PointSetDist_Scoring.ScalarEstimateDistAnswer - ->Ok - ) - | (Score_Scalar(esti'), Score_Dist(answ'), Some(Ok(PSScalar(prior'')))) => - toPointSetFn(answ')->E.R.bind(answ'' => - {estimate: esti', answer: answ'', prior: Some(prior'')} - ->PointSetDist_Scoring.ScalarEstimateDistAnswer - ->Ok - ) + | (Score_Scalar(_), Score_Dist(_), None) => NotYetImplemented->Error + | (Score_Scalar(_), Score_Dist(_), Some(Ok(PSScalar(_)))) => NotYetImplemented->Error | (Score_Scalar(_), _, Some(Ok(PSDist(_)))) => DistributionTypes.Unreachable->Error | (Score_Scalar(esti'), Score_Scalar(answ'), None) => {estimate: esti', answer: answ', prior: None} diff --git a/packages/squiggle-lang/src/rescript/Distributions/PointSetDist/PointSetDist_Scoring.res b/packages/squiggle-lang/src/rescript/Distributions/PointSetDist/PointSetDist_Scoring.res index a5468b26..50f8faa9 100644 --- a/packages/squiggle-lang/src/rescript/Distributions/PointSetDist/PointSetDist_Scoring.res +++ b/packages/squiggle-lang/src/rescript/Distributions/PointSetDist/PointSetDist_Scoring.res @@ -1,11 +1,11 @@ type pointSetDist = PointSetTypes.pointSetDist type scalar = float +type score = float type abstractScoreArgs<'a, 'b> = {estimate: 'a, answer: 'b, prior: option<'a>} type scoreArgs = | DistEstimateDistAnswer(abstractScoreArgs) | DistEstimateScalarAnswer(abstractScoreArgs) - | ScalarEstimateDistAnswer(abstractScoreArgs) | ScalarEstimateScalarAnswer(abstractScoreArgs) let logFn = Js.Math.log // base e @@ -35,7 +35,7 @@ module WithDistAnswer = { ~combineFn, ~integrateFn, ~toMixedFn, - ): result => { + ): result => { let combineAndIntegrate = (estimate, answer) => combineFn(integrand, estimate, answer)->E.R2.fmap(integrateFn) @@ -83,7 +83,7 @@ module WithDistAnswer = { ~combineFn, ~integrateFn, ~toMixedFn, - ): result => { + ): result => { let kl1 = sum(~estimate, ~answer, ~combineFn, ~integrateFn, ~toMixedFn) let kl2 = sum(~estimate=prior, ~answer, ~combineFn, ~integrateFn, ~toMixedFn) E.R.merge(kl1, kl2)->E.R2.fmap(((kl1', kl2')) => kl1' -. kl2') @@ -92,9 +92,9 @@ module WithDistAnswer = { module WithScalarAnswer = { let sum = (mp: PointSetTypes.MixedPoint.t): float => mp.continuous +. mp.discrete - let score = (~estimate: pointSetDist, ~answer: scalar): result => { + let score = (~estimate: pointSetDist, ~answer: scalar): result => { let _score = (~estimatePdf: float => float, ~answer: float): result< - float, + score, Operation.Error.t, > => { let density = answer->estimatePdf @@ -117,7 +117,7 @@ module WithScalarAnswer = { } let scoreWithPrior = (~estimate: pointSetDist, ~answer: scalar, ~prior: pointSetDist): result< - float, + score, Operation.Error.t, > => { let _scoreWithPrior = ( @@ -178,7 +178,7 @@ let twoGenericDistsToTwoPointSetDists = (~toPointSetFn, estimate, answer): resul > => E.R.merge(toPointSetFn(estimate, ()), toPointSetFn(answer, ())) let logScore = (args: scoreArgs, ~combineFn, ~integrateFn, ~toMixedFn): result< - float, + score, Operation.Error.t, > => switch args { @@ -190,7 +190,6 @@ let logScore = (args: scoreArgs, ~combineFn, ~integrateFn, ~toMixedFn): result< WithScalarAnswer.score(~estimate, ~answer) | DistEstimateScalarAnswer({estimate, answer, prior: Some(prior)}) => WithScalarAnswer.scoreWithPrior(~estimate, ~answer, ~prior) - | ScalarEstimateDistAnswer(_) => Operation.NotYetImplemented->Error | ScalarEstimateScalarAnswer({estimate, answer, prior: None}) => TwoScalars.score(~estimate, ~answer) | ScalarEstimateScalarAnswer({estimate, answer, prior: Some(prior)}) =>