src/socket_wrapper.c: make use of uid_wrapper_syscall_{valid,va}()
authorStefan Metzmacher <metze@samba.org>
Mon, 16 Jan 2023 18:51:05 +0000 (19:51 +0100)
committerAndreas Schneider <asn@samba.org>
Tue, 17 Jan 2023 16:49:01 +0000 (17:49 +0100)
If we find uid_wrapper_syscall_{valid,va}() symbols in the already
loaded libraries, we'll try to hand over syscall() invocations to
uid_wrapper.

Signed-off-by: Stefan Metzmacher <metze@samba.org>
Reviewed-by: Andreas Schneider <asn@samba.org>
src/socket_wrapper.c

index 199d65243cba738d5de064453680e6b2a042cf56..bf4a976eaee7069ab4bade1c650ca3969d5bfd56 100644 (file)
@@ -659,6 +659,28 @@ struct swrap_libc_symbols {
        SWRAP_SYMBOL_ENTRY(syscall);
 #endif
 };
+#undef SWRAP_SYMBOL_ENTRY
+
+#define SWRAP_SYMBOL_ENTRY(i) \
+       union { \
+               __rtld_default_##i f; \
+               void *obj; \
+       } _rtld_default_##i
+
+#ifdef HAVE_SYSCALL
+typedef bool (*__rtld_default_uid_wrapper_syscall_valid)(long int sysno);
+typedef long int (*__rtld_default_uid_wrapper_syscall_va)(long int sysno, va_list va);
+#endif
+
+struct swrap_rtld_default_symbols {
+#ifdef HAVE_SYSCALL
+       SWRAP_SYMBOL_ENTRY(uid_wrapper_syscall_valid);
+       SWRAP_SYMBOL_ENTRY(uid_wrapper_syscall_va);
+#else
+       uint8_t dummy;
+#endif
+};
+#undef SWRAP_SYMBOL_ENTRY
 
 struct swrap {
        struct {
@@ -666,6 +688,10 @@ struct swrap {
                void *socket_handle;
                struct swrap_libc_symbols symbols;
        } libc;
+
+       struct {
+               struct swrap_rtld_default_symbols symbols;
+       } rtld_default;
 };
 
 static struct swrap swrap;
@@ -846,6 +872,11 @@ static void _swrap_mutex_unlock(pthread_mutex_t *mutex, const char *name, const
 #define swrap_bind_symbol_libsocket(sym_name) \
        _swrap_bind_symbol_generic(SWRAP_LIBSOCKET, sym_name)
 
+#define swrap_bind_symbol_rtld_default_optional(sym_name) do { \
+       swrap.rtld_default.symbols._rtld_default_##sym_name.obj = \
+               dlsym(RTLD_DEFAULT, #sym_name); \
+} while(0);
+
 static void swrap_bind_symbol_all(void);
 
 /****************************************************************************
@@ -1321,6 +1352,36 @@ static long int libc_vsyscall(long int sysno, va_list va)
 
        return rc;
 }
+
+static bool swrap_uwrap_syscall_valid(long int sysno)
+{
+       swrap_bind_symbol_all();
+
+       if (swrap.rtld_default.symbols._rtld_default_uid_wrapper_syscall_valid.f == NULL) {
+               return false;
+       }
+
+       return swrap.rtld_default.symbols._rtld_default_uid_wrapper_syscall_valid.f(
+                                               sysno);
+}
+
+DO_NOT_SANITIZE_ADDRESS_ATTRIBUTE
+static long int swrap_uwrap_syscall_va(long int sysno, va_list va)
+{
+       swrap_bind_symbol_all();
+
+       if (swrap.rtld_default.symbols._rtld_default_uid_wrapper_syscall_va.f == NULL) {
+               /*
+                * Fallback to libc, if uid_wrapper_syscall_va is not
+                * available.
+                */
+               return libc_vsyscall(sysno, va);
+       }
+
+       return swrap.rtld_default.symbols._rtld_default_uid_wrapper_syscall_va.f(
+                                               sysno,
+                                               va);
+}
 #endif /* HAVE_SYSCALL */
 
 /* DO NOT call this function during library initialization! */
@@ -1385,6 +1446,8 @@ static void __swrap_bind_symbol_all_once(void)
        swrap_bind_symbol_libsocket(writev);
 #ifdef HAVE_SYSCALL
        swrap_bind_symbol_libc(syscall);
+       swrap_bind_symbol_rtld_default_optional(uid_wrapper_syscall_valid);
+       swrap_bind_symbol_rtld_default_optional(uid_wrapper_syscall_va);
 #endif
 }
 
@@ -8504,6 +8567,16 @@ long int syscall(long int sysno, ...)
         * we care about...
         */
        if (!swrap_is_swrap_related_syscall(sysno)) {
+               /*
+                * We need to give socket_wrapper a
+                * chance to take over...
+                */
+               if (swrap_uwrap_syscall_valid(sysno)) {
+                       rc = swrap_uwrap_syscall_va(sysno, va);
+                       va_end(va);
+                       return rc;
+               }
+
                rc = libc_vsyscall(sysno, va);
                va_end(va);
                return rc;