Change back to use combinationByRendering for this PR
This commit is contained in:
parent
1dfea41ab8
commit
f09caaae5e
|
@ -20,7 +20,7 @@ module AlgebraicCombination = {
|
||||||
| _ => Ok(`AlgebraicCombination((operation, t1, t2)))
|
| _ => Ok(`AlgebraicCombination((operation, t1, t2)))
|
||||||
};
|
};
|
||||||
|
|
||||||
let tryCombination = (n, algebraicOp, t1: node, t2: node) => {
|
let combinationBySampling = (n, algebraicOp, t1: node, t2: node) => {
|
||||||
let sampleN =
|
let sampleN =
|
||||||
mapRenderable(Shape.sampleNRendered(n), SymbolicDist.T.sampleN(n));
|
mapRenderable(Shape.sampleNRendered(n), SymbolicDist.T.sampleN(n));
|
||||||
switch (sampleN(t1), sampleN(t2)) {
|
switch (sampleN(t1), sampleN(t2)) {
|
||||||
|
@ -33,24 +33,28 @@ module AlgebraicCombination = {
|
||||||
};
|
};
|
||||||
};
|
};
|
||||||
|
|
||||||
let renderIfNotRendered = (params, t) =>
|
let combinationByRendering =
|
||||||
!renderable(t)
|
(evaluationParams, algebraicOp, t1: node, t2: node)
|
||||||
? switch (render(params, t)) {
|
: result(node, string) => {
|
||||||
| Ok(r) => Ok(r)
|
E.R.merge(
|
||||||
| Error(e) => Error(e)
|
renderAndGetShape(evaluationParams, t1),
|
||||||
}
|
renderAndGetShape(evaluationParams, t2),
|
||||||
: Ok(t);
|
)
|
||||||
|
|> E.R.fmap(((a, b)) =>
|
||||||
|
`RenderedDist(Shape.combineAlgebraically(algebraicOp, a, b))
|
||||||
|
);
|
||||||
|
};
|
||||||
|
|
||||||
let combineAsShapes =
|
let combineShapesUsingSampling =
|
||||||
(evaluationParams: evaluationParams, algebraicOp, t1: node, t2: node) => {
|
(evaluationParams: evaluationParams, algebraicOp, t1: node, t2: node) => {
|
||||||
let i1 = renderIfNotRendered(evaluationParams, t1);
|
let i1 = renderIfNotRenderable(evaluationParams, t1);
|
||||||
let i2 = renderIfNotRendered(evaluationParams, t2);
|
let i2 = renderIfNotRenderable(evaluationParams, t2);
|
||||||
E.R.merge(i1, i2)
|
E.R.merge(i1, i2)
|
||||||
|> E.R.bind(
|
|> E.R.bind(
|
||||||
_,
|
_,
|
||||||
((a, b)) => {
|
((a, b)) => {
|
||||||
let samples =
|
let samples =
|
||||||
tryCombination(
|
combinationBySampling(
|
||||||
evaluationParams.samplingInputs.sampleCount,
|
evaluationParams.samplingInputs.sampleCount,
|
||||||
algebraicOp,
|
algebraicOp,
|
||||||
a,
|
a,
|
||||||
|
@ -92,7 +96,7 @@ module AlgebraicCombination = {
|
||||||
_,
|
_,
|
||||||
fun
|
fun
|
||||||
| `SymbolicDist(d) as t => Ok(t)
|
| `SymbolicDist(d) as t => Ok(t)
|
||||||
| _ => combineAsShapes(evaluationParams, algebraicOp, t1, t2),
|
| _ => combinationByRendering(evaluationParams, algebraicOp, t1, t2),
|
||||||
);
|
);
|
||||||
};
|
};
|
||||||
|
|
||||||
|
@ -131,7 +135,16 @@ module PointwiseCombination = {
|
||||||
`RenderedDist(
|
`RenderedDist(
|
||||||
Shape.combinePointwise(
|
Shape.combinePointwise(
|
||||||
~integralSumCachesFn=(a, b) => Some(a +. b),
|
~integralSumCachesFn=(a, b) => Some(a +. b),
|
||||||
~integralCachesFn=(a, b) => Some(Continuous.combinePointwise(~extrapolation=`UseOutermostPoints, (+.), a, b)),
|
~integralCachesFn=
|
||||||
|
(a, b) =>
|
||||||
|
Some(
|
||||||
|
Continuous.combinePointwise(
|
||||||
|
~extrapolation=`UseOutermostPoints,
|
||||||
|
(+.),
|
||||||
|
a,
|
||||||
|
b,
|
||||||
|
),
|
||||||
|
),
|
||||||
(+.),
|
(+.),
|
||||||
rs1,
|
rs1,
|
||||||
rs2,
|
rs2,
|
||||||
|
@ -153,7 +166,12 @@ module PointwiseCombination = {
|
||||||
};
|
};
|
||||||
|
|
||||||
let operationToLeaf =
|
let operationToLeaf =
|
||||||
(evaluationParams: evaluationParams, pointwiseOp: pointwiseOperation, t1: t, t2: t) => {
|
(
|
||||||
|
evaluationParams: evaluationParams,
|
||||||
|
pointwiseOp: pointwiseOperation,
|
||||||
|
t1: t,
|
||||||
|
t2: t,
|
||||||
|
) => {
|
||||||
switch (pointwiseOp) {
|
switch (pointwiseOp) {
|
||||||
| `Add => pointwiseAdd(evaluationParams, t1, t2)
|
| `Add => pointwiseAdd(evaluationParams, t1, t2)
|
||||||
| `Multiply => pointwiseMultiply(evaluationParams, t1, t2)
|
| `Multiply => pointwiseMultiply(evaluationParams, t1, t2)
|
||||||
|
|
|
@ -1,7 +1,13 @@
|
||||||
type algebraicOperation = [ | `Add | `Multiply | `Subtract | `Divide];
|
type algebraicOperation = [ | `Add | `Multiply | `Subtract | `Divide];
|
||||||
type pointwiseOperation = [ | `Add | `Multiply];
|
type pointwiseOperation = [ | `Add | `Multiply];
|
||||||
type scaleOperation = [ | `Multiply | `Exponentiate | `Log];
|
type scaleOperation = [ | `Multiply | `Exponentiate | `Log];
|
||||||
type distToFloatOperation = [ | `Pdf(float) | `Cdf(float) | `Inv(float) | `Mean | `Sample];
|
type distToFloatOperation = [
|
||||||
|
| `Pdf(float)
|
||||||
|
| `Cdf(float)
|
||||||
|
| `Inv(float)
|
||||||
|
| `Mean
|
||||||
|
| `Sample
|
||||||
|
];
|
||||||
|
|
||||||
module ExpressionTree = {
|
module ExpressionTree = {
|
||||||
type node = [
|
type node = [
|
||||||
|
@ -34,17 +40,52 @@ module ExpressionTree = {
|
||||||
let render = (evaluationParams: evaluationParams, r) =>
|
let render = (evaluationParams: evaluationParams, r) =>
|
||||||
evaluateNode(evaluationParams, `Render(r));
|
evaluateNode(evaluationParams, `Render(r));
|
||||||
|
|
||||||
let evaluateAndRetry = (evaluationParams, fn, node) =>
|
|
||||||
node
|
|
||||||
|> evaluationParams.evaluateNode(evaluationParams)
|
|
||||||
|> E.R.bind(_, fn(evaluationParams));
|
|
||||||
|
|
||||||
let renderable =
|
let renderable =
|
||||||
fun
|
fun
|
||||||
| `SymbolicDist(_) => true
|
| `SymbolicDist(_) => true
|
||||||
| `RenderedDist(_) => true
|
| `RenderedDist(_) => true
|
||||||
| _ => false;
|
| _ => false;
|
||||||
|
|
||||||
|
let renderIfNotRenderable = (params, t) =>
|
||||||
|
!renderable(t)
|
||||||
|
? switch (render(params, t)) {
|
||||||
|
| Ok(r) => Ok(r)
|
||||||
|
| Error(e) => Error(e)
|
||||||
|
}
|
||||||
|
: Ok(t);
|
||||||
|
|
||||||
|
let renderIfNotRendered = (params, t) =>
|
||||||
|
switch (t) {
|
||||||
|
| `RenderedDist(_) => Ok(t)
|
||||||
|
| _ =>
|
||||||
|
switch (render(params, t)) {
|
||||||
|
| Ok(r) => Ok(r)
|
||||||
|
| Error(e) => Error(e)
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
let evaluateAndRetry = (evaluationParams, fn, node) =>
|
||||||
|
node
|
||||||
|
|> evaluationParams.evaluateNode(evaluationParams)
|
||||||
|
|> E.R.bind(_, fn(evaluationParams));
|
||||||
|
|
||||||
|
let renderedShape = (item: node) =>
|
||||||
|
switch (item) {
|
||||||
|
| `RenderedDist(r) => Some(r)
|
||||||
|
| _ => None
|
||||||
|
};
|
||||||
|
|
||||||
|
let renderAndGetShape = (params, t) =>
|
||||||
|
switch (renderIfNotRendered(params, t)) {
|
||||||
|
| Ok(`RenderedDist(r)) => Ok(r)
|
||||||
|
| Error(r) =>
|
||||||
|
Js.log(r);
|
||||||
|
Error(r);
|
||||||
|
| Ok(l) =>
|
||||||
|
Js.log(l);
|
||||||
|
Error("Did not render as requested");
|
||||||
|
};
|
||||||
|
|
||||||
let mapRenderable = (renderedFn, symFn, item: node) =>
|
let mapRenderable = (renderedFn, symFn, item: node) =>
|
||||||
switch (item) {
|
switch (item) {
|
||||||
| `SymbolicDist(s) => Some(symFn(s))
|
| `SymbolicDist(s) => Some(symFn(s))
|
||||||
|
|
Loading…
Reference in New Issue
Block a user