diff --git a/packages/squiggle-lang/__tests__/ReducerInterface/ReducerInterface_Distribution_test.res b/packages/squiggle-lang/__tests__/ReducerInterface/ReducerInterface_Distribution_test.res index bc95f160..4e2f9aaa 100644 --- a/packages/squiggle-lang/__tests__/ReducerInterface/ReducerInterface_Distribution_test.res +++ b/packages/squiggle-lang/__tests__/ReducerInterface/ReducerInterface_Distribution_test.res @@ -23,7 +23,7 @@ describe("eval on distribution functions", () => { testEval("-normal(5,2)", "Ok(Normal(-5,2))") }) describe("to", () => { - testEval("5 to 2", "Error(Distribution Math Error: Low value must be less than high value.)") + testEval("5 to 2", "Error(TODO: Low value must be less than high value.)") testEval("to(2,5)", "Ok(Lognormal(1.1512925464970227,0.27853260523016377))") testEval("to(-2,2)", "Ok(Normal(0,1.2159136638235384))") }) diff --git a/packages/squiggle-lang/src/rescript/Distributions/GenericDist.res b/packages/squiggle-lang/src/rescript/Distributions/GenericDist.res index 1df10240..d6920a4e 100644 --- a/packages/squiggle-lang/src/rescript/Distributions/GenericDist.res +++ b/packages/squiggle-lang/src/rescript/Distributions/GenericDist.res @@ -31,6 +31,8 @@ let sampleN = (t: t, n) => | SampleSet(r) => SampleSetDist.sampleN(r, n) } +let sample = (t: t) => sampleN(t, 1)->E.A.first |> E.O.toExn("Should not have happened") + let toSampleSetDist = (t: t, n) => SampleSetDist.make(sampleN(t, n))->E.R2.errMap(DistributionTypes.Error.sampleErrorToDistErr) diff --git a/packages/squiggle-lang/src/rescript/Distributions/GenericDist.resi b/packages/squiggle-lang/src/rescript/Distributions/GenericDist.resi index 79fb54ab..fd9afa58 100644 --- a/packages/squiggle-lang/src/rescript/Distributions/GenericDist.resi +++ b/packages/squiggle-lang/src/rescript/Distributions/GenericDist.resi @@ -6,6 +6,7 @@ type scaleMultiplyFn = (t, float) => result type pointwiseAddFn = (t, t) => result let sampleN: (t, int) => array +let sample: t => float let toSampleSetDist: (t, int) => Belt.Result.t diff --git a/packages/squiggle-lang/src/rescript/FunctionRegistry/FunctionRegistry_Core.res b/packages/squiggle-lang/src/rescript/FunctionRegistry/FunctionRegistry_Core.res new file mode 100644 index 00000000..99ecc78f --- /dev/null +++ b/packages/squiggle-lang/src/rescript/FunctionRegistry/FunctionRegistry_Core.res @@ -0,0 +1,301 @@ +type expressionValue = ReducerInterface_ExpressionValue.expressionValue + +/* + Function Registry "Type". A type, without any other information. + Like, #Float +*/ +type rec frType = + | FRTypeNumber + | FRTypeNumeric + | FRTypeDistOrNumber + | FRTypeRecord(frTypeRecord) + | FRTypeArray(array) + | FRTypeOption(frType) +and frTypeRecord = array +and frTypeRecordParam = (string, frType) + +/* + Function Registry "Value". A type, with the information of that type. + Like, #Float(40.0) +*/ +type rec frValue = + | FRValueNumber(float) + | FRValueDist(DistributionTypes.genericDist) + | FRValueOption(option) + | FRValueDistOrNumber(frValueDistOrNumber) + | FRValueRecord(frValueRecord) +and frValueRecord = array +and frValueRecordParam = (string, frValue) +and frValueDistOrNumber = FRValueNumber(float) | FRValueDist(DistributionTypes.genericDist) + +type fnDefinition = { + name: string, + inputs: array, + run: (array, DistributionOperation.env) => result, +} + +type function = { + name: string, + definitions: array, +} + +type registry = array + +module FRType = { + type t = frType + let rec toString = (t: t) => + switch t { + | FRTypeNumber => "number" + | FRTypeNumeric => "numeric" + | FRTypeDistOrNumber => "frValueDistOrNumber" + | FRTypeRecord(r) => { + let input = ((name, frType): frTypeRecordParam) => `${name}: ${toString(frType)}` + `record({${r->E.A2.fmap(input)->E.A2.joinWith(", ")}})` + } + | FRTypeArray(r) => `record(${r->E.A2.fmap(toString)->E.A2.joinWith(", ")})` + | FRTypeOption(v) => `option(${toString(v)})` + } + + let rec matchWithExpressionValue = (t: t, r: expressionValue): option => + switch (t, r) { + | (FRTypeNumber, EvNumber(f)) => Some(FRValueNumber(f)) + | (FRTypeDistOrNumber, EvNumber(f)) => Some(FRValueDistOrNumber(FRValueNumber(f))) + | (FRTypeDistOrNumber, EvDistribution(Symbolic(#Float(f)))) => + Some(FRValueDistOrNumber(FRValueNumber(f))) + | (FRTypeDistOrNumber, EvDistribution(f)) => Some(FRValueDistOrNumber(FRValueDist(f))) + | (FRTypeNumeric, EvNumber(f)) => Some(FRValueNumber(f)) + | (FRTypeNumeric, EvDistribution(Symbolic(#Float(f)))) => Some(FRValueNumber(f)) + | (FRTypeOption(v), _) => Some(FRValueOption(matchWithExpressionValue(v, r))) + | (FRTypeRecord(recordParams), EvRecord(record)) => { + let getAndMatch = (name, input) => + E.Dict.get(record, name)->E.O.bind(matchWithExpressionValue(input)) + //All names in the type must be present. If any are missing, the corresponding + //value will be None, and this function would return None. + let namesAndValues: array> = + recordParams->E.A2.fmap(((name, input)) => + getAndMatch(name, input)->E.O2.fmap(match => (name, match)) + ) + namesAndValues->E.A.O.openIfAllSome->E.O2.fmap(r => FRValueRecord(r)) + } + | _ => None + } + + let matchWithExpressionValueArray = (inputs: array, args: array): option< + array, + > => { + let isSameLength = E.A.length(inputs) == E.A.length(args) + if !isSameLength { + None + } else { + E.A.zip(inputs, args) + ->E.A2.fmap(((input, arg)) => matchWithExpressionValue(input, arg)) + ->E.A.O.openIfAllSome + } + } +} + +/* + This module, Matcher, is fairly lengthy. However, only two functions from it + are meant to be used outside of it. These are findMatches and matchToDef in Matches.Registry. + The rest of it is just called from those two functions. +*/ +module Matcher = { + module MatchSimple = { + type t = DifferentName | SameNameDifferentArguments | FullMatch + + let isFullMatch = (match: t) => + switch match { + | FullMatch => true + | _ => false + } + + let isNameMatchOnly = (match: t) => + switch match { + | SameNameDifferentArguments => true + | _ => false + } + } + + module Match = { + type t<'a, 'b> = DifferentName | SameNameDifferentArguments('a) | FullMatch('b) + + let isFullMatch = (match: t<'a, 'b>): bool => + switch match { + | FullMatch(_) => true + | _ => false + } + + let isNameMatchOnly = (match: t<'a, 'b>) => + switch match { + | SameNameDifferentArguments(_) => true + | _ => false + } + } + + module FnDefinition = { + let matchAssumingSameName = (f: fnDefinition, args: array) => { + switch FRType.matchWithExpressionValueArray(f.inputs, args) { + | Some(_) => MatchSimple.FullMatch + | None => MatchSimple.SameNameDifferentArguments + } + } + + let match = (f: fnDefinition, fnName: string, args: array) => { + if f.name !== fnName { + MatchSimple.DifferentName + } else { + matchAssumingSameName(f, args) + } + } + } + + module Function = { + type definitionId = int + type match = Match.t, definitionId> + + let match = (f: function, fnName: string, args: array): match => { + let matchedDefinition = () => + E.A.getIndexBy(f.definitions, r => + MatchSimple.isFullMatch(FnDefinition.match(r, fnName, args)) + ) |> E.O.fmap(r => Match.FullMatch(r)) + let getMatchedNameOnlyDefinition = () => { + let nameMatchIndexes = + f.definitions + ->E.A2.fmapi((index, r) => + MatchSimple.isNameMatchOnly(FnDefinition.match(r, fnName, args)) ? Some(index) : None + ) + ->E.A.O.concatSomes + switch nameMatchIndexes { + | [] => None + | elements => Some(Match.SameNameDifferentArguments(elements)) + } + } + + E.A.O.firstSomeFnWithDefault( + [matchedDefinition, getMatchedNameOnlyDefinition], + Match.DifferentName, + ) + } + } + + module RegistryMatch = { + type match = { + fnName: string, + inputIndex: int, + } + let makeMatch = (fnName: string, inputIndex: int) => {fnName: fnName, inputIndex: inputIndex} + } + + module Registry = { + let _findExactMatches = (r: registry, fnName: string, args: array) => { + let functionMatchPairs = r->E.A2.fmap(l => (l, Function.match(l, fnName, args))) + let fullMatch = functionMatchPairs->E.A.getBy(((_, match)) => Match.isFullMatch(match)) + fullMatch->E.O.bind(((fn, match)) => + switch match { + | FullMatch(index) => Some(RegistryMatch.makeMatch(fn.name, index)) + | _ => None + } + ) + } + + let _findNameMatches = (r: registry, fnName: string, args: array) => { + let functionMatchPairs = r->E.A2.fmap(l => (l, Function.match(l, fnName, args))) + let getNameMatches = + functionMatchPairs + ->E.A2.fmap(((fn, match)) => Match.isNameMatchOnly(match) ? Some((fn, match)) : None) + ->E.A.O.concatSomes + let matches = + getNameMatches + ->E.A2.fmap(((fn, match)) => + switch match { + | SameNameDifferentArguments(indexes) => + indexes->E.A2.fmap(index => RegistryMatch.makeMatch(fn.name, index)) + | _ => [] + } + ) + ->Belt.Array.concatMany + E.A.toNoneIfEmpty(matches) + } + + let findMatches = (r: registry, fnName: string, args: array) => { + switch _findExactMatches(r, fnName, args) { + | Some(r) => Match.FullMatch(r) + | None => + switch _findNameMatches(r, fnName, args) { + | Some(r) => Match.SameNameDifferentArguments(r) + | None => Match.DifferentName + } + } + } + + let matchToDef = (registry: registry, {fnName, inputIndex}: RegistryMatch.match): option< + fnDefinition, + > => + registry + ->E.A.getBy(fn => fn.name === fnName) + ->E.O.bind(fn => E.A.get(fn.definitions, inputIndex)) + } +} + +module FnDefinition = { + type t = fnDefinition + + let toString = (t: t) => { + let inputs = t.inputs->E.A2.fmap(FRType.toString)->E.A2.joinWith(", ") + t.name ++ `(${inputs})` + } + + let run = (t: t, args: array, env: DistributionOperation.env) => { + let argValues = FRType.matchWithExpressionValueArray(t.inputs, args) + switch argValues { + | Some(values) => t.run(values, env) + | None => Error("Incorrect Types") + } + } + + let make = (~name, ~inputs, ~run): t => { + name: name, + inputs: inputs, + run: run, + } +} + +module Function = { + type t = function + + let make = (~name, ~definitions): t => { + name: name, + definitions: definitions, + } +} + +module Registry = { + /* + There's a (potential+minor) bug here: If a function definition is called outside of the calls + to the registry, then it's possible that there could be a match after the registry is + called. However, for now, we could just call the registry last. + */ + let matchAndRun = ( + ~registry: registry, + ~fnName: string, + ~args: array, + ~env: DistributionOperation.env, + ) => { + let matchToDef = m => Matcher.Registry.matchToDef(registry, m) + let showNameMatchDefinitions = matches => { + let defs = + matches + ->E.A2.fmap(matchToDef) + ->E.A.O.concatSomes + ->E.A2.fmap(FnDefinition.toString) + ->E.A2.fmap(r => `[${r}]`) + ->E.A2.joinWith("; ") + `There are function matches for ${fnName}(), but with different arguments: ${defs}` + } + switch Matcher.Registry.findMatches(registry, fnName, args) { + | Matcher.Match.FullMatch(match) => match->matchToDef->E.O2.fmap(FnDefinition.run(_, args, env)) + | SameNameDifferentArguments(m) => Some(Error(showNameMatchDefinitions(m))) + | _ => None + } + } +} diff --git a/packages/squiggle-lang/src/rescript/FunctionRegistry/FunctionRegistry_Core.resi b/packages/squiggle-lang/src/rescript/FunctionRegistry/FunctionRegistry_Core.resi new file mode 100644 index 00000000..5ca8c708 --- /dev/null +++ b/packages/squiggle-lang/src/rescript/FunctionRegistry/FunctionRegistry_Core.resi @@ -0,0 +1,58 @@ +type expressionValue = ReducerInterface_ExpressionValue.expressionValue + +type rec frType = + | FRTypeNumber + | FRTypeNumeric + | FRTypeDistOrNumber + | FRTypeRecord(frTypeRecord) + | FRTypeArray(array) + | FRTypeOption(frType) +and frTypeRecord = array +and frTypeRecordParam = (string, frType) + +type rec frValue = + | FRValueNumber(float) + | FRValueDist(DistributionTypes.genericDist) + | FRValueOption(option) + | FRValueDistOrNumber(frValueDistOrNumber) + | FRValueRecord(frValueRecord) +and frValueRecord = array +and frValueRecordParam = (string, frValue) +and frValueDistOrNumber = FRValueNumber(float) | FRValueDist(DistributionTypes.genericDist) + +type fnDefinition = { + name: string, + inputs: array, + run: (array, DistributionOperation.env) => result, +} + +type function = { + name: string, + definitions: array, +} + +type registry = array + +// Note: The function "name" is just used for documentation purposes +module Function: { + type t = function + let make: (~name: string, ~definitions: array) => t +} + +module FnDefinition: { + type t = fnDefinition + let make: ( + ~name: string, + ~inputs: array, + ~run: (array, DistributionOperation.env) => result, + ) => t +} + +module Registry: { + let matchAndRun: ( + ~registry: registry, + ~fnName: string, + ~args: array, + ~env: QuriSquiggleLang.DistributionOperation.env, + ) => option> +} diff --git a/packages/squiggle-lang/src/rescript/FunctionRegistry/FunctionRegistry_Helpers.res b/packages/squiggle-lang/src/rescript/FunctionRegistry/FunctionRegistry_Helpers.res new file mode 100644 index 00000000..118a15d2 --- /dev/null +++ b/packages/squiggle-lang/src/rescript/FunctionRegistry/FunctionRegistry_Helpers.res @@ -0,0 +1,156 @@ +open FunctionRegistry_Core + +let impossibleError = "Wrong inputs / Logically impossible" + +module Wrappers = { + let symbolic = r => DistributionTypes.Symbolic(r) + let evDistribution = r => ReducerInterface_ExpressionValue.EvDistribution(r) + let symbolicEvDistribution = r => r->DistributionTypes.Symbolic->evDistribution +} + +module Prepare = { + type ts = array + type err = string + + module ToValueArray = { + module Record = { + let twoArgs = (inputs: ts): result => + switch inputs { + | [FRValueRecord([(_, n1), (_, n2)])] => Ok([n1, n2]) + | _ => Error(impossibleError) + } + } + } + + module ToValueTuple = { + let twoDistOrNumber = (values: ts): result<(frValueDistOrNumber, frValueDistOrNumber), err> => { + switch values { + | [FRValueDistOrNumber(a1), FRValueDistOrNumber(a2)] => Ok(a1, a2) + | _ => Error(impossibleError) + } + } + + let oneDistOrNumber = (values: ts): result => { + switch values { + | [FRValueDistOrNumber(a1)] => Ok(a1) + | _ => Error(impossibleError) + } + } + + module Record = { + let twoDistOrNumber = (values: ts): result<(frValueDistOrNumber, frValueDistOrNumber), err> => + values->ToValueArray.Record.twoArgs->E.R.bind(twoDistOrNumber) + } + } +} + +module Process = { + module DistOrNumberToDist = { + module Helpers = { + let toSampleSet = (r, env: DistributionOperation.env) => + GenericDist.toSampleSetDist(r, env.sampleCount) + + let mapFnResult = r => + switch r { + | Ok(r) => Ok(GenericDist.sample(r)) + | Error(r) => Error(Operation.Other(r)) + } + + let wrapSymbolic = (fn, r) => r->fn->E.R2.fmap(Wrappers.symbolic) + + let singleVarSample = (dist, fn, env) => { + switch toSampleSet(dist, env) { + | Ok(dist) => + switch SampleSetDist.samplesMap(~fn=f => fn(f)->mapFnResult, dist) { + | Ok(r) => Ok(DistributionTypes.SampleSet(r)) + | Error(r) => Error(DistributionTypes.Error.toString(DistributionTypes.SampleSetError(r))) + } + | Error(r) => Error(DistributionTypes.Error.toString(r)) + } + } + + let twoVarSample = (dist1, dist2, fn, env) => { + let altFn = (a, b) => fn((a, b))->mapFnResult + switch E.R.merge(toSampleSet(dist1, env), toSampleSet(dist2, env)) { + | Ok((t1, t2)) => + switch SampleSetDist.map2(~fn=altFn, ~t1, ~t2) { + | Ok(r) => Ok(DistributionTypes.SampleSet(r)) + | Error(r) => Error(Operation.Error.toString(r)) + } + | Error(r) => Error(DistributionTypes.Error.toString(r)) + } + } + } + + let oneValue = ( + ~fn: float => result, + ~value: frValueDistOrNumber, + ~env: DistributionOperation.env, + ): result => { + switch value { + | FRValueNumber(a1) => fn(a1) + | FRValueDist(a1) => Helpers.singleVarSample(a1, r => fn(r), env) + } + } + + let oneValueUsingSymbolicDist = (~fn, ~value) => oneValue(~fn=Helpers.wrapSymbolic(fn), ~value) + + let twoValues = ( + ~fn: ((float, float)) => result, + ~values: (frValueDistOrNumber, frValueDistOrNumber), + ~env: DistributionOperation.env, + ): result => { + switch values { + | (FRValueNumber(a1), FRValueNumber(a2)) => fn((a1, a2)) + | (FRValueDist(a1), FRValueNumber(a2)) => Helpers.singleVarSample(a1, r => fn((r, a2)), env) + | (FRValueNumber(a1), FRValueDist(a2)) => Helpers.singleVarSample(a2, r => fn((a1, r)), env) + | (FRValueDist(a1), FRValueDist(a2)) => Helpers.twoVarSample(a1, a2, fn, env) + } + } + + let twoValuesUsingSymbolicDist = (~fn, ~values) => + twoValues(~fn=Helpers.wrapSymbolic(fn), ~values) + } +} + +module TwoArgDist = { + let process = (~fn, ~env, r) => + r + ->E.R.bind(Process.DistOrNumberToDist.twoValuesUsingSymbolicDist(~fn, ~values=_, ~env)) + ->E.R2.fmap(Wrappers.evDistribution) + + let make = (name, fn) => { + FnDefinition.make(~name, ~inputs=[FRTypeDistOrNumber, FRTypeDistOrNumber], ~run=(inputs, env) => + inputs->Prepare.ToValueTuple.twoDistOrNumber->process(~fn, ~env) + ) + } + + let makeRecordP5P95 = (name, fn) => { + FnDefinition.make( + ~name, + ~inputs=[FRTypeRecord([("p5", FRTypeDistOrNumber), ("p95", FRTypeDistOrNumber)])], + ~run=(inputs, env) => inputs->Prepare.ToValueTuple.Record.twoDistOrNumber->process(~fn, ~env), + ) + } + + let makeRecordMeanStdev = (name, fn) => { + FnDefinition.make( + ~name, + ~inputs=[FRTypeRecord([("mean", FRTypeDistOrNumber), ("stdev", FRTypeDistOrNumber)])], + ~run=(inputs, env) => inputs->Prepare.ToValueTuple.Record.twoDistOrNumber->process(~fn, ~env), + ) + } +} + +module OneArgDist = { + let process = (~fn, ~env, r) => + r + ->E.R.bind(Process.DistOrNumberToDist.oneValueUsingSymbolicDist(~fn, ~value=_, ~env)) + ->E.R2.fmap(Wrappers.evDistribution) + + let make = (name, fn) => { + FnDefinition.make(~name, ~inputs=[FRTypeDistOrNumber], ~run=(inputs, env) => + inputs->Prepare.ToValueTuple.oneDistOrNumber->process(~fn, ~env) + ) + } +} diff --git a/packages/squiggle-lang/src/rescript/FunctionRegistry/FunctionRegistry_Library.res b/packages/squiggle-lang/src/rescript/FunctionRegistry/FunctionRegistry_Library.res new file mode 100644 index 00000000..43be7118 --- /dev/null +++ b/packages/squiggle-lang/src/rescript/FunctionRegistry/FunctionRegistry_Library.res @@ -0,0 +1,65 @@ +open FunctionRegistry_Core +open FunctionRegistry_Helpers + +let twoArgs = E.Tuple2.toFnCall + +let registry = [ + Function.make( + ~name="Normal", + ~definitions=[ + TwoArgDist.make("normal", twoArgs(SymbolicDist.Normal.make)), + TwoArgDist.makeRecordP5P95("normal", r => + twoArgs(SymbolicDist.Normal.from90PercentCI, r)->Ok + ), + TwoArgDist.makeRecordMeanStdev("normal", twoArgs(SymbolicDist.Normal.make)), + ], + ), + Function.make( + ~name="Lognormal", + ~definitions=[ + TwoArgDist.make("lognormal", twoArgs(SymbolicDist.Lognormal.make)), + TwoArgDist.makeRecordP5P95("lognormal", r => + twoArgs(SymbolicDist.Lognormal.from90PercentCI, r)->Ok + ), + TwoArgDist.makeRecordMeanStdev("lognormal", twoArgs(SymbolicDist.Lognormal.fromMeanAndStdev)), + ], + ), + Function.make( + ~name="Uniform", + ~definitions=[TwoArgDist.make("uniform", twoArgs(SymbolicDist.Uniform.make))], + ), + Function.make( + ~name="Beta", + ~definitions=[TwoArgDist.make("beta", twoArgs(SymbolicDist.Beta.make))], + ), + Function.make( + ~name="Cauchy", + ~definitions=[TwoArgDist.make("cauchy", twoArgs(SymbolicDist.Cauchy.make))], + ), + Function.make( + ~name="Gamma", + ~definitions=[TwoArgDist.make("gamma", twoArgs(SymbolicDist.Gamma.make))], + ), + Function.make( + ~name="Logistic", + ~definitions=[TwoArgDist.make("logistic", twoArgs(SymbolicDist.Logistic.make))], + ), + Function.make( + ~name="To", + ~definitions=[ + TwoArgDist.make("to", twoArgs(SymbolicDist.From90thPercentile.make)), + TwoArgDist.make( + "credibleIntervalToDistribution", + twoArgs(SymbolicDist.From90thPercentile.make), + ), + ], + ), + Function.make( + ~name="Exponential", + ~definitions=[OneArgDist.make("exponential", SymbolicDist.Exponential.make)], + ), + Function.make( + ~name="Bernoulli", + ~definitions=[OneArgDist.make("bernoulli", SymbolicDist.Bernoulli.make)], + ), +] diff --git a/packages/squiggle-lang/src/rescript/FunctionRegistry/README.md b/packages/squiggle-lang/src/rescript/FunctionRegistry/README.md new file mode 100644 index 00000000..e974b189 --- /dev/null +++ b/packages/squiggle-lang/src/rescript/FunctionRegistry/README.md @@ -0,0 +1,46 @@ +# Function Registry + +The function registry is a library for organizing function definitions. + +The main interface is fairly constrained. Basically, write functions like the following, and add them to a big array. + +```rescript + Function.make( + ~name="Normal", + ~definitions=[ + FnDefinition.make( + ~name="Normal", + ~definitions=[ + FnDefinition.make(~name="normal", ~inputs=[FRTypeDistOrNumber, FRTypeDistOrNumber], ~run=( + inputs, + env, + ) => + inputs + ->Prepare.ToValueTuple.twoDistOrNumber + ->E.R.bind( + Process.twoDistsOrNumbersToDistUsingSymbolicDist( + ~fn=E.Tuple2.toFnCall(SymbolicDist.Normal.make), + ~env, + ~values=_, + ), + ) + ->E.R2.fmap(Wrappers.evDistribution) + ), + ], + ) + ], + ) +``` + +The Function name is just there for future documentation. The function defintions + +## Key Files + +**FunctionRegistry_Core** +Key types, internal functionality, and a `Registry` module with a `matchAndRun` function to call function definitions. + +**FunctionRegistry_Library** +A list of all the Functions defined in the Function Registry. + +**FunctionRegistry_Helpers** +A list of helper functions for the FunctionRegistry_Library. diff --git a/packages/squiggle-lang/src/rescript/ReducerInterface/ReducerInterface_GenericDistribution.res b/packages/squiggle-lang/src/rescript/ReducerInterface/ReducerInterface_GenericDistribution.res index fa152c6f..dc827805 100644 --- a/packages/squiggle-lang/src/rescript/ReducerInterface/ReducerInterface_GenericDistribution.res +++ b/packages/squiggle-lang/src/rescript/ReducerInterface/ReducerInterface_GenericDistribution.res @@ -179,27 +179,6 @@ module Helpers = { } module SymbolicConstructors = { - let oneFloat = name => - switch name { - | "exponential" => Ok(SymbolicDist.Exponential.make) - | "bernoulli" => Ok(SymbolicDist.Bernoulli.make) - | _ => Error("Unreachable state") - } - - let twoFloat = name => - switch name { - | "beta" => Ok(SymbolicDist.Beta.make) - | "cauchy" => Ok(SymbolicDist.Cauchy.make) - | "credibleIntervalToDistribution" => Ok(SymbolicDist.From90thPercentile.make) - | "gamma" => Ok(SymbolicDist.Gamma.make) - | "logistic" => Ok(SymbolicDist.Logistic.make) - | "lognormal" => Ok(SymbolicDist.Lognormal.make) - | "normal" => Ok(SymbolicDist.Normal.make) - | "to" => Ok(SymbolicDist.From90thPercentile.make) // as credibleIntervalToDistribution is defined "to" might be redundant - | "uniform" => Ok(SymbolicDist.Uniform.make) - | _ => Error("Unreachable state") - } - let threeFloat = name => switch name { | "triangular" => Ok(SymbolicDist.Triangular.make) @@ -221,27 +200,8 @@ let dispatchToGenericOutput = ( ): option => { let (fnName, args) = call switch (fnName, args) { - | (("exponential" | "bernoulli") as fnName, [EvNumber(f)]) => - SymbolicConstructors.oneFloat(fnName) - ->E.R.bind(r => r(f)) - ->SymbolicConstructors.symbolicResultToOutput | ("delta", [EvNumber(f)]) => SymbolicDist.Float.makeSafe(f)->SymbolicConstructors.symbolicResultToOutput - | ( - ("normal" - | "uniform" - | "beta" - | "lognormal" - | "cauchy" - | "gamma" - | "credibleIntervalToDistribution" - | "to" - | "logistic") as fnName, - [EvNumber(f1), EvNumber(f2)], - ) => - SymbolicConstructors.twoFloat(fnName) - ->E.R.bind(r => r(f1, f2)) - ->SymbolicConstructors.symbolicResultToOutput | ("triangular" as fnName, [EvNumber(f1), EvNumber(f2), EvNumber(f3)]) => SymbolicConstructors.threeFloat(fnName) ->E.R.bind(r => r(f1, f2, f3)) @@ -390,6 +350,20 @@ let genericOutputToReducerValue = (o: DistributionOperation.outputType): result< | GenDistError(err) => Error(REDistributionError(err)) } -let dispatch = (call, environment) => { - dispatchToGenericOutput(call, environment)->E.O2.fmap(genericOutputToReducerValue) +// I expect that it's important to build this first, so it doesn't get recalculated for each tryRegistry() call. +let registry = FunctionRegistry_Library.registry + +let tryRegistry = ((fnName, args): ExpressionValue.functionCall, env) => { + FunctionRegistry_Core.Registry.matchAndRun(~registry, ~fnName, ~args, ~env)->E.O2.fmap( + E.R2.errMap(_, s => Reducer_ErrorValue.RETodo(s)), + ) +} + +let dispatch = (call: ExpressionValue.functionCall, environment) => { + let regularDispatch = + dispatchToGenericOutput(call, environment)->E.O2.fmap(genericOutputToReducerValue) + switch regularDispatch { + | Some(x) => Some(x) + | None => tryRegistry(call, environment) + } } diff --git a/packages/squiggle-lang/src/rescript/Utility/E.res b/packages/squiggle-lang/src/rescript/Utility/E.res index b08754ef..3357f4f4 100644 --- a/packages/squiggle-lang/src/rescript/Utility/E.res +++ b/packages/squiggle-lang/src/rescript/Utility/E.res @@ -2,6 +2,9 @@ Some functions from modules `L`, `O`, and `R` below were copied directly from running `rescript convert -all` on Rationale https://github.com/jonlaing/rationale */ + +let equals = (a, b) => a === b + module FloatFloatMap = { module Id = Belt.Id.MakeComparable({ type t = float @@ -49,6 +52,7 @@ module Tuple2 = { let (_, b) = v b } + let toFnCall = (fn, (a1, a2)) => fn(a1, a2) } module O = { @@ -525,6 +529,7 @@ module A = { let unsafe_get = Array.unsafe_get let get = Belt.Array.get let getBy = Belt.Array.getBy + let getIndexBy = Belt.Array.getIndexBy let last = a => get(a, length(a) - 1) let first = get(_, 0) let hasBy = (r, fn) => Belt.Array.getBy(r, fn) |> O.isSome @@ -538,6 +543,7 @@ module A = { let reducei = Belt.Array.reduceWithIndex let isEmpty = r => length(r) < 1 let stableSortBy = Belt.SortArray.stableSortBy + let toNoneIfEmpty = r => isEmpty(r) ? None : Some(r) let toRanges = (a: array<'a>) => switch a |> Belt.Array.length { | 0 @@ -552,6 +558,12 @@ module A = { |> (x => Ok(x)) } + let getByOpen = (a, op, bin) => + switch getBy(a, r => bin(op(r))) { + | Some(r) => Some(op(r)) + | None => None + } + let tail = Belt.Array.sliceToEnd(_, 1) let zip = Belt.Array.zip @@ -636,6 +648,19 @@ module A = { } } let firstSome = x => Belt.Array.getBy(x, O.isSome) + + let firstSomeFn = (r: array option<'a>>): option<'a> => + O.flatten(getByOpen(r, l => l(), O.isSome)) + + let firstSomeFnWithDefault = (r, default) => firstSomeFn(r)->O2.default(default) + + let openIfAllSome = (optionals: array>): option> => { + if all(O.isSome, optionals) { + Some(optionals |> fmap(O.toExn("Warning: This should not have happened"))) + } else { + None + } + } } module R = { @@ -822,6 +847,7 @@ module A = { module A2 = { let fmap = (a, b) => A.fmap(b, a) + let fmapi = (a, b) => A.fmapi(b, a) let joinWith = (a, b) => A.joinWith(b, a) let filter = (a, b) => A.filter(b, a) } @@ -833,3 +859,9 @@ module JsArray = { |> Js.Array.map(O.toExn("Warning: This should not have happened")) let filter = Js.Array.filter } + +module Dict = { + type t<'a> = Js.Dict.t<'a> + let get = Js.Dict.get + let keys = Js.Dict.keys +} diff --git a/packages/squiggle-lang/src/rescript/Utility/Operation.res b/packages/squiggle-lang/src/rescript/Utility/Operation.res index 7972b2fa..6476850d 100644 --- a/packages/squiggle-lang/src/rescript/Utility/Operation.res +++ b/packages/squiggle-lang/src/rescript/Utility/Operation.res @@ -58,6 +58,7 @@ type operationError = | SampleMapNeedsNtoNFunction | PdfInvalidError | NotYetImplemented // should be removed when `klDivergence` for mixed and discrete is implemented. + | Other(string) @genType module Error = { @@ -73,6 +74,7 @@ module Error = { | SampleMapNeedsNtoNFunction => "SampleMap needs a function that converts a number to a number" | PdfInvalidError => "This Pdf is invalid" | NotYetImplemented => "This pathway is not yet implemented" + | Other(t) => t } }