Refactored Operation.res to live out of the interpreter

This commit is contained in:
Ozzie Gooen 2022-02-16 14:57:46 -05:00
parent 45017f3145
commit 24a7d0eedf
13 changed files with 67 additions and 77 deletions

View File

@ -1,6 +1,6 @@
open ASTTypes.AST open ASTTypes.AST
let toString = ASTBasic.toString let toString = ASTTypes.Node.toString
let envs = (samplingInputs, environment) => { let envs = (samplingInputs, environment) => {
samplingInputs: samplingInputs, samplingInputs: samplingInputs,

View File

@ -1,27 +0,0 @@
open ASTTypes.AST
// This file exists to manage a dependency cycle. It would be good to refactor later.
let rec toString: node => string = x =>
switch x {
| #SymbolicDist(d) => SymbolicDist.T.toString(d)
| #RenderedDist(_) => "[renderedShape]"
| #AlgebraicCombination(op, t1, t2) => Operation.Algebraic.format(op, toString(t1), toString(t2))
| #PointwiseCombination(op, t1, t2) => Operation.Pointwise.format(op, toString(t1), toString(t2))
| #Normalize(t) => "normalize(k" ++ (toString(t) ++ ")")
| #Truncate(lc, rc, t) => Operation.T.truncateToString(lc, rc, toString(t))
| #Render(t) => toString(t)
| #Symbol(t) => "Symbol: " ++ t
| #FunctionCall(name, args) =>
"[Function call: (" ++
(name ++
((args |> E.A.fmap(toString) |> Js.String.concatMany(_, ",")) ++ ")]"))
| #Function(args, internal) =>
"[Function: (" ++ ((args |> Js.String.concatMany(_, ",")) ++ (toString(internal) ++ ")]"))
| #Array(a) => "[" ++ ((a |> E.A.fmap(toString) |> Js.String.concatMany(_, ",")) ++ "]")
| #Hash(h) =>
"{" ++
((h
|> E.A.fmap(((name, value)) => name ++ (":" ++ toString(value)))
|> Js.String.concatMany(_, ",")) ++
"}")
}

View File

@ -56,7 +56,7 @@ module AlgebraicCombination = {
let operationToLeaf = ( let operationToLeaf = (
evaluationParams: evaluationParams, evaluationParams: evaluationParams,
algebraicOp: ASTTypes.algebraicOperation, algebraicOp: Operation.algebraicOperation,
t1: t, t1: t,
t2: t, t2: t,
): result<node, string> => ): result<node, string> =>
@ -106,7 +106,7 @@ module PointwiseCombination = {
let operationToLeaf = ( let operationToLeaf = (
evaluationParams: evaluationParams, evaluationParams: evaluationParams,
pointwiseOp: pointwiseOperation, pointwiseOp: Operation.pointwiseOperation,
t1: t, t1: t,
t2: t, t2: t,
) => ) =>

View File

@ -1,20 +1,3 @@
type algebraicOperation = [
| #Add
| #Multiply
| #Subtract
| #Divide
| #Exponentiate
]
type pointwiseOperation = [#Add | #Multiply | #Exponentiate]
type scaleOperation = [#Multiply | #Exponentiate | #Log]
type distToFloatOperation = [
| #Pdf(float)
| #Cdf(float)
| #Inv(float)
| #Mean
| #Sample
]
module AST = { module AST = {
type rec hash = array<(string, node)> type rec hash = array<(string, node)>
and node = [ and node = [
@ -24,8 +7,8 @@ module AST = {
| #Hash(hash) | #Hash(hash)
| #Array(array<node>) | #Array(array<node>)
| #Function(array<string>, node) | #Function(array<string>, node)
| #AlgebraicCombination(algebraicOperation, node, node) | #AlgebraicCombination(Operation.algebraicOperation, node, node)
| #PointwiseCombination(pointwiseOperation, node, node) | #PointwiseCombination(Operation.pointwiseOperation, node, node)
| #Normalize(node) | #Normalize(node)
| #Render(node) | #Render(node)
| #Truncate(option<float>, option<float>, node) | #Truncate(option<float>, option<float>, node)
@ -172,3 +155,32 @@ module Program = {
] ]
type program = array<statement> type program = array<statement>
} }
module Node = {
let rec toString: AST.node => string = x =>
switch x {
| #SymbolicDist(d) => SymbolicDist.T.toString(d)
| #RenderedDist(_) => "[renderedShape]"
| #AlgebraicCombination(op, t1, t2) =>
Operation.Algebraic.format(op, toString(t1), toString(t2))
| #PointwiseCombination(op, t1, t2) =>
Operation.Pointwise.format(op, toString(t1), toString(t2))
| #Normalize(t) => "normalize(k" ++ (toString(t) ++ ")")
| #Truncate(lc, rc, t) => Operation.Truncate.toString(lc, rc, toString(t))
| #Render(t) => toString(t)
| #Symbol(t) => "Symbol: " ++ t
| #FunctionCall(name, args) =>
"[Function call: (" ++
(name ++
((args |> E.A.fmap(toString) |> Js.String.concatMany(_, ",")) ++ ")]"))
| #Function(args, internal) =>
"[Function: (" ++ ((args |> Js.String.concatMany(_, ",")) ++ (toString(internal) ++ ")]"))
| #Array(a) => "[" ++ ((a |> E.A.fmap(toString) |> Js.String.concatMany(_, ",")) ++ "]")
| #Hash(h) =>
"{" ++
((h
|> E.A.fmap(((name, value)) => name ++ (":" ++ toString(value)))
|> Js.String.concatMany(_, ",")) ++
"}")
}
}

View File

@ -84,7 +84,7 @@ let makeDist = (name, fn) =>
) )
let floatFromDist = ( let floatFromDist = (
distToFloatOp: ASTTypes.distToFloatOperation, distToFloatOp: Operation.distToFloatOperation,
t: TypeSystem.samplingDist, t: TypeSystem.samplingDist,
): result<node, string> => ): result<node, string> =>
switch t { switch t {

View File

@ -61,7 +61,7 @@ module TypedValue = {
|> E.A.fmap(((name, t)) => fromNode(t) |> E.R.fmap(r => (name, r))) |> E.A.fmap(((name, t)) => fromNode(t) |> E.R.fmap(r => (name, r)))
|> E.A.R.firstErrorOrOpen |> E.A.R.firstErrorOrOpen
|> E.R.fmap(r => #Hash(r)) |> E.R.fmap(r => #Hash(r))
| e => Error("Wrong type: " ++ ASTBasic.toString(e)) | e => Error("Wrong type: " ++ ASTTypes.Node.toString(e))
} }
// todo: Arrays and hashes // todo: Arrays and hashes

View File

@ -96,7 +96,7 @@ let toDiscretePointMassesFromTriangulars = (
} }
let combineShapesContinuousContinuous = ( let combineShapesContinuousContinuous = (
op: ASTTypes.algebraicOperation, op: Operation.algebraicOperation,
s1: PointSetTypes.xyShape, s1: PointSetTypes.xyShape,
s2: PointSetTypes.xyShape, s2: PointSetTypes.xyShape,
): PointSetTypes.xyShape => { ): PointSetTypes.xyShape => {
@ -200,7 +200,7 @@ let toDiscretePointMassesFromDiscrete = (s: PointSetTypes.xyShape): pointMassesW
} }
let combineShapesContinuousDiscrete = ( let combineShapesContinuousDiscrete = (
op: ASTTypes.algebraicOperation, op: Operation.algebraicOperation,
continuousShape: PointSetTypes.xyShape, continuousShape: PointSetTypes.xyShape,
discreteShape: PointSetTypes.xyShape, discreteShape: PointSetTypes.xyShape,
): PointSetTypes.xyShape => { ): PointSetTypes.xyShape => {

View File

@ -211,7 +211,7 @@ module T = Dist({
/* This simply creates multiple copies of the continuous distribution, scaled and shifted according to /* This simply creates multiple copies of the continuous distribution, scaled and shifted according to
each discrete data point, and then adds them all together. */ each discrete data point, and then adds them all together. */
let combineAlgebraicallyWithDiscrete = ( let combineAlgebraicallyWithDiscrete = (
op: ASTTypes.algebraicOperation, op: Operation.algebraicOperation,
t1: t, t1: t,
t2: PointSetTypes.discreteShape, t2: PointSetTypes.discreteShape,
) => { ) => {
@ -244,7 +244,7 @@ let combineAlgebraicallyWithDiscrete = (
} }
} }
let combineAlgebraically = (op: ASTTypes.algebraicOperation, t1: t, t2: t) => { let combineAlgebraically = (op: Operation.algebraicOperation, t1: t, t2: t) => {
let s1 = t1 |> getShape let s1 = t1 |> getShape
let s2 = t2 |> getShape let s2 = t2 |> getShape
let t1n = s1 |> XYShape.T.length let t1n = s1 |> XYShape.T.length

View File

@ -85,7 +85,7 @@ let updateIntegralCache = (integralCache, t: t): t => {
/* This multiples all of the data points together and creates a new discrete distribution from the results. /* This multiples all of the data points together and creates a new discrete distribution from the results.
Data points at the same xs get added together. It may be a good idea to downsample t1 and t2 before and/or the result after. */ Data points at the same xs get added together. It may be a good idea to downsample t1 and t2 before and/or the result after. */
let combineAlgebraically = (op: ASTTypes.algebraicOperation, t1: t, t2: t): t => { let combineAlgebraically = (op: Operation.algebraicOperation, t1: t, t2: t): t => {
let t1s = t1 |> getShape let t1s = t1 |> getShape
let t2s = t2 |> getShape let t2s = t2 |> getShape
let t1n = t1s |> XYShape.T.length let t1n = t1s |> XYShape.T.length

View File

@ -227,7 +227,7 @@ module T = Dist({
} }
}) })
let combineAlgebraically = (op: ASTTypes.algebraicOperation, t1: t, t2: t): t => { let combineAlgebraically = (op: Operation.algebraicOperation, t1: t, t2: t): t => {
// Discrete convolution can cause a huge increase in the number of samples, // Discrete convolution can cause a huge increase in the number of samples,
// so we'll first downsample. // so we'll first downsample.

View File

@ -33,7 +33,7 @@ let toMixed = mapToAll((
), ),
)) ))
let combineAlgebraically = (op: ASTTypes.algebraicOperation, t1: t, t2: t): t => let combineAlgebraically = (op: Operation.algebraicOperation, t1: t, t2: t): t =>
switch (t1, t2) { switch (t1, t2) {
| (Continuous(m1), Continuous(m2)) => | (Continuous(m1), Continuous(m2)) =>
Continuous.combineAlgebraically(op, m1, m2) |> Continuous.T.toPointSetDist Continuous.combineAlgebraically(op, m1, m2) |> Continuous.T.toPointSetDist
@ -197,7 +197,7 @@ let sampleNRendered = (n, dist) => {
doN(n, () => sample(distWithUpdatedIntegralCache)) doN(n, () => sample(distWithUpdatedIntegralCache))
} }
let operate = (distToFloatOp: ASTTypes.distToFloatOperation, s): float => let operate = (distToFloatOp: Operation.distToFloatOperation, s): float =>
switch distToFloatOp { switch distToFloatOp {
| #Pdf(f) => pdf(f, s) | #Pdf(f) => pdf(f, s)
| #Cdf(f) => pdf(f, s) | #Cdf(f) => pdf(f, s)

View File

@ -272,7 +272,7 @@ module T = {
| #Float(n) => Float.mean(n) | #Float(n) => Float.mean(n)
} }
let operate = (distToFloatOp: ASTTypes.distToFloatOperation, s) => let operate = (distToFloatOp: Operation.distToFloatOperation, s) =>
switch distToFloatOp { switch distToFloatOp {
| #Cdf(f) => Ok(cdf(f, s)) | #Cdf(f) => Ok(cdf(f, s))
| #Pdf(f) => Ok(pdf(f, s)) | #Pdf(f) => Ok(pdf(f, s))
@ -302,7 +302,7 @@ module T = {
let tryAnalyticalSimplification = ( let tryAnalyticalSimplification = (
d1: symbolicDist, d1: symbolicDist,
d2: symbolicDist, d2: symbolicDist,
op: ASTTypes.algebraicOperation, op: Operation.algebraicOperation,
): analyticalSimplificationResult => ): analyticalSimplificationResult =>
switch (d1, d2) { switch (d1, d2) {
| (#Float(v1), #Float(v2)) => | (#Float(v1), #Float(v2)) =>

View File

@ -1,4 +1,21 @@
open ASTTypes // This file has no dependencies. It's used outside of the interpreter, but the interpreter depends on it.
type algebraicOperation = [
| #Add
| #Multiply
| #Subtract
| #Divide
| #Exponentiate
]
type pointwiseOperation = [#Add | #Multiply | #Exponentiate]
type scaleOperation = [#Multiply | #Exponentiate | #Log]
type distToFloatOperation = [
| #Pdf(float)
| #Cdf(float)
| #Inv(float)
| #Mean
| #Sample
]
module Algebraic = { module Algebraic = {
type t = algebraicOperation type t = algebraicOperation
@ -86,22 +103,10 @@ module Scale = {
} }
} }
module T = { module Truncate = {
let truncateToString = (left: option<float>, right: option<float>, nodeToString) => { let toString = (left: option<float>, right: option<float>, nodeToString) => {
let left = left |> E.O.dimap(Js.Float.toString, () => "-inf") let left = left |> E.O.dimap(Js.Float.toString, () => "-inf")
let right = right |> E.O.dimap(Js.Float.toString, () => "inf") let right = right |> E.O.dimap(Js.Float.toString, () => "inf")
j`truncate($nodeToString, $left, $right)` j`truncate($nodeToString, $left, $right)`
} }
let toString = (nodeToString, x) => }
switch x {
| #AlgebraicCombination(op, t1, t2) => Algebraic.format(op, nodeToString(t1), nodeToString(t2))
| #PointwiseCombination(op, t1, t2) => Pointwise.format(op, nodeToString(t1), nodeToString(t2))
| #VerticalScaling(scaleOp, t, scaleBy) =>
Scale.format(scaleOp, nodeToString(t), nodeToString(scaleBy))
| #Normalize(t) => "normalize(k" ++ (nodeToString(t) ++ ")")
| #FloatFromDist(floatFromDistOp, t) => DistToFloat.format(floatFromDistOp, nodeToString(t))
| #Truncate(lc, rc, t) => truncateToString(lc, rc, nodeToString(t))
| #Render(t) => nodeToString(t)
| _ => ""
} // SymbolicDist and RenderedDist are handled in AST.toString.
}