diff --git a/packages/squiggle-lang/src/rescript/Distributions/DistributionOperation.res b/packages/squiggle-lang/src/rescript/Distributions/DistributionOperation.res index 9af9917d..61b5cd6b 100644 --- a/packages/squiggle-lang/src/rescript/Distributions/DistributionOperation.res +++ b/packages/squiggle-lang/src/rescript/Distributions/DistributionOperation.res @@ -150,7 +150,12 @@ let rec run = (~env, functionCallInfo: functionCallInfo): outputType => { ->E.R2.fmap(r => Float(r)) ->OutputLocal.fromResult | ToScore(LogScore(answer, prior)) => - GenericDist.Score.logScoreWithPointResolution(dist, answer, prior, ~toPointSetFn) + GenericDist.Score.logScoreWithPointResolution( + ~prediction=dist, + ~answer, + ~prior, + ~toPointSetFn, + ) ->E.R2.fmap(r => Float(r)) ->OutputLocal.fromResult | ToBool(IsNormalized) => dist->GenericDist.isNormalized->Bool @@ -267,8 +272,12 @@ module Constructors = { let normalize = (~env, dist) => C.normalize(dist)->run(~env)->toDistR let isNormalized = (~env, dist) => C.isNormalized(dist)->run(~env)->toBoolR let klDivergence = (~env, dist1, dist2) => C.klDivergence(dist1, dist2)->run(~env)->toFloatR - let logScoreWithPointResolution = (~env, prediction, answer, prior) => - C.logScoreWithPointResolution(prediction, answer, prior)->run(~env)->toFloatR + let logScoreWithPointResolution = ( + ~env, + ~prediction: DistributionTypes.genericDist, + ~answer: float, + ~prior: option, + ) => C.logScoreWithPointResolution(~prediction, ~answer, ~prior)->run(~env)->toFloatR let toPointSet = (~env, dist) => C.toPointSet(dist)->run(~env)->toDistR let toSampleSet = (~env, dist, n) => C.toSampleSet(dist, n)->run(~env)->toDistR let fromSamples = (~env, xs) => C.fromSamples(xs)->run(~env)->toDistR diff --git a/packages/squiggle-lang/src/rescript/Distributions/DistributionOperation.resi b/packages/squiggle-lang/src/rescript/Distributions/DistributionOperation.resi index 7941489e..aa006c06 100644 --- a/packages/squiggle-lang/src/rescript/Distributions/DistributionOperation.resi +++ b/packages/squiggle-lang/src/rescript/Distributions/DistributionOperation.resi @@ -65,9 +65,9 @@ module Constructors: { @genType let logScoreWithPointResolution: ( ~env: env, - genericDist, - float, - option, + ~prediction: genericDist, + ~answer: float, + ~prior: option, ) => result @genType let toPointSet: (~env: env, genericDist) => result diff --git a/packages/squiggle-lang/src/rescript/Distributions/DistributionTypes.res b/packages/squiggle-lang/src/rescript/Distributions/DistributionTypes.res index f377a616..2bb409ad 100644 --- a/packages/squiggle-lang/src/rescript/Distributions/DistributionTypes.res +++ b/packages/squiggle-lang/src/rescript/Distributions/DistributionTypes.res @@ -162,9 +162,9 @@ module Constructors = { let truncate = (dist, left, right): t => FromDist(ToDist(Truncate(left, right)), dist) let inspect = (dist): t => FromDist(ToDist(Inspect), dist) let klDivergence = (dist1, dist2): t => FromDist(ToScore(KLDivergence(dist2)), dist1) - let logScoreWithPointResolution = (prior, prediction, answer): t => FromDist( - ToScore(LogScore(prediction, answer)), - prior, + let logScoreWithPointResolution = (~prediction, ~answer, ~prior): t => FromDist( + ToScore(LogScore(answer, prior)), + prediction, ) let scalePower = (dist, n): t => FromDist(ToDist(Scale(#Power, n)), dist) let scaleLogarithm = (dist, n): t => FromDist(ToDist(Scale(#Logarithm, n)), dist) diff --git a/packages/squiggle-lang/src/rescript/Distributions/GenericDist.res b/packages/squiggle-lang/src/rescript/Distributions/GenericDist.res index c2b03474..1df10240 100644 --- a/packages/squiggle-lang/src/rescript/Distributions/GenericDist.res +++ b/packages/squiggle-lang/src/rescript/Distributions/GenericDist.res @@ -68,18 +68,21 @@ module Score = { } let logScoreWithPointResolution = ( - prediction, - answer, - prior, + ~prediction: DistributionTypes.genericDist, + ~answer: float, + ~prior: option, ~toPointSetFn: toPointSetFn, ): result => { switch prior { | Some(prior') => - E.R.merge(toPointSetFn(prior'), toPointSetFn(prediction))->E.R.bind(((a, b)) => + E.R.merge(toPointSetFn(prior'), toPointSetFn(prediction))->E.R.bind((( + prior'', + prediction'', + )) => PointSetDist.T.logScoreWithPointResolution( - b, - answer, - a->Some, + ~prediction=prediction'', + ~answer, + ~prior=prior''->Some, )->E.R2.errMap(x => DistributionTypes.OperationError(x)) ) | None => @@ -87,9 +90,9 @@ module Score = { ->toPointSetFn ->E.R.bind(x => PointSetDist.T.logScoreWithPointResolution( - x, - answer, - None, + ~prediction=x, + ~answer, + ~prior=None, )->E.R2.errMap(x => DistributionTypes.OperationError(x)) ) } diff --git a/packages/squiggle-lang/src/rescript/Distributions/GenericDist.resi b/packages/squiggle-lang/src/rescript/Distributions/GenericDist.resi index ea9a4110..79fb54ab 100644 --- a/packages/squiggle-lang/src/rescript/Distributions/GenericDist.resi +++ b/packages/squiggle-lang/src/rescript/Distributions/GenericDist.resi @@ -26,9 +26,9 @@ let toFloatOperation: ( module Score: { let klDivergence: (t, t, ~toPointSetFn: toPointSetFn) => result let logScoreWithPointResolution: ( - t, - float, - option, + ~prediction: t, + ~answer: float, + ~prior: option, ~toPointSetFn: toPointSetFn, ) => result } diff --git a/packages/squiggle-lang/src/rescript/Distributions/PointSetDist/Continuous.res b/packages/squiggle-lang/src/rescript/Distributions/PointSetDist/Continuous.res index c1d87946..3661a531 100644 --- a/packages/squiggle-lang/src/rescript/Distributions/PointSetDist/Continuous.res +++ b/packages/squiggle-lang/src/rescript/Distributions/PointSetDist/Continuous.res @@ -279,7 +279,7 @@ module T = Dist({ ) newShape->E.R2.fmap(x => x->make->integralEndY) } - let logScoreWithPointResolution = (prediction: t, answer: float, prior: option) => { + let logScoreWithPointResolution = (~prediction: t, ~answer: float, ~prior: option) => { let priorPdf = prior->E.O2.fmap((shape, x) => XYShape.XtoY.linear(x, shape.xyShape)) let predictionPdf = x => XYShape.XtoY.linear(x, prediction.xyShape) PointSetDist_Scoring.LogScoreWithPointResolution.score(~priorPdf, ~predictionPdf, ~answer) diff --git a/packages/squiggle-lang/src/rescript/Distributions/PointSetDist/Discrete.res b/packages/squiggle-lang/src/rescript/Distributions/PointSetDist/Discrete.res index c0e3f3a8..fea5db6f 100644 --- a/packages/squiggle-lang/src/rescript/Distributions/PointSetDist/Discrete.res +++ b/packages/squiggle-lang/src/rescript/Distributions/PointSetDist/Discrete.res @@ -229,7 +229,7 @@ module T = Dist({ answer, )->E.R2.fmap(integralEndY) } - let logScoreWithPointResolution = (prediction: t, answer: float, prior: option) => { + let logScoreWithPointResolution = (~prediction: t, ~answer: float, ~prior: option) => { Error(Operation.NotYetImplemented) } }) diff --git a/packages/squiggle-lang/src/rescript/Distributions/PointSetDist/Distributions.res b/packages/squiggle-lang/src/rescript/Distributions/PointSetDist/Distributions.res index f28b6369..2d0358ec 100644 --- a/packages/squiggle-lang/src/rescript/Distributions/PointSetDist/Distributions.res +++ b/packages/squiggle-lang/src/rescript/Distributions/PointSetDist/Distributions.res @@ -34,7 +34,11 @@ module type dist = { let mean: t => float let variance: t => float let klDivergence: (t, t) => result - let logScoreWithPointResolution: (t, float, option) => result + let logScoreWithPointResolution: ( + ~prediction: t, + ~answer: float, + ~prior: option, + ) => result } module Dist = (T: dist) => { diff --git a/packages/squiggle-lang/src/rescript/Distributions/PointSetDist/Mixed.res b/packages/squiggle-lang/src/rescript/Distributions/PointSetDist/Mixed.res index d3f09798..42a88909 100644 --- a/packages/squiggle-lang/src/rescript/Distributions/PointSetDist/Mixed.res +++ b/packages/squiggle-lang/src/rescript/Distributions/PointSetDist/Mixed.res @@ -306,7 +306,7 @@ module T = Dist({ let klContinuousPart = Continuous.T.klDivergence(prediction.continuous, answer.continuous) E.R.merge(klDiscretePart, klContinuousPart)->E.R2.fmap(t => fst(t) +. snd(t)) } - let logScoreWithPointResolution = (prediction: t, answer: float, prior: option) => { + let logScoreWithPointResolution = (~prediction: t, ~answer: float, ~prior: option) => { Error(Operation.NotYetImplemented) } }) diff --git a/packages/squiggle-lang/src/rescript/Distributions/PointSetDist/PointSetDist.res b/packages/squiggle-lang/src/rescript/Distributions/PointSetDist/PointSetDist.res index cdeaef5a..d21a7383 100644 --- a/packages/squiggle-lang/src/rescript/Distributions/PointSetDist/PointSetDist.res +++ b/packages/squiggle-lang/src/rescript/Distributions/PointSetDist/PointSetDist.res @@ -203,11 +203,12 @@ module T = Dist({ | (m1, m2) => Mixed.T.klDivergence(m1->toMixed, m2->toMixed) } - let logScoreWithPointResolution = (prediction: t, answer: float, prior: option) => { + let logScoreWithPointResolution = (~prediction: t, ~answer: float, ~prior: option) => { switch (prior, prediction) { | (Some(Continuous(t1)), Continuous(t2)) => - Continuous.T.logScoreWithPointResolution(t2, answer, t1->Some) - | (None, Continuous(t2)) => Continuous.T.logScoreWithPointResolution(t2, answer, None) + Continuous.T.logScoreWithPointResolution(~prediction=t2, ~answer, ~prior=t1->Some) + | (None, Continuous(t2)) => + Continuous.T.logScoreWithPointResolution(~prediction=t2, ~answer, ~prior=None) | _ => Error(Operation.NotYetImplemented) } } diff --git a/packages/squiggle-lang/src/rescript/Distributions/PointSetDist/PointSetDist_Scoring.res b/packages/squiggle-lang/src/rescript/Distributions/PointSetDist/PointSetDist_Scoring.res index bfc40071..532bc76c 100644 --- a/packages/squiggle-lang/src/rescript/Distributions/PointSetDist/PointSetDist_Scoring.res +++ b/packages/squiggle-lang/src/rescript/Distributions/PointSetDist/PointSetDist_Scoring.res @@ -24,7 +24,7 @@ module LogScoreWithPointResolution = { ): result => { let numerator = answer->predictionPdf if numerator < 0.0 { - Operation.ComplexNumberError->Error + Operation.PdfInvalidError->Error } else if numerator == 0.0 { infinity->Ok } else { diff --git a/packages/squiggle-lang/src/rescript/ReducerInterface/ReducerInterface_GenericDistribution.res b/packages/squiggle-lang/src/rescript/ReducerInterface/ReducerInterface_GenericDistribution.res index 4021649f..0f29f9db 100644 --- a/packages/squiggle-lang/src/rescript/ReducerInterface/ReducerInterface_GenericDistribution.res +++ b/packages/squiggle-lang/src/rescript/ReducerInterface/ReducerInterface_GenericDistribution.res @@ -162,24 +162,6 @@ module Helpers = { } } } - let constructNonNormalizedPointSet = ( - ~supportOf: DistributionTypes.genericDist, - fn: float => float, - env: DistributionOperation.env, - ): DistributionTypes.genericDist => { - let cdf = x => toFloatFn(#Cdf(x), supportOf, ~env) - let leftEndpoint = cdf(MagicNumbers.Epsilon.ten) - let rightEndpoint = cdf(1.0 -. MagicNumbers.Epsilon.ten) - let xs = switch (leftEndpoint, rightEndpoint) { - | (Some(Float(a)), Some(Float(b))) => - E.A.Floats.range(a, b, MagicNumbers.Environment.defaultXYPointLength) - | _ => [] - } - {xs: xs, ys: E.A.fmap(fn, xs)} - ->Continuous.make - ->PointSetTypes.Continuous - ->DistributionTypes.PointSet - } let klDivergenceWithPrior = ( prediction: DistributionTypes.genericDist, diff --git a/packages/squiggle-lang/src/rescript/Utility/E.res b/packages/squiggle-lang/src/rescript/Utility/E.res index 35439230..64729324 100644 --- a/packages/squiggle-lang/src/rescript/Utility/E.res +++ b/packages/squiggle-lang/src/rescript/Utility/E.res @@ -620,6 +620,7 @@ module A = { | Some(o) => o | None => [] } + // REturns `None` there are no non-`None` elements let rec arrSomeToSomeArr = (optionals: array>): option> => { let optionals' = optionals->Belt.List.fromArray switch optionals' { @@ -631,17 +632,7 @@ module A = { } } } - let rec firstSome = (optionals: array>): option<'a> => { - let optionals' = optionals->Belt.List.fromArray - switch optionals' { - | list{} => None - | list{x, ...xs} => - switch x { - | Some(_) => x - | None => xs->Belt.List.toArray->firstSome - } - } - } + let firstSome = x => Belt.Array.getBy(x, O.isSome) } module R = { diff --git a/packages/squiggle-lang/src/rescript/Utility/Operation.res b/packages/squiggle-lang/src/rescript/Utility/Operation.res index cfa18925..3f56493b 100644 --- a/packages/squiggle-lang/src/rescript/Utility/Operation.res +++ b/packages/squiggle-lang/src/rescript/Utility/Operation.res @@ -55,7 +55,7 @@ type operationError = | ComplexNumberError | InfinityError | NegativeInfinityError - | LogicallyInconsistentPathwayError + | PdfInvalidError | NotYetImplemented // should be removed when `klDivergence` for mixed and discrete is implemented. @genType @@ -69,7 +69,7 @@ module Error = { | ComplexNumberError => "Operation returned complex result" | InfinityError => "Operation returned positive infinity" | NegativeInfinityError => "Operation returned negative infinity" - | LogicallyInconsistentPathwayError => "This pathway should have been logically unreachable" + | PdfInvalidError => "This Pdf is invalid" | NotYetImplemented => "This pathway is not yet implemented" } }