r7769: added client support in the tls library api
[samba.git] / source4 / lib / tls / tls.c
1 /* 
2    Unix SMB/CIFS implementation.
3
4    transport layer security handling code
5
6    Copyright (C) Andrew Tridgell 2005
7    
8    This program is free software; you can redistribute it and/or modify
9    it under the terms of the GNU General Public License as published by
10    the Free Software Foundation; either version 2 of the License, or
11    (at your option) any later version.
12    
13    This program is distributed in the hope that it will be useful,
14    but WITHOUT ANY WARRANTY; without even the implied warranty of
15    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
16    GNU General Public License for more details.
17    
18    You should have received a copy of the GNU General Public License
19    along with this program; if not, write to the Free Software
20    Foundation, Inc., 675 Mass Ave, Cambridge, MA 02139, USA.
21 */
22
23 #include "includes.h"
24 #include "lib/events/events.h"
25 #include "lib/socket/socket.h"
26 #include "lib/tls/tls.h"
27
28 #if HAVE_LIBGNUTLS
29 #include "gnutls/gnutls.h"
30
31 #define DH_BITS 1024
32
33 /* hold persistent tls data */
34 struct tls_params {
35         gnutls_certificate_credentials x509_cred;
36         gnutls_dh_params dh_params;
37         BOOL tls_enabled;
38 };
39
40 /* hold per connection tls data */
41 struct tls_context {
42         struct socket_context *socket;
43         struct fd_event *fde;
44         gnutls_session session;
45         BOOL done_handshake;
46         BOOL have_first_byte;
47         uint8_t first_byte;
48         BOOL tls_enabled;
49         BOOL tls_detect;
50         const char *plain_chars;
51         BOOL output_pending;
52         gnutls_certificate_credentials xcred;
53         BOOL interrupted;
54 };
55
56 #define TLSCHECK(call) do { \
57         ret = call; \
58         if (ret < 0) { \
59                 DEBUG(0,("TLS %s - %s\n", #call, gnutls_strerror(ret))); \
60                 goto failed; \
61         } \
62 } while (0)
63
64
65
66 /*
67   callback for reading from a socket
68 */
69 static ssize_t tls_pull(gnutls_transport_ptr ptr, void *buf, size_t size)
70 {
71         struct tls_context *tls = talloc_get_type(ptr, struct tls_context);
72         NTSTATUS status;
73         size_t nread;
74         
75         if (tls->have_first_byte) {
76                 *(uint8_t *)buf = tls->first_byte;
77                 tls->have_first_byte = False;
78                 return 1;
79         }
80
81         status = socket_recv(tls->socket, buf, size, &nread, 0);
82         if (NT_STATUS_EQUAL(status, NT_STATUS_END_OF_FILE)) {
83                 return 0;
84         }
85         if (NT_STATUS_IS_ERR(status)) {
86                 EVENT_FD_NOT_READABLE(tls->fde);
87                 EVENT_FD_NOT_WRITEABLE(tls->fde);
88                 errno = EBADF;
89                 return -1;
90         }
91         if (!NT_STATUS_IS_OK(status)) {
92                 EVENT_FD_READABLE(tls->fde);
93                 errno = EAGAIN;
94                 return -1;
95         }
96         if (tls->output_pending) {
97                 EVENT_FD_WRITEABLE(tls->fde);
98         }
99         if (size != nread) {
100                 EVENT_FD_READABLE(tls->fde);
101         }
102         return nread;
103 }
104
105 /*
106   callback for writing to a socket
107 */
108 static ssize_t tls_push(gnutls_transport_ptr ptr, const void *buf, size_t size)
109 {
110         struct tls_context *tls = talloc_get_type(ptr, struct tls_context);
111         NTSTATUS status;
112         size_t nwritten;
113         DATA_BLOB b;
114
115         if (!tls->tls_enabled) {
116                 return size;
117         }
118
119         b.data = discard_const(buf);
120         b.length = size;
121
122         status = socket_send(tls->socket, &b, &nwritten, 0);
123         if (NT_STATUS_EQUAL(status, STATUS_MORE_ENTRIES)) {
124                 errno = EAGAIN;
125                 return -1;
126         }
127         if (!NT_STATUS_IS_OK(status)) {
128                 EVENT_FD_WRITEABLE(tls->fde);
129                 return -1;
130         }
131         if (size != nwritten) {
132                 EVENT_FD_WRITEABLE(tls->fde);
133         }
134         return nwritten;
135 }
136
137 /*
138   destroy a tls session
139  */
140 static int tls_destructor(void *ptr)
141 {
142         struct tls_context *tls = talloc_get_type(ptr, struct tls_context);
143         int ret;
144         ret = gnutls_bye(tls->session, GNUTLS_SHUT_WR);
145         if (ret < 0) {
146                 DEBUG(0,("TLS gnutls_bye failed - %s\n", gnutls_strerror(ret)));
147         }
148         return 0;
149 }
150
151
152 /*
153   possibly continue the handshake process
154 */
155 static NTSTATUS tls_handshake(struct tls_context *tls)
156 {
157         int ret;
158
159         if (tls->done_handshake) {
160                 return NT_STATUS_OK;
161         }
162         
163         ret = gnutls_handshake(tls->session);
164         if (ret == GNUTLS_E_INTERRUPTED || ret == GNUTLS_E_AGAIN) {
165                 if (gnutls_record_get_direction(tls->session) == 1) {
166                         EVENT_FD_WRITEABLE(tls->fde);
167                 }
168                 return STATUS_MORE_ENTRIES;
169         }
170         if (ret < 0) {
171                 DEBUG(0,("TLS gnutls_handshake failed - %s\n", gnutls_strerror(ret)));
172                 return NT_STATUS_UNEXPECTED_NETWORK_ERROR;
173         }
174         tls->done_handshake = True;
175         return NT_STATUS_OK;
176 }
177
178 /*
179   possibly continue an interrupted operation
180 */
181 static NTSTATUS tls_interrupted(struct tls_context *tls)
182 {
183         int ret;
184
185         if (!tls->interrupted) {
186                 return NT_STATUS_OK;
187         }
188         if (gnutls_record_get_direction(tls->session) == 1) {
189                 ret = gnutls_record_send(tls->session, NULL, 0);
190         } else {
191                 ret = gnutls_record_recv(tls->session, NULL, 0);
192         }
193         if (ret == GNUTLS_E_INTERRUPTED || ret == GNUTLS_E_AGAIN) {
194                 return STATUS_MORE_ENTRIES;
195         }
196         tls->interrupted = False;
197         return NT_STATUS_OK;
198 }
199
200 /*
201   see how many bytes are pending on the connection
202 */
203 NTSTATUS tls_socket_pending(struct tls_context *tls, size_t *npending)
204 {
205         if (!tls->tls_enabled || tls->tls_detect) {
206                 return socket_pending(tls->socket, npending);
207         }
208         *npending = gnutls_record_check_pending(tls->session);
209         if (*npending == 0) {
210                 NTSTATUS status = socket_pending(tls->socket, npending);
211                 if (*npending == 0) {
212                         /* seems to be a gnutls bug */
213                         (*npending) = 100;
214                 }
215                 return status;
216         }
217         return NT_STATUS_OK;
218 }
219
220 /*
221   receive data either by tls or normal socket_recv
222 */
223 NTSTATUS tls_socket_recv(struct tls_context *tls, void *buf, size_t wantlen, 
224                          size_t *nread)
225 {
226         int ret;
227         NTSTATUS status;
228         if (tls->tls_enabled && tls->tls_detect) {
229                 status = socket_recv(tls->socket, &tls->first_byte, 1, nread, 0);
230                 NT_STATUS_NOT_OK_RETURN(status);
231                 if (*nread == 0) return NT_STATUS_OK;
232                 tls->tls_detect = False;
233                 /* look for the first byte of a valid HTTP operation */
234                 if (strchr(tls->plain_chars, tls->first_byte)) {
235                         /* not a tls link */
236                         tls->tls_enabled = False;
237                         *(uint8_t *)buf = tls->first_byte;
238                         return NT_STATUS_OK;
239                 }
240                 tls->have_first_byte = True;
241         }
242
243         if (!tls->tls_enabled) {
244                 return socket_recv(tls->socket, buf, wantlen, nread, 0);
245         }
246
247         status = tls_handshake(tls);
248         NT_STATUS_NOT_OK_RETURN(status);
249
250         status = tls_interrupted(tls);
251         NT_STATUS_NOT_OK_RETURN(status);
252
253         ret = gnutls_record_recv(tls->session, buf, wantlen);
254         if (ret == GNUTLS_E_INTERRUPTED || ret == GNUTLS_E_AGAIN) {
255                 if (gnutls_record_get_direction(tls->session) == 1) {
256                         EVENT_FD_WRITEABLE(tls->fde);
257                 }
258                 tls->interrupted = True;
259                 return STATUS_MORE_ENTRIES;
260         }
261         if (ret < 0) {
262                 return NT_STATUS_UNEXPECTED_NETWORK_ERROR;
263         }
264         *nread = ret;
265         return NT_STATUS_OK;
266 }
267
268
269 /*
270   send data either by tls or normal socket_recv
271 */
272 NTSTATUS tls_socket_send(struct tls_context *tls, const DATA_BLOB *blob, size_t *sendlen)
273 {
274         NTSTATUS status;
275         int ret;
276
277         if (!tls->tls_enabled) {
278                 return socket_send(tls->socket, blob, sendlen, 0);
279         }
280
281         status = tls_handshake(tls);
282         NT_STATUS_NOT_OK_RETURN(status);
283
284         status = tls_interrupted(tls);
285         NT_STATUS_NOT_OK_RETURN(status);
286
287         ret = gnutls_record_send(tls->session, blob->data, blob->length);
288         if (ret == GNUTLS_E_INTERRUPTED || ret == GNUTLS_E_AGAIN) {
289                 if (gnutls_record_get_direction(tls->session) == 1) {
290                         EVENT_FD_WRITEABLE(tls->fde);
291                 }
292                 tls->interrupted = True;
293                 return STATUS_MORE_ENTRIES;
294         }
295         if (ret < 0) {
296                 DEBUG(0,("gnutls_record_send of %d failed - %s\n", blob->length, gnutls_strerror(ret)));
297                 return NT_STATUS_UNEXPECTED_NETWORK_ERROR;
298         }
299         *sendlen = ret;
300         tls->output_pending = (ret < blob->length);
301         return NT_STATUS_OK;
302 }
303
304
305 /*
306   initialise global tls state
307 */
308 struct tls_params *tls_initialise(TALLOC_CTX *mem_ctx)
309 {
310         struct tls_params *params;
311         int ret;
312         const char *keyfile = lp_tls_keyfile();
313         const char *certfile = lp_tls_certfile();
314         const char *cafile = lp_tls_cafile();
315         const char *crlfile = lp_tls_crlfile();
316         void tls_cert_generate(TALLOC_CTX *, const char *, const char *, const char *);
317
318         params = talloc(mem_ctx, struct tls_params);
319         if (params == NULL) return NULL;
320
321         if (!lp_tls_enabled() || keyfile == NULL || *keyfile == 0) {
322                 params->tls_enabled = False;
323                 return params;
324         }
325
326         if (!file_exist(cafile)) {
327                 tls_cert_generate(params, keyfile, certfile, cafile);
328         }
329
330         ret = gnutls_global_init();
331         if (ret < 0) goto init_failed;
332
333         gnutls_certificate_allocate_credentials(&params->x509_cred);
334         if (ret < 0) goto init_failed;
335
336         if (cafile && *cafile) {
337                 ret = gnutls_certificate_set_x509_trust_file(params->x509_cred, cafile, 
338                                                              GNUTLS_X509_FMT_PEM);      
339                 if (ret < 0) {
340                         DEBUG(0,("TLS failed to initialise cafile %s\n", cafile));
341                         goto init_failed;
342                 }
343         }
344
345         if (crlfile && *crlfile) {
346                 ret = gnutls_certificate_set_x509_crl_file(params->x509_cred, 
347                                                            crlfile, 
348                                                            GNUTLS_X509_FMT_PEM);
349                 if (ret < 0) {
350                         DEBUG(0,("TLS failed to initialise crlfile %s\n", crlfile));
351                         goto init_failed;
352                 }
353         }
354         
355         ret = gnutls_certificate_set_x509_key_file(params->x509_cred, 
356                                                    certfile, keyfile,
357                                                    GNUTLS_X509_FMT_PEM);
358         if (ret < 0) {
359                 DEBUG(0,("TLS failed to initialise certfile %s and keyfile %s\n", 
360                          certfile, keyfile));
361                 goto init_failed;
362         }
363         
364         ret = gnutls_dh_params_init(&params->dh_params);
365         if (ret < 0) goto init_failed;
366
367         ret = gnutls_dh_params_generate2(params->dh_params, DH_BITS);
368         if (ret < 0) goto init_failed;
369
370         gnutls_certificate_set_dh_params(params->x509_cred, params->dh_params);
371
372         params->tls_enabled = True;
373
374         return params;
375
376 init_failed:
377         DEBUG(0,("GNUTLS failed to initialise - %s\n", gnutls_strerror(ret)));
378         params->tls_enabled = False;
379         return params;
380 }
381
382
383 /*
384   setup for a new connection
385 */
386 struct tls_context *tls_init_server(struct tls_params *params, 
387                                     struct socket_context *socket,
388                                     struct fd_event *fde, 
389                                     const char *plain_chars,
390                                     BOOL tls_enable)
391 {
392         struct tls_context *tls;
393         int ret;
394
395         tls = talloc(socket, struct tls_context);
396         if (tls == NULL) return NULL;
397
398         tls->socket          = socket;
399         tls->fde             = fde;
400
401         if (!params->tls_enabled || !tls_enable) {
402                 tls->tls_enabled = False;
403                 return tls;
404         }
405
406         TLSCHECK(gnutls_init(&tls->session, GNUTLS_SERVER));
407
408         talloc_set_destructor(tls, tls_destructor);
409
410         TLSCHECK(gnutls_set_default_priority(tls->session));
411         TLSCHECK(gnutls_credentials_set(tls->session, GNUTLS_CRD_CERTIFICATE, 
412                                         params->x509_cred));
413         gnutls_certificate_server_set_request(tls->session, GNUTLS_CERT_REQUEST);
414         gnutls_dh_set_prime_bits(tls->session, DH_BITS);
415         gnutls_transport_set_ptr(tls->session, (gnutls_transport_ptr)tls);
416         gnutls_transport_set_pull_function(tls->session, (gnutls_pull_func)tls_pull);
417         gnutls_transport_set_push_function(tls->session, (gnutls_push_func)tls_push);
418         gnutls_transport_set_lowat(tls->session, 0);
419
420         tls->plain_chars = plain_chars;
421         if (plain_chars) {
422                 tls->tls_detect = True;
423         } else {
424                 tls->tls_detect = False;
425         }
426
427         tls->output_pending  = False;
428         tls->done_handshake  = False;
429         tls->have_first_byte = False;
430         tls->tls_enabled     = True;
431         tls->interrupted     = False;
432         
433         return tls;
434
435 failed:
436         DEBUG(0,("TLS init connection failed - %s\n", gnutls_strerror(ret)));
437         tls->tls_enabled = False;
438         params->tls_enabled = False;
439         return tls;
440 }
441
442
443 /*
444   setup for a new client connection
445 */
446 struct tls_context *tls_init_client(struct socket_context *socket,
447                                     struct fd_event *fde, 
448                                     BOOL tls_enable)
449 {
450         struct tls_context *tls;
451         int ret;
452         const int cert_type_priority[] = { GNUTLS_CRT_X509, GNUTLS_CRT_OPENPGP, 0 };
453         tls = talloc(socket, struct tls_context);
454         if (tls == NULL) return NULL;
455
456         tls->socket          = socket;
457         tls->fde             = fde;
458         tls->tls_enabled     = tls_enable;
459
460         if (!tls->tls_enabled) {
461                 return tls;
462         }
463
464         gnutls_global_init();
465
466         gnutls_certificate_allocate_credentials(&tls->xcred);
467         gnutls_certificate_set_x509_trust_file(tls->xcred, lp_tls_cafile(),
468                                                GNUTLS_X509_FMT_PEM);
469         TLSCHECK(gnutls_init(&tls->session, GNUTLS_CLIENT));
470         TLSCHECK(gnutls_set_default_priority(tls->session));
471         gnutls_certificate_type_set_priority(tls->session, cert_type_priority);
472         TLSCHECK(gnutls_credentials_set(tls->session, GNUTLS_CRD_CERTIFICATE, tls->xcred));
473
474         talloc_set_destructor(tls, tls_destructor);
475
476         gnutls_transport_set_ptr(tls->session, (gnutls_transport_ptr)tls);
477         gnutls_transport_set_pull_function(tls->session, (gnutls_pull_func)tls_pull);
478         gnutls_transport_set_push_function(tls->session, (gnutls_push_func)tls_push);
479         gnutls_transport_set_lowat(tls->session, 0);
480         tls->tls_detect = False;
481
482         tls->output_pending  = False;
483         tls->done_handshake  = False;
484         tls->have_first_byte = False;
485         tls->tls_enabled     = True;
486         tls->interrupted     = False;
487         
488         return tls;
489
490 failed:
491         DEBUG(0,("TLS init connection failed - %s\n", gnutls_strerror(ret)));
492         tls->tls_enabled = False;
493         return tls;
494 }
495
496 BOOL tls_enabled(struct tls_context *tls)
497 {
498         return tls->tls_enabled;
499 }
500
501 BOOL tls_support(struct tls_params *params)
502 {
503         return params->tls_enabled;
504 }
505
506 #else
507
508 /* for systems without tls we just map the tls socket calls to the
509    normal socket calls */
510
511 struct tls_params *tls_initialise(TALLOC_CTX *mem_ctx)
512 {
513         return talloc_new(mem_ctx);
514 }
515
516 struct tls_context *tls_init_server(struct tls_params *params, 
517                                     struct socket_context *sock, 
518                                     struct fd_event *fde,
519                                     const char *plain_chars,
520                                     BOOL tls_enable)
521 {
522         if (plain_chars == NULL) return NULL;
523         return (struct tls_context *)sock;
524 }
525
526 struct tls_context *tls_init_client(struct socket_context *sock, 
527                                     struct fd_event *fde,
528                                     BOOL tls_enable)
529 {
530         return (struct tls_context *)sock;
531 }
532
533
534 NTSTATUS tls_socket_recv(struct tls_context *tls, void *buf, size_t wantlen, 
535                          size_t *nread)
536 {
537         return socket_recv((struct socket_context *)tls, buf, wantlen, nread, 0);
538 }
539
540 NTSTATUS tls_socket_send(struct tls_context *tls, const DATA_BLOB *blob, size_t *sendlen)
541 {
542         return socket_send((struct socket_context *)tls, blob, sendlen, 0);
543 }
544
545 BOOL tls_enabled(struct tls_context *tls)
546 {
547         return False;
548 }
549
550 BOOL tls_support(struct tls_params *params)
551 {
552         return False;
553 }
554
555 NTSTATUS tls_socket_pending(struct tls_context *tls, size_t *npending)
556 {
557         return socket_pending((struct socket_context *)tls, npending);
558 }
559
560 #endif