Colors for loss testing

This commit is contained in:
Sam Nolan 2022-04-26 16:38:17 -04:00
parent 752f2a1ea5
commit 69ab296bb4
6 changed files with 78 additions and 63 deletions

View File

@ -1,29 +1,46 @@
import { distributions, generateFloat, generateFloatRange } from "./generators"; import { distributions, generateFloat, generateFloatRange } from "./generators";
import { test, expectEqual } from "./lib"; import { test, expectEqual } from "./lib";
let checkDistributionSame = (distribution: string, operation: (arg: string) => string): void => { let checkDistributionSame = (
expectEqual(operation(distribution), operation(`toPointSet(${distribution})`)); distribution: string,
expectEqual(operation(distribution), operation(`toSampleSet(${distribution})`)); operation: (arg: string) => string
} ): void => {
expectEqual(
operation(distribution),
operation(`toPointSet(${distribution})`)
);
expectEqual(
operation(distribution),
operation(`toSampleSet(${distribution})`)
);
};
Object.entries(distributions).map(([key, generator]) => { Object.entries(distributions).map(([key, generator]) => {
let distribution = generator(); let distribution = generator();
test(`mean is the same for ${key} distribution under all distribution types`, () => test(`mean is the same for ${key} distribution under all distribution types`, () =>
checkDistributionSame(distribution, (d: string) => `mean(${d})`) checkDistributionSame(distribution, (d: string) => `mean(${d})`));
)
test(`cdf is the same for ${key} distribution under all distribution types`, () => { test(`cdf is the same for ${key} distribution under all distribution types`, () => {
let cdf_value = generateFloat(); let cdf_value = generateFloat();
checkDistributionSame(distribution, (d: string) => `cdf(${d}, ${cdf_value})`) checkDistributionSame(
}) distribution,
(d: string) => `cdf(${d}, ${cdf_value})`
);
});
test(`pdf is the same for ${key} distribution under all distribution types`, () => { test(`pdf is the same for ${key} distribution under all distribution types`, () => {
let pdf_value = generateFloat(); let pdf_value = generateFloat();
checkDistributionSame(distribution, (d: string) => `pdf(${d}, ${pdf_value})`) checkDistributionSame(
}) distribution,
(d: string) => `pdf(${d}, ${pdf_value})`
);
});
test(`inv is the same for ${key} distribution under all distribution types`, () => { test(`inv is the same for ${key} distribution under all distribution types`, () => {
let inv_value = generateFloatRange(0, 1); let inv_value = generateFloatRange(0, 1);
checkDistributionSame(distribution, (d: string) => `inv(${d}, ${inv_value})`) checkDistributionSame(
}) distribution,
(d: string) => `inv(${d}, ${inv_value})`
);
});
}); });

View File

@ -1,8 +1,12 @@
export let generateFloat = (): number => Math.random() * 200 - 100; export let generateFloatRange = (min: number, max: number): number =>
export let generateFloatMin = (min:number): number => Math.random() * (100 - min) + min; Math.floor(Math.random() * (max - min) + min);
export let generateFloatRange = (min:number, max: number): number => Math.random() * (max - min) + min;
let generatePositive = (): number => Math.random() * 100; export let generateFloatMin = (min: number): number =>
generateFloatRange(min, 100);
export let generateFloat = (): number => generateFloatMin(-100);
let generatePositive = (): number => generateFloatMin(1);
export let generateNormal = (): string => export let generateNormal = (): string =>
`normal(${generateFloat()}, ${generatePositive()})`; `normal(${generateFloat()}, ${generatePositive()})`;
@ -17,20 +21,20 @@ export let generateExponential = (): string =>
`exponential(${generatePositive()})`; `exponential(${generatePositive()})`;
export let generateUniform = (): string => { export let generateUniform = (): string => {
let a = generateFloat() let a = generateFloat();
let b = generateFloatMin(a) let b = generateFloatMin(a + 1);
return `uniform(${a}, ${b})` return `uniform(${a}, ${b})`;
} };
export let generateCauchy = (): string => { export let generateCauchy = (): string => {
return `cauchy(${generateFloat()}, ${generatePositive()})` return `cauchy(${generateFloat()}, ${generatePositive()})`;
} };
export let generateTriangular = (): string => { export let generateTriangular = (): string => {
let a = generateFloat() let a = generateFloat();
let b = generateFloatMin(a) let b = generateFloatMin(a + 1);
let c = generateFloatMin(b) let c = generateFloatMin(b + 1);
return `triangular(${a}, ${b}, ${c})` return `triangular(${a}, ${b}, ${c})`;
} };
export let distributions: { [key: string]: () => string } = { export let distributions: { [key: string]: () => string } = {
normal: generateNormal, normal: generateNormal,
@ -39,6 +43,5 @@ export let distributions : {[key: string]: () => string} = {
exponential: generateExponential, exponential: generateExponential,
triangular: generateTriangular, triangular: generateTriangular,
cauchy: generateCauchy, cauchy: generateCauchy,
uniform: generateUniform uniform: generateUniform,
} };

View File

@ -1,5 +1,5 @@
import { run, squiggleExpression, errorValueToString } from "../src/js/index"; import { run, squiggleExpression, errorValueToString } from "../src/js/index";
import _ from "lodash"; import * as chalk from "chalk";
let testRun = (x: string): squiggleExpression => { let testRun = (x: string): squiggleExpression => {
let result = run(x, { sampleCount: 100, xyPointLength: 100 }); let result = run(x, { sampleCount: 100, xyPointLength: 100 });
@ -22,17 +22,21 @@ export function expectEqual(expression1: string, expression2: string) {
let result1 = testRun(expression1); let result1 = testRun(expression1);
let result2 = testRun(expression2); let result2 = testRun(expression2);
if (result1.tag === "number" && result2.tag === "number") { if (result1.tag === "number" && result2.tag === "number") {
let loss = getLoss(result1.value, result2.value); let logloss = Math.log(Math.abs(result1.value - result2.value));
console.log(`${expression1} === ${expression2}`); let isBadLogless = logloss > 1;
console.log(`${result1.value} === ${result2.value}`); console.log(chalk.blue(`${expression1} = ${expression2}`));
console.log(`loss = ${loss}`); console.log(`${result1.value} = ${result2.value}`);
console.log(`logloss = ${Math.abs(Math.log(result1.value) - Math.log(result2.value))}`); console.log(
console.log() `logloss = ${
} isBadLogless
else { ? chalk.red(logloss.toFixed(2))
throw Error(`Expected both to be number, but got ${result1.tag} and ${result2.tag}`) : chalk.green(logloss.toFixed(2))
}`
);
console.log();
} else {
throw Error(
`Expected both to be number, but got ${result1.tag} and ${result2.tag}`
);
} }
} }
let getLoss = (actual: number, expected: number): number =>
Math.abs(expected - actual);

View File

@ -36,6 +36,7 @@
"@types/jest": "^27.4.0", "@types/jest": "^27.4.0",
"babel-plugin-transform-es2015-modules-commonjs": "^6.26.2", "babel-plugin-transform-es2015-modules-commonjs": "^6.26.2",
"bisect_ppx": "^2.7.1", "bisect_ppx": "^2.7.1",
"chalk": "^4.1.2",
"codecov": "3.8.3", "codecov": "3.8.3",
"fast-check": "2.25.0", "fast-check": "2.25.0",
"gentype": "^4.3.0", "gentype": "^4.3.0",

View File

@ -4,7 +4,6 @@
"jsx": "react", "jsx": "react",
"allowJs": true, "allowJs": true,
"noImplicitAny": true, "noImplicitAny": true,
"esModuleInterop": true,
"removeComments": true, "removeComments": true,
"preserveConstEnums": true, "preserveConstEnums": true,
"sourceMap": true, "sourceMap": true,

View File

@ -4120,19 +4120,10 @@
dependencies: dependencies:
"@types/react" "*" "@types/react" "*"
"@types/react@*", "@types/react@^18.0.3": "@types/react@*", "@types/react@^16.9.19", "@types/react@^18.0.1", "@types/react@^18.0.3":
version "18.0.7" version "18.0.8"
resolved "https://registry.yarnpkg.com/@types/react/-/react-18.0.7.tgz#8437a226763adf854969954dfe582529a406cbad" resolved "https://registry.yarnpkg.com/@types/react/-/react-18.0.8.tgz#a051eb380a9fbcaa404550543c58e1cf5ce4ab87"
integrity sha512-CXSXHzTexlX9esf4ReIUJeaemKcmBEvYzxHDUk19c3BCcEGUvUjkeC3jkscPSfSaQ6SPDRNd/zMxi8oc/P1zxA== integrity sha512-+j2hk9BzCOrrOSJASi5XiOyBbERk9jG5O73Ya4M0env5Ixi6vUNli4qy994AINcEF+1IEHISYFfIT4zwr++LKw==
dependencies:
"@types/prop-types" "*"
"@types/scheduler" "*"
csstype "^3.0.2"
"@types/react@^16.9.19":
version "16.14.25"
resolved "https://registry.yarnpkg.com/@types/react/-/react-16.14.25.tgz#d003f712c7563fdef5a87327f1892825af375608"
integrity sha512-cXRVHd7vBT5v1is72mmvmsg9stZrbJO04DJqFeh3Yj2tVKO6vmxg5BI+ybI6Ls7ROXRG3aFbZj9x0WA3ZAoDQw==
dependencies: dependencies:
"@types/prop-types" "*" "@types/prop-types" "*"
"@types/scheduler" "*" "@types/scheduler" "*"