Refactor GenericOperation to allow for operations other than toDist operations

This commit is contained in:
Ozzie Gooen 2022-03-27 16:59:46 -04:00
parent c2ac9614d0
commit b70e8e02e1
3 changed files with 84 additions and 39 deletions

View File

@ -1,4 +1,4 @@
//TODO: multimodal, add interface, split up a little bit, test somehow, track performance, refactor sampleSet, refactor ASTEvaluator.res. //TODO: multimodal, add interface, test somehow, track performance, refactor sampleSet, refactor ASTEvaluator.res.
type genericDist = GenericDist_Types.genericDist type genericDist = GenericDist_Types.genericDist
type error = GenericDist_Types.error type error = GenericDist_Types.error
type toPointSetFn = genericDist => result<PointSetTypes.pointSetDist, error> type toPointSetFn = genericDist => result<PointSetTypes.pointSetDist, error>
@ -26,6 +26,8 @@ let normalize = (t: t) =>
| #SampleSet(_) => t | #SampleSet(_) => t
} }
// let isNormalized = (t:t) =>
let operationToFloat = (toPointSet: toPointSetFn, fnName, t: genericDist): result<float, error> => { let operationToFloat = (toPointSet: toPointSetFn, fnName, t: genericDist): result<float, error> => {
let symbolicSolution = switch t { let symbolicSolution = switch t {
@ -36,6 +38,7 @@ let operationToFloat = (toPointSet: toPointSetFn, fnName, t: genericDist): resul
} }
| _ => None | _ => None
} }
switch symbolicSolution { switch symbolicSolution {
| Some(r) => Ok(r) | Some(r) => Ok(r)
| None => toPointSet(t) |> E.R.fmap(PointSetDist.operate(fnName)) | None => toPointSet(t) |> E.R.fmap(PointSetDist.operate(fnName))

View File

@ -1,6 +1,8 @@
type operation = GenericDist_Types.Operation.t type operation = GenericDist_Types.Operation.genericFunction
type genericDist = GenericDist_Types.genericDist; type genericDist = GenericDist_Types.genericDist
type error = GenericDist_Types.error; type error = GenericDist_Types.error
// TODO: It could be great to use a cache for some calculations (basically, do memoization). Also, better analytics/tracking could go a long way.
type params = { type params = {
sampleCount: int, sampleCount: int,
@ -27,52 +29,70 @@ let fromResult = (r: result<outputType, error>): outputType =>
| Error(e) => #Error(e) | Error(e) => #Error(e)
} }
let rec run = (wrapped: wrapped, fnName: operation): outputType => { let rec run = (extra, fnName: operation): outputType => {
let (value, {sampleCount, xyPointLength} as extra) = wrapped let {sampleCount, xyPointLength} = extra
let reCall = (~value=value, ~extra=extra, ~fnName=fnName, ()) => { let reCall = (~extra=extra, ~fnName=fnName, ()) => {
run((value, extra), fnName) run(extra, fnName)
} }
let toPointSet = r => { let toPointSet = r => {
switch reCall(~value=r, ~fnName=#toDist(#toPointSet), ()) { switch reCall(~fnName=#fromDist(#toDist(#toPointSet), r), ()) {
| #Dist(#PointSet(p)) => Ok(p) | #Dist(#PointSet(p)) => Ok(p)
| #Error(r) => Error(r) | #Error(r) => Error(r)
| _ => Error(ImpossiblePath) | _ => Error(ImpossiblePath)
} }
} }
let toSampleSet = r => { let toSampleSet = r => {
switch reCall(~value=r, ~fnName=#toDist(#toSampleSet(sampleCount)), ()) { switch reCall(~fnName=#fromDist(#toDist(#toSampleSet(sampleCount)), r), ()) {
| #Dist(#SampleSet(p)) => Ok(p) | #Dist(#SampleSet(p)) => Ok(p)
| #Error(r) => Error(r) | #Error(r) => Error(r)
| _ => Error(ImpossiblePath) | _ => Error(ImpossiblePath)
} }
} }
let fromDistFn = (subFn: GenericDist_Types.Operation.fromDist, dist: genericDist) =>
switch subFn {
| #toFloat(fnName) =>
GenericDist.operationToFloat(toPointSet, fnName, dist)
|> E.R.fmap(r => #Float(r))
|> fromResult
| #toString => #Error(GenericDist_Types.NotYetImplemented)
| #toDist(#normalize) => dist |> GenericDist.normalize |> (r => #Dist(r))
| #toDist(#truncate(left, right)) =>
dist
|> GenericDist.Truncate.run(toPointSet, left, right)
|> E.R.fmap(r => #Dist(r))
|> fromResult
| #toDist(#toPointSet) =>
dist
|> GenericDist.toPointSet(xyPointLength)
|> E.R.fmap(r => #Dist(#PointSet(r)))
|> fromResult
| #toDist(#toSampleSet(n)) =>
dist |> GenericDist.sampleN(n) |> E.R.fmap(r => #Dist(#SampleSet(r))) |> fromResult
| #toDistCombination(#Algebraic, _, #Float(_)) => #Error(NotYetImplemented)
| #toDistCombination(#Algebraic, operation, #Dist(dist2)) =>
dist
|> GenericDist.AlgebraicCombination.run(toPointSet, toSampleSet, operation, dist2)
|> E.R.fmap(r => #Dist(r))
|> fromResult
| #toDistCombination(#Pointwise, operation, #Dist(dist2)) =>
dist
|> GenericDist.pointwiseCombination(toPointSet, operation, dist2)
|> E.R.fmap(r => #Dist(r))
|> fromResult
| #toDistCombination(#Pointwise, operation, #Float(f)) =>
dist
|> GenericDist.pointwiseCombinationFloat(toPointSet, operation, f)
|> E.R.fmap(r => #Dist(r))
|> fromResult
}
switch fnName { switch fnName {
| #toFloat(fnName) => | #fromDist(subFn, dist) => fromDistFn(subFn, dist)
GenericDist.operationToFloat(toPointSet, fnName, value) |> E.R.fmap(r => #Float(r)) |> fromResult | #fromFloat(subFn, float) => reCall(
| #toString => ~fnName=#fromDist(subFn, #Symbolic(SymbolicDist.Float.make(float))),
#Error(GenericDist_Types.NotYetImplemented) (),
| #toDist(#normalize) => value |> GenericDist.normalize |> (r => #Dist(r)) )
| #toDist(#truncate(left, right)) => | _ => #Error(NotYetImplemented)
value |> GenericDist.Truncate.run(toPointSet, left, right) |> E.R.fmap(r => #Dist(r)) |> fromResult
| #toDist(#toPointSet) =>
value |> GenericDist.toPointSet(xyPointLength) |> E.R.fmap(r => #Dist(#PointSet(r))) |> fromResult
| #toDist(#toSampleSet(n)) =>
value |> GenericDist.sampleN(n) |> E.R.fmap(r => #Dist(#SampleSet(r))) |> fromResult
| #toDistCombination(#Algebraic, _, #Float(_)) => #Error(NotYetImplemented)
| #toDistCombination(#Algebraic, operation, #Dist(value2)) =>
value
|> GenericDist.AlgebraicCombination.run(toPointSet, toSampleSet, operation, value2)
|> E.R.fmap(r => #Dist(r))
|> fromResult
| #toDistCombination(#Pointwise, operation, #Dist(value2)) =>
value
|> GenericDist.pointwiseCombination(toPointSet, operation, value2)
|> E.R.fmap(r => #Dist(r))
|> fromResult
| #toDistCombination(#Pointwise, operation, #Float(f)) =>
value
|> GenericDist.pointwiseCombinationFloat(toPointSet, operation, f)
|> E.R.fmap(r => #Dist(r))
|> fromResult
} }
} }

View File

@ -54,10 +54,32 @@ module Operation = {
| #Sample(int) | #Sample(int)
] ]
type t = [ type fromDist = [
| #toFloat(toFloat) | #toFloat(toFloat)
| #toDist(toDist) | #toDist(toDist)
| #toDistCombination(direction, arithmeticOperation, [#Dist(genericDist) | #Float(float)]) | #toDistCombination(direction, arithmeticOperation, [#Dist(genericDist) | #Float(float)])
| #toString | #toString
] ]
}
type genericFunction = [
| #fromDist(fromDist, genericDist)
| #fromFloat(fromDist, float)
| #mixture(array<(genericDist, float)>)
]
let toString = (distFunction: fromDist): string =>
switch distFunction {
| #toFloat(#Cdf(r)) => `cdf(${E.Float.toFixed(r)})`
| #toFloat(#Inv(r)) => `inv(${E.Float.toFixed(r)})`
| #toFloat(#Mean) => `mean`
| #toFloat(#Pdf(r)) => `pdf${E.Float.toFixed(r)}`
| #toFloat(#Sample) => `sample`
| #toDist(#normalize) => `normalize`
| #toDist(#toPointSet) => `toPointSet`
| #toDist(#toSampleSet(r)) => `toSampleSet${E.I.toString(r)}`
| #toDist(#truncate(_, _)) => `truncate`
| #toString => `toString`
| #toDistCombination(#Algebraic, _, _) => `algebraic`
| #toDistCombination(#Pointwise, _, _) => `pointwise`
}
}