squiggle/packages/squiggle-lang/src/rescript/GenericDist/GenericDist_GenericOperation.res

173 lines
5.5 KiB
Plaintext
Raw Normal View History

type functionCallInfo = GenericDist_Types.Operation.genericFunctionCallInfo
type genericDist = GenericDist_Types.genericDist
type error = GenericDist_Types.error
// TODO: It could be great to use a cache for some calculations (basically, do memoization). Also, better analytics/tracking could go a long way.
2022-03-27 18:22:26 +00:00
type env = {
2022-03-27 18:22:26 +00:00
sampleCount: int,
xyPointLength: int,
}
type outputType =
| Dist(GenericDist_Types.genericDist)
| Float(float)
| String(string)
| GenDistError(GenericDist_Types.error)
2022-03-27 18:22:26 +00:00
/*
We're going to add another function to this module later, so first define a
local version, which is not exported.
*/
module OutputLocal = {
type t = outputType
let toError = (t: outputType) =>
switch t {
| GenDistError(d) => Some(d)
| _ => None
}
let toErrorOrUnreachable = (t: t): error => t->toError->E.O2.default((Unreachable: error))
let toDistR = (t: t): result<genericDist, error> =>
switch t {
| Dist(r) => Ok(r)
| e => Error(toErrorOrUnreachable(e))
}
let toDist = (t: t) =>
switch t {
| Dist(d) => Some(d)
2022-03-28 19:14:39 +00:00
| _ => None
}
let toFloat = (t: t) =>
switch t {
| Float(d) => Some(d)
2022-03-28 19:14:39 +00:00
| _ => None
}
let toString = (t: t) =>
switch t {
| String(d) => Some(d)
2022-03-28 19:14:39 +00:00
| _ => None
}
//This is used to catch errors in other switch statements.
let fromResult = (r: result<t, error>): outputType =>
switch r {
| Ok(t) => t
| Error(e) => GenDistError(e)
2022-03-28 19:14:39 +00:00
}
}
let rec run = (~env, functionCallInfo: functionCallInfo): outputType => {
let {sampleCount, xyPointLength} = env
let reCall = (~env=env, ~functionCallInfo=functionCallInfo, ()) => {
run(~env, functionCallInfo)
2022-03-27 18:22:26 +00:00
}
let toPointSetFn = r => {
switch reCall(~functionCallInfo=#fromDist(#toDist(#toPointSet), r), ()) {
| Dist(PointSet(p)) => Ok(p)
| e => Error(OutputLocal.toErrorOrUnreachable(e))
2022-03-27 18:22:26 +00:00
}
}
let toSampleSetFn = r => {
switch reCall(~functionCallInfo=#fromDist(#toDist(#toSampleSet(sampleCount)), r), ()) {
| Dist(SampleSet(p)) => Ok(p)
| e => Error(OutputLocal.toErrorOrUnreachable(e))
2022-03-27 18:22:26 +00:00
}
}
let scaleMultiply = (r, weight) =>
reCall(
~functionCallInfo=#fromDist(#toDistCombination(#Pointwise, #Multiply, #Float(weight)), r),
(),
)->OutputLocal.toDistR
let pointwiseAdd = (r1, r2) =>
reCall(
~functionCallInfo=#fromDist(#toDistCombination(#Pointwise, #Add, #Dist(r2)), r1),
(),
)->OutputLocal.toDistR
2022-03-31 12:37:04 +00:00
let fromDistFn = (subFnName: GenericDist_Types.Operation.fromDist, dist: genericDist) =>
switch subFnName {
| #toFloat(distToFloatOperation) =>
GenericDist.toFloatOperation(dist, ~toPointSetFn, ~distToFloatOperation)
->E.R2.fmap(r => Float(r))
->OutputLocal.fromResult
| #toString => dist->GenericDist.toString->String
| #toDist(#inspect) => {
2022-03-28 19:14:39 +00:00
Js.log2("Console log requested: ", dist)
Dist(dist)
2022-03-28 19:14:39 +00:00
}
| #toDist(#normalize) => dist->GenericDist.normalize->Dist
| #toDist(#truncate(leftCutoff, rightCutoff)) =>
GenericDist.truncate(~toPointSetFn, ~leftCutoff, ~rightCutoff, dist, ())
->E.R2.fmap(r => Dist(r))
->OutputLocal.fromResult
| #toDist(#toPointSet) =>
dist
->GenericDist.toPointSet(~xyPointLength, ~sampleCount)
->E.R2.fmap(r => Dist(PointSet(r)))
->OutputLocal.fromResult
| #toDist(#toSampleSet(n)) =>
dist->GenericDist.sampleN(n)->E.R2.fmap(r => Dist(SampleSet(r)))->OutputLocal.fromResult
| #toDistCombination(#Algebraic, _, #Float(_)) => GenDistError(NotYetImplemented)
| #toDistCombination(#Algebraic, arithmeticOperation, #Dist(t2)) =>
dist
->GenericDist.algebraicCombination(~toPointSetFn, ~toSampleSetFn, ~arithmeticOperation, ~t2)
->E.R2.fmap(r => Dist(r))
->OutputLocal.fromResult
| #toDistCombination(#Pointwise, arithmeticOperation, #Dist(t2)) =>
dist
->GenericDist.pointwiseCombination(~toPointSetFn, ~arithmeticOperation, ~t2)
->E.R2.fmap(r => Dist(r))
->OutputLocal.fromResult
| #toDistCombination(#Pointwise, arithmeticOperation, #Float(float)) =>
dist
->GenericDist.pointwiseCombinationFloat(~toPointSetFn, ~arithmeticOperation, ~float)
->E.R2.fmap(r => Dist(r))
->OutputLocal.fromResult
}
switch functionCallInfo {
2022-03-31 12:37:04 +00:00
| #fromDist(subFnName, dist) => fromDistFn(subFnName, dist)
| #fromFloat(subFnName, float) =>
reCall(~functionCallInfo=#fromDist(subFnName, GenericDist.fromFloat(float)), ())
| #mixture(dists) =>
dists
->GenericDist.mixture(~scaleMultiplyFn=scaleMultiply, ~pointwiseAddFn=pointwiseAdd)
->E.R2.fmap(r => Dist(r))
->OutputLocal.fromResult
2022-03-27 18:22:26 +00:00
}
}
2022-03-28 19:14:39 +00:00
let runFromDist = (~env, ~functionCallInfo, dist) => run(~env, #fromDist(functionCallInfo, dist))
let runFromFloat = (~env, ~functionCallInfo, float) =>
run(~env, #fromFloat(functionCallInfo, float))
2022-03-28 19:14:39 +00:00
module Output = {
include OutputLocal
let fmap = (
~env,
input: outputType,
functionCallInfo: GenericDist_Types.Operation.singleParamaterFunction,
): outputType => {
let newFnCall: result<functionCallInfo, error> = switch (functionCallInfo, input) {
| (#fromDist(fromDist), Dist(o)) => Ok(#fromDist(fromDist, o))
| (#fromFloat(fromDist), Float(o)) => Ok(#fromFloat(fromDist, o))
| (_, GenDistError(r)) => Error(r)
| (#fromDist(_), _) => Error(Other("Expected dist, got something else"))
| (#fromFloat(_), _) => Error(Other("Expected float, got something else"))
}
newFnCall->E.R2.fmap(run(~env))->OutputLocal.fromResult
2022-03-28 19:14:39 +00:00
}
}