Simple node score refactor for analytical/sampling decisions

This commit is contained in:
Ozzie Gooen 2020-07-19 23:48:37 +01:00
parent ac7b1ee9d5
commit b72dbab863

View File

@ -32,24 +32,19 @@ module AlgebraicCombination = {
);
};
let nodeScore: node => int =
fun
| `SymbolicDist(`Float(_)) => 1
| `SymbolicDist(_) => 1000
| `RenderedDist(Discrete(m)) => m.xyShape |> XYShape.T.length
| `RenderedDist(Mixed(_)) => 1000
| `RenderedDist(Continuous(_)) => 1000
| _ => 1000;
let choose = (t1: node, t2: node) => {
let dLength = (r: DistTypes.discreteShape) =>
r.xyShape |> XYShape.T.length;
switch (t1, t2) {
| (`SymbolicDist(`Float(_)), _)
| (_, `SymbolicDist(`Float(_))) => `Analytical
| (`RenderedDist(Continuous(_)), `RenderedDist(Continuous(_))) => `Sampling
| (`RenderedDist(Discrete(m1)), `RenderedDist(Discrete(m2)))
when dLength(m1) * dLength(m2) > 1000 => `Analytical
| (`RenderedDist(Discrete(_)), `RenderedDist(Discrete(_))) => `Sampling
| (`RenderedDist(Discrete(d)), `SymbolicDist(_))
| (`SymbolicDist(_), `RenderedDist(Discrete(d)))
| (`RenderedDist(Discrete(d)), `RenderedDist(Continuous(_)))
| (`RenderedDist(Continuous(_)), `RenderedDist(Discrete(d))) =>
dLength(d) > 10 ? `Sampling : `Analytical
| _ => `Sampling
};
nodeScore(t1) * nodeScore(t2) > 10000 ? `Sampling : `Analytical;
};
let combine =
(evaluationParams, algebraicOp, t1: node, t2: node)
: result(node, string) => {
@ -159,6 +154,7 @@ module PointwiseCombination = {
let pointwiseMultiply = (evaluationParams: evaluationParams, t1: t, t2: t) => {
// TODO: construct a function that we can easily sample from, to construct
// a RenderedDist. Use the xMin and xMax of the rendered shapes to tell the sampling function where to look.
// TODO: This should work for symbolic distributions too!
switch (
Render.render(evaluationParams, t1),
Render.render(evaluationParams, t2),
@ -189,7 +185,7 @@ module Truncate = {
let trySimplification = (leftCutoff, rightCutoff, t): simplificationResult => {
switch (leftCutoff, rightCutoff, t) {
| (None, None, t) => `Solution(t)
| (Some(lc), Some(rc), t) when lc > rc =>
| (Some(lc), Some(rc), _) when lc > rc =>
`Error(
"Left truncation bound must be smaller than right truncation bound.",
)