Big refactor towards proper distTree, still slow and untested

This commit is contained in:
Sebastian Kosch 2020-06-12 23:30:51 -07:00
parent bc271a090b
commit f6c1918b12
7 changed files with 475 additions and 247 deletions

View File

@ -74,6 +74,21 @@ module Continuous = {
(fn, {xyShape, interpolation}: t): option(DistTypes.continuousShape) => (fn, {xyShape, interpolation}: t): option(DistTypes.continuousShape) =>
fn(xyShape) |> E.O.fmap(make(interpolation)); fn(xyShape) |> E.O.fmap(make(interpolation));
let empty: DistTypes.continuousShape = {xyShape: XYShape.T.empty, interpolation: `Linear};
let combine =
(fn, t1: DistTypes.continuousShape, t2: DistTypes.continuousShape)
: DistTypes.continuousShape => {
make(`Linear, XYShape.Combine.combine(
~xsSelection=ALL_XS,
~xToYSelection=XYShape.XtoY.linear,
~fn,
t1.xyShape,
t2.xyShape,
));
};
let reduce = (fn, items) =>
items |> E.A.fold_left(combine(fn), empty);
let toLinear = (t: t): option(t) => { let toLinear = (t: t): option(t) => {
switch (t) { switch (t) {
| {interpolation: `Stepwise, xyShape} => | {interpolation: `Stepwise, xyShape} =>
@ -166,6 +181,7 @@ module Discrete = {
let sortedByX = (t: DistTypes.discreteShape) => let sortedByX = (t: DistTypes.discreteShape) =>
t |> XYShape.T.zip |> XYShape.Zipped.sortByX; t |> XYShape.T.zip |> XYShape.Zipped.sortByX;
let empty = XYShape.T.empty; let empty = XYShape.T.empty;
let make = (s: DistTypes.discreteShape) => s;
let combine = let combine =
(fn, t1: DistTypes.discreteShape, t2: DistTypes.discreteShape) (fn, t1: DistTypes.discreteShape, t2: DistTypes.discreteShape)
: DistTypes.discreteShape => { : DistTypes.discreteShape => {

View File

@ -179,16 +179,25 @@ module Combine = {
t1: T.t, t1: T.t,
t2: T.t, t2: T.t,
) => { ) => {
let allXs =
switch (xsSelection) {
| ALL_XS => Ts.allXs([|t1, t2|])
| XS_EVENLY_DIVIDED(sampleCount) =>
Ts.equallyDividedXs([|t1, t2|], sampleCount)
};
let allYs = switch ((E.A.length(t1.xs), E.A.length(t2.xs))) {
allXs |> E.A.fmap(x => fn(xToYSelection(x, t1), xToYSelection(x, t2))); | (0, 0) => T.empty
T.fromArrays(allXs, allYs); | (0, _) => t2
| (_, 0) => t1
| (_, _) => {
let allXs =
switch (xsSelection) {
| ALL_XS => Ts.allXs([|t1, t2|])
| XS_EVENLY_DIVIDED(sampleCount) =>
Ts.equallyDividedXs([|t1, t2|], sampleCount)
};
let allYs =
allXs |> E.A.fmap(x => fn(xToYSelection(x, t1), xToYSelection(x, t2)));
T.fromArrays(allXs, allYs);
}
}
}; };
let combineLinear = combine(~xToYSelection=XtoY.linear); let combineLinear = combine(~xToYSelection=XtoY.linear);

View File

@ -43,7 +43,7 @@ module ShapeRenderer = {
module Symbolic = { module Symbolic = {
type inputs = {length: int}; type inputs = {length: int};
type outputs = { type outputs = {
graph: SymbolicDist.bigDist, graph: SymbolicDist.distTree,
shape: DistTypes.shape, shape: DistTypes.shape,
}; };
let make = (graph, shape) => {graph, shape}; let make = (graph, shape) => {graph, shape};

View File

@ -88,73 +88,69 @@ module MathAdtToDistDst = {
); );
}; };
let normal: array(arg) => result(SymbolicDist.bigDist, string) = let normal: array(arg) => result(SymbolicDist.distTree, string) =
fun fun
| [|Value(mean), Value(stdev)|] => | [|Value(mean), Value(stdev)|] =>
Ok(`Simple(`Normal({mean, stdev}))) Ok(`Distribution(`Normal({mean, stdev})))
| _ => Error("Wrong number of variables in normal distribution"); | _ => Error("Wrong number of variables in normal distribution");
let lognormal: array(arg) => result(SymbolicDist.bigDist, string) = let lognormal: array(arg) => result(SymbolicDist.distTree, string) =
fun fun
| [|Value(mu), Value(sigma)|] => Ok(`Simple(`Lognormal({mu, sigma}))) | [|Value(mu), Value(sigma)|] => Ok(`Distribution(`Lognormal({mu, sigma})))
| [|Object(o)|] => { | [|Object(o)|] => {
let g = Js.Dict.get(o); let g = Js.Dict.get(o);
switch (g("mean"), g("stdev"), g("mu"), g("sigma")) { switch (g("mean"), g("stdev"), g("mu"), g("sigma")) {
| (Some(Value(mean)), Some(Value(stdev)), _, _) => | (Some(Value(mean)), Some(Value(stdev)), _, _) =>
Ok(`Simple(SymbolicDist.Lognormal.fromMeanAndStdev(mean, stdev))) Ok(`Distribution(SymbolicDist.Lognormal.fromMeanAndStdev(mean, stdev)))
| (_, _, Some(Value(mu)), Some(Value(sigma))) => | (_, _, Some(Value(mu)), Some(Value(sigma))) =>
Ok(`Simple(`Lognormal({mu, sigma}))) Ok(`Distribution(`Lognormal({mu, sigma})))
| _ => Error("Lognormal distribution would need mean and stdev") | _ => Error("Lognormal distribution would need mean and stdev")
}; };
} }
| _ => Error("Wrong number of variables in lognormal distribution"); | _ => Error("Wrong number of variables in lognormal distribution");
let to_: array(arg) => result(SymbolicDist.bigDist, string) = let to_: array(arg) => result(SymbolicDist.distTree, string) =
fun fun
| [|Value(low), Value(high)|] when low <= 0.0 && low < high=> { | [|Value(low), Value(high)|] when low <= 0.0 && low < high=> {
Ok(`Simple(SymbolicDist.Normal.from90PercentCI(low, high))); Ok(`Distribution(SymbolicDist.Normal.from90PercentCI(low, high)));
} }
| [|Value(low), Value(high)|] when low < high => { | [|Value(low), Value(high)|] when low < high => {
Ok(`Simple(SymbolicDist.Lognormal.from90PercentCI(low, high))); Ok(`Distribution(SymbolicDist.Lognormal.from90PercentCI(low, high)));
} }
| [|Value(_), Value(_)|] => | [|Value(_), Value(_)|] =>
Error("Low value must be less than high value.") Error("Low value must be less than high value.")
| _ => Error("Wrong number of variables in lognormal distribution"); | _ => Error("Wrong number of variables in lognormal distribution");
let uniform: array(arg) => result(SymbolicDist.bigDist, string) = let uniform: array(arg) => result(SymbolicDist.distTree, string) =
fun fun
| [|Value(low), Value(high)|] => Ok(`Simple(`Uniform({low, high}))) | [|Value(low), Value(high)|] => Ok(`Distribution(`Uniform({low, high})))
| _ => Error("Wrong number of variables in lognormal distribution"); | _ => Error("Wrong number of variables in lognormal distribution");
let beta: array(arg) => result(SymbolicDist.bigDist, string) = let beta: array(arg) => result(SymbolicDist.distTree, string) =
fun fun
| [|Value(alpha), Value(beta)|] => Ok(`Simple(`Beta({alpha, beta}))) | [|Value(alpha), Value(beta)|] => Ok(`Distribution(`Beta({alpha, beta})))
| _ => Error("Wrong number of variables in lognormal distribution"); | _ => Error("Wrong number of variables in lognormal distribution");
let exponential: array(arg) => result(SymbolicDist.bigDist, string) = let exponential: array(arg) => result(SymbolicDist.distTree, string) =
fun fun
| [|Value(rate)|] => Ok(`Simple(`Exponential({rate: rate}))) | [|Value(rate)|] => Ok(`Distribution(`Exponential({rate: rate})))
| _ => Error("Wrong number of variables in Exponential distribution"); | _ => Error("Wrong number of variables in Exponential distribution");
let cauchy: array(arg) => result(SymbolicDist.bigDist, string) = let cauchy: array(arg) => result(SymbolicDist.distTree, string) =
fun fun
| [|Value(local), Value(scale)|] => | [|Value(local), Value(scale)|] =>
Ok(`Simple(`Cauchy({local, scale}))) Ok(`Distribution(`Cauchy({local, scale})))
| _ => Error("Wrong number of variables in cauchy distribution"); | _ => Error("Wrong number of variables in cauchy distribution");
let triangular: array(arg) => result(SymbolicDist.bigDist, string) = let triangular: array(arg) => result(SymbolicDist.distTree, string) =
fun fun
| [|Value(low), Value(medium), Value(high)|] => | [|Value(low), Value(medium), Value(high)|] =>
Ok(`Simple(`Triangular({low, medium, high}))) Ok(`Distribution(`Triangular({low, medium, high})))
| _ => Error("Wrong number of variables in triangle distribution"); | _ => Error("Wrong number of variables in triangle distribution");
/*let add: array(arg) => result(SymbolicDist.bigDist, string) =
fun
| */
let multiModal = let multiModal =
( (
args: array(result(SymbolicDist.bigDist, string)), args: array(result(SymbolicDist.distTree, string)),
weights: option(array(float)), weights: option(array(float)),
) => { ) => {
let weights = weights |> E.O.default([||]); let weights = weights |> E.O.default([||]);
@ -162,8 +158,14 @@ module MathAdtToDistDst = {
args args
|> E.A.fmap( |> E.A.fmap(
fun fun
| Ok(`Simple(d)) => Ok(`Simple(d)) | Ok(`Distribution(d)) => Ok(`Distribution(d))
| Ok(`PointwiseCombination(dists)) => Ok(`PointwiseCombination(dists)) | Ok(`Combination(t1, t2, op)) => Ok(`Combination(t1, t2, op))
| Ok(`PointwiseSum(t1, t2)) => Ok(`PointwiseSum(t1, t2))
| Ok(`PointwiseProduct(t1, t2)) => Ok(`PointwiseProduct(t1, t2))
| Ok(`Normalize(t)) => Ok(`Normalize(t))
| Ok(`LeftTruncate(t, x)) => Ok(`LeftTruncate(t, x))
| Ok(`RightTruncate(t, x)) => Ok(`RightTruncate(t, x))
| Ok(`Render(t)) => Ok(`Render(t))
| Error(e) => Error(e) | Error(e) => Error(e)
| _ => Error("Unexpected dist") | _ => Error("Unexpected dist")
); );
@ -175,16 +177,26 @@ module MathAdtToDistDst = {
| Some(Error(e)) => Error(e) | Some(Error(e)) => Error(e)
| None when withoutErrors |> E.A.length == 0 => | None when withoutErrors |> E.A.length == 0 =>
Error("Multimodals need at least one input") Error("Multimodals need at least one input")
| _ => | _ => {
withoutErrors let components = withoutErrors
|> E.A.fmapi((index, item) => |> E.A.fmapi((index, t) => {
(item, weights |> E.A.get(_, index) |> E.O.default(1.0)) let w = weights |> E.A.get(_, index) |> E.O.default(1.0);
)
|> (r => Ok(`PointwiseCombination(r))) `VerticalScaling(t, `Distribution(`Float(w)))
});
let pointwiseSum = components
|> Js.Array.sliceFrom(1)
|> E.A.fold_left((acc, x) => {
`PointwiseSum(acc, x)
}, E.A.unsafe_get(components, 0))
Ok(`Normalize(pointwiseSum))
}
}; };
}; };
let arrayParser = (args:array(arg)):result(SymbolicDist.bigDist, string) => { let arrayParser = (args:array(arg)):result(SymbolicDist.distTree, string) => {
let samples = args let samples = args
|> E.A.fmap( |> E.A.fmap(
fun fun
@ -200,13 +212,13 @@ module MathAdtToDistDst = {
SymbolicDist.ContinuousShape.make(_pdf, cdf) SymbolicDist.ContinuousShape.make(_pdf, cdf)
}); });
switch(shape){ switch(shape){
| Some(s) => Ok(`Simple(`ContinuousShape(s))) | Some(s) => Ok(`Distribution(`ContinuousShape(s)))
| None => Error("Rendering did not work") | None => Error("Rendering did not work")
} }
} }
let rec functionParser = (r): result(SymbolicDist.bigDist, string) => let rec functionParser = (r): result(SymbolicDist.distTree, string) =>
r r
|> ( |> (
fun fun
@ -218,7 +230,7 @@ module MathAdtToDistDst = {
| Fn({name: "exponential", args}) => exponential(args) | Fn({name: "exponential", args}) => exponential(args)
| Fn({name: "cauchy", args}) => cauchy(args) | Fn({name: "cauchy", args}) => cauchy(args)
| Fn({name: "triangular", args}) => triangular(args) | Fn({name: "triangular", args}) => triangular(args)
| Value(f) => Ok(`Simple(`Float(f))) | Value(f) => Ok(`Distribution(`Float(f)))
| Fn({name: "mm", args}) => { | Fn({name: "mm", args}) => {
let weights = let weights =
args args
@ -245,25 +257,54 @@ module MathAdtToDistDst = {
let dists = possibleDists |> E.A.fmap(functionParser); let dists = possibleDists |> E.A.fmap(functionParser);
multiModal(dists, weights); multiModal(dists, weights);
} }
//| Fn({name: "add", args}) => add(args)
| Fn({name: "add", args}) => {
args
|> E.A.fmap(functionParser)
|> (fun
| [|Ok(l), Ok(r)|] => Ok(`Combination(l, r, `AddOperation))
| _ => Error("Addition needs two operands"))
}
| Fn({name: "subtract", args}) => {
args
|> E.A.fmap(functionParser)
|> (fun
| [|Ok(l), Ok(r)|] => Ok(`Combination(l, r, `SubtractOperation))
| _ => Error("Subtraction needs two operands"))
}
| Fn({name: "multiply", args}) => {
args
|> E.A.fmap(functionParser)
|> (fun
| [|Ok(l), Ok(r)|] => Ok(`Combination(l, r, `MultiplyOperation))
| _ => Error("Multiplication needs two operands"))
}
| Fn({name: "divide", args}) => {
args
|> E.A.fmap(functionParser)
|> (fun
| [|Ok(l), Ok(`Distribution(`Float(0.0)))|] => Error("Division by zero")
| [|Ok(l), Ok(r)|] => Ok(`Combination(l, r, `DivideOperation))
| _ => Error("Division needs two operands"))
}
| Fn({name}) => Error(name ++ ": function not supported") | Fn({name}) => Error(name ++ ": function not supported")
| _ => { | _ => {
Error("This type not currently supported"); Error("This type not currently supported");
} }
); );
let topLevel = (r): result(SymbolicDist.bigDist, string) => let topLevel = (r): result(SymbolicDist.distTree, string) =>
r r
|> ( |> (
fun fun
| Fn(_) => functionParser(r) | Fn(_) => functionParser(r)
| Value(r) => Ok(`Simple(`Float(r))) | Value(r) => Ok(`Distribution(`Float(r)))
| Array(r) => arrayParser(r) | Array(r) => arrayParser(r)
| Symbol(_) => Error("Symbol not valid as top level") | Symbol(_) => Error("Symbol not valid as top level")
| Object(_) => Error("Object not valid as top level") | Object(_) => Error("Object not valid as top level")
); );
let run = (r): result(SymbolicDist.bigDist, string) => let run = (r): result(SymbolicDist.distTree, string) =>
r |> MathAdtCleaner.run |> topLevel; r |> MathAdtCleaner.run |> topLevel;
}; };

View File

@ -50,41 +50,28 @@ type dist = [
| `Float(float) // Dirac delta at x. Practically useful only in the context of multimodals. | `Float(float) // Dirac delta at x. Practically useful only in the context of multimodals.
]; ];
/* Build a tree. type integral = float;
type cutoffX = float;
type operation = [
| `AddOperation
| `SubtractOperation
| `MultiplyOperation
| `DivideOperation
| `ExponentiateOperation
];
Multiple operations possible: type distTree = [
| `Distribution(dist)
- PointwiseSum(Scalar, Scalar) | `Combination(distTree, distTree, operation)
- PointwiseSum(WeightedDist, WeightedDist) | `PointwiseSum(distTree, distTree)
- PointwiseProduct(Scalar, Scalar) | `PointwiseProduct(distTree, distTree)
- PointwiseProduct(Scalar, WeightedDist) | `VerticalScaling(distTree, distTree)
- PointwiseProduct(WeightedDist, WeightedDist) | `Normalize(distTree)
| `LeftTruncate(distTree, cutoffX)
- IndependentVariableSum(WeightedDist, WeightedDist) [i.e., convolution] | `RightTruncate(distTree, cutoffX)
- IndependentVariableProduct(WeightedDist, WeightedDist) [i.e. distribution product] | `Render(distTree)
*/ ]
and weightedDists = array((distTree, float));
/*type weightedDist = (float, dist);
type bigDistTree =
/* | DistLeaf(dist) */
/* | ScalarLeaf(float) */
/* | PointwiseScalarDistProduct(DistLeaf(d), ScalarLeaf(s)) */
| WeightedDistLeaf(weightedDist)
| PointwiseNormalizedDistSum(array(bigDistTree));
let rec treeIntegral = item => {
switch (item) {
| WeightedDistLeaf((w, d)) => w
| PointwiseNormalizedDistSum(childTrees) =>
childTrees |> E.A.fmap(treeIntegral) |> E.A.Floats.sum
};
};*/
/* bigDist can either be a single distribution, or a
PointwiseCombination, i.e. an array of (dist, weight) tuples */
type bigDist = [ | `Simple(dist) | `PointwiseCombination(pointwiseAdd)]
and pointwiseAdd = array((bigDist, float));
module ContinuousShape = { module ContinuousShape = {
type t = continuousShape; type t = continuousShape;
@ -326,138 +313,331 @@ module GenericSimple = {
}; };
}; };
module PointwiseAddDistributionsWeighted = { module DistTree = {
type t = pointwiseAdd; type nodeResult = [
| `Distribution(dist)
// RenderedShape: continuous xyShape, discrete xyShape, total value.
| `RenderedShape(DistTypes.continuousShape, DistTypes.discreteShape, integral)
];
let normalizeWeights = (weightedDists: t) => { let evaluateDistribution = (d: dist): nodeResult => {
let total = weightedDists |> E.A.fmap(snd) |> E.A.Floats.sum; // certain distributions we may want to evaluate to RenderedShapes right away, e.g. discrete
weightedDists |> E.A.fmap(((d, w)) => (d, w /. total)); `Distribution(d)
}; };
let rec pdf = (x: float, weightedNormalizedDists: t) => // This is a performance bottleneck!
weightedNormalizedDists // Using raw JS here so we can use native for loops and access array elements
|> E.A.fmap(((d, w)) => { // directly, without option checks.
switch (d) { let jsCombinationConvolve: (array(float), array(float), array(float), array(float), float => float => float) => (array(float), array(float)) = [%bs.raw
| `PointwiseCombination(ts) => pdf(x, ts) *. w {|
| `Simple(d) => GenericSimple.pdf(x, d) *. w function (s1xs, s1ys, s2xs, s2ys, func) {
} const r = new Map();
})
|> E.A.Floats.sum;
// TODO: perhaps rename into minCdfX? // To convolve, add the xs and multiply the ys:
// TODO: how should nonexistent min values be handled? They should never happen for (let i = 0; i < s1xs.length; i++) {
let rec min = (dists: t) => for (let j = 0; j < s2xs.length; j++) {
dists const x = func(s1xs[i], s2xs[j]);
|> E.A.fmap(((d, w)) => { const cv = r.get(x) | 0;
switch (d) { r.set(x, cv + s1ys[i] * s2ys[j]); // add up the ys, if same x
| `PointwiseCombination(ts) => E.O.toExn("Dist has no min", min(ts)) }
| `Simple(d) => GenericSimple.min(d) }
}
})
|> E.A.min;
// TODO: perhaps rename into minCdfX? const rxys = [...r.entries()];
let rec max = (dists: t) => rxys.sort(([x1, y1], [x2, y2]) => x1 - x2);
dists
|> E.A.fmap(((d, w)) => {
switch (d) {
| `PointwiseCombination(ts) => E.O.toExn("Dist has no max", max(ts))
| `Simple(d) => GenericSimple.max(d)
}
})
|> E.A.max;
const rxs = new Array(rxys.length);
const rys = new Array(rxys.length);
/*let rec discreteShape = (dists: t, sampleCount: int) => { for (let i = 0; i < rxys.length; i++) {
let discrete = rxs[i] = rxys[i][0];
dists rys[i] = rxys[i][1];
|> E.A.fmap(((x, w)) => { }
switch (d) {
| `Float(d) => Some((d, w)) // if the distribution is just a number, then the weight is considered the y
| _ => None
}
})
|> E.A.O.concatSomes
|> E.A.fmap(((x, y)) =>
({xs: [|x|], ys: [|y|]}: DistTypes.xyShape)
)
// take an array of xyShapes and combine them together
//* r
|> (
fun
| `Float(r) => Some((r, e))
| _ => None
)
)*/
|> Distributions.Discrete.reduce((+.));
discrete;
};*/
return [rxs, rys];
}
|}];
let rec findContinuousXs = (dists: t, sampleCount: int) => { let funcFromOp = (op: operation) => {
// we need to go through the tree of distributions and, for the continuous ones, find the xs at which switch (op) {
// later, all distributions will get evaluated. | `AddOperation => (+.)
| `SubtractOperation => (-.)
| `MultiplyOperation => (*.)
| `DivideOperation => (/.)
| `ExponentiateOperation => (**)
}
}
// we want to accumulate a set of xs. let renderDistributionToXYShape = (d: dist, sampleCount: int): (DistTypes.continuousShape, DistTypes.discreteShape) => {
let xs: array(float) = // render the distribution into an XY shape
dists switch (d) {
|> E.A.fold_left((accXs, (d, w)) => { | `Float(v) => (Distributions.Continuous.empty, {xs: [|v|], ys: [|1.0|]})
switch (d) { | _ => {
| `Simple(t) when (GenericSimple.contType(t) == `Discrete) => accXs let xs = GenericSimple.interpolateXs(~xSelection=`ByWeight, d, sampleCount);
| `Simple(d) => { let ys = xs |> E.A.fmap(x => GenericSimple.pdf(x, d));
let xs = GenericSimple.interpolateXs(~xSelection=`ByWeight, d, sampleCount) (Distributions.Continuous.make(`Linear, {xs: xs, ys: ys}), XYShape.T.empty)
}
E.A.append(accXs, xs) }
}
| `PointwiseCombination(ts) => {
let xs = findContinuousXs(ts, sampleCount);
E.A.append(accXs, xs)
}
}
}, [||]);
xs
}; };
/* Accumulate (accContShapes, accDistShapes), each of which is an array of {xs, ys} shapes. */ let combinationDistributionOfXYShapes = (sc1: DistTypes.continuousShape, // continuous shape
let rec accumulateContAndDiscShapes = (dists: t, continuousXs: array(float), currentWeight) => { sd1: DistTypes.discreteShape, // discrete shape
let normalized = normalizeWeights(dists); sc2: DistTypes.continuousShape,
sd2: DistTypes.discreteShape, func): (DistTypes.continuousShape, DistTypes.discreteShape) => {
normalized let (ccxs, ccys) = jsCombinationConvolve(sc1.xyShape.xs, sc1.xyShape.ys, sc2.xyShape.xs, sc2.xyShape.ys, func);
|> E.A.fold_left(((accContShapes: array(DistTypes.xyShape), accDiscShapes: array(DistTypes.xyShape)), (d, w)) => { let (dcxs, dcys) = jsCombinationConvolve(sd1.xs, sd1.ys, sc2.xyShape.xs, sc2.xyShape.ys, func);
switch (d) { let (cdxs, cdys) = jsCombinationConvolve(sc1.xyShape.xs, sc1.xyShape.ys, sd2.xs, sd2.ys, func);
let (ddxs, ddys) = jsCombinationConvolve(sd1.xs, sd1.ys, sd2.xs, sd2.ys, func);
| `Simple(`Float(x)) => { let ccxy = Distributions.Continuous.make(`Linear, {xs: ccxs, ys: ccys});
let ds: DistTypes.xyShape = {xs: [|x|], ys: [|w *. currentWeight|]}; let dcxy = Distributions.Continuous.make(`Linear, {xs: dcxs, ys: dcys});
(accContShapes, E.A.append(accDiscShapes, [|ds|])) let cdxy = Distributions.Continuous.make(`Linear, {xs: cdxs, ys: cdys});
} // the continuous parts are added up; only the discrete-discrete sum is discrete
let continuousShapeSum = Distributions.Continuous.reduce((+.), [|ccxy, dcxy, cdxy|]);
| `Simple(d) when (GenericSimple.contType(d) == `Continuous) => { let ddxy: DistTypes.discreteShape = {xs: cdxs, ys: cdys};
let ys = continuousXs |> E.A.fmap(x => GenericSimple.pdf(x, d) *. w *. currentWeight);
let cs = XYShape.T.fromArrays(continuousXs, ys);
(E.A.append(accContShapes, [|cs|]), accDiscShapes) (continuousShapeSum, ddxy)
} };
| `Simple(d) => (accContShapes, accDiscShapes) // default -- should never happen let evaluateCombinationDistribution = (et1: nodeResult, et2: nodeResult, op: operation, sampleCount: int) => {
/* return either a Distribution or a RenderedShape. Must integrate to 1. */
| `PointwiseCombination(ts) => { let func = funcFromOp(op);
let (cs, ds) = accumulateContAndDiscShapes(ts, continuousXs, w *. currentWeight); switch ((et1, et2, op)) {
(E.A.append(accContShapes, cs), E.A.append(accDiscShapes, ds)) /* Known cases: replace symbolic with symbolic distribution */
} | (`Distribution(`Float(v1)), `Distribution(`Float(v2)), _) => {
`Distribution(`Float(func(v1, v2)))
}
| (`Distribution(`Float(v1)), `Distribution(`Normal(n2)), `AddOperation) => {
let n: normal = {mean: v1 +. n2.mean, stdev: n2.stdev};
`Distribution(`Normal(n))
}
| (`Distribution(`Normal(n1)), `Distribution(`Normal(n2)), `AddOperation) => {
let n: normal = {mean: n1.mean +. n2.mean, stdev: sqrt(n1.stdev ** 2. +. n2.stdev ** 2.)};
`Distribution(`Normal(n));
}
/* General cases: convolve the XYShapes */
| (`Distribution(d1), `Distribution(d2), _) => {
let (sc1, sd1) = renderDistributionToXYShape(d1, sampleCount);
let (sc2, sd2) = renderDistributionToXYShape(d2, sampleCount);
let (sc, sd) = combinationDistributionOfXYShapes(sc1, sd1, sc2, sd2, func);
`RenderedShape(sc, sd, 1.0)
}
| (`Distribution(d1), `RenderedShape(sc2, sd2, i2), _) => {
let (sc1, sd1) = renderDistributionToXYShape(d1, sampleCount);
let (sc, sd) = combinationDistributionOfXYShapes(sc1, sd1, sc2, sd2, func);
`RenderedShape(sc, sd, i2)
}
| (`RenderedShape(sc1, sd1, i1), `Distribution(d2), _) => {
let (sc2, sd2) = renderDistributionToXYShape(d2, sampleCount);
let (sc, sd) = combinationDistributionOfXYShapes(sc1, sd1, sc2, sd2, func);
`RenderedShape(sc, sd, i1);
}
| (`RenderedShape(sc1, sd1, i1), `RenderedShape(sc2, sd2, i2), _) => {
// sum of two multimodals that have a continuous and discrete each.
let (sc, sd) = combinationDistributionOfXYShapes(sc1, sd1, sc2, sd2, func);
`RenderedShape(sc, sd, i1);
}
}
};
let evaluatePointwiseSum = (et1: nodeResult, et2: nodeResult, sampleCount: int) => {
switch ((et1, et2)) {
/* Known cases: */
| (`Distribution(`Float(v1)), `Distribution(`Float(v2))) => {
v1 == v2
? `RenderedShape(Distributions.Continuous.empty, Distributions.Discrete.make({xs: [|v1|], ys: [|2.|]}), 2.)
: `RenderedShape(Distributions.Continuous.empty, Distributions.Discrete.empty, 0.) // TODO: add warning: shouldn't pointwise add scalars.
}
| (`Distribution(`Float(v1)), `Distribution(d2)) => {
let sd1: DistTypes.xyShape = {xs: [|v1|], ys: [|1.|]};
let (sc2, sd2) = renderDistributionToXYShape(d2, sampleCount);
`RenderedShape(sc2, Distributions.Discrete.reduce((+.), [|sd1, sd2|]), 2.)
}
| (`Distribution(d1), `Distribution(`Float(v2))) => {
let (sc1, sd1) = renderDistributionToXYShape(d1, sampleCount);
let sd2: DistTypes.xyShape = {xs: [|v2|], ys: [|1.|]};
`RenderedShape(sc1, Distributions.Discrete.reduce((+.), [|sd1, sd2|]), 2.)
}
| (`Distribution(d1), `Distribution(d2)) => {
let (sc1, sd1) = renderDistributionToXYShape(d1, sampleCount);
let (sc2, sd2) = renderDistributionToXYShape(d2, sampleCount);
`RenderedShape(Distributions.Continuous.reduce((+.), [|sc1, sc2|]), Distributions.Discrete.reduce((+.), [|sd1, sd2|]), 2.)
}
| (`Distribution(d1), `RenderedShape(sc2, sd2, i2))
| (`RenderedShape(sc2, sd2, i2), `Distribution(d1)) => {
let (sc1, sd1) = renderDistributionToXYShape(d1, sampleCount);
`RenderedShape(Distributions.Continuous.reduce((+.), [|sc1, sc2|]), Distributions.Discrete.reduce((+.), [|sd1, sd2|]), 1. +. i2)
}
| (`RenderedShape(sc1, sd1, i1), `RenderedShape(sc2, sd2, i2)) => {
Js.log3("Reducing continuous rr", sc1, sc2);
Js.log2("Continuous reduction:", Distributions.Continuous.reduce((+.), [|sc1, sc2|]));
Js.log2("Discrete reduction:", Distributions.Discrete.reduce((+.), [|sd1, sd2|]));
`RenderedShape(Distributions.Continuous.reduce((+.), [|sc1, sc2|]), Distributions.Discrete.reduce((+.), [|sd1, sd2|]), i1 +. i2)
}
}
};
let evaluatePointwiseProduct = (et1: nodeResult, et2: nodeResult, sampleCount: int) => {
switch ((et1, et2)) {
/* Known cases: */
| (`Distribution(`Float(v1)), `Distribution(`Float(v2))) => {
v1 == v2
? `RenderedShape(Distributions.Continuous.empty, Distributions.Discrete.make({xs: [|v1|], ys: [|1.|]}), 1.)
: `RenderedShape(Distributions.Continuous.empty, Distributions.Discrete.empty, 0.) // TODO: add warning: shouldn't pointwise multiply scalars.
}
| (`Distribution(`Float(v1)), `Distribution(d2)) => {
// evaluate d2 at v1
let y = GenericSimple.pdf(v1, d2);
`RenderedShape(Distributions.Continuous.empty, Distributions.Discrete.make({xs: [|v1|], ys: [|y|]}), y)
}
| (`Distribution(d1), `Distribution(`Float(v2))) => {
// evaluate d1 at v2
let y = GenericSimple.pdf(v2, d1);
`RenderedShape(Distributions.Continuous.empty, Distributions.Discrete.make({xs: [|v2|], ys: [|y|]}), y)
}
| (`Distribution(`Normal(n1)), `Distribution(`Normal(n2))) => {
let mean = (n1.mean *. n2.stdev**2. +. n2.mean *. n1.stdev**2.) /. (n1.stdev**2. +. n2.stdev**2.);
let stdev = 1. /. ((1. /. n1.stdev**2.) +. (1. /. n2.stdev**2.));
let integral = 0; // TODO
`RenderedShape(Distributions.Continuous.empty, Distributions.Discrete.empty, 0.)
}
/* General cases */
| (`Distribution(d1), `Distribution(d2)) => {
// NOT IMPLEMENTED YET
// TODO: evaluate integral properly
let (sc1, sd1) = renderDistributionToXYShape(d1, sampleCount);
let (sc2, sd2) = renderDistributionToXYShape(d2, sampleCount);
`RenderedShape(Distributions.Continuous.empty, Distributions.Discrete.empty, 0.)
}
| (`Distribution(d1), `RenderedShape(sc2, sd2, i2)) => {
// NOT IMPLEMENTED YET
// TODO: evaluate integral properly
let (sc1, sd1) = renderDistributionToXYShape(d1, sampleCount);
`RenderedShape(Distributions.Continuous.empty, Distributions.Discrete.empty, 0.)
}
| (`RenderedShape(sc1, sd1, i1), `Distribution(d1)) => {
// NOT IMPLEMENTED YET
// TODO: evaluate integral properly
let (sc2, sd2) = renderDistributionToXYShape(d1, sampleCount);
`RenderedShape(Distributions.Continuous.empty, Distributions.Discrete.empty, 0.)
}
| (`RenderedShape(sc1, sd1, i1), `RenderedShape(sc2, sd2, i2)) => {
`RenderedShape(Distributions.Continuous.empty, Distributions.Discrete.empty, 0.)
}
}
};
let evaluateNormalize = (et: nodeResult, sampleCount: int) => {
// just divide everything by the integral.
switch (et) {
| `RenderedShape(sc, sd, i) => {
// loop through all ys and divide them by i
let normalize = (s: DistTypes.xyShape): DistTypes.xyShape => {xs: s.xs, ys: s.ys |> E.A.fmap(y => y /. i)};
let scn = sc |> Distributions.Continuous.shapeMap(normalize);
let sdn = sd |> normalize;
`RenderedShape(scn, sdn, 1.)
}
| `Distribution(d) => `Distribution(d) // any kind of atomic dist should already be normalized -- TODO: THIS IS ACTUALLY FALSE! E.g. pointwise product of normal * normal
}
};
let evaluateTruncate = (et: nodeResult, xc: cutoffX, compareFunc: (float, float) => bool, sampleCount: int) => {
let cut = (s: DistTypes.xyShape): DistTypes.xyShape => {
let (xs, ys) = s.ys
|> Belt.Array.zip(s.xs)
|> E.A.filter(((x, y)) => compareFunc(x, xc))
|> Belt.Array.unzip
let cutShape: DistTypes.xyShape = {xs, ys};
cutShape;
};
switch (et) {
| `Distribution(d) => {
let (sc, sd) = renderDistributionToXYShape(d, sampleCount);
let scc = sc |> Distributions.Continuous.shapeMap(cut);
let sdc = sd |> cut;
let newIntegral = 1.; // TODO
`RenderedShape(scc, sdc, newIntegral);
}
| `RenderedShape(sc, sd, i) => {
let scc = sc |> Distributions.Continuous.shapeMap(cut);
let sdc = sd |> cut;
let newIntegral = 1.; // TODO
`RenderedShape(scc, sdc, newIntegral);
}
}
};
let evaluateVerticalScaling = (et1: nodeResult, et2: nodeResult, sampleCount: int) => {
let scale = (i: float, s: DistTypes.xyShape): DistTypes.xyShape => {xs: s.xs, ys: s.ys |> E.A.fmap(y => y *. i)};
switch ((et1, et2)) {
| (`Distribution(`Float(v)), `Distribution(d))
| (`Distribution(d), `Distribution(`Float(v))) => {
let (sc, sd) = renderDistributionToXYShape(d, sampleCount);
let scc = sc |> Distributions.Continuous.shapeMap(scale(v));
let sdc = sd |> scale(v);
let newIntegral = v; // TODO
`RenderedShape(scc, sdc, newIntegral);
}
| (`Distribution(`Float(v)), `RenderedShape(sc, sd, i))
| (`RenderedShape(sc, sd, i), `Distribution(`Float(v))) => {
let scc = sc |> Distributions.Continuous.shapeMap(scale(v));
let sdc = sd |> scale(v);
let newIntegral = v; // TODO
`RenderedShape(scc, sdc, newIntegral);
} }
| _ => `RenderedShape(Distributions.Continuous.empty, Distributions.Discrete.empty, 0.) // TODO: give warning
}
}
}, ([||]: array(DistTypes.xyShape), [||]: array(DistTypes.xyShape))) let renderNode = (et: nodeResult, sampleCount: int) => {
switch (et) {
| `Distribution(d) => {
let (sc, sd) = renderDistributionToXYShape(d, sampleCount);
`RenderedShape(sc, sd, 1.0);
}
| s => s
}
}
let rec evaluateNode = (treeNode: distTree, sampleCount: int): nodeResult => {
// returns either a new symbolic distribution
switch (treeNode) {
| `Distribution(d) => evaluateDistribution(d)
| `Combination(t1, t2, op) => evaluateCombinationDistribution(evaluateNode(t1, sampleCount), evaluateNode(t2, sampleCount), op, sampleCount)
| `PointwiseSum(t1, t2) => evaluatePointwiseSum(evaluateNode(t1, sampleCount), evaluateNode(t2, sampleCount), sampleCount)
| `PointwiseProduct(t1, t2) => evaluatePointwiseProduct(evaluateNode(t1, sampleCount), evaluateNode(t2, sampleCount), sampleCount)
| `VerticalScaling(t1, t2) => evaluateVerticalScaling(evaluateNode(t1, sampleCount), evaluateNode(t2, sampleCount), sampleCount)
| `Normalize(t) => evaluateNormalize(evaluateNode(t, sampleCount), sampleCount)
| `LeftTruncate(t, x) => evaluateTruncate(evaluateNode(t, sampleCount), x, (<=), sampleCount)
| `RightTruncate(t, x) => evaluateTruncate(evaluateNode(t, sampleCount), x, (>=), sampleCount)
| `Render(t) => renderNode(evaluateNode(t, sampleCount), sampleCount)
}
}; };
/* let toShape = (treeNode: distTree, sampleCount: int) => {
We will assume that each dist (of t) in the multimodal has a total of one. /*let continuousXs = findContinuousXs(dists, sampleCount);
We can therefore normalize the weights of the parts.
However, a multimodal can consist of both discrete and continuous shapes.
These need to be added and collected individually.
*/
let toShape = (dists: t, sampleCount: int) => {
let continuousXs = findContinuousXs(dists, sampleCount);
continuousXs |> Array.fast_sort(compare); continuousXs |> Array.fast_sort(compare);
let (contShapes, distShapes) = accumulateContAndDiscShapes(dists, continuousXs, 1.0); let (contShapes, distShapes) = accumulateContAndDiscShapes(dists, continuousXs, 1.0);
@ -469,60 +649,42 @@ module PointwiseAddDistributionsWeighted = {
}, {xs: continuousXs, ys: Array.make(Array.length(continuousXs), 0.0)}) }, {xs: continuousXs, ys: Array.make(Array.length(continuousXs), 0.0)})
|> Distributions.Continuous.make(`Linear); |> Distributions.Continuous.make(`Linear);
let combinedDiscrete = Distributions.Discrete.reduce((+.), distShapes) let combinedDiscrete = Distributions.Discrete.reduce((+.), distShapes)*/
let shape = MixedShapeBuilder.buildSimple(~continuous=Some(combinedContinuous), ~discrete=combinedDiscrete); let treeShape = evaluateNode(`Render(`Normalize(treeNode)), sampleCount);
switch (treeShape) {
| `Distribution(_) => E.O.toExn("No shape found!", None)
| `RenderedShape(sc, sd, _) => {
let shape = MixedShapeBuilder.buildSimple(~continuous=Some(sc), ~discrete=sd);
shape |> E.O.toExt(""); shape |> E.O.toExt("");
}
}
}; };
let rec toString = (dists: t): string => { let rec toString = (treeNode: distTree): string => {
let distString = let stringFromOp = op => switch (op) {
dists | `AddOperation => " + "
|> E.A.fmap(((d, _)) => | `SubtractOperation => " - "
switch (d) { | `MultiplyOperation => " * "
| `Simple(d) => GenericSimple.toString(d) | `DivideOperation => " / "
| `PointwiseCombination(ts: t) => ts |> toString | `ExponentiateOperation => "^"
} };
)
|> Js.Array.joinWith(",");
// mm(normal(0,1), normal(1,2)) => "multimodal(normal(0,1), normal(1,2), ) switch (treeNode) {
| `Distribution(d) => GenericSimple.toString(d)
let weights = | `Combination(t1, t2, op) => toString(t1) ++ stringFromOp(op) ++ toString(t2)
dists | `PointwiseSum(t1, t2) => toString(t1) ++ " .+ " ++ toString(t2)
|> E.A.fmap(((_, w)) => | `PointwiseProduct(t1, t2) => toString(t1) ++ " .* " ++ toString(t2)
Js.Float.toPrecisionWithPrecision(w, ~digits=2) | `VerticalScaling(t1, t2) => toString(t1) ++ " @ " ++ toString(t2)
) | `Normalize(t) => "normalize(" ++ toString(t) ++ ")"
|> Js.Array.joinWith(","); | `LeftTruncate(t, x) => "leftTruncate(" ++ toString(t) ++ ", " ++ string_of_float(x) ++ ")"
| `RightTruncate(t, x) => "rightTruncate(" ++ toString(t) ++ ", " ++ string_of_float(x) ++ ")"
{j|multimodal($distString, [$weights])|j}; }
}; };
}; };
// assume that recursive pointwiseNormalizedDistSums are the only type of operation there is. let toString = (treeNode: distTree) => DistTree.toString(treeNode)
// in the original, it was a list of (dist, weight) tuples. Now, it's a tree of (dist, weight) tuples, just that every
// dist can be either a GenericSimple or another PointwiseAdd.
/*let toString = (r: bigDistTree) => { let toShape = (sampleCount: int, treeNode: distTree) =>
switch (r) { DistTree.toShape(treeNode, sampleCount) //~xSelection=`ByWeight,
| WeightedDistLeaf((w, d)) => GenericWeighted.toString(w) // "normal "
| PointwiseNormalizedDistSum(childTrees) => childTrees |> E.A.fmap(toString) |> Js.Array.joinWith("")
}
}*/
let toString = (r: bigDist) =>
// we need to recursively create the string representation of the tree.
r
|> (
fun
| `Simple(d) => GenericSimple.toString(d)
| `PointwiseCombination(d) =>
PointwiseAddDistributionsWeighted.toString(d)
);
let toShape = n =>
fun
| `Simple(d) => GenericSimple.toShape(~xSelection=`ByWeight, d, n)
| `PointwiseCombination(d) =>
PointwiseAddDistributionsWeighted.toShape(d, n);