Merge tag 'core-entry-2024-03-23' of git://git.kernel.org/pub/scm/linux/kernel/git...
[sfrench/cifs-2.6.git] / io_uring / net.c
1 // SPDX-License-Identifier: GPL-2.0
2 #include <linux/kernel.h>
3 #include <linux/errno.h>
4 #include <linux/file.h>
5 #include <linux/slab.h>
6 #include <linux/net.h>
7 #include <linux/compat.h>
8 #include <net/compat.h>
9 #include <linux/io_uring.h>
10
11 #include <uapi/linux/io_uring.h>
12
13 #include "io_uring.h"
14 #include "kbuf.h"
15 #include "alloc_cache.h"
16 #include "net.h"
17 #include "notif.h"
18 #include "rsrc.h"
19
20 #if defined(CONFIG_NET)
21 struct io_shutdown {
22         struct file                     *file;
23         int                             how;
24 };
25
26 struct io_accept {
27         struct file                     *file;
28         struct sockaddr __user          *addr;
29         int __user                      *addr_len;
30         int                             flags;
31         u32                             file_slot;
32         unsigned long                   nofile;
33 };
34
35 struct io_socket {
36         struct file                     *file;
37         int                             domain;
38         int                             type;
39         int                             protocol;
40         int                             flags;
41         u32                             file_slot;
42         unsigned long                   nofile;
43 };
44
45 struct io_connect {
46         struct file                     *file;
47         struct sockaddr __user          *addr;
48         int                             addr_len;
49         bool                            in_progress;
50         bool                            seen_econnaborted;
51 };
52
53 struct io_sr_msg {
54         struct file                     *file;
55         union {
56                 struct compat_msghdr __user     *umsg_compat;
57                 struct user_msghdr __user       *umsg;
58                 void __user                     *buf;
59         };
60         unsigned                        len;
61         unsigned                        done_io;
62         unsigned                        msg_flags;
63         unsigned                        nr_multishot_loops;
64         u16                             flags;
65         /* initialised and used only by !msg send variants */
66         u16                             addr_len;
67         u16                             buf_group;
68         void __user                     *addr;
69         void __user                     *msg_control;
70         /* used only for send zerocopy */
71         struct io_kiocb                 *notif;
72 };
73
74 /*
75  * Number of times we'll try and do receives if there's more data. If we
76  * exceed this limit, then add us to the back of the queue and retry from
77  * there. This helps fairness between flooding clients.
78  */
79 #define MULTISHOT_MAX_RETRY     32
80
81 int io_shutdown_prep(struct io_kiocb *req, const struct io_uring_sqe *sqe)
82 {
83         struct io_shutdown *shutdown = io_kiocb_to_cmd(req, struct io_shutdown);
84
85         if (unlikely(sqe->off || sqe->addr || sqe->rw_flags ||
86                      sqe->buf_index || sqe->splice_fd_in))
87                 return -EINVAL;
88
89         shutdown->how = READ_ONCE(sqe->len);
90         req->flags |= REQ_F_FORCE_ASYNC;
91         return 0;
92 }
93
94 int io_shutdown(struct io_kiocb *req, unsigned int issue_flags)
95 {
96         struct io_shutdown *shutdown = io_kiocb_to_cmd(req, struct io_shutdown);
97         struct socket *sock;
98         int ret;
99
100         WARN_ON_ONCE(issue_flags & IO_URING_F_NONBLOCK);
101
102         sock = sock_from_file(req->file);
103         if (unlikely(!sock))
104                 return -ENOTSOCK;
105
106         ret = __sys_shutdown_sock(sock, shutdown->how);
107         io_req_set_res(req, ret, 0);
108         return IOU_OK;
109 }
110
111 static bool io_net_retry(struct socket *sock, int flags)
112 {
113         if (!(flags & MSG_WAITALL))
114                 return false;
115         return sock->type == SOCK_STREAM || sock->type == SOCK_SEQPACKET;
116 }
117
118 static void io_netmsg_recycle(struct io_kiocb *req, unsigned int issue_flags)
119 {
120         struct io_async_msghdr *hdr = req->async_data;
121
122         if (!req_has_async_data(req) || issue_flags & IO_URING_F_UNLOCKED)
123                 return;
124
125         /* Let normal cleanup path reap it if we fail adding to the cache */
126         if (io_alloc_cache_put(&req->ctx->netmsg_cache, &hdr->cache)) {
127                 req->async_data = NULL;
128                 req->flags &= ~REQ_F_ASYNC_DATA;
129         }
130 }
131
132 static struct io_async_msghdr *io_msg_alloc_async(struct io_kiocb *req,
133                                                   unsigned int issue_flags)
134 {
135         struct io_ring_ctx *ctx = req->ctx;
136         struct io_cache_entry *entry;
137         struct io_async_msghdr *hdr;
138
139         if (!(issue_flags & IO_URING_F_UNLOCKED)) {
140                 entry = io_alloc_cache_get(&ctx->netmsg_cache);
141                 if (entry) {
142                         hdr = container_of(entry, struct io_async_msghdr, cache);
143                         hdr->free_iov = NULL;
144                         req->flags |= REQ_F_ASYNC_DATA;
145                         req->async_data = hdr;
146                         return hdr;
147                 }
148         }
149
150         if (!io_alloc_async_data(req)) {
151                 hdr = req->async_data;
152                 hdr->free_iov = NULL;
153                 return hdr;
154         }
155         return NULL;
156 }
157
158 static inline struct io_async_msghdr *io_msg_alloc_async_prep(struct io_kiocb *req)
159 {
160         /* ->prep_async is always called from the submission context */
161         return io_msg_alloc_async(req, 0);
162 }
163
164 static int io_setup_async_msg(struct io_kiocb *req,
165                               struct io_async_msghdr *kmsg,
166                               unsigned int issue_flags)
167 {
168         struct io_async_msghdr *async_msg;
169
170         if (req_has_async_data(req))
171                 return -EAGAIN;
172         async_msg = io_msg_alloc_async(req, issue_flags);
173         if (!async_msg) {
174                 kfree(kmsg->free_iov);
175                 return -ENOMEM;
176         }
177         req->flags |= REQ_F_NEED_CLEANUP;
178         memcpy(async_msg, kmsg, sizeof(*kmsg));
179         if (async_msg->msg.msg_name)
180                 async_msg->msg.msg_name = &async_msg->addr;
181
182         if ((req->flags & REQ_F_BUFFER_SELECT) && !async_msg->msg.msg_iter.nr_segs)
183                 return -EAGAIN;
184
185         /* if were using fast_iov, set it to the new one */
186         if (iter_is_iovec(&kmsg->msg.msg_iter) && !kmsg->free_iov) {
187                 size_t fast_idx = iter_iov(&kmsg->msg.msg_iter) - kmsg->fast_iov;
188                 async_msg->msg.msg_iter.__iov = &async_msg->fast_iov[fast_idx];
189         }
190
191         return -EAGAIN;
192 }
193
194 #ifdef CONFIG_COMPAT
195 static int io_compat_msg_copy_hdr(struct io_kiocb *req,
196                                   struct io_async_msghdr *iomsg,
197                                   struct compat_msghdr *msg, int ddir)
198 {
199         struct io_sr_msg *sr = io_kiocb_to_cmd(req, struct io_sr_msg);
200         struct compat_iovec __user *uiov;
201         int ret;
202
203         if (copy_from_user(msg, sr->umsg_compat, sizeof(*msg)))
204                 return -EFAULT;
205
206         uiov = compat_ptr(msg->msg_iov);
207         if (req->flags & REQ_F_BUFFER_SELECT) {
208                 compat_ssize_t clen;
209
210                 iomsg->free_iov = NULL;
211                 if (msg->msg_iovlen == 0) {
212                         sr->len = 0;
213                 } else if (msg->msg_iovlen > 1) {
214                         return -EINVAL;
215                 } else {
216                         if (!access_ok(uiov, sizeof(*uiov)))
217                                 return -EFAULT;
218                         if (__get_user(clen, &uiov->iov_len))
219                                 return -EFAULT;
220                         if (clen < 0)
221                                 return -EINVAL;
222                         sr->len = clen;
223                 }
224
225                 return 0;
226         }
227
228         iomsg->free_iov = iomsg->fast_iov;
229         ret = __import_iovec(ddir, (struct iovec __user *)uiov, msg->msg_iovlen,
230                                 UIO_FASTIOV, &iomsg->free_iov,
231                                 &iomsg->msg.msg_iter, true);
232         if (unlikely(ret < 0))
233                 return ret;
234
235         return 0;
236 }
237 #endif
238
239 static int io_msg_copy_hdr(struct io_kiocb *req, struct io_async_msghdr *iomsg,
240                            struct user_msghdr *msg, int ddir)
241 {
242         struct io_sr_msg *sr = io_kiocb_to_cmd(req, struct io_sr_msg);
243         int ret;
244
245         if (!user_access_begin(sr->umsg, sizeof(*sr->umsg)))
246                 return -EFAULT;
247
248         ret = -EFAULT;
249         unsafe_get_user(msg->msg_name, &sr->umsg->msg_name, ua_end);
250         unsafe_get_user(msg->msg_namelen, &sr->umsg->msg_namelen, ua_end);
251         unsafe_get_user(msg->msg_iov, &sr->umsg->msg_iov, ua_end);
252         unsafe_get_user(msg->msg_iovlen, &sr->umsg->msg_iovlen, ua_end);
253         unsafe_get_user(msg->msg_control, &sr->umsg->msg_control, ua_end);
254         unsafe_get_user(msg->msg_controllen, &sr->umsg->msg_controllen, ua_end);
255         msg->msg_flags = 0;
256
257         if (req->flags & REQ_F_BUFFER_SELECT) {
258                 if (msg->msg_iovlen == 0) {
259                         sr->len = iomsg->fast_iov[0].iov_len = 0;
260                         iomsg->fast_iov[0].iov_base = NULL;
261                         iomsg->free_iov = NULL;
262                 } else if (msg->msg_iovlen > 1) {
263                         ret = -EINVAL;
264                         goto ua_end;
265                 } else {
266                         /* we only need the length for provided buffers */
267                         if (!access_ok(&msg->msg_iov[0].iov_len, sizeof(__kernel_size_t)))
268                                 goto ua_end;
269                         unsafe_get_user(iomsg->fast_iov[0].iov_len,
270                                         &msg->msg_iov[0].iov_len, ua_end);
271                         sr->len = iomsg->fast_iov[0].iov_len;
272                         iomsg->free_iov = NULL;
273                 }
274                 ret = 0;
275 ua_end:
276                 user_access_end();
277                 return ret;
278         }
279
280         user_access_end();
281         iomsg->free_iov = iomsg->fast_iov;
282         ret = __import_iovec(ddir, msg->msg_iov, msg->msg_iovlen, UIO_FASTIOV,
283                                 &iomsg->free_iov, &iomsg->msg.msg_iter, false);
284         if (unlikely(ret < 0))
285                 return ret;
286
287         return 0;
288 }
289
290 static int io_sendmsg_copy_hdr(struct io_kiocb *req,
291                                struct io_async_msghdr *iomsg)
292 {
293         struct io_sr_msg *sr = io_kiocb_to_cmd(req, struct io_sr_msg);
294         struct user_msghdr msg;
295         int ret;
296
297         iomsg->msg.msg_name = &iomsg->addr;
298         iomsg->msg.msg_iter.nr_segs = 0;
299
300 #ifdef CONFIG_COMPAT
301         if (unlikely(req->ctx->compat)) {
302                 struct compat_msghdr cmsg;
303
304                 ret = io_compat_msg_copy_hdr(req, iomsg, &cmsg, ITER_SOURCE);
305                 if (unlikely(ret))
306                         return ret;
307
308                 return __get_compat_msghdr(&iomsg->msg, &cmsg, NULL);
309         }
310 #endif
311
312         ret = io_msg_copy_hdr(req, iomsg, &msg, ITER_SOURCE);
313         if (unlikely(ret))
314                 return ret;
315
316         ret = __copy_msghdr(&iomsg->msg, &msg, NULL);
317
318         /* save msg_control as sys_sendmsg() overwrites it */
319         sr->msg_control = iomsg->msg.msg_control_user;
320         return ret;
321 }
322
323 int io_send_prep_async(struct io_kiocb *req)
324 {
325         struct io_sr_msg *zc = io_kiocb_to_cmd(req, struct io_sr_msg);
326         struct io_async_msghdr *io;
327         int ret;
328
329         if (req_has_async_data(req))
330                 return 0;
331         zc->done_io = 0;
332         if (!zc->addr)
333                 return 0;
334         io = io_msg_alloc_async_prep(req);
335         if (!io)
336                 return -ENOMEM;
337         ret = move_addr_to_kernel(zc->addr, zc->addr_len, &io->addr);
338         return ret;
339 }
340
341 static int io_setup_async_addr(struct io_kiocb *req,
342                               struct sockaddr_storage *addr_storage,
343                               unsigned int issue_flags)
344 {
345         struct io_sr_msg *sr = io_kiocb_to_cmd(req, struct io_sr_msg);
346         struct io_async_msghdr *io;
347
348         if (!sr->addr || req_has_async_data(req))
349                 return -EAGAIN;
350         io = io_msg_alloc_async(req, issue_flags);
351         if (!io)
352                 return -ENOMEM;
353         memcpy(&io->addr, addr_storage, sizeof(io->addr));
354         return -EAGAIN;
355 }
356
357 int io_sendmsg_prep_async(struct io_kiocb *req)
358 {
359         struct io_sr_msg *sr = io_kiocb_to_cmd(req, struct io_sr_msg);
360         int ret;
361
362         sr->done_io = 0;
363         if (!io_msg_alloc_async_prep(req))
364                 return -ENOMEM;
365         ret = io_sendmsg_copy_hdr(req, req->async_data);
366         if (!ret)
367                 req->flags |= REQ_F_NEED_CLEANUP;
368         return ret;
369 }
370
371 void io_sendmsg_recvmsg_cleanup(struct io_kiocb *req)
372 {
373         struct io_async_msghdr *io = req->async_data;
374
375         kfree(io->free_iov);
376 }
377
378 int io_sendmsg_prep(struct io_kiocb *req, const struct io_uring_sqe *sqe)
379 {
380         struct io_sr_msg *sr = io_kiocb_to_cmd(req, struct io_sr_msg);
381
382         sr->done_io = 0;
383
384         if (req->opcode == IORING_OP_SEND) {
385                 if (READ_ONCE(sqe->__pad3[0]))
386                         return -EINVAL;
387                 sr->addr = u64_to_user_ptr(READ_ONCE(sqe->addr2));
388                 sr->addr_len = READ_ONCE(sqe->addr_len);
389         } else if (sqe->addr2 || sqe->file_index) {
390                 return -EINVAL;
391         }
392
393         sr->umsg = u64_to_user_ptr(READ_ONCE(sqe->addr));
394         sr->len = READ_ONCE(sqe->len);
395         sr->flags = READ_ONCE(sqe->ioprio);
396         if (sr->flags & ~IORING_RECVSEND_POLL_FIRST)
397                 return -EINVAL;
398         sr->msg_flags = READ_ONCE(sqe->msg_flags) | MSG_NOSIGNAL;
399         if (sr->msg_flags & MSG_DONTWAIT)
400                 req->flags |= REQ_F_NOWAIT;
401
402 #ifdef CONFIG_COMPAT
403         if (req->ctx->compat)
404                 sr->msg_flags |= MSG_CMSG_COMPAT;
405 #endif
406         return 0;
407 }
408
409 static void io_req_msg_cleanup(struct io_kiocb *req,
410                                struct io_async_msghdr *kmsg,
411                                unsigned int issue_flags)
412 {
413         req->flags &= ~REQ_F_NEED_CLEANUP;
414         /* fast path, check for non-NULL to avoid function call */
415         if (kmsg->free_iov)
416                 kfree(kmsg->free_iov);
417         io_netmsg_recycle(req, issue_flags);
418 }
419
420 int io_sendmsg(struct io_kiocb *req, unsigned int issue_flags)
421 {
422         struct io_sr_msg *sr = io_kiocb_to_cmd(req, struct io_sr_msg);
423         struct io_async_msghdr iomsg, *kmsg;
424         struct socket *sock;
425         unsigned flags;
426         int min_ret = 0;
427         int ret;
428
429         sock = sock_from_file(req->file);
430         if (unlikely(!sock))
431                 return -ENOTSOCK;
432
433         if (req_has_async_data(req)) {
434                 kmsg = req->async_data;
435                 kmsg->msg.msg_control_user = sr->msg_control;
436         } else {
437                 ret = io_sendmsg_copy_hdr(req, &iomsg);
438                 if (ret)
439                         return ret;
440                 kmsg = &iomsg;
441         }
442
443         if (!(req->flags & REQ_F_POLLED) &&
444             (sr->flags & IORING_RECVSEND_POLL_FIRST))
445                 return io_setup_async_msg(req, kmsg, issue_flags);
446
447         flags = sr->msg_flags;
448         if (issue_flags & IO_URING_F_NONBLOCK)
449                 flags |= MSG_DONTWAIT;
450         if (flags & MSG_WAITALL)
451                 min_ret = iov_iter_count(&kmsg->msg.msg_iter);
452
453         ret = __sys_sendmsg_sock(sock, &kmsg->msg, flags);
454
455         if (ret < min_ret) {
456                 if (ret == -EAGAIN && (issue_flags & IO_URING_F_NONBLOCK))
457                         return io_setup_async_msg(req, kmsg, issue_flags);
458                 if (ret > 0 && io_net_retry(sock, flags)) {
459                         kmsg->msg.msg_controllen = 0;
460                         kmsg->msg.msg_control = NULL;
461                         sr->done_io += ret;
462                         req->flags |= REQ_F_BL_NO_RECYCLE;
463                         return io_setup_async_msg(req, kmsg, issue_flags);
464                 }
465                 if (ret == -ERESTARTSYS)
466                         ret = -EINTR;
467                 req_set_fail(req);
468         }
469         io_req_msg_cleanup(req, kmsg, issue_flags);
470         if (ret >= 0)
471                 ret += sr->done_io;
472         else if (sr->done_io)
473                 ret = sr->done_io;
474         io_req_set_res(req, ret, 0);
475         return IOU_OK;
476 }
477
478 int io_send(struct io_kiocb *req, unsigned int issue_flags)
479 {
480         struct sockaddr_storage __address;
481         struct io_sr_msg *sr = io_kiocb_to_cmd(req, struct io_sr_msg);
482         struct msghdr msg;
483         struct socket *sock;
484         unsigned flags;
485         int min_ret = 0;
486         int ret;
487
488         msg.msg_name = NULL;
489         msg.msg_control = NULL;
490         msg.msg_controllen = 0;
491         msg.msg_namelen = 0;
492         msg.msg_ubuf = NULL;
493
494         if (sr->addr) {
495                 if (req_has_async_data(req)) {
496                         struct io_async_msghdr *io = req->async_data;
497
498                         msg.msg_name = &io->addr;
499                 } else {
500                         ret = move_addr_to_kernel(sr->addr, sr->addr_len, &__address);
501                         if (unlikely(ret < 0))
502                                 return ret;
503                         msg.msg_name = (struct sockaddr *)&__address;
504                 }
505                 msg.msg_namelen = sr->addr_len;
506         }
507
508         if (!(req->flags & REQ_F_POLLED) &&
509             (sr->flags & IORING_RECVSEND_POLL_FIRST))
510                 return io_setup_async_addr(req, &__address, issue_flags);
511
512         sock = sock_from_file(req->file);
513         if (unlikely(!sock))
514                 return -ENOTSOCK;
515
516         ret = import_ubuf(ITER_SOURCE, sr->buf, sr->len, &msg.msg_iter);
517         if (unlikely(ret))
518                 return ret;
519
520         flags = sr->msg_flags;
521         if (issue_flags & IO_URING_F_NONBLOCK)
522                 flags |= MSG_DONTWAIT;
523         if (flags & MSG_WAITALL)
524                 min_ret = iov_iter_count(&msg.msg_iter);
525
526         flags &= ~MSG_INTERNAL_SENDMSG_FLAGS;
527         msg.msg_flags = flags;
528         ret = sock_sendmsg(sock, &msg);
529         if (ret < min_ret) {
530                 if (ret == -EAGAIN && (issue_flags & IO_URING_F_NONBLOCK))
531                         return io_setup_async_addr(req, &__address, issue_flags);
532
533                 if (ret > 0 && io_net_retry(sock, flags)) {
534                         sr->len -= ret;
535                         sr->buf += ret;
536                         sr->done_io += ret;
537                         req->flags |= REQ_F_BL_NO_RECYCLE;
538                         return io_setup_async_addr(req, &__address, issue_flags);
539                 }
540                 if (ret == -ERESTARTSYS)
541                         ret = -EINTR;
542                 req_set_fail(req);
543         }
544         if (ret >= 0)
545                 ret += sr->done_io;
546         else if (sr->done_io)
547                 ret = sr->done_io;
548         io_req_set_res(req, ret, 0);
549         return IOU_OK;
550 }
551
552 static int io_recvmsg_mshot_prep(struct io_kiocb *req,
553                                  struct io_async_msghdr *iomsg,
554                                  int namelen, size_t controllen)
555 {
556         if ((req->flags & (REQ_F_APOLL_MULTISHOT|REQ_F_BUFFER_SELECT)) ==
557                           (REQ_F_APOLL_MULTISHOT|REQ_F_BUFFER_SELECT)) {
558                 int hdr;
559
560                 if (unlikely(namelen < 0))
561                         return -EOVERFLOW;
562                 if (check_add_overflow(sizeof(struct io_uring_recvmsg_out),
563                                         namelen, &hdr))
564                         return -EOVERFLOW;
565                 if (check_add_overflow(hdr, controllen, &hdr))
566                         return -EOVERFLOW;
567
568                 iomsg->namelen = namelen;
569                 iomsg->controllen = controllen;
570                 return 0;
571         }
572
573         return 0;
574 }
575
576 static int io_recvmsg_copy_hdr(struct io_kiocb *req,
577                                struct io_async_msghdr *iomsg)
578 {
579         struct user_msghdr msg;
580         int ret;
581
582         iomsg->msg.msg_name = &iomsg->addr;
583         iomsg->msg.msg_iter.nr_segs = 0;
584
585 #ifdef CONFIG_COMPAT
586         if (unlikely(req->ctx->compat)) {
587                 struct compat_msghdr cmsg;
588
589                 ret = io_compat_msg_copy_hdr(req, iomsg, &cmsg, ITER_DEST);
590                 if (unlikely(ret))
591                         return ret;
592
593                 ret = __get_compat_msghdr(&iomsg->msg, &cmsg, &iomsg->uaddr);
594                 if (unlikely(ret))
595                         return ret;
596
597                 return io_recvmsg_mshot_prep(req, iomsg, cmsg.msg_namelen,
598                                                 cmsg.msg_controllen);
599         }
600 #endif
601
602         ret = io_msg_copy_hdr(req, iomsg, &msg, ITER_DEST);
603         if (unlikely(ret))
604                 return ret;
605
606         ret = __copy_msghdr(&iomsg->msg, &msg, &iomsg->uaddr);
607         if (unlikely(ret))
608                 return ret;
609
610         return io_recvmsg_mshot_prep(req, iomsg, msg.msg_namelen,
611                                         msg.msg_controllen);
612 }
613
614 int io_recvmsg_prep_async(struct io_kiocb *req)
615 {
616         struct io_sr_msg *sr = io_kiocb_to_cmd(req, struct io_sr_msg);
617         struct io_async_msghdr *iomsg;
618         int ret;
619
620         sr->done_io = 0;
621         if (!io_msg_alloc_async_prep(req))
622                 return -ENOMEM;
623         iomsg = req->async_data;
624         ret = io_recvmsg_copy_hdr(req, iomsg);
625         if (!ret)
626                 req->flags |= REQ_F_NEED_CLEANUP;
627         return ret;
628 }
629
630 #define RECVMSG_FLAGS (IORING_RECVSEND_POLL_FIRST | IORING_RECV_MULTISHOT)
631
632 int io_recvmsg_prep(struct io_kiocb *req, const struct io_uring_sqe *sqe)
633 {
634         struct io_sr_msg *sr = io_kiocb_to_cmd(req, struct io_sr_msg);
635
636         sr->done_io = 0;
637
638         if (unlikely(sqe->file_index || sqe->addr2))
639                 return -EINVAL;
640
641         sr->umsg = u64_to_user_ptr(READ_ONCE(sqe->addr));
642         sr->len = READ_ONCE(sqe->len);
643         sr->flags = READ_ONCE(sqe->ioprio);
644         if (sr->flags & ~(RECVMSG_FLAGS))
645                 return -EINVAL;
646         sr->msg_flags = READ_ONCE(sqe->msg_flags);
647         if (sr->msg_flags & MSG_DONTWAIT)
648                 req->flags |= REQ_F_NOWAIT;
649         if (sr->msg_flags & MSG_ERRQUEUE)
650                 req->flags |= REQ_F_CLEAR_POLLIN;
651         if (sr->flags & IORING_RECV_MULTISHOT) {
652                 if (!(req->flags & REQ_F_BUFFER_SELECT))
653                         return -EINVAL;
654                 if (sr->msg_flags & MSG_WAITALL)
655                         return -EINVAL;
656                 if (req->opcode == IORING_OP_RECV && sr->len)
657                         return -EINVAL;
658                 req->flags |= REQ_F_APOLL_MULTISHOT;
659                 /*
660                  * Store the buffer group for this multishot receive separately,
661                  * as if we end up doing an io-wq based issue that selects a
662                  * buffer, it has to be committed immediately and that will
663                  * clear ->buf_list. This means we lose the link to the buffer
664                  * list, and the eventual buffer put on completion then cannot
665                  * restore it.
666                  */
667                 sr->buf_group = req->buf_index;
668         }
669
670 #ifdef CONFIG_COMPAT
671         if (req->ctx->compat)
672                 sr->msg_flags |= MSG_CMSG_COMPAT;
673 #endif
674         sr->nr_multishot_loops = 0;
675         return 0;
676 }
677
678 static inline void io_recv_prep_retry(struct io_kiocb *req)
679 {
680         struct io_sr_msg *sr = io_kiocb_to_cmd(req, struct io_sr_msg);
681
682         req->flags &= ~REQ_F_BL_EMPTY;
683         sr->done_io = 0;
684         sr->len = 0; /* get from the provided buffer */
685         req->buf_index = sr->buf_group;
686 }
687
688 /*
689  * Finishes io_recv and io_recvmsg.
690  *
691  * Returns true if it is actually finished, or false if it should run
692  * again (for multishot).
693  */
694 static inline bool io_recv_finish(struct io_kiocb *req, int *ret,
695                                   struct msghdr *msg, bool mshot_finished,
696                                   unsigned issue_flags)
697 {
698         unsigned int cflags;
699
700         cflags = io_put_kbuf(req, issue_flags);
701         if (msg->msg_inq > 0)
702                 cflags |= IORING_CQE_F_SOCK_NONEMPTY;
703
704         /*
705          * Fill CQE for this receive and see if we should keep trying to
706          * receive from this socket.
707          */
708         if ((req->flags & REQ_F_APOLL_MULTISHOT) && !mshot_finished &&
709             io_fill_cqe_req_aux(req, issue_flags & IO_URING_F_COMPLETE_DEFER,
710                                 *ret, cflags | IORING_CQE_F_MORE)) {
711                 struct io_sr_msg *sr = io_kiocb_to_cmd(req, struct io_sr_msg);
712                 int mshot_retry_ret = IOU_ISSUE_SKIP_COMPLETE;
713
714                 io_recv_prep_retry(req);
715                 /* Known not-empty or unknown state, retry */
716                 if (cflags & IORING_CQE_F_SOCK_NONEMPTY || msg->msg_inq < 0) {
717                         if (sr->nr_multishot_loops++ < MULTISHOT_MAX_RETRY)
718                                 return false;
719                         /* mshot retries exceeded, force a requeue */
720                         sr->nr_multishot_loops = 0;
721                         mshot_retry_ret = IOU_REQUEUE;
722                 }
723                 if (issue_flags & IO_URING_F_MULTISHOT)
724                         *ret = mshot_retry_ret;
725                 else
726                         *ret = -EAGAIN;
727                 return true;
728         }
729
730         /* Finish the request / stop multishot. */
731         io_req_set_res(req, *ret, cflags);
732
733         if (issue_flags & IO_URING_F_MULTISHOT)
734                 *ret = IOU_STOP_MULTISHOT;
735         else
736                 *ret = IOU_OK;
737         return true;
738 }
739
740 static int io_recvmsg_prep_multishot(struct io_async_msghdr *kmsg,
741                                      struct io_sr_msg *sr, void __user **buf,
742                                      size_t *len)
743 {
744         unsigned long ubuf = (unsigned long) *buf;
745         unsigned long hdr;
746
747         hdr = sizeof(struct io_uring_recvmsg_out) + kmsg->namelen +
748                 kmsg->controllen;
749         if (*len < hdr)
750                 return -EFAULT;
751
752         if (kmsg->controllen) {
753                 unsigned long control = ubuf + hdr - kmsg->controllen;
754
755                 kmsg->msg.msg_control_user = (void __user *) control;
756                 kmsg->msg.msg_controllen = kmsg->controllen;
757         }
758
759         sr->buf = *buf; /* stash for later copy */
760         *buf = (void __user *) (ubuf + hdr);
761         kmsg->payloadlen = *len = *len - hdr;
762         return 0;
763 }
764
765 struct io_recvmsg_multishot_hdr {
766         struct io_uring_recvmsg_out msg;
767         struct sockaddr_storage addr;
768 };
769
770 static int io_recvmsg_multishot(struct socket *sock, struct io_sr_msg *io,
771                                 struct io_async_msghdr *kmsg,
772                                 unsigned int flags, bool *finished)
773 {
774         int err;
775         int copy_len;
776         struct io_recvmsg_multishot_hdr hdr;
777
778         if (kmsg->namelen)
779                 kmsg->msg.msg_name = &hdr.addr;
780         kmsg->msg.msg_flags = flags & (MSG_CMSG_CLOEXEC|MSG_CMSG_COMPAT);
781         kmsg->msg.msg_namelen = 0;
782
783         if (sock->file->f_flags & O_NONBLOCK)
784                 flags |= MSG_DONTWAIT;
785
786         err = sock_recvmsg(sock, &kmsg->msg, flags);
787         *finished = err <= 0;
788         if (err < 0)
789                 return err;
790
791         hdr.msg = (struct io_uring_recvmsg_out) {
792                 .controllen = kmsg->controllen - kmsg->msg.msg_controllen,
793                 .flags = kmsg->msg.msg_flags & ~MSG_CMSG_COMPAT
794         };
795
796         hdr.msg.payloadlen = err;
797         if (err > kmsg->payloadlen)
798                 err = kmsg->payloadlen;
799
800         copy_len = sizeof(struct io_uring_recvmsg_out);
801         if (kmsg->msg.msg_namelen > kmsg->namelen)
802                 copy_len += kmsg->namelen;
803         else
804                 copy_len += kmsg->msg.msg_namelen;
805
806         /*
807          *      "fromlen shall refer to the value before truncation.."
808          *                      1003.1g
809          */
810         hdr.msg.namelen = kmsg->msg.msg_namelen;
811
812         /* ensure that there is no gap between hdr and sockaddr_storage */
813         BUILD_BUG_ON(offsetof(struct io_recvmsg_multishot_hdr, addr) !=
814                      sizeof(struct io_uring_recvmsg_out));
815         if (copy_to_user(io->buf, &hdr, copy_len)) {
816                 *finished = true;
817                 return -EFAULT;
818         }
819
820         return sizeof(struct io_uring_recvmsg_out) + kmsg->namelen +
821                         kmsg->controllen + err;
822 }
823
824 int io_recvmsg(struct io_kiocb *req, unsigned int issue_flags)
825 {
826         struct io_sr_msg *sr = io_kiocb_to_cmd(req, struct io_sr_msg);
827         struct io_async_msghdr iomsg, *kmsg;
828         struct socket *sock;
829         unsigned flags;
830         int ret, min_ret = 0;
831         bool force_nonblock = issue_flags & IO_URING_F_NONBLOCK;
832         bool mshot_finished = true;
833
834         sock = sock_from_file(req->file);
835         if (unlikely(!sock))
836                 return -ENOTSOCK;
837
838         if (req_has_async_data(req)) {
839                 kmsg = req->async_data;
840         } else {
841                 ret = io_recvmsg_copy_hdr(req, &iomsg);
842                 if (ret)
843                         return ret;
844                 kmsg = &iomsg;
845         }
846
847         if (!(req->flags & REQ_F_POLLED) &&
848             (sr->flags & IORING_RECVSEND_POLL_FIRST))
849                 return io_setup_async_msg(req, kmsg, issue_flags);
850
851         flags = sr->msg_flags;
852         if (force_nonblock)
853                 flags |= MSG_DONTWAIT;
854
855 retry_multishot:
856         if (io_do_buffer_select(req)) {
857                 void __user *buf;
858                 size_t len = sr->len;
859
860                 buf = io_buffer_select(req, &len, issue_flags);
861                 if (!buf)
862                         return -ENOBUFS;
863
864                 if (req->flags & REQ_F_APOLL_MULTISHOT) {
865                         ret = io_recvmsg_prep_multishot(kmsg, sr, &buf, &len);
866                         if (ret) {
867                                 io_kbuf_recycle(req, issue_flags);
868                                 return ret;
869                         }
870                 }
871
872                 iov_iter_ubuf(&kmsg->msg.msg_iter, ITER_DEST, buf, len);
873         }
874
875         kmsg->msg.msg_get_inq = 1;
876         kmsg->msg.msg_inq = -1;
877         if (req->flags & REQ_F_APOLL_MULTISHOT) {
878                 ret = io_recvmsg_multishot(sock, sr, kmsg, flags,
879                                            &mshot_finished);
880         } else {
881                 /* disable partial retry for recvmsg with cmsg attached */
882                 if (flags & MSG_WAITALL && !kmsg->msg.msg_controllen)
883                         min_ret = iov_iter_count(&kmsg->msg.msg_iter);
884
885                 ret = __sys_recvmsg_sock(sock, &kmsg->msg, sr->umsg,
886                                          kmsg->uaddr, flags);
887         }
888
889         if (ret < min_ret) {
890                 if (ret == -EAGAIN && force_nonblock) {
891                         ret = io_setup_async_msg(req, kmsg, issue_flags);
892                         if (ret == -EAGAIN && (issue_flags & IO_URING_F_MULTISHOT)) {
893                                 io_kbuf_recycle(req, issue_flags);
894                                 return IOU_ISSUE_SKIP_COMPLETE;
895                         }
896                         return ret;
897                 }
898                 if (ret > 0 && io_net_retry(sock, flags)) {
899                         sr->done_io += ret;
900                         req->flags |= REQ_F_BL_NO_RECYCLE;
901                         return io_setup_async_msg(req, kmsg, issue_flags);
902                 }
903                 if (ret == -ERESTARTSYS)
904                         ret = -EINTR;
905                 req_set_fail(req);
906         } else if ((flags & MSG_WAITALL) && (kmsg->msg.msg_flags & (MSG_TRUNC | MSG_CTRUNC))) {
907                 req_set_fail(req);
908         }
909
910         if (ret > 0)
911                 ret += sr->done_io;
912         else if (sr->done_io)
913                 ret = sr->done_io;
914         else
915                 io_kbuf_recycle(req, issue_flags);
916
917         if (!io_recv_finish(req, &ret, &kmsg->msg, mshot_finished, issue_flags))
918                 goto retry_multishot;
919
920         if (mshot_finished)
921                 io_req_msg_cleanup(req, kmsg, issue_flags);
922         else if (ret == -EAGAIN)
923                 return io_setup_async_msg(req, kmsg, issue_flags);
924
925         return ret;
926 }
927
928 int io_recv(struct io_kiocb *req, unsigned int issue_flags)
929 {
930         struct io_sr_msg *sr = io_kiocb_to_cmd(req, struct io_sr_msg);
931         struct msghdr msg;
932         struct socket *sock;
933         unsigned flags;
934         int ret, min_ret = 0;
935         bool force_nonblock = issue_flags & IO_URING_F_NONBLOCK;
936         size_t len = sr->len;
937
938         if (!(req->flags & REQ_F_POLLED) &&
939             (sr->flags & IORING_RECVSEND_POLL_FIRST))
940                 return -EAGAIN;
941
942         sock = sock_from_file(req->file);
943         if (unlikely(!sock))
944                 return -ENOTSOCK;
945
946         msg.msg_name = NULL;
947         msg.msg_namelen = 0;
948         msg.msg_control = NULL;
949         msg.msg_get_inq = 1;
950         msg.msg_controllen = 0;
951         msg.msg_iocb = NULL;
952         msg.msg_ubuf = NULL;
953
954         flags = sr->msg_flags;
955         if (force_nonblock)
956                 flags |= MSG_DONTWAIT;
957
958 retry_multishot:
959         if (io_do_buffer_select(req)) {
960                 void __user *buf;
961
962                 buf = io_buffer_select(req, &len, issue_flags);
963                 if (!buf)
964                         return -ENOBUFS;
965                 sr->buf = buf;
966                 sr->len = len;
967         }
968
969         ret = import_ubuf(ITER_DEST, sr->buf, len, &msg.msg_iter);
970         if (unlikely(ret))
971                 goto out_free;
972
973         msg.msg_inq = -1;
974         msg.msg_flags = 0;
975
976         if (flags & MSG_WAITALL)
977                 min_ret = iov_iter_count(&msg.msg_iter);
978
979         ret = sock_recvmsg(sock, &msg, flags);
980         if (ret < min_ret) {
981                 if (ret == -EAGAIN && force_nonblock) {
982                         if (issue_flags & IO_URING_F_MULTISHOT) {
983                                 io_kbuf_recycle(req, issue_flags);
984                                 return IOU_ISSUE_SKIP_COMPLETE;
985                         }
986
987                         return -EAGAIN;
988                 }
989                 if (ret > 0 && io_net_retry(sock, flags)) {
990                         sr->len -= ret;
991                         sr->buf += ret;
992                         sr->done_io += ret;
993                         req->flags |= REQ_F_BL_NO_RECYCLE;
994                         return -EAGAIN;
995                 }
996                 if (ret == -ERESTARTSYS)
997                         ret = -EINTR;
998                 req_set_fail(req);
999         } else if ((flags & MSG_WAITALL) && (msg.msg_flags & (MSG_TRUNC | MSG_CTRUNC))) {
1000 out_free:
1001                 req_set_fail(req);
1002         }
1003
1004         if (ret > 0)
1005                 ret += sr->done_io;
1006         else if (sr->done_io)
1007                 ret = sr->done_io;
1008         else
1009                 io_kbuf_recycle(req, issue_flags);
1010
1011         if (!io_recv_finish(req, &ret, &msg, ret <= 0, issue_flags))
1012                 goto retry_multishot;
1013
1014         return ret;
1015 }
1016
1017 void io_send_zc_cleanup(struct io_kiocb *req)
1018 {
1019         struct io_sr_msg *zc = io_kiocb_to_cmd(req, struct io_sr_msg);
1020         struct io_async_msghdr *io;
1021
1022         if (req_has_async_data(req)) {
1023                 io = req->async_data;
1024                 /* might be ->fast_iov if *msg_copy_hdr failed */
1025                 if (io->free_iov != io->fast_iov)
1026                         kfree(io->free_iov);
1027         }
1028         if (zc->notif) {
1029                 io_notif_flush(zc->notif);
1030                 zc->notif = NULL;
1031         }
1032 }
1033
1034 #define IO_ZC_FLAGS_COMMON (IORING_RECVSEND_POLL_FIRST | IORING_RECVSEND_FIXED_BUF)
1035 #define IO_ZC_FLAGS_VALID  (IO_ZC_FLAGS_COMMON | IORING_SEND_ZC_REPORT_USAGE)
1036
1037 int io_send_zc_prep(struct io_kiocb *req, const struct io_uring_sqe *sqe)
1038 {
1039         struct io_sr_msg *zc = io_kiocb_to_cmd(req, struct io_sr_msg);
1040         struct io_ring_ctx *ctx = req->ctx;
1041         struct io_kiocb *notif;
1042
1043         zc->done_io = 0;
1044
1045         if (unlikely(READ_ONCE(sqe->__pad2[0]) || READ_ONCE(sqe->addr3)))
1046                 return -EINVAL;
1047         /* we don't support IOSQE_CQE_SKIP_SUCCESS just yet */
1048         if (req->flags & REQ_F_CQE_SKIP)
1049                 return -EINVAL;
1050
1051         notif = zc->notif = io_alloc_notif(ctx);
1052         if (!notif)
1053                 return -ENOMEM;
1054         notif->cqe.user_data = req->cqe.user_data;
1055         notif->cqe.res = 0;
1056         notif->cqe.flags = IORING_CQE_F_NOTIF;
1057         req->flags |= REQ_F_NEED_CLEANUP;
1058
1059         zc->flags = READ_ONCE(sqe->ioprio);
1060         if (unlikely(zc->flags & ~IO_ZC_FLAGS_COMMON)) {
1061                 if (zc->flags & ~IO_ZC_FLAGS_VALID)
1062                         return -EINVAL;
1063                 if (zc->flags & IORING_SEND_ZC_REPORT_USAGE) {
1064                         io_notif_set_extended(notif);
1065                         io_notif_to_data(notif)->zc_report = true;
1066                 }
1067         }
1068
1069         if (zc->flags & IORING_RECVSEND_FIXED_BUF) {
1070                 unsigned idx = READ_ONCE(sqe->buf_index);
1071
1072                 if (unlikely(idx >= ctx->nr_user_bufs))
1073                         return -EFAULT;
1074                 idx = array_index_nospec(idx, ctx->nr_user_bufs);
1075                 req->imu = READ_ONCE(ctx->user_bufs[idx]);
1076                 io_req_set_rsrc_node(notif, ctx, 0);
1077         }
1078
1079         if (req->opcode == IORING_OP_SEND_ZC) {
1080                 if (READ_ONCE(sqe->__pad3[0]))
1081                         return -EINVAL;
1082                 zc->addr = u64_to_user_ptr(READ_ONCE(sqe->addr2));
1083                 zc->addr_len = READ_ONCE(sqe->addr_len);
1084         } else {
1085                 if (unlikely(sqe->addr2 || sqe->file_index))
1086                         return -EINVAL;
1087                 if (unlikely(zc->flags & IORING_RECVSEND_FIXED_BUF))
1088                         return -EINVAL;
1089         }
1090
1091         zc->buf = u64_to_user_ptr(READ_ONCE(sqe->addr));
1092         zc->len = READ_ONCE(sqe->len);
1093         zc->msg_flags = READ_ONCE(sqe->msg_flags) | MSG_NOSIGNAL;
1094         if (zc->msg_flags & MSG_DONTWAIT)
1095                 req->flags |= REQ_F_NOWAIT;
1096
1097 #ifdef CONFIG_COMPAT
1098         if (req->ctx->compat)
1099                 zc->msg_flags |= MSG_CMSG_COMPAT;
1100 #endif
1101         return 0;
1102 }
1103
1104 static int io_sg_from_iter_iovec(struct sock *sk, struct sk_buff *skb,
1105                                  struct iov_iter *from, size_t length)
1106 {
1107         skb_zcopy_downgrade_managed(skb);
1108         return __zerocopy_sg_from_iter(NULL, sk, skb, from, length);
1109 }
1110
1111 static int io_sg_from_iter(struct sock *sk, struct sk_buff *skb,
1112                            struct iov_iter *from, size_t length)
1113 {
1114         struct skb_shared_info *shinfo = skb_shinfo(skb);
1115         int frag = shinfo->nr_frags;
1116         int ret = 0;
1117         struct bvec_iter bi;
1118         ssize_t copied = 0;
1119         unsigned long truesize = 0;
1120
1121         if (!frag)
1122                 shinfo->flags |= SKBFL_MANAGED_FRAG_REFS;
1123         else if (unlikely(!skb_zcopy_managed(skb)))
1124                 return __zerocopy_sg_from_iter(NULL, sk, skb, from, length);
1125
1126         bi.bi_size = min(from->count, length);
1127         bi.bi_bvec_done = from->iov_offset;
1128         bi.bi_idx = 0;
1129
1130         while (bi.bi_size && frag < MAX_SKB_FRAGS) {
1131                 struct bio_vec v = mp_bvec_iter_bvec(from->bvec, bi);
1132
1133                 copied += v.bv_len;
1134                 truesize += PAGE_ALIGN(v.bv_len + v.bv_offset);
1135                 __skb_fill_page_desc_noacc(shinfo, frag++, v.bv_page,
1136                                            v.bv_offset, v.bv_len);
1137                 bvec_iter_advance_single(from->bvec, &bi, v.bv_len);
1138         }
1139         if (bi.bi_size)
1140                 ret = -EMSGSIZE;
1141
1142         shinfo->nr_frags = frag;
1143         from->bvec += bi.bi_idx;
1144         from->nr_segs -= bi.bi_idx;
1145         from->count -= copied;
1146         from->iov_offset = bi.bi_bvec_done;
1147
1148         skb->data_len += copied;
1149         skb->len += copied;
1150         skb->truesize += truesize;
1151
1152         if (sk && sk->sk_type == SOCK_STREAM) {
1153                 sk_wmem_queued_add(sk, truesize);
1154                 if (!skb_zcopy_pure(skb))
1155                         sk_mem_charge(sk, truesize);
1156         } else {
1157                 refcount_add(truesize, &skb->sk->sk_wmem_alloc);
1158         }
1159         return ret;
1160 }
1161
1162 int io_send_zc(struct io_kiocb *req, unsigned int issue_flags)
1163 {
1164         struct sockaddr_storage __address;
1165         struct io_sr_msg *zc = io_kiocb_to_cmd(req, struct io_sr_msg);
1166         struct msghdr msg;
1167         struct socket *sock;
1168         unsigned msg_flags;
1169         int ret, min_ret = 0;
1170
1171         sock = sock_from_file(req->file);
1172         if (unlikely(!sock))
1173                 return -ENOTSOCK;
1174         if (!test_bit(SOCK_SUPPORT_ZC, &sock->flags))
1175                 return -EOPNOTSUPP;
1176
1177         msg.msg_name = NULL;
1178         msg.msg_control = NULL;
1179         msg.msg_controllen = 0;
1180         msg.msg_namelen = 0;
1181
1182         if (zc->addr) {
1183                 if (req_has_async_data(req)) {
1184                         struct io_async_msghdr *io = req->async_data;
1185
1186                         msg.msg_name = &io->addr;
1187                 } else {
1188                         ret = move_addr_to_kernel(zc->addr, zc->addr_len, &__address);
1189                         if (unlikely(ret < 0))
1190                                 return ret;
1191                         msg.msg_name = (struct sockaddr *)&__address;
1192                 }
1193                 msg.msg_namelen = zc->addr_len;
1194         }
1195
1196         if (!(req->flags & REQ_F_POLLED) &&
1197             (zc->flags & IORING_RECVSEND_POLL_FIRST))
1198                 return io_setup_async_addr(req, &__address, issue_flags);
1199
1200         if (zc->flags & IORING_RECVSEND_FIXED_BUF) {
1201                 ret = io_import_fixed(ITER_SOURCE, &msg.msg_iter, req->imu,
1202                                         (u64)(uintptr_t)zc->buf, zc->len);
1203                 if (unlikely(ret))
1204                         return ret;
1205                 msg.sg_from_iter = io_sg_from_iter;
1206         } else {
1207                 io_notif_set_extended(zc->notif);
1208                 ret = import_ubuf(ITER_SOURCE, zc->buf, zc->len, &msg.msg_iter);
1209                 if (unlikely(ret))
1210                         return ret;
1211                 ret = io_notif_account_mem(zc->notif, zc->len);
1212                 if (unlikely(ret))
1213                         return ret;
1214                 msg.sg_from_iter = io_sg_from_iter_iovec;
1215         }
1216
1217         msg_flags = zc->msg_flags | MSG_ZEROCOPY;
1218         if (issue_flags & IO_URING_F_NONBLOCK)
1219                 msg_flags |= MSG_DONTWAIT;
1220         if (msg_flags & MSG_WAITALL)
1221                 min_ret = iov_iter_count(&msg.msg_iter);
1222         msg_flags &= ~MSG_INTERNAL_SENDMSG_FLAGS;
1223
1224         msg.msg_flags = msg_flags;
1225         msg.msg_ubuf = &io_notif_to_data(zc->notif)->uarg;
1226         ret = sock_sendmsg(sock, &msg);
1227
1228         if (unlikely(ret < min_ret)) {
1229                 if (ret == -EAGAIN && (issue_flags & IO_URING_F_NONBLOCK))
1230                         return io_setup_async_addr(req, &__address, issue_flags);
1231
1232                 if (ret > 0 && io_net_retry(sock, msg.msg_flags)) {
1233                         zc->len -= ret;
1234                         zc->buf += ret;
1235                         zc->done_io += ret;
1236                         req->flags |= REQ_F_BL_NO_RECYCLE;
1237                         return io_setup_async_addr(req, &__address, issue_flags);
1238                 }
1239                 if (ret == -ERESTARTSYS)
1240                         ret = -EINTR;
1241                 req_set_fail(req);
1242         }
1243
1244         if (ret >= 0)
1245                 ret += zc->done_io;
1246         else if (zc->done_io)
1247                 ret = zc->done_io;
1248
1249         /*
1250          * If we're in io-wq we can't rely on tw ordering guarantees, defer
1251          * flushing notif to io_send_zc_cleanup()
1252          */
1253         if (!(issue_flags & IO_URING_F_UNLOCKED)) {
1254                 io_notif_flush(zc->notif);
1255                 req->flags &= ~REQ_F_NEED_CLEANUP;
1256         }
1257         io_req_set_res(req, ret, IORING_CQE_F_MORE);
1258         return IOU_OK;
1259 }
1260
1261 int io_sendmsg_zc(struct io_kiocb *req, unsigned int issue_flags)
1262 {
1263         struct io_sr_msg *sr = io_kiocb_to_cmd(req, struct io_sr_msg);
1264         struct io_async_msghdr iomsg, *kmsg;
1265         struct socket *sock;
1266         unsigned flags;
1267         int ret, min_ret = 0;
1268
1269         io_notif_set_extended(sr->notif);
1270
1271         sock = sock_from_file(req->file);
1272         if (unlikely(!sock))
1273                 return -ENOTSOCK;
1274         if (!test_bit(SOCK_SUPPORT_ZC, &sock->flags))
1275                 return -EOPNOTSUPP;
1276
1277         if (req_has_async_data(req)) {
1278                 kmsg = req->async_data;
1279         } else {
1280                 ret = io_sendmsg_copy_hdr(req, &iomsg);
1281                 if (ret)
1282                         return ret;
1283                 kmsg = &iomsg;
1284         }
1285
1286         if (!(req->flags & REQ_F_POLLED) &&
1287             (sr->flags & IORING_RECVSEND_POLL_FIRST))
1288                 return io_setup_async_msg(req, kmsg, issue_flags);
1289
1290         flags = sr->msg_flags | MSG_ZEROCOPY;
1291         if (issue_flags & IO_URING_F_NONBLOCK)
1292                 flags |= MSG_DONTWAIT;
1293         if (flags & MSG_WAITALL)
1294                 min_ret = iov_iter_count(&kmsg->msg.msg_iter);
1295
1296         kmsg->msg.msg_ubuf = &io_notif_to_data(sr->notif)->uarg;
1297         kmsg->msg.sg_from_iter = io_sg_from_iter_iovec;
1298         ret = __sys_sendmsg_sock(sock, &kmsg->msg, flags);
1299
1300         if (unlikely(ret < min_ret)) {
1301                 if (ret == -EAGAIN && (issue_flags & IO_URING_F_NONBLOCK))
1302                         return io_setup_async_msg(req, kmsg, issue_flags);
1303
1304                 if (ret > 0 && io_net_retry(sock, flags)) {
1305                         sr->done_io += ret;
1306                         req->flags |= REQ_F_BL_NO_RECYCLE;
1307                         return io_setup_async_msg(req, kmsg, issue_flags);
1308                 }
1309                 if (ret == -ERESTARTSYS)
1310                         ret = -EINTR;
1311                 req_set_fail(req);
1312         }
1313         /* fast path, check for non-NULL to avoid function call */
1314         if (kmsg->free_iov) {
1315                 kfree(kmsg->free_iov);
1316                 kmsg->free_iov = NULL;
1317         }
1318
1319         io_netmsg_recycle(req, issue_flags);
1320         if (ret >= 0)
1321                 ret += sr->done_io;
1322         else if (sr->done_io)
1323                 ret = sr->done_io;
1324
1325         /*
1326          * If we're in io-wq we can't rely on tw ordering guarantees, defer
1327          * flushing notif to io_send_zc_cleanup()
1328          */
1329         if (!(issue_flags & IO_URING_F_UNLOCKED)) {
1330                 io_notif_flush(sr->notif);
1331                 req->flags &= ~REQ_F_NEED_CLEANUP;
1332         }
1333         io_req_set_res(req, ret, IORING_CQE_F_MORE);
1334         return IOU_OK;
1335 }
1336
1337 void io_sendrecv_fail(struct io_kiocb *req)
1338 {
1339         struct io_sr_msg *sr = io_kiocb_to_cmd(req, struct io_sr_msg);
1340
1341         if (sr->done_io)
1342                 req->cqe.res = sr->done_io;
1343
1344         if ((req->flags & REQ_F_NEED_CLEANUP) &&
1345             (req->opcode == IORING_OP_SEND_ZC || req->opcode == IORING_OP_SENDMSG_ZC))
1346                 req->cqe.flags |= IORING_CQE_F_MORE;
1347 }
1348
1349 int io_accept_prep(struct io_kiocb *req, const struct io_uring_sqe *sqe)
1350 {
1351         struct io_accept *accept = io_kiocb_to_cmd(req, struct io_accept);
1352         unsigned flags;
1353
1354         if (sqe->len || sqe->buf_index)
1355                 return -EINVAL;
1356
1357         accept->addr = u64_to_user_ptr(READ_ONCE(sqe->addr));
1358         accept->addr_len = u64_to_user_ptr(READ_ONCE(sqe->addr2));
1359         accept->flags = READ_ONCE(sqe->accept_flags);
1360         accept->nofile = rlimit(RLIMIT_NOFILE);
1361         flags = READ_ONCE(sqe->ioprio);
1362         if (flags & ~IORING_ACCEPT_MULTISHOT)
1363                 return -EINVAL;
1364
1365         accept->file_slot = READ_ONCE(sqe->file_index);
1366         if (accept->file_slot) {
1367                 if (accept->flags & SOCK_CLOEXEC)
1368                         return -EINVAL;
1369                 if (flags & IORING_ACCEPT_MULTISHOT &&
1370                     accept->file_slot != IORING_FILE_INDEX_ALLOC)
1371                         return -EINVAL;
1372         }
1373         if (accept->flags & ~(SOCK_CLOEXEC | SOCK_NONBLOCK))
1374                 return -EINVAL;
1375         if (SOCK_NONBLOCK != O_NONBLOCK && (accept->flags & SOCK_NONBLOCK))
1376                 accept->flags = (accept->flags & ~SOCK_NONBLOCK) | O_NONBLOCK;
1377         if (flags & IORING_ACCEPT_MULTISHOT)
1378                 req->flags |= REQ_F_APOLL_MULTISHOT;
1379         return 0;
1380 }
1381
1382 int io_accept(struct io_kiocb *req, unsigned int issue_flags)
1383 {
1384         struct io_accept *accept = io_kiocb_to_cmd(req, struct io_accept);
1385         bool force_nonblock = issue_flags & IO_URING_F_NONBLOCK;
1386         unsigned int file_flags = force_nonblock ? O_NONBLOCK : 0;
1387         bool fixed = !!accept->file_slot;
1388         struct file *file;
1389         int ret, fd;
1390
1391 retry:
1392         if (!fixed) {
1393                 fd = __get_unused_fd_flags(accept->flags, accept->nofile);
1394                 if (unlikely(fd < 0))
1395                         return fd;
1396         }
1397         file = do_accept(req->file, file_flags, accept->addr, accept->addr_len,
1398                          accept->flags);
1399         if (IS_ERR(file)) {
1400                 if (!fixed)
1401                         put_unused_fd(fd);
1402                 ret = PTR_ERR(file);
1403                 if (ret == -EAGAIN && force_nonblock) {
1404                         /*
1405                          * if it's multishot and polled, we don't need to
1406                          * return EAGAIN to arm the poll infra since it
1407                          * has already been done
1408                          */
1409                         if (issue_flags & IO_URING_F_MULTISHOT)
1410                                 return IOU_ISSUE_SKIP_COMPLETE;
1411                         return ret;
1412                 }
1413                 if (ret == -ERESTARTSYS)
1414                         ret = -EINTR;
1415                 req_set_fail(req);
1416         } else if (!fixed) {
1417                 fd_install(fd, file);
1418                 ret = fd;
1419         } else {
1420                 ret = io_fixed_fd_install(req, issue_flags, file,
1421                                                 accept->file_slot);
1422         }
1423
1424         if (!(req->flags & REQ_F_APOLL_MULTISHOT)) {
1425                 io_req_set_res(req, ret, 0);
1426                 return IOU_OK;
1427         }
1428
1429         if (ret < 0)
1430                 return ret;
1431         if (io_fill_cqe_req_aux(req, issue_flags & IO_URING_F_COMPLETE_DEFER,
1432                                 ret, IORING_CQE_F_MORE))
1433                 goto retry;
1434
1435         io_req_set_res(req, ret, 0);
1436         return IOU_STOP_MULTISHOT;
1437 }
1438
1439 int io_socket_prep(struct io_kiocb *req, const struct io_uring_sqe *sqe)
1440 {
1441         struct io_socket *sock = io_kiocb_to_cmd(req, struct io_socket);
1442
1443         if (sqe->addr || sqe->rw_flags || sqe->buf_index)
1444                 return -EINVAL;
1445
1446         sock->domain = READ_ONCE(sqe->fd);
1447         sock->type = READ_ONCE(sqe->off);
1448         sock->protocol = READ_ONCE(sqe->len);
1449         sock->file_slot = READ_ONCE(sqe->file_index);
1450         sock->nofile = rlimit(RLIMIT_NOFILE);
1451
1452         sock->flags = sock->type & ~SOCK_TYPE_MASK;
1453         if (sock->file_slot && (sock->flags & SOCK_CLOEXEC))
1454                 return -EINVAL;
1455         if (sock->flags & ~(SOCK_CLOEXEC | SOCK_NONBLOCK))
1456                 return -EINVAL;
1457         return 0;
1458 }
1459
1460 int io_socket(struct io_kiocb *req, unsigned int issue_flags)
1461 {
1462         struct io_socket *sock = io_kiocb_to_cmd(req, struct io_socket);
1463         bool fixed = !!sock->file_slot;
1464         struct file *file;
1465         int ret, fd;
1466
1467         if (!fixed) {
1468                 fd = __get_unused_fd_flags(sock->flags, sock->nofile);
1469                 if (unlikely(fd < 0))
1470                         return fd;
1471         }
1472         file = __sys_socket_file(sock->domain, sock->type, sock->protocol);
1473         if (IS_ERR(file)) {
1474                 if (!fixed)
1475                         put_unused_fd(fd);
1476                 ret = PTR_ERR(file);
1477                 if (ret == -EAGAIN && (issue_flags & IO_URING_F_NONBLOCK))
1478                         return -EAGAIN;
1479                 if (ret == -ERESTARTSYS)
1480                         ret = -EINTR;
1481                 req_set_fail(req);
1482         } else if (!fixed) {
1483                 fd_install(fd, file);
1484                 ret = fd;
1485         } else {
1486                 ret = io_fixed_fd_install(req, issue_flags, file,
1487                                             sock->file_slot);
1488         }
1489         io_req_set_res(req, ret, 0);
1490         return IOU_OK;
1491 }
1492
1493 int io_connect_prep_async(struct io_kiocb *req)
1494 {
1495         struct io_async_connect *io = req->async_data;
1496         struct io_connect *conn = io_kiocb_to_cmd(req, struct io_connect);
1497
1498         return move_addr_to_kernel(conn->addr, conn->addr_len, &io->address);
1499 }
1500
1501 int io_connect_prep(struct io_kiocb *req, const struct io_uring_sqe *sqe)
1502 {
1503         struct io_connect *conn = io_kiocb_to_cmd(req, struct io_connect);
1504
1505         if (sqe->len || sqe->buf_index || sqe->rw_flags || sqe->splice_fd_in)
1506                 return -EINVAL;
1507
1508         conn->addr = u64_to_user_ptr(READ_ONCE(sqe->addr));
1509         conn->addr_len =  READ_ONCE(sqe->addr2);
1510         conn->in_progress = conn->seen_econnaborted = false;
1511         return 0;
1512 }
1513
1514 int io_connect(struct io_kiocb *req, unsigned int issue_flags)
1515 {
1516         struct io_connect *connect = io_kiocb_to_cmd(req, struct io_connect);
1517         struct io_async_connect __io, *io;
1518         unsigned file_flags;
1519         int ret;
1520         bool force_nonblock = issue_flags & IO_URING_F_NONBLOCK;
1521
1522         if (req_has_async_data(req)) {
1523                 io = req->async_data;
1524         } else {
1525                 ret = move_addr_to_kernel(connect->addr,
1526                                                 connect->addr_len,
1527                                                 &__io.address);
1528                 if (ret)
1529                         goto out;
1530                 io = &__io;
1531         }
1532
1533         file_flags = force_nonblock ? O_NONBLOCK : 0;
1534
1535         ret = __sys_connect_file(req->file, &io->address,
1536                                         connect->addr_len, file_flags);
1537         if ((ret == -EAGAIN || ret == -EINPROGRESS || ret == -ECONNABORTED)
1538             && force_nonblock) {
1539                 if (ret == -EINPROGRESS) {
1540                         connect->in_progress = true;
1541                 } else if (ret == -ECONNABORTED) {
1542                         if (connect->seen_econnaborted)
1543                                 goto out;
1544                         connect->seen_econnaborted = true;
1545                 }
1546                 if (req_has_async_data(req))
1547                         return -EAGAIN;
1548                 if (io_alloc_async_data(req)) {
1549                         ret = -ENOMEM;
1550                         goto out;
1551                 }
1552                 memcpy(req->async_data, &__io, sizeof(__io));
1553                 return -EAGAIN;
1554         }
1555         if (connect->in_progress) {
1556                 /*
1557                  * At least bluetooth will return -EBADFD on a re-connect
1558                  * attempt, and it's (supposedly) also valid to get -EISCONN
1559                  * which means the previous result is good. For both of these,
1560                  * grab the sock_error() and use that for the completion.
1561                  */
1562                 if (ret == -EBADFD || ret == -EISCONN)
1563                         ret = sock_error(sock_from_file(req->file)->sk);
1564         }
1565         if (ret == -ERESTARTSYS)
1566                 ret = -EINTR;
1567 out:
1568         if (ret < 0)
1569                 req_set_fail(req);
1570         io_req_set_res(req, ret, 0);
1571         return IOU_OK;
1572 }
1573
1574 void io_netmsg_cache_free(struct io_cache_entry *entry)
1575 {
1576         kfree(container_of(entry, struct io_async_msghdr, cache));
1577 }
1578 #endif