diff options
author | Paul Eggert <eggert@cs.ucla.edu> | 2010-12-11 00:27:05 -0800 |
---|---|---|
committer | Paul Eggert <eggert@cs.ucla.edu> | 2010-12-11 00:29:13 -0800 |
commit | 27e997d0ebf3b2411e26b92a90365209e5cc1d00 (patch) | |
tree | 571e28d157699197104810161289df260aee4d3b /src | |
parent | c9db0ac6decb8121097d67e13659748ff3b3bcd6 (diff) | |
download | coreutils-27e997d0ebf3b2411e26b92a90365209e5cc1d00.tar.xz |
sort: integer overflow checks in thread counts, etc.
* src/sort.c (specify_nthreads, merge_tree_init, init_node):
(queue_init, sortlines, struct thread_args, sort, main):
Use size_t, not unsigned long int, for thread counts, since thread
counts are now used to compute sizes.
(specify_nthreads): Check for size_t overflow.
(merge_tree_init, sort): Shorten name of local variable, for
readability.
(merge_tree_init): Move constants next to each other in product,
so that the constant folding is easier to see.
(init_node): Now static. Add 'restrict' only where it might
be helpful for compiler optimization.
(queue_init): 2nd arg is now nthreads, not "reserve", which is
a bit harder to follow. All uses changed.
(struct thread_args): Rename lo_child to is_lo_child, so that
it's obvious to the reader when we're talking about this boolean
as opposed to the new lo_child member of the other structure.
All uses changed.
(sort): Remove unused local variable end_node.
(main): Don't allow large thread counts to cause undefined behavior
later, due to integer overflow.
Diffstat (limited to 'src')
-rw-r--r-- | src/sort.c | 115 |
1 files changed, 64 insertions, 51 deletions
diff --git a/src/sort.c b/src/sort.c index b724f3dc1..2c0f8522f 100644 --- a/src/sort.c +++ b/src/sort.c @@ -1379,15 +1379,17 @@ specify_sort_size (int oi, char c, char const *s) } /* Specify the number of threads to spawn during internal sort. */ -static unsigned long int +static size_t specify_nthreads (int oi, char c, char const *s) { unsigned long int nthreads; enum strtol_error e = xstrtoul (s, NULL, 10, &nthreads, ""); if (e == LONGINT_OVERFLOW) - return ULONG_MAX; + return SIZE_MAX; if (e != LONGINT_OK) xstrtol_fatal (e, oi, c, long_options, s); + if (SIZE_MAX < nthreads) + nthreads = SIZE_MAX; if (nthreads == 0) error (SORT_FAILURE, 0, _("number in parallel must be nonzero")); return nthreads; @@ -3139,28 +3141,28 @@ sequential_sort (struct line *restrict lines, size_t nlines, } } -struct merge_node * init_node (struct merge_node *, struct merge_node *, - struct line *restrict, unsigned long int, - size_t, bool); +static struct merge_node *init_node (struct merge_node *restrict, + struct merge_node *restrict, + struct line *, size_t, size_t, bool); -/* Initialize the merge tree. */ +/* Create and return a merge tree for NTHREADS threads, sorting NLINES + lines, with destination DEST. */ static struct merge_node * -merge_tree_init (unsigned long int nthreads, size_t nlines, - struct line *restrict dest) +merge_tree_init (size_t nthreads, size_t nlines, struct line *dest) { - struct merge_node *merge_tree = xmalloc (2 * nthreads * sizeof *merge_tree); - - struct merge_node *root_node = merge_tree; - root_node->lo = root_node->hi = root_node->end_lo = root_node->end_hi = NULL; - root_node->dest = NULL; - root_node->nlo = root_node->nhi = nlines; - root_node->parent = NULL; - root_node->level = MERGE_END; - root_node->queued = false; - pthread_mutex_init (&root_node->lock, NULL); - - init_node (root_node, root_node + 1, dest, nthreads, nlines, false); + struct merge_node *merge_tree = xmalloc (2 * sizeof *merge_tree * nthreads); + + struct merge_node *root = merge_tree; + root->lo = root->hi = root->end_lo = root->end_hi = NULL; + root->dest = NULL; + root->nlo = root->nhi = nlines; + root->parent = NULL; + root->level = MERGE_END; + root->queued = false; + pthread_mutex_init (&root->lock, NULL); + + init_node (root, root + 1, dest, nthreads, nlines, false); return merge_tree; } @@ -3171,19 +3173,25 @@ merge_tree_destroy (struct merge_node *merge_tree) free (merge_tree); } -/* Initialize a merge tree node. */ +/* Initialize a merge tree node and its descendants. The node's + parent is PARENT. The node and its descendants are taken from the + array of nodes NODE_POOL. Their destination starts at DEST; they + will consume NTHREADS threads. The total number of sort lines is + TOTAL_LINES. IS_LO_CHILD is true if the node is the low child of + its parent. */ -struct merge_node * -init_node (struct merge_node *parent, struct merge_node *node_pool, - struct line *restrict dest, unsigned long int nthreads, - size_t total_lines, bool lo_child) +static struct merge_node * +init_node (struct merge_node *restrict parent, + struct merge_node *restrict node_pool, + struct line *dest, size_t nthreads, + size_t total_lines, bool is_lo_child) { - size_t nlines = (lo_child)? parent->nlo : parent->nhi; + size_t nlines = (is_lo_child ? parent->nlo : parent->nhi); size_t nlo = nlines / 2; size_t nhi = nlines - nlo; struct line *lo = dest - total_lines; struct line *hi = lo - nlo; - struct line **parent_end = (lo_child)? &parent->end_lo : &parent->end_hi; + struct line **parent_end = (is_lo_child ? &parent->end_lo : &parent->end_hi); struct merge_node *node = node_pool++; node->lo = node->end_lo = lo; @@ -3198,8 +3206,8 @@ init_node (struct merge_node *parent, struct merge_node *node_pool, if (nthreads > 1) { - unsigned long int lo_threads = nthreads / 2; - unsigned long int hi_threads = nthreads - lo_threads; + size_t lo_threads = nthreads / 2; + size_t hi_threads = nthreads - lo_threads; node->lo_child = node_pool; node_pool = init_node (node, node_pool, lo, lo_threads, total_lines, true); @@ -3254,15 +3262,16 @@ queue_destroy (struct merge_node_queue *queue) pthread_mutex_destroy (&queue->mutex); } -/* Initialize merge QUEUE, allocating space for a maximum of RESERVE nodes. - Though it's highly unlikely all nodes are in the heap at the same time, - RESERVE should accommodate all of them. Counting a NULL dummy head for the - heap, RESERVE should be 2 * NTHREADS. */ +/* Initialize merge QUEUE, allocating space suitable for a maximum of + NTHREADS threads. */ static void -queue_init (struct merge_node_queue *queue, size_t reserve) +queue_init (struct merge_node_queue *queue, size_t nthreads) { - queue->priority_queue = heap_alloc (compare_nodes, reserve); + /* Though it's highly unlikely all nodes are in the heap at the same + time, the heap should accommodate all of them. Counting a NULL + dummy head for the heap, reserve 2 * NTHREADS nodes. */ + queue->priority_queue = heap_alloc (compare_nodes, 2 * nthreads); pthread_mutex_init (&queue->mutex, NULL); pthread_cond_init (&queue->cond, NULL); } @@ -3454,7 +3463,7 @@ merge_loop (struct merge_node_queue *queue, } -static void sortlines (struct line *restrict, unsigned long int, size_t, +static void sortlines (struct line *restrict, size_t, size_t, struct merge_node *, bool, struct merge_node_queue *, FILE *, char const *); @@ -3467,7 +3476,7 @@ struct thread_args struct line *lines; /* Number of threads to use. If 0 or 1, sort single-threaded. */ - unsigned long int nthreads; + size_t nthreads; /* Number of lines in LINES and DEST. */ size_t const total_lines; @@ -3477,7 +3486,7 @@ struct thread_args struct merge_node *const node; /* True if this node is sorting the lower half of the parent's work. */ - bool lo_child; + bool is_lo_child; /* The priority queue controlling available work for the entire internal sort. */ @@ -3496,7 +3505,7 @@ sortlines_thread (void *data) { struct thread_args const *args = data; sortlines (args->lines, args->nthreads, args->total_lines, - args->node, args->lo_child, args->queue, args->tfp, + args->node, args->is_lo_child, args->queue, args->tfp, args->output_temp); return NULL; } @@ -3526,15 +3535,15 @@ sortlines_thread (void *data) have been merged. */ static void -sortlines (struct line *restrict lines, unsigned long int nthreads, - size_t total_lines, struct merge_node *node, bool lo_child, +sortlines (struct line *restrict lines, size_t nthreads, + size_t total_lines, struct merge_node *node, bool is_lo_child, struct merge_node_queue *queue, FILE *tfp, char const *temp_output) { size_t nlines = node->nlo + node->nhi; /* Calculate thread arguments. */ - unsigned long int lo_threads = nthreads / 2; - unsigned long int hi_threads = nthreads - lo_threads; + size_t lo_threads = nthreads / 2; + size_t hi_threads = nthreads - lo_threads; pthread_t thread; struct thread_args args = {lines, lo_threads, total_lines, node->lo_child, true, queue, tfp, temp_output}; @@ -3774,7 +3783,7 @@ merge (struct sortfile *files, size_t ntemps, size_t nfiles, static void sort (char * const *files, size_t nfiles, char const *output_file, - unsigned long int nthreads) + size_t nthreads) { struct buffer buf; size_t ntemps = 0; @@ -3793,7 +3802,7 @@ sort (char * const *files, size_t nfiles, char const *output_file, if (nthreads > 1) { /* Get log P. */ - unsigned long int tmp = 1; + size_t tmp = 1; size_t mult = 1; while (tmp < nthreads) { @@ -3843,16 +3852,15 @@ sort (char * const *files, size_t nfiles, char const *output_file, if (1 < buf.nlines) { struct merge_node_queue queue; - queue_init (&queue, 2 * nthreads); + queue_init (&queue, nthreads); struct merge_node *merge_tree = merge_tree_init (nthreads, buf.nlines, line); - struct merge_node *end_node = merge_tree; - struct merge_node *root_node = merge_tree + 1; + struct merge_node *root = merge_tree + 1; - sortlines (line, nthreads, buf.nlines, root_node, + sortlines (line, nthreads, buf.nlines, root, true, &queue, tfp, temp_output); queue_destroy (&queue); - pthread_mutex_destroy (&root_node->lock); + pthread_mutex_destroy (&root->lock); merge_tree_destroy (merge_tree); } else @@ -4076,7 +4084,7 @@ main (int argc, char **argv) bool mergeonly = false; char *random_source = NULL; bool need_random = false; - unsigned long int nthreads = 0; + size_t nthreads = 0; size_t nfiles = 0; bool posixly_correct = (getenv ("POSIXLY_CORRECT") != NULL); bool obsolete_usage = (posix2_version () < 200112); @@ -4620,6 +4628,11 @@ main (int argc, char **argv) if (!nthreads || nthreads > np2) nthreads = np2; + /* Avoid integer overflow later. */ + size_t nthreads_max = SIZE_MAX / (2 * sizeof (struct merge_node)); + if (nthreads_max < nthreads) + nthreads = nthreads_max; + sort (files, nfiles, outfile, nthreads); } |