Simple node score refactor for analytical/sampling decisions
This commit is contained in:
parent
ac7b1ee9d5
commit
b72dbab863
|
@ -32,24 +32,19 @@ module AlgebraicCombination = {
|
||||||
);
|
);
|
||||||
};
|
};
|
||||||
|
|
||||||
|
let nodeScore: node => int =
|
||||||
|
fun
|
||||||
|
| `SymbolicDist(`Float(_)) => 1
|
||||||
|
| `SymbolicDist(_) => 1000
|
||||||
|
| `RenderedDist(Discrete(m)) => m.xyShape |> XYShape.T.length
|
||||||
|
| `RenderedDist(Mixed(_)) => 1000
|
||||||
|
| `RenderedDist(Continuous(_)) => 1000
|
||||||
|
| _ => 1000;
|
||||||
|
|
||||||
let choose = (t1: node, t2: node) => {
|
let choose = (t1: node, t2: node) => {
|
||||||
let dLength = (r: DistTypes.discreteShape) =>
|
nodeScore(t1) * nodeScore(t2) > 10000 ? `Sampling : `Analytical;
|
||||||
r.xyShape |> XYShape.T.length;
|
|
||||||
switch (t1, t2) {
|
|
||||||
| (`SymbolicDist(`Float(_)), _)
|
|
||||||
| (_, `SymbolicDist(`Float(_))) => `Analytical
|
|
||||||
| (`RenderedDist(Continuous(_)), `RenderedDist(Continuous(_))) => `Sampling
|
|
||||||
| (`RenderedDist(Discrete(m1)), `RenderedDist(Discrete(m2)))
|
|
||||||
when dLength(m1) * dLength(m2) > 1000 => `Analytical
|
|
||||||
| (`RenderedDist(Discrete(_)), `RenderedDist(Discrete(_))) => `Sampling
|
|
||||||
| (`RenderedDist(Discrete(d)), `SymbolicDist(_))
|
|
||||||
| (`SymbolicDist(_), `RenderedDist(Discrete(d)))
|
|
||||||
| (`RenderedDist(Discrete(d)), `RenderedDist(Continuous(_)))
|
|
||||||
| (`RenderedDist(Continuous(_)), `RenderedDist(Discrete(d))) =>
|
|
||||||
dLength(d) > 10 ? `Sampling : `Analytical
|
|
||||||
| _ => `Sampling
|
|
||||||
};
|
|
||||||
};
|
};
|
||||||
|
|
||||||
let combine =
|
let combine =
|
||||||
(evaluationParams, algebraicOp, t1: node, t2: node)
|
(evaluationParams, algebraicOp, t1: node, t2: node)
|
||||||
: result(node, string) => {
|
: result(node, string) => {
|
||||||
|
@ -159,6 +154,7 @@ module PointwiseCombination = {
|
||||||
let pointwiseMultiply = (evaluationParams: evaluationParams, t1: t, t2: t) => {
|
let pointwiseMultiply = (evaluationParams: evaluationParams, t1: t, t2: t) => {
|
||||||
// TODO: construct a function that we can easily sample from, to construct
|
// TODO: construct a function that we can easily sample from, to construct
|
||||||
// a RenderedDist. Use the xMin and xMax of the rendered shapes to tell the sampling function where to look.
|
// a RenderedDist. Use the xMin and xMax of the rendered shapes to tell the sampling function where to look.
|
||||||
|
// TODO: This should work for symbolic distributions too!
|
||||||
switch (
|
switch (
|
||||||
Render.render(evaluationParams, t1),
|
Render.render(evaluationParams, t1),
|
||||||
Render.render(evaluationParams, t2),
|
Render.render(evaluationParams, t2),
|
||||||
|
@ -189,7 +185,7 @@ module Truncate = {
|
||||||
let trySimplification = (leftCutoff, rightCutoff, t): simplificationResult => {
|
let trySimplification = (leftCutoff, rightCutoff, t): simplificationResult => {
|
||||||
switch (leftCutoff, rightCutoff, t) {
|
switch (leftCutoff, rightCutoff, t) {
|
||||||
| (None, None, t) => `Solution(t)
|
| (None, None, t) => `Solution(t)
|
||||||
| (Some(lc), Some(rc), t) when lc > rc =>
|
| (Some(lc), Some(rc), _) when lc > rc =>
|
||||||
`Error(
|
`Error(
|
||||||
"Left truncation bound must be smaller than right truncation bound.",
|
"Left truncation bound must be smaller than right truncation bound.",
|
||||||
)
|
)
|
||||||
|
|
Loading…
Reference in New Issue
Block a user