Moved data to evaluationParams

This commit is contained in:
Ozzie Gooen 2020-07-08 11:39:03 +01:00
parent da52444e8e
commit 9d0ecda297
3 changed files with 72 additions and 58 deletions

View File

@ -3,7 +3,7 @@ open ExpressionTypes.ExpressionTree;
let toShape = (sampleCount: int, node: node) => { let toShape = (sampleCount: int, node: node) => {
let renderResult = let renderResult =
`Render(`Normalize(node)) `Render(`Normalize(node))
|> ExpressionTreeEvaluator.toLeaf({sampleCount: sampleCount}); |> ExpressionTreeEvaluator.toLeaf({sampleCount: sampleCount, evaluateNode: ExpressionTreeEvaluator.toLeaf});
switch (renderResult) { switch (renderResult) {
| Ok(`RenderedDist(rs)) => | Ok(`RenderedDist(rs)) =>

View File

@ -22,8 +22,9 @@ module AlgebraicCombination = {
| _ => Ok(`AlgebraicCombination((operation, t1, t2))) | _ => Ok(`AlgebraicCombination((operation, t1, t2)))
}; };
let combineAsShapes = (toLeaf, renderParams, algebraicOp, t1, t2) => { let combineAsShapes =
let renderShape = r => toLeaf(renderParams, `Render(r)); (evaluationParams: evaluationParams, algebraicOp, t1, t2) => {
let renderShape = render(evaluationParams);
switch (renderShape(t1), renderShape(t2)) { switch (renderShape(t1), renderShape(t2)) {
| (Ok(`RenderedDist(s1)), Ok(`RenderedDist(s2))) => | (Ok(`RenderedDist(s1)), Ok(`RenderedDist(s2))) =>
Ok( Ok(
@ -39,8 +40,7 @@ module AlgebraicCombination = {
let operationToLeaf = let operationToLeaf =
( (
toLeaf, evaluationParams: evaluationParams,
renderParams: renderParams,
algebraicOp: ExpressionTypes.algebraicOperation, algebraicOp: ExpressionTypes.algebraicOperation,
t1: t, t1: t,
t2: t, t2: t,
@ -52,16 +52,17 @@ module AlgebraicCombination = {
_, _,
fun fun
| `SymbolicDist(d) as t => Ok(t) | `SymbolicDist(d) as t => Ok(t)
| _ => combineAsShapes(toLeaf, renderParams, algebraicOp, t1, t2), | _ => combineAsShapes(evaluationParams, algebraicOp, t1, t2),
); );
}; };
module VerticalScaling = { module VerticalScaling = {
let operationToLeaf = (toLeaf, renderParams, scaleOp, t, scaleBy) => { let operationToLeaf =
(evaluationParams: evaluationParams, scaleOp, t, scaleBy) => {
// scaleBy has to be a single float, otherwise we'll return an error. // scaleBy has to be a single float, otherwise we'll return an error.
let fn = Operation.Scale.toFn(scaleOp); let fn = Operation.Scale.toFn(scaleOp);
let knownIntegralSumFn = Operation.Scale.toKnownIntegralSumFn(scaleOp); let knownIntegralSumFn = Operation.Scale.toKnownIntegralSumFn(scaleOp);
let renderedShape = toLeaf(renderParams, `Render(t)); let renderedShape = render(evaluationParams, t);
switch (renderedShape, scaleBy) { switch (renderedShape, scaleBy) {
| (Ok(`RenderedDist(rs)), `SymbolicDist(`Float(sm))) => | (Ok(`RenderedDist(rs)), `SymbolicDist(`Float(sm))) =>
@ -81,9 +82,8 @@ module VerticalScaling = {
}; };
module PointwiseCombination = { module PointwiseCombination = {
let pointwiseAdd = (toLeaf, renderParams, t1, t2) => { let pointwiseAdd = (evaluationParams: evaluationParams, t1, t2) => {
let renderShape = r => toLeaf(renderParams, `Render(r)); switch (render(evaluationParams, t1), render(evaluationParams, t2)) {
switch (renderShape(t1), renderShape(t2)) {
| (Ok(`RenderedDist(rs1)), Ok(`RenderedDist(rs2))) => | (Ok(`RenderedDist(rs1)), Ok(`RenderedDist(rs2))) =>
Ok( Ok(
`RenderedDist( `RenderedDist(
@ -101,7 +101,7 @@ module PointwiseCombination = {
}; };
}; };
let pointwiseMultiply = (toLeaf, renderParams, t1, t2) => { let pointwiseMultiply = (evaluationParams: evaluationParams, t1, t2) => {
// TODO: construct a function that we can easily sample from, to construct // TODO: construct a function that we can easily sample from, to construct
// a RenderedDist. Use the xMin and xMax of the rendered shapes to tell the sampling function where to look. // a RenderedDist. Use the xMin and xMax of the rendered shapes to tell the sampling function where to look.
Error( Error(
@ -109,10 +109,11 @@ module PointwiseCombination = {
); );
}; };
let operationToLeaf = (toLeaf, renderParams, pointwiseOp, t1, t2) => { let operationToLeaf =
(evaluationParams: evaluationParams, pointwiseOp, t1, t2) => {
switch (pointwiseOp) { switch (pointwiseOp) {
| `Add => pointwiseAdd(toLeaf, renderParams, t1, t2) | `Add => pointwiseAdd(evaluationParams, t1, t2)
| `Multiply => pointwiseMultiply(toLeaf, renderParams, t1, t2) | `Multiply => pointwiseMultiply(evaluationParams, t1, t2)
}; };
}; };
}; };
@ -133,24 +134,23 @@ module Truncate = {
}; };
}; };
let truncateAsShape = (toLeaf, renderParams, leftCutoff, rightCutoff, t) => { let truncateAsShape =
(evaluationParams: evaluationParams, leftCutoff, rightCutoff, t) => {
// TODO: use named args in renderToShape; if we're lucky we can at least get the tail // TODO: use named args in renderToShape; if we're lucky we can at least get the tail
// of a distribution we otherwise wouldn't get at all // of a distribution we otherwise wouldn't get at all
let renderedShape = toLeaf(renderParams, `Render(t)); switch (render(evaluationParams, t)) {
switch (renderedShape) {
| Ok(`RenderedDist(rs)) => | Ok(`RenderedDist(rs)) =>
let truncatedShape = let truncatedShape =
rs |> Distributions.Shape.T.truncate(leftCutoff, rightCutoff); rs |> Distributions.Shape.T.truncate(leftCutoff, rightCutoff);
Ok(`RenderedDist(truncatedShape)); Ok(`RenderedDist(truncatedShape));
| Error(e1) => Error(e1) | Error(e) => Error(e)
| _ => Error("Could not truncate distribution.") | _ => Error("Could not truncate distribution.")
}; };
}; };
let operationToLeaf = let operationToLeaf =
( (
toLeaf, evaluationParams,
renderParams,
leftCutoff: option(float), leftCutoff: option(float),
rightCutoff: option(float), rightCutoff: option(float),
t: node, t: node,
@ -163,62 +163,59 @@ module Truncate = {
| `Solution(t) => Ok(t) | `Solution(t) => Ok(t)
| `Error(e) => Error(e) | `Error(e) => Error(e)
| `NoSolution => | `NoSolution =>
truncateAsShape(toLeaf, renderParams, leftCutoff, rightCutoff, t) truncateAsShape(evaluationParams, leftCutoff, rightCutoff, t)
); );
}; };
}; };
module Normalize = { module Normalize = {
let rec operationToLeaf = let rec operationToLeaf = (evaluationParams, t: node): result(node, string) => {
(toLeaf, renderParams, t: node): result(node, string) => {
switch (t) { switch (t) {
| `RenderedDist(s) => | `RenderedDist(s) =>
Ok(`RenderedDist(Distributions.Shape.T.normalize(s))) Ok(`RenderedDist(Distributions.Shape.T.normalize(s)))
| `SymbolicDist(_) => Ok(t) | `SymbolicDist(_) => Ok(t)
| _ => | _ =>
t t
|> toLeaf(renderParams) |> evaluateNode(evaluationParams)
|> E.R.bind(_, operationToLeaf(toLeaf, renderParams)) |> E.R.bind(_, operationToLeaf(evaluationParams))
}; };
}; };
}; };
module FloatFromDist = { module FloatFromDist = {
let symbolicToLeaf = (distToFloatOp: distToFloatOperation, s) => {
SymbolicDist.T.operate(distToFloatOp, s)
|> E.R.bind(_, v => Ok(`SymbolicDist(`Float(v))));
};
let renderedToLeaf =
(distToFloatOp: distToFloatOperation, rs: DistTypes.shape)
: result(node, string) => {
Distributions.Shape.operate(distToFloatOp, rs)
|> (v => Ok(`SymbolicDist(`Float(v))));
};
let rec operationToLeaf = let rec operationToLeaf =
(toLeaf, renderParams, distToFloatOp: distToFloatOperation, t: node) (evaluationParams, distToFloatOp: distToFloatOperation, t: node)
: result(node, string) => { : result(node, string) => {
switch (t) { switch (t) {
| `SymbolicDist(s) => symbolicToLeaf(distToFloatOp, s) | `SymbolicDist(s) =>
| `RenderedDist(rs) => renderedToLeaf(distToFloatOp, rs) SymbolicDist.T.operate(distToFloatOp, s)
|> E.R.bind(_, v => Ok(`SymbolicDist(`Float(v))))
| `RenderedDist(rs) =>
Distributions.Shape.operate(distToFloatOp, rs)
|> (v => Ok(`SymbolicDist(`Float(v))))
| _ => | _ =>
t t
|> toLeaf(renderParams) |> evaluateNode(evaluationParams)
|> E.R.bind(_, operationToLeaf(toLeaf, renderParams, distToFloatOp)) |> E.R.bind(_, operationToLeaf(evaluationParams, distToFloatOp))
}; };
}; };
}; };
module Render = { module Render = {
let rec operationToLeaf = let rec operationToLeaf =
(toLeaf, renderParams, t: node): result(t, string) => { (evaluationParams: evaluationParams, t: node): result(t, string) => {
switch (t) { switch (t) {
| `SymbolicDist(d) => | `SymbolicDist(d) =>
Ok(`RenderedDist(SymbolicDist.T.toShape(renderParams.sampleCount, d))) Ok(
`RenderedDist(
SymbolicDist.T.toShape(evaluationParams.sampleCount, d),
),
)
| `RenderedDist(_) as t => Ok(t) // already a rendered shape, we're done here | `RenderedDist(_) as t => Ok(t) // already a rendered shape, we're done here
| _ => | _ =>
t t
|> toLeaf(renderParams) |> evaluateNode(evaluationParams)
|> E.R.bind(_, operationToLeaf(toLeaf, renderParams)) |> E.R.bind(_, operationToLeaf(evaluationParams))
}; };
}; };
}; };
@ -229,35 +226,38 @@ module Render = {
but most often it will produce a RenderedDist. but most often it will produce a RenderedDist.
This function is used mainly to turn a parse tree into a single RenderedDist This function is used mainly to turn a parse tree into a single RenderedDist
that can then be displayed to the user. */ that can then be displayed to the user. */
let rec toLeaf = (renderParams, node: t): result(t, string) => { let toLeaf =
(
evaluationParams: ExpressionTypes.ExpressionTree.evaluationParams,
node: t,
)
: result(t, string) => {
switch (node) { switch (node) {
// Leaf nodes just stay leaf nodes // Leaf nodes just stay leaf nodes
| `SymbolicDist(_) | `SymbolicDist(_)
| `RenderedDist(_) => Ok(node) | `RenderedDist(_) => Ok(node)
// Operations need to be turned into leaves // Operations nevaluationParamsd to be turned into leaves
| `AlgebraicCombination(algebraicOp, t1, t2) => | `AlgebraicCombination(algebraicOp, t1, t2) =>
AlgebraicCombination.operationToLeaf( AlgebraicCombination.operationToLeaf(
toLeaf, evaluationParams,
renderParams,
algebraicOp, algebraicOp,
t1, t1,
t2, t2,
) )
| `PointwiseCombination(pointwiseOp, t1, t2) => | `PointwiseCombination(pointwiseOp, t1, t2) =>
PointwiseCombination.operationToLeaf( PointwiseCombination.operationToLeaf(
toLeaf, evaluationParams,
renderParams,
pointwiseOp, pointwiseOp,
t1, t1,
t2, t2,
) )
| `VerticalScaling(scaleOp, t, scaleBy) => | `VerticalScaling(scaleOp, t, scaleBy) =>
VerticalScaling.operationToLeaf(toLeaf, renderParams, scaleOp, t, scaleBy) VerticalScaling.operationToLeaf(evaluationParams, scaleOp, t, scaleBy)
| `Truncate(leftCutoff, rightCutoff, t) => | `Truncate(leftCutoff, rightCutoff, t) =>
Truncate.operationToLeaf(toLeaf, renderParams, leftCutoff, rightCutoff, t) Truncate.operationToLeaf(evaluationParams, leftCutoff, rightCutoff, t)
| `FloatFromDist(distToFloatOp, t) => | `FloatFromDist(distToFloatOp, t) =>
FloatFromDist.operationToLeaf(toLeaf, renderParams, distToFloatOp, t) FloatFromDist.operationToLeaf(evaluationParams, distToFloatOp, t)
| `Normalize(t) => Normalize.operationToLeaf(toLeaf, renderParams, t) | `Normalize(t) => Normalize.operationToLeaf(evaluationParams, t)
| `Render(t) => Render.operationToLeaf(toLeaf, renderParams, t) | `Render(t) => Render.operationToLeaf(evaluationParams, t)
}; };
}; };

View File

@ -5,10 +5,8 @@ type distToFloatOperation = [ | `Pdf(float) | `Inv(float) | `Mean | `Sample];
module ExpressionTree = { module ExpressionTree = {
type node = [ type node = [
// leaf nodes:
| `SymbolicDist(SymbolicTypes.symbolicDist) | `SymbolicDist(SymbolicTypes.symbolicDist)
| `RenderedDist(DistTypes.shape) | `RenderedDist(DistTypes.shape)
// operations:
| `AlgebraicCombination(algebraicOperation, node, node) | `AlgebraicCombination(algebraicOperation, node, node)
| `PointwiseCombination(pointwiseOperation, node, node) | `PointwiseCombination(pointwiseOperation, node, node)
| `VerticalScaling(scaleOperation, node, node) | `VerticalScaling(scaleOperation, node, node)
@ -17,6 +15,22 @@ module ExpressionTree = {
| `Normalize(node) | `Normalize(node)
| `FloatFromDist(distToFloatOperation, node) | `FloatFromDist(distToFloatOperation, node)
]; ];
type dist = [
| `SymbolicDist(SymbolicTypes.symbolicDist)
| `RenderedDist(DistTypes.shape)
]
type evaluationParams = {
sampleCount: int,
evaluateNode: (evaluationParams, node) => Belt.Result.t(node, string),
};
let evaluateNode = (evaluationParams: evaluationParams) =>
evaluationParams.evaluateNode(evaluationParams);
let render = (evaluationParams: evaluationParams, r) =>
evaluateNode(evaluationParams, `Render(r));
}; };
type simplificationResult = [ type simplificationResult = [