Merge pull request #479 from quantified-uncertainty/gamma-distribution

Add Gamma distribution
This commit is contained in:
Ozzie Gooen 2022-05-04 12:01:28 -04:00 committed by GitHub
commit 7585bd3599
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 46 additions and 1 deletions

View File

@ -30,6 +30,7 @@ describe("eval on distribution functions", () => {
describe("mean", () => { describe("mean", () => {
testEval("mean(normal(5,2))", "Ok(5)") testEval("mean(normal(5,2))", "Ok(5)")
testEval("mean(lognormal(1,2))", "Ok(20.085536923187668)") testEval("mean(lognormal(1,2))", "Ok(20.085536923187668)")
testEval("mean(gamma(5,5))", "Ok(25)")
}) })
describe("toString", () => { describe("toString", () => {
testEval("toString(normal(5,2))", "Ok('Normal(5,2)')") testEval("toString(normal(5,2))", "Ok('Normal(5,2)')")

View File

@ -216,6 +216,27 @@ module Uniform = {
} }
} }
module Gamma = {
type t = gamma
let make = (shape: float, scale: float) => {
if shape > 0. {
if scale > 0. {
Ok(#Gamma({shape: shape, scale: scale}))
} else {
Error("scale must be larger than 0")
}
} else {
Error("shape must be larger than 0")
}
}
let pdf = (x: float, t: t) => Jstat.Gamma.pdf(x, t.shape, t.scale)
let cdf = (x: float, t: t) => Jstat.Gamma.cdf(x, t.shape, t.scale)
let inv = (p: float, t: t) => Jstat.Gamma.inv(p, t.shape, t.scale)
let sample = (t: t) => Jstat.Gamma.sample(t.shape, t.scale)
let mean = (t: t) => Ok(Jstat.Gamma.mean(t.shape, t.scale))
let toString = ({shape, scale}: t) => j`($shape, $scale)`
}
module Float = { module Float = {
type t = float type t = float
let make = t => #Float(t) let make = t => #Float(t)
@ -252,6 +273,7 @@ module T = {
| #Triangular(n) => Triangular.pdf(x, n) | #Triangular(n) => Triangular.pdf(x, n)
| #Exponential(n) => Exponential.pdf(x, n) | #Exponential(n) => Exponential.pdf(x, n)
| #Cauchy(n) => Cauchy.pdf(x, n) | #Cauchy(n) => Cauchy.pdf(x, n)
| #Gamma(n) => Gamma.pdf(x, n)
| #Lognormal(n) => Lognormal.pdf(x, n) | #Lognormal(n) => Lognormal.pdf(x, n)
| #Uniform(n) => Uniform.pdf(x, n) | #Uniform(n) => Uniform.pdf(x, n)
| #Beta(n) => Beta.pdf(x, n) | #Beta(n) => Beta.pdf(x, n)
@ -264,6 +286,7 @@ module T = {
| #Triangular(n) => Triangular.cdf(x, n) | #Triangular(n) => Triangular.cdf(x, n)
| #Exponential(n) => Exponential.cdf(x, n) | #Exponential(n) => Exponential.cdf(x, n)
| #Cauchy(n) => Cauchy.cdf(x, n) | #Cauchy(n) => Cauchy.cdf(x, n)
| #Gamma(n) => Gamma.cdf(x, n)
| #Lognormal(n) => Lognormal.cdf(x, n) | #Lognormal(n) => Lognormal.cdf(x, n)
| #Uniform(n) => Uniform.cdf(x, n) | #Uniform(n) => Uniform.cdf(x, n)
| #Beta(n) => Beta.cdf(x, n) | #Beta(n) => Beta.cdf(x, n)
@ -276,6 +299,7 @@ module T = {
| #Triangular(n) => Triangular.inv(x, n) | #Triangular(n) => Triangular.inv(x, n)
| #Exponential(n) => Exponential.inv(x, n) | #Exponential(n) => Exponential.inv(x, n)
| #Cauchy(n) => Cauchy.inv(x, n) | #Cauchy(n) => Cauchy.inv(x, n)
| #Gamma(n) => Gamma.inv(x, n)
| #Lognormal(n) => Lognormal.inv(x, n) | #Lognormal(n) => Lognormal.inv(x, n)
| #Uniform(n) => Uniform.inv(x, n) | #Uniform(n) => Uniform.inv(x, n)
| #Beta(n) => Beta.inv(x, n) | #Beta(n) => Beta.inv(x, n)
@ -288,6 +312,7 @@ module T = {
| #Triangular(n) => Triangular.sample(n) | #Triangular(n) => Triangular.sample(n)
| #Exponential(n) => Exponential.sample(n) | #Exponential(n) => Exponential.sample(n)
| #Cauchy(n) => Cauchy.sample(n) | #Cauchy(n) => Cauchy.sample(n)
| #Gamma(n) => Gamma.sample(n)
| #Lognormal(n) => Lognormal.sample(n) | #Lognormal(n) => Lognormal.sample(n)
| #Uniform(n) => Uniform.sample(n) | #Uniform(n) => Uniform.sample(n)
| #Beta(n) => Beta.sample(n) | #Beta(n) => Beta.sample(n)
@ -310,6 +335,7 @@ module T = {
| #Exponential(n) => Exponential.toString(n) | #Exponential(n) => Exponential.toString(n)
| #Cauchy(n) => Cauchy.toString(n) | #Cauchy(n) => Cauchy.toString(n)
| #Normal(n) => Normal.toString(n) | #Normal(n) => Normal.toString(n)
| #Gamma(n) => Gamma.toString(n)
| #Lognormal(n) => Lognormal.toString(n) | #Lognormal(n) => Lognormal.toString(n)
| #Uniform(n) => Uniform.toString(n) | #Uniform(n) => Uniform.toString(n)
| #Beta(n) => Beta.toString(n) | #Beta(n) => Beta.toString(n)
@ -323,6 +349,7 @@ module T = {
| #Cauchy(n) => Cauchy.inv(minCdfValue, n) | #Cauchy(n) => Cauchy.inv(minCdfValue, n)
| #Normal(n) => Normal.inv(minCdfValue, n) | #Normal(n) => Normal.inv(minCdfValue, n)
| #Lognormal(n) => Lognormal.inv(minCdfValue, n) | #Lognormal(n) => Lognormal.inv(minCdfValue, n)
| #Gamma(n) => Gamma.inv(minCdfValue, n)
| #Uniform({low}) => low | #Uniform({low}) => low
| #Beta(n) => Beta.inv(minCdfValue, n) | #Beta(n) => Beta.inv(minCdfValue, n)
| #Float(n) => n | #Float(n) => n
@ -334,6 +361,7 @@ module T = {
| #Exponential(n) => Exponential.inv(maxCdfValue, n) | #Exponential(n) => Exponential.inv(maxCdfValue, n)
| #Cauchy(n) => Cauchy.inv(maxCdfValue, n) | #Cauchy(n) => Cauchy.inv(maxCdfValue, n)
| #Normal(n) => Normal.inv(maxCdfValue, n) | #Normal(n) => Normal.inv(maxCdfValue, n)
| #Gamma(n) => Gamma.inv(maxCdfValue, n)
| #Lognormal(n) => Lognormal.inv(maxCdfValue, n) | #Lognormal(n) => Lognormal.inv(maxCdfValue, n)
| #Beta(n) => Beta.inv(maxCdfValue, n) | #Beta(n) => Beta.inv(maxCdfValue, n)
| #Uniform({high}) => high | #Uniform({high}) => high
@ -349,6 +377,7 @@ module T = {
| #Lognormal(n) => Lognormal.mean(n) | #Lognormal(n) => Lognormal.mean(n)
| #Beta(n) => Beta.mean(n) | #Beta(n) => Beta.mean(n)
| #Uniform(n) => Uniform.mean(n) | #Uniform(n) => Uniform.mean(n)
| #Gamma(n) => Gamma.mean(n)
| #Float(n) => Float.mean(n) | #Float(n) => Float.mean(n)
} }

View File

@ -31,6 +31,11 @@ type triangular = {
high: float, high: float,
} }
type gamma = {
shape: float,
scale: float,
}
@genType @genType
type symbolicDist = [ type symbolicDist = [
| #Normal(normal) | #Normal(normal)
@ -40,6 +45,7 @@ type symbolicDist = [
| #Exponential(exponential) | #Exponential(exponential)
| #Cauchy(cauchy) | #Cauchy(cauchy)
| #Triangular(triangular) | #Triangular(triangular)
| #Gamma(gamma)
| #Float(float) | #Float(float)
] ]

View File

@ -154,6 +154,7 @@ module SymbolicConstructors = {
| "beta" => Ok(SymbolicDist.Beta.make) | "beta" => Ok(SymbolicDist.Beta.make)
| "lognormal" => Ok(SymbolicDist.Lognormal.make) | "lognormal" => Ok(SymbolicDist.Lognormal.make)
| "cauchy" => Ok(SymbolicDist.Cauchy.make) | "cauchy" => Ok(SymbolicDist.Cauchy.make)
| "gamma" => Ok(SymbolicDist.Gamma.make)
| "to" => Ok(SymbolicDist.From90thPercentile.make) | "to" => Ok(SymbolicDist.From90thPercentile.make)
| _ => Error("Unreachable state") | _ => Error("Unreachable state")
} }
@ -185,7 +186,7 @@ let dispatchToGenericOutput = (call: ExpressionValue.functionCall, _environment)
| ("delta", [EvNumber(f)]) => | ("delta", [EvNumber(f)]) =>
SymbolicDist.Float.makeSafe(f)->SymbolicConstructors.symbolicResultToOutput SymbolicDist.Float.makeSafe(f)->SymbolicConstructors.symbolicResultToOutput
| ( | (
("normal" | "uniform" | "beta" | "lognormal" | "cauchy" | "to") as fnName, ("normal" | "uniform" | "beta" | "lognormal" | "cauchy" | "gamma" | "to") as fnName,
[EvNumber(f1), EvNumber(f2)], [EvNumber(f1), EvNumber(f2)],
) => ) =>
SymbolicConstructors.twoFloat(fnName) SymbolicConstructors.twoFloat(fnName)

View File

@ -81,6 +81,14 @@ module Binomial = {
@module("jstat") @scope("binomial") external cdf: (float, float, float) => float = "cdf" @module("jstat") @scope("binomial") external cdf: (float, float, float) => float = "cdf"
} }
module Gamma = {
@module("jstat") @scope("gamma") external pdf: (float, float, float) => float = "pdf"
@module("jstat") @scope("gamma") external cdf: (float, float, float) => float = "cdf"
@module("jstat") @scope("gamma") external inv: (float, float, float) => float = "inv"
@module("jstat") @scope("gamma") external mean: (float, float) => float = "mean"
@module("jstat") @scope("gamma") external sample: (float, float) => float = "sample"
}
@module("jstat") external sum: array<float> => float = "sum" @module("jstat") external sum: array<float> => float = "sum"
@module("jstat") external product: array<float> => float = "product" @module("jstat") external product: array<float> => float = "product"
@module("jstat") external min: array<float> => float = "min" @module("jstat") external min: array<float> => float = "min"