Moving function creation into ExpressionTreeEvaluator as functions
This commit is contained in:
		
							parent
							
								
									102a147b97
								
							
						
					
					
						commit
						d3766f7a7f
					
				|  | @ -216,6 +216,11 @@ let sample = (t: t): float => { | |||
|   bar; | ||||
| }; | ||||
| 
 | ||||
| let isFloat = (t:t) => switch(t){ | ||||
| | Discrete({xyShape: {xs: [|_|], ys: [|1.0|]}}) => true | ||||
| | _ => false | ||||
| } | ||||
| 
 | ||||
| let sampleNRendered = (n, dist) => { | ||||
|   let integralCache = T.Integral.get(dist); | ||||
|   let distWithUpdatedIntegralCache = T.updateIntegralCache(Some(integralCache), dist); | ||||
|  |  | |||
|  | @ -259,10 +259,22 @@ module FloatFromDist = { | |||
|   }; | ||||
| }; | ||||
| 
 | ||||
| let callableFunction = (evaluationParams, name, args) => { | ||||
|   let b = | ||||
|     args | ||||
|     |> E.A.fmap(a => | ||||
|          Render.render(evaluationParams, a) | ||||
|          |> E.R.bind(_, Render.toFloat) | ||||
|        ) | ||||
|     |> E.A.R.firstErrorOrOpen; | ||||
|   b |> E.R.bind(_, Functions.fnn("normal")); | ||||
| }; | ||||
| 
 | ||||
| module Render = { | ||||
|   let rec operationToLeaf = | ||||
|           (evaluationParams: evaluationParams, t: node): result(t, string) => { | ||||
|     switch (t) { | ||||
|     | `Function(_) => Error("Cannot render a function") | ||||
|     | `SymbolicDist(d) => | ||||
|       Ok( | ||||
|         `RenderedDist( | ||||
|  | @ -275,6 +287,13 @@ module Render = { | |||
|   }; | ||||
| }; | ||||
| 
 | ||||
| let run = (node, fnNode) => { | ||||
|   switch (fnNode) { | ||||
|   | `Function(r) => Ok(r(node)) | ||||
|   | _ => Error("Not a function") | ||||
|   }; | ||||
| }; | ||||
| 
 | ||||
| /* This function recursively goes through the nodes of the parse tree, | ||||
|    replacing each Operation node and its subtree with a Data node. | ||||
|    Whenever possible, the replacement produces a new Symbolic Data node, | ||||
|  | @ -314,5 +333,8 @@ let toLeaf = | |||
|     FloatFromDist.operationToLeaf(evaluationParams, distToFloatOp, t) | ||||
|   | `Normalize(t) => Normalize.operationToLeaf(evaluationParams, t) | ||||
|   | `Render(t) => Render.operationToLeaf(evaluationParams, t) | ||||
|   | `Function(t) => Ok(`Function(t)) | ||||
|   | `CallableFunction(name, args) => | ||||
|     callableFunction(evaluationParams, name, args) | ||||
|   }; | ||||
| }; | ||||
|  |  | |||
|  | @ -20,6 +20,8 @@ module ExpressionTree = { | |||
|     | `Truncate(option(float), option(float), node) | ||||
|     | `Normalize(node) | ||||
|     | `FloatFromDist(distToFloatOperation, node) | ||||
|     | `Function(node => result(node, string)) | ||||
|     | `CallableFunction(string, array(node)) | ||||
|   ]; | ||||
| 
 | ||||
|   type samplingInputs = { | ||||
|  | @ -71,8 +73,14 @@ module ExpressionTree = { | |||
|       | `RenderedDist(r) => Some(r) | ||||
|       | _ => None | ||||
|       }; | ||||
|   }; | ||||
| 
 | ||||
|     let _toFloat = (t:DistTypes.shape) => switch(t){ | ||||
|     | Discrete({xyShape: {xs: [|x|], ys: [|1.0|]}}) => Some(`SymbolicDist(`Float(x))) | ||||
|     | _ => None | ||||
|     } | ||||
| 
 | ||||
|     let toFloat = (item:node):result(node, string) => item |> getShape |> E.O.bind(_,_toFloat) |> E.O.toResult("Not valid shape") | ||||
|   }; | ||||
| }; | ||||
| 
 | ||||
| type simplificationResult = [ | ||||
|  |  | |||
							
								
								
									
										27
									
								
								src/distPlus/expressionTree/Functions.re
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										27
									
								
								src/distPlus/expressionTree/Functions.re
									
									
									
									
									
										Normal file
									
								
							|  | @ -0,0 +1,27 @@ | |||
| type node = ExpressionTypes.ExpressionTree.node; | ||||
| 
 | ||||
| let toOkSym = r => Ok(`SymbolicDist(r)); | ||||
| 
 | ||||
| let twoFloats = (fn, n1: node, n2: node): result(node, string) => | ||||
|   switch (n1, n2) { | ||||
|   | (`SymbolicDist(`Float(a)), `SymbolicDist(`Float(b))) => fn(a, b) | ||||
|   | _ => Error("Variables have wrong type") | ||||
|   }; | ||||
| 
 | ||||
| let twoFloatsToOkSym = fn => twoFloats((f1, f2) => fn(f1, f2) |> toOkSym); | ||||
| 
 | ||||
| let apply2 = (fn, args): result(node, string) => | ||||
|   switch (args) { | ||||
|   | [|a, b|] => fn(a, b) | ||||
|   | _ => Error("Needs 2 args") | ||||
|   }; | ||||
| 
 | ||||
| let fnn = (name, args: array(node)) => { | ||||
|   switch (name) { | ||||
|   | "normal" => apply2(twoFloatsToOkSym(SymbolicDist.Normal.make), args) | ||||
|   | "uniform" => apply2(twoFloatsToOkSym(SymbolicDist.Uniform.make), args) | ||||
|   | "beta" => apply2(twoFloatsToOkSym(SymbolicDist.Beta.make), args) | ||||
|   | "cauchy" => apply2(twoFloatsToOkSym(SymbolicDist.Cauchy.make), args) | ||||
|   | _ => Error("Function not found") | ||||
|   }; | ||||
| }; | ||||
|  | @ -55,7 +55,7 @@ module MathAdtToDistDst = { | |||
|   let handleSymbol = (inputVars: inputVars, sym) => { | ||||
|     switch (Belt.Map.String.get(inputVars, sym)) { | ||||
|     | Some(s) => Ok(s) | ||||
|     | None => Error("Couldn't find.") | ||||
|     | None => Error("Couldn't find symbol " ++ sym) | ||||
|     }; | ||||
|   }; | ||||
| 
 | ||||
|  | @ -93,13 +93,6 @@ module MathAdtToDistDst = { | |||
|         ); | ||||
|   }; | ||||
| 
 | ||||
|   let normal: | ||||
|     array(arg) => result(ExpressionTypes.ExpressionTree.node, string) = | ||||
|     fun | ||||
|     | [|Value(mean), Value(stdev)|] => | ||||
|       Ok(`SymbolicDist(`Normal({mean, stdev}))) | ||||
|     | _ => Error("Wrong number of variables in normal distribution"); | ||||
| 
 | ||||
|   let lognormal: | ||||
|     array(arg) => result(ExpressionTypes.ExpressionTree.node, string) = | ||||
|     fun | ||||
|  | @ -135,32 +128,12 @@ module MathAdtToDistDst = { | |||
|       Error("Low value must be less than high value.") | ||||
|     | _ => Error("Wrong number of variables in lognormal distribution"); | ||||
| 
 | ||||
|   let uniform: | ||||
|     array(arg) => result(ExpressionTypes.ExpressionTree.node, string) = | ||||
|     fun | ||||
|     | [|Value(low), Value(high)|] => | ||||
|       Ok(`SymbolicDist(`Uniform({low, high}))) | ||||
|     | _ => Error("Wrong number of variables in lognormal distribution"); | ||||
| 
 | ||||
|   let beta: array(arg) => result(ExpressionTypes.ExpressionTree.node, string) = | ||||
|     fun | ||||
|     | [|Value(alpha), Value(beta)|] => | ||||
|       Ok(`SymbolicDist(`Beta({alpha, beta}))) | ||||
|     | _ => Error("Wrong number of variables in lognormal distribution"); | ||||
| 
 | ||||
|   let exponential: | ||||
|     array(arg) => result(ExpressionTypes.ExpressionTree.node, string) = | ||||
|     fun | ||||
|     | [|Value(rate)|] => Ok(`SymbolicDist(`Exponential({rate: rate}))) | ||||
|     | _ => Error("Wrong number of variables in Exponential distribution"); | ||||
| 
 | ||||
|   let cauchy: | ||||
|     array(arg) => result(ExpressionTypes.ExpressionTree.node, string) = | ||||
|     fun | ||||
|     | [|Value(local), Value(scale)|] => | ||||
|       Ok(`SymbolicDist(`Cauchy({local, scale}))) | ||||
|     | _ => Error("Wrong number of variables in cauchy distribution"); | ||||
| 
 | ||||
|   let triangular: | ||||
|     array(arg) => result(ExpressionTypes.ExpressionTree.node, string) = | ||||
|     fun | ||||
|  | @ -214,43 +187,16 @@ module MathAdtToDistDst = { | |||
|     }; | ||||
|   }; | ||||
| 
 | ||||
|   // let arrayParser = | ||||
|   //     (args: array(arg)) | ||||
|   //     : result(ExpressionTypes.ExpressionTree.node, string) => { | ||||
|   //   let samples = | ||||
|   //     args | ||||
|   //     |> E.A.fmap( | ||||
|   //          fun | ||||
|   //          | Value(n) => Some(n) | ||||
|   //          | _ => None, | ||||
|   //        ) | ||||
|   //     |> E.A.O.concatSomes; | ||||
|   //   let outputs = Samples.T.fromSamples(samples); | ||||
|   //   let pdf = | ||||
|   //     outputs.shape |> E.O.bind(_, Shape.T.toContinuous); | ||||
|   //   let shape = | ||||
|   //     pdf | ||||
|   //     |> E.O.fmap(pdf => { | ||||
|   //          let _pdf = Continuous.T.normalize(pdf); | ||||
|   //          let cdf = Continuous.T.integral(~cache=None, _pdf); | ||||
|   //          SymbolicDist.ContinuousShape.make(_pdf, cdf); | ||||
|   //        }); | ||||
|   //   switch (shape) { | ||||
|   //   | Some(s) => Ok(`SymbolicDist(`ContinuousShape(s))) | ||||
|   //   | None => Error("Rendering did not work") | ||||
|   //   }; | ||||
|   // }; | ||||
| 
 | ||||
|   let operationParser = | ||||
|       ( | ||||
|         name: string, | ||||
|         args: array(result(ExpressionTypes.ExpressionTree.node, string)), | ||||
|         args: result(array(ExpressionTypes.ExpressionTree.node), string), | ||||
|       ) => { | ||||
|     let toOkAlgebraic = r => Ok(`AlgebraicCombination(r)); | ||||
|     let toOkPointwise = r => Ok(`PointwiseCombination(r)); | ||||
|     let toOkTruncate = r => Ok(`Truncate(r)); | ||||
|     let toOkFloatFromDist = r => Ok(`FloatFromDist(r)); | ||||
|     E.A.R.firstErrorOrOpen(args) | ||||
|     args | ||||
|     |> E.R.bind(_, args => { | ||||
|          switch (name, args) { | ||||
|          | ("add", [|l, r|]) => toOkAlgebraic((`Add, l, r)) | ||||
|  | @ -303,17 +249,17 @@ module MathAdtToDistDst = { | |||
|   }; | ||||
| 
 | ||||
|   let functionParser = (nodeParser, name, args) => { | ||||
|     let parseArgs = () => args |> E.A.fmap(nodeParser); | ||||
|     Js.log2("Parseargs", parseArgs); | ||||
|     let parseArgs = () => | ||||
|       args |> E.A.fmap(nodeParser) |> E.A.R.firstErrorOrOpen; | ||||
|     switch (name) { | ||||
|     | "normal" => normal(args) | ||||
|     | "lognormal" => lognormal(args) | ||||
|     | "uniform" => uniform(args) | ||||
|     | "beta" => beta(args) | ||||
|     | "to" => to_(args) | ||||
|     | "exponential" => exponential(args) | ||||
|     | "cauchy" => cauchy(args) | ||||
|     | "triangular" => triangular(args) | ||||
|     | "normal" | "uniform" | "beta" | "caucy" => | ||||
|       parseArgs() | ||||
|       |> E.R.fmap( | ||||
|            ( | ||||
|              args: array(ExpressionTypes.ExpressionTree.node), | ||||
|            ) => | ||||
|            `CallableFunction((name, args)) | ||||
|          ) | ||||
|     | "mm" => | ||||
|       let weights = | ||||
|         args | ||||
|  | @ -358,14 +304,18 @@ module MathAdtToDistDst = { | |||
|     }; | ||||
|   }; | ||||
| 
 | ||||
|   let rec nodeParser = inputVars => | ||||
|     fun | ||||
|     | Value(f) => Ok(`SymbolicDist(`Float(f))) | ||||
|     | Symbol(s) => handleSymbol(inputVars, s) | ||||
|     | Fn({name, args}) => functionParser(nodeParser(inputVars), name, args) | ||||
|     | _ => { | ||||
|         Error("This type not currently supported"); | ||||
|       }; | ||||
|   let rec nodeParser: | ||||
|     (inputVars, MathJsonToMathJsAdt.arg) => | ||||
|     result(ExpressionTypes.ExpressionTree.node, string) = | ||||
|     inputVars => | ||||
|       fun | ||||
|       | Value(f) => Ok(`SymbolicDist(`Float(f))) | ||||
|       | Symbol(s) => handleSymbol(inputVars, s) | ||||
|       | Fn({name, args}) => | ||||
|         functionParser(nodeParser(inputVars), name, args) | ||||
|       | _ => { | ||||
|           Error("This type not currently supported"); | ||||
|         }; | ||||
| 
 | ||||
|   let topLevel = inputVars => | ||||
|     fun | ||||
|  | @ -405,7 +355,6 @@ let fromString2 = (inputVars: inputVars, str) => { | |||
|       } | ||||
|     }); | ||||
| 
 | ||||
|   Js.log(mathJsParse); | ||||
|   let value = E.R.bind(mathJsParse, MathAdtToDistDst.run(inputVars)); | ||||
|   value; | ||||
| }; | ||||
|  |  | |||
|  | @ -2,6 +2,8 @@ open SymbolicTypes; | |||
| 
 | ||||
| module Exponential = { | ||||
|   type t = exponential; | ||||
|   let make = (rate): symbolicDist => | ||||
|     `Exponential({rate}); | ||||
|   let pdf = (x, t: t) => Jstat.exponential##pdf(x, t.rate); | ||||
|   let cdf = (x, t: t) => Jstat.exponential##cdf(x, t.rate); | ||||
|   let inv = (p, t: t) => Jstat.exponential##inv(p, t.rate); | ||||
|  | @ -12,6 +14,8 @@ module Exponential = { | |||
| 
 | ||||
| module Cauchy = { | ||||
|   type t = cauchy; | ||||
|   let make = (local, scale): symbolicDist => | ||||
|     `Cauchy({local,scale}); | ||||
|   let pdf = (x, t: t) => Jstat.cauchy##pdf(x, t.local, t.scale); | ||||
|   let cdf = (x, t: t) => Jstat.cauchy##cdf(x, t.local, t.scale); | ||||
|   let inv = (p, t: t) => Jstat.cauchy##inv(p, t.local, t.scale); | ||||
|  | @ -22,6 +26,8 @@ module Cauchy = { | |||
| 
 | ||||
| module Triangular = { | ||||
|   type t = triangular; | ||||
|   let make = (low, medium, high): symbolicDist => | ||||
|     `Triangular({low, medium, high}); | ||||
|   let pdf = (x, t: t) => Jstat.triangular##pdf(x, t.low, t.high, t.medium); | ||||
|   let cdf = (x, t: t) => Jstat.triangular##cdf(x, t.low, t.high, t.medium); | ||||
|   let inv = (p, t: t) => Jstat.triangular##inv(p, t.low, t.high, t.medium); | ||||
|  | @ -32,7 +38,7 @@ module Triangular = { | |||
| 
 | ||||
| module Normal = { | ||||
|   type t = normal; | ||||
|   let make = (mean, stdev):t => {mean, stdev}; | ||||
|   let make = (mean, stdev): symbolicDist => `Normal({mean, stdev}); | ||||
|   let pdf = (x, t: t) => Jstat.normal##pdf(x, t.mean, t.stdev); | ||||
|   let cdf = (x, t: t) => Jstat.normal##cdf(x, t.mean, t.stdev); | ||||
| 
 | ||||
|  | @ -76,6 +82,7 @@ module Normal = { | |||
| 
 | ||||
| module Beta = { | ||||
|   type t = beta; | ||||
|   let make = (alpha, beta) => `Beta({alpha, beta}) | ||||
|   let pdf = (x, t: t) => Jstat.beta##pdf(x, t.alpha, t.beta); | ||||
|   let cdf = (x, t: t) => Jstat.beta##cdf(x, t.alpha, t.beta); | ||||
|   let inv = (p, t: t) => Jstat.beta##inv(p, t.alpha, t.beta); | ||||
|  | @ -86,6 +93,7 @@ module Beta = { | |||
| 
 | ||||
| module Lognormal = { | ||||
|   type t = lognormal; | ||||
|   let make = (mu, sigma) => `Lognormal({mu, sigma}) | ||||
|   let pdf = (x, t: t) => Jstat.lognormal##pdf(x, t.mu, t.sigma); | ||||
|   let cdf = (x, t: t) => Jstat.lognormal##cdf(x, t.mu, t.sigma); | ||||
|   let inv = (p, t: t) => Jstat.lognormal##inv(p, t.mu, t.sigma); | ||||
|  | @ -132,6 +140,7 @@ module Lognormal = { | |||
| 
 | ||||
| module Uniform = { | ||||
|   type t = uniform; | ||||
|   let make = (low, high) => `Uniform({low, high}) | ||||
|   let pdf = (x, t: t) => Jstat.uniform##pdf(x, t.low, t.high); | ||||
|   let cdf = (x, t: t) => Jstat.uniform##cdf(x, t.low, t.high); | ||||
|   let inv = (p, t: t) => Jstat.uniform##inv(p, t.low, t.high); | ||||
|  | @ -147,6 +156,7 @@ module Uniform = { | |||
| 
 | ||||
| module Float = { | ||||
|   type t = float; | ||||
|   let make = t => `Float(t) | ||||
|   let pdf = (x, t: t) => x == t ? 1.0 : 0.0; | ||||
|   let cdf = (x, t: t) => x >= t ? 1.0 : 0.0; | ||||
|   let inv = (p, t: t) => p < t ? 0.0 : 1.0; | ||||
|  | @ -318,13 +328,14 @@ module T = { | |||
|     switch (d) { | ||||
|     | `Float(v) => | ||||
|       Discrete( | ||||
|         Discrete.make(~integralSumCache=Some(1.0), {xs: [|v|], ys: [|1.0|]}), | ||||
|         Discrete.make( | ||||
|           ~integralSumCache=Some(1.0), | ||||
|           {xs: [|v|], ys: [|1.0|]}, | ||||
|         ), | ||||
|       ) | ||||
|     | _ => | ||||
|       let xs = interpolateXs(~xSelection=`ByWeight, d, sampleCount); | ||||
|       let ys = xs |> E.A.fmap(x => pdf(x, d)); | ||||
|       Continuous( | ||||
|         Continuous.make(~integralSumCache=Some(1.0), {xs, ys}), | ||||
|       ); | ||||
|       Continuous(Continuous.make(~integralSumCache=Some(1.0), {xs, ys})); | ||||
|     }; | ||||
| }; | ||||
|  |  | |||
		Loading…
	
		Reference in New Issue
	
	Block a user