Minor refactors
This commit is contained in:
parent
e93b500d4a
commit
1a2ce5bfa0
|
@ -1,19 +1,28 @@
|
||||||
type symboliDist = SymbolicDistTypes.symbolicDist
|
|
||||||
|
|
||||||
type error =
|
type error =
|
||||||
| NeedsPointSetConversion
|
| NeedsPointSetConversion
|
||||||
| InputsNeedPointSetConversion
|
| InputsNeedPointSetConversion
|
||||||
| NotYetImplemented
|
| NotYetImplemented
|
||||||
|
| ImpossiblePath
|
||||||
| Other(string)
|
| Other(string)
|
||||||
|
|
||||||
type genericDist = [
|
type genericDist = [
|
||||||
| #XYShape(PointSetTypes.pointSetDist)
|
| #PointSet(PointSetTypes.pointSetDist)
|
||||||
| #SampleSet(array<float>)
|
| #SampleSet(array<float>)
|
||||||
| #Symbolic(symboliDist)
|
| #Symbolic(SymbolicDistTypes.symbolicDist)
|
||||||
|
]
|
||||||
|
|
||||||
|
type outputType = [
|
||||||
|
| #Dist(genericDist)
|
||||||
| #Error(error)
|
| #Error(error)
|
||||||
| #Float(float)
|
| #Float(float)
|
||||||
]
|
]
|
||||||
|
|
||||||
|
let fromResult = (r: result<outputType, error>): outputType =>
|
||||||
|
switch r {
|
||||||
|
| Ok(o) => o
|
||||||
|
| Error(e) => #Error(e)
|
||||||
|
}
|
||||||
|
|
||||||
type direction = [
|
type direction = [
|
||||||
| #Algebraic
|
| #Algebraic
|
||||||
| #Pointwise
|
| #Pointwise
|
||||||
|
@ -71,10 +80,11 @@ let genericParams = {
|
||||||
}
|
}
|
||||||
|
|
||||||
type wrapped = (genericDist, params)
|
type wrapped = (genericDist, params)
|
||||||
|
type wrappedOutput = (outputType, params)
|
||||||
|
|
||||||
let wrapWithParams = (g: genericDist, f: params): wrapped => (g, f)
|
let wrapWithParams = (g: genericDist, f: params): wrapped => (g, f)
|
||||||
|
|
||||||
let exampleDist: genericDist = #XYShape(
|
let exampleDist: genericDist = #PointSet(
|
||||||
Discrete(Discrete.make(~integralSumCache=Some(1.0), {xs: [3.0], ys: [1.0]})),
|
Discrete(Discrete.make(~integralSumCache=Some(1.0), {xs: [3.0], ys: [1.0]})),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -106,45 +116,79 @@ module AlgebraicCombination = {
|
||||||
|
|
||||||
let sampleN = (n, genericDist) => {
|
let sampleN = (n, genericDist) => {
|
||||||
switch genericDist {
|
switch genericDist {
|
||||||
| #XYShape(r) => Ok(PointSetDist.sampleNRendered(n, r))
|
| #PointSet(r) => Ok(PointSetDist.sampleNRendered(n, r))
|
||||||
| #Symbolic(r) => Ok(SymbolicDist.T.sampleN(n, r))
|
| #Symbolic(r) => Ok(SymbolicDist.T.sampleN(n, r))
|
||||||
| #SampleSet(r) => Error(NotYetImplemented)
|
| #SampleSet(_) => Error(NotYetImplemented)
|
||||||
| #Error(r) => Error(r)
|
|
||||||
| _ => Error(NotYetImplemented)
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
let rec applyFnInternal = (wrapped: wrapped, fnName: operation): wrapped => {
|
let toFloat = (
|
||||||
let (v, {sampleCount, xyPointLength} as extra) = wrapped
|
toPointSet: genericDist => result<PointSetTypes.pointSetDist, error>,
|
||||||
let newVal: genericDist = switch (fnName, v) {
|
fnName,
|
||||||
| (#toFloat(n), #XYShape(r)) => #Float(PointSetDist.operate(n, r))
|
value,
|
||||||
| (#toFloat(n), #Symbolic(r)) =>
|
) => {
|
||||||
switch SymbolicDist.T.operate(n, r) {
|
switch value {
|
||||||
| Ok(float) => #Float(float)
|
| #Symbolic(r) if Belt.Result.isOk(SymbolicDist.T.operate(fnName, r)) =>
|
||||||
| Error(e) => #Error(Other(e))
|
switch SymbolicDist.T.operate(fnName, r) {
|
||||||
|
| Ok(float) => Ok(float)
|
||||||
|
| Error(_) => Error(ImpossiblePath)
|
||||||
}
|
}
|
||||||
| (#toFloat(n), #SampleSet(_)) => #Error(NeedsPointSetConversion)
|
| #PointSet(r) => Ok(PointSetDist.operate(fnName, r))
|
||||||
| (#toDist(#normalize), #XYShape(r)) => #XYShape(PointSetDist.T.normalize(r))
|
| _ =>
|
||||||
| (#toDist(#normalize), #Symbolic(_)) => v
|
switch toPointSet(value) {
|
||||||
| (#toDist(#normalize), #SampleSet(_)) => v
|
| Ok(r) => Ok(PointSetDist.operate(fnName, r))
|
||||||
| (#toDist(#toPointSet), #XYShape(_)) => v
|
| Error(r) => Error(r)
|
||||||
| (#toDist(#toPointSet), #Symbolic(r)) => #XYShape(SymbolicDist.T.toPointSetDist(sampleCount, r))
|
}
|
||||||
| (#toDist(#toPointSet), #SampleSet(r)) => {
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
let distToPointSet = (sampleCount, dist: genericDist) => {
|
||||||
|
switch dist {
|
||||||
|
| #PointSet(pointSet) => Ok(pointSet)
|
||||||
|
| #Symbolic(r) => Ok(SymbolicDist.T.toPointSetDist(sampleCount, r))
|
||||||
|
| #SampleSet(r) => {
|
||||||
let response = SampleSet.toPointSetDist(
|
let response = SampleSet.toPointSetDist(
|
||||||
~samples=r,
|
~samples=r,
|
||||||
~samplingInputs=defaultSamplingInputs,
|
~samplingInputs=defaultSamplingInputs,
|
||||||
(),
|
(),
|
||||||
).pointSetDist
|
).pointSetDist
|
||||||
switch response {
|
switch response {
|
||||||
| Some(r) => #XYShape(r)
|
| Some(r) => Ok(r)
|
||||||
| None => #Error(Other("Failed to convert sample into shape"))
|
| None => Error(Other("Converting sampleSet to pointSet failed"))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
| (#toDist(#toSampleSet(n)), r) =>
|
|
||||||
switch sampleN(n, r) {
|
|
||||||
| Ok(r) => #SampleSet(r)
|
|
||||||
| Error(r) => #Error(r)
|
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
let rec applyFnInternal = (wrapped: wrapped, fnName: operation): wrappedOutput => {
|
||||||
|
let (value, {sampleCount, xyPointLength} as extra) = wrapped
|
||||||
|
let reCall = (~value=value, ~extra=extra, ~fnName=fnName, ()) => {
|
||||||
|
applyFnInternal((value, extra), fnName)
|
||||||
|
}
|
||||||
|
let reCallUnwrapped = (~value=value, ~extra=extra, ~fnName=fnName, ()) => {
|
||||||
|
let (value, _) = applyFnInternal((value, extra), fnName)
|
||||||
|
value
|
||||||
|
}
|
||||||
|
let toPointSet = r => {
|
||||||
|
switch reCallUnwrapped(~value=r, ~fnName=#toDist(#toPointSet), ()) {
|
||||||
|
| #Dist(#PointSet(p)) => Ok(p)
|
||||||
|
| #Error(r) => Error(r)
|
||||||
|
| _ => Error(Other("Impossible error"))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
let toPointSetAndReCall = v =>
|
||||||
|
toPointSet(v) |> E.R.fmap(r => reCallUnwrapped(~value=#PointSet(r), ()))
|
||||||
|
let newVal: outputType = switch (fnName, value) {
|
||||||
|
// | (#toFloat(n), v) => toFloat(toPointSet, v, n)
|
||||||
|
| (#toFloat(fnName), _) =>
|
||||||
|
toFloat(toPointSet, fnName, value) |> E.R.fmap(r => #Float(r)) |> fromResult
|
||||||
|
| (#toDist(#normalize), #PointSet(r)) => #Dist(#PointSet(PointSetDist.T.normalize(r)))
|
||||||
|
| (#toDist(#normalize), #Symbolic(_)) => #Dist(value)
|
||||||
|
| (#toDist(#normalize), #SampleSet(_)) => #Dist(value)
|
||||||
|
| (#toDist(#toPointSet), _) =>
|
||||||
|
value |> distToPointSet(sampleCount) |> E.R.fmap(r => #Dist(#PointSet(r))) |> fromResult
|
||||||
|
| (#toDist(#toSampleSet(n)), _) =>
|
||||||
|
value |> sampleN(n) |> E.R.fmap(r => #Dist(#SampleSet(r))) |> fromResult
|
||||||
| (#toDistCombination(#Algebraic, operation, p2), p1) => {
|
| (#toDistCombination(#Algebraic, operation, p2), p1) => {
|
||||||
// TODO: This could be more complex, to get possible simplification and similar.
|
// TODO: This could be more complex, to get possible simplification and similar.
|
||||||
let dist1 = sampleN(sampleCount, p1)
|
let dist1 = sampleN(sampleCount, p1)
|
||||||
|
@ -153,43 +197,44 @@ let rec applyFnInternal = (wrapped: wrapped, fnName: operation): wrapped => {
|
||||||
Belt.Array.zip(d1, d2) |> E.A.fmap(((a, b)) => Operation.Algebraic.toFn(operation, a, b))
|
Belt.Array.zip(d1, d2) |> E.A.fmap(((a, b)) => Operation.Algebraic.toFn(operation, a, b))
|
||||||
})
|
})
|
||||||
switch samples {
|
switch samples {
|
||||||
| Ok(r) => #SampleSet(r)
|
| Ok(r) => #Dist(#SampleSet(r))
|
||||||
| Error(e) => #Error(e)
|
| Error(e) => #Error(e)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
| (#toDistCombination(#Pointwise, operation, p2), p1) =>
|
| (#toDistCombination(#Pointwise, operation, p2), p1) =>
|
||||||
switch (
|
switch (
|
||||||
applyFnInternal((p1, extra), #toDist(#toPointSet)),
|
toPointSet(p1),
|
||||||
applyFnInternal((p2, extra), #toDist(#toPointSet)),
|
toPointSet(p2)
|
||||||
) {
|
) {
|
||||||
| ((#XYShape(p1), _), (#XYShape(p2), _)) =>
|
| (Ok(p1), Ok(p2)) =>
|
||||||
#XYShape(PointSetDist.combinePointwise(combinationToFn(operation), p1, p2))
|
// TODO: If the dist is symbolic, then it doesn't need to be converted into a pointSet
|
||||||
| _ => #Error(Other("No Match or not supported"))
|
#Dist(#PointSet(PointSetDist.combinePointwise(combinationToFn(operation), p1, p2)))
|
||||||
|
| (_, _) => #Error(Other("No Match or not supported"))
|
||||||
}
|
}
|
||||||
| _ => #Error(Other("No Match or not supported"))
|
| _ => #Error(Other("No Match or not supported"))
|
||||||
}
|
}
|
||||||
(newVal, {sampleCount: sampleCount, xyPointLength: xyPointLength})
|
(newVal, {sampleCount: sampleCount, xyPointLength: xyPointLength})
|
||||||
}
|
}
|
||||||
|
|
||||||
let applyFn = (wrapped, fnName): wrapped => {
|
// let applyFn = (wrapped, fnName): wrapped => {
|
||||||
let (v, extra) as result = applyFnInternal(wrapped, fnName)
|
// let (v, extra) as result = applyFnInternal(wrapped, fnName)
|
||||||
switch v {
|
// switch v {
|
||||||
| #Error(NeedsPointSetConversion) => {
|
// | #Error(NeedsPointSetConversion) => {
|
||||||
let convertedToPointSet = applyFnInternal(wrapped, #toDist(#toPointSet))
|
// let convertedToPointSet = applyFnInternal(wrapped, #toDist(#toPointSet))
|
||||||
applyFnInternal(convertedToPointSet, fnName)
|
// applyFnInternal(convertedToPointSet, fnName)
|
||||||
}
|
// }
|
||||||
| #Error(InputsNeedPointSetConversion) => {
|
// | #Error(InputsNeedPointSetConversion) => {
|
||||||
let altDist = switch fnName {
|
// let altDist = switch fnName {
|
||||||
| #toDistCombination(p1, p2, dist) => {
|
// | #toDistCombination(p1, p2, dist) => {
|
||||||
let (newDist, _) = applyFnInternal((dist, extra), #toDist(#toPointSet))
|
// let (newDist, _) = applyFnInternal((dist, extra), #toDist(#toPointSet))
|
||||||
applyFnInternal(wrapped, #toDistCombination(p1, p2, newDist))
|
// applyFnInternal(wrapped, #toDistCombination(p1, p2, newDist))
|
||||||
}
|
// }
|
||||||
| _ => (#Error(Other("Not needed")), extra)
|
// | _ => (#Error(Other("Not needed")), extra)
|
||||||
}
|
// }
|
||||||
altDist
|
// altDist
|
||||||
}
|
// }
|
||||||
| _ => result
|
// | _ => result
|
||||||
}
|
// }
|
||||||
}
|
// }
|
||||||
|
|
||||||
let foo = exampleDist->wrapWithParams(genericParams)->applyFn(#toDist(#normalize))
|
// let foo = exampleDist->wrapWithParams(genericParams)->applyFn(#toDist(#normalize))
|
||||||
|
|
Loading…
Reference in New Issue
Block a user