Working on code reorganization, doesn't compile yet

This commit is contained in:
Sebastian Kosch 2020-06-25 23:38:14 -07:00
parent 214f3b9e58
commit bd528571af
17 changed files with 1142 additions and 843 deletions

View File

@ -386,10 +386,9 @@ describe("Shape", () => {
let numSamples = 10000; let numSamples = 10000;
open Distributions.Shape; open Distributions.Shape;
let normal: SymbolicDist.dist = `Normal({mean, stdev}); let normal: SymbolicDist.dist = `Normal({mean, stdev});
let normalShape = SymbolicDist.GenericSimple.toShape(normal, numSamples); let normalShape = TreeNode.toShape(numSamples, normal);
let lognormal = SymbolicDist.Lognormal.fromMeanAndStdev(mean, stdev); let lognormal = SymbolicDist.Lognormal.fromMeanAndStdev(mean, stdev);
let lognormalShape = let lognormalShape = TreeNode.toShape(numSamples, lognormal);
SymbolicDist.GenericSimple.toShape(lognormal, numSamples);
makeTestCloseEquality( makeTestCloseEquality(
"Mean of a normal", "Mean of a normal",

View File

@ -44,14 +44,14 @@ module DemoDist = {
Distributions.DistPlus.make( Distributions.DistPlus.make(
~shape= ~shape=
Continuous( Continuous(
Distributions.Continuous.make(`Linear, {xs, ys}), Distributions.Continuous.make(`Linear, {xs, ys}, None),
), ),
~domain=Complete, ~domain=Complete,
~unit=UnspecifiedDistribution, ~unit=UnspecifiedDistribution,
~guesstimatorString=None, ~guesstimatorString=None,
(), (),
) )
|> Distributions.DistPlus.T.scaleToIntegralSum(~intendedSum=1.0); |> Distributions.DistPlus.T.normalize;
<DistPlusPlot distPlus />; <DistPlusPlot distPlus />;
}; };
<Antd.Card title={"Distribution" |> R.ste}> <Antd.Card title={"Distribution" |> R.ste}>

View File

@ -37,7 +37,7 @@ module DemoDist = {
let parsed1 = MathJsParser.fromString(guesstimatorString); let parsed1 = MathJsParser.fromString(guesstimatorString);
let shape = let shape =
switch (parsed1) { switch (parsed1) {
| Ok(r) => Some(SymbolicDist.toShape(10000, r)) | Ok(r) => Some(TreeNode.toShape(10000, r))
| _ => None | _ => None
}; };

View File

@ -177,6 +177,7 @@ module Convert = {
let continuousShape: Types.continuousShape = { let continuousShape: Types.continuousShape = {
xyShape, xyShape,
interpolation: `Linear, interpolation: `Linear,
knownIntegralSum: None,
}; };
let integral = XYShape.Analysis.integrateContinuousShape(continuousShape); let integral = XYShape.Analysis.integrateContinuousShape(continuousShape);
@ -188,6 +189,7 @@ module Convert = {
ys, ys,
}, },
interpolation: `Linear, interpolation: `Linear,
knownIntegralSum: Some(1.0),
}; };
continuousShape; continuousShape;
}; };
@ -387,7 +389,7 @@ module Draw = {
let numSamples = 3000; let numSamples = 3000;
let normal: SymbolicDist.dist = `Normal({mean, stdev}); let normal: SymbolicDist.dist = `Normal({mean, stdev});
let normalShape = SymbolicDist.GenericSimple.toShape(normal, numSamples); let normalShape = TreeNode.toShape(numSamples, `DistData(`Symbolic(normal)));
let xyShape: Types.xyShape = let xyShape: Types.xyShape =
switch (normalShape) { switch (normalShape) {
| Mixed(_) => {xs: [||], ys: [||]} | Mixed(_) => {xs: [||], ys: [||]}
@ -667,9 +669,7 @@ module State = {
/* create a cdf from a pdf */ /* create a cdf from a pdf */
let _pdf = let _pdf =
Distributions.Continuous.T.scaleToIntegralSum( Distributions.Continuous.T.normalize(
~cache=None,
~intendedSum=1.0,
pdf, pdf,
); );

View File

@ -95,7 +95,7 @@ let table = (distPlus, x) => {
</td> </td>
<td className="px-4 py-2 border "> <td className="px-4 py-2 border ">
{distPlus {distPlus
|> Distributions.DistPlus.T.toScaledContinuous |> Distributions.DistPlus.T.normalizedToContinuous
|> E.O.fmap( |> E.O.fmap(
Distributions.Continuous.T.Integral.sum(~cache=None), Distributions.Continuous.T.Integral.sum(~cache=None),
) )
@ -113,7 +113,7 @@ let table = (distPlus, x) => {
</td> </td>
<td className="px-4 py-2 border "> <td className="px-4 py-2 border ">
{distPlus {distPlus
|> Distributions.DistPlus.T.toScaledDiscrete |> Distributions.DistPlus.T.normalizedToDiscrete
|> E.O.fmap(Distributions.Discrete.T.Integral.sum(~cache=None)) |> E.O.fmap(Distributions.Discrete.T.Integral.sum(~cache=None))
|> E.O.fmap(E.Float.with2DigitsPrecision) |> E.O.fmap(E.Float.with2DigitsPrecision)
|> E.O.default("") |> E.O.default("")
@ -211,15 +211,13 @@ let percentiles = distPlus => {
</div>; </div>;
}; };
let adjustBoth = discreteProbabilityMass => { let adjustBoth = discreteProbabilityMassFraction => {
let yMaxDiscreteDomainFactor = discreteProbabilityMass; let yMaxDiscreteDomainFactor = discreteProbabilityMassFraction;
let yMaxContinuousDomainFactor = 1.0 -. discreteProbabilityMass; let yMaxContinuousDomainFactor = 1.0 -. discreteProbabilityMassFraction;
let yMax = let yMax = (yMaxDiscreteDomainFactor > 0.5 ? yMaxDiscreteDomainFactor : yMaxContinuousDomainFactor);
yMaxDiscreteDomainFactor > yMaxContinuousDomainFactor
? yMaxDiscreteDomainFactor : yMaxContinuousDomainFactor;
( (
1.0 /. (yMaxDiscreteDomainFactor /. yMax), yMax /. yMaxDiscreteDomainFactor,
1.0 /. (yMaxContinuousDomainFactor /. yMax), yMax /. yMaxContinuousDomainFactor,
); );
}; };
@ -227,10 +225,10 @@ module DistPlusChart = {
[@react.component] [@react.component]
let make = (~distPlus: DistTypes.distPlus, ~config: chartConfig, ~onHover) => { let make = (~distPlus: DistTypes.distPlus, ~config: chartConfig, ~onHover) => {
open Distributions.DistPlus; open Distributions.DistPlus;
let discrete = distPlus |> T.toScaledDiscrete; let discrete = distPlus |> T.normalizedToDiscrete |> E.O.fmap(Distributions.Discrete.getShape);
let continuous = let continuous =
distPlus distPlus
|> T.toScaledContinuous |> T.normalizedToContinuous
|> E.O.fmap(Distributions.Continuous.getShape); |> E.O.fmap(Distributions.Continuous.getShape);
let range = T.xTotalRange(distPlus); let range = T.xTotalRange(distPlus);
@ -254,10 +252,10 @@ module DistPlusChart = {
}; };
let timeScale = distPlus.unit |> DistTypes.DistributionUnit.toJson; let timeScale = distPlus.unit |> DistTypes.DistributionUnit.toJson;
let toDiscreteProbabilityMass = let discreteProbabilityMassFraction =
distPlus |> Distributions.DistPlus.T.toDiscreteProbabilityMass; distPlus |> Distributions.DistPlus.T.toDiscreteProbabilityMassFraction;
let (yMaxDiscreteDomainFactor, yMaxContinuousDomainFactor) = let (yMaxDiscreteDomainFactor, yMaxContinuousDomainFactor) =
adjustBoth(toDiscreteProbabilityMass); adjustBoth(discreteProbabilityMassFraction);
<DistributionPlot <DistributionPlot
xScale={config.xLog ? "log" : "linear"} xScale={config.xLog ? "log" : "linear"}
yScale={config.yLog ? "log" : "linear"} yScale={config.yLog ? "log" : "linear"}

View File

@ -17,14 +17,18 @@ type xyShape = {
type continuousShape = { type continuousShape = {
xyShape, xyShape,
interpolation: [ | `Stepwise | `Linear], interpolation: [ | `Stepwise | `Linear],
knownIntegralSum: option(float),
}; };
type discreteShape = xyShape; type discreteShape = {
xyShape,
knownIntegralSum: option(float),
};
type mixedShape = { type mixedShape = {
continuous: continuousShape, continuous: continuousShape,
discrete: discreteShape, discrete: discreteShape,
discreteProbabilityMassFraction: float, // discreteProbabilityMassFraction: float,
}; };
type shapeMonad('a, 'b, 'c) = type shapeMonad('a, 'b, 'c) =

File diff suppressed because it is too large Load Diff

View File

@ -8,14 +8,15 @@ type assumptions = {
discreteProbabilityMass: option(float), discreteProbabilityMass: option(float),
}; };
let buildSimple = (~continuous: option(DistTypes.continuousShape), ~discrete): option(DistTypes.shape) => { let buildSimple = (~continuous: option(DistTypes.continuousShape), ~discrete: option(DistTypes.discreteShape)): option(DistTypes.shape) => {
let continuous = continuous |> E.O.default(Distributions.Continuous.make(`Linear, {xs: [||], ys: [||]})) let continuous = continuous |> E.O.default(Distributions.Continuous.make(`Linear, {xs: [||], ys: [||]}, Some(0.0)));
let discrete = discrete |> E.O.default(Distributions.Discrete.make({xs: [||], ys: [||]}, Some(0.0)));
let cLength = let cLength =
continuous continuous
|> Distributions.Continuous.getShape |> Distributions.Continuous.getShape
|> XYShape.T.xs |> XYShape.T.xs
|> E.A.length; |> E.A.length;
let dLength = discrete |> XYShape.T.xs |> E.A.length; let dLength = discrete |> Distributions.Discrete.getShape |> XYShape.T.xs |> E.A.length;
switch (cLength, dLength) { switch (cLength, dLength) {
| (0 | 1, 0) => None | (0 | 1, 0) => None
| (0 | 1, _) => Some(Discrete(discrete)) | (0 | 1, _) => Some(Discrete(discrete))
@ -23,18 +24,12 @@ let buildSimple = (~continuous: option(DistTypes.continuousShape), ~discrete): o
| (_, _) => | (_, _) =>
let discreteProbabilityMassFraction = let discreteProbabilityMassFraction =
Distributions.Discrete.T.Integral.sum(~cache=None, discrete); Distributions.Discrete.T.Integral.sum(~cache=None, discrete);
let discrete = let discrete = Distributions.Discrete.T.normalize(discrete);
Distributions.Discrete.T.scaleToIntegralSum(~intendedSum=1.0, discrete); let continuous = Distributions.Continuous.T.normalize(continuous);
let continuous =
Distributions.Continuous.T.scaleToIntegralSum(
~intendedSum=1.0,
continuous,
);
let mixedDist = let mixedDist =
Distributions.Mixed.make( Distributions.Mixed.make(
~continuous, ~continuous,
~discrete, ~discrete
~discreteProbabilityMassFraction,
); );
Some(Mixed(mixedDist)); Some(Mixed(mixedDist));
}; };
@ -42,7 +37,7 @@ let buildSimple = (~continuous: option(DistTypes.continuousShape), ~discrete): o
// TODO: Delete, only being used in tests // TODO: Delete, only being used in tests
let build = (~continuous, ~discrete, ~assumptions) => /*let build = (~continuous, ~discrete, ~assumptions) =>
switch (assumptions) { switch (assumptions) {
| { | {
continuous: ADDS_TO_CORRECT_PROBABILITY, continuous: ADDS_TO_CORRECT_PROBABILITY,
@ -102,4 +97,4 @@ let build = (~continuous, ~discrete, ~assumptions) =>
), ),
); );
| _ => None | _ => None
}; };*/

View File

@ -17,6 +17,7 @@ module T = {
type ts = array(xyShape); type ts = array(xyShape);
let xs = (t: t) => t.xs; let xs = (t: t) => t.xs;
let ys = (t: t) => t.ys; let ys = (t: t) => t.ys;
let length = (t: t) => E.A.length(t.xs);
let empty = {xs: [||], ys: [||]}; let empty = {xs: [||], ys: [||]};
let minX = (t: t) => t |> xs |> E.A.Sorted.min |> extImp; let minX = (t: t) => t |> xs |> E.A.Sorted.min |> extImp;
let maxX = (t: t) => t |> xs |> E.A.Sorted.max |> extImp; let maxX = (t: t) => t |> xs |> E.A.Sorted.max |> extImp;
@ -154,7 +155,9 @@ module XsConversion = {
let proportionByProbabilityMass = let proportionByProbabilityMass =
(newLength: int, integral: T.t, t: T.t): T.t => { (newLength: int, integral: T.t, t: T.t): T.t => {
equallyDivideXByMass(newLength, integral) |> _replaceWithXs(_, t); integral
|> equallyDivideXByMass(newLength) // creates a new set of xs at evenly spaced percentiles
|> _replaceWithXs(_, t); // linearly interpolates new ys for the new xs
}; };
}; };
@ -164,6 +167,7 @@ module Zipped = {
let compareXs = ((x1, _), (x2, _)) => x1 > x2 ? 1 : 0; let compareXs = ((x1, _), (x2, _)) => x1 > x2 ? 1 : 0;
let sortByY = (t: zipped) => t |> E.A.stableSortBy(_, compareYs); let sortByY = (t: zipped) => t |> E.A.stableSortBy(_, compareYs);
let sortByX = (t: zipped) => t |> E.A.stableSortBy(_, compareXs); let sortByX = (t: zipped) => t |> E.A.stableSortBy(_, compareXs);
let filterByX = (testFn: (float => bool), t: zipped) => t |> E.A.filter(((x, _)) => testFn(x));
}; };
module Combine = { module Combine = {
@ -253,8 +257,8 @@ module Range = {
Belt.Array.set( Belt.Array.set(
cumulativeY, cumulativeY,
x + 1, x + 1,
(xs[x + 1] -. xs[x]) (xs[x + 1] -. xs[x]) // dx
*. ((ys[x] +. ys[x + 1]) /. 2.) *. ((ys[x] +. ys[x + 1]) /. 2.) // (1/2) * (avgY)
+. cumulativeY[x], +. cumulativeY[x],
); );
(); ();

View File

@ -43,7 +43,7 @@ module ShapeRenderer = {
module Symbolic = { module Symbolic = {
type inputs = {length: int}; type inputs = {length: int};
type outputs = { type outputs = {
graph: SymbolicDist.distTree, graph: TreeNode.treeNode,
shape: DistTypes.shape, shape: DistTypes.shape,
}; };
let make = (graph, shape) => {graph, shape}; let make = (graph, shape) => {graph, shape};

View File

@ -21,7 +21,7 @@ let runSymbolic = (guesstimatorString, length) => {
|> E.R.fmap(g => |> E.R.fmap(g =>
RenderTypes.ShapeRenderer.Symbolic.make( RenderTypes.ShapeRenderer.Symbolic.make(
g, g,
SymbolicDist.toShape(length, g), TreeNode.toShape(length, g),
) )
); );
}; };

View File

@ -4,10 +4,10 @@ type discrete = {
ys: array(float), ys: array(float),
}; };
let jsToDistDiscrete = (d: discrete): DistTypes.discreteShape => { let jsToDistDiscrete = (d: discrete): DistTypes.discreteShape => {xyShape: {
xs: xsGet(d), xs: xsGet(d),
ys: ysGet(d), ys: ysGet(d),
}; }, knownIntegralSum: None};
[@bs.module "./GuesstimatorLibrary.js"] [@bs.module "./GuesstimatorLibrary.js"]
external stringToSamples: (string, int) => array(float) = "stringToSamples"; external stringToSamples: (string, int) => array(float) = "stringToSamples";

View File

@ -115,11 +115,12 @@ module T = {
Array.fast_sort(compare, samples); Array.fast_sort(compare, samples);
let (continuousPart, discretePart) = E.A.Sorted.Floats.split(samples); let (continuousPart, discretePart) = E.A.Sorted.Floats.split(samples);
let length = samples |> E.A.length |> float_of_int; let length = samples |> E.A.length |> float_of_int;
let discrete: DistTypes.xyShape = let discrete: DistTypes.discreteShape =
discretePart discretePart
|> E.FloatFloatMap.fmap(r => r /. length) |> E.FloatFloatMap.fmap(r => r /. length)
|> E.FloatFloatMap.toArray |> E.FloatFloatMap.toArray
|> XYShape.T.fromZippedArray; |> XYShape.T.fromZippedArray
|> Distributions.Discrete.make(_, None);
let pdf = let pdf =
continuousPart |> E.A.length > 5 continuousPart |> E.A.length > 5
@ -149,14 +150,14 @@ module T = {
~outputXYPoints=samplingInputs.outputXYPoints, ~outputXYPoints=samplingInputs.outputXYPoints,
formatUnitWidth(usedUnitWidth), formatUnitWidth(usedUnitWidth),
) )
|> Distributions.Continuous.make(`Linear) |> Distributions.Continuous.make(`Linear, _, None)
|> (r => Some((r, foo))); |> (r => Some((r, foo)));
} }
: None; : None;
let shape = let shape =
MixedShapeBuilder.buildSimple( MixedShapeBuilder.buildSimple(
~continuous=pdf |> E.O.fmap(fst), ~continuous=pdf |> E.O.fmap(fst),
~discrete, ~discrete=Some(discrete),
); );
let samplesParse: RenderTypes.ShapeRenderer.Sampling.outputs = { let samplesParse: RenderTypes.ShapeRenderer.Sampling.outputs = {
continuousParseParams: pdf |> E.O.fmap(snd), continuousParseParams: pdf |> E.O.fmap(snd),

View File

@ -88,69 +88,69 @@ module MathAdtToDistDst = {
); );
}; };
let normal: array(arg) => result(SymbolicDist.distTree, string) = let normal: array(arg) => result(TreeNode.treeNode, string) =
fun fun
| [|Value(mean), Value(stdev)|] => | [|Value(mean), Value(stdev)|] =>
Ok(`Simple(`Normal({mean, stdev}))) Ok(`DistData(`Symbolic(`Normal({mean, stdev}))))
| _ => Error("Wrong number of variables in normal distribution"); | _ => Error("Wrong number of variables in normal distribution");
let lognormal: array(arg) => result(SymbolicDist.distTree, string) = let lognormal: array(arg) => result(TreeNode.treeNode, string) =
fun fun
| [|Value(mu), Value(sigma)|] => Ok(`Simple(`Lognormal({mu, sigma}))) | [|Value(mu), Value(sigma)|] => Ok(`DistData(`Symbolic(`Lognormal({mu, sigma}))))
| [|Object(o)|] => { | [|Object(o)|] => {
let g = Js.Dict.get(o); let g = Js.Dict.get(o);
switch (g("mean"), g("stdev"), g("mu"), g("sigma")) { switch (g("mean"), g("stdev"), g("mu"), g("sigma")) {
| (Some(Value(mean)), Some(Value(stdev)), _, _) => | (Some(Value(mean)), Some(Value(stdev)), _, _) =>
Ok(`Simple(SymbolicDist.Lognormal.fromMeanAndStdev(mean, stdev))) Ok(`DistData(`Symbolic(SymbolicDist.Lognormal.fromMeanAndStdev(mean, stdev))))
| (_, _, Some(Value(mu)), Some(Value(sigma))) => | (_, _, Some(Value(mu)), Some(Value(sigma))) =>
Ok(`Simple(`Lognormal({mu, sigma}))) Ok(`DistData(`Symbolic(`Lognormal({mu, sigma}))))
| _ => Error("Lognormal distribution would need mean and stdev") | _ => Error("Lognormal distribution would need mean and stdev")
}; };
} }
| _ => Error("Wrong number of variables in lognormal distribution"); | _ => Error("Wrong number of variables in lognormal distribution");
let to_: array(arg) => result(SymbolicDist.distTree, string) = let to_: array(arg) => result(TreeNode.treeNode, string) =
fun fun
| [|Value(low), Value(high)|] when low <= 0.0 && low < high=> { | [|Value(low), Value(high)|] when low <= 0.0 && low < high=> {
Ok(`Simple(SymbolicDist.Normal.from90PercentCI(low, high))); Ok(`DistData(`Symbolic(SymbolicDist.Normal.from90PercentCI(low, high))));
} }
| [|Value(low), Value(high)|] when low < high => { | [|Value(low), Value(high)|] when low < high => {
Ok(`Simple(SymbolicDist.Lognormal.from90PercentCI(low, high))); Ok(`DistData(`Symbolic(SymbolicDist.Lognormal.from90PercentCI(low, high))));
} }
| [|Value(_), Value(_)|] => | [|Value(_), Value(_)|] =>
Error("Low value must be less than high value.") Error("Low value must be less than high value.")
| _ => Error("Wrong number of variables in lognormal distribution"); | _ => Error("Wrong number of variables in lognormal distribution");
let uniform: array(arg) => result(SymbolicDist.distTree, string) = let uniform: array(arg) => result(TreeNode.treeNode, string) =
fun fun
| [|Value(low), Value(high)|] => Ok(`Simple(`Uniform({low, high}))) | [|Value(low), Value(high)|] => Ok(`DistData(`Symbolic(`Uniform({low, high}))))
| _ => Error("Wrong number of variables in lognormal distribution"); | _ => Error("Wrong number of variables in lognormal distribution");
let beta: array(arg) => result(SymbolicDist.distTree, string) = let beta: array(arg) => result(TreeNode.treeNode, string) =
fun fun
| [|Value(alpha), Value(beta)|] => Ok(`Simple(`Beta({alpha, beta}))) | [|Value(alpha), Value(beta)|] => Ok(`DistData(`Symbolic(`Beta({alpha, beta}))))
| _ => Error("Wrong number of variables in lognormal distribution"); | _ => Error("Wrong number of variables in lognormal distribution");
let exponential: array(arg) => result(SymbolicDist.distTree, string) = let exponential: array(arg) => result(TreeNode.treeNode, string) =
fun fun
| [|Value(rate)|] => Ok(`Simple(`Exponential({rate: rate}))) | [|Value(rate)|] => Ok(`DistData(`Symbolic(`Exponential({rate: rate}))))
| _ => Error("Wrong number of variables in Exponential distribution"); | _ => Error("Wrong number of variables in Exponential distribution");
let cauchy: array(arg) => result(SymbolicDist.distTree, string) = let cauchy: array(arg) => result(TreeNode.treeNode, string) =
fun fun
| [|Value(local), Value(scale)|] => | [|Value(local), Value(scale)|] =>
Ok(`Simple(`Cauchy({local, scale}))) Ok(`DistData(`Symbolic(`Cauchy({local, scale}))))
| _ => Error("Wrong number of variables in cauchy distribution"); | _ => Error("Wrong number of variables in cauchy distribution");
let triangular: array(arg) => result(SymbolicDist.distTree, string) = let triangular: array(arg) => result(TreeNode.treeNode, string) =
fun fun
| [|Value(low), Value(medium), Value(high)|] => | [|Value(low), Value(medium), Value(high)|] =>
Ok(`Simple(`Triangular({low, medium, high}))) Ok(`DistData(`Symbolic(`Triangular({low, medium, high}))))
| _ => Error("Wrong number of variables in triangle distribution"); | _ => Error("Wrong number of variables in triangle distribution");
let multiModal = let multiModal =
( (
args: array(result(SymbolicDist.distTree, string)), args: array(result(TreeNode.treeNode, string)),
weights: option(array(float)), weights: option(array(float)),
) => { ) => {
let weights = weights |> E.O.default([||]); let weights = weights |> E.O.default([||]);
@ -158,16 +158,8 @@ module MathAdtToDistDst = {
args args
|> E.A.fmap( |> E.A.fmap(
fun fun
| Ok(`Simple(d)) => Ok(`Simple(d)) | Ok(a) => a
| Ok(`Combination(t1, t2, op)) => Ok(`Combination(t1, t2, op))
| Ok(`PointwiseSum(t1, t2)) => Ok(`PointwiseSum(t1, t2))
| Ok(`PointwiseProduct(t1, t2)) => Ok(`PointwiseProduct(t1, t2))
| Ok(`Normalize(t)) => Ok(`Normalize(t))
| Ok(`LeftTruncate(t, x)) => Ok(`LeftTruncate(t, x))
| Ok(`RightTruncate(t, x)) => Ok(`RightTruncate(t, x))
| Ok(`Render(t)) => Ok(`Render(t))
| Error(e) => Error(e) | Error(e) => Error(e)
| _ => Error("Unexpected dist")
); );
let firstWithError = dists |> Belt.Array.getBy(_, Belt.Result.isError); let firstWithError = dists |> Belt.Array.getBy(_, Belt.Result.isError);
@ -182,7 +174,7 @@ module MathAdtToDistDst = {
|> E.A.fmapi((index, t) => { |> E.A.fmapi((index, t) => {
let w = weights |> E.A.get(_, index) |> E.O.default(1.0); let w = weights |> E.A.get(_, index) |> E.O.default(1.0);
`VerticalScaling(t, `Simple(`Float(w))) `Operation(`ScaleBy(`Multiply, t, `DistData(`Symbolic(`Float(w)))))
}); });
let pointwiseSum = components let pointwiseSum = components
@ -196,7 +188,7 @@ module MathAdtToDistDst = {
}; };
}; };
let arrayParser = (args:array(arg)):result(SymbolicDist.distTree, string) => { let arrayParser = (args:array(arg)):result(TreeNode.treeNode, string) => {
let samples = args let samples = args
|> E.A.fmap( |> E.A.fmap(
fun fun
@ -207,18 +199,18 @@ module MathAdtToDistDst = {
let outputs = Samples.T.fromSamples(samples); let outputs = Samples.T.fromSamples(samples);
let pdf = outputs.shape |> E.O.bind(_,Distributions.Shape.T.toContinuous); let pdf = outputs.shape |> E.O.bind(_,Distributions.Shape.T.toContinuous);
let shape = pdf |> E.O.fmap(pdf => { let shape = pdf |> E.O.fmap(pdf => {
let _pdf = Distributions.Continuous.T.scaleToIntegralSum(~cache=None, ~intendedSum=1.0, pdf); let _pdf = Distributions.Continuous.T.normalize(pdf);
let cdf = Distributions.Continuous.T.integral(~cache=None, _pdf); let cdf = Distributions.Continuous.T.integral(~cache=None, _pdf);
SymbolicDist.ContinuousShape.make(_pdf, cdf) SymbolicDist.ContinuousShape.make(_pdf, cdf)
}); });
switch(shape){ switch(shape){
| Some(s) => Ok(`Simple(`ContinuousShape(s))) | Some(s) => Ok(`DistData(`Symbolic(`ContinuousShape(s))))
| None => Error("Rendering did not work") | None => Error("Rendering did not work")
} }
} }
let rec functionParser = (r): result(SymbolicDist.distTree, string) => let rec functionParser = (r): result(TreeNode.treeNode, string) =>
r r
|> ( |> (
fun fun
@ -230,7 +222,7 @@ module MathAdtToDistDst = {
| Fn({name: "exponential", args}) => exponential(args) | Fn({name: "exponential", args}) => exponential(args)
| Fn({name: "cauchy", args}) => cauchy(args) | Fn({name: "cauchy", args}) => cauchy(args)
| Fn({name: "triangular", args}) => triangular(args) | Fn({name: "triangular", args}) => triangular(args)
| Value(f) => Ok(`Simple(`Float(f))) | Value(f) => Ok(`DistData(`Symbolic(`Float(f))))
| Fn({name: "mm", args}) => { | Fn({name: "mm", args}) => {
let weights = let weights =
args args
@ -283,7 +275,7 @@ module MathAdtToDistDst = {
args args
|> E.A.fmap(functionParser) |> E.A.fmap(functionParser)
|> (fun |> (fun
| [|Ok(l), Ok(`Simple(`Float(0.0)))|] => Error("Division by zero") | [|Ok(l), Ok(`DistData(`Symbolic(`Float(0.0))))|] => Error("Division by zero")
| [|Ok(l), Ok(r)|] => Ok(`Combination(l, r, `DivideOperation)) | [|Ok(l), Ok(r)|] => Ok(`Combination(l, r, `DivideOperation))
| _ => Error("Division needs two operands")) | _ => Error("Division needs two operands"))
} }
@ -298,14 +290,14 @@ module MathAdtToDistDst = {
args args
|> E.A.fmap(functionParser) |> E.A.fmap(functionParser)
|> (fun |> (fun
| [|Ok(l), Ok(`Simple(`Float(r)))|] => Ok(`LeftTruncate(l, r)) | [|Ok(l), Ok(`DistData(`Symbolic(`Float(r))))|] => Ok(`LeftTruncate(l, r))
| _ => Error("leftTruncate needs two arguments: the expression and the cutoff")) | _ => Error("leftTruncate needs two arguments: the expression and the cutoff"))
} }
| Fn({name: "rightTruncate", args}) => { | Fn({name: "rightTruncate", args}) => {
args args
|> E.A.fmap(functionParser) |> E.A.fmap(functionParser)
|> (fun |> (fun
| [|Ok(l), Ok(`Simple(`Float(r)))|] => Ok(`RightTruncate(l, r)) | [|Ok(l), Ok(`DistData(`Symbolic(`Float(r))))|] => Ok(`RightTruncate(l, r))
| _ => Error("rightTruncate needs two arguments: the expression and the cutoff")) | _ => Error("rightTruncate needs two arguments: the expression and the cutoff"))
} }
| Fn({name}) => Error(name ++ ": function not supported") | Fn({name}) => Error(name ++ ": function not supported")
@ -314,18 +306,18 @@ module MathAdtToDistDst = {
} }
); );
let topLevel = (r): result(SymbolicDist.distTree, string) => let topLevel = (r): result(TreeNode.treeNode, string) =>
r r
|> ( |> (
fun fun
| Fn(_) => functionParser(r) | Fn(_) => functionParser(r)
| Value(r) => Ok(`Simple(`Float(r))) | Value(r) => Ok(`DistData(`Symbolic(`Float(r))))
| Array(r) => arrayParser(r) | Array(r) => arrayParser(r)
| Symbol(_) => Error("Symbol not valid as top level") | Symbol(_) => Error("Symbol not valid as top level")
| Object(_) => Error("Object not valid as top level") | Object(_) => Error("Object not valid as top level")
); );
let run = (r): result(SymbolicDist.distTree, string) => let run = (r): result(TreeNode.treeNode, string) =>
r |> MathAdtCleaner.run |> topLevel; r |> MathAdtCleaner.run |> topLevel;
}; };

View File

@ -36,7 +36,6 @@ type continuousShape = {
cdf: DistTypes.continuousShape, cdf: DistTypes.continuousShape,
}; };
type contType = [ | `Continuous | `Discrete];
type dist = [ type dist = [
| `Normal(normal) | `Normal(normal)
@ -50,29 +49,6 @@ type dist = [
| `Float(float) // Dirac delta at x. Practically useful only in the context of multimodals. | `Float(float) // Dirac delta at x. Practically useful only in the context of multimodals.
]; ];
type integral = float;
type cutoffX = float;
type operation = [
| `AddOperation
| `SubtractOperation
| `MultiplyOperation
| `DivideOperation
| `ExponentiateOperation
];
type distTree = [
| `Simple(dist)
| `Combination(distTree, distTree, operation)
| `PointwiseSum(distTree, distTree)
| `PointwiseProduct(distTree, distTree)
| `VerticalScaling(distTree, distTree)
| `Normalize(distTree)
| `LeftTruncate(distTree, cutoffX)
| `RightTruncate(distTree, cutoffX)
| `Render(distTree)
]
and weightedDists = array((distTree, float));
module ContinuousShape = { module ContinuousShape = {
type t = continuousShape; type t = continuousShape;
let make = (pdf, cdf): t => {pdf, cdf}; let make = (pdf, cdf): t => {pdf, cdf};
@ -82,8 +58,9 @@ module ContinuousShape = {
Distributions.Continuous.T.xToY(p, t.pdf).continuous; Distributions.Continuous.T.xToY(p, t.pdf).continuous;
// TODO: Fix the sampling, to have it work correctly. // TODO: Fix the sampling, to have it work correctly.
let sample = (t: t) => 3.0; let sample = (t: t) => 3.0;
// TODO: Fix the mean, to have it work correctly.
let mean = (t: t) => Ok(0.0);
let toString = t => {j|CustomContinuousShape|j}; let toString = t => {j|CustomContinuousShape|j};
let contType: contType = `Continuous;
}; };
module Exponential = { module Exponential = {
@ -91,8 +68,8 @@ module Exponential = {
let pdf = (x, t: t) => Jstat.exponential##pdf(x, t.rate); let pdf = (x, t: t) => Jstat.exponential##pdf(x, t.rate);
let inv = (p, t: t) => Jstat.exponential##inv(p, t.rate); let inv = (p, t: t) => Jstat.exponential##inv(p, t.rate);
let sample = (t: t) => Jstat.exponential##sample(t.rate); let sample = (t: t) => Jstat.exponential##sample(t.rate);
let mean = (t: t) => Ok(Jstat.exponential##mean(t.rate));
let toString = ({rate}: t) => {j|Exponential($rate)|j}; let toString = ({rate}: t) => {j|Exponential($rate)|j};
let contType: contType = `Continuous;
}; };
module Cauchy = { module Cauchy = {
@ -100,8 +77,8 @@ module Cauchy = {
let pdf = (x, t: t) => Jstat.cauchy##pdf(x, t.local, t.scale); let pdf = (x, t: t) => Jstat.cauchy##pdf(x, t.local, t.scale);
let inv = (p, t: t) => Jstat.cauchy##inv(p, t.local, t.scale); let inv = (p, t: t) => Jstat.cauchy##inv(p, t.local, t.scale);
let sample = (t: t) => Jstat.cauchy##sample(t.local, t.scale); let sample = (t: t) => Jstat.cauchy##sample(t.local, t.scale);
let mean = (t: t) => Error("Cauchy distributions have no mean value.")
let toString = ({local, scale}: t) => {j|Cauchy($local, $scale)|j}; let toString = ({local, scale}: t) => {j|Cauchy($local, $scale)|j};
let contType: contType = `Continuous;
}; };
module Triangular = { module Triangular = {
@ -109,8 +86,8 @@ module Triangular = {
let pdf = (x, t: t) => Jstat.triangular##pdf(x, t.low, t.high, t.medium); let pdf = (x, t: t) => Jstat.triangular##pdf(x, t.low, t.high, t.medium);
let inv = (p, t: t) => Jstat.triangular##inv(p, t.low, t.high, t.medium); let inv = (p, t: t) => Jstat.triangular##inv(p, t.low, t.high, t.medium);
let sample = (t: t) => Jstat.triangular##sample(t.low, t.high, t.medium); let sample = (t: t) => Jstat.triangular##sample(t.low, t.high, t.medium);
let mean = (t: t) => Ok(Jstat.triangular##mean(t.low, t.high, t.medium));
let toString = ({low, medium, high}: t) => {j|Triangular($low, $medium, $high)|j}; let toString = ({low, medium, high}: t) => {j|Triangular($low, $medium, $high)|j};
let contType: contType = `Continuous;
}; };
module Normal = { module Normal = {
@ -124,8 +101,26 @@ module Normal = {
}; };
let inv = (p, t: t) => Jstat.normal##inv(p, t.mean, t.stdev); let inv = (p, t: t) => Jstat.normal##inv(p, t.mean, t.stdev);
let sample = (t: t) => Jstat.normal##sample(t.mean, t.stdev); let sample = (t: t) => Jstat.normal##sample(t.mean, t.stdev);
let mean = (t: t) => Ok(Jstat.normal##mean(t.mean, t.stdev));
let toString = ({mean, stdev}: t) => {j|Normal($mean,$stdev)|j}; let toString = ({mean, stdev}: t) => {j|Normal($mean,$stdev)|j};
let contType: contType = `Continuous;
let add = (n1: t, n2: t) => {
let mean = n1.mean +. n2.mean;
let stdev = sqrt(n1.stdev ** 2. +. n2.stdev ** 2.);
`Normal({mean, stdev});
};
let subtract = (n1: t, n2: t) => {
let mean = n1.mean -. n2.mean;
let stdev = sqrt(n1.stdev ** 2. +. n2.stdev ** 2.);
`Normal({mean, stdev});
};
// TODO: is this useful here at all? would need the integral as well ...
let pointwiseProduct = (n1: t, n2: t) => {
let mean = (n1.mean *. n2.stdev**2. +. n2.mean *. n1.stdev**2.) /. (n1.stdev**2. +. n2.stdev**2.);
let stdev = 1. /. ((1. /. n1.stdev**2.) +. (1. /. n2.stdev**2.));
`Normal({mean, stdev});
};
}; };
module Beta = { module Beta = {
@ -133,17 +128,17 @@ module Beta = {
let pdf = (x, t: t) => Jstat.beta##pdf(x, t.alpha, t.beta); let pdf = (x, t: t) => Jstat.beta##pdf(x, t.alpha, t.beta);
let inv = (p, t: t) => Jstat.beta##inv(p, t.alpha, t.beta); let inv = (p, t: t) => Jstat.beta##inv(p, t.alpha, t.beta);
let sample = (t: t) => Jstat.beta##sample(t.alpha, t.beta); let sample = (t: t) => Jstat.beta##sample(t.alpha, t.beta);
let mean = (t: t) => Ok(Jstat.beta##mean(t.alpha, t.beta));
let toString = ({alpha, beta}: t) => {j|Beta($alpha,$beta)|j}; let toString = ({alpha, beta}: t) => {j|Beta($alpha,$beta)|j};
let contType: contType = `Continuous;
}; };
module Lognormal = { module Lognormal = {
type t = lognormal; type t = lognormal;
let pdf = (x, t: t) => Jstat.lognormal##pdf(x, t.mu, t.sigma); let pdf = (x, t: t) => Jstat.lognormal##pdf(x, t.mu, t.sigma);
let inv = (p, t: t) => Jstat.lognormal##inv(p, t.mu, t.sigma); let inv = (p, t: t) => Jstat.lognormal##inv(p, t.mu, t.sigma);
let mean = (t: t) => Ok(Jstat.lognormal##mean(t.mu, t.sigma));
let sample = (t: t) => Jstat.lognormal##sample(t.mu, t.sigma); let sample = (t: t) => Jstat.lognormal##sample(t.mu, t.sigma);
let toString = ({mu, sigma}: t) => {j|Lognormal($mu,$sigma)|j}; let toString = ({mu, sigma}: t) => {j|Lognormal($mu,$sigma)|j};
let contType: contType = `Continuous;
let from90PercentCI = (low, high) => { let from90PercentCI = (low, high) => {
let logLow = Js.Math.log(low); let logLow = Js.Math.log(low);
let logHigh = Js.Math.log(high); let logHigh = Js.Math.log(high);
@ -163,6 +158,17 @@ module Lognormal = {
); );
`Lognormal({mu, sigma}); `Lognormal({mu, sigma});
}; };
let multiply = (l1, l2) => {
let mu = l1.mu +. l2.mu;
let sigma = l1.sigma +. l2.sigma;
`Lognormal({mu, sigma})
};
let divide = (l1, l2) => {
let mu = l1.mu -. l2.mu;
let sigma = l1.sigma +. l2.sigma;
`Lognormal({mu, sigma})
};
}; };
module Uniform = { module Uniform = {
@ -170,20 +176,20 @@ module Uniform = {
let pdf = (x, t: t) => Jstat.uniform##pdf(x, t.low, t.high); let pdf = (x, t: t) => Jstat.uniform##pdf(x, t.low, t.high);
let inv = (p, t: t) => Jstat.uniform##inv(p, t.low, t.high); let inv = (p, t: t) => Jstat.uniform##inv(p, t.low, t.high);
let sample = (t: t) => Jstat.uniform##sample(t.low, t.high); let sample = (t: t) => Jstat.uniform##sample(t.low, t.high);
let mean = (t: t) => Ok(Jstat.uniform##mean(t.low, t.high));
let toString = ({low, high}: t) => {j|Uniform($low,$high)|j}; let toString = ({low, high}: t) => {j|Uniform($low,$high)|j};
let contType: contType = `Continuous;
}; };
module Float = { module Float = {
type t = float; type t = float;
let pdf = (x, t: t) => x == t ? 1.0 : 0.0; let pdf = (x, t: t) => x == t ? 1.0 : 0.0;
let inv = (p, t: t) => p < t ? 0.0 : 1.0; let inv = (p, t: t) => p < t ? 0.0 : 1.0;
let mean = (t: t) => Ok(t);
let sample = (t: t) => t; let sample = (t: t) => t;
let toString = Js.Float.toString; let toString = Js.Float.toString;
let contType: contType = `Discrete;
}; };
module GenericSimple = { module GenericDistFunctions = {
let minCdfValue = 0.0001; let minCdfValue = 0.0001;
let maxCdfValue = 0.9999; let maxCdfValue = 0.9999;
@ -200,19 +206,6 @@ module GenericSimple = {
| `ContinuousShape(n) => ContinuousShape.pdf(x, n) | `ContinuousShape(n) => ContinuousShape.pdf(x, n)
}; };
let contType = (dist: dist): contType =>
switch (dist) {
| `Normal(_) => Normal.contType
| `Triangular(_) => Triangular.contType
| `Exponential(_) => Exponential.contType
| `Cauchy(_) => Cauchy.contType
| `Lognormal(_) => Lognormal.contType
| `Uniform(_) => Uniform.contType
| `Beta(_) => Beta.contType
| `Float(_) => Float.contType
| `ContinuousShape(_) => ContinuousShape.contType
};
let inv = (x, dist) => let inv = (x, dist) =>
switch (dist) { switch (dist) {
| `Normal(n) => Normal.inv(x, n) | `Normal(n) => Normal.inv(x, n)
@ -274,15 +267,18 @@ module GenericSimple = {
| `Uniform({high}) => high | `Uniform({high}) => high
| `Float(n) => n; | `Float(n) => n;
/* This function returns a list of x's at which to evaluate the overall distribution (for rendering). let mean: dist => result(float, string) =
This function is called separately for each individual distribution. fun
| `Triangular(n) => Triangular.mean(n)
| `Exponential(n) => Exponential.mean(n)
| `Cauchy(n) => Cauchy.mean(n)
| `Normal(n) => Normal.mean(n)
| `Lognormal(n) => Lognormal.mean(n)
| `Beta(n) => Beta.mean(n)
| `ContinuousShape(n) => ContinuousShape.mean(n)
| `Uniform(n) => Uniform.mean(n)
| `Float(n) => Float.mean(n)
When called with xSelection=`Linear, this function will return (n) x's, evenly
distributed between the min and max of the distribution (whatever those are defined to be above).
When called with xSelection=`ByWeight, this function will distribute the x's such as to
match the cumulative shape of the distribution. This is slower but may give better results.
*/
let interpolateXs = let interpolateXs =
(~xSelection: [ | `Linear | `ByWeight]=`Linear, dist: dist, n) => { (~xSelection: [ | `Linear | `ByWeight]=`Linear, dist: dist, n) => {
switch (xSelection, dist) { switch (xSelection, dist) {
@ -297,423 +293,5 @@ module GenericSimple = {
ys |> E.A.fmap(y => inv(y, dist)); ys |> E.A.fmap(y => inv(y, dist));
}; };
}; };
let toShape =
(~xSelection: [ | `Linear | `ByWeight]=`Linear, dist: dist, n)
: DistTypes.shape => {
switch (dist) {
| `ContinuousShape(n) => n.pdf |> Distributions.Continuous.T.toShape
| dist =>
let xs = interpolateXs(~xSelection, dist, n);
let ys = xs |> E.A.fmap(r => pdf(r, dist));
XYShape.T.fromArrays(xs, ys)
|> Distributions.Continuous.make(`Linear, _)
|> Distributions.Continuous.T.toShape;
};
};
}; };
module DistTree = {
type nodeResult = [
| `Simple(dist)
// RenderedShape: continuous xyShape, discrete xyShape, total value.
| `RenderedShape(DistTypes.continuousShape, DistTypes.discreteShape, integral)
];
let evaluateDistribution = (d: dist): nodeResult => {
`Simple(d)
};
// This is a performance bottleneck!
// Using raw JS here so we can use native for loops and access array elements
// directly, without option checks.
let jsContinuousCombinationConvolve: (array(float), array(float), array(float), array(float), float => float => float) => array(array((float, float))) = [%bs.raw
{|
function (s1xs, s1ys, s2xs, s2ys, func) {
// For continuous-continuous convolution, use linear interpolation.
// Let's assume we got downsampled distributions
const outXYShapes = new Array(s1xs.length);
for (let i = 0; i < s1xs.length; i++) {
// create a new distribution
const dxyShape = new Array(s2xs.length);
for (let j = 0; j < s2xs.length; j++) {
dxyShape[j] = [func(s1xs[i], s2xs[j]), (s1ys[i] * s2ys[j])];
}
outXYShapes[i] = dxyShape;
}
return outXYShapes;
}
|}];
let jsDiscreteCombinationConvolve: (array(float), array(float), array(float), array(float), float => float => float) => (array(float), array(float)) = [%bs.raw
{|
function (s1xs, s1ys, s2xs, s2ys, func) {
const r = new Map();
for (let i = 0; i < s1xs.length; i++) {
for (let j = 0; j < s2xs.length; j++) {
const x = func(s1xs[i], s2xs[j]);
const cv = r.get(x) | 0;
r.set(x, cv + s1ys[i] * s2ys[j]); // add up the ys, if same x
}
}
const rxys = [...r.entries()];
rxys.sort(([x1, y1], [x2, y2]) => x1 - x2);
const rxs = new Array(rxys.length);
const rys = new Array(rxys.length);
for (let i = 0; i < rxys.length; i++) {
rxs[i] = rxys[i][0];
rys[i] = rxys[i][1];
}
return [rxs, rys];
}
|}];
let funcFromOp = (op: operation) => {
switch (op) {
| `AddOperation => (+.)
| `SubtractOperation => (-.)
| `MultiplyOperation => (*.)
| `DivideOperation => (/.)
| `ExponentiateOperation => (**)
}
}
let renderDistributionToXYShape = (d: dist, n: int): (DistTypes.continuousShape, DistTypes.discreteShape) => {
// render the distribution into an XY shape
switch (d) {
| `Float(v) => (Distributions.Continuous.empty, {xs: [|v|], ys: [|1.0|]})
| _ => {
let xs = GenericSimple.interpolateXs(~xSelection=`ByWeight, d, n);
let ys = xs |> E.A.fmap(x => GenericSimple.pdf(x, d));
(Distributions.Continuous.make(`Linear, {xs: xs, ys: ys}), XYShape.T.empty)
}
}
};
let combinationDistributionOfXYShapes = (sc1: DistTypes.continuousShape, // continuous shape
sd1: DistTypes.discreteShape, // discrete shape
sc2: DistTypes.continuousShape,
sd2: DistTypes.discreteShape, func): (DistTypes.continuousShape, DistTypes.discreteShape) => {
// First, deal with the discrete-discrete convolution:
let (ddxs, ddys) = jsDiscreteCombinationConvolve(sd1.xs, sd1.ys, sd2.xs, sd2.ys, func);
let ddxy: DistTypes.discreteShape = {xs: ddxs, ys: ddys};
// Then, do the other three:
let downsample = (sc: DistTypes.continuousShape) => {
let scLength = E.A.length(sc.xyShape.xs);
let scSqLength = sqrt(float_of_int(scLength));
scSqLength > 10. ? Distributions.Continuous.T.truncate(int_of_float(scSqLength), sc) : sc;
};
let combinePointConvolutionResults = ca => ca |> E.A.fmap(s => {
// s is an array of (x, y) objects
let (xs, ys) = Belt.Array.unzip(s);
Distributions.Continuous.make(`Linear, {xs, ys});
})
|> Distributions.Continuous.reduce((+.));
let sc1d = downsample(sc1);
let sc2d = downsample(sc2);
let ccxy = jsContinuousCombinationConvolve(sc1d.xyShape.xs, sc1d.xyShape.ys, sc2d.xyShape.xs, sc2d.xyShape.ys, func) |> combinePointConvolutionResults;
let dcxy = jsContinuousCombinationConvolve(sc1d.xyShape.xs, sc1d.xyShape.ys, sc2d.xyShape.xs, sc2d.xyShape.ys, func) |> combinePointConvolutionResults;
let cdxy = jsContinuousCombinationConvolve(sc1d.xyShape.xs, sc1d.xyShape.ys, sc2d.xyShape.xs, sc2d.xyShape.ys, func) |> combinePointConvolutionResults;
let continuousShapeSum = Distributions.Continuous.reduce((+.), [|ccxy, dcxy, cdxy|]);
(continuousShapeSum, ddxy)
};
let evaluateCombinationDistribution = (et1: nodeResult, et2: nodeResult, op: operation, n: int) => {
/* return either a Distribution or a RenderedShape. Must integrate to 1. */
let func = funcFromOp(op);
switch ((et1, et2, op)) {
/* Known cases: replace symbolic with symbolic distribution */
| (`Simple(`Float(v1)), `Simple(`Float(v2)), _) => {
`Simple(`Float(func(v1, v2)))
}
| (`Simple(`Normal(n2)), `Simple(`Float(v1)), `AddOperation)
| (`Simple(`Float(v1)), `Simple(`Normal(n2)), `AddOperation) => {
let n: normal = {mean: v1 +. n2.mean, stdev: n2.stdev};
`Simple(`Normal(n))
}
| (`Simple(`Normal(n1)), `Simple(`Normal(n2)), `AddOperation) => {
let n: normal = {mean: n1.mean +. n2.mean, stdev: sqrt(n1.stdev ** 2. +. n2.stdev ** 2.)};
`Simple(`Normal(n));
}
| (`Simple(`Normal(n1)), `Simple(`Normal(n2)), `SubtractOperation) => {
let n: normal = {mean: n1.mean -. n2.mean, stdev: sqrt(n1.stdev ** 2. +. n2.stdev ** 2.)};
`Simple(`Normal(n));
}
| (`Simple(`Lognormal(l1)), `Simple(`Lognormal(l2)), `MultiplyOperation) => {
let l: lognormal = {mu: l1.mu +. l2.mu, sigma: l1.sigma +. l2.sigma};
`Simple(`Lognormal(l));
}
| (`Simple(`Lognormal(l1)), `Simple(`Lognormal(l2)), `DivideOperation) => {
let l: lognormal = {mu: l1.mu -. l2.mu, sigma: l1.sigma +. l2.sigma};
`Simple(`Lognormal(l));
}
/* General cases: convolve the XYShapes */
| (`Simple(d1), `Simple(d2), _) => {
let (sc1, sd1) = renderDistributionToXYShape(d1, n);
let (sc2, sd2) = renderDistributionToXYShape(d2, n);
let (sc, sd) = combinationDistributionOfXYShapes(sc1, sd1, sc2, sd2, func);
`RenderedShape(sc, sd, 1.0)
}
| (`Simple(d1), `RenderedShape(sc2, sd2, i2), _)
| (`RenderedShape(sc2, sd2, i2), `Simple(d1), _) => {
let (sc1, sd1) = renderDistributionToXYShape(d1, n);
let (sc, sd) = combinationDistributionOfXYShapes(sc1, sd1, sc2, sd2, func);
`RenderedShape(sc, sd, i2)
}
| (`RenderedShape(sc1, sd1, i1), `RenderedShape(sc2, sd2, i2), _) => {
// sum of two multimodals that have a continuous and discrete each.
let (sc, sd) = combinationDistributionOfXYShapes(sc1, sd1, sc2, sd2, func);
`RenderedShape(sc, sd, i1);
}
}
};
let evaluatePointwiseSum = (et1: nodeResult, et2: nodeResult, n: int) => {
switch ((et1, et2)) {
/* Known cases: */
| (`Simple(`Float(v1)), `Simple(`Float(v2))) => {
v1 == v2
? `RenderedShape(Distributions.Continuous.empty, Distributions.Discrete.make({xs: [|v1|], ys: [|2.|]}), 2.)
: `RenderedShape(Distributions.Continuous.empty, Distributions.Discrete.empty, 0.) // TODO: add warning: shouldn't pointwise add scalars.
}
| (`Simple(`Float(v1)), `Simple(d2))
| (`Simple(d2), `Simple(`Float(v1))) => {
let sd1: DistTypes.xyShape = {xs: [|v1|], ys: [|1.|]};
let (sc2, sd2) = renderDistributionToXYShape(d2, n);
`RenderedShape(sc2, Distributions.Discrete.reduce((+.), [|sd1, sd2|]), 2.)
}
| (`Simple(d1), `Simple(d2)) => {
let (sc1, sd1) = renderDistributionToXYShape(d1, n);
let (sc2, sd2) = renderDistributionToXYShape(d2, n);
`RenderedShape(Distributions.Continuous.reduce((+.), [|sc1, sc2|]), Distributions.Discrete.reduce((+.), [|sd1, sd2|]), 2.)
}
| (`Simple(d1), `RenderedShape(sc2, sd2, i2))
| (`RenderedShape(sc2, sd2, i2), `Simple(d1)) => {
let (sc1, sd1) = renderDistributionToXYShape(d1, n);
`RenderedShape(Distributions.Continuous.reduce((+.), [|sc1, sc2|]), Distributions.Discrete.reduce((+.), [|sd1, sd2|]), 1. +. i2)
}
| (`RenderedShape(sc1, sd1, i1), `RenderedShape(sc2, sd2, i2)) => {
`RenderedShape(Distributions.Continuous.reduce((+.), [|sc1, sc2|]), Distributions.Discrete.reduce((+.), [|sd1, sd2|]), i1 +. i2)
}
}
};
let evaluatePointwiseProduct = (et1: nodeResult, et2: nodeResult, n: int) => {
switch ((et1, et2)) {
/* Known cases: */
| (`Simple(`Float(v1)), `Simple(`Float(v2))) => {
v1 == v2
? `RenderedShape(Distributions.Continuous.empty, Distributions.Discrete.make({xs: [|v1|], ys: [|1.|]}), 1.)
: `RenderedShape(Distributions.Continuous.empty, Distributions.Discrete.empty, 0.) // TODO: add warning: shouldn't pointwise multiply scalars.
}
| (`Simple(`Float(v1)), `Simple(d2)) => {
// evaluate d2 at v1
let y = GenericSimple.pdf(v1, d2);
`RenderedShape(Distributions.Continuous.empty, Distributions.Discrete.make({xs: [|v1|], ys: [|y|]}), y)
}
| (`Simple(d1), `Simple(`Float(v2))) => {
// evaluate d1 at v2
let y = GenericSimple.pdf(v2, d1);
`RenderedShape(Distributions.Continuous.empty, Distributions.Discrete.make({xs: [|v2|], ys: [|y|]}), y)
}
| (`Simple(`Normal(n1)), `Simple(`Normal(n2))) => {
let mean = (n1.mean *. n2.stdev**2. +. n2.mean *. n1.stdev**2.) /. (n1.stdev**2. +. n2.stdev**2.);
let stdev = 1. /. ((1. /. n1.stdev**2.) +. (1. /. n2.stdev**2.));
let integral = 0; // TODO
`RenderedShape(Distributions.Continuous.empty, Distributions.Discrete.empty, 0.)
}
/* General cases */
| (`Simple(d1), `Simple(d2)) => {
// NOT IMPLEMENTED YET
// TODO: evaluate integral properly
let (sc1, sd1) = renderDistributionToXYShape(d1, n);
let (sc2, sd2) = renderDistributionToXYShape(d2, n);
`RenderedShape(Distributions.Continuous.empty, Distributions.Discrete.empty, 0.)
}
| (`Simple(d1), `RenderedShape(sc2, sd2, i2)) => {
// NOT IMPLEMENTED YET
// TODO: evaluate integral properly
let (sc1, sd1) = renderDistributionToXYShape(d1, n);
`RenderedShape(Distributions.Continuous.empty, Distributions.Discrete.empty, 0.)
}
| (`RenderedShape(sc1, sd1, i1), `Simple(d1)) => {
// NOT IMPLEMENTED YET
// TODO: evaluate integral properly
let (sc2, sd2) = renderDistributionToXYShape(d1, n);
`RenderedShape(Distributions.Continuous.empty, Distributions.Discrete.empty, 0.)
}
| (`RenderedShape(sc1, sd1, i1), `RenderedShape(sc2, sd2, i2)) => {
`RenderedShape(Distributions.Continuous.empty, Distributions.Discrete.empty, 0.)
}
}
};
let evaluateNormalize = (et: nodeResult, n: int) => {
// just divide everything by the integral.
switch (et) {
| `RenderedShape(sc, sd, 0.) => {
`RenderedShape(Distributions.Continuous.empty, Distributions.Discrete.empty, 0.)
}
| `RenderedShape(sc, sd, i) => {
// loop through all ys and divide them by i
let normalize = (s: DistTypes.xyShape): DistTypes.xyShape => {xs: s.xs, ys: s.ys |> E.A.fmap(y => y /. i)};
let scn = sc |> Distributions.Continuous.shapeMap(normalize);
let sdn = sd |> normalize;
`RenderedShape(scn, sdn, 1.)
}
| `Simple(d) => `Simple(d) // any kind of atomic dist should already be normalized -- TODO: THIS IS ACTUALLY FALSE! E.g. pointwise product of normal * normal
}
};
let evaluateTruncate = (et: nodeResult, xc: cutoffX, compareFunc: (float, float) => bool, n: int) => {
let cut = (s: DistTypes.xyShape): DistTypes.xyShape => {
let (xs, ys) = s.ys
|> Belt.Array.zip(s.xs)
|> E.A.filter(((x, y)) => compareFunc(x, xc))
|> Belt.Array.unzip
let cutShape: DistTypes.xyShape = {xs, ys};
cutShape;
};
switch (et) {
| `Simple(d) => {
let (sc, sd) = renderDistributionToXYShape(d, n);
let scc = sc |> Distributions.Continuous.shapeMap(cut);
let sdc = sd |> cut;
let newIntegral = 1.; // TODO
`RenderedShape(scc, sdc, newIntegral);
}
| `RenderedShape(sc, sd, i) => {
let scc = sc |> Distributions.Continuous.shapeMap(cut);
let sdc = sd |> cut;
let newIntegral = 1.; // TODO
`RenderedShape(scc, sdc, newIntegral);
}
}
};
let evaluateVerticalScaling = (et1: nodeResult, et2: nodeResult, n: int) => {
let scale = (i: float, s: DistTypes.xyShape): DistTypes.xyShape => {xs: s.xs, ys: s.ys |> E.A.fmap(y => y *. i)};
switch ((et1, et2)) {
| (`Simple(`Float(v)), `Simple(d))
| (`Simple(d), `Simple(`Float(v))) => {
let (sc, sd) = renderDistributionToXYShape(d, n);
let scc = sc |> Distributions.Continuous.shapeMap(scale(v));
let sdc = sd |> scale(v);
let newIntegral = v; // TODO
`RenderedShape(scc, sdc, newIntegral);
}
| (`Simple(`Float(v)), `RenderedShape(sc, sd, i))
| (`RenderedShape(sc, sd, i), `Simple(`Float(v))) => {
let scc = sc |> Distributions.Continuous.shapeMap(scale(v));
let sdc = sd |> scale(v);
let newIntegral = v; // TODO
`RenderedShape(scc, sdc, newIntegral);
}
| _ => `RenderedShape(Distributions.Continuous.empty, Distributions.Discrete.empty, 0.) // TODO: give warning
}
}
let renderNode = (et: nodeResult, n: int) => {
switch (et) {
| `Simple(d) => {
let (sc, sd) = renderDistributionToXYShape(d, n);
`RenderedShape(sc, sd, 1.0);
}
| s => s
}
}
let rec evaluateNode = (treeNode: distTree, n: int): nodeResult => {
// returns either a new symbolic distribution
switch (treeNode) {
| `Simple(d) => evaluateDistribution(d)
| `Combination(t1, t2, op) => evaluateCombinationDistribution(evaluateNode(t1, n), evaluateNode(t2, n), op, n)
| `PointwiseSum(t1, t2) => evaluatePointwiseSum(evaluateNode(t1, n), evaluateNode(t2, n), n)
| `PointwiseProduct(t1, t2) => evaluatePointwiseProduct(evaluateNode(t1, n), evaluateNode(t2, n), n)
| `VerticalScaling(t1, t2) => evaluateVerticalScaling(evaluateNode(t1, n), evaluateNode(t2, n), n)
| `Normalize(t) => evaluateNormalize(evaluateNode(t, n), n)
| `LeftTruncate(t, x) => evaluateTruncate(evaluateNode(t, n), x, (>=), n)
| `RightTruncate(t, x) => evaluateTruncate(evaluateNode(t, n), x, (<=), n)
| `Render(t) => renderNode(evaluateNode(t, n), n)
}
};
let toShape = (treeNode: distTree, n: int) => {
let treeShape = evaluateNode(`Render(`Normalize(treeNode)), n);
switch (treeShape) {
| `Simple(_) => E.O.toExn("No shape found!", None)
| `RenderedShape(sc, sd, _) => {
let shape = MixedShapeBuilder.buildSimple(~continuous=Some(sc), ~discrete=sd);
shape |> E.O.toExt("");
}
}
};
let rec toString = (treeNode: distTree): string => {
let stringFromOp = op => switch (op) {
| `AddOperation => " + "
| `SubtractOperation => " - "
| `MultiplyOperation => " * "
| `DivideOperation => " / "
| `ExponentiateOperation => "^"
};
switch (treeNode) {
| `Simple(d) => GenericSimple.toString(d)
| `Combination(t1, t2, op) => toString(t1) ++ stringFromOp(op) ++ toString(t2)
| `PointwiseSum(t1, t2) => toString(t1) ++ " .+ " ++ toString(t2)
| `PointwiseProduct(t1, t2) => toString(t1) ++ " .* " ++ toString(t2)
| `VerticalScaling(t1, t2) => toString(t1) ++ " @ " ++ toString(t2)
| `Normalize(t) => "normalize(" ++ toString(t) ++ ")"
| `LeftTruncate(t, x) => "leftTruncate(" ++ toString(t) ++ ", " ++ string_of_float(x) ++ ")"
| `RightTruncate(t, x) => "rightTruncate(" ++ toString(t) ++ ", " ++ string_of_float(x) ++ ")"
| `Render(t) => toString(t)
}
};
};
let toString = (treeNode: distTree) => DistTree.toString(treeNode)
let toShape = (sampleCount: int, treeNode: distTree) =>
DistTree.toShape(treeNode, sampleCount) //~xSelection=`ByWeight,

View File

@ -0,0 +1,414 @@
/* This module represents a tree node. */
/* TreeNodes are either Data (i.e. symbolic or rendered distributions) or Operations. */
type treeNode = [
| `DistData(distData)
| `Operation(operation)
] and distData = [
| `Symbolic(SymbolicDist.dist)
| `RenderedShape(DistTypes.shape)
] and operation = [
// binary operations
| `StandardOperation(standardOperation, treeNode, treeNode)
| `PointwiseOperation(pointwiseOperation, treeNode, treeNode)
| `ScaleOperation(scaleOperation, treeNode, scaleBy)
// unary operations
| `Render(treeNode) // always evaluates to `DistData(`RenderedShape(...))
| `Truncate(leftCutoff, rightCutoff, treeNode)
| `Normalize(treeNode)
// direct evaluations of dists (e.g. cdf, sample)
| `FloatFromDist(distToFloatOperation, treeNode)
] and standardOperation = [
| `Add
| `Multiply
| `Subtract
| `Divide
| `Exponentiate
] and pointwiseOperation = [
| `Add
| `Multiply
] and scaleOperation = [
| `Multiply
| `Log
]
and scaleBy = treeNode and leftCutoff = option(float) and rightCutoff = option(float)
and distToFloatOperation = [
| `Pdf(float)
| `Cdf(float)
| `Inv(float)
| `Sample
];
module TreeNode = {
type t = treeNode;
type simplifier = treeNode => result(treeNode, string);
type renderParams = {
operationToDistData: (int, operation) => result(t, string),
sampleCount: int,
}
let rec renderToShape = (renderParams, t: t): result(DistTypes.shape, string) => {
switch (t) {
| `DistData(`RenderedShape(s)) => Ok(s) // already a rendered shape, we're done here
| `DistData(`Symbolic(d)) =>
switch (d) {
| `Float(v) =>
Ok(Discrete(Distributions.Discrete.make({xs: [|v|], ys: [|1.0|]}, Some(1.0))));
| _ =>
let xs = SymbolicDist.GenericDistFunctions.interpolateXs(~xSelection=`ByWeight, d, renderParams.sampleCount);
let ys = xs |> E.A.fmap(x => SymbolicDist.GenericDistFunctions.pdf(x, d));
Ok(Continuous(Distributions.Continuous.make(`Linear, {xs, ys}, Some(1.0))));
}
| `Operation(op) => E.R.bind(renderParams.operationToDistData(renderParams.sampleCount, op), renderToShape(renderParams))
};
};
/* The following modules encapsulate everything we can do with
* different kinds of operations. */
/* Given two random variables A and B, this returns the distribution
of a new variable that is the result of the operation on A and B.
For instance, normal(0, 1) + normal(1, 1) -> normal(1, 2).
In general, this is implemented via convolution. */
module StandardOperation = {
let funcFromOp: (standardOperation, float, float) => float =
fun
| `Add => (+.)
| `Subtract => (-.)
| `Multiply => ( *. )
| `Divide => (/.)
| `Exponentiate => ( ** );
module Simplify = {
let tryCombiningFloats: simplifier =
fun
| `Operation(
`StandardOperation(
`Divide,
`DistData(`Symbolic(`Float(v1))),
`DistData(`Symbolic(`Float(0.))),
),
) =>
Error("Cannot divide $v1 by zero.")
| `Operation(
`StandardOperation(
standardOp,
`DistData(`Symbolic(`Float(v1))),
`DistData(`Symbolic(`Float(v2))),
),
) => {
let func = funcFromOp(standardOp);
Ok(`DistData(`Symbolic(`Float(func(v1, v2)))));
}
| t => Ok(t);
let tryCombiningNormals: simplifier =
fun
| `Operation(
`StandardOperation(
`Add,
`DistData(`Symbolic(`Normal(n1))),
`DistData(`Symbolic(`Normal(n2))),
),
) =>
Ok(`DistData(`Symbolic(SymbolicDist.Normal.add(n1, n2))))
| `Operation(
`StandardOperation(
`Subtract,
`DistData(`Symbolic(`Normal(n1))),
`DistData(`Symbolic(`Normal(n2))),
),
) =>
Ok(`DistData(`Symbolic(SymbolicDist.Normal.subtract(n1, n2))))
| t => Ok(t);
let tryCombiningLognormals: simplifier =
fun
| `Operation(
`StandardOperation(
`Multiply,
`DistData(`Symbolic(`Lognormal(l1))),
`DistData(`Symbolic(`Lognormal(l2))),
),
) =>
Ok(`DistData(`Symbolic(SymbolicDist.Lognormal.multiply(l1, l2))))
| `Operation(
`StandardOperation(
`Divide,
`DistData(`Symbolic(`Lognormal(l1))),
`DistData(`Symbolic(`Lognormal(l2))),
),
) =>
Ok(`DistData(`Symbolic(SymbolicDist.Lognormal.divide(l1, l2))))
| t => Ok(t);
let attempt = (standardOp, t1: t, t2: t): result(treeNode, string) => {
let originalTreeNode =
`Operation(`StandardOperation((standardOp, t1, t2)));
originalTreeNode
|> tryCombiningFloats
|> E.R.bind(_, tryCombiningNormals)
|> E.R.bind(_, tryCombiningLognormals);
};
};
let evaluateNumerically = (standardOp, renderParams, t1, t2) => {
let func = funcFromOp(standardOp);
// TODO: downsample the two shapes
let renderedShape1 = t1 |> renderToShape(renderParams);
let renderedShape2 = t2 |> renderToShape(renderParams);
// This will most likely require a mixed
switch ((renderedShape1, renderedShape2)) {
| (Error(e1), _) => Error(e1)
| (_, Error(e2)) => Error(e2)
| (Ok(s1), Ok(s2)) => Ok(`DistData(`RenderedShape(Distributions.Shape.convolve(func, s1, s2))))
};
};
let evaluateToDistData =
(standardOp: standardOperation, renderParams, t1: t, t2: t): result(treeNode, string) =>
standardOp
|> Simplify.attempt(_, t1, t2)
|> E.R.bind(
_,
fun
| `DistData(d) => Ok(`DistData(d)) // the analytical simplifaction worked, nice!
| `Operation(_) => // if not, run the convolution
evaluateNumerically(standardOp, renderParams, t1, t2),
);
};
module ScaleOperation = {
let rec mean = (renderParams, t: t): result(float, string) => {
switch (t) {
| `DistData(`RenderedShape(s)) => Ok(Distributions.Shape.T.mean(s))
| `DistData(`Symbolic(s)) => SymbolicDist.GenericDistFunctions.mean(s)
// evaluating the operation returns result(treeNode(distData)). We then want to make sure
| `Operation(op) => E.R.bind(renderParams.operationToDistData(renderParams.sampleCount, op), mean(renderParams))
}
};
let fnFromOp =
fun
| `Multiply => (*.)
| `Log => ((a, b) => ( log(a) /. log(b) ));
let knownIntegralSumFnFromOp =
fun
| `Multiply => (a, b) => Some(a *. b)
| `Log => ((_, _) => None);
let evaluateToDistData = (scaleOp, renderParams, t, scaleBy) => {
let fn = fnFromOp(scaleOp);
let knownIntegralSumFn = knownIntegralSumFnFromOp(scaleOp);
let renderedShape = t |> renderToShape(renderParams);
let scaleByMeanValue = mean(renderParams, scaleBy);
switch ((renderedShape, scaleByMeanValue)) {
| (Error(e1), _) => Error(e1)
| (_, Error(e2)) => Error(e2)
| (Ok(rs), Ok(sm)) =>
Ok(`DistData(`RenderedShape(Distributions.Shape.T.mapY(~knownIntegralSumFn=knownIntegralSumFn(sm), fn(sm), rs))))
}
};
};
module PointwiseOperation = {
let funcFromOp: (pointwiseOperation => ((float, float) => float)) =
fun
| `Add => (+.)
| `Multiply => ( *. );
let evaluateToDistData = (pointwiseOp, renderParams, t1, t2) => {
let func = funcFromOp(pointwiseOp);
let renderedShape1 = t1 |> renderToShape(renderParams);
let renderedShape2 = t2 |> renderToShape(renderParams);
// TODO: figure out integral, diff between pointwiseAdd and pointwiseProduct and other stuff
// Distributions.Shape.reduce(func, renderedShape1, renderedShape2);
Error("Pointwise operations currently not supported.")
};
};
module Truncate = {
module Simplify = {
let tryTruncatingNothing: simplifier = fun
| `Operation(`Truncate(None, None, `DistData(d))) => Ok(`DistData(d))
| t => Ok(t);
let tryTruncatingUniform: simplifier = fun
| `Operation(`Truncate(lc, rc, `DistData(`Symbolic(`Uniform(u))))) => {
// just create a new Uniform distribution
let newLow = max(E.O.default(neg_infinity, lc), u.low);
let newHigh = min(E.O.default(infinity, rc), u.high);
Ok(`DistData(`Symbolic(`Uniform({low: newLow, high: newHigh}))));
}
| t => Ok(t);
let attempt = (leftCutoff, rightCutoff, t): result(treeNode, string) => {
let originalTreeNode = `Operation(`Truncate(leftCutoff, rightCutoff, t));
originalTreeNode
|> tryTruncatingNothing
|> E.R.bind(_, tryTruncatingUniform);
};
};
let evaluateNumerically = (leftCutoff, rightCutoff, renderParams, t) => {
// TODO: use named args in renderToShape; if we're lucky we can at least get the tail
// of a distribution we otherwise wouldn't get at all
let renderedShape = t |> renderToShape(renderParams);
E.R.bind(renderedShape, rs => {
let truncatedShape = rs |> Distributions.Shape.truncate(leftCutoff, rightCutoff);
Ok(`DistData(`RenderedShape(rs)));
});
};
let evaluateToDistData = (leftCutoff: option(float), rightCutoff: option(float), renderParams, t: treeNode): result(treeNode, string) => {
t
|> Simplify.attempt(leftCutoff, rightCutoff)
|> E.R.bind(
_,
fun
| `DistData(d) => Ok(`DistData(d)) // the analytical simplifaction worked, nice!
| `Operation(_) => evaluateNumerically(leftCutoff, rightCutoff, renderParams, t),
); // if not, run the convolution
};
};
module Normalize = {
let rec evaluateToDistData = (renderParams, t: treeNode): result(treeNode, string) => {
switch (t) {
| `DistData(`Symbolic(_)) => Ok(t)
| `DistData(`RenderedShape(s)) => {
let normalized = Distributions.Shape.normalize(s);
Ok(`DistData(`RenderedShape(normalized)));
}
| `Operation(op) => E.R.bind(renderParams.operationToDistData(renderParams.sampleCount, op), evaluateToDistData(renderParams))
}
}
};
module FloatFromDist = {
let evaluateFromSymbolic = (distToFloatOp: distToFloatOperation, s) => {
let value = switch (distToFloatOp) {
| `Pdf(f) => SymbolicDist.GenericDistFunctions.pdf(f, s)
| `Cdf(f) => 0.0
| `Inv(f) => SymbolicDist.GenericDistFunctions.inv(f, s)
| `Sample => SymbolicDist.GenericDistFunctions.sample(s)
}
Ok(`DistData(`Symbolic(`Float(value))));
};
let evaluateFromRenderedShape = (distToFloatOp: distToFloatOperation, rs: DistTypes.shape): result(treeNode, string) => {
// evaluate the pdf, cdf, get sample, etc. from the renderedShape rs
// Should be a float like Ok(`DistData(`Symbolic(Float(0.0))));
Error("Float from dist is not yet implemented.");
};
let rec evaluateToDistData = (distToFloatOp: distToFloatOperation, renderParams, t: treeNode): result(treeNode, string) => {
switch (t) {
| `DistData(`Symbolic(s)) => evaluateFromSymbolic(distToFloatOp, s) // we want to evaluate the distToFloatOp on the symbolic dist
| `DistData(`RenderedShape(rs)) => evaluateFromRenderedShape(distToFloatOp, rs)
| `Operation(op) => E.R.bind(renderParams.operationToDistData(renderParams.sampleCount, op), evaluateToDistData(distToFloatOp, renderParams))
}
}
};
module Render = {
let evaluateToRenderedShape = (renderParams, t: treeNode): result(t, string) => {
E.R.bind(renderToShape(renderParams, t), rs => Ok(`DistData(`RenderedShape(rs))));
}
};
let rec operationToDistData =
(sampleCount: int, op: operation): result(t, string) => {
// the functions that convert the Operation nodes to DistData nodes need to
// have a way to call this function on their children, if their children are themselves Operation nodes.
let renderParams: renderParams = {
operationToDistData: operationToDistData,
sampleCount: sampleCount,
};
switch (op) {
| `StandardOperation(standardOp, t1, t2) =>
StandardOperation.evaluateToDistData(
standardOp, renderParams, t1, t2 // we want to give it the option to render or simply leave it as is
)
| `PointwiseOperation(pointwiseOp, t1, t2) =>
PointwiseOperation.evaluateToDistData(
pointwiseOp,
renderParams,
t1,
t2,
)
| `ScaleOperation(scaleOp, t, scaleBy) =>
ScaleOperation.evaluateToDistData(scaleOp, renderParams, t, scaleBy)
| `Truncate(leftCutoff, rightCutoff, t) => Truncate.evaluateToDistData(leftCutoff, rightCutoff, renderParams, t)
| `FloatFromDist(distToFloatOp, t) => FloatFromDist.evaluateToDistData(distToFloatOp, renderParams, t)
| `Normalize(t) => Normalize.evaluateToDistData(renderParams, t)
| `Render(t) => Render.evaluateToRenderedShape(renderParams, t)
};
};
/* This function recursively goes through the nodes of the parse tree,
replacing each Operation node and its subtree with a Data node.
Whenever possible, the replacement produces a new Symbolic Data node,
but most often it will produce a RenderedShape.
This function is used mainly to turn a parse tree into a single RenderedShape
that can then be displayed to the user. */
let rec toDistData = (treeNode: t, sampleCount: int): result(t, string) => {
switch (treeNode) {
| `DistData(d) => Ok(`DistData(d))
| `Operation(op) => operationToDistData(sampleCount, op)
};
};
let rec toString = (t: t): string => {
let stringFromStandardOperation = fun
| `Add => " + "
| `Subtract => " - "
| `Multiply => " * "
| `Divide => " / "
| `Exponentiate => "^";
let stringFromPointwiseOperation =
fun
| `Add => " .+ "
| `Multiply => " .* ";
switch (t) {
| `DistData(`Symbolic(d)) => SymbolicDist.GenericDistFunctions.toString(d)
| `DistData(`RenderedShape(s)) => "[shape]"
| `Operation(`StandardOperation(op, t1, t2)) => toString(t1) ++ stringFromStandardOperation(op) ++ toString(t2)
| `Operation(`PointwiseOperation(op, t1, t2)) => toString(t1) ++ stringFromPointwiseOperation(op) ++ toString(t2)
| `Operation(`ScaleOperation(_scaleOp, t, scaleBy)) => toString(t) ++ " @ " ++ toString(scaleBy)
| `Operation(`Normalize(t)) => "normalize(" ++ toString(t) ++ ")"
| `Operation(`Truncate(lc, rc, t)) => "truncate(" ++ toString(t) ++ ", " ++ E.O.dimap(string_of_float, () => "-inf", lc) ++ ", " ++ E.O.dimap(string_of_float, () => "inf", rc) ++ ")"
| `Operation(`Render(t)) => toString(t)
}
};
};
let toShape = (sampleCount: int, treeNode: treeNode) => {
let renderResult = TreeNode.toDistData(`Operation(`Render(treeNode)), sampleCount);
switch (renderResult) {
| Ok(`DistData(`RenderedShape(rs))) => {
let continuous = Distributions.Shape.T.toContinuous(rs);
let discrete = Distributions.Shape.T.toDiscrete(rs);
let shape = MixedShapeBuilder.buildSimple(~continuous, ~discrete);
shape |> E.O.toExt("");
}
| Ok(_) => E.O.toExn("Rendering failed.", None)
| Error(message) => E.O.toExn("No shape found!", None)
}
};

View File

@ -5,6 +5,7 @@ type normal = {
[@bs.meth] "cdf": (float, float, float) => float, [@bs.meth] "cdf": (float, float, float) => float,
[@bs.meth] "inv": (float, float, float) => float, [@bs.meth] "inv": (float, float, float) => float,
[@bs.meth] "sample": (float, float) => float, [@bs.meth] "sample": (float, float) => float,
[@bs.meth] "mean": (float, float) => float,
}; };
type lognormal = { type lognormal = {
. .
@ -12,6 +13,7 @@ type lognormal = {
[@bs.meth] "cdf": (float, float, float) => float, [@bs.meth] "cdf": (float, float, float) => float,
[@bs.meth] "inv": (float, float, float) => float, [@bs.meth] "inv": (float, float, float) => float,
[@bs.meth] "sample": (float, float) => float, [@bs.meth] "sample": (float, float) => float,
[@bs.meth] "mean": (float, float) => float,
}; };
type uniform = { type uniform = {
. .
@ -19,6 +21,7 @@ type uniform = {
[@bs.meth] "cdf": (float, float, float) => float, [@bs.meth] "cdf": (float, float, float) => float,
[@bs.meth] "inv": (float, float, float) => float, [@bs.meth] "inv": (float, float, float) => float,
[@bs.meth] "sample": (float, float) => float, [@bs.meth] "sample": (float, float) => float,
[@bs.meth] "mean": (float, float) => float,
}; };
type beta = { type beta = {
. .
@ -26,6 +29,7 @@ type beta = {
[@bs.meth] "cdf": (float, float, float) => float, [@bs.meth] "cdf": (float, float, float) => float,
[@bs.meth] "inv": (float, float, float) => float, [@bs.meth] "inv": (float, float, float) => float,
[@bs.meth] "sample": (float, float) => float, [@bs.meth] "sample": (float, float) => float,
[@bs.meth] "mean": (float, float) => float,
}; };
type exponential = { type exponential = {
. .
@ -33,6 +37,7 @@ type exponential = {
[@bs.meth] "cdf": (float, float) => float, [@bs.meth] "cdf": (float, float) => float,
[@bs.meth] "inv": (float, float) => float, [@bs.meth] "inv": (float, float) => float,
[@bs.meth] "sample": float => float, [@bs.meth] "sample": float => float,
[@bs.meth] "mean": float => float,
}; };
type cauchy = { type cauchy = {
. .
@ -47,6 +52,7 @@ type triangular = {
[@bs.meth] "cdf": (float, float, float, float) => float, [@bs.meth] "cdf": (float, float, float, float) => float,
[@bs.meth] "inv": (float, float, float, float) => float, [@bs.meth] "inv": (float, float, float, float) => float,
[@bs.meth] "sample": (float, float, float) => float, [@bs.meth] "sample": (float, float, float) => float,
[@bs.meth] "mean": (float, float, float) => float,
}; };
// Pareto doesn't have sample for some reason // Pareto doesn't have sample for some reason
@ -61,6 +67,7 @@ type poisson = {
[@bs.meth] "pdf": (float, float) => float, [@bs.meth] "pdf": (float, float) => float,
[@bs.meth] "cdf": (float, float) => float, [@bs.meth] "cdf": (float, float) => float,
[@bs.meth] "sample": float => float, [@bs.meth] "sample": float => float,
[@bs.meth] "mean": float => float,
}; };
type weibull = { type weibull = {
. .
@ -68,6 +75,7 @@ type weibull = {
[@bs.meth] "cdf": (float, float, float) => float, [@bs.meth] "cdf": (float, float, float) => float,
[@bs.meth] "inv": (float, float, float) => float, [@bs.meth] "inv": (float, float, float) => float,
[@bs.meth] "sample": (float, float) => float, [@bs.meth] "sample": (float, float) => float,
[@bs.meth] "mean": (float, float) => float,
}; };
type binomial = { type binomial = {
. .