diff --git a/packages/squiggle-lang/src/rescript/GenericDist/GenericDist.res b/packages/squiggle-lang/src/rescript/GenericDist/GenericDist.res index 56a086ca..48588afe 100644 --- a/packages/squiggle-lang/src/rescript/GenericDist/GenericDist.res +++ b/packages/squiggle-lang/src/rescript/GenericDist/GenericDist.res @@ -41,7 +41,7 @@ let operationToFloat = (t, toPointSet: toPointSetFn, fnName) => { switch symbolicSolution { | Some(r) => Ok(r) - | None => toPointSet(t)->E.R.fmap2(PointSetDist.operate(fnName)) + | None => toPointSet(t)->E.R2.fmap(PointSetDist.operate(fnName)) } } @@ -95,7 +95,7 @@ module Truncate = { switch trySymbolicSimplification(leftCutoff, rightCutoff, t) { | Some(r) => Ok(r) | None => - toPointSet(t)->E.R.fmap2(t => + toPointSet(t)->E.R2.fmap(t => #PointSet(PointSetDist.T.truncate(leftCutoff, rightCutoff, t)) ) } @@ -134,7 +134,7 @@ module AlgebraicCombination = { t1: t, t2: t, ) => - E.R.merge(toPointSet(t1), toPointSet(t2))->E.R.fmap2(((a, b)) => + E.R.merge(toPointSet(t1), toPointSet(t2))->E.R2.fmap(((a, b)) => PointSetDist.combineAlgebraically(operation, a, b) ) @@ -145,8 +145,8 @@ module AlgebraicCombination = { t2: t, ) => { let operation = Operation.Algebraic.toFn(operation) - E.R.merge(toSampleSet(t1), toSampleSet(t2)) -> E.R.fmap2(((a, b)) => { - Belt.Array.zip(a, b) -> E.A.fmap2(((a, b)) => operation(a, b)) + E.R.merge(toSampleSet(t1), toSampleSet(t2))->E.R2.fmap(((a, b)) => { + Belt.Array.zip(a, b)->E.A2.fmap(((a, b)) => operation(a, b)) }) } @@ -155,7 +155,7 @@ module AlgebraicCombination = { switch x { | #Symbolic(#Float(_)) => 1 | #Symbolic(_) => 1000 - | #PointSet(Discrete(m)) => m.xyShape -> XYShape.T.length + | #PointSet(Discrete(m)) => m.xyShape->XYShape.T.length | #PointSet(Mixed(_)) => 1000 | #PointSet(Continuous(_)) => 1000 | _ => 1000 @@ -179,9 +179,9 @@ module AlgebraicCombination = { | None => switch chooseConvolutionOrMonteCarlo(t1, t2) { | #CalculateWithMonteCarlo => - runMonteCarlo(toSampleSet, algebraicOp, t1, t2)->E.R.fmap2(r => #SampleSet(r)) + runMonteCarlo(toSampleSet, algebraicOp, t1, t2)->E.R2.fmap(r => #SampleSet(r)) | #CalculateWithConvolution => - runConvolution(toPointSet, algebraicOp, t1, t2)->E.R.fmap2(r => #PointSet(r)) + runConvolution(toPointSet, algebraicOp, t1, t2)->E.R2.fmap(r => #PointSet(r)) } } } @@ -195,10 +195,10 @@ let pointwiseCombination = (t1: t, toPointSet: toPointSetFn, operation, t2: t): error, > => { E.R.merge(toPointSet(t1), toPointSet(t2)) - ->E.R.fmap2(((t1, t2)) => + ->E.R2.fmap(((t1, t2)) => PointSetDist.combinePointwise(GenericDist_Types.Operation.arithmeticToFn(operation), t1, t2) ) - ->E.R.fmap2(r => #PointSet(r)) + ->E.R2.fmap(r => #PointSet(r)) } let pointwiseCombinationFloat = ( @@ -210,7 +210,7 @@ let pointwiseCombinationFloat = ( switch operation { | #Add | #Subtract => Error(GenericDist_Types.DistributionVerticalShiftIsInvalid) | (#Multiply | #Divide | #Exponentiate | #Log) as operation => - toPointSet(t)->E.R.fmap2(t => { + toPointSet(t)->E.R2.fmap(t => { //TODO: Move to PointSet codebase let fn = (secondary, main) => Operation.Scale.toFn(operation, main, secondary) let integralSumCacheFn = Operation.Scale.toIntegralSumCacheFn(operation) @@ -222,10 +222,10 @@ let pointwiseCombinationFloat = ( t, ) }) - }->E.R.fmap2(r => #PointSet(r)) + }->E.R2.fmap(r => #PointSet(r)) } -//Note: The result should always cumulatively sum to 1. +//Note: The result should always cumulatively sum to 1. This would be good to test. let mixture = ( values: array<(t, float)>, scaleMultiply: scaleMultiplyFn, @@ -234,10 +234,10 @@ let mixture = ( if E.A.length(values) == 0 { Error(GenericDist_Types.Other("mixture must have at least 1 element")) } else { - let totalWeight = values->E.A.fmap2(E.Tuple2.second)->E.A.Floats.sum + let totalWeight = values->E.A2.fmap(E.Tuple2.second)->E.A.Floats.sum let properlyWeightedValues = values - ->E.A.fmap2(((dist, weight)) => scaleMultiply(dist, weight /. totalWeight)) + ->E.A2.fmap(((dist, weight)) => scaleMultiply(dist, weight /. totalWeight)) ->E.A.R.firstErrorOrOpen properlyWeightedValues->E.R.bind(values => { values diff --git a/packages/squiggle-lang/src/rescript/GenericDist/GenericDist_GenericOperation.res b/packages/squiggle-lang/src/rescript/GenericDist/GenericDist_GenericOperation.res index c3aa71f8..43aec78f 100644 --- a/packages/squiggle-lang/src/rescript/GenericDist/GenericDist_GenericOperation.res +++ b/packages/squiggle-lang/src/rescript/GenericDist/GenericDist_GenericOperation.res @@ -93,7 +93,7 @@ let rec run = (extra, fnName: operation): outputType => { let fromDistFn = (subFn: GenericDist_Types.Operation.fromDist, dist: genericDist) => switch subFn { | #toFloat(fnName) => - GenericDist.operationToFloat(dist, toPointSet, fnName)->E.R.fmap2(r => #Float(r))->fromResult + GenericDist.operationToFloat(dist, toPointSet, fnName)->E.R2.fmap(r => #Float(r))->fromResult | #toString => dist->GenericDist.toString->(r => #String(r)) | #toDist(#consoleLog) => { Js.log2("Console log requested: ", dist) @@ -101,26 +101,26 @@ let rec run = (extra, fnName: operation): outputType => { } | #toDist(#normalize) => dist->GenericDist.normalize->(r => #Dist(r)) | #toDist(#truncate(left, right)) => - dist->GenericDist.truncate(toPointSet, left, right)->E.R.fmap2(r => #Dist(r))->fromResult + dist->GenericDist.truncate(toPointSet, left, right)->E.R2.fmap(r => #Dist(r))->fromResult | #toDist(#toPointSet) => - dist->GenericDist.toPointSet(xyPointLength)->E.R.fmap2(r => #Dist(#PointSet(r)))->fromResult + dist->GenericDist.toPointSet(xyPointLength)->E.R2.fmap(r => #Dist(#PointSet(r)))->fromResult | #toDist(#toSampleSet(n)) => - dist->GenericDist.sampleN(n)->E.R.fmap2(r => #Dist(#SampleSet(r)))->fromResult + dist->GenericDist.sampleN(n)->E.R2.fmap(r => #Dist(#SampleSet(r)))->fromResult | #toDistCombination(#Algebraic, _, #Float(_)) => #Error(NotYetImplemented) | #toDistCombination(#Algebraic, operation, #Dist(dist2)) => dist ->GenericDist.algebraicCombination(toPointSet, toSampleSet, operation, dist2) - ->E.R.fmap2(r => #Dist(r)) + ->E.R2.fmap(r => #Dist(r)) ->fromResult | #toDistCombination(#Pointwise, operation, #Dist(dist2)) => dist ->GenericDist.pointwiseCombination(toPointSet, operation, dist2) - ->E.R.fmap2(r => #Dist(r)) + ->E.R2.fmap(r => #Dist(r)) ->fromResult | #toDistCombination(#Pointwise, operation, #Float(f)) => dist ->GenericDist.pointwiseCombinationFloat(toPointSet, operation, f) - ->E.R.fmap2(r => #Dist(r)) + ->E.R2.fmap(r => #Dist(r)) ->fromResult } @@ -128,7 +128,7 @@ let rec run = (extra, fnName: operation): outputType => { | #fromDist(subFn, dist) => fromDistFn(subFn, dist) | #fromFloat(subFn, float) => reCall(~fnName=#fromDist(subFn, GenericDist.fromFloat(float)), ()) | #mixture(dists) => - dists->GenericDist.mixture(scaleMultiply, pointwiseAdd)->E.R.fmap2(r => #Dist(r))->fromResult + dists->GenericDist.mixture(scaleMultiply, pointwiseAdd)->E.R2.fmap(r => #Dist(r))->fromResult } } @@ -147,5 +147,5 @@ let fmap = ( | (#fromDist(_), _) => Error(Other("Expected dist, got something else")) | (#fromFloat(_), _) => Error(Other("Expected float, got something else")) } - newFnCall->E.R.fmap2(r => run(extra, r))->fromResult + newFnCall->E.R2.fmap(r => run(extra, r))->fromResult } diff --git a/packages/squiggle-lang/src/rescript/utility/E.res b/packages/squiggle-lang/src/rescript/utility/E.res index 69d8ecc1..121ecb91 100644 --- a/packages/squiggle-lang/src/rescript/utility/E.res +++ b/packages/squiggle-lang/src/rescript/utility/E.res @@ -148,7 +148,6 @@ module R = { let result = Rationale.Result.result let id = e => e |> result(U.id, U.id) let fmap = Rationale.Result.fmap - let fmap2 = (a,b) => Rationale.Result.fmap(b,a) let bind = Rationale.Result.bind let toExn = Belt.Result.getExn let default = (default, res: Belt.Result.t<'a, 'b>) => @@ -172,6 +171,10 @@ module R = { errorCondition(r) ? Error(errorMessage) : Ok(r) } +module R2 = { + let fmap = (a,b) => R.fmap(b,a) +} + let safe_fn_of_string = (fn, s: string): option<'a> => try Some(fn(s)) catch { | _ => None @@ -245,7 +248,6 @@ module L = { /* A for Array */ module A = { let fmap = Array.map - let fmap2 = (a,b) => Array.map(b,a) let fmapi = Array.mapi let to_list = Array.to_list let of_list = Array.of_list @@ -448,6 +450,10 @@ module A = { } } +module A2 = { + let fmap = (a,b) => A.fmap(b,a) +} + module JsArray = { let concatSomes = (optionals: Js.Array.t>): Js.Array.t<'a> => optionals