Made more dists using new format
This commit is contained in:
		
							parent
							
								
									c57cc3144e
								
							
						
					
					
						commit
						479fdbb491
					
				| 
						 | 
				
			
			@ -8,8 +8,6 @@ let rec toString: node => string =
 | 
			
		|||
    Operation.Algebraic.format(op, toString(t1), toString(t2))
 | 
			
		||||
  | `PointwiseCombination(op, t1, t2) =>
 | 
			
		||||
    Operation.Pointwise.format(op, toString(t1), toString(t2))
 | 
			
		||||
  | `VerticalScaling(scaleOp, t, scaleBy) =>
 | 
			
		||||
    Operation.Scale.format(scaleOp, toString(t), toString(scaleBy))
 | 
			
		||||
  | `Normalize(t) => "normalize(k" ++ toString(t) ++ ")"
 | 
			
		||||
  | `Truncate(lc, rc, t) =>
 | 
			
		||||
    Operation.T.truncateToString(lc, rc, toString(t))
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -91,36 +91,6 @@ module AlgebraicCombination = {
 | 
			
		|||
       );
 | 
			
		||||
};
 | 
			
		||||
 | 
			
		||||
module VerticalScaling = {
 | 
			
		||||
  let operationToLeaf =
 | 
			
		||||
      (evaluationParams: evaluationParams, scaleOp, t, scaleBy) => {
 | 
			
		||||
    // scaleBy has to be a single float, otherwise we'll return an error.
 | 
			
		||||
    let fn = (secondary, main) =>
 | 
			
		||||
      Operation.Scale.toFn(scaleOp, main, secondary);
 | 
			
		||||
    let integralSumCacheFn = Operation.Scale.toIntegralSumCacheFn(scaleOp);
 | 
			
		||||
    let integralCacheFn = Operation.Scale.toIntegralCacheFn(scaleOp);
 | 
			
		||||
    let renderedShape = Render.render(evaluationParams, t);
 | 
			
		||||
 | 
			
		||||
    let s =
 | 
			
		||||
      switch (renderedShape, scaleBy) {
 | 
			
		||||
      | (Ok(`RenderedDist(rs)), `SymbolicDist(`Float(scaleBy))) =>
 | 
			
		||||
        Ok(
 | 
			
		||||
          `RenderedDist(
 | 
			
		||||
            Shape.T.mapY(
 | 
			
		||||
              ~integralSumCacheFn=integralSumCacheFn(scaleBy),
 | 
			
		||||
              ~integralCacheFn=integralCacheFn(scaleBy),
 | 
			
		||||
              ~fn=fn(scaleBy),
 | 
			
		||||
              rs,
 | 
			
		||||
            ),
 | 
			
		||||
          ),
 | 
			
		||||
        )
 | 
			
		||||
      | (Error(e1), _) => Error(e1)
 | 
			
		||||
      | (_, _) => Error("Can only scale by float values.")
 | 
			
		||||
      };
 | 
			
		||||
    s;
 | 
			
		||||
  };
 | 
			
		||||
};
 | 
			
		||||
 | 
			
		||||
module PointwiseCombination = {
 | 
			
		||||
  let pointwiseAdd = (evaluationParams: evaluationParams, t1: t, t2: t) => {
 | 
			
		||||
    switch (
 | 
			
		||||
| 
						 | 
				
			
			@ -309,8 +279,6 @@ let rec toLeaf =
 | 
			
		|||
      t1,
 | 
			
		||||
      t2,
 | 
			
		||||
    )
 | 
			
		||||
  | `VerticalScaling(scaleOp, t, scaleBy) =>
 | 
			
		||||
    VerticalScaling.operationToLeaf(evaluationParams, scaleOp, t, scaleBy)
 | 
			
		||||
  | `Truncate(leftCutoff, rightCutoff, t) =>
 | 
			
		||||
    Truncate.operationToLeaf(evaluationParams, leftCutoff, rightCutoff, t)
 | 
			
		||||
  | `Normalize(t) => Normalize.operationToLeaf(evaluationParams, t)
 | 
			
		||||
| 
						 | 
				
			
			@ -336,12 +304,7 @@ let rec toLeaf =
 | 
			
		|||
    let components =
 | 
			
		||||
      r
 | 
			
		||||
      |> E.A.fmap(((dist, weight)) =>
 | 
			
		||||
           `VerticalScaling((
 | 
			
		||||
             `Multiply,
 | 
			
		||||
             dist,
 | 
			
		||||
             `SymbolicDist(`Float(weight)),
 | 
			
		||||
           ))
 | 
			
		||||
         );
 | 
			
		||||
      `FunctionCall("scaleExp", [|dist, `SymbolicDist(`Float(weight))|]));
 | 
			
		||||
    let pointwiseSum =
 | 
			
		||||
      components
 | 
			
		||||
      |> Js.Array.sliceFrom(1)
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -25,7 +25,6 @@ module ExpressionTree = {
 | 
			
		|||
    | `Function(array(string), node)
 | 
			
		||||
    | `AlgebraicCombination(algebraicOperation, node, node)
 | 
			
		||||
    | `PointwiseCombination(pointwiseOperation, node, node)
 | 
			
		||||
    | `VerticalScaling(scaleOperation, node, node)
 | 
			
		||||
    | `Normalize(node)
 | 
			
		||||
    | `Render(node)
 | 
			
		||||
    | `Truncate(option(float), option(float), node)
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -1,6 +1,9 @@
 | 
			
		|||
open TypeSystem;
 | 
			
		||||
 | 
			
		||||
let wrongInputsError = (r) => {Js.log2("Wrong inputs", r); Error("Wrong inputs")};
 | 
			
		||||
let wrongInputsError = r => {
 | 
			
		||||
  Js.log2("Wrong inputs", r);
 | 
			
		||||
  Error("Wrong inputs");
 | 
			
		||||
};
 | 
			
		||||
 | 
			
		||||
let to_: (float, float) => result(node, string) =
 | 
			
		||||
  (low, high) =>
 | 
			
		||||
| 
						 | 
				
			
			@ -20,7 +23,7 @@ let makeSymbolicFromTwoFloats = (name, fn) =>
 | 
			
		|||
    ~run=
 | 
			
		||||
      fun
 | 
			
		||||
      | [|`Float(a), `Float(b)|] => Ok(`SymbolicDist(fn(a, b)))
 | 
			
		||||
      | e => wrongInputsError(e)
 | 
			
		||||
      | e => wrongInputsError(e),
 | 
			
		||||
  );
 | 
			
		||||
 | 
			
		||||
let makeSymbolicFromOneFloat = (name, fn) =>
 | 
			
		||||
| 
						 | 
				
			
			@ -31,7 +34,7 @@ let makeSymbolicFromOneFloat = (name, fn) =>
 | 
			
		|||
    ~run=
 | 
			
		||||
      fun
 | 
			
		||||
      | [|`Float(a)|] => Ok(`SymbolicDist(fn(a)))
 | 
			
		||||
      | e => wrongInputsError(e)
 | 
			
		||||
      | e => wrongInputsError(e),
 | 
			
		||||
  );
 | 
			
		||||
 | 
			
		||||
let makeDistFloat = (name, fn) =>
 | 
			
		||||
| 
						 | 
				
			
			@ -41,8 +44,19 @@ let makeDistFloat = (name, fn) =>
 | 
			
		|||
    ~inputs=[|`SamplingDistribution, `Float|],
 | 
			
		||||
    ~run=
 | 
			
		||||
      fun
 | 
			
		||||
      | [|`SamplingDist(a), `Float(b)|] => (fn(a,b))
 | 
			
		||||
      | e => wrongInputsError(e)
 | 
			
		||||
      | [|`SamplingDist(a), `Float(b)|] => fn(a, b)
 | 
			
		||||
      | e => wrongInputsError(e),
 | 
			
		||||
  );
 | 
			
		||||
 | 
			
		||||
let makeRenderedDistFloat = (name, fn) =>
 | 
			
		||||
  Function.make(
 | 
			
		||||
    ~name,
 | 
			
		||||
    ~output=`RenderedDistribution,
 | 
			
		||||
    ~inputs=[|`RenderedDistribution, `Float|],
 | 
			
		||||
    ~run=
 | 
			
		||||
      fun
 | 
			
		||||
      | [|`RenderedDist(a), `Float(b)|] => fn(a, b)
 | 
			
		||||
      | e => wrongInputsError(e),
 | 
			
		||||
  );
 | 
			
		||||
 | 
			
		||||
let makeDist = (name, fn) =>
 | 
			
		||||
| 
						 | 
				
			
			@ -53,7 +67,7 @@ let makeDist = (name, fn) =>
 | 
			
		|||
    ~run=
 | 
			
		||||
      fun
 | 
			
		||||
      | [|`SamplingDist(a)|] => fn(a)
 | 
			
		||||
      | e => wrongInputsError(e)
 | 
			
		||||
      | e => wrongInputsError(e),
 | 
			
		||||
  );
 | 
			
		||||
 | 
			
		||||
let floatFromDist =
 | 
			
		||||
| 
						 | 
				
			
			@ -71,6 +85,22 @@ let floatFromDist =
 | 
			
		|||
  };
 | 
			
		||||
};
 | 
			
		||||
 | 
			
		||||
let verticalScaling = (scaleOp, rs, scaleBy) => {
 | 
			
		||||
  // scaleBy has to be a single float, otherwise we'll return an error.
 | 
			
		||||
  let fn = (secondary, main) =>
 | 
			
		||||
    Operation.Scale.toFn(scaleOp, main, secondary);
 | 
			
		||||
  let integralSumCacheFn = Operation.Scale.toIntegralSumCacheFn(scaleOp);
 | 
			
		||||
  let integralCacheFn = Operation.Scale.toIntegralCacheFn(scaleOp);
 | 
			
		||||
  Ok(`RenderedDist(
 | 
			
		||||
    Shape.T.mapY(
 | 
			
		||||
      ~integralSumCacheFn=integralSumCacheFn(scaleBy),
 | 
			
		||||
      ~integralCacheFn=integralCacheFn(scaleBy),
 | 
			
		||||
      ~fn=fn(scaleBy),
 | 
			
		||||
      rs,
 | 
			
		||||
    ),
 | 
			
		||||
  ));
 | 
			
		||||
};
 | 
			
		||||
 | 
			
		||||
let functions = [|
 | 
			
		||||
  makeSymbolicFromTwoFloats("normal", SymbolicDist.Normal.make),
 | 
			
		||||
  makeSymbolicFromTwoFloats("uniform", SymbolicDist.Uniform.make),
 | 
			
		||||
| 
						 | 
				
			
			@ -87,8 +117,8 @@ let functions = [|
 | 
			
		|||
    ~inputs=[|`Float, `Float|],
 | 
			
		||||
    ~run=
 | 
			
		||||
      fun
 | 
			
		||||
      | [|`Float(a), `Float(b)|] => to_(a,b)
 | 
			
		||||
      | e => wrongInputsError(e)
 | 
			
		||||
      | [|`Float(a), `Float(b)|] => to_(a, b)
 | 
			
		||||
      | e => wrongInputsError(e),
 | 
			
		||||
  ),
 | 
			
		||||
  Function.make(
 | 
			
		||||
    ~name="triangular",
 | 
			
		||||
| 
						 | 
				
			
			@ -99,11 +129,39 @@ let functions = [|
 | 
			
		|||
      | [|`Float(a), `Float(b), `Float(c)|] =>
 | 
			
		||||
        SymbolicDist.Triangular.make(a, b, c)
 | 
			
		||||
        |> E.R.fmap(r => `SymbolicDist(r))
 | 
			
		||||
      | e => wrongInputsError(e)
 | 
			
		||||
      | e => wrongInputsError(e),
 | 
			
		||||
  ),
 | 
			
		||||
  makeDistFloat("pdf", (dist, float) => floatFromDist(`Pdf(float), dist)),
 | 
			
		||||
  makeDistFloat("inv", (dist, float) => floatFromDist(`Inv(float), dist)),
 | 
			
		||||
  makeDistFloat("cdf", (dist, float) => floatFromDist(`Cdf(float), dist)),
 | 
			
		||||
  makeDist("mean", (dist) => floatFromDist(`Mean, dist)),
 | 
			
		||||
  makeDist("sample", (dist) => floatFromDist(`Sample, dist))
 | 
			
		||||
  makeDist("mean", dist => floatFromDist(`Mean, dist)),
 | 
			
		||||
  makeDist("sample", dist => floatFromDist(`Sample, dist)),
 | 
			
		||||
  Function.make(
 | 
			
		||||
    ~name="render",
 | 
			
		||||
    ~output=`RenderedDistribution,
 | 
			
		||||
    ~inputs=[|`RenderedDistribution|],
 | 
			
		||||
    ~run=
 | 
			
		||||
      fun
 | 
			
		||||
      | [|`RenderedDist(c)|] => Ok(`RenderedDist(c))
 | 
			
		||||
      | e => wrongInputsError(e),
 | 
			
		||||
  ),
 | 
			
		||||
  Function.make(
 | 
			
		||||
    ~name="normalize",
 | 
			
		||||
    ~output=`SamplingDistribution,
 | 
			
		||||
    ~inputs=[|`SamplingDistribution|],
 | 
			
		||||
    ~run=
 | 
			
		||||
      fun
 | 
			
		||||
      | [|`SamplingDist(`SymbolicDist(c))|] => Ok(`SymbolicDist(c))
 | 
			
		||||
      | [|`SamplingDist(`RenderedDist(c))|] => Ok(`RenderedDist(Shape.T.normalize(c)))
 | 
			
		||||
      | e => wrongInputsError(e),
 | 
			
		||||
  ),
 | 
			
		||||
  makeRenderedDistFloat("scaleExp", (dist, float) =>
 | 
			
		||||
    verticalScaling(`Exponentiate, dist, float)
 | 
			
		||||
  ),
 | 
			
		||||
  makeRenderedDistFloat("scaleMultiply", (dist, float) =>
 | 
			
		||||
    verticalScaling(`Multiply, dist, float)
 | 
			
		||||
  ),
 | 
			
		||||
  makeRenderedDistFloat("scaleLog", (dist, float) =>
 | 
			
		||||
    verticalScaling(`Log, dist, float)
 | 
			
		||||
  ),
 | 
			
		||||
|];
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -193,14 +193,6 @@ module MathAdtToDistDst = {
 | 
			
		|||
           Error(
 | 
			
		||||
             "truncate needs three arguments: the expression and both cutoffs",
 | 
			
		||||
           )
 | 
			
		||||
         | ("scaleMultiply", [|d, `SymbolicDist(`Float(v))|]) =>
 | 
			
		||||
           Ok(`VerticalScaling((`Multiply, d, `SymbolicDist(`Float(v)))))
 | 
			
		||||
         | ("scaleExp", [|d, `SymbolicDist(`Float(v))|]) =>
 | 
			
		||||
           Ok(
 | 
			
		||||
             `VerticalScaling((`Exponentiate, d, `SymbolicDist(`Float(v)))),
 | 
			
		||||
           )
 | 
			
		||||
         | ("scaleLog", [|d, `SymbolicDist(`Float(v))|]) =>
 | 
			
		||||
           Ok(`VerticalScaling((`Log, d, `SymbolicDist(`Float(v)))))
 | 
			
		||||
         | _ => Error("This type not currently supported")
 | 
			
		||||
         }
 | 
			
		||||
       });
 | 
			
		||||
| 
						 | 
				
			
			@ -245,9 +237,6 @@ module MathAdtToDistDst = {
 | 
			
		|||
    | "pow"
 | 
			
		||||
    | "leftTruncate"
 | 
			
		||||
    | "rightTruncate"
 | 
			
		||||
    | "scaleMultiply"
 | 
			
		||||
    | "scaleExp"
 | 
			
		||||
    | "scaleLog"
 | 
			
		||||
    | "truncate" => operationParser(name, parseArgs())
 | 
			
		||||
    | name =>
 | 
			
		||||
      parseArgs()
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
		Loading…
	
		Reference in New Issue
	
	Block a user