Minor refactors

This commit is contained in:
Ozzie Gooen 2022-03-25 22:11:27 -04:00 committed by GitHub
parent e93b500d4a
commit 1a2ce5bfa0
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -1,19 +1,28 @@
type symboliDist = SymbolicDistTypes.symbolicDist
type error =
| NeedsPointSetConversion
| InputsNeedPointSetConversion
| NotYetImplemented
| ImpossiblePath
| Other(string)
type genericDist = [
| #XYShape(PointSetTypes.pointSetDist)
| #PointSet(PointSetTypes.pointSetDist)
| #SampleSet(array<float>)
| #Symbolic(symboliDist)
| #Symbolic(SymbolicDistTypes.symbolicDist)
]
type outputType = [
| #Dist(genericDist)
| #Error(error)
| #Float(float)
]
let fromResult = (r: result<outputType, error>): outputType =>
switch r {
| Ok(o) => o
| Error(e) => #Error(e)
}
type direction = [
| #Algebraic
| #Pointwise
@ -71,10 +80,11 @@ let genericParams = {
}
type wrapped = (genericDist, params)
type wrappedOutput = (outputType, params)
let wrapWithParams = (g: genericDist, f: params): wrapped => (g, f)
let exampleDist: genericDist = #XYShape(
let exampleDist: genericDist = #PointSet(
Discrete(Discrete.make(~integralSumCache=Some(1.0), {xs: [3.0], ys: [1.0]})),
)
@ -106,45 +116,79 @@ module AlgebraicCombination = {
let sampleN = (n, genericDist) => {
switch genericDist {
| #XYShape(r) => Ok(PointSetDist.sampleNRendered(n, r))
| #PointSet(r) => Ok(PointSetDist.sampleNRendered(n, r))
| #Symbolic(r) => Ok(SymbolicDist.T.sampleN(n, r))
| #SampleSet(r) => Error(NotYetImplemented)
| #Error(r) => Error(r)
| _ => Error(NotYetImplemented)
| #SampleSet(_) => Error(NotYetImplemented)
}
}
let rec applyFnInternal = (wrapped: wrapped, fnName: operation): wrapped => {
let (v, {sampleCount, xyPointLength} as extra) = wrapped
let newVal: genericDist = switch (fnName, v) {
| (#toFloat(n), #XYShape(r)) => #Float(PointSetDist.operate(n, r))
| (#toFloat(n), #Symbolic(r)) =>
switch SymbolicDist.T.operate(n, r) {
| Ok(float) => #Float(float)
| Error(e) => #Error(Other(e))
let toFloat = (
toPointSet: genericDist => result<PointSetTypes.pointSetDist, error>,
fnName,
value,
) => {
switch value {
| #Symbolic(r) if Belt.Result.isOk(SymbolicDist.T.operate(fnName, r)) =>
switch SymbolicDist.T.operate(fnName, r) {
| Ok(float) => Ok(float)
| Error(_) => Error(ImpossiblePath)
}
| (#toFloat(n), #SampleSet(_)) => #Error(NeedsPointSetConversion)
| (#toDist(#normalize), #XYShape(r)) => #XYShape(PointSetDist.T.normalize(r))
| (#toDist(#normalize), #Symbolic(_)) => v
| (#toDist(#normalize), #SampleSet(_)) => v
| (#toDist(#toPointSet), #XYShape(_)) => v
| (#toDist(#toPointSet), #Symbolic(r)) => #XYShape(SymbolicDist.T.toPointSetDist(sampleCount, r))
| (#toDist(#toPointSet), #SampleSet(r)) => {
| #PointSet(r) => Ok(PointSetDist.operate(fnName, r))
| _ =>
switch toPointSet(value) {
| Ok(r) => Ok(PointSetDist.operate(fnName, r))
| Error(r) => Error(r)
}
}
}
let distToPointSet = (sampleCount, dist: genericDist) => {
switch dist {
| #PointSet(pointSet) => Ok(pointSet)
| #Symbolic(r) => Ok(SymbolicDist.T.toPointSetDist(sampleCount, r))
| #SampleSet(r) => {
let response = SampleSet.toPointSetDist(
~samples=r,
~samplingInputs=defaultSamplingInputs,
(),
).pointSetDist
switch response {
| Some(r) => #XYShape(r)
| None => #Error(Other("Failed to convert sample into shape"))
| Some(r) => Ok(r)
| None => Error(Other("Converting sampleSet to pointSet failed"))
}
}
| (#toDist(#toSampleSet(n)), r) =>
switch sampleN(n, r) {
| Ok(r) => #SampleSet(r)
| Error(r) => #Error(r)
}
}
let rec applyFnInternal = (wrapped: wrapped, fnName: operation): wrappedOutput => {
let (value, {sampleCount, xyPointLength} as extra) = wrapped
let reCall = (~value=value, ~extra=extra, ~fnName=fnName, ()) => {
applyFnInternal((value, extra), fnName)
}
let reCallUnwrapped = (~value=value, ~extra=extra, ~fnName=fnName, ()) => {
let (value, _) = applyFnInternal((value, extra), fnName)
value
}
let toPointSet = r => {
switch reCallUnwrapped(~value=r, ~fnName=#toDist(#toPointSet), ()) {
| #Dist(#PointSet(p)) => Ok(p)
| #Error(r) => Error(r)
| _ => Error(Other("Impossible error"))
}
}
let toPointSetAndReCall = v =>
toPointSet(v) |> E.R.fmap(r => reCallUnwrapped(~value=#PointSet(r), ()))
let newVal: outputType = switch (fnName, value) {
// | (#toFloat(n), v) => toFloat(toPointSet, v, n)
| (#toFloat(fnName), _) =>
toFloat(toPointSet, fnName, value) |> E.R.fmap(r => #Float(r)) |> fromResult
| (#toDist(#normalize), #PointSet(r)) => #Dist(#PointSet(PointSetDist.T.normalize(r)))
| (#toDist(#normalize), #Symbolic(_)) => #Dist(value)
| (#toDist(#normalize), #SampleSet(_)) => #Dist(value)
| (#toDist(#toPointSet), _) =>
value |> distToPointSet(sampleCount) |> E.R.fmap(r => #Dist(#PointSet(r))) |> fromResult
| (#toDist(#toSampleSet(n)), _) =>
value |> sampleN(n) |> E.R.fmap(r => #Dist(#SampleSet(r))) |> fromResult
| (#toDistCombination(#Algebraic, operation, p2), p1) => {
// TODO: This could be more complex, to get possible simplification and similar.
let dist1 = sampleN(sampleCount, p1)
@ -153,43 +197,44 @@ let rec applyFnInternal = (wrapped: wrapped, fnName: operation): wrapped => {
Belt.Array.zip(d1, d2) |> E.A.fmap(((a, b)) => Operation.Algebraic.toFn(operation, a, b))
})
switch samples {
| Ok(r) => #SampleSet(r)
| Ok(r) => #Dist(#SampleSet(r))
| Error(e) => #Error(e)
}
}
| (#toDistCombination(#Pointwise, operation, p2), p1) =>
switch (
applyFnInternal((p1, extra), #toDist(#toPointSet)),
applyFnInternal((p2, extra), #toDist(#toPointSet)),
toPointSet(p1),
toPointSet(p2)
) {
| ((#XYShape(p1), _), (#XYShape(p2), _)) =>
#XYShape(PointSetDist.combinePointwise(combinationToFn(operation), p1, p2))
| _ => #Error(Other("No Match or not supported"))
| (Ok(p1), Ok(p2)) =>
// TODO: If the dist is symbolic, then it doesn't need to be converted into a pointSet
#Dist(#PointSet(PointSetDist.combinePointwise(combinationToFn(operation), p1, p2)))
| (_, _) => #Error(Other("No Match or not supported"))
}
| _ => #Error(Other("No Match or not supported"))
}
(newVal, {sampleCount: sampleCount, xyPointLength: xyPointLength})
}
let applyFn = (wrapped, fnName): wrapped => {
let (v, extra) as result = applyFnInternal(wrapped, fnName)
switch v {
| #Error(NeedsPointSetConversion) => {
let convertedToPointSet = applyFnInternal(wrapped, #toDist(#toPointSet))
applyFnInternal(convertedToPointSet, fnName)
}
| #Error(InputsNeedPointSetConversion) => {
let altDist = switch fnName {
| #toDistCombination(p1, p2, dist) => {
let (newDist, _) = applyFnInternal((dist, extra), #toDist(#toPointSet))
applyFnInternal(wrapped, #toDistCombination(p1, p2, newDist))
}
| _ => (#Error(Other("Not needed")), extra)
}
altDist
}
| _ => result
}
}
// let applyFn = (wrapped, fnName): wrapped => {
// let (v, extra) as result = applyFnInternal(wrapped, fnName)
// switch v {
// | #Error(NeedsPointSetConversion) => {
// let convertedToPointSet = applyFnInternal(wrapped, #toDist(#toPointSet))
// applyFnInternal(convertedToPointSet, fnName)
// }
// | #Error(InputsNeedPointSetConversion) => {
// let altDist = switch fnName {
// | #toDistCombination(p1, p2, dist) => {
// let (newDist, _) = applyFnInternal((dist, extra), #toDist(#toPointSet))
// applyFnInternal(wrapped, #toDistCombination(p1, p2, newDist))
// }
// | _ => (#Error(Other("Not needed")), extra)
// }
// altDist
// }
// | _ => result
// }
// }
let foo = exampleDist->wrapWithParams(genericParams)->applyFn(#toDist(#normalize))
// let foo = exampleDist->wrapWithParams(genericParams)->applyFn(#toDist(#normalize))