nwrap: Bind symbols only once
authorAndreas Schneider <asn@samba.org>
Fri, 4 Nov 2022 13:24:54 +0000 (14:24 +0100)
committerAndreas Schneider <asn@samba.org>
Fri, 4 Nov 2022 13:41:50 +0000 (14:41 +0100)
Signed-off-by: Andreas Schneider <asn@samba.org>
Reviewed-by: Stefan Metzmacher <metze@samba.org>
src/nss_wrapper.c

index 8d91e86642d5f047ef2de5213974aa9f6c0d4765..97c908ef4bc644cb6740f05bb1b4cc7e83a7bce1 100644 (file)
@@ -177,7 +177,6 @@ typedef nss_status_t NSS_STATUS;
 #define NWRAP_INET_ADDRSTRLEN INET_ADDRSTRLEN
 #endif
 
-static pthread_mutex_t libc_symbol_binding_mutex = PTHREAD_MUTEX_INITIALIZER;
 static pthread_mutex_t nss_module_symbol_binding_mutex = PTHREAD_MUTEX_INITIALIZER;
 
 static bool nwrap_initialized = false;
@@ -195,7 +194,6 @@ static pthread_mutex_t nwrap_sp_global_mutex = PTHREAD_MUTEX_INITIALIZER;
  * nwrap_init() function.
  */
 # define NWRAP_LOCK_ALL do { \
-       nwrap_mutex_lock(&libc_symbol_binding_mutex); \
        nwrap_mutex_lock(&nss_module_symbol_binding_mutex); \
        nwrap_mutex_lock(&nwrap_initialized_mutex); \
        nwrap_mutex_lock(&nwrap_global_mutex); \
@@ -213,7 +211,6 @@ static pthread_mutex_t nwrap_sp_global_mutex = PTHREAD_MUTEX_INITIALIZER;
        nwrap_mutex_unlock(&nwrap_global_mutex); \
        nwrap_mutex_unlock(&nwrap_initialized_mutex); \
        nwrap_mutex_unlock(&nss_module_symbol_binding_mutex); \
-       nwrap_mutex_unlock(&libc_symbol_binding_mutex); \
 } while (0);
 
 static void nwrap_init(void);
@@ -1233,36 +1230,30 @@ static void _nwrap_mutex_unlock(pthread_mutex_t *mutex, const char *name, const
 }
 
 #define nwrap_bind_symbol_libc(sym_name) \
-       nwrap_mutex_lock(&libc_symbol_binding_mutex); \
        if (nwrap_main_global->libc->symbols._libc_##sym_name.obj == NULL) { \
                nwrap_main_global->libc->symbols._libc_##sym_name.obj = \
                        _nwrap_bind_symbol(NWRAP_LIBC, #sym_name); \
        } \
-       nwrap_mutex_unlock(&libc_symbol_binding_mutex)
 
 #define nwrap_bind_symbol_libc_posix(sym_name) \
-       nwrap_mutex_lock(&libc_symbol_binding_mutex); \
        if (nwrap_main_global->libc->symbols._libc_##sym_name.obj == NULL) { \
                nwrap_main_global->libc->symbols._libc_##sym_name.obj = \
                        _nwrap_bind_symbol(NWRAP_LIBC, "__posix_" #sym_name); \
        } \
-       nwrap_mutex_unlock(&libc_symbol_binding_mutex)
 
 #define nwrap_bind_symbol_libnsl(sym_name) \
-       nwrap_mutex_lock(&libc_symbol_binding_mutex); \
        if (nwrap_main_global->libc->symbols._libc_##sym_name.obj == NULL) { \
                nwrap_main_global->libc->symbols._libc_##sym_name.obj = \
                        _nwrap_bind_symbol(NWRAP_LIBNSL, #sym_name); \
        } \
-       nwrap_mutex_unlock(&libc_symbol_binding_mutex)
 
 #define nwrap_bind_symbol_libsocket(sym_name) \
-       nwrap_mutex_lock(&libc_symbol_binding_mutex); \
        if (nwrap_main_global->libc->symbols._libc_##sym_name.obj == NULL) { \
                nwrap_main_global->libc->symbols._libc_##sym_name.obj = \
                        _nwrap_bind_symbol(NWRAP_LIBSOCKET, #sym_name); \
        } \
-       nwrap_mutex_unlock(&libc_symbol_binding_mutex)
+
+static void nwrap_bind_symbol_all(void);
 
 /* INTERNAL HELPER FUNCTIONS */
 static void nwrap_lines_unload(struct nwrap_cache *const nwrap)
@@ -1287,7 +1278,7 @@ static void nwrap_lines_unload(struct nwrap_cache *const nwrap)
  */
 static struct passwd *libc_getpwnam(const char *name)
 {
-       nwrap_bind_symbol_libc(getpwnam);
+       nwrap_bind_symbol_all();
 
        return nwrap_main_global->libc->symbols._libc_getpwnam.f(name);
 }
@@ -1299,11 +1290,7 @@ static int libc_getpwnam_r(const char *name,
                           size_t buflen,
                           struct passwd **result)
 {
-#ifdef HAVE___POSIX_GETPWNAM_R
-       nwrap_bind_symbol_libc_posix(getpwnam_r);
-#else
-       nwrap_bind_symbol_libc(getpwnam_r);
-#endif
+       nwrap_bind_symbol_all();
 
        return nwrap_main_global->libc->symbols._libc_getpwnam_r.f(name,
                                                                   pwd,
@@ -1315,7 +1302,7 @@ static int libc_getpwnam_r(const char *name,
 
 static struct passwd *libc_getpwuid(uid_t uid)
 {
-       nwrap_bind_symbol_libc(getpwuid);
+       nwrap_bind_symbol_all();
 
        return nwrap_main_global->libc->symbols._libc_getpwuid.f(uid);
 }
@@ -1327,11 +1314,7 @@ static int libc_getpwuid_r(uid_t uid,
                           size_t buflen,
                           struct passwd **result)
 {
-#ifdef HAVE___POSIX_GETPWUID_R
-       nwrap_bind_symbol_libc_posix(getpwuid_r);
-#else
-       nwrap_bind_symbol_libc(getpwuid_r);
-#endif
+       nwrap_bind_symbol_all();
 
        return nwrap_main_global->libc->symbols._libc_getpwuid_r.f(uid,
                                                                   pwd,
@@ -1374,14 +1357,14 @@ static bool str_tolower_copy(char **dst_name, const char *const src_name)
 
 static void libc_setpwent(void)
 {
-       nwrap_bind_symbol_libc(setpwent);
+       nwrap_bind_symbol_all();
 
        nwrap_main_global->libc->symbols._libc_setpwent.f();
 }
 
 static struct passwd *libc_getpwent(void)
 {
-       nwrap_bind_symbol_libc(getpwent);
+       nwrap_bind_symbol_all();
 
        return nwrap_main_global->libc->symbols._libc_getpwent.f();
 }
@@ -1392,7 +1375,7 @@ static struct passwd *libc_getpwent_r(struct passwd *pwdst,
                                      char *buf,
                                      int buflen)
 {
-       nwrap_bind_symbol_libc(getpwent_r);
+       nwrap_bind_symbol_all();
 
        return nwrap_main_global->libc->symbols._libc_getpwent_r.f(pwdst,
                                                                   buf,
@@ -1404,7 +1387,7 @@ static int libc_getpwent_r(struct passwd *pwdst,
                           size_t buflen,
                           struct passwd **pwdstp)
 {
-       nwrap_bind_symbol_libc(getpwent_r);
+       nwrap_bind_symbol_all();
 
        return nwrap_main_global->libc->symbols._libc_getpwent_r.f(pwdst,
                                                                   buf,
@@ -1416,21 +1399,21 @@ static int libc_getpwent_r(struct passwd *pwdst,
 
 static void libc_endpwent(void)
 {
-       nwrap_bind_symbol_libc(endpwent);
+       nwrap_bind_symbol_all();
 
        nwrap_main_global->libc->symbols._libc_endpwent.f();
 }
 
 static int libc_initgroups(const char *user, gid_t gid)
 {
-       nwrap_bind_symbol_libc(initgroups);
+       nwrap_bind_symbol_all();
 
        return nwrap_main_global->libc->symbols._libc_initgroups.f(user, gid);
 }
 
 static struct group *libc_getgrnam(const char *name)
 {
-       nwrap_bind_symbol_libc(getgrnam);
+       nwrap_bind_symbol_all();
 
        return nwrap_main_global->libc->symbols._libc_getgrnam.f(name);
 }
@@ -1442,11 +1425,7 @@ static int libc_getgrnam_r(const char *name,
                           size_t buflen,
                           struct group **result)
 {
-#ifdef HAVE___POSIX_GETGRNAM_R
-       nwrap_bind_symbol_libc_posix(getgrnam_r);
-#else
-       nwrap_bind_symbol_libc(getgrnam_r);
-#endif
+       nwrap_bind_symbol_all();
 
        return nwrap_main_global->libc->symbols._libc_getgrnam_r.f(name,
                                                                   grp,
@@ -1458,7 +1437,7 @@ static int libc_getgrnam_r(const char *name,
 
 static struct group *libc_getgrgid(gid_t gid)
 {
-       nwrap_bind_symbol_libc(getgrgid);
+       nwrap_bind_symbol_all();
 
        return nwrap_main_global->libc->symbols._libc_getgrgid.f(gid);
 }
@@ -1470,14 +1449,7 @@ static int libc_getgrgid_r(gid_t gid,
                           size_t buflen,
                           struct group **result)
 {
-#ifdef HAVE___POSIX_GETGRGID_R
-       if (nwrap_main_global->libc->symbols._libc_getgrgid_r == NULL) {
-               *(void **) (&nwrap_main_global->libc->symbols._libc_getgrgid_r) =
-                       _nwrap_bind_symbol_libc("__posix_getgrgid_r");
-       }
-#else
-       nwrap_bind_symbol_libc(getgrgid_r);
-#endif
+       nwrap_bind_symbol_all();
 
        return nwrap_main_global->libc->symbols._libc_getgrgid_r.f(gid,
                                                                   grp,
@@ -1489,14 +1461,14 @@ static int libc_getgrgid_r(gid_t gid,
 
 static void libc_setgrent(void)
 {
-       nwrap_bind_symbol_libc(setgrent);
+       nwrap_bind_symbol_all();
 
        nwrap_main_global->libc->symbols._libc_setgrent.f();
 }
 
 static struct group *libc_getgrent(void)
 {
-       nwrap_bind_symbol_libc(getgrent);
+       nwrap_bind_symbol_all();
 
        return nwrap_main_global->libc->symbols._libc_getgrent.f();
 }
@@ -1507,7 +1479,7 @@ static struct group *libc_getgrent_r(struct group *group,
                                     char *buf,
                                     size_t buflen)
 {
-       nwrap_bind_symbol_libc(getgrent_r);
+       nwrap_bind_symbol_all();
 
        return nwrap_main_global->libc->symbols._libc_getgrent_r.f(group,
                                                                   buf,
@@ -1519,7 +1491,7 @@ static int libc_getgrent_r(struct group *group,
                           size_t buflen,
                           struct group **result)
 {
-       nwrap_bind_symbol_libc(getgrent_r);
+       nwrap_bind_symbol_all();
 
        return nwrap_main_global->libc->symbols._libc_getgrent_r.f(group,
                                                                   buf,
@@ -1531,7 +1503,7 @@ static int libc_getgrent_r(struct group *group,
 
 static void libc_endgrent(void)
 {
-       nwrap_bind_symbol_libc(endgrent);
+       nwrap_bind_symbol_all();
 
        nwrap_main_global->libc->symbols._libc_endgrent.f();
 }
@@ -1542,7 +1514,7 @@ static int libc_getgrouplist(const char *user,
                             gid_t *groups,
                             int *ngroups)
 {
-       nwrap_bind_symbol_libc(getgrouplist);
+       nwrap_bind_symbol_all();
 
        return nwrap_main_global->libc->symbols._libc_getgrouplist.f(user,
                                                                     group,
@@ -1553,28 +1525,28 @@ static int libc_getgrouplist(const char *user,
 
 static void libc_sethostent(int stayopen)
 {
-       nwrap_bind_symbol_libnsl(sethostent);
+       nwrap_bind_symbol_all();
 
        nwrap_main_global->libc->symbols._libc_sethostent.f(stayopen);
 }
 
 static struct hostent *libc_gethostent(void)
 {
-       nwrap_bind_symbol_libnsl(gethostent);
+       nwrap_bind_symbol_all();
 
        return nwrap_main_global->libc->symbols._libc_gethostent.f();
 }
 
 static void libc_endhostent(void)
 {
-       nwrap_bind_symbol_libnsl(endhostent);
+       nwrap_bind_symbol_all();
 
        nwrap_main_global->libc->symbols._libc_endhostent.f();
 }
 
 static struct hostent *libc_gethostbyname(const char *name)
 {
-       nwrap_bind_symbol_libnsl(gethostbyname);
+       nwrap_bind_symbol_all();
 
        return nwrap_main_global->libc->symbols._libc_gethostbyname.f(name);
 }
@@ -1582,7 +1554,7 @@ static struct hostent *libc_gethostbyname(const char *name)
 #ifdef HAVE_GETHOSTBYNAME2 /* GNU extension */
 static struct hostent *libc_gethostbyname2(const char *name, int af)
 {
-       nwrap_bind_symbol_libnsl(gethostbyname2);
+       nwrap_bind_symbol_all();
 
        return nwrap_main_global->libc->symbols._libc_gethostbyname2.f(name, af);
 }
@@ -1597,7 +1569,7 @@ static int libc_gethostbyname2_r(const char *name,
                                 struct hostent **result,
                                 int *h_errnop)
 {
-       nwrap_bind_symbol_libnsl(gethostbyname2_r);
+       nwrap_bind_symbol_all();
 
        return nwrap_main_global->libc->symbols._libc_gethostbyname2_r.f(name,
                                                                         af,
@@ -1613,7 +1585,7 @@ static struct hostent *libc_gethostbyaddr(const void *addr,
                                          socklen_t len,
                                          int type)
 {
-       nwrap_bind_symbol_libnsl(gethostbyaddr);
+       nwrap_bind_symbol_all();
 
        return nwrap_main_global->libc->symbols._libc_gethostbyaddr.f(addr,
                                                                      len,
@@ -1622,7 +1594,7 @@ static struct hostent *libc_gethostbyaddr(const void *addr,
 
 static int libc_gethostname(char *name, size_t len)
 {
-       nwrap_bind_symbol_libnsl(gethostname);
+       nwrap_bind_symbol_all();
 
        return nwrap_main_global->libc->symbols._libc_gethostname.f(name, len);
 }
@@ -1635,7 +1607,7 @@ static int libc_gethostbyname_r(const char *name,
                                struct hostent **result,
                                int *h_errnop)
 {
-       nwrap_bind_symbol_libnsl(gethostbyname_r);
+       nwrap_bind_symbol_all();
 
        return nwrap_main_global->libc->symbols._libc_gethostbyname_r.f(name,
                                                                        ret,
@@ -1656,7 +1628,7 @@ static int libc_gethostbyaddr_r(const void *addr,
                                struct hostent **result,
                                int *h_errnop)
 {
-       nwrap_bind_symbol_libnsl(gethostbyaddr_r);
+       nwrap_bind_symbol_all();
 
        return nwrap_main_global->libc->symbols._libc_gethostbyaddr_r.f(addr,
                                                                        len,
@@ -1674,7 +1646,7 @@ static int libc_getaddrinfo(const char *node,
                            const struct addrinfo *hints,
                            struct addrinfo **res)
 {
-       nwrap_bind_symbol_libsocket(getaddrinfo);
+       nwrap_bind_symbol_all();
 
        return nwrap_main_global->libc->symbols._libc_getaddrinfo.f(node,
                                                                    service,
@@ -1690,7 +1662,7 @@ static int libc_getnameinfo(const struct sockaddr *sa,
                            size_t servlen,
                            int flags)
 {
-       nwrap_bind_symbol_libsocket(getnameinfo);
+       nwrap_bind_symbol_all();
 
        return nwrap_main_global->libc->symbols._libc_getnameinfo.f(sa,
                                                                    salen,
@@ -1701,6 +1673,81 @@ static int libc_getnameinfo(const struct sockaddr *sa,
                                                                    flags);
 }
 
+static void __nwrap_bind_symbol_all_once(void)
+{
+       nwrap_bind_symbol_libc(getpwnam);
+#ifdef HAVE_GETPWNAM_R
+# ifdef HAVE___POSIX_GETPWNAM_R
+       nwrap_bind_symbol_libc_posix(getpwnam_r);
+# else
+       nwrap_bind_symbol_libc(getpwnam_r);
+# endif
+#endif
+       nwrap_bind_symbol_libc(getpwuid);
+#ifdef HAVE_GETPWUID_R
+# ifdef HAVE___POSIX_GETPWUID_R
+       nwrap_bind_symbol_libc_posix(getpwuid_r);
+# else
+       nwrap_bind_symbol_libc(getpwuid_r);
+# endif
+#endif
+       nwrap_bind_symbol_libc(setpwent);
+       nwrap_bind_symbol_libc(getpwent);
+#ifdef HAVE_GETPWENT_R
+       nwrap_bind_symbol_libc(getpwent_r);
+#endif
+       nwrap_bind_symbol_libc(endpwent);
+       nwrap_bind_symbol_libc(initgroups);
+       nwrap_bind_symbol_libc(getgrnam);
+#ifdef HAVE_GETGRNAM_R
+# ifdef HAVE___POSIX_GETGRNAM_R
+       nwrap_bind_symbol_libc_posix(getgrnam_r);
+# else
+       nwrap_bind_symbol_libc(getgrnam_r);
+# endif
+#endif
+       nwrap_bind_symbol_libc(getgrgid);
+#ifdef HAVE_GETGRGID_R
+# ifdef HAVE___POSIX_GETGRGID_R
+       nwrap_bind_symbol_libc_posix(getgrgid_r);
+# else
+       nwrap_bind_symbol_libc(getgrgid_r);
+# endif
+#endif
+       nwrap_bind_symbol_libc(setgrent);
+       nwrap_bind_symbol_libc(getgrent);
+       nwrap_bind_symbol_libc(getgrent_r);
+       nwrap_bind_symbol_libc(endgrent);
+       nwrap_bind_symbol_libc(getgrouplist);
+       nwrap_bind_symbol_libnsl(sethostent);
+       nwrap_bind_symbol_libnsl(gethostent);
+       nwrap_bind_symbol_libnsl(endhostent);
+       nwrap_bind_symbol_libnsl(gethostbyname);
+#ifdef HAVE_GETHOSTBYNAME2 /* GNU extension */
+       nwrap_bind_symbol_libnsl(gethostbyname2);
+#endif
+#ifdef HAVE_GETHOSTBYNAME2_R /* GNU extension */
+       nwrap_bind_symbol_libnsl(gethostbyname2_r);
+#endif
+       nwrap_bind_symbol_libnsl(gethostbyaddr);
+       nwrap_bind_symbol_libnsl(gethostname);
+#ifdef HAVE_GETHOSTBYNAME_R
+       nwrap_bind_symbol_libnsl(gethostbyname_r);
+#endif
+#ifdef HAVE_GETHOSTBYADDR_R
+       nwrap_bind_symbol_libnsl(gethostbyaddr_r);
+#endif
+       nwrap_bind_symbol_libsocket(getaddrinfo);
+       nwrap_bind_symbol_libsocket(getnameinfo);
+}
+
+static void nwrap_bind_symbol_all(void)
+{
+       static pthread_once_t all_symbol_binding_once = PTHREAD_ONCE_INIT;
+
+       pthread_once(&all_symbol_binding_once, __nwrap_bind_symbol_all_once);
+}
+
 /*********************************************************
  * NWRAP NSS MODULE LOADER FUNCTIONS
  *********************************************************/