From 4f08533055a81e418caeacaea65ddfd3cc211bdc Mon Sep 17 00:00:00 2001 From: Ozzie Gooen Date: Fri, 31 Jul 2020 11:27:16 +0100 Subject: [PATCH] Very simple functions working --- .../expressionTree/ExpressionTypes.re | 15 +++-- src/distPlus/expressionTree/Functions.re | 67 +++++++++++++++---- src/distPlus/expressionTree/MathJsParser.re | 7 +- 3 files changed, 69 insertions(+), 20 deletions(-) diff --git a/src/distPlus/expressionTree/ExpressionTypes.re b/src/distPlus/expressionTree/ExpressionTypes.re index c3a93a1f..7c0ae0da 100644 --- a/src/distPlus/expressionTree/ExpressionTypes.re +++ b/src/distPlus/expressionTree/ExpressionTypes.re @@ -20,7 +20,7 @@ module ExpressionTree = { | `Truncate(option(float), option(float), node) | `Normalize(node) | `FloatFromDist(distToFloatOperation, node) - | `Function(node => result(node, string)) + | `Function(array(string), node) | `CallableFunction(string, array(node)) | `Symbol(string) ]; @@ -36,9 +36,16 @@ module ExpressionTree = { module Environment = { type t = environment - let empty:t = [||]->Belt.Map.String.fromArray - let update = (t,str, fn) => Belt.Map.String.update(t, str, fn) - let get = (t:t,str) => Belt.Map.String.get(t, str) + module MS = Belt.Map.String; + let fromArray = MS.fromArray + let empty:t = [||]->fromArray; + let mergeKeepSecond = (a:t,b:t) => MS.merge(a,b, (_,a,b) =>switch(a,b){ + | (_, Some(b)) => Some(b) + | (Some(a), _) => Some(a) + | _ => None + }) + let update = (t,str, fn) => MS.update(t, str, fn) + let get = (t:t,str) => MS.get(t, str) } type evaluationParams = { diff --git a/src/distPlus/expressionTree/Functions.re b/src/distPlus/expressionTree/Functions.re index dc5d4b4f..ccd3e96c 100644 --- a/src/distPlus/expressionTree/Functions.re +++ b/src/distPlus/expressionTree/Functions.re @@ -49,21 +49,62 @@ let to_: array(node) => result(node, string) = Error("Low value must be less than high value.") | _ => Error("Requires 2 variables"); -let fnn = (evaluationParams:ExpressionTypes.ExpressionTree.evaluationParams, name, args: array(node)) => { - switch (name, ExpressionTypes.ExpressionTree.Environment.get(evaluationParams.environment, name)) { - | (_, Some(`Function(t))) => t(`Function(t)); - | ("normal", _) => apply2(twoFloatsToOkSym(SymbolicDist.Normal.make), args) - | ("uniform", _) => apply2(twoFloatsToOkSym(SymbolicDist.Uniform.make), args) +let processCustomFn = + ( + evaluationParams: ExpressionTypes.ExpressionTree.evaluationParams, + args: array(node), + argNames: array(string), + fnResult: node, + ) => + if (E.A.length(args) == E.A.length(argNames)) { + let newEnvironment = + Belt.Array.zip(argNames, args) + |> ExpressionTypes.ExpressionTree.Environment.fromArray; + let newEvaluationParams: ExpressionTypes.ExpressionTree.evaluationParams = { + samplingInputs: evaluationParams.samplingInputs, + environment: + ExpressionTypes.ExpressionTree.Environment.mergeKeepSecond( + evaluationParams.environment, + newEnvironment, + ), + evaluateNode: evaluationParams.evaluateNode, + }; + Js.log4("HI", newEnvironment, newEvaluationParams, args); + evaluationParams.evaluateNode(newEvaluationParams, fnResult); + } else { + Error("Failure"); + }; + +let fnn = + ( + evaluationParams: ExpressionTypes.ExpressionTree.evaluationParams, + name, + args: array(node), + ) => { + Js.log3("Trying function", name, evaluationParams.environment); + switch ( + name, + ExpressionTypes.ExpressionTree.Environment.get( + evaluationParams.environment, + name, + ), + ) { + | (_, Some(`Function(argNames, tt))) => processCustomFn(evaluationParams, args, argNames, tt) + | ("normal", _) => + apply2(twoFloatsToOkSym(SymbolicDist.Normal.make), args) + | ("uniform", _) => + apply2(twoFloatsToOkSym(SymbolicDist.Uniform.make), args) | ("beta", _) => apply2(twoFloatsToOkSym(SymbolicDist.Beta.make), args) - | ("cauchy", _) => apply2(twoFloatsToOkSym(SymbolicDist.Cauchy.make), args) - | ("lognormal", _) => apply2(twoFloatsToOkSym(SymbolicDist.Lognormal.make), args) - | ("lognormalFromMeanAndStdDev", _) => apply2(twoFloatsToOkSym(SymbolicDist.Lognormal.fromMeanAndStdev), args) - | ("exponential", _) => + | ("cauchy", _) => + apply2(twoFloatsToOkSym(SymbolicDist.Cauchy.make), args) + | ("lognormal", _) => + apply2(twoFloatsToOkSym(SymbolicDist.Lognormal.make), args) + | ("lognormalFromMeanAndStdDev", _) => + apply2(twoFloatsToOkSym(SymbolicDist.Lognormal.fromMeanAndStdev), args) + | ("exponential", _) => switch (args) { - | [| - `SymbolicDist(`Float(a)), - |] => - Ok(`SymbolicDist(SymbolicDist.Exponential.make(a))); + | [|`SymbolicDist(`Float(a))|] => + Ok(`SymbolicDist(SymbolicDist.Exponential.make(a))) | _ => Error("Needs 3 valid arguments") } | ("triangular", _) => diff --git a/src/distPlus/expressionTree/MathJsParser.re b/src/distPlus/expressionTree/MathJsParser.re index 3eca937c..a7efdfcd 100644 --- a/src/distPlus/expressionTree/MathJsParser.re +++ b/src/distPlus/expressionTree/MathJsParser.re @@ -123,7 +123,7 @@ module MathAdtToDistDst = { switch (args) { | [|Object(o)|] => let g = s => - Js.Dict.get(o, s) |> E.O.toResult("") |> E.R.bind(_, nodeParser); + Js.Dict.get(o, s) |> E.O.toResult("Variable was empty") |> E.R.bind(_, nodeParser); switch (g("mean"), g("stdev"), g("mu"), g("sigma")) { | (Ok(mean), Ok(stdev), _, _) => Ok( @@ -326,8 +326,8 @@ module MathAdtToDistDst = { switch (r) { | FunctionAssignment({name, args, expression}) => switch (nodeParser(inputVars, expression)) { - | Ok(r) => Ok([|`Assignment((name, `Function(_ => Ok(r))))|]) - | _ => Error("") + | Ok(r) => Ok([|`Assignment((name, `Function(args, r)))|]) + | Error(r) => Error(r) } | Value(_) as r => nodeParser(inputVars, r) |> E.R.fmap(r => [|`Expression(r)|]) @@ -382,6 +382,7 @@ let fromString2 = (inputVars: inputVars, str) => { }); let value = E.R.bind(mathJsParse, MathAdtToDistDst.run(inputVars)); + Js.log3("Parsed", mathJsParse, value); value; };