Fleshed out AlgebraicCombination
This commit is contained in:
parent
d490af38f0
commit
c5afb2d867
|
@ -115,6 +115,7 @@ let combineShapesContinuousContinuous = (
|
||||||
| #Multiply => (m1, m2) => m1 *. m2
|
| #Multiply => (m1, m2) => m1 *. m2
|
||||||
| #Divide => (m1, mInv2) => m1 *. mInv2
|
| #Divide => (m1, mInv2) => m1 *. mInv2
|
||||||
| #Exponentiate => (m1, mInv2) => m1 ** mInv2
|
| #Exponentiate => (m1, mInv2) => m1 ** mInv2
|
||||||
|
| #Log => (m1, m2) => log(m1) /. log(m2)
|
||||||
} // note: here, mInv2 = mean(1 / t2) ~= 1 / mean(t2)
|
} // note: here, mInv2 = mean(1 / t2) ~= 1 / mean(t2)
|
||||||
|
|
||||||
// TODO: I don't know what the variances are for exponentatiation
|
// TODO: I don't know what the variances are for exponentatiation
|
||||||
|
@ -232,6 +233,7 @@ let combineShapesContinuousDiscrete = (
|
||||||
}
|
}
|
||||||
| #Multiply
|
| #Multiply
|
||||||
| #Exponentiate
|
| #Exponentiate
|
||||||
|
| #Log
|
||||||
| #Divide =>
|
| #Divide =>
|
||||||
for j in 0 to t2n - 1 {
|
for j in 0 to t2n - 1 {
|
||||||
// creates a new continuous shape for each one of the discrete points, and collects them in outXYShapes.
|
// creates a new continuous shape for each one of the discrete points, and collects them in outXYShapes.
|
||||||
|
|
|
@ -41,7 +41,8 @@ let combineAlgebraically = (op: Operation.algebraicOperation, t1: t, t2: t): t =
|
||||||
| (Continuous(m1), Discrete(m2))
|
| (Continuous(m1), Discrete(m2))
|
||||||
| (Discrete(m2), Continuous(m1)) =>
|
| (Discrete(m2), Continuous(m1)) =>
|
||||||
Continuous.combineAlgebraicallyWithDiscrete(op, m1, m2) |> Continuous.T.toPointSetDist
|
Continuous.combineAlgebraicallyWithDiscrete(op, m1, m2) |> Continuous.T.toPointSetDist
|
||||||
| (Discrete(m1), Discrete(m2)) => Discrete.combineAlgebraically(op, m1, m2) |> Discrete.T.toPointSetDist
|
| (Discrete(m1), Discrete(m2)) =>
|
||||||
|
Discrete.combineAlgebraically(op, m1, m2) |> Discrete.T.toPointSetDist
|
||||||
| (m1, m2) => Mixed.combineAlgebraically(op, toMixed(m1), toMixed(m2)) |> Mixed.T.toPointSetDist
|
| (m1, m2) => Mixed.combineAlgebraically(op, toMixed(m1), toMixed(m2)) |> Mixed.T.toPointSetDist
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -55,12 +55,6 @@ module OperationType = {
|
||||||
| #Sample(int)
|
| #Sample(int)
|
||||||
]
|
]
|
||||||
|
|
||||||
type scale = [
|
|
||||||
| #Multiply
|
|
||||||
| #Exponentiate
|
|
||||||
| #Log
|
|
||||||
]
|
|
||||||
|
|
||||||
type t = [
|
type t = [
|
||||||
| #toFloat(toFloat)
|
| #toFloat(toFloat)
|
||||||
| #toDist(toDist)
|
| #toDist(toDist)
|
||||||
|
@ -73,6 +67,7 @@ type operation = OperationType.t
|
||||||
module T = {
|
module T = {
|
||||||
type t = genericDist
|
type t = genericDist
|
||||||
type toPointSetFn = genericDist => result<PointSetTypes.pointSetDist, error>
|
type toPointSetFn = genericDist => result<PointSetTypes.pointSetDist, error>
|
||||||
|
type toSampleSetFn = genericDist => result<array<float>, error>
|
||||||
let sampleN = (n, t: t) => {
|
let sampleN = (n, t: t) => {
|
||||||
switch t {
|
switch t {
|
||||||
| #PointSet(r) => Ok(PointSetDist.sampleNRendered(n, r))
|
| #PointSet(r) => Ok(PointSetDist.sampleNRendered(n, r))
|
||||||
|
@ -81,7 +76,7 @@ module T = {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
let toFloat = (toPointSet: toPointSetFn, fnName, t: genericDist) => {
|
let toFloat = (toPointSet: toPointSetFn, fnName, t: genericDist): result<float, error> => {
|
||||||
switch t {
|
switch t {
|
||||||
| #Symbolic(r) if Belt.Result.isOk(SymbolicDist.T.operate(fnName, r)) =>
|
| #Symbolic(r) if Belt.Result.isOk(SymbolicDist.T.operate(fnName, r)) =>
|
||||||
switch SymbolicDist.T.operate(fnName, r) {
|
switch SymbolicDist.T.operate(fnName, r) {
|
||||||
|
@ -104,7 +99,7 @@ module T = {
|
||||||
kernelWidth: None,
|
kernelWidth: None,
|
||||||
}
|
}
|
||||||
|
|
||||||
let toPointSet = (xyPointLength, t: t) => {
|
let toPointSet = (xyPointLength, t: t): result<PointSetTypes.pointSetDist, error> => {
|
||||||
switch t {
|
switch t {
|
||||||
| #PointSet(pointSet) => Ok(pointSet)
|
| #PointSet(pointSet) => Ok(pointSet)
|
||||||
| #Symbolic(r) => Ok(SymbolicDist.T.toPointSetDist(xyPointLength, r))
|
| #Symbolic(r) => Ok(SymbolicDist.T.toPointSetDist(xyPointLength, r))
|
||||||
|
@ -122,16 +117,82 @@ module T = {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
let algebraicCombination = (operation, sampleCount, dist1: t, dist2: t) => {
|
module AlgebraicCombination = {
|
||||||
let dist1 = sampleN(sampleCount, dist1)
|
let tryAnalyticalSimplification = (operation: OperationType.combination, t1: t, t2: t): option<
|
||||||
let dist2 = sampleN(sampleCount, dist2)
|
result<SymbolicDistTypes.symbolicDist, string>,
|
||||||
let samples = E.R.merge(dist1, dist2) |> E.R.fmap(((d1, d2)) => {
|
> =>
|
||||||
Belt.Array.zip(d1, d2) |> E.A.fmap(((a, b)) => Operation.Algebraic.toFn(operation, a, b))
|
switch (operation, t1, t2) {
|
||||||
})
|
| (operation, #Symbolic(d1), #Symbolic(d2)) =>
|
||||||
samples |> E.R.fmap(r => #SampleSet(r))
|
switch SymbolicDist.T.tryAnalyticalSimplification(d1, d2, operation) {
|
||||||
|
| #AnalyticalSolution(symbolicDist) => Some(Ok(symbolicDist))
|
||||||
|
| #Error(er) => Some(Error(er))
|
||||||
|
| #NoSolution => None
|
||||||
|
}
|
||||||
|
| _ => None
|
||||||
|
}
|
||||||
|
|
||||||
|
let runConvolution = (
|
||||||
|
toPointSet: toPointSetFn,
|
||||||
|
operation: OperationType.combination,
|
||||||
|
t1: t,
|
||||||
|
t2: t,
|
||||||
|
) =>
|
||||||
|
E.R.merge(toPointSet(t1), toPointSet(t2)) |> E.R.fmap(((a, b)) =>
|
||||||
|
PointSetDist.combineAlgebraically(operation, a, b)
|
||||||
|
)
|
||||||
|
|
||||||
|
let runMonteCarlo = (
|
||||||
|
toSampleSet: toSampleSetFn,
|
||||||
|
operation: OperationType.combination,
|
||||||
|
t1: t,
|
||||||
|
t2: t,
|
||||||
|
) => {
|
||||||
|
E.R.merge(toSampleSet(t1), toSampleSet(t2)) |> E.R.fmap(((a, b)) => {
|
||||||
|
Belt.Array.zip(a, b) |> E.A.fmap(((a, b)) => Operation.Algebraic.toFn(operation, a, b))
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
//I'm (Ozzie) really just guessing here, very little idea what's best
|
||||||
|
let expectedConvolutionCost: t => int = x =>
|
||||||
|
switch x {
|
||||||
|
| #Symbolic(#Float(_)) => 1
|
||||||
|
| #Symbolic(_) => 1000
|
||||||
|
| #PointSet(Discrete(m)) => m.xyShape |> XYShape.T.length
|
||||||
|
| #PointSet(Mixed(_)) => 1000
|
||||||
|
| #PointSet(Continuous(_)) => 1000
|
||||||
|
| _ => 1000
|
||||||
|
}
|
||||||
|
|
||||||
|
let chooseConvolutionOrMonteCarlo = (t1: t, t2: t) =>
|
||||||
|
expectedConvolutionCost(t1) * expectedConvolutionCost(t2) > 10000
|
||||||
|
? #CalculateWithMonteCarlo
|
||||||
|
: #CalculateWithConvolution
|
||||||
|
|
||||||
|
let run = (
|
||||||
|
toPointSet: toPointSetFn,
|
||||||
|
toSampleSet: toSampleSetFn,
|
||||||
|
algebraicOp,
|
||||||
|
t1: t,
|
||||||
|
t2: t,
|
||||||
|
): result<t, error> => {
|
||||||
|
switch tryAnalyticalSimplification(algebraicOp, t1, t2) {
|
||||||
|
| Some(Ok(symbolicDist)) => Ok(#Symbolic(symbolicDist))
|
||||||
|
| Some(Error(e)) => Error(Other(e))
|
||||||
|
| None =>
|
||||||
|
switch chooseConvolutionOrMonteCarlo(t1, t2) {
|
||||||
|
| #CalculateWithMonteCarlo =>
|
||||||
|
runMonteCarlo(toSampleSet, algebraicOp, t1, t2) |> E.R.fmap(r => #SampleSet(r))
|
||||||
|
| #CalculateWithConvolution =>
|
||||||
|
runConvolution(toPointSet, algebraicOp, t1, t2) |> E.R.fmap(r => #PointSet(r))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
let pointwiseCombination = (toPointSet: toPointSetFn, operation, t1: t, t2: t) => {
|
let pointwiseCombination = (toPointSet: toPointSetFn, operation, t1: t, t2: t): result<
|
||||||
|
t,
|
||||||
|
error,
|
||||||
|
> => {
|
||||||
E.R.merge(toPointSet(t1), toPointSet(t2))
|
E.R.merge(toPointSet(t1), toPointSet(t2))
|
||||||
|> E.R.fmap(((t1, t2)) =>
|
|> E.R.fmap(((t1, t2)) =>
|
||||||
PointSetDist.combinePointwise(OperationType.combinationToFn(operation), t1, t2)
|
PointSetDist.combinePointwise(OperationType.combinationToFn(operation), t1, t2)
|
||||||
|
@ -144,11 +205,12 @@ module T = {
|
||||||
operation: OperationType.combination,
|
operation: OperationType.combination,
|
||||||
t: t,
|
t: t,
|
||||||
f: float,
|
f: float,
|
||||||
) => {
|
): result<t, error> => {
|
||||||
switch operation {
|
switch operation {
|
||||||
| #Add | #Subtract => Error(DistributionVerticalShiftIsInvalid)
|
| #Add | #Subtract => Error(DistributionVerticalShiftIsInvalid)
|
||||||
| (#Multiply | #Divide | #Exponentiate | #Log) as operation =>
|
| (#Multiply | #Divide | #Exponentiate | #Log) as operation =>
|
||||||
toPointSet(t) |> E.R.fmap(t => {
|
toPointSet(t) |> E.R.fmap(t => {
|
||||||
|
//TODO: Move to PointSet codebase
|
||||||
let fn = (secondary, main) => Operation.Scale.toFn(operation, main, secondary)
|
let fn = (secondary, main) => Operation.Scale.toFn(operation, main, secondary)
|
||||||
let integralSumCacheFn = Operation.Scale.toIntegralSumCacheFn(operation)
|
let integralSumCacheFn = Operation.Scale.toIntegralSumCacheFn(operation)
|
||||||
let integralCacheFn = Operation.Scale.toIntegralCacheFn(operation)
|
let integralCacheFn = Operation.Scale.toIntegralCacheFn(operation)
|
||||||
|
@ -159,7 +221,7 @@ module T = {
|
||||||
t,
|
t,
|
||||||
)
|
)
|
||||||
})
|
})
|
||||||
}
|
} |> E.R.fmap(r => #PointSet(r))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -188,10 +250,10 @@ module OmniRunner = {
|
||||||
| Error(e) => #Error(e)
|
| Error(e) => #Error(e)
|
||||||
}
|
}
|
||||||
|
|
||||||
let rec applyFnInternal = (wrapped: wrapped, fnName: operation): outputType => {
|
let rec run = (wrapped: wrapped, fnName: operation): outputType => {
|
||||||
let (value, {sampleCount, xyPointLength} as extra) = wrapped
|
let (value, {sampleCount, xyPointLength} as extra) = wrapped
|
||||||
let reCall = (~value=value, ~extra=extra, ~fnName=fnName, ()) => {
|
let reCall = (~value=value, ~extra=extra, ~fnName=fnName, ()) => {
|
||||||
applyFnInternal((value, extra), fnName)
|
run((value, extra), fnName)
|
||||||
}
|
}
|
||||||
let toPointSet = r => {
|
let toPointSet = r => {
|
||||||
switch reCall(~value=r, ~fnName=#toDist(#toPointSet), ()) {
|
switch reCall(~value=r, ~fnName=#toDist(#toPointSet), ()) {
|
||||||
|
@ -200,8 +262,14 @@ module OmniRunner = {
|
||||||
| _ => Error(Other("Impossible error"))
|
| _ => Error(Other("Impossible error"))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
let toPointSetAndReCall = v => toPointSet(v) |> E.R.fmap(r => reCall(~value=#PointSet(r), ()))
|
let toSampleSet = r => {
|
||||||
let newVal: outputType = switch (fnName, value) {
|
switch reCall(~value=r, ~fnName=#toDist(#toSampleSet(sampleCount)), ()) {
|
||||||
|
| #Dist(#SampleSet(p)) => Ok(p)
|
||||||
|
| #Error(r) => Error(r)
|
||||||
|
| _ => Error(Other("Impossible error"))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
switch (fnName, value) {
|
||||||
// | (#toFloat(n), v) => toFloat(toPointSet, v, n)
|
// | (#toFloat(n), v) => toFloat(toPointSet, v, n)
|
||||||
| (#toFloat(fnName), _) =>
|
| (#toFloat(fnName), _) =>
|
||||||
T.toFloat(toPointSet, fnName, value) |> E.R.fmap(r => #Float(r)) |> fromResult
|
T.toFloat(toPointSet, fnName, value) |> E.R.fmap(r => #Float(r)) |> fromResult
|
||||||
|
@ -214,17 +282,16 @@ module OmniRunner = {
|
||||||
value |> T.sampleN(n) |> E.R.fmap(r => #Dist(#SampleSet(r))) |> fromResult
|
value |> T.sampleN(n) |> E.R.fmap(r => #Dist(#SampleSet(r))) |> fromResult
|
||||||
| (#toDistCombination(#Algebraic, _, #Float(_)), _) => #Error(NotYetImplemented)
|
| (#toDistCombination(#Algebraic, _, #Float(_)), _) => #Error(NotYetImplemented)
|
||||||
| (#toDistCombination(#Algebraic, operation, #Dist(p2)), p1) =>
|
| (#toDistCombination(#Algebraic, operation, #Dist(p2)), p1) =>
|
||||||
T.algebraicCombination(operation, sampleCount, p1, p2)
|
T.AlgebraicCombination.run(toPointSet, toSampleSet, operation, p1, p2)
|
||||||
|> E.R.fmap(r => #Dist(r))
|
|> E.R.fmap(r => #Dist(r))
|
||||||
|> fromResult
|
|> fromResult
|
||||||
| (#toDistCombination(#Pointwise, operation, #Dist(p2)), p1) =>
|
| (#toDistCombination(#Pointwise, operation, #Dist(p2)), p1) =>
|
||||||
T.pointwiseCombination(toPointSet, operation, p1, p2) |> E.R.fmap(r => #Dist(r)) |> fromResult
|
T.pointwiseCombination(toPointSet, operation, p1, p2) |> E.R.fmap(r => #Dist(r)) |> fromResult
|
||||||
| (#toDistCombination(#Pointwise, operation, #Float(f)), _) =>
|
| (#toDistCombination(#Pointwise, operation, #Float(f)), _) =>
|
||||||
T.pointwiseCombinationFloat(toPointSet, operation, value, f)
|
T.pointwiseCombinationFloat(toPointSet, operation, value, f)
|
||||||
|> E.R.fmap(r => #Dist(#PointSet(r)))
|
|> E.R.fmap(r => #Dist(r))
|
||||||
|> fromResult
|
|> fromResult
|
||||||
}
|
}
|
||||||
newVal
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue
Block a user