Added sampleN from Stdlib to allow for correct sampling of discrete distributions
This commit is contained in:
parent
7237f2709b
commit
e1efefaf7d
16
packages/squiggle-lang/__tests__/Stdlib_test.res
Normal file
16
packages/squiggle-lang/__tests__/Stdlib_test.res
Normal file
|
@ -0,0 +1,16 @@
|
||||||
|
open Jest
|
||||||
|
open Expect
|
||||||
|
|
||||||
|
let makeTest = (~only=false, str, item1, item2) =>
|
||||||
|
only
|
||||||
|
? Only.test(str, () => expect(item1)->toEqual(item2))
|
||||||
|
: test(str, () => expect(item1)->toEqual(item2))
|
||||||
|
|
||||||
|
describe("Stdlib", () => {
|
||||||
|
makeTest("min", Stdlib.Random.sample([1.0, 2.0], {probs: [0.5, 0.5], size: 10}) |> E.A.length, 10)
|
||||||
|
makeTest(
|
||||||
|
"min",
|
||||||
|
Stdlib.Random.sample([1.0, 2.0], {probs: [0.5, 0.5], size: 10}) |> E.A.uniq |> E.A.Floats.sort,
|
||||||
|
[1.0, 2.0],
|
||||||
|
)
|
||||||
|
})
|
|
@ -18,6 +18,7 @@
|
||||||
"benchmark": "ts-node benchmark/conversion_tests.ts",
|
"benchmark": "ts-node benchmark/conversion_tests.ts",
|
||||||
"test": "jest",
|
"test": "jest",
|
||||||
"test:ts": "jest __tests__/TS/",
|
"test:ts": "jest __tests__/TS/",
|
||||||
|
"test:stdlib": "jest __tests__/Stdlib_test.bs.js",
|
||||||
"test:rescript": "jest --modulePathIgnorePatterns=__tests__/TS/*",
|
"test:rescript": "jest --modulePathIgnorePatterns=__tests__/TS/*",
|
||||||
"test:watch": "jest --watchAll",
|
"test:watch": "jest --watchAll",
|
||||||
"test:fnRegistry": "jest __tests__/SquiggleLibrary/SquiggleLibrary_FunctionRegistryLibrary_test.bs.js",
|
"test:fnRegistry": "jest __tests__/SquiggleLibrary/SquiggleLibrary_FunctionRegistryLibrary_test.bs.js",
|
||||||
|
|
|
@ -224,3 +224,8 @@ module T = Dist({
|
||||||
XYShape.Analysis.getVarianceDangerously(t, mean, getMeanOfSquares)
|
XYShape.Analysis.getVarianceDangerously(t, mean, getMeanOfSquares)
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
|
|
||||||
|
let sampleN = (t: t, n): array<float> => {
|
||||||
|
let normalized = t |> T.normalize |> getShape
|
||||||
|
Stdlib.Random.sample(normalized.xs, {probs: normalized.ys, size: n})
|
||||||
|
}
|
||||||
|
|
|
@ -257,3 +257,7 @@ let toSparkline = (t: t, bucketCount): result<string, PointSetTypes.sparklineErr
|
||||||
->E.O2.fmap(Continuous.downsampleEquallyOverX(bucketCount))
|
->E.O2.fmap(Continuous.downsampleEquallyOverX(bucketCount))
|
||||||
->E.O2.toResult(PointSetTypes.CannotSparklineDiscrete)
|
->E.O2.toResult(PointSetTypes.CannotSparklineDiscrete)
|
||||||
->E.R2.fmap(r => Continuous.getShape(r).ys->Sparklines.create())
|
->E.R2.fmap(r => Continuous.getShape(r).ys->Sparklines.create())
|
||||||
|
|
||||||
|
let makeDiscrete = (d):t => Discrete(d)
|
||||||
|
let makeContinuous = (d):t => Continuous(d)
|
||||||
|
let makeMixed = (d):t => Mixed(d)
|
|
@ -134,12 +134,19 @@ let percentile = (t, f) => T.get(t)->E.A.Floats.percentile(f)
|
||||||
|
|
||||||
let mixture = (values: array<(t, float)>, intendedLength: int) => {
|
let mixture = (values: array<(t, float)>, intendedLength: int) => {
|
||||||
let totalWeight = values->E.A2.fmap(E.Tuple2.second)->E.A.Floats.sum
|
let totalWeight = values->E.A2.fmap(E.Tuple2.second)->E.A.Floats.sum
|
||||||
values
|
let discreteSamples =
|
||||||
->E.A2.fmap(((dist, weight)) => {
|
values
|
||||||
let adjustedWeight = weight /. totalWeight
|
->Belt.Array.mapWithIndex((i, (_, weight)) => (E.I.toFloat(i), weight /. totalWeight))
|
||||||
let samplesToGet = adjustedWeight *. E.I.toFloat(intendedLength) |> E.Float.toInt
|
->XYShape.T.fromZippedArray
|
||||||
sampleN(dist, samplesToGet)
|
->Discrete.make
|
||||||
})
|
->Discrete.sampleN(intendedLength)
|
||||||
->E.A.concatMany
|
let dists = values->E.A2.fmap(E.Tuple2.first)->E.A2.fmap(T.get)
|
||||||
->T.make
|
let samples =
|
||||||
|
discreteSamples
|
||||||
|
->Belt.Array.mapWithIndex((index, distIndexToChoose) => {
|
||||||
|
let chosenDist = E.A.get(dists, E.Float.toInt(distIndexToChoose))
|
||||||
|
chosenDist |> E.O2.bind(E.A.get(_, index))
|
||||||
|
})
|
||||||
|
->E.A.O.openIfAllSome
|
||||||
|
(samples |> E.O.toExn("Mixture unreachable error"))->T.make
|
||||||
}
|
}
|
||||||
|
|
|
@ -38,3 +38,11 @@ module Logistic = {
|
||||||
@module external variance: (float, float) => float = "@stdlib/stats/base/dists/logistic/variance"
|
@module external variance: (float, float) => float = "@stdlib/stats/base/dists/logistic/variance"
|
||||||
let variance = variance
|
let variance = variance
|
||||||
}
|
}
|
||||||
|
|
||||||
|
module Random = {
|
||||||
|
type sampleArgs = {
|
||||||
|
probs: array<float>,
|
||||||
|
size: int,
|
||||||
|
}
|
||||||
|
@module external sample: (array<float>, sampleArgs) => array<float> = "@stdlib/random/sample"
|
||||||
|
}
|
Loading…
Reference in New Issue
Block a user