squiggle/packages/playground/src/distPlus/distribution/Shape.re
2022-02-06 14:48:21 -05:00

241 lines
6.3 KiB
ReasonML

open Distributions;
type t = DistTypes.shape;
let mapToAll = ((fn1, fn2, fn3), t: t) =>
switch (t) {
| Mixed(m) => fn1(m)
| Discrete(m) => fn2(m)
| Continuous(m) => fn3(m)
};
let fmap = ((fn1, fn2, fn3), t: t): t =>
switch (t) {
| Mixed(m) => Mixed(fn1(m))
| Discrete(m) => Discrete(fn2(m))
| Continuous(m) => Continuous(fn3(m))
};
let toMixed =
mapToAll((
m => m,
d => Mixed.make(~integralSumCache=d.integralSumCache, ~integralCache=d.integralCache, ~discrete=d, ~continuous=Continuous.empty),
c => Mixed.make(~integralSumCache=c.integralSumCache, ~integralCache=c.integralCache, ~discrete=Discrete.empty, ~continuous=c),
));
let combineAlgebraically =
(op: ExpressionTypes.algebraicOperation, t1: t, t2: t): t => {
switch (t1, t2) {
| (Continuous(m1), Continuous(m2)) =>
Continuous.combineAlgebraically(op, m1, m2) |> Continuous.T.toShape;
| (Continuous(m1), Discrete(m2))
| (Discrete(m2), Continuous(m1)) =>
Continuous.combineAlgebraicallyWithDiscrete(op, m1, m2) |> Continuous.T.toShape
| (Discrete(m1), Discrete(m2)) =>
Discrete.combineAlgebraically(op, m1, m2) |> Discrete.T.toShape
| (m1, m2) =>
Mixed.combineAlgebraically(
op,
toMixed(m1),
toMixed(m2),
)
|> Mixed.T.toShape
};
};
let combinePointwise =
(~integralSumCachesFn: (float, float) => option(float) = (_, _) => None,
~integralCachesFn: (DistTypes.continuousShape, DistTypes.continuousShape) => option(DistTypes.continuousShape) = (_, _) => None,
fn,
t1: t,
t2: t) =>
switch (t1, t2) {
| (Continuous(m1), Continuous(m2)) =>
DistTypes.Continuous(
Continuous.combinePointwise(~integralSumCachesFn, ~integralCachesFn, fn, m1, m2),
)
| (Discrete(m1), Discrete(m2)) =>
DistTypes.Discrete(
Discrete.combinePointwise(~integralSumCachesFn, ~integralCachesFn, fn, m1, m2),
)
| (m1, m2) =>
DistTypes.Mixed(
Mixed.combinePointwise(
~integralSumCachesFn,
~integralCachesFn,
fn,
toMixed(m1),
toMixed(m2),
),
)
};
module T =
Dist({
type t = DistTypes.shape;
type integral = DistTypes.continuousShape;
let xToY = (f: float) =>
mapToAll((
Mixed.T.xToY(f),
Discrete.T.xToY(f),
Continuous.T.xToY(f),
));
let toShape = (t: t) => t;
let toContinuous = t => None;
let toDiscrete = t => None;
let downsample = (i, t) =>
fmap(
(
Mixed.T.downsample(i),
Discrete.T.downsample(i),
Continuous.T.downsample(i),
),
t,
);
let truncate = (leftCutoff, rightCutoff, t): t =>
fmap(
(
Mixed.T.truncate(leftCutoff, rightCutoff),
Discrete.T.truncate(leftCutoff, rightCutoff),
Continuous.T.truncate(leftCutoff, rightCutoff),
),
t,
);
let toDiscreteProbabilityMassFraction = t => 0.0;
let normalize =
fmap((
Mixed.T.normalize,
Discrete.T.normalize,
Continuous.T.normalize
));
let updateIntegralCache = (integralCache, t: t): t =>
fmap((
Mixed.T.updateIntegralCache(integralCache),
Discrete.T.updateIntegralCache(integralCache),
Continuous.T.updateIntegralCache(integralCache),
), t);
let toContinuous =
mapToAll((
Mixed.T.toContinuous,
Discrete.T.toContinuous,
Continuous.T.toContinuous,
));
let toDiscrete =
mapToAll((
Mixed.T.toDiscrete,
Discrete.T.toDiscrete,
Continuous.T.toDiscrete,
));
let toDiscreteProbabilityMassFraction =
mapToAll((
Mixed.T.toDiscreteProbabilityMassFraction,
Discrete.T.toDiscreteProbabilityMassFraction,
Continuous.T.toDiscreteProbabilityMassFraction,
));
let minX = mapToAll((Mixed.T.minX, Discrete.T.minX, Continuous.T.minX));
let integral =
mapToAll((
Mixed.T.Integral.get,
Discrete.T.Integral.get,
Continuous.T.Integral.get,
));
let integralEndY =
mapToAll((
Mixed.T.Integral.sum,
Discrete.T.Integral.sum,
Continuous.T.Integral.sum,
));
let integralXtoY = (f) => {
mapToAll((
Mixed.T.Integral.xToY(f),
Discrete.T.Integral.xToY(f),
Continuous.T.Integral.xToY(f),
));
};
let integralYtoX = (f) => {
mapToAll((
Mixed.T.Integral.yToX(f),
Discrete.T.Integral.yToX(f),
Continuous.T.Integral.yToX(f),
));
};
let maxX = mapToAll((Mixed.T.maxX, Discrete.T.maxX, Continuous.T.maxX));
let mapY = (~integralSumCacheFn=previousIntegralSum => None, ~integralCacheFn=previousIntegral=>None, ~fn) =>{
fmap((
Mixed.T.mapY(~integralSumCacheFn, ~integralCacheFn, ~fn),
Discrete.T.mapY(~integralSumCacheFn, ~integralCacheFn, ~fn),
Continuous.T.mapY(~integralSumCacheFn, ~integralCacheFn, ~fn),
));
}
let mean = (t: t): float =>
switch (t) {
| Mixed(m) => Mixed.T.mean(m)
| Discrete(m) => Discrete.T.mean(m)
| Continuous(m) => Continuous.T.mean(m)
};
let variance = (t: t): float =>
switch (t) {
| Mixed(m) => Mixed.T.variance(m)
| Discrete(m) => Discrete.T.variance(m)
| Continuous(m) => Continuous.T.variance(m)
};
});
let pdf = (f: float, t: t) => {
let mixedPoint: DistTypes.mixedPoint = T.xToY(f, t);
mixedPoint.continuous +. mixedPoint.discrete;
};
let inv = T.Integral.yToX;
let cdf = T.Integral.xToY;
let doN = (n, fn) => {
let items = Belt.Array.make(n, 0.0);
for (x in 0 to n - 1) {
let _ = Belt.Array.set(items, x, fn());
();
};
items;
};
let sample = (t: t): float => {
let randomItem = Random.float(1.);
let bar = t |> T.Integral.yToX(randomItem);
bar;
};
let isFloat = (t:t) => switch(t){
| Discrete({xyShape: {xs: [|_|], ys: [|1.0|]}}) => true
| _ => false
}
let sampleNRendered = (n, dist) => {
let integralCache = T.Integral.get(dist);
let distWithUpdatedIntegralCache = T.updateIntegralCache(Some(integralCache), dist);
doN(n, () => sample(distWithUpdatedIntegralCache));
};
let operate = (distToFloatOp: ExpressionTypes.distToFloatOperation, s): float =>
switch (distToFloatOp) {
| `Pdf(f) => pdf(f, s)
| `Cdf(f) => pdf(f, s)
| `Inv(f) => inv(f, s)
| `Sample => sample(s)
| `Mean => T.mean(s)
};