Slightly cleaned up tree evaluation

This commit is contained in:
Sebastian Kosch 2020-06-13 18:46:38 -07:00
parent 9b10452156
commit 8827650da3
2 changed files with 115 additions and 83 deletions

View File

@ -197,6 +197,7 @@ module MathAdtToDistDst = {
};
let arrayParser = (args:array(arg)):result(SymbolicDist.distTree, string) => {
Js.log2("SAMPLING NOW!", args);
let samples = args
|> E.A.fmap(
fun
@ -287,6 +288,27 @@ module MathAdtToDistDst = {
| [|Ok(l), Ok(r)|] => Ok(`Combination(l, r, `DivideOperation))
| _ => Error("Division needs two operands"))
}
| Fn({name: "pow", args}) => {
args
|> E.A.fmap(functionParser)
|> (fun
| [|Ok(l), Ok(r)|] => Ok(`Combination(l, r, `ExponentiateOperation))
| _ => Error("Exponentiations needs two operands"))
}
| Fn({name: "leftTruncate", args}) => {
args
|> E.A.fmap(functionParser)
|> (fun
| [|Ok(l), Ok(`Distribution(`Float(r)))|] => Ok(`LeftTruncate(l, r))
| _ => Error("leftTruncate needs two arguments: the expression and the cutoff"))
}
| Fn({name: "rightTruncate", args}) => {
args
|> E.A.fmap(functionParser)
|> (fun
| [|Ok(l), Ok(`Distribution(`Float(r)))|] => Ok(`RightTruncate(l, r))
| _ => Error("rightTruncate needs two arguments: the expression and the cutoff"))
}
| Fn({name}) => Error(name ++ ": function not supported")
| _ => {
Error("This type not currently supported");

View File

@ -277,34 +277,34 @@ module GenericSimple = {
/* This function returns a list of x's at which to evaluate the overall distribution (for rendering).
This function is called separately for each individual distribution.
When called with xSelection=`Linear, this function will return (sampleCount) x's, evenly
When called with xSelection=`Linear, this function will return (n) x's, evenly
distributed between the min and max of the distribution (whatever those are defined to be above).
When called with xSelection=`ByWeight, this function will distribute the x's such as to
match the cumulative shape of the distribution. This is slower but may give better results.
*/
let interpolateXs =
(~xSelection: [ | `Linear | `ByWeight]=`Linear, dist: dist, sampleCount) => {
(~xSelection: [ | `Linear | `ByWeight]=`Linear, dist: dist, n) => {
switch (xSelection, dist) {
| (`Linear, _) => E.A.Floats.range(min(dist), max(dist), sampleCount)
| (`Linear, _) => E.A.Floats.range(min(dist), max(dist), n)
| (`ByWeight, `Uniform(n)) =>
// In `ByWeight mode, uniform distributions get special treatment because we need two x's
// on either side for proper rendering (just left and right of the discontinuities).
let dx = 0.00001 *. (n.high -. n.low);
[|n.low -. dx, n.low +. dx, n.high -. dx, n.high +. dx|];
| (`ByWeight, _) =>
let ys = E.A.Floats.range(minCdfValue, maxCdfValue, sampleCount);
let ys = E.A.Floats.range(minCdfValue, maxCdfValue, n);
ys |> E.A.fmap(y => inv(y, dist));
};
};
let toShape =
(~xSelection: [ | `Linear | `ByWeight]=`Linear, dist: dist, sampleCount)
(~xSelection: [ | `Linear | `ByWeight]=`Linear, dist: dist, n)
: DistTypes.shape => {
switch (dist) {
| `ContinuousShape(n) => n.pdf |> Distributions.Continuous.T.toShape
| dist =>
let xs = interpolateXs(~xSelection, dist, sampleCount);
let xs = interpolateXs(~xSelection, dist, n);
let ys = xs |> E.A.fmap(r => pdf(r, dist));
XYShape.T.fromArrays(xs, ys)
|> Distributions.Continuous.make(`Linear, _)
@ -321,23 +321,43 @@ module DistTree = {
];
let evaluateDistribution = (d: dist): nodeResult => {
// certain distributions we may want to evaluate to RenderedShapes right away, e.g. discrete
`Distribution(d)
};
// This is a performance bottleneck!
// Using raw JS here so we can use native for loops and access array elements
// directly, without option checks.
let jsCombinationConvolve: (array(float), array(float), array(float), array(float), float => float => float) => (array(float), array(float)) = [%bs.raw
let jsContinuousCombinationConvolve: (array(float), array(float), array(float), array(float), float => float => float) => array(array((float, float))) = [%bs.raw
{|
function (s1xs, s1ys, s2xs, s2ys, func) {
// For continuous-continuous convolution, use linear interpolation.
// Let's assume we got downsampled distributions
const outXYShapes = new Array(s1xs.length);
for (let i = 0; i < s1xs.length; i++) {
// create a new distribution
const dxyShape = new Array(s2xs.length);
for (let j = 0; j < s2xs.length; j++) {
dxyShape[j] = [func(s1xs[i], s2xs[j]), (s1ys[i] * s2ys[j])];
}
outXYShapes[i] = dxyShape;
}
return outXYShapes;
}
|}];
let jsDiscreteCombinationConvolve: (array(float), array(float), array(float), array(float), float => float => float) => (array(float), array(float)) = [%bs.raw
{|
function (s1xs, s1ys, s2xs, s2ys, func) {
const r = new Map();
// To convolve, add the xs and multiply the ys:
for (let i = 0; i < s1xs.length; i++) {
for (let j = 0; j < s2xs.length; j++) {
const x = func(s1xs[i], s2xs[j]);
const cv = r.get(x) | 0;
r.set(x, cv + s1ys[i] * s2ys[j]); // add up the ys, if same x
}
}
@ -367,12 +387,12 @@ module DistTree = {
}
}
let renderDistributionToXYShape = (d: dist, sampleCount: int): (DistTypes.continuousShape, DistTypes.discreteShape) => {
let renderDistributionToXYShape = (d: dist, n: int): (DistTypes.continuousShape, DistTypes.discreteShape) => {
// render the distribution into an XY shape
switch (d) {
| `Float(v) => (Distributions.Continuous.empty, {xs: [|v|], ys: [|1.0|]})
| _ => {
let xs = GenericSimple.interpolateXs(~xSelection=`ByWeight, d, sampleCount);
let xs = GenericSimple.interpolateXs(~xSelection=`ByWeight, d, n);
let ys = xs |> E.A.fmap(x => GenericSimple.pdf(x, d));
(Distributions.Continuous.make(`Linear, {xs: xs, ys: ys}), XYShape.T.empty)
}
@ -384,23 +404,37 @@ module DistTree = {
sc2: DistTypes.continuousShape,
sd2: DistTypes.discreteShape, func): (DistTypes.continuousShape, DistTypes.discreteShape) => {
let (ccxs, ccys) = jsCombinationConvolve(sc1.xyShape.xs, sc1.xyShape.ys, sc2.xyShape.xs, sc2.xyShape.ys, func);
let (dcxs, dcys) = jsCombinationConvolve(sd1.xs, sd1.ys, sc2.xyShape.xs, sc2.xyShape.ys, func);
let (cdxs, cdys) = jsCombinationConvolve(sc1.xyShape.xs, sc1.xyShape.ys, sd2.xs, sd2.ys, func);
let (ddxs, ddys) = jsCombinationConvolve(sd1.xs, sd1.ys, sd2.xs, sd2.ys, func);
let ccxy = Distributions.Continuous.make(`Linear, {xs: ccxs, ys: ccys});
let dcxy = Distributions.Continuous.make(`Linear, {xs: dcxs, ys: dcys});
let cdxy = Distributions.Continuous.make(`Linear, {xs: cdxs, ys: cdys});
// the continuous parts are added up; only the discrete-discrete sum is discrete
let continuousShapeSum = Distributions.Continuous.reduce((+.), [|ccxy, dcxy, cdxy|]);
// First, deal with the discrete-discrete convolution:
let (ddxs, ddys) = jsDiscreteCombinationConvolve(sd1.xs, sd1.ys, sd2.xs, sd2.ys, func);
let ddxy: DistTypes.discreteShape = {xs: cdxs, ys: cdys};
// Then, do the other three:
let downsample = (sc: DistTypes.continuousShape) => {
let scLength = E.A.length(sc.xyShape.xs);
let scSqLength = sqrt(float_of_int(scLength));
scSqLength > 10. ? Distributions.Continuous.T.truncate(int_of_float(scSqLength), sc) : sc;
};
let combinePointConvolutionResults = ccs
|> E.A.fmap(s => {
// s is an array of (x, y) objects
let (xs, ys) = Belt.Array.unzip(s);
Distributions.Continuous.make(`Linear, {xs, ys});
})
|> Distributions.Continuous.reduce((+.));
let sc1d = downsample(sc1);
let sc2d = downsample(sc2);
let ccxy = jsContinuousCombinationConvolve(sc1d.xyShape.xs, sc1d.xyShape.ys, sc2d.xyShape.xs, sc2d.xyShape.ys, func) |> combinePointConvolutionResults;
let dcxy = jsContinuousCombinationConvolve(sc1d.xyShape.xs, sc1d.xyShape.ys, sc2d.xyShape.xs, sc2d.xyShape.ys, func) |> combinePointConvolutionResults;
let cdxy = jsContinuousCombinationConvolve(sc1d.xyShape.xs, sc1d.xyShape.ys, sc2d.xyShape.xs, sc2d.xyShape.ys, func) |> combinePointConvolutionResults;
let continuousShapeSum = Distributions.Continuous.reduce((+.), [|ccxy, dcxy, cdxy|]);
(continuousShapeSum, ddxy)
};
let evaluateCombinationDistribution = (et1: nodeResult, et2: nodeResult, op: operation, sampleCount: int) => {
let evaluateCombinationDistribution = (et1: nodeResult, et2: nodeResult, op: operation, n: int) => {
/* return either a Distribution or a RenderedShape. Must integrate to 1. */
let func = funcFromOp(op);
@ -439,31 +473,26 @@ module DistTree = {
/* General cases: convolve the XYShapes */
| (`Distribution(d1), `Distribution(d2), _) => {
let (sc1, sd1) = renderDistributionToXYShape(d1, sampleCount);
let (sc2, sd2) = renderDistributionToXYShape(d2, sampleCount);
let (sc1, sd1) = renderDistributionToXYShape(d1, n);
let (sc2, sd2) = renderDistributionToXYShape(d2, n);
let (sc, sd) = combinationDistributionOfXYShapes(sc1, sd1, sc2, sd2, func);
`RenderedShape(sc, sd, 1.0)
}
| (`Distribution(d1), `RenderedShape(sc2, sd2, i2), _) => {
let (sc1, sd1) = renderDistributionToXYShape(d1, sampleCount);
| (`Distribution(d2), `RenderedShape(sc1, sd1, i1), _)
| (`RenderedShape(sc1, sd1, i1), `Distribution(d2), _) => {
let (sc1, sd1) = renderDistributionToXYShape(d1, n);
let (sc, sd) = combinationDistributionOfXYShapes(sc1, sd1, sc2, sd2, func);
`RenderedShape(sc, sd, i2)
}
| (`RenderedShape(sc1, sd1, i1), `Distribution(d2), _) => {
let (sc2, sd2) = renderDistributionToXYShape(d2, sampleCount);
let (sc, sd) = combinationDistributionOfXYShapes(sc1, sd1, sc2, sd2, func);
`RenderedShape(sc, sd, i1);
}
| (`RenderedShape(sc1, sd1, i1), `RenderedShape(sc2, sd2, i2), _) => {
// sum of two multimodals that have a continuous and discrete each.
let (sc, sd) = combinationDistributionOfXYShapes(sc1, sd1, sc2, sd2, func);
`RenderedShape(sc, sd, i1);
}
}
};
let evaluatePointwiseSum = (et1: nodeResult, et2: nodeResult, sampleCount: int) => {
let evaluatePointwiseSum = (et1: nodeResult, et2: nodeResult, n: int) => {
switch ((et1, et2)) {
/* Known cases: */
| (`Distribution(`Float(v1)), `Distribution(`Float(v2))) => {
@ -471,36 +500,29 @@ module DistTree = {
? `RenderedShape(Distributions.Continuous.empty, Distributions.Discrete.make({xs: [|v1|], ys: [|2.|]}), 2.)
: `RenderedShape(Distributions.Continuous.empty, Distributions.Discrete.empty, 0.) // TODO: add warning: shouldn't pointwise add scalars.
}
| (`Distribution(`Float(v1)), `Distribution(d2)) => {
| (`Distribution(`Float(v1)), `Distribution(d2))
| (`Distribution(d2), `Distribution(`Float(v1))) => {
let sd1: DistTypes.xyShape = {xs: [|v1|], ys: [|1.|]};
let (sc2, sd2) = renderDistributionToXYShape(d2, sampleCount);
let (sc2, sd2) = renderDistributionToXYShape(d2, n);
`RenderedShape(sc2, Distributions.Discrete.reduce((+.), [|sd1, sd2|]), 2.)
}
| (`Distribution(d1), `Distribution(`Float(v2))) => {
let (sc1, sd1) = renderDistributionToXYShape(d1, sampleCount);
let sd2: DistTypes.xyShape = {xs: [|v2|], ys: [|1.|]};
`RenderedShape(sc1, Distributions.Discrete.reduce((+.), [|sd1, sd2|]), 2.)
}
| (`Distribution(d1), `Distribution(d2)) => {
let (sc1, sd1) = renderDistributionToXYShape(d1, sampleCount);
let (sc2, sd2) = renderDistributionToXYShape(d2, sampleCount);
let (sc1, sd1) = renderDistributionToXYShape(d1, n);
let (sc2, sd2) = renderDistributionToXYShape(d2, n);
`RenderedShape(Distributions.Continuous.reduce((+.), [|sc1, sc2|]), Distributions.Discrete.reduce((+.), [|sd1, sd2|]), 2.)
}
| (`Distribution(d1), `RenderedShape(sc2, sd2, i2))
| (`RenderedShape(sc2, sd2, i2), `Distribution(d1)) => {
let (sc1, sd1) = renderDistributionToXYShape(d1, sampleCount);
let (sc1, sd1) = renderDistributionToXYShape(d1, n);
`RenderedShape(Distributions.Continuous.reduce((+.), [|sc1, sc2|]), Distributions.Discrete.reduce((+.), [|sd1, sd2|]), 1. +. i2)
}
| (`RenderedShape(sc1, sd1, i1), `RenderedShape(sc2, sd2, i2)) => {
Js.log3("Reducing continuous rr", sc1, sc2);
Js.log2("Continuous reduction:", Distributions.Continuous.reduce((+.), [|sc1, sc2|]));
Js.log2("Discrete reduction:", Distributions.Discrete.reduce((+.), [|sd1, sd2|]));
`RenderedShape(Distributions.Continuous.reduce((+.), [|sc1, sc2|]), Distributions.Discrete.reduce((+.), [|sd1, sd2|]), i1 +. i2)
}
}
};
let evaluatePointwiseProduct = (et1: nodeResult, et2: nodeResult, sampleCount: int) => {
let evaluatePointwiseProduct = (et1: nodeResult, et2: nodeResult, n: int) => {
switch ((et1, et2)) {
/* Known cases: */
| (`Distribution(`Float(v1)), `Distribution(`Float(v2))) => {
@ -528,20 +550,20 @@ module DistTree = {
| (`Distribution(d1), `Distribution(d2)) => {
// NOT IMPLEMENTED YET
// TODO: evaluate integral properly
let (sc1, sd1) = renderDistributionToXYShape(d1, sampleCount);
let (sc2, sd2) = renderDistributionToXYShape(d2, sampleCount);
let (sc1, sd1) = renderDistributionToXYShape(d1, n);
let (sc2, sd2) = renderDistributionToXYShape(d2, n);
`RenderedShape(Distributions.Continuous.empty, Distributions.Discrete.empty, 0.)
}
| (`Distribution(d1), `RenderedShape(sc2, sd2, i2)) => {
// NOT IMPLEMENTED YET
// TODO: evaluate integral properly
let (sc1, sd1) = renderDistributionToXYShape(d1, sampleCount);
let (sc1, sd1) = renderDistributionToXYShape(d1, n);
`RenderedShape(Distributions.Continuous.empty, Distributions.Discrete.empty, 0.)
}
| (`RenderedShape(sc1, sd1, i1), `Distribution(d1)) => {
// NOT IMPLEMENTED YET
// TODO: evaluate integral properly
let (sc2, sd2) = renderDistributionToXYShape(d1, sampleCount);
let (sc2, sd2) = renderDistributionToXYShape(d1, n);
`RenderedShape(Distributions.Continuous.empty, Distributions.Discrete.empty, 0.)
}
| (`RenderedShape(sc1, sd1, i1), `RenderedShape(sc2, sd2, i2)) => {
@ -551,7 +573,7 @@ module DistTree = {
};
let evaluateNormalize = (et: nodeResult, sampleCount: int) => {
let evaluateNormalize = (et: nodeResult, n: int) => {
// just divide everything by the integral.
switch (et) {
| `RenderedShape(sc, sd, 0.) => {
@ -570,7 +592,7 @@ module DistTree = {
}
};
let evaluateTruncate = (et: nodeResult, xc: cutoffX, compareFunc: (float, float) => bool, sampleCount: int) => {
let evaluateTruncate = (et: nodeResult, xc: cutoffX, compareFunc: (float, float) => bool, n: int) => {
let cut = (s: DistTypes.xyShape): DistTypes.xyShape => {
let (xs, ys) = s.ys
|> Belt.Array.zip(s.xs)
@ -583,7 +605,7 @@ module DistTree = {
switch (et) {
| `Distribution(d) => {
let (sc, sd) = renderDistributionToXYShape(d, sampleCount);
let (sc, sd) = renderDistributionToXYShape(d, n);
let scc = sc |> Distributions.Continuous.shapeMap(cut);
let sdc = sd |> cut;
@ -603,13 +625,13 @@ module DistTree = {
}
};
let evaluateVerticalScaling = (et1: nodeResult, et2: nodeResult, sampleCount: int) => {
let evaluateVerticalScaling = (et1: nodeResult, et2: nodeResult, n: int) => {
let scale = (i: float, s: DistTypes.xyShape): DistTypes.xyShape => {xs: s.xs, ys: s.ys |> E.A.fmap(y => y *. i)};
switch ((et1, et2)) {
| (`Distribution(`Float(v)), `Distribution(d))
| (`Distribution(d), `Distribution(`Float(v))) => {
let (sc, sd) = renderDistributionToXYShape(d, sampleCount);
let (sc, sd) = renderDistributionToXYShape(d, n);
let scc = sc |> Distributions.Continuous.shapeMap(scale(v));
let sdc = sd |> scale(v);
@ -631,47 +653,34 @@ module DistTree = {
}
}
let renderNode = (et: nodeResult, sampleCount: int) => {
let renderNode = (et: nodeResult, n: int) => {
switch (et) {
| `Distribution(d) => {
let (sc, sd) = renderDistributionToXYShape(d, sampleCount);
let (sc, sd) = renderDistributionToXYShape(d, n);
`RenderedShape(sc, sd, 1.0);
}
| s => s
}
}
let rec evaluateNode = (treeNode: distTree, sampleCount: int): nodeResult => {
let rec evaluateNode = (treeNode: distTree, n: int): nodeResult => {
// returns either a new symbolic distribution
switch (treeNode) {
| `Distribution(d) => evaluateDistribution(d)
| `Combination(t1, t2, op) => evaluateCombinationDistribution(evaluateNode(t1, sampleCount), evaluateNode(t2, sampleCount), op, sampleCount)
| `PointwiseSum(t1, t2) => evaluatePointwiseSum(evaluateNode(t1, sampleCount), evaluateNode(t2, sampleCount), sampleCount)
| `PointwiseProduct(t1, t2) => evaluatePointwiseProduct(evaluateNode(t1, sampleCount), evaluateNode(t2, sampleCount), sampleCount)
| `VerticalScaling(t1, t2) => evaluateVerticalScaling(evaluateNode(t1, sampleCount), evaluateNode(t2, sampleCount), sampleCount)
| `Normalize(t) => evaluateNormalize(evaluateNode(t, sampleCount), sampleCount)
| `LeftTruncate(t, x) => evaluateTruncate(evaluateNode(t, sampleCount), x, (<=), sampleCount)
| `RightTruncate(t, x) => evaluateTruncate(evaluateNode(t, sampleCount), x, (>=), sampleCount)
| `Render(t) => renderNode(evaluateNode(t, sampleCount), sampleCount)
| `Combination(t1, t2, op) => evaluateCombinationDistribution(evaluateNode(t1, n), evaluateNode(t2, n), op, n)
| `PointwiseSum(t1, t2) => evaluatePointwiseSum(evaluateNode(t1, n), evaluateNode(t2, n), n)
| `PointwiseProduct(t1, t2) => evaluatePointwiseProduct(evaluateNode(t1, n), evaluateNode(t2, n), n)
| `VerticalScaling(t1, t2) => evaluateVerticalScaling(evaluateNode(t1, n), evaluateNode(t2, n), n)
| `Normalize(t) => evaluateNormalize(evaluateNode(t, n), n)
| `LeftTruncate(t, x) => evaluateTruncate(evaluateNode(t, n), x, (>=), n)
| `RightTruncate(t, x) => evaluateTruncate(evaluateNode(t, n), x, (<=), n)
| `Render(t) => renderNode(evaluateNode(t, n), n)
}
};
let toShape = (treeNode: distTree, sampleCount: int) => {
/*let continuousXs = findContinuousXs(dists, sampleCount);
continuousXs |> Array.fast_sort(compare);
let toShape = (treeNode: distTree, n: int) => {
let treeShape = evaluateNode(`Render(`Normalize(treeNode)), n);
let (contShapes, distShapes) = accumulateContAndDiscShapes(dists, continuousXs, 1.0);
let combinedContinuous = contShapes
|> E.A.fold_left((shapeAcc: DistTypes.xyShape, shape: DistTypes.xyShape) => {
let ys = E.A.fmapi((i, y) => y +. shape.ys[i], shapeAcc.ys);
{xs: continuousXs, ys: ys}
}, {xs: continuousXs, ys: Array.make(Array.length(continuousXs), 0.0)})
|> Distributions.Continuous.make(`Linear);
let combinedDiscrete = Distributions.Discrete.reduce((+.), distShapes)*/
let treeShape = evaluateNode(`Render(`Normalize(treeNode)), sampleCount);
switch (treeShape) {
| `Distribution(_) => E.O.toExn("No shape found!", None)
| `RenderedShape(sc, sd, _) => {
@ -700,6 +709,7 @@ module DistTree = {
| `Normalize(t) => "normalize(" ++ toString(t) ++ ")"
| `LeftTruncate(t, x) => "leftTruncate(" ++ toString(t) ++ ", " ++ string_of_float(x) ++ ")"
| `RightTruncate(t, x) => "rightTruncate(" ++ toString(t) ++ ", " ++ string_of_float(x) ++ ")"
| `Render(t) => toString(t)
}
};
};