rwrap: split out a rwrap_reset_nameservers() function
[resolv_wrapper.git] / src / resolv_wrapper.c
index c88893ef9ec432c3ca641efba7433ddfec321ed5..7c5bf81ab54ef6c95becaa8ebede45a0bdbd8285 100644 (file)
 
 #include <resolv.h>
 
+#ifdef HAVE_RES_STATE_U_EXT_NSADDRS
+#define HAVE_RESOLV_IPV6_NSADDRS 1
+#endif
+
 /* GCC has printf type attribute check. */
 #ifdef HAVE_ATTRIBUTE_PRINTF_FORMAT
 #define PRINTF_ATTRIBUTE(a,b) __attribute__ ((__format__ (__printf__, a, b)))
@@ -83,6 +87,19 @@ enum rwrap_dbglvl_e {
        RWRAP_LOG_TRACE
 };
 
+#ifndef HAVE_GETPROGNAME
+static const char *getprogname(void)
+{
+#if defined(HAVE_PROGRAM_INVOCATION_SHORT_NAME)
+       return program_invocation_short_name;
+#elif defined(HAVE_GETEXECNAME)
+       return getexecname();
+#else
+       return NULL;
+#endif /* HAVE_PROGRAM_INVOCATION_SHORT_NAME */
+}
+#endif /* HAVE_GETPROGNAME */
+
 static void rwrap_log(enum rwrap_dbglvl_e dbglvl, const char *func, const char *format, ...) PRINTF_ATTRIBUTE(3, 4);
 # define RWRAP_LOG(dbglvl, ...) rwrap_log((dbglvl), __func__, __VA_ARGS__)
 
@@ -94,8 +111,8 @@ static void rwrap_log(enum rwrap_dbglvl_e dbglvl,
        va_list va;
        const char *d;
        unsigned int lvl = 0;
-       int pid = getpid();
        const char *prefix = NULL;
+       const char *progname = NULL;
 
        d = getenv("RESOLV_WRAPPER_DEBUGLEVEL");
        if (d != NULL) {
@@ -128,10 +145,16 @@ static void rwrap_log(enum rwrap_dbglvl_e dbglvl,
                        break;
        }
 
+       progname = getprogname();
+       if (progname == NULL) {
+               progname = "<unknown>";
+       }
+
        fprintf(stderr,
-               "%s(%d) - %s: %s\n",
+               "%s[%s (%u)] - %s: %s\n",
                prefix,
-               pid,
+               progname,
+               (unsigned int)getpid(),
                func,
                buffer);
 }
@@ -1583,6 +1606,29 @@ static int libc_res_nsearch(struct __res_state *state,
  *   RES_HELPER
  ***************************************************************************/
 
+static void rwrap_reset_nameservers(struct __res_state *state)
+{
+#ifdef HAVE_RES_STATE_U_EXT_NSADDRS
+       size_t i;
+
+       for (i = 0; i < (size_t)state->nscount; i++) {
+               if (state->_u._ext.nssocks[i] != -1) {
+                       close(state->_u._ext.nssocks[i]);
+                       state->_u._ext.nssocks[i] = -1;
+               }
+               SAFE_FREE(state->_u._ext.nsaddrs[i]);
+       }
+       memset(&state->_u._ext, 0, sizeof(state->_u._ext));
+       for (i = 0; i < MAXNS; i++) {
+               state->_u._ext.nssocks[i] = -1;
+               state->_u._ext.nsmap[i] = MAXNS + 1;
+       }
+       state->ipv6_unavail = false;
+#endif
+       memset(state->nsaddr_list, 0, sizeof(state->nsaddr_list));
+       state->nscount = 0;
+}
+
 static int rwrap_parse_resolv_conf(struct __res_state *state,
                                   const char *resolv_conf)
 {
@@ -1590,6 +1636,8 @@ static int rwrap_parse_resolv_conf(struct __res_state *state,
        char buf[BUFSIZ];
        int nserv = 0;
 
+       rwrap_reset_nameservers(state);
+
        fp = fopen(resolv_conf, "r");
        if (fp == NULL) {
                RWRAP_LOG(RWRAP_LOG_ERROR,
@@ -1626,14 +1674,13 @@ static int rwrap_parse_resolv_conf(struct __res_state *state,
 
                        ok = inet_pton(AF_INET, p, &a);
                        if (ok) {
-                               state->nsaddr_list[state->nscount] = (struct sockaddr_in) {
+                               state->nsaddr_list[nserv] = (struct sockaddr_in) {
                                        .sin_family = AF_INET,
                                        .sin_addr = a,
                                        .sin_port = htons(53),
                                        .sin_zero = { 0 },
                                };
 
-                               state->nscount++;
                                nserv++;
                        } else {
 #ifdef HAVE_RESOLV_IPV6_NSADDRS
@@ -1654,11 +1701,11 @@ static int rwrap_parse_resolv_conf(struct __res_state *state,
                                        sa6->sin6_flowinfo = 0;
                                        sa6->sin6_addr = a6;
 
-                                       state->_u._ext.nsaddrs[state->_u._ext.nscount] = sa6;
-                                       state->_u._ext.nssocks[state->_u._ext.nscount] = -1;
-                                       state->_u._ext.nsmap[state->_u._ext.nscount] = MAXNS + 1;
+                                       state->_u._ext.nsaddrs[nserv] = sa6;
+                                       state->_u._ext.nssocks[nserv] = -1;
+                                       state->_u._ext.nsmap[nserv] = MAXNS + 1;
 
-                                       state->_u._ext.nscount++;
+                                       state->_u._ext.nscount6++;
                                        nserv++;
                                } else {
                                        RWRAP_LOG(RWRAP_LOG_ERROR,
@@ -1681,6 +1728,13 @@ static int rwrap_parse_resolv_conf(struct __res_state *state,
                } /* TODO: match other keywords */
        }
 
+       /*
+        * note that state->_u._ext.nscount is left as 0,
+        * this matches glibc and allows resolv wrapper
+        * to work with most (maybe all) glibc versions.
+        */
+       state->nscount = nserv;
+
        if (ferror(fp)) {
                RWRAP_LOG(RWRAP_LOG_ERROR,
                          "Reading from %s failed",
@@ -1706,21 +1760,6 @@ static int rwrap_res_ninit(struct __res_state *state)
                const char *resolv_conf = getenv("RESOLV_WRAPPER_CONF");
 
                if (resolv_conf != NULL) {
-                       uint16_t i;
-
-                       (void)i; /* maybe unused */
-
-                       /* Delete name servers */
-                       state->nscount = 0;
-                       memset(state->nsaddr_list, 0, sizeof(state->nsaddr_list));
-
-#ifdef HAVE_RESOLV_IPV6_NSADDRS
-                       state->_u._ext.nscount = 0;
-                       for (i = 0; i < state->_u._ext.nscount; i++) {
-                               SAFE_FREE(state->_u._ext.nsaddrs[i]);
-                       }
-#endif
-
                        rc = rwrap_parse_resolv_conf(state, resolv_conf);
                }
        }
@@ -1767,19 +1806,8 @@ int __res_init(void)
 
 static void rwrap_res_nclose(struct __res_state *state)
 {
-#ifdef HAVE_RESOLV_IPV6_NSADDRS
-       int i;
-#endif
-
+       rwrap_reset_nameservers(state);
        libc_res_nclose(state);
-
-#ifdef HAVE_RESOLV_IPV6_NSADDRS
-       if (state != NULL) {
-               for (i = 0; i < state->_u._ext.nscount; i++) {
-                       SAFE_FREE(state->_u._ext.nsaddrs[i]);
-               }
-       }
-#endif
 }
 
 #if !defined(res_nclose) && defined(HAVE_RES_NCLOSE)