From 9d0ecda2979798c00c205c00daee1664e75f4cbd Mon Sep 17 00:00:00 2001 From: Ozzie Gooen Date: Wed, 8 Jul 2020 11:39:03 +0100 Subject: [PATCH] Moved data to evaluationParams --- src/distPlus/expressionTree/ExpressionTree.re | 2 +- .../expressionTree/ExpressionTreeEvaluator.re | 110 +++++++++--------- .../expressionTree/ExpressionTypes.re | 18 ++- 3 files changed, 72 insertions(+), 58 deletions(-) diff --git a/src/distPlus/expressionTree/ExpressionTree.re b/src/distPlus/expressionTree/ExpressionTree.re index c5e4e0a4..333801bf 100644 --- a/src/distPlus/expressionTree/ExpressionTree.re +++ b/src/distPlus/expressionTree/ExpressionTree.re @@ -3,7 +3,7 @@ open ExpressionTypes.ExpressionTree; let toShape = (sampleCount: int, node: node) => { let renderResult = `Render(`Normalize(node)) - |> ExpressionTreeEvaluator.toLeaf({sampleCount: sampleCount}); + |> ExpressionTreeEvaluator.toLeaf({sampleCount: sampleCount, evaluateNode: ExpressionTreeEvaluator.toLeaf}); switch (renderResult) { | Ok(`RenderedDist(rs)) => diff --git a/src/distPlus/expressionTree/ExpressionTreeEvaluator.re b/src/distPlus/expressionTree/ExpressionTreeEvaluator.re index 60379971..d91e2484 100644 --- a/src/distPlus/expressionTree/ExpressionTreeEvaluator.re +++ b/src/distPlus/expressionTree/ExpressionTreeEvaluator.re @@ -22,8 +22,9 @@ module AlgebraicCombination = { | _ => Ok(`AlgebraicCombination((operation, t1, t2))) }; - let combineAsShapes = (toLeaf, renderParams, algebraicOp, t1, t2) => { - let renderShape = r => toLeaf(renderParams, `Render(r)); + let combineAsShapes = + (evaluationParams: evaluationParams, algebraicOp, t1, t2) => { + let renderShape = render(evaluationParams); switch (renderShape(t1), renderShape(t2)) { | (Ok(`RenderedDist(s1)), Ok(`RenderedDist(s2))) => Ok( @@ -39,8 +40,7 @@ module AlgebraicCombination = { let operationToLeaf = ( - toLeaf, - renderParams: renderParams, + evaluationParams: evaluationParams, algebraicOp: ExpressionTypes.algebraicOperation, t1: t, t2: t, @@ -52,16 +52,17 @@ module AlgebraicCombination = { _, fun | `SymbolicDist(d) as t => Ok(t) - | _ => combineAsShapes(toLeaf, renderParams, algebraicOp, t1, t2), + | _ => combineAsShapes(evaluationParams, algebraicOp, t1, t2), ); }; module VerticalScaling = { - let operationToLeaf = (toLeaf, renderParams, scaleOp, t, scaleBy) => { + let operationToLeaf = + (evaluationParams: evaluationParams, scaleOp, t, scaleBy) => { // scaleBy has to be a single float, otherwise we'll return an error. let fn = Operation.Scale.toFn(scaleOp); let knownIntegralSumFn = Operation.Scale.toKnownIntegralSumFn(scaleOp); - let renderedShape = toLeaf(renderParams, `Render(t)); + let renderedShape = render(evaluationParams, t); switch (renderedShape, scaleBy) { | (Ok(`RenderedDist(rs)), `SymbolicDist(`Float(sm))) => @@ -81,9 +82,8 @@ module VerticalScaling = { }; module PointwiseCombination = { - let pointwiseAdd = (toLeaf, renderParams, t1, t2) => { - let renderShape = r => toLeaf(renderParams, `Render(r)); - switch (renderShape(t1), renderShape(t2)) { + let pointwiseAdd = (evaluationParams: evaluationParams, t1, t2) => { + switch (render(evaluationParams, t1), render(evaluationParams, t2)) { | (Ok(`RenderedDist(rs1)), Ok(`RenderedDist(rs2))) => Ok( `RenderedDist( @@ -101,7 +101,7 @@ module PointwiseCombination = { }; }; - let pointwiseMultiply = (toLeaf, renderParams, t1, t2) => { + let pointwiseMultiply = (evaluationParams: evaluationParams, t1, t2) => { // TODO: construct a function that we can easily sample from, to construct // a RenderedDist. Use the xMin and xMax of the rendered shapes to tell the sampling function where to look. Error( @@ -109,10 +109,11 @@ module PointwiseCombination = { ); }; - let operationToLeaf = (toLeaf, renderParams, pointwiseOp, t1, t2) => { + let operationToLeaf = + (evaluationParams: evaluationParams, pointwiseOp, t1, t2) => { switch (pointwiseOp) { - | `Add => pointwiseAdd(toLeaf, renderParams, t1, t2) - | `Multiply => pointwiseMultiply(toLeaf, renderParams, t1, t2) + | `Add => pointwiseAdd(evaluationParams, t1, t2) + | `Multiply => pointwiseMultiply(evaluationParams, t1, t2) }; }; }; @@ -133,24 +134,23 @@ module Truncate = { }; }; - let truncateAsShape = (toLeaf, renderParams, leftCutoff, rightCutoff, t) => { + let truncateAsShape = + (evaluationParams: evaluationParams, leftCutoff, rightCutoff, t) => { // TODO: use named args in renderToShape; if we're lucky we can at least get the tail // of a distribution we otherwise wouldn't get at all - let renderedShape = toLeaf(renderParams, `Render(t)); - switch (renderedShape) { + switch (render(evaluationParams, t)) { | Ok(`RenderedDist(rs)) => let truncatedShape = rs |> Distributions.Shape.T.truncate(leftCutoff, rightCutoff); Ok(`RenderedDist(truncatedShape)); - | Error(e1) => Error(e1) + | Error(e) => Error(e) | _ => Error("Could not truncate distribution.") }; }; let operationToLeaf = ( - toLeaf, - renderParams, + evaluationParams, leftCutoff: option(float), rightCutoff: option(float), t: node, @@ -163,62 +163,59 @@ module Truncate = { | `Solution(t) => Ok(t) | `Error(e) => Error(e) | `NoSolution => - truncateAsShape(toLeaf, renderParams, leftCutoff, rightCutoff, t) + truncateAsShape(evaluationParams, leftCutoff, rightCutoff, t) ); }; }; module Normalize = { - let rec operationToLeaf = - (toLeaf, renderParams, t: node): result(node, string) => { + let rec operationToLeaf = (evaluationParams, t: node): result(node, string) => { switch (t) { | `RenderedDist(s) => Ok(`RenderedDist(Distributions.Shape.T.normalize(s))) | `SymbolicDist(_) => Ok(t) | _ => t - |> toLeaf(renderParams) - |> E.R.bind(_, operationToLeaf(toLeaf, renderParams)) + |> evaluateNode(evaluationParams) + |> E.R.bind(_, operationToLeaf(evaluationParams)) }; }; }; module FloatFromDist = { - let symbolicToLeaf = (distToFloatOp: distToFloatOperation, s) => { - SymbolicDist.T.operate(distToFloatOp, s) - |> E.R.bind(_, v => Ok(`SymbolicDist(`Float(v)))); - }; - let renderedToLeaf = - (distToFloatOp: distToFloatOperation, rs: DistTypes.shape) - : result(node, string) => { - Distributions.Shape.operate(distToFloatOp, rs) - |> (v => Ok(`SymbolicDist(`Float(v)))); - }; let rec operationToLeaf = - (toLeaf, renderParams, distToFloatOp: distToFloatOperation, t: node) + (evaluationParams, distToFloatOp: distToFloatOperation, t: node) : result(node, string) => { switch (t) { - | `SymbolicDist(s) => symbolicToLeaf(distToFloatOp, s) - | `RenderedDist(rs) => renderedToLeaf(distToFloatOp, rs) + | `SymbolicDist(s) => + SymbolicDist.T.operate(distToFloatOp, s) + |> E.R.bind(_, v => Ok(`SymbolicDist(`Float(v)))) + | `RenderedDist(rs) => + Distributions.Shape.operate(distToFloatOp, rs) + |> (v => Ok(`SymbolicDist(`Float(v)))) | _ => t - |> toLeaf(renderParams) - |> E.R.bind(_, operationToLeaf(toLeaf, renderParams, distToFloatOp)) + |> evaluateNode(evaluationParams) + |> E.R.bind(_, operationToLeaf(evaluationParams, distToFloatOp)) }; }; }; module Render = { let rec operationToLeaf = - (toLeaf, renderParams, t: node): result(t, string) => { + (evaluationParams: evaluationParams, t: node): result(t, string) => { switch (t) { | `SymbolicDist(d) => - Ok(`RenderedDist(SymbolicDist.T.toShape(renderParams.sampleCount, d))) + Ok( + `RenderedDist( + SymbolicDist.T.toShape(evaluationParams.sampleCount, d), + ), + ) | `RenderedDist(_) as t => Ok(t) // already a rendered shape, we're done here | _ => t - |> toLeaf(renderParams) - |> E.R.bind(_, operationToLeaf(toLeaf, renderParams)) + |> evaluateNode(evaluationParams) + |> E.R.bind(_, operationToLeaf(evaluationParams)) }; }; }; @@ -229,35 +226,38 @@ module Render = { but most often it will produce a RenderedDist. This function is used mainly to turn a parse tree into a single RenderedDist that can then be displayed to the user. */ -let rec toLeaf = (renderParams, node: t): result(t, string) => { +let toLeaf = + ( + evaluationParams: ExpressionTypes.ExpressionTree.evaluationParams, + node: t, + ) + : result(t, string) => { switch (node) { // Leaf nodes just stay leaf nodes | `SymbolicDist(_) | `RenderedDist(_) => Ok(node) - // Operations need to be turned into leaves + // Operations nevaluationParamsd to be turned into leaves | `AlgebraicCombination(algebraicOp, t1, t2) => AlgebraicCombination.operationToLeaf( - toLeaf, - renderParams, + evaluationParams, algebraicOp, t1, t2, ) | `PointwiseCombination(pointwiseOp, t1, t2) => PointwiseCombination.operationToLeaf( - toLeaf, - renderParams, + evaluationParams, pointwiseOp, t1, t2, ) | `VerticalScaling(scaleOp, t, scaleBy) => - VerticalScaling.operationToLeaf(toLeaf, renderParams, scaleOp, t, scaleBy) + VerticalScaling.operationToLeaf(evaluationParams, scaleOp, t, scaleBy) | `Truncate(leftCutoff, rightCutoff, t) => - Truncate.operationToLeaf(toLeaf, renderParams, leftCutoff, rightCutoff, t) + Truncate.operationToLeaf(evaluationParams, leftCutoff, rightCutoff, t) | `FloatFromDist(distToFloatOp, t) => - FloatFromDist.operationToLeaf(toLeaf, renderParams, distToFloatOp, t) - | `Normalize(t) => Normalize.operationToLeaf(toLeaf, renderParams, t) - | `Render(t) => Render.operationToLeaf(toLeaf, renderParams, t) + FloatFromDist.operationToLeaf(evaluationParams, distToFloatOp, t) + | `Normalize(t) => Normalize.operationToLeaf(evaluationParams, t) + | `Render(t) => Render.operationToLeaf(evaluationParams, t) }; }; diff --git a/src/distPlus/expressionTree/ExpressionTypes.re b/src/distPlus/expressionTree/ExpressionTypes.re index 8b6ece67..252bb301 100644 --- a/src/distPlus/expressionTree/ExpressionTypes.re +++ b/src/distPlus/expressionTree/ExpressionTypes.re @@ -5,10 +5,8 @@ type distToFloatOperation = [ | `Pdf(float) | `Inv(float) | `Mean | `Sample]; module ExpressionTree = { type node = [ - // leaf nodes: | `SymbolicDist(SymbolicTypes.symbolicDist) | `RenderedDist(DistTypes.shape) - // operations: | `AlgebraicCombination(algebraicOperation, node, node) | `PointwiseCombination(pointwiseOperation, node, node) | `VerticalScaling(scaleOperation, node, node) @@ -17,6 +15,22 @@ module ExpressionTree = { | `Normalize(node) | `FloatFromDist(distToFloatOperation, node) ]; + + type dist = [ + | `SymbolicDist(SymbolicTypes.symbolicDist) + | `RenderedDist(DistTypes.shape) + ] + + type evaluationParams = { + sampleCount: int, + evaluateNode: (evaluationParams, node) => Belt.Result.t(node, string), + }; + + let evaluateNode = (evaluationParams: evaluationParams) => + evaluationParams.evaluateNode(evaluationParams); + + let render = (evaluationParams: evaluationParams, r) => + evaluateNode(evaluationParams, `Render(r)); }; type simplificationResult = [