swrap: export a public socket_wrapper_indicate_no_inet_fd() helper function
[socket_wrapper.git] / src / socket_wrapper.c
index ece34934f9c90de80baa4c9918920049d8ea22bb..714cd259fa46402066e2c9f1bfff316a4a647af8 100644 (file)
@@ -2,8 +2,8 @@
  * BSD 3-Clause License
  *
  * Copyright (c) 2005-2008, Jelmer Vernooij <jelmer@samba.org>
- * Copyright (c) 2006-2018, Stefan Metzmacher <metze@samba.org>
- * Copyright (c) 2013-2018, Andreas Schneider <asn@samba.org>
+ * Copyright (c) 2006-2021, Stefan Metzmacher <metze@samba.org>
+ * Copyright (c) 2013-2021, Andreas Schneider <asn@samba.org>
  * Copyright (c) 2014-2017, Michael Adam <obnox@samba.org>
  * Copyright (c) 2016-2018, Anoop C S <anoopcs@redhat.com>
  * All rights reserved.
@@ -86,6 +86,8 @@
 #endif
 #include <pthread.h>
 
+#include "socket_wrapper.h"
+
 enum swrap_dbglvl_e {
        SWRAP_LOG_ERROR = 0,
        SWRAP_LOG_WARN,
@@ -370,7 +372,7 @@ static pthread_mutex_t autobind_start_mutex = PTHREAD_MUTEX_INITIALIZER;
 /* Mutex to guard the initialization of array of socket_info structures */
 static pthread_mutex_t sockets_mutex = PTHREAD_MUTEX_INITIALIZER;
 
-/* Mutex to guard the socket reset in swrap_close() and swrap_remove_stale() */
+/* Mutex to guard the socket reset in swrap_remove_wrapper() */
 static pthread_mutex_t socket_reset_mutex = PTHREAD_MUTEX_INITIALIZER;
 
 /* Mutex to synchronize access to first free index in socket_info array */
@@ -392,8 +394,6 @@ static pthread_mutex_t mtu_update_mutex = PTHREAD_MUTEX_INITIALIZER;
 
 /* Function prototypes */
 
-bool socket_wrapper_enabled(void);
-
 #if ! defined(HAVE_CONSTRUCTOR_ATTRIBUTE) && defined(HAVE_PRAGMA_INIT)
 /* xlC and other oldschool compilers support (only) this */
 #pragma init (swrap_constructor)
@@ -2027,6 +2027,13 @@ static int convert_in_un_remote(struct socket_info *si, const struct sockaddr *i
                        type = u_type;
                        iface = (addr & 0x000000FF);
                } else {
+                       char str[256] = {0,};
+                       inet_ntop(inaddr->sa_family,
+                                 &in->sin_addr,
+                                 str, sizeof(str));
+                       SWRAP_LOG(SWRAP_LOG_WARN,
+                                 "str[%s] prt[%u]",
+                                 str, (unsigned)prt);
                        errno = ENETUNREACH;
                        return -1;
                }
@@ -2062,6 +2069,13 @@ static int convert_in_un_remote(struct socket_info *si, const struct sockaddr *i
                if (IN6_ARE_ADDR_EQUAL(&cmp1, &cmp2)) {
                        iface = in->sin6_addr.s6_addr[15];
                } else {
+                       char str[256] = {0,};
+                       inet_ntop(inaddr->sa_family,
+                                 &in->sin6_addr,
+                                 str, sizeof(str));
+                       SWRAP_LOG(SWRAP_LOG_WARN,
+                                 "str[%s] prt[%u]",
+                                 str, (unsigned)prt);
                        errno = ENETUNREACH;
                        return -1;
                }
@@ -2390,46 +2404,7 @@ static bool check_addr_port_in_use(const struct sockaddr *sa, socklen_t len)
 }
 #endif
 
-static void swrap_remove_stale(int fd)
-{
-       struct socket_info *si;
-       int si_index;
-
-       SWRAP_LOG(SWRAP_LOG_TRACE, "remove stale wrapper for %d", fd);
-
-       swrap_mutex_lock(&socket_reset_mutex);
-
-       si_index = find_socket_info_index(fd);
-       if (si_index == -1) {
-               swrap_mutex_unlock(&socket_reset_mutex);
-               return;
-       }
-
-       reset_socket_info_index(fd);
-
-       si = swrap_get_socket_info(si_index);
-
-       swrap_mutex_lock(&first_free_mutex);
-       SWRAP_LOCK_SI(si);
-
-       swrap_dec_refcount(si);
-
-       if (swrap_get_refcount(si) > 0) {
-               goto out;
-       }
-
-       if (si->un_addr.sun_path[0] != '\0') {
-               unlink(si->un_addr.sun_path);
-       }
-
-       swrap_set_next_free(si, first_free);
-       first_free = si_index;
-
-out:
-       SWRAP_UNLOCK_SI(si);
-       swrap_mutex_unlock(&first_free_mutex);
-       swrap_mutex_unlock(&socket_reset_mutex);
-}
+static void swrap_remove_stale(int fd);
 
 static int sockaddr_convert_to_un(struct socket_info *si,
                                  const struct sockaddr *in_addr,
@@ -2990,7 +2965,7 @@ static int swrap_pcap_get_fd(const char *fname)
                file_hdr.frame_max_len  = SWRAP_FRAME_LENGTH_MAX;
                file_hdr.link_type      = 0x0065; /* 101 RAW IP */
 
-               if (write(fd, &file_hdr, sizeof(file_hdr)) != sizeof(file_hdr)) {
+               if (libc_write(fd, &file_hdr, sizeof(file_hdr)) != sizeof(file_hdr)) {
                        libc_close(fd);
                        fd = -1;
                }
@@ -3325,7 +3300,7 @@ static void swrap_pcap_dump_packet(struct socket_info *si,
 
        fd = swrap_pcap_get_fd(file_name);
        if (fd != -1) {
-               if (write(fd, packet, packet_len) != (ssize_t)packet_len) {
+               if (libc_write(fd, packet, packet_len) != (ssize_t)packet_len) {
                        free(packet);
                        goto done;
                }
@@ -5532,7 +5507,7 @@ static int swrap_sendmsg_unix_scm_rights(const struct cmsghdr *cmsg,
                return -1;
        }
 
-       sret = write(pipefd[1], &info, sizeof(info));
+       sret = libc_write(pipefd[1], &info, sizeof(info));
        if (sret != sizeof(info)) {
                int saved_errno = errno;
                if (sret != -1) {
@@ -5960,7 +5935,8 @@ static ssize_t swrap_sendmsg_after_unix(struct msghdr *msg_tmp,
 }
 
 static int swrap_recvmsg_before_unix(struct msghdr *msg_in,
-                                    struct msghdr *msg_tmp)
+                                    struct msghdr *msg_tmp,
+                                    uint8_t **tmp_control)
 {
 #ifdef HAVE_STRUCT_MSGHDR_MSG_CONTROL
        const size_t cm_extra_space = CMSG_SPACE(sizeof(int));
@@ -5968,6 +5944,7 @@ static int swrap_recvmsg_before_unix(struct msghdr *msg_in,
        size_t cm_data_space = 0;
 
        *msg_tmp = *msg_in;
+       *tmp_control = NULL;
 
        SWRAP_LOG(SWRAP_LOG_TRACE,
                  "msg_in->msg_controllen=%zu",
@@ -5991,10 +5968,10 @@ static int swrap_recvmsg_before_unix(struct msghdr *msg_in,
        if (cm_data == NULL) {
                return -1;
        }
-       memcpy(cm_data, msg_in->msg_control, msg_in->msg_controllen);
 
        msg_tmp->msg_controllen = cm_data_space;
        msg_tmp->msg_control = cm_data;
+       *tmp_control = cm_data;
 
        SWRAP_LOG(SWRAP_LOG_TRACE,
                  "msg_tmp->msg_controllen=%zu",
@@ -6002,11 +5979,13 @@ static int swrap_recvmsg_before_unix(struct msghdr *msg_in,
        return 0;
 #else /* HAVE_STRUCT_MSGHDR_MSG_CONTROL */
        *msg_tmp = *msg_in;
+       *tmp_control = NULL;
        return 0;
 #endif /* ! HAVE_STRUCT_MSGHDR_MSG_CONTROL */
 }
 
 static ssize_t swrap_recvmsg_after_unix(struct msghdr *msg_tmp,
+                                       uint8_t **tmp_control,
                                        struct msghdr *msg_out,
                                        ssize_t ret)
 {
@@ -6016,13 +5995,26 @@ static ssize_t swrap_recvmsg_after_unix(struct msghdr *msg_tmp,
        size_t cm_data_space = 0;
        int rc = -1;
 
+       if (ret < 0) {
+               int saved_errno = errno;
+               SWRAP_LOG(SWRAP_LOG_TRACE, "ret=%zd - %d - %s", ret,
+                         saved_errno, strerror(saved_errno));
+               SAFE_FREE(*tmp_control);
+               /* msg_out should not be touched on error */
+               errno = saved_errno;
+               return ret;
+       }
+
        SWRAP_LOG(SWRAP_LOG_TRACE,
                  "msg_tmp->msg_controllen=%zu",
                  (size_t)msg_tmp->msg_controllen);
 
        /* Nothing to do */
        if (msg_tmp->msg_controllen == 0 || msg_tmp->msg_control == NULL) {
+               int saved_errno = errno;
                *msg_out = *msg_tmp;
+               SAFE_FREE(*tmp_control);
+               errno = saved_errno;
                return ret;
        }
 
@@ -6045,16 +6037,17 @@ static ssize_t swrap_recvmsg_after_unix(struct msghdr *msg_tmp,
                if (rc < 0) {
                        int saved_errno = errno;
                        SAFE_FREE(cm_data);
+                       SAFE_FREE(*tmp_control);
                        errno = saved_errno;
                        return rc;
                }
        }
 
        /*
-        * msg_tmp->msg_control was created by swrap_recvmsg_before_unix()
-        * and msg_out->msg_control is still the buffer of the caller.
+        * msg_tmp->msg_control (*tmp_control) was created by
+        * swrap_recvmsg_before_unix() and msg_out->msg_control
+        * is still the buffer of the caller.
         */
-       SAFE_FREE(msg_tmp->msg_control);
        msg_tmp->msg_control = msg_out->msg_control;
        msg_tmp->msg_controllen = msg_out->msg_controllen;
        *msg_out = *msg_tmp;
@@ -6063,13 +6056,17 @@ static ssize_t swrap_recvmsg_after_unix(struct msghdr *msg_tmp,
        memcpy(msg_out->msg_control, cm_data, cm_data_space);
        msg_out->msg_controllen = cm_data_space;
        SAFE_FREE(cm_data);
+       SAFE_FREE(*tmp_control);
 
        SWRAP_LOG(SWRAP_LOG_TRACE,
                  "msg_out->msg_controllen=%zu",
                  (size_t)msg_out->msg_controllen);
        return ret;
 #else /* HAVE_STRUCT_MSGHDR_MSG_CONTROL */
+       int saved_errno = errno;
        *msg_out = *msg_tmp;
+       SAFE_FREE(*tmp_control);
+       errno = saved_errno;
        return ret;
 #endif /* ! HAVE_STRUCT_MSGHDR_MSG_CONTROL */
 }
@@ -6986,12 +6983,13 @@ static ssize_t swrap_recvmsg(int s, struct msghdr *omsg, int flags)
 
        si = find_socket_info(s);
        if (si == NULL) {
-               rc = swrap_recvmsg_before_unix(omsg, &msg);
+               uint8_t *tmp_control = NULL;
+               rc = swrap_recvmsg_before_unix(omsg, &msg, &tmp_control);
                if (rc < 0) {
                        return rc;
                }
                ret = libc_recvmsg(s, &msg, flags);
-               return swrap_recvmsg_after_unix(&msg, omsg, ret);
+               return swrap_recvmsg_after_unix(&msg, &tmp_control, omsg, ret);
        }
 
        tmp.iov_base = NULL;
@@ -7380,10 +7378,13 @@ ssize_t writev(int s, const struct iovec *vector, int count)
  * CLOSE
  ***************************/
 
-static int swrap_close(int fd)
+static int swrap_remove_wrapper(const char *__func_name,
+                               int (*__close_fd_fn)(int fd),
+                               int fd)
 {
        struct socket_info *si = NULL;
        int si_index;
+       int ret_errno = errno;
        int ret;
 
        swrap_mutex_lock(&socket_reset_mutex);
@@ -7391,10 +7392,10 @@ static int swrap_close(int fd)
        si_index = find_socket_info_index(fd);
        if (si_index == -1) {
                swrap_mutex_unlock(&socket_reset_mutex);
-               return libc_close(fd);
+               return __close_fd_fn(fd);
        }
 
-       SWRAP_LOG(SWRAP_LOG_TRACE, "Close wrapper for fd=%d", fd);
+       swrap_log(SWRAP_LOG_TRACE, __func_name, "Remove wrapper for fd=%d", fd);
        reset_socket_info_index(fd);
 
        si = swrap_get_socket_info(si_index);
@@ -7402,7 +7403,10 @@ static int swrap_close(int fd)
        swrap_mutex_lock(&first_free_mutex);
        SWRAP_LOCK_SI(si);
 
-       ret = libc_close(fd);
+       ret = __close_fd_fn(fd);
+       if (ret == -1) {
+               ret_errno = errno;
+       }
 
        swrap_dec_refcount(si);
 
@@ -7437,9 +7441,48 @@ out:
        swrap_mutex_unlock(&first_free_mutex);
        swrap_mutex_unlock(&socket_reset_mutex);
 
+       errno = ret_errno;
        return ret;
 }
 
+static int swrap_noop_close(int fd)
+{
+       (void)fd; /* unused */
+       return 0;
+}
+
+static void swrap_remove_stale(int fd)
+{
+       swrap_remove_wrapper(__func__, swrap_noop_close, fd);
+}
+
+/*
+ * This allows socket_wrapper aware applications to
+ * indicate that the given fd does not belong to
+ * an inet socket.
+ *
+ * We already overload a lot of unrelated functions
+ * like eventfd(), timerfd_create(), ... in order to
+ * call swrap_remove_stale() on the returned fd, but
+ * we'll never be able to handle all possible syscalls.
+ *
+ * socket_wrapper_indicate_no_inet_fd() gives them a way
+ * to do the same.
+ *
+ * We don't export swrap_remove_stale() in order to
+ * make it easier to analyze SOCKET_WRAPPER_DEBUGLEVEL=3
+ * log files.
+ */
+void socket_wrapper_indicate_no_inet_fd(int fd)
+{
+       swrap_remove_wrapper(__func__, swrap_noop_close, fd);
+}
+
+static int swrap_close(int fd)
+{
+       return swrap_remove_wrapper(__func__, libc_close, fd);
+}
+
 int close(int fd)
 {
        return swrap_close(fd);