diff --git a/packages/squiggle-lang/__tests__/Distributions/DistributionOperation_test.res b/packages/squiggle-lang/__tests__/Distributions/DistributionOperation_test.res index f206b31d..34a8dd6e 100644 --- a/packages/squiggle-lang/__tests__/Distributions/DistributionOperation_test.res +++ b/packages/squiggle-lang/__tests__/Distributions/DistributionOperation_test.res @@ -90,15 +90,6 @@ describe("toPointSet", () => { expect(result)->toBeSoCloseTo(5.0, ~digits=0) }) - test("on sample set distribution with under 4 points", () => { - let sampleSet = SampleSetDist.make([0.0, 1.0, 2.0, 3.0]) -> E.R.toExn; - let result = - run(FromDist(ToDist(ToPointSet), SampleSet(sampleSet)))->outputMap( - FromDist(ToFloat(#Mean)), - ) - expect(result)->toEqual(GenDistError(Other("Converting sampleSet to pointSet failed"))) - }) - test("on sample set", () => { let result = run(FromDist(ToDist(ToPointSet), normalDist5)) diff --git a/packages/squiggle-lang/src/rescript/Distributions/GenericDist/GenericDist.res b/packages/squiggle-lang/src/rescript/Distributions/GenericDist/GenericDist.res index f87b3a22..a69d7f5d 100644 --- a/packages/squiggle-lang/src/rescript/Distributions/GenericDist/GenericDist.res +++ b/packages/squiggle-lang/src/rescript/Distributions/GenericDist/GenericDist.res @@ -65,7 +65,7 @@ let toPointSet = ( | PointSet(pointSet) => Ok(pointSet) | Symbolic(r) => Ok(SymbolicDist.T.toPointSetDist(~xSelection, xyPointLength, r)) | SampleSet(r) => - SampleSetDist.toPointSetDist2( + SampleSetDist.toPointSetDist( ~samples=r, ~samplingInputs={ sampleCount: sampleCount, diff --git a/packages/squiggle-lang/src/rescript/Distributions/SampleSetDist/SampleSetDist.res b/packages/squiggle-lang/src/rescript/Distributions/SampleSetDist/SampleSetDist.res index 264b32ea..ccf1b775 100644 --- a/packages/squiggle-lang/src/rescript/Distributions/SampleSetDist/SampleSetDist.res +++ b/packages/squiggle-lang/src/rescript/Distributions/SampleSetDist/SampleSetDist.res @@ -15,14 +15,18 @@ module T: { include T -let length = (t:t) => get(t) |> E.A.length; +let length = (t: t) => get(t) |> E.A.length -// TODO: Refactor to raise correct error when not enough samples -let toPointSetDist = (~samples: t, ~samplingInputs: SamplingInputs.samplingInputs, ()) => - SampleSetDist_ToPointSet.toPointSetDist(~samples=get(samples), ~samplingInputs, ()) - -let toPointSetDist2 = (~samples: t, ~samplingInputs: SamplingInputs.samplingInputs, ()) => - SampleSetDist_ToPointSet.toPointSetDist(~samples=get(samples), ~samplingInputs, ()).pointSetDist |> E.O.toResult("Failed to convert to PointSetDist") +// TODO: Refactor to get error in the toPointSetDist function, instead of adding at very end. +let toPointSetDist = (~samples: t, ~samplingInputs: SamplingInputs.samplingInputs, ()): result< + PointSetTypes.pointSetDist, + string, +> => + SampleSetDist_ToPointSet.toPointSetDist( + ~samples=get(samples), + ~samplingInputs, + (), + ).pointSetDist |> E.O.toResult("Failed to convert to PointSetDist") //Randomly get one sample from the distribution let sample = (t: t): float => { diff --git a/packages/squiggle-lang/src/rescript/OldInterpreter/ASTTypes.res b/packages/squiggle-lang/src/rescript/OldInterpreter/ASTTypes.res index 57b4577c..3c14343a 100644 --- a/packages/squiggle-lang/src/rescript/OldInterpreter/ASTTypes.res +++ b/packages/squiggle-lang/src/rescript/OldInterpreter/ASTTypes.res @@ -222,11 +222,10 @@ module SamplingDistribution = { let sampleSetDist = samples -> E.R.bind(SampleSetDist.make) - let pointSetDist = + let pointSetDist = sampleSetDist - -> E.R2.fmap(r => - SampleSetDist.toPointSetDist(~samplingInputs=evaluationParams.samplingInputs, ~samples=r, ())) - -> E.R.bind(r => r.pointSetDist |> E.O.toResult("combineShapesUsingSampling Error")) + -> E.R.bind(r => + SampleSetDist.toPointSetDist(~samplingInputs=evaluationParams.samplingInputs, ~samples=r, ())); pointSetDist |> E.R.fmap(r => #Normalize(#RenderedDist(r))) }) }