r12694: Move some headers to the directory of the subsystem they belong to.
[samba.git] / source4 / lib / stream / packet.c
1 /* 
2    Unix SMB/CIFS mplementation.
3
4    helper layer for breaking up streams into discrete requests
5    
6    Copyright (C) Andrew Tridgell  2005
7     
8    This program is free software; you can redistribute it and/or modify
9    it under the terms of the GNU General Public License as published by
10    the Free Software Foundation; either version 2 of the License, or
11    (at your option) any later version.
12    
13    This program is distributed in the hope that it will be useful,
14    but WITHOUT ANY WARRANTY; without even the implied warranty of
15    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
16    GNU General Public License for more details.
17    
18    You should have received a copy of the GNU General Public License
19    along with this program; if not, write to the Free Software
20    Foundation, Inc., 675 Mass Ave, Cambridge, MA 02139, USA.
21    
22 */
23
24 #include "includes.h"
25 #include "smb.h"
26 #include "dlinklist.h"
27 #include "lib/events/events.h"
28 #include "lib/socket/socket.h"
29 #include "lib/tls/tls.h"
30 #include "lib/stream/packet.h"
31
32
33 struct packet_context {
34         packet_callback_fn_t callback;
35         packet_full_request_fn_t full_request;
36         packet_error_handler_fn_t error_handler;
37         DATA_BLOB partial;
38         uint32_t num_read;
39         uint32_t initial_read;
40         struct tls_context *tls;
41         struct socket_context *sock;
42         struct event_context *ev;
43         size_t packet_size;
44         void *private;
45         struct fd_event *fde;
46         BOOL serialise;
47         int processing;
48         BOOL recv_disable;
49         BOOL nofree;
50
51         BOOL busy;
52         BOOL destructor_called;
53
54         struct send_element {
55                 struct send_element *next, *prev;
56                 DATA_BLOB blob;
57                 size_t nsent;
58         } *send_queue;
59 };
60
61 /*
62   a destructor used when we are processing packets to prevent freeing of this
63   context while it is being used
64 */
65 static int packet_destructor(void *p)
66 {
67         struct packet_context *pc = talloc_get_type(p, struct packet_context);
68
69         if (pc->busy) {
70                 pc->destructor_called = True;
71                 /* now we refuse the talloc_free() request. The free will
72                    happen again in the packet_recv() code */
73                 return -1;
74         }
75
76         return 0;
77 }
78
79
80 /*
81   initialise a packet receiver
82 */
83 struct packet_context *packet_init(TALLOC_CTX *mem_ctx)
84 {
85         struct packet_context *pc = talloc_zero(mem_ctx, struct packet_context);
86         if (pc != NULL) {
87                 talloc_set_destructor(pc, packet_destructor);
88         }
89         return pc;
90 }
91
92
93 /*
94   set the request callback, called when a full request is ready
95 */
96 void packet_set_callback(struct packet_context *pc, packet_callback_fn_t callback)
97 {
98         pc->callback = callback;
99 }
100
101 /*
102   set the error handler
103 */
104 void packet_set_error_handler(struct packet_context *pc, packet_error_handler_fn_t handler)
105 {
106         pc->error_handler = handler;
107 }
108
109 /*
110   set the private pointer passed to the callback functions
111 */
112 void packet_set_private(struct packet_context *pc, void *private)
113 {
114         pc->private = private;
115 }
116
117 /*
118   set the full request callback. Should return as follows:
119      NT_STATUS_OK == blob is a full request.
120      STATUS_MORE_ENTRIES == blob is not complete yet
121      any error == blob is not a valid 
122 */
123 void packet_set_full_request(struct packet_context *pc, packet_full_request_fn_t callback)
124 {
125         pc->full_request = callback;
126 }
127
128 /*
129   set a tls context to use. You must either set a tls_context or a socket_context
130 */
131 void packet_set_tls(struct packet_context *pc, struct tls_context *tls)
132 {
133         pc->tls = tls;
134 }
135
136 /*
137   set a socket context to use. You must either set a tls_context or a socket_context
138 */
139 void packet_set_socket(struct packet_context *pc, struct socket_context *sock)
140 {
141         pc->sock = sock;
142 }
143
144 /*
145   set an event context. If this is set then the code will ensure that
146   packets arrive with separate events, by creating a immediate event
147   for any secondary packets when more than one packet is read at one
148   time on a socket. This can matter for code that relies on not
149   getting more than one packet per event
150 */
151 void packet_set_event_context(struct packet_context *pc, struct event_context *ev)
152 {
153         pc->ev = ev;
154 }
155
156 /*
157   tell the packet layer the fde for the socket
158 */
159 void packet_set_fde(struct packet_context *pc, struct fd_event *fde)
160 {
161         pc->fde = fde;
162 }
163
164 /*
165   tell the packet layer to serialise requests, so we don't process two
166   requests at once on one connection. You must have set the
167   event_context and fde
168 */
169 void packet_set_serialise(struct packet_context *pc)
170 {
171         pc->serialise = True;
172 }
173
174 /*
175   tell the packet layer how much to read when starting a new packet
176   this ensures it doesn't overread
177 */
178 void packet_set_initial_read(struct packet_context *pc, uint32_t initial_read)
179 {
180         pc->initial_read = initial_read;
181 }
182
183 /*
184   tell the packet system not to steal/free blobs given to packet_send()
185 */
186 void packet_set_nofree(struct packet_context *pc)
187 {
188         pc->nofree = True;
189 }
190
191
192 /*
193   tell the caller we have an error
194 */
195 static void packet_error(struct packet_context *pc, NTSTATUS status)
196 {
197         pc->tls = NULL;
198         pc->sock = NULL;
199         if (pc->error_handler) {
200                 pc->error_handler(pc->private, status);
201                 return;
202         }
203         /* default error handler is to free the callers private pointer */
204         if (!NT_STATUS_EQUAL(status, NT_STATUS_END_OF_FILE)) {
205                 DEBUG(0,("packet_error on %s - %s\n", 
206                          talloc_get_name(pc->private), nt_errstr(status)));
207         }
208         talloc_free(pc->private);
209         return;
210 }
211
212
213 /*
214   tell the caller we have EOF
215 */
216 static void packet_eof(struct packet_context *pc)
217 {
218         packet_error(pc, NT_STATUS_END_OF_FILE);
219 }
220
221
222 /*
223   used to put packets on event boundaries
224 */
225 static void packet_next_event(struct event_context *ev, struct timed_event *te, 
226                               struct timeval t, void *private)
227 {
228         struct packet_context *pc = talloc_get_type(private, struct packet_context);
229         if (pc->num_read != 0 && pc->packet_size != 0 &&
230             pc->packet_size <= pc->num_read) {
231                 packet_recv(pc);
232         }
233 }
234
235
236 /*
237   call this when the socket becomes readable to kick off the whole
238   stream parsing process
239 */
240 void packet_recv(struct packet_context *pc)
241 {
242         size_t npending;
243         NTSTATUS status;
244         size_t nread = 0;
245         DATA_BLOB blob;
246
247         if (pc->processing) {
248                 EVENT_FD_NOT_READABLE(pc->fde);
249                 pc->processing++;
250                 return;
251         }
252
253         if (pc->recv_disable) {
254                 EVENT_FD_NOT_READABLE(pc->fde);
255                 return;
256         }
257
258         if (pc->packet_size != 0 && pc->num_read >= pc->packet_size) {
259                 goto next_partial;
260         }
261
262         if (pc->packet_size != 0) {
263                 /* we've already worked out how long this next packet is, so skip the
264                    socket_pending() call */
265                 npending = pc->packet_size - pc->num_read;
266         } else if (pc->initial_read != 0) {
267                 npending = pc->initial_read - pc->num_read;
268         } else {
269                 if (pc->tls) {
270                         status = tls_socket_pending(pc->tls, &npending);
271                 } else if (pc->sock) {
272                         status = socket_pending(pc->sock, &npending);
273                 } else {
274                         status = NT_STATUS_CONNECTION_DISCONNECTED;
275                 }
276                 if (!NT_STATUS_IS_OK(status)) {
277                         packet_error(pc, status);
278                         return;
279                 }
280         }
281
282         if (npending == 0) {
283                 packet_eof(pc);
284                 return;
285         }
286
287         /* possibly expand the partial packet buffer */
288         if (npending + pc->num_read > pc->partial.length) {
289                 status = data_blob_realloc(pc, &pc->partial, npending+pc->num_read);
290                 if (!NT_STATUS_IS_OK(status)) {
291                         packet_error(pc, status);
292                         return;
293                 }
294         }
295
296         if (pc->tls) {
297                 status = tls_socket_recv(pc->tls, pc->partial.data + pc->num_read, 
298                                          npending, &nread);
299         } else {
300                 status = socket_recv(pc->sock, pc->partial.data + pc->num_read, 
301                                      npending, &nread, 0);
302         }
303         if (NT_STATUS_IS_ERR(status)) {
304                 packet_error(pc, status);
305                 return;
306         }
307         if (!NT_STATUS_IS_OK(status)) {
308                 return;
309         }
310
311         if (nread == 0) {
312                 packet_eof(pc);
313                 return;
314         }
315
316         pc->num_read += nread;
317
318 next_partial:
319         if (pc->partial.length != pc->num_read) {
320                 status = data_blob_realloc(pc, &pc->partial, pc->num_read);
321                 if (!NT_STATUS_IS_OK(status)) {
322                         packet_error(pc, status);
323                         return;
324                 }
325         }
326
327         /* see if its a full request */
328         blob = pc->partial;
329         blob.length = pc->num_read;
330         status = pc->full_request(pc->private, blob, &pc->packet_size);
331         if (NT_STATUS_IS_ERR(status)) {
332                 packet_error(pc, status);
333                 return;
334         }
335         if (!NT_STATUS_IS_OK(status)) {
336                 return;
337         }
338
339         if (pc->packet_size > pc->num_read) {
340                 /* the caller made an error */
341                 DEBUG(0,("Invalid packet_size %lu greater than num_read %lu\n",
342                          (long)pc->packet_size, (long)pc->num_read));
343                 packet_error(pc, NT_STATUS_INVALID_PARAMETER);
344                 return;
345         }
346
347         /* it is a full request - give it to the caller */
348         blob = pc->partial;
349         blob.length = pc->num_read;
350
351         if (pc->packet_size < pc->num_read) {
352                 pc->partial = data_blob_talloc(pc, blob.data + pc->packet_size, 
353                                                pc->num_read - pc->packet_size);
354                 if (pc->partial.data == NULL) {
355                         packet_error(pc, NT_STATUS_NO_MEMORY);
356                         return;
357                 }
358                 status = data_blob_realloc(pc, &blob, pc->packet_size);
359                 if (!NT_STATUS_IS_OK(status)) {
360                         packet_error(pc, status);
361                         return;
362                 }
363         } else {
364                 pc->partial = data_blob(NULL, 0);
365         }
366         pc->num_read -= pc->packet_size;
367         pc->packet_size = 0;
368         
369         if (pc->serialise) {
370                 pc->processing = 1;
371         }
372
373         pc->busy = True;
374
375         status = pc->callback(pc->private, blob);
376
377         pc->busy = False;
378
379         if (pc->destructor_called) {
380                 talloc_free(pc);
381                 return;
382         }
383
384         if (pc->processing) {
385                 if (pc->processing > 1) {
386                         EVENT_FD_READABLE(pc->fde);
387                 }
388                 pc->processing = 0;
389         }
390
391         if (!NT_STATUS_IS_OK(status)) {
392                 packet_error(pc, status);
393                 return;
394         }
395
396         if (pc->partial.length == 0) {
397                 return;
398         }
399
400         /* we got multiple packets in one tcp read */
401         if (pc->ev == NULL) {
402                 goto next_partial;
403         }
404
405         blob = pc->partial;
406         blob.length = pc->num_read;
407
408         status = pc->full_request(pc->private, blob, &pc->packet_size);
409         if (NT_STATUS_IS_ERR(status)) {
410                 packet_error(pc, status);
411                 return;
412         }
413
414         if (!NT_STATUS_IS_OK(status)) {
415                 return;
416         }
417
418         event_add_timed(pc->ev, pc, timeval_zero(), packet_next_event, pc);
419 }
420
421
422 /*
423   temporarily disable receiving 
424 */
425 void packet_recv_disable(struct packet_context *pc)
426 {
427         EVENT_FD_NOT_READABLE(pc->fde);
428         pc->recv_disable = True;
429 }
430
431 /*
432   re-enable receiving 
433 */
434 void packet_recv_enable(struct packet_context *pc)
435 {
436         EVENT_FD_READABLE(pc->fde);
437         pc->recv_disable = False;
438         if (pc->num_read != 0 && pc->packet_size >= pc->num_read) {
439                 event_add_timed(pc->ev, pc, timeval_zero(), packet_next_event, pc);
440         }
441 }
442
443 /*
444   trigger a run of the send queue
445 */
446 void packet_queue_run(struct packet_context *pc)
447 {
448         while (pc->send_queue) {
449                 struct send_element *el = pc->send_queue;
450                 NTSTATUS status;
451                 size_t nwritten;
452                 DATA_BLOB blob = data_blob_const(el->blob.data + el->nsent,
453                                                  el->blob.length - el->nsent);
454
455                 if (pc->tls) {
456                         status = tls_socket_send(pc->tls, &blob, &nwritten);
457                 } else {
458                         status = socket_send(pc->sock, &blob, &nwritten, 0);
459                 }
460                 if (NT_STATUS_IS_ERR(status)) {
461                         packet_error(pc, NT_STATUS_NET_WRITE_FAULT);
462                         return;
463                 }
464                 if (!NT_STATUS_IS_OK(status)) {
465                         return;
466                 }
467                 el->nsent += nwritten;
468                 if (el->nsent == el->blob.length) {
469                         DLIST_REMOVE(pc->send_queue, el);
470                         talloc_free(el);
471                 }
472         }
473
474         /* we're out of requests to send, so don't wait for write
475            events any more */
476         EVENT_FD_NOT_WRITEABLE(pc->fde);
477 }
478
479 /*
480   put a packet in the send queue
481 */
482 NTSTATUS packet_send(struct packet_context *pc, DATA_BLOB blob)
483 {
484         struct send_element *el;
485         el = talloc(pc, struct send_element);
486         NT_STATUS_HAVE_NO_MEMORY(el);
487
488         DLIST_ADD_END(pc->send_queue, el, struct send_element *);
489         el->blob = blob;
490         el->nsent = 0;
491
492         /* if we aren't going to free the packet then we must reference it
493            to ensure it doesn't disappear before going out */
494         if (pc->nofree) {
495                 if (!talloc_reference(el, blob.data)) {
496                         return NT_STATUS_NO_MEMORY;
497                 }
498         } else {
499                 talloc_steal(el, blob.data);
500         }
501
502         EVENT_FD_WRITEABLE(pc->fde);
503
504         return NT_STATUS_OK;
505 }
506
507
508 /*
509   a full request checker for NBT formatted packets (first 3 bytes are length)
510 */
511 NTSTATUS packet_full_request_nbt(void *private, DATA_BLOB blob, size_t *size)
512 {
513         if (blob.length < 4) {
514                 return STATUS_MORE_ENTRIES;
515         }
516         *size = 4 + smb_len(blob.data);
517         if (*size > blob.length) {
518                 return STATUS_MORE_ENTRIES;
519         }
520         return NT_STATUS_OK;
521 }
522
523
524 /*
525   work out if a packet is complete for protocols that use a 32 bit network byte
526   order length
527 */
528 NTSTATUS packet_full_request_u32(void *private, DATA_BLOB blob, size_t *size)
529 {
530         if (blob.length < 4) {
531                 return STATUS_MORE_ENTRIES;
532         }
533         *size = 4 + RIVAL(blob.data, 0);
534         if (*size > blob.length) {
535                 return STATUS_MORE_ENTRIES;
536         }
537         return NT_STATUS_OK;
538 }