diff --git a/packages/squiggle-lang/src/rescript/ReducerInterface/ReducerInterface_GenericDistribution.res b/packages/squiggle-lang/src/rescript/ReducerInterface/ReducerInterface_GenericDistribution.res index 616ad39c..4021649f 100644 --- a/packages/squiggle-lang/src/rescript/ReducerInterface/ReducerInterface_GenericDistribution.res +++ b/packages/squiggle-lang/src/rescript/ReducerInterface/ReducerInterface_GenericDistribution.res @@ -180,6 +180,20 @@ module Helpers = { ->PointSetTypes.Continuous ->DistributionTypes.PointSet } + + let klDivergenceWithPrior = ( + prediction: DistributionTypes.genericDist, + answer: DistributionTypes.genericDist, + prior: DistributionTypes.genericDist, + env: DistributionOperation.env, + ) => { + let term1 = DistributionOperation.Constructors.klDivergence(~env, prediction, answer) + let term2 = DistributionOperation.Constructors.klDivergence(~env, prior, answer) + switch E.R.merge(term1, term2)->E.R2.fmap(((a, b)) => a -. b) { + | Ok(x) => x->DistributionOperation.Float->Some + | Error(_) => None + } + } } module SymbolicConstructors = { @@ -268,8 +282,10 @@ let dispatchToGenericOutput = ( ~env, )->Some | ("normalize", [EvDistribution(dist)]) => Helpers.toDistFn(Normalize, dist, ~env) - | ("klDivergence", [EvDistribution(a), EvDistribution(b)]) => - Some(DistributionOperation.run(FromDist(ToScore(KLDivergence(b)), a), ~env)) + | ("klDivergence", [EvDistribution(prediction), EvDistribution(answer)]) => + Some(DistributionOperation.run(FromDist(ToScore(KLDivergence(answer)), prediction), ~env)) + | ("klDivergence", [EvDistribution(prediction), EvDistribution(answer), EvDistribution(prior)]) => + Helpers.klDivergenceWithPrior(prediction, answer, prior, env) | ( "logScoreWithPointAnswer", [EvDistribution(prediction), EvNumber(answer), EvDistribution(prior)],