From ffc622fb6db06fca8e335c7f4452f39678e6fcce Mon Sep 17 00:00:00 2001 From: Ozzie Gooen Date: Tue, 29 Mar 2022 21:28:14 -0400 Subject: [PATCH] Responded to two simple CR comments --- .../GenericDist/GenericOperation__Test.res | 2 +- .../src/rescript/GenericDist/GenericDist.res | 6 +++-- .../GenericDist_GenericOperation.res | 26 +++++++++---------- .../GenericDist_GenericOperation.resi | 4 +-- .../GenericDist/GenericDist_Types.res | 5 ++-- 5 files changed, 23 insertions(+), 20 deletions(-) diff --git a/packages/squiggle-lang/__tests__/GenericDist/GenericOperation__Test.res b/packages/squiggle-lang/__tests__/GenericDist/GenericOperation__Test.res index 6263155b..9dfadbf8 100644 --- a/packages/squiggle-lang/__tests__/GenericDist/GenericOperation__Test.res +++ b/packages/squiggle-lang/__tests__/GenericDist/GenericOperation__Test.res @@ -59,7 +59,7 @@ describe("toPointSet", () => { run(#fromDist(#toDist(#toPointSet), #SampleSet([0.0, 1.0, 2.0, 3.0])))->fmap( #fromDist(#toFloat(#Mean)), ) - expect(result)->toEqual(#Error(Other("Converting sampleSet to pointSet failed"))) + expect(result)->toEqual(#GenDistError(Other("Converting sampleSet to pointSet failed"))) }) test("on sample set", () => { diff --git a/packages/squiggle-lang/src/rescript/GenericDist/GenericDist.res b/packages/squiggle-lang/src/rescript/GenericDist/GenericDist.res index 48588afe..91e9daf6 100644 --- a/packages/squiggle-lang/src/rescript/GenericDist/GenericDist.res +++ b/packages/squiggle-lang/src/rescript/GenericDist/GenericDist.res @@ -207,7 +207,7 @@ let pointwiseCombinationFloat = ( operation: GenericDist_Types.Operation.arithmeticOperation, f: float, ): result => { - switch operation { + let m = switch operation { | #Add | #Subtract => Error(GenericDist_Types.DistributionVerticalShiftIsInvalid) | (#Multiply | #Divide | #Exponentiate | #Log) as operation => toPointSet(t)->E.R2.fmap(t => { @@ -222,10 +222,12 @@ let pointwiseCombinationFloat = ( t, ) }) - }->E.R2.fmap(r => #PointSet(r)) + } + m->E.R2.fmap(r => #PointSet(r)) } //Note: The result should always cumulatively sum to 1. This would be good to test. +//Note: If the inputs are not normalized, this will return poor results. The weights probably refer to the post-normalized forms. It would be good to apply a catch to this. let mixture = ( values: array<(t, float)>, scaleMultiply: scaleMultiplyFn, diff --git a/packages/squiggle-lang/src/rescript/GenericDist/GenericDist_GenericOperation.res b/packages/squiggle-lang/src/rescript/GenericDist/GenericDist_GenericOperation.res index 43aec78f..5a893150 100644 --- a/packages/squiggle-lang/src/rescript/GenericDist/GenericDist_GenericOperation.res +++ b/packages/squiggle-lang/src/rescript/GenericDist/GenericDist_GenericOperation.res @@ -1,4 +1,4 @@ -type operation = GenericDist_Types.Operation.genericFunctionCall +type operation = GenericDist_Types.Operation.genericFunctionCallInfo type genericDist = GenericDist_Types.genericDist type error = GenericDist_Types.error @@ -11,9 +11,9 @@ type params = { type outputType = [ | #Dist(genericDist) - | #Error(error) | #Float(float) | #String(string) + | #GenDistError(error) ] module Output = { @@ -37,7 +37,7 @@ module Output = { let toError = (o: outputType) => switch o { - | #Error(d) => Some(d) + | #GenDistError(d) => Some(d) | _ => None } } @@ -45,14 +45,14 @@ module Output = { let fromResult = (r: result): outputType => switch r { | Ok(o) => o - | Error(e) => #Error(e) + | Error(e) => #GenDistError(e) } let outputToDistResult = (b: outputType): result => switch b { | #Dist(r) => Ok(r) - | #Error(r) => Error(r) - | _ => Error(ImpossiblePath) + | #GenDistError(r) => Error(r) + | _ => Error(Unreachable) } let rec run = (extra, fnName: operation): outputType => { @@ -65,16 +65,16 @@ let rec run = (extra, fnName: operation): outputType => { let toPointSet = r => { switch reCall(~fnName=#fromDist(#toDist(#toPointSet), r), ()) { | #Dist(#PointSet(p)) => Ok(p) - | #Error(r) => Error(r) - | _ => Error(ImpossiblePath) + | #GenDistError(r) => Error(r) + | _ => Error(Unreachable) } } let toSampleSet = r => { switch reCall(~fnName=#fromDist(#toDist(#toSampleSet(sampleCount)), r), ()) { | #Dist(#SampleSet(p)) => Ok(p) - | #Error(r) => Error(r) - | _ => Error(ImpossiblePath) + | #GenDistError(r) => Error(r) + | _ => Error(Unreachable) } } @@ -106,7 +106,7 @@ let rec run = (extra, fnName: operation): outputType => { dist->GenericDist.toPointSet(xyPointLength)->E.R2.fmap(r => #Dist(#PointSet(r)))->fromResult | #toDist(#toSampleSet(n)) => dist->GenericDist.sampleN(n)->E.R2.fmap(r => #Dist(#SampleSet(r)))->fromResult - | #toDistCombination(#Algebraic, _, #Float(_)) => #Error(NotYetImplemented) + | #toDistCombination(#Algebraic, _, #Float(_)) => #GenDistError(NotYetImplemented) | #toDistCombination(#Algebraic, operation, #Dist(dist2)) => dist ->GenericDist.algebraicCombination(toPointSet, toSampleSet, operation, dist2) @@ -143,9 +143,9 @@ let fmap = ( let newFnCall: result = switch (fn, input) { | (#fromDist(fromDist), #Dist(o)) => Ok(#fromDist(fromDist, o)) | (#fromFloat(fromDist), #Float(o)) => Ok(#fromFloat(fromDist, o)) - | (_, #Error(r)) => Error(r) + | (_, #GenDistError(r)) => Error(r) | (#fromDist(_), _) => Error(Other("Expected dist, got something else")) | (#fromFloat(_), _) => Error(Other("Expected float, got something else")) } newFnCall->E.R2.fmap(r => run(extra, r))->fromResult -} +} \ No newline at end of file diff --git a/packages/squiggle-lang/src/rescript/GenericDist/GenericDist_GenericOperation.resi b/packages/squiggle-lang/src/rescript/GenericDist/GenericDist_GenericOperation.resi index 53e1463a..f8acdc42 100644 --- a/packages/squiggle-lang/src/rescript/GenericDist/GenericDist_GenericOperation.resi +++ b/packages/squiggle-lang/src/rescript/GenericDist/GenericDist_GenericOperation.resi @@ -5,12 +5,12 @@ type params = { type outputType = [ | #Dist(GenericDist_Types.genericDist) - | #Error(GenericDist_Types.error) + | #GenDistError(GenericDist_Types.error) | #Float(float) | #String(string) ] -let run: (params, GenericDist_Types.Operation.genericFunctionCall) => outputType +let run: (params, GenericDist_Types.Operation.genericFunctionCallInfo) => outputType let runFromDist: ( params, GenericDist_Types.Operation.fromDist, diff --git a/packages/squiggle-lang/src/rescript/GenericDist/GenericDist_Types.res b/packages/squiggle-lang/src/rescript/GenericDist/GenericDist_Types.res index c2fb64d7..3a55ee63 100644 --- a/packages/squiggle-lang/src/rescript/GenericDist/GenericDist_Types.res +++ b/packages/squiggle-lang/src/rescript/GenericDist/GenericDist_Types.res @@ -6,7 +6,7 @@ type genericDist = [ type error = | NotYetImplemented - | ImpossiblePath + | Unreachable | DistributionVerticalShiftIsInvalid | Other(string) @@ -67,12 +67,13 @@ module Operation = { | #fromFloat(fromDist) ] - type genericFunctionCall = [ + type genericFunctionCallInfo = [ | #fromDist(fromDist, genericDist) | #fromFloat(fromDist, float) | #mixture(array<(genericDist, float)>) ] + //TODO: Should support all genericFunctionCallInfo types let toString = (distFunction: fromDist): string => switch distFunction { | #toFloat(#Cdf(r)) => `cdf(${E.Float.toFixed(r)})`