Added sampleN from Stdlib to allow for correct sampling of discrete distributions
This commit is contained in:
		
							parent
							
								
									7237f2709b
								
							
						
					
					
						commit
						e1efefaf7d
					
				
							
								
								
									
										16
									
								
								packages/squiggle-lang/__tests__/Stdlib_test.res
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										16
									
								
								packages/squiggle-lang/__tests__/Stdlib_test.res
									
									
									
									
									
										Normal file
									
								
							| 
						 | 
				
			
			@ -0,0 +1,16 @@
 | 
			
		|||
open Jest
 | 
			
		||||
open Expect
 | 
			
		||||
 | 
			
		||||
let makeTest = (~only=false, str, item1, item2) =>
 | 
			
		||||
  only
 | 
			
		||||
    ? Only.test(str, () => expect(item1)->toEqual(item2))
 | 
			
		||||
    : test(str, () => expect(item1)->toEqual(item2))
 | 
			
		||||
 | 
			
		||||
describe("Stdlib", () => {
 | 
			
		||||
  makeTest("min", Stdlib.Random.sample([1.0, 2.0], {probs: [0.5, 0.5], size: 10}) |> E.A.length, 10)
 | 
			
		||||
  makeTest(
 | 
			
		||||
    "min",
 | 
			
		||||
    Stdlib.Random.sample([1.0, 2.0], {probs: [0.5, 0.5], size: 10}) |> E.A.uniq |> E.A.Floats.sort,
 | 
			
		||||
    [1.0, 2.0],
 | 
			
		||||
  )
 | 
			
		||||
})
 | 
			
		||||
| 
						 | 
				
			
			@ -18,6 +18,7 @@
 | 
			
		|||
    "benchmark": "ts-node benchmark/conversion_tests.ts",
 | 
			
		||||
    "test": "jest",
 | 
			
		||||
    "test:ts": "jest __tests__/TS/",
 | 
			
		||||
    "test:stdlib": "jest __tests__/Stdlib_test.bs.js",
 | 
			
		||||
    "test:rescript": "jest --modulePathIgnorePatterns=__tests__/TS/*",
 | 
			
		||||
    "test:watch": "jest --watchAll",
 | 
			
		||||
    "test:fnRegistry": "jest __tests__/SquiggleLibrary/SquiggleLibrary_FunctionRegistryLibrary_test.bs.js",
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -224,3 +224,8 @@ module T = Dist({
 | 
			
		|||
    XYShape.Analysis.getVarianceDangerously(t, mean, getMeanOfSquares)
 | 
			
		||||
  }
 | 
			
		||||
})
 | 
			
		||||
 | 
			
		||||
let sampleN = (t: t, n): array<float> => {
 | 
			
		||||
  let normalized = t |> T.normalize |> getShape
 | 
			
		||||
  Stdlib.Random.sample(normalized.xs, {probs: normalized.ys, size: n})
 | 
			
		||||
}
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -257,3 +257,7 @@ let toSparkline = (t: t, bucketCount): result<string, PointSetTypes.sparklineErr
 | 
			
		|||
  ->E.O2.fmap(Continuous.downsampleEquallyOverX(bucketCount))
 | 
			
		||||
  ->E.O2.toResult(PointSetTypes.CannotSparklineDiscrete)
 | 
			
		||||
  ->E.R2.fmap(r => Continuous.getShape(r).ys->Sparklines.create())
 | 
			
		||||
 | 
			
		||||
let makeDiscrete = (d):t => Discrete(d)
 | 
			
		||||
let makeContinuous = (d):t => Continuous(d)
 | 
			
		||||
let makeMixed = (d):t => Mixed(d)
 | 
			
		||||
| 
						 | 
				
			
			@ -134,12 +134,19 @@ let percentile = (t, f) => T.get(t)->E.A.Floats.percentile(f)
 | 
			
		|||
 | 
			
		||||
let mixture = (values: array<(t, float)>, intendedLength: int) => {
 | 
			
		||||
  let totalWeight = values->E.A2.fmap(E.Tuple2.second)->E.A.Floats.sum
 | 
			
		||||
  let discreteSamples =
 | 
			
		||||
    values
 | 
			
		||||
  ->E.A2.fmap(((dist, weight)) => {
 | 
			
		||||
    let adjustedWeight = weight /. totalWeight
 | 
			
		||||
    let samplesToGet = adjustedWeight *. E.I.toFloat(intendedLength) |> E.Float.toInt
 | 
			
		||||
    sampleN(dist, samplesToGet)
 | 
			
		||||
    ->Belt.Array.mapWithIndex((i, (_, weight)) => (E.I.toFloat(i), weight /. totalWeight))
 | 
			
		||||
    ->XYShape.T.fromZippedArray
 | 
			
		||||
    ->Discrete.make
 | 
			
		||||
    ->Discrete.sampleN(intendedLength)
 | 
			
		||||
  let dists = values->E.A2.fmap(E.Tuple2.first)->E.A2.fmap(T.get)
 | 
			
		||||
  let samples =
 | 
			
		||||
    discreteSamples
 | 
			
		||||
    ->Belt.Array.mapWithIndex((index, distIndexToChoose) => {
 | 
			
		||||
      let chosenDist = E.A.get(dists, E.Float.toInt(distIndexToChoose))
 | 
			
		||||
      chosenDist |> E.O2.bind(E.A.get(_, index))
 | 
			
		||||
    })
 | 
			
		||||
  ->E.A.concatMany
 | 
			
		||||
  ->T.make
 | 
			
		||||
    ->E.A.O.openIfAllSome
 | 
			
		||||
  (samples |> E.O.toExn("Mixture unreachable error"))->T.make
 | 
			
		||||
}
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -38,3 +38,11 @@ module Logistic = {
 | 
			
		|||
  @module external variance: (float, float) => float = "@stdlib/stats/base/dists/logistic/variance"
 | 
			
		||||
  let variance = variance
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
module Random = {
 | 
			
		||||
  type sampleArgs = {
 | 
			
		||||
    probs: array<float>,
 | 
			
		||||
    size: int,
 | 
			
		||||
  }
 | 
			
		||||
  @module external sample: (array<float>, sampleArgs) => array<float> = "@stdlib/random/sample"
 | 
			
		||||
}
 | 
			
		||||
		Loading…
	
		Reference in New Issue
	
	Block a user