diff --git a/packages/squiggle-lang/__tests__/JS__Test.ts b/packages/squiggle-lang/__tests__/JS__Test.ts index 07a4cabe..ba2f91f4 100644 --- a/packages/squiggle-lang/__tests__/JS__Test.ts +++ b/packages/squiggle-lang/__tests__/JS__Test.ts @@ -39,10 +39,18 @@ describe("Multimodal too many weights error", () => { }); describe("GenericDist", () => { + + //It's important that sampleCount is less than 9. If it's more, than that will create randomness + let env = { sampleCount: 8, xyPointLength: 100 }; let dist = new GenericDist( { tag: "SampleSet", value: [3, 4, 5, 6, 6, 7, 10, 15, 30] }, - { sampleCount: 100, xyPointLength: 100 } + env ); + let dist2 = new GenericDist( + { tag: "SampleSet", value: [20, 22, 24, 29, 30, 35, 38, 44, 52] }, + env + ); + test("mean", () => { expect(dist.mean().value).toBeCloseTo(3.737); }); @@ -63,4 +71,16 @@ describe("GenericDist", () => { test("toSparkline", () => { expect(dist.toSparkline(20).value).toBe("▁▁▃▅███▆▄▃▂▁▁▂▂▃▂▁▁▁"); }); + test("algebraicAdd", () => { + expect( + resultMap(dist.algebraicAdd(dist2), (r: GenericDist) => r.toSparkline(20)) + .value.value + ).toBe("▁▁▂▄▆████▇▆▄▄▃▃▃▂▁▁▁"); + }); + test("pointwiseAdd", () => { + expect( + resultMap(dist.pointwiseAdd(dist2), (r: GenericDist) => r.toSparkline(20)) + .value.value + ).toBe("▁▂▅██▅▅▅▆▇█▆▅▃▃▂▂▁▁▁"); + }); }); diff --git a/packages/squiggle-lang/src/rescript/Distributions/DistributionOperation/DistributionOperation.res b/packages/squiggle-lang/src/rescript/Distributions/DistributionOperation/DistributionOperation.res index 12b94bcc..11e7f190 100644 --- a/packages/squiggle-lang/src/rescript/Distributions/DistributionOperation/DistributionOperation.res +++ b/packages/squiggle-lang/src/rescript/Distributions/DistributionOperation/DistributionOperation.res @@ -221,7 +221,7 @@ module Constructors = { let pointwiseSubtract = (~env, dist1, dist2) => C.pointwiseSubtract(dist1, dist2)->run(~env)->toDistR let pointwiseLogarithm = (~env, dist1, dist2) => - C.pointwiseSubtract(dist1, dist2)->run(~env)->toDistR + C.pointwiseLogarithm(dist1, dist2)->run(~env)->toDistR let pointwiseExponentiate = (~env, dist1, dist2) => - C.pointwiseSubtract(dist1, dist2)->run(~env)->toDistR + C.pointwiseExponentiate(dist1, dist2)->run(~env)->toDistR } diff --git a/packages/squiggle-lang/src/rescript/Distributions/GenericDist/GenericDist.res b/packages/squiggle-lang/src/rescript/Distributions/GenericDist/GenericDist.res index 4815e98c..c321dc4a 100644 --- a/packages/squiggle-lang/src/rescript/Distributions/GenericDist/GenericDist.res +++ b/packages/squiggle-lang/src/rescript/Distributions/GenericDist/GenericDist.res @@ -10,7 +10,7 @@ let sampleN = (t: t, n) => switch t { | PointSet(r) => Ok(PointSetDist.sampleNRendered(n, r)) | Symbolic(r) => Ok(SymbolicDist.T.sampleN(n, r)) - | SampleSet(_) => Error(GenericDist_Types.NotYetImplemented) + | SampleSet(r) => Ok(SampleSet.sampleN(r, n)) } let fromFloat = (f: float): t => Symbolic(SymbolicDist.Float.make(f)) @@ -83,8 +83,10 @@ let toPointSet = ( let toSparkline = (t: t, ~sampleCount: int, ~buckets: int=20, unit): result => t - ->toPointSet(~xSelection=#Linear, ~xyPointLength=buckets*3, ~sampleCount, ()) - ->E.R.bind(r => r->PointSetDist.toSparkline(buckets)->E.R2.errMap(r => Error(GenericDist_Types.Other(r)))) + ->toPointSet(~xSelection=#Linear, ~xyPointLength=buckets * 3, ~sampleCount, ()) + ->E.R.bind(r => + r->PointSetDist.toSparkline(buckets)->E.R2.errMap(r => Error(GenericDist_Types.Other(r))) + ) module Truncate = { let trySymbolicSimplification = (leftCutoff, rightCutoff, t: t): option => diff --git a/packages/squiggle-lang/src/rescript/Distributions/GenericDist/GenericDist_Types.res b/packages/squiggle-lang/src/rescript/Distributions/GenericDist/GenericDist_Types.res index 1bfb3e82..b0a30d52 100644 --- a/packages/squiggle-lang/src/rescript/Distributions/GenericDist/GenericDist_Types.res +++ b/packages/squiggle-lang/src/rescript/Distributions/GenericDist/GenericDist_Types.res @@ -140,7 +140,7 @@ module Constructors = { dist1, ) let pointwiseAdd = (dist1, dist2): t => FromDist( - ToDistCombination(Algebraic, #Add, #Dist(dist2)), + ToDistCombination(Pointwise, #Add, #Dist(dist2)), dist1, ) let pointwiseMultiply = (dist1, dist2): t => FromDist( diff --git a/packages/squiggle-lang/src/rescript/Distributions/SampleSetDist/SampleSet.res b/packages/squiggle-lang/src/rescript/Distributions/SampleSetDist/SampleSet.res index 3c8f686a..07855686 100644 --- a/packages/squiggle-lang/src/rescript/Distributions/SampleSetDist/SampleSet.res +++ b/packages/squiggle-lang/src/rescript/Distributions/SampleSetDist/SampleSet.res @@ -60,6 +60,7 @@ module Internals = { : { let _ = Js.Array.push(element, continuous) } + () }) (continuous, discrete) @@ -143,4 +144,17 @@ let toPointSetDist = ( } samplesParse +} + +let sample = (t: t): float => { + let i = E.Int.random(~min=0, ~max=E.A.length(t) - 1) + E.A.unsafe_get(t, i) +} + +let sampleN = (t: t, n) => { + if n <= E.A.length(t) { + E.A.slice(t, ~offset=0, ~len=n) + } else { + Belt.Array.makeBy(n, _ => sample(t)) + } } \ No newline at end of file diff --git a/packages/squiggle-lang/src/rescript/Utility/E.res b/packages/squiggle-lang/src/rescript/Utility/E.res index 01dbde7d..16d87896 100644 --- a/packages/squiggle-lang/src/rescript/Utility/E.res +++ b/packages/squiggle-lang/src/rescript/Utility/E.res @@ -24,6 +24,7 @@ module FloatFloatMap = { module Int = { let max = (i1: int, i2: int) => i1 > i2 ? i1 : i2 + let random = (~min, ~max) => Js.Math.random_int(min, max) } /* Utils */ module U = { @@ -277,6 +278,7 @@ module A = { let fold_right = Array.fold_right let concatMany = Belt.Array.concatMany let keepMap = Belt.Array.keepMap + let slice = Belt.Array.slice let init = Array.init let reduce = Belt.Array.reduce let reducei = Belt.Array.reduceWithIndex