Moved code from Functions.re
This commit is contained in:
parent
a6051d8371
commit
d2e7605ffd
|
@ -215,18 +215,46 @@ module Normalize = {
|
|||
};
|
||||
};
|
||||
|
||||
module FunctionCall = {
|
||||
let _runHardcodedFunction = (name, evaluationParams, args) =>
|
||||
TypeSystem.Function.Ts.findByNameAndRun(
|
||||
Fns.functions,
|
||||
name,
|
||||
evaluationParams,
|
||||
args,
|
||||
);
|
||||
|
||||
let _runLocalFunction = (name, evaluationParams: evaluationParams, args) => {
|
||||
Environment.getFunction(evaluationParams.environment, name)
|
||||
|> E.R.bind(_, ((argNames, fn)) =>
|
||||
PTypes.Function.run(evaluationParams, args, (argNames, fn))
|
||||
);
|
||||
};
|
||||
|
||||
let _runWithEvaluatedInputs =
|
||||
(
|
||||
evaluationParams: ExpressionTypes.ExpressionTree.evaluationParams,
|
||||
name,
|
||||
args: array(ExpressionTypes.ExpressionTree.node),
|
||||
) => {
|
||||
_runHardcodedFunction(name, evaluationParams, args)
|
||||
|> E.O.default(_runLocalFunction(name, evaluationParams, args));
|
||||
};
|
||||
|
||||
// TODO: This forces things to be floats
|
||||
let callableFunction = (evaluationParams, name, args) => {
|
||||
args
|
||||
|> E.A.fmap(a => evaluationParams.evaluateNode(evaluationParams, a))
|
||||
|> E.A.R.firstErrorOrOpen
|
||||
|> E.R.bind(_, Functions.fnn(evaluationParams, name));
|
||||
let run = (evaluationParams, name, args) => {
|
||||
args
|
||||
|> E.A.fmap(a => evaluationParams.evaluateNode(evaluationParams, a))
|
||||
|> E.A.R.firstErrorOrOpen
|
||||
|> E.R.bind(_, _runWithEvaluatedInputs(evaluationParams, name));
|
||||
};
|
||||
};
|
||||
|
||||
|
||||
module Render = {
|
||||
let rec operationToLeaf =
|
||||
(evaluationParams: evaluationParams, t: node): result(t, string) => {
|
||||
Js.log2("rendering", t);
|
||||
Js.log2("rendering", t);
|
||||
switch (t) {
|
||||
| `Function(_) => Error("Cannot render a function")
|
||||
| `SymbolicDist(d) =>
|
||||
|
@ -305,7 +333,7 @@ let rec toLeaf =
|
|||
|> E.R.bind(_, toLeaf(evaluationParams));
|
||||
| `FunctionCall(name, args) =>
|
||||
Js.log3("In function call", name, args);
|
||||
callableFunction(evaluationParams, name, args)
|
||||
FunctionCall.run(evaluationParams, name, args)
|
||||
|> E.R.bind(_, toLeaf(evaluationParams));
|
||||
};
|
||||
};
|
||||
|
|
|
@ -99,6 +99,11 @@ module ExpressionTree = {
|
|||
);
|
||||
let update = (t, str, fn) => MS.update(t, str, fn);
|
||||
let get = (t: t, str) => MS.get(t, str);
|
||||
let getFunction = (t: t, str) =>
|
||||
switch (get(t, str)) {
|
||||
| Some(`Function(argNames, fn)) => Ok((argNames, fn))
|
||||
| _ => Error("Function " ++ str ++ " not found")
|
||||
};
|
||||
};
|
||||
|
||||
type evaluationParams = {
|
||||
|
|
|
@ -1,33 +0,0 @@
|
|||
type node = ExpressionTypes.ExpressionTree.node;
|
||||
|
||||
let toOkSym = r => Ok(`SymbolicDist(r));
|
||||
let getFloat = ExpressionTypes.ExpressionTree.getFloat;
|
||||
let fnn =
|
||||
(
|
||||
evaluationParams: ExpressionTypes.ExpressionTree.evaluationParams,
|
||||
name,
|
||||
args: array(node),
|
||||
) => {
|
||||
// let foundFn =
|
||||
// TypeSystem.Function.Ts.findByName(Fns.functions, name) |> E.O.toResult("Function " ++ name ++ " not found");
|
||||
// let ran = foundFn |> E.R.bind(_,TypeSystem.Function.T.run(evaluationParams,args))
|
||||
|
||||
let foundFn =
|
||||
TypeSystem.Function.Ts.findByNameAndRun(Fns.functions, name, evaluationParams, args);
|
||||
switch (foundFn) {
|
||||
| Some(r) => r
|
||||
| None =>
|
||||
switch (
|
||||
name,
|
||||
ExpressionTypes.ExpressionTree.Environment.get(
|
||||
evaluationParams.environment,
|
||||
name,
|
||||
),
|
||||
) {
|
||||
| (_, Some(`Function(argNames, tt))) =>
|
||||
Js.log("Fundction found: " ++ name);
|
||||
PTypes.Function.run(evaluationParams, args, (argNames, tt))
|
||||
| _ => Error("Function " ++ name ++ " not found")
|
||||
}
|
||||
};
|
||||
};
|
|
@ -33,6 +33,7 @@ module Function = {
|
|||
} else {
|
||||
Error("Wrong number of variables");
|
||||
};
|
||||
|
||||
};
|
||||
|
||||
module Primative = {
|
||||
|
|
|
@ -112,7 +112,7 @@ module Multimodal = {
|
|||
|
||||
let _paramsToDistsAndWeights = (r: array(typedValue)) =>
|
||||
switch (r) {
|
||||
| [|`Named(r)|] =>
|
||||
| [|`Hash(r)|] =>
|
||||
let dists =
|
||||
getByNameResult(r, "dists")
|
||||
->E.R.bind(TypeSystem.TypedValue.toArray)
|
||||
|
@ -171,7 +171,7 @@ module Multimodal = {
|
|||
~name="multimodal",
|
||||
~outputType=`SamplingDistribution,
|
||||
~inputTypes=[|
|
||||
`Named([|
|
||||
`Hash([|
|
||||
("dists", `Array(`SamplingDistribution)),
|
||||
("weights", `Array(`Float)),
|
||||
|]),
|
|
@ -6,20 +6,22 @@ type samplingDist = [
|
|||
| `RenderedDist(DistTypes.shape)
|
||||
];
|
||||
|
||||
type _type = [
|
||||
type hashType = array((string, _type))
|
||||
and _type = [
|
||||
| `Float
|
||||
| `SamplingDistribution
|
||||
| `RenderedDistribution
|
||||
| `Array(_type)
|
||||
| `Named(array((string, _type)))
|
||||
| `Hash(hashType)
|
||||
];
|
||||
|
||||
type typedValue = [
|
||||
type hashTypedValue = array((string, typedValue))
|
||||
and typedValue = [
|
||||
| `Float(float)
|
||||
| `RenderedDist(DistTypes.shape)
|
||||
| `SamplingDist(samplingDist)
|
||||
| `Array(array(typedValue))
|
||||
| `Named(array((string, typedValue)))
|
||||
| `Hash(hashTypedValue)
|
||||
];
|
||||
|
||||
type _function = {
|
||||
|
@ -48,7 +50,7 @@ module TypedValue = {
|
|||
hash
|
||||
|> E.A.fmap(((name, t)) => fromNode(t) |> E.R.fmap(r => (name, r)))
|
||||
|> E.A.R.firstErrorOrOpen
|
||||
|> E.R.fmap(r => `Named(r))
|
||||
|> E.R.fmap(r => `Hash(r))
|
||||
| _ => Error("Wrong type")
|
||||
};
|
||||
|
||||
|
@ -74,8 +76,8 @@ module TypedValue = {
|
|||
|> E.A.fmap(fromNodeWithTypeCoercion(evaluationParams, _type))
|
||||
|> E.A.R.firstErrorOrOpen
|
||||
|> E.R.fmap(r => `Array(r))
|
||||
| (`Named(named), `Hash(r)) =>
|
||||
let foo =
|
||||
| (`Hash(named), `Hash(r)) =>
|
||||
let keyValues =
|
||||
named
|
||||
|> E.A.fmap(((name, intendedType)) =>
|
||||
(
|
||||
|
@ -84,8 +86,8 @@ module TypedValue = {
|
|||
ExpressionTypes.ExpressionTree.Hash.getByName(r, name),
|
||||
)
|
||||
);
|
||||
let bar =
|
||||
foo
|
||||
let typedHash =
|
||||
keyValues
|
||||
|> E.A.fmap(((name, intendedType, optionNode)) =>
|
||||
switch (optionNode) {
|
||||
| Some(node) =>
|
||||
|
@ -95,34 +97,33 @@ module TypedValue = {
|
|||
}
|
||||
)
|
||||
|> E.A.R.firstErrorOrOpen
|
||||
|> E.R.fmap(r => `Named(r));
|
||||
bar;
|
||||
|> E.R.fmap(r => `Hash(r));
|
||||
typedHash;
|
||||
| _ => Error("fromNodeWithTypeCoercion error, sorry.")
|
||||
};
|
||||
};
|
||||
|
||||
let toFloat =
|
||||
let toFloat: typedValue => result(float,string) =
|
||||
fun
|
||||
| `Float(x) => Ok(x)
|
||||
| _ => Error("Not a float");
|
||||
|
||||
let toArray =
|
||||
let toArray: typedValue => result(array('a),string) =
|
||||
fun
|
||||
| `Array(x) => Ok(x)
|
||||
| _ => Error("Not an array");
|
||||
|
||||
let toNamed =
|
||||
let toNamed: typedValue => result(hashTypedValue, string) =
|
||||
fun
|
||||
| `Named(x) => Ok(x)
|
||||
| `Hash(x) => Ok(x)
|
||||
| _ => Error("Not a named item");
|
||||
|
||||
let toDist =
|
||||
fun
|
||||
| `SamplingDist(`SymbolicDist(c)) => Ok(`SymbolicDist(c))
|
||||
| `SamplingDist(`RenderedDist(c)) => Ok(`RenderedDist(c))
|
||||
| `Float(x) =>
|
||||
Ok(`RenderedDist(SymbolicDist.T.toShape(1000, `Float(x))))
|
||||
| _ => Error("");
|
||||
| `Float(x) => Ok(`SymbolicDist(`Float(x)))
|
||||
| _ => Error("Cannot be converted into a distribution");
|
||||
};
|
||||
|
||||
module Function = {
|
||||
|
@ -188,7 +189,6 @@ module Function = {
|
|||
inputNodes: inputNodes,
|
||||
t: t,
|
||||
) => {
|
||||
Js.log("Running!");
|
||||
inputsToTypedValues(evaluationParams, inputNodes, t)->E.R.bind(t.run)
|
||||
|> (
|
||||
fun
|
Loading…
Reference in New Issue
Block a user