Refactored Operation.res to live out of the interpreter
This commit is contained in:
parent
45017f3145
commit
24a7d0eedf
|
@ -1,6 +1,6 @@
|
|||
open ASTTypes.AST
|
||||
|
||||
let toString = ASTBasic.toString
|
||||
let toString = ASTTypes.Node.toString
|
||||
|
||||
let envs = (samplingInputs, environment) => {
|
||||
samplingInputs: samplingInputs,
|
||||
|
|
|
@ -1,27 +0,0 @@
|
|||
open ASTTypes.AST
|
||||
// This file exists to manage a dependency cycle. It would be good to refactor later.
|
||||
|
||||
let rec toString: node => string = x =>
|
||||
switch x {
|
||||
| #SymbolicDist(d) => SymbolicDist.T.toString(d)
|
||||
| #RenderedDist(_) => "[renderedShape]"
|
||||
| #AlgebraicCombination(op, t1, t2) => Operation.Algebraic.format(op, toString(t1), toString(t2))
|
||||
| #PointwiseCombination(op, t1, t2) => Operation.Pointwise.format(op, toString(t1), toString(t2))
|
||||
| #Normalize(t) => "normalize(k" ++ (toString(t) ++ ")")
|
||||
| #Truncate(lc, rc, t) => Operation.T.truncateToString(lc, rc, toString(t))
|
||||
| #Render(t) => toString(t)
|
||||
| #Symbol(t) => "Symbol: " ++ t
|
||||
| #FunctionCall(name, args) =>
|
||||
"[Function call: (" ++
|
||||
(name ++
|
||||
((args |> E.A.fmap(toString) |> Js.String.concatMany(_, ",")) ++ ")]"))
|
||||
| #Function(args, internal) =>
|
||||
"[Function: (" ++ ((args |> Js.String.concatMany(_, ",")) ++ (toString(internal) ++ ")]"))
|
||||
| #Array(a) => "[" ++ ((a |> E.A.fmap(toString) |> Js.String.concatMany(_, ",")) ++ "]")
|
||||
| #Hash(h) =>
|
||||
"{" ++
|
||||
((h
|
||||
|> E.A.fmap(((name, value)) => name ++ (":" ++ toString(value)))
|
||||
|> Js.String.concatMany(_, ",")) ++
|
||||
"}")
|
||||
}
|
|
@ -56,7 +56,7 @@ module AlgebraicCombination = {
|
|||
|
||||
let operationToLeaf = (
|
||||
evaluationParams: evaluationParams,
|
||||
algebraicOp: ASTTypes.algebraicOperation,
|
||||
algebraicOp: Operation.algebraicOperation,
|
||||
t1: t,
|
||||
t2: t,
|
||||
): result<node, string> =>
|
||||
|
@ -106,7 +106,7 @@ module PointwiseCombination = {
|
|||
|
||||
let operationToLeaf = (
|
||||
evaluationParams: evaluationParams,
|
||||
pointwiseOp: pointwiseOperation,
|
||||
pointwiseOp: Operation.pointwiseOperation,
|
||||
t1: t,
|
||||
t2: t,
|
||||
) =>
|
||||
|
|
|
@ -1,20 +1,3 @@
|
|||
type algebraicOperation = [
|
||||
| #Add
|
||||
| #Multiply
|
||||
| #Subtract
|
||||
| #Divide
|
||||
| #Exponentiate
|
||||
]
|
||||
type pointwiseOperation = [#Add | #Multiply | #Exponentiate]
|
||||
type scaleOperation = [#Multiply | #Exponentiate | #Log]
|
||||
type distToFloatOperation = [
|
||||
| #Pdf(float)
|
||||
| #Cdf(float)
|
||||
| #Inv(float)
|
||||
| #Mean
|
||||
| #Sample
|
||||
]
|
||||
|
||||
module AST = {
|
||||
type rec hash = array<(string, node)>
|
||||
and node = [
|
||||
|
@ -24,8 +7,8 @@ module AST = {
|
|||
| #Hash(hash)
|
||||
| #Array(array<node>)
|
||||
| #Function(array<string>, node)
|
||||
| #AlgebraicCombination(algebraicOperation, node, node)
|
||||
| #PointwiseCombination(pointwiseOperation, node, node)
|
||||
| #AlgebraicCombination(Operation.algebraicOperation, node, node)
|
||||
| #PointwiseCombination(Operation.pointwiseOperation, node, node)
|
||||
| #Normalize(node)
|
||||
| #Render(node)
|
||||
| #Truncate(option<float>, option<float>, node)
|
||||
|
@ -172,3 +155,32 @@ module Program = {
|
|||
]
|
||||
type program = array<statement>
|
||||
}
|
||||
|
||||
module Node = {
|
||||
let rec toString: AST.node => string = x =>
|
||||
switch x {
|
||||
| #SymbolicDist(d) => SymbolicDist.T.toString(d)
|
||||
| #RenderedDist(_) => "[renderedShape]"
|
||||
| #AlgebraicCombination(op, t1, t2) =>
|
||||
Operation.Algebraic.format(op, toString(t1), toString(t2))
|
||||
| #PointwiseCombination(op, t1, t2) =>
|
||||
Operation.Pointwise.format(op, toString(t1), toString(t2))
|
||||
| #Normalize(t) => "normalize(k" ++ (toString(t) ++ ")")
|
||||
| #Truncate(lc, rc, t) => Operation.Truncate.toString(lc, rc, toString(t))
|
||||
| #Render(t) => toString(t)
|
||||
| #Symbol(t) => "Symbol: " ++ t
|
||||
| #FunctionCall(name, args) =>
|
||||
"[Function call: (" ++
|
||||
(name ++
|
||||
((args |> E.A.fmap(toString) |> Js.String.concatMany(_, ",")) ++ ")]"))
|
||||
| #Function(args, internal) =>
|
||||
"[Function: (" ++ ((args |> Js.String.concatMany(_, ",")) ++ (toString(internal) ++ ")]"))
|
||||
| #Array(a) => "[" ++ ((a |> E.A.fmap(toString) |> Js.String.concatMany(_, ",")) ++ "]")
|
||||
| #Hash(h) =>
|
||||
"{" ++
|
||||
((h
|
||||
|> E.A.fmap(((name, value)) => name ++ (":" ++ toString(value)))
|
||||
|> Js.String.concatMany(_, ",")) ++
|
||||
"}")
|
||||
}
|
||||
}
|
||||
|
|
|
@ -84,7 +84,7 @@ let makeDist = (name, fn) =>
|
|||
)
|
||||
|
||||
let floatFromDist = (
|
||||
distToFloatOp: ASTTypes.distToFloatOperation,
|
||||
distToFloatOp: Operation.distToFloatOperation,
|
||||
t: TypeSystem.samplingDist,
|
||||
): result<node, string> =>
|
||||
switch t {
|
||||
|
|
|
@ -61,7 +61,7 @@ module TypedValue = {
|
|||
|> E.A.fmap(((name, t)) => fromNode(t) |> E.R.fmap(r => (name, r)))
|
||||
|> E.A.R.firstErrorOrOpen
|
||||
|> E.R.fmap(r => #Hash(r))
|
||||
| e => Error("Wrong type: " ++ ASTBasic.toString(e))
|
||||
| e => Error("Wrong type: " ++ ASTTypes.Node.toString(e))
|
||||
}
|
||||
|
||||
// todo: Arrays and hashes
|
||||
|
|
|
@ -96,7 +96,7 @@ let toDiscretePointMassesFromTriangulars = (
|
|||
}
|
||||
|
||||
let combineShapesContinuousContinuous = (
|
||||
op: ASTTypes.algebraicOperation,
|
||||
op: Operation.algebraicOperation,
|
||||
s1: PointSetTypes.xyShape,
|
||||
s2: PointSetTypes.xyShape,
|
||||
): PointSetTypes.xyShape => {
|
||||
|
@ -200,7 +200,7 @@ let toDiscretePointMassesFromDiscrete = (s: PointSetTypes.xyShape): pointMassesW
|
|||
}
|
||||
|
||||
let combineShapesContinuousDiscrete = (
|
||||
op: ASTTypes.algebraicOperation,
|
||||
op: Operation.algebraicOperation,
|
||||
continuousShape: PointSetTypes.xyShape,
|
||||
discreteShape: PointSetTypes.xyShape,
|
||||
): PointSetTypes.xyShape => {
|
||||
|
|
|
@ -211,7 +211,7 @@ module T = Dist({
|
|||
/* This simply creates multiple copies of the continuous distribution, scaled and shifted according to
|
||||
each discrete data point, and then adds them all together. */
|
||||
let combineAlgebraicallyWithDiscrete = (
|
||||
op: ASTTypes.algebraicOperation,
|
||||
op: Operation.algebraicOperation,
|
||||
t1: t,
|
||||
t2: PointSetTypes.discreteShape,
|
||||
) => {
|
||||
|
@ -244,7 +244,7 @@ let combineAlgebraicallyWithDiscrete = (
|
|||
}
|
||||
}
|
||||
|
||||
let combineAlgebraically = (op: ASTTypes.algebraicOperation, t1: t, t2: t) => {
|
||||
let combineAlgebraically = (op: Operation.algebraicOperation, t1: t, t2: t) => {
|
||||
let s1 = t1 |> getShape
|
||||
let s2 = t2 |> getShape
|
||||
let t1n = s1 |> XYShape.T.length
|
||||
|
|
|
@ -85,7 +85,7 @@ let updateIntegralCache = (integralCache, t: t): t => {
|
|||
|
||||
/* This multiples all of the data points together and creates a new discrete distribution from the results.
|
||||
Data points at the same xs get added together. It may be a good idea to downsample t1 and t2 before and/or the result after. */
|
||||
let combineAlgebraically = (op: ASTTypes.algebraicOperation, t1: t, t2: t): t => {
|
||||
let combineAlgebraically = (op: Operation.algebraicOperation, t1: t, t2: t): t => {
|
||||
let t1s = t1 |> getShape
|
||||
let t2s = t2 |> getShape
|
||||
let t1n = t1s |> XYShape.T.length
|
||||
|
|
|
@ -227,7 +227,7 @@ module T = Dist({
|
|||
}
|
||||
})
|
||||
|
||||
let combineAlgebraically = (op: ASTTypes.algebraicOperation, t1: t, t2: t): t => {
|
||||
let combineAlgebraically = (op: Operation.algebraicOperation, t1: t, t2: t): t => {
|
||||
// Discrete convolution can cause a huge increase in the number of samples,
|
||||
// so we'll first downsample.
|
||||
|
||||
|
|
|
@ -33,7 +33,7 @@ let toMixed = mapToAll((
|
|||
),
|
||||
))
|
||||
|
||||
let combineAlgebraically = (op: ASTTypes.algebraicOperation, t1: t, t2: t): t =>
|
||||
let combineAlgebraically = (op: Operation.algebraicOperation, t1: t, t2: t): t =>
|
||||
switch (t1, t2) {
|
||||
| (Continuous(m1), Continuous(m2)) =>
|
||||
Continuous.combineAlgebraically(op, m1, m2) |> Continuous.T.toPointSetDist
|
||||
|
@ -197,7 +197,7 @@ let sampleNRendered = (n, dist) => {
|
|||
doN(n, () => sample(distWithUpdatedIntegralCache))
|
||||
}
|
||||
|
||||
let operate = (distToFloatOp: ASTTypes.distToFloatOperation, s): float =>
|
||||
let operate = (distToFloatOp: Operation.distToFloatOperation, s): float =>
|
||||
switch distToFloatOp {
|
||||
| #Pdf(f) => pdf(f, s)
|
||||
| #Cdf(f) => pdf(f, s)
|
||||
|
|
|
@ -272,7 +272,7 @@ module T = {
|
|||
| #Float(n) => Float.mean(n)
|
||||
}
|
||||
|
||||
let operate = (distToFloatOp: ASTTypes.distToFloatOperation, s) =>
|
||||
let operate = (distToFloatOp: Operation.distToFloatOperation, s) =>
|
||||
switch distToFloatOp {
|
||||
| #Cdf(f) => Ok(cdf(f, s))
|
||||
| #Pdf(f) => Ok(pdf(f, s))
|
||||
|
@ -302,7 +302,7 @@ module T = {
|
|||
let tryAnalyticalSimplification = (
|
||||
d1: symbolicDist,
|
||||
d2: symbolicDist,
|
||||
op: ASTTypes.algebraicOperation,
|
||||
op: Operation.algebraicOperation,
|
||||
): analyticalSimplificationResult =>
|
||||
switch (d1, d2) {
|
||||
| (#Float(v1), #Float(v2)) =>
|
||||
|
|
|
@ -1,4 +1,21 @@
|
|||
open ASTTypes
|
||||
// This file has no dependencies. It's used outside of the interpreter, but the interpreter depends on it.
|
||||
|
||||
type algebraicOperation = [
|
||||
| #Add
|
||||
| #Multiply
|
||||
| #Subtract
|
||||
| #Divide
|
||||
| #Exponentiate
|
||||
]
|
||||
type pointwiseOperation = [#Add | #Multiply | #Exponentiate]
|
||||
type scaleOperation = [#Multiply | #Exponentiate | #Log]
|
||||
type distToFloatOperation = [
|
||||
| #Pdf(float)
|
||||
| #Cdf(float)
|
||||
| #Inv(float)
|
||||
| #Mean
|
||||
| #Sample
|
||||
]
|
||||
|
||||
module Algebraic = {
|
||||
type t = algebraicOperation
|
||||
|
@ -86,22 +103,10 @@ module Scale = {
|
|||
}
|
||||
}
|
||||
|
||||
module T = {
|
||||
let truncateToString = (left: option<float>, right: option<float>, nodeToString) => {
|
||||
module Truncate = {
|
||||
let toString = (left: option<float>, right: option<float>, nodeToString) => {
|
||||
let left = left |> E.O.dimap(Js.Float.toString, () => "-inf")
|
||||
let right = right |> E.O.dimap(Js.Float.toString, () => "inf")
|
||||
j`truncate($nodeToString, $left, $right)`
|
||||
}
|
||||
let toString = (nodeToString, x) =>
|
||||
switch x {
|
||||
| #AlgebraicCombination(op, t1, t2) => Algebraic.format(op, nodeToString(t1), nodeToString(t2))
|
||||
| #PointwiseCombination(op, t1, t2) => Pointwise.format(op, nodeToString(t1), nodeToString(t2))
|
||||
| #VerticalScaling(scaleOp, t, scaleBy) =>
|
||||
Scale.format(scaleOp, nodeToString(t), nodeToString(scaleBy))
|
||||
| #Normalize(t) => "normalize(k" ++ (nodeToString(t) ++ ")")
|
||||
| #FloatFromDist(floatFromDistOp, t) => DistToFloat.format(floatFromDistOp, nodeToString(t))
|
||||
| #Truncate(lc, rc, t) => truncateToString(lc, rc, nodeToString(t))
|
||||
| #Render(t) => nodeToString(t)
|
||||
| _ => ""
|
||||
} // SymbolicDist and RenderedDist are handled in AST.toString.
|
||||
}
|
||||
}
|
Loading…
Reference in New Issue
Block a user