Slightly cleaned up tree evaluation

This commit is contained in:
Sebastian Kosch 2020-06-13 18:46:38 -07:00
parent 9b10452156
commit 8827650da3
2 changed files with 115 additions and 83 deletions

View File

@ -197,6 +197,7 @@ module MathAdtToDistDst = {
}; };
let arrayParser = (args:array(arg)):result(SymbolicDist.distTree, string) => { let arrayParser = (args:array(arg)):result(SymbolicDist.distTree, string) => {
Js.log2("SAMPLING NOW!", args);
let samples = args let samples = args
|> E.A.fmap( |> E.A.fmap(
fun fun
@ -287,6 +288,27 @@ module MathAdtToDistDst = {
| [|Ok(l), Ok(r)|] => Ok(`Combination(l, r, `DivideOperation)) | [|Ok(l), Ok(r)|] => Ok(`Combination(l, r, `DivideOperation))
| _ => Error("Division needs two operands")) | _ => Error("Division needs two operands"))
} }
| Fn({name: "pow", args}) => {
args
|> E.A.fmap(functionParser)
|> (fun
| [|Ok(l), Ok(r)|] => Ok(`Combination(l, r, `ExponentiateOperation))
| _ => Error("Exponentiations needs two operands"))
}
| Fn({name: "leftTruncate", args}) => {
args
|> E.A.fmap(functionParser)
|> (fun
| [|Ok(l), Ok(`Distribution(`Float(r)))|] => Ok(`LeftTruncate(l, r))
| _ => Error("leftTruncate needs two arguments: the expression and the cutoff"))
}
| Fn({name: "rightTruncate", args}) => {
args
|> E.A.fmap(functionParser)
|> (fun
| [|Ok(l), Ok(`Distribution(`Float(r)))|] => Ok(`RightTruncate(l, r))
| _ => Error("rightTruncate needs two arguments: the expression and the cutoff"))
}
| Fn({name}) => Error(name ++ ": function not supported") | Fn({name}) => Error(name ++ ": function not supported")
| _ => { | _ => {
Error("This type not currently supported"); Error("This type not currently supported");

View File

@ -277,34 +277,34 @@ module GenericSimple = {
/* This function returns a list of x's at which to evaluate the overall distribution (for rendering). /* This function returns a list of x's at which to evaluate the overall distribution (for rendering).
This function is called separately for each individual distribution. This function is called separately for each individual distribution.
When called with xSelection=`Linear, this function will return (sampleCount) x's, evenly When called with xSelection=`Linear, this function will return (n) x's, evenly
distributed between the min and max of the distribution (whatever those are defined to be above). distributed between the min and max of the distribution (whatever those are defined to be above).
When called with xSelection=`ByWeight, this function will distribute the x's such as to When called with xSelection=`ByWeight, this function will distribute the x's such as to
match the cumulative shape of the distribution. This is slower but may give better results. match the cumulative shape of the distribution. This is slower but may give better results.
*/ */
let interpolateXs = let interpolateXs =
(~xSelection: [ | `Linear | `ByWeight]=`Linear, dist: dist, sampleCount) => { (~xSelection: [ | `Linear | `ByWeight]=`Linear, dist: dist, n) => {
switch (xSelection, dist) { switch (xSelection, dist) {
| (`Linear, _) => E.A.Floats.range(min(dist), max(dist), sampleCount) | (`Linear, _) => E.A.Floats.range(min(dist), max(dist), n)
| (`ByWeight, `Uniform(n)) => | (`ByWeight, `Uniform(n)) =>
// In `ByWeight mode, uniform distributions get special treatment because we need two x's // In `ByWeight mode, uniform distributions get special treatment because we need two x's
// on either side for proper rendering (just left and right of the discontinuities). // on either side for proper rendering (just left and right of the discontinuities).
let dx = 0.00001 *. (n.high -. n.low); let dx = 0.00001 *. (n.high -. n.low);
[|n.low -. dx, n.low +. dx, n.high -. dx, n.high +. dx|]; [|n.low -. dx, n.low +. dx, n.high -. dx, n.high +. dx|];
| (`ByWeight, _) => | (`ByWeight, _) =>
let ys = E.A.Floats.range(minCdfValue, maxCdfValue, sampleCount); let ys = E.A.Floats.range(minCdfValue, maxCdfValue, n);
ys |> E.A.fmap(y => inv(y, dist)); ys |> E.A.fmap(y => inv(y, dist));
}; };
}; };
let toShape = let toShape =
(~xSelection: [ | `Linear | `ByWeight]=`Linear, dist: dist, sampleCount) (~xSelection: [ | `Linear | `ByWeight]=`Linear, dist: dist, n)
: DistTypes.shape => { : DistTypes.shape => {
switch (dist) { switch (dist) {
| `ContinuousShape(n) => n.pdf |> Distributions.Continuous.T.toShape | `ContinuousShape(n) => n.pdf |> Distributions.Continuous.T.toShape
| dist => | dist =>
let xs = interpolateXs(~xSelection, dist, sampleCount); let xs = interpolateXs(~xSelection, dist, n);
let ys = xs |> E.A.fmap(r => pdf(r, dist)); let ys = xs |> E.A.fmap(r => pdf(r, dist));
XYShape.T.fromArrays(xs, ys) XYShape.T.fromArrays(xs, ys)
|> Distributions.Continuous.make(`Linear, _) |> Distributions.Continuous.make(`Linear, _)
@ -321,23 +321,43 @@ module DistTree = {
]; ];
let evaluateDistribution = (d: dist): nodeResult => { let evaluateDistribution = (d: dist): nodeResult => {
// certain distributions we may want to evaluate to RenderedShapes right away, e.g. discrete
`Distribution(d) `Distribution(d)
}; };
// This is a performance bottleneck! // This is a performance bottleneck!
// Using raw JS here so we can use native for loops and access array elements // Using raw JS here so we can use native for loops and access array elements
// directly, without option checks. // directly, without option checks.
let jsCombinationConvolve: (array(float), array(float), array(float), array(float), float => float => float) => (array(float), array(float)) = [%bs.raw let jsContinuousCombinationConvolve: (array(float), array(float), array(float), array(float), float => float => float) => array(array((float, float))) = [%bs.raw
{|
function (s1xs, s1ys, s2xs, s2ys, func) {
// For continuous-continuous convolution, use linear interpolation.
// Let's assume we got downsampled distributions
const outXYShapes = new Array(s1xs.length);
for (let i = 0; i < s1xs.length; i++) {
// create a new distribution
const dxyShape = new Array(s2xs.length);
for (let j = 0; j < s2xs.length; j++) {
dxyShape[j] = [func(s1xs[i], s2xs[j]), (s1ys[i] * s2ys[j])];
}
outXYShapes[i] = dxyShape;
}
return outXYShapes;
}
|}];
let jsDiscreteCombinationConvolve: (array(float), array(float), array(float), array(float), float => float => float) => (array(float), array(float)) = [%bs.raw
{| {|
function (s1xs, s1ys, s2xs, s2ys, func) { function (s1xs, s1ys, s2xs, s2ys, func) {
const r = new Map(); const r = new Map();
// To convolve, add the xs and multiply the ys:
for (let i = 0; i < s1xs.length; i++) { for (let i = 0; i < s1xs.length; i++) {
for (let j = 0; j < s2xs.length; j++) { for (let j = 0; j < s2xs.length; j++) {
const x = func(s1xs[i], s2xs[j]); const x = func(s1xs[i], s2xs[j]);
const cv = r.get(x) | 0; const cv = r.get(x) | 0;
r.set(x, cv + s1ys[i] * s2ys[j]); // add up the ys, if same x r.set(x, cv + s1ys[i] * s2ys[j]); // add up the ys, if same x
} }
} }
@ -367,12 +387,12 @@ module DistTree = {
} }
} }
let renderDistributionToXYShape = (d: dist, sampleCount: int): (DistTypes.continuousShape, DistTypes.discreteShape) => { let renderDistributionToXYShape = (d: dist, n: int): (DistTypes.continuousShape, DistTypes.discreteShape) => {
// render the distribution into an XY shape // render the distribution into an XY shape
switch (d) { switch (d) {
| `Float(v) => (Distributions.Continuous.empty, {xs: [|v|], ys: [|1.0|]}) | `Float(v) => (Distributions.Continuous.empty, {xs: [|v|], ys: [|1.0|]})
| _ => { | _ => {
let xs = GenericSimple.interpolateXs(~xSelection=`ByWeight, d, sampleCount); let xs = GenericSimple.interpolateXs(~xSelection=`ByWeight, d, n);
let ys = xs |> E.A.fmap(x => GenericSimple.pdf(x, d)); let ys = xs |> E.A.fmap(x => GenericSimple.pdf(x, d));
(Distributions.Continuous.make(`Linear, {xs: xs, ys: ys}), XYShape.T.empty) (Distributions.Continuous.make(`Linear, {xs: xs, ys: ys}), XYShape.T.empty)
} }
@ -384,23 +404,37 @@ module DistTree = {
sc2: DistTypes.continuousShape, sc2: DistTypes.continuousShape,
sd2: DistTypes.discreteShape, func): (DistTypes.continuousShape, DistTypes.discreteShape) => { sd2: DistTypes.discreteShape, func): (DistTypes.continuousShape, DistTypes.discreteShape) => {
let (ccxs, ccys) = jsCombinationConvolve(sc1.xyShape.xs, sc1.xyShape.ys, sc2.xyShape.xs, sc2.xyShape.ys, func); // First, deal with the discrete-discrete convolution:
let (dcxs, dcys) = jsCombinationConvolve(sd1.xs, sd1.ys, sc2.xyShape.xs, sc2.xyShape.ys, func); let (ddxs, ddys) = jsDiscreteCombinationConvolve(sd1.xs, sd1.ys, sd2.xs, sd2.ys, func);
let (cdxs, cdys) = jsCombinationConvolve(sc1.xyShape.xs, sc1.xyShape.ys, sd2.xs, sd2.ys, func);
let (ddxs, ddys) = jsCombinationConvolve(sd1.xs, sd1.ys, sd2.xs, sd2.ys, func);
let ccxy = Distributions.Continuous.make(`Linear, {xs: ccxs, ys: ccys});
let dcxy = Distributions.Continuous.make(`Linear, {xs: dcxs, ys: dcys});
let cdxy = Distributions.Continuous.make(`Linear, {xs: cdxs, ys: cdys});
// the continuous parts are added up; only the discrete-discrete sum is discrete
let continuousShapeSum = Distributions.Continuous.reduce((+.), [|ccxy, dcxy, cdxy|]);
let ddxy: DistTypes.discreteShape = {xs: cdxs, ys: cdys}; let ddxy: DistTypes.discreteShape = {xs: cdxs, ys: cdys};
// Then, do the other three:
let downsample = (sc: DistTypes.continuousShape) => {
let scLength = E.A.length(sc.xyShape.xs);
let scSqLength = sqrt(float_of_int(scLength));
scSqLength > 10. ? Distributions.Continuous.T.truncate(int_of_float(scSqLength), sc) : sc;
};
let combinePointConvolutionResults = ccs
|> E.A.fmap(s => {
// s is an array of (x, y) objects
let (xs, ys) = Belt.Array.unzip(s);
Distributions.Continuous.make(`Linear, {xs, ys});
})
|> Distributions.Continuous.reduce((+.));
let sc1d = downsample(sc1);
let sc2d = downsample(sc2);
let ccxy = jsContinuousCombinationConvolve(sc1d.xyShape.xs, sc1d.xyShape.ys, sc2d.xyShape.xs, sc2d.xyShape.ys, func) |> combinePointConvolutionResults;
let dcxy = jsContinuousCombinationConvolve(sc1d.xyShape.xs, sc1d.xyShape.ys, sc2d.xyShape.xs, sc2d.xyShape.ys, func) |> combinePointConvolutionResults;
let cdxy = jsContinuousCombinationConvolve(sc1d.xyShape.xs, sc1d.xyShape.ys, sc2d.xyShape.xs, sc2d.xyShape.ys, func) |> combinePointConvolutionResults;
let continuousShapeSum = Distributions.Continuous.reduce((+.), [|ccxy, dcxy, cdxy|]);
(continuousShapeSum, ddxy) (continuousShapeSum, ddxy)
}; };
let evaluateCombinationDistribution = (et1: nodeResult, et2: nodeResult, op: operation, sampleCount: int) => { let evaluateCombinationDistribution = (et1: nodeResult, et2: nodeResult, op: operation, n: int) => {
/* return either a Distribution or a RenderedShape. Must integrate to 1. */ /* return either a Distribution or a RenderedShape. Must integrate to 1. */
let func = funcFromOp(op); let func = funcFromOp(op);
@ -439,31 +473,26 @@ module DistTree = {
/* General cases: convolve the XYShapes */ /* General cases: convolve the XYShapes */
| (`Distribution(d1), `Distribution(d2), _) => { | (`Distribution(d1), `Distribution(d2), _) => {
let (sc1, sd1) = renderDistributionToXYShape(d1, sampleCount); let (sc1, sd1) = renderDistributionToXYShape(d1, n);
let (sc2, sd2) = renderDistributionToXYShape(d2, sampleCount); let (sc2, sd2) = renderDistributionToXYShape(d2, n);
let (sc, sd) = combinationDistributionOfXYShapes(sc1, sd1, sc2, sd2, func); let (sc, sd) = combinationDistributionOfXYShapes(sc1, sd1, sc2, sd2, func);
`RenderedShape(sc, sd, 1.0) `RenderedShape(sc, sd, 1.0)
} }
| (`Distribution(d1), `RenderedShape(sc2, sd2, i2), _) => { | (`Distribution(d2), `RenderedShape(sc1, sd1, i1), _)
let (sc1, sd1) = renderDistributionToXYShape(d1, sampleCount); | (`RenderedShape(sc1, sd1, i1), `Distribution(d2), _) => {
let (sc1, sd1) = renderDistributionToXYShape(d1, n);
let (sc, sd) = combinationDistributionOfXYShapes(sc1, sd1, sc2, sd2, func); let (sc, sd) = combinationDistributionOfXYShapes(sc1, sd1, sc2, sd2, func);
`RenderedShape(sc, sd, i2) `RenderedShape(sc, sd, i2)
} }
| (`RenderedShape(sc1, sd1, i1), `Distribution(d2), _) => {
let (sc2, sd2) = renderDistributionToXYShape(d2, sampleCount);
let (sc, sd) = combinationDistributionOfXYShapes(sc1, sd1, sc2, sd2, func);
`RenderedShape(sc, sd, i1);
}
| (`RenderedShape(sc1, sd1, i1), `RenderedShape(sc2, sd2, i2), _) => { | (`RenderedShape(sc1, sd1, i1), `RenderedShape(sc2, sd2, i2), _) => {
// sum of two multimodals that have a continuous and discrete each. // sum of two multimodals that have a continuous and discrete each.
let (sc, sd) = combinationDistributionOfXYShapes(sc1, sd1, sc2, sd2, func); let (sc, sd) = combinationDistributionOfXYShapes(sc1, sd1, sc2, sd2, func);
`RenderedShape(sc, sd, i1); `RenderedShape(sc, sd, i1);
} }
} }
}; };
let evaluatePointwiseSum = (et1: nodeResult, et2: nodeResult, sampleCount: int) => { let evaluatePointwiseSum = (et1: nodeResult, et2: nodeResult, n: int) => {
switch ((et1, et2)) { switch ((et1, et2)) {
/* Known cases: */ /* Known cases: */
| (`Distribution(`Float(v1)), `Distribution(`Float(v2))) => { | (`Distribution(`Float(v1)), `Distribution(`Float(v2))) => {
@ -471,36 +500,29 @@ module DistTree = {
? `RenderedShape(Distributions.Continuous.empty, Distributions.Discrete.make({xs: [|v1|], ys: [|2.|]}), 2.) ? `RenderedShape(Distributions.Continuous.empty, Distributions.Discrete.make({xs: [|v1|], ys: [|2.|]}), 2.)
: `RenderedShape(Distributions.Continuous.empty, Distributions.Discrete.empty, 0.) // TODO: add warning: shouldn't pointwise add scalars. : `RenderedShape(Distributions.Continuous.empty, Distributions.Discrete.empty, 0.) // TODO: add warning: shouldn't pointwise add scalars.
} }
| (`Distribution(`Float(v1)), `Distribution(d2)) => { | (`Distribution(`Float(v1)), `Distribution(d2))
| (`Distribution(d2), `Distribution(`Float(v1))) => {
let sd1: DistTypes.xyShape = {xs: [|v1|], ys: [|1.|]}; let sd1: DistTypes.xyShape = {xs: [|v1|], ys: [|1.|]};
let (sc2, sd2) = renderDistributionToXYShape(d2, sampleCount); let (sc2, sd2) = renderDistributionToXYShape(d2, n);
`RenderedShape(sc2, Distributions.Discrete.reduce((+.), [|sd1, sd2|]), 2.) `RenderedShape(sc2, Distributions.Discrete.reduce((+.), [|sd1, sd2|]), 2.)
} }
| (`Distribution(d1), `Distribution(`Float(v2))) => {
let (sc1, sd1) = renderDistributionToXYShape(d1, sampleCount);
let sd2: DistTypes.xyShape = {xs: [|v2|], ys: [|1.|]};
`RenderedShape(sc1, Distributions.Discrete.reduce((+.), [|sd1, sd2|]), 2.)
}
| (`Distribution(d1), `Distribution(d2)) => { | (`Distribution(d1), `Distribution(d2)) => {
let (sc1, sd1) = renderDistributionToXYShape(d1, sampleCount); let (sc1, sd1) = renderDistributionToXYShape(d1, n);
let (sc2, sd2) = renderDistributionToXYShape(d2, sampleCount); let (sc2, sd2) = renderDistributionToXYShape(d2, n);
`RenderedShape(Distributions.Continuous.reduce((+.), [|sc1, sc2|]), Distributions.Discrete.reduce((+.), [|sd1, sd2|]), 2.) `RenderedShape(Distributions.Continuous.reduce((+.), [|sc1, sc2|]), Distributions.Discrete.reduce((+.), [|sd1, sd2|]), 2.)
} }
| (`Distribution(d1), `RenderedShape(sc2, sd2, i2)) | (`Distribution(d1), `RenderedShape(sc2, sd2, i2))
| (`RenderedShape(sc2, sd2, i2), `Distribution(d1)) => { | (`RenderedShape(sc2, sd2, i2), `Distribution(d1)) => {
let (sc1, sd1) = renderDistributionToXYShape(d1, sampleCount); let (sc1, sd1) = renderDistributionToXYShape(d1, n);
`RenderedShape(Distributions.Continuous.reduce((+.), [|sc1, sc2|]), Distributions.Discrete.reduce((+.), [|sd1, sd2|]), 1. +. i2) `RenderedShape(Distributions.Continuous.reduce((+.), [|sc1, sc2|]), Distributions.Discrete.reduce((+.), [|sd1, sd2|]), 1. +. i2)
} }
| (`RenderedShape(sc1, sd1, i1), `RenderedShape(sc2, sd2, i2)) => { | (`RenderedShape(sc1, sd1, i1), `RenderedShape(sc2, sd2, i2)) => {
Js.log3("Reducing continuous rr", sc1, sc2);
Js.log2("Continuous reduction:", Distributions.Continuous.reduce((+.), [|sc1, sc2|]));
Js.log2("Discrete reduction:", Distributions.Discrete.reduce((+.), [|sd1, sd2|]));
`RenderedShape(Distributions.Continuous.reduce((+.), [|sc1, sc2|]), Distributions.Discrete.reduce((+.), [|sd1, sd2|]), i1 +. i2) `RenderedShape(Distributions.Continuous.reduce((+.), [|sc1, sc2|]), Distributions.Discrete.reduce((+.), [|sd1, sd2|]), i1 +. i2)
} }
} }
}; };
let evaluatePointwiseProduct = (et1: nodeResult, et2: nodeResult, sampleCount: int) => { let evaluatePointwiseProduct = (et1: nodeResult, et2: nodeResult, n: int) => {
switch ((et1, et2)) { switch ((et1, et2)) {
/* Known cases: */ /* Known cases: */
| (`Distribution(`Float(v1)), `Distribution(`Float(v2))) => { | (`Distribution(`Float(v1)), `Distribution(`Float(v2))) => {
@ -528,20 +550,20 @@ module DistTree = {
| (`Distribution(d1), `Distribution(d2)) => { | (`Distribution(d1), `Distribution(d2)) => {
// NOT IMPLEMENTED YET // NOT IMPLEMENTED YET
// TODO: evaluate integral properly // TODO: evaluate integral properly
let (sc1, sd1) = renderDistributionToXYShape(d1, sampleCount); let (sc1, sd1) = renderDistributionToXYShape(d1, n);
let (sc2, sd2) = renderDistributionToXYShape(d2, sampleCount); let (sc2, sd2) = renderDistributionToXYShape(d2, n);
`RenderedShape(Distributions.Continuous.empty, Distributions.Discrete.empty, 0.) `RenderedShape(Distributions.Continuous.empty, Distributions.Discrete.empty, 0.)
} }
| (`Distribution(d1), `RenderedShape(sc2, sd2, i2)) => { | (`Distribution(d1), `RenderedShape(sc2, sd2, i2)) => {
// NOT IMPLEMENTED YET // NOT IMPLEMENTED YET
// TODO: evaluate integral properly // TODO: evaluate integral properly
let (sc1, sd1) = renderDistributionToXYShape(d1, sampleCount); let (sc1, sd1) = renderDistributionToXYShape(d1, n);
`RenderedShape(Distributions.Continuous.empty, Distributions.Discrete.empty, 0.) `RenderedShape(Distributions.Continuous.empty, Distributions.Discrete.empty, 0.)
} }
| (`RenderedShape(sc1, sd1, i1), `Distribution(d1)) => { | (`RenderedShape(sc1, sd1, i1), `Distribution(d1)) => {
// NOT IMPLEMENTED YET // NOT IMPLEMENTED YET
// TODO: evaluate integral properly // TODO: evaluate integral properly
let (sc2, sd2) = renderDistributionToXYShape(d1, sampleCount); let (sc2, sd2) = renderDistributionToXYShape(d1, n);
`RenderedShape(Distributions.Continuous.empty, Distributions.Discrete.empty, 0.) `RenderedShape(Distributions.Continuous.empty, Distributions.Discrete.empty, 0.)
} }
| (`RenderedShape(sc1, sd1, i1), `RenderedShape(sc2, sd2, i2)) => { | (`RenderedShape(sc1, sd1, i1), `RenderedShape(sc2, sd2, i2)) => {
@ -551,7 +573,7 @@ module DistTree = {
}; };
let evaluateNormalize = (et: nodeResult, sampleCount: int) => { let evaluateNormalize = (et: nodeResult, n: int) => {
// just divide everything by the integral. // just divide everything by the integral.
switch (et) { switch (et) {
| `RenderedShape(sc, sd, 0.) => { | `RenderedShape(sc, sd, 0.) => {
@ -570,7 +592,7 @@ module DistTree = {
} }
}; };
let evaluateTruncate = (et: nodeResult, xc: cutoffX, compareFunc: (float, float) => bool, sampleCount: int) => { let evaluateTruncate = (et: nodeResult, xc: cutoffX, compareFunc: (float, float) => bool, n: int) => {
let cut = (s: DistTypes.xyShape): DistTypes.xyShape => { let cut = (s: DistTypes.xyShape): DistTypes.xyShape => {
let (xs, ys) = s.ys let (xs, ys) = s.ys
|> Belt.Array.zip(s.xs) |> Belt.Array.zip(s.xs)
@ -583,7 +605,7 @@ module DistTree = {
switch (et) { switch (et) {
| `Distribution(d) => { | `Distribution(d) => {
let (sc, sd) = renderDistributionToXYShape(d, sampleCount); let (sc, sd) = renderDistributionToXYShape(d, n);
let scc = sc |> Distributions.Continuous.shapeMap(cut); let scc = sc |> Distributions.Continuous.shapeMap(cut);
let sdc = sd |> cut; let sdc = sd |> cut;
@ -603,13 +625,13 @@ module DistTree = {
} }
}; };
let evaluateVerticalScaling = (et1: nodeResult, et2: nodeResult, sampleCount: int) => { let evaluateVerticalScaling = (et1: nodeResult, et2: nodeResult, n: int) => {
let scale = (i: float, s: DistTypes.xyShape): DistTypes.xyShape => {xs: s.xs, ys: s.ys |> E.A.fmap(y => y *. i)}; let scale = (i: float, s: DistTypes.xyShape): DistTypes.xyShape => {xs: s.xs, ys: s.ys |> E.A.fmap(y => y *. i)};
switch ((et1, et2)) { switch ((et1, et2)) {
| (`Distribution(`Float(v)), `Distribution(d)) | (`Distribution(`Float(v)), `Distribution(d))
| (`Distribution(d), `Distribution(`Float(v))) => { | (`Distribution(d), `Distribution(`Float(v))) => {
let (sc, sd) = renderDistributionToXYShape(d, sampleCount); let (sc, sd) = renderDistributionToXYShape(d, n);
let scc = sc |> Distributions.Continuous.shapeMap(scale(v)); let scc = sc |> Distributions.Continuous.shapeMap(scale(v));
let sdc = sd |> scale(v); let sdc = sd |> scale(v);
@ -631,47 +653,34 @@ module DistTree = {
} }
} }
let renderNode = (et: nodeResult, sampleCount: int) => { let renderNode = (et: nodeResult, n: int) => {
switch (et) { switch (et) {
| `Distribution(d) => { | `Distribution(d) => {
let (sc, sd) = renderDistributionToXYShape(d, sampleCount); let (sc, sd) = renderDistributionToXYShape(d, n);
`RenderedShape(sc, sd, 1.0); `RenderedShape(sc, sd, 1.0);
} }
| s => s | s => s
} }
} }
let rec evaluateNode = (treeNode: distTree, sampleCount: int): nodeResult => { let rec evaluateNode = (treeNode: distTree, n: int): nodeResult => {
// returns either a new symbolic distribution // returns either a new symbolic distribution
switch (treeNode) { switch (treeNode) {
| `Distribution(d) => evaluateDistribution(d) | `Distribution(d) => evaluateDistribution(d)
| `Combination(t1, t2, op) => evaluateCombinationDistribution(evaluateNode(t1, sampleCount), evaluateNode(t2, sampleCount), op, sampleCount) | `Combination(t1, t2, op) => evaluateCombinationDistribution(evaluateNode(t1, n), evaluateNode(t2, n), op, n)
| `PointwiseSum(t1, t2) => evaluatePointwiseSum(evaluateNode(t1, sampleCount), evaluateNode(t2, sampleCount), sampleCount) | `PointwiseSum(t1, t2) => evaluatePointwiseSum(evaluateNode(t1, n), evaluateNode(t2, n), n)
| `PointwiseProduct(t1, t2) => evaluatePointwiseProduct(evaluateNode(t1, sampleCount), evaluateNode(t2, sampleCount), sampleCount) | `PointwiseProduct(t1, t2) => evaluatePointwiseProduct(evaluateNode(t1, n), evaluateNode(t2, n), n)
| `VerticalScaling(t1, t2) => evaluateVerticalScaling(evaluateNode(t1, sampleCount), evaluateNode(t2, sampleCount), sampleCount) | `VerticalScaling(t1, t2) => evaluateVerticalScaling(evaluateNode(t1, n), evaluateNode(t2, n), n)
| `Normalize(t) => evaluateNormalize(evaluateNode(t, sampleCount), sampleCount) | `Normalize(t) => evaluateNormalize(evaluateNode(t, n), n)
| `LeftTruncate(t, x) => evaluateTruncate(evaluateNode(t, sampleCount), x, (<=), sampleCount) | `LeftTruncate(t, x) => evaluateTruncate(evaluateNode(t, n), x, (>=), n)
| `RightTruncate(t, x) => evaluateTruncate(evaluateNode(t, sampleCount), x, (>=), sampleCount) | `RightTruncate(t, x) => evaluateTruncate(evaluateNode(t, n), x, (<=), n)
| `Render(t) => renderNode(evaluateNode(t, sampleCount), sampleCount) | `Render(t) => renderNode(evaluateNode(t, n), n)
} }
}; };
let toShape = (treeNode: distTree, sampleCount: int) => { let toShape = (treeNode: distTree, n: int) => {
/*let continuousXs = findContinuousXs(dists, sampleCount); let treeShape = evaluateNode(`Render(`Normalize(treeNode)), n);
continuousXs |> Array.fast_sort(compare);
let (contShapes, distShapes) = accumulateContAndDiscShapes(dists, continuousXs, 1.0);
let combinedContinuous = contShapes
|> E.A.fold_left((shapeAcc: DistTypes.xyShape, shape: DistTypes.xyShape) => {
let ys = E.A.fmapi((i, y) => y +. shape.ys[i], shapeAcc.ys);
{xs: continuousXs, ys: ys}
}, {xs: continuousXs, ys: Array.make(Array.length(continuousXs), 0.0)})
|> Distributions.Continuous.make(`Linear);
let combinedDiscrete = Distributions.Discrete.reduce((+.), distShapes)*/
let treeShape = evaluateNode(`Render(`Normalize(treeNode)), sampleCount);
switch (treeShape) { switch (treeShape) {
| `Distribution(_) => E.O.toExn("No shape found!", None) | `Distribution(_) => E.O.toExn("No shape found!", None)
| `RenderedShape(sc, sd, _) => { | `RenderedShape(sc, sd, _) => {
@ -700,6 +709,7 @@ module DistTree = {
| `Normalize(t) => "normalize(" ++ toString(t) ++ ")" | `Normalize(t) => "normalize(" ++ toString(t) ++ ")"
| `LeftTruncate(t, x) => "leftTruncate(" ++ toString(t) ++ ", " ++ string_of_float(x) ++ ")" | `LeftTruncate(t, x) => "leftTruncate(" ++ toString(t) ++ ", " ++ string_of_float(x) ++ ")"
| `RightTruncate(t, x) => "rightTruncate(" ++ toString(t) ++ ", " ++ string_of_float(x) ++ ")" | `RightTruncate(t, x) => "rightTruncate(" ++ toString(t) ++ ", " ++ string_of_float(x) ++ ")"
| `Render(t) => toString(t)
} }
}; };
}; };