2022-03-27 20:59:46 +00:00
//TODO: multimodal, add interface, test somehow, track performance, refactor sampleSet, refactor ASTEvaluator.res.
2022-03-28 12:39:07 +00:00
type t = GenericDist_Types.genericDist
2022-03-27 18:22:26 +00:00
type error = GenericDist_Types.error
2022-03-28 12:39:07 +00:00
type toPointSetFn = t => result<PointSetTypes.pointSetDist, error>
type toSampleSetFn = t => result<array<float>, error>
type scaleMultiplyFn = (t, float) => result<t, error>
type pointwiseAddFn = (t, t) => result<t, error>
2022-03-27 18:22:26 +00:00
2022-03-29 18:36:54 +00:00
let sampleN = (t: t, n) =>
2022-03-27 18:22:26 +00:00
switch t {
| #PointSet(r) => Ok(PointSetDist.sampleNRendered(n, r))
| #Symbolic(r) => Ok(SymbolicDist.T.sampleN(n, r))
| #SampleSet(_) => Error(GenericDist_Types.NotYetImplemented)
}
2022-03-27 21:37:27 +00:00
let fromFloat = (f: float) => #Symbolic(SymbolicDist.Float.make(f))
2022-03-27 18:22:26 +00:00
let toString = (t: t) =>
switch t {
| #PointSet(_) => "Point Set Distribution"
| #Symbolic(r) => SymbolicDist.T.toString(r)
| #SampleSet(_) => "Sample Set Distribution"
}
2022-03-27 21:37:27 +00:00
let normalize = (t: t) =>
2022-03-27 18:22:26 +00:00
switch t {
| #PointSet(r) => #PointSet(PointSetDist.T.normalize(r))
| #Symbolic(_) => t
| #SampleSet(_) => t
}
2022-03-31 18:07:39 +00:00
let toFloatOperation = (
2022-03-31 13:19:27 +00:00
t,
~toPointSetFn: toPointSetFn,
2022-03-31 18:07:39 +00:00
~distToFloatOperation: Operation.distToFloatOperation,
2022-03-31 13:19:27 +00:00
) => {
2022-03-27 18:22:26 +00:00
let symbolicSolution = switch t {
| #Symbolic(r) =>
2022-03-31 18:07:39 +00:00
switch SymbolicDist.T.operate(distToFloatOperation, r) {
2022-03-27 18:22:26 +00:00
| Ok(f) => Some(f)
| _ => None
}
| _ => None
}
2022-03-27 20:59:46 +00:00
2022-03-27 18:22:26 +00:00
switch symbolicSolution {
| Some(r) => Ok(r)
2022-03-31 18:07:39 +00:00
| None => toPointSetFn(t)->E.R2.fmap(PointSetDist.operate(distToFloatOperation))
2022-03-27 18:22:26 +00:00
}
}
2022-03-29 19:21:38 +00:00
//Todo: If it's a pointSet, but the xyPointLenght is different from what it has, it should change.
// This is tricky because the case of discrete distributions.
2022-03-31 18:07:39 +00:00
// Also, change the outputXYPoints/pointSetDistLength details
let toPointSet = (~xyPointLength, ~sampleCount, t): result<PointSetTypes.pointSetDist, error> => {
2022-03-27 18:22:26 +00:00
switch t {
| #PointSet(pointSet) => Ok(pointSet)
| #Symbolic(r) => Ok(SymbolicDist.T.toPointSetDist(xyPointLength, r))
| #SampleSet(r) => {
let response = SampleSet.toPointSetDist(
~samples=r,
2022-03-31 18:07:39 +00:00
~samplingInputs={
sampleCount: sampleCount,
outputXYPoints: xyPointLength,
pointSetDistLength: xyPointLength,
kernelWidth: None,
},
2022-03-27 18:22:26 +00:00
(),
).pointSetDist
switch response {
| Some(r) => Ok(r)
| None => Error(Other("Converting sampleSet to pointSet failed"))
}
}
}
}
module Truncate = {
let trySymbolicSimplification = (leftCutoff, rightCutoff, t): option<t> =>
switch (leftCutoff, rightCutoff, t) {
| (None, None, _) => None
| (lc, rc, #Symbolic(#Uniform(u))) if lc < rc =>
Some(#Symbolic(#Uniform(SymbolicDist.Uniform.truncate(lc, rc, u))))
| _ => None
}
let run = (
2022-03-29 19:47:32 +00:00
t: t,
2022-03-31 13:19:27 +00:00
~toPointSetFn: toPointSetFn,
~leftCutoff=None: option<float>,
~rightCutoff=None: option<float>,
(),
2022-03-27 18:22:26 +00:00
): result<t, error> => {
let doesNotNeedCutoff = E.O.isNone(leftCutoff) && E.O.isNone(rightCutoff)
if doesNotNeedCutoff {
Ok(t)
} else {
switch trySymbolicSimplification(leftCutoff, rightCutoff, t) {
| Some(r) => Ok(r)
| None =>
2022-03-31 13:19:27 +00:00
toPointSetFn(t)->E.R2.fmap(t =>
2022-03-27 18:22:26 +00:00
#PointSet(PointSetDist.T.truncate(leftCutoff, rightCutoff, t))
)
}
}
}
}
2022-03-28 11:56:20 +00:00
let truncate = Truncate.run
2022-03-27 18:22:26 +00:00
/* Given two random variables A and B, this returns the distribution
of a new variable that is the result of the operation on A and B.
For instance, normal(0, 1) + normal(1, 1) -> normal(1, 2).
2022-03-29 19:21:38 +00:00
In general, this is implemented via convolution.
TODO: It would be useful to be able to pass in a paramater to get this to run either with convolution or monte carlo.
*/
2022-03-27 18:22:26 +00:00
module AlgebraicCombination = {
let tryAnalyticalSimplification = (
2022-03-31 18:07:39 +00:00
arithmeticOperation: GenericDist_Types.Operation.arithmeticOperation,
2022-03-27 18:22:26 +00:00
t1: t,
t2: t,
): option<result<SymbolicDistTypes.symbolicDist, string>> =>
2022-03-31 18:07:39 +00:00
switch (arithmeticOperation, t1, t2) {
| (arithmeticOperation, #Symbolic(d1), #Symbolic(d2)) =>
switch SymbolicDist.T.tryAnalyticalSimplification(d1, d2, arithmeticOperation) {
2022-03-27 18:22:26 +00:00
| #AnalyticalSolution(symbolicDist) => Some(Ok(symbolicDist))
| #Error(er) => Some(Error(er))
| #NoSolution => None
}
| _ => None
}
let runConvolution = (
toPointSet: toPointSetFn,
2022-03-31 18:07:39 +00:00
arithmeticOperation: GenericDist_Types.Operation.arithmeticOperation,
2022-03-27 18:22:26 +00:00
t1: t,
t2: t,
) =>
2022-03-29 21:35:33 +00:00
E.R.merge(toPointSet(t1), toPointSet(t2))->E.R2.fmap(((a, b)) =>
2022-03-31 18:07:39 +00:00
PointSetDist.combineAlgebraically(arithmeticOperation, a, b)
2022-03-27 18:22:26 +00:00
)
let runMonteCarlo = (
toSampleSet: toSampleSetFn,
2022-03-31 18:07:39 +00:00
arithmeticOperation: GenericDist_Types.Operation.arithmeticOperation,
2022-03-27 18:22:26 +00:00
t1: t,
t2: t,
) => {
2022-03-31 18:07:39 +00:00
let arithmeticOperation = Operation.Algebraic.toFn(arithmeticOperation)
2022-03-29 21:35:33 +00:00
E.R.merge(toSampleSet(t1), toSampleSet(t2))->E.R2.fmap(((a, b)) => {
2022-03-31 18:07:39 +00:00
Belt.Array.zip(a, b)->E.A2.fmap(((a, b)) => arithmeticOperation(a, b))
2022-03-27 18:22:26 +00:00
})
}
//I'm (Ozzie) really just guessing here, very little idea what's best
let expectedConvolutionCost: t => int = x =>
switch x {
| #Symbolic(#Float(_)) => 1
| #Symbolic(_) => 1000
2022-03-29 21:35:33 +00:00
| #PointSet(Discrete(m)) => m.xyShape->XYShape.T.length
2022-03-27 18:22:26 +00:00
| #PointSet(Mixed(_)) => 1000
| #PointSet(Continuous(_)) => 1000
| _ => 1000
}
let chooseConvolutionOrMonteCarlo = (t2: t, t1: t) =>
expectedConvolutionCost(t1) * expectedConvolutionCost(t2) > 10000
? #CalculateWithMonteCarlo
: #CalculateWithConvolution
let run = (
2022-03-29 19:47:32 +00:00
t1: t,
2022-03-31 13:19:27 +00:00
~toPointSetFn: toPointSetFn,
~toSampleSetFn: toSampleSetFn,
2022-03-31 18:07:39 +00:00
~arithmeticOperation,
2022-03-31 13:19:27 +00:00
~t2: t,
2022-03-27 18:22:26 +00:00
): result<t, error> => {
2022-03-31 18:07:39 +00:00
switch tryAnalyticalSimplification(arithmeticOperation, t1, t2) {
2022-03-27 18:22:26 +00:00
| Some(Ok(symbolicDist)) => Ok(#Symbolic(symbolicDist))
| Some(Error(e)) => Error(Other(e))
| None =>
switch chooseConvolutionOrMonteCarlo(t1, t2) {
| #CalculateWithMonteCarlo =>
2022-03-31 18:07:39 +00:00
runMonteCarlo(toSampleSetFn, arithmeticOperation, t1, t2)->E.R2.fmap(r => #SampleSet(r))
2022-03-27 18:22:26 +00:00
| #CalculateWithConvolution =>
2022-03-31 18:07:39 +00:00
runConvolution(toPointSetFn, arithmeticOperation, t1, t2)->E.R2.fmap(r => #PointSet(r))
2022-03-27 18:22:26 +00:00
}
}
}
}
2022-03-28 11:56:20 +00:00
let algebraicCombination = AlgebraicCombination.run
2022-03-27 18:22:26 +00:00
//TODO: Add faster pointwiseCombine fn
2022-03-31 18:07:39 +00:00
let pointwiseCombination = (
t1: t,
~toPointSetFn: toPointSetFn,
~arithmeticOperation,
~t2: t,
): result<t, error> => {
2022-03-31 13:19:27 +00:00
E.R.merge(toPointSetFn(t1), toPointSetFn(t2))
2022-03-29 21:35:33 +00:00
->E.R2.fmap(((t1, t2)) =>
2022-03-31 18:07:39 +00:00
PointSetDist.combinePointwise(
GenericDist_Types.Operation.arithmeticToFn(arithmeticOperation),
t1,
t2,
)
2022-03-27 18:22:26 +00:00
)
2022-03-29 21:35:33 +00:00
->E.R2.fmap(r => #PointSet(r))
2022-03-27 18:22:26 +00:00
}
let pointwiseCombinationFloat = (
2022-03-29 19:47:32 +00:00
t: t,
2022-03-31 13:19:27 +00:00
~toPointSetFn: toPointSetFn,
2022-03-31 18:07:39 +00:00
~arithmeticOperation: GenericDist_Types.Operation.arithmeticOperation,
2022-03-31 13:19:27 +00:00
~float: float,
2022-03-27 18:22:26 +00:00
): result<t, error> => {
2022-03-31 18:07:39 +00:00
let m = switch arithmeticOperation {
2022-03-27 18:22:26 +00:00
| #Add | #Subtract => Error(GenericDist_Types.DistributionVerticalShiftIsInvalid)
2022-03-31 18:07:39 +00:00
| (#Multiply | #Divide | #Exponentiate | #Log) as arithmeticOperation =>
2022-03-31 13:19:27 +00:00
toPointSetFn(t)->E.R2.fmap(t => {
2022-03-27 18:22:26 +00:00
//TODO: Move to PointSet codebase
2022-03-31 18:07:39 +00:00
let fn = (secondary, main) => Operation.Scale.toFn(arithmeticOperation, main, secondary)
let integralSumCacheFn = Operation.Scale.toIntegralSumCacheFn(arithmeticOperation)
let integralCacheFn = Operation.Scale.toIntegralCacheFn(arithmeticOperation)
2022-03-27 18:22:26 +00:00
PointSetDist.T.mapY(
2022-03-31 13:19:27 +00:00
~integralSumCacheFn=integralSumCacheFn(float),
~integralCacheFn=integralCacheFn(float),
~fn=fn(float),
2022-03-27 18:22:26 +00:00
t,
)
})
2022-03-30 01:28:14 +00:00
}
m->E.R2.fmap(r => #PointSet(r))
2022-03-27 18:22:26 +00:00
}
2022-03-27 21:37:27 +00:00
2022-03-29 21:35:33 +00:00
//Note: The result should always cumulatively sum to 1. This would be good to test.
2022-03-30 01:28:14 +00:00
//Note: If the inputs are not normalized, this will return poor results. The weights probably refer to the post-normalized forms. It would be good to apply a catch to this.
2022-03-27 21:37:27 +00:00
let mixture = (
2022-03-29 19:47:32 +00:00
values: array<(t, float)>,
2022-03-31 13:19:27 +00:00
~scaleMultiplyFn: scaleMultiplyFn,
~pointwiseAddFn: pointwiseAddFn,
2022-03-27 21:37:27 +00:00
) => {
2022-03-28 11:56:20 +00:00
if E.A.length(values) == 0 {
Error(GenericDist_Types.Other("mixture must have at least 1 element"))
} else {
2022-03-29 21:35:33 +00:00
let totalWeight = values->E.A2.fmap(E.Tuple2.second)->E.A.Floats.sum
2022-03-28 11:56:20 +00:00
let properlyWeightedValues =
2022-03-29 19:21:38 +00:00
values
2022-03-31 13:19:27 +00:00
->E.A2.fmap(((dist, weight)) => scaleMultiplyFn(dist, weight /. totalWeight))
2022-03-29 19:21:38 +00:00
->E.A.R.firstErrorOrOpen
properlyWeightedValues->E.R.bind(values => {
2022-03-28 11:56:20 +00:00
values
|> Js.Array.sliceFrom(1)
|> E.A.fold_left(
2022-03-31 13:19:27 +00:00
(acc, x) => E.R.bind(acc, acc => pointwiseAddFn(acc, x)),
2022-03-28 11:56:20 +00:00
Ok(E.A.unsafe_get(values, 0)),
)
})
}
2022-03-27 21:37:27 +00:00
}