Very simple functions working

This commit is contained in:
Ozzie Gooen 2020-07-31 11:27:16 +01:00
parent 7ba49f7219
commit 4f08533055
3 changed files with 69 additions and 20 deletions

View File

@ -20,7 +20,7 @@ module ExpressionTree = {
| `Truncate(option(float), option(float), node) | `Truncate(option(float), option(float), node)
| `Normalize(node) | `Normalize(node)
| `FloatFromDist(distToFloatOperation, node) | `FloatFromDist(distToFloatOperation, node)
| `Function(node => result(node, string)) | `Function(array(string), node)
| `CallableFunction(string, array(node)) | `CallableFunction(string, array(node))
| `Symbol(string) | `Symbol(string)
]; ];
@ -36,9 +36,16 @@ module ExpressionTree = {
module Environment = { module Environment = {
type t = environment type t = environment
let empty:t = [||]->Belt.Map.String.fromArray module MS = Belt.Map.String;
let update = (t,str, fn) => Belt.Map.String.update(t, str, fn) let fromArray = MS.fromArray
let get = (t:t,str) => Belt.Map.String.get(t, str) let empty:t = [||]->fromArray;
let mergeKeepSecond = (a:t,b:t) => MS.merge(a,b, (_,a,b) =>switch(a,b){
| (_, Some(b)) => Some(b)
| (Some(a), _) => Some(a)
| _ => None
})
let update = (t,str, fn) => MS.update(t, str, fn)
let get = (t:t,str) => MS.get(t, str)
} }
type evaluationParams = { type evaluationParams = {

View File

@ -49,21 +49,62 @@ let to_: array(node) => result(node, string) =
Error("Low value must be less than high value.") Error("Low value must be less than high value.")
| _ => Error("Requires 2 variables"); | _ => Error("Requires 2 variables");
let fnn = (evaluationParams:ExpressionTypes.ExpressionTree.evaluationParams, name, args: array(node)) => { let processCustomFn =
switch (name, ExpressionTypes.ExpressionTree.Environment.get(evaluationParams.environment, name)) { (
| (_, Some(`Function(t))) => t(`Function(t)); evaluationParams: ExpressionTypes.ExpressionTree.evaluationParams,
| ("normal", _) => apply2(twoFloatsToOkSym(SymbolicDist.Normal.make), args) args: array(node),
| ("uniform", _) => apply2(twoFloatsToOkSym(SymbolicDist.Uniform.make), args) argNames: array(string),
fnResult: node,
) =>
if (E.A.length(args) == E.A.length(argNames)) {
let newEnvironment =
Belt.Array.zip(argNames, args)
|> ExpressionTypes.ExpressionTree.Environment.fromArray;
let newEvaluationParams: ExpressionTypes.ExpressionTree.evaluationParams = {
samplingInputs: evaluationParams.samplingInputs,
environment:
ExpressionTypes.ExpressionTree.Environment.mergeKeepSecond(
evaluationParams.environment,
newEnvironment,
),
evaluateNode: evaluationParams.evaluateNode,
};
Js.log4("HI", newEnvironment, newEvaluationParams, args);
evaluationParams.evaluateNode(newEvaluationParams, fnResult);
} else {
Error("Failure");
};
let fnn =
(
evaluationParams: ExpressionTypes.ExpressionTree.evaluationParams,
name,
args: array(node),
) => {
Js.log3("Trying function", name, evaluationParams.environment);
switch (
name,
ExpressionTypes.ExpressionTree.Environment.get(
evaluationParams.environment,
name,
),
) {
| (_, Some(`Function(argNames, tt))) => processCustomFn(evaluationParams, args, argNames, tt)
| ("normal", _) =>
apply2(twoFloatsToOkSym(SymbolicDist.Normal.make), args)
| ("uniform", _) =>
apply2(twoFloatsToOkSym(SymbolicDist.Uniform.make), args)
| ("beta", _) => apply2(twoFloatsToOkSym(SymbolicDist.Beta.make), args) | ("beta", _) => apply2(twoFloatsToOkSym(SymbolicDist.Beta.make), args)
| ("cauchy", _) => apply2(twoFloatsToOkSym(SymbolicDist.Cauchy.make), args) | ("cauchy", _) =>
| ("lognormal", _) => apply2(twoFloatsToOkSym(SymbolicDist.Lognormal.make), args) apply2(twoFloatsToOkSym(SymbolicDist.Cauchy.make), args)
| ("lognormalFromMeanAndStdDev", _) => apply2(twoFloatsToOkSym(SymbolicDist.Lognormal.fromMeanAndStdev), args) | ("lognormal", _) =>
apply2(twoFloatsToOkSym(SymbolicDist.Lognormal.make), args)
| ("lognormalFromMeanAndStdDev", _) =>
apply2(twoFloatsToOkSym(SymbolicDist.Lognormal.fromMeanAndStdev), args)
| ("exponential", _) => | ("exponential", _) =>
switch (args) { switch (args) {
| [| | [|`SymbolicDist(`Float(a))|] =>
`SymbolicDist(`Float(a)), Ok(`SymbolicDist(SymbolicDist.Exponential.make(a)))
|] =>
Ok(`SymbolicDist(SymbolicDist.Exponential.make(a)));
| _ => Error("Needs 3 valid arguments") | _ => Error("Needs 3 valid arguments")
} }
| ("triangular", _) => | ("triangular", _) =>

View File

@ -123,7 +123,7 @@ module MathAdtToDistDst = {
switch (args) { switch (args) {
| [|Object(o)|] => | [|Object(o)|] =>
let g = s => let g = s =>
Js.Dict.get(o, s) |> E.O.toResult("") |> E.R.bind(_, nodeParser); Js.Dict.get(o, s) |> E.O.toResult("Variable was empty") |> E.R.bind(_, nodeParser);
switch (g("mean"), g("stdev"), g("mu"), g("sigma")) { switch (g("mean"), g("stdev"), g("mu"), g("sigma")) {
| (Ok(mean), Ok(stdev), _, _) => | (Ok(mean), Ok(stdev), _, _) =>
Ok( Ok(
@ -326,8 +326,8 @@ module MathAdtToDistDst = {
switch (r) { switch (r) {
| FunctionAssignment({name, args, expression}) => | FunctionAssignment({name, args, expression}) =>
switch (nodeParser(inputVars, expression)) { switch (nodeParser(inputVars, expression)) {
| Ok(r) => Ok([|`Assignment((name, `Function(_ => Ok(r))))|]) | Ok(r) => Ok([|`Assignment((name, `Function(args, r)))|])
| _ => Error("") | Error(r) => Error(r)
} }
| Value(_) as r => | Value(_) as r =>
nodeParser(inputVars, r) |> E.R.fmap(r => [|`Expression(r)|]) nodeParser(inputVars, r) |> E.R.fmap(r => [|`Expression(r)|])
@ -382,6 +382,7 @@ let fromString2 = (inputVars: inputVars, str) => {
}); });
let value = E.R.bind(mathJsParse, MathAdtToDistDst.run(inputVars)); let value = E.R.bind(mathJsParse, MathAdtToDistDst.run(inputVars));
Js.log3("Parsed", mathJsParse, value);
value; value;
}; };