Moved code from Functions.re

This commit is contained in:
Ozzie Gooen 2020-11-08 18:02:23 -08:00
parent a6051d8371
commit d2e7605ffd
6 changed files with 62 additions and 61 deletions

View File

@ -215,18 +215,46 @@ module Normalize = {
};
};
module FunctionCall = {
let _runHardcodedFunction = (name, evaluationParams, args) =>
TypeSystem.Function.Ts.findByNameAndRun(
Fns.functions,
name,
evaluationParams,
args,
);
let _runLocalFunction = (name, evaluationParams: evaluationParams, args) => {
Environment.getFunction(evaluationParams.environment, name)
|> E.R.bind(_, ((argNames, fn)) =>
PTypes.Function.run(evaluationParams, args, (argNames, fn))
);
};
let _runWithEvaluatedInputs =
(
evaluationParams: ExpressionTypes.ExpressionTree.evaluationParams,
name,
args: array(ExpressionTypes.ExpressionTree.node),
) => {
_runHardcodedFunction(name, evaluationParams, args)
|> E.O.default(_runLocalFunction(name, evaluationParams, args));
};
// TODO: This forces things to be floats
let callableFunction = (evaluationParams, name, args) => {
args
|> E.A.fmap(a => evaluationParams.evaluateNode(evaluationParams, a))
|> E.A.R.firstErrorOrOpen
|> E.R.bind(_, Functions.fnn(evaluationParams, name));
let run = (evaluationParams, name, args) => {
args
|> E.A.fmap(a => evaluationParams.evaluateNode(evaluationParams, a))
|> E.A.R.firstErrorOrOpen
|> E.R.bind(_, _runWithEvaluatedInputs(evaluationParams, name));
};
};
module Render = {
let rec operationToLeaf =
(evaluationParams: evaluationParams, t: node): result(t, string) => {
Js.log2("rendering", t);
Js.log2("rendering", t);
switch (t) {
| `Function(_) => Error("Cannot render a function")
| `SymbolicDist(d) =>
@ -305,7 +333,7 @@ let rec toLeaf =
|> E.R.bind(_, toLeaf(evaluationParams));
| `FunctionCall(name, args) =>
Js.log3("In function call", name, args);
callableFunction(evaluationParams, name, args)
FunctionCall.run(evaluationParams, name, args)
|> E.R.bind(_, toLeaf(evaluationParams));
};
};

View File

@ -99,6 +99,11 @@ module ExpressionTree = {
);
let update = (t, str, fn) => MS.update(t, str, fn);
let get = (t: t, str) => MS.get(t, str);
let getFunction = (t: t, str) =>
switch (get(t, str)) {
| Some(`Function(argNames, fn)) => Ok((argNames, fn))
| _ => Error("Function " ++ str ++ " not found")
};
};
type evaluationParams = {

View File

@ -1,33 +0,0 @@
type node = ExpressionTypes.ExpressionTree.node;
let toOkSym = r => Ok(`SymbolicDist(r));
let getFloat = ExpressionTypes.ExpressionTree.getFloat;
let fnn =
(
evaluationParams: ExpressionTypes.ExpressionTree.evaluationParams,
name,
args: array(node),
) => {
// let foundFn =
// TypeSystem.Function.Ts.findByName(Fns.functions, name) |> E.O.toResult("Function " ++ name ++ " not found");
// let ran = foundFn |> E.R.bind(_,TypeSystem.Function.T.run(evaluationParams,args))
let foundFn =
TypeSystem.Function.Ts.findByNameAndRun(Fns.functions, name, evaluationParams, args);
switch (foundFn) {
| Some(r) => r
| None =>
switch (
name,
ExpressionTypes.ExpressionTree.Environment.get(
evaluationParams.environment,
name,
),
) {
| (_, Some(`Function(argNames, tt))) =>
Js.log("Fundction found: " ++ name);
PTypes.Function.run(evaluationParams, args, (argNames, tt))
| _ => Error("Function " ++ name ++ " not found")
}
};
};

View File

@ -33,6 +33,7 @@ module Function = {
} else {
Error("Wrong number of variables");
};
};
module Primative = {

View File

@ -112,7 +112,7 @@ module Multimodal = {
let _paramsToDistsAndWeights = (r: array(typedValue)) =>
switch (r) {
| [|`Named(r)|] =>
| [|`Hash(r)|] =>
let dists =
getByNameResult(r, "dists")
->E.R.bind(TypeSystem.TypedValue.toArray)
@ -171,7 +171,7 @@ module Multimodal = {
~name="multimodal",
~outputType=`SamplingDistribution,
~inputTypes=[|
`Named([|
`Hash([|
("dists", `Array(`SamplingDistribution)),
("weights", `Array(`Float)),
|]),

View File

@ -6,20 +6,22 @@ type samplingDist = [
| `RenderedDist(DistTypes.shape)
];
type _type = [
type hashType = array((string, _type))
and _type = [
| `Float
| `SamplingDistribution
| `RenderedDistribution
| `Array(_type)
| `Named(array((string, _type)))
| `Hash(hashType)
];
type typedValue = [
type hashTypedValue = array((string, typedValue))
and typedValue = [
| `Float(float)
| `RenderedDist(DistTypes.shape)
| `SamplingDist(samplingDist)
| `Array(array(typedValue))
| `Named(array((string, typedValue)))
| `Hash(hashTypedValue)
];
type _function = {
@ -48,7 +50,7 @@ module TypedValue = {
hash
|> E.A.fmap(((name, t)) => fromNode(t) |> E.R.fmap(r => (name, r)))
|> E.A.R.firstErrorOrOpen
|> E.R.fmap(r => `Named(r))
|> E.R.fmap(r => `Hash(r))
| _ => Error("Wrong type")
};
@ -74,8 +76,8 @@ module TypedValue = {
|> E.A.fmap(fromNodeWithTypeCoercion(evaluationParams, _type))
|> E.A.R.firstErrorOrOpen
|> E.R.fmap(r => `Array(r))
| (`Named(named), `Hash(r)) =>
let foo =
| (`Hash(named), `Hash(r)) =>
let keyValues =
named
|> E.A.fmap(((name, intendedType)) =>
(
@ -84,8 +86,8 @@ module TypedValue = {
ExpressionTypes.ExpressionTree.Hash.getByName(r, name),
)
);
let bar =
foo
let typedHash =
keyValues
|> E.A.fmap(((name, intendedType, optionNode)) =>
switch (optionNode) {
| Some(node) =>
@ -95,34 +97,33 @@ module TypedValue = {
}
)
|> E.A.R.firstErrorOrOpen
|> E.R.fmap(r => `Named(r));
bar;
|> E.R.fmap(r => `Hash(r));
typedHash;
| _ => Error("fromNodeWithTypeCoercion error, sorry.")
};
};
let toFloat =
let toFloat: typedValue => result(float,string) =
fun
| `Float(x) => Ok(x)
| _ => Error("Not a float");
let toArray =
let toArray: typedValue => result(array('a),string) =
fun
| `Array(x) => Ok(x)
| _ => Error("Not an array");
let toNamed =
let toNamed: typedValue => result(hashTypedValue, string) =
fun
| `Named(x) => Ok(x)
| `Hash(x) => Ok(x)
| _ => Error("Not a named item");
let toDist =
fun
| `SamplingDist(`SymbolicDist(c)) => Ok(`SymbolicDist(c))
| `SamplingDist(`RenderedDist(c)) => Ok(`RenderedDist(c))
| `Float(x) =>
Ok(`RenderedDist(SymbolicDist.T.toShape(1000, `Float(x))))
| _ => Error("");
| `Float(x) => Ok(`SymbolicDist(`Float(x)))
| _ => Error("Cannot be converted into a distribution");
};
module Function = {
@ -188,7 +189,6 @@ module Function = {
inputNodes: inputNodes,
t: t,
) => {
Js.log("Running!");
inputsToTypedValues(evaluationParams, inputNodes, t)->E.R.bind(t.run)
|> (
fun