From 61aaca3e2ff86a1e80e22b8c2654f72eea388490 Mon Sep 17 00:00:00 2001 From: Ozzie Gooen Date: Sat, 9 Apr 2022 18:10:06 -0400 Subject: [PATCH] Gave SampleSetDist a private type --- .../DistributionOperation_test.res | 3 +- .../__tests__/Distributions/Samples_test.res | 8 ++-- .../DistributionOperation.res | 2 +- .../Distributions/GenericDist/GenericDist.res | 31 +++++++----- .../GenericDist/GenericDist.resi | 3 +- .../GenericDist/GenericDist_Types.res | 2 +- .../SampleSetDist/SampleSetDist.res | 48 +++++++++++++++++++ ...leSet.res => SampleSetDist_ToPointSet.res} | 27 ----------- .../src/rescript/OldInterpreter/ASTTypes.res | 14 +++--- 9 files changed, 83 insertions(+), 55 deletions(-) create mode 100644 packages/squiggle-lang/src/rescript/Distributions/SampleSetDist/SampleSetDist.res rename packages/squiggle-lang/src/rescript/Distributions/SampleSetDist/{SampleSet.res => SampleSetDist_ToPointSet.res} (85%) diff --git a/packages/squiggle-lang/__tests__/Distributions/DistributionOperation_test.res b/packages/squiggle-lang/__tests__/Distributions/DistributionOperation_test.res index cd22d512..f206b31d 100644 --- a/packages/squiggle-lang/__tests__/Distributions/DistributionOperation_test.res +++ b/packages/squiggle-lang/__tests__/Distributions/DistributionOperation_test.res @@ -91,8 +91,9 @@ describe("toPointSet", () => { }) test("on sample set distribution with under 4 points", () => { + let sampleSet = SampleSetDist.make([0.0, 1.0, 2.0, 3.0]) -> E.R.toExn; let result = - run(FromDist(ToDist(ToPointSet), SampleSet([0.0, 1.0, 2.0, 3.0])))->outputMap( + run(FromDist(ToDist(ToPointSet), SampleSet(sampleSet)))->outputMap( FromDist(ToFloat(#Mean)), ) expect(result)->toEqual(GenDistError(Other("Converting sampleSet to pointSet failed"))) diff --git a/packages/squiggle-lang/__tests__/Distributions/Samples_test.res b/packages/squiggle-lang/__tests__/Distributions/Samples_test.res index db80f9f7..5a48dd80 100644 --- a/packages/squiggle-lang/__tests__/Distributions/Samples_test.res +++ b/packages/squiggle-lang/__tests__/Distributions/Samples_test.res @@ -4,12 +4,12 @@ open TestHelpers describe("Continuous and discrete splits", () => { makeTest( "splits (1)", - SampleSet.Internals.T.splitContinuousAndDiscrete([1.432, 1.33455, 2.0]), + SampleSetDist_ToPointSet.Internals.T.splitContinuousAndDiscrete([1.432, 1.33455, 2.0]), ([1.432, 1.33455, 2.0], E.FloatFloatMap.empty()), ) makeTest( "splits (2)", - SampleSet.Internals.T.splitContinuousAndDiscrete([ + SampleSetDist_ToPointSet.Internals.T.splitContinuousAndDiscrete([ 1.432, 1.33455, 2.0, @@ -26,13 +26,13 @@ describe("Continuous and discrete splits", () => { E.A.concatMany([sorted, sorted, sorted, sorted]) |> Belt.SortArray.stableSortBy(_, compare) } - let (_, discrete1) = SampleSet.Internals.T.splitContinuousAndDiscrete( + let (_, discrete1) = SampleSetDist_ToPointSet.Internals.T.splitContinuousAndDiscrete( makeDuplicatedArray(10), ) let toArr1 = discrete1 |> E.FloatFloatMap.toArray makeTest("splitMedium at count=10", toArr1 |> Belt.Array.length, 10) - let (_c, discrete2) = SampleSet.Internals.T.splitContinuousAndDiscrete( + let (_c, discrete2) = SampleSetDist_ToPointSet.Internals.T.splitContinuousAndDiscrete( makeDuplicatedArray(500), ) let toArr2 = discrete2 |> E.FloatFloatMap.toArray diff --git a/packages/squiggle-lang/src/rescript/Distributions/DistributionOperation/DistributionOperation.res b/packages/squiggle-lang/src/rescript/Distributions/DistributionOperation/DistributionOperation.res index 69a9a4ed..bb5f4f1a 100644 --- a/packages/squiggle-lang/src/rescript/Distributions/DistributionOperation/DistributionOperation.res +++ b/packages/squiggle-lang/src/rescript/Distributions/DistributionOperation/DistributionOperation.res @@ -128,7 +128,7 @@ let rec run = (~env, functionCallInfo: functionCallInfo): outputType => { ->E.R2.fmap(r => Dist(r)) ->OutputLocal.fromResult | ToDist(ToSampleSet(n)) => - dist->GenericDist.sampleN(n)->E.R2.fmap(r => Dist(SampleSet(r)))->OutputLocal.fromResult + dist->GenericDist.toSampleSetDist(n)->E.R2.fmap(r => Dist(SampleSet(r)))->OutputLocal.fromResult | ToDist(ToPointSet) => dist ->GenericDist.toPointSet(~xyPointLength, ~sampleCount, ()) diff --git a/packages/squiggle-lang/src/rescript/Distributions/GenericDist/GenericDist.res b/packages/squiggle-lang/src/rescript/Distributions/GenericDist/GenericDist.res index ba19e60b..90381fba 100644 --- a/packages/squiggle-lang/src/rescript/Distributions/GenericDist/GenericDist.res +++ b/packages/squiggle-lang/src/rescript/Distributions/GenericDist/GenericDist.res @@ -2,16 +2,18 @@ type t = GenericDist_Types.genericDist type error = GenericDist_Types.error type toPointSetFn = t => result -type toSampleSetFn = t => result, error> +type toSampleSetFn = t => result type scaleMultiplyFn = (t, float) => result type pointwiseAddFn = (t, t) => result - +let mapStringErrors = n => n->E.R2.errMap(r => Error(GenericDist_Types.Other(r))) let sampleN = (t: t, n) => switch t { | PointSet(r) => Ok(PointSetDist.sampleNRendered(n, r)) | Symbolic(r) => Ok(SymbolicDist.T.sampleN(n, r)) - | SampleSet(r) => Ok(SampleSet.sampleN(r, n)) + | SampleSet(r) => Ok(SampleSetDist.sampleN(r, n)) } +let toSampleSetDist = (t: t, n) => sampleN(t, n)->E.R.bind(SampleSetDist.make)->mapStringErrors +let mapStringErrors = n => n->E.R2.errMap(r => Error(GenericDist_Types.Other(r))) let fromFloat = (f: float): t => Symbolic(SymbolicDist.Float.make(f)) @@ -63,7 +65,7 @@ let toPointSet = ( | PointSet(pointSet) => Ok(pointSet) | Symbolic(r) => Ok(SymbolicDist.T.toPointSetDist(~xSelection, xyPointLength, r)) | SampleSet(r) => { - let response = SampleSet.toPointSetDist( + let response = SampleSetDist.toPointSetDist( ~samples=r, ~samplingInputs={ sampleCount: sampleCount, @@ -167,8 +169,9 @@ module AlgebraicCombination = { t2: t, ) => { let arithmeticOperation = Operation.Algebraic.toFn(arithmeticOperation) - E.R.merge(toSampleSet(t1), toSampleSet(t2))->E.R2.fmap(((a, b)) => { - Belt.Array.zip(a, b)->E.A2.fmap(((a, b)) => arithmeticOperation(a, b)) + E.R.merge(toSampleSet(t1), toSampleSet(t2))->E.R.bind(((a, b)) => { + SampleSetDist.runMonteCarlo(arithmeticOperation, a, b) + ->mapStringErrors }) } @@ -200,13 +203,15 @@ module AlgebraicCombination = { | Some(Error(e)) => Error(Other(e)) | None => switch chooseConvolutionOrMonteCarlo(t1, t2) { - | #CalculateWithMonteCarlo => - runMonteCarlo( - toSampleSetFn, - arithmeticOperation, - t1, - t2, - )->E.R2.fmap(r => GenericDist_Types.SampleSet(r)) + | #CalculateWithMonteCarlo => { + let sampleSetDist: result = runMonteCarlo( + toSampleSetFn, + arithmeticOperation, + t1, + t2, + ) + sampleSetDist->E.R2.fmap(r => GenericDist_Types.SampleSet(r)) + } | #CalculateWithConvolution => runConvolution( toPointSetFn, diff --git a/packages/squiggle-lang/src/rescript/Distributions/GenericDist/GenericDist.resi b/packages/squiggle-lang/src/rescript/Distributions/GenericDist/GenericDist.resi index adf3b9d4..b65489e3 100644 --- a/packages/squiggle-lang/src/rescript/Distributions/GenericDist/GenericDist.resi +++ b/packages/squiggle-lang/src/rescript/Distributions/GenericDist/GenericDist.resi @@ -1,11 +1,12 @@ type t = GenericDist_Types.genericDist type error = GenericDist_Types.error type toPointSetFn = t => result -type toSampleSetFn = t => result, error> +type toSampleSetFn = t => result type scaleMultiplyFn = (t, float) => result type pointwiseAddFn = (t, t) => result let sampleN: (t, int) => result, error> +let toSampleSetDist: (t, int) => Belt.Result.t let fromFloat: float => t diff --git a/packages/squiggle-lang/src/rescript/Distributions/GenericDist/GenericDist_Types.res b/packages/squiggle-lang/src/rescript/Distributions/GenericDist/GenericDist_Types.res index 70754807..c3c923a4 100644 --- a/packages/squiggle-lang/src/rescript/Distributions/GenericDist/GenericDist_Types.res +++ b/packages/squiggle-lang/src/rescript/Distributions/GenericDist/GenericDist_Types.res @@ -1,6 +1,6 @@ type genericDist = | PointSet(PointSetTypes.pointSetDist) - | SampleSet(SampleSet.t) + | SampleSet(SampleSetDist.t) | Symbolic(SymbolicDistTypes.symbolicDist) @genType diff --git a/packages/squiggle-lang/src/rescript/Distributions/SampleSetDist/SampleSetDist.res b/packages/squiggle-lang/src/rescript/Distributions/SampleSetDist/SampleSetDist.res new file mode 100644 index 00000000..cea650f8 --- /dev/null +++ b/packages/squiggle-lang/src/rescript/Distributions/SampleSetDist/SampleSetDist.res @@ -0,0 +1,48 @@ +module T: { + type t + let make: array => result + let get: t => array +} = { + type t = array + let make = (a: array) => + if E.A.length(a) > 5 { + Ok(a) + } else { + Error("too small") + } + let get = (a: t) => a +} + +include T + +// TODO: Refactor to raise correct error when not enough samples + +let toPointSetDist = (~samples: t, ~samplingInputs: SamplingInputs.samplingInputs, ()) => + SampleSetDist_ToPointSet.toPointSetDist(~samples=get(samples), ~samplingInputs, ()) + +//Randomly get one sample from the distribution +let sample = (t: t): float => { + let i = E.Int.random(~min=0, ~max=E.A.length(get(t)) - 1) + E.A.unsafe_get(get(t), i) +} + +/* +If asked for a length of samples shorter or equal the length of the distribution, +return this first n samples of this distribution. +Else, return n random samples of the distribution. +The former helps in cases where multiple distributions are correlated. +However, if n > length(t), then there's no clear right answer, so we just randomly +sample everything. +*/ +let sampleN = (t: t, n) => { + if n <= E.A.length(get(t)) { + E.A.slice(get(t), ~offset=0, ~len=n) + } else { + Belt.Array.makeBy(n, _ => sample(t)) + } +} + +let runMonteCarlo = (fn: (float, float) => float, t1: t, t2: t) => { + let samples = Belt.Array.zip(get(t1), get(t2))->E.A2.fmap(((a, b)) => fn(a, b)) + make(samples) +} diff --git a/packages/squiggle-lang/src/rescript/Distributions/SampleSetDist/SampleSet.res b/packages/squiggle-lang/src/rescript/Distributions/SampleSetDist/SampleSetDist_ToPointSet.res similarity index 85% rename from packages/squiggle-lang/src/rescript/Distributions/SampleSetDist/SampleSet.res rename to packages/squiggle-lang/src/rescript/Distributions/SampleSetDist/SampleSetDist_ToPointSet.res index f8bce6f6..c8880f7a 100644 --- a/packages/squiggle-lang/src/rescript/Distributions/SampleSetDist/SampleSet.res +++ b/packages/squiggle-lang/src/rescript/Distributions/SampleSetDist/SampleSetDist_ToPointSet.res @@ -1,8 +1,3 @@ -@genType -type t = array - -// TODO: Refactor to raise correct error when not enough samples - module Internals = { module Types = { type samplingStats = { @@ -145,25 +140,3 @@ let toPointSetDist = ( samplesParse } - -//Randomly get one sample from the distribution -let sample = (t: t): float => { - let i = E.Int.random(~min=0, ~max=E.A.length(t) - 1) - E.A.unsafe_get(t, i) -} - -/* -If asked for a length of samples shorter or equal the length of the distribution, -return this first n samples of this distribution. -Else, return n random samples of the distribution. -The former helps in cases where multiple distributions are correlated. -However, if n > length(t), then there's no clear right answer, so we just randomly -sample everything. -*/ -let sampleN = (t: t, n) => { - if n <= E.A.length(t) { - E.A.slice(t, ~offset=0, ~len=n) - } else { - Belt.Array.makeBy(n, _ => sample(t)) - } -} diff --git a/packages/squiggle-lang/src/rescript/OldInterpreter/ASTTypes.res b/packages/squiggle-lang/src/rescript/OldInterpreter/ASTTypes.res index 31217374..57b4577c 100644 --- a/packages/squiggle-lang/src/rescript/OldInterpreter/ASTTypes.res +++ b/packages/squiggle-lang/src/rescript/OldInterpreter/ASTTypes.res @@ -218,15 +218,15 @@ module SamplingDistribution = { algebraicOp, a, b, - ) + ) |> E.O.toResult("Could not get samples") + + let sampleSetDist = samples -> E.R.bind(SampleSetDist.make) let pointSetDist = - samples - |> E.O.fmap(r => - SampleSet.toPointSetDist(~samplingInputs=evaluationParams.samplingInputs, ~samples=r, ()) - ) - |> E.O.bind(_, r => r.pointSetDist) - |> E.O.toResult("No response") + sampleSetDist + -> E.R2.fmap(r => + SampleSetDist.toPointSetDist(~samplingInputs=evaluationParams.samplingInputs, ~samples=r, ())) + -> E.R.bind(r => r.pointSetDist |> E.O.toResult("combineShapesUsingSampling Error")) pointSetDist |> E.R.fmap(r => #Normalize(#RenderedDist(r))) }) }