diff --git a/packages/squiggle-lang/__tests__/ReducerInterface/ReducerInterface_Distribution_test.res b/packages/squiggle-lang/__tests__/ReducerInterface/ReducerInterface_Distribution_test.res index ecc07bfa..aa878a66 100644 --- a/packages/squiggle-lang/__tests__/ReducerInterface/ReducerInterface_Distribution_test.res +++ b/packages/squiggle-lang/__tests__/ReducerInterface/ReducerInterface_Distribution_test.res @@ -31,6 +31,7 @@ describe("eval on distribution functions", () => { testEval("mean(normal(5,2))", "Ok(5)") testEval("mean(lognormal(1,2))", "Ok(20.085536923187668)") testEval("mean(gamma(5,5))", "Ok(25)") + testEval("mean(metalog([1, 2]))", "Ok(1.0000000000000024)") testEval("mean(bernoulli(0.2))", "Ok(0.2)") testEval("mean(bernoulli(0.8))", "Ok(0.8)") testEval("mean(logistic(5,1))", "Ok(5)") diff --git a/packages/squiggle-lang/src/rescript/Distributions/SymbolicDist/SymbolicDist.res b/packages/squiggle-lang/src/rescript/Distributions/SymbolicDist/SymbolicDist.res index 249deb02..39f23b0a 100644 --- a/packages/squiggle-lang/src/rescript/Distributions/SymbolicDist/SymbolicDist.res +++ b/packages/squiggle-lang/src/rescript/Distributions/SymbolicDist/SymbolicDist.res @@ -254,6 +254,92 @@ module Logistic = { let toString = ({location, scale}: t) => j`Logistic($location,$scale)` } +module Metalog = { + type t = metalog + let make = terms => + Js.Array.length(terms) > 1 + ? Ok(#Metalog({terms: terms})) + : Error("Metalog must have 2 or more terms") + + let inv = (p, t: t) => { + let logy = log(p /. (1. -. p)) + E.A.Floats.sum(Js.Array.mapi((term, i) => + if i == 0 { + term + } else if i == 1 { + term *. logy + } else if i == 2 { + term *. (p -. 0.5) *. logy + } else if i == 3 { + term *. (p -. 0.5) + } else if mod(i, 2) == 0 { + term *. (p -. 0.5) ** (Belt.Int.toFloat(i) /. 2.) + } else { + term *. (p -. 0.5) ** ((Belt.Int.toFloat(i) -. 1.) /. 2.) *. logy + } + , t.terms)) + } + + let invderiv = (y: float, t: t): float => E.A.Floats.sum(Js.Array.mapi((term, i) => { + let k = Belt.Int.toFloat(i) +. 1. + if i == 0 { + 0. + } else if i == 1 { + term /. ((1. -. y) *. y) + } else if i == 2 { + term /. (1. -. y) ** 2. + } else if i == 3 { + term + } else if mod(i, 2) == 0 { + term *. ((k -. 1.) /. 2.) *. (y -. 0.5) ** ((k -. 3.) /. 2.) + } else { + term *. + ((y -. 0.5) ** (k /. 2. -. 1.) /. (y *. (1. -. y)) +. + (k /. 2. -. 1.) *. (y -. 0.5) ** (k /. 2. -. 2.) *. log(y /. (1. -. y))) + } + }, t.terms)) + + let rec improveCdfGuess = (guess: float, roundsLeft: int, target: float, t: t) => + if roundsLeft == 0 { + guess + } else { + improveCdfGuess( + guess -. (inv(guess, t) -. target) /. invderiv(guess, t), + roundsLeft - 1, + target, + t, + ) + } + + let cdf = (x: float, t: t): float => { + let guess = ref(0.5) + let bisection_rounds = 50 + for i in 0 to bisection_rounds - 1 { + let correction = 0.5 ** Belt.Int.toFloat(i + 2) + if inv(guess.contents, t) < x { + guess := guess.contents +. correction + } else { + guess := guess.contents -. correction + } + } + guess.contents + } + + let pdf = (x: float, t: t): float => 1. /. invderiv(cdf(x, t), t) + + let sample = (t: t) => inv(Jstat.Uniform.sample(0., 1.), t) + + let toString = ({terms}: t) => + j`Metalog([${Js.Array.joinWith(", ", Js.Array.map(Belt.Float.toString, terms))}])` + + let meanSteps = 1000 + let mean = (t: t) => { + let stepCountFloat = Belt.Int.toFloat(meanSteps) + let range = E.A.Floats.range(0. +. 1. /. stepCountFloat, 1. -. 1. /. stepCountFloat, meanSteps) + Ok(E.A.Floats.sum(E.A.fmap(x => inv(x, t) /. stepCountFloat, range))) + } +} + module Bernoulli = { type t = bernoulli let make = p => @@ -347,6 +433,7 @@ module T = { | #Beta(n) => Beta.pdf(x, n) | #Float(n) => Float.pdf(x, n) | #Bernoulli(n) => Bernoulli.pdf(x, n) + | #Metalog(n) => Metalog.pdf(x, n) } let cdf = (x, dist) => @@ -362,6 +449,7 @@ module T = { | #Beta(n) => Beta.cdf(x, n) | #Float(n) => Float.cdf(x, n) | #Bernoulli(n) => Bernoulli.cdf(x, n) + | #Metalog(n) => Metalog.cdf(x, n) } let inv = (x, dist) => @@ -377,6 +465,7 @@ module T = { | #Beta(n) => Beta.inv(x, n) | #Float(n) => Float.inv(x, n) | #Bernoulli(n) => Bernoulli.inv(x, n) + | #Metalog(n) => Metalog.inv(x, n) } let sample: symbolicDist => float = x => @@ -392,6 +481,7 @@ module T = { | #Beta(n) => Beta.sample(n) | #Float(n) => Float.sample(n) | #Bernoulli(n) => Bernoulli.sample(n) + | #Metalog(n) => Metalog.sample(n) } let doN = (n, fn) => { @@ -417,6 +507,7 @@ module T = { | #Beta(n) => Beta.toString(n) | #Float(n) => Float.toString(n) | #Bernoulli(n) => Bernoulli.toString(n) + | #Metalog(n) => Metalog.toString(n) } let min: symbolicDist => float = x => @@ -431,6 +522,7 @@ module T = { | #Uniform({low}) => low | #Bernoulli(n) => Bernoulli.min(n) | #Beta(n) => Beta.inv(minCdfValue, n) + | #Metalog(n) => Metalog.inv(minCdfValue, n) | #Float(n) => n } @@ -445,6 +537,7 @@ module T = { | #Logistic(n) => Logistic.inv(maxCdfValue, n) | #Beta(n) => Beta.inv(maxCdfValue, n) | #Bernoulli(n) => Bernoulli.max(n) + | #Metalog(n) => Metalog.inv(maxCdfValue, n) | #Uniform({high}) => high | #Float(n) => n } @@ -461,6 +554,7 @@ module T = { | #Uniform(n) => Uniform.mean(n) | #Gamma(n) => Gamma.mean(n) | #Bernoulli(n) => Bernoulli.mean(n) + | #Metalog(n) => Metalog.mean(n) | #Float(n) => Float.mean(n) } diff --git a/packages/squiggle-lang/src/rescript/Distributions/SymbolicDist/SymbolicDistTypes.res b/packages/squiggle-lang/src/rescript/Distributions/SymbolicDist/SymbolicDistTypes.res index 8a4956c5..9ff37e88 100644 --- a/packages/squiggle-lang/src/rescript/Distributions/SymbolicDist/SymbolicDistTypes.res +++ b/packages/squiggle-lang/src/rescript/Distributions/SymbolicDist/SymbolicDistTypes.res @@ -41,6 +41,8 @@ type logistic = { scale: float, } +type metalog = {terms: array} + type bernoulli = {p: float} @genType @@ -56,6 +58,7 @@ type symbolicDist = [ | #Float(float) | #Bernoulli(bernoulli) | #Logistic(logistic) + | #Metalog(metalog) ] type analyticalSimplificationResult = [ diff --git a/packages/squiggle-lang/src/rescript/FunctionRegistry/FunctionRegistry_Library.res b/packages/squiggle-lang/src/rescript/FunctionRegistry/FunctionRegistry_Library.res index add5cfa3..c6d96dc2 100644 --- a/packages/squiggle-lang/src/rescript/FunctionRegistry/FunctionRegistry_Library.res +++ b/packages/squiggle-lang/src/rescript/FunctionRegistry/FunctionRegistry_Library.res @@ -139,10 +139,20 @@ beta({mean: 0.39, stdev: 0.1})`, ), Function.make( ~name="Logistic", - ~examples=`gamma(5, 1)`, + ~examples=`logistic(5, 1)`, ~definitions=[TwoArgDist.make("logistic", twoArgs(SymbolicDist.Logistic.make))], (), ), + Function.make( + ~name="Metalog", + ~examples=`metalog([1, 2, 3])`, + ~definitions=[ + ArrayNumberDist.make("metalog", x => + SymbolicDist.Metalog.make(x)->E.R2.fmap(x => Wrappers.symbolic(x)->Wrappers.evDistribution) + ), + ], + (), + ), Function.make( ~name="To (Distribution)", ~examples=`5 to 10