diff --git a/packages/squiggle-lang/src/rescript/FunctionRegistry.res b/packages/squiggle-lang/src/rescript/FunctionRegistry.res index 766e3aaf..84e5eccb 100644 --- a/packages/squiggle-lang/src/rescript/FunctionRegistry.res +++ b/packages/squiggle-lang/src/rescript/FunctionRegistry.res @@ -28,6 +28,8 @@ type function = { definitions: array, } +type registry = array + module Function = { let make = (name, definitions): function => { name: name, @@ -79,23 +81,27 @@ let isNameMatchOnly = (match: match) => | _ => false } +let matchSingleSameName = (f: fnDefinition, args: array) => { + let inputTypes = f.inputs + if E.A.length(f.inputs) !== E.A.length(args) { + SameNameDifferentArguments(f.name) + } else { + let foo = + E.A.zip(inputTypes, args) + ->E.A2.fmap(((input, arg)) => matchInput(input, arg)) + ->E.A.O.arrSomeToSomeArr + switch foo { + | Some(r) => Match(f.name, r) + | None => SameNameDifferentArguments(f.name) + } + } +} + let matchSingle = (f: fnDefinition, fnName: string, args: array) => { if f.name !== fnName { DifferentName } else { - let inputTypes = f.inputs - if E.A.length(f.inputs) !== E.A.length(args) { - SameNameDifferentArguments(f.name) - } else { - let foo = - E.A.zip(inputTypes, args) - ->E.A2.fmap(((input, arg)) => matchInput(input, arg)) - ->E.A.O.arrSomeToSomeArr - switch foo { - | Some(r) => Match(f.name, r) - | None => SameNameDifferentArguments(f.name) - } - } + matchSingleSameName(f, args) } } @@ -107,6 +113,123 @@ let match = (f: function, fnName: string, args: array) => { E.A.O.firstSomeFnWithDefault([matchedDefinition, getMatchedNameOnlyDefinition], DifferentName) } +module IndexMatch = { + type t = [#FullMatch(int) | #NameMatchOnly(array) | #NoMatch] + let isFullMatch = (t: t) => + switch t { + | #FullMatch(_) => true + | _ => false + } + let isNameOnlyMatch = (t: t) => + switch t { + | #NameMatchOnly(_) => true + | _ => false + } +} + +let match2 = (f: function, fnName: string, args: array) => { + let matchedDefinition = () => + E.A.getIndexBy(f.definitions, r => isFullMatch(matchSingle(r, fnName, args))) |> E.O.fmap(r => + #FullMatch(r) + ) + let getMatchedNameOnlyDefinition = () => { + let nameMatchIndexes = + f.definitions + ->E.A2.fmapi((index, r) => isNameMatchOnly(matchSingle(r, fnName, args)) ? Some(index) : None) + ->E.A.O.concatSomes + switch nameMatchIndexes { + | [] => None + | elements => Some(#NameMatchOnly(elements)) + } + } + + E.A.O.firstSomeFnWithDefault([matchedDefinition, getMatchedNameOnlyDefinition], #NoMatch) +} + +module IndexMatch2 = { + type match = { + fnName: string, + inputIndex: int, + } + type t = [#FullMatch(match) | #NameMatchOnly(array) | #NoMatch] + let makeMatch = (fnName: string, inputIndex: int) => {fnName: fnName, inputIndex: inputIndex} + let isFullMatch = (t: t) => + switch t { + | #FullMatch(_) => true + | _ => false + } + let isNameOnlyMatch = (t: t) => + switch t { + | #NameMatchOnly(_) => true + | _ => false + } +} + +module Registry = { + let findExactMatches = (r: registry, fnName: string, args: array) => { + let functionMatchPairs = r->E.A2.fmap(l => (l, match2(l, fnName, args))) + let getFullMatch = E.A.getBy(functionMatchPairs, ((_, match)) => IndexMatch.isFullMatch(match)) + let fullMatch: option = getFullMatch->E.O.bind(((fn, match)) => + switch match { + | #FullMatch(index) => Some(IndexMatch2.makeMatch(fn.name, index)) + | _ => None + } + ) + fullMatch + } + + let findNameMatches = (r: registry, fnName: string, args: array) => { + let functionMatchPairs = r->E.A2.fmap(l => (l, match2(l, fnName, args))) + let getNameMatches = + functionMatchPairs + ->E.A2.fmap(((fn, match)) => IndexMatch.isNameOnlyMatch(match) ? Some((fn, match)) : None) + ->E.A.O.concatSomes + let matches = + getNameMatches + ->E.A2.fmap(((fn, match)) => + switch match { + | #NameMatchOnly(indexes) => + indexes->E.A2.fmap(index => IndexMatch2.makeMatch(fn.name, index)) + | _ => [] + } + ) + ->Belt.Array.concatMany + E.A.toNoneIfEmpty(matches) + } + + let findMatches = (r: registry, fnName: string, args: array) => { + switch findExactMatches(r, fnName, args) { + | Some(r) => #FullMatch(r) + | None => + switch findNameMatches(r, fnName, args) { + | Some(r) => #NameMatchOnly(r) + | None => #NoMatch + } + } + } + + let fullMatchToDef = (registry: registry, {fnName, inputIndex}: IndexMatch2.match): option< + fnDefinition, + > => + registry + ->E.A.getBy(fn => fn.name === fnName) + ->E.O.bind(fn => E.A.get(fn.definitions, inputIndex)) + + let runDef = (fnDefinition: fnDefinition, args: array) => { + switch matchSingleSameName(fnDefinition, args) { + | Match(_, values) => fnDefinition.run(values) + | _ => Error("Impossible") + } + } + + let matchAndRun = (r: registry, fnName: string, args: array) => { + switch findMatches(r, fnName, args) { + | #FullMatch(m) => fullMatchToDef(r, m)->E.O2.fmap(runDef(_, args)) + | _ => None + } + } +} + let twoNumberInputs = (inputs: array) => switch inputs { | [Number(n1), Number(n2)] => Ok(n1, n2) diff --git a/packages/squiggle-lang/src/rescript/Utility/E.res b/packages/squiggle-lang/src/rescript/Utility/E.res index c3cf855d..06d8d4e9 100644 --- a/packages/squiggle-lang/src/rescript/Utility/E.res +++ b/packages/squiggle-lang/src/rescript/Utility/E.res @@ -522,6 +522,7 @@ module A = { let unsafe_get = Array.unsafe_get let get = Belt.Array.get let getBy = Belt.Array.getBy + let getIndexBy = Belt.Array.getIndexBy let last = a => get(a, length(a) - 1) let first = get(_, 0) let hasBy = (r, fn) => Belt.Array.getBy(r, fn) |> O.isSome @@ -535,6 +536,7 @@ module A = { let reducei = Belt.Array.reduceWithIndex let isEmpty = r => length(r) < 1 let stableSortBy = Belt.SortArray.stableSortBy + let toNoneIfEmpty = r => isEmpty(r) ? None : Some(r) let toRanges = (a: array<'a>) => switch a |> Belt.Array.length { | 0 @@ -830,6 +832,7 @@ module A = { module A2 = { let fmap = (a, b) => A.fmap(b, a) + let fmapi = (a, b) => A.fmapi(b, a) let joinWith = (a, b) => A.joinWith(b, a) let filter = (a, b) => A.filter(b, a) }