First attempt at sampling
This commit is contained in:
parent
84b6d7176c
commit
e5f38af43e
|
@ -26,18 +26,12 @@ let combineAlgebraically =
|
|||
(op: ExpressionTypes.algebraicOperation, t1: t, t2: t): t => {
|
||||
switch (t1, t2) {
|
||||
| (Continuous(m1), Continuous(m2)) =>
|
||||
DistTypes.Continuous(
|
||||
Continuous.combineAlgebraically(op, m1, m2),
|
||||
)
|
||||
DistTypes.Continuous(Continuous.combineAlgebraically(op, m1, m2))
|
||||
| (Discrete(m1), Discrete(m2)) =>
|
||||
DistTypes.Discrete(Discrete.combineAlgebraically(op, m1, m2))
|
||||
| (m1, m2) =>
|
||||
DistTypes.Mixed(
|
||||
Mixed.combineAlgebraically(
|
||||
op,
|
||||
toMixed(m1),
|
||||
toMixed(m2),
|
||||
),
|
||||
Mixed.combineAlgebraically(op, toMixed(m1), toMixed(m2)),
|
||||
)
|
||||
};
|
||||
};
|
||||
|
@ -73,10 +67,6 @@ let inv = (f: float, t: t): float => {
|
|||
0.0;
|
||||
};
|
||||
|
||||
let sample = (t: t): float => {
|
||||
0.0;
|
||||
};
|
||||
|
||||
module T =
|
||||
Dist({
|
||||
type t = DistTypes.shape;
|
||||
|
@ -199,10 +189,30 @@ module T =
|
|||
};
|
||||
});
|
||||
|
||||
let doN = (n, fn) => {
|
||||
let items = Belt.Array.make(n, 0.0);
|
||||
for (x in 0 to n - 1) {
|
||||
let _ = Belt.Array.set(items, x, fn());
|
||||
();
|
||||
};
|
||||
items;
|
||||
};
|
||||
|
||||
let sample = (cache, t: t): float => {
|
||||
let randomItem = Random.float(1.);
|
||||
let bar = T.Integral.yToX(~cache, randomItem, t);
|
||||
bar;
|
||||
};
|
||||
|
||||
let sampleNRendered = (n, dist) => {
|
||||
let integralCache = T.Integral.get(~cache=None, dist);
|
||||
doN(n, () => sample(Some(integralCache), dist));
|
||||
};
|
||||
|
||||
let operate = (distToFloatOp: ExpressionTypes.distToFloatOperation, s) =>
|
||||
switch (distToFloatOp) {
|
||||
| `Pdf(f) => pdf(f, s)
|
||||
| `Inv(f) => inv(f, s)
|
||||
| `Sample => sample(s)
|
||||
| `Sample => sample(None, s)
|
||||
| `Mean => T.mean(s)
|
||||
};
|
||||
|
|
|
@ -22,22 +22,56 @@ module AlgebraicCombination = {
|
|||
| _ => Ok(`AlgebraicCombination((operation, t1, t2)))
|
||||
};
|
||||
|
||||
let combineAsShapes =
|
||||
(evaluationParams: evaluationParams, algebraicOp, t1, t2) => {
|
||||
let renderShape = render(evaluationParams);
|
||||
switch (renderShape(t1), renderShape(t2)) {
|
||||
| (Ok(`RenderedDist(s1)), Ok(`RenderedDist(s2))) =>
|
||||
Ok(
|
||||
`RenderedDist(
|
||||
Shape.combineAlgebraically(algebraicOp, s1, s2),
|
||||
),
|
||||
let tryCombination = (n, algebraicOp, t1: node, t2: node) => {
|
||||
let sampleN =
|
||||
mapRenderable(Shape.sampleNRendered(n), SymbolicDist.T.sampleN(n));
|
||||
switch (sampleN(t1), sampleN(t2)) {
|
||||
| (Some(a), Some(b)) =>
|
||||
Some(
|
||||
Belt.Array.zip(a, b)
|
||||
|> E.A.fmap(((a, b)) => Operation.Algebraic.toFn(algebraicOp, a, b)),
|
||||
)
|
||||
| (Error(e1), _) => Error(e1)
|
||||
| (_, Error(e2)) => Error(e2)
|
||||
| _ => Error("Algebraic combination: rendering failed.")
|
||||
| _ => None
|
||||
};
|
||||
};
|
||||
|
||||
let renderIfNotRendered = (params, t) =>
|
||||
!renderable(t)
|
||||
? switch (render(params, t)) {
|
||||
| Ok(r) => Ok(r)
|
||||
| Error(e) => Error(e)
|
||||
}
|
||||
: Ok(t);
|
||||
|
||||
let combineAsShapes =
|
||||
(evaluationParams: evaluationParams, algebraicOp, t1: node, t2: node) => {
|
||||
let i1 = renderIfNotRendered(evaluationParams, t1);
|
||||
let i2 = renderIfNotRendered(evaluationParams, t2);
|
||||
E.R.merge(i1, i2)
|
||||
|> E.R.bind(
|
||||
_,
|
||||
((a, b)) => {
|
||||
let samples = tryCombination(evaluationParams.sampleCount, algebraicOp, a, b);
|
||||
let shape =
|
||||
samples
|
||||
|> E.O.fmap(
|
||||
Samples.T.fromSamples(
|
||||
~samplingInputs={
|
||||
sampleCount: Some(evaluationParams.sampleCount),
|
||||
outputXYPoints: None,
|
||||
kernelWidth: None,
|
||||
},
|
||||
),
|
||||
)
|
||||
|> E.O.bind(_, (r: RenderTypes.ShapeRenderer.Sampling.outputs) =>
|
||||
r.shape
|
||||
)
|
||||
|> E.O.toResult("No response");
|
||||
shape |> E.R.fmap(r => `Normalize(`RenderedDist(r)));
|
||||
},
|
||||
);
|
||||
};
|
||||
|
||||
let operationToLeaf =
|
||||
(
|
||||
evaluationParams: evaluationParams,
|
||||
|
@ -125,11 +159,9 @@ module Truncate = {
|
|||
| (Some(lc), Some(rc), t) when lc > rc =>
|
||||
`Error("Left truncation bound must be smaller than right bound.")
|
||||
| (lc, rc, `SymbolicDist(`Uniform(u))) =>
|
||||
// just create a new Uniform distribution
|
||||
let nu: SymbolicTypes.uniform = u;
|
||||
let newLow = max(E.O.default(neg_infinity, lc), nu.low);
|
||||
let newHigh = min(E.O.default(infinity, rc), nu.high);
|
||||
`Solution(`SymbolicDist(`Uniform({low: newLow, high: newHigh})));
|
||||
`Solution(
|
||||
`SymbolicDist(`Uniform(SymbolicDist.Uniform.truncate(lc, rc, u))),
|
||||
)
|
||||
| _ => `NoSolution
|
||||
};
|
||||
};
|
||||
|
@ -140,9 +172,7 @@ module Truncate = {
|
|||
// of a distribution we otherwise wouldn't get at all
|
||||
switch (render(evaluationParams, t)) {
|
||||
| Ok(`RenderedDist(rs)) =>
|
||||
let truncatedShape =
|
||||
rs |> Shape.T.truncate(leftCutoff, rightCutoff);
|
||||
Ok(`RenderedDist(truncatedShape));
|
||||
Ok(`RenderedDist(Shape.T.truncate(leftCutoff, rightCutoff, rs)))
|
||||
| Error(e) => Error(e)
|
||||
| _ => Error("Could not truncate distribution.")
|
||||
};
|
||||
|
@ -171,8 +201,7 @@ module Truncate = {
|
|||
module Normalize = {
|
||||
let rec operationToLeaf = (evaluationParams, t: node): result(node, string) => {
|
||||
switch (t) {
|
||||
| `RenderedDist(s) =>
|
||||
Ok(`RenderedDist(Shape.T.normalize(s)))
|
||||
| `RenderedDist(s) => Ok(`RenderedDist(Shape.T.normalize(s)))
|
||||
| `SymbolicDist(_) => Ok(t)
|
||||
| _ => evaluateAndRetry(evaluationParams, operationToLeaf, t)
|
||||
};
|
||||
|
|
|
@ -29,6 +29,17 @@ module ExpressionTree = {
|
|||
|
||||
let evaluateAndRetry = (evaluationParams, fn, node) =>
|
||||
node |> evaluationParams.evaluateNode(evaluationParams) |> E.R.bind(_, fn(evaluationParams));
|
||||
|
||||
let renderable = fun
|
||||
| `SymbolicDist(_) => true
|
||||
| `RenderedDist(_) => true
|
||||
| _ => false;
|
||||
|
||||
let mapRenderable = (renderedFn, symFn, item: node) => switch(item) {
|
||||
| `SymbolicDist(s) => Some(symFn(s))
|
||||
| `RenderedDist(r) => Some(renderedFn(r))
|
||||
| _ => None
|
||||
}
|
||||
};
|
||||
|
||||
type simplificationResult = [
|
||||
|
|
|
@ -130,6 +130,11 @@ module Uniform = {
|
|||
let sample = (t: t) => Jstat.uniform##sample(t.low, t.high);
|
||||
let mean = (t: t) => Ok(Jstat.uniform##mean(t.low, t.high));
|
||||
let toString = ({low, high}: t) => {j|Uniform($low,$high)|j};
|
||||
let truncate = (low, high, t: t): t => {
|
||||
let newLow = max(E.O.default(neg_infinity, low), t.low);
|
||||
let newHigh = min(E.O.default(infinity, high), t.high);
|
||||
{low: newLow, high: newHigh};
|
||||
};
|
||||
};
|
||||
|
||||
module Float = {
|
||||
|
@ -178,7 +183,20 @@ module T = {
|
|||
| `Lognormal(n) => Lognormal.sample(n)
|
||||
| `Uniform(n) => Uniform.sample(n)
|
||||
| `Beta(n) => Beta.sample(n)
|
||||
| `Float(n) => Float.sample(n)
|
||||
| `Float(n) => Float.sample(n);
|
||||
|
||||
let doN = (n, fn) => {
|
||||
let items = Belt.Array.make(n, 0.0);
|
||||
for (x in 0 to n - 1) {
|
||||
let _ = Belt.Array.set(items, x, fn());
|
||||
();
|
||||
};
|
||||
items;
|
||||
};
|
||||
|
||||
let sampleN = (n, dist) => {
|
||||
doN(n, () => sample(dist));
|
||||
};
|
||||
|
||||
let toString: symbolicDist => string =
|
||||
fun
|
||||
|
@ -189,7 +207,7 @@ module T = {
|
|||
| `Lognormal(n) => Lognormal.toString(n)
|
||||
| `Uniform(n) => Uniform.toString(n)
|
||||
| `Beta(n) => Beta.toString(n)
|
||||
| `Float(n) => Float.toString(n)
|
||||
| `Float(n) => Float.toString(n);
|
||||
|
||||
let min: symbolicDist => float =
|
||||
fun
|
||||
|
@ -237,10 +255,10 @@ module T = {
|
|||
switch (xSelection, dist) {
|
||||
| (`Linear, _) => E.A.Floats.range(min(dist), max(dist), n)
|
||||
| (`ByWeight, `Uniform(n)) =>
|
||||
// In `ByWeight mode, uniform distributions get special treatment because we need two x's
|
||||
// on either side for proper rendering (just left and right of the discontinuities).
|
||||
let dx = 0.00001 *. (n.high -. n.low);
|
||||
[|n.low -. dx, n.low +. dx, n.high -. dx, n.high +. dx|];
|
||||
// In `ByWeight mode, uniform distributions get special treatment because we need two x's
|
||||
// on either side for proper rendering (just left and right of the discontinuities).
|
||||
let dx = 0.00001 *. (n.high -. n.low);
|
||||
[|n.low -. dx, n.low +. dx, n.high -. dx, n.high +. dx|];
|
||||
| (`ByWeight, _) =>
|
||||
let ys = E.A.Floats.range(minCdfValue, maxCdfValue, n);
|
||||
ys |> E.A.fmap(y => inv(y, dist));
|
||||
|
@ -277,14 +295,10 @@ module T = {
|
|||
let toShape = (sampleCount, d: symbolicDist): DistTypes.shape =>
|
||||
switch (d) {
|
||||
| `Float(v) =>
|
||||
Discrete(
|
||||
Discrete.make({xs: [|v|], ys: [|1.0|]}, Some(1.0)),
|
||||
)
|
||||
Discrete(Discrete.make({xs: [|v|], ys: [|1.0|]}, Some(1.0)))
|
||||
| _ =>
|
||||
let xs = interpolateXs(~xSelection=`ByWeight, d, sampleCount);
|
||||
let ys = xs |> E.A.fmap(x => pdf(x, d));
|
||||
Continuous(
|
||||
Continuous.make(`Linear, {xs, ys}, Some(1.0)),
|
||||
);
|
||||
Continuous(Continuous.make(`Linear, {xs, ys}, Some(1.0)));
|
||||
};
|
||||
};
|
||||
|
|
|
@ -145,6 +145,12 @@ module R = {
|
|||
let id = e => e |> result(U.id, U.id);
|
||||
let fmap = Rationale.Result.fmap;
|
||||
let bind = Rationale.Result.bind;
|
||||
let merge = (a, b) =>
|
||||
switch (a, b) {
|
||||
| (Error(e), _) => Error(e)
|
||||
| (_, Error(e)) => Error(e)
|
||||
| (Ok(a), Ok(b)) => Ok((a, b))
|
||||
};
|
||||
let toOption = (e: Belt.Result.t('a, 'b)) =>
|
||||
switch (e) {
|
||||
| Ok(r) => Some(r)
|
||||
|
|
Loading…
Reference in New Issue
Block a user