Minor cleanup

This commit is contained in:
Ozzie Gooen 2020-08-07 10:29:47 +01:00
parent f72a3ddde6
commit 3147e6d9c3
5 changed files with 71 additions and 50 deletions

View File

@ -289,28 +289,6 @@ module Render = {
};
};
let run = (node, fnNode) => {
switch (fnNode) {
| `Function(r) => Ok(r(node))
| _ => Error("Not a function")
};
};
// let outputType = (t:t) => t |> fun
// | `SymbolicDist(_) => `RenderedDist
// | `RenderedDist(_) => `RenderedDist
// | `Bool(_) => `RenderedDist
// | `AlgebraicCombination(_) => `RenderedDist
// | `PointwiseCombination(_) => `RenderedDist
// | `VerticalScaling(_) => `RenderedDist
// | `Truncate(_) => `RenderedDist
// | `FloatFromDist(_) => `RenderedDist
// | `Normalize(_) => `RenderedDist
// | `Render(_) => `RenderedDist
// | `Function(_) => `Function
// | `FunctionCall(_) => `Any
// | `Symbol(_) => `Any
/* This function recursively goes through the nodes of the parse tree,
replacing each Operation node and its subtree with a Data node.
Whenever possible, the replacement produces a new Symbolic Data node,

View File

@ -18,12 +18,13 @@ module ExpressionTree = {
| `AlgebraicCombination(algebraicOperation, node, node)
| `PointwiseCombination(pointwiseOperation, node, node)
| `VerticalScaling(scaleOperation, node, node)
| `Normalize(node)
| `Render(node)
| `Truncate(option(float), option(float), node)
| `Normalize(node)
| `FloatFromDist(distToFloatOperation, node)
| `FunctionCall(string, array(node))
];
// Have nil as option
type samplingInputs = {
sampleCount: int,

View File

@ -49,31 +49,6 @@ let to_: array(node) => result(node, string) =
Error("Low value must be less than high value.")
| _ => Error("Requires 2 variables");
let processCustomFn =
(
evaluationParams: ExpressionTypes.ExpressionTree.evaluationParams,
args: array(node),
argNames: array(string),
fnResult: node,
) =>
if (E.A.length(args) == E.A.length(argNames)) {
let newEnvironment =
Belt.Array.zip(argNames, args)
|> ExpressionTypes.ExpressionTree.Environment.fromArray;
let newEvaluationParams: ExpressionTypes.ExpressionTree.evaluationParams = {
samplingInputs: evaluationParams.samplingInputs,
environment:
ExpressionTypes.ExpressionTree.Environment.mergeKeepSecond(
evaluationParams.environment,
newEnvironment,
),
evaluateNode: evaluationParams.evaluateNode,
};
evaluationParams.evaluateNode(newEvaluationParams, fnResult);
} else {
Error("Failure");
};
let fnn =
(
evaluationParams: ExpressionTypes.ExpressionTree.evaluationParams,
@ -88,7 +63,7 @@ let fnn =
),
) {
| (_, Some(`Function(argNames, tt))) =>
processCustomFn(evaluationParams, args, argNames, tt)
PTypes.Function.run(evaluationParams, args, (argNames, tt))
| ("normal", _) =>
apply2(twoFloatsToOkSym(SymbolicDist.Normal.make), args)
| ("uniform", _) =>

View File

@ -1,4 +1,3 @@
[%%debugger.chrome]
module MathJsonToMathJsAdt = {
type arg =
| Symbol(string)

View File

@ -1,12 +1,80 @@
open ExpressionTypes.ExpressionTree;
module Function = {
type t = (array(string), node);
let fromNode: node => option(t) =
node =>
switch (node) {
| `Function(r) => Some(r)
| _ => None
};
let argumentNames = ((a, _): t) => a;
let internals = ((_, b): t) => b;
let run =
(
evaluationParams: ExpressionTypes.ExpressionTree.evaluationParams,
args: array(node),
t: t,
) =>
if (E.A.length(args) == E.A.length(argumentNames(t))) {
let newEnvironment =
Belt.Array.zip(argumentNames(t), args)
|> ExpressionTypes.ExpressionTree.Environment.fromArray;
let newEvaluationParams: ExpressionTypes.ExpressionTree.evaluationParams = {
samplingInputs: evaluationParams.samplingInputs,
environment:
ExpressionTypes.ExpressionTree.Environment.mergeKeepSecond(
evaluationParams.environment,
newEnvironment,
),
evaluateNode: evaluationParams.evaluateNode,
};
evaluationParams.evaluateNode(newEvaluationParams, internals(t));
} else {
Error("Failure");
};
};
module Primative = {
type t = [
| `SymbolicDist(SymbolicTypes.symbolicDist)
| `RenderedDist(DistTypes.shape)
| `Function(array(string), node)
];
let isPrimative: node => bool =
fun
| `SymbolicDist(_)
| `RenderedDist(_)
| `Function(_) => true
| _ => false;
let fromNode: node => option(t) =
fun
| `SymbolicDist(_) as n
| `RenderedDist(_) as n
| `Function(_) as n => Some(n)
| _ => None;
};
module SamplingDistribution = {
type t = [
| `SymbolicDist(SymbolicTypes.symbolicDist)
| `RenderedDist(DistTypes.shape)
];
let isSamplingDistribution: node => bool =
fun
| `SymbolicDist(_) => true
| `RenderedDist(_) => true
| _ => false;
let fromNode: node => result(t, string) =
fun
| `SymbolicDist(n) => Ok(`SymbolicDist(n))
| `RenderedDist(n) => Ok(`RenderedDist(n))
| _ => Error("Not valid type");
let renderIfIsNotSamplingDistribution = (params, t): result(node, string) =>
!isSamplingDistribution(t)
? switch (Render.render(params, t)) {