s4:lib/tls - include GNUTLS headers consistently using <...>
[samba.git] / source4 / lib / tls / tls_tstream.c
1 /*
2    Unix SMB/CIFS implementation.
3
4    Copyright (C) Stefan Metzmacher 2010
5
6    This program is free software; you can redistribute it and/or modify
7    it under the terms of the GNU General Public License as published by
8    the Free Software Foundation; either version 3 of the License, or
9    (at your option) any later version.
10
11    This program is distributed in the hope that it will be useful,
12    but WITHOUT ANY WARRANTY; without even the implied warranty of
13    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
14    GNU General Public License for more details.
15
16    You should have received a copy of the GNU General Public License
17    along with this program.  If not, see <http://www.gnu.org/licenses/>.
18 */
19
20 #include "includes.h"
21 #include "system/network.h"
22 #include "../util/tevent_unix.h"
23 #include "../lib/tsocket/tsocket.h"
24 #include "../lib/tsocket/tsocket_internal.h"
25 #include "lib/tls/tls.h"
26
27 #if ENABLE_GNUTLS
28 #include <gnutls/gnutls.h>
29
30 #define DH_BITS 1024
31
32 #if defined(HAVE_GNUTLS_DATUM) && !defined(HAVE_GNUTLS_DATUM_T)
33 typedef gnutls_datum gnutls_datum_t;
34 #endif
35
36 #endif /* ENABLE_GNUTLS */
37
38 static const struct tstream_context_ops tstream_tls_ops;
39
40 struct tstream_tls {
41         struct tstream_context *plain_stream;
42         int error;
43
44 #if ENABLE_GNUTLS
45         gnutls_session tls_session;
46 #endif /* ENABLE_GNUTLS */
47
48         struct tevent_context *current_ev;
49
50         struct tevent_immediate *retry_im;
51
52         struct {
53                 uint8_t *buf;
54                 off_t ofs;
55                 struct iovec iov;
56                 struct tevent_req *subreq;
57                 struct tevent_immediate *im;
58         } push;
59
60         struct {
61                 uint8_t *buf;
62                 struct iovec iov;
63                 struct tevent_req *subreq;
64         } pull;
65
66         struct {
67                 struct tevent_req *req;
68         } handshake;
69
70         struct {
71                 off_t ofs;
72                 size_t left;
73                 uint8_t buffer[1024];
74                 struct tevent_req *req;
75         } write;
76
77         struct {
78                 off_t ofs;
79                 size_t left;
80                 uint8_t buffer[1024];
81                 struct tevent_req *req;
82         } read;
83
84         struct {
85                 struct tevent_req *req;
86         } disconnect;
87 };
88
89 static void tstream_tls_retry_handshake(struct tstream_context *stream);
90 static void tstream_tls_retry_read(struct tstream_context *stream);
91 static void tstream_tls_retry_write(struct tstream_context *stream);
92 static void tstream_tls_retry_disconnect(struct tstream_context *stream);
93 static void tstream_tls_retry_trigger(struct tevent_context *ctx,
94                                       struct tevent_immediate *im,
95                                       void *private_data);
96
97 static void tstream_tls_retry(struct tstream_context *stream, bool deferred)
98 {
99
100         struct tstream_tls *tlss =
101                 tstream_context_data(stream,
102                 struct tstream_tls);
103
104         if (tlss->disconnect.req) {
105                 tstream_tls_retry_disconnect(stream);
106                 return;
107         }
108
109         if (tlss->handshake.req) {
110                 tstream_tls_retry_handshake(stream);
111                 return;
112         }
113
114         if (tlss->write.req && tlss->read.req && !deferred) {
115                 tevent_schedule_immediate(tlss->retry_im, tlss->current_ev,
116                                           tstream_tls_retry_trigger,
117                                           stream);
118         }
119
120         if (tlss->write.req) {
121                 tstream_tls_retry_write(stream);
122                 return;
123         }
124
125         if (tlss->read.req) {
126                 tstream_tls_retry_read(stream);
127                 return;
128         }
129 }
130
131 static void tstream_tls_retry_trigger(struct tevent_context *ctx,
132                                       struct tevent_immediate *im,
133                                       void *private_data)
134 {
135         struct tstream_context *stream =
136                 talloc_get_type_abort(private_data,
137                 struct tstream_context);
138
139         tstream_tls_retry(stream, true);
140 }
141
142 #if ENABLE_GNUTLS
143 static void tstream_tls_push_trigger_write(struct tevent_context *ev,
144                                            struct tevent_immediate *im,
145                                            void *private_data);
146
147 static ssize_t tstream_tls_push_function(gnutls_transport_ptr ptr,
148                                          const void *buf, size_t size)
149 {
150         struct tstream_context *stream =
151                 talloc_get_type_abort(ptr,
152                 struct tstream_context);
153         struct tstream_tls *tlss =
154                 tstream_context_data(stream,
155                 struct tstream_tls);
156         uint8_t *nbuf;
157         size_t len;
158
159         if (tlss->error != 0) {
160                 errno = tlss->error;
161                 return -1;
162         }
163
164         if (tlss->push.subreq) {
165                 errno = EAGAIN;
166                 return -1;
167         }
168
169         len = MIN(size, UINT16_MAX - tlss->push.ofs);
170
171         if (len == 0) {
172                 errno = EAGAIN;
173                 return -1;
174         }
175
176         nbuf = talloc_realloc(tlss, tlss->push.buf,
177                               uint8_t, tlss->push.ofs + len);
178         if (nbuf == NULL) {
179                 if (tlss->push.buf) {
180                         errno = EAGAIN;
181                         return -1;
182                 }
183
184                 return -1;
185         }
186         tlss->push.buf = nbuf;
187
188         memcpy(tlss->push.buf + tlss->push.ofs, buf, len);
189
190         if (tlss->push.im == NULL) {
191                 tlss->push.im = tevent_create_immediate(tlss);
192                 if (tlss->push.im == NULL) {
193                         errno = ENOMEM;
194                         return -1;
195                 }
196         }
197
198         if (tlss->push.ofs == 0) {
199                 /*
200                  * We'll do start the tstream_writev
201                  * in the next event cycle.
202                  *
203                  * This way we can batch all push requests,
204                  * if they fit into a UINT16_MAX buffer.
205                  *
206                  * This is important as gnutls_handshake()
207                  * had a bug in some versions e.g. 2.4.1
208                  * and others (See bug #7218) and it doesn't
209                  * handle EAGAIN.
210                  */
211                 tevent_schedule_immediate(tlss->push.im,
212                                           tlss->current_ev,
213                                           tstream_tls_push_trigger_write,
214                                           stream);
215         }
216
217         tlss->push.ofs += len;
218         return len;
219 }
220
221 static void tstream_tls_push_done(struct tevent_req *subreq);
222
223 static void tstream_tls_push_trigger_write(struct tevent_context *ev,
224                                            struct tevent_immediate *im,
225                                            void *private_data)
226 {
227         struct tstream_context *stream =
228                 talloc_get_type_abort(private_data,
229                 struct tstream_context);
230         struct tstream_tls *tlss =
231                 tstream_context_data(stream,
232                 struct tstream_tls);
233         struct tevent_req *subreq;
234
235         if (tlss->push.subreq) {
236                 /* nothing todo */
237                 return;
238         }
239
240         tlss->push.iov.iov_base = (char *)tlss->push.buf;
241         tlss->push.iov.iov_len = tlss->push.ofs;
242
243         subreq = tstream_writev_send(tlss,
244                                      tlss->current_ev,
245                                      tlss->plain_stream,
246                                      &tlss->push.iov, 1);
247         if (subreq == NULL) {
248                 tlss->error = ENOMEM;
249                 tstream_tls_retry(stream, false);
250                 return;
251         }
252         tevent_req_set_callback(subreq, tstream_tls_push_done, stream);
253
254         tlss->push.subreq = subreq;
255 }
256
257 static void tstream_tls_push_done(struct tevent_req *subreq)
258 {
259         struct tstream_context *stream =
260                 tevent_req_callback_data(subreq,
261                 struct tstream_context);
262         struct tstream_tls *tlss =
263                 tstream_context_data(stream,
264                 struct tstream_tls);
265         int ret;
266         int sys_errno;
267
268         tlss->push.subreq = NULL;
269         ZERO_STRUCT(tlss->push.iov);
270         TALLOC_FREE(tlss->push.buf);
271         tlss->push.ofs = 0;
272
273         ret = tstream_writev_recv(subreq, &sys_errno);
274         TALLOC_FREE(subreq);
275         if (ret == -1) {
276                 tlss->error = sys_errno;
277                 tstream_tls_retry(stream, false);
278                 return;
279         }
280
281         tstream_tls_retry(stream, false);
282 }
283
284 static void tstream_tls_pull_done(struct tevent_req *subreq);
285
286 static ssize_t tstream_tls_pull_function(gnutls_transport_ptr ptr,
287                                          void *buf, size_t size)
288 {
289         struct tstream_context *stream =
290                 talloc_get_type_abort(ptr,
291                 struct tstream_context);
292         struct tstream_tls *tlss =
293                 tstream_context_data(stream,
294                 struct tstream_tls);
295         struct tevent_req *subreq;
296         size_t len;
297
298         if (tlss->error != 0) {
299                 errno = tlss->error;
300                 return -1;
301         }
302
303         if (tlss->pull.subreq) {
304                 errno = EAGAIN;
305                 return -1;
306         }
307
308         if (tlss->pull.iov.iov_base) {
309                 uint8_t *b;
310                 size_t n;
311
312                 b = (uint8_t *)tlss->pull.iov.iov_base;
313
314                 n = MIN(tlss->pull.iov.iov_len, size);
315                 memcpy(buf, b, n);
316
317                 tlss->pull.iov.iov_len -= n;
318                 b += n;
319                 tlss->pull.iov.iov_base = (char *)b;
320                 if (tlss->pull.iov.iov_len == 0) {
321                         tlss->pull.iov.iov_base = NULL;
322                         TALLOC_FREE(tlss->pull.buf);
323                 }
324
325                 return n;
326         }
327
328         if (size == 0) {
329                 return 0;
330         }
331
332         len = MIN(size, UINT16_MAX);
333
334         tlss->pull.buf = talloc_array(tlss, uint8_t, len);
335         if (tlss->pull.buf == NULL) {
336                 return -1;
337         }
338
339         tlss->pull.iov.iov_base = (char *)tlss->pull.buf;
340         tlss->pull.iov.iov_len = len;
341
342         subreq = tstream_readv_send(tlss,
343                                     tlss->current_ev,
344                                     tlss->plain_stream,
345                                     &tlss->pull.iov, 1);
346         if (subreq == NULL) {
347                 errno = ENOMEM;
348                 return -1;
349         }
350         tevent_req_set_callback(subreq, tstream_tls_pull_done, stream);
351
352         tlss->pull.subreq = subreq;
353         errno = EAGAIN;
354         return -1;
355 }
356
357 static void tstream_tls_pull_done(struct tevent_req *subreq)
358 {
359         struct tstream_context *stream =
360                 tevent_req_callback_data(subreq,
361                 struct tstream_context);
362         struct tstream_tls *tlss =
363                 tstream_context_data(stream,
364                 struct tstream_tls);
365         int ret;
366         int sys_errno;
367
368         tlss->pull.subreq = NULL;
369
370         ret = tstream_readv_recv(subreq, &sys_errno);
371         TALLOC_FREE(subreq);
372         if (ret == -1) {
373                 tlss->error = sys_errno;
374                 tstream_tls_retry(stream, false);
375                 return;
376         }
377
378         tstream_tls_retry(stream, false);
379 }
380 #endif /* ENABLE_GNUTLS */
381
382 static int tstream_tls_destructor(struct tstream_tls *tlss)
383 {
384 #if ENABLE_GNUTLS
385         if (tlss->tls_session) {
386                 gnutls_deinit(tlss->tls_session);
387                 tlss->tls_session = NULL;
388         }
389 #endif /* ENABLE_GNUTLS */
390         return 0;
391 }
392
393 static ssize_t tstream_tls_pending_bytes(struct tstream_context *stream)
394 {
395         struct tstream_tls *tlss =
396                 tstream_context_data(stream,
397                 struct tstream_tls);
398         size_t ret;
399
400         if (tlss->error != 0) {
401                 errno = tlss->error;
402                 return -1;
403         }
404
405 #if ENABLE_GNUTLS
406         ret = gnutls_record_check_pending(tlss->tls_session);
407         ret += tlss->read.left;
408 #else /* ENABLE_GNUTLS */
409         errno = ENOSYS;
410         ret = -1;
411 #endif /* ENABLE_GNUTLS */
412         return ret;
413 }
414
415 struct tstream_tls_readv_state {
416         struct tstream_context *stream;
417
418         struct iovec *vector;
419         int count;
420
421         int ret;
422 };
423
424 static void tstream_tls_readv_crypt_next(struct tevent_req *req);
425
426 static struct tevent_req *tstream_tls_readv_send(TALLOC_CTX *mem_ctx,
427                                         struct tevent_context *ev,
428                                         struct tstream_context *stream,
429                                         struct iovec *vector,
430                                         size_t count)
431 {
432         struct tstream_tls *tlss =
433                 tstream_context_data(stream,
434                 struct tstream_tls);
435         struct tevent_req *req;
436         struct tstream_tls_readv_state *state;
437
438         tlss->read.req = NULL;
439         tlss->current_ev = ev;
440
441         req = tevent_req_create(mem_ctx, &state,
442                                 struct tstream_tls_readv_state);
443         if (req == NULL) {
444                 return NULL;
445         }
446
447         state->stream = stream;
448         state->ret = 0;
449
450         if (tlss->error != 0) {
451                 tevent_req_error(req, tlss->error);
452                 return tevent_req_post(req, ev);
453         }
454
455         /*
456          * we make a copy of the vector so we can change the structure
457          */
458         state->vector = talloc_array(state, struct iovec, count);
459         if (tevent_req_nomem(state->vector, req)) {
460                 return tevent_req_post(req, ev);
461         }
462         memcpy(state->vector, vector, sizeof(struct iovec) * count);
463         state->count = count;
464
465         tstream_tls_readv_crypt_next(req);
466         if (!tevent_req_is_in_progress(req)) {
467                 return tevent_req_post(req, ev);
468         }
469
470         return req;
471 }
472
473 static void tstream_tls_readv_crypt_next(struct tevent_req *req)
474 {
475         struct tstream_tls_readv_state *state =
476                 tevent_req_data(req,
477                 struct tstream_tls_readv_state);
478         struct tstream_tls *tlss =
479                 tstream_context_data(state->stream,
480                 struct tstream_tls);
481
482         /*
483          * copy the pending buffer first
484          */
485         while (tlss->read.left > 0 && state->count > 0) {
486                 uint8_t *base = (uint8_t *)state->vector[0].iov_base;
487                 size_t len = MIN(tlss->read.left, state->vector[0].iov_len);
488
489                 memcpy(base, tlss->read.buffer + tlss->read.ofs, len);
490
491                 base += len;
492                 state->vector[0].iov_base = (char *) base;
493                 state->vector[0].iov_len -= len;
494
495                 tlss->read.ofs += len;
496                 tlss->read.left -= len;
497
498                 if (state->vector[0].iov_len == 0) {
499                         state->vector += 1;
500                         state->count -= 1;
501                 }
502
503                 state->ret += len;
504         }
505
506         if (state->count == 0) {
507                 tevent_req_done(req);
508                 return;
509         }
510
511         tlss->read.req = req;
512         tstream_tls_retry_read(state->stream);
513 }
514
515 static void tstream_tls_retry_read(struct tstream_context *stream)
516 {
517         struct tstream_tls *tlss =
518                 tstream_context_data(stream,
519                 struct tstream_tls);
520         struct tevent_req *req = tlss->read.req;
521 #if ENABLE_GNUTLS
522         int ret;
523
524         if (tlss->error != 0) {
525                 tevent_req_error(req, tlss->error);
526                 return;
527         }
528
529         tlss->read.left = 0;
530         tlss->read.ofs = 0;
531
532         ret = gnutls_record_recv(tlss->tls_session,
533                                  tlss->read.buffer,
534                                  sizeof(tlss->read.buffer));
535         if (ret == GNUTLS_E_INTERRUPTED || ret == GNUTLS_E_AGAIN) {
536                 return;
537         }
538
539         tlss->read.req = NULL;
540
541         if (gnutls_error_is_fatal(ret) != 0) {
542                 DEBUG(1,("TLS %s - %s\n", __location__, gnutls_strerror(ret)));
543                 tlss->error = EIO;
544                 tevent_req_error(req, tlss->error);
545                 return;
546         }
547
548         if (ret == 0) {
549                 tlss->error = EPIPE;
550                 tevent_req_error(req, tlss->error);
551                 return;
552         }
553
554         tlss->read.left = ret;
555         tstream_tls_readv_crypt_next(req);
556 #else /* ENABLE_GNUTLS */
557         tevent_req_error(req, ENOSYS);
558 #endif /* ENABLE_GNUTLS */
559 }
560
561 static int tstream_tls_readv_recv(struct tevent_req *req,
562                                   int *perrno)
563 {
564         struct tstream_tls_readv_state *state =
565                 tevent_req_data(req,
566                 struct tstream_tls_readv_state);
567         struct tstream_tls *tlss =
568                 tstream_context_data(state->stream,
569                 struct tstream_tls);
570         int ret;
571
572         tlss->read.req = NULL;
573
574         ret = tsocket_simple_int_recv(req, perrno);
575         if (ret == 0) {
576                 ret = state->ret;
577         }
578
579         tevent_req_received(req);
580         return ret;
581 }
582
583 struct tstream_tls_writev_state {
584         struct tstream_context *stream;
585
586         struct iovec *vector;
587         int count;
588
589         int ret;
590 };
591
592 static void tstream_tls_writev_crypt_next(struct tevent_req *req);
593
594 static struct tevent_req *tstream_tls_writev_send(TALLOC_CTX *mem_ctx,
595                                         struct tevent_context *ev,
596                                         struct tstream_context *stream,
597                                         const struct iovec *vector,
598                                         size_t count)
599 {
600         struct tstream_tls *tlss =
601                 tstream_context_data(stream,
602                 struct tstream_tls);
603         struct tevent_req *req;
604         struct tstream_tls_writev_state *state;
605
606         tlss->write.req = NULL;
607         tlss->current_ev = ev;
608
609         req = tevent_req_create(mem_ctx, &state,
610                                 struct tstream_tls_writev_state);
611         if (req == NULL) {
612                 return NULL;
613         }
614
615         state->stream = stream;
616         state->ret = 0;
617
618         if (tlss->error != 0) {
619                 tevent_req_error(req, tlss->error);
620                 return tevent_req_post(req, ev);
621         }
622
623         /*
624          * we make a copy of the vector so we can change the structure
625          */
626         state->vector = talloc_array(state, struct iovec, count);
627         if (tevent_req_nomem(state->vector, req)) {
628                 return tevent_req_post(req, ev);
629         }
630         memcpy(state->vector, vector, sizeof(struct iovec) * count);
631         state->count = count;
632
633         tstream_tls_writev_crypt_next(req);
634         if (!tevent_req_is_in_progress(req)) {
635                 return tevent_req_post(req, ev);
636         }
637
638         return req;
639 }
640
641 static void tstream_tls_writev_crypt_next(struct tevent_req *req)
642 {
643         struct tstream_tls_writev_state *state =
644                 tevent_req_data(req,
645                 struct tstream_tls_writev_state);
646         struct tstream_tls *tlss =
647                 tstream_context_data(state->stream,
648                 struct tstream_tls);
649
650         tlss->write.left = sizeof(tlss->write.buffer);
651         tlss->write.ofs = 0;
652
653         /*
654          * first fill our buffer
655          */
656         while (tlss->write.left > 0 && state->count > 0) {
657                 uint8_t *base = (uint8_t *)state->vector[0].iov_base;
658                 size_t len = MIN(tlss->write.left, state->vector[0].iov_len);
659
660                 memcpy(tlss->write.buffer + tlss->write.ofs, base, len);
661
662                 base += len;
663                 state->vector[0].iov_base = (char *) base;
664                 state->vector[0].iov_len -= len;
665
666                 tlss->write.ofs += len;
667                 tlss->write.left -= len;
668
669                 if (state->vector[0].iov_len == 0) {
670                         state->vector += 1;
671                         state->count -= 1;
672                 }
673
674                 state->ret += len;
675         }
676
677         if (tlss->write.ofs == 0) {
678                 tevent_req_done(req);
679                 return;
680         }
681
682         tlss->write.left = tlss->write.ofs;
683         tlss->write.ofs = 0;
684
685         tlss->write.req = req;
686         tstream_tls_retry_write(state->stream);
687 }
688
689 static void tstream_tls_retry_write(struct tstream_context *stream)
690 {
691         struct tstream_tls *tlss =
692                 tstream_context_data(stream,
693                 struct tstream_tls);
694         struct tevent_req *req = tlss->write.req;
695 #if ENABLE_GNUTLS
696         int ret;
697
698         if (tlss->error != 0) {
699                 tevent_req_error(req, tlss->error);
700                 return;
701         }
702
703         ret = gnutls_record_send(tlss->tls_session,
704                                  tlss->write.buffer + tlss->write.ofs,
705                                  tlss->write.left);
706         if (ret == GNUTLS_E_INTERRUPTED || ret == GNUTLS_E_AGAIN) {
707                 return;
708         }
709
710         tlss->write.req = NULL;
711
712         if (gnutls_error_is_fatal(ret) != 0) {
713                 DEBUG(1,("TLS %s - %s\n", __location__, gnutls_strerror(ret)));
714                 tlss->error = EIO;
715                 tevent_req_error(req, tlss->error);
716                 return;
717         }
718
719         if (ret == 0) {
720                 tlss->error = EPIPE;
721                 tevent_req_error(req, tlss->error);
722                 return;
723         }
724
725         tlss->write.ofs += ret;
726         tlss->write.left -= ret;
727
728         if (tlss->write.left > 0) {
729                 tlss->write.req = req;
730                 tstream_tls_retry_write(stream);
731                 return;
732         }
733
734         tstream_tls_writev_crypt_next(req);
735 #else /* ENABLE_GNUTLS */
736         tevent_req_error(req, ENOSYS);
737 #endif /* ENABLE_GNUTLS */
738 }
739
740 static int tstream_tls_writev_recv(struct tevent_req *req,
741                                    int *perrno)
742 {
743         struct tstream_tls_writev_state *state =
744                 tevent_req_data(req,
745                 struct tstream_tls_writev_state);
746         struct tstream_tls *tlss =
747                 tstream_context_data(state->stream,
748                 struct tstream_tls);
749         int ret;
750
751         tlss->write.req = NULL;
752
753         ret = tsocket_simple_int_recv(req, perrno);
754         if (ret == 0) {
755                 ret = state->ret;
756         }
757
758         tevent_req_received(req);
759         return ret;
760 }
761
762 struct tstream_tls_disconnect_state {
763         uint8_t _dummy;
764 };
765
766 static struct tevent_req *tstream_tls_disconnect_send(TALLOC_CTX *mem_ctx,
767                                                 struct tevent_context *ev,
768                                                 struct tstream_context *stream)
769 {
770         struct tstream_tls *tlss =
771                 tstream_context_data(stream,
772                 struct tstream_tls);
773         struct tevent_req *req;
774         struct tstream_tls_disconnect_state *state;
775
776         tlss->disconnect.req = NULL;
777         tlss->current_ev = ev;
778
779         req = tevent_req_create(mem_ctx, &state,
780                                 struct tstream_tls_disconnect_state);
781         if (req == NULL) {
782                 return NULL;
783         }
784
785         if (tlss->error != 0) {
786                 tevent_req_error(req, tlss->error);
787                 return tevent_req_post(req, ev);
788         }
789
790         tlss->disconnect.req = req;
791         tstream_tls_retry_disconnect(stream);
792         if (!tevent_req_is_in_progress(req)) {
793                 return tevent_req_post(req, ev);
794         }
795
796         return req;
797 }
798
799 static void tstream_tls_retry_disconnect(struct tstream_context *stream)
800 {
801         struct tstream_tls *tlss =
802                 tstream_context_data(stream,
803                 struct tstream_tls);
804         struct tevent_req *req = tlss->disconnect.req;
805 #if ENABLE_GNUTLS
806         int ret;
807
808         if (tlss->error != 0) {
809                 tevent_req_error(req, tlss->error);
810                 return;
811         }
812
813         ret = gnutls_bye(tlss->tls_session, GNUTLS_SHUT_WR);
814         if (ret == GNUTLS_E_INTERRUPTED || ret == GNUTLS_E_AGAIN) {
815                 return;
816         }
817
818         tlss->disconnect.req = NULL;
819
820         if (gnutls_error_is_fatal(ret) != 0) {
821                 DEBUG(1,("TLS %s - %s\n", __location__, gnutls_strerror(ret)));
822                 tlss->error = EIO;
823                 tevent_req_error(req, tlss->error);
824                 return;
825         }
826
827         if (ret != GNUTLS_E_SUCCESS) {
828                 DEBUG(1,("TLS %s - %s\n", __location__, gnutls_strerror(ret)));
829                 tlss->error = EIO;
830                 tevent_req_error(req, tlss->error);
831                 return;
832         }
833
834         tevent_req_done(req);
835 #else /* ENABLE_GNUTLS */
836         tevent_req_error(req, ENOSYS);
837 #endif /* ENABLE_GNUTLS */
838 }
839
840 static int tstream_tls_disconnect_recv(struct tevent_req *req,
841                                        int *perrno)
842 {
843         int ret;
844
845         ret = tsocket_simple_int_recv(req, perrno);
846
847         tevent_req_received(req);
848         return ret;
849 }
850
851 static const struct tstream_context_ops tstream_tls_ops = {
852         .name                   = "tls",
853
854         .pending_bytes          = tstream_tls_pending_bytes,
855
856         .readv_send             = tstream_tls_readv_send,
857         .readv_recv             = tstream_tls_readv_recv,
858
859         .writev_send            = tstream_tls_writev_send,
860         .writev_recv            = tstream_tls_writev_recv,
861
862         .disconnect_send        = tstream_tls_disconnect_send,
863         .disconnect_recv        = tstream_tls_disconnect_recv,
864 };
865
866 struct tstream_tls_params {
867 #if ENABLE_GNUTLS
868         gnutls_certificate_credentials x509_cred;
869         gnutls_dh_params dh_params;
870 #endif /* ENABLE_GNUTLS */
871         bool tls_enabled;
872 };
873
874 static int tstream_tls_params_destructor(struct tstream_tls_params *tlsp)
875 {
876 #if ENABLE_GNUTLS
877         if (tlsp->x509_cred) {
878                 gnutls_certificate_free_credentials(tlsp->x509_cred);
879                 tlsp->x509_cred = NULL;
880         }
881         if (tlsp->dh_params) {
882                 gnutls_dh_params_deinit(tlsp->dh_params);
883                 tlsp->dh_params = NULL;
884         }
885 #endif /* ENABLE_GNUTLS */
886         return 0;
887 }
888
889 bool tstream_tls_params_enabled(struct tstream_tls_params *tlsp)
890 {
891         return tlsp->tls_enabled;
892 }
893
894 NTSTATUS tstream_tls_params_client(TALLOC_CTX *mem_ctx,
895                                    const char *ca_file,
896                                    const char *crl_file,
897                                    struct tstream_tls_params **_tlsp)
898 {
899 #if ENABLE_GNUTLS
900         struct tstream_tls_params *tlsp;
901         int ret;
902
903         ret = gnutls_global_init();
904         if (ret != GNUTLS_E_SUCCESS) {
905                 DEBUG(0,("TLS %s - %s\n", __location__, gnutls_strerror(ret)));
906                 return NT_STATUS_NOT_SUPPORTED;
907         }
908
909         tlsp = talloc_zero(mem_ctx, struct tstream_tls_params);
910         NT_STATUS_HAVE_NO_MEMORY(tlsp);
911
912         talloc_set_destructor(tlsp, tstream_tls_params_destructor);
913
914         ret = gnutls_certificate_allocate_credentials(&tlsp->x509_cred);
915         if (ret != GNUTLS_E_SUCCESS) {
916                 DEBUG(0,("TLS %s - %s\n", __location__, gnutls_strerror(ret)));
917                 talloc_free(tlsp);
918                 return NT_STATUS_NO_MEMORY;
919         }
920
921         if (ca_file && *ca_file) {
922                 ret = gnutls_certificate_set_x509_trust_file(tlsp->x509_cred,
923                                                              ca_file,
924                                                              GNUTLS_X509_FMT_PEM);
925                 if (ret < 0) {
926                         DEBUG(0,("TLS failed to initialise cafile %s - %s\n",
927                                  ca_file, gnutls_strerror(ret)));
928                         talloc_free(tlsp);
929                         return NT_STATUS_CANT_ACCESS_DOMAIN_INFO;
930                 }
931         }
932
933         if (crl_file && *crl_file) {
934                 ret = gnutls_certificate_set_x509_crl_file(tlsp->x509_cred,
935                                                            crl_file, 
936                                                            GNUTLS_X509_FMT_PEM);
937                 if (ret < 0) {
938                         DEBUG(0,("TLS failed to initialise crlfile %s - %s\n",
939                                  crl_file, gnutls_strerror(ret)));
940                         talloc_free(tlsp);
941                         return NT_STATUS_CANT_ACCESS_DOMAIN_INFO;
942                 }
943         }
944
945         tlsp->tls_enabled = true;
946
947         *_tlsp = tlsp;
948         return NT_STATUS_OK;
949 #else /* ENABLE_GNUTLS */
950         return NT_STATUS_NOT_IMPLEMENTED;
951 #endif /* ENABLE_GNUTLS */
952 }
953
954 struct tstream_tls_connect_state {
955         struct tstream_context *tls_stream;
956 };
957
958 struct tevent_req *_tstream_tls_connect_send(TALLOC_CTX *mem_ctx,
959                                              struct tevent_context *ev,
960                                              struct tstream_context *plain_stream,
961                                              struct tstream_tls_params *tls_params,
962                                              const char *location)
963 {
964         struct tevent_req *req;
965         struct tstream_tls_connect_state *state;
966 #if ENABLE_GNUTLS
967         struct tstream_tls *tlss;
968         int ret;
969         static const int cert_type_priority[] = {
970                 GNUTLS_CRT_X509,
971                 GNUTLS_CRT_OPENPGP,
972                 0
973         };
974 #endif /* ENABLE_GNUTLS */
975
976         req = tevent_req_create(mem_ctx, &state,
977                                 struct tstream_tls_connect_state);
978         if (req == NULL) {
979                 return NULL;
980         }
981
982 #if ENABLE_GNUTLS
983         state->tls_stream = tstream_context_create(state,
984                                                    &tstream_tls_ops,
985                                                    &tlss,
986                                                    struct tstream_tls,
987                                                    location);
988         if (tevent_req_nomem(state->tls_stream, req)) {
989                 return tevent_req_post(req, ev);
990         }
991         ZERO_STRUCTP(tlss);
992         talloc_set_destructor(tlss, tstream_tls_destructor);
993
994         tlss->plain_stream = plain_stream;
995
996         tlss->current_ev = ev;
997         tlss->retry_im = tevent_create_immediate(tlss);
998         if (tevent_req_nomem(tlss->retry_im, req)) {
999                 return tevent_req_post(req, ev);
1000         }
1001
1002         ret = gnutls_init(&tlss->tls_session, GNUTLS_CLIENT);
1003         if (ret != GNUTLS_E_SUCCESS) {
1004                 DEBUG(0,("TLS %s - %s\n", __location__, gnutls_strerror(ret)));
1005                 tevent_req_error(req, EINVAL);
1006                 return tevent_req_post(req, ev);
1007         }
1008
1009         ret = gnutls_set_default_priority(tlss->tls_session);
1010         if (ret != GNUTLS_E_SUCCESS) {
1011                 DEBUG(0,("TLS %s - %s\n", __location__, gnutls_strerror(ret)));
1012                 tevent_req_error(req, EINVAL);
1013                 return tevent_req_post(req, ev);
1014         }
1015
1016         gnutls_certificate_type_set_priority(tlss->tls_session, cert_type_priority);
1017
1018         ret = gnutls_credentials_set(tlss->tls_session,
1019                                      GNUTLS_CRD_CERTIFICATE,
1020                                      tls_params->x509_cred);
1021         if (ret != GNUTLS_E_SUCCESS) {
1022                 DEBUG(0,("TLS %s - %s\n", __location__, gnutls_strerror(ret)));
1023                 tevent_req_error(req, EINVAL);
1024                 return tevent_req_post(req, ev);
1025         }
1026
1027         gnutls_transport_set_ptr(tlss->tls_session, (gnutls_transport_ptr)state->tls_stream);
1028         gnutls_transport_set_pull_function(tlss->tls_session,
1029                                            (gnutls_pull_func)tstream_tls_pull_function);
1030         gnutls_transport_set_push_function(tlss->tls_session,
1031                                            (gnutls_push_func)tstream_tls_push_function);
1032 #if GNUTLS_VERSION_MAJOR < 3
1033         gnutls_transport_set_lowat(tlss->tls_session, 0);
1034 #endif
1035
1036         tlss->handshake.req = req;
1037         tstream_tls_retry_handshake(state->tls_stream);
1038         if (!tevent_req_is_in_progress(req)) {
1039                 return tevent_req_post(req, ev);
1040         }
1041
1042         return req;
1043 #else /* ENABLE_GNUTLS */
1044         tevent_req_error(req, ENOSYS);
1045         return tevent_req_post(req, ev);
1046 #endif /* ENABLE_GNUTLS */
1047 }
1048
1049 int tstream_tls_connect_recv(struct tevent_req *req,
1050                              int *perrno,
1051                              TALLOC_CTX *mem_ctx,
1052                              struct tstream_context **tls_stream)
1053 {
1054         struct tstream_tls_connect_state *state =
1055                 tevent_req_data(req,
1056                 struct tstream_tls_connect_state);
1057
1058         if (tevent_req_is_unix_error(req, perrno)) {
1059                 tevent_req_received(req);
1060                 return -1;
1061         }
1062
1063         *tls_stream = talloc_move(mem_ctx, &state->tls_stream);
1064         tevent_req_received(req);
1065         return 0;
1066 }
1067
1068 extern void tls_cert_generate(TALLOC_CTX *, const char *, const char *, const char *, const char *);
1069
1070 /*
1071   initialise global tls state
1072 */
1073 NTSTATUS tstream_tls_params_server(TALLOC_CTX *mem_ctx,
1074                                    const char *dns_host_name,
1075                                    bool enabled,
1076                                    const char *key_file,
1077                                    const char *cert_file,
1078                                    const char *ca_file,
1079                                    const char *crl_file,
1080                                    const char *dhp_file,
1081                                    struct tstream_tls_params **_tlsp)
1082 {
1083         struct tstream_tls_params *tlsp;
1084 #if ENABLE_GNUTLS
1085         int ret;
1086
1087         if (!enabled || key_file == NULL || *key_file == 0) {
1088                 tlsp = talloc_zero(mem_ctx, struct tstream_tls_params);
1089                 NT_STATUS_HAVE_NO_MEMORY(tlsp);
1090                 talloc_set_destructor(tlsp, tstream_tls_params_destructor);
1091                 tlsp->tls_enabled = false;
1092
1093                 *_tlsp = tlsp;
1094                 return NT_STATUS_OK;
1095         }
1096
1097         ret = gnutls_global_init();
1098         if (ret != GNUTLS_E_SUCCESS) {
1099                 DEBUG(0,("TLS %s - %s\n", __location__, gnutls_strerror(ret)));
1100                 return NT_STATUS_NOT_SUPPORTED;
1101         }
1102
1103         tlsp = talloc_zero(mem_ctx, struct tstream_tls_params);
1104         NT_STATUS_HAVE_NO_MEMORY(tlsp);
1105
1106         talloc_set_destructor(tlsp, tstream_tls_params_destructor);
1107
1108         if (!file_exist(ca_file)) {
1109                 tls_cert_generate(tlsp, dns_host_name,
1110                                   key_file, cert_file, ca_file);
1111         }
1112
1113         ret = gnutls_certificate_allocate_credentials(&tlsp->x509_cred);
1114         if (ret != GNUTLS_E_SUCCESS) {
1115                 DEBUG(0,("TLS %s - %s\n", __location__, gnutls_strerror(ret)));
1116                 talloc_free(tlsp);
1117                 return NT_STATUS_NO_MEMORY;
1118         }
1119
1120         if (ca_file && *ca_file) {
1121                 ret = gnutls_certificate_set_x509_trust_file(tlsp->x509_cred,
1122                                                              ca_file,
1123                                                              GNUTLS_X509_FMT_PEM);
1124                 if (ret < 0) {
1125                         DEBUG(0,("TLS failed to initialise cafile %s - %s\n",
1126                                  ca_file, gnutls_strerror(ret)));
1127                         talloc_free(tlsp);
1128                         return NT_STATUS_CANT_ACCESS_DOMAIN_INFO;
1129                 }
1130         }
1131
1132         if (crl_file && *crl_file) {
1133                 ret = gnutls_certificate_set_x509_crl_file(tlsp->x509_cred,
1134                                                            crl_file, 
1135                                                            GNUTLS_X509_FMT_PEM);
1136                 if (ret < 0) {
1137                         DEBUG(0,("TLS failed to initialise crlfile %s - %s\n",
1138                                  crl_file, gnutls_strerror(ret)));
1139                         talloc_free(tlsp);
1140                         return NT_STATUS_CANT_ACCESS_DOMAIN_INFO;
1141                 }
1142         }
1143
1144         ret = gnutls_certificate_set_x509_key_file(tlsp->x509_cred,
1145                                                    cert_file, key_file,
1146                                                    GNUTLS_X509_FMT_PEM);
1147         if (ret != GNUTLS_E_SUCCESS) {
1148                 DEBUG(0,("TLS failed to initialise certfile %s and keyfile %s - %s\n",
1149                          cert_file, key_file, gnutls_strerror(ret)));
1150                 talloc_free(tlsp);
1151                 return NT_STATUS_CANT_ACCESS_DOMAIN_INFO;
1152         }
1153
1154         ret = gnutls_dh_params_init(&tlsp->dh_params);
1155         if (ret != GNUTLS_E_SUCCESS) {
1156                 DEBUG(0,("TLS %s - %s\n", __location__, gnutls_strerror(ret)));
1157                 talloc_free(tlsp);
1158                 return NT_STATUS_NO_MEMORY;
1159         }
1160
1161         if (dhp_file && *dhp_file) {
1162                 gnutls_datum_t dhparms;
1163                 size_t size;
1164
1165                 dhparms.data = (uint8_t *)file_load(dhp_file, &size, 0, tlsp);
1166
1167                 if (!dhparms.data) {
1168                         DEBUG(0,("TLS failed to read DH Parms from %s - %d:%s\n",
1169                                  dhp_file, errno, strerror(errno)));
1170                         talloc_free(tlsp);
1171                         return NT_STATUS_CANT_ACCESS_DOMAIN_INFO;
1172                 }
1173                 dhparms.size = size;
1174
1175                 ret = gnutls_dh_params_import_pkcs3(tlsp->dh_params,
1176                                                     &dhparms,
1177                                                     GNUTLS_X509_FMT_PEM);
1178                 if (ret != GNUTLS_E_SUCCESS) {
1179                         DEBUG(0,("TLS failed to import pkcs3 %s - %s\n",
1180                                  dhp_file, gnutls_strerror(ret)));
1181                         talloc_free(tlsp);
1182                         return NT_STATUS_CANT_ACCESS_DOMAIN_INFO;
1183                 }
1184         } else {
1185                 ret = gnutls_dh_params_generate2(tlsp->dh_params, DH_BITS);
1186                 if (ret != GNUTLS_E_SUCCESS) {
1187                         DEBUG(0,("TLS failed to generate dh_params - %s\n",
1188                                  gnutls_strerror(ret)));
1189                         talloc_free(tlsp);
1190                         return NT_STATUS_INTERNAL_ERROR;
1191                 }
1192         }
1193
1194         gnutls_certificate_set_dh_params(tlsp->x509_cred, tlsp->dh_params);
1195
1196         tlsp->tls_enabled = true;
1197
1198 #else /* ENABLE_GNUTLS */
1199         tlsp = talloc_zero(mem_ctx, struct tstream_tls_params);
1200         NT_STATUS_HAVE_NO_MEMORY(tlsp);
1201         talloc_set_destructor(tlsp, tstream_tls_params_destructor);
1202         tlsp->tls_enabled = false;
1203 #endif /* ENABLE_GNUTLS */
1204
1205         *_tlsp = tlsp;
1206         return NT_STATUS_OK;
1207 }
1208
1209 struct tstream_tls_accept_state {
1210         struct tstream_context *tls_stream;
1211 };
1212
1213 struct tevent_req *_tstream_tls_accept_send(TALLOC_CTX *mem_ctx,
1214                                             struct tevent_context *ev,
1215                                             struct tstream_context *plain_stream,
1216                                             struct tstream_tls_params *tlsp,
1217                                             const char *location)
1218 {
1219         struct tevent_req *req;
1220         struct tstream_tls_accept_state *state;
1221         struct tstream_tls *tlss;
1222 #if ENABLE_GNUTLS
1223         int ret;
1224 #endif /* ENABLE_GNUTLS */
1225
1226         req = tevent_req_create(mem_ctx, &state,
1227                                 struct tstream_tls_accept_state);
1228         if (req == NULL) {
1229                 return NULL;
1230         }
1231
1232         state->tls_stream = tstream_context_create(state,
1233                                                    &tstream_tls_ops,
1234                                                    &tlss,
1235                                                    struct tstream_tls,
1236                                                    location);
1237         if (tevent_req_nomem(state->tls_stream, req)) {
1238                 return tevent_req_post(req, ev);
1239         }
1240         ZERO_STRUCTP(tlss);
1241         talloc_set_destructor(tlss, tstream_tls_destructor);
1242
1243 #if ENABLE_GNUTLS
1244         tlss->plain_stream = plain_stream;
1245
1246         tlss->current_ev = ev;
1247         tlss->retry_im = tevent_create_immediate(tlss);
1248         if (tevent_req_nomem(tlss->retry_im, req)) {
1249                 return tevent_req_post(req, ev);
1250         }
1251
1252         ret = gnutls_init(&tlss->tls_session, GNUTLS_SERVER);
1253         if (ret != GNUTLS_E_SUCCESS) {
1254                 DEBUG(0,("TLS %s - %s\n", __location__, gnutls_strerror(ret)));
1255                 tevent_req_error(req, EINVAL);
1256                 return tevent_req_post(req, ev);
1257         }
1258
1259         ret = gnutls_set_default_priority(tlss->tls_session);
1260         if (ret != GNUTLS_E_SUCCESS) {
1261                 DEBUG(0,("TLS %s - %s\n", __location__, gnutls_strerror(ret)));
1262                 tevent_req_error(req, EINVAL);
1263                 return tevent_req_post(req, ev);
1264         }
1265
1266         ret = gnutls_credentials_set(tlss->tls_session, GNUTLS_CRD_CERTIFICATE,
1267                                      tlsp->x509_cred);
1268         if (ret != GNUTLS_E_SUCCESS) {
1269                 DEBUG(0,("TLS %s - %s\n", __location__, gnutls_strerror(ret)));
1270                 tevent_req_error(req, EINVAL);
1271                 return tevent_req_post(req, ev);
1272         }
1273
1274         gnutls_certificate_server_set_request(tlss->tls_session,
1275                                               GNUTLS_CERT_REQUEST);
1276         gnutls_dh_set_prime_bits(tlss->tls_session, DH_BITS);
1277
1278         gnutls_transport_set_ptr(tlss->tls_session, (gnutls_transport_ptr)state->tls_stream);
1279         gnutls_transport_set_pull_function(tlss->tls_session,
1280                                            (gnutls_pull_func)tstream_tls_pull_function);
1281         gnutls_transport_set_push_function(tlss->tls_session,
1282                                            (gnutls_push_func)tstream_tls_push_function);
1283 #if GNUTLS_VERSION_MAJOR < 3
1284         gnutls_transport_set_lowat(tlss->tls_session, 0);
1285 #endif
1286
1287         tlss->handshake.req = req;
1288         tstream_tls_retry_handshake(state->tls_stream);
1289         if (!tevent_req_is_in_progress(req)) {
1290                 return tevent_req_post(req, ev);
1291         }
1292
1293         return req;
1294 #else /* ENABLE_GNUTLS */
1295         tevent_req_error(req, ENOSYS);
1296         return tevent_req_post(req, ev);
1297 #endif /* ENABLE_GNUTLS */
1298 }
1299
1300 static void tstream_tls_retry_handshake(struct tstream_context *stream)
1301 {
1302         struct tstream_tls *tlss =
1303                 tstream_context_data(stream,
1304                 struct tstream_tls);
1305         struct tevent_req *req = tlss->handshake.req;
1306 #if ENABLE_GNUTLS
1307         int ret;
1308
1309         if (tlss->error != 0) {
1310                 tevent_req_error(req, tlss->error);
1311                 return;
1312         }
1313
1314         ret = gnutls_handshake(tlss->tls_session);
1315         if (ret == GNUTLS_E_INTERRUPTED || ret == GNUTLS_E_AGAIN) {
1316                 return;
1317         }
1318
1319         tlss->handshake.req = NULL;
1320
1321         if (gnutls_error_is_fatal(ret) != 0) {
1322                 DEBUG(1,("TLS %s - %s\n", __location__, gnutls_strerror(ret)));
1323                 tlss->error = EIO;
1324                 tevent_req_error(req, tlss->error);
1325                 return;
1326         }
1327
1328         if (ret != GNUTLS_E_SUCCESS) {
1329                 DEBUG(1,("TLS %s - %s\n", __location__, gnutls_strerror(ret)));
1330                 tlss->error = EIO;
1331                 tevent_req_error(req, tlss->error);
1332                 return;
1333         }
1334
1335         tevent_req_done(req);
1336 #else /* ENABLE_GNUTLS */
1337         tevent_req_error(req, ENOSYS);
1338 #endif /* ENABLE_GNUTLS */
1339 }
1340
1341 int tstream_tls_accept_recv(struct tevent_req *req,
1342                             int *perrno,
1343                             TALLOC_CTX *mem_ctx,
1344                             struct tstream_context **tls_stream)
1345 {
1346         struct tstream_tls_accept_state *state =
1347                 tevent_req_data(req,
1348                 struct tstream_tls_accept_state);
1349
1350         if (tevent_req_is_unix_error(req, perrno)) {
1351                 tevent_req_received(req);
1352                 return -1;
1353         }
1354
1355         *tls_stream = talloc_move(mem_ctx, &state->tls_stream);
1356         tevent_req_received(req);
1357         return 0;
1358 }