squiggle/src/distPlus/symbolic/MathJsParser.re

353 lines
12 KiB
ReasonML
Raw Normal View History

2020-04-05 06:36:14 +00:00
// todo: rename to SymbolicParser
2020-03-24 17:48:46 +00:00
module MathJsonToMathJsAdt = {
type arg =
| Symbol(string)
| Value(float)
| Fn(fn)
| Array(array(arg))
| Object(Js.Dict.t(arg))
and fn = {
name: string,
args: array(arg),
};
2020-03-24 17:48:46 +00:00
let rec run = (j: Js.Json.t) =>
Json.Decode.(
switch (field("mathjs", string, j)) {
| "FunctionNode" =>
let args = j |> field("args", array(run));
Some(
Fn({
name: j |> field("fn", field("name", string)),
args: args |> E.A.O.concatSomes,
}),
);
| "OperatorNode" =>
let args = j |> field("args", array(run));
Some(
Fn({
name: j |> field("fn", string),
args: args |> E.A.O.concatSomes,
}),
);
| "ConstantNode" =>
optional(field("value", Json.Decode.float), j)
|> E.O.fmap(r => Value(r))
2020-04-01 17:47:59 +00:00
| "ParenthesisNode" => j |> field("content", run)
2020-03-24 17:48:46 +00:00
| "ObjectNode" =>
let properties = j |> field("properties", dict(run));
Js.Dict.entries(properties)
|> E.A.fmap(((key, value)) => value |> E.O.fmap(v => (key, v)))
|> E.A.O.concatSomes
|> Js.Dict.fromArray
|> (r => Some(Object(r)));
| "ArrayNode" =>
let items = field("items", array(run), j);
Some(Array(items |> E.A.O.concatSomes));
| "SymbolNode" => Some(Symbol(field("name", string, j)))
| n =>
Js.log3("Couldn't parse mathjs node", j, n);
None;
}
);
};
2020-03-24 17:48:46 +00:00
module MathAdtToDistDst = {
open MathJsonToMathJsAdt;
2020-03-24 17:48:46 +00:00
module MathAdtCleaner = {
let transformWithSymbol = (f: float, s: string) =>
switch (s) {
| "K"
| "k" => f *. 1000.
| "M"
| "m" => f *. 1000000.
| "B"
| "b" => f *. 1000000000.
| "T"
| "t" => f *. 1000000000000.
| _ => f
};
2020-03-24 00:04:48 +00:00
2020-03-24 17:48:46 +00:00
let rec run =
fun
| Fn({name: "multiply", args: [|Value(f), Symbol(s)|]}) =>
Value(transformWithSymbol(f, s))
2020-04-03 20:53:23 +00:00
| Fn({name: "unaryMinus", args: [|Value(f)|]}) => Value((-1.0) *. f)
2020-03-24 17:48:46 +00:00
| Fn({name, args}) => Fn({name, args: args |> E.A.fmap(run)})
| Array(args) => Array(args |> E.A.fmap(run))
| Symbol(s) => Symbol(s)
| Value(v) => Value(v)
| Object(v) =>
Object(
v
|> Js.Dict.entries
|> E.A.fmap(((key, value)) => (key, run(value)))
|> Js.Dict.fromArray,
);
};
2020-03-24 00:04:48 +00:00
let normal: array(arg) => result(SymbolicDist.distTree, string) =
2020-03-24 17:48:46 +00:00
fun
| [|Value(mean), Value(stdev)|] =>
2020-06-14 01:54:54 +00:00
Ok(`Simple(`Normal({mean, stdev})))
2020-03-24 17:48:46 +00:00
| _ => Error("Wrong number of variables in normal distribution");
2020-03-24 00:04:48 +00:00
let lognormal: array(arg) => result(SymbolicDist.distTree, string) =
2020-03-24 17:48:46 +00:00
fun
2020-06-14 01:54:54 +00:00
| [|Value(mu), Value(sigma)|] => Ok(`Simple(`Lognormal({mu, sigma})))
2020-03-24 17:48:46 +00:00
| [|Object(o)|] => {
let g = Js.Dict.get(o);
switch (g("mean"), g("stdev"), g("mu"), g("sigma")) {
| (Some(Value(mean)), Some(Value(stdev)), _, _) =>
2020-06-14 01:54:54 +00:00
Ok(`Simple(SymbolicDist.Lognormal.fromMeanAndStdev(mean, stdev)))
2020-03-24 17:48:46 +00:00
| (_, _, Some(Value(mu)), Some(Value(sigma))) =>
2020-06-14 01:54:54 +00:00
Ok(`Simple(`Lognormal({mu, sigma})))
2020-03-24 17:48:46 +00:00
| _ => Error("Lognormal distribution would need mean and stdev")
};
}
| _ => Error("Wrong number of variables in lognormal distribution");
2020-03-24 00:04:48 +00:00
let to_: array(arg) => result(SymbolicDist.distTree, string) =
2020-03-24 17:48:46 +00:00
fun
| [|Value(low), Value(high)|] when low <= 0.0 && low < high=> {
2020-06-14 01:54:54 +00:00
Ok(`Simple(SymbolicDist.Normal.from90PercentCI(low, high)));
}
2020-03-25 15:12:39 +00:00
| [|Value(low), Value(high)|] when low < high => {
2020-06-14 01:54:54 +00:00
Ok(`Simple(SymbolicDist.Lognormal.from90PercentCI(low, high)));
2020-03-24 17:48:46 +00:00
}
2020-03-25 15:12:39 +00:00
| [|Value(_), Value(_)|] =>
Error("Low value must be less than high value.")
2020-03-24 17:48:46 +00:00
| _ => Error("Wrong number of variables in lognormal distribution");
2020-03-24 00:04:48 +00:00
let uniform: array(arg) => result(SymbolicDist.distTree, string) =
2020-03-24 17:48:46 +00:00
fun
2020-06-14 01:54:54 +00:00
| [|Value(low), Value(high)|] => Ok(`Simple(`Uniform({low, high})))
2020-03-24 17:48:46 +00:00
| _ => Error("Wrong number of variables in lognormal distribution");
let beta: array(arg) => result(SymbolicDist.distTree, string) =
fun
2020-06-14 01:54:54 +00:00
| [|Value(alpha), Value(beta)|] => Ok(`Simple(`Beta({alpha, beta})))
2020-03-24 17:48:46 +00:00
| _ => Error("Wrong number of variables in lognormal distribution");
let exponential: array(arg) => result(SymbolicDist.distTree, string) =
2020-03-26 16:01:52 +00:00
fun
2020-06-14 01:54:54 +00:00
| [|Value(rate)|] => Ok(`Simple(`Exponential({rate: rate})))
2020-03-26 16:01:52 +00:00
| _ => Error("Wrong number of variables in Exponential distribution");
let cauchy: array(arg) => result(SymbolicDist.distTree, string) =
2020-03-26 16:01:52 +00:00
fun
| [|Value(local), Value(scale)|] =>
2020-06-14 01:54:54 +00:00
Ok(`Simple(`Cauchy({local, scale})))
2020-03-26 16:01:52 +00:00
| _ => Error("Wrong number of variables in cauchy distribution");
let triangular: array(arg) => result(SymbolicDist.distTree, string) =
2020-03-26 16:01:52 +00:00
fun
| [|Value(low), Value(medium), Value(high)|] =>
2020-06-14 01:54:54 +00:00
Ok(`Simple(`Triangular({low, medium, high})))
2020-03-26 16:01:52 +00:00
| _ => Error("Wrong number of variables in triangle distribution");
2020-03-25 15:12:39 +00:00
let multiModal =
(
args: array(result(SymbolicDist.distTree, string)),
2020-04-03 20:53:23 +00:00
weights: option(array(float)),
2020-03-25 15:12:39 +00:00
) => {
2020-04-03 20:53:23 +00:00
let weights = weights |> E.O.default([||]);
2020-03-24 17:48:46 +00:00
let dists =
args
|> E.A.fmap(
fun
2020-06-14 01:54:54 +00:00
| Ok(`Simple(d)) => Ok(`Simple(d))
| Ok(`Combination(t1, t2, op)) => Ok(`Combination(t1, t2, op))
| Ok(`PointwiseSum(t1, t2)) => Ok(`PointwiseSum(t1, t2))
| Ok(`PointwiseProduct(t1, t2)) => Ok(`PointwiseProduct(t1, t2))
| Ok(`Normalize(t)) => Ok(`Normalize(t))
| Ok(`LeftTruncate(t, x)) => Ok(`LeftTruncate(t, x))
| Ok(`RightTruncate(t, x)) => Ok(`RightTruncate(t, x))
| Ok(`Render(t)) => Ok(`Render(t))
2020-04-02 23:20:58 +00:00
| Error(e) => Error(e)
| _ => Error("Unexpected dist")
2020-04-03 20:53:23 +00:00
);
2020-04-02 23:20:58 +00:00
let firstWithError = dists |> Belt.Array.getBy(_, Belt.Result.isError);
let withoutErrors = dists |> E.A.fmap(E.R.toOption) |> E.A.O.concatSomes;
2020-04-03 20:53:23 +00:00
switch (firstWithError) {
| Some(Error(e)) => Error(e)
| None when withoutErrors |> E.A.length == 0 =>
Error("Multimodals need at least one input")
| _ => {
let components = withoutErrors
|> E.A.fmapi((index, t) => {
let w = weights |> E.A.get(_, index) |> E.O.default(1.0);
2020-06-14 01:54:54 +00:00
`VerticalScaling(t, `Simple(`Float(w)))
});
let pointwiseSum = components
|> Js.Array.sliceFrom(1)
|> E.A.fold_left((acc, x) => {
`PointwiseSum(acc, x)
}, E.A.unsafe_get(components, 0))
Ok(`Normalize(pointwiseSum))
}
2020-03-24 17:48:46 +00:00
};
};
let arrayParser = (args:array(arg)):result(SymbolicDist.distTree, string) => {
let samples = args
|> E.A.fmap(
fun
| Value(n) => Some(n)
| _ => None
)
|> E.A.O.concatSomes
let outputs = Samples.T.fromSamples(samples);
let pdf = outputs.shape |> E.O.bind(_,Distributions.Shape.T.toContinuous);
let shape = pdf |> E.O.fmap(pdf => {
let _pdf = Distributions.Continuous.T.scaleToIntegralSum(~cache=None, ~intendedSum=1.0, pdf);
let cdf = Distributions.Continuous.T.integral(~cache=None, _pdf);
SymbolicDist.ContinuousShape.make(_pdf, cdf)
});
switch(shape){
2020-06-14 01:54:54 +00:00
| Some(s) => Ok(`Simple(`ContinuousShape(s)))
| None => Error("Rendering did not work")
}
}
let rec functionParser = (r): result(SymbolicDist.distTree, string) =>
2020-03-24 17:48:46 +00:00
r
|> (
fun
| Fn({name: "normal", args}) => normal(args)
| Fn({name: "lognormal", args}) => lognormal(args)
| Fn({name: "uniform", args}) => uniform(args)
| Fn({name: "beta", args}) => beta(args)
| Fn({name: "to", args}) => to_(args)
2020-03-26 16:01:52 +00:00
| Fn({name: "exponential", args}) => exponential(args)
| Fn({name: "cauchy", args}) => cauchy(args)
| Fn({name: "triangular", args}) => triangular(args)
2020-06-14 01:54:54 +00:00
| Value(f) => Ok(`Simple(`Float(f)))
2020-03-24 17:48:46 +00:00
| Fn({name: "mm", args}) => {
2020-03-25 15:12:39 +00:00
let weights =
args
|> E.A.last
|> E.O.bind(
_,
fun
| Array(values) => Some(values)
| _ => None,
)
2020-04-03 20:53:23 +00:00
|> E.O.fmap(o =>
o
|> E.A.fmap(
fun
| Value(r) => Some(r)
| _ => None,
)
|> E.A.O.concatSomes
);
let possibleDists =
E.O.isSome(weights)
? Belt.Array.slice(args, ~offset=0, ~len=E.A.length(args) - 1)
: args;
let dists = possibleDists |> E.A.fmap(functionParser);
2020-03-25 15:12:39 +00:00
multiModal(dists, weights);
2020-03-24 17:48:46 +00:00
}
| Fn({name: "add", args}) => {
args
|> E.A.fmap(functionParser)
|> (fun
| [|Ok(l), Ok(r)|] => Ok(`Combination(l, r, `AddOperation))
| _ => Error("Addition needs two operands"))
}
| Fn({name: "subtract", args}) => {
args
|> E.A.fmap(functionParser)
|> (fun
| [|Ok(l), Ok(r)|] => Ok(`Combination(l, r, `SubtractOperation))
| _ => Error("Subtraction needs two operands"))
}
| Fn({name: "multiply", args}) => {
args
|> E.A.fmap(functionParser)
|> (fun
| [|Ok(l), Ok(r)|] => Ok(`Combination(l, r, `MultiplyOperation))
| _ => Error("Multiplication needs two operands"))
}
| Fn({name: "divide", args}) => {
args
|> E.A.fmap(functionParser)
|> (fun
2020-06-14 01:54:54 +00:00
| [|Ok(l), Ok(`Simple(`Float(0.0)))|] => Error("Division by zero")
| [|Ok(l), Ok(r)|] => Ok(`Combination(l, r, `DivideOperation))
| _ => Error("Division needs two operands"))
}
2020-06-14 01:46:38 +00:00
| Fn({name: "pow", args}) => {
args
|> E.A.fmap(functionParser)
|> (fun
| [|Ok(l), Ok(r)|] => Ok(`Combination(l, r, `ExponentiateOperation))
| _ => Error("Exponentiations needs two operands"))
}
| Fn({name: "leftTruncate", args}) => {
args
|> E.A.fmap(functionParser)
|> (fun
2020-06-14 01:54:54 +00:00
| [|Ok(l), Ok(`Simple(`Float(r)))|] => Ok(`LeftTruncate(l, r))
2020-06-14 01:46:38 +00:00
| _ => Error("leftTruncate needs two arguments: the expression and the cutoff"))
}
| Fn({name: "rightTruncate", args}) => {
args
|> E.A.fmap(functionParser)
|> (fun
2020-06-14 01:54:54 +00:00
| [|Ok(l), Ok(`Simple(`Float(r)))|] => Ok(`RightTruncate(l, r))
2020-06-14 01:46:38 +00:00
| _ => Error("rightTruncate needs two arguments: the expression and the cutoff"))
}
2020-03-25 15:12:39 +00:00
| Fn({name}) => Error(name ++ ": function not supported")
2020-04-03 20:53:23 +00:00
| _ => {
Error("This type not currently supported");
}
2020-03-24 17:48:46 +00:00
);
let topLevel = (r): result(SymbolicDist.distTree, string) =>
2020-03-24 17:48:46 +00:00
r
|> (
fun
| Fn(_) => functionParser(r)
2020-06-14 01:54:54 +00:00
| Value(r) => Ok(`Simple(`Float(r)))
| Array(r) => arrayParser(r)
2020-03-24 17:48:46 +00:00
| Symbol(_) => Error("Symbol not valid as top level")
| Object(_) => Error("Object not valid as top level")
);
2020-03-24 00:04:48 +00:00
let run = (r): result(SymbolicDist.distTree, string) =>
r |> MathAdtCleaner.run |> topLevel;
2020-03-24 17:48:46 +00:00
};
let fromString = str => {
/* We feed the user-typed string into Mathjs.parseMath,
which returns a JSON with (hopefully) a single-element array.
This array element is the top-level node of a nested-object tree
representing the functions/arguments/values/etc. in the string.
The function MathJsonToMathJsAdt then recursively unpacks this JSON into a typed data structure we can use.
Inside of this function, MathAdtToDistDst is called whenever a distribution function is encountered.
*/
2020-03-24 17:48:46 +00:00
let mathJsToJson = Mathjs.parseMath(str);
let mathJsParse =
E.R.bind(mathJsToJson, r => {
2020-03-24 17:48:46 +00:00
switch (MathJsonToMathJsAdt.run(r)) {
| Some(r) => Ok(r)
| None => Error("MathJsParse Error")
}
});
2020-03-24 17:48:46 +00:00
let value = E.R.bind(mathJsParse, MathAdtToDistDst.run);
value;
};