Refactored PTypes away

This commit is contained in:
Ozzie Gooen 2022-02-16 17:37:59 -05:00
parent d8b37bb113
commit f76eaf6d03
11 changed files with 219 additions and 221 deletions

View File

@ -64,7 +64,7 @@ module Internals = {
} }
let makeOutputs = (graph, pointSetDist): outputs => {graph: graph, pointSetDist: pointSetDist} let makeOutputs = (graph, pointSetDist): outputs => {graph: graph, pointSetDist: pointSetDist}
let makeInputs = (inputs: Inputs.inputs): ASTTypes.AST.samplingInputs => { let makeInputs = (inputs: Inputs.inputs): SamplingInputs.samplingInputs => {
sampleCount: inputs.samplingInputs.sampleCount |> E.O.default(10000), sampleCount: inputs.samplingInputs.sampleCount |> E.O.default(10000),
outputXYPoints: inputs.samplingInputs.outputXYPoints |> E.O.default(10000), outputXYPoints: inputs.samplingInputs.outputXYPoints |> E.O.default(10000),
kernelWidth: inputs.samplingInputs.kernelWidth, kernelWidth: inputs.samplingInputs.kernelWidth,
@ -74,7 +74,7 @@ module Internals = {
let runNode = (inputs, node) => let runNode = (inputs, node) =>
AST.toLeaf(makeInputs(inputs), inputs.environment, node) AST.toLeaf(makeInputs(inputs), inputs.environment, node)
let runProgram = (inputs: Inputs.inputs, p: ASTTypes.Program.program) => { let runProgram = (inputs: Inputs.inputs, p: ASTTypes.AST.program) => {
let ins = ref(inputs) let ins = ref(inputs)
p p
|> E.A.fmap(x => |> E.A.fmap(x =>

View File

@ -18,7 +18,7 @@ let toPointSetDist = (samplingInputs, environment, node: node) =>
| Error(e) => Error(e) | Error(e) => Error(e)
} }
let runFunction = (samplingInputs, environment, inputs, fn: PTypes.Function.t) => { let runFunction = (samplingInputs, environment, inputs, fn: ASTTypes.AST.Function.t) => {
let params = envs(samplingInputs, environment) let params = envs(samplingInputs, environment)
PTypes.Function.run(params, inputs, fn) ASTTypes.AST.Function.run(params, inputs, fn)
} }

View File

@ -1,4 +1,3 @@
open ASTTypes
open ASTTypes.AST open ASTTypes.AST
type t = node type t = node
@ -44,12 +43,17 @@ module AlgebraicCombination = {
let combine = (evaluationParams, algebraicOp, t1: node, t2: node): result<node, string> => let combine = (evaluationParams, algebraicOp, t1: node, t2: node): result<node, string> =>
E.R.merge( E.R.merge(
PTypes.SamplingDistribution.renderIfIsNotSamplingDistribution(evaluationParams, t1), ASTTypes.AST.SamplingDistribution.renderIfIsNotSamplingDistribution(evaluationParams, t1),
PTypes.SamplingDistribution.renderIfIsNotSamplingDistribution(evaluationParams, t2), ASTTypes.AST.SamplingDistribution.renderIfIsNotSamplingDistribution(evaluationParams, t2),
) |> E.R.bind(_, ((a, b)) => ) |> E.R.bind(_, ((a, b)) =>
switch choose(a, b) { switch choose(a, b) {
| #Sampling => | #Sampling =>
PTypes.SamplingDistribution.combineShapesUsingSampling(evaluationParams, algebraicOp, a, b) ASTTypes.AST.SamplingDistribution.combineShapesUsingSampling(
evaluationParams,
algebraicOp,
a,
b,
)
| #Analytical => combinationByRendering(evaluationParams, algebraicOp, a, b) | #Analytical => combinationByRendering(evaluationParams, algebraicOp, a, b)
} }
) )
@ -118,6 +122,12 @@ module PointwiseCombination = {
} }
module Truncate = { module Truncate = {
type simplificationResult = [
| #Solution(ASTTypes.AST.node)
| #Error(string)
| #NoSolution
]
let trySimplification = (leftCutoff, rightCutoff, t): simplificationResult => let trySimplification = (leftCutoff, rightCutoff, t): simplificationResult =>
switch (leftCutoff, rightCutoff, t) { switch (leftCutoff, rightCutoff, t) {
| (None, None, t) => #Solution(t) | (None, None, t) => #Solution(t)
@ -132,7 +142,8 @@ module Truncate = {
switch // TODO: use named args for xMin/xMax in renderToShape; if we're lucky we can at least get the tail switch // TODO: use named args for xMin/xMax in renderToShape; if we're lucky we can at least get the tail
// of a distribution we otherwise wouldn't get at all // of a distribution we otherwise wouldn't get at all
Node.ensureIsRendered(evaluationParams, t) { Node.ensureIsRendered(evaluationParams, t) {
| Ok(#RenderedDist(rs)) => Ok(#RenderedDist(PointSetDist.T.truncate(leftCutoff, rightCutoff, rs))) | Ok(#RenderedDist(rs)) =>
Ok(#RenderedDist(PointSetDist.T.truncate(leftCutoff, rightCutoff, rs)))
| Error(e) => Error(e) | Error(e) => Error(e)
| _ => Error("Could not truncate distribution.") | _ => Error("Could not truncate distribution.")
} }
@ -160,7 +171,7 @@ module Normalize = {
switch t { switch t {
| #RenderedDist(s) => Ok(#RenderedDist(PointSetDist.T.normalize(s))) | #RenderedDist(s) => Ok(#RenderedDist(PointSetDist.T.normalize(s)))
| #SymbolicDist(_) => Ok(t) | #SymbolicDist(_) => Ok(t)
| _ => evaluateAndRetry(evaluationParams, operationToLeaf, t) | _ => ASTTypes.AST.Node.evaluateAndRetry(evaluationParams, operationToLeaf, t)
} }
} }
@ -170,7 +181,7 @@ module FunctionCall = {
let _runLocalFunction = (name, evaluationParams: evaluationParams, args) => let _runLocalFunction = (name, evaluationParams: evaluationParams, args) =>
Environment.getFunction(evaluationParams.environment, name) |> E.R.bind(_, ((argNames, fn)) => Environment.getFunction(evaluationParams.environment, name) |> E.R.bind(_, ((argNames, fn)) =>
PTypes.Function.run(evaluationParams, args, (argNames, fn)) ASTTypes.AST.Function.run(evaluationParams, args, (argNames, fn))
) )
let _runWithEvaluatedInputs = ( let _runWithEvaluatedInputs = (
@ -195,9 +206,13 @@ module Render = {
switch t { switch t {
| #Function(_) => Error("Cannot render a function") | #Function(_) => Error("Cannot render a function")
| #SymbolicDist(d) => | #SymbolicDist(d) =>
Ok(#RenderedDist(SymbolicDist.T.toPointSetDist(evaluationParams.samplingInputs.pointSetDistLength, d))) Ok(
#RenderedDist(
SymbolicDist.T.toPointSetDist(evaluationParams.samplingInputs.pointSetDistLength, d),
),
)
| #RenderedDist(_) as t => Ok(t) // already a rendered pointSetDist, we're done here | #RenderedDist(_) as t => Ok(t) // already a rendered pointSetDist, we're done here
| _ => evaluateAndRetry(evaluationParams, operationToLeaf, t) | _ => ASTTypes.AST.Node.evaluateAndRetry(evaluationParams, operationToLeaf, t)
} }
} }
@ -207,10 +222,7 @@ module Render = {
but most often it will produce a RenderedDist. but most often it will produce a RenderedDist.
This function is used mainly to turn a parse tree into a single RenderedDist This function is used mainly to turn a parse tree into a single RenderedDist
that can then be displayed to the user. */ that can then be displayed to the user. */
let rec toLeaf = ( let rec toLeaf = (evaluationParams: ASTTypes.AST.evaluationParams, node: t): result<t, string> =>
evaluationParams: ASTTypes.AST.evaluationParams,
node: t,
): result<t, string> =>
switch node { switch node {
// Leaf nodes just stay leaf nodes // Leaf nodes just stay leaf nodes
| #SymbolicDist(_) | #SymbolicDist(_)

View File

@ -15,43 +15,20 @@ module AST = {
| #FunctionCall(string, array<node>) | #FunctionCall(string, array<node>)
] ]
module Hash = { type statement = [
type t<'a> = array<(string, 'a)> | #Assignment(string, node)
let getByName = (t: t<'a>, name) => | #Expression(node)
E.A.getBy(t, ((n, _)) => n == name) |> E.O.fmap(((_, r)) => r) ]
type program = array<statement>
let getByNameResult = (t: t<'a>, name) =>
getByName(t, name) |> E.O.toResult(name ++ " expected and not found")
let getByNames = (hash: t<'a>, names: array<string>) =>
names |> E.A.fmap(name => (name, getByName(hash, name)))
}
// Have nil as option
type samplingInputs = {
sampleCount: int,
outputXYPoints: int,
kernelWidth: option<float>,
pointSetDistLength: int,
}
module SamplingInputs = {
type t = {
sampleCount: option<int>,
outputXYPoints: option<int>,
kernelWidth: option<float>,
pointSetDistLength: option<int>,
}
let withDefaults = (t: t): samplingInputs => {
sampleCount: t.sampleCount |> E.O.default(10000),
outputXYPoints: t.outputXYPoints |> E.O.default(10000),
kernelWidth: t.kernelWidth,
pointSetDistLength: t.pointSetDistLength |> E.O.default(10000),
}
}
type environment = Belt.Map.String.t<node> type environment = Belt.Map.String.t<node>
type rec evaluationParams = {
samplingInputs: SamplingInputs.samplingInputs,
environment: environment,
evaluateNode: (evaluationParams, node) => Belt.Result.t<node, string>,
}
module Environment = { module Environment = {
type t = environment type t = environment
module MS = Belt.Map.String module MS = Belt.Map.String
@ -74,18 +51,6 @@ module AST = {
} }
} }
type rec evaluationParams = {
samplingInputs: samplingInputs,
environment: environment,
evaluateNode: (evaluationParams, node) => Belt.Result.t<node, string>,
}
let evaluateNode = (evaluationParams: evaluationParams) =>
evaluationParams.evaluateNode(evaluationParams)
let evaluateAndRetry = (evaluationParams, fn, node) =>
node |> evaluationParams.evaluateNode(evaluationParams) |> E.R.bind(_, fn(evaluationParams))
module Node = { module Node = {
let getFloat = (node: node) => let getFloat = (node: node) =>
node |> ( node |> (
@ -97,6 +62,12 @@ module AST = {
} }
) )
let evaluate = (evaluationParams: evaluationParams) =>
evaluationParams.evaluateNode(evaluationParams)
let evaluateAndRetry = (evaluationParams, fn, node) =>
node |> evaluationParams.evaluateNode(evaluationParams) |> E.R.bind(_, fn(evaluationParams))
let rec toString: node => string = x => let rec toString: node => string = x =>
switch x { switch x {
| #SymbolicDist(d) => SymbolicDist.T.toString(d) | #SymbolicDist(d) => SymbolicDist.T.toString(d)
@ -124,8 +95,7 @@ module AST = {
"}") "}")
} }
let render = (evaluationParams: evaluationParams, r) => let render = (evaluationParams: evaluationParams, r) => #Render(r) |> evaluate(evaluationParams)
#Render(r) |> evaluateNode(evaluationParams)
let ensureIsRendered = (params, t) => let ensureIsRendered = (params, t) =>
switch t { switch t {
@ -160,18 +130,143 @@ module AST = {
let toFloat = (item: node): result<node, string> => let toFloat = (item: node): result<node, string> =>
item |> toPointSetDist |> E.O.bind(_, _toFloat) |> E.O.toResult("Not valid shape") item |> toPointSetDist |> E.O.bind(_, _toFloat) |> E.O.toResult("Not valid shape")
} }
}
type simplificationResult = [ module Function = {
| #Solution(AST.node) type t = (array<string>, node)
| #Error(string) let fromNode: node => option<t> = node =>
| #NoSolution switch node {
] | #Function(r) => Some(r)
| _ => None
}
let argumentNames = ((a, _): t) => a
let internals = ((_, b): t) => b
let run = (evaluationParams: evaluationParams, args: array<node>, t: t) =>
if E.A.length(args) == E.A.length(argumentNames(t)) {
let newEnvironment = Belt.Array.zip(argumentNames(t), args) |> Environment.fromArray
let newEvaluationParams: evaluationParams = {
samplingInputs: evaluationParams.samplingInputs,
environment: Environment.mergeKeepSecond(
evaluationParams.environment,
newEnvironment,
),
evaluateNode: evaluationParams.evaluateNode,
}
evaluationParams.evaluateNode(newEvaluationParams, internals(t))
} else {
Error("Wrong number of variables")
}
}
module Program = { module Primative = {
type statement = [ type t = [
| #Assignment(string, AST.node) | #SymbolicDist(SymbolicDistTypes.symbolicDist)
| #Expression(AST.node) | #RenderedDist(PointSetTypes.pointSetDist)
] | #Function(array<string>, node)
type program = array<statement> ]
let isPrimative: node => bool = x =>
switch x {
| #SymbolicDist(_)
| #RenderedDist(_)
| #Function(_) => true
| _ => false
}
let fromNode: node => option<t> = x =>
switch x {
| #SymbolicDist(_) as n
| #RenderedDist(_) as n
| #Function(_) as n =>
Some(n)
| _ => None
}
}
module SamplingDistribution = {
type t = [
| #SymbolicDist(SymbolicDistTypes.symbolicDist)
| #RenderedDist(PointSetTypes.pointSetDist)
]
let isSamplingDistribution: node => bool = x =>
switch x {
| #SymbolicDist(_) => true
| #RenderedDist(_) => true
| _ => false
}
let fromNode: node => result<t, string> = x =>
switch x {
| #SymbolicDist(n) => Ok(#SymbolicDist(n))
| #RenderedDist(n) => Ok(#RenderedDist(n))
| _ => Error("Not valid type")
}
let renderIfIsNotSamplingDistribution = (params, t): result<node, string> =>
!isSamplingDistribution(t)
? switch Node.render(params, t) {
| Ok(r) => Ok(r)
| Error(e) => Error(e)
}
: Ok(t)
let map = (~renderedDistFn, ~symbolicDistFn, node: node) =>
node |> (
x =>
switch x {
| #RenderedDist(r) => Some(renderedDistFn(r))
| #SymbolicDist(s) => Some(symbolicDistFn(s))
| _ => None
}
)
let sampleN = n =>
map(
~renderedDistFn=PointSetDist.sampleNRendered(n),
~symbolicDistFn=SymbolicDist.T.sampleN(n),
)
let getCombinationSamples = (n, algebraicOp, t1: node, t2: node) =>
switch (sampleN(n, t1), sampleN(n, t2)) {
| (Some(a), Some(b)) =>
Some(
Belt.Array.zip(a, b) |> E.A.fmap(((a, b)) => Operation.Algebraic.toFn(algebraicOp, a, b)),
)
| _ => None
}
let combineShapesUsingSampling = (
evaluationParams: evaluationParams,
algebraicOp,
t1: node,
t2: node,
) => {
let i1 = renderIfIsNotSamplingDistribution(evaluationParams, t1)
let i2 = renderIfIsNotSamplingDistribution(evaluationParams, t2)
E.R.merge(i1, i2) |> E.R.bind(_, ((a, b)) => {
let samples = getCombinationSamples(
evaluationParams.samplingInputs.sampleCount,
algebraicOp,
a,
b,
)
let pointSetDist =
samples
|> E.O.fmap(r =>
SampleSet.toPointSetDist(
~samplingInputs=evaluationParams.samplingInputs,
~samples=r,
(),
)
)
|> E.O.bind(_, r => r.pointSetDist)
|> E.O.toResult("No response")
pointSetDist |> E.R.fmap(r => #Normalize(#RenderedDist(r)))
})
// todo: This bottom part should probably be somewhere else.
// todo: REFACTOR: I'm not sure about the SampleSet line.
}
}
} }

View File

@ -1,138 +0,0 @@
open ASTTypes.AST
module Function = {
type t = (array<string>, node)
let fromNode: node => option<t> = node =>
switch node {
| #Function(r) => Some(r)
| _ => None
}
let argumentNames = ((a, _): t) => a
let internals = ((_, b): t) => b
let run = (
evaluationParams: ASTTypes.AST.evaluationParams,
args: array<node>,
t: t,
) =>
if E.A.length(args) == E.A.length(argumentNames(t)) {
let newEnvironment =
Belt.Array.zip(
argumentNames(t),
args,
) |> ASTTypes.AST.Environment.fromArray
let newEvaluationParams: ASTTypes.AST.evaluationParams = {
samplingInputs: evaluationParams.samplingInputs,
environment: ASTTypes.AST.Environment.mergeKeepSecond(
evaluationParams.environment,
newEnvironment,
),
evaluateNode: evaluationParams.evaluateNode,
}
evaluationParams.evaluateNode(newEvaluationParams, internals(t))
} else {
Error("Wrong number of variables")
}
}
module Primative = {
type t = [
| #SymbolicDist(SymbolicDistTypes.symbolicDist)
| #RenderedDist(PointSetTypes.pointSetDist)
| #Function(array<string>, node)
]
let isPrimative: node => bool = x =>
switch x {
| #SymbolicDist(_)
| #RenderedDist(_)
| #Function(_) => true
| _ => false
}
let fromNode: node => option<t> = x =>
switch x {
| #SymbolicDist(_) as n
| #RenderedDist(_) as n
| #Function(_) as n =>
Some(n)
| _ => None
}
}
module SamplingDistribution = {
type t = [
| #SymbolicDist(SymbolicDistTypes.symbolicDist)
| #RenderedDist(PointSetTypes.pointSetDist)
]
let isSamplingDistribution: node => bool = x =>
switch x {
| #SymbolicDist(_) => true
| #RenderedDist(_) => true
| _ => false
}
let fromNode: node => result<t, string> = x =>
switch x {
| #SymbolicDist(n) => Ok(#SymbolicDist(n))
| #RenderedDist(n) => Ok(#RenderedDist(n))
| _ => Error("Not valid type")
}
let renderIfIsNotSamplingDistribution = (params, t): result<node, string> =>
!isSamplingDistribution(t)
? switch Node.render(params, t) {
| Ok(r) => Ok(r)
| Error(e) => Error(e)
}
: Ok(t)
let map = (~renderedDistFn, ~symbolicDistFn, node: node) =>
node |> (
x =>
switch x {
| #RenderedDist(r) => Some(renderedDistFn(r))
| #SymbolicDist(s) => Some(symbolicDistFn(s))
| _ => None
}
)
let sampleN = n =>
map(~renderedDistFn=PointSetDist.sampleNRendered(n), ~symbolicDistFn=SymbolicDist.T.sampleN(n))
let getCombinationSamples = (n, algebraicOp, t1: node, t2: node) =>
switch (sampleN(n, t1), sampleN(n, t2)) {
| (Some(a), Some(b)) =>
Some(
Belt.Array.zip(a, b) |> E.A.fmap(((a, b)) => Operation.Algebraic.toFn(algebraicOp, a, b)),
)
| _ => None
}
let combineShapesUsingSampling = (
evaluationParams: evaluationParams,
algebraicOp,
t1: node,
t2: node,
) => {
let i1 = renderIfIsNotSamplingDistribution(evaluationParams, t1)
let i2 = renderIfIsNotSamplingDistribution(evaluationParams, t2)
E.R.merge(i1, i2) |> E.R.bind(_, ((a, b)) => {
let samples = getCombinationSamples(
evaluationParams.samplingInputs.sampleCount,
algebraicOp,
a,
b,
)
// todo: This bottom part should probably be somewhere else.
// todo: REFACTOR: I'm not sure about the SampleSet line.
let pointSetDist =
samples
|> E.O.fmap(r => SampleSet.toPointSetDist(~samplingInputs=evaluationParams.samplingInputs, ~samples=r, ()))
|> E.O.bind(_, r => r.pointSetDist)
|> E.O.toResult("No response")
pointSetDist |> E.R.fmap(r => #Normalize(#RenderedDist(r)))
})
}
}

View File

@ -111,7 +111,7 @@ let verticalScaling = (scaleOp, rs, scaleBy) => {
} }
module Multimodal = { module Multimodal = {
let getByNameResult = ASTTypes.AST.Hash.getByNameResult let getByNameResult = Hash.getByNameResult
let _paramsToDistsAndWeights = (r: array<typedValue>) => let _paramsToDistsAndWeights = (r: array<typedValue>) =>
switch r { switch r {

View File

@ -73,7 +73,7 @@ module TypedValue = {
| _ => Error("Type Error: Expected float.") | _ => Error("Type Error: Expected float.")
} }
| (#SamplingDistribution, _) => | (#SamplingDistribution, _) =>
PTypes.SamplingDistribution.renderIfIsNotSamplingDistribution( ASTTypes.AST.SamplingDistribution.renderIfIsNotSamplingDistribution(
evaluationParams, evaluationParams,
node, node,
) |> E.R.bind(_, fromNode) ) |> E.R.bind(_, fromNode)
@ -89,7 +89,7 @@ module TypedValue = {
named |> E.A.fmap(((name, intendedType)) => ( named |> E.A.fmap(((name, intendedType)) => (
name, name,
intendedType, intendedType,
ASTTypes.AST.Hash.getByName(r, name), Hash.getByName(r, name),
)) ))
let typedHash = let typedHash =
keyValues keyValues

View File

@ -246,7 +246,7 @@ module MathAdtToDistDst = {
// let evaluatedExpression = run(expression); // let evaluatedExpression = run(expression);
// `Function(_ => Ok(evaluatedExpression)); // `Function(_ => Ok(evaluatedExpression));
// } // }
let rec topLevel = (r): result<ASTTypes.Program.program, string> => let rec topLevel = (r): result<ASTTypes.AST.program, string> =>
switch r { switch r {
| FunctionAssignment({name, args, expression}) => | FunctionAssignment({name, args, expression}) =>
switch nodeParser(expression) { switch nodeParser(expression) {
@ -267,7 +267,7 @@ module MathAdtToDistDst = {
blocks |> E.A.fmap(b => topLevel(b)) |> E.A.R.firstErrorOrOpen |> E.R.fmap(E.A.concatMany) blocks |> E.A.fmap(b => topLevel(b)) |> E.A.R.firstErrorOrOpen |> E.R.fmap(E.A.concatMany)
} }
let run = (r): result<ASTTypes.Program.program, string> => let run = (r): result<ASTTypes.AST.program, string> =>
r |> MathAdtCleaner.run |> topLevel r |> MathAdtCleaner.run |> topLevel
} }

View File

@ -80,7 +80,7 @@ module Internals = {
let toPointSetDist = ( let toPointSetDist = (
~samples: Internals.T.t, ~samples: Internals.T.t,
~samplingInputs: ASTTypes.AST.samplingInputs, ~samplingInputs: SamplingInputs.samplingInputs,
(), (),
) => { ) => {
Array.fast_sort(compare, samples) Array.fast_sort(compare, samples)

View File

@ -0,0 +1,8 @@
type t<'a> = array<(string, 'a)>
let getByName = (t: t<'a>, name) => E.A.getBy(t, ((n, _)) => n == name) |> E.O.fmap(((_, r)) => r)
let getByNameResult = (t: t<'a>, name) =>
getByName(t, name) |> E.O.toResult(name ++ " expected and not found")
let getByNames = (hash: t<'a>, names: array<string>) =>
names |> E.A.fmap(name => (name, getByName(hash, name)))

View File

@ -0,0 +1,21 @@
type samplingInputs = {
sampleCount: int,
outputXYPoints: int,
kernelWidth: option<float>,
pointSetDistLength: int,
}
module SamplingInputs = {
type t = {
sampleCount: option<int>,
outputXYPoints: option<int>,
kernelWidth: option<float>,
pointSetDistLength: option<int>,
}
let withDefaults = (t: t): samplingInputs => {
sampleCount: t.sampleCount |> E.O.default(10000),
outputXYPoints: t.outputXYPoints |> E.O.default(10000),
kernelWidth: t.kernelWidth,
pointSetDistLength: t.pointSetDistLength |> E.O.default(10000),
}
}