Ensure render in DistPlusRenderer

This commit is contained in:
Ozzie Gooen 2020-08-08 22:33:36 +01:00
parent d8c1aa6693
commit b064001a51
4 changed files with 66 additions and 30 deletions

View File

@ -160,7 +160,7 @@ module PointwiseCombination = {
Render.render(evaluationParams, t2),
) {
| (Ok(`RenderedDist(rs1)), Ok(`RenderedDist(rs2))) =>
Ok(`RenderedDist(Shape.combinePointwise(( *. ), rs1, rs2)))
Ok(`RenderedDist(Shape.combinePointwise(fn, rs1, rs2)))
| (Error(e1), _) => Error(e1)
| (_, Error(e2)) => Error(e2)
| _ => Error("Pointwise combination: rendering failed.")
@ -177,7 +177,7 @@ module PointwiseCombination = {
switch (pointwiseOp) {
| `Add => pointwiseAdd(evaluationParams, t1, t2)
| `Multiply => pointwiseCombine(( *. ),evaluationParams, t1, t2)
| `Exponentiate => pointwiseCombine(( *. ),evaluationParams, t1, t2)
| `Exponentiate => pointwiseCombine(( ** ),evaluationParams, t1, t2)
};
};
};

View File

@ -49,6 +49,11 @@ let to_: array(node) => result(node, string) =
Error("Low value must be less than high value.")
| _ => Error("Requires 2 variables");
// Possible setup:
// let normal = {"inputs": [`float, `float], "outputs": [`float]};
// let render = {"inputs": [`dist], "outputs": [`renderedDist]};
// let render = {"inputs": [`distRenderedDist], "outputs": [`renderedDist]};
let fnn =
(
evaluationParams: ExpressionTypes.ExpressionTree.evaluationParams,

View File

@ -201,6 +201,9 @@ module MathAdtToDistDst = {
| ("dotMultiply", [|l, r|]) => toOkPointwise((`Multiply, l, r))
| ("dotMultiply", _) =>
Error("Dotwise multiplication needs two operands")
| ("dotPow", [|l, r|]) => toOkPointwise((`Exponentiate, l, r))
| ("dotPow", _) =>
Error("Dotwise exponentiation needs two operands")
| ("rightLogShift", [|l, r|]) => toOkPointwise((`Add, l, r))
| ("rightLogShift", _) =>
Error("Dotwise addition needs two operands")
@ -227,6 +230,12 @@ module MathAdtToDistDst = {
Error(
"truncate needs three arguments: the expression and both cutoffs",
)
| ("scaleMultiply", [|d, `SymbolicDist(`Float(v))|]) =>
Ok(`VerticalScaling(`Multiply, d, `SymbolicDist(`Float(v))))
| ("scaleExp", [|d, `SymbolicDist(`Float(v))|]) =>
Ok(`VerticalScaling(`Exponentiate, d, `SymbolicDist(`Float(v))))
| ("scaleLog", [|d, `SymbolicDist(`Float(v))|]) =>
Ok(`VerticalScaling(`Log, d, `SymbolicDist(`Float(v))))
| ("pdf", [|d, `SymbolicDist(`Float(v))|]) =>
toOkFloatFromDist((`Pdf(v), d))
| ("cdf", [|d, `SymbolicDist(`Float(v))|]) =>
@ -275,11 +284,15 @@ module MathAdtToDistDst = {
| "subtract"
| "multiply"
| "dotMultiply"
| "dotPow"
| "rightLogShift"
| "divide"
| "pow"
| "leftTruncate"
| "rightTruncate"
| "scaleMultiply"
| "scaleExp"
| "scaleLog"
| "truncate"
| "mean"
| "inv"
@ -355,6 +368,7 @@ let fromString2 = str => {
Inside of this function, MathAdtToDistDst is called whenever a distribution function is encountered.
*/
let mathJsToJson = str |> pointwiseToRightLogShift |> Mathjs.parseMath;
let mathJsParse =
E.R.bind(mathJsToJson, r => {
switch (MathJsonToMathJsAdt.run(r)) {
@ -364,6 +378,7 @@ let fromString2 = str => {
});
let value = E.R.bind(mathJsParse, MathAdtToDistDst.run);
Js.log2(mathJsParse, value);
value;
};

View File

@ -91,18 +91,16 @@ module Internals = {
};
let makeOutputs = (graph, shape): outputs => {graph, shape};
let runNode = (inputs, node) => {
ExpressionTree.toLeaf(
{
let makeInputs = (inputs): ExpressionTypes.ExpressionTree.samplingInputs => {
sampleCount: inputs.samplingInputs.sampleCount |> E.O.default(10000),
outputXYPoints:
inputs.samplingInputs.outputXYPoints |> E.O.default(10000),
kernelWidth: inputs.samplingInputs.kernelWidth,
shapeLength: inputs.samplingInputs.shapeLength |> E.O.default(10000),
},
inputs.environment,
node,
);
};
let runNode = (inputs, node) => {
ExpressionTree.toLeaf(makeInputs(inputs), inputs.environment, node);
};
let runProgram = (inputs: inputs, p: ExpressionTypes.Program.program) => {
@ -154,14 +152,39 @@ let run = (inputs: Inputs.inputs) => {
|> E.R.fmap(Internals.outputToDistPlus(inputs));
};
let exportDistPlus = inputs =>
let renderIfNeeded =
(inputs, node: ExpressionTypes.ExpressionTree.node)
: result(ExpressionTypes.ExpressionTree.node, string) =>
node
|> (
fun
| `RenderedDist(n) => Ok(`DistPlus(Internals.outputToDistPlus(inputs, n)))
| `SymbolicDist(n) => {
`Render(`SymbolicDist(n))
|> Internals.runNode(Internals.distPlusRenderInputsToInputs(inputs))
|> (
fun
| Ok(`RenderedDist(n)) => Ok(`RenderedDist(n))
| Error(r) => Error(r)
| _ => Error("Didn't render, but intended to")
);
}
| n => Ok(n)
);
let exportDistPlus = (inputs, node: ExpressionTypes.ExpressionTree.node) =>
node
|> renderIfNeeded(inputs)
|> E.R.bind(
_,
fun
| `RenderedDist(n) =>
Ok(`DistPlus(Internals.outputToDistPlus(inputs, n)))
| `Function(n) => Ok(`Function(n))
| n =>
Error(
"Didn't output a rendered distribution. Format:"
++ ExpressionTree.toString(n),
),
);
let run2 = (inputs: Inputs.inputs) => {
@ -177,17 +200,10 @@ let runFunction =
fn: (array(string), ExpressionTypes.ExpressionTree.node),
fnInputs,
) => {
let (_, fns) = fn;
let inputs = ins |> Internals.distPlusRenderInputsToInputs;
let output =
ExpressionTree.runFunction(
{
sampleCount: inputs.samplingInputs.sampleCount |> E.O.default(10000),
outputXYPoints:
inputs.samplingInputs.outputXYPoints |> E.O.default(10000),
kernelWidth: inputs.samplingInputs.kernelWidth,
shapeLength: inputs.samplingInputs.shapeLength |> E.O.default(10000),
},
Internals.makeInputs(inputs),
inputs.environment,
fnInputs,
fn,