This commit is contained in:
Ozzie Gooen 2022-05-15 19:42:10 -04:00
parent 7216f8079f
commit 6156ae65d1
3 changed files with 21 additions and 14 deletions

View File

@ -31,6 +31,8 @@ describe("eval on distribution functions", () => {
testEval("mean(normal(5,2))", "Ok(5)")
testEval("mean(lognormal(1,2))", "Ok(20.085536923187668)")
testEval("mean(gamma(5,5))", "Ok(25)")
testEval("mean(bernoulli(0.2))", "Ok(0.2)")
testEval("mean(bernoulli(0.8))", "Ok(0.8)")
})
describe("toString", () => {
testEval("toString(normal(5,2))", "Ok('Normal(5,2)')")

View File

@ -221,20 +221,27 @@ module Bernoulli = {
let make = p =>
p >= 0.0 && p <= 1.0
? Ok(#Bernoulli({p: p}))
: Error("Beta distribution parameters must be positive")
: Error("Bernoulli parameter must be between 0 and 1")
let pmf = (x, t: t) => Stdlib.Bernoulli.pmf(x, t.p)
//Bernoulli is a discrete distribution, so it doesn't really have a pdf().
//We fake this for now with the pmf function, but this should be fixed at some point.
let pdf = (x, t: t) => Stdlib.Bernoulli.pmf(x, t.p)
let cdf = (x, t: t) => Stdlib.Bernoulli.cdf(x, t.p)
let inv = (p, t: t) => Stdlib.Bernoulli.quantile(p, t.p)
let mean = (t: t) => Ok(Stdlib.Bernoulli.mean(t.p))
let min = (t: t) => t.p == 1.0 ? 1.0 : 0.0
let max = (t: t) => t.p == 0.0 ? 0.0 : 1.0
let sample = (t: t) => {
let s = Uniform.sample(({low: 0.0, high: 1.0}));
let s = Uniform.sample({low: 0.0, high: 1.0})
inv(s, t)
}
let toString = ({p}: t) => j`Bernoulli($p)`
let toPointSetDist = ({p}: t): PointSetTypes.pointSetDist => Discrete(
Discrete.make(~integralSumCache=Some(1.0), {xs: [0.0, 1.0], ys: [1.0 -. p, p]}),
)
}
module Gamma = {
type t = gamma
let make = (shape: float, scale: float) => {
@ -271,6 +278,9 @@ module Float = {
let mean = (t: t) => Ok(t)
let sample = (t: t) => t
let toString = (t: t) => j`Delta($t)`
let toPointSetDist = (t: t): PointSetTypes.pointSetDist => Discrete(
Discrete.make(~integralSumCache=Some(1.0), {xs: [t], ys: [1.0]}),
)
}
module From90thPercentile = {
@ -375,7 +385,7 @@ module T = {
| #Lognormal(n) => Lognormal.inv(minCdfValue, n)
| #Gamma(n) => Gamma.inv(minCdfValue, n)
| #Uniform({low}) => low
| #Bernoulli(n) => 0.0
| #Bernoulli(n) => Bernoulli.min(n)
| #Beta(n) => Beta.inv(minCdfValue, n)
| #Float(n) => n
}
@ -389,7 +399,7 @@ module T = {
| #Gamma(n) => Gamma.inv(maxCdfValue, n)
| #Lognormal(n) => Lognormal.inv(maxCdfValue, n)
| #Beta(n) => Beta.inv(maxCdfValue, n)
| #Bernoulli(n) => 1.0
| #Bernoulli(n) => Bernoulli.max(n)
| #Uniform({high}) => high
| #Float(n) => n
}
@ -404,8 +414,8 @@ module T = {
| #Beta(n) => Beta.mean(n)
| #Uniform(n) => Uniform.mean(n)
| #Gamma(n) => Gamma.mean(n)
| #Float(n) => Float.mean(n)
| #Bernoulli(n) => Bernoulli.mean(n)
| #Float(n) => Float.mean(n)
}
let operate = (distToFloatOp: Operation.distToFloatOperation, s) =>
@ -480,9 +490,8 @@ module T = {
d: symbolicDist,
): PointSetTypes.pointSetDist =>
switch d {
| #Float(v) => Discrete(Discrete.make(~integralSumCache=Some(1.0), {xs: [v], ys: [1.0]}))
| #Bernoulli(v) =>
Discrete(Discrete.make(~integralSumCache=Some(1.0), {xs: [0.0, 1.0], ys: [1.0 -. v.p, v.p]}))
| #Float(v) => Float.toPointSetDist(v)
| #Bernoulli(v) => Bernoulli.toPointSetDist(v)
| _ =>
let xs = interpolateXs(~xSelection, d, sampleCount)
let ys = xs |> E.A.fmap(x => pdf(x, d))

View File

@ -1,4 +0,0 @@
var Bernoulli = require("@stdlib/stats/base/dists/bernoulli").Bernoulli;
let bernoulliCdf = (p: number, x: number): number => new Bernoulli(p).cdf(x);
let bernoulliPmf = (p: number, x: number): number => new Bernoulli(p).cmf(x);