Add metalog

This commit is contained in:
Sam Nolan 2022-07-19 17:05:15 +10:00
parent 2c903a335e
commit 2adaba7d91
4 changed files with 109 additions and 1 deletions

View File

@ -31,6 +31,7 @@ describe("eval on distribution functions", () => {
testEval("mean(normal(5,2))", "Ok(5)") testEval("mean(normal(5,2))", "Ok(5)")
testEval("mean(lognormal(1,2))", "Ok(20.085536923187668)") testEval("mean(lognormal(1,2))", "Ok(20.085536923187668)")
testEval("mean(gamma(5,5))", "Ok(25)") testEval("mean(gamma(5,5))", "Ok(25)")
testEval("mean(metalog([1, 2]))", "Ok(1.0000000000000024)")
testEval("mean(bernoulli(0.2))", "Ok(0.2)") testEval("mean(bernoulli(0.2))", "Ok(0.2)")
testEval("mean(bernoulli(0.8))", "Ok(0.8)") testEval("mean(bernoulli(0.8))", "Ok(0.8)")
testEval("mean(logistic(5,1))", "Ok(5)") testEval("mean(logistic(5,1))", "Ok(5)")

View File

@ -254,6 +254,92 @@ module Logistic = {
let toString = ({location, scale}: t) => j`Logistic($location,$scale)` let toString = ({location, scale}: t) => j`Logistic($location,$scale)`
} }
module Metalog = {
type t = metalog
let make = terms =>
Js.Array.length(terms) > 1
? Ok(#Metalog({terms: terms}))
: Error("Metalog must have 2 or more terms")
let inv = (p, t: t) => {
let logy = log(p /. (1. -. p))
E.A.Floats.sum(Js.Array.mapi((term, i) =>
if i == 0 {
term
} else if i == 1 {
term *. logy
} else if i == 2 {
term *. (p -. 0.5) *. logy
} else if i == 3 {
term *. (p -. 0.5)
} else if mod(i, 2) == 0 {
term *. (p -. 0.5) ** (Belt.Int.toFloat(i) /. 2.)
} else {
term *. (p -. 0.5) ** ((Belt.Int.toFloat(i) -. 1.) /. 2.) *. logy
}
, t.terms))
}
let invderiv = (y: float, t: t): float => E.A.Floats.sum(Js.Array.mapi((term, i) => {
let k = Belt.Int.toFloat(i) +. 1.
if i == 0 {
0.
} else if i == 1 {
term /. ((1. -. y) *. y)
} else if i == 2 {
term /. (1. -. y) ** 2.
} else if i == 3 {
term
} else if mod(i, 2) == 0 {
term *. ((k -. 1.) /. 2.) *. (y -. 0.5) ** ((k -. 3.) /. 2.)
} else {
term *.
((y -. 0.5) ** (k /. 2. -. 1.) /. (y *. (1. -. y)) +.
(k /. 2. -. 1.) *. (y -. 0.5) ** (k /. 2. -. 2.) *. log(y /. (1. -. y)))
}
}, t.terms))
let rec improveCdfGuess = (guess: float, roundsLeft: int, target: float, t: t) =>
if roundsLeft == 0 {
guess
} else {
improveCdfGuess(
guess -. (inv(guess, t) -. target) /. invderiv(guess, t),
roundsLeft - 1,
target,
t,
)
}
let cdf = (x: float, t: t): float => {
let guess = ref(0.5)
let bisection_rounds = 50
for i in 0 to bisection_rounds - 1 {
let correction = 0.5 ** Belt.Int.toFloat(i + 2)
if inv(guess.contents, t) < x {
guess := guess.contents +. correction
} else {
guess := guess.contents -. correction
}
}
guess.contents
}
let pdf = (x: float, t: t): float => 1. /. invderiv(cdf(x, t), t)
let sample = (t: t) => inv(Jstat.Uniform.sample(0., 1.), t)
let toString = ({terms}: t) =>
j`Metalog([${Js.Array.joinWith(", ", Js.Array.map(Belt.Float.toString, terms))}])`
let meanSteps = 1000
let mean = (t: t) => {
let stepCountFloat = Belt.Int.toFloat(meanSteps)
let range = E.A.Floats.range(0. +. 1. /. stepCountFloat, 1. -. 1. /. stepCountFloat, meanSteps)
Ok(E.A.Floats.sum(E.A.fmap(x => inv(x, t) /. stepCountFloat, range)))
}
}
module Bernoulli = { module Bernoulli = {
type t = bernoulli type t = bernoulli
let make = p => let make = p =>
@ -347,6 +433,7 @@ module T = {
| #Beta(n) => Beta.pdf(x, n) | #Beta(n) => Beta.pdf(x, n)
| #Float(n) => Float.pdf(x, n) | #Float(n) => Float.pdf(x, n)
| #Bernoulli(n) => Bernoulli.pdf(x, n) | #Bernoulli(n) => Bernoulli.pdf(x, n)
| #Metalog(n) => Metalog.pdf(x, n)
} }
let cdf = (x, dist) => let cdf = (x, dist) =>
@ -362,6 +449,7 @@ module T = {
| #Beta(n) => Beta.cdf(x, n) | #Beta(n) => Beta.cdf(x, n)
| #Float(n) => Float.cdf(x, n) | #Float(n) => Float.cdf(x, n)
| #Bernoulli(n) => Bernoulli.cdf(x, n) | #Bernoulli(n) => Bernoulli.cdf(x, n)
| #Metalog(n) => Metalog.cdf(x, n)
} }
let inv = (x, dist) => let inv = (x, dist) =>
@ -377,6 +465,7 @@ module T = {
| #Beta(n) => Beta.inv(x, n) | #Beta(n) => Beta.inv(x, n)
| #Float(n) => Float.inv(x, n) | #Float(n) => Float.inv(x, n)
| #Bernoulli(n) => Bernoulli.inv(x, n) | #Bernoulli(n) => Bernoulli.inv(x, n)
| #Metalog(n) => Metalog.inv(x, n)
} }
let sample: symbolicDist => float = x => let sample: symbolicDist => float = x =>
@ -392,6 +481,7 @@ module T = {
| #Beta(n) => Beta.sample(n) | #Beta(n) => Beta.sample(n)
| #Float(n) => Float.sample(n) | #Float(n) => Float.sample(n)
| #Bernoulli(n) => Bernoulli.sample(n) | #Bernoulli(n) => Bernoulli.sample(n)
| #Metalog(n) => Metalog.sample(n)
} }
let doN = (n, fn) => { let doN = (n, fn) => {
@ -417,6 +507,7 @@ module T = {
| #Beta(n) => Beta.toString(n) | #Beta(n) => Beta.toString(n)
| #Float(n) => Float.toString(n) | #Float(n) => Float.toString(n)
| #Bernoulli(n) => Bernoulli.toString(n) | #Bernoulli(n) => Bernoulli.toString(n)
| #Metalog(n) => Metalog.toString(n)
} }
let min: symbolicDist => float = x => let min: symbolicDist => float = x =>
@ -431,6 +522,7 @@ module T = {
| #Uniform({low}) => low | #Uniform({low}) => low
| #Bernoulli(n) => Bernoulli.min(n) | #Bernoulli(n) => Bernoulli.min(n)
| #Beta(n) => Beta.inv(minCdfValue, n) | #Beta(n) => Beta.inv(minCdfValue, n)
| #Metalog(n) => Metalog.inv(minCdfValue, n)
| #Float(n) => n | #Float(n) => n
} }
@ -445,6 +537,7 @@ module T = {
| #Logistic(n) => Logistic.inv(maxCdfValue, n) | #Logistic(n) => Logistic.inv(maxCdfValue, n)
| #Beta(n) => Beta.inv(maxCdfValue, n) | #Beta(n) => Beta.inv(maxCdfValue, n)
| #Bernoulli(n) => Bernoulli.max(n) | #Bernoulli(n) => Bernoulli.max(n)
| #Metalog(n) => Metalog.inv(maxCdfValue, n)
| #Uniform({high}) => high | #Uniform({high}) => high
| #Float(n) => n | #Float(n) => n
} }
@ -461,6 +554,7 @@ module T = {
| #Uniform(n) => Uniform.mean(n) | #Uniform(n) => Uniform.mean(n)
| #Gamma(n) => Gamma.mean(n) | #Gamma(n) => Gamma.mean(n)
| #Bernoulli(n) => Bernoulli.mean(n) | #Bernoulli(n) => Bernoulli.mean(n)
| #Metalog(n) => Metalog.mean(n)
| #Float(n) => Float.mean(n) | #Float(n) => Float.mean(n)
} }

View File

@ -41,6 +41,8 @@ type logistic = {
scale: float, scale: float,
} }
type metalog = {terms: array<float>}
type bernoulli = {p: float} type bernoulli = {p: float}
@genType @genType
@ -56,6 +58,7 @@ type symbolicDist = [
| #Float(float) | #Float(float)
| #Bernoulli(bernoulli) | #Bernoulli(bernoulli)
| #Logistic(logistic) | #Logistic(logistic)
| #Metalog(metalog)
] ]
type analyticalSimplificationResult = [ type analyticalSimplificationResult = [

View File

@ -139,10 +139,20 @@ beta({mean: 0.39, stdev: 0.1})`,
), ),
Function.make( Function.make(
~name="Logistic", ~name="Logistic",
~examples=`gamma(5, 1)`, ~examples=`logistic(5, 1)`,
~definitions=[TwoArgDist.make("logistic", twoArgs(SymbolicDist.Logistic.make))], ~definitions=[TwoArgDist.make("logistic", twoArgs(SymbolicDist.Logistic.make))],
(), (),
), ),
Function.make(
~name="Metalog",
~examples=`metalog([1, 2, 3])`,
~definitions=[
ArrayNumberDist.make("metalog", x =>
SymbolicDist.Metalog.make(x)->E.R2.fmap(x => Wrappers.symbolic(x)->Wrappers.evDistribution)
),
],
(),
),
Function.make( Function.make(
~name="To (Distribution)", ~name="To (Distribution)",
~examples=`5 to 10 ~examples=`5 to 10