First pass at nested multimodals, still needs lots of cleanup

This commit is contained in:
Sebastian Kosch 2020-06-09 21:28:03 -07:00
parent a9d52e2c5c
commit eb0ffdc6c3
2 changed files with 239 additions and 88 deletions

View File

@ -148,6 +148,10 @@ module MathAdtToDistDst = {
Ok(`Simple(`Triangular({low, medium, high})))
| _ => Error("Wrong number of variables in triangle distribution");
/*let add: array(arg) => result(SymbolicDist.bigDist, string) =
fun
| */
let multiModal =
(
args: array(result(SymbolicDist.bigDist, string)),
@ -158,22 +162,25 @@ module MathAdtToDistDst = {
args
|> E.A.fmap(
fun
| Ok(`Simple(n)) => Ok(n)
| Ok(`Simple(d)) => Ok(`Simple(d))
| Ok(`PointwiseCombination(dists)) => Ok(`PointwiseCombination(dists))
| Error(e) => Error(e)
| Ok(k) => Error(SymbolicDist.toString(k)),
| _ => Error("Unexpected dist")
);
let firstWithError = dists |> Belt.Array.getBy(_, Belt.Result.isError);
let withoutErrors = dists |> E.A.fmap(E.R.toOption) |> E.A.O.concatSomes;
switch (firstWithError) {
| Some(Error(e)) => Error(e)
| None when withoutErrors |> E.A.length == 0 =>
Error("Multimodals need at least one input")
| _ =>
withoutErrors
|> E.A.fmapi((index, item) =>
(item, weights |> E.A.get(_, index) |> E.O.default(1.0))
)
|> (r => Ok(`PointwiseCombination(r)))
| Some(Error(e)) => Error(e)
| None when withoutErrors |> E.A.length == 0 =>
Error("Multimodals need at least one input")
| _ =>
withoutErrors
|> E.A.fmapi((index, item) =>
(item, weights |> E.A.get(_, index) |> E.O.default(1.0))
)
|> (r => Ok(`PointwiseCombination(r)))
};
};
@ -186,12 +193,12 @@ module MathAdtToDistDst = {
)
|> E.A.O.concatSomes
let outputs = Samples.T.fromSamples(samples);
let pdf = outputs.shape |> E.O.bind(_,Distributions.Shape.T.toContinuous)
let pdf = outputs.shape |> E.O.bind(_,Distributions.Shape.T.toContinuous);
let shape = pdf |> E.O.fmap(pdf => {
let _pdf = Distributions.Continuous.T.scaleToIntegralSum(~cache=None, ~intendedSum=1.0, pdf);
let cdf = Distributions.Continuous.T.integral(~cache=None, _pdf);
SymbolicDist.ContinuousShape.make(_pdf, cdf)
})
});
switch(shape){
| Some(s) => Ok(`Simple(`ContinuousShape(s)))
| None => Error("Rendering did not work")
@ -238,6 +245,7 @@ module MathAdtToDistDst = {
let dists = possibleDists |> E.A.fmap(functionParser);
multiModal(dists, weights);
}
//| Fn({name: "add", args}) => add(args)
| Fn({name}) => Error(name ++ ": function not supported")
| _ => {
Error("This type not currently supported");
@ -255,19 +263,32 @@ module MathAdtToDistDst = {
| Object(_) => Error("Object not valid as top level")
);
let run = (r): result(SymbolicDist.bigDist, string) =>
r |> MathAdtCleaner.run |> topLevel;
let run = (r): result(SymbolicDist.bigDist, string) => {
let o = r |> MathAdtCleaner.run |> topLevel;
Js.log2("parser output", o);
o
};
};
let fromString = str => {
/* We feed the user-typed string into Mathjs.parseMath,
which returns a JSON with (hopefully) a single-element array.
This array element is the top-level node of a nested-object tree
representing the functions/arguments/values/etc. in the string.
The function MathJsonToMathJsAdt then recursively unpacks this JSON into a typed data structure we can use.
Inside of this function, MathAdtToDistDst is called whenever a distribution function is encountered.
*/
let mathJsToJson = Mathjs.parseMath(str);
let mathJsParse =
E.R.bind(mathJsToJson, r =>
E.R.bind(mathJsToJson, r => {
Js.log2("parsed", r);
switch (MathJsonToMathJsAdt.run(r)) {
| Some(r) => Ok(r)
| None => Error("MathJsParse Error")
}
);
});
let value = E.R.bind(mathJsParse, MathAdtToDistDst.run);
value;
};
};

View File

@ -47,12 +47,44 @@ type dist = [
| `Cauchy(cauchy)
| `Triangular(triangular)
| `ContinuousShape(continuousShape)
| `Float(float)
| `Float(float) // Dirac delta at x. Practically useful only in the context of multimodals.
];
type pointwiseAdd = array((dist, float));
/* Build a tree.
type bigDist = [ | `Simple(dist) | `PointwiseCombination(pointwiseAdd)];
Multiple operations possible:
- PointwiseSum(Scalar, Scalar)
- PointwiseSum(WeightedDist, WeightedDist)
- PointwiseProduct(Scalar, Scalar)
- PointwiseProduct(Scalar, WeightedDist)
- PointwiseProduct(WeightedDist, WeightedDist)
- IndependentVariableSum(WeightedDist, WeightedDist) [i.e., convolution]
- IndependentVariableProduct(WeightedDist, WeightedDist) [i.e. distribution product]
*/
type weightedDist = (float, dist);
type bigDistTree =
/* | DistLeaf(dist) */
/* | ScalarLeaf(float) */
/* | PointwiseScalarDistProduct(DistLeaf(d), ScalarLeaf(s)) */
| WeightedDistLeaf(weightedDist)
| PointwiseNormalizedDistSum(array(bigDistTree));
let rec treeIntegral = item => {
switch (item) {
| WeightedDistLeaf((w, d)) => w
| PointwiseNormalizedDistSum(childTrees) =>
childTrees |> E.A.fmap(treeIntegral) |> E.A.Floats.sum
};
};
/* bigDist can either be a single distribution, or a
PointwiseCombination, i.e. an array of (dist, weight) tuples */
type bigDist = [ | `Simple(dist) | `PointwiseCombination(pointwiseAdd)]
and pointwiseAdd = array((bigDist, float));
module ContinuousShape = {
type t = continuousShape;
@ -255,29 +287,27 @@ module GenericSimple = {
| `Uniform({high}) => high
| `Float(n) => n;
/* This function returns a list of x's at which to evaluate the overall distribution (for rendering).
This function is called separately for each individual distribution.
This function is called separately for each individual distribution.
When called with xSelection=`Linear, this function will return (sampleCount) x's, evenly
distributed between the min and max of the distribution (whatever those are defined to be above).
When called with xSelection=`Linear, this function will return (sampleCount) x's, evenly
distributed between the min and max of the distribution (whatever those are defined to be above).
When called with xSelection=`ByWeight, this function will distribute the x's such as to
match the cumulative shape of the distribution. This is slower but may give better results.
*/
When called with xSelection=`ByWeight, this function will distribute the x's such as to
match the cumulative shape of the distribution. This is slower but may give better results.
*/
let interpolateXs =
(~xSelection: [ | `Linear | `ByWeight]=`Linear, dist: dist, sampleCount) => {
switch (xSelection, dist) {
| (`Linear, _) => E.A.Floats.range(min(dist), max(dist), sampleCount)
| (`ByWeight, `Uniform(n)) =>
// In `ByWeight mode, uniform distributions get special treatment because we need two x's
// on either side for proper rendering (just left and right of the discontinuities).
let dx = 0.00001 *. (n.high -. n.low);
[|n.low -. dx, n.low +. dx, n.high -. dx, n.high +. dx|]
| (`ByWeight, _) =>
let ys = E.A.Floats.range(minCdfValue, maxCdfValue, sampleCount)
ys |> E.A.fmap(y => inv(y, dist))
[|n.low -. dx, n.low +. dx, n.high -. dx, n.high +. dx|];
| (`ByWeight, _) =>
let ys = E.A.Floats.range(minCdfValue, maxCdfValue, sampleCount);
ys |> E.A.fmap(y => inv(y, dist));
};
};
@ -299,90 +329,190 @@ module GenericSimple = {
module PointwiseAddDistributionsWeighted = {
type t = pointwiseAdd;
let normalizeWeights = (dists: t) => {
let total = dists |> E.A.fmap(snd) |> E.A.Floats.sum;
dists |> E.A.fmap(((a, b)) => (a, b /. total));
let normalizeWeights = (weightedDists: t) => {
let total = weightedDists |> E.A.fmap(snd) |> E.A.Floats.sum;
weightedDists |> E.A.fmap(((d, w)) => (d, w /. total));
};
let pdf = (x: float, dists: t) =>
dists
|> E.A.fmap(((e, w)) => GenericSimple.pdf(x, e) *. w)
let rec pdf = (x: float, weightedNormalizedDists: t) =>
weightedNormalizedDists
|> E.A.fmap(((d, w)) => {
switch (d) {
| `PointwiseCombination(ts) => pdf(x, ts) *. w
| `Simple(d) => GenericSimple.pdf(x, d) *. w
}
})
|> E.A.Floats.sum;
let min = (dists: t) =>
dists |> E.A.fmap(d => d |> fst |> GenericSimple.min) |> E.A.min;
// TODO: perhaps rename into minCdfX?
// TODO: how should nonexistent min values be handled? They should never happen
let rec min = (dists: t) =>
dists
|> E.A.fmap(((d, w)) => {
switch (d) {
| `PointwiseCombination(ts) => E.O.toExn("Dist has no min", min(ts))
| `Simple(d) => GenericSimple.min(d)
}
})
|> E.A.min;
let max = (dists: t) =>
dists |> E.A.fmap(d => d |> fst |> GenericSimple.max) |> E.A.max;
// TODO: perhaps rename into minCdfX?
let rec max = (dists: t) =>
dists
|> E.A.fmap(((d, w)) => {
switch (d) {
| `PointwiseCombination(ts) => E.O.toExn("Dist has no max", max(ts))
| `Simple(d) => GenericSimple.max(d)
}
})
|> E.A.max;
let discreteShape = (dists: t, sampleCount: int) => {
/*let rec discreteShape = (dists: t, sampleCount: int) => {
let discrete =
dists
|> E.A.fmap(((r, e)) =>
r
|> (
fun
| `Float(r) => Some((r, e))
| _ => None
)
)
|> E.A.fmap(((x, w)) => {
switch (d) {
| `Float(d) => Some((d, w)) // if the distribution is just a number, then the weight is considered the y
| _ => None
}
})
|> E.A.O.concatSomes
|> E.A.fmap(((x, y)) =>
({xs: [|x|], ys: [|y|]}: DistTypes.xyShape)
)
// take an array of xyShapes and combine them together
//* r
|> (
fun
| `Float(r) => Some((r, e))
| _ => None
)
)*/
|> Distributions.Discrete.reduce((+.));
discrete;
};*/
let rec findContinuousXs = (dists: t, sampleCount: int) => {
// we need to go through the tree of distributions and, for the continuous ones, find the xs at which
// later, all distributions will get evaluated.
// we want to accumulate a set of xs.
let xs: array(float) =
dists
|> E.A.fold_left((accXs, (d, w)) => {
switch (d) {
| `Simple(t) when (GenericSimple.contType(t) == `Discrete) => accXs
| `Simple(d) => {
let xs = GenericSimple.interpolateXs(~xSelection=`ByWeight, d, sampleCount)
E.A.append(accXs, xs)
}
| `PointwiseCombination(ts) => {
let xs = findContinuousXs(ts, sampleCount);
E.A.append(accXs, xs)
}
}
}, [||]);
xs
};
let continuousShape = (dists: t, sampleCount: int) => {
let xs =
dists
|> E.A.fmap(r =>
r
|> fst
|> GenericSimple.interpolateXs(
~xSelection=`ByWeight,
_,
sampleCount / (dists |> E.A.length),
)
)
|> E.A.concatMany;
xs |> Array.fast_sort(compare);
let ys = xs |> E.A.fmap(pdf(_, dists));
XYShape.T.fromArrays(xs, ys) |> Distributions.Continuous.make(`Linear, _);
};
let toShape = (dists: t, sampleCount: int) => {
/* Accumulate (accContShapes, accDistShapes), each of which is an array of {xs, ys} shapes. */
let rec accumulateContAndDiscShapes = (dists: t, continuousXs: array(float), currentWeight) => {
let normalized = normalizeWeights(dists);
let continuous =
normalized
|> E.A.filter(((r, _)) => GenericSimple.contType(r) == `Continuous)
|> continuousShape(_, sampleCount);
let discrete =
normalized
|> E.A.filter(((r, _)) => GenericSimple.contType(r) == `Discrete)
|> discreteShape(_, sampleCount);
let shape =
MixedShapeBuilder.buildSimple(~continuous=Some(continuous), ~discrete);
normalized
|> E.A.fold_left(((accContShapes: array(DistTypes.xyShape), accDiscShapes: array(DistTypes.xyShape)), (d, w)) => {
switch (d) {
| `Simple(`Float(x)) => {
let ds: DistTypes.xyShape = {xs: [|x|], ys: [|w *. currentWeight|]};
(accContShapes, E.A.append(accDiscShapes, [|ds|]))
}
| `Simple(d) when (GenericSimple.contType(d) == `Continuous) => {
let ys = continuousXs |> E.A.fmap(x => GenericSimple.pdf(x, d) *. w *. currentWeight);
let cs = XYShape.T.fromArrays(continuousXs, ys);
(E.A.append(accContShapes, [|cs|]), accDiscShapes)
}
| `Simple(d) => (accContShapes, accDiscShapes) // default -- should never happen
| `PointwiseCombination(ts) => {
let (cs, ds) = accumulateContAndDiscShapes(ts, continuousXs, w *. currentWeight);
(E.A.append(accContShapes, cs), E.A.append(accDiscShapes, ds))
}
}
}, ([||]: array(DistTypes.xyShape), [||]: array(DistTypes.xyShape)))
};
/*
We will assume that each dist (of t) in the multimodal has a total of one.
We can therefore normalize the weights of the parts.
However, a multimodal can consist of both discrete and continuous shapes.
These need to be added and collected individually.
*/
let toShape = (dists: t, sampleCount: int) => {
let continuousXs = findContinuousXs(dists, sampleCount);
continuousXs |> Array.fast_sort(compare);
let (contShapes, distShapes) = accumulateContAndDiscShapes(dists, continuousXs, 1.0);
let combinedContinuous = contShapes
|> E.A.fold_left((shapeAcc: DistTypes.xyShape, shape: DistTypes.xyShape) => {
let ys = E.A.fmapi((i, y) => y +. shape.ys[i], shapeAcc.ys);
{xs: continuousXs, ys: ys}
}, {xs: continuousXs, ys: Array.make(Array.length(continuousXs), 0.0)})
|> Distributions.Continuous.make(`Linear);
let combinedDiscrete = Distributions.Discrete.reduce((+.), distShapes)
let shape = MixedShapeBuilder.buildSimple(~continuous=Some(combinedContinuous), ~discrete=combinedDiscrete);
shape |> E.O.toExt("");
};
let toString = (dists: t) => {
let rec toString = (dists: t): string => {
let distString =
dists
|> E.A.fmap(d => GenericSimple.toString(fst(d)))
|> Js.Array.joinWith(",");
let weights =
dists
|> E.A.fmap(d =>
snd(d) |> Js.Float.toPrecisionWithPrecision(~digits=2)
|> E.A.fmap(((d, _)) =>
switch (d) {
| `Simple(d) => GenericSimple.toString(d)
| `PointwiseCombination(ts: t) => ts |> toString
}
)
|> Js.Array.joinWith(",");
// mm(normal(0,1), normal(1,2)) => "multimodal(normal(0,1), normal(1,2), )
let weights =
dists
|> E.A.fmap(((_, w)) =>
Js.Float.toPrecisionWithPrecision(w, ~digits=2)
)
|> Js.Array.joinWith(",");
{j|multimodal($distString, [$weights])|j};
};
};
// assume that recursive pointwiseNormalizedDistSums are the only type of operation there is.
// in the original, it was a list of (dist, weight) tuples. Now, it's a tree of (dist, weight) tuples, just that every
// dist can be either a GenericSimple or another PointwiseAdd.
/*let toString = (r: bigDistTree) => {
switch (r) {
| WeightedDistLeaf((w, d)) => GenericWeighted.toString(w) // "normal "
| PointwiseNormalizedDistSum(childTrees) => childTrees |> E.A.fmap(toString) |> Js.Array.joinWith("")
}
}*/
let toString = (r: bigDist) =>
// we need to recursively create the string representation of the tree.
r
|> (
fun