swrap: fix fd-passing without 4 padding bytes
[socket_wrapper.git] / src / socket_wrapper.c
index 14d0dda00d25cd94643630eaf178bcf14c13f42b..ece34934f9c90de80baa4c9918920049d8ea22bb 100644 (file)
@@ -183,7 +183,6 @@ enum swrap_dbglvl_e {
 
 /* Add new global locks here please */
 # define SWRAP_REINIT_ALL do { \
-       size_t __i; \
        int ret; \
        ret = socket_wrapper_init_mutex(&sockets_mutex); \
        if (ret != 0) exit(-1); \
@@ -191,10 +190,8 @@ enum swrap_dbglvl_e {
        if (ret != 0) exit(-1); \
        ret = socket_wrapper_init_mutex(&first_free_mutex); \
        if (ret != 0) exit(-1); \
-       for (__i = 0; (sockets != NULL) && __i < socket_info_max; __i++) { \
-               ret = socket_wrapper_init_mutex(&sockets[__i].meta.mutex); \
-               if (ret != 0) exit(-1); \
-       } \
+       ret = socket_wrapper_init_mutex(&sockets_si_global); \
+       if (ret != 0) exit(-1); \
        ret = socket_wrapper_init_mutex(&autobind_start_mutex); \
        if (ret != 0) exit(-1); \
        ret = socket_wrapper_init_mutex(&pcap_dump_mutex); \
@@ -204,27 +201,20 @@ enum swrap_dbglvl_e {
 } while(0)
 
 # define SWRAP_LOCK_ALL do { \
-       size_t __i; \
        swrap_mutex_lock(&sockets_mutex); \
        swrap_mutex_lock(&socket_reset_mutex); \
        swrap_mutex_lock(&first_free_mutex); \
-       for (__i = 0; (sockets != NULL) && __i < socket_info_max; __i++) { \
-               swrap_mutex_lock(&sockets[__i].meta.mutex); \
-       } \
+       swrap_mutex_lock(&sockets_si_global); \
        swrap_mutex_lock(&autobind_start_mutex); \
        swrap_mutex_lock(&pcap_dump_mutex); \
        swrap_mutex_lock(&mtu_update_mutex); \
 } while(0)
 
 # define SWRAP_UNLOCK_ALL do { \
-       size_t __s; \
        swrap_mutex_unlock(&mtu_update_mutex); \
        swrap_mutex_unlock(&pcap_dump_mutex); \
        swrap_mutex_unlock(&autobind_start_mutex); \
-       for (__s = 0; (sockets != NULL) && __s < socket_info_max; __s++) { \
-               size_t __i = (socket_info_max - 1) - __s; \
-               swrap_mutex_unlock(&sockets[__i].meta.mutex); \
-       } \
+       swrap_mutex_unlock(&sockets_si_global); \
        swrap_mutex_unlock(&first_free_mutex); \
        swrap_mutex_unlock(&socket_reset_mutex); \
        swrap_mutex_unlock(&sockets_mutex); \
@@ -235,12 +225,20 @@ enum swrap_dbglvl_e {
 
 #define SWRAP_LOCK_SI(si) do { \
        struct socket_info_container *sic = SOCKET_INFO_CONTAINER(si); \
-       swrap_mutex_lock(&sic->meta.mutex); \
+       if (sic != NULL) { \
+               swrap_mutex_lock(&sockets_si_global); \
+       } else { \
+               abort(); \
+       } \
 } while(0)
 
 #define SWRAP_UNLOCK_SI(si) do { \
        struct socket_info_container *sic = SOCKET_INFO_CONTAINER(si); \
-       swrap_mutex_unlock(&sic->meta.mutex); \
+       if (sic != NULL) { \
+               swrap_mutex_unlock(&sockets_si_global); \
+       } else { \
+               abort(); \
+       } \
 } while(0)
 
 #if defined(HAVE_GETTIMEOFDAY_TZ) || defined(HAVE_GETTIMEOFDAY_TZ_VOID)
@@ -337,7 +335,13 @@ struct socket_info_meta
 {
        unsigned int refcount;
        int next_free;
-       pthread_mutex_t mutex;
+       /*
+        * As long as we don't use shared memory
+        * for the sockets array, we use
+        * sockets_si_global as a single mutex.
+        *
+        * pthread_mutex_t mutex;
+        */
 };
 
 struct socket_info_container
@@ -372,6 +376,14 @@ static pthread_mutex_t socket_reset_mutex = PTHREAD_MUTEX_INITIALIZER;
 /* Mutex to synchronize access to first free index in socket_info array */
 static pthread_mutex_t first_free_mutex = PTHREAD_MUTEX_INITIALIZER;
 
+/*
+ * Mutex to synchronize access to to socket_info structures
+ * We use a single global mutex in order to avoid leaking
+ * ~ 38M copy on write memory per fork.
+ * max_sockets=65535 * sizeof(struct socket_info_container)=592 = 38796720
+ */
+static pthread_mutex_t sockets_si_global = PTHREAD_MUTEX_INITIALIZER;
+
 /* Mutex to synchronize access to packet capture dump file */
 static pthread_mutex_t pcap_dump_mutex = PTHREAD_MUTEX_INITIALIZER;
 
@@ -755,6 +767,7 @@ static void _swrap_mutex_lock(pthread_mutex_t *mutex, const char *name, const ch
        if (ret != 0) {
                SWRAP_LOG(SWRAP_LOG_ERROR, "PID(%d):PPID(%d): %s(%u): Couldn't lock pthread mutex(%s) - %s",
                          getpid(), getppid(), caller, line, name, strerror(ret));
+               abort();
        }
 }
 
@@ -767,6 +780,7 @@ static void _swrap_mutex_unlock(pthread_mutex_t *mutex, const char *name, const
        if (ret != 0) {
                SWRAP_LOG(SWRAP_LOG_ERROR, "PID(%d):PPID(%d): %s(%u): Couldn't unlock pthread mutex(%s) - %s",
                          getpid(), getppid(), caller, line, name, strerror(ret));
+               abort();
        }
 }
 
@@ -1705,27 +1719,18 @@ static void socket_wrapper_init_sockets(void)
        }
 
        swrap_mutex_lock(&first_free_mutex);
+       swrap_mutex_lock(&sockets_si_global);
 
        first_free = 0;
 
        for (i = 0; i < max_sockets; i++) {
                swrap_set_next_free(&sockets[i].info, i+1);
-               sockets[i].meta.mutex = (pthread_mutex_t)PTHREAD_MUTEX_INITIALIZER;
-       }
-
-       for (i = 0; i < max_sockets; i++) {
-               ret = socket_wrapper_init_mutex(&sockets[i].meta.mutex);
-               if (ret != 0) {
-                       SWRAP_LOG(SWRAP_LOG_ERROR,
-                                 "Failed to initialize pthread mutex i=%zu", i);
-                       goto done;
-               }
        }
 
        /* mark the end of the free list */
        swrap_set_next_free(&sockets[max_sockets-1].info, -1);
 
-done:
+       swrap_mutex_unlock(&sockets_si_global);
        swrap_mutex_unlock(&first_free_mutex);
        swrap_mutex_unlock(&sockets_mutex);
        if (ret != 0) {
@@ -1885,31 +1890,40 @@ static int convert_un_in(const struct sockaddr_un *un, struct sockaddr *in, sock
        if (p) p++; else p = un->sun_path;
 
        if (sscanf(p, SOCKET_FORMAT, &type, &iface, &prt) != 3) {
+               SWRAP_LOG(SWRAP_LOG_ERROR, "sun_path[%s] p[%s]",
+                         un->sun_path, p);
                errno = EINVAL;
                return -1;
        }
 
-       SWRAP_LOG(SWRAP_LOG_TRACE, "type %c iface %u port %u",
-                       type, iface, prt);
-
        if (iface == 0 || iface > MAX_WRAPPED_INTERFACES) {
+               SWRAP_LOG(SWRAP_LOG_ERROR, "type %c iface %u port %u",
+                         type, iface, prt);
                errno = EINVAL;
                return -1;
        }
 
        if (prt > 0xFFFF) {
+               SWRAP_LOG(SWRAP_LOG_ERROR, "type %c iface %u port %u",
+                         type, iface, prt);
                errno = EINVAL;
                return -1;
        }
 
+       SWRAP_LOG(SWRAP_LOG_TRACE, "type %c iface %u port %u",
+                 type, iface, prt);
+
        switch(type) {
        case SOCKET_TYPE_CHAR_TCP:
        case SOCKET_TYPE_CHAR_UDP: {
                struct sockaddr_in *in2 = (struct sockaddr_in *)(void *)in;
 
                if ((*len) < sizeof(*in2)) {
-                   errno = EINVAL;
-                   return -1;
+                       SWRAP_LOG(SWRAP_LOG_ERROR,
+                                 "V4: *len(%zu) < sizeof(*in2)=%zu",
+                                 (size_t)*len, sizeof(*in2));
+                       errno = EINVAL;
+                       return -1;
                }
 
                memset(in2, 0, sizeof(*in2));
@@ -1926,6 +1940,10 @@ static int convert_un_in(const struct sockaddr_un *un, struct sockaddr *in, sock
                struct sockaddr_in6 *in2 = (struct sockaddr_in6 *)(void *)in;
 
                if ((*len) < sizeof(*in2)) {
+                       SWRAP_LOG(SWRAP_LOG_ERROR,
+                                 "V6: *len(%zu) < sizeof(*in2)=%zu",
+                                 (size_t)*len, sizeof(*in2));
+                       SWRAP_LOG(SWRAP_LOG_ERROR, "LINE:%d", __LINE__);
                        errno = EINVAL;
                        return -1;
                }
@@ -1941,6 +1959,8 @@ static int convert_un_in(const struct sockaddr_un *un, struct sockaddr *in, sock
        }
 #endif
        default:
+               SWRAP_LOG(SWRAP_LOG_ERROR, "type %c iface %u port %u",
+                         type, iface, prt);
                errno = EINVAL;
                return -1;
        }
@@ -3627,10 +3647,12 @@ static int swrap_accept(int s,
        ret = libc_accept(s, &un_addr.sa.s, &un_addr.sa_socklen);
 #endif
        if (ret == -1) {
-               if (errno == ENOTSOCK) {
+               int saved_errno = errno;
+               if (saved_errno == ENOTSOCK) {
                        /* Remove stale fds */
                        swrap_remove_stale(s);
                }
+               errno = saved_errno;
                return ret;
        }
 
@@ -3639,6 +3661,50 @@ static int swrap_accept(int s,
        /* Check if we have a stale fd and remove it */
        swrap_remove_stale(fd);
 
+       if (un_addr.sa.un.sun_path[0] == '\0') {
+               /*
+                * FreeBSD seems to have a problem where
+                * accept4() on the unix socket doesn't
+                * ECONNABORTED for already disconnected connections.
+                *
+                * Let's try libc_getpeername() to get the peer address
+                * as a fallback, but it'll likely return ENOTCONN,
+                * which we have to map to ECONNABORTED.
+                */
+               un_addr.sa_socklen = sizeof(struct sockaddr_un),
+               ret = libc_getpeername(fd, &un_addr.sa.s, &un_addr.sa_socklen);
+               if (ret == -1) {
+                       int saved_errno = errno;
+                       libc_close(fd);
+                       if (saved_errno == ENOTCONN) {
+                               /*
+                                * If the connection is already disconnected
+                                * we should return ECONNABORTED.
+                                */
+                               saved_errno = ECONNABORTED;
+                       }
+                       errno = saved_errno;
+                       return ret;
+               }
+       }
+
+       ret = libc_getsockname(fd,
+                              &un_my_addr.sa.s,
+                              &un_my_addr.sa_socklen);
+       if (ret == -1) {
+               int saved_errno = errno;
+               libc_close(fd);
+               if (saved_errno == ENOTCONN) {
+                       /*
+                        * If the connection is already disconnected
+                        * we should return ECONNABORTED.
+                        */
+                       saved_errno = ECONNABORTED;
+               }
+               errno = saved_errno;
+               return ret;
+       }
+
        SWRAP_LOCK_SI(parent_si);
 
        ret = sockaddr_convert_from_un(parent_si,
@@ -3648,8 +3714,10 @@ static int swrap_accept(int s,
                                       &in_addr.sa.s,
                                       &in_addr.sa_socklen);
        if (ret == -1) {
+               int saved_errno = errno;
                SWRAP_UNLOCK_SI(parent_si);
                libc_close(fd);
+               errno = saved_errno;
                return ret;
        }
 
@@ -3677,14 +3745,6 @@ static int swrap_accept(int s,
                *addrlen = in_addr.sa_socklen;
        }
 
-       ret = libc_getsockname(fd,
-                              &un_my_addr.sa.s,
-                              &un_my_addr.sa_socklen);
-       if (ret == -1) {
-               libc_close(fd);
-               return ret;
-       }
-
        ret = sockaddr_convert_from_un(child_si,
                                       &un_my_addr.sa.un,
                                       un_my_addr.sa_socklen,
@@ -3692,7 +3752,9 @@ static int swrap_accept(int s,
                                       &in_my_addr.sa.s,
                                       &in_my_addr.sa_socklen);
        if (ret == -1) {
+               int saved_errno = errno;
                libc_close(fd);
+               errno = saved_errno;
                return ret;
        }
 
@@ -5388,7 +5450,7 @@ static int swrap_sendmsg_unix_scm_rights(const struct cmsghdr *cmsg,
        *new_cmsg = *cmsg;
        __fds_out.p = CMSG_DATA(new_cmsg);
        fds_out = __fds_out.fds;
-       memcpy(fds_out, fds_in, size_fds_out);
+       memcpy(fds_out, fds_in, size_fds_in);
        new_cmsg->cmsg_len = cmsg->cmsg_len;
 
        for (i = 0; i < num_fds_in; i++) {
@@ -5900,8 +5962,48 @@ static ssize_t swrap_sendmsg_after_unix(struct msghdr *msg_tmp,
 static int swrap_recvmsg_before_unix(struct msghdr *msg_in,
                                     struct msghdr *msg_tmp)
 {
+#ifdef HAVE_STRUCT_MSGHDR_MSG_CONTROL
+       const size_t cm_extra_space = CMSG_SPACE(sizeof(int));
+       uint8_t *cm_data = NULL;
+       size_t cm_data_space = 0;
+
+       *msg_tmp = *msg_in;
+
+       SWRAP_LOG(SWRAP_LOG_TRACE,
+                 "msg_in->msg_controllen=%zu",
+                 (size_t)msg_in->msg_controllen);
+
+       /* Nothing to do */
+       if (msg_in->msg_controllen == 0 || msg_in->msg_control == NULL) {
+               return 0;
+       }
+
+       /*
+        * We need to give the kernel a bit more space in order
+        * recv the pipe fd, added by swrap_sendmsg_before_unix()).
+        * swrap_recvmsg_after_unix() will hide it again.
+        */
+       cm_data_space = msg_in->msg_controllen;
+       if (cm_data_space < (INT32_MAX - cm_extra_space)) {
+               cm_data_space += cm_extra_space;
+       }
+       cm_data = calloc(1, cm_data_space);
+       if (cm_data == NULL) {
+               return -1;
+       }
+       memcpy(cm_data, msg_in->msg_control, msg_in->msg_controllen);
+
+       msg_tmp->msg_controllen = cm_data_space;
+       msg_tmp->msg_control = cm_data;
+
+       SWRAP_LOG(SWRAP_LOG_TRACE,
+                 "msg_tmp->msg_controllen=%zu",
+                 (size_t)msg_tmp->msg_controllen);
+       return 0;
+#else /* HAVE_STRUCT_MSGHDR_MSG_CONTROL */
        *msg_tmp = *msg_in;
        return 0;
+#endif /* ! HAVE_STRUCT_MSGHDR_MSG_CONTROL */
 }
 
 static ssize_t swrap_recvmsg_after_unix(struct msghdr *msg_tmp,
@@ -5914,9 +6016,14 @@ static ssize_t swrap_recvmsg_after_unix(struct msghdr *msg_tmp,
        size_t cm_data_space = 0;
        int rc = -1;
 
+       SWRAP_LOG(SWRAP_LOG_TRACE,
+                 "msg_tmp->msg_controllen=%zu",
+                 (size_t)msg_tmp->msg_controllen);
+
        /* Nothing to do */
        if (msg_tmp->msg_controllen == 0 || msg_tmp->msg_control == NULL) {
-               goto done;
+               *msg_out = *msg_tmp;
+               return ret;
        }
 
        for (cmsg = CMSG_FIRSTHDR(msg_tmp);
@@ -5944,15 +6051,27 @@ static ssize_t swrap_recvmsg_after_unix(struct msghdr *msg_tmp,
        }
 
        /*
-        * msg_tmp->msg_control is still the buffer of the caller.
+        * msg_tmp->msg_control was created by swrap_recvmsg_before_unix()
+        * and msg_out->msg_control is still the buffer of the caller.
         */
-       memcpy(msg_tmp->msg_control, cm_data, cm_data_space);
-       msg_tmp->msg_controllen = cm_data_space;
+       SAFE_FREE(msg_tmp->msg_control);
+       msg_tmp->msg_control = msg_out->msg_control;
+       msg_tmp->msg_controllen = msg_out->msg_controllen;
+       *msg_out = *msg_tmp;
+
+       cm_data_space = MIN(cm_data_space, msg_out->msg_controllen);
+       memcpy(msg_out->msg_control, cm_data, cm_data_space);
+       msg_out->msg_controllen = cm_data_space;
        SAFE_FREE(cm_data);
-done:
-#endif /* ! HAVE_STRUCT_MSGHDR_MSG_CONTROL */
+
+       SWRAP_LOG(SWRAP_LOG_TRACE,
+                 "msg_out->msg_controllen=%zu",
+                 (size_t)msg_out->msg_controllen);
+       return ret;
+#else /* HAVE_STRUCT_MSGHDR_MSG_CONTROL */
        *msg_out = *msg_tmp;
        return ret;
+#endif /* ! HAVE_STRUCT_MSGHDR_MSG_CONTROL */
 }
 
 static ssize_t swrap_sendmsg_before(int fd,