First attempt at sampling

This commit is contained in:
Ozzie Gooen 2020-07-11 16:41:35 +01:00
parent 84b6d7176c
commit e5f38af43e
6 changed files with 119 additions and 49 deletions

View File

@ -245,4 +245,4 @@ let combineAlgebraically =
// return a new Continuous distribution
make(`Linear, combinedShape, combinedIntegralSum);
};
};
};

View File

@ -26,18 +26,12 @@ let combineAlgebraically =
(op: ExpressionTypes.algebraicOperation, t1: t, t2: t): t => {
switch (t1, t2) {
| (Continuous(m1), Continuous(m2)) =>
DistTypes.Continuous(
Continuous.combineAlgebraically(op, m1, m2),
)
DistTypes.Continuous(Continuous.combineAlgebraically(op, m1, m2))
| (Discrete(m1), Discrete(m2)) =>
DistTypes.Discrete(Discrete.combineAlgebraically(op, m1, m2))
| (m1, m2) =>
DistTypes.Mixed(
Mixed.combineAlgebraically(
op,
toMixed(m1),
toMixed(m2),
),
Mixed.combineAlgebraically(op, toMixed(m1), toMixed(m2)),
)
};
};
@ -73,10 +67,6 @@ let inv = (f: float, t: t): float => {
0.0;
};
let sample = (t: t): float => {
0.0;
};
module T =
Dist({
type t = DistTypes.shape;
@ -199,10 +189,30 @@ module T =
};
});
let doN = (n, fn) => {
let items = Belt.Array.make(n, 0.0);
for (x in 0 to n - 1) {
let _ = Belt.Array.set(items, x, fn());
();
};
items;
};
let sample = (cache, t: t): float => {
let randomItem = Random.float(1.);
let bar = T.Integral.yToX(~cache, randomItem, t);
bar;
};
let sampleNRendered = (n, dist) => {
let integralCache = T.Integral.get(~cache=None, dist);
doN(n, () => sample(Some(integralCache), dist));
};
let operate = (distToFloatOp: ExpressionTypes.distToFloatOperation, s) =>
switch (distToFloatOp) {
| `Pdf(f) => pdf(f, s)
| `Inv(f) => inv(f, s)
| `Sample => sample(s)
| `Sample => sample(None, s)
| `Mean => T.mean(s)
};

View File

@ -22,22 +22,56 @@ module AlgebraicCombination = {
| _ => Ok(`AlgebraicCombination((operation, t1, t2)))
};
let combineAsShapes =
(evaluationParams: evaluationParams, algebraicOp, t1, t2) => {
let renderShape = render(evaluationParams);
switch (renderShape(t1), renderShape(t2)) {
| (Ok(`RenderedDist(s1)), Ok(`RenderedDist(s2))) =>
Ok(
`RenderedDist(
Shape.combineAlgebraically(algebraicOp, s1, s2),
),
let tryCombination = (n, algebraicOp, t1: node, t2: node) => {
let sampleN =
mapRenderable(Shape.sampleNRendered(n), SymbolicDist.T.sampleN(n));
switch (sampleN(t1), sampleN(t2)) {
| (Some(a), Some(b)) =>
Some(
Belt.Array.zip(a, b)
|> E.A.fmap(((a, b)) => Operation.Algebraic.toFn(algebraicOp, a, b)),
)
| (Error(e1), _) => Error(e1)
| (_, Error(e2)) => Error(e2)
| _ => Error("Algebraic combination: rendering failed.")
| _ => None
};
};
let renderIfNotRendered = (params, t) =>
!renderable(t)
? switch (render(params, t)) {
| Ok(r) => Ok(r)
| Error(e) => Error(e)
}
: Ok(t);
let combineAsShapes =
(evaluationParams: evaluationParams, algebraicOp, t1: node, t2: node) => {
let i1 = renderIfNotRendered(evaluationParams, t1);
let i2 = renderIfNotRendered(evaluationParams, t2);
E.R.merge(i1, i2)
|> E.R.bind(
_,
((a, b)) => {
let samples = tryCombination(evaluationParams.sampleCount, algebraicOp, a, b);
let shape =
samples
|> E.O.fmap(
Samples.T.fromSamples(
~samplingInputs={
sampleCount: Some(evaluationParams.sampleCount),
outputXYPoints: None,
kernelWidth: None,
},
),
)
|> E.O.bind(_, (r: RenderTypes.ShapeRenderer.Sampling.outputs) =>
r.shape
)
|> E.O.toResult("No response");
shape |> E.R.fmap(r => `Normalize(`RenderedDist(r)));
},
);
};
let operationToLeaf =
(
evaluationParams: evaluationParams,
@ -125,11 +159,9 @@ module Truncate = {
| (Some(lc), Some(rc), t) when lc > rc =>
`Error("Left truncation bound must be smaller than right bound.")
| (lc, rc, `SymbolicDist(`Uniform(u))) =>
// just create a new Uniform distribution
let nu: SymbolicTypes.uniform = u;
let newLow = max(E.O.default(neg_infinity, lc), nu.low);
let newHigh = min(E.O.default(infinity, rc), nu.high);
`Solution(`SymbolicDist(`Uniform({low: newLow, high: newHigh})));
`Solution(
`SymbolicDist(`Uniform(SymbolicDist.Uniform.truncate(lc, rc, u))),
)
| _ => `NoSolution
};
};
@ -140,9 +172,7 @@ module Truncate = {
// of a distribution we otherwise wouldn't get at all
switch (render(evaluationParams, t)) {
| Ok(`RenderedDist(rs)) =>
let truncatedShape =
rs |> Shape.T.truncate(leftCutoff, rightCutoff);
Ok(`RenderedDist(truncatedShape));
Ok(`RenderedDist(Shape.T.truncate(leftCutoff, rightCutoff, rs)))
| Error(e) => Error(e)
| _ => Error("Could not truncate distribution.")
};
@ -171,8 +201,7 @@ module Truncate = {
module Normalize = {
let rec operationToLeaf = (evaluationParams, t: node): result(node, string) => {
switch (t) {
| `RenderedDist(s) =>
Ok(`RenderedDist(Shape.T.normalize(s)))
| `RenderedDist(s) => Ok(`RenderedDist(Shape.T.normalize(s)))
| `SymbolicDist(_) => Ok(t)
| _ => evaluateAndRetry(evaluationParams, operationToLeaf, t)
};

View File

@ -29,6 +29,17 @@ module ExpressionTree = {
let evaluateAndRetry = (evaluationParams, fn, node) =>
node |> evaluationParams.evaluateNode(evaluationParams) |> E.R.bind(_, fn(evaluationParams));
let renderable = fun
| `SymbolicDist(_) => true
| `RenderedDist(_) => true
| _ => false;
let mapRenderable = (renderedFn, symFn, item: node) => switch(item) {
| `SymbolicDist(s) => Some(symFn(s))
| `RenderedDist(r) => Some(renderedFn(r))
| _ => None
}
};
type simplificationResult = [

View File

@ -130,6 +130,11 @@ module Uniform = {
let sample = (t: t) => Jstat.uniform##sample(t.low, t.high);
let mean = (t: t) => Ok(Jstat.uniform##mean(t.low, t.high));
let toString = ({low, high}: t) => {j|Uniform($low,$high)|j};
let truncate = (low, high, t: t): t => {
let newLow = max(E.O.default(neg_infinity, low), t.low);
let newHigh = min(E.O.default(infinity, high), t.high);
{low: newLow, high: newHigh};
};
};
module Float = {
@ -178,7 +183,20 @@ module T = {
| `Lognormal(n) => Lognormal.sample(n)
| `Uniform(n) => Uniform.sample(n)
| `Beta(n) => Beta.sample(n)
| `Float(n) => Float.sample(n)
| `Float(n) => Float.sample(n);
let doN = (n, fn) => {
let items = Belt.Array.make(n, 0.0);
for (x in 0 to n - 1) {
let _ = Belt.Array.set(items, x, fn());
();
};
items;
};
let sampleN = (n, dist) => {
doN(n, () => sample(dist));
};
let toString: symbolicDist => string =
fun
@ -189,7 +207,7 @@ module T = {
| `Lognormal(n) => Lognormal.toString(n)
| `Uniform(n) => Uniform.toString(n)
| `Beta(n) => Beta.toString(n)
| `Float(n) => Float.toString(n)
| `Float(n) => Float.toString(n);
let min: symbolicDist => float =
fun
@ -237,10 +255,10 @@ module T = {
switch (xSelection, dist) {
| (`Linear, _) => E.A.Floats.range(min(dist), max(dist), n)
| (`ByWeight, `Uniform(n)) =>
// In `ByWeight mode, uniform distributions get special treatment because we need two x's
// on either side for proper rendering (just left and right of the discontinuities).
let dx = 0.00001 *. (n.high -. n.low);
[|n.low -. dx, n.low +. dx, n.high -. dx, n.high +. dx|];
// In `ByWeight mode, uniform distributions get special treatment because we need two x's
// on either side for proper rendering (just left and right of the discontinuities).
let dx = 0.00001 *. (n.high -. n.low);
[|n.low -. dx, n.low +. dx, n.high -. dx, n.high +. dx|];
| (`ByWeight, _) =>
let ys = E.A.Floats.range(minCdfValue, maxCdfValue, n);
ys |> E.A.fmap(y => inv(y, dist));
@ -277,14 +295,10 @@ module T = {
let toShape = (sampleCount, d: symbolicDist): DistTypes.shape =>
switch (d) {
| `Float(v) =>
Discrete(
Discrete.make({xs: [|v|], ys: [|1.0|]}, Some(1.0)),
)
Discrete(Discrete.make({xs: [|v|], ys: [|1.0|]}, Some(1.0)))
| _ =>
let xs = interpolateXs(~xSelection=`ByWeight, d, sampleCount);
let ys = xs |> E.A.fmap(x => pdf(x, d));
Continuous(
Continuous.make(`Linear, {xs, ys}, Some(1.0)),
);
Continuous(Continuous.make(`Linear, {xs, ys}, Some(1.0)));
};
};

View File

@ -145,6 +145,12 @@ module R = {
let id = e => e |> result(U.id, U.id);
let fmap = Rationale.Result.fmap;
let bind = Rationale.Result.bind;
let merge = (a, b) =>
switch (a, b) {
| (Error(e), _) => Error(e)
| (_, Error(e)) => Error(e)
| (Ok(a), Ok(b)) => Ok((a, b))
};
let toOption = (e: Belt.Result.t('a, 'b)) =>
switch (e) {
| Ok(r) => Some(r)
@ -464,4 +470,4 @@ module JsArray = {
Rationale.Option.toExn("Warning: This should not have happened"),
);
let filter = Js.Array.filter;
};
};