Ensure render in DistPlusRenderer
This commit is contained in:
parent
d8c1aa6693
commit
b064001a51
|
@ -160,7 +160,7 @@ module PointwiseCombination = {
|
|||
Render.render(evaluationParams, t2),
|
||||
) {
|
||||
| (Ok(`RenderedDist(rs1)), Ok(`RenderedDist(rs2))) =>
|
||||
Ok(`RenderedDist(Shape.combinePointwise(( *. ), rs1, rs2)))
|
||||
Ok(`RenderedDist(Shape.combinePointwise(fn, rs1, rs2)))
|
||||
| (Error(e1), _) => Error(e1)
|
||||
| (_, Error(e2)) => Error(e2)
|
||||
| _ => Error("Pointwise combination: rendering failed.")
|
||||
|
@ -177,7 +177,7 @@ module PointwiseCombination = {
|
|||
switch (pointwiseOp) {
|
||||
| `Add => pointwiseAdd(evaluationParams, t1, t2)
|
||||
| `Multiply => pointwiseCombine(( *. ),evaluationParams, t1, t2)
|
||||
| `Exponentiate => pointwiseCombine(( *. ),evaluationParams, t1, t2)
|
||||
| `Exponentiate => pointwiseCombine(( ** ),evaluationParams, t1, t2)
|
||||
};
|
||||
};
|
||||
};
|
||||
|
|
|
@ -49,6 +49,11 @@ let to_: array(node) => result(node, string) =
|
|||
Error("Low value must be less than high value.")
|
||||
| _ => Error("Requires 2 variables");
|
||||
|
||||
// Possible setup:
|
||||
// let normal = {"inputs": [`float, `float], "outputs": [`float]};
|
||||
// let render = {"inputs": [`dist], "outputs": [`renderedDist]};
|
||||
// let render = {"inputs": [`distRenderedDist], "outputs": [`renderedDist]};
|
||||
|
||||
let fnn =
|
||||
(
|
||||
evaluationParams: ExpressionTypes.ExpressionTree.evaluationParams,
|
||||
|
|
|
@ -201,6 +201,9 @@ module MathAdtToDistDst = {
|
|||
| ("dotMultiply", [|l, r|]) => toOkPointwise((`Multiply, l, r))
|
||||
| ("dotMultiply", _) =>
|
||||
Error("Dotwise multiplication needs two operands")
|
||||
| ("dotPow", [|l, r|]) => toOkPointwise((`Exponentiate, l, r))
|
||||
| ("dotPow", _) =>
|
||||
Error("Dotwise exponentiation needs two operands")
|
||||
| ("rightLogShift", [|l, r|]) => toOkPointwise((`Add, l, r))
|
||||
| ("rightLogShift", _) =>
|
||||
Error("Dotwise addition needs two operands")
|
||||
|
@ -227,6 +230,12 @@ module MathAdtToDistDst = {
|
|||
Error(
|
||||
"truncate needs three arguments: the expression and both cutoffs",
|
||||
)
|
||||
| ("scaleMultiply", [|d, `SymbolicDist(`Float(v))|]) =>
|
||||
Ok(`VerticalScaling(`Multiply, d, `SymbolicDist(`Float(v))))
|
||||
| ("scaleExp", [|d, `SymbolicDist(`Float(v))|]) =>
|
||||
Ok(`VerticalScaling(`Exponentiate, d, `SymbolicDist(`Float(v))))
|
||||
| ("scaleLog", [|d, `SymbolicDist(`Float(v))|]) =>
|
||||
Ok(`VerticalScaling(`Log, d, `SymbolicDist(`Float(v))))
|
||||
| ("pdf", [|d, `SymbolicDist(`Float(v))|]) =>
|
||||
toOkFloatFromDist((`Pdf(v), d))
|
||||
| ("cdf", [|d, `SymbolicDist(`Float(v))|]) =>
|
||||
|
@ -275,11 +284,15 @@ module MathAdtToDistDst = {
|
|||
| "subtract"
|
||||
| "multiply"
|
||||
| "dotMultiply"
|
||||
| "dotPow"
|
||||
| "rightLogShift"
|
||||
| "divide"
|
||||
| "pow"
|
||||
| "leftTruncate"
|
||||
| "rightTruncate"
|
||||
| "scaleMultiply"
|
||||
| "scaleExp"
|
||||
| "scaleLog"
|
||||
| "truncate"
|
||||
| "mean"
|
||||
| "inv"
|
||||
|
@ -355,6 +368,7 @@ let fromString2 = str => {
|
|||
Inside of this function, MathAdtToDistDst is called whenever a distribution function is encountered.
|
||||
*/
|
||||
let mathJsToJson = str |> pointwiseToRightLogShift |> Mathjs.parseMath;
|
||||
|
||||
let mathJsParse =
|
||||
E.R.bind(mathJsToJson, r => {
|
||||
switch (MathJsonToMathJsAdt.run(r)) {
|
||||
|
@ -364,6 +378,7 @@ let fromString2 = str => {
|
|||
});
|
||||
|
||||
let value = E.R.bind(mathJsParse, MathAdtToDistDst.run);
|
||||
Js.log2(mathJsParse, value);
|
||||
value;
|
||||
};
|
||||
|
||||
|
|
|
@ -91,18 +91,16 @@ module Internals = {
|
|||
};
|
||||
let makeOutputs = (graph, shape): outputs => {graph, shape};
|
||||
|
||||
let makeInputs = (inputs): ExpressionTypes.ExpressionTree.samplingInputs => {
|
||||
sampleCount: inputs.samplingInputs.sampleCount |> E.O.default(10000),
|
||||
outputXYPoints:
|
||||
inputs.samplingInputs.outputXYPoints |> E.O.default(10000),
|
||||
kernelWidth: inputs.samplingInputs.kernelWidth,
|
||||
shapeLength: inputs.samplingInputs.shapeLength |> E.O.default(10000),
|
||||
};
|
||||
|
||||
let runNode = (inputs, node) => {
|
||||
ExpressionTree.toLeaf(
|
||||
{
|
||||
sampleCount: inputs.samplingInputs.sampleCount |> E.O.default(10000),
|
||||
outputXYPoints:
|
||||
inputs.samplingInputs.outputXYPoints |> E.O.default(10000),
|
||||
kernelWidth: inputs.samplingInputs.kernelWidth,
|
||||
shapeLength: inputs.samplingInputs.shapeLength |> E.O.default(10000),
|
||||
},
|
||||
inputs.environment,
|
||||
node,
|
||||
);
|
||||
ExpressionTree.toLeaf(makeInputs(inputs), inputs.environment, node);
|
||||
};
|
||||
|
||||
let runProgram = (inputs: inputs, p: ExpressionTypes.Program.program) => {
|
||||
|
@ -154,15 +152,40 @@ let run = (inputs: Inputs.inputs) => {
|
|||
|> E.R.fmap(Internals.outputToDistPlus(inputs));
|
||||
};
|
||||
|
||||
let exportDistPlus = inputs =>
|
||||
fun
|
||||
| `RenderedDist(n) => Ok(`DistPlus(Internals.outputToDistPlus(inputs, n)))
|
||||
| `Function(n) => Ok(`Function(n))
|
||||
| n =>
|
||||
Error(
|
||||
"Didn't output a rendered distribution. Format:"
|
||||
++ ExpressionTree.toString(n),
|
||||
);
|
||||
let renderIfNeeded =
|
||||
(inputs, node: ExpressionTypes.ExpressionTree.node)
|
||||
: result(ExpressionTypes.ExpressionTree.node, string) =>
|
||||
node
|
||||
|> (
|
||||
fun
|
||||
| `SymbolicDist(n) => {
|
||||
`Render(`SymbolicDist(n))
|
||||
|> Internals.runNode(Internals.distPlusRenderInputsToInputs(inputs))
|
||||
|> (
|
||||
fun
|
||||
| Ok(`RenderedDist(n)) => Ok(`RenderedDist(n))
|
||||
| Error(r) => Error(r)
|
||||
| _ => Error("Didn't render, but intended to")
|
||||
);
|
||||
}
|
||||
| n => Ok(n)
|
||||
);
|
||||
|
||||
let exportDistPlus = (inputs, node: ExpressionTypes.ExpressionTree.node) =>
|
||||
node
|
||||
|> renderIfNeeded(inputs)
|
||||
|> E.R.bind(
|
||||
_,
|
||||
fun
|
||||
| `RenderedDist(n) =>
|
||||
Ok(`DistPlus(Internals.outputToDistPlus(inputs, n)))
|
||||
| `Function(n) => Ok(`Function(n))
|
||||
| n =>
|
||||
Error(
|
||||
"Didn't output a rendered distribution. Format:"
|
||||
++ ExpressionTree.toString(n),
|
||||
),
|
||||
);
|
||||
|
||||
let run2 = (inputs: Inputs.inputs) => {
|
||||
inputs
|
||||
|
@ -177,17 +200,10 @@ let runFunction =
|
|||
fn: (array(string), ExpressionTypes.ExpressionTree.node),
|
||||
fnInputs,
|
||||
) => {
|
||||
let (_, fns) = fn;
|
||||
let inputs = ins |> Internals.distPlusRenderInputsToInputs;
|
||||
let output =
|
||||
ExpressionTree.runFunction(
|
||||
{
|
||||
sampleCount: inputs.samplingInputs.sampleCount |> E.O.default(10000),
|
||||
outputXYPoints:
|
||||
inputs.samplingInputs.outputXYPoints |> E.O.default(10000),
|
||||
kernelWidth: inputs.samplingInputs.kernelWidth,
|
||||
shapeLength: inputs.samplingInputs.shapeLength |> E.O.default(10000),
|
||||
},
|
||||
Internals.makeInputs(inputs),
|
||||
inputs.environment,
|
||||
fnInputs,
|
||||
fn,
|
||||
|
|
Loading…
Reference in New Issue
Block a user