Merge pull request #20 from QURIresearch/Refactor-Feb-2022
Refactor feb 2022
This commit is contained in:
commit
c0c6a45dc7
|
@ -38,7 +38,7 @@ describe("Lodash", () =>
|
||||||
let toArr = discrete |> E.FloatFloatMap.toArray
|
let toArr = discrete |> E.FloatFloatMap.toArray
|
||||||
makeTest("splitMedium", toArr |> Belt.Array.length, 10)
|
makeTest("splitMedium", toArr |> Belt.Array.length, 10)
|
||||||
|
|
||||||
let (c, discrete) = SampleSet.Internals.T.splitContinuousAndDiscrete(
|
let (_c, discrete) = SampleSet.Internals.T.splitContinuousAndDiscrete(
|
||||||
makeDuplicatedArray(500),
|
makeDuplicatedArray(500),
|
||||||
)
|
)
|
||||||
let toArr = discrete |> E.FloatFloatMap.toArray
|
let toArr = discrete |> E.FloatFloatMap.toArray
|
||||||
|
|
|
@ -4,9 +4,9 @@
|
||||||
"homepage": "https://foretold-app.github.io/estiband/",
|
"homepage": "https://foretold-app.github.io/estiband/",
|
||||||
"private": false,
|
"private": false,
|
||||||
"scripts": {
|
"scripts": {
|
||||||
"build": "rescript build",
|
"build": "rescript build -with-deps",
|
||||||
"parcel": "parcel build ./src/js/index.js --no-source-maps --no-autoinstall",
|
"parcel": "parcel build ./src/js/index.js --no-source-maps --no-autoinstall",
|
||||||
"start": "rescript build -w",
|
"start": "rescript build -w -with-deps",
|
||||||
"clean": "rescript clean",
|
"clean": "rescript clean",
|
||||||
"test": "jest",
|
"test": "jest",
|
||||||
"test:ci": "yarn jest ./__tests__/Lodash__test.re",
|
"test:ci": "yarn jest ./__tests__/Lodash__test.re",
|
||||||
|
|
|
@ -14,7 +14,7 @@ module Inputs = {
|
||||||
type inputs = {
|
type inputs = {
|
||||||
squiggleString: string,
|
squiggleString: string,
|
||||||
samplingInputs: SamplingInputs.t,
|
samplingInputs: SamplingInputs.t,
|
||||||
environment: ASTTypes.AST.environment,
|
environment: ASTTypes.environment,
|
||||||
}
|
}
|
||||||
|
|
||||||
let empty: SamplingInputs.t = {
|
let empty: SamplingInputs.t = {
|
||||||
|
@ -27,7 +27,7 @@ module Inputs = {
|
||||||
let make = (
|
let make = (
|
||||||
~samplingInputs=empty,
|
~samplingInputs=empty,
|
||||||
~squiggleString,
|
~squiggleString,
|
||||||
~environment=ASTTypes.AST.Environment.empty,
|
~environment=ASTTypes.Environment.empty,
|
||||||
(),
|
(),
|
||||||
): inputs => {
|
): inputs => {
|
||||||
samplingInputs: samplingInputs,
|
samplingInputs: samplingInputs,
|
||||||
|
@ -36,12 +36,12 @@ module Inputs = {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
type \"export" = [
|
type exported = [
|
||||||
| #DistPlus(DistPlus.t)
|
| #DistPlus(DistPlus.t)
|
||||||
| #Float(float)
|
| #Float(float)
|
||||||
| #Function(
|
| #Function(
|
||||||
(array<string>, ASTTypes.AST.node),
|
(array<string>, ASTTypes.node),
|
||||||
ASTTypes.AST.environment,
|
ASTTypes.environment,
|
||||||
)
|
)
|
||||||
]
|
]
|
||||||
|
|
||||||
|
@ -53,18 +53,18 @@ module Internals = {
|
||||||
): Inputs.inputs => {
|
): Inputs.inputs => {
|
||||||
samplingInputs: samplingInputs,
|
samplingInputs: samplingInputs,
|
||||||
squiggleString: squiggleString,
|
squiggleString: squiggleString,
|
||||||
environment: ASTTypes.AST.Environment.update(environment, str, _ => Some(
|
environment: ASTTypes.Environment.update(environment, str, _ => Some(
|
||||||
node,
|
node,
|
||||||
)),
|
)),
|
||||||
}
|
}
|
||||||
|
|
||||||
type outputs = {
|
type outputs = {
|
||||||
graph: ASTTypes.AST.node,
|
graph: ASTTypes.node,
|
||||||
pointSetDist: PointSetTypes.pointSetDist,
|
pointSetDist: PointSetTypes.pointSetDist,
|
||||||
}
|
}
|
||||||
let makeOutputs = (graph, pointSetDist): outputs => {graph: graph, pointSetDist: pointSetDist}
|
let makeOutputs = (graph, pointSetDist): outputs => {graph: graph, pointSetDist: pointSetDist}
|
||||||
|
|
||||||
let makeInputs = (inputs: Inputs.inputs): ASTTypes.AST.samplingInputs => {
|
let makeInputs = (inputs: Inputs.inputs): SamplingInputs.samplingInputs => {
|
||||||
sampleCount: inputs.samplingInputs.sampleCount |> E.O.default(10000),
|
sampleCount: inputs.samplingInputs.sampleCount |> E.O.default(10000),
|
||||||
outputXYPoints: inputs.samplingInputs.outputXYPoints |> E.O.default(10000),
|
outputXYPoints: inputs.samplingInputs.outputXYPoints |> E.O.default(10000),
|
||||||
kernelWidth: inputs.samplingInputs.kernelWidth,
|
kernelWidth: inputs.samplingInputs.kernelWidth,
|
||||||
|
@ -74,7 +74,7 @@ module Internals = {
|
||||||
let runNode = (inputs, node) =>
|
let runNode = (inputs, node) =>
|
||||||
AST.toLeaf(makeInputs(inputs), inputs.environment, node)
|
AST.toLeaf(makeInputs(inputs), inputs.environment, node)
|
||||||
|
|
||||||
let runProgram = (inputs: Inputs.inputs, p: ASTTypes.Program.program) => {
|
let runProgram = (inputs: Inputs.inputs, p: ASTTypes.program) => {
|
||||||
let ins = ref(inputs)
|
let ins = ref(inputs)
|
||||||
p
|
p
|
||||||
|> E.A.fmap(x =>
|
|> E.A.fmap(x =>
|
||||||
|
@ -97,8 +97,8 @@ module Internals = {
|
||||||
DistPlus.make(~pointSetDist, ~squiggleString=Some(inputs.squiggleString), ())
|
DistPlus.make(~pointSetDist, ~squiggleString=Some(inputs.squiggleString), ())
|
||||||
}
|
}
|
||||||
|
|
||||||
let renderIfNeeded = (inputs: Inputs.inputs, node: ASTTypes.AST.node): result<
|
let renderIfNeeded = (inputs: Inputs.inputs, node: ASTTypes.node): result<
|
||||||
ASTTypes.AST.node,
|
ASTTypes.node,
|
||||||
string,
|
string,
|
||||||
> =>
|
> =>
|
||||||
node |> (
|
node |> (
|
||||||
|
@ -121,12 +121,12 @@ let renderIfNeeded = (inputs: Inputs.inputs, node: ASTTypes.AST.node): result<
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
// TODO: Consider using ASTTypes.AST.getFloat or similar in this function
|
// TODO: Consider using ASTTypes.getFloat or similar in this function
|
||||||
let coersionToExportedTypes = (
|
let coersionToExportedTypes = (
|
||||||
inputs,
|
inputs,
|
||||||
env: ASTTypes.AST.environment,
|
env: ASTTypes.environment,
|
||||||
node: ASTTypes.AST.node,
|
node: ASTTypes.node,
|
||||||
): result<\"export", string> =>
|
): result<exported, string> =>
|
||||||
node
|
node
|
||||||
|> renderIfNeeded(inputs)
|
|> renderIfNeeded(inputs)
|
||||||
|> E.R.bind(_, x =>
|
|> E.R.bind(_, x =>
|
||||||
|
@ -160,7 +160,7 @@ let evaluateProgram = (inputs: Inputs.inputs) =>
|
||||||
|
|
||||||
let evaluateFunction = (
|
let evaluateFunction = (
|
||||||
inputs: Inputs.inputs,
|
inputs: Inputs.inputs,
|
||||||
fn: (array<string>, ASTTypes.AST.node),
|
fn: (array<string>, ASTTypes.node),
|
||||||
fnInputs,
|
fnInputs,
|
||||||
) => {
|
) => {
|
||||||
let output = AST.runFunction(
|
let output = AST.runFunction(
|
||||||
|
|
|
@ -1,6 +1,6 @@
|
||||||
open ASTTypes.AST
|
open ASTTypes
|
||||||
|
|
||||||
let toString = ASTBasic.toString
|
let toString = ASTTypes.Node.toString
|
||||||
|
|
||||||
let envs = (samplingInputs, environment) => {
|
let envs = (samplingInputs, environment) => {
|
||||||
samplingInputs: samplingInputs,
|
samplingInputs: samplingInputs,
|
||||||
|
@ -18,7 +18,7 @@ let toPointSetDist = (samplingInputs, environment, node: node) =>
|
||||||
| Error(e) => Error(e)
|
| Error(e) => Error(e)
|
||||||
}
|
}
|
||||||
|
|
||||||
let runFunction = (samplingInputs, environment, inputs, fn: PTypes.Function.t) => {
|
let runFunction = (samplingInputs, environment, inputs, fn: ASTTypes.Function.t) => {
|
||||||
let params = envs(samplingInputs, environment)
|
let params = envs(samplingInputs, environment)
|
||||||
PTypes.Function.run(params, inputs, fn)
|
ASTTypes.Function.run(params, inputs, fn)
|
||||||
}
|
}
|
|
@ -1,27 +0,0 @@
|
||||||
open ASTTypes.AST
|
|
||||||
// This file exists to manage a dependency cycle. It would be good to refactor later.
|
|
||||||
|
|
||||||
let rec toString: node => string = x =>
|
|
||||||
switch x {
|
|
||||||
| #SymbolicDist(d) => SymbolicDist.T.toString(d)
|
|
||||||
| #RenderedDist(_) => "[renderedShape]"
|
|
||||||
| #AlgebraicCombination(op, t1, t2) => Operation.Algebraic.format(op, toString(t1), toString(t2))
|
|
||||||
| #PointwiseCombination(op, t1, t2) => Operation.Pointwise.format(op, toString(t1), toString(t2))
|
|
||||||
| #Normalize(t) => "normalize(k" ++ (toString(t) ++ ")")
|
|
||||||
| #Truncate(lc, rc, t) => Operation.T.truncateToString(lc, rc, toString(t))
|
|
||||||
| #Render(t) => toString(t)
|
|
||||||
| #Symbol(t) => "Symbol: " ++ t
|
|
||||||
| #FunctionCall(name, args) =>
|
|
||||||
"[Function call: (" ++
|
|
||||||
(name ++
|
|
||||||
((args |> E.A.fmap(toString) |> Js.String.concatMany(_, ",")) ++ ")]"))
|
|
||||||
| #Function(args, internal) =>
|
|
||||||
"[Function: (" ++ ((args |> Js.String.concatMany(_, ",")) ++ (toString(internal) ++ ")]"))
|
|
||||||
| #Array(a) => "[" ++ ((a |> E.A.fmap(toString) |> Js.String.concatMany(_, ",")) ++ "]")
|
|
||||||
| #Hash(h) =>
|
|
||||||
"{" ++
|
|
||||||
((h
|
|
||||||
|> E.A.fmap(((name, value)) => name ++ (":" ++ toString(value)))
|
|
||||||
|> Js.String.concatMany(_, ",")) ++
|
|
||||||
"}")
|
|
||||||
}
|
|
|
@ -1,5 +1,4 @@
|
||||||
open ASTTypes
|
open ASTTypes
|
||||||
open ASTTypes.AST
|
|
||||||
|
|
||||||
type t = node
|
type t = node
|
||||||
type tResult = node => result<node, string>
|
type tResult = node => result<node, string>
|
||||||
|
@ -25,8 +24,8 @@ module AlgebraicCombination = {
|
||||||
string,
|
string,
|
||||||
> =>
|
> =>
|
||||||
E.R.merge(
|
E.R.merge(
|
||||||
Render.ensureIsRenderedAndGetShape(evaluationParams, t1),
|
Node.ensureIsRenderedAndGetShape(evaluationParams, t1),
|
||||||
Render.ensureIsRenderedAndGetShape(evaluationParams, t2),
|
Node.ensureIsRenderedAndGetShape(evaluationParams, t2),
|
||||||
) |> E.R.fmap(((a, b)) => #RenderedDist(PointSetDist.combineAlgebraically(algebraicOp, a, b)))
|
) |> E.R.fmap(((a, b)) => #RenderedDist(PointSetDist.combineAlgebraically(algebraicOp, a, b)))
|
||||||
|
|
||||||
let nodeScore: node => int = x =>
|
let nodeScore: node => int = x =>
|
||||||
|
@ -44,19 +43,24 @@ module AlgebraicCombination = {
|
||||||
|
|
||||||
let combine = (evaluationParams, algebraicOp, t1: node, t2: node): result<node, string> =>
|
let combine = (evaluationParams, algebraicOp, t1: node, t2: node): result<node, string> =>
|
||||||
E.R.merge(
|
E.R.merge(
|
||||||
PTypes.SamplingDistribution.renderIfIsNotSamplingDistribution(evaluationParams, t1),
|
ASTTypes.SamplingDistribution.renderIfIsNotSamplingDistribution(evaluationParams, t1),
|
||||||
PTypes.SamplingDistribution.renderIfIsNotSamplingDistribution(evaluationParams, t2),
|
ASTTypes.SamplingDistribution.renderIfIsNotSamplingDistribution(evaluationParams, t2),
|
||||||
) |> E.R.bind(_, ((a, b)) =>
|
) |> E.R.bind(_, ((a, b)) =>
|
||||||
switch choose(a, b) {
|
switch choose(a, b) {
|
||||||
| #Sampling =>
|
| #Sampling =>
|
||||||
PTypes.SamplingDistribution.combineShapesUsingSampling(evaluationParams, algebraicOp, a, b)
|
ASTTypes.SamplingDistribution.combineShapesUsingSampling(
|
||||||
|
evaluationParams,
|
||||||
|
algebraicOp,
|
||||||
|
a,
|
||||||
|
b,
|
||||||
|
)
|
||||||
| #Analytical => combinationByRendering(evaluationParams, algebraicOp, a, b)
|
| #Analytical => combinationByRendering(evaluationParams, algebraicOp, a, b)
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
let operationToLeaf = (
|
let operationToLeaf = (
|
||||||
evaluationParams: evaluationParams,
|
evaluationParams: evaluationParams,
|
||||||
algebraicOp: ASTTypes.algebraicOperation,
|
algebraicOp: Operation.algebraicOperation,
|
||||||
t1: t,
|
t1: t,
|
||||||
t2: t,
|
t2: t,
|
||||||
): result<node, string> =>
|
): result<node, string> =>
|
||||||
|
@ -71,8 +75,10 @@ module AlgebraicCombination = {
|
||||||
}
|
}
|
||||||
|
|
||||||
module PointwiseCombination = {
|
module PointwiseCombination = {
|
||||||
|
//TODO: This is crude and slow. It forces everything to be pointSetDist, even though much
|
||||||
|
//of the process could happen on symbolic distributions without a conversion to be a pointSetDist.
|
||||||
let pointwiseAdd = (evaluationParams: evaluationParams, t1: t, t2: t) =>
|
let pointwiseAdd = (evaluationParams: evaluationParams, t1: t, t2: t) =>
|
||||||
switch (Render.render(evaluationParams, t1), Render.render(evaluationParams, t2)) {
|
switch (Node.render(evaluationParams, t1), Node.render(evaluationParams, t2)) {
|
||||||
| (Ok(#RenderedDist(rs1)), Ok(#RenderedDist(rs2))) =>
|
| (Ok(#RenderedDist(rs1)), Ok(#RenderedDist(rs2))) =>
|
||||||
Ok(
|
Ok(
|
||||||
#RenderedDist(
|
#RenderedDist(
|
||||||
|
@ -96,7 +102,7 @@ module PointwiseCombination = {
|
||||||
switch // TODO: construct a function that we can easily sample from, to construct
|
switch // TODO: construct a function that we can easily sample from, to construct
|
||||||
// a RenderedDist. Use the xMin and xMax of the rendered pointSetDists to tell the sampling function where to look.
|
// a RenderedDist. Use the xMin and xMax of the rendered pointSetDists to tell the sampling function where to look.
|
||||||
// TODO: This should work for symbolic distributions too!
|
// TODO: This should work for symbolic distributions too!
|
||||||
(Render.render(evaluationParams, t1), Render.render(evaluationParams, t2)) {
|
(Node.render(evaluationParams, t1), Node.render(evaluationParams, t2)) {
|
||||||
| (Ok(#RenderedDist(rs1)), Ok(#RenderedDist(rs2))) =>
|
| (Ok(#RenderedDist(rs1)), Ok(#RenderedDist(rs2))) =>
|
||||||
Ok(#RenderedDist(PointSetDist.combinePointwise(fn, rs1, rs2)))
|
Ok(#RenderedDist(PointSetDist.combinePointwise(fn, rs1, rs2)))
|
||||||
| (Error(e1), _) => Error(e1)
|
| (Error(e1), _) => Error(e1)
|
||||||
|
@ -106,7 +112,7 @@ module PointwiseCombination = {
|
||||||
|
|
||||||
let operationToLeaf = (
|
let operationToLeaf = (
|
||||||
evaluationParams: evaluationParams,
|
evaluationParams: evaluationParams,
|
||||||
pointwiseOp: pointwiseOperation,
|
pointwiseOp: Operation.pointwiseOperation,
|
||||||
t1: t,
|
t1: t,
|
||||||
t2: t,
|
t2: t,
|
||||||
) =>
|
) =>
|
||||||
|
@ -118,6 +124,12 @@ module PointwiseCombination = {
|
||||||
}
|
}
|
||||||
|
|
||||||
module Truncate = {
|
module Truncate = {
|
||||||
|
type simplificationResult = [
|
||||||
|
| #Solution(ASTTypes.node)
|
||||||
|
| #Error(string)
|
||||||
|
| #NoSolution
|
||||||
|
]
|
||||||
|
|
||||||
let trySimplification = (leftCutoff, rightCutoff, t): simplificationResult =>
|
let trySimplification = (leftCutoff, rightCutoff, t): simplificationResult =>
|
||||||
switch (leftCutoff, rightCutoff, t) {
|
switch (leftCutoff, rightCutoff, t) {
|
||||||
| (None, None, t) => #Solution(t)
|
| (None, None, t) => #Solution(t)
|
||||||
|
@ -131,8 +143,9 @@ module Truncate = {
|
||||||
let truncateAsShape = (evaluationParams: evaluationParams, leftCutoff, rightCutoff, t) =>
|
let truncateAsShape = (evaluationParams: evaluationParams, leftCutoff, rightCutoff, t) =>
|
||||||
switch // TODO: use named args for xMin/xMax in renderToShape; if we're lucky we can at least get the tail
|
switch // TODO: use named args for xMin/xMax in renderToShape; if we're lucky we can at least get the tail
|
||||||
// of a distribution we otherwise wouldn't get at all
|
// of a distribution we otherwise wouldn't get at all
|
||||||
Render.ensureIsRendered(evaluationParams, t) {
|
Node.ensureIsRendered(evaluationParams, t) {
|
||||||
| Ok(#RenderedDist(rs)) => Ok(#RenderedDist(PointSetDist.T.truncate(leftCutoff, rightCutoff, rs)))
|
| Ok(#RenderedDist(rs)) =>
|
||||||
|
Ok(#RenderedDist(PointSetDist.T.truncate(leftCutoff, rightCutoff, rs)))
|
||||||
| Error(e) => Error(e)
|
| Error(e) => Error(e)
|
||||||
| _ => Error("Could not truncate distribution.")
|
| _ => Error("Could not truncate distribution.")
|
||||||
}
|
}
|
||||||
|
@ -160,7 +173,7 @@ module Normalize = {
|
||||||
switch t {
|
switch t {
|
||||||
| #RenderedDist(s) => Ok(#RenderedDist(PointSetDist.T.normalize(s)))
|
| #RenderedDist(s) => Ok(#RenderedDist(PointSetDist.T.normalize(s)))
|
||||||
| #SymbolicDist(_) => Ok(t)
|
| #SymbolicDist(_) => Ok(t)
|
||||||
| _ => evaluateAndRetry(evaluationParams, operationToLeaf, t)
|
| _ => ASTTypes.Node.evaluateAndRetry(evaluationParams, operationToLeaf, t)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -170,13 +183,13 @@ module FunctionCall = {
|
||||||
|
|
||||||
let _runLocalFunction = (name, evaluationParams: evaluationParams, args) =>
|
let _runLocalFunction = (name, evaluationParams: evaluationParams, args) =>
|
||||||
Environment.getFunction(evaluationParams.environment, name) |> E.R.bind(_, ((argNames, fn)) =>
|
Environment.getFunction(evaluationParams.environment, name) |> E.R.bind(_, ((argNames, fn)) =>
|
||||||
PTypes.Function.run(evaluationParams, args, (argNames, fn))
|
ASTTypes.Function.run(evaluationParams, args, (argNames, fn))
|
||||||
)
|
)
|
||||||
|
|
||||||
let _runWithEvaluatedInputs = (
|
let _runWithEvaluatedInputs = (
|
||||||
evaluationParams: ASTTypes.AST.evaluationParams,
|
evaluationParams: ASTTypes.evaluationParams,
|
||||||
name,
|
name,
|
||||||
args: array<ASTTypes.AST.node>,
|
args: array<ASTTypes.node>,
|
||||||
) =>
|
) =>
|
||||||
_runHardcodedFunction(name, evaluationParams, args) |> E.O.default(
|
_runHardcodedFunction(name, evaluationParams, args) |> E.O.default(
|
||||||
_runLocalFunction(name, evaluationParams, args),
|
_runLocalFunction(name, evaluationParams, args),
|
||||||
|
@ -195,9 +208,13 @@ module Render = {
|
||||||
switch t {
|
switch t {
|
||||||
| #Function(_) => Error("Cannot render a function")
|
| #Function(_) => Error("Cannot render a function")
|
||||||
| #SymbolicDist(d) =>
|
| #SymbolicDist(d) =>
|
||||||
Ok(#RenderedDist(SymbolicDist.T.toPointSetDist(evaluationParams.samplingInputs.pointSetDistLength, d)))
|
Ok(
|
||||||
|
#RenderedDist(
|
||||||
|
SymbolicDist.T.toPointSetDist(evaluationParams.samplingInputs.pointSetDistLength, d),
|
||||||
|
),
|
||||||
|
)
|
||||||
| #RenderedDist(_) as t => Ok(t) // already a rendered pointSetDist, we're done here
|
| #RenderedDist(_) as t => Ok(t) // already a rendered pointSetDist, we're done here
|
||||||
| _ => evaluateAndRetry(evaluationParams, operationToLeaf, t)
|
| _ => ASTTypes.Node.evaluateAndRetry(evaluationParams, operationToLeaf, t)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -207,10 +224,7 @@ module Render = {
|
||||||
but most often it will produce a RenderedDist.
|
but most often it will produce a RenderedDist.
|
||||||
This function is used mainly to turn a parse tree into a single RenderedDist
|
This function is used mainly to turn a parse tree into a single RenderedDist
|
||||||
that can then be displayed to the user. */
|
that can then be displayed to the user. */
|
||||||
let rec toLeaf = (
|
let rec toLeaf = (evaluationParams: ASTTypes.evaluationParams, node: t): result<t, string> =>
|
||||||
evaluationParams: ASTTypes.AST.evaluationParams,
|
|
||||||
node: t,
|
|
||||||
): result<t, string> =>
|
|
||||||
switch node {
|
switch node {
|
||||||
// Leaf nodes just stay leaf nodes
|
// Leaf nodes just stay leaf nodes
|
||||||
| #SymbolicDist(_)
|
| #SymbolicDist(_)
|
||||||
|
@ -236,7 +250,7 @@ let rec toLeaf = (
|
||||||
|> E.A.R.firstErrorOrOpen
|
|> E.A.R.firstErrorOrOpen
|
||||||
|> E.R.fmap(r => #Hash(r))
|
|> E.R.fmap(r => #Hash(r))
|
||||||
| #Symbol(r) =>
|
| #Symbol(r) =>
|
||||||
ASTTypes.AST.Environment.get(evaluationParams.environment, r)
|
ASTTypes.Environment.get(evaluationParams.environment, r)
|
||||||
|> E.O.toResult("Undeclared variable " ++ r)
|
|> E.O.toResult("Undeclared variable " ++ r)
|
||||||
|> E.R.bind(_, toLeaf(evaluationParams))
|
|> E.R.bind(_, toLeaf(evaluationParams))
|
||||||
| #FunctionCall(name, args) =>
|
| #FunctionCall(name, args) =>
|
||||||
|
|
|
@ -1,90 +1,34 @@
|
||||||
type algebraicOperation = [
|
type rec hash = array<(string, node)>
|
||||||
| #Add
|
and node = [
|
||||||
| #Multiply
|
|
||||||
| #Subtract
|
|
||||||
| #Divide
|
|
||||||
| #Exponentiate
|
|
||||||
]
|
|
||||||
type pointwiseOperation = [#Add | #Multiply | #Exponentiate]
|
|
||||||
type scaleOperation = [#Multiply | #Exponentiate | #Log]
|
|
||||||
type distToFloatOperation = [
|
|
||||||
| #Pdf(float)
|
|
||||||
| #Cdf(float)
|
|
||||||
| #Inv(float)
|
|
||||||
| #Mean
|
|
||||||
| #Sample
|
|
||||||
]
|
|
||||||
|
|
||||||
module AST = {
|
|
||||||
type rec hash = array<(string, node)>
|
|
||||||
and node = [
|
|
||||||
| #SymbolicDist(SymbolicDistTypes.symbolicDist)
|
| #SymbolicDist(SymbolicDistTypes.symbolicDist)
|
||||||
| #RenderedDist(PointSetTypes.pointSetDist)
|
| #RenderedDist(PointSetTypes.pointSetDist)
|
||||||
| #Symbol(string)
|
| #Symbol(string)
|
||||||
| #Hash(hash)
|
| #Hash(hash)
|
||||||
| #Array(array<node>)
|
| #Array(array<node>)
|
||||||
| #Function(array<string>, node)
|
| #Function(array<string>, node)
|
||||||
| #AlgebraicCombination(algebraicOperation, node, node)
|
| #AlgebraicCombination(Operation.algebraicOperation, node, node)
|
||||||
| #PointwiseCombination(pointwiseOperation, node, node)
|
| #PointwiseCombination(Operation.pointwiseOperation, node, node)
|
||||||
| #Normalize(node)
|
| #Normalize(node)
|
||||||
| #Render(node)
|
| #Render(node)
|
||||||
| #Truncate(option<float>, option<float>, node)
|
| #Truncate(option<float>, option<float>, node)
|
||||||
| #FunctionCall(string, array<node>)
|
| #FunctionCall(string, array<node>)
|
||||||
]
|
]
|
||||||
|
|
||||||
module Hash = {
|
type statement = [
|
||||||
type t<'a> = array<(string, 'a)>
|
| #Assignment(string, node)
|
||||||
let getByName = (t: t<'a>, name) =>
|
| #Expression(node)
|
||||||
E.A.getBy(t, ((n, _)) => n == name) |> E.O.fmap(((_, r)) => r)
|
]
|
||||||
|
type program = array<statement>
|
||||||
|
|
||||||
let getByNameResult = (t: t<'a>, name) =>
|
type environment = Belt.Map.String.t<node>
|
||||||
getByName(t, name) |> E.O.toResult(name ++ " expected and not found")
|
|
||||||
|
|
||||||
let getByNames = (hash: t<'a>, names: array<string>) =>
|
type rec evaluationParams = {
|
||||||
names |> E.A.fmap(name => (name, getByName(hash, name)))
|
samplingInputs: SamplingInputs.samplingInputs,
|
||||||
}
|
environment: environment,
|
||||||
// Have nil as option
|
evaluateNode: (evaluationParams, node) => Belt.Result.t<node, string>,
|
||||||
let getFloat = (node: node) =>
|
}
|
||||||
node |> (
|
|
||||||
x =>
|
|
||||||
switch x {
|
|
||||||
| #RenderedDist(Discrete({xyShape: {xs: [x], ys: [1.0]}})) => Some(x)
|
|
||||||
| #SymbolicDist(#Float(x)) => Some(x)
|
|
||||||
| _ => None
|
|
||||||
}
|
|
||||||
)
|
|
||||||
|
|
||||||
let toFloatIfNeeded = (node: node) =>
|
module Environment = {
|
||||||
switch node |> getFloat {
|
|
||||||
| Some(float) => #SymbolicDist(#Float(float))
|
|
||||||
| None => node
|
|
||||||
}
|
|
||||||
|
|
||||||
type samplingInputs = {
|
|
||||||
sampleCount: int,
|
|
||||||
outputXYPoints: int,
|
|
||||||
kernelWidth: option<float>,
|
|
||||||
pointSetDistLength: int,
|
|
||||||
}
|
|
||||||
|
|
||||||
module SamplingInputs = {
|
|
||||||
type t = {
|
|
||||||
sampleCount: option<int>,
|
|
||||||
outputXYPoints: option<int>,
|
|
||||||
kernelWidth: option<float>,
|
|
||||||
pointSetDistLength: option<int>,
|
|
||||||
}
|
|
||||||
let withDefaults = (t: t): samplingInputs => {
|
|
||||||
sampleCount: t.sampleCount |> E.O.default(10000),
|
|
||||||
outputXYPoints: t.outputXYPoints |> E.O.default(10000),
|
|
||||||
kernelWidth: t.kernelWidth,
|
|
||||||
pointSetDistLength: t.pointSetDistLength |> E.O.default(10000),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
type environment = Belt.Map.String.t<node>
|
|
||||||
|
|
||||||
module Environment = {
|
|
||||||
type t = environment
|
type t = environment
|
||||||
module MS = Belt.Map.String
|
module MS = Belt.Map.String
|
||||||
let fromArray = MS.fromArray
|
let fromArray = MS.fromArray
|
||||||
|
@ -104,25 +48,53 @@ module AST = {
|
||||||
| Some(#Function(argNames, fn)) => Ok((argNames, fn))
|
| Some(#Function(argNames, fn)) => Ok((argNames, fn))
|
||||||
| _ => Error("Function " ++ (str ++ " not found"))
|
| _ => Error("Function " ++ (str ++ " not found"))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
type rec evaluationParams = {
|
module Node = {
|
||||||
samplingInputs: samplingInputs,
|
let getFloat = (node: node) =>
|
||||||
environment: environment,
|
node |> (
|
||||||
evaluateNode: (evaluationParams, node) => Belt.Result.t<node, string>,
|
x =>
|
||||||
|
switch x {
|
||||||
|
| #RenderedDist(Discrete({xyShape: {xs: [x], ys: [1.0]}})) => Some(x)
|
||||||
|
| #SymbolicDist(#Float(x)) => Some(x)
|
||||||
|
| _ => None
|
||||||
}
|
}
|
||||||
|
)
|
||||||
|
|
||||||
let evaluateNode = (evaluationParams: evaluationParams) =>
|
let evaluate = (evaluationParams: evaluationParams) =>
|
||||||
evaluationParams.evaluateNode(evaluationParams)
|
evaluationParams.evaluateNode(evaluationParams)
|
||||||
|
|
||||||
let evaluateAndRetry = (evaluationParams, fn, node) =>
|
let evaluateAndRetry = (evaluationParams, fn, node) =>
|
||||||
node |> evaluationParams.evaluateNode(evaluationParams) |> E.R.bind(_, fn(evaluationParams))
|
node |> evaluationParams.evaluateNode(evaluationParams) |> E.R.bind(_, fn(evaluationParams))
|
||||||
|
|
||||||
module Render = {
|
let rec toString: node => string = x =>
|
||||||
type t = node
|
switch x {
|
||||||
|
| #SymbolicDist(d) => SymbolicDist.T.toString(d)
|
||||||
|
| #RenderedDist(_) => "[renderedShape]"
|
||||||
|
| #AlgebraicCombination(op, t1, t2) =>
|
||||||
|
Operation.Algebraic.format(op, toString(t1), toString(t2))
|
||||||
|
| #PointwiseCombination(op, t1, t2) =>
|
||||||
|
Operation.Pointwise.format(op, toString(t1), toString(t2))
|
||||||
|
| #Normalize(t) => "normalize(k" ++ (toString(t) ++ ")")
|
||||||
|
| #Truncate(lc, rc, t) => Operation.Truncate.toString(lc, rc, toString(t))
|
||||||
|
| #Render(t) => toString(t)
|
||||||
|
| #Symbol(t) => "Symbol: " ++ t
|
||||||
|
| #FunctionCall(name, args) =>
|
||||||
|
"[Function call: (" ++
|
||||||
|
(name ++
|
||||||
|
((args |> E.A.fmap(toString) |> Js.String.concatMany(_, ",")) ++ ")]"))
|
||||||
|
| #Function(args, internal) =>
|
||||||
|
"[Function: (" ++ ((args |> Js.String.concatMany(_, ",")) ++ (toString(internal) ++ ")]"))
|
||||||
|
| #Array(a) => "[" ++ ((a |> E.A.fmap(toString) |> Js.String.concatMany(_, ",")) ++ "]")
|
||||||
|
| #Hash(h) =>
|
||||||
|
"{" ++
|
||||||
|
((h
|
||||||
|
|> E.A.fmap(((name, value)) => name ++ (":" ++ toString(value)))
|
||||||
|
|> Js.String.concatMany(_, ",")) ++
|
||||||
|
"}")
|
||||||
|
}
|
||||||
|
|
||||||
let render = (evaluationParams: evaluationParams, r) =>
|
let render = (evaluationParams: evaluationParams, r) => #Render(r) |> evaluate(evaluationParams)
|
||||||
#Render(r) |> evaluateNode(evaluationParams)
|
|
||||||
|
|
||||||
let ensureIsRendered = (params, t) =>
|
let ensureIsRendered = (params, t) =>
|
||||||
switch t {
|
switch t {
|
||||||
|
@ -142,7 +114,7 @@ module AST = {
|
||||||
| Error(e) => Error(e)
|
| Error(e) => Error(e)
|
||||||
}
|
}
|
||||||
|
|
||||||
let getShape = (item: node) =>
|
let toPointSetDist = (item: node) =>
|
||||||
switch item {
|
switch item {
|
||||||
| #RenderedDist(r) => Some(r)
|
| #RenderedDist(r) => Some(r)
|
||||||
| _ => None
|
| _ => None
|
||||||
|
@ -155,20 +127,106 @@ module AST = {
|
||||||
}
|
}
|
||||||
|
|
||||||
let toFloat = (item: node): result<node, string> =>
|
let toFloat = (item: node): result<node, string> =>
|
||||||
item |> getShape |> E.O.bind(_, _toFloat) |> E.O.toResult("Not valid shape")
|
item |> toPointSetDist |> E.O.bind(_, _toFloat) |> E.O.toResult("Not valid shape")
|
||||||
|
}
|
||||||
|
|
||||||
|
module Function = {
|
||||||
|
type t = (array<string>, node)
|
||||||
|
let fromNode: node => option<t> = node =>
|
||||||
|
switch node {
|
||||||
|
| #Function(r) => Some(r)
|
||||||
|
| _ => None
|
||||||
|
}
|
||||||
|
let argumentNames = ((a, _): t) => a
|
||||||
|
let internals = ((_, b): t) => b
|
||||||
|
let run = (evaluationParams: evaluationParams, args: array<node>, t: t) =>
|
||||||
|
if E.A.length(args) == E.A.length(argumentNames(t)) {
|
||||||
|
let newEnvironment = Belt.Array.zip(argumentNames(t), args) |> Environment.fromArray
|
||||||
|
let newEvaluationParams: evaluationParams = {
|
||||||
|
samplingInputs: evaluationParams.samplingInputs,
|
||||||
|
environment: Environment.mergeKeepSecond(evaluationParams.environment, newEnvironment),
|
||||||
|
evaluateNode: evaluationParams.evaluateNode,
|
||||||
|
}
|
||||||
|
evaluationParams.evaluateNode(newEvaluationParams, internals(t))
|
||||||
|
} else {
|
||||||
|
Error("Wrong number of variables")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
type simplificationResult = [
|
module SamplingDistribution = {
|
||||||
| #Solution(AST.node)
|
type t = [
|
||||||
| #Error(string)
|
| #SymbolicDist(SymbolicDistTypes.symbolicDist)
|
||||||
| #NoSolution
|
| #RenderedDist(PointSetTypes.pointSetDist)
|
||||||
]
|
|
||||||
|
|
||||||
module Program = {
|
|
||||||
type statement = [
|
|
||||||
| #Assignment(string, AST.node)
|
|
||||||
| #Expression(AST.node)
|
|
||||||
]
|
]
|
||||||
type program = array<statement>
|
|
||||||
|
let isSamplingDistribution: node => bool = x =>
|
||||||
|
switch x {
|
||||||
|
| #SymbolicDist(_) => true
|
||||||
|
| #RenderedDist(_) => true
|
||||||
|
| _ => false
|
||||||
|
}
|
||||||
|
|
||||||
|
let fromNode: node => result<t, string> = x =>
|
||||||
|
switch x {
|
||||||
|
| #SymbolicDist(n) => Ok(#SymbolicDist(n))
|
||||||
|
| #RenderedDist(n) => Ok(#RenderedDist(n))
|
||||||
|
| _ => Error("Not valid type")
|
||||||
|
}
|
||||||
|
|
||||||
|
let renderIfIsNotSamplingDistribution = (params, t): result<node, string> =>
|
||||||
|
!isSamplingDistribution(t)
|
||||||
|
? switch Node.render(params, t) {
|
||||||
|
| Ok(r) => Ok(r)
|
||||||
|
| Error(e) => Error(e)
|
||||||
|
}
|
||||||
|
: Ok(t)
|
||||||
|
|
||||||
|
let map = (~renderedDistFn, ~symbolicDistFn, node: node) =>
|
||||||
|
node |> (
|
||||||
|
x =>
|
||||||
|
switch x {
|
||||||
|
| #RenderedDist(r) => Some(renderedDistFn(r))
|
||||||
|
| #SymbolicDist(s) => Some(symbolicDistFn(s))
|
||||||
|
| _ => None
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
let sampleN = n =>
|
||||||
|
map(~renderedDistFn=PointSetDist.sampleNRendered(n), ~symbolicDistFn=SymbolicDist.T.sampleN(n))
|
||||||
|
|
||||||
|
let getCombinationSamples = (n, algebraicOp, t1: node, t2: node) =>
|
||||||
|
switch (sampleN(n, t1), sampleN(n, t2)) {
|
||||||
|
| (Some(a), Some(b)) =>
|
||||||
|
Some(
|
||||||
|
Belt.Array.zip(a, b) |> E.A.fmap(((a, b)) => Operation.Algebraic.toFn(algebraicOp, a, b)),
|
||||||
|
)
|
||||||
|
| _ => None
|
||||||
|
}
|
||||||
|
|
||||||
|
let combineShapesUsingSampling = (
|
||||||
|
evaluationParams: evaluationParams,
|
||||||
|
algebraicOp,
|
||||||
|
t1: node,
|
||||||
|
t2: node,
|
||||||
|
) => {
|
||||||
|
let i1 = renderIfIsNotSamplingDistribution(evaluationParams, t1)
|
||||||
|
let i2 = renderIfIsNotSamplingDistribution(evaluationParams, t2)
|
||||||
|
E.R.merge(i1, i2) |> E.R.bind(_, ((a, b)) => {
|
||||||
|
let samples = getCombinationSamples(
|
||||||
|
evaluationParams.samplingInputs.sampleCount,
|
||||||
|
algebraicOp,
|
||||||
|
a,
|
||||||
|
b,
|
||||||
|
)
|
||||||
|
|
||||||
|
let pointSetDist =
|
||||||
|
samples
|
||||||
|
|> E.O.fmap(r =>
|
||||||
|
SampleSet.toPointSetDist(~samplingInputs=evaluationParams.samplingInputs, ~samples=r, ())
|
||||||
|
)
|
||||||
|
|> E.O.bind(_, r => r.pointSetDist)
|
||||||
|
|> E.O.toResult("No response")
|
||||||
|
pointSetDist |> E.R.fmap(r => #Normalize(#RenderedDist(r)))
|
||||||
|
})
|
||||||
|
}
|
||||||
}
|
}
|
|
@ -1,138 +0,0 @@
|
||||||
open ASTTypes.AST
|
|
||||||
|
|
||||||
module Function = {
|
|
||||||
type t = (array<string>, node)
|
|
||||||
let fromNode: node => option<t> = node =>
|
|
||||||
switch node {
|
|
||||||
| #Function(r) => Some(r)
|
|
||||||
| _ => None
|
|
||||||
}
|
|
||||||
let argumentNames = ((a, _): t) => a
|
|
||||||
let internals = ((_, b): t) => b
|
|
||||||
let run = (
|
|
||||||
evaluationParams: ASTTypes.AST.evaluationParams,
|
|
||||||
args: array<node>,
|
|
||||||
t: t,
|
|
||||||
) =>
|
|
||||||
if E.A.length(args) == E.A.length(argumentNames(t)) {
|
|
||||||
let newEnvironment =
|
|
||||||
Belt.Array.zip(
|
|
||||||
argumentNames(t),
|
|
||||||
args,
|
|
||||||
) |> ASTTypes.AST.Environment.fromArray
|
|
||||||
let newEvaluationParams: ASTTypes.AST.evaluationParams = {
|
|
||||||
samplingInputs: evaluationParams.samplingInputs,
|
|
||||||
environment: ASTTypes.AST.Environment.mergeKeepSecond(
|
|
||||||
evaluationParams.environment,
|
|
||||||
newEnvironment,
|
|
||||||
),
|
|
||||||
evaluateNode: evaluationParams.evaluateNode,
|
|
||||||
}
|
|
||||||
evaluationParams.evaluateNode(newEvaluationParams, internals(t))
|
|
||||||
} else {
|
|
||||||
Error("Wrong number of variables")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
module Primative = {
|
|
||||||
type t = [
|
|
||||||
| #SymbolicDist(SymbolicDistTypes.symbolicDist)
|
|
||||||
| #RenderedDist(PointSetTypes.pointSetDist)
|
|
||||||
| #Function(array<string>, node)
|
|
||||||
]
|
|
||||||
|
|
||||||
let isPrimative: node => bool = x =>
|
|
||||||
switch x {
|
|
||||||
| #SymbolicDist(_)
|
|
||||||
| #RenderedDist(_)
|
|
||||||
| #Function(_) => true
|
|
||||||
| _ => false
|
|
||||||
}
|
|
||||||
|
|
||||||
let fromNode: node => option<t> = x =>
|
|
||||||
switch x {
|
|
||||||
| #SymbolicDist(_) as n
|
|
||||||
| #RenderedDist(_) as n
|
|
||||||
| #Function(_) as n =>
|
|
||||||
Some(n)
|
|
||||||
| _ => None
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
module SamplingDistribution = {
|
|
||||||
type t = [
|
|
||||||
| #SymbolicDist(SymbolicDistTypes.symbolicDist)
|
|
||||||
| #RenderedDist(PointSetTypes.pointSetDist)
|
|
||||||
]
|
|
||||||
|
|
||||||
let isSamplingDistribution: node => bool = x =>
|
|
||||||
switch x {
|
|
||||||
| #SymbolicDist(_) => true
|
|
||||||
| #RenderedDist(_) => true
|
|
||||||
| _ => false
|
|
||||||
}
|
|
||||||
|
|
||||||
let fromNode: node => result<t, string> = x =>
|
|
||||||
switch x {
|
|
||||||
| #SymbolicDist(n) => Ok(#SymbolicDist(n))
|
|
||||||
| #RenderedDist(n) => Ok(#RenderedDist(n))
|
|
||||||
| _ => Error("Not valid type")
|
|
||||||
}
|
|
||||||
|
|
||||||
let renderIfIsNotSamplingDistribution = (params, t): result<node, string> =>
|
|
||||||
!isSamplingDistribution(t)
|
|
||||||
? switch Render.render(params, t) {
|
|
||||||
| Ok(r) => Ok(r)
|
|
||||||
| Error(e) => Error(e)
|
|
||||||
}
|
|
||||||
: Ok(t)
|
|
||||||
|
|
||||||
let map = (~renderedDistFn, ~symbolicDistFn, node: node) =>
|
|
||||||
node |> (
|
|
||||||
x =>
|
|
||||||
switch x {
|
|
||||||
| #RenderedDist(r) => Some(renderedDistFn(r))
|
|
||||||
| #SymbolicDist(s) => Some(symbolicDistFn(s))
|
|
||||||
| _ => None
|
|
||||||
}
|
|
||||||
)
|
|
||||||
|
|
||||||
let sampleN = n =>
|
|
||||||
map(~renderedDistFn=PointSetDist.sampleNRendered(n), ~symbolicDistFn=SymbolicDist.T.sampleN(n))
|
|
||||||
|
|
||||||
let getCombinationSamples = (n, algebraicOp, t1: node, t2: node) =>
|
|
||||||
switch (sampleN(n, t1), sampleN(n, t2)) {
|
|
||||||
| (Some(a), Some(b)) =>
|
|
||||||
Some(
|
|
||||||
Belt.Array.zip(a, b) |> E.A.fmap(((a, b)) => Operation.Algebraic.toFn(algebraicOp, a, b)),
|
|
||||||
)
|
|
||||||
| _ => None
|
|
||||||
}
|
|
||||||
|
|
||||||
let combineShapesUsingSampling = (
|
|
||||||
evaluationParams: evaluationParams,
|
|
||||||
algebraicOp,
|
|
||||||
t1: node,
|
|
||||||
t2: node,
|
|
||||||
) => {
|
|
||||||
let i1 = renderIfIsNotSamplingDistribution(evaluationParams, t1)
|
|
||||||
let i2 = renderIfIsNotSamplingDistribution(evaluationParams, t2)
|
|
||||||
E.R.merge(i1, i2) |> E.R.bind(_, ((a, b)) => {
|
|
||||||
let samples = getCombinationSamples(
|
|
||||||
evaluationParams.samplingInputs.sampleCount,
|
|
||||||
algebraicOp,
|
|
||||||
a,
|
|
||||||
b,
|
|
||||||
)
|
|
||||||
|
|
||||||
// todo: This bottom part should probably be somewhere else.
|
|
||||||
// todo: REFACTOR: I'm not sure about the SampleSet line.
|
|
||||||
let pointSetDist =
|
|
||||||
samples
|
|
||||||
|> E.O.fmap(r => SampleSet.toPointSetDist(~samplingInputs=evaluationParams.samplingInputs, ~samples=r, ()))
|
|
||||||
|> E.O.bind(_, r => r.pointSetDist)
|
|
||||||
|> E.O.toResult("No response")
|
|
||||||
pointSetDist |> E.R.fmap(r => #Normalize(#RenderedDist(r)))
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
|
|
@ -84,7 +84,7 @@ let makeDist = (name, fn) =>
|
||||||
)
|
)
|
||||||
|
|
||||||
let floatFromDist = (
|
let floatFromDist = (
|
||||||
distToFloatOp: ASTTypes.distToFloatOperation,
|
distToFloatOp: Operation.distToFloatOperation,
|
||||||
t: TypeSystem.samplingDist,
|
t: TypeSystem.samplingDist,
|
||||||
): result<node, string> =>
|
): result<node, string> =>
|
||||||
switch t {
|
switch t {
|
||||||
|
@ -111,7 +111,7 @@ let verticalScaling = (scaleOp, rs, scaleBy) => {
|
||||||
}
|
}
|
||||||
|
|
||||||
module Multimodal = {
|
module Multimodal = {
|
||||||
let getByNameResult = ASTTypes.AST.Hash.getByNameResult
|
let getByNameResult = Hash.getByNameResult
|
||||||
|
|
||||||
let _paramsToDistsAndWeights = (r: array<typedValue>) =>
|
let _paramsToDistsAndWeights = (r: array<typedValue>) =>
|
||||||
switch r {
|
switch r {
|
||||||
|
|
|
@ -1,5 +1,5 @@
|
||||||
type node = ASTTypes.AST.node
|
type node = ASTTypes.node
|
||||||
let getFloat = ASTTypes.AST.getFloat
|
let getFloat = ASTTypes.Node.getFloat
|
||||||
|
|
||||||
type samplingDist = [
|
type samplingDist = [
|
||||||
| #SymbolicDist(SymbolicDistTypes.symbolicDist)
|
| #SymbolicDist(SymbolicDistTypes.symbolicDist)
|
||||||
|
@ -61,7 +61,7 @@ module TypedValue = {
|
||||||
|> E.A.fmap(((name, t)) => fromNode(t) |> E.R.fmap(r => (name, r)))
|
|> E.A.fmap(((name, t)) => fromNode(t) |> E.R.fmap(r => (name, r)))
|
||||||
|> E.A.R.firstErrorOrOpen
|
|> E.A.R.firstErrorOrOpen
|
||||||
|> E.R.fmap(r => #Hash(r))
|
|> E.R.fmap(r => #Hash(r))
|
||||||
| e => Error("Wrong type: " ++ ASTBasic.toString(e))
|
| e => Error("Wrong type: " ++ ASTTypes.Node.toString(e))
|
||||||
}
|
}
|
||||||
|
|
||||||
// todo: Arrays and hashes
|
// todo: Arrays and hashes
|
||||||
|
@ -73,12 +73,12 @@ module TypedValue = {
|
||||||
| _ => Error("Type Error: Expected float.")
|
| _ => Error("Type Error: Expected float.")
|
||||||
}
|
}
|
||||||
| (#SamplingDistribution, _) =>
|
| (#SamplingDistribution, _) =>
|
||||||
PTypes.SamplingDistribution.renderIfIsNotSamplingDistribution(
|
ASTTypes.SamplingDistribution.renderIfIsNotSamplingDistribution(
|
||||||
evaluationParams,
|
evaluationParams,
|
||||||
node,
|
node,
|
||||||
) |> E.R.bind(_, fromNode)
|
) |> E.R.bind(_, fromNode)
|
||||||
| (#RenderedDistribution, _) =>
|
| (#RenderedDistribution, _) =>
|
||||||
ASTTypes.AST.Render.render(evaluationParams, node) |> E.R.bind(_, fromNode)
|
ASTTypes.Node.render(evaluationParams, node) |> E.R.bind(_, fromNode)
|
||||||
| (#Array(_type), #Array(b)) =>
|
| (#Array(_type), #Array(b)) =>
|
||||||
b
|
b
|
||||||
|> E.A.fmap(fromNodeWithTypeCoercion(evaluationParams, _type))
|
|> E.A.fmap(fromNodeWithTypeCoercion(evaluationParams, _type))
|
||||||
|
@ -89,7 +89,7 @@ module TypedValue = {
|
||||||
named |> E.A.fmap(((name, intendedType)) => (
|
named |> E.A.fmap(((name, intendedType)) => (
|
||||||
name,
|
name,
|
||||||
intendedType,
|
intendedType,
|
||||||
ASTTypes.AST.Hash.getByName(r, name),
|
Hash.getByName(r, name),
|
||||||
))
|
))
|
||||||
let typedHash =
|
let typedHash =
|
||||||
keyValues
|
keyValues
|
||||||
|
@ -172,7 +172,7 @@ module Function = {
|
||||||
|> E.A.R.firstErrorOrOpen
|
|> E.A.R.firstErrorOrOpen
|
||||||
|
|
||||||
let inputsToTypedValues = (
|
let inputsToTypedValues = (
|
||||||
evaluationParams: ASTTypes.AST.evaluationParams,
|
evaluationParams: ASTTypes.evaluationParams,
|
||||||
inputNodes: inputNodes,
|
inputNodes: inputNodes,
|
||||||
t: t,
|
t: t,
|
||||||
) =>
|
) =>
|
||||||
|
@ -181,7 +181,7 @@ module Function = {
|
||||||
)
|
)
|
||||||
|
|
||||||
let run = (
|
let run = (
|
||||||
evaluationParams: ASTTypes.AST.evaluationParams,
|
evaluationParams: ASTTypes.evaluationParams,
|
||||||
inputNodes: inputNodes,
|
inputNodes: inputNodes,
|
||||||
t: t,
|
t: t,
|
||||||
) =>
|
) =>
|
||||||
|
|
|
@ -122,7 +122,7 @@ module MathAdtToDistDst = {
|
||||||
| _ => Error("Lognormal distribution needs either mean and stdev or mu and sigma")
|
| _ => Error("Lognormal distribution needs either mean and stdev or mu and sigma")
|
||||||
}
|
}
|
||||||
| _ =>
|
| _ =>
|
||||||
parseArgs() |> E.R.fmap((args: array<ASTTypes.AST.node>) =>
|
parseArgs() |> E.R.fmap((args: array<ASTTypes.node>) =>
|
||||||
#FunctionCall("lognormal", args)
|
#FunctionCall("lognormal", args)
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
|
@ -130,8 +130,8 @@ module MathAdtToDistDst = {
|
||||||
// Error("Dotwise exponentiation needs two operands")
|
// Error("Dotwise exponentiation needs two operands")
|
||||||
let operationParser = (
|
let operationParser = (
|
||||||
name: string,
|
name: string,
|
||||||
args: result<array<ASTTypes.AST.node>, string>,
|
args: result<array<ASTTypes.node>, string>,
|
||||||
): result<ASTTypes.AST.node, string> => {
|
): result<ASTTypes.node, string> => {
|
||||||
let toOkAlgebraic = r => Ok(#AlgebraicCombination(r))
|
let toOkAlgebraic = r => Ok(#AlgebraicCombination(r))
|
||||||
let toOkPointwise = r => Ok(#PointwiseCombination(r))
|
let toOkPointwise = r => Ok(#PointwiseCombination(r))
|
||||||
let toOkTruncate = r => Ok(#Truncate(r))
|
let toOkTruncate = r => Ok(#Truncate(r))
|
||||||
|
@ -170,12 +170,12 @@ module MathAdtToDistDst = {
|
||||||
|
|
||||||
let functionParser = (
|
let functionParser = (
|
||||||
nodeParser: MathJsonToMathJsAdt.arg => Belt.Result.t<
|
nodeParser: MathJsonToMathJsAdt.arg => Belt.Result.t<
|
||||||
ASTTypes.AST.node,
|
ASTTypes.node,
|
||||||
string,
|
string,
|
||||||
>,
|
>,
|
||||||
name: string,
|
name: string,
|
||||||
args: array<MathJsonToMathJsAdt.arg>,
|
args: array<MathJsonToMathJsAdt.arg>,
|
||||||
): result<ASTTypes.AST.node, string> => {
|
): result<ASTTypes.node, string> => {
|
||||||
let parseArray = ags => ags |> E.A.fmap(nodeParser) |> E.A.R.firstErrorOrOpen
|
let parseArray = ags => ags |> E.A.fmap(nodeParser) |> E.A.R.firstErrorOrOpen
|
||||||
let parseArgs = () => parseArray(args)
|
let parseArgs = () => parseArray(args)
|
||||||
switch name {
|
switch name {
|
||||||
|
@ -212,27 +212,27 @@ module MathAdtToDistDst = {
|
||||||
| (Some(Error(r)), _) => Error(r)
|
| (Some(Error(r)), _) => Error(r)
|
||||||
| (_, Error(r)) => Error(r)
|
| (_, Error(r)) => Error(r)
|
||||||
| (None, Ok(dists)) =>
|
| (None, Ok(dists)) =>
|
||||||
let hash: ASTTypes.AST.node = #FunctionCall(
|
let hash: ASTTypes.node = #FunctionCall(
|
||||||
"multimodal",
|
"multimodal",
|
||||||
[#Hash([("dists", #Array(dists)), ("weights", #Array([]))])],
|
[#Hash([("dists", #Array(dists)), ("weights", #Array([]))])],
|
||||||
)
|
)
|
||||||
Ok(hash)
|
Ok(hash)
|
||||||
| (Some(Ok(weights)), Ok(dists)) =>
|
| (Some(Ok(weights)), Ok(dists)) =>
|
||||||
let hash: ASTTypes.AST.node = #FunctionCall(
|
let hash: ASTTypes.node = #FunctionCall(
|
||||||
"multimodal",
|
"multimodal",
|
||||||
[#Hash([("dists", #Array(dists)), ("weights", #Array(weights))])],
|
[#Hash([("dists", #Array(dists)), ("weights", #Array(weights))])],
|
||||||
)
|
)
|
||||||
Ok(hash)
|
Ok(hash)
|
||||||
}
|
}
|
||||||
| name =>
|
| name =>
|
||||||
parseArgs() |> E.R.fmap((args: array<ASTTypes.AST.node>) =>
|
parseArgs() |> E.R.fmap((args: array<ASTTypes.node>) =>
|
||||||
#FunctionCall(name, args)
|
#FunctionCall(name, args)
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
let rec nodeParser: MathJsonToMathJsAdt.arg => result<
|
let rec nodeParser: MathJsonToMathJsAdt.arg => result<
|
||||||
ASTTypes.AST.node,
|
ASTTypes.node,
|
||||||
string,
|
string,
|
||||||
> = x =>
|
> = x =>
|
||||||
switch x {
|
switch x {
|
||||||
|
@ -246,7 +246,7 @@ module MathAdtToDistDst = {
|
||||||
// let evaluatedExpression = run(expression);
|
// let evaluatedExpression = run(expression);
|
||||||
// `Function(_ => Ok(evaluatedExpression));
|
// `Function(_ => Ok(evaluatedExpression));
|
||||||
// }
|
// }
|
||||||
let rec topLevel = (r): result<ASTTypes.Program.program, string> =>
|
let rec topLevel = (r): result<ASTTypes.program, string> =>
|
||||||
switch r {
|
switch r {
|
||||||
| FunctionAssignment({name, args, expression}) =>
|
| FunctionAssignment({name, args, expression}) =>
|
||||||
switch nodeParser(expression) {
|
switch nodeParser(expression) {
|
||||||
|
@ -267,7 +267,7 @@ module MathAdtToDistDst = {
|
||||||
blocks |> E.A.fmap(b => topLevel(b)) |> E.A.R.firstErrorOrOpen |> E.R.fmap(E.A.concatMany)
|
blocks |> E.A.fmap(b => topLevel(b)) |> E.A.R.firstErrorOrOpen |> E.R.fmap(E.A.concatMany)
|
||||||
}
|
}
|
||||||
|
|
||||||
let run = (r): result<ASTTypes.Program.program, string> =>
|
let run = (r): result<ASTTypes.program, string> =>
|
||||||
r |> MathAdtCleaner.run |> topLevel
|
r |> MathAdtCleaner.run |> topLevel
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -96,12 +96,10 @@ let toDiscretePointMassesFromTriangulars = (
|
||||||
}
|
}
|
||||||
|
|
||||||
let combineShapesContinuousContinuous = (
|
let combineShapesContinuousContinuous = (
|
||||||
op: ASTTypes.algebraicOperation,
|
op: Operation.algebraicOperation,
|
||||||
s1: PointSetTypes.xyShape,
|
s1: PointSetTypes.xyShape,
|
||||||
s2: PointSetTypes.xyShape,
|
s2: PointSetTypes.xyShape,
|
||||||
): PointSetTypes.xyShape => {
|
): PointSetTypes.xyShape => {
|
||||||
let t1n = s1 |> XYShape.T.length
|
|
||||||
let t2n = s2 |> XYShape.T.length
|
|
||||||
|
|
||||||
// if we add the two distributions, we should probably use normal filters.
|
// if we add the two distributions, we should probably use normal filters.
|
||||||
// if we multiply the two distributions, we should probably use lognormal filters.
|
// if we multiply the two distributions, we should probably use lognormal filters.
|
||||||
|
@ -194,13 +192,13 @@ let toDiscretePointMassesFromDiscrete = (s: PointSetTypes.xyShape): pointMassesW
|
||||||
|
|
||||||
let masses: array<float> = Belt.Array.makeBy(n, i => ys[i])
|
let masses: array<float> = Belt.Array.makeBy(n, i => ys[i])
|
||||||
let means: array<float> = Belt.Array.makeBy(n, i => xs[i])
|
let means: array<float> = Belt.Array.makeBy(n, i => xs[i])
|
||||||
let variances: array<float> = Belt.Array.makeBy(n, i => 0.0)
|
let variances: array<float> = Belt.Array.makeBy(n, _ => 0.0)
|
||||||
|
|
||||||
{n: n, masses: masses, means: means, variances: variances}
|
{n: n, masses: masses, means: means, variances: variances}
|
||||||
}
|
}
|
||||||
|
|
||||||
let combineShapesContinuousDiscrete = (
|
let combineShapesContinuousDiscrete = (
|
||||||
op: ASTTypes.algebraicOperation,
|
op: Operation.algebraicOperation,
|
||||||
continuousShape: PointSetTypes.xyShape,
|
continuousShape: PointSetTypes.xyShape,
|
||||||
discreteShape: PointSetTypes.xyShape,
|
discreteShape: PointSetTypes.xyShape,
|
||||||
): PointSetTypes.xyShape => {
|
): PointSetTypes.xyShape => {
|
||||||
|
|
|
@ -211,7 +211,7 @@ module T = Dist({
|
||||||
/* This simply creates multiple copies of the continuous distribution, scaled and shifted according to
|
/* This simply creates multiple copies of the continuous distribution, scaled and shifted according to
|
||||||
each discrete data point, and then adds them all together. */
|
each discrete data point, and then adds them all together. */
|
||||||
let combineAlgebraicallyWithDiscrete = (
|
let combineAlgebraicallyWithDiscrete = (
|
||||||
op: ASTTypes.algebraicOperation,
|
op: Operation.algebraicOperation,
|
||||||
t1: t,
|
t1: t,
|
||||||
t2: PointSetTypes.discreteShape,
|
t2: PointSetTypes.discreteShape,
|
||||||
) => {
|
) => {
|
||||||
|
@ -244,7 +244,7 @@ let combineAlgebraicallyWithDiscrete = (
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
let combineAlgebraically = (op: ASTTypes.algebraicOperation, t1: t, t2: t) => {
|
let combineAlgebraically = (op: Operation.algebraicOperation, t1: t, t2: t) => {
|
||||||
let s1 = t1 |> getShape
|
let s1 = t1 |> getShape
|
||||||
let s2 = t2 |> getShape
|
let s2 = t2 |> getShape
|
||||||
let t1n = s1 |> XYShape.T.length
|
let t1n = s1 |> XYShape.T.length
|
||||||
|
|
|
@ -85,7 +85,7 @@ let updateIntegralCache = (integralCache, t: t): t => {
|
||||||
|
|
||||||
/* This multiples all of the data points together and creates a new discrete distribution from the results.
|
/* This multiples all of the data points together and creates a new discrete distribution from the results.
|
||||||
Data points at the same xs get added together. It may be a good idea to downsample t1 and t2 before and/or the result after. */
|
Data points at the same xs get added together. It may be a good idea to downsample t1 and t2 before and/or the result after. */
|
||||||
let combineAlgebraically = (op: ASTTypes.algebraicOperation, t1: t, t2: t): t => {
|
let combineAlgebraically = (op: Operation.algebraicOperation, t1: t, t2: t): t => {
|
||||||
let t1s = t1 |> getShape
|
let t1s = t1 |> getShape
|
||||||
let t2s = t2 |> getShape
|
let t2s = t2 |> getShape
|
||||||
let t1n = t1s |> XYShape.T.length
|
let t1n = t1s |> XYShape.T.length
|
||||||
|
|
|
@ -227,7 +227,7 @@ module T = Dist({
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
|
|
||||||
let combineAlgebraically = (op: ASTTypes.algebraicOperation, t1: t, t2: t): t => {
|
let combineAlgebraically = (op: Operation.algebraicOperation, t1: t, t2: t): t => {
|
||||||
// Discrete convolution can cause a huge increase in the number of samples,
|
// Discrete convolution can cause a huge increase in the number of samples,
|
||||||
// so we'll first downsample.
|
// so we'll first downsample.
|
||||||
|
|
||||||
|
@ -240,9 +240,6 @@ let combineAlgebraically = (op: ASTTypes.algebraicOperation, t1: t, t2: t): t =>
|
||||||
// sqtl > 10 ? T.downsample(int_of_float(sqtl), t) : t;
|
// sqtl > 10 ? T.downsample(int_of_float(sqtl), t) : t;
|
||||||
//};
|
//};
|
||||||
|
|
||||||
let t1d = t1
|
|
||||||
let t2d = t2
|
|
||||||
|
|
||||||
// continuous (*) continuous => continuous, but also
|
// continuous (*) continuous => continuous, but also
|
||||||
// discrete (*) continuous => continuous (and vice versa). We have to take care of all combos and then combine them:
|
// discrete (*) continuous => continuous (and vice versa). We have to take care of all combos and then combine them:
|
||||||
let ccConvResult = Continuous.combineAlgebraically(op, t1.continuous, t2.continuous)
|
let ccConvResult = Continuous.combineAlgebraically(op, t1.continuous, t2.continuous)
|
||||||
|
|
|
@ -1,6 +1,7 @@
|
||||||
open Distributions
|
open Distributions
|
||||||
|
|
||||||
type t = PointSetTypes.pointSetDist
|
type t = PointSetTypes.pointSetDist
|
||||||
|
|
||||||
let mapToAll = ((fn1, fn2, fn3), t: t) =>
|
let mapToAll = ((fn1, fn2, fn3), t: t) =>
|
||||||
switch t {
|
switch t {
|
||||||
| Mixed(m) => fn1(m)
|
| Mixed(m) => fn1(m)
|
||||||
|
@ -33,7 +34,7 @@ let toMixed = mapToAll((
|
||||||
),
|
),
|
||||||
))
|
))
|
||||||
|
|
||||||
let combineAlgebraically = (op: ASTTypes.algebraicOperation, t1: t, t2: t): t =>
|
let combineAlgebraically = (op: Operation.algebraicOperation, t1: t, t2: t): t =>
|
||||||
switch (t1, t2) {
|
switch (t1, t2) {
|
||||||
| (Continuous(m1), Continuous(m2)) =>
|
| (Continuous(m1), Continuous(m2)) =>
|
||||||
Continuous.combineAlgebraically(op, m1, m2) |> Continuous.T.toPointSetDist
|
Continuous.combineAlgebraically(op, m1, m2) |> Continuous.T.toPointSetDist
|
||||||
|
@ -77,9 +78,6 @@ module T = Dist({
|
||||||
|
|
||||||
let toPointSetDist = (t: t) => t
|
let toPointSetDist = (t: t) => t
|
||||||
|
|
||||||
let toContinuous = t => None
|
|
||||||
let toDiscrete = t => None
|
|
||||||
|
|
||||||
let downsample = (i, t) =>
|
let downsample = (i, t) =>
|
||||||
fmap((Mixed.T.downsample(i), Discrete.T.downsample(i), Continuous.T.downsample(i)), t)
|
fmap((Mixed.T.downsample(i), Discrete.T.downsample(i), Continuous.T.downsample(i)), t)
|
||||||
|
|
||||||
|
@ -93,8 +91,6 @@ module T = Dist({
|
||||||
t,
|
t,
|
||||||
)
|
)
|
||||||
|
|
||||||
let toDiscreteProbabilityMassFraction = t => 0.0
|
|
||||||
|
|
||||||
let normalize = fmap((Mixed.T.normalize, Discrete.T.normalize, Continuous.T.normalize))
|
let normalize = fmap((Mixed.T.normalize, Discrete.T.normalize, Continuous.T.normalize))
|
||||||
|
|
||||||
let updateIntegralCache = (integralCache, t: t): t =>
|
let updateIntegralCache = (integralCache, t: t): t =>
|
||||||
|
@ -197,7 +193,7 @@ let sampleNRendered = (n, dist) => {
|
||||||
doN(n, () => sample(distWithUpdatedIntegralCache))
|
doN(n, () => sample(distWithUpdatedIntegralCache))
|
||||||
}
|
}
|
||||||
|
|
||||||
let operate = (distToFloatOp: ASTTypes.distToFloatOperation, s): float =>
|
let operate = (distToFloatOp: Operation.distToFloatOperation, s): float =>
|
||||||
switch distToFloatOp {
|
switch distToFloatOp {
|
||||||
| #Pdf(f) => pdf(f, s)
|
| #Pdf(f) => pdf(f, s)
|
||||||
| #Cdf(f) => pdf(f, s)
|
| #Cdf(f) => pdf(f, s)
|
||||||
|
|
|
@ -159,7 +159,7 @@ module XtoY = {
|
||||||
y1 *. (1. -. fraction) +. y2 *. fraction
|
y1 *. (1. -. fraction) +. y2 *. fraction
|
||||||
}
|
}
|
||||||
| (#Stepwise, #UseZero) =>
|
| (#Stepwise, #UseZero) =>
|
||||||
(t: T.t, leftIndex: int, x: float) =>
|
(t: T.t, leftIndex: int, _x: float) =>
|
||||||
if leftIndex < 0 {
|
if leftIndex < 0 {
|
||||||
0.0
|
0.0
|
||||||
} else if leftIndex >= T.length(t) - 1 {
|
} else if leftIndex >= T.length(t) - 1 {
|
||||||
|
@ -168,7 +168,7 @@ module XtoY = {
|
||||||
t.ys[leftIndex]
|
t.ys[leftIndex]
|
||||||
}
|
}
|
||||||
| (#Stepwise, #UseOutermostPoints) =>
|
| (#Stepwise, #UseOutermostPoints) =>
|
||||||
(t: T.t, leftIndex: int, x: float) =>
|
(t: T.t, leftIndex: int, _x: float) =>
|
||||||
if leftIndex < 0 {
|
if leftIndex < 0 {
|
||||||
t.ys[0]
|
t.ys[0]
|
||||||
} else if leftIndex >= T.length(t) - 1 {
|
} else if leftIndex >= T.length(t) - 1 {
|
||||||
|
|
|
@ -80,7 +80,7 @@ module Internals = {
|
||||||
|
|
||||||
let toPointSetDist = (
|
let toPointSetDist = (
|
||||||
~samples: Internals.T.t,
|
~samples: Internals.T.t,
|
||||||
~samplingInputs: ASTTypes.AST.samplingInputs,
|
~samplingInputs: SamplingInputs.samplingInputs,
|
||||||
(),
|
(),
|
||||||
) => {
|
) => {
|
||||||
Array.fast_sort(compare, samples)
|
Array.fast_sort(compare, samples)
|
||||||
|
|
|
@ -272,7 +272,7 @@ module T = {
|
||||||
| #Float(n) => Float.mean(n)
|
| #Float(n) => Float.mean(n)
|
||||||
}
|
}
|
||||||
|
|
||||||
let operate = (distToFloatOp: ASTTypes.distToFloatOperation, s) =>
|
let operate = (distToFloatOp: Operation.distToFloatOperation, s) =>
|
||||||
switch distToFloatOp {
|
switch distToFloatOp {
|
||||||
| #Cdf(f) => Ok(cdf(f, s))
|
| #Cdf(f) => Ok(cdf(f, s))
|
||||||
| #Pdf(f) => Ok(pdf(f, s))
|
| #Pdf(f) => Ok(pdf(f, s))
|
||||||
|
@ -302,7 +302,7 @@ module T = {
|
||||||
let tryAnalyticalSimplification = (
|
let tryAnalyticalSimplification = (
|
||||||
d1: symbolicDist,
|
d1: symbolicDist,
|
||||||
d2: symbolicDist,
|
d2: symbolicDist,
|
||||||
op: ASTTypes.algebraicOperation,
|
op: Operation.algebraicOperation,
|
||||||
): analyticalSimplificationResult =>
|
): analyticalSimplificationResult =>
|
||||||
switch (d1, d2) {
|
switch (d1, d2) {
|
||||||
| (#Float(v1), #Float(v2)) =>
|
| (#Float(v1), #Float(v2)) =>
|
||||||
|
|
8
packages/squiggle-lang/src/rescript/utility/Hash.res
Normal file
8
packages/squiggle-lang/src/rescript/utility/Hash.res
Normal file
|
@ -0,0 +1,8 @@
|
||||||
|
type t<'a> = array<(string, 'a)>
|
||||||
|
let getByName = (t: t<'a>, name) => E.A.getBy(t, ((n, _)) => n == name) |> E.O.fmap(((_, r)) => r)
|
||||||
|
|
||||||
|
let getByNameResult = (t: t<'a>, name) =>
|
||||||
|
getByName(t, name) |> E.O.toResult(name ++ " expected and not found")
|
||||||
|
|
||||||
|
let getByNames = (hash: t<'a>, names: array<string>) =>
|
||||||
|
names |> E.A.fmap(name => (name, getByName(hash, name)))
|
|
@ -1,4 +1,21 @@
|
||||||
open ASTTypes
|
// This file has no dependencies. It's used outside of the interpreter, but the interpreter depends on it.
|
||||||
|
|
||||||
|
type algebraicOperation = [
|
||||||
|
| #Add
|
||||||
|
| #Multiply
|
||||||
|
| #Subtract
|
||||||
|
| #Divide
|
||||||
|
| #Exponentiate
|
||||||
|
]
|
||||||
|
type pointwiseOperation = [#Add | #Multiply | #Exponentiate]
|
||||||
|
type scaleOperation = [#Multiply | #Exponentiate | #Log]
|
||||||
|
type distToFloatOperation = [
|
||||||
|
| #Pdf(float)
|
||||||
|
| #Cdf(float)
|
||||||
|
| #Inv(float)
|
||||||
|
| #Mean
|
||||||
|
| #Sample
|
||||||
|
]
|
||||||
|
|
||||||
module Algebraic = {
|
module Algebraic = {
|
||||||
type t = algebraicOperation
|
type t = algebraicOperation
|
||||||
|
@ -80,28 +97,16 @@ module Scale = {
|
||||||
|
|
||||||
let toIntegralCacheFn = x =>
|
let toIntegralCacheFn = x =>
|
||||||
switch x {
|
switch x {
|
||||||
| #Multiply => (a, b) => None // TODO: this could probably just be multiplied out (using Continuous.scaleBy)
|
| #Multiply => (_, _) => None // TODO: this could probably just be multiplied out (using Continuous.scaleBy)
|
||||||
| #Exponentiate => (_, _) => None
|
| #Exponentiate => (_, _) => None
|
||||||
| #Log => (_, _) => None
|
| #Log => (_, _) => None
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
module T = {
|
module Truncate = {
|
||||||
let truncateToString = (left: option<float>, right: option<float>, nodeToString) => {
|
let toString = (left: option<float>, right: option<float>, nodeToString) => {
|
||||||
let left = left |> E.O.dimap(Js.Float.toString, () => "-inf")
|
let left = left |> E.O.dimap(Js.Float.toString, () => "-inf")
|
||||||
let right = right |> E.O.dimap(Js.Float.toString, () => "inf")
|
let right = right |> E.O.dimap(Js.Float.toString, () => "inf")
|
||||||
j`truncate($nodeToString, $left, $right)`
|
j`truncate($nodeToString, $left, $right)`
|
||||||
}
|
}
|
||||||
let toString = (nodeToString, x) =>
|
|
||||||
switch x {
|
|
||||||
| #AlgebraicCombination(op, t1, t2) => Algebraic.format(op, nodeToString(t1), nodeToString(t2))
|
|
||||||
| #PointwiseCombination(op, t1, t2) => Pointwise.format(op, nodeToString(t1), nodeToString(t2))
|
|
||||||
| #VerticalScaling(scaleOp, t, scaleBy) =>
|
|
||||||
Scale.format(scaleOp, nodeToString(t), nodeToString(scaleBy))
|
|
||||||
| #Normalize(t) => "normalize(k" ++ (nodeToString(t) ++ ")")
|
|
||||||
| #FloatFromDist(floatFromDistOp, t) => DistToFloat.format(floatFromDistOp, nodeToString(t))
|
|
||||||
| #Truncate(lc, rc, t) => truncateToString(lc, rc, nodeToString(t))
|
|
||||||
| #Render(t) => nodeToString(t)
|
|
||||||
| _ => ""
|
|
||||||
} // SymbolicDist and RenderedDist are handled in AST.toString.
|
|
||||||
}
|
}
|
|
@ -0,0 +1,21 @@
|
||||||
|
type samplingInputs = {
|
||||||
|
sampleCount: int,
|
||||||
|
outputXYPoints: int,
|
||||||
|
kernelWidth: option<float>,
|
||||||
|
pointSetDistLength: int,
|
||||||
|
}
|
||||||
|
|
||||||
|
module SamplingInputs = {
|
||||||
|
type t = {
|
||||||
|
sampleCount: option<int>,
|
||||||
|
outputXYPoints: option<int>,
|
||||||
|
kernelWidth: option<float>,
|
||||||
|
pointSetDistLength: option<int>,
|
||||||
|
}
|
||||||
|
let withDefaults = (t: t): samplingInputs => {
|
||||||
|
sampleCount: t.sampleCount |> E.O.default(10000),
|
||||||
|
outputXYPoints: t.outputXYPoints |> E.O.default(10000),
|
||||||
|
kernelWidth: t.kernelWidth,
|
||||||
|
pointSetDistLength: t.pointSetDistLength |> E.O.default(10000),
|
||||||
|
}
|
||||||
|
}
|
Loading…
Reference in New Issue
Block a user