Minor renames, and moved attemptAlgebraicOperation to SymbolicDist

This commit is contained in:
Ozzie Gooen 2020-07-02 12:14:16 +01:00
parent 491ac15f7b
commit 101824e500
6 changed files with 241 additions and 262 deletions

View File

@ -382,10 +382,10 @@ describe("Shape", () => {
let variance = stdev ** 2.0; let variance = stdev ** 2.0;
let numSamples = 10000; let numSamples = 10000;
open Distributions.Shape; open Distributions.Shape;
let normal: SymbolicDist.dist = `Normal({mean, stdev}); let normal: SymbolicTypes.symbolicDist = `Normal({mean, stdev});
let normalShape = TreeNode.toShape(numSamples, `DistData(`Symbolic(normal))); let normalShape = TreeNode.toShape(numSamples, `Leaf(`SymbolicDist(normal)));
let lognormal = SymbolicDist.Lognormal.fromMeanAndStdev(mean, stdev); let lognormal = SymbolicDist.Lognormal.fromMeanAndStdev(mean, stdev);
let lognormalShape = TreeNode.toShape(numSamples, `DistData(`Symbolic(lognormal))); let lognormalShape = TreeNode.toShape(numSamples, `Leaf(`SymbolicDist(lognormal)));
makeTestCloseEquality( makeTestCloseEquality(
"Mean of a normal", "Mean of a normal",

View File

@ -388,8 +388,8 @@ module Draw = {
let stdev = 15.0; let stdev = 15.0;
let numSamples = 3000; let numSamples = 3000;
let normal: SymbolicDist.dist = `Normal({mean, stdev}); let normal: SymbolicTypes.symbolicDist = `Normal({mean, stdev});
let normalShape = TreeNode.toShape(numSamples, `DistData(`Symbolic(normal))); let normalShape = TreeNode.toShape(numSamples, `Leaf(`SymbolicDist(normal)));
let xyShape: Types.xyShape = let xyShape: Types.xyShape =
switch (normalShape) { switch (normalShape) {
| Mixed(_) => {xs: [||], ys: [||]} | Mixed(_) => {xs: [||], ys: [||]}
@ -398,9 +398,9 @@ module Draw = {
}; };
/* // To use a lognormal instead: /* // To use a lognormal instead:
let lognormal = SymbolicDist.Lognormal.fromMeanAndStdev(mean, stdev); let lognormal = SymbolicTypes.Lognormal.fromMeanAndStdev(mean, stdev);
let lognormalShape = let lognormalShape =
SymbolicDist.GenericSimple.toShape(lognormal, numSamples); SymbolicTypes.GenericSimple.toShape(lognormal, numSamples);
let lognormalXYShape: Types.xyShape = let lognormalXYShape: Types.xyShape =
switch (lognormalShape) { switch (lognormalShape) {
| Mixed(_) => {xs: [||], ys: [||]} | Mixed(_) => {xs: [||], ys: [||]}

View File

@ -89,26 +89,26 @@ module MathAdtToDistDst = {
let normal: array(arg) => result(TreeNode.treeNode, string) = let normal: array(arg) => result(TreeNode.treeNode, string) =
fun fun
| [|Value(mean), Value(stdev)|] => | [|Value(mean), Value(stdev)|] =>
Ok(`DistData(`Symbolic(`Normal({mean, stdev})))) Ok(`Leaf(`SymbolicDist(`Normal({mean, stdev}))))
| _ => Error("Wrong number of variables in normal distribution"); | _ => Error("Wrong number of variables in normal distribution");
let lognormal: array(arg) => result(TreeNode.treeNode, string) = let lognormal: array(arg) => result(TreeNode.treeNode, string) =
fun fun
| [|Value(mu), Value(sigma)|] => | [|Value(mu), Value(sigma)|] =>
Ok(`DistData(`Symbolic(`Lognormal({mu, sigma})))) Ok(`Leaf(`SymbolicDist(`Lognormal({mu, sigma}))))
| [|Object(o)|] => { | [|Object(o)|] => {
let g = Js.Dict.get(o); let g = Js.Dict.get(o);
switch (g("mean"), g("stdev"), g("mu"), g("sigma")) { switch (g("mean"), g("stdev"), g("mu"), g("sigma")) {
| (Some(Value(mean)), Some(Value(stdev)), _, _) => | (Some(Value(mean)), Some(Value(stdev)), _, _) =>
Ok( Ok(
`DistData( `Leaf(
`Symbolic( `SymbolicDist(
SymbolicDist.Lognormal.fromMeanAndStdev(mean, stdev), SymbolicDist.Lognormal.fromMeanAndStdev(mean, stdev),
), ),
), ),
) )
| (_, _, Some(Value(mu)), Some(Value(sigma))) => | (_, _, Some(Value(mu)), Some(Value(sigma))) =>
Ok(`DistData(`Symbolic(`Lognormal({mu, sigma})))) Ok(`Leaf(`SymbolicDist(`Lognormal({mu, sigma}))))
| _ => Error("Lognormal distribution would need mean and stdev") | _ => Error("Lognormal distribution would need mean and stdev")
}; };
} }
@ -118,15 +118,15 @@ module MathAdtToDistDst = {
fun fun
| [|Value(low), Value(high)|] when low <= 0.0 && low < high => { | [|Value(low), Value(high)|] when low <= 0.0 && low < high => {
Ok( Ok(
`DistData( `Leaf(
`Symbolic(SymbolicDist.Normal.from90PercentCI(low, high)), `SymbolicDist(SymbolicDist.Normal.from90PercentCI(low, high)),
), ),
); );
} }
| [|Value(low), Value(high)|] when low < high => { | [|Value(low), Value(high)|] when low < high => {
Ok( Ok(
`DistData( `Leaf(
`Symbolic(SymbolicDist.Lognormal.from90PercentCI(low, high)), `SymbolicDist(SymbolicDist.Lognormal.from90PercentCI(low, high)),
), ),
); );
} }
@ -137,31 +137,31 @@ module MathAdtToDistDst = {
let uniform: array(arg) => result(TreeNode.treeNode, string) = let uniform: array(arg) => result(TreeNode.treeNode, string) =
fun fun
| [|Value(low), Value(high)|] => | [|Value(low), Value(high)|] =>
Ok(`DistData(`Symbolic(`Uniform({low, high})))) Ok(`Leaf(`SymbolicDist(`Uniform({low, high}))))
| _ => Error("Wrong number of variables in lognormal distribution"); | _ => Error("Wrong number of variables in lognormal distribution");
let beta: array(arg) => result(TreeNode.treeNode, string) = let beta: array(arg) => result(TreeNode.treeNode, string) =
fun fun
| [|Value(alpha), Value(beta)|] => | [|Value(alpha), Value(beta)|] =>
Ok(`DistData(`Symbolic(`Beta({alpha, beta})))) Ok(`Leaf(`SymbolicDist(`Beta({alpha, beta}))))
| _ => Error("Wrong number of variables in lognormal distribution"); | _ => Error("Wrong number of variables in lognormal distribution");
let exponential: array(arg) => result(TreeNode.treeNode, string) = let exponential: array(arg) => result(TreeNode.treeNode, string) =
fun fun
| [|Value(rate)|] => | [|Value(rate)|] =>
Ok(`DistData(`Symbolic(`Exponential({rate: rate})))) Ok(`Leaf(`SymbolicDist(`Exponential({rate: rate}))))
| _ => Error("Wrong number of variables in Exponential distribution"); | _ => Error("Wrong number of variables in Exponential distribution");
let cauchy: array(arg) => result(TreeNode.treeNode, string) = let cauchy: array(arg) => result(TreeNode.treeNode, string) =
fun fun
| [|Value(local), Value(scale)|] => | [|Value(local), Value(scale)|] =>
Ok(`DistData(`Symbolic(`Cauchy({local, scale})))) Ok(`Leaf(`SymbolicDist(`Cauchy({local, scale}))))
| _ => Error("Wrong number of variables in cauchy distribution"); | _ => Error("Wrong number of variables in cauchy distribution");
let triangular: array(arg) => result(TreeNode.treeNode, string) = let triangular: array(arg) => result(TreeNode.treeNode, string) =
fun fun
| [|Value(low), Value(medium), Value(high)|] => | [|Value(low), Value(medium), Value(high)|] =>
Ok(`DistData(`Symbolic(`Triangular({low, medium, high})))) Ok(`Leaf(`SymbolicDist(`Triangular({low, medium, high}))))
| _ => Error("Wrong number of variables in triangle distribution"); | _ => Error("Wrong number of variables in triangle distribution");
let multiModal = let multiModal =
@ -196,7 +196,7 @@ module MathAdtToDistDst = {
`VerticalScaling(( `VerticalScaling((
`Multiply, `Multiply,
t, t,
`DistData(`Symbolic(`Float(w))), `Leaf(`SymbolicDist(`Float(w))),
)), )),
); );
}); });
@ -235,7 +235,7 @@ module MathAdtToDistDst = {
SymbolicDist.ContinuousShape.make(_pdf, cdf); SymbolicDist.ContinuousShape.make(_pdf, cdf);
}); });
switch (shape) { switch (shape) {
| Some(s) => Ok(`DistData(`Symbolic(`ContinuousShape(s)))) | Some(s) => Ok(`Leaf(`SymbolicDist(`ContinuousShape(s))))
| None => Error("Rendering did not work") | None => Error("Rendering did not work")
}; };
}; };
@ -254,11 +254,11 @@ module MathAdtToDistDst = {
| ("divide", [|Ok(l), Ok(r)|]) => toOkAlgebraic((`Divide, l, r)) | ("divide", [|Ok(l), Ok(r)|]) => toOkAlgebraic((`Divide, l, r))
| ("divide", _) => Error("Division needs two operands") | ("divide", _) => Error("Division needs two operands")
| ("pow", _) => Error("Exponentiation is not yet supported.") | ("pow", _) => Error("Exponentiation is not yet supported.")
| ("leftTruncate", [|Ok(d), Ok(`DistData(`Symbolic(`Float(lc))))|]) => | ("leftTruncate", [|Ok(d), Ok(`Leaf(`SymbolicDist(`Float(lc))))|]) =>
toOkTrunctate((Some(lc), None, d)) toOkTrunctate((Some(lc), None, d))
| ("leftTruncate", _) => | ("leftTruncate", _) =>
Error("leftTruncate needs two arguments: the expression and the cutoff") Error("leftTruncate needs two arguments: the expression and the cutoff")
| ("rightTruncate", [|Ok(d), Ok(`DistData(`Symbolic(`Float(rc))))|]) => | ("rightTruncate", [|Ok(d), Ok(`Leaf(`SymbolicDist(`Float(rc))))|]) =>
toOkTrunctate((None, Some(rc), d)) toOkTrunctate((None, Some(rc), d))
| ("rightTruncate", _) => | ("rightTruncate", _) =>
Error( Error(
@ -268,8 +268,8 @@ module MathAdtToDistDst = {
"truncate", "truncate",
[| [|
Ok(d), Ok(d),
Ok(`DistData(`Symbolic(`Float(lc)))), Ok(`Leaf(`SymbolicDist(`Float(lc)))),
Ok(`DistData(`Symbolic(`Float(rc)))), Ok(`Leaf(`SymbolicDist(`Float(rc)))),
|], |],
) => ) =>
toOkTrunctate((Some(lc), Some(rc), d)) toOkTrunctate((Some(lc), Some(rc), d))
@ -333,7 +333,7 @@ module MathAdtToDistDst = {
let rec nodeParser = let rec nodeParser =
fun fun
| Value(f) => Ok(`DistData(`Symbolic(`Float(f)))) | Value(f) => Ok(`Leaf(`SymbolicDist(`Float(f))))
| Fn({name, args}) => functionParser(nodeParser, name, args) | Fn({name, args}) => functionParser(nodeParser, name, args)
| _ => { | _ => {
Error("This type not currently supported"); Error("This type not currently supported");

View File

@ -1,52 +1,4 @@
type normal = { open SymbolicTypes;
mean: float,
stdev: float,
};
type lognormal = {
mu: float,
sigma: float,
};
type uniform = {
low: float,
high: float,
};
type beta = {
alpha: float,
beta: float,
};
type exponential = {rate: float};
type cauchy = {
local: float,
scale: float,
};
type triangular = {
low: float,
medium: float,
high: float,
};
type continuousShape = {
pdf: DistTypes.continuousShape,
cdf: DistTypes.continuousShape,
};
type dist = [
| `Normal(normal)
| `Beta(beta)
| `Lognormal(lognormal)
| `Uniform(uniform)
| `Exponential(exponential)
| `Cauchy(cauchy)
| `Triangular(triangular)
| `ContinuousShape(continuousShape)
| `Float(float) // Dirac delta at x. Practically useful only in the context of multimodals.
];
module ContinuousShape = { module ContinuousShape = {
type t = continuousShape; type t = continuousShape;
@ -124,11 +76,12 @@ module Normal = {
`Normal({mean, stdev}); `Normal({mean, stdev});
}; };
let operate = (operation: SymbolicTypes.Algebraic.t, n1: t, n2: t) => switch(operation){ let operate = (operation: SymbolicTypes.Algebraic.t, n1: t, n2: t) =>
switch (operation) {
| `Add => Some(add(n1, n2)) | `Add => Some(add(n1, n2))
| `Subtract => Some(subtract(n1, n2)) | `Subtract => Some(subtract(n1, n2))
| _ => None | _ => None
} };
}; };
module Beta = { module Beta = {
@ -177,11 +130,12 @@ module Lognormal = {
let sigma = l1.sigma +. l2.sigma; let sigma = l1.sigma +. l2.sigma;
`Lognormal({mu, sigma}); `Lognormal({mu, sigma});
}; };
let operate = (operation: SymbolicTypes.Algebraic.t, n1: t, n2: t) => switch(operation){ let operate = (operation: SymbolicTypes.Algebraic.t, n1: t, n2: t) =>
switch (operation) {
| `Multiply => Some(multiply(n1, n2)) | `Multiply => Some(multiply(n1, n2))
| `Divide => Some(divide(n1, n2)) | `Divide => Some(divide(n1, n2))
| _ => None | _ => None
} };
}; };
module Uniform = { module Uniform = {
@ -202,7 +156,7 @@ module Float = {
let toString = Js.Float.toString; let toString = Js.Float.toString;
}; };
module GenericDistFunctions = { module T = {
let minCdfValue = 0.0001; let minCdfValue = 0.0001;
let maxCdfValue = 0.9999; let maxCdfValue = 0.9999;
@ -232,7 +186,7 @@ module GenericDistFunctions = {
| `ContinuousShape(n) => ContinuousShape.inv(x, n) | `ContinuousShape(n) => ContinuousShape.inv(x, n)
}; };
let sample: dist => float = let sample: symbolicDist => float =
fun fun
| `Normal(n) => Normal.sample(n) | `Normal(n) => Normal.sample(n)
| `Triangular(n) => Triangular.sample(n) | `Triangular(n) => Triangular.sample(n)
@ -244,7 +198,7 @@ module GenericDistFunctions = {
| `Float(n) => Float.sample(n) | `Float(n) => Float.sample(n)
| `ContinuousShape(n) => ContinuousShape.sample(n); | `ContinuousShape(n) => ContinuousShape.sample(n);
let toString: dist => string = let toString: symbolicDist => string =
fun fun
| `Triangular(n) => Triangular.toString(n) | `Triangular(n) => Triangular.toString(n)
| `Exponential(n) => Exponential.toString(n) | `Exponential(n) => Exponential.toString(n)
@ -256,7 +210,7 @@ module GenericDistFunctions = {
| `Float(n) => Float.toString(n) | `Float(n) => Float.toString(n)
| `ContinuousShape(n) => ContinuousShape.toString(n); | `ContinuousShape(n) => ContinuousShape.toString(n);
let min: dist => float = let min: symbolicDist => float =
fun fun
| `Triangular({low}) => low | `Triangular({low}) => low
| `Exponential(n) => Exponential.inv(minCdfValue, n) | `Exponential(n) => Exponential.inv(minCdfValue, n)
@ -268,7 +222,7 @@ module GenericDistFunctions = {
| `ContinuousShape(n) => ContinuousShape.inv(minCdfValue, n) | `ContinuousShape(n) => ContinuousShape.inv(minCdfValue, n)
| `Float(n) => n; | `Float(n) => n;
let max: dist => float = let max: symbolicDist => float =
fun fun
| `Triangular(n) => n.high | `Triangular(n) => n.high
| `Exponential(n) => Exponential.inv(maxCdfValue, n) | `Exponential(n) => Exponential.inv(maxCdfValue, n)
@ -280,7 +234,7 @@ module GenericDistFunctions = {
| `Uniform({high}) => high | `Uniform({high}) => high
| `Float(n) => n; | `Float(n) => n;
let mean: dist => result(float, string) = let mean: symbolicDist => result(float, string) =
fun fun
| `Triangular(n) => Triangular.mean(n) | `Triangular(n) => Triangular.mean(n)
| `Exponential(n) => Exponential.mean(n) | `Exponential(n) => Exponential.mean(n)
@ -293,7 +247,7 @@ module GenericDistFunctions = {
| `Float(n) => Float.mean(n); | `Float(n) => Float.mean(n);
let interpolateXs = let interpolateXs =
(~xSelection: [ | `Linear | `ByWeight]=`Linear, dist: dist, n) => { (~xSelection: [ | `Linear | `ByWeight]=`Linear, dist: symbolicDist, n) => {
switch (xSelection, dist) { switch (xSelection, dist) {
| (`Linear, _) => E.A.Floats.range(min(dist), max(dist), n) | (`Linear, _) => E.A.Floats.range(min(dist), max(dist), n)
/* | (`ByWeight, `Uniform(n)) => /* | (`ByWeight, `Uniform(n)) =>
@ -306,4 +260,36 @@ module GenericDistFunctions = {
ys |> E.A.fmap(y => inv(y, dist)); ys |> E.A.fmap(y => inv(y, dist));
}; };
}; };
/* This returns an optional that wraps a result. If the optional is None,
there is no valid analytic solution. If it Some, it
can still return an error if there is a serious problem,
like in the casea of a divide by 0.
*/
type analyticalSolutionAttempt = [
| `AnalyticalSolution(SymbolicTypes.symbolicDist)
| `Error(string)
| `NoSolution
];
let attemptAlgebraicOperation =
(
d1: symbolicDist,
d2: symbolicDist,
op: SymbolicTypes.algebraicOperation,
)
: analyticalSolutionAttempt =>
switch (d1, d2) {
| (`Float(v1), `Float(v2)) =>
switch (SymbolicTypes.Algebraic.applyFn(op, v1, v2)) {
| Ok(r) => `AnalyticalSolution(`Float(r))
| Error(n) => `Error(n)
}
| (`Normal(v1), `Normal(v2)) =>
Normal.operate(op, v1, v2)
|> E.O.dimap(r => `AnalyticalSolution(r), () => `NoSolution)
| (`Lognormal(v1), `Lognormal(v2)) =>
Lognormal.operate(op, v1, v2)
|> E.O.dimap(r => `AnalyticalSolution(r), () => `NoSolution)
| _ => `NoSolution
};
}; };

View File

@ -1,7 +1,57 @@
type normal = {
mean: float,
stdev: float,
};
type lognormal = {
mu: float,
sigma: float,
};
type uniform = {
low: float,
high: float,
};
type beta = {
alpha: float,
beta: float,
};
type exponential = {rate: float};
type cauchy = {
local: float,
scale: float,
};
type triangular = {
low: float,
medium: float,
high: float,
};
type continuousShape = {
pdf: DistTypes.continuousShape,
cdf: DistTypes.continuousShape,
};
type symbolicDist = [
| `Normal(normal)
| `Beta(beta)
| `Lognormal(lognormal)
| `Uniform(uniform)
| `Exponential(exponential)
| `Cauchy(cauchy)
| `Triangular(triangular)
| `ContinuousShape(continuousShape)
| `Float(float) // Dirac delta at x. Practically useful only in the context of multimodals.
];
type algebraicOperation = [ | `Add | `Multiply | `Subtract | `Divide];
type pointwiseOperation = [ | `Add | `Multiply]; type pointwiseOperation = [ | `Add | `Multiply];
type scaleOperation = [ | `Multiply | `Exponentiate | `Log]; type scaleOperation = [ | `Multiply | `Exponentiate | `Log];
type distToFloatOperation = [ | `Pdf(float) | `Inv(float) | `Mean | `Sample]; type distToFloatOperation = [ | `Pdf(float) | `Inv(float) | `Mean | `Sample];
type algebraicOperation = [ | `Add | `Multiply | `Subtract | `Divide];
module Algebraic = { module Algebraic = {
type t = algebraicOperation; type t = algebraicOperation;

View File

@ -1,13 +1,12 @@
/* This module represents a tree node. */ /* This module represents a tree node. */
open SymbolicTypes; open SymbolicTypes;
// todo: Symbolic already has an arbitrary continuousShape option. It seems messy to have both. type leaf = [
type distData = [ | `SymbolicDist(SymbolicTypes.symbolicDist)
| `Symbolic(SymbolicDist.dist) | `RenderedDist(DistTypes.shape)
| `RenderedShape(DistTypes.shape)
]; ];
/* TreeNodes are either Data (i.e. symbolic or rendered distributions) or Operations. Operations always refer to two child nodes.*/ /* TreeNodes are either Data (i.e. symbolic or rendered distributions) or Operations. Operations always refer to two child nodes.*/
type treeNode = [ | `DistData(distData) | `Operation(operation)] type treeNode = [ | `Leaf(leaf) | `Operation(operation)]
and operation = [ and operation = [
| `AlgebraicCombination(algebraicOperation, treeNode, treeNode) | `AlgebraicCombination(algebraicOperation, treeNode, treeNode)
| `PointwiseCombination(pointwiseOperation, treeNode, treeNode) | `PointwiseCombination(pointwiseOperation, treeNode, treeNode)
@ -48,9 +47,8 @@ module TreeNode = {
let rec toString = let rec toString =
fun fun
| `DistData(`Symbolic(d)) => | `Leaf(`SymbolicDist(d)) => SymbolicDist.T.toString(d)
SymbolicDist.GenericDistFunctions.toString(d) | `Leaf(`RenderedDist(_)) => "[shape]"
| `DistData(`RenderedShape(_)) => "[shape]"
| `Operation(op) => Operation.toString(toString, op); | `Operation(op) => Operation.toString(toString, op);
/* The following modules encapsulate everything we can do with /* The following modules encapsulate everything we can do with
@ -61,73 +59,34 @@ module TreeNode = {
For instance, normal(0, 1) + normal(1, 1) -> normal(1, 2). For instance, normal(0, 1) + normal(1, 1) -> normal(1, 2).
In general, this is implemented via convolution. */ In general, this is implemented via convolution. */
module AlgebraicCombination = { module AlgebraicCombination = {
let simplify = (algebraicOp, t1: t, t2: t): result(treeNode, string) => { let toTreeNode = (op, t1, t2) =>
let tryCombiningFloats: tResult = `Operation(`AlgebraicCombination((op, t1, t2)));
fun let tryAnalyticalSolution =
| `Operation(
`AlgebraicCombination(
algebraicOp,
`DistData(`Symbolic(`Float(v1))),
`DistData(`Symbolic(`Float(v2))),
),
) =>
SymbolicTypes.Algebraic.applyFn(algebraicOp, v1, v2)
|> E.R.fmap(r => `DistData(`Symbolic(`Float(r))))
| t => Ok(t);
let optionToSymbolicResult = (t, o) =>
o
|> E.O.dimap(r => `DistData(`Symbolic(r)), () => t)
|> (r => Ok(r));
let tryCombiningNormals: tResult =
fun fun
| `Operation( | `Operation(
`AlgebraicCombination( `AlgebraicCombination(
operation, operation,
`DistData(`Symbolic(`Normal(n1))), `Leaf(`SymbolicDist(d1)),
`DistData(`Symbolic(`Normal(n2))), `Leaf(`SymbolicDist(d2)),
), ),
) as t => ) as t =>
SymbolicDist.Normal.operate(operation, n1, n2) switch (SymbolicDist.T.attemptAlgebraicOperation(d1, d2, operation)) {
|> optionToSymbolicResult(t) | `AnalyticalSolution(symbolicDist) =>
Ok(`Leaf(`SymbolicDist(symbolicDist)))
| `Error(er) => Error(er)
| `NoSolution => Ok(t)
}
| t => Ok(t); | t => Ok(t);
let tryCombiningLognormals: tResult =
fun
| `Operation(
`AlgebraicCombination(
operation,
`DistData(`Symbolic(`Lognormal(n1))),
`DistData(`Symbolic(`Lognormal(n2))),
),
) as t =>
SymbolicDist.Lognormal.operate(operation, n1, n2)
|> optionToSymbolicResult(t)
| t => Ok(t);
let originalTreeNode =
`Operation(`AlgebraicCombination((algebraicOp, t1, t2)));
// Feedback: I like this pattern, kudos
originalTreeNode
|> tryCombiningFloats
|> E.R.bind(_, tryCombiningNormals)
|> E.R.bind(_, tryCombiningLognormals);
};
// todo: I don't like the name evaluateNumerically that much, if this renders and does it algebraically. It's tricky. // todo: I don't like the name evaluateNumerically that much, if this renders and does it algebraically. It's tricky.
let evaluateNumerically = (algebraicOp, operationToDistData, t1, t2) => { let evaluateNumerically = (algebraicOp, operationToLeaf, t1, t2) => {
// force rendering into shapes // force rendering into shapes
let renderShape = r => operationToDistData(`Render(r)); let renderShape = r => operationToLeaf(`Render(r));
switch (renderShape(t1), renderShape(t2)) { switch (renderShape(t1), renderShape(t2)) {
| ( | (Ok(`Leaf(`RenderedDist(s1))), Ok(`Leaf(`RenderedDist(s2)))) =>
Ok(`DistData(`RenderedShape(s1))),
Ok(`DistData(`RenderedShape(s2))),
) =>
Ok( Ok(
`DistData( `Leaf(
`RenderedShape( `RenderedDist(
Distributions.Shape.combineAlgebraically(algebraicOp, s1, s2), Distributions.Shape.combineAlgebraically(algebraicOp, s1, s2),
), ),
), ),
@ -138,42 +97,40 @@ module TreeNode = {
}; };
}; };
let evaluateToDistData = let evaluateToLeaf =
( (
algebraicOp: SymbolicTypes.algebraicOperation, algebraicOp: SymbolicTypes.algebraicOperation,
operationToDistData, operationToLeaf,
t1: t, t1: t,
t2: t, t2: t,
) )
: result(treeNode, string) => : result(treeNode, string) =>
algebraicOp algebraicOp
|> simplify(_, t1, t2) |> toTreeNode(_, t1, t2)
|> tryAnalyticalSolution
|> E.R.bind( |> E.R.bind(
_, _,
fun fun
| `DistData(d) => Ok(`DistData(d)) // the analytical simplifaction worked, nice! | `Leaf(d) => Ok(`Leaf(d)) // the analytical simplifaction worked, nice!
| `Operation(_) => | `Operation(_) =>
// if not, run the convolution // if not, run the convolution
evaluateNumerically(algebraicOp, operationToDistData, t1, t2), evaluateNumerically(algebraicOp, operationToLeaf, t1, t2),
); );
}; };
module VerticalScaling = { module VerticalScaling = {
let evaluateToDistData = (scaleOp, operationToDistData, t, scaleBy) => { let evaluateToLeaf = (scaleOp, operationToLeaf, t, scaleBy) => {
// scaleBy has to be a single float, otherwise we'll return an error. // scaleBy has to be a single float, otherwise we'll return an error.
let fn = SymbolicTypes.Scale.toFn(scaleOp); let fn = SymbolicTypes.Scale.toFn(scaleOp);
let knownIntegralSumFn = let knownIntegralSumFn =
SymbolicTypes.Scale.toKnownIntegralSumFn(scaleOp); SymbolicTypes.Scale.toKnownIntegralSumFn(scaleOp);
let renderedShape = operationToDistData(`Render(t)); let renderedShape = operationToLeaf(`Render(t));
switch (renderedShape, scaleBy) { switch (renderedShape, scaleBy) {
| ( | (Ok(`Leaf(`RenderedDist(rs))), `Leaf(`SymbolicDist(`Float(sm)))) =>
Ok(`DistData(`RenderedShape(rs))),
`DistData(`Symbolic(`Float(sm))),
) =>
Ok( Ok(
`DistData( `Leaf(
`RenderedShape( `RenderedDist(
Distributions.Shape.T.mapY( Distributions.Shape.T.mapY(
~knownIntegralSumFn=knownIntegralSumFn(sm), ~knownIntegralSumFn=knownIntegralSumFn(sm),
fn(sm), fn(sm),
@ -189,18 +146,15 @@ module TreeNode = {
}; };
module PointwiseCombination = { module PointwiseCombination = {
let pointwiseAdd = (operationToDistData, t1, t2) => { let pointwiseAdd = (operationToLeaf, t1, t2) => {
let renderedShape1 = operationToDistData(`Render(t1)); let renderedShape1 = operationToLeaf(`Render(t1));
let renderedShape2 = operationToDistData(`Render(t2)); let renderedShape2 = operationToLeaf(`Render(t2));
switch (renderedShape1, renderedShape2) { switch (renderedShape1, renderedShape2) {
| ( | (Ok(`Leaf(`RenderedDist(rs1))), Ok(`Leaf(`RenderedDist(rs2)))) =>
Ok(`DistData(`RenderedShape(rs1))),
Ok(`DistData(`RenderedShape(rs2))),
) =>
Ok( Ok(
`DistData( `Leaf(
`RenderedShape( `RenderedDist(
Distributions.Shape.combinePointwise( Distributions.Shape.combinePointwise(
~knownIntegralSumsFn=(a, b) => Some(a +. b), ~knownIntegralSumsFn=(a, b) => Some(a +. b),
(+.), (+.),
@ -216,18 +170,18 @@ module TreeNode = {
}; };
}; };
let pointwiseMultiply = (operationToDistData, t1, t2) => { let pointwiseMultiply = (operationToLeaf, t1, t2) => {
// TODO: construct a function that we can easily sample from, to construct // TODO: construct a function that we can easily sample from, to construct
// a RenderedShape. Use the xMin and xMax of the rendered shapes to tell the sampling function where to look. // a RenderedDist. Use the xMin and xMax of the rendered shapes to tell the sampling function where to look.
Error( Error(
"Pointwise multiplication not yet supported.", "Pointwise multiplication not yet supported.",
); );
}; };
let evaluateToDistData = (pointwiseOp, operationToDistData, t1, t2) => { let evaluateToLeaf = (pointwiseOp, operationToLeaf, t1, t2) => {
switch (pointwiseOp) { switch (pointwiseOp) {
| `Add => pointwiseAdd(operationToDistData, t1, t2) | `Add => pointwiseAdd(operationToLeaf, t1, t2)
| `Multiply => pointwiseMultiply(operationToDistData, t1, t2) | `Multiply => pointwiseMultiply(operationToLeaf, t1, t2)
}; };
}; };
}; };
@ -236,18 +190,17 @@ module TreeNode = {
module Simplify = { module Simplify = {
let tryTruncatingNothing: tResult = let tryTruncatingNothing: tResult =
fun fun
| `Operation(`Truncate(None, None, `DistData(d))) => | `Operation(`Truncate(None, None, `Leaf(d))) => Ok(`Leaf(d))
Ok(`DistData(d))
| t => Ok(t); | t => Ok(t);
let tryTruncatingUniform: tResult = let tryTruncatingUniform: tResult =
fun fun
| `Operation(`Truncate(lc, rc, `DistData(`Symbolic(`Uniform(u))))) => { | `Operation(`Truncate(lc, rc, `Leaf(`SymbolicDist(`Uniform(u))))) => {
// just create a new Uniform distribution // just create a new Uniform distribution
let newLow = max(E.O.default(neg_infinity, lc), u.low); let newLow = max(E.O.default(neg_infinity, lc), u.low);
let newHigh = min(E.O.default(infinity, rc), u.high); let newHigh = min(E.O.default(infinity, rc), u.high);
Ok( Ok(
`DistData(`Symbolic(`Uniform({low: newLow, high: newHigh}))), `Leaf(`SymbolicDist(`Uniform({low: newLow, high: newHigh}))),
); );
} }
| t => Ok(t); | t => Ok(t);
@ -262,27 +215,26 @@ module TreeNode = {
}; };
}; };
let evaluateNumerically = let evaluateNumerically = (leftCutoff, rightCutoff, operationToLeaf, t) => {
(leftCutoff, rightCutoff, operationToDistData, t) => {
// TODO: use named args in renderToShape; if we're lucky we can at least get the tail // TODO: use named args in renderToShape; if we're lucky we can at least get the tail
// of a distribution we otherwise wouldn't get at all // of a distribution we otherwise wouldn't get at all
let renderedShape = operationToDistData(`Render(t)); let renderedShape = operationToLeaf(`Render(t));
switch (renderedShape) { switch (renderedShape) {
| Ok(`DistData(`RenderedShape(rs))) => | Ok(`Leaf(`RenderedDist(rs))) =>
let truncatedShape = let truncatedShape =
rs |> Distributions.Shape.T.truncate(leftCutoff, rightCutoff); rs |> Distributions.Shape.T.truncate(leftCutoff, rightCutoff);
Ok(`DistData(`RenderedShape(rs))); Ok(`Leaf(`RenderedDist(rs)));
| Error(e1) => Error(e1) | Error(e1) => Error(e1)
| _ => Error("Could not truncate distribution.") | _ => Error("Could not truncate distribution.")
}; };
}; };
let evaluateToDistData = let evaluateToLeaf =
( (
leftCutoff: option(float), leftCutoff: option(float),
rightCutoff: option(float), rightCutoff: option(float),
operationToDistData, operationToLeaf,
t: treeNode, t: treeNode,
) )
: result(treeNode, string) => { : result(treeNode, string) => {
@ -291,31 +243,23 @@ module TreeNode = {
|> E.R.bind( |> E.R.bind(
_, _,
fun fun
| `DistData(d) => Ok(`DistData(d)) // the analytical simplifaction worked, nice! | `Leaf(d) => Ok(`Leaf(d)) // the analytical simplifaction worked, nice!
| `Operation(_) => | `Operation(_) =>
evaluateNumerically( evaluateNumerically(leftCutoff, rightCutoff, operationToLeaf, t),
leftCutoff,
rightCutoff,
operationToDistData,
t,
),
); // if not, run the convolution ); // if not, run the convolution
}; };
}; };
module Normalize = { module Normalize = {
let rec evaluateToDistData = let rec evaluateToLeaf =
(operationToDistData, t: treeNode): result(treeNode, string) => { (operationToLeaf, t: treeNode): result(treeNode, string) => {
switch (t) { switch (t) {
| `DistData(`Symbolic(_)) => Ok(t) | `Leaf(`SymbolicDist(_)) => Ok(t)
| `DistData(`RenderedShape(s)) => | `Leaf(`RenderedDist(s)) =>
let normalized = Distributions.Shape.T.normalize(s); let normalized = Distributions.Shape.T.normalize(s);
Ok(`DistData(`RenderedShape(normalized))); Ok(`Leaf(`RenderedDist(normalized)));
| `Operation(op) => | `Operation(op) =>
E.R.bind( E.R.bind(operationToLeaf(op), evaluateToLeaf(operationToLeaf))
operationToDistData(op),
evaluateToDistData(operationToDistData),
)
}; };
}; };
}; };
@ -324,14 +268,14 @@ module TreeNode = {
let evaluateFromSymbolic = (distToFloatOp: distToFloatOperation, s) => { let evaluateFromSymbolic = (distToFloatOp: distToFloatOperation, s) => {
let value = let value =
switch (distToFloatOp) { switch (distToFloatOp) {
| `Pdf(f) => Ok(SymbolicDist.GenericDistFunctions.pdf(f, s)) | `Pdf(f) => Ok(SymbolicDist.T.pdf(f, s))
| `Inv(f) => Ok(SymbolicDist.GenericDistFunctions.inv(f, s)) | `Inv(f) => Ok(SymbolicDist.T.inv(f, s))
| `Sample => Ok(SymbolicDist.GenericDistFunctions.sample(s)) | `Sample => Ok(SymbolicDist.T.sample(s))
| `Mean => SymbolicDist.GenericDistFunctions.mean(s) | `Mean => SymbolicDist.T.mean(s)
}; };
E.R.bind(value, v => Ok(`DistData(`Symbolic(`Float(v))))); E.R.bind(value, v => Ok(`Leaf(`SymbolicDist(`Float(v)))));
}; };
let evaluateFromRenderedShape = let evaluateFromRenderedDist =
(distToFloatOp: distToFloatOperation, rs: DistTypes.shape) (distToFloatOp: distToFloatOperation, rs: DistTypes.shape)
: result(treeNode, string) => { : result(treeNode, string) => {
let value = let value =
@ -341,45 +285,45 @@ module TreeNode = {
| `Sample => Ok(Distributions.Shape.sample(rs)) | `Sample => Ok(Distributions.Shape.sample(rs))
| `Mean => Ok(Distributions.Shape.T.mean(rs)) | `Mean => Ok(Distributions.Shape.T.mean(rs))
}; };
E.R.bind(value, v => Ok(`DistData(`Symbolic(`Float(v))))); E.R.bind(value, v => Ok(`Leaf(`SymbolicDist(`Float(v)))));
}; };
let rec evaluateToDistData = let rec evaluateToLeaf =
( (
distToFloatOp: distToFloatOperation, distToFloatOp: distToFloatOperation,
operationToDistData, operationToLeaf,
t: treeNode, t: treeNode,
) )
: result(treeNode, string) => { : result(treeNode, string) => {
switch (t) { switch (t) {
| `DistData(`Symbolic(s)) => evaluateFromSymbolic(distToFloatOp, s) // we want to evaluate the distToFloatOp on the symbolic dist | `Leaf(`SymbolicDist(s)) => evaluateFromSymbolic(distToFloatOp, s) // we want to evaluate the distToFloatOp on the symbolic dist
| `DistData(`RenderedShape(rs)) => | `Leaf(`RenderedDist(rs)) =>
evaluateFromRenderedShape(distToFloatOp, rs) evaluateFromRenderedDist(distToFloatOp, rs)
| `Operation(op) => | `Operation(op) =>
E.R.bind( E.R.bind(
operationToDistData(op), operationToLeaf(op),
evaluateToDistData(distToFloatOp, operationToDistData), evaluateToLeaf(distToFloatOp, operationToLeaf),
) )
}; };
}; };
}; };
module Render = { module Render = {
let rec evaluateToRenderedShape = let rec evaluateToRenderedDist =
( (
operationToDistData: operation => result(t, string), operationToLeaf: operation => result(t, string),
sampleCount: int, sampleCount: int,
t: treeNode, t: treeNode,
) )
: result(t, string) => { : result(t, string) => {
switch (t) { switch (t) {
| `DistData(`RenderedShape(s)) => Ok(`DistData(`RenderedShape(s))) // already a rendered shape, we're done here | `Leaf(`RenderedDist(s)) => Ok(`Leaf(`RenderedDist(s))) // already a rendered shape, we're done here
| `DistData(`Symbolic(d)) => | `Leaf(`SymbolicDist(d)) =>
// todo: move to dist // todo: move to dist
switch (d) { switch (d) {
| `Float(v) => | `Float(v) =>
Ok( Ok(
`DistData( `Leaf(
`RenderedShape( `RenderedDist(
Discrete( Discrete(
Distributions.Discrete.make( Distributions.Discrete.make(
{xs: [|v|], ys: [|1.0|]}, {xs: [|v|], ys: [|1.0|]},
@ -391,16 +335,15 @@ module TreeNode = {
) )
| _ => | _ =>
let xs = let xs =
SymbolicDist.GenericDistFunctions.interpolateXs( SymbolicDist.T.interpolateXs(
~xSelection=`ByWeight, ~xSelection=`ByWeight,
d, d,
sampleCount, sampleCount,
); );
let ys = let ys = xs |> E.A.fmap(x => SymbolicDist.T.pdf(x, d));
xs |> E.A.fmap(x => SymbolicDist.GenericDistFunctions.pdf(x, d));
Ok( Ok(
`DistData( `Leaf(
`RenderedShape( `RenderedDist(
Continuous( Continuous(
Distributions.Continuous.make( Distributions.Continuous.make(
`Linear, `Linear,
@ -414,57 +357,57 @@ module TreeNode = {
} }
| `Operation(op) => | `Operation(op) =>
E.R.bind( E.R.bind(
operationToDistData(op), operationToLeaf(op),
evaluateToRenderedShape(operationToDistData, sampleCount), evaluateToRenderedDist(operationToLeaf, sampleCount),
) )
}; };
}; };
}; };
let rec operationToDistData = let rec operationToLeaf =
(sampleCount: int, op: operation): result(t, string) => { (sampleCount: int, op: operation): result(t, string) => {
// the functions that convert the Operation nodes to DistData nodes need to // the functions that convert the Operation nodes to Leaf nodes need to
// have a way to call this function on their children, if their children are themselves Operation nodes. // have a way to call this function on their children, if their children are themselves Operation nodes.
switch (op) { switch (op) {
| `AlgebraicCombination(algebraicOp, t1, t2) => | `AlgebraicCombination(algebraicOp, t1, t2) =>
AlgebraicCombination.evaluateToDistData( AlgebraicCombination.evaluateToLeaf(
algebraicOp, algebraicOp,
operationToDistData(sampleCount), operationToLeaf(sampleCount),
t1, t1,
t2 // we want to give it the option to render or simply leave it as is t2 // we want to give it the option to render or simply leave it as is
) )
| `PointwiseCombination(pointwiseOp, t1, t2) => | `PointwiseCombination(pointwiseOp, t1, t2) =>
PointwiseCombination.evaluateToDistData( PointwiseCombination.evaluateToLeaf(
pointwiseOp, pointwiseOp,
operationToDistData(sampleCount), operationToLeaf(sampleCount),
t1, t1,
t2, t2,
) )
| `VerticalScaling(scaleOp, t, scaleBy) => | `VerticalScaling(scaleOp, t, scaleBy) =>
VerticalScaling.evaluateToDistData( VerticalScaling.evaluateToLeaf(
scaleOp, scaleOp,
operationToDistData(sampleCount), operationToLeaf(sampleCount),
t, t,
scaleBy, scaleBy,
) )
| `Truncate(leftCutoff, rightCutoff, t) => | `Truncate(leftCutoff, rightCutoff, t) =>
Truncate.evaluateToDistData( Truncate.evaluateToLeaf(
leftCutoff, leftCutoff,
rightCutoff, rightCutoff,
operationToDistData(sampleCount), operationToLeaf(sampleCount),
t, t,
) )
| `FloatFromDist(distToFloatOp, t) => | `FloatFromDist(distToFloatOp, t) =>
FloatFromDist.evaluateToDistData( FloatFromDist.evaluateToLeaf(
distToFloatOp, distToFloatOp,
operationToDistData(sampleCount), operationToLeaf(sampleCount),
t, t,
) )
| `Normalize(t) => | `Normalize(t) =>
Normalize.evaluateToDistData(operationToDistData(sampleCount), t) Normalize.evaluateToLeaf(operationToLeaf(sampleCount), t)
| `Render(t) => | `Render(t) =>
Render.evaluateToRenderedShape( Render.evaluateToRenderedDist(
operationToDistData(sampleCount), operationToLeaf(sampleCount),
sampleCount, sampleCount,
t, t,
) )
@ -474,23 +417,23 @@ module TreeNode = {
/* This function recursively goes through the nodes of the parse tree, /* This function recursively goes through the nodes of the parse tree,
replacing each Operation node and its subtree with a Data node. replacing each Operation node and its subtree with a Data node.
Whenever possible, the replacement produces a new Symbolic Data node, Whenever possible, the replacement produces a new Symbolic Data node,
but most often it will produce a RenderedShape. but most often it will produce a RenderedDist.
This function is used mainly to turn a parse tree into a single RenderedShape This function is used mainly to turn a parse tree into a single RenderedDist
that can then be displayed to the user. */ that can then be displayed to the user. */
let toDistData = (treeNode: t, sampleCount: int): result(t, string) => { let toLeaf = (treeNode: t, sampleCount: int): result(t, string) => {
switch (treeNode) { switch (treeNode) {
| `DistData(d) => Ok(`DistData(d)) | `Leaf(d) => Ok(`Leaf(d))
| `Operation(op) => operationToDistData(sampleCount, op) | `Operation(op) => operationToLeaf(sampleCount, op)
}; };
}; };
}; };
let toShape = (sampleCount: int, treeNode: treeNode) => { let toShape = (sampleCount: int, treeNode: treeNode) => {
let renderResult = let renderResult =
TreeNode.toDistData(`Operation(`Render(treeNode)), sampleCount); TreeNode.toLeaf(`Operation(`Render(treeNode)), sampleCount);
switch (renderResult) { switch (renderResult) {
| Ok(`DistData(`RenderedShape(rs))) => | Ok(`Leaf(`RenderedDist(rs))) =>
let continuous = Distributions.Shape.T.toContinuous(rs); let continuous = Distributions.Shape.T.toContinuous(rs);
let discrete = Distributions.Shape.T.toDiscrete(rs); let discrete = Distributions.Shape.T.toDiscrete(rs);
let shape = MixedShapeBuilder.buildSimple(~continuous, ~discrete); let shape = MixedShapeBuilder.buildSimple(~continuous, ~discrete);