diff --git a/packages/squiggle-lang/__tests__/Distributions/Scoring/KlDivergence_test.res b/packages/squiggle-lang/__tests__/Distributions/Scoring/KlDivergence_test.res index f1e6c23b..e281b9de 100644 --- a/packages/squiggle-lang/__tests__/Distributions/Scoring/KlDivergence_test.res +++ b/packages/squiggle-lang/__tests__/Distributions/Scoring/KlDivergence_test.res @@ -57,7 +57,7 @@ describe("klDivergence: continuous -> continuous -> float", () => { let kl = E.R.liftJoin2(klDivergence, prediction, answer) switch kl { - | Ok(kl') => kl'->expect->toBeSoCloseTo(analyticalKl, ~digits=3) + | Ok(kl') => kl'->expect->toBeSoCloseTo(analyticalKl, ~digits=2) | Error(err) => { Js.Console.log(DistributionTypes.Error.toString(err)) raise(KlFailed) diff --git a/packages/squiggle-lang/src/rescript/Distributions/GenericDist.res b/packages/squiggle-lang/src/rescript/Distributions/GenericDist.res index dfa728bc..3dd9981a 100644 --- a/packages/squiggle-lang/src/rescript/Distributions/GenericDist.res +++ b/packages/squiggle-lang/src/rescript/Distributions/GenericDist.res @@ -155,6 +155,7 @@ module Score = { ->PointSetDist_Scoring.DistEstimateDistAnswer ->Ok ) + | (Score_Dist(_), _, Some(Ok(S(_)))) => DistributionTypes.Unreachable->Error | (Score_Dist(esti'), Score_Scalar(answ'), None) => toPointSetFn(esti', ())->E.R.bind(esti'' => {estimate: esti'', answer: answ', prior: None} @@ -179,6 +180,7 @@ module Score = { ->PointSetDist_Scoring.ScalarEstimateDistAnswer ->Ok ) + | (Score_Scalar(_), _, Some(Ok(D(_)))) => DistributionTypes.Unreachable->Error | (Score_Scalar(esti'), Score_Scalar(answ'), None) => {estimate: esti', answer: answ', prior: None} ->PointSetDist_Scoring.ScalarEstimateScalarAnswer @@ -199,44 +201,6 @@ module Score = { argsMake(~esti=estimate, ~answ=answer, ~prior)->E.R.bind(x => x->PointSetDist.logScore->E.R2.errMap(y => DistributionTypes.OperationError(y)) ) - - // let klDivergence = (prediction, answer, ~toPointSetFn: toPointSetFn): result => { - // let pointSets = E.R.merge(toPointSetFn(prediction), toPointSetFn(answer)) - // pointSets |> E.R2.bind(((predi, ans)) => - // PointSetDist.T.klDivergence(predi, ans)->E.R2.errMap(x => DistributionTypes.OperationError(x)) - // ) - // } - // - // let logScoreWithPointResolution = ( - // ~prediction: DistributionTypes.genericDist, - // ~answer: float, - // ~prior: option, - // ~toPointSetFn: toPointSetFn, - // ): result => { - // switch prior { - // | Some(prior') => - // E.R.merge(toPointSetFn(prior'), toPointSetFn(prediction))->E.R.bind((( - // prior'', - // prediction'', - // )) => - // PointSetDist.T.logScoreWithPointResolution( - // ~prediction=prediction'', - // ~answer, - // ~prior=prior''->Some, - // )->E.R2.errMap(x => DistributionTypes.OperationError(x)) - // ) - // | None => - // prediction - // ->toPointSetFn - // ->E.R.bind(x => - // PointSetDist.T.logScoreWithPointResolution( - // ~prediction=x, - // ~answer, - // ~prior=None, - // )->E.R2.errMap(x => DistributionTypes.OperationError(x)) - // ) - // } - // } } /* PointSetDist.toSparkline calls "downsampleEquallyOverX", which downsamples it to n=bucketCount. diff --git a/packages/squiggle-lang/src/rescript/Distributions/PointSetDist/Continuous.res b/packages/squiggle-lang/src/rescript/Distributions/PointSetDist/Continuous.res index 658c6b8a..4abb389d 100644 --- a/packages/squiggle-lang/src/rescript/Distributions/PointSetDist/Continuous.res +++ b/packages/squiggle-lang/src/rescript/Distributions/PointSetDist/Continuous.res @@ -270,20 +270,6 @@ module T = Dist({ } let variance = (t: t): float => XYShape.Analysis.getVarianceDangerously(t, mean, Analysis.getMeanOfSquares) - - // let klDivergence = (prediction: t, answer: t) => { - // let newShape = XYShape.PointwiseCombination.combineAlongSupportOfSecondArgument( - // PointSetDist_Scoring.KLDivergence.integrand, - // prediction.xyShape, - // answer.xyShape, - // ) - // newShape->E.R2.fmap(x => x->make->integralEndY) - // } - // let logScoreWithPointResolution = (~prediction: t, ~answer: float, ~prior: option) => { - // let priorPdf = prior->E.O2.fmap((shape, x) => XYShape.XtoY.linear(x, shape.xyShape)) - // let predictionPdf = x => XYShape.XtoY.linear(x, prediction.xyShape) - // PointSetDist_Scoring.LogScoreWithPointResolution.score(~priorPdf, ~predictionPdf, ~answer) - // } }) let isNormalized = (t: t): bool => { diff --git a/packages/squiggle-lang/src/rescript/Distributions/PointSetDist/Distributions.res b/packages/squiggle-lang/src/rescript/Distributions/PointSetDist/Distributions.res index fe848107..3a35d57b 100644 --- a/packages/squiggle-lang/src/rescript/Distributions/PointSetDist/Distributions.res +++ b/packages/squiggle-lang/src/rescript/Distributions/PointSetDist/Distributions.res @@ -33,12 +33,6 @@ module type dist = { let mean: t => float let variance: t => float - // let klDivergence: (t, t) => result - // let logScoreWithPointResolution: ( - // ~prediction: t, - // ~answer: float, - // ~prior: option, - // ) => result } module Dist = (T: dist) => { @@ -61,9 +55,6 @@ module Dist = (T: dist) => { let mean = T.mean let variance = T.variance let integralEndY = T.integralEndY - // let klDivergence = T.klDivergence - // let logScoreWithPointResolution = T.logScoreWithPointResolution - let updateIntegralCache = T.updateIntegralCache module Integral = { diff --git a/packages/squiggle-lang/src/rescript/Distributions/PointSetDist/PointSetDist.res b/packages/squiggle-lang/src/rescript/Distributions/PointSetDist/PointSetDist.res index a49163a1..b743c7bf 100644 --- a/packages/squiggle-lang/src/rescript/Distributions/PointSetDist/PointSetDist.res +++ b/packages/squiggle-lang/src/rescript/Distributions/PointSetDist/PointSetDist.res @@ -195,27 +195,10 @@ module T = Dist({ | Discrete(m) => Discrete.T.variance(m) | Continuous(m) => Continuous.T.variance(m) } - - // let klDivergence = (prediction: t, answer: t) => - // switch (prediction, answer) { - // | (Continuous(t1), Continuous(t2)) => Continuous.T.klDivergence(t1, t2) - // | (Discrete(t1), Discrete(t2)) => Discrete.T.klDivergence(t1, t2) - // | (m1, m2) => Mixed.T.klDivergence(m1->toMixed, m2->toMixed) - // } - // - // let logScoreWithPointResolution = (~prediction: t, ~answer: float, ~prior: option) => { - // switch (prior, prediction) { - // | (Some(Continuous(t1)), Continuous(t2)) => - // Continuous.T.logScoreWithPointResolution(~prediction=t2, ~answer, ~prior=t1->Some) - // | (None, Continuous(t2)) => - // Continuous.T.logScoreWithPointResolution(~prediction=t2, ~answer, ~prior=None) - // | _ => Error(Operation.NotYetImplemented) - // } - // } }) let logScore = (args: PointSetDist_Scoring.scoreArgs): result => - PointSetDist_Scoring.logScore(args, ~combineFn=combinePointwise, ~integrateFn=T.integralEndY) + PointSetDist_Scoring.logScore(args, ~combineFn=combinePointwise, ~integrateFn=T.Integral.sum) let pdf = (f: float, t: t) => { let mixedPoint: PointSetTypes.mixedPoint = T.xToY(f, t) diff --git a/packages/squiggle-lang/src/rescript/Distributions/PointSetDist/PointSetDist_Scoring.res b/packages/squiggle-lang/src/rescript/Distributions/PointSetDist/PointSetDist_Scoring.res index 95ebc89c..b302d83e 100644 --- a/packages/squiggle-lang/src/rescript/Distributions/PointSetDist/PointSetDist_Scoring.res +++ b/packages/squiggle-lang/src/rescript/Distributions/PointSetDist/PointSetDist_Scoring.res @@ -1,7 +1,4 @@ type t = PointSetTypes.pointSetDist -type continuousShape = PointSetTypes.continuousShape -type discreteShape = PointSetTypes.discreteShape -type mixedShape = PointSetTypes.mixedShape type scalar = float type abstractScoreArgs<'a, 'b> = {estimate: 'a, answer: 'b, prior: option<'a>} @@ -71,14 +68,14 @@ module WithScalarAnswer = { minusScaledLogOfQuot(~esti=numerator, ~answ=priorDensityOfAnswer) } } + + let sum = (mp: PointSetTypes.MixedPoint.t): float => mp.continuous +. mp.discrete let score = (~estimate: t, ~answer: scalar): result => { let estimatePdf = x => switch estimate { - | Continuous(esti) => XYShape.XtoY.linear(x, esti.xyShape) - | Discrete(esti) => XYShape.XtoY.linear(x, esti.xyShape) - | Mixed(esti) => - XYShape.XtoY.linear(x, esti.continuous.xyShape) +. - XYShape.XtoY.linear(x, esti.discrete.xyShape) + | Continuous(esti) => Continuous.T.xToY(x, esti)->sum + | Discrete(esti) => Discrete.T.xToY(x, esti)->sum + | Mixed(esti) => Mixed.T.xToY(x, esti)->sum } score'(~estimatePdf, ~answer) @@ -89,19 +86,15 @@ module WithScalarAnswer = { > => { let estimatePdf = x => switch estimate { - | Continuous(esti) => XYShape.XtoY.linear(x, esti.xyShape) - | Discrete(esti) => XYShape.XtoY.linear(x, esti.xyShape) - | Mixed(esti) => - XYShape.XtoY.linear(x, esti.continuous.xyShape) +. - XYShape.XtoY.linear(x, esti.discrete.xyShape) + | Continuous(esti) => Continuous.T.xToY(x, esti)->sum + | Discrete(esti) => Discrete.T.xToY(x, esti)->sum + | Mixed(esti) => Mixed.T.xToY(x, esti)->sum } let priorPdf = x => switch prior { - | Continuous(prio) => XYShape.XtoY.linear(x, prio.xyShape) - | Discrete(prio) => XYShape.XtoY.linear(x, prio.xyShape) - | Mixed(prio) => - XYShape.XtoY.linear(x, prio.continuous.xyShape) +. - XYShape.XtoY.linear(x, prio.discrete.xyShape) + | Continuous(prio) => Continuous.T.xToY(x, prio)->sum + | Discrete(prio) => Discrete.T.xToY(x, prio)->sum + | Mixed(prio) => Mixed.T.xToY(x, prio)->sum } scoreWithPrior'(~estimatePdf, ~answer, ~priorPdf) } diff --git a/packages/squiggle-lang/src/rescript/Utility/E.res b/packages/squiggle-lang/src/rescript/Utility/E.res index 3357f4f4..a3bcb911 100644 --- a/packages/squiggle-lang/src/rescript/Utility/E.res +++ b/packages/squiggle-lang/src/rescript/Utility/E.res @@ -541,6 +541,7 @@ module A = { let init = Array.init let reduce = Belt.Array.reduce let reducei = Belt.Array.reduceWithIndex + let some = Belt.Array.some let isEmpty = r => length(r) < 1 let stableSortBy = Belt.SortArray.stableSortBy let toNoneIfEmpty = r => isEmpty(r) ? None : Some(r)