First pass at nested multimodals, still needs lots of cleanup

This commit is contained in:
Sebastian Kosch 2020-06-09 21:28:03 -07:00
parent a9d52e2c5c
commit eb0ffdc6c3
2 changed files with 239 additions and 88 deletions

View File

@ -148,6 +148,10 @@ module MathAdtToDistDst = {
Ok(`Simple(`Triangular({low, medium, high}))) Ok(`Simple(`Triangular({low, medium, high})))
| _ => Error("Wrong number of variables in triangle distribution"); | _ => Error("Wrong number of variables in triangle distribution");
/*let add: array(arg) => result(SymbolicDist.bigDist, string) =
fun
| */
let multiModal = let multiModal =
( (
args: array(result(SymbolicDist.bigDist, string)), args: array(result(SymbolicDist.bigDist, string)),
@ -158,22 +162,25 @@ module MathAdtToDistDst = {
args args
|> E.A.fmap( |> E.A.fmap(
fun fun
| Ok(`Simple(n)) => Ok(n) | Ok(`Simple(d)) => Ok(`Simple(d))
| Ok(`PointwiseCombination(dists)) => Ok(`PointwiseCombination(dists))
| Error(e) => Error(e) | Error(e) => Error(e)
| Ok(k) => Error(SymbolicDist.toString(k)), | _ => Error("Unexpected dist")
); );
let firstWithError = dists |> Belt.Array.getBy(_, Belt.Result.isError); let firstWithError = dists |> Belt.Array.getBy(_, Belt.Result.isError);
let withoutErrors = dists |> E.A.fmap(E.R.toOption) |> E.A.O.concatSomes; let withoutErrors = dists |> E.A.fmap(E.R.toOption) |> E.A.O.concatSomes;
switch (firstWithError) { switch (firstWithError) {
| Some(Error(e)) => Error(e) | Some(Error(e)) => Error(e)
| None when withoutErrors |> E.A.length == 0 => | None when withoutErrors |> E.A.length == 0 =>
Error("Multimodals need at least one input") Error("Multimodals need at least one input")
| _ => | _ =>
withoutErrors withoutErrors
|> E.A.fmapi((index, item) => |> E.A.fmapi((index, item) =>
(item, weights |> E.A.get(_, index) |> E.O.default(1.0)) (item, weights |> E.A.get(_, index) |> E.O.default(1.0))
) )
|> (r => Ok(`PointwiseCombination(r))) |> (r => Ok(`PointwiseCombination(r)))
}; };
}; };
@ -186,12 +193,12 @@ module MathAdtToDistDst = {
) )
|> E.A.O.concatSomes |> E.A.O.concatSomes
let outputs = Samples.T.fromSamples(samples); let outputs = Samples.T.fromSamples(samples);
let pdf = outputs.shape |> E.O.bind(_,Distributions.Shape.T.toContinuous) let pdf = outputs.shape |> E.O.bind(_,Distributions.Shape.T.toContinuous);
let shape = pdf |> E.O.fmap(pdf => { let shape = pdf |> E.O.fmap(pdf => {
let _pdf = Distributions.Continuous.T.scaleToIntegralSum(~cache=None, ~intendedSum=1.0, pdf); let _pdf = Distributions.Continuous.T.scaleToIntegralSum(~cache=None, ~intendedSum=1.0, pdf);
let cdf = Distributions.Continuous.T.integral(~cache=None, _pdf); let cdf = Distributions.Continuous.T.integral(~cache=None, _pdf);
SymbolicDist.ContinuousShape.make(_pdf, cdf) SymbolicDist.ContinuousShape.make(_pdf, cdf)
}) });
switch(shape){ switch(shape){
| Some(s) => Ok(`Simple(`ContinuousShape(s))) | Some(s) => Ok(`Simple(`ContinuousShape(s)))
| None => Error("Rendering did not work") | None => Error("Rendering did not work")
@ -238,6 +245,7 @@ module MathAdtToDistDst = {
let dists = possibleDists |> E.A.fmap(functionParser); let dists = possibleDists |> E.A.fmap(functionParser);
multiModal(dists, weights); multiModal(dists, weights);
} }
//| Fn({name: "add", args}) => add(args)
| Fn({name}) => Error(name ++ ": function not supported") | Fn({name}) => Error(name ++ ": function not supported")
| _ => { | _ => {
Error("This type not currently supported"); Error("This type not currently supported");
@ -255,19 +263,32 @@ module MathAdtToDistDst = {
| Object(_) => Error("Object not valid as top level") | Object(_) => Error("Object not valid as top level")
); );
let run = (r): result(SymbolicDist.bigDist, string) => let run = (r): result(SymbolicDist.bigDist, string) => {
r |> MathAdtCleaner.run |> topLevel; let o = r |> MathAdtCleaner.run |> topLevel;
Js.log2("parser output", o);
o
};
}; };
let fromString = str => { let fromString = str => {
/* We feed the user-typed string into Mathjs.parseMath,
which returns a JSON with (hopefully) a single-element array.
This array element is the top-level node of a nested-object tree
representing the functions/arguments/values/etc. in the string.
The function MathJsonToMathJsAdt then recursively unpacks this JSON into a typed data structure we can use.
Inside of this function, MathAdtToDistDst is called whenever a distribution function is encountered.
*/
let mathJsToJson = Mathjs.parseMath(str); let mathJsToJson = Mathjs.parseMath(str);
let mathJsParse = let mathJsParse =
E.R.bind(mathJsToJson, r => E.R.bind(mathJsToJson, r => {
Js.log2("parsed", r);
switch (MathJsonToMathJsAdt.run(r)) { switch (MathJsonToMathJsAdt.run(r)) {
| Some(r) => Ok(r) | Some(r) => Ok(r)
| None => Error("MathJsParse Error") | None => Error("MathJsParse Error")
} }
); });
let value = E.R.bind(mathJsParse, MathAdtToDistDst.run); let value = E.R.bind(mathJsParse, MathAdtToDistDst.run);
value; value;
}; };

View File

@ -47,12 +47,44 @@ type dist = [
| `Cauchy(cauchy) | `Cauchy(cauchy)
| `Triangular(triangular) | `Triangular(triangular)
| `ContinuousShape(continuousShape) | `ContinuousShape(continuousShape)
| `Float(float) | `Float(float) // Dirac delta at x. Practically useful only in the context of multimodals.
]; ];
type pointwiseAdd = array((dist, float)); /* Build a tree.
type bigDist = [ | `Simple(dist) | `PointwiseCombination(pointwiseAdd)]; Multiple operations possible:
- PointwiseSum(Scalar, Scalar)
- PointwiseSum(WeightedDist, WeightedDist)
- PointwiseProduct(Scalar, Scalar)
- PointwiseProduct(Scalar, WeightedDist)
- PointwiseProduct(WeightedDist, WeightedDist)
- IndependentVariableSum(WeightedDist, WeightedDist) [i.e., convolution]
- IndependentVariableProduct(WeightedDist, WeightedDist) [i.e. distribution product]
*/
type weightedDist = (float, dist);
type bigDistTree =
/* | DistLeaf(dist) */
/* | ScalarLeaf(float) */
/* | PointwiseScalarDistProduct(DistLeaf(d), ScalarLeaf(s)) */
| WeightedDistLeaf(weightedDist)
| PointwiseNormalizedDistSum(array(bigDistTree));
let rec treeIntegral = item => {
switch (item) {
| WeightedDistLeaf((w, d)) => w
| PointwiseNormalizedDistSum(childTrees) =>
childTrees |> E.A.fmap(treeIntegral) |> E.A.Floats.sum
};
};
/* bigDist can either be a single distribution, or a
PointwiseCombination, i.e. an array of (dist, weight) tuples */
type bigDist = [ | `Simple(dist) | `PointwiseCombination(pointwiseAdd)]
and pointwiseAdd = array((bigDist, float));
module ContinuousShape = { module ContinuousShape = {
type t = continuousShape; type t = continuousShape;
@ -255,29 +287,27 @@ module GenericSimple = {
| `Uniform({high}) => high | `Uniform({high}) => high
| `Float(n) => n; | `Float(n) => n;
/* This function returns a list of x's at which to evaluate the overall distribution (for rendering). /* This function returns a list of x's at which to evaluate the overall distribution (for rendering).
This function is called separately for each individual distribution. This function is called separately for each individual distribution.
When called with xSelection=`Linear, this function will return (sampleCount) x's, evenly When called with xSelection=`Linear, this function will return (sampleCount) x's, evenly
distributed between the min and max of the distribution (whatever those are defined to be above). distributed between the min and max of the distribution (whatever those are defined to be above).
When called with xSelection=`ByWeight, this function will distribute the x's such as to When called with xSelection=`ByWeight, this function will distribute the x's such as to
match the cumulative shape of the distribution. This is slower but may give better results. match the cumulative shape of the distribution. This is slower but may give better results.
*/ */
let interpolateXs = let interpolateXs =
(~xSelection: [ | `Linear | `ByWeight]=`Linear, dist: dist, sampleCount) => { (~xSelection: [ | `Linear | `ByWeight]=`Linear, dist: dist, sampleCount) => {
switch (xSelection, dist) { switch (xSelection, dist) {
| (`Linear, _) => E.A.Floats.range(min(dist), max(dist), sampleCount) | (`Linear, _) => E.A.Floats.range(min(dist), max(dist), sampleCount)
| (`ByWeight, `Uniform(n)) => | (`ByWeight, `Uniform(n)) =>
// In `ByWeight mode, uniform distributions get special treatment because we need two x's // In `ByWeight mode, uniform distributions get special treatment because we need two x's
// on either side for proper rendering (just left and right of the discontinuities). // on either side for proper rendering (just left and right of the discontinuities).
let dx = 0.00001 *. (n.high -. n.low); let dx = 0.00001 *. (n.high -. n.low);
[|n.low -. dx, n.low +. dx, n.high -. dx, n.high +. dx|] [|n.low -. dx, n.low +. dx, n.high -. dx, n.high +. dx|];
| (`ByWeight, _) => | (`ByWeight, _) =>
let ys = E.A.Floats.range(minCdfValue, maxCdfValue, sampleCount) let ys = E.A.Floats.range(minCdfValue, maxCdfValue, sampleCount);
ys |> E.A.fmap(y => inv(y, dist)) ys |> E.A.fmap(y => inv(y, dist));
}; };
}; };
@ -299,90 +329,190 @@ module GenericSimple = {
module PointwiseAddDistributionsWeighted = { module PointwiseAddDistributionsWeighted = {
type t = pointwiseAdd; type t = pointwiseAdd;
let normalizeWeights = (dists: t) => { let normalizeWeights = (weightedDists: t) => {
let total = dists |> E.A.fmap(snd) |> E.A.Floats.sum; let total = weightedDists |> E.A.fmap(snd) |> E.A.Floats.sum;
dists |> E.A.fmap(((a, b)) => (a, b /. total)); weightedDists |> E.A.fmap(((d, w)) => (d, w /. total));
}; };
let pdf = (x: float, dists: t) => let rec pdf = (x: float, weightedNormalizedDists: t) =>
dists weightedNormalizedDists
|> E.A.fmap(((e, w)) => GenericSimple.pdf(x, e) *. w) |> E.A.fmap(((d, w)) => {
switch (d) {
| `PointwiseCombination(ts) => pdf(x, ts) *. w
| `Simple(d) => GenericSimple.pdf(x, d) *. w
}
})
|> E.A.Floats.sum; |> E.A.Floats.sum;
let min = (dists: t) => // TODO: perhaps rename into minCdfX?
dists |> E.A.fmap(d => d |> fst |> GenericSimple.min) |> E.A.min; // TODO: how should nonexistent min values be handled? They should never happen
let rec min = (dists: t) =>
dists
|> E.A.fmap(((d, w)) => {
switch (d) {
| `PointwiseCombination(ts) => E.O.toExn("Dist has no min", min(ts))
| `Simple(d) => GenericSimple.min(d)
}
})
|> E.A.min;
let max = (dists: t) => // TODO: perhaps rename into minCdfX?
dists |> E.A.fmap(d => d |> fst |> GenericSimple.max) |> E.A.max; let rec max = (dists: t) =>
dists
|> E.A.fmap(((d, w)) => {
switch (d) {
| `PointwiseCombination(ts) => E.O.toExn("Dist has no max", max(ts))
| `Simple(d) => GenericSimple.max(d)
}
})
|> E.A.max;
let discreteShape = (dists: t, sampleCount: int) => {
/*let rec discreteShape = (dists: t, sampleCount: int) => {
let discrete = let discrete =
dists dists
|> E.A.fmap(((r, e)) => |> E.A.fmap(((x, w)) => {
r switch (d) {
|> ( | `Float(d) => Some((d, w)) // if the distribution is just a number, then the weight is considered the y
fun | _ => None
| `Float(r) => Some((r, e)) }
| _ => None })
)
)
|> E.A.O.concatSomes |> E.A.O.concatSomes
|> E.A.fmap(((x, y)) => |> E.A.fmap(((x, y)) =>
({xs: [|x|], ys: [|y|]}: DistTypes.xyShape) ({xs: [|x|], ys: [|y|]}: DistTypes.xyShape)
) )
// take an array of xyShapes and combine them together
//* r
|> (
fun
| `Float(r) => Some((r, e))
| _ => None
)
)*/
|> Distributions.Discrete.reduce((+.)); |> Distributions.Discrete.reduce((+.));
discrete; discrete;
};*/
let rec findContinuousXs = (dists: t, sampleCount: int) => {
// we need to go through the tree of distributions and, for the continuous ones, find the xs at which
// later, all distributions will get evaluated.
// we want to accumulate a set of xs.
let xs: array(float) =
dists
|> E.A.fold_left((accXs, (d, w)) => {
switch (d) {
| `Simple(t) when (GenericSimple.contType(t) == `Discrete) => accXs
| `Simple(d) => {
let xs = GenericSimple.interpolateXs(~xSelection=`ByWeight, d, sampleCount)
E.A.append(accXs, xs)
}
| `PointwiseCombination(ts) => {
let xs = findContinuousXs(ts, sampleCount);
E.A.append(accXs, xs)
}
}
}, [||]);
xs
}; };
let continuousShape = (dists: t, sampleCount: int) => { /* Accumulate (accContShapes, accDistShapes), each of which is an array of {xs, ys} shapes. */
let xs = let rec accumulateContAndDiscShapes = (dists: t, continuousXs: array(float), currentWeight) => {
dists
|> E.A.fmap(r =>
r
|> fst
|> GenericSimple.interpolateXs(
~xSelection=`ByWeight,
_,
sampleCount / (dists |> E.A.length),
)
)
|> E.A.concatMany;
xs |> Array.fast_sort(compare);
let ys = xs |> E.A.fmap(pdf(_, dists));
XYShape.T.fromArrays(xs, ys) |> Distributions.Continuous.make(`Linear, _);
};
let toShape = (dists: t, sampleCount: int) => {
let normalized = normalizeWeights(dists); let normalized = normalizeWeights(dists);
let continuous =
normalized normalized
|> E.A.filter(((r, _)) => GenericSimple.contType(r) == `Continuous) |> E.A.fold_left(((accContShapes: array(DistTypes.xyShape), accDiscShapes: array(DistTypes.xyShape)), (d, w)) => {
|> continuousShape(_, sampleCount); switch (d) {
let discrete =
normalized | `Simple(`Float(x)) => {
|> E.A.filter(((r, _)) => GenericSimple.contType(r) == `Discrete) let ds: DistTypes.xyShape = {xs: [|x|], ys: [|w *. currentWeight|]};
|> discreteShape(_, sampleCount); (accContShapes, E.A.append(accDiscShapes, [|ds|]))
let shape = }
MixedShapeBuilder.buildSimple(~continuous=Some(continuous), ~discrete);
| `Simple(d) when (GenericSimple.contType(d) == `Continuous) => {
let ys = continuousXs |> E.A.fmap(x => GenericSimple.pdf(x, d) *. w *. currentWeight);
let cs = XYShape.T.fromArrays(continuousXs, ys);
(E.A.append(accContShapes, [|cs|]), accDiscShapes)
}
| `Simple(d) => (accContShapes, accDiscShapes) // default -- should never happen
| `PointwiseCombination(ts) => {
let (cs, ds) = accumulateContAndDiscShapes(ts, continuousXs, w *. currentWeight);
(E.A.append(accContShapes, cs), E.A.append(accDiscShapes, ds))
}
}
}, ([||]: array(DistTypes.xyShape), [||]: array(DistTypes.xyShape)))
};
/*
We will assume that each dist (of t) in the multimodal has a total of one.
We can therefore normalize the weights of the parts.
However, a multimodal can consist of both discrete and continuous shapes.
These need to be added and collected individually.
*/
let toShape = (dists: t, sampleCount: int) => {
let continuousXs = findContinuousXs(dists, sampleCount);
continuousXs |> Array.fast_sort(compare);
let (contShapes, distShapes) = accumulateContAndDiscShapes(dists, continuousXs, 1.0);
let combinedContinuous = contShapes
|> E.A.fold_left((shapeAcc: DistTypes.xyShape, shape: DistTypes.xyShape) => {
let ys = E.A.fmapi((i, y) => y +. shape.ys[i], shapeAcc.ys);
{xs: continuousXs, ys: ys}
}, {xs: continuousXs, ys: Array.make(Array.length(continuousXs), 0.0)})
|> Distributions.Continuous.make(`Linear);
let combinedDiscrete = Distributions.Discrete.reduce((+.), distShapes)
let shape = MixedShapeBuilder.buildSimple(~continuous=Some(combinedContinuous), ~discrete=combinedDiscrete);
shape |> E.O.toExt(""); shape |> E.O.toExt("");
}; };
let toString = (dists: t) => { let rec toString = (dists: t): string => {
let distString = let distString =
dists dists
|> E.A.fmap(d => GenericSimple.toString(fst(d))) |> E.A.fmap(((d, _)) =>
|> Js.Array.joinWith(","); switch (d) {
let weights = | `Simple(d) => GenericSimple.toString(d)
dists | `PointwiseCombination(ts: t) => ts |> toString
|> E.A.fmap(d => }
snd(d) |> Js.Float.toPrecisionWithPrecision(~digits=2)
) )
|> Js.Array.joinWith(","); |> Js.Array.joinWith(",");
// mm(normal(0,1), normal(1,2)) => "multimodal(normal(0,1), normal(1,2), )
let weights =
dists
|> E.A.fmap(((_, w)) =>
Js.Float.toPrecisionWithPrecision(w, ~digits=2)
)
|> Js.Array.joinWith(",");
{j|multimodal($distString, [$weights])|j}; {j|multimodal($distString, [$weights])|j};
}; };
}; };
// assume that recursive pointwiseNormalizedDistSums are the only type of operation there is.
// in the original, it was a list of (dist, weight) tuples. Now, it's a tree of (dist, weight) tuples, just that every
// dist can be either a GenericSimple or another PointwiseAdd.
/*let toString = (r: bigDistTree) => {
switch (r) {
| WeightedDistLeaf((w, d)) => GenericWeighted.toString(w) // "normal "
| PointwiseNormalizedDistSum(childTrees) => childTrees |> E.A.fmap(toString) |> Js.Array.joinWith("")
}
}*/
let toString = (r: bigDist) => let toString = (r: bigDist) =>
// we need to recursively create the string representation of the tree.
r r
|> ( |> (
fun fun