Fixed some discrete functions

This commit is contained in:
Ozzie Gooen 2020-02-23 19:40:38 +00:00
parent 029431a449
commit 8d7a6f7f6c
4 changed files with 55 additions and 68 deletions

View File

@ -49,35 +49,46 @@ describe("Shape", () => {
describe("Discrete", () => { describe("Discrete", () => {
open Distributions.Discrete; open Distributions.Discrete;
let shape: DistTypes.xyShape = {xs: [|1., 4., 8.|], ys: [|8., 9., 2.|]}; let shape: DistTypes.xyShape = {
xs: [|1., 4., 8.|],
ys: [|0.3, 0.5, 0.2|],
};
let discrete = shape; let discrete = shape;
makeTest("minX", T.minX(discrete), Some(1.0)); makeTest("minX", T.minX(discrete), Some(1.0));
makeTest("maxX", T.maxX(discrete), Some(8.0)); makeTest("maxX", T.maxX(discrete), Some(8.0));
makeTest( makeTest(
"pointwiseFmap", "pointwiseFmap",
T.pointwiseFmap(r => r *. 2.0, discrete) |> (r => r.ys), T.pointwiseFmap(r => r *. 2.0, discrete) |> (r => r.ys),
[|16., 18.0, 4.0|], [|0.6, 1.0, 0.4|],
); );
makeTest( makeTest(
"xToY at 4.0", "xToY at 4.0",
T.xToY(4., discrete), T.xToY(4., discrete),
{discrete: 9.0, continuous: 0.0}, {discrete: 0.5, continuous: 0.0},
); );
makeTest( makeTest(
"xToY at 0.0", "xToY at 0.0",
T.xToY(0., discrete), T.xToY(0., discrete),
{discrete: 8.0, continuous: 0.0}, {discrete: 0.0, continuous: 0.0},
); );
makeTest( makeTest(
"xToY at 5.0", "xToY at 5.0",
T.xToY(5., discrete), T.xToY(5., discrete),
{discrete: 7.25, continuous: 0.0}, {discrete: 0.0, continuous: 0.0},
);
makeTest(
"integral",
T.Integral.get(~cache=None, discrete),
Distributions.Continuous.make(
{xs: [|1., 4., 8.|], ys: [|0.3, 0.8, 1.0|]},
`Stepwise,
),
); );
makeTest( makeTest(
"integralXToY", "integralXToY",
T.Integral.xToY(~cache=None, 2.0, discrete), T.Integral.xToY(~cache=None, 6.0, discrete),
11.0, 0.9,
); );
makeTest("integralSum", T.Integral.sum(~cache=None, discrete), 19.0); makeTest("integralSum", T.Integral.sum(~cache=None, discrete), 1.0);
}); });
}); });

View File

@ -122,6 +122,8 @@ module Continuous = {
}); });
}; };
// |> XYShape.Range.stepsToContinuous
// |> E.O.toExt("ERROR"),
module Discrete = { module Discrete = {
module T = module T =
Dist({ Dist({
@ -133,12 +135,7 @@ module Discrete = {
cache cache
|> E.O.default( |> E.O.default(
{ {
Continuous.make( Continuous.make(XYShape.accumulateYs(t), `Stepwise);
XYShape.accumulateYs(t)
|> XYShape.Range.stepsToContinuous
|> E.O.toExt("ERROR"),
`Stepwise,
);
}, },
); );
// todo: Fix this with last element // todo: Fix this with last element
@ -151,13 +148,19 @@ module Discrete = {
let toDiscrete = t => Some(t); let toDiscrete = t => Some(t);
let toScaledContinuous = _ => None; let toScaledContinuous = _ => None;
let toScaledDiscrete = t => Some(t); let toScaledDiscrete = t => Some(t);
// todo: Fix this with code that work find recent value and use that instead.
let xToY = (f, t) => let xToY = (f, t) => {
CdfLibrary.Distribution.findY(f, t) XYShape.getBy(t, ((x, _)) => x == f)
|> E.O.fmap(((_, y)) => y)
|> E.O.default(0.0)
|> DistTypes.MixedPoint.makeDiscrete; |> DistTypes.MixedPoint.makeDiscrete;
};
// todo: This should use cache and/or same code as above. FindingY is more complex, should use interpolationType. // todo: This should use cache and/or same code as above. FindingY is more complex, should use interpolationType.
let integralXtoY = (~cache, f, t) => let integralXtoY = (~cache, f, t) =>
t |> XYShape.accumulateYs |> CdfLibrary.Distribution.findY(f); t
|> integral(~cache)
|> Continuous.getShape
|> CdfLibrary.Distribution.findY(f);
}); });
}; };

View File

@ -1,29 +1,31 @@
open DistTypes; open DistTypes;
let _lastElement = (a: array('a)) =>
switch (Belt.Array.size(a)) {
| 0 => None
| n => Belt.Array.get(a, n - 1)
};
type t = xyShape; type t = xyShape;
let toJs = (t: t) => { let toJs = (t: t) => {
{"xs": t.xs, "ys": t.ys}; {"xs": t.xs, "ys": t.ys};
}; };
let minX = (t: t) => t.xs |> E.A.first;
let minX = (t: t) => t.xs |> E.A.get(_, 0); let maxX = (t: t) => t.xs |> E.A.last;
// TODO: Check if this actually gets the last element, I'm not sure it does. let first = (t: t) =>
let maxX = (t: t) => t.xs |> (r => E.A.get(r, E.A.length(r) - 1)); switch (t.xs |> E.A.first, t.ys |> E.A.first) {
| (Some(x), Some(y)) => Some((x, y))
| _ => None
};
let last = (t: t) =>
switch (t.xs |> E.A.last, t.ys |> E.A.last) {
| (Some(x), Some(y)) => Some((x, y))
| _ => None
};
let zip = t => Belt.Array.zip(t.xs, t.ys); let zip = t => Belt.Array.zip(t.xs, t.ys);
let getBy = (t: t, fn) => t |> zip |> Belt.Array.getBy(_, fn);
let fmap = (t: t, y): t => {xs: t.xs, ys: t.ys |> E.A.fmap(y)}; let pointwiseMap = (fn, t: t): t => {xs: t.xs, ys: t.ys |> E.A.fmap(fn)};
let fromArray = ((xs, ys)): t => {xs, ys}; let fromArray = ((xs, ys)): t => {xs, ys};
let fromArrays = (xs, ys): t => {xs, ys}; let fromArrays = (xs, ys): t => {xs, ys};
let pointwiseMap = (fn, t: t): t => {xs: t.xs, ys: t.ys |> E.A.fmap(fn)};
let compare = (a: float, b: float) => a > b ? 1 : (-1); // todo: maybe not needed?
// let comparePoint = (a: float, b: float) => a > b ? 1 : (-1);
let comparePoints = ((x1: float, y1: float), (x2: float, y2: float)) => let comparePoints = ((x1: float, y1: float), (x2: float, y2: float)) =>
switch (x1 == x2, y1 == y2) { switch (x1 == x2, y1 == y2) {
@ -32,6 +34,7 @@ let comparePoints = ((x1: float, y1: float), (x2: float, y2: float)) =>
| (true, true) => (-1) | (true, true) => (-1)
}; };
// todo: This is broken :(
let combine = (t1: t, t2: t) => { let combine = (t1: t, t2: t) => {
let totalLength = E.A.length(t1.xs) + E.A.length(t2.xs); let totalLength = E.A.length(t1.xs) + E.A.length(t2.xs);
let array = Belt.Array.concat(zip(t1), zip(t2)); let array = Belt.Array.concat(zip(t1), zip(t2));
@ -53,14 +56,6 @@ let intersperce = (t1: t, t2: t) => {
items^ |> Belt.Array.unzip |> fromArray; items^ |> Belt.Array.unzip |> fromArray;
}; };
let scaleCdfTo = (~scaleTo=1., t: t) =>
switch (_lastElement(t.ys)) {
| Some(n) =>
let scaleBy = scaleTo /. n;
fmap(t, r => r *. scaleBy);
| None => t
};
let yFold = (fn, t: t) => { let yFold = (fn, t: t) => {
E.A.fold_left(fn, 0., t.ys); E.A.fold_left(fn, 0., t.ys);
}; };
@ -69,9 +64,8 @@ let ySum = yFold((a, b) => a +. b);
let _transverse = fn => let _transverse = fn =>
Belt.Array.reduce(_, [||], (items, (x, y)) => Belt.Array.reduce(_, [||], (items, (x, y)) =>
switch (_lastElement(items)) { switch (E.A.last(items)) {
| Some((xLast, yLast)) => | Some((_, yLast)) => Belt.Array.concat(items, [|(x, fn(y, yLast))|])
Belt.Array.concat(items, [|(x, fn(y, yLast))|])
| None => [|(x, y)|] | None => [|(x, y)|]
} }
); );
@ -89,6 +83,7 @@ let subtractYs = _transverseShape((aCurrent, aLast) => aCurrent -. aLast);
let findY = CdfLibrary.Distribution.findY; let findY = CdfLibrary.Distribution.findY;
let findX = CdfLibrary.Distribution.findX; let findX = CdfLibrary.Distribution.findX;
// I'm really not sure this part is actually what we want at this point.
module Range = { module Range = {
// ((lastX, lastY), (nextX, nextY)) // ((lastX, lastY), (nextX, nextY))
type zippedRange = ((float, float), (float, float)); type zippedRange = ((float, float), (float, float));

View File

@ -122,7 +122,7 @@ module R = {
}; };
let safe_fn_of_string = (fn, s: string): option('a) => let safe_fn_of_string = (fn, s: string): option('a) =>
try (Some(fn(s))) { try(Some(fn(s))) {
| _ => None | _ => None
}; };
@ -216,6 +216,8 @@ module A = {
let unsafe_get = Array.unsafe_get; let unsafe_get = Array.unsafe_get;
let get = Belt.Array.get; let get = Belt.Array.get;
let getBy = Belt.Array.getBy; let getBy = Belt.Array.getBy;
let last = a => get(a, length(a) - 1);
let first = get(_, 0);
let hasBy = (r, fn) => Belt.Array.getBy(r, fn) |> O.isSome; let hasBy = (r, fn) => Belt.Array.getBy(r, fn) |> O.isSome;
let fold_left = Array.fold_left; let fold_left = Array.fold_left;
let fold_right = Array.fold_right; let fold_right = Array.fold_right;
@ -294,28 +296,4 @@ module JsArray = {
Rationale.Option.toExn("Warning: This should not have happened"), Rationale.Option.toExn("Warning: This should not have happened"),
); );
let filter = Js.Array.filter; let filter = Js.Array.filter;
};
module NonZeroInt = {
type t = int;
let make = (i: int) => i < 0 ? None : Some(i);
let fmap = (fn, a: t) => make(fn(a));
let increment = fmap(I.increment);
let decrement = fmap(I.decrement);
};
module BoundedInt = {
type t = int;
let make = (i: int, limit: int) => {
let lessThan0 = r => r < 0;
let greaterThanLimit = r => r > limit;
if (lessThan0(i) || greaterThanLimit(i)) {
None;
} else {
Some(i);
};
};
let fmap = (fn, a: t, l) => make(fn(a), l);
let increment = fmap(I.increment);
let decrement = fmap(I.decrement);
}; };