diff --git a/packages/squiggle-lang/__tests__/Distributions/DistributionOperation_test.res b/packages/squiggle-lang/__tests__/Distributions/DistributionOperation_test.res index 54ce40a8..5af6f8e0 100644 --- a/packages/squiggle-lang/__tests__/Distributions/DistributionOperation_test.res +++ b/packages/squiggle-lang/__tests__/Distributions/DistributionOperation_test.res @@ -1,7 +1,7 @@ open Jest open Expect -let env: DistributionOperation.env = { +let env: GenericDist.env = { sampleCount: 100, xyPointLength: 100, } diff --git a/packages/squiggle-lang/__tests__/TestHelpers.res b/packages/squiggle-lang/__tests__/TestHelpers.res index 71805c70..898bb3a9 100644 --- a/packages/squiggle-lang/__tests__/TestHelpers.res +++ b/packages/squiggle-lang/__tests__/TestHelpers.res @@ -29,7 +29,7 @@ let {toFloat, toDist, toString, toError, fmap} = module(DistributionOperation.Ou let fnImage = (theFn, inps) => Js.Array.map(theFn, inps) -let env: DistributionOperation.env = { +let env: GenericDist.env = { sampleCount: MagicNumbers.Environment.defaultSampleCount, xyPointLength: MagicNumbers.Environment.defaultXYPointLength, } diff --git a/packages/squiggle-lang/src/rescript/Distributions/DistributionOperation.res b/packages/squiggle-lang/src/rescript/Distributions/DistributionOperation.res index 1b2fb0e9..6df75749 100644 --- a/packages/squiggle-lang/src/rescript/Distributions/DistributionOperation.res +++ b/packages/squiggle-lang/src/rescript/Distributions/DistributionOperation.res @@ -4,12 +4,9 @@ type error = DistributionTypes.error // TODO: It could be great to use a cache for some calculations (basically, do memoization). Also, better analytics/tracking could go a long way. -type env = { - sampleCount: int, - xyPointLength: int, -} +type env = GenericDist.env -let defaultEnv = { +let defaultEnv:env = { sampleCount: MagicNumbers.Environment.defaultSampleCount, xyPointLength: MagicNumbers.Environment.defaultXYPointLength, } @@ -93,7 +90,7 @@ module OutputLocal = { } } -let rec run = (~env, functionCallInfo: functionCallInfo): outputType => { +let rec run = (~env:env, functionCallInfo: functionCallInfo): outputType => { let {sampleCount, xyPointLength} = env let reCall = (~env=env, ~functionCallInfo=functionCallInfo, ()) => { @@ -146,7 +143,7 @@ let rec run = (~env, functionCallInfo: functionCallInfo): outputType => { } | #ToDist(Normalize) => dist->GenericDist.normalize->Dist | #ToScore(LogScore(answer, prior)) => - GenericDist.Score.logScore(~estimate=Score_Dist(dist), ~answer, ~prior) + GenericDist.Score.logScore(~estimate=dist, ~answer, ~prior, ~env) ->E.R2.fmap(s => Float(s)) ->OutputLocal.fromResult | #ToBool(IsNormalized) => dist->GenericDist.isNormalized->Bool diff --git a/packages/squiggle-lang/src/rescript/Distributions/DistributionOperation.resi b/packages/squiggle-lang/src/rescript/Distributions/DistributionOperation.resi index c39dab7f..68da9534 100644 --- a/packages/squiggle-lang/src/rescript/Distributions/DistributionOperation.resi +++ b/packages/squiggle-lang/src/rescript/Distributions/DistributionOperation.resi @@ -1,11 +1,5 @@ @genType -type env = { - sampleCount: int, - xyPointLength: int, -} - -@genType -let defaultEnv: env +let defaultEnv: GenericDist.env open DistributionTypes @@ -19,14 +13,17 @@ type outputType = | GenDistError(error) @genType -let run: (~env: env, DistributionTypes.DistributionOperation.genericFunctionCallInfo) => outputType +let run: ( + ~env: GenericDist.env, + DistributionTypes.DistributionOperation.genericFunctionCallInfo, +) => outputType let runFromDist: ( - ~env: env, + ~env: GenericDist.env, ~functionCallInfo: DistributionTypes.DistributionOperation.fromDist, genericDist, ) => outputType let runFromFloat: ( - ~env: env, + ~env: GenericDist.env, ~functionCallInfo: DistributionTypes.DistributionOperation.fromFloat, float, ) => outputType @@ -42,90 +39,147 @@ module Output: { let toBool: t => option let toBoolR: t => result let toError: t => option - let fmap: (~env: env, t, DistributionTypes.DistributionOperation.singleParamaterFunction) => t + let fmap: ( + ~env: GenericDist.env, + t, + DistributionTypes.DistributionOperation.singleParamaterFunction, + ) => t } module Constructors: { @genType - let mean: (~env: env, genericDist) => result + let mean: (~env: GenericDist.env, genericDist) => result @genType - let stdev: (~env: env, genericDist) => result + let stdev: (~env: GenericDist.env, genericDist) => result @genType - let variance: (~env: env, genericDist) => result + let variance: (~env: GenericDist.env, genericDist) => result @genType - let sample: (~env: env, genericDist) => result + let sample: (~env: GenericDist.env, genericDist) => result @genType - let cdf: (~env: env, genericDist, float) => result + let cdf: (~env: GenericDist.env, genericDist, float) => result @genType - let inv: (~env: env, genericDist, float) => result + let inv: (~env: GenericDist.env, genericDist, float) => result @genType - let pdf: (~env: env, genericDist, float) => result + let pdf: (~env: GenericDist.env, genericDist, float) => result @genType - let normalize: (~env: env, genericDist) => result + let normalize: (~env: GenericDist.env, genericDist) => result @genType - let isNormalized: (~env: env, genericDist) => result + let isNormalized: (~env: GenericDist.env, genericDist) => result module LogScore: { @genType - let distEstimateDistAnswer: (~env: env, genericDist, genericDist) => result - @genType - let distEstimateDistAnswerWithPrior: ( - ~env: env, + let distEstimateDistAnswer: ( + ~env: GenericDist.env, genericDist, genericDist, - DistributionTypes.DistributionOperation.genericDistOrScalar, ) => result @genType - let distEstimateScalarAnswer: (~env: env, genericDist, float) => result + let distEstimateDistAnswerWithPrior: ( + ~env: GenericDist.env, + genericDist, + genericDist, + genericDist, + ) => result @genType - let distEstimateScalarAnswerWithPrior: ( - ~env: env, + let distEstimateScalarAnswer: ( + ~env: GenericDist.env, genericDist, float, - DistributionTypes.DistributionOperation.genericDistOrScalar, + ) => result + @genType + let distEstimateScalarAnswerWithPrior: ( + ~env: GenericDist.env, + genericDist, + float, + genericDist, ) => result } @genType - let toPointSet: (~env: env, genericDist) => result + let toPointSet: (~env: GenericDist.env, genericDist) => result @genType - let toSampleSet: (~env: env, genericDist, int) => result + let toSampleSet: (~env: GenericDist.env, genericDist, int) => result @genType - let fromSamples: (~env: env, SampleSetDist.t) => result + let fromSamples: (~env: GenericDist.env, SampleSetDist.t) => result @genType - let truncate: (~env: env, genericDist, option, option) => result + let truncate: ( + ~env: GenericDist.env, + genericDist, + option, + option, + ) => result @genType - let inspect: (~env: env, genericDist) => result + let inspect: (~env: GenericDist.env, genericDist) => result @genType - let toString: (~env: env, genericDist) => result + let toString: (~env: GenericDist.env, genericDist) => result @genType - let toSparkline: (~env: env, genericDist, int) => result + let toSparkline: (~env: GenericDist.env, genericDist, int) => result @genType - let algebraicAdd: (~env: env, genericDist, genericDist) => result + let algebraicAdd: (~env: GenericDist.env, genericDist, genericDist) => result @genType - let algebraicMultiply: (~env: env, genericDist, genericDist) => result + let algebraicMultiply: ( + ~env: GenericDist.env, + genericDist, + genericDist, + ) => result @genType - let algebraicDivide: (~env: env, genericDist, genericDist) => result + let algebraicDivide: ( + ~env: GenericDist.env, + genericDist, + genericDist, + ) => result @genType - let algebraicSubtract: (~env: env, genericDist, genericDist) => result + let algebraicSubtract: ( + ~env: GenericDist.env, + genericDist, + genericDist, + ) => result @genType - let algebraicLogarithm: (~env: env, genericDist, genericDist) => result + let algebraicLogarithm: ( + ~env: GenericDist.env, + genericDist, + genericDist, + ) => result @genType - let algebraicPower: (~env: env, genericDist, genericDist) => result + let algebraicPower: ( + ~env: GenericDist.env, + genericDist, + genericDist, + ) => result @genType - let scaleLogarithm: (~env: env, genericDist, float) => result + let scaleLogarithm: (~env: GenericDist.env, genericDist, float) => result @genType - let scaleMultiply: (~env: env, genericDist, float) => result + let scaleMultiply: (~env: GenericDist.env, genericDist, float) => result @genType - let scalePower: (~env: env, genericDist, float) => result + let scalePower: (~env: GenericDist.env, genericDist, float) => result @genType - let pointwiseAdd: (~env: env, genericDist, genericDist) => result + let pointwiseAdd: (~env: GenericDist.env, genericDist, genericDist) => result @genType - let pointwiseMultiply: (~env: env, genericDist, genericDist) => result + let pointwiseMultiply: ( + ~env: GenericDist.env, + genericDist, + genericDist, + ) => result @genType - let pointwiseDivide: (~env: env, genericDist, genericDist) => result + let pointwiseDivide: ( + ~env: GenericDist.env, + genericDist, + genericDist, + ) => result @genType - let pointwiseSubtract: (~env: env, genericDist, genericDist) => result + let pointwiseSubtract: ( + ~env: GenericDist.env, + genericDist, + genericDist, + ) => result @genType - let pointwiseLogarithm: (~env: env, genericDist, genericDist) => result + let pointwiseLogarithm: ( + ~env: GenericDist.env, + genericDist, + genericDist, + ) => result @genType - let pointwisePower: (~env: env, genericDist, genericDist) => result + let pointwisePower: ( + ~env: GenericDist.env, + genericDist, + genericDist, + ) => result } diff --git a/packages/squiggle-lang/src/rescript/Distributions/DistributionTypes.res b/packages/squiggle-lang/src/rescript/Distributions/DistributionTypes.res index b607b8e4..0c119ea4 100644 --- a/packages/squiggle-lang/src/rescript/Distributions/DistributionTypes.res +++ b/packages/squiggle-lang/src/rescript/Distributions/DistributionTypes.res @@ -100,7 +100,7 @@ module DistributionOperation = { type genericDistOrScalar = Score_Dist(genericDist) | Score_Scalar(float) - type toScore = LogScore(genericDistOrScalar, option) + type toScore = LogScore(genericDistOrScalar, option) type fromFloat = [ | #ToFloat(toFloat) diff --git a/packages/squiggle-lang/src/rescript/Distributions/GenericDist.res b/packages/squiggle-lang/src/rescript/Distributions/GenericDist.res index 17adf244..2d59831a 100644 --- a/packages/squiggle-lang/src/rescript/Distributions/GenericDist.res +++ b/packages/squiggle-lang/src/rescript/Distributions/GenericDist.res @@ -6,6 +6,11 @@ type toSampleSetFn = t => result type scaleMultiplyFn = (t, float) => result type pointwiseAddFn = (t, t) => result +type env = { + sampleCount: int, + xyPointLength: int, +} + let isPointSet = (t: t) => switch t { | PointSet(_) => true @@ -133,36 +138,33 @@ let toPointSet = ( module Score = { type genericDistOrScalar = DistributionTypes.DistributionOperation.genericDistOrScalar - type pointSet_ScoreDistOrScalar = PSDist(PointSetTypes.pointSetDist) | PSScalar(float) - let argsMake = ( - ~esti: genericDistOrScalar, - ~answ: genericDistOrScalar, - ~prior: option, - ): result => { + let argsMake = (~esti: t, ~answ: genericDistOrScalar, ~prior: option, ~env: env): result< + PointSetDist_Scoring.scoreArgs, + error, + > => { let toPointSetFn = t => toPointSet( t, - ~xyPointLength=MagicNumbers.Environment.defaultXYPointLength, - ~sampleCount=MagicNumbers.Environment.defaultSampleCount, + ~xyPointLength=env.xyPointLength, + ~sampleCount=env.sampleCount, ~xSelection=#ByWeight, (), ) - let prior': option> = switch prior { + let prior': option> = switch prior { | None => None - | Some(Score_Dist(d)) => toPointSetFn(d)->E.R.bind(x => x->PSDist->Ok)->Some - | Some(Score_Scalar(s)) => s->PSScalar->Ok->Some + | Some(d) => toPointSetFn(d)->Some } let twoDists = (~toPointSetFn, esti': t, answ': t): result< (PointSetTypes.pointSetDist, PointSetTypes.pointSetDist), error, > => E.R.merge(toPointSetFn(esti'), toPointSetFn(answ')) switch (esti, answ, prior') { - | (Score_Dist(esti'), Score_Dist(answ'), None) => + | (esti', Score_Dist(answ'), None) => twoDists(~toPointSetFn, esti', answ')->E.R2.fmap(((esti'', answ'')) => {estimate: esti'', answer: answ'', prior: None}->PointSetDist_Scoring.DistAnswer ) - | (Score_Dist(esti'), Score_Dist(answ'), Some(Ok(PSDist(prior'')))) => + | (esti', Score_Dist(answ'), Some(Ok(prior''))) => twoDists(~toPointSetFn, esti', answ')->E.R2.fmap(((esti'', answ'')) => { estimate: esti'', @@ -170,8 +172,7 @@ module Score = { prior: Some(prior''), }->PointSetDist_Scoring.DistAnswer ) - | (Score_Dist(_), _, Some(Ok(PSScalar(_)))) => DistributionTypes.Unreachable->Error - | (Score_Dist(esti'), Score_Scalar(answ'), None) => + | (esti', Score_Scalar(answ'), None) => toPointSetFn(esti')->E.R2.fmap(esti'' => { estimate: esti'', @@ -179,7 +180,7 @@ module Score = { prior: None, }->PointSetDist_Scoring.ScalarAnswer ) - | (Score_Dist(esti'), Score_Scalar(answ'), Some(Ok(PSDist(prior'')))) => + | (esti', Score_Scalar(answ'), Some(Ok(prior''))) => toPointSetFn(esti')->E.R2.fmap(esti'' => { estimate: esti'', @@ -187,20 +188,17 @@ module Score = { prior: Some(prior''), }->PointSetDist_Scoring.ScalarAnswer ) - | (Score_Scalar(_), Score_Dist(_), None) => NotYetImplemented->Error - | (Score_Scalar(_), Score_Dist(_), Some(Ok(PSScalar(_)))) => NotYetImplemented->Error - | (Score_Scalar(_), _, Some(Ok(PSDist(_)))) => DistributionTypes.Unreachable->Error - | (Score_Scalar(_), Score_Scalar(_), _) => NotYetImplemented->Error | (_, _, Some(Error(err))) => err->Error } } let logScore = ( - ~estimate: genericDistOrScalar, + ~estimate: t, ~answer: genericDistOrScalar, - ~prior: option, + ~prior: option, + ~env: env, ): result => - argsMake(~esti=estimate, ~answ=answer, ~prior)->E.R.bind(x => + argsMake(~esti=estimate, ~answ=answer, ~prior, ~env)->E.R.bind(x => x->PointSetDist.logScore->E.R2.errMap(y => DistributionTypes.OperationError(y)) ) } diff --git a/packages/squiggle-lang/src/rescript/Distributions/GenericDist.resi b/packages/squiggle-lang/src/rescript/Distributions/GenericDist.resi index 432ae847..24faabe0 100644 --- a/packages/squiggle-lang/src/rescript/Distributions/GenericDist.resi +++ b/packages/squiggle-lang/src/rescript/Distributions/GenericDist.resi @@ -5,6 +5,9 @@ type toSampleSetFn = t => result type scaleMultiplyFn = (t, float) => result type pointwiseAddFn = (t, t) => result +@genType +type env = {sampleCount: int, xyPointLength: int} + let sampleN: (t, int) => array let sample: t => float @@ -26,9 +29,10 @@ let toFloatOperation: ( module Score: { let logScore: ( - ~estimate: DistributionTypes.DistributionOperation.genericDistOrScalar, + ~estimate: t, ~answer: DistributionTypes.DistributionOperation.genericDistOrScalar, - ~prior: option, + ~prior: option, + ~env: env ) => result } diff --git a/packages/squiggle-lang/src/rescript/Distributions/PointSetDist/PointSetDist_Scoring.res b/packages/squiggle-lang/src/rescript/Distributions/PointSetDist/PointSetDist_Scoring.res index 5e8eb489..8e5d71a4 100644 --- a/packages/squiggle-lang/src/rescript/Distributions/PointSetDist/PointSetDist_Scoring.res +++ b/packages/squiggle-lang/src/rescript/Distributions/PointSetDist/PointSetDist_Scoring.res @@ -19,7 +19,7 @@ module WithDistAnswer = { float, Operation.Error.t, > => - // We decided that negative infinity, not an error at answerElement = 0.0, is a desirable value. + // We decided that 0.0, not an error at answerElement = 0.0, is a desirable value. if answerElement == 0.0 { Ok(0.0) } else if estimateElement == 0.0 { diff --git a/packages/squiggle-lang/src/rescript/FunctionRegistry/FunctionRegistry_Core.res b/packages/squiggle-lang/src/rescript/FunctionRegistry/FunctionRegistry_Core.res index bbca7bc7..c69075d9 100644 --- a/packages/squiggle-lang/src/rescript/FunctionRegistry/FunctionRegistry_Core.res +++ b/packages/squiggle-lang/src/rescript/FunctionRegistry/FunctionRegistry_Core.res @@ -8,6 +8,7 @@ type rec frType = | FRTypeNumber | FRTypeNumeric | FRTypeDistOrNumber + | FRTypeDist | FRTypeLambda | FRTypeRecord(frTypeRecord) | FRTypeDict(frType) @@ -41,7 +42,7 @@ and frValueDistOrNumber = FRValueNumber(float) | FRValueDist(DistributionTypes.g type fnDefinition = { name: string, inputs: array, - run: (array, DistributionOperation.env) => result, + run: (array, GenericDist.env) => result, } type function = { @@ -60,6 +61,7 @@ module FRType = { switch t { | FRTypeNumber => "number" | FRTypeNumeric => "numeric" + | FRTypeDist => "distribution" | FRTypeDistOrNumber => "distribution|number" | FRTypeRecord(r) => { let input = ((name, frType): frTypeRecordParam) => `${name}: ${toString(frType)}` @@ -98,6 +100,7 @@ module FRType = { | (FRTypeDistOrNumber, IEvDistribution(Symbolic(#Float(f)))) => Some(FRValueDistOrNumber(FRValueNumber(f))) | (FRTypeDistOrNumber, IEvDistribution(f)) => Some(FRValueDistOrNumber(FRValueDist(f))) + | (FRTypeDist, IEvDistribution(f)) => Some(FRValueDist(f)) | (FRTypeNumeric, IEvNumber(f)) => Some(FRValueNumber(f)) | (FRTypeNumeric, IEvDistribution(Symbolic(#Float(f)))) => Some(FRValueNumber(f)) | (FRTypeLambda, IEvLambda(f)) => Some(FRValueLambda(f)) @@ -319,7 +322,7 @@ module FnDefinition = { t.name ++ `(${inputs})` } - let run = (t: t, args: array, env: DistributionOperation.env) => { + let run = (t: t, args: array, env: GenericDist.env) => { let argValues = FRType.matchWithExpressionValueArray(t.inputs, args) switch argValues { | Some(values) => t.run(values, env) @@ -374,7 +377,7 @@ module Registry = { ~registry: registry, ~fnName: string, ~args: array, - ~env: DistributionOperation.env, + ~env: GenericDist.env, ) => { let matchToDef = m => Matcher.Registry.matchToDef(registry, m) //Js.log(toSimple(registry)) diff --git a/packages/squiggle-lang/src/rescript/FunctionRegistry/FunctionRegistry_Helpers.res b/packages/squiggle-lang/src/rescript/FunctionRegistry/FunctionRegistry_Helpers.res index 46ae18f9..357725e1 100644 --- a/packages/squiggle-lang/src/rescript/FunctionRegistry/FunctionRegistry_Helpers.res +++ b/packages/squiggle-lang/src/rescript/FunctionRegistry/FunctionRegistry_Helpers.res @@ -27,6 +27,12 @@ module Prepare = { | _ => Error(impossibleError) } + let threeArgs = (inputs: ts): result => + switch inputs { + | [FRValueRecord([(_, n1), (_, n2), (_, n3)])] => Ok([n1, n2, n3]) + | _ => Error(impossibleError) + } + let toArgs = (inputs: ts): result => switch inputs { | [FRValueRecord(args)] => args->E.A2.fmap(((_, b)) => b)->Ok @@ -57,6 +63,13 @@ module Prepare = { } } + let twoDist = (values: ts): result<(DistributionTypes.genericDist, DistributionTypes.genericDist), err> => { + switch values { + | [FRValueDist(a1), FRValueDist(a2)] => Ok(a1, a2) + | _ => Error(impossibleError) + } + } + let twoNumbers = (values: ts): result<(float, float), err> => { switch values { | [FRValueNumber(a1), FRValueNumber(a2)] => Ok(a1, a2) @@ -81,6 +94,9 @@ module Prepare = { module Record = { let twoDistOrNumber = (values: ts): result<(frValueDistOrNumber, frValueDistOrNumber), err> => values->ToValueArray.Record.twoArgs->E.R.bind(twoDistOrNumber) + + let twoDist = (values: ts): result<(DistributionTypes.genericDist, DistributionTypes.genericDist), err> => + values->ToValueArray.Record.twoArgs->E.R.bind(twoDist) } } @@ -128,7 +144,7 @@ module Prepare = { module Process = { module DistOrNumberToDist = { module Helpers = { - let toSampleSet = (r, env: DistributionOperation.env) => + let toSampleSet = (r, env: GenericDist.env) => GenericDist.toSampleSetDist(r, env.sampleCount) let mapFnResult = r => @@ -166,7 +182,7 @@ module Process = { let oneValue = ( ~fn: float => result, ~value: frValueDistOrNumber, - ~env: DistributionOperation.env, + ~env: GenericDist.env, ): result => { switch value { | FRValueNumber(a1) => fn(a1) @@ -179,7 +195,7 @@ module Process = { let twoValues = ( ~fn: ((float, float)) => result, ~values: (frValueDistOrNumber, frValueDistOrNumber), - ~env: DistributionOperation.env, + ~env: GenericDist.env, ): result => { switch values { | (FRValueNumber(a1), FRValueNumber(a2)) => fn((a1, a2)) diff --git a/packages/squiggle-lang/src/rescript/FunctionRegistry/FunctionRegistry_Library.res b/packages/squiggle-lang/src/rescript/FunctionRegistry/FunctionRegistry_Library.res index a37b8dc4..add5cfa3 100644 --- a/packages/squiggle-lang/src/rescript/FunctionRegistry/FunctionRegistry_Library.res +++ b/packages/squiggle-lang/src/rescript/FunctionRegistry/FunctionRegistry_Library.res @@ -49,7 +49,7 @@ let inputsTodist = (inputs: array, makeDist) => { expressionValue } -let registry = [ +let registryStart = [ Function.make( ~name="toContinuousPointSet", ~definitions=[ @@ -510,3 +510,67 @@ to(5,10) (), ), ] + +let runScoring = (estimate, answer, prior, env) => { + GenericDist.Score.logScore(~estimate, ~answer, ~prior, ~env) + ->E.R2.fmap(FunctionRegistry_Helpers.Wrappers.evNumber) + ->E.R2.errMap(DistributionTypes.Error.toString) +} + +let scoreFunctions = [ + Function.make( + ~name="Score", + ~definitions=[ + FnDefinition.make( + ~name="logScore", + ~inputs=[ + FRTypeRecord([ + ("estimate", FRTypeDist), + ("answer", FRTypeDistOrNumber), + ("prior", FRTypeDist), + ]), + ], + ~run=(inputs, env) => { + switch FunctionRegistry_Helpers.Prepare.ToValueArray.Record.threeArgs(inputs) { + | Ok([FRValueDist(estimate), FRValueDistOrNumber(FRValueDist(d)), FRValueDist(prior)]) => + runScoring(estimate, Score_Dist(d), Some(prior), env) + | Ok([ + FRValueDist(estimate), + FRValueDistOrNumber(FRValueNumber(d)), + FRValueDist(prior), + ]) => + runScoring(estimate, Score_Scalar(d), Some(prior), env) + | Error(e) => Error(e) + | _ => Error(FunctionRegistry_Helpers.impossibleError) + } + }, + ), + FnDefinition.make( + ~name="logScore", + ~inputs=[FRTypeRecord([("estimate", FRTypeDist), ("answer", FRTypeDistOrNumber)])], + ~run=(inputs, env) => { + switch FunctionRegistry_Helpers.Prepare.ToValueArray.Record.twoArgs(inputs) { + | Ok([FRValueDist(estimate), FRValueDistOrNumber(FRValueDist(d))]) => + runScoring(estimate, Score_Dist(d), None, env) + | Ok([FRValueDist(estimate), FRValueDistOrNumber(FRValueNumber(d))]) => + runScoring(estimate, Score_Scalar(d), None, env) + | Error(e) => Error(e) + | _ => Error(FunctionRegistry_Helpers.impossibleError) + } + }, + ), + FnDefinition.make(~name="klDivergence", ~inputs=[FRTypeDist, FRTypeDist], ~run=( + inputs, + env, + ) => { + switch inputs { + | [FRValueDist(estimate), FRValueDist(d)] => runScoring(estimate, Score_Dist(d), None, env) + | _ => Error(FunctionRegistry_Helpers.impossibleError) + } + }), + ], + (), + ), +] + +let registry = E.A.append(registryStart, scoreFunctions) diff --git a/packages/squiggle-lang/src/rescript/ReducerInterface/ReducerInterface_Date.res b/packages/squiggle-lang/src/rescript/ReducerInterface/ReducerInterface_Date.res index e84c26f4..3396d0f3 100644 --- a/packages/squiggle-lang/src/rescript/ReducerInterface/ReducerInterface_Date.res +++ b/packages/squiggle-lang/src/rescript/ReducerInterface/ReducerInterface_Date.res @@ -1,7 +1,7 @@ module IEV = ReducerInterface_InternalExpressionValue type internalExpressionValue = IEV.t -let dispatch = (call: IEV.functionCall, _: DistributionOperation.env): option< +let dispatch = (call: IEV.functionCall, _: GenericDist.env): option< result, > => { switch call { diff --git a/packages/squiggle-lang/src/rescript/ReducerInterface/ReducerInterface_Duration.res b/packages/squiggle-lang/src/rescript/ReducerInterface/ReducerInterface_Duration.res index 838e4375..f9e06de4 100644 --- a/packages/squiggle-lang/src/rescript/ReducerInterface/ReducerInterface_Duration.res +++ b/packages/squiggle-lang/src/rescript/ReducerInterface/ReducerInterface_Duration.res @@ -1,7 +1,7 @@ module IEV = ReducerInterface_InternalExpressionValue type internalExpressionValue = IEV.t -let dispatch = (call: IEV.functionCall, _: DistributionOperation.env): option< +let dispatch = (call: IEV.functionCall, _: GenericDist.env): option< result, > => { switch call { diff --git a/packages/squiggle-lang/src/rescript/ReducerInterface/ReducerInterface_ExternalExpressionValue.res b/packages/squiggle-lang/src/rescript/ReducerInterface/ReducerInterface_ExternalExpressionValue.res index 9bd356d4..b21ba3c6 100644 --- a/packages/squiggle-lang/src/rescript/ReducerInterface/ReducerInterface_ExternalExpressionValue.res +++ b/packages/squiggle-lang/src/rescript/ReducerInterface/ReducerInterface_ExternalExpressionValue.res @@ -86,7 +86,7 @@ let toStringResult = x => } @genType -type environment = DistributionOperation.env +type environment = GenericDist.env @genType let defaultEnvironment: environment = DistributionOperation.defaultEnv diff --git a/packages/squiggle-lang/src/rescript/ReducerInterface/ReducerInterface_GenericDistribution.res b/packages/squiggle-lang/src/rescript/ReducerInterface/ReducerInterface_GenericDistribution.res index b73f84e0..a65edf2f 100644 --- a/packages/squiggle-lang/src/rescript/ReducerInterface/ReducerInterface_GenericDistribution.res +++ b/packages/squiggle-lang/src/rescript/ReducerInterface/ReducerInterface_GenericDistribution.res @@ -32,7 +32,7 @@ module Helpers = { let toFloatFn = ( fnCall: DistributionTypes.DistributionOperation.toFloat, dist: DistributionTypes.genericDist, - ~env: DistributionOperation.env, + ~env: GenericDist.env, ) => { FromDist(#ToFloat(fnCall), dist)->DistributionOperation.run(~env)->Some } @@ -40,7 +40,7 @@ module Helpers = { let toStringFn = ( fnCall: DistributionTypes.DistributionOperation.toString, dist: DistributionTypes.genericDist, - ~env: DistributionOperation.env, + ~env: GenericDist.env, ) => { FromDist(#ToString(fnCall), dist)->DistributionOperation.run(~env)->Some } @@ -48,7 +48,7 @@ module Helpers = { let toBoolFn = ( fnCall: DistributionTypes.DistributionOperation.toBool, dist: DistributionTypes.genericDist, - ~env: DistributionOperation.env, + ~env: GenericDist.env, ) => { FromDist(#ToBool(fnCall), dist)->DistributionOperation.run(~env)->Some } @@ -56,12 +56,12 @@ module Helpers = { let toDistFn = ( fnCall: DistributionTypes.DistributionOperation.toDist, dist, - ~env: DistributionOperation.env, + ~env: GenericDist.env, ) => { FromDist(#ToDist(fnCall), dist)->DistributionOperation.run(~env)->Some } - let twoDiststoDistFn = (direction, arithmetic, dist1, dist2, ~env: DistributionOperation.env) => { + let twoDiststoDistFn = (direction, arithmetic, dist1, dist2, ~env: GenericDist.env) => { FromDist( #ToDistCombination(direction, arithmeticMap(arithmetic), #Dist(dist2)), dist1, @@ -97,7 +97,7 @@ module Helpers = { let mixtureWithGivenWeights = ( distributions: array, weights: array, - ~env: DistributionOperation.env, + ~env: GenericDist.env, ): DistributionOperation.outputType => E.A.length(distributions) == E.A.length(weights) ? Mixture(Belt.Array.zip(distributions, weights))->DistributionOperation.run(~env) @@ -107,7 +107,7 @@ module Helpers = { let mixtureWithDefaultWeights = ( distributions: array, - ~env: DistributionOperation.env, + ~env: GenericDist.env, ): DistributionOperation.outputType => { let length = E.A.length(distributions) let weights = Belt.Array.make(length, 1.0 /. Belt.Int.toFloat(length)) @@ -116,7 +116,7 @@ module Helpers = { let mixture = ( args: array, - ~env: DistributionOperation.env, + ~env: GenericDist.env, ): DistributionOperation.outputType => { let error = (err: string): DistributionOperation.outputType => err->DistributionTypes.ArgumentError->GenDistError @@ -173,7 +173,7 @@ module SymbolicConstructors = { } } -let dispatchToGenericOutput = (call: IEV.functionCall, env: DistributionOperation.env): option< +let dispatchToGenericOutput = (call: IEV.functionCall, env: GenericDist.env): option< DistributionOperation.outputType, > => { let (fnName, args) = call @@ -213,70 +213,6 @@ let dispatchToGenericOutput = (call: IEV.functionCall, env: DistributionOperatio ~env, )->Some | ("normalize", [IEvDistribution(dist)]) => Helpers.toDistFn(Normalize, dist, ~env) - | ("klDivergence", [IEvDistribution(prediction), IEvDistribution(answer)]) => - Some( - DistributionOperation.run( - FromDist( - #ToScore(LogScore(DistributionTypes.DistributionOperation.Score_Dist(answer), None)), - prediction, - ), - ~env, - ), - ) - | ( - "klDivergence", - [IEvDistribution(prediction), IEvDistribution(answer), IEvDistribution(prior)], - ) => - Some( - DistributionOperation.run( - FromDist( - #ToScore( - LogScore( - DistributionTypes.DistributionOperation.Score_Dist(answer), - Some(DistributionTypes.DistributionOperation.Score_Dist(prior)), - ), - ), - prediction, - ), - ~env, - ), - ) - | ( - "logScoreWithPointAnswer", - [IEvDistribution(prediction), IEvNumber(answer), IEvDistribution(prior)], - ) - | ( - "logScoreWithPointAnswer", - [ - IEvDistribution(prediction), - IEvDistribution(Symbolic(#Float(answer))), - IEvDistribution(prior), - ], - ) => - DistributionOperation.run( - FromDist( - #ToScore( - LogScore( - DistributionTypes.DistributionOperation.Score_Scalar(answer), - DistributionTypes.DistributionOperation.Score_Dist(prior)->Some, - ), - ), - prediction, - ), - ~env, - )->Some - | ("logScoreWithPointAnswer", [IEvDistribution(prediction), IEvNumber(answer)]) - | ( - "logScoreWithPointAnswer", - [IEvDistribution(prediction), IEvDistribution(Symbolic(#Float(answer)))], - ) => - DistributionOperation.run( - FromDist( - #ToScore(LogScore(DistributionTypes.DistributionOperation.Score_Scalar(answer), None)), - prediction, - ), - ~env, - )->Some | ("isNormalized", [IEvDistribution(dist)]) => Helpers.toBoolFn(IsNormalized, dist, ~env) | ("toPointSet", [IEvDistribution(dist)]) => Helpers.toDistFn(ToPointSet, dist, ~env) | ("scaleLog", [IEvDistribution(dist)]) => diff --git a/packages/squiggle-lang/src/rescript/ReducerInterface/ReducerInterface_Number.res b/packages/squiggle-lang/src/rescript/ReducerInterface/ReducerInterface_Number.res index 900e6a2f..f22df39b 100644 --- a/packages/squiggle-lang/src/rescript/ReducerInterface/ReducerInterface_Number.res +++ b/packages/squiggle-lang/src/rescript/ReducerInterface/ReducerInterface_Number.res @@ -24,7 +24,7 @@ module ScientificUnit = { } } -let dispatch = (call: IEV.functionCall, _: DistributionOperation.env): option< +let dispatch = (call: IEV.functionCall, _: GenericDist.env): option< result, > => { switch call { diff --git a/packages/squiggle-lang/src/rescript/TypescriptInterface.res b/packages/squiggle-lang/src/rescript/TypescriptInterface.res index 3d2ce160..a1f5afe6 100644 --- a/packages/squiggle-lang/src/rescript/TypescriptInterface.res +++ b/packages/squiggle-lang/src/rescript/TypescriptInterface.res @@ -8,7 +8,7 @@ The below few seem to work fine. In the future there's definitely more work to d */ @genType -type samplingParams = DistributionOperation.env +type samplingParams = GenericDist.env @genType type genericDist = DistributionTypes.genericDist