First attempt at sampling
This commit is contained in:
parent
84b6d7176c
commit
e5f38af43e
|
@ -26,18 +26,12 @@ let combineAlgebraically =
|
||||||
(op: ExpressionTypes.algebraicOperation, t1: t, t2: t): t => {
|
(op: ExpressionTypes.algebraicOperation, t1: t, t2: t): t => {
|
||||||
switch (t1, t2) {
|
switch (t1, t2) {
|
||||||
| (Continuous(m1), Continuous(m2)) =>
|
| (Continuous(m1), Continuous(m2)) =>
|
||||||
DistTypes.Continuous(
|
DistTypes.Continuous(Continuous.combineAlgebraically(op, m1, m2))
|
||||||
Continuous.combineAlgebraically(op, m1, m2),
|
|
||||||
)
|
|
||||||
| (Discrete(m1), Discrete(m2)) =>
|
| (Discrete(m1), Discrete(m2)) =>
|
||||||
DistTypes.Discrete(Discrete.combineAlgebraically(op, m1, m2))
|
DistTypes.Discrete(Discrete.combineAlgebraically(op, m1, m2))
|
||||||
| (m1, m2) =>
|
| (m1, m2) =>
|
||||||
DistTypes.Mixed(
|
DistTypes.Mixed(
|
||||||
Mixed.combineAlgebraically(
|
Mixed.combineAlgebraically(op, toMixed(m1), toMixed(m2)),
|
||||||
op,
|
|
||||||
toMixed(m1),
|
|
||||||
toMixed(m2),
|
|
||||||
),
|
|
||||||
)
|
)
|
||||||
};
|
};
|
||||||
};
|
};
|
||||||
|
@ -73,10 +67,6 @@ let inv = (f: float, t: t): float => {
|
||||||
0.0;
|
0.0;
|
||||||
};
|
};
|
||||||
|
|
||||||
let sample = (t: t): float => {
|
|
||||||
0.0;
|
|
||||||
};
|
|
||||||
|
|
||||||
module T =
|
module T =
|
||||||
Dist({
|
Dist({
|
||||||
type t = DistTypes.shape;
|
type t = DistTypes.shape;
|
||||||
|
@ -199,10 +189,30 @@ module T =
|
||||||
};
|
};
|
||||||
});
|
});
|
||||||
|
|
||||||
|
let doN = (n, fn) => {
|
||||||
|
let items = Belt.Array.make(n, 0.0);
|
||||||
|
for (x in 0 to n - 1) {
|
||||||
|
let _ = Belt.Array.set(items, x, fn());
|
||||||
|
();
|
||||||
|
};
|
||||||
|
items;
|
||||||
|
};
|
||||||
|
|
||||||
|
let sample = (cache, t: t): float => {
|
||||||
|
let randomItem = Random.float(1.);
|
||||||
|
let bar = T.Integral.yToX(~cache, randomItem, t);
|
||||||
|
bar;
|
||||||
|
};
|
||||||
|
|
||||||
|
let sampleNRendered = (n, dist) => {
|
||||||
|
let integralCache = T.Integral.get(~cache=None, dist);
|
||||||
|
doN(n, () => sample(Some(integralCache), dist));
|
||||||
|
};
|
||||||
|
|
||||||
let operate = (distToFloatOp: ExpressionTypes.distToFloatOperation, s) =>
|
let operate = (distToFloatOp: ExpressionTypes.distToFloatOperation, s) =>
|
||||||
switch (distToFloatOp) {
|
switch (distToFloatOp) {
|
||||||
| `Pdf(f) => pdf(f, s)
|
| `Pdf(f) => pdf(f, s)
|
||||||
| `Inv(f) => inv(f, s)
|
| `Inv(f) => inv(f, s)
|
||||||
| `Sample => sample(s)
|
| `Sample => sample(None, s)
|
||||||
| `Mean => T.mean(s)
|
| `Mean => T.mean(s)
|
||||||
};
|
};
|
||||||
|
|
|
@ -22,22 +22,56 @@ module AlgebraicCombination = {
|
||||||
| _ => Ok(`AlgebraicCombination((operation, t1, t2)))
|
| _ => Ok(`AlgebraicCombination((operation, t1, t2)))
|
||||||
};
|
};
|
||||||
|
|
||||||
let combineAsShapes =
|
let tryCombination = (n, algebraicOp, t1: node, t2: node) => {
|
||||||
(evaluationParams: evaluationParams, algebraicOp, t1, t2) => {
|
let sampleN =
|
||||||
let renderShape = render(evaluationParams);
|
mapRenderable(Shape.sampleNRendered(n), SymbolicDist.T.sampleN(n));
|
||||||
switch (renderShape(t1), renderShape(t2)) {
|
switch (sampleN(t1), sampleN(t2)) {
|
||||||
| (Ok(`RenderedDist(s1)), Ok(`RenderedDist(s2))) =>
|
| (Some(a), Some(b)) =>
|
||||||
Ok(
|
Some(
|
||||||
`RenderedDist(
|
Belt.Array.zip(a, b)
|
||||||
Shape.combineAlgebraically(algebraicOp, s1, s2),
|
|> E.A.fmap(((a, b)) => Operation.Algebraic.toFn(algebraicOp, a, b)),
|
||||||
),
|
|
||||||
)
|
)
|
||||||
| (Error(e1), _) => Error(e1)
|
| _ => None
|
||||||
| (_, Error(e2)) => Error(e2)
|
|
||||||
| _ => Error("Algebraic combination: rendering failed.")
|
|
||||||
};
|
};
|
||||||
};
|
};
|
||||||
|
|
||||||
|
let renderIfNotRendered = (params, t) =>
|
||||||
|
!renderable(t)
|
||||||
|
? switch (render(params, t)) {
|
||||||
|
| Ok(r) => Ok(r)
|
||||||
|
| Error(e) => Error(e)
|
||||||
|
}
|
||||||
|
: Ok(t);
|
||||||
|
|
||||||
|
let combineAsShapes =
|
||||||
|
(evaluationParams: evaluationParams, algebraicOp, t1: node, t2: node) => {
|
||||||
|
let i1 = renderIfNotRendered(evaluationParams, t1);
|
||||||
|
let i2 = renderIfNotRendered(evaluationParams, t2);
|
||||||
|
E.R.merge(i1, i2)
|
||||||
|
|> E.R.bind(
|
||||||
|
_,
|
||||||
|
((a, b)) => {
|
||||||
|
let samples = tryCombination(evaluationParams.sampleCount, algebraicOp, a, b);
|
||||||
|
let shape =
|
||||||
|
samples
|
||||||
|
|> E.O.fmap(
|
||||||
|
Samples.T.fromSamples(
|
||||||
|
~samplingInputs={
|
||||||
|
sampleCount: Some(evaluationParams.sampleCount),
|
||||||
|
outputXYPoints: None,
|
||||||
|
kernelWidth: None,
|
||||||
|
},
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|> E.O.bind(_, (r: RenderTypes.ShapeRenderer.Sampling.outputs) =>
|
||||||
|
r.shape
|
||||||
|
)
|
||||||
|
|> E.O.toResult("No response");
|
||||||
|
shape |> E.R.fmap(r => `Normalize(`RenderedDist(r)));
|
||||||
|
},
|
||||||
|
);
|
||||||
|
};
|
||||||
|
|
||||||
let operationToLeaf =
|
let operationToLeaf =
|
||||||
(
|
(
|
||||||
evaluationParams: evaluationParams,
|
evaluationParams: evaluationParams,
|
||||||
|
@ -125,11 +159,9 @@ module Truncate = {
|
||||||
| (Some(lc), Some(rc), t) when lc > rc =>
|
| (Some(lc), Some(rc), t) when lc > rc =>
|
||||||
`Error("Left truncation bound must be smaller than right bound.")
|
`Error("Left truncation bound must be smaller than right bound.")
|
||||||
| (lc, rc, `SymbolicDist(`Uniform(u))) =>
|
| (lc, rc, `SymbolicDist(`Uniform(u))) =>
|
||||||
// just create a new Uniform distribution
|
`Solution(
|
||||||
let nu: SymbolicTypes.uniform = u;
|
`SymbolicDist(`Uniform(SymbolicDist.Uniform.truncate(lc, rc, u))),
|
||||||
let newLow = max(E.O.default(neg_infinity, lc), nu.low);
|
)
|
||||||
let newHigh = min(E.O.default(infinity, rc), nu.high);
|
|
||||||
`Solution(`SymbolicDist(`Uniform({low: newLow, high: newHigh})));
|
|
||||||
| _ => `NoSolution
|
| _ => `NoSolution
|
||||||
};
|
};
|
||||||
};
|
};
|
||||||
|
@ -140,9 +172,7 @@ module Truncate = {
|
||||||
// of a distribution we otherwise wouldn't get at all
|
// of a distribution we otherwise wouldn't get at all
|
||||||
switch (render(evaluationParams, t)) {
|
switch (render(evaluationParams, t)) {
|
||||||
| Ok(`RenderedDist(rs)) =>
|
| Ok(`RenderedDist(rs)) =>
|
||||||
let truncatedShape =
|
Ok(`RenderedDist(Shape.T.truncate(leftCutoff, rightCutoff, rs)))
|
||||||
rs |> Shape.T.truncate(leftCutoff, rightCutoff);
|
|
||||||
Ok(`RenderedDist(truncatedShape));
|
|
||||||
| Error(e) => Error(e)
|
| Error(e) => Error(e)
|
||||||
| _ => Error("Could not truncate distribution.")
|
| _ => Error("Could not truncate distribution.")
|
||||||
};
|
};
|
||||||
|
@ -171,8 +201,7 @@ module Truncate = {
|
||||||
module Normalize = {
|
module Normalize = {
|
||||||
let rec operationToLeaf = (evaluationParams, t: node): result(node, string) => {
|
let rec operationToLeaf = (evaluationParams, t: node): result(node, string) => {
|
||||||
switch (t) {
|
switch (t) {
|
||||||
| `RenderedDist(s) =>
|
| `RenderedDist(s) => Ok(`RenderedDist(Shape.T.normalize(s)))
|
||||||
Ok(`RenderedDist(Shape.T.normalize(s)))
|
|
||||||
| `SymbolicDist(_) => Ok(t)
|
| `SymbolicDist(_) => Ok(t)
|
||||||
| _ => evaluateAndRetry(evaluationParams, operationToLeaf, t)
|
| _ => evaluateAndRetry(evaluationParams, operationToLeaf, t)
|
||||||
};
|
};
|
||||||
|
|
|
@ -29,6 +29,17 @@ module ExpressionTree = {
|
||||||
|
|
||||||
let evaluateAndRetry = (evaluationParams, fn, node) =>
|
let evaluateAndRetry = (evaluationParams, fn, node) =>
|
||||||
node |> evaluationParams.evaluateNode(evaluationParams) |> E.R.bind(_, fn(evaluationParams));
|
node |> evaluationParams.evaluateNode(evaluationParams) |> E.R.bind(_, fn(evaluationParams));
|
||||||
|
|
||||||
|
let renderable = fun
|
||||||
|
| `SymbolicDist(_) => true
|
||||||
|
| `RenderedDist(_) => true
|
||||||
|
| _ => false;
|
||||||
|
|
||||||
|
let mapRenderable = (renderedFn, symFn, item: node) => switch(item) {
|
||||||
|
| `SymbolicDist(s) => Some(symFn(s))
|
||||||
|
| `RenderedDist(r) => Some(renderedFn(r))
|
||||||
|
| _ => None
|
||||||
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
type simplificationResult = [
|
type simplificationResult = [
|
||||||
|
|
|
@ -130,6 +130,11 @@ module Uniform = {
|
||||||
let sample = (t: t) => Jstat.uniform##sample(t.low, t.high);
|
let sample = (t: t) => Jstat.uniform##sample(t.low, t.high);
|
||||||
let mean = (t: t) => Ok(Jstat.uniform##mean(t.low, t.high));
|
let mean = (t: t) => Ok(Jstat.uniform##mean(t.low, t.high));
|
||||||
let toString = ({low, high}: t) => {j|Uniform($low,$high)|j};
|
let toString = ({low, high}: t) => {j|Uniform($low,$high)|j};
|
||||||
|
let truncate = (low, high, t: t): t => {
|
||||||
|
let newLow = max(E.O.default(neg_infinity, low), t.low);
|
||||||
|
let newHigh = min(E.O.default(infinity, high), t.high);
|
||||||
|
{low: newLow, high: newHigh};
|
||||||
|
};
|
||||||
};
|
};
|
||||||
|
|
||||||
module Float = {
|
module Float = {
|
||||||
|
@ -178,7 +183,20 @@ module T = {
|
||||||
| `Lognormal(n) => Lognormal.sample(n)
|
| `Lognormal(n) => Lognormal.sample(n)
|
||||||
| `Uniform(n) => Uniform.sample(n)
|
| `Uniform(n) => Uniform.sample(n)
|
||||||
| `Beta(n) => Beta.sample(n)
|
| `Beta(n) => Beta.sample(n)
|
||||||
| `Float(n) => Float.sample(n)
|
| `Float(n) => Float.sample(n);
|
||||||
|
|
||||||
|
let doN = (n, fn) => {
|
||||||
|
let items = Belt.Array.make(n, 0.0);
|
||||||
|
for (x in 0 to n - 1) {
|
||||||
|
let _ = Belt.Array.set(items, x, fn());
|
||||||
|
();
|
||||||
|
};
|
||||||
|
items;
|
||||||
|
};
|
||||||
|
|
||||||
|
let sampleN = (n, dist) => {
|
||||||
|
doN(n, () => sample(dist));
|
||||||
|
};
|
||||||
|
|
||||||
let toString: symbolicDist => string =
|
let toString: symbolicDist => string =
|
||||||
fun
|
fun
|
||||||
|
@ -189,7 +207,7 @@ module T = {
|
||||||
| `Lognormal(n) => Lognormal.toString(n)
|
| `Lognormal(n) => Lognormal.toString(n)
|
||||||
| `Uniform(n) => Uniform.toString(n)
|
| `Uniform(n) => Uniform.toString(n)
|
||||||
| `Beta(n) => Beta.toString(n)
|
| `Beta(n) => Beta.toString(n)
|
||||||
| `Float(n) => Float.toString(n)
|
| `Float(n) => Float.toString(n);
|
||||||
|
|
||||||
let min: symbolicDist => float =
|
let min: symbolicDist => float =
|
||||||
fun
|
fun
|
||||||
|
@ -237,10 +255,10 @@ module T = {
|
||||||
switch (xSelection, dist) {
|
switch (xSelection, dist) {
|
||||||
| (`Linear, _) => E.A.Floats.range(min(dist), max(dist), n)
|
| (`Linear, _) => E.A.Floats.range(min(dist), max(dist), n)
|
||||||
| (`ByWeight, `Uniform(n)) =>
|
| (`ByWeight, `Uniform(n)) =>
|
||||||
// In `ByWeight mode, uniform distributions get special treatment because we need two x's
|
// In `ByWeight mode, uniform distributions get special treatment because we need two x's
|
||||||
// on either side for proper rendering (just left and right of the discontinuities).
|
// on either side for proper rendering (just left and right of the discontinuities).
|
||||||
let dx = 0.00001 *. (n.high -. n.low);
|
let dx = 0.00001 *. (n.high -. n.low);
|
||||||
[|n.low -. dx, n.low +. dx, n.high -. dx, n.high +. dx|];
|
[|n.low -. dx, n.low +. dx, n.high -. dx, n.high +. dx|];
|
||||||
| (`ByWeight, _) =>
|
| (`ByWeight, _) =>
|
||||||
let ys = E.A.Floats.range(minCdfValue, maxCdfValue, n);
|
let ys = E.A.Floats.range(minCdfValue, maxCdfValue, n);
|
||||||
ys |> E.A.fmap(y => inv(y, dist));
|
ys |> E.A.fmap(y => inv(y, dist));
|
||||||
|
@ -277,14 +295,10 @@ module T = {
|
||||||
let toShape = (sampleCount, d: symbolicDist): DistTypes.shape =>
|
let toShape = (sampleCount, d: symbolicDist): DistTypes.shape =>
|
||||||
switch (d) {
|
switch (d) {
|
||||||
| `Float(v) =>
|
| `Float(v) =>
|
||||||
Discrete(
|
Discrete(Discrete.make({xs: [|v|], ys: [|1.0|]}, Some(1.0)))
|
||||||
Discrete.make({xs: [|v|], ys: [|1.0|]}, Some(1.0)),
|
|
||||||
)
|
|
||||||
| _ =>
|
| _ =>
|
||||||
let xs = interpolateXs(~xSelection=`ByWeight, d, sampleCount);
|
let xs = interpolateXs(~xSelection=`ByWeight, d, sampleCount);
|
||||||
let ys = xs |> E.A.fmap(x => pdf(x, d));
|
let ys = xs |> E.A.fmap(x => pdf(x, d));
|
||||||
Continuous(
|
Continuous(Continuous.make(`Linear, {xs, ys}, Some(1.0)));
|
||||||
Continuous.make(`Linear, {xs, ys}, Some(1.0)),
|
|
||||||
);
|
|
||||||
};
|
};
|
||||||
};
|
};
|
||||||
|
|
|
@ -145,6 +145,12 @@ module R = {
|
||||||
let id = e => e |> result(U.id, U.id);
|
let id = e => e |> result(U.id, U.id);
|
||||||
let fmap = Rationale.Result.fmap;
|
let fmap = Rationale.Result.fmap;
|
||||||
let bind = Rationale.Result.bind;
|
let bind = Rationale.Result.bind;
|
||||||
|
let merge = (a, b) =>
|
||||||
|
switch (a, b) {
|
||||||
|
| (Error(e), _) => Error(e)
|
||||||
|
| (_, Error(e)) => Error(e)
|
||||||
|
| (Ok(a), Ok(b)) => Ok((a, b))
|
||||||
|
};
|
||||||
let toOption = (e: Belt.Result.t('a, 'b)) =>
|
let toOption = (e: Belt.Result.t('a, 'b)) =>
|
||||||
switch (e) {
|
switch (e) {
|
||||||
| Ok(r) => Some(r)
|
| Ok(r) => Some(r)
|
||||||
|
|
Loading…
Reference in New Issue
Block a user