squiggle/src/distPlus/expressionTree/MathJsParser.re

382 lines
12 KiB
ReasonML
Raw Normal View History

2020-03-24 17:48:46 +00:00
module MathJsonToMathJsAdt = {
type arg =
| Symbol(string)
| Value(float)
| Fn(fn)
| Array(array(arg))
2020-07-30 13:47:59 +00:00
| Blocks(array(arg))
2020-03-24 17:48:46 +00:00
| Object(Js.Dict.t(arg))
| Assignment(arg, arg)
2020-07-31 09:41:34 +00:00
| FunctionAssignment(fnAssignment)
2020-03-24 17:48:46 +00:00
and fn = {
name: string,
args: array(arg),
2020-07-31 09:41:34 +00:00
}
and fnAssignment = {
name: string,
args: array(string),
expression: arg,
2020-03-24 17:48:46 +00:00
};
2020-03-24 17:48:46 +00:00
let rec run = (j: Js.Json.t) =>
Json.Decode.(
switch (field("mathjs", string, j)) {
| "FunctionNode" =>
let args = j |> field("args", array(run));
2020-07-19 13:21:47 +00:00
let name = j |> optional(field("fn", field("name", string)));
name |> E.O.fmap(name => Fn({name, args: args |> E.A.O.concatSomes}));
2020-03-24 17:48:46 +00:00
| "OperatorNode" =>
let args = j |> field("args", array(run));
Some(
Fn({
name: j |> field("fn", string),
args: args |> E.A.O.concatSomes,
}),
);
| "ConstantNode" =>
optional(field("value", Json.Decode.float), j)
|> E.O.fmap(r => Value(r))
2020-04-01 17:47:59 +00:00
| "ParenthesisNode" => j |> field("content", run)
2020-03-24 17:48:46 +00:00
| "ObjectNode" =>
let properties = j |> field("properties", dict(run));
Js.Dict.entries(properties)
|> E.A.fmap(((key, value)) => value |> E.O.fmap(v => (key, v)))
|> E.A.O.concatSomes
|> Js.Dict.fromArray
|> (r => Some(Object(r)));
| "ArrayNode" =>
let items = field("items", array(run), j);
Some(Array(items |> E.A.O.concatSomes));
| "SymbolNode" => Some(Symbol(field("name", string, j)))
| "AssignmentNode" =>
2020-07-30 13:47:59 +00:00
let object_ = j |> field("object", run);
let value_ = j |> field("value", run);
switch (object_, value_) {
| (Some(o), Some(v)) => Some(Assignment(o, v))
| _ => None
};
| "BlockNode" =>
let block = r => r |> field("node", run);
let args = j |> field("blocks", array(block)) |> E.A.O.concatSomes;
Some(Blocks(args));
2020-07-31 09:41:34 +00:00
| "FunctionAssignmentNode" =>
let name = j |> field("name", string);
let args = j |> field("params", array(field("name", string)));
let expression = j |> field("expr", run);
expression
|> E.O.fmap(expression =>
FunctionAssignment({name, args, expression})
);
2020-03-24 17:48:46 +00:00
| n =>
Js.log3("Couldn't parse mathjs node", j, n);
None;
}
);
};
2020-03-24 17:48:46 +00:00
module MathAdtToDistDst = {
open MathJsonToMathJsAdt;
2020-07-31 10:43:56 +00:00
let handleSymbol = sym => {
2020-07-31 09:41:34 +00:00
Ok(`Symbol(sym));
};
2020-03-24 17:48:46 +00:00
module MathAdtCleaner = {
let transformWithSymbol = (f: float, s: string) =>
switch (s) {
| "K"
| "k" => Some(f *. 1000.)
2020-03-24 17:48:46 +00:00
| "M"
| "m" => Some(f *. 1000000.)
2020-03-24 17:48:46 +00:00
| "B"
| "b" => Some(f *. 1000000000.)
2020-03-24 17:48:46 +00:00
| "T"
| "t" => Some(f *. 1000000000000.)
| _ => None
2020-03-24 17:48:46 +00:00
};
let rec run =
fun
| Fn({name: "multiply", args: [|Value(f), Symbol(s)|]}) as doNothing =>
transformWithSymbol(f, s)
|> E.O.fmap(r => Value(r))
|> E.O.default(doNothing)
2020-04-03 20:53:23 +00:00
| Fn({name: "unaryMinus", args: [|Value(f)|]}) => Value((-1.0) *. f)
2020-03-24 17:48:46 +00:00
| Fn({name, args}) => Fn({name, args: args |> E.A.fmap(run)})
| Array(args) => Array(args |> E.A.fmap(run))
| Symbol(s) => Symbol(s)
| Value(v) => Value(v)
2020-07-30 13:47:59 +00:00
| Blocks(args) => Blocks(args |> E.A.fmap(run))
| Assignment(a, b) => Assignment(a, run(b))
2020-07-31 09:41:34 +00:00
| FunctionAssignment(a) => FunctionAssignment(a)
2020-03-24 17:48:46 +00:00
| Object(v) =>
Object(
v
|> Js.Dict.entries
|> E.A.fmap(((key, value)) => (key, run(value)))
|> Js.Dict.fromArray,
);
};
2020-03-24 00:04:48 +00:00
2020-07-30 12:54:35 +00:00
let lognormal = (args, parseArgs, nodeParser) =>
switch (args) {
| [|Object(o)|] =>
let g = s =>
2020-07-31 10:43:56 +00:00
Js.Dict.get(o, s)
|> E.O.toResult("Variable was empty")
|> E.R.bind(_, nodeParser);
2020-07-30 12:54:35 +00:00
switch (g("mean"), g("stdev"), g("mu"), g("sigma")) {
| (Ok(mean), Ok(stdev), _, _) =>
2020-07-01 22:05:35 +00:00
Ok(
2020-07-31 11:31:39 +00:00
`FunctionCall(("lognormalFromMeanAndStdDev", [|mean, stdev|])),
2020-07-30 12:54:35 +00:00
)
| (_, _, Ok(mu), Ok(sigma)) =>
2020-07-31 11:31:39 +00:00
Ok(`FunctionCall(("lognormal", [|mu, sigma|])))
| _ =>
Error(
"Lognormal distribution needs either mean and stdev or mu and sigma",
)
2020-07-30 12:54:35 +00:00
};
| _ =>
parseArgs()
|> E.R.fmap((args: array(ExpressionTypes.ExpressionTree.node)) =>
2020-07-31 11:31:39 +00:00
`FunctionCall(("lognormal", args))
2020-07-30 12:54:35 +00:00
)
};
2020-03-26 16:01:52 +00:00
2020-03-25 15:12:39 +00:00
let multiModal =
(
2020-07-02 17:12:03 +00:00
args: array(result(ExpressionTypes.ExpressionTree.node, string)),
2020-04-03 20:53:23 +00:00
weights: option(array(float)),
2020-03-25 15:12:39 +00:00
) => {
2020-04-03 20:53:23 +00:00
let weights = weights |> E.O.default([||]);
2020-06-27 04:29:21 +00:00
let firstWithError = args |> Belt.Array.getBy(_, Belt.Result.isError);
let withoutErrors = args |> E.A.fmap(E.R.toOption) |> E.A.O.concatSomes;
2020-04-03 20:53:23 +00:00
switch (firstWithError) {
2020-07-01 22:05:35 +00:00
| Some(Error(e)) => Error(e)
| None when withoutErrors |> E.A.length == 0 =>
Error("Multimodals need at least one input")
| _ =>
let components =
withoutErrors
|> E.A.fmapi((index, t) => {
2020-07-01 22:05:35 +00:00
let w = weights |> E.A.get(_, index) |> E.O.default(1.0);
`VerticalScaling((`Multiply, t, `SymbolicDist(`Float(w))));
2020-07-01 22:05:35 +00:00
});
2020-07-01 22:05:35 +00:00
let pointwiseSum =
components
|> Js.Array.sliceFrom(1)
2020-07-01 22:05:35 +00:00
|> E.A.fold_left(
(acc, x) => {`PointwiseCombination((`Add, acc, x))},
2020-07-01 22:05:35 +00:00
E.A.unsafe_get(components, 0),
);
Ok(`Normalize(pointwiseSum));
2020-03-24 17:48:46 +00:00
};
};
2020-07-01 22:05:35 +00:00
let operationParser =
(
name: string,
args: result(array(ExpressionTypes.ExpressionTree.node), string),
) => {
let toOkAlgebraic = r => Ok(`AlgebraicCombination(r));
2020-07-19 13:21:47 +00:00
let toOkPointwise = r => Ok(`PointwiseCombination(r));
2020-07-13 19:05:00 +00:00
let toOkTruncate = r => Ok(`Truncate(r));
2020-07-19 12:10:36 +00:00
let toOkFloatFromDist = r => Ok(`FloatFromDist(r));
args
2020-07-23 10:17:39 +00:00
|> E.R.bind(_, args => {
switch (name, args) {
| ("add", [|l, r|]) => toOkAlgebraic((`Add, l, r))
| ("add", _) => Error("Addition needs two operands")
| ("subtract", [|l, r|]) => toOkAlgebraic((`Subtract, l, r))
| ("subtract", _) => Error("Subtraction needs two operands")
| ("multiply", [|l, r|]) => toOkAlgebraic((`Multiply, l, r))
| ("multiply", _) => Error("Multiplication needs two operands")
| ("dotMultiply", [|l, r|]) => toOkPointwise((`Multiply, l, r))
| ("dotMultiply", _) =>
Error("Dotwise multiplication needs two operands")
| ("rightLogShift", [|l, r|]) => toOkPointwise((`Add, l, r))
| ("rightLogShift", _) =>
Error("Dotwise addition needs two operands")
| ("divide", [|l, r|]) => toOkAlgebraic((`Divide, l, r))
| ("divide", _) => Error("Division needs two operands")
| ("pow", _) => Error("Exponentiation is not yet supported.")
| ("leftTruncate", [|d, `SymbolicDist(`Float(lc))|]) =>
toOkTruncate((Some(lc), None, d))
| ("leftTruncate", _) =>
Error(
"leftTruncate needs two arguments: the expression and the cutoff",
)
| ("rightTruncate", [|d, `SymbolicDist(`Float(rc))|]) =>
toOkTruncate((None, Some(rc), d))
| ("rightTruncate", _) =>
Error(
"rightTruncate needs two arguments: the expression and the cutoff",
)
| (
"truncate",
[|d, `SymbolicDist(`Float(lc)), `SymbolicDist(`Float(rc))|],
) =>
toOkTruncate((Some(lc), Some(rc), d))
| ("truncate", _) =>
Error(
"truncate needs three arguments: the expression and both cutoffs",
)
| ("pdf", [|d, `SymbolicDist(`Float(v))|]) =>
toOkFloatFromDist((`Pdf(v), d))
| ("cdf", [|d, `SymbolicDist(`Float(v))|]) =>
toOkFloatFromDist((`Cdf(v), d))
| ("inv", [|d, `SymbolicDist(`Float(v))|]) =>
toOkFloatFromDist((`Inv(v), d))
| ("mean", [|d|]) => toOkFloatFromDist((`Mean, d))
| ("sample", [|d|]) => toOkFloatFromDist((`Sample, d))
| _ => Error("This type not currently supported")
}
});
2020-07-01 22:47:49 +00:00
};
2020-07-01 22:05:35 +00:00
2020-07-01 22:47:49 +00:00
let functionParser = (nodeParser, name, args) => {
2020-07-30 12:54:35 +00:00
let parseArray = ags =>
ags |> E.A.fmap(nodeParser) |> E.A.R.firstErrorOrOpen;
let parseArgs = () => parseArray(args);
2020-07-01 22:47:49 +00:00
switch (name) {
2020-07-30 12:54:35 +00:00
| "normal"
| "uniform"
| "beta"
| "triangular"
| "to"
| "exponential"
| "cauchy" =>
parseArgs()
2020-07-30 12:54:35 +00:00
|> E.R.fmap((args: array(ExpressionTypes.ExpressionTree.node)) =>
2020-07-31 11:31:39 +00:00
`FunctionCall((name, args))
)
2020-07-30 12:54:35 +00:00
| "lognormal" => lognormal(args, parseArgs, nodeParser)
2020-07-01 22:47:49 +00:00
| "mm" =>
let weights =
args
|> E.A.last
|> E.O.bind(
_,
fun
| Array(values) => Some(values)
| _ => None,
)
|> E.O.fmap(o =>
o
|> E.A.fmap(
fun
| Value(r) => Some(r)
| _ => None,
)
|> E.A.O.concatSomes
);
let possibleDists =
E.O.isSome(weights)
? Belt.Array.slice(args, ~offset=0, ~len=E.A.length(args) - 1)
: args;
let dists = possibleDists |> E.A.fmap(nodeParser);
multiModal(dists, weights);
| "add"
| "subtract"
| "multiply"
2020-07-19 13:21:47 +00:00
| "dotMultiply"
| "rightLogShift"
2020-07-01 22:47:49 +00:00
| "divide"
| "pow"
| "leftTruncate"
| "rightTruncate"
2020-07-13 19:05:00 +00:00
| "truncate"
| "mean"
| "inv"
| "sample"
| "cdf"
| "pdf" => operationParser(name, parseArgs())
2020-07-31 09:41:34 +00:00
| name =>
parseArgs()
|> E.R.fmap((args: array(ExpressionTypes.ExpressionTree.node)) =>
2020-07-31 11:31:39 +00:00
`FunctionCall((name, args))
2020-07-31 09:41:34 +00:00
)
2020-07-01 22:05:35 +00:00
};
};
2020-03-24 17:48:46 +00:00
let rec nodeParser:
2020-07-31 10:43:56 +00:00
MathJsonToMathJsAdt.arg =>
result(ExpressionTypes.ExpressionTree.node, string) =
2020-07-31 10:43:56 +00:00
fun
| Value(f) => Ok(`SymbolicDist(`Float(f)))
| Symbol(sym) => Ok(`Symbol(sym))
| Fn({name, args}) => functionParser(nodeParser, name, args)
| _ => {
Error("This type not currently supported")
};
2020-07-01 22:47:49 +00:00
2020-07-31 09:41:34 +00:00
// | FunctionAssignment({name, args, expression}) => {
// let evaluatedExpression = run(expression);
// `Function(_ => Ok(evaluatedExpression));
// }
2020-07-31 10:43:56 +00:00
let rec topLevel = (r): result(ExpressionTypes.Program.program, string) =>
switch (r) {
2020-07-31 09:41:34 +00:00
| FunctionAssignment({name, args, expression}) =>
2020-07-31 10:43:56 +00:00
switch (nodeParser(expression)) {
| Ok(r) => Ok([|`Assignment((name, `Function((args, r))))|])
2020-07-31 10:27:16 +00:00
| Error(r) => Error(r)
2020-07-31 09:41:34 +00:00
}
2020-07-31 10:43:56 +00:00
| Value(_) as r => nodeParser(r) |> E.R.fmap(r => [|`Expression(r)|])
| Fn(_) as r => nodeParser(r) |> E.R.fmap(r => [|`Expression(r)|])
| Array(_) => Error("Array not valid as top level")
2020-07-31 10:43:56 +00:00
| Symbol(s) => handleSymbol(s) |> E.R.fmap(r => [|`Expression(r)|])
2020-07-30 13:47:59 +00:00
| Object(_) => Error("Object not valid as top level")
| Assignment(name, value) =>
switch (name) {
| Symbol(symbol) =>
2020-07-31 10:43:56 +00:00
nodeParser(value) |> E.R.fmap(r => [|`Assignment((symbol, r))|])
| _ => Error("Symbol not a string")
}
| Blocks(blocks) =>
blocks
2020-07-31 10:43:56 +00:00
|> E.A.fmap(b => topLevel(b))
|> E.A.R.firstErrorOrOpen
|> E.R.fmap(E.A.concatMany)
};
2020-03-24 00:04:48 +00:00
2020-07-31 10:43:56 +00:00
let run = (r): result(ExpressionTypes.Program.program, string) =>
r |> MathAdtCleaner.run |> topLevel;
2020-03-24 17:48:46 +00:00
};
2020-07-19 13:21:47 +00:00
/* The MathJs parser doesn't support '.+' syntax, but we want it because it
would make sense with '.*'. Our workaround is to change this to >>>, which is
logShift in mathJS. We don't expect to use logShift anytime soon, so this tradeoff
seems fine.
*/
let pointwiseToRightLogShift = Js.String.replaceByRe([%re "/\.\+/g"], ">>>");
2020-07-31 10:43:56 +00:00
let fromString2 = str => {
/* We feed the user-typed string into Mathjs.parseMath,
2020-07-01 22:05:35 +00:00
which returns a JSON with (hopefully) a single-element array.
This array element is the top-level node of a nested-object tree
representing the functions/arguments/values/etc. in the string.
2020-07-01 22:05:35 +00:00
The function MathJsonToMathJsAdt then recursively unpacks this JSON into a typed data structure we can use.
Inside of this function, MathAdtToDistDst is called whenever a distribution function is encountered.
*/
2020-07-19 13:21:47 +00:00
let mathJsToJson = str |> pointwiseToRightLogShift |> Mathjs.parseMath;
2020-03-24 17:48:46 +00:00
let mathJsParse =
E.R.bind(mathJsToJson, r => {
2020-03-24 17:48:46 +00:00
switch (MathJsonToMathJsAdt.run(r)) {
| Some(r) => Ok(r)
| None => Error("MathJsParse Error")
}
2020-07-01 22:05:35 +00:00
});
2020-07-31 10:43:56 +00:00
let value = E.R.bind(mathJsParse, MathAdtToDistDst.run);
2020-03-24 17:48:46 +00:00
value;
};
2020-07-31 10:43:56 +00:00
let fromString = str => {
fromString2(str);
};