Big refactor towards proper distTree, still slow and untested
This commit is contained in:
parent
bc271a090b
commit
f6c1918b12
|
@ -74,6 +74,21 @@ module Continuous = {
|
||||||
(fn, {xyShape, interpolation}: t): option(DistTypes.continuousShape) =>
|
(fn, {xyShape, interpolation}: t): option(DistTypes.continuousShape) =>
|
||||||
fn(xyShape) |> E.O.fmap(make(interpolation));
|
fn(xyShape) |> E.O.fmap(make(interpolation));
|
||||||
|
|
||||||
|
let empty: DistTypes.continuousShape = {xyShape: XYShape.T.empty, interpolation: `Linear};
|
||||||
|
let combine =
|
||||||
|
(fn, t1: DistTypes.continuousShape, t2: DistTypes.continuousShape)
|
||||||
|
: DistTypes.continuousShape => {
|
||||||
|
make(`Linear, XYShape.Combine.combine(
|
||||||
|
~xsSelection=ALL_XS,
|
||||||
|
~xToYSelection=XYShape.XtoY.linear,
|
||||||
|
~fn,
|
||||||
|
t1.xyShape,
|
||||||
|
t2.xyShape,
|
||||||
|
));
|
||||||
|
};
|
||||||
|
let reduce = (fn, items) =>
|
||||||
|
items |> E.A.fold_left(combine(fn), empty);
|
||||||
|
|
||||||
let toLinear = (t: t): option(t) => {
|
let toLinear = (t: t): option(t) => {
|
||||||
switch (t) {
|
switch (t) {
|
||||||
| {interpolation: `Stepwise, xyShape} =>
|
| {interpolation: `Stepwise, xyShape} =>
|
||||||
|
@ -166,6 +181,7 @@ module Discrete = {
|
||||||
let sortedByX = (t: DistTypes.discreteShape) =>
|
let sortedByX = (t: DistTypes.discreteShape) =>
|
||||||
t |> XYShape.T.zip |> XYShape.Zipped.sortByX;
|
t |> XYShape.T.zip |> XYShape.Zipped.sortByX;
|
||||||
let empty = XYShape.T.empty;
|
let empty = XYShape.T.empty;
|
||||||
|
let make = (s: DistTypes.discreteShape) => s;
|
||||||
let combine =
|
let combine =
|
||||||
(fn, t1: DistTypes.discreteShape, t2: DistTypes.discreteShape)
|
(fn, t1: DistTypes.discreteShape, t2: DistTypes.discreteShape)
|
||||||
: DistTypes.discreteShape => {
|
: DistTypes.discreteShape => {
|
||||||
|
|
|
@ -179,16 +179,25 @@ module Combine = {
|
||||||
t1: T.t,
|
t1: T.t,
|
||||||
t2: T.t,
|
t2: T.t,
|
||||||
) => {
|
) => {
|
||||||
let allXs =
|
|
||||||
switch (xsSelection) {
|
|
||||||
| ALL_XS => Ts.allXs([|t1, t2|])
|
|
||||||
| XS_EVENLY_DIVIDED(sampleCount) =>
|
|
||||||
Ts.equallyDividedXs([|t1, t2|], sampleCount)
|
|
||||||
};
|
|
||||||
|
|
||||||
let allYs =
|
switch ((E.A.length(t1.xs), E.A.length(t2.xs))) {
|
||||||
allXs |> E.A.fmap(x => fn(xToYSelection(x, t1), xToYSelection(x, t2)));
|
| (0, 0) => T.empty
|
||||||
T.fromArrays(allXs, allYs);
|
| (0, _) => t2
|
||||||
|
| (_, 0) => t1
|
||||||
|
| (_, _) => {
|
||||||
|
let allXs =
|
||||||
|
switch (xsSelection) {
|
||||||
|
| ALL_XS => Ts.allXs([|t1, t2|])
|
||||||
|
| XS_EVENLY_DIVIDED(sampleCount) =>
|
||||||
|
Ts.equallyDividedXs([|t1, t2|], sampleCount)
|
||||||
|
};
|
||||||
|
|
||||||
|
let allYs =
|
||||||
|
allXs |> E.A.fmap(x => fn(xToYSelection(x, t1), xToYSelection(x, t2)));
|
||||||
|
|
||||||
|
T.fromArrays(allXs, allYs);
|
||||||
|
}
|
||||||
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
let combineLinear = combine(~xToYSelection=XtoY.linear);
|
let combineLinear = combine(~xToYSelection=XtoY.linear);
|
||||||
|
|
|
@ -43,7 +43,7 @@ module ShapeRenderer = {
|
||||||
module Symbolic = {
|
module Symbolic = {
|
||||||
type inputs = {length: int};
|
type inputs = {length: int};
|
||||||
type outputs = {
|
type outputs = {
|
||||||
graph: SymbolicDist.bigDist,
|
graph: SymbolicDist.distTree,
|
||||||
shape: DistTypes.shape,
|
shape: DistTypes.shape,
|
||||||
};
|
};
|
||||||
let make = (graph, shape) => {graph, shape};
|
let make = (graph, shape) => {graph, shape};
|
||||||
|
|
|
@ -88,73 +88,69 @@ module MathAdtToDistDst = {
|
||||||
);
|
);
|
||||||
};
|
};
|
||||||
|
|
||||||
let normal: array(arg) => result(SymbolicDist.bigDist, string) =
|
let normal: array(arg) => result(SymbolicDist.distTree, string) =
|
||||||
fun
|
fun
|
||||||
| [|Value(mean), Value(stdev)|] =>
|
| [|Value(mean), Value(stdev)|] =>
|
||||||
Ok(`Simple(`Normal({mean, stdev})))
|
Ok(`Distribution(`Normal({mean, stdev})))
|
||||||
| _ => Error("Wrong number of variables in normal distribution");
|
| _ => Error("Wrong number of variables in normal distribution");
|
||||||
|
|
||||||
let lognormal: array(arg) => result(SymbolicDist.bigDist, string) =
|
let lognormal: array(arg) => result(SymbolicDist.distTree, string) =
|
||||||
fun
|
fun
|
||||||
| [|Value(mu), Value(sigma)|] => Ok(`Simple(`Lognormal({mu, sigma})))
|
| [|Value(mu), Value(sigma)|] => Ok(`Distribution(`Lognormal({mu, sigma})))
|
||||||
| [|Object(o)|] => {
|
| [|Object(o)|] => {
|
||||||
let g = Js.Dict.get(o);
|
let g = Js.Dict.get(o);
|
||||||
switch (g("mean"), g("stdev"), g("mu"), g("sigma")) {
|
switch (g("mean"), g("stdev"), g("mu"), g("sigma")) {
|
||||||
| (Some(Value(mean)), Some(Value(stdev)), _, _) =>
|
| (Some(Value(mean)), Some(Value(stdev)), _, _) =>
|
||||||
Ok(`Simple(SymbolicDist.Lognormal.fromMeanAndStdev(mean, stdev)))
|
Ok(`Distribution(SymbolicDist.Lognormal.fromMeanAndStdev(mean, stdev)))
|
||||||
| (_, _, Some(Value(mu)), Some(Value(sigma))) =>
|
| (_, _, Some(Value(mu)), Some(Value(sigma))) =>
|
||||||
Ok(`Simple(`Lognormal({mu, sigma})))
|
Ok(`Distribution(`Lognormal({mu, sigma})))
|
||||||
| _ => Error("Lognormal distribution would need mean and stdev")
|
| _ => Error("Lognormal distribution would need mean and stdev")
|
||||||
};
|
};
|
||||||
}
|
}
|
||||||
| _ => Error("Wrong number of variables in lognormal distribution");
|
| _ => Error("Wrong number of variables in lognormal distribution");
|
||||||
|
|
||||||
let to_: array(arg) => result(SymbolicDist.bigDist, string) =
|
let to_: array(arg) => result(SymbolicDist.distTree, string) =
|
||||||
fun
|
fun
|
||||||
| [|Value(low), Value(high)|] when low <= 0.0 && low < high=> {
|
| [|Value(low), Value(high)|] when low <= 0.0 && low < high=> {
|
||||||
Ok(`Simple(SymbolicDist.Normal.from90PercentCI(low, high)));
|
Ok(`Distribution(SymbolicDist.Normal.from90PercentCI(low, high)));
|
||||||
}
|
}
|
||||||
| [|Value(low), Value(high)|] when low < high => {
|
| [|Value(low), Value(high)|] when low < high => {
|
||||||
Ok(`Simple(SymbolicDist.Lognormal.from90PercentCI(low, high)));
|
Ok(`Distribution(SymbolicDist.Lognormal.from90PercentCI(low, high)));
|
||||||
}
|
}
|
||||||
| [|Value(_), Value(_)|] =>
|
| [|Value(_), Value(_)|] =>
|
||||||
Error("Low value must be less than high value.")
|
Error("Low value must be less than high value.")
|
||||||
| _ => Error("Wrong number of variables in lognormal distribution");
|
| _ => Error("Wrong number of variables in lognormal distribution");
|
||||||
|
|
||||||
let uniform: array(arg) => result(SymbolicDist.bigDist, string) =
|
let uniform: array(arg) => result(SymbolicDist.distTree, string) =
|
||||||
fun
|
fun
|
||||||
| [|Value(low), Value(high)|] => Ok(`Simple(`Uniform({low, high})))
|
| [|Value(low), Value(high)|] => Ok(`Distribution(`Uniform({low, high})))
|
||||||
| _ => Error("Wrong number of variables in lognormal distribution");
|
| _ => Error("Wrong number of variables in lognormal distribution");
|
||||||
|
|
||||||
let beta: array(arg) => result(SymbolicDist.bigDist, string) =
|
let beta: array(arg) => result(SymbolicDist.distTree, string) =
|
||||||
fun
|
fun
|
||||||
| [|Value(alpha), Value(beta)|] => Ok(`Simple(`Beta({alpha, beta})))
|
| [|Value(alpha), Value(beta)|] => Ok(`Distribution(`Beta({alpha, beta})))
|
||||||
| _ => Error("Wrong number of variables in lognormal distribution");
|
| _ => Error("Wrong number of variables in lognormal distribution");
|
||||||
|
|
||||||
let exponential: array(arg) => result(SymbolicDist.bigDist, string) =
|
let exponential: array(arg) => result(SymbolicDist.distTree, string) =
|
||||||
fun
|
fun
|
||||||
| [|Value(rate)|] => Ok(`Simple(`Exponential({rate: rate})))
|
| [|Value(rate)|] => Ok(`Distribution(`Exponential({rate: rate})))
|
||||||
| _ => Error("Wrong number of variables in Exponential distribution");
|
| _ => Error("Wrong number of variables in Exponential distribution");
|
||||||
|
|
||||||
let cauchy: array(arg) => result(SymbolicDist.bigDist, string) =
|
let cauchy: array(arg) => result(SymbolicDist.distTree, string) =
|
||||||
fun
|
fun
|
||||||
| [|Value(local), Value(scale)|] =>
|
| [|Value(local), Value(scale)|] =>
|
||||||
Ok(`Simple(`Cauchy({local, scale})))
|
Ok(`Distribution(`Cauchy({local, scale})))
|
||||||
| _ => Error("Wrong number of variables in cauchy distribution");
|
| _ => Error("Wrong number of variables in cauchy distribution");
|
||||||
|
|
||||||
let triangular: array(arg) => result(SymbolicDist.bigDist, string) =
|
let triangular: array(arg) => result(SymbolicDist.distTree, string) =
|
||||||
fun
|
fun
|
||||||
| [|Value(low), Value(medium), Value(high)|] =>
|
| [|Value(low), Value(medium), Value(high)|] =>
|
||||||
Ok(`Simple(`Triangular({low, medium, high})))
|
Ok(`Distribution(`Triangular({low, medium, high})))
|
||||||
| _ => Error("Wrong number of variables in triangle distribution");
|
| _ => Error("Wrong number of variables in triangle distribution");
|
||||||
|
|
||||||
/*let add: array(arg) => result(SymbolicDist.bigDist, string) =
|
|
||||||
fun
|
|
||||||
| */
|
|
||||||
|
|
||||||
let multiModal =
|
let multiModal =
|
||||||
(
|
(
|
||||||
args: array(result(SymbolicDist.bigDist, string)),
|
args: array(result(SymbolicDist.distTree, string)),
|
||||||
weights: option(array(float)),
|
weights: option(array(float)),
|
||||||
) => {
|
) => {
|
||||||
let weights = weights |> E.O.default([||]);
|
let weights = weights |> E.O.default([||]);
|
||||||
|
@ -162,8 +158,14 @@ module MathAdtToDistDst = {
|
||||||
args
|
args
|
||||||
|> E.A.fmap(
|
|> E.A.fmap(
|
||||||
fun
|
fun
|
||||||
| Ok(`Simple(d)) => Ok(`Simple(d))
|
| Ok(`Distribution(d)) => Ok(`Distribution(d))
|
||||||
| Ok(`PointwiseCombination(dists)) => Ok(`PointwiseCombination(dists))
|
| Ok(`Combination(t1, t2, op)) => Ok(`Combination(t1, t2, op))
|
||||||
|
| Ok(`PointwiseSum(t1, t2)) => Ok(`PointwiseSum(t1, t2))
|
||||||
|
| Ok(`PointwiseProduct(t1, t2)) => Ok(`PointwiseProduct(t1, t2))
|
||||||
|
| Ok(`Normalize(t)) => Ok(`Normalize(t))
|
||||||
|
| Ok(`LeftTruncate(t, x)) => Ok(`LeftTruncate(t, x))
|
||||||
|
| Ok(`RightTruncate(t, x)) => Ok(`RightTruncate(t, x))
|
||||||
|
| Ok(`Render(t)) => Ok(`Render(t))
|
||||||
| Error(e) => Error(e)
|
| Error(e) => Error(e)
|
||||||
| _ => Error("Unexpected dist")
|
| _ => Error("Unexpected dist")
|
||||||
);
|
);
|
||||||
|
@ -175,16 +177,26 @@ module MathAdtToDistDst = {
|
||||||
| Some(Error(e)) => Error(e)
|
| Some(Error(e)) => Error(e)
|
||||||
| None when withoutErrors |> E.A.length == 0 =>
|
| None when withoutErrors |> E.A.length == 0 =>
|
||||||
Error("Multimodals need at least one input")
|
Error("Multimodals need at least one input")
|
||||||
| _ =>
|
| _ => {
|
||||||
withoutErrors
|
let components = withoutErrors
|
||||||
|> E.A.fmapi((index, item) =>
|
|> E.A.fmapi((index, t) => {
|
||||||
(item, weights |> E.A.get(_, index) |> E.O.default(1.0))
|
let w = weights |> E.A.get(_, index) |> E.O.default(1.0);
|
||||||
)
|
|
||||||
|> (r => Ok(`PointwiseCombination(r)))
|
`VerticalScaling(t, `Distribution(`Float(w)))
|
||||||
|
});
|
||||||
|
|
||||||
|
let pointwiseSum = components
|
||||||
|
|> Js.Array.sliceFrom(1)
|
||||||
|
|> E.A.fold_left((acc, x) => {
|
||||||
|
`PointwiseSum(acc, x)
|
||||||
|
}, E.A.unsafe_get(components, 0))
|
||||||
|
|
||||||
|
Ok(`Normalize(pointwiseSum))
|
||||||
|
}
|
||||||
};
|
};
|
||||||
};
|
};
|
||||||
|
|
||||||
let arrayParser = (args:array(arg)):result(SymbolicDist.bigDist, string) => {
|
let arrayParser = (args:array(arg)):result(SymbolicDist.distTree, string) => {
|
||||||
let samples = args
|
let samples = args
|
||||||
|> E.A.fmap(
|
|> E.A.fmap(
|
||||||
fun
|
fun
|
||||||
|
@ -200,13 +212,13 @@ module MathAdtToDistDst = {
|
||||||
SymbolicDist.ContinuousShape.make(_pdf, cdf)
|
SymbolicDist.ContinuousShape.make(_pdf, cdf)
|
||||||
});
|
});
|
||||||
switch(shape){
|
switch(shape){
|
||||||
| Some(s) => Ok(`Simple(`ContinuousShape(s)))
|
| Some(s) => Ok(`Distribution(`ContinuousShape(s)))
|
||||||
| None => Error("Rendering did not work")
|
| None => Error("Rendering did not work")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
let rec functionParser = (r): result(SymbolicDist.bigDist, string) =>
|
let rec functionParser = (r): result(SymbolicDist.distTree, string) =>
|
||||||
r
|
r
|
||||||
|> (
|
|> (
|
||||||
fun
|
fun
|
||||||
|
@ -218,7 +230,7 @@ module MathAdtToDistDst = {
|
||||||
| Fn({name: "exponential", args}) => exponential(args)
|
| Fn({name: "exponential", args}) => exponential(args)
|
||||||
| Fn({name: "cauchy", args}) => cauchy(args)
|
| Fn({name: "cauchy", args}) => cauchy(args)
|
||||||
| Fn({name: "triangular", args}) => triangular(args)
|
| Fn({name: "triangular", args}) => triangular(args)
|
||||||
| Value(f) => Ok(`Simple(`Float(f)))
|
| Value(f) => Ok(`Distribution(`Float(f)))
|
||||||
| Fn({name: "mm", args}) => {
|
| Fn({name: "mm", args}) => {
|
||||||
let weights =
|
let weights =
|
||||||
args
|
args
|
||||||
|
@ -245,25 +257,54 @@ module MathAdtToDistDst = {
|
||||||
let dists = possibleDists |> E.A.fmap(functionParser);
|
let dists = possibleDists |> E.A.fmap(functionParser);
|
||||||
multiModal(dists, weights);
|
multiModal(dists, weights);
|
||||||
}
|
}
|
||||||
//| Fn({name: "add", args}) => add(args)
|
|
||||||
|
| Fn({name: "add", args}) => {
|
||||||
|
args
|
||||||
|
|> E.A.fmap(functionParser)
|
||||||
|
|> (fun
|
||||||
|
| [|Ok(l), Ok(r)|] => Ok(`Combination(l, r, `AddOperation))
|
||||||
|
| _ => Error("Addition needs two operands"))
|
||||||
|
}
|
||||||
|
| Fn({name: "subtract", args}) => {
|
||||||
|
args
|
||||||
|
|> E.A.fmap(functionParser)
|
||||||
|
|> (fun
|
||||||
|
| [|Ok(l), Ok(r)|] => Ok(`Combination(l, r, `SubtractOperation))
|
||||||
|
| _ => Error("Subtraction needs two operands"))
|
||||||
|
}
|
||||||
|
| Fn({name: "multiply", args}) => {
|
||||||
|
args
|
||||||
|
|> E.A.fmap(functionParser)
|
||||||
|
|> (fun
|
||||||
|
| [|Ok(l), Ok(r)|] => Ok(`Combination(l, r, `MultiplyOperation))
|
||||||
|
| _ => Error("Multiplication needs two operands"))
|
||||||
|
}
|
||||||
|
| Fn({name: "divide", args}) => {
|
||||||
|
args
|
||||||
|
|> E.A.fmap(functionParser)
|
||||||
|
|> (fun
|
||||||
|
| [|Ok(l), Ok(`Distribution(`Float(0.0)))|] => Error("Division by zero")
|
||||||
|
| [|Ok(l), Ok(r)|] => Ok(`Combination(l, r, `DivideOperation))
|
||||||
|
| _ => Error("Division needs two operands"))
|
||||||
|
}
|
||||||
| Fn({name}) => Error(name ++ ": function not supported")
|
| Fn({name}) => Error(name ++ ": function not supported")
|
||||||
| _ => {
|
| _ => {
|
||||||
Error("This type not currently supported");
|
Error("This type not currently supported");
|
||||||
}
|
}
|
||||||
);
|
);
|
||||||
|
|
||||||
let topLevel = (r): result(SymbolicDist.bigDist, string) =>
|
let topLevel = (r): result(SymbolicDist.distTree, string) =>
|
||||||
r
|
r
|
||||||
|> (
|
|> (
|
||||||
fun
|
fun
|
||||||
| Fn(_) => functionParser(r)
|
| Fn(_) => functionParser(r)
|
||||||
| Value(r) => Ok(`Simple(`Float(r)))
|
| Value(r) => Ok(`Distribution(`Float(r)))
|
||||||
| Array(r) => arrayParser(r)
|
| Array(r) => arrayParser(r)
|
||||||
| Symbol(_) => Error("Symbol not valid as top level")
|
| Symbol(_) => Error("Symbol not valid as top level")
|
||||||
| Object(_) => Error("Object not valid as top level")
|
| Object(_) => Error("Object not valid as top level")
|
||||||
);
|
);
|
||||||
|
|
||||||
let run = (r): result(SymbolicDist.bigDist, string) =>
|
let run = (r): result(SymbolicDist.distTree, string) =>
|
||||||
r |> MathAdtCleaner.run |> topLevel;
|
r |> MathAdtCleaner.run |> topLevel;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
|
|
@ -50,41 +50,28 @@ type dist = [
|
||||||
| `Float(float) // Dirac delta at x. Practically useful only in the context of multimodals.
|
| `Float(float) // Dirac delta at x. Practically useful only in the context of multimodals.
|
||||||
];
|
];
|
||||||
|
|
||||||
/* Build a tree.
|
type integral = float;
|
||||||
|
type cutoffX = float;
|
||||||
|
type operation = [
|
||||||
|
| `AddOperation
|
||||||
|
| `SubtractOperation
|
||||||
|
| `MultiplyOperation
|
||||||
|
| `DivideOperation
|
||||||
|
| `ExponentiateOperation
|
||||||
|
];
|
||||||
|
|
||||||
Multiple operations possible:
|
type distTree = [
|
||||||
|
| `Distribution(dist)
|
||||||
- PointwiseSum(Scalar, Scalar)
|
| `Combination(distTree, distTree, operation)
|
||||||
- PointwiseSum(WeightedDist, WeightedDist)
|
| `PointwiseSum(distTree, distTree)
|
||||||
- PointwiseProduct(Scalar, Scalar)
|
| `PointwiseProduct(distTree, distTree)
|
||||||
- PointwiseProduct(Scalar, WeightedDist)
|
| `VerticalScaling(distTree, distTree)
|
||||||
- PointwiseProduct(WeightedDist, WeightedDist)
|
| `Normalize(distTree)
|
||||||
|
| `LeftTruncate(distTree, cutoffX)
|
||||||
- IndependentVariableSum(WeightedDist, WeightedDist) [i.e., convolution]
|
| `RightTruncate(distTree, cutoffX)
|
||||||
- IndependentVariableProduct(WeightedDist, WeightedDist) [i.e. distribution product]
|
| `Render(distTree)
|
||||||
*/
|
]
|
||||||
|
and weightedDists = array((distTree, float));
|
||||||
/*type weightedDist = (float, dist);
|
|
||||||
|
|
||||||
type bigDistTree =
|
|
||||||
/* | DistLeaf(dist) */
|
|
||||||
/* | ScalarLeaf(float) */
|
|
||||||
/* | PointwiseScalarDistProduct(DistLeaf(d), ScalarLeaf(s)) */
|
|
||||||
| WeightedDistLeaf(weightedDist)
|
|
||||||
| PointwiseNormalizedDistSum(array(bigDistTree));
|
|
||||||
|
|
||||||
let rec treeIntegral = item => {
|
|
||||||
switch (item) {
|
|
||||||
| WeightedDistLeaf((w, d)) => w
|
|
||||||
| PointwiseNormalizedDistSum(childTrees) =>
|
|
||||||
childTrees |> E.A.fmap(treeIntegral) |> E.A.Floats.sum
|
|
||||||
};
|
|
||||||
};*/
|
|
||||||
|
|
||||||
/* bigDist can either be a single distribution, or a
|
|
||||||
PointwiseCombination, i.e. an array of (dist, weight) tuples */
|
|
||||||
type bigDist = [ | `Simple(dist) | `PointwiseCombination(pointwiseAdd)]
|
|
||||||
and pointwiseAdd = array((bigDist, float));
|
|
||||||
|
|
||||||
module ContinuousShape = {
|
module ContinuousShape = {
|
||||||
type t = continuousShape;
|
type t = continuousShape;
|
||||||
|
@ -326,138 +313,331 @@ module GenericSimple = {
|
||||||
};
|
};
|
||||||
};
|
};
|
||||||
|
|
||||||
module PointwiseAddDistributionsWeighted = {
|
module DistTree = {
|
||||||
type t = pointwiseAdd;
|
type nodeResult = [
|
||||||
|
| `Distribution(dist)
|
||||||
|
// RenderedShape: continuous xyShape, discrete xyShape, total value.
|
||||||
|
| `RenderedShape(DistTypes.continuousShape, DistTypes.discreteShape, integral)
|
||||||
|
];
|
||||||
|
|
||||||
let normalizeWeights = (weightedDists: t) => {
|
let evaluateDistribution = (d: dist): nodeResult => {
|
||||||
let total = weightedDists |> E.A.fmap(snd) |> E.A.Floats.sum;
|
// certain distributions we may want to evaluate to RenderedShapes right away, e.g. discrete
|
||||||
weightedDists |> E.A.fmap(((d, w)) => (d, w /. total));
|
`Distribution(d)
|
||||||
};
|
};
|
||||||
|
|
||||||
let rec pdf = (x: float, weightedNormalizedDists: t) =>
|
// This is a performance bottleneck!
|
||||||
weightedNormalizedDists
|
// Using raw JS here so we can use native for loops and access array elements
|
||||||
|> E.A.fmap(((d, w)) => {
|
// directly, without option checks.
|
||||||
switch (d) {
|
let jsCombinationConvolve: (array(float), array(float), array(float), array(float), float => float => float) => (array(float), array(float)) = [%bs.raw
|
||||||
| `PointwiseCombination(ts) => pdf(x, ts) *. w
|
{|
|
||||||
| `Simple(d) => GenericSimple.pdf(x, d) *. w
|
function (s1xs, s1ys, s2xs, s2ys, func) {
|
||||||
}
|
const r = new Map();
|
||||||
})
|
|
||||||
|> E.A.Floats.sum;
|
|
||||||
|
|
||||||
// TODO: perhaps rename into minCdfX?
|
// To convolve, add the xs and multiply the ys:
|
||||||
// TODO: how should nonexistent min values be handled? They should never happen
|
for (let i = 0; i < s1xs.length; i++) {
|
||||||
let rec min = (dists: t) =>
|
for (let j = 0; j < s2xs.length; j++) {
|
||||||
dists
|
const x = func(s1xs[i], s2xs[j]);
|
||||||
|> E.A.fmap(((d, w)) => {
|
const cv = r.get(x) | 0;
|
||||||
switch (d) {
|
r.set(x, cv + s1ys[i] * s2ys[j]); // add up the ys, if same x
|
||||||
| `PointwiseCombination(ts) => E.O.toExn("Dist has no min", min(ts))
|
}
|
||||||
| `Simple(d) => GenericSimple.min(d)
|
}
|
||||||
}
|
|
||||||
})
|
|
||||||
|> E.A.min;
|
|
||||||
|
|
||||||
// TODO: perhaps rename into minCdfX?
|
const rxys = [...r.entries()];
|
||||||
let rec max = (dists: t) =>
|
rxys.sort(([x1, y1], [x2, y2]) => x1 - x2);
|
||||||
dists
|
|
||||||
|> E.A.fmap(((d, w)) => {
|
|
||||||
switch (d) {
|
|
||||||
| `PointwiseCombination(ts) => E.O.toExn("Dist has no max", max(ts))
|
|
||||||
| `Simple(d) => GenericSimple.max(d)
|
|
||||||
}
|
|
||||||
})
|
|
||||||
|> E.A.max;
|
|
||||||
|
|
||||||
|
const rxs = new Array(rxys.length);
|
||||||
|
const rys = new Array(rxys.length);
|
||||||
|
|
||||||
/*let rec discreteShape = (dists: t, sampleCount: int) => {
|
for (let i = 0; i < rxys.length; i++) {
|
||||||
let discrete =
|
rxs[i] = rxys[i][0];
|
||||||
dists
|
rys[i] = rxys[i][1];
|
||||||
|> E.A.fmap(((x, w)) => {
|
}
|
||||||
switch (d) {
|
|
||||||
| `Float(d) => Some((d, w)) // if the distribution is just a number, then the weight is considered the y
|
|
||||||
| _ => None
|
|
||||||
}
|
|
||||||
})
|
|
||||||
|> E.A.O.concatSomes
|
|
||||||
|> E.A.fmap(((x, y)) =>
|
|
||||||
({xs: [|x|], ys: [|y|]}: DistTypes.xyShape)
|
|
||||||
)
|
|
||||||
// take an array of xyShapes and combine them together
|
|
||||||
//* r
|
|
||||||
|> (
|
|
||||||
fun
|
|
||||||
| `Float(r) => Some((r, e))
|
|
||||||
| _ => None
|
|
||||||
)
|
|
||||||
)*/
|
|
||||||
|> Distributions.Discrete.reduce((+.));
|
|
||||||
discrete;
|
|
||||||
};*/
|
|
||||||
|
|
||||||
|
return [rxs, rys];
|
||||||
|
}
|
||||||
|
|}];
|
||||||
|
|
||||||
let rec findContinuousXs = (dists: t, sampleCount: int) => {
|
let funcFromOp = (op: operation) => {
|
||||||
// we need to go through the tree of distributions and, for the continuous ones, find the xs at which
|
switch (op) {
|
||||||
// later, all distributions will get evaluated.
|
| `AddOperation => (+.)
|
||||||
|
| `SubtractOperation => (-.)
|
||||||
|
| `MultiplyOperation => (*.)
|
||||||
|
| `DivideOperation => (/.)
|
||||||
|
| `ExponentiateOperation => (**)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// we want to accumulate a set of xs.
|
let renderDistributionToXYShape = (d: dist, sampleCount: int): (DistTypes.continuousShape, DistTypes.discreteShape) => {
|
||||||
let xs: array(float) =
|
// render the distribution into an XY shape
|
||||||
dists
|
switch (d) {
|
||||||
|> E.A.fold_left((accXs, (d, w)) => {
|
| `Float(v) => (Distributions.Continuous.empty, {xs: [|v|], ys: [|1.0|]})
|
||||||
switch (d) {
|
| _ => {
|
||||||
| `Simple(t) when (GenericSimple.contType(t) == `Discrete) => accXs
|
let xs = GenericSimple.interpolateXs(~xSelection=`ByWeight, d, sampleCount);
|
||||||
| `Simple(d) => {
|
let ys = xs |> E.A.fmap(x => GenericSimple.pdf(x, d));
|
||||||
let xs = GenericSimple.interpolateXs(~xSelection=`ByWeight, d, sampleCount)
|
(Distributions.Continuous.make(`Linear, {xs: xs, ys: ys}), XYShape.T.empty)
|
||||||
|
}
|
||||||
E.A.append(accXs, xs)
|
}
|
||||||
}
|
|
||||||
| `PointwiseCombination(ts) => {
|
|
||||||
let xs = findContinuousXs(ts, sampleCount);
|
|
||||||
E.A.append(accXs, xs)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}, [||]);
|
|
||||||
xs
|
|
||||||
};
|
};
|
||||||
|
|
||||||
/* Accumulate (accContShapes, accDistShapes), each of which is an array of {xs, ys} shapes. */
|
let combinationDistributionOfXYShapes = (sc1: DistTypes.continuousShape, // continuous shape
|
||||||
let rec accumulateContAndDiscShapes = (dists: t, continuousXs: array(float), currentWeight) => {
|
sd1: DistTypes.discreteShape, // discrete shape
|
||||||
let normalized = normalizeWeights(dists);
|
sc2: DistTypes.continuousShape,
|
||||||
|
sd2: DistTypes.discreteShape, func): (DistTypes.continuousShape, DistTypes.discreteShape) => {
|
||||||
|
|
||||||
normalized
|
let (ccxs, ccys) = jsCombinationConvolve(sc1.xyShape.xs, sc1.xyShape.ys, sc2.xyShape.xs, sc2.xyShape.ys, func);
|
||||||
|> E.A.fold_left(((accContShapes: array(DistTypes.xyShape), accDiscShapes: array(DistTypes.xyShape)), (d, w)) => {
|
let (dcxs, dcys) = jsCombinationConvolve(sd1.xs, sd1.ys, sc2.xyShape.xs, sc2.xyShape.ys, func);
|
||||||
switch (d) {
|
let (cdxs, cdys) = jsCombinationConvolve(sc1.xyShape.xs, sc1.xyShape.ys, sd2.xs, sd2.ys, func);
|
||||||
|
let (ddxs, ddys) = jsCombinationConvolve(sd1.xs, sd1.ys, sd2.xs, sd2.ys, func);
|
||||||
|
|
||||||
| `Simple(`Float(x)) => {
|
let ccxy = Distributions.Continuous.make(`Linear, {xs: ccxs, ys: ccys});
|
||||||
let ds: DistTypes.xyShape = {xs: [|x|], ys: [|w *. currentWeight|]};
|
let dcxy = Distributions.Continuous.make(`Linear, {xs: dcxs, ys: dcys});
|
||||||
(accContShapes, E.A.append(accDiscShapes, [|ds|]))
|
let cdxy = Distributions.Continuous.make(`Linear, {xs: cdxs, ys: cdys});
|
||||||
}
|
// the continuous parts are added up; only the discrete-discrete sum is discrete
|
||||||
|
let continuousShapeSum = Distributions.Continuous.reduce((+.), [|ccxy, dcxy, cdxy|]);
|
||||||
|
|
||||||
| `Simple(d) when (GenericSimple.contType(d) == `Continuous) => {
|
let ddxy: DistTypes.discreteShape = {xs: cdxs, ys: cdys};
|
||||||
let ys = continuousXs |> E.A.fmap(x => GenericSimple.pdf(x, d) *. w *. currentWeight);
|
|
||||||
let cs = XYShape.T.fromArrays(continuousXs, ys);
|
|
||||||
|
|
||||||
(E.A.append(accContShapes, [|cs|]), accDiscShapes)
|
(continuousShapeSum, ddxy)
|
||||||
}
|
};
|
||||||
|
|
||||||
| `Simple(d) => (accContShapes, accDiscShapes) // default -- should never happen
|
let evaluateCombinationDistribution = (et1: nodeResult, et2: nodeResult, op: operation, sampleCount: int) => {
|
||||||
|
/* return either a Distribution or a RenderedShape. Must integrate to 1. */
|
||||||
|
|
||||||
| `PointwiseCombination(ts) => {
|
let func = funcFromOp(op);
|
||||||
let (cs, ds) = accumulateContAndDiscShapes(ts, continuousXs, w *. currentWeight);
|
switch ((et1, et2, op)) {
|
||||||
(E.A.append(accContShapes, cs), E.A.append(accDiscShapes, ds))
|
/* Known cases: replace symbolic with symbolic distribution */
|
||||||
}
|
| (`Distribution(`Float(v1)), `Distribution(`Float(v2)), _) => {
|
||||||
|
`Distribution(`Float(func(v1, v2)))
|
||||||
|
}
|
||||||
|
|
||||||
|
| (`Distribution(`Float(v1)), `Distribution(`Normal(n2)), `AddOperation) => {
|
||||||
|
let n: normal = {mean: v1 +. n2.mean, stdev: n2.stdev};
|
||||||
|
`Distribution(`Normal(n))
|
||||||
|
}
|
||||||
|
|
||||||
|
| (`Distribution(`Normal(n1)), `Distribution(`Normal(n2)), `AddOperation) => {
|
||||||
|
let n: normal = {mean: n1.mean +. n2.mean, stdev: sqrt(n1.stdev ** 2. +. n2.stdev ** 2.)};
|
||||||
|
`Distribution(`Normal(n));
|
||||||
|
}
|
||||||
|
|
||||||
|
/* General cases: convolve the XYShapes */
|
||||||
|
| (`Distribution(d1), `Distribution(d2), _) => {
|
||||||
|
let (sc1, sd1) = renderDistributionToXYShape(d1, sampleCount);
|
||||||
|
let (sc2, sd2) = renderDistributionToXYShape(d2, sampleCount);
|
||||||
|
let (sc, sd) = combinationDistributionOfXYShapes(sc1, sd1, sc2, sd2, func);
|
||||||
|
`RenderedShape(sc, sd, 1.0)
|
||||||
|
}
|
||||||
|
| (`Distribution(d1), `RenderedShape(sc2, sd2, i2), _) => {
|
||||||
|
let (sc1, sd1) = renderDistributionToXYShape(d1, sampleCount);
|
||||||
|
let (sc, sd) = combinationDistributionOfXYShapes(sc1, sd1, sc2, sd2, func);
|
||||||
|
`RenderedShape(sc, sd, i2)
|
||||||
|
}
|
||||||
|
| (`RenderedShape(sc1, sd1, i1), `Distribution(d2), _) => {
|
||||||
|
let (sc2, sd2) = renderDistributionToXYShape(d2, sampleCount);
|
||||||
|
let (sc, sd) = combinationDistributionOfXYShapes(sc1, sd1, sc2, sd2, func);
|
||||||
|
`RenderedShape(sc, sd, i1);
|
||||||
|
}
|
||||||
|
| (`RenderedShape(sc1, sd1, i1), `RenderedShape(sc2, sd2, i2), _) => {
|
||||||
|
// sum of two multimodals that have a continuous and discrete each.
|
||||||
|
let (sc, sd) = combinationDistributionOfXYShapes(sc1, sd1, sc2, sd2, func);
|
||||||
|
|
||||||
|
`RenderedShape(sc, sd, i1);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
let evaluatePointwiseSum = (et1: nodeResult, et2: nodeResult, sampleCount: int) => {
|
||||||
|
switch ((et1, et2)) {
|
||||||
|
/* Known cases: */
|
||||||
|
| (`Distribution(`Float(v1)), `Distribution(`Float(v2))) => {
|
||||||
|
v1 == v2
|
||||||
|
? `RenderedShape(Distributions.Continuous.empty, Distributions.Discrete.make({xs: [|v1|], ys: [|2.|]}), 2.)
|
||||||
|
: `RenderedShape(Distributions.Continuous.empty, Distributions.Discrete.empty, 0.) // TODO: add warning: shouldn't pointwise add scalars.
|
||||||
|
}
|
||||||
|
| (`Distribution(`Float(v1)), `Distribution(d2)) => {
|
||||||
|
let sd1: DistTypes.xyShape = {xs: [|v1|], ys: [|1.|]};
|
||||||
|
let (sc2, sd2) = renderDistributionToXYShape(d2, sampleCount);
|
||||||
|
`RenderedShape(sc2, Distributions.Discrete.reduce((+.), [|sd1, sd2|]), 2.)
|
||||||
|
}
|
||||||
|
| (`Distribution(d1), `Distribution(`Float(v2))) => {
|
||||||
|
let (sc1, sd1) = renderDistributionToXYShape(d1, sampleCount);
|
||||||
|
let sd2: DistTypes.xyShape = {xs: [|v2|], ys: [|1.|]};
|
||||||
|
`RenderedShape(sc1, Distributions.Discrete.reduce((+.), [|sd1, sd2|]), 2.)
|
||||||
|
}
|
||||||
|
| (`Distribution(d1), `Distribution(d2)) => {
|
||||||
|
let (sc1, sd1) = renderDistributionToXYShape(d1, sampleCount);
|
||||||
|
let (sc2, sd2) = renderDistributionToXYShape(d2, sampleCount);
|
||||||
|
`RenderedShape(Distributions.Continuous.reduce((+.), [|sc1, sc2|]), Distributions.Discrete.reduce((+.), [|sd1, sd2|]), 2.)
|
||||||
|
}
|
||||||
|
| (`Distribution(d1), `RenderedShape(sc2, sd2, i2))
|
||||||
|
| (`RenderedShape(sc2, sd2, i2), `Distribution(d1)) => {
|
||||||
|
let (sc1, sd1) = renderDistributionToXYShape(d1, sampleCount);
|
||||||
|
`RenderedShape(Distributions.Continuous.reduce((+.), [|sc1, sc2|]), Distributions.Discrete.reduce((+.), [|sd1, sd2|]), 1. +. i2)
|
||||||
|
}
|
||||||
|
| (`RenderedShape(sc1, sd1, i1), `RenderedShape(sc2, sd2, i2)) => {
|
||||||
|
Js.log3("Reducing continuous rr", sc1, sc2);
|
||||||
|
Js.log2("Continuous reduction:", Distributions.Continuous.reduce((+.), [|sc1, sc2|]));
|
||||||
|
Js.log2("Discrete reduction:", Distributions.Discrete.reduce((+.), [|sd1, sd2|]));
|
||||||
|
`RenderedShape(Distributions.Continuous.reduce((+.), [|sc1, sc2|]), Distributions.Discrete.reduce((+.), [|sd1, sd2|]), i1 +. i2)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
let evaluatePointwiseProduct = (et1: nodeResult, et2: nodeResult, sampleCount: int) => {
|
||||||
|
switch ((et1, et2)) {
|
||||||
|
/* Known cases: */
|
||||||
|
| (`Distribution(`Float(v1)), `Distribution(`Float(v2))) => {
|
||||||
|
v1 == v2
|
||||||
|
? `RenderedShape(Distributions.Continuous.empty, Distributions.Discrete.make({xs: [|v1|], ys: [|1.|]}), 1.)
|
||||||
|
: `RenderedShape(Distributions.Continuous.empty, Distributions.Discrete.empty, 0.) // TODO: add warning: shouldn't pointwise multiply scalars.
|
||||||
|
}
|
||||||
|
| (`Distribution(`Float(v1)), `Distribution(d2)) => {
|
||||||
|
// evaluate d2 at v1
|
||||||
|
let y = GenericSimple.pdf(v1, d2);
|
||||||
|
`RenderedShape(Distributions.Continuous.empty, Distributions.Discrete.make({xs: [|v1|], ys: [|y|]}), y)
|
||||||
|
}
|
||||||
|
| (`Distribution(d1), `Distribution(`Float(v2))) => {
|
||||||
|
// evaluate d1 at v2
|
||||||
|
let y = GenericSimple.pdf(v2, d1);
|
||||||
|
`RenderedShape(Distributions.Continuous.empty, Distributions.Discrete.make({xs: [|v2|], ys: [|y|]}), y)
|
||||||
|
}
|
||||||
|
| (`Distribution(`Normal(n1)), `Distribution(`Normal(n2))) => {
|
||||||
|
let mean = (n1.mean *. n2.stdev**2. +. n2.mean *. n1.stdev**2.) /. (n1.stdev**2. +. n2.stdev**2.);
|
||||||
|
let stdev = 1. /. ((1. /. n1.stdev**2.) +. (1. /. n2.stdev**2.));
|
||||||
|
let integral = 0; // TODO
|
||||||
|
`RenderedShape(Distributions.Continuous.empty, Distributions.Discrete.empty, 0.)
|
||||||
|
}
|
||||||
|
/* General cases */
|
||||||
|
| (`Distribution(d1), `Distribution(d2)) => {
|
||||||
|
// NOT IMPLEMENTED YET
|
||||||
|
// TODO: evaluate integral properly
|
||||||
|
let (sc1, sd1) = renderDistributionToXYShape(d1, sampleCount);
|
||||||
|
let (sc2, sd2) = renderDistributionToXYShape(d2, sampleCount);
|
||||||
|
`RenderedShape(Distributions.Continuous.empty, Distributions.Discrete.empty, 0.)
|
||||||
|
}
|
||||||
|
| (`Distribution(d1), `RenderedShape(sc2, sd2, i2)) => {
|
||||||
|
// NOT IMPLEMENTED YET
|
||||||
|
// TODO: evaluate integral properly
|
||||||
|
let (sc1, sd1) = renderDistributionToXYShape(d1, sampleCount);
|
||||||
|
`RenderedShape(Distributions.Continuous.empty, Distributions.Discrete.empty, 0.)
|
||||||
|
}
|
||||||
|
| (`RenderedShape(sc1, sd1, i1), `Distribution(d1)) => {
|
||||||
|
// NOT IMPLEMENTED YET
|
||||||
|
// TODO: evaluate integral properly
|
||||||
|
let (sc2, sd2) = renderDistributionToXYShape(d1, sampleCount);
|
||||||
|
`RenderedShape(Distributions.Continuous.empty, Distributions.Discrete.empty, 0.)
|
||||||
|
}
|
||||||
|
| (`RenderedShape(sc1, sd1, i1), `RenderedShape(sc2, sd2, i2)) => {
|
||||||
|
`RenderedShape(Distributions.Continuous.empty, Distributions.Discrete.empty, 0.)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
|
||||||
|
let evaluateNormalize = (et: nodeResult, sampleCount: int) => {
|
||||||
|
// just divide everything by the integral.
|
||||||
|
switch (et) {
|
||||||
|
| `RenderedShape(sc, sd, i) => {
|
||||||
|
// loop through all ys and divide them by i
|
||||||
|
let normalize = (s: DistTypes.xyShape): DistTypes.xyShape => {xs: s.xs, ys: s.ys |> E.A.fmap(y => y /. i)};
|
||||||
|
|
||||||
|
let scn = sc |> Distributions.Continuous.shapeMap(normalize);
|
||||||
|
let sdn = sd |> normalize;
|
||||||
|
|
||||||
|
`RenderedShape(scn, sdn, 1.)
|
||||||
|
}
|
||||||
|
| `Distribution(d) => `Distribution(d) // any kind of atomic dist should already be normalized -- TODO: THIS IS ACTUALLY FALSE! E.g. pointwise product of normal * normal
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
let evaluateTruncate = (et: nodeResult, xc: cutoffX, compareFunc: (float, float) => bool, sampleCount: int) => {
|
||||||
|
let cut = (s: DistTypes.xyShape): DistTypes.xyShape => {
|
||||||
|
let (xs, ys) = s.ys
|
||||||
|
|> Belt.Array.zip(s.xs)
|
||||||
|
|> E.A.filter(((x, y)) => compareFunc(x, xc))
|
||||||
|
|> Belt.Array.unzip
|
||||||
|
|
||||||
|
let cutShape: DistTypes.xyShape = {xs, ys};
|
||||||
|
cutShape;
|
||||||
|
};
|
||||||
|
|
||||||
|
switch (et) {
|
||||||
|
| `Distribution(d) => {
|
||||||
|
let (sc, sd) = renderDistributionToXYShape(d, sampleCount);
|
||||||
|
|
||||||
|
let scc = sc |> Distributions.Continuous.shapeMap(cut);
|
||||||
|
let sdc = sd |> cut;
|
||||||
|
|
||||||
|
let newIntegral = 1.; // TODO
|
||||||
|
|
||||||
|
`RenderedShape(scc, sdc, newIntegral);
|
||||||
|
}
|
||||||
|
| `RenderedShape(sc, sd, i) => {
|
||||||
|
let scc = sc |> Distributions.Continuous.shapeMap(cut);
|
||||||
|
let sdc = sd |> cut;
|
||||||
|
|
||||||
|
let newIntegral = 1.; // TODO
|
||||||
|
|
||||||
|
`RenderedShape(scc, sdc, newIntegral);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
let evaluateVerticalScaling = (et1: nodeResult, et2: nodeResult, sampleCount: int) => {
|
||||||
|
let scale = (i: float, s: DistTypes.xyShape): DistTypes.xyShape => {xs: s.xs, ys: s.ys |> E.A.fmap(y => y *. i)};
|
||||||
|
|
||||||
|
switch ((et1, et2)) {
|
||||||
|
| (`Distribution(`Float(v)), `Distribution(d))
|
||||||
|
| (`Distribution(d), `Distribution(`Float(v))) => {
|
||||||
|
let (sc, sd) = renderDistributionToXYShape(d, sampleCount);
|
||||||
|
|
||||||
|
let scc = sc |> Distributions.Continuous.shapeMap(scale(v));
|
||||||
|
let sdc = sd |> scale(v);
|
||||||
|
|
||||||
|
let newIntegral = v; // TODO
|
||||||
|
|
||||||
|
`RenderedShape(scc, sdc, newIntegral);
|
||||||
|
}
|
||||||
|
| (`Distribution(`Float(v)), `RenderedShape(sc, sd, i))
|
||||||
|
| (`RenderedShape(sc, sd, i), `Distribution(`Float(v))) => {
|
||||||
|
let scc = sc |> Distributions.Continuous.shapeMap(scale(v));
|
||||||
|
let sdc = sd |> scale(v);
|
||||||
|
|
||||||
|
let newIntegral = v; // TODO
|
||||||
|
|
||||||
|
`RenderedShape(scc, sdc, newIntegral);
|
||||||
}
|
}
|
||||||
|
| _ => `RenderedShape(Distributions.Continuous.empty, Distributions.Discrete.empty, 0.) // TODO: give warning
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
}, ([||]: array(DistTypes.xyShape), [||]: array(DistTypes.xyShape)))
|
let renderNode = (et: nodeResult, sampleCount: int) => {
|
||||||
|
switch (et) {
|
||||||
|
| `Distribution(d) => {
|
||||||
|
let (sc, sd) = renderDistributionToXYShape(d, sampleCount);
|
||||||
|
`RenderedShape(sc, sd, 1.0);
|
||||||
|
}
|
||||||
|
| s => s
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
let rec evaluateNode = (treeNode: distTree, sampleCount: int): nodeResult => {
|
||||||
|
// returns either a new symbolic distribution
|
||||||
|
switch (treeNode) {
|
||||||
|
| `Distribution(d) => evaluateDistribution(d)
|
||||||
|
| `Combination(t1, t2, op) => evaluateCombinationDistribution(evaluateNode(t1, sampleCount), evaluateNode(t2, sampleCount), op, sampleCount)
|
||||||
|
| `PointwiseSum(t1, t2) => evaluatePointwiseSum(evaluateNode(t1, sampleCount), evaluateNode(t2, sampleCount), sampleCount)
|
||||||
|
| `PointwiseProduct(t1, t2) => evaluatePointwiseProduct(evaluateNode(t1, sampleCount), evaluateNode(t2, sampleCount), sampleCount)
|
||||||
|
| `VerticalScaling(t1, t2) => evaluateVerticalScaling(evaluateNode(t1, sampleCount), evaluateNode(t2, sampleCount), sampleCount)
|
||||||
|
| `Normalize(t) => evaluateNormalize(evaluateNode(t, sampleCount), sampleCount)
|
||||||
|
| `LeftTruncate(t, x) => evaluateTruncate(evaluateNode(t, sampleCount), x, (<=), sampleCount)
|
||||||
|
| `RightTruncate(t, x) => evaluateTruncate(evaluateNode(t, sampleCount), x, (>=), sampleCount)
|
||||||
|
| `Render(t) => renderNode(evaluateNode(t, sampleCount), sampleCount)
|
||||||
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
/*
|
let toShape = (treeNode: distTree, sampleCount: int) => {
|
||||||
We will assume that each dist (of t) in the multimodal has a total of one.
|
/*let continuousXs = findContinuousXs(dists, sampleCount);
|
||||||
We can therefore normalize the weights of the parts.
|
|
||||||
|
|
||||||
However, a multimodal can consist of both discrete and continuous shapes.
|
|
||||||
These need to be added and collected individually.
|
|
||||||
*/
|
|
||||||
let toShape = (dists: t, sampleCount: int) => {
|
|
||||||
let continuousXs = findContinuousXs(dists, sampleCount);
|
|
||||||
continuousXs |> Array.fast_sort(compare);
|
continuousXs |> Array.fast_sort(compare);
|
||||||
|
|
||||||
let (contShapes, distShapes) = accumulateContAndDiscShapes(dists, continuousXs, 1.0);
|
let (contShapes, distShapes) = accumulateContAndDiscShapes(dists, continuousXs, 1.0);
|
||||||
|
@ -469,60 +649,42 @@ module PointwiseAddDistributionsWeighted = {
|
||||||
}, {xs: continuousXs, ys: Array.make(Array.length(continuousXs), 0.0)})
|
}, {xs: continuousXs, ys: Array.make(Array.length(continuousXs), 0.0)})
|
||||||
|> Distributions.Continuous.make(`Linear);
|
|> Distributions.Continuous.make(`Linear);
|
||||||
|
|
||||||
let combinedDiscrete = Distributions.Discrete.reduce((+.), distShapes)
|
let combinedDiscrete = Distributions.Discrete.reduce((+.), distShapes)*/
|
||||||
|
|
||||||
let shape = MixedShapeBuilder.buildSimple(~continuous=Some(combinedContinuous), ~discrete=combinedDiscrete);
|
let treeShape = evaluateNode(`Render(`Normalize(treeNode)), sampleCount);
|
||||||
|
switch (treeShape) {
|
||||||
|
| `Distribution(_) => E.O.toExn("No shape found!", None)
|
||||||
|
| `RenderedShape(sc, sd, _) => {
|
||||||
|
let shape = MixedShapeBuilder.buildSimple(~continuous=Some(sc), ~discrete=sd);
|
||||||
|
|
||||||
shape |> E.O.toExt("");
|
shape |> E.O.toExt("");
|
||||||
|
}
|
||||||
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
let rec toString = (dists: t): string => {
|
let rec toString = (treeNode: distTree): string => {
|
||||||
let distString =
|
let stringFromOp = op => switch (op) {
|
||||||
dists
|
| `AddOperation => " + "
|
||||||
|> E.A.fmap(((d, _)) =>
|
| `SubtractOperation => " - "
|
||||||
switch (d) {
|
| `MultiplyOperation => " * "
|
||||||
| `Simple(d) => GenericSimple.toString(d)
|
| `DivideOperation => " / "
|
||||||
| `PointwiseCombination(ts: t) => ts |> toString
|
| `ExponentiateOperation => "^"
|
||||||
}
|
};
|
||||||
)
|
|
||||||
|> Js.Array.joinWith(",");
|
|
||||||
|
|
||||||
// mm(normal(0,1), normal(1,2)) => "multimodal(normal(0,1), normal(1,2), )
|
switch (treeNode) {
|
||||||
|
| `Distribution(d) => GenericSimple.toString(d)
|
||||||
let weights =
|
| `Combination(t1, t2, op) => toString(t1) ++ stringFromOp(op) ++ toString(t2)
|
||||||
dists
|
| `PointwiseSum(t1, t2) => toString(t1) ++ " .+ " ++ toString(t2)
|
||||||
|> E.A.fmap(((_, w)) =>
|
| `PointwiseProduct(t1, t2) => toString(t1) ++ " .* " ++ toString(t2)
|
||||||
Js.Float.toPrecisionWithPrecision(w, ~digits=2)
|
| `VerticalScaling(t1, t2) => toString(t1) ++ " @ " ++ toString(t2)
|
||||||
)
|
| `Normalize(t) => "normalize(" ++ toString(t) ++ ")"
|
||||||
|> Js.Array.joinWith(",");
|
| `LeftTruncate(t, x) => "leftTruncate(" ++ toString(t) ++ ", " ++ string_of_float(x) ++ ")"
|
||||||
|
| `RightTruncate(t, x) => "rightTruncate(" ++ toString(t) ++ ", " ++ string_of_float(x) ++ ")"
|
||||||
{j|multimodal($distString, [$weights])|j};
|
}
|
||||||
};
|
};
|
||||||
};
|
};
|
||||||
|
|
||||||
// assume that recursive pointwiseNormalizedDistSums are the only type of operation there is.
|
let toString = (treeNode: distTree) => DistTree.toString(treeNode)
|
||||||
// in the original, it was a list of (dist, weight) tuples. Now, it's a tree of (dist, weight) tuples, just that every
|
|
||||||
// dist can be either a GenericSimple or another PointwiseAdd.
|
|
||||||
|
|
||||||
/*let toString = (r: bigDistTree) => {
|
let toShape = (sampleCount: int, treeNode: distTree) =>
|
||||||
switch (r) {
|
DistTree.toShape(treeNode, sampleCount) //~xSelection=`ByWeight,
|
||||||
| WeightedDistLeaf((w, d)) => GenericWeighted.toString(w) // "normal "
|
|
||||||
| PointwiseNormalizedDistSum(childTrees) => childTrees |> E.A.fmap(toString) |> Js.Array.joinWith("")
|
|
||||||
}
|
|
||||||
}*/
|
|
||||||
|
|
||||||
let toString = (r: bigDist) =>
|
|
||||||
// we need to recursively create the string representation of the tree.
|
|
||||||
r
|
|
||||||
|> (
|
|
||||||
fun
|
|
||||||
| `Simple(d) => GenericSimple.toString(d)
|
|
||||||
| `PointwiseCombination(d) =>
|
|
||||||
PointwiseAddDistributionsWeighted.toString(d)
|
|
||||||
);
|
|
||||||
|
|
||||||
let toShape = n =>
|
|
||||||
fun
|
|
||||||
| `Simple(d) => GenericSimple.toShape(~xSelection=`ByWeight, d, n)
|
|
||||||
| `PointwiseCombination(d) =>
|
|
||||||
PointwiseAddDistributionsWeighted.toShape(d, n);
|
|
||||||
|
|
Loading…
Reference in New Issue
Block a user