Ensure render in DistPlusRenderer

This commit is contained in:
Ozzie Gooen 2020-08-08 22:33:36 +01:00
parent d8c1aa6693
commit b064001a51
4 changed files with 66 additions and 30 deletions

View File

@ -160,7 +160,7 @@ module PointwiseCombination = {
Render.render(evaluationParams, t2), Render.render(evaluationParams, t2),
) { ) {
| (Ok(`RenderedDist(rs1)), Ok(`RenderedDist(rs2))) => | (Ok(`RenderedDist(rs1)), Ok(`RenderedDist(rs2))) =>
Ok(`RenderedDist(Shape.combinePointwise(( *. ), rs1, rs2))) Ok(`RenderedDist(Shape.combinePointwise(fn, rs1, rs2)))
| (Error(e1), _) => Error(e1) | (Error(e1), _) => Error(e1)
| (_, Error(e2)) => Error(e2) | (_, Error(e2)) => Error(e2)
| _ => Error("Pointwise combination: rendering failed.") | _ => Error("Pointwise combination: rendering failed.")
@ -177,7 +177,7 @@ module PointwiseCombination = {
switch (pointwiseOp) { switch (pointwiseOp) {
| `Add => pointwiseAdd(evaluationParams, t1, t2) | `Add => pointwiseAdd(evaluationParams, t1, t2)
| `Multiply => pointwiseCombine(( *. ),evaluationParams, t1, t2) | `Multiply => pointwiseCombine(( *. ),evaluationParams, t1, t2)
| `Exponentiate => pointwiseCombine(( *. ),evaluationParams, t1, t2) | `Exponentiate => pointwiseCombine(( ** ),evaluationParams, t1, t2)
}; };
}; };
}; };

View File

@ -49,6 +49,11 @@ let to_: array(node) => result(node, string) =
Error("Low value must be less than high value.") Error("Low value must be less than high value.")
| _ => Error("Requires 2 variables"); | _ => Error("Requires 2 variables");
// Possible setup:
// let normal = {"inputs": [`float, `float], "outputs": [`float]};
// let render = {"inputs": [`dist], "outputs": [`renderedDist]};
// let render = {"inputs": [`distRenderedDist], "outputs": [`renderedDist]};
let fnn = let fnn =
( (
evaluationParams: ExpressionTypes.ExpressionTree.evaluationParams, evaluationParams: ExpressionTypes.ExpressionTree.evaluationParams,

View File

@ -201,6 +201,9 @@ module MathAdtToDistDst = {
| ("dotMultiply", [|l, r|]) => toOkPointwise((`Multiply, l, r)) | ("dotMultiply", [|l, r|]) => toOkPointwise((`Multiply, l, r))
| ("dotMultiply", _) => | ("dotMultiply", _) =>
Error("Dotwise multiplication needs two operands") Error("Dotwise multiplication needs two operands")
| ("dotPow", [|l, r|]) => toOkPointwise((`Exponentiate, l, r))
| ("dotPow", _) =>
Error("Dotwise exponentiation needs two operands")
| ("rightLogShift", [|l, r|]) => toOkPointwise((`Add, l, r)) | ("rightLogShift", [|l, r|]) => toOkPointwise((`Add, l, r))
| ("rightLogShift", _) => | ("rightLogShift", _) =>
Error("Dotwise addition needs two operands") Error("Dotwise addition needs two operands")
@ -227,6 +230,12 @@ module MathAdtToDistDst = {
Error( Error(
"truncate needs three arguments: the expression and both cutoffs", "truncate needs three arguments: the expression and both cutoffs",
) )
| ("scaleMultiply", [|d, `SymbolicDist(`Float(v))|]) =>
Ok(`VerticalScaling(`Multiply, d, `SymbolicDist(`Float(v))))
| ("scaleExp", [|d, `SymbolicDist(`Float(v))|]) =>
Ok(`VerticalScaling(`Exponentiate, d, `SymbolicDist(`Float(v))))
| ("scaleLog", [|d, `SymbolicDist(`Float(v))|]) =>
Ok(`VerticalScaling(`Log, d, `SymbolicDist(`Float(v))))
| ("pdf", [|d, `SymbolicDist(`Float(v))|]) => | ("pdf", [|d, `SymbolicDist(`Float(v))|]) =>
toOkFloatFromDist((`Pdf(v), d)) toOkFloatFromDist((`Pdf(v), d))
| ("cdf", [|d, `SymbolicDist(`Float(v))|]) => | ("cdf", [|d, `SymbolicDist(`Float(v))|]) =>
@ -275,11 +284,15 @@ module MathAdtToDistDst = {
| "subtract" | "subtract"
| "multiply" | "multiply"
| "dotMultiply" | "dotMultiply"
| "dotPow"
| "rightLogShift" | "rightLogShift"
| "divide" | "divide"
| "pow" | "pow"
| "leftTruncate" | "leftTruncate"
| "rightTruncate" | "rightTruncate"
| "scaleMultiply"
| "scaleExp"
| "scaleLog"
| "truncate" | "truncate"
| "mean" | "mean"
| "inv" | "inv"
@ -355,6 +368,7 @@ let fromString2 = str => {
Inside of this function, MathAdtToDistDst is called whenever a distribution function is encountered. Inside of this function, MathAdtToDistDst is called whenever a distribution function is encountered.
*/ */
let mathJsToJson = str |> pointwiseToRightLogShift |> Mathjs.parseMath; let mathJsToJson = str |> pointwiseToRightLogShift |> Mathjs.parseMath;
let mathJsParse = let mathJsParse =
E.R.bind(mathJsToJson, r => { E.R.bind(mathJsToJson, r => {
switch (MathJsonToMathJsAdt.run(r)) { switch (MathJsonToMathJsAdt.run(r)) {
@ -364,6 +378,7 @@ let fromString2 = str => {
}); });
let value = E.R.bind(mathJsParse, MathAdtToDistDst.run); let value = E.R.bind(mathJsParse, MathAdtToDistDst.run);
Js.log2(mathJsParse, value);
value; value;
}; };

View File

@ -91,18 +91,16 @@ module Internals = {
}; };
let makeOutputs = (graph, shape): outputs => {graph, shape}; let makeOutputs = (graph, shape): outputs => {graph, shape};
let runNode = (inputs, node) => { let makeInputs = (inputs): ExpressionTypes.ExpressionTree.samplingInputs => {
ExpressionTree.toLeaf(
{
sampleCount: inputs.samplingInputs.sampleCount |> E.O.default(10000), sampleCount: inputs.samplingInputs.sampleCount |> E.O.default(10000),
outputXYPoints: outputXYPoints:
inputs.samplingInputs.outputXYPoints |> E.O.default(10000), inputs.samplingInputs.outputXYPoints |> E.O.default(10000),
kernelWidth: inputs.samplingInputs.kernelWidth, kernelWidth: inputs.samplingInputs.kernelWidth,
shapeLength: inputs.samplingInputs.shapeLength |> E.O.default(10000), shapeLength: inputs.samplingInputs.shapeLength |> E.O.default(10000),
}, };
inputs.environment,
node, let runNode = (inputs, node) => {
); ExpressionTree.toLeaf(makeInputs(inputs), inputs.environment, node);
}; };
let runProgram = (inputs: inputs, p: ExpressionTypes.Program.program) => { let runProgram = (inputs: inputs, p: ExpressionTypes.Program.program) => {
@ -154,14 +152,39 @@ let run = (inputs: Inputs.inputs) => {
|> E.R.fmap(Internals.outputToDistPlus(inputs)); |> E.R.fmap(Internals.outputToDistPlus(inputs));
}; };
let exportDistPlus = inputs => let renderIfNeeded =
(inputs, node: ExpressionTypes.ExpressionTree.node)
: result(ExpressionTypes.ExpressionTree.node, string) =>
node
|> (
fun fun
| `RenderedDist(n) => Ok(`DistPlus(Internals.outputToDistPlus(inputs, n))) | `SymbolicDist(n) => {
`Render(`SymbolicDist(n))
|> Internals.runNode(Internals.distPlusRenderInputsToInputs(inputs))
|> (
fun
| Ok(`RenderedDist(n)) => Ok(`RenderedDist(n))
| Error(r) => Error(r)
| _ => Error("Didn't render, but intended to")
);
}
| n => Ok(n)
);
let exportDistPlus = (inputs, node: ExpressionTypes.ExpressionTree.node) =>
node
|> renderIfNeeded(inputs)
|> E.R.bind(
_,
fun
| `RenderedDist(n) =>
Ok(`DistPlus(Internals.outputToDistPlus(inputs, n)))
| `Function(n) => Ok(`Function(n)) | `Function(n) => Ok(`Function(n))
| n => | n =>
Error( Error(
"Didn't output a rendered distribution. Format:" "Didn't output a rendered distribution. Format:"
++ ExpressionTree.toString(n), ++ ExpressionTree.toString(n),
),
); );
let run2 = (inputs: Inputs.inputs) => { let run2 = (inputs: Inputs.inputs) => {
@ -177,17 +200,10 @@ let runFunction =
fn: (array(string), ExpressionTypes.ExpressionTree.node), fn: (array(string), ExpressionTypes.ExpressionTree.node),
fnInputs, fnInputs,
) => { ) => {
let (_, fns) = fn;
let inputs = ins |> Internals.distPlusRenderInputsToInputs; let inputs = ins |> Internals.distPlusRenderInputsToInputs;
let output = let output =
ExpressionTree.runFunction( ExpressionTree.runFunction(
{ Internals.makeInputs(inputs),
sampleCount: inputs.samplingInputs.sampleCount |> E.O.default(10000),
outputXYPoints:
inputs.samplingInputs.outputXYPoints |> E.O.default(10000),
kernelWidth: inputs.samplingInputs.kernelWidth,
shapeLength: inputs.samplingInputs.shapeLength |> E.O.default(10000),
},
inputs.environment, inputs.environment,
fnInputs, fnInputs,
fn, fn,