squiggle/packages/squiggle-lang/__tests__/Distributions/Invariants/Means_test.res

157 lines
4.6 KiB
Plaintext
Raw Normal View History

2022-04-13 15:54:37 +00:00
/*
This is the most basic file in our invariants family of tests.
2022-04-13 23:17:49 +00:00
Validate that the addition of means equals the mean of the addition, similar for subtraction and multiplication.
2022-04-13 15:54:37 +00:00
Details in https://develop--squiggle-documentation.netlify.app/docs/internal/invariants/
2022-04-13 23:17:49 +00:00
Note: epsilon of 1e3 means the invariants are, in general, not being satisfied.
2022-04-13 15:54:37 +00:00
*/
2022-04-12 21:06:53 +00:00
open Jest
open Expect
open TestHelpers
2022-04-13 23:17:49 +00:00
module Internals = {
2022-04-14 00:58:16 +00:00
let epsilon = 5e1
2022-04-13 16:06:50 +00:00
2022-04-13 05:02:53 +00:00
let mean = GenericDist_Types.Constructors.UsingDists.mean
2022-04-13 23:18:08 +00:00
2022-04-13 23:17:49 +00:00
let expectImpossiblePath: string => assertion = algebraicOp =>
2022-04-13 05:02:53 +00:00
`${algebraicOp} has`->expect->toEqual("failed")
let distributions = list{
2022-04-14 00:58:16 +00:00
normalMake(4e0, 1e0),
2022-04-13 05:02:53 +00:00
betaMake(2e0, 4e0),
exponentialMake(1.234e0),
uniformMake(7e0, 1e1),
// cauchyMake(1e0, 1e0),
2022-04-14 00:58:16 +00:00
lognormalMake(2e0, 1e0),
2022-04-13 05:02:53 +00:00
triangularMake(1e0, 1e1, 5e1),
Ok(floatMake(1e1)),
}
2022-04-13 23:17:49 +00:00
let pairsOfDifferentDistributions = E.L.combinations2(distributions)
let runMean: DistributionTypes.genericDist => float = dist => {
dist->mean->run->toFloat->E.O2.toExn("Shouldn't see this because we trust testcase input")
}
2022-04-13 05:02:53 +00:00
2022-04-13 15:54:37 +00:00
let testOperationMean = (
2022-04-13 23:18:08 +00:00
distOp: (
DistributionTypes.genericDist,
DistributionTypes.genericDist,
) => result<DistributionTypes.genericDist, DistributionTypes.error>,
description: string,
floatOp: (float, float) => float,
dist1': SymbolicDistTypes.symbolicDist,
dist2': SymbolicDistTypes.symbolicDist,
~epsilon: float,
) => {
2022-04-13 23:58:36 +00:00
let dist1 = dist1'->DistributionTypes.Symbolic
let dist2 = dist2'->DistributionTypes.Symbolic
let received =
2022-04-13 23:18:08 +00:00
distOp(dist1, dist2)->E.R2.fmap(mean)->E.R2.fmap(run)->E.R2.fmap(toFloat)->E.R.toExn
let expected = floatOp(runMean(dist1), runMean(dist2))
switch received {
2022-04-13 23:17:49 +00:00
| None => expectImpossiblePath(description)
2022-04-13 23:18:08 +00:00
| Some(x) => expectErrorToBeBounded(x, expected, ~epsilon)
2022-04-13 05:02:53 +00:00
}
}
2022-04-13 23:17:49 +00:00
}
2022-04-13 23:17:49 +00:00
let {
algebraicAdd,
algebraicMultiply,
algebraicDivide,
algebraicSubtract,
algebraicLogarithm,
algebraicPower,
} = module(DistributionOperation.Constructors)
2022-04-13 04:35:07 +00:00
2022-04-13 23:17:49 +00:00
let algebraicAdd = algebraicAdd(~env)
let algebraicMultiply = algebraicMultiply(~env)
let algebraicDivide = algebraicDivide(~env)
let algebraicSubtract = algebraicSubtract(~env)
let algebraicLogarithm = algebraicLogarithm(~env)
let algebraicPower = algebraicPower(~env)
let {testOperationMean, distributions, pairsOfDifferentDistributions, epsilon} = module(Internals)
2022-04-14 00:58:16 +00:00
describe("Means are invariant", () => {
2022-04-13 23:17:49 +00:00
describe("for addition", () => {
2022-04-13 23:18:08 +00:00
let testAdditionMean = testOperationMean(algebraicAdd, "algebraicAdd", \"+.", ~epsilon)
2022-04-13 23:58:36 +00:00
testAll("with two of the same distribution", distributions, dist => {
2022-04-13 23:18:08 +00:00
E.R.liftM2(testAdditionMean, dist, dist)->E.R.toExn
2022-04-13 05:02:53 +00:00
})
2022-04-13 04:35:07 +00:00
2022-04-13 23:58:36 +00:00
testAll("with two different distributions", pairsOfDifferentDistributions, dists => {
2022-04-13 05:02:53 +00:00
let (dist1, dist2) = dists
2022-04-13 23:18:08 +00:00
E.R.liftM2(testAdditionMean, dist1, dist2)->E.R.toExn
2022-04-13 04:35:07 +00:00
})
2022-04-13 23:58:36 +00:00
testAll(
"with two different distributions in swapped order",
pairsOfDifferentDistributions,
dists => {
let (dist1, dist2) = dists
E.R.liftM2(testAdditionMean, dist2, dist1)->E.R.toExn
},
)
2022-04-13 05:02:53 +00:00
})
2022-04-13 23:17:49 +00:00
describe("for subtraction", () => {
2022-04-13 23:18:08 +00:00
let testSubtractionMean = testOperationMean(
algebraicSubtract,
"algebraicSubtract",
\"-.",
~epsilon,
)
2022-04-13 04:35:07 +00:00
2022-04-13 23:58:36 +00:00
testAll("with two of the same distribution", distributions, dist => {
2022-04-13 23:18:08 +00:00
E.R.liftM2(testSubtractionMean, dist, dist)->E.R.toExn
2022-04-13 05:02:53 +00:00
})
2022-04-13 04:35:07 +00:00
2022-04-13 23:58:36 +00:00
testAll("with two different distributions", pairsOfDifferentDistributions, dists => {
2022-04-13 05:02:53 +00:00
let (dist1, dist2) = dists
2022-04-13 23:18:08 +00:00
E.R.liftM2(testSubtractionMean, dist1, dist2)->E.R.toExn
2022-04-13 04:35:07 +00:00
})
2022-04-13 23:58:36 +00:00
testAll(
"with two different distributions in swapped order",
pairsOfDifferentDistributions,
dists => {
let (dist1, dist2) = dists
E.R.liftM2(testSubtractionMean, dist2, dist1)->E.R.toExn
},
)
2022-04-13 05:02:53 +00:00
})
2022-04-13 23:17:49 +00:00
describe("for multiplication", () => {
2022-04-13 23:18:08 +00:00
let testMultiplicationMean = testOperationMean(
algebraicMultiply,
"algebraicMultiply",
\"*.",
~epsilon,
)
2022-04-13 04:35:07 +00:00
2022-04-13 23:58:36 +00:00
testAll("with two of the same distribution", distributions, dist => {
2022-04-13 23:18:08 +00:00
E.R.liftM2(testMultiplicationMean, dist, dist)->E.R.toExn
2022-04-13 05:02:53 +00:00
})
2022-04-13 04:35:07 +00:00
2022-04-13 23:58:36 +00:00
testAll("with two different distributions", pairsOfDifferentDistributions, dists => {
2022-04-13 05:02:53 +00:00
let (dist1, dist2) = dists
2022-04-13 23:18:08 +00:00
E.R.liftM2(testMultiplicationMean, dist1, dist2)->E.R.toExn
2022-04-13 05:02:53 +00:00
})
2022-04-13 23:58:36 +00:00
testAll(
"with two different distributions in swapped order",
pairsOfDifferentDistributions,
dists => {
let (dist1, dist2) = dists
E.R.liftM2(testMultiplicationMean, dist2, dist1)->E.R.toExn
},
)
2022-04-13 05:02:53 +00:00
})
})