Added native logScoring and related functionality to XYShape

This commit is contained in:
Ozzie Gooen 2020-03-14 21:18:34 +00:00
parent f887af22ea
commit a99415e5bc
7 changed files with 258 additions and 173 deletions

View File

@ -0,0 +1,43 @@
open Jest;
open Expect;
let makeTest = (~only=false, str, item1, item2) =>
only
? Only.test(str, () =>
expect(item1) |> toEqual(item2)
)
: test(str, () =>
expect(item1) |> toEqual(item2)
);
let shape1: DistTypes.xyShape = {xs: [|1., 4., 8.|], ys: [|0.2, 0.4, 0.8|]};
let shape2: DistTypes.xyShape = {
xs: [|1., 5., 10.|],
ys: [|0.2, 0.5, 0.8|],
};
let shape3: DistTypes.xyShape = {
xs: [|1., 20., 50.|],
ys: [|0.2, 0.5, 0.8|],
};
describe("XYShapes", () => {
describe("logScorePoint", () => {
makeTest(
"When identical",
XYShape.logScorePoint(30, shape1, shape1),
Some(0.0),
);
makeTest(
"When similar",
XYShape.logScorePoint(30, shape1, shape2),
Some(1.658971191043856),
);
makeTest(
"When very different",
XYShape.logScorePoint(30, shape1, shape3),
Some(210.3721280423322),
);
})
});

View File

@ -134,12 +134,12 @@ let make =
?xScale ?xScale
?yScale ?yScale
?timeScale ?timeScale
discrete={discrete |> E.O.fmap(XYShape.toJs)} discrete={discrete |> E.O.fmap(XYShape.T.toJs)}
height height
marginBottom=50 marginBottom=50
marginTop=0 marginTop=0
onHover onHover
continuous={continuous |> E.O.fmap(XYShape.toJs)} continuous={continuous |> E.O.fmap(XYShape.T.toJs)}
showDistributionLines showDistributionLines
showDistributionYAxis showDistributionYAxis
showVerticalLine showVerticalLine

View File

@ -82,7 +82,7 @@ module Continuous = {
interpolation, interpolation,
}; };
let lastY = (t: t) => let lastY = (t: t) =>
t |> xyShape |> XYShape.unsafeLast |> (((_, y)) => y); t |> xyShape |> XYShape.T.unsafeLast |> (((_, y)) => y);
let oShapeMap = let oShapeMap =
(fn, {xyShape, interpolation}: t): option(DistTypes.continuousShape) => (fn, {xyShape, interpolation}: t): option(DistTypes.continuousShape) =>
fn(xyShape) |> E.O.fmap(make(_, interpolation)); fn(xyShape) |> E.O.fmap(make(_, interpolation));
@ -105,22 +105,22 @@ module Continuous = {
Dist({ Dist({
type t = DistTypes.continuousShape; type t = DistTypes.continuousShape;
type integral = DistTypes.continuousShape; type integral = DistTypes.continuousShape;
let minX = shapeFn(XYShape.minX); let minX = shapeFn(XYShape.T.minX);
let maxX = shapeFn(XYShape.maxX); let maxX = shapeFn(XYShape.T.maxX);
let toDiscreteProbabilityMass = _ => 0.0; let toDiscreteProbabilityMass = _ => 0.0;
let pointwiseFmap = (fn, t: t) => let pointwiseFmap = (fn, t: t) =>
t |> xyShape |> XYShape.pointwiseMap(fn) |> fromShape; t |> xyShape |> XYShape.T.pointwiseMap(fn) |> fromShape;
let toShape = (t: t): DistTypes.shape => Continuous(t); let toShape = (t: t): DistTypes.shape => Continuous(t);
let xToY = (f, {interpolation, xyShape}: t) => let xToY = (f, {interpolation, xyShape}: t) =>
switch (interpolation) { switch (interpolation) {
| `Stepwise => | `Stepwise =>
xyShape xyShape
|> XYShape.XtoY.stepwiseIncremental(f) |> XYShape.T.XtoY.stepwiseIncremental(f)
|> E.O.default(0.0) |> E.O.default(0.0)
|> DistTypes.MixedPoint.makeContinuous |> DistTypes.MixedPoint.makeContinuous
| `Linear => | `Linear =>
xyShape xyShape
|> XYShape.XtoY.linear(f) |> XYShape.T.XtoY.linear(f)
|> DistTypes.MixedPoint.makeContinuous |> DistTypes.MixedPoint.makeContinuous
}; };
@ -161,14 +161,14 @@ module Discrete = {
let integral = (~cache, t) => let integral = (~cache, t) =>
switch (cache) { switch (cache) {
| Some(c) => c | Some(c) => c
| None => Continuous.make(XYShape.accumulateYs(t), `Stepwise) | None => Continuous.make(XYShape.T.accumulateYs(t), `Stepwise)
}; };
let integralEndY = (~cache, t) => let integralEndY = (~cache, t) =>
t |> integral(~cache) |> Continuous.lastY; t |> integral(~cache) |> Continuous.lastY;
let minX = XYShape.minX; let minX = XYShape.T.minX;
let maxX = XYShape.maxX; let maxX = XYShape.T.maxX;
let toDiscreteProbabilityMass = t => 1.0; let toDiscreteProbabilityMass = t => 1.0;
let pointwiseFmap = XYShape.pointwiseMap; let pointwiseFmap = XYShape.T.pointwiseMap;
let toShape = (t: t): DistTypes.shape => Discrete(t); let toShape = (t: t): DistTypes.shape => Discrete(t);
let toContinuous = _ => None; let toContinuous = _ => None;
let toDiscrete = t => Some(t); let toDiscrete = t => Some(t);
@ -176,7 +176,7 @@ module Discrete = {
let toScaledDiscrete = t => Some(t); let toScaledDiscrete = t => Some(t);
let xToY = (f, t) => { let xToY = (f, t) => {
XYShape.XtoY.stepwiseIfAtX(f, t) XYShape.T.XtoY.stepwiseIfAtX(f, t)
|> E.O.default(0.0) |> E.O.default(0.0)
|> DistTypes.MixedPoint.makeDiscrete; |> DistTypes.MixedPoint.makeDiscrete;
}; };
@ -294,7 +294,7 @@ module Mixed = {
let result = let result =
Continuous.make( Continuous.make(
XYShape.Combine.combineLinear( XYShape.T.Combine.combineLinear(
Continuous.getShape(cont), Continuous.getShape(dist), (a, b) => Continuous.getShape(cont), Continuous.getShape(dist), (a, b) =>
a +. b a +. b
), ),

View File

@ -10,8 +10,11 @@ type assumptions = {
let buildSimple = (~continuous, ~discrete): option(DistTypes.shape) => { let buildSimple = (~continuous, ~discrete): option(DistTypes.shape) => {
let cLength = let cLength =
continuous |> Distributions.Continuous.getShape |> XYShape.xs |> E.A.length; continuous
let dLength = discrete |> XYShape.xs |> E.A.length; |> Distributions.Continuous.getShape
|> XYShape.T.xs
|> E.A.length;
let dLength = discrete |> XYShape.T.xs |> E.A.length;
switch (cLength, dLength) { switch (cLength, dLength) {
| (0 | 1, 0) => None | (0 | 1, 0) => None
| (0 | 1, _) => Some(Discrete(discrete)) | (0 | 1, _) => Some(Discrete(discrete))

View File

@ -1,36 +1,38 @@
open DistTypes; open DistTypes;
type t = xyShape; module T = {
type t = xyShape;
type ts = array(xyShape);
let toJs = (t: t) => { let toJs = (t: t) => {
{"xs": t.xs, "ys": t.ys}; {"xs": t.xs, "ys": t.ys};
}; };
let xs = (t: t) => t.xs; let xs = (t: t) => t.xs;
let minX = (t: t) => t |> xs |> E.A.first; let minX = (t: t) => t |> xs |> E.A.first;
let maxX = (t: t) => t |> xs |> E.A.last; let maxX = (t: t) => t |> xs |> E.A.last;
let xTotalRange = (t: t) => let xTotalRange = (t: t) =>
switch (minX(t), maxX(t)) { switch (minX(t), maxX(t)) {
| (Some(min), Some(max)) => Some(max -. min) | (Some(min), Some(max)) => Some(max -. min)
| _ => None | _ => None
}; };
let first = ({xs, ys}: t) => let first = ({xs, ys}: t) =>
switch (xs |> E.A.first, ys |> E.A.first) { switch (xs |> E.A.first, ys |> E.A.first) {
| (Some(x), Some(y)) => Some((x, y)) | (Some(x), Some(y)) => Some((x, y))
| _ => None | _ => None
}; };
let last = ({xs, ys}: t) => let last = ({xs, ys}: t) =>
switch (xs |> E.A.last, ys |> E.A.last) { switch (xs |> E.A.last, ys |> E.A.last) {
| (Some(x), Some(y)) => Some((x, y)) | (Some(x), Some(y)) => Some((x, y))
| _ => None | _ => None
}; };
let unsafeFirst = (t: t) => first(t) |> E.O.toExn("Unsafe operation"); let unsafeFirst = (t: t) => first(t) |> E.O.toExn("Unsafe operation");
let unsafeLast = (t: t) => last(t) |> E.O.toExn("Unsafe operation"); let unsafeLast = (t: t) => last(t) |> E.O.toExn("Unsafe operation");
let zip = ({xs, ys}: t) => Belt.Array.zip(xs, ys); let zip = ({xs, ys}: t) => Belt.Array.zip(xs, ys);
let getBy = (t: t, fn) => t |> zip |> Belt.Array.getBy(_, fn); let getBy = (t: t, fn) => t |> zip |> Belt.Array.getBy(_, fn);
let firstPairAtOrBeforeValue = (xValue, t: t) => { let firstPairAtOrBeforeValue = (xValue, t: t) => {
let zipped = zip(t); let zipped = zip(t);
let firstIndex = let firstIndex =
zipped |> Belt.Array.getIndexBy(_, ((x, y)) => x > xValue); zipped |> Belt.Array.getIndexBy(_, ((x, y)) => x > xValue);
@ -41,9 +43,9 @@ let firstPairAtOrBeforeValue = (xValue, t: t) => {
| Some(n) => Some(n - 1) | Some(n) => Some(n - 1)
}; };
previousIndex |> Belt.Option.flatMap(_, Belt.Array.get(zipped)); previousIndex |> Belt.Option.flatMap(_, Belt.Array.get(zipped));
}; };
module XtoY = { module XtoY = {
let stepwiseIncremental = (f, t: t) => let stepwiseIncremental = (f, t: t) =>
firstPairAtOrBeforeValue(f, t) |> E.O.fmap(((_, y)) => y); firstPairAtOrBeforeValue(f, t) |> E.O.fmap(((_, y)) => y);
@ -53,14 +55,14 @@ module XtoY = {
// TODO: When Roman's PR comes in, fix this bit. This depends on interpolation, obviously. // TODO: When Roman's PR comes in, fix this bit. This depends on interpolation, obviously.
let linear = (f, t: t) => t |> CdfLibrary.Distribution.findY(f); let linear = (f, t: t) => t |> CdfLibrary.Distribution.findY(f);
}; };
let pointwiseMap = (fn, t: t): t => {xs: t.xs, ys: t.ys |> E.A.fmap(fn)}; let pointwiseMap = (fn, t: t): t => {xs: t.xs, ys: t.ys |> E.A.fmap(fn)};
let xMap = (fn, t: t): t => {xs: E.A.fmap(fn, t.xs), ys: t.ys}; let xMap = (fn, t: t): t => {xs: E.A.fmap(fn, t.xs), ys: t.ys};
let fromArray = ((xs, ys)): t => {xs, ys}; let fromArray = ((xs, ys)): t => {xs, ys};
let fromArrays = (xs, ys): t => {xs, ys}; let fromArrays = (xs, ys): t => {xs, ys};
module Combine = { module Combine = {
let combineLinear = (t1: t, t2: t, fn: (float, float) => float) => { let combineLinear = (t1: t, t2: t, fn: (float, float) => float) => {
let allXs = Belt.Array.concat(xs(t1), xs(t2)); let allXs = Belt.Array.concat(xs(t1), xs(t2));
allXs |> Array.sort(compare); allXs |> Array.sort(compare);
@ -101,27 +103,26 @@ module Combine = {
}); });
fromArrays(allXs, allYs); fromArrays(allXs, allYs);
}; };
}; };
// todo: maybe not needed? // todo: maybe not needed?
// let comparePoint = (a: float, b: float) => a > b ? 1 : (-1); // let comparePoint = (a: float, b: float) => a > b ? 1 : (-1);
let comparePoints = ((x1: float, y1: float), (x2: float, y2: float)) => let comparePoints = ((x1: float, y1: float), (x2: float, y2: float)) =>
switch (x1 == x2, y1 == y2) { switch (x1 == x2, y1 == y2) {
| (false, _) => compare(x1, x2) | (false, _) => compare(x1, x2)
| (true, false) => compare(y1, y2) | (true, false) => compare(y1, y2)
| (true, true) => (-1) | (true, true) => (-1)
}; };
// todo: This is broken :( // todo: This is broken :(
let combine = (t1: t, t2: t) => { let combine = (t1: t, t2: t) => {
let totalLength = E.A.length(t1.xs) + E.A.length(t2.xs);
let array = Belt.Array.concat(zip(t1), zip(t2)); let array = Belt.Array.concat(zip(t1), zip(t2));
Array.sort(comparePoints, array); Array.sort(comparePoints, array);
array |> Belt.Array.unzip |> fromArray; array |> Belt.Array.unzip |> fromArray;
}; };
let intersperce = (t1: t, t2: t) => { let intersperce = (t1: t, t2: t) => {
let items: ref(array((float, float))) = ref([||]); let items: ref(array((float, float))) = ref([||]);
let t1 = zip(t1); let t1 = zip(t1);
let t2 = zip(t2); let t2 = zip(t2);
@ -133,37 +134,39 @@ let intersperce = (t1: t, t2: t) => {
} }
}); });
items^ |> Belt.Array.unzip |> fromArray; items^ |> Belt.Array.unzip |> fromArray;
}; };
let yFold = (fn, t: t) => { let yFold = (fn, t: t) => {
E.A.fold_left(fn, 0., t.ys); E.A.fold_left(fn, 0., t.ys);
}; };
let ySum = yFold((a, b) => a +. b); let ySum = yFold((a, b) => a +. b);
let _transverse = fn => let _transverse = fn =>
Belt.Array.reduce(_, [||], (items, (x, y)) => Belt.Array.reduce(_, [||], (items, (x, y)) =>
switch (E.A.last(items)) { switch (E.A.last(items)) {
| Some((_, yLast)) => Belt.Array.concat(items, [|(x, fn(y, yLast))|]) | Some((_, yLast)) =>
Belt.Array.concat(items, [|(x, fn(y, yLast))|])
| None => [|(x, y)|] | None => [|(x, y)|]
} }
); );
let _transverseShape = (fn, p: t) => { let _transverseShape = (fn, p: t) => {
Belt.Array.zip(p.xs, p.ys) Belt.Array.zip(p.xs, p.ys)
|> _transverse(fn) |> _transverse(fn)
|> Belt.Array.unzip |> Belt.Array.unzip
|> fromArray; |> fromArray;
}; };
let filter = (fn, t: t) => let filter = (fn, t: t) =>
t |> zip |> E.A.filter(fn) |> Belt.Array.unzip |> fromArray; t |> zip |> E.A.filter(fn) |> Belt.Array.unzip |> fromArray;
let accumulateYs = _transverseShape((aCurrent, aLast) => aCurrent +. aLast); let accumulateYs = _transverseShape((aCurrent, aLast) => aCurrent +. aLast);
let subtractYs = _transverseShape((aCurrent, aLast) => aCurrent -. aLast); let subtractYs = _transverseShape((aCurrent, aLast) => aCurrent -. aLast);
let findY = CdfLibrary.Distribution.findY; let findY = CdfLibrary.Distribution.findY;
let findX = CdfLibrary.Distribution.findX; let findX = CdfLibrary.Distribution.findX;
};
// I'm really not sure this part is actually what we want at this point. // I'm really not sure this part is actually what we want at this point.
module Range = { module Range = {
@ -171,7 +174,7 @@ module Range = {
type zippedRange = ((float, float), (float, float)); type zippedRange = ((float, float), (float, float));
let floatSum = Belt.Array.reduce(_, 0., (a, b) => a +. b); let floatSum = Belt.Array.reduce(_, 0., (a, b) => a +. b);
let toT = r => r |> Belt.Array.unzip |> fromArray; let toT = r => r |> Belt.Array.unzip |> T.fromArray;
let nextX = ((_, (nextX, _)): zippedRange) => nextX; let nextX = ((_, (nextX, _)): zippedRange) => nextX;
let rangePointAssumingSteps = let rangePointAssumingSteps =
@ -197,21 +200,21 @@ module Range = {
let integrateWithTriangles = z => { let integrateWithTriangles = z => {
let rangeItems = mapYsBasedOnRanges(rangeAreaAssumingTriangles, z); let rangeItems = mapYsBasedOnRanges(rangeAreaAssumingTriangles, z);
( (
switch (rangeItems, z |> first) { switch (rangeItems, z |> T.first) {
| (Some(r), Some((firstX, _))) => | (Some(r), Some((firstX, _))) =>
Some(Belt.Array.concat([|(firstX, 0.0)|], r)) Some(Belt.Array.concat([|(firstX, 0.0)|], r))
| _ => None | _ => None
} }
) )
|> E.O.fmap(toT) |> E.O.fmap(toT)
|> E.O.fmap(accumulateYs); |> E.O.fmap(T.accumulateYs);
}; };
let derivative = mapYsBasedOnRanges(delta_y_over_delta_x); let derivative = mapYsBasedOnRanges(delta_y_over_delta_x);
// TODO: It would be nicer if this the diff didn't change the first element, and also maybe if there were a more elegant way of doing this. // TODO: It would be nicer if this the diff didn't change the first element, and also maybe if there were a more elegant way of doing this.
let stepsToContinuous = t => { let stepsToContinuous = t => {
let diff = xTotalRange(t) |> E.O.fmap(r => r *. 0.00001); let diff = T.xTotalRange(t) |> E.O.fmap(r => r *. 0.00001);
let items = let items =
switch (diff, E.A.toRanges(Belt.Array.zip(t.xs, t.ys))) { switch (diff, E.A.toRanges(Belt.Array.zip(t.xs, t.ys))) {
| (Some(diff), Ok(items)) => | (Some(diff), Ok(items)) =>
@ -219,21 +222,57 @@ module Range = {
items items
|> Belt.Array.map(_, rangePointAssumingSteps) |> Belt.Array.map(_, rangePointAssumingSteps)
|> Belt.Array.unzip |> Belt.Array.unzip
|> fromArray |> T.fromArray
|> intersperce(t |> xMap(e => e +. diff)), |> T.intersperce(t |> T.xMap(e => e +. diff)),
) )
| _ => Some(t) | _ => Some(t)
}; };
let bar = items |> E.O.fmap(zip) |> E.O.bind(_, E.A.get(_, 0)); let bar = items |> E.O.fmap(T.zip) |> E.O.bind(_, E.A.get(_, 0));
let items = let items =
switch (items, bar) { switch (items, bar) {
| (Some(items), Some((0.0, _))) => Some(items) | (Some(items), Some((0.0, _))) => Some(items)
| (Some(items), Some((firstX, _))) => | (Some(items), Some((firstX, _))) =>
let all = E.A.append([|(firstX, 0.0)|], items |> zip); let all = E.A.append([|(firstX, 0.0)|], items |> T.zip);
let foo = all |> Belt.Array.unzip |> fromArray; let foo = all |> Belt.Array.unzip |> T.fromArray;
Some(foo); Some(foo);
| _ => None | _ => None
}; };
items; items;
}; };
}; };
module Ts = {
type t = T.ts;
let minX = (t: t) =>
t |> E.A.fmap(T.minX) |> E.A.O.concatSomes |> Functions.min;
let maxX = (t: t) =>
t |> E.A.fmap(T.maxX) |> E.A.O.concatSomes |> Functions.max;
// TODO/Warning: This will break if the shapes are empty.
let equallyDividedXs = (t: t, newLength) => {
Functions.range(minX(t), maxX(t), newLength);
};
};
let combinePointwise = (fn, sampleCount, t1: xyShape, t2: xyShape) => {
let xs = Ts.equallyDividedXs([|t1, t2|], sampleCount);
let ys =
xs |> E.A.fmap(x => fn(T.XtoY.linear(x, t1), T.XtoY.linear(x, t2)));
T.fromArrays(xs, ys);
};
let logScoreDist =
combinePointwise((prediction, answer) =>
switch (answer) {
| 0. => 0.0
| answer =>
answer *. Js.Math.log2(Js.Math.abs_float(prediction /. answer))
}
);
let logScorePoint = (sampleCount, t1, t2) =>
logScoreDist(sampleCount, t1, t2)
|> Range.integrateWithTriangles
|> E.O.fmap(T.accumulateYs)
|> E.O.bind(_, T.last)
|> E.O.fmap(((_, y)) => y);