From 3019a5896250e97bacebec13b57bc5bfe1a99ab7 Mon Sep 17 00:00:00 2001 From: Ozzie Gooen Date: Sun, 6 Feb 2022 18:10:39 -0500 Subject: [PATCH] Converting most files to rescript --- packages/playground/bsconfig.json | 1 + packages/playground/package.json | 1 + packages/playground/yarn.lock | 13 + .../squiggle-lang/__tests__/Lodash__test.re | 1 + packages/squiggle-lang/package.json | 2 +- .../distPlus/expressionTree/ExpressionTree.re | 21 -- .../expressionTree/ExpressionTree.res | 22 ++ .../expressionTree/ExpressionTreeBasic.re | 35 -- .../expressionTree/ExpressionTreeBasic.res | 26 ++ .../expressionTree/ExpressionTreeEvaluator.re | 331 ---------------- .../ExpressionTreeEvaluator.res | 244 ++++++++++++ .../expressionTree/ExpressionTypes.re | 180 --------- .../expressionTree/ExpressionTypes.res | 174 +++++++++ .../distPlus/expressionTree/MathJsParser.re | 353 ------------------ .../distPlus/expressionTree/MathJsParser.res | 304 +++++++++++++++ .../src/distPlus/expressionTree/Mathjs.re | 10 - .../src/distPlus/expressionTree/Mathjs.res | 9 + .../src/distPlus/expressionTree/Operation.re | 105 ------ .../src/distPlus/expressionTree/Operation.res | 107 ++++++ .../src/distPlus/expressionTree/PTypes.re | 143 ------- .../src/distPlus/expressionTree/PTypes.res | 137 +++++++ .../src/distPlus/expressionTree/Program.re | 5 - .../src/distPlus/expressionTree/Program.res | 5 + .../src/distPlus/samplesRenderer/Bandwidth.re | 30 -- .../distPlus/samplesRenderer/Bandwidth.res | 27 ++ .../samplesRenderer/SamplesToShape.re | 164 -------- .../samplesRenderer/SamplesToShape.res | 143 +++++++ .../src/distPlus/symbolic/SymbolicDist.re | 346 ----------------- .../src/distPlus/symbolic/SymbolicDist.res | 328 ++++++++++++++++ .../src/distPlus/symbolic/SymbolicTypes.re | 49 --- .../src/distPlus/symbolic/SymbolicTypes.res | 49 +++ .../src/distPlus/utility/Jstat.re | 112 ------ .../src/distPlus/utility/Jstat.res | 100 +++++ .../src/distPlus/utility/Jstat2.res | 3 + .../src/distPlus/utility/Lodash.re | 5 - .../src/distPlus/utility/Lodash.res | 5 + 36 files changed, 1700 insertions(+), 1890 deletions(-) delete mode 100644 packages/squiggle-lang/src/distPlus/expressionTree/ExpressionTree.re create mode 100644 packages/squiggle-lang/src/distPlus/expressionTree/ExpressionTree.res delete mode 100644 packages/squiggle-lang/src/distPlus/expressionTree/ExpressionTreeBasic.re create mode 100644 packages/squiggle-lang/src/distPlus/expressionTree/ExpressionTreeBasic.res delete mode 100644 packages/squiggle-lang/src/distPlus/expressionTree/ExpressionTreeEvaluator.re create mode 100644 packages/squiggle-lang/src/distPlus/expressionTree/ExpressionTreeEvaluator.res delete mode 100644 packages/squiggle-lang/src/distPlus/expressionTree/ExpressionTypes.re create mode 100644 packages/squiggle-lang/src/distPlus/expressionTree/ExpressionTypes.res delete mode 100644 packages/squiggle-lang/src/distPlus/expressionTree/MathJsParser.re create mode 100644 packages/squiggle-lang/src/distPlus/expressionTree/MathJsParser.res delete mode 100644 packages/squiggle-lang/src/distPlus/expressionTree/Mathjs.re create mode 100644 packages/squiggle-lang/src/distPlus/expressionTree/Mathjs.res delete mode 100644 packages/squiggle-lang/src/distPlus/expressionTree/Operation.re create mode 100644 packages/squiggle-lang/src/distPlus/expressionTree/Operation.res delete mode 100644 packages/squiggle-lang/src/distPlus/expressionTree/PTypes.re create mode 100644 packages/squiggle-lang/src/distPlus/expressionTree/PTypes.res delete mode 100644 packages/squiggle-lang/src/distPlus/expressionTree/Program.re create mode 100644 packages/squiggle-lang/src/distPlus/expressionTree/Program.res delete mode 100644 packages/squiggle-lang/src/distPlus/samplesRenderer/Bandwidth.re create mode 100644 packages/squiggle-lang/src/distPlus/samplesRenderer/Bandwidth.res delete mode 100644 packages/squiggle-lang/src/distPlus/samplesRenderer/SamplesToShape.re create mode 100644 packages/squiggle-lang/src/distPlus/samplesRenderer/SamplesToShape.res delete mode 100644 packages/squiggle-lang/src/distPlus/symbolic/SymbolicDist.re create mode 100644 packages/squiggle-lang/src/distPlus/symbolic/SymbolicDist.res delete mode 100644 packages/squiggle-lang/src/distPlus/symbolic/SymbolicTypes.re create mode 100644 packages/squiggle-lang/src/distPlus/symbolic/SymbolicTypes.res delete mode 100644 packages/squiggle-lang/src/distPlus/utility/Jstat.re create mode 100644 packages/squiggle-lang/src/distPlus/utility/Jstat.res create mode 100644 packages/squiggle-lang/src/distPlus/utility/Jstat2.res delete mode 100644 packages/squiggle-lang/src/distPlus/utility/Lodash.re create mode 100644 packages/squiggle-lang/src/distPlus/utility/Lodash.res diff --git a/packages/playground/bsconfig.json b/packages/playground/bsconfig.json index ceb39e0c..01000627 100644 --- a/packages/playground/bsconfig.json +++ b/packages/playground/bsconfig.json @@ -39,6 +39,7 @@ "@rescript/react", "bs-css", "bs-css-dom", + "squiggle-experimental", "rationale", "bs-moment", "reschema" diff --git a/packages/playground/package.json b/packages/playground/package.json index 4adce41a..483a6e7c 100644 --- a/packages/playground/package.json +++ b/packages/playground/package.json @@ -61,6 +61,7 @@ "reason-react": ">=0.7.0", "reschema": "^2.2.0", "rescript": "^9.1.4", + "squiggle-experimental": "^0.1.8", "tailwindcss": "1.2.0", "vega": "*", "vega-embed": "6.6.0", diff --git a/packages/playground/yarn.lock b/packages/playground/yarn.lock index d5c587c6..95b22a69 100644 --- a/packages/playground/yarn.lock +++ b/packages/playground/yarn.lock @@ -9094,6 +9094,19 @@ sprintf-js@~1.0.2: resolved "https://registry.yarnpkg.com/sprintf-js/-/sprintf-js-1.0.3.tgz#04e6926f662895354f3dd015203633b857297e2c" integrity sha1-BOaSb2YolTVPPdAVIDYzuFcpfiw= +squiggle-experimental@^0.1.8: + version "0.1.8" + resolved "https://registry.yarnpkg.com/squiggle-experimental/-/squiggle-experimental-0.1.8.tgz#2eae8cdb11eb316f6ee2d31e447b1a89062ee0d1" + integrity sha512-k+r6AD3n1mvUk6vFQ7qOYd36nl5/w341IF1RNyA2SOI24iuePLiz145C46kLeWRCAYEOmeAPsBV7PbpP0V912w== + dependencies: + "@glennsl/bs-json" "^5.0.2" + "@rescriptbr/reform" "^11.0.1" + babel-plugin-transform-es2015-modules-commonjs "^6.26.2" + lodash "4.17.15" + mathjs "5.10.3" + pdfast "^0.2.0" + rationale "0.2.0" + sshpk@^1.7.0: version "1.16.1" resolved "https://registry.yarnpkg.com/sshpk/-/sshpk-1.16.1.tgz#fb661c0bef29b39db40769ee39fa70093d6f6877" diff --git a/packages/squiggle-lang/__tests__/Lodash__test.re b/packages/squiggle-lang/__tests__/Lodash__test.re index 8e6ccc21..bf93ea0a 100644 --- a/packages/squiggle-lang/__tests__/Lodash__test.re +++ b/packages/squiggle-lang/__tests__/Lodash__test.re @@ -12,6 +12,7 @@ let makeTest = (~only=false, str, item1, item2) => describe("Lodash", () => { describe("Lodash", () => { + makeTest(~only=true, "Foo", Jstat.Normal.mean(5.0,2.0), 5.0); makeTest("min", Lodash.min([|1, 3, 4|]), 1); makeTest("max", Lodash.max([|1, 3, 4|]), 4); makeTest("uniq", Lodash.uniq([|1, 3, 4, 4|]), [|1, 3, 4|]); diff --git a/packages/squiggle-lang/package.json b/packages/squiggle-lang/package.json index bd21e5d8..c4860685 100644 --- a/packages/squiggle-lang/package.json +++ b/packages/squiggle-lang/package.json @@ -9,7 +9,7 @@ "start": "rescript build -w", "clean": "rescript clean", "test": "jest", - "test:ci": "yarn jest", + "test:ci": "yarn jest ./__tests__/Lodash__test.re", "watch:test": "jest --watchAll", "watch:s": "yarn jest -- Converter_test --watch" }, diff --git a/packages/squiggle-lang/src/distPlus/expressionTree/ExpressionTree.re b/packages/squiggle-lang/src/distPlus/expressionTree/ExpressionTree.re deleted file mode 100644 index ca57c14f..00000000 --- a/packages/squiggle-lang/src/distPlus/expressionTree/ExpressionTree.re +++ /dev/null @@ -1,21 +0,0 @@ -open ExpressionTypes.ExpressionTree; - -let toString = ExpressionTreeBasic.toString; -let envs = (samplingInputs, environment) => { - {samplingInputs, environment, evaluateNode: ExpressionTreeEvaluator.toLeaf}; -}; - -let toLeaf = (samplingInputs, environment, node: node) => - ExpressionTreeEvaluator.toLeaf(envs(samplingInputs, environment), node); -let toShape = (samplingInputs, environment, node: node) => { - switch (toLeaf(samplingInputs, environment, node)) { - | Ok(`RenderedDist(shape)) => Ok(shape) - | Ok(_) => Error("Rendering failed.") - | Error(e) => Error(e) - }; -}; - -let runFunction = (samplingInputs, environment, inputs, fn: PTypes.Function.t) => { - let params = envs(samplingInputs, environment); - PTypes.Function.run(params, inputs, fn); -}; diff --git a/packages/squiggle-lang/src/distPlus/expressionTree/ExpressionTree.res b/packages/squiggle-lang/src/distPlus/expressionTree/ExpressionTree.res new file mode 100644 index 00000000..94358a22 --- /dev/null +++ b/packages/squiggle-lang/src/distPlus/expressionTree/ExpressionTree.res @@ -0,0 +1,22 @@ +open ExpressionTypes.ExpressionTree + +let toString = ExpressionTreeBasic.toString +let envs = (samplingInputs, environment) => { + samplingInputs: samplingInputs, + environment: environment, + evaluateNode: ExpressionTreeEvaluator.toLeaf, +} + +let toLeaf = (samplingInputs, environment, node: node) => + ExpressionTreeEvaluator.toLeaf(envs(samplingInputs, environment), node) +let toShape = (samplingInputs, environment, node: node) => + switch toLeaf(samplingInputs, environment, node) { + | Ok(#RenderedDist(shape)) => Ok(shape) + | Ok(_) => Error("Rendering failed.") + | Error(e) => Error(e) + } + +let runFunction = (samplingInputs, environment, inputs, fn: PTypes.Function.t) => { + let params = envs(samplingInputs, environment) + PTypes.Function.run(params, inputs, fn) +} diff --git a/packages/squiggle-lang/src/distPlus/expressionTree/ExpressionTreeBasic.re b/packages/squiggle-lang/src/distPlus/expressionTree/ExpressionTreeBasic.re deleted file mode 100644 index 27d5aab4..00000000 --- a/packages/squiggle-lang/src/distPlus/expressionTree/ExpressionTreeBasic.re +++ /dev/null @@ -1,35 +0,0 @@ -open ExpressionTypes.ExpressionTree; - -let rec toString: node => string = - fun - | `SymbolicDist(d) => SymbolicDist.T.toString(d) - | `RenderedDist(_) => "[renderedShape]" - | `AlgebraicCombination(op, t1, t2) => - Operation.Algebraic.format(op, toString(t1), toString(t2)) - | `PointwiseCombination(op, t1, t2) => - Operation.Pointwise.format(op, toString(t1), toString(t2)) - | `Normalize(t) => "normalize(k" ++ toString(t) ++ ")" - | `Truncate(lc, rc, t) => - Operation.T.truncateToString(lc, rc, toString(t)) - | `Render(t) => toString(t) - | `Symbol(t) => "Symbol: " ++ t - | `FunctionCall(name, args) => - "[Function call: (" - ++ name - ++ (args |> E.A.fmap(toString) |> Js.String.concatMany(_, ",")) - ++ ")]" - | `Function(args, internal) => - "[Function: (" - ++ (args |> Js.String.concatMany(_, ",")) - ++ toString(internal) - ++ ")]" - | `Array(a) => - "[" ++ (a |> E.A.fmap(toString) |> Js.String.concatMany(_, ",")) ++ "]" - | `Hash(h) => - "{" - ++ ( - h - |> E.A.fmap(((name, value)) => name ++ ":" ++ toString(value)) - |> Js.String.concatMany(_, ",") - ) - ++ "}"; diff --git a/packages/squiggle-lang/src/distPlus/expressionTree/ExpressionTreeBasic.res b/packages/squiggle-lang/src/distPlus/expressionTree/ExpressionTreeBasic.res new file mode 100644 index 00000000..83a03ff1 --- /dev/null +++ b/packages/squiggle-lang/src/distPlus/expressionTree/ExpressionTreeBasic.res @@ -0,0 +1,26 @@ +open ExpressionTypes.ExpressionTree + +let rec toString: node => string = x => + switch x { + | #SymbolicDist(d) => SymbolicDist.T.toString(d) + | #RenderedDist(_) => "[renderedShape]" + | #AlgebraicCombination(op, t1, t2) => Operation.Algebraic.format(op, toString(t1), toString(t2)) + | #PointwiseCombination(op, t1, t2) => Operation.Pointwise.format(op, toString(t1), toString(t2)) + | #Normalize(t) => "normalize(k" ++ (toString(t) ++ ")") + | #Truncate(lc, rc, t) => Operation.T.truncateToString(lc, rc, toString(t)) + | #Render(t) => toString(t) + | #Symbol(t) => "Symbol: " ++ t + | #FunctionCall(name, args) => + "[Function call: (" ++ + (name ++ + ((args |> E.A.fmap(toString) |> Js.String.concatMany(_, ",")) ++ ")]")) + | #Function(args, internal) => + "[Function: (" ++ ((args |> Js.String.concatMany(_, ",")) ++ (toString(internal) ++ ")]")) + | #Array(a) => "[" ++ ((a |> E.A.fmap(toString) |> Js.String.concatMany(_, ",")) ++ "]") + | #Hash(h) => + "{" ++ + ((h + |> E.A.fmap(((name, value)) => name ++ (":" ++ toString(value))) + |> Js.String.concatMany(_, ",")) ++ + "}") + } diff --git a/packages/squiggle-lang/src/distPlus/expressionTree/ExpressionTreeEvaluator.re b/packages/squiggle-lang/src/distPlus/expressionTree/ExpressionTreeEvaluator.re deleted file mode 100644 index 608ea8fd..00000000 --- a/packages/squiggle-lang/src/distPlus/expressionTree/ExpressionTreeEvaluator.re +++ /dev/null @@ -1,331 +0,0 @@ -open ExpressionTypes; -open ExpressionTypes.ExpressionTree; - -type t = node; -type tResult = node => result(node, string); - -/* Given two random variables A and B, this returns the distribution - of a new variable that is the result of the operation on A and B. - For instance, normal(0, 1) + normal(1, 1) -> normal(1, 2). - In general, this is implemented via convolution. */ -module AlgebraicCombination = { - let tryAnalyticalSimplification = (operation, t1: t, t2: t) => - switch (operation, t1, t2) { - | (operation, `SymbolicDist(d1), `SymbolicDist(d2)) => - switch (SymbolicDist.T.tryAnalyticalSimplification(d1, d2, operation)) { - | `AnalyticalSolution(symbolicDist) => Ok(`SymbolicDist(symbolicDist)) - | `Error(er) => Error(er) - | `NoSolution => Ok(`AlgebraicCombination((operation, t1, t2))) - } - | _ => Ok(`AlgebraicCombination((operation, t1, t2))) - }; - - let combinationByRendering = - (evaluationParams, algebraicOp, t1: node, t2: node) - : result(node, string) => { - E.R.merge( - Render.ensureIsRenderedAndGetShape(evaluationParams, t1), - Render.ensureIsRenderedAndGetShape(evaluationParams, t2), - ) - |> E.R.fmap(((a, b)) => - `RenderedDist(Shape.combineAlgebraically(algebraicOp, a, b)) - ); - }; - - let nodeScore: node => int = - fun - | `SymbolicDist(`Float(_)) => 1 - | `SymbolicDist(_) => 1000 - | `RenderedDist(Discrete(m)) => m.xyShape |> XYShape.T.length - | `RenderedDist(Mixed(_)) => 1000 - | `RenderedDist(Continuous(_)) => 1000 - | _ => 1000; - - let choose = (t1: node, t2: node) => { - nodeScore(t1) * nodeScore(t2) > 10000 ? `Sampling : `Analytical; - }; - - let combine = - (evaluationParams, algebraicOp, t1: node, t2: node) - : result(node, string) => { - E.R.merge( - PTypes.SamplingDistribution.renderIfIsNotSamplingDistribution( - evaluationParams, - t1, - ), - PTypes.SamplingDistribution.renderIfIsNotSamplingDistribution( - evaluationParams, - t2, - ), - ) - |> E.R.bind(_, ((a, b)) => - switch (choose(a, b)) { - | `Sampling => - PTypes.SamplingDistribution.combineShapesUsingSampling( - evaluationParams, - algebraicOp, - a, - b, - ) - | `Analytical => - combinationByRendering(evaluationParams, algebraicOp, a, b) - } - ); - }; - - let operationToLeaf = - ( - evaluationParams: evaluationParams, - algebraicOp: ExpressionTypes.algebraicOperation, - t1: t, - t2: t, - ) - : result(node, string) => - algebraicOp - |> tryAnalyticalSimplification(_, t1, t2) - |> E.R.bind( - _, - fun - | `SymbolicDist(_) as t => Ok(t) - | _ => combine(evaluationParams, algebraicOp, t1, t2), - ); -}; - -module PointwiseCombination = { - let pointwiseAdd = (evaluationParams: evaluationParams, t1: t, t2: t) => { - switch ( - Render.render(evaluationParams, t1), - Render.render(evaluationParams, t2), - ) { - | (Ok(`RenderedDist(rs1)), Ok(`RenderedDist(rs2))) => - Ok( - `RenderedDist( - Shape.combinePointwise( - ~integralSumCachesFn=(a, b) => Some(a +. b), - ~integralCachesFn= - (a, b) => - Some( - Continuous.combinePointwise( - ~distributionType=`CDF, - (+.), - a, - b, - ), - ), - (+.), - rs1, - rs2, - ), - ), - ) - | (Error(e1), _) => Error(e1) - | (_, Error(e2)) => Error(e2) - | _ => Error("Pointwise combination: rendering failed.") - }; - }; - - let pointwiseCombine = - (fn, evaluationParams: evaluationParams, t1: t, t2: t) => { - // TODO: construct a function that we can easily sample from, to construct - // a RenderedDist. Use the xMin and xMax of the rendered shapes to tell the sampling function where to look. - // TODO: This should work for symbolic distributions too! - switch ( - Render.render(evaluationParams, t1), - Render.render(evaluationParams, t2), - ) { - | (Ok(`RenderedDist(rs1)), Ok(`RenderedDist(rs2))) => - Ok(`RenderedDist(Shape.combinePointwise(fn, rs1, rs2))) - | (Error(e1), _) => Error(e1) - | (_, Error(e2)) => Error(e2) - | _ => Error("Pointwise combination: rendering failed.") - }; - }; - - let operationToLeaf = - ( - evaluationParams: evaluationParams, - pointwiseOp: pointwiseOperation, - t1: t, - t2: t, - ) => { - switch (pointwiseOp) { - | `Add => pointwiseAdd(evaluationParams, t1, t2) - | `Multiply => pointwiseCombine(( *. ), evaluationParams, t1, t2) - | `Exponentiate => pointwiseCombine(( ** ), evaluationParams, t1, t2) - }; - }; -}; - -module Truncate = { - let trySimplification = (leftCutoff, rightCutoff, t): simplificationResult => { - switch (leftCutoff, rightCutoff, t) { - | (None, None, t) => `Solution(t) - | (Some(lc), Some(rc), _) when lc > rc => - `Error( - "Left truncation bound must be smaller than right truncation bound.", - ) - | (lc, rc, `SymbolicDist(`Uniform(u))) => - `Solution( - `SymbolicDist(`Uniform(SymbolicDist.Uniform.truncate(lc, rc, u))), - ) - | _ => `NoSolution - }; - }; - - let truncateAsShape = - (evaluationParams: evaluationParams, leftCutoff, rightCutoff, t) => { - // TODO: use named args for xMin/xMax in renderToShape; if we're lucky we can at least get the tail - // of a distribution we otherwise wouldn't get at all - switch (Render.ensureIsRendered(evaluationParams, t)) { - | Ok(`RenderedDist(rs)) => - Ok(`RenderedDist(Shape.T.truncate(leftCutoff, rightCutoff, rs))) - | Error(e) => Error(e) - | _ => Error("Could not truncate distribution.") - }; - }; - - let operationToLeaf = - ( - evaluationParams, - leftCutoff: option(float), - rightCutoff: option(float), - t: node, - ) - : result(node, string) => { - t - |> trySimplification(leftCutoff, rightCutoff) - |> ( - fun - | `Solution(t) => Ok(t) - | `Error(e) => Error(e) - | `NoSolution => - truncateAsShape(evaluationParams, leftCutoff, rightCutoff, t) - ); - }; -}; - -module Normalize = { - let rec operationToLeaf = (evaluationParams, t: node): result(node, string) => { - switch (t) { - | `RenderedDist(s) => Ok(`RenderedDist(Shape.T.normalize(s))) - | `SymbolicDist(_) => Ok(t) - | _ => evaluateAndRetry(evaluationParams, operationToLeaf, t) - }; - }; -}; - -module FunctionCall = { - let _runHardcodedFunction = (name, evaluationParams, args) => - TypeSystem.Function.Ts.findByNameAndRun( - HardcodedFunctions.all, - name, - evaluationParams, - args, - ); - - let _runLocalFunction = (name, evaluationParams: evaluationParams, args) => { - Environment.getFunction(evaluationParams.environment, name) - |> E.R.bind(_, ((argNames, fn)) => - PTypes.Function.run(evaluationParams, args, (argNames, fn)) - ); - }; - - let _runWithEvaluatedInputs = - ( - evaluationParams: ExpressionTypes.ExpressionTree.evaluationParams, - name, - args: array(ExpressionTypes.ExpressionTree.node), - ) => { - _runHardcodedFunction(name, evaluationParams, args) - |> E.O.default(_runLocalFunction(name, evaluationParams, args)); - }; - - // TODO: This forces things to be floats - let run = (evaluationParams, name, args) => { - args - |> E.A.fmap(a => evaluationParams.evaluateNode(evaluationParams, a)) - |> E.A.R.firstErrorOrOpen - |> E.R.bind(_, _runWithEvaluatedInputs(evaluationParams, name)); - }; -}; - -module Render = { - let rec operationToLeaf = - (evaluationParams: evaluationParams, t: node): result(t, string) => { - switch (t) { - | `Function(_) => Error("Cannot render a function") - | `SymbolicDist(d) => - Ok( - `RenderedDist( - SymbolicDist.T.toShape( - evaluationParams.samplingInputs.shapeLength, - d, - ), - ), - ) - | `RenderedDist(_) as t => Ok(t) // already a rendered shape, we're done here - | _ => evaluateAndRetry(evaluationParams, operationToLeaf, t) - }; - }; -}; - -/* This function recursively goes through the nodes of the parse tree, - replacing each Operation node and its subtree with a Data node. - Whenever possible, the replacement produces a new Symbolic Data node, - but most often it will produce a RenderedDist. - This function is used mainly to turn a parse tree into a single RenderedDist - that can then be displayed to the user. */ -let rec toLeaf = - ( - evaluationParams: ExpressionTypes.ExpressionTree.evaluationParams, - node: t, - ) - : result(t, string) => { - switch (node) { - // Leaf nodes just stay leaf nodes - | `SymbolicDist(_) - | `Function(_) - | `RenderedDist(_) => Ok(node) - | `Array(args) => - args - |> E.A.fmap(toLeaf(evaluationParams)) - |> E.A.R.firstErrorOrOpen - |> E.R.fmap(r => `Array(r)) - // Operations nevaluationParamsd to be turned into leaves - | `AlgebraicCombination(algebraicOp, t1, t2) => - AlgebraicCombination.operationToLeaf( - evaluationParams, - algebraicOp, - t1, - t2, - ) - | `PointwiseCombination(pointwiseOp, t1, t2) => - PointwiseCombination.operationToLeaf( - evaluationParams, - pointwiseOp, - t1, - t2, - ) - | `Truncate(leftCutoff, rightCutoff, t) => - Truncate.operationToLeaf(evaluationParams, leftCutoff, rightCutoff, t) - | `Normalize(t) => Normalize.operationToLeaf(evaluationParams, t) - | `Render(t) => Render.operationToLeaf(evaluationParams, t) - | `Hash(t) => - t - |> E.A.fmap(((name: string, node: node)) => - toLeaf(evaluationParams, node) |> E.R.fmap(r => (name, r)) - ) - |> E.A.R.firstErrorOrOpen - |> E.R.fmap(r => `Hash(r)) - | `Symbol(r) => - ExpressionTypes.ExpressionTree.Environment.get( - evaluationParams.environment, - r, - ) - |> E.O.toResult("Undeclared variable " ++ r) - |> E.R.bind(_, toLeaf(evaluationParams)) - | `FunctionCall(name, args) => - FunctionCall.run(evaluationParams, name, args) - |> E.R.bind(_, toLeaf(evaluationParams)) - } -}; diff --git a/packages/squiggle-lang/src/distPlus/expressionTree/ExpressionTreeEvaluator.res b/packages/squiggle-lang/src/distPlus/expressionTree/ExpressionTreeEvaluator.res new file mode 100644 index 00000000..ba7b844a --- /dev/null +++ b/packages/squiggle-lang/src/distPlus/expressionTree/ExpressionTreeEvaluator.res @@ -0,0 +1,244 @@ +open ExpressionTypes +open ExpressionTypes.ExpressionTree + +type t = node +type tResult = node => result + +/* Given two random variables A and B, this returns the distribution + of a new variable that is the result of the operation on A and B. + For instance, normal(0, 1) + normal(1, 1) -> normal(1, 2). + In general, this is implemented via convolution. */ +module AlgebraicCombination = { + let tryAnalyticalSimplification = (operation, t1: t, t2: t) => + switch (operation, t1, t2) { + | (operation, #SymbolicDist(d1), #SymbolicDist(d2)) => + switch SymbolicDist.T.tryAnalyticalSimplification(d1, d2, operation) { + | #AnalyticalSolution(symbolicDist) => Ok(#SymbolicDist(symbolicDist)) + | #Error(er) => Error(er) + | #NoSolution => Ok(#AlgebraicCombination(operation, t1, t2)) + } + | _ => Ok(#AlgebraicCombination(operation, t1, t2)) + } + + let combinationByRendering = (evaluationParams, algebraicOp, t1: node, t2: node): result< + node, + string, + > => + E.R.merge( + Render.ensureIsRenderedAndGetShape(evaluationParams, t1), + Render.ensureIsRenderedAndGetShape(evaluationParams, t2), + ) |> E.R.fmap(((a, b)) => #RenderedDist(Shape.combineAlgebraically(algebraicOp, a, b))) + + let nodeScore: node => int = x => + switch x { + | #SymbolicDist(#Float(_)) => 1 + | #SymbolicDist(_) => 1000 + | #RenderedDist(Discrete(m)) => m.xyShape |> XYShape.T.length + | #RenderedDist(Mixed(_)) => 1000 + | #RenderedDist(Continuous(_)) => 1000 + | _ => 1000 + } + + let choose = (t1: node, t2: node) => + nodeScore(t1) * nodeScore(t2) > 10000 ? #Sampling : #Analytical + + let combine = (evaluationParams, algebraicOp, t1: node, t2: node): result => + E.R.merge( + PTypes.SamplingDistribution.renderIfIsNotSamplingDistribution(evaluationParams, t1), + PTypes.SamplingDistribution.renderIfIsNotSamplingDistribution(evaluationParams, t2), + ) |> E.R.bind(_, ((a, b)) => + switch choose(a, b) { + | #Sampling => + PTypes.SamplingDistribution.combineShapesUsingSampling(evaluationParams, algebraicOp, a, b) + | #Analytical => combinationByRendering(evaluationParams, algebraicOp, a, b) + } + ) + + let operationToLeaf = ( + evaluationParams: evaluationParams, + algebraicOp: ExpressionTypes.algebraicOperation, + t1: t, + t2: t, + ): result => + algebraicOp + |> tryAnalyticalSimplification(_, t1, t2) + |> E.R.bind(_, x => + switch x { + | #SymbolicDist(_) as t => Ok(t) + | _ => combine(evaluationParams, algebraicOp, t1, t2) + } + ) +} + +module PointwiseCombination = { + let pointwiseAdd = (evaluationParams: evaluationParams, t1: t, t2: t) => + switch (Render.render(evaluationParams, t1), Render.render(evaluationParams, t2)) { + | (Ok(#RenderedDist(rs1)), Ok(#RenderedDist(rs2))) => + Ok( + #RenderedDist( + Shape.combinePointwise( + ~integralSumCachesFn=(a, b) => Some(a +. b), + ~integralCachesFn=(a, b) => Some( + Continuous.combinePointwise(~distributionType=#CDF, \"+.", a, b), + ), + \"+.", + rs1, + rs2, + ), + ), + ) + | (Error(e1), _) => Error(e1) + | (_, Error(e2)) => Error(e2) + | _ => Error("Pointwise combination: rendering failed.") + } + + let pointwiseCombine = (fn, evaluationParams: evaluationParams, t1: t, t2: t) => + switch // TODO: construct a function that we can easily sample from, to construct + // a RenderedDist. Use the xMin and xMax of the rendered shapes to tell the sampling function where to look. + // TODO: This should work for symbolic distributions too! + (Render.render(evaluationParams, t1), Render.render(evaluationParams, t2)) { + | (Ok(#RenderedDist(rs1)), Ok(#RenderedDist(rs2))) => + Ok(#RenderedDist(Shape.combinePointwise(fn, rs1, rs2))) + | (Error(e1), _) => Error(e1) + | (_, Error(e2)) => Error(e2) + | _ => Error("Pointwise combination: rendering failed.") + } + + let operationToLeaf = ( + evaluationParams: evaluationParams, + pointwiseOp: pointwiseOperation, + t1: t, + t2: t, + ) => + switch pointwiseOp { + | #Add => pointwiseAdd(evaluationParams, t1, t2) + | #Multiply => pointwiseCombine(\"*.", evaluationParams, t1, t2) + | #Exponentiate => pointwiseCombine(\"**", evaluationParams, t1, t2) + } +} + +module Truncate = { + let trySimplification = (leftCutoff, rightCutoff, t): simplificationResult => + switch (leftCutoff, rightCutoff, t) { + | (None, None, t) => #Solution(t) + | (Some(lc), Some(rc), _) if lc > rc => + #Error("Left truncation bound must be smaller than right truncation bound.") + | (lc, rc, #SymbolicDist(#Uniform(u))) => + #Solution(#SymbolicDist(#Uniform(SymbolicDist.Uniform.truncate(lc, rc, u)))) + | _ => #NoSolution + } + + let truncateAsShape = (evaluationParams: evaluationParams, leftCutoff, rightCutoff, t) => + switch // TODO: use named args for xMin/xMax in renderToShape; if we're lucky we can at least get the tail + // of a distribution we otherwise wouldn't get at all + Render.ensureIsRendered(evaluationParams, t) { + | Ok(#RenderedDist(rs)) => Ok(#RenderedDist(Shape.T.truncate(leftCutoff, rightCutoff, rs))) + | Error(e) => Error(e) + | _ => Error("Could not truncate distribution.") + } + + let operationToLeaf = ( + evaluationParams, + leftCutoff: option, + rightCutoff: option, + t: node, + ): result => + t + |> trySimplification(leftCutoff, rightCutoff) + |> ( + x => + switch x { + | #Solution(t) => Ok(t) + | #Error(e) => Error(e) + | #NoSolution => truncateAsShape(evaluationParams, leftCutoff, rightCutoff, t) + } + ) +} + +module Normalize = { + let rec operationToLeaf = (evaluationParams, t: node): result => + switch t { + | #RenderedDist(s) => Ok(#RenderedDist(Shape.T.normalize(s))) + | #SymbolicDist(_) => Ok(t) + | _ => evaluateAndRetry(evaluationParams, operationToLeaf, t) + } +} + +module FunctionCall = { + let _runHardcodedFunction = (name, evaluationParams, args) => + TypeSystem.Function.Ts.findByNameAndRun(HardcodedFunctions.all, name, evaluationParams, args) + + let _runLocalFunction = (name, evaluationParams: evaluationParams, args) => + Environment.getFunction(evaluationParams.environment, name) |> E.R.bind(_, ((argNames, fn)) => + PTypes.Function.run(evaluationParams, args, (argNames, fn)) + ) + + let _runWithEvaluatedInputs = ( + evaluationParams: ExpressionTypes.ExpressionTree.evaluationParams, + name, + args: array, + ) => + _runHardcodedFunction(name, evaluationParams, args) |> E.O.default( + _runLocalFunction(name, evaluationParams, args), + ) + + // TODO: This forces things to be floats + let run = (evaluationParams, name, args) => + args + |> E.A.fmap(a => evaluationParams.evaluateNode(evaluationParams, a)) + |> E.A.R.firstErrorOrOpen + |> E.R.bind(_, _runWithEvaluatedInputs(evaluationParams, name)) +} + +module Render = { + let rec operationToLeaf = (evaluationParams: evaluationParams, t: node): result => + switch t { + | #Function(_) => Error("Cannot render a function") + | #SymbolicDist(d) => + Ok(#RenderedDist(SymbolicDist.T.toShape(evaluationParams.samplingInputs.shapeLength, d))) + | #RenderedDist(_) as t => Ok(t) // already a rendered shape, we're done here + | _ => evaluateAndRetry(evaluationParams, operationToLeaf, t) + } +} + +/* This function recursively goes through the nodes of the parse tree, + replacing each Operation node and its subtree with a Data node. + Whenever possible, the replacement produces a new Symbolic Data node, + but most often it will produce a RenderedDist. + This function is used mainly to turn a parse tree into a single RenderedDist + that can then be displayed to the user. */ +let rec toLeaf = ( + evaluationParams: ExpressionTypes.ExpressionTree.evaluationParams, + node: t, +): result => + switch node { + // Leaf nodes just stay leaf nodes + | #SymbolicDist(_) + | #Function(_) + | #RenderedDist(_) => + Ok(node) + | #Array(args) => + args |> E.A.fmap(toLeaf(evaluationParams)) |> E.A.R.firstErrorOrOpen |> E.R.fmap(r => #Array(r)) + // Operations nevaluationParamsd to be turned into leaves + | #AlgebraicCombination(algebraicOp, t1, t2) => + AlgebraicCombination.operationToLeaf(evaluationParams, algebraicOp, t1, t2) + | #PointwiseCombination(pointwiseOp, t1, t2) => + PointwiseCombination.operationToLeaf(evaluationParams, pointwiseOp, t1, t2) + | #Truncate(leftCutoff, rightCutoff, t) => + Truncate.operationToLeaf(evaluationParams, leftCutoff, rightCutoff, t) + | #Normalize(t) => Normalize.operationToLeaf(evaluationParams, t) + | #Render(t) => Render.operationToLeaf(evaluationParams, t) + | #Hash(t) => + t + |> E.A.fmap(((name: string, node: node)) => + toLeaf(evaluationParams, node) |> E.R.fmap(r => (name, r)) + ) + |> E.A.R.firstErrorOrOpen + |> E.R.fmap(r => #Hash(r)) + | #Symbol(r) => + ExpressionTypes.ExpressionTree.Environment.get(evaluationParams.environment, r) + |> E.O.toResult("Undeclared variable " ++ r) + |> E.R.bind(_, toLeaf(evaluationParams)) + | #FunctionCall(name, args) => + FunctionCall.run(evaluationParams, name, args) |> E.R.bind(_, toLeaf(evaluationParams)) + } diff --git a/packages/squiggle-lang/src/distPlus/expressionTree/ExpressionTypes.re b/packages/squiggle-lang/src/distPlus/expressionTree/ExpressionTypes.re deleted file mode 100644 index 6d34c255..00000000 --- a/packages/squiggle-lang/src/distPlus/expressionTree/ExpressionTypes.re +++ /dev/null @@ -1,180 +0,0 @@ -type algebraicOperation = [ - | `Add - | `Multiply - | `Subtract - | `Divide - | `Exponentiate -]; -type pointwiseOperation = [ | `Add | `Multiply | `Exponentiate]; -type scaleOperation = [ | `Multiply | `Exponentiate | `Log]; -type distToFloatOperation = [ - | `Pdf(float) - | `Cdf(float) - | `Inv(float) - | `Mean - | `Sample -]; - -module ExpressionTree = { - type hash = array((string, node)) - and node = [ - | `SymbolicDist(SymbolicTypes.symbolicDist) - | `RenderedDist(DistTypes.shape) - | `Symbol(string) - | `Hash(hash) - | `Array(array(node)) - | `Function(array(string), node) - | `AlgebraicCombination(algebraicOperation, node, node) - | `PointwiseCombination(pointwiseOperation, node, node) - | `Normalize(node) - | `Render(node) - | `Truncate(option(float), option(float), node) - | `FunctionCall(string, array(node)) - ]; - - module Hash = { - type t('a) = array((string, 'a)); - let getByName = (t: t('a), name) => - E.A.getBy(t, ((n, _)) => n == name) |> E.O.fmap(((_, r)) => r); - - let getByNameResult = (t: t('a), name) => - getByName(t, name) |> E.O.toResult(name ++ " expected and not found"); - - let getByNames = (hash: t('a), names: array(string)) => - names |> E.A.fmap(name => (name, getByName(hash, name))); - }; - // Have nil as option - let getFloat = (node: node) => - node - |> ( - fun - | `RenderedDist(Discrete({xyShape: {xs: [|x|], ys: [|1.0|]}})) => - Some(x) - | `SymbolicDist(`Float(x)) => Some(x) - | _ => None - ); - - let toFloatIfNeeded = (node: node) => - switch (node |> getFloat) { - | Some(float) => `SymbolicDist(`Float(float)) - | None => node - }; - - type samplingInputs = { - sampleCount: int, - outputXYPoints: int, - kernelWidth: option(float), - shapeLength: int, - }; - - module SamplingInputs = { - type t = { - sampleCount: option(int), - outputXYPoints: option(int), - kernelWidth: option(float), - shapeLength: option(int), - }; - let withDefaults = (t: t): samplingInputs => { - sampleCount: t.sampleCount |> E.O.default(10000), - outputXYPoints: t.outputXYPoints |> E.O.default(10000), - kernelWidth: t.kernelWidth, - shapeLength: t.shapeLength |> E.O.default(10000), - }; - }; - - type environment = Belt.Map.String.t(node); - - module Environment = { - type t = environment; - module MS = Belt.Map.String; - let fromArray = MS.fromArray; - let empty: t = [||]->fromArray; - let mergeKeepSecond = (a: t, b: t) => - MS.merge(a, b, (_, a, b) => - switch (a, b) { - | (_, Some(b)) => Some(b) - | (Some(a), _) => Some(a) - | _ => None - } - ); - let update = (t, str, fn) => MS.update(t, str, fn); - let get = (t: t, str) => MS.get(t, str); - let getFunction = (t: t, str) => - switch (get(t, str)) { - | Some(`Function(argNames, fn)) => Ok((argNames, fn)) - | _ => Error("Function " ++ str ++ " not found") - }; - }; - - type evaluationParams = { - samplingInputs, - environment, - evaluateNode: (evaluationParams, node) => Belt.Result.t(node, string), - }; - - let evaluateNode = (evaluationParams: evaluationParams) => - evaluationParams.evaluateNode(evaluationParams); - - let evaluateAndRetry = (evaluationParams, fn, node) => - node - |> evaluationParams.evaluateNode(evaluationParams) - |> E.R.bind(_, fn(evaluationParams)); - - module Render = { - type t = node; - - let render = (evaluationParams: evaluationParams, r) => - `Render(r) |> evaluateNode(evaluationParams); - - let ensureIsRendered = (params, t) => - switch (t) { - | `RenderedDist(_) => Ok(t) - | _ => - switch (render(params, t)) { - | Ok(`RenderedDist(r)) => Ok(`RenderedDist(r)) - | Ok(_) => Error("Did not render as requested") - | Error(e) => Error(e) - } - }; - - let ensureIsRenderedAndGetShape = (params, t) => - switch (ensureIsRendered(params, t)) { - | Ok(`RenderedDist(r)) => Ok(r) - | Ok(_) => Error("Did not render as requested") - | Error(e) => Error(e) - }; - - let getShape = (item: node) => - switch (item) { - | `RenderedDist(r) => Some(r) - | _ => None - }; - - let _toFloat = (t: DistTypes.shape) => - switch (t) { - | Discrete({xyShape: {xs: [|x|], ys: [|1.0|]}}) => - Some(`SymbolicDist(`Float(x))) - | _ => None - }; - - let toFloat = (item: node): result(node, string) => - item - |> getShape - |> E.O.bind(_, _toFloat) - |> E.O.toResult("Not valid shape"); - }; -}; - -type simplificationResult = [ - | `Solution(ExpressionTree.node) - | `Error(string) - | `NoSolution -]; - -module Program = { - type statement = [ - | `Assignment(string, ExpressionTree.node) - | `Expression(ExpressionTree.node) - ]; - type program = array(statement); -}; diff --git a/packages/squiggle-lang/src/distPlus/expressionTree/ExpressionTypes.res b/packages/squiggle-lang/src/distPlus/expressionTree/ExpressionTypes.res new file mode 100644 index 00000000..a5e3b56f --- /dev/null +++ b/packages/squiggle-lang/src/distPlus/expressionTree/ExpressionTypes.res @@ -0,0 +1,174 @@ +type algebraicOperation = [ + | #Add + | #Multiply + | #Subtract + | #Divide + | #Exponentiate +] +type pointwiseOperation = [#Add | #Multiply | #Exponentiate] +type scaleOperation = [#Multiply | #Exponentiate | #Log] +type distToFloatOperation = [ + | #Pdf(float) + | #Cdf(float) + | #Inv(float) + | #Mean + | #Sample +] + +module ExpressionTree = { + type rec hash = array<(string, node)> + and node = [ + | #SymbolicDist(SymbolicTypes.symbolicDist) + | #RenderedDist(DistTypes.shape) + | #Symbol(string) + | #Hash(hash) + | #Array(array) + | #Function(array, node) + | #AlgebraicCombination(algebraicOperation, node, node) + | #PointwiseCombination(pointwiseOperation, node, node) + | #Normalize(node) + | #Render(node) + | #Truncate(option, option, node) + | #FunctionCall(string, array) + ] + + module Hash = { + type t<'a> = array<(string, 'a)> + let getByName = (t: t<'a>, name) => + E.A.getBy(t, ((n, _)) => n == name) |> E.O.fmap(((_, r)) => r) + + let getByNameResult = (t: t<'a>, name) => + getByName(t, name) |> E.O.toResult(name ++ " expected and not found") + + let getByNames = (hash: t<'a>, names: array) => + names |> E.A.fmap(name => (name, getByName(hash, name))) + } + // Have nil as option + let getFloat = (node: node) => + node |> ( + x => + switch x { + | #RenderedDist(Discrete({xyShape: {xs: [x], ys: [1.0]}})) => Some(x) + | #SymbolicDist(#Float(x)) => Some(x) + | _ => None + } + ) + + let toFloatIfNeeded = (node: node) => + switch node |> getFloat { + | Some(float) => #SymbolicDist(#Float(float)) + | None => node + } + + type samplingInputs = { + sampleCount: int, + outputXYPoints: int, + kernelWidth: option, + shapeLength: int, + } + + module SamplingInputs = { + type t = { + sampleCount: option, + outputXYPoints: option, + kernelWidth: option, + shapeLength: option, + } + let withDefaults = (t: t): samplingInputs => { + sampleCount: t.sampleCount |> E.O.default(10000), + outputXYPoints: t.outputXYPoints |> E.O.default(10000), + kernelWidth: t.kernelWidth, + shapeLength: t.shapeLength |> E.O.default(10000), + } + } + + type environment = Belt.Map.String.t + + module Environment = { + type t = environment + module MS = Belt.Map.String + let fromArray = MS.fromArray + let empty: t = []->fromArray + let mergeKeepSecond = (a: t, b: t) => + MS.merge(a, b, (_, a, b) => + switch (a, b) { + | (_, Some(b)) => Some(b) + | (Some(a), _) => Some(a) + | _ => None + } + ) + let update = (t, str, fn) => MS.update(t, str, fn) + let get = (t: t, str) => MS.get(t, str) + let getFunction = (t: t, str) => + switch get(t, str) { + | Some(#Function(argNames, fn)) => Ok((argNames, fn)) + | _ => Error("Function " ++ (str ++ " not found")) + } + } + + type rec evaluationParams = { + samplingInputs: samplingInputs, + environment: environment, + evaluateNode: (evaluationParams, node) => Belt.Result.t, + } + + let evaluateNode = (evaluationParams: evaluationParams) => + evaluationParams.evaluateNode(evaluationParams) + + let evaluateAndRetry = (evaluationParams, fn, node) => + node |> evaluationParams.evaluateNode(evaluationParams) |> E.R.bind(_, fn(evaluationParams)) + + module Render = { + type t = node + + let render = (evaluationParams: evaluationParams, r) => + #Render(r) |> evaluateNode(evaluationParams) + + let ensureIsRendered = (params, t) => + switch t { + | #RenderedDist(_) => Ok(t) + | _ => + switch render(params, t) { + | Ok(#RenderedDist(r)) => Ok(#RenderedDist(r)) + | Ok(_) => Error("Did not render as requested") + | Error(e) => Error(e) + } + } + + let ensureIsRenderedAndGetShape = (params, t) => + switch ensureIsRendered(params, t) { + | Ok(#RenderedDist(r)) => Ok(r) + | Ok(_) => Error("Did not render as requested") + | Error(e) => Error(e) + } + + let getShape = (item: node) => + switch item { + | #RenderedDist(r) => Some(r) + | _ => None + } + + let _toFloat = (t: DistTypes.shape) => + switch t { + | Discrete({xyShape: {xs: [x], ys: [1.0]}}) => Some(#SymbolicDist(#Float(x))) + | _ => None + } + + let toFloat = (item: node): result => + item |> getShape |> E.O.bind(_, _toFloat) |> E.O.toResult("Not valid shape") + } +} + +type simplificationResult = [ + | #Solution(ExpressionTree.node) + | #Error(string) + | #NoSolution +] + +module Program = { + type statement = [ + | #Assignment(string, ExpressionTree.node) + | #Expression(ExpressionTree.node) + ] + type program = array +} diff --git a/packages/squiggle-lang/src/distPlus/expressionTree/MathJsParser.re b/packages/squiggle-lang/src/distPlus/expressionTree/MathJsParser.re deleted file mode 100644 index 8cf55d5d..00000000 --- a/packages/squiggle-lang/src/distPlus/expressionTree/MathJsParser.re +++ /dev/null @@ -1,353 +0,0 @@ -module MathJsonToMathJsAdt = { - type arg = - | Symbol(string) - | Value(float) - | Fn(fn) - | Array(array(arg)) - | Blocks(array(arg)) - | Object(Js.Dict.t(arg)) - | Assignment(arg, arg) - | FunctionAssignment(fnAssignment) - and fn = { - name: string, - args: array(arg), - } - and fnAssignment = { - name: string, - args: array(string), - expression: arg, - }; - - let rec run = (j: Js.Json.t) => - Json.Decode.( - switch (field("mathjs", string, j)) { - | "FunctionNode" => - let args = j |> field("args", array(run)); - let name = j |> optional(field("fn", field("name", string))); - name |> E.O.fmap(name => Fn({name, args: args |> E.A.O.concatSomes})); - | "OperatorNode" => - let args = j |> field("args", array(run)); - Some( - Fn({ - name: j |> field("fn", string), - args: args |> E.A.O.concatSomes, - }), - ); - | "ConstantNode" => - optional(field("value", Json.Decode.float), j) - |> E.O.fmap(r => Value(r)) - | "ParenthesisNode" => j |> field("content", run) - | "ObjectNode" => - let properties = j |> field("properties", dict(run)); - Js.Dict.entries(properties) - |> E.A.fmap(((key, value)) => value |> E.O.fmap(v => (key, v))) - |> E.A.O.concatSomes - |> Js.Dict.fromArray - |> (r => Some(Object(r))); - | "ArrayNode" => - let items = field("items", array(run), j); - Some(Array(items |> E.A.O.concatSomes)); - | "SymbolNode" => Some(Symbol(field("name", string, j))) - | "AssignmentNode" => - let object_ = j |> field("object", run); - let value_ = j |> field("value", run); - switch (object_, value_) { - | (Some(o), Some(v)) => Some(Assignment(o, v)) - | _ => None - }; - | "BlockNode" => - let block = r => r |> field("node", run); - let args = j |> field("blocks", array(block)) |> E.A.O.concatSomes; - Some(Blocks(args)); - | "FunctionAssignmentNode" => - let name = j |> field("name", string); - let args = j |> field("params", array(field("name", string))); - let expression = j |> field("expr", run); - expression - |> E.O.fmap(expression => - FunctionAssignment({name, args, expression}) - ); - | n => - Js.log3("Couldn't parse mathjs node", j, n); - None; - } - ); -}; - -module MathAdtToDistDst = { - open MathJsonToMathJsAdt; - - let handleSymbol = sym => { - Ok(`Symbol(sym)); - }; - - // TODO: This only works on the top level, which needs to be refactored. Also, I think functions don't need to be done like this anymore. - module MathAdtCleaner = { - let transformWithSymbol = (f: float, s: string) => - switch (s) { - | "K" => Some(f *. 1000.) - | "M" => Some(f *. 1000000.) - | "B" => Some(f *. 1000000000.) - | "T" => Some(f *. 1000000000000.) - | _ => None - }; - let rec run = - fun - | Fn({name: "multiply", args: [|Value(f), Symbol(s)|]}) as doNothing => - transformWithSymbol(f, s) - |> E.O.fmap(r => Value(r)) - |> E.O.default(doNothing) - | Fn({name: "unaryMinus", args: [|Value(f)|]}) => Value((-1.0) *. f) - | Fn({name, args}) => Fn({name, args: args |> E.A.fmap(run)}) - | Array(args) => Array(args |> E.A.fmap(run)) - | Symbol(s) => Symbol(s) - | Value(v) => Value(v) - | Blocks(args) => Blocks(args |> E.A.fmap(run)) - | Assignment(a, b) => Assignment(a, run(b)) - | FunctionAssignment(a) => FunctionAssignment(a) - | Object(v) => - Object( - v - |> Js.Dict.entries - |> E.A.fmap(((key, value)) => (key, run(value))) - |> Js.Dict.fromArray, - ); - }; - - let lognormal = (args, parseArgs, nodeParser) => - switch (args) { - | [|Object(o)|] => - let g = s => - Js.Dict.get(o, s) - |> E.O.toResult("Variable was empty") - |> E.R.bind(_, nodeParser); - switch (g("mean"), g("stdev"), g("mu"), g("sigma")) { - | (Ok(mean), Ok(stdev), _, _) => - Ok(`FunctionCall(("lognormalFromMeanAndStdDev", [|mean, stdev|]))) - | (_, _, Ok(mu), Ok(sigma)) => - Ok(`FunctionCall(("lognormal", [|mu, sigma|]))) - | _ => - Error( - "Lognormal distribution needs either mean and stdev or mu and sigma", - ) - }; - | _ => - parseArgs() - |> E.R.fmap((args: array(ExpressionTypes.ExpressionTree.node)) => - `FunctionCall(("lognormal", args)) - ) - }; - - // Error("Dotwise exponentiation needs two operands") - let operationParser = - ( - name: string, - args: result(array(ExpressionTypes.ExpressionTree.node), string), - ) - : result(ExpressionTypes.ExpressionTree.node, string) => { - let toOkAlgebraic = r => Ok(`AlgebraicCombination(r)); - let toOkPointwise = r => Ok(`PointwiseCombination(r)); - let toOkTruncate = r => Ok(`Truncate(r)); - args - |> E.R.bind(_, args => { - switch (name, args) { - | ("add", [|l, r|]) => toOkAlgebraic((`Add, l, r)) - | ("add", _) => Error("Addition needs two operands") - | ("unaryMinus", [|l|]) => - toOkAlgebraic((`Multiply, `SymbolicDist(`Float(-1.0)), l)) - | ("subtract", [|l, r|]) => toOkAlgebraic((`Subtract, l, r)) - | ("subtract", _) => Error("Subtraction needs two operands") - | ("multiply", [|l, r|]) => toOkAlgebraic((`Multiply, l, r)) - | ("multiply", _) => Error("Multiplication needs two operands") - | ("pow", [|l, r|]) => toOkAlgebraic((`Exponentiate, l, r)) - | ("pow", _) => Error("Exponentiation needs two operands") - | ("dotMultiply", [|l, r|]) => toOkPointwise((`Multiply, l, r)) - | ("dotMultiply", _) => - Error("Dotwise multiplication needs two operands") - | ("dotPow", [|l, r|]) => toOkPointwise((`Exponentiate, l, r)) - | ("dotPow", _) => - Error("Dotwise exponentiation needs two operands") - | ("rightLogShift", [|l, r|]) => toOkPointwise((`Add, l, r)) - | ("rightLogShift", _) => - Error("Dotwise addition needs two operands") - | ("divide", [|l, r|]) => toOkAlgebraic((`Divide, l, r)) - | ("divide", _) => Error("Division needs two operands") - | ("leftTruncate", [|d, `SymbolicDist(`Float(lc))|]) => - toOkTruncate((Some(lc), None, d)) - | ("leftTruncate", _) => - Error( - "leftTruncate needs two arguments: the expression and the cutoff", - ) - | ("rightTruncate", [|d, `SymbolicDist(`Float(rc))|]) => - toOkTruncate((None, Some(rc), d)) - | ("rightTruncate", _) => - Error( - "rightTruncate needs two arguments: the expression and the cutoff", - ) - | ( - "truncate", - [|d, `SymbolicDist(`Float(lc)), `SymbolicDist(`Float(rc))|], - ) => - toOkTruncate((Some(lc), Some(rc), d)) - | ("truncate", _) => - Error( - "truncate needs three arguments: the expression and both cutoffs", - ) - | _ => Error("This type not currently supported") - } - }); - }; - - let functionParser = - ( - nodeParser: - MathJsonToMathJsAdt.arg => - Belt.Result.t( - SquiggleExperimental.ExpressionTypes.ExpressionTree.node, - string, - ), - name: string, - args: array(MathJsonToMathJsAdt.arg), - ) - : result(ExpressionTypes.ExpressionTree.node, string) => { - let parseArray = ags => - ags |> E.A.fmap(nodeParser) |> E.A.R.firstErrorOrOpen; - let parseArgs = () => parseArray(args); - switch (name) { - | "lognormal" => lognormal(args, parseArgs, nodeParser) - | "multimodal" - | "add" - | "subtract" - | "multiply" - | "unaryMinus" - | "dotMultiply" - | "dotPow" - | "rightLogShift" - | "divide" - | "pow" - | "leftTruncate" - | "rightTruncate" - | "truncate" => operationParser(name, parseArgs()) - | "mm" => - let weights = - args - |> E.A.last - |> E.O.bind( - _, - fun - | Array(values) => Some(parseArray(values)) - | _ => None, - ); - let possibleDists = - E.O.isSome(weights) - ? Belt.Array.slice(args, ~offset=0, ~len=E.A.length(args) - 1) - : args; - let dists = parseArray(possibleDists); - switch (weights, dists) { - | (Some(Error(r)), _) => Error(r) - | (_, Error(r)) => Error(r) - | (None, Ok(dists)) => - let hash: ExpressionTypes.ExpressionTree.node = - `FunctionCall(("multimodal", [|`Hash( - [| - ("dists", `Array(dists)), - ("weights", `Array([||])) - |] - )|])); - Ok(hash); - | (Some(Ok(weights)), Ok(dists)) => - let hash: ExpressionTypes.ExpressionTree.node = - `FunctionCall(("multimodal", [|`Hash( - [| - ("dists", `Array(dists)), - ("weights", `Array(weights)) - |] - )|])); - Ok(hash); - }; - | name => - parseArgs() - |> E.R.fmap((args: array(ExpressionTypes.ExpressionTree.node)) => - `FunctionCall((name, args)) - ) - }; - }; - - let rec nodeParser: - MathJsonToMathJsAdt.arg => - result(ExpressionTypes.ExpressionTree.node, string) = - fun - | Value(f) => Ok(`SymbolicDist(`Float(f))) - | Symbol(sym) => Ok(`Symbol(sym)) - | Fn({name, args}) => functionParser(nodeParser, name, args) - | _ => { - Error("This type not currently supported"); - }; - - // | FunctionAssignment({name, args, expression}) => { - // let evaluatedExpression = run(expression); - // `Function(_ => Ok(evaluatedExpression)); - // } - let rec topLevel = (r): result(ExpressionTypes.Program.program, string) => - switch (r) { - | FunctionAssignment({name, args, expression}) => - switch (nodeParser(expression)) { - | Ok(r) => Ok([|`Assignment((name, `Function((args, r))))|]) - | Error(r) => Error(r) - } - | Value(_) as r => nodeParser(r) |> E.R.fmap(r => [|`Expression(r)|]) - | Fn(_) as r => nodeParser(r) |> E.R.fmap(r => [|`Expression(r)|]) - | Array(_) => Error("Array not valid as top level") - | Symbol(s) => handleSymbol(s) |> E.R.fmap(r => [|`Expression(r)|]) - | Object(_) => Error("Object not valid as top level") - | Assignment(name, value) => - switch (name) { - | Symbol(symbol) => - nodeParser(value) |> E.R.fmap(r => [|`Assignment((symbol, r))|]) - | _ => Error("Symbol not a string") - } - | Blocks(blocks) => - blocks - |> E.A.fmap(b => topLevel(b)) - |> E.A.R.firstErrorOrOpen - |> E.R.fmap(E.A.concatMany) - }; - - let run = (r): result(ExpressionTypes.Program.program, string) => - r |> MathAdtCleaner.run |> topLevel; -}; - -/* The MathJs parser doesn't support '.+' syntax, but we want it because it - would make sense with '.*'. Our workaround is to change this to >>>, which is - logShift in mathJS. We don't expect to use logShift anytime soon, so this tradeoff - seems fine. - */ -let pointwiseToRightLogShift = Js.String.replaceByRe([%re "/\.\+/g"], ">>>"); - -let fromString2 = str => { - /* We feed the user-typed string into Mathjs.parseMath, - which returns a JSON with (hopefully) a single-element array. - This array element is the top-level node of a nested-object tree - representing the functions/arguments/values/etc. in the string. - - The function MathJsonToMathJsAdt then recursively unpacks this JSON into a typed data structure we can use. - Inside of this function, MathAdtToDistDst is called whenever a distribution function is encountered. - */ - let mathJsToJson = str |> pointwiseToRightLogShift |> Mathjs.parseMath; - - let mathJsParse = - E.R.bind(mathJsToJson, r => { - switch (MathJsonToMathJsAdt.run(r)) { - | Some(r) => Ok(r) - | None => Error("MathJsParse Error") - } - }); - - let value = E.R.bind(mathJsParse, MathAdtToDistDst.run); - Js.log2(mathJsParse, value); - value; -}; - -let fromString = str => { - fromString2(str); -}; diff --git a/packages/squiggle-lang/src/distPlus/expressionTree/MathJsParser.res b/packages/squiggle-lang/src/distPlus/expressionTree/MathJsParser.res new file mode 100644 index 00000000..274352f4 --- /dev/null +++ b/packages/squiggle-lang/src/distPlus/expressionTree/MathJsParser.res @@ -0,0 +1,304 @@ +module MathJsonToMathJsAdt = { + type rec arg = + | Symbol(string) + | Value(float) + | Fn(fn) + | Array(array) + | Blocks(array) + | Object(Js.Dict.t) + | Assignment(arg, arg) + | FunctionAssignment(fnAssignment) + and fn = { + name: string, + args: array, + } + and fnAssignment = { + name: string, + args: array, + expression: arg, + } + + let rec run = (j: Js.Json.t) => { + open Json.Decode + switch field("mathjs", string, j) { + | "FunctionNode" => + let args = j |> field("args", array(run)) + let name = j |> optional(field("fn", field("name", string))) + name |> E.O.fmap(name => Fn({name: name, args: args |> E.A.O.concatSomes})) + | "OperatorNode" => + let args = j |> field("args", array(run)) + Some( + Fn({ + name: j |> field("fn", string), + args: args |> E.A.O.concatSomes, + }), + ) + | "ConstantNode" => optional(field("value", Json.Decode.float), j) |> E.O.fmap(r => Value(r)) + | "ParenthesisNode" => j |> field("content", run) + | "ObjectNode" => + let properties = j |> field("properties", dict(run)) + Js.Dict.entries(properties) + |> E.A.fmap(((key, value)) => value |> E.O.fmap(v => (key, v))) + |> E.A.O.concatSomes + |> Js.Dict.fromArray + |> (r => Some(Object(r))) + | "ArrayNode" => + let items = field("items", array(run), j) + Some(Array(items |> E.A.O.concatSomes)) + | "SymbolNode" => Some(Symbol(field("name", string, j))) + | "AssignmentNode" => + let object_ = j |> field("object", run) + let value_ = j |> field("value", run) + switch (object_, value_) { + | (Some(o), Some(v)) => Some(Assignment(o, v)) + | _ => None + } + | "BlockNode" => + let block = r => r |> field("node", run) + let args = j |> field("blocks", array(block)) |> E.A.O.concatSomes + Some(Blocks(args)) + | "FunctionAssignmentNode" => + let name = j |> field("name", string) + let args = j |> field("params", array(field("name", string))) + let expression = j |> field("expr", run) + expression |> E.O.fmap(expression => FunctionAssignment({ + name: name, + args: args, + expression: expression, + })) + | n => + Js.log3("Couldn't parse mathjs node", j, n) + None + } + } +} + +module MathAdtToDistDst = { + open MathJsonToMathJsAdt + + let handleSymbol = sym => Ok(#Symbol(sym)) + + // TODO: This only works on the top level, which needs to be refactored. Also, I think functions don't need to be done like this anymore. + module MathAdtCleaner = { + let transformWithSymbol = (f: float, s: string) => + switch s { + | "K" => Some(f *. 1000.) + | "M" => Some(f *. 1000000.) + | "B" => Some(f *. 1000000000.) + | "T" => Some(f *. 1000000000000.) + | _ => None + } + let rec run = x => + switch x { + | Fn({name: "multiply", args: [Value(f), Symbol(s)]}) as doNothing => + transformWithSymbol(f, s) |> E.O.fmap(r => Value(r)) |> E.O.default(doNothing) + | Fn({name: "unaryMinus", args: [Value(f)]}) => Value(-1.0 *. f) + | Fn({name, args}) => Fn({name: name, args: args |> E.A.fmap(run)}) + | Array(args) => Array(args |> E.A.fmap(run)) + | Symbol(s) => Symbol(s) + | Value(v) => Value(v) + | Blocks(args) => Blocks(args |> E.A.fmap(run)) + | Assignment(a, b) => Assignment(a, run(b)) + | FunctionAssignment(a) => FunctionAssignment(a) + | Object(v) => + Object( + v + |> Js.Dict.entries + |> E.A.fmap(((key, value)) => (key, run(value))) + |> Js.Dict.fromArray, + ) + } + } + + let lognormal = (args, parseArgs, nodeParser) => + switch args { + | [Object(o)] => + let g = s => + Js.Dict.get(o, s) |> E.O.toResult("Variable was empty") |> E.R.bind(_, nodeParser) + switch (g("mean"), g("stdev"), g("mu"), g("sigma")) { + | (Ok(mean), Ok(stdev), _, _) => + Ok(#FunctionCall("lognormalFromMeanAndStdDev", [mean, stdev])) + | (_, _, Ok(mu), Ok(sigma)) => Ok(#FunctionCall("lognormal", [mu, sigma])) + | _ => Error("Lognormal distribution needs either mean and stdev or mu and sigma") + } + | _ => + parseArgs() |> E.R.fmap((args: array) => + #FunctionCall("lognormal", args) + ) + } + + // Error("Dotwise exponentiation needs two operands") + let operationParser = ( + name: string, + args: result, string>, + ): result => { + let toOkAlgebraic = r => Ok(#AlgebraicCombination(r)) + let toOkPointwise = r => Ok(#PointwiseCombination(r)) + let toOkTruncate = r => Ok(#Truncate(r)) + args |> E.R.bind(_, args => + switch (name, args) { + | ("add", [l, r]) => toOkAlgebraic((#Add, l, r)) + | ("add", _) => Error("Addition needs two operands") + | ("unaryMinus", [l]) => toOkAlgebraic((#Multiply, #SymbolicDist(#Float(-1.0)), l)) + | ("subtract", [l, r]) => toOkAlgebraic((#Subtract, l, r)) + | ("subtract", _) => Error("Subtraction needs two operands") + | ("multiply", [l, r]) => toOkAlgebraic((#Multiply, l, r)) + | ("multiply", _) => Error("Multiplication needs two operands") + | ("pow", [l, r]) => toOkAlgebraic((#Exponentiate, l, r)) + | ("pow", _) => Error("Exponentiation needs two operands") + | ("dotMultiply", [l, r]) => toOkPointwise((#Multiply, l, r)) + | ("dotMultiply", _) => Error("Dotwise multiplication needs two operands") + | ("dotPow", [l, r]) => toOkPointwise((#Exponentiate, l, r)) + | ("dotPow", _) => Error("Dotwise exponentiation needs two operands") + | ("rightLogShift", [l, r]) => toOkPointwise((#Add, l, r)) + | ("rightLogShift", _) => Error("Dotwise addition needs two operands") + | ("divide", [l, r]) => toOkAlgebraic((#Divide, l, r)) + | ("divide", _) => Error("Division needs two operands") + | ("leftTruncate", [d, #SymbolicDist(#Float(lc))]) => toOkTruncate((Some(lc), None, d)) + | ("leftTruncate", _) => + Error("leftTruncate needs two arguments: the expression and the cutoff") + | ("rightTruncate", [d, #SymbolicDist(#Float(rc))]) => toOkTruncate((None, Some(rc), d)) + | ("rightTruncate", _) => + Error("rightTruncate needs two arguments: the expression and the cutoff") + | ("truncate", [d, #SymbolicDist(#Float(lc)), #SymbolicDist(#Float(rc))]) => + toOkTruncate((Some(lc), Some(rc), d)) + | ("truncate", _) => Error("truncate needs three arguments: the expression and both cutoffs") + | _ => Error("This type not currently supported") + } + ) + } + + let functionParser = ( + nodeParser: MathJsonToMathJsAdt.arg => Belt.Result.t< + SquiggleExperimental.ExpressionTypes.ExpressionTree.node, + string, + >, + name: string, + args: array, + ): result => { + let parseArray = ags => ags |> E.A.fmap(nodeParser) |> E.A.R.firstErrorOrOpen + let parseArgs = () => parseArray(args) + switch name { + | "lognormal" => lognormal(args, parseArgs, nodeParser) + | "multimodal" + | "add" + | "subtract" + | "multiply" + | "unaryMinus" + | "dotMultiply" + | "dotPow" + | "rightLogShift" + | "divide" + | "pow" + | "leftTruncate" + | "rightTruncate" + | "truncate" => + operationParser(name, parseArgs()) + | "mm" => + let weights = + args + |> E.A.last + |> E.O.bind(_, x => + switch x { + | Array(values) => Some(parseArray(values)) + | _ => None + } + ) + let possibleDists = E.O.isSome(weights) + ? Belt.Array.slice(args, ~offset=0, ~len=E.A.length(args) - 1) + : args + let dists = parseArray(possibleDists) + switch (weights, dists) { + | (Some(Error(r)), _) => Error(r) + | (_, Error(r)) => Error(r) + | (None, Ok(dists)) => + let hash: ExpressionTypes.ExpressionTree.node = #FunctionCall( + "multimodal", + [#Hash([("dists", #Array(dists)), ("weights", #Array([]))])], + ) + Ok(hash) + | (Some(Ok(weights)), Ok(dists)) => + let hash: ExpressionTypes.ExpressionTree.node = #FunctionCall( + "multimodal", + [#Hash([("dists", #Array(dists)), ("weights", #Array(weights))])], + ) + Ok(hash) + } + | name => + parseArgs() |> E.R.fmap((args: array) => + #FunctionCall(name, args) + ) + } + } + + let rec nodeParser: MathJsonToMathJsAdt.arg => result< + ExpressionTypes.ExpressionTree.node, + string, + > = x => + switch x { + | Value(f) => Ok(#SymbolicDist(#Float(f))) + | Symbol(sym) => Ok(#Symbol(sym)) + | Fn({name, args}) => functionParser(nodeParser, name, args) + | _ => Error("This type not currently supported") + } + + // | FunctionAssignment({name, args, expression}) => { + // let evaluatedExpression = run(expression); + // `Function(_ => Ok(evaluatedExpression)); + // } + let rec topLevel = (r): result => + switch r { + | FunctionAssignment({name, args, expression}) => + switch nodeParser(expression) { + | Ok(r) => Ok([#Assignment(name, #Function(args, r))]) + | Error(r) => Error(r) + } + | Value(_) as r => nodeParser(r) |> E.R.fmap(r => [#Expression(r)]) + | Fn(_) as r => nodeParser(r) |> E.R.fmap(r => [#Expression(r)]) + | Array(_) => Error("Array not valid as top level") + | Symbol(s) => handleSymbol(s) |> E.R.fmap(r => [#Expression(r)]) + | Object(_) => Error("Object not valid as top level") + | Assignment(name, value) => + switch name { + | Symbol(symbol) => nodeParser(value) |> E.R.fmap(r => [#Assignment(symbol, r)]) + | _ => Error("Symbol not a string") + } + | Blocks(blocks) => + blocks |> E.A.fmap(b => topLevel(b)) |> E.A.R.firstErrorOrOpen |> E.R.fmap(E.A.concatMany) + } + + let run = (r): result => + r |> MathAdtCleaner.run |> topLevel +} + +/* The MathJs parser doesn't support '.+' syntax, but we want it because it + would make sense with '.*'. Our workaround is to change this to >>>, which is + logShift in mathJS. We don't expect to use logShift anytime soon, so this tradeoff + seems fine. + */ +let pointwiseToRightLogShift = Js.String.replaceByRe(%re("/\.\+/g"), ">>>") + +let fromString2 = str => { + /* We feed the user-typed string into Mathjs.parseMath, + which returns a JSON with (hopefully) a single-element array. + This array element is the top-level node of a nested-object tree + representing the functions/arguments/values/etc. in the string. + + The function MathJsonToMathJsAdt then recursively unpacks this JSON into a typed data structure we can use. + Inside of this function, MathAdtToDistDst is called whenever a distribution function is encountered. + */ + let mathJsToJson = str |> pointwiseToRightLogShift |> Mathjs.parseMath + + let mathJsParse = E.R.bind(mathJsToJson, r => + switch MathJsonToMathJsAdt.run(r) { + | Some(r) => Ok(r) + | None => Error("MathJsParse Error") + } + ) + + let value = E.R.bind(mathJsParse, MathAdtToDistDst.run) + Js.log2(mathJsParse, value) + value +} + +let fromString = str => fromString2(str) diff --git a/packages/squiggle-lang/src/distPlus/expressionTree/Mathjs.re b/packages/squiggle-lang/src/distPlus/expressionTree/Mathjs.re deleted file mode 100644 index 1b0c7bc7..00000000 --- a/packages/squiggle-lang/src/distPlus/expressionTree/Mathjs.re +++ /dev/null @@ -1,10 +0,0 @@ -[@bs.module "./MathjsWrapper.js"] -external parseMathExt: string => Js.Json.t = "parseMath"; - -let parseMath = (str: string): result(Js.Json.t, string) => - switch (parseMathExt(str)) { - | exception (Js.Exn.Error(err)) => - Error(Js.Exn.message(err) |> E.O.default("MathJS Parse Error")) - | exception _ => Error("MathJS Parse Error") - | j => Ok(j) - }; \ No newline at end of file diff --git a/packages/squiggle-lang/src/distPlus/expressionTree/Mathjs.res b/packages/squiggle-lang/src/distPlus/expressionTree/Mathjs.res new file mode 100644 index 00000000..5e1be176 --- /dev/null +++ b/packages/squiggle-lang/src/distPlus/expressionTree/Mathjs.res @@ -0,0 +1,9 @@ +@module("./MathjsWrapper.js") +external parseMathExt: string => Js.Json.t = "parseMath" + +let parseMath = (str: string): result => + switch parseMathExt(str) { + | exception Js.Exn.Error(err) => Error(Js.Exn.message(err) |> E.O.default("MathJS Parse Error")) + | exception _ => Error("MathJS Parse Error") + | j => Ok(j) + } diff --git a/packages/squiggle-lang/src/distPlus/expressionTree/Operation.re b/packages/squiggle-lang/src/distPlus/expressionTree/Operation.re deleted file mode 100644 index 6f93af1a..00000000 --- a/packages/squiggle-lang/src/distPlus/expressionTree/Operation.re +++ /dev/null @@ -1,105 +0,0 @@ -open ExpressionTypes; - -module Algebraic = { - type t = algebraicOperation; - let toFn: (t, float, float) => float = - fun - | `Add => (+.) - | `Subtract => (-.) - | `Multiply => ( *. ) - | `Exponentiate => ( ** ) - | `Divide => (/.); - - let applyFn = (t, f1, f2) => { - switch (t, f1, f2) { - | (`Divide, _, 0.) => Error("Cannot divide $v1 by zero.") - | _ => Ok(toFn(t, f1, f2)) - }; - }; - - let toString = - fun - | `Add => "+" - | `Subtract => "-" - | `Multiply => "*" - | `Exponentiate => ( "**" ) - | `Divide => "/"; - - let format = (a, b, c) => b ++ " " ++ toString(a) ++ " " ++ c; -}; - -module Pointwise = { - type t = pointwiseOperation; - let toString = - fun - | `Add => "+" - | `Exponentiate => "^" - | `Multiply => "*"; - - let format = (a, b, c) => b ++ " " ++ toString(a) ++ " " ++ c; -}; - -module DistToFloat = { - type t = distToFloatOperation; - - let format = (operation, value) => - switch (operation) { - | `Cdf(f) => {j|cdf(x=$f,$value)|j} - | `Pdf(f) => {j|pdf(x=$f,$value)|j} - | `Inv(f) => {j|inv(x=$f,$value)|j} - | `Sample => "sample($value)" - | `Mean => "mean($value)" - }; -}; - -// Note that different logarithms don't really do anything. -module Scale = { - type t = scaleOperation; - let toFn = - fun - | `Multiply => ( *. ) - | `Exponentiate => ( ** ) - | `Log => ((a, b) => log(a) /. log(b)); - - let format = (operation: t, value, scaleBy) => - switch (operation) { - | `Multiply => {j|verticalMultiply($value, $scaleBy) |j} - | `Exponentiate => {j|verticalExponentiate($value, $scaleBy) |j} - | `Log => {j|verticalLog($value, $scaleBy) |j} - }; - - let toIntegralSumCacheFn = - fun - | `Multiply => ((a, b) => Some(a *. b)) - | `Exponentiate => ((_, _) => None) - | `Log => ((_, _) => None); - - let toIntegralCacheFn = - fun - | `Multiply => ((a, b) => None) // TODO: this could probably just be multiplied out (using Continuous.scaleBy) - | `Exponentiate => ((_, _) => None) - | `Log => ((_, _) => None); -}; - -module T = { - let truncateToString = - (left: option(float), right: option(float), nodeToString) => { - let left = left |> E.O.dimap(Js.Float.toString, () => "-inf"); - let right = right |> E.O.dimap(Js.Float.toString, () => "inf"); - {j|truncate($nodeToString, $left, $right)|j}; - }; - let toString = nodeToString => - fun - | `AlgebraicCombination(op, t1, t2) => - Algebraic.format(op, nodeToString(t1), nodeToString(t2)) - | `PointwiseCombination(op, t1, t2) => - Pointwise.format(op, nodeToString(t1), nodeToString(t2)) - | `VerticalScaling(scaleOp, t, scaleBy) => - Scale.format(scaleOp, nodeToString(t), nodeToString(scaleBy)) - | `Normalize(t) => "normalize(k" ++ nodeToString(t) ++ ")" - | `FloatFromDist(floatFromDistOp, t) => - DistToFloat.format(floatFromDistOp, nodeToString(t)) - | `Truncate(lc, rc, t) => truncateToString(lc, rc, nodeToString(t)) - | `Render(t) => nodeToString(t) - | _ => ""; // SymbolicDist and RenderedDist are handled in ExpressionTree.toString. -}; diff --git a/packages/squiggle-lang/src/distPlus/expressionTree/Operation.res b/packages/squiggle-lang/src/distPlus/expressionTree/Operation.res new file mode 100644 index 00000000..0e0b51ed --- /dev/null +++ b/packages/squiggle-lang/src/distPlus/expressionTree/Operation.res @@ -0,0 +1,107 @@ +open ExpressionTypes + +module Algebraic = { + type t = algebraicOperation + let toFn: (t, float, float) => float = x => + switch x { + | #Add => \"+." + | #Subtract => \"-." + | #Multiply => \"*." + | #Exponentiate => \"**" + | #Divide => \"/." + } + + let applyFn = (t, f1, f2) => + switch (t, f1, f2) { + | (#Divide, _, 0.) => Error("Cannot divide $v1 by zero.") + | _ => Ok(toFn(t, f1, f2)) + } + + let toString = x => + switch x { + | #Add => "+" + | #Subtract => "-" + | #Multiply => "*" + | #Exponentiate => "**" + | #Divide => "/" + } + + let format = (a, b, c) => b ++ (" " ++ (toString(a) ++ (" " ++ c))) +} + +module Pointwise = { + type t = pointwiseOperation + let toString = x => + switch x { + | #Add => "+" + | #Exponentiate => "^" + | #Multiply => "*" + } + + let format = (a, b, c) => b ++ (" " ++ (toString(a) ++ (" " ++ c))) +} + +module DistToFloat = { + type t = distToFloatOperation + + let format = (operation, value) => + switch operation { + | #Cdf(f) => j`cdf(x=$f,$value)` + | #Pdf(f) => j`pdf(x=$f,$value)` + | #Inv(f) => j`inv(x=$f,$value)` + | #Sample => "sample($value)" + | #Mean => "mean($value)" + } +} + +// Note that different logarithms don't really do anything. +module Scale = { + type t = scaleOperation + let toFn = x => + switch x { + | #Multiply => \"*." + | #Exponentiate => \"**" + | #Log => (a, b) => log(a) /. log(b) + } + + let format = (operation: t, value, scaleBy) => + switch operation { + | #Multiply => j`verticalMultiply($value, $scaleBy) ` + | #Exponentiate => j`verticalExponentiate($value, $scaleBy) ` + | #Log => j`verticalLog($value, $scaleBy) ` + } + + let toIntegralSumCacheFn = x => + switch x { + | #Multiply => (a, b) => Some(a *. b) + | #Exponentiate => (_, _) => None + | #Log => (_, _) => None + } + + let toIntegralCacheFn = x => + switch x { + | #Multiply => (a, b) => None // TODO: this could probably just be multiplied out (using Continuous.scaleBy) + | #Exponentiate => (_, _) => None + | #Log => (_, _) => None + } +} + +module T = { + let truncateToString = (left: option, right: option, nodeToString) => { + let left = left |> E.O.dimap(Js.Float.toString, () => "-inf") + let right = right |> E.O.dimap(Js.Float.toString, () => "inf") + j`truncate($nodeToString, $left, $right)` + } + let toString = (nodeToString, x) => + switch x { + | #AlgebraicCombination(op, t1, t2) => Algebraic.format(op, nodeToString(t1), nodeToString(t2)) + | #PointwiseCombination(op, t1, t2) => Pointwise.format(op, nodeToString(t1), nodeToString(t2)) + | #VerticalScaling(scaleOp, t, scaleBy) => + Scale.format(scaleOp, nodeToString(t), nodeToString(scaleBy)) + | #Normalize(t) => "normalize(k" ++ (nodeToString(t) ++ ")") + | #FloatFromDist(floatFromDistOp, t) => DistToFloat.format(floatFromDistOp, nodeToString(t)) + | #Truncate(lc, rc, t) => truncateToString(lc, rc, nodeToString(t)) + | #Render(t) => nodeToString(t) + | _ => "" + } // SymbolicDist and RenderedDist are handled in ExpressionTree.toString. +} diff --git a/packages/squiggle-lang/src/distPlus/expressionTree/PTypes.re b/packages/squiggle-lang/src/distPlus/expressionTree/PTypes.re deleted file mode 100644 index 90c92dc0..00000000 --- a/packages/squiggle-lang/src/distPlus/expressionTree/PTypes.re +++ /dev/null @@ -1,143 +0,0 @@ -open ExpressionTypes.ExpressionTree; - -module Function = { - type t = (array(string), node); - let fromNode: node => option(t) = - node => - switch (node) { - | `Function(r) => Some(r) - | _ => None - }; - let argumentNames = ((a, _): t) => a; - let internals = ((_, b): t) => b; - let run = - ( - evaluationParams: ExpressionTypes.ExpressionTree.evaluationParams, - args: array(node), - t: t, - ) => - if (E.A.length(args) == E.A.length(argumentNames(t))) { - let newEnvironment = - Belt.Array.zip(argumentNames(t), args) - |> ExpressionTypes.ExpressionTree.Environment.fromArray; - let newEvaluationParams: ExpressionTypes.ExpressionTree.evaluationParams = { - samplingInputs: evaluationParams.samplingInputs, - environment: - ExpressionTypes.ExpressionTree.Environment.mergeKeepSecond( - evaluationParams.environment, - newEnvironment, - ), - evaluateNode: evaluationParams.evaluateNode, - }; - evaluationParams.evaluateNode(newEvaluationParams, internals(t)); - } else { - Error("Wrong number of variables"); - }; - -}; - -module Primative = { - type t = [ - | `SymbolicDist(SymbolicTypes.symbolicDist) - | `RenderedDist(DistTypes.shape) - | `Function(array(string), node) - ]; - - let isPrimative: node => bool = - fun - | `SymbolicDist(_) - | `RenderedDist(_) - | `Function(_) => true - | _ => false; - - let fromNode: node => option(t) = - fun - | `SymbolicDist(_) as n - | `RenderedDist(_) as n - | `Function(_) as n => Some(n) - | _ => None; -}; - -module SamplingDistribution = { - type t = [ - | `SymbolicDist(SymbolicTypes.symbolicDist) - | `RenderedDist(DistTypes.shape) - ]; - - let isSamplingDistribution: node => bool = - fun - | `SymbolicDist(_) => true - | `RenderedDist(_) => true - | _ => false; - - let fromNode: node => result(t, string) = - fun - | `SymbolicDist(n) => Ok(`SymbolicDist(n)) - | `RenderedDist(n) => Ok(`RenderedDist(n)) - | _ => Error("Not valid type"); - - let renderIfIsNotSamplingDistribution = (params, t): result(node, string) => - !isSamplingDistribution(t) - ? switch (Render.render(params, t)) { - | Ok(r) => Ok(r) - | Error(e) => Error(e) - } - : Ok(t); - - let map = (~renderedDistFn, ~symbolicDistFn, node: node) => - node - |> ( - fun - | `RenderedDist(r) => Some(renderedDistFn(r)) - | `SymbolicDist(s) => Some(symbolicDistFn(s)) - | _ => None - ); - - let sampleN = n => - map( - ~renderedDistFn=Shape.sampleNRendered(n), - ~symbolicDistFn=SymbolicDist.T.sampleN(n), - ); - - let getCombinationSamples = (n, algebraicOp, t1: node, t2: node) => { - switch (sampleN(n, t1), sampleN(n, t2)) { - | (Some(a), Some(b)) => - Some( - Belt.Array.zip(a, b) - |> E.A.fmap(((a, b)) => Operation.Algebraic.toFn(algebraicOp, a, b)), - ) - | _ => None - }; - }; - - let combineShapesUsingSampling = - (evaluationParams: evaluationParams, algebraicOp, t1: node, t2: node) => { - let i1 = renderIfIsNotSamplingDistribution(evaluationParams, t1); - let i2 = renderIfIsNotSamplingDistribution(evaluationParams, t2); - E.R.merge(i1, i2) - |> E.R.bind( - _, - ((a, b)) => { - let samples = - getCombinationSamples( - evaluationParams.samplingInputs.sampleCount, - algebraicOp, - a, - b, - ); - - // todo: This bottom part should probably be somewhere else. - let shape = - samples - |> E.O.fmap( - SamplesToShape.fromSamples( - ~samplingInputs=evaluationParams.samplingInputs, - ), - ) - |> E.O.bind(_, r => r.shape) - |> E.O.toResult("No response"); - shape |> E.R.fmap(r => `Normalize(`RenderedDist(r))); - }, - ); - }; -}; diff --git a/packages/squiggle-lang/src/distPlus/expressionTree/PTypes.res b/packages/squiggle-lang/src/distPlus/expressionTree/PTypes.res new file mode 100644 index 00000000..9d804de0 --- /dev/null +++ b/packages/squiggle-lang/src/distPlus/expressionTree/PTypes.res @@ -0,0 +1,137 @@ +open ExpressionTypes.ExpressionTree + +module Function = { + type t = (array, node) + let fromNode: node => option = node => + switch node { + | #Function(r) => Some(r) + | _ => None + } + let argumentNames = ((a, _): t) => a + let internals = ((_, b): t) => b + let run = ( + evaluationParams: ExpressionTypes.ExpressionTree.evaluationParams, + args: array, + t: t, + ) => + if E.A.length(args) == E.A.length(argumentNames(t)) { + let newEnvironment = + Belt.Array.zip( + argumentNames(t), + args, + ) |> ExpressionTypes.ExpressionTree.Environment.fromArray + let newEvaluationParams: ExpressionTypes.ExpressionTree.evaluationParams = { + samplingInputs: evaluationParams.samplingInputs, + environment: ExpressionTypes.ExpressionTree.Environment.mergeKeepSecond( + evaluationParams.environment, + newEnvironment, + ), + evaluateNode: evaluationParams.evaluateNode, + } + evaluationParams.evaluateNode(newEvaluationParams, internals(t)) + } else { + Error("Wrong number of variables") + } +} + +module Primative = { + type t = [ + | #SymbolicDist(SymbolicTypes.symbolicDist) + | #RenderedDist(DistTypes.shape) + | #Function(array, node) + ] + + let isPrimative: node => bool = x => + switch x { + | #SymbolicDist(_) + | #RenderedDist(_) + | #Function(_) => true + | _ => false + } + + let fromNode: node => option = x => + switch x { + | #SymbolicDist(_) as n + | #RenderedDist(_) as n + | #Function(_) as n => + Some(n) + | _ => None + } +} + +module SamplingDistribution = { + type t = [ + | #SymbolicDist(SymbolicTypes.symbolicDist) + | #RenderedDist(DistTypes.shape) + ] + + let isSamplingDistribution: node => bool = x => + switch x { + | #SymbolicDist(_) => true + | #RenderedDist(_) => true + | _ => false + } + + let fromNode: node => result = x => + switch x { + | #SymbolicDist(n) => Ok(#SymbolicDist(n)) + | #RenderedDist(n) => Ok(#RenderedDist(n)) + | _ => Error("Not valid type") + } + + let renderIfIsNotSamplingDistribution = (params, t): result => + !isSamplingDistribution(t) + ? switch Render.render(params, t) { + | Ok(r) => Ok(r) + | Error(e) => Error(e) + } + : Ok(t) + + let map = (~renderedDistFn, ~symbolicDistFn, node: node) => + node |> ( + x => + switch x { + | #RenderedDist(r) => Some(renderedDistFn(r)) + | #SymbolicDist(s) => Some(symbolicDistFn(s)) + | _ => None + } + ) + + let sampleN = n => + map(~renderedDistFn=Shape.sampleNRendered(n), ~symbolicDistFn=SymbolicDist.T.sampleN(n)) + + let getCombinationSamples = (n, algebraicOp, t1: node, t2: node) => + switch (sampleN(n, t1), sampleN(n, t2)) { + | (Some(a), Some(b)) => + Some( + Belt.Array.zip(a, b) |> E.A.fmap(((a, b)) => Operation.Algebraic.toFn(algebraicOp, a, b)), + ) + | _ => None + } + + let combineShapesUsingSampling = ( + evaluationParams: evaluationParams, + algebraicOp, + t1: node, + t2: node, + ) => { + let i1 = renderIfIsNotSamplingDistribution(evaluationParams, t1) + let i2 = renderIfIsNotSamplingDistribution(evaluationParams, t2) + E.R.merge(i1, i2) |> E.R.bind(_, ((a, b)) => { + let samples = getCombinationSamples( + evaluationParams.samplingInputs.sampleCount, + algebraicOp, + a, + b, + ) + + // todo: This bottom part should probably be somewhere else. + let shape = + samples + |> E.O.fmap(SamplesToShape.fromSamples(~samplingInputs=evaluationParams.samplingInputs)) + |> E.O.bind(_, r => r.shape) + |> E.O.toResult("No response") + shape |> E.R.fmap(r => #Normalize(#RenderedDist(r))) + }) + } +} diff --git a/packages/squiggle-lang/src/distPlus/expressionTree/Program.re b/packages/squiggle-lang/src/distPlus/expressionTree/Program.re deleted file mode 100644 index 7adfb3d8..00000000 --- a/packages/squiggle-lang/src/distPlus/expressionTree/Program.re +++ /dev/null @@ -1,5 +0,0 @@ -type t = ExpressionTypes.Program.program; - -let last = (r:t) => E.A.last(r) |> E.O.toResult("No rendered lines"); -// let run = (p:program) => p |> E.A.last |> E.O.fmap(r => -// ) \ No newline at end of file diff --git a/packages/squiggle-lang/src/distPlus/expressionTree/Program.res b/packages/squiggle-lang/src/distPlus/expressionTree/Program.res new file mode 100644 index 00000000..d3404712 --- /dev/null +++ b/packages/squiggle-lang/src/distPlus/expressionTree/Program.res @@ -0,0 +1,5 @@ +type t = ExpressionTypes.Program.program + +let last = (r: t) => E.A.last(r) |> E.O.toResult("No rendered lines") +// let run = (p:program) => p |> E.A.last |> E.O.fmap(r => +// ) diff --git a/packages/squiggle-lang/src/distPlus/samplesRenderer/Bandwidth.re b/packages/squiggle-lang/src/distPlus/samplesRenderer/Bandwidth.re deleted file mode 100644 index d2315639..00000000 --- a/packages/squiggle-lang/src/distPlus/samplesRenderer/Bandwidth.re +++ /dev/null @@ -1,30 +0,0 @@ -//The math here was taken from https://github.com/jasondavies/science.js/blob/master/src/stats/bandwidth.js - -let len = x => E.A.length(x) |> float_of_int; - -let iqr = x => { - Jstat.percentile(x, 0.75, true) -. Jstat.percentile(x, 0.25, true); -}; - -// Silverman, B. W. (1986) Density Estimation. London: Chapman and Hall. -let nrd0 = x => { - let hi = Js_math.sqrt(Jstat.variance(x)); - let lo = Js_math.minMany_float([|hi, iqr(x) /. 1.34|]); - let e = Js_math.abs_float(x[1]); - let lo' = - switch (lo, hi, e) { - | (lo, _, _) when !Js.Float.isNaN(lo) => lo - | (_, hi, _) when !Js.Float.isNaN(hi) => hi - | (_, _, e) when !Js.Float.isNaN(e) => e - | _ => 1.0 - }; - 0.9 *. lo' *. Js.Math.pow_float(~base=len(x), ~exp=-0.2); -}; - -// Scott, D. W. (1992) Multivariate Density Estimation: Theory, Practice, and Visualization. Wiley. -let nrd = x => { - let h = iqr(x) /. 1.34; - 1.06 - *. Js.Math.min_float(Js.Math.sqrt(Jstat.variance(x)), h) - *. Js.Math.pow_float(~base=len(x), ~exp=(-1.0) /. 5.0); -}; \ No newline at end of file diff --git a/packages/squiggle-lang/src/distPlus/samplesRenderer/Bandwidth.res b/packages/squiggle-lang/src/distPlus/samplesRenderer/Bandwidth.res new file mode 100644 index 00000000..6650b862 --- /dev/null +++ b/packages/squiggle-lang/src/distPlus/samplesRenderer/Bandwidth.res @@ -0,0 +1,27 @@ +//The math here was taken from https://github.com/jasondavies/science.js/blob/master/src/stats/bandwidth.js + +let len = x => E.A.length(x) |> float_of_int + +let iqr = x => Jstat.percentile(x, 0.75, true) -. Jstat.percentile(x, 0.25, true) + +// Silverman, B. W. (1986) Density Estimation. London: Chapman and Hall. +let nrd0 = x => { + let hi = Js_math.sqrt(Jstat.variance(x)) + let lo = Js_math.minMany_float([hi, iqr(x) /. 1.34]) + let e = Js_math.abs_float(x[1]) + let lo' = switch (lo, hi, e) { + | (lo, _, _) if !Js.Float.isNaN(lo) => lo + | (_, hi, _) if !Js.Float.isNaN(hi) => hi + | (_, _, e) if !Js.Float.isNaN(e) => e + | _ => 1.0 + } + 0.9 *. lo' *. Js.Math.pow_float(~base=len(x), ~exp=-0.2) +} + +// Scott, D. W. (1992) Multivariate Density Estimation: Theory, Practice, and Visualization. Wiley. +let nrd = x => { + let h = iqr(x) /. 1.34 + 1.06 *. + Js.Math.min_float(Js.Math.sqrt(Jstat.variance(x)), h) *. + Js.Math.pow_float(~base=len(x), ~exp=-1.0 /. 5.0) +} diff --git a/packages/squiggle-lang/src/distPlus/samplesRenderer/SamplesToShape.re b/packages/squiggle-lang/src/distPlus/samplesRenderer/SamplesToShape.re deleted file mode 100644 index 7b3c231f..00000000 --- a/packages/squiggle-lang/src/distPlus/samplesRenderer/SamplesToShape.re +++ /dev/null @@ -1,164 +0,0 @@ -module Internals = { - module Types = { - type samplingStats = { - sampleCount: int, - outputXYPoints: int, - bandwidthXSuggested: float, - bandwidthUnitSuggested: float, - bandwidthXImplemented: float, - bandwidthUnitImplemented: float, - }; - - type outputs = { - continuousParseParams: option(samplingStats), - shape: option(DistTypes.shape), - }; - }; - - module JS = { - [@bs.deriving abstract] - type distJs = { - xs: array(float), - ys: array(float), - }; - - let jsToDist = (d: distJs): DistTypes.xyShape => { - xs: xsGet(d), - ys: ysGet(d), - }; - - [@bs.module "./KdeLibrary.js"] - external samplesToContinuousPdf: (array(float), int, int) => distJs = - "samplesToContinuousPdf"; - }; - - module KDE = { - let normalSampling = (samples, outputXYPoints, kernelWidth) => { - samples - |> JS.samplesToContinuousPdf(_, outputXYPoints, kernelWidth) - |> JS.jsToDist; - }; - }; - - module T = { - type t = array(float); - - let splitContinuousAndDiscrete = (sortedArray: t) => { - let continuous = [||]; - let discrete = E.FloatFloatMap.empty(); - Belt.Array.forEachWithIndex( - sortedArray, - (index, element) => { - let maxIndex = (sortedArray |> Array.length) - 1; - let possiblySimilarElements = - ( - switch (index) { - | 0 => [|index + 1|] - | n when n == maxIndex => [|index - 1|] - | _ => [|index - 1, index + 1|] - } - ) - |> Belt.Array.map(_, r => sortedArray[r]); - let hasSimilarElement = - Belt.Array.some(possiblySimilarElements, r => r == element); - hasSimilarElement - ? E.FloatFloatMap.increment(element, discrete) - : { - let _ = Js.Array.push(element, continuous); - (); - }; - (); - }, - ); - (continuous, discrete); - }; - - let xWidthToUnitWidth = (samples, outputXYPoints, xWidth) => { - let xyPointRange = E.A.Sorted.range(samples) |> E.O.default(0.0); - let xyPointWidth = xyPointRange /. float_of_int(outputXYPoints); - xWidth /. xyPointWidth; - }; - - let formatUnitWidth = w => Jstat.max([|w, 1.0|]) |> int_of_float; - - let suggestedUnitWidth = (samples, outputXYPoints) => { - let suggestedXWidth = Bandwidth.nrd0(samples); - xWidthToUnitWidth(samples, outputXYPoints, suggestedXWidth); - }; - - let kde = (~samples, ~outputXYPoints, width) => { - KDE.normalSampling(samples, outputXYPoints, width); - }; - }; -}; - -let toShape = - ( - ~samples: Internals.T.t, - ~samplingInputs: ExpressionTypes.ExpressionTree.samplingInputs, - (), - ) => { - Array.fast_sort(compare, samples); - let (continuousPart, discretePart) = E.A.Sorted.Floats.split(samples); - let length = samples |> E.A.length |> float_of_int; - let discrete: DistTypes.discreteShape = - discretePart - |> E.FloatFloatMap.fmap(r => r /. length) - |> E.FloatFloatMap.toArray - |> XYShape.T.fromZippedArray - |> Discrete.make; - - let pdf = - continuousPart |> E.A.length > 5 - ? { - let _suggestedXWidth = Bandwidth.nrd0(continuousPart); - // todo: This does some recalculating from the last step. - let _suggestedUnitWidth = - Internals.T.suggestedUnitWidth( - continuousPart, - samplingInputs.outputXYPoints, - ); - let usedWidth = - samplingInputs.kernelWidth |> E.O.default(_suggestedXWidth); - let usedUnitWidth = - Internals.T.xWidthToUnitWidth( - samples, - samplingInputs.outputXYPoints, - usedWidth, - ); - let samplingStats: Internals.Types.samplingStats = { - sampleCount: samplingInputs.sampleCount, - outputXYPoints: samplingInputs.outputXYPoints, - bandwidthXSuggested: _suggestedXWidth, - bandwidthUnitSuggested: _suggestedUnitWidth, - bandwidthXImplemented: usedWidth, - bandwidthUnitImplemented: usedUnitWidth, - }; - continuousPart - |> Internals.T.kde( - ~samples=_, - ~outputXYPoints=samplingInputs.outputXYPoints, - Internals.T.formatUnitWidth(usedUnitWidth), - ) - |> Continuous.make - |> (r => Some((r, samplingStats))); - } - : None; - - let shape = - MixedShapeBuilder.buildSimple( - ~continuous=pdf |> E.O.fmap(fst), - ~discrete=Some(discrete), - ); - - let samplesParse: Internals.Types.outputs = { - continuousParseParams: pdf |> E.O.fmap(snd), - shape, - }; - - samplesParse; -}; - -let fromSamples = (~samplingInputs, samples) => { - toShape(~samples, ~samplingInputs, ()); -}; diff --git a/packages/squiggle-lang/src/distPlus/samplesRenderer/SamplesToShape.res b/packages/squiggle-lang/src/distPlus/samplesRenderer/SamplesToShape.res new file mode 100644 index 00000000..433ba91d --- /dev/null +++ b/packages/squiggle-lang/src/distPlus/samplesRenderer/SamplesToShape.res @@ -0,0 +1,143 @@ +module Internals = { + module Types = { + type samplingStats = { + sampleCount: int, + outputXYPoints: int, + bandwidthXSuggested: float, + bandwidthUnitSuggested: float, + bandwidthXImplemented: float, + bandwidthUnitImplemented: float, + } + + type outputs = { + continuousParseParams: option, + shape: option, + } + } + + module JS = { + @deriving(abstract) + type distJs = { + xs: array, + ys: array, + } + + let jsToDist = (d: distJs): DistTypes.xyShape => { + xs: xsGet(d), + ys: ysGet(d), + } + + @module("./KdeLibrary.js") + external samplesToContinuousPdf: (array, int, int) => distJs = "samplesToContinuousPdf" + } + + module KDE = { + let normalSampling = (samples, outputXYPoints, kernelWidth) => + samples |> JS.samplesToContinuousPdf(_, outputXYPoints, kernelWidth) |> JS.jsToDist + } + + module T = { + type t = array + + let splitContinuousAndDiscrete = (sortedArray: t) => { + let continuous = [] + let discrete = E.FloatFloatMap.empty() + Belt.Array.forEachWithIndex(sortedArray, (index, element) => { + let maxIndex = (sortedArray |> Array.length) - 1 + let possiblySimilarElements = switch index { + | 0 => [index + 1] + | n if n == maxIndex => [index - 1] + | _ => [index - 1, index + 1] + } |> Belt.Array.map(_, r => sortedArray[r]) + let hasSimilarElement = Belt.Array.some(possiblySimilarElements, r => r == element) + hasSimilarElement + ? E.FloatFloatMap.increment(element, discrete) + : { + let _ = Js.Array.push(element, continuous) + } + () + }) + (continuous, discrete) + } + + let xWidthToUnitWidth = (samples, outputXYPoints, xWidth) => { + let xyPointRange = E.A.Sorted.range(samples) |> E.O.default(0.0) + let xyPointWidth = xyPointRange /. float_of_int(outputXYPoints) + xWidth /. xyPointWidth + } + + let formatUnitWidth = w => Jstat.max([w, 1.0]) |> int_of_float + + let suggestedUnitWidth = (samples, outputXYPoints) => { + let suggestedXWidth = Bandwidth.nrd0(samples) + xWidthToUnitWidth(samples, outputXYPoints, suggestedXWidth) + } + + let kde = (~samples, ~outputXYPoints, width) => + KDE.normalSampling(samples, outputXYPoints, width) + } +} + +let toShape = ( + ~samples: Internals.T.t, + ~samplingInputs: ExpressionTypes.ExpressionTree.samplingInputs, + (), +) => { + Array.fast_sort(compare, samples) + let (continuousPart, discretePart) = E.A.Sorted.Floats.split(samples) + let length = samples |> E.A.length |> float_of_int + let discrete: DistTypes.discreteShape = + discretePart + |> E.FloatFloatMap.fmap(r => r /. length) + |> E.FloatFloatMap.toArray + |> XYShape.T.fromZippedArray + |> Discrete.make + + let pdf = + continuousPart |> E.A.length > 5 + ? { + let _suggestedXWidth = Bandwidth.nrd0(continuousPart) + // todo: This does some recalculating from the last step. + let _suggestedUnitWidth = Internals.T.suggestedUnitWidth( + continuousPart, + samplingInputs.outputXYPoints, + ) + let usedWidth = samplingInputs.kernelWidth |> E.O.default(_suggestedXWidth) + let usedUnitWidth = Internals.T.xWidthToUnitWidth( + samples, + samplingInputs.outputXYPoints, + usedWidth, + ) + let samplingStats: Internals.Types.samplingStats = { + sampleCount: samplingInputs.sampleCount, + outputXYPoints: samplingInputs.outputXYPoints, + bandwidthXSuggested: _suggestedXWidth, + bandwidthUnitSuggested: _suggestedUnitWidth, + bandwidthXImplemented: usedWidth, + bandwidthUnitImplemented: usedUnitWidth, + } + continuousPart + |> Internals.T.kde( + ~samples=_, + ~outputXYPoints=samplingInputs.outputXYPoints, + Internals.T.formatUnitWidth(usedUnitWidth), + ) + |> Continuous.make + |> (r => Some((r, samplingStats))) + } + : None + + let shape = MixedShapeBuilder.buildSimple( + ~continuous=pdf |> E.O.fmap(fst), + ~discrete=Some(discrete), + ) + + let samplesParse: Internals.Types.outputs = { + continuousParseParams: pdf |> E.O.fmap(snd), + shape: shape, + } + + samplesParse +} + +let fromSamples = (~samplingInputs, samples) => toShape(~samples, ~samplingInputs, ()) diff --git a/packages/squiggle-lang/src/distPlus/symbolic/SymbolicDist.re b/packages/squiggle-lang/src/distPlus/symbolic/SymbolicDist.re deleted file mode 100644 index 33942187..00000000 --- a/packages/squiggle-lang/src/distPlus/symbolic/SymbolicDist.re +++ /dev/null @@ -1,346 +0,0 @@ -open SymbolicTypes; - -module Exponential = { - type t = exponential; - let make = (rate:float): symbolicDist => - `Exponential( - { - rate:rate - }, - ); - let pdf = (x, t: t) => Jstat.exponential##pdf(x, t.rate); - let cdf = (x, t: t) => Jstat.exponential##cdf(x, t.rate); - let inv = (p, t: t) => Jstat.exponential##inv(p, t.rate); - let sample = (t: t) => Jstat.exponential##sample(t.rate); - let mean = (t: t) => Ok(Jstat.exponential##mean(t.rate)); - let toString = ({rate}: t) => {j|Exponential($rate)|j}; -}; - -module Cauchy = { - type t = cauchy; - let make = (local, scale): symbolicDist => `Cauchy({local, scale}); - let pdf = (x, t: t) => Jstat.cauchy##pdf(x, t.local, t.scale); - let cdf = (x, t: t) => Jstat.cauchy##cdf(x, t.local, t.scale); - let inv = (p, t: t) => Jstat.cauchy##inv(p, t.local, t.scale); - let sample = (t: t) => Jstat.cauchy##sample(t.local, t.scale); - let mean = (_: t) => Error("Cauchy distributions have no mean value."); - let toString = ({local, scale}: t) => {j|Cauchy($local, $scale)|j}; -}; - -module Triangular = { - type t = triangular; - let make = (low, medium, high): result(symbolicDist, string) => - low < medium && medium < high - ? Ok(`Triangular({low, medium, high})) - : Error("Triangular values must be increasing order"); - let pdf = (x, t: t) => Jstat.triangular##pdf(x, t.low, t.high, t.medium); - let cdf = (x, t: t) => Jstat.triangular##cdf(x, t.low, t.high, t.medium); - let inv = (p, t: t) => Jstat.triangular##inv(p, t.low, t.high, t.medium); - let sample = (t: t) => Jstat.triangular##sample(t.low, t.high, t.medium); - let mean = (t: t) => Ok(Jstat.triangular##mean(t.low, t.high, t.medium)); - let toString = ({low, medium, high}: t) => {j|Triangular($low, $medium, $high)|j}; -}; - -module Normal = { - type t = normal; - let make = (mean, stdev): symbolicDist => `Normal({mean, stdev}); - let pdf = (x, t: t) => Jstat.normal##pdf(x, t.mean, t.stdev); - let cdf = (x, t: t) => Jstat.normal##cdf(x, t.mean, t.stdev); - - let from90PercentCI = (low, high) => { - let mean = E.A.Floats.mean([|low, high|]); - let stdev = (high -. low) /. (2. *. 1.644854); - `Normal({mean, stdev}); - }; - let inv = (p, t: t) => Jstat.normal##inv(p, t.mean, t.stdev); - let sample = (t: t) => Jstat.normal##sample(t.mean, t.stdev); - let mean = (t: t) => Ok(Jstat.normal##mean(t.mean, t.stdev)); - let toString = ({mean, stdev}: t) => {j|Normal($mean,$stdev)|j}; - - let add = (n1: t, n2: t) => { - let mean = n1.mean +. n2.mean; - let stdev = sqrt(n1.stdev ** 2. +. n2.stdev ** 2.); - `Normal({mean, stdev}); - }; - let subtract = (n1: t, n2: t) => { - let mean = n1.mean -. n2.mean; - let stdev = sqrt(n1.stdev ** 2. +. n2.stdev ** 2.); - `Normal({mean, stdev}); - }; - - // TODO: is this useful here at all? would need the integral as well ... - let pointwiseProduct = (n1: t, n2: t) => { - let mean = - (n1.mean *. n2.stdev ** 2. +. n2.mean *. n1.stdev ** 2.) - /. (n1.stdev ** 2. +. n2.stdev ** 2.); - let stdev = 1. /. (1. /. n1.stdev ** 2. +. 1. /. n2.stdev ** 2.); - `Normal({mean, stdev}); - }; - - let operate = (operation: Operation.Algebraic.t, n1: t, n2: t) => - switch (operation) { - | `Add => Some(add(n1, n2)) - | `Subtract => Some(subtract(n1, n2)) - | _ => None - }; -}; - -module Beta = { - type t = beta; - let make = (alpha, beta) => `Beta({alpha, beta}); - let pdf = (x, t: t) => Jstat.beta##pdf(x, t.alpha, t.beta); - let cdf = (x, t: t) => Jstat.beta##cdf(x, t.alpha, t.beta); - let inv = (p, t: t) => Jstat.beta##inv(p, t.alpha, t.beta); - let sample = (t: t) => Jstat.beta##sample(t.alpha, t.beta); - let mean = (t: t) => Ok(Jstat.beta##mean(t.alpha, t.beta)); - let toString = ({alpha, beta}: t) => {j|Beta($alpha,$beta)|j}; -}; - -module Lognormal = { - type t = lognormal; - let make = (mu, sigma) => `Lognormal({mu, sigma}); - let pdf = (x, t: t) => Jstat.lognormal##pdf(x, t.mu, t.sigma); - let cdf = (x, t: t) => Jstat.lognormal##cdf(x, t.mu, t.sigma); - let inv = (p, t: t) => Jstat.lognormal##inv(p, t.mu, t.sigma); - let mean = (t: t) => Ok(Jstat.lognormal##mean(t.mu, t.sigma)); - let sample = (t: t) => Jstat.lognormal##sample(t.mu, t.sigma); - let toString = ({mu, sigma}: t) => {j|Lognormal($mu,$sigma)|j}; - let from90PercentCI = (low, high) => { - let logLow = Js.Math.log(low); - let logHigh = Js.Math.log(high); - let mu = E.A.Floats.mean([|logLow, logHigh|]); - let sigma = (logHigh -. logLow) /. (2.0 *. 1.645); - `Lognormal({mu, sigma}); - }; - let fromMeanAndStdev = (mean, stdev) => { - let variance = Js.Math.pow_float(~base=stdev, ~exp=2.0); - let meanSquared = Js.Math.pow_float(~base=mean, ~exp=2.0); - let mu = - Js.Math.log(mean) -. 0.5 *. Js.Math.log(variance /. meanSquared +. 1.0); - let sigma = - Js.Math.pow_float( - ~base=Js.Math.log(variance /. meanSquared +. 1.0), - ~exp=0.5, - ); - `Lognormal({mu, sigma}); - }; - - let multiply = (l1, l2) => { - let mu = l1.mu +. l2.mu; - let sigma = l1.sigma +. l2.sigma; - `Lognormal({mu, sigma}); - }; - let divide = (l1, l2) => { - let mu = l1.mu -. l2.mu; - let sigma = l1.sigma +. l2.sigma; - `Lognormal({mu, sigma}); - }; - let operate = (operation: Operation.Algebraic.t, n1: t, n2: t) => - switch (operation) { - | `Multiply => Some(multiply(n1, n2)) - | `Divide => Some(divide(n1, n2)) - | _ => None - }; -}; - -module Uniform = { - type t = uniform; - let make = (low, high) => `Uniform({low, high}); - let pdf = (x, t: t) => Jstat.uniform##pdf(x, t.low, t.high); - let cdf = (x, t: t) => Jstat.uniform##cdf(x, t.low, t.high); - let inv = (p, t: t) => Jstat.uniform##inv(p, t.low, t.high); - let sample = (t: t) => Jstat.uniform##sample(t.low, t.high); - let mean = (t: t) => Ok(Jstat.uniform##mean(t.low, t.high)); - let toString = ({low, high}: t) => {j|Uniform($low,$high)|j}; - let truncate = (low, high, t: t): t => { - let newLow = max(E.O.default(neg_infinity, low), t.low); - let newHigh = min(E.O.default(infinity, high), t.high); - {low: newLow, high: newHigh}; - }; -}; - -module Float = { - type t = float; - let make = t => `Float(t); - let pdf = (x, t: t) => x == t ? 1.0 : 0.0; - let cdf = (x, t: t) => x >= t ? 1.0 : 0.0; - let inv = (p, t: t) => p < t ? 0.0 : 1.0; - let mean = (t: t) => Ok(t); - let sample = (t: t) => t; - let toString = Js.Float.toString; -}; - -module T = { - let minCdfValue = 0.0001; - let maxCdfValue = 0.9999; - - let pdf = (x, dist) => - switch (dist) { - | `Normal(n) => Normal.pdf(x, n) - | `Triangular(n) => Triangular.pdf(x, n) - | `Exponential(n) => Exponential.pdf(x, n) - | `Cauchy(n) => Cauchy.pdf(x, n) - | `Lognormal(n) => Lognormal.pdf(x, n) - | `Uniform(n) => Uniform.pdf(x, n) - | `Beta(n) => Beta.pdf(x, n) - | `Float(n) => Float.pdf(x, n) - }; - - let cdf = (x, dist) => - switch (dist) { - | `Normal(n) => Normal.cdf(x, n) - | `Triangular(n) => Triangular.cdf(x, n) - | `Exponential(n) => Exponential.cdf(x, n) - | `Cauchy(n) => Cauchy.cdf(x, n) - | `Lognormal(n) => Lognormal.cdf(x, n) - | `Uniform(n) => Uniform.cdf(x, n) - | `Beta(n) => Beta.cdf(x, n) - | `Float(n) => Float.cdf(x, n) - }; - - let inv = (x, dist) => - switch (dist) { - | `Normal(n) => Normal.inv(x, n) - | `Triangular(n) => Triangular.inv(x, n) - | `Exponential(n) => Exponential.inv(x, n) - | `Cauchy(n) => Cauchy.inv(x, n) - | `Lognormal(n) => Lognormal.inv(x, n) - | `Uniform(n) => Uniform.inv(x, n) - | `Beta(n) => Beta.inv(x, n) - | `Float(n) => Float.inv(x, n) - }; - - let sample: symbolicDist => float = - fun - | `Normal(n) => Normal.sample(n) - | `Triangular(n) => Triangular.sample(n) - | `Exponential(n) => Exponential.sample(n) - | `Cauchy(n) => Cauchy.sample(n) - | `Lognormal(n) => Lognormal.sample(n) - | `Uniform(n) => Uniform.sample(n) - | `Beta(n) => Beta.sample(n) - | `Float(n) => Float.sample(n); - - let doN = (n, fn) => { - let items = Belt.Array.make(n, 0.0); - for (x in 0 to n - 1) { - let _ = Belt.Array.set(items, x, fn()); - (); - }; - items; - }; - - let sampleN = (n, dist) => { - doN(n, () => sample(dist)); - }; - - let toString: symbolicDist => string = - fun - | `Triangular(n) => Triangular.toString(n) - | `Exponential(n) => Exponential.toString(n) - | `Cauchy(n) => Cauchy.toString(n) - | `Normal(n) => Normal.toString(n) - | `Lognormal(n) => Lognormal.toString(n) - | `Uniform(n) => Uniform.toString(n) - | `Beta(n) => Beta.toString(n) - | `Float(n) => Float.toString(n); - - let min: symbolicDist => float = - fun - | `Triangular({low}) => low - | `Exponential(n) => Exponential.inv(minCdfValue, n) - | `Cauchy(n) => Cauchy.inv(minCdfValue, n) - | `Normal(n) => Normal.inv(minCdfValue, n) - | `Lognormal(n) => Lognormal.inv(minCdfValue, n) - | `Uniform({low}) => low - | `Beta(n) => Beta.inv(minCdfValue, n) - | `Float(n) => n; - - let max: symbolicDist => float = - fun - | `Triangular(n) => n.high - | `Exponential(n) => Exponential.inv(maxCdfValue, n) - | `Cauchy(n) => Cauchy.inv(maxCdfValue, n) - | `Normal(n) => Normal.inv(maxCdfValue, n) - | `Lognormal(n) => Lognormal.inv(maxCdfValue, n) - | `Beta(n) => Beta.inv(maxCdfValue, n) - | `Uniform({high}) => high - | `Float(n) => n; - - let mean: symbolicDist => result(float, string) = - fun - | `Triangular(n) => Triangular.mean(n) - | `Exponential(n) => Exponential.mean(n) - | `Cauchy(n) => Cauchy.mean(n) - | `Normal(n) => Normal.mean(n) - | `Lognormal(n) => Lognormal.mean(n) - | `Beta(n) => Beta.mean(n) - | `Uniform(n) => Uniform.mean(n) - | `Float(n) => Float.mean(n); - - let operate = (distToFloatOp: ExpressionTypes.distToFloatOperation, s) => - switch (distToFloatOp) { - | `Cdf(f) => Ok(cdf(f, s)) - | `Pdf(f) => Ok(pdf(f, s)) - | `Inv(f) => Ok(inv(f, s)) - | `Sample => Ok(sample(s)) - | `Mean => mean(s) - }; - - let interpolateXs = - (~xSelection: [ | `Linear | `ByWeight]=`Linear, dist: symbolicDist, n) => { - switch (xSelection, dist) { - | (`Linear, _) => E.A.Floats.range(min(dist), max(dist), n) - | (`ByWeight, `Uniform(n)) => - // In `ByWeight mode, uniform distributions get special treatment because we need two x's - // on either side for proper rendering (just left and right of the discontinuities). - let dx = 0.00001 *. (n.high -. n.low); - [|n.low -. dx, n.low +. dx, n.high -. dx, n.high +. dx|]; - | (`ByWeight, _) => - let ys = E.A.Floats.range(minCdfValue, maxCdfValue, n); - ys |> E.A.fmap(y => inv(y, dist)); - }; - }; - - /* Calling e.g. "Normal.operate" returns an optional that wraps a result. - If the optional is None, there is no valid analytic solution. If it Some, it - can still return an error if there is a serious problem, - like in the case of a divide by 0. - */ - let tryAnalyticalSimplification = - ( - d1: symbolicDist, - d2: symbolicDist, - op: ExpressionTypes.algebraicOperation, - ) - : analyticalSimplificationResult => - switch (d1, d2) { - | (`Float(v1), `Float(v2)) => - switch (Operation.Algebraic.applyFn(op, v1, v2)) { - | Ok(r) => `AnalyticalSolution(`Float(r)) - | Error(n) => `Error(n) - } - | (`Normal(v1), `Normal(v2)) => - Normal.operate(op, v1, v2) - |> E.O.dimap(r => `AnalyticalSolution(r), () => `NoSolution) - | (`Lognormal(v1), `Lognormal(v2)) => - Lognormal.operate(op, v1, v2) - |> E.O.dimap(r => `AnalyticalSolution(r), () => `NoSolution) - | _ => `NoSolution - }; - - let toShape = (sampleCount, d: symbolicDist): DistTypes.shape => - switch (d) { - | `Float(v) => - Discrete( - Discrete.make( - ~integralSumCache=Some(1.0), - {xs: [|v|], ys: [|1.0|]}, - ), - ) - | _ => - let xs = interpolateXs(~xSelection=`ByWeight, d, sampleCount); - let ys = xs |> E.A.fmap(x => pdf(x, d)); - Continuous(Continuous.make(~integralSumCache=Some(1.0), {xs, ys})); - }; -}; diff --git a/packages/squiggle-lang/src/distPlus/symbolic/SymbolicDist.res b/packages/squiggle-lang/src/distPlus/symbolic/SymbolicDist.res new file mode 100644 index 00000000..14915511 --- /dev/null +++ b/packages/squiggle-lang/src/distPlus/symbolic/SymbolicDist.res @@ -0,0 +1,328 @@ +open SymbolicTypes + +module Normal = { + type t = normal + let make = (mean, stdev): symbolicDist => #Normal({mean: mean, stdev: stdev}) + let pdf = (x, t: t) => Jstat.Normal.pdf(x, t.mean, t.stdev) + let cdf = (x, t: t) => Jstat.Normal.cdf(x, t.mean, t.stdev) + + let from90PercentCI = (low, high) => { + let mean = E.A.Floats.mean([low, high]) + let stdev = (high -. low) /. (2. *. 1.644854) + #Normal({mean: mean, stdev: stdev}) + } + let inv = (p, t: t) => Jstat.Normal.inv(p, t.mean, t.stdev) + let sample = (t: t) => Jstat.Normal.sample(t.mean, t.stdev) + let mean = (t: t) => Ok(Jstat.Normal.mean(t.mean, t.stdev)) + let toString = ({mean, stdev}: t) => j`Normal($mean,$stdev)` + + let add = (n1: t, n2: t) => { + let mean = n1.mean +. n2.mean + let stdev = sqrt(n1.stdev ** 2. +. n2.stdev ** 2.) + #Normal({mean: mean, stdev: stdev}) + } + let subtract = (n1: t, n2: t) => { + let mean = n1.mean -. n2.mean + let stdev = sqrt(n1.stdev ** 2. +. n2.stdev ** 2.) + #Normal({mean: mean, stdev: stdev}) + } + + // TODO: is this useful here at all? would need the integral as well ... + let pointwiseProduct = (n1: t, n2: t) => { + let mean = + (n1.mean *. n2.stdev ** 2. +. n2.mean *. n1.stdev ** 2.) /. (n1.stdev ** 2. +. n2.stdev ** 2.) + let stdev = 1. /. (1. /. n1.stdev ** 2. +. 1. /. n2.stdev ** 2.) + #Normal({mean: mean, stdev: stdev}) + } + + let operate = (operation: Operation.Algebraic.t, n1: t, n2: t) => + switch operation { + | #Add => Some(add(n1, n2)) + | #Subtract => Some(subtract(n1, n2)) + | _ => None + } +} + +module Exponential = { + type t = exponential + let make = (rate: float): symbolicDist => + #Exponential({ + rate: rate, + }) + let pdf = (x, t: t) => Jstat.Exponential.pdf(x, t.rate) + let cdf = (x, t: t) => Jstat.Exponential.cdf(x, t.rate) + let inv = (p, t: t) => Jstat.Exponential.inv(p, t.rate) + let sample = (t: t) => Jstat.Exponential.sample(t.rate) + let mean = (t: t) => Ok(Jstat.Exponential.mean(t.rate)) + let toString = ({rate}: t) => j`Exponential($rate)` +} + +module Cauchy = { + type t = cauchy + let make = (local, scale): symbolicDist => #Cauchy({local: local, scale: scale}) + let pdf = (x, t: t) => Jstat.Cauchy.pdf(x, t.local, t.scale) + let cdf = (x, t: t) => Jstat.Cauchy.cdf(x, t.local, t.scale) + let inv = (p, t: t) => Jstat.Cauchy.inv(p, t.local, t.scale) + let sample = (t: t) => Jstat.Cauchy.sample(t.local, t.scale) + let mean = (_: t) => Error("Cauchy distributions have no mean value.") + let toString = ({local, scale}: t) => j`Cauchy($local, $scale)` +} + +module Triangular = { + type t = triangular + let make = (low, medium, high): result => + low < medium && medium < high + ? Ok(#Triangular({low: low, medium: medium, high: high})) + : Error("Triangular values must be increasing order") + let pdf = (x, t: t) => Jstat.Triangular.pdf(x, t.low, t.high, t.medium) + let cdf = (x, t: t) => Jstat.Triangular.cdf(x, t.low, t.high, t.medium) + let inv = (p, t: t) => Jstat.Triangular.inv(p, t.low, t.high, t.medium) + let sample = (t: t) => Jstat.Triangular.sample(t.low, t.high, t.medium) + let mean = (t: t) => Ok(Jstat.Triangular.mean(t.low, t.high, t.medium)) + let toString = ({low, medium, high}: t) => j`Triangular($low, $medium, $high)` +} + +module Beta = { + type t = beta + let make = (alpha, beta) => #Beta({alpha: alpha, beta: beta}) + let pdf = (x, t: t) => Jstat.Beta.pdf(x, t.alpha, t.beta) + let cdf = (x, t: t) => Jstat.Beta.cdf(x, t.alpha, t.beta) + let inv = (p, t: t) => Jstat.Beta.inv(p, t.alpha, t.beta) + let sample = (t: t) => Jstat.Beta.sample(t.alpha, t.beta) + let mean = (t: t) => Ok(Jstat.Beta.mean(t.alpha, t.beta)) + let toString = ({alpha, beta}: t) => j`Beta($alpha,$beta)` +} + +module Lognormal = { + type t = lognormal + let make = (mu, sigma) => #Lognormal({mu: mu, sigma: sigma}) + let pdf = (x, t: t) => Jstat.Lognormal.pdf(x, t.mu, t.sigma) + let cdf = (x, t: t) => Jstat.Lognormal.cdf(x, t.mu, t.sigma) + let inv = (p, t: t) => Jstat.Lognormal.inv(p, t.mu, t.sigma) + let mean = (t: t) => Ok(Jstat.Lognormal.mean(t.mu, t.sigma)) + let sample = (t: t) => Jstat.Lognormal.sample(t.mu, t.sigma) + let toString = ({mu, sigma}: t) => j`Lognormal($mu,$sigma)` + let from90PercentCI = (low, high) => { + let logLow = Js.Math.log(low) + let logHigh = Js.Math.log(high) + let mu = E.A.Floats.mean([logLow, logHigh]) + let sigma = (logHigh -. logLow) /. (2.0 *. 1.645) + #Lognormal({mu: mu, sigma: sigma}) + } + let fromMeanAndStdev = (mean, stdev) => { + let variance = Js.Math.pow_float(~base=stdev, ~exp=2.0) + let meanSquared = Js.Math.pow_float(~base=mean, ~exp=2.0) + let mu = Js.Math.log(mean) -. 0.5 *. Js.Math.log(variance /. meanSquared +. 1.0) + let sigma = Js.Math.pow_float(~base=Js.Math.log(variance /. meanSquared +. 1.0), ~exp=0.5) + #Lognormal({mu: mu, sigma: sigma}) + } + + let multiply = (l1, l2) => { + let mu = l1.mu +. l2.mu + let sigma = l1.sigma +. l2.sigma + #Lognormal({mu: mu, sigma: sigma}) + } + let divide = (l1, l2) => { + let mu = l1.mu -. l2.mu + let sigma = l1.sigma +. l2.sigma + #Lognormal({mu: mu, sigma: sigma}) + } + let operate = (operation: Operation.Algebraic.t, n1: t, n2: t) => + switch operation { + | #Multiply => Some(multiply(n1, n2)) + | #Divide => Some(divide(n1, n2)) + | _ => None + } +} + +module Uniform = { + type t = uniform + let make = (low, high) => #Uniform({low: low, high: high}) + let pdf = (x, t: t) => Jstat.Uniform.pdf(x, t.low, t.high) + let cdf = (x, t: t) => Jstat.Uniform.cdf(x, t.low, t.high) + let inv = (p, t: t) => Jstat.Uniform.inv(p, t.low, t.high) + let sample = (t: t) => Jstat.Uniform.sample(t.low, t.high) + let mean = (t: t) => Ok(Jstat.Uniform.mean(t.low, t.high)) + let toString = ({low, high}: t) => j`Uniform($low,$high)` + let truncate = (low, high, t: t): t => { + let newLow = max(E.O.default(neg_infinity, low), t.low) + let newHigh = min(E.O.default(infinity, high), t.high) + {low: newLow, high: newHigh} + } +} + +module Float = { + type t = float + let make = t => #Float(t) + let pdf = (x, t: t) => x == t ? 1.0 : 0.0 + let cdf = (x, t: t) => x >= t ? 1.0 : 0.0 + let inv = (p, t: t) => p < t ? 0.0 : 1.0 + let mean = (t: t) => Ok(t) + let sample = (t: t) => t + let toString = Js.Float.toString +} + +module T = { + let minCdfValue = 0.0001 + let maxCdfValue = 0.9999 + + let pdf = (x, dist) => + switch dist { + | #Normal(n) => Normal.pdf(x, n) + | #Triangular(n) => Triangular.pdf(x, n) + | #Exponential(n) => Exponential.pdf(x, n) + | #Cauchy(n) => Cauchy.pdf(x, n) + | #Lognormal(n) => Lognormal.pdf(x, n) + | #Uniform(n) => Uniform.pdf(x, n) + | #Beta(n) => Beta.pdf(x, n) + | #Float(n) => Float.pdf(x, n) + } + + let cdf = (x, dist) => + switch dist { + | #Normal(n) => Normal.cdf(x, n) + | #Triangular(n) => Triangular.cdf(x, n) + | #Exponential(n) => Exponential.cdf(x, n) + | #Cauchy(n) => Cauchy.cdf(x, n) + | #Lognormal(n) => Lognormal.cdf(x, n) + | #Uniform(n) => Uniform.cdf(x, n) + | #Beta(n) => Beta.cdf(x, n) + | #Float(n) => Float.cdf(x, n) + } + + let inv = (x, dist) => + switch dist { + | #Normal(n) => Normal.inv(x, n) + | #Triangular(n) => Triangular.inv(x, n) + | #Exponential(n) => Exponential.inv(x, n) + | #Cauchy(n) => Cauchy.inv(x, n) + | #Lognormal(n) => Lognormal.inv(x, n) + | #Uniform(n) => Uniform.inv(x, n) + | #Beta(n) => Beta.inv(x, n) + | #Float(n) => Float.inv(x, n) + } + + let sample: symbolicDist => float = x => + switch x { + | #Normal(n) => Normal.sample(n) + | #Triangular(n) => Triangular.sample(n) + | #Exponential(n) => Exponential.sample(n) + | #Cauchy(n) => Cauchy.sample(n) + | #Lognormal(n) => Lognormal.sample(n) + | #Uniform(n) => Uniform.sample(n) + | #Beta(n) => Beta.sample(n) + | #Float(n) => Float.sample(n) + } + + let doN = (n, fn) => { + let items = Belt.Array.make(n, 0.0) + for x in 0 to n - 1 { + let _ = Belt.Array.set(items, x, fn()) + } + items + } + + let sampleN = (n, dist) => doN(n, () => sample(dist)) + + let toString: symbolicDist => string = x => + switch x { + | #Triangular(n) => Triangular.toString(n) + | #Exponential(n) => Exponential.toString(n) + | #Cauchy(n) => Cauchy.toString(n) + | #Normal(n) => Normal.toString(n) + | #Lognormal(n) => Lognormal.toString(n) + | #Uniform(n) => Uniform.toString(n) + | #Beta(n) => Beta.toString(n) + | #Float(n) => Float.toString(n) + } + + let min: symbolicDist => float = x => + switch x { + | #Triangular({low}) => low + | #Exponential(n) => Exponential.inv(minCdfValue, n) + | #Cauchy(n) => Cauchy.inv(minCdfValue, n) + | #Normal(n) => Normal.inv(minCdfValue, n) + | #Lognormal(n) => Lognormal.inv(minCdfValue, n) + | #Uniform({low}) => low + | #Beta(n) => Beta.inv(minCdfValue, n) + | #Float(n) => n + } + + let max: symbolicDist => float = x => + switch x { + | #Triangular(n) => n.high + | #Exponential(n) => Exponential.inv(maxCdfValue, n) + | #Cauchy(n) => Cauchy.inv(maxCdfValue, n) + | #Normal(n) => Normal.inv(maxCdfValue, n) + | #Lognormal(n) => Lognormal.inv(maxCdfValue, n) + | #Beta(n) => Beta.inv(maxCdfValue, n) + | #Uniform({high}) => high + | #Float(n) => n + } + + let mean: symbolicDist => result = x => + switch x { + | #Triangular(n) => Triangular.mean(n) + | #Exponential(n) => Exponential.mean(n) + | #Cauchy(n) => Cauchy.mean(n) + | #Normal(n) => Normal.mean(n) + | #Lognormal(n) => Lognormal.mean(n) + | #Beta(n) => Beta.mean(n) + | #Uniform(n) => Uniform.mean(n) + | #Float(n) => Float.mean(n) + } + + let operate = (distToFloatOp: ExpressionTypes.distToFloatOperation, s) => + switch distToFloatOp { + | #Cdf(f) => Ok(cdf(f, s)) + | #Pdf(f) => Ok(pdf(f, s)) + | #Inv(f) => Ok(inv(f, s)) + | #Sample => Ok(sample(s)) + | #Mean => mean(s) + } + + let interpolateXs = (~xSelection: [#Linear | #ByWeight]=#Linear, dist: symbolicDist, n) => + switch (xSelection, dist) { + | (#Linear, _) => E.A.Floats.range(min(dist), max(dist), n) + | (#ByWeight, #Uniform(n)) => + // In `ByWeight mode, uniform distributions get special treatment because we need two x's + // on either side for proper rendering (just left and right of the discontinuities). + let dx = 0.00001 *. (n.high -. n.low) + [n.low -. dx, n.low +. dx, n.high -. dx, n.high +. dx] + | (#ByWeight, _) => + let ys = E.A.Floats.range(minCdfValue, maxCdfValue, n) + ys |> E.A.fmap(y => inv(y, dist)) + } + + /* Calling e.g. "Normal.operate" returns an optional that wraps a result. + If the optional is None, there is no valid analytic solution. If it Some, it + can still return an error if there is a serious problem, + like in the case of a divide by 0. + */ + let tryAnalyticalSimplification = ( + d1: symbolicDist, + d2: symbolicDist, + op: ExpressionTypes.algebraicOperation, + ): analyticalSimplificationResult => + switch (d1, d2) { + | (#Float(v1), #Float(v2)) => + switch Operation.Algebraic.applyFn(op, v1, v2) { + | Ok(r) => #AnalyticalSolution(#Float(r)) + | Error(n) => #Error(n) + } + | (#Normal(v1), #Normal(v2)) => + Normal.operate(op, v1, v2) |> E.O.dimap(r => #AnalyticalSolution(r), () => #NoSolution) + | (#Lognormal(v1), #Lognormal(v2)) => + Lognormal.operate(op, v1, v2) |> E.O.dimap(r => #AnalyticalSolution(r), () => #NoSolution) + | _ => #NoSolution + } + + let toShape = (sampleCount, d: symbolicDist): DistTypes.shape => + switch d { + | #Float(v) => Discrete(Discrete.make(~integralSumCache=Some(1.0), {xs: [v], ys: [1.0]})) + | _ => + let xs = interpolateXs(~xSelection=#ByWeight, d, sampleCount) + let ys = xs |> E.A.fmap(x => pdf(x, d)) + Continuous(Continuous.make(~integralSumCache=Some(1.0), {xs: xs, ys: ys})) + } +} diff --git a/packages/squiggle-lang/src/distPlus/symbolic/SymbolicTypes.re b/packages/squiggle-lang/src/distPlus/symbolic/SymbolicTypes.re deleted file mode 100644 index 4d10899e..00000000 --- a/packages/squiggle-lang/src/distPlus/symbolic/SymbolicTypes.re +++ /dev/null @@ -1,49 +0,0 @@ -type normal = { - mean: float, - stdev: float, -}; - -type lognormal = { - mu: float, - sigma: float, -}; - -type uniform = { - low: float, - high: float, -}; - -type beta = { - alpha: float, - beta: float, -}; - -type exponential = {rate: float}; - -type cauchy = { - local: float, - scale: float, -}; - -type triangular = { - low: float, - medium: float, - high: float, -}; - -type symbolicDist = [ - | `Normal(normal) - | `Beta(beta) - | `Lognormal(lognormal) - | `Uniform(uniform) - | `Exponential(exponential) - | `Cauchy(cauchy) - | `Triangular(triangular) - | `Float(float) // Dirac delta at x. Practically useful only in the context of multimodals. -]; - -type analyticalSimplificationResult = [ - | `AnalyticalSolution(symbolicDist) - | `Error(string) - | `NoSolution -]; diff --git a/packages/squiggle-lang/src/distPlus/symbolic/SymbolicTypes.res b/packages/squiggle-lang/src/distPlus/symbolic/SymbolicTypes.res new file mode 100644 index 00000000..d2a6603e --- /dev/null +++ b/packages/squiggle-lang/src/distPlus/symbolic/SymbolicTypes.res @@ -0,0 +1,49 @@ +type normal = { + mean: float, + stdev: float, +} + +type lognormal = { + mu: float, + sigma: float, +} + +type uniform = { + low: float, + high: float, +} + +type beta = { + alpha: float, + beta: float, +} + +type exponential = {rate: float} + +type cauchy = { + local: float, + scale: float, +} + +type triangular = { + low: float, + medium: float, + high: float, +} + +type symbolicDist = [ + | #Normal(normal) + | #Beta(beta) + | #Lognormal(lognormal) + | #Uniform(uniform) + | #Exponential(exponential) + | #Cauchy(cauchy) + | #Triangular(triangular) + | #Float(float) +] + +type analyticalSimplificationResult = [ + | #AnalyticalSolution(symbolicDist) + | #Error(string) + | #NoSolution +] diff --git a/packages/squiggle-lang/src/distPlus/utility/Jstat.re b/packages/squiggle-lang/src/distPlus/utility/Jstat.re deleted file mode 100644 index 0a2cc13f..00000000 --- a/packages/squiggle-lang/src/distPlus/utility/Jstat.re +++ /dev/null @@ -1,112 +0,0 @@ -// Todo: Another way of doing this is with [@bs.scope "normal"], which may be more elegant -type normal = { - . - [@bs.meth] "pdf": (float, float, float) => float, - [@bs.meth] "cdf": (float, float, float) => float, - [@bs.meth] "inv": (float, float, float) => float, - [@bs.meth] "sample": (float, float) => float, - [@bs.meth] "mean": (float, float) => float, -}; -type lognormal = { - . - [@bs.meth] "pdf": (float, float, float) => float, - [@bs.meth] "cdf": (float, float, float) => float, - [@bs.meth] "inv": (float, float, float) => float, - [@bs.meth] "sample": (float, float) => float, - [@bs.meth] "mean": (float, float) => float, -}; -type uniform = { - . - [@bs.meth] "pdf": (float, float, float) => float, - [@bs.meth] "cdf": (float, float, float) => float, - [@bs.meth] "inv": (float, float, float) => float, - [@bs.meth] "sample": (float, float) => float, - [@bs.meth] "mean": (float, float) => float, -}; -type beta = { - . - [@bs.meth] "pdf": (float, float, float) => float, - [@bs.meth] "cdf": (float, float, float) => float, - [@bs.meth] "inv": (float, float, float) => float, - [@bs.meth] "sample": (float, float) => float, - [@bs.meth] "mean": (float, float) => float, -}; -type exponential = { - . - [@bs.meth] "pdf": (float, float) => float, - [@bs.meth] "cdf": (float, float) => float, - [@bs.meth] "inv": (float, float) => float, - [@bs.meth] "sample": float => float, - [@bs.meth] "mean": float => float, -}; -type cauchy = { - . - [@bs.meth] "pdf": (float, float, float) => float, - [@bs.meth] "cdf": (float, float, float) => float, - [@bs.meth] "inv": (float, float, float) => float, - [@bs.meth] "sample": (float, float) => float, -}; -type triangular = { - . - [@bs.meth] "pdf": (float, float, float, float) => float, - [@bs.meth] "cdf": (float, float, float, float) => float, - [@bs.meth] "inv": (float, float, float, float) => float, - [@bs.meth] "sample": (float, float, float) => float, - [@bs.meth] "mean": (float, float, float) => float, -}; - -// Pareto doesn't have sample for some reason -type pareto = { - . - [@bs.meth] "pdf": (float, float, float) => float, - [@bs.meth] "cdf": (float, float, float) => float, - [@bs.meth] "inv": (float, float, float) => float, -}; -type poisson = { - . - [@bs.meth] "pdf": (float, float) => float, - [@bs.meth] "cdf": (float, float) => float, - [@bs.meth] "sample": float => float, - [@bs.meth] "mean": float => float, -}; -type weibull = { - . - [@bs.meth] "pdf": (float, float, float) => float, - [@bs.meth] "cdf": (float, float, float) => float, - [@bs.meth] "inv": (float, float, float) => float, - [@bs.meth] "sample": (float, float) => float, - [@bs.meth] "mean": (float, float) => float, -}; -type binomial = { - . - [@bs.meth] "pdf": (float, float, float) => float, - [@bs.meth] "cdf": (float, float, float) => float, -}; -[@bs.module "jstat"] external normal: normal = "normal"; -[@bs.module "jstat"] external lognormal: lognormal = "lognormal"; -[@bs.module "jstat"] external uniform: uniform = "uniform"; -[@bs.module "jstat"] external beta: beta = "beta"; -[@bs.module "jstat"] external exponential: exponential = "exponential"; -[@bs.module "jstat"] external cauchy: cauchy = "cauchy"; -[@bs.module "jstat"] external triangular: triangular = "triangular"; -[@bs.module "jstat"] external poisson: poisson = "poisson"; -[@bs.module "jstat"] external pareto: pareto = "pareto"; -[@bs.module "jstat"] external weibull: weibull = "weibull"; -[@bs.module "jstat"] external binomial: binomial = "binomial"; - -[@bs.module "jstat"] external sum: array(float) => float = "sum"; -[@bs.module "jstat"] external product: array(float) => float = "product"; -[@bs.module "jstat"] external min: array(float) => float = "min"; -[@bs.module "jstat"] external max: array(float) => float = "max"; -[@bs.module "jstat"] external mean: array(float) => float = "mean"; -[@bs.module "jstat"] external geomean: array(float) => float = "geomean"; -[@bs.module "jstat"] external mode: array(float) => float = "mode"; -[@bs.module "jstat"] external variance: array(float) => float = "variance"; -[@bs.module "jstat"] external deviation: array(float) => float = "deviation"; -[@bs.module "jstat"] external stdev: array(float) => float = "stdev"; -[@bs.module "jstat"] -external quartiles: (array(float)) => array(float) = "quartiles"; -[@bs.module "jstat"] -external quantiles: (array(float), array(float)) => array(float) = "quantiles"; -[@bs.module "jstat"] -external percentile: (array(float), float, bool) => float = "percentile"; diff --git a/packages/squiggle-lang/src/distPlus/utility/Jstat.res b/packages/squiggle-lang/src/distPlus/utility/Jstat.res new file mode 100644 index 00000000..7ef574fc --- /dev/null +++ b/packages/squiggle-lang/src/distPlus/utility/Jstat.res @@ -0,0 +1,100 @@ +// Todo: Another way of doing this is with [@bs.scope "normal"], which may be more elegant +module Normal = { + @module("jStat") @scope("normal") external pdf: (float, float, float) => float = "pdf" + @module("jStat") @scope("normal") external cdf: (float, float, float) => float = "cdf" + @module("jStat") @scope("normal") external inv: (float, float, float) => float = "inv" + @module("jStat") @scope("normal") external sample: (float, float) => float = "sample" + @module("jStat") @scope("normal") external mean: (float, float) => float = "mean" +} + +module Lognormal = { + @module("jStat") @scope("lognormal") external pdf: (float, float, float) => float = "pdf" + @module("jStat") @scope("lognormal") external cdf: (float, float, float) => float = "cdf" + @module("jStat") @scope("lognormal") external inv: (float, float, float) => float = "inv" + @module("jStat") @scope("lognormal") external sample: (float, float) => float = "sample" + @module("jStat") @scope("lognormal") external mean: (float, float) => float = "mean" +} + +module Uniform = { + @module("jStat") @scope("uniform") external pdf: (float, float, float) => float = "pdf" + @module("jStat") @scope("uniform") external cdf: (float, float, float) => float = "cdf" + @module("jStat") @scope("uniform") external inv: (float, float, float) => float = "inv" + @module("jStat") @scope("uniform") external sample: (float, float) => float = "sample" + @module("jStat") @scope("uniform") external mean: (float, float) => float = "mean" +} + +type beta +module Beta = { + @module("jStat") @scope("uniform") external pdf: (float, float, float) => float = "pdf" + @module("jStat") @scope("uniform") external cdf: (float, float, float) => float = "cdf" + @module("jStat") @scope("uniform") external inv: (float, float, float) => float = "inv" + @module("jStat") @scope("uniform") external sample: (float, float) => float = "sample" + @module("jStat") @scope("uniform") external mean: (float, float) => float = "mean" +} + +module Exponential = { + @module("jStat") @scope("uniform") external pdf: (float, float) => float = "pdf" + @module("jStat") @scope("uniform") external cdf: (float, float) => float = "cdf" + @module("jStat") @scope("uniform") external inv: (float, float) => float = "inv" + @module("jStat") @scope("uniform") external sample: (float) => float = "sample" + @module("jStat") @scope("uniform") external mean: (float) => float = "mean" +} + +module Cauchy = { + @module("jStat") @scope("uniform") external pdf: (float, float, float) => float = "pdf" + @module("jStat") @scope("uniform") external cdf: (float, float, float) => float = "cdf" + @module("jStat") @scope("uniform") external inv: (float, float, float) => float = "inv" + @module("jStat") @scope("uniform") external sample: (float, float) => float = "sample" + @module("jStat") @scope("uniform") external mean: (float, float) => float = "mean" +} + +module Triangular = { + @module("jStat") @scope("uniform") external pdf: (float, float, float, float) => float = "pdf" + @module("jStat") @scope("uniform") external cdf: (float, float, float, float) => float = "cdf" + @module("jStat") @scope("uniform") external inv: (float, float, float, float) => float = "inv" + @module("jStat") @scope("uniform") external sample: (float, float, float) => float = "sample" + @module("jStat") @scope("uniform") external mean: (float, float, float) => float = "mean" +} + + +module Pareto = { + @module("jStat") @scope("uniform") external pdf: (float, float, float) => float = "pdf" + @module("jStat") @scope("uniform") external cdf: (float, float, float) => float = "cdf" + @module("jStat") @scope("uniform") external inv: (float, float, float) => float = "inv" +} + +module Poisson = { + @module("jStat") @scope("uniform") external pdf: (float, float) => float = "pdf" + @module("jStat") @scope("uniform") external cdf: (float, float) => float = "cdf" + @module("jStat") @scope("uniform") external sample: (float) => float = "sample" + @module("jStat") @scope("uniform") external mean: (float) => float = "mean" +} + +module Weibull = { + @module("jStat") @scope("uniform") external pdf: (float, float, float) => float = "pdf" + @module("jStat") @scope("uniform") external cdf: (float, float,float ) => float = "cdf" + @module("jStat") @scope("uniform") external sample: (float,float) => float = "sample" + @module("jStat") @scope("uniform") external mean: (float,float) => float = "mean" +} + +module Binomial = { + @module("jStat") @scope("uniform") external pdf: (float, float, float) => float = "pdf" + @module("jStat") @scope("uniform") external cdf: (float, float,float ) => float = "cdf" +} + +@module("jstat") external sum: array => float = "sum" +@module("jstat") external product: array => float = "product" +@module("jstat") external min: array => float = "min" +@module("jstat") external max: array => float = "max" +@module("jstat") external mean: array => float = "mean" +@module("jstat") external geomean: array => float = "geomean" +@module("jstat") external mode: array => float = "mode" +@module("jstat") external variance: array => float = "variance" +@module("jstat") external deviation: array => float = "deviation" +@module("jstat") external stdev: array => float = "stdev" +@module("jstat") +external quartiles: array => array = "quartiles" +@module("jstat") +external quantiles: (array, array) => array = "quantiles" +@module("jstat") +external percentile: (array, float, bool) => float = "percentile" diff --git a/packages/squiggle-lang/src/distPlus/utility/Jstat2.res b/packages/squiggle-lang/src/distPlus/utility/Jstat2.res new file mode 100644 index 00000000..f5df95bd --- /dev/null +++ b/packages/squiggle-lang/src/distPlus/utility/Jstat2.res @@ -0,0 +1,3 @@ +@module("jStat") @scope("normal") external mean: (float, float) => float = "mean" + +let foo = mean; \ No newline at end of file diff --git a/packages/squiggle-lang/src/distPlus/utility/Lodash.re b/packages/squiggle-lang/src/distPlus/utility/Lodash.re deleted file mode 100644 index 57eab32e..00000000 --- a/packages/squiggle-lang/src/distPlus/utility/Lodash.re +++ /dev/null @@ -1,5 +0,0 @@ -[@bs.module "lodash"] external min: array('a) => 'a = "min"; -[@bs.module "lodash"] external max: array('a) => 'a = "max"; -[@bs.module "lodash"] external uniq: array('a) => array('a) = "uniq"; -[@bs.module "lodash"] -external countBy: (array('a), 'a => 'b) => Js.Dict.t(int) = "countBy"; \ No newline at end of file diff --git a/packages/squiggle-lang/src/distPlus/utility/Lodash.res b/packages/squiggle-lang/src/distPlus/utility/Lodash.res new file mode 100644 index 00000000..d4359408 --- /dev/null +++ b/packages/squiggle-lang/src/distPlus/utility/Lodash.res @@ -0,0 +1,5 @@ +@module("lodash") external min: array<'a> => 'a = "min" +@module("lodash") external max: array<'a> => 'a = "max" +@module("lodash") external uniq: array<'a> => array<'a> = "uniq" +@module("lodash") +external countBy: (array<'a>, 'a => 'b) => Js.Dict.t = "countBy"