Formatted ExpressionTreeEvaluator.re
This commit is contained in:
		
							parent
							
								
									56a9bda82a
								
							
						
					
					
						commit
						4cf7a69d3e
					
				| 
						 | 
				
			
			@ -4,9 +4,7 @@ open ExpressionTypes.ExpressionTree;
 | 
			
		|||
type t = node;
 | 
			
		||||
type tResult = node => result(node, string);
 | 
			
		||||
 | 
			
		||||
type renderParams = {
 | 
			
		||||
  sampleCount: int,
 | 
			
		||||
};
 | 
			
		||||
type renderParams = {sampleCount: int};
 | 
			
		||||
 | 
			
		||||
/* Given two random variables A and B, this returns the distribution
 | 
			
		||||
   of a new variable that is the result of the operation on A and B.
 | 
			
		||||
| 
						 | 
				
			
			@ -15,26 +13,23 @@ type renderParams = {
 | 
			
		|||
module AlgebraicCombination = {
 | 
			
		||||
  let tryAnalyticalSimplification = (operation, t1: t, t2: t) =>
 | 
			
		||||
    switch (operation, t1, t2) {
 | 
			
		||||
    | (operation,
 | 
			
		||||
          `SymbolicDist(d1),
 | 
			
		||||
          `SymbolicDist(d2),
 | 
			
		||||
        ) =>
 | 
			
		||||
    | (operation, `SymbolicDist(d1), `SymbolicDist(d2)) =>
 | 
			
		||||
      switch (SymbolicDist.T.tryAnalyticalSimplification(d1, d2, operation)) {
 | 
			
		||||
      | `AnalyticalSolution(symbolicDist) => Ok(`SymbolicDist(symbolicDist))
 | 
			
		||||
      | `Error(er) => Error(er)
 | 
			
		||||
      | `NoSolution => Ok(`AlgebraicCombination(operation, t1, t2))
 | 
			
		||||
      | `NoSolution => Ok(`AlgebraicCombination((operation, t1, t2)))
 | 
			
		||||
      }
 | 
			
		||||
    | _ => Ok(`AlgebraicCombination(operation, t1, t2))
 | 
			
		||||
  };
 | 
			
		||||
    | _ => Ok(`AlgebraicCombination((operation, t1, t2)))
 | 
			
		||||
    };
 | 
			
		||||
 | 
			
		||||
  let combineAsShapes = (toLeaf, renderParams, algebraicOp, t1, t2) => {
 | 
			
		||||
    let renderShape = r => toLeaf(renderParams, `Render(r));
 | 
			
		||||
    switch (renderShape(t1), renderShape(t2)) {
 | 
			
		||||
    | (Ok(`RenderedDist(s1)), Ok(`RenderedDist(s2))) =>
 | 
			
		||||
      Ok(
 | 
			
		||||
          `RenderedDist(
 | 
			
		||||
            Distributions.Shape.combineAlgebraically(algebraicOp, s1, s2),
 | 
			
		||||
          ),
 | 
			
		||||
        `RenderedDist(
 | 
			
		||||
          Distributions.Shape.combineAlgebraically(algebraicOp, s1, s2),
 | 
			
		||||
        ),
 | 
			
		||||
      )
 | 
			
		||||
    | (Error(e1), _) => Error(e1)
 | 
			
		||||
    | (_, Error(e2)) => Error(e2)
 | 
			
		||||
| 
						 | 
				
			
			@ -51,14 +46,13 @@ module AlgebraicCombination = {
 | 
			
		|||
        t2: t,
 | 
			
		||||
      )
 | 
			
		||||
      : result(node, string) =>
 | 
			
		||||
 | 
			
		||||
    algebraicOp
 | 
			
		||||
    |> tryAnalyticalSimplification(_, t1, t2)
 | 
			
		||||
    |> E.R.bind(
 | 
			
		||||
         _,
 | 
			
		||||
         fun
 | 
			
		||||
        | `SymbolicDist(d) as t => Ok(t)
 | 
			
		||||
        | _ => combineAsShapes(toLeaf, renderParams, algebraicOp, t1, t2)
 | 
			
		||||
         | `SymbolicDist(d) as t => Ok(t)
 | 
			
		||||
         | _ => combineAsShapes(toLeaf, renderParams, algebraicOp, t1, t2),
 | 
			
		||||
       );
 | 
			
		||||
};
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			@ -72,12 +66,12 @@ module VerticalScaling = {
 | 
			
		|||
    switch (renderedShape, scaleBy) {
 | 
			
		||||
    | (Ok(`RenderedDist(rs)), `SymbolicDist(`Float(sm))) =>
 | 
			
		||||
      Ok(
 | 
			
		||||
          `RenderedDist(
 | 
			
		||||
            Distributions.Shape.T.mapY(
 | 
			
		||||
              ~knownIntegralSumFn=knownIntegralSumFn(sm),
 | 
			
		||||
              fn(sm),
 | 
			
		||||
              rs,
 | 
			
		||||
            ),
 | 
			
		||||
        `RenderedDist(
 | 
			
		||||
          Distributions.Shape.T.mapY(
 | 
			
		||||
            ~knownIntegralSumFn=knownIntegralSumFn(sm),
 | 
			
		||||
            fn(sm),
 | 
			
		||||
            rs,
 | 
			
		||||
          ),
 | 
			
		||||
        ),
 | 
			
		||||
      )
 | 
			
		||||
    | (Error(e1), _) => Error(e1)
 | 
			
		||||
| 
						 | 
				
			
			@ -127,13 +121,12 @@ module Truncate = {
 | 
			
		|||
  let trySimplification = (leftCutoff, rightCutoff, t) => {
 | 
			
		||||
    switch (leftCutoff, rightCutoff, t) {
 | 
			
		||||
    | (None, None, t) => Ok(t)
 | 
			
		||||
    | (lc, rc, `SymbolicDist(`Uniform(u))) => {
 | 
			
		||||
        // just create a new Uniform distribution
 | 
			
		||||
        let nu: SymbolicTypes.uniform = u;
 | 
			
		||||
        let newLow = max(E.O.default(neg_infinity, lc), nu.low);
 | 
			
		||||
        let newHigh = min(E.O.default(infinity, rc), nu.high);
 | 
			
		||||
        Ok(`SymbolicDist(`Uniform({low: newLow, high: newHigh})));
 | 
			
		||||
      }
 | 
			
		||||
    | (lc, rc, `SymbolicDist(`Uniform(u))) =>
 | 
			
		||||
      // just create a new Uniform distribution
 | 
			
		||||
      let nu: SymbolicTypes.uniform = u;
 | 
			
		||||
      let newLow = max(E.O.default(neg_infinity, lc), nu.low);
 | 
			
		||||
      let newHigh = min(E.O.default(infinity, rc), nu.high);
 | 
			
		||||
      Ok(`SymbolicDist(`Uniform({low: newLow, high: newHigh})));
 | 
			
		||||
    | (_, _, t) => Ok(t)
 | 
			
		||||
    };
 | 
			
		||||
  };
 | 
			
		||||
| 
						 | 
				
			
			@ -144,43 +137,47 @@ module Truncate = {
 | 
			
		|||
    let renderedShape = toLeaf(renderParams, `Render(t));
 | 
			
		||||
 | 
			
		||||
    switch (renderedShape) {
 | 
			
		||||
    | Ok(`RenderedDist(rs)) => {
 | 
			
		||||
    | Ok(`RenderedDist(rs)) =>
 | 
			
		||||
      let truncatedShape =
 | 
			
		||||
        rs |> Distributions.Shape.T.truncate(leftCutoff, rightCutoff);
 | 
			
		||||
      Ok(`RenderedDist(truncatedShape));
 | 
			
		||||
    }
 | 
			
		||||
    | Error(e1) => Error(e1)
 | 
			
		||||
    | _ => Error("Could not truncate distribution.")
 | 
			
		||||
    };
 | 
			
		||||
  };
 | 
			
		||||
 | 
			
		||||
  let operationToLeaf =
 | 
			
		||||
  (
 | 
			
		||||
    toLeaf,
 | 
			
		||||
      renderParams,
 | 
			
		||||
    leftCutoff: option(float),
 | 
			
		||||
    rightCutoff: option(float),
 | 
			
		||||
    t: node,
 | 
			
		||||
  )
 | 
			
		||||
  : result(node, string) => {
 | 
			
		||||
      (
 | 
			
		||||
        toLeaf,
 | 
			
		||||
        renderParams,
 | 
			
		||||
        leftCutoff: option(float),
 | 
			
		||||
        rightCutoff: option(float),
 | 
			
		||||
        t: node,
 | 
			
		||||
      )
 | 
			
		||||
      : result(node, string) => {
 | 
			
		||||
    t
 | 
			
		||||
    |> trySimplification(leftCutoff, rightCutoff)
 | 
			
		||||
    |> E.R.bind(
 | 
			
		||||
         _,
 | 
			
		||||
         fun
 | 
			
		||||
         | `SymbolicDist(d) as t => Ok(t)
 | 
			
		||||
         | _ => truncateAsShape(toLeaf, renderParams, leftCutoff, rightCutoff, t),
 | 
			
		||||
         | _ =>
 | 
			
		||||
           truncateAsShape(toLeaf, renderParams, leftCutoff, rightCutoff, t),
 | 
			
		||||
       );
 | 
			
		||||
  };
 | 
			
		||||
};
 | 
			
		||||
 | 
			
		||||
module Normalize = {
 | 
			
		||||
  let rec operationToLeaf = (toLeaf, renderParams, t: node): result(node, string) => {
 | 
			
		||||
  let rec operationToLeaf =
 | 
			
		||||
          (toLeaf, renderParams, t: node): result(node, string) => {
 | 
			
		||||
    switch (t) {
 | 
			
		||||
    | `RenderedDist(s) =>
 | 
			
		||||
      Ok(`RenderedDist(Distributions.Shape.T.normalize(s)))
 | 
			
		||||
    | `SymbolicDist(_) => Ok(t)
 | 
			
		||||
    | _ => t |> toLeaf(renderParams) |> E.R.bind(_, operationToLeaf(toLeaf, renderParams))
 | 
			
		||||
    | _ =>
 | 
			
		||||
      t
 | 
			
		||||
      |> toLeaf(renderParams)
 | 
			
		||||
      |> E.R.bind(_, operationToLeaf(toLeaf, renderParams))
 | 
			
		||||
    };
 | 
			
		||||
  };
 | 
			
		||||
};
 | 
			
		||||
| 
						 | 
				
			
			@ -202,24 +199,25 @@ module FloatFromDist = {
 | 
			
		|||
    switch (t) {
 | 
			
		||||
    | `SymbolicDist(s) => symbolicToLeaf(distToFloatOp, s)
 | 
			
		||||
    | `RenderedDist(rs) => renderedToLeaf(distToFloatOp, rs)
 | 
			
		||||
    | _ => t |> toLeaf(renderParams) |> E.R.bind(_, operationToLeaf(toLeaf, renderParams, distToFloatOp))
 | 
			
		||||
    | _ =>
 | 
			
		||||
      t
 | 
			
		||||
      |> toLeaf(renderParams)
 | 
			
		||||
      |> E.R.bind(_, operationToLeaf(toLeaf, renderParams, distToFloatOp))
 | 
			
		||||
    };
 | 
			
		||||
  };
 | 
			
		||||
};
 | 
			
		||||
 | 
			
		||||
module Render = {
 | 
			
		||||
  let rec operationToLeaf =
 | 
			
		||||
          (
 | 
			
		||||
            toLeaf,
 | 
			
		||||
            renderParams,
 | 
			
		||||
            t: node,
 | 
			
		||||
          )
 | 
			
		||||
          : result(t, string) => {
 | 
			
		||||
          (toLeaf, renderParams, t: node): result(t, string) => {
 | 
			
		||||
    switch (t) {
 | 
			
		||||
    | `SymbolicDist(d) =>
 | 
			
		||||
      Ok(`RenderedDist(SymbolicDist.T.toShape(renderParams.sampleCount, d)))
 | 
			
		||||
    | `RenderedDist(_) as t => Ok(t) // already a rendered shape, we're done here
 | 
			
		||||
    | _ => t |> toLeaf(renderParams) |> E.R.bind(_, operationToLeaf(toLeaf, renderParams))
 | 
			
		||||
    | _ =>
 | 
			
		||||
      t
 | 
			
		||||
      |> toLeaf(renderParams)
 | 
			
		||||
      |> E.R.bind(_, operationToLeaf(toLeaf, renderParams))
 | 
			
		||||
    };
 | 
			
		||||
  };
 | 
			
		||||
};
 | 
			
		||||
| 
						 | 
				
			
			@ -242,7 +240,7 @@ let rec toLeaf = (renderParams, node: t): result(t, string) => {
 | 
			
		|||
      renderParams,
 | 
			
		||||
      algebraicOp,
 | 
			
		||||
      t1,
 | 
			
		||||
      t2
 | 
			
		||||
      t2,
 | 
			
		||||
    )
 | 
			
		||||
  | `PointwiseCombination(pointwiseOp, t1, t2) =>
 | 
			
		||||
    PointwiseCombination.operationToLeaf(
 | 
			
		||||
| 
						 | 
				
			
			@ -253,9 +251,7 @@ let rec toLeaf = (renderParams, node: t): result(t, string) => {
 | 
			
		|||
      t2,
 | 
			
		||||
    )
 | 
			
		||||
  | `VerticalScaling(scaleOp, t, scaleBy) =>
 | 
			
		||||
    VerticalScaling.operationToLeaf(
 | 
			
		||||
    toLeaf, renderParams, scaleOp, t, scaleBy
 | 
			
		||||
  )
 | 
			
		||||
    VerticalScaling.operationToLeaf(toLeaf, renderParams, scaleOp, t, scaleBy)
 | 
			
		||||
  | `Truncate(leftCutoff, rightCutoff, t) =>
 | 
			
		||||
    Truncate.operationToLeaf(toLeaf, renderParams, leftCutoff, rightCutoff, t)
 | 
			
		||||
  | `FloatFromDist(distToFloatOp, t) =>
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
		Loading…
	
		Reference in New Issue
	
	Block a user