From a9d93c078a80bb301df3f811da7442f6e99edc68 Mon Sep 17 00:00:00 2001 From: Stefan Metzmacher Date: Thu, 28 Jan 2021 14:31:31 +0100 Subject: [PATCH] src/socket_wrapper.c: always go through swrap_bind_symbol_all() protected by pthread_once() Signed-off-by: Stefan Metzmacher Reviewed-by: Andreas Schneider --- src/socket_wrapper.c | 107 ++++++++++++++++++++++--------------------- 1 file changed, 55 insertions(+), 52 deletions(-) diff --git a/src/socket_wrapper.c b/src/socket_wrapper.c index 5101e7a..388b764 100644 --- a/src/socket_wrapper.c +++ b/src/socket_wrapper.c @@ -179,11 +179,11 @@ enum swrap_dbglvl_e { #endif /* Add new global locks here please */ -# define SWRAP_LOCK_ALL \ - swrap_mutex_lock(&libc_symbol_binding_mutex); \ +# define SWRAP_LOCK_ALL do { \ +} while(0) -# define SWRAP_UNLOCK_ALL \ - swrap_mutex_unlock(&libc_symbol_binding_mutex); \ +# define SWRAP_UNLOCK_ALL do { \ +} while(0) #define SOCKET_INFO_CONTAINER(si) \ (struct socket_info_container *)(si) @@ -309,9 +309,6 @@ static size_t socket_fds_max = SOCKET_WRAPPER_MAX_SOCKETS_LIMIT; /* Hash table to map fds to corresponding socket_info index */ static int *socket_fds_idx; -/* Mutex to synchronize access to global libc.symbols */ -static pthread_mutex_t libc_symbol_binding_mutex = PTHREAD_MUTEX_INITIALIZER; - /* Mutex for syncronizing port selection during swrap_auto_bind() */ static pthread_mutex_t autobind_start_mutex; @@ -726,15 +723,10 @@ static void swrap_mutex_unlock(pthread_mutex_t *mutex) * This is an optimization to avoid locking each time we check if the symbol is * bound. */ -#define _swrap_bind_symbol_generic(lib, sym_name) \ - if (swrap.libc.symbols._libc_##sym_name.obj == NULL) { \ - swrap_mutex_lock(&libc_symbol_binding_mutex); \ - if (swrap.libc.symbols._libc_##sym_name.obj == NULL) { \ - swrap.libc.symbols._libc_##sym_name.obj = \ - _swrap_bind_symbol(lib, #sym_name); \ - } \ - swrap_mutex_unlock(&libc_symbol_binding_mutex); \ - } +#define _swrap_bind_symbol_generic(lib, sym_name) do { \ + swrap.libc.symbols._libc_##sym_name.obj = \ + _swrap_bind_symbol(lib, #sym_name); \ +} while(0); #define swrap_bind_symbol_libc(sym_name) \ _swrap_bind_symbol_generic(SWRAP_LIBC, sym_name) @@ -742,6 +734,8 @@ static void swrap_mutex_unlock(pthread_mutex_t *mutex) #define swrap_bind_symbol_libsocket(sym_name) \ _swrap_bind_symbol_generic(SWRAP_LIBSOCKET, sym_name) +static void swrap_bind_symbol_all(void); + /**************************************************************************** * IMPORTANT **************************************************************************** @@ -759,7 +753,7 @@ static int libc_accept4(int sockfd, socklen_t *addrlen, int flags) { - swrap_bind_symbol_libsocket(accept4); + swrap_bind_symbol_all(); return swrap.libc.symbols._libc_accept4.f(sockfd, addr, addrlen, flags); } @@ -768,7 +762,7 @@ static int libc_accept4(int sockfd, static int libc_accept(int sockfd, struct sockaddr *addr, socklen_t *addrlen) { - swrap_bind_symbol_libsocket(accept); + swrap_bind_symbol_all(); return swrap.libc.symbols._libc_accept.f(sockfd, addr, addrlen); } @@ -778,14 +772,14 @@ static int libc_bind(int sockfd, const struct sockaddr *addr, socklen_t addrlen) { - swrap_bind_symbol_libsocket(bind); + swrap_bind_symbol_all(); return swrap.libc.symbols._libc_bind.f(sockfd, addr, addrlen); } static int libc_close(int fd) { - swrap_bind_symbol_libc(close); + swrap_bind_symbol_all(); return swrap.libc.symbols._libc_close.f(fd); } @@ -794,21 +788,21 @@ static int libc_connect(int sockfd, const struct sockaddr *addr, socklen_t addrlen) { - swrap_bind_symbol_libsocket(connect); + swrap_bind_symbol_all(); return swrap.libc.symbols._libc_connect.f(sockfd, addr, addrlen); } static int libc_dup(int fd) { - swrap_bind_symbol_libc(dup); + swrap_bind_symbol_all(); return swrap.libc.symbols._libc_dup.f(fd); } static int libc_dup2(int oldfd, int newfd) { - swrap_bind_symbol_libc(dup2); + swrap_bind_symbol_all(); return swrap.libc.symbols._libc_dup2.f(oldfd, newfd); } @@ -816,7 +810,7 @@ static int libc_dup2(int oldfd, int newfd) #ifdef HAVE_EVENTFD static int libc_eventfd(int count, int flags) { - swrap_bind_symbol_libc(eventfd); + swrap_bind_symbol_all(); return swrap.libc.symbols._libc_eventfd.f(count, flags); } @@ -828,7 +822,7 @@ static int libc_vfcntl(int fd, int cmd, va_list ap) void *arg; int rc; - swrap_bind_symbol_libc(fcntl); + swrap_bind_symbol_all(); arg = va_arg(ap, void *); @@ -841,7 +835,7 @@ static int libc_getpeername(int sockfd, struct sockaddr *addr, socklen_t *addrlen) { - swrap_bind_symbol_libsocket(getpeername); + swrap_bind_symbol_all(); return swrap.libc.symbols._libc_getpeername.f(sockfd, addr, addrlen); } @@ -850,7 +844,7 @@ static int libc_getsockname(int sockfd, struct sockaddr *addr, socklen_t *addrlen) { - swrap_bind_symbol_libsocket(getsockname); + swrap_bind_symbol_all(); return swrap.libc.symbols._libc_getsockname.f(sockfd, addr, addrlen); } @@ -861,7 +855,7 @@ static int libc_getsockopt(int sockfd, void *optval, socklen_t *optlen) { - swrap_bind_symbol_libsocket(getsockopt); + swrap_bind_symbol_all(); return swrap.libc.symbols._libc_getsockopt.f(sockfd, level, @@ -876,7 +870,7 @@ static int libc_vioctl(int d, unsigned long int request, va_list ap) void *arg; int rc; - swrap_bind_symbol_libc(ioctl); + swrap_bind_symbol_all(); arg = va_arg(ap, void *); @@ -887,14 +881,14 @@ static int libc_vioctl(int d, unsigned long int request, va_list ap) static int libc_listen(int sockfd, int backlog) { - swrap_bind_symbol_libsocket(listen); + swrap_bind_symbol_all(); return swrap.libc.symbols._libc_listen.f(sockfd, backlog); } static FILE *libc_fopen(const char *name, const char *mode) { - swrap_bind_symbol_libc(fopen); + swrap_bind_symbol_all(); return swrap.libc.symbols._libc_fopen.f(name, mode); } @@ -902,7 +896,7 @@ static FILE *libc_fopen(const char *name, const char *mode) #ifdef HAVE_FOPEN64 static FILE *libc_fopen64(const char *name, const char *mode) { - swrap_bind_symbol_libc(fopen64); + swrap_bind_symbol_all(); return swrap.libc.symbols._libc_fopen64.f(name, mode); } @@ -913,7 +907,7 @@ static int libc_vopen(const char *pathname, int flags, va_list ap) int mode = 0; int fd; - swrap_bind_symbol_libc(open); + swrap_bind_symbol_all(); if (flags & O_CREAT) { mode = va_arg(ap, int); @@ -941,7 +935,7 @@ static int libc_vopen64(const char *pathname, int flags, va_list ap) int mode = 0; int fd; - swrap_bind_symbol_libc(open64); + swrap_bind_symbol_all(); if (flags & O_CREAT) { mode = va_arg(ap, int); @@ -957,7 +951,7 @@ static int libc_vopenat(int dirfd, const char *path, int flags, va_list ap) int mode = 0; int fd; - swrap_bind_symbol_libc(openat); + swrap_bind_symbol_all(); if (flags & O_CREAT) { mode = va_arg(ap, int); @@ -986,28 +980,28 @@ static int libc_openat(int dirfd, const char *path, int flags, ...) static int libc_pipe(int pipefd[2]) { - swrap_bind_symbol_libsocket(pipe); + swrap_bind_symbol_all(); return swrap.libc.symbols._libc_pipe.f(pipefd); } static int libc_read(int fd, void *buf, size_t count) { - swrap_bind_symbol_libc(read); + swrap_bind_symbol_all(); return swrap.libc.symbols._libc_read.f(fd, buf, count); } static ssize_t libc_readv(int fd, const struct iovec *iov, int iovcnt) { - swrap_bind_symbol_libsocket(readv); + swrap_bind_symbol_all(); return swrap.libc.symbols._libc_readv.f(fd, iov, iovcnt); } static int libc_recv(int sockfd, void *buf, size_t len, int flags) { - swrap_bind_symbol_libsocket(recv); + swrap_bind_symbol_all(); return swrap.libc.symbols._libc_recv.f(sockfd, buf, len, flags); } @@ -1019,7 +1013,7 @@ static int libc_recvfrom(int sockfd, struct sockaddr *src_addr, socklen_t *addrlen) { - swrap_bind_symbol_libsocket(recvfrom); + swrap_bind_symbol_all(); return swrap.libc.symbols._libc_recvfrom.f(sockfd, buf, @@ -1031,21 +1025,21 @@ static int libc_recvfrom(int sockfd, static int libc_recvmsg(int sockfd, struct msghdr *msg, int flags) { - swrap_bind_symbol_libsocket(recvmsg); + swrap_bind_symbol_all(); return swrap.libc.symbols._libc_recvmsg.f(sockfd, msg, flags); } static int libc_send(int sockfd, const void *buf, size_t len, int flags) { - swrap_bind_symbol_libsocket(send); + swrap_bind_symbol_all(); return swrap.libc.symbols._libc_send.f(sockfd, buf, len, flags); } static int libc_sendmsg(int sockfd, const struct msghdr *msg, int flags) { - swrap_bind_symbol_libsocket(sendmsg); + swrap_bind_symbol_all(); return swrap.libc.symbols._libc_sendmsg.f(sockfd, msg, flags); } @@ -1057,7 +1051,7 @@ static int libc_sendto(int sockfd, const struct sockaddr *dst_addr, socklen_t addrlen) { - swrap_bind_symbol_libsocket(sendto); + swrap_bind_symbol_all(); return swrap.libc.symbols._libc_sendto.f(sockfd, buf, @@ -1073,7 +1067,7 @@ static int libc_setsockopt(int sockfd, const void *optval, socklen_t optlen) { - swrap_bind_symbol_libsocket(setsockopt); + swrap_bind_symbol_all(); return swrap.libc.symbols._libc_setsockopt.f(sockfd, level, @@ -1085,7 +1079,7 @@ static int libc_setsockopt(int sockfd, #ifdef HAVE_SIGNALFD static int libc_signalfd(int fd, const sigset_t *mask, int flags) { - swrap_bind_symbol_libsocket(signalfd); + swrap_bind_symbol_all(); return swrap.libc.symbols._libc_signalfd.f(fd, mask, flags); } @@ -1093,14 +1087,14 @@ static int libc_signalfd(int fd, const sigset_t *mask, int flags) static int libc_socket(int domain, int type, int protocol) { - swrap_bind_symbol_libsocket(socket); + swrap_bind_symbol_all(); return swrap.libc.symbols._libc_socket.f(domain, type, protocol); } static int libc_socketpair(int domain, int type, int protocol, int sv[2]) { - swrap_bind_symbol_libsocket(socketpair); + swrap_bind_symbol_all(); return swrap.libc.symbols._libc_socketpair.f(domain, type, protocol, sv); } @@ -1108,7 +1102,7 @@ static int libc_socketpair(int domain, int type, int protocol, int sv[2]) #ifdef HAVE_TIMERFD_CREATE static int libc_timerfd_create(int clockid, int flags) { - swrap_bind_symbol_libc(timerfd_create); + swrap_bind_symbol_all(); return swrap.libc.symbols._libc_timerfd_create.f(clockid, flags); } @@ -1116,20 +1110,20 @@ static int libc_timerfd_create(int clockid, int flags) static ssize_t libc_write(int fd, const void *buf, size_t count) { - swrap_bind_symbol_libc(write); + swrap_bind_symbol_all(); return swrap.libc.symbols._libc_write.f(fd, buf, count); } static ssize_t libc_writev(int fd, const struct iovec *iov, int iovcnt) { - swrap_bind_symbol_libsocket(writev); + swrap_bind_symbol_all(); return swrap.libc.symbols._libc_writev.f(fd, iov, iovcnt); } /* DO NOT call this function during library initialization! */ -static void swrap_bind_symbol_all(void) +static void __swrap_bind_symbol_all_once(void) { #ifdef HAVE_ACCEPT4 swrap_bind_symbol_libsocket(accept4); @@ -1181,6 +1175,13 @@ static void swrap_bind_symbol_all(void) swrap_bind_symbol_libsocket(writev); } +static void swrap_bind_symbol_all(void) +{ + static pthread_once_t all_symbol_binding_once = PTHREAD_ONCE_INIT; + + pthread_once(&all_symbol_binding_once, __swrap_bind_symbol_all_once); +} + /********************************************************* * SWRAP HELPER FUNCTIONS *********************************************************/ @@ -1609,6 +1610,8 @@ static void socket_wrapper_init_sockets(void) size_t i; int ret; + swrap_bind_symbol_all(); + swrap_mutex_lock(&sockets_mutex); if (sockets != NULL) { -- 2.34.1