Moved code from Functions.re
This commit is contained in:
parent
a6051d8371
commit
d2e7605ffd
|
@ -215,18 +215,46 @@ module Normalize = {
|
||||||
};
|
};
|
||||||
};
|
};
|
||||||
|
|
||||||
|
module FunctionCall = {
|
||||||
|
let _runHardcodedFunction = (name, evaluationParams, args) =>
|
||||||
|
TypeSystem.Function.Ts.findByNameAndRun(
|
||||||
|
Fns.functions,
|
||||||
|
name,
|
||||||
|
evaluationParams,
|
||||||
|
args,
|
||||||
|
);
|
||||||
|
|
||||||
|
let _runLocalFunction = (name, evaluationParams: evaluationParams, args) => {
|
||||||
|
Environment.getFunction(evaluationParams.environment, name)
|
||||||
|
|> E.R.bind(_, ((argNames, fn)) =>
|
||||||
|
PTypes.Function.run(evaluationParams, args, (argNames, fn))
|
||||||
|
);
|
||||||
|
};
|
||||||
|
|
||||||
|
let _runWithEvaluatedInputs =
|
||||||
|
(
|
||||||
|
evaluationParams: ExpressionTypes.ExpressionTree.evaluationParams,
|
||||||
|
name,
|
||||||
|
args: array(ExpressionTypes.ExpressionTree.node),
|
||||||
|
) => {
|
||||||
|
_runHardcodedFunction(name, evaluationParams, args)
|
||||||
|
|> E.O.default(_runLocalFunction(name, evaluationParams, args));
|
||||||
|
};
|
||||||
|
|
||||||
// TODO: This forces things to be floats
|
// TODO: This forces things to be floats
|
||||||
let callableFunction = (evaluationParams, name, args) => {
|
let run = (evaluationParams, name, args) => {
|
||||||
args
|
args
|
||||||
|> E.A.fmap(a => evaluationParams.evaluateNode(evaluationParams, a))
|
|> E.A.fmap(a => evaluationParams.evaluateNode(evaluationParams, a))
|
||||||
|> E.A.R.firstErrorOrOpen
|
|> E.A.R.firstErrorOrOpen
|
||||||
|> E.R.bind(_, Functions.fnn(evaluationParams, name));
|
|> E.R.bind(_, _runWithEvaluatedInputs(evaluationParams, name));
|
||||||
|
};
|
||||||
};
|
};
|
||||||
|
|
||||||
|
|
||||||
module Render = {
|
module Render = {
|
||||||
let rec operationToLeaf =
|
let rec operationToLeaf =
|
||||||
(evaluationParams: evaluationParams, t: node): result(t, string) => {
|
(evaluationParams: evaluationParams, t: node): result(t, string) => {
|
||||||
Js.log2("rendering", t);
|
Js.log2("rendering", t);
|
||||||
switch (t) {
|
switch (t) {
|
||||||
| `Function(_) => Error("Cannot render a function")
|
| `Function(_) => Error("Cannot render a function")
|
||||||
| `SymbolicDist(d) =>
|
| `SymbolicDist(d) =>
|
||||||
|
@ -305,7 +333,7 @@ let rec toLeaf =
|
||||||
|> E.R.bind(_, toLeaf(evaluationParams));
|
|> E.R.bind(_, toLeaf(evaluationParams));
|
||||||
| `FunctionCall(name, args) =>
|
| `FunctionCall(name, args) =>
|
||||||
Js.log3("In function call", name, args);
|
Js.log3("In function call", name, args);
|
||||||
callableFunction(evaluationParams, name, args)
|
FunctionCall.run(evaluationParams, name, args)
|
||||||
|> E.R.bind(_, toLeaf(evaluationParams));
|
|> E.R.bind(_, toLeaf(evaluationParams));
|
||||||
};
|
};
|
||||||
};
|
};
|
||||||
|
|
|
@ -99,6 +99,11 @@ module ExpressionTree = {
|
||||||
);
|
);
|
||||||
let update = (t, str, fn) => MS.update(t, str, fn);
|
let update = (t, str, fn) => MS.update(t, str, fn);
|
||||||
let get = (t: t, str) => MS.get(t, str);
|
let get = (t: t, str) => MS.get(t, str);
|
||||||
|
let getFunction = (t: t, str) =>
|
||||||
|
switch (get(t, str)) {
|
||||||
|
| Some(`Function(argNames, fn)) => Ok((argNames, fn))
|
||||||
|
| _ => Error("Function " ++ str ++ " not found")
|
||||||
|
};
|
||||||
};
|
};
|
||||||
|
|
||||||
type evaluationParams = {
|
type evaluationParams = {
|
||||||
|
|
|
@ -1,33 +0,0 @@
|
||||||
type node = ExpressionTypes.ExpressionTree.node;
|
|
||||||
|
|
||||||
let toOkSym = r => Ok(`SymbolicDist(r));
|
|
||||||
let getFloat = ExpressionTypes.ExpressionTree.getFloat;
|
|
||||||
let fnn =
|
|
||||||
(
|
|
||||||
evaluationParams: ExpressionTypes.ExpressionTree.evaluationParams,
|
|
||||||
name,
|
|
||||||
args: array(node),
|
|
||||||
) => {
|
|
||||||
// let foundFn =
|
|
||||||
// TypeSystem.Function.Ts.findByName(Fns.functions, name) |> E.O.toResult("Function " ++ name ++ " not found");
|
|
||||||
// let ran = foundFn |> E.R.bind(_,TypeSystem.Function.T.run(evaluationParams,args))
|
|
||||||
|
|
||||||
let foundFn =
|
|
||||||
TypeSystem.Function.Ts.findByNameAndRun(Fns.functions, name, evaluationParams, args);
|
|
||||||
switch (foundFn) {
|
|
||||||
| Some(r) => r
|
|
||||||
| None =>
|
|
||||||
switch (
|
|
||||||
name,
|
|
||||||
ExpressionTypes.ExpressionTree.Environment.get(
|
|
||||||
evaluationParams.environment,
|
|
||||||
name,
|
|
||||||
),
|
|
||||||
) {
|
|
||||||
| (_, Some(`Function(argNames, tt))) =>
|
|
||||||
Js.log("Fundction found: " ++ name);
|
|
||||||
PTypes.Function.run(evaluationParams, args, (argNames, tt))
|
|
||||||
| _ => Error("Function " ++ name ++ " not found")
|
|
||||||
}
|
|
||||||
};
|
|
||||||
};
|
|
|
@ -33,6 +33,7 @@ module Function = {
|
||||||
} else {
|
} else {
|
||||||
Error("Wrong number of variables");
|
Error("Wrong number of variables");
|
||||||
};
|
};
|
||||||
|
|
||||||
};
|
};
|
||||||
|
|
||||||
module Primative = {
|
module Primative = {
|
||||||
|
|
|
@ -112,7 +112,7 @@ module Multimodal = {
|
||||||
|
|
||||||
let _paramsToDistsAndWeights = (r: array(typedValue)) =>
|
let _paramsToDistsAndWeights = (r: array(typedValue)) =>
|
||||||
switch (r) {
|
switch (r) {
|
||||||
| [|`Named(r)|] =>
|
| [|`Hash(r)|] =>
|
||||||
let dists =
|
let dists =
|
||||||
getByNameResult(r, "dists")
|
getByNameResult(r, "dists")
|
||||||
->E.R.bind(TypeSystem.TypedValue.toArray)
|
->E.R.bind(TypeSystem.TypedValue.toArray)
|
||||||
|
@ -171,7 +171,7 @@ module Multimodal = {
|
||||||
~name="multimodal",
|
~name="multimodal",
|
||||||
~outputType=`SamplingDistribution,
|
~outputType=`SamplingDistribution,
|
||||||
~inputTypes=[|
|
~inputTypes=[|
|
||||||
`Named([|
|
`Hash([|
|
||||||
("dists", `Array(`SamplingDistribution)),
|
("dists", `Array(`SamplingDistribution)),
|
||||||
("weights", `Array(`Float)),
|
("weights", `Array(`Float)),
|
||||||
|]),
|
|]),
|
|
@ -6,20 +6,22 @@ type samplingDist = [
|
||||||
| `RenderedDist(DistTypes.shape)
|
| `RenderedDist(DistTypes.shape)
|
||||||
];
|
];
|
||||||
|
|
||||||
type _type = [
|
type hashType = array((string, _type))
|
||||||
|
and _type = [
|
||||||
| `Float
|
| `Float
|
||||||
| `SamplingDistribution
|
| `SamplingDistribution
|
||||||
| `RenderedDistribution
|
| `RenderedDistribution
|
||||||
| `Array(_type)
|
| `Array(_type)
|
||||||
| `Named(array((string, _type)))
|
| `Hash(hashType)
|
||||||
];
|
];
|
||||||
|
|
||||||
type typedValue = [
|
type hashTypedValue = array((string, typedValue))
|
||||||
|
and typedValue = [
|
||||||
| `Float(float)
|
| `Float(float)
|
||||||
| `RenderedDist(DistTypes.shape)
|
| `RenderedDist(DistTypes.shape)
|
||||||
| `SamplingDist(samplingDist)
|
| `SamplingDist(samplingDist)
|
||||||
| `Array(array(typedValue))
|
| `Array(array(typedValue))
|
||||||
| `Named(array((string, typedValue)))
|
| `Hash(hashTypedValue)
|
||||||
];
|
];
|
||||||
|
|
||||||
type _function = {
|
type _function = {
|
||||||
|
@ -48,7 +50,7 @@ module TypedValue = {
|
||||||
hash
|
hash
|
||||||
|> E.A.fmap(((name, t)) => fromNode(t) |> E.R.fmap(r => (name, r)))
|
|> E.A.fmap(((name, t)) => fromNode(t) |> E.R.fmap(r => (name, r)))
|
||||||
|> E.A.R.firstErrorOrOpen
|
|> E.A.R.firstErrorOrOpen
|
||||||
|> E.R.fmap(r => `Named(r))
|
|> E.R.fmap(r => `Hash(r))
|
||||||
| _ => Error("Wrong type")
|
| _ => Error("Wrong type")
|
||||||
};
|
};
|
||||||
|
|
||||||
|
@ -74,8 +76,8 @@ module TypedValue = {
|
||||||
|> E.A.fmap(fromNodeWithTypeCoercion(evaluationParams, _type))
|
|> E.A.fmap(fromNodeWithTypeCoercion(evaluationParams, _type))
|
||||||
|> E.A.R.firstErrorOrOpen
|
|> E.A.R.firstErrorOrOpen
|
||||||
|> E.R.fmap(r => `Array(r))
|
|> E.R.fmap(r => `Array(r))
|
||||||
| (`Named(named), `Hash(r)) =>
|
| (`Hash(named), `Hash(r)) =>
|
||||||
let foo =
|
let keyValues =
|
||||||
named
|
named
|
||||||
|> E.A.fmap(((name, intendedType)) =>
|
|> E.A.fmap(((name, intendedType)) =>
|
||||||
(
|
(
|
||||||
|
@ -84,8 +86,8 @@ module TypedValue = {
|
||||||
ExpressionTypes.ExpressionTree.Hash.getByName(r, name),
|
ExpressionTypes.ExpressionTree.Hash.getByName(r, name),
|
||||||
)
|
)
|
||||||
);
|
);
|
||||||
let bar =
|
let typedHash =
|
||||||
foo
|
keyValues
|
||||||
|> E.A.fmap(((name, intendedType, optionNode)) =>
|
|> E.A.fmap(((name, intendedType, optionNode)) =>
|
||||||
switch (optionNode) {
|
switch (optionNode) {
|
||||||
| Some(node) =>
|
| Some(node) =>
|
||||||
|
@ -95,34 +97,33 @@ module TypedValue = {
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|> E.A.R.firstErrorOrOpen
|
|> E.A.R.firstErrorOrOpen
|
||||||
|> E.R.fmap(r => `Named(r));
|
|> E.R.fmap(r => `Hash(r));
|
||||||
bar;
|
typedHash;
|
||||||
| _ => Error("fromNodeWithTypeCoercion error, sorry.")
|
| _ => Error("fromNodeWithTypeCoercion error, sorry.")
|
||||||
};
|
};
|
||||||
};
|
};
|
||||||
|
|
||||||
let toFloat =
|
let toFloat: typedValue => result(float,string) =
|
||||||
fun
|
fun
|
||||||
| `Float(x) => Ok(x)
|
| `Float(x) => Ok(x)
|
||||||
| _ => Error("Not a float");
|
| _ => Error("Not a float");
|
||||||
|
|
||||||
let toArray =
|
let toArray: typedValue => result(array('a),string) =
|
||||||
fun
|
fun
|
||||||
| `Array(x) => Ok(x)
|
| `Array(x) => Ok(x)
|
||||||
| _ => Error("Not an array");
|
| _ => Error("Not an array");
|
||||||
|
|
||||||
let toNamed =
|
let toNamed: typedValue => result(hashTypedValue, string) =
|
||||||
fun
|
fun
|
||||||
| `Named(x) => Ok(x)
|
| `Hash(x) => Ok(x)
|
||||||
| _ => Error("Not a named item");
|
| _ => Error("Not a named item");
|
||||||
|
|
||||||
let toDist =
|
let toDist =
|
||||||
fun
|
fun
|
||||||
| `SamplingDist(`SymbolicDist(c)) => Ok(`SymbolicDist(c))
|
| `SamplingDist(`SymbolicDist(c)) => Ok(`SymbolicDist(c))
|
||||||
| `SamplingDist(`RenderedDist(c)) => Ok(`RenderedDist(c))
|
| `SamplingDist(`RenderedDist(c)) => Ok(`RenderedDist(c))
|
||||||
| `Float(x) =>
|
| `Float(x) => Ok(`SymbolicDist(`Float(x)))
|
||||||
Ok(`RenderedDist(SymbolicDist.T.toShape(1000, `Float(x))))
|
| _ => Error("Cannot be converted into a distribution");
|
||||||
| _ => Error("");
|
|
||||||
};
|
};
|
||||||
|
|
||||||
module Function = {
|
module Function = {
|
||||||
|
@ -188,7 +189,6 @@ module Function = {
|
||||||
inputNodes: inputNodes,
|
inputNodes: inputNodes,
|
||||||
t: t,
|
t: t,
|
||||||
) => {
|
) => {
|
||||||
Js.log("Running!");
|
|
||||||
inputsToTypedValues(evaluationParams, inputNodes, t)->E.R.bind(t.run)
|
inputsToTypedValues(evaluationParams, inputNodes, t)->E.R.bind(t.run)
|
||||||
|> (
|
|> (
|
||||||
fun
|
fun
|
Loading…
Reference in New Issue
Block a user