From cebca141bd0348070eaf7bcebdfd1967951f9072 Mon Sep 17 00:00:00 2001 From: Ozzie Gooen Date: Sun, 15 Mar 2020 00:30:18 +0000 Subject: [PATCH] Added simple truncate logic --- src/distributions/Distributions.re | 57 ++++++++++++++++++++++++++++++ src/distributions/XYShape.re | 8 +++++ 2 files changed, 65 insertions(+) diff --git a/src/distributions/Distributions.re b/src/distributions/Distributions.re index 7a89e472..1c778d45 100644 --- a/src/distributions/Distributions.re +++ b/src/distributions/Distributions.re @@ -19,6 +19,7 @@ module type dist = { let minX: t => option(float); let maxX: t => option(float); let pointwiseFmap: (float => float, t) => t; + let truncate: (int, t) => t; let xToY: (float, t) => DistTypes.mixedPoint; let toShape: t => DistTypes.shape; let toContinuous: t => option(DistTypes.continuousShape); @@ -46,6 +47,7 @@ module Dist = (T: dist) => { }; let pointwiseFmap = T.pointwiseFmap; let xToY = T.xToY; + let truncate = T.truncate; let toShape = T.toShape; let toDiscreteProbabilityMass = T.toDiscreteProbabilityMass; let toContinuous = T.toContinuous; @@ -110,6 +112,8 @@ module Continuous = { let toDiscreteProbabilityMass = _ => 0.0; let pointwiseFmap = (fn, t: t) => t |> xyShape |> XYShape.T.pointwiseMap(fn) |> fromShape; + let truncate = i => + shapeMap(CdfLibrary.Distribution.convertToNewLength(i)); let toShape = (t: t): DistTypes.shape => Continuous(t); let xToY = (f, {interpolation, xyShape}: t) => switch (interpolation) { @@ -154,6 +158,14 @@ module Continuous = { }; module Discrete = { + let sortedByY = (t: DistTypes.discreteShape) => + t + |> XYShape.T.zip + |> E.A.stableSortBy(_, ((_, y1), (_, y2)) => y1 > y2 ? 1 : 0); + let sortedByX = (t: DistTypes.discreteShape) => + t + |> XYShape.T.zip + |> E.A.stableSortBy(_, ((x1, _), (x2, _)) => x1 > x2 ? 1 : 0); module T = Dist({ type t = DistTypes.discreteShape; @@ -174,6 +186,13 @@ module Discrete = { let toDiscrete = t => Some(t); let toScaledContinuous = _ => None; let toScaledDiscrete = t => Some(t); + let truncate = (i, t: t): DistTypes.discreteShape => + t + |> XYShape.T.zip + |> XYShape.T.Zipped.sortByY + |> Belt.Array.slice(_, ~offset=0, ~len=i) + |> XYShape.T.Zipped.sortByX + |> XYShape.T.fromZippedArray; let xToY = (f, t) => { XYShape.T.XtoY.stepwiseIfAtX(f, t) @@ -252,6 +271,32 @@ module Mixed = { DistTypes.MixedPoint.add(c, d); }; + let truncate = + ( + count, + {discrete, continuous, discreteProbabilityMassFraction} as t: t, + ) + : t => { + { + discrete: + Discrete.T.truncate( + int_of_float( + float_of_int(count) *. discreteProbabilityMassFraction, + ), + discrete, + ), + continuous: + Continuous.T.truncate( + int_of_float( + float_of_int(count) + *. (1.0 -. discreteProbabilityMassFraction), + ), + continuous, + ), + discreteProbabilityMassFraction, + }; + }; + let toScaledContinuous = ({continuous} as t: t) => Some(scaleContinuous(t, continuous)); @@ -381,6 +426,16 @@ module Shape = { ), ); + let truncate = (i, t: t) => + fmap( + t, + ( + Mixed.T.truncate(i), + Discrete.T.truncate(i), + Continuous.T.truncate(i), + ), + ); + let toDiscreteProbabilityMass = (t: t) => mapToAll( t, @@ -560,6 +615,8 @@ module DistPlus = { let integral = (~cache, t: t) => updateShape(Continuous(t.integralCache), t); + let truncate = (i, t) => + updateShape(t |> toShape |> Shape.T.truncate(i), t); // todo: adjust for limit, maybe? let pointwiseFmap = (fn, {shape, _} as t: t): t => Shape.T.pointwiseFmap(fn, shape) |> updateShape(_, t); diff --git a/src/distributions/XYShape.re b/src/distributions/XYShape.re index 30bf0dc3..f299580e 100644 --- a/src/distributions/XYShape.re +++ b/src/distributions/XYShape.re @@ -64,6 +64,14 @@ module T = { let fromZippedArray = (is: array((float, float))): t => is |> Belt.Array.unzip |> fromArray; + module Zipped = { + type zipped = array((float, float)); + let sortByY = (t: zipped) => + t |> E.A.stableSortBy(_, ((_, y1), (_, y2)) => y1 > y2 ? 1 : 0); + let sortByX = (t: zipped) => + t |> E.A.stableSortBy(_, ((x1, _), (x2, _)) => x1 > x2 ? 1 : 0); + }; + module Combine = { let combineLinear = (t1: t, t2: t, fn: (float, float) => float) => { let allXs = Belt.Array.concat(xs(t1), xs(t2));