From b4a1137019731994dc1a915f28740158d4299228 Mon Sep 17 00:00:00 2001 From: Quinn Dougherty Date: Fri, 13 May 2022 15:43:59 -0400 Subject: [PATCH] Implemented correct math underlying `logScoreWithPointResolution` Value: [1e-2 to 7e-1] Realized that I need to switch argument order, put `prior` last maybe. --- .../Distributions/DistributionOperation.res | 4 +-- .../Distributions/DistributionTypes.res | 2 +- .../rescript/Distributions/GenericDist.res | 31 +++++++++++++++--- .../rescript/Distributions/GenericDist.resi | 7 +++- .../Distributions/PointSetDist/Continuous.res | 11 +++---- .../Distributions/PointSetDist/Discrete.res | 2 +- .../PointSetDist/Distributions.res | 4 +-- .../Distributions/PointSetDist/Mixed.res | 2 +- .../PointSetDist/PointSetDist.res | 6 ++-- .../PointSetDist/PointSetDist_Scoring.res | 32 ++++++++++++++----- 10 files changed, 71 insertions(+), 30 deletions(-) diff --git a/packages/squiggle-lang/src/rescript/Distributions/DistributionOperation.res b/packages/squiggle-lang/src/rescript/Distributions/DistributionOperation.res index 97085133..4c4291eb 100644 --- a/packages/squiggle-lang/src/rescript/Distributions/DistributionOperation.res +++ b/packages/squiggle-lang/src/rescript/Distributions/DistributionOperation.res @@ -149,7 +149,7 @@ let rec run = (~env, functionCallInfo: functionCallInfo): outputType => { ->E.R2.fmap(r => Float(r)) ->OutputLocal.fromResult | ToScore(LogScore(prediction, answer)) => - GenericDist.Score.logScore(dist, prediction, answer, ~toPointSetFn) + GenericDist.Score.logScoreWithPointResolution(Some(dist), prediction, answer, ~toPointSetFn) ->E.R2.fmap(r => Float(r)) ->OutputLocal.fromResult | ToBool(IsNormalized) => dist->GenericDist.isNormalized->Bool @@ -267,7 +267,7 @@ module Constructors = { let isNormalized = (~env, dist) => C.isNormalized(dist)->run(~env)->toBoolR let klDivergence = (~env, dist1, dist2) => C.klDivergence(dist1, dist2)->run(~env)->toFloatR let logScore = (~env, prior, prediction, answer) => - C.logScore(prior, prediction, answer)->run(~env)->toFloatR + C.logScoreWithPointResolution(prior, prediction, answer)->run(~env)->toFloatR let toPointSet = (~env, dist) => C.toPointSet(dist)->run(~env)->toDistR let toSampleSet = (~env, dist, n) => C.toSampleSet(dist, n)->run(~env)->toDistR let fromSamples = (~env, xs) => C.fromSamples(xs)->run(~env)->toDistR diff --git a/packages/squiggle-lang/src/rescript/Distributions/DistributionTypes.res b/packages/squiggle-lang/src/rescript/Distributions/DistributionTypes.res index 35e1b1a7..480775bf 100644 --- a/packages/squiggle-lang/src/rescript/Distributions/DistributionTypes.res +++ b/packages/squiggle-lang/src/rescript/Distributions/DistributionTypes.res @@ -162,7 +162,7 @@ module Constructors = { let truncate = (dist, left, right): t => FromDist(ToDist(Truncate(left, right)), dist) let inspect = (dist): t => FromDist(ToDist(Inspect), dist) let klDivergence = (dist1, dist2): t => FromDist(ToScore(KLDivergence(dist2)), dist1) - let logScore = (prior, prediction, answer): t => FromDist( + let logScoreWithPointResolution = (prior, prediction, answer): t => FromDist( ToScore(LogScore(prediction, answer)), prior, ) diff --git a/packages/squiggle-lang/src/rescript/Distributions/GenericDist.res b/packages/squiggle-lang/src/rescript/Distributions/GenericDist.res index 3e067c51..3357556f 100644 --- a/packages/squiggle-lang/src/rescript/Distributions/GenericDist.res +++ b/packages/squiggle-lang/src/rescript/Distributions/GenericDist.res @@ -67,11 +67,32 @@ module Score = { ) } - let logScore = (prior, prediction, answer, ~toPointSetFn: toPointSetFn): result => { - let pointSets = E.R.merge(toPointSetFn(prior), toPointSetFn(prediction)) - pointSets |> E.R2.bind(((a, b)) => - PointSetDist.T.logScore(a, b, answer)->E.R2.errMap(x => DistributionTypes.OperationError(x)) - ) + let logScoreWithPointResolution = ( + prior, + prediction, + answer, + ~toPointSetFn: toPointSetFn, + ): result => { + switch prior { + | Some(prior') => + E.R.merge(toPointSetFn(prior'), toPointSetFn(prediction))->E.R.bind(((a, b)) => + PointSetDist.T.logScoreWithPointResolution( + a->Some, + b, + answer, + )->E.R2.errMap(x => DistributionTypes.OperationError(x)) + ) + | None => + prediction + ->toPointSetFn + ->E.R.bind(x => + PointSetDist.T.logScoreWithPointResolution( + None, + x, + answer, + )->E.R2.errMap(x => DistributionTypes.OperationError(x)) + ) + } } } diff --git a/packages/squiggle-lang/src/rescript/Distributions/GenericDist.resi b/packages/squiggle-lang/src/rescript/Distributions/GenericDist.resi index 45f1e8f8..712e38b9 100644 --- a/packages/squiggle-lang/src/rescript/Distributions/GenericDist.resi +++ b/packages/squiggle-lang/src/rescript/Distributions/GenericDist.resi @@ -25,7 +25,12 @@ let toFloatOperation: ( module Score: { let klDivergence: (t, t, ~toPointSetFn: toPointSetFn) => result - let logScore: (t, t, float, ~toPointSetFn: toPointSetFn) => result + let logScoreWithPointResolution: ( + option, + t, + float, + ~toPointSetFn: toPointSetFn, + ) => result } @genType diff --git a/packages/squiggle-lang/src/rescript/Distributions/PointSetDist/Continuous.res b/packages/squiggle-lang/src/rescript/Distributions/PointSetDist/Continuous.res index 7bf9c874..22aaaf3b 100644 --- a/packages/squiggle-lang/src/rescript/Distributions/PointSetDist/Continuous.res +++ b/packages/squiggle-lang/src/rescript/Distributions/PointSetDist/Continuous.res @@ -279,13 +279,10 @@ module T = Dist({ ) newShape->E.R2.fmap(x => x->make->integralEndY) } - let logScore = (prior: t, prediction: t, answer: float) => { - let newShape = XYShape.PointwiseCombination.combineAlongSupportOfSecondArgument( - PointSetDist_Scoring.LogScore.integrand(~answer), - prior.xyShape, - prediction.xyShape, - ) - newShape->E.R2.fmap(x => x->make->integralEndY) + let logScoreWithPointResolution = (prior: option, prediction: t, answer: float) => { + let priorPdf = prior->E.O2.fmap((shape, x) => XYShape.XtoY.linear(x, shape.xyShape)) + let predictionPdf = x => XYShape.XtoY.linear(x, prediction.xyShape) + PointSetDist_Scoring.LogScoreWithPointResolution.score(~priorPdf, ~predictionPdf, ~answer) } }) diff --git a/packages/squiggle-lang/src/rescript/Distributions/PointSetDist/Discrete.res b/packages/squiggle-lang/src/rescript/Distributions/PointSetDist/Discrete.res index e10ed981..26e15f6e 100644 --- a/packages/squiggle-lang/src/rescript/Distributions/PointSetDist/Discrete.res +++ b/packages/squiggle-lang/src/rescript/Distributions/PointSetDist/Discrete.res @@ -229,7 +229,7 @@ module T = Dist({ answer, )->E.R2.fmap(integralEndY) } - let logScore = (prior: t, prediction: t, answer: float) => { + let logScoreWithPointResolution = (prior: option, prediction: t, answer: float) => { Error(Operation.NotYetImplemented) } }) diff --git a/packages/squiggle-lang/src/rescript/Distributions/PointSetDist/Distributions.res b/packages/squiggle-lang/src/rescript/Distributions/PointSetDist/Distributions.res index aa44da10..014c0668 100644 --- a/packages/squiggle-lang/src/rescript/Distributions/PointSetDist/Distributions.res +++ b/packages/squiggle-lang/src/rescript/Distributions/PointSetDist/Distributions.res @@ -34,7 +34,7 @@ module type dist = { let mean: t => float let variance: t => float let klDivergence: (t, t) => result - let logScore: (t, t, float) => result + let logScoreWithPointResolution: (option, t, float) => result } module Dist = (T: dist) => { @@ -58,7 +58,7 @@ module Dist = (T: dist) => { let variance = T.variance let integralEndY = T.integralEndY let klDivergence = T.klDivergence - let logScore = T.logScore + let logScoreWithPointResolution = T.logScoreWithPointResolution let updateIntegralCache = T.updateIntegralCache diff --git a/packages/squiggle-lang/src/rescript/Distributions/PointSetDist/Mixed.res b/packages/squiggle-lang/src/rescript/Distributions/PointSetDist/Mixed.res index 2fd046ee..4f864856 100644 --- a/packages/squiggle-lang/src/rescript/Distributions/PointSetDist/Mixed.res +++ b/packages/squiggle-lang/src/rescript/Distributions/PointSetDist/Mixed.res @@ -306,7 +306,7 @@ module T = Dist({ let klContinuousPart = Continuous.T.klDivergence(prediction.continuous, answer.continuous) E.R.merge(klDiscretePart, klContinuousPart)->E.R2.fmap(t => fst(t) +. snd(t)) } - let logScore = (prior: t, prediction: t, answer: float) => { + let logScoreWithPointResolution = (prior: option, prediction: t, answer: float) => { Error(Operation.NotYetImplemented) } }) diff --git a/packages/squiggle-lang/src/rescript/Distributions/PointSetDist/PointSetDist.res b/packages/squiggle-lang/src/rescript/Distributions/PointSetDist/PointSetDist.res index 2f8ebee3..7f87fd01 100644 --- a/packages/squiggle-lang/src/rescript/Distributions/PointSetDist/PointSetDist.res +++ b/packages/squiggle-lang/src/rescript/Distributions/PointSetDist/PointSetDist.res @@ -203,9 +203,11 @@ module T = Dist({ | (m1, m2) => Mixed.T.klDivergence(m1->toMixed, m2->toMixed) } - let logScore = (prior: t, prediction: t, answer: float) => { + let logScoreWithPointResolution = (prior: option, prediction: t, answer: float) => { switch (prior, prediction) { - | (Continuous(t1), Continuous(t2)) => Continuous.T.logScore(t1, t2, answer) + | (Some(Continuous(t1)), Continuous(t2)) => + Continuous.T.logScoreWithPointResolution(t1->Some, t2, answer) + | (None, Continuous(t2)) => Continuous.T.logScoreWithPointResolution(None, t2, answer) | _ => Error(Operation.NotYetImplemented) } } diff --git a/packages/squiggle-lang/src/rescript/Distributions/PointSetDist/PointSetDist_Scoring.res b/packages/squiggle-lang/src/rescript/Distributions/PointSetDist/PointSetDist_Scoring.res index 8daf260c..ddf5207d 100644 --- a/packages/squiggle-lang/src/rescript/Distributions/PointSetDist/PointSetDist_Scoring.res +++ b/packages/squiggle-lang/src/rescript/Distributions/PointSetDist/PointSetDist_Scoring.res @@ -18,16 +18,32 @@ module KLDivergence = { /* */ -module LogScore = { +module LogScoreWithPointResolution = { let logFn = Js.Math.log - let integrand = (priorElement: float, predictionElement: float, ~answer: float) => { - if answer == 0.0 { - Ok(0.0) - } else if predictionElement == 0.0 { - Ok(infinity) + let score = ( + ~priorPdf: option float>, + ~predictionPdf: float => float, + ~answer: float, + ): result => { + let numer = answer->predictionPdf + if numer < 0.0 { + Operation.ComplexNumberError->Error + } else if numer == 0.0 { + infinity->Ok } else { - let quot = predictionElement /. priorElement - quot < 0.0 ? Error(Operation.ComplexNumberError) : Ok(-.answer *. logFn(quot /. answer)) + -.( + switch priorPdf { + | None => numer->logFn + | Some(f) => { + let priorDensityOfAnswer = f(answer) + if priorDensityOfAnswer == 0.0 { + neg_infinity + } else { + (numer /. priorDensityOfAnswer)->logFn + } + } + } + )->Ok } } }