From 24a7d0eedf340c2dcb221fbca103b470739a694b Mon Sep 17 00:00:00 2001 From: Ozzie Gooen Date: Wed, 16 Feb 2022 14:57:46 -0500 Subject: [PATCH] Refactored Operation.res to live out of the interpreter --- .../src/rescript/interpreter/AST.res | 2 +- .../src/rescript/interpreter/ASTBasic.res | 27 ---------- .../src/rescript/interpreter/ASTEvaluator.res | 4 +- .../src/rescript/interpreter/ASTTypes.res | 50 ++++++++++++------- .../typeSystem/HardcodedFunctions.res | 2 +- .../interpreter/typeSystem/TypeSystem.res | 2 +- .../AlgebraicShapeCombination.res | 4 +- .../src/rescript/pointSetDist/Continuous.res | 4 +- .../src/rescript/pointSetDist/Discrete.res | 2 +- .../src/rescript/pointSetDist/Mixed.res | 2 +- .../rescript/pointSetDist/PointSetDist.res | 4 +- .../rescript/symbolicDist/SymbolicDist.res | 4 +- .../{interpreter => utility}/Operation.res | 37 ++++++++------ 13 files changed, 67 insertions(+), 77 deletions(-) delete mode 100644 packages/squiggle-lang/src/rescript/interpreter/ASTBasic.res rename packages/squiggle-lang/src/rescript/{interpreter => utility}/Operation.res (72%) diff --git a/packages/squiggle-lang/src/rescript/interpreter/AST.res b/packages/squiggle-lang/src/rescript/interpreter/AST.res index 6680db90..2bc93459 100644 --- a/packages/squiggle-lang/src/rescript/interpreter/AST.res +++ b/packages/squiggle-lang/src/rescript/interpreter/AST.res @@ -1,6 +1,6 @@ open ASTTypes.AST -let toString = ASTBasic.toString +let toString = ASTTypes.Node.toString let envs = (samplingInputs, environment) => { samplingInputs: samplingInputs, diff --git a/packages/squiggle-lang/src/rescript/interpreter/ASTBasic.res b/packages/squiggle-lang/src/rescript/interpreter/ASTBasic.res deleted file mode 100644 index b0806c82..00000000 --- a/packages/squiggle-lang/src/rescript/interpreter/ASTBasic.res +++ /dev/null @@ -1,27 +0,0 @@ -open ASTTypes.AST -// This file exists to manage a dependency cycle. It would be good to refactor later. - -let rec toString: node => string = x => - switch x { - | #SymbolicDist(d) => SymbolicDist.T.toString(d) - | #RenderedDist(_) => "[renderedShape]" - | #AlgebraicCombination(op, t1, t2) => Operation.Algebraic.format(op, toString(t1), toString(t2)) - | #PointwiseCombination(op, t1, t2) => Operation.Pointwise.format(op, toString(t1), toString(t2)) - | #Normalize(t) => "normalize(k" ++ (toString(t) ++ ")") - | #Truncate(lc, rc, t) => Operation.T.truncateToString(lc, rc, toString(t)) - | #Render(t) => toString(t) - | #Symbol(t) => "Symbol: " ++ t - | #FunctionCall(name, args) => - "[Function call: (" ++ - (name ++ - ((args |> E.A.fmap(toString) |> Js.String.concatMany(_, ",")) ++ ")]")) - | #Function(args, internal) => - "[Function: (" ++ ((args |> Js.String.concatMany(_, ",")) ++ (toString(internal) ++ ")]")) - | #Array(a) => "[" ++ ((a |> E.A.fmap(toString) |> Js.String.concatMany(_, ",")) ++ "]") - | #Hash(h) => - "{" ++ - ((h - |> E.A.fmap(((name, value)) => name ++ (":" ++ toString(value))) - |> Js.String.concatMany(_, ",")) ++ - "}") - } \ No newline at end of file diff --git a/packages/squiggle-lang/src/rescript/interpreter/ASTEvaluator.res b/packages/squiggle-lang/src/rescript/interpreter/ASTEvaluator.res index ba2929f0..53123ee5 100644 --- a/packages/squiggle-lang/src/rescript/interpreter/ASTEvaluator.res +++ b/packages/squiggle-lang/src/rescript/interpreter/ASTEvaluator.res @@ -56,7 +56,7 @@ module AlgebraicCombination = { let operationToLeaf = ( evaluationParams: evaluationParams, - algebraicOp: ASTTypes.algebraicOperation, + algebraicOp: Operation.algebraicOperation, t1: t, t2: t, ): result => @@ -106,7 +106,7 @@ module PointwiseCombination = { let operationToLeaf = ( evaluationParams: evaluationParams, - pointwiseOp: pointwiseOperation, + pointwiseOp: Operation.pointwiseOperation, t1: t, t2: t, ) => diff --git a/packages/squiggle-lang/src/rescript/interpreter/ASTTypes.res b/packages/squiggle-lang/src/rescript/interpreter/ASTTypes.res index 5c76df13..b9a1049d 100644 --- a/packages/squiggle-lang/src/rescript/interpreter/ASTTypes.res +++ b/packages/squiggle-lang/src/rescript/interpreter/ASTTypes.res @@ -1,20 +1,3 @@ -type algebraicOperation = [ - | #Add - | #Multiply - | #Subtract - | #Divide - | #Exponentiate -] -type pointwiseOperation = [#Add | #Multiply | #Exponentiate] -type scaleOperation = [#Multiply | #Exponentiate | #Log] -type distToFloatOperation = [ - | #Pdf(float) - | #Cdf(float) - | #Inv(float) - | #Mean - | #Sample -] - module AST = { type rec hash = array<(string, node)> and node = [ @@ -24,8 +7,8 @@ module AST = { | #Hash(hash) | #Array(array) | #Function(array, node) - | #AlgebraicCombination(algebraicOperation, node, node) - | #PointwiseCombination(pointwiseOperation, node, node) + | #AlgebraicCombination(Operation.algebraicOperation, node, node) + | #PointwiseCombination(Operation.pointwiseOperation, node, node) | #Normalize(node) | #Render(node) | #Truncate(option, option, node) @@ -172,3 +155,32 @@ module Program = { ] type program = array } + +module Node = { + let rec toString: AST.node => string = x => + switch x { + | #SymbolicDist(d) => SymbolicDist.T.toString(d) + | #RenderedDist(_) => "[renderedShape]" + | #AlgebraicCombination(op, t1, t2) => + Operation.Algebraic.format(op, toString(t1), toString(t2)) + | #PointwiseCombination(op, t1, t2) => + Operation.Pointwise.format(op, toString(t1), toString(t2)) + | #Normalize(t) => "normalize(k" ++ (toString(t) ++ ")") + | #Truncate(lc, rc, t) => Operation.Truncate.toString(lc, rc, toString(t)) + | #Render(t) => toString(t) + | #Symbol(t) => "Symbol: " ++ t + | #FunctionCall(name, args) => + "[Function call: (" ++ + (name ++ + ((args |> E.A.fmap(toString) |> Js.String.concatMany(_, ",")) ++ ")]")) + | #Function(args, internal) => + "[Function: (" ++ ((args |> Js.String.concatMany(_, ",")) ++ (toString(internal) ++ ")]")) + | #Array(a) => "[" ++ ((a |> E.A.fmap(toString) |> Js.String.concatMany(_, ",")) ++ "]") + | #Hash(h) => + "{" ++ + ((h + |> E.A.fmap(((name, value)) => name ++ (":" ++ toString(value))) + |> Js.String.concatMany(_, ",")) ++ + "}") + } +} diff --git a/packages/squiggle-lang/src/rescript/interpreter/typeSystem/HardcodedFunctions.res b/packages/squiggle-lang/src/rescript/interpreter/typeSystem/HardcodedFunctions.res index 3e06ae71..5bf93c4d 100644 --- a/packages/squiggle-lang/src/rescript/interpreter/typeSystem/HardcodedFunctions.res +++ b/packages/squiggle-lang/src/rescript/interpreter/typeSystem/HardcodedFunctions.res @@ -84,7 +84,7 @@ let makeDist = (name, fn) => ) let floatFromDist = ( - distToFloatOp: ASTTypes.distToFloatOperation, + distToFloatOp: Operation.distToFloatOperation, t: TypeSystem.samplingDist, ): result => switch t { diff --git a/packages/squiggle-lang/src/rescript/interpreter/typeSystem/TypeSystem.res b/packages/squiggle-lang/src/rescript/interpreter/typeSystem/TypeSystem.res index 2989f1c0..9535e40e 100644 --- a/packages/squiggle-lang/src/rescript/interpreter/typeSystem/TypeSystem.res +++ b/packages/squiggle-lang/src/rescript/interpreter/typeSystem/TypeSystem.res @@ -61,7 +61,7 @@ module TypedValue = { |> E.A.fmap(((name, t)) => fromNode(t) |> E.R.fmap(r => (name, r))) |> E.A.R.firstErrorOrOpen |> E.R.fmap(r => #Hash(r)) - | e => Error("Wrong type: " ++ ASTBasic.toString(e)) + | e => Error("Wrong type: " ++ ASTTypes.Node.toString(e)) } // todo: Arrays and hashes diff --git a/packages/squiggle-lang/src/rescript/pointSetDist/AlgebraicShapeCombination.res b/packages/squiggle-lang/src/rescript/pointSetDist/AlgebraicShapeCombination.res index 44ab740f..6e22c349 100644 --- a/packages/squiggle-lang/src/rescript/pointSetDist/AlgebraicShapeCombination.res +++ b/packages/squiggle-lang/src/rescript/pointSetDist/AlgebraicShapeCombination.res @@ -96,7 +96,7 @@ let toDiscretePointMassesFromTriangulars = ( } let combineShapesContinuousContinuous = ( - op: ASTTypes.algebraicOperation, + op: Operation.algebraicOperation, s1: PointSetTypes.xyShape, s2: PointSetTypes.xyShape, ): PointSetTypes.xyShape => { @@ -200,7 +200,7 @@ let toDiscretePointMassesFromDiscrete = (s: PointSetTypes.xyShape): pointMassesW } let combineShapesContinuousDiscrete = ( - op: ASTTypes.algebraicOperation, + op: Operation.algebraicOperation, continuousShape: PointSetTypes.xyShape, discreteShape: PointSetTypes.xyShape, ): PointSetTypes.xyShape => { diff --git a/packages/squiggle-lang/src/rescript/pointSetDist/Continuous.res b/packages/squiggle-lang/src/rescript/pointSetDist/Continuous.res index 0dd0924d..92654b35 100644 --- a/packages/squiggle-lang/src/rescript/pointSetDist/Continuous.res +++ b/packages/squiggle-lang/src/rescript/pointSetDist/Continuous.res @@ -211,7 +211,7 @@ module T = Dist({ /* This simply creates multiple copies of the continuous distribution, scaled and shifted according to each discrete data point, and then adds them all together. */ let combineAlgebraicallyWithDiscrete = ( - op: ASTTypes.algebraicOperation, + op: Operation.algebraicOperation, t1: t, t2: PointSetTypes.discreteShape, ) => { @@ -244,7 +244,7 @@ let combineAlgebraicallyWithDiscrete = ( } } -let combineAlgebraically = (op: ASTTypes.algebraicOperation, t1: t, t2: t) => { +let combineAlgebraically = (op: Operation.algebraicOperation, t1: t, t2: t) => { let s1 = t1 |> getShape let s2 = t2 |> getShape let t1n = s1 |> XYShape.T.length diff --git a/packages/squiggle-lang/src/rescript/pointSetDist/Discrete.res b/packages/squiggle-lang/src/rescript/pointSetDist/Discrete.res index 65828461..fd1cedd6 100644 --- a/packages/squiggle-lang/src/rescript/pointSetDist/Discrete.res +++ b/packages/squiggle-lang/src/rescript/pointSetDist/Discrete.res @@ -85,7 +85,7 @@ let updateIntegralCache = (integralCache, t: t): t => { /* This multiples all of the data points together and creates a new discrete distribution from the results. Data points at the same xs get added together. It may be a good idea to downsample t1 and t2 before and/or the result after. */ -let combineAlgebraically = (op: ASTTypes.algebraicOperation, t1: t, t2: t): t => { +let combineAlgebraically = (op: Operation.algebraicOperation, t1: t, t2: t): t => { let t1s = t1 |> getShape let t2s = t2 |> getShape let t1n = t1s |> XYShape.T.length diff --git a/packages/squiggle-lang/src/rescript/pointSetDist/Mixed.res b/packages/squiggle-lang/src/rescript/pointSetDist/Mixed.res index dcdb2c2f..8c7b99ed 100644 --- a/packages/squiggle-lang/src/rescript/pointSetDist/Mixed.res +++ b/packages/squiggle-lang/src/rescript/pointSetDist/Mixed.res @@ -227,7 +227,7 @@ module T = Dist({ } }) -let combineAlgebraically = (op: ASTTypes.algebraicOperation, t1: t, t2: t): t => { +let combineAlgebraically = (op: Operation.algebraicOperation, t1: t, t2: t): t => { // Discrete convolution can cause a huge increase in the number of samples, // so we'll first downsample. diff --git a/packages/squiggle-lang/src/rescript/pointSetDist/PointSetDist.res b/packages/squiggle-lang/src/rescript/pointSetDist/PointSetDist.res index caf0fa1f..41e84fee 100644 --- a/packages/squiggle-lang/src/rescript/pointSetDist/PointSetDist.res +++ b/packages/squiggle-lang/src/rescript/pointSetDist/PointSetDist.res @@ -33,7 +33,7 @@ let toMixed = mapToAll(( ), )) -let combineAlgebraically = (op: ASTTypes.algebraicOperation, t1: t, t2: t): t => +let combineAlgebraically = (op: Operation.algebraicOperation, t1: t, t2: t): t => switch (t1, t2) { | (Continuous(m1), Continuous(m2)) => Continuous.combineAlgebraically(op, m1, m2) |> Continuous.T.toPointSetDist @@ -197,7 +197,7 @@ let sampleNRendered = (n, dist) => { doN(n, () => sample(distWithUpdatedIntegralCache)) } -let operate = (distToFloatOp: ASTTypes.distToFloatOperation, s): float => +let operate = (distToFloatOp: Operation.distToFloatOperation, s): float => switch distToFloatOp { | #Pdf(f) => pdf(f, s) | #Cdf(f) => pdf(f, s) diff --git a/packages/squiggle-lang/src/rescript/symbolicDist/SymbolicDist.res b/packages/squiggle-lang/src/rescript/symbolicDist/SymbolicDist.res index aa02f32a..3da6bd02 100644 --- a/packages/squiggle-lang/src/rescript/symbolicDist/SymbolicDist.res +++ b/packages/squiggle-lang/src/rescript/symbolicDist/SymbolicDist.res @@ -272,7 +272,7 @@ module T = { | #Float(n) => Float.mean(n) } - let operate = (distToFloatOp: ASTTypes.distToFloatOperation, s) => + let operate = (distToFloatOp: Operation.distToFloatOperation, s) => switch distToFloatOp { | #Cdf(f) => Ok(cdf(f, s)) | #Pdf(f) => Ok(pdf(f, s)) @@ -302,7 +302,7 @@ module T = { let tryAnalyticalSimplification = ( d1: symbolicDist, d2: symbolicDist, - op: ASTTypes.algebraicOperation, + op: Operation.algebraicOperation, ): analyticalSimplificationResult => switch (d1, d2) { | (#Float(v1), #Float(v2)) => diff --git a/packages/squiggle-lang/src/rescript/interpreter/Operation.res b/packages/squiggle-lang/src/rescript/utility/Operation.res similarity index 72% rename from packages/squiggle-lang/src/rescript/interpreter/Operation.res rename to packages/squiggle-lang/src/rescript/utility/Operation.res index f5a120a8..dc29bc71 100644 --- a/packages/squiggle-lang/src/rescript/interpreter/Operation.res +++ b/packages/squiggle-lang/src/rescript/utility/Operation.res @@ -1,4 +1,21 @@ -open ASTTypes +// This file has no dependencies. It's used outside of the interpreter, but the interpreter depends on it. + +type algebraicOperation = [ + | #Add + | #Multiply + | #Subtract + | #Divide + | #Exponentiate +] +type pointwiseOperation = [#Add | #Multiply | #Exponentiate] +type scaleOperation = [#Multiply | #Exponentiate | #Log] +type distToFloatOperation = [ + | #Pdf(float) + | #Cdf(float) + | #Inv(float) + | #Mean + | #Sample +] module Algebraic = { type t = algebraicOperation @@ -86,22 +103,10 @@ module Scale = { } } -module T = { - let truncateToString = (left: option, right: option, nodeToString) => { +module Truncate = { + let toString = (left: option, right: option, nodeToString) => { let left = left |> E.O.dimap(Js.Float.toString, () => "-inf") let right = right |> E.O.dimap(Js.Float.toString, () => "inf") j`truncate($nodeToString, $left, $right)` } - let toString = (nodeToString, x) => - switch x { - | #AlgebraicCombination(op, t1, t2) => Algebraic.format(op, nodeToString(t1), nodeToString(t2)) - | #PointwiseCombination(op, t1, t2) => Pointwise.format(op, nodeToString(t1), nodeToString(t2)) - | #VerticalScaling(scaleOp, t, scaleBy) => - Scale.format(scaleOp, nodeToString(t), nodeToString(scaleBy)) - | #Normalize(t) => "normalize(k" ++ (nodeToString(t) ++ ")") - | #FloatFromDist(floatFromDistOp, t) => DistToFloat.format(floatFromDistOp, nodeToString(t)) - | #Truncate(lc, rc, t) => truncateToString(lc, rc, nodeToString(t)) - | #Render(t) => nodeToString(t) - | _ => "" - } // SymbolicDist and RenderedDist are handled in AST.toString. -} +} \ No newline at end of file