Merge pull request #1161 from quantified-uncertainty/sampleset-to-pointset-speedups

Sampleset to pointset speedups
This commit is contained in:
Vyacheslav Matyukhin 2022-09-21 02:51:26 +04:00 committed by GitHub
commit d9f4171943
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
10 changed files with 152 additions and 128 deletions

View File

@ -9,22 +9,28 @@ let prepareInputs = (ar, minWeight) =>
describe("Continuous and discrete splits", () => { describe("Continuous and discrete splits", () => {
makeTest( makeTest(
"is empty, with no common elements", "is empty, with no common elements",
prepareInputs([1.432, 1.33455, 2.0], 2), prepareInputs([1.33455, 1.432, 2.0], 2),
([1.33455, 1.432, 2.0], []), ([1.33455, 1.432, 2.0], []),
) )
makeTest( makeTest(
"only stores 3.5 as discrete when minWeight is 3", "only stores 3.5 as discrete when minWeight is 3",
prepareInputs([1.432, 1.33455, 2.0, 2.0, 3.5, 3.5, 3.5], 3), prepareInputs([1.33455, 1.432, 2.0, 2.0, 3.5, 3.5, 3.5], 3),
([1.33455, 1.432, 2.0, 2.0], [(3.5, 3.0)]), ([1.33455, 1.432, 2.0, 2.0], [(3.5, 3.0)]),
) )
makeTest( makeTest(
"doesn't store 3.5 as discrete when minWeight is 5", "doesn't store 3.5 as discrete when minWeight is 5",
prepareInputs([1.432, 1.33455, 2.0, 2.0, 3.5, 3.5, 3.5], 5), prepareInputs([1.33455, 1.432, 2.0, 2.0, 3.5, 3.5, 3.5], 5),
([1.33455, 1.432, 2.0, 2.0, 3.5, 3.5, 3.5], []), ([1.33455, 1.432, 2.0, 2.0, 3.5, 3.5, 3.5], []),
) )
makeTest(
"more general test",
prepareInputs([10., 10., 11., 11., 11., 12., 13., 13., 13., 13., 13., 14.], 3),
([10., 10., 12., 14.], [(11., 3.), (13., 5.)]),
)
let makeDuplicatedArray = count => { let makeDuplicatedArray = count => {
let arr = Belt.Array.range(1, count) |> E.A.fmap(float_of_int) let arr = Belt.Array.range(1, count) |> E.A.fmap(float_of_int)
let sorted = arr |> Belt.SortArray.stableSortBy(_, compare) let sorted = arr |> Belt.SortArray.stableSortBy(_, compare)

View File

@ -1,15 +1,6 @@
#!/usr/bin/env node #!/usr/bin/env node
import { SqProject } from "@quri/squiggle-lang"; import { SqProject } from "@quri/squiggle-lang";
import { measure } from "./lib.mjs";
const measure = (cb, times = 1) => {
const t1 = new Date();
for (let i = 1; i <= times; i++) {
cb();
}
const t2 = new Date();
return (t2 - t1) / 1000;
};
const maxP = 5; const maxP = 5;

View File

@ -1,15 +1,6 @@
#!/usr/bin/env node #!/usr/bin/env node
import { SqProject } from "@quri/squiggle-lang"; import { SqProject } from "@quri/squiggle-lang";
import { measure } from "./lib.mjs";
const measure = (cb, times = 1) => {
const t1 = new Date();
for (let i = 1; i <= times; i++) {
cb();
}
const t2 = new Date();
return (t2 - t1) / 1000;
};
const maxP = 7; const maxP = 7;

View File

@ -0,0 +1,34 @@
#!/usr/bin/env node
import { SqProject } from "@quri/squiggle-lang";
import { measure } from "./lib.mjs";
const maxP = 3;
const sampleCount = process.env.SAMPLE_COUNT;
for (let p = 0; p <= maxP; p++) {
const size = Math.pow(10, p);
const project = SqProject.create();
if (sampleCount) {
project.setEnvironment({
sampleCount: Number(sampleCount),
xyPointLength: Number(sampleCount),
});
}
project.setSource(
"main",
`
List.upTo(1, ${size}) -> map(
{ |x| normal(x,2) -> SampleSet.fromDist -> PointSet.fromDist }
)->List.last
`
);
const time = measure(() => {
project.run("main");
});
const result = project.getResult("main");
if (result.tag != "Ok") {
throw new Error("Code failed: " + result.value.toString());
}
console.log(`1e${p}`, "\t", time);
}

View File

@ -0,0 +1,41 @@
import { SqProject } from "@quri/squiggle-lang";
export const measure = (cb, times = 1) => {
const t1 = new Date();
for (let i = 1; i <= times; i++) {
cb();
}
const t2 = new Date();
return (t2 - t1) / 1000;
};
export const red = (str) => `\x1b[31m${str}\x1b[0m`;
export const green = (str) => `\x1b[32m${str}\x1b[0m`;
export const run = (src, { output, sampleCount }) => {
const project = SqProject.create();
if (sampleCount) {
project.setEnvironment({
sampleCount: Number(sampleCount),
xyPointLength: Number(sampleCount),
});
}
project.setSource("main", src);
const time = measure(() => project.run("main"));
const bindings = project.getBindings("main");
const result = project.getResult("main");
if (output) {
console.log("Result:", result.tag, result.value.toString());
console.log("Bindings:", bindings.toString());
}
console.log(
"Time:",
String(time),
result.tag === "Error" ? red(result.tag) : green(result.tag),
result.tag === "Error" ? result.value.toString() : ""
);
};

View File

@ -1,21 +1,9 @@
#!/usr/bin/env node #!/usr/bin/env node
import { SqProject } from "@quri/squiggle-lang";
import fs from "fs"; import fs from "fs";
import { Command } from "commander"; import { Command } from "commander";
const measure = (cb, times = 1) => { import { run } from "./lib.mjs";
const t1 = new Date();
for (let i = 1; i <= times; i++) {
cb();
}
const t2 = new Date();
return (t2 - t1) / 1000;
};
const red = (str) => `\x1b[31m${str}\x1b[0m`;
const green = (str) => `\x1b[32m${str}\x1b[0m`;
const program = new Command(); const program = new Command();
@ -24,34 +12,11 @@ program.arguments("<string>");
const options = program.parse(process.argv); const options = program.parse(process.argv);
const project = SqProject.create();
const sampleCount = process.env.SAMPLE_COUNT; const sampleCount = process.env.SAMPLE_COUNT;
if (sampleCount) {
project.setEnvironment({
sampleCount: Number(sampleCount),
xyPointLength: Number(sampleCount),
});
}
const src = fs.readFileSync(program.args[0], "utf-8"); const src = fs.readFileSync(program.args[0], "utf-8");
if (!src) { if (!src) {
throw new Error("Expected src"); throw new Error("Expected src");
} }
project.setSource("main", src); run(src, { output: options.output, sampleCount });
const time = measure(() => project.run("main"));
const bindings = project.getBindings("main");
const result = project.getResult("main");
if (options.output) {
console.log("Result:", result.tag, result.value.toString());
console.log("Bindings:", bindings.toString());
}
console.log(
"Time:",
String(time),
result.tag === "Error" ? red(result.tag) : green(result.tag),
result.tag === "Error" ? result.value.toString() : ""
);

View File

@ -1,18 +1,10 @@
#!/usr/bin/env node #!/usr/bin/env node
import { SqProject } from "@quri/squiggle-lang"; import { run } from "./lib.mjs";
const project = SqProject.create();
const src = process.argv[2]; const src = process.argv[2];
if (!src) { if (!src) {
throw new Error("Expected src"); throw new Error("Expected src");
} }
console.log(`Running ${src}`); console.log(`Running ${src}`);
project.setSource("a", src);
project.run("a");
const result = project.getResult("a"); run(src);
console.log(result.tag, result.value.toString());
const bindings = project.getBindings("a");
console.log(bindings.asValue().toString());

View File

@ -33,19 +33,19 @@ module Internals = {
module KDE = { module KDE = {
let normalSampling = (samples, outputXYPoints, kernelWidth) => let normalSampling = (samples, outputXYPoints, kernelWidth) =>
samples |> JS.samplesToContinuousPdf(_, outputXYPoints, kernelWidth) |> JS.jsToDist samples->JS.samplesToContinuousPdf(outputXYPoints, kernelWidth)->JS.jsToDist
} }
module T = { module T = {
type t = array<float> type t = array<float>
let xWidthToUnitWidth = (samples, outputXYPoints, xWidth) => { let xWidthToUnitWidth = (samples, outputXYPoints, xWidth) => {
let xyPointRange = E.A.Sorted.range(samples) |> E.O.default(0.0) let xyPointRange = E.A.Sorted.range(samples)->E.O2.default(0.0)
let xyPointWidth = xyPointRange /. float_of_int(outputXYPoints) let xyPointWidth = xyPointRange /. float_of_int(outputXYPoints)
xWidth /. xyPointWidth xWidth /. xyPointWidth
} }
let formatUnitWidth = w => Jstat.max([w, 1.0]) |> int_of_float let formatUnitWidth = w => Jstat.max([w, 1.0])->int_of_float
let suggestedUnitWidth = (samples, outputXYPoints) => { let suggestedUnitWidth = (samples, outputXYPoints) => {
let suggestedXWidth = SampleSetDist_Bandwidth.nrd0(samples) let suggestedXWidth = SampleSetDist_Bandwidth.nrd0(samples)
@ -62,23 +62,24 @@ let toPointSetDist = (
~samplingInputs: SamplingInputs.samplingInputs, ~samplingInputs: SamplingInputs.samplingInputs,
(), (),
): Internals.Types.outputs => { ): Internals.Types.outputs => {
let samples = Js.Array2.copy(samples) let samples = samples->Js.Array2.copy->Js.Array2.sortInPlaceWith(compare)
Array.fast_sort(compare, samples)
let minDiscreteToKeep = MagicNumbers.ToPointSet.minDiscreteToKeep(samples) let minDiscreteToKeep = MagicNumbers.ToPointSet.minDiscreteToKeep(samples)
let (continuousPart, discretePart) = E.A.Floats.Sorted.splitContinuousAndDiscreteForMinWeight( let (continuousPart, discretePart) = E.A.Floats.Sorted.splitContinuousAndDiscreteForMinWeight(
samples, samples,
~minDiscreteWeight=minDiscreteToKeep, ~minDiscreteWeight=minDiscreteToKeep,
) )
let length = samples |> E.A.length |> float_of_int
let length = samples->E.A.length->float_of_int
let discrete: PointSetTypes.discreteShape = let discrete: PointSetTypes.discreteShape =
discretePart discretePart
|> E.FloatFloatMap.fmap(r => r /. length) ->E.FloatFloatMap.fmap(r => r /. length, _)
|> E.FloatFloatMap.toArray ->E.FloatFloatMap.toArray
|> XYShape.T.fromZippedArray ->XYShape.T.fromZippedArray
|> Discrete.make ->Discrete.make
let pdf = let pdf =
continuousPart |> E.A.length > 5 continuousPart->E.A.length > 5
? { ? {
let _suggestedXWidth = SampleSetDist_Bandwidth.nrd0(continuousPart) let _suggestedXWidth = SampleSetDist_Bandwidth.nrd0(continuousPart)
// todo: This does some recalculating from the last step. // todo: This does some recalculating from the last step.
@ -86,7 +87,7 @@ let toPointSetDist = (
continuousPart, continuousPart,
samplingInputs.outputXYPoints, samplingInputs.outputXYPoints,
) )
let usedWidth = samplingInputs.kernelWidth |> E.O.default(_suggestedXWidth) let usedWidth = samplingInputs.kernelWidth->E.O2.default(_suggestedXWidth)
let usedUnitWidth = Internals.T.xWidthToUnitWidth( let usedUnitWidth = Internals.T.xWidthToUnitWidth(
samples, samples,
samplingInputs.outputXYPoints, samplingInputs.outputXYPoints,
@ -101,18 +102,18 @@ let toPointSetDist = (
bandwidthUnitImplemented: usedUnitWidth, bandwidthUnitImplemented: usedUnitWidth,
} }
continuousPart continuousPart
|> Internals.T.kde( ->Internals.T.kde(
~samples=_, ~samples=_,
~outputXYPoints=samplingInputs.outputXYPoints, ~outputXYPoints=samplingInputs.outputXYPoints,
Internals.T.formatUnitWidth(usedUnitWidth), Internals.T.formatUnitWidth(usedUnitWidth),
) )
|> Continuous.make ->Continuous.make
|> (r => Some((r, samplingStats))) ->(r => Some((r, samplingStats)))
} }
: None : None
let pointSetDist = MixedShapeBuilder.buildSimple( let pointSetDist = MixedShapeBuilder.buildSimple(
~continuous=pdf |> E.O.fmap(fst), ~continuous=pdf->E.O2.fmap(fst),
~discrete=Some(discrete), ~discrete=Some(discrete),
) )
@ -125,7 +126,7 @@ let toPointSetDist = (
let normalizedPointSet = pointSetDist->E.O2.fmap(PointSetDist.T.normalize) let normalizedPointSet = pointSetDist->E.O2.fmap(PointSetDist.T.normalize)
let samplesParse: Internals.Types.outputs = { let samplesParse: Internals.Types.outputs = {
continuousParseParams: pdf |> E.O.fmap(snd), continuousParseParams: pdf->E.O2.fmap(snd),
pointSetDist: normalizedPointSet, pointSetDist: normalizedPointSet,
} }

View File

@ -305,55 +305,50 @@ module Floats = {
/* /*
This function goes through a sorted array and divides it into two different clusters: This function goes through a sorted array and divides it into two different clusters:
continuous samples and discrete samples. The discrete samples are stored in a mutable map. continuous samples and discrete samples. The discrete samples are stored in a mutable map.
Samples are thought to be discrete if they have any duplicates. Samples are thought to be discrete if they have at least `minDiscreteWight` duplicates.
*/
let _splitContinuousAndDiscreteForDuplicates = (sortedArray: array<float>) => {
let continuous: array<float> = []
let discrete = FloatFloatMap.empty()
Belt.Array.forEachWithIndex(sortedArray, (index, element) => {
let maxIndex = (sortedArray |> Array.length) - 1
let possiblySimilarElements = switch index {
| 0 => [index + 1]
| n if n == maxIndex => [index - 1]
| _ => [index - 1, index + 1]
} |> Belt.Array.map(_, r => sortedArray[r])
let hasSimilarElement = Belt.Array.some(possiblySimilarElements, r => r == element)
hasSimilarElement
? FloatFloatMap.increment(element, discrete)
: {
let _ = Js.Array.push(element, continuous)
}
() If the min discrete weight is 4, that would mean that at least four elements needed from a specific
}) value for that to be kept as discrete. This is important because in some cases, we can expect that
some common elements will be generated by regular operations. The final continuous array will be sorted.
(continuous, discrete) This function is performance-critical, don't change it significantly without benchmarking
} SampleSet->PointSet conversion performance.
/*
This function works very similarly to splitContinuousAndDiscreteForDuplicates. The one major difference
is that you can specify a minDiscreteWeight. If the min discreet weight is 4, that would mean that
at least four elements needed from a specific value for that to be kept as discrete. This is important
because in some cases, we can expect that some common elements will be generated by regular operations.
The final continous array will be sorted.
*/ */
let splitContinuousAndDiscreteForMinWeight = ( let splitContinuousAndDiscreteForMinWeight = (
sortedArray: array<float>, sortedArray: array<float>,
~minDiscreteWeight: int, ~minDiscreteWeight: int,
) => { ) => {
let (continuous, discrete) = _splitContinuousAndDiscreteForDuplicates(sortedArray) let continuous: array<float> = []
let keepFn = v => Belt.Float.toInt(v) >= minDiscreteWeight let discrete = FloatFloatMap.empty()
let (discreteToKeep, discreteToIntegrate) = FloatFloatMap.partition(
((_, v)) => keepFn(v), let addData = (count: int, value: float): unit => {
discrete, if count >= minDiscreteWeight {
FloatFloatMap.add(value, count->Belt.Int.toFloat, discrete)
} else {
for _ in 1 to count {
continuous->Js.Array2.push(value)->ignore
}
}
}
let (finalCount, finalValue) = sortedArray->Belt.Array.reduce(
// initial prev value doesn't matter; if it collides with the first element of the array, flush won't do anything
(0, 0.),
((count, prev), element) => {
if element == prev {
(count + 1, prev)
} else {
// new value, process previous ones
addData(count, prev)
(1, element)
}
},
) )
let newContinousSamples =
discreteToIntegrate->FloatFloatMap.toArray // flush final values
|> fmap(((k, v)) => Belt.Array.makeBy(Belt.Float.toInt(v), _ => k)) addData(finalCount, finalValue)
|> Belt.Array.concatMany
let newContinuous = concat(continuous, newContinousSamples) (continuous, discrete)
newContinuous |> Array.fast_sort(floatCompare)
(newContinuous, discreteToKeep)
} }
} }
} }

View File

@ -16,6 +16,14 @@ let increment = (el, t: t) =>
} }
) )
let add = (el, amount: float, t: t) =>
Belt.MutableMap.update(t, el, x =>
switch x {
| Some(n) => Some(n +. amount)
| None => Some(amount)
}
)
let get = (el, t: t) => Belt.MutableMap.get(t, el) let get = (el, t: t) => Belt.MutableMap.get(t, el)
let fmap = (fn, t: t) => Belt.MutableMap.map(t, fn) let fmap = (fn, t: t) => Belt.MutableMap.map(t, fn)
let partition = (fn, t: t) => { let partition = (fn, t: t) => {