Ensure render in DistPlusRenderer
This commit is contained in:
parent
d8c1aa6693
commit
b064001a51
|
@ -160,7 +160,7 @@ module PointwiseCombination = {
|
||||||
Render.render(evaluationParams, t2),
|
Render.render(evaluationParams, t2),
|
||||||
) {
|
) {
|
||||||
| (Ok(`RenderedDist(rs1)), Ok(`RenderedDist(rs2))) =>
|
| (Ok(`RenderedDist(rs1)), Ok(`RenderedDist(rs2))) =>
|
||||||
Ok(`RenderedDist(Shape.combinePointwise(( *. ), rs1, rs2)))
|
Ok(`RenderedDist(Shape.combinePointwise(fn, rs1, rs2)))
|
||||||
| (Error(e1), _) => Error(e1)
|
| (Error(e1), _) => Error(e1)
|
||||||
| (_, Error(e2)) => Error(e2)
|
| (_, Error(e2)) => Error(e2)
|
||||||
| _ => Error("Pointwise combination: rendering failed.")
|
| _ => Error("Pointwise combination: rendering failed.")
|
||||||
|
@ -177,7 +177,7 @@ module PointwiseCombination = {
|
||||||
switch (pointwiseOp) {
|
switch (pointwiseOp) {
|
||||||
| `Add => pointwiseAdd(evaluationParams, t1, t2)
|
| `Add => pointwiseAdd(evaluationParams, t1, t2)
|
||||||
| `Multiply => pointwiseCombine(( *. ),evaluationParams, t1, t2)
|
| `Multiply => pointwiseCombine(( *. ),evaluationParams, t1, t2)
|
||||||
| `Exponentiate => pointwiseCombine(( *. ),evaluationParams, t1, t2)
|
| `Exponentiate => pointwiseCombine(( ** ),evaluationParams, t1, t2)
|
||||||
};
|
};
|
||||||
};
|
};
|
||||||
};
|
};
|
||||||
|
|
|
@ -49,6 +49,11 @@ let to_: array(node) => result(node, string) =
|
||||||
Error("Low value must be less than high value.")
|
Error("Low value must be less than high value.")
|
||||||
| _ => Error("Requires 2 variables");
|
| _ => Error("Requires 2 variables");
|
||||||
|
|
||||||
|
// Possible setup:
|
||||||
|
// let normal = {"inputs": [`float, `float], "outputs": [`float]};
|
||||||
|
// let render = {"inputs": [`dist], "outputs": [`renderedDist]};
|
||||||
|
// let render = {"inputs": [`distRenderedDist], "outputs": [`renderedDist]};
|
||||||
|
|
||||||
let fnn =
|
let fnn =
|
||||||
(
|
(
|
||||||
evaluationParams: ExpressionTypes.ExpressionTree.evaluationParams,
|
evaluationParams: ExpressionTypes.ExpressionTree.evaluationParams,
|
||||||
|
|
|
@ -201,6 +201,9 @@ module MathAdtToDistDst = {
|
||||||
| ("dotMultiply", [|l, r|]) => toOkPointwise((`Multiply, l, r))
|
| ("dotMultiply", [|l, r|]) => toOkPointwise((`Multiply, l, r))
|
||||||
| ("dotMultiply", _) =>
|
| ("dotMultiply", _) =>
|
||||||
Error("Dotwise multiplication needs two operands")
|
Error("Dotwise multiplication needs two operands")
|
||||||
|
| ("dotPow", [|l, r|]) => toOkPointwise((`Exponentiate, l, r))
|
||||||
|
| ("dotPow", _) =>
|
||||||
|
Error("Dotwise exponentiation needs two operands")
|
||||||
| ("rightLogShift", [|l, r|]) => toOkPointwise((`Add, l, r))
|
| ("rightLogShift", [|l, r|]) => toOkPointwise((`Add, l, r))
|
||||||
| ("rightLogShift", _) =>
|
| ("rightLogShift", _) =>
|
||||||
Error("Dotwise addition needs two operands")
|
Error("Dotwise addition needs two operands")
|
||||||
|
@ -227,6 +230,12 @@ module MathAdtToDistDst = {
|
||||||
Error(
|
Error(
|
||||||
"truncate needs three arguments: the expression and both cutoffs",
|
"truncate needs three arguments: the expression and both cutoffs",
|
||||||
)
|
)
|
||||||
|
| ("scaleMultiply", [|d, `SymbolicDist(`Float(v))|]) =>
|
||||||
|
Ok(`VerticalScaling(`Multiply, d, `SymbolicDist(`Float(v))))
|
||||||
|
| ("scaleExp", [|d, `SymbolicDist(`Float(v))|]) =>
|
||||||
|
Ok(`VerticalScaling(`Exponentiate, d, `SymbolicDist(`Float(v))))
|
||||||
|
| ("scaleLog", [|d, `SymbolicDist(`Float(v))|]) =>
|
||||||
|
Ok(`VerticalScaling(`Log, d, `SymbolicDist(`Float(v))))
|
||||||
| ("pdf", [|d, `SymbolicDist(`Float(v))|]) =>
|
| ("pdf", [|d, `SymbolicDist(`Float(v))|]) =>
|
||||||
toOkFloatFromDist((`Pdf(v), d))
|
toOkFloatFromDist((`Pdf(v), d))
|
||||||
| ("cdf", [|d, `SymbolicDist(`Float(v))|]) =>
|
| ("cdf", [|d, `SymbolicDist(`Float(v))|]) =>
|
||||||
|
@ -275,11 +284,15 @@ module MathAdtToDistDst = {
|
||||||
| "subtract"
|
| "subtract"
|
||||||
| "multiply"
|
| "multiply"
|
||||||
| "dotMultiply"
|
| "dotMultiply"
|
||||||
|
| "dotPow"
|
||||||
| "rightLogShift"
|
| "rightLogShift"
|
||||||
| "divide"
|
| "divide"
|
||||||
| "pow"
|
| "pow"
|
||||||
| "leftTruncate"
|
| "leftTruncate"
|
||||||
| "rightTruncate"
|
| "rightTruncate"
|
||||||
|
| "scaleMultiply"
|
||||||
|
| "scaleExp"
|
||||||
|
| "scaleLog"
|
||||||
| "truncate"
|
| "truncate"
|
||||||
| "mean"
|
| "mean"
|
||||||
| "inv"
|
| "inv"
|
||||||
|
@ -355,6 +368,7 @@ let fromString2 = str => {
|
||||||
Inside of this function, MathAdtToDistDst is called whenever a distribution function is encountered.
|
Inside of this function, MathAdtToDistDst is called whenever a distribution function is encountered.
|
||||||
*/
|
*/
|
||||||
let mathJsToJson = str |> pointwiseToRightLogShift |> Mathjs.parseMath;
|
let mathJsToJson = str |> pointwiseToRightLogShift |> Mathjs.parseMath;
|
||||||
|
|
||||||
let mathJsParse =
|
let mathJsParse =
|
||||||
E.R.bind(mathJsToJson, r => {
|
E.R.bind(mathJsToJson, r => {
|
||||||
switch (MathJsonToMathJsAdt.run(r)) {
|
switch (MathJsonToMathJsAdt.run(r)) {
|
||||||
|
@ -364,6 +378,7 @@ let fromString2 = str => {
|
||||||
});
|
});
|
||||||
|
|
||||||
let value = E.R.bind(mathJsParse, MathAdtToDistDst.run);
|
let value = E.R.bind(mathJsParse, MathAdtToDistDst.run);
|
||||||
|
Js.log2(mathJsParse, value);
|
||||||
value;
|
value;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
|
|
@ -91,18 +91,16 @@ module Internals = {
|
||||||
};
|
};
|
||||||
let makeOutputs = (graph, shape): outputs => {graph, shape};
|
let makeOutputs = (graph, shape): outputs => {graph, shape};
|
||||||
|
|
||||||
let runNode = (inputs, node) => {
|
let makeInputs = (inputs): ExpressionTypes.ExpressionTree.samplingInputs => {
|
||||||
ExpressionTree.toLeaf(
|
|
||||||
{
|
|
||||||
sampleCount: inputs.samplingInputs.sampleCount |> E.O.default(10000),
|
sampleCount: inputs.samplingInputs.sampleCount |> E.O.default(10000),
|
||||||
outputXYPoints:
|
outputXYPoints:
|
||||||
inputs.samplingInputs.outputXYPoints |> E.O.default(10000),
|
inputs.samplingInputs.outputXYPoints |> E.O.default(10000),
|
||||||
kernelWidth: inputs.samplingInputs.kernelWidth,
|
kernelWidth: inputs.samplingInputs.kernelWidth,
|
||||||
shapeLength: inputs.samplingInputs.shapeLength |> E.O.default(10000),
|
shapeLength: inputs.samplingInputs.shapeLength |> E.O.default(10000),
|
||||||
},
|
};
|
||||||
inputs.environment,
|
|
||||||
node,
|
let runNode = (inputs, node) => {
|
||||||
);
|
ExpressionTree.toLeaf(makeInputs(inputs), inputs.environment, node);
|
||||||
};
|
};
|
||||||
|
|
||||||
let runProgram = (inputs: inputs, p: ExpressionTypes.Program.program) => {
|
let runProgram = (inputs: inputs, p: ExpressionTypes.Program.program) => {
|
||||||
|
@ -154,14 +152,39 @@ let run = (inputs: Inputs.inputs) => {
|
||||||
|> E.R.fmap(Internals.outputToDistPlus(inputs));
|
|> E.R.fmap(Internals.outputToDistPlus(inputs));
|
||||||
};
|
};
|
||||||
|
|
||||||
let exportDistPlus = inputs =>
|
let renderIfNeeded =
|
||||||
|
(inputs, node: ExpressionTypes.ExpressionTree.node)
|
||||||
|
: result(ExpressionTypes.ExpressionTree.node, string) =>
|
||||||
|
node
|
||||||
|
|> (
|
||||||
fun
|
fun
|
||||||
| `RenderedDist(n) => Ok(`DistPlus(Internals.outputToDistPlus(inputs, n)))
|
| `SymbolicDist(n) => {
|
||||||
|
`Render(`SymbolicDist(n))
|
||||||
|
|> Internals.runNode(Internals.distPlusRenderInputsToInputs(inputs))
|
||||||
|
|> (
|
||||||
|
fun
|
||||||
|
| Ok(`RenderedDist(n)) => Ok(`RenderedDist(n))
|
||||||
|
| Error(r) => Error(r)
|
||||||
|
| _ => Error("Didn't render, but intended to")
|
||||||
|
);
|
||||||
|
}
|
||||||
|
| n => Ok(n)
|
||||||
|
);
|
||||||
|
|
||||||
|
let exportDistPlus = (inputs, node: ExpressionTypes.ExpressionTree.node) =>
|
||||||
|
node
|
||||||
|
|> renderIfNeeded(inputs)
|
||||||
|
|> E.R.bind(
|
||||||
|
_,
|
||||||
|
fun
|
||||||
|
| `RenderedDist(n) =>
|
||||||
|
Ok(`DistPlus(Internals.outputToDistPlus(inputs, n)))
|
||||||
| `Function(n) => Ok(`Function(n))
|
| `Function(n) => Ok(`Function(n))
|
||||||
| n =>
|
| n =>
|
||||||
Error(
|
Error(
|
||||||
"Didn't output a rendered distribution. Format:"
|
"Didn't output a rendered distribution. Format:"
|
||||||
++ ExpressionTree.toString(n),
|
++ ExpressionTree.toString(n),
|
||||||
|
),
|
||||||
);
|
);
|
||||||
|
|
||||||
let run2 = (inputs: Inputs.inputs) => {
|
let run2 = (inputs: Inputs.inputs) => {
|
||||||
|
@ -177,17 +200,10 @@ let runFunction =
|
||||||
fn: (array(string), ExpressionTypes.ExpressionTree.node),
|
fn: (array(string), ExpressionTypes.ExpressionTree.node),
|
||||||
fnInputs,
|
fnInputs,
|
||||||
) => {
|
) => {
|
||||||
let (_, fns) = fn;
|
|
||||||
let inputs = ins |> Internals.distPlusRenderInputsToInputs;
|
let inputs = ins |> Internals.distPlusRenderInputsToInputs;
|
||||||
let output =
|
let output =
|
||||||
ExpressionTree.runFunction(
|
ExpressionTree.runFunction(
|
||||||
{
|
Internals.makeInputs(inputs),
|
||||||
sampleCount: inputs.samplingInputs.sampleCount |> E.O.default(10000),
|
|
||||||
outputXYPoints:
|
|
||||||
inputs.samplingInputs.outputXYPoints |> E.O.default(10000),
|
|
||||||
kernelWidth: inputs.samplingInputs.kernelWidth,
|
|
||||||
shapeLength: inputs.samplingInputs.shapeLength |> E.O.default(10000),
|
|
||||||
},
|
|
||||||
inputs.environment,
|
inputs.environment,
|
||||||
fnInputs,
|
fnInputs,
|
||||||
fn,
|
fn,
|
||||||
|
|
Loading…
Reference in New Issue
Block a user