s4:rpc_server/lsa: prepare dcesrv_lsa_LookupSids* for async processing
authorStefan Metzmacher <metze@samba.org>
Fri, 19 Jan 2018 12:42:40 +0000 (13:42 +0100)
committerRalph Boehme <slow@samba.org>
Wed, 21 Feb 2018 13:19:18 +0000 (14:19 +0100)
Bug: https://bugzilla.samba.org/show_bug.cgi?id=13286

Signed-off-by: Stefan Metzmacher <metze@samba.org>
Reviewed-by: Ralph Boehme <slow@samba.org>
source4/rpc_server/lsa/lsa_lookup.c

index bb14cf7a16ba24c3790b676cb88275aa878b932d..62d29d3173ebfd28487de4fabb27e0bbdad1f43c 100644 (file)
@@ -587,11 +587,27 @@ static NTSTATUS dcesrv_lsa_lookup_sid(struct lsa_policy_state *state, TALLOC_CTX
        return NT_STATUS_OK;
 }
 
-static NTSTATUS dcesrv_lsa_LookupSids_common(struct dcesrv_call_state *dce_call,
-                                            TALLOC_CTX *mem_ctx,
-                                            struct lsa_policy_state *policy_state,
-                                            struct lsa_LookupSids2 *r)
+struct dcesrv_lsa_LookupSids_base_state {
+       struct dcesrv_call_state *dce_call;
+
+       TALLOC_CTX *mem_ctx;
+
+       struct lsa_policy_state *policy_state;
+
+       struct lsa_LookupSids3 r;
+
+       struct {
+               struct lsa_LookupSids *l;
+               struct lsa_LookupSids2 *l2;
+               struct lsa_LookupSids3 *l3;
+       } _r;
+};
+
+static NTSTATUS dcesrv_lsa_LookupSids_base_call(struct dcesrv_lsa_LookupSids_base_state *state)
 {
+       struct lsa_policy_state *policy_state = state->policy_state;
+       TALLOC_CTX *mem_ctx = state->mem_ctx;
+       struct lsa_LookupSids3 *r = &state->r;
        struct lsa_RefDomainList *domains = NULL;
        uint32_t i;
 
@@ -675,6 +691,45 @@ static NTSTATUS dcesrv_lsa_LookupSids_common(struct dcesrv_call_state *dce_call,
        return NT_STATUS_OK;
 }
 
+static void dcesrv_lsa_LookupSids_base_map(
+       struct dcesrv_lsa_LookupSids_base_state *state)
+{
+       if (state->_r.l3 != NULL) {
+               struct lsa_LookupSids3 *r = state->_r.l3;
+
+               r->out.result = state->r.out.result;
+               return;
+       }
+
+       if (state->_r.l2 != NULL) {
+               struct lsa_LookupSids2 *r = state->_r.l2;
+
+               r->out.result = state->r.out.result;
+               return;
+       }
+
+       if (state->_r.l != NULL) {
+               struct lsa_LookupSids *r = state->_r.l;
+               uint32_t i;
+
+               r->out.result = state->r.out.result;
+
+               SMB_ASSERT(state->r.out.names->count <= r->in.sids->num_sids);
+               for (i = 0; i < state->r.out.names->count; i++) {
+                       struct lsa_TranslatedName2 *n2 =
+                               &state->r.out.names->names[i];
+                       struct lsa_TranslatedName *n =
+                               &r->out.names->names[i];
+
+                       n->sid_type = n2->sid_type;
+                       n->name = n2->name;
+                       n->sid_index = n2->sid_index;
+               }
+               r->out.names->count = state->r.out.names->count;
+               return;
+       }
+}
+
 /*
   lsa_LookupSids2
 */
@@ -684,8 +739,9 @@ NTSTATUS dcesrv_lsa_LookupSids2(struct dcesrv_call_state *dce_call,
 {
        enum dcerpc_transport_t transport =
                dcerpc_binding_get_transport(dce_call->conn->endpoint->ep_description);
-       struct lsa_policy_state *policy_state = NULL;
+       struct dcesrv_lsa_LookupSids_base_state *state = NULL;
        struct dcesrv_handle *policy_handle = NULL;
+       NTSTATUS status;
 
        if (transport != NCACN_NP && transport != NCALRPC) {
                DCESRV_FAULT(DCERPC_FAULT_ACCESS_DENIED);
@@ -693,12 +749,43 @@ NTSTATUS dcesrv_lsa_LookupSids2(struct dcesrv_call_state *dce_call,
 
        DCESRV_PULL_HANDLE(policy_handle, r->in.handle, LSA_HANDLE_POLICY);
 
-       policy_state = policy_handle->data;
+       *r->out.domains = NULL;
+       r->out.names->count = 0;
+       r->out.names->names = NULL;
+       *r->out.count = 0;
+
+       state = talloc_zero(mem_ctx, struct dcesrv_lsa_LookupSids_base_state);
+       if (state == NULL) {
+               return NT_STATUS_NO_MEMORY;
+       }
+
+       state->dce_call = dce_call;
+       state->mem_ctx = mem_ctx;
+
+       state->policy_state = policy_handle->data;
+
+       state->r.in.sids = r->in.sids;
+       state->r.in.level = r->in.level;
+       state->r.in.lookup_options = r->in.lookup_options;
+       state->r.in.client_revision = r->in.client_revision;
+       state->r.in.names = r->in.names;
+       state->r.in.count = r->in.count;
+       state->r.out.domains = r->out.domains;
+       state->r.out.names = r->out.names;
+       state->r.out.count = r->out.count;
+
+       state->_r.l2 = r;
 
-       return dcesrv_lsa_LookupSids_common(dce_call,
-                                           mem_ctx,
-                                           policy_state,
-                                           r);
+       status = dcesrv_lsa_LookupSids_base_call(state);
+
+       if (dce_call->state_flags & DCESRV_CALL_STATE_FLAG_ASYNC) {
+               return status;
+       }
+
+       state->r.out.result = status;
+       dcesrv_lsa_LookupSids_base_map(state);
+       TALLOC_FREE(state);
+       return status;
 }
 
 
@@ -715,8 +802,7 @@ NTSTATUS dcesrv_lsa_LookupSids3(struct dcesrv_call_state *dce_call,
        enum dcerpc_transport_t transport =
                dcerpc_binding_get_transport(dce_call->conn->endpoint->ep_description);
        const struct dcesrv_auth *auth = &dce_call->conn->auth_state;
-       struct lsa_policy_state *policy_state;
-       struct lsa_LookupSids2 q;
+       struct dcesrv_lsa_LookupSids_base_state *state = NULL;
        NTSTATUS status;
 
        if (transport != NCACN_IP_TCP) {
@@ -737,37 +823,42 @@ NTSTATUS dcesrv_lsa_LookupSids3(struct dcesrv_call_state *dce_call,
        r->out.names->names = NULL;
        *r->out.count = 0;
 
-       status = dcesrv_lsa_get_policy_state(dce_call, mem_ctx,
+       state = talloc_zero(mem_ctx, struct dcesrv_lsa_LookupSids_base_state);
+       if (state == NULL) {
+               return NT_STATUS_NO_MEMORY;
+       }
+
+       state->dce_call = dce_call;
+       state->mem_ctx = mem_ctx;
+
+       status = dcesrv_lsa_get_policy_state(state->dce_call, mem_ctx,
                                             0, /* we skip access checks */
-                                            &policy_state);
+                                            &state->policy_state);
        if (!NT_STATUS_IS_OK(status)) {
                return status;
        }
 
-       ZERO_STRUCT(q);
+       state->r.in.sids = r->in.sids;
+       state->r.in.level = r->in.level;
+       state->r.in.lookup_options = r->in.lookup_options;
+       state->r.in.client_revision = r->in.client_revision;
+       state->r.in.names = r->in.names;
+       state->r.in.count = r->in.count;
+       state->r.out.domains = r->out.domains;
+       state->r.out.names = r->out.names;
+       state->r.out.count = r->out.count;
 
-       q.in.handle   = NULL;
-       q.in.sids     = r->in.sids;
-       q.in.names    = r->in.names;
-       q.in.level    = r->in.level;
-       q.in.count    = r->in.count;
-       q.in.lookup_options = r->in.lookup_options;
-       q.in.client_revision = r->in.client_revision;
-       q.out.count   = r->out.count;
-       q.out.names   = r->out.names;
-       q.out.domains = r->out.domains;
+       state->_r.l3 = r;
 
-       status = dcesrv_lsa_LookupSids_common(dce_call,
-                                             mem_ctx,
-                                             policy_state,
-                                             &q);
-
-       talloc_free(policy_state);
+       status = dcesrv_lsa_LookupSids_base_call(state);
 
-       r->out.count = q.out.count;
-       r->out.names = q.out.names;
-       r->out.domains = q.out.domains;
+       if (dce_call->state_flags & DCESRV_CALL_STATE_FLAG_ASYNC) {
+               return status;
+       }
 
+       state->r.out.result = status;
+       dcesrv_lsa_LookupSids_base_map(state);
+       TALLOC_FREE(state);
        return status;
 }
 
@@ -780,14 +871,16 @@ NTSTATUS dcesrv_lsa_LookupSids(struct dcesrv_call_state *dce_call, TALLOC_CTX *m
 {
        enum dcerpc_transport_t transport =
                dcerpc_binding_get_transport(dce_call->conn->endpoint->ep_description);
-       struct lsa_LookupSids2 r2;
+       struct dcesrv_lsa_LookupSids_base_state *state = NULL;
+       struct dcesrv_handle *policy_handle = NULL;
        NTSTATUS status;
-       uint32_t i;
 
        if (transport != NCACN_NP && transport != NCALRPC) {
                DCESRV_FAULT(DCERPC_FAULT_ACCESS_DENIED);
        }
 
+       DCESRV_PULL_HANDLE(policy_handle, r->in.handle, LSA_HANDLE_POLICY);
+
        *r->out.domains = NULL;
        r->out.names->count = 0;
        r->out.names->names = NULL;
@@ -800,37 +893,43 @@ NTSTATUS dcesrv_lsa_LookupSids(struct dcesrv_call_state *dce_call, TALLOC_CTX *m
                return NT_STATUS_NO_MEMORY;
        }
 
-       ZERO_STRUCT(r2);
+       state = talloc_zero(mem_ctx, struct dcesrv_lsa_LookupSids_base_state);
+       if (state == NULL) {
+               return NT_STATUS_NO_MEMORY;
+       }
+
+       state->dce_call = dce_call;
+       state->mem_ctx = mem_ctx;
+
+       state->policy_state = policy_handle->data;
 
-       r2.in.handle   = r->in.handle;
-       r2.in.sids     = r->in.sids;
-       r2.in.names    = talloc_zero(mem_ctx, struct lsa_TransNameArray2);
-       if (r2.in.names == NULL) {
+       state->r.in.sids = r->in.sids;
+       state->r.in.level = r->in.level;
+       state->r.in.lookup_options = LSA_LOOKUP_OPTION_SEARCH_ISOLATED_NAMES;
+       state->r.in.client_revision = LSA_CLIENT_REVISION_1;
+       state->r.in.names = talloc_zero(state, struct lsa_TransNameArray2);
+       if (state->r.in.names == NULL) {
                return NT_STATUS_NO_MEMORY;
        }
-       r2.in.level    = r->in.level;
-       r2.in.count    = r->in.count;
-       r2.in.lookup_options = LSA_LOOKUP_OPTION_SEARCH_ISOLATED_NAMES;
-       r2.in.client_revision = LSA_CLIENT_REVISION_1;
-       r2.out.count   = r->out.count;
-       r2.out.names   = talloc_zero(mem_ctx, struct lsa_TransNameArray2);
-       if (r2.out.names == NULL) {
+       state->r.in.count = r->in.count;
+       state->r.out.domains = r->out.domains;
+       state->r.out.names = talloc_zero(state, struct lsa_TransNameArray2);
+       if (state->r.out.names == NULL) {
                return NT_STATUS_NO_MEMORY;
        }
-       r2.out.domains = r->out.domains;
+       state->r.out.count = r->out.count;
 
-       status = dcesrv_lsa_LookupSids2(dce_call, mem_ctx, &r2);
-       /* we deliberately don't check for error from the above,
-          as even on error we are supposed to return the names  */
+       state->_r.l = r;
 
-       SMB_ASSERT(r2.out.names->count <= r->in.sids->num_sids);
-       for (i=0;i<r2.out.names->count;i++) {
-               r->out.names->names[i].sid_type    = r2.out.names->names[i].sid_type;
-               r->out.names->names[i].name.string = r2.out.names->names[i].name.string;
-               r->out.names->names[i].sid_index   = r2.out.names->names[i].sid_index;
+       status = dcesrv_lsa_LookupSids_base_call(state);
+
+       if (dce_call->state_flags & DCESRV_CALL_STATE_FLAG_ASYNC) {
+               return status;
        }
-       r->out.names->count = r2.out.names->count;
 
+       state->r.out.result = status;
+       dcesrv_lsa_LookupSids_base_map(state);
+       TALLOC_FREE(state);
        return status;
 }