squiggle/packages/squiggle-lang/src/rescript/ProgramEvaluator.res

211 lines
5.5 KiB
Plaintext
Raw Normal View History

2022-01-18 00:10:06 +00:00
// TODO: This setup is more confusing than it should be, there's more work to do in cleanup here.
module Inputs = {
module SamplingInputs = {
type t = {
sampleCount: option<int>,
outputXYPoints: option<int>,
kernelWidth: option<float>,
2022-02-27 04:25:30 +00:00
pointDistLength: option<int>,
2022-01-18 00:10:06 +00:00
}
}
let defaultRecommendedLength = 100
let defaultShouldDownsample = true
type inputs = {
squiggleString: string,
samplingInputs: SamplingInputs.t,
2022-02-17 13:51:24 +00:00
environment: ASTTypes.environment,
2022-01-18 00:10:06 +00:00
}
let empty: SamplingInputs.t = {
sampleCount: None,
outputXYPoints: None,
kernelWidth: None,
2022-02-27 04:25:30 +00:00
pointDistLength: None,
2022-01-18 00:10:06 +00:00
}
let make = (
~samplingInputs=empty,
~squiggleString,
2022-02-17 13:51:24 +00:00
~environment=ASTTypes.Environment.empty,
2022-01-18 00:10:06 +00:00
(),
): inputs => {
samplingInputs: samplingInputs,
squiggleString: squiggleString,
environment: environment,
}
}
2022-02-27 04:25:30 +00:00
type exportType = [
| #DistPlus(DistPlus.t)
2022-01-18 00:10:06 +00:00
| #Float(float)
2022-02-27 04:25:30 +00:00
| #Function((float) => Belt.Result.t<DistPlus.t,string>)
2022-01-18 00:10:06 +00:00
]
module Internals = {
let addVariable = (
{samplingInputs, squiggleString, environment}: Inputs.inputs,
str,
node,
): Inputs.inputs => {
samplingInputs: samplingInputs,
squiggleString: squiggleString,
2022-02-17 13:51:24 +00:00
environment: ASTTypes.Environment.update(environment, str, _ => Some(
2022-01-18 00:10:06 +00:00
node,
)),
}
type outputs = {
2022-02-17 13:51:24 +00:00
graph: ASTTypes.node,
pointSetDist: PointSetTypes.pointSetDist,
2022-01-18 00:10:06 +00:00
}
2022-02-27 04:25:30 +00:00
let makeOutputs = (graph, shape): outputs => {graph: graph, pointSetDist: shape}
2022-01-18 00:10:06 +00:00
2022-02-16 22:37:59 +00:00
let makeInputs = (inputs: Inputs.inputs): SamplingInputs.samplingInputs => {
2022-01-18 00:10:06 +00:00
sampleCount: inputs.samplingInputs.sampleCount |> E.O.default(10000),
outputXYPoints: inputs.samplingInputs.outputXYPoints |> E.O.default(10000),
kernelWidth: inputs.samplingInputs.kernelWidth,
2022-02-27 04:25:30 +00:00
pointSetDistLength: inputs.samplingInputs.pointDistLength |> E.O.default(10000),
2022-01-18 00:10:06 +00:00
}
let runNode = (inputs, node) =>
2022-02-15 20:58:43 +00:00
AST.toLeaf(makeInputs(inputs), inputs.environment, node)
2022-01-18 00:10:06 +00:00
2022-02-17 13:51:24 +00:00
let runProgram = (inputs: Inputs.inputs, p: ASTTypes.program) => {
2022-01-18 00:10:06 +00:00
let ins = ref(inputs)
p
|> E.A.fmap(x =>
switch x {
| #Assignment(name, node) =>
ins := addVariable(ins.contents, name, node)
None
| #Expression(node) =>
Some(runNode(ins.contents, node) |> E.R.fmap(r => (ins.contents.environment, r)))
}
)
|> E.A.O.concatSomes
|> E.A.R.firstErrorOrOpen
}
let inputsToLeaf = (inputs: Inputs.inputs) =>
2022-02-27 04:25:30 +00:00
Parser.fromString(inputs.squiggleString) |> E.R.bind(_, g => runProgram(inputs, g))
2022-01-18 00:10:06 +00:00
let outputToDistPlus = (inputs: Inputs.inputs, pointSetDist: PointSetTypes.pointSetDist) =>
DistPlus.make(~pointSetDist, ~squiggleString=Some(inputs.squiggleString), ())
2022-01-18 00:10:06 +00:00
}
2022-02-17 13:51:24 +00:00
let renderIfNeeded = (inputs: Inputs.inputs, node: ASTTypes.node): result<
ASTTypes.node,
2022-01-18 00:10:06 +00:00
string,
> =>
node |> (
x =>
switch x {
| #Normalize(_) as n
| #SymbolicDist(_) as n =>
#Render(n)
|> Internals.runNode(inputs)
|> (
x =>
switch x {
| Ok(#RenderedDist(_)) as r => r
| Error(r) => Error(r)
| _ => Error("Didn't render, but intended to")
}
)
2022-01-29 22:43:08 +00:00
2022-01-18 00:10:06 +00:00
| n => Ok(n)
}
)
2022-02-27 04:25:30 +00:00
let rec returnDist = (functionInfo : (array<string>, ASTTypes.node),
inputs : Inputs.inputs,
env : ASTTypes.environment) => {
(input : float) => {
let foo: Inputs.inputs = {...inputs, environment: env};
evaluateFunction(
foo,
functionInfo,
[#SymbolicDist(#Float(input))],
) |> E.R.bind(_, a =>
switch a {
| #DistPlus(d) => Ok(DistPlus.T.normalize(d))
| n =>
Js.log2("Error here", n)
Error("wrong type")
}
)
}
}
// TODO: Consider using ExpressionTypes.ExpressionTree.getFloat or similar in this function
and coersionToExportedTypes = (
2022-01-18 00:10:06 +00:00
inputs,
2022-02-17 13:51:24 +00:00
env: ASTTypes.environment,
node: ASTTypes.node,
2022-02-27 04:25:30 +00:00
): result<exportType, string> =>
2022-01-18 00:10:06 +00:00
node
|> renderIfNeeded(inputs)
|> E.R.bind(_, x =>
switch x {
| #RenderedDist(Discrete({xyShape: {xs: [x], ys: [1.0]}})) => Ok(#Float(x))
| #SymbolicDist(#Float(x)) => Ok(#Float(x))
| #RenderedDist(n) => Ok(#DistPlus(Internals.outputToDistPlus(inputs, n)))
2022-02-27 04:25:30 +00:00
| #Function(n) => Ok(#Function(returnDist(n, inputs, env)))
2022-02-15 20:58:43 +00:00
| n => Error("Didn't output a rendered distribution. Format:" ++ AST.toString(n))
2022-01-18 00:10:06 +00:00
}
)
2022-02-27 04:25:30 +00:00
and evaluateFunction = (
inputs: Inputs.inputs,
fn: (array<string>, ASTTypes.node),
fnInputs,
) => {
let output = AST.runFunction(
Internals.makeInputs(inputs),
inputs.environment,
fnInputs,
fn,
)
output |> E.R.bind(_, coersionToExportedTypes(inputs, inputs.environment))
}
2022-01-18 00:10:06 +00:00
let rec mapM = (f, xs) =>
switch xs {
2022-02-27 04:25:30 +00:00
| [] => Ok([])
| arr =>
switch f(arr[0]) {
2022-01-29 22:43:08 +00:00
| Error(err) => Error(err)
| Ok(val) =>
2022-02-27 04:25:30 +00:00
switch mapM(f, Belt.Array.sliceToEnd(arr, 1)) {
2022-01-29 22:43:08 +00:00
| Error(err) => Error(err)
2022-02-27 04:25:30 +00:00
| Ok(restList) => Ok(Belt.Array.concat([val], restList))
2022-01-29 22:43:08 +00:00
}
}
}
2022-01-18 00:10:06 +00:00
let evaluateProgram = (inputs: Inputs.inputs) =>
2022-01-29 22:43:08 +00:00
inputs
|> Internals.inputsToLeaf
2022-02-27 04:25:30 +00:00
|> E.R.bind(_, xs => mapM(((a, b)) => coersionToExportedTypes(inputs, a, b), xs))
2022-01-18 00:10:06 +00:00
2022-01-29 22:43:08 +00:00
@genType
let runAll = (squiggleString: string) => {
let inputs = Inputs.make(
~samplingInputs={
sampleCount: Some(10000),
outputXYPoints: Some(10000),
kernelWidth: None,
2022-02-27 04:25:30 +00:00
pointDistLength: Some(1000),
2022-01-29 22:43:08 +00:00
},
~squiggleString,
~environment=[]->Belt.Map.String.fromArray,
(),
)
let response1 = evaluateProgram(inputs);
2022-02-27 04:25:30 +00:00
response1
2022-01-29 22:43:08 +00:00
}