Added OldInterpreter files
This commit is contained in:
parent
a2729f34cb
commit
6b69a94a1a
24
packages/squiggle-lang/src/rescript/OldInterpreter/AST.res
Normal file
24
packages/squiggle-lang/src/rescript/OldInterpreter/AST.res
Normal file
|
@ -0,0 +1,24 @@
|
|||
open ASTTypes
|
||||
|
||||
let toString = ASTTypes.Node.toString
|
||||
|
||||
let envs = (samplingInputs, environment) => {
|
||||
samplingInputs: samplingInputs,
|
||||
environment: environment,
|
||||
evaluateNode: ASTEvaluator.toLeaf,
|
||||
}
|
||||
|
||||
let toLeaf = (samplingInputs, environment, node: node) =>
|
||||
ASTEvaluator.toLeaf(envs(samplingInputs, environment), node)
|
||||
|
||||
let toPointSetDist = (samplingInputs, environment, node: node) =>
|
||||
switch toLeaf(samplingInputs, environment, node) {
|
||||
| Ok(#RenderedDist(pointSetDist)) => Ok(pointSetDist)
|
||||
| Ok(_) => Error("Rendering failed.")
|
||||
| Error(e) => Error(e)
|
||||
}
|
||||
|
||||
let runFunction = (samplingInputs, environment, inputs, fn: ASTTypes.Function.t) => {
|
||||
let params = envs(samplingInputs, environment)
|
||||
ASTTypes.Function.run(params, inputs, fn)
|
||||
}
|
|
@ -0,0 +1,257 @@
|
|||
open ASTTypes
|
||||
|
||||
type tResult = node => result<node, string>
|
||||
|
||||
/* Given two random variables A and B, this returns the distribution
|
||||
of a new variable that is the result of the operation on A and B.
|
||||
For instance, normal(0, 1) + normal(1, 1) -> normal(1, 2).
|
||||
In general, this is implemented via convolution. */
|
||||
module AlgebraicCombination = {
|
||||
let tryAnalyticalSimplification = (operation, t1: node, t2: node) =>
|
||||
switch (operation, t1, t2) {
|
||||
| (operation, #SymbolicDist(d1), #SymbolicDist(d2)) =>
|
||||
switch SymbolicDist.T.tryAnalyticalSimplification(d1, d2, operation) {
|
||||
| #AnalyticalSolution(symbolicDist) => Ok(#SymbolicDist(symbolicDist))
|
||||
| #Error(er) => Error(er)
|
||||
| #NoSolution => Ok(#AlgebraicCombination(operation, t1, t2))
|
||||
}
|
||||
| _ => Ok(#AlgebraicCombination(operation, t1, t2))
|
||||
}
|
||||
|
||||
let combinationByRendering = (evaluationParams, algebraicOp, t1: node, t2: node): result<
|
||||
node,
|
||||
string,
|
||||
> =>
|
||||
E.R.merge(
|
||||
Node.ensureIsRenderedAndGetShape(evaluationParams, t1),
|
||||
Node.ensureIsRenderedAndGetShape(evaluationParams, t2),
|
||||
) |> E.R.fmap(((a, b)) => #RenderedDist(PointSetDist.combineAlgebraically(algebraicOp, a, b)))
|
||||
|
||||
let nodeScore: node => int = x =>
|
||||
switch x {
|
||||
| #SymbolicDist(#Float(_)) => 1
|
||||
| #SymbolicDist(_) => 1000
|
||||
| #RenderedDist(Discrete(m)) => m.xyShape |> XYShape.T.length
|
||||
| #RenderedDist(Mixed(_)) => 1000
|
||||
| #RenderedDist(Continuous(_)) => 1000
|
||||
| _ => 1000
|
||||
}
|
||||
|
||||
let choose = (t1: node, t2: node) =>
|
||||
nodeScore(t1) * nodeScore(t2) > 10000 ? #Sampling : #Analytical
|
||||
|
||||
let combine = (evaluationParams, algebraicOp, t1: node, t2: node): result<node, string> =>
|
||||
E.R.merge(
|
||||
ASTTypes.SamplingDistribution.renderIfIsNotSamplingDistribution(evaluationParams, t1),
|
||||
ASTTypes.SamplingDistribution.renderIfIsNotSamplingDistribution(evaluationParams, t2),
|
||||
) |> E.R.bind(_, ((a, b)) =>
|
||||
switch choose(a, b) {
|
||||
| #Sampling =>
|
||||
ASTTypes.SamplingDistribution.combineShapesUsingSampling(
|
||||
evaluationParams,
|
||||
algebraicOp,
|
||||
a,
|
||||
b,
|
||||
)
|
||||
| #Analytical => combinationByRendering(evaluationParams, algebraicOp, a, b)
|
||||
}
|
||||
)
|
||||
|
||||
let operationToLeaf = (
|
||||
evaluationParams: evaluationParams,
|
||||
algebraicOp: Operation.algebraicOperation,
|
||||
t1: node,
|
||||
t2: node,
|
||||
): result<node, string> =>
|
||||
algebraicOp
|
||||
|> tryAnalyticalSimplification(_, t1, t2)
|
||||
|> E.R.bind(_, x =>
|
||||
switch x {
|
||||
| #SymbolicDist(_) as t => Ok(t)
|
||||
| _ => combine(evaluationParams, algebraicOp, t1, t2)
|
||||
}
|
||||
)
|
||||
}
|
||||
|
||||
module PointwiseCombination = {
|
||||
//TODO: This is crude and slow. It forces everything to be pointSetDist, even though much
|
||||
//of the process could happen on symbolic distributions without a conversion to be a pointSetDist.
|
||||
let pointwiseAdd = (evaluationParams: evaluationParams, t1: node, t2: node) =>
|
||||
switch (Node.render(evaluationParams, t1), Node.render(evaluationParams, t2)) {
|
||||
| (Ok(#RenderedDist(rs1)), Ok(#RenderedDist(rs2))) =>
|
||||
Ok(
|
||||
#RenderedDist(
|
||||
PointSetDist.combinePointwise(
|
||||
~integralSumCachesFn=(a, b) => Some(a +. b),
|
||||
~integralCachesFn=(a, b) => Some(
|
||||
Continuous.combinePointwise(~distributionType=#CDF, \"+.", a, b),
|
||||
),
|
||||
\"+.",
|
||||
rs1,
|
||||
rs2,
|
||||
),
|
||||
),
|
||||
)
|
||||
| (Error(e1), _) => Error(e1)
|
||||
| (_, Error(e2)) => Error(e2)
|
||||
| _ => Error("Pointwise combination: rendering failed.")
|
||||
}
|
||||
|
||||
let pointwiseCombine = (fn, evaluationParams: evaluationParams, t1: node, t2: node) =>
|
||||
switch // TODO: construct a function that we can easily sample from, to construct
|
||||
// a RenderedDist. Use the xMin and xMax of the rendered pointSetDists to tell the sampling function where to look.
|
||||
// TODO: This should work for symbolic distributions too!
|
||||
(Node.render(evaluationParams, t1), Node.render(evaluationParams, t2)) {
|
||||
| (Ok(#RenderedDist(rs1)), Ok(#RenderedDist(rs2))) =>
|
||||
Ok(#RenderedDist(PointSetDist.combinePointwise(fn, rs1, rs2)))
|
||||
| (Error(e1), _) => Error(e1)
|
||||
| (_, Error(e2)) => Error(e2)
|
||||
| _ => Error("Pointwise combination: rendering failed.")
|
||||
}
|
||||
|
||||
let operationToLeaf = (
|
||||
evaluationParams: evaluationParams,
|
||||
pointwiseOp: Operation.pointwiseOperation,
|
||||
t1: node,
|
||||
t2: node,
|
||||
) =>
|
||||
switch pointwiseOp {
|
||||
| #Add => pointwiseAdd(evaluationParams, t1, t2)
|
||||
| #Multiply => pointwiseCombine(\"*.", evaluationParams, t1, t2)
|
||||
| #Exponentiate => pointwiseCombine(\"**", evaluationParams, t1, t2)
|
||||
}
|
||||
}
|
||||
|
||||
module Truncate = {
|
||||
type simplificationResult = [
|
||||
| #Solution(ASTTypes.node)
|
||||
| #Error(string)
|
||||
| #NoSolution
|
||||
]
|
||||
|
||||
let trySimplification = (leftCutoff, rightCutoff, t): simplificationResult =>
|
||||
switch (leftCutoff, rightCutoff, t) {
|
||||
| (None, None, t) => #Solution(t)
|
||||
| (Some(lc), Some(rc), _) if lc > rc =>
|
||||
#Error("Left truncation bound must be smaller than right truncation bound.")
|
||||
| (lc, rc, #SymbolicDist(#Uniform(u))) =>
|
||||
#Solution(#SymbolicDist(#Uniform(SymbolicDist.Uniform.truncate(lc, rc, u))))
|
||||
| _ => #NoSolution
|
||||
}
|
||||
|
||||
let truncateAsShape = (evaluationParams: evaluationParams, leftCutoff, rightCutoff, t) =>
|
||||
switch // TODO: use named args for xMin/xMax in renderToShape; if we're lucky we can at least get the tail
|
||||
// of a distribution we otherwise wouldn't get at all
|
||||
Node.ensureIsRendered(evaluationParams, t) {
|
||||
| Ok(#RenderedDist(rs)) =>
|
||||
Ok(#RenderedDist(PointSetDist.T.truncate(leftCutoff, rightCutoff, rs)))
|
||||
| Error(e) => Error(e)
|
||||
| _ => Error("Could not truncate distribution.")
|
||||
}
|
||||
|
||||
let operationToLeaf = (
|
||||
evaluationParams,
|
||||
leftCutoff: option<float>,
|
||||
rightCutoff: option<float>,
|
||||
t: node,
|
||||
): result<node, string> =>
|
||||
t
|
||||
|> trySimplification(leftCutoff, rightCutoff)
|
||||
|> (
|
||||
x =>
|
||||
switch x {
|
||||
| #Solution(t) => Ok(t)
|
||||
| #Error(e) => Error(e)
|
||||
| #NoSolution => truncateAsShape(evaluationParams, leftCutoff, rightCutoff, t)
|
||||
}
|
||||
)
|
||||
}
|
||||
|
||||
module Normalize = {
|
||||
let rec operationToLeaf = (evaluationParams, t: node): result<node, string> =>
|
||||
switch t {
|
||||
| #RenderedDist(s) => Ok(#RenderedDist(PointSetDist.T.normalize(s)))
|
||||
| #SymbolicDist(_) => Ok(t)
|
||||
| _ => ASTTypes.Node.evaluateAndRetry(evaluationParams, operationToLeaf, t)
|
||||
}
|
||||
}
|
||||
|
||||
module FunctionCall = {
|
||||
let _runHardcodedFunction = (name, evaluationParams, args) =>
|
||||
TypeSystem.Function.Ts.findByNameAndRun(HardcodedFunctions.all, name, evaluationParams, args)
|
||||
|
||||
let _runLocalFunction = (name, evaluationParams: evaluationParams, args) =>
|
||||
Environment.getFunction(evaluationParams.environment, name) |> E.R.bind(_, ((argNames, fn)) =>
|
||||
ASTTypes.Function.run(evaluationParams, args, (argNames, fn))
|
||||
)
|
||||
|
||||
let _runWithEvaluatedInputs = (
|
||||
evaluationParams: ASTTypes.evaluationParams,
|
||||
name,
|
||||
args: array<ASTTypes.node>,
|
||||
) =>
|
||||
_runHardcodedFunction(name, evaluationParams, args) |> E.O.default(
|
||||
_runLocalFunction(name, evaluationParams, args),
|
||||
)
|
||||
|
||||
// TODO: This forces things to be floats
|
||||
let run = (evaluationParams, name, args) =>
|
||||
args
|
||||
|> E.A.fmap(a => evaluationParams.evaluateNode(evaluationParams, a))
|
||||
|> E.A.R.firstErrorOrOpen
|
||||
|> E.R.bind(_, _runWithEvaluatedInputs(evaluationParams, name))
|
||||
}
|
||||
|
||||
module Render = {
|
||||
let rec operationToLeaf = (evaluationParams: evaluationParams, t: node): result<node, string> =>
|
||||
switch t {
|
||||
| #Function(_) => Error("Cannot render a function")
|
||||
| #SymbolicDist(d) =>
|
||||
Ok(
|
||||
#RenderedDist(
|
||||
SymbolicDist.T.toPointSetDist(evaluationParams.samplingInputs.pointSetDistLength, d),
|
||||
),
|
||||
)
|
||||
| #RenderedDist(_) as t => Ok(t) // already a rendered pointSetDist, we're done here
|
||||
| _ => ASTTypes.Node.evaluateAndRetry(evaluationParams, operationToLeaf, t)
|
||||
}
|
||||
}
|
||||
|
||||
/* This function recursively goes through the nodes of the parse tree,
|
||||
replacing each Operation node and its subtree with a Data node.
|
||||
Whenever possible, the replacement produces a new Symbolic Data node,
|
||||
but most often it will produce a RenderedDist.
|
||||
This function is used mainly to turn a parse tree into a single RenderedDist
|
||||
that can then be displayed to the user. */
|
||||
let rec toLeaf = (evaluationParams: ASTTypes.evaluationParams, node: node): result<node, string> =>
|
||||
switch node {
|
||||
// Leaf nodes just stay leaf nodes
|
||||
| #SymbolicDist(_)
|
||||
| #Function(_)
|
||||
| #RenderedDist(_) =>
|
||||
Ok(node)
|
||||
| #Array(args) =>
|
||||
args |> E.A.fmap(toLeaf(evaluationParams)) |> E.A.R.firstErrorOrOpen |> E.R.fmap(r => #Array(r))
|
||||
// Operations nevaluationParamsd to be turned into leaves
|
||||
| #AlgebraicCombination(algebraicOp, t1, t2) =>
|
||||
AlgebraicCombination.operationToLeaf(evaluationParams, algebraicOp, t1, t2)
|
||||
| #PointwiseCombination(pointwiseOp, t1, t2) =>
|
||||
PointwiseCombination.operationToLeaf(evaluationParams, pointwiseOp, t1, t2)
|
||||
| #Truncate(leftCutoff, rightCutoff, t) =>
|
||||
Truncate.operationToLeaf(evaluationParams, leftCutoff, rightCutoff, t)
|
||||
| #Normalize(t) => Normalize.operationToLeaf(evaluationParams, t)
|
||||
| #Render(t) => Render.operationToLeaf(evaluationParams, t)
|
||||
| #Hash(t) =>
|
||||
t
|
||||
|> E.A.fmap(((name: string, node: node)) =>
|
||||
toLeaf(evaluationParams, node) |> E.R.fmap(r => (name, r))
|
||||
)
|
||||
|> E.A.R.firstErrorOrOpen
|
||||
|> E.R.fmap(r => #Hash(r))
|
||||
| #Symbol(r) =>
|
||||
ASTTypes.Environment.get(evaluationParams.environment, r)
|
||||
|> E.O.toResult("Undeclared variable " ++ r)
|
||||
|> E.R.bind(_, toLeaf(evaluationParams))
|
||||
| #FunctionCall(name, args) =>
|
||||
FunctionCall.run(evaluationParams, name, args) |> E.R.bind(_, toLeaf(evaluationParams))
|
||||
}
|
233
packages/squiggle-lang/src/rescript/OldInterpreter/ASTTypes.res
Normal file
233
packages/squiggle-lang/src/rescript/OldInterpreter/ASTTypes.res
Normal file
|
@ -0,0 +1,233 @@
|
|||
@genType
|
||||
type rec hash = array<(string, node)>
|
||||
and node = [
|
||||
| #SymbolicDist(SymbolicDistTypes.symbolicDist)
|
||||
| #RenderedDist(PointSetTypes.pointSetDist)
|
||||
| #Symbol(string)
|
||||
| #Hash(hash)
|
||||
| #Array(array<node>)
|
||||
| #Function(array<string>, node)
|
||||
| #AlgebraicCombination(Operation.algebraicOperation, node, node)
|
||||
| #PointwiseCombination(Operation.pointwiseOperation, node, node)
|
||||
| #Normalize(node)
|
||||
| #Render(node)
|
||||
| #Truncate(option<float>, option<float>, node)
|
||||
| #FunctionCall(string, array<node>)
|
||||
]
|
||||
|
||||
type statement = [
|
||||
| #Assignment(string, node)
|
||||
| #Expression(node)
|
||||
]
|
||||
type program = array<statement>
|
||||
|
||||
type environment = Belt.Map.String.t<node>
|
||||
|
||||
type rec evaluationParams = {
|
||||
samplingInputs: SamplingInputs.samplingInputs,
|
||||
environment: environment,
|
||||
evaluateNode: (evaluationParams, node) => Belt.Result.t<node, string>,
|
||||
}
|
||||
|
||||
module Environment = {
|
||||
type t = environment
|
||||
module MS = Belt.Map.String
|
||||
let fromArray = MS.fromArray
|
||||
let empty: t = []->fromArray
|
||||
let mergeKeepSecond = (a: t, b: t) =>
|
||||
MS.merge(a, b, (_, a, b) =>
|
||||
switch (a, b) {
|
||||
| (_, Some(b)) => Some(b)
|
||||
| (Some(a), _) => Some(a)
|
||||
| _ => None
|
||||
}
|
||||
)
|
||||
let update = (t, str, fn) => MS.update(t, str, fn)
|
||||
let get = (t: t, str) => MS.get(t, str)
|
||||
let getFunction = (t: t, str) =>
|
||||
switch get(t, str) {
|
||||
| Some(#Function(argNames, fn)) => Ok((argNames, fn))
|
||||
| _ => Error("Function " ++ (str ++ " not found"))
|
||||
}
|
||||
}
|
||||
|
||||
module Node = {
|
||||
let getFloat = (node: node) =>
|
||||
node |> (
|
||||
x =>
|
||||
switch x {
|
||||
| #RenderedDist(Discrete({xyShape: {xs: [x], ys: [1.0]}})) => Some(x)
|
||||
| #SymbolicDist(#Float(x)) => Some(x)
|
||||
| _ => None
|
||||
}
|
||||
)
|
||||
|
||||
let evaluate = (evaluationParams: evaluationParams) =>
|
||||
evaluationParams.evaluateNode(evaluationParams)
|
||||
|
||||
let evaluateAndRetry = (evaluationParams, fn, node) =>
|
||||
node |> evaluationParams.evaluateNode(evaluationParams) |> E.R.bind(_, fn(evaluationParams))
|
||||
|
||||
let rec toString: node => string = x =>
|
||||
switch x {
|
||||
| #SymbolicDist(d) => SymbolicDist.T.toString(d)
|
||||
| #RenderedDist(_) => "[renderedShape]"
|
||||
| #AlgebraicCombination(op, t1, t2) =>
|
||||
Operation.Algebraic.format(op, toString(t1), toString(t2))
|
||||
| #PointwiseCombination(op, t1, t2) =>
|
||||
Operation.Pointwise.format(op, toString(t1), toString(t2))
|
||||
| #Normalize(t) => "normalize(k" ++ (toString(t) ++ ")")
|
||||
| #Truncate(lc, rc, t) => Operation.Truncate.toString(lc, rc, toString(t))
|
||||
| #Render(t) => toString(t)
|
||||
| #Symbol(t) => "Symbol: " ++ t
|
||||
| #FunctionCall(name, args) =>
|
||||
"[Function call: (" ++
|
||||
(name ++
|
||||
((args |> E.A.fmap(toString) |> Js.String.concatMany(_, ",")) ++ ")]"))
|
||||
| #Function(args, internal) =>
|
||||
"[Function: (" ++ ((args |> Js.String.concatMany(_, ",")) ++ (toString(internal) ++ ")]"))
|
||||
| #Array(a) => "[" ++ ((a |> E.A.fmap(toString) |> Js.String.concatMany(_, ",")) ++ "]")
|
||||
| #Hash(h) =>
|
||||
"{" ++
|
||||
((h
|
||||
|> E.A.fmap(((name, value)) => name ++ (":" ++ toString(value)))
|
||||
|> Js.String.concatMany(_, ",")) ++
|
||||
"}")
|
||||
}
|
||||
|
||||
let render = (evaluationParams: evaluationParams, r) => #Render(r) |> evaluate(evaluationParams)
|
||||
|
||||
let ensureIsRendered = (params, t) =>
|
||||
switch t {
|
||||
| #RenderedDist(_) => Ok(t)
|
||||
| _ =>
|
||||
switch render(params, t) {
|
||||
| Ok(#RenderedDist(r)) => Ok(#RenderedDist(r))
|
||||
| Ok(_) => Error("Did not render as requested")
|
||||
| Error(e) => Error(e)
|
||||
}
|
||||
}
|
||||
|
||||
let ensureIsRenderedAndGetShape = (params, t) =>
|
||||
switch ensureIsRendered(params, t) {
|
||||
| Ok(#RenderedDist(r)) => Ok(r)
|
||||
| Ok(_) => Error("Did not render as requested")
|
||||
| Error(e) => Error(e)
|
||||
}
|
||||
|
||||
let toPointSetDist = (item: node) =>
|
||||
switch item {
|
||||
| #RenderedDist(r) => Some(r)
|
||||
| _ => None
|
||||
}
|
||||
|
||||
let _toFloat = (t: PointSetTypes.pointSetDist) =>
|
||||
switch t {
|
||||
| Discrete({xyShape: {xs: [x], ys: [1.0]}}) => Some(#SymbolicDist(#Float(x)))
|
||||
| _ => None
|
||||
}
|
||||
|
||||
let toFloat = (item: node): result<node, string> =>
|
||||
item |> toPointSetDist |> E.O.bind(_, _toFloat) |> E.O.toResult("Not valid shape")
|
||||
}
|
||||
|
||||
module Function = {
|
||||
type t = (array<string>, node)
|
||||
let fromNode: node => option<t> = node =>
|
||||
switch node {
|
||||
| #Function(r) => Some(r)
|
||||
| _ => None
|
||||
}
|
||||
let argumentNames = ((a, _): t) => a
|
||||
let internals = ((_, b): t) => b
|
||||
let run = (evaluationParams: evaluationParams, args: array<node>, t: t) =>
|
||||
if E.A.length(args) == E.A.length(argumentNames(t)) {
|
||||
let newEnvironment = Belt.Array.zip(argumentNames(t), args) |> Environment.fromArray
|
||||
let newEvaluationParams: evaluationParams = {
|
||||
samplingInputs: evaluationParams.samplingInputs,
|
||||
environment: Environment.mergeKeepSecond(evaluationParams.environment, newEnvironment),
|
||||
evaluateNode: evaluationParams.evaluateNode,
|
||||
}
|
||||
evaluationParams.evaluateNode(newEvaluationParams, internals(t))
|
||||
} else {
|
||||
Error("Wrong number of variables")
|
||||
}
|
||||
}
|
||||
|
||||
module SamplingDistribution = {
|
||||
type t = [
|
||||
| #SymbolicDist(SymbolicDistTypes.symbolicDist)
|
||||
| #RenderedDist(PointSetTypes.pointSetDist)
|
||||
]
|
||||
|
||||
let isSamplingDistribution: node => bool = x =>
|
||||
switch x {
|
||||
| #SymbolicDist(_) => true
|
||||
| #RenderedDist(_) => true
|
||||
| _ => false
|
||||
}
|
||||
|
||||
let fromNode: node => result<t, string> = x =>
|
||||
switch x {
|
||||
| #SymbolicDist(n) => Ok(#SymbolicDist(n))
|
||||
| #RenderedDist(n) => Ok(#RenderedDist(n))
|
||||
| _ => Error("Not valid type")
|
||||
}
|
||||
|
||||
let renderIfIsNotSamplingDistribution = (params, t): result<node, string> =>
|
||||
!isSamplingDistribution(t)
|
||||
? switch Node.render(params, t) {
|
||||
| Ok(r) => Ok(r)
|
||||
| Error(e) => Error(e)
|
||||
}
|
||||
: Ok(t)
|
||||
|
||||
let map = (~renderedDistFn, ~symbolicDistFn, node: node) =>
|
||||
node |> (
|
||||
x =>
|
||||
switch x {
|
||||
| #RenderedDist(r) => Some(renderedDistFn(r))
|
||||
| #SymbolicDist(s) => Some(symbolicDistFn(s))
|
||||
| _ => None
|
||||
}
|
||||
)
|
||||
|
||||
let sampleN = n =>
|
||||
map(~renderedDistFn=PointSetDist.sampleNRendered(n), ~symbolicDistFn=SymbolicDist.T.sampleN(n))
|
||||
|
||||
let getCombinationSamples = (n, algebraicOp, t1: node, t2: node) =>
|
||||
switch (sampleN(n, t1), sampleN(n, t2)) {
|
||||
| (Some(a), Some(b)) =>
|
||||
Some(
|
||||
Belt.Array.zip(a, b) |> E.A.fmap(((a, b)) => Operation.Algebraic.toFn(algebraicOp, a, b)),
|
||||
)
|
||||
| _ => None
|
||||
}
|
||||
|
||||
let combineShapesUsingSampling = (
|
||||
evaluationParams: evaluationParams,
|
||||
algebraicOp,
|
||||
t1: node,
|
||||
t2: node,
|
||||
) => {
|
||||
let i1 = renderIfIsNotSamplingDistribution(evaluationParams, t1)
|
||||
let i2 = renderIfIsNotSamplingDistribution(evaluationParams, t2)
|
||||
E.R.merge(i1, i2) |> E.R.bind(_, ((a, b)) => {
|
||||
let samples = getCombinationSamples(
|
||||
evaluationParams.samplingInputs.sampleCount,
|
||||
algebraicOp,
|
||||
a,
|
||||
b,
|
||||
)
|
||||
|
||||
let pointSetDist =
|
||||
samples
|
||||
|> E.O.fmap(r =>
|
||||
SampleSet.toPointSetDist(~samplingInputs=evaluationParams.samplingInputs, ~samples=r, ())
|
||||
)
|
||||
|> E.O.bind(_, r => r.pointSetDist)
|
||||
|> E.O.toResult("No response")
|
||||
pointSetDist |> E.R.fmap(r => #Normalize(#RenderedDist(r)))
|
||||
})
|
||||
}
|
||||
}
|
116
packages/squiggle-lang/src/rescript/OldInterpreter/DistPlus.res
Normal file
116
packages/squiggle-lang/src/rescript/OldInterpreter/DistPlus.res
Normal file
|
@ -0,0 +1,116 @@
|
|||
open PointSetTypes;
|
||||
|
||||
@genType
|
||||
type t = PointSetTypes.distPlus;
|
||||
|
||||
let pointSetDistIntegral = pointSetDist => PointSetDist.T.Integral.get(pointSetDist);
|
||||
let make =
|
||||
(
|
||||
~pointSetDist,
|
||||
~squiggleString,
|
||||
(),
|
||||
)
|
||||
: t => {
|
||||
let integral = pointSetDistIntegral(pointSetDist);
|
||||
{pointSetDist, integralCache: integral, squiggleString};
|
||||
};
|
||||
|
||||
let update =
|
||||
(
|
||||
~pointSetDist=?,
|
||||
~integralCache=?,
|
||||
~squiggleString=?,
|
||||
t: t,
|
||||
) => {
|
||||
pointSetDist: E.O.default(t.pointSetDist, pointSetDist),
|
||||
integralCache: E.O.default(t.integralCache, integralCache),
|
||||
squiggleString: E.O.default(t.squiggleString, squiggleString),
|
||||
};
|
||||
|
||||
let updateShape = (pointSetDist, t) => {
|
||||
let integralCache = pointSetDistIntegral(pointSetDist);
|
||||
update(~pointSetDist, ~integralCache, t);
|
||||
};
|
||||
|
||||
let toPointSetDist = ({pointSetDist, _}: t) => pointSetDist;
|
||||
|
||||
let pointSetDistFn = (fn, {pointSetDist}: t) => fn(pointSetDist);
|
||||
|
||||
module T =
|
||||
Distributions.Dist({
|
||||
type t = PointSetTypes.distPlus;
|
||||
type integral = PointSetTypes.distPlus;
|
||||
let toPointSetDist = toPointSetDist;
|
||||
let toContinuous = pointSetDistFn(PointSetDist.T.toContinuous);
|
||||
let toDiscrete = pointSetDistFn(PointSetDist.T.toDiscrete);
|
||||
|
||||
let normalize = (t: t): t => {
|
||||
let normalizedShape = t |> toPointSetDist |> PointSetDist.T.normalize;
|
||||
t |> updateShape(normalizedShape);
|
||||
};
|
||||
|
||||
let truncate = (leftCutoff, rightCutoff, t: t): t => {
|
||||
let truncatedShape =
|
||||
t
|
||||
|> toPointSetDist
|
||||
|> PointSetDist.T.truncate(leftCutoff, rightCutoff);
|
||||
|
||||
t |> updateShape(truncatedShape);
|
||||
};
|
||||
|
||||
let xToY = (f, t: t) =>
|
||||
t
|
||||
|> toPointSetDist
|
||||
|> PointSetDist.T.xToY(f);
|
||||
|
||||
let minX = pointSetDistFn(PointSetDist.T.minX);
|
||||
let maxX = pointSetDistFn(PointSetDist.T.maxX);
|
||||
let toDiscreteProbabilityMassFraction =
|
||||
pointSetDistFn(PointSetDist.T.toDiscreteProbabilityMassFraction);
|
||||
|
||||
// This bit is kind of awkward, could probably use rethinking.
|
||||
let integral = (t: t) =>
|
||||
updateShape(Continuous(t.integralCache), t);
|
||||
|
||||
let updateIntegralCache = (integralCache: option<PointSetTypes.continuousShape>, t) =>
|
||||
update(~integralCache=E.O.default(t.integralCache, integralCache), t);
|
||||
|
||||
let downsample = (i, t): t =>
|
||||
updateShape(t |> toPointSetDist |> PointSetDist.T.downsample(i), t);
|
||||
// todo: adjust for limit, maybe?
|
||||
let mapY =
|
||||
(
|
||||
~integralSumCacheFn=previousIntegralSum => None,
|
||||
~integralCacheFn=previousIntegralCache => None,
|
||||
~fn,
|
||||
{pointSetDist, _} as t: t,
|
||||
)
|
||||
: t =>
|
||||
PointSetDist.T.mapY(~integralSumCacheFn, ~fn, pointSetDist)
|
||||
|> updateShape(_, t);
|
||||
|
||||
// get the total of everything
|
||||
let integralEndY = (t: t) => {
|
||||
PointSetDist.T.Integral.sum(
|
||||
toPointSetDist(t),
|
||||
);
|
||||
};
|
||||
|
||||
// TODO: Fix this below, obviously. Adjust for limits
|
||||
let integralXtoY = (f, t: t) => {
|
||||
PointSetDist.T.Integral.xToY(
|
||||
f,
|
||||
toPointSetDist(t),
|
||||
)
|
||||
};
|
||||
|
||||
// TODO: This part is broken when there is a limit, if this is supposed to be taken into account.
|
||||
let integralYtoX = (f, t: t) => {
|
||||
PointSetDist.T.Integral.yToX(f, toPointSetDist(t));
|
||||
};
|
||||
|
||||
let mean = (t: t) => {
|
||||
PointSetDist.T.mean(t.pointSetDist);
|
||||
};
|
||||
let variance = (t: t) => PointSetDist.T.variance(t.pointSetDist);
|
||||
});
|
|
@ -0,0 +1,234 @@
|
|||
open TypeSystem
|
||||
|
||||
let wrongInputsError = (r: array<typedValue>) => {
|
||||
let inputs = r |> E.A.fmap(TypedValue.toString) |> Js.String.concatMany(_, ",")
|
||||
Js.log3("Inputs were", inputs, r)
|
||||
Error("Wrong inputs. The inputs were:" ++ inputs)
|
||||
}
|
||||
|
||||
let to_: (float, float) => result<node, string> = (low, high) =>
|
||||
switch (low, high) {
|
||||
| (low, high) if low <= 0.0 && low < high =>
|
||||
Ok(#SymbolicDist(SymbolicDist.Normal.from90PercentCI(low, high)))
|
||||
| (low, high) if low < high =>
|
||||
Ok(#SymbolicDist(SymbolicDist.Lognormal.from90PercentCI(low, high)))
|
||||
| (_, _) => Error("Low value must be less than high value.")
|
||||
}
|
||||
|
||||
let makeSymbolicFromTwoFloats = (name, fn) =>
|
||||
Function.T.make(
|
||||
~name,
|
||||
~outputType=#SamplingDistribution,
|
||||
~inputTypes=[#Float, #Float],
|
||||
~run=x =>
|
||||
switch x {
|
||||
| [#Float(a), #Float(b)] => fn(a, b) |> E.R.fmap(r => (#SymbolicDist(r)))
|
||||
| e => wrongInputsError(e)
|
||||
},
|
||||
(),
|
||||
)
|
||||
|
||||
let makeSymbolicFromOneFloat = (name, fn) =>
|
||||
Function.T.make(
|
||||
~name,
|
||||
~outputType=#SamplingDistribution,
|
||||
~inputTypes=[#Float],
|
||||
~run=x =>
|
||||
switch x {
|
||||
| [#Float(a)] => fn(a) |> E.R.fmap(r => #SymbolicDist(r))
|
||||
| e => wrongInputsError(e)
|
||||
},
|
||||
(),
|
||||
)
|
||||
|
||||
let makeDistFloat = (name, fn) =>
|
||||
Function.T.make(
|
||||
~name,
|
||||
~outputType=#SamplingDistribution,
|
||||
~inputTypes=[#SamplingDistribution, #Float],
|
||||
~run=x =>
|
||||
switch x {
|
||||
| [#SamplingDist(a), #Float(b)] => fn(a, b)
|
||||
| [#RenderedDist(a), #Float(b)] => fn(#RenderedDist(a), b)
|
||||
| e => wrongInputsError(e)
|
||||
},
|
||||
(),
|
||||
)
|
||||
|
||||
let makeRenderedDistFloat = (name, fn) =>
|
||||
Function.T.make(
|
||||
~name,
|
||||
~outputType=#RenderedDistribution,
|
||||
~inputTypes=[#RenderedDistribution, #Float],
|
||||
~shouldCoerceTypes=true,
|
||||
~run=x =>
|
||||
switch x {
|
||||
| [#RenderedDist(a), #Float(b)] => fn(a, b)
|
||||
| e => wrongInputsError(e)
|
||||
},
|
||||
(),
|
||||
)
|
||||
|
||||
let makeDist = (name, fn) =>
|
||||
Function.T.make(
|
||||
~name,
|
||||
~outputType=#SamplingDistribution,
|
||||
~inputTypes=[#SamplingDistribution],
|
||||
~run=x =>
|
||||
switch x {
|
||||
| [#SamplingDist(a)] => fn(a)
|
||||
| [#RenderedDist(a)] => fn(#RenderedDist(a))
|
||||
| e => wrongInputsError(e)
|
||||
},
|
||||
(),
|
||||
)
|
||||
|
||||
let floatFromDist = (
|
||||
distToFloatOp: Operation.distToFloatOperation,
|
||||
t: TypeSystem.samplingDist,
|
||||
): result<node, string> =>
|
||||
switch t {
|
||||
| #SymbolicDist(s) =>
|
||||
SymbolicDist.T.operate(distToFloatOp, s) |> E.R.bind(_, v => Ok(#SymbolicDist(#Float(v))))
|
||||
| #RenderedDist(rs) => PointSetDist.operate(distToFloatOp, rs) |> (v => Ok(#SymbolicDist(#Float(v))))
|
||||
}
|
||||
|
||||
let verticalScaling = (scaleOp, rs, scaleBy) => {
|
||||
// scaleBy has to be a single float, otherwise we'll return an error.
|
||||
let fn = (secondary, main) => Operation.Scale.toFn(scaleOp, main, secondary)
|
||||
let integralSumCacheFn = Operation.Scale.toIntegralSumCacheFn(scaleOp)
|
||||
let integralCacheFn = Operation.Scale.toIntegralCacheFn(scaleOp)
|
||||
Ok(
|
||||
#RenderedDist(
|
||||
PointSetDist.T.mapY(
|
||||
~integralSumCacheFn=integralSumCacheFn(scaleBy),
|
||||
~integralCacheFn=integralCacheFn(scaleBy),
|
||||
~fn=fn(scaleBy),
|
||||
rs,
|
||||
),
|
||||
),
|
||||
)
|
||||
}
|
||||
|
||||
module Multimodal = {
|
||||
let getByNameResult = Hash.getByNameResult
|
||||
|
||||
let _paramsToDistsAndWeights = (r: array<typedValue>) =>
|
||||
switch r {
|
||||
| [#Hash(r)] =>
|
||||
let dists =
|
||||
getByNameResult(r, "dists")
|
||||
->E.R.bind(TypeSystem.TypedValue.toArray)
|
||||
->E.R.bind(r => r |> E.A.fmap(TypeSystem.TypedValue.toDist) |> E.A.R.firstErrorOrOpen)
|
||||
let weights =
|
||||
getByNameResult(r, "weights")
|
||||
->E.R.bind(TypeSystem.TypedValue.toArray)
|
||||
->E.R.bind(r => r |> E.A.fmap(TypeSystem.TypedValue.toFloat) |> E.A.R.firstErrorOrOpen)
|
||||
|
||||
E.R.merge(dists, weights) -> E.R.bind(((a, b)) =>
|
||||
E.A.length(b) > E.A.length(a) ?
|
||||
Error("Too many weights provided") :
|
||||
Ok(E.A.zipMaxLength(a, b) |> E.A.fmap(((a, b)) => (a |> E.O.toExn(""), b |> E.O.default(1.0))))
|
||||
)
|
||||
| _ => Error("Needs items")
|
||||
}
|
||||
let _runner: array<typedValue> => result<node, string> = r => {
|
||||
let paramsToDistsAndWeights =
|
||||
_paramsToDistsAndWeights(r) |> E.R.fmap(
|
||||
E.A.fmap(((dist, weight)) =>
|
||||
#FunctionCall("scaleMultiply", [dist, #SymbolicDist(#Float(weight))])
|
||||
),
|
||||
)
|
||||
let pointwiseSum: result<node, string> =
|
||||
paramsToDistsAndWeights->E.R.bind(E.R.errorIfCondition(E.A.isEmpty, "Needs one input"))
|
||||
|> E.R.fmap(r =>
|
||||
r
|
||||
|> Js.Array.sliceFrom(1)
|
||||
|> E.A.fold_left((acc, x) => #PointwiseCombination(#Add, acc, x), E.A.unsafe_get(r, 0))
|
||||
)
|
||||
pointwiseSum
|
||||
}
|
||||
|
||||
let _function = Function.T.make(
|
||||
~name="multimodal",
|
||||
~outputType=#SamplingDistribution,
|
||||
~inputTypes=[#Hash([("dists", #Array(#SamplingDistribution)), ("weights", #Array(#Float))])],
|
||||
~run=_runner,
|
||||
(),
|
||||
)
|
||||
}
|
||||
|
||||
let all = [
|
||||
makeSymbolicFromTwoFloats("normal", SymbolicDist.Normal.make),
|
||||
makeSymbolicFromTwoFloats("uniform", SymbolicDist.Uniform.make),
|
||||
makeSymbolicFromTwoFloats("beta", SymbolicDist.Beta.make),
|
||||
makeSymbolicFromTwoFloats("lognormal", SymbolicDist.Lognormal.make),
|
||||
makeSymbolicFromTwoFloats("lognormalFromMeanAndStdDev", SymbolicDist.Lognormal.fromMeanAndStdev),
|
||||
makeSymbolicFromOneFloat("exponential", SymbolicDist.Exponential.make),
|
||||
Function.T.make(
|
||||
~name="to",
|
||||
~outputType=#SamplingDistribution,
|
||||
~inputTypes=[#Float, #Float],
|
||||
~run=x =>
|
||||
switch x {
|
||||
| [#Float(a), #Float(b)] => to_(a, b)
|
||||
| e => wrongInputsError(e)
|
||||
},
|
||||
(),
|
||||
),
|
||||
Function.T.make(
|
||||
~name="triangular",
|
||||
~outputType=#SamplingDistribution,
|
||||
~inputTypes=[#Float, #Float, #Float],
|
||||
~run=x =>
|
||||
switch x {
|
||||
| [#Float(a), #Float(b), #Float(c)] =>
|
||||
SymbolicDist.Triangular.make(a, b, c) |> E.R.fmap(r => #SymbolicDist(r))
|
||||
| e => wrongInputsError(e)
|
||||
},
|
||||
(),
|
||||
),
|
||||
Function.T.make(
|
||||
~name="log",
|
||||
~outputType=#Float,
|
||||
~inputTypes=[#Float],
|
||||
~run=x =>
|
||||
switch x {
|
||||
| [#Float(a)] => Ok(#SymbolicDist(#Float(Js.Math.log(a))))
|
||||
| e => wrongInputsError(e)
|
||||
},
|
||||
(),
|
||||
),
|
||||
makeDistFloat("pdf", (dist, float) => floatFromDist(#Pdf(float), dist)),
|
||||
makeDistFloat("inv", (dist, float) => floatFromDist(#Inv(float), dist)),
|
||||
makeDistFloat("cdf", (dist, float) => floatFromDist(#Cdf(float), dist)),
|
||||
makeDist("mean", dist => floatFromDist(#Mean, dist)),
|
||||
makeDist("sample", dist => floatFromDist(#Sample, dist)),
|
||||
Function.T.make(
|
||||
~name="render",
|
||||
~outputType=#RenderedDistribution,
|
||||
~inputTypes=[#RenderedDistribution],
|
||||
~run=x =>
|
||||
switch x {
|
||||
| [#RenderedDist(c)] => Ok(#RenderedDist(c))
|
||||
| e => wrongInputsError(e)
|
||||
},
|
||||
(),
|
||||
),
|
||||
Function.T.make(
|
||||
~name="normalize",
|
||||
~outputType=#SamplingDistribution,
|
||||
~inputTypes=[#SamplingDistribution],
|
||||
~run=x =>
|
||||
switch x {
|
||||
| [#SamplingDist(#SymbolicDist(c))] => Ok(#SymbolicDist(c))
|
||||
| [#SamplingDist(#RenderedDist(c))] => Ok(#RenderedDist(PointSetDist.T.normalize(c)))
|
||||
| e => wrongInputsError(e)
|
||||
},
|
||||
(),
|
||||
),
|
||||
makeRenderedDistFloat("scaleExp", (dist, float) => verticalScaling(#Exponentiate, dist, float)),
|
||||
makeRenderedDistFloat("scaleMultiply", (dist, float) => verticalScaling(#Multiply, dist, float)),
|
||||
makeRenderedDistFloat("scaleLog", (dist, float) => verticalScaling(#Logarithm, dist, float)),
|
||||
Multimodal._function,
|
||||
]
|
|
@ -0,0 +1,204 @@
|
|||
type node = ASTTypes.node
|
||||
let getFloat = ASTTypes.Node.getFloat
|
||||
|
||||
type samplingDist = [
|
||||
| #SymbolicDist(SymbolicDistTypes.symbolicDist)
|
||||
| #RenderedDist(PointSetTypes.pointSetDist)
|
||||
]
|
||||
|
||||
type rec hashType = array<(string, _type)>
|
||||
and _type = [
|
||||
| #Float
|
||||
| #SamplingDistribution
|
||||
| #RenderedDistribution
|
||||
| #Array(_type)
|
||||
| #Hash(hashType)
|
||||
]
|
||||
|
||||
type rec hashTypedValue = array<(string, typedValue)>
|
||||
and typedValue = [
|
||||
| #Float(float)
|
||||
| #RenderedDist(PointSetTypes.pointSetDist)
|
||||
| #SamplingDist(samplingDist)
|
||||
| #Array(array<typedValue>)
|
||||
| #Hash(hashTypedValue)
|
||||
]
|
||||
|
||||
type _function = {
|
||||
name: string,
|
||||
inputTypes: array<_type>,
|
||||
outputType: _type,
|
||||
run: array<typedValue> => result<node, string>,
|
||||
shouldCoerceTypes: bool,
|
||||
}
|
||||
|
||||
type functions = array<_function>
|
||||
type inputNodes = array<node>
|
||||
|
||||
module TypedValue = {
|
||||
let rec toString: typedValue => string = x =>
|
||||
switch x {
|
||||
| #SamplingDist(_) => "[sampling dist]"
|
||||
| #RenderedDist(_) => "[rendered PointSetDist]"
|
||||
| #Float(f) => "Float: " ++ Js.Float.toString(f)
|
||||
| #Array(a) => "[" ++ ((a |> E.A.fmap(toString) |> Js.String.concatMany(_, ",")) ++ "]")
|
||||
| #Hash(v) =>
|
||||
"{" ++
|
||||
((v
|
||||
|> E.A.fmap(((name, value)) => name ++ (":" ++ toString(value)))
|
||||
|> Js.String.concatMany(_, ",")) ++
|
||||
"}")
|
||||
}
|
||||
|
||||
let rec fromNode = (node: node): result<typedValue, string> =>
|
||||
switch node {
|
||||
| #SymbolicDist(#Float(r)) => Ok(#Float(r))
|
||||
| #SymbolicDist(s) => Ok(#SamplingDist(#SymbolicDist(s)))
|
||||
| #RenderedDist(s) => Ok(#RenderedDist(s))
|
||||
| #Array(r) => r |> E.A.fmap(fromNode) |> E.A.R.firstErrorOrOpen |> E.R.fmap(r => #Array(r))
|
||||
| #Hash(hash) =>
|
||||
hash
|
||||
|> E.A.fmap(((name, t)) => fromNode(t) |> E.R.fmap(r => (name, r)))
|
||||
|> E.A.R.firstErrorOrOpen
|
||||
|> E.R.fmap(r => #Hash(r))
|
||||
| e => Error("Wrong type: " ++ ASTTypes.Node.toString(e))
|
||||
}
|
||||
|
||||
// todo: Arrays and hashes
|
||||
let rec fromNodeWithTypeCoercion = (evaluationParams, _type: _type, node) =>
|
||||
switch (_type, node) {
|
||||
| (#Float, _) =>
|
||||
switch getFloat(node) {
|
||||
| Some(a) => Ok(#Float(a))
|
||||
| _ => Error("Type Error: Expected float.")
|
||||
}
|
||||
| (#SamplingDistribution, _) =>
|
||||
ASTTypes.SamplingDistribution.renderIfIsNotSamplingDistribution(
|
||||
evaluationParams,
|
||||
node,
|
||||
) |> E.R.bind(_, fromNode)
|
||||
| (#RenderedDistribution, _) =>
|
||||
ASTTypes.Node.render(evaluationParams, node) |> E.R.bind(_, fromNode)
|
||||
| (#Array(_type), #Array(b)) =>
|
||||
b
|
||||
|> E.A.fmap(fromNodeWithTypeCoercion(evaluationParams, _type))
|
||||
|> E.A.R.firstErrorOrOpen
|
||||
|> E.R.fmap(r => #Array(r))
|
||||
| (#Hash(named), #Hash(r)) =>
|
||||
let keyValues =
|
||||
named |> E.A.fmap(((name, intendedType)) => (
|
||||
name,
|
||||
intendedType,
|
||||
Hash.getByName(r, name),
|
||||
))
|
||||
let typedHash =
|
||||
keyValues
|
||||
|> E.A.fmap(((name, intendedType, optionNode)) =>
|
||||
switch optionNode {
|
||||
| Some(node) =>
|
||||
fromNodeWithTypeCoercion(evaluationParams, intendedType, node) |> E.R.fmap(node => (
|
||||
name,
|
||||
node,
|
||||
))
|
||||
| None => Error("Hash parameter not present in hash.")
|
||||
}
|
||||
)
|
||||
|> E.A.R.firstErrorOrOpen
|
||||
|> E.R.fmap(r => #Hash(r))
|
||||
typedHash
|
||||
| _ => Error("fromNodeWithTypeCoercion error, sorry.")
|
||||
}
|
||||
|
||||
let toFloat: typedValue => result<float, string> = x =>
|
||||
switch x {
|
||||
| #Float(x) => Ok(x)
|
||||
| _ => Error("Not a float")
|
||||
}
|
||||
|
||||
let toArray: typedValue => result<array<'a>, string> = x =>
|
||||
switch x {
|
||||
| #Array(x) => Ok(x)
|
||||
| _ => Error("Not an array")
|
||||
}
|
||||
|
||||
let toNamed: typedValue => result<hashTypedValue, string> = x =>
|
||||
switch x {
|
||||
| #Hash(x) => Ok(x)
|
||||
| _ => Error("Not a named item")
|
||||
}
|
||||
|
||||
let toDist: typedValue => result<node, string> = x =>
|
||||
switch x {
|
||||
| #SamplingDist(#SymbolicDist(c)) => Ok(#SymbolicDist(c))
|
||||
| #SamplingDist(#RenderedDist(c)) => Ok(#RenderedDist(c))
|
||||
| #RenderedDist(c) => Ok(#RenderedDist(c))
|
||||
| #Float(x) => Ok(#SymbolicDist(#Float(x)))
|
||||
| x => Error("Cannot be converted into a distribution: " ++ toString(x))
|
||||
}
|
||||
}
|
||||
|
||||
module Function = {
|
||||
type t = _function
|
||||
type ts = functions
|
||||
|
||||
module T = {
|
||||
let make = (~name, ~inputTypes, ~outputType, ~run, ~shouldCoerceTypes=true, _): t => {
|
||||
name: name,
|
||||
inputTypes: inputTypes,
|
||||
outputType: outputType,
|
||||
run: run,
|
||||
shouldCoerceTypes: shouldCoerceTypes,
|
||||
}
|
||||
|
||||
let _inputLengthCheck = (inputNodes: inputNodes, t: t) => {
|
||||
let expectedLength = E.A.length(t.inputTypes)
|
||||
let actualLength = E.A.length(inputNodes)
|
||||
expectedLength == actualLength
|
||||
? Ok(inputNodes)
|
||||
: Error(
|
||||
"Wrong number of inputs. Expected" ++
|
||||
((expectedLength |> E.I.toString) ++
|
||||
(". Got:" ++ (actualLength |> E.I.toString))),
|
||||
)
|
||||
}
|
||||
|
||||
let _coerceInputNodes = (evaluationParams, inputTypes, shouldCoerce, inputNodes) =>
|
||||
Belt.Array.zip(inputTypes, inputNodes)
|
||||
|> E.A.fmap(((def, input)) =>
|
||||
shouldCoerce
|
||||
? TypedValue.fromNodeWithTypeCoercion(evaluationParams, def, input)
|
||||
: TypedValue.fromNode(input)
|
||||
)
|
||||
|> E.A.R.firstErrorOrOpen
|
||||
|
||||
let inputsToTypedValues = (
|
||||
evaluationParams: ASTTypes.evaluationParams,
|
||||
inputNodes: inputNodes,
|
||||
t: t,
|
||||
) =>
|
||||
_inputLengthCheck(inputNodes, t)->E.R.bind(
|
||||
_coerceInputNodes(evaluationParams, t.inputTypes, t.shouldCoerceTypes),
|
||||
)
|
||||
|
||||
let run = (
|
||||
evaluationParams: ASTTypes.evaluationParams,
|
||||
inputNodes: inputNodes,
|
||||
t: t,
|
||||
) =>
|
||||
inputsToTypedValues(evaluationParams, inputNodes, t)->E.R.bind(t.run)
|
||||
|> (
|
||||
x =>
|
||||
switch x {
|
||||
| Ok(i) => Ok(i)
|
||||
| Error(r) => Error("Function " ++ (t.name ++ (" error: " ++ r)))
|
||||
}
|
||||
)
|
||||
}
|
||||
|
||||
module Ts = {
|
||||
let findByName = (ts: ts, n: string) => ts |> Belt.Array.getBy(_, ({name}) => name == n)
|
||||
|
||||
let findByNameAndRun = (ts: ts, n: string, evaluationParams, inputTypes) =>
|
||||
findByName(ts, n) |> E.O.fmap(T.run(evaluationParams, inputTypes))
|
||||
}
|
||||
}
|
Loading…
Reference in New Issue
Block a user