Shape -> PointSetDist

This commit is contained in:
Ozzie Gooen 2022-02-15 16:13:33 -05:00
parent eb5f5245b6
commit 336a5fb57f
17 changed files with 60 additions and 60 deletions

View File

@ -1,7 +1,7 @@
open Jest
open Expect
let shape: PointSetTypes.xyShape = {xs: [1., 4., 8.], ys: [8., 9., 2.]}
// let PointSetDist: PointSetTypes.xyPointSetDist = {xs: [1., 4., 8.], ys: [8., 9., 2.]}
// let makeTest = (~only=false, str, item1, item2) =>
// only
@ -21,15 +21,15 @@ let shape: PointSetTypes.xyShape = {xs: [1., 4., 8.], ys: [8., 9., 2.]}
// expect(item1) |> toBeSoCloseTo(item2, ~digits)
// );
// describe("Shape", () => {
// describe("PointSetDist", () => {
// describe("Continuous", () => {
// open Continuous;
// let continuous = make(`Linear, shape, None);
// let continuous = make(`Linear, PointSetDist, None);
// makeTest("minX", T.minX(continuous), 1.0);
// makeTest("maxX", T.maxX(continuous), 8.0);
// makeTest(
// "mapY",
// T.mapY(r => r *. 2.0, continuous) |> getShape |> (r => r.ys),
// T.mapY(r => r *. 2.0, continuous) |> getPointSetDist |> (r => r.ys),
// [|16., 18.0, 4.0|],
// );
// describe("xToY", () => {
@ -57,7 +57,7 @@ let shape: PointSetTypes.xyShape = {xs: [1., 4., 8.], ys: [8., 9., 2.]}
// );
// });
// describe("when Stepwise", () => {
// let continuous = make(`Stepwise, shape, None);
// let continuous = make(`Stepwise, PointSetDist, None);
// makeTest(
// "at 4.0",
// T.xToY(4., continuous),
@ -82,7 +82,7 @@ let shape: PointSetTypes.xyShape = {xs: [1., 4., 8.], ys: [8., 9., 2.]}
// });
// makeTest(
// "integral",
// T.Integral.get(~cache=None, continuous) |> getShape,
// T.Integral.get(~cache=None, continuous) |> getPointSetDist,
// {xs: [|1.0, 4.0, 8.0|], ys: [|0.0, 25.5, 47.5|]},
// );
// makeTest(
@ -90,7 +90,7 @@ let shape: PointSetTypes.xyShape = {xs: [1., 4., 8.], ys: [8., 9., 2.]}
// {
// let continuous =
// make(`Stepwise, {xs: [|1., 4., 8.|], ys: [|0.1, 5., 1.0|]}, None);
// continuous |> toLinear |> E.O.fmap(getShape);
// continuous |> toLinear |> E.O.fmap(getPointSetDist);
// },
// Some({
// xs: [|1.00007, 1.00007, 4.0, 4.00007, 8.0, 8.00007|],
@ -101,7 +101,7 @@ let shape: PointSetTypes.xyShape = {xs: [1., 4., 8.], ys: [8., 9., 2.]}
// "toLinear",
// {
// let continuous = make(`Stepwise, {xs: [|0.0|], ys: [|0.3|]}, None);
// continuous |> toLinear |> E.O.fmap(getShape);
// continuous |> toLinear |> E.O.fmap(getPointSetDist);
// },
// Some({xs: [|0.0|], ys: [|0.3|]}),
// );
@ -131,16 +131,16 @@ let shape: PointSetTypes.xyShape = {xs: [1., 4., 8.], ys: [8., 9., 2.]}
// describe("Discrete", () => {
// open Discrete;
// let shape: PointSetTypes.xyShape = {
// let PointSetDist: PointSetTypes.xyPointSetDist = {
// xs: [|1., 4., 8.|],
// ys: [|0.3, 0.5, 0.2|],
// };
// let discrete = make(shape, None);
// let discrete = make(PointSetDist, None);
// makeTest("minX", T.minX(discrete), 1.0);
// makeTest("maxX", T.maxX(discrete), 8.0);
// makeTest(
// "mapY",
// T.mapY(r => r *. 2.0, discrete) |> (r => getShape(r).ys),
// T.mapY(r => r *. 2.0, discrete) |> (r => getPointSetDist(r).ys),
// [|0.6, 1.0, 0.4|],
// );
// makeTest(
@ -209,11 +209,11 @@ let shape: PointSetTypes.xyShape = {xs: [1., 4., 8.], ys: [8., 9., 2.]}
// describe("Mixed", () => {
// open Distributions.Mixed;
// let discreteShape: PointSetTypes.xyShape = {
// let discretePointSetDist: PointSetTypes.xyPointSetDist = {
// xs: [|1., 4., 8.|],
// ys: [|0.3, 0.5, 0.2|],
// };
// let discrete = Discrete.make(discreteShape, None);
// let discrete = Discrete.make(discretePointSetDist, None);
// let continuous =
// Continuous.make(
// `Linear,
@ -309,11 +309,11 @@ let shape: PointSetTypes.xyShape = {xs: [1., 4., 8.], ys: [8., 9., 2.]}
// describe("Distplus", () => {
// open DistPlus;
// let discreteShape: PointSetTypes.xyShape = {
// let discretePointSetDist: PointSetTypes.xyPointSetDist = {
// xs: [|1., 4., 8.|],
// ys: [|0.3, 0.5, 0.2|],
// };
// let discrete = Discrete.make(discreteShape, None);
// let discrete = Discrete.make(discretePointSetDist, None);
// let continuous =
// Continuous.make(
// `Linear,
@ -328,7 +328,7 @@ let shape: PointSetTypes.xyShape = {xs: [1., 4., 8.], ys: [8., 9., 2.]}
// );
// let distPlus =
// DistPlus.make(
// ~shape=Mixed(mixed),
// ~PointSetDist=Mixed(mixed),
// ~squiggleString=None,
// (),
// );
@ -376,38 +376,38 @@ let shape: PointSetTypes.xyShape = {xs: [1., 4., 8.], ys: [8., 9., 2.]}
// );
// });
// describe("Shape", () => {
// describe("PointSetDist", () => {
// let mean = 10.0;
// let stdev = 4.0;
// let variance = stdev ** 2.0;
// let numSamples = 10000;
// open Distributions.Shape;
// open Distributions.PointSetDist;
// let normal: SymbolicDistTypes.symbolicDist = `Normal({mean, stdev});
// let normalShape = AST.toShape(numSamples, `SymbolicDist(normal));
// let normalPointSetDist = AST.toPointSetDist(numSamples, `SymbolicDist(normal));
// let lognormal = SymbolicDist.Lognormal.fromMeanAndStdev(mean, stdev);
// let lognormalShape = AST.toShape(numSamples, `SymbolicDist(lognormal));
// let lognormalPointSetDist = AST.toPointSetDist(numSamples, `SymbolicDist(lognormal));
// makeTestCloseEquality(
// "Mean of a normal",
// T.mean(normalShape),
// T.mean(normalPointSetDist),
// mean,
// ~digits=2,
// );
// makeTestCloseEquality(
// "Variance of a normal",
// T.variance(normalShape),
// T.variance(normalPointSetDist),
// variance,
// ~digits=1,
// );
// makeTestCloseEquality(
// "Mean of a lognormal",
// T.mean(lognormalShape),
// T.mean(lognormalPointSetDist),
// mean,
// ~digits=2,
// );
// makeTestCloseEquality(
// "Variance of a lognormal",
// T.variance(lognormalShape),
// T.variance(lognormalPointSetDist),
// variance,
// ~digits=0,
// );

View File

@ -15,7 +15,7 @@ let evalParams: ASTTypes.AST.evaluationParams = {
sampleCount: 1000,
outputXYPoints: 10000,
kernelWidth: None,
shapeLength: 1000,
PointSetDistLength: 1000,
},
environment:
[|
@ -28,9 +28,9 @@ let evalParams: ASTTypes.AST.evaluationParams = {
evaluateNode: ASTEvaluator.toLeaf,
};
let shape1: PointSetTypes.xyShape = {xs: [|1., 4., 8.|], ys: [|0.2, 0.4, 0.8|]};
let PointSetDist1: PointSetTypes.xyPointSetDist = {xs: [|1., 4., 8.|], ys: [|0.2, 0.4, 0.8|]};
describe("XYShapes", () => {
describe("XYPointSetDists", () => {
describe("logScorePoint", () => {
makeTest(
"When identical",

View File

@ -6,25 +6,25 @@ let makeTest = (~only=false, str, item1, item2) =>
? Only.test(str, () => expect(item1) -> toEqual(item2))
: test(str, () => expect(item1) -> toEqual(item2))
let shape1: PointSetTypes.xyShape = {xs: [1., 4., 8.], ys: [0.2, 0.4, 0.8]}
let pointSetDist1: PointSetTypes.xyShape = {xs: [1., 4., 8.], ys: [0.2, 0.4, 0.8]}
let shape2: PointSetTypes.xyShape = {
let pointSetDist2: PointSetTypes.xyShape = {
xs: [1., 5., 10.],
ys: [0.2, 0.5, 0.8],
}
let shape3: PointSetTypes.xyShape = {
let pointSetDist3: PointSetTypes.xyShape = {
xs: [1., 20., 50.],
ys: [0.2, 0.5, 0.8],
}
describe("XYShapes", () => {
describe("logScorePoint", () => {
makeTest("When identical", XYShape.logScorePoint(30, shape1, shape1), Some(0.0))
makeTest("When similar", XYShape.logScorePoint(30, shape1, shape2), Some(1.658971191043856))
makeTest("When identical", XYShape.logScorePoint(30, pointSetDist1, pointSetDist1), Some(0.0))
makeTest("When similar", XYShape.logScorePoint(30, pointSetDist1, pointSetDist2), Some(1.658971191043856))
makeTest(
"When very different",
XYShape.logScorePoint(30, shape1, shape3),
XYShape.logScorePoint(30, pointSetDist1, pointSetDist3),
Some(210.3721280423322),
)
})
@ -41,7 +41,7 @@ describe("XYShapes", () => {
describe("integrateWithTriangles", () =>
makeTest(
"integrates correctly",
XYShape.Range.integrateWithTriangles(shape1),
XYShape.Range.integrateWithTriangles(pointSetDist1),
Some({
xs: [1., 4., 8.],
ys: [0.0, 0.9000000000000001, 3.3000000000000007],

View File

@ -27,7 +27,7 @@ module AlgebraicCombination = {
E.R.merge(
Render.ensureIsRenderedAndGetShape(evaluationParams, t1),
Render.ensureIsRenderedAndGetShape(evaluationParams, t2),
) |> E.R.fmap(((a, b)) => #RenderedDist(Shape.combineAlgebraically(algebraicOp, a, b)))
) |> E.R.fmap(((a, b)) => #RenderedDist(PointSetDist.combineAlgebraically(algebraicOp, a, b)))
let nodeScore: node => int = x =>
switch x {
@ -76,7 +76,7 @@ module PointwiseCombination = {
| (Ok(#RenderedDist(rs1)), Ok(#RenderedDist(rs2))) =>
Ok(
#RenderedDist(
Shape.combinePointwise(
PointSetDist.combinePointwise(
~integralSumCachesFn=(a, b) => Some(a +. b),
~integralCachesFn=(a, b) => Some(
Continuous.combinePointwise(~distributionType=#CDF, \"+.", a, b),
@ -98,7 +98,7 @@ module PointwiseCombination = {
// TODO: This should work for symbolic distributions too!
(Render.render(evaluationParams, t1), Render.render(evaluationParams, t2)) {
| (Ok(#RenderedDist(rs1)), Ok(#RenderedDist(rs2))) =>
Ok(#RenderedDist(Shape.combinePointwise(fn, rs1, rs2)))
Ok(#RenderedDist(PointSetDist.combinePointwise(fn, rs1, rs2)))
| (Error(e1), _) => Error(e1)
| (_, Error(e2)) => Error(e2)
| _ => Error("Pointwise combination: rendering failed.")
@ -132,7 +132,7 @@ module Truncate = {
switch // TODO: use named args for xMin/xMax in renderToShape; if we're lucky we can at least get the tail
// of a distribution we otherwise wouldn't get at all
Render.ensureIsRendered(evaluationParams, t) {
| Ok(#RenderedDist(rs)) => Ok(#RenderedDist(Shape.T.truncate(leftCutoff, rightCutoff, rs)))
| Ok(#RenderedDist(rs)) => Ok(#RenderedDist(PointSetDist.T.truncate(leftCutoff, rightCutoff, rs)))
| Error(e) => Error(e)
| _ => Error("Could not truncate distribution.")
}
@ -158,7 +158,7 @@ module Truncate = {
module Normalize = {
let rec operationToLeaf = (evaluationParams, t: node): result<node, string> =>
switch t {
| #RenderedDist(s) => Ok(#RenderedDist(Shape.T.normalize(s)))
| #RenderedDist(s) => Ok(#RenderedDist(PointSetDist.T.normalize(s)))
| #SymbolicDist(_) => Ok(t)
| _ => evaluateAndRetry(evaluationParams, operationToLeaf, t)
}

View File

@ -98,7 +98,7 @@ module SamplingDistribution = {
)
let sampleN = n =>
map(~renderedDistFn=Shape.sampleNRendered(n), ~symbolicDistFn=SymbolicDist.T.sampleN(n))
map(~renderedDistFn=PointSetDist.sampleNRendered(n), ~symbolicDistFn=SymbolicDist.T.sampleN(n))
let getCombinationSamples = (n, algebraicOp, t1: node, t2: node) =>
switch (sampleN(n, t1), sampleN(n, t2)) {

View File

@ -90,7 +90,7 @@ let floatFromDist = (
switch t {
| #SymbolicDist(s) =>
SymbolicDist.T.operate(distToFloatOp, s) |> E.R.bind(_, v => Ok(#SymbolicDist(#Float(v))))
| #RenderedDist(rs) => Shape.operate(distToFloatOp, rs) |> (v => Ok(#SymbolicDist(#Float(v))))
| #RenderedDist(rs) => PointSetDist.operate(distToFloatOp, rs) |> (v => Ok(#SymbolicDist(#Float(v))))
}
let verticalScaling = (scaleOp, rs, scaleBy) => {
@ -100,7 +100,7 @@ let verticalScaling = (scaleOp, rs, scaleBy) => {
let integralCacheFn = Operation.Scale.toIntegralCacheFn(scaleOp)
Ok(
#RenderedDist(
Shape.T.mapY(
PointSetDist.T.mapY(
~integralSumCacheFn=integralSumCacheFn(scaleBy),
~integralCacheFn=integralCacheFn(scaleBy),
~fn=fn(scaleBy),
@ -209,7 +209,7 @@ let all = [
~run=x =>
switch x {
| [#SamplingDist(#SymbolicDist(c))] => Ok(#SymbolicDist(c))
| [#SamplingDist(#RenderedDist(c))] => Ok(#RenderedDist(Shape.T.normalize(c)))
| [#SamplingDist(#RenderedDist(c))] => Ok(#RenderedDist(PointSetDist.T.normalize(c)))
| e => wrongInputsError(e)
},
(),

View File

@ -39,7 +39,7 @@ module TypedValue = {
let rec toString: typedValue => string = x =>
switch x {
| #SamplingDist(_) => "[sampling dist]"
| #RenderedDist(_) => "[rendered Shape]"
| #RenderedDist(_) => "[rendered PointSetDist]"
| #Float(f) => "Float: " ++ Js.Float.toString(f)
| #Array(a) => "[" ++ ((a |> E.A.fmap(toString) |> Js.String.concatMany(_, ",")) ++ "]")
| #Hash(v) =>

View File

@ -2,7 +2,7 @@ open PointSetTypes;
type t = PointSetTypes.distPlus;
let shapeIntegral = shape => Shape.T.Integral.get(shape);
let shapeIntegral = shape => PointSetDist.T.Integral.get(shape);
let make =
(
~shape,
@ -52,11 +52,11 @@ module T =
type t = PointSetTypes.distPlus;
type integral = PointSetTypes.distPlus;
let toShape = toShape;
let toContinuous = shapeFn(Shape.T.toContinuous);
let toDiscrete = shapeFn(Shape.T.toDiscrete);
let toContinuous = shapeFn(PointSetDist.T.toContinuous);
let toDiscrete = shapeFn(PointSetDist.T.toDiscrete);
let normalize = (t: t): t => {
let normalizedShape = t |> toShape |> Shape.T.normalize;
let normalizedShape = t |> toShape |> PointSetDist.T.normalize;
t |> updateShape(normalizedShape);
};
@ -64,7 +64,7 @@ module T =
let truncatedShape =
t
|> toShape
|> Shape.T.truncate(leftCutoff, rightCutoff);
|> PointSetDist.T.truncate(leftCutoff, rightCutoff);
t |> updateShape(truncatedShape);
};
@ -72,13 +72,13 @@ module T =
let xToY = (f, t: t) =>
t
|> toShape
|> Shape.T.xToY(f)
|> PointSetDist.T.xToY(f)
|> MixedPoint.fmap(domainIncludedProbabilityMassAdjustment(t));
let minX = shapeFn(Shape.T.minX);
let maxX = shapeFn(Shape.T.maxX);
let minX = shapeFn(PointSetDist.T.minX);
let maxX = shapeFn(PointSetDist.T.maxX);
let toDiscreteProbabilityMassFraction =
shapeFn(Shape.T.toDiscreteProbabilityMassFraction);
shapeFn(PointSetDist.T.toDiscreteProbabilityMassFraction);
// This bit is kind of awkward, could probably use rethinking.
let integral = (t: t) =>
@ -88,7 +88,7 @@ module T =
update(~integralCache=E.O.default(t.integralCache, integralCache), t);
let downsample = (i, t): t =>
updateShape(t |> toShape |> Shape.T.downsample(i), t);
updateShape(t |> toShape |> PointSetDist.T.downsample(i), t);
// todo: adjust for limit, maybe?
let mapY =
(
@ -98,19 +98,19 @@ module T =
{shape, _} as t: t,
)
: t =>
Shape.T.mapY(~integralSumCacheFn, ~fn, shape)
PointSetDist.T.mapY(~integralSumCacheFn, ~fn, shape)
|> updateShape(_, t);
// get the total of everything
let integralEndY = (t: t) => {
Shape.T.Integral.sum(
PointSetDist.T.Integral.sum(
toShape(t),
);
};
// TODO: Fix this below, obviously. Adjust for limits
let integralXtoY = (f, t: t) => {
Shape.T.Integral.xToY(
PointSetDist.T.Integral.xToY(
f,
toShape(t),
)
@ -119,11 +119,11 @@ module T =
// TODO: This part is broken when there is a limit, if this is supposed to be taken into account.
let integralYtoX = (f, t: t) => {
Shape.T.Integral.yToX(f, toShape(t));
PointSetDist.T.Integral.yToX(f, toShape(t));
};
let mean = (t: t) => {
Shape.T.mean(t.shape);
PointSetDist.T.mean(t.shape);
};
let variance = (t: t) => Shape.T.variance(t.shape);
let variance = (t: t) => PointSetDist.T.variance(t.shape);
});