From 978e149913886207b222f7c9efdb0ad65114ca7b Mon Sep 17 00:00:00 2001 From: Quinn Dougherty Date: Thu, 12 May 2022 13:11:51 -0400 Subject: [PATCH 01/12] Initialized `logScore` and `logScoreAgainstImproperPrior` Value: [1e-5 to 6e-3] --- .../Distributions/PointSetDist/Continuous.res | 20 ++++++++++++------- .../Distributions/PointSetDist/Discrete.res | 6 ++++++ .../PointSetDist/Distributions.res | 4 ++++ .../Distributions/PointSetDist/Mixed.res | 6 ++++++ .../PointSetDist/PointSetDist.res | 17 ++++++++++++---- .../PointSetDist/PointSetDist_Scoring.res | 17 ++++++++++++++++ 6 files changed, 59 insertions(+), 11 deletions(-) diff --git a/packages/squiggle-lang/src/rescript/Distributions/PointSetDist/Continuous.res b/packages/squiggle-lang/src/rescript/Distributions/PointSetDist/Continuous.res index 3aca0c66..4151254f 100644 --- a/packages/squiggle-lang/src/rescript/Distributions/PointSetDist/Continuous.res +++ b/packages/squiggle-lang/src/rescript/Distributions/PointSetDist/Continuous.res @@ -277,13 +277,19 @@ module T = Dist({ prediction.xyShape, answer.xyShape, ) - let xyShapeToContinuous: XYShape.xyShape => t = xyShape => { - xyShape: xyShape, - interpolation: #Linear, - integralSumCache: None, - integralCache: None, - } - newShape->E.R2.fmap(x => x->xyShapeToContinuous->integralEndY) + newShape->E.R2.fmap(x => x->make->integralEndY) + } + let logScore = (prior: t, prediction: t, answer: float) => { + let newShape = XYShape.PointwiseCombination.combineAlongSupportOfSecondArgument( + PointSetDist_Scoring.LogScore.integrand(~answer), + prior.xyShape, + prediction.xyShape, + ) + newShape->E.R2.fmap(x => x->make->integralEndY) + } + let logScoreAgainstImproperPrior = (prediction: t, answer: float) => { + let prior = make({xs: prediction.xyShape.xs, ys: E.A.fmap(_ => 1.0, prediction.xyShape.xs)}) + logScore(prior, prediction, answer) } }) diff --git a/packages/squiggle-lang/src/rescript/Distributions/PointSetDist/Discrete.res b/packages/squiggle-lang/src/rescript/Distributions/PointSetDist/Discrete.res index abb6b793..9bc274e7 100644 --- a/packages/squiggle-lang/src/rescript/Distributions/PointSetDist/Discrete.res +++ b/packages/squiggle-lang/src/rescript/Distributions/PointSetDist/Discrete.res @@ -229,4 +229,10 @@ module T = Dist({ answer, )->E.R2.fmap(integralEndY) } + let logScore = (prior: t, prediction: t, answer: float) => { + Error(Operation.NotYetImplemented) + } + let logScoreAgainstImproperPrior = (prediction: t, answer: float) => { + Error(Operation.NotYetImplemented) + } }) diff --git a/packages/squiggle-lang/src/rescript/Distributions/PointSetDist/Distributions.res b/packages/squiggle-lang/src/rescript/Distributions/PointSetDist/Distributions.res index 85ffe4b1..9fb7e689 100644 --- a/packages/squiggle-lang/src/rescript/Distributions/PointSetDist/Distributions.res +++ b/packages/squiggle-lang/src/rescript/Distributions/PointSetDist/Distributions.res @@ -34,6 +34,8 @@ module type dist = { let mean: t => float let variance: t => float let klDivergence: (t, t) => result + let logScore: (t, t, float) => result + let logScoreAgainstImproperPrior: (t, float) => result } module Dist = (T: dist) => { @@ -57,6 +59,8 @@ module Dist = (T: dist) => { let variance = T.variance let integralEndY = T.integralEndY let klDivergence = T.klDivergence + let logScore = T.logScore + let logScoreAgainstImproperPrior = T.logScoreAgainstImproperPrior let updateIntegralCache = T.updateIntegralCache diff --git a/packages/squiggle-lang/src/rescript/Distributions/PointSetDist/Mixed.res b/packages/squiggle-lang/src/rescript/Distributions/PointSetDist/Mixed.res index 7bbe2065..50e8a419 100644 --- a/packages/squiggle-lang/src/rescript/Distributions/PointSetDist/Mixed.res +++ b/packages/squiggle-lang/src/rescript/Distributions/PointSetDist/Mixed.res @@ -306,6 +306,12 @@ module T = Dist({ let klContinuousPart = Continuous.T.klDivergence(prediction.continuous, answer.continuous) E.R.merge(klDiscretePart, klContinuousPart)->E.R2.fmap(t => fst(t) +. snd(t)) } + let logScore = (prior: t, prediction: t, answer: float) => { + Error(Operation.NotYetImplemented) + } + let logScoreAgainstImproperPrior = (prediction: t, answer: float) => { + Error(Operation.NotYetImplemented) + } }) let combineAlgebraically = (op: Operation.convolutionOperation, t1: t, t2: t): t => { diff --git a/packages/squiggle-lang/src/rescript/Distributions/PointSetDist/PointSetDist.res b/packages/squiggle-lang/src/rescript/Distributions/PointSetDist/PointSetDist.res index db47d1e1..05e79830 100644 --- a/packages/squiggle-lang/src/rescript/Distributions/PointSetDist/PointSetDist.res +++ b/packages/squiggle-lang/src/rescript/Distributions/PointSetDist/PointSetDist.res @@ -196,13 +196,22 @@ module T = Dist({ | Continuous(m) => Continuous.T.variance(m) } - let klDivergence = (t1: t, t2: t) => - switch (t1, t2) { + let klDivergence = (prediction: t, answer: t) => + switch (prediction, answer) { | (Continuous(t1), Continuous(t2)) => Continuous.T.klDivergence(t1, t2) | (Discrete(t1), Discrete(t2)) => Discrete.T.klDivergence(t1, t2) - | (Mixed(t1), Mixed(t2)) => Mixed.T.klDivergence(t1, t2) - | _ => Error(NotYetImplemented) + | (m1, m2) => Mixed.T.klDivergence(m1->toMixed, m2->toMixed) } + + let logScore = (prior: t, prediction: t, answer: float) => { + switch (prior, prediction) { + | (Continuous(t1), Continuous(t2)) => Continuous.T.logScore(t1, t2, answer) + | _ => Error(Operation.NotYetImplemented) + } + } + let logScoreAgainstImproperPrior = (prediction: t, answer: float) => { + Error(Operation.NotYetImplemented) + } }) let pdf = (f: float, t: t) => { diff --git a/packages/squiggle-lang/src/rescript/Distributions/PointSetDist/PointSetDist_Scoring.res b/packages/squiggle-lang/src/rescript/Distributions/PointSetDist/PointSetDist_Scoring.res index b22883df..8daf260c 100644 --- a/packages/squiggle-lang/src/rescript/Distributions/PointSetDist/PointSetDist_Scoring.res +++ b/packages/squiggle-lang/src/rescript/Distributions/PointSetDist/PointSetDist_Scoring.res @@ -14,3 +14,20 @@ module KLDivergence = { quot < 0.0 ? Error(Operation.ComplexNumberError) : Ok(-.answerElement *. logFn(quot)) } } + +/* + +*/ +module LogScore = { + let logFn = Js.Math.log + let integrand = (priorElement: float, predictionElement: float, ~answer: float) => { + if answer == 0.0 { + Ok(0.0) + } else if predictionElement == 0.0 { + Ok(infinity) + } else { + let quot = predictionElement /. priorElement + quot < 0.0 ? Error(Operation.ComplexNumberError) : Ok(-.answer *. logFn(quot /. answer)) + } + } +} From 51310819a12a510fc0b539a599d013bc9863e3b5 Mon Sep 17 00:00:00 2001 From: Quinn Dougherty Date: Thu, 12 May 2022 14:16:52 -0400 Subject: [PATCH 02/12] `logScore` now in interface. Value: [1e-4 to 1e-1] --- .../Distributions/DistributionOperation.res | 8 ++++- .../Distributions/DistributionOperation.resi | 2 ++ .../Distributions/DistributionTypes.res | 7 +++- .../rescript/Distributions/GenericDist.res | 19 ++++++++--- .../rescript/Distributions/GenericDist.resi | 5 ++- .../Distributions/PointSetDist/Continuous.res | 4 --- .../Distributions/PointSetDist/Discrete.res | 3 -- .../PointSetDist/Distributions.res | 2 -- .../Distributions/PointSetDist/Mixed.res | 3 -- .../src/rescript/MagicNumbers.res | 1 + .../ReducerInterface_GenericDistribution.res | 32 ++++++++++++++++++- 11 files changed, 65 insertions(+), 21 deletions(-) diff --git a/packages/squiggle-lang/src/rescript/Distributions/DistributionOperation.res b/packages/squiggle-lang/src/rescript/Distributions/DistributionOperation.res index 80a53eb8..97085133 100644 --- a/packages/squiggle-lang/src/rescript/Distributions/DistributionOperation.res +++ b/packages/squiggle-lang/src/rescript/Distributions/DistributionOperation.res @@ -145,7 +145,11 @@ let rec run = (~env, functionCallInfo: functionCallInfo): outputType => { } | ToDist(Normalize) => dist->GenericDist.normalize->Dist | ToScore(KLDivergence(t2)) => - GenericDist.klDivergence(dist, t2, ~toPointSetFn) + GenericDist.Score.klDivergence(dist, t2, ~toPointSetFn) + ->E.R2.fmap(r => Float(r)) + ->OutputLocal.fromResult + | ToScore(LogScore(prediction, answer)) => + GenericDist.Score.logScore(dist, prediction, answer, ~toPointSetFn) ->E.R2.fmap(r => Float(r)) ->OutputLocal.fromResult | ToBool(IsNormalized) => dist->GenericDist.isNormalized->Bool @@ -262,6 +266,8 @@ module Constructors = { let normalize = (~env, dist) => C.normalize(dist)->run(~env)->toDistR let isNormalized = (~env, dist) => C.isNormalized(dist)->run(~env)->toBoolR let klDivergence = (~env, dist1, dist2) => C.klDivergence(dist1, dist2)->run(~env)->toFloatR + let logScore = (~env, prior, prediction, answer) => + C.logScore(prior, prediction, answer)->run(~env)->toFloatR let toPointSet = (~env, dist) => C.toPointSet(dist)->run(~env)->toDistR let toSampleSet = (~env, dist, n) => C.toSampleSet(dist, n)->run(~env)->toDistR let fromSamples = (~env, xs) => C.fromSamples(xs)->run(~env)->toDistR diff --git a/packages/squiggle-lang/src/rescript/Distributions/DistributionOperation.resi b/packages/squiggle-lang/src/rescript/Distributions/DistributionOperation.resi index 200be7d7..c3d14014 100644 --- a/packages/squiggle-lang/src/rescript/Distributions/DistributionOperation.resi +++ b/packages/squiggle-lang/src/rescript/Distributions/DistributionOperation.resi @@ -62,6 +62,8 @@ module Constructors: { @genType let klDivergence: (~env: env, genericDist, genericDist) => result @genType + let logScore: (~env: env, genericDist, genericDist, float) => result + @genType let toPointSet: (~env: env, genericDist) => result @genType let toSampleSet: (~env: env, genericDist, int) => result diff --git a/packages/squiggle-lang/src/rescript/Distributions/DistributionTypes.res b/packages/squiggle-lang/src/rescript/Distributions/DistributionTypes.res index a9f7dfbe..35e1b1a7 100644 --- a/packages/squiggle-lang/src/rescript/Distributions/DistributionTypes.res +++ b/packages/squiggle-lang/src/rescript/Distributions/DistributionTypes.res @@ -91,7 +91,7 @@ module DistributionOperation = { | ToString | ToSparkline(int) - type toScore = KLDivergence(genericDist) + type toScore = KLDivergence(genericDist) | LogScore(genericDist, float) type fromDist = | ToFloat(toFloat) @@ -120,6 +120,7 @@ module DistributionOperation = { | ToFloat(#Sample) => `sample` | ToFloat(#IntegralSum) => `integralSum` | ToScore(KLDivergence(_)) => `klDivergence` + | ToScore(LogScore(_, x)) => `logScore against ${E.Float.toFixed(x)}` | ToDist(Normalize) => `normalize` | ToDist(ToPointSet) => `toPointSet` | ToDist(ToSampleSet(r)) => `toSampleSet(${E.I.toString(r)})` @@ -161,6 +162,10 @@ module Constructors = { let truncate = (dist, left, right): t => FromDist(ToDist(Truncate(left, right)), dist) let inspect = (dist): t => FromDist(ToDist(Inspect), dist) let klDivergence = (dist1, dist2): t => FromDist(ToScore(KLDivergence(dist2)), dist1) + let logScore = (prior, prediction, answer): t => FromDist( + ToScore(LogScore(prediction, answer)), + prior, + ) let scalePower = (dist, n): t => FromDist(ToDist(Scale(#Power, n)), dist) let scaleLogarithm = (dist, n): t => FromDist(ToDist(Scale(#Logarithm, n)), dist) let scaleLogarithmWithThreshold = (dist, n, eps): t => FromDist( diff --git a/packages/squiggle-lang/src/rescript/Distributions/GenericDist.res b/packages/squiggle-lang/src/rescript/Distributions/GenericDist.res index 2085d72c..1995ef34 100644 --- a/packages/squiggle-lang/src/rescript/Distributions/GenericDist.res +++ b/packages/squiggle-lang/src/rescript/Distributions/GenericDist.res @@ -59,11 +59,20 @@ let integralEndY = (t: t): float => let isNormalized = (t: t): bool => Js.Math.abs_float(integralEndY(t) -. 1.0) < 1e-7 -let klDivergence = (t1, t2, ~toPointSetFn: toPointSetFn): result => { - let pointSets = E.R.merge(toPointSetFn(t1), toPointSetFn(t2)) - pointSets |> E.R2.bind(((a, b)) => - PointSetDist.T.klDivergence(a, b)->E.R2.errMap(x => DistributionTypes.OperationError(x)) - ) +module Score = { + let klDivergence = (t1, t2, ~toPointSetFn: toPointSetFn): result => { + let pointSets = E.R.merge(toPointSetFn(t1), toPointSetFn(t2)) + pointSets |> E.R2.bind(((a, b)) => + PointSetDist.T.klDivergence(a, b)->E.R2.errMap(x => DistributionTypes.OperationError(x)) + ) + } + + let logScore = (prior, prediction, answer, ~toPointSetFn: toPointSetFn): result => { + let pointSets = E.R.merge(toPointSetFn(prior), toPointSetFn(prediction)) + pointSets |> E.R2.bind(((a, b)) => + PointSetDist.T.logScore(a, b, answer)->E.R2.errMap(x => DistributionTypes.OperationError(x)) + ) + } } let toFloatOperation = ( diff --git a/packages/squiggle-lang/src/rescript/Distributions/GenericDist.resi b/packages/squiggle-lang/src/rescript/Distributions/GenericDist.resi index 03bc5fe8..45f1e8f8 100644 --- a/packages/squiggle-lang/src/rescript/Distributions/GenericDist.resi +++ b/packages/squiggle-lang/src/rescript/Distributions/GenericDist.resi @@ -23,7 +23,10 @@ let toFloatOperation: ( ~distToFloatOperation: DistributionTypes.DistributionOperation.toFloat, ) => result -let klDivergence: (t, t, ~toPointSetFn: toPointSetFn) => result +module Score: { + let klDivergence: (t, t, ~toPointSetFn: toPointSetFn) => result + let logScore: (t, t, float, ~toPointSetFn: toPointSetFn) => result +} @genType let toPointSet: ( diff --git a/packages/squiggle-lang/src/rescript/Distributions/PointSetDist/Continuous.res b/packages/squiggle-lang/src/rescript/Distributions/PointSetDist/Continuous.res index 4151254f..7bf9c874 100644 --- a/packages/squiggle-lang/src/rescript/Distributions/PointSetDist/Continuous.res +++ b/packages/squiggle-lang/src/rescript/Distributions/PointSetDist/Continuous.res @@ -287,10 +287,6 @@ module T = Dist({ ) newShape->E.R2.fmap(x => x->make->integralEndY) } - let logScoreAgainstImproperPrior = (prediction: t, answer: float) => { - let prior = make({xs: prediction.xyShape.xs, ys: E.A.fmap(_ => 1.0, prediction.xyShape.xs)}) - logScore(prior, prediction, answer) - } }) let isNormalized = (t: t): bool => { diff --git a/packages/squiggle-lang/src/rescript/Distributions/PointSetDist/Discrete.res b/packages/squiggle-lang/src/rescript/Distributions/PointSetDist/Discrete.res index 9bc274e7..e10ed981 100644 --- a/packages/squiggle-lang/src/rescript/Distributions/PointSetDist/Discrete.res +++ b/packages/squiggle-lang/src/rescript/Distributions/PointSetDist/Discrete.res @@ -232,7 +232,4 @@ module T = Dist({ let logScore = (prior: t, prediction: t, answer: float) => { Error(Operation.NotYetImplemented) } - let logScoreAgainstImproperPrior = (prediction: t, answer: float) => { - Error(Operation.NotYetImplemented) - } }) diff --git a/packages/squiggle-lang/src/rescript/Distributions/PointSetDist/Distributions.res b/packages/squiggle-lang/src/rescript/Distributions/PointSetDist/Distributions.res index 9fb7e689..aa44da10 100644 --- a/packages/squiggle-lang/src/rescript/Distributions/PointSetDist/Distributions.res +++ b/packages/squiggle-lang/src/rescript/Distributions/PointSetDist/Distributions.res @@ -35,7 +35,6 @@ module type dist = { let variance: t => float let klDivergence: (t, t) => result let logScore: (t, t, float) => result - let logScoreAgainstImproperPrior: (t, float) => result } module Dist = (T: dist) => { @@ -60,7 +59,6 @@ module Dist = (T: dist) => { let integralEndY = T.integralEndY let klDivergence = T.klDivergence let logScore = T.logScore - let logScoreAgainstImproperPrior = T.logScoreAgainstImproperPrior let updateIntegralCache = T.updateIntegralCache diff --git a/packages/squiggle-lang/src/rescript/Distributions/PointSetDist/Mixed.res b/packages/squiggle-lang/src/rescript/Distributions/PointSetDist/Mixed.res index 50e8a419..2fd046ee 100644 --- a/packages/squiggle-lang/src/rescript/Distributions/PointSetDist/Mixed.res +++ b/packages/squiggle-lang/src/rescript/Distributions/PointSetDist/Mixed.res @@ -309,9 +309,6 @@ module T = Dist({ let logScore = (prior: t, prediction: t, answer: float) => { Error(Operation.NotYetImplemented) } - let logScoreAgainstImproperPrior = (prediction: t, answer: float) => { - Error(Operation.NotYetImplemented) - } }) let combineAlgebraically = (op: Operation.convolutionOperation, t1: t, t2: t): t => { diff --git a/packages/squiggle-lang/src/rescript/MagicNumbers.res b/packages/squiggle-lang/src/rescript/MagicNumbers.res index 13beafe4..b859421c 100644 --- a/packages/squiggle-lang/src/rescript/MagicNumbers.res +++ b/packages/squiggle-lang/src/rescript/MagicNumbers.res @@ -12,6 +12,7 @@ module Epsilon = { module Environment = { let defaultXYPointLength = 1000 let defaultSampleCount = 10000 + let sparklineLength = 20 } module OpCost = { diff --git a/packages/squiggle-lang/src/rescript/ReducerInterface/ReducerInterface_GenericDistribution.res b/packages/squiggle-lang/src/rescript/ReducerInterface/ReducerInterface_GenericDistribution.res index 8dae2586..7e5709bd 100644 --- a/packages/squiggle-lang/src/rescript/ReducerInterface/ReducerInterface_GenericDistribution.res +++ b/packages/squiggle-lang/src/rescript/ReducerInterface/ReducerInterface_GenericDistribution.res @@ -157,6 +157,20 @@ module Helpers = { } } } + let constructNonNormalizedPointSet = ( + ~supportOf: DistributionTypes.genericDist, + fn: float => float, + ): option => { + switch supportOf { + | PointSet(Continuous(dist)) => + {xs: dist.xyShape.xs, ys: E.A.fmap(fn, dist.xyShape.xs)} + ->Continuous.make + ->Continuous + ->PointSet + ->Some + | _ => None + } + } } module SymbolicConstructors = { @@ -219,7 +233,8 @@ let dispatchToGenericOutput = (call: ExpressionValue.functionCall, _environment) | ("mean", [EvDistribution(dist)]) => Helpers.toFloatFn(#Mean, dist) | ("integralSum", [EvDistribution(dist)]) => Helpers.toFloatFn(#IntegralSum, dist) | ("toString", [EvDistribution(dist)]) => Helpers.toStringFn(ToString, dist) - | ("toSparkline", [EvDistribution(dist)]) => Helpers.toStringFn(ToSparkline(20), dist) + | ("toSparkline", [EvDistribution(dist)]) => + Helpers.toStringFn(ToSparkline(MagicNumbers.Environment.sparklineLength), dist) | ("toSparkline", [EvDistribution(dist), EvNumber(n)]) => Helpers.toStringFn(ToSparkline(Belt.Float.toInt(n)), dist) | ("exp", [EvDistribution(a)]) => @@ -233,6 +248,21 @@ let dispatchToGenericOutput = (call: ExpressionValue.functionCall, _environment) | ("normalize", [EvDistribution(dist)]) => Helpers.toDistFn(Normalize, dist) | ("klDivergence", [EvDistribution(a), EvDistribution(b)]) => Some(runGenericOperation(FromDist(ToScore(KLDivergence(b)), a))) + | ("logScore", [EvDistribution(prior), EvDistribution(prediction), EvNumber(answer)]) + | ( + "logScore", + [EvDistribution(prior), EvDistribution(prediction), EvDistribution(Symbolic(#Float(answer)))], + ) => + Some(runGenericOperation(FromDist(ToScore(LogScore(prediction, answer)), prior))) + | ("logScoreAgainstImproperPrior", [EvDistribution(prediction), EvNumber(answer)]) + | ( + "logScoreAgainstImproperPrior", + [EvDistribution(prediction), EvDistribution(Symbolic(#Float(answer)))], + ) => + E.O.fmap( + d => runGenericOperation(FromDist(ToScore(LogScore(prediction, answer)), d)), + Helpers.constructNonNormalizedPointSet(~supportOf=prediction, _ => 1.0), + ) | ("isNormalized", [EvDistribution(dist)]) => Helpers.toBoolFn(IsNormalized, dist) | ("toPointSet", [EvDistribution(dist)]) => Helpers.toDistFn(ToPointSet, dist) | ("scaleLog", [EvDistribution(dist)]) => From 65751e590a7c240680617e3a1e8d82c4a1aec8d3 Mon Sep 17 00:00:00 2001 From: Quinn Dougherty Date: Thu, 12 May 2022 15:26:51 -0400 Subject: [PATCH 03/12] Fixed `logScoreAgainstImproperPrior` by finding how it was `None` Value: [1e-4 to 8e-2] --- .../ReducerInterface_GenericDistribution.res | 30 ++++++++++--------- 1 file changed, 16 insertions(+), 14 deletions(-) diff --git a/packages/squiggle-lang/src/rescript/ReducerInterface/ReducerInterface_GenericDistribution.res b/packages/squiggle-lang/src/rescript/ReducerInterface/ReducerInterface_GenericDistribution.res index 7e5709bd..b7f5f824 100644 --- a/packages/squiggle-lang/src/rescript/ReducerInterface/ReducerInterface_GenericDistribution.res +++ b/packages/squiggle-lang/src/rescript/ReducerInterface/ReducerInterface_GenericDistribution.res @@ -160,16 +160,16 @@ module Helpers = { let constructNonNormalizedPointSet = ( ~supportOf: DistributionTypes.genericDist, fn: float => float, - ): option => { - switch supportOf { - | PointSet(Continuous(dist)) => - {xs: dist.xyShape.xs, ys: E.A.fmap(fn, dist.xyShape.xs)} - ->Continuous.make - ->Continuous - ->PointSet - ->Some - | _ => None + ): DistributionTypes.genericDist => { + let cdf = x => toFloatFn(#Cdf(x), supportOf) + let leftEndpoint = cdf(MagicNumbers.Epsilon.ten) + let rightEndpoint = cdf(1.0 -. MagicNumbers.Epsilon.ten) + let xs = switch (leftEndpoint, rightEndpoint) { + | (Some(Float(a)), Some(Float(b))) => + E.A.Floats.range(a, b, MagicNumbers.Environment.defaultXYPointLength) + | _ => [] } + {xs: xs, ys: E.A.fmap(fn, xs)}->Continuous.make->Continuous->DistributionTypes.PointSet } } @@ -253,16 +253,18 @@ let dispatchToGenericOutput = (call: ExpressionValue.functionCall, _environment) "logScore", [EvDistribution(prior), EvDistribution(prediction), EvDistribution(Symbolic(#Float(answer)))], ) => - Some(runGenericOperation(FromDist(ToScore(LogScore(prediction, answer)), prior))) + runGenericOperation(FromDist(ToScore(LogScore(prediction, answer)), prior))->Some | ("logScoreAgainstImproperPrior", [EvDistribution(prediction), EvNumber(answer)]) | ( "logScoreAgainstImproperPrior", [EvDistribution(prediction), EvDistribution(Symbolic(#Float(answer)))], ) => - E.O.fmap( - d => runGenericOperation(FromDist(ToScore(LogScore(prediction, answer)), d)), - Helpers.constructNonNormalizedPointSet(~supportOf=prediction, _ => 1.0), - ) + runGenericOperation( + FromDist( + ToScore(LogScore(prediction, answer)), + Helpers.constructNonNormalizedPointSet(~supportOf=prediction, _ => 1.0), + ), + )->Some | ("isNormalized", [EvDistribution(dist)]) => Helpers.toBoolFn(IsNormalized, dist) | ("toPointSet", [EvDistribution(dist)]) => Helpers.toDistFn(ToPointSet, dist) | ("scaleLog", [EvDistribution(dist)]) => From 2ab395b4e5ff727f08826b3c470b024b51ef3104 Mon Sep 17 00:00:00 2001 From: Quinn Dougherty Date: Thu, 12 May 2022 16:03:29 -0400 Subject: [PATCH 04/12] Some minor CR Value: [1e-10 to 1e-4] --- .../src/rescript/Distributions/GenericDist.res | 8 ++++---- .../rescript/Distributions/PointSetDist/PointSetDist.res | 3 --- .../ReducerInterface_GenericDistribution.res | 5 ++++- 3 files changed, 8 insertions(+), 8 deletions(-) diff --git a/packages/squiggle-lang/src/rescript/Distributions/GenericDist.res b/packages/squiggle-lang/src/rescript/Distributions/GenericDist.res index 1995ef34..3e067c51 100644 --- a/packages/squiggle-lang/src/rescript/Distributions/GenericDist.res +++ b/packages/squiggle-lang/src/rescript/Distributions/GenericDist.res @@ -60,10 +60,10 @@ let integralEndY = (t: t): float => let isNormalized = (t: t): bool => Js.Math.abs_float(integralEndY(t) -. 1.0) < 1e-7 module Score = { - let klDivergence = (t1, t2, ~toPointSetFn: toPointSetFn): result => { - let pointSets = E.R.merge(toPointSetFn(t1), toPointSetFn(t2)) - pointSets |> E.R2.bind(((a, b)) => - PointSetDist.T.klDivergence(a, b)->E.R2.errMap(x => DistributionTypes.OperationError(x)) + let klDivergence = (prediction, answer, ~toPointSetFn: toPointSetFn): result => { + let pointSets = E.R.merge(toPointSetFn(prediction), toPointSetFn(answer)) + pointSets |> E.R2.bind(((predi, ans)) => + PointSetDist.T.klDivergence(predi, ans)->E.R2.errMap(x => DistributionTypes.OperationError(x)) ) } diff --git a/packages/squiggle-lang/src/rescript/Distributions/PointSetDist/PointSetDist.res b/packages/squiggle-lang/src/rescript/Distributions/PointSetDist/PointSetDist.res index 05e79830..2f8ebee3 100644 --- a/packages/squiggle-lang/src/rescript/Distributions/PointSetDist/PointSetDist.res +++ b/packages/squiggle-lang/src/rescript/Distributions/PointSetDist/PointSetDist.res @@ -209,9 +209,6 @@ module T = Dist({ | _ => Error(Operation.NotYetImplemented) } } - let logScoreAgainstImproperPrior = (prediction: t, answer: float) => { - Error(Operation.NotYetImplemented) - } }) let pdf = (f: float, t: t) => { diff --git a/packages/squiggle-lang/src/rescript/ReducerInterface/ReducerInterface_GenericDistribution.res b/packages/squiggle-lang/src/rescript/ReducerInterface/ReducerInterface_GenericDistribution.res index b7f5f824..fb26189f 100644 --- a/packages/squiggle-lang/src/rescript/ReducerInterface/ReducerInterface_GenericDistribution.res +++ b/packages/squiggle-lang/src/rescript/ReducerInterface/ReducerInterface_GenericDistribution.res @@ -169,7 +169,10 @@ module Helpers = { E.A.Floats.range(a, b, MagicNumbers.Environment.defaultXYPointLength) | _ => [] } - {xs: xs, ys: E.A.fmap(fn, xs)}->Continuous.make->Continuous->DistributionTypes.PointSet + {xs: xs, ys: E.A.fmap(fn, xs)} + ->Continuous.make + ->PointSetTypes.Continuous + ->DistributionTypes.PointSet } } From 3eef57f8556aebc519cb4d895ca456135a1aa0d5 Mon Sep 17 00:00:00 2001 From: Quinn Dougherty Date: Fri, 13 May 2022 13:18:52 -0400 Subject: [PATCH 05/12] proof of concept for records as arguments Value: [1e-3 to 8e-1] --- .../ReducerInterface_GenericDistribution.res | 18 ++++++++++++++++-- .../squiggle-lang/src/rescript/Utility/E.res | 11 +++++++++++ 2 files changed, 27 insertions(+), 2 deletions(-) diff --git a/packages/squiggle-lang/src/rescript/ReducerInterface/ReducerInterface_GenericDistribution.res b/packages/squiggle-lang/src/rescript/ReducerInterface/ReducerInterface_GenericDistribution.res index fb26189f..b2bf8a32 100644 --- a/packages/squiggle-lang/src/rescript/ReducerInterface/ReducerInterface_GenericDistribution.res +++ b/packages/squiggle-lang/src/rescript/ReducerInterface/ReducerInterface_GenericDistribution.res @@ -1,5 +1,5 @@ module ExpressionValue = ReducerInterface_ExpressionValue -type expressionValue = ReducerInterface_ExpressionValue.expressionValue +type expressionValue = ExpressionValue.expressionValue let defaultEnv: DistributionOperation.env = { sampleCount: MagicNumbers.Environment.defaultSampleCount, @@ -210,7 +210,7 @@ module SymbolicConstructors = { } } -let dispatchToGenericOutput = (call: ExpressionValue.functionCall, _environment): option< +let rec dispatchToGenericOutput = (call: ExpressionValue.functionCall, _environment): option< DistributionOperation.outputType, > => { let (fnName, args) = call @@ -257,6 +257,10 @@ let dispatchToGenericOutput = (call: ExpressionValue.functionCall, _environment) [EvDistribution(prior), EvDistribution(prediction), EvDistribution(Symbolic(#Float(answer)))], ) => runGenericOperation(FromDist(ToScore(LogScore(prediction, answer)), prior))->Some + | ("logScore", [EvRecord(r)]) => + recurRecordArgs("logScore", ["prior", "prediction", "answer"], r, _environment) + | ("increment", [EvNumber(x)]) => (x +. 1.0)->DistributionOperation.Float->Some + | ("increment", [EvRecord(r)]) => recurRecordArgs("increment", ["incrementee"], r, _environment) | ("logScoreAgainstImproperPrior", [EvDistribution(prediction), EvNumber(answer)]) | ( "logScoreAgainstImproperPrior", @@ -340,6 +344,16 @@ let dispatchToGenericOutput = (call: ExpressionValue.functionCall, _environment) | _ => None } } +and recurRecordArgs = ( + fnName: string, + argNames: array, + args: ExpressionValue.record, + _environment: 'a, +): option => + // argNames -> E.A2.fmap(x => Js.Dict.get(args, x)) -> E.A.O.arrSomeToSomeArr -> E.O.bind(a => dispatchToGenericOutput((fnName, a), _environment)) + argNames + ->E.A2.fmap(x => Js.Dict.unsafeGet(args, x)) + ->(a => dispatchToGenericOutput((fnName, a), _environment)) let genericOutputToReducerValue = (o: DistributionOperation.outputType): result< expressionValue, diff --git a/packages/squiggle-lang/src/rescript/Utility/E.res b/packages/squiggle-lang/src/rescript/Utility/E.res index 15678e1a..63999ce6 100644 --- a/packages/squiggle-lang/src/rescript/Utility/E.res +++ b/packages/squiggle-lang/src/rescript/Utility/E.res @@ -620,6 +620,17 @@ module A = { | Some(o) => o | None => [] } + let rec arrSomeToSomeArr = (optionals: array>): option> => { + let optionals' = optionals->Belt.List.fromArray + switch optionals' { + | list{} => []->Some + | list{x, ...xs} => + switch x { + | Some(_) => xs->Belt.List.toArray->arrSomeToSomeArr + | None => None + } + } + } } module R = { From b4a1137019731994dc1a915f28740158d4299228 Mon Sep 17 00:00:00 2001 From: Quinn Dougherty Date: Fri, 13 May 2022 15:43:59 -0400 Subject: [PATCH 06/12] Implemented correct math underlying `logScoreWithPointResolution` Value: [1e-2 to 7e-1] Realized that I need to switch argument order, put `prior` last maybe. --- .../Distributions/DistributionOperation.res | 4 +-- .../Distributions/DistributionTypes.res | 2 +- .../rescript/Distributions/GenericDist.res | 31 +++++++++++++++--- .../rescript/Distributions/GenericDist.resi | 7 +++- .../Distributions/PointSetDist/Continuous.res | 11 +++---- .../Distributions/PointSetDist/Discrete.res | 2 +- .../PointSetDist/Distributions.res | 4 +-- .../Distributions/PointSetDist/Mixed.res | 2 +- .../PointSetDist/PointSetDist.res | 6 ++-- .../PointSetDist/PointSetDist_Scoring.res | 32 ++++++++++++++----- 10 files changed, 71 insertions(+), 30 deletions(-) diff --git a/packages/squiggle-lang/src/rescript/Distributions/DistributionOperation.res b/packages/squiggle-lang/src/rescript/Distributions/DistributionOperation.res index 97085133..4c4291eb 100644 --- a/packages/squiggle-lang/src/rescript/Distributions/DistributionOperation.res +++ b/packages/squiggle-lang/src/rescript/Distributions/DistributionOperation.res @@ -149,7 +149,7 @@ let rec run = (~env, functionCallInfo: functionCallInfo): outputType => { ->E.R2.fmap(r => Float(r)) ->OutputLocal.fromResult | ToScore(LogScore(prediction, answer)) => - GenericDist.Score.logScore(dist, prediction, answer, ~toPointSetFn) + GenericDist.Score.logScoreWithPointResolution(Some(dist), prediction, answer, ~toPointSetFn) ->E.R2.fmap(r => Float(r)) ->OutputLocal.fromResult | ToBool(IsNormalized) => dist->GenericDist.isNormalized->Bool @@ -267,7 +267,7 @@ module Constructors = { let isNormalized = (~env, dist) => C.isNormalized(dist)->run(~env)->toBoolR let klDivergence = (~env, dist1, dist2) => C.klDivergence(dist1, dist2)->run(~env)->toFloatR let logScore = (~env, prior, prediction, answer) => - C.logScore(prior, prediction, answer)->run(~env)->toFloatR + C.logScoreWithPointResolution(prior, prediction, answer)->run(~env)->toFloatR let toPointSet = (~env, dist) => C.toPointSet(dist)->run(~env)->toDistR let toSampleSet = (~env, dist, n) => C.toSampleSet(dist, n)->run(~env)->toDistR let fromSamples = (~env, xs) => C.fromSamples(xs)->run(~env)->toDistR diff --git a/packages/squiggle-lang/src/rescript/Distributions/DistributionTypes.res b/packages/squiggle-lang/src/rescript/Distributions/DistributionTypes.res index 35e1b1a7..480775bf 100644 --- a/packages/squiggle-lang/src/rescript/Distributions/DistributionTypes.res +++ b/packages/squiggle-lang/src/rescript/Distributions/DistributionTypes.res @@ -162,7 +162,7 @@ module Constructors = { let truncate = (dist, left, right): t => FromDist(ToDist(Truncate(left, right)), dist) let inspect = (dist): t => FromDist(ToDist(Inspect), dist) let klDivergence = (dist1, dist2): t => FromDist(ToScore(KLDivergence(dist2)), dist1) - let logScore = (prior, prediction, answer): t => FromDist( + let logScoreWithPointResolution = (prior, prediction, answer): t => FromDist( ToScore(LogScore(prediction, answer)), prior, ) diff --git a/packages/squiggle-lang/src/rescript/Distributions/GenericDist.res b/packages/squiggle-lang/src/rescript/Distributions/GenericDist.res index 3e067c51..3357556f 100644 --- a/packages/squiggle-lang/src/rescript/Distributions/GenericDist.res +++ b/packages/squiggle-lang/src/rescript/Distributions/GenericDist.res @@ -67,11 +67,32 @@ module Score = { ) } - let logScore = (prior, prediction, answer, ~toPointSetFn: toPointSetFn): result => { - let pointSets = E.R.merge(toPointSetFn(prior), toPointSetFn(prediction)) - pointSets |> E.R2.bind(((a, b)) => - PointSetDist.T.logScore(a, b, answer)->E.R2.errMap(x => DistributionTypes.OperationError(x)) - ) + let logScoreWithPointResolution = ( + prior, + prediction, + answer, + ~toPointSetFn: toPointSetFn, + ): result => { + switch prior { + | Some(prior') => + E.R.merge(toPointSetFn(prior'), toPointSetFn(prediction))->E.R.bind(((a, b)) => + PointSetDist.T.logScoreWithPointResolution( + a->Some, + b, + answer, + )->E.R2.errMap(x => DistributionTypes.OperationError(x)) + ) + | None => + prediction + ->toPointSetFn + ->E.R.bind(x => + PointSetDist.T.logScoreWithPointResolution( + None, + x, + answer, + )->E.R2.errMap(x => DistributionTypes.OperationError(x)) + ) + } } } diff --git a/packages/squiggle-lang/src/rescript/Distributions/GenericDist.resi b/packages/squiggle-lang/src/rescript/Distributions/GenericDist.resi index 45f1e8f8..712e38b9 100644 --- a/packages/squiggle-lang/src/rescript/Distributions/GenericDist.resi +++ b/packages/squiggle-lang/src/rescript/Distributions/GenericDist.resi @@ -25,7 +25,12 @@ let toFloatOperation: ( module Score: { let klDivergence: (t, t, ~toPointSetFn: toPointSetFn) => result - let logScore: (t, t, float, ~toPointSetFn: toPointSetFn) => result + let logScoreWithPointResolution: ( + option, + t, + float, + ~toPointSetFn: toPointSetFn, + ) => result } @genType diff --git a/packages/squiggle-lang/src/rescript/Distributions/PointSetDist/Continuous.res b/packages/squiggle-lang/src/rescript/Distributions/PointSetDist/Continuous.res index 7bf9c874..22aaaf3b 100644 --- a/packages/squiggle-lang/src/rescript/Distributions/PointSetDist/Continuous.res +++ b/packages/squiggle-lang/src/rescript/Distributions/PointSetDist/Continuous.res @@ -279,13 +279,10 @@ module T = Dist({ ) newShape->E.R2.fmap(x => x->make->integralEndY) } - let logScore = (prior: t, prediction: t, answer: float) => { - let newShape = XYShape.PointwiseCombination.combineAlongSupportOfSecondArgument( - PointSetDist_Scoring.LogScore.integrand(~answer), - prior.xyShape, - prediction.xyShape, - ) - newShape->E.R2.fmap(x => x->make->integralEndY) + let logScoreWithPointResolution = (prior: option, prediction: t, answer: float) => { + let priorPdf = prior->E.O2.fmap((shape, x) => XYShape.XtoY.linear(x, shape.xyShape)) + let predictionPdf = x => XYShape.XtoY.linear(x, prediction.xyShape) + PointSetDist_Scoring.LogScoreWithPointResolution.score(~priorPdf, ~predictionPdf, ~answer) } }) diff --git a/packages/squiggle-lang/src/rescript/Distributions/PointSetDist/Discrete.res b/packages/squiggle-lang/src/rescript/Distributions/PointSetDist/Discrete.res index e10ed981..26e15f6e 100644 --- a/packages/squiggle-lang/src/rescript/Distributions/PointSetDist/Discrete.res +++ b/packages/squiggle-lang/src/rescript/Distributions/PointSetDist/Discrete.res @@ -229,7 +229,7 @@ module T = Dist({ answer, )->E.R2.fmap(integralEndY) } - let logScore = (prior: t, prediction: t, answer: float) => { + let logScoreWithPointResolution = (prior: option, prediction: t, answer: float) => { Error(Operation.NotYetImplemented) } }) diff --git a/packages/squiggle-lang/src/rescript/Distributions/PointSetDist/Distributions.res b/packages/squiggle-lang/src/rescript/Distributions/PointSetDist/Distributions.res index aa44da10..014c0668 100644 --- a/packages/squiggle-lang/src/rescript/Distributions/PointSetDist/Distributions.res +++ b/packages/squiggle-lang/src/rescript/Distributions/PointSetDist/Distributions.res @@ -34,7 +34,7 @@ module type dist = { let mean: t => float let variance: t => float let klDivergence: (t, t) => result - let logScore: (t, t, float) => result + let logScoreWithPointResolution: (option, t, float) => result } module Dist = (T: dist) => { @@ -58,7 +58,7 @@ module Dist = (T: dist) => { let variance = T.variance let integralEndY = T.integralEndY let klDivergence = T.klDivergence - let logScore = T.logScore + let logScoreWithPointResolution = T.logScoreWithPointResolution let updateIntegralCache = T.updateIntegralCache diff --git a/packages/squiggle-lang/src/rescript/Distributions/PointSetDist/Mixed.res b/packages/squiggle-lang/src/rescript/Distributions/PointSetDist/Mixed.res index 2fd046ee..4f864856 100644 --- a/packages/squiggle-lang/src/rescript/Distributions/PointSetDist/Mixed.res +++ b/packages/squiggle-lang/src/rescript/Distributions/PointSetDist/Mixed.res @@ -306,7 +306,7 @@ module T = Dist({ let klContinuousPart = Continuous.T.klDivergence(prediction.continuous, answer.continuous) E.R.merge(klDiscretePart, klContinuousPart)->E.R2.fmap(t => fst(t) +. snd(t)) } - let logScore = (prior: t, prediction: t, answer: float) => { + let logScoreWithPointResolution = (prior: option, prediction: t, answer: float) => { Error(Operation.NotYetImplemented) } }) diff --git a/packages/squiggle-lang/src/rescript/Distributions/PointSetDist/PointSetDist.res b/packages/squiggle-lang/src/rescript/Distributions/PointSetDist/PointSetDist.res index 2f8ebee3..7f87fd01 100644 --- a/packages/squiggle-lang/src/rescript/Distributions/PointSetDist/PointSetDist.res +++ b/packages/squiggle-lang/src/rescript/Distributions/PointSetDist/PointSetDist.res @@ -203,9 +203,11 @@ module T = Dist({ | (m1, m2) => Mixed.T.klDivergence(m1->toMixed, m2->toMixed) } - let logScore = (prior: t, prediction: t, answer: float) => { + let logScoreWithPointResolution = (prior: option, prediction: t, answer: float) => { switch (prior, prediction) { - | (Continuous(t1), Continuous(t2)) => Continuous.T.logScore(t1, t2, answer) + | (Some(Continuous(t1)), Continuous(t2)) => + Continuous.T.logScoreWithPointResolution(t1->Some, t2, answer) + | (None, Continuous(t2)) => Continuous.T.logScoreWithPointResolution(None, t2, answer) | _ => Error(Operation.NotYetImplemented) } } diff --git a/packages/squiggle-lang/src/rescript/Distributions/PointSetDist/PointSetDist_Scoring.res b/packages/squiggle-lang/src/rescript/Distributions/PointSetDist/PointSetDist_Scoring.res index 8daf260c..ddf5207d 100644 --- a/packages/squiggle-lang/src/rescript/Distributions/PointSetDist/PointSetDist_Scoring.res +++ b/packages/squiggle-lang/src/rescript/Distributions/PointSetDist/PointSetDist_Scoring.res @@ -18,16 +18,32 @@ module KLDivergence = { /* */ -module LogScore = { +module LogScoreWithPointResolution = { let logFn = Js.Math.log - let integrand = (priorElement: float, predictionElement: float, ~answer: float) => { - if answer == 0.0 { - Ok(0.0) - } else if predictionElement == 0.0 { - Ok(infinity) + let score = ( + ~priorPdf: option float>, + ~predictionPdf: float => float, + ~answer: float, + ): result => { + let numer = answer->predictionPdf + if numer < 0.0 { + Operation.ComplexNumberError->Error + } else if numer == 0.0 { + infinity->Ok } else { - let quot = predictionElement /. priorElement - quot < 0.0 ? Error(Operation.ComplexNumberError) : Ok(-.answer *. logFn(quot /. answer)) + -.( + switch priorPdf { + | None => numer->logFn + | Some(f) => { + let priorDensityOfAnswer = f(answer) + if priorDensityOfAnswer == 0.0 { + neg_infinity + } else { + (numer /. priorDensityOfAnswer)->logFn + } + } + } + )->Ok } } } From bdbb86aa9ed4423cd988efbaf29cb807060c634c Mon Sep 17 00:00:00 2001 From: Quinn Dougherty Date: Fri, 13 May 2022 16:15:04 -0400 Subject: [PATCH 07/12] `logScore` on records now interprets almost every which way we're interested in Value: [1e-3 to 9e-1] --- .../Distributions/DistributionOperation.res | 8 ++-- .../Distributions/DistributionOperation.resi | 7 +++- .../Distributions/DistributionTypes.res | 4 +- .../rescript/Distributions/GenericDist.res | 6 +-- .../rescript/Distributions/GenericDist.resi | 2 +- .../Distributions/PointSetDist/Continuous.res | 2 +- .../Distributions/PointSetDist/Discrete.res | 2 +- .../PointSetDist/Distributions.res | 2 +- .../Distributions/PointSetDist/Mixed.res | 2 +- .../PointSetDist/PointSetDist.res | 6 +-- .../ReducerInterface_GenericDistribution.res | 37 +++++++++++-------- .../squiggle-lang/src/rescript/Utility/E.res | 11 ++++++ 12 files changed, 56 insertions(+), 33 deletions(-) diff --git a/packages/squiggle-lang/src/rescript/Distributions/DistributionOperation.res b/packages/squiggle-lang/src/rescript/Distributions/DistributionOperation.res index 4c4291eb..1e2ff872 100644 --- a/packages/squiggle-lang/src/rescript/Distributions/DistributionOperation.res +++ b/packages/squiggle-lang/src/rescript/Distributions/DistributionOperation.res @@ -148,8 +148,8 @@ let rec run = (~env, functionCallInfo: functionCallInfo): outputType => { GenericDist.Score.klDivergence(dist, t2, ~toPointSetFn) ->E.R2.fmap(r => Float(r)) ->OutputLocal.fromResult - | ToScore(LogScore(prediction, answer)) => - GenericDist.Score.logScoreWithPointResolution(Some(dist), prediction, answer, ~toPointSetFn) + | ToScore(LogScore(answer, prior)) => + GenericDist.Score.logScoreWithPointResolution(dist, answer, prior, ~toPointSetFn) ->E.R2.fmap(r => Float(r)) ->OutputLocal.fromResult | ToBool(IsNormalized) => dist->GenericDist.isNormalized->Bool @@ -266,8 +266,8 @@ module Constructors = { let normalize = (~env, dist) => C.normalize(dist)->run(~env)->toDistR let isNormalized = (~env, dist) => C.isNormalized(dist)->run(~env)->toBoolR let klDivergence = (~env, dist1, dist2) => C.klDivergence(dist1, dist2)->run(~env)->toFloatR - let logScore = (~env, prior, prediction, answer) => - C.logScoreWithPointResolution(prior, prediction, answer)->run(~env)->toFloatR + let logScoreWithPointResolution = (~env, prediction, answer, prior) => + C.logScoreWithPointResolution(prediction, answer, prior)->run(~env)->toFloatR let toPointSet = (~env, dist) => C.toPointSet(dist)->run(~env)->toDistR let toSampleSet = (~env, dist, n) => C.toSampleSet(dist, n)->run(~env)->toDistR let fromSamples = (~env, xs) => C.fromSamples(xs)->run(~env)->toDistR diff --git a/packages/squiggle-lang/src/rescript/Distributions/DistributionOperation.resi b/packages/squiggle-lang/src/rescript/Distributions/DistributionOperation.resi index c3d14014..fffd011b 100644 --- a/packages/squiggle-lang/src/rescript/Distributions/DistributionOperation.resi +++ b/packages/squiggle-lang/src/rescript/Distributions/DistributionOperation.resi @@ -62,7 +62,12 @@ module Constructors: { @genType let klDivergence: (~env: env, genericDist, genericDist) => result @genType - let logScore: (~env: env, genericDist, genericDist, float) => result + let logScoreWithPointResolution: ( + ~env: env, + genericDist, + float, + option, + ) => result @genType let toPointSet: (~env: env, genericDist) => result @genType diff --git a/packages/squiggle-lang/src/rescript/Distributions/DistributionTypes.res b/packages/squiggle-lang/src/rescript/Distributions/DistributionTypes.res index 480775bf..f377a616 100644 --- a/packages/squiggle-lang/src/rescript/Distributions/DistributionTypes.res +++ b/packages/squiggle-lang/src/rescript/Distributions/DistributionTypes.res @@ -91,7 +91,7 @@ module DistributionOperation = { | ToString | ToSparkline(int) - type toScore = KLDivergence(genericDist) | LogScore(genericDist, float) + type toScore = KLDivergence(genericDist) | LogScore(float, option) type fromDist = | ToFloat(toFloat) @@ -120,7 +120,7 @@ module DistributionOperation = { | ToFloat(#Sample) => `sample` | ToFloat(#IntegralSum) => `integralSum` | ToScore(KLDivergence(_)) => `klDivergence` - | ToScore(LogScore(_, x)) => `logScore against ${E.Float.toFixed(x)}` + | ToScore(LogScore(x, _)) => `logScore against ${E.Float.toFixed(x)}` | ToDist(Normalize) => `normalize` | ToDist(ToPointSet) => `toPointSet` | ToDist(ToSampleSet(r)) => `toSampleSet(${E.I.toString(r)})` diff --git a/packages/squiggle-lang/src/rescript/Distributions/GenericDist.res b/packages/squiggle-lang/src/rescript/Distributions/GenericDist.res index 3357556f..c2b03474 100644 --- a/packages/squiggle-lang/src/rescript/Distributions/GenericDist.res +++ b/packages/squiggle-lang/src/rescript/Distributions/GenericDist.res @@ -68,18 +68,18 @@ module Score = { } let logScoreWithPointResolution = ( - prior, prediction, answer, + prior, ~toPointSetFn: toPointSetFn, ): result => { switch prior { | Some(prior') => E.R.merge(toPointSetFn(prior'), toPointSetFn(prediction))->E.R.bind(((a, b)) => PointSetDist.T.logScoreWithPointResolution( - a->Some, b, answer, + a->Some, )->E.R2.errMap(x => DistributionTypes.OperationError(x)) ) | None => @@ -87,9 +87,9 @@ module Score = { ->toPointSetFn ->E.R.bind(x => PointSetDist.T.logScoreWithPointResolution( - None, x, answer, + None, )->E.R2.errMap(x => DistributionTypes.OperationError(x)) ) } diff --git a/packages/squiggle-lang/src/rescript/Distributions/GenericDist.resi b/packages/squiggle-lang/src/rescript/Distributions/GenericDist.resi index 712e38b9..ea9a4110 100644 --- a/packages/squiggle-lang/src/rescript/Distributions/GenericDist.resi +++ b/packages/squiggle-lang/src/rescript/Distributions/GenericDist.resi @@ -26,9 +26,9 @@ let toFloatOperation: ( module Score: { let klDivergence: (t, t, ~toPointSetFn: toPointSetFn) => result let logScoreWithPointResolution: ( - option, t, float, + option, ~toPointSetFn: toPointSetFn, ) => result } diff --git a/packages/squiggle-lang/src/rescript/Distributions/PointSetDist/Continuous.res b/packages/squiggle-lang/src/rescript/Distributions/PointSetDist/Continuous.res index 22aaaf3b..c1d87946 100644 --- a/packages/squiggle-lang/src/rescript/Distributions/PointSetDist/Continuous.res +++ b/packages/squiggle-lang/src/rescript/Distributions/PointSetDist/Continuous.res @@ -279,7 +279,7 @@ module T = Dist({ ) newShape->E.R2.fmap(x => x->make->integralEndY) } - let logScoreWithPointResolution = (prior: option, prediction: t, answer: float) => { + let logScoreWithPointResolution = (prediction: t, answer: float, prior: option) => { let priorPdf = prior->E.O2.fmap((shape, x) => XYShape.XtoY.linear(x, shape.xyShape)) let predictionPdf = x => XYShape.XtoY.linear(x, prediction.xyShape) PointSetDist_Scoring.LogScoreWithPointResolution.score(~priorPdf, ~predictionPdf, ~answer) diff --git a/packages/squiggle-lang/src/rescript/Distributions/PointSetDist/Discrete.res b/packages/squiggle-lang/src/rescript/Distributions/PointSetDist/Discrete.res index 26e15f6e..c0e3f3a8 100644 --- a/packages/squiggle-lang/src/rescript/Distributions/PointSetDist/Discrete.res +++ b/packages/squiggle-lang/src/rescript/Distributions/PointSetDist/Discrete.res @@ -229,7 +229,7 @@ module T = Dist({ answer, )->E.R2.fmap(integralEndY) } - let logScoreWithPointResolution = (prior: option, prediction: t, answer: float) => { + let logScoreWithPointResolution = (prediction: t, answer: float, prior: option) => { Error(Operation.NotYetImplemented) } }) diff --git a/packages/squiggle-lang/src/rescript/Distributions/PointSetDist/Distributions.res b/packages/squiggle-lang/src/rescript/Distributions/PointSetDist/Distributions.res index 014c0668..f28b6369 100644 --- a/packages/squiggle-lang/src/rescript/Distributions/PointSetDist/Distributions.res +++ b/packages/squiggle-lang/src/rescript/Distributions/PointSetDist/Distributions.res @@ -34,7 +34,7 @@ module type dist = { let mean: t => float let variance: t => float let klDivergence: (t, t) => result - let logScoreWithPointResolution: (option, t, float) => result + let logScoreWithPointResolution: (t, float, option) => result } module Dist = (T: dist) => { diff --git a/packages/squiggle-lang/src/rescript/Distributions/PointSetDist/Mixed.res b/packages/squiggle-lang/src/rescript/Distributions/PointSetDist/Mixed.res index 4f864856..d3f09798 100644 --- a/packages/squiggle-lang/src/rescript/Distributions/PointSetDist/Mixed.res +++ b/packages/squiggle-lang/src/rescript/Distributions/PointSetDist/Mixed.res @@ -306,7 +306,7 @@ module T = Dist({ let klContinuousPart = Continuous.T.klDivergence(prediction.continuous, answer.continuous) E.R.merge(klDiscretePart, klContinuousPart)->E.R2.fmap(t => fst(t) +. snd(t)) } - let logScoreWithPointResolution = (prior: option, prediction: t, answer: float) => { + let logScoreWithPointResolution = (prediction: t, answer: float, prior: option) => { Error(Operation.NotYetImplemented) } }) diff --git a/packages/squiggle-lang/src/rescript/Distributions/PointSetDist/PointSetDist.res b/packages/squiggle-lang/src/rescript/Distributions/PointSetDist/PointSetDist.res index 7f87fd01..cdeaef5a 100644 --- a/packages/squiggle-lang/src/rescript/Distributions/PointSetDist/PointSetDist.res +++ b/packages/squiggle-lang/src/rescript/Distributions/PointSetDist/PointSetDist.res @@ -203,11 +203,11 @@ module T = Dist({ | (m1, m2) => Mixed.T.klDivergence(m1->toMixed, m2->toMixed) } - let logScoreWithPointResolution = (prior: option, prediction: t, answer: float) => { + let logScoreWithPointResolution = (prediction: t, answer: float, prior: option) => { switch (prior, prediction) { | (Some(Continuous(t1)), Continuous(t2)) => - Continuous.T.logScoreWithPointResolution(t1->Some, t2, answer) - | (None, Continuous(t2)) => Continuous.T.logScoreWithPointResolution(None, t2, answer) + Continuous.T.logScoreWithPointResolution(t2, answer, t1->Some) + | (None, Continuous(t2)) => Continuous.T.logScoreWithPointResolution(t2, answer, None) | _ => Error(Operation.NotYetImplemented) } } diff --git a/packages/squiggle-lang/src/rescript/ReducerInterface/ReducerInterface_GenericDistribution.res b/packages/squiggle-lang/src/rescript/ReducerInterface/ReducerInterface_GenericDistribution.res index b2bf8a32..bee0e3fe 100644 --- a/packages/squiggle-lang/src/rescript/ReducerInterface/ReducerInterface_GenericDistribution.res +++ b/packages/squiggle-lang/src/rescript/ReducerInterface/ReducerInterface_GenericDistribution.res @@ -251,27 +251,34 @@ let rec dispatchToGenericOutput = (call: ExpressionValue.functionCall, _environm | ("normalize", [EvDistribution(dist)]) => Helpers.toDistFn(Normalize, dist) | ("klDivergence", [EvDistribution(a), EvDistribution(b)]) => Some(runGenericOperation(FromDist(ToScore(KLDivergence(b)), a))) - | ("logScore", [EvDistribution(prior), EvDistribution(prediction), EvNumber(answer)]) | ( - "logScore", - [EvDistribution(prior), EvDistribution(prediction), EvDistribution(Symbolic(#Float(answer)))], + "logScoreWithPointResolution", + [EvDistribution(prediction), EvNumber(answer), EvDistribution(prior)], + ) + | ( + "logScoreWithPointResolution", + [EvDistribution(prediction), EvDistribution(Symbolic(#Float(answer))), EvDistribution(prior)], ) => - runGenericOperation(FromDist(ToScore(LogScore(prediction, answer)), prior))->Some - | ("logScore", [EvRecord(r)]) => - recurRecordArgs("logScore", ["prior", "prediction", "answer"], r, _environment) - | ("increment", [EvNumber(x)]) => (x +. 1.0)->DistributionOperation.Float->Some - | ("increment", [EvRecord(r)]) => recurRecordArgs("increment", ["incrementee"], r, _environment) - | ("logScoreAgainstImproperPrior", [EvDistribution(prediction), EvNumber(answer)]) + runGenericOperation(FromDist(ToScore(LogScore(answer, prior->Some)), prediction))->Some + | ("logScoreWithPointResolution", [EvDistribution(prediction), EvNumber(answer)]) | ( - "logScoreAgainstImproperPrior", + "logScoreWithPointResolution", [EvDistribution(prediction), EvDistribution(Symbolic(#Float(answer)))], ) => - runGenericOperation( - FromDist( - ToScore(LogScore(prediction, answer)), - Helpers.constructNonNormalizedPointSet(~supportOf=prediction, _ => 1.0), + runGenericOperation(FromDist(ToScore(LogScore(answer, None)), prediction))->Some + | ("logScore", [EvRecord(r)]) => + [ + recurRecordArgs( + "logScoreWithPointResolution", + ["estimate", "answer", "prior"], + r, + _environment, ), - )->Some + recurRecordArgs("klDivergence", ["estimate", "answer"], r, _environment), + recurRecordArgs("logScoreWithPointResolution", ["estimate", "answer"], r, _environment), + ]->E.A.O.firstSome + | ("increment", [EvNumber(x)]) => (x +. 1.0)->DistributionOperation.Float->Some // this tests recurRecordArgs function + | ("increment", [EvRecord(r)]) => recurRecordArgs("increment", ["incrementee"], r, _environment) | ("isNormalized", [EvDistribution(dist)]) => Helpers.toBoolFn(IsNormalized, dist) | ("toPointSet", [EvDistribution(dist)]) => Helpers.toDistFn(ToPointSet, dist) | ("scaleLog", [EvDistribution(dist)]) => diff --git a/packages/squiggle-lang/src/rescript/Utility/E.res b/packages/squiggle-lang/src/rescript/Utility/E.res index 63999ce6..35439230 100644 --- a/packages/squiggle-lang/src/rescript/Utility/E.res +++ b/packages/squiggle-lang/src/rescript/Utility/E.res @@ -631,6 +631,17 @@ module A = { } } } + let rec firstSome = (optionals: array>): option<'a> => { + let optionals' = optionals->Belt.List.fromArray + switch optionals' { + | list{} => None + | list{x, ...xs} => + switch x { + | Some(_) => x + | None => xs->Belt.List.toArray->firstSome + } + } + } } module R = { From 30ab62e9b8839ab949c677884dfea3b0a713e917 Mon Sep 17 00:00:00 2001 From: Quinn Dougherty Date: Mon, 16 May 2022 12:03:37 -0400 Subject: [PATCH 08/12] backed out of mutually recursive dispatch Value: [1e-5 to 1e-3] --- .../ReducerInterface_GenericDistribution.res | 35 +++++-------------- 1 file changed, 8 insertions(+), 27 deletions(-) diff --git a/packages/squiggle-lang/src/rescript/ReducerInterface/ReducerInterface_GenericDistribution.res b/packages/squiggle-lang/src/rescript/ReducerInterface/ReducerInterface_GenericDistribution.res index 2e99994f..2484fa62 100644 --- a/packages/squiggle-lang/src/rescript/ReducerInterface/ReducerInterface_GenericDistribution.res +++ b/packages/squiggle-lang/src/rescript/ReducerInterface/ReducerInterface_GenericDistribution.res @@ -165,7 +165,7 @@ module Helpers = { let constructNonNormalizedPointSet = ( ~supportOf: DistributionTypes.genericDist, fn: float => float, - env: DistributionOperation.env + env: DistributionOperation.env, ): DistributionTypes.genericDist => { let cdf = x => toFloatFn(#Cdf(x), supportOf, ~env) let leftEndpoint = cdf(MagicNumbers.Epsilon.ten) @@ -216,7 +216,7 @@ module SymbolicConstructors = { } } -let rec dispatchToGenericOutput = ( +let dispatchToGenericOutput = ( call: ExpressionValue.functionCall, env: DistributionOperation.env, ): option => { @@ -243,7 +243,8 @@ let rec dispatchToGenericOutput = ( | ("mean", [EvDistribution(dist)]) => Helpers.toFloatFn(#Mean, dist, ~env) | ("integralSum", [EvDistribution(dist)]) => Helpers.toFloatFn(#IntegralSum, dist, ~env) | ("toString", [EvDistribution(dist)]) => Helpers.toStringFn(ToString, dist, ~env) - | ("toSparkline", [EvDistribution(dist)]) => Helpers.toStringFn(ToSparkline(MagicNumbers.Environment.sparklineLength), dist, ~env) + | ("toSparkline", [EvDistribution(dist)]) => + Helpers.toStringFn(ToSparkline(MagicNumbers.Environment.sparklineLength), dist, ~env) | ("toSparkline", [EvDistribution(dist), EvNumber(n)]) => Helpers.toStringFn(ToSparkline(Belt.Float.toInt(n)), dist, ~env) | ("exp", [EvDistribution(a)]) => @@ -266,26 +267,16 @@ let rec dispatchToGenericOutput = ( "logScoreWithPointResolution", [EvDistribution(prediction), EvDistribution(Symbolic(#Float(answer))), EvDistribution(prior)], ) => - DistributionOperation.run(FromDist(ToScore(LogScore(answer, prior->Some)), prediction), ~env)->Some + DistributionOperation.run( + FromDist(ToScore(LogScore(answer, prior->Some)), prediction), + ~env, + )->Some | ("logScoreWithPointResolution", [EvDistribution(prediction), EvNumber(answer)]) | ( "logScoreWithPointResolution", [EvDistribution(prediction), EvDistribution(Symbolic(#Float(answer)))], ) => DistributionOperation.run(FromDist(ToScore(LogScore(answer, None)), prediction), ~env)->Some - | ("logScore", [EvRecord(r)]) => - [ - recurRecordArgs( - "logScoreWithPointResolution", - ["estimate", "answer", "prior"], - r, - env, - ), - recurRecordArgs("klDivergence", ["estimate", "answer"], r, env), - recurRecordArgs("logScoreWithPointResolution", ["estimate", "answer"], r, env), - ]->E.A.O.firstSome - | ("increment", [EvNumber(x)]) => (x +. 1.0)->DistributionOperation.Float->Some // this tests recurRecordArgs function - | ("increment", [EvRecord(r)]) => recurRecordArgs("increment", ["incrementee"], r, env) | ("isNormalized", [EvDistribution(dist)]) => Helpers.toBoolFn(IsNormalized, dist, ~env) | ("toPointSet", [EvDistribution(dist)]) => Helpers.toDistFn(ToPointSet, dist, ~env) | ("scaleLog", [EvDistribution(dist)]) => @@ -372,16 +363,6 @@ let rec dispatchToGenericOutput = ( | _ => None } } -and recurRecordArgs = ( - fnName: string, - argNames: array, - args: ExpressionValue.record, - env: DistributionOperation.env, -): option => - // argNames -> E.A2.fmap(x => Js.Dict.get(args, x)) -> E.A.O.arrSomeToSomeArr -> E.O.bind(a => dispatchToGenericOutput((fnName, a), _environment)) - argNames - ->E.A2.fmap(x => Js.Dict.unsafeGet(args, x)) - ->(a => dispatchToGenericOutput((fnName, a), env)) let genericOutputToReducerValue = (o: DistributionOperation.outputType): result< expressionValue, From 3c3c88fb7bd15e356f05a5babcf48843a27b9439 Mon Sep 17 00:00:00 2001 From: Quinn Dougherty Date: Mon, 16 May 2022 12:06:21 -0400 Subject: [PATCH 09/12] `...Resolution` => `..Answer` --- .../ReducerInterface_GenericDistribution.res | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/packages/squiggle-lang/src/rescript/ReducerInterface/ReducerInterface_GenericDistribution.res b/packages/squiggle-lang/src/rescript/ReducerInterface/ReducerInterface_GenericDistribution.res index 2484fa62..4df742be 100644 --- a/packages/squiggle-lang/src/rescript/ReducerInterface/ReducerInterface_GenericDistribution.res +++ b/packages/squiggle-lang/src/rescript/ReducerInterface/ReducerInterface_GenericDistribution.res @@ -260,20 +260,20 @@ let dispatchToGenericOutput = ( | ("klDivergence", [EvDistribution(a), EvDistribution(b)]) => Some(DistributionOperation.run(FromDist(ToScore(KLDivergence(b)), a), ~env)) | ( - "logScoreWithPointResolution", + "logScoreWithPointAnswer", [EvDistribution(prediction), EvNumber(answer), EvDistribution(prior)], ) | ( - "logScoreWithPointResolution", + "logScoreWithPointAnswer", [EvDistribution(prediction), EvDistribution(Symbolic(#Float(answer))), EvDistribution(prior)], ) => DistributionOperation.run( FromDist(ToScore(LogScore(answer, prior->Some)), prediction), ~env, )->Some - | ("logScoreWithPointResolution", [EvDistribution(prediction), EvNumber(answer)]) + | ("logScoreWithPointAnswer", [EvDistribution(prediction), EvNumber(answer)]) | ( - "logScoreWithPointResolution", + "logScoreWithPointAnswer", [EvDistribution(prediction), EvDistribution(Symbolic(#Float(answer)))], ) => DistributionOperation.run(FromDist(ToScore(LogScore(answer, None)), prediction), ~env)->Some From 81b2c74ac8d4b889f894ee4393c4ed6c1d930a1c Mon Sep 17 00:00:00 2001 From: Quinn Dougherty Date: Mon, 16 May 2022 13:18:01 -0400 Subject: [PATCH 10/12] `klDivergence` with prior Value: [1e-4 to 5e-23] --- .../ReducerInterface_GenericDistribution.res | 20 +++++++++++++++++-- 1 file changed, 18 insertions(+), 2 deletions(-) diff --git a/packages/squiggle-lang/src/rescript/ReducerInterface/ReducerInterface_GenericDistribution.res b/packages/squiggle-lang/src/rescript/ReducerInterface/ReducerInterface_GenericDistribution.res index 616ad39c..4021649f 100644 --- a/packages/squiggle-lang/src/rescript/ReducerInterface/ReducerInterface_GenericDistribution.res +++ b/packages/squiggle-lang/src/rescript/ReducerInterface/ReducerInterface_GenericDistribution.res @@ -180,6 +180,20 @@ module Helpers = { ->PointSetTypes.Continuous ->DistributionTypes.PointSet } + + let klDivergenceWithPrior = ( + prediction: DistributionTypes.genericDist, + answer: DistributionTypes.genericDist, + prior: DistributionTypes.genericDist, + env: DistributionOperation.env, + ) => { + let term1 = DistributionOperation.Constructors.klDivergence(~env, prediction, answer) + let term2 = DistributionOperation.Constructors.klDivergence(~env, prior, answer) + switch E.R.merge(term1, term2)->E.R2.fmap(((a, b)) => a -. b) { + | Ok(x) => x->DistributionOperation.Float->Some + | Error(_) => None + } + } } module SymbolicConstructors = { @@ -268,8 +282,10 @@ let dispatchToGenericOutput = ( ~env, )->Some | ("normalize", [EvDistribution(dist)]) => Helpers.toDistFn(Normalize, dist, ~env) - | ("klDivergence", [EvDistribution(a), EvDistribution(b)]) => - Some(DistributionOperation.run(FromDist(ToScore(KLDivergence(b)), a), ~env)) + | ("klDivergence", [EvDistribution(prediction), EvDistribution(answer)]) => + Some(DistributionOperation.run(FromDist(ToScore(KLDivergence(answer)), prediction), ~env)) + | ("klDivergence", [EvDistribution(prediction), EvDistribution(answer), EvDistribution(prior)]) => + Helpers.klDivergenceWithPrior(prediction, answer, prior, env) | ( "logScoreWithPointAnswer", [EvDistribution(prediction), EvNumber(answer), EvDistribution(prior)], From 1d2bb556de10028c98ed1ea6ede5549ff1f6c156 Mon Sep 17 00:00:00 2001 From: Quinn Dougherty Date: Mon, 16 May 2022 15:39:40 -0400 Subject: [PATCH 11/12] Minor CR comments Value: [1e-6 to 1e-3] --- .../PointSetDist/PointSetDist_Scoring.res | 13 +++++-------- 1 file changed, 5 insertions(+), 8 deletions(-) diff --git a/packages/squiggle-lang/src/rescript/Distributions/PointSetDist/PointSetDist_Scoring.res b/packages/squiggle-lang/src/rescript/Distributions/PointSetDist/PointSetDist_Scoring.res index ddf5207d..bfc40071 100644 --- a/packages/squiggle-lang/src/rescript/Distributions/PointSetDist/PointSetDist_Scoring.res +++ b/packages/squiggle-lang/src/rescript/Distributions/PointSetDist/PointSetDist_Scoring.res @@ -15,9 +15,6 @@ module KLDivergence = { } } -/* - -*/ module LogScoreWithPointResolution = { let logFn = Js.Math.log let score = ( @@ -25,21 +22,21 @@ module LogScoreWithPointResolution = { ~predictionPdf: float => float, ~answer: float, ): result => { - let numer = answer->predictionPdf - if numer < 0.0 { + let numerator = answer->predictionPdf + if numerator < 0.0 { Operation.ComplexNumberError->Error - } else if numer == 0.0 { + } else if numerator == 0.0 { infinity->Ok } else { -.( switch priorPdf { - | None => numer->logFn + | None => numerator->logFn | Some(f) => { let priorDensityOfAnswer = f(answer) if priorDensityOfAnswer == 0.0 { neg_infinity } else { - (numer /. priorDensityOfAnswer)->logFn + (numerator /. priorDensityOfAnswer)->logFn } } } From 9e7319ed5735dca757c90ebdc50cc5dab8ce5e24 Mon Sep 17 00:00:00 2001 From: Quinn Dougherty Date: Mon, 16 May 2022 18:06:14 -0400 Subject: [PATCH 12/12] More substantial CR; more named args Value: [1e-6 to 1e-2] --- .../Distributions/DistributionOperation.res | 15 +++++++++--- .../Distributions/DistributionOperation.resi | 6 ++--- .../Distributions/DistributionTypes.res | 6 ++--- .../rescript/Distributions/GenericDist.res | 23 +++++++++++-------- .../rescript/Distributions/GenericDist.resi | 6 ++--- .../Distributions/PointSetDist/Continuous.res | 2 +- .../Distributions/PointSetDist/Discrete.res | 2 +- .../PointSetDist/Distributions.res | 6 ++++- .../Distributions/PointSetDist/Mixed.res | 2 +- .../PointSetDist/PointSetDist.res | 7 +++--- .../PointSetDist/PointSetDist_Scoring.res | 2 +- .../ReducerInterface_GenericDistribution.res | 18 --------------- .../squiggle-lang/src/rescript/Utility/E.res | 13 ++--------- .../src/rescript/Utility/Operation.res | 4 ++-- 14 files changed, 51 insertions(+), 61 deletions(-) diff --git a/packages/squiggle-lang/src/rescript/Distributions/DistributionOperation.res b/packages/squiggle-lang/src/rescript/Distributions/DistributionOperation.res index 9af9917d..61b5cd6b 100644 --- a/packages/squiggle-lang/src/rescript/Distributions/DistributionOperation.res +++ b/packages/squiggle-lang/src/rescript/Distributions/DistributionOperation.res @@ -150,7 +150,12 @@ let rec run = (~env, functionCallInfo: functionCallInfo): outputType => { ->E.R2.fmap(r => Float(r)) ->OutputLocal.fromResult | ToScore(LogScore(answer, prior)) => - GenericDist.Score.logScoreWithPointResolution(dist, answer, prior, ~toPointSetFn) + GenericDist.Score.logScoreWithPointResolution( + ~prediction=dist, + ~answer, + ~prior, + ~toPointSetFn, + ) ->E.R2.fmap(r => Float(r)) ->OutputLocal.fromResult | ToBool(IsNormalized) => dist->GenericDist.isNormalized->Bool @@ -267,8 +272,12 @@ module Constructors = { let normalize = (~env, dist) => C.normalize(dist)->run(~env)->toDistR let isNormalized = (~env, dist) => C.isNormalized(dist)->run(~env)->toBoolR let klDivergence = (~env, dist1, dist2) => C.klDivergence(dist1, dist2)->run(~env)->toFloatR - let logScoreWithPointResolution = (~env, prediction, answer, prior) => - C.logScoreWithPointResolution(prediction, answer, prior)->run(~env)->toFloatR + let logScoreWithPointResolution = ( + ~env, + ~prediction: DistributionTypes.genericDist, + ~answer: float, + ~prior: option, + ) => C.logScoreWithPointResolution(~prediction, ~answer, ~prior)->run(~env)->toFloatR let toPointSet = (~env, dist) => C.toPointSet(dist)->run(~env)->toDistR let toSampleSet = (~env, dist, n) => C.toSampleSet(dist, n)->run(~env)->toDistR let fromSamples = (~env, xs) => C.fromSamples(xs)->run(~env)->toDistR diff --git a/packages/squiggle-lang/src/rescript/Distributions/DistributionOperation.resi b/packages/squiggle-lang/src/rescript/Distributions/DistributionOperation.resi index 7941489e..aa006c06 100644 --- a/packages/squiggle-lang/src/rescript/Distributions/DistributionOperation.resi +++ b/packages/squiggle-lang/src/rescript/Distributions/DistributionOperation.resi @@ -65,9 +65,9 @@ module Constructors: { @genType let logScoreWithPointResolution: ( ~env: env, - genericDist, - float, - option, + ~prediction: genericDist, + ~answer: float, + ~prior: option, ) => result @genType let toPointSet: (~env: env, genericDist) => result diff --git a/packages/squiggle-lang/src/rescript/Distributions/DistributionTypes.res b/packages/squiggle-lang/src/rescript/Distributions/DistributionTypes.res index f377a616..2bb409ad 100644 --- a/packages/squiggle-lang/src/rescript/Distributions/DistributionTypes.res +++ b/packages/squiggle-lang/src/rescript/Distributions/DistributionTypes.res @@ -162,9 +162,9 @@ module Constructors = { let truncate = (dist, left, right): t => FromDist(ToDist(Truncate(left, right)), dist) let inspect = (dist): t => FromDist(ToDist(Inspect), dist) let klDivergence = (dist1, dist2): t => FromDist(ToScore(KLDivergence(dist2)), dist1) - let logScoreWithPointResolution = (prior, prediction, answer): t => FromDist( - ToScore(LogScore(prediction, answer)), - prior, + let logScoreWithPointResolution = (~prediction, ~answer, ~prior): t => FromDist( + ToScore(LogScore(answer, prior)), + prediction, ) let scalePower = (dist, n): t => FromDist(ToDist(Scale(#Power, n)), dist) let scaleLogarithm = (dist, n): t => FromDist(ToDist(Scale(#Logarithm, n)), dist) diff --git a/packages/squiggle-lang/src/rescript/Distributions/GenericDist.res b/packages/squiggle-lang/src/rescript/Distributions/GenericDist.res index c2b03474..1df10240 100644 --- a/packages/squiggle-lang/src/rescript/Distributions/GenericDist.res +++ b/packages/squiggle-lang/src/rescript/Distributions/GenericDist.res @@ -68,18 +68,21 @@ module Score = { } let logScoreWithPointResolution = ( - prediction, - answer, - prior, + ~prediction: DistributionTypes.genericDist, + ~answer: float, + ~prior: option, ~toPointSetFn: toPointSetFn, ): result => { switch prior { | Some(prior') => - E.R.merge(toPointSetFn(prior'), toPointSetFn(prediction))->E.R.bind(((a, b)) => + E.R.merge(toPointSetFn(prior'), toPointSetFn(prediction))->E.R.bind((( + prior'', + prediction'', + )) => PointSetDist.T.logScoreWithPointResolution( - b, - answer, - a->Some, + ~prediction=prediction'', + ~answer, + ~prior=prior''->Some, )->E.R2.errMap(x => DistributionTypes.OperationError(x)) ) | None => @@ -87,9 +90,9 @@ module Score = { ->toPointSetFn ->E.R.bind(x => PointSetDist.T.logScoreWithPointResolution( - x, - answer, - None, + ~prediction=x, + ~answer, + ~prior=None, )->E.R2.errMap(x => DistributionTypes.OperationError(x)) ) } diff --git a/packages/squiggle-lang/src/rescript/Distributions/GenericDist.resi b/packages/squiggle-lang/src/rescript/Distributions/GenericDist.resi index ea9a4110..79fb54ab 100644 --- a/packages/squiggle-lang/src/rescript/Distributions/GenericDist.resi +++ b/packages/squiggle-lang/src/rescript/Distributions/GenericDist.resi @@ -26,9 +26,9 @@ let toFloatOperation: ( module Score: { let klDivergence: (t, t, ~toPointSetFn: toPointSetFn) => result let logScoreWithPointResolution: ( - t, - float, - option, + ~prediction: t, + ~answer: float, + ~prior: option, ~toPointSetFn: toPointSetFn, ) => result } diff --git a/packages/squiggle-lang/src/rescript/Distributions/PointSetDist/Continuous.res b/packages/squiggle-lang/src/rescript/Distributions/PointSetDist/Continuous.res index c1d87946..3661a531 100644 --- a/packages/squiggle-lang/src/rescript/Distributions/PointSetDist/Continuous.res +++ b/packages/squiggle-lang/src/rescript/Distributions/PointSetDist/Continuous.res @@ -279,7 +279,7 @@ module T = Dist({ ) newShape->E.R2.fmap(x => x->make->integralEndY) } - let logScoreWithPointResolution = (prediction: t, answer: float, prior: option) => { + let logScoreWithPointResolution = (~prediction: t, ~answer: float, ~prior: option) => { let priorPdf = prior->E.O2.fmap((shape, x) => XYShape.XtoY.linear(x, shape.xyShape)) let predictionPdf = x => XYShape.XtoY.linear(x, prediction.xyShape) PointSetDist_Scoring.LogScoreWithPointResolution.score(~priorPdf, ~predictionPdf, ~answer) diff --git a/packages/squiggle-lang/src/rescript/Distributions/PointSetDist/Discrete.res b/packages/squiggle-lang/src/rescript/Distributions/PointSetDist/Discrete.res index c0e3f3a8..fea5db6f 100644 --- a/packages/squiggle-lang/src/rescript/Distributions/PointSetDist/Discrete.res +++ b/packages/squiggle-lang/src/rescript/Distributions/PointSetDist/Discrete.res @@ -229,7 +229,7 @@ module T = Dist({ answer, )->E.R2.fmap(integralEndY) } - let logScoreWithPointResolution = (prediction: t, answer: float, prior: option) => { + let logScoreWithPointResolution = (~prediction: t, ~answer: float, ~prior: option) => { Error(Operation.NotYetImplemented) } }) diff --git a/packages/squiggle-lang/src/rescript/Distributions/PointSetDist/Distributions.res b/packages/squiggle-lang/src/rescript/Distributions/PointSetDist/Distributions.res index f28b6369..2d0358ec 100644 --- a/packages/squiggle-lang/src/rescript/Distributions/PointSetDist/Distributions.res +++ b/packages/squiggle-lang/src/rescript/Distributions/PointSetDist/Distributions.res @@ -34,7 +34,11 @@ module type dist = { let mean: t => float let variance: t => float let klDivergence: (t, t) => result - let logScoreWithPointResolution: (t, float, option) => result + let logScoreWithPointResolution: ( + ~prediction: t, + ~answer: float, + ~prior: option, + ) => result } module Dist = (T: dist) => { diff --git a/packages/squiggle-lang/src/rescript/Distributions/PointSetDist/Mixed.res b/packages/squiggle-lang/src/rescript/Distributions/PointSetDist/Mixed.res index d3f09798..42a88909 100644 --- a/packages/squiggle-lang/src/rescript/Distributions/PointSetDist/Mixed.res +++ b/packages/squiggle-lang/src/rescript/Distributions/PointSetDist/Mixed.res @@ -306,7 +306,7 @@ module T = Dist({ let klContinuousPart = Continuous.T.klDivergence(prediction.continuous, answer.continuous) E.R.merge(klDiscretePart, klContinuousPart)->E.R2.fmap(t => fst(t) +. snd(t)) } - let logScoreWithPointResolution = (prediction: t, answer: float, prior: option) => { + let logScoreWithPointResolution = (~prediction: t, ~answer: float, ~prior: option) => { Error(Operation.NotYetImplemented) } }) diff --git a/packages/squiggle-lang/src/rescript/Distributions/PointSetDist/PointSetDist.res b/packages/squiggle-lang/src/rescript/Distributions/PointSetDist/PointSetDist.res index cdeaef5a..d21a7383 100644 --- a/packages/squiggle-lang/src/rescript/Distributions/PointSetDist/PointSetDist.res +++ b/packages/squiggle-lang/src/rescript/Distributions/PointSetDist/PointSetDist.res @@ -203,11 +203,12 @@ module T = Dist({ | (m1, m2) => Mixed.T.klDivergence(m1->toMixed, m2->toMixed) } - let logScoreWithPointResolution = (prediction: t, answer: float, prior: option) => { + let logScoreWithPointResolution = (~prediction: t, ~answer: float, ~prior: option) => { switch (prior, prediction) { | (Some(Continuous(t1)), Continuous(t2)) => - Continuous.T.logScoreWithPointResolution(t2, answer, t1->Some) - | (None, Continuous(t2)) => Continuous.T.logScoreWithPointResolution(t2, answer, None) + Continuous.T.logScoreWithPointResolution(~prediction=t2, ~answer, ~prior=t1->Some) + | (None, Continuous(t2)) => + Continuous.T.logScoreWithPointResolution(~prediction=t2, ~answer, ~prior=None) | _ => Error(Operation.NotYetImplemented) } } diff --git a/packages/squiggle-lang/src/rescript/Distributions/PointSetDist/PointSetDist_Scoring.res b/packages/squiggle-lang/src/rescript/Distributions/PointSetDist/PointSetDist_Scoring.res index bfc40071..532bc76c 100644 --- a/packages/squiggle-lang/src/rescript/Distributions/PointSetDist/PointSetDist_Scoring.res +++ b/packages/squiggle-lang/src/rescript/Distributions/PointSetDist/PointSetDist_Scoring.res @@ -24,7 +24,7 @@ module LogScoreWithPointResolution = { ): result => { let numerator = answer->predictionPdf if numerator < 0.0 { - Operation.ComplexNumberError->Error + Operation.PdfInvalidError->Error } else if numerator == 0.0 { infinity->Ok } else { diff --git a/packages/squiggle-lang/src/rescript/ReducerInterface/ReducerInterface_GenericDistribution.res b/packages/squiggle-lang/src/rescript/ReducerInterface/ReducerInterface_GenericDistribution.res index 4021649f..0f29f9db 100644 --- a/packages/squiggle-lang/src/rescript/ReducerInterface/ReducerInterface_GenericDistribution.res +++ b/packages/squiggle-lang/src/rescript/ReducerInterface/ReducerInterface_GenericDistribution.res @@ -162,24 +162,6 @@ module Helpers = { } } } - let constructNonNormalizedPointSet = ( - ~supportOf: DistributionTypes.genericDist, - fn: float => float, - env: DistributionOperation.env, - ): DistributionTypes.genericDist => { - let cdf = x => toFloatFn(#Cdf(x), supportOf, ~env) - let leftEndpoint = cdf(MagicNumbers.Epsilon.ten) - let rightEndpoint = cdf(1.0 -. MagicNumbers.Epsilon.ten) - let xs = switch (leftEndpoint, rightEndpoint) { - | (Some(Float(a)), Some(Float(b))) => - E.A.Floats.range(a, b, MagicNumbers.Environment.defaultXYPointLength) - | _ => [] - } - {xs: xs, ys: E.A.fmap(fn, xs)} - ->Continuous.make - ->PointSetTypes.Continuous - ->DistributionTypes.PointSet - } let klDivergenceWithPrior = ( prediction: DistributionTypes.genericDist, diff --git a/packages/squiggle-lang/src/rescript/Utility/E.res b/packages/squiggle-lang/src/rescript/Utility/E.res index 35439230..64729324 100644 --- a/packages/squiggle-lang/src/rescript/Utility/E.res +++ b/packages/squiggle-lang/src/rescript/Utility/E.res @@ -620,6 +620,7 @@ module A = { | Some(o) => o | None => [] } + // REturns `None` there are no non-`None` elements let rec arrSomeToSomeArr = (optionals: array>): option> => { let optionals' = optionals->Belt.List.fromArray switch optionals' { @@ -631,17 +632,7 @@ module A = { } } } - let rec firstSome = (optionals: array>): option<'a> => { - let optionals' = optionals->Belt.List.fromArray - switch optionals' { - | list{} => None - | list{x, ...xs} => - switch x { - | Some(_) => x - | None => xs->Belt.List.toArray->firstSome - } - } - } + let firstSome = x => Belt.Array.getBy(x, O.isSome) } module R = { diff --git a/packages/squiggle-lang/src/rescript/Utility/Operation.res b/packages/squiggle-lang/src/rescript/Utility/Operation.res index cfa18925..3f56493b 100644 --- a/packages/squiggle-lang/src/rescript/Utility/Operation.res +++ b/packages/squiggle-lang/src/rescript/Utility/Operation.res @@ -55,7 +55,7 @@ type operationError = | ComplexNumberError | InfinityError | NegativeInfinityError - | LogicallyInconsistentPathwayError + | PdfInvalidError | NotYetImplemented // should be removed when `klDivergence` for mixed and discrete is implemented. @genType @@ -69,7 +69,7 @@ module Error = { | ComplexNumberError => "Operation returned complex result" | InfinityError => "Operation returned positive infinity" | NegativeInfinityError => "Operation returned negative infinity" - | LogicallyInconsistentPathwayError => "This pathway should have been logically unreachable" + | PdfInvalidError => "This Pdf is invalid" | NotYetImplemented => "This pathway is not yet implemented" } }