squiggle/packages/squiggle-lang/src/rescript/GenericDist/GenericDist_GenericOperation.res

164 lines
4.9 KiB
Plaintext
Raw Normal View History

2022-03-30 01:28:14 +00:00
type operation = GenericDist_Types.Operation.genericFunctionCallInfo
type genericDist = GenericDist_Types.genericDist
type error = GenericDist_Types.error
// TODO: It could be great to use a cache for some calculations (basically, do memoization). Also, better analytics/tracking could go a long way.
2022-03-27 18:22:26 +00:00
type params = {
sampleCount: int,
xyPointLength: int,
}
type outputType = [
| #Dist(genericDist)
| #Float(float)
2022-03-28 01:07:41 +00:00
| #String(string)
2022-03-30 01:28:14 +00:00
| #GenDistError(error)
2022-03-27 18:22:26 +00:00
]
2022-03-28 19:14:39 +00:00
module Output = {
let toDist = (o: outputType) =>
switch o {
| #Dist(d) => Some(d)
| _ => None
}
let toFloat = (o: outputType) =>
switch o {
| #Float(d) => Some(d)
| _ => None
}
let toString = (o: outputType) =>
switch o {
| #String(d) => Some(d)
| _ => None
}
let toError = (o: outputType) =>
switch o {
2022-03-30 01:28:14 +00:00
| #GenDistError(d) => Some(d)
2022-03-28 19:14:39 +00:00
| _ => None
}
}
2022-03-27 18:22:26 +00:00
let fromResult = (r: result<outputType, error>): outputType =>
switch r {
| Ok(o) => o
2022-03-30 01:28:14 +00:00
| Error(e) => #GenDistError(e)
2022-03-27 18:22:26 +00:00
}
//This is used to catch errors in other switch statements.
let _errorMap = (o: outputType): error =>
switch o {
| #GenDistError(r) => r
| _ => Unreachable
}
2022-03-31 12:37:04 +00:00
let outputToDistResult = (o: outputType): result<genericDist, error> =>
switch o {
| #Dist(r) => Ok(r)
| r => Error(_errorMap(r))
}
let rec run = (extra, fnName: operation): outputType => {
let {sampleCount, xyPointLength} = extra
let reCall = (~extra=extra, ~fnName=fnName, ()) => {
run(extra, fnName)
2022-03-27 18:22:26 +00:00
}
let toPointSetFn = r => {
switch reCall(~fnName=#fromDist(#toDist(#toPointSet), r), ()) {
2022-03-27 18:22:26 +00:00
| #Dist(#PointSet(p)) => Ok(p)
| r => Error(_errorMap(r))
2022-03-27 18:22:26 +00:00
}
}
let toSampleSetFn = r => {
switch reCall(~fnName=#fromDist(#toDist(#toSampleSet(sampleCount)), r), ()) {
2022-03-27 18:22:26 +00:00
| #Dist(#SampleSet(p)) => Ok(p)
| r => Error(_errorMap(r))
2022-03-27 18:22:26 +00:00
}
}
let scaleMultiply = (r, weight) =>
reCall(
~fnName=#fromDist(#toDistCombination(#Pointwise, #Multiply, #Float(weight)), r),
(),
2022-03-29 19:21:38 +00:00
)->outputToDistResult
let pointwiseAdd = (r1, r2) =>
reCall(
~fnName=#fromDist(#toDistCombination(#Pointwise, #Add, #Dist(r2)), r1),
(),
2022-03-29 19:21:38 +00:00
)->outputToDistResult
2022-03-31 12:37:04 +00:00
let fromDistFn = (subFnName: GenericDist_Types.Operation.fromDist, dist: genericDist) =>
switch subFnName {
| #toFloat(fnName) =>
GenericDist.operationToFloat(dist, ~toPointSetFn, ~operation=fnName)
->E.R2.fmap(r => #Float(r))
->fromResult
2022-03-29 19:21:38 +00:00
| #toString => dist->GenericDist.toString->(r => #String(r))
| #toDist(#inspect) => {
2022-03-28 19:14:39 +00:00
Js.log2("Console log requested: ", dist)
#Dist(dist)
}
2022-03-29 19:21:38 +00:00
| #toDist(#normalize) => dist->GenericDist.normalize->(r => #Dist(r))
| #toDist(#truncate(leftCutoff, rightCutoff)) =>
GenericDist.truncate(~toPointSetFn, ~leftCutoff, ~rightCutoff, dist, ())
->E.R2.fmap(r => #Dist(r))
->fromResult
| #toDist(#toPointSet) =>
2022-03-29 21:35:33 +00:00
dist->GenericDist.toPointSet(xyPointLength)->E.R2.fmap(r => #Dist(#PointSet(r)))->fromResult
| #toDist(#toSampleSet(n)) =>
2022-03-29 21:35:33 +00:00
dist->GenericDist.sampleN(n)->E.R2.fmap(r => #Dist(#SampleSet(r)))->fromResult
2022-03-30 01:28:14 +00:00
| #toDistCombination(#Algebraic, _, #Float(_)) => #GenDistError(NotYetImplemented)
| #toDistCombination(#Algebraic, operation, #Dist(t2)) =>
dist
->GenericDist.algebraicCombination(~toPointSetFn, ~toSampleSetFn, ~operation, ~t2)
2022-03-29 21:35:33 +00:00
->E.R2.fmap(r => #Dist(r))
2022-03-29 19:47:32 +00:00
->fromResult
| #toDistCombination(#Pointwise, operation, #Dist(t2)) =>
dist
->GenericDist.pointwiseCombination(~toPointSetFn, ~operation, ~t2)
2022-03-29 21:35:33 +00:00
->E.R2.fmap(r => #Dist(r))
2022-03-29 19:47:32 +00:00
->fromResult
| #toDistCombination(#Pointwise, operation, #Float(float)) =>
dist
->GenericDist.pointwiseCombinationFloat(~toPointSetFn, ~operation, ~float)
2022-03-29 21:35:33 +00:00
->E.R2.fmap(r => #Dist(r))
2022-03-29 19:47:32 +00:00
->fromResult
}
2022-03-27 18:22:26 +00:00
switch fnName {
2022-03-31 12:37:04 +00:00
| #fromDist(subFnName, dist) => fromDistFn(subFnName, dist)
| #fromFloat(subFnName, float) =>
reCall(~fnName=#fromDist(subFnName, GenericDist.fromFloat(float)), ())
| #mixture(dists) =>
dists
->GenericDist.mixture(~scaleMultiplyFn=scaleMultiply, ~pointwiseAddFn=pointwiseAdd)
->E.R2.fmap(r => #Dist(r))
->fromResult
2022-03-27 18:22:26 +00:00
}
}
2022-03-28 19:14:39 +00:00
let runFromDist = (extra, fnName, dist) => run(extra, #fromDist(fnName, dist))
let runFromFloat = (extra, fnName, float) => run(extra, #fromFloat(fnName, float))
2022-03-31 12:41:50 +00:00
let outputMap = (
2022-03-28 19:14:39 +00:00
extra,
input: outputType,
fn: GenericDist_Types.Operation.singleParamaterFunction,
2022-03-28 19:14:39 +00:00
): outputType => {
let newFnCall: result<operation, error> = switch (fn, input) {
| (#fromDist(fromDist), #Dist(o)) => Ok(#fromDist(fromDist, o))
| (#fromFloat(fromDist), #Float(o)) => Ok(#fromFloat(fromDist, o))
2022-03-30 01:28:14 +00:00
| (_, #GenDistError(r)) => Error(r)
2022-03-28 19:14:39 +00:00
| (#fromDist(_), _) => Error(Other("Expected dist, got something else"))
| (#fromFloat(_), _) => Error(Other("Expected float, got something else"))
}
2022-03-29 21:35:33 +00:00
newFnCall->E.R2.fmap(r => run(extra, r))->fromResult
}