Refactored Operation.res to live out of the interpreter

This commit is contained in:
Ozzie Gooen 2022-02-16 14:57:46 -05:00
parent 45017f3145
commit 24a7d0eedf
13 changed files with 67 additions and 77 deletions

View File

@ -1,6 +1,6 @@
open ASTTypes.AST
let toString = ASTBasic.toString
let toString = ASTTypes.Node.toString
let envs = (samplingInputs, environment) => {
samplingInputs: samplingInputs,

View File

@ -1,27 +0,0 @@
open ASTTypes.AST
// This file exists to manage a dependency cycle. It would be good to refactor later.
let rec toString: node => string = x =>
switch x {
| #SymbolicDist(d) => SymbolicDist.T.toString(d)
| #RenderedDist(_) => "[renderedShape]"
| #AlgebraicCombination(op, t1, t2) => Operation.Algebraic.format(op, toString(t1), toString(t2))
| #PointwiseCombination(op, t1, t2) => Operation.Pointwise.format(op, toString(t1), toString(t2))
| #Normalize(t) => "normalize(k" ++ (toString(t) ++ ")")
| #Truncate(lc, rc, t) => Operation.T.truncateToString(lc, rc, toString(t))
| #Render(t) => toString(t)
| #Symbol(t) => "Symbol: " ++ t
| #FunctionCall(name, args) =>
"[Function call: (" ++
(name ++
((args |> E.A.fmap(toString) |> Js.String.concatMany(_, ",")) ++ ")]"))
| #Function(args, internal) =>
"[Function: (" ++ ((args |> Js.String.concatMany(_, ",")) ++ (toString(internal) ++ ")]"))
| #Array(a) => "[" ++ ((a |> E.A.fmap(toString) |> Js.String.concatMany(_, ",")) ++ "]")
| #Hash(h) =>
"{" ++
((h
|> E.A.fmap(((name, value)) => name ++ (":" ++ toString(value)))
|> Js.String.concatMany(_, ",")) ++
"}")
}

View File

@ -56,7 +56,7 @@ module AlgebraicCombination = {
let operationToLeaf = (
evaluationParams: evaluationParams,
algebraicOp: ASTTypes.algebraicOperation,
algebraicOp: Operation.algebraicOperation,
t1: t,
t2: t,
): result<node, string> =>
@ -106,7 +106,7 @@ module PointwiseCombination = {
let operationToLeaf = (
evaluationParams: evaluationParams,
pointwiseOp: pointwiseOperation,
pointwiseOp: Operation.pointwiseOperation,
t1: t,
t2: t,
) =>

View File

@ -1,20 +1,3 @@
type algebraicOperation = [
| #Add
| #Multiply
| #Subtract
| #Divide
| #Exponentiate
]
type pointwiseOperation = [#Add | #Multiply | #Exponentiate]
type scaleOperation = [#Multiply | #Exponentiate | #Log]
type distToFloatOperation = [
| #Pdf(float)
| #Cdf(float)
| #Inv(float)
| #Mean
| #Sample
]
module AST = {
type rec hash = array<(string, node)>
and node = [
@ -24,8 +7,8 @@ module AST = {
| #Hash(hash)
| #Array(array<node>)
| #Function(array<string>, node)
| #AlgebraicCombination(algebraicOperation, node, node)
| #PointwiseCombination(pointwiseOperation, node, node)
| #AlgebraicCombination(Operation.algebraicOperation, node, node)
| #PointwiseCombination(Operation.pointwiseOperation, node, node)
| #Normalize(node)
| #Render(node)
| #Truncate(option<float>, option<float>, node)
@ -172,3 +155,32 @@ module Program = {
]
type program = array<statement>
}
module Node = {
let rec toString: AST.node => string = x =>
switch x {
| #SymbolicDist(d) => SymbolicDist.T.toString(d)
| #RenderedDist(_) => "[renderedShape]"
| #AlgebraicCombination(op, t1, t2) =>
Operation.Algebraic.format(op, toString(t1), toString(t2))
| #PointwiseCombination(op, t1, t2) =>
Operation.Pointwise.format(op, toString(t1), toString(t2))
| #Normalize(t) => "normalize(k" ++ (toString(t) ++ ")")
| #Truncate(lc, rc, t) => Operation.Truncate.toString(lc, rc, toString(t))
| #Render(t) => toString(t)
| #Symbol(t) => "Symbol: " ++ t
| #FunctionCall(name, args) =>
"[Function call: (" ++
(name ++
((args |> E.A.fmap(toString) |> Js.String.concatMany(_, ",")) ++ ")]"))
| #Function(args, internal) =>
"[Function: (" ++ ((args |> Js.String.concatMany(_, ",")) ++ (toString(internal) ++ ")]"))
| #Array(a) => "[" ++ ((a |> E.A.fmap(toString) |> Js.String.concatMany(_, ",")) ++ "]")
| #Hash(h) =>
"{" ++
((h
|> E.A.fmap(((name, value)) => name ++ (":" ++ toString(value)))
|> Js.String.concatMany(_, ",")) ++
"}")
}
}

View File

@ -84,7 +84,7 @@ let makeDist = (name, fn) =>
)
let floatFromDist = (
distToFloatOp: ASTTypes.distToFloatOperation,
distToFloatOp: Operation.distToFloatOperation,
t: TypeSystem.samplingDist,
): result<node, string> =>
switch t {

View File

@ -61,7 +61,7 @@ module TypedValue = {
|> E.A.fmap(((name, t)) => fromNode(t) |> E.R.fmap(r => (name, r)))
|> E.A.R.firstErrorOrOpen
|> E.R.fmap(r => #Hash(r))
| e => Error("Wrong type: " ++ ASTBasic.toString(e))
| e => Error("Wrong type: " ++ ASTTypes.Node.toString(e))
}
// todo: Arrays and hashes

View File

@ -96,7 +96,7 @@ let toDiscretePointMassesFromTriangulars = (
}
let combineShapesContinuousContinuous = (
op: ASTTypes.algebraicOperation,
op: Operation.algebraicOperation,
s1: PointSetTypes.xyShape,
s2: PointSetTypes.xyShape,
): PointSetTypes.xyShape => {
@ -200,7 +200,7 @@ let toDiscretePointMassesFromDiscrete = (s: PointSetTypes.xyShape): pointMassesW
}
let combineShapesContinuousDiscrete = (
op: ASTTypes.algebraicOperation,
op: Operation.algebraicOperation,
continuousShape: PointSetTypes.xyShape,
discreteShape: PointSetTypes.xyShape,
): PointSetTypes.xyShape => {

View File

@ -211,7 +211,7 @@ module T = Dist({
/* This simply creates multiple copies of the continuous distribution, scaled and shifted according to
each discrete data point, and then adds them all together. */
let combineAlgebraicallyWithDiscrete = (
op: ASTTypes.algebraicOperation,
op: Operation.algebraicOperation,
t1: t,
t2: PointSetTypes.discreteShape,
) => {
@ -244,7 +244,7 @@ let combineAlgebraicallyWithDiscrete = (
}
}
let combineAlgebraically = (op: ASTTypes.algebraicOperation, t1: t, t2: t) => {
let combineAlgebraically = (op: Operation.algebraicOperation, t1: t, t2: t) => {
let s1 = t1 |> getShape
let s2 = t2 |> getShape
let t1n = s1 |> XYShape.T.length

View File

@ -85,7 +85,7 @@ let updateIntegralCache = (integralCache, t: t): t => {
/* This multiples all of the data points together and creates a new discrete distribution from the results.
Data points at the same xs get added together. It may be a good idea to downsample t1 and t2 before and/or the result after. */
let combineAlgebraically = (op: ASTTypes.algebraicOperation, t1: t, t2: t): t => {
let combineAlgebraically = (op: Operation.algebraicOperation, t1: t, t2: t): t => {
let t1s = t1 |> getShape
let t2s = t2 |> getShape
let t1n = t1s |> XYShape.T.length

View File

@ -227,7 +227,7 @@ module T = Dist({
}
})
let combineAlgebraically = (op: ASTTypes.algebraicOperation, t1: t, t2: t): t => {
let combineAlgebraically = (op: Operation.algebraicOperation, t1: t, t2: t): t => {
// Discrete convolution can cause a huge increase in the number of samples,
// so we'll first downsample.

View File

@ -33,7 +33,7 @@ let toMixed = mapToAll((
),
))
let combineAlgebraically = (op: ASTTypes.algebraicOperation, t1: t, t2: t): t =>
let combineAlgebraically = (op: Operation.algebraicOperation, t1: t, t2: t): t =>
switch (t1, t2) {
| (Continuous(m1), Continuous(m2)) =>
Continuous.combineAlgebraically(op, m1, m2) |> Continuous.T.toPointSetDist
@ -197,7 +197,7 @@ let sampleNRendered = (n, dist) => {
doN(n, () => sample(distWithUpdatedIntegralCache))
}
let operate = (distToFloatOp: ASTTypes.distToFloatOperation, s): float =>
let operate = (distToFloatOp: Operation.distToFloatOperation, s): float =>
switch distToFloatOp {
| #Pdf(f) => pdf(f, s)
| #Cdf(f) => pdf(f, s)

View File

@ -272,7 +272,7 @@ module T = {
| #Float(n) => Float.mean(n)
}
let operate = (distToFloatOp: ASTTypes.distToFloatOperation, s) =>
let operate = (distToFloatOp: Operation.distToFloatOperation, s) =>
switch distToFloatOp {
| #Cdf(f) => Ok(cdf(f, s))
| #Pdf(f) => Ok(pdf(f, s))
@ -302,7 +302,7 @@ module T = {
let tryAnalyticalSimplification = (
d1: symbolicDist,
d2: symbolicDist,
op: ASTTypes.algebraicOperation,
op: Operation.algebraicOperation,
): analyticalSimplificationResult =>
switch (d1, d2) {
| (#Float(v1), #Float(v2)) =>

View File

@ -1,4 +1,21 @@
open ASTTypes
// This file has no dependencies. It's used outside of the interpreter, but the interpreter depends on it.
type algebraicOperation = [
| #Add
| #Multiply
| #Subtract
| #Divide
| #Exponentiate
]
type pointwiseOperation = [#Add | #Multiply | #Exponentiate]
type scaleOperation = [#Multiply | #Exponentiate | #Log]
type distToFloatOperation = [
| #Pdf(float)
| #Cdf(float)
| #Inv(float)
| #Mean
| #Sample
]
module Algebraic = {
type t = algebraicOperation
@ -86,22 +103,10 @@ module Scale = {
}
}
module T = {
let truncateToString = (left: option<float>, right: option<float>, nodeToString) => {
module Truncate = {
let toString = (left: option<float>, right: option<float>, nodeToString) => {
let left = left |> E.O.dimap(Js.Float.toString, () => "-inf")
let right = right |> E.O.dimap(Js.Float.toString, () => "inf")
j`truncate($nodeToString, $left, $right)`
}
let toString = (nodeToString, x) =>
switch x {
| #AlgebraicCombination(op, t1, t2) => Algebraic.format(op, nodeToString(t1), nodeToString(t2))
| #PointwiseCombination(op, t1, t2) => Pointwise.format(op, nodeToString(t1), nodeToString(t2))
| #VerticalScaling(scaleOp, t, scaleBy) =>
Scale.format(scaleOp, nodeToString(t), nodeToString(scaleBy))
| #Normalize(t) => "normalize(k" ++ (nodeToString(t) ++ ")")
| #FloatFromDist(floatFromDistOp, t) => DistToFloat.format(floatFromDistOp, nodeToString(t))
| #Truncate(lc, rc, t) => truncateToString(lc, rc, nodeToString(t))
| #Render(t) => nodeToString(t)
| _ => ""
} // SymbolicDist and RenderedDist are handled in AST.toString.
}