diff --git a/packages/squiggle-lang/src/rescript/Distributions/GenericDist.res b/packages/squiggle-lang/src/rescript/Distributions/GenericDist.res index 3dd9981a..2a806c9c 100644 --- a/packages/squiggle-lang/src/rescript/Distributions/GenericDist.res +++ b/packages/squiggle-lang/src/rescript/Distributions/GenericDist.res @@ -128,26 +128,27 @@ module Score = { ~answ: scoreDistOrScalar, ~prior: option, ): result => { - let toPointSetFn = toPointSet( - ~xyPointLength=MagicNumbers.Environment.defaultXYPointLength, - ~sampleCount=MagicNumbers.Environment.defaultSampleCount, - ~xSelection=#ByWeight, - ) + let toPointSetFn = t => + toPointSet( + t, + ~xyPointLength=MagicNumbers.Environment.defaultXYPointLength, + ~sampleCount=MagicNumbers.Environment.defaultSampleCount, + ~xSelection=#ByWeight, + (), + ) let prior': option> = switch prior { | None => None - | Some(Score_Dist(d)) => toPointSetFn(d, ())->E.R.bind(x => x->D->Ok)->Some + | Some(Score_Dist(d)) => toPointSetFn(d)->E.R.bind(x => x->D->Ok)->Some | Some(Score_Scalar(s)) => s->S->Ok->Some } let twoDists = (esti': t, answ': t): result< (PointSetTypes.pointSetDist, PointSetTypes.pointSetDist), error, - > => E.R.merge(toPointSetFn(esti', ()), toPointSetFn(answ', ())) + > => E.R.merge(toPointSetFn(esti'), toPointSetFn(answ')) switch (esti, answ, prior') { | (Score_Dist(esti'), Score_Dist(answ'), None) => - twoDists(esti', answ')->E.R.bind(((esti'', answ'')) => - {estimate: esti'', answer: answ'', prior: None} - ->PointSetDist_Scoring.DistEstimateDistAnswer - ->Ok + twoDists(esti', answ')->E.R2.fmap(((esti'', answ'')) => + {estimate: esti'', answer: answ'', prior: None}->PointSetDist_Scoring.DistEstimateDistAnswer ) | (Score_Dist(esti'), Score_Dist(answ'), Some(Ok(D(prior'')))) => twoDists(esti', answ')->E.R.bind(((esti'', answ'')) => @@ -157,25 +158,25 @@ module Score = { ) | (Score_Dist(_), _, Some(Ok(S(_)))) => DistributionTypes.Unreachable->Error | (Score_Dist(esti'), Score_Scalar(answ'), None) => - toPointSetFn(esti', ())->E.R.bind(esti'' => + toPointSetFn(esti')->E.R.bind(esti'' => {estimate: esti'', answer: answ', prior: None} ->PointSetDist_Scoring.DistEstimateScalarAnswer ->Ok ) | (Score_Dist(esti'), Score_Scalar(answ'), Some(Ok(D(prior'')))) => - toPointSetFn(esti', ())->E.R.bind(esti'' => + toPointSetFn(esti')->E.R.bind(esti'' => {estimate: esti'', answer: answ', prior: Some(prior'')} ->PointSetDist_Scoring.DistEstimateScalarAnswer ->Ok ) | (Score_Scalar(esti'), Score_Dist(answ'), None) => - toPointSetFn(answ', ())->E.R.bind(answ'' => + toPointSetFn(answ')->E.R.bind(answ'' => {estimate: esti', answer: answ'', prior: None} ->PointSetDist_Scoring.ScalarEstimateDistAnswer ->Ok ) | (Score_Scalar(esti'), Score_Dist(answ'), Some(Ok(S(prior'')))) => - toPointSetFn(answ', ())->E.R.bind(answ'' => + toPointSetFn(answ')->E.R.bind(answ'' => {estimate: esti', answer: answ'', prior: Some(prior'')} ->PointSetDist_Scoring.ScalarEstimateDistAnswer ->Ok diff --git a/packages/squiggle-lang/src/rescript/Distributions/PointSetDist/PointSetDist_Scoring.res b/packages/squiggle-lang/src/rescript/Distributions/PointSetDist/PointSetDist_Scoring.res index a228fcdc..284914c2 100644 --- a/packages/squiggle-lang/src/rescript/Distributions/PointSetDist/PointSetDist_Scoring.res +++ b/packages/squiggle-lang/src/rescript/Distributions/PointSetDist/PointSetDist_Scoring.res @@ -1,11 +1,11 @@ -type t = PointSetTypes.pointSetDist +type pointSetDist = PointSetTypes.pointSetDist type scalar = float type abstractScoreArgs<'a, 'b> = {estimate: 'a, answer: 'b, prior: option<'a>} type scoreArgs = - | DistEstimateDistAnswer(abstractScoreArgs) - | DistEstimateScalarAnswer(abstractScoreArgs) - | ScalarEstimateDistAnswer(abstractScoreArgs) + | DistEstimateDistAnswer(abstractScoreArgs) + | DistEstimateScalarAnswer(abstractScoreArgs) + | ScalarEstimateDistAnswer(abstractScoreArgs) | ScalarEstimateScalarAnswer(abstractScoreArgs) let logFn = Js.Math.log // base e @@ -29,15 +29,17 @@ module WithDistAnswer = { minusScaledLogOfQuot(~esti=estimateElement, ~answ=answerElement) } - let rec sum = (~estimate: t, ~answer: t, ~combineFn, ~integrateFn, ~toMixedFn): result< - float, - Operation.Error.t, - > => - switch (estimate, answer) { - | (Continuous(_), Continuous(_)) - | (Discrete(_), Discrete(_)) => + let sum = ( + ~estimate: pointSetDist, + ~answer: pointSetDist, + ~combineFn, + ~integrateFn, + ~toMixedFn, + ): result => { + let combineAndIntegrate = (estimate, answer) => combineFn(integrand, estimate, answer)->E.R2.fmap(integrateFn) - | (_, _) => + + let getMixedSums = (estimate: pointSetDist, answer: pointSetDist) => { let esti = estimate->toMixedFn let answ = answer->toMixedFn switch ( @@ -53,29 +55,31 @@ module WithDistAnswer = { Some(answDiscretePart), ) => E.R.merge( - sum( - ~estimate=Discrete(estiDiscretePart), - ~answer=Discrete(answDiscretePart), - ~combineFn, - ~integrateFn, - ~toMixedFn, + combineAndIntegrate( + PointSetTypes.Discrete(estiDiscretePart), + PointSetTypes.Discrete(answDiscretePart), ), - sum( - ~estimate=Continuous(estiContinuousPart), - ~answer=Continuous(answContinuousPart), - ~combineFn, - ~integrateFn, - ~toMixedFn, - ), - )->E.R2.fmap(((discretePart, continuousPart)) => discretePart +. continuousPart) + combineAndIntegrate(Continuous(estiContinuousPart), Continuous(answContinuousPart)), + ) | (_, _, _, _) => `unreachable state`->Operation.Other->Error } } + switch (estimate, answer) { + | (Continuous(_), Continuous(_)) + | (Discrete(_), Discrete(_)) => + combineAndIntegrate(estimate, answer) + | (_, _) => + getMixedSums(estimate, answer)->E.R2.fmap(((discretePart, continuousPart)) => + discretePart +. continuousPart + ) + } + } + let sumWithPrior = ( - ~estimate: t, - ~answer: t, - ~prior: t, + ~estimate: pointSetDist, + ~answer: pointSetDist, + ~prior: pointSetDist, ~combineFn, ~integrateFn, ~toMixedFn, @@ -87,47 +91,51 @@ module WithDistAnswer = { } module WithScalarAnswer = { - let score' = (~estimatePdf: float => float, ~answer: float): result => { - let density = answer->estimatePdf - if density < 0.0 { - Operation.PdfInvalidError->Error - } else if density == 0.0 { - infinity->Ok - } else { - density->logFn->(x => -.x)->Ok - } - } - let scoreWithPrior' = ( - ~estimatePdf: float => float, - ~answer: scalar, - ~priorPdf: float => float, - ): result => { - let numerator = answer->estimatePdf - let priorDensityOfAnswer = answer->priorPdf - if numerator < 0.0 || priorDensityOfAnswer < 0.0 { - Operation.PdfInvalidError->Error - } else if numerator == 0.0 || priorDensityOfAnswer == 0.0 { - infinity->Ok - } else { - minusScaledLogOfQuot(~esti=numerator, ~answ=priorDensityOfAnswer) - } - } - let sum = (mp: PointSetTypes.MixedPoint.t): float => mp.continuous +. mp.discrete - let score = (~estimate: t, ~answer: scalar): result => { + let score = (~estimate: pointSetDist, ~answer: scalar): result => { + let _score = (~estimatePdf: float => float, ~answer: float): result< + float, + Operation.Error.t, + > => { + let density = answer->estimatePdf + if density < 0.0 { + Operation.PdfInvalidError->Error + } else if density == 0.0 { + infinity->Ok + } else { + density->logFn->(x => -.x)->Ok + } + } + let estimatePdf = x => switch estimate { | Continuous(esti) => Continuous.T.xToY(x, esti)->sum | Discrete(esti) => Discrete.T.xToY(x, esti)->sum | Mixed(esti) => Mixed.T.xToY(x, esti)->sum } - - score'(~estimatePdf, ~answer) + _score(~estimatePdf, ~answer) } - let scoreWithPrior = (~estimate: t, ~answer: scalar, ~prior: t): result< + + let scoreWithPrior = (~estimate: pointSetDist, ~answer: scalar, ~prior: pointSetDist): result< float, Operation.Error.t, > => { + let _scoreWithPrior = ( + ~estimatePdf: float => float, + ~answer: scalar, + ~priorPdf: float => float, + ): result => { + let numerator = answer->estimatePdf + let priorDensityOfAnswer = answer->priorPdf + if numerator < 0.0 || priorDensityOfAnswer < 0.0 { + Operation.PdfInvalidError->Error + } else if numerator == 0.0 || priorDensityOfAnswer == 0.0 { + infinity->Ok + } else { + minusScaledLogOfQuot(~esti=numerator, ~answ=priorDensityOfAnswer) + } + } + let estimatePdf = x => switch estimate { | Continuous(esti) => Continuous.T.xToY(x, esti)->sum @@ -140,7 +148,7 @@ module WithScalarAnswer = { | Discrete(prio) => Discrete.T.xToY(x, prio)->sum | Mixed(prio) => Mixed.T.xToY(x, prio)->sum } - scoreWithPrior'(~estimatePdf, ~answer, ~priorPdf) + _scoreWithPrior(~estimatePdf, ~answer, ~priorPdf) } }