src/socket_wrapper.c: always go through swrap_bind_symbol_all() protected by pthread_...
authorStefan Metzmacher <metze@samba.org>
Thu, 28 Jan 2021 13:31:31 +0000 (14:31 +0100)
committerStefan Metzmacher <metze@samba.org>
Thu, 28 Jan 2021 14:27:14 +0000 (15:27 +0100)
Signed-off-by: Stefan Metzmacher <metze@samba.org>
Reviewed-by: Andreas Schneider <asn@samba.org>
src/socket_wrapper.c

index 5101e7a7becd99c6339f4e9bcb9cf2432475ec89..388b764018d2af105445dfe158902f9ed657e746 100644 (file)
@@ -179,11 +179,11 @@ enum swrap_dbglvl_e {
 #endif
 
 /* Add new global locks here please */
-# define SWRAP_LOCK_ALL \
-       swrap_mutex_lock(&libc_symbol_binding_mutex); \
+# define SWRAP_LOCK_ALL do { \
+} while(0)
 
-# define SWRAP_UNLOCK_ALL \
-       swrap_mutex_unlock(&libc_symbol_binding_mutex); \
+# define SWRAP_UNLOCK_ALL do { \
+} while(0)
 
 #define SOCKET_INFO_CONTAINER(si) \
        (struct socket_info_container *)(si)
@@ -309,9 +309,6 @@ static size_t socket_fds_max = SOCKET_WRAPPER_MAX_SOCKETS_LIMIT;
 /* Hash table to map fds to corresponding socket_info index */
 static int *socket_fds_idx;
 
-/* Mutex to synchronize access to global libc.symbols */
-static pthread_mutex_t libc_symbol_binding_mutex = PTHREAD_MUTEX_INITIALIZER;
-
 /* Mutex for syncronizing port selection during swrap_auto_bind() */
 static pthread_mutex_t autobind_start_mutex;
 
@@ -726,15 +723,10 @@ static void swrap_mutex_unlock(pthread_mutex_t *mutex)
  * This is an optimization to avoid locking each time we check if the symbol is
  * bound.
  */
-#define _swrap_bind_symbol_generic(lib, sym_name) \
-       if (swrap.libc.symbols._libc_##sym_name.obj == NULL) { \
-               swrap_mutex_lock(&libc_symbol_binding_mutex); \
-               if (swrap.libc.symbols._libc_##sym_name.obj == NULL) { \
-                       swrap.libc.symbols._libc_##sym_name.obj = \
-                               _swrap_bind_symbol(lib, #sym_name); \
-               } \
-               swrap_mutex_unlock(&libc_symbol_binding_mutex); \
-       }
+#define _swrap_bind_symbol_generic(lib, sym_name) do { \
+       swrap.libc.symbols._libc_##sym_name.obj = \
+               _swrap_bind_symbol(lib, #sym_name); \
+} while(0);
 
 #define swrap_bind_symbol_libc(sym_name) \
        _swrap_bind_symbol_generic(SWRAP_LIBC, sym_name)
@@ -742,6 +734,8 @@ static void swrap_mutex_unlock(pthread_mutex_t *mutex)
 #define swrap_bind_symbol_libsocket(sym_name) \
        _swrap_bind_symbol_generic(SWRAP_LIBSOCKET, sym_name)
 
+static void swrap_bind_symbol_all(void);
+
 /****************************************************************************
  *                               IMPORTANT
  ****************************************************************************
@@ -759,7 +753,7 @@ static int libc_accept4(int sockfd,
                        socklen_t *addrlen,
                        int flags)
 {
-       swrap_bind_symbol_libsocket(accept4);
+       swrap_bind_symbol_all();
 
        return swrap.libc.symbols._libc_accept4.f(sockfd, addr, addrlen, flags);
 }
@@ -768,7 +762,7 @@ static int libc_accept4(int sockfd,
 
 static int libc_accept(int sockfd, struct sockaddr *addr, socklen_t *addrlen)
 {
-       swrap_bind_symbol_libsocket(accept);
+       swrap_bind_symbol_all();
 
        return swrap.libc.symbols._libc_accept.f(sockfd, addr, addrlen);
 }
@@ -778,14 +772,14 @@ static int libc_bind(int sockfd,
                     const struct sockaddr *addr,
                     socklen_t addrlen)
 {
-       swrap_bind_symbol_libsocket(bind);
+       swrap_bind_symbol_all();
 
        return swrap.libc.symbols._libc_bind.f(sockfd, addr, addrlen);
 }
 
 static int libc_close(int fd)
 {
-       swrap_bind_symbol_libc(close);
+       swrap_bind_symbol_all();
 
        return swrap.libc.symbols._libc_close.f(fd);
 }
@@ -794,21 +788,21 @@ static int libc_connect(int sockfd,
                        const struct sockaddr *addr,
                        socklen_t addrlen)
 {
-       swrap_bind_symbol_libsocket(connect);
+       swrap_bind_symbol_all();
 
        return swrap.libc.symbols._libc_connect.f(sockfd, addr, addrlen);
 }
 
 static int libc_dup(int fd)
 {
-       swrap_bind_symbol_libc(dup);
+       swrap_bind_symbol_all();
 
        return swrap.libc.symbols._libc_dup.f(fd);
 }
 
 static int libc_dup2(int oldfd, int newfd)
 {
-       swrap_bind_symbol_libc(dup2);
+       swrap_bind_symbol_all();
 
        return swrap.libc.symbols._libc_dup2.f(oldfd, newfd);
 }
@@ -816,7 +810,7 @@ static int libc_dup2(int oldfd, int newfd)
 #ifdef HAVE_EVENTFD
 static int libc_eventfd(int count, int flags)
 {
-       swrap_bind_symbol_libc(eventfd);
+       swrap_bind_symbol_all();
 
        return swrap.libc.symbols._libc_eventfd.f(count, flags);
 }
@@ -828,7 +822,7 @@ static int libc_vfcntl(int fd, int cmd, va_list ap)
        void *arg;
        int rc;
 
-       swrap_bind_symbol_libc(fcntl);
+       swrap_bind_symbol_all();
 
        arg = va_arg(ap, void *);
 
@@ -841,7 +835,7 @@ static int libc_getpeername(int sockfd,
                            struct sockaddr *addr,
                            socklen_t *addrlen)
 {
-       swrap_bind_symbol_libsocket(getpeername);
+       swrap_bind_symbol_all();
 
        return swrap.libc.symbols._libc_getpeername.f(sockfd, addr, addrlen);
 }
@@ -850,7 +844,7 @@ static int libc_getsockname(int sockfd,
                            struct sockaddr *addr,
                            socklen_t *addrlen)
 {
-       swrap_bind_symbol_libsocket(getsockname);
+       swrap_bind_symbol_all();
 
        return swrap.libc.symbols._libc_getsockname.f(sockfd, addr, addrlen);
 }
@@ -861,7 +855,7 @@ static int libc_getsockopt(int sockfd,
                           void *optval,
                           socklen_t *optlen)
 {
-       swrap_bind_symbol_libsocket(getsockopt);
+       swrap_bind_symbol_all();
 
        return swrap.libc.symbols._libc_getsockopt.f(sockfd,
                                                     level,
@@ -876,7 +870,7 @@ static int libc_vioctl(int d, unsigned long int request, va_list ap)
        void *arg;
        int rc;
 
-       swrap_bind_symbol_libc(ioctl);
+       swrap_bind_symbol_all();
 
        arg = va_arg(ap, void *);
 
@@ -887,14 +881,14 @@ static int libc_vioctl(int d, unsigned long int request, va_list ap)
 
 static int libc_listen(int sockfd, int backlog)
 {
-       swrap_bind_symbol_libsocket(listen);
+       swrap_bind_symbol_all();
 
        return swrap.libc.symbols._libc_listen.f(sockfd, backlog);
 }
 
 static FILE *libc_fopen(const char *name, const char *mode)
 {
-       swrap_bind_symbol_libc(fopen);
+       swrap_bind_symbol_all();
 
        return swrap.libc.symbols._libc_fopen.f(name, mode);
 }
@@ -902,7 +896,7 @@ static FILE *libc_fopen(const char *name, const char *mode)
 #ifdef HAVE_FOPEN64
 static FILE *libc_fopen64(const char *name, const char *mode)
 {
-       swrap_bind_symbol_libc(fopen64);
+       swrap_bind_symbol_all();
 
        return swrap.libc.symbols._libc_fopen64.f(name, mode);
 }
@@ -913,7 +907,7 @@ static int libc_vopen(const char *pathname, int flags, va_list ap)
        int mode = 0;
        int fd;
 
-       swrap_bind_symbol_libc(open);
+       swrap_bind_symbol_all();
 
        if (flags & O_CREAT) {
                mode = va_arg(ap, int);
@@ -941,7 +935,7 @@ static int libc_vopen64(const char *pathname, int flags, va_list ap)
        int mode = 0;
        int fd;
 
-       swrap_bind_symbol_libc(open64);
+       swrap_bind_symbol_all();
 
        if (flags & O_CREAT) {
                mode = va_arg(ap, int);
@@ -957,7 +951,7 @@ static int libc_vopenat(int dirfd, const char *path, int flags, va_list ap)
        int mode = 0;
        int fd;
 
-       swrap_bind_symbol_libc(openat);
+       swrap_bind_symbol_all();
 
        if (flags & O_CREAT) {
                mode = va_arg(ap, int);
@@ -986,28 +980,28 @@ static int libc_openat(int dirfd, const char *path, int flags, ...)
 
 static int libc_pipe(int pipefd[2])
 {
-       swrap_bind_symbol_libsocket(pipe);
+       swrap_bind_symbol_all();
 
        return swrap.libc.symbols._libc_pipe.f(pipefd);
 }
 
 static int libc_read(int fd, void *buf, size_t count)
 {
-       swrap_bind_symbol_libc(read);
+       swrap_bind_symbol_all();
 
        return swrap.libc.symbols._libc_read.f(fd, buf, count);
 }
 
 static ssize_t libc_readv(int fd, const struct iovec *iov, int iovcnt)
 {
-       swrap_bind_symbol_libsocket(readv);
+       swrap_bind_symbol_all();
 
        return swrap.libc.symbols._libc_readv.f(fd, iov, iovcnt);
 }
 
 static int libc_recv(int sockfd, void *buf, size_t len, int flags)
 {
-       swrap_bind_symbol_libsocket(recv);
+       swrap_bind_symbol_all();
 
        return swrap.libc.symbols._libc_recv.f(sockfd, buf, len, flags);
 }
@@ -1019,7 +1013,7 @@ static int libc_recvfrom(int sockfd,
                         struct sockaddr *src_addr,
                         socklen_t *addrlen)
 {
-       swrap_bind_symbol_libsocket(recvfrom);
+       swrap_bind_symbol_all();
 
        return swrap.libc.symbols._libc_recvfrom.f(sockfd,
                                                   buf,
@@ -1031,21 +1025,21 @@ static int libc_recvfrom(int sockfd,
 
 static int libc_recvmsg(int sockfd, struct msghdr *msg, int flags)
 {
-       swrap_bind_symbol_libsocket(recvmsg);
+       swrap_bind_symbol_all();
 
        return swrap.libc.symbols._libc_recvmsg.f(sockfd, msg, flags);
 }
 
 static int libc_send(int sockfd, const void *buf, size_t len, int flags)
 {
-       swrap_bind_symbol_libsocket(send);
+       swrap_bind_symbol_all();
 
        return swrap.libc.symbols._libc_send.f(sockfd, buf, len, flags);
 }
 
 static int libc_sendmsg(int sockfd, const struct msghdr *msg, int flags)
 {
-       swrap_bind_symbol_libsocket(sendmsg);
+       swrap_bind_symbol_all();
 
        return swrap.libc.symbols._libc_sendmsg.f(sockfd, msg, flags);
 }
@@ -1057,7 +1051,7 @@ static int libc_sendto(int sockfd,
                       const  struct sockaddr *dst_addr,
                       socklen_t addrlen)
 {
-       swrap_bind_symbol_libsocket(sendto);
+       swrap_bind_symbol_all();
 
        return swrap.libc.symbols._libc_sendto.f(sockfd,
                                                 buf,
@@ -1073,7 +1067,7 @@ static int libc_setsockopt(int sockfd,
                           const void *optval,
                           socklen_t optlen)
 {
-       swrap_bind_symbol_libsocket(setsockopt);
+       swrap_bind_symbol_all();
 
        return swrap.libc.symbols._libc_setsockopt.f(sockfd,
                                                     level,
@@ -1085,7 +1079,7 @@ static int libc_setsockopt(int sockfd,
 #ifdef HAVE_SIGNALFD
 static int libc_signalfd(int fd, const sigset_t *mask, int flags)
 {
-       swrap_bind_symbol_libsocket(signalfd);
+       swrap_bind_symbol_all();
 
        return swrap.libc.symbols._libc_signalfd.f(fd, mask, flags);
 }
@@ -1093,14 +1087,14 @@ static int libc_signalfd(int fd, const sigset_t *mask, int flags)
 
 static int libc_socket(int domain, int type, int protocol)
 {
-       swrap_bind_symbol_libsocket(socket);
+       swrap_bind_symbol_all();
 
        return swrap.libc.symbols._libc_socket.f(domain, type, protocol);
 }
 
 static int libc_socketpair(int domain, int type, int protocol, int sv[2])
 {
-       swrap_bind_symbol_libsocket(socketpair);
+       swrap_bind_symbol_all();
 
        return swrap.libc.symbols._libc_socketpair.f(domain, type, protocol, sv);
 }
@@ -1108,7 +1102,7 @@ static int libc_socketpair(int domain, int type, int protocol, int sv[2])
 #ifdef HAVE_TIMERFD_CREATE
 static int libc_timerfd_create(int clockid, int flags)
 {
-       swrap_bind_symbol_libc(timerfd_create);
+       swrap_bind_symbol_all();
 
        return swrap.libc.symbols._libc_timerfd_create.f(clockid, flags);
 }
@@ -1116,20 +1110,20 @@ static int libc_timerfd_create(int clockid, int flags)
 
 static ssize_t libc_write(int fd, const void *buf, size_t count)
 {
-       swrap_bind_symbol_libc(write);
+       swrap_bind_symbol_all();
 
        return swrap.libc.symbols._libc_write.f(fd, buf, count);
 }
 
 static ssize_t libc_writev(int fd, const struct iovec *iov, int iovcnt)
 {
-       swrap_bind_symbol_libsocket(writev);
+       swrap_bind_symbol_all();
 
        return swrap.libc.symbols._libc_writev.f(fd, iov, iovcnt);
 }
 
 /* DO NOT call this function during library initialization! */
-static void swrap_bind_symbol_all(void)
+static void __swrap_bind_symbol_all_once(void)
 {
 #ifdef HAVE_ACCEPT4
        swrap_bind_symbol_libsocket(accept4);
@@ -1181,6 +1175,13 @@ static void swrap_bind_symbol_all(void)
        swrap_bind_symbol_libsocket(writev);
 }
 
+static void swrap_bind_symbol_all(void)
+{
+       static pthread_once_t all_symbol_binding_once = PTHREAD_ONCE_INIT;
+
+       pthread_once(&all_symbol_binding_once, __swrap_bind_symbol_all_once);
+}
+
 /*********************************************************
  * SWRAP HELPER FUNCTIONS
  *********************************************************/
@@ -1609,6 +1610,8 @@ static void socket_wrapper_init_sockets(void)
        size_t i;
        int ret;
 
+       swrap_bind_symbol_all();
+
        swrap_mutex_lock(&sockets_mutex);
 
        if (sockets != NULL) {