ionic: update documentation for XDP support
[sfrench/cifs-2.6.git] / io_uring / net.c
1 // SPDX-License-Identifier: GPL-2.0
2 #include <linux/kernel.h>
3 #include <linux/errno.h>
4 #include <linux/file.h>
5 #include <linux/slab.h>
6 #include <linux/net.h>
7 #include <linux/compat.h>
8 #include <net/compat.h>
9 #include <linux/io_uring.h>
10
11 #include <uapi/linux/io_uring.h>
12
13 #include "io_uring.h"
14 #include "kbuf.h"
15 #include "alloc_cache.h"
16 #include "net.h"
17 #include "notif.h"
18 #include "rsrc.h"
19
20 #if defined(CONFIG_NET)
21 struct io_shutdown {
22         struct file                     *file;
23         int                             how;
24 };
25
26 struct io_accept {
27         struct file                     *file;
28         struct sockaddr __user          *addr;
29         int __user                      *addr_len;
30         int                             flags;
31         u32                             file_slot;
32         unsigned long                   nofile;
33 };
34
35 struct io_socket {
36         struct file                     *file;
37         int                             domain;
38         int                             type;
39         int                             protocol;
40         int                             flags;
41         u32                             file_slot;
42         unsigned long                   nofile;
43 };
44
45 struct io_connect {
46         struct file                     *file;
47         struct sockaddr __user          *addr;
48         int                             addr_len;
49         bool                            in_progress;
50         bool                            seen_econnaborted;
51 };
52
53 struct io_sr_msg {
54         struct file                     *file;
55         union {
56                 struct compat_msghdr __user     *umsg_compat;
57                 struct user_msghdr __user       *umsg;
58                 void __user                     *buf;
59         };
60         unsigned                        len;
61         unsigned                        done_io;
62         unsigned                        msg_flags;
63         unsigned                        nr_multishot_loops;
64         u16                             flags;
65         /* initialised and used only by !msg send variants */
66         u16                             addr_len;
67         u16                             buf_group;
68         void __user                     *addr;
69         void __user                     *msg_control;
70         /* used only for send zerocopy */
71         struct io_kiocb                 *notif;
72 };
73
74 /*
75  * Number of times we'll try and do receives if there's more data. If we
76  * exceed this limit, then add us to the back of the queue and retry from
77  * there. This helps fairness between flooding clients.
78  */
79 #define MULTISHOT_MAX_RETRY     32
80
81 int io_shutdown_prep(struct io_kiocb *req, const struct io_uring_sqe *sqe)
82 {
83         struct io_shutdown *shutdown = io_kiocb_to_cmd(req, struct io_shutdown);
84
85         if (unlikely(sqe->off || sqe->addr || sqe->rw_flags ||
86                      sqe->buf_index || sqe->splice_fd_in))
87                 return -EINVAL;
88
89         shutdown->how = READ_ONCE(sqe->len);
90         req->flags |= REQ_F_FORCE_ASYNC;
91         return 0;
92 }
93
94 int io_shutdown(struct io_kiocb *req, unsigned int issue_flags)
95 {
96         struct io_shutdown *shutdown = io_kiocb_to_cmd(req, struct io_shutdown);
97         struct socket *sock;
98         int ret;
99
100         WARN_ON_ONCE(issue_flags & IO_URING_F_NONBLOCK);
101
102         sock = sock_from_file(req->file);
103         if (unlikely(!sock))
104                 return -ENOTSOCK;
105
106         ret = __sys_shutdown_sock(sock, shutdown->how);
107         io_req_set_res(req, ret, 0);
108         return IOU_OK;
109 }
110
111 static bool io_net_retry(struct socket *sock, int flags)
112 {
113         if (!(flags & MSG_WAITALL))
114                 return false;
115         return sock->type == SOCK_STREAM || sock->type == SOCK_SEQPACKET;
116 }
117
118 static void io_netmsg_recycle(struct io_kiocb *req, unsigned int issue_flags)
119 {
120         struct io_async_msghdr *hdr = req->async_data;
121
122         if (!req_has_async_data(req) || issue_flags & IO_URING_F_UNLOCKED)
123                 return;
124
125         /* Let normal cleanup path reap it if we fail adding to the cache */
126         if (io_alloc_cache_put(&req->ctx->netmsg_cache, &hdr->cache)) {
127                 req->async_data = NULL;
128                 req->flags &= ~REQ_F_ASYNC_DATA;
129         }
130 }
131
132 static struct io_async_msghdr *io_msg_alloc_async(struct io_kiocb *req,
133                                                   unsigned int issue_flags)
134 {
135         struct io_ring_ctx *ctx = req->ctx;
136         struct io_cache_entry *entry;
137         struct io_async_msghdr *hdr;
138
139         if (!(issue_flags & IO_URING_F_UNLOCKED)) {
140                 entry = io_alloc_cache_get(&ctx->netmsg_cache);
141                 if (entry) {
142                         hdr = container_of(entry, struct io_async_msghdr, cache);
143                         hdr->free_iov = NULL;
144                         req->flags |= REQ_F_ASYNC_DATA;
145                         req->async_data = hdr;
146                         return hdr;
147                 }
148         }
149
150         if (!io_alloc_async_data(req)) {
151                 hdr = req->async_data;
152                 hdr->free_iov = NULL;
153                 return hdr;
154         }
155         return NULL;
156 }
157
158 static inline struct io_async_msghdr *io_msg_alloc_async_prep(struct io_kiocb *req)
159 {
160         /* ->prep_async is always called from the submission context */
161         return io_msg_alloc_async(req, 0);
162 }
163
164 static int io_setup_async_msg(struct io_kiocb *req,
165                               struct io_async_msghdr *kmsg,
166                               unsigned int issue_flags)
167 {
168         struct io_async_msghdr *async_msg;
169
170         if (req_has_async_data(req))
171                 return -EAGAIN;
172         async_msg = io_msg_alloc_async(req, issue_flags);
173         if (!async_msg) {
174                 kfree(kmsg->free_iov);
175                 return -ENOMEM;
176         }
177         req->flags |= REQ_F_NEED_CLEANUP;
178         memcpy(async_msg, kmsg, sizeof(*kmsg));
179         if (async_msg->msg.msg_name)
180                 async_msg->msg.msg_name = &async_msg->addr;
181
182         if ((req->flags & REQ_F_BUFFER_SELECT) && !async_msg->msg.msg_iter.nr_segs)
183                 return -EAGAIN;
184
185         /* if were using fast_iov, set it to the new one */
186         if (iter_is_iovec(&kmsg->msg.msg_iter) && !kmsg->free_iov) {
187                 size_t fast_idx = iter_iov(&kmsg->msg.msg_iter) - kmsg->fast_iov;
188                 async_msg->msg.msg_iter.__iov = &async_msg->fast_iov[fast_idx];
189         }
190
191         return -EAGAIN;
192 }
193
194 #ifdef CONFIG_COMPAT
195 static int io_compat_msg_copy_hdr(struct io_kiocb *req,
196                                   struct io_async_msghdr *iomsg,
197                                   struct compat_msghdr *msg, int ddir)
198 {
199         struct io_sr_msg *sr = io_kiocb_to_cmd(req, struct io_sr_msg);
200         struct compat_iovec __user *uiov;
201         int ret;
202
203         if (copy_from_user(msg, sr->umsg_compat, sizeof(*msg)))
204                 return -EFAULT;
205
206         uiov = compat_ptr(msg->msg_iov);
207         if (req->flags & REQ_F_BUFFER_SELECT) {
208                 compat_ssize_t clen;
209
210                 iomsg->free_iov = NULL;
211                 if (msg->msg_iovlen == 0) {
212                         sr->len = 0;
213                 } else if (msg->msg_iovlen > 1) {
214                         return -EINVAL;
215                 } else {
216                         if (!access_ok(uiov, sizeof(*uiov)))
217                                 return -EFAULT;
218                         if (__get_user(clen, &uiov->iov_len))
219                                 return -EFAULT;
220                         if (clen < 0)
221                                 return -EINVAL;
222                         sr->len = clen;
223                 }
224
225                 return 0;
226         }
227
228         iomsg->free_iov = iomsg->fast_iov;
229         ret = __import_iovec(ddir, (struct iovec __user *)uiov, msg->msg_iovlen,
230                                 UIO_FASTIOV, &iomsg->free_iov,
231                                 &iomsg->msg.msg_iter, true);
232         if (unlikely(ret < 0))
233                 return ret;
234
235         return 0;
236 }
237 #endif
238
239 static int io_msg_copy_hdr(struct io_kiocb *req, struct io_async_msghdr *iomsg,
240                            struct user_msghdr *msg, int ddir)
241 {
242         struct io_sr_msg *sr = io_kiocb_to_cmd(req, struct io_sr_msg);
243         int ret;
244
245         if (!user_access_begin(sr->umsg, sizeof(*sr->umsg)))
246                 return -EFAULT;
247
248         ret = -EFAULT;
249         unsafe_get_user(msg->msg_name, &sr->umsg->msg_name, ua_end);
250         unsafe_get_user(msg->msg_namelen, &sr->umsg->msg_namelen, ua_end);
251         unsafe_get_user(msg->msg_iov, &sr->umsg->msg_iov, ua_end);
252         unsafe_get_user(msg->msg_iovlen, &sr->umsg->msg_iovlen, ua_end);
253         unsafe_get_user(msg->msg_control, &sr->umsg->msg_control, ua_end);
254         unsafe_get_user(msg->msg_controllen, &sr->umsg->msg_controllen, ua_end);
255         msg->msg_flags = 0;
256
257         if (req->flags & REQ_F_BUFFER_SELECT) {
258                 if (msg->msg_iovlen == 0) {
259                         sr->len = iomsg->fast_iov[0].iov_len = 0;
260                         iomsg->fast_iov[0].iov_base = NULL;
261                         iomsg->free_iov = NULL;
262                 } else if (msg->msg_iovlen > 1) {
263                         ret = -EINVAL;
264                         goto ua_end;
265                 } else {
266                         /* we only need the length for provided buffers */
267                         if (!access_ok(&msg->msg_iov[0].iov_len, sizeof(__kernel_size_t)))
268                                 goto ua_end;
269                         unsafe_get_user(iomsg->fast_iov[0].iov_len,
270                                         &msg->msg_iov[0].iov_len, ua_end);
271                         sr->len = iomsg->fast_iov[0].iov_len;
272                         iomsg->free_iov = NULL;
273                 }
274                 ret = 0;
275 ua_end:
276                 user_access_end();
277                 return ret;
278         }
279
280         user_access_end();
281         iomsg->free_iov = iomsg->fast_iov;
282         ret = __import_iovec(ddir, msg->msg_iov, msg->msg_iovlen, UIO_FASTIOV,
283                                 &iomsg->free_iov, &iomsg->msg.msg_iter, false);
284         if (unlikely(ret < 0))
285                 return ret;
286
287         return 0;
288 }
289
290 static int io_sendmsg_copy_hdr(struct io_kiocb *req,
291                                struct io_async_msghdr *iomsg)
292 {
293         struct io_sr_msg *sr = io_kiocb_to_cmd(req, struct io_sr_msg);
294         struct user_msghdr msg;
295         int ret;
296
297         iomsg->msg.msg_name = &iomsg->addr;
298         iomsg->msg.msg_iter.nr_segs = 0;
299
300 #ifdef CONFIG_COMPAT
301         if (unlikely(req->ctx->compat)) {
302                 struct compat_msghdr cmsg;
303
304                 ret = io_compat_msg_copy_hdr(req, iomsg, &cmsg, ITER_SOURCE);
305                 if (unlikely(ret))
306                         return ret;
307
308                 return __get_compat_msghdr(&iomsg->msg, &cmsg, NULL);
309         }
310 #endif
311
312         ret = io_msg_copy_hdr(req, iomsg, &msg, ITER_SOURCE);
313         if (unlikely(ret))
314                 return ret;
315
316         ret = __copy_msghdr(&iomsg->msg, &msg, NULL);
317
318         /* save msg_control as sys_sendmsg() overwrites it */
319         sr->msg_control = iomsg->msg.msg_control_user;
320         return ret;
321 }
322
323 int io_send_prep_async(struct io_kiocb *req)
324 {
325         struct io_sr_msg *zc = io_kiocb_to_cmd(req, struct io_sr_msg);
326         struct io_async_msghdr *io;
327         int ret;
328
329         if (!zc->addr || req_has_async_data(req))
330                 return 0;
331         io = io_msg_alloc_async_prep(req);
332         if (!io)
333                 return -ENOMEM;
334         ret = move_addr_to_kernel(zc->addr, zc->addr_len, &io->addr);
335         return ret;
336 }
337
338 static int io_setup_async_addr(struct io_kiocb *req,
339                               struct sockaddr_storage *addr_storage,
340                               unsigned int issue_flags)
341 {
342         struct io_sr_msg *sr = io_kiocb_to_cmd(req, struct io_sr_msg);
343         struct io_async_msghdr *io;
344
345         if (!sr->addr || req_has_async_data(req))
346                 return -EAGAIN;
347         io = io_msg_alloc_async(req, issue_flags);
348         if (!io)
349                 return -ENOMEM;
350         memcpy(&io->addr, addr_storage, sizeof(io->addr));
351         return -EAGAIN;
352 }
353
354 int io_sendmsg_prep_async(struct io_kiocb *req)
355 {
356         int ret;
357
358         if (!io_msg_alloc_async_prep(req))
359                 return -ENOMEM;
360         ret = io_sendmsg_copy_hdr(req, req->async_data);
361         if (!ret)
362                 req->flags |= REQ_F_NEED_CLEANUP;
363         return ret;
364 }
365
366 void io_sendmsg_recvmsg_cleanup(struct io_kiocb *req)
367 {
368         struct io_async_msghdr *io = req->async_data;
369
370         kfree(io->free_iov);
371 }
372
373 int io_sendmsg_prep(struct io_kiocb *req, const struct io_uring_sqe *sqe)
374 {
375         struct io_sr_msg *sr = io_kiocb_to_cmd(req, struct io_sr_msg);
376
377         sr->done_io = 0;
378
379         if (req->opcode == IORING_OP_SEND) {
380                 if (READ_ONCE(sqe->__pad3[0]))
381                         return -EINVAL;
382                 sr->addr = u64_to_user_ptr(READ_ONCE(sqe->addr2));
383                 sr->addr_len = READ_ONCE(sqe->addr_len);
384         } else if (sqe->addr2 || sqe->file_index) {
385                 return -EINVAL;
386         }
387
388         sr->umsg = u64_to_user_ptr(READ_ONCE(sqe->addr));
389         sr->len = READ_ONCE(sqe->len);
390         sr->flags = READ_ONCE(sqe->ioprio);
391         if (sr->flags & ~IORING_RECVSEND_POLL_FIRST)
392                 return -EINVAL;
393         sr->msg_flags = READ_ONCE(sqe->msg_flags) | MSG_NOSIGNAL;
394         if (sr->msg_flags & MSG_DONTWAIT)
395                 req->flags |= REQ_F_NOWAIT;
396
397 #ifdef CONFIG_COMPAT
398         if (req->ctx->compat)
399                 sr->msg_flags |= MSG_CMSG_COMPAT;
400 #endif
401         return 0;
402 }
403
404 static void io_req_msg_cleanup(struct io_kiocb *req,
405                                struct io_async_msghdr *kmsg,
406                                unsigned int issue_flags)
407 {
408         req->flags &= ~REQ_F_NEED_CLEANUP;
409         /* fast path, check for non-NULL to avoid function call */
410         if (kmsg->free_iov)
411                 kfree(kmsg->free_iov);
412         io_netmsg_recycle(req, issue_flags);
413 }
414
415 int io_sendmsg(struct io_kiocb *req, unsigned int issue_flags)
416 {
417         struct io_sr_msg *sr = io_kiocb_to_cmd(req, struct io_sr_msg);
418         struct io_async_msghdr iomsg, *kmsg;
419         struct socket *sock;
420         unsigned flags;
421         int min_ret = 0;
422         int ret;
423
424         sock = sock_from_file(req->file);
425         if (unlikely(!sock))
426                 return -ENOTSOCK;
427
428         if (req_has_async_data(req)) {
429                 kmsg = req->async_data;
430                 kmsg->msg.msg_control_user = sr->msg_control;
431         } else {
432                 ret = io_sendmsg_copy_hdr(req, &iomsg);
433                 if (ret)
434                         return ret;
435                 kmsg = &iomsg;
436         }
437
438         if (!(req->flags & REQ_F_POLLED) &&
439             (sr->flags & IORING_RECVSEND_POLL_FIRST))
440                 return io_setup_async_msg(req, kmsg, issue_flags);
441
442         flags = sr->msg_flags;
443         if (issue_flags & IO_URING_F_NONBLOCK)
444                 flags |= MSG_DONTWAIT;
445         if (flags & MSG_WAITALL)
446                 min_ret = iov_iter_count(&kmsg->msg.msg_iter);
447
448         ret = __sys_sendmsg_sock(sock, &kmsg->msg, flags);
449
450         if (ret < min_ret) {
451                 if (ret == -EAGAIN && (issue_flags & IO_URING_F_NONBLOCK))
452                         return io_setup_async_msg(req, kmsg, issue_flags);
453                 if (ret > 0 && io_net_retry(sock, flags)) {
454                         kmsg->msg.msg_controllen = 0;
455                         kmsg->msg.msg_control = NULL;
456                         sr->done_io += ret;
457                         req->flags |= REQ_F_BL_NO_RECYCLE;
458                         return io_setup_async_msg(req, kmsg, issue_flags);
459                 }
460                 if (ret == -ERESTARTSYS)
461                         ret = -EINTR;
462                 req_set_fail(req);
463         }
464         io_req_msg_cleanup(req, kmsg, issue_flags);
465         if (ret >= 0)
466                 ret += sr->done_io;
467         else if (sr->done_io)
468                 ret = sr->done_io;
469         io_req_set_res(req, ret, 0);
470         return IOU_OK;
471 }
472
473 int io_send(struct io_kiocb *req, unsigned int issue_flags)
474 {
475         struct sockaddr_storage __address;
476         struct io_sr_msg *sr = io_kiocb_to_cmd(req, struct io_sr_msg);
477         struct msghdr msg;
478         struct socket *sock;
479         unsigned flags;
480         int min_ret = 0;
481         int ret;
482
483         msg.msg_name = NULL;
484         msg.msg_control = NULL;
485         msg.msg_controllen = 0;
486         msg.msg_namelen = 0;
487         msg.msg_ubuf = NULL;
488
489         if (sr->addr) {
490                 if (req_has_async_data(req)) {
491                         struct io_async_msghdr *io = req->async_data;
492
493                         msg.msg_name = &io->addr;
494                 } else {
495                         ret = move_addr_to_kernel(sr->addr, sr->addr_len, &__address);
496                         if (unlikely(ret < 0))
497                                 return ret;
498                         msg.msg_name = (struct sockaddr *)&__address;
499                 }
500                 msg.msg_namelen = sr->addr_len;
501         }
502
503         if (!(req->flags & REQ_F_POLLED) &&
504             (sr->flags & IORING_RECVSEND_POLL_FIRST))
505                 return io_setup_async_addr(req, &__address, issue_flags);
506
507         sock = sock_from_file(req->file);
508         if (unlikely(!sock))
509                 return -ENOTSOCK;
510
511         ret = import_ubuf(ITER_SOURCE, sr->buf, sr->len, &msg.msg_iter);
512         if (unlikely(ret))
513                 return ret;
514
515         flags = sr->msg_flags;
516         if (issue_flags & IO_URING_F_NONBLOCK)
517                 flags |= MSG_DONTWAIT;
518         if (flags & MSG_WAITALL)
519                 min_ret = iov_iter_count(&msg.msg_iter);
520
521         flags &= ~MSG_INTERNAL_SENDMSG_FLAGS;
522         msg.msg_flags = flags;
523         ret = sock_sendmsg(sock, &msg);
524         if (ret < min_ret) {
525                 if (ret == -EAGAIN && (issue_flags & IO_URING_F_NONBLOCK))
526                         return io_setup_async_addr(req, &__address, issue_flags);
527
528                 if (ret > 0 && io_net_retry(sock, flags)) {
529                         sr->len -= ret;
530                         sr->buf += ret;
531                         sr->done_io += ret;
532                         req->flags |= REQ_F_BL_NO_RECYCLE;
533                         return io_setup_async_addr(req, &__address, issue_flags);
534                 }
535                 if (ret == -ERESTARTSYS)
536                         ret = -EINTR;
537                 req_set_fail(req);
538         }
539         if (ret >= 0)
540                 ret += sr->done_io;
541         else if (sr->done_io)
542                 ret = sr->done_io;
543         io_req_set_res(req, ret, 0);
544         return IOU_OK;
545 }
546
547 static int io_recvmsg_mshot_prep(struct io_kiocb *req,
548                                  struct io_async_msghdr *iomsg,
549                                  int namelen, size_t controllen)
550 {
551         if ((req->flags & (REQ_F_APOLL_MULTISHOT|REQ_F_BUFFER_SELECT)) ==
552                           (REQ_F_APOLL_MULTISHOT|REQ_F_BUFFER_SELECT)) {
553                 int hdr;
554
555                 if (unlikely(namelen < 0))
556                         return -EOVERFLOW;
557                 if (check_add_overflow(sizeof(struct io_uring_recvmsg_out),
558                                         namelen, &hdr))
559                         return -EOVERFLOW;
560                 if (check_add_overflow(hdr, controllen, &hdr))
561                         return -EOVERFLOW;
562
563                 iomsg->namelen = namelen;
564                 iomsg->controllen = controllen;
565                 return 0;
566         }
567
568         return 0;
569 }
570
571 static int io_recvmsg_copy_hdr(struct io_kiocb *req,
572                                struct io_async_msghdr *iomsg)
573 {
574         struct user_msghdr msg;
575         int ret;
576
577         iomsg->msg.msg_name = &iomsg->addr;
578         iomsg->msg.msg_iter.nr_segs = 0;
579
580 #ifdef CONFIG_COMPAT
581         if (unlikely(req->ctx->compat)) {
582                 struct compat_msghdr cmsg;
583
584                 ret = io_compat_msg_copy_hdr(req, iomsg, &cmsg, ITER_DEST);
585                 if (unlikely(ret))
586                         return ret;
587
588                 ret = __get_compat_msghdr(&iomsg->msg, &cmsg, &iomsg->uaddr);
589                 if (unlikely(ret))
590                         return ret;
591
592                 return io_recvmsg_mshot_prep(req, iomsg, cmsg.msg_namelen,
593                                                 cmsg.msg_controllen);
594         }
595 #endif
596
597         ret = io_msg_copy_hdr(req, iomsg, &msg, ITER_DEST);
598         if (unlikely(ret))
599                 return ret;
600
601         ret = __copy_msghdr(&iomsg->msg, &msg, &iomsg->uaddr);
602         if (unlikely(ret))
603                 return ret;
604
605         return io_recvmsg_mshot_prep(req, iomsg, msg.msg_namelen,
606                                         msg.msg_controllen);
607 }
608
609 int io_recvmsg_prep_async(struct io_kiocb *req)
610 {
611         struct io_async_msghdr *iomsg;
612         int ret;
613
614         if (!io_msg_alloc_async_prep(req))
615                 return -ENOMEM;
616         iomsg = req->async_data;
617         ret = io_recvmsg_copy_hdr(req, iomsg);
618         if (!ret)
619                 req->flags |= REQ_F_NEED_CLEANUP;
620         return ret;
621 }
622
623 #define RECVMSG_FLAGS (IORING_RECVSEND_POLL_FIRST | IORING_RECV_MULTISHOT)
624
625 int io_recvmsg_prep(struct io_kiocb *req, const struct io_uring_sqe *sqe)
626 {
627         struct io_sr_msg *sr = io_kiocb_to_cmd(req, struct io_sr_msg);
628
629         sr->done_io = 0;
630
631         if (unlikely(sqe->file_index || sqe->addr2))
632                 return -EINVAL;
633
634         sr->umsg = u64_to_user_ptr(READ_ONCE(sqe->addr));
635         sr->len = READ_ONCE(sqe->len);
636         sr->flags = READ_ONCE(sqe->ioprio);
637         if (sr->flags & ~(RECVMSG_FLAGS))
638                 return -EINVAL;
639         sr->msg_flags = READ_ONCE(sqe->msg_flags);
640         if (sr->msg_flags & MSG_DONTWAIT)
641                 req->flags |= REQ_F_NOWAIT;
642         if (sr->msg_flags & MSG_ERRQUEUE)
643                 req->flags |= REQ_F_CLEAR_POLLIN;
644         if (sr->flags & IORING_RECV_MULTISHOT) {
645                 if (!(req->flags & REQ_F_BUFFER_SELECT))
646                         return -EINVAL;
647                 if (sr->msg_flags & MSG_WAITALL)
648                         return -EINVAL;
649                 if (req->opcode == IORING_OP_RECV && sr->len)
650                         return -EINVAL;
651                 req->flags |= REQ_F_APOLL_MULTISHOT;
652                 /*
653                  * Store the buffer group for this multishot receive separately,
654                  * as if we end up doing an io-wq based issue that selects a
655                  * buffer, it has to be committed immediately and that will
656                  * clear ->buf_list. This means we lose the link to the buffer
657                  * list, and the eventual buffer put on completion then cannot
658                  * restore it.
659                  */
660                 sr->buf_group = req->buf_index;
661         }
662
663 #ifdef CONFIG_COMPAT
664         if (req->ctx->compat)
665                 sr->msg_flags |= MSG_CMSG_COMPAT;
666 #endif
667         sr->nr_multishot_loops = 0;
668         return 0;
669 }
670
671 static inline void io_recv_prep_retry(struct io_kiocb *req)
672 {
673         struct io_sr_msg *sr = io_kiocb_to_cmd(req, struct io_sr_msg);
674
675         req->flags &= ~REQ_F_BL_EMPTY;
676         sr->done_io = 0;
677         sr->len = 0; /* get from the provided buffer */
678         req->buf_index = sr->buf_group;
679 }
680
681 /*
682  * Finishes io_recv and io_recvmsg.
683  *
684  * Returns true if it is actually finished, or false if it should run
685  * again (for multishot).
686  */
687 static inline bool io_recv_finish(struct io_kiocb *req, int *ret,
688                                   struct msghdr *msg, bool mshot_finished,
689                                   unsigned issue_flags)
690 {
691         unsigned int cflags;
692
693         cflags = io_put_kbuf(req, issue_flags);
694         if (msg->msg_inq > 0)
695                 cflags |= IORING_CQE_F_SOCK_NONEMPTY;
696
697         /*
698          * Fill CQE for this receive and see if we should keep trying to
699          * receive from this socket.
700          */
701         if ((req->flags & REQ_F_APOLL_MULTISHOT) && !mshot_finished &&
702             io_fill_cqe_req_aux(req, issue_flags & IO_URING_F_COMPLETE_DEFER,
703                                 *ret, cflags | IORING_CQE_F_MORE)) {
704                 struct io_sr_msg *sr = io_kiocb_to_cmd(req, struct io_sr_msg);
705                 int mshot_retry_ret = IOU_ISSUE_SKIP_COMPLETE;
706
707                 io_recv_prep_retry(req);
708                 /* Known not-empty or unknown state, retry */
709                 if (cflags & IORING_CQE_F_SOCK_NONEMPTY || msg->msg_inq < 0) {
710                         if (sr->nr_multishot_loops++ < MULTISHOT_MAX_RETRY)
711                                 return false;
712                         /* mshot retries exceeded, force a requeue */
713                         sr->nr_multishot_loops = 0;
714                         mshot_retry_ret = IOU_REQUEUE;
715                 }
716                 if (issue_flags & IO_URING_F_MULTISHOT)
717                         *ret = mshot_retry_ret;
718                 else
719                         *ret = -EAGAIN;
720                 return true;
721         }
722
723         /* Finish the request / stop multishot. */
724         io_req_set_res(req, *ret, cflags);
725
726         if (issue_flags & IO_URING_F_MULTISHOT)
727                 *ret = IOU_STOP_MULTISHOT;
728         else
729                 *ret = IOU_OK;
730         return true;
731 }
732
733 static int io_recvmsg_prep_multishot(struct io_async_msghdr *kmsg,
734                                      struct io_sr_msg *sr, void __user **buf,
735                                      size_t *len)
736 {
737         unsigned long ubuf = (unsigned long) *buf;
738         unsigned long hdr;
739
740         hdr = sizeof(struct io_uring_recvmsg_out) + kmsg->namelen +
741                 kmsg->controllen;
742         if (*len < hdr)
743                 return -EFAULT;
744
745         if (kmsg->controllen) {
746                 unsigned long control = ubuf + hdr - kmsg->controllen;
747
748                 kmsg->msg.msg_control_user = (void __user *) control;
749                 kmsg->msg.msg_controllen = kmsg->controllen;
750         }
751
752         sr->buf = *buf; /* stash for later copy */
753         *buf = (void __user *) (ubuf + hdr);
754         kmsg->payloadlen = *len = *len - hdr;
755         return 0;
756 }
757
758 struct io_recvmsg_multishot_hdr {
759         struct io_uring_recvmsg_out msg;
760         struct sockaddr_storage addr;
761 };
762
763 static int io_recvmsg_multishot(struct socket *sock, struct io_sr_msg *io,
764                                 struct io_async_msghdr *kmsg,
765                                 unsigned int flags, bool *finished)
766 {
767         int err;
768         int copy_len;
769         struct io_recvmsg_multishot_hdr hdr;
770
771         if (kmsg->namelen)
772                 kmsg->msg.msg_name = &hdr.addr;
773         kmsg->msg.msg_flags = flags & (MSG_CMSG_CLOEXEC|MSG_CMSG_COMPAT);
774         kmsg->msg.msg_namelen = 0;
775
776         if (sock->file->f_flags & O_NONBLOCK)
777                 flags |= MSG_DONTWAIT;
778
779         err = sock_recvmsg(sock, &kmsg->msg, flags);
780         *finished = err <= 0;
781         if (err < 0)
782                 return err;
783
784         hdr.msg = (struct io_uring_recvmsg_out) {
785                 .controllen = kmsg->controllen - kmsg->msg.msg_controllen,
786                 .flags = kmsg->msg.msg_flags & ~MSG_CMSG_COMPAT
787         };
788
789         hdr.msg.payloadlen = err;
790         if (err > kmsg->payloadlen)
791                 err = kmsg->payloadlen;
792
793         copy_len = sizeof(struct io_uring_recvmsg_out);
794         if (kmsg->msg.msg_namelen > kmsg->namelen)
795                 copy_len += kmsg->namelen;
796         else
797                 copy_len += kmsg->msg.msg_namelen;
798
799         /*
800          *      "fromlen shall refer to the value before truncation.."
801          *                      1003.1g
802          */
803         hdr.msg.namelen = kmsg->msg.msg_namelen;
804
805         /* ensure that there is no gap between hdr and sockaddr_storage */
806         BUILD_BUG_ON(offsetof(struct io_recvmsg_multishot_hdr, addr) !=
807                      sizeof(struct io_uring_recvmsg_out));
808         if (copy_to_user(io->buf, &hdr, copy_len)) {
809                 *finished = true;
810                 return -EFAULT;
811         }
812
813         return sizeof(struct io_uring_recvmsg_out) + kmsg->namelen +
814                         kmsg->controllen + err;
815 }
816
817 int io_recvmsg(struct io_kiocb *req, unsigned int issue_flags)
818 {
819         struct io_sr_msg *sr = io_kiocb_to_cmd(req, struct io_sr_msg);
820         struct io_async_msghdr iomsg, *kmsg;
821         struct socket *sock;
822         unsigned flags;
823         int ret, min_ret = 0;
824         bool force_nonblock = issue_flags & IO_URING_F_NONBLOCK;
825         bool mshot_finished = true;
826
827         sock = sock_from_file(req->file);
828         if (unlikely(!sock))
829                 return -ENOTSOCK;
830
831         if (req_has_async_data(req)) {
832                 kmsg = req->async_data;
833         } else {
834                 ret = io_recvmsg_copy_hdr(req, &iomsg);
835                 if (ret)
836                         return ret;
837                 kmsg = &iomsg;
838         }
839
840         if (!(req->flags & REQ_F_POLLED) &&
841             (sr->flags & IORING_RECVSEND_POLL_FIRST))
842                 return io_setup_async_msg(req, kmsg, issue_flags);
843
844         flags = sr->msg_flags;
845         if (force_nonblock)
846                 flags |= MSG_DONTWAIT;
847
848 retry_multishot:
849         if (io_do_buffer_select(req)) {
850                 void __user *buf;
851                 size_t len = sr->len;
852
853                 buf = io_buffer_select(req, &len, issue_flags);
854                 if (!buf)
855                         return -ENOBUFS;
856
857                 if (req->flags & REQ_F_APOLL_MULTISHOT) {
858                         ret = io_recvmsg_prep_multishot(kmsg, sr, &buf, &len);
859                         if (ret) {
860                                 io_kbuf_recycle(req, issue_flags);
861                                 return ret;
862                         }
863                 }
864
865                 iov_iter_ubuf(&kmsg->msg.msg_iter, ITER_DEST, buf, len);
866         }
867
868         kmsg->msg.msg_get_inq = 1;
869         kmsg->msg.msg_inq = -1;
870         if (req->flags & REQ_F_APOLL_MULTISHOT) {
871                 ret = io_recvmsg_multishot(sock, sr, kmsg, flags,
872                                            &mshot_finished);
873         } else {
874                 /* disable partial retry for recvmsg with cmsg attached */
875                 if (flags & MSG_WAITALL && !kmsg->msg.msg_controllen)
876                         min_ret = iov_iter_count(&kmsg->msg.msg_iter);
877
878                 ret = __sys_recvmsg_sock(sock, &kmsg->msg, sr->umsg,
879                                          kmsg->uaddr, flags);
880         }
881
882         if (ret < min_ret) {
883                 if (ret == -EAGAIN && force_nonblock) {
884                         ret = io_setup_async_msg(req, kmsg, issue_flags);
885                         if (ret == -EAGAIN && (issue_flags & IO_URING_F_MULTISHOT)) {
886                                 io_kbuf_recycle(req, issue_flags);
887                                 return IOU_ISSUE_SKIP_COMPLETE;
888                         }
889                         return ret;
890                 }
891                 if (ret > 0 && io_net_retry(sock, flags)) {
892                         sr->done_io += ret;
893                         req->flags |= REQ_F_BL_NO_RECYCLE;
894                         return io_setup_async_msg(req, kmsg, issue_flags);
895                 }
896                 if (ret == -ERESTARTSYS)
897                         ret = -EINTR;
898                 req_set_fail(req);
899         } else if ((flags & MSG_WAITALL) && (kmsg->msg.msg_flags & (MSG_TRUNC | MSG_CTRUNC))) {
900                 req_set_fail(req);
901         }
902
903         if (ret > 0)
904                 ret += sr->done_io;
905         else if (sr->done_io)
906                 ret = sr->done_io;
907         else
908                 io_kbuf_recycle(req, issue_flags);
909
910         if (!io_recv_finish(req, &ret, &kmsg->msg, mshot_finished, issue_flags))
911                 goto retry_multishot;
912
913         if (mshot_finished)
914                 io_req_msg_cleanup(req, kmsg, issue_flags);
915         else if (ret == -EAGAIN)
916                 return io_setup_async_msg(req, kmsg, issue_flags);
917
918         return ret;
919 }
920
921 int io_recv(struct io_kiocb *req, unsigned int issue_flags)
922 {
923         struct io_sr_msg *sr = io_kiocb_to_cmd(req, struct io_sr_msg);
924         struct msghdr msg;
925         struct socket *sock;
926         unsigned flags;
927         int ret, min_ret = 0;
928         bool force_nonblock = issue_flags & IO_URING_F_NONBLOCK;
929         size_t len = sr->len;
930
931         if (!(req->flags & REQ_F_POLLED) &&
932             (sr->flags & IORING_RECVSEND_POLL_FIRST))
933                 return -EAGAIN;
934
935         sock = sock_from_file(req->file);
936         if (unlikely(!sock))
937                 return -ENOTSOCK;
938
939         msg.msg_name = NULL;
940         msg.msg_namelen = 0;
941         msg.msg_control = NULL;
942         msg.msg_get_inq = 1;
943         msg.msg_controllen = 0;
944         msg.msg_iocb = NULL;
945         msg.msg_ubuf = NULL;
946
947         flags = sr->msg_flags;
948         if (force_nonblock)
949                 flags |= MSG_DONTWAIT;
950
951 retry_multishot:
952         if (io_do_buffer_select(req)) {
953                 void __user *buf;
954
955                 buf = io_buffer_select(req, &len, issue_flags);
956                 if (!buf)
957                         return -ENOBUFS;
958                 sr->buf = buf;
959                 sr->len = len;
960         }
961
962         ret = import_ubuf(ITER_DEST, sr->buf, len, &msg.msg_iter);
963         if (unlikely(ret))
964                 goto out_free;
965
966         msg.msg_inq = -1;
967         msg.msg_flags = 0;
968
969         if (flags & MSG_WAITALL)
970                 min_ret = iov_iter_count(&msg.msg_iter);
971
972         ret = sock_recvmsg(sock, &msg, flags);
973         if (ret < min_ret) {
974                 if (ret == -EAGAIN && force_nonblock) {
975                         if (issue_flags & IO_URING_F_MULTISHOT) {
976                                 io_kbuf_recycle(req, issue_flags);
977                                 return IOU_ISSUE_SKIP_COMPLETE;
978                         }
979
980                         return -EAGAIN;
981                 }
982                 if (ret > 0 && io_net_retry(sock, flags)) {
983                         sr->len -= ret;
984                         sr->buf += ret;
985                         sr->done_io += ret;
986                         req->flags |= REQ_F_BL_NO_RECYCLE;
987                         return -EAGAIN;
988                 }
989                 if (ret == -ERESTARTSYS)
990                         ret = -EINTR;
991                 req_set_fail(req);
992         } else if ((flags & MSG_WAITALL) && (msg.msg_flags & (MSG_TRUNC | MSG_CTRUNC))) {
993 out_free:
994                 req_set_fail(req);
995         }
996
997         if (ret > 0)
998                 ret += sr->done_io;
999         else if (sr->done_io)
1000                 ret = sr->done_io;
1001         else
1002                 io_kbuf_recycle(req, issue_flags);
1003
1004         if (!io_recv_finish(req, &ret, &msg, ret <= 0, issue_flags))
1005                 goto retry_multishot;
1006
1007         return ret;
1008 }
1009
1010 void io_send_zc_cleanup(struct io_kiocb *req)
1011 {
1012         struct io_sr_msg *zc = io_kiocb_to_cmd(req, struct io_sr_msg);
1013         struct io_async_msghdr *io;
1014
1015         if (req_has_async_data(req)) {
1016                 io = req->async_data;
1017                 /* might be ->fast_iov if *msg_copy_hdr failed */
1018                 if (io->free_iov != io->fast_iov)
1019                         kfree(io->free_iov);
1020         }
1021         if (zc->notif) {
1022                 io_notif_flush(zc->notif);
1023                 zc->notif = NULL;
1024         }
1025 }
1026
1027 #define IO_ZC_FLAGS_COMMON (IORING_RECVSEND_POLL_FIRST | IORING_RECVSEND_FIXED_BUF)
1028 #define IO_ZC_FLAGS_VALID  (IO_ZC_FLAGS_COMMON | IORING_SEND_ZC_REPORT_USAGE)
1029
1030 int io_send_zc_prep(struct io_kiocb *req, const struct io_uring_sqe *sqe)
1031 {
1032         struct io_sr_msg *zc = io_kiocb_to_cmd(req, struct io_sr_msg);
1033         struct io_ring_ctx *ctx = req->ctx;
1034         struct io_kiocb *notif;
1035
1036         zc->done_io = 0;
1037
1038         if (unlikely(READ_ONCE(sqe->__pad2[0]) || READ_ONCE(sqe->addr3)))
1039                 return -EINVAL;
1040         /* we don't support IOSQE_CQE_SKIP_SUCCESS just yet */
1041         if (req->flags & REQ_F_CQE_SKIP)
1042                 return -EINVAL;
1043
1044         notif = zc->notif = io_alloc_notif(ctx);
1045         if (!notif)
1046                 return -ENOMEM;
1047         notif->cqe.user_data = req->cqe.user_data;
1048         notif->cqe.res = 0;
1049         notif->cqe.flags = IORING_CQE_F_NOTIF;
1050         req->flags |= REQ_F_NEED_CLEANUP;
1051
1052         zc->flags = READ_ONCE(sqe->ioprio);
1053         if (unlikely(zc->flags & ~IO_ZC_FLAGS_COMMON)) {
1054                 if (zc->flags & ~IO_ZC_FLAGS_VALID)
1055                         return -EINVAL;
1056                 if (zc->flags & IORING_SEND_ZC_REPORT_USAGE) {
1057                         io_notif_set_extended(notif);
1058                         io_notif_to_data(notif)->zc_report = true;
1059                 }
1060         }
1061
1062         if (zc->flags & IORING_RECVSEND_FIXED_BUF) {
1063                 unsigned idx = READ_ONCE(sqe->buf_index);
1064
1065                 if (unlikely(idx >= ctx->nr_user_bufs))
1066                         return -EFAULT;
1067                 idx = array_index_nospec(idx, ctx->nr_user_bufs);
1068                 req->imu = READ_ONCE(ctx->user_bufs[idx]);
1069                 io_req_set_rsrc_node(notif, ctx, 0);
1070         }
1071
1072         if (req->opcode == IORING_OP_SEND_ZC) {
1073                 if (READ_ONCE(sqe->__pad3[0]))
1074                         return -EINVAL;
1075                 zc->addr = u64_to_user_ptr(READ_ONCE(sqe->addr2));
1076                 zc->addr_len = READ_ONCE(sqe->addr_len);
1077         } else {
1078                 if (unlikely(sqe->addr2 || sqe->file_index))
1079                         return -EINVAL;
1080                 if (unlikely(zc->flags & IORING_RECVSEND_FIXED_BUF))
1081                         return -EINVAL;
1082         }
1083
1084         zc->buf = u64_to_user_ptr(READ_ONCE(sqe->addr));
1085         zc->len = READ_ONCE(sqe->len);
1086         zc->msg_flags = READ_ONCE(sqe->msg_flags) | MSG_NOSIGNAL;
1087         if (zc->msg_flags & MSG_DONTWAIT)
1088                 req->flags |= REQ_F_NOWAIT;
1089
1090 #ifdef CONFIG_COMPAT
1091         if (req->ctx->compat)
1092                 zc->msg_flags |= MSG_CMSG_COMPAT;
1093 #endif
1094         return 0;
1095 }
1096
1097 static int io_sg_from_iter_iovec(struct sock *sk, struct sk_buff *skb,
1098                                  struct iov_iter *from, size_t length)
1099 {
1100         skb_zcopy_downgrade_managed(skb);
1101         return __zerocopy_sg_from_iter(NULL, sk, skb, from, length);
1102 }
1103
1104 static int io_sg_from_iter(struct sock *sk, struct sk_buff *skb,
1105                            struct iov_iter *from, size_t length)
1106 {
1107         struct skb_shared_info *shinfo = skb_shinfo(skb);
1108         int frag = shinfo->nr_frags;
1109         int ret = 0;
1110         struct bvec_iter bi;
1111         ssize_t copied = 0;
1112         unsigned long truesize = 0;
1113
1114         if (!frag)
1115                 shinfo->flags |= SKBFL_MANAGED_FRAG_REFS;
1116         else if (unlikely(!skb_zcopy_managed(skb)))
1117                 return __zerocopy_sg_from_iter(NULL, sk, skb, from, length);
1118
1119         bi.bi_size = min(from->count, length);
1120         bi.bi_bvec_done = from->iov_offset;
1121         bi.bi_idx = 0;
1122
1123         while (bi.bi_size && frag < MAX_SKB_FRAGS) {
1124                 struct bio_vec v = mp_bvec_iter_bvec(from->bvec, bi);
1125
1126                 copied += v.bv_len;
1127                 truesize += PAGE_ALIGN(v.bv_len + v.bv_offset);
1128                 __skb_fill_page_desc_noacc(shinfo, frag++, v.bv_page,
1129                                            v.bv_offset, v.bv_len);
1130                 bvec_iter_advance_single(from->bvec, &bi, v.bv_len);
1131         }
1132         if (bi.bi_size)
1133                 ret = -EMSGSIZE;
1134
1135         shinfo->nr_frags = frag;
1136         from->bvec += bi.bi_idx;
1137         from->nr_segs -= bi.bi_idx;
1138         from->count -= copied;
1139         from->iov_offset = bi.bi_bvec_done;
1140
1141         skb->data_len += copied;
1142         skb->len += copied;
1143         skb->truesize += truesize;
1144
1145         if (sk && sk->sk_type == SOCK_STREAM) {
1146                 sk_wmem_queued_add(sk, truesize);
1147                 if (!skb_zcopy_pure(skb))
1148                         sk_mem_charge(sk, truesize);
1149         } else {
1150                 refcount_add(truesize, &skb->sk->sk_wmem_alloc);
1151         }
1152         return ret;
1153 }
1154
1155 int io_send_zc(struct io_kiocb *req, unsigned int issue_flags)
1156 {
1157         struct sockaddr_storage __address;
1158         struct io_sr_msg *zc = io_kiocb_to_cmd(req, struct io_sr_msg);
1159         struct msghdr msg;
1160         struct socket *sock;
1161         unsigned msg_flags;
1162         int ret, min_ret = 0;
1163
1164         sock = sock_from_file(req->file);
1165         if (unlikely(!sock))
1166                 return -ENOTSOCK;
1167         if (!test_bit(SOCK_SUPPORT_ZC, &sock->flags))
1168                 return -EOPNOTSUPP;
1169
1170         msg.msg_name = NULL;
1171         msg.msg_control = NULL;
1172         msg.msg_controllen = 0;
1173         msg.msg_namelen = 0;
1174
1175         if (zc->addr) {
1176                 if (req_has_async_data(req)) {
1177                         struct io_async_msghdr *io = req->async_data;
1178
1179                         msg.msg_name = &io->addr;
1180                 } else {
1181                         ret = move_addr_to_kernel(zc->addr, zc->addr_len, &__address);
1182                         if (unlikely(ret < 0))
1183                                 return ret;
1184                         msg.msg_name = (struct sockaddr *)&__address;
1185                 }
1186                 msg.msg_namelen = zc->addr_len;
1187         }
1188
1189         if (!(req->flags & REQ_F_POLLED) &&
1190             (zc->flags & IORING_RECVSEND_POLL_FIRST))
1191                 return io_setup_async_addr(req, &__address, issue_flags);
1192
1193         if (zc->flags & IORING_RECVSEND_FIXED_BUF) {
1194                 ret = io_import_fixed(ITER_SOURCE, &msg.msg_iter, req->imu,
1195                                         (u64)(uintptr_t)zc->buf, zc->len);
1196                 if (unlikely(ret))
1197                         return ret;
1198                 msg.sg_from_iter = io_sg_from_iter;
1199         } else {
1200                 io_notif_set_extended(zc->notif);
1201                 ret = import_ubuf(ITER_SOURCE, zc->buf, zc->len, &msg.msg_iter);
1202                 if (unlikely(ret))
1203                         return ret;
1204                 ret = io_notif_account_mem(zc->notif, zc->len);
1205                 if (unlikely(ret))
1206                         return ret;
1207                 msg.sg_from_iter = io_sg_from_iter_iovec;
1208         }
1209
1210         msg_flags = zc->msg_flags | MSG_ZEROCOPY;
1211         if (issue_flags & IO_URING_F_NONBLOCK)
1212                 msg_flags |= MSG_DONTWAIT;
1213         if (msg_flags & MSG_WAITALL)
1214                 min_ret = iov_iter_count(&msg.msg_iter);
1215         msg_flags &= ~MSG_INTERNAL_SENDMSG_FLAGS;
1216
1217         msg.msg_flags = msg_flags;
1218         msg.msg_ubuf = &io_notif_to_data(zc->notif)->uarg;
1219         ret = sock_sendmsg(sock, &msg);
1220
1221         if (unlikely(ret < min_ret)) {
1222                 if (ret == -EAGAIN && (issue_flags & IO_URING_F_NONBLOCK))
1223                         return io_setup_async_addr(req, &__address, issue_flags);
1224
1225                 if (ret > 0 && io_net_retry(sock, msg.msg_flags)) {
1226                         zc->len -= ret;
1227                         zc->buf += ret;
1228                         zc->done_io += ret;
1229                         req->flags |= REQ_F_BL_NO_RECYCLE;
1230                         return io_setup_async_addr(req, &__address, issue_flags);
1231                 }
1232                 if (ret == -ERESTARTSYS)
1233                         ret = -EINTR;
1234                 req_set_fail(req);
1235         }
1236
1237         if (ret >= 0)
1238                 ret += zc->done_io;
1239         else if (zc->done_io)
1240                 ret = zc->done_io;
1241
1242         /*
1243          * If we're in io-wq we can't rely on tw ordering guarantees, defer
1244          * flushing notif to io_send_zc_cleanup()
1245          */
1246         if (!(issue_flags & IO_URING_F_UNLOCKED)) {
1247                 io_notif_flush(zc->notif);
1248                 req->flags &= ~REQ_F_NEED_CLEANUP;
1249         }
1250         io_req_set_res(req, ret, IORING_CQE_F_MORE);
1251         return IOU_OK;
1252 }
1253
1254 int io_sendmsg_zc(struct io_kiocb *req, unsigned int issue_flags)
1255 {
1256         struct io_sr_msg *sr = io_kiocb_to_cmd(req, struct io_sr_msg);
1257         struct io_async_msghdr iomsg, *kmsg;
1258         struct socket *sock;
1259         unsigned flags;
1260         int ret, min_ret = 0;
1261
1262         io_notif_set_extended(sr->notif);
1263
1264         sock = sock_from_file(req->file);
1265         if (unlikely(!sock))
1266                 return -ENOTSOCK;
1267         if (!test_bit(SOCK_SUPPORT_ZC, &sock->flags))
1268                 return -EOPNOTSUPP;
1269
1270         if (req_has_async_data(req)) {
1271                 kmsg = req->async_data;
1272         } else {
1273                 ret = io_sendmsg_copy_hdr(req, &iomsg);
1274                 if (ret)
1275                         return ret;
1276                 kmsg = &iomsg;
1277         }
1278
1279         if (!(req->flags & REQ_F_POLLED) &&
1280             (sr->flags & IORING_RECVSEND_POLL_FIRST))
1281                 return io_setup_async_msg(req, kmsg, issue_flags);
1282
1283         flags = sr->msg_flags | MSG_ZEROCOPY;
1284         if (issue_flags & IO_URING_F_NONBLOCK)
1285                 flags |= MSG_DONTWAIT;
1286         if (flags & MSG_WAITALL)
1287                 min_ret = iov_iter_count(&kmsg->msg.msg_iter);
1288
1289         kmsg->msg.msg_ubuf = &io_notif_to_data(sr->notif)->uarg;
1290         kmsg->msg.sg_from_iter = io_sg_from_iter_iovec;
1291         ret = __sys_sendmsg_sock(sock, &kmsg->msg, flags);
1292
1293         if (unlikely(ret < min_ret)) {
1294                 if (ret == -EAGAIN && (issue_flags & IO_URING_F_NONBLOCK))
1295                         return io_setup_async_msg(req, kmsg, issue_flags);
1296
1297                 if (ret > 0 && io_net_retry(sock, flags)) {
1298                         sr->done_io += ret;
1299                         req->flags |= REQ_F_BL_NO_RECYCLE;
1300                         return io_setup_async_msg(req, kmsg, issue_flags);
1301                 }
1302                 if (ret == -ERESTARTSYS)
1303                         ret = -EINTR;
1304                 req_set_fail(req);
1305         }
1306         /* fast path, check for non-NULL to avoid function call */
1307         if (kmsg->free_iov) {
1308                 kfree(kmsg->free_iov);
1309                 kmsg->free_iov = NULL;
1310         }
1311
1312         io_netmsg_recycle(req, issue_flags);
1313         if (ret >= 0)
1314                 ret += sr->done_io;
1315         else if (sr->done_io)
1316                 ret = sr->done_io;
1317
1318         /*
1319          * If we're in io-wq we can't rely on tw ordering guarantees, defer
1320          * flushing notif to io_send_zc_cleanup()
1321          */
1322         if (!(issue_flags & IO_URING_F_UNLOCKED)) {
1323                 io_notif_flush(sr->notif);
1324                 req->flags &= ~REQ_F_NEED_CLEANUP;
1325         }
1326         io_req_set_res(req, ret, IORING_CQE_F_MORE);
1327         return IOU_OK;
1328 }
1329
1330 void io_sendrecv_fail(struct io_kiocb *req)
1331 {
1332         struct io_sr_msg *sr = io_kiocb_to_cmd(req, struct io_sr_msg);
1333
1334         if (sr->done_io)
1335                 req->cqe.res = sr->done_io;
1336
1337         if ((req->flags & REQ_F_NEED_CLEANUP) &&
1338             (req->opcode == IORING_OP_SEND_ZC || req->opcode == IORING_OP_SENDMSG_ZC))
1339                 req->cqe.flags |= IORING_CQE_F_MORE;
1340 }
1341
1342 int io_accept_prep(struct io_kiocb *req, const struct io_uring_sqe *sqe)
1343 {
1344         struct io_accept *accept = io_kiocb_to_cmd(req, struct io_accept);
1345         unsigned flags;
1346
1347         if (sqe->len || sqe->buf_index)
1348                 return -EINVAL;
1349
1350         accept->addr = u64_to_user_ptr(READ_ONCE(sqe->addr));
1351         accept->addr_len = u64_to_user_ptr(READ_ONCE(sqe->addr2));
1352         accept->flags = READ_ONCE(sqe->accept_flags);
1353         accept->nofile = rlimit(RLIMIT_NOFILE);
1354         flags = READ_ONCE(sqe->ioprio);
1355         if (flags & ~IORING_ACCEPT_MULTISHOT)
1356                 return -EINVAL;
1357
1358         accept->file_slot = READ_ONCE(sqe->file_index);
1359         if (accept->file_slot) {
1360                 if (accept->flags & SOCK_CLOEXEC)
1361                         return -EINVAL;
1362                 if (flags & IORING_ACCEPT_MULTISHOT &&
1363                     accept->file_slot != IORING_FILE_INDEX_ALLOC)
1364                         return -EINVAL;
1365         }
1366         if (accept->flags & ~(SOCK_CLOEXEC | SOCK_NONBLOCK))
1367                 return -EINVAL;
1368         if (SOCK_NONBLOCK != O_NONBLOCK && (accept->flags & SOCK_NONBLOCK))
1369                 accept->flags = (accept->flags & ~SOCK_NONBLOCK) | O_NONBLOCK;
1370         if (flags & IORING_ACCEPT_MULTISHOT)
1371                 req->flags |= REQ_F_APOLL_MULTISHOT;
1372         return 0;
1373 }
1374
1375 int io_accept(struct io_kiocb *req, unsigned int issue_flags)
1376 {
1377         struct io_accept *accept = io_kiocb_to_cmd(req, struct io_accept);
1378         bool force_nonblock = issue_flags & IO_URING_F_NONBLOCK;
1379         unsigned int file_flags = force_nonblock ? O_NONBLOCK : 0;
1380         bool fixed = !!accept->file_slot;
1381         struct file *file;
1382         int ret, fd;
1383
1384 retry:
1385         if (!fixed) {
1386                 fd = __get_unused_fd_flags(accept->flags, accept->nofile);
1387                 if (unlikely(fd < 0))
1388                         return fd;
1389         }
1390         file = do_accept(req->file, file_flags, accept->addr, accept->addr_len,
1391                          accept->flags);
1392         if (IS_ERR(file)) {
1393                 if (!fixed)
1394                         put_unused_fd(fd);
1395                 ret = PTR_ERR(file);
1396                 if (ret == -EAGAIN && force_nonblock) {
1397                         /*
1398                          * if it's multishot and polled, we don't need to
1399                          * return EAGAIN to arm the poll infra since it
1400                          * has already been done
1401                          */
1402                         if (issue_flags & IO_URING_F_MULTISHOT)
1403                                 return IOU_ISSUE_SKIP_COMPLETE;
1404                         return ret;
1405                 }
1406                 if (ret == -ERESTARTSYS)
1407                         ret = -EINTR;
1408                 req_set_fail(req);
1409         } else if (!fixed) {
1410                 fd_install(fd, file);
1411                 ret = fd;
1412         } else {
1413                 ret = io_fixed_fd_install(req, issue_flags, file,
1414                                                 accept->file_slot);
1415         }
1416
1417         if (!(req->flags & REQ_F_APOLL_MULTISHOT)) {
1418                 io_req_set_res(req, ret, 0);
1419                 return IOU_OK;
1420         }
1421
1422         if (ret < 0)
1423                 return ret;
1424         if (io_fill_cqe_req_aux(req, issue_flags & IO_URING_F_COMPLETE_DEFER,
1425                                 ret, IORING_CQE_F_MORE))
1426                 goto retry;
1427
1428         io_req_set_res(req, ret, 0);
1429         return IOU_STOP_MULTISHOT;
1430 }
1431
1432 int io_socket_prep(struct io_kiocb *req, const struct io_uring_sqe *sqe)
1433 {
1434         struct io_socket *sock = io_kiocb_to_cmd(req, struct io_socket);
1435
1436         if (sqe->addr || sqe->rw_flags || sqe->buf_index)
1437                 return -EINVAL;
1438
1439         sock->domain = READ_ONCE(sqe->fd);
1440         sock->type = READ_ONCE(sqe->off);
1441         sock->protocol = READ_ONCE(sqe->len);
1442         sock->file_slot = READ_ONCE(sqe->file_index);
1443         sock->nofile = rlimit(RLIMIT_NOFILE);
1444
1445         sock->flags = sock->type & ~SOCK_TYPE_MASK;
1446         if (sock->file_slot && (sock->flags & SOCK_CLOEXEC))
1447                 return -EINVAL;
1448         if (sock->flags & ~(SOCK_CLOEXEC | SOCK_NONBLOCK))
1449                 return -EINVAL;
1450         return 0;
1451 }
1452
1453 int io_socket(struct io_kiocb *req, unsigned int issue_flags)
1454 {
1455         struct io_socket *sock = io_kiocb_to_cmd(req, struct io_socket);
1456         bool fixed = !!sock->file_slot;
1457         struct file *file;
1458         int ret, fd;
1459
1460         if (!fixed) {
1461                 fd = __get_unused_fd_flags(sock->flags, sock->nofile);
1462                 if (unlikely(fd < 0))
1463                         return fd;
1464         }
1465         file = __sys_socket_file(sock->domain, sock->type, sock->protocol);
1466         if (IS_ERR(file)) {
1467                 if (!fixed)
1468                         put_unused_fd(fd);
1469                 ret = PTR_ERR(file);
1470                 if (ret == -EAGAIN && (issue_flags & IO_URING_F_NONBLOCK))
1471                         return -EAGAIN;
1472                 if (ret == -ERESTARTSYS)
1473                         ret = -EINTR;
1474                 req_set_fail(req);
1475         } else if (!fixed) {
1476                 fd_install(fd, file);
1477                 ret = fd;
1478         } else {
1479                 ret = io_fixed_fd_install(req, issue_flags, file,
1480                                             sock->file_slot);
1481         }
1482         io_req_set_res(req, ret, 0);
1483         return IOU_OK;
1484 }
1485
1486 int io_connect_prep_async(struct io_kiocb *req)
1487 {
1488         struct io_async_connect *io = req->async_data;
1489         struct io_connect *conn = io_kiocb_to_cmd(req, struct io_connect);
1490
1491         return move_addr_to_kernel(conn->addr, conn->addr_len, &io->address);
1492 }
1493
1494 int io_connect_prep(struct io_kiocb *req, const struct io_uring_sqe *sqe)
1495 {
1496         struct io_connect *conn = io_kiocb_to_cmd(req, struct io_connect);
1497
1498         if (sqe->len || sqe->buf_index || sqe->rw_flags || sqe->splice_fd_in)
1499                 return -EINVAL;
1500
1501         conn->addr = u64_to_user_ptr(READ_ONCE(sqe->addr));
1502         conn->addr_len =  READ_ONCE(sqe->addr2);
1503         conn->in_progress = conn->seen_econnaborted = false;
1504         return 0;
1505 }
1506
1507 int io_connect(struct io_kiocb *req, unsigned int issue_flags)
1508 {
1509         struct io_connect *connect = io_kiocb_to_cmd(req, struct io_connect);
1510         struct io_async_connect __io, *io;
1511         unsigned file_flags;
1512         int ret;
1513         bool force_nonblock = issue_flags & IO_URING_F_NONBLOCK;
1514
1515         if (req_has_async_data(req)) {
1516                 io = req->async_data;
1517         } else {
1518                 ret = move_addr_to_kernel(connect->addr,
1519                                                 connect->addr_len,
1520                                                 &__io.address);
1521                 if (ret)
1522                         goto out;
1523                 io = &__io;
1524         }
1525
1526         file_flags = force_nonblock ? O_NONBLOCK : 0;
1527
1528         ret = __sys_connect_file(req->file, &io->address,
1529                                         connect->addr_len, file_flags);
1530         if ((ret == -EAGAIN || ret == -EINPROGRESS || ret == -ECONNABORTED)
1531             && force_nonblock) {
1532                 if (ret == -EINPROGRESS) {
1533                         connect->in_progress = true;
1534                 } else if (ret == -ECONNABORTED) {
1535                         if (connect->seen_econnaborted)
1536                                 goto out;
1537                         connect->seen_econnaborted = true;
1538                 }
1539                 if (req_has_async_data(req))
1540                         return -EAGAIN;
1541                 if (io_alloc_async_data(req)) {
1542                         ret = -ENOMEM;
1543                         goto out;
1544                 }
1545                 memcpy(req->async_data, &__io, sizeof(__io));
1546                 return -EAGAIN;
1547         }
1548         if (connect->in_progress) {
1549                 /*
1550                  * At least bluetooth will return -EBADFD on a re-connect
1551                  * attempt, and it's (supposedly) also valid to get -EISCONN
1552                  * which means the previous result is good. For both of these,
1553                  * grab the sock_error() and use that for the completion.
1554                  */
1555                 if (ret == -EBADFD || ret == -EISCONN)
1556                         ret = sock_error(sock_from_file(req->file)->sk);
1557         }
1558         if (ret == -ERESTARTSYS)
1559                 ret = -EINTR;
1560 out:
1561         if (ret < 0)
1562                 req_set_fail(req);
1563         io_req_set_res(req, ret, 0);
1564         return IOU_OK;
1565 }
1566
1567 void io_netmsg_cache_free(struct io_cache_entry *entry)
1568 {
1569         kfree(container_of(entry, struct io_async_msghdr, cache));
1570 }
1571 #endif