First implementation of sampleSet mixed distribution
This commit is contained in:
parent
c3bb0ba291
commit
7237f2709b
|
@ -216,7 +216,7 @@ let rec run = (~env: env, functionCallInfo: functionCallInfo): outputType => {
|
||||||
| FromFloat(subFnName, x) => reCall(~functionCallInfo=FromFloat(subFnName, x), ())
|
| FromFloat(subFnName, x) => reCall(~functionCallInfo=FromFloat(subFnName, x), ())
|
||||||
| Mixture(dists) =>
|
| Mixture(dists) =>
|
||||||
dists
|
dists
|
||||||
->GenericDist.mixture(~scaleMultiplyFn=scaleMultiply, ~pointwiseAddFn=pointwiseAdd)
|
->GenericDist.mixture(~scaleMultiplyFn=scaleMultiply, ~pointwiseAddFn=pointwiseAdd, ~env)
|
||||||
->E.R2.fmap(r => Dist(r))
|
->E.R2.fmap(r => Dist(r))
|
||||||
->OutputLocal.fromResult
|
->OutputLocal.fromResult
|
||||||
| FromSamples(xs) =>
|
| FromSamples(xs) =>
|
||||||
|
|
|
@ -491,15 +491,30 @@ let pointwiseCombinationFloat = (
|
||||||
m->E.R2.fmap(r => DistributionTypes.PointSet(r))
|
m->E.R2.fmap(r => DistributionTypes.PointSet(r))
|
||||||
}
|
}
|
||||||
|
|
||||||
//Note: The result should always cumulatively sum to 1. This would be good to test.
|
//TODO: The result should always cumulatively sum to 1. This would be good to test.
|
||||||
//Note: If the inputs are not normalized, this will return poor results. The weights probably refer to the post-normalized forms. It would be good to apply a catch to this.
|
//TODO: If the inputs are not normalized, this will return poor results. The weights probably refer to the post-normalized forms. It would be good to apply a catch to this.
|
||||||
let mixture = (
|
let mixture = (
|
||||||
values: array<(t, float)>,
|
values: array<(t, float)>,
|
||||||
~scaleMultiplyFn: scaleMultiplyFn,
|
~scaleMultiplyFn: scaleMultiplyFn,
|
||||||
~pointwiseAddFn: pointwiseAddFn,
|
~pointwiseAddFn: pointwiseAddFn,
|
||||||
|
~env: env,
|
||||||
) => {
|
) => {
|
||||||
if E.A.length(values) == 0 {
|
let allValuesAreSampleSet = v => E.A.all(((t, _)) => isSampleSetSet(t), v)
|
||||||
|
|
||||||
|
if E.A.isEmpty(values) {
|
||||||
Error(DistributionTypes.OtherError("Mixture error: mixture must have at least 1 element"))
|
Error(DistributionTypes.OtherError("Mixture error: mixture must have at least 1 element"))
|
||||||
|
} else if allValuesAreSampleSet(values) {
|
||||||
|
let withSampleSetValues = values->E.A2.fmap(((value, weight)) =>
|
||||||
|
switch value {
|
||||||
|
| SampleSet(sampleSet) => Ok((sampleSet, weight))
|
||||||
|
| _ => Error("Unreachable")
|
||||||
|
} |> E.R.toExn("Mixture coding error: SampleSet expected. This should be inaccessible.")
|
||||||
|
)
|
||||||
|
let sampleSetMixture = SampleSetDist.mixture(withSampleSetValues, env.sampleCount)
|
||||||
|
switch sampleSetMixture {
|
||||||
|
| Ok(sampleSet) => Ok(DistributionTypes.SampleSet(sampleSet))
|
||||||
|
| Error(err) => Error(DistributionTypes.Error.sampleErrorToDistErr(err))
|
||||||
|
}
|
||||||
} else {
|
} else {
|
||||||
let totalWeight = values->E.A2.fmap(E.Tuple2.second)->E.A.Floats.sum
|
let totalWeight = values->E.A2.fmap(E.Tuple2.second)->E.A.Floats.sum
|
||||||
let properlyWeightedValues =
|
let properlyWeightedValues =
|
||||||
|
|
|
@ -81,6 +81,7 @@ let mixture: (
|
||||||
array<(t, float)>,
|
array<(t, float)>,
|
||||||
~scaleMultiplyFn: scaleMultiplyFn,
|
~scaleMultiplyFn: scaleMultiplyFn,
|
||||||
~pointwiseAddFn: pointwiseAddFn,
|
~pointwiseAddFn: pointwiseAddFn,
|
||||||
|
~env: env,
|
||||||
) => result<t, error>
|
) => result<t, error>
|
||||||
|
|
||||||
let isSymbolic: t => bool
|
let isSymbolic: t => bool
|
||||||
|
|
|
@ -131,3 +131,15 @@ let max = t => T.get(t)->E.A.Floats.max
|
||||||
let stdev = t => T.get(t)->E.A.Floats.stdev
|
let stdev = t => T.get(t)->E.A.Floats.stdev
|
||||||
let variance = t => T.get(t)->E.A.Floats.variance
|
let variance = t => T.get(t)->E.A.Floats.variance
|
||||||
let percentile = (t, f) => T.get(t)->E.A.Floats.percentile(f)
|
let percentile = (t, f) => T.get(t)->E.A.Floats.percentile(f)
|
||||||
|
|
||||||
|
let mixture = (values: array<(t, float)>, intendedLength: int) => {
|
||||||
|
let totalWeight = values->E.A2.fmap(E.Tuple2.second)->E.A.Floats.sum
|
||||||
|
values
|
||||||
|
->E.A2.fmap(((dist, weight)) => {
|
||||||
|
let adjustedWeight = weight /. totalWeight
|
||||||
|
let samplesToGet = adjustedWeight *. E.I.toFloat(intendedLength) |> E.Float.toInt
|
||||||
|
sampleN(dist, samplesToGet)
|
||||||
|
})
|
||||||
|
->E.A.concatMany
|
||||||
|
->T.make
|
||||||
|
}
|
||||||
|
|
|
@ -220,6 +220,7 @@ module I = {
|
||||||
let increment = n => n + 1
|
let increment = n => n + 1
|
||||||
let decrement = n => n - 1
|
let decrement = n => n - 1
|
||||||
let toString = Js.Int.toString
|
let toString = Js.Int.toString
|
||||||
|
let toFloat = Js.Int.toFloat
|
||||||
}
|
}
|
||||||
|
|
||||||
exception Assertion(string)
|
exception Assertion(string)
|
||||||
|
|
Loading…
Reference in New Issue
Block a user