2022-03-27 20:59:46 +00:00
|
|
|
//TODO: multimodal, add interface, test somehow, track performance, refactor sampleSet, refactor ASTEvaluator.res.
|
2022-03-28 12:39:07 +00:00
|
|
|
type t = GenericDist_Types.genericDist
|
2022-03-27 18:22:26 +00:00
|
|
|
type error = GenericDist_Types.error
|
2022-03-28 12:39:07 +00:00
|
|
|
type toPointSetFn = t => result<PointSetTypes.pointSetDist, error>
|
|
|
|
type toSampleSetFn = t => result<array<float>, error>
|
|
|
|
type scaleMultiplyFn = (t, float) => result<t, error>
|
|
|
|
type pointwiseAddFn = (t, t) => result<t, error>
|
2022-03-27 18:22:26 +00:00
|
|
|
|
2022-03-29 18:36:54 +00:00
|
|
|
let sampleN = (t: t, n) =>
|
2022-03-27 18:22:26 +00:00
|
|
|
switch t {
|
|
|
|
| #PointSet(r) => Ok(PointSetDist.sampleNRendered(n, r))
|
|
|
|
| #Symbolic(r) => Ok(SymbolicDist.T.sampleN(n, r))
|
|
|
|
| #SampleSet(_) => Error(GenericDist_Types.NotYetImplemented)
|
|
|
|
}
|
|
|
|
|
2022-03-27 21:37:27 +00:00
|
|
|
let fromFloat = (f: float) => #Symbolic(SymbolicDist.Float.make(f))
|
|
|
|
|
2022-03-27 18:22:26 +00:00
|
|
|
let toString = (t: t) =>
|
|
|
|
switch t {
|
|
|
|
| #PointSet(_) => "Point Set Distribution"
|
|
|
|
| #Symbolic(r) => SymbolicDist.T.toString(r)
|
|
|
|
| #SampleSet(_) => "Sample Set Distribution"
|
|
|
|
}
|
|
|
|
|
2022-03-27 21:37:27 +00:00
|
|
|
let normalize = (t: t) =>
|
2022-03-27 18:22:26 +00:00
|
|
|
switch t {
|
|
|
|
| #PointSet(r) => #PointSet(PointSetDist.T.normalize(r))
|
|
|
|
| #Symbolic(_) => t
|
|
|
|
| #SampleSet(_) => t
|
|
|
|
}
|
|
|
|
|
2022-03-29 19:47:32 +00:00
|
|
|
let operationToFloat = (t, toPointSet: toPointSetFn, fnName) => {
|
2022-03-27 18:22:26 +00:00
|
|
|
let symbolicSolution = switch t {
|
|
|
|
| #Symbolic(r) =>
|
|
|
|
switch SymbolicDist.T.operate(fnName, r) {
|
|
|
|
| Ok(f) => Some(f)
|
|
|
|
| _ => None
|
|
|
|
}
|
|
|
|
| _ => None
|
|
|
|
}
|
2022-03-27 20:59:46 +00:00
|
|
|
|
2022-03-27 18:22:26 +00:00
|
|
|
switch symbolicSolution {
|
|
|
|
| Some(r) => Ok(r)
|
2022-03-29 19:21:38 +00:00
|
|
|
| None => toPointSet(t)->E.R.fmap2(PointSetDist.operate(fnName))
|
2022-03-27 18:22:26 +00:00
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
//TODO: Refactor this bit.
|
|
|
|
let defaultSamplingInputs: SamplingInputs.samplingInputs = {
|
|
|
|
sampleCount: 10000,
|
|
|
|
outputXYPoints: 10000,
|
|
|
|
pointSetDistLength: 1000,
|
|
|
|
kernelWidth: None,
|
|
|
|
}
|
|
|
|
|
2022-03-29 19:21:38 +00:00
|
|
|
//Todo: If it's a pointSet, but the xyPointLenght is different from what it has, it should change.
|
|
|
|
// This is tricky because the case of discrete distributions.
|
2022-03-29 18:36:54 +00:00
|
|
|
let toPointSet = (t, xyPointLength): result<PointSetTypes.pointSetDist, error> => {
|
2022-03-27 18:22:26 +00:00
|
|
|
switch t {
|
|
|
|
| #PointSet(pointSet) => Ok(pointSet)
|
|
|
|
| #Symbolic(r) => Ok(SymbolicDist.T.toPointSetDist(xyPointLength, r))
|
|
|
|
| #SampleSet(r) => {
|
|
|
|
let response = SampleSet.toPointSetDist(
|
|
|
|
~samples=r,
|
|
|
|
~samplingInputs=defaultSamplingInputs,
|
|
|
|
(),
|
|
|
|
).pointSetDist
|
|
|
|
switch response {
|
|
|
|
| Some(r) => Ok(r)
|
|
|
|
| None => Error(Other("Converting sampleSet to pointSet failed"))
|
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
module Truncate = {
|
|
|
|
let trySymbolicSimplification = (leftCutoff, rightCutoff, t): option<t> =>
|
|
|
|
switch (leftCutoff, rightCutoff, t) {
|
|
|
|
| (None, None, _) => None
|
|
|
|
| (lc, rc, #Symbolic(#Uniform(u))) if lc < rc =>
|
|
|
|
Some(#Symbolic(#Uniform(SymbolicDist.Uniform.truncate(lc, rc, u))))
|
|
|
|
| _ => None
|
|
|
|
}
|
|
|
|
|
|
|
|
let run = (
|
2022-03-29 19:47:32 +00:00
|
|
|
t: t,
|
2022-03-27 18:22:26 +00:00
|
|
|
toPointSet: toPointSetFn,
|
|
|
|
leftCutoff: option<float>,
|
|
|
|
rightCutoff: option<float>,
|
|
|
|
): result<t, error> => {
|
|
|
|
let doesNotNeedCutoff = E.O.isNone(leftCutoff) && E.O.isNone(rightCutoff)
|
|
|
|
if doesNotNeedCutoff {
|
|
|
|
Ok(t)
|
|
|
|
} else {
|
|
|
|
switch trySymbolicSimplification(leftCutoff, rightCutoff, t) {
|
|
|
|
| Some(r) => Ok(r)
|
|
|
|
| None =>
|
2022-03-29 19:47:32 +00:00
|
|
|
toPointSet(t)->E.R.fmap2(t =>
|
2022-03-27 18:22:26 +00:00
|
|
|
#PointSet(PointSetDist.T.truncate(leftCutoff, rightCutoff, t))
|
|
|
|
)
|
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
2022-03-28 11:56:20 +00:00
|
|
|
let truncate = Truncate.run
|
|
|
|
|
2022-03-27 18:22:26 +00:00
|
|
|
/* Given two random variables A and B, this returns the distribution
|
|
|
|
of a new variable that is the result of the operation on A and B.
|
|
|
|
For instance, normal(0, 1) + normal(1, 1) -> normal(1, 2).
|
2022-03-29 19:21:38 +00:00
|
|
|
In general, this is implemented via convolution.
|
|
|
|
|
|
|
|
TODO: It would be useful to be able to pass in a paramater to get this to run either with convolution or monte carlo.
|
|
|
|
*/
|
2022-03-27 18:22:26 +00:00
|
|
|
module AlgebraicCombination = {
|
|
|
|
let tryAnalyticalSimplification = (
|
|
|
|
operation: GenericDist_Types.Operation.arithmeticOperation,
|
|
|
|
t1: t,
|
|
|
|
t2: t,
|
|
|
|
): option<result<SymbolicDistTypes.symbolicDist, string>> =>
|
|
|
|
switch (operation, t1, t2) {
|
|
|
|
| (operation, #Symbolic(d1), #Symbolic(d2)) =>
|
|
|
|
switch SymbolicDist.T.tryAnalyticalSimplification(d1, d2, operation) {
|
|
|
|
| #AnalyticalSolution(symbolicDist) => Some(Ok(symbolicDist))
|
|
|
|
| #Error(er) => Some(Error(er))
|
|
|
|
| #NoSolution => None
|
|
|
|
}
|
|
|
|
| _ => None
|
|
|
|
}
|
|
|
|
|
|
|
|
let runConvolution = (
|
|
|
|
toPointSet: toPointSetFn,
|
|
|
|
operation: GenericDist_Types.Operation.arithmeticOperation,
|
|
|
|
t1: t,
|
|
|
|
t2: t,
|
|
|
|
) =>
|
2022-03-29 19:47:32 +00:00
|
|
|
E.R.merge(toPointSet(t1), toPointSet(t2))->E.R.fmap2(((a, b)) =>
|
2022-03-27 18:22:26 +00:00
|
|
|
PointSetDist.combineAlgebraically(operation, a, b)
|
|
|
|
)
|
|
|
|
|
|
|
|
let runMonteCarlo = (
|
|
|
|
toSampleSet: toSampleSetFn,
|
|
|
|
operation: GenericDist_Types.Operation.arithmeticOperation,
|
|
|
|
t1: t,
|
|
|
|
t2: t,
|
|
|
|
) => {
|
2022-03-28 11:56:20 +00:00
|
|
|
let operation = Operation.Algebraic.toFn(operation)
|
2022-03-29 19:47:32 +00:00
|
|
|
E.R.merge(toSampleSet(t1), toSampleSet(t2)) -> E.R.fmap2(((a, b)) => {
|
|
|
|
Belt.Array.zip(a, b) -> E.A.fmap2(((a, b)) => operation(a, b))
|
2022-03-27 18:22:26 +00:00
|
|
|
})
|
|
|
|
}
|
|
|
|
|
|
|
|
//I'm (Ozzie) really just guessing here, very little idea what's best
|
|
|
|
let expectedConvolutionCost: t => int = x =>
|
|
|
|
switch x {
|
|
|
|
| #Symbolic(#Float(_)) => 1
|
|
|
|
| #Symbolic(_) => 1000
|
2022-03-29 19:47:32 +00:00
|
|
|
| #PointSet(Discrete(m)) => m.xyShape -> XYShape.T.length
|
2022-03-27 18:22:26 +00:00
|
|
|
| #PointSet(Mixed(_)) => 1000
|
|
|
|
| #PointSet(Continuous(_)) => 1000
|
|
|
|
| _ => 1000
|
|
|
|
}
|
|
|
|
|
|
|
|
let chooseConvolutionOrMonteCarlo = (t2: t, t1: t) =>
|
|
|
|
expectedConvolutionCost(t1) * expectedConvolutionCost(t2) > 10000
|
|
|
|
? #CalculateWithMonteCarlo
|
|
|
|
: #CalculateWithConvolution
|
|
|
|
|
|
|
|
let run = (
|
2022-03-29 19:47:32 +00:00
|
|
|
t1: t,
|
2022-03-27 18:22:26 +00:00
|
|
|
toPointSet: toPointSetFn,
|
|
|
|
toSampleSet: toSampleSetFn,
|
|
|
|
algebraicOp,
|
|
|
|
t2: t,
|
|
|
|
): result<t, error> => {
|
|
|
|
switch tryAnalyticalSimplification(algebraicOp, t1, t2) {
|
|
|
|
| Some(Ok(symbolicDist)) => Ok(#Symbolic(symbolicDist))
|
|
|
|
| Some(Error(e)) => Error(Other(e))
|
|
|
|
| None =>
|
|
|
|
switch chooseConvolutionOrMonteCarlo(t1, t2) {
|
|
|
|
| #CalculateWithMonteCarlo =>
|
2022-03-29 19:21:38 +00:00
|
|
|
runMonteCarlo(toSampleSet, algebraicOp, t1, t2)->E.R.fmap2(r => #SampleSet(r))
|
2022-03-27 18:22:26 +00:00
|
|
|
| #CalculateWithConvolution =>
|
2022-03-29 19:21:38 +00:00
|
|
|
runConvolution(toPointSet, algebraicOp, t1, t2)->E.R.fmap2(r => #PointSet(r))
|
2022-03-27 18:22:26 +00:00
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
2022-03-28 11:56:20 +00:00
|
|
|
let algebraicCombination = AlgebraicCombination.run
|
|
|
|
|
2022-03-27 18:22:26 +00:00
|
|
|
//TODO: Add faster pointwiseCombine fn
|
2022-03-29 19:47:32 +00:00
|
|
|
let pointwiseCombination = (t1: t, toPointSet: toPointSetFn, operation, t2: t): result<
|
2022-03-27 18:22:26 +00:00
|
|
|
t,
|
|
|
|
error,
|
|
|
|
> => {
|
|
|
|
E.R.merge(toPointSet(t1), toPointSet(t2))
|
2022-03-29 19:21:38 +00:00
|
|
|
->E.R.fmap2(((t1, t2)) =>
|
2022-03-27 18:22:26 +00:00
|
|
|
PointSetDist.combinePointwise(GenericDist_Types.Operation.arithmeticToFn(operation), t1, t2)
|
|
|
|
)
|
2022-03-29 19:21:38 +00:00
|
|
|
->E.R.fmap2(r => #PointSet(r))
|
2022-03-27 18:22:26 +00:00
|
|
|
}
|
|
|
|
|
|
|
|
let pointwiseCombinationFloat = (
|
2022-03-29 19:47:32 +00:00
|
|
|
t: t,
|
2022-03-27 18:22:26 +00:00
|
|
|
toPointSet: toPointSetFn,
|
|
|
|
operation: GenericDist_Types.Operation.arithmeticOperation,
|
|
|
|
f: float,
|
|
|
|
): result<t, error> => {
|
|
|
|
switch operation {
|
|
|
|
| #Add | #Subtract => Error(GenericDist_Types.DistributionVerticalShiftIsInvalid)
|
|
|
|
| (#Multiply | #Divide | #Exponentiate | #Log) as operation =>
|
2022-03-29 19:21:38 +00:00
|
|
|
toPointSet(t)->E.R.fmap2(t => {
|
2022-03-27 18:22:26 +00:00
|
|
|
//TODO: Move to PointSet codebase
|
|
|
|
let fn = (secondary, main) => Operation.Scale.toFn(operation, main, secondary)
|
|
|
|
let integralSumCacheFn = Operation.Scale.toIntegralSumCacheFn(operation)
|
|
|
|
let integralCacheFn = Operation.Scale.toIntegralCacheFn(operation)
|
|
|
|
PointSetDist.T.mapY(
|
|
|
|
~integralSumCacheFn=integralSumCacheFn(f),
|
|
|
|
~integralCacheFn=integralCacheFn(f),
|
|
|
|
~fn=fn(f),
|
|
|
|
t,
|
|
|
|
)
|
|
|
|
})
|
2022-03-29 19:21:38 +00:00
|
|
|
}->E.R.fmap2(r => #PointSet(r))
|
2022-03-27 18:22:26 +00:00
|
|
|
}
|
2022-03-27 21:37:27 +00:00
|
|
|
|
2022-03-29 19:21:38 +00:00
|
|
|
//Note: The result should always cumulatively sum to 1.
|
2022-03-27 21:37:27 +00:00
|
|
|
let mixture = (
|
2022-03-29 19:47:32 +00:00
|
|
|
values: array<(t, float)>,
|
2022-03-28 12:39:07 +00:00
|
|
|
scaleMultiply: scaleMultiplyFn,
|
|
|
|
pointwiseAdd: pointwiseAddFn,
|
2022-03-27 21:37:27 +00:00
|
|
|
) => {
|
2022-03-28 11:56:20 +00:00
|
|
|
if E.A.length(values) == 0 {
|
|
|
|
Error(GenericDist_Types.Other("mixture must have at least 1 element"))
|
|
|
|
} else {
|
2022-03-29 19:21:38 +00:00
|
|
|
let totalWeight = values->E.A.fmap2(E.Tuple2.second)->E.A.Floats.sum
|
2022-03-28 11:56:20 +00:00
|
|
|
let properlyWeightedValues =
|
2022-03-29 19:21:38 +00:00
|
|
|
values
|
|
|
|
->E.A.fmap2(((dist, weight)) => scaleMultiply(dist, weight /. totalWeight))
|
|
|
|
->E.A.R.firstErrorOrOpen
|
|
|
|
properlyWeightedValues->E.R.bind(values => {
|
2022-03-28 11:56:20 +00:00
|
|
|
values
|
|
|
|
|> Js.Array.sliceFrom(1)
|
|
|
|
|> E.A.fold_left(
|
|
|
|
(acc, x) => E.R.bind(acc, acc => pointwiseAdd(acc, x)),
|
|
|
|
Ok(E.A.unsafe_get(values, 0)),
|
|
|
|
)
|
|
|
|
})
|
|
|
|
}
|
2022-03-27 21:37:27 +00:00
|
|
|
}
|