source4 samr: cache samr_QueryDisplayInfo results
authorGary Lockyer <gary@catalyst.net.nz>
Tue, 9 Oct 2018 20:20:25 +0000 (09:20 +1300)
committerAndrew Bartlett <abartlet@samba.org>
Tue, 20 Nov 2018 21:14:17 +0000 (22:14 +0100)
Add a cache of GUID's that matched the last samr_QueryDisplayInfo made on a
domain handle.  The cache is cleared if the requested start index is
zero, or if the level does not match that in the cache.

The cache is maintained in the guid_caches array of the dcesrv_handle.

Note: that currently this cache exists for the lifetime of the RPC
      handle.

Signed-off-by: Gary Lockyer <gary@catalyst.net.nz>
Reviewed-by: Andrew Bartlett <abartlet@samba.org>
selftest/knownfail.d/samr [deleted file]
source4/rpc_server/samr/dcesrv_samr.c
source4/rpc_server/samr/dcesrv_samr.h

diff --git a/selftest/knownfail.d/samr b/selftest/knownfail.d/samr
deleted file mode 100644 (file)
index d92b890..0000000
+++ /dev/null
@@ -1,10 +0,0 @@
-samba.tests.dcerpc.sam.samba.tests.dcerpc.sam.SamrTests.test_QueryDisplayInfo_level_1\(ad_dc_ntvfs:local\)
-samba.tests.dcerpc.sam.samba.tests.dcerpc.sam.SamrTests.test_QueryDisplayInfo_level_2\(ad_dc_ntvfs:local\)
-samba.tests.dcerpc.sam.samba.tests.dcerpc.sam.SamrTests.test_QueryDisplayInfo_level_3\(ad_dc_ntvfs:local\)
-samba.tests.dcerpc.sam.samba.tests.dcerpc.sam.SamrTests.test_QueryDisplayInfo_level_4\(ad_dc_ntvfs:local\)
-samba.tests.dcerpc.sam.samba.tests.dcerpc.sam.SamrTests.test_QueryDisplayInfo_level_5\(ad_dc_ntvfs:local\)
-samba.tests.dcerpc.sam.python3.samba.tests.dcerpc.sam.SamrTests.test_QueryDisplayInfo_level_1\(ad_dc_ntvfs:local\)
-samba.tests.dcerpc.sam.python3.samba.tests.dcerpc.sam.SamrTests.test_QueryDisplayInfo_level_2\(ad_dc_ntvfs:local\)
-samba.tests.dcerpc.sam.python3.samba.tests.dcerpc.sam.SamrTests.test_QueryDisplayInfo_level_3\(ad_dc_ntvfs:local\)
-samba.tests.dcerpc.sam.python3.samba.tests.dcerpc.sam.SamrTests.test_QueryDisplayInfo_level_4\(ad_dc_ntvfs:local\)
-samba.tests.dcerpc.sam.python3.samba.tests.dcerpc.sam.SamrTests.test_QueryDisplayInfo_level_5\(ad_dc_ntvfs:local\)
index 3df0c51dfee947fd70a15ddd94675cbdc030b577..bd7ffda7ebb3d8c0851b57d4b961f461d02d8f35 100644 (file)
@@ -145,7 +145,62 @@ static NTSTATUS dcesrv_interface_samr_bind(struct dcesrv_call_state *dce_call,
        }                                                               \
 } while (0)
 
+/*
+ * Clear a GUID cache
+ */
+static void clear_guid_cache(struct samr_guid_cache *cache)
+{
+       cache->handle = 0;
+       cache->size = 0;
+       TALLOC_FREE(cache->entries);
+}
 
+/*
+ * initialize a GUID cache
+ */
+static void initialize_guid_cache(struct samr_guid_cache *cache)
+{
+       cache->handle = 0;
+       cache->size = 0;
+       cache->entries = NULL;
+}
+
+static NTSTATUS load_guid_cache(
+       struct samr_guid_cache *cache,
+       struct samr_domain_state *d_state,
+       unsigned int ldb_cnt,
+       struct ldb_message **res)
+{
+       NTSTATUS status = NT_STATUS_OK;
+       unsigned int i;
+       TALLOC_CTX *frame = talloc_stackframe();
+
+       clear_guid_cache(cache);
+
+       /*
+        * Store the GUID's in the cache.
+        */
+       cache->handle = 0;
+       cache->size = ldb_cnt;
+       cache->entries = talloc_array(d_state, struct GUID, ldb_cnt);
+       if (cache->entries == NULL) {
+               clear_guid_cache(cache);
+               status = NT_STATUS_NO_MEMORY;
+               goto exit;
+       }
+
+       /*
+        * Extract a list of the GUIDs for all the matching objects
+        * we cache just the GUIDS to reduce the memory overhead of
+        * the result cache.
+        */
+       for (i = 0; i < ldb_cnt; i++) {
+               cache->entries[i] = samdb_result_guid(res[i], "objectGUID");
+       }
+exit:
+       TALLOC_FREE(frame);
+       return status;
+}
 
 /*
   samr_Connect
@@ -384,6 +439,7 @@ static NTSTATUS dcesrv_samr_OpenDomain(struct dcesrv_call_state *dce_call, TALLO
        const char * const dom_attrs[] = { "cn", NULL};
        struct ldb_message **dom_msgs;
        int ret;
+       unsigned int i;
 
        ZERO_STRUCTP(r->out.domain_handle);
 
@@ -435,6 +491,10 @@ static NTSTATUS dcesrv_samr_OpenDomain(struct dcesrv_call_state *dce_call, TALLO
 
        d_state->lp_ctx = dce_call->conn->dce_ctx->lp_ctx;
 
+       for (i = 0; i < SAMR_LAST_CACHE; i++) {
+               initialize_guid_cache(&d_state->guid_caches[i]);
+       }
+
        h_domain = dcesrv_handle_new(dce_call->context, SAMR_HANDLE_DOMAIN);
        if (!h_domain) {
                talloc_free(d_state);
@@ -3707,88 +3767,164 @@ static NTSTATUS dcesrv_samr_GetGroupsForUser(struct dcesrv_call_state *dce_call,
        return NT_STATUS_OK;
 }
 
-
 /*
-  samr_QueryDisplayInfo
-*/
+ * samr_QueryDisplayInfo
+ *
+ * A cache of the GUID's matching the last query is maintained
+ * in the SAMR_QUERY_DISPLAY_INFO_CACHE guid_cache maintained o
+ * n the dcesrv_handle.
+ */
 static NTSTATUS dcesrv_samr_QueryDisplayInfo(struct dcesrv_call_state *dce_call, TALLOC_CTX *mem_ctx,
                       struct samr_QueryDisplayInfo *r)
 {
        struct dcesrv_handle *h;
        struct samr_domain_state *d_state;
        struct ldb_result *res;
-       unsigned int i;
-       uint32_t count;
-       const char * const attrs[] = { "objectSid", "sAMAccountName",
-               "displayName", "description", "userAccountControl",
-               "pwdLastSet", NULL };
+       uint32_t i;
+       uint32_t results = 0;
+       uint32_t count = 0;
+       const char *const cache_attrs[] = {"objectGUID", NULL};
+       const char *const attrs[] = {
+           "objectSID", "sAMAccountName", "displayName", "description", NULL};
        struct samr_DispEntryFull *entriesFull = NULL;
        struct samr_DispEntryFullGroup *entriesFullGroup = NULL;
        struct samr_DispEntryAscii *entriesAscii = NULL;
        struct samr_DispEntryGeneral *entriesGeneral = NULL;
        const char *filter;
        int ret;
+       NTSTATUS status;
+       struct samr_guid_cache *cache = NULL;
 
        DCESRV_PULL_HANDLE(h, r->in.domain_handle, SAMR_HANDLE_DOMAIN);
 
        d_state = h->data;
 
-       switch (r->in.level) {
-       case 1:
-       case 4:
-               filter = talloc_asprintf(mem_ctx, "(&(objectclass=user)"
-                                        "(sAMAccountType=%d))",
-                                        ATYPE_NORMAL_ACCOUNT);
-               break;
-       case 2:
-               filter = talloc_asprintf(mem_ctx, "(&(objectclass=user)"
-                                        "(sAMAccountType=%d))",
-                                        ATYPE_WORKSTATION_TRUST);
-               break;
-       case 3:
-       case 5:
-               filter = talloc_asprintf(mem_ctx,
-                                        "(&(|(groupType=%d)(groupType=%d))"
-                                        "(objectClass=group))",
-                                        GTYPE_SECURITY_UNIVERSAL_GROUP,
-                                        GTYPE_SECURITY_GLOBAL_GROUP);
-               break;
-       default:
-               return NT_STATUS_INVALID_INFO_CLASS;
-       }
+       cache = &d_state->guid_caches[SAMR_QUERY_DISPLAY_INFO_CACHE];
+       /*
+        * Can the cached results be used?
+        * The cache is discarded if the start index is zero, or the requested
+        * level is different from that in the cache.
+        */
+       if ((r->in.start_idx == 0) || (r->in.level != cache->handle)) {
+               /*
+                * The cached results can not be used, so will need to query
+                * the database.
+                */
 
-       /* search for all requested objects in all domains. This could
-          possibly be cached and resumed based on resume_key */
-       ret = dsdb_search(d_state->sam_ctx, mem_ctx, &res, ldb_get_default_basedn(d_state->sam_ctx),
-                         LDB_SCOPE_SUBTREE, attrs, 0, "%s", filter);
-       if (ret != LDB_SUCCESS) {
-               return NT_STATUS_INTERNAL_DB_CORRUPTION;
+               /*
+                * Get the search filter for the current level
+                */
+               switch (r->in.level) {
+               case 1:
+               case 4:
+                       filter = talloc_asprintf(mem_ctx,
+                                                "(&(objectclass=user)"
+                                                "(sAMAccountType=%d))",
+                                                ATYPE_NORMAL_ACCOUNT);
+                       break;
+               case 2:
+                       filter = talloc_asprintf(mem_ctx,
+                                                "(&(objectclass=user)"
+                                                "(sAMAccountType=%d))",
+                                                ATYPE_WORKSTATION_TRUST);
+                       break;
+               case 3:
+               case 5:
+                       filter =
+                           talloc_asprintf(mem_ctx,
+                                           "(&(|(groupType=%d)(groupType=%d))"
+                                           "(objectClass=group))",
+                                           GTYPE_SECURITY_UNIVERSAL_GROUP,
+                                           GTYPE_SECURITY_GLOBAL_GROUP);
+                       break;
+               default:
+                       return NT_STATUS_INVALID_INFO_CLASS;
+               }
+               clear_guid_cache(cache);
+
+               /*
+                * search for all requested objects in all domains.
+                */
+               ret = dsdb_search(d_state->sam_ctx,
+                                 mem_ctx,
+                                 &res,
+                                 ldb_get_default_basedn(d_state->sam_ctx),
+                                 LDB_SCOPE_SUBTREE,
+                                 cache_attrs,
+                                 0,
+                                 "%s",
+                                 filter);
+               if (ret != LDB_SUCCESS) {
+                       return NT_STATUS_INTERNAL_DB_CORRUPTION;
+               }
+               if ((res->count == 0) || (r->in.max_entries == 0)) {
+                       return NT_STATUS_OK;
+               }
+
+               status = load_guid_cache(cache, d_state, res->count, res->msgs);
+               TALLOC_FREE(res);
+               if (!NT_STATUS_IS_OK(status)) {
+                       return status;
+               }
+               cache->handle = r->in.level;
        }
-       if ((res->count == 0) || (r->in.max_entries == 0)) {
+       *r->out.total_size = cache->size;
+
+       /*
+        * if there are no entries or the requested start index is greater
+        * than the number of entries, we return an empty response.
+        */
+       if (r->in.start_idx >= cache->size) {
+               *r->out.returned_size = 0;
+               switch(r->in.level) {
+               case 1:
+                       r->out.info->info1.count = *r->out.returned_size;
+                       r->out.info->info1.entries = NULL;
+                       break;
+               case 2:
+                       r->out.info->info2.count = *r->out.returned_size;
+                       r->out.info->info2.entries = NULL;
+                       break;
+               case 3:
+                       r->out.info->info3.count = *r->out.returned_size;
+                       r->out.info->info3.entries = NULL;
+                       break;
+               case 4:
+                       r->out.info->info4.count = *r->out.returned_size;
+                       r->out.info->info4.entries = NULL;
+                       break;
+               case 5:
+                       r->out.info->info5.count = *r->out.returned_size;
+                       r->out.info->info5.entries = NULL;
+                       break;
+               }
                return NT_STATUS_OK;
        }
 
+       /*
+        * Allocate an array of the appropriate result structures for the
+        * current query level.
+        *
+        * r->in.start_idx is always < cache->size due to the check above
+        */
+       results = MIN((cache->size - r->in.start_idx), r->in.max_entries);
        switch (r->in.level) {
        case 1:
-               entriesGeneral = talloc_array(mem_ctx,
-                                             struct samr_DispEntryGeneral,
-                                             res->count);
+               entriesGeneral = talloc_array(
+                   mem_ctx, struct samr_DispEntryGeneral, results);
                break;
        case 2:
-               entriesFull = talloc_array(mem_ctx,
-                                          struct samr_DispEntryFull,
-                                          res->count);
+               entriesFull =
+                   talloc_array(mem_ctx, struct samr_DispEntryFull, results);
                break;
        case 3:
-               entriesFullGroup = talloc_array(mem_ctx,
-                                               struct samr_DispEntryFullGroup,
-                                               res->count);
+               entriesFullGroup = talloc_array(
+                   mem_ctx, struct samr_DispEntryFullGroup, results);
                break;
        case 4:
        case 5:
-               entriesAscii = talloc_array(mem_ctx,
-                                           struct samr_DispEntryAscii,
-                                           res->count);
+               entriesAscii =
+                   talloc_array(mem_ctx, struct samr_DispEntryAscii, results);
                break;
        }
 
@@ -3796,135 +3932,157 @@ static NTSTATUS dcesrv_samr_QueryDisplayInfo(struct dcesrv_call_state *dce_call,
            (entriesAscii == NULL) && (entriesFullGroup == NULL))
                return NT_STATUS_NO_MEMORY;
 
+       /*
+        * Process the list of result GUID's.
+        * Read the details of each object and populate the result structure
+        * for the current level.
+        */
        count = 0;
-
-       for (i = 0; i < res->count; i++) {
+       for (i = 0; i < results; i++) {
                struct dom_sid *objectsid;
+               struct ldb_result *rec;
+               const uint32_t idx = r->in.start_idx + i;
 
-               objectsid = samdb_result_dom_sid(mem_ctx, res->msgs[i],
-                                                "objectSid");
-               if (objectsid == NULL)
+               /*
+                * Read an object from disk using the GUID as the key
+                *
+                * If the object can not be read, or it does not have a SID
+                * it is ignored.  In this case the number of entries returned
+                * will be less than the requested size, there will also be
+                * a gap in the idx numbers in the returned elements e.g. if
+                * there are 3 GUIDs a, b, c in the cache and b is deleted from
+                * disk then details for a, and c will be returned with
+                * idx values of 1 and 3 respectively.
+                *
+                */
+               ret = dsdb_search_by_dn_guid(d_state->sam_ctx,
+                                            mem_ctx,
+                                            &rec,
+                                            &cache->entries[idx],
+                                            attrs,
+                                            0);
+               if (ret == LDB_ERR_NO_SUCH_OBJECT) {
+                       struct GUID_txt_buf guid_buf;
+                       char *guid_str =
+                               GUID_buf_string(&cache->entries[idx],
+                                               &guid_buf);
+                       DBG_WARNING("GUID [%s] not found\n", guid_str);
                        continue;
+               } else if (ret != LDB_SUCCESS) {
+                       clear_guid_cache(cache);
+                       return NT_STATUS_INTERNAL_DB_CORRUPTION;
+               }
+               objectsid = samdb_result_dom_sid(mem_ctx,
+                                                rec->msgs[0],
+                                                "objectSID");
+               if (objectsid == NULL) {
+                       struct GUID_txt_buf guid_buf;
+                       char *guid_str =
+                               GUID_buf_string(&cache->entries[idx],
+                                               &guid_buf);
+                       DBG_WARNING("objectSID for GUID [%s] not found\n",
+                                   guid_str);
+                       continue;
+               }
 
+               /*
+                * Populate the result structure for the current object
+                */
                switch(r->in.level) {
                case 1:
-                       entriesGeneral[count].idx = count + 1;
+
+                       entriesGeneral[count].idx = idx + 1;
                        entriesGeneral[count].rid =
-                               objectsid->sub_auths[objectsid->num_auths-1];
+                           objectsid->sub_auths[objectsid->num_auths - 1];
                        entriesGeneral[count].acct_flags =
-                               samdb_result_acct_flags(res->msgs[i], NULL);
+                           samdb_result_acct_flags(rec->msgs[0], NULL);
                        entriesGeneral[count].account_name.string =
-                               ldb_msg_find_attr_as_string(res->msgs[i],
-                                                           "sAMAccountName", "");
+                           ldb_msg_find_attr_as_string(
+                               rec->msgs[0], "sAMAccountName", "");
                        entriesGeneral[count].full_name.string =
-                               ldb_msg_find_attr_as_string(res->msgs[i],
-                                                           "displayName", "");
+                           ldb_msg_find_attr_as_string(
+                               rec->msgs[0], "displayName", "");
                        entriesGeneral[count].description.string =
-                               ldb_msg_find_attr_as_string(res->msgs[i],
-                                                           "description", "");
+                           ldb_msg_find_attr_as_string(
+                               rec->msgs[0], "description", "");
                        break;
                case 2:
-                       entriesFull[count].idx = count + 1;
+                       entriesFull[count].idx = idx + 1;
                        entriesFull[count].rid =
-                               objectsid->sub_auths[objectsid->num_auths-1];
+                           objectsid->sub_auths[objectsid->num_auths - 1];
 
-                       /* No idea why we need to or in ACB_NORMAL here, but this is what Win2k3 seems to do... */
+                       /*
+                        * No idea why we need to or in ACB_NORMAL here,
+                        * but this is what Win2k3 seems to do...
+                        */
                        entriesFull[count].acct_flags =
-                               samdb_result_acct_flags(res->msgs[i],
-                                                       NULL) | ACB_NORMAL;
+                           samdb_result_acct_flags(rec->msgs[0], NULL) |
+                           ACB_NORMAL;
                        entriesFull[count].account_name.string =
-                               ldb_msg_find_attr_as_string(res->msgs[i],
-                                                           "sAMAccountName", "");
+                           ldb_msg_find_attr_as_string(
+                               rec->msgs[0], "sAMAccountName", "");
                        entriesFull[count].description.string =
-                               ldb_msg_find_attr_as_string(res->msgs[i],
-                                                           "description", "");
+                           ldb_msg_find_attr_as_string(
+                               rec->msgs[0], "description", "");
                        break;
                case 3:
-                       entriesFullGroup[count].idx = count + 1;
+                       entriesFullGroup[count].idx = idx + 1;
                        entriesFullGroup[count].rid =
-                               objectsid->sub_auths[objectsid->num_auths-1];
-                       /* We get a "7" here for groups */
-                       entriesFullGroup[count].acct_flags
-                               = SE_GROUP_MANDATORY | SE_GROUP_ENABLED_BY_DEFAULT | SE_GROUP_ENABLED;
+                           objectsid->sub_auths[objectsid->num_auths - 1];
+                       /*
+                        * We get a "7" here for groups
+                        */
+                       entriesFullGroup[count].acct_flags =
+                           SE_GROUP_MANDATORY | SE_GROUP_ENABLED_BY_DEFAULT |
+                           SE_GROUP_ENABLED;
                        entriesFullGroup[count].account_name.string =
-                               ldb_msg_find_attr_as_string(res->msgs[i],
-                                                           "sAMAccountName", "");
+                           ldb_msg_find_attr_as_string(
+                               rec->msgs[0], "sAMAccountName", "");
                        entriesFullGroup[count].description.string =
-                               ldb_msg_find_attr_as_string(res->msgs[i],
-                                                           "description", "");
+                           ldb_msg_find_attr_as_string(
+                               rec->msgs[0], "description", "");
                        break;
                case 4:
                case 5:
-                       entriesAscii[count].idx = count + 1;
+                       entriesAscii[count].idx = idx + 1;
                        entriesAscii[count].account_name.string =
-                               ldb_msg_find_attr_as_string(res->msgs[i],
-                                                           "sAMAccountName", "");
+                           ldb_msg_find_attr_as_string(
+                               rec->msgs[0], "sAMAccountName", "");
                        break;
                }
-
-               count += 1;
+               count++;
        }
 
-       *r->out.total_size = count;
-
-       if (r->in.start_idx >= count) {
-               *r->out.returned_size = 0;
-               switch(r->in.level) {
-               case 1:
-                       r->out.info->info1.count = *r->out.returned_size;
-                       r->out.info->info1.entries = NULL;
-                       break;
-               case 2:
-                       r->out.info->info2.count = *r->out.returned_size;
-                       r->out.info->info2.entries = NULL;
-                       break;
-               case 3:
-                       r->out.info->info3.count = *r->out.returned_size;
-                       r->out.info->info3.entries = NULL;
-                       break;
-               case 4:
-                       r->out.info->info4.count = *r->out.returned_size;
-                       r->out.info->info4.entries = NULL;
-                       break;
-               case 5:
-                       r->out.info->info5.count = *r->out.returned_size;
-                       r->out.info->info5.entries = NULL;
-                       break;
-               }
-       } else {
-               *r->out.returned_size = MIN(count - r->in.start_idx,
-                                          r->in.max_entries);
-               switch(r->in.level) {
-               case 1:
-                       r->out.info->info1.count = *r->out.returned_size;
-                       r->out.info->info1.entries =
-                               &(entriesGeneral[r->in.start_idx]);
-                       break;
-               case 2:
-                       r->out.info->info2.count = *r->out.returned_size;
-                       r->out.info->info2.entries =
-                               &(entriesFull[r->in.start_idx]);
-                       break;
-               case 3:
-                       r->out.info->info3.count = *r->out.returned_size;
-                       r->out.info->info3.entries =
-                               &(entriesFullGroup[r->in.start_idx]);
-                       break;
-               case 4:
-                       r->out.info->info4.count = *r->out.returned_size;
-                       r->out.info->info4.entries =
-                               &(entriesAscii[r->in.start_idx]);
-                       break;
-               case 5:
-                       r->out.info->info5.count = *r->out.returned_size;
-                       r->out.info->info5.entries =
-                               &(entriesAscii[r->in.start_idx]);
-                       break;
-               }
+       /*
+        * Build the response based on the request level.
+        */
+       *r->out.returned_size = count;
+       switch(r->in.level) {
+       case 1:
+               r->out.info->info1.count = count;
+               r->out.info->info1.entries = entriesGeneral;
+               break;
+       case 2:
+               r->out.info->info2.count = count;
+               r->out.info->info2.entries = entriesFull;
+               break;
+       case 3:
+               r->out.info->info3.count = count;
+               r->out.info->info3.entries = entriesFullGroup;
+               break;
+       case 4:
+               r->out.info->info4.count = count;
+               r->out.info->info4.entries = entriesAscii;
+               break;
+       case 5:
+               r->out.info->info5.count = count;
+               r->out.info->info5.entries = entriesAscii;
+               break;
        }
 
-       return (*r->out.returned_size < (count - r->in.start_idx)) ?
-               STATUS_MORE_ENTRIES : NT_STATUS_OK;
+       return ((r->in.start_idx + results) < cache->size)
+                  ? STATUS_MORE_ENTRIES
+                  : NT_STATUS_OK;
 }
 
 
index 261bd052efe6a67375da0344734c373965ce1f60..f08bac053c80341da7cd7fefa6530f17dd12e71e 100644 (file)
@@ -42,6 +42,20 @@ struct samr_connect_state {
        uint32_t access_mask;
 };
 
+/*
+ * Cache of object GUIDS
+ */
+struct samr_guid_cache {
+       unsigned handle;
+       unsigned size;
+       struct GUID *entries;
+};
+
+enum samr_guid_cache_id {
+       SAMR_QUERY_DISPLAY_INFO_CACHE,
+       SAMR_LAST_CACHE
+};
+
 /*
   state associated with a samr_OpenDomain() operation
 */
@@ -55,6 +69,7 @@ struct samr_domain_state {
        enum server_role role;
        bool builtin;
        struct loadparm_context *lp_ctx;
+       struct samr_guid_cache guid_caches[SAMR_LAST_CACHE];
 };
 
 /*