Added .+ and .* support to parser

This commit is contained in:
Ozzie Gooen 2020-07-19 14:21:47 +01:00
parent b2c6ef7e5e
commit 6bf350039b
3 changed files with 36 additions and 13 deletions

View File

@ -78,7 +78,7 @@ let combinePointwise =
make( make(
~integralSumCache=combinedIntegralSum, ~integralSumCache=combinedIntegralSum,
XYShape.PointwiseCombination.combine( XYShape.PointwiseCombination.combine(
(+.), fn,
interpolator, interpolator,
t1.xyShape, t1.xyShape,
t2.xyShape, t2.xyShape,

View File

@ -79,7 +79,10 @@ module VerticalScaling = {
module PointwiseCombination = { module PointwiseCombination = {
let pointwiseAdd = (evaluationParams: evaluationParams, t1: t, t2: t) => { let pointwiseAdd = (evaluationParams: evaluationParams, t1: t, t2: t) => {
switch (Render.render(evaluationParams, t1), Render.render(evaluationParams, t2)) { switch (
Render.render(evaluationParams, t1),
Render.render(evaluationParams, t2),
) {
| (Ok(`RenderedDist(rs1)), Ok(`RenderedDist(rs2))) => | (Ok(`RenderedDist(rs1)), Ok(`RenderedDist(rs2))) =>
Ok( Ok(
`RenderedDist( `RenderedDist(
@ -110,9 +113,16 @@ module PointwiseCombination = {
let pointwiseMultiply = (evaluationParams: evaluationParams, t1: t, t2: t) => { let pointwiseMultiply = (evaluationParams: evaluationParams, t1: t, t2: t) => {
// TODO: construct a function that we can easily sample from, to construct // TODO: construct a function that we can easily sample from, to construct
// a RenderedDist. Use the xMin and xMax of the rendered shapes to tell the sampling function where to look. // a RenderedDist. Use the xMin and xMax of the rendered shapes to tell the sampling function where to look.
Error( switch (
"Pointwise multiplication not yet supported.", Render.render(evaluationParams, t1),
); Render.render(evaluationParams, t2),
) {
| (Ok(`RenderedDist(rs1)), Ok(`RenderedDist(rs2))) =>
Ok(`RenderedDist(Shape.combinePointwise(( *. ), rs1, rs2)))
| (Error(e1), _) => Error(e1)
| (_, Error(e2)) => Error(e2)
| _ => Error("Pointwise combination: rendering failed.")
};
}; };
let operationToLeaf = let operationToLeaf =
@ -134,7 +144,9 @@ module Truncate = {
switch (leftCutoff, rightCutoff, t) { switch (leftCutoff, rightCutoff, t) {
| (None, None, t) => `Solution(t) | (None, None, t) => `Solution(t)
| (Some(lc), Some(rc), t) when lc > rc => | (Some(lc), Some(rc), t) when lc > rc =>
`Error("Left truncation bound must be smaller than right truncation bound.") `Error(
"Left truncation bound must be smaller than right truncation bound.",
)
| (lc, rc, `SymbolicDist(`Uniform(u))) => | (lc, rc, `SymbolicDist(`Uniform(u))) =>
`Solution( `Solution(
`SymbolicDist(`Uniform(SymbolicDist.Uniform.truncate(lc, rc, u))), `SymbolicDist(`Uniform(SymbolicDist.Uniform.truncate(lc, rc, u))),

View File

@ -15,12 +15,8 @@ module MathJsonToMathJsAdt = {
switch (field("mathjs", string, j)) { switch (field("mathjs", string, j)) {
| "FunctionNode" => | "FunctionNode" =>
let args = j |> field("args", array(run)); let args = j |> field("args", array(run));
Some( let name = j |> optional(field("fn", field("name", string)));
Fn({ name |> E.O.fmap(name => Fn({name, args: args |> E.A.O.concatSomes}));
name: j |> field("fn", field("name", string)),
args: args |> E.A.O.concatSomes,
}),
);
| "OperatorNode" => | "OperatorNode" =>
let args = j |> field("args", array(run)); let args = j |> field("args", array(run));
Some( Some(
@ -240,6 +236,7 @@ module MathAdtToDistDst = {
args: array(result(ExpressionTypes.ExpressionTree.node, string)), args: array(result(ExpressionTypes.ExpressionTree.node, string)),
) => { ) => {
let toOkAlgebraic = r => Ok(`AlgebraicCombination(r)); let toOkAlgebraic = r => Ok(`AlgebraicCombination(r));
let toOkPointwise = r => Ok(`PointwiseCombination(r));
let toOkTruncate = r => Ok(`Truncate(r)); let toOkTruncate = r => Ok(`Truncate(r));
let toOkFloatFromDist = r => Ok(`FloatFromDist(r)); let toOkFloatFromDist = r => Ok(`FloatFromDist(r));
switch (name, args) { switch (name, args) {
@ -249,6 +246,11 @@ module MathAdtToDistDst = {
| ("subtract", _) => Error("Subtraction needs two operands") | ("subtract", _) => Error("Subtraction needs two operands")
| ("multiply", [|Ok(l), Ok(r)|]) => toOkAlgebraic((`Multiply, l, r)) | ("multiply", [|Ok(l), Ok(r)|]) => toOkAlgebraic((`Multiply, l, r))
| ("multiply", _) => Error("Multiplication needs two operands") | ("multiply", _) => Error("Multiplication needs two operands")
| ("dotMultiply", [|Ok(l), Ok(r)|]) => toOkPointwise((`Multiply, l, r))
| ("dotMultiply", _) =>
Error("Dotwise multiplication needs two operands")
| ("rightLogShift", [|Ok(l), Ok(r)|]) => toOkPointwise((`Add, l, r))
| ("rightLogShift", _) => Error("Dotwise addition needs two operands")
| ("divide", [|Ok(l), Ok(r)|]) => toOkAlgebraic((`Divide, l, r)) | ("divide", [|Ok(l), Ok(r)|]) => toOkAlgebraic((`Divide, l, r))
| ("divide", _) => Error("Division needs two operands") | ("divide", _) => Error("Division needs two operands")
| ("pow", _) => Error("Exponentiation is not yet supported.") | ("pow", _) => Error("Exponentiation is not yet supported.")
@ -324,6 +326,8 @@ module MathAdtToDistDst = {
| "add" | "add"
| "subtract" | "subtract"
| "multiply" | "multiply"
| "dotMultiply"
| "rightLogShift"
| "divide" | "divide"
| "pow" | "pow"
| "leftTruncate" | "leftTruncate"
@ -358,6 +362,13 @@ module MathAdtToDistDst = {
r |> MathAdtCleaner.run |> topLevel; r |> MathAdtCleaner.run |> topLevel;
}; };
/* The MathJs parser doesn't support '.+' syntax, but we want it because it
would make sense with '.*'. Our workaround is to change this to >>>, which is
logShift in mathJS. We don't expect to use logShift anytime soon, so this tradeoff
seems fine.
*/
let pointwiseToRightLogShift = Js.String.replaceByRe([%re "/\.\+/g"], ">>>");
let fromString = str => { let fromString = str => {
/* We feed the user-typed string into Mathjs.parseMath, /* We feed the user-typed string into Mathjs.parseMath,
which returns a JSON with (hopefully) a single-element array. which returns a JSON with (hopefully) a single-element array.
@ -367,7 +378,7 @@ let fromString = str => {
The function MathJsonToMathJsAdt then recursively unpacks this JSON into a typed data structure we can use. The function MathJsonToMathJsAdt then recursively unpacks this JSON into a typed data structure we can use.
Inside of this function, MathAdtToDistDst is called whenever a distribution function is encountered. Inside of this function, MathAdtToDistDst is called whenever a distribution function is encountered.
*/ */
let mathJsToJson = Mathjs.parseMath(str); let mathJsToJson = str |> pointwiseToRightLogShift |> Mathjs.parseMath;
let mathJsParse = let mathJsParse =
E.R.bind(mathJsToJson, r => { E.R.bind(mathJsToJson, r => {
switch (MathJsonToMathJsAdt.run(r)) { switch (MathJsonToMathJsAdt.run(r)) {