io_uring/net: unify how recvmsg and sendmsg copy in the msghdr
[sfrench/cifs-2.6.git] / io_uring / net.c
index 43bc9a5f96f9d1ce48f2d6e2a5cf850c675aee0d..8c64e552a00f032edc0c698a6dedfb579ece1aaa 100644 (file)
@@ -204,16 +204,150 @@ static int io_setup_async_msg(struct io_kiocb *req,
        return -EAGAIN;
 }
 
+static bool io_recvmsg_multishot_overflow(struct io_async_msghdr *iomsg)
+{
+       int hdr;
+
+       if (iomsg->namelen < 0)
+               return true;
+       if (check_add_overflow((int)sizeof(struct io_uring_recvmsg_out),
+                              iomsg->namelen, &hdr))
+               return true;
+       if (check_add_overflow(hdr, (int)iomsg->controllen, &hdr))
+               return true;
+
+       return false;
+}
+
+#ifdef CONFIG_COMPAT
+static int __io_compat_msg_copy_hdr(struct io_kiocb *req,
+                                   struct io_async_msghdr *iomsg,
+                                   struct sockaddr __user **addr, int ddir)
+{
+       struct io_sr_msg *sr = io_kiocb_to_cmd(req, struct io_sr_msg);
+       struct compat_msghdr msg;
+       struct compat_iovec __user *uiov;
+       int ret;
+
+       if (copy_from_user(&msg, sr->umsg_compat, sizeof(msg)))
+               return -EFAULT;
+
+       ret = __get_compat_msghdr(&iomsg->msg, &msg, addr);
+       if (ret)
+               return ret;
+
+       uiov = compat_ptr(msg.msg_iov);
+       if (req->flags & REQ_F_BUFFER_SELECT) {
+               compat_ssize_t clen;
+
+               iomsg->free_iov = NULL;
+               if (msg.msg_iovlen == 0) {
+                       sr->len = 0;
+               } else if (msg.msg_iovlen > 1) {
+                       return -EINVAL;
+               } else {
+                       if (!access_ok(uiov, sizeof(*uiov)))
+                               return -EFAULT;
+                       if (__get_user(clen, &uiov->iov_len))
+                               return -EFAULT;
+                       if (clen < 0)
+                               return -EINVAL;
+                       sr->len = clen;
+               }
+
+               if (ddir == ITER_DEST && req->flags & REQ_F_APOLL_MULTISHOT) {
+                       iomsg->namelen = msg.msg_namelen;
+                       iomsg->controllen = msg.msg_controllen;
+                       if (io_recvmsg_multishot_overflow(iomsg))
+                               return -EOVERFLOW;
+               }
+
+               return 0;
+       }
+
+       iomsg->free_iov = iomsg->fast_iov;
+       ret = __import_iovec(ddir, (struct iovec __user *)uiov, msg.msg_iovlen,
+                               UIO_FASTIOV, &iomsg->free_iov,
+                               &iomsg->msg.msg_iter, true);
+       if (unlikely(ret < 0))
+               return ret;
+
+       return 0;
+}
+#endif
+
+static int __io_msg_copy_hdr(struct io_kiocb *req, struct io_async_msghdr *iomsg,
+                            struct sockaddr __user **addr, int ddir)
+{
+       struct io_sr_msg *sr = io_kiocb_to_cmd(req, struct io_sr_msg);
+       struct user_msghdr msg;
+       int ret;
+
+       if (copy_from_user(&msg, sr->umsg, sizeof(*sr->umsg)))
+               return -EFAULT;
+
+       ret = __copy_msghdr(&iomsg->msg, &msg, addr);
+       if (ret)
+               return ret;
+
+       if (req->flags & REQ_F_BUFFER_SELECT) {
+               if (msg.msg_iovlen == 0) {
+                       sr->len = iomsg->fast_iov[0].iov_len = 0;
+                       iomsg->fast_iov[0].iov_base = NULL;
+                       iomsg->free_iov = NULL;
+               } else if (msg.msg_iovlen > 1) {
+                       return -EINVAL;
+               } else {
+                       if (copy_from_user(iomsg->fast_iov, msg.msg_iov,
+                                          sizeof(*msg.msg_iov)))
+                               return -EFAULT;
+                       sr->len = iomsg->fast_iov[0].iov_len;
+                       iomsg->free_iov = NULL;
+               }
+
+               if (ddir == ITER_DEST && req->flags & REQ_F_APOLL_MULTISHOT) {
+                       iomsg->namelen = msg.msg_namelen;
+                       iomsg->controllen = msg.msg_controllen;
+                       if (io_recvmsg_multishot_overflow(iomsg))
+                               return -EOVERFLOW;
+               }
+
+               return 0;
+       }
+
+       iomsg->free_iov = iomsg->fast_iov;
+       ret = __import_iovec(ddir, msg.msg_iov, msg.msg_iovlen, UIO_FASTIOV,
+                               &iomsg->free_iov, &iomsg->msg.msg_iter, false);
+       if (unlikely(ret < 0))
+               return ret;
+
+       return 0;
+}
+
+static int io_msg_copy_hdr(struct io_kiocb *req, struct io_async_msghdr *iomsg,
+                          struct sockaddr __user **addr, int ddir)
+{
+       iomsg->msg.msg_name = &iomsg->addr;
+       iomsg->msg.msg_iter.nr_segs = 0;
+
+#ifdef CONFIG_COMPAT
+       if (req->ctx->compat)
+               return __io_compat_msg_copy_hdr(req, iomsg, addr, ddir);
+#endif
+
+       return __io_msg_copy_hdr(req, iomsg, addr, ddir);
+}
+
 static int io_sendmsg_copy_hdr(struct io_kiocb *req,
                               struct io_async_msghdr *iomsg)
 {
        struct io_sr_msg *sr = io_kiocb_to_cmd(req, struct io_sr_msg);
        int ret;
 
-       iomsg->msg.msg_name = &iomsg->addr;
-       iomsg->free_iov = iomsg->fast_iov;
-       ret = sendmsg_copy_msghdr(&iomsg->msg, sr->umsg, sr->msg_flags,
-                                       &iomsg->free_iov);
+       ret = io_msg_copy_hdr(req, iomsg, NULL, ITER_SOURCE);
+       if (ret)
+               return ret;
+
        /* save msg_control as sys_sendmsg() overwrites it */
        sr->msg_control = iomsg->msg.msg_control_user;
        return ret;
@@ -435,142 +569,21 @@ int io_send(struct io_kiocb *req, unsigned int issue_flags)
        return IOU_OK;
 }
 
-static bool io_recvmsg_multishot_overflow(struct io_async_msghdr *iomsg)
-{
-       int hdr;
-
-       if (iomsg->namelen < 0)
-               return true;
-       if (check_add_overflow((int)sizeof(struct io_uring_recvmsg_out),
-                              iomsg->namelen, &hdr))
-               return true;
-       if (check_add_overflow(hdr, (int)iomsg->controllen, &hdr))
-               return true;
-
-       return false;
-}
-
-static int __io_recvmsg_copy_hdr(struct io_kiocb *req,
-                                struct io_async_msghdr *iomsg)
-{
-       struct io_sr_msg *sr = io_kiocb_to_cmd(req, struct io_sr_msg);
-       struct user_msghdr msg;
-       int ret;
-
-       if (copy_from_user(&msg, sr->umsg, sizeof(*sr->umsg)))
-               return -EFAULT;
-
-       ret = __copy_msghdr(&iomsg->msg, &msg, &iomsg->uaddr);
-       if (ret)
-               return ret;
-
-       if (req->flags & REQ_F_BUFFER_SELECT) {
-               if (msg.msg_iovlen == 0) {
-                       sr->len = iomsg->fast_iov[0].iov_len = 0;
-                       iomsg->fast_iov[0].iov_base = NULL;
-                       iomsg->free_iov = NULL;
-               } else if (msg.msg_iovlen > 1) {
-                       return -EINVAL;
-               } else {
-                       if (copy_from_user(iomsg->fast_iov, msg.msg_iov, sizeof(*msg.msg_iov)))
-                               return -EFAULT;
-                       sr->len = iomsg->fast_iov[0].iov_len;
-                       iomsg->free_iov = NULL;
-               }
-
-               if (req->flags & REQ_F_APOLL_MULTISHOT) {
-                       iomsg->namelen = msg.msg_namelen;
-                       iomsg->controllen = msg.msg_controllen;
-                       if (io_recvmsg_multishot_overflow(iomsg))
-                               return -EOVERFLOW;
-               }
-       } else {
-               iomsg->free_iov = iomsg->fast_iov;
-               ret = __import_iovec(ITER_DEST, msg.msg_iov, msg.msg_iovlen, UIO_FASTIOV,
-                                    &iomsg->free_iov, &iomsg->msg.msg_iter,
-                                    false);
-               if (ret > 0)
-                       ret = 0;
-       }
-
-       return ret;
-}
-
-#ifdef CONFIG_COMPAT
-static int __io_compat_recvmsg_copy_hdr(struct io_kiocb *req,
-                                       struct io_async_msghdr *iomsg)
-{
-       struct io_sr_msg *sr = io_kiocb_to_cmd(req, struct io_sr_msg);
-       struct compat_msghdr msg;
-       struct compat_iovec __user *uiov;
-       int ret;
-
-       if (copy_from_user(&msg, sr->umsg_compat, sizeof(msg)))
-               return -EFAULT;
-
-       ret = __get_compat_msghdr(&iomsg->msg, &msg, &iomsg->uaddr);
-       if (ret)
-               return ret;
-
-       uiov = compat_ptr(msg.msg_iov);
-       if (req->flags & REQ_F_BUFFER_SELECT) {
-               compat_ssize_t clen;
-
-               iomsg->free_iov = NULL;
-               if (msg.msg_iovlen == 0) {
-                       sr->len = 0;
-               } else if (msg.msg_iovlen > 1) {
-                       return -EINVAL;
-               } else {
-                       if (!access_ok(uiov, sizeof(*uiov)))
-                               return -EFAULT;
-                       if (__get_user(clen, &uiov->iov_len))
-                               return -EFAULT;
-                       if (clen < 0)
-                               return -EINVAL;
-                       sr->len = clen;
-               }
-
-               if (req->flags & REQ_F_APOLL_MULTISHOT) {
-                       iomsg->namelen = msg.msg_namelen;
-                       iomsg->controllen = msg.msg_controllen;
-                       if (io_recvmsg_multishot_overflow(iomsg))
-                               return -EOVERFLOW;
-               }
-       } else {
-               iomsg->free_iov = iomsg->fast_iov;
-               ret = __import_iovec(ITER_DEST, (struct iovec __user *)uiov, msg.msg_iovlen,
-                                  UIO_FASTIOV, &iomsg->free_iov,
-                                  &iomsg->msg.msg_iter, true);
-               if (ret < 0)
-                       return ret;
-       }
-
-       return 0;
-}
-#endif
-
 static int io_recvmsg_copy_hdr(struct io_kiocb *req,
                               struct io_async_msghdr *iomsg)
 {
-       iomsg->msg.msg_name = &iomsg->addr;
-       iomsg->msg.msg_iter.nr_segs = 0;
-
-#ifdef CONFIG_COMPAT
-       if (req->ctx->compat)
-               return __io_compat_recvmsg_copy_hdr(req, iomsg);
-#endif
-
-       return __io_recvmsg_copy_hdr(req, iomsg);
+       return io_msg_copy_hdr(req, iomsg, &iomsg->uaddr, ITER_DEST);
 }
 
 int io_recvmsg_prep_async(struct io_kiocb *req)
 {
+       struct io_async_msghdr *iomsg;
        int ret;
 
        if (!io_msg_alloc_async_prep(req))
                return -ENOMEM;
-       ret = io_recvmsg_copy_hdr(req, req->async_data);
+       iomsg = req->async_data;
+       ret = io_recvmsg_copy_hdr(req, iomsg);
        if (!ret)
                req->flags |= REQ_F_NEED_CLEANUP;
        return ret;