diff --git a/packages/squiggle-lang/__tests__/GenericDist/GenericOperation__Test.res b/packages/squiggle-lang/__tests__/GenericDist/GenericOperation__Test.res index a5c1011f..4e3f207c 100644 --- a/packages/squiggle-lang/__tests__/GenericDist/GenericOperation__Test.res +++ b/packages/squiggle-lang/__tests__/GenericDist/GenericOperation__Test.res @@ -6,10 +6,10 @@ let env: GenericDist_GenericOperation.env = { xyPointLength: 100, } -let normalDist: GenericDist_Types.genericDist = #Symbolic(#Normal({mean: 5.0, stdev: 2.0})) -let normalDist10: GenericDist_Types.genericDist = #Symbolic(#Normal({mean: 10.0, stdev: 2.0})) -let normalDist20: GenericDist_Types.genericDist = #Symbolic(#Normal({mean: 20.0, stdev: 2.0})) -let uniformDist: GenericDist_Types.genericDist = #Symbolic(#Uniform({low: 9.0, high: 10.0})) +let normalDist: GenericDist_Types.genericDist = Symbolic(#Normal({mean: 5.0, stdev: 2.0})) +let normalDist10: GenericDist_Types.genericDist = Symbolic(#Normal({mean: 10.0, stdev: 2.0})) +let normalDist20: GenericDist_Types.genericDist = Symbolic(#Normal({mean: 20.0, stdev: 2.0})) +let uniformDist: GenericDist_Types.genericDist = Symbolic(#Uniform({low: 9.0, high: 10.0})) let {toFloat, toDist, toString, toError} = module(GenericDist_GenericOperation.Output) let {run} = module(GenericDist_GenericOperation) @@ -57,7 +57,7 @@ describe("toPointSet", () => { test("on sample set distribution with under 4 points", () => { let result = - run(#fromDist(#toDist(#toPointSet), #SampleSet([0.0, 1.0, 2.0, 3.0])))->outputMap( + run(#fromDist(#toDist(#toPointSet), SampleSet([0.0, 1.0, 2.0, 3.0])))->outputMap( #fromDist(#toFloat(#Mean)), ) expect(result)->toEqual(GenDistError(Other("Converting sampleSet to pointSet failed"))) diff --git a/packages/squiggle-lang/src/rescript/GenericDist/GenericDist.res b/packages/squiggle-lang/src/rescript/GenericDist/GenericDist.res index ba293541..bb2f8d71 100644 --- a/packages/squiggle-lang/src/rescript/GenericDist/GenericDist.res +++ b/packages/squiggle-lang/src/rescript/GenericDist/GenericDist.res @@ -8,25 +8,25 @@ type pointwiseAddFn = (t, t) => result let sampleN = (t: t, n) => switch t { - | #PointSet(r) => Ok(PointSetDist.sampleNRendered(n, r)) - | #Symbolic(r) => Ok(SymbolicDist.T.sampleN(n, r)) - | #SampleSet(_) => Error(GenericDist_Types.NotYetImplemented) + | PointSet(r) => Ok(PointSetDist.sampleNRendered(n, r)) + | Symbolic(r) => Ok(SymbolicDist.T.sampleN(n, r)) + | SampleSet(_) => Error(GenericDist_Types.NotYetImplemented) } -let fromFloat = (f: float) => #Symbolic(SymbolicDist.Float.make(f)) +let fromFloat = (f: float): t => Symbolic(SymbolicDist.Float.make(f)) let toString = (t: t) => switch t { - | #PointSet(_) => "Point Set Distribution" - | #Symbolic(r) => SymbolicDist.T.toString(r) - | #SampleSet(_) => "Sample Set Distribution" + | PointSet(_) => "Point Set Distribution" + | Symbolic(r) => SymbolicDist.T.toString(r) + | SampleSet(_) => "Sample Set Distribution" } -let normalize = (t: t) => +let normalize = (t: t): t => switch t { - | #PointSet(r) => #PointSet(PointSetDist.T.normalize(r)) - | #Symbolic(_) => t - | #SampleSet(_) => t + | PointSet(r) => PointSet(PointSetDist.T.normalize(r)) + | Symbolic(_) => t + | SampleSet(_) => t } let toFloatOperation = ( @@ -34,8 +34,8 @@ let toFloatOperation = ( ~toPointSetFn: toPointSetFn, ~distToFloatOperation: Operation.distToFloatOperation, ) => { - let symbolicSolution = switch t { - | #Symbolic(r) => + let symbolicSolution = switch (t: t) { + | Symbolic(r) => switch SymbolicDist.T.operate(distToFloatOperation, r) { | Ok(f) => Some(f) | _ => None @@ -53,10 +53,10 @@ let toFloatOperation = ( // This is tricky because the case of discrete distributions. // Also, change the outputXYPoints/pointSetDistLength details let toPointSet = (~xyPointLength, ~sampleCount, t): result => { - switch t { - | #PointSet(pointSet) => Ok(pointSet) - | #Symbolic(r) => Ok(SymbolicDist.T.toPointSetDist(xyPointLength, r)) - | #SampleSet(r) => { + switch (t: t) { + | PointSet(pointSet) => Ok(pointSet) + | Symbolic(r) => Ok(SymbolicDist.T.toPointSetDist(xyPointLength, r)) + | SampleSet(r) => { let response = SampleSet.toPointSetDist( ~samples=r, ~samplingInputs={ @@ -76,11 +76,11 @@ let toPointSet = (~xyPointLength, ~sampleCount, t): result => + let trySymbolicSimplification = (leftCutoff, rightCutoff, t: t): option => switch (leftCutoff, rightCutoff, t) { | (None, None, _) => None - | (lc, rc, #Symbolic(#Uniform(u))) if lc < rc => - Some(#Symbolic(#Uniform(SymbolicDist.Uniform.truncate(lc, rc, u)))) + | (lc, rc, Symbolic(#Uniform(u))) if lc < rc => + Some(Symbolic(#Uniform(SymbolicDist.Uniform.truncate(lc, rc, u)))) | _ => None } @@ -98,9 +98,9 @@ module Truncate = { switch trySymbolicSimplification(leftCutoff, rightCutoff, t) { | Some(r) => Ok(r) | None => - toPointSetFn(t)->E.R2.fmap(t => - #PointSet(PointSetDist.T.truncate(leftCutoff, rightCutoff, t)) - ) + toPointSetFn(t)->E.R2.fmap(t => { + GenericDist_Types.PointSet(PointSetDist.T.truncate(leftCutoff, rightCutoff, t)) + }) } } } @@ -122,7 +122,7 @@ module AlgebraicCombination = { t2: t, ): option> => switch (arithmeticOperation, t1, t2) { - | (arithmeticOperation, #Symbolic(d1), #Symbolic(d2)) => + | (arithmeticOperation, Symbolic(d1), Symbolic(d2)) => switch SymbolicDist.T.tryAnalyticalSimplification(d1, d2, arithmeticOperation) { | #AnalyticalSolution(symbolicDist) => Some(Ok(symbolicDist)) | #Error(er) => Some(Error(er)) @@ -156,11 +156,11 @@ module AlgebraicCombination = { //I'm (Ozzie) really just guessing here, very little idea what's best let expectedConvolutionCost: t => int = x => switch x { - | #Symbolic(#Float(_)) => 1 - | #Symbolic(_) => 1000 - | #PointSet(Discrete(m)) => m.xyShape->XYShape.T.length - | #PointSet(Mixed(_)) => 1000 - | #PointSet(Continuous(_)) => 1000 + | Symbolic(#Float(_)) => 1 + | Symbolic(_) => 1000 + | PointSet(Discrete(m)) => m.xyShape->XYShape.T.length + | PointSet(Mixed(_)) => 1000 + | PointSet(Continuous(_)) => 1000 | _ => 1000 } @@ -177,14 +177,24 @@ module AlgebraicCombination = { ~t2: t, ): result => { switch tryAnalyticalSimplification(arithmeticOperation, t1, t2) { - | Some(Ok(symbolicDist)) => Ok(#Symbolic(symbolicDist)) + | Some(Ok(symbolicDist)) => Ok(Symbolic(symbolicDist)) | Some(Error(e)) => Error(Other(e)) | None => switch chooseConvolutionOrMonteCarlo(t1, t2) { | #CalculateWithMonteCarlo => - runMonteCarlo(toSampleSetFn, arithmeticOperation, t1, t2)->E.R2.fmap(r => #SampleSet(r)) + runMonteCarlo( + toSampleSetFn, + arithmeticOperation, + t1, + t2, + )->E.R2.fmap(r => GenericDist_Types.SampleSet(r)) | #CalculateWithConvolution => - runConvolution(toPointSetFn, arithmeticOperation, t1, t2)->E.R2.fmap(r => #PointSet(r)) + runConvolution( + toPointSetFn, + arithmeticOperation, + t1, + t2, + )->E.R2.fmap(r => GenericDist_Types.PointSet(r)) } } } @@ -207,7 +217,7 @@ let pointwiseCombination = ( t2, ) ) - ->E.R2.fmap(r => #PointSet(r)) + ->E.R2.fmap(r => GenericDist_Types.PointSet(r)) } let pointwiseCombinationFloat = ( @@ -232,7 +242,7 @@ let pointwiseCombinationFloat = ( ) }) } - m->E.R2.fmap(r => #PointSet(r)) + m->E.R2.fmap(r => GenericDist_Types.PointSet(r)) } //Note: The result should always cumulatively sum to 1. This would be good to test. diff --git a/packages/squiggle-lang/src/rescript/GenericDist/GenericDist_GenericOperation.res b/packages/squiggle-lang/src/rescript/GenericDist/GenericDist_GenericOperation.res index 51878a7d..55f6c621 100644 --- a/packages/squiggle-lang/src/rescript/GenericDist/GenericDist_GenericOperation.res +++ b/packages/squiggle-lang/src/rescript/GenericDist/GenericDist_GenericOperation.res @@ -71,14 +71,14 @@ let rec run = (~env, functionCallInfo: functionCallInfo): outputType => { let toPointSetFn = r => { switch reCall(~functionCallInfo=#fromDist(#toDist(#toPointSet), r), ()) { - | Dist(#PointSet(p)) => Ok(p) + | Dist(PointSet(p)) => Ok(p) | e => Error(OutputLocal.toErrorOrUnreachable(e)) } } let toSampleSetFn = r => { switch reCall(~functionCallInfo=#fromDist(#toDist(#toSampleSet(sampleCount)), r), ()) { - | Dist(#SampleSet(p)) => Ok(p) + | Dist(SampleSet(p)) => Ok(p) | e => Error(OutputLocal.toErrorOrUnreachable(e)) } } @@ -114,10 +114,10 @@ let rec run = (~env, functionCallInfo: functionCallInfo): outputType => { | #toDist(#toPointSet) => dist ->GenericDist.toPointSet(~xyPointLength, ~sampleCount) - ->E.R2.fmap(r => Dist(#PointSet(r))) + ->E.R2.fmap(r => Dist(PointSet(r))) ->OutputLocal.fromResult | #toDist(#toSampleSet(n)) => - dist->GenericDist.sampleN(n)->E.R2.fmap(r => Dist(#SampleSet(r)))->OutputLocal.fromResult + dist->GenericDist.sampleN(n)->E.R2.fmap(r => Dist(SampleSet(r)))->OutputLocal.fromResult | #toDistCombination(#Algebraic, _, #Float(_)) => GenDistError(NotYetImplemented) | #toDistCombination(#Algebraic, arithmeticOperation, #Dist(t2)) => dist diff --git a/packages/squiggle-lang/src/rescript/GenericDist/GenericDist_GenericOperation.resi b/packages/squiggle-lang/src/rescript/GenericDist/GenericDist_GenericOperation.resi index c9e26058..abbd713e 100644 --- a/packages/squiggle-lang/src/rescript/GenericDist/GenericDist_GenericOperation.resi +++ b/packages/squiggle-lang/src/rescript/GenericDist/GenericDist_GenericOperation.resi @@ -15,7 +15,11 @@ let runFromDist: ( ~functionCallInfo: GenericDist_Types.Operation.fromDist, GenericDist_Types.genericDist, ) => outputType -let runFromFloat: (~env: env, ~functionCallInfo: GenericDist_Types.Operation.fromDist, float) => outputType +let runFromFloat: ( + ~env: env, + ~functionCallInfo: GenericDist_Types.Operation.fromDist, + float, +) => outputType module Output: { type t = outputType diff --git a/packages/squiggle-lang/src/rescript/GenericDist/GenericDist_Types.res b/packages/squiggle-lang/src/rescript/GenericDist/GenericDist_Types.res index e8b9a0c0..bc79cfc1 100644 --- a/packages/squiggle-lang/src/rescript/GenericDist/GenericDist_Types.res +++ b/packages/squiggle-lang/src/rescript/GenericDist/GenericDist_Types.res @@ -1,8 +1,7 @@ -type genericDist = [ - | #PointSet(PointSetTypes.pointSetDist) - | #SampleSet(array) - | #Symbolic(SymbolicDistTypes.symbolicDist) -] +type genericDist = + | PointSet(PointSetTypes.pointSetDist) + | SampleSet(array) + | Symbolic(SymbolicDistTypes.symbolicDist) type error = | NotYetImplemented