2022-03-27 20:59:46 +00:00
//TODO: multimodal, add interface, test somehow, track performance, refactor sampleSet, refactor ASTEvaluator.res.
2022-04-11 18:00:56 +00:00
type t = DistributionTypes.genericDist
type error = DistributionTypes.error
2022-03-28 12:39:07 +00:00
type toPointSetFn = t => result<PointSetTypes.pointSetDist, error>
2022-04-09 22:10:06 +00:00
type toSampleSetFn = t => result<SampleSetDist.t, error>
2022-03-28 12:39:07 +00:00
type scaleMultiplyFn = (t, float) => result<t, error>
type pointwiseAddFn = (t, t) => result<t, error>
2022-04-10 01:24:44 +00:00
2022-04-28 13:08:53 +00:00
let isPointSet = (t: t) =>
switch t {
| PointSet(_) => true
| _ => false
}
let isSampleSetSet = (t: t) =>
switch t {
| SampleSet(_) => true
| _ => false
}
let isSymbolic = (t: t) =>
switch t {
| Symbolic(_) => true
| _ => false
}
2022-03-29 18:36:54 +00:00
let sampleN = (t: t, n) =>
2022-03-27 18:22:26 +00:00
switch t {
2022-04-10 01:24:44 +00:00
| PointSet(r) => PointSetDist.sampleNRendered(n, r)
| Symbolic(r) => SymbolicDist.T.sampleN(n, r)
| SampleSet(r) => SampleSetDist.sampleN(r, n)
2022-03-27 18:22:26 +00:00
}
2022-04-10 01:24:44 +00:00
2022-05-21 02:54:15 +00:00
let sample = (t: t) => sampleN(t, 1)->E.A.first |> E.O.toExn("Should not have happened")
2022-05-19 13:25:34 +00:00
2022-04-10 00:34:21 +00:00
let toSampleSetDist = (t: t, n) =>
2022-04-23 13:56:47 +00:00
SampleSetDist.make(sampleN(t, n))->E.R2.errMap(DistributionTypes.Error.sampleErrorToDistErr)
2022-03-27 18:22:26 +00:00
2022-03-31 18:15:21 +00:00
let fromFloat = (f: float): t => Symbolic(SymbolicDist.Float.make(f))
2022-03-27 21:37:27 +00:00
2022-03-27 18:22:26 +00:00
let toString = (t: t) =>
switch t {
2022-03-31 18:15:21 +00:00
| PointSet(_) => "Point Set Distribution"
| Symbolic(r) => SymbolicDist.T.toString(r)
| SampleSet(_) => "Sample Set Distribution"
2022-03-27 18:22:26 +00:00
}
2022-03-31 18:15:21 +00:00
let normalize = (t: t): t =>
2022-03-27 18:22:26 +00:00
switch t {
2022-03-31 18:15:21 +00:00
| PointSet(r) => PointSet(PointSetDist.T.normalize(r))
| Symbolic(_) => t
| SampleSet(_) => t
2022-03-27 18:22:26 +00:00
}
2022-04-15 17:58:00 +00:00
let integralEndY = (t: t): float =>
switch t {
| PointSet(r) => PointSetDist.T.integralEndY(r)
| Symbolic(_) => 1.0
| SampleSet(_) => 1.0
}
2022-04-15 20:28:51 +00:00
let isNormalized = (t: t): bool => Js.Math.abs_float(integralEndY(t) -. 1.0) < 1e-7
2022-04-15 17:58:00 +00:00
2022-03-31 18:07:39 +00:00
let toFloatOperation = (
2022-03-31 13:19:27 +00:00
t,
~toPointSetFn: toPointSetFn,
2022-04-29 01:31:15 +00:00
~distToFloatOperation: DistributionTypes.DistributionOperation.toFloat,
2022-03-31 13:19:27 +00:00
) => {
2022-04-29 01:31:15 +00:00
switch distToFloatOperation {
| #IntegralSum => Ok(integralEndY(t))
2022-06-06 19:08:00 +00:00
| (#Pdf(_) | #Cdf(_) | #Inv(_) | #Mean | #Sample | #Min | #Max) as op => {
2022-04-29 01:31:15 +00:00
let trySymbolicSolution = switch (t: t) {
| Symbolic(r) => SymbolicDist.T.operate(op, r)->E.R.toOption
| _ => None
}
2022-03-27 20:59:46 +00:00
2022-04-29 01:31:15 +00:00
let trySampleSetSolution = switch ((t: t), distToFloatOperation) {
| (SampleSet(sampleSet), #Mean) => SampleSetDist.mean(sampleSet)->Some
| (SampleSet(sampleSet), #Sample) => SampleSetDist.sample(sampleSet)->Some
| (SampleSet(sampleSet), #Inv(r)) => SampleSetDist.percentile(sampleSet, r)->Some
2022-06-06 19:08:00 +00:00
| (SampleSet(sampleSet), #Min) => SampleSetDist.min(sampleSet)->Some
| (SampleSet(sampleSet), #Max) => SampleSetDist.max(sampleSet)->Some
2022-04-29 01:31:15 +00:00
| _ => None
}
2022-03-27 20:59:46 +00:00
2022-04-29 01:31:15 +00:00
switch trySymbolicSolution {
| Some(r) => Ok(r)
| None =>
switch trySampleSetSolution {
| Some(r) => Ok(r)
| None => toPointSetFn(t)->E.R2.fmap(PointSetDist.operate(op))
}
}
2022-04-26 01:46:40 +00:00
}
2022-06-06 19:08:00 +00:00
| (#Stdev | #Variance | #Mode) as op =>
switch t {
| SampleSet(s) =>
switch op {
| #Stdev => SampleSetDist.stdev(s)->Ok
| #Variance => SampleSetDist.variance(s)->Ok
| #Mode => SampleSetDist.mode(s)->Ok
}
| _ => Error(DistributionTypes.NotYetImplemented)
}
2022-03-27 18:22:26 +00:00
}
}
2022-04-08 00:17:01 +00:00
//Todo: If it's a pointSet, but the xyPointLength is different from what it has, it should change.
2022-03-29 19:21:38 +00:00
// This is tricky because the case of discrete distributions.
2022-03-31 18:07:39 +00:00
// Also, change the outputXYPoints/pointSetDistLength details
2022-04-08 12:44:04 +00:00
let toPointSet = (
t,
~xyPointLength,
~sampleCount,
2022-04-23 18:09:06 +00:00
~xSelection: DistributionTypes.DistributionOperation.pointsetXSelection=#ByWeight,
2022-04-21 22:09:06 +00:00
(),
2022-04-08 12:44:04 +00:00
): result<PointSetTypes.pointSetDist, error> => {
2022-03-31 18:15:21 +00:00
switch (t: t) {
| PointSet(pointSet) => Ok(pointSet)
2022-04-08 12:44:04 +00:00
| Symbolic(r) => Ok(SymbolicDist.T.toPointSetDist(~xSelection, xyPointLength, r))
2022-04-10 00:21:02 +00:00
| SampleSet(r) =>
2022-04-10 00:27:03 +00:00
SampleSetDist.toPointSetDist(
2022-04-10 00:21:02 +00:00
~samples=r,
~samplingInputs={
sampleCount: sampleCount,
outputXYPoints: xyPointLength,
pointSetDistLength: xyPointLength,
kernelWidth: None,
},
2022-04-22 20:27:17 +00:00
)->E.R2.errMap(x => DistributionTypes.PointSetConversionError(x))
2022-03-27 18:22:26 +00:00
}
}
2022-05-25 12:17:45 +00:00
module Score = {
2022-06-20 12:28:30 +00:00
type genericDistOrScalar = DistributionTypes.DistributionOperation.genericDistOrScalar
2022-06-20 12:47:16 +00:00
type pointSet_ScoreDistOrScalar = PSDist(PointSetTypes.pointSetDist) | PSScalar(float)
2022-05-25 12:17:45 +00:00
let argsMake = (
2022-06-20 12:28:30 +00:00
~esti: genericDistOrScalar,
~answ: genericDistOrScalar,
~prior: option<genericDistOrScalar>,
2022-05-25 12:17:45 +00:00
): result<PointSetDist_Scoring.scoreArgs, error> => {
2022-05-25 22:10:05 +00:00
let toPointSetFn = t =>
toPointSet(
t,
~xyPointLength=MagicNumbers.Environment.defaultXYPointLength,
~sampleCount=MagicNumbers.Environment.defaultSampleCount,
~xSelection=#ByWeight,
(),
)
2022-05-25 12:17:45 +00:00
let prior': option<result<pointSet_ScoreDistOrScalar, error>> = switch prior {
| None => None
2022-06-20 12:47:16 +00:00
| Some(Score_Dist(d)) => toPointSetFn(d)->E.R.bind(x => x->PSDist->Ok)->Some
| Some(Score_Scalar(s)) => s->PSScalar->Ok->Some
2022-05-25 12:17:45 +00:00
}
2022-06-20 12:47:16 +00:00
let twoDists = (~toPointSetFn, esti': t, answ': t): result<
2022-05-25 12:17:45 +00:00
(PointSetTypes.pointSetDist, PointSetTypes.pointSetDist),
error,
2022-05-25 22:10:05 +00:00
> => E.R.merge(toPointSetFn(esti'), toPointSetFn(answ'))
2022-05-25 12:17:45 +00:00
switch (esti, answ, prior') {
| (Score_Dist(esti'), Score_Dist(answ'), None) =>
2022-06-20 12:47:16 +00:00
twoDists(~toPointSetFn, esti', answ')->E.R2.fmap(((esti'', answ'')) =>
2022-05-25 22:10:05 +00:00
{estimate: esti'', answer: answ'', prior: None}->PointSetDist_Scoring.DistEstimateDistAnswer
2022-05-25 12:17:45 +00:00
)
2022-06-20 12:47:16 +00:00
| (Score_Dist(esti'), Score_Dist(answ'), Some(Ok(PSDist(prior'')))) =>
2022-06-20 12:28:30 +00:00
twoDists(~toPointSetFn, esti', answ')->E.R2.fmap(((esti'', answ'')) =>
2022-06-20 12:51:34 +00:00
{
estimate: esti'',
answer: answ'',
prior: Some(prior''),
}->PointSetDist_Scoring.DistEstimateDistAnswer
2022-05-25 12:17:45 +00:00
)
2022-06-20 12:47:16 +00:00
| (Score_Dist(_), _, Some(Ok(PSScalar(_)))) => DistributionTypes.Unreachable->Error
2022-05-25 12:17:45 +00:00
| (Score_Dist(esti'), Score_Scalar(answ'), None) =>
2022-05-25 22:10:05 +00:00
toPointSetFn(esti')->E.R.bind(esti'' =>
2022-05-25 12:17:45 +00:00
{estimate: esti'', answer: answ', prior: None}
->PointSetDist_Scoring.DistEstimateScalarAnswer
->Ok
)
2022-06-20 12:47:16 +00:00
| (Score_Dist(esti'), Score_Scalar(answ'), Some(Ok(PSDist(prior'')))) =>
2022-05-25 22:10:05 +00:00
toPointSetFn(esti')->E.R.bind(esti'' =>
2022-05-25 12:17:45 +00:00
{estimate: esti'', answer: answ', prior: Some(prior'')}
2022-06-20 12:51:34 +00:00
->PointSetDist_Scoring.DistEstimateScalarAnswer
->Ok
2022-05-25 12:17:45 +00:00
)
2022-06-20 13:34:56 +00:00
| (Score_Scalar(_), Score_Dist(_), None) => NotYetImplemented->Error
| (Score_Scalar(_), Score_Dist(_), Some(Ok(PSScalar(_)))) => NotYetImplemented->Error
2022-06-20 12:47:16 +00:00
| (Score_Scalar(_), _, Some(Ok(PSDist(_)))) => DistributionTypes.Unreachable->Error
| (Score_Scalar(esti'), Score_Scalar(answ'), None) =>
2022-05-25 12:17:45 +00:00
{estimate: esti', answer: answ', prior: None}
->PointSetDist_Scoring.ScalarEstimateScalarAnswer
->Ok
2022-06-20 12:47:16 +00:00
| (Score_Scalar(esti'), Score_Scalar(answ'), Some(Ok(PSScalar(prior'')))) =>
2022-05-25 12:17:45 +00:00
{estimate: esti', answer: answ', prior: prior''->Some}
->PointSetDist_Scoring.ScalarEstimateScalarAnswer
->Ok
| (_, _, Some(Error(err))) => err->Error
}
}
let logScore = (
2022-06-20 12:28:30 +00:00
~estimate: genericDistOrScalar,
~answer: genericDistOrScalar,
~prior: option<genericDistOrScalar>,
2022-05-25 12:17:45 +00:00
): result<float, error> =>
argsMake(~esti=estimate, ~answ=answer, ~prior)->E.R.bind(x =>
x->PointSetDist.logScore->E.R2.errMap(y => DistributionTypes.OperationError(y))
)
}
2022-04-09 02:55:06 +00:00
/*
PointSetDist.toSparkline calls "downsampleEquallyOverX", which downsamples it to n=bucketCount.
It first needs a pointSetDist, so we convert to a pointSetDist. In this process we want the
xyPointLength to be a bit longer than the eventual toSparkline downsampling. I chose 3
fairly arbitrarily.
*/
2022-04-21 22:09:06 +00:00
let toSparkline = (t: t, ~sampleCount: int, ~bucketCount: int=20, ()): result<string, error> =>
2022-04-08 12:44:04 +00:00
t
2022-04-09 02:55:06 +00:00
->toPointSet(~xSelection=#Linear, ~xyPointLength=bucketCount * 3, ~sampleCount, ())
2022-04-09 01:31:08 +00:00
->E.R.bind(r =>
2022-04-22 20:27:17 +00:00
r->PointSetDist.toSparkline(bucketCount)->E.R2.errMap(x => DistributionTypes.SparklineError(x))
2022-04-09 01:31:08 +00:00
)
2022-04-08 00:17:01 +00:00
2022-03-27 18:22:26 +00:00
module Truncate = {
2022-04-21 22:42:15 +00:00
let trySymbolicSimplification = (
leftCutoff: option<float>,
rightCutoff: option<float>,
t: t,
): option<t> =>
2022-03-27 18:22:26 +00:00
switch (leftCutoff, rightCutoff, t) {
| (None, None, _) => None
2022-04-21 22:42:15 +00:00
| (Some(lc), Some(rc), Symbolic(#Uniform(u))) if lc < rc =>
Some(Symbolic(#Uniform(SymbolicDist.Uniform.truncate(Some(lc), Some(rc), u))))
| (lc, rc, Symbolic(#Uniform(u))) =>
2022-03-31 18:15:21 +00:00
Some(Symbolic(#Uniform(SymbolicDist.Uniform.truncate(lc, rc, u))))
2022-03-27 18:22:26 +00:00
| _ => None
}
let run = (
2022-03-29 19:47:32 +00:00
t: t,
2022-03-31 13:19:27 +00:00
~toPointSetFn: toPointSetFn,
~leftCutoff=None: option<float>,
~rightCutoff=None: option<float>,
(),
2022-03-27 18:22:26 +00:00
): result<t, error> => {
let doesNotNeedCutoff = E.O.isNone(leftCutoff) && E.O.isNone(rightCutoff)
if doesNotNeedCutoff {
Ok(t)
} else {
switch trySymbolicSimplification(leftCutoff, rightCutoff, t) {
| Some(r) => Ok(r)
| None =>
2022-03-31 18:15:21 +00:00
toPointSetFn(t)->E.R2.fmap(t => {
2022-06-15 00:00:09 +00:00
DistributionTypes.PointSet(
PointSetDist.T.truncate(leftCutoff, rightCutoff, t)->PointSetDist.T.normalize,
)
2022-03-31 18:15:21 +00:00
})
2022-03-27 18:22:26 +00:00
}
}
}
}
2022-03-28 11:56:20 +00:00
let truncate = Truncate.run
2022-03-27 18:22:26 +00:00
/* Given two random variables A and B, this returns the distribution
of a new variable that is the result of the operation on A and B.
For instance, normal(0, 1) + normal(1, 1) -> normal(1, 2).
2022-03-29 19:21:38 +00:00
In general, this is implemented via convolution.
*/
2022-03-27 18:22:26 +00:00
module AlgebraicCombination = {
2022-04-27 16:48:46 +00:00
module InputValidator = {
/*
2022-04-23 21:51:41 +00:00
It would be good to also do a check to make sure that probability mass for the second
operand, at value 1.0, is 0 (or approximately 0). However, we'd ideally want to check
that both the probability mass and the probability density are greater than zero.
Right now we don't yet have a way of getting probability mass, so I'll leave this for later.
*/
2022-04-27 16:48:46 +00:00
let getLogarithmInputError = (t1: t, t2: t, ~toPointSetFn: toPointSetFn): option<error> => {
let firstOperandIsGreaterThanZero =
2022-04-27 19:21:27 +00:00
toFloatOperation(
t1,
~toPointSetFn,
~distToFloatOperation=#Cdf(MagicNumbers.Epsilon.ten),
) |> E.R.fmap(r => r > 0.)
2022-04-27 16:48:46 +00:00
let secondOperandIsGreaterThanZero =
2022-04-27 19:21:27 +00:00
toFloatOperation(
t2,
~toPointSetFn,
~distToFloatOperation=#Cdf(MagicNumbers.Epsilon.ten),
) |> E.R.fmap(r => r > 0.)
2022-04-27 16:48:46 +00:00
let items = E.A.R.firstErrorOrOpen([
firstOperandIsGreaterThanZero,
secondOperandIsGreaterThanZero,
])
switch items {
| Error(r) => Some(r)
| Ok([true, _]) =>
2022-04-28 12:09:31 +00:00
Some(LogarithmOfDistributionError("First input must be completely greater than 0"))
2022-04-27 16:48:46 +00:00
| Ok([false, true]) =>
2022-04-28 12:09:31 +00:00
Some(LogarithmOfDistributionError("Second input must be completely greater than 0"))
2022-04-27 16:48:46 +00:00
| Ok([false, false]) => None
| Ok(_) => Some(Unreachable)
}
2022-04-23 21:51:41 +00:00
}
2022-04-27 16:48:46 +00:00
let run = (t1: t, t2: t, ~toPointSetFn: toPointSetFn, ~arithmeticOperation): option<error> => {
if arithmeticOperation == #Logarithm {
getLogarithmInputError(t1, t2, ~toPointSetFn)
} else {
None
}
2022-04-23 21:51:41 +00:00
}
}
2022-04-27 16:48:46 +00:00
module StrategyCallOnValidatedInputs = {
let convolution = (
toPointSet: toPointSetFn,
arithmeticOperation: Operation.convolutionOperation,
t1: t,
t2: t,
): result<t, error> =>
E.R.merge(toPointSet(t1), toPointSet(t2))
->E.R2.fmap(((a, b)) => PointSetDist.combineAlgebraically(arithmeticOperation, a, b))
->E.R2.fmap(r => DistributionTypes.PointSet(r))
2022-03-27 18:22:26 +00:00
2022-04-27 16:48:46 +00:00
let monteCarlo = (
toSampleSet: toSampleSetFn,
arithmeticOperation: Operation.algebraicOperation,
t1: t,
t2: t,
): result<t, error> => {
let fn = Operation.Algebraic.toFn(arithmeticOperation)
E.R.merge(toSampleSet(t1), toSampleSet(t2))
->E.R.bind(((t1, t2)) => {
2022-05-27 13:40:49 +00:00
SampleSetDist.map2(~fn, ~t1, ~t2)->E.R2.errMap(x => DistributionTypes.SampleSetError(x))
2022-04-27 16:48:46 +00:00
})
->E.R2.fmap(r => DistributionTypes.SampleSet(r))
}
2022-04-22 16:43:18 +00:00
2022-04-27 16:48:46 +00:00
let symbolic = (
arithmeticOperation: Operation.algebraicOperation,
t1: t,
t2: t,
): SymbolicDistTypes.analyticalSimplificationResult => {
switch (t1, t2) {
| (DistributionTypes.Symbolic(d1), DistributionTypes.Symbolic(d2)) =>
SymbolicDist.T.tryAnalyticalSimplification(d1, d2, arithmeticOperation)
| _ => #NoSolution
}
2022-04-26 22:41:57 +00:00
}
2022-04-27 16:48:46 +00:00
}
2022-04-26 22:41:57 +00:00
2022-04-27 16:48:46 +00:00
module StrategyChooser = {
type specificStrategy = [#AsSymbolic | #AsMonteCarlo | #AsConvolution]
//I'm (Ozzie) really just guessing here, very little idea what's best
let expectedConvolutionCost: t => int = x =>
switch x {
2022-04-27 19:21:27 +00:00
| Symbolic(#Float(_)) => MagicNumbers.OpCost.floatCost
| Symbolic(_) => MagicNumbers.OpCost.symbolicCost
2022-04-27 16:48:46 +00:00
| PointSet(Discrete(m)) => m.xyShape->XYShape.T.length
2022-04-27 19:06:15 +00:00
| PointSet(Mixed(_)) => MagicNumbers.OpCost.mixedCost
| PointSet(Continuous(_)) => MagicNumbers.OpCost.continuousCost
| _ => MagicNumbers.OpCost.wildcardCost
2022-04-27 16:48:46 +00:00
}
2022-04-28 13:08:53 +00:00
let hasSampleSetDist = (t1: t, t2: t): bool => isSampleSetSet(t1) || isSampleSetSet(t2)
let convolutionIsFasterThanMonteCarlo = (t1: t, t2: t): bool =>
expectedConvolutionCost(t1) * expectedConvolutionCost(t2) < MagicNumbers.OpCost.monteCarloCost
let preferConvolutionToMonteCarlo = (t1, t2, arithmeticOperation) => {
!hasSampleSetDist(t1, t2) &&
Operation.Convolution.canDoAlgebraicOperation(arithmeticOperation) &&
convolutionIsFasterThanMonteCarlo(t1, t2)
}
2022-04-27 16:48:46 +00:00
let run = (~t1: t, ~t2: t, ~arithmeticOperation): specificStrategy => {
switch StrategyCallOnValidatedInputs.symbolic(arithmeticOperation, t1, t2) {
| #AnalyticalSolution(_)
| #Error(_) =>
#AsSymbolic
| #NoSolution =>
2022-04-28 13:08:53 +00:00
preferConvolutionToMonteCarlo(t1, t2, arithmeticOperation) ? #AsConvolution : #AsMonteCarlo
2022-04-27 16:48:46 +00:00
}
2022-04-26 22:41:57 +00:00
}
}
2022-04-27 16:48:46 +00:00
let runStrategyOnValidatedInputs = (
~t1: t,
~t2: t,
~arithmeticOperation,
~strategy: StrategyChooser.specificStrategy,
2022-03-31 13:19:27 +00:00
~toPointSetFn: toPointSetFn,
~toSampleSetFn: toSampleSetFn,
2022-03-27 18:22:26 +00:00
): result<t, error> => {
2022-04-27 16:48:46 +00:00
switch strategy {
| #AsMonteCarlo =>
StrategyCallOnValidatedInputs.monteCarlo(toSampleSetFn, arithmeticOperation, t1, t2)
| #AsSymbolic =>
switch StrategyCallOnValidatedInputs.symbolic(arithmeticOperation, t1, t2) {
| #AnalyticalSolution(symbolicDist) => Ok(Symbolic(symbolicDist))
| #Error(e) => Error(OperationError(e))
| #NoSolution => Error(Unreachable)
}
| #AsConvolution =>
switch Operation.Convolution.fromAlgebraicOperation(arithmeticOperation) {
| Some(convOp) => StrategyCallOnValidatedInputs.convolution(toPointSetFn, convOp, t1, t2)
| None => Error(Unreachable)
2022-03-27 18:22:26 +00:00
}
}
}
2022-04-26 22:41:57 +00:00
let run = (
~strategy: DistributionTypes.asAlgebraicCombinationStrategy,
t1: t,
~toPointSetFn: toPointSetFn,
~toSampleSetFn: toSampleSetFn,
2022-04-27 00:30:38 +00:00
~arithmeticOperation: Operation.algebraicOperation,
2022-04-26 22:41:57 +00:00
~t2: t,
): result<t, error> => {
2022-04-27 16:48:46 +00:00
let invalidOperationError = InputValidator.run(t1, t2, ~arithmeticOperation, ~toPointSetFn)
switch (invalidOperationError, strategy) {
| (Some(e), _) => Error(e)
| (None, AsDefault) => {
let chooseStrategy = StrategyChooser.run(~arithmeticOperation, ~t1, ~t2)
runStrategyOnValidatedInputs(
~t1,
~t2,
~strategy=chooseStrategy,
~arithmeticOperation,
~toPointSetFn,
~toSampleSetFn,
)
}
| (None, AsMonteCarlo) =>
StrategyCallOnValidatedInputs.monteCarlo(toSampleSetFn, arithmeticOperation, t1, t2)
| (None, AsSymbolic) =>
switch StrategyCallOnValidatedInputs.symbolic(arithmeticOperation, t1, t2) {
2022-04-27 15:13:10 +00:00
| #AnalyticalSolution(symbolicDist) => Ok(Symbolic(symbolicDist))
| #NoSolution => Error(RequestedStrategyInvalidError(`No analytic solution for inputs`))
| #Error(err) => Error(OperationError(err))
2022-04-26 22:41:57 +00:00
}
2022-04-27 16:48:46 +00:00
| (None, AsConvolution) =>
switch Operation.Convolution.fromAlgebraicOperation(arithmeticOperation) {
| None => {
let errString = `Convolution not supported for ${Operation.Algebraic.toString(
arithmeticOperation,
)}`
Error(RequestedStrategyInvalidError(errString))
}
| Some(convOp) => StrategyCallOnValidatedInputs.convolution(toPointSetFn, convOp, t1, t2)
2022-04-26 22:41:57 +00:00
}
}
}
2022-03-27 18:22:26 +00:00
}
2022-03-28 11:56:20 +00:00
let algebraicCombination = AlgebraicCombination.run
2022-03-27 18:22:26 +00:00
//TODO: Add faster pointwiseCombine fn
2022-03-31 18:07:39 +00:00
let pointwiseCombination = (
t1: t,
~toPointSetFn: toPointSetFn,
2022-04-23 18:09:06 +00:00
~algebraicCombination: Operation.algebraicOperation,
2022-03-31 18:07:39 +00:00
~t2: t,
): result<t, error> => {
2022-04-23 18:09:06 +00:00
E.R.merge(toPointSetFn(t1), toPointSetFn(t2))->E.R.bind(((t1, t2)) =>
PointSetDist.combinePointwise(Operation.Algebraic.toFn(algebraicCombination), t1, t2)
->E.R2.fmap(r => DistributionTypes.PointSet(r))
->E.R2.errMap(err => DistributionTypes.OperationError(err))
2022-03-27 18:22:26 +00:00
)
}
let pointwiseCombinationFloat = (
2022-03-29 19:47:32 +00:00
t: t,
2022-03-31 13:19:27 +00:00
~toPointSetFn: toPointSetFn,
2022-04-23 18:09:06 +00:00
~algebraicCombination: Operation.algebraicOperation,
2022-04-22 20:27:17 +00:00
~f: float,
2022-03-27 18:22:26 +00:00
): result<t, error> => {
2022-05-04 17:02:58 +00:00
let executeCombination = arithOp =>
2022-04-22 20:27:17 +00:00
toPointSetFn(t)->E.R.bind(t => {
2022-03-27 18:22:26 +00:00
//TODO: Move to PointSet codebase
2022-05-04 17:02:58 +00:00
let fn = (secondary, main) => Operation.Scale.toFn(arithOp, main, secondary)
let integralSumCacheFn = Operation.Scale.toIntegralSumCacheFn(arithOp)
let integralCacheFn = Operation.Scale.toIntegralCacheFn(arithOp)
2022-04-22 20:27:17 +00:00
PointSetDist.T.mapYResult(
~integralSumCacheFn=integralSumCacheFn(f),
~integralCacheFn=integralCacheFn(f),
~fn=fn(f),
2022-03-27 18:22:26 +00:00
t,
2022-04-22 20:27:17 +00:00
)->E.R2.errMap(x => DistributionTypes.OperationError(x))
2022-03-27 18:22:26 +00:00
})
2022-05-04 17:02:58 +00:00
let m = switch algebraicCombination {
| #Add | #Subtract => Error(DistributionTypes.DistributionVerticalShiftIsInvalid)
| (#Multiply | #Divide | #Power | #Logarithm) as arithmeticOperation =>
executeCombination(arithmeticOperation)
| #LogarithmWithThreshold(eps) => executeCombination(#LogarithmWithThreshold(eps))
2022-03-30 01:28:14 +00:00
}
2022-04-11 18:00:56 +00:00
m->E.R2.fmap(r => DistributionTypes.PointSet(r))
2022-03-27 18:22:26 +00:00
}
2022-03-27 21:37:27 +00:00
2022-03-29 21:35:33 +00:00
//Note: The result should always cumulatively sum to 1. This would be good to test.
2022-03-30 01:28:14 +00:00
//Note: If the inputs are not normalized, this will return poor results. The weights probably refer to the post-normalized forms. It would be good to apply a catch to this.
2022-03-27 21:37:27 +00:00
let mixture = (
2022-03-29 19:47:32 +00:00
values: array<(t, float)>,
2022-03-31 13:19:27 +00:00
~scaleMultiplyFn: scaleMultiplyFn,
~pointwiseAddFn: pointwiseAddFn,
2022-03-27 21:37:27 +00:00
) => {
2022-03-28 11:56:20 +00:00
if E.A.length(values) == 0 {
2022-04-23 18:13:38 +00:00
Error(DistributionTypes.OtherError("Mixture error: mixture must have at least 1 element"))
2022-03-28 11:56:20 +00:00
} else {
2022-03-29 21:35:33 +00:00
let totalWeight = values->E.A2.fmap(E.Tuple2.second)->E.A.Floats.sum
2022-03-28 11:56:20 +00:00
let properlyWeightedValues =
2022-03-29 19:21:38 +00:00
values
2022-03-31 13:19:27 +00:00
->E.A2.fmap(((dist, weight)) => scaleMultiplyFn(dist, weight /. totalWeight))
2022-03-29 19:21:38 +00:00
->E.A.R.firstErrorOrOpen
properlyWeightedValues->E.R.bind(values => {
2022-03-28 11:56:20 +00:00
values
|> Js.Array.sliceFrom(1)
|> E.A.fold_left(
2022-03-31 13:19:27 +00:00
(acc, x) => E.R.bind(acc, acc => pointwiseAddFn(acc, x)),
2022-03-28 11:56:20 +00:00
Ok(E.A.unsafe_get(values, 0)),
)
})
}
2022-03-27 21:37:27 +00:00
}