diff --git a/packages/squiggle-lang/src/rescript/Distributions/DistributionOperation.res b/packages/squiggle-lang/src/rescript/Distributions/DistributionOperation.res index 1b2fb0e9..9f8c7549 100644 --- a/packages/squiggle-lang/src/rescript/Distributions/DistributionOperation.res +++ b/packages/squiggle-lang/src/rescript/Distributions/DistributionOperation.res @@ -146,7 +146,7 @@ let rec run = (~env, functionCallInfo: functionCallInfo): outputType => { } | #ToDist(Normalize) => dist->GenericDist.normalize->Dist | #ToScore(LogScore(answer, prior)) => - GenericDist.Score.logScore(~estimate=Score_Dist(dist), ~answer, ~prior) + GenericDist.Score.logScore(~estimate=dist, ~answer, ~prior) ->E.R2.fmap(s => Float(s)) ->OutputLocal.fromResult | #ToBool(IsNormalized) => dist->GenericDist.isNormalized->Bool diff --git a/packages/squiggle-lang/src/rescript/Distributions/DistributionOperation.resi b/packages/squiggle-lang/src/rescript/Distributions/DistributionOperation.resi index c39dab7f..d5ca7326 100644 --- a/packages/squiggle-lang/src/rescript/Distributions/DistributionOperation.resi +++ b/packages/squiggle-lang/src/rescript/Distributions/DistributionOperation.resi @@ -72,7 +72,7 @@ module Constructors: { ~env: env, genericDist, genericDist, - DistributionTypes.DistributionOperation.genericDistOrScalar, + genericDist, ) => result @genType let distEstimateScalarAnswer: (~env: env, genericDist, float) => result @@ -81,7 +81,7 @@ module Constructors: { ~env: env, genericDist, float, - DistributionTypes.DistributionOperation.genericDistOrScalar, + genericDist, ) => result } @genType diff --git a/packages/squiggle-lang/src/rescript/Distributions/DistributionTypes.res b/packages/squiggle-lang/src/rescript/Distributions/DistributionTypes.res index b607b8e4..0c119ea4 100644 --- a/packages/squiggle-lang/src/rescript/Distributions/DistributionTypes.res +++ b/packages/squiggle-lang/src/rescript/Distributions/DistributionTypes.res @@ -100,7 +100,7 @@ module DistributionOperation = { type genericDistOrScalar = Score_Dist(genericDist) | Score_Scalar(float) - type toScore = LogScore(genericDistOrScalar, option) + type toScore = LogScore(genericDistOrScalar, option) type fromFloat = [ | #ToFloat(toFloat) diff --git a/packages/squiggle-lang/src/rescript/Distributions/GenericDist.res b/packages/squiggle-lang/src/rescript/Distributions/GenericDist.res index 17adf244..db2a559a 100644 --- a/packages/squiggle-lang/src/rescript/Distributions/GenericDist.res +++ b/packages/squiggle-lang/src/rescript/Distributions/GenericDist.res @@ -133,13 +133,11 @@ let toPointSet = ( module Score = { type genericDistOrScalar = DistributionTypes.DistributionOperation.genericDistOrScalar - type pointSet_ScoreDistOrScalar = PSDist(PointSetTypes.pointSetDist) | PSScalar(float) - let argsMake = ( - ~esti: genericDistOrScalar, - ~answ: genericDistOrScalar, - ~prior: option, - ): result => { + let argsMake = (~esti: t, ~answ: genericDistOrScalar, ~prior: option): result< + PointSetDist_Scoring.scoreArgs, + error, + > => { let toPointSetFn = t => toPointSet( t, @@ -148,21 +146,20 @@ module Score = { ~xSelection=#ByWeight, (), ) - let prior': option> = switch prior { + let prior': option> = switch prior { | None => None - | Some(Score_Dist(d)) => toPointSetFn(d)->E.R.bind(x => x->PSDist->Ok)->Some - | Some(Score_Scalar(s)) => s->PSScalar->Ok->Some + | Some(d) => toPointSetFn(d)->Some } let twoDists = (~toPointSetFn, esti': t, answ': t): result< (PointSetTypes.pointSetDist, PointSetTypes.pointSetDist), error, > => E.R.merge(toPointSetFn(esti'), toPointSetFn(answ')) switch (esti, answ, prior') { - | (Score_Dist(esti'), Score_Dist(answ'), None) => + | (esti', Score_Dist(answ'), None) => twoDists(~toPointSetFn, esti', answ')->E.R2.fmap(((esti'', answ'')) => {estimate: esti'', answer: answ'', prior: None}->PointSetDist_Scoring.DistAnswer ) - | (Score_Dist(esti'), Score_Dist(answ'), Some(Ok(PSDist(prior'')))) => + | (esti', Score_Dist(answ'), Some(Ok(prior''))) => twoDists(~toPointSetFn, esti', answ')->E.R2.fmap(((esti'', answ'')) => { estimate: esti'', @@ -170,8 +167,7 @@ module Score = { prior: Some(prior''), }->PointSetDist_Scoring.DistAnswer ) - | (Score_Dist(_), _, Some(Ok(PSScalar(_)))) => DistributionTypes.Unreachable->Error - | (Score_Dist(esti'), Score_Scalar(answ'), None) => + | (esti', Score_Scalar(answ'), None) => toPointSetFn(esti')->E.R2.fmap(esti'' => { estimate: esti'', @@ -179,7 +175,7 @@ module Score = { prior: None, }->PointSetDist_Scoring.ScalarAnswer ) - | (Score_Dist(esti'), Score_Scalar(answ'), Some(Ok(PSDist(prior'')))) => + | (esti', Score_Scalar(answ'), Some(Ok(prior''))) => toPointSetFn(esti')->E.R2.fmap(esti'' => { estimate: esti'', @@ -187,19 +183,14 @@ module Score = { prior: Some(prior''), }->PointSetDist_Scoring.ScalarAnswer ) - | (Score_Scalar(_), Score_Dist(_), None) => NotYetImplemented->Error - | (Score_Scalar(_), Score_Dist(_), Some(Ok(PSScalar(_)))) => NotYetImplemented->Error - | (Score_Scalar(_), _, Some(Ok(PSDist(_)))) => DistributionTypes.Unreachable->Error - | (Score_Scalar(_), Score_Scalar(_), _) => NotYetImplemented->Error | (_, _, Some(Error(err))) => err->Error } } - let logScore = ( - ~estimate: genericDistOrScalar, - ~answer: genericDistOrScalar, - ~prior: option, - ): result => + let logScore = (~estimate: t, ~answer: genericDistOrScalar, ~prior: option): result< + float, + error, + > => argsMake(~esti=estimate, ~answ=answer, ~prior)->E.R.bind(x => x->PointSetDist.logScore->E.R2.errMap(y => DistributionTypes.OperationError(y)) ) diff --git a/packages/squiggle-lang/src/rescript/Distributions/GenericDist.resi b/packages/squiggle-lang/src/rescript/Distributions/GenericDist.resi index 432ae847..d4f8933d 100644 --- a/packages/squiggle-lang/src/rescript/Distributions/GenericDist.resi +++ b/packages/squiggle-lang/src/rescript/Distributions/GenericDist.resi @@ -26,9 +26,9 @@ let toFloatOperation: ( module Score: { let logScore: ( - ~estimate: DistributionTypes.DistributionOperation.genericDistOrScalar, + ~estimate: t, ~answer: DistributionTypes.DistributionOperation.genericDistOrScalar, - ~prior: option, + ~prior: option, ) => result } diff --git a/packages/squiggle-lang/src/rescript/Distributions/PointSetDist/PointSetDist_Scoring.res b/packages/squiggle-lang/src/rescript/Distributions/PointSetDist/PointSetDist_Scoring.res index 5e8eb489..8e5d71a4 100644 --- a/packages/squiggle-lang/src/rescript/Distributions/PointSetDist/PointSetDist_Scoring.res +++ b/packages/squiggle-lang/src/rescript/Distributions/PointSetDist/PointSetDist_Scoring.res @@ -19,7 +19,7 @@ module WithDistAnswer = { float, Operation.Error.t, > => - // We decided that negative infinity, not an error at answerElement = 0.0, is a desirable value. + // We decided that 0.0, not an error at answerElement = 0.0, is a desirable value. if answerElement == 0.0 { Ok(0.0) } else if estimateElement == 0.0 { diff --git a/packages/squiggle-lang/src/rescript/ReducerInterface/ReducerInterface_GenericDistribution.res b/packages/squiggle-lang/src/rescript/ReducerInterface/ReducerInterface_GenericDistribution.res index b73f84e0..e66fba9b 100644 --- a/packages/squiggle-lang/src/rescript/ReducerInterface/ReducerInterface_GenericDistribution.res +++ b/packages/squiggle-lang/src/rescript/ReducerInterface/ReducerInterface_GenericDistribution.res @@ -231,10 +231,7 @@ let dispatchToGenericOutput = (call: IEV.functionCall, env: DistributionOperatio DistributionOperation.run( FromDist( #ToScore( - LogScore( - DistributionTypes.DistributionOperation.Score_Dist(answer), - Some(DistributionTypes.DistributionOperation.Score_Dist(prior)), - ), + LogScore(DistributionTypes.DistributionOperation.Score_Dist(answer), Some(prior)), ), prediction, ), @@ -256,10 +253,7 @@ let dispatchToGenericOutput = (call: IEV.functionCall, env: DistributionOperatio DistributionOperation.run( FromDist( #ToScore( - LogScore( - DistributionTypes.DistributionOperation.Score_Scalar(answer), - DistributionTypes.DistributionOperation.Score_Dist(prior)->Some, - ), + LogScore(DistributionTypes.DistributionOperation.Score_Scalar(answer), prior->Some), ), prediction, ),