Change back to use combinationByRendering for this PR

This commit is contained in:
Ozzie Gooen 2020-07-17 14:14:26 +01:00
parent 1dfea41ab8
commit f09caaae5e
2 changed files with 80 additions and 21 deletions

View File

@ -20,7 +20,7 @@ module AlgebraicCombination = {
| _ => Ok(`AlgebraicCombination((operation, t1, t2))) | _ => Ok(`AlgebraicCombination((operation, t1, t2)))
}; };
let tryCombination = (n, algebraicOp, t1: node, t2: node) => { let combinationBySampling = (n, algebraicOp, t1: node, t2: node) => {
let sampleN = let sampleN =
mapRenderable(Shape.sampleNRendered(n), SymbolicDist.T.sampleN(n)); mapRenderable(Shape.sampleNRendered(n), SymbolicDist.T.sampleN(n));
switch (sampleN(t1), sampleN(t2)) { switch (sampleN(t1), sampleN(t2)) {
@ -33,24 +33,28 @@ module AlgebraicCombination = {
}; };
}; };
let renderIfNotRendered = (params, t) => let combinationByRendering =
!renderable(t) (evaluationParams, algebraicOp, t1: node, t2: node)
? switch (render(params, t)) { : result(node, string) => {
| Ok(r) => Ok(r) E.R.merge(
| Error(e) => Error(e) renderAndGetShape(evaluationParams, t1),
} renderAndGetShape(evaluationParams, t2),
: Ok(t); )
|> E.R.fmap(((a, b)) =>
`RenderedDist(Shape.combineAlgebraically(algebraicOp, a, b))
);
};
let combineAsShapes = let combineShapesUsingSampling =
(evaluationParams: evaluationParams, algebraicOp, t1: node, t2: node) => { (evaluationParams: evaluationParams, algebraicOp, t1: node, t2: node) => {
let i1 = renderIfNotRendered(evaluationParams, t1); let i1 = renderIfNotRenderable(evaluationParams, t1);
let i2 = renderIfNotRendered(evaluationParams, t2); let i2 = renderIfNotRenderable(evaluationParams, t2);
E.R.merge(i1, i2) E.R.merge(i1, i2)
|> E.R.bind( |> E.R.bind(
_, _,
((a, b)) => { ((a, b)) => {
let samples = let samples =
tryCombination( combinationBySampling(
evaluationParams.samplingInputs.sampleCount, evaluationParams.samplingInputs.sampleCount,
algebraicOp, algebraicOp,
a, a,
@ -92,7 +96,7 @@ module AlgebraicCombination = {
_, _,
fun fun
| `SymbolicDist(d) as t => Ok(t) | `SymbolicDist(d) as t => Ok(t)
| _ => combineAsShapes(evaluationParams, algebraicOp, t1, t2), | _ => combinationByRendering(evaluationParams, algebraicOp, t1, t2),
); );
}; };
@ -131,7 +135,16 @@ module PointwiseCombination = {
`RenderedDist( `RenderedDist(
Shape.combinePointwise( Shape.combinePointwise(
~integralSumCachesFn=(a, b) => Some(a +. b), ~integralSumCachesFn=(a, b) => Some(a +. b),
~integralCachesFn=(a, b) => Some(Continuous.combinePointwise(~extrapolation=`UseOutermostPoints, (+.), a, b)), ~integralCachesFn=
(a, b) =>
Some(
Continuous.combinePointwise(
~extrapolation=`UseOutermostPoints,
(+.),
a,
b,
),
),
(+.), (+.),
rs1, rs1,
rs2, rs2,
@ -153,7 +166,12 @@ module PointwiseCombination = {
}; };
let operationToLeaf = let operationToLeaf =
(evaluationParams: evaluationParams, pointwiseOp: pointwiseOperation, t1: t, t2: t) => { (
evaluationParams: evaluationParams,
pointwiseOp: pointwiseOperation,
t1: t,
t2: t,
) => {
switch (pointwiseOp) { switch (pointwiseOp) {
| `Add => pointwiseAdd(evaluationParams, t1, t2) | `Add => pointwiseAdd(evaluationParams, t1, t2)
| `Multiply => pointwiseMultiply(evaluationParams, t1, t2) | `Multiply => pointwiseMultiply(evaluationParams, t1, t2)

View File

@ -1,7 +1,13 @@
type algebraicOperation = [ | `Add | `Multiply | `Subtract | `Divide]; type algebraicOperation = [ | `Add | `Multiply | `Subtract | `Divide];
type pointwiseOperation = [ | `Add | `Multiply]; type pointwiseOperation = [ | `Add | `Multiply];
type scaleOperation = [ | `Multiply | `Exponentiate | `Log]; type scaleOperation = [ | `Multiply | `Exponentiate | `Log];
type distToFloatOperation = [ | `Pdf(float) | `Cdf(float) | `Inv(float) | `Mean | `Sample]; type distToFloatOperation = [
| `Pdf(float)
| `Cdf(float)
| `Inv(float)
| `Mean
| `Sample
];
module ExpressionTree = { module ExpressionTree = {
type node = [ type node = [
@ -34,17 +40,52 @@ module ExpressionTree = {
let render = (evaluationParams: evaluationParams, r) => let render = (evaluationParams: evaluationParams, r) =>
evaluateNode(evaluationParams, `Render(r)); evaluateNode(evaluationParams, `Render(r));
let evaluateAndRetry = (evaluationParams, fn, node) =>
node
|> evaluationParams.evaluateNode(evaluationParams)
|> E.R.bind(_, fn(evaluationParams));
let renderable = let renderable =
fun fun
| `SymbolicDist(_) => true | `SymbolicDist(_) => true
| `RenderedDist(_) => true | `RenderedDist(_) => true
| _ => false; | _ => false;
let renderIfNotRenderable = (params, t) =>
!renderable(t)
? switch (render(params, t)) {
| Ok(r) => Ok(r)
| Error(e) => Error(e)
}
: Ok(t);
let renderIfNotRendered = (params, t) =>
switch (t) {
| `RenderedDist(_) => Ok(t)
| _ =>
switch (render(params, t)) {
| Ok(r) => Ok(r)
| Error(e) => Error(e)
}
};
let evaluateAndRetry = (evaluationParams, fn, node) =>
node
|> evaluationParams.evaluateNode(evaluationParams)
|> E.R.bind(_, fn(evaluationParams));
let renderedShape = (item: node) =>
switch (item) {
| `RenderedDist(r) => Some(r)
| _ => None
};
let renderAndGetShape = (params, t) =>
switch (renderIfNotRendered(params, t)) {
| Ok(`RenderedDist(r)) => Ok(r)
| Error(r) =>
Js.log(r);
Error(r);
| Ok(l) =>
Js.log(l);
Error("Did not render as requested");
};
let mapRenderable = (renderedFn, symFn, item: node) => let mapRenderable = (renderedFn, symFn, item: node) =>
switch (item) { switch (item) {
| `SymbolicDist(s) => Some(symFn(s)) | `SymbolicDist(s) => Some(symFn(s))