Minor refactors

This commit is contained in:
Ozzie Gooen 2022-03-25 22:11:27 -04:00 committed by GitHub
parent e93b500d4a
commit 1a2ce5bfa0
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -1,19 +1,28 @@
type symboliDist = SymbolicDistTypes.symbolicDist
type error = type error =
| NeedsPointSetConversion | NeedsPointSetConversion
| InputsNeedPointSetConversion | InputsNeedPointSetConversion
| NotYetImplemented | NotYetImplemented
| ImpossiblePath
| Other(string) | Other(string)
type genericDist = [ type genericDist = [
| #XYShape(PointSetTypes.pointSetDist) | #PointSet(PointSetTypes.pointSetDist)
| #SampleSet(array<float>) | #SampleSet(array<float>)
| #Symbolic(symboliDist) | #Symbolic(SymbolicDistTypes.symbolicDist)
]
type outputType = [
| #Dist(genericDist)
| #Error(error) | #Error(error)
| #Float(float) | #Float(float)
] ]
let fromResult = (r: result<outputType, error>): outputType =>
switch r {
| Ok(o) => o
| Error(e) => #Error(e)
}
type direction = [ type direction = [
| #Algebraic | #Algebraic
| #Pointwise | #Pointwise
@ -71,10 +80,11 @@ let genericParams = {
} }
type wrapped = (genericDist, params) type wrapped = (genericDist, params)
type wrappedOutput = (outputType, params)
let wrapWithParams = (g: genericDist, f: params): wrapped => (g, f) let wrapWithParams = (g: genericDist, f: params): wrapped => (g, f)
let exampleDist: genericDist = #XYShape( let exampleDist: genericDist = #PointSet(
Discrete(Discrete.make(~integralSumCache=Some(1.0), {xs: [3.0], ys: [1.0]})), Discrete(Discrete.make(~integralSumCache=Some(1.0), {xs: [3.0], ys: [1.0]})),
) )
@ -106,45 +116,79 @@ module AlgebraicCombination = {
let sampleN = (n, genericDist) => { let sampleN = (n, genericDist) => {
switch genericDist { switch genericDist {
| #XYShape(r) => Ok(PointSetDist.sampleNRendered(n, r)) | #PointSet(r) => Ok(PointSetDist.sampleNRendered(n, r))
| #Symbolic(r) => Ok(SymbolicDist.T.sampleN(n, r)) | #Symbolic(r) => Ok(SymbolicDist.T.sampleN(n, r))
| #SampleSet(r) => Error(NotYetImplemented) | #SampleSet(_) => Error(NotYetImplemented)
| #Error(r) => Error(r)
| _ => Error(NotYetImplemented)
} }
} }
let rec applyFnInternal = (wrapped: wrapped, fnName: operation): wrapped => { let toFloat = (
let (v, {sampleCount, xyPointLength} as extra) = wrapped toPointSet: genericDist => result<PointSetTypes.pointSetDist, error>,
let newVal: genericDist = switch (fnName, v) { fnName,
| (#toFloat(n), #XYShape(r)) => #Float(PointSetDist.operate(n, r)) value,
| (#toFloat(n), #Symbolic(r)) => ) => {
switch SymbolicDist.T.operate(n, r) { switch value {
| Ok(float) => #Float(float) | #Symbolic(r) if Belt.Result.isOk(SymbolicDist.T.operate(fnName, r)) =>
| Error(e) => #Error(Other(e)) switch SymbolicDist.T.operate(fnName, r) {
| Ok(float) => Ok(float)
| Error(_) => Error(ImpossiblePath)
} }
| (#toFloat(n), #SampleSet(_)) => #Error(NeedsPointSetConversion) | #PointSet(r) => Ok(PointSetDist.operate(fnName, r))
| (#toDist(#normalize), #XYShape(r)) => #XYShape(PointSetDist.T.normalize(r)) | _ =>
| (#toDist(#normalize), #Symbolic(_)) => v switch toPointSet(value) {
| (#toDist(#normalize), #SampleSet(_)) => v | Ok(r) => Ok(PointSetDist.operate(fnName, r))
| (#toDist(#toPointSet), #XYShape(_)) => v | Error(r) => Error(r)
| (#toDist(#toPointSet), #Symbolic(r)) => #XYShape(SymbolicDist.T.toPointSetDist(sampleCount, r)) }
| (#toDist(#toPointSet), #SampleSet(r)) => { }
}
let distToPointSet = (sampleCount, dist: genericDist) => {
switch dist {
| #PointSet(pointSet) => Ok(pointSet)
| #Symbolic(r) => Ok(SymbolicDist.T.toPointSetDist(sampleCount, r))
| #SampleSet(r) => {
let response = SampleSet.toPointSetDist( let response = SampleSet.toPointSetDist(
~samples=r, ~samples=r,
~samplingInputs=defaultSamplingInputs, ~samplingInputs=defaultSamplingInputs,
(), (),
).pointSetDist ).pointSetDist
switch response { switch response {
| Some(r) => #XYShape(r) | Some(r) => Ok(r)
| None => #Error(Other("Failed to convert sample into shape")) | None => Error(Other("Converting sampleSet to pointSet failed"))
} }
} }
| (#toDist(#toSampleSet(n)), r) =>
switch sampleN(n, r) {
| Ok(r) => #SampleSet(r)
| Error(r) => #Error(r)
} }
}
let rec applyFnInternal = (wrapped: wrapped, fnName: operation): wrappedOutput => {
let (value, {sampleCount, xyPointLength} as extra) = wrapped
let reCall = (~value=value, ~extra=extra, ~fnName=fnName, ()) => {
applyFnInternal((value, extra), fnName)
}
let reCallUnwrapped = (~value=value, ~extra=extra, ~fnName=fnName, ()) => {
let (value, _) = applyFnInternal((value, extra), fnName)
value
}
let toPointSet = r => {
switch reCallUnwrapped(~value=r, ~fnName=#toDist(#toPointSet), ()) {
| #Dist(#PointSet(p)) => Ok(p)
| #Error(r) => Error(r)
| _ => Error(Other("Impossible error"))
}
}
let toPointSetAndReCall = v =>
toPointSet(v) |> E.R.fmap(r => reCallUnwrapped(~value=#PointSet(r), ()))
let newVal: outputType = switch (fnName, value) {
// | (#toFloat(n), v) => toFloat(toPointSet, v, n)
| (#toFloat(fnName), _) =>
toFloat(toPointSet, fnName, value) |> E.R.fmap(r => #Float(r)) |> fromResult
| (#toDist(#normalize), #PointSet(r)) => #Dist(#PointSet(PointSetDist.T.normalize(r)))
| (#toDist(#normalize), #Symbolic(_)) => #Dist(value)
| (#toDist(#normalize), #SampleSet(_)) => #Dist(value)
| (#toDist(#toPointSet), _) =>
value |> distToPointSet(sampleCount) |> E.R.fmap(r => #Dist(#PointSet(r))) |> fromResult
| (#toDist(#toSampleSet(n)), _) =>
value |> sampleN(n) |> E.R.fmap(r => #Dist(#SampleSet(r))) |> fromResult
| (#toDistCombination(#Algebraic, operation, p2), p1) => { | (#toDistCombination(#Algebraic, operation, p2), p1) => {
// TODO: This could be more complex, to get possible simplification and similar. // TODO: This could be more complex, to get possible simplification and similar.
let dist1 = sampleN(sampleCount, p1) let dist1 = sampleN(sampleCount, p1)
@ -153,43 +197,44 @@ let rec applyFnInternal = (wrapped: wrapped, fnName: operation): wrapped => {
Belt.Array.zip(d1, d2) |> E.A.fmap(((a, b)) => Operation.Algebraic.toFn(operation, a, b)) Belt.Array.zip(d1, d2) |> E.A.fmap(((a, b)) => Operation.Algebraic.toFn(operation, a, b))
}) })
switch samples { switch samples {
| Ok(r) => #SampleSet(r) | Ok(r) => #Dist(#SampleSet(r))
| Error(e) => #Error(e) | Error(e) => #Error(e)
} }
} }
| (#toDistCombination(#Pointwise, operation, p2), p1) => | (#toDistCombination(#Pointwise, operation, p2), p1) =>
switch ( switch (
applyFnInternal((p1, extra), #toDist(#toPointSet)), toPointSet(p1),
applyFnInternal((p2, extra), #toDist(#toPointSet)), toPointSet(p2)
) { ) {
| ((#XYShape(p1), _), (#XYShape(p2), _)) => | (Ok(p1), Ok(p2)) =>
#XYShape(PointSetDist.combinePointwise(combinationToFn(operation), p1, p2)) // TODO: If the dist is symbolic, then it doesn't need to be converted into a pointSet
| _ => #Error(Other("No Match or not supported")) #Dist(#PointSet(PointSetDist.combinePointwise(combinationToFn(operation), p1, p2)))
| (_, _) => #Error(Other("No Match or not supported"))
} }
| _ => #Error(Other("No Match or not supported")) | _ => #Error(Other("No Match or not supported"))
} }
(newVal, {sampleCount: sampleCount, xyPointLength: xyPointLength}) (newVal, {sampleCount: sampleCount, xyPointLength: xyPointLength})
} }
let applyFn = (wrapped, fnName): wrapped => { // let applyFn = (wrapped, fnName): wrapped => {
let (v, extra) as result = applyFnInternal(wrapped, fnName) // let (v, extra) as result = applyFnInternal(wrapped, fnName)
switch v { // switch v {
| #Error(NeedsPointSetConversion) => { // | #Error(NeedsPointSetConversion) => {
let convertedToPointSet = applyFnInternal(wrapped, #toDist(#toPointSet)) // let convertedToPointSet = applyFnInternal(wrapped, #toDist(#toPointSet))
applyFnInternal(convertedToPointSet, fnName) // applyFnInternal(convertedToPointSet, fnName)
} // }
| #Error(InputsNeedPointSetConversion) => { // | #Error(InputsNeedPointSetConversion) => {
let altDist = switch fnName { // let altDist = switch fnName {
| #toDistCombination(p1, p2, dist) => { // | #toDistCombination(p1, p2, dist) => {
let (newDist, _) = applyFnInternal((dist, extra), #toDist(#toPointSet)) // let (newDist, _) = applyFnInternal((dist, extra), #toDist(#toPointSet))
applyFnInternal(wrapped, #toDistCombination(p1, p2, newDist)) // applyFnInternal(wrapped, #toDistCombination(p1, p2, newDist))
} // }
| _ => (#Error(Other("Not needed")), extra) // | _ => (#Error(Other("Not needed")), extra)
} // }
altDist // altDist
} // }
| _ => result // | _ => result
} // }
} // }
let foo = exampleDist->wrapWithParams(genericParams)->applyFn(#toDist(#normalize)) // let foo = exampleDist->wrapWithParams(genericParams)->applyFn(#toDist(#normalize))