r7750: handle STATUS_MORE_ENTRIES on send in tls
[samba.git] / source4 / lib / tls / tls.c
1 /* 
2    Unix SMB/CIFS implementation.
3
4    transport layer security handling code
5
6    Copyright (C) Andrew Tridgell 2005
7    
8    This program is free software; you can redistribute it and/or modify
9    it under the terms of the GNU General Public License as published by
10    the Free Software Foundation; either version 2 of the License, or
11    (at your option) any later version.
12    
13    This program is distributed in the hope that it will be useful,
14    but WITHOUT ANY WARRANTY; without even the implied warranty of
15    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
16    GNU General Public License for more details.
17    
18    You should have received a copy of the GNU General Public License
19    along with this program; if not, write to the Free Software
20    Foundation, Inc., 675 Mass Ave, Cambridge, MA 02139, USA.
21 */
22
23 #include "includes.h"
24 #include "lib/events/events.h"
25 #include "lib/socket/socket.h"
26 #include "lib/tls/tls.h"
27
28 #if HAVE_LIBGNUTLS
29 #include "gnutls/gnutls.h"
30
31 #define DH_BITS 1024
32
33 /* hold persistent tls data */
34 struct tls_params {
35         gnutls_certificate_credentials x509_cred;
36         gnutls_dh_params dh_params;
37         BOOL tls_enabled;
38 };
39
40 /* hold per connection tls data */
41 struct tls_context {
42         struct tls_params *params;
43         struct socket_context *socket;
44         struct fd_event *fde;
45         gnutls_session session;
46         BOOL done_handshake;
47         BOOL have_first_byte;
48         uint8_t first_byte;
49         BOOL tls_enabled;
50         BOOL tls_detect;
51         const char *plain_chars;
52         BOOL output_pending;
53 };
54
55
56 /*
57   callback for reading from a socket
58 */
59 static ssize_t tls_pull(gnutls_transport_ptr ptr, void *buf, size_t size)
60 {
61         struct tls_context *tls = talloc_get_type(ptr, struct tls_context);
62         NTSTATUS status;
63         size_t nread;
64         
65         if (tls->have_first_byte) {
66                 *(uint8_t *)buf = tls->first_byte;
67                 tls->have_first_byte = False;
68                 return 1;
69         }
70
71         status = socket_recv(tls->socket, buf, size, &nread, 0);
72         if (NT_STATUS_EQUAL(status, NT_STATUS_END_OF_FILE)) {
73                 return 0;
74         }
75         if (NT_STATUS_IS_ERR(status)) {
76                 EVENT_FD_NOT_READABLE(tls->fde);
77                 EVENT_FD_NOT_WRITEABLE(tls->fde);
78                 errno = EBADF;
79                 return -1;
80         }
81         if (!NT_STATUS_IS_OK(status)) {
82                 EVENT_FD_READABLE(tls->fde);
83                 EVENT_FD_NOT_WRITEABLE(tls->fde);
84                 errno = EAGAIN;
85                 return -1;
86         }
87         if (tls->output_pending) {
88                 EVENT_FD_WRITEABLE(tls->fde);
89         }
90         if (size != nread) {
91                 EVENT_FD_READABLE(tls->fde);
92         }
93         return nread;
94 }
95
96 /*
97   callback for writing to a socket
98 */
99 static ssize_t tls_push(gnutls_transport_ptr ptr, const void *buf, size_t size)
100 {
101         struct tls_context *tls = talloc_get_type(ptr, struct tls_context);
102         NTSTATUS status;
103         size_t nwritten;
104         DATA_BLOB b;
105
106         if (!tls->tls_enabled) {
107                 return size;
108         }
109
110         b.data = discard_const(buf);
111         b.length = size;
112
113         status = socket_send(tls->socket, &b, &nwritten, 0);
114         if (NT_STATUS_EQUAL(status, STATUS_MORE_ENTRIES)) {
115                 errno = EAGAIN;
116                 return -1;
117         }
118         if (!NT_STATUS_IS_OK(status)) {
119                 EVENT_FD_WRITEABLE(tls->fde);
120                 return -1;
121         }
122         if (size != nwritten) {
123                 EVENT_FD_WRITEABLE(tls->fde);
124         }
125         return nwritten;
126 }
127
128 /*
129   destroy a tls session
130  */
131 static int tls_destructor(void *ptr)
132 {
133         struct tls_context *tls = talloc_get_type(ptr, struct tls_context);
134         int ret;
135         ret = gnutls_bye(tls->session, GNUTLS_SHUT_WR);
136         if (ret < 0) {
137                 DEBUG(0,("TLS gnutls_bye failed - %s\n", gnutls_strerror(ret)));
138         }
139         return 0;
140 }
141
142
143 /*
144   possibly continue the handshake process
145 */
146 static NTSTATUS tls_handshake(struct tls_context *tls)
147 {
148         int ret;
149
150         if (tls->done_handshake) {
151                 return NT_STATUS_OK;
152         }
153         
154         ret = gnutls_handshake(tls->session);
155         if (ret == GNUTLS_E_INTERRUPTED || ret == GNUTLS_E_AGAIN) {
156                 return STATUS_MORE_ENTRIES;
157         }
158         if (ret < 0) {
159                 DEBUG(0,("TLS gnutls_handshake failed - %s\n", gnutls_strerror(ret)));
160                 return NT_STATUS_UNEXPECTED_NETWORK_ERROR;
161         }
162         tls->done_handshake = True;
163         return NT_STATUS_OK;
164 }
165
166 /*
167   see how many bytes are pending on the connection
168 */
169 NTSTATUS tls_socket_pending(struct tls_context *tls, size_t *npending)
170 {
171         if (!tls->tls_enabled || tls->tls_detect) {
172                 return socket_pending(tls->socket, npending);
173         }
174         *npending = gnutls_record_check_pending(tls->session);
175         if (*npending == 0) {
176                 return socket_pending(tls->socket, npending);
177         }
178         return NT_STATUS_OK;
179 }
180
181 /*
182   receive data either by tls or normal socket_recv
183 */
184 NTSTATUS tls_socket_recv(struct tls_context *tls, void *buf, size_t wantlen, 
185                          size_t *nread)
186 {
187         int ret;
188         NTSTATUS status;
189         if (tls->tls_enabled && tls->tls_detect) {
190                 status = socket_recv(tls->socket, &tls->first_byte, 1, nread, 0);
191                 NT_STATUS_NOT_OK_RETURN(status);
192                 if (*nread == 0) return NT_STATUS_OK;
193                 tls->tls_detect = False;
194                 /* look for the first byte of a valid HTTP operation */
195                 if (strchr(tls->plain_chars, tls->first_byte)) {
196                         /* not a tls link */
197                         tls->tls_enabled = False;
198                         *(uint8_t *)buf = tls->first_byte;
199                         return NT_STATUS_OK;
200                 }
201                 tls->have_first_byte = True;
202         }
203
204         if (!tls->tls_enabled) {
205                 return socket_recv(tls->socket, buf, wantlen, nread, 0);
206         }
207
208         status = tls_handshake(tls);
209         NT_STATUS_NOT_OK_RETURN(status);
210
211         ret = gnutls_record_recv(tls->session, buf, wantlen);
212         if (ret == GNUTLS_E_INTERRUPTED || ret == GNUTLS_E_AGAIN) {
213                 return STATUS_MORE_ENTRIES;
214         }
215         if (ret < 0) {
216                 return NT_STATUS_UNEXPECTED_NETWORK_ERROR;
217         }
218         *nread = ret;
219         return NT_STATUS_OK;
220 }
221
222
223 /*
224   send data either by tls or normal socket_recv
225 */
226 NTSTATUS tls_socket_send(struct tls_context *tls, const DATA_BLOB *blob, size_t *sendlen)
227 {
228         NTSTATUS status;
229         int ret;
230
231         if (!tls->tls_enabled) {
232                 return socket_send(tls->socket, blob, sendlen, 0);
233         }
234
235         status = tls_handshake(tls);
236         NT_STATUS_NOT_OK_RETURN(status);
237
238         ret = gnutls_record_send(tls->session, blob->data, blob->length);
239         if (ret == GNUTLS_E_INTERRUPTED || ret == GNUTLS_E_AGAIN) {
240                 return STATUS_MORE_ENTRIES;
241         }
242         if (ret < 0) {
243                 DEBUG(0,("gnutls_record_send of %d failed - %s\n", blob->length, gnutls_strerror(ret)));
244                 return NT_STATUS_UNEXPECTED_NETWORK_ERROR;
245         }
246         *sendlen = ret;
247         tls->output_pending = (ret < blob->length);
248         return NT_STATUS_OK;
249 }
250
251
252 /*
253   initialise global tls state
254 */
255 struct tls_params *tls_initialise(TALLOC_CTX *mem_ctx)
256 {
257         struct tls_params *params;
258         int ret;
259         const char *keyfile = lp_tls_keyfile();
260         const char *certfile = lp_tls_certfile();
261         const char *cafile = lp_tls_cafile();
262         const char *crlfile = lp_tls_crlfile();
263         void tls_cert_generate(TALLOC_CTX *, const char *, const char *, const char *);
264
265         params = talloc(mem_ctx, struct tls_params);
266         if (params == NULL) return NULL;
267
268         if (!lp_tls_enabled() || keyfile == NULL || *keyfile == 0) {
269                 params->tls_enabled = False;
270                 return params;
271         }
272
273         if (!file_exist(cafile)) {
274                 tls_cert_generate(params, keyfile, certfile, cafile);
275         }
276
277         ret = gnutls_global_init();
278         if (ret < 0) goto init_failed;
279
280         gnutls_certificate_allocate_credentials(&params->x509_cred);
281         if (ret < 0) goto init_failed;
282
283         if (cafile && *cafile) {
284                 ret = gnutls_certificate_set_x509_trust_file(params->x509_cred, cafile, 
285                                                              GNUTLS_X509_FMT_PEM);      
286                 if (ret < 0) {
287                         DEBUG(0,("TLS failed to initialise cafile %s\n", cafile));
288                         goto init_failed;
289                 }
290         }
291
292         if (crlfile && *crlfile) {
293                 ret = gnutls_certificate_set_x509_crl_file(params->x509_cred, 
294                                                            crlfile, 
295                                                            GNUTLS_X509_FMT_PEM);
296                 if (ret < 0) {
297                         DEBUG(0,("TLS failed to initialise crlfile %s\n", crlfile));
298                         goto init_failed;
299                 }
300         }
301         
302         ret = gnutls_certificate_set_x509_key_file(params->x509_cred, 
303                                                    certfile, keyfile,
304                                                    GNUTLS_X509_FMT_PEM);
305         if (ret < 0) {
306                 DEBUG(0,("TLS failed to initialise certfile %s and keyfile %s\n", 
307                          certfile, keyfile));
308                 goto init_failed;
309         }
310         
311         ret = gnutls_dh_params_init(&params->dh_params);
312         if (ret < 0) goto init_failed;
313
314         ret = gnutls_dh_params_generate2(params->dh_params, DH_BITS);
315         if (ret < 0) goto init_failed;
316
317         gnutls_certificate_set_dh_params(params->x509_cred, params->dh_params);
318
319         params->tls_enabled = True;
320         return params;
321
322 init_failed:
323         DEBUG(0,("GNUTLS failed to initialise - %s\n", gnutls_strerror(ret)));
324         params->tls_enabled = False;
325         return params;
326 }
327
328
329 /*
330   setup for a new connection
331 */
332 struct tls_context *tls_init_server(struct tls_params *params, 
333                                     struct socket_context *socket,
334                                     struct fd_event *fde, 
335                                     const char *plain_chars)
336 {
337         struct tls_context *tls;
338         int ret;
339
340         tls = talloc(socket, struct tls_context);
341         if (tls == NULL) return NULL;
342
343         tls->socket          = socket;
344         tls->fde             = fde;
345
346         if (!params->tls_enabled) {
347                 tls->tls_enabled = False;
348                 return tls;
349         }
350
351 #define TLSCHECK(call) do { \
352         ret = call; \
353         if (ret < 0) { \
354                 DEBUG(0,("TLS %s - %s\n", #call, gnutls_strerror(ret))); \
355                 goto failed; \
356         } \
357 } while (0)
358
359         TLSCHECK(gnutls_init(&tls->session, GNUTLS_SERVER));
360
361         talloc_set_destructor(tls, tls_destructor);
362
363         TLSCHECK(gnutls_set_default_priority(tls->session));
364         TLSCHECK(gnutls_credentials_set(tls->session, GNUTLS_CRD_CERTIFICATE, 
365                                         params->x509_cred));
366         gnutls_certificate_server_set_request(tls->session, GNUTLS_CERT_REQUEST);
367         gnutls_dh_set_prime_bits(tls->session, DH_BITS);
368         gnutls_transport_set_ptr(tls->session, (gnutls_transport_ptr)tls);
369         gnutls_transport_set_pull_function(tls->session, (gnutls_pull_func)tls_pull);
370         gnutls_transport_set_push_function(tls->session, (gnutls_push_func)tls_push);
371         gnutls_transport_set_lowat(tls->session, 0);
372
373         tls->plain_chars = plain_chars;
374         if (plain_chars) {
375                 tls->tls_detect = True;
376         } else {
377                 tls->tls_detect = False;
378         }
379
380         tls->output_pending  = False;
381         tls->params          = params;
382         tls->done_handshake  = False;
383         tls->have_first_byte = False;
384         tls->tls_enabled     = True;
385         
386         return tls;
387
388 failed:
389         DEBUG(0,("TLS init connection failed - %s\n", gnutls_strerror(ret)));
390         tls->tls_enabled = False;
391         params->tls_enabled = False;
392         return tls;
393 }
394
395 BOOL tls_enabled(struct tls_context *tls)
396 {
397         return tls->tls_enabled;
398 }
399
400 BOOL tls_support(struct tls_params *params)
401 {
402         return params->tls_enabled;
403 }
404
405
406 #else
407
408 /* for systems without tls we just map the tls socket calls to the
409    normal socket calls */
410
411 struct tls_params *tls_initialise(TALLOC_CTX *mem_ctx)
412 {
413         return talloc_new(mem_ctx);
414 }
415
416 struct tls_context *tls_init_server(struct tls_params *params, 
417                                     struct socket_context *sock, 
418                                     struct fd_event *fde,
419                                     const char *plain_chars)
420 {
421         if (plain_chars == NULL) return NULL;
422         return (struct tls_context *)sock;
423 }
424
425
426 NTSTATUS tls_socket_recv(struct tls_context *tls, void *buf, size_t wantlen, 
427                          size_t *nread)
428 {
429         return socket_recv((struct socket_context *)tls, buf, wantlen, nread, 0);
430 }
431
432 NTSTATUS tls_socket_send(struct tls_context *tls, const DATA_BLOB *blob, size_t *sendlen)
433 {
434         return socket_send((struct socket_context *)tls, blob, sendlen, 0);
435 }
436
437 BOOL tls_enabled(struct tls_context *tls)
438 {
439         return False;
440 }
441
442 BOOL tls_support(struct tls_params *params)
443 {
444         return False;
445 }
446
447 NTSTATUS tls_socket_pending(struct tls_context *tls, size_t *npending)
448 {
449         return socket_pending((struct socket_context *)tls, npending);
450 }
451
452 #endif