Refactored AST file

This commit is contained in:
Ozzie Gooen 2022-02-16 17:10:48 -05:00
parent 24a7d0eedf
commit d8b37bb113
5 changed files with 49 additions and 58 deletions

View File

@ -1,6 +1,6 @@
open ASTTypes.AST open ASTTypes.AST
let toString = ASTTypes.Node.toString let toString = ASTTypes.AST.Node.toString
let envs = (samplingInputs, environment) => { let envs = (samplingInputs, environment) => {
samplingInputs: samplingInputs, samplingInputs: samplingInputs,

View File

@ -25,8 +25,8 @@ module AlgebraicCombination = {
string, string,
> => > =>
E.R.merge( E.R.merge(
Render.ensureIsRenderedAndGetShape(evaluationParams, t1), Node.ensureIsRenderedAndGetShape(evaluationParams, t1),
Render.ensureIsRenderedAndGetShape(evaluationParams, t2), Node.ensureIsRenderedAndGetShape(evaluationParams, t2),
) |> E.R.fmap(((a, b)) => #RenderedDist(PointSetDist.combineAlgebraically(algebraicOp, a, b))) ) |> E.R.fmap(((a, b)) => #RenderedDist(PointSetDist.combineAlgebraically(algebraicOp, a, b)))
let nodeScore: node => int = x => let nodeScore: node => int = x =>
@ -72,7 +72,7 @@ module AlgebraicCombination = {
module PointwiseCombination = { module PointwiseCombination = {
let pointwiseAdd = (evaluationParams: evaluationParams, t1: t, t2: t) => let pointwiseAdd = (evaluationParams: evaluationParams, t1: t, t2: t) =>
switch (Render.render(evaluationParams, t1), Render.render(evaluationParams, t2)) { switch (Node.render(evaluationParams, t1), Node.render(evaluationParams, t2)) {
| (Ok(#RenderedDist(rs1)), Ok(#RenderedDist(rs2))) => | (Ok(#RenderedDist(rs1)), Ok(#RenderedDist(rs2))) =>
Ok( Ok(
#RenderedDist( #RenderedDist(
@ -96,7 +96,7 @@ module PointwiseCombination = {
switch // TODO: construct a function that we can easily sample from, to construct switch // TODO: construct a function that we can easily sample from, to construct
// a RenderedDist. Use the xMin and xMax of the rendered pointSetDists to tell the sampling function where to look. // a RenderedDist. Use the xMin and xMax of the rendered pointSetDists to tell the sampling function where to look.
// TODO: This should work for symbolic distributions too! // TODO: This should work for symbolic distributions too!
(Render.render(evaluationParams, t1), Render.render(evaluationParams, t2)) { (Node.render(evaluationParams, t1), Node.render(evaluationParams, t2)) {
| (Ok(#RenderedDist(rs1)), Ok(#RenderedDist(rs2))) => | (Ok(#RenderedDist(rs1)), Ok(#RenderedDist(rs2))) =>
Ok(#RenderedDist(PointSetDist.combinePointwise(fn, rs1, rs2))) Ok(#RenderedDist(PointSetDist.combinePointwise(fn, rs1, rs2)))
| (Error(e1), _) => Error(e1) | (Error(e1), _) => Error(e1)
@ -131,7 +131,7 @@ module Truncate = {
let truncateAsShape = (evaluationParams: evaluationParams, leftCutoff, rightCutoff, t) => let truncateAsShape = (evaluationParams: evaluationParams, leftCutoff, rightCutoff, t) =>
switch // TODO: use named args for xMin/xMax in renderToShape; if we're lucky we can at least get the tail switch // TODO: use named args for xMin/xMax in renderToShape; if we're lucky we can at least get the tail
// of a distribution we otherwise wouldn't get at all // of a distribution we otherwise wouldn't get at all
Render.ensureIsRendered(evaluationParams, t) { Node.ensureIsRendered(evaluationParams, t) {
| Ok(#RenderedDist(rs)) => Ok(#RenderedDist(PointSetDist.T.truncate(leftCutoff, rightCutoff, rs))) | Ok(#RenderedDist(rs)) => Ok(#RenderedDist(PointSetDist.T.truncate(leftCutoff, rightCutoff, rs)))
| Error(e) => Error(e) | Error(e) => Error(e)
| _ => Error("Could not truncate distribution.") | _ => Error("Could not truncate distribution.")

View File

@ -27,21 +27,6 @@ module AST = {
names |> E.A.fmap(name => (name, getByName(hash, name))) names |> E.A.fmap(name => (name, getByName(hash, name)))
} }
// Have nil as option // Have nil as option
let getFloat = (node: node) =>
node |> (
x =>
switch x {
| #RenderedDist(Discrete({xyShape: {xs: [x], ys: [1.0]}})) => Some(x)
| #SymbolicDist(#Float(x)) => Some(x)
| _ => None
}
)
let toFloatIfNeeded = (node: node) =>
switch node |> getFloat {
| Some(float) => #SymbolicDist(#Float(float))
| None => node
}
type samplingInputs = { type samplingInputs = {
sampleCount: int, sampleCount: int,
@ -101,63 +86,18 @@ module AST = {
let evaluateAndRetry = (evaluationParams, fn, node) => let evaluateAndRetry = (evaluationParams, fn, node) =>
node |> evaluationParams.evaluateNode(evaluationParams) |> E.R.bind(_, fn(evaluationParams)) node |> evaluationParams.evaluateNode(evaluationParams) |> E.R.bind(_, fn(evaluationParams))
module Render = { module Node = {
type t = node let getFloat = (node: node) =>
node |> (
let render = (evaluationParams: evaluationParams, r) => x =>
#Render(r) |> evaluateNode(evaluationParams) switch x {
| #RenderedDist(Discrete({xyShape: {xs: [x], ys: [1.0]}})) => Some(x)
let ensureIsRendered = (params, t) => | #SymbolicDist(#Float(x)) => Some(x)
switch t {
| #RenderedDist(_) => Ok(t)
| _ =>
switch render(params, t) {
| Ok(#RenderedDist(r)) => Ok(#RenderedDist(r))
| Ok(_) => Error("Did not render as requested")
| Error(e) => Error(e)
}
}
let ensureIsRenderedAndGetShape = (params, t) =>
switch ensureIsRendered(params, t) {
| Ok(#RenderedDist(r)) => Ok(r)
| Ok(_) => Error("Did not render as requested")
| Error(e) => Error(e)
}
let getShape = (item: node) =>
switch item {
| #RenderedDist(r) => Some(r)
| _ => None | _ => None
} }
)
let _toFloat = (t: PointSetTypes.pointSetDist) => let rec toString: node => string = x =>
switch t {
| Discrete({xyShape: {xs: [x], ys: [1.0]}}) => Some(#SymbolicDist(#Float(x)))
| _ => None
}
let toFloat = (item: node): result<node, string> =>
item |> getShape |> E.O.bind(_, _toFloat) |> E.O.toResult("Not valid shape")
}
}
type simplificationResult = [
| #Solution(AST.node)
| #Error(string)
| #NoSolution
]
module Program = {
type statement = [
| #Assignment(string, AST.node)
| #Expression(AST.node)
]
type program = array<statement>
}
module Node = {
let rec toString: AST.node => string = x =>
switch x { switch x {
| #SymbolicDist(d) => SymbolicDist.T.toString(d) | #SymbolicDist(d) => SymbolicDist.T.toString(d)
| #RenderedDist(_) => "[renderedShape]" | #RenderedDist(_) => "[renderedShape]"
@ -183,4 +123,55 @@ module Node = {
|> Js.String.concatMany(_, ",")) ++ |> Js.String.concatMany(_, ",")) ++
"}") "}")
} }
let render = (evaluationParams: evaluationParams, r) =>
#Render(r) |> evaluateNode(evaluationParams)
let ensureIsRendered = (params, t) =>
switch t {
| #RenderedDist(_) => Ok(t)
| _ =>
switch render(params, t) {
| Ok(#RenderedDist(r)) => Ok(#RenderedDist(r))
| Ok(_) => Error("Did not render as requested")
| Error(e) => Error(e)
}
}
let ensureIsRenderedAndGetShape = (params, t) =>
switch ensureIsRendered(params, t) {
| Ok(#RenderedDist(r)) => Ok(r)
| Ok(_) => Error("Did not render as requested")
| Error(e) => Error(e)
}
let toPointSetDist = (item: node) =>
switch item {
| #RenderedDist(r) => Some(r)
| _ => None
}
let _toFloat = (t: PointSetTypes.pointSetDist) =>
switch t {
| Discrete({xyShape: {xs: [x], ys: [1.0]}}) => Some(#SymbolicDist(#Float(x)))
| _ => None
}
let toFloat = (item: node): result<node, string> =>
item |> toPointSetDist |> E.O.bind(_, _toFloat) |> E.O.toResult("Not valid shape")
}
}
type simplificationResult = [
| #Solution(AST.node)
| #Error(string)
| #NoSolution
]
module Program = {
type statement = [
| #Assignment(string, AST.node)
| #Expression(AST.node)
]
type program = array<statement>
} }

View File

@ -81,7 +81,7 @@ module SamplingDistribution = {
let renderIfIsNotSamplingDistribution = (params, t): result<node, string> => let renderIfIsNotSamplingDistribution = (params, t): result<node, string> =>
!isSamplingDistribution(t) !isSamplingDistribution(t)
? switch Render.render(params, t) { ? switch Node.render(params, t) {
| Ok(r) => Ok(r) | Ok(r) => Ok(r)
| Error(e) => Error(e) | Error(e) => Error(e)
} }

View File

@ -1,5 +1,5 @@
type node = ASTTypes.AST.node type node = ASTTypes.AST.node
let getFloat = ASTTypes.AST.getFloat let getFloat = ASTTypes.AST.Node.getFloat
type samplingDist = [ type samplingDist = [
| #SymbolicDist(SymbolicDistTypes.symbolicDist) | #SymbolicDist(SymbolicDistTypes.symbolicDist)
@ -61,7 +61,7 @@ module TypedValue = {
|> E.A.fmap(((name, t)) => fromNode(t) |> E.R.fmap(r => (name, r))) |> E.A.fmap(((name, t)) => fromNode(t) |> E.R.fmap(r => (name, r)))
|> E.A.R.firstErrorOrOpen |> E.A.R.firstErrorOrOpen
|> E.R.fmap(r => #Hash(r)) |> E.R.fmap(r => #Hash(r))
| e => Error("Wrong type: " ++ ASTTypes.Node.toString(e)) | e => Error("Wrong type: " ++ ASTTypes.AST.Node.toString(e))
} }
// todo: Arrays and hashes // todo: Arrays and hashes
@ -78,7 +78,7 @@ module TypedValue = {
node, node,
) |> E.R.bind(_, fromNode) ) |> E.R.bind(_, fromNode)
| (#RenderedDistribution, _) => | (#RenderedDistribution, _) =>
ASTTypes.AST.Render.render(evaluationParams, node) |> E.R.bind(_, fromNode) ASTTypes.AST.Node.render(evaluationParams, node) |> E.R.bind(_, fromNode)
| (#Array(_type), #Array(b)) => | (#Array(_type), #Array(b)) =>
b b
|> E.A.fmap(fromNodeWithTypeCoercion(evaluationParams, _type)) |> E.A.fmap(fromNodeWithTypeCoercion(evaluationParams, _type))