From 15534b10ce4374e3d0d10c860773c1adf8be476a Mon Sep 17 00:00:00 2001 From: Ozzie Gooen Date: Thu, 31 Mar 2022 14:51:42 -0400 Subject: [PATCH] Converted most of Operation to not be polymorphic --- .../GenericDist/GenericOperation__Test.res | 24 ++--- .../src/rescript/GenericDist/GenericDist.resi | 2 +- .../GenericDist_GenericOperation.res | 53 ++++++----- .../GenericDist/GenericDist_Types.res | 88 +++++++++---------- 4 files changed, 82 insertions(+), 85 deletions(-) diff --git a/packages/squiggle-lang/__tests__/GenericDist/GenericOperation__Test.res b/packages/squiggle-lang/__tests__/GenericDist/GenericOperation__Test.res index 4e3f207c..90d5a67c 100644 --- a/packages/squiggle-lang/__tests__/GenericDist/GenericOperation__Test.res +++ b/packages/squiggle-lang/__tests__/GenericDist/GenericOperation__Test.res @@ -22,14 +22,14 @@ let toExt: option<'a> => 'a = E.O.toExt( describe("normalize", () => { test("has no impact on normal dist", () => { - let result = run(#fromDist(#toDist(#normalize), normalDist)) + let result = run(FromDist(ToDist(Normalize), normalDist)) expect(result)->toEqual(Dist(normalDist)) }) }) describe("mean", () => { test("for a normal distribution", () => { - let result = GenericDist_GenericOperation.run(~env, #fromDist(#toFloat(#Mean), normalDist)) + let result = GenericDist_GenericOperation.run(~env, FromDist(ToFloat(#Mean), normalDist)) expect(result)->toEqual(Float(5.0)) }) }) @@ -37,8 +37,8 @@ describe("mean", () => { describe("mixture", () => { test("on two normal distributions", () => { let result = - run(#mixture([(normalDist10, 0.5), (normalDist20, 0.5)])) - ->outputMap(#fromDist(#toFloat(#Mean))) + run(Mixture([(normalDist10, 0.5), (normalDist20, 0.5)])) + ->outputMap(FromDist(ToFloat(#Mean))) ->toFloat ->toExt expect(result)->toBeCloseTo(15.28) @@ -48,8 +48,8 @@ describe("mixture", () => { describe("toPointSet", () => { test("on symbolic normal distribution", () => { let result = - run(#fromDist(#toDist(#toPointSet), normalDist)) - ->outputMap(#fromDist(#toFloat(#Mean))) + run(FromDist(ToDist(ToPointSet), normalDist)) + ->outputMap(FromDist(ToFloat(#Mean))) ->toFloat ->toExt expect(result)->toBeCloseTo(5.09) @@ -57,18 +57,18 @@ describe("toPointSet", () => { test("on sample set distribution with under 4 points", () => { let result = - run(#fromDist(#toDist(#toPointSet), SampleSet([0.0, 1.0, 2.0, 3.0])))->outputMap( - #fromDist(#toFloat(#Mean)), + run(FromDist(ToDist(ToPointSet), SampleSet([0.0, 1.0, 2.0, 3.0])))->outputMap( + FromDist(ToFloat(#Mean)), ) expect(result)->toEqual(GenDistError(Other("Converting sampleSet to pointSet failed"))) }) Skip.test("on sample set", () => { let result = - run(#fromDist(#toDist(#toPointSet), normalDist)) - ->outputMap(#fromDist(#toDist(#toSampleSet(1000)))) - ->outputMap(#fromDist(#toDist(#toPointSet))) - ->outputMap(#fromDist(#toFloat(#Mean))) + run(FromDist(ToDist(ToPointSet), normalDist)) + ->outputMap(FromDist(ToDist(ToSampleSet(1000)))) + ->outputMap(FromDist(ToDist(ToPointSet))) + ->outputMap(FromDist(ToFloat(#Mean))) ->toFloat ->toExt expect(result)->toBeCloseTo(5.09) diff --git a/packages/squiggle-lang/src/rescript/GenericDist/GenericDist.resi b/packages/squiggle-lang/src/rescript/GenericDist/GenericDist.resi index f61a983f..f567f6be 100644 --- a/packages/squiggle-lang/src/rescript/GenericDist/GenericDist.resi +++ b/packages/squiggle-lang/src/rescript/GenericDist/GenericDist.resi @@ -59,4 +59,4 @@ let mixture: ( array<(t, float)>, ~scaleMultiplyFn: scaleMultiplyFn, ~pointwiseAddFn: pointwiseAddFn, -) => result +) => result \ No newline at end of file diff --git a/packages/squiggle-lang/src/rescript/GenericDist/GenericDist_GenericOperation.res b/packages/squiggle-lang/src/rescript/GenericDist/GenericDist_GenericOperation.res index 55f6c621..67db34e1 100644 --- a/packages/squiggle-lang/src/rescript/GenericDist/GenericDist_GenericOperation.res +++ b/packages/squiggle-lang/src/rescript/GenericDist/GenericDist_GenericOperation.res @@ -70,14 +70,14 @@ let rec run = (~env, functionCallInfo: functionCallInfo): outputType => { } let toPointSetFn = r => { - switch reCall(~functionCallInfo=#fromDist(#toDist(#toPointSet), r), ()) { + switch reCall(~functionCallInfo=FromDist(ToDist(ToPointSet), r), ()) { | Dist(PointSet(p)) => Ok(p) | e => Error(OutputLocal.toErrorOrUnreachable(e)) } } let toSampleSetFn = r => { - switch reCall(~functionCallInfo=#fromDist(#toDist(#toSampleSet(sampleCount)), r), ()) { + switch reCall(~functionCallInfo=FromDist(ToDist(ToSampleSet(sampleCount)), r), ()) { | Dist(SampleSet(p)) => Ok(p) | e => Error(OutputLocal.toErrorOrUnreachable(e)) } @@ -85,51 +85,51 @@ let rec run = (~env, functionCallInfo: functionCallInfo): outputType => { let scaleMultiply = (r, weight) => reCall( - ~functionCallInfo=#fromDist(#toDistCombination(#Pointwise, #Multiply, #Float(weight)), r), + ~functionCallInfo=FromDist(ToDistCombination(Pointwise, #Multiply, #Float(weight)), r), (), )->OutputLocal.toDistR let pointwiseAdd = (r1, r2) => reCall( - ~functionCallInfo=#fromDist(#toDistCombination(#Pointwise, #Add, #Dist(r2)), r1), + ~functionCallInfo=FromDist(ToDistCombination(Pointwise, #Add, #Dist(r2)), r1), (), )->OutputLocal.toDistR let fromDistFn = (subFnName: GenericDist_Types.Operation.fromDist, dist: genericDist) => switch subFnName { - | #toFloat(distToFloatOperation) => + | ToFloat(distToFloatOperation) => GenericDist.toFloatOperation(dist, ~toPointSetFn, ~distToFloatOperation) ->E.R2.fmap(r => Float(r)) ->OutputLocal.fromResult - | #toString => dist->GenericDist.toString->String - | #toDist(#inspect) => { + | ToString => dist->GenericDist.toString->String + | ToDist(Inspect) => { Js.log2("Console log requested: ", dist) Dist(dist) } - | #toDist(#normalize) => dist->GenericDist.normalize->Dist - | #toDist(#truncate(leftCutoff, rightCutoff)) => + | ToDist(Normalize) => dist->GenericDist.normalize->Dist + | ToDist(Truncate(leftCutoff, rightCutoff)) => GenericDist.truncate(~toPointSetFn, ~leftCutoff, ~rightCutoff, dist, ()) ->E.R2.fmap(r => Dist(r)) ->OutputLocal.fromResult - | #toDist(#toPointSet) => + | ToDist(ToSampleSet(n)) => + dist->GenericDist.sampleN(n)->E.R2.fmap(r => Dist(SampleSet(r)))->OutputLocal.fromResult + | ToDist(ToPointSet) => dist ->GenericDist.toPointSet(~xyPointLength, ~sampleCount) ->E.R2.fmap(r => Dist(PointSet(r))) ->OutputLocal.fromResult - | #toDist(#toSampleSet(n)) => - dist->GenericDist.sampleN(n)->E.R2.fmap(r => Dist(SampleSet(r)))->OutputLocal.fromResult - | #toDistCombination(#Algebraic, _, #Float(_)) => GenDistError(NotYetImplemented) - | #toDistCombination(#Algebraic, arithmeticOperation, #Dist(t2)) => + | ToDistCombination(Algebraic, _, #Float(_)) => GenDistError(NotYetImplemented) + | ToDistCombination(Algebraic, arithmeticOperation, #Dist(t2)) => dist ->GenericDist.algebraicCombination(~toPointSetFn, ~toSampleSetFn, ~arithmeticOperation, ~t2) ->E.R2.fmap(r => Dist(r)) ->OutputLocal.fromResult - | #toDistCombination(#Pointwise, arithmeticOperation, #Dist(t2)) => + | ToDistCombination(Pointwise, arithmeticOperation, #Dist(t2)) => dist ->GenericDist.pointwiseCombination(~toPointSetFn, ~arithmeticOperation, ~t2) ->E.R2.fmap(r => Dist(r)) ->OutputLocal.fromResult - | #toDistCombination(#Pointwise, arithmeticOperation, #Float(float)) => + | ToDistCombination(Pointwise, arithmeticOperation, #Float(float)) => dist ->GenericDist.pointwiseCombinationFloat(~toPointSetFn, ~arithmeticOperation, ~float) ->E.R2.fmap(r => Dist(r)) @@ -137,10 +137,10 @@ let rec run = (~env, functionCallInfo: functionCallInfo): outputType => { } switch functionCallInfo { - | #fromDist(subFnName, dist) => fromDistFn(subFnName, dist) - | #fromFloat(subFnName, float) => - reCall(~functionCallInfo=#fromDist(subFnName, GenericDist.fromFloat(float)), ()) - | #mixture(dists) => + | FromDist(subFnName, dist) => fromDistFn(subFnName, dist) + | FromFloat(subFnName, float) => + reCall(~functionCallInfo=FromDist(subFnName, GenericDist.fromFloat(float)), ()) + | Mixture(dists) => dists ->GenericDist.mixture(~scaleMultiplyFn=scaleMultiply, ~pointwiseAddFn=pointwiseAdd) ->E.R2.fmap(r => Dist(r)) @@ -148,9 +148,8 @@ let rec run = (~env, functionCallInfo: functionCallInfo): outputType => { } } -let runFromDist = (~env, ~functionCallInfo, dist) => run(~env, #fromDist(functionCallInfo, dist)) -let runFromFloat = (~env, ~functionCallInfo, float) => - run(~env, #fromFloat(functionCallInfo, float)) +let runFromDist = (~env, ~functionCallInfo, dist) => run(~env, FromDist(functionCallInfo, dist)) +let runFromFloat = (~env, ~functionCallInfo, float) => run(~env, FromFloat(functionCallInfo, float)) module Output = { include OutputLocal @@ -161,11 +160,11 @@ module Output = { functionCallInfo: GenericDist_Types.Operation.singleParamaterFunction, ): outputType => { let newFnCall: result = switch (functionCallInfo, input) { - | (#fromDist(fromDist), Dist(o)) => Ok(#fromDist(fromDist, o)) - | (#fromFloat(fromDist), Float(o)) => Ok(#fromFloat(fromDist, o)) + | (FromDist(fromDist), Dist(o)) => Ok(FromDist(fromDist, o)) + | (FromFloat(fromDist), Float(o)) => Ok(FromFloat(fromDist, o)) | (_, GenDistError(r)) => Error(r) - | (#fromDist(_), _) => Error(Other("Expected dist, got something else")) - | (#fromFloat(_), _) => Error(Other("Expected float, got something else")) + | (FromDist(_), _) => Error(Other("Expected dist, got something else")) + | (FromFloat(_), _) => Error(Other("Expected float, got something else")) } newFnCall->E.R2.fmap(run(~env))->OutputLocal.fromResult } diff --git a/packages/squiggle-lang/src/rescript/GenericDist/GenericDist_Types.res b/packages/squiggle-lang/src/rescript/GenericDist/GenericDist_Types.res index bc79cfc1..98c0da25 100644 --- a/packages/squiggle-lang/src/rescript/GenericDist/GenericDist_Types.res +++ b/packages/squiggle-lang/src/rescript/GenericDist/GenericDist_Types.res @@ -10,10 +10,9 @@ type error = | Other(string) module Operation = { - type direction = [ - | #Algebraic - | #Pointwise - ] + type direction = + | Algebraic + | Pointwise type arithmeticOperation = [ | #Add @@ -42,51 +41,50 @@ module Operation = { | #Sample ] - type toDist = [ - | #normalize - | #toPointSet - | #toSampleSet(int) - | #truncate(option, option) - | #inspect - ] + type toDist = + | Normalize + | ToPointSet + | ToSampleSet(int) + | Truncate(option, option) + | Inspect - type toFloatArray = [ - | #Sample(int) - ] + type toFloatArray = Sample(int) - type fromDist = [ - | #toFloat(toFloat) - | #toDist(toDist) - | #toDistCombination(direction, arithmeticOperation, [#Dist(genericDist) | #Float(float)]) - | #toString - ] + type fromDist = + | ToFloat(toFloat) + | ToDist(toDist) + | ToDistCombination(direction, arithmeticOperation, [#Dist(genericDist) | #Float(float)]) + | ToString - type singleParamaterFunction = [ - | #fromDist(fromDist) - | #fromFloat(fromDist) - ] + type singleParamaterFunction = + | FromDist(fromDist) + | FromFloat(fromDist) - type genericFunctionCallInfo = [ - | #fromDist(fromDist, genericDist) - | #fromFloat(fromDist, float) - | #mixture(array<(genericDist, float)>) - ] + type genericFunctionCallInfo = + | FromDist(fromDist, genericDist) + | FromFloat(fromDist, float) + | Mixture(array<(genericDist, float)>) - //TODO: Should support all genericFunctionCallInfo types - let toString = (distFunction: fromDist): string => + let distCallToString = (distFunction: fromDist): string => switch distFunction { - | #toFloat(#Cdf(r)) => `cdf(${E.Float.toFixed(r)})` - | #toFloat(#Inv(r)) => `inv(${E.Float.toFixed(r)})` - | #toFloat(#Mean) => `mean` - | #toFloat(#Pdf(r)) => `pdf(${E.Float.toFixed(r)})` - | #toFloat(#Sample) => `sample` - | #toDist(#normalize) => `normalize` - | #toDist(#toPointSet) => `toPointSet` - | #toDist(#toSampleSet(r)) => `toSampleSet(${E.I.toString(r)})` - | #toDist(#truncate(_, _)) => `truncate` - | #toDist(#inspect) => `inspect` - | #toString => `toString` - | #toDistCombination(#Algebraic, _, _) => `algebraic` - | #toDistCombination(#Pointwise, _, _) => `pointwise` + | ToFloat(#Cdf(r)) => `cdf(${E.Float.toFixed(r)})` + | ToFloat(#Inv(r)) => `inv(${E.Float.toFixed(r)})` + | ToFloat(#Mean) => `mean` + | ToFloat(#Pdf(r)) => `pdf(${E.Float.toFixed(r)})` + | ToFloat(#Sample) => `sample` + | ToDist(Normalize) => `normalize` + | ToDist(ToPointSet) => `toPointSet` + | ToDist(ToSampleSet(r)) => `toSampleSet(${E.I.toString(r)})` + | ToDist(Truncate(_, _)) => `truncate` + | ToDist(Inspect) => `inspect` + | ToString => `toString` + | ToDistCombination(Algebraic, _, _) => `algebraic` + | ToDistCombination(Pointwise, _, _) => `pointwise` } -} \ No newline at end of file + + let toString = (d: genericFunctionCallInfo): string => + switch d { + | FromDist(f, _) | FromFloat(f, _) => distCallToString(f) + | Mixture(_) => `mixture` + } +}