Minor cleanup
This commit is contained in:
parent
f72a3ddde6
commit
3147e6d9c3
|
@ -289,28 +289,6 @@ module Render = {
|
|||
};
|
||||
};
|
||||
|
||||
let run = (node, fnNode) => {
|
||||
switch (fnNode) {
|
||||
| `Function(r) => Ok(r(node))
|
||||
| _ => Error("Not a function")
|
||||
};
|
||||
};
|
||||
|
||||
// let outputType = (t:t) => t |> fun
|
||||
// | `SymbolicDist(_) => `RenderedDist
|
||||
// | `RenderedDist(_) => `RenderedDist
|
||||
// | `Bool(_) => `RenderedDist
|
||||
// | `AlgebraicCombination(_) => `RenderedDist
|
||||
// | `PointwiseCombination(_) => `RenderedDist
|
||||
// | `VerticalScaling(_) => `RenderedDist
|
||||
// | `Truncate(_) => `RenderedDist
|
||||
// | `FloatFromDist(_) => `RenderedDist
|
||||
// | `Normalize(_) => `RenderedDist
|
||||
// | `Render(_) => `RenderedDist
|
||||
// | `Function(_) => `Function
|
||||
// | `FunctionCall(_) => `Any
|
||||
// | `Symbol(_) => `Any
|
||||
|
||||
/* This function recursively goes through the nodes of the parse tree,
|
||||
replacing each Operation node and its subtree with a Data node.
|
||||
Whenever possible, the replacement produces a new Symbolic Data node,
|
||||
|
|
|
@ -18,12 +18,13 @@ module ExpressionTree = {
|
|||
| `AlgebraicCombination(algebraicOperation, node, node)
|
||||
| `PointwiseCombination(pointwiseOperation, node, node)
|
||||
| `VerticalScaling(scaleOperation, node, node)
|
||||
| `Normalize(node)
|
||||
| `Render(node)
|
||||
| `Truncate(option(float), option(float), node)
|
||||
| `Normalize(node)
|
||||
| `FloatFromDist(distToFloatOperation, node)
|
||||
| `FunctionCall(string, array(node))
|
||||
];
|
||||
// Have nil as option
|
||||
|
||||
type samplingInputs = {
|
||||
sampleCount: int,
|
||||
|
|
|
@ -49,31 +49,6 @@ let to_: array(node) => result(node, string) =
|
|||
Error("Low value must be less than high value.")
|
||||
| _ => Error("Requires 2 variables");
|
||||
|
||||
let processCustomFn =
|
||||
(
|
||||
evaluationParams: ExpressionTypes.ExpressionTree.evaluationParams,
|
||||
args: array(node),
|
||||
argNames: array(string),
|
||||
fnResult: node,
|
||||
) =>
|
||||
if (E.A.length(args) == E.A.length(argNames)) {
|
||||
let newEnvironment =
|
||||
Belt.Array.zip(argNames, args)
|
||||
|> ExpressionTypes.ExpressionTree.Environment.fromArray;
|
||||
let newEvaluationParams: ExpressionTypes.ExpressionTree.evaluationParams = {
|
||||
samplingInputs: evaluationParams.samplingInputs,
|
||||
environment:
|
||||
ExpressionTypes.ExpressionTree.Environment.mergeKeepSecond(
|
||||
evaluationParams.environment,
|
||||
newEnvironment,
|
||||
),
|
||||
evaluateNode: evaluationParams.evaluateNode,
|
||||
};
|
||||
evaluationParams.evaluateNode(newEvaluationParams, fnResult);
|
||||
} else {
|
||||
Error("Failure");
|
||||
};
|
||||
|
||||
let fnn =
|
||||
(
|
||||
evaluationParams: ExpressionTypes.ExpressionTree.evaluationParams,
|
||||
|
@ -88,7 +63,7 @@ let fnn =
|
|||
),
|
||||
) {
|
||||
| (_, Some(`Function(argNames, tt))) =>
|
||||
processCustomFn(evaluationParams, args, argNames, tt)
|
||||
PTypes.Function.run(evaluationParams, args, (argNames, tt))
|
||||
| ("normal", _) =>
|
||||
apply2(twoFloatsToOkSym(SymbolicDist.Normal.make), args)
|
||||
| ("uniform", _) =>
|
||||
|
|
|
@ -1,4 +1,3 @@
|
|||
[%%debugger.chrome]
|
||||
module MathJsonToMathJsAdt = {
|
||||
type arg =
|
||||
| Symbol(string)
|
||||
|
|
|
@ -1,12 +1,80 @@
|
|||
open ExpressionTypes.ExpressionTree;
|
||||
|
||||
module Function = {
|
||||
type t = (array(string), node);
|
||||
let fromNode: node => option(t) =
|
||||
node =>
|
||||
switch (node) {
|
||||
| `Function(r) => Some(r)
|
||||
| _ => None
|
||||
};
|
||||
let argumentNames = ((a, _): t) => a;
|
||||
let internals = ((_, b): t) => b;
|
||||
let run =
|
||||
(
|
||||
evaluationParams: ExpressionTypes.ExpressionTree.evaluationParams,
|
||||
args: array(node),
|
||||
t: t,
|
||||
) =>
|
||||
if (E.A.length(args) == E.A.length(argumentNames(t))) {
|
||||
let newEnvironment =
|
||||
Belt.Array.zip(argumentNames(t), args)
|
||||
|> ExpressionTypes.ExpressionTree.Environment.fromArray;
|
||||
let newEvaluationParams: ExpressionTypes.ExpressionTree.evaluationParams = {
|
||||
samplingInputs: evaluationParams.samplingInputs,
|
||||
environment:
|
||||
ExpressionTypes.ExpressionTree.Environment.mergeKeepSecond(
|
||||
evaluationParams.environment,
|
||||
newEnvironment,
|
||||
),
|
||||
evaluateNode: evaluationParams.evaluateNode,
|
||||
};
|
||||
evaluationParams.evaluateNode(newEvaluationParams, internals(t));
|
||||
} else {
|
||||
Error("Failure");
|
||||
};
|
||||
};
|
||||
|
||||
module Primative = {
|
||||
type t = [
|
||||
| `SymbolicDist(SymbolicTypes.symbolicDist)
|
||||
| `RenderedDist(DistTypes.shape)
|
||||
| `Function(array(string), node)
|
||||
];
|
||||
|
||||
let isPrimative: node => bool =
|
||||
fun
|
||||
| `SymbolicDist(_)
|
||||
| `RenderedDist(_)
|
||||
| `Function(_) => true
|
||||
| _ => false;
|
||||
|
||||
let fromNode: node => option(t) =
|
||||
fun
|
||||
| `SymbolicDist(_) as n
|
||||
| `RenderedDist(_) as n
|
||||
| `Function(_) as n => Some(n)
|
||||
| _ => None;
|
||||
};
|
||||
|
||||
module SamplingDistribution = {
|
||||
type t = [
|
||||
| `SymbolicDist(SymbolicTypes.symbolicDist)
|
||||
| `RenderedDist(DistTypes.shape)
|
||||
];
|
||||
|
||||
let isSamplingDistribution: node => bool =
|
||||
fun
|
||||
| `SymbolicDist(_) => true
|
||||
| `RenderedDist(_) => true
|
||||
| _ => false;
|
||||
|
||||
let fromNode: node => result(t, string) =
|
||||
fun
|
||||
| `SymbolicDist(n) => Ok(`SymbolicDist(n))
|
||||
| `RenderedDist(n) => Ok(`RenderedDist(n))
|
||||
| _ => Error("Not valid type");
|
||||
|
||||
let renderIfIsNotSamplingDistribution = (params, t): result(node, string) =>
|
||||
!isSamplingDistribution(t)
|
||||
? switch (Render.render(params, t)) {
|
||||
|
|
Loading…
Reference in New Issue
Block a user