From ffaf349e0aa2f998b023f4c112d4d93aa1032981 Mon Sep 17 00:00:00 2001 From: Sam Nolan Date: Tue, 19 Jul 2022 11:33:11 +1000 Subject: [PATCH] Basic mapSampleN support --- .../Reducer_Dispatch_BuiltIn_test.res | 4 +++ .../SampleSetDist/SampleSetDist.res | 5 ++++ .../Reducer_Dispatch_BuiltIn.res | 26 +++++++++++++++++-- .../squiggle-lang/src/rescript/Utility/E.res | 13 ++++++++++ 4 files changed, 46 insertions(+), 2 deletions(-) diff --git a/packages/squiggle-lang/__tests__/Reducer/Reducer_Dispatch/Reducer_Dispatch_BuiltIn_test.res b/packages/squiggle-lang/__tests__/Reducer/Reducer_Dispatch/Reducer_Dispatch_BuiltIn_test.res index 98192d31..65048ebb 100644 --- a/packages/squiggle-lang/__tests__/Reducer/Reducer_Dispatch/Reducer_Dispatch_BuiltIn_test.res +++ b/packages/squiggle-lang/__tests__/Reducer/Reducer_Dispatch/Reducer_Dispatch_BuiltIn_test.res @@ -21,6 +21,10 @@ describe("builtin", () => { "addOne(t)=t+1; toList(mapSamples(fromSamples([1,2,3,4,5,6]), addOne))", "Ok([2,3,4,5,6,7])", ) + testEval( + "toList(mapSamplesN([fromSamples([1,2,3,4,5,6]), fromSamples([6, 5, 4, 3, 2, 1])], {|x| x[0] > x[1] ? x[0] : x[1]}))", + "Ok([6,5,4,4,5,6])", + ) }) describe("builtin exception", () => { diff --git a/packages/squiggle-lang/src/rescript/Distributions/SampleSetDist/SampleSetDist.res b/packages/squiggle-lang/src/rescript/Distributions/SampleSetDist/SampleSetDist.res index bfbaa795..dc15f7a1 100644 --- a/packages/squiggle-lang/src/rescript/Distributions/SampleSetDist/SampleSetDist.res +++ b/packages/squiggle-lang/src/rescript/Distributions/SampleSetDist/SampleSetDist.res @@ -117,6 +117,11 @@ let map3 = ( ): result => E.A.zip3(get(t1), get(t2), get(t3))->E.A2.fmap(E.Tuple3.toFnCall(fn))->_fromSampleResultArray +let mapN = (~fn: array => result, ~t1: array): result< + t, + sampleSetError, +> => E.A.transpose(E.A.fmap(get, t1))->E.A2.fmap(fn)->_fromSampleResultArray + let mean = t => T.get(t)->E.A.Floats.mean let geomean = t => T.get(t)->E.A.Floats.geomean let mode = t => T.get(t)->E.A.Floats.mode diff --git a/packages/squiggle-lang/src/rescript/Reducer/Reducer_Dispatch/Reducer_Dispatch_BuiltIn.res b/packages/squiggle-lang/src/rescript/Reducer/Reducer_Dispatch/Reducer_Dispatch_BuiltIn.res index 484b0acb..c7ab817e 100644 --- a/packages/squiggle-lang/src/rescript/Reducer/Reducer_Dispatch/Reducer_Dispatch_BuiltIn.res +++ b/packages/squiggle-lang/src/rescript/Reducer/Reducer_Dispatch/Reducer_Dispatch_BuiltIn.res @@ -19,6 +19,15 @@ open Reducer_ErrorValue exception TestRescriptException +let parseSampleSetArray = (arr: array): option> => { + let parseSampleSet = (value: internalExpressionValue): option => + switch value { + | IEvDistribution(SampleSet(dist)) => Some(dist) + | _ => None + } + E.A.O.openIfAllSome(E.A.fmap(parseSampleSet, arr)) +} + let callInternal = (call: functionCall, environment, reducer: ExpressionT.reducerFn): result< 'b, errorValue, @@ -149,6 +158,11 @@ let callInternal = (call: functionCall, environment, reducer: ExpressionT.reduce doLambdaCall(aLambdaValue, list{IEvNumber(a), IEvNumber(b), IEvNumber(c)}) SampleSetDist.map3(~fn, ~t1, ~t2, ~t3)->toType } + + let mapN = (t1: array, aLambdaValue) => { + let fn = a => doLambdaCall(aLambdaValue, list{IEvArray(E.A.fmap(x => IEvNumber(x), a))}) + SampleSetDist.mapN(~fn, ~t1)->toType + } } let doReduceArray = (aValueArray, initialValue, aLambdaValue) => { @@ -230,6 +244,12 @@ let callInternal = (call: functionCall, environment, reducer: ExpressionT.reduce ], ) => SampleMap.map3(dist1, dist2, dist3, aLambdaValue) + | ("mapSamplesN", [IEvArray(aValueArray), IEvLambda(aLambdaValue)]) => + switch parseSampleSetArray(aValueArray) { + | Some(sampleSetArr) => SampleMap.mapN(sampleSetArr, aLambdaValue) + | None => + Error(REFunctionNotFound(call->functionCallToCallSignature->functionCallSignatureToString)) + } | ("reduce", [IEvArray(aValueArray), initialValue, IEvLambda(aLambdaValue)]) => doReduceArray(aValueArray, initialValue, aLambdaValue) | ("reduceReverse", [IEvArray(aValueArray), initialValue, IEvLambda(aLambdaValue)]) => @@ -246,7 +266,6 @@ let callInternal = (call: functionCall, environment, reducer: ExpressionT.reduce Error(REFunctionNotFound(call->functionCallToCallSignature->functionCallSignatureToString)) // Report full type signature as error } } - /* Reducer uses Result monad while reducing expressions */ @@ -262,5 +281,8 @@ let dispatch = (call: functionCall, environment, reducer: ExpressionT.reducerFn) ExternalLibrary.dispatch((Js.String.make(fn), args), environment, callInternalWithReducer) } catch { | Js.Exn.Error(obj) => REJavaScriptExn(Js.Exn.message(obj), Js.Exn.name(obj))->Error - | _ => RETodo("unhandled rescript exception")->Error + | err => { + Js.Console.log(err) + RETodo("unhandled rescript exception")->Error + } } diff --git a/packages/squiggle-lang/src/rescript/Utility/E.res b/packages/squiggle-lang/src/rescript/Utility/E.res index fd328a1c..f593cd48 100644 --- a/packages/squiggle-lang/src/rescript/Utility/E.res +++ b/packages/squiggle-lang/src/rescript/Utility/E.res @@ -631,6 +631,19 @@ module A = { ) let filter = Js.Array.filter let joinWith = Js.Array.joinWith + let transpose = (xs: array>): array> => { + let arr: array> = [] + for i in 0 to length(xs) - 1 { + for j in 0 to length(xs[i]) - 1 { + if Js.Array.length(arr) <= j { + ignore(Js.Array.push([xs[i][j]], arr)) + } else { + ignore(Js.Array.push(xs[i][j], arr[j])) + } + } + } + arr + } let all = (p: 'a => bool, xs: array<'a>): bool => length(filter(p, xs)) == length(xs) let any = (p: 'a => bool, xs: array<'a>): bool => length(filter(p, xs)) > 0