From d8b37bb11398cf5d2ec8a779c312b0fe2cd15230 Mon Sep 17 00:00:00 2001 From: Ozzie Gooen Date: Wed, 16 Feb 2022 17:10:48 -0500 Subject: [PATCH] Refactored AST file --- .../src/rescript/interpreter/AST.res | 2 +- .../src/rescript/interpreter/ASTEvaluator.res | 10 +-- .../src/rescript/interpreter/ASTTypes.res | 87 +++++++++---------- .../src/rescript/interpreter/PTypes.res | 2 +- .../interpreter/typeSystem/TypeSystem.res | 6 +- 5 files changed, 49 insertions(+), 58 deletions(-) diff --git a/packages/squiggle-lang/src/rescript/interpreter/AST.res b/packages/squiggle-lang/src/rescript/interpreter/AST.res index 2bc93459..a7f6619f 100644 --- a/packages/squiggle-lang/src/rescript/interpreter/AST.res +++ b/packages/squiggle-lang/src/rescript/interpreter/AST.res @@ -1,6 +1,6 @@ open ASTTypes.AST -let toString = ASTTypes.Node.toString +let toString = ASTTypes.AST.Node.toString let envs = (samplingInputs, environment) => { samplingInputs: samplingInputs, diff --git a/packages/squiggle-lang/src/rescript/interpreter/ASTEvaluator.res b/packages/squiggle-lang/src/rescript/interpreter/ASTEvaluator.res index 53123ee5..54cee830 100644 --- a/packages/squiggle-lang/src/rescript/interpreter/ASTEvaluator.res +++ b/packages/squiggle-lang/src/rescript/interpreter/ASTEvaluator.res @@ -25,8 +25,8 @@ module AlgebraicCombination = { string, > => E.R.merge( - Render.ensureIsRenderedAndGetShape(evaluationParams, t1), - Render.ensureIsRenderedAndGetShape(evaluationParams, t2), + Node.ensureIsRenderedAndGetShape(evaluationParams, t1), + Node.ensureIsRenderedAndGetShape(evaluationParams, t2), ) |> E.R.fmap(((a, b)) => #RenderedDist(PointSetDist.combineAlgebraically(algebraicOp, a, b))) let nodeScore: node => int = x => @@ -72,7 +72,7 @@ module AlgebraicCombination = { module PointwiseCombination = { let pointwiseAdd = (evaluationParams: evaluationParams, t1: t, t2: t) => - switch (Render.render(evaluationParams, t1), Render.render(evaluationParams, t2)) { + switch (Node.render(evaluationParams, t1), Node.render(evaluationParams, t2)) { | (Ok(#RenderedDist(rs1)), Ok(#RenderedDist(rs2))) => Ok( #RenderedDist( @@ -96,7 +96,7 @@ module PointwiseCombination = { switch // TODO: construct a function that we can easily sample from, to construct // a RenderedDist. Use the xMin and xMax of the rendered pointSetDists to tell the sampling function where to look. // TODO: This should work for symbolic distributions too! - (Render.render(evaluationParams, t1), Render.render(evaluationParams, t2)) { + (Node.render(evaluationParams, t1), Node.render(evaluationParams, t2)) { | (Ok(#RenderedDist(rs1)), Ok(#RenderedDist(rs2))) => Ok(#RenderedDist(PointSetDist.combinePointwise(fn, rs1, rs2))) | (Error(e1), _) => Error(e1) @@ -131,7 +131,7 @@ module Truncate = { let truncateAsShape = (evaluationParams: evaluationParams, leftCutoff, rightCutoff, t) => switch // TODO: use named args for xMin/xMax in renderToShape; if we're lucky we can at least get the tail // of a distribution we otherwise wouldn't get at all - Render.ensureIsRendered(evaluationParams, t) { + Node.ensureIsRendered(evaluationParams, t) { | Ok(#RenderedDist(rs)) => Ok(#RenderedDist(PointSetDist.T.truncate(leftCutoff, rightCutoff, rs))) | Error(e) => Error(e) | _ => Error("Could not truncate distribution.") diff --git a/packages/squiggle-lang/src/rescript/interpreter/ASTTypes.res b/packages/squiggle-lang/src/rescript/interpreter/ASTTypes.res index b9a1049d..85b5ac12 100644 --- a/packages/squiggle-lang/src/rescript/interpreter/ASTTypes.res +++ b/packages/squiggle-lang/src/rescript/interpreter/ASTTypes.res @@ -27,21 +27,6 @@ module AST = { names |> E.A.fmap(name => (name, getByName(hash, name))) } // Have nil as option - let getFloat = (node: node) => - node |> ( - x => - switch x { - | #RenderedDist(Discrete({xyShape: {xs: [x], ys: [1.0]}})) => Some(x) - | #SymbolicDist(#Float(x)) => Some(x) - | _ => None - } - ) - - let toFloatIfNeeded = (node: node) => - switch node |> getFloat { - | Some(float) => #SymbolicDist(#Float(float)) - | None => node - } type samplingInputs = { sampleCount: int, @@ -101,8 +86,43 @@ module AST = { let evaluateAndRetry = (evaluationParams, fn, node) => node |> evaluationParams.evaluateNode(evaluationParams) |> E.R.bind(_, fn(evaluationParams)) - module Render = { - type t = node + module Node = { + let getFloat = (node: node) => + node |> ( + x => + switch x { + | #RenderedDist(Discrete({xyShape: {xs: [x], ys: [1.0]}})) => Some(x) + | #SymbolicDist(#Float(x)) => Some(x) + | _ => None + } + ) + + let rec toString: node => string = x => + switch x { + | #SymbolicDist(d) => SymbolicDist.T.toString(d) + | #RenderedDist(_) => "[renderedShape]" + | #AlgebraicCombination(op, t1, t2) => + Operation.Algebraic.format(op, toString(t1), toString(t2)) + | #PointwiseCombination(op, t1, t2) => + Operation.Pointwise.format(op, toString(t1), toString(t2)) + | #Normalize(t) => "normalize(k" ++ (toString(t) ++ ")") + | #Truncate(lc, rc, t) => Operation.Truncate.toString(lc, rc, toString(t)) + | #Render(t) => toString(t) + | #Symbol(t) => "Symbol: " ++ t + | #FunctionCall(name, args) => + "[Function call: (" ++ + (name ++ + ((args |> E.A.fmap(toString) |> Js.String.concatMany(_, ",")) ++ ")]")) + | #Function(args, internal) => + "[Function: (" ++ ((args |> Js.String.concatMany(_, ",")) ++ (toString(internal) ++ ")]")) + | #Array(a) => "[" ++ ((a |> E.A.fmap(toString) |> Js.String.concatMany(_, ",")) ++ "]") + | #Hash(h) => + "{" ++ + ((h + |> E.A.fmap(((name, value)) => name ++ (":" ++ toString(value))) + |> Js.String.concatMany(_, ",")) ++ + "}") + } let render = (evaluationParams: evaluationParams, r) => #Render(r) |> evaluateNode(evaluationParams) @@ -125,7 +145,7 @@ module AST = { | Error(e) => Error(e) } - let getShape = (item: node) => + let toPointSetDist = (item: node) => switch item { | #RenderedDist(r) => Some(r) | _ => None @@ -138,7 +158,7 @@ module AST = { } let toFloat = (item: node): result => - item |> getShape |> E.O.bind(_, _toFloat) |> E.O.toResult("Not valid shape") + item |> toPointSetDist |> E.O.bind(_, _toFloat) |> E.O.toResult("Not valid shape") } } @@ -155,32 +175,3 @@ module Program = { ] type program = array } - -module Node = { - let rec toString: AST.node => string = x => - switch x { - | #SymbolicDist(d) => SymbolicDist.T.toString(d) - | #RenderedDist(_) => "[renderedShape]" - | #AlgebraicCombination(op, t1, t2) => - Operation.Algebraic.format(op, toString(t1), toString(t2)) - | #PointwiseCombination(op, t1, t2) => - Operation.Pointwise.format(op, toString(t1), toString(t2)) - | #Normalize(t) => "normalize(k" ++ (toString(t) ++ ")") - | #Truncate(lc, rc, t) => Operation.Truncate.toString(lc, rc, toString(t)) - | #Render(t) => toString(t) - | #Symbol(t) => "Symbol: " ++ t - | #FunctionCall(name, args) => - "[Function call: (" ++ - (name ++ - ((args |> E.A.fmap(toString) |> Js.String.concatMany(_, ",")) ++ ")]")) - | #Function(args, internal) => - "[Function: (" ++ ((args |> Js.String.concatMany(_, ",")) ++ (toString(internal) ++ ")]")) - | #Array(a) => "[" ++ ((a |> E.A.fmap(toString) |> Js.String.concatMany(_, ",")) ++ "]") - | #Hash(h) => - "{" ++ - ((h - |> E.A.fmap(((name, value)) => name ++ (":" ++ toString(value))) - |> Js.String.concatMany(_, ",")) ++ - "}") - } -} diff --git a/packages/squiggle-lang/src/rescript/interpreter/PTypes.res b/packages/squiggle-lang/src/rescript/interpreter/PTypes.res index b710c5b2..bc9ac14a 100644 --- a/packages/squiggle-lang/src/rescript/interpreter/PTypes.res +++ b/packages/squiggle-lang/src/rescript/interpreter/PTypes.res @@ -81,7 +81,7 @@ module SamplingDistribution = { let renderIfIsNotSamplingDistribution = (params, t): result => !isSamplingDistribution(t) - ? switch Render.render(params, t) { + ? switch Node.render(params, t) { | Ok(r) => Ok(r) | Error(e) => Error(e) } diff --git a/packages/squiggle-lang/src/rescript/interpreter/typeSystem/TypeSystem.res b/packages/squiggle-lang/src/rescript/interpreter/typeSystem/TypeSystem.res index 9535e40e..95321dee 100644 --- a/packages/squiggle-lang/src/rescript/interpreter/typeSystem/TypeSystem.res +++ b/packages/squiggle-lang/src/rescript/interpreter/typeSystem/TypeSystem.res @@ -1,5 +1,5 @@ type node = ASTTypes.AST.node -let getFloat = ASTTypes.AST.getFloat +let getFloat = ASTTypes.AST.Node.getFloat type samplingDist = [ | #SymbolicDist(SymbolicDistTypes.symbolicDist) @@ -61,7 +61,7 @@ module TypedValue = { |> E.A.fmap(((name, t)) => fromNode(t) |> E.R.fmap(r => (name, r))) |> E.A.R.firstErrorOrOpen |> E.R.fmap(r => #Hash(r)) - | e => Error("Wrong type: " ++ ASTTypes.Node.toString(e)) + | e => Error("Wrong type: " ++ ASTTypes.AST.Node.toString(e)) } // todo: Arrays and hashes @@ -78,7 +78,7 @@ module TypedValue = { node, ) |> E.R.bind(_, fromNode) | (#RenderedDistribution, _) => - ASTTypes.AST.Render.render(evaluationParams, node) |> E.R.bind(_, fromNode) + ASTTypes.AST.Node.render(evaluationParams, node) |> E.R.bind(_, fromNode) | (#Array(_type), #Array(b)) => b |> E.A.fmap(fromNodeWithTypeCoercion(evaluationParams, _type))