First implementation of sampleSet mixed distribution

This commit is contained in:
Ozzie Gooen 2022-08-20 20:32:41 -07:00
parent c3bb0ba291
commit 7237f2709b
5 changed files with 33 additions and 4 deletions

View File

@ -216,7 +216,7 @@ let rec run = (~env: env, functionCallInfo: functionCallInfo): outputType => {
| FromFloat(subFnName, x) => reCall(~functionCallInfo=FromFloat(subFnName, x), ()) | FromFloat(subFnName, x) => reCall(~functionCallInfo=FromFloat(subFnName, x), ())
| Mixture(dists) => | Mixture(dists) =>
dists dists
->GenericDist.mixture(~scaleMultiplyFn=scaleMultiply, ~pointwiseAddFn=pointwiseAdd) ->GenericDist.mixture(~scaleMultiplyFn=scaleMultiply, ~pointwiseAddFn=pointwiseAdd, ~env)
->E.R2.fmap(r => Dist(r)) ->E.R2.fmap(r => Dist(r))
->OutputLocal.fromResult ->OutputLocal.fromResult
| FromSamples(xs) => | FromSamples(xs) =>

View File

@ -491,15 +491,30 @@ let pointwiseCombinationFloat = (
m->E.R2.fmap(r => DistributionTypes.PointSet(r)) m->E.R2.fmap(r => DistributionTypes.PointSet(r))
} }
//Note: The result should always cumulatively sum to 1. This would be good to test. //TODO: The result should always cumulatively sum to 1. This would be good to test.
//Note: If the inputs are not normalized, this will return poor results. The weights probably refer to the post-normalized forms. It would be good to apply a catch to this. //TODO: If the inputs are not normalized, this will return poor results. The weights probably refer to the post-normalized forms. It would be good to apply a catch to this.
let mixture = ( let mixture = (
values: array<(t, float)>, values: array<(t, float)>,
~scaleMultiplyFn: scaleMultiplyFn, ~scaleMultiplyFn: scaleMultiplyFn,
~pointwiseAddFn: pointwiseAddFn, ~pointwiseAddFn: pointwiseAddFn,
~env: env,
) => { ) => {
if E.A.length(values) == 0 { let allValuesAreSampleSet = v => E.A.all(((t, _)) => isSampleSetSet(t), v)
if E.A.isEmpty(values) {
Error(DistributionTypes.OtherError("Mixture error: mixture must have at least 1 element")) Error(DistributionTypes.OtherError("Mixture error: mixture must have at least 1 element"))
} else if allValuesAreSampleSet(values) {
let withSampleSetValues = values->E.A2.fmap(((value, weight)) =>
switch value {
| SampleSet(sampleSet) => Ok((sampleSet, weight))
| _ => Error("Unreachable")
} |> E.R.toExn("Mixture coding error: SampleSet expected. This should be inaccessible.")
)
let sampleSetMixture = SampleSetDist.mixture(withSampleSetValues, env.sampleCount)
switch sampleSetMixture {
| Ok(sampleSet) => Ok(DistributionTypes.SampleSet(sampleSet))
| Error(err) => Error(DistributionTypes.Error.sampleErrorToDistErr(err))
}
} else { } else {
let totalWeight = values->E.A2.fmap(E.Tuple2.second)->E.A.Floats.sum let totalWeight = values->E.A2.fmap(E.Tuple2.second)->E.A.Floats.sum
let properlyWeightedValues = let properlyWeightedValues =

View File

@ -81,6 +81,7 @@ let mixture: (
array<(t, float)>, array<(t, float)>,
~scaleMultiplyFn: scaleMultiplyFn, ~scaleMultiplyFn: scaleMultiplyFn,
~pointwiseAddFn: pointwiseAddFn, ~pointwiseAddFn: pointwiseAddFn,
~env: env,
) => result<t, error> ) => result<t, error>
let isSymbolic: t => bool let isSymbolic: t => bool

View File

@ -131,3 +131,15 @@ let max = t => T.get(t)->E.A.Floats.max
let stdev = t => T.get(t)->E.A.Floats.stdev let stdev = t => T.get(t)->E.A.Floats.stdev
let variance = t => T.get(t)->E.A.Floats.variance let variance = t => T.get(t)->E.A.Floats.variance
let percentile = (t, f) => T.get(t)->E.A.Floats.percentile(f) let percentile = (t, f) => T.get(t)->E.A.Floats.percentile(f)
let mixture = (values: array<(t, float)>, intendedLength: int) => {
let totalWeight = values->E.A2.fmap(E.Tuple2.second)->E.A.Floats.sum
values
->E.A2.fmap(((dist, weight)) => {
let adjustedWeight = weight /. totalWeight
let samplesToGet = adjustedWeight *. E.I.toFloat(intendedLength) |> E.Float.toInt
sampleN(dist, samplesToGet)
})
->E.A.concatMany
->T.make
}

View File

@ -220,6 +220,7 @@ module I = {
let increment = n => n + 1 let increment = n => n + 1
let decrement = n => n - 1 let decrement = n => n - 1
let toString = Js.Int.toString let toString = Js.Int.toString
let toFloat = Js.Int.toFloat
} }
exception Assertion(string) exception Assertion(string)