Minor fix for multimodals
This commit is contained in:
		
							parent
							
								
									4c33561b7c
								
							
						
					
					
						commit
						7566c59fef
					
				| 
						 | 
					@ -304,7 +304,7 @@ let rec toLeaf =
 | 
				
			||||||
    let components =
 | 
					    let components =
 | 
				
			||||||
      r
 | 
					      r
 | 
				
			||||||
      |> E.A.fmap(((dist, weight)) =>
 | 
					      |> E.A.fmap(((dist, weight)) =>
 | 
				
			||||||
      `FunctionCall("scaleExp", [|dist, `SymbolicDist(`Float(weight))|]));
 | 
					      `FunctionCall("scaleMultiply", [|dist, `SymbolicDist(`Float(weight))|]));
 | 
				
			||||||
    let pointwiseSum =
 | 
					    let pointwiseSum =
 | 
				
			||||||
      components
 | 
					      components
 | 
				
			||||||
      |> Js.Array.sliceFrom(1)
 | 
					      |> Js.Array.sliceFrom(1)
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -91,14 +91,16 @@ let verticalScaling = (scaleOp, rs, scaleBy) => {
 | 
				
			||||||
    Operation.Scale.toFn(scaleOp, main, secondary);
 | 
					    Operation.Scale.toFn(scaleOp, main, secondary);
 | 
				
			||||||
  let integralSumCacheFn = Operation.Scale.toIntegralSumCacheFn(scaleOp);
 | 
					  let integralSumCacheFn = Operation.Scale.toIntegralSumCacheFn(scaleOp);
 | 
				
			||||||
  let integralCacheFn = Operation.Scale.toIntegralCacheFn(scaleOp);
 | 
					  let integralCacheFn = Operation.Scale.toIntegralCacheFn(scaleOp);
 | 
				
			||||||
  Ok(`RenderedDist(
 | 
					  Ok(
 | 
				
			||||||
    Shape.T.mapY(
 | 
					    `RenderedDist(
 | 
				
			||||||
      ~integralSumCacheFn=integralSumCacheFn(scaleBy),
 | 
					      Shape.T.mapY(
 | 
				
			||||||
      ~integralCacheFn=integralCacheFn(scaleBy),
 | 
					        ~integralSumCacheFn=integralSumCacheFn(scaleBy),
 | 
				
			||||||
      ~fn=fn(scaleBy),
 | 
					        ~integralCacheFn=integralCacheFn(scaleBy),
 | 
				
			||||||
      rs,
 | 
					        ~fn=fn(scaleBy),
 | 
				
			||||||
 | 
					        rs,
 | 
				
			||||||
 | 
					      ),
 | 
				
			||||||
    ),
 | 
					    ),
 | 
				
			||||||
  ));
 | 
					  );
 | 
				
			||||||
};
 | 
					};
 | 
				
			||||||
 | 
					
 | 
				
			||||||
let functions = [|
 | 
					let functions = [|
 | 
				
			||||||
| 
						 | 
					@ -152,7 +154,8 @@ let functions = [|
 | 
				
			||||||
    ~run=
 | 
					    ~run=
 | 
				
			||||||
      fun
 | 
					      fun
 | 
				
			||||||
      | [|`SamplingDist(`SymbolicDist(c))|] => Ok(`SymbolicDist(c))
 | 
					      | [|`SamplingDist(`SymbolicDist(c))|] => Ok(`SymbolicDist(c))
 | 
				
			||||||
      | [|`SamplingDist(`RenderedDist(c))|] => Ok(`RenderedDist(Shape.T.normalize(c)))
 | 
					      | [|`SamplingDist(`RenderedDist(c))|] =>
 | 
				
			||||||
 | 
					        Ok(`RenderedDist(Shape.T.normalize(c)))
 | 
				
			||||||
      | e => wrongInputsError(e),
 | 
					      | e => wrongInputsError(e),
 | 
				
			||||||
  ),
 | 
					  ),
 | 
				
			||||||
  makeRenderedDistFloat("scaleExp", (dist, float) =>
 | 
					  makeRenderedDistFloat("scaleExp", (dist, float) =>
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -22,7 +22,6 @@ let fnn =
 | 
				
			||||||
    ) {
 | 
					    ) {
 | 
				
			||||||
    | (_, Some(`Function(argNames, tt))) =>
 | 
					    | (_, Some(`Function(argNames, tt))) =>
 | 
				
			||||||
      PTypes.Function.run(evaluationParams, args, (argNames, tt))
 | 
					      PTypes.Function.run(evaluationParams, args, (argNames, tt))
 | 
				
			||||||
    | ("mm", _)
 | 
					 | 
				
			||||||
    | ("multimodal", _) =>
 | 
					    | ("multimodal", _) =>
 | 
				
			||||||
      switch (args |> E.A.to_list) {
 | 
					      switch (args |> E.A.to_list) {
 | 
				
			||||||
      | [`Array(weights), ...dists] =>
 | 
					      | [`Array(weights), ...dists] =>
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -204,7 +204,8 @@ module MathAdtToDistDst = {
 | 
				
			||||||
    let parseArgs = () => parseArray(args);
 | 
					    let parseArgs = () => parseArray(args);
 | 
				
			||||||
    switch (name) {
 | 
					    switch (name) {
 | 
				
			||||||
    | "lognormal" => lognormal(args, parseArgs, nodeParser)
 | 
					    | "lognormal" => lognormal(args, parseArgs, nodeParser)
 | 
				
			||||||
    | "mm" =>{
 | 
					    | "multimodal"
 | 
				
			||||||
 | 
					    | "mm" =>
 | 
				
			||||||
      let weights =
 | 
					      let weights =
 | 
				
			||||||
        args
 | 
					        args
 | 
				
			||||||
        |> E.A.last
 | 
					        |> E.A.last
 | 
				
			||||||
| 
						 | 
					@ -212,20 +213,22 @@ module MathAdtToDistDst = {
 | 
				
			||||||
             _,
 | 
					             _,
 | 
				
			||||||
             fun
 | 
					             fun
 | 
				
			||||||
             | Array(values) => Some(parseArray(values))
 | 
					             | Array(values) => Some(parseArray(values))
 | 
				
			||||||
             | _ => None
 | 
					             | _ => None,
 | 
				
			||||||
           );
 | 
					           );
 | 
				
			||||||
      let possibleDists =
 | 
					      let possibleDists =
 | 
				
			||||||
        E.O.isSome(weights)
 | 
					        E.O.isSome(weights)
 | 
				
			||||||
          ? Belt.Array.slice(args, ~offset=0, ~len=E.A.length(args) - 1)
 | 
					          ? Belt.Array.slice(args, ~offset=0, ~len=E.A.length(args) - 1)
 | 
				
			||||||
          : args;
 | 
					          : args;
 | 
				
			||||||
      let dists = parseArray(possibleDists);
 | 
					      let dists = parseArray(possibleDists);
 | 
				
			||||||
      switch(weights, dists){
 | 
					      switch (weights, dists) {
 | 
				
			||||||
        | (Some(Error(r)), _) => Error(r)
 | 
					      | (Some(Error(r)), _) => Error(r)
 | 
				
			||||||
        | (_, Error(r)) => Error(r)
 | 
					      | (_, Error(r)) => Error(r)
 | 
				
			||||||
        | (None, Ok(dists)) => Ok(`FunctionCall("multimodal", dists))
 | 
					      | (None, Ok(dists)) => Ok(`FunctionCall(("multimodal", dists)))
 | 
				
			||||||
        | (Some(Ok(r)), Ok(dists)) => Ok(`FunctionCall("multimodal", E.A.append([|`Array(r)|], dists)))
 | 
					      | (Some(Ok(r)), Ok(dists)) =>
 | 
				
			||||||
      }
 | 
					        Ok(
 | 
				
			||||||
    }
 | 
					          `FunctionCall(("multimodal", E.A.append([|`Array(r)|], dists))),
 | 
				
			||||||
 | 
					        )
 | 
				
			||||||
 | 
					      };
 | 
				
			||||||
    | "add"
 | 
					    | "add"
 | 
				
			||||||
    | "subtract"
 | 
					    | "subtract"
 | 
				
			||||||
    | "multiply"
 | 
					    | "multiply"
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
		Loading…
	
		Reference in New Issue
	
	Block a user