s4:lib/stream: make use of smb_len_tcp()
[rusty/samba.git] / source4 / lib / stream / packet.c
1 /* 
2    Unix SMB/CIFS mplementation.
3
4    helper layer for breaking up streams into discrete requests
5    
6    Copyright (C) Andrew Tridgell  2005
7     
8    This program is free software; you can redistribute it and/or modify
9    it under the terms of the GNU General Public License as published by
10    the Free Software Foundation; either version 3 of the License, or
11    (at your option) any later version.
12    
13    This program is distributed in the hope that it will be useful,
14    but WITHOUT ANY WARRANTY; without even the implied warranty of
15    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
16    GNU General Public License for more details.
17    
18    You should have received a copy of the GNU General Public License
19    along with this program.  If not, see <http://www.gnu.org/licenses/>.
20    
21 */
22
23 #include "includes.h"
24 #include "../lib/util/dlinklist.h"
25 #include "lib/events/events.h"
26 #include "lib/socket/socket.h"
27 #include "lib/stream/packet.h"
28 #include "libcli/raw/smb.h"
29
30 struct packet_context {
31         packet_callback_fn_t callback;
32         packet_full_request_fn_t full_request;
33         packet_error_handler_fn_t error_handler;
34         DATA_BLOB partial;
35         uint32_t num_read;
36         uint32_t initial_read;
37         struct socket_context *sock;
38         struct tevent_context *ev;
39         size_t packet_size;
40         void *private_data;
41         struct tevent_fd *fde;
42         bool serialise;
43         int processing;
44         bool recv_disable;
45         bool recv_need_enable;
46         bool nofree;
47
48         bool busy;
49         bool destructor_called;
50
51         bool unreliable_select;
52
53         struct send_element {
54                 struct send_element *next, *prev;
55                 DATA_BLOB blob;
56                 size_t nsent;
57                 packet_send_callback_fn_t send_callback;
58                 void *send_callback_private;
59         } *send_queue;
60 };
61
62 /*
63   a destructor used when we are processing packets to prevent freeing of this
64   context while it is being used
65 */
66 static int packet_destructor(struct packet_context *pc)
67 {
68         if (pc->busy) {
69                 pc->destructor_called = true;
70                 /* now we refuse the talloc_free() request. The free will
71                    happen again in the packet_recv() code */
72                 return -1;
73         }
74
75         return 0;
76 }
77
78
79 /*
80   initialise a packet receiver
81 */
82 _PUBLIC_ struct packet_context *packet_init(TALLOC_CTX *mem_ctx)
83 {
84         struct packet_context *pc = talloc_zero(mem_ctx, struct packet_context);
85         if (pc != NULL) {
86                 talloc_set_destructor(pc, packet_destructor);
87         }
88         return pc;
89 }
90
91
92 /*
93   set the request callback, called when a full request is ready
94 */
95 _PUBLIC_ void packet_set_callback(struct packet_context *pc, packet_callback_fn_t callback)
96 {
97         pc->callback = callback;
98 }
99
100 /*
101   set the error handler
102 */
103 _PUBLIC_ void packet_set_error_handler(struct packet_context *pc, packet_error_handler_fn_t handler)
104 {
105         pc->error_handler = handler;
106 }
107
108 /*
109   set the private pointer passed to the callback functions
110 */
111 _PUBLIC_ void packet_set_private(struct packet_context *pc, void *private_data)
112 {
113         pc->private_data = private_data;
114 }
115
116 /*
117   set the full request callback. Should return as follows:
118      NT_STATUS_OK == blob is a full request.
119      STATUS_MORE_ENTRIES == blob is not complete yet
120      any error == blob is not a valid 
121 */
122 _PUBLIC_ void packet_set_full_request(struct packet_context *pc, packet_full_request_fn_t callback)
123 {
124         pc->full_request = callback;
125 }
126
127 /*
128   set a socket context to use. You must set a socket_context
129 */
130 _PUBLIC_ void packet_set_socket(struct packet_context *pc, struct socket_context *sock)
131 {
132         pc->sock = sock;
133 }
134
135 /*
136   set an event context. If this is set then the code will ensure that
137   packets arrive with separate events, by creating a immediate event
138   for any secondary packets when more than one packet is read at one
139   time on a socket. This can matter for code that relies on not
140   getting more than one packet per event
141 */
142 _PUBLIC_ void packet_set_event_context(struct packet_context *pc, struct tevent_context *ev)
143 {
144         pc->ev = ev;
145 }
146
147 /*
148   tell the packet layer the fde for the socket
149 */
150 _PUBLIC_ void packet_set_fde(struct packet_context *pc, struct tevent_fd *fde)
151 {
152         pc->fde = fde;
153 }
154
155 /*
156   tell the packet layer to serialise requests, so we don't process two
157   requests at once on one connection. You must have set the
158   event_context and fde
159 */
160 _PUBLIC_ void packet_set_serialise(struct packet_context *pc)
161 {
162         pc->serialise = true;
163 }
164
165 /*
166   tell the packet layer how much to read when starting a new packet
167   this ensures it doesn't overread
168 */
169 _PUBLIC_ void packet_set_initial_read(struct packet_context *pc, uint32_t initial_read)
170 {
171         pc->initial_read = initial_read;
172 }
173
174 /*
175   tell the packet system not to steal/free blobs given to packet_send()
176 */
177 _PUBLIC_ void packet_set_nofree(struct packet_context *pc)
178 {
179         pc->nofree = true;
180 }
181
182 /*
183   tell the packet system that select/poll/epoll on the underlying
184   socket may not be a reliable way to determine if data is available
185   for receive. This happens with underlying socket systems such as the
186   one implemented on top of GNUTLS, where there may be data in
187   encryption/compression buffers that could be received by
188   socket_recv(), while there is no data waiting at the real socket
189   level as seen by select/poll/epoll. The GNUTLS library is supposed
190   to cope with this by always leaving some data sitting in the socket
191   buffer, but it does not seem to be reliable.
192  */
193 _PUBLIC_ void packet_set_unreliable_select(struct packet_context *pc)
194 {
195         pc->unreliable_select = true;
196 }
197
198 /*
199   tell the caller we have an error
200 */
201 static void packet_error(struct packet_context *pc, NTSTATUS status)
202 {
203         pc->sock = NULL;
204         if (pc->error_handler) {
205                 pc->error_handler(pc->private_data, status);
206                 return;
207         }
208         /* default error handler is to free the callers private pointer */
209         if (!NT_STATUS_EQUAL(status, NT_STATUS_END_OF_FILE)) {
210                 DEBUG(0,("packet_error on %s - %s\n", 
211                          talloc_get_name(pc->private_data), nt_errstr(status)));
212         }
213         talloc_free(pc->private_data);
214         return;
215 }
216
217
218 /*
219   tell the caller we have EOF
220 */
221 static void packet_eof(struct packet_context *pc)
222 {
223         packet_error(pc, NT_STATUS_END_OF_FILE);
224 }
225
226
227 /*
228   used to put packets on event boundaries
229 */
230 static void packet_next_event(struct tevent_context *ev, struct tevent_timer *te, 
231                               struct timeval t, void *private_data)
232 {
233         struct packet_context *pc = talloc_get_type(private_data, struct packet_context);
234         if (pc->num_read != 0 && pc->packet_size != 0 &&
235             pc->packet_size <= pc->num_read) {
236                 packet_recv(pc);
237         }
238 }
239
240
241 /*
242   call this when the socket becomes readable to kick off the whole
243   stream parsing process
244 */
245 _PUBLIC_ void packet_recv(struct packet_context *pc)
246 {
247         size_t npending;
248         NTSTATUS status;
249         size_t nread = 0;
250         DATA_BLOB blob;
251         bool recv_retry = false;
252
253         if (pc->processing) {
254                 TEVENT_FD_NOT_READABLE(pc->fde);
255                 pc->processing++;
256                 return;
257         }
258
259         if (pc->recv_disable) {
260                 pc->recv_need_enable = true;
261                 TEVENT_FD_NOT_READABLE(pc->fde);
262                 return;
263         }
264
265         if (pc->packet_size != 0 && pc->num_read >= pc->packet_size) {
266                 goto next_partial;
267         }
268
269         if (pc->packet_size != 0) {
270                 /* we've already worked out how long this next packet is, so skip the
271                    socket_pending() call */
272                 npending = pc->packet_size - pc->num_read;
273         } else if (pc->initial_read != 0) {
274                 npending = pc->initial_read - pc->num_read;
275         } else {
276                 if (pc->sock) {
277                         status = socket_pending(pc->sock, &npending);
278                 } else {
279                         status = NT_STATUS_CONNECTION_DISCONNECTED;
280                 }
281                 if (!NT_STATUS_IS_OK(status)) {
282                         packet_error(pc, status);
283                         return;
284                 }
285         }
286
287         if (npending == 0) {
288                 packet_eof(pc);
289                 return;
290         }
291
292 again:
293
294         if (npending + pc->num_read < npending) {
295                 packet_error(pc, NT_STATUS_INVALID_PARAMETER);
296                 return;
297         }
298
299         if (npending + pc->num_read < pc->num_read) {
300                 packet_error(pc, NT_STATUS_INVALID_PARAMETER);
301                 return;
302         }
303
304         /* possibly expand the partial packet buffer */
305         if (npending + pc->num_read > pc->partial.length) {
306                 if (!data_blob_realloc(pc, &pc->partial, npending+pc->num_read)) {
307                         packet_error(pc, NT_STATUS_NO_MEMORY);
308                         return;
309                 }
310         }
311
312         if (pc->partial.length < pc->num_read + npending) {
313                 packet_error(pc, NT_STATUS_INVALID_PARAMETER);
314                 return;
315         }
316
317         if ((uint8_t *)pc->partial.data + pc->num_read < (uint8_t *)pc->partial.data) {
318                 packet_error(pc, NT_STATUS_INVALID_PARAMETER);
319                 return;
320         }
321         if ((uint8_t *)pc->partial.data + pc->num_read + npending < (uint8_t *)pc->partial.data) {
322                 packet_error(pc, NT_STATUS_INVALID_PARAMETER);
323                 return;
324         }
325
326         status = socket_recv(pc->sock, pc->partial.data + pc->num_read, 
327                              npending, &nread);
328
329         if (NT_STATUS_IS_ERR(status)) {
330                 packet_error(pc, status);
331                 return;
332         }
333         if (recv_retry && NT_STATUS_EQUAL(status, STATUS_MORE_ENTRIES)) {
334                 nread = 0;
335                 status = NT_STATUS_OK;
336         }
337         if (!NT_STATUS_IS_OK(status)) {
338                 return;
339         }
340
341         if (nread == 0 && !recv_retry) {
342                 packet_eof(pc);
343                 return;
344         }
345
346         pc->num_read += nread;
347
348         if (pc->unreliable_select && nread != 0) {
349                 recv_retry = true;
350                 status = socket_pending(pc->sock, &npending);
351                 if (!NT_STATUS_IS_OK(status)) {
352                         packet_error(pc, status);
353                         return;
354                 }
355                 if (npending != 0) {
356                         goto again;
357                 }
358         }
359
360 next_partial:
361         if (pc->partial.length != pc->num_read) {
362                 if (!data_blob_realloc(pc, &pc->partial, pc->num_read)) {
363                         packet_error(pc, NT_STATUS_NO_MEMORY);
364                         return;
365                 }
366         }
367
368         /* see if its a full request */
369         blob = pc->partial;
370         blob.length = pc->num_read;
371         status = pc->full_request(pc->private_data, blob, &pc->packet_size);
372         if (NT_STATUS_IS_ERR(status)) {
373                 packet_error(pc, status);
374                 return;
375         }
376         if (!NT_STATUS_IS_OK(status)) {
377                 return;
378         }
379
380         if (pc->packet_size > pc->num_read) {
381                 /* the caller made an error */
382                 DEBUG(0,("Invalid packet_size %lu greater than num_read %lu\n",
383                          (long)pc->packet_size, (long)pc->num_read));
384                 packet_error(pc, NT_STATUS_INVALID_PARAMETER);
385                 return;
386         }
387
388         /* it is a full request - give it to the caller */
389         blob = pc->partial;
390         blob.length = pc->num_read;
391
392         if (pc->packet_size < pc->num_read) {
393                 pc->partial = data_blob_talloc(pc, blob.data + pc->packet_size, 
394                                                pc->num_read - pc->packet_size);
395                 if (pc->partial.data == NULL) {
396                         packet_error(pc, NT_STATUS_NO_MEMORY);
397                         return;
398                 }
399                 /* Trunate the blob sent to the caller to only the packet length */
400                 if (!data_blob_realloc(pc, &blob, pc->packet_size)) {
401                         packet_error(pc, NT_STATUS_NO_MEMORY);
402                         return;
403                 }
404         } else {
405                 pc->partial = data_blob(NULL, 0);
406         }
407         pc->num_read -= pc->packet_size;
408         pc->packet_size = 0;
409         
410         if (pc->serialise) {
411                 pc->processing = 1;
412         }
413
414         pc->busy = true;
415
416         status = pc->callback(pc->private_data, blob);
417
418         pc->busy = false;
419
420         if (pc->destructor_called) {
421                 talloc_free(pc);
422                 return;
423         }
424
425         if (pc->processing) {
426                 if (pc->processing > 1) {
427                         TEVENT_FD_READABLE(pc->fde);
428                 }
429                 pc->processing = 0;
430         }
431
432         if (!NT_STATUS_IS_OK(status)) {
433                 packet_error(pc, status);
434                 return;
435         }
436
437         /* Have we consumed the whole buffer yet? */
438         if (pc->partial.length == 0) {
439                 return;
440         }
441
442         /* we got multiple packets in one tcp read */
443         if (pc->ev == NULL) {
444                 goto next_partial;
445         }
446
447         blob = pc->partial;
448         blob.length = pc->num_read;
449
450         status = pc->full_request(pc->private_data, blob, &pc->packet_size);
451         if (NT_STATUS_IS_ERR(status)) {
452                 packet_error(pc, status);
453                 return;
454         }
455
456         if (!NT_STATUS_IS_OK(status)) {
457                 return;
458         }
459
460         tevent_add_timer(pc->ev, pc, timeval_zero(), packet_next_event, pc);
461 }
462
463
464 /*
465   temporarily disable receiving 
466 */
467 _PUBLIC_ void packet_recv_disable(struct packet_context *pc)
468 {
469         pc->recv_disable = true;
470 }
471
472 /*
473   re-enable receiving 
474 */
475 _PUBLIC_ void packet_recv_enable(struct packet_context *pc)
476 {
477         if (pc->recv_need_enable) {
478                 pc->recv_need_enable = false;
479                 TEVENT_FD_READABLE(pc->fde);
480         }
481         pc->recv_disable = false;
482         if (pc->num_read != 0 && pc->packet_size >= pc->num_read) {
483                 tevent_add_timer(pc->ev, pc, timeval_zero(), packet_next_event, pc);
484         }
485 }
486
487 /*
488   trigger a run of the send queue
489 */
490 _PUBLIC_ void packet_queue_run(struct packet_context *pc)
491 {
492         while (pc->send_queue) {
493                 struct send_element *el = pc->send_queue;
494                 NTSTATUS status;
495                 size_t nwritten;
496                 DATA_BLOB blob = data_blob_const(el->blob.data + el->nsent,
497                                                  el->blob.length - el->nsent);
498
499                 status = socket_send(pc->sock, &blob, &nwritten);
500
501                 if (NT_STATUS_IS_ERR(status)) {
502                         packet_error(pc, status);
503                         return;
504                 }
505                 if (!NT_STATUS_IS_OK(status)) {
506                         return;
507                 }
508                 el->nsent += nwritten;
509                 if (el->nsent == el->blob.length) {
510                         DLIST_REMOVE(pc->send_queue, el);
511                         if (el->send_callback) {
512                                 pc->busy = true;
513                                 el->send_callback(el->send_callback_private);
514                                 pc->busy = false;
515                                 if (pc->destructor_called) {
516                                         talloc_free(pc);
517                                         return;
518                                 }
519                         }
520                         talloc_free(el);
521                 }
522         }
523
524         /* we're out of requests to send, so don't wait for write
525            events any more */
526         TEVENT_FD_NOT_WRITEABLE(pc->fde);
527 }
528
529 /*
530   put a packet in the send queue.  When the packet is actually sent,
531   call send_callback.  
532
533   Useful for operations that must occur after sending a message, such
534   as the switch to SASL encryption after as sucessful LDAP bind relpy.
535 */
536 _PUBLIC_ NTSTATUS packet_send_callback(struct packet_context *pc, DATA_BLOB blob,
537                                        packet_send_callback_fn_t send_callback, 
538                                        void *private_data)
539 {
540         struct send_element *el;
541         el = talloc(pc, struct send_element);
542         NT_STATUS_HAVE_NO_MEMORY(el);
543
544         DLIST_ADD_END(pc->send_queue, el, struct send_element *);
545         el->blob = blob;
546         el->nsent = 0;
547         el->send_callback = send_callback;
548         el->send_callback_private = private_data;
549
550         /* if we aren't going to free the packet then we must reference it
551            to ensure it doesn't disappear before going out */
552         if (pc->nofree) {
553                 if (!talloc_reference(el, blob.data)) {
554                         return NT_STATUS_NO_MEMORY;
555                 }
556         } else {
557                 talloc_steal(el, blob.data);
558         }
559
560         if (private_data && !talloc_reference(el, private_data)) {
561                 return NT_STATUS_NO_MEMORY;
562         }
563
564         TEVENT_FD_WRITEABLE(pc->fde);
565
566         return NT_STATUS_OK;
567 }
568
569 /*
570   put a packet in the send queue
571 */
572 _PUBLIC_ NTSTATUS packet_send(struct packet_context *pc, DATA_BLOB blob)
573 {
574         return packet_send_callback(pc, blob, NULL, NULL);
575 }
576
577
578 /*
579   a full request checker for NBT formatted packets (first 3 bytes are length)
580 */
581 _PUBLIC_ NTSTATUS packet_full_request_nbt(void *private_data, DATA_BLOB blob, size_t *size)
582 {
583         if (blob.length < 4) {
584                 return STATUS_MORE_ENTRIES;
585         }
586         /*
587          * Note: that we use smb_len_tcp() instead
588          *       of smb_len_nbt() as this function is not
589          *       used for nbt and the source4 copy
590          *       of smb_len() was smb_len_tcp()
591          */
592         *size = 4 + smb_len_tcp(blob.data);
593         if (*size > blob.length) {
594                 return STATUS_MORE_ENTRIES;
595         }
596         return NT_STATUS_OK;
597 }
598
599
600 /*
601   work out if a packet is complete for protocols that use a 32 bit network byte
602   order length
603 */
604 _PUBLIC_ NTSTATUS packet_full_request_u32(void *private_data, DATA_BLOB blob, size_t *size)
605 {
606         if (blob.length < 4) {
607                 return STATUS_MORE_ENTRIES;
608         }
609         *size = 4 + RIVAL(blob.data, 0);
610         if (*size > blob.length) {
611                 return STATUS_MORE_ENTRIES;
612         }
613         return NT_STATUS_OK;
614 }