From 4b3f24b38daeb53db7f2117438146dcf7b74511d Mon Sep 17 00:00:00 2001 From: Ozzie Gooen Date: Thu, 31 Mar 2022 14:07:39 -0400 Subject: [PATCH] Converted params to env, named several arguments --- .../GenericDist/GenericOperation__Test.res | 8 +- .../src/rescript/GenericDist/GenericDist.res | 80 ++++++++++--------- .../src/rescript/GenericDist/GenericDist.resi | 16 ++-- .../GenericDist_GenericOperation.res | 55 ++++++------- .../GenericDist_GenericOperation.resi | 14 ++-- 5 files changed, 91 insertions(+), 82 deletions(-) diff --git a/packages/squiggle-lang/__tests__/GenericDist/GenericOperation__Test.res b/packages/squiggle-lang/__tests__/GenericDist/GenericOperation__Test.res index 1e5b5397..a5c1011f 100644 --- a/packages/squiggle-lang/__tests__/GenericDist/GenericOperation__Test.res +++ b/packages/squiggle-lang/__tests__/GenericDist/GenericOperation__Test.res @@ -1,7 +1,7 @@ open Jest open Expect -let params: GenericDist_GenericOperation.params = { +let env: GenericDist_GenericOperation.env = { sampleCount: 100, xyPointLength: 100, } @@ -14,8 +14,8 @@ let uniformDist: GenericDist_Types.genericDist = #Symbolic(#Uniform({low: 9.0, h let {toFloat, toDist, toString, toError} = module(GenericDist_GenericOperation.Output) let {run} = module(GenericDist_GenericOperation) let {fmap} = module(GenericDist_GenericOperation.Output) -let run = run(params) -let outputMap = fmap(params) +let run = run(~env) +let outputMap = fmap(~env) let toExt: option<'a> => 'a = E.O.toExt( "Should be impossible to reach (This error is in test file)", ) @@ -29,7 +29,7 @@ describe("normalize", () => { describe("mean", () => { test("for a normal distribution", () => { - let result = GenericDist_GenericOperation.run(params, #fromDist(#toFloat(#Mean), normalDist)) + let result = GenericDist_GenericOperation.run(~env, #fromDist(#toFloat(#Mean), normalDist)) expect(result)->toEqual(Float(5.0)) }) }) diff --git a/packages/squiggle-lang/src/rescript/GenericDist/GenericDist.res b/packages/squiggle-lang/src/rescript/GenericDist/GenericDist.res index bb48c36e..ba293541 100644 --- a/packages/squiggle-lang/src/rescript/GenericDist/GenericDist.res +++ b/packages/squiggle-lang/src/rescript/GenericDist/GenericDist.res @@ -29,14 +29,14 @@ let normalize = (t: t) => | #SampleSet(_) => t } -let operationToFloat = ( +let toFloatOperation = ( t, ~toPointSetFn: toPointSetFn, - ~operation: Operation.distToFloatOperation, + ~distToFloatOperation: Operation.distToFloatOperation, ) => { let symbolicSolution = switch t { | #Symbolic(r) => - switch SymbolicDist.T.operate(operation, r) { + switch SymbolicDist.T.operate(distToFloatOperation, r) { | Ok(f) => Some(f) | _ => None } @@ -45,28 +45,26 @@ let operationToFloat = ( switch symbolicSolution { | Some(r) => Ok(r) - | None => toPointSetFn(t)->E.R2.fmap(PointSetDist.operate(operation)) + | None => toPointSetFn(t)->E.R2.fmap(PointSetDist.operate(distToFloatOperation)) } } -//TODO: Refactor this bit. -let defaultSamplingInputs: SamplingInputs.samplingInputs = { - sampleCount: 10000, - outputXYPoints: 10000, - pointSetDistLength: 1000, - kernelWidth: None, -} - //Todo: If it's a pointSet, but the xyPointLenght is different from what it has, it should change. // This is tricky because the case of discrete distributions. -let toPointSet = (t, xyPointLength): result => { +// Also, change the outputXYPoints/pointSetDistLength details +let toPointSet = (~xyPointLength, ~sampleCount, t): result => { switch t { | #PointSet(pointSet) => Ok(pointSet) | #Symbolic(r) => Ok(SymbolicDist.T.toPointSetDist(xyPointLength, r)) | #SampleSet(r) => { let response = SampleSet.toPointSetDist( ~samples=r, - ~samplingInputs=defaultSamplingInputs, + ~samplingInputs={ + sampleCount: sampleCount, + outputXYPoints: xyPointLength, + pointSetDistLength: xyPointLength, + kernelWidth: None, + }, (), ).pointSetDist switch response { @@ -119,13 +117,13 @@ let truncate = Truncate.run */ module AlgebraicCombination = { let tryAnalyticalSimplification = ( - operation: GenericDist_Types.Operation.arithmeticOperation, + arithmeticOperation: GenericDist_Types.Operation.arithmeticOperation, t1: t, t2: t, ): option> => - switch (operation, t1, t2) { - | (operation, #Symbolic(d1), #Symbolic(d2)) => - switch SymbolicDist.T.tryAnalyticalSimplification(d1, d2, operation) { + switch (arithmeticOperation, t1, t2) { + | (arithmeticOperation, #Symbolic(d1), #Symbolic(d2)) => + switch SymbolicDist.T.tryAnalyticalSimplification(d1, d2, arithmeticOperation) { | #AnalyticalSolution(symbolicDist) => Some(Ok(symbolicDist)) | #Error(er) => Some(Error(er)) | #NoSolution => None @@ -135,23 +133,23 @@ module AlgebraicCombination = { let runConvolution = ( toPointSet: toPointSetFn, - operation: GenericDist_Types.Operation.arithmeticOperation, + arithmeticOperation: GenericDist_Types.Operation.arithmeticOperation, t1: t, t2: t, ) => E.R.merge(toPointSet(t1), toPointSet(t2))->E.R2.fmap(((a, b)) => - PointSetDist.combineAlgebraically(operation, a, b) + PointSetDist.combineAlgebraically(arithmeticOperation, a, b) ) let runMonteCarlo = ( toSampleSet: toSampleSetFn, - operation: GenericDist_Types.Operation.arithmeticOperation, + arithmeticOperation: GenericDist_Types.Operation.arithmeticOperation, t1: t, t2: t, ) => { - let operation = Operation.Algebraic.toFn(operation) + let arithmeticOperation = Operation.Algebraic.toFn(arithmeticOperation) E.R.merge(toSampleSet(t1), toSampleSet(t2))->E.R2.fmap(((a, b)) => { - Belt.Array.zip(a, b)->E.A2.fmap(((a, b)) => operation(a, b)) + Belt.Array.zip(a, b)->E.A2.fmap(((a, b)) => arithmeticOperation(a, b)) }) } @@ -175,18 +173,18 @@ module AlgebraicCombination = { t1: t, ~toPointSetFn: toPointSetFn, ~toSampleSetFn: toSampleSetFn, - ~operation, + ~arithmeticOperation, ~t2: t, ): result => { - switch tryAnalyticalSimplification(operation, t1, t2) { + switch tryAnalyticalSimplification(arithmeticOperation, t1, t2) { | Some(Ok(symbolicDist)) => Ok(#Symbolic(symbolicDist)) | Some(Error(e)) => Error(Other(e)) | None => switch chooseConvolutionOrMonteCarlo(t1, t2) { | #CalculateWithMonteCarlo => - runMonteCarlo(toSampleSetFn, operation, t1, t2)->E.R2.fmap(r => #SampleSet(r)) + runMonteCarlo(toSampleSetFn, arithmeticOperation, t1, t2)->E.R2.fmap(r => #SampleSet(r)) | #CalculateWithConvolution => - runConvolution(toPointSetFn, operation, t1, t2)->E.R2.fmap(r => #PointSet(r)) + runConvolution(toPointSetFn, arithmeticOperation, t1, t2)->E.R2.fmap(r => #PointSet(r)) } } } @@ -195,13 +193,19 @@ module AlgebraicCombination = { let algebraicCombination = AlgebraicCombination.run //TODO: Add faster pointwiseCombine fn -let pointwiseCombination = (t1: t, ~toPointSetFn: toPointSetFn, ~operation, ~t2: t): result< - t, - error, -> => { +let pointwiseCombination = ( + t1: t, + ~toPointSetFn: toPointSetFn, + ~arithmeticOperation, + ~t2: t, +): result => { E.R.merge(toPointSetFn(t1), toPointSetFn(t2)) ->E.R2.fmap(((t1, t2)) => - PointSetDist.combinePointwise(GenericDist_Types.Operation.arithmeticToFn(operation), t1, t2) + PointSetDist.combinePointwise( + GenericDist_Types.Operation.arithmeticToFn(arithmeticOperation), + t1, + t2, + ) ) ->E.R2.fmap(r => #PointSet(r)) } @@ -209,17 +213,17 @@ let pointwiseCombination = (t1: t, ~toPointSetFn: toPointSetFn, ~operation, ~t2: let pointwiseCombinationFloat = ( t: t, ~toPointSetFn: toPointSetFn, - ~operation: GenericDist_Types.Operation.arithmeticOperation, + ~arithmeticOperation: GenericDist_Types.Operation.arithmeticOperation, ~float: float, ): result => { - let m = switch operation { + let m = switch arithmeticOperation { | #Add | #Subtract => Error(GenericDist_Types.DistributionVerticalShiftIsInvalid) - | (#Multiply | #Divide | #Exponentiate | #Log) as operation => + | (#Multiply | #Divide | #Exponentiate | #Log) as arithmeticOperation => toPointSetFn(t)->E.R2.fmap(t => { //TODO: Move to PointSet codebase - let fn = (secondary, main) => Operation.Scale.toFn(operation, main, secondary) - let integralSumCacheFn = Operation.Scale.toIntegralSumCacheFn(operation) - let integralCacheFn = Operation.Scale.toIntegralCacheFn(operation) + let fn = (secondary, main) => Operation.Scale.toFn(arithmeticOperation, main, secondary) + let integralSumCacheFn = Operation.Scale.toIntegralSumCacheFn(arithmeticOperation) + let integralCacheFn = Operation.Scale.toIntegralCacheFn(arithmeticOperation) PointSetDist.T.mapY( ~integralSumCacheFn=integralSumCacheFn(float), ~integralCacheFn=integralCacheFn(float), diff --git a/packages/squiggle-lang/src/rescript/GenericDist/GenericDist.resi b/packages/squiggle-lang/src/rescript/GenericDist/GenericDist.resi index 92830670..f61a983f 100644 --- a/packages/squiggle-lang/src/rescript/GenericDist/GenericDist.resi +++ b/packages/squiggle-lang/src/rescript/GenericDist/GenericDist.resi @@ -13,13 +13,17 @@ let toString: t => string let normalize: t => t -let operationToFloat: ( +let toFloatOperation: ( t, ~toPointSetFn: toPointSetFn, - ~operation: Operation.distToFloatOperation, + ~distToFloatOperation: Operation.distToFloatOperation, ) => result -let toPointSet: (t, int) => result +let toPointSet: ( + ~xyPointLength: int, + ~sampleCount: int, + t, +) => result let truncate: ( t, @@ -33,21 +37,21 @@ let algebraicCombination: ( t, ~toPointSetFn: toPointSetFn, ~toSampleSetFn: toSampleSetFn, - ~operation: GenericDist_Types.Operation.arithmeticOperation, + ~arithmeticOperation: GenericDist_Types.Operation.arithmeticOperation, ~t2: t, ) => result let pointwiseCombination: ( t, ~toPointSetFn: toPointSetFn, - ~operation: GenericDist_Types.Operation.arithmeticOperation, + ~arithmeticOperation: GenericDist_Types.Operation.arithmeticOperation, ~t2: t, ) => result let pointwiseCombinationFloat: ( t, ~toPointSetFn: toPointSetFn, - ~operation: GenericDist_Types.Operation.arithmeticOperation, + ~arithmeticOperation: GenericDist_Types.Operation.arithmeticOperation, ~float: float, ) => result diff --git a/packages/squiggle-lang/src/rescript/GenericDist/GenericDist_GenericOperation.res b/packages/squiggle-lang/src/rescript/GenericDist/GenericDist_GenericOperation.res index 9cc93949..51878a7d 100644 --- a/packages/squiggle-lang/src/rescript/GenericDist/GenericDist_GenericOperation.res +++ b/packages/squiggle-lang/src/rescript/GenericDist/GenericDist_GenericOperation.res @@ -1,10 +1,10 @@ -type operation = GenericDist_Types.Operation.genericFunctionCallInfo +type functionCallInfo = GenericDist_Types.Operation.genericFunctionCallInfo type genericDist = GenericDist_Types.genericDist type error = GenericDist_Types.error // TODO: It could be great to use a cache for some calculations (basically, do memoization). Also, better analytics/tracking could go a long way. -type params = { +type env = { sampleCount: int, xyPointLength: int, } @@ -62,22 +62,22 @@ module OutputLocal = { } } -let rec run = (extra, fnName: operation): outputType => { - let {sampleCount, xyPointLength} = extra +let rec run = (~env, functionCallInfo: functionCallInfo): outputType => { + let {sampleCount, xyPointLength} = env - let reCall = (~extra=extra, ~fnName=fnName, ()) => { - run(extra, fnName) + let reCall = (~env=env, ~functionCallInfo=functionCallInfo, ()) => { + run(~env, functionCallInfo) } let toPointSetFn = r => { - switch reCall(~fnName=#fromDist(#toDist(#toPointSet), r), ()) { + switch reCall(~functionCallInfo=#fromDist(#toDist(#toPointSet), r), ()) { | Dist(#PointSet(p)) => Ok(p) | e => Error(OutputLocal.toErrorOrUnreachable(e)) } } let toSampleSetFn = r => { - switch reCall(~fnName=#fromDist(#toDist(#toSampleSet(sampleCount)), r), ()) { + switch reCall(~functionCallInfo=#fromDist(#toDist(#toSampleSet(sampleCount)), r), ()) { | Dist(#SampleSet(p)) => Ok(p) | e => Error(OutputLocal.toErrorOrUnreachable(e)) } @@ -85,20 +85,20 @@ let rec run = (extra, fnName: operation): outputType => { let scaleMultiply = (r, weight) => reCall( - ~fnName=#fromDist(#toDistCombination(#Pointwise, #Multiply, #Float(weight)), r), + ~functionCallInfo=#fromDist(#toDistCombination(#Pointwise, #Multiply, #Float(weight)), r), (), )->OutputLocal.toDistR let pointwiseAdd = (r1, r2) => reCall( - ~fnName=#fromDist(#toDistCombination(#Pointwise, #Add, #Dist(r2)), r1), + ~functionCallInfo=#fromDist(#toDistCombination(#Pointwise, #Add, #Dist(r2)), r1), (), )->OutputLocal.toDistR let fromDistFn = (subFnName: GenericDist_Types.Operation.fromDist, dist: genericDist) => switch subFnName { - | #toFloat(fnName) => - GenericDist.operationToFloat(dist, ~toPointSetFn, ~operation=fnName) + | #toFloat(distToFloatOperation) => + GenericDist.toFloatOperation(dist, ~toPointSetFn, ~distToFloatOperation) ->E.R2.fmap(r => Float(r)) ->OutputLocal.fromResult | #toString => dist->GenericDist.toString->String @@ -113,33 +113,33 @@ let rec run = (extra, fnName: operation): outputType => { ->OutputLocal.fromResult | #toDist(#toPointSet) => dist - ->GenericDist.toPointSet(xyPointLength) + ->GenericDist.toPointSet(~xyPointLength, ~sampleCount) ->E.R2.fmap(r => Dist(#PointSet(r))) ->OutputLocal.fromResult | #toDist(#toSampleSet(n)) => dist->GenericDist.sampleN(n)->E.R2.fmap(r => Dist(#SampleSet(r)))->OutputLocal.fromResult | #toDistCombination(#Algebraic, _, #Float(_)) => GenDistError(NotYetImplemented) - | #toDistCombination(#Algebraic, operation, #Dist(t2)) => + | #toDistCombination(#Algebraic, arithmeticOperation, #Dist(t2)) => dist - ->GenericDist.algebraicCombination(~toPointSetFn, ~toSampleSetFn, ~operation, ~t2) + ->GenericDist.algebraicCombination(~toPointSetFn, ~toSampleSetFn, ~arithmeticOperation, ~t2) ->E.R2.fmap(r => Dist(r)) ->OutputLocal.fromResult - | #toDistCombination(#Pointwise, operation, #Dist(t2)) => + | #toDistCombination(#Pointwise, arithmeticOperation, #Dist(t2)) => dist - ->GenericDist.pointwiseCombination(~toPointSetFn, ~operation, ~t2) + ->GenericDist.pointwiseCombination(~toPointSetFn, ~arithmeticOperation, ~t2) ->E.R2.fmap(r => Dist(r)) ->OutputLocal.fromResult - | #toDistCombination(#Pointwise, operation, #Float(float)) => + | #toDistCombination(#Pointwise, arithmeticOperation, #Float(float)) => dist - ->GenericDist.pointwiseCombinationFloat(~toPointSetFn, ~operation, ~float) + ->GenericDist.pointwiseCombinationFloat(~toPointSetFn, ~arithmeticOperation, ~float) ->E.R2.fmap(r => Dist(r)) ->OutputLocal.fromResult } - switch fnName { + switch functionCallInfo { | #fromDist(subFnName, dist) => fromDistFn(subFnName, dist) | #fromFloat(subFnName, float) => - reCall(~fnName=#fromDist(subFnName, GenericDist.fromFloat(float)), ()) + reCall(~functionCallInfo=#fromDist(subFnName, GenericDist.fromFloat(float)), ()) | #mixture(dists) => dists ->GenericDist.mixture(~scaleMultiplyFn=scaleMultiply, ~pointwiseAddFn=pointwiseAdd) @@ -148,24 +148,25 @@ let rec run = (extra, fnName: operation): outputType => { } } -let runFromDist = (extra, fnName, dist) => run(extra, #fromDist(fnName, dist)) -let runFromFloat = (extra, fnName, float) => run(extra, #fromFloat(fnName, float)) +let runFromDist = (~env, ~functionCallInfo, dist) => run(~env, #fromDist(functionCallInfo, dist)) +let runFromFloat = (~env, ~functionCallInfo, float) => + run(~env, #fromFloat(functionCallInfo, float)) module Output = { include OutputLocal let fmap = ( - extra, + ~env, input: outputType, - fn: GenericDist_Types.Operation.singleParamaterFunction, + functionCallInfo: GenericDist_Types.Operation.singleParamaterFunction, ): outputType => { - let newFnCall: result = switch (fn, input) { + let newFnCall: result = switch (functionCallInfo, input) { | (#fromDist(fromDist), Dist(o)) => Ok(#fromDist(fromDist, o)) | (#fromFloat(fromDist), Float(o)) => Ok(#fromFloat(fromDist, o)) | (_, GenDistError(r)) => Error(r) | (#fromDist(_), _) => Error(Other("Expected dist, got something else")) | (#fromFloat(_), _) => Error(Other("Expected float, got something else")) } - newFnCall->E.R2.fmap(r => run(extra, r))->OutputLocal.fromResult + newFnCall->E.R2.fmap(run(~env))->OutputLocal.fromResult } } diff --git a/packages/squiggle-lang/src/rescript/GenericDist/GenericDist_GenericOperation.resi b/packages/squiggle-lang/src/rescript/GenericDist/GenericDist_GenericOperation.resi index 2769a505..c9e26058 100644 --- a/packages/squiggle-lang/src/rescript/GenericDist/GenericDist_GenericOperation.resi +++ b/packages/squiggle-lang/src/rescript/GenericDist/GenericDist_GenericOperation.resi @@ -1,4 +1,4 @@ -type params = { +type env = { sampleCount: int, xyPointLength: int, } @@ -9,13 +9,13 @@ type outputType = | String(string) | GenDistError(GenericDist_Types.error) -let run: (params, GenericDist_Types.Operation.genericFunctionCallInfo) => outputType +let run: (~env: env, GenericDist_Types.Operation.genericFunctionCallInfo) => outputType let runFromDist: ( - params, - GenericDist_Types.Operation.fromDist, + ~env: env, + ~functionCallInfo: GenericDist_Types.Operation.fromDist, GenericDist_Types.genericDist, ) => outputType -let runFromFloat: (params, GenericDist_Types.Operation.fromDist, float) => outputType +let runFromFloat: (~env: env, ~functionCallInfo: GenericDist_Types.Operation.fromDist, float) => outputType module Output: { type t = outputType @@ -24,5 +24,5 @@ module Output: { let toFloat: t => option let toString: t => option let toError: t => option - let fmap: (params, t, GenericDist_Types.Operation.singleParamaterFunction) => t -} \ No newline at end of file + let fmap: (~env: env, t, GenericDist_Types.Operation.singleParamaterFunction) => t +}