Added simple truncate logic

This commit is contained in:
Ozzie Gooen 2020-03-15 00:30:18 +00:00
parent b4b9f2cc9f
commit cebca141bd
2 changed files with 65 additions and 0 deletions

View File

@ -19,6 +19,7 @@ module type dist = {
let minX: t => option(float); let minX: t => option(float);
let maxX: t => option(float); let maxX: t => option(float);
let pointwiseFmap: (float => float, t) => t; let pointwiseFmap: (float => float, t) => t;
let truncate: (int, t) => t;
let xToY: (float, t) => DistTypes.mixedPoint; let xToY: (float, t) => DistTypes.mixedPoint;
let toShape: t => DistTypes.shape; let toShape: t => DistTypes.shape;
let toContinuous: t => option(DistTypes.continuousShape); let toContinuous: t => option(DistTypes.continuousShape);
@ -46,6 +47,7 @@ module Dist = (T: dist) => {
}; };
let pointwiseFmap = T.pointwiseFmap; let pointwiseFmap = T.pointwiseFmap;
let xToY = T.xToY; let xToY = T.xToY;
let truncate = T.truncate;
let toShape = T.toShape; let toShape = T.toShape;
let toDiscreteProbabilityMass = T.toDiscreteProbabilityMass; let toDiscreteProbabilityMass = T.toDiscreteProbabilityMass;
let toContinuous = T.toContinuous; let toContinuous = T.toContinuous;
@ -110,6 +112,8 @@ module Continuous = {
let toDiscreteProbabilityMass = _ => 0.0; let toDiscreteProbabilityMass = _ => 0.0;
let pointwiseFmap = (fn, t: t) => let pointwiseFmap = (fn, t: t) =>
t |> xyShape |> XYShape.T.pointwiseMap(fn) |> fromShape; t |> xyShape |> XYShape.T.pointwiseMap(fn) |> fromShape;
let truncate = i =>
shapeMap(CdfLibrary.Distribution.convertToNewLength(i));
let toShape = (t: t): DistTypes.shape => Continuous(t); let toShape = (t: t): DistTypes.shape => Continuous(t);
let xToY = (f, {interpolation, xyShape}: t) => let xToY = (f, {interpolation, xyShape}: t) =>
switch (interpolation) { switch (interpolation) {
@ -154,6 +158,14 @@ module Continuous = {
}; };
module Discrete = { module Discrete = {
let sortedByY = (t: DistTypes.discreteShape) =>
t
|> XYShape.T.zip
|> E.A.stableSortBy(_, ((_, y1), (_, y2)) => y1 > y2 ? 1 : 0);
let sortedByX = (t: DistTypes.discreteShape) =>
t
|> XYShape.T.zip
|> E.A.stableSortBy(_, ((x1, _), (x2, _)) => x1 > x2 ? 1 : 0);
module T = module T =
Dist({ Dist({
type t = DistTypes.discreteShape; type t = DistTypes.discreteShape;
@ -174,6 +186,13 @@ module Discrete = {
let toDiscrete = t => Some(t); let toDiscrete = t => Some(t);
let toScaledContinuous = _ => None; let toScaledContinuous = _ => None;
let toScaledDiscrete = t => Some(t); let toScaledDiscrete = t => Some(t);
let truncate = (i, t: t): DistTypes.discreteShape =>
t
|> XYShape.T.zip
|> XYShape.T.Zipped.sortByY
|> Belt.Array.slice(_, ~offset=0, ~len=i)
|> XYShape.T.Zipped.sortByX
|> XYShape.T.fromZippedArray;
let xToY = (f, t) => { let xToY = (f, t) => {
XYShape.T.XtoY.stepwiseIfAtX(f, t) XYShape.T.XtoY.stepwiseIfAtX(f, t)
@ -252,6 +271,32 @@ module Mixed = {
DistTypes.MixedPoint.add(c, d); DistTypes.MixedPoint.add(c, d);
}; };
let truncate =
(
count,
{discrete, continuous, discreteProbabilityMassFraction} as t: t,
)
: t => {
{
discrete:
Discrete.T.truncate(
int_of_float(
float_of_int(count) *. discreteProbabilityMassFraction,
),
discrete,
),
continuous:
Continuous.T.truncate(
int_of_float(
float_of_int(count)
*. (1.0 -. discreteProbabilityMassFraction),
),
continuous,
),
discreteProbabilityMassFraction,
};
};
let toScaledContinuous = ({continuous} as t: t) => let toScaledContinuous = ({continuous} as t: t) =>
Some(scaleContinuous(t, continuous)); Some(scaleContinuous(t, continuous));
@ -381,6 +426,16 @@ module Shape = {
), ),
); );
let truncate = (i, t: t) =>
fmap(
t,
(
Mixed.T.truncate(i),
Discrete.T.truncate(i),
Continuous.T.truncate(i),
),
);
let toDiscreteProbabilityMass = (t: t) => let toDiscreteProbabilityMass = (t: t) =>
mapToAll( mapToAll(
t, t,
@ -560,6 +615,8 @@ module DistPlus = {
let integral = (~cache, t: t) => let integral = (~cache, t: t) =>
updateShape(Continuous(t.integralCache), t); updateShape(Continuous(t.integralCache), t);
let truncate = (i, t) =>
updateShape(t |> toShape |> Shape.T.truncate(i), t);
// todo: adjust for limit, maybe? // todo: adjust for limit, maybe?
let pointwiseFmap = (fn, {shape, _} as t: t): t => let pointwiseFmap = (fn, {shape, _} as t: t): t =>
Shape.T.pointwiseFmap(fn, shape) |> updateShape(_, t); Shape.T.pointwiseFmap(fn, shape) |> updateShape(_, t);

View File

@ -64,6 +64,14 @@ module T = {
let fromZippedArray = (is: array((float, float))): t => let fromZippedArray = (is: array((float, float))): t =>
is |> Belt.Array.unzip |> fromArray; is |> Belt.Array.unzip |> fromArray;
module Zipped = {
type zipped = array((float, float));
let sortByY = (t: zipped) =>
t |> E.A.stableSortBy(_, ((_, y1), (_, y2)) => y1 > y2 ? 1 : 0);
let sortByX = (t: zipped) =>
t |> E.A.stableSortBy(_, ((x1, _), (x2, _)) => x1 > x2 ? 1 : 0);
};
module Combine = { module Combine = {
let combineLinear = (t1: t, t2: t, fn: (float, float) => float) => { let combineLinear = (t1: t, t2: t, fn: (float, float) => float) => {
let allXs = Belt.Array.concat(xs(t1), xs(t2)); let allXs = Belt.Array.concat(xs(t1), xs(t2));