2c472641cc1406f4e0d5efd4d3afb1dca09f1495
[samba.git] / source4 / lib / stream / packet.c
1 /* 
2    Unix SMB/CIFS mplementation.
3
4    helper layer for breaking up streams into discrete requests
5    
6    Copyright (C) Andrew Tridgell  2005
7     
8    This program is free software; you can redistribute it and/or modify
9    it under the terms of the GNU General Public License as published by
10    the Free Software Foundation; either version 2 of the License, or
11    (at your option) any later version.
12    
13    This program is distributed in the hope that it will be useful,
14    but WITHOUT ANY WARRANTY; without even the implied warranty of
15    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
16    GNU General Public License for more details.
17    
18    You should have received a copy of the GNU General Public License
19    along with this program; if not, write to the Free Software
20    Foundation, Inc., 675 Mass Ave, Cambridge, MA 02139, USA.
21    
22 */
23
24 #include "includes.h"
25 #include "lib/util/dlinklist.h"
26 #include "lib/events/events.h"
27 #include "lib/socket/socket.h"
28 #include "lib/stream/packet.h"
29 #include "libcli/raw/smb.h"
30
31 struct packet_context {
32         packet_callback_fn_t callback;
33         packet_full_request_fn_t full_request;
34         packet_error_handler_fn_t error_handler;
35         DATA_BLOB partial;
36         uint32_t num_read;
37         uint32_t initial_read;
38         struct socket_context *sock;
39         struct event_context *ev;
40         size_t packet_size;
41         void *private;
42         struct fd_event *fde;
43         BOOL serialise;
44         int processing;
45         BOOL recv_disable;
46         BOOL nofree;
47
48         BOOL busy;
49         BOOL destructor_called;
50
51         struct send_element {
52                 struct send_element *next, *prev;
53                 DATA_BLOB blob;
54                 size_t nsent;
55                 packet_send_callback_fn_t send_callback;
56                 void *send_callback_private;
57         } *send_queue;
58 };
59
60 /*
61   a destructor used when we are processing packets to prevent freeing of this
62   context while it is being used
63 */
64 static int packet_destructor(struct packet_context *pc)
65 {
66         if (pc->busy) {
67                 pc->destructor_called = True;
68                 /* now we refuse the talloc_free() request. The free will
69                    happen again in the packet_recv() code */
70                 return -1;
71         }
72
73         return 0;
74 }
75
76
77 /*
78   initialise a packet receiver
79 */
80 _PUBLIC_ struct packet_context *packet_init(TALLOC_CTX *mem_ctx)
81 {
82         struct packet_context *pc = talloc_zero(mem_ctx, struct packet_context);
83         if (pc != NULL) {
84                 talloc_set_destructor(pc, packet_destructor);
85         }
86         return pc;
87 }
88
89
90 /*
91   set the request callback, called when a full request is ready
92 */
93 _PUBLIC_ void packet_set_callback(struct packet_context *pc, packet_callback_fn_t callback)
94 {
95         pc->callback = callback;
96 }
97
98 /*
99   set the error handler
100 */
101 _PUBLIC_ void packet_set_error_handler(struct packet_context *pc, packet_error_handler_fn_t handler)
102 {
103         pc->error_handler = handler;
104 }
105
106 /*
107   set the private pointer passed to the callback functions
108 */
109 _PUBLIC_ void packet_set_private(struct packet_context *pc, void *private)
110 {
111         pc->private = private;
112 }
113
114 /*
115   set the full request callback. Should return as follows:
116      NT_STATUS_OK == blob is a full request.
117      STATUS_MORE_ENTRIES == blob is not complete yet
118      any error == blob is not a valid 
119 */
120 _PUBLIC_ void packet_set_full_request(struct packet_context *pc, packet_full_request_fn_t callback)
121 {
122         pc->full_request = callback;
123 }
124
125 /*
126   set a socket context to use. You must set a socket_context
127 */
128 _PUBLIC_ void packet_set_socket(struct packet_context *pc, struct socket_context *sock)
129 {
130         pc->sock = sock;
131 }
132
133 /*
134   set an event context. If this is set then the code will ensure that
135   packets arrive with separate events, by creating a immediate event
136   for any secondary packets when more than one packet is read at one
137   time on a socket. This can matter for code that relies on not
138   getting more than one packet per event
139 */
140 _PUBLIC_ void packet_set_event_context(struct packet_context *pc, struct event_context *ev)
141 {
142         pc->ev = ev;
143 }
144
145 /*
146   tell the packet layer the fde for the socket
147 */
148 _PUBLIC_ void packet_set_fde(struct packet_context *pc, struct fd_event *fde)
149 {
150         pc->fde = fde;
151 }
152
153 /*
154   tell the packet layer to serialise requests, so we don't process two
155   requests at once on one connection. You must have set the
156   event_context and fde
157 */
158 _PUBLIC_ void packet_set_serialise(struct packet_context *pc)
159 {
160         pc->serialise = True;
161 }
162
163 /*
164   tell the packet layer how much to read when starting a new packet
165   this ensures it doesn't overread
166 */
167 _PUBLIC_ void packet_set_initial_read(struct packet_context *pc, uint32_t initial_read)
168 {
169         pc->initial_read = initial_read;
170 }
171
172 /*
173   tell the packet system not to steal/free blobs given to packet_send()
174 */
175 _PUBLIC_ void packet_set_nofree(struct packet_context *pc)
176 {
177         pc->nofree = True;
178 }
179
180
181 /*
182   tell the caller we have an error
183 */
184 static void packet_error(struct packet_context *pc, NTSTATUS status)
185 {
186         pc->sock = NULL;
187         if (pc->error_handler) {
188                 pc->error_handler(pc->private, status);
189                 return;
190         }
191         /* default error handler is to free the callers private pointer */
192         if (!NT_STATUS_EQUAL(status, NT_STATUS_END_OF_FILE)) {
193                 DEBUG(0,("packet_error on %s - %s\n", 
194                          talloc_get_name(pc->private), nt_errstr(status)));
195         }
196         talloc_free(pc->private);
197         return;
198 }
199
200
201 /*
202   tell the caller we have EOF
203 */
204 static void packet_eof(struct packet_context *pc)
205 {
206         packet_error(pc, NT_STATUS_END_OF_FILE);
207 }
208
209
210 /*
211   used to put packets on event boundaries
212 */
213 static void packet_next_event(struct event_context *ev, struct timed_event *te, 
214                               struct timeval t, void *private)
215 {
216         struct packet_context *pc = talloc_get_type(private, struct packet_context);
217         if (pc->num_read != 0 && pc->packet_size != 0 &&
218             pc->packet_size <= pc->num_read) {
219                 packet_recv(pc);
220         }
221 }
222
223
224 /*
225   call this when the socket becomes readable to kick off the whole
226   stream parsing process
227 */
228 _PUBLIC_ void packet_recv(struct packet_context *pc)
229 {
230         size_t npending;
231         NTSTATUS status;
232         size_t nread = 0;
233         DATA_BLOB blob;
234
235         if (pc->processing) {
236                 EVENT_FD_NOT_READABLE(pc->fde);
237                 pc->processing++;
238                 return;
239         }
240
241         if (pc->recv_disable) {
242                 EVENT_FD_NOT_READABLE(pc->fde);
243                 return;
244         }
245
246         if (pc->packet_size != 0 && pc->num_read >= pc->packet_size) {
247                 goto next_partial;
248         }
249
250         if (pc->packet_size != 0) {
251                 /* we've already worked out how long this next packet is, so skip the
252                    socket_pending() call */
253                 npending = pc->packet_size - pc->num_read;
254         } else if (pc->initial_read != 0) {
255                 npending = pc->initial_read - pc->num_read;
256         } else {
257                 if (pc->sock) {
258                         status = socket_pending(pc->sock, &npending);
259                 } else {
260                         status = NT_STATUS_CONNECTION_DISCONNECTED;
261                 }
262                 if (!NT_STATUS_IS_OK(status)) {
263                         packet_error(pc, status);
264                         return;
265                 }
266         }
267
268         if (npending == 0) {
269                 packet_eof(pc);
270                 return;
271         }
272
273         if (npending + pc->num_read < npending) {
274                 packet_error(pc, NT_STATUS_INVALID_PARAMETER);
275                 return;
276         }
277
278         if (npending + pc->num_read < pc->num_read) {
279                 packet_error(pc, NT_STATUS_INVALID_PARAMETER);
280                 return;
281         }
282
283         /* possibly expand the partial packet buffer */
284         if (npending + pc->num_read > pc->partial.length) {
285                 status = data_blob_realloc(pc, &pc->partial, npending+pc->num_read);
286                 if (!NT_STATUS_IS_OK(status)) {
287                         packet_error(pc, status);
288                         return;
289                 }
290         }
291
292         if (pc->partial.length < pc->num_read + npending) {
293                 packet_error(pc, NT_STATUS_INVALID_PARAMETER);
294                 return;
295         }
296
297         if ((uint8_t *)pc->partial.data + pc->num_read < (uint8_t *)pc->partial.data) {
298                 packet_error(pc, NT_STATUS_INVALID_PARAMETER);
299                 return;
300         }
301         if ((uint8_t *)pc->partial.data + pc->num_read + npending < (uint8_t *)pc->partial.data) {
302                 packet_error(pc, NT_STATUS_INVALID_PARAMETER);
303                 return;
304         }
305
306         status = socket_recv(pc->sock, pc->partial.data + pc->num_read, 
307                              npending, &nread);
308
309         if (NT_STATUS_IS_ERR(status)) {
310                 packet_error(pc, status);
311                 return;
312         }
313         if (!NT_STATUS_IS_OK(status)) {
314                 return;
315         }
316
317         if (nread == 0) {
318                 packet_eof(pc);
319                 return;
320         }
321
322         pc->num_read += nread;
323
324 next_partial:
325         if (pc->partial.length != pc->num_read) {
326                 status = data_blob_realloc(pc, &pc->partial, pc->num_read);
327                 if (!NT_STATUS_IS_OK(status)) {
328                         packet_error(pc, status);
329                         return;
330                 }
331         }
332
333         /* see if its a full request */
334         blob = pc->partial;
335         blob.length = pc->num_read;
336         status = pc->full_request(pc->private, blob, &pc->packet_size);
337         if (NT_STATUS_IS_ERR(status)) {
338                 packet_error(pc, status);
339                 return;
340         }
341         if (!NT_STATUS_IS_OK(status)) {
342                 return;
343         }
344
345         if (pc->packet_size > pc->num_read) {
346                 /* the caller made an error */
347                 DEBUG(0,("Invalid packet_size %lu greater than num_read %lu\n",
348                          (long)pc->packet_size, (long)pc->num_read));
349                 packet_error(pc, NT_STATUS_INVALID_PARAMETER);
350                 return;
351         }
352
353         /* it is a full request - give it to the caller */
354         blob = pc->partial;
355         blob.length = pc->num_read;
356
357         if (pc->packet_size < pc->num_read) {
358                 pc->partial = data_blob_talloc(pc, blob.data + pc->packet_size, 
359                                                pc->num_read - pc->packet_size);
360                 if (pc->partial.data == NULL) {
361                         packet_error(pc, NT_STATUS_NO_MEMORY);
362                         return;
363                 }
364                 /* Trunate the blob sent to the caller to only the packet length */
365                 status = data_blob_realloc(pc, &blob, pc->packet_size);
366                 if (!NT_STATUS_IS_OK(status)) {
367                         packet_error(pc, status);
368                         return;
369                 }
370         } else {
371                 pc->partial = data_blob(NULL, 0);
372         }
373         pc->num_read -= pc->packet_size;
374         pc->packet_size = 0;
375         
376         if (pc->serialise) {
377                 pc->processing = 1;
378         }
379
380         pc->busy = True;
381
382         status = pc->callback(pc->private, blob);
383
384         pc->busy = False;
385
386         if (pc->destructor_called) {
387                 talloc_free(pc);
388                 return;
389         }
390
391         if (pc->processing) {
392                 if (pc->processing > 1) {
393                         EVENT_FD_READABLE(pc->fde);
394                 }
395                 pc->processing = 0;
396         }
397
398         if (!NT_STATUS_IS_OK(status)) {
399                 packet_error(pc, status);
400                 return;
401         }
402
403         /* Have we consumed the whole buffer yet? */
404         if (pc->partial.length == 0) {
405                 return;
406         }
407
408         /* we got multiple packets in one tcp read */
409         if (pc->ev == NULL) {
410                 goto next_partial;
411         }
412
413         blob = pc->partial;
414         blob.length = pc->num_read;
415
416         status = pc->full_request(pc->private, blob, &pc->packet_size);
417         if (NT_STATUS_IS_ERR(status)) {
418                 packet_error(pc, status);
419                 return;
420         }
421
422         if (!NT_STATUS_IS_OK(status)) {
423                 return;
424         }
425
426         event_add_timed(pc->ev, pc, timeval_zero(), packet_next_event, pc);
427 }
428
429
430 /*
431   temporarily disable receiving 
432 */
433 _PUBLIC_ void packet_recv_disable(struct packet_context *pc)
434 {
435         EVENT_FD_NOT_READABLE(pc->fde);
436         pc->recv_disable = True;
437 }
438
439 /*
440   re-enable receiving 
441 */
442 _PUBLIC_ void packet_recv_enable(struct packet_context *pc)
443 {
444         EVENT_FD_READABLE(pc->fde);
445         pc->recv_disable = False;
446         if (pc->num_read != 0 && pc->packet_size >= pc->num_read) {
447                 event_add_timed(pc->ev, pc, timeval_zero(), packet_next_event, pc);
448         }
449 }
450
451 /*
452   trigger a run of the send queue
453 */
454 _PUBLIC_ void packet_queue_run(struct packet_context *pc)
455 {
456         while (pc->send_queue) {
457                 struct send_element *el = pc->send_queue;
458                 NTSTATUS status;
459                 size_t nwritten;
460                 DATA_BLOB blob = data_blob_const(el->blob.data + el->nsent,
461                                                  el->blob.length - el->nsent);
462
463                 status = socket_send(pc->sock, &blob, &nwritten);
464
465                 if (NT_STATUS_IS_ERR(status)) {
466                         packet_error(pc, status);
467                         return;
468                 }
469                 if (!NT_STATUS_IS_OK(status)) {
470                         return;
471                 }
472                 el->nsent += nwritten;
473                 if (el->nsent == el->blob.length) {
474                         DLIST_REMOVE(pc->send_queue, el);
475                         if (el->send_callback) {
476                                 el->send_callback(el->send_callback_private);
477                         }
478                         talloc_free(el);
479                 }
480         }
481
482         /* we're out of requests to send, so don't wait for write
483            events any more */
484         EVENT_FD_NOT_WRITEABLE(pc->fde);
485 }
486
487 /*
488   put a packet in the send queue.  When the packet is actually sent,
489   call send_callback.  
490
491   Useful for operations that must occour after sending a message, such
492   as the switch to SASL encryption after as sucessful LDAP bind relpy.
493 */
494 _PUBLIC_ NTSTATUS packet_send_callback(struct packet_context *pc, DATA_BLOB blob,
495                                        packet_send_callback_fn_t send_callback, 
496                                        void *private)
497 {
498         struct send_element *el;
499         el = talloc(pc, struct send_element);
500         NT_STATUS_HAVE_NO_MEMORY(el);
501
502         DLIST_ADD_END(pc->send_queue, el, struct send_element *);
503         el->blob = blob;
504         el->nsent = 0;
505         el->send_callback = send_callback;
506         el->send_callback_private = private;
507
508         /* if we aren't going to free the packet then we must reference it
509            to ensure it doesn't disappear before going out */
510         if (pc->nofree) {
511                 if (!talloc_reference(el, blob.data)) {
512                         return NT_STATUS_NO_MEMORY;
513                 }
514         } else {
515                 talloc_steal(el, blob.data);
516         }
517
518         if (private && !talloc_reference(el, private)) {
519                 return NT_STATUS_NO_MEMORY;
520         }
521
522         EVENT_FD_WRITEABLE(pc->fde);
523
524         return NT_STATUS_OK;
525 }
526
527 /*
528   put a packet in the send queue
529 */
530 _PUBLIC_ NTSTATUS packet_send(struct packet_context *pc, DATA_BLOB blob)
531 {
532         return packet_send_callback(pc, blob, NULL, NULL);
533 }
534
535
536 /*
537   a full request checker for NBT formatted packets (first 3 bytes are length)
538 */
539 _PUBLIC_ NTSTATUS packet_full_request_nbt(void *private, DATA_BLOB blob, size_t *size)
540 {
541         if (blob.length < 4) {
542                 return STATUS_MORE_ENTRIES;
543         }
544         *size = 4 + smb_len(blob.data);
545         if (*size > blob.length) {
546                 return STATUS_MORE_ENTRIES;
547         }
548         return NT_STATUS_OK;
549 }
550
551
552 /*
553   work out if a packet is complete for protocols that use a 32 bit network byte
554   order length
555 */
556 _PUBLIC_ NTSTATUS packet_full_request_u32(void *private, DATA_BLOB blob, size_t *size)
557 {
558         if (blob.length < 4) {
559                 return STATUS_MORE_ENTRIES;
560         }
561         *size = 4 + RIVAL(blob.data, 0);
562         if (*size > blob.length) {
563                 return STATUS_MORE_ENTRIES;
564         }
565         return NT_STATUS_OK;
566 }