fix: SampleSet.fromDist works for discrete and mixed

This commit is contained in:
Ozzie Gooen 2022-09-02 21:51:42 -07:00
parent 462f1c9649
commit 93f4c1e0c2
8 changed files with 59 additions and 26 deletions

View File

@ -31,9 +31,9 @@ let isSymbolic = (t: t) =>
let sampleN = (t: t, n) =>
switch t {
| PointSet(r) => PointSetDist.sampleNRendered(n, r)
| Symbolic(r) => SymbolicDist.T.sampleN(n, r)
| PointSet(r) => PointSetDist.T.sampleN(r,n)
| SampleSet(r) => SampleSetDist.sampleN(r, n)
| Symbolic(r) => SymbolicDist.T.sampleN(n, r)
}
let sample = (t: t) => sampleN(t, 1)->E.A.first |> E.O.toExn("Should not have happened")

View File

@ -270,6 +270,25 @@ module T = Dist({
}
let variance = (t: t): float =>
XYShape.Analysis.getVarianceDangerously(t, mean, Analysis.getMeanOfSquares)
let doN = (n, fn) => {
let items = Belt.Array.make(n, 0.0)
for x in 0 to n - 1 {
let _ = Belt.Array.set(items, x, fn())
}
items
}
let sample = (t: t): float => {
let randomItem = Random.float(1.0)
t |> integralYtoX(randomItem)
}
let sampleN = (dist, n) => {
let integralCache = integral(dist)
let distWithUpdatedIntegralCache = updateIntegralCache(Some(integralCache), dist)
doN(n, () => sample(distWithUpdatedIntegralCache))
}
})
let isNormalized = (t: t): bool => {

View File

@ -223,9 +223,9 @@ module T = Dist({
let getMeanOfSquares = t => t |> shapeMap(XYShape.T.square) |> mean
XYShape.Analysis.getVarianceDangerously(t, mean, getMeanOfSquares)
}
})
let sampleN = (t: t, n): array<float> => {
let normalized = t->T.normalize->getShape
let normalized = t->normalize->getShape
Stdlib.Random.sample(normalized.xs, {probs: normalized.ys, size: n})
}
})

View File

@ -33,6 +33,7 @@ module type dist = {
let mean: t => float
let variance: t => float
let sampleN: (t, int) => array<float>
}
module Dist = (T: dist) => {
@ -64,6 +65,8 @@ module Dist = (T: dist) => {
let yToX = T.integralYtoX
let sum = T.integralEndY
}
let sampleN = T.sampleN
}
module Common = {

View File

@ -270,38 +270,47 @@ module T = Dist({
})
}
let mean = ({discrete, continuous}: t): float => {
let discreteIntegralSum =({discrete}: t): float => Discrete.T.Integral.sum(discrete)
let continuousIntegralSum =({continuous}: t): float => Continuous.T.Integral.sum(continuous)
let integralSum =(t:t): float => discreteIntegralSum(t) +. continuousIntegralSum(t)
let mean = ({discrete, continuous} as t: t): float => {
let discreteMean = Discrete.T.mean(discrete)
let continuousMean = Continuous.T.mean(continuous)
// the combined mean is the weighted sum of the two:
let discreteIntegralSum = Discrete.T.Integral.sum(discrete)
let continuousIntegralSum = Continuous.T.Integral.sum(continuous)
let totalIntegralSum = discreteIntegralSum +. continuousIntegralSum
(discreteMean *. discreteIntegralSum +. continuousMean *. continuousIntegralSum) /.
totalIntegralSum
(discreteMean *. discreteIntegralSum(t) +. continuousMean *. continuousIntegralSum(t)) /.
integralSum(t)
}
let variance = ({discrete, continuous} as t: t): float => {
// the combined mean is the weighted sum of the two:
let discreteIntegralSum = Discrete.T.Integral.sum(discrete)
let continuousIntegralSum = Continuous.T.Integral.sum(continuous)
let totalIntegralSum = discreteIntegralSum +. continuousIntegralSum
let _discreteIntegralSum = discreteIntegralSum(t)
let _integralSum = integralSum(t)
let getMeanOfSquares = ({discrete, continuous}: t) => {
let discreteMean = discrete |> Discrete.shapeMap(XYShape.T.square) |> Discrete.T.mean
let continuousMean = continuous |> Continuous.Analysis.getMeanOfSquares
(discreteMean *. discreteIntegralSum +. continuousMean *. continuousIntegralSum) /.
totalIntegralSum
let continuousMean = continuous -> Continuous.Analysis.getMeanOfSquares
(discreteMean *. discreteIntegralSum(t) +. continuousMean *. continuousIntegralSum(t)) /.
integralSum(t)
}
switch discreteIntegralSum /. totalIntegralSum {
switch _discreteIntegralSum /. _integralSum {
| 1.0 => Discrete.T.variance(discrete)
| 0.0 => Continuous.T.variance(continuous)
| _ => XYShape.Analysis.getVarianceDangerously(t, mean, getMeanOfSquares)
}
}
let sampleN = (t: t, n:int): array<float> => {
let discreteIntegralSum = discreteIntegralSum(t);
let integralSum = integralSum(t);
let discreteSampleLength:int = (Js.Int.toFloat(n) *. discreteIntegralSum /. integralSum) -> E.Float.toInt
let continuousSampleLength = n - discreteSampleLength;
let continuousSamples = t.continuous ->Continuous.T.normalize-> Continuous.T.sampleN( continuousSampleLength)
let discreteSamples = t.discrete ->Discrete.T.normalize->Discrete.T.sampleN(discreteSampleLength)
Js.log3("Samples", continuousSamples, discreteSamples);
E.A.concat(discreteSamples, continuousSamples) -> E.A.shuffle
}
})
let combineAlgebraically = (op: Operation.convolutionOperation, t1: t, t2: t): t => {

View File

@ -198,6 +198,13 @@ module T = Dist({
| Discrete(m) => Discrete.T.variance(m)
| Continuous(m) => Continuous.T.variance(m)
}
let sampleN = (t: t, int): array<float> =>
switch t {
| Mixed(m) => Mixed.T.sampleN(m,int)
| Discrete(m) => Discrete.T.sampleN(m,int)
| Continuous(m) => Continuous.T.sampleN(m,int)
}
})
let logScore = (args: PointSetDist_Scoring.scoreArgs): result<float, Operation.Error.t> =>
@ -235,12 +242,6 @@ let isFloat = (t: t) =>
| _ => false
}
let sampleNRendered = (n, dist) => {
let integralCache = T.Integral.get(dist)
let distWithUpdatedIntegralCache = T.updateIntegralCache(Some(integralCache), dist)
doN(n, () => sample(distWithUpdatedIntegralCache))
}
let operate = (distToFloatOp: Operation.distToFloatOperation, s): float =>
switch distToFloatOp {
| #Pdf(f) => pdf(f, s)

View File

@ -139,7 +139,7 @@ let mixture = (values: array<(t, float)>, intendedLength: int) => {
->Belt.Array.mapWithIndex((i, (_, weight)) => (E.I.toFloat(i), weight /. totalWeight))
->XYShape.T.fromZippedArray
->Discrete.make
->Discrete.sampleN(intendedLength)
->Discrete.T.sampleN(intendedLength)
let dists = values->E.A2.fmap(E.Tuple2.first)->E.A2.fmap(T.get)
let samples =
discreteSamples

View File

@ -559,6 +559,7 @@ module A = {
let isEmpty = r => length(r) < 1
let stableSortBy = Belt.SortArray.stableSortBy
let toNoneIfEmpty = r => isEmpty(r) ? None : Some(r)
let shuffle = Belt.Array.shuffle
let toRanges = (a: array<'a>) =>
switch a |> Belt.Array.length {
| 0