Removed AST module from ASTTypes file
This commit is contained in:
parent
f76eaf6d03
commit
a324f8a7d6
|
@ -10,7 +10,7 @@ let makeTest = (~only=false, str, item1, item2) =>
|
|||
expect(item1) |> toEqual(item2)
|
||||
);
|
||||
|
||||
let evalParams: ASTTypes.AST.evaluationParams = {
|
||||
let evalParams: ASTTypes.evaluationParams = {
|
||||
samplingInputs: {
|
||||
sampleCount: 1000,
|
||||
outputXYPoints: 10000,
|
||||
|
|
|
@ -14,7 +14,7 @@ module Inputs = {
|
|||
type inputs = {
|
||||
squiggleString: string,
|
||||
samplingInputs: SamplingInputs.t,
|
||||
environment: ASTTypes.AST.environment,
|
||||
environment: ASTTypes.environment,
|
||||
}
|
||||
|
||||
let empty: SamplingInputs.t = {
|
||||
|
@ -27,7 +27,7 @@ module Inputs = {
|
|||
let make = (
|
||||
~samplingInputs=empty,
|
||||
~squiggleString,
|
||||
~environment=ASTTypes.AST.Environment.empty,
|
||||
~environment=ASTTypes.Environment.empty,
|
||||
(),
|
||||
): inputs => {
|
||||
samplingInputs: samplingInputs,
|
||||
|
@ -40,8 +40,8 @@ type \"export" = [
|
|||
| #DistPlus(DistPlus.t)
|
||||
| #Float(float)
|
||||
| #Function(
|
||||
(array<string>, ASTTypes.AST.node),
|
||||
ASTTypes.AST.environment,
|
||||
(array<string>, ASTTypes.node),
|
||||
ASTTypes.environment,
|
||||
)
|
||||
]
|
||||
|
||||
|
@ -53,13 +53,13 @@ module Internals = {
|
|||
): Inputs.inputs => {
|
||||
samplingInputs: samplingInputs,
|
||||
squiggleString: squiggleString,
|
||||
environment: ASTTypes.AST.Environment.update(environment, str, _ => Some(
|
||||
environment: ASTTypes.Environment.update(environment, str, _ => Some(
|
||||
node,
|
||||
)),
|
||||
}
|
||||
|
||||
type outputs = {
|
||||
graph: ASTTypes.AST.node,
|
||||
graph: ASTTypes.node,
|
||||
pointSetDist: PointSetTypes.pointSetDist,
|
||||
}
|
||||
let makeOutputs = (graph, pointSetDist): outputs => {graph: graph, pointSetDist: pointSetDist}
|
||||
|
@ -74,7 +74,7 @@ module Internals = {
|
|||
let runNode = (inputs, node) =>
|
||||
AST.toLeaf(makeInputs(inputs), inputs.environment, node)
|
||||
|
||||
let runProgram = (inputs: Inputs.inputs, p: ASTTypes.AST.program) => {
|
||||
let runProgram = (inputs: Inputs.inputs, p: ASTTypes.program) => {
|
||||
let ins = ref(inputs)
|
||||
p
|
||||
|> E.A.fmap(x =>
|
||||
|
@ -97,8 +97,8 @@ module Internals = {
|
|||
DistPlus.make(~pointSetDist, ~squiggleString=Some(inputs.squiggleString), ())
|
||||
}
|
||||
|
||||
let renderIfNeeded = (inputs: Inputs.inputs, node: ASTTypes.AST.node): result<
|
||||
ASTTypes.AST.node,
|
||||
let renderIfNeeded = (inputs: Inputs.inputs, node: ASTTypes.node): result<
|
||||
ASTTypes.node,
|
||||
string,
|
||||
> =>
|
||||
node |> (
|
||||
|
@ -121,11 +121,11 @@ let renderIfNeeded = (inputs: Inputs.inputs, node: ASTTypes.AST.node): result<
|
|||
}
|
||||
)
|
||||
|
||||
// TODO: Consider using ASTTypes.AST.getFloat or similar in this function
|
||||
// TODO: Consider using ASTTypes.getFloat or similar in this function
|
||||
let coersionToExportedTypes = (
|
||||
inputs,
|
||||
env: ASTTypes.AST.environment,
|
||||
node: ASTTypes.AST.node,
|
||||
env: ASTTypes.environment,
|
||||
node: ASTTypes.node,
|
||||
): result<\"export", string> =>
|
||||
node
|
||||
|> renderIfNeeded(inputs)
|
||||
|
@ -160,7 +160,7 @@ let evaluateProgram = (inputs: Inputs.inputs) =>
|
|||
|
||||
let evaluateFunction = (
|
||||
inputs: Inputs.inputs,
|
||||
fn: (array<string>, ASTTypes.AST.node),
|
||||
fn: (array<string>, ASTTypes.node),
|
||||
fnInputs,
|
||||
) => {
|
||||
let output = AST.runFunction(
|
||||
|
|
|
@ -1,6 +1,6 @@
|
|||
open ASTTypes.AST
|
||||
open ASTTypes
|
||||
|
||||
let toString = ASTTypes.AST.Node.toString
|
||||
let toString = ASTTypes.Node.toString
|
||||
|
||||
let envs = (samplingInputs, environment) => {
|
||||
samplingInputs: samplingInputs,
|
||||
|
@ -18,7 +18,7 @@ let toPointSetDist = (samplingInputs, environment, node: node) =>
|
|||
| Error(e) => Error(e)
|
||||
}
|
||||
|
||||
let runFunction = (samplingInputs, environment, inputs, fn: ASTTypes.AST.Function.t) => {
|
||||
let runFunction = (samplingInputs, environment, inputs, fn: ASTTypes.Function.t) => {
|
||||
let params = envs(samplingInputs, environment)
|
||||
ASTTypes.AST.Function.run(params, inputs, fn)
|
||||
ASTTypes.Function.run(params, inputs, fn)
|
||||
}
|
|
@ -1,4 +1,4 @@
|
|||
open ASTTypes.AST
|
||||
open ASTTypes
|
||||
|
||||
type t = node
|
||||
type tResult = node => result<node, string>
|
||||
|
@ -43,12 +43,12 @@ module AlgebraicCombination = {
|
|||
|
||||
let combine = (evaluationParams, algebraicOp, t1: node, t2: node): result<node, string> =>
|
||||
E.R.merge(
|
||||
ASTTypes.AST.SamplingDistribution.renderIfIsNotSamplingDistribution(evaluationParams, t1),
|
||||
ASTTypes.AST.SamplingDistribution.renderIfIsNotSamplingDistribution(evaluationParams, t2),
|
||||
ASTTypes.SamplingDistribution.renderIfIsNotSamplingDistribution(evaluationParams, t1),
|
||||
ASTTypes.SamplingDistribution.renderIfIsNotSamplingDistribution(evaluationParams, t2),
|
||||
) |> E.R.bind(_, ((a, b)) =>
|
||||
switch choose(a, b) {
|
||||
| #Sampling =>
|
||||
ASTTypes.AST.SamplingDistribution.combineShapesUsingSampling(
|
||||
ASTTypes.SamplingDistribution.combineShapesUsingSampling(
|
||||
evaluationParams,
|
||||
algebraicOp,
|
||||
a,
|
||||
|
@ -123,7 +123,7 @@ module PointwiseCombination = {
|
|||
|
||||
module Truncate = {
|
||||
type simplificationResult = [
|
||||
| #Solution(ASTTypes.AST.node)
|
||||
| #Solution(ASTTypes.node)
|
||||
| #Error(string)
|
||||
| #NoSolution
|
||||
]
|
||||
|
@ -171,7 +171,7 @@ module Normalize = {
|
|||
switch t {
|
||||
| #RenderedDist(s) => Ok(#RenderedDist(PointSetDist.T.normalize(s)))
|
||||
| #SymbolicDist(_) => Ok(t)
|
||||
| _ => ASTTypes.AST.Node.evaluateAndRetry(evaluationParams, operationToLeaf, t)
|
||||
| _ => ASTTypes.Node.evaluateAndRetry(evaluationParams, operationToLeaf, t)
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -181,13 +181,13 @@ module FunctionCall = {
|
|||
|
||||
let _runLocalFunction = (name, evaluationParams: evaluationParams, args) =>
|
||||
Environment.getFunction(evaluationParams.environment, name) |> E.R.bind(_, ((argNames, fn)) =>
|
||||
ASTTypes.AST.Function.run(evaluationParams, args, (argNames, fn))
|
||||
ASTTypes.Function.run(evaluationParams, args, (argNames, fn))
|
||||
)
|
||||
|
||||
let _runWithEvaluatedInputs = (
|
||||
evaluationParams: ASTTypes.AST.evaluationParams,
|
||||
evaluationParams: ASTTypes.evaluationParams,
|
||||
name,
|
||||
args: array<ASTTypes.AST.node>,
|
||||
args: array<ASTTypes.node>,
|
||||
) =>
|
||||
_runHardcodedFunction(name, evaluationParams, args) |> E.O.default(
|
||||
_runLocalFunction(name, evaluationParams, args),
|
||||
|
@ -212,7 +212,7 @@ module Render = {
|
|||
),
|
||||
)
|
||||
| #RenderedDist(_) as t => Ok(t) // already a rendered pointSetDist, we're done here
|
||||
| _ => ASTTypes.AST.Node.evaluateAndRetry(evaluationParams, operationToLeaf, t)
|
||||
| _ => ASTTypes.Node.evaluateAndRetry(evaluationParams, operationToLeaf, t)
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -222,7 +222,7 @@ module Render = {
|
|||
but most often it will produce a RenderedDist.
|
||||
This function is used mainly to turn a parse tree into a single RenderedDist
|
||||
that can then be displayed to the user. */
|
||||
let rec toLeaf = (evaluationParams: ASTTypes.AST.evaluationParams, node: t): result<t, string> =>
|
||||
let rec toLeaf = (evaluationParams: ASTTypes.evaluationParams, node: t): result<t, string> =>
|
||||
switch node {
|
||||
// Leaf nodes just stay leaf nodes
|
||||
| #SymbolicDist(_)
|
||||
|
@ -248,7 +248,7 @@ let rec toLeaf = (evaluationParams: ASTTypes.AST.evaluationParams, node: t): res
|
|||
|> E.A.R.firstErrorOrOpen
|
||||
|> E.R.fmap(r => #Hash(r))
|
||||
| #Symbol(r) =>
|
||||
ASTTypes.AST.Environment.get(evaluationParams.environment, r)
|
||||
ASTTypes.Environment.get(evaluationParams.environment, r)
|
||||
|> E.O.toResult("Undeclared variable " ++ r)
|
||||
|> E.R.bind(_, toLeaf(evaluationParams))
|
||||
| #FunctionCall(name, args) =>
|
||||
|
|
|
@ -1,272 +1,257 @@
|
|||
module AST = {
|
||||
type rec hash = array<(string, node)>
|
||||
and node = [
|
||||
| #SymbolicDist(SymbolicDistTypes.symbolicDist)
|
||||
| #RenderedDist(PointSetTypes.pointSetDist)
|
||||
| #Symbol(string)
|
||||
| #Hash(hash)
|
||||
| #Array(array<node>)
|
||||
| #Function(array<string>, node)
|
||||
| #AlgebraicCombination(Operation.algebraicOperation, node, node)
|
||||
| #PointwiseCombination(Operation.pointwiseOperation, node, node)
|
||||
| #Normalize(node)
|
||||
| #Render(node)
|
||||
| #Truncate(option<float>, option<float>, node)
|
||||
| #FunctionCall(string, array<node>)
|
||||
]
|
||||
type rec hash = array<(string, node)>
|
||||
and node = [
|
||||
| #SymbolicDist(SymbolicDistTypes.symbolicDist)
|
||||
| #RenderedDist(PointSetTypes.pointSetDist)
|
||||
| #Symbol(string)
|
||||
| #Hash(hash)
|
||||
| #Array(array<node>)
|
||||
| #Function(array<string>, node)
|
||||
| #AlgebraicCombination(Operation.algebraicOperation, node, node)
|
||||
| #PointwiseCombination(Operation.pointwiseOperation, node, node)
|
||||
| #Normalize(node)
|
||||
| #Render(node)
|
||||
| #Truncate(option<float>, option<float>, node)
|
||||
| #FunctionCall(string, array<node>)
|
||||
]
|
||||
|
||||
type statement = [
|
||||
| #Assignment(string, node)
|
||||
| #Expression(node)
|
||||
]
|
||||
type program = array<statement>
|
||||
type statement = [
|
||||
| #Assignment(string, node)
|
||||
| #Expression(node)
|
||||
]
|
||||
type program = array<statement>
|
||||
|
||||
type environment = Belt.Map.String.t<node>
|
||||
type environment = Belt.Map.String.t<node>
|
||||
|
||||
type rec evaluationParams = {
|
||||
samplingInputs: SamplingInputs.samplingInputs,
|
||||
environment: environment,
|
||||
evaluateNode: (evaluationParams, node) => Belt.Result.t<node, string>,
|
||||
}
|
||||
type rec evaluationParams = {
|
||||
samplingInputs: SamplingInputs.samplingInputs,
|
||||
environment: environment,
|
||||
evaluateNode: (evaluationParams, node) => Belt.Result.t<node, string>,
|
||||
}
|
||||
|
||||
module Environment = {
|
||||
type t = environment
|
||||
module MS = Belt.Map.String
|
||||
let fromArray = MS.fromArray
|
||||
let empty: t = []->fromArray
|
||||
let mergeKeepSecond = (a: t, b: t) =>
|
||||
MS.merge(a, b, (_, a, b) =>
|
||||
switch (a, b) {
|
||||
| (_, Some(b)) => Some(b)
|
||||
| (Some(a), _) => Some(a)
|
||||
module Environment = {
|
||||
type t = environment
|
||||
module MS = Belt.Map.String
|
||||
let fromArray = MS.fromArray
|
||||
let empty: t = []->fromArray
|
||||
let mergeKeepSecond = (a: t, b: t) =>
|
||||
MS.merge(a, b, (_, a, b) =>
|
||||
switch (a, b) {
|
||||
| (_, Some(b)) => Some(b)
|
||||
| (Some(a), _) => Some(a)
|
||||
| _ => None
|
||||
}
|
||||
)
|
||||
let update = (t, str, fn) => MS.update(t, str, fn)
|
||||
let get = (t: t, str) => MS.get(t, str)
|
||||
let getFunction = (t: t, str) =>
|
||||
switch get(t, str) {
|
||||
| Some(#Function(argNames, fn)) => Ok((argNames, fn))
|
||||
| _ => Error("Function " ++ (str ++ " not found"))
|
||||
}
|
||||
}
|
||||
|
||||
module Node = {
|
||||
let getFloat = (node: node) =>
|
||||
node |> (
|
||||
x =>
|
||||
switch x {
|
||||
| #RenderedDist(Discrete({xyShape: {xs: [x], ys: [1.0]}})) => Some(x)
|
||||
| #SymbolicDist(#Float(x)) => Some(x)
|
||||
| _ => None
|
||||
}
|
||||
)
|
||||
let update = (t, str, fn) => MS.update(t, str, fn)
|
||||
let get = (t: t, str) => MS.get(t, str)
|
||||
let getFunction = (t: t, str) =>
|
||||
switch get(t, str) {
|
||||
| Some(#Function(argNames, fn)) => Ok((argNames, fn))
|
||||
| _ => Error("Function " ++ (str ++ " not found"))
|
||||
}
|
||||
}
|
||||
)
|
||||
|
||||
module Node = {
|
||||
let getFloat = (node: node) =>
|
||||
node |> (
|
||||
x =>
|
||||
switch x {
|
||||
| #RenderedDist(Discrete({xyShape: {xs: [x], ys: [1.0]}})) => Some(x)
|
||||
| #SymbolicDist(#Float(x)) => Some(x)
|
||||
| _ => None
|
||||
}
|
||||
)
|
||||
let evaluate = (evaluationParams: evaluationParams) =>
|
||||
evaluationParams.evaluateNode(evaluationParams)
|
||||
|
||||
let evaluate = (evaluationParams: evaluationParams) =>
|
||||
evaluationParams.evaluateNode(evaluationParams)
|
||||
let evaluateAndRetry = (evaluationParams, fn, node) =>
|
||||
node |> evaluationParams.evaluateNode(evaluationParams) |> E.R.bind(_, fn(evaluationParams))
|
||||
|
||||
let evaluateAndRetry = (evaluationParams, fn, node) =>
|
||||
node |> evaluationParams.evaluateNode(evaluationParams) |> E.R.bind(_, fn(evaluationParams))
|
||||
let rec toString: node => string = x =>
|
||||
switch x {
|
||||
| #SymbolicDist(d) => SymbolicDist.T.toString(d)
|
||||
| #RenderedDist(_) => "[renderedShape]"
|
||||
| #AlgebraicCombination(op, t1, t2) =>
|
||||
Operation.Algebraic.format(op, toString(t1), toString(t2))
|
||||
| #PointwiseCombination(op, t1, t2) =>
|
||||
Operation.Pointwise.format(op, toString(t1), toString(t2))
|
||||
| #Normalize(t) => "normalize(k" ++ (toString(t) ++ ")")
|
||||
| #Truncate(lc, rc, t) => Operation.Truncate.toString(lc, rc, toString(t))
|
||||
| #Render(t) => toString(t)
|
||||
| #Symbol(t) => "Symbol: " ++ t
|
||||
| #FunctionCall(name, args) =>
|
||||
"[Function call: (" ++
|
||||
(name ++
|
||||
((args |> E.A.fmap(toString) |> Js.String.concatMany(_, ",")) ++ ")]"))
|
||||
| #Function(args, internal) =>
|
||||
"[Function: (" ++ ((args |> Js.String.concatMany(_, ",")) ++ (toString(internal) ++ ")]"))
|
||||
| #Array(a) => "[" ++ ((a |> E.A.fmap(toString) |> Js.String.concatMany(_, ",")) ++ "]")
|
||||
| #Hash(h) =>
|
||||
"{" ++
|
||||
((h
|
||||
|> E.A.fmap(((name, value)) => name ++ (":" ++ toString(value)))
|
||||
|> Js.String.concatMany(_, ",")) ++
|
||||
"}")
|
||||
}
|
||||
|
||||
let rec toString: node => string = x =>
|
||||
switch x {
|
||||
| #SymbolicDist(d) => SymbolicDist.T.toString(d)
|
||||
| #RenderedDist(_) => "[renderedShape]"
|
||||
| #AlgebraicCombination(op, t1, t2) =>
|
||||
Operation.Algebraic.format(op, toString(t1), toString(t2))
|
||||
| #PointwiseCombination(op, t1, t2) =>
|
||||
Operation.Pointwise.format(op, toString(t1), toString(t2))
|
||||
| #Normalize(t) => "normalize(k" ++ (toString(t) ++ ")")
|
||||
| #Truncate(lc, rc, t) => Operation.Truncate.toString(lc, rc, toString(t))
|
||||
| #Render(t) => toString(t)
|
||||
| #Symbol(t) => "Symbol: " ++ t
|
||||
| #FunctionCall(name, args) =>
|
||||
"[Function call: (" ++
|
||||
(name ++
|
||||
((args |> E.A.fmap(toString) |> Js.String.concatMany(_, ",")) ++ ")]"))
|
||||
| #Function(args, internal) =>
|
||||
"[Function: (" ++ ((args |> Js.String.concatMany(_, ",")) ++ (toString(internal) ++ ")]"))
|
||||
| #Array(a) => "[" ++ ((a |> E.A.fmap(toString) |> Js.String.concatMany(_, ",")) ++ "]")
|
||||
| #Hash(h) =>
|
||||
"{" ++
|
||||
((h
|
||||
|> E.A.fmap(((name, value)) => name ++ (":" ++ toString(value)))
|
||||
|> Js.String.concatMany(_, ",")) ++
|
||||
"}")
|
||||
}
|
||||
let render = (evaluationParams: evaluationParams, r) => #Render(r) |> evaluate(evaluationParams)
|
||||
|
||||
let render = (evaluationParams: evaluationParams, r) => #Render(r) |> evaluate(evaluationParams)
|
||||
|
||||
let ensureIsRendered = (params, t) =>
|
||||
switch t {
|
||||
| #RenderedDist(_) => Ok(t)
|
||||
| _ =>
|
||||
switch render(params, t) {
|
||||
| Ok(#RenderedDist(r)) => Ok(#RenderedDist(r))
|
||||
| Ok(_) => Error("Did not render as requested")
|
||||
| Error(e) => Error(e)
|
||||
}
|
||||
}
|
||||
|
||||
let ensureIsRenderedAndGetShape = (params, t) =>
|
||||
switch ensureIsRendered(params, t) {
|
||||
| Ok(#RenderedDist(r)) => Ok(r)
|
||||
let ensureIsRendered = (params, t) =>
|
||||
switch t {
|
||||
| #RenderedDist(_) => Ok(t)
|
||||
| _ =>
|
||||
switch render(params, t) {
|
||||
| Ok(#RenderedDist(r)) => Ok(#RenderedDist(r))
|
||||
| Ok(_) => Error("Did not render as requested")
|
||||
| Error(e) => Error(e)
|
||||
}
|
||||
|
||||
let toPointSetDist = (item: node) =>
|
||||
switch item {
|
||||
| #RenderedDist(r) => Some(r)
|
||||
| _ => None
|
||||
}
|
||||
|
||||
let _toFloat = (t: PointSetTypes.pointSetDist) =>
|
||||
switch t {
|
||||
| Discrete({xyShape: {xs: [x], ys: [1.0]}}) => Some(#SymbolicDist(#Float(x)))
|
||||
| _ => None
|
||||
}
|
||||
|
||||
let toFloat = (item: node): result<node, string> =>
|
||||
item |> toPointSetDist |> E.O.bind(_, _toFloat) |> E.O.toResult("Not valid shape")
|
||||
}
|
||||
|
||||
module Function = {
|
||||
type t = (array<string>, node)
|
||||
let fromNode: node => option<t> = node =>
|
||||
switch node {
|
||||
| #Function(r) => Some(r)
|
||||
| _ => None
|
||||
}
|
||||
let argumentNames = ((a, _): t) => a
|
||||
let internals = ((_, b): t) => b
|
||||
let run = (evaluationParams: evaluationParams, args: array<node>, t: t) =>
|
||||
if E.A.length(args) == E.A.length(argumentNames(t)) {
|
||||
let newEnvironment = Belt.Array.zip(argumentNames(t), args) |> Environment.fromArray
|
||||
let newEvaluationParams: evaluationParams = {
|
||||
samplingInputs: evaluationParams.samplingInputs,
|
||||
environment: Environment.mergeKeepSecond(
|
||||
evaluationParams.environment,
|
||||
newEnvironment,
|
||||
),
|
||||
evaluateNode: evaluationParams.evaluateNode,
|
||||
}
|
||||
evaluationParams.evaluateNode(newEvaluationParams, internals(t))
|
||||
} else {
|
||||
Error("Wrong number of variables")
|
||||
}
|
||||
}
|
||||
|
||||
module Primative = {
|
||||
type t = [
|
||||
| #SymbolicDist(SymbolicDistTypes.symbolicDist)
|
||||
| #RenderedDist(PointSetTypes.pointSetDist)
|
||||
| #Function(array<string>, node)
|
||||
]
|
||||
|
||||
let isPrimative: node => bool = x =>
|
||||
switch x {
|
||||
| #SymbolicDist(_)
|
||||
| #RenderedDist(_)
|
||||
| #Function(_) => true
|
||||
| _ => false
|
||||
}
|
||||
|
||||
let fromNode: node => option<t> = x =>
|
||||
switch x {
|
||||
| #SymbolicDist(_) as n
|
||||
| #RenderedDist(_) as n
|
||||
| #Function(_) as n =>
|
||||
Some(n)
|
||||
| _ => None
|
||||
}
|
||||
}
|
||||
|
||||
module SamplingDistribution = {
|
||||
type t = [
|
||||
| #SymbolicDist(SymbolicDistTypes.symbolicDist)
|
||||
| #RenderedDist(PointSetTypes.pointSetDist)
|
||||
]
|
||||
|
||||
let isSamplingDistribution: node => bool = x =>
|
||||
switch x {
|
||||
| #SymbolicDist(_) => true
|
||||
| #RenderedDist(_) => true
|
||||
| _ => false
|
||||
}
|
||||
|
||||
let fromNode: node => result<t, string> = x =>
|
||||
switch x {
|
||||
| #SymbolicDist(n) => Ok(#SymbolicDist(n))
|
||||
| #RenderedDist(n) => Ok(#RenderedDist(n))
|
||||
| _ => Error("Not valid type")
|
||||
}
|
||||
|
||||
let renderIfIsNotSamplingDistribution = (params, t): result<node, string> =>
|
||||
!isSamplingDistribution(t)
|
||||
? switch Node.render(params, t) {
|
||||
| Ok(r) => Ok(r)
|
||||
| Error(e) => Error(e)
|
||||
}
|
||||
: Ok(t)
|
||||
|
||||
let map = (~renderedDistFn, ~symbolicDistFn, node: node) =>
|
||||
node |> (
|
||||
x =>
|
||||
switch x {
|
||||
| #RenderedDist(r) => Some(renderedDistFn(r))
|
||||
| #SymbolicDist(s) => Some(symbolicDistFn(s))
|
||||
| _ => None
|
||||
}
|
||||
)
|
||||
|
||||
let sampleN = n =>
|
||||
map(
|
||||
~renderedDistFn=PointSetDist.sampleNRendered(n),
|
||||
~symbolicDistFn=SymbolicDist.T.sampleN(n),
|
||||
)
|
||||
|
||||
let getCombinationSamples = (n, algebraicOp, t1: node, t2: node) =>
|
||||
switch (sampleN(n, t1), sampleN(n, t2)) {
|
||||
| (Some(a), Some(b)) =>
|
||||
Some(
|
||||
Belt.Array.zip(a, b) |> E.A.fmap(((a, b)) => Operation.Algebraic.toFn(algebraicOp, a, b)),
|
||||
)
|
||||
| _ => None
|
||||
}
|
||||
|
||||
let combineShapesUsingSampling = (
|
||||
evaluationParams: evaluationParams,
|
||||
algebraicOp,
|
||||
t1: node,
|
||||
t2: node,
|
||||
) => {
|
||||
let i1 = renderIfIsNotSamplingDistribution(evaluationParams, t1)
|
||||
let i2 = renderIfIsNotSamplingDistribution(evaluationParams, t2)
|
||||
E.R.merge(i1, i2) |> E.R.bind(_, ((a, b)) => {
|
||||
let samples = getCombinationSamples(
|
||||
evaluationParams.samplingInputs.sampleCount,
|
||||
algebraicOp,
|
||||
a,
|
||||
b,
|
||||
)
|
||||
|
||||
let pointSetDist =
|
||||
samples
|
||||
|> E.O.fmap(r =>
|
||||
SampleSet.toPointSetDist(
|
||||
~samplingInputs=evaluationParams.samplingInputs,
|
||||
~samples=r,
|
||||
(),
|
||||
)
|
||||
)
|
||||
|> E.O.bind(_, r => r.pointSetDist)
|
||||
|> E.O.toResult("No response")
|
||||
pointSetDist |> E.R.fmap(r => #Normalize(#RenderedDist(r)))
|
||||
})
|
||||
|
||||
// todo: This bottom part should probably be somewhere else.
|
||||
// todo: REFACTOR: I'm not sure about the SampleSet line.
|
||||
}
|
||||
|
||||
let ensureIsRenderedAndGetShape = (params, t) =>
|
||||
switch ensureIsRendered(params, t) {
|
||||
| Ok(#RenderedDist(r)) => Ok(r)
|
||||
| Ok(_) => Error("Did not render as requested")
|
||||
| Error(e) => Error(e)
|
||||
}
|
||||
|
||||
let toPointSetDist = (item: node) =>
|
||||
switch item {
|
||||
| #RenderedDist(r) => Some(r)
|
||||
| _ => None
|
||||
}
|
||||
|
||||
let _toFloat = (t: PointSetTypes.pointSetDist) =>
|
||||
switch t {
|
||||
| Discrete({xyShape: {xs: [x], ys: [1.0]}}) => Some(#SymbolicDist(#Float(x)))
|
||||
| _ => None
|
||||
}
|
||||
|
||||
let toFloat = (item: node): result<node, string> =>
|
||||
item |> toPointSetDist |> E.O.bind(_, _toFloat) |> E.O.toResult("Not valid shape")
|
||||
}
|
||||
|
||||
module Function = {
|
||||
type t = (array<string>, node)
|
||||
let fromNode: node => option<t> = node =>
|
||||
switch node {
|
||||
| #Function(r) => Some(r)
|
||||
| _ => None
|
||||
}
|
||||
let argumentNames = ((a, _): t) => a
|
||||
let internals = ((_, b): t) => b
|
||||
let run = (evaluationParams: evaluationParams, args: array<node>, t: t) =>
|
||||
if E.A.length(args) == E.A.length(argumentNames(t)) {
|
||||
let newEnvironment = Belt.Array.zip(argumentNames(t), args) |> Environment.fromArray
|
||||
let newEvaluationParams: evaluationParams = {
|
||||
samplingInputs: evaluationParams.samplingInputs,
|
||||
environment: Environment.mergeKeepSecond(evaluationParams.environment, newEnvironment),
|
||||
evaluateNode: evaluationParams.evaluateNode,
|
||||
}
|
||||
evaluationParams.evaluateNode(newEvaluationParams, internals(t))
|
||||
} else {
|
||||
Error("Wrong number of variables")
|
||||
}
|
||||
}
|
||||
|
||||
module Primative = {
|
||||
type t = [
|
||||
| #SymbolicDist(SymbolicDistTypes.symbolicDist)
|
||||
| #RenderedDist(PointSetTypes.pointSetDist)
|
||||
| #Function(array<string>, node)
|
||||
]
|
||||
|
||||
let isPrimative: node => bool = x =>
|
||||
switch x {
|
||||
| #SymbolicDist(_)
|
||||
| #RenderedDist(_)
|
||||
| #Function(_) => true
|
||||
| _ => false
|
||||
}
|
||||
|
||||
let fromNode: node => option<t> = x =>
|
||||
switch x {
|
||||
| #SymbolicDist(_) as n
|
||||
| #RenderedDist(_) as n
|
||||
| #Function(_) as n =>
|
||||
Some(n)
|
||||
| _ => None
|
||||
}
|
||||
}
|
||||
|
||||
module SamplingDistribution = {
|
||||
type t = [
|
||||
| #SymbolicDist(SymbolicDistTypes.symbolicDist)
|
||||
| #RenderedDist(PointSetTypes.pointSetDist)
|
||||
]
|
||||
|
||||
let isSamplingDistribution: node => bool = x =>
|
||||
switch x {
|
||||
| #SymbolicDist(_) => true
|
||||
| #RenderedDist(_) => true
|
||||
| _ => false
|
||||
}
|
||||
|
||||
let fromNode: node => result<t, string> = x =>
|
||||
switch x {
|
||||
| #SymbolicDist(n) => Ok(#SymbolicDist(n))
|
||||
| #RenderedDist(n) => Ok(#RenderedDist(n))
|
||||
| _ => Error("Not valid type")
|
||||
}
|
||||
|
||||
let renderIfIsNotSamplingDistribution = (params, t): result<node, string> =>
|
||||
!isSamplingDistribution(t)
|
||||
? switch Node.render(params, t) {
|
||||
| Ok(r) => Ok(r)
|
||||
| Error(e) => Error(e)
|
||||
}
|
||||
: Ok(t)
|
||||
|
||||
let map = (~renderedDistFn, ~symbolicDistFn, node: node) =>
|
||||
node |> (
|
||||
x =>
|
||||
switch x {
|
||||
| #RenderedDist(r) => Some(renderedDistFn(r))
|
||||
| #SymbolicDist(s) => Some(symbolicDistFn(s))
|
||||
| _ => None
|
||||
}
|
||||
)
|
||||
|
||||
let sampleN = n =>
|
||||
map(~renderedDistFn=PointSetDist.sampleNRendered(n), ~symbolicDistFn=SymbolicDist.T.sampleN(n))
|
||||
|
||||
let getCombinationSamples = (n, algebraicOp, t1: node, t2: node) =>
|
||||
switch (sampleN(n, t1), sampleN(n, t2)) {
|
||||
| (Some(a), Some(b)) =>
|
||||
Some(
|
||||
Belt.Array.zip(a, b) |> E.A.fmap(((a, b)) => Operation.Algebraic.toFn(algebraicOp, a, b)),
|
||||
)
|
||||
| _ => None
|
||||
}
|
||||
|
||||
let combineShapesUsingSampling = (
|
||||
evaluationParams: evaluationParams,
|
||||
algebraicOp,
|
||||
t1: node,
|
||||
t2: node,
|
||||
) => {
|
||||
let i1 = renderIfIsNotSamplingDistribution(evaluationParams, t1)
|
||||
let i2 = renderIfIsNotSamplingDistribution(evaluationParams, t2)
|
||||
E.R.merge(i1, i2) |> E.R.bind(_, ((a, b)) => {
|
||||
let samples = getCombinationSamples(
|
||||
evaluationParams.samplingInputs.sampleCount,
|
||||
algebraicOp,
|
||||
a,
|
||||
b,
|
||||
)
|
||||
|
||||
let pointSetDist =
|
||||
samples
|
||||
|> E.O.fmap(r =>
|
||||
SampleSet.toPointSetDist(~samplingInputs=evaluationParams.samplingInputs, ~samples=r, ())
|
||||
)
|
||||
|> E.O.bind(_, r => r.pointSetDist)
|
||||
|> E.O.toResult("No response")
|
||||
pointSetDist |> E.R.fmap(r => #Normalize(#RenderedDist(r)))
|
||||
})
|
||||
}
|
||||
}
|
|
@ -1,5 +1,5 @@
|
|||
type node = ASTTypes.AST.node
|
||||
let getFloat = ASTTypes.AST.Node.getFloat
|
||||
type node = ASTTypes.node
|
||||
let getFloat = ASTTypes.Node.getFloat
|
||||
|
||||
type samplingDist = [
|
||||
| #SymbolicDist(SymbolicDistTypes.symbolicDist)
|
||||
|
@ -61,7 +61,7 @@ module TypedValue = {
|
|||
|> E.A.fmap(((name, t)) => fromNode(t) |> E.R.fmap(r => (name, r)))
|
||||
|> E.A.R.firstErrorOrOpen
|
||||
|> E.R.fmap(r => #Hash(r))
|
||||
| e => Error("Wrong type: " ++ ASTTypes.AST.Node.toString(e))
|
||||
| e => Error("Wrong type: " ++ ASTTypes.Node.toString(e))
|
||||
}
|
||||
|
||||
// todo: Arrays and hashes
|
||||
|
@ -73,12 +73,12 @@ module TypedValue = {
|
|||
| _ => Error("Type Error: Expected float.")
|
||||
}
|
||||
| (#SamplingDistribution, _) =>
|
||||
ASTTypes.AST.SamplingDistribution.renderIfIsNotSamplingDistribution(
|
||||
ASTTypes.SamplingDistribution.renderIfIsNotSamplingDistribution(
|
||||
evaluationParams,
|
||||
node,
|
||||
) |> E.R.bind(_, fromNode)
|
||||
| (#RenderedDistribution, _) =>
|
||||
ASTTypes.AST.Node.render(evaluationParams, node) |> E.R.bind(_, fromNode)
|
||||
ASTTypes.Node.render(evaluationParams, node) |> E.R.bind(_, fromNode)
|
||||
| (#Array(_type), #Array(b)) =>
|
||||
b
|
||||
|> E.A.fmap(fromNodeWithTypeCoercion(evaluationParams, _type))
|
||||
|
@ -172,7 +172,7 @@ module Function = {
|
|||
|> E.A.R.firstErrorOrOpen
|
||||
|
||||
let inputsToTypedValues = (
|
||||
evaluationParams: ASTTypes.AST.evaluationParams,
|
||||
evaluationParams: ASTTypes.evaluationParams,
|
||||
inputNodes: inputNodes,
|
||||
t: t,
|
||||
) =>
|
||||
|
@ -181,7 +181,7 @@ module Function = {
|
|||
)
|
||||
|
||||
let run = (
|
||||
evaluationParams: ASTTypes.AST.evaluationParams,
|
||||
evaluationParams: ASTTypes.evaluationParams,
|
||||
inputNodes: inputNodes,
|
||||
t: t,
|
||||
) =>
|
||||
|
|
|
@ -122,7 +122,7 @@ module MathAdtToDistDst = {
|
|||
| _ => Error("Lognormal distribution needs either mean and stdev or mu and sigma")
|
||||
}
|
||||
| _ =>
|
||||
parseArgs() |> E.R.fmap((args: array<ASTTypes.AST.node>) =>
|
||||
parseArgs() |> E.R.fmap((args: array<ASTTypes.node>) =>
|
||||
#FunctionCall("lognormal", args)
|
||||
)
|
||||
}
|
||||
|
@ -130,8 +130,8 @@ module MathAdtToDistDst = {
|
|||
// Error("Dotwise exponentiation needs two operands")
|
||||
let operationParser = (
|
||||
name: string,
|
||||
args: result<array<ASTTypes.AST.node>, string>,
|
||||
): result<ASTTypes.AST.node, string> => {
|
||||
args: result<array<ASTTypes.node>, string>,
|
||||
): result<ASTTypes.node, string> => {
|
||||
let toOkAlgebraic = r => Ok(#AlgebraicCombination(r))
|
||||
let toOkPointwise = r => Ok(#PointwiseCombination(r))
|
||||
let toOkTruncate = r => Ok(#Truncate(r))
|
||||
|
@ -170,12 +170,12 @@ module MathAdtToDistDst = {
|
|||
|
||||
let functionParser = (
|
||||
nodeParser: MathJsonToMathJsAdt.arg => Belt.Result.t<
|
||||
ASTTypes.AST.node,
|
||||
ASTTypes.node,
|
||||
string,
|
||||
>,
|
||||
name: string,
|
||||
args: array<MathJsonToMathJsAdt.arg>,
|
||||
): result<ASTTypes.AST.node, string> => {
|
||||
): result<ASTTypes.node, string> => {
|
||||
let parseArray = ags => ags |> E.A.fmap(nodeParser) |> E.A.R.firstErrorOrOpen
|
||||
let parseArgs = () => parseArray(args)
|
||||
switch name {
|
||||
|
@ -212,27 +212,27 @@ module MathAdtToDistDst = {
|
|||
| (Some(Error(r)), _) => Error(r)
|
||||
| (_, Error(r)) => Error(r)
|
||||
| (None, Ok(dists)) =>
|
||||
let hash: ASTTypes.AST.node = #FunctionCall(
|
||||
let hash: ASTTypes.node = #FunctionCall(
|
||||
"multimodal",
|
||||
[#Hash([("dists", #Array(dists)), ("weights", #Array([]))])],
|
||||
)
|
||||
Ok(hash)
|
||||
| (Some(Ok(weights)), Ok(dists)) =>
|
||||
let hash: ASTTypes.AST.node = #FunctionCall(
|
||||
let hash: ASTTypes.node = #FunctionCall(
|
||||
"multimodal",
|
||||
[#Hash([("dists", #Array(dists)), ("weights", #Array(weights))])],
|
||||
)
|
||||
Ok(hash)
|
||||
}
|
||||
| name =>
|
||||
parseArgs() |> E.R.fmap((args: array<ASTTypes.AST.node>) =>
|
||||
parseArgs() |> E.R.fmap((args: array<ASTTypes.node>) =>
|
||||
#FunctionCall(name, args)
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
let rec nodeParser: MathJsonToMathJsAdt.arg => result<
|
||||
ASTTypes.AST.node,
|
||||
ASTTypes.node,
|
||||
string,
|
||||
> = x =>
|
||||
switch x {
|
||||
|
@ -246,7 +246,7 @@ module MathAdtToDistDst = {
|
|||
// let evaluatedExpression = run(expression);
|
||||
// `Function(_ => Ok(evaluatedExpression));
|
||||
// }
|
||||
let rec topLevel = (r): result<ASTTypes.AST.program, string> =>
|
||||
let rec topLevel = (r): result<ASTTypes.program, string> =>
|
||||
switch r {
|
||||
| FunctionAssignment({name, args, expression}) =>
|
||||
switch nodeParser(expression) {
|
||||
|
@ -267,7 +267,7 @@ module MathAdtToDistDst = {
|
|||
blocks |> E.A.fmap(b => topLevel(b)) |> E.A.R.firstErrorOrOpen |> E.R.fmap(E.A.concatMany)
|
||||
}
|
||||
|
||||
let run = (r): result<ASTTypes.AST.program, string> =>
|
||||
let run = (r): result<ASTTypes.program, string> =>
|
||||
r |> MathAdtCleaner.run |> topLevel
|
||||
}
|
||||
|
||||
|
|
Loading…
Reference in New Issue
Block a user