diff --git a/packages/squiggle-lang/src/rescript/Distributions/DistributionOperation.res b/packages/squiggle-lang/src/rescript/Distributions/DistributionOperation.res index da4c010a..61b5cd6b 100644 --- a/packages/squiggle-lang/src/rescript/Distributions/DistributionOperation.res +++ b/packages/squiggle-lang/src/rescript/Distributions/DistributionOperation.res @@ -146,7 +146,16 @@ let rec run = (~env, functionCallInfo: functionCallInfo): outputType => { } | ToDist(Normalize) => dist->GenericDist.normalize->Dist | ToScore(KLDivergence(t2)) => - GenericDist.klDivergence(dist, t2, ~toPointSetFn) + GenericDist.Score.klDivergence(dist, t2, ~toPointSetFn) + ->E.R2.fmap(r => Float(r)) + ->OutputLocal.fromResult + | ToScore(LogScore(answer, prior)) => + GenericDist.Score.logScoreWithPointResolution( + ~prediction=dist, + ~answer, + ~prior, + ~toPointSetFn, + ) ->E.R2.fmap(r => Float(r)) ->OutputLocal.fromResult | ToBool(IsNormalized) => dist->GenericDist.isNormalized->Bool @@ -263,6 +272,12 @@ module Constructors = { let normalize = (~env, dist) => C.normalize(dist)->run(~env)->toDistR let isNormalized = (~env, dist) => C.isNormalized(dist)->run(~env)->toBoolR let klDivergence = (~env, dist1, dist2) => C.klDivergence(dist1, dist2)->run(~env)->toFloatR + let logScoreWithPointResolution = ( + ~env, + ~prediction: DistributionTypes.genericDist, + ~answer: float, + ~prior: option, + ) => C.logScoreWithPointResolution(~prediction, ~answer, ~prior)->run(~env)->toFloatR let toPointSet = (~env, dist) => C.toPointSet(dist)->run(~env)->toDistR let toSampleSet = (~env, dist, n) => C.toSampleSet(dist, n)->run(~env)->toDistR let fromSamples = (~env, xs) => C.fromSamples(xs)->run(~env)->toDistR diff --git a/packages/squiggle-lang/src/rescript/Distributions/DistributionOperation.resi b/packages/squiggle-lang/src/rescript/Distributions/DistributionOperation.resi index a8d61a9a..aa006c06 100644 --- a/packages/squiggle-lang/src/rescript/Distributions/DistributionOperation.resi +++ b/packages/squiggle-lang/src/rescript/Distributions/DistributionOperation.resi @@ -63,6 +63,13 @@ module Constructors: { @genType let klDivergence: (~env: env, genericDist, genericDist) => result @genType + let logScoreWithPointResolution: ( + ~env: env, + ~prediction: genericDist, + ~answer: float, + ~prior: option, + ) => result + @genType let toPointSet: (~env: env, genericDist) => result @genType let toSampleSet: (~env: env, genericDist, int) => result diff --git a/packages/squiggle-lang/src/rescript/Distributions/DistributionTypes.res b/packages/squiggle-lang/src/rescript/Distributions/DistributionTypes.res index a9f7dfbe..2bb409ad 100644 --- a/packages/squiggle-lang/src/rescript/Distributions/DistributionTypes.res +++ b/packages/squiggle-lang/src/rescript/Distributions/DistributionTypes.res @@ -91,7 +91,7 @@ module DistributionOperation = { | ToString | ToSparkline(int) - type toScore = KLDivergence(genericDist) + type toScore = KLDivergence(genericDist) | LogScore(float, option) type fromDist = | ToFloat(toFloat) @@ -120,6 +120,7 @@ module DistributionOperation = { | ToFloat(#Sample) => `sample` | ToFloat(#IntegralSum) => `integralSum` | ToScore(KLDivergence(_)) => `klDivergence` + | ToScore(LogScore(x, _)) => `logScore against ${E.Float.toFixed(x)}` | ToDist(Normalize) => `normalize` | ToDist(ToPointSet) => `toPointSet` | ToDist(ToSampleSet(r)) => `toSampleSet(${E.I.toString(r)})` @@ -161,6 +162,10 @@ module Constructors = { let truncate = (dist, left, right): t => FromDist(ToDist(Truncate(left, right)), dist) let inspect = (dist): t => FromDist(ToDist(Inspect), dist) let klDivergence = (dist1, dist2): t => FromDist(ToScore(KLDivergence(dist2)), dist1) + let logScoreWithPointResolution = (~prediction, ~answer, ~prior): t => FromDist( + ToScore(LogScore(answer, prior)), + prediction, + ) let scalePower = (dist, n): t => FromDist(ToDist(Scale(#Power, n)), dist) let scaleLogarithm = (dist, n): t => FromDist(ToDist(Scale(#Logarithm, n)), dist) let scaleLogarithmWithThreshold = (dist, n, eps): t => FromDist( diff --git a/packages/squiggle-lang/src/rescript/Distributions/GenericDist.res b/packages/squiggle-lang/src/rescript/Distributions/GenericDist.res index 2085d72c..1df10240 100644 --- a/packages/squiggle-lang/src/rescript/Distributions/GenericDist.res +++ b/packages/squiggle-lang/src/rescript/Distributions/GenericDist.res @@ -59,11 +59,44 @@ let integralEndY = (t: t): float => let isNormalized = (t: t): bool => Js.Math.abs_float(integralEndY(t) -. 1.0) < 1e-7 -let klDivergence = (t1, t2, ~toPointSetFn: toPointSetFn): result => { - let pointSets = E.R.merge(toPointSetFn(t1), toPointSetFn(t2)) - pointSets |> E.R2.bind(((a, b)) => - PointSetDist.T.klDivergence(a, b)->E.R2.errMap(x => DistributionTypes.OperationError(x)) - ) +module Score = { + let klDivergence = (prediction, answer, ~toPointSetFn: toPointSetFn): result => { + let pointSets = E.R.merge(toPointSetFn(prediction), toPointSetFn(answer)) + pointSets |> E.R2.bind(((predi, ans)) => + PointSetDist.T.klDivergence(predi, ans)->E.R2.errMap(x => DistributionTypes.OperationError(x)) + ) + } + + let logScoreWithPointResolution = ( + ~prediction: DistributionTypes.genericDist, + ~answer: float, + ~prior: option, + ~toPointSetFn: toPointSetFn, + ): result => { + switch prior { + | Some(prior') => + E.R.merge(toPointSetFn(prior'), toPointSetFn(prediction))->E.R.bind((( + prior'', + prediction'', + )) => + PointSetDist.T.logScoreWithPointResolution( + ~prediction=prediction'', + ~answer, + ~prior=prior''->Some, + )->E.R2.errMap(x => DistributionTypes.OperationError(x)) + ) + | None => + prediction + ->toPointSetFn + ->E.R.bind(x => + PointSetDist.T.logScoreWithPointResolution( + ~prediction=x, + ~answer, + ~prior=None, + )->E.R2.errMap(x => DistributionTypes.OperationError(x)) + ) + } + } } let toFloatOperation = ( diff --git a/packages/squiggle-lang/src/rescript/Distributions/GenericDist.resi b/packages/squiggle-lang/src/rescript/Distributions/GenericDist.resi index 03bc5fe8..79fb54ab 100644 --- a/packages/squiggle-lang/src/rescript/Distributions/GenericDist.resi +++ b/packages/squiggle-lang/src/rescript/Distributions/GenericDist.resi @@ -23,7 +23,15 @@ let toFloatOperation: ( ~distToFloatOperation: DistributionTypes.DistributionOperation.toFloat, ) => result -let klDivergence: (t, t, ~toPointSetFn: toPointSetFn) => result +module Score: { + let klDivergence: (t, t, ~toPointSetFn: toPointSetFn) => result + let logScoreWithPointResolution: ( + ~prediction: t, + ~answer: float, + ~prior: option, + ~toPointSetFn: toPointSetFn, + ) => result +} @genType let toPointSet: ( diff --git a/packages/squiggle-lang/src/rescript/Distributions/PointSetDist/Continuous.res b/packages/squiggle-lang/src/rescript/Distributions/PointSetDist/Continuous.res index 3aca0c66..3661a531 100644 --- a/packages/squiggle-lang/src/rescript/Distributions/PointSetDist/Continuous.res +++ b/packages/squiggle-lang/src/rescript/Distributions/PointSetDist/Continuous.res @@ -277,13 +277,12 @@ module T = Dist({ prediction.xyShape, answer.xyShape, ) - let xyShapeToContinuous: XYShape.xyShape => t = xyShape => { - xyShape: xyShape, - interpolation: #Linear, - integralSumCache: None, - integralCache: None, - } - newShape->E.R2.fmap(x => x->xyShapeToContinuous->integralEndY) + newShape->E.R2.fmap(x => x->make->integralEndY) + } + let logScoreWithPointResolution = (~prediction: t, ~answer: float, ~prior: option) => { + let priorPdf = prior->E.O2.fmap((shape, x) => XYShape.XtoY.linear(x, shape.xyShape)) + let predictionPdf = x => XYShape.XtoY.linear(x, prediction.xyShape) + PointSetDist_Scoring.LogScoreWithPointResolution.score(~priorPdf, ~predictionPdf, ~answer) } }) diff --git a/packages/squiggle-lang/src/rescript/Distributions/PointSetDist/Discrete.res b/packages/squiggle-lang/src/rescript/Distributions/PointSetDist/Discrete.res index abb6b793..fea5db6f 100644 --- a/packages/squiggle-lang/src/rescript/Distributions/PointSetDist/Discrete.res +++ b/packages/squiggle-lang/src/rescript/Distributions/PointSetDist/Discrete.res @@ -229,4 +229,7 @@ module T = Dist({ answer, )->E.R2.fmap(integralEndY) } + let logScoreWithPointResolution = (~prediction: t, ~answer: float, ~prior: option) => { + Error(Operation.NotYetImplemented) + } }) diff --git a/packages/squiggle-lang/src/rescript/Distributions/PointSetDist/Distributions.res b/packages/squiggle-lang/src/rescript/Distributions/PointSetDist/Distributions.res index 85ffe4b1..2d0358ec 100644 --- a/packages/squiggle-lang/src/rescript/Distributions/PointSetDist/Distributions.res +++ b/packages/squiggle-lang/src/rescript/Distributions/PointSetDist/Distributions.res @@ -34,6 +34,11 @@ module type dist = { let mean: t => float let variance: t => float let klDivergence: (t, t) => result + let logScoreWithPointResolution: ( + ~prediction: t, + ~answer: float, + ~prior: option, + ) => result } module Dist = (T: dist) => { @@ -57,6 +62,7 @@ module Dist = (T: dist) => { let variance = T.variance let integralEndY = T.integralEndY let klDivergence = T.klDivergence + let logScoreWithPointResolution = T.logScoreWithPointResolution let updateIntegralCache = T.updateIntegralCache diff --git a/packages/squiggle-lang/src/rescript/Distributions/PointSetDist/Mixed.res b/packages/squiggle-lang/src/rescript/Distributions/PointSetDist/Mixed.res index 7bbe2065..42a88909 100644 --- a/packages/squiggle-lang/src/rescript/Distributions/PointSetDist/Mixed.res +++ b/packages/squiggle-lang/src/rescript/Distributions/PointSetDist/Mixed.res @@ -306,6 +306,9 @@ module T = Dist({ let klContinuousPart = Continuous.T.klDivergence(prediction.continuous, answer.continuous) E.R.merge(klDiscretePart, klContinuousPart)->E.R2.fmap(t => fst(t) +. snd(t)) } + let logScoreWithPointResolution = (~prediction: t, ~answer: float, ~prior: option) => { + Error(Operation.NotYetImplemented) + } }) let combineAlgebraically = (op: Operation.convolutionOperation, t1: t, t2: t): t => { diff --git a/packages/squiggle-lang/src/rescript/Distributions/PointSetDist/PointSetDist.res b/packages/squiggle-lang/src/rescript/Distributions/PointSetDist/PointSetDist.res index db47d1e1..d21a7383 100644 --- a/packages/squiggle-lang/src/rescript/Distributions/PointSetDist/PointSetDist.res +++ b/packages/squiggle-lang/src/rescript/Distributions/PointSetDist/PointSetDist.res @@ -196,13 +196,22 @@ module T = Dist({ | Continuous(m) => Continuous.T.variance(m) } - let klDivergence = (t1: t, t2: t) => - switch (t1, t2) { + let klDivergence = (prediction: t, answer: t) => + switch (prediction, answer) { | (Continuous(t1), Continuous(t2)) => Continuous.T.klDivergence(t1, t2) | (Discrete(t1), Discrete(t2)) => Discrete.T.klDivergence(t1, t2) - | (Mixed(t1), Mixed(t2)) => Mixed.T.klDivergence(t1, t2) - | _ => Error(NotYetImplemented) + | (m1, m2) => Mixed.T.klDivergence(m1->toMixed, m2->toMixed) } + + let logScoreWithPointResolution = (~prediction: t, ~answer: float, ~prior: option) => { + switch (prior, prediction) { + | (Some(Continuous(t1)), Continuous(t2)) => + Continuous.T.logScoreWithPointResolution(~prediction=t2, ~answer, ~prior=t1->Some) + | (None, Continuous(t2)) => + Continuous.T.logScoreWithPointResolution(~prediction=t2, ~answer, ~prior=None) + | _ => Error(Operation.NotYetImplemented) + } + } }) let pdf = (f: float, t: t) => { diff --git a/packages/squiggle-lang/src/rescript/Distributions/PointSetDist/PointSetDist_Scoring.res b/packages/squiggle-lang/src/rescript/Distributions/PointSetDist/PointSetDist_Scoring.res index b22883df..532bc76c 100644 --- a/packages/squiggle-lang/src/rescript/Distributions/PointSetDist/PointSetDist_Scoring.res +++ b/packages/squiggle-lang/src/rescript/Distributions/PointSetDist/PointSetDist_Scoring.res @@ -14,3 +14,33 @@ module KLDivergence = { quot < 0.0 ? Error(Operation.ComplexNumberError) : Ok(-.answerElement *. logFn(quot)) } } + +module LogScoreWithPointResolution = { + let logFn = Js.Math.log + let score = ( + ~priorPdf: option float>, + ~predictionPdf: float => float, + ~answer: float, + ): result => { + let numerator = answer->predictionPdf + if numerator < 0.0 { + Operation.PdfInvalidError->Error + } else if numerator == 0.0 { + infinity->Ok + } else { + -.( + switch priorPdf { + | None => numerator->logFn + | Some(f) => { + let priorDensityOfAnswer = f(answer) + if priorDensityOfAnswer == 0.0 { + neg_infinity + } else { + (numerator /. priorDensityOfAnswer)->logFn + } + } + } + )->Ok + } + } +} diff --git a/packages/squiggle-lang/src/rescript/MagicNumbers.res b/packages/squiggle-lang/src/rescript/MagicNumbers.res index 13beafe4..b859421c 100644 --- a/packages/squiggle-lang/src/rescript/MagicNumbers.res +++ b/packages/squiggle-lang/src/rescript/MagicNumbers.res @@ -12,6 +12,7 @@ module Epsilon = { module Environment = { let defaultXYPointLength = 1000 let defaultSampleCount = 10000 + let sparklineLength = 20 } module OpCost = { diff --git a/packages/squiggle-lang/src/rescript/ReducerInterface/ReducerInterface_GenericDistribution.res b/packages/squiggle-lang/src/rescript/ReducerInterface/ReducerInterface_GenericDistribution.res index 1f1291e9..0f29f9db 100644 --- a/packages/squiggle-lang/src/rescript/ReducerInterface/ReducerInterface_GenericDistribution.res +++ b/packages/squiggle-lang/src/rescript/ReducerInterface/ReducerInterface_GenericDistribution.res @@ -1,5 +1,5 @@ module ExpressionValue = ReducerInterface_ExpressionValue -type expressionValue = ReducerInterface_ExpressionValue.expressionValue +type expressionValue = ExpressionValue.expressionValue module Helpers = { let arithmeticMap = r => @@ -162,6 +162,20 @@ module Helpers = { } } } + + let klDivergenceWithPrior = ( + prediction: DistributionTypes.genericDist, + answer: DistributionTypes.genericDist, + prior: DistributionTypes.genericDist, + env: DistributionOperation.env, + ) => { + let term1 = DistributionOperation.Constructors.klDivergence(~env, prediction, answer) + let term2 = DistributionOperation.Constructors.klDivergence(~env, prior, answer) + switch E.R.merge(term1, term2)->E.R2.fmap(((a, b)) => a -. b) { + | Ok(x) => x->DistributionOperation.Float->Some + | Error(_) => None + } + } } module SymbolicConstructors = { @@ -236,7 +250,8 @@ let dispatchToGenericOutput = ( | ("mean", [EvDistribution(dist)]) => Helpers.toFloatFn(#Mean, dist, ~env) | ("integralSum", [EvDistribution(dist)]) => Helpers.toFloatFn(#IntegralSum, dist, ~env) | ("toString", [EvDistribution(dist)]) => Helpers.toStringFn(ToString, dist, ~env) - | ("toSparkline", [EvDistribution(dist)]) => Helpers.toStringFn(ToSparkline(20), dist, ~env) + | ("toSparkline", [EvDistribution(dist)]) => + Helpers.toStringFn(ToSparkline(MagicNumbers.Environment.sparklineLength), dist, ~env) | ("toSparkline", [EvDistribution(dist), EvNumber(n)]) => Helpers.toStringFn(ToSparkline(Belt.Float.toInt(n)), dist, ~env) | ("exp", [EvDistribution(a)]) => @@ -249,8 +264,28 @@ let dispatchToGenericOutput = ( ~env, )->Some | ("normalize", [EvDistribution(dist)]) => Helpers.toDistFn(Normalize, dist, ~env) - | ("klDivergence", [EvDistribution(a), EvDistribution(b)]) => - Some(DistributionOperation.run(FromDist(ToScore(KLDivergence(b)), a), ~env)) + | ("klDivergence", [EvDistribution(prediction), EvDistribution(answer)]) => + Some(DistributionOperation.run(FromDist(ToScore(KLDivergence(answer)), prediction), ~env)) + | ("klDivergence", [EvDistribution(prediction), EvDistribution(answer), EvDistribution(prior)]) => + Helpers.klDivergenceWithPrior(prediction, answer, prior, env) + | ( + "logScoreWithPointAnswer", + [EvDistribution(prediction), EvNumber(answer), EvDistribution(prior)], + ) + | ( + "logScoreWithPointAnswer", + [EvDistribution(prediction), EvDistribution(Symbolic(#Float(answer))), EvDistribution(prior)], + ) => + DistributionOperation.run( + FromDist(ToScore(LogScore(answer, prior->Some)), prediction), + ~env, + )->Some + | ("logScoreWithPointAnswer", [EvDistribution(prediction), EvNumber(answer)]) + | ( + "logScoreWithPointAnswer", + [EvDistribution(prediction), EvDistribution(Symbolic(#Float(answer)))], + ) => + DistributionOperation.run(FromDist(ToScore(LogScore(answer, None)), prediction), ~env)->Some | ("isNormalized", [EvDistribution(dist)]) => Helpers.toBoolFn(IsNormalized, dist, ~env) | ("toPointSet", [EvDistribution(dist)]) => Helpers.toDistFn(ToPointSet, dist, ~env) | ("scaleLog", [EvDistribution(dist)]) => diff --git a/packages/squiggle-lang/src/rescript/Utility/E.res b/packages/squiggle-lang/src/rescript/Utility/E.res index 15678e1a..64729324 100644 --- a/packages/squiggle-lang/src/rescript/Utility/E.res +++ b/packages/squiggle-lang/src/rescript/Utility/E.res @@ -620,6 +620,19 @@ module A = { | Some(o) => o | None => [] } + // REturns `None` there are no non-`None` elements + let rec arrSomeToSomeArr = (optionals: array>): option> => { + let optionals' = optionals->Belt.List.fromArray + switch optionals' { + | list{} => []->Some + | list{x, ...xs} => + switch x { + | Some(_) => xs->Belt.List.toArray->arrSomeToSomeArr + | None => None + } + } + } + let firstSome = x => Belt.Array.getBy(x, O.isSome) } module R = { diff --git a/packages/squiggle-lang/src/rescript/Utility/Operation.res b/packages/squiggle-lang/src/rescript/Utility/Operation.res index cfa18925..3f56493b 100644 --- a/packages/squiggle-lang/src/rescript/Utility/Operation.res +++ b/packages/squiggle-lang/src/rescript/Utility/Operation.res @@ -55,7 +55,7 @@ type operationError = | ComplexNumberError | InfinityError | NegativeInfinityError - | LogicallyInconsistentPathwayError + | PdfInvalidError | NotYetImplemented // should be removed when `klDivergence` for mixed and discrete is implemented. @genType @@ -69,7 +69,7 @@ module Error = { | ComplexNumberError => "Operation returned complex result" | InfinityError => "Operation returned positive infinity" | NegativeInfinityError => "Operation returned negative infinity" - | LogicallyInconsistentPathwayError => "This pathway should have been logically unreachable" + | PdfInvalidError => "This Pdf is invalid" | NotYetImplemented => "This pathway is not yet implemented" } }