fix: SampleSet.fromDist works for discrete and mixed
This commit is contained in:
		
							parent
							
								
									462f1c9649
								
							
						
					
					
						commit
						93f4c1e0c2
					
				|  | @ -31,9 +31,9 @@ let isSymbolic = (t: t) => | ||||||
| 
 | 
 | ||||||
| let sampleN = (t: t, n) => | let sampleN = (t: t, n) => | ||||||
|   switch t { |   switch t { | ||||||
|   | PointSet(r) => PointSetDist.sampleNRendered(n, r) |   | PointSet(r) => PointSetDist.T.sampleN(r,n) | ||||||
|   | Symbolic(r) => SymbolicDist.T.sampleN(n, r) |  | ||||||
|   | SampleSet(r) => SampleSetDist.sampleN(r, n) |   | SampleSet(r) => SampleSetDist.sampleN(r, n) | ||||||
|  |   | Symbolic(r) => SymbolicDist.T.sampleN(n, r) | ||||||
|   } |   } | ||||||
| 
 | 
 | ||||||
| let sample = (t: t) => sampleN(t, 1)->E.A.first |> E.O.toExn("Should not have happened") | let sample = (t: t) => sampleN(t, 1)->E.A.first |> E.O.toExn("Should not have happened") | ||||||
|  |  | ||||||
|  | @ -270,6 +270,25 @@ module T = Dist({ | ||||||
|   } |   } | ||||||
|   let variance = (t: t): float => |   let variance = (t: t): float => | ||||||
|     XYShape.Analysis.getVarianceDangerously(t, mean, Analysis.getMeanOfSquares) |     XYShape.Analysis.getVarianceDangerously(t, mean, Analysis.getMeanOfSquares) | ||||||
|  | 
 | ||||||
|  | let doN = (n, fn) => { | ||||||
|  |   let items = Belt.Array.make(n, 0.0) | ||||||
|  |   for x in 0 to n - 1 { | ||||||
|  |     let _ = Belt.Array.set(items, x, fn()) | ||||||
|  |   } | ||||||
|  |   items | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | let sample = (t: t): float => { | ||||||
|  |   let randomItem = Random.float(1.0) | ||||||
|  |   t |> integralYtoX(randomItem) | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | let sampleN = (dist, n) => { | ||||||
|  |   let integralCache = integral(dist) | ||||||
|  |   let distWithUpdatedIntegralCache = updateIntegralCache(Some(integralCache), dist) | ||||||
|  |   doN(n, () => sample(distWithUpdatedIntegralCache)) | ||||||
|  | } | ||||||
| }) | }) | ||||||
| 
 | 
 | ||||||
| let isNormalized = (t: t): bool => { | let isNormalized = (t: t): bool => { | ||||||
|  |  | ||||||
|  | @ -223,9 +223,9 @@ module T = Dist({ | ||||||
|     let getMeanOfSquares = t => t |> shapeMap(XYShape.T.square) |> mean |     let getMeanOfSquares = t => t |> shapeMap(XYShape.T.square) |> mean | ||||||
|     XYShape.Analysis.getVarianceDangerously(t, mean, getMeanOfSquares) |     XYShape.Analysis.getVarianceDangerously(t, mean, getMeanOfSquares) | ||||||
|   } |   } | ||||||
| }) |  | ||||||
| 
 | 
 | ||||||
| let sampleN = (t: t, n): array<float> => { | let sampleN = (t: t, n): array<float> => { | ||||||
|   let normalized = t->T.normalize->getShape |   let normalized = t->normalize->getShape | ||||||
|   Stdlib.Random.sample(normalized.xs, {probs: normalized.ys, size: n}) |   Stdlib.Random.sample(normalized.xs, {probs: normalized.ys, size: n}) | ||||||
| } | } | ||||||
|  | }) | ||||||
|  |  | ||||||
|  | @ -33,6 +33,7 @@ module type dist = { | ||||||
| 
 | 
 | ||||||
|   let mean: t => float |   let mean: t => float | ||||||
|   let variance: t => float |   let variance: t => float | ||||||
|  |   let sampleN: (t, int) => array<float> | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| module Dist = (T: dist) => { | module Dist = (T: dist) => { | ||||||
|  | @ -64,6 +65,8 @@ module Dist = (T: dist) => { | ||||||
|     let yToX = T.integralYtoX |     let yToX = T.integralYtoX | ||||||
|     let sum = T.integralEndY |     let sum = T.integralEndY | ||||||
|   } |   } | ||||||
|  | 
 | ||||||
|  |   let sampleN = T.sampleN | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| module Common = { | module Common = { | ||||||
|  |  | ||||||
|  | @ -270,38 +270,47 @@ module T = Dist({ | ||||||
|     }) |     }) | ||||||
|   } |   } | ||||||
| 
 | 
 | ||||||
|   let mean = ({discrete, continuous}: t): float => { |   let discreteIntegralSum =({discrete}: t): float => Discrete.T.Integral.sum(discrete) | ||||||
|  |   let continuousIntegralSum =({continuous}: t): float => Continuous.T.Integral.sum(continuous) | ||||||
|  |   let integralSum =(t:t): float => discreteIntegralSum(t) +. continuousIntegralSum(t) | ||||||
|  | 
 | ||||||
|  |   let mean = ({discrete, continuous} as t: t): float => { | ||||||
|     let discreteMean = Discrete.T.mean(discrete) |     let discreteMean = Discrete.T.mean(discrete) | ||||||
|     let continuousMean = Continuous.T.mean(continuous) |     let continuousMean = Continuous.T.mean(continuous) | ||||||
| 
 | 
 | ||||||
|     // the combined mean is the weighted sum of the two: |     (discreteMean *. discreteIntegralSum(t) +. continuousMean *. continuousIntegralSum(t)) /. | ||||||
|     let discreteIntegralSum = Discrete.T.Integral.sum(discrete) |       integralSum(t) | ||||||
|     let continuousIntegralSum = Continuous.T.Integral.sum(continuous) |  | ||||||
|     let totalIntegralSum = discreteIntegralSum +. continuousIntegralSum |  | ||||||
| 
 |  | ||||||
|     (discreteMean *. discreteIntegralSum +. continuousMean *. continuousIntegralSum) /. |  | ||||||
|       totalIntegralSum |  | ||||||
|   } |   } | ||||||
| 
 | 
 | ||||||
|   let variance = ({discrete, continuous} as t: t): float => { |   let variance = ({discrete, continuous} as t: t): float => { | ||||||
|     // the combined mean is the weighted sum of the two: |     // the combined mean is the weighted sum of the two: | ||||||
|     let discreteIntegralSum = Discrete.T.Integral.sum(discrete) |  | ||||||
|     let continuousIntegralSum = Continuous.T.Integral.sum(continuous) |  | ||||||
|     let totalIntegralSum = discreteIntegralSum +. continuousIntegralSum |  | ||||||
| 
 | 
 | ||||||
|  |     let _discreteIntegralSum = discreteIntegralSum(t) | ||||||
|  |     let _integralSum = integralSum(t) | ||||||
|     let getMeanOfSquares = ({discrete, continuous}: t) => { |     let getMeanOfSquares = ({discrete, continuous}: t) => { | ||||||
|       let discreteMean = discrete |> Discrete.shapeMap(XYShape.T.square) |> Discrete.T.mean |       let discreteMean = discrete |> Discrete.shapeMap(XYShape.T.square) |> Discrete.T.mean | ||||||
|       let continuousMean = continuous |> Continuous.Analysis.getMeanOfSquares |       let continuousMean = continuous -> Continuous.Analysis.getMeanOfSquares | ||||||
|       (discreteMean *. discreteIntegralSum +. continuousMean *. continuousIntegralSum) /. |       (discreteMean *. discreteIntegralSum(t) +. continuousMean *. continuousIntegralSum(t)) /. | ||||||
|         totalIntegralSum |         integralSum(t) | ||||||
|     } |     } | ||||||
| 
 | 
 | ||||||
|     switch discreteIntegralSum /. totalIntegralSum { |     switch _discreteIntegralSum /. _integralSum { | ||||||
|     | 1.0 => Discrete.T.variance(discrete) |     | 1.0 => Discrete.T.variance(discrete) | ||||||
|     | 0.0 => Continuous.T.variance(continuous) |     | 0.0 => Continuous.T.variance(continuous) | ||||||
|     | _ => XYShape.Analysis.getVarianceDangerously(t, mean, getMeanOfSquares) |     | _ => XYShape.Analysis.getVarianceDangerously(t, mean, getMeanOfSquares) | ||||||
|     } |     } | ||||||
|   } |   } | ||||||
|  |    | ||||||
|  |   let sampleN = (t: t, n:int): array<float> => { | ||||||
|  |     let discreteIntegralSum = discreteIntegralSum(t); | ||||||
|  |     let integralSum = integralSum(t); | ||||||
|  |     let discreteSampleLength:int = (Js.Int.toFloat(n) *. discreteIntegralSum /. integralSum) -> E.Float.toInt | ||||||
|  |     let continuousSampleLength = n - discreteSampleLength; | ||||||
|  |     let continuousSamples = t.continuous ->Continuous.T.normalize-> Continuous.T.sampleN( continuousSampleLength) | ||||||
|  |     let discreteSamples = t.discrete ->Discrete.T.normalize->Discrete.T.sampleN(discreteSampleLength) | ||||||
|  |     Js.log3("Samples", continuousSamples, discreteSamples); | ||||||
|  |     E.A.concat(discreteSamples, continuousSamples) -> E.A.shuffle | ||||||
|  |   } | ||||||
| }) | }) | ||||||
| 
 | 
 | ||||||
| let combineAlgebraically = (op: Operation.convolutionOperation, t1: t, t2: t): t => { | let combineAlgebraically = (op: Operation.convolutionOperation, t1: t, t2: t): t => { | ||||||
|  |  | ||||||
|  | @ -198,6 +198,13 @@ module T = Dist({ | ||||||
|     | Discrete(m) => Discrete.T.variance(m) |     | Discrete(m) => Discrete.T.variance(m) | ||||||
|     | Continuous(m) => Continuous.T.variance(m) |     | Continuous(m) => Continuous.T.variance(m) | ||||||
|     } |     } | ||||||
|  | 
 | ||||||
|  |   let sampleN = (t: t, int): array<float> => | ||||||
|  |     switch t { | ||||||
|  |     | Mixed(m) => Mixed.T.sampleN(m,int) | ||||||
|  |     | Discrete(m) => Discrete.T.sampleN(m,int) | ||||||
|  |     | Continuous(m) => Continuous.T.sampleN(m,int) | ||||||
|  |     } | ||||||
| }) | }) | ||||||
| 
 | 
 | ||||||
| let logScore = (args: PointSetDist_Scoring.scoreArgs): result<float, Operation.Error.t> => | let logScore = (args: PointSetDist_Scoring.scoreArgs): result<float, Operation.Error.t> => | ||||||
|  | @ -235,12 +242,6 @@ let isFloat = (t: t) => | ||||||
|   | _ => false |   | _ => false | ||||||
|   } |   } | ||||||
| 
 | 
 | ||||||
| let sampleNRendered = (n, dist) => { |  | ||||||
|   let integralCache = T.Integral.get(dist) |  | ||||||
|   let distWithUpdatedIntegralCache = T.updateIntegralCache(Some(integralCache), dist) |  | ||||||
|   doN(n, () => sample(distWithUpdatedIntegralCache)) |  | ||||||
| } |  | ||||||
| 
 |  | ||||||
| let operate = (distToFloatOp: Operation.distToFloatOperation, s): float => | let operate = (distToFloatOp: Operation.distToFloatOperation, s): float => | ||||||
|   switch distToFloatOp { |   switch distToFloatOp { | ||||||
|   | #Pdf(f) => pdf(f, s) |   | #Pdf(f) => pdf(f, s) | ||||||
|  |  | ||||||
|  | @ -139,7 +139,7 @@ let mixture = (values: array<(t, float)>, intendedLength: int) => { | ||||||
|     ->Belt.Array.mapWithIndex((i, (_, weight)) => (E.I.toFloat(i), weight /. totalWeight)) |     ->Belt.Array.mapWithIndex((i, (_, weight)) => (E.I.toFloat(i), weight /. totalWeight)) | ||||||
|     ->XYShape.T.fromZippedArray |     ->XYShape.T.fromZippedArray | ||||||
|     ->Discrete.make |     ->Discrete.make | ||||||
|     ->Discrete.sampleN(intendedLength) |     ->Discrete.T.sampleN(intendedLength) | ||||||
|   let dists = values->E.A2.fmap(E.Tuple2.first)->E.A2.fmap(T.get) |   let dists = values->E.A2.fmap(E.Tuple2.first)->E.A2.fmap(T.get) | ||||||
|   let samples = |   let samples = | ||||||
|     discreteSamples |     discreteSamples | ||||||
|  |  | ||||||
|  | @ -559,6 +559,7 @@ module A = { | ||||||
|   let isEmpty = r => length(r) < 1 |   let isEmpty = r => length(r) < 1 | ||||||
|   let stableSortBy = Belt.SortArray.stableSortBy |   let stableSortBy = Belt.SortArray.stableSortBy | ||||||
|   let toNoneIfEmpty = r => isEmpty(r) ? None : Some(r) |   let toNoneIfEmpty = r => isEmpty(r) ? None : Some(r) | ||||||
|  |   let shuffle = Belt.Array.shuffle | ||||||
|   let toRanges = (a: array<'a>) => |   let toRanges = (a: array<'a>) => | ||||||
|     switch a |> Belt.Array.length { |     switch a |> Belt.Array.length { | ||||||
|     | 0 |     | 0 | ||||||
|  |  | ||||||
		Loading…
	
		Reference in New Issue
	
	Block a user