diff --git a/src/distPlus/expressionTree/ExpressionTree.re b/src/distPlus/expressionTree/ExpressionTree.re index 2611a432..ca57c14f 100644 --- a/src/distPlus/expressionTree/ExpressionTree.re +++ b/src/distPlus/expressionTree/ExpressionTree.re @@ -1,31 +1,6 @@ open ExpressionTypes.ExpressionTree; -let rec toString: node => string = - fun - | `SymbolicDist(d) => SymbolicDist.T.toString(d) - | `RenderedDist(_) => "[renderedShape]" - | `AlgebraicCombination(op, t1, t2) => - Operation.Algebraic.format(op, toString(t1), toString(t2)) - | `PointwiseCombination(op, t1, t2) => - Operation.Pointwise.format(op, toString(t1), toString(t2)) - | `Normalize(t) => "normalize(k" ++ toString(t) ++ ")" - | `Truncate(lc, rc, t) => - Operation.T.truncateToString(lc, rc, toString(t)) - | `Render(t) => toString(t) - | `Symbol(t) => "Symbol: " ++ t - | `FunctionCall(name, args) => - "[Function call: (" - ++ name - ++ (args |> E.A.fmap(toString) |> Js.String.concatMany(_, ",")) - ++ ")]" - | `Function(args, internal) => - "[Function: (" - ++ (args |> Js.String.concatMany(_, ",")) - ++ toString(internal) - ++ ")]" - | `Array(_) => "Array" - | `Hash(_) => "Hash" - +let toString = ExpressionTreeBasic.toString; let envs = (samplingInputs, environment) => { {samplingInputs, environment, evaluateNode: ExpressionTreeEvaluator.toLeaf}; }; @@ -42,5 +17,5 @@ let toShape = (samplingInputs, environment, node: node) => { let runFunction = (samplingInputs, environment, inputs, fn: PTypes.Function.t) => { let params = envs(samplingInputs, environment); - PTypes.Function.run(params, inputs, fn) -} + PTypes.Function.run(params, inputs, fn); +}; diff --git a/src/distPlus/expressionTree/ExpressionTreeBasic.re b/src/distPlus/expressionTree/ExpressionTreeBasic.re new file mode 100644 index 00000000..27d5aab4 --- /dev/null +++ b/src/distPlus/expressionTree/ExpressionTreeBasic.re @@ -0,0 +1,35 @@ +open ExpressionTypes.ExpressionTree; + +let rec toString: node => string = + fun + | `SymbolicDist(d) => SymbolicDist.T.toString(d) + | `RenderedDist(_) => "[renderedShape]" + | `AlgebraicCombination(op, t1, t2) => + Operation.Algebraic.format(op, toString(t1), toString(t2)) + | `PointwiseCombination(op, t1, t2) => + Operation.Pointwise.format(op, toString(t1), toString(t2)) + | `Normalize(t) => "normalize(k" ++ toString(t) ++ ")" + | `Truncate(lc, rc, t) => + Operation.T.truncateToString(lc, rc, toString(t)) + | `Render(t) => toString(t) + | `Symbol(t) => "Symbol: " ++ t + | `FunctionCall(name, args) => + "[Function call: (" + ++ name + ++ (args |> E.A.fmap(toString) |> Js.String.concatMany(_, ",")) + ++ ")]" + | `Function(args, internal) => + "[Function: (" + ++ (args |> Js.String.concatMany(_, ",")) + ++ toString(internal) + ++ ")]" + | `Array(a) => + "[" ++ (a |> E.A.fmap(toString) |> Js.String.concatMany(_, ",")) ++ "]" + | `Hash(h) => + "{" + ++ ( + h + |> E.A.fmap(((name, value)) => name ++ ":" ++ toString(value)) + |> Js.String.concatMany(_, ",") + ) + ++ "}"; diff --git a/src/distPlus/typeSystem/Fns.re b/src/distPlus/typeSystem/Fns.re index e49f434c..5c5fb426 100644 --- a/src/distPlus/typeSystem/Fns.re +++ b/src/distPlus/typeSystem/Fns.re @@ -1,7 +1,9 @@ open TypeSystem; -let wrongInputsError = r => { - Error("Wrong inputs"); +let wrongInputsError = (r: array(typedValue)) => { + let inputs = r |> E.A.fmap(TypedValue.toString) |>Js.String.concatMany(_, ","); + Js.log3("Inputs were", inputs, r); + Error("Wrong inputs. The inputs were:" ++ inputs); }; let to_: (float, float) => result(node, string) = @@ -46,6 +48,7 @@ let makeDistFloat = (name, fn) => ~run= fun | [|`SamplingDist(a), `Float(b)|] => fn(a, b) + | [|`RenderedDist(a), `Float(b)|] => fn(`RenderedDist(a), b) | e => wrongInputsError(e), (), ); @@ -55,6 +58,7 @@ let makeRenderedDistFloat = (name, fn) => ~name, ~outputType=`RenderedDistribution, ~inputTypes=[|`RenderedDistribution, `Float|], + ~shouldCoerceTypes=true, ~run= fun | [|`RenderedDist(a), `Float(b)|] => fn(a, b) @@ -70,6 +74,7 @@ let makeDist = (name, fn) => ~run= fun | [|`SamplingDist(a)|] => fn(a) + | [|`RenderedDist(a)|] => fn(`RenderedDist(a)) | e => wrongInputsError(e), (), ); diff --git a/src/distPlus/typeSystem/TypeSystem.re b/src/distPlus/typeSystem/TypeSystem.re index b6f54b46..654d3e98 100644 --- a/src/distPlus/typeSystem/TypeSystem.re +++ b/src/distPlus/typeSystem/TypeSystem.re @@ -36,6 +36,22 @@ type functions = array(_function); type inputNodes = array(node); module TypedValue = { + let rec toString: typedValue => string = + fun + | `SamplingDist(_) => "[sampling dist]" + | `RenderedDist(_) => "[rendered Shape]" + | `Float(f) => "Float: " ++ Js.Float.toString(f) + | `Array(a) => + "[" ++ (a |> E.A.fmap(toString) |> Js.String.concatMany(_, ",")) ++ "]" + | `Hash(v) => + "{" + ++ ( + v + |> E.A.fmap(((name, value)) => name ++ ":" ++ toString(value)) + |> Js.String.concatMany(_, ",") + ) + ++ "}"; + let rec fromNode = (node: node): result(typedValue, string) => switch (ExpressionTypes.ExpressionTree.toFloatIfNeeded(node)) { | `SymbolicDist(`Float(r)) => Ok(`Float(r)) @@ -51,7 +67,7 @@ module TypedValue = { |> E.A.fmap(((name, t)) => fromNode(t) |> E.R.fmap(r => (name, r))) |> E.A.R.firstErrorOrOpen |> E.R.fmap(r => `Hash(r)) - | _ => Error("Wrong type") + | e => Error("Wrong type: " ++ ExpressionTreeBasic.toString(e)) }; // todo: Arrays and hashes @@ -103,12 +119,12 @@ module TypedValue = { }; }; - let toFloat: typedValue => result(float,string) = + let toFloat: typedValue => result(float, string) = fun | `Float(x) => Ok(x) | _ => Error("Not a float"); - let toArray: typedValue => result(array('a),string) = + let toArray: typedValue => result(array('a), string) = fun | `Array(x) => Ok(x) | _ => Error("Not an array"); @@ -118,12 +134,13 @@ module TypedValue = { | `Hash(x) => Ok(x) | _ => Error("Not a named item"); - let toDist = + let toDist: typedValue => result(node,string) = fun | `SamplingDist(`SymbolicDist(c)) => Ok(`SymbolicDist(c)) | `SamplingDist(`RenderedDist(c)) => Ok(`RenderedDist(c)) + | `RenderedDist(c) => Ok(`RenderedDist(c)) | `Float(x) => Ok(`SymbolicDist(`Float(x))) - | _ => Error("Cannot be converted into a distribution"); + | x => Error("Cannot be converted into a distribution: " ++ toString(x)); }; module Function = {