From 3d71d7bf4e5ad8cb174d604d8371cd2c3a2e856a Mon Sep 17 00:00:00 2001 From: Ozzie Gooen Date: Wed, 13 Jul 2022 08:56:22 -0700 Subject: [PATCH] Replaced some scoring functions with functionRegistry functions --- .../FunctionRegistry_Core.res | 3 + .../FunctionRegistry_Helpers.res | 16 ++++++ .../FunctionRegistry_Library.res | 57 ++++++++++++++++++- .../ReducerInterface_GenericDistribution.res | 48 ---------------- 4 files changed, 75 insertions(+), 49 deletions(-) diff --git a/packages/squiggle-lang/src/rescript/FunctionRegistry/FunctionRegistry_Core.res b/packages/squiggle-lang/src/rescript/FunctionRegistry/FunctionRegistry_Core.res index bbca7bc7..4eff7945 100644 --- a/packages/squiggle-lang/src/rescript/FunctionRegistry/FunctionRegistry_Core.res +++ b/packages/squiggle-lang/src/rescript/FunctionRegistry/FunctionRegistry_Core.res @@ -8,6 +8,7 @@ type rec frType = | FRTypeNumber | FRTypeNumeric | FRTypeDistOrNumber + | FRTypeDist | FRTypeLambda | FRTypeRecord(frTypeRecord) | FRTypeDict(frType) @@ -60,6 +61,7 @@ module FRType = { switch t { | FRTypeNumber => "number" | FRTypeNumeric => "numeric" + | FRTypeDist => "distribution" | FRTypeDistOrNumber => "distribution|number" | FRTypeRecord(r) => { let input = ((name, frType): frTypeRecordParam) => `${name}: ${toString(frType)}` @@ -98,6 +100,7 @@ module FRType = { | (FRTypeDistOrNumber, IEvDistribution(Symbolic(#Float(f)))) => Some(FRValueDistOrNumber(FRValueNumber(f))) | (FRTypeDistOrNumber, IEvDistribution(f)) => Some(FRValueDistOrNumber(FRValueDist(f))) + | (FRTypeDist, IEvDistribution(f)) => Some(FRValueDist(f)) | (FRTypeNumeric, IEvNumber(f)) => Some(FRValueNumber(f)) | (FRTypeNumeric, IEvDistribution(Symbolic(#Float(f)))) => Some(FRValueNumber(f)) | (FRTypeLambda, IEvLambda(f)) => Some(FRValueLambda(f)) diff --git a/packages/squiggle-lang/src/rescript/FunctionRegistry/FunctionRegistry_Helpers.res b/packages/squiggle-lang/src/rescript/FunctionRegistry/FunctionRegistry_Helpers.res index 46ae18f9..a1d530b1 100644 --- a/packages/squiggle-lang/src/rescript/FunctionRegistry/FunctionRegistry_Helpers.res +++ b/packages/squiggle-lang/src/rescript/FunctionRegistry/FunctionRegistry_Helpers.res @@ -27,6 +27,12 @@ module Prepare = { | _ => Error(impossibleError) } + let threeArgs = (inputs: ts): result => + switch inputs { + | [FRValueRecord([(_, n1), (_, n2), (_, n3)])] => Ok([n1, n2, n3]) + | _ => Error(impossibleError) + } + let toArgs = (inputs: ts): result => switch inputs { | [FRValueRecord(args)] => args->E.A2.fmap(((_, b)) => b)->Ok @@ -57,6 +63,13 @@ module Prepare = { } } + let twoDist = (values: ts): result<(DistributionTypes.genericDist, DistributionTypes.genericDist), err> => { + switch values { + | [FRValueDist(a1), FRValueDist(a2)] => Ok(a1, a2) + | _ => Error(impossibleError) + } + } + let twoNumbers = (values: ts): result<(float, float), err> => { switch values { | [FRValueNumber(a1), FRValueNumber(a2)] => Ok(a1, a2) @@ -81,6 +94,9 @@ module Prepare = { module Record = { let twoDistOrNumber = (values: ts): result<(frValueDistOrNumber, frValueDistOrNumber), err> => values->ToValueArray.Record.twoArgs->E.R.bind(twoDistOrNumber) + + let twoDist = (values: ts): result<(DistributionTypes.genericDist, DistributionTypes.genericDist), err> => + values->ToValueArray.Record.twoArgs->E.R.bind(twoDist) } } diff --git a/packages/squiggle-lang/src/rescript/FunctionRegistry/FunctionRegistry_Library.res b/packages/squiggle-lang/src/rescript/FunctionRegistry/FunctionRegistry_Library.res index a37b8dc4..9e8a2681 100644 --- a/packages/squiggle-lang/src/rescript/FunctionRegistry/FunctionRegistry_Library.res +++ b/packages/squiggle-lang/src/rescript/FunctionRegistry/FunctionRegistry_Library.res @@ -49,7 +49,7 @@ let inputsTodist = (inputs: array, makeDist) => { expressionValue } -let registry = [ +let registryStart = [ Function.make( ~name="toContinuousPointSet", ~definitions=[ @@ -510,3 +510,58 @@ to(5,10) (), ), ] + +let runScoring = (estimate, answer, prior) => { + GenericDist.Score.logScore(~estimate, ~answer, ~prior) + ->E.R2.fmap(FunctionRegistry_Helpers.Wrappers.evNumber) + ->E.R2.errMap(DistributionTypes.Error.toString) +} + +let scoreFunctions = [ + Function.make( + ~name="Score", + ~definitions=[ + FnDefinition.make( + ~name="logScore", + ~inputs=[ + FRTypeRecord([ + ("estimate", FRTypeDist), + ("answer", FRTypeDistOrNumber), + ("prior", FRTypeDist), + ]), + ], + ~run=(inputs, _) => { + switch FunctionRegistry_Helpers.Prepare.ToValueArray.Record.threeArgs(inputs) { + | Ok([FRValueDist(estimate), FRValueDistOrNumber(FRValueDist(d)), FRValueDist(prior)]) => + runScoring(estimate, Score_Dist(d), Some(prior)) + | Ok([ + FRValueDist(estimate), + FRValueDistOrNumber(FRValueNumber(d)), + FRValueDist(prior), + ]) => + runScoring(estimate, Score_Scalar(d), Some(prior)) + | Error(e) => Error(e) + | _ => Error(FunctionRegistry_Helpers.impossibleError) + } + }, + ), + FnDefinition.make( + ~name="logScore", + ~inputs=[FRTypeRecord([("estimate", FRTypeDist), ("answer", FRTypeDistOrNumber)])], + ~run=(inputs, _) => { + switch FunctionRegistry_Helpers.Prepare.ToValueArray.Record.twoArgs(inputs) { + | Ok([FRValueDist(estimate), FRValueDistOrNumber(FRValueDist(d))]) => + runScoring(estimate, Score_Dist(d), None) + | Ok([FRValueDist(estimate), FRValueDistOrNumber(FRValueNumber(d))]) => + runScoring(estimate, Score_Scalar(d), None) + | Error(e) => Error(e) + | _ => Error(FunctionRegistry_Helpers.impossibleError) + } + }, + ), + ], + (), + ), +] + +let registry = E.A.append(registryStart, scoreFunctions) diff --git a/packages/squiggle-lang/src/rescript/ReducerInterface/ReducerInterface_GenericDistribution.res b/packages/squiggle-lang/src/rescript/ReducerInterface/ReducerInterface_GenericDistribution.res index e66fba9b..d17d5061 100644 --- a/packages/squiggle-lang/src/rescript/ReducerInterface/ReducerInterface_GenericDistribution.res +++ b/packages/squiggle-lang/src/rescript/ReducerInterface/ReducerInterface_GenericDistribution.res @@ -223,54 +223,6 @@ let dispatchToGenericOutput = (call: IEV.functionCall, env: DistributionOperatio ~env, ), ) - | ( - "klDivergence", - [IEvDistribution(prediction), IEvDistribution(answer), IEvDistribution(prior)], - ) => - Some( - DistributionOperation.run( - FromDist( - #ToScore( - LogScore(DistributionTypes.DistributionOperation.Score_Dist(answer), Some(prior)), - ), - prediction, - ), - ~env, - ), - ) - | ( - "logScoreWithPointAnswer", - [IEvDistribution(prediction), IEvNumber(answer), IEvDistribution(prior)], - ) - | ( - "logScoreWithPointAnswer", - [ - IEvDistribution(prediction), - IEvDistribution(Symbolic(#Float(answer))), - IEvDistribution(prior), - ], - ) => - DistributionOperation.run( - FromDist( - #ToScore( - LogScore(DistributionTypes.DistributionOperation.Score_Scalar(answer), prior->Some), - ), - prediction, - ), - ~env, - )->Some - | ("logScoreWithPointAnswer", [IEvDistribution(prediction), IEvNumber(answer)]) - | ( - "logScoreWithPointAnswer", - [IEvDistribution(prediction), IEvDistribution(Symbolic(#Float(answer)))], - ) => - DistributionOperation.run( - FromDist( - #ToScore(LogScore(DistributionTypes.DistributionOperation.Score_Scalar(answer), None)), - prediction, - ), - ~env, - )->Some | ("isNormalized", [IEvDistribution(dist)]) => Helpers.toBoolFn(IsNormalized, dist, ~env) | ("toPointSet", [IEvDistribution(dist)]) => Helpers.toDistFn(ToPointSet, dist, ~env) | ("scaleLog", [IEvDistribution(dist)]) =>