diff --git a/packages/squiggle-lang/__tests__/ReducerInterface/ReducerInterface_Distribution_test.res b/packages/squiggle-lang/__tests__/ReducerInterface/ReducerInterface_Distribution_test.res index 571a838b..253a9cc3 100644 --- a/packages/squiggle-lang/__tests__/ReducerInterface/ReducerInterface_Distribution_test.res +++ b/packages/squiggle-lang/__tests__/ReducerInterface/ReducerInterface_Distribution_test.res @@ -31,6 +31,8 @@ describe("eval on distribution functions", () => { testEval("mean(normal(5,2))", "Ok(5)") testEval("mean(lognormal(1,2))", "Ok(20.085536923187668)") testEval("mean(gamma(5,5))", "Ok(25)") + testEval("mean(bernoulli(0.2))", "Ok(0.2)") + testEval("mean(bernoulli(0.8))", "Ok(0.8)") }) describe("toString", () => { testEval("toString(normal(5,2))", "Ok('Normal(5,2)')") diff --git a/packages/squiggle-lang/src/rescript/Distributions/SymbolicDist/SymbolicDist.res b/packages/squiggle-lang/src/rescript/Distributions/SymbolicDist/SymbolicDist.res index 6c57430c..6356f97c 100644 --- a/packages/squiggle-lang/src/rescript/Distributions/SymbolicDist/SymbolicDist.res +++ b/packages/squiggle-lang/src/rescript/Distributions/SymbolicDist/SymbolicDist.res @@ -221,20 +221,27 @@ module Bernoulli = { let make = p => p >= 0.0 && p <= 1.0 ? Ok(#Bernoulli({p: p})) - : Error("Beta distribution parameters must be positive") + : Error("Bernoulli parameter must be between 0 and 1") let pmf = (x, t: t) => Stdlib.Bernoulli.pmf(x, t.p) + + //Bernoulli is a discrete distribution, so it doesn't really have a pdf(). + //We fake this for now with the pmf function, but this should be fixed at some point. let pdf = (x, t: t) => Stdlib.Bernoulli.pmf(x, t.p) let cdf = (x, t: t) => Stdlib.Bernoulli.cdf(x, t.p) let inv = (p, t: t) => Stdlib.Bernoulli.quantile(p, t.p) let mean = (t: t) => Ok(Stdlib.Bernoulli.mean(t.p)) + let min = (t: t) => t.p == 1.0 ? 1.0 : 0.0 + let max = (t: t) => t.p == 0.0 ? 0.0 : 1.0 let sample = (t: t) => { - let s = Uniform.sample(({low: 0.0, high: 1.0})); - inv(s,t) + let s = Uniform.sample({low: 0.0, high: 1.0}) + inv(s, t) } let toString = ({p}: t) => j`Bernoulli($p)` + let toPointSetDist = ({p}: t): PointSetTypes.pointSetDist => Discrete( + Discrete.make(~integralSumCache=Some(1.0), {xs: [0.0, 1.0], ys: [1.0 -. p, p]}), + ) } - module Gamma = { type t = gamma let make = (shape: float, scale: float) => { @@ -271,6 +278,9 @@ module Float = { let mean = (t: t) => Ok(t) let sample = (t: t) => t let toString = (t: t) => j`Delta($t)` + let toPointSetDist = (t: t): PointSetTypes.pointSetDist => Discrete( + Discrete.make(~integralSumCache=Some(1.0), {xs: [t], ys: [1.0]}), + ) } module From90thPercentile = { @@ -375,7 +385,7 @@ module T = { | #Lognormal(n) => Lognormal.inv(minCdfValue, n) | #Gamma(n) => Gamma.inv(minCdfValue, n) | #Uniform({low}) => low - | #Bernoulli(n) => 0.0 + | #Bernoulli(n) => Bernoulli.min(n) | #Beta(n) => Beta.inv(minCdfValue, n) | #Float(n) => n } @@ -389,7 +399,7 @@ module T = { | #Gamma(n) => Gamma.inv(maxCdfValue, n) | #Lognormal(n) => Lognormal.inv(maxCdfValue, n) | #Beta(n) => Beta.inv(maxCdfValue, n) - | #Bernoulli(n) => 1.0 + | #Bernoulli(n) => Bernoulli.max(n) | #Uniform({high}) => high | #Float(n) => n } @@ -404,8 +414,8 @@ module T = { | #Beta(n) => Beta.mean(n) | #Uniform(n) => Uniform.mean(n) | #Gamma(n) => Gamma.mean(n) - | #Float(n) => Float.mean(n) | #Bernoulli(n) => Bernoulli.mean(n) + | #Float(n) => Float.mean(n) } let operate = (distToFloatOp: Operation.distToFloatOperation, s) => @@ -480,9 +490,8 @@ module T = { d: symbolicDist, ): PointSetTypes.pointSetDist => switch d { - | #Float(v) => Discrete(Discrete.make(~integralSumCache=Some(1.0), {xs: [v], ys: [1.0]})) - | #Bernoulli(v) => - Discrete(Discrete.make(~integralSumCache=Some(1.0), {xs: [0.0, 1.0], ys: [1.0 -. v.p, v.p]})) + | #Float(v) => Float.toPointSetDist(v) + | #Bernoulli(v) => Bernoulli.toPointSetDist(v) | _ => let xs = interpolateXs(~xSelection, d, sampleCount) let ys = xs |> E.A.fmap(x => pdf(x, d)) diff --git a/packages/squiggle-lang/src/rescript/Utility/stdlib.ts b/packages/squiggle-lang/src/rescript/Utility/stdlib.ts deleted file mode 100644 index 5d106084..00000000 --- a/packages/squiggle-lang/src/rescript/Utility/stdlib.ts +++ /dev/null @@ -1,4 +0,0 @@ -var Bernoulli = require("@stdlib/stats/base/dists/bernoulli").Bernoulli; - -let bernoulliCdf = (p: number, x: number): number => new Bernoulli(p).cdf(x); -let bernoulliPmf = (p: number, x: number): number => new Bernoulli(p).cmf(x);