Moved code from Functions.re

This commit is contained in:
Ozzie Gooen 2020-11-08 18:02:23 -08:00
parent a6051d8371
commit d2e7605ffd
6 changed files with 62 additions and 61 deletions

View File

@ -215,18 +215,46 @@ module Normalize = {
}; };
}; };
module FunctionCall = {
let _runHardcodedFunction = (name, evaluationParams, args) =>
TypeSystem.Function.Ts.findByNameAndRun(
Fns.functions,
name,
evaluationParams,
args,
);
let _runLocalFunction = (name, evaluationParams: evaluationParams, args) => {
Environment.getFunction(evaluationParams.environment, name)
|> E.R.bind(_, ((argNames, fn)) =>
PTypes.Function.run(evaluationParams, args, (argNames, fn))
);
};
let _runWithEvaluatedInputs =
(
evaluationParams: ExpressionTypes.ExpressionTree.evaluationParams,
name,
args: array(ExpressionTypes.ExpressionTree.node),
) => {
_runHardcodedFunction(name, evaluationParams, args)
|> E.O.default(_runLocalFunction(name, evaluationParams, args));
};
// TODO: This forces things to be floats // TODO: This forces things to be floats
let callableFunction = (evaluationParams, name, args) => { let run = (evaluationParams, name, args) => {
args args
|> E.A.fmap(a => evaluationParams.evaluateNode(evaluationParams, a)) |> E.A.fmap(a => evaluationParams.evaluateNode(evaluationParams, a))
|> E.A.R.firstErrorOrOpen |> E.A.R.firstErrorOrOpen
|> E.R.bind(_, Functions.fnn(evaluationParams, name)); |> E.R.bind(_, _runWithEvaluatedInputs(evaluationParams, name));
};
}; };
module Render = { module Render = {
let rec operationToLeaf = let rec operationToLeaf =
(evaluationParams: evaluationParams, t: node): result(t, string) => { (evaluationParams: evaluationParams, t: node): result(t, string) => {
Js.log2("rendering", t); Js.log2("rendering", t);
switch (t) { switch (t) {
| `Function(_) => Error("Cannot render a function") | `Function(_) => Error("Cannot render a function")
| `SymbolicDist(d) => | `SymbolicDist(d) =>
@ -305,7 +333,7 @@ let rec toLeaf =
|> E.R.bind(_, toLeaf(evaluationParams)); |> E.R.bind(_, toLeaf(evaluationParams));
| `FunctionCall(name, args) => | `FunctionCall(name, args) =>
Js.log3("In function call", name, args); Js.log3("In function call", name, args);
callableFunction(evaluationParams, name, args) FunctionCall.run(evaluationParams, name, args)
|> E.R.bind(_, toLeaf(evaluationParams)); |> E.R.bind(_, toLeaf(evaluationParams));
}; };
}; };

View File

@ -99,6 +99,11 @@ module ExpressionTree = {
); );
let update = (t, str, fn) => MS.update(t, str, fn); let update = (t, str, fn) => MS.update(t, str, fn);
let get = (t: t, str) => MS.get(t, str); let get = (t: t, str) => MS.get(t, str);
let getFunction = (t: t, str) =>
switch (get(t, str)) {
| Some(`Function(argNames, fn)) => Ok((argNames, fn))
| _ => Error("Function " ++ str ++ " not found")
};
}; };
type evaluationParams = { type evaluationParams = {

View File

@ -1,33 +0,0 @@
type node = ExpressionTypes.ExpressionTree.node;
let toOkSym = r => Ok(`SymbolicDist(r));
let getFloat = ExpressionTypes.ExpressionTree.getFloat;
let fnn =
(
evaluationParams: ExpressionTypes.ExpressionTree.evaluationParams,
name,
args: array(node),
) => {
// let foundFn =
// TypeSystem.Function.Ts.findByName(Fns.functions, name) |> E.O.toResult("Function " ++ name ++ " not found");
// let ran = foundFn |> E.R.bind(_,TypeSystem.Function.T.run(evaluationParams,args))
let foundFn =
TypeSystem.Function.Ts.findByNameAndRun(Fns.functions, name, evaluationParams, args);
switch (foundFn) {
| Some(r) => r
| None =>
switch (
name,
ExpressionTypes.ExpressionTree.Environment.get(
evaluationParams.environment,
name,
),
) {
| (_, Some(`Function(argNames, tt))) =>
Js.log("Fundction found: " ++ name);
PTypes.Function.run(evaluationParams, args, (argNames, tt))
| _ => Error("Function " ++ name ++ " not found")
}
};
};

View File

@ -33,6 +33,7 @@ module Function = {
} else { } else {
Error("Wrong number of variables"); Error("Wrong number of variables");
}; };
}; };
module Primative = { module Primative = {

View File

@ -112,7 +112,7 @@ module Multimodal = {
let _paramsToDistsAndWeights = (r: array(typedValue)) => let _paramsToDistsAndWeights = (r: array(typedValue)) =>
switch (r) { switch (r) {
| [|`Named(r)|] => | [|`Hash(r)|] =>
let dists = let dists =
getByNameResult(r, "dists") getByNameResult(r, "dists")
->E.R.bind(TypeSystem.TypedValue.toArray) ->E.R.bind(TypeSystem.TypedValue.toArray)
@ -171,7 +171,7 @@ module Multimodal = {
~name="multimodal", ~name="multimodal",
~outputType=`SamplingDistribution, ~outputType=`SamplingDistribution,
~inputTypes=[| ~inputTypes=[|
`Named([| `Hash([|
("dists", `Array(`SamplingDistribution)), ("dists", `Array(`SamplingDistribution)),
("weights", `Array(`Float)), ("weights", `Array(`Float)),
|]), |]),

View File

@ -6,20 +6,22 @@ type samplingDist = [
| `RenderedDist(DistTypes.shape) | `RenderedDist(DistTypes.shape)
]; ];
type _type = [ type hashType = array((string, _type))
and _type = [
| `Float | `Float
| `SamplingDistribution | `SamplingDistribution
| `RenderedDistribution | `RenderedDistribution
| `Array(_type) | `Array(_type)
| `Named(array((string, _type))) | `Hash(hashType)
]; ];
type typedValue = [ type hashTypedValue = array((string, typedValue))
and typedValue = [
| `Float(float) | `Float(float)
| `RenderedDist(DistTypes.shape) | `RenderedDist(DistTypes.shape)
| `SamplingDist(samplingDist) | `SamplingDist(samplingDist)
| `Array(array(typedValue)) | `Array(array(typedValue))
| `Named(array((string, typedValue))) | `Hash(hashTypedValue)
]; ];
type _function = { type _function = {
@ -48,7 +50,7 @@ module TypedValue = {
hash hash
|> E.A.fmap(((name, t)) => fromNode(t) |> E.R.fmap(r => (name, r))) |> E.A.fmap(((name, t)) => fromNode(t) |> E.R.fmap(r => (name, r)))
|> E.A.R.firstErrorOrOpen |> E.A.R.firstErrorOrOpen
|> E.R.fmap(r => `Named(r)) |> E.R.fmap(r => `Hash(r))
| _ => Error("Wrong type") | _ => Error("Wrong type")
}; };
@ -74,8 +76,8 @@ module TypedValue = {
|> E.A.fmap(fromNodeWithTypeCoercion(evaluationParams, _type)) |> E.A.fmap(fromNodeWithTypeCoercion(evaluationParams, _type))
|> E.A.R.firstErrorOrOpen |> E.A.R.firstErrorOrOpen
|> E.R.fmap(r => `Array(r)) |> E.R.fmap(r => `Array(r))
| (`Named(named), `Hash(r)) => | (`Hash(named), `Hash(r)) =>
let foo = let keyValues =
named named
|> E.A.fmap(((name, intendedType)) => |> E.A.fmap(((name, intendedType)) =>
( (
@ -84,8 +86,8 @@ module TypedValue = {
ExpressionTypes.ExpressionTree.Hash.getByName(r, name), ExpressionTypes.ExpressionTree.Hash.getByName(r, name),
) )
); );
let bar = let typedHash =
foo keyValues
|> E.A.fmap(((name, intendedType, optionNode)) => |> E.A.fmap(((name, intendedType, optionNode)) =>
switch (optionNode) { switch (optionNode) {
| Some(node) => | Some(node) =>
@ -95,34 +97,33 @@ module TypedValue = {
} }
) )
|> E.A.R.firstErrorOrOpen |> E.A.R.firstErrorOrOpen
|> E.R.fmap(r => `Named(r)); |> E.R.fmap(r => `Hash(r));
bar; typedHash;
| _ => Error("fromNodeWithTypeCoercion error, sorry.") | _ => Error("fromNodeWithTypeCoercion error, sorry.")
}; };
}; };
let toFloat = let toFloat: typedValue => result(float,string) =
fun fun
| `Float(x) => Ok(x) | `Float(x) => Ok(x)
| _ => Error("Not a float"); | _ => Error("Not a float");
let toArray = let toArray: typedValue => result(array('a),string) =
fun fun
| `Array(x) => Ok(x) | `Array(x) => Ok(x)
| _ => Error("Not an array"); | _ => Error("Not an array");
let toNamed = let toNamed: typedValue => result(hashTypedValue, string) =
fun fun
| `Named(x) => Ok(x) | `Hash(x) => Ok(x)
| _ => Error("Not a named item"); | _ => Error("Not a named item");
let toDist = let toDist =
fun fun
| `SamplingDist(`SymbolicDist(c)) => Ok(`SymbolicDist(c)) | `SamplingDist(`SymbolicDist(c)) => Ok(`SymbolicDist(c))
| `SamplingDist(`RenderedDist(c)) => Ok(`RenderedDist(c)) | `SamplingDist(`RenderedDist(c)) => Ok(`RenderedDist(c))
| `Float(x) => | `Float(x) => Ok(`SymbolicDist(`Float(x)))
Ok(`RenderedDist(SymbolicDist.T.toShape(1000, `Float(x)))) | _ => Error("Cannot be converted into a distribution");
| _ => Error("");
}; };
module Function = { module Function = {
@ -188,7 +189,6 @@ module Function = {
inputNodes: inputNodes, inputNodes: inputNodes,
t: t, t: t,
) => { ) => {
Js.log("Running!");
inputsToTypedValues(evaluationParams, inputNodes, t)->E.R.bind(t.run) inputsToTypedValues(evaluationParams, inputNodes, t)->E.R.bind(t.run)
|> ( |> (
fun fun