From 516f4fa39d609d0d26a7a1f088d5574d5e2a3985 Mon Sep 17 00:00:00 2001 From: Ozzie Gooen Date: Sat, 20 Aug 2022 17:10:08 -0700 Subject: [PATCH] Added truncate for SampleSet distribution --- .../ReducerInterface_Distribution_test.res | 1 + .../rescript/Distributions/GenericDist.res | 18 ++++++++---- .../SampleSetDist/SampleSetDist.res | 9 ++++++ packages/website/docs/Api/Dist.mdx | 29 +++++++++++++++---- 4 files changed, 46 insertions(+), 11 deletions(-) diff --git a/packages/squiggle-lang/__tests__/ReducerInterface/ReducerInterface_Distribution_test.res b/packages/squiggle-lang/__tests__/ReducerInterface/ReducerInterface_Distribution_test.res index 3083b71b..47f1bc8a 100644 --- a/packages/squiggle-lang/__tests__/ReducerInterface/ReducerInterface_Distribution_test.res +++ b/packages/squiggle-lang/__tests__/ReducerInterface/ReducerInterface_Distribution_test.res @@ -74,6 +74,7 @@ describe("eval on distribution functions", () => { testEval("truncateLeft(normal(5,2), 3)", "Ok(Point Set Distribution)") testEval("truncateRight(normal(5,2), 3)", "Ok(Point Set Distribution)") testEval("truncate(normal(5,2), 3, 8)", "Ok(Point Set Distribution)") + testEval("truncate(normal(5,2) |> SampleSet.fromDist, 3, 8)", "Ok(Sample Set Distribution)") testEval("isNormalized(truncate(normal(5,2), 3, 8))", "Ok(true)") }) diff --git a/packages/squiggle-lang/src/rescript/Distributions/GenericDist.res b/packages/squiggle-lang/src/rescript/Distributions/GenericDist.res index f536d54d..14abd1b6 100644 --- a/packages/squiggle-lang/src/rescript/Distributions/GenericDist.res +++ b/packages/squiggle-lang/src/rescript/Distributions/GenericDist.res @@ -242,11 +242,19 @@ module Truncate = { switch trySymbolicSimplification(leftCutoff, rightCutoff, t) { | Some(r) => Ok(r) | None => - toPointSetFn(t)->E.R2.fmap(t => { - DistributionTypes.PointSet( - PointSetDist.T.truncate(leftCutoff, rightCutoff, t)->PointSetDist.T.normalize, - ) - }) + switch t { + | SampleSet(t) => + switch SampleSetDist.truncate(t, ~leftCutoff, ~rightCutoff) { + | Ok(r) => Ok(SampleSet(r)) + | Error(err) => Error(DistributionTypes.SampleSetError(err)) + } + | _ => + toPointSetFn(t)->E.R2.fmap(t => { + DistributionTypes.PointSet( + PointSetDist.T.truncate(leftCutoff, rightCutoff, t)->PointSetDist.T.normalize, + ) + }) + } } } } diff --git a/packages/squiggle-lang/src/rescript/Distributions/SampleSetDist/SampleSetDist.res b/packages/squiggle-lang/src/rescript/Distributions/SampleSetDist/SampleSetDist.res index dc15f7a1..f0fbff99 100644 --- a/packages/squiggle-lang/src/rescript/Distributions/SampleSetDist/SampleSetDist.res +++ b/packages/squiggle-lang/src/rescript/Distributions/SampleSetDist/SampleSetDist.res @@ -131,3 +131,12 @@ let max = t => T.get(t)->E.A.Floats.max let stdev = t => T.get(t)->E.A.Floats.stdev let variance = t => T.get(t)->E.A.Floats.variance let percentile = (t, f) => T.get(t)->E.A.Floats.percentile(f) + +let truncateLeft = (t, f) => T.get(t)->E.A2.filter(x => x >= f)->T.make +let truncateRight = (t, f) => T.get(t)->E.A2.filter(x => x <= f)->T.make + +let truncate = (t, ~leftCutoff: option, ~rightCutoff: option) => { + let withTruncatedLeft = t => leftCutoff |> E.O.dimap(left => truncateLeft(t, left), _ => Ok(t)) + let withTruncatedRight = t => rightCutoff |> E.O.dimap(left => truncateRight(t, left), _ => Ok(t)) + t->withTruncatedLeft |> E.R2.bind(withTruncatedRight) +} diff --git a/packages/website/docs/Api/Dist.mdx b/packages/website/docs/Api/Dist.mdx index e37c6f75..6bde09c9 100644 --- a/packages/website/docs/Api/Dist.mdx +++ b/packages/website/docs/Api/Dist.mdx @@ -290,12 +290,29 @@ quantile: (distribution, number) => number quantile(normal(5, 2), 0.5); ``` -### truncateLeft +### truncate -Truncates the left side of a distribution. Returns either a pointSet distribution or a symbolic distribution. +Truncates both the left side and the right side of a distribution. ``` -truncateLeft: (distribution, l => number) => distribution +truncate: (distribution, left: number, right: number) => distribution +``` + + +

+ Sample set distributions are truncated by filtering samples, but point set + distributions are truncated using direct geometric manipulation. Uniform + distributions are truncated symbolically. Symbolic but non-uniform + distributions get converted to Point Set distributions. +

+
+ +### truncateLeft + +Truncates the left side of a distribution. + +``` +truncateLeft: (distribution, left: number) => distribution ``` **Examples** @@ -306,10 +323,10 @@ truncateLeft(normal(5, 2), 3); ### truncateRight -Truncates the right side of a distribution. Returns either a pointSet distribution or a symbolic distribution. +Truncates the right side of a distribution. ``` -truncateRight: (distribution, r => number) => distribution +truncateRight: (distribution, right: number) => distribution ``` **Examples** @@ -388,7 +405,7 @@ The only functions that do not return normalized distributions are the pointwise ### normalize -Normalize a distribution. This means scaling it appropriately so that it's cumulative sum is equal to 1. This only impacts Pointset distributions, because those are the only ones that can be non-normlized. +Normalize a distribution. This means scaling it appropriately so that it's cumulative sum is equal to 1. This only impacts Point Set distributions, because those are the only ones that can be non-normlized. ``` normalize: (distribution) => distribution