TODO writev_send not always TEVENT_FD_READ
[metze/samba/wip.git] / lib / async_req / async_sock.c
1 /*
2    Unix SMB/CIFS implementation.
3    async socket syscalls
4    Copyright (C) Volker Lendecke 2008
5
6      ** NOTE! The following LGPL license applies to the async_sock
7      ** library. This does NOT imply that all of Samba is released
8      ** under the LGPL
9
10    This library is free software; you can redistribute it and/or
11    modify it under the terms of the GNU Lesser General Public
12    License as published by the Free Software Foundation; either
13    version 3 of the License, or (at your option) any later version.
14
15    This library is distributed in the hope that it will be useful,
16    but WITHOUT ANY WARRANTY; without even the implied warranty of
17    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
18    Library General Public License for more details.
19
20    You should have received a copy of the GNU Lesser General Public License
21    along with this program.  If not, see <http://www.gnu.org/licenses/>.
22 */
23
24 #include "replace.h"
25 #include "system/network.h"
26 #include "system/filesys.h"
27 #include <talloc.h>
28 #include <tevent.h>
29 #include "lib/async_req/async_sock.h"
30
31 /* Note: lib/util/ is currently GPL */
32 #include "lib/util/tevent_unix.h"
33 #include "lib/util/samba_util.h"
34
35 #ifndef TALLOC_FREE
36 #define TALLOC_FREE(ctx) do { talloc_free(ctx); ctx=NULL; } while(0)
37 #endif
38
39 struct sendto_state {
40         int fd;
41         const void *buf;
42         size_t len;
43         int flags;
44         const struct sockaddr_storage *addr;
45         socklen_t addr_len;
46         ssize_t sent;
47 };
48
49 static void sendto_handler(struct tevent_context *ev,
50                                struct tevent_fd *fde,
51                                uint16_t flags, void *private_data);
52
53 struct tevent_req *sendto_send(TALLOC_CTX *mem_ctx, struct tevent_context *ev,
54                                int fd, const void *buf, size_t len, int flags,
55                                const struct sockaddr_storage *addr)
56 {
57         struct tevent_req *result;
58         struct sendto_state *state;
59         struct tevent_fd *fde;
60
61         result = tevent_req_create(mem_ctx, &state, struct sendto_state);
62         if (result == NULL) {
63                 return result;
64         }
65         state->fd = fd;
66         state->buf = buf;
67         state->len = len;
68         state->flags = flags;
69         state->addr = addr;
70
71         switch (addr->ss_family) {
72         case AF_INET:
73                 state->addr_len = sizeof(struct sockaddr_in);
74                 break;
75 #if defined(HAVE_IPV6)
76         case AF_INET6:
77                 state->addr_len = sizeof(struct sockaddr_in6);
78                 break;
79 #endif
80         case AF_UNIX:
81                 state->addr_len = sizeof(struct sockaddr_un);
82                 break;
83         default:
84                 state->addr_len = sizeof(struct sockaddr_storage);
85                 break;
86         }
87
88         fde = tevent_add_fd(ev, state, fd, TEVENT_FD_WRITE, sendto_handler,
89                             result);
90         if (fde == NULL) {
91                 TALLOC_FREE(result);
92                 return NULL;
93         }
94         return result;
95 }
96
97 static void sendto_handler(struct tevent_context *ev,
98                                struct tevent_fd *fde,
99                                uint16_t flags, void *private_data)
100 {
101         struct tevent_req *req = talloc_get_type_abort(
102                 private_data, struct tevent_req);
103         struct sendto_state *state =
104                 tevent_req_data(req, struct sendto_state);
105
106         state->sent = sendto(state->fd, state->buf, state->len, state->flags,
107                              (const struct sockaddr *)state->addr,
108                              state->addr_len);
109         if ((state->sent == -1) && (errno == EINTR)) {
110                 /* retry */
111                 return;
112         }
113         if (state->sent == -1) {
114                 tevent_req_error(req, errno);
115                 return;
116         }
117         tevent_req_done(req);
118 }
119
120 ssize_t sendto_recv(struct tevent_req *req, int *perrno)
121 {
122         struct sendto_state *state =
123                 tevent_req_data(req, struct sendto_state);
124
125         if (tevent_req_is_unix_error(req, perrno)) {
126                 return -1;
127         }
128         return state->sent;
129 }
130
131 struct recvfrom_state {
132         int fd;
133         void *buf;
134         size_t len;
135         int flags;
136         struct sockaddr_storage *addr;
137         socklen_t *addr_len;
138         ssize_t received;
139 };
140
141 static void recvfrom_handler(struct tevent_context *ev,
142                                struct tevent_fd *fde,
143                                uint16_t flags, void *private_data);
144
145 struct tevent_req *recvfrom_send(TALLOC_CTX *mem_ctx,
146                                  struct tevent_context *ev,
147                                  int fd, void *buf, size_t len, int flags,
148                                  struct sockaddr_storage *addr,
149                                  socklen_t *addr_len)
150 {
151         struct tevent_req *result;
152         struct recvfrom_state *state;
153         struct tevent_fd *fde;
154
155         result = tevent_req_create(mem_ctx, &state, struct recvfrom_state);
156         if (result == NULL) {
157                 return result;
158         }
159         state->fd = fd;
160         state->buf = buf;
161         state->len = len;
162         state->flags = flags;
163         state->addr = addr;
164         state->addr_len = addr_len;
165
166         fde = tevent_add_fd(ev, state, fd, TEVENT_FD_READ, recvfrom_handler,
167                             result);
168         if (fde == NULL) {
169                 TALLOC_FREE(result);
170                 return NULL;
171         }
172         return result;
173 }
174
175 static void recvfrom_handler(struct tevent_context *ev,
176                                struct tevent_fd *fde,
177                                uint16_t flags, void *private_data)
178 {
179         struct tevent_req *req = talloc_get_type_abort(
180                 private_data, struct tevent_req);
181         struct recvfrom_state *state =
182                 tevent_req_data(req, struct recvfrom_state);
183
184         state->received = recvfrom(state->fd, state->buf, state->len,
185                                    state->flags, (struct sockaddr *)state->addr,
186                                    state->addr_len);
187         if ((state->received == -1) && (errno == EINTR)) {
188                 /* retry */
189                 return;
190         }
191         if (state->received == 0) {
192                 tevent_req_error(req, EPIPE);
193                 return;
194         }
195         if (state->received == -1) {
196                 tevent_req_error(req, errno);
197                 return;
198         }
199         tevent_req_done(req);
200 }
201
202 ssize_t recvfrom_recv(struct tevent_req *req, int *perrno)
203 {
204         struct recvfrom_state *state =
205                 tevent_req_data(req, struct recvfrom_state);
206
207         if (tevent_req_is_unix_error(req, perrno)) {
208                 return -1;
209         }
210         return state->received;
211 }
212
213 struct async_connect_state {
214         int fd;
215         int result;
216         int sys_errno;
217         long old_sockflags;
218         socklen_t address_len;
219         struct sockaddr_storage address;
220
221         void (*before_connect)(void *private_data);
222         void (*after_connect)(void *private_data);
223         void *private_data;
224 };
225
226 static void async_connect_connected(struct tevent_context *ev,
227                                     struct tevent_fd *fde, uint16_t flags,
228                                     void *priv);
229
230 /**
231  * @brief async version of connect(2)
232  * @param[in] mem_ctx   The memory context to hang the result off
233  * @param[in] ev        The event context to work from
234  * @param[in] fd        The socket to recv from
235  * @param[in] address   Where to connect?
236  * @param[in] address_len Length of *address
237  * @retval The async request
238  *
239  * This function sets the socket into non-blocking state to be able to call
240  * connect in an async state. This will be reset when the request is finished.
241  */
242
243 struct tevent_req *async_connect_send(
244         TALLOC_CTX *mem_ctx, struct tevent_context *ev, int fd,
245         const struct sockaddr *address, socklen_t address_len,
246         void (*before_connect)(void *private_data),
247         void (*after_connect)(void *private_data),
248         void *private_data)
249 {
250         struct tevent_req *result;
251         struct async_connect_state *state;
252         struct tevent_fd *fde;
253
254         result = tevent_req_create(
255                 mem_ctx, &state, struct async_connect_state);
256         if (result == NULL) {
257                 return NULL;
258         }
259
260         /**
261          * We have to set the socket to nonblocking for async connect(2). Keep
262          * the old sockflags around.
263          */
264
265         state->fd = fd;
266         state->sys_errno = 0;
267         state->before_connect = before_connect;
268         state->after_connect = after_connect;
269         state->private_data = private_data;
270
271         state->old_sockflags = fcntl(fd, F_GETFL, 0);
272         if (state->old_sockflags == -1) {
273                 goto post_errno;
274         }
275
276         state->address_len = address_len;
277         if (address_len > sizeof(state->address)) {
278                 errno = EINVAL;
279                 goto post_errno;
280         }
281         memcpy(&state->address, address, address_len);
282
283         set_blocking(fd, false);
284
285         if (state->before_connect != NULL) {
286                 state->before_connect(state->private_data);
287         }
288
289         state->result = connect(fd, address, address_len);
290
291         if (state->after_connect != NULL) {
292                 state->after_connect(state->private_data);
293         }
294
295         if (state->result == 0) {
296                 tevent_req_done(result);
297                 goto done;
298         }
299
300         /**
301          * A number of error messages show that something good is progressing
302          * and that we have to wait for readability.
303          *
304          * If none of them are present, bail out.
305          */
306
307         if (!(errno == EINPROGRESS || errno == EALREADY ||
308 #ifdef EISCONN
309               errno == EISCONN ||
310 #endif
311               errno == EAGAIN || errno == EINTR)) {
312                 state->sys_errno = errno;
313                 goto post_errno;
314         }
315
316         fde = tevent_add_fd(ev, state, fd, TEVENT_FD_READ | TEVENT_FD_WRITE,
317                            async_connect_connected, result);
318         if (fde == NULL) {
319                 state->sys_errno = ENOMEM;
320                 goto post_errno;
321         }
322         return result;
323
324  post_errno:
325         tevent_req_error(result, state->sys_errno);
326  done:
327         fcntl(fd, F_SETFL, state->old_sockflags);
328         return tevent_req_post(result, ev);
329 }
330
331 /**
332  * fde event handler for connect(2)
333  * @param[in] ev        The event context that sent us here
334  * @param[in] fde       The file descriptor event associated with the connect
335  * @param[in] flags     Indicate read/writeability of the socket
336  * @param[in] priv      private data, "struct async_req *" in this case
337  */
338
339 static void async_connect_connected(struct tevent_context *ev,
340                                     struct tevent_fd *fde, uint16_t flags,
341                                     void *priv)
342 {
343         struct tevent_req *req = talloc_get_type_abort(
344                 priv, struct tevent_req);
345         struct async_connect_state *state =
346                 tevent_req_data(req, struct async_connect_state);
347         int ret;
348
349         if (state->before_connect != NULL) {
350                 state->before_connect(state->private_data);
351         }
352
353         ret = connect(state->fd, (struct sockaddr *)(void *)&state->address,
354                       state->address_len);
355
356         if (state->after_connect != NULL) {
357                 state->after_connect(state->private_data);
358         }
359
360         if (ret == 0) {
361                 state->sys_errno = 0;
362                 TALLOC_FREE(fde);
363                 tevent_req_done(req);
364                 return;
365         }
366         if (errno == EINPROGRESS) {
367                 /* Try again later, leave the fde around */
368                 return;
369         }
370         state->sys_errno = errno;
371         TALLOC_FREE(fde);
372         tevent_req_error(req, errno);
373         return;
374 }
375
376 int async_connect_recv(struct tevent_req *req, int *perrno)
377 {
378         struct async_connect_state *state =
379                 tevent_req_data(req, struct async_connect_state);
380         int err;
381
382         fcntl(state->fd, F_SETFL, state->old_sockflags);
383
384         if (tevent_req_is_unix_error(req, &err)) {
385                 *perrno = err;
386                 return -1;
387         }
388
389         if (state->sys_errno == 0) {
390                 return 0;
391         }
392
393         *perrno = state->sys_errno;
394         return -1;
395 }
396
397 struct writev_state {
398         struct tevent_context *ev;
399         int fd;
400         struct iovec *iov;
401         int count;
402         size_t total_size;
403         uint16_t flags;
404         bool err_on_readability;
405 };
406
407 static void writev_trigger(struct tevent_req *req, void *private_data);
408 static void writev_handler(struct tevent_context *ev, struct tevent_fd *fde,
409                            uint16_t flags, void *private_data);
410
411 struct tevent_req *writev_send(TALLOC_CTX *mem_ctx, struct tevent_context *ev,
412                                struct tevent_queue *queue, int fd,
413                                bool err_on_readability,
414                                struct iovec *iov, int count)
415 {
416         struct tevent_req *req;
417         struct writev_state *state;
418
419         req = tevent_req_create(mem_ctx, &state, struct writev_state);
420         if (req == NULL) {
421                 return NULL;
422         }
423         state->ev = ev;
424         state->fd = fd;
425         state->total_size = 0;
426         state->count = count;
427         state->iov = (struct iovec *)talloc_memdup(
428                 state, iov, sizeof(struct iovec) * count);
429         if (state->iov == NULL) {
430                 goto fail;
431         }
432         state->flags = TEVENT_FD_WRITE;
433         state->err_on_readability = err_on_readability;
434         if (state->err_on_readability) {
435                 state->flags |= TEVENT_FD_READ;
436         }
437
438         if (queue == NULL) {
439                 struct tevent_fd *fde;
440                 fde = tevent_add_fd(state->ev, state, state->fd,
441                                     state->flags, writev_handler, req);
442                 if (tevent_req_nomem(fde, req)) {
443                         return tevent_req_post(req, ev);
444                 }
445                 return req;
446         }
447
448         if (!tevent_queue_add(queue, ev, req, writev_trigger, NULL)) {
449                 goto fail;
450         }
451         return req;
452  fail:
453         TALLOC_FREE(req);
454         return NULL;
455 }
456
457 static void writev_trigger(struct tevent_req *req, void *private_data)
458 {
459         struct writev_state *state = tevent_req_data(req, struct writev_state);
460         struct tevent_fd *fde;
461
462         fde = tevent_add_fd(state->ev, state, state->fd, state->flags,
463                             writev_handler, req);
464         if (fde == NULL) {
465                 tevent_req_error(req, ENOMEM);
466         }
467 }
468
469 static void writev_handler(struct tevent_context *ev, struct tevent_fd *fde,
470                            uint16_t flags, void *private_data)
471 {
472         struct tevent_req *req = talloc_get_type_abort(
473                 private_data, struct tevent_req);
474         struct writev_state *state =
475                 tevent_req_data(req, struct writev_state);
476         size_t to_write, written;
477         int i;
478
479         to_write = 0;
480
481         if ((state->flags & TEVENT_FD_READ) && (flags & TEVENT_FD_READ)) {
482                 int ret, value;
483
484                 if (state->err_on_readability) {
485                         /* Readable and the caller wants an error on read. */
486                         tevent_req_error(req, EPIPE);
487                         return;
488                 }
489
490                 /* Might be an error. Check if there are bytes to read */
491                 ret = ioctl(state->fd, FIONREAD, &value);
492                 /* FIXME - should we also check
493                    for ret == 0 and value == 0 here ? */
494                 if (ret == -1) {
495                         /* There's an error. */
496                         tevent_req_error(req, EPIPE);
497                         return;
498                 }
499                 /* A request for TEVENT_FD_READ will succeed from now and
500                    forevermore until the bytes are read so if there was
501                    an error we'll wait until we do read, then get it in
502                    the read callback function. Until then, remove TEVENT_FD_READ
503                    from the flags we're waiting for. */
504                 state->flags &= ~TEVENT_FD_READ;
505                 TEVENT_FD_NOT_READABLE(fde);
506
507                 /* If not writable, we're done. */
508                 if (!(flags & TEVENT_FD_WRITE)) {
509                         return;
510                 }
511         }
512
513         for (i=0; i<state->count; i++) {
514                 to_write += state->iov[i].iov_len;
515         }
516
517         written = writev(state->fd, state->iov, state->count);
518         if ((written == -1) && (errno == EINTR)) {
519                 /* retry */
520                 return;
521         }
522         if (written == -1) {
523                 tevent_req_error(req, errno);
524                 return;
525         }
526         if (written == 0) {
527                 tevent_req_error(req, EPIPE);
528                 return;
529         }
530         state->total_size += written;
531
532         if (written == to_write) {
533                 tevent_req_done(req);
534                 return;
535         }
536
537         /*
538          * We've written less than we were asked to, drop stuff from
539          * state->iov.
540          */
541
542         while (written > 0) {
543                 if (written < state->iov[0].iov_len) {
544                         state->iov[0].iov_base =
545                                 (char *)state->iov[0].iov_base + written;
546                         state->iov[0].iov_len -= written;
547                         break;
548                 }
549                 written -= state->iov[0].iov_len;
550                 state->iov += 1;
551                 state->count -= 1;
552         }
553 }
554
555 ssize_t writev_recv(struct tevent_req *req, int *perrno)
556 {
557         struct writev_state *state =
558                 tevent_req_data(req, struct writev_state);
559
560         if (tevent_req_is_unix_error(req, perrno)) {
561                 return -1;
562         }
563         return state->total_size;
564 }
565
566 struct read_packet_state {
567         int fd;
568         uint8_t *buf;
569         size_t nread;
570         ssize_t (*more)(uint8_t *buf, size_t buflen, void *private_data);
571         void *private_data;
572 };
573
574 static void read_packet_handler(struct tevent_context *ev,
575                                 struct tevent_fd *fde,
576                                 uint16_t flags, void *private_data);
577
578 struct tevent_req *read_packet_send(TALLOC_CTX *mem_ctx,
579                                     struct tevent_context *ev,
580                                     int fd, size_t initial,
581                                     ssize_t (*more)(uint8_t *buf,
582                                                     size_t buflen,
583                                                     void *private_data),
584                                     void *private_data)
585 {
586         struct tevent_req *result;
587         struct read_packet_state *state;
588         struct tevent_fd *fde;
589
590         result = tevent_req_create(mem_ctx, &state, struct read_packet_state);
591         if (result == NULL) {
592                 return NULL;
593         }
594         state->fd = fd;
595         state->nread = 0;
596         state->more = more;
597         state->private_data = private_data;
598
599         state->buf = talloc_array(state, uint8_t, initial);
600         if (state->buf == NULL) {
601                 goto fail;
602         }
603
604         fde = tevent_add_fd(ev, state, fd, TEVENT_FD_READ, read_packet_handler,
605                             result);
606         if (fde == NULL) {
607                 goto fail;
608         }
609         return result;
610  fail:
611         TALLOC_FREE(result);
612         return NULL;
613 }
614
615 static void read_packet_handler(struct tevent_context *ev,
616                                 struct tevent_fd *fde,
617                                 uint16_t flags, void *private_data)
618 {
619         struct tevent_req *req = talloc_get_type_abort(
620                 private_data, struct tevent_req);
621         struct read_packet_state *state =
622                 tevent_req_data(req, struct read_packet_state);
623         size_t total = talloc_get_size(state->buf);
624         ssize_t nread, more;
625         uint8_t *tmp;
626
627         nread = recv(state->fd, state->buf+state->nread, total-state->nread,
628                      0);
629         if ((nread == -1) && (errno == ENOTSOCK)) {
630                 nread = read(state->fd, state->buf+state->nread,
631                              total-state->nread);
632         }
633         if ((nread == -1) && (errno == EINTR)) {
634                 /* retry */
635                 return;
636         }
637         if (nread == -1) {
638                 tevent_req_error(req, errno);
639                 return;
640         }
641         if (nread == 0) {
642                 tevent_req_error(req, EPIPE);
643                 return;
644         }
645
646         state->nread += nread;
647         if (state->nread < total) {
648                 /* Come back later */
649                 return;
650         }
651
652         /*
653          * We got what was initially requested. See if "more" asks for -- more.
654          */
655         if (state->more == NULL) {
656                 /* Nobody to ask, this is a async read_data */
657                 tevent_req_done(req);
658                 return;
659         }
660
661         more = state->more(state->buf, total, state->private_data);
662         if (more == -1) {
663                 /* We got an invalid packet, tell the caller */
664                 tevent_req_error(req, EIO);
665                 return;
666         }
667         if (more == 0) {
668                 /* We're done, full packet received */
669                 tevent_req_done(req);
670                 return;
671         }
672
673         if (total + more < total) {
674                 tevent_req_error(req, EMSGSIZE);
675                 return;
676         }
677
678         tmp = talloc_realloc(state, state->buf, uint8_t, total+more);
679         if (tevent_req_nomem(tmp, req)) {
680                 return;
681         }
682         state->buf = tmp;
683 }
684
685 ssize_t read_packet_recv(struct tevent_req *req, TALLOC_CTX *mem_ctx,
686                          uint8_t **pbuf, int *perrno)
687 {
688         struct read_packet_state *state =
689                 tevent_req_data(req, struct read_packet_state);
690
691         if (tevent_req_is_unix_error(req, perrno)) {
692                 return -1;
693         }
694         *pbuf = talloc_move(mem_ctx, &state->buf);
695         return talloc_get_size(*pbuf);
696 }
697
698 struct wait_for_read_state {
699         struct tevent_req *req;
700         struct tevent_fd *fde;
701 };
702
703 static void wait_for_read_done(struct tevent_context *ev,
704                                struct tevent_fd *fde,
705                                uint16_t flags,
706                                void *private_data);
707
708 struct tevent_req *wait_for_read_send(TALLOC_CTX *mem_ctx,
709                                       struct tevent_context *ev,
710                                       int fd)
711 {
712         struct tevent_req *req;
713         struct wait_for_read_state *state;
714
715         req = tevent_req_create(mem_ctx, &state, struct wait_for_read_state);
716         if (req == NULL) {
717                 return NULL;
718         }
719         state->req = req;
720         state->fde = tevent_add_fd(ev, state, fd, TEVENT_FD_READ,
721                                    wait_for_read_done, state);
722         if (tevent_req_nomem(state->fde, req)) {
723                 return tevent_req_post(req, ev);
724         }
725         return req;
726 }
727
728 static void wait_for_read_done(struct tevent_context *ev,
729                                struct tevent_fd *fde,
730                                uint16_t flags,
731                                void *private_data)
732 {
733         struct wait_for_read_state *state = talloc_get_type_abort(
734                 private_data, struct wait_for_read_state);
735
736         if (flags & TEVENT_FD_READ) {
737                 TALLOC_FREE(state->fde);
738                 tevent_req_done(state->req);
739         }
740 }
741
742 bool wait_for_read_recv(struct tevent_req *req, int *perr)
743 {
744         int err;
745
746         if (tevent_req_is_unix_error(req, &err)) {
747                 *perr = err;
748                 return false;
749         }
750         return true;
751 }