Simple node score refactor for analytical/sampling decisions
This commit is contained in:
parent
ac7b1ee9d5
commit
b72dbab863
|
@ -32,24 +32,19 @@ module AlgebraicCombination = {
|
|||
);
|
||||
};
|
||||
|
||||
let nodeScore: node => int =
|
||||
fun
|
||||
| `SymbolicDist(`Float(_)) => 1
|
||||
| `SymbolicDist(_) => 1000
|
||||
| `RenderedDist(Discrete(m)) => m.xyShape |> XYShape.T.length
|
||||
| `RenderedDist(Mixed(_)) => 1000
|
||||
| `RenderedDist(Continuous(_)) => 1000
|
||||
| _ => 1000;
|
||||
|
||||
let choose = (t1: node, t2: node) => {
|
||||
let dLength = (r: DistTypes.discreteShape) =>
|
||||
r.xyShape |> XYShape.T.length;
|
||||
switch (t1, t2) {
|
||||
| (`SymbolicDist(`Float(_)), _)
|
||||
| (_, `SymbolicDist(`Float(_))) => `Analytical
|
||||
| (`RenderedDist(Continuous(_)), `RenderedDist(Continuous(_))) => `Sampling
|
||||
| (`RenderedDist(Discrete(m1)), `RenderedDist(Discrete(m2)))
|
||||
when dLength(m1) * dLength(m2) > 1000 => `Analytical
|
||||
| (`RenderedDist(Discrete(_)), `RenderedDist(Discrete(_))) => `Sampling
|
||||
| (`RenderedDist(Discrete(d)), `SymbolicDist(_))
|
||||
| (`SymbolicDist(_), `RenderedDist(Discrete(d)))
|
||||
| (`RenderedDist(Discrete(d)), `RenderedDist(Continuous(_)))
|
||||
| (`RenderedDist(Continuous(_)), `RenderedDist(Discrete(d))) =>
|
||||
dLength(d) > 10 ? `Sampling : `Analytical
|
||||
| _ => `Sampling
|
||||
};
|
||||
nodeScore(t1) * nodeScore(t2) > 10000 ? `Sampling : `Analytical;
|
||||
};
|
||||
|
||||
let combine =
|
||||
(evaluationParams, algebraicOp, t1: node, t2: node)
|
||||
: result(node, string) => {
|
||||
|
@ -159,6 +154,7 @@ module PointwiseCombination = {
|
|||
let pointwiseMultiply = (evaluationParams: evaluationParams, t1: t, t2: t) => {
|
||||
// TODO: construct a function that we can easily sample from, to construct
|
||||
// a RenderedDist. Use the xMin and xMax of the rendered shapes to tell the sampling function where to look.
|
||||
// TODO: This should work for symbolic distributions too!
|
||||
switch (
|
||||
Render.render(evaluationParams, t1),
|
||||
Render.render(evaluationParams, t2),
|
||||
|
@ -189,7 +185,7 @@ module Truncate = {
|
|||
let trySimplification = (leftCutoff, rightCutoff, t): simplificationResult => {
|
||||
switch (leftCutoff, rightCutoff, t) {
|
||||
| (None, None, t) => `Solution(t)
|
||||
| (Some(lc), Some(rc), t) when lc > rc =>
|
||||
| (Some(lc), Some(rc), _) when lc > rc =>
|
||||
`Error(
|
||||
"Left truncation bound must be smaller than right truncation bound.",
|
||||
)
|
||||
|
|
Loading…
Reference in New Issue
Block a user