Added .+ and .* support to parser
This commit is contained in:
parent
b2c6ef7e5e
commit
6bf350039b
|
@ -78,7 +78,7 @@ let combinePointwise =
|
||||||
make(
|
make(
|
||||||
~integralSumCache=combinedIntegralSum,
|
~integralSumCache=combinedIntegralSum,
|
||||||
XYShape.PointwiseCombination.combine(
|
XYShape.PointwiseCombination.combine(
|
||||||
(+.),
|
fn,
|
||||||
interpolator,
|
interpolator,
|
||||||
t1.xyShape,
|
t1.xyShape,
|
||||||
t2.xyShape,
|
t2.xyShape,
|
||||||
|
|
|
@ -79,7 +79,10 @@ module VerticalScaling = {
|
||||||
|
|
||||||
module PointwiseCombination = {
|
module PointwiseCombination = {
|
||||||
let pointwiseAdd = (evaluationParams: evaluationParams, t1: t, t2: t) => {
|
let pointwiseAdd = (evaluationParams: evaluationParams, t1: t, t2: t) => {
|
||||||
switch (Render.render(evaluationParams, t1), Render.render(evaluationParams, t2)) {
|
switch (
|
||||||
|
Render.render(evaluationParams, t1),
|
||||||
|
Render.render(evaluationParams, t2),
|
||||||
|
) {
|
||||||
| (Ok(`RenderedDist(rs1)), Ok(`RenderedDist(rs2))) =>
|
| (Ok(`RenderedDist(rs1)), Ok(`RenderedDist(rs2))) =>
|
||||||
Ok(
|
Ok(
|
||||||
`RenderedDist(
|
`RenderedDist(
|
||||||
|
@ -110,9 +113,16 @@ module PointwiseCombination = {
|
||||||
let pointwiseMultiply = (evaluationParams: evaluationParams, t1: t, t2: t) => {
|
let pointwiseMultiply = (evaluationParams: evaluationParams, t1: t, t2: t) => {
|
||||||
// TODO: construct a function that we can easily sample from, to construct
|
// TODO: construct a function that we can easily sample from, to construct
|
||||||
// a RenderedDist. Use the xMin and xMax of the rendered shapes to tell the sampling function where to look.
|
// a RenderedDist. Use the xMin and xMax of the rendered shapes to tell the sampling function where to look.
|
||||||
Error(
|
switch (
|
||||||
"Pointwise multiplication not yet supported.",
|
Render.render(evaluationParams, t1),
|
||||||
);
|
Render.render(evaluationParams, t2),
|
||||||
|
) {
|
||||||
|
| (Ok(`RenderedDist(rs1)), Ok(`RenderedDist(rs2))) =>
|
||||||
|
Ok(`RenderedDist(Shape.combinePointwise(( *. ), rs1, rs2)))
|
||||||
|
| (Error(e1), _) => Error(e1)
|
||||||
|
| (_, Error(e2)) => Error(e2)
|
||||||
|
| _ => Error("Pointwise combination: rendering failed.")
|
||||||
|
};
|
||||||
};
|
};
|
||||||
|
|
||||||
let operationToLeaf =
|
let operationToLeaf =
|
||||||
|
@ -134,7 +144,9 @@ module Truncate = {
|
||||||
switch (leftCutoff, rightCutoff, t) {
|
switch (leftCutoff, rightCutoff, t) {
|
||||||
| (None, None, t) => `Solution(t)
|
| (None, None, t) => `Solution(t)
|
||||||
| (Some(lc), Some(rc), t) when lc > rc =>
|
| (Some(lc), Some(rc), t) when lc > rc =>
|
||||||
`Error("Left truncation bound must be smaller than right truncation bound.")
|
`Error(
|
||||||
|
"Left truncation bound must be smaller than right truncation bound.",
|
||||||
|
)
|
||||||
| (lc, rc, `SymbolicDist(`Uniform(u))) =>
|
| (lc, rc, `SymbolicDist(`Uniform(u))) =>
|
||||||
`Solution(
|
`Solution(
|
||||||
`SymbolicDist(`Uniform(SymbolicDist.Uniform.truncate(lc, rc, u))),
|
`SymbolicDist(`Uniform(SymbolicDist.Uniform.truncate(lc, rc, u))),
|
||||||
|
|
|
@ -15,12 +15,8 @@ module MathJsonToMathJsAdt = {
|
||||||
switch (field("mathjs", string, j)) {
|
switch (field("mathjs", string, j)) {
|
||||||
| "FunctionNode" =>
|
| "FunctionNode" =>
|
||||||
let args = j |> field("args", array(run));
|
let args = j |> field("args", array(run));
|
||||||
Some(
|
let name = j |> optional(field("fn", field("name", string)));
|
||||||
Fn({
|
name |> E.O.fmap(name => Fn({name, args: args |> E.A.O.concatSomes}));
|
||||||
name: j |> field("fn", field("name", string)),
|
|
||||||
args: args |> E.A.O.concatSomes,
|
|
||||||
}),
|
|
||||||
);
|
|
||||||
| "OperatorNode" =>
|
| "OperatorNode" =>
|
||||||
let args = j |> field("args", array(run));
|
let args = j |> field("args", array(run));
|
||||||
Some(
|
Some(
|
||||||
|
@ -240,6 +236,7 @@ module MathAdtToDistDst = {
|
||||||
args: array(result(ExpressionTypes.ExpressionTree.node, string)),
|
args: array(result(ExpressionTypes.ExpressionTree.node, string)),
|
||||||
) => {
|
) => {
|
||||||
let toOkAlgebraic = r => Ok(`AlgebraicCombination(r));
|
let toOkAlgebraic = r => Ok(`AlgebraicCombination(r));
|
||||||
|
let toOkPointwise = r => Ok(`PointwiseCombination(r));
|
||||||
let toOkTruncate = r => Ok(`Truncate(r));
|
let toOkTruncate = r => Ok(`Truncate(r));
|
||||||
let toOkFloatFromDist = r => Ok(`FloatFromDist(r));
|
let toOkFloatFromDist = r => Ok(`FloatFromDist(r));
|
||||||
switch (name, args) {
|
switch (name, args) {
|
||||||
|
@ -249,6 +246,11 @@ module MathAdtToDistDst = {
|
||||||
| ("subtract", _) => Error("Subtraction needs two operands")
|
| ("subtract", _) => Error("Subtraction needs two operands")
|
||||||
| ("multiply", [|Ok(l), Ok(r)|]) => toOkAlgebraic((`Multiply, l, r))
|
| ("multiply", [|Ok(l), Ok(r)|]) => toOkAlgebraic((`Multiply, l, r))
|
||||||
| ("multiply", _) => Error("Multiplication needs two operands")
|
| ("multiply", _) => Error("Multiplication needs two operands")
|
||||||
|
| ("dotMultiply", [|Ok(l), Ok(r)|]) => toOkPointwise((`Multiply, l, r))
|
||||||
|
| ("dotMultiply", _) =>
|
||||||
|
Error("Dotwise multiplication needs two operands")
|
||||||
|
| ("rightLogShift", [|Ok(l), Ok(r)|]) => toOkPointwise((`Add, l, r))
|
||||||
|
| ("rightLogShift", _) => Error("Dotwise addition needs two operands")
|
||||||
| ("divide", [|Ok(l), Ok(r)|]) => toOkAlgebraic((`Divide, l, r))
|
| ("divide", [|Ok(l), Ok(r)|]) => toOkAlgebraic((`Divide, l, r))
|
||||||
| ("divide", _) => Error("Division needs two operands")
|
| ("divide", _) => Error("Division needs two operands")
|
||||||
| ("pow", _) => Error("Exponentiation is not yet supported.")
|
| ("pow", _) => Error("Exponentiation is not yet supported.")
|
||||||
|
@ -324,6 +326,8 @@ module MathAdtToDistDst = {
|
||||||
| "add"
|
| "add"
|
||||||
| "subtract"
|
| "subtract"
|
||||||
| "multiply"
|
| "multiply"
|
||||||
|
| "dotMultiply"
|
||||||
|
| "rightLogShift"
|
||||||
| "divide"
|
| "divide"
|
||||||
| "pow"
|
| "pow"
|
||||||
| "leftTruncate"
|
| "leftTruncate"
|
||||||
|
@ -358,6 +362,13 @@ module MathAdtToDistDst = {
|
||||||
r |> MathAdtCleaner.run |> topLevel;
|
r |> MathAdtCleaner.run |> topLevel;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
/* The MathJs parser doesn't support '.+' syntax, but we want it because it
|
||||||
|
would make sense with '.*'. Our workaround is to change this to >>>, which is
|
||||||
|
logShift in mathJS. We don't expect to use logShift anytime soon, so this tradeoff
|
||||||
|
seems fine.
|
||||||
|
*/
|
||||||
|
let pointwiseToRightLogShift = Js.String.replaceByRe([%re "/\.\+/g"], ">>>");
|
||||||
|
|
||||||
let fromString = str => {
|
let fromString = str => {
|
||||||
/* We feed the user-typed string into Mathjs.parseMath,
|
/* We feed the user-typed string into Mathjs.parseMath,
|
||||||
which returns a JSON with (hopefully) a single-element array.
|
which returns a JSON with (hopefully) a single-element array.
|
||||||
|
@ -367,7 +378,7 @@ let fromString = str => {
|
||||||
The function MathJsonToMathJsAdt then recursively unpacks this JSON into a typed data structure we can use.
|
The function MathJsonToMathJsAdt then recursively unpacks this JSON into a typed data structure we can use.
|
||||||
Inside of this function, MathAdtToDistDst is called whenever a distribution function is encountered.
|
Inside of this function, MathAdtToDistDst is called whenever a distribution function is encountered.
|
||||||
*/
|
*/
|
||||||
let mathJsToJson = Mathjs.parseMath(str);
|
let mathJsToJson = str |> pointwiseToRightLogShift |> Mathjs.parseMath;
|
||||||
let mathJsParse =
|
let mathJsParse =
|
||||||
E.R.bind(mathJsToJson, r => {
|
E.R.bind(mathJsToJson, r => {
|
||||||
switch (MathJsonToMathJsAdt.run(r)) {
|
switch (MathJsonToMathJsAdt.run(r)) {
|
||||||
|
|
Loading…
Reference in New Issue
Block a user