summaryrefslogtreecommitdiff
path: root/src
diff options
context:
space:
mode:
authorPaul Eggert <eggert@cs.ucla.edu>2010-12-11 00:27:05 -0800
committerPaul Eggert <eggert@cs.ucla.edu>2010-12-11 00:29:13 -0800
commit27e997d0ebf3b2411e26b92a90365209e5cc1d00 (patch)
tree571e28d157699197104810161289df260aee4d3b /src
parentc9db0ac6decb8121097d67e13659748ff3b3bcd6 (diff)
downloadcoreutils-27e997d0ebf3b2411e26b92a90365209e5cc1d00.tar.xz
sort: integer overflow checks in thread counts, etc.
* src/sort.c (specify_nthreads, merge_tree_init, init_node): (queue_init, sortlines, struct thread_args, sort, main): Use size_t, not unsigned long int, for thread counts, since thread counts are now used to compute sizes. (specify_nthreads): Check for size_t overflow. (merge_tree_init, sort): Shorten name of local variable, for readability. (merge_tree_init): Move constants next to each other in product, so that the constant folding is easier to see. (init_node): Now static. Add 'restrict' only where it might be helpful for compiler optimization. (queue_init): 2nd arg is now nthreads, not "reserve", which is a bit harder to follow. All uses changed. (struct thread_args): Rename lo_child to is_lo_child, so that it's obvious to the reader when we're talking about this boolean as opposed to the new lo_child member of the other structure. All uses changed. (sort): Remove unused local variable end_node. (main): Don't allow large thread counts to cause undefined behavior later, due to integer overflow.
Diffstat (limited to 'src')
-rw-r--r--src/sort.c115
1 files changed, 64 insertions, 51 deletions
diff --git a/src/sort.c b/src/sort.c
index b724f3dc1..2c0f8522f 100644
--- a/src/sort.c
+++ b/src/sort.c
@@ -1379,15 +1379,17 @@ specify_sort_size (int oi, char c, char const *s)
}
/* Specify the number of threads to spawn during internal sort. */
-static unsigned long int
+static size_t
specify_nthreads (int oi, char c, char const *s)
{
unsigned long int nthreads;
enum strtol_error e = xstrtoul (s, NULL, 10, &nthreads, "");
if (e == LONGINT_OVERFLOW)
- return ULONG_MAX;
+ return SIZE_MAX;
if (e != LONGINT_OK)
xstrtol_fatal (e, oi, c, long_options, s);
+ if (SIZE_MAX < nthreads)
+ nthreads = SIZE_MAX;
if (nthreads == 0)
error (SORT_FAILURE, 0, _("number in parallel must be nonzero"));
return nthreads;
@@ -3139,28 +3141,28 @@ sequential_sort (struct line *restrict lines, size_t nlines,
}
}
-struct merge_node * init_node (struct merge_node *, struct merge_node *,
- struct line *restrict, unsigned long int,
- size_t, bool);
+static struct merge_node *init_node (struct merge_node *restrict,
+ struct merge_node *restrict,
+ struct line *, size_t, size_t, bool);
-/* Initialize the merge tree. */
+/* Create and return a merge tree for NTHREADS threads, sorting NLINES
+ lines, with destination DEST. */
static struct merge_node *
-merge_tree_init (unsigned long int nthreads, size_t nlines,
- struct line *restrict dest)
+merge_tree_init (size_t nthreads, size_t nlines, struct line *dest)
{
- struct merge_node *merge_tree = xmalloc (2 * nthreads * sizeof *merge_tree);
-
- struct merge_node *root_node = merge_tree;
- root_node->lo = root_node->hi = root_node->end_lo = root_node->end_hi = NULL;
- root_node->dest = NULL;
- root_node->nlo = root_node->nhi = nlines;
- root_node->parent = NULL;
- root_node->level = MERGE_END;
- root_node->queued = false;
- pthread_mutex_init (&root_node->lock, NULL);
-
- init_node (root_node, root_node + 1, dest, nthreads, nlines, false);
+ struct merge_node *merge_tree = xmalloc (2 * sizeof *merge_tree * nthreads);
+
+ struct merge_node *root = merge_tree;
+ root->lo = root->hi = root->end_lo = root->end_hi = NULL;
+ root->dest = NULL;
+ root->nlo = root->nhi = nlines;
+ root->parent = NULL;
+ root->level = MERGE_END;
+ root->queued = false;
+ pthread_mutex_init (&root->lock, NULL);
+
+ init_node (root, root + 1, dest, nthreads, nlines, false);
return merge_tree;
}
@@ -3171,19 +3173,25 @@ merge_tree_destroy (struct merge_node *merge_tree)
free (merge_tree);
}
-/* Initialize a merge tree node. */
+/* Initialize a merge tree node and its descendants. The node's
+ parent is PARENT. The node and its descendants are taken from the
+ array of nodes NODE_POOL. Their destination starts at DEST; they
+ will consume NTHREADS threads. The total number of sort lines is
+ TOTAL_LINES. IS_LO_CHILD is true if the node is the low child of
+ its parent. */
-struct merge_node *
-init_node (struct merge_node *parent, struct merge_node *node_pool,
- struct line *restrict dest, unsigned long int nthreads,
- size_t total_lines, bool lo_child)
+static struct merge_node *
+init_node (struct merge_node *restrict parent,
+ struct merge_node *restrict node_pool,
+ struct line *dest, size_t nthreads,
+ size_t total_lines, bool is_lo_child)
{
- size_t nlines = (lo_child)? parent->nlo : parent->nhi;
+ size_t nlines = (is_lo_child ? parent->nlo : parent->nhi);
size_t nlo = nlines / 2;
size_t nhi = nlines - nlo;
struct line *lo = dest - total_lines;
struct line *hi = lo - nlo;
- struct line **parent_end = (lo_child)? &parent->end_lo : &parent->end_hi;
+ struct line **parent_end = (is_lo_child ? &parent->end_lo : &parent->end_hi);
struct merge_node *node = node_pool++;
node->lo = node->end_lo = lo;
@@ -3198,8 +3206,8 @@ init_node (struct merge_node *parent, struct merge_node *node_pool,
if (nthreads > 1)
{
- unsigned long int lo_threads = nthreads / 2;
- unsigned long int hi_threads = nthreads - lo_threads;
+ size_t lo_threads = nthreads / 2;
+ size_t hi_threads = nthreads - lo_threads;
node->lo_child = node_pool;
node_pool = init_node (node, node_pool, lo, lo_threads,
total_lines, true);
@@ -3254,15 +3262,16 @@ queue_destroy (struct merge_node_queue *queue)
pthread_mutex_destroy (&queue->mutex);
}
-/* Initialize merge QUEUE, allocating space for a maximum of RESERVE nodes.
- Though it's highly unlikely all nodes are in the heap at the same time,
- RESERVE should accommodate all of them. Counting a NULL dummy head for the
- heap, RESERVE should be 2 * NTHREADS. */
+/* Initialize merge QUEUE, allocating space suitable for a maximum of
+ NTHREADS threads. */
static void
-queue_init (struct merge_node_queue *queue, size_t reserve)
+queue_init (struct merge_node_queue *queue, size_t nthreads)
{
- queue->priority_queue = heap_alloc (compare_nodes, reserve);
+ /* Though it's highly unlikely all nodes are in the heap at the same
+ time, the heap should accommodate all of them. Counting a NULL
+ dummy head for the heap, reserve 2 * NTHREADS nodes. */
+ queue->priority_queue = heap_alloc (compare_nodes, 2 * nthreads);
pthread_mutex_init (&queue->mutex, NULL);
pthread_cond_init (&queue->cond, NULL);
}
@@ -3454,7 +3463,7 @@ merge_loop (struct merge_node_queue *queue,
}
-static void sortlines (struct line *restrict, unsigned long int, size_t,
+static void sortlines (struct line *restrict, size_t, size_t,
struct merge_node *, bool, struct merge_node_queue *,
FILE *, char const *);
@@ -3467,7 +3476,7 @@ struct thread_args
struct line *lines;
/* Number of threads to use. If 0 or 1, sort single-threaded. */
- unsigned long int nthreads;
+ size_t nthreads;
/* Number of lines in LINES and DEST. */
size_t const total_lines;
@@ -3477,7 +3486,7 @@ struct thread_args
struct merge_node *const node;
/* True if this node is sorting the lower half of the parent's work. */
- bool lo_child;
+ bool is_lo_child;
/* The priority queue controlling available work for the entire
internal sort. */
@@ -3496,7 +3505,7 @@ sortlines_thread (void *data)
{
struct thread_args const *args = data;
sortlines (args->lines, args->nthreads, args->total_lines,
- args->node, args->lo_child, args->queue, args->tfp,
+ args->node, args->is_lo_child, args->queue, args->tfp,
args->output_temp);
return NULL;
}
@@ -3526,15 +3535,15 @@ sortlines_thread (void *data)
have been merged. */
static void
-sortlines (struct line *restrict lines, unsigned long int nthreads,
- size_t total_lines, struct merge_node *node, bool lo_child,
+sortlines (struct line *restrict lines, size_t nthreads,
+ size_t total_lines, struct merge_node *node, bool is_lo_child,
struct merge_node_queue *queue, FILE *tfp, char const *temp_output)
{
size_t nlines = node->nlo + node->nhi;
/* Calculate thread arguments. */
- unsigned long int lo_threads = nthreads / 2;
- unsigned long int hi_threads = nthreads - lo_threads;
+ size_t lo_threads = nthreads / 2;
+ size_t hi_threads = nthreads - lo_threads;
pthread_t thread;
struct thread_args args = {lines, lo_threads, total_lines,
node->lo_child, true, queue, tfp, temp_output};
@@ -3774,7 +3783,7 @@ merge (struct sortfile *files, size_t ntemps, size_t nfiles,
static void
sort (char * const *files, size_t nfiles, char const *output_file,
- unsigned long int nthreads)
+ size_t nthreads)
{
struct buffer buf;
size_t ntemps = 0;
@@ -3793,7 +3802,7 @@ sort (char * const *files, size_t nfiles, char const *output_file,
if (nthreads > 1)
{
/* Get log P. */
- unsigned long int tmp = 1;
+ size_t tmp = 1;
size_t mult = 1;
while (tmp < nthreads)
{
@@ -3843,16 +3852,15 @@ sort (char * const *files, size_t nfiles, char const *output_file,
if (1 < buf.nlines)
{
struct merge_node_queue queue;
- queue_init (&queue, 2 * nthreads);
+ queue_init (&queue, nthreads);
struct merge_node *merge_tree =
merge_tree_init (nthreads, buf.nlines, line);
- struct merge_node *end_node = merge_tree;
- struct merge_node *root_node = merge_tree + 1;
+ struct merge_node *root = merge_tree + 1;
- sortlines (line, nthreads, buf.nlines, root_node,
+ sortlines (line, nthreads, buf.nlines, root,
true, &queue, tfp, temp_output);
queue_destroy (&queue);
- pthread_mutex_destroy (&root_node->lock);
+ pthread_mutex_destroy (&root->lock);
merge_tree_destroy (merge_tree);
}
else
@@ -4076,7 +4084,7 @@ main (int argc, char **argv)
bool mergeonly = false;
char *random_source = NULL;
bool need_random = false;
- unsigned long int nthreads = 0;
+ size_t nthreads = 0;
size_t nfiles = 0;
bool posixly_correct = (getenv ("POSIXLY_CORRECT") != NULL);
bool obsolete_usage = (posix2_version () < 200112);
@@ -4620,6 +4628,11 @@ main (int argc, char **argv)
if (!nthreads || nthreads > np2)
nthreads = np2;
+ /* Avoid integer overflow later. */
+ size_t nthreads_max = SIZE_MAX / (2 * sizeof (struct merge_node));
+ if (nthreads_max < nthreads)
+ nthreads = nthreads_max;
+
sort (files, nfiles, outfile, nthreads);
}