diff --git a/packages/squiggle-lang/__tests__/ReducerInterface/ReducerInterface_Distribution_test.res b/packages/squiggle-lang/__tests__/ReducerInterface/ReducerInterface_Distribution_test.res index 253a9cc3..48bdcaa5 100644 --- a/packages/squiggle-lang/__tests__/ReducerInterface/ReducerInterface_Distribution_test.res +++ b/packages/squiggle-lang/__tests__/ReducerInterface/ReducerInterface_Distribution_test.res @@ -33,6 +33,7 @@ describe("eval on distribution functions", () => { testEval("mean(gamma(5,5))", "Ok(25)") testEval("mean(bernoulli(0.2))", "Ok(0.2)") testEval("mean(bernoulli(0.8))", "Ok(0.8)") + testEval("mean(logistic(5,1))", "Ok(5)") }) describe("toString", () => { testEval("toString(normal(5,2))", "Ok('Normal(5,2)')") diff --git a/packages/squiggle-lang/src/rescript/Distributions/SymbolicDist/SymbolicDist.res b/packages/squiggle-lang/src/rescript/Distributions/SymbolicDist/SymbolicDist.res index 6356f97c..d1fcb9ef 100644 --- a/packages/squiggle-lang/src/rescript/Distributions/SymbolicDist/SymbolicDist.res +++ b/packages/squiggle-lang/src/rescript/Distributions/SymbolicDist/SymbolicDist.res @@ -216,6 +216,24 @@ module Uniform = { } } +module Logistic = { + type t = logistic + let make = (location, scale) => + scale > 0.0 + ? Ok(#Logistic({location: location, scale: scale})) + : Error("Scale must be positive") + + let pdf = (x, t: t) => Stdlib.Logistic.pdf(x, t.location, t.scale) + let cdf = (x, t: t) => Stdlib.Logistic.cdf(x, t.location, t.scale) + let inv = (p, t: t) => Stdlib.Logistic.quantile(p, t.location, t.scale) + let sample = (t: t) => { + let s = Uniform.sample({low: 0.0, high: 1.0}) + inv(s, t) + } + let mean = (t: t) => Ok(Stdlib.Logistic.mean(t.location, t.scale)) + let toString = ({location, scale}: t) => j`Logistic($location,$scale)` +} + module Bernoulli = { type t = bernoulli let make = p => @@ -304,6 +322,7 @@ module T = { | #Cauchy(n) => Cauchy.pdf(x, n) | #Gamma(n) => Gamma.pdf(x, n) | #Lognormal(n) => Lognormal.pdf(x, n) + | #Logistic(n) => Logistic.pdf(x, n) | #Uniform(n) => Uniform.pdf(x, n) | #Beta(n) => Beta.pdf(x, n) | #Float(n) => Float.pdf(x, n) @@ -317,6 +336,7 @@ module T = { | #Exponential(n) => Exponential.cdf(x, n) | #Cauchy(n) => Cauchy.cdf(x, n) | #Gamma(n) => Gamma.cdf(x, n) + | #Logistic(n) => Logistic.cdf(x, n) | #Lognormal(n) => Lognormal.cdf(x, n) | #Uniform(n) => Uniform.cdf(x, n) | #Beta(n) => Beta.cdf(x, n) @@ -331,6 +351,7 @@ module T = { | #Exponential(n) => Exponential.inv(x, n) | #Cauchy(n) => Cauchy.inv(x, n) | #Gamma(n) => Gamma.inv(x, n) + | #Logistic(n) => Logistic.inv(x, n) | #Lognormal(n) => Lognormal.inv(x, n) | #Uniform(n) => Uniform.inv(x, n) | #Beta(n) => Beta.inv(x, n) @@ -345,6 +366,7 @@ module T = { | #Exponential(n) => Exponential.sample(n) | #Cauchy(n) => Cauchy.sample(n) | #Gamma(n) => Gamma.sample(n) + | #Logistic(n) => Logistic.sample(n) | #Lognormal(n) => Lognormal.sample(n) | #Uniform(n) => Uniform.sample(n) | #Beta(n) => Beta.sample(n) @@ -369,6 +391,7 @@ module T = { | #Cauchy(n) => Cauchy.toString(n) | #Normal(n) => Normal.toString(n) | #Gamma(n) => Gamma.toString(n) + | #Logistic(n) => Logistic.toString(n) | #Lognormal(n) => Lognormal.toString(n) | #Uniform(n) => Uniform.toString(n) | #Beta(n) => Beta.toString(n) @@ -383,6 +406,7 @@ module T = { | #Cauchy(n) => Cauchy.inv(minCdfValue, n) | #Normal(n) => Normal.inv(minCdfValue, n) | #Lognormal(n) => Lognormal.inv(minCdfValue, n) + | #Logistic(n) => Logistic.inv(minCdfValue, n) | #Gamma(n) => Gamma.inv(minCdfValue, n) | #Uniform({low}) => low | #Bernoulli(n) => Bernoulli.min(n) @@ -398,6 +422,7 @@ module T = { | #Normal(n) => Normal.inv(maxCdfValue, n) | #Gamma(n) => Gamma.inv(maxCdfValue, n) | #Lognormal(n) => Lognormal.inv(maxCdfValue, n) + | #Logistic(n) => Logistic.inv(maxCdfValue, n) | #Beta(n) => Beta.inv(maxCdfValue, n) | #Bernoulli(n) => Bernoulli.max(n) | #Uniform({high}) => high @@ -412,6 +437,7 @@ module T = { | #Normal(n) => Normal.mean(n) | #Lognormal(n) => Lognormal.mean(n) | #Beta(n) => Beta.mean(n) + | #Logistic(n) => Logistic.mean(n) | #Uniform(n) => Uniform.mean(n) | #Gamma(n) => Gamma.mean(n) | #Bernoulli(n) => Bernoulli.mean(n) diff --git a/packages/squiggle-lang/src/rescript/Distributions/SymbolicDist/SymbolicDistTypes.res b/packages/squiggle-lang/src/rescript/Distributions/SymbolicDist/SymbolicDistTypes.res index e888fbd2..8a4956c5 100644 --- a/packages/squiggle-lang/src/rescript/Distributions/SymbolicDist/SymbolicDistTypes.res +++ b/packages/squiggle-lang/src/rescript/Distributions/SymbolicDist/SymbolicDistTypes.res @@ -36,6 +36,11 @@ type gamma = { scale: float, } +type logistic = { + location: float, + scale: float, +} + type bernoulli = {p: float} @genType @@ -50,6 +55,7 @@ type symbolicDist = [ | #Gamma(gamma) | #Float(float) | #Bernoulli(bernoulli) + | #Logistic(logistic) ] type analyticalSimplificationResult = [ diff --git a/packages/squiggle-lang/src/rescript/ReducerInterface/ReducerInterface_GenericDistribution.res b/packages/squiggle-lang/src/rescript/ReducerInterface/ReducerInterface_GenericDistribution.res index 8dffecb4..1f1291e9 100644 --- a/packages/squiggle-lang/src/rescript/ReducerInterface/ReducerInterface_GenericDistribution.res +++ b/packages/squiggle-lang/src/rescript/ReducerInterface/ReducerInterface_GenericDistribution.res @@ -178,6 +178,7 @@ module SymbolicConstructors = { | "uniform" => Ok(SymbolicDist.Uniform.make) | "beta" => Ok(SymbolicDist.Beta.make) | "lognormal" => Ok(SymbolicDist.Lognormal.make) + | "logistic" => Ok(SymbolicDist.Logistic.make) | "cauchy" => Ok(SymbolicDist.Cauchy.make) | "gamma" => Ok(SymbolicDist.Gamma.make) | "to" => Ok(SymbolicDist.From90thPercentile.make) @@ -212,7 +213,14 @@ let dispatchToGenericOutput = ( | ("delta", [EvNumber(f)]) => SymbolicDist.Float.makeSafe(f)->SymbolicConstructors.symbolicResultToOutput | ( - ("normal" | "uniform" | "beta" | "lognormal" | "cauchy" | "gamma" | "to") as fnName, + ("normal" + | "uniform" + | "beta" + | "lognormal" + | "cauchy" + | "gamma" + | "to" + | "logistic") as fnName, [EvNumber(f1), EvNumber(f2)], ) => SymbolicConstructors.twoFloat(fnName) diff --git a/packages/squiggle-lang/src/rescript/Utility/Stdlib.res b/packages/squiggle-lang/src/rescript/Utility/Stdlib.res index 3455e962..faa1cb1d 100644 --- a/packages/squiggle-lang/src/rescript/Utility/Stdlib.res +++ b/packages/squiggle-lang/src/rescript/Utility/Stdlib.res @@ -10,4 +10,31 @@ module Bernoulli = { @module external mean: float => float = "@stdlib/stats/base/dists/bernoulli/mean" let mean = mean + + @module external stdev: float => float = "@stdlib/stats/base/dists/bernoulli/stdev" + let stdev = stdev + + @module external variance: float => float = "@stdlib/stats/base/dists/bernoulli/variance" + let variance = variance +} + +module Logistic = { + @module external cdf: (float, float, float) => float = "@stdlib/stats/base/dists/logistic/cdf" + let cdf = cdf + + @module external pdf: (float, float, float) => float = "@stdlib/stats/base/dists/logistic/pdf" + let pdf = pdf + + @module + external quantile: (float, float, float) => float = "@stdlib/stats/base/dists/logistic/quantile" + let quantile = quantile + + @module external mean: (float, float) => float = "@stdlib/stats/base/dists/logistic/mean" + let mean = mean + + @module external stdev: (float, float) => float = "@stdlib/stats/base/dists/logistic/stdev" + let stdev = stdev + + @module external variance: (float, float) => float = "@stdlib/stats/base/dists/logistic/variance" + let variance = variance }