diff --git a/packages/squiggle-lang/src/rescript/ReducerInterface/ReducerInterface_GenericDistribution.res b/packages/squiggle-lang/src/rescript/ReducerInterface/ReducerInterface_GenericDistribution.res index fb26189f..b2bf8a32 100644 --- a/packages/squiggle-lang/src/rescript/ReducerInterface/ReducerInterface_GenericDistribution.res +++ b/packages/squiggle-lang/src/rescript/ReducerInterface/ReducerInterface_GenericDistribution.res @@ -1,5 +1,5 @@ module ExpressionValue = ReducerInterface_ExpressionValue -type expressionValue = ReducerInterface_ExpressionValue.expressionValue +type expressionValue = ExpressionValue.expressionValue let defaultEnv: DistributionOperation.env = { sampleCount: MagicNumbers.Environment.defaultSampleCount, @@ -210,7 +210,7 @@ module SymbolicConstructors = { } } -let dispatchToGenericOutput = (call: ExpressionValue.functionCall, _environment): option< +let rec dispatchToGenericOutput = (call: ExpressionValue.functionCall, _environment): option< DistributionOperation.outputType, > => { let (fnName, args) = call @@ -257,6 +257,10 @@ let dispatchToGenericOutput = (call: ExpressionValue.functionCall, _environment) [EvDistribution(prior), EvDistribution(prediction), EvDistribution(Symbolic(#Float(answer)))], ) => runGenericOperation(FromDist(ToScore(LogScore(prediction, answer)), prior))->Some + | ("logScore", [EvRecord(r)]) => + recurRecordArgs("logScore", ["prior", "prediction", "answer"], r, _environment) + | ("increment", [EvNumber(x)]) => (x +. 1.0)->DistributionOperation.Float->Some + | ("increment", [EvRecord(r)]) => recurRecordArgs("increment", ["incrementee"], r, _environment) | ("logScoreAgainstImproperPrior", [EvDistribution(prediction), EvNumber(answer)]) | ( "logScoreAgainstImproperPrior", @@ -340,6 +344,16 @@ let dispatchToGenericOutput = (call: ExpressionValue.functionCall, _environment) | _ => None } } +and recurRecordArgs = ( + fnName: string, + argNames: array, + args: ExpressionValue.record, + _environment: 'a, +): option => + // argNames -> E.A2.fmap(x => Js.Dict.get(args, x)) -> E.A.O.arrSomeToSomeArr -> E.O.bind(a => dispatchToGenericOutput((fnName, a), _environment)) + argNames + ->E.A2.fmap(x => Js.Dict.unsafeGet(args, x)) + ->(a => dispatchToGenericOutput((fnName, a), _environment)) let genericOutputToReducerValue = (o: DistributionOperation.outputType): result< expressionValue, diff --git a/packages/squiggle-lang/src/rescript/Utility/E.res b/packages/squiggle-lang/src/rescript/Utility/E.res index 15678e1a..63999ce6 100644 --- a/packages/squiggle-lang/src/rescript/Utility/E.res +++ b/packages/squiggle-lang/src/rescript/Utility/E.res @@ -620,6 +620,17 @@ module A = { | Some(o) => o | None => [] } + let rec arrSomeToSomeArr = (optionals: array>): option> => { + let optionals' = optionals->Belt.List.fromArray + switch optionals' { + | list{} => []->Some + | list{x, ...xs} => + switch x { + | Some(_) => xs->Belt.List.toArray->arrSomeToSomeArr + | None => None + } + } + } } module R = {