s3:libads: add ads_set_reconnect_fn() and only reconnect if we can get creds
authorStefan Metzmacher <metze@samba.org>
Wed, 6 Mar 2024 09:13:11 +0000 (10:13 +0100)
committerStefan Metzmacher <metze@samba.org>
Wed, 8 May 2024 08:17:06 +0000 (10:17 +0200)
This reconnect is only useful for long running connections (e.g. in winbindd)
and there we'll make use of it...

Signed-off-by: Stefan Metzmacher <metze@samba.org>
source3/include/ads.h
source3/libads/ads_ldap_protos.h
source3/libads/ads_struct.c
source3/libads/ldap_utils.c
source3/librpc/idl/ads.idl

index 6c9e57b9ed0ca8796e576b9c6ec83f9f09d01a3e..92430cd1edc46212ce41b80b5897f068af5e6b78 100644 (file)
@@ -6,6 +6,9 @@
   basically this is a wrapper around ldap
 */
 
+struct cli_credentials;
+struct ads_reconnect_state;
+
 #include "libads/ads_status.h"
 #include "smb_ldap.h"
 #include "librpc/gen_ndr/ads.h"
@@ -19,6 +22,14 @@ struct ads_saslwrap_ops {
        void (*disconnect)(struct ads_saslwrap *);
 };
 
+struct ads_reconnect_state {
+       NTSTATUS (*fn)(struct ads_struct *ads,
+                      void *private_data,
+                      TALLOC_CTX *mem_ctx,
+                      struct cli_credentials **creds);
+       void *private_data;
+};
+
 typedef struct ads_struct ADS_STRUCT;
 
 #ifdef HAVE_ADS
index b063815678a16f3e1e7e9ec770b63120ea6bb180..aba80476c0520d2c38d114f6fae058467f2e8ddf 100644 (file)
@@ -77,6 +77,12 @@ ADS_STATUS ads_search(ADS_STRUCT *ads, LDAPMessage **res,
                      const char *expr, const char **attrs);
 ADS_STATUS ads_search_dn(ADS_STRUCT *ads, LDAPMessage **res,
                         const char *dn, const char **attrs);
+void ads_set_reconnect_fn(ADS_STRUCT *ads,
+                         NTSTATUS (*fn)(struct ads_struct *ads,
+                                        void *private_data,
+                                        TALLOC_CTX *mem_ctx,
+                                        struct cli_credentials **creds),
+                         void *private_data);
 ADS_STATUS ads_do_search_all_args(ADS_STRUCT *ads, const char *bind_path,
                                  int scope, const char *expr,
                                  const char **attrs, void *args,
index 55f55e7e3643a6b2b8ef055ab9c9b88d12657e2f..c597d58230fa2757b4a1025bb2732fd541a9cf47 100644 (file)
@@ -216,6 +216,13 @@ ADS_STRUCT *ads_init(TALLOC_CTX *mem_ctx,
 
        ads->auth.flags = wrap_flags;
 
+       ads->auth.reconnect_state = talloc_zero(ads,
+                                               struct ads_reconnect_state);
+       if (ads->auth.reconnect_state == NULL) {
+               TALLOC_FREE(ads);
+               return NULL;
+       }
+
        /* Start with the configured page size when the connection is new,
         * we will drop it by half we get a timeout.   */
        ads->config.ldap_page_size     = lp_ldap_page_size();
index c08f046a4057c7375a8443beaa588ca04e3f9f38..9d6d962a2bc47d36a208bff8ab9b2c8134d5a9ad 100644 (file)
 #include "includes.h"
 #include "ads.h"
 #include "lib/param/loadparm.h"
+#include "auth/credentials/credentials.h"
 
 #ifdef HAVE_LDAP
 
+void ads_set_reconnect_fn(ADS_STRUCT *ads,
+                         NTSTATUS (*fn)(struct ads_struct *ads,
+                                        void *private_data,
+                                        TALLOC_CTX *mem_ctx,
+                                        struct cli_credentials **creds),
+                         void *private_data)
+{
+       ads->auth.reconnect_state->fn = fn;
+       ads->auth.reconnect_state->private_data = private_data;
+}
+
 static ADS_STATUS ads_ranged_search_internal(ADS_STRUCT *ads,
                                             TALLOC_CTX *mem_ctx,
                                             int scope,
@@ -84,6 +96,9 @@ static ADS_STATUS ads_do_search_retry_internal(ADS_STRUCT *ads, const char *bind
        }
 
        while (--count) {
+               struct cli_credentials *creds = NULL;
+               char *cred_name = NULL;
+               NTSTATUS ntstatus;
 
                if (NT_STATUS_EQUAL(ads_ntstatus(status), NT_STATUS_IO_TIMEOUT) &&
                    ads->config.ldap_page_size >= (lp_ldap_page_size() / 4) &&
@@ -98,24 +113,49 @@ static ADS_STATUS ads_do_search_retry_internal(ADS_STRUCT *ads, const char *bind
                        ads_msgfree(ads, *res);
                *res = NULL;
 
-               DEBUG(3,("Reopening ads connection to realm '%s' after error %s\n", 
-                        ads->config.realm, ads_errstr(status)));
-
                ads_disconnect(ads);
-               status = ads_connect(ads);
 
+               if (ads->auth.reconnect_state->fn == NULL) {
+                       DBG_NOTICE("Search for %s in <%s> failed: %s\n",
+                                  expr, bp, ads_errstr(status));
+                       SAFE_FREE(bp);
+                       return status;
+               }
+
+               ntstatus = ads->auth.reconnect_state->fn(ads,
+                               ads->auth.reconnect_state->private_data,
+                               ads, &creds);
+               if (!NT_STATUS_IS_OK(ntstatus)) {
+                       DBG_WARNING("Failed to get creds for realm(%s): %s\n",
+                                   ads->server.realm, nt_errstr(ntstatus));
+                       DBG_WARNING("Search for %s in <%s> failed: %s\n",
+                                  expr, bp, ads_errstr(status));
+                       SAFE_FREE(bp);
+                       return status;
+               }
+
+               cred_name = cli_credentials_get_unparsed_name(creds, creds);
+               DBG_NOTICE("Reopening ads connection as %s to "
+                          "realm '%s' after error %s\n",
+                          cred_name, ads->server.realm, ads_errstr(status));
+
+               status = ads_connect_creds(ads, creds);
                if (!ADS_ERR_OK(status)) {
-                       DEBUG(1,("ads_search_retry: failed to reconnect (%s)\n",
-                                ads_errstr(status)));
+                       DBG_WARNING("Reconnect ads connection as %s to "
+                                   "realm '%s' failed: %s\n",
+                                   cred_name, ads->server.realm,
+                                   ads_errstr(status));
                        /*
                         * We need to keep the ads pointer
                         * from being freed here as we don't own it and
                         * callers depend on it being around.
                         */
                        ads_disconnect(ads);
+                       TALLOC_FREE(creds);
                        SAFE_FREE(bp);
                        return status;
                }
+               TALLOC_FREE(creds);
 
                *res = NULL;
 
index 81c0f273ec01d55ad69786c6bb25a8ac93c7cfde..e47f2c98ad865f839cfddbdd103f22aef5319e97 100644 (file)
@@ -53,6 +53,7 @@ interface ads
                ads_auth_flags flags;
                string ccache_name;
                NTTIME expire_time;
+               [ignore] struct ads_reconnect_state *reconnect_state;
        } ads_auth;
 
        typedef [nopull,nopush] struct {