7431eec3250bb4678ca40ca1331b8289b05a8c66
[samba.git] / ctdb / common / sock_io.c
1 /*
2    Generic Unix-domain Socket I/O
3
4    Copyright (C) Amitay Isaacs  2016
5
6    This program is free software; you can redistribute it and/or modify
7    it under the terms of the GNU General Public License as published by
8    the Free Software Foundation; either version 3 of the License, or
9    (at your option) any later version.
10
11    This program is distributed in the hope that it will be useful,
12    but WITHOUT ANY WARRANTY; without even the implied warranty of
13    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
14    GNU General Public License for more details.
15
16    You should have received a copy of the GNU General Public License
17    along with this program; if not, see <http://www.gnu.org/licenses/>.
18 */
19
20 #include "replace.h"
21 #include "system/filesys.h"
22 #include "system/network.h"
23
24 #include <talloc.h>
25 #include <tevent.h>
26
27 #include "lib/util/sys_rw.h"
28 #include "lib/util/debug.h"
29 #include "lib/util/blocking.h"
30
31 #include "common/logging.h"
32 #include "common/sock_io.h"
33
34 int sock_connect(const char *sockpath)
35 {
36         struct sockaddr_un addr;
37         size_t len;
38         int fd, ret;
39
40         if (sockpath == NULL) {
41                 D_ERR("Invalid socket path\n");
42                 return -1;
43         }
44
45         memset(&addr, 0, sizeof(addr));
46         addr.sun_family = AF_UNIX;
47         len = strlcpy(addr.sun_path, sockpath, sizeof(addr.sun_path));
48         if (len >= sizeof(addr.sun_path)) {
49                 D_ERR("Socket path too long, len=%zu\n", strlen(sockpath));
50                 return -1;
51         }
52
53         fd = socket(AF_UNIX, SOCK_STREAM, 0);
54         if (fd == -1) {
55                 D_ERR("socket() failed, errno=%d\n", errno);
56                 return -1;
57         }
58
59         ret = connect(fd, (struct sockaddr *)&addr, sizeof(addr));
60         if (ret == -1) {
61                 D_ERR("connect() failed, errno=%d\n", errno);
62                 close(fd);
63                 return -1;
64         }
65
66         return fd;
67 }
68
69 struct sock_queue {
70         struct tevent_context *ev;
71         sock_queue_callback_fn_t callback;
72         void *private_data;
73         int fd;
74
75         struct tevent_immediate *im;
76         struct tevent_queue *queue;
77         struct tevent_fd *fde;
78         uint8_t *buf;
79         size_t buflen, begin, end;
80 };
81
82 static bool sock_queue_set_fd(struct sock_queue *queue, int fd);
83 static int sock_queue_destructor(struct sock_queue *queue);
84 static void sock_queue_handler(struct tevent_context *ev,
85                                struct tevent_fd *fde, uint16_t flags,
86                                void *private_data);
87 static void sock_queue_process(struct sock_queue *queue);
88 static void sock_queue_process_event(struct tevent_context *ev,
89                                      struct tevent_immediate *im,
90                                      void *private_data);
91
92 struct sock_queue *sock_queue_setup(TALLOC_CTX *mem_ctx,
93                                     struct tevent_context *ev,
94                                     int fd,
95                                     sock_queue_callback_fn_t callback,
96                                     void *private_data)
97 {
98         struct sock_queue *queue;
99
100         queue = talloc_zero(mem_ctx, struct sock_queue);
101         if (queue == NULL) {
102                 return NULL;
103         }
104
105         queue->ev = ev;
106         queue->callback = callback;
107         queue->private_data = private_data;
108
109         queue->im = tevent_create_immediate(queue);
110         if (queue->im == NULL) {
111                 talloc_free(queue);
112                 return NULL;
113         }
114
115         queue->queue = tevent_queue_create(queue, "out-queue");
116         if (queue->queue == NULL) {
117                 talloc_free(queue);
118                 return NULL;
119         }
120
121         if (! sock_queue_set_fd(queue, fd)) {
122                 talloc_free(queue);
123                 return NULL;
124         }
125
126         talloc_set_destructor(queue, sock_queue_destructor);
127
128         return queue;
129 }
130
131 static bool sock_queue_set_fd(struct sock_queue *queue, int fd)
132 {
133         TALLOC_FREE(queue->fde);
134         queue->fd = fd;
135
136         if (fd != -1) {
137                 int ret;
138
139                 ret = set_blocking(fd, false);
140                 if (ret != 0) {
141                         return false;
142                 }
143
144                 queue->fde = tevent_add_fd(queue->ev, queue, fd,
145                                            TEVENT_FD_READ,
146                                            sock_queue_handler, queue);
147                 if (queue->fde == NULL) {
148                         return false;
149                 }
150                 tevent_fd_set_auto_close(queue->fde);
151         }
152
153         return true;
154 }
155
156 static int sock_queue_destructor(struct sock_queue *queue)
157 {
158         TALLOC_FREE(queue->fde);
159         queue->fd = -1;
160
161         return 0;
162 }
163
164 static void sock_queue_handler(struct tevent_context *ev,
165                                struct tevent_fd *fde, uint16_t flags,
166                                void *private_data)
167 {
168         struct sock_queue *queue = talloc_get_type_abort(
169                 private_data, struct sock_queue);
170         int ret, num_ready;
171         ssize_t nread;
172
173         ret = ioctl(queue->fd, FIONREAD, &num_ready);
174         if (ret != 0) {
175                 /* Ignore */
176                 return;
177         }
178
179         if (num_ready == 0) {
180                 /* descriptor has been closed */
181                 goto fail;
182         }
183
184         if (num_ready > queue->buflen - queue->end) {
185                 queue->buf = talloc_realloc_size(queue, queue->buf,
186                                                  queue->end + num_ready);
187                 if (queue->buf == NULL) {
188                         goto fail;
189                 }
190                 queue->buflen = queue->end + num_ready;
191         }
192
193         nread = sys_read(queue->fd, queue->buf + queue->end, num_ready);
194         if (nread < 0) {
195                 goto fail;
196         }
197         queue->end += nread;
198
199         sock_queue_process(queue);
200         return;
201
202 fail:
203         queue->callback(NULL, 0, queue->private_data);
204 }
205
206 static void sock_queue_process(struct sock_queue *queue)
207 {
208         uint32_t pkt_size;
209
210         if ((queue->end - queue->begin) < sizeof(uint32_t)) {
211                 /* not enough data */
212                 return;
213         }
214
215         pkt_size = *(uint32_t *)(queue->buf + queue->begin);
216         if (pkt_size == 0) {
217                 D_ERR("Invalid packet of length 0\n");
218                 queue->callback(NULL, 0, queue->private_data);
219         }
220
221         if ((queue->end - queue->begin) < pkt_size) {
222                 /* not enough data */
223                 return;
224         }
225
226         queue->callback(queue->buf + queue->begin, pkt_size,
227                         queue->private_data);
228         queue->begin += pkt_size;
229
230         if (queue->begin < queue->end) {
231                 /* more data to be processed */
232                 tevent_schedule_immediate(queue->im, queue->ev,
233                                           sock_queue_process_event, queue);
234         } else {
235                 TALLOC_FREE(queue->buf);
236                 queue->buflen = 0;
237                 queue->begin = 0;
238                 queue->end = 0;
239         }
240 }
241
242 static void sock_queue_process_event(struct tevent_context *ev,
243                                      struct tevent_immediate *im,
244                                      void *private_data)
245 {
246         struct sock_queue *queue = talloc_get_type_abort(
247                 private_data, struct sock_queue);
248
249         sock_queue_process(queue);
250 }
251
252 struct sock_queue_write_state {
253         uint8_t *pkt;
254         uint32_t pkt_size;
255 };
256
257 static void sock_queue_trigger(struct tevent_req *req, void *private_data);
258
259 int sock_queue_write(struct sock_queue *queue, uint8_t *buf, size_t buflen)
260 {
261         struct tevent_req *req;
262         struct sock_queue_write_state *state;
263         bool status;
264
265         if (buflen >= INT32_MAX) {
266                 return -1;
267         }
268
269         req = tevent_req_create(queue, &state, struct sock_queue_write_state);
270         if (req == NULL) {
271                 return -1;
272         }
273
274         state->pkt = buf;
275         state->pkt_size = (uint32_t)buflen;
276
277         status = tevent_queue_add_entry(queue->queue, queue->ev, req,
278                                         sock_queue_trigger, queue);
279         if (! status) {
280                 talloc_free(req);
281                 return -1;
282         }
283
284         return 0;
285 }
286
287 static void sock_queue_trigger(struct tevent_req *req, void *private_data)
288 {
289         struct sock_queue *queue = talloc_get_type_abort(
290                 private_data, struct sock_queue);
291         struct sock_queue_write_state *state = tevent_req_data(
292                 req, struct sock_queue_write_state);
293         size_t offset = 0;
294
295         do {
296                 ssize_t nwritten;
297
298                 nwritten = sys_write(queue->fd, state->pkt + offset,
299                                      state->pkt_size - offset);
300                 if (nwritten < 0) {
301                         queue->callback(NULL, 0, queue->private_data);
302                         return;
303                 }
304                 offset += nwritten;
305
306         } while (offset < state->pkt_size);
307
308         tevent_req_done(req);
309         talloc_free(req);
310 }