First attempt at sampling

This commit is contained in:
Ozzie Gooen 2020-07-11 16:41:35 +01:00
parent 84b6d7176c
commit e5f38af43e
6 changed files with 119 additions and 49 deletions

View File

@ -26,18 +26,12 @@ let combineAlgebraically =
(op: ExpressionTypes.algebraicOperation, t1: t, t2: t): t => { (op: ExpressionTypes.algebraicOperation, t1: t, t2: t): t => {
switch (t1, t2) { switch (t1, t2) {
| (Continuous(m1), Continuous(m2)) => | (Continuous(m1), Continuous(m2)) =>
DistTypes.Continuous( DistTypes.Continuous(Continuous.combineAlgebraically(op, m1, m2))
Continuous.combineAlgebraically(op, m1, m2),
)
| (Discrete(m1), Discrete(m2)) => | (Discrete(m1), Discrete(m2)) =>
DistTypes.Discrete(Discrete.combineAlgebraically(op, m1, m2)) DistTypes.Discrete(Discrete.combineAlgebraically(op, m1, m2))
| (m1, m2) => | (m1, m2) =>
DistTypes.Mixed( DistTypes.Mixed(
Mixed.combineAlgebraically( Mixed.combineAlgebraically(op, toMixed(m1), toMixed(m2)),
op,
toMixed(m1),
toMixed(m2),
),
) )
}; };
}; };
@ -73,10 +67,6 @@ let inv = (f: float, t: t): float => {
0.0; 0.0;
}; };
let sample = (t: t): float => {
0.0;
};
module T = module T =
Dist({ Dist({
type t = DistTypes.shape; type t = DistTypes.shape;
@ -199,10 +189,30 @@ module T =
}; };
}); });
let doN = (n, fn) => {
let items = Belt.Array.make(n, 0.0);
for (x in 0 to n - 1) {
let _ = Belt.Array.set(items, x, fn());
();
};
items;
};
let sample = (cache, t: t): float => {
let randomItem = Random.float(1.);
let bar = T.Integral.yToX(~cache, randomItem, t);
bar;
};
let sampleNRendered = (n, dist) => {
let integralCache = T.Integral.get(~cache=None, dist);
doN(n, () => sample(Some(integralCache), dist));
};
let operate = (distToFloatOp: ExpressionTypes.distToFloatOperation, s) => let operate = (distToFloatOp: ExpressionTypes.distToFloatOperation, s) =>
switch (distToFloatOp) { switch (distToFloatOp) {
| `Pdf(f) => pdf(f, s) | `Pdf(f) => pdf(f, s)
| `Inv(f) => inv(f, s) | `Inv(f) => inv(f, s)
| `Sample => sample(s) | `Sample => sample(None, s)
| `Mean => T.mean(s) | `Mean => T.mean(s)
}; };

View File

@ -22,20 +22,54 @@ module AlgebraicCombination = {
| _ => Ok(`AlgebraicCombination((operation, t1, t2))) | _ => Ok(`AlgebraicCombination((operation, t1, t2)))
}; };
let tryCombination = (n, algebraicOp, t1: node, t2: node) => {
let sampleN =
mapRenderable(Shape.sampleNRendered(n), SymbolicDist.T.sampleN(n));
switch (sampleN(t1), sampleN(t2)) {
| (Some(a), Some(b)) =>
Some(
Belt.Array.zip(a, b)
|> E.A.fmap(((a, b)) => Operation.Algebraic.toFn(algebraicOp, a, b)),
)
| _ => None
};
};
let renderIfNotRendered = (params, t) =>
!renderable(t)
? switch (render(params, t)) {
| Ok(r) => Ok(r)
| Error(e) => Error(e)
}
: Ok(t);
let combineAsShapes = let combineAsShapes =
(evaluationParams: evaluationParams, algebraicOp, t1, t2) => { (evaluationParams: evaluationParams, algebraicOp, t1: node, t2: node) => {
let renderShape = render(evaluationParams); let i1 = renderIfNotRendered(evaluationParams, t1);
switch (renderShape(t1), renderShape(t2)) { let i2 = renderIfNotRendered(evaluationParams, t2);
| (Ok(`RenderedDist(s1)), Ok(`RenderedDist(s2))) => E.R.merge(i1, i2)
Ok( |> E.R.bind(
`RenderedDist( _,
Shape.combineAlgebraically(algebraicOp, s1, s2), ((a, b)) => {
let samples = tryCombination(evaluationParams.sampleCount, algebraicOp, a, b);
let shape =
samples
|> E.O.fmap(
Samples.T.fromSamples(
~samplingInputs={
sampleCount: Some(evaluationParams.sampleCount),
outputXYPoints: None,
kernelWidth: None,
},
), ),
) )
| (Error(e1), _) => Error(e1) |> E.O.bind(_, (r: RenderTypes.ShapeRenderer.Sampling.outputs) =>
| (_, Error(e2)) => Error(e2) r.shape
| _ => Error("Algebraic combination: rendering failed.") )
}; |> E.O.toResult("No response");
shape |> E.R.fmap(r => `Normalize(`RenderedDist(r)));
},
);
}; };
let operationToLeaf = let operationToLeaf =
@ -125,11 +159,9 @@ module Truncate = {
| (Some(lc), Some(rc), t) when lc > rc => | (Some(lc), Some(rc), t) when lc > rc =>
`Error("Left truncation bound must be smaller than right bound.") `Error("Left truncation bound must be smaller than right bound.")
| (lc, rc, `SymbolicDist(`Uniform(u))) => | (lc, rc, `SymbolicDist(`Uniform(u))) =>
// just create a new Uniform distribution `Solution(
let nu: SymbolicTypes.uniform = u; `SymbolicDist(`Uniform(SymbolicDist.Uniform.truncate(lc, rc, u))),
let newLow = max(E.O.default(neg_infinity, lc), nu.low); )
let newHigh = min(E.O.default(infinity, rc), nu.high);
`Solution(`SymbolicDist(`Uniform({low: newLow, high: newHigh})));
| _ => `NoSolution | _ => `NoSolution
}; };
}; };
@ -140,9 +172,7 @@ module Truncate = {
// of a distribution we otherwise wouldn't get at all // of a distribution we otherwise wouldn't get at all
switch (render(evaluationParams, t)) { switch (render(evaluationParams, t)) {
| Ok(`RenderedDist(rs)) => | Ok(`RenderedDist(rs)) =>
let truncatedShape = Ok(`RenderedDist(Shape.T.truncate(leftCutoff, rightCutoff, rs)))
rs |> Shape.T.truncate(leftCutoff, rightCutoff);
Ok(`RenderedDist(truncatedShape));
| Error(e) => Error(e) | Error(e) => Error(e)
| _ => Error("Could not truncate distribution.") | _ => Error("Could not truncate distribution.")
}; };
@ -171,8 +201,7 @@ module Truncate = {
module Normalize = { module Normalize = {
let rec operationToLeaf = (evaluationParams, t: node): result(node, string) => { let rec operationToLeaf = (evaluationParams, t: node): result(node, string) => {
switch (t) { switch (t) {
| `RenderedDist(s) => | `RenderedDist(s) => Ok(`RenderedDist(Shape.T.normalize(s)))
Ok(`RenderedDist(Shape.T.normalize(s)))
| `SymbolicDist(_) => Ok(t) | `SymbolicDist(_) => Ok(t)
| _ => evaluateAndRetry(evaluationParams, operationToLeaf, t) | _ => evaluateAndRetry(evaluationParams, operationToLeaf, t)
}; };

View File

@ -29,6 +29,17 @@ module ExpressionTree = {
let evaluateAndRetry = (evaluationParams, fn, node) => let evaluateAndRetry = (evaluationParams, fn, node) =>
node |> evaluationParams.evaluateNode(evaluationParams) |> E.R.bind(_, fn(evaluationParams)); node |> evaluationParams.evaluateNode(evaluationParams) |> E.R.bind(_, fn(evaluationParams));
let renderable = fun
| `SymbolicDist(_) => true
| `RenderedDist(_) => true
| _ => false;
let mapRenderable = (renderedFn, symFn, item: node) => switch(item) {
| `SymbolicDist(s) => Some(symFn(s))
| `RenderedDist(r) => Some(renderedFn(r))
| _ => None
}
}; };
type simplificationResult = [ type simplificationResult = [

View File

@ -130,6 +130,11 @@ module Uniform = {
let sample = (t: t) => Jstat.uniform##sample(t.low, t.high); let sample = (t: t) => Jstat.uniform##sample(t.low, t.high);
let mean = (t: t) => Ok(Jstat.uniform##mean(t.low, t.high)); let mean = (t: t) => Ok(Jstat.uniform##mean(t.low, t.high));
let toString = ({low, high}: t) => {j|Uniform($low,$high)|j}; let toString = ({low, high}: t) => {j|Uniform($low,$high)|j};
let truncate = (low, high, t: t): t => {
let newLow = max(E.O.default(neg_infinity, low), t.low);
let newHigh = min(E.O.default(infinity, high), t.high);
{low: newLow, high: newHigh};
};
}; };
module Float = { module Float = {
@ -178,7 +183,20 @@ module T = {
| `Lognormal(n) => Lognormal.sample(n) | `Lognormal(n) => Lognormal.sample(n)
| `Uniform(n) => Uniform.sample(n) | `Uniform(n) => Uniform.sample(n)
| `Beta(n) => Beta.sample(n) | `Beta(n) => Beta.sample(n)
| `Float(n) => Float.sample(n) | `Float(n) => Float.sample(n);
let doN = (n, fn) => {
let items = Belt.Array.make(n, 0.0);
for (x in 0 to n - 1) {
let _ = Belt.Array.set(items, x, fn());
();
};
items;
};
let sampleN = (n, dist) => {
doN(n, () => sample(dist));
};
let toString: symbolicDist => string = let toString: symbolicDist => string =
fun fun
@ -189,7 +207,7 @@ module T = {
| `Lognormal(n) => Lognormal.toString(n) | `Lognormal(n) => Lognormal.toString(n)
| `Uniform(n) => Uniform.toString(n) | `Uniform(n) => Uniform.toString(n)
| `Beta(n) => Beta.toString(n) | `Beta(n) => Beta.toString(n)
| `Float(n) => Float.toString(n) | `Float(n) => Float.toString(n);
let min: symbolicDist => float = let min: symbolicDist => float =
fun fun
@ -277,14 +295,10 @@ module T = {
let toShape = (sampleCount, d: symbolicDist): DistTypes.shape => let toShape = (sampleCount, d: symbolicDist): DistTypes.shape =>
switch (d) { switch (d) {
| `Float(v) => | `Float(v) =>
Discrete( Discrete(Discrete.make({xs: [|v|], ys: [|1.0|]}, Some(1.0)))
Discrete.make({xs: [|v|], ys: [|1.0|]}, Some(1.0)),
)
| _ => | _ =>
let xs = interpolateXs(~xSelection=`ByWeight, d, sampleCount); let xs = interpolateXs(~xSelection=`ByWeight, d, sampleCount);
let ys = xs |> E.A.fmap(x => pdf(x, d)); let ys = xs |> E.A.fmap(x => pdf(x, d));
Continuous( Continuous(Continuous.make(`Linear, {xs, ys}, Some(1.0)));
Continuous.make(`Linear, {xs, ys}, Some(1.0)),
);
}; };
}; };

View File

@ -145,6 +145,12 @@ module R = {
let id = e => e |> result(U.id, U.id); let id = e => e |> result(U.id, U.id);
let fmap = Rationale.Result.fmap; let fmap = Rationale.Result.fmap;
let bind = Rationale.Result.bind; let bind = Rationale.Result.bind;
let merge = (a, b) =>
switch (a, b) {
| (Error(e), _) => Error(e)
| (_, Error(e)) => Error(e)
| (Ok(a), Ok(b)) => Ok((a, b))
};
let toOption = (e: Belt.Result.t('a, 'b)) => let toOption = (e: Belt.Result.t('a, 'b)) =>
switch (e) { switch (e) {
| Ok(r) => Some(r) | Ok(r) => Some(r)