rpc_server3: Use fdopen_keepfd()
[samba.git] / source3 / rpc_server / rpc_worker.c
index 4f47a0ad4f3b8a6fef78496f21e84f8b52b1e1ba..81c7a90972180391a53fd6e2b3412e80c79ecfce 100644 (file)
@@ -24,6 +24,7 @@
 #include "source3/librpc/gen_ndr/ndr_rpc_host.h"
 #include "lib/util/debug.h"
 #include "lib/util/fault.h"
+#include "lib/util/util_file.h"
 #include "rpc_server.h"
 #include "rpc_pipes.h"
 #include "source3/smbd/proto.h"
@@ -95,11 +96,13 @@ static void rpc_worker_print_interface(
 
 static NTSTATUS rpc_worker_report_status(struct rpc_worker *worker)
 {
-       uint8_t buf[9];
+       uint8_t buf[16];
        DATA_BLOB blob = { .data = buf, .length = sizeof(buf), };
        enum ndr_err_code ndr_err;
        NTSTATUS status;
 
+       worker->status.num_association_groups = worker->dce_ctx->assoc_groups_num;
+
        if (DEBUGLEVEL >= 10) {
                NDR_PRINT_DEBUG(rpc_worker_status, &worker->status);
        }
@@ -129,7 +132,19 @@ static void rpc_worker_connection_terminated(
        NTSTATUS status;
        bool found = false;
 
-       SMB_ASSERT(worker->status.num_clients > 0);
+       /*
+        * We need to drop the association group reference
+        * explicitly here in order to avoid the order given
+        * by the destructors. rpc_worker_report_status() below,
+        * expects worker->dce_ctx->assoc_groups_num to be updated
+        * already.
+        */
+       if (conn->assoc_group != NULL) {
+               talloc_unlink(conn, conn->assoc_group);
+               conn->assoc_group = NULL;
+       }
+
+       SMB_ASSERT(worker->status.num_connections > 0);
 
        for (w = worker->conns; w != NULL; w = w->next) {
                if (w == ncacn_conn) {
@@ -141,7 +156,7 @@ static void rpc_worker_connection_terminated(
 
        DLIST_REMOVE(worker->conns, ncacn_conn);
 
-       worker->status.num_clients -= 1;
+       worker->status.num_connections -= 1;
 
        status = rpc_worker_report_status(worker);
        if (!NT_STATUS_IS_OK(status)) {
@@ -172,7 +187,7 @@ static void rpc_worker_new_client(
        int sock)
 {
        struct dcesrv_context *dce_ctx = worker->dce_ctx;
-       struct named_pipe_auth_req_info7 *info7 = client->npa_info7;
+       struct named_pipe_auth_req_info8 *info8 = client->npa_info8;
        struct tsocket_address *remote_client_addr = NULL;
        struct tsocket_address *local_server_addr = NULL;
        struct dcerpc_binding *b = NULL;
@@ -265,84 +280,84 @@ static void rpc_worker_new_client(
 
        if (transport == NCALRPC) {
                ret = tsocket_address_unix_from_path(ncacn_conn,
-                                                    info7->remote_client_addr,
+                                                    info8->remote_client_addr,
                                                     &remote_client_addr);
                if (ret == -1) {
                        DBG_DEBUG("tsocket_address_unix_from_path"
                                  "(%s) failed: %s\n",
-                                 info7->remote_client_addr,
+                                 info8->remote_client_addr,
                                  strerror(errno));
                        goto fail;
                }
 
                ncacn_conn->remote_client_name =
-                       talloc_strdup(ncacn_conn, info7->remote_client_name);
+                       talloc_strdup(ncacn_conn, info8->remote_client_name);
                if (ncacn_conn->remote_client_name == NULL) {
                        DBG_DEBUG("talloc_strdup(%s) failed\n",
-                                 info7->remote_client_name);
+                                 info8->remote_client_name);
                        goto fail;
                }
 
                ret = tsocket_address_unix_from_path(ncacn_conn,
-                                                    info7->local_server_addr,
+                                                    info8->local_server_addr,
                                                     &local_server_addr);
                if (ret == -1) {
                        DBG_DEBUG("tsocket_address_unix_from_path"
                                  "(%s) failed: %s\n",
-                                 info7->local_server_addr,
+                                 info8->local_server_addr,
                                  strerror(errno));
                        goto fail;
                }
 
                ncacn_conn->local_server_name =
-                       talloc_strdup(ncacn_conn, info7->local_server_name);
+                       talloc_strdup(ncacn_conn, info8->local_server_name);
                if (ncacn_conn->local_server_name == NULL) {
                        DBG_DEBUG("talloc_strdup(%s) failed\n",
-                                 info7->local_server_name);
+                                 info8->local_server_name);
                        goto fail;
                }
        } else {
                ret = tsocket_address_inet_from_strings(
                        ncacn_conn,
                        "ip",
-                       info7->remote_client_addr,
-                       info7->remote_client_port,
+                       info8->remote_client_addr,
+                       info8->remote_client_port,
                        &remote_client_addr);
                if (ret == -1) {
                        DBG_DEBUG("tsocket_address_inet_from_strings"
                                  "(%s, %" PRIu16 ") failed: %s\n",
-                                 info7->remote_client_addr,
-                                 info7->remote_client_port,
+                                 info8->remote_client_addr,
+                                 info8->remote_client_port,
                                  strerror(errno));
                        goto fail;
                }
                ncacn_conn->remote_client_name =
-                       talloc_strdup(ncacn_conn, info7->remote_client_name);
+                       talloc_strdup(ncacn_conn, info8->remote_client_name);
                if (ncacn_conn->remote_client_name == NULL) {
                        DBG_DEBUG("talloc_strdup(%s) failed\n",
-                                 info7->remote_client_name);
+                                 info8->remote_client_name);
                        goto fail;
                }
 
                ret = tsocket_address_inet_from_strings(
                        ncacn_conn,
                        "ip",
-                       info7->local_server_addr,
-                       info7->local_server_port,
+                       info8->local_server_addr,
+                       info8->local_server_port,
                        &local_server_addr);
                if (ret == -1) {
                        DBG_DEBUG("tsocket_address_inet_from_strings"
                                  "(%s, %" PRIu16 ") failed: %s\n",
-                                 info7->local_server_addr,
-                                 info7->local_server_port,
+                                 info8->local_server_addr,
+                                 info8->local_server_port,
                                  strerror(errno));
                        goto fail;
                }
                ncacn_conn->local_server_name =
-                       talloc_strdup(ncacn_conn, info7->local_server_name);
+                       talloc_strdup(ncacn_conn, info8->local_server_name);
                if (ncacn_conn->local_server_name == NULL) {
                        DBG_DEBUG("talloc_strdup(%s) failed\n",
-                                 info7->local_server_name);
+                                 info8->local_server_name);
                        goto fail;
                }
        }
@@ -364,10 +379,10 @@ static void rpc_worker_new_client(
                 * socket that the client connected to, passed in from
                 * samba-dcerpcd via the binding. For NCACN_NP (root
                 * only by unix permissions) we got a
-                * named_pipe_auth_req_info7 where the transport can
+                * named_pipe_auth_req_info8 where the transport can
                 * be overridden.
                 */
-               transport = info7->transport;
+               transport = info8->transport;
        } else {
                ret = tstream_bsd_existing_socket(
                        ncacn_conn, sock, &tstream);
@@ -376,10 +391,12 @@ static void rpc_worker_new_client(
                                  strerror(errno));
                        goto fail;
                }
+               /* as server we want to fail early */
+               tstream_bsd_fail_readv_first_error(tstream, true);
        }
        sock = -1;
 
-       token = info7->session_info->session_info->security_token;
+       token = info8->session_info->session_info->security_token;
 
        if (security_token_is_system(token) && (transport != NCALRPC)) {
                DBG_DEBUG("System token only allowed on NCALRPC\n");
@@ -410,7 +427,7 @@ static void rpc_worker_new_client(
        status = dcesrv_endpoint_connect(dce_ctx,
                                         ncacn_conn,
                                         ep,
-                                        info7->session_info->session_info,
+                                        info8->session_info->session_info,
                                         global_event_context(),
                                         state_flags,
                                         &dcesrv_conn);
@@ -467,7 +484,7 @@ static void rpc_worker_new_client(
        TALLOC_FREE(client);
 
        DLIST_ADD(worker->conns, ncacn_conn);
-       worker->status.num_clients += 1;
+       worker->status.num_connections += 1;
 
        dcesrv_loop_next_packet(dcesrv_conn, pkt, buffer);
 
@@ -553,7 +570,6 @@ static bool rpc_worker_status_filter(
                private_data, struct rpc_worker);
        struct dcerpc_ncacn_conn *conn = NULL;
        FILE *f = NULL;
-       int fd;
 
        if (rec->msg_type != MSG_RPC_DUMP_STATUS) {
                return false;
@@ -564,18 +580,9 @@ static bool rpc_worker_status_filter(
                return false;
        }
 
-       fd = dup(rec->fds[0]);
-       if (fd == -1) {
-               DBG_DEBUG("dup(%"PRIi64") failed: %s\n",
-                         rec->fds[0],
-                         strerror(errno));
-               return false;
-       }
-
-       f = fdopen(fd, "w");
+       f = fdopen_keepfd(rec->fds[0], "w");
        if (f == NULL) {
-               DBG_DEBUG("fdopen failed: %s\n", strerror(errno));
-               close(fd);
+               DBG_DEBUG("fdopen_keepfd failed: %s\n", strerror(errno));
                return false;
        }
 
@@ -612,7 +619,7 @@ static struct dcesrv_assoc_group *rpc_worker_assoc_group_reference(
        void *id_ptr = NULL;
 
        /* find an association group given a assoc_group_id */
-       id_ptr = idr_find(conn->dce_ctx->assoc_groups_idr, id & 0xffffff);
+       id_ptr = idr_find(conn->dce_ctx->assoc_groups_idr, id & UINT16_MAX);
        if (id_ptr == NULL) {
                DBG_NOTICE("Failed to find assoc_group 0x%08x\n", id);
                return NULL;
@@ -626,7 +633,7 @@ static struct dcesrv_assoc_group *rpc_worker_assoc_group_reference(
                        transport);
 
                DBG_NOTICE("assoc_group 0x%08x (transport %s) "
-                          "is not available on transport %s",
+                          "is not available on transport %s\n",
                           id, at, ct);
                return NULL;
        }
@@ -647,11 +654,14 @@ static int rpc_worker_assoc_group_destructor(
 
        ret = idr_remove(
                assoc_group->dce_ctx->assoc_groups_idr,
-               assoc_group->id & 0xffffff);
+               assoc_group->id & UINT16_MAX);
        if (ret != 0) {
                DBG_WARNING("Failed to remove assoc_group 0x%08x\n",
                            assoc_group->id);
        }
+
+       SMB_ASSERT(assoc_group->dce_ctx->assoc_groups_num > 0);
+       assoc_group->dce_ctx->assoc_groups_num -= 1;
        return 0;
 }
 
@@ -659,7 +669,7 @@ static int rpc_worker_assoc_group_destructor(
   allocate a new association group
  */
 static struct dcesrv_assoc_group *rpc_worker_assoc_group_new(
-       struct dcesrv_connection *conn, uint8_t worker_index)
+       struct dcesrv_connection *conn, uint16_t worker_index)
 {
        struct dcesrv_context *dce_ctx = conn->dce_ctx;
        const struct dcesrv_endpoint *endpoint = conn->endpoint;
@@ -673,6 +683,11 @@ static struct dcesrv_assoc_group *rpc_worker_assoc_group_new(
                return NULL;
        }
 
+       /*
+        * We use 16-bit to encode the worker index,
+        * have 16-bits left within the worker to form a
+        * 32-bit association group id.
+        */
        id = idr_get_new_random(
                dce_ctx->assoc_groups_idr, assoc_group, 1, UINT16_MAX);
        if (id == -1) {
@@ -680,12 +695,15 @@ static struct dcesrv_assoc_group *rpc_worker_assoc_group_new(
                DBG_WARNING("Out of association groups!\n");
                return NULL;
        }
-       assoc_group->id = (worker_index << 24) + id;
+       assoc_group->id = (((uint32_t)worker_index) << 16) | id;
        assoc_group->transport = transport;
        assoc_group->dce_ctx = dce_ctx;
 
        talloc_set_destructor(assoc_group, rpc_worker_assoc_group_destructor);
 
+       SMB_ASSERT(dce_ctx->assoc_groups_num < UINT16_MAX);
+       dce_ctx->assoc_groups_num += 1;
+
        return assoc_group;
 }
 
@@ -698,10 +716,10 @@ static NTSTATUS rpc_worker_assoc_group_find(
        uint32_t assoc_group_id = call->pkt.u.bind.assoc_group_id;
 
        if (assoc_group_id != 0) {
-               uint8_t worker_index = (assoc_group_id & 0xff000000) >> 24;
+               uint16_t worker_index = (assoc_group_id & 0xffff0000) >> 16;
                if (worker_index != w->status.worker_index) {
-                       DBG_DEBUG("Wrong worker id %"PRIu8", "
-                                 "expected %"PRIu8"\n",
+                       DBG_DEBUG("Wrong worker id %"PRIu16", "
+                                 "expected %"PRIu32"\n",
                                  worker_index,
                                  w->status.worker_index);
                        return NT_STATUS_NOT_FOUND;
@@ -801,7 +819,7 @@ static struct tevent_req *rpc_worker_send(
                tevent_req_error(req, EINVAL);
                return tevent_req_post(req, ev);
        }
-       if ((worker_index < 0) || ((unsigned)worker_index > UINT32_MAX)) {
+       if ((worker_index < 0) || ((unsigned)worker_index > UINT16_MAX)) {
                DBG_ERR("Invalid worker index %d\n", worker_index);
                tevent_req_error(req, EINVAL);
                return tevent_req_post(req, ev);
@@ -955,7 +973,7 @@ static NTSTATUS register_ep_server(
  *
  * get_servers() is called when the process is about to do the real
  * work. So more heavy-weight initialization should happen here. It
- * should return the number of server implementations provided.
+ * should return NT_STATUS_OK and the number of server implementations provided.
  *
  * @param[in] argc argc from main()
  * @param[in] argv argv from main()
@@ -974,9 +992,10 @@ int rpc_worker_main(
        size_t (*get_interfaces)(
                const struct ndr_interface_table ***ifaces,
                void *private_data),
-       size_t (*get_servers)(
+       NTSTATUS (*get_servers)(
                struct dcesrv_context *dce_ctx,
                const struct dcesrv_endpoint_server ***ep_servers,
+               size_t *num_ep_servers,
                void *private_data),
        void *private_data)
 {
@@ -1122,10 +1141,12 @@ int rpc_worker_main(
        /* Ignore children - no zombies. */
        CatchChild();
 
-       DEBUG(0, ("%s version %s started.\n",
-                 progname,
-                 samba_version_string()));
-       DEBUGADD(0,("%s\n", COPYRIGHT_STARTUP_MESSAGE));
+       reopen_logs();
+
+       DBG_STARTUP_NOTICE("%s version %s started.\n%s\n",
+                          progname,
+                          samba_version_string(),
+                          samba_copyright_string());
 
        msg_ctx = global_messaging_context();
        if (msg_ctx == NULL) {
@@ -1185,15 +1206,24 @@ int rpc_worker_main(
 
        DBG_INFO("Initializing DCE/RPC registered endpoint servers\n");
 
-       num_servers = get_servers(dce_ctx, &ep_servers, private_data);
+       status = get_servers(dce_ctx,
+                            &ep_servers,
+                            &num_servers,
+                            private_data);
+       if (!NT_STATUS_IS_OK(status)) {
+               DBG_ERR("get_servers failed: %s\n", nt_errstr(status));
+               global_messaging_context_free();
+               TALLOC_FREE(frame);
+               exit(1);
+       }
 
        DBG_DEBUG("get_servers() returned %zu servers\n", num_servers);
 
        for (i=0; i<num_servers; i++) {
                status = register_ep_server(dce_ctx, ep_servers[i]);
                if (!NT_STATUS_IS_OK(status)) {
-                       DBG_DEBUG("register_ep_server failed: %s\n",
-                                 nt_errstr(status));
+                       DBG_ERR("register_ep_server failed: %s\n",
+                               nt_errstr(status));
                        global_messaging_context_free();
                        TALLOC_FREE(frame);
                        exit(1);