Further cleanup of MultiModal functionality
This commit is contained in:
		
							parent
							
								
									b71d037180
								
							
						
					
					
						commit
						a6051d8371
					
				|  | @ -24,7 +24,6 @@ let rec toString: node => string = | |||
|     ++ toString(internal) | ||||
|     ++ ")]" | ||||
|   | `Array(_) => "Array" | ||||
|   | `MultiModal(_) => "Multimodal" | ||||
|   | `Hash(_) => "Hash" | ||||
| 
 | ||||
| let envs = (samplingInputs, environment) => { | ||||
|  |  | |||
|  | @ -307,6 +307,5 @@ let rec toLeaf = | |||
|     Js.log3("In function call", name, args); | ||||
|     callableFunction(evaluationParams, name, args) | ||||
|     |> E.R.bind(_, toLeaf(evaluationParams)); | ||||
|   | `MultiModal(r) => Error("Multimodal?") | ||||
|   }; | ||||
| }; | ||||
|  |  | |||
|  | @ -30,16 +30,18 @@ module ExpressionTree = { | |||
|     | `Render(node) | ||||
|     | `Truncate(option(float), option(float), node) | ||||
|     | `FunctionCall(string, array(node)) | ||||
|     | `MultiModal(array((node, float))) | ||||
|   ]; | ||||
| 
 | ||||
|   module Hash = { | ||||
|     type t('a) = array((string, 'a)); | ||||
|     let getByName = (t:t('a), name) => | ||||
|     let getByName = (t: t('a), name) => | ||||
|       E.A.getBy(t, ((n, _)) => n == name) |> E.O.fmap(((_, r)) => r); | ||||
| 
 | ||||
|     let getByNames = (hash: t('a), names:array(string)) => | ||||
|       names |> E.A.fmap(name => (name, getByName(hash, name))) | ||||
|     let getByNameResult = (t: t('a), name) => | ||||
|       getByName(t, name) |> E.O.toResult(name ++ " expected and not found"); | ||||
| 
 | ||||
|     let getByNames = (hash: t('a), names: array(string)) => | ||||
|       names |> E.A.fmap(name => (name, getByName(hash, name))); | ||||
|   }; | ||||
|   // Have nil as option | ||||
|   let getFloat = (node: node) => | ||||
|  |  | |||
|  | @ -107,6 +107,80 @@ let verticalScaling = (scaleOp, rs, scaleBy) => { | |||
|   ); | ||||
| }; | ||||
| 
 | ||||
| module Multimodal = { | ||||
|   let getByNameResult = ExpressionTypes.ExpressionTree.Hash.getByNameResult; | ||||
| 
 | ||||
|   let _paramsToDistsAndWeights = (r: array(typedValue)) => | ||||
|     switch (r) { | ||||
|     | [|`Named(r)|] => | ||||
|       let dists = | ||||
|         getByNameResult(r, "dists") | ||||
|         ->E.R.bind(TypeSystem.TypedValue.toArray) | ||||
|         ->E.R.bind(r => | ||||
|             r | ||||
|             |> E.A.fmap(TypeSystem.TypedValue.toDist) | ||||
|             |> E.A.R.firstErrorOrOpen | ||||
|           ); | ||||
|       let weights = | ||||
|         getByNameResult(r, "weights") | ||||
|         ->E.R.bind(TypeSystem.TypedValue.toArray) | ||||
|         ->E.R.bind(r => | ||||
|             r | ||||
|             |> E.A.fmap(TypeSystem.TypedValue.toFloat) | ||||
|             |> E.A.R.firstErrorOrOpen | ||||
|           ); | ||||
| 
 | ||||
|       E.R.merge(dists, weights) | ||||
|       |> E.R.fmap(((a, b)) => | ||||
|            E.A.zipMaxLength(a, b) | ||||
|            |> E.A.fmap(((a, b)) => | ||||
|                 (a |> E.O.toExn(""), b |> E.O.default(1.0)) | ||||
|               ) | ||||
|          ); | ||||
|     | _ => Error("Needs items") | ||||
|     }; | ||||
|   let _runner: array(typedValue) => result(node, string) = | ||||
|     r => { | ||||
|       let paramsToDistsAndWeights = | ||||
|         _paramsToDistsAndWeights(r) | ||||
|         |> E.R.fmap( | ||||
|              E.A.fmap(((dist, weight)) => | ||||
|                `FunctionCall(( | ||||
|                  "scaleMultiply", | ||||
|                  [|dist, `SymbolicDist(`Float(weight))|], | ||||
|                )) | ||||
|              ), | ||||
|            ); | ||||
|       let pointwiseSum: result(node, string) = | ||||
|         paramsToDistsAndWeights->E.R.bind( | ||||
|           E.R.errorIfCondition(E.A.isEmpty, "Needs one input"), | ||||
|         ) | ||||
|         |> E.R.fmap(r => | ||||
|              r | ||||
|              |> Js.Array.sliceFrom(1) | ||||
|              |> E.A.fold_left( | ||||
|                   (acc, x) => {`PointwiseCombination((`Add, acc, x))}, | ||||
|                   E.A.unsafe_get(r, 0), | ||||
|                 ) | ||||
|            ); | ||||
|       pointwiseSum; | ||||
|     }; | ||||
| 
 | ||||
|   let _function = | ||||
|     Function.T.make( | ||||
|       ~name="multimodal", | ||||
|       ~outputType=`SamplingDistribution, | ||||
|       ~inputTypes=[| | ||||
|         `Named([| | ||||
|           ("dists", `Array(`SamplingDistribution)), | ||||
|           ("weights", `Array(`Float)), | ||||
|         |]), | ||||
|       |], | ||||
|       ~run=_runner, | ||||
|       (), | ||||
|     ); | ||||
| }; | ||||
| 
 | ||||
| let functions = [| | ||||
|   makeSymbolicFromTwoFloats("normal", SymbolicDist.Normal.make), | ||||
|   makeSymbolicFromTwoFloats("uniform", SymbolicDist.Uniform.make), | ||||
|  | @ -175,98 +249,5 @@ let functions = [| | |||
|   makeRenderedDistFloat("scaleLog", (dist, float) => | ||||
|     verticalScaling(`Log, dist, float) | ||||
|   ), | ||||
|   Function.T.make( | ||||
|     ~name="multimodal", | ||||
|     ~outputType=`SamplingDistribution, | ||||
|     ~inputTypes=[| | ||||
|       `Named([| | ||||
|         ("dists", `Array(`SamplingDistribution)), | ||||
|         ("weights", `Array(`Float)), | ||||
|       |]), | ||||
|     |], | ||||
|     ~run= | ||||
|       fun | ||||
|       | [|`Named(r)|] => { | ||||
|           let foo = | ||||
|               (r: TypeSystem.typedValue) | ||||
|               : result(ExpressionTypes.ExpressionTree.node, string) => | ||||
|             switch (r) { | ||||
|             | `SamplingDist(`SymbolicDist(c)) => Ok(`SymbolicDist(c)) | ||||
|             | `SamplingDist(`RenderedDist(c)) => Ok(`RenderedDist(c)) | ||||
|             | `Float(x) => | ||||
|               Ok(`RenderedDist(SymbolicDist.T.toShape(1000, `Float(x)))) | ||||
|             | _ => Error("") | ||||
|             }; | ||||
|           let weight = (r: TypeSystem.typedValue) => | ||||
|             switch (r) { | ||||
|             | `Float(x) => Ok(x) | ||||
|             | _ => Error("Wrong Type") | ||||
|             }; | ||||
|           let dists = | ||||
|             switch (ExpressionTypes.ExpressionTree.Hash.getByName(r, "dists")) { | ||||
|             | Some(`Array(r)) => r |> E.A.fmap(foo) |> E.A.R.firstErrorOrOpen | ||||
|             | _ => Error("") | ||||
|             }; | ||||
|           let weights = | ||||
|             ( | ||||
|               switch ( | ||||
|                 ExpressionTypes.ExpressionTree.Hash.getByName(r, "weights") | ||||
|               ) { | ||||
|               | Some(`Array(r)) => | ||||
|                 r |> E.A.fmap(weight) |> E.A.R.firstErrorOrOpen | ||||
|               | _ => Error("") | ||||
|               } | ||||
|             ) | ||||
|             |> ( | ||||
|               fun | ||||
|               | Ok(r) => r | ||||
|               | _ => [||] | ||||
|             ); | ||||
|           let withWeights = | ||||
|             dists | ||||
|             |> E.R.fmap(d => { | ||||
|                  let iis = | ||||
|                    d |> E.A.length |> Belt.Array.makeUninitializedUnsafe; | ||||
|                  for (i in 0 to (d |> E.A.length) - 1) { | ||||
|                    Belt.Array.set( | ||||
|                      iis, | ||||
|                      i, | ||||
|                      ( | ||||
|                        E.A.unsafe_get(d, i), | ||||
|                        E.A.get(weights, i) |> E.O.default(1.0), | ||||
|                      ), | ||||
|                    ) | ||||
|                    |> ignore; | ||||
|                  }; | ||||
|                  iis; | ||||
|                }); | ||||
|           let components: result(array(node), string) = | ||||
|             withWeights | ||||
|             |> E.R.fmap( | ||||
|                  E.A.fmap(((dist, weight)) => | ||||
|                    `FunctionCall(( | ||||
|                      "scaleMultiply", | ||||
|                      [|dist, `SymbolicDist(`Float(weight))|], | ||||
|                    )) | ||||
|                  ), | ||||
|                ); | ||||
|           let pointwiseSum = | ||||
|             components | ||||
|             |> E.R.bind(_, r => { | ||||
|                  E.A.length(r) > 0 | ||||
|                    ? Ok(r) : Error("Invalid argument length") | ||||
|                }) | ||||
|             |> E.R.fmap(r => | ||||
|                  r | ||||
|                  |> Js.Array.sliceFrom(1) | ||||
|                  |> E.A.fold_left( | ||||
|                       (acc, x) => {`PointwiseCombination((`Add, acc, x))}, | ||||
|                       E.A.unsafe_get(r, 0), | ||||
|                     ) | ||||
|                ); | ||||
|           pointwiseSum; | ||||
|         } | ||||
|       | _ => Error(""), | ||||
|     (), | ||||
|   ), | ||||
|   Multimodal._function | ||||
| |]; | ||||
|  |  | |||
|  | @ -54,7 +54,6 @@ module TypedValue = { | |||
| 
 | ||||
|   // todo: Arrays and hashes | ||||
|   let rec fromNodeWithTypeCoercion = (evaluationParams, _type: _type, node) => { | ||||
|     Js.log3("With Coersion!", _type, node); | ||||
|     switch (_type, node) { | ||||
|     | (`Float, _) => | ||||
|       switch (getFloat(node)) { | ||||
|  | @ -76,7 +75,6 @@ module TypedValue = { | |||
|       |> E.A.R.firstErrorOrOpen | ||||
|       |> E.R.fmap(r => `Array(r)) | ||||
|     | (`Named(named), `Hash(r)) => | ||||
|       Js.log3("Named", named, r); | ||||
|       let foo = | ||||
|         named | ||||
|         |> E.A.fmap(((name, intendedType)) => | ||||
|  | @ -86,7 +84,6 @@ module TypedValue = { | |||
|                ExpressionTypes.ExpressionTree.Hash.getByName(r, name), | ||||
|              ) | ||||
|            ); | ||||
|       Js.log("Named: part 2"); | ||||
|       let bar = | ||||
|         foo | ||||
|         |> E.A.fmap(((name, intendedType, optionNode)) => | ||||
|  | @ -99,11 +96,33 @@ module TypedValue = { | |||
|            ) | ||||
|         |> E.A.R.firstErrorOrOpen | ||||
|         |> E.R.fmap(r => `Named(r)); | ||||
|       Js.log3("Named!", foo, bar); | ||||
|       bar; | ||||
|     | _ => Error("fromNodeWithTypeCoercion error, sorry.") | ||||
|     }; | ||||
|   }; | ||||
| 
 | ||||
|   let toFloat = | ||||
|     fun | ||||
|     | `Float(x) => Ok(x) | ||||
|     | _ => Error("Not a float"); | ||||
| 
 | ||||
|   let toArray = | ||||
|     fun | ||||
|     | `Array(x) => Ok(x) | ||||
|     | _ => Error("Not an array"); | ||||
| 
 | ||||
|   let toNamed = | ||||
|     fun | ||||
|     | `Named(x) => Ok(x) | ||||
|     | _ => Error("Not a named item"); | ||||
| 
 | ||||
|   let toDist = | ||||
|     fun | ||||
|     | `SamplingDist(`SymbolicDist(c)) => Ok(`SymbolicDist(c)) | ||||
|     | `SamplingDist(`RenderedDist(c)) => Ok(`RenderedDist(c)) | ||||
|     | `Float(x) => | ||||
|       Ok(`RenderedDist(SymbolicDist.T.toShape(1000, `Float(x)))) | ||||
|     | _ => Error(""); | ||||
| }; | ||||
| 
 | ||||
| module Function = { | ||||
|  |  | |||
|  | @ -144,7 +144,6 @@ let renderIfNeeded = | |||
|   node | ||||
|   |> ( | ||||
|     fun | ||||
|     | `MultiModal(_) as n | ||||
|     | `Normalize(_) as n | ||||
|     | `SymbolicDist(_) as n => { | ||||
|         `Render(n) | ||||
|  |  | |||
|  | @ -26,6 +26,9 @@ module FloatFloatMap = { | |||
|   let fmap = (fn, t: t) => Belt.MutableMap.map(t, fn); | ||||
| }; | ||||
| 
 | ||||
| module Int = { | ||||
|   let max = (i1: int, i2: int) => i1 > i2 ? i1 : i2; | ||||
| }; | ||||
| /* Utils */ | ||||
| module U = { | ||||
|   let isEqual = (a, b) => a == b; | ||||
|  | @ -146,6 +149,11 @@ module R = { | |||
|   let fmap = Rationale.Result.fmap; | ||||
|   let bind = Rationale.Result.bind; | ||||
|   let toExn = Belt.Result.getExn; | ||||
|   let default = (default, res: Belt.Result.t('a, 'b)) => | ||||
|     switch (res) { | ||||
|     | Ok(r) => r | ||||
|     | Error(_) => default | ||||
|     }; | ||||
|   let merge = (a, b) => | ||||
|     switch (a, b) { | ||||
|     | (Error(e), _) => Error(e) | ||||
|  | @ -157,6 +165,9 @@ module R = { | |||
|     | Ok(r) => Some(r) | ||||
|     | Error(_) => None | ||||
|     }; | ||||
| 
 | ||||
|   let errorIfCondition = (errorCondition, errorMessage, r) => | ||||
|     errorCondition(r) ? Error(errorMessage) : Ok(r); | ||||
| }; | ||||
| 
 | ||||
| let safe_fn_of_string = (fn, s: string): option('a) => | ||||
|  | @ -263,6 +274,7 @@ module A = { | |||
|   let init = Array.init; | ||||
|   let reduce = Belt.Array.reduce; | ||||
|   let reducei = Belt.Array.reduceWithIndex; | ||||
|   let isEmpty = r => length(r) < 1; | ||||
|   let min = a => | ||||
|     get(a, 0) | ||||
|     |> O.fmap(first => Belt.Array.reduce(a, first, (i, j) => i < j ? i : j)); | ||||
|  | @ -285,6 +297,16 @@ module A = { | |||
|       |> Rationale.Result.return | ||||
|     }; | ||||
| 
 | ||||
|   // This zips while taking the longest elements of each array. | ||||
|   let zipMaxLength = (array1, array2) => { | ||||
|     let maxLength = Int.max(length(array1), length(array2)); | ||||
|     let result = maxLength |> Belt.Array.makeUninitializedUnsafe; | ||||
|     for (i in 0 to maxLength - 1) { | ||||
|       Belt.Array.set(result, i, (get(array1, i), get(array2, i))) |> ignore; | ||||
|     }; | ||||
|     result; | ||||
|   }; | ||||
| 
 | ||||
|   let asList = (f: list('a) => list('a), r: array('a)) => | ||||
|     r |> to_list |> f |> of_list; | ||||
|   /* TODO: Is there a better way of doing this? */ | ||||
|  |  | |||
		Loading…
	
		Reference in New Issue
	
	Block a user