Refactored Operation.res to live out of the interpreter
This commit is contained in:
parent
45017f3145
commit
24a7d0eedf
|
@ -1,6 +1,6 @@
|
||||||
open ASTTypes.AST
|
open ASTTypes.AST
|
||||||
|
|
||||||
let toString = ASTBasic.toString
|
let toString = ASTTypes.Node.toString
|
||||||
|
|
||||||
let envs = (samplingInputs, environment) => {
|
let envs = (samplingInputs, environment) => {
|
||||||
samplingInputs: samplingInputs,
|
samplingInputs: samplingInputs,
|
||||||
|
|
|
@ -1,27 +0,0 @@
|
||||||
open ASTTypes.AST
|
|
||||||
// This file exists to manage a dependency cycle. It would be good to refactor later.
|
|
||||||
|
|
||||||
let rec toString: node => string = x =>
|
|
||||||
switch x {
|
|
||||||
| #SymbolicDist(d) => SymbolicDist.T.toString(d)
|
|
||||||
| #RenderedDist(_) => "[renderedShape]"
|
|
||||||
| #AlgebraicCombination(op, t1, t2) => Operation.Algebraic.format(op, toString(t1), toString(t2))
|
|
||||||
| #PointwiseCombination(op, t1, t2) => Operation.Pointwise.format(op, toString(t1), toString(t2))
|
|
||||||
| #Normalize(t) => "normalize(k" ++ (toString(t) ++ ")")
|
|
||||||
| #Truncate(lc, rc, t) => Operation.T.truncateToString(lc, rc, toString(t))
|
|
||||||
| #Render(t) => toString(t)
|
|
||||||
| #Symbol(t) => "Symbol: " ++ t
|
|
||||||
| #FunctionCall(name, args) =>
|
|
||||||
"[Function call: (" ++
|
|
||||||
(name ++
|
|
||||||
((args |> E.A.fmap(toString) |> Js.String.concatMany(_, ",")) ++ ")]"))
|
|
||||||
| #Function(args, internal) =>
|
|
||||||
"[Function: (" ++ ((args |> Js.String.concatMany(_, ",")) ++ (toString(internal) ++ ")]"))
|
|
||||||
| #Array(a) => "[" ++ ((a |> E.A.fmap(toString) |> Js.String.concatMany(_, ",")) ++ "]")
|
|
||||||
| #Hash(h) =>
|
|
||||||
"{" ++
|
|
||||||
((h
|
|
||||||
|> E.A.fmap(((name, value)) => name ++ (":" ++ toString(value)))
|
|
||||||
|> Js.String.concatMany(_, ",")) ++
|
|
||||||
"}")
|
|
||||||
}
|
|
|
@ -56,7 +56,7 @@ module AlgebraicCombination = {
|
||||||
|
|
||||||
let operationToLeaf = (
|
let operationToLeaf = (
|
||||||
evaluationParams: evaluationParams,
|
evaluationParams: evaluationParams,
|
||||||
algebraicOp: ASTTypes.algebraicOperation,
|
algebraicOp: Operation.algebraicOperation,
|
||||||
t1: t,
|
t1: t,
|
||||||
t2: t,
|
t2: t,
|
||||||
): result<node, string> =>
|
): result<node, string> =>
|
||||||
|
@ -106,7 +106,7 @@ module PointwiseCombination = {
|
||||||
|
|
||||||
let operationToLeaf = (
|
let operationToLeaf = (
|
||||||
evaluationParams: evaluationParams,
|
evaluationParams: evaluationParams,
|
||||||
pointwiseOp: pointwiseOperation,
|
pointwiseOp: Operation.pointwiseOperation,
|
||||||
t1: t,
|
t1: t,
|
||||||
t2: t,
|
t2: t,
|
||||||
) =>
|
) =>
|
||||||
|
|
|
@ -1,20 +1,3 @@
|
||||||
type algebraicOperation = [
|
|
||||||
| #Add
|
|
||||||
| #Multiply
|
|
||||||
| #Subtract
|
|
||||||
| #Divide
|
|
||||||
| #Exponentiate
|
|
||||||
]
|
|
||||||
type pointwiseOperation = [#Add | #Multiply | #Exponentiate]
|
|
||||||
type scaleOperation = [#Multiply | #Exponentiate | #Log]
|
|
||||||
type distToFloatOperation = [
|
|
||||||
| #Pdf(float)
|
|
||||||
| #Cdf(float)
|
|
||||||
| #Inv(float)
|
|
||||||
| #Mean
|
|
||||||
| #Sample
|
|
||||||
]
|
|
||||||
|
|
||||||
module AST = {
|
module AST = {
|
||||||
type rec hash = array<(string, node)>
|
type rec hash = array<(string, node)>
|
||||||
and node = [
|
and node = [
|
||||||
|
@ -24,8 +7,8 @@ module AST = {
|
||||||
| #Hash(hash)
|
| #Hash(hash)
|
||||||
| #Array(array<node>)
|
| #Array(array<node>)
|
||||||
| #Function(array<string>, node)
|
| #Function(array<string>, node)
|
||||||
| #AlgebraicCombination(algebraicOperation, node, node)
|
| #AlgebraicCombination(Operation.algebraicOperation, node, node)
|
||||||
| #PointwiseCombination(pointwiseOperation, node, node)
|
| #PointwiseCombination(Operation.pointwiseOperation, node, node)
|
||||||
| #Normalize(node)
|
| #Normalize(node)
|
||||||
| #Render(node)
|
| #Render(node)
|
||||||
| #Truncate(option<float>, option<float>, node)
|
| #Truncate(option<float>, option<float>, node)
|
||||||
|
@ -172,3 +155,32 @@ module Program = {
|
||||||
]
|
]
|
||||||
type program = array<statement>
|
type program = array<statement>
|
||||||
}
|
}
|
||||||
|
|
||||||
|
module Node = {
|
||||||
|
let rec toString: AST.node => string = x =>
|
||||||
|
switch x {
|
||||||
|
| #SymbolicDist(d) => SymbolicDist.T.toString(d)
|
||||||
|
| #RenderedDist(_) => "[renderedShape]"
|
||||||
|
| #AlgebraicCombination(op, t1, t2) =>
|
||||||
|
Operation.Algebraic.format(op, toString(t1), toString(t2))
|
||||||
|
| #PointwiseCombination(op, t1, t2) =>
|
||||||
|
Operation.Pointwise.format(op, toString(t1), toString(t2))
|
||||||
|
| #Normalize(t) => "normalize(k" ++ (toString(t) ++ ")")
|
||||||
|
| #Truncate(lc, rc, t) => Operation.Truncate.toString(lc, rc, toString(t))
|
||||||
|
| #Render(t) => toString(t)
|
||||||
|
| #Symbol(t) => "Symbol: " ++ t
|
||||||
|
| #FunctionCall(name, args) =>
|
||||||
|
"[Function call: (" ++
|
||||||
|
(name ++
|
||||||
|
((args |> E.A.fmap(toString) |> Js.String.concatMany(_, ",")) ++ ")]"))
|
||||||
|
| #Function(args, internal) =>
|
||||||
|
"[Function: (" ++ ((args |> Js.String.concatMany(_, ",")) ++ (toString(internal) ++ ")]"))
|
||||||
|
| #Array(a) => "[" ++ ((a |> E.A.fmap(toString) |> Js.String.concatMany(_, ",")) ++ "]")
|
||||||
|
| #Hash(h) =>
|
||||||
|
"{" ++
|
||||||
|
((h
|
||||||
|
|> E.A.fmap(((name, value)) => name ++ (":" ++ toString(value)))
|
||||||
|
|> Js.String.concatMany(_, ",")) ++
|
||||||
|
"}")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
|
@ -84,7 +84,7 @@ let makeDist = (name, fn) =>
|
||||||
)
|
)
|
||||||
|
|
||||||
let floatFromDist = (
|
let floatFromDist = (
|
||||||
distToFloatOp: ASTTypes.distToFloatOperation,
|
distToFloatOp: Operation.distToFloatOperation,
|
||||||
t: TypeSystem.samplingDist,
|
t: TypeSystem.samplingDist,
|
||||||
): result<node, string> =>
|
): result<node, string> =>
|
||||||
switch t {
|
switch t {
|
||||||
|
|
|
@ -61,7 +61,7 @@ module TypedValue = {
|
||||||
|> E.A.fmap(((name, t)) => fromNode(t) |> E.R.fmap(r => (name, r)))
|
|> E.A.fmap(((name, t)) => fromNode(t) |> E.R.fmap(r => (name, r)))
|
||||||
|> E.A.R.firstErrorOrOpen
|
|> E.A.R.firstErrorOrOpen
|
||||||
|> E.R.fmap(r => #Hash(r))
|
|> E.R.fmap(r => #Hash(r))
|
||||||
| e => Error("Wrong type: " ++ ASTBasic.toString(e))
|
| e => Error("Wrong type: " ++ ASTTypes.Node.toString(e))
|
||||||
}
|
}
|
||||||
|
|
||||||
// todo: Arrays and hashes
|
// todo: Arrays and hashes
|
||||||
|
|
|
@ -96,7 +96,7 @@ let toDiscretePointMassesFromTriangulars = (
|
||||||
}
|
}
|
||||||
|
|
||||||
let combineShapesContinuousContinuous = (
|
let combineShapesContinuousContinuous = (
|
||||||
op: ASTTypes.algebraicOperation,
|
op: Operation.algebraicOperation,
|
||||||
s1: PointSetTypes.xyShape,
|
s1: PointSetTypes.xyShape,
|
||||||
s2: PointSetTypes.xyShape,
|
s2: PointSetTypes.xyShape,
|
||||||
): PointSetTypes.xyShape => {
|
): PointSetTypes.xyShape => {
|
||||||
|
@ -200,7 +200,7 @@ let toDiscretePointMassesFromDiscrete = (s: PointSetTypes.xyShape): pointMassesW
|
||||||
}
|
}
|
||||||
|
|
||||||
let combineShapesContinuousDiscrete = (
|
let combineShapesContinuousDiscrete = (
|
||||||
op: ASTTypes.algebraicOperation,
|
op: Operation.algebraicOperation,
|
||||||
continuousShape: PointSetTypes.xyShape,
|
continuousShape: PointSetTypes.xyShape,
|
||||||
discreteShape: PointSetTypes.xyShape,
|
discreteShape: PointSetTypes.xyShape,
|
||||||
): PointSetTypes.xyShape => {
|
): PointSetTypes.xyShape => {
|
||||||
|
|
|
@ -211,7 +211,7 @@ module T = Dist({
|
||||||
/* This simply creates multiple copies of the continuous distribution, scaled and shifted according to
|
/* This simply creates multiple copies of the continuous distribution, scaled and shifted according to
|
||||||
each discrete data point, and then adds them all together. */
|
each discrete data point, and then adds them all together. */
|
||||||
let combineAlgebraicallyWithDiscrete = (
|
let combineAlgebraicallyWithDiscrete = (
|
||||||
op: ASTTypes.algebraicOperation,
|
op: Operation.algebraicOperation,
|
||||||
t1: t,
|
t1: t,
|
||||||
t2: PointSetTypes.discreteShape,
|
t2: PointSetTypes.discreteShape,
|
||||||
) => {
|
) => {
|
||||||
|
@ -244,7 +244,7 @@ let combineAlgebraicallyWithDiscrete = (
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
let combineAlgebraically = (op: ASTTypes.algebraicOperation, t1: t, t2: t) => {
|
let combineAlgebraically = (op: Operation.algebraicOperation, t1: t, t2: t) => {
|
||||||
let s1 = t1 |> getShape
|
let s1 = t1 |> getShape
|
||||||
let s2 = t2 |> getShape
|
let s2 = t2 |> getShape
|
||||||
let t1n = s1 |> XYShape.T.length
|
let t1n = s1 |> XYShape.T.length
|
||||||
|
|
|
@ -85,7 +85,7 @@ let updateIntegralCache = (integralCache, t: t): t => {
|
||||||
|
|
||||||
/* This multiples all of the data points together and creates a new discrete distribution from the results.
|
/* This multiples all of the data points together and creates a new discrete distribution from the results.
|
||||||
Data points at the same xs get added together. It may be a good idea to downsample t1 and t2 before and/or the result after. */
|
Data points at the same xs get added together. It may be a good idea to downsample t1 and t2 before and/or the result after. */
|
||||||
let combineAlgebraically = (op: ASTTypes.algebraicOperation, t1: t, t2: t): t => {
|
let combineAlgebraically = (op: Operation.algebraicOperation, t1: t, t2: t): t => {
|
||||||
let t1s = t1 |> getShape
|
let t1s = t1 |> getShape
|
||||||
let t2s = t2 |> getShape
|
let t2s = t2 |> getShape
|
||||||
let t1n = t1s |> XYShape.T.length
|
let t1n = t1s |> XYShape.T.length
|
||||||
|
|
|
@ -227,7 +227,7 @@ module T = Dist({
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
|
|
||||||
let combineAlgebraically = (op: ASTTypes.algebraicOperation, t1: t, t2: t): t => {
|
let combineAlgebraically = (op: Operation.algebraicOperation, t1: t, t2: t): t => {
|
||||||
// Discrete convolution can cause a huge increase in the number of samples,
|
// Discrete convolution can cause a huge increase in the number of samples,
|
||||||
// so we'll first downsample.
|
// so we'll first downsample.
|
||||||
|
|
||||||
|
|
|
@ -33,7 +33,7 @@ let toMixed = mapToAll((
|
||||||
),
|
),
|
||||||
))
|
))
|
||||||
|
|
||||||
let combineAlgebraically = (op: ASTTypes.algebraicOperation, t1: t, t2: t): t =>
|
let combineAlgebraically = (op: Operation.algebraicOperation, t1: t, t2: t): t =>
|
||||||
switch (t1, t2) {
|
switch (t1, t2) {
|
||||||
| (Continuous(m1), Continuous(m2)) =>
|
| (Continuous(m1), Continuous(m2)) =>
|
||||||
Continuous.combineAlgebraically(op, m1, m2) |> Continuous.T.toPointSetDist
|
Continuous.combineAlgebraically(op, m1, m2) |> Continuous.T.toPointSetDist
|
||||||
|
@ -197,7 +197,7 @@ let sampleNRendered = (n, dist) => {
|
||||||
doN(n, () => sample(distWithUpdatedIntegralCache))
|
doN(n, () => sample(distWithUpdatedIntegralCache))
|
||||||
}
|
}
|
||||||
|
|
||||||
let operate = (distToFloatOp: ASTTypes.distToFloatOperation, s): float =>
|
let operate = (distToFloatOp: Operation.distToFloatOperation, s): float =>
|
||||||
switch distToFloatOp {
|
switch distToFloatOp {
|
||||||
| #Pdf(f) => pdf(f, s)
|
| #Pdf(f) => pdf(f, s)
|
||||||
| #Cdf(f) => pdf(f, s)
|
| #Cdf(f) => pdf(f, s)
|
||||||
|
|
|
@ -272,7 +272,7 @@ module T = {
|
||||||
| #Float(n) => Float.mean(n)
|
| #Float(n) => Float.mean(n)
|
||||||
}
|
}
|
||||||
|
|
||||||
let operate = (distToFloatOp: ASTTypes.distToFloatOperation, s) =>
|
let operate = (distToFloatOp: Operation.distToFloatOperation, s) =>
|
||||||
switch distToFloatOp {
|
switch distToFloatOp {
|
||||||
| #Cdf(f) => Ok(cdf(f, s))
|
| #Cdf(f) => Ok(cdf(f, s))
|
||||||
| #Pdf(f) => Ok(pdf(f, s))
|
| #Pdf(f) => Ok(pdf(f, s))
|
||||||
|
@ -302,7 +302,7 @@ module T = {
|
||||||
let tryAnalyticalSimplification = (
|
let tryAnalyticalSimplification = (
|
||||||
d1: symbolicDist,
|
d1: symbolicDist,
|
||||||
d2: symbolicDist,
|
d2: symbolicDist,
|
||||||
op: ASTTypes.algebraicOperation,
|
op: Operation.algebraicOperation,
|
||||||
): analyticalSimplificationResult =>
|
): analyticalSimplificationResult =>
|
||||||
switch (d1, d2) {
|
switch (d1, d2) {
|
||||||
| (#Float(v1), #Float(v2)) =>
|
| (#Float(v1), #Float(v2)) =>
|
||||||
|
|
|
@ -1,4 +1,21 @@
|
||||||
open ASTTypes
|
// This file has no dependencies. It's used outside of the interpreter, but the interpreter depends on it.
|
||||||
|
|
||||||
|
type algebraicOperation = [
|
||||||
|
| #Add
|
||||||
|
| #Multiply
|
||||||
|
| #Subtract
|
||||||
|
| #Divide
|
||||||
|
| #Exponentiate
|
||||||
|
]
|
||||||
|
type pointwiseOperation = [#Add | #Multiply | #Exponentiate]
|
||||||
|
type scaleOperation = [#Multiply | #Exponentiate | #Log]
|
||||||
|
type distToFloatOperation = [
|
||||||
|
| #Pdf(float)
|
||||||
|
| #Cdf(float)
|
||||||
|
| #Inv(float)
|
||||||
|
| #Mean
|
||||||
|
| #Sample
|
||||||
|
]
|
||||||
|
|
||||||
module Algebraic = {
|
module Algebraic = {
|
||||||
type t = algebraicOperation
|
type t = algebraicOperation
|
||||||
|
@ -86,22 +103,10 @@ module Scale = {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
module T = {
|
module Truncate = {
|
||||||
let truncateToString = (left: option<float>, right: option<float>, nodeToString) => {
|
let toString = (left: option<float>, right: option<float>, nodeToString) => {
|
||||||
let left = left |> E.O.dimap(Js.Float.toString, () => "-inf")
|
let left = left |> E.O.dimap(Js.Float.toString, () => "-inf")
|
||||||
let right = right |> E.O.dimap(Js.Float.toString, () => "inf")
|
let right = right |> E.O.dimap(Js.Float.toString, () => "inf")
|
||||||
j`truncate($nodeToString, $left, $right)`
|
j`truncate($nodeToString, $left, $right)`
|
||||||
}
|
}
|
||||||
let toString = (nodeToString, x) =>
|
}
|
||||||
switch x {
|
|
||||||
| #AlgebraicCombination(op, t1, t2) => Algebraic.format(op, nodeToString(t1), nodeToString(t2))
|
|
||||||
| #PointwiseCombination(op, t1, t2) => Pointwise.format(op, nodeToString(t1), nodeToString(t2))
|
|
||||||
| #VerticalScaling(scaleOp, t, scaleBy) =>
|
|
||||||
Scale.format(scaleOp, nodeToString(t), nodeToString(scaleBy))
|
|
||||||
| #Normalize(t) => "normalize(k" ++ (nodeToString(t) ++ ")")
|
|
||||||
| #FloatFromDist(floatFromDistOp, t) => DistToFloat.format(floatFromDistOp, nodeToString(t))
|
|
||||||
| #Truncate(lc, rc, t) => truncateToString(lc, rc, nodeToString(t))
|
|
||||||
| #Render(t) => nodeToString(t)
|
|
||||||
| _ => ""
|
|
||||||
} // SymbolicDist and RenderedDist are handled in AST.toString.
|
|
||||||
}
|
|
Loading…
Reference in New Issue
Block a user