diff --git a/packages/squiggle-lang/src/rescript/Distributions/PointSetDist/Mixed.res b/packages/squiggle-lang/src/rescript/Distributions/PointSetDist/Mixed.res index c9d46efb..71310156 100644 --- a/packages/squiggle-lang/src/rescript/Distributions/PointSetDist/Mixed.res +++ b/packages/squiggle-lang/src/rescript/Distributions/PointSetDist/Mixed.res @@ -306,9 +306,6 @@ module T = Dist({ // let klContinuousPart = Continuous.T.klDivergence(prediction.continuous, answer.continuous) // E.R.merge(klDiscretePart, klContinuousPart)->E.R2.fmap(t => fst(t) +. snd(t)) // } - // let logScoreWithPointResolution = (~prediction: t, ~answer: float, ~prior: option) => { - // Error(Operation.NotYetImplemented) - // } }) let combineAlgebraically = (op: Operation.convolutionOperation, t1: t, t2: t): t => { diff --git a/packages/squiggle-lang/src/rescript/Distributions/PointSetDist/PointSetDist_Scoring.res b/packages/squiggle-lang/src/rescript/Distributions/PointSetDist/PointSetDist_Scoring.res index b302d83e..202397c5 100644 --- a/packages/squiggle-lang/src/rescript/Distributions/PointSetDist/PointSetDist_Scoring.res +++ b/packages/squiggle-lang/src/rescript/Distributions/PointSetDist/PointSetDist_Scoring.res @@ -29,8 +29,26 @@ module WithDistAnswer = { minusScaledLogOfQuot(~esti=estimateElement, ~answ=answerElement) } - let sum = (~estimate: t, ~answer: t, ~combineFn, ~integrateFn) => - combineFn(integrand, estimate, answer)->E.R2.fmap(integrateFn) + let rec sum = (~estimate: t, ~answer: t, ~combineFn, ~integrateFn) => + switch (estimate, answer) { + | ((Continuous(_) | Discrete(_)) as esti, (Continuous(_) | Discrete(_)) as answ) => + combineFn(integrand, esti, answ)->E.R2.fmap(integrateFn) + | (Mixed(esti), Mixed(answ)) => + E.R.merge( + sum( + ~estimate=Discrete(esti.discrete), + ~answer=Discrete(answ.discrete), + ~combineFn, + ~integrateFn, + ), + sum( + ~estimate=Continuous(esti.continuous), + ~answer=Continuous(answ.continuous), + ~combineFn, + ~integrateFn, + ), + )->E.R2.fmap(((discretePart, continuousPart)) => discretePart +. continuousPart) + } let sumWithPrior = (~estimate: t, ~answer: t, ~prior: t, ~combineFn, ~integrateFn): result< float,