io: Make queue_io_read() safe for reentry
authorDavid Disseldorp <ddiss@suse.de>
Sun, 31 Jul 2011 01:14:54 +0000 (03:14 +0200)
committerRonnie Sahlberg <ronniesahlberg@gmail.com>
Fri, 5 Aug 2011 04:28:43 +0000 (14:28 +1000)
queue_io_read() may be reentered via the queue callback, recoverd is
particularly guilty of this.

queue_io_read() is not safe for reentry if more than one packet is
received and partial chunks follow - data read off the pipe on re-entry
is assumed to be the start-of-packet four byte length. This leads to a
wrongly aligned stream and the notorious "Invalid packet of length 0"
errors.

This change fixes queue_io_read() to be safe under reentry, only a
single packet is processed per call.

https://bugzilla.samba.org/show_bug.cgi?id=8319

common/ctdb_io.c

index 81f9451396b1e62e888c7903077f8de0fe076155..0f44b8785e0a09d805e975d67c1ab999c48b4859 100644 (file)
@@ -81,12 +81,17 @@ static void dump_packet(unsigned char *data, size_t len)
 
 /*
   called when an incoming connection is readable
+  This function MUST be safe for reentry via the queue callback!
 */
 static void queue_io_read(struct ctdb_queue *queue)
 {
        int num_ready = 0;
-       ssize_t nread, totread, partlen;
-       uint8_t *data, *data_base;
+       uint32_t sz_bytes_req;
+       uint32_t pkt_size;
+       uint32_t pkt_bytes_remaining;
+       uint32_t to_read;
+       ssize_t nread;
+       uint8_t *data;
 
        if (ioctl(queue->fd, FIONREAD, &num_ready) != 0) {
                return;
@@ -96,93 +101,77 @@ static void queue_io_read(struct ctdb_queue *queue)
                goto failed;
        }
 
-
-       queue->partial.data = talloc_realloc_size(queue, queue->partial.data, 
-                                                 num_ready + queue->partial.length);
-
        if (queue->partial.data == NULL) {
-               DEBUG(DEBUG_ERR,("%s: read error alloc failed for %u\n",
-                       queue->name, num_ready + queue->partial.length));
-               goto failed;
-       }
-
-       nread = read(queue->fd, queue->partial.data + queue->partial.length, num_ready);
-       if (nread <= 0) {
-               DEBUG(DEBUG_ERR,("%s: read error nread=%d\n",
-                                queue->name, (int)nread));
-               goto failed;
+               /* starting fresh, allocate buf for size bytes */
+               sz_bytes_req = sizeof(pkt_size);
+               queue->partial.data = talloc_size(queue, sz_bytes_req);
+               if (queue->partial.data == NULL) {
+                       DEBUG(DEBUG_ERR,("read error alloc failed for %u\n",
+                                        sz_bytes_req));
+                       goto failed;
+               }
+       } else if (queue->partial.length < sizeof(pkt_size)) {
+               /* yet to find out the packet length */
+               sz_bytes_req = sizeof(pkt_size) - queue->partial.length;
+       } else {
+               /* partial packet, length known, full buf allocated */
+               sz_bytes_req = 0;
        }
-       totread = nread;
-       partlen = queue->partial.length;
-
        data = queue->partial.data;
-       nread += queue->partial.length;
-
-       queue->partial.data = NULL;
-       queue->partial.length = 0;
-
-       if (nread >= 4 && *(uint32_t *)data == nread) {
-               /* it is the responsibility of the incoming packet
-                function to free 'data' */
-               queue->callback(data, nread, queue->private_data);
-               return;
-       }
 
-       data_base = data;
-
-       while (nread >= 4 && *(uint32_t *)data <= nread) {
-               /* we have at least one packet */
-               uint8_t *d2;
-               uint32_t len;
-               bool destroyed = false;
-
-               len = *(uint32_t *)data;
-               if (len == 0) {
-                       /* bad packet! treat as EOF */
-                       DEBUG(DEBUG_CRIT,("%s: Invalid packet of length 0 (nread = %zu, totread = %zu, partlen = %zu)\n",
-                                         queue->name, nread, totread, partlen));
-                       dump_packet(data_base, totread + partlen);
-                       goto failed;
-               }
-               d2 = talloc_memdup(queue, data, len);
-               if (d2 == NULL) {
-                       DEBUG(DEBUG_ERR,("%s: read error memdup failed for %u\n",
-                                        queue->name, len));
-                       /* sigh */
+       if (sz_bytes_req > 0) {
+               to_read = MIN(sz_bytes_req, num_ready);
+               nread = read(queue->fd, data + queue->partial.length,
+                            to_read);
+               if (nread <= 0) {
+                       DEBUG(DEBUG_ERR,("read error nread=%d\n", (int)nread));
                        goto failed;
                }
+               queue->partial.length += nread;
 
-               queue->destroyed = &destroyed;
-               queue->callback(d2, len, queue->private_data);
-               /* If callback freed us, don't do anything else. */
-               if (destroyed) {
+               if (nread < sz_bytes_req) {
+                       /* not enough to know the length */
+                       DEBUG(DEBUG_DEBUG,("Partial packet length read\n"));
                        return;
                }
-               queue->destroyed = NULL;
+               /* size now known, allocate buffer for the full packet */
+               queue->partial.data = talloc_realloc_size(queue, data,
+                                                         *(uint32_t *)data);
+               if (queue->partial.data == NULL) {
+                       DEBUG(DEBUG_ERR,("read error alloc failed for %u\n",
+                                        *(uint32_t *)data));
+                       goto failed;
+               }
+               data = queue->partial.data;
+               num_ready -= nread;
+       }
 
-               data += len;
-               nread -= len;           
+       pkt_size = *(uint32_t *)data;
+       if (pkt_size == 0) {
+               DEBUG(DEBUG_CRIT,("Invalid packet of length 0\n"));
+               goto failed;
        }
 
-       if (nread > 0) {
-               /* we have only part of a packet */
-               if (data_base == data) {
-                       queue->partial.data = data;
-                       queue->partial.length = nread;
-               } else {
-                       queue->partial.data = talloc_memdup(queue, data, nread);
-                       if (queue->partial.data == NULL) {
-                               DEBUG(DEBUG_ERR,("%s: read error memdup partial failed for %u\n",
-                                                queue->name, (unsigned)nread));
-                               goto failed;
-                       }
-                       queue->partial.length = nread;
-                       talloc_free(data_base);
-               }
+       pkt_bytes_remaining = pkt_size - queue->partial.length;
+       to_read = MIN(pkt_bytes_remaining, num_ready);
+       nread = read(queue->fd, data + queue->partial.length,
+                    to_read);
+       if (nread <= 0) {
+               DEBUG(DEBUG_ERR,("read error nread=%d\n",
+                                (int)nread));
+               goto failed;
+       }
+       queue->partial.length += nread;
+
+       if (queue->partial.length < pkt_size) {
+               DEBUG(DEBUG_DEBUG,("Partial packet data read\n"));
                return;
        }
 
-       talloc_free(data_base);
+       queue->partial.data = NULL;
+       queue->partial.length = 0;
+       /* it is the responsibility of the callback to free 'data' */
+       queue->callback(data, pkt_size, queue->private_data);
        return;
 
 failed: