factorize paralellization in C code out

- Conceptually clearer
- Allows for composing multiple mixtures together
- Considering incorporating it into squiggle.c
This commit is contained in:
NunoSempere 2023-11-18 19:43:10 +00:00
parent 9a60392849
commit 2a9d3bf135
9 changed files with 23 additions and 72 deletions

Binary file not shown.

View File

@ -152,7 +152,7 @@ float random_to(float low, float high, uint32_t* seed)
// Mixture function
float mixture_one_thread(float (*samplers[])(uint32_t*), float* weights, int n_dists, uint32_t* seed)
float mixture(float (*samplers[])(uint32_t*), float* weights, int n_dists, uint32_t* seed)
{
// You can see a slightly simpler version of this function in the git history
@ -167,68 +167,19 @@ float mixture_one_thread(float (*samplers[])(uint32_t*), float* weights, int n_d
//create var holders
float p1, result;
int sample_index, i, own_length;
p1 = random_uniform(0, 1);
p1 = random_uniform(0, 1, seed);
for (int i = 0; i < n_dists; i++) {
if (p1 < cummulative_weights[i]) {
result = samplers[i]();
if (p1 < cumsummed_normalized_weights[i]) {
result = samplers[i](seed);
break;
}
}
free(normalized_weights);
free(cummulative_weights);
free(cumsummed_normalized_weights);
return result;
}
// mixture paralellized
void mixture_paralell(float (*samplers[])(uint32_t*), float* weights, int n_dists, float** results, int n_threads)
{
// You can see a simpler version of this function in the git history
// or in alt/C-02-better-algorithm-one-thread/
float sum_weights = array_sum(weights, n_dists);
float* cumsummed_normalized_weights = malloc(n_dists * sizeof(float));
cumsummed_normalized_weights[0] = weights[0] / sum_weights;
for (int i = 1; i < n_dists; i++) {
cumsummed_normalized_weights[i] = cumsummed_normalized_weights[i - 1] + weights[i] / sum_weights;
}
//create var holders
float p1;
int sample_index, i, split_array_length;
// uint32_t* seeds[n_threads];
uint32_t** seeds = malloc(n_threads * sizeof(uint32_t*));
for (uint32_t i = 0; i < n_threads; i++) {
seeds[i] = malloc(sizeof(uint32_t));
*seeds[i] = i + 1; // xorshift can't start with 0
}
#pragma omp parallel private(i, p1, sample_index, split_array_length)
{
#pragma omp for
for (i = 0; i < n_threads; i++) {
split_array_length = split_array_get_length(i, N, n_threads);
for (int j = 0; j < split_array_length; j++) {
p1 = random_uniform(0, 1, seeds[i]);
for (int k = 0; k < n_dists; k++) {
if (p1 < cumsummed_normalized_weights[k]) {
results[i][j] = samplers[k](seeds[i]);
break;
}
}
}
}
}
// free(normalized_weights);
// free(cummulative_weights);
free(cumsummed_normalized_weights);
for (uint32_t i = 0; i < n_threads; i++) {
free(seeds[i]);
}
free(seeds);
}
// Parallization function
void paralellize(float *sampler(uint32_t* seed), float** results, int n_threads){
void paralellize(float (*sampler)(uint32_t* seed), float** results, int n_threads){
int sample_index, i, split_array_length;
uint32_t** seeds = malloc(n_threads * sizeof(uint32_t*));
@ -237,18 +188,17 @@ void paralellize(float *sampler(uint32_t* seed), float** results, int n_threads)
*seeds[i] = i + 1; // xorshift can't start with 0
}
#pragma omp parallel private(i, p1, sample_index, split_array_length)
#pragma omp parallel private(i, sample_index, split_array_length)
{
#pragma omp for
for (i = 0; i < n_threads; i++) {
split_array_length = split_array_get_length(i, N, n_threads);
for (int j = 0; j < split_array_length; j++) {
results[i][j] = sampler(seeds[i]);
break;
}
}
}
free(cumsummed_normalized_weights);
for (uint32_t i = 0; i < n_threads; i++) {
free(seeds[i]);
}
@ -278,17 +228,8 @@ float sample_many(uint32_t* seed)
return random_to(2, 10, seed);
}
int main()
{
// Toy example
// Declare variables in play
float sample_mixture(uint32_t* seed){
float p_a, p_b, p_c;
int n_threads = omp_get_max_threads();
// printf("Max threads: %d\n", n_threads);
// omp_set_num_threads(n_threads);
float** dist_mixture = malloc(n_threads * sizeof(float*));
split_array_allocate(dist_mixture, N, n_threads);
// Initialize variables
p_a = 0.8;
@ -300,10 +241,20 @@ int main()
float weights[] = { 1 - p_c, p_c / 2, p_c / 4, p_c / 4 };
float (*samplers[])(uint32_t*) = { sample_0, sample_1, sample_few, sample_many };
mixture(samplers, weights, n_dists, dist_mixture, n_threads);
printf("Sum(dist_mixture, N)/N = %f\n", split_array_sum(dist_mixture, N, n_threads) / N);
// array_print(dist_mixture[0], N);
split_array_free(dist_mixture, n_threads);
return mixture(samplers, weights, n_dists, seed);
}
int main()
{
int n_threads = omp_get_max_threads();
// printf("Max threads: %d\n", n_threads);
// omp_set_num_threads(n_threads);
float** split_array_results = malloc(n_threads * sizeof(float*));
split_array_allocate(split_array_results, N, n_threads);
paralellize(sample_mixture, split_array_results, n_threads);
printf("Sum(split_array_results, N)/N = %f\n", split_array_sum(split_array_results, N, n_threads) / N);
split_array_free(split_array_results, n_threads);
return 0;
}