diff --git a/packages/squiggle-lang/src/rescript/Distributions/PointSetDist/Continuous.res b/packages/squiggle-lang/src/rescript/Distributions/PointSetDist/Continuous.res index 3aca0c66..4151254f 100644 --- a/packages/squiggle-lang/src/rescript/Distributions/PointSetDist/Continuous.res +++ b/packages/squiggle-lang/src/rescript/Distributions/PointSetDist/Continuous.res @@ -277,13 +277,19 @@ module T = Dist({ prediction.xyShape, answer.xyShape, ) - let xyShapeToContinuous: XYShape.xyShape => t = xyShape => { - xyShape: xyShape, - interpolation: #Linear, - integralSumCache: None, - integralCache: None, - } - newShape->E.R2.fmap(x => x->xyShapeToContinuous->integralEndY) + newShape->E.R2.fmap(x => x->make->integralEndY) + } + let logScore = (prior: t, prediction: t, answer: float) => { + let newShape = XYShape.PointwiseCombination.combineAlongSupportOfSecondArgument( + PointSetDist_Scoring.LogScore.integrand(~answer), + prior.xyShape, + prediction.xyShape, + ) + newShape->E.R2.fmap(x => x->make->integralEndY) + } + let logScoreAgainstImproperPrior = (prediction: t, answer: float) => { + let prior = make({xs: prediction.xyShape.xs, ys: E.A.fmap(_ => 1.0, prediction.xyShape.xs)}) + logScore(prior, prediction, answer) } }) diff --git a/packages/squiggle-lang/src/rescript/Distributions/PointSetDist/Discrete.res b/packages/squiggle-lang/src/rescript/Distributions/PointSetDist/Discrete.res index abb6b793..9bc274e7 100644 --- a/packages/squiggle-lang/src/rescript/Distributions/PointSetDist/Discrete.res +++ b/packages/squiggle-lang/src/rescript/Distributions/PointSetDist/Discrete.res @@ -229,4 +229,10 @@ module T = Dist({ answer, )->E.R2.fmap(integralEndY) } + let logScore = (prior: t, prediction: t, answer: float) => { + Error(Operation.NotYetImplemented) + } + let logScoreAgainstImproperPrior = (prediction: t, answer: float) => { + Error(Operation.NotYetImplemented) + } }) diff --git a/packages/squiggle-lang/src/rescript/Distributions/PointSetDist/Distributions.res b/packages/squiggle-lang/src/rescript/Distributions/PointSetDist/Distributions.res index 85ffe4b1..9fb7e689 100644 --- a/packages/squiggle-lang/src/rescript/Distributions/PointSetDist/Distributions.res +++ b/packages/squiggle-lang/src/rescript/Distributions/PointSetDist/Distributions.res @@ -34,6 +34,8 @@ module type dist = { let mean: t => float let variance: t => float let klDivergence: (t, t) => result + let logScore: (t, t, float) => result + let logScoreAgainstImproperPrior: (t, float) => result } module Dist = (T: dist) => { @@ -57,6 +59,8 @@ module Dist = (T: dist) => { let variance = T.variance let integralEndY = T.integralEndY let klDivergence = T.klDivergence + let logScore = T.logScore + let logScoreAgainstImproperPrior = T.logScoreAgainstImproperPrior let updateIntegralCache = T.updateIntegralCache diff --git a/packages/squiggle-lang/src/rescript/Distributions/PointSetDist/Mixed.res b/packages/squiggle-lang/src/rescript/Distributions/PointSetDist/Mixed.res index 7bbe2065..50e8a419 100644 --- a/packages/squiggle-lang/src/rescript/Distributions/PointSetDist/Mixed.res +++ b/packages/squiggle-lang/src/rescript/Distributions/PointSetDist/Mixed.res @@ -306,6 +306,12 @@ module T = Dist({ let klContinuousPart = Continuous.T.klDivergence(prediction.continuous, answer.continuous) E.R.merge(klDiscretePart, klContinuousPart)->E.R2.fmap(t => fst(t) +. snd(t)) } + let logScore = (prior: t, prediction: t, answer: float) => { + Error(Operation.NotYetImplemented) + } + let logScoreAgainstImproperPrior = (prediction: t, answer: float) => { + Error(Operation.NotYetImplemented) + } }) let combineAlgebraically = (op: Operation.convolutionOperation, t1: t, t2: t): t => { diff --git a/packages/squiggle-lang/src/rescript/Distributions/PointSetDist/PointSetDist.res b/packages/squiggle-lang/src/rescript/Distributions/PointSetDist/PointSetDist.res index db47d1e1..05e79830 100644 --- a/packages/squiggle-lang/src/rescript/Distributions/PointSetDist/PointSetDist.res +++ b/packages/squiggle-lang/src/rescript/Distributions/PointSetDist/PointSetDist.res @@ -196,13 +196,22 @@ module T = Dist({ | Continuous(m) => Continuous.T.variance(m) } - let klDivergence = (t1: t, t2: t) => - switch (t1, t2) { + let klDivergence = (prediction: t, answer: t) => + switch (prediction, answer) { | (Continuous(t1), Continuous(t2)) => Continuous.T.klDivergence(t1, t2) | (Discrete(t1), Discrete(t2)) => Discrete.T.klDivergence(t1, t2) - | (Mixed(t1), Mixed(t2)) => Mixed.T.klDivergence(t1, t2) - | _ => Error(NotYetImplemented) + | (m1, m2) => Mixed.T.klDivergence(m1->toMixed, m2->toMixed) } + + let logScore = (prior: t, prediction: t, answer: float) => { + switch (prior, prediction) { + | (Continuous(t1), Continuous(t2)) => Continuous.T.logScore(t1, t2, answer) + | _ => Error(Operation.NotYetImplemented) + } + } + let logScoreAgainstImproperPrior = (prediction: t, answer: float) => { + Error(Operation.NotYetImplemented) + } }) let pdf = (f: float, t: t) => { diff --git a/packages/squiggle-lang/src/rescript/Distributions/PointSetDist/PointSetDist_Scoring.res b/packages/squiggle-lang/src/rescript/Distributions/PointSetDist/PointSetDist_Scoring.res index b22883df..8daf260c 100644 --- a/packages/squiggle-lang/src/rescript/Distributions/PointSetDist/PointSetDist_Scoring.res +++ b/packages/squiggle-lang/src/rescript/Distributions/PointSetDist/PointSetDist_Scoring.res @@ -14,3 +14,20 @@ module KLDivergence = { quot < 0.0 ? Error(Operation.ComplexNumberError) : Ok(-.answerElement *. logFn(quot)) } } + +/* + +*/ +module LogScore = { + let logFn = Js.Math.log + let integrand = (priorElement: float, predictionElement: float, ~answer: float) => { + if answer == 0.0 { + Ok(0.0) + } else if predictionElement == 0.0 { + Ok(infinity) + } else { + let quot = predictionElement /. priorElement + quot < 0.0 ? Error(Operation.ComplexNumberError) : Ok(-.answer *. logFn(quot /. answer)) + } + } +}