Added evaluateAndRetry function

This commit is contained in:
Ozzie Gooen 2020-07-08 13:52:47 +01:00
parent 9d0ecda297
commit 248545ee34
2 changed files with 14 additions and 16 deletions

View File

@ -174,10 +174,7 @@ module Normalize = {
| `RenderedDist(s) => | `RenderedDist(s) =>
Ok(`RenderedDist(Distributions.Shape.T.normalize(s))) Ok(`RenderedDist(Distributions.Shape.T.normalize(s)))
| `SymbolicDist(_) => Ok(t) | `SymbolicDist(_) => Ok(t)
| _ => | _ => evaluateAndRetry(evaluationParams, operationToLeaf, t)
t
|> evaluateNode(evaluationParams)
|> E.R.bind(_, operationToLeaf(evaluationParams))
}; };
}; };
}; };
@ -195,8 +192,9 @@ module FloatFromDist = {
|> (v => Ok(`SymbolicDist(`Float(v)))) |> (v => Ok(`SymbolicDist(`Float(v))))
| _ => | _ =>
t t
|> evaluateNode(evaluationParams) |> evaluateAndRetry(evaluationParams, r =>
|> E.R.bind(_, operationToLeaf(evaluationParams, distToFloatOp)) operationToLeaf(r, distToFloatOp)
)
}; };
}; };
}; };
@ -212,10 +210,7 @@ module Render = {
), ),
) )
| `RenderedDist(_) as t => Ok(t) // already a rendered shape, we're done here | `RenderedDist(_) as t => Ok(t) // already a rendered shape, we're done here
| _ => | _ => evaluateAndRetry(evaluationParams, operationToLeaf, t)
t
|> evaluateNode(evaluationParams)
|> E.R.bind(_, operationToLeaf(evaluationParams))
}; };
}; };
}; };

View File

@ -19,7 +19,7 @@ module ExpressionTree = {
type dist = [ type dist = [
| `SymbolicDist(SymbolicTypes.symbolicDist) | `SymbolicDist(SymbolicTypes.symbolicDist)
| `RenderedDist(DistTypes.shape) | `RenderedDist(DistTypes.shape)
] ];
type evaluationParams = { type evaluationParams = {
sampleCount: int, sampleCount: int,
@ -31,6 +31,9 @@ module ExpressionTree = {
let render = (evaluationParams: evaluationParams, r) => let render = (evaluationParams: evaluationParams, r) =>
evaluateNode(evaluationParams, `Render(r)); evaluateNode(evaluationParams, `Render(r));
let evaluateAndRetry = (evaluationParams, fn, node) =>
node |> evaluationParams.evaluateNode(evaluationParams) |> E.R.bind(_, fn(evaluationParams));
}; };
type simplificationResult = [ type simplificationResult = [