09fe5714ebee359731bb1d0c269e23497f60c60f
[metze/samba/wip.git] / source4 / lib / tls / tls_tstream.c
1 /*
2    Unix SMB/CIFS implementation.
3
4    Copyright (C) Stefan Metzmacher 2010
5
6    This program is free software; you can redistribute it and/or modify
7    it under the terms of the GNU General Public License as published by
8    the Free Software Foundation; either version 3 of the License, or
9    (at your option) any later version.
10
11    This program is distributed in the hope that it will be useful,
12    but WITHOUT ANY WARRANTY; without even the implied warranty of
13    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
14    GNU General Public License for more details.
15
16    You should have received a copy of the GNU General Public License
17    along with this program.  If not, see <http://www.gnu.org/licenses/>.
18 */
19
20 #include "includes.h"
21 #include "system/network.h"
22 #include "../util/tevent_unix.h"
23 #include "../lib/tsocket/tsocket.h"
24 #include "../lib/tsocket/tsocket_internal.h"
25 #include "lib/tls/tls.h"
26
27 #if ENABLE_GNUTLS
28 #include "gnutls/gnutls.h"
29
30 #define DH_BITS 1024
31
32 #if defined(HAVE_GNUTLS_DATUM) && !defined(HAVE_GNUTLS_DATUM_T)
33 typedef gnutls_datum gnutls_datum_t;
34 #endif
35
36 #endif /* ENABLE_GNUTLS */
37
38 static const struct tstream_context_ops tstream_tls_ops;
39
40 struct tstream_tls {
41         struct tstream_context *plain_stream;
42         int error;
43
44 #if ENABLE_GNUTLS
45         gnutls_session tls_session;
46 #endif /* ENABLE_GNUTLS */
47
48         struct tevent_context *current_ev;
49
50         struct tevent_immediate *retry_im;
51
52         struct {
53                 uint8_t *buf;
54                 off_t ofs;
55                 struct iovec iov;
56                 struct tevent_req *subreq;
57                 struct tevent_immediate *im;
58         } push;
59
60         struct {
61                 uint8_t buffer[1024];
62                 struct iovec iov;
63                 struct tevent_req *subreq;
64         } pull;
65
66         struct {
67                 struct tevent_req *req;
68         } handshake;
69
70         struct {
71                 off_t ofs;
72                 size_t left;
73                 uint8_t buffer[1024];
74                 struct tevent_req *req;
75         } write;
76
77         struct {
78                 off_t ofs;
79                 size_t left;
80                 uint8_t buffer[1024];
81                 struct tevent_req *req;
82         } read;
83
84         struct {
85                 struct tevent_req *req;
86         } disconnect;
87 };
88
89 static void tstream_tls_retry_handshake(struct tstream_context *stream);
90 static void tstream_tls_retry_read(struct tstream_context *stream);
91 static void tstream_tls_retry_write(struct tstream_context *stream);
92 static void tstream_tls_retry_disconnect(struct tstream_context *stream);
93 static void tstream_tls_retry_trigger(struct tevent_context *ctx,
94                                       struct tevent_immediate *im,
95                                       void *private_data);
96
97 static void tstream_tls_retry(struct tstream_context *stream, bool deferred)
98 {
99
100         struct tstream_tls *tlss =
101                 tstream_context_data(stream,
102                 struct tstream_tls);
103
104         if (tlss->disconnect.req) {
105                 tstream_tls_retry_disconnect(stream);
106                 return;
107         }
108
109         if (tlss->handshake.req) {
110                 tstream_tls_retry_handshake(stream);
111                 return;
112         }
113
114         if (tlss->write.req && tlss->read.req && !deferred) {
115                 tevent_schedule_immediate(tlss->retry_im, tlss->current_ev,
116                                           tstream_tls_retry_trigger,
117                                           stream);
118         }
119
120         if (tlss->write.req) {
121                 tstream_tls_retry_write(stream);
122                 return;
123         }
124
125         if (tlss->read.req) {
126                 tstream_tls_retry_read(stream);
127                 return;
128         }
129 }
130
131 static void tstream_tls_retry_trigger(struct tevent_context *ctx,
132                                       struct tevent_immediate *im,
133                                       void *private_data)
134 {
135         struct tstream_context *stream =
136                 talloc_get_type_abort(private_data,
137                 struct tstream_context);
138
139         tstream_tls_retry(stream, true);
140 }
141
142 #if ENABLE_GNUTLS
143 static void tstream_tls_push_trigger_write(struct tevent_context *ev,
144                                            struct tevent_immediate *im,
145                                            void *private_data);
146
147 static ssize_t tstream_tls_push_function(gnutls_transport_ptr ptr,
148                                          const void *buf, size_t size)
149 {
150         struct tstream_context *stream =
151                 talloc_get_type_abort(ptr,
152                 struct tstream_context);
153         struct tstream_tls *tlss =
154                 tstream_context_data(stream,
155                 struct tstream_tls);
156         uint8_t *nbuf;
157         size_t len;
158
159         if (tlss->error != 0) {
160                 errno = tlss->error;
161                 return -1;
162         }
163
164         if (tlss->push.subreq) {
165                 errno = EAGAIN;
166                 return -1;
167         }
168
169         len = MIN(size, UINT16_MAX - tlss->push.ofs);
170
171         if (len == 0) {
172                 errno = EAGAIN;
173                 return -1;
174         }
175
176         nbuf = talloc_realloc(tlss, tlss->push.buf,
177                               uint8_t, tlss->push.ofs + len);
178         if (nbuf == NULL) {
179                 if (tlss->push.buf) {
180                         errno = EAGAIN;
181                         return -1;
182                 }
183
184                 return -1;
185         }
186         tlss->push.buf = nbuf;
187
188         memcpy(tlss->push.buf + tlss->push.ofs, buf, len);
189
190         if (tlss->push.im == NULL) {
191                 tlss->push.im = tevent_create_immediate(tlss);
192                 if (tlss->push.im == NULL) {
193                         errno = ENOMEM;
194                         return -1;
195                 }
196         }
197
198         if (tlss->push.ofs == 0) {
199                 /*
200                  * We'll do start the tstream_writev
201                  * in the next event cycle.
202                  *
203                  * This way we can batch all push requests,
204                  * if they fit into a UINT16_MAX buffer.
205                  *
206                  * This is important as gnutls_handshake()
207                  * had a bug in some versions e.g. 2.4.1
208                  * and others (See bug #7218) and it doesn't
209                  * handle EAGAIN.
210                  */
211                 tevent_schedule_immediate(tlss->push.im,
212                                           tlss->current_ev,
213                                           tstream_tls_push_trigger_write,
214                                           stream);
215         }
216
217         tlss->push.ofs += len;
218         return len;
219 }
220
221 static void tstream_tls_push_done(struct tevent_req *subreq);
222
223 static void tstream_tls_push_trigger_write(struct tevent_context *ev,
224                                            struct tevent_immediate *im,
225                                            void *private_data)
226 {
227         struct tstream_context *stream =
228                 talloc_get_type_abort(private_data,
229                 struct tstream_context);
230         struct tstream_tls *tlss =
231                 tstream_context_data(stream,
232                 struct tstream_tls);
233         struct tevent_req *subreq;
234
235         if (tlss->push.subreq) {
236                 /* nothing todo */
237                 return;
238         }
239
240         tlss->push.iov.iov_base = (char *)tlss->push.buf;
241         tlss->push.iov.iov_len = tlss->push.ofs;
242
243         subreq = tstream_writev_send(tlss,
244                                      tlss->current_ev,
245                                      tlss->plain_stream,
246                                      &tlss->push.iov, 1);
247         if (subreq == NULL) {
248                 tlss->error = ENOMEM;
249                 tstream_tls_retry(stream, false);
250                 return;
251         }
252         tevent_req_set_callback(subreq, tstream_tls_push_done, stream);
253
254         tlss->push.subreq = subreq;
255 }
256
257 static void tstream_tls_push_done(struct tevent_req *subreq)
258 {
259         struct tstream_context *stream =
260                 tevent_req_callback_data(subreq,
261                 struct tstream_context);
262         struct tstream_tls *tlss =
263                 tstream_context_data(stream,
264                 struct tstream_tls);
265         int ret;
266         int sys_errno;
267
268         tlss->push.subreq = NULL;
269         ZERO_STRUCT(tlss->push.iov);
270         TALLOC_FREE(tlss->push.buf);
271         tlss->push.ofs = 0;
272
273         ret = tstream_writev_recv(subreq, &sys_errno);
274         TALLOC_FREE(subreq);
275         if (ret == -1) {
276                 tlss->error = sys_errno;
277                 tstream_tls_retry(stream, false);
278                 return;
279         }
280
281         tstream_tls_retry(stream, false);
282 }
283
284 static void tstream_tls_pull_done(struct tevent_req *subreq);
285
286 static ssize_t tstream_tls_pull_function(gnutls_transport_ptr ptr,
287                                          void *buf, size_t size)
288 {
289         struct tstream_context *stream =
290                 talloc_get_type_abort(ptr,
291                 struct tstream_context);
292         struct tstream_tls *tlss =
293                 tstream_context_data(stream,
294                 struct tstream_tls);
295         struct tevent_req *subreq;
296
297         if (tlss->error != 0) {
298                 errno = tlss->error;
299                 return -1;
300         }
301
302         if (tlss->pull.subreq) {
303                 errno = EAGAIN;
304                 return -1;
305         }
306
307         if (tlss->pull.iov.iov_base) {
308                 uint8_t *b;
309                 size_t n;
310
311                 b = (uint8_t *)tlss->pull.iov.iov_base;
312
313                 n = MIN(tlss->pull.iov.iov_len, size);
314                 memcpy(buf, b, n);
315
316                 tlss->pull.iov.iov_len -= n;
317                 b += n;
318                 tlss->pull.iov.iov_base = (char *)b;
319                 if (tlss->pull.iov.iov_len == 0) {
320                         tlss->pull.iov.iov_base = NULL;
321                 }
322
323                 return n;
324         }
325
326         if (size == 0) {
327                 return 0;
328         }
329
330         tlss->pull.iov.iov_base = tlss->pull.buffer;
331         tlss->pull.iov.iov_len = MIN(size, sizeof(tlss->pull.buffer));
332
333         subreq = tstream_readv_send(tlss,
334                                     tlss->current_ev,
335                                     tlss->plain_stream,
336                                     &tlss->pull.iov, 1);
337         if (subreq == NULL) {
338                 errno = ENOMEM;
339                 return -1;
340         }
341         tevent_req_set_callback(subreq, tstream_tls_pull_done, stream);
342
343         tlss->pull.subreq = subreq;
344         errno = EAGAIN;
345         return -1;
346 }
347
348 static void tstream_tls_pull_done(struct tevent_req *subreq)
349 {
350         struct tstream_context *stream =
351                 tevent_req_callback_data(subreq,
352                 struct tstream_context);
353         struct tstream_tls *tlss =
354                 tstream_context_data(stream,
355                 struct tstream_tls);
356         int ret;
357         int sys_errno;
358
359         tlss->pull.subreq = NULL;
360
361         ret = tstream_readv_recv(subreq, &sys_errno);
362         TALLOC_FREE(subreq);
363         if (ret == -1) {
364                 tlss->error = sys_errno;
365                 tstream_tls_retry(stream, false);
366                 return;
367         }
368
369         tstream_tls_retry(stream, false);
370 }
371 #endif /* ENABLE_GNUTLS */
372
373 static int tstream_tls_destructor(struct tstream_tls *tlss)
374 {
375 #if ENABLE_GNUTLS
376         if (tlss->tls_session) {
377                 gnutls_deinit(tlss->tls_session);
378                 tlss->tls_session = NULL;
379         }
380 #endif /* ENABLE_GNUTLS */
381         return 0;
382 }
383
384 static ssize_t tstream_tls_pending_bytes(struct tstream_context *stream)
385 {
386         struct tstream_tls *tlss =
387                 tstream_context_data(stream,
388                 struct tstream_tls);
389         size_t ret;
390
391         if (tlss->error != 0) {
392                 errno = tlss->error;
393                 return -1;
394         }
395
396 #if ENABLE_GNUTLS
397         ret = gnutls_record_check_pending(tlss->tls_session);
398         ret += tlss->read.left;
399 #else /* ENABLE_GNUTLS */
400         errno = ENOSYS;
401         ret = -1;
402 #endif /* ENABLE_GNUTLS */
403         return ret;
404 }
405
406 struct tstream_tls_readv_state {
407         struct tstream_context *stream;
408
409         struct iovec *vector;
410         int count;
411
412         int ret;
413 };
414
415 static void tstream_tls_readv_crypt_next(struct tevent_req *req);
416
417 static struct tevent_req *tstream_tls_readv_send(TALLOC_CTX *mem_ctx,
418                                         struct tevent_context *ev,
419                                         struct tstream_context *stream,
420                                         struct iovec *vector,
421                                         size_t count)
422 {
423         struct tstream_tls *tlss =
424                 tstream_context_data(stream,
425                 struct tstream_tls);
426         struct tevent_req *req;
427         struct tstream_tls_readv_state *state;
428
429         tlss->read.req = NULL;
430         tlss->current_ev = ev;
431
432         req = tevent_req_create(mem_ctx, &state,
433                                 struct tstream_tls_readv_state);
434         if (req == NULL) {
435                 return NULL;
436         }
437
438         state->stream = stream;
439         state->ret = 0;
440
441         if (tlss->error != 0) {
442                 tevent_req_error(req, tlss->error);
443                 return tevent_req_post(req, ev);
444         }
445
446         /*
447          * we make a copy of the vector so we can change the structure
448          */
449         state->vector = talloc_array(state, struct iovec, count);
450         if (tevent_req_nomem(state->vector, req)) {
451                 return tevent_req_post(req, ev);
452         }
453         memcpy(state->vector, vector, sizeof(struct iovec) * count);
454         state->count = count;
455
456         tstream_tls_readv_crypt_next(req);
457         if (!tevent_req_is_in_progress(req)) {
458                 return tevent_req_post(req, ev);
459         }
460
461         return req;
462 }
463
464 static void tstream_tls_readv_crypt_next(struct tevent_req *req)
465 {
466         struct tstream_tls_readv_state *state =
467                 tevent_req_data(req,
468                 struct tstream_tls_readv_state);
469         struct tstream_tls *tlss =
470                 tstream_context_data(state->stream,
471                 struct tstream_tls);
472
473         /*
474          * copy the pending buffer first
475          */
476         while (tlss->read.left > 0 && state->count > 0) {
477                 uint8_t *base = (uint8_t *)state->vector[0].iov_base;
478                 size_t len = MIN(tlss->read.left, state->vector[0].iov_len);
479
480                 memcpy(base, tlss->read.buffer + tlss->read.ofs, len);
481
482                 base += len;
483                 state->vector[0].iov_base = (char *) base;
484                 state->vector[0].iov_len -= len;
485
486                 tlss->read.ofs += len;
487                 tlss->read.left -= len;
488
489                 if (state->vector[0].iov_len == 0) {
490                         state->vector += 1;
491                         state->count -= 1;
492                 }
493
494                 state->ret += len;
495         }
496
497         if (state->count == 0) {
498                 tevent_req_done(req);
499                 return;
500         }
501
502         tlss->read.req = req;
503         tstream_tls_retry_read(state->stream);
504 }
505
506 static void tstream_tls_retry_read(struct tstream_context *stream)
507 {
508         struct tstream_tls *tlss =
509                 tstream_context_data(stream,
510                 struct tstream_tls);
511         struct tevent_req *req = tlss->read.req;
512 #if ENABLE_GNUTLS
513         int ret;
514
515         if (tlss->error != 0) {
516                 tevent_req_error(req, tlss->error);
517                 return;
518         }
519
520         tlss->read.left = 0;
521         tlss->read.ofs = 0;
522
523         ret = gnutls_record_recv(tlss->tls_session,
524                                  tlss->read.buffer,
525                                  sizeof(tlss->read.buffer));
526         if (ret == GNUTLS_E_INTERRUPTED || ret == GNUTLS_E_AGAIN) {
527                 return;
528         }
529
530         tlss->read.req = NULL;
531
532         if (gnutls_error_is_fatal(ret) != 0) {
533                 DEBUG(1,("TLS %s - %s\n", __location__, gnutls_strerror(ret)));
534                 tlss->error = EIO;
535                 tevent_req_error(req, tlss->error);
536                 return;
537         }
538
539         if (ret == 0) {
540                 tlss->error = EPIPE;
541                 tevent_req_error(req, tlss->error);
542                 return;
543         }
544
545         tlss->read.left = ret;
546         tstream_tls_readv_crypt_next(req);
547 #else /* ENABLE_GNUTLS */
548         tevent_req_error(req, ENOSYS);
549 #endif /* ENABLE_GNUTLS */
550 }
551
552 static int tstream_tls_readv_recv(struct tevent_req *req,
553                                   int *perrno)
554 {
555         struct tstream_tls_readv_state *state =
556                 tevent_req_data(req,
557                 struct tstream_tls_readv_state);
558         struct tstream_tls *tlss =
559                 tstream_context_data(state->stream,
560                 struct tstream_tls);
561         int ret;
562
563         tlss->read.req = NULL;
564
565         ret = tsocket_simple_int_recv(req, perrno);
566         if (ret == 0) {
567                 ret = state->ret;
568         }
569
570         tevent_req_received(req);
571         return ret;
572 }
573
574 struct tstream_tls_writev_state {
575         struct tstream_context *stream;
576
577         struct iovec *vector;
578         int count;
579
580         int ret;
581 };
582
583 static void tstream_tls_writev_crypt_next(struct tevent_req *req);
584
585 static struct tevent_req *tstream_tls_writev_send(TALLOC_CTX *mem_ctx,
586                                         struct tevent_context *ev,
587                                         struct tstream_context *stream,
588                                         const struct iovec *vector,
589                                         size_t count)
590 {
591         struct tstream_tls *tlss =
592                 tstream_context_data(stream,
593                 struct tstream_tls);
594         struct tevent_req *req;
595         struct tstream_tls_writev_state *state;
596
597         tlss->write.req = NULL;
598         tlss->current_ev = ev;
599
600         req = tevent_req_create(mem_ctx, &state,
601                                 struct tstream_tls_writev_state);
602         if (req == NULL) {
603                 return NULL;
604         }
605
606         state->stream = stream;
607         state->ret = 0;
608
609         if (tlss->error != 0) {
610                 tevent_req_error(req, tlss->error);
611                 return tevent_req_post(req, ev);
612         }
613
614         /*
615          * we make a copy of the vector so we can change the structure
616          */
617         state->vector = talloc_array(state, struct iovec, count);
618         if (tevent_req_nomem(state->vector, req)) {
619                 return tevent_req_post(req, ev);
620         }
621         memcpy(state->vector, vector, sizeof(struct iovec) * count);
622         state->count = count;
623
624         tstream_tls_writev_crypt_next(req);
625         if (!tevent_req_is_in_progress(req)) {
626                 return tevent_req_post(req, ev);
627         }
628
629         return req;
630 }
631
632 static void tstream_tls_writev_crypt_next(struct tevent_req *req)
633 {
634         struct tstream_tls_writev_state *state =
635                 tevent_req_data(req,
636                 struct tstream_tls_writev_state);
637         struct tstream_tls *tlss =
638                 tstream_context_data(state->stream,
639                 struct tstream_tls);
640
641         tlss->write.left = sizeof(tlss->write.buffer);
642         tlss->write.ofs = 0;
643
644         /*
645          * first fill our buffer
646          */
647         while (tlss->write.left > 0 && state->count > 0) {
648                 uint8_t *base = (uint8_t *)state->vector[0].iov_base;
649                 size_t len = MIN(tlss->write.left, state->vector[0].iov_len);
650
651                 memcpy(tlss->write.buffer + tlss->write.ofs, base, len);
652
653                 base += len;
654                 state->vector[0].iov_base = (char *) base;
655                 state->vector[0].iov_len -= len;
656
657                 tlss->write.ofs += len;
658                 tlss->write.left -= len;
659
660                 if (state->vector[0].iov_len == 0) {
661                         state->vector += 1;
662                         state->count -= 1;
663                 }
664
665                 state->ret += len;
666         }
667
668         if (tlss->write.ofs == 0) {
669                 tevent_req_done(req);
670                 return;
671         }
672
673         tlss->write.left = tlss->write.ofs;
674         tlss->write.ofs = 0;
675
676         tlss->write.req = req;
677         tstream_tls_retry_write(state->stream);
678 }
679
680 static void tstream_tls_retry_write(struct tstream_context *stream)
681 {
682         struct tstream_tls *tlss =
683                 tstream_context_data(stream,
684                 struct tstream_tls);
685         struct tevent_req *req = tlss->write.req;
686 #if ENABLE_GNUTLS
687         int ret;
688
689         if (tlss->error != 0) {
690                 tevent_req_error(req, tlss->error);
691                 return;
692         }
693
694         ret = gnutls_record_send(tlss->tls_session,
695                                  tlss->write.buffer + tlss->write.ofs,
696                                  tlss->write.left);
697         if (ret == GNUTLS_E_INTERRUPTED || ret == GNUTLS_E_AGAIN) {
698                 return;
699         }
700
701         tlss->write.req = NULL;
702
703         if (gnutls_error_is_fatal(ret) != 0) {
704                 DEBUG(1,("TLS %s - %s\n", __location__, gnutls_strerror(ret)));
705                 tlss->error = EIO;
706                 tevent_req_error(req, tlss->error);
707                 return;
708         }
709
710         if (ret == 0) {
711                 tlss->error = EPIPE;
712                 tevent_req_error(req, tlss->error);
713                 return;
714         }
715
716         tlss->write.ofs += ret;
717         tlss->write.left -= ret;
718
719         if (tlss->write.left > 0) {
720                 tlss->write.req = req;
721                 tstream_tls_retry_write(stream);
722                 return;
723         }
724
725         tstream_tls_writev_crypt_next(req);
726 #else /* ENABLE_GNUTLS */
727         tevent_req_error(req, ENOSYS);
728 #endif /* ENABLE_GNUTLS */
729 }
730
731 static int tstream_tls_writev_recv(struct tevent_req *req,
732                                    int *perrno)
733 {
734         struct tstream_tls_writev_state *state =
735                 tevent_req_data(req,
736                 struct tstream_tls_writev_state);
737         struct tstream_tls *tlss =
738                 tstream_context_data(state->stream,
739                 struct tstream_tls);
740         int ret;
741
742         tlss->write.req = NULL;
743
744         ret = tsocket_simple_int_recv(req, perrno);
745         if (ret == 0) {
746                 ret = state->ret;
747         }
748
749         tevent_req_received(req);
750         return ret;
751 }
752
753 struct tstream_tls_disconnect_state {
754         uint8_t _dummy;
755 };
756
757 static struct tevent_req *tstream_tls_disconnect_send(TALLOC_CTX *mem_ctx,
758                                                 struct tevent_context *ev,
759                                                 struct tstream_context *stream)
760 {
761         struct tstream_tls *tlss =
762                 tstream_context_data(stream,
763                 struct tstream_tls);
764         struct tevent_req *req;
765         struct tstream_tls_disconnect_state *state;
766
767         tlss->disconnect.req = NULL;
768         tlss->current_ev = ev;
769
770         req = tevent_req_create(mem_ctx, &state,
771                                 struct tstream_tls_disconnect_state);
772         if (req == NULL) {
773                 return NULL;
774         }
775
776         if (tlss->error != 0) {
777                 tevent_req_error(req, tlss->error);
778                 return tevent_req_post(req, ev);
779         }
780
781         tlss->disconnect.req = req;
782         tstream_tls_retry_disconnect(stream);
783         if (!tevent_req_is_in_progress(req)) {
784                 return tevent_req_post(req, ev);
785         }
786
787         return req;
788 }
789
790 static void tstream_tls_retry_disconnect(struct tstream_context *stream)
791 {
792         struct tstream_tls *tlss =
793                 tstream_context_data(stream,
794                 struct tstream_tls);
795         struct tevent_req *req = tlss->disconnect.req;
796 #if ENABLE_GNUTLS
797         int ret;
798
799         if (tlss->error != 0) {
800                 tevent_req_error(req, tlss->error);
801                 return;
802         }
803
804         ret = gnutls_bye(tlss->tls_session, GNUTLS_SHUT_WR);
805         if (ret == GNUTLS_E_INTERRUPTED || ret == GNUTLS_E_AGAIN) {
806                 return;
807         }
808
809         tlss->disconnect.req = NULL;
810
811         if (gnutls_error_is_fatal(ret) != 0) {
812                 DEBUG(1,("TLS %s - %s\n", __location__, gnutls_strerror(ret)));
813                 tlss->error = EIO;
814                 tevent_req_error(req, tlss->error);
815                 return;
816         }
817
818         if (ret != GNUTLS_E_SUCCESS) {
819                 DEBUG(1,("TLS %s - %s\n", __location__, gnutls_strerror(ret)));
820                 tlss->error = EIO;
821                 tevent_req_error(req, tlss->error);
822                 return;
823         }
824
825         tevent_req_done(req);
826 #else /* ENABLE_GNUTLS */
827         tevent_req_error(req, ENOSYS);
828 #endif /* ENABLE_GNUTLS */
829 }
830
831 static int tstream_tls_disconnect_recv(struct tevent_req *req,
832                                        int *perrno)
833 {
834         int ret;
835
836         ret = tsocket_simple_int_recv(req, perrno);
837
838         tevent_req_received(req);
839         return ret;
840 }
841
842 static const struct tstream_context_ops tstream_tls_ops = {
843         .name                   = "tls",
844
845         .pending_bytes          = tstream_tls_pending_bytes,
846
847         .readv_send             = tstream_tls_readv_send,
848         .readv_recv             = tstream_tls_readv_recv,
849
850         .writev_send            = tstream_tls_writev_send,
851         .writev_recv            = tstream_tls_writev_recv,
852
853         .disconnect_send        = tstream_tls_disconnect_send,
854         .disconnect_recv        = tstream_tls_disconnect_recv,
855 };
856
857 struct tstream_tls_params {
858 #if ENABLE_GNUTLS
859         gnutls_certificate_credentials x509_cred;
860         gnutls_dh_params dh_params;
861 #endif /* ENABLE_GNUTLS */
862         bool tls_enabled;
863 };
864
865 static int tstream_tls_params_destructor(struct tstream_tls_params *tlsp)
866 {
867 #if ENABLE_GNUTLS
868         if (tlsp->x509_cred) {
869                 gnutls_certificate_free_credentials(tlsp->x509_cred);
870                 tlsp->x509_cred = NULL;
871         }
872         if (tlsp->dh_params) {
873                 gnutls_dh_params_deinit(tlsp->dh_params);
874                 tlsp->dh_params = NULL;
875         }
876 #endif /* ENABLE_GNUTLS */
877         return 0;
878 }
879
880 bool tstream_tls_params_enabled(struct tstream_tls_params *tlsp)
881 {
882         return tlsp->tls_enabled;
883 }
884
885 NTSTATUS tstream_tls_params_client(TALLOC_CTX *mem_ctx,
886                                    const char *ca_file,
887                                    const char *crl_file,
888                                    struct tstream_tls_params **_tlsp)
889 {
890 #if ENABLE_GNUTLS
891         struct tstream_tls_params *tlsp;
892         int ret;
893
894         ret = gnutls_global_init();
895         if (ret != GNUTLS_E_SUCCESS) {
896                 DEBUG(0,("TLS %s - %s\n", __location__, gnutls_strerror(ret)));
897                 return NT_STATUS_NOT_SUPPORTED;
898         }
899
900         tlsp = talloc_zero(mem_ctx, struct tstream_tls_params);
901         NT_STATUS_HAVE_NO_MEMORY(tlsp);
902
903         talloc_set_destructor(tlsp, tstream_tls_params_destructor);
904
905         ret = gnutls_certificate_allocate_credentials(&tlsp->x509_cred);
906         if (ret != GNUTLS_E_SUCCESS) {
907                 DEBUG(0,("TLS %s - %s\n", __location__, gnutls_strerror(ret)));
908                 talloc_free(tlsp);
909                 return NT_STATUS_NO_MEMORY;
910         }
911
912         if (ca_file && *ca_file) {
913                 ret = gnutls_certificate_set_x509_trust_file(tlsp->x509_cred,
914                                                              ca_file,
915                                                              GNUTLS_X509_FMT_PEM);
916                 if (ret < 0) {
917                         DEBUG(0,("TLS failed to initialise cafile %s - %s\n",
918                                  ca_file, gnutls_strerror(ret)));
919                         talloc_free(tlsp);
920                         return NT_STATUS_CANT_ACCESS_DOMAIN_INFO;
921                 }
922         }
923
924         if (crl_file && *crl_file) {
925                 ret = gnutls_certificate_set_x509_crl_file(tlsp->x509_cred,
926                                                            crl_file, 
927                                                            GNUTLS_X509_FMT_PEM);
928                 if (ret < 0) {
929                         DEBUG(0,("TLS failed to initialise crlfile %s - %s\n",
930                                  crl_file, gnutls_strerror(ret)));
931                         talloc_free(tlsp);
932                         return NT_STATUS_CANT_ACCESS_DOMAIN_INFO;
933                 }
934         }
935
936         tlsp->tls_enabled = true;
937
938         *_tlsp = tlsp;
939         return NT_STATUS_OK;
940 #else /* ENABLE_GNUTLS */
941         return NT_STATUS_NOT_IMPLEMENTED;
942 #endif /* ENABLE_GNUTLS */
943 }
944
945 struct tstream_tls_connect_state {
946         struct tstream_context *tls_stream;
947 };
948
949 struct tevent_req *_tstream_tls_connect_send(TALLOC_CTX *mem_ctx,
950                                              struct tevent_context *ev,
951                                              struct tstream_context *plain_stream,
952                                              struct tstream_tls_params *tls_params,
953                                              const char *location)
954 {
955         struct tevent_req *req;
956         struct tstream_tls_connect_state *state;
957 #if ENABLE_GNUTLS
958         struct tstream_tls *tlss;
959         int ret;
960         static const int cert_type_priority[] = {
961                 GNUTLS_CRT_X509,
962                 GNUTLS_CRT_OPENPGP,
963                 0
964         };
965 #endif /* ENABLE_GNUTLS */
966
967         req = tevent_req_create(mem_ctx, &state,
968                                 struct tstream_tls_connect_state);
969         if (req == NULL) {
970                 return NULL;
971         }
972
973 #if ENABLE_GNUTLS
974         state->tls_stream = tstream_context_create(state,
975                                                    &tstream_tls_ops,
976                                                    &tlss,
977                                                    struct tstream_tls,
978                                                    location);
979         if (tevent_req_nomem(state->tls_stream, req)) {
980                 return tevent_req_post(req, ev);
981         }
982         ZERO_STRUCTP(tlss);
983         talloc_set_destructor(tlss, tstream_tls_destructor);
984
985         tlss->plain_stream = plain_stream;
986
987         tlss->current_ev = ev;
988         tlss->retry_im = tevent_create_immediate(tlss);
989         if (tevent_req_nomem(tlss->retry_im, req)) {
990                 return tevent_req_post(req, ev);
991         }
992
993         ret = gnutls_init(&tlss->tls_session, GNUTLS_CLIENT);
994         if (ret != GNUTLS_E_SUCCESS) {
995                 DEBUG(0,("TLS %s - %s\n", __location__, gnutls_strerror(ret)));
996                 tevent_req_error(req, EINVAL);
997                 return tevent_req_post(req, ev);
998         }
999
1000         ret = gnutls_set_default_priority(tlss->tls_session);
1001         if (ret != GNUTLS_E_SUCCESS) {
1002                 DEBUG(0,("TLS %s - %s\n", __location__, gnutls_strerror(ret)));
1003                 tevent_req_error(req, EINVAL);
1004                 return tevent_req_post(req, ev);
1005         }
1006
1007         gnutls_certificate_type_set_priority(tlss->tls_session, cert_type_priority);
1008
1009         ret = gnutls_credentials_set(tlss->tls_session,
1010                                      GNUTLS_CRD_CERTIFICATE,
1011                                      tls_params->x509_cred);
1012         if (ret != GNUTLS_E_SUCCESS) {
1013                 DEBUG(0,("TLS %s - %s\n", __location__, gnutls_strerror(ret)));
1014                 tevent_req_error(req, EINVAL);
1015                 return tevent_req_post(req, ev);
1016         }
1017
1018         gnutls_transport_set_ptr(tlss->tls_session, (gnutls_transport_ptr)state->tls_stream);
1019         gnutls_transport_set_pull_function(tlss->tls_session,
1020                                            (gnutls_pull_func)tstream_tls_pull_function);
1021         gnutls_transport_set_push_function(tlss->tls_session,
1022                                            (gnutls_push_func)tstream_tls_push_function);
1023         gnutls_transport_set_lowat(tlss->tls_session, 0);
1024
1025         tlss->handshake.req = req;
1026         tstream_tls_retry_handshake(state->tls_stream);
1027         if (!tevent_req_is_in_progress(req)) {
1028                 return tevent_req_post(req, ev);
1029         }
1030
1031         return req;
1032 #else /* ENABLE_GNUTLS */
1033         tevent_req_error(req, ENOSYS);
1034         return tevent_req_post(req, ev);
1035 #endif /* ENABLE_GNUTLS */
1036 }
1037
1038 int tstream_tls_connect_recv(struct tevent_req *req,
1039                              int *perrno,
1040                              TALLOC_CTX *mem_ctx,
1041                              struct tstream_context **tls_stream)
1042 {
1043         struct tstream_tls_connect_state *state =
1044                 tevent_req_data(req,
1045                 struct tstream_tls_connect_state);
1046
1047         if (tevent_req_is_unix_error(req, perrno)) {
1048                 tevent_req_received(req);
1049                 return -1;
1050         }
1051
1052         *tls_stream = talloc_move(mem_ctx, &state->tls_stream);
1053         tevent_req_received(req);
1054         return 0;
1055 }
1056
1057 extern void tls_cert_generate(TALLOC_CTX *, const char *, const char *, const char *, const char *);
1058
1059 /*
1060   initialise global tls state
1061 */
1062 NTSTATUS tstream_tls_params_server(TALLOC_CTX *mem_ctx,
1063                                    const char *dns_host_name,
1064                                    bool enabled,
1065                                    const char *key_file,
1066                                    const char *cert_file,
1067                                    const char *ca_file,
1068                                    const char *crl_file,
1069                                    const char *dhp_file,
1070                                    struct tstream_tls_params **_tlsp)
1071 {
1072         struct tstream_tls_params *tlsp;
1073 #if ENABLE_GNUTLS
1074         int ret;
1075
1076         if (!enabled || key_file == NULL || *key_file == 0) {
1077                 tlsp = talloc_zero(mem_ctx, struct tstream_tls_params);
1078                 NT_STATUS_HAVE_NO_MEMORY(tlsp);
1079                 talloc_set_destructor(tlsp, tstream_tls_params_destructor);
1080                 tlsp->tls_enabled = false;
1081
1082                 *_tlsp = tlsp;
1083                 return NT_STATUS_OK;
1084         }
1085
1086         ret = gnutls_global_init();
1087         if (ret != GNUTLS_E_SUCCESS) {
1088                 DEBUG(0,("TLS %s - %s\n", __location__, gnutls_strerror(ret)));
1089                 return NT_STATUS_NOT_SUPPORTED;
1090         }
1091
1092         tlsp = talloc_zero(mem_ctx, struct tstream_tls_params);
1093         NT_STATUS_HAVE_NO_MEMORY(tlsp);
1094
1095         talloc_set_destructor(tlsp, tstream_tls_params_destructor);
1096
1097         if (!file_exist(ca_file)) {
1098                 tls_cert_generate(tlsp, dns_host_name,
1099                                   key_file, cert_file, ca_file);
1100         }
1101
1102         ret = gnutls_certificate_allocate_credentials(&tlsp->x509_cred);
1103         if (ret != GNUTLS_E_SUCCESS) {
1104                 DEBUG(0,("TLS %s - %s\n", __location__, gnutls_strerror(ret)));
1105                 talloc_free(tlsp);
1106                 return NT_STATUS_NO_MEMORY;
1107         }
1108
1109         if (ca_file && *ca_file) {
1110                 ret = gnutls_certificate_set_x509_trust_file(tlsp->x509_cred,
1111                                                              ca_file,
1112                                                              GNUTLS_X509_FMT_PEM);
1113                 if (ret < 0) {
1114                         DEBUG(0,("TLS failed to initialise cafile %s - %s\n",
1115                                  ca_file, gnutls_strerror(ret)));
1116                         talloc_free(tlsp);
1117                         return NT_STATUS_CANT_ACCESS_DOMAIN_INFO;
1118                 }
1119         }
1120
1121         if (crl_file && *crl_file) {
1122                 ret = gnutls_certificate_set_x509_crl_file(tlsp->x509_cred,
1123                                                            crl_file, 
1124                                                            GNUTLS_X509_FMT_PEM);
1125                 if (ret < 0) {
1126                         DEBUG(0,("TLS failed to initialise crlfile %s - %s\n",
1127                                  crl_file, gnutls_strerror(ret)));
1128                         talloc_free(tlsp);
1129                         return NT_STATUS_CANT_ACCESS_DOMAIN_INFO;
1130                 }
1131         }
1132
1133         ret = gnutls_certificate_set_x509_key_file(tlsp->x509_cred,
1134                                                    cert_file, key_file,
1135                                                    GNUTLS_X509_FMT_PEM);
1136         if (ret != GNUTLS_E_SUCCESS) {
1137                 DEBUG(0,("TLS failed to initialise certfile %s and keyfile %s - %s\n",
1138                          cert_file, key_file, gnutls_strerror(ret)));
1139                 talloc_free(tlsp);
1140                 return NT_STATUS_CANT_ACCESS_DOMAIN_INFO;
1141         }
1142
1143         ret = gnutls_dh_params_init(&tlsp->dh_params);
1144         if (ret != GNUTLS_E_SUCCESS) {
1145                 DEBUG(0,("TLS %s - %s\n", __location__, gnutls_strerror(ret)));
1146                 talloc_free(tlsp);
1147                 return NT_STATUS_NO_MEMORY;
1148         }
1149
1150         if (dhp_file && *dhp_file) {
1151                 gnutls_datum_t dhparms;
1152                 size_t size;
1153
1154                 dhparms.data = (uint8_t *)file_load(dhp_file, &size, 0, tlsp);
1155
1156                 if (!dhparms.data) {
1157                         DEBUG(0,("TLS failed to read DH Parms from %s - %d:%s\n",
1158                                  dhp_file, errno, strerror(errno)));
1159                         talloc_free(tlsp);
1160                         return NT_STATUS_CANT_ACCESS_DOMAIN_INFO;
1161                 }
1162                 dhparms.size = size;
1163
1164                 ret = gnutls_dh_params_import_pkcs3(tlsp->dh_params,
1165                                                     &dhparms,
1166                                                     GNUTLS_X509_FMT_PEM);
1167                 if (ret != GNUTLS_E_SUCCESS) {
1168                         DEBUG(0,("TLS failed to import pkcs3 %s - %s\n",
1169                                  dhp_file, gnutls_strerror(ret)));
1170                         talloc_free(tlsp);
1171                         return NT_STATUS_CANT_ACCESS_DOMAIN_INFO;
1172                 }
1173         } else {
1174                 ret = gnutls_dh_params_generate2(tlsp->dh_params, DH_BITS);
1175                 if (ret != GNUTLS_E_SUCCESS) {
1176                         DEBUG(0,("TLS failed to generate dh_params - %s\n",
1177                                  gnutls_strerror(ret)));
1178                         talloc_free(tlsp);
1179                         return NT_STATUS_INTERNAL_ERROR;
1180                 }
1181         }
1182
1183         gnutls_certificate_set_dh_params(tlsp->x509_cred, tlsp->dh_params);
1184
1185         tlsp->tls_enabled = true;
1186
1187 #else /* ENABLE_GNUTLS */
1188         tlsp = talloc_zero(mem_ctx, struct tstream_tls_params);
1189         NT_STATUS_HAVE_NO_MEMORY(tlsp);
1190         talloc_set_destructor(tlsp, tstream_tls_params_destructor);
1191         tlsp->tls_enabled = false;
1192 #endif /* ENABLE_GNUTLS */
1193
1194         *_tlsp = tlsp;
1195         return NT_STATUS_OK;
1196 }
1197
1198 struct tstream_tls_accept_state {
1199         struct tstream_context *tls_stream;
1200 };
1201
1202 struct tevent_req *_tstream_tls_accept_send(TALLOC_CTX *mem_ctx,
1203                                             struct tevent_context *ev,
1204                                             struct tstream_context *plain_stream,
1205                                             struct tstream_tls_params *tlsp,
1206                                             const char *location)
1207 {
1208         struct tevent_req *req;
1209         struct tstream_tls_accept_state *state;
1210         struct tstream_tls *tlss;
1211 #if ENABLE_GNUTLS
1212         int ret;
1213 #endif /* ENABLE_GNUTLS */
1214
1215         req = tevent_req_create(mem_ctx, &state,
1216                                 struct tstream_tls_accept_state);
1217         if (req == NULL) {
1218                 return NULL;
1219         }
1220
1221         state->tls_stream = tstream_context_create(state,
1222                                                    &tstream_tls_ops,
1223                                                    &tlss,
1224                                                    struct tstream_tls,
1225                                                    location);
1226         if (tevent_req_nomem(state->tls_stream, req)) {
1227                 return tevent_req_post(req, ev);
1228         }
1229         ZERO_STRUCTP(tlss);
1230         talloc_set_destructor(tlss, tstream_tls_destructor);
1231
1232 #if ENABLE_GNUTLS
1233         tlss->plain_stream = plain_stream;
1234
1235         tlss->current_ev = ev;
1236         tlss->retry_im = tevent_create_immediate(tlss);
1237         if (tevent_req_nomem(tlss->retry_im, req)) {
1238                 return tevent_req_post(req, ev);
1239         }
1240
1241         ret = gnutls_init(&tlss->tls_session, GNUTLS_SERVER);
1242         if (ret != GNUTLS_E_SUCCESS) {
1243                 DEBUG(0,("TLS %s - %s\n", __location__, gnutls_strerror(ret)));
1244                 tevent_req_error(req, EINVAL);
1245                 return tevent_req_post(req, ev);
1246         }
1247
1248         ret = gnutls_set_default_priority(tlss->tls_session);
1249         if (ret != GNUTLS_E_SUCCESS) {
1250                 DEBUG(0,("TLS %s - %s\n", __location__, gnutls_strerror(ret)));
1251                 tevent_req_error(req, EINVAL);
1252                 return tevent_req_post(req, ev);
1253         }
1254
1255         ret = gnutls_credentials_set(tlss->tls_session, GNUTLS_CRD_CERTIFICATE,
1256                                      tlsp->x509_cred);
1257         if (ret != GNUTLS_E_SUCCESS) {
1258                 DEBUG(0,("TLS %s - %s\n", __location__, gnutls_strerror(ret)));
1259                 tevent_req_error(req, EINVAL);
1260                 return tevent_req_post(req, ev);
1261         }
1262
1263         gnutls_certificate_server_set_request(tlss->tls_session,
1264                                               GNUTLS_CERT_REQUEST);
1265         gnutls_dh_set_prime_bits(tlss->tls_session, DH_BITS);
1266
1267         gnutls_transport_set_ptr(tlss->tls_session, (gnutls_transport_ptr)state->tls_stream);
1268         gnutls_transport_set_pull_function(tlss->tls_session,
1269                                            (gnutls_pull_func)tstream_tls_pull_function);
1270         gnutls_transport_set_push_function(tlss->tls_session,
1271                                            (gnutls_push_func)tstream_tls_push_function);
1272         gnutls_transport_set_lowat(tlss->tls_session, 0);
1273
1274         tlss->handshake.req = req;
1275         tstream_tls_retry_handshake(state->tls_stream);
1276         if (!tevent_req_is_in_progress(req)) {
1277                 return tevent_req_post(req, ev);
1278         }
1279
1280         return req;
1281 #else /* ENABLE_GNUTLS */
1282         tevent_req_error(req, ENOSYS);
1283         return tevent_req_post(req, ev);
1284 #endif /* ENABLE_GNUTLS */
1285 }
1286
1287 static void tstream_tls_retry_handshake(struct tstream_context *stream)
1288 {
1289         struct tstream_tls *tlss =
1290                 tstream_context_data(stream,
1291                 struct tstream_tls);
1292         struct tevent_req *req = tlss->handshake.req;
1293 #if ENABLE_GNUTLS
1294         int ret;
1295
1296         if (tlss->error != 0) {
1297                 tevent_req_error(req, tlss->error);
1298                 return;
1299         }
1300
1301         ret = gnutls_handshake(tlss->tls_session);
1302         if (ret == GNUTLS_E_INTERRUPTED || ret == GNUTLS_E_AGAIN) {
1303                 return;
1304         }
1305
1306         tlss->handshake.req = NULL;
1307
1308         if (gnutls_error_is_fatal(ret) != 0) {
1309                 DEBUG(1,("TLS %s - %s\n", __location__, gnutls_strerror(ret)));
1310                 tlss->error = EIO;
1311                 tevent_req_error(req, tlss->error);
1312                 return;
1313         }
1314
1315         if (ret != GNUTLS_E_SUCCESS) {
1316                 DEBUG(1,("TLS %s - %s\n", __location__, gnutls_strerror(ret)));
1317                 tlss->error = EIO;
1318                 tevent_req_error(req, tlss->error);
1319                 return;
1320         }
1321
1322         tevent_req_done(req);
1323 #else /* ENABLE_GNUTLS */
1324         tevent_req_error(req, ENOSYS);
1325 #endif /* ENABLE_GNUTLS */
1326 }
1327
1328 int tstream_tls_accept_recv(struct tevent_req *req,
1329                             int *perrno,
1330                             TALLOC_CTX *mem_ctx,
1331                             struct tstream_context **tls_stream)
1332 {
1333         struct tstream_tls_accept_state *state =
1334                 tevent_req_data(req,
1335                 struct tstream_tls_accept_state);
1336
1337         if (tevent_req_is_unix_error(req, perrno)) {
1338                 tevent_req_received(req);
1339                 return -1;
1340         }
1341
1342         *tls_stream = talloc_move(mem_ctx, &state->tls_stream);
1343         tevent_req_received(req);
1344         return 0;
1345 }