Moved data to evaluationParams
This commit is contained in:
parent
da52444e8e
commit
9d0ecda297
|
@ -3,7 +3,7 @@ open ExpressionTypes.ExpressionTree;
|
||||||
let toShape = (sampleCount: int, node: node) => {
|
let toShape = (sampleCount: int, node: node) => {
|
||||||
let renderResult =
|
let renderResult =
|
||||||
`Render(`Normalize(node))
|
`Render(`Normalize(node))
|
||||||
|> ExpressionTreeEvaluator.toLeaf({sampleCount: sampleCount});
|
|> ExpressionTreeEvaluator.toLeaf({sampleCount: sampleCount, evaluateNode: ExpressionTreeEvaluator.toLeaf});
|
||||||
|
|
||||||
switch (renderResult) {
|
switch (renderResult) {
|
||||||
| Ok(`RenderedDist(rs)) =>
|
| Ok(`RenderedDist(rs)) =>
|
||||||
|
|
|
@ -22,8 +22,9 @@ module AlgebraicCombination = {
|
||||||
| _ => Ok(`AlgebraicCombination((operation, t1, t2)))
|
| _ => Ok(`AlgebraicCombination((operation, t1, t2)))
|
||||||
};
|
};
|
||||||
|
|
||||||
let combineAsShapes = (toLeaf, renderParams, algebraicOp, t1, t2) => {
|
let combineAsShapes =
|
||||||
let renderShape = r => toLeaf(renderParams, `Render(r));
|
(evaluationParams: evaluationParams, algebraicOp, t1, t2) => {
|
||||||
|
let renderShape = render(evaluationParams);
|
||||||
switch (renderShape(t1), renderShape(t2)) {
|
switch (renderShape(t1), renderShape(t2)) {
|
||||||
| (Ok(`RenderedDist(s1)), Ok(`RenderedDist(s2))) =>
|
| (Ok(`RenderedDist(s1)), Ok(`RenderedDist(s2))) =>
|
||||||
Ok(
|
Ok(
|
||||||
|
@ -39,8 +40,7 @@ module AlgebraicCombination = {
|
||||||
|
|
||||||
let operationToLeaf =
|
let operationToLeaf =
|
||||||
(
|
(
|
||||||
toLeaf,
|
evaluationParams: evaluationParams,
|
||||||
renderParams: renderParams,
|
|
||||||
algebraicOp: ExpressionTypes.algebraicOperation,
|
algebraicOp: ExpressionTypes.algebraicOperation,
|
||||||
t1: t,
|
t1: t,
|
||||||
t2: t,
|
t2: t,
|
||||||
|
@ -52,16 +52,17 @@ module AlgebraicCombination = {
|
||||||
_,
|
_,
|
||||||
fun
|
fun
|
||||||
| `SymbolicDist(d) as t => Ok(t)
|
| `SymbolicDist(d) as t => Ok(t)
|
||||||
| _ => combineAsShapes(toLeaf, renderParams, algebraicOp, t1, t2),
|
| _ => combineAsShapes(evaluationParams, algebraicOp, t1, t2),
|
||||||
);
|
);
|
||||||
};
|
};
|
||||||
|
|
||||||
module VerticalScaling = {
|
module VerticalScaling = {
|
||||||
let operationToLeaf = (toLeaf, renderParams, scaleOp, t, scaleBy) => {
|
let operationToLeaf =
|
||||||
|
(evaluationParams: evaluationParams, scaleOp, t, scaleBy) => {
|
||||||
// scaleBy has to be a single float, otherwise we'll return an error.
|
// scaleBy has to be a single float, otherwise we'll return an error.
|
||||||
let fn = Operation.Scale.toFn(scaleOp);
|
let fn = Operation.Scale.toFn(scaleOp);
|
||||||
let knownIntegralSumFn = Operation.Scale.toKnownIntegralSumFn(scaleOp);
|
let knownIntegralSumFn = Operation.Scale.toKnownIntegralSumFn(scaleOp);
|
||||||
let renderedShape = toLeaf(renderParams, `Render(t));
|
let renderedShape = render(evaluationParams, t);
|
||||||
|
|
||||||
switch (renderedShape, scaleBy) {
|
switch (renderedShape, scaleBy) {
|
||||||
| (Ok(`RenderedDist(rs)), `SymbolicDist(`Float(sm))) =>
|
| (Ok(`RenderedDist(rs)), `SymbolicDist(`Float(sm))) =>
|
||||||
|
@ -81,9 +82,8 @@ module VerticalScaling = {
|
||||||
};
|
};
|
||||||
|
|
||||||
module PointwiseCombination = {
|
module PointwiseCombination = {
|
||||||
let pointwiseAdd = (toLeaf, renderParams, t1, t2) => {
|
let pointwiseAdd = (evaluationParams: evaluationParams, t1, t2) => {
|
||||||
let renderShape = r => toLeaf(renderParams, `Render(r));
|
switch (render(evaluationParams, t1), render(evaluationParams, t2)) {
|
||||||
switch (renderShape(t1), renderShape(t2)) {
|
|
||||||
| (Ok(`RenderedDist(rs1)), Ok(`RenderedDist(rs2))) =>
|
| (Ok(`RenderedDist(rs1)), Ok(`RenderedDist(rs2))) =>
|
||||||
Ok(
|
Ok(
|
||||||
`RenderedDist(
|
`RenderedDist(
|
||||||
|
@ -101,7 +101,7 @@ module PointwiseCombination = {
|
||||||
};
|
};
|
||||||
};
|
};
|
||||||
|
|
||||||
let pointwiseMultiply = (toLeaf, renderParams, t1, t2) => {
|
let pointwiseMultiply = (evaluationParams: evaluationParams, t1, t2) => {
|
||||||
// TODO: construct a function that we can easily sample from, to construct
|
// TODO: construct a function that we can easily sample from, to construct
|
||||||
// a RenderedDist. Use the xMin and xMax of the rendered shapes to tell the sampling function where to look.
|
// a RenderedDist. Use the xMin and xMax of the rendered shapes to tell the sampling function where to look.
|
||||||
Error(
|
Error(
|
||||||
|
@ -109,10 +109,11 @@ module PointwiseCombination = {
|
||||||
);
|
);
|
||||||
};
|
};
|
||||||
|
|
||||||
let operationToLeaf = (toLeaf, renderParams, pointwiseOp, t1, t2) => {
|
let operationToLeaf =
|
||||||
|
(evaluationParams: evaluationParams, pointwiseOp, t1, t2) => {
|
||||||
switch (pointwiseOp) {
|
switch (pointwiseOp) {
|
||||||
| `Add => pointwiseAdd(toLeaf, renderParams, t1, t2)
|
| `Add => pointwiseAdd(evaluationParams, t1, t2)
|
||||||
| `Multiply => pointwiseMultiply(toLeaf, renderParams, t1, t2)
|
| `Multiply => pointwiseMultiply(evaluationParams, t1, t2)
|
||||||
};
|
};
|
||||||
};
|
};
|
||||||
};
|
};
|
||||||
|
@ -133,24 +134,23 @@ module Truncate = {
|
||||||
};
|
};
|
||||||
};
|
};
|
||||||
|
|
||||||
let truncateAsShape = (toLeaf, renderParams, leftCutoff, rightCutoff, t) => {
|
let truncateAsShape =
|
||||||
|
(evaluationParams: evaluationParams, leftCutoff, rightCutoff, t) => {
|
||||||
// TODO: use named args in renderToShape; if we're lucky we can at least get the tail
|
// TODO: use named args in renderToShape; if we're lucky we can at least get the tail
|
||||||
// of a distribution we otherwise wouldn't get at all
|
// of a distribution we otherwise wouldn't get at all
|
||||||
let renderedShape = toLeaf(renderParams, `Render(t));
|
switch (render(evaluationParams, t)) {
|
||||||
switch (renderedShape) {
|
|
||||||
| Ok(`RenderedDist(rs)) =>
|
| Ok(`RenderedDist(rs)) =>
|
||||||
let truncatedShape =
|
let truncatedShape =
|
||||||
rs |> Distributions.Shape.T.truncate(leftCutoff, rightCutoff);
|
rs |> Distributions.Shape.T.truncate(leftCutoff, rightCutoff);
|
||||||
Ok(`RenderedDist(truncatedShape));
|
Ok(`RenderedDist(truncatedShape));
|
||||||
| Error(e1) => Error(e1)
|
| Error(e) => Error(e)
|
||||||
| _ => Error("Could not truncate distribution.")
|
| _ => Error("Could not truncate distribution.")
|
||||||
};
|
};
|
||||||
};
|
};
|
||||||
|
|
||||||
let operationToLeaf =
|
let operationToLeaf =
|
||||||
(
|
(
|
||||||
toLeaf,
|
evaluationParams,
|
||||||
renderParams,
|
|
||||||
leftCutoff: option(float),
|
leftCutoff: option(float),
|
||||||
rightCutoff: option(float),
|
rightCutoff: option(float),
|
||||||
t: node,
|
t: node,
|
||||||
|
@ -163,62 +163,59 @@ module Truncate = {
|
||||||
| `Solution(t) => Ok(t)
|
| `Solution(t) => Ok(t)
|
||||||
| `Error(e) => Error(e)
|
| `Error(e) => Error(e)
|
||||||
| `NoSolution =>
|
| `NoSolution =>
|
||||||
truncateAsShape(toLeaf, renderParams, leftCutoff, rightCutoff, t)
|
truncateAsShape(evaluationParams, leftCutoff, rightCutoff, t)
|
||||||
);
|
);
|
||||||
};
|
};
|
||||||
};
|
};
|
||||||
|
|
||||||
module Normalize = {
|
module Normalize = {
|
||||||
let rec operationToLeaf =
|
let rec operationToLeaf = (evaluationParams, t: node): result(node, string) => {
|
||||||
(toLeaf, renderParams, t: node): result(node, string) => {
|
|
||||||
switch (t) {
|
switch (t) {
|
||||||
| `RenderedDist(s) =>
|
| `RenderedDist(s) =>
|
||||||
Ok(`RenderedDist(Distributions.Shape.T.normalize(s)))
|
Ok(`RenderedDist(Distributions.Shape.T.normalize(s)))
|
||||||
| `SymbolicDist(_) => Ok(t)
|
| `SymbolicDist(_) => Ok(t)
|
||||||
| _ =>
|
| _ =>
|
||||||
t
|
t
|
||||||
|> toLeaf(renderParams)
|
|> evaluateNode(evaluationParams)
|
||||||
|> E.R.bind(_, operationToLeaf(toLeaf, renderParams))
|
|> E.R.bind(_, operationToLeaf(evaluationParams))
|
||||||
};
|
};
|
||||||
};
|
};
|
||||||
};
|
};
|
||||||
|
|
||||||
module FloatFromDist = {
|
module FloatFromDist = {
|
||||||
let symbolicToLeaf = (distToFloatOp: distToFloatOperation, s) => {
|
|
||||||
SymbolicDist.T.operate(distToFloatOp, s)
|
|
||||||
|> E.R.bind(_, v => Ok(`SymbolicDist(`Float(v))));
|
|
||||||
};
|
|
||||||
let renderedToLeaf =
|
|
||||||
(distToFloatOp: distToFloatOperation, rs: DistTypes.shape)
|
|
||||||
: result(node, string) => {
|
|
||||||
Distributions.Shape.operate(distToFloatOp, rs)
|
|
||||||
|> (v => Ok(`SymbolicDist(`Float(v))));
|
|
||||||
};
|
|
||||||
let rec operationToLeaf =
|
let rec operationToLeaf =
|
||||||
(toLeaf, renderParams, distToFloatOp: distToFloatOperation, t: node)
|
(evaluationParams, distToFloatOp: distToFloatOperation, t: node)
|
||||||
: result(node, string) => {
|
: result(node, string) => {
|
||||||
switch (t) {
|
switch (t) {
|
||||||
| `SymbolicDist(s) => symbolicToLeaf(distToFloatOp, s)
|
| `SymbolicDist(s) =>
|
||||||
| `RenderedDist(rs) => renderedToLeaf(distToFloatOp, rs)
|
SymbolicDist.T.operate(distToFloatOp, s)
|
||||||
|
|> E.R.bind(_, v => Ok(`SymbolicDist(`Float(v))))
|
||||||
|
| `RenderedDist(rs) =>
|
||||||
|
Distributions.Shape.operate(distToFloatOp, rs)
|
||||||
|
|> (v => Ok(`SymbolicDist(`Float(v))))
|
||||||
| _ =>
|
| _ =>
|
||||||
t
|
t
|
||||||
|> toLeaf(renderParams)
|
|> evaluateNode(evaluationParams)
|
||||||
|> E.R.bind(_, operationToLeaf(toLeaf, renderParams, distToFloatOp))
|
|> E.R.bind(_, operationToLeaf(evaluationParams, distToFloatOp))
|
||||||
};
|
};
|
||||||
};
|
};
|
||||||
};
|
};
|
||||||
|
|
||||||
module Render = {
|
module Render = {
|
||||||
let rec operationToLeaf =
|
let rec operationToLeaf =
|
||||||
(toLeaf, renderParams, t: node): result(t, string) => {
|
(evaluationParams: evaluationParams, t: node): result(t, string) => {
|
||||||
switch (t) {
|
switch (t) {
|
||||||
| `SymbolicDist(d) =>
|
| `SymbolicDist(d) =>
|
||||||
Ok(`RenderedDist(SymbolicDist.T.toShape(renderParams.sampleCount, d)))
|
Ok(
|
||||||
|
`RenderedDist(
|
||||||
|
SymbolicDist.T.toShape(evaluationParams.sampleCount, d),
|
||||||
|
),
|
||||||
|
)
|
||||||
| `RenderedDist(_) as t => Ok(t) // already a rendered shape, we're done here
|
| `RenderedDist(_) as t => Ok(t) // already a rendered shape, we're done here
|
||||||
| _ =>
|
| _ =>
|
||||||
t
|
t
|
||||||
|> toLeaf(renderParams)
|
|> evaluateNode(evaluationParams)
|
||||||
|> E.R.bind(_, operationToLeaf(toLeaf, renderParams))
|
|> E.R.bind(_, operationToLeaf(evaluationParams))
|
||||||
};
|
};
|
||||||
};
|
};
|
||||||
};
|
};
|
||||||
|
@ -229,35 +226,38 @@ module Render = {
|
||||||
but most often it will produce a RenderedDist.
|
but most often it will produce a RenderedDist.
|
||||||
This function is used mainly to turn a parse tree into a single RenderedDist
|
This function is used mainly to turn a parse tree into a single RenderedDist
|
||||||
that can then be displayed to the user. */
|
that can then be displayed to the user. */
|
||||||
let rec toLeaf = (renderParams, node: t): result(t, string) => {
|
let toLeaf =
|
||||||
|
(
|
||||||
|
evaluationParams: ExpressionTypes.ExpressionTree.evaluationParams,
|
||||||
|
node: t,
|
||||||
|
)
|
||||||
|
: result(t, string) => {
|
||||||
switch (node) {
|
switch (node) {
|
||||||
// Leaf nodes just stay leaf nodes
|
// Leaf nodes just stay leaf nodes
|
||||||
| `SymbolicDist(_)
|
| `SymbolicDist(_)
|
||||||
| `RenderedDist(_) => Ok(node)
|
| `RenderedDist(_) => Ok(node)
|
||||||
// Operations need to be turned into leaves
|
// Operations nevaluationParamsd to be turned into leaves
|
||||||
| `AlgebraicCombination(algebraicOp, t1, t2) =>
|
| `AlgebraicCombination(algebraicOp, t1, t2) =>
|
||||||
AlgebraicCombination.operationToLeaf(
|
AlgebraicCombination.operationToLeaf(
|
||||||
toLeaf,
|
evaluationParams,
|
||||||
renderParams,
|
|
||||||
algebraicOp,
|
algebraicOp,
|
||||||
t1,
|
t1,
|
||||||
t2,
|
t2,
|
||||||
)
|
)
|
||||||
| `PointwiseCombination(pointwiseOp, t1, t2) =>
|
| `PointwiseCombination(pointwiseOp, t1, t2) =>
|
||||||
PointwiseCombination.operationToLeaf(
|
PointwiseCombination.operationToLeaf(
|
||||||
toLeaf,
|
evaluationParams,
|
||||||
renderParams,
|
|
||||||
pointwiseOp,
|
pointwiseOp,
|
||||||
t1,
|
t1,
|
||||||
t2,
|
t2,
|
||||||
)
|
)
|
||||||
| `VerticalScaling(scaleOp, t, scaleBy) =>
|
| `VerticalScaling(scaleOp, t, scaleBy) =>
|
||||||
VerticalScaling.operationToLeaf(toLeaf, renderParams, scaleOp, t, scaleBy)
|
VerticalScaling.operationToLeaf(evaluationParams, scaleOp, t, scaleBy)
|
||||||
| `Truncate(leftCutoff, rightCutoff, t) =>
|
| `Truncate(leftCutoff, rightCutoff, t) =>
|
||||||
Truncate.operationToLeaf(toLeaf, renderParams, leftCutoff, rightCutoff, t)
|
Truncate.operationToLeaf(evaluationParams, leftCutoff, rightCutoff, t)
|
||||||
| `FloatFromDist(distToFloatOp, t) =>
|
| `FloatFromDist(distToFloatOp, t) =>
|
||||||
FloatFromDist.operationToLeaf(toLeaf, renderParams, distToFloatOp, t)
|
FloatFromDist.operationToLeaf(evaluationParams, distToFloatOp, t)
|
||||||
| `Normalize(t) => Normalize.operationToLeaf(toLeaf, renderParams, t)
|
| `Normalize(t) => Normalize.operationToLeaf(evaluationParams, t)
|
||||||
| `Render(t) => Render.operationToLeaf(toLeaf, renderParams, t)
|
| `Render(t) => Render.operationToLeaf(evaluationParams, t)
|
||||||
};
|
};
|
||||||
};
|
};
|
||||||
|
|
|
@ -5,10 +5,8 @@ type distToFloatOperation = [ | `Pdf(float) | `Inv(float) | `Mean | `Sample];
|
||||||
|
|
||||||
module ExpressionTree = {
|
module ExpressionTree = {
|
||||||
type node = [
|
type node = [
|
||||||
// leaf nodes:
|
|
||||||
| `SymbolicDist(SymbolicTypes.symbolicDist)
|
| `SymbolicDist(SymbolicTypes.symbolicDist)
|
||||||
| `RenderedDist(DistTypes.shape)
|
| `RenderedDist(DistTypes.shape)
|
||||||
// operations:
|
|
||||||
| `AlgebraicCombination(algebraicOperation, node, node)
|
| `AlgebraicCombination(algebraicOperation, node, node)
|
||||||
| `PointwiseCombination(pointwiseOperation, node, node)
|
| `PointwiseCombination(pointwiseOperation, node, node)
|
||||||
| `VerticalScaling(scaleOperation, node, node)
|
| `VerticalScaling(scaleOperation, node, node)
|
||||||
|
@ -17,6 +15,22 @@ module ExpressionTree = {
|
||||||
| `Normalize(node)
|
| `Normalize(node)
|
||||||
| `FloatFromDist(distToFloatOperation, node)
|
| `FloatFromDist(distToFloatOperation, node)
|
||||||
];
|
];
|
||||||
|
|
||||||
|
type dist = [
|
||||||
|
| `SymbolicDist(SymbolicTypes.symbolicDist)
|
||||||
|
| `RenderedDist(DistTypes.shape)
|
||||||
|
]
|
||||||
|
|
||||||
|
type evaluationParams = {
|
||||||
|
sampleCount: int,
|
||||||
|
evaluateNode: (evaluationParams, node) => Belt.Result.t(node, string),
|
||||||
|
};
|
||||||
|
|
||||||
|
let evaluateNode = (evaluationParams: evaluationParams) =>
|
||||||
|
evaluationParams.evaluateNode(evaluationParams);
|
||||||
|
|
||||||
|
let render = (evaluationParams: evaluationParams, r) =>
|
||||||
|
evaluateNode(evaluationParams, `Render(r));
|
||||||
};
|
};
|
||||||
|
|
||||||
type simplificationResult = [
|
type simplificationResult = [
|
||||||
|
|
Loading…
Reference in New Issue
Block a user