diff --git a/packages/squiggle-lang/__tests__/GenericDist/GenericOperation__Test.res b/packages/squiggle-lang/__tests__/GenericDist/GenericOperation__Test.res index 52ab6c19..1e5b5397 100644 --- a/packages/squiggle-lang/__tests__/GenericDist/GenericOperation__Test.res +++ b/packages/squiggle-lang/__tests__/GenericDist/GenericOperation__Test.res @@ -12,9 +12,10 @@ let normalDist20: GenericDist_Types.genericDist = #Symbolic(#Normal({mean: 20.0, let uniformDist: GenericDist_Types.genericDist = #Symbolic(#Uniform({low: 9.0, high: 10.0})) let {toFloat, toDist, toString, toError} = module(GenericDist_GenericOperation.Output) -let {run, outputMap} = module(GenericDist_GenericOperation) +let {run} = module(GenericDist_GenericOperation) +let {fmap} = module(GenericDist_GenericOperation.Output) let run = run(params) -let outputMap = outputMap(params) +let outputMap = fmap(params) let toExt: option<'a> => 'a = E.O.toExt( "Should be impossible to reach (This error is in test file)", ) diff --git a/packages/squiggle-lang/src/rescript/GenericDist/GenericDist_GenericOperation.res b/packages/squiggle-lang/src/rescript/GenericDist/GenericDist_GenericOperation.res index e744543e..9cc93949 100644 --- a/packages/squiggle-lang/src/rescript/GenericDist/GenericDist_GenericOperation.res +++ b/packages/squiggle-lang/src/rescript/GenericDist/GenericDist_GenericOperation.res @@ -9,57 +9,59 @@ type params = { xyPointLength: int, } -type outputType = +type outputType = | Dist(GenericDist_Types.genericDist) | Float(float) | String(string) | GenDistError(GenericDist_Types.error) -module Output = { - let toDist = (o: outputType) => - switch o { +/* +We're going to add another function to this module later, so first define a +local version, which is not exported. +*/ +module OutputLocal = { + type t = outputType + + let toError = (t: outputType) => + switch t { + | GenDistError(d) => Some(d) + | _ => None + } + + let toErrorOrUnreachable = (t: t): error => t->toError->E.O2.default((Unreachable: error)) + + let toDistR = (t: t): result => + switch t { + | Dist(r) => Ok(r) + | e => Error(toErrorOrUnreachable(e)) + } + + let toDist = (t: t) => + switch t { | Dist(d) => Some(d) | _ => None } - let toFloat = (o: outputType) => - switch o { + let toFloat = (t: t) => + switch t { | Float(d) => Some(d) | _ => None } - let toString = (o: outputType) => - switch o { + let toString = (t: t) => + switch t { | String(d) => Some(d) | _ => None } - let toError = (o: outputType) => - switch o { - | GenDistError(d) => Some(d) - | _ => None + //This is used to catch errors in other switch statements. + let fromResult = (r: result): outputType => + switch r { + | Ok(t) => t + | Error(e) => GenDistError(e) } } -let fromResult = (r: result): outputType => - switch r { - | Ok(o) => o - | Error(e) => GenDistError(e) - } - -//This is used to catch errors in other switch statements. -let _errorMap = (o: outputType): error => - switch o { - | GenDistError(r) => r - | _ => Unreachable - } - -let outputToDistResult = (o: outputType): result => - switch o { - | Dist(r) => Ok(r) - | r => Error(_errorMap(r)) - } - let rec run = (extra, fnName: operation): outputType => { let {sampleCount, xyPointLength} = extra @@ -70,14 +72,14 @@ let rec run = (extra, fnName: operation): outputType => { let toPointSetFn = r => { switch reCall(~fnName=#fromDist(#toDist(#toPointSet), r), ()) { | Dist(#PointSet(p)) => Ok(p) - | r => Error(_errorMap(r)) + | e => Error(OutputLocal.toErrorOrUnreachable(e)) } } let toSampleSetFn = r => { switch reCall(~fnName=#fromDist(#toDist(#toSampleSet(sampleCount)), r), ()) { | Dist(#SampleSet(p)) => Ok(p) - | r => Error(_errorMap(r)) + | e => Error(OutputLocal.toErrorOrUnreachable(e)) } } @@ -85,20 +87,20 @@ let rec run = (extra, fnName: operation): outputType => { reCall( ~fnName=#fromDist(#toDistCombination(#Pointwise, #Multiply, #Float(weight)), r), (), - )->outputToDistResult + )->OutputLocal.toDistR let pointwiseAdd = (r1, r2) => reCall( ~fnName=#fromDist(#toDistCombination(#Pointwise, #Add, #Dist(r2)), r1), (), - )->outputToDistResult + )->OutputLocal.toDistR let fromDistFn = (subFnName: GenericDist_Types.Operation.fromDist, dist: genericDist) => switch subFnName { | #toFloat(fnName) => GenericDist.operationToFloat(dist, ~toPointSetFn, ~operation=fnName) ->E.R2.fmap(r => Float(r)) - ->fromResult + ->OutputLocal.fromResult | #toString => dist->GenericDist.toString->String | #toDist(#inspect) => { Js.log2("Console log requested: ", dist) @@ -108,27 +110,30 @@ let rec run = (extra, fnName: operation): outputType => { | #toDist(#truncate(leftCutoff, rightCutoff)) => GenericDist.truncate(~toPointSetFn, ~leftCutoff, ~rightCutoff, dist, ()) ->E.R2.fmap(r => Dist(r)) - ->fromResult + ->OutputLocal.fromResult | #toDist(#toPointSet) => - dist->GenericDist.toPointSet(xyPointLength)->E.R2.fmap(r => Dist(#PointSet(r)))->fromResult + dist + ->GenericDist.toPointSet(xyPointLength) + ->E.R2.fmap(r => Dist(#PointSet(r))) + ->OutputLocal.fromResult | #toDist(#toSampleSet(n)) => - dist->GenericDist.sampleN(n)->E.R2.fmap(r => Dist(#SampleSet(r)))->fromResult + dist->GenericDist.sampleN(n)->E.R2.fmap(r => Dist(#SampleSet(r)))->OutputLocal.fromResult | #toDistCombination(#Algebraic, _, #Float(_)) => GenDistError(NotYetImplemented) | #toDistCombination(#Algebraic, operation, #Dist(t2)) => dist ->GenericDist.algebraicCombination(~toPointSetFn, ~toSampleSetFn, ~operation, ~t2) ->E.R2.fmap(r => Dist(r)) - ->fromResult + ->OutputLocal.fromResult | #toDistCombination(#Pointwise, operation, #Dist(t2)) => dist ->GenericDist.pointwiseCombination(~toPointSetFn, ~operation, ~t2) ->E.R2.fmap(r => Dist(r)) - ->fromResult + ->OutputLocal.fromResult | #toDistCombination(#Pointwise, operation, #Float(float)) => dist ->GenericDist.pointwiseCombinationFloat(~toPointSetFn, ~operation, ~float) ->E.R2.fmap(r => Dist(r)) - ->fromResult + ->OutputLocal.fromResult } switch fnName { @@ -139,24 +144,28 @@ let rec run = (extra, fnName: operation): outputType => { dists ->GenericDist.mixture(~scaleMultiplyFn=scaleMultiply, ~pointwiseAddFn=pointwiseAdd) ->E.R2.fmap(r => Dist(r)) - ->fromResult + ->OutputLocal.fromResult } } let runFromDist = (extra, fnName, dist) => run(extra, #fromDist(fnName, dist)) let runFromFloat = (extra, fnName, float) => run(extra, #fromFloat(fnName, float)) -let outputMap = ( - extra, - input: outputType, - fn: GenericDist_Types.Operation.singleParamaterFunction, -): outputType => { - let newFnCall: result = switch (fn, input) { - | (#fromDist(fromDist), Dist(o)) => Ok(#fromDist(fromDist, o)) - | (#fromFloat(fromDist), Float(o)) => Ok(#fromFloat(fromDist, o)) - | (_, GenDistError(r)) => Error(r) - | (#fromDist(_), _) => Error(Other("Expected dist, got something else")) - | (#fromFloat(_), _) => Error(Other("Expected float, got something else")) +module Output = { + include OutputLocal + + let fmap = ( + extra, + input: outputType, + fn: GenericDist_Types.Operation.singleParamaterFunction, + ): outputType => { + let newFnCall: result = switch (fn, input) { + | (#fromDist(fromDist), Dist(o)) => Ok(#fromDist(fromDist, o)) + | (#fromFloat(fromDist), Float(o)) => Ok(#fromFloat(fromDist, o)) + | (_, GenDistError(r)) => Error(r) + | (#fromDist(_), _) => Error(Other("Expected dist, got something else")) + | (#fromFloat(_), _) => Error(Other("Expected float, got something else")) + } + newFnCall->E.R2.fmap(r => run(extra, r))->OutputLocal.fromResult } - newFnCall->E.R2.fmap(r => run(extra, r))->fromResult } diff --git a/packages/squiggle-lang/src/rescript/GenericDist/GenericDist_GenericOperation.resi b/packages/squiggle-lang/src/rescript/GenericDist/GenericDist_GenericOperation.resi index 22864ed3..2769a505 100644 --- a/packages/squiggle-lang/src/rescript/GenericDist/GenericDist_GenericOperation.resi +++ b/packages/squiggle-lang/src/rescript/GenericDist/GenericDist_GenericOperation.resi @@ -16,15 +16,13 @@ let runFromDist: ( GenericDist_Types.genericDist, ) => outputType let runFromFloat: (params, GenericDist_Types.Operation.fromDist, float) => outputType -let outputMap: ( - params, - outputType, - GenericDist_Types.Operation.singleParamaterFunction, -) => outputType module Output: { - let toDist: outputType => option - let toFloat: outputType => option - let toString: outputType => option - let toError: outputType => option -} + type t = outputType + let toDist: t => option + let toDistR: t => result + let toFloat: t => option + let toString: t => option + let toError: t => option + let fmap: (params, t, GenericDist_Types.Operation.singleParamaterFunction) => t +} \ No newline at end of file diff --git a/packages/squiggle-lang/src/rescript/utility/E.res b/packages/squiggle-lang/src/rescript/utility/E.res index 121ecb91..9c6c2a73 100644 --- a/packages/squiggle-lang/src/rescript/utility/E.res +++ b/packages/squiggle-lang/src/rescript/utility/E.res @@ -98,6 +98,10 @@ module O = { let max = compare(\">") } +module O2 = { + let default = (a,b) => O.default(b,a) +} + /* Functions */ module F = { let apply = (a, e) => a |> e