Refactored AST file
This commit is contained in:
parent
24a7d0eedf
commit
d8b37bb113
|
@ -1,6 +1,6 @@
|
||||||
open ASTTypes.AST
|
open ASTTypes.AST
|
||||||
|
|
||||||
let toString = ASTTypes.Node.toString
|
let toString = ASTTypes.AST.Node.toString
|
||||||
|
|
||||||
let envs = (samplingInputs, environment) => {
|
let envs = (samplingInputs, environment) => {
|
||||||
samplingInputs: samplingInputs,
|
samplingInputs: samplingInputs,
|
||||||
|
|
|
@ -25,8 +25,8 @@ module AlgebraicCombination = {
|
||||||
string,
|
string,
|
||||||
> =>
|
> =>
|
||||||
E.R.merge(
|
E.R.merge(
|
||||||
Render.ensureIsRenderedAndGetShape(evaluationParams, t1),
|
Node.ensureIsRenderedAndGetShape(evaluationParams, t1),
|
||||||
Render.ensureIsRenderedAndGetShape(evaluationParams, t2),
|
Node.ensureIsRenderedAndGetShape(evaluationParams, t2),
|
||||||
) |> E.R.fmap(((a, b)) => #RenderedDist(PointSetDist.combineAlgebraically(algebraicOp, a, b)))
|
) |> E.R.fmap(((a, b)) => #RenderedDist(PointSetDist.combineAlgebraically(algebraicOp, a, b)))
|
||||||
|
|
||||||
let nodeScore: node => int = x =>
|
let nodeScore: node => int = x =>
|
||||||
|
@ -72,7 +72,7 @@ module AlgebraicCombination = {
|
||||||
|
|
||||||
module PointwiseCombination = {
|
module PointwiseCombination = {
|
||||||
let pointwiseAdd = (evaluationParams: evaluationParams, t1: t, t2: t) =>
|
let pointwiseAdd = (evaluationParams: evaluationParams, t1: t, t2: t) =>
|
||||||
switch (Render.render(evaluationParams, t1), Render.render(evaluationParams, t2)) {
|
switch (Node.render(evaluationParams, t1), Node.render(evaluationParams, t2)) {
|
||||||
| (Ok(#RenderedDist(rs1)), Ok(#RenderedDist(rs2))) =>
|
| (Ok(#RenderedDist(rs1)), Ok(#RenderedDist(rs2))) =>
|
||||||
Ok(
|
Ok(
|
||||||
#RenderedDist(
|
#RenderedDist(
|
||||||
|
@ -96,7 +96,7 @@ module PointwiseCombination = {
|
||||||
switch // TODO: construct a function that we can easily sample from, to construct
|
switch // TODO: construct a function that we can easily sample from, to construct
|
||||||
// a RenderedDist. Use the xMin and xMax of the rendered pointSetDists to tell the sampling function where to look.
|
// a RenderedDist. Use the xMin and xMax of the rendered pointSetDists to tell the sampling function where to look.
|
||||||
// TODO: This should work for symbolic distributions too!
|
// TODO: This should work for symbolic distributions too!
|
||||||
(Render.render(evaluationParams, t1), Render.render(evaluationParams, t2)) {
|
(Node.render(evaluationParams, t1), Node.render(evaluationParams, t2)) {
|
||||||
| (Ok(#RenderedDist(rs1)), Ok(#RenderedDist(rs2))) =>
|
| (Ok(#RenderedDist(rs1)), Ok(#RenderedDist(rs2))) =>
|
||||||
Ok(#RenderedDist(PointSetDist.combinePointwise(fn, rs1, rs2)))
|
Ok(#RenderedDist(PointSetDist.combinePointwise(fn, rs1, rs2)))
|
||||||
| (Error(e1), _) => Error(e1)
|
| (Error(e1), _) => Error(e1)
|
||||||
|
@ -131,7 +131,7 @@ module Truncate = {
|
||||||
let truncateAsShape = (evaluationParams: evaluationParams, leftCutoff, rightCutoff, t) =>
|
let truncateAsShape = (evaluationParams: evaluationParams, leftCutoff, rightCutoff, t) =>
|
||||||
switch // TODO: use named args for xMin/xMax in renderToShape; if we're lucky we can at least get the tail
|
switch // TODO: use named args for xMin/xMax in renderToShape; if we're lucky we can at least get the tail
|
||||||
// of a distribution we otherwise wouldn't get at all
|
// of a distribution we otherwise wouldn't get at all
|
||||||
Render.ensureIsRendered(evaluationParams, t) {
|
Node.ensureIsRendered(evaluationParams, t) {
|
||||||
| Ok(#RenderedDist(rs)) => Ok(#RenderedDist(PointSetDist.T.truncate(leftCutoff, rightCutoff, rs)))
|
| Ok(#RenderedDist(rs)) => Ok(#RenderedDist(PointSetDist.T.truncate(leftCutoff, rightCutoff, rs)))
|
||||||
| Error(e) => Error(e)
|
| Error(e) => Error(e)
|
||||||
| _ => Error("Could not truncate distribution.")
|
| _ => Error("Could not truncate distribution.")
|
||||||
|
|
|
@ -27,21 +27,6 @@ module AST = {
|
||||||
names |> E.A.fmap(name => (name, getByName(hash, name)))
|
names |> E.A.fmap(name => (name, getByName(hash, name)))
|
||||||
}
|
}
|
||||||
// Have nil as option
|
// Have nil as option
|
||||||
let getFloat = (node: node) =>
|
|
||||||
node |> (
|
|
||||||
x =>
|
|
||||||
switch x {
|
|
||||||
| #RenderedDist(Discrete({xyShape: {xs: [x], ys: [1.0]}})) => Some(x)
|
|
||||||
| #SymbolicDist(#Float(x)) => Some(x)
|
|
||||||
| _ => None
|
|
||||||
}
|
|
||||||
)
|
|
||||||
|
|
||||||
let toFloatIfNeeded = (node: node) =>
|
|
||||||
switch node |> getFloat {
|
|
||||||
| Some(float) => #SymbolicDist(#Float(float))
|
|
||||||
| None => node
|
|
||||||
}
|
|
||||||
|
|
||||||
type samplingInputs = {
|
type samplingInputs = {
|
||||||
sampleCount: int,
|
sampleCount: int,
|
||||||
|
@ -101,8 +86,43 @@ module AST = {
|
||||||
let evaluateAndRetry = (evaluationParams, fn, node) =>
|
let evaluateAndRetry = (evaluationParams, fn, node) =>
|
||||||
node |> evaluationParams.evaluateNode(evaluationParams) |> E.R.bind(_, fn(evaluationParams))
|
node |> evaluationParams.evaluateNode(evaluationParams) |> E.R.bind(_, fn(evaluationParams))
|
||||||
|
|
||||||
module Render = {
|
module Node = {
|
||||||
type t = node
|
let getFloat = (node: node) =>
|
||||||
|
node |> (
|
||||||
|
x =>
|
||||||
|
switch x {
|
||||||
|
| #RenderedDist(Discrete({xyShape: {xs: [x], ys: [1.0]}})) => Some(x)
|
||||||
|
| #SymbolicDist(#Float(x)) => Some(x)
|
||||||
|
| _ => None
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
let rec toString: node => string = x =>
|
||||||
|
switch x {
|
||||||
|
| #SymbolicDist(d) => SymbolicDist.T.toString(d)
|
||||||
|
| #RenderedDist(_) => "[renderedShape]"
|
||||||
|
| #AlgebraicCombination(op, t1, t2) =>
|
||||||
|
Operation.Algebraic.format(op, toString(t1), toString(t2))
|
||||||
|
| #PointwiseCombination(op, t1, t2) =>
|
||||||
|
Operation.Pointwise.format(op, toString(t1), toString(t2))
|
||||||
|
| #Normalize(t) => "normalize(k" ++ (toString(t) ++ ")")
|
||||||
|
| #Truncate(lc, rc, t) => Operation.Truncate.toString(lc, rc, toString(t))
|
||||||
|
| #Render(t) => toString(t)
|
||||||
|
| #Symbol(t) => "Symbol: " ++ t
|
||||||
|
| #FunctionCall(name, args) =>
|
||||||
|
"[Function call: (" ++
|
||||||
|
(name ++
|
||||||
|
((args |> E.A.fmap(toString) |> Js.String.concatMany(_, ",")) ++ ")]"))
|
||||||
|
| #Function(args, internal) =>
|
||||||
|
"[Function: (" ++ ((args |> Js.String.concatMany(_, ",")) ++ (toString(internal) ++ ")]"))
|
||||||
|
| #Array(a) => "[" ++ ((a |> E.A.fmap(toString) |> Js.String.concatMany(_, ",")) ++ "]")
|
||||||
|
| #Hash(h) =>
|
||||||
|
"{" ++
|
||||||
|
((h
|
||||||
|
|> E.A.fmap(((name, value)) => name ++ (":" ++ toString(value)))
|
||||||
|
|> Js.String.concatMany(_, ",")) ++
|
||||||
|
"}")
|
||||||
|
}
|
||||||
|
|
||||||
let render = (evaluationParams: evaluationParams, r) =>
|
let render = (evaluationParams: evaluationParams, r) =>
|
||||||
#Render(r) |> evaluateNode(evaluationParams)
|
#Render(r) |> evaluateNode(evaluationParams)
|
||||||
|
@ -125,7 +145,7 @@ module AST = {
|
||||||
| Error(e) => Error(e)
|
| Error(e) => Error(e)
|
||||||
}
|
}
|
||||||
|
|
||||||
let getShape = (item: node) =>
|
let toPointSetDist = (item: node) =>
|
||||||
switch item {
|
switch item {
|
||||||
| #RenderedDist(r) => Some(r)
|
| #RenderedDist(r) => Some(r)
|
||||||
| _ => None
|
| _ => None
|
||||||
|
@ -138,7 +158,7 @@ module AST = {
|
||||||
}
|
}
|
||||||
|
|
||||||
let toFloat = (item: node): result<node, string> =>
|
let toFloat = (item: node): result<node, string> =>
|
||||||
item |> getShape |> E.O.bind(_, _toFloat) |> E.O.toResult("Not valid shape")
|
item |> toPointSetDist |> E.O.bind(_, _toFloat) |> E.O.toResult("Not valid shape")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -155,32 +175,3 @@ module Program = {
|
||||||
]
|
]
|
||||||
type program = array<statement>
|
type program = array<statement>
|
||||||
}
|
}
|
||||||
|
|
||||||
module Node = {
|
|
||||||
let rec toString: AST.node => string = x =>
|
|
||||||
switch x {
|
|
||||||
| #SymbolicDist(d) => SymbolicDist.T.toString(d)
|
|
||||||
| #RenderedDist(_) => "[renderedShape]"
|
|
||||||
| #AlgebraicCombination(op, t1, t2) =>
|
|
||||||
Operation.Algebraic.format(op, toString(t1), toString(t2))
|
|
||||||
| #PointwiseCombination(op, t1, t2) =>
|
|
||||||
Operation.Pointwise.format(op, toString(t1), toString(t2))
|
|
||||||
| #Normalize(t) => "normalize(k" ++ (toString(t) ++ ")")
|
|
||||||
| #Truncate(lc, rc, t) => Operation.Truncate.toString(lc, rc, toString(t))
|
|
||||||
| #Render(t) => toString(t)
|
|
||||||
| #Symbol(t) => "Symbol: " ++ t
|
|
||||||
| #FunctionCall(name, args) =>
|
|
||||||
"[Function call: (" ++
|
|
||||||
(name ++
|
|
||||||
((args |> E.A.fmap(toString) |> Js.String.concatMany(_, ",")) ++ ")]"))
|
|
||||||
| #Function(args, internal) =>
|
|
||||||
"[Function: (" ++ ((args |> Js.String.concatMany(_, ",")) ++ (toString(internal) ++ ")]"))
|
|
||||||
| #Array(a) => "[" ++ ((a |> E.A.fmap(toString) |> Js.String.concatMany(_, ",")) ++ "]")
|
|
||||||
| #Hash(h) =>
|
|
||||||
"{" ++
|
|
||||||
((h
|
|
||||||
|> E.A.fmap(((name, value)) => name ++ (":" ++ toString(value)))
|
|
||||||
|> Js.String.concatMany(_, ",")) ++
|
|
||||||
"}")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
|
@ -81,7 +81,7 @@ module SamplingDistribution = {
|
||||||
|
|
||||||
let renderIfIsNotSamplingDistribution = (params, t): result<node, string> =>
|
let renderIfIsNotSamplingDistribution = (params, t): result<node, string> =>
|
||||||
!isSamplingDistribution(t)
|
!isSamplingDistribution(t)
|
||||||
? switch Render.render(params, t) {
|
? switch Node.render(params, t) {
|
||||||
| Ok(r) => Ok(r)
|
| Ok(r) => Ok(r)
|
||||||
| Error(e) => Error(e)
|
| Error(e) => Error(e)
|
||||||
}
|
}
|
||||||
|
|
|
@ -1,5 +1,5 @@
|
||||||
type node = ASTTypes.AST.node
|
type node = ASTTypes.AST.node
|
||||||
let getFloat = ASTTypes.AST.getFloat
|
let getFloat = ASTTypes.AST.Node.getFloat
|
||||||
|
|
||||||
type samplingDist = [
|
type samplingDist = [
|
||||||
| #SymbolicDist(SymbolicDistTypes.symbolicDist)
|
| #SymbolicDist(SymbolicDistTypes.symbolicDist)
|
||||||
|
@ -61,7 +61,7 @@ module TypedValue = {
|
||||||
|> E.A.fmap(((name, t)) => fromNode(t) |> E.R.fmap(r => (name, r)))
|
|> E.A.fmap(((name, t)) => fromNode(t) |> E.R.fmap(r => (name, r)))
|
||||||
|> E.A.R.firstErrorOrOpen
|
|> E.A.R.firstErrorOrOpen
|
||||||
|> E.R.fmap(r => #Hash(r))
|
|> E.R.fmap(r => #Hash(r))
|
||||||
| e => Error("Wrong type: " ++ ASTTypes.Node.toString(e))
|
| e => Error("Wrong type: " ++ ASTTypes.AST.Node.toString(e))
|
||||||
}
|
}
|
||||||
|
|
||||||
// todo: Arrays and hashes
|
// todo: Arrays and hashes
|
||||||
|
@ -78,7 +78,7 @@ module TypedValue = {
|
||||||
node,
|
node,
|
||||||
) |> E.R.bind(_, fromNode)
|
) |> E.R.bind(_, fromNode)
|
||||||
| (#RenderedDistribution, _) =>
|
| (#RenderedDistribution, _) =>
|
||||||
ASTTypes.AST.Render.render(evaluationParams, node) |> E.R.bind(_, fromNode)
|
ASTTypes.AST.Node.render(evaluationParams, node) |> E.R.bind(_, fromNode)
|
||||||
| (#Array(_type), #Array(b)) =>
|
| (#Array(_type), #Array(b)) =>
|
||||||
b
|
b
|
||||||
|> E.A.fmap(fromNodeWithTypeCoercion(evaluationParams, _type))
|
|> E.A.fmap(fromNodeWithTypeCoercion(evaluationParams, _type))
|
||||||
|
|
Loading…
Reference in New Issue
Block a user