From 51310819a12a510fc0b539a599d013bc9863e3b5 Mon Sep 17 00:00:00 2001 From: Quinn Dougherty Date: Thu, 12 May 2022 14:16:52 -0400 Subject: [PATCH] `logScore` now in interface. Value: [1e-4 to 1e-1] --- .../Distributions/DistributionOperation.res | 8 ++++- .../Distributions/DistributionOperation.resi | 2 ++ .../Distributions/DistributionTypes.res | 7 +++- .../rescript/Distributions/GenericDist.res | 19 ++++++++--- .../rescript/Distributions/GenericDist.resi | 5 ++- .../Distributions/PointSetDist/Continuous.res | 4 --- .../Distributions/PointSetDist/Discrete.res | 3 -- .../PointSetDist/Distributions.res | 2 -- .../Distributions/PointSetDist/Mixed.res | 3 -- .../src/rescript/MagicNumbers.res | 1 + .../ReducerInterface_GenericDistribution.res | 32 ++++++++++++++++++- 11 files changed, 65 insertions(+), 21 deletions(-) diff --git a/packages/squiggle-lang/src/rescript/Distributions/DistributionOperation.res b/packages/squiggle-lang/src/rescript/Distributions/DistributionOperation.res index 80a53eb8..97085133 100644 --- a/packages/squiggle-lang/src/rescript/Distributions/DistributionOperation.res +++ b/packages/squiggle-lang/src/rescript/Distributions/DistributionOperation.res @@ -145,7 +145,11 @@ let rec run = (~env, functionCallInfo: functionCallInfo): outputType => { } | ToDist(Normalize) => dist->GenericDist.normalize->Dist | ToScore(KLDivergence(t2)) => - GenericDist.klDivergence(dist, t2, ~toPointSetFn) + GenericDist.Score.klDivergence(dist, t2, ~toPointSetFn) + ->E.R2.fmap(r => Float(r)) + ->OutputLocal.fromResult + | ToScore(LogScore(prediction, answer)) => + GenericDist.Score.logScore(dist, prediction, answer, ~toPointSetFn) ->E.R2.fmap(r => Float(r)) ->OutputLocal.fromResult | ToBool(IsNormalized) => dist->GenericDist.isNormalized->Bool @@ -262,6 +266,8 @@ module Constructors = { let normalize = (~env, dist) => C.normalize(dist)->run(~env)->toDistR let isNormalized = (~env, dist) => C.isNormalized(dist)->run(~env)->toBoolR let klDivergence = (~env, dist1, dist2) => C.klDivergence(dist1, dist2)->run(~env)->toFloatR + let logScore = (~env, prior, prediction, answer) => + C.logScore(prior, prediction, answer)->run(~env)->toFloatR let toPointSet = (~env, dist) => C.toPointSet(dist)->run(~env)->toDistR let toSampleSet = (~env, dist, n) => C.toSampleSet(dist, n)->run(~env)->toDistR let fromSamples = (~env, xs) => C.fromSamples(xs)->run(~env)->toDistR diff --git a/packages/squiggle-lang/src/rescript/Distributions/DistributionOperation.resi b/packages/squiggle-lang/src/rescript/Distributions/DistributionOperation.resi index 200be7d7..c3d14014 100644 --- a/packages/squiggle-lang/src/rescript/Distributions/DistributionOperation.resi +++ b/packages/squiggle-lang/src/rescript/Distributions/DistributionOperation.resi @@ -62,6 +62,8 @@ module Constructors: { @genType let klDivergence: (~env: env, genericDist, genericDist) => result @genType + let logScore: (~env: env, genericDist, genericDist, float) => result + @genType let toPointSet: (~env: env, genericDist) => result @genType let toSampleSet: (~env: env, genericDist, int) => result diff --git a/packages/squiggle-lang/src/rescript/Distributions/DistributionTypes.res b/packages/squiggle-lang/src/rescript/Distributions/DistributionTypes.res index a9f7dfbe..35e1b1a7 100644 --- a/packages/squiggle-lang/src/rescript/Distributions/DistributionTypes.res +++ b/packages/squiggle-lang/src/rescript/Distributions/DistributionTypes.res @@ -91,7 +91,7 @@ module DistributionOperation = { | ToString | ToSparkline(int) - type toScore = KLDivergence(genericDist) + type toScore = KLDivergence(genericDist) | LogScore(genericDist, float) type fromDist = | ToFloat(toFloat) @@ -120,6 +120,7 @@ module DistributionOperation = { | ToFloat(#Sample) => `sample` | ToFloat(#IntegralSum) => `integralSum` | ToScore(KLDivergence(_)) => `klDivergence` + | ToScore(LogScore(_, x)) => `logScore against ${E.Float.toFixed(x)}` | ToDist(Normalize) => `normalize` | ToDist(ToPointSet) => `toPointSet` | ToDist(ToSampleSet(r)) => `toSampleSet(${E.I.toString(r)})` @@ -161,6 +162,10 @@ module Constructors = { let truncate = (dist, left, right): t => FromDist(ToDist(Truncate(left, right)), dist) let inspect = (dist): t => FromDist(ToDist(Inspect), dist) let klDivergence = (dist1, dist2): t => FromDist(ToScore(KLDivergence(dist2)), dist1) + let logScore = (prior, prediction, answer): t => FromDist( + ToScore(LogScore(prediction, answer)), + prior, + ) let scalePower = (dist, n): t => FromDist(ToDist(Scale(#Power, n)), dist) let scaleLogarithm = (dist, n): t => FromDist(ToDist(Scale(#Logarithm, n)), dist) let scaleLogarithmWithThreshold = (dist, n, eps): t => FromDist( diff --git a/packages/squiggle-lang/src/rescript/Distributions/GenericDist.res b/packages/squiggle-lang/src/rescript/Distributions/GenericDist.res index 2085d72c..1995ef34 100644 --- a/packages/squiggle-lang/src/rescript/Distributions/GenericDist.res +++ b/packages/squiggle-lang/src/rescript/Distributions/GenericDist.res @@ -59,11 +59,20 @@ let integralEndY = (t: t): float => let isNormalized = (t: t): bool => Js.Math.abs_float(integralEndY(t) -. 1.0) < 1e-7 -let klDivergence = (t1, t2, ~toPointSetFn: toPointSetFn): result => { - let pointSets = E.R.merge(toPointSetFn(t1), toPointSetFn(t2)) - pointSets |> E.R2.bind(((a, b)) => - PointSetDist.T.klDivergence(a, b)->E.R2.errMap(x => DistributionTypes.OperationError(x)) - ) +module Score = { + let klDivergence = (t1, t2, ~toPointSetFn: toPointSetFn): result => { + let pointSets = E.R.merge(toPointSetFn(t1), toPointSetFn(t2)) + pointSets |> E.R2.bind(((a, b)) => + PointSetDist.T.klDivergence(a, b)->E.R2.errMap(x => DistributionTypes.OperationError(x)) + ) + } + + let logScore = (prior, prediction, answer, ~toPointSetFn: toPointSetFn): result => { + let pointSets = E.R.merge(toPointSetFn(prior), toPointSetFn(prediction)) + pointSets |> E.R2.bind(((a, b)) => + PointSetDist.T.logScore(a, b, answer)->E.R2.errMap(x => DistributionTypes.OperationError(x)) + ) + } } let toFloatOperation = ( diff --git a/packages/squiggle-lang/src/rescript/Distributions/GenericDist.resi b/packages/squiggle-lang/src/rescript/Distributions/GenericDist.resi index 03bc5fe8..45f1e8f8 100644 --- a/packages/squiggle-lang/src/rescript/Distributions/GenericDist.resi +++ b/packages/squiggle-lang/src/rescript/Distributions/GenericDist.resi @@ -23,7 +23,10 @@ let toFloatOperation: ( ~distToFloatOperation: DistributionTypes.DistributionOperation.toFloat, ) => result -let klDivergence: (t, t, ~toPointSetFn: toPointSetFn) => result +module Score: { + let klDivergence: (t, t, ~toPointSetFn: toPointSetFn) => result + let logScore: (t, t, float, ~toPointSetFn: toPointSetFn) => result +} @genType let toPointSet: ( diff --git a/packages/squiggle-lang/src/rescript/Distributions/PointSetDist/Continuous.res b/packages/squiggle-lang/src/rescript/Distributions/PointSetDist/Continuous.res index 4151254f..7bf9c874 100644 --- a/packages/squiggle-lang/src/rescript/Distributions/PointSetDist/Continuous.res +++ b/packages/squiggle-lang/src/rescript/Distributions/PointSetDist/Continuous.res @@ -287,10 +287,6 @@ module T = Dist({ ) newShape->E.R2.fmap(x => x->make->integralEndY) } - let logScoreAgainstImproperPrior = (prediction: t, answer: float) => { - let prior = make({xs: prediction.xyShape.xs, ys: E.A.fmap(_ => 1.0, prediction.xyShape.xs)}) - logScore(prior, prediction, answer) - } }) let isNormalized = (t: t): bool => { diff --git a/packages/squiggle-lang/src/rescript/Distributions/PointSetDist/Discrete.res b/packages/squiggle-lang/src/rescript/Distributions/PointSetDist/Discrete.res index 9bc274e7..e10ed981 100644 --- a/packages/squiggle-lang/src/rescript/Distributions/PointSetDist/Discrete.res +++ b/packages/squiggle-lang/src/rescript/Distributions/PointSetDist/Discrete.res @@ -232,7 +232,4 @@ module T = Dist({ let logScore = (prior: t, prediction: t, answer: float) => { Error(Operation.NotYetImplemented) } - let logScoreAgainstImproperPrior = (prediction: t, answer: float) => { - Error(Operation.NotYetImplemented) - } }) diff --git a/packages/squiggle-lang/src/rescript/Distributions/PointSetDist/Distributions.res b/packages/squiggle-lang/src/rescript/Distributions/PointSetDist/Distributions.res index 9fb7e689..aa44da10 100644 --- a/packages/squiggle-lang/src/rescript/Distributions/PointSetDist/Distributions.res +++ b/packages/squiggle-lang/src/rescript/Distributions/PointSetDist/Distributions.res @@ -35,7 +35,6 @@ module type dist = { let variance: t => float let klDivergence: (t, t) => result let logScore: (t, t, float) => result - let logScoreAgainstImproperPrior: (t, float) => result } module Dist = (T: dist) => { @@ -60,7 +59,6 @@ module Dist = (T: dist) => { let integralEndY = T.integralEndY let klDivergence = T.klDivergence let logScore = T.logScore - let logScoreAgainstImproperPrior = T.logScoreAgainstImproperPrior let updateIntegralCache = T.updateIntegralCache diff --git a/packages/squiggle-lang/src/rescript/Distributions/PointSetDist/Mixed.res b/packages/squiggle-lang/src/rescript/Distributions/PointSetDist/Mixed.res index 50e8a419..2fd046ee 100644 --- a/packages/squiggle-lang/src/rescript/Distributions/PointSetDist/Mixed.res +++ b/packages/squiggle-lang/src/rescript/Distributions/PointSetDist/Mixed.res @@ -309,9 +309,6 @@ module T = Dist({ let logScore = (prior: t, prediction: t, answer: float) => { Error(Operation.NotYetImplemented) } - let logScoreAgainstImproperPrior = (prediction: t, answer: float) => { - Error(Operation.NotYetImplemented) - } }) let combineAlgebraically = (op: Operation.convolutionOperation, t1: t, t2: t): t => { diff --git a/packages/squiggle-lang/src/rescript/MagicNumbers.res b/packages/squiggle-lang/src/rescript/MagicNumbers.res index 13beafe4..b859421c 100644 --- a/packages/squiggle-lang/src/rescript/MagicNumbers.res +++ b/packages/squiggle-lang/src/rescript/MagicNumbers.res @@ -12,6 +12,7 @@ module Epsilon = { module Environment = { let defaultXYPointLength = 1000 let defaultSampleCount = 10000 + let sparklineLength = 20 } module OpCost = { diff --git a/packages/squiggle-lang/src/rescript/ReducerInterface/ReducerInterface_GenericDistribution.res b/packages/squiggle-lang/src/rescript/ReducerInterface/ReducerInterface_GenericDistribution.res index 8dae2586..7e5709bd 100644 --- a/packages/squiggle-lang/src/rescript/ReducerInterface/ReducerInterface_GenericDistribution.res +++ b/packages/squiggle-lang/src/rescript/ReducerInterface/ReducerInterface_GenericDistribution.res @@ -157,6 +157,20 @@ module Helpers = { } } } + let constructNonNormalizedPointSet = ( + ~supportOf: DistributionTypes.genericDist, + fn: float => float, + ): option => { + switch supportOf { + | PointSet(Continuous(dist)) => + {xs: dist.xyShape.xs, ys: E.A.fmap(fn, dist.xyShape.xs)} + ->Continuous.make + ->Continuous + ->PointSet + ->Some + | _ => None + } + } } module SymbolicConstructors = { @@ -219,7 +233,8 @@ let dispatchToGenericOutput = (call: ExpressionValue.functionCall, _environment) | ("mean", [EvDistribution(dist)]) => Helpers.toFloatFn(#Mean, dist) | ("integralSum", [EvDistribution(dist)]) => Helpers.toFloatFn(#IntegralSum, dist) | ("toString", [EvDistribution(dist)]) => Helpers.toStringFn(ToString, dist) - | ("toSparkline", [EvDistribution(dist)]) => Helpers.toStringFn(ToSparkline(20), dist) + | ("toSparkline", [EvDistribution(dist)]) => + Helpers.toStringFn(ToSparkline(MagicNumbers.Environment.sparklineLength), dist) | ("toSparkline", [EvDistribution(dist), EvNumber(n)]) => Helpers.toStringFn(ToSparkline(Belt.Float.toInt(n)), dist) | ("exp", [EvDistribution(a)]) => @@ -233,6 +248,21 @@ let dispatchToGenericOutput = (call: ExpressionValue.functionCall, _environment) | ("normalize", [EvDistribution(dist)]) => Helpers.toDistFn(Normalize, dist) | ("klDivergence", [EvDistribution(a), EvDistribution(b)]) => Some(runGenericOperation(FromDist(ToScore(KLDivergence(b)), a))) + | ("logScore", [EvDistribution(prior), EvDistribution(prediction), EvNumber(answer)]) + | ( + "logScore", + [EvDistribution(prior), EvDistribution(prediction), EvDistribution(Symbolic(#Float(answer)))], + ) => + Some(runGenericOperation(FromDist(ToScore(LogScore(prediction, answer)), prior))) + | ("logScoreAgainstImproperPrior", [EvDistribution(prediction), EvNumber(answer)]) + | ( + "logScoreAgainstImproperPrior", + [EvDistribution(prediction), EvDistribution(Symbolic(#Float(answer)))], + ) => + E.O.fmap( + d => runGenericOperation(FromDist(ToScore(LogScore(prediction, answer)), d)), + Helpers.constructNonNormalizedPointSet(~supportOf=prediction, _ => 1.0), + ) | ("isNormalized", [EvDistribution(dist)]) => Helpers.toBoolFn(IsNormalized, dist) | ("toPointSet", [EvDistribution(dist)]) => Helpers.toDistFn(ToPointSet, dist) | ("scaleLog", [EvDistribution(dist)]) =>