214 lines
6.5 KiB
Plaintext
214 lines
6.5 KiB
Plaintext
|
//TODO: multimodal, add interface, split up a little bit, test somehow, track performance, refactor sampleSet, refactor ASTEvaluator.res.
|
||
|
type genericDist = GenericDist_Types.genericDist
|
||
|
type error = GenericDist_Types.error
|
||
|
type toPointSetFn = genericDist => result<PointSetTypes.pointSetDist, error>
|
||
|
type toSampleSetFn = genericDist => result<array<float>, error>
|
||
|
type t = genericDist
|
||
|
|
||
|
let sampleN = (n, t: t) =>
|
||
|
switch t {
|
||
|
| #PointSet(r) => Ok(PointSetDist.sampleNRendered(n, r))
|
||
|
| #Symbolic(r) => Ok(SymbolicDist.T.sampleN(n, r))
|
||
|
| #SampleSet(_) => Error(GenericDist_Types.NotYetImplemented)
|
||
|
}
|
||
|
|
||
|
let toString = (t: t) =>
|
||
|
switch t {
|
||
|
| #PointSet(_) => "Point Set Distribution"
|
||
|
| #Symbolic(r) => SymbolicDist.T.toString(r)
|
||
|
| #SampleSet(_) => "Sample Set Distribution"
|
||
|
}
|
||
|
|
||
|
let normalize = (t: t) =>
|
||
|
switch t {
|
||
|
| #PointSet(r) => #PointSet(PointSetDist.T.normalize(r))
|
||
|
| #Symbolic(_) => t
|
||
|
| #SampleSet(_) => t
|
||
|
}
|
||
|
|
||
|
|
||
|
let operationToFloat = (toPointSet: toPointSetFn, fnName, t: genericDist): result<float, error> => {
|
||
|
let symbolicSolution = switch t {
|
||
|
| #Symbolic(r) =>
|
||
|
switch SymbolicDist.T.operate(fnName, r) {
|
||
|
| Ok(f) => Some(f)
|
||
|
| _ => None
|
||
|
}
|
||
|
| _ => None
|
||
|
}
|
||
|
switch symbolicSolution {
|
||
|
| Some(r) => Ok(r)
|
||
|
| None => toPointSet(t) |> E.R.fmap(PointSetDist.operate(fnName))
|
||
|
}
|
||
|
}
|
||
|
|
||
|
//TODO: Refactor this bit.
|
||
|
let defaultSamplingInputs: SamplingInputs.samplingInputs = {
|
||
|
sampleCount: 10000,
|
||
|
outputXYPoints: 10000,
|
||
|
pointSetDistLength: 1000,
|
||
|
kernelWidth: None,
|
||
|
}
|
||
|
|
||
|
let toPointSet = (xyPointLength, t: t): result<PointSetTypes.pointSetDist, error> => {
|
||
|
switch t {
|
||
|
| #PointSet(pointSet) => Ok(pointSet)
|
||
|
| #Symbolic(r) => Ok(SymbolicDist.T.toPointSetDist(xyPointLength, r))
|
||
|
| #SampleSet(r) => {
|
||
|
let response = SampleSet.toPointSetDist(
|
||
|
~samples=r,
|
||
|
~samplingInputs=defaultSamplingInputs,
|
||
|
(),
|
||
|
).pointSetDist
|
||
|
switch response {
|
||
|
| Some(r) => Ok(r)
|
||
|
| None => Error(Other("Converting sampleSet to pointSet failed"))
|
||
|
}
|
||
|
}
|
||
|
}
|
||
|
}
|
||
|
|
||
|
module Truncate = {
|
||
|
let trySymbolicSimplification = (leftCutoff, rightCutoff, t): option<t> =>
|
||
|
switch (leftCutoff, rightCutoff, t) {
|
||
|
| (None, None, _) => None
|
||
|
| (lc, rc, #Symbolic(#Uniform(u))) if lc < rc =>
|
||
|
Some(#Symbolic(#Uniform(SymbolicDist.Uniform.truncate(lc, rc, u))))
|
||
|
| _ => None
|
||
|
}
|
||
|
|
||
|
let run = (
|
||
|
toPointSet: toPointSetFn,
|
||
|
leftCutoff: option<float>,
|
||
|
rightCutoff: option<float>,
|
||
|
t: t,
|
||
|
): result<t, error> => {
|
||
|
let doesNotNeedCutoff = E.O.isNone(leftCutoff) && E.O.isNone(rightCutoff)
|
||
|
if doesNotNeedCutoff {
|
||
|
Ok(t)
|
||
|
} else {
|
||
|
switch trySymbolicSimplification(leftCutoff, rightCutoff, t) {
|
||
|
| Some(r) => Ok(r)
|
||
|
| None =>
|
||
|
toPointSet(t) |> E.R.fmap(t =>
|
||
|
#PointSet(PointSetDist.T.truncate(leftCutoff, rightCutoff, t))
|
||
|
)
|
||
|
}
|
||
|
}
|
||
|
}
|
||
|
}
|
||
|
|
||
|
/* Given two random variables A and B, this returns the distribution
|
||
|
of a new variable that is the result of the operation on A and B.
|
||
|
For instance, normal(0, 1) + normal(1, 1) -> normal(1, 2).
|
||
|
In general, this is implemented via convolution. */
|
||
|
module AlgebraicCombination = {
|
||
|
let tryAnalyticalSimplification = (
|
||
|
operation: GenericDist_Types.Operation.arithmeticOperation,
|
||
|
t1: t,
|
||
|
t2: t,
|
||
|
): option<result<SymbolicDistTypes.symbolicDist, string>> =>
|
||
|
switch (operation, t1, t2) {
|
||
|
| (operation, #Symbolic(d1), #Symbolic(d2)) =>
|
||
|
switch SymbolicDist.T.tryAnalyticalSimplification(d1, d2, operation) {
|
||
|
| #AnalyticalSolution(symbolicDist) => Some(Ok(symbolicDist))
|
||
|
| #Error(er) => Some(Error(er))
|
||
|
| #NoSolution => None
|
||
|
}
|
||
|
| _ => None
|
||
|
}
|
||
|
|
||
|
let runConvolution = (
|
||
|
toPointSet: toPointSetFn,
|
||
|
operation: GenericDist_Types.Operation.arithmeticOperation,
|
||
|
t1: t,
|
||
|
t2: t,
|
||
|
) =>
|
||
|
E.R.merge(toPointSet(t1), toPointSet(t2)) |> E.R.fmap(((a, b)) =>
|
||
|
PointSetDist.combineAlgebraically(operation, a, b)
|
||
|
)
|
||
|
|
||
|
let runMonteCarlo = (
|
||
|
toSampleSet: toSampleSetFn,
|
||
|
operation: GenericDist_Types.Operation.arithmeticOperation,
|
||
|
t1: t,
|
||
|
t2: t,
|
||
|
) => {
|
||
|
E.R.merge(toSampleSet(t1), toSampleSet(t2)) |> E.R.fmap(((a, b)) => {
|
||
|
Belt.Array.zip(a, b) |> E.A.fmap(((a, b)) => Operation.Algebraic.toFn(operation, a, b))
|
||
|
})
|
||
|
}
|
||
|
|
||
|
//I'm (Ozzie) really just guessing here, very little idea what's best
|
||
|
let expectedConvolutionCost: t => int = x =>
|
||
|
switch x {
|
||
|
| #Symbolic(#Float(_)) => 1
|
||
|
| #Symbolic(_) => 1000
|
||
|
| #PointSet(Discrete(m)) => m.xyShape |> XYShape.T.length
|
||
|
| #PointSet(Mixed(_)) => 1000
|
||
|
| #PointSet(Continuous(_)) => 1000
|
||
|
| _ => 1000
|
||
|
}
|
||
|
|
||
|
let chooseConvolutionOrMonteCarlo = (t2: t, t1: t) =>
|
||
|
expectedConvolutionCost(t1) * expectedConvolutionCost(t2) > 10000
|
||
|
? #CalculateWithMonteCarlo
|
||
|
: #CalculateWithConvolution
|
||
|
|
||
|
let run = (
|
||
|
toPointSet: toPointSetFn,
|
||
|
toSampleSet: toSampleSetFn,
|
||
|
algebraicOp,
|
||
|
t1: t,
|
||
|
t2: t,
|
||
|
): result<t, error> => {
|
||
|
switch tryAnalyticalSimplification(algebraicOp, t1, t2) {
|
||
|
| Some(Ok(symbolicDist)) => Ok(#Symbolic(symbolicDist))
|
||
|
| Some(Error(e)) => Error(Other(e))
|
||
|
| None =>
|
||
|
switch chooseConvolutionOrMonteCarlo(t1, t2) {
|
||
|
| #CalculateWithMonteCarlo =>
|
||
|
runMonteCarlo(toSampleSet, algebraicOp, t1, t2) |> E.R.fmap(r => #SampleSet(r))
|
||
|
| #CalculateWithConvolution =>
|
||
|
runConvolution(toPointSet, algebraicOp, t1, t2) |> E.R.fmap(r => #PointSet(r))
|
||
|
}
|
||
|
}
|
||
|
}
|
||
|
}
|
||
|
|
||
|
//TODO: Add faster pointwiseCombine fn
|
||
|
let pointwiseCombination = (toPointSet: toPointSetFn, operation, t2: t, t1: t): result<
|
||
|
t,
|
||
|
error,
|
||
|
> => {
|
||
|
E.R.merge(toPointSet(t1), toPointSet(t2))
|
||
|
|> E.R.fmap(((t1, t2)) =>
|
||
|
PointSetDist.combinePointwise(GenericDist_Types.Operation.arithmeticToFn(operation), t1, t2)
|
||
|
)
|
||
|
|> E.R.fmap(r => #PointSet(r))
|
||
|
}
|
||
|
|
||
|
let pointwiseCombinationFloat = (
|
||
|
toPointSet: toPointSetFn,
|
||
|
operation: GenericDist_Types.Operation.arithmeticOperation,
|
||
|
f: float,
|
||
|
t: t,
|
||
|
): result<t, error> => {
|
||
|
switch operation {
|
||
|
| #Add | #Subtract => Error(GenericDist_Types.DistributionVerticalShiftIsInvalid)
|
||
|
| (#Multiply | #Divide | #Exponentiate | #Log) as operation =>
|
||
|
toPointSet(t) |> E.R.fmap(t => {
|
||
|
//TODO: Move to PointSet codebase
|
||
|
let fn = (secondary, main) => Operation.Scale.toFn(operation, main, secondary)
|
||
|
let integralSumCacheFn = Operation.Scale.toIntegralSumCacheFn(operation)
|
||
|
let integralCacheFn = Operation.Scale.toIntegralCacheFn(operation)
|
||
|
PointSetDist.T.mapY(
|
||
|
~integralSumCacheFn=integralSumCacheFn(f),
|
||
|
~integralCacheFn=integralCacheFn(f),
|
||
|
~fn=fn(f),
|
||
|
t,
|
||
|
)
|
||
|
})
|
||
|
} |> E.R.fmap(r => #PointSet(r))
|
||
|
}
|