Minor refactors
This commit is contained in:
parent
e93b500d4a
commit
1a2ce5bfa0
|
@ -1,19 +1,28 @@
|
|||
type symboliDist = SymbolicDistTypes.symbolicDist
|
||||
|
||||
type error =
|
||||
| NeedsPointSetConversion
|
||||
| InputsNeedPointSetConversion
|
||||
| NotYetImplemented
|
||||
| ImpossiblePath
|
||||
| Other(string)
|
||||
|
||||
type genericDist = [
|
||||
| #XYShape(PointSetTypes.pointSetDist)
|
||||
| #PointSet(PointSetTypes.pointSetDist)
|
||||
| #SampleSet(array<float>)
|
||||
| #Symbolic(symboliDist)
|
||||
| #Symbolic(SymbolicDistTypes.symbolicDist)
|
||||
]
|
||||
|
||||
type outputType = [
|
||||
| #Dist(genericDist)
|
||||
| #Error(error)
|
||||
| #Float(float)
|
||||
]
|
||||
|
||||
let fromResult = (r: result<outputType, error>): outputType =>
|
||||
switch r {
|
||||
| Ok(o) => o
|
||||
| Error(e) => #Error(e)
|
||||
}
|
||||
|
||||
type direction = [
|
||||
| #Algebraic
|
||||
| #Pointwise
|
||||
|
@ -71,10 +80,11 @@ let genericParams = {
|
|||
}
|
||||
|
||||
type wrapped = (genericDist, params)
|
||||
type wrappedOutput = (outputType, params)
|
||||
|
||||
let wrapWithParams = (g: genericDist, f: params): wrapped => (g, f)
|
||||
|
||||
let exampleDist: genericDist = #XYShape(
|
||||
let exampleDist: genericDist = #PointSet(
|
||||
Discrete(Discrete.make(~integralSumCache=Some(1.0), {xs: [3.0], ys: [1.0]})),
|
||||
)
|
||||
|
||||
|
@ -106,45 +116,79 @@ module AlgebraicCombination = {
|
|||
|
||||
let sampleN = (n, genericDist) => {
|
||||
switch genericDist {
|
||||
| #XYShape(r) => Ok(PointSetDist.sampleNRendered(n, r))
|
||||
| #PointSet(r) => Ok(PointSetDist.sampleNRendered(n, r))
|
||||
| #Symbolic(r) => Ok(SymbolicDist.T.sampleN(n, r))
|
||||
| #SampleSet(r) => Error(NotYetImplemented)
|
||||
| #Error(r) => Error(r)
|
||||
| _ => Error(NotYetImplemented)
|
||||
| #SampleSet(_) => Error(NotYetImplemented)
|
||||
}
|
||||
}
|
||||
|
||||
let rec applyFnInternal = (wrapped: wrapped, fnName: operation): wrapped => {
|
||||
let (v, {sampleCount, xyPointLength} as extra) = wrapped
|
||||
let newVal: genericDist = switch (fnName, v) {
|
||||
| (#toFloat(n), #XYShape(r)) => #Float(PointSetDist.operate(n, r))
|
||||
| (#toFloat(n), #Symbolic(r)) =>
|
||||
switch SymbolicDist.T.operate(n, r) {
|
||||
| Ok(float) => #Float(float)
|
||||
| Error(e) => #Error(Other(e))
|
||||
let toFloat = (
|
||||
toPointSet: genericDist => result<PointSetTypes.pointSetDist, error>,
|
||||
fnName,
|
||||
value,
|
||||
) => {
|
||||
switch value {
|
||||
| #Symbolic(r) if Belt.Result.isOk(SymbolicDist.T.operate(fnName, r)) =>
|
||||
switch SymbolicDist.T.operate(fnName, r) {
|
||||
| Ok(float) => Ok(float)
|
||||
| Error(_) => Error(ImpossiblePath)
|
||||
}
|
||||
| (#toFloat(n), #SampleSet(_)) => #Error(NeedsPointSetConversion)
|
||||
| (#toDist(#normalize), #XYShape(r)) => #XYShape(PointSetDist.T.normalize(r))
|
||||
| (#toDist(#normalize), #Symbolic(_)) => v
|
||||
| (#toDist(#normalize), #SampleSet(_)) => v
|
||||
| (#toDist(#toPointSet), #XYShape(_)) => v
|
||||
| (#toDist(#toPointSet), #Symbolic(r)) => #XYShape(SymbolicDist.T.toPointSetDist(sampleCount, r))
|
||||
| (#toDist(#toPointSet), #SampleSet(r)) => {
|
||||
| #PointSet(r) => Ok(PointSetDist.operate(fnName, r))
|
||||
| _ =>
|
||||
switch toPointSet(value) {
|
||||
| Ok(r) => Ok(PointSetDist.operate(fnName, r))
|
||||
| Error(r) => Error(r)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
let distToPointSet = (sampleCount, dist: genericDist) => {
|
||||
switch dist {
|
||||
| #PointSet(pointSet) => Ok(pointSet)
|
||||
| #Symbolic(r) => Ok(SymbolicDist.T.toPointSetDist(sampleCount, r))
|
||||
| #SampleSet(r) => {
|
||||
let response = SampleSet.toPointSetDist(
|
||||
~samples=r,
|
||||
~samplingInputs=defaultSamplingInputs,
|
||||
(),
|
||||
).pointSetDist
|
||||
switch response {
|
||||
| Some(r) => #XYShape(r)
|
||||
| None => #Error(Other("Failed to convert sample into shape"))
|
||||
| Some(r) => Ok(r)
|
||||
| None => Error(Other("Converting sampleSet to pointSet failed"))
|
||||
}
|
||||
}
|
||||
| (#toDist(#toSampleSet(n)), r) =>
|
||||
switch sampleN(n, r) {
|
||||
| Ok(r) => #SampleSet(r)
|
||||
| Error(r) => #Error(r)
|
||||
}
|
||||
}
|
||||
|
||||
let rec applyFnInternal = (wrapped: wrapped, fnName: operation): wrappedOutput => {
|
||||
let (value, {sampleCount, xyPointLength} as extra) = wrapped
|
||||
let reCall = (~value=value, ~extra=extra, ~fnName=fnName, ()) => {
|
||||
applyFnInternal((value, extra), fnName)
|
||||
}
|
||||
let reCallUnwrapped = (~value=value, ~extra=extra, ~fnName=fnName, ()) => {
|
||||
let (value, _) = applyFnInternal((value, extra), fnName)
|
||||
value
|
||||
}
|
||||
let toPointSet = r => {
|
||||
switch reCallUnwrapped(~value=r, ~fnName=#toDist(#toPointSet), ()) {
|
||||
| #Dist(#PointSet(p)) => Ok(p)
|
||||
| #Error(r) => Error(r)
|
||||
| _ => Error(Other("Impossible error"))
|
||||
}
|
||||
}
|
||||
let toPointSetAndReCall = v =>
|
||||
toPointSet(v) |> E.R.fmap(r => reCallUnwrapped(~value=#PointSet(r), ()))
|
||||
let newVal: outputType = switch (fnName, value) {
|
||||
// | (#toFloat(n), v) => toFloat(toPointSet, v, n)
|
||||
| (#toFloat(fnName), _) =>
|
||||
toFloat(toPointSet, fnName, value) |> E.R.fmap(r => #Float(r)) |> fromResult
|
||||
| (#toDist(#normalize), #PointSet(r)) => #Dist(#PointSet(PointSetDist.T.normalize(r)))
|
||||
| (#toDist(#normalize), #Symbolic(_)) => #Dist(value)
|
||||
| (#toDist(#normalize), #SampleSet(_)) => #Dist(value)
|
||||
| (#toDist(#toPointSet), _) =>
|
||||
value |> distToPointSet(sampleCount) |> E.R.fmap(r => #Dist(#PointSet(r))) |> fromResult
|
||||
| (#toDist(#toSampleSet(n)), _) =>
|
||||
value |> sampleN(n) |> E.R.fmap(r => #Dist(#SampleSet(r))) |> fromResult
|
||||
| (#toDistCombination(#Algebraic, operation, p2), p1) => {
|
||||
// TODO: This could be more complex, to get possible simplification and similar.
|
||||
let dist1 = sampleN(sampleCount, p1)
|
||||
|
@ -153,43 +197,44 @@ let rec applyFnInternal = (wrapped: wrapped, fnName: operation): wrapped => {
|
|||
Belt.Array.zip(d1, d2) |> E.A.fmap(((a, b)) => Operation.Algebraic.toFn(operation, a, b))
|
||||
})
|
||||
switch samples {
|
||||
| Ok(r) => #SampleSet(r)
|
||||
| Ok(r) => #Dist(#SampleSet(r))
|
||||
| Error(e) => #Error(e)
|
||||
}
|
||||
}
|
||||
| (#toDistCombination(#Pointwise, operation, p2), p1) =>
|
||||
switch (
|
||||
applyFnInternal((p1, extra), #toDist(#toPointSet)),
|
||||
applyFnInternal((p2, extra), #toDist(#toPointSet)),
|
||||
toPointSet(p1),
|
||||
toPointSet(p2)
|
||||
) {
|
||||
| ((#XYShape(p1), _), (#XYShape(p2), _)) =>
|
||||
#XYShape(PointSetDist.combinePointwise(combinationToFn(operation), p1, p2))
|
||||
| _ => #Error(Other("No Match or not supported"))
|
||||
| (Ok(p1), Ok(p2)) =>
|
||||
// TODO: If the dist is symbolic, then it doesn't need to be converted into a pointSet
|
||||
#Dist(#PointSet(PointSetDist.combinePointwise(combinationToFn(operation), p1, p2)))
|
||||
| (_, _) => #Error(Other("No Match or not supported"))
|
||||
}
|
||||
| _ => #Error(Other("No Match or not supported"))
|
||||
}
|
||||
(newVal, {sampleCount: sampleCount, xyPointLength: xyPointLength})
|
||||
}
|
||||
|
||||
let applyFn = (wrapped, fnName): wrapped => {
|
||||
let (v, extra) as result = applyFnInternal(wrapped, fnName)
|
||||
switch v {
|
||||
| #Error(NeedsPointSetConversion) => {
|
||||
let convertedToPointSet = applyFnInternal(wrapped, #toDist(#toPointSet))
|
||||
applyFnInternal(convertedToPointSet, fnName)
|
||||
}
|
||||
| #Error(InputsNeedPointSetConversion) => {
|
||||
let altDist = switch fnName {
|
||||
| #toDistCombination(p1, p2, dist) => {
|
||||
let (newDist, _) = applyFnInternal((dist, extra), #toDist(#toPointSet))
|
||||
applyFnInternal(wrapped, #toDistCombination(p1, p2, newDist))
|
||||
}
|
||||
| _ => (#Error(Other("Not needed")), extra)
|
||||
}
|
||||
altDist
|
||||
}
|
||||
| _ => result
|
||||
}
|
||||
}
|
||||
// let applyFn = (wrapped, fnName): wrapped => {
|
||||
// let (v, extra) as result = applyFnInternal(wrapped, fnName)
|
||||
// switch v {
|
||||
// | #Error(NeedsPointSetConversion) => {
|
||||
// let convertedToPointSet = applyFnInternal(wrapped, #toDist(#toPointSet))
|
||||
// applyFnInternal(convertedToPointSet, fnName)
|
||||
// }
|
||||
// | #Error(InputsNeedPointSetConversion) => {
|
||||
// let altDist = switch fnName {
|
||||
// | #toDistCombination(p1, p2, dist) => {
|
||||
// let (newDist, _) = applyFnInternal((dist, extra), #toDist(#toPointSet))
|
||||
// applyFnInternal(wrapped, #toDistCombination(p1, p2, newDist))
|
||||
// }
|
||||
// | _ => (#Error(Other("Not needed")), extra)
|
||||
// }
|
||||
// altDist
|
||||
// }
|
||||
// | _ => result
|
||||
// }
|
||||
// }
|
||||
|
||||
let foo = exampleDist->wrapWithParams(genericParams)->applyFn(#toDist(#normalize))
|
||||
// let foo = exampleDist->wrapWithParams(genericParams)->applyFn(#toDist(#normalize))
|
||||
|
|
Loading…
Reference in New Issue
Block a user