Merge pull request #1161 from quantified-uncertainty/sampleset-to-pointset-speedups

Sampleset to pointset speedups
This commit is contained in:
Vyacheslav Matyukhin 2022-09-21 02:51:26 +04:00 committed by GitHub
commit d9f4171943
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
10 changed files with 152 additions and 128 deletions

View File

@ -9,22 +9,28 @@ let prepareInputs = (ar, minWeight) =>
describe("Continuous and discrete splits", () => {
makeTest(
"is empty, with no common elements",
prepareInputs([1.432, 1.33455, 2.0], 2),
prepareInputs([1.33455, 1.432, 2.0], 2),
([1.33455, 1.432, 2.0], []),
)
makeTest(
"only stores 3.5 as discrete when minWeight is 3",
prepareInputs([1.432, 1.33455, 2.0, 2.0, 3.5, 3.5, 3.5], 3),
prepareInputs([1.33455, 1.432, 2.0, 2.0, 3.5, 3.5, 3.5], 3),
([1.33455, 1.432, 2.0, 2.0], [(3.5, 3.0)]),
)
makeTest(
"doesn't store 3.5 as discrete when minWeight is 5",
prepareInputs([1.432, 1.33455, 2.0, 2.0, 3.5, 3.5, 3.5], 5),
prepareInputs([1.33455, 1.432, 2.0, 2.0, 3.5, 3.5, 3.5], 5),
([1.33455, 1.432, 2.0, 2.0, 3.5, 3.5, 3.5], []),
)
makeTest(
"more general test",
prepareInputs([10., 10., 11., 11., 11., 12., 13., 13., 13., 13., 13., 14.], 3),
([10., 10., 12., 14.], [(11., 3.), (13., 5.)]),
)
let makeDuplicatedArray = count => {
let arr = Belt.Array.range(1, count) |> E.A.fmap(float_of_int)
let sorted = arr |> Belt.SortArray.stableSortBy(_, compare)

View File

@ -1,15 +1,6 @@
#!/usr/bin/env node
import { SqProject } from "@quri/squiggle-lang";
const measure = (cb, times = 1) => {
const t1 = new Date();
for (let i = 1; i <= times; i++) {
cb();
}
const t2 = new Date();
return (t2 - t1) / 1000;
};
import { measure } from "./lib.mjs";
const maxP = 5;

View File

@ -1,15 +1,6 @@
#!/usr/bin/env node
import { SqProject } from "@quri/squiggle-lang";
const measure = (cb, times = 1) => {
const t1 = new Date();
for (let i = 1; i <= times; i++) {
cb();
}
const t2 = new Date();
return (t2 - t1) / 1000;
};
import { measure } from "./lib.mjs";
const maxP = 7;

View File

@ -0,0 +1,34 @@
#!/usr/bin/env node
import { SqProject } from "@quri/squiggle-lang";
import { measure } from "./lib.mjs";
const maxP = 3;
const sampleCount = process.env.SAMPLE_COUNT;
for (let p = 0; p <= maxP; p++) {
const size = Math.pow(10, p);
const project = SqProject.create();
if (sampleCount) {
project.setEnvironment({
sampleCount: Number(sampleCount),
xyPointLength: Number(sampleCount),
});
}
project.setSource(
"main",
`
List.upTo(1, ${size}) -> map(
{ |x| normal(x,2) -> SampleSet.fromDist -> PointSet.fromDist }
)->List.last
`
);
const time = measure(() => {
project.run("main");
});
const result = project.getResult("main");
if (result.tag != "Ok") {
throw new Error("Code failed: " + result.value.toString());
}
console.log(`1e${p}`, "\t", time);
}

View File

@ -0,0 +1,41 @@
import { SqProject } from "@quri/squiggle-lang";
export const measure = (cb, times = 1) => {
const t1 = new Date();
for (let i = 1; i <= times; i++) {
cb();
}
const t2 = new Date();
return (t2 - t1) / 1000;
};
export const red = (str) => `\x1b[31m${str}\x1b[0m`;
export const green = (str) => `\x1b[32m${str}\x1b[0m`;
export const run = (src, { output, sampleCount }) => {
const project = SqProject.create();
if (sampleCount) {
project.setEnvironment({
sampleCount: Number(sampleCount),
xyPointLength: Number(sampleCount),
});
}
project.setSource("main", src);
const time = measure(() => project.run("main"));
const bindings = project.getBindings("main");
const result = project.getResult("main");
if (output) {
console.log("Result:", result.tag, result.value.toString());
console.log("Bindings:", bindings.toString());
}
console.log(
"Time:",
String(time),
result.tag === "Error" ? red(result.tag) : green(result.tag),
result.tag === "Error" ? result.value.toString() : ""
);
};

View File

@ -1,21 +1,9 @@
#!/usr/bin/env node
import { SqProject } from "@quri/squiggle-lang";
import fs from "fs";
import { Command } from "commander";
const measure = (cb, times = 1) => {
const t1 = new Date();
for (let i = 1; i <= times; i++) {
cb();
}
const t2 = new Date();
return (t2 - t1) / 1000;
};
const red = (str) => `\x1b[31m${str}\x1b[0m`;
const green = (str) => `\x1b[32m${str}\x1b[0m`;
import { run } from "./lib.mjs";
const program = new Command();
@ -24,34 +12,11 @@ program.arguments("<string>");
const options = program.parse(process.argv);
const project = SqProject.create();
const sampleCount = process.env.SAMPLE_COUNT;
if (sampleCount) {
project.setEnvironment({
sampleCount: Number(sampleCount),
xyPointLength: Number(sampleCount),
});
}
const src = fs.readFileSync(program.args[0], "utf-8");
if (!src) {
throw new Error("Expected src");
}
project.setSource("main", src);
const time = measure(() => project.run("main"));
const bindings = project.getBindings("main");
const result = project.getResult("main");
if (options.output) {
console.log("Result:", result.tag, result.value.toString());
console.log("Bindings:", bindings.toString());
}
console.log(
"Time:",
String(time),
result.tag === "Error" ? red(result.tag) : green(result.tag),
result.tag === "Error" ? result.value.toString() : ""
);
run(src, { output: options.output, sampleCount });

View File

@ -1,18 +1,10 @@
#!/usr/bin/env node
import { SqProject } from "@quri/squiggle-lang";
const project = SqProject.create();
import { run } from "./lib.mjs";
const src = process.argv[2];
if (!src) {
throw new Error("Expected src");
}
console.log(`Running ${src}`);
project.setSource("a", src);
project.run("a");
const result = project.getResult("a");
console.log(result.tag, result.value.toString());
const bindings = project.getBindings("a");
console.log(bindings.asValue().toString());
run(src);

View File

@ -33,19 +33,19 @@ module Internals = {
module KDE = {
let normalSampling = (samples, outputXYPoints, kernelWidth) =>
samples |> JS.samplesToContinuousPdf(_, outputXYPoints, kernelWidth) |> JS.jsToDist
samples->JS.samplesToContinuousPdf(outputXYPoints, kernelWidth)->JS.jsToDist
}
module T = {
type t = array<float>
let xWidthToUnitWidth = (samples, outputXYPoints, xWidth) => {
let xyPointRange = E.A.Sorted.range(samples) |> E.O.default(0.0)
let xyPointRange = E.A.Sorted.range(samples)->E.O2.default(0.0)
let xyPointWidth = xyPointRange /. float_of_int(outputXYPoints)
xWidth /. xyPointWidth
}
let formatUnitWidth = w => Jstat.max([w, 1.0]) |> int_of_float
let formatUnitWidth = w => Jstat.max([w, 1.0])->int_of_float
let suggestedUnitWidth = (samples, outputXYPoints) => {
let suggestedXWidth = SampleSetDist_Bandwidth.nrd0(samples)
@ -62,23 +62,24 @@ let toPointSetDist = (
~samplingInputs: SamplingInputs.samplingInputs,
(),
): Internals.Types.outputs => {
let samples = Js.Array2.copy(samples)
Array.fast_sort(compare, samples)
let samples = samples->Js.Array2.copy->Js.Array2.sortInPlaceWith(compare)
let minDiscreteToKeep = MagicNumbers.ToPointSet.minDiscreteToKeep(samples)
let (continuousPart, discretePart) = E.A.Floats.Sorted.splitContinuousAndDiscreteForMinWeight(
samples,
~minDiscreteWeight=minDiscreteToKeep,
)
let length = samples |> E.A.length |> float_of_int
let length = samples->E.A.length->float_of_int
let discrete: PointSetTypes.discreteShape =
discretePart
|> E.FloatFloatMap.fmap(r => r /. length)
|> E.FloatFloatMap.toArray
|> XYShape.T.fromZippedArray
|> Discrete.make
->E.FloatFloatMap.fmap(r => r /. length, _)
->E.FloatFloatMap.toArray
->XYShape.T.fromZippedArray
->Discrete.make
let pdf =
continuousPart |> E.A.length > 5
continuousPart->E.A.length > 5
? {
let _suggestedXWidth = SampleSetDist_Bandwidth.nrd0(continuousPart)
// todo: This does some recalculating from the last step.
@ -86,7 +87,7 @@ let toPointSetDist = (
continuousPart,
samplingInputs.outputXYPoints,
)
let usedWidth = samplingInputs.kernelWidth |> E.O.default(_suggestedXWidth)
let usedWidth = samplingInputs.kernelWidth->E.O2.default(_suggestedXWidth)
let usedUnitWidth = Internals.T.xWidthToUnitWidth(
samples,
samplingInputs.outputXYPoints,
@ -101,18 +102,18 @@ let toPointSetDist = (
bandwidthUnitImplemented: usedUnitWidth,
}
continuousPart
|> Internals.T.kde(
->Internals.T.kde(
~samples=_,
~outputXYPoints=samplingInputs.outputXYPoints,
Internals.T.formatUnitWidth(usedUnitWidth),
)
|> Continuous.make
|> (r => Some((r, samplingStats)))
->Continuous.make
->(r => Some((r, samplingStats)))
}
: None
let pointSetDist = MixedShapeBuilder.buildSimple(
~continuous=pdf |> E.O.fmap(fst),
~continuous=pdf->E.O2.fmap(fst),
~discrete=Some(discrete),
)
@ -125,7 +126,7 @@ let toPointSetDist = (
let normalizedPointSet = pointSetDist->E.O2.fmap(PointSetDist.T.normalize)
let samplesParse: Internals.Types.outputs = {
continuousParseParams: pdf |> E.O.fmap(snd),
continuousParseParams: pdf->E.O2.fmap(snd),
pointSetDist: normalizedPointSet,
}

View File

@ -305,55 +305,50 @@ module Floats = {
/*
This function goes through a sorted array and divides it into two different clusters:
continuous samples and discrete samples. The discrete samples are stored in a mutable map.
Samples are thought to be discrete if they have any duplicates.
*/
let _splitContinuousAndDiscreteForDuplicates = (sortedArray: array<float>) => {
let continuous: array<float> = []
let discrete = FloatFloatMap.empty()
Belt.Array.forEachWithIndex(sortedArray, (index, element) => {
let maxIndex = (sortedArray |> Array.length) - 1
let possiblySimilarElements = switch index {
| 0 => [index + 1]
| n if n == maxIndex => [index - 1]
| _ => [index - 1, index + 1]
} |> Belt.Array.map(_, r => sortedArray[r])
let hasSimilarElement = Belt.Array.some(possiblySimilarElements, r => r == element)
hasSimilarElement
? FloatFloatMap.increment(element, discrete)
: {
let _ = Js.Array.push(element, continuous)
}
Samples are thought to be discrete if they have at least `minDiscreteWight` duplicates.
()
})
If the min discrete weight is 4, that would mean that at least four elements needed from a specific
value for that to be kept as discrete. This is important because in some cases, we can expect that
some common elements will be generated by regular operations. The final continuous array will be sorted.
(continuous, discrete)
}
/*
This function works very similarly to splitContinuousAndDiscreteForDuplicates. The one major difference
is that you can specify a minDiscreteWeight. If the min discreet weight is 4, that would mean that
at least four elements needed from a specific value for that to be kept as discrete. This is important
because in some cases, we can expect that some common elements will be generated by regular operations.
The final continous array will be sorted.
This function is performance-critical, don't change it significantly without benchmarking
SampleSet->PointSet conversion performance.
*/
let splitContinuousAndDiscreteForMinWeight = (
sortedArray: array<float>,
~minDiscreteWeight: int,
) => {
let (continuous, discrete) = _splitContinuousAndDiscreteForDuplicates(sortedArray)
let keepFn = v => Belt.Float.toInt(v) >= minDiscreteWeight
let (discreteToKeep, discreteToIntegrate) = FloatFloatMap.partition(
((_, v)) => keepFn(v),
discrete,
let continuous: array<float> = []
let discrete = FloatFloatMap.empty()
let addData = (count: int, value: float): unit => {
if count >= minDiscreteWeight {
FloatFloatMap.add(value, count->Belt.Int.toFloat, discrete)
} else {
for _ in 1 to count {
continuous->Js.Array2.push(value)->ignore
}
}
}
let (finalCount, finalValue) = sortedArray->Belt.Array.reduce(
// initial prev value doesn't matter; if it collides with the first element of the array, flush won't do anything
(0, 0.),
((count, prev), element) => {
if element == prev {
(count + 1, prev)
} else {
// new value, process previous ones
addData(count, prev)
(1, element)
}
},
)
let newContinousSamples =
discreteToIntegrate->FloatFloatMap.toArray
|> fmap(((k, v)) => Belt.Array.makeBy(Belt.Float.toInt(v), _ => k))
|> Belt.Array.concatMany
let newContinuous = concat(continuous, newContinousSamples)
newContinuous |> Array.fast_sort(floatCompare)
(newContinuous, discreteToKeep)
// flush final values
addData(finalCount, finalValue)
(continuous, discrete)
}
}
}

View File

@ -16,6 +16,14 @@ let increment = (el, t: t) =>
}
)
let add = (el, amount: float, t: t) =>
Belt.MutableMap.update(t, el, x =>
switch x {
| Some(n) => Some(n +. amount)
| None => Some(amount)
}
)
let get = (el, t: t) => Belt.MutableMap.get(t, el)
let fmap = (fn, t: t) => Belt.MutableMap.map(t, fn)
let partition = (fn, t: t) => {