From bdbb86aa9ed4423cd988efbaf29cb807060c634c Mon Sep 17 00:00:00 2001 From: Quinn Dougherty Date: Fri, 13 May 2022 16:15:04 -0400 Subject: [PATCH] `logScore` on records now interprets almost every which way we're interested in Value: [1e-3 to 9e-1] --- .../Distributions/DistributionOperation.res | 8 ++-- .../Distributions/DistributionOperation.resi | 7 +++- .../Distributions/DistributionTypes.res | 4 +- .../rescript/Distributions/GenericDist.res | 6 +-- .../rescript/Distributions/GenericDist.resi | 2 +- .../Distributions/PointSetDist/Continuous.res | 2 +- .../Distributions/PointSetDist/Discrete.res | 2 +- .../PointSetDist/Distributions.res | 2 +- .../Distributions/PointSetDist/Mixed.res | 2 +- .../PointSetDist/PointSetDist.res | 6 +-- .../ReducerInterface_GenericDistribution.res | 37 +++++++++++-------- .../squiggle-lang/src/rescript/Utility/E.res | 11 ++++++ 12 files changed, 56 insertions(+), 33 deletions(-) diff --git a/packages/squiggle-lang/src/rescript/Distributions/DistributionOperation.res b/packages/squiggle-lang/src/rescript/Distributions/DistributionOperation.res index 4c4291eb..1e2ff872 100644 --- a/packages/squiggle-lang/src/rescript/Distributions/DistributionOperation.res +++ b/packages/squiggle-lang/src/rescript/Distributions/DistributionOperation.res @@ -148,8 +148,8 @@ let rec run = (~env, functionCallInfo: functionCallInfo): outputType => { GenericDist.Score.klDivergence(dist, t2, ~toPointSetFn) ->E.R2.fmap(r => Float(r)) ->OutputLocal.fromResult - | ToScore(LogScore(prediction, answer)) => - GenericDist.Score.logScoreWithPointResolution(Some(dist), prediction, answer, ~toPointSetFn) + | ToScore(LogScore(answer, prior)) => + GenericDist.Score.logScoreWithPointResolution(dist, answer, prior, ~toPointSetFn) ->E.R2.fmap(r => Float(r)) ->OutputLocal.fromResult | ToBool(IsNormalized) => dist->GenericDist.isNormalized->Bool @@ -266,8 +266,8 @@ module Constructors = { let normalize = (~env, dist) => C.normalize(dist)->run(~env)->toDistR let isNormalized = (~env, dist) => C.isNormalized(dist)->run(~env)->toBoolR let klDivergence = (~env, dist1, dist2) => C.klDivergence(dist1, dist2)->run(~env)->toFloatR - let logScore = (~env, prior, prediction, answer) => - C.logScoreWithPointResolution(prior, prediction, answer)->run(~env)->toFloatR + let logScoreWithPointResolution = (~env, prediction, answer, prior) => + C.logScoreWithPointResolution(prediction, answer, prior)->run(~env)->toFloatR let toPointSet = (~env, dist) => C.toPointSet(dist)->run(~env)->toDistR let toSampleSet = (~env, dist, n) => C.toSampleSet(dist, n)->run(~env)->toDistR let fromSamples = (~env, xs) => C.fromSamples(xs)->run(~env)->toDistR diff --git a/packages/squiggle-lang/src/rescript/Distributions/DistributionOperation.resi b/packages/squiggle-lang/src/rescript/Distributions/DistributionOperation.resi index c3d14014..fffd011b 100644 --- a/packages/squiggle-lang/src/rescript/Distributions/DistributionOperation.resi +++ b/packages/squiggle-lang/src/rescript/Distributions/DistributionOperation.resi @@ -62,7 +62,12 @@ module Constructors: { @genType let klDivergence: (~env: env, genericDist, genericDist) => result @genType - let logScore: (~env: env, genericDist, genericDist, float) => result + let logScoreWithPointResolution: ( + ~env: env, + genericDist, + float, + option, + ) => result @genType let toPointSet: (~env: env, genericDist) => result @genType diff --git a/packages/squiggle-lang/src/rescript/Distributions/DistributionTypes.res b/packages/squiggle-lang/src/rescript/Distributions/DistributionTypes.res index 480775bf..f377a616 100644 --- a/packages/squiggle-lang/src/rescript/Distributions/DistributionTypes.res +++ b/packages/squiggle-lang/src/rescript/Distributions/DistributionTypes.res @@ -91,7 +91,7 @@ module DistributionOperation = { | ToString | ToSparkline(int) - type toScore = KLDivergence(genericDist) | LogScore(genericDist, float) + type toScore = KLDivergence(genericDist) | LogScore(float, option) type fromDist = | ToFloat(toFloat) @@ -120,7 +120,7 @@ module DistributionOperation = { | ToFloat(#Sample) => `sample` | ToFloat(#IntegralSum) => `integralSum` | ToScore(KLDivergence(_)) => `klDivergence` - | ToScore(LogScore(_, x)) => `logScore against ${E.Float.toFixed(x)}` + | ToScore(LogScore(x, _)) => `logScore against ${E.Float.toFixed(x)}` | ToDist(Normalize) => `normalize` | ToDist(ToPointSet) => `toPointSet` | ToDist(ToSampleSet(r)) => `toSampleSet(${E.I.toString(r)})` diff --git a/packages/squiggle-lang/src/rescript/Distributions/GenericDist.res b/packages/squiggle-lang/src/rescript/Distributions/GenericDist.res index 3357556f..c2b03474 100644 --- a/packages/squiggle-lang/src/rescript/Distributions/GenericDist.res +++ b/packages/squiggle-lang/src/rescript/Distributions/GenericDist.res @@ -68,18 +68,18 @@ module Score = { } let logScoreWithPointResolution = ( - prior, prediction, answer, + prior, ~toPointSetFn: toPointSetFn, ): result => { switch prior { | Some(prior') => E.R.merge(toPointSetFn(prior'), toPointSetFn(prediction))->E.R.bind(((a, b)) => PointSetDist.T.logScoreWithPointResolution( - a->Some, b, answer, + a->Some, )->E.R2.errMap(x => DistributionTypes.OperationError(x)) ) | None => @@ -87,9 +87,9 @@ module Score = { ->toPointSetFn ->E.R.bind(x => PointSetDist.T.logScoreWithPointResolution( - None, x, answer, + None, )->E.R2.errMap(x => DistributionTypes.OperationError(x)) ) } diff --git a/packages/squiggle-lang/src/rescript/Distributions/GenericDist.resi b/packages/squiggle-lang/src/rescript/Distributions/GenericDist.resi index 712e38b9..ea9a4110 100644 --- a/packages/squiggle-lang/src/rescript/Distributions/GenericDist.resi +++ b/packages/squiggle-lang/src/rescript/Distributions/GenericDist.resi @@ -26,9 +26,9 @@ let toFloatOperation: ( module Score: { let klDivergence: (t, t, ~toPointSetFn: toPointSetFn) => result let logScoreWithPointResolution: ( - option, t, float, + option, ~toPointSetFn: toPointSetFn, ) => result } diff --git a/packages/squiggle-lang/src/rescript/Distributions/PointSetDist/Continuous.res b/packages/squiggle-lang/src/rescript/Distributions/PointSetDist/Continuous.res index 22aaaf3b..c1d87946 100644 --- a/packages/squiggle-lang/src/rescript/Distributions/PointSetDist/Continuous.res +++ b/packages/squiggle-lang/src/rescript/Distributions/PointSetDist/Continuous.res @@ -279,7 +279,7 @@ module T = Dist({ ) newShape->E.R2.fmap(x => x->make->integralEndY) } - let logScoreWithPointResolution = (prior: option, prediction: t, answer: float) => { + let logScoreWithPointResolution = (prediction: t, answer: float, prior: option) => { let priorPdf = prior->E.O2.fmap((shape, x) => XYShape.XtoY.linear(x, shape.xyShape)) let predictionPdf = x => XYShape.XtoY.linear(x, prediction.xyShape) PointSetDist_Scoring.LogScoreWithPointResolution.score(~priorPdf, ~predictionPdf, ~answer) diff --git a/packages/squiggle-lang/src/rescript/Distributions/PointSetDist/Discrete.res b/packages/squiggle-lang/src/rescript/Distributions/PointSetDist/Discrete.res index 26e15f6e..c0e3f3a8 100644 --- a/packages/squiggle-lang/src/rescript/Distributions/PointSetDist/Discrete.res +++ b/packages/squiggle-lang/src/rescript/Distributions/PointSetDist/Discrete.res @@ -229,7 +229,7 @@ module T = Dist({ answer, )->E.R2.fmap(integralEndY) } - let logScoreWithPointResolution = (prior: option, prediction: t, answer: float) => { + let logScoreWithPointResolution = (prediction: t, answer: float, prior: option) => { Error(Operation.NotYetImplemented) } }) diff --git a/packages/squiggle-lang/src/rescript/Distributions/PointSetDist/Distributions.res b/packages/squiggle-lang/src/rescript/Distributions/PointSetDist/Distributions.res index 014c0668..f28b6369 100644 --- a/packages/squiggle-lang/src/rescript/Distributions/PointSetDist/Distributions.res +++ b/packages/squiggle-lang/src/rescript/Distributions/PointSetDist/Distributions.res @@ -34,7 +34,7 @@ module type dist = { let mean: t => float let variance: t => float let klDivergence: (t, t) => result - let logScoreWithPointResolution: (option, t, float) => result + let logScoreWithPointResolution: (t, float, option) => result } module Dist = (T: dist) => { diff --git a/packages/squiggle-lang/src/rescript/Distributions/PointSetDist/Mixed.res b/packages/squiggle-lang/src/rescript/Distributions/PointSetDist/Mixed.res index 4f864856..d3f09798 100644 --- a/packages/squiggle-lang/src/rescript/Distributions/PointSetDist/Mixed.res +++ b/packages/squiggle-lang/src/rescript/Distributions/PointSetDist/Mixed.res @@ -306,7 +306,7 @@ module T = Dist({ let klContinuousPart = Continuous.T.klDivergence(prediction.continuous, answer.continuous) E.R.merge(klDiscretePart, klContinuousPart)->E.R2.fmap(t => fst(t) +. snd(t)) } - let logScoreWithPointResolution = (prior: option, prediction: t, answer: float) => { + let logScoreWithPointResolution = (prediction: t, answer: float, prior: option) => { Error(Operation.NotYetImplemented) } }) diff --git a/packages/squiggle-lang/src/rescript/Distributions/PointSetDist/PointSetDist.res b/packages/squiggle-lang/src/rescript/Distributions/PointSetDist/PointSetDist.res index 7f87fd01..cdeaef5a 100644 --- a/packages/squiggle-lang/src/rescript/Distributions/PointSetDist/PointSetDist.res +++ b/packages/squiggle-lang/src/rescript/Distributions/PointSetDist/PointSetDist.res @@ -203,11 +203,11 @@ module T = Dist({ | (m1, m2) => Mixed.T.klDivergence(m1->toMixed, m2->toMixed) } - let logScoreWithPointResolution = (prior: option, prediction: t, answer: float) => { + let logScoreWithPointResolution = (prediction: t, answer: float, prior: option) => { switch (prior, prediction) { | (Some(Continuous(t1)), Continuous(t2)) => - Continuous.T.logScoreWithPointResolution(t1->Some, t2, answer) - | (None, Continuous(t2)) => Continuous.T.logScoreWithPointResolution(None, t2, answer) + Continuous.T.logScoreWithPointResolution(t2, answer, t1->Some) + | (None, Continuous(t2)) => Continuous.T.logScoreWithPointResolution(t2, answer, None) | _ => Error(Operation.NotYetImplemented) } } diff --git a/packages/squiggle-lang/src/rescript/ReducerInterface/ReducerInterface_GenericDistribution.res b/packages/squiggle-lang/src/rescript/ReducerInterface/ReducerInterface_GenericDistribution.res index b2bf8a32..bee0e3fe 100644 --- a/packages/squiggle-lang/src/rescript/ReducerInterface/ReducerInterface_GenericDistribution.res +++ b/packages/squiggle-lang/src/rescript/ReducerInterface/ReducerInterface_GenericDistribution.res @@ -251,27 +251,34 @@ let rec dispatchToGenericOutput = (call: ExpressionValue.functionCall, _environm | ("normalize", [EvDistribution(dist)]) => Helpers.toDistFn(Normalize, dist) | ("klDivergence", [EvDistribution(a), EvDistribution(b)]) => Some(runGenericOperation(FromDist(ToScore(KLDivergence(b)), a))) - | ("logScore", [EvDistribution(prior), EvDistribution(prediction), EvNumber(answer)]) | ( - "logScore", - [EvDistribution(prior), EvDistribution(prediction), EvDistribution(Symbolic(#Float(answer)))], + "logScoreWithPointResolution", + [EvDistribution(prediction), EvNumber(answer), EvDistribution(prior)], + ) + | ( + "logScoreWithPointResolution", + [EvDistribution(prediction), EvDistribution(Symbolic(#Float(answer))), EvDistribution(prior)], ) => - runGenericOperation(FromDist(ToScore(LogScore(prediction, answer)), prior))->Some - | ("logScore", [EvRecord(r)]) => - recurRecordArgs("logScore", ["prior", "prediction", "answer"], r, _environment) - | ("increment", [EvNumber(x)]) => (x +. 1.0)->DistributionOperation.Float->Some - | ("increment", [EvRecord(r)]) => recurRecordArgs("increment", ["incrementee"], r, _environment) - | ("logScoreAgainstImproperPrior", [EvDistribution(prediction), EvNumber(answer)]) + runGenericOperation(FromDist(ToScore(LogScore(answer, prior->Some)), prediction))->Some + | ("logScoreWithPointResolution", [EvDistribution(prediction), EvNumber(answer)]) | ( - "logScoreAgainstImproperPrior", + "logScoreWithPointResolution", [EvDistribution(prediction), EvDistribution(Symbolic(#Float(answer)))], ) => - runGenericOperation( - FromDist( - ToScore(LogScore(prediction, answer)), - Helpers.constructNonNormalizedPointSet(~supportOf=prediction, _ => 1.0), + runGenericOperation(FromDist(ToScore(LogScore(answer, None)), prediction))->Some + | ("logScore", [EvRecord(r)]) => + [ + recurRecordArgs( + "logScoreWithPointResolution", + ["estimate", "answer", "prior"], + r, + _environment, ), - )->Some + recurRecordArgs("klDivergence", ["estimate", "answer"], r, _environment), + recurRecordArgs("logScoreWithPointResolution", ["estimate", "answer"], r, _environment), + ]->E.A.O.firstSome + | ("increment", [EvNumber(x)]) => (x +. 1.0)->DistributionOperation.Float->Some // this tests recurRecordArgs function + | ("increment", [EvRecord(r)]) => recurRecordArgs("increment", ["incrementee"], r, _environment) | ("isNormalized", [EvDistribution(dist)]) => Helpers.toBoolFn(IsNormalized, dist) | ("toPointSet", [EvDistribution(dist)]) => Helpers.toDistFn(ToPointSet, dist) | ("scaleLog", [EvDistribution(dist)]) => diff --git a/packages/squiggle-lang/src/rescript/Utility/E.res b/packages/squiggle-lang/src/rescript/Utility/E.res index 63999ce6..35439230 100644 --- a/packages/squiggle-lang/src/rescript/Utility/E.res +++ b/packages/squiggle-lang/src/rescript/Utility/E.res @@ -631,6 +631,17 @@ module A = { } } } + let rec firstSome = (optionals: array>): option<'a> => { + let optionals' = optionals->Belt.List.fromArray + switch optionals' { + | list{} => None + | list{x, ...xs} => + switch x { + | Some(_) => x + | None => xs->Belt.List.toArray->firstSome + } + } + } } module R = {