Slightly cleaned up tree evaluation
This commit is contained in:
parent
9b10452156
commit
8827650da3
|
@ -197,6 +197,7 @@ module MathAdtToDistDst = {
|
||||||
};
|
};
|
||||||
|
|
||||||
let arrayParser = (args:array(arg)):result(SymbolicDist.distTree, string) => {
|
let arrayParser = (args:array(arg)):result(SymbolicDist.distTree, string) => {
|
||||||
|
Js.log2("SAMPLING NOW!", args);
|
||||||
let samples = args
|
let samples = args
|
||||||
|> E.A.fmap(
|
|> E.A.fmap(
|
||||||
fun
|
fun
|
||||||
|
@ -287,6 +288,27 @@ module MathAdtToDistDst = {
|
||||||
| [|Ok(l), Ok(r)|] => Ok(`Combination(l, r, `DivideOperation))
|
| [|Ok(l), Ok(r)|] => Ok(`Combination(l, r, `DivideOperation))
|
||||||
| _ => Error("Division needs two operands"))
|
| _ => Error("Division needs two operands"))
|
||||||
}
|
}
|
||||||
|
| Fn({name: "pow", args}) => {
|
||||||
|
args
|
||||||
|
|> E.A.fmap(functionParser)
|
||||||
|
|> (fun
|
||||||
|
| [|Ok(l), Ok(r)|] => Ok(`Combination(l, r, `ExponentiateOperation))
|
||||||
|
| _ => Error("Exponentiations needs two operands"))
|
||||||
|
}
|
||||||
|
| Fn({name: "leftTruncate", args}) => {
|
||||||
|
args
|
||||||
|
|> E.A.fmap(functionParser)
|
||||||
|
|> (fun
|
||||||
|
| [|Ok(l), Ok(`Distribution(`Float(r)))|] => Ok(`LeftTruncate(l, r))
|
||||||
|
| _ => Error("leftTruncate needs two arguments: the expression and the cutoff"))
|
||||||
|
}
|
||||||
|
| Fn({name: "rightTruncate", args}) => {
|
||||||
|
args
|
||||||
|
|> E.A.fmap(functionParser)
|
||||||
|
|> (fun
|
||||||
|
| [|Ok(l), Ok(`Distribution(`Float(r)))|] => Ok(`RightTruncate(l, r))
|
||||||
|
| _ => Error("rightTruncate needs two arguments: the expression and the cutoff"))
|
||||||
|
}
|
||||||
| Fn({name}) => Error(name ++ ": function not supported")
|
| Fn({name}) => Error(name ++ ": function not supported")
|
||||||
| _ => {
|
| _ => {
|
||||||
Error("This type not currently supported");
|
Error("This type not currently supported");
|
||||||
|
|
|
@ -277,34 +277,34 @@ module GenericSimple = {
|
||||||
/* This function returns a list of x's at which to evaluate the overall distribution (for rendering).
|
/* This function returns a list of x's at which to evaluate the overall distribution (for rendering).
|
||||||
This function is called separately for each individual distribution.
|
This function is called separately for each individual distribution.
|
||||||
|
|
||||||
When called with xSelection=`Linear, this function will return (sampleCount) x's, evenly
|
When called with xSelection=`Linear, this function will return (n) x's, evenly
|
||||||
distributed between the min and max of the distribution (whatever those are defined to be above).
|
distributed between the min and max of the distribution (whatever those are defined to be above).
|
||||||
|
|
||||||
When called with xSelection=`ByWeight, this function will distribute the x's such as to
|
When called with xSelection=`ByWeight, this function will distribute the x's such as to
|
||||||
match the cumulative shape of the distribution. This is slower but may give better results.
|
match the cumulative shape of the distribution. This is slower but may give better results.
|
||||||
*/
|
*/
|
||||||
let interpolateXs =
|
let interpolateXs =
|
||||||
(~xSelection: [ | `Linear | `ByWeight]=`Linear, dist: dist, sampleCount) => {
|
(~xSelection: [ | `Linear | `ByWeight]=`Linear, dist: dist, n) => {
|
||||||
switch (xSelection, dist) {
|
switch (xSelection, dist) {
|
||||||
| (`Linear, _) => E.A.Floats.range(min(dist), max(dist), sampleCount)
|
| (`Linear, _) => E.A.Floats.range(min(dist), max(dist), n)
|
||||||
| (`ByWeight, `Uniform(n)) =>
|
| (`ByWeight, `Uniform(n)) =>
|
||||||
// In `ByWeight mode, uniform distributions get special treatment because we need two x's
|
// In `ByWeight mode, uniform distributions get special treatment because we need two x's
|
||||||
// on either side for proper rendering (just left and right of the discontinuities).
|
// on either side for proper rendering (just left and right of the discontinuities).
|
||||||
let dx = 0.00001 *. (n.high -. n.low);
|
let dx = 0.00001 *. (n.high -. n.low);
|
||||||
[|n.low -. dx, n.low +. dx, n.high -. dx, n.high +. dx|];
|
[|n.low -. dx, n.low +. dx, n.high -. dx, n.high +. dx|];
|
||||||
| (`ByWeight, _) =>
|
| (`ByWeight, _) =>
|
||||||
let ys = E.A.Floats.range(minCdfValue, maxCdfValue, sampleCount);
|
let ys = E.A.Floats.range(minCdfValue, maxCdfValue, n);
|
||||||
ys |> E.A.fmap(y => inv(y, dist));
|
ys |> E.A.fmap(y => inv(y, dist));
|
||||||
};
|
};
|
||||||
};
|
};
|
||||||
|
|
||||||
let toShape =
|
let toShape =
|
||||||
(~xSelection: [ | `Linear | `ByWeight]=`Linear, dist: dist, sampleCount)
|
(~xSelection: [ | `Linear | `ByWeight]=`Linear, dist: dist, n)
|
||||||
: DistTypes.shape => {
|
: DistTypes.shape => {
|
||||||
switch (dist) {
|
switch (dist) {
|
||||||
| `ContinuousShape(n) => n.pdf |> Distributions.Continuous.T.toShape
|
| `ContinuousShape(n) => n.pdf |> Distributions.Continuous.T.toShape
|
||||||
| dist =>
|
| dist =>
|
||||||
let xs = interpolateXs(~xSelection, dist, sampleCount);
|
let xs = interpolateXs(~xSelection, dist, n);
|
||||||
let ys = xs |> E.A.fmap(r => pdf(r, dist));
|
let ys = xs |> E.A.fmap(r => pdf(r, dist));
|
||||||
XYShape.T.fromArrays(xs, ys)
|
XYShape.T.fromArrays(xs, ys)
|
||||||
|> Distributions.Continuous.make(`Linear, _)
|
|> Distributions.Continuous.make(`Linear, _)
|
||||||
|
@ -321,23 +321,43 @@ module DistTree = {
|
||||||
];
|
];
|
||||||
|
|
||||||
let evaluateDistribution = (d: dist): nodeResult => {
|
let evaluateDistribution = (d: dist): nodeResult => {
|
||||||
// certain distributions we may want to evaluate to RenderedShapes right away, e.g. discrete
|
|
||||||
`Distribution(d)
|
`Distribution(d)
|
||||||
};
|
};
|
||||||
|
|
||||||
// This is a performance bottleneck!
|
// This is a performance bottleneck!
|
||||||
// Using raw JS here so we can use native for loops and access array elements
|
// Using raw JS here so we can use native for loops and access array elements
|
||||||
// directly, without option checks.
|
// directly, without option checks.
|
||||||
let jsCombinationConvolve: (array(float), array(float), array(float), array(float), float => float => float) => (array(float), array(float)) = [%bs.raw
|
let jsContinuousCombinationConvolve: (array(float), array(float), array(float), array(float), float => float => float) => array(array((float, float))) = [%bs.raw
|
||||||
|
{|
|
||||||
|
function (s1xs, s1ys, s2xs, s2ys, func) {
|
||||||
|
// For continuous-continuous convolution, use linear interpolation.
|
||||||
|
// Let's assume we got downsampled distributions
|
||||||
|
|
||||||
|
const outXYShapes = new Array(s1xs.length);
|
||||||
|
for (let i = 0; i < s1xs.length; i++) {
|
||||||
|
// create a new distribution
|
||||||
|
const dxyShape = new Array(s2xs.length);
|
||||||
|
for (let j = 0; j < s2xs.length; j++) {
|
||||||
|
dxyShape[j] = [func(s1xs[i], s2xs[j]), (s1ys[i] * s2ys[j])];
|
||||||
|
}
|
||||||
|
outXYShapes[i] = dxyShape;
|
||||||
|
}
|
||||||
|
|
||||||
|
return outXYShapes;
|
||||||
|
}
|
||||||
|
|}];
|
||||||
|
|
||||||
|
let jsDiscreteCombinationConvolve: (array(float), array(float), array(float), array(float), float => float => float) => (array(float), array(float)) = [%bs.raw
|
||||||
{|
|
{|
|
||||||
function (s1xs, s1ys, s2xs, s2ys, func) {
|
function (s1xs, s1ys, s2xs, s2ys, func) {
|
||||||
const r = new Map();
|
const r = new Map();
|
||||||
|
|
||||||
// To convolve, add the xs and multiply the ys:
|
|
||||||
for (let i = 0; i < s1xs.length; i++) {
|
for (let i = 0; i < s1xs.length; i++) {
|
||||||
for (let j = 0; j < s2xs.length; j++) {
|
for (let j = 0; j < s2xs.length; j++) {
|
||||||
|
|
||||||
const x = func(s1xs[i], s2xs[j]);
|
const x = func(s1xs[i], s2xs[j]);
|
||||||
const cv = r.get(x) | 0;
|
const cv = r.get(x) | 0;
|
||||||
|
|
||||||
r.set(x, cv + s1ys[i] * s2ys[j]); // add up the ys, if same x
|
r.set(x, cv + s1ys[i] * s2ys[j]); // add up the ys, if same x
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -367,12 +387,12 @@ module DistTree = {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
let renderDistributionToXYShape = (d: dist, sampleCount: int): (DistTypes.continuousShape, DistTypes.discreteShape) => {
|
let renderDistributionToXYShape = (d: dist, n: int): (DistTypes.continuousShape, DistTypes.discreteShape) => {
|
||||||
// render the distribution into an XY shape
|
// render the distribution into an XY shape
|
||||||
switch (d) {
|
switch (d) {
|
||||||
| `Float(v) => (Distributions.Continuous.empty, {xs: [|v|], ys: [|1.0|]})
|
| `Float(v) => (Distributions.Continuous.empty, {xs: [|v|], ys: [|1.0|]})
|
||||||
| _ => {
|
| _ => {
|
||||||
let xs = GenericSimple.interpolateXs(~xSelection=`ByWeight, d, sampleCount);
|
let xs = GenericSimple.interpolateXs(~xSelection=`ByWeight, d, n);
|
||||||
let ys = xs |> E.A.fmap(x => GenericSimple.pdf(x, d));
|
let ys = xs |> E.A.fmap(x => GenericSimple.pdf(x, d));
|
||||||
(Distributions.Continuous.make(`Linear, {xs: xs, ys: ys}), XYShape.T.empty)
|
(Distributions.Continuous.make(`Linear, {xs: xs, ys: ys}), XYShape.T.empty)
|
||||||
}
|
}
|
||||||
|
@ -384,23 +404,37 @@ module DistTree = {
|
||||||
sc2: DistTypes.continuousShape,
|
sc2: DistTypes.continuousShape,
|
||||||
sd2: DistTypes.discreteShape, func): (DistTypes.continuousShape, DistTypes.discreteShape) => {
|
sd2: DistTypes.discreteShape, func): (DistTypes.continuousShape, DistTypes.discreteShape) => {
|
||||||
|
|
||||||
let (ccxs, ccys) = jsCombinationConvolve(sc1.xyShape.xs, sc1.xyShape.ys, sc2.xyShape.xs, sc2.xyShape.ys, func);
|
// First, deal with the discrete-discrete convolution:
|
||||||
let (dcxs, dcys) = jsCombinationConvolve(sd1.xs, sd1.ys, sc2.xyShape.xs, sc2.xyShape.ys, func);
|
let (ddxs, ddys) = jsDiscreteCombinationConvolve(sd1.xs, sd1.ys, sd2.xs, sd2.ys, func);
|
||||||
let (cdxs, cdys) = jsCombinationConvolve(sc1.xyShape.xs, sc1.xyShape.ys, sd2.xs, sd2.ys, func);
|
|
||||||
let (ddxs, ddys) = jsCombinationConvolve(sd1.xs, sd1.ys, sd2.xs, sd2.ys, func);
|
|
||||||
|
|
||||||
let ccxy = Distributions.Continuous.make(`Linear, {xs: ccxs, ys: ccys});
|
|
||||||
let dcxy = Distributions.Continuous.make(`Linear, {xs: dcxs, ys: dcys});
|
|
||||||
let cdxy = Distributions.Continuous.make(`Linear, {xs: cdxs, ys: cdys});
|
|
||||||
// the continuous parts are added up; only the discrete-discrete sum is discrete
|
|
||||||
let continuousShapeSum = Distributions.Continuous.reduce((+.), [|ccxy, dcxy, cdxy|]);
|
|
||||||
|
|
||||||
let ddxy: DistTypes.discreteShape = {xs: cdxs, ys: cdys};
|
let ddxy: DistTypes.discreteShape = {xs: cdxs, ys: cdys};
|
||||||
|
|
||||||
|
// Then, do the other three:
|
||||||
|
let downsample = (sc: DistTypes.continuousShape) => {
|
||||||
|
let scLength = E.A.length(sc.xyShape.xs);
|
||||||
|
let scSqLength = sqrt(float_of_int(scLength));
|
||||||
|
scSqLength > 10. ? Distributions.Continuous.T.truncate(int_of_float(scSqLength), sc) : sc;
|
||||||
|
};
|
||||||
|
|
||||||
|
let combinePointConvolutionResults = ccs
|
||||||
|
|> E.A.fmap(s => {
|
||||||
|
// s is an array of (x, y) objects
|
||||||
|
let (xs, ys) = Belt.Array.unzip(s);
|
||||||
|
Distributions.Continuous.make(`Linear, {xs, ys});
|
||||||
|
})
|
||||||
|
|> Distributions.Continuous.reduce((+.));
|
||||||
|
|
||||||
|
let sc1d = downsample(sc1);
|
||||||
|
let sc2d = downsample(sc2);
|
||||||
|
|
||||||
|
let ccxy = jsContinuousCombinationConvolve(sc1d.xyShape.xs, sc1d.xyShape.ys, sc2d.xyShape.xs, sc2d.xyShape.ys, func) |> combinePointConvolutionResults;
|
||||||
|
let dcxy = jsContinuousCombinationConvolve(sc1d.xyShape.xs, sc1d.xyShape.ys, sc2d.xyShape.xs, sc2d.xyShape.ys, func) |> combinePointConvolutionResults;
|
||||||
|
let cdxy = jsContinuousCombinationConvolve(sc1d.xyShape.xs, sc1d.xyShape.ys, sc2d.xyShape.xs, sc2d.xyShape.ys, func) |> combinePointConvolutionResults;
|
||||||
|
let continuousShapeSum = Distributions.Continuous.reduce((+.), [|ccxy, dcxy, cdxy|]);
|
||||||
|
|
||||||
(continuousShapeSum, ddxy)
|
(continuousShapeSum, ddxy)
|
||||||
};
|
};
|
||||||
|
|
||||||
let evaluateCombinationDistribution = (et1: nodeResult, et2: nodeResult, op: operation, sampleCount: int) => {
|
let evaluateCombinationDistribution = (et1: nodeResult, et2: nodeResult, op: operation, n: int) => {
|
||||||
/* return either a Distribution or a RenderedShape. Must integrate to 1. */
|
/* return either a Distribution or a RenderedShape. Must integrate to 1. */
|
||||||
|
|
||||||
let func = funcFromOp(op);
|
let func = funcFromOp(op);
|
||||||
|
@ -439,31 +473,26 @@ module DistTree = {
|
||||||
|
|
||||||
/* General cases: convolve the XYShapes */
|
/* General cases: convolve the XYShapes */
|
||||||
| (`Distribution(d1), `Distribution(d2), _) => {
|
| (`Distribution(d1), `Distribution(d2), _) => {
|
||||||
let (sc1, sd1) = renderDistributionToXYShape(d1, sampleCount);
|
let (sc1, sd1) = renderDistributionToXYShape(d1, n);
|
||||||
let (sc2, sd2) = renderDistributionToXYShape(d2, sampleCount);
|
let (sc2, sd2) = renderDistributionToXYShape(d2, n);
|
||||||
let (sc, sd) = combinationDistributionOfXYShapes(sc1, sd1, sc2, sd2, func);
|
let (sc, sd) = combinationDistributionOfXYShapes(sc1, sd1, sc2, sd2, func);
|
||||||
`RenderedShape(sc, sd, 1.0)
|
`RenderedShape(sc, sd, 1.0)
|
||||||
}
|
}
|
||||||
| (`Distribution(d1), `RenderedShape(sc2, sd2, i2), _) => {
|
| (`Distribution(d2), `RenderedShape(sc1, sd1, i1), _)
|
||||||
let (sc1, sd1) = renderDistributionToXYShape(d1, sampleCount);
|
| (`RenderedShape(sc1, sd1, i1), `Distribution(d2), _) => {
|
||||||
|
let (sc1, sd1) = renderDistributionToXYShape(d1, n);
|
||||||
let (sc, sd) = combinationDistributionOfXYShapes(sc1, sd1, sc2, sd2, func);
|
let (sc, sd) = combinationDistributionOfXYShapes(sc1, sd1, sc2, sd2, func);
|
||||||
`RenderedShape(sc, sd, i2)
|
`RenderedShape(sc, sd, i2)
|
||||||
}
|
}
|
||||||
| (`RenderedShape(sc1, sd1, i1), `Distribution(d2), _) => {
|
|
||||||
let (sc2, sd2) = renderDistributionToXYShape(d2, sampleCount);
|
|
||||||
let (sc, sd) = combinationDistributionOfXYShapes(sc1, sd1, sc2, sd2, func);
|
|
||||||
`RenderedShape(sc, sd, i1);
|
|
||||||
}
|
|
||||||
| (`RenderedShape(sc1, sd1, i1), `RenderedShape(sc2, sd2, i2), _) => {
|
| (`RenderedShape(sc1, sd1, i1), `RenderedShape(sc2, sd2, i2), _) => {
|
||||||
// sum of two multimodals that have a continuous and discrete each.
|
// sum of two multimodals that have a continuous and discrete each.
|
||||||
let (sc, sd) = combinationDistributionOfXYShapes(sc1, sd1, sc2, sd2, func);
|
let (sc, sd) = combinationDistributionOfXYShapes(sc1, sd1, sc2, sd2, func);
|
||||||
|
|
||||||
`RenderedShape(sc, sd, i1);
|
`RenderedShape(sc, sd, i1);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
let evaluatePointwiseSum = (et1: nodeResult, et2: nodeResult, sampleCount: int) => {
|
let evaluatePointwiseSum = (et1: nodeResult, et2: nodeResult, n: int) => {
|
||||||
switch ((et1, et2)) {
|
switch ((et1, et2)) {
|
||||||
/* Known cases: */
|
/* Known cases: */
|
||||||
| (`Distribution(`Float(v1)), `Distribution(`Float(v2))) => {
|
| (`Distribution(`Float(v1)), `Distribution(`Float(v2))) => {
|
||||||
|
@ -471,36 +500,29 @@ module DistTree = {
|
||||||
? `RenderedShape(Distributions.Continuous.empty, Distributions.Discrete.make({xs: [|v1|], ys: [|2.|]}), 2.)
|
? `RenderedShape(Distributions.Continuous.empty, Distributions.Discrete.make({xs: [|v1|], ys: [|2.|]}), 2.)
|
||||||
: `RenderedShape(Distributions.Continuous.empty, Distributions.Discrete.empty, 0.) // TODO: add warning: shouldn't pointwise add scalars.
|
: `RenderedShape(Distributions.Continuous.empty, Distributions.Discrete.empty, 0.) // TODO: add warning: shouldn't pointwise add scalars.
|
||||||
}
|
}
|
||||||
| (`Distribution(`Float(v1)), `Distribution(d2)) => {
|
| (`Distribution(`Float(v1)), `Distribution(d2))
|
||||||
|
| (`Distribution(d2), `Distribution(`Float(v1))) => {
|
||||||
let sd1: DistTypes.xyShape = {xs: [|v1|], ys: [|1.|]};
|
let sd1: DistTypes.xyShape = {xs: [|v1|], ys: [|1.|]};
|
||||||
let (sc2, sd2) = renderDistributionToXYShape(d2, sampleCount);
|
let (sc2, sd2) = renderDistributionToXYShape(d2, n);
|
||||||
`RenderedShape(sc2, Distributions.Discrete.reduce((+.), [|sd1, sd2|]), 2.)
|
`RenderedShape(sc2, Distributions.Discrete.reduce((+.), [|sd1, sd2|]), 2.)
|
||||||
}
|
}
|
||||||
| (`Distribution(d1), `Distribution(`Float(v2))) => {
|
|
||||||
let (sc1, sd1) = renderDistributionToXYShape(d1, sampleCount);
|
|
||||||
let sd2: DistTypes.xyShape = {xs: [|v2|], ys: [|1.|]};
|
|
||||||
`RenderedShape(sc1, Distributions.Discrete.reduce((+.), [|sd1, sd2|]), 2.)
|
|
||||||
}
|
|
||||||
| (`Distribution(d1), `Distribution(d2)) => {
|
| (`Distribution(d1), `Distribution(d2)) => {
|
||||||
let (sc1, sd1) = renderDistributionToXYShape(d1, sampleCount);
|
let (sc1, sd1) = renderDistributionToXYShape(d1, n);
|
||||||
let (sc2, sd2) = renderDistributionToXYShape(d2, sampleCount);
|
let (sc2, sd2) = renderDistributionToXYShape(d2, n);
|
||||||
`RenderedShape(Distributions.Continuous.reduce((+.), [|sc1, sc2|]), Distributions.Discrete.reduce((+.), [|sd1, sd2|]), 2.)
|
`RenderedShape(Distributions.Continuous.reduce((+.), [|sc1, sc2|]), Distributions.Discrete.reduce((+.), [|sd1, sd2|]), 2.)
|
||||||
}
|
}
|
||||||
| (`Distribution(d1), `RenderedShape(sc2, sd2, i2))
|
| (`Distribution(d1), `RenderedShape(sc2, sd2, i2))
|
||||||
| (`RenderedShape(sc2, sd2, i2), `Distribution(d1)) => {
|
| (`RenderedShape(sc2, sd2, i2), `Distribution(d1)) => {
|
||||||
let (sc1, sd1) = renderDistributionToXYShape(d1, sampleCount);
|
let (sc1, sd1) = renderDistributionToXYShape(d1, n);
|
||||||
`RenderedShape(Distributions.Continuous.reduce((+.), [|sc1, sc2|]), Distributions.Discrete.reduce((+.), [|sd1, sd2|]), 1. +. i2)
|
`RenderedShape(Distributions.Continuous.reduce((+.), [|sc1, sc2|]), Distributions.Discrete.reduce((+.), [|sd1, sd2|]), 1. +. i2)
|
||||||
}
|
}
|
||||||
| (`RenderedShape(sc1, sd1, i1), `RenderedShape(sc2, sd2, i2)) => {
|
| (`RenderedShape(sc1, sd1, i1), `RenderedShape(sc2, sd2, i2)) => {
|
||||||
Js.log3("Reducing continuous rr", sc1, sc2);
|
|
||||||
Js.log2("Continuous reduction:", Distributions.Continuous.reduce((+.), [|sc1, sc2|]));
|
|
||||||
Js.log2("Discrete reduction:", Distributions.Discrete.reduce((+.), [|sd1, sd2|]));
|
|
||||||
`RenderedShape(Distributions.Continuous.reduce((+.), [|sc1, sc2|]), Distributions.Discrete.reduce((+.), [|sd1, sd2|]), i1 +. i2)
|
`RenderedShape(Distributions.Continuous.reduce((+.), [|sc1, sc2|]), Distributions.Discrete.reduce((+.), [|sd1, sd2|]), i1 +. i2)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
let evaluatePointwiseProduct = (et1: nodeResult, et2: nodeResult, sampleCount: int) => {
|
let evaluatePointwiseProduct = (et1: nodeResult, et2: nodeResult, n: int) => {
|
||||||
switch ((et1, et2)) {
|
switch ((et1, et2)) {
|
||||||
/* Known cases: */
|
/* Known cases: */
|
||||||
| (`Distribution(`Float(v1)), `Distribution(`Float(v2))) => {
|
| (`Distribution(`Float(v1)), `Distribution(`Float(v2))) => {
|
||||||
|
@ -528,20 +550,20 @@ module DistTree = {
|
||||||
| (`Distribution(d1), `Distribution(d2)) => {
|
| (`Distribution(d1), `Distribution(d2)) => {
|
||||||
// NOT IMPLEMENTED YET
|
// NOT IMPLEMENTED YET
|
||||||
// TODO: evaluate integral properly
|
// TODO: evaluate integral properly
|
||||||
let (sc1, sd1) = renderDistributionToXYShape(d1, sampleCount);
|
let (sc1, sd1) = renderDistributionToXYShape(d1, n);
|
||||||
let (sc2, sd2) = renderDistributionToXYShape(d2, sampleCount);
|
let (sc2, sd2) = renderDistributionToXYShape(d2, n);
|
||||||
`RenderedShape(Distributions.Continuous.empty, Distributions.Discrete.empty, 0.)
|
`RenderedShape(Distributions.Continuous.empty, Distributions.Discrete.empty, 0.)
|
||||||
}
|
}
|
||||||
| (`Distribution(d1), `RenderedShape(sc2, sd2, i2)) => {
|
| (`Distribution(d1), `RenderedShape(sc2, sd2, i2)) => {
|
||||||
// NOT IMPLEMENTED YET
|
// NOT IMPLEMENTED YET
|
||||||
// TODO: evaluate integral properly
|
// TODO: evaluate integral properly
|
||||||
let (sc1, sd1) = renderDistributionToXYShape(d1, sampleCount);
|
let (sc1, sd1) = renderDistributionToXYShape(d1, n);
|
||||||
`RenderedShape(Distributions.Continuous.empty, Distributions.Discrete.empty, 0.)
|
`RenderedShape(Distributions.Continuous.empty, Distributions.Discrete.empty, 0.)
|
||||||
}
|
}
|
||||||
| (`RenderedShape(sc1, sd1, i1), `Distribution(d1)) => {
|
| (`RenderedShape(sc1, sd1, i1), `Distribution(d1)) => {
|
||||||
// NOT IMPLEMENTED YET
|
// NOT IMPLEMENTED YET
|
||||||
// TODO: evaluate integral properly
|
// TODO: evaluate integral properly
|
||||||
let (sc2, sd2) = renderDistributionToXYShape(d1, sampleCount);
|
let (sc2, sd2) = renderDistributionToXYShape(d1, n);
|
||||||
`RenderedShape(Distributions.Continuous.empty, Distributions.Discrete.empty, 0.)
|
`RenderedShape(Distributions.Continuous.empty, Distributions.Discrete.empty, 0.)
|
||||||
}
|
}
|
||||||
| (`RenderedShape(sc1, sd1, i1), `RenderedShape(sc2, sd2, i2)) => {
|
| (`RenderedShape(sc1, sd1, i1), `RenderedShape(sc2, sd2, i2)) => {
|
||||||
|
@ -551,7 +573,7 @@ module DistTree = {
|
||||||
};
|
};
|
||||||
|
|
||||||
|
|
||||||
let evaluateNormalize = (et: nodeResult, sampleCount: int) => {
|
let evaluateNormalize = (et: nodeResult, n: int) => {
|
||||||
// just divide everything by the integral.
|
// just divide everything by the integral.
|
||||||
switch (et) {
|
switch (et) {
|
||||||
| `RenderedShape(sc, sd, 0.) => {
|
| `RenderedShape(sc, sd, 0.) => {
|
||||||
|
@ -570,7 +592,7 @@ module DistTree = {
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
let evaluateTruncate = (et: nodeResult, xc: cutoffX, compareFunc: (float, float) => bool, sampleCount: int) => {
|
let evaluateTruncate = (et: nodeResult, xc: cutoffX, compareFunc: (float, float) => bool, n: int) => {
|
||||||
let cut = (s: DistTypes.xyShape): DistTypes.xyShape => {
|
let cut = (s: DistTypes.xyShape): DistTypes.xyShape => {
|
||||||
let (xs, ys) = s.ys
|
let (xs, ys) = s.ys
|
||||||
|> Belt.Array.zip(s.xs)
|
|> Belt.Array.zip(s.xs)
|
||||||
|
@ -583,7 +605,7 @@ module DistTree = {
|
||||||
|
|
||||||
switch (et) {
|
switch (et) {
|
||||||
| `Distribution(d) => {
|
| `Distribution(d) => {
|
||||||
let (sc, sd) = renderDistributionToXYShape(d, sampleCount);
|
let (sc, sd) = renderDistributionToXYShape(d, n);
|
||||||
|
|
||||||
let scc = sc |> Distributions.Continuous.shapeMap(cut);
|
let scc = sc |> Distributions.Continuous.shapeMap(cut);
|
||||||
let sdc = sd |> cut;
|
let sdc = sd |> cut;
|
||||||
|
@ -603,13 +625,13 @@ module DistTree = {
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
let evaluateVerticalScaling = (et1: nodeResult, et2: nodeResult, sampleCount: int) => {
|
let evaluateVerticalScaling = (et1: nodeResult, et2: nodeResult, n: int) => {
|
||||||
let scale = (i: float, s: DistTypes.xyShape): DistTypes.xyShape => {xs: s.xs, ys: s.ys |> E.A.fmap(y => y *. i)};
|
let scale = (i: float, s: DistTypes.xyShape): DistTypes.xyShape => {xs: s.xs, ys: s.ys |> E.A.fmap(y => y *. i)};
|
||||||
|
|
||||||
switch ((et1, et2)) {
|
switch ((et1, et2)) {
|
||||||
| (`Distribution(`Float(v)), `Distribution(d))
|
| (`Distribution(`Float(v)), `Distribution(d))
|
||||||
| (`Distribution(d), `Distribution(`Float(v))) => {
|
| (`Distribution(d), `Distribution(`Float(v))) => {
|
||||||
let (sc, sd) = renderDistributionToXYShape(d, sampleCount);
|
let (sc, sd) = renderDistributionToXYShape(d, n);
|
||||||
|
|
||||||
let scc = sc |> Distributions.Continuous.shapeMap(scale(v));
|
let scc = sc |> Distributions.Continuous.shapeMap(scale(v));
|
||||||
let sdc = sd |> scale(v);
|
let sdc = sd |> scale(v);
|
||||||
|
@ -631,47 +653,34 @@ module DistTree = {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
let renderNode = (et: nodeResult, sampleCount: int) => {
|
let renderNode = (et: nodeResult, n: int) => {
|
||||||
switch (et) {
|
switch (et) {
|
||||||
| `Distribution(d) => {
|
| `Distribution(d) => {
|
||||||
let (sc, sd) = renderDistributionToXYShape(d, sampleCount);
|
let (sc, sd) = renderDistributionToXYShape(d, n);
|
||||||
`RenderedShape(sc, sd, 1.0);
|
`RenderedShape(sc, sd, 1.0);
|
||||||
}
|
}
|
||||||
| s => s
|
| s => s
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
let rec evaluateNode = (treeNode: distTree, sampleCount: int): nodeResult => {
|
let rec evaluateNode = (treeNode: distTree, n: int): nodeResult => {
|
||||||
// returns either a new symbolic distribution
|
// returns either a new symbolic distribution
|
||||||
switch (treeNode) {
|
switch (treeNode) {
|
||||||
| `Distribution(d) => evaluateDistribution(d)
|
| `Distribution(d) => evaluateDistribution(d)
|
||||||
| `Combination(t1, t2, op) => evaluateCombinationDistribution(evaluateNode(t1, sampleCount), evaluateNode(t2, sampleCount), op, sampleCount)
|
| `Combination(t1, t2, op) => evaluateCombinationDistribution(evaluateNode(t1, n), evaluateNode(t2, n), op, n)
|
||||||
| `PointwiseSum(t1, t2) => evaluatePointwiseSum(evaluateNode(t1, sampleCount), evaluateNode(t2, sampleCount), sampleCount)
|
| `PointwiseSum(t1, t2) => evaluatePointwiseSum(evaluateNode(t1, n), evaluateNode(t2, n), n)
|
||||||
| `PointwiseProduct(t1, t2) => evaluatePointwiseProduct(evaluateNode(t1, sampleCount), evaluateNode(t2, sampleCount), sampleCount)
|
| `PointwiseProduct(t1, t2) => evaluatePointwiseProduct(evaluateNode(t1, n), evaluateNode(t2, n), n)
|
||||||
| `VerticalScaling(t1, t2) => evaluateVerticalScaling(evaluateNode(t1, sampleCount), evaluateNode(t2, sampleCount), sampleCount)
|
| `VerticalScaling(t1, t2) => evaluateVerticalScaling(evaluateNode(t1, n), evaluateNode(t2, n), n)
|
||||||
| `Normalize(t) => evaluateNormalize(evaluateNode(t, sampleCount), sampleCount)
|
| `Normalize(t) => evaluateNormalize(evaluateNode(t, n), n)
|
||||||
| `LeftTruncate(t, x) => evaluateTruncate(evaluateNode(t, sampleCount), x, (<=), sampleCount)
|
| `LeftTruncate(t, x) => evaluateTruncate(evaluateNode(t, n), x, (>=), n)
|
||||||
| `RightTruncate(t, x) => evaluateTruncate(evaluateNode(t, sampleCount), x, (>=), sampleCount)
|
| `RightTruncate(t, x) => evaluateTruncate(evaluateNode(t, n), x, (<=), n)
|
||||||
| `Render(t) => renderNode(evaluateNode(t, sampleCount), sampleCount)
|
| `Render(t) => renderNode(evaluateNode(t, n), n)
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
let toShape = (treeNode: distTree, sampleCount: int) => {
|
let toShape = (treeNode: distTree, n: int) => {
|
||||||
/*let continuousXs = findContinuousXs(dists, sampleCount);
|
let treeShape = evaluateNode(`Render(`Normalize(treeNode)), n);
|
||||||
continuousXs |> Array.fast_sort(compare);
|
|
||||||
|
|
||||||
let (contShapes, distShapes) = accumulateContAndDiscShapes(dists, continuousXs, 1.0);
|
|
||||||
|
|
||||||
let combinedContinuous = contShapes
|
|
||||||
|> E.A.fold_left((shapeAcc: DistTypes.xyShape, shape: DistTypes.xyShape) => {
|
|
||||||
let ys = E.A.fmapi((i, y) => y +. shape.ys[i], shapeAcc.ys);
|
|
||||||
{xs: continuousXs, ys: ys}
|
|
||||||
}, {xs: continuousXs, ys: Array.make(Array.length(continuousXs), 0.0)})
|
|
||||||
|> Distributions.Continuous.make(`Linear);
|
|
||||||
|
|
||||||
let combinedDiscrete = Distributions.Discrete.reduce((+.), distShapes)*/
|
|
||||||
|
|
||||||
let treeShape = evaluateNode(`Render(`Normalize(treeNode)), sampleCount);
|
|
||||||
switch (treeShape) {
|
switch (treeShape) {
|
||||||
| `Distribution(_) => E.O.toExn("No shape found!", None)
|
| `Distribution(_) => E.O.toExn("No shape found!", None)
|
||||||
| `RenderedShape(sc, sd, _) => {
|
| `RenderedShape(sc, sd, _) => {
|
||||||
|
@ -700,6 +709,7 @@ module DistTree = {
|
||||||
| `Normalize(t) => "normalize(" ++ toString(t) ++ ")"
|
| `Normalize(t) => "normalize(" ++ toString(t) ++ ")"
|
||||||
| `LeftTruncate(t, x) => "leftTruncate(" ++ toString(t) ++ ", " ++ string_of_float(x) ++ ")"
|
| `LeftTruncate(t, x) => "leftTruncate(" ++ toString(t) ++ ", " ++ string_of_float(x) ++ ")"
|
||||||
| `RightTruncate(t, x) => "rightTruncate(" ++ toString(t) ++ ", " ++ string_of_float(x) ++ ")"
|
| `RightTruncate(t, x) => "rightTruncate(" ++ toString(t) ++ ", " ++ string_of_float(x) ++ ")"
|
||||||
|
| `Render(t) => toString(t)
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
};
|
};
|
||||||
|
|
Loading…
Reference in New Issue
Block a user