381d71e64abab6134309943506a1d247a2a19341
[metze/samba/wip.git] / source4 / lib / tls / tls_tstream.c
1 /*
2    Unix SMB/CIFS implementation.
3
4    Copyright (C) Stefan Metzmacher 2010
5
6    This program is free software; you can redistribute it and/or modify
7    it under the terms of the GNU General Public License as published by
8    the Free Software Foundation; either version 3 of the License, or
9    (at your option) any later version.
10
11    This program is distributed in the hope that it will be useful,
12    but WITHOUT ANY WARRANTY; without even the implied warranty of
13    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
14    GNU General Public License for more details.
15
16    You should have received a copy of the GNU General Public License
17    along with this program.  If not, see <http://www.gnu.org/licenses/>.
18 */
19
20 #include "includes.h"
21 #include "system/network.h"
22 #include "../util/tevent_unix.h"
23 #include "../lib/tsocket/tsocket.h"
24 #include "../lib/tsocket/tsocket_internal.h"
25 #include "lib/tls/tls.h"
26
27 #if ENABLE_GNUTLS
28 #include "gnutls/gnutls.h"
29
30 #define DH_BITS 1024
31
32 #if defined(HAVE_GNUTLS_DATUM) && !defined(HAVE_GNUTLS_DATUM_T)
33 typedef gnutls_datum gnutls_datum_t;
34 #endif
35
36 #endif /* ENABLE_GNUTLS */
37
38 static const struct tstream_context_ops tstream_tls_ops;
39
40 struct tstream_tls {
41         struct tstream_context *plain_stream;
42         int error;
43
44 #if ENABLE_GNUTLS
45         gnutls_session tls_session;
46 #endif /* ENABLE_GNUTLS */
47
48         struct tevent_context *current_ev;
49
50         struct tevent_immediate *retry_im;
51
52         struct {
53                 uint8_t buffer[4096];
54                 off_t ofs;
55                 struct iovec iov;
56                 struct tevent_req *subreq;
57                 struct tevent_immediate *im;
58         } push;
59
60         struct {
61                 uint8_t buffer[1024];
62                 struct iovec iov;
63                 struct tevent_req *subreq;
64         } pull;
65
66         struct {
67                 struct tevent_req *req;
68         } handshake;
69
70         struct {
71                 off_t ofs;
72                 size_t left;
73                 uint8_t buffer[1024];
74                 struct tevent_req *req;
75         } write;
76
77         struct {
78                 off_t ofs;
79                 size_t left;
80                 uint8_t buffer[1024];
81                 struct tevent_req *req;
82         } read;
83
84         struct {
85                 struct tevent_req *req;
86         } disconnect;
87 };
88
89 static void tstream_tls_retry_handshake(struct tstream_context *stream);
90 static void tstream_tls_retry_read(struct tstream_context *stream);
91 static void tstream_tls_retry_write(struct tstream_context *stream);
92 static void tstream_tls_retry_disconnect(struct tstream_context *stream);
93 static void tstream_tls_retry_trigger(struct tevent_context *ctx,
94                                       struct tevent_immediate *im,
95                                       void *private_data);
96
97 static void tstream_tls_retry(struct tstream_context *stream, bool deferred)
98 {
99
100         struct tstream_tls *tlss =
101                 tstream_context_data(stream,
102                 struct tstream_tls);
103
104         if (tlss->disconnect.req) {
105                 tstream_tls_retry_disconnect(stream);
106                 return;
107         }
108
109         if (tlss->handshake.req) {
110                 tstream_tls_retry_handshake(stream);
111                 return;
112         }
113
114         if (tlss->write.req && tlss->read.req && !deferred) {
115                 tevent_schedule_immediate(tlss->retry_im, tlss->current_ev,
116                                           tstream_tls_retry_trigger,
117                                           stream);
118         }
119
120         if (tlss->write.req) {
121                 tstream_tls_retry_write(stream);
122                 return;
123         }
124
125         if (tlss->read.req) {
126                 tstream_tls_retry_read(stream);
127                 return;
128         }
129 }
130
131 static void tstream_tls_retry_trigger(struct tevent_context *ctx,
132                                       struct tevent_immediate *im,
133                                       void *private_data)
134 {
135         struct tstream_context *stream =
136                 talloc_get_type_abort(private_data,
137                 struct tstream_context);
138
139         tstream_tls_retry(stream, true);
140 }
141
142 #if ENABLE_GNUTLS
143 static void tstream_tls_push_trigger_write(struct tevent_context *ev,
144                                            struct tevent_immediate *im,
145                                            void *private_data);
146
147 static ssize_t tstream_tls_push_function(gnutls_transport_ptr ptr,
148                                          const void *buf, size_t size)
149 {
150         struct tstream_context *stream =
151                 talloc_get_type_abort(ptr,
152                 struct tstream_context);
153         struct tstream_tls *tlss =
154                 tstream_context_data(stream,
155                 struct tstream_tls);
156         size_t len;
157
158         if (tlss->error != 0) {
159                 errno = tlss->error;
160                 return -1;
161         }
162
163         if (tlss->push.subreq) {
164                 errno = EAGAIN;
165                 return -1;
166         }
167
168         if (tlss->push.ofs == sizeof(tlss->push.buffer)) {
169                 errno = EAGAIN;
170                 return -1;
171         }
172
173         len = MIN(size, sizeof(tlss->push.buffer) - tlss->push.ofs);
174         memcpy(tlss->push.buffer + tlss->push.ofs, buf, len);
175
176         if (tlss->push.im == NULL) {
177                 tlss->push.im = tevent_create_immediate(tlss);
178                 if (tlss->push.im == NULL) {
179                         errno = ENOMEM;
180                         return -1;
181                 }
182         }
183
184         if (tlss->push.ofs == 0) {
185                 /*
186                  * We'll do start the tstream_writev
187                  * in the next event cycle.
188                  *
189                  * This way we can batch all push requests,
190                  * if they fit into the buffer.
191                  *
192                  * This is important as gnutls_handshake()
193                  * had a bug in some versions e.g. 2.4.1
194                  * and others (See bug #7218) and it doesn't
195                  * handle EAGAIN.
196                  */
197                 tevent_schedule_immediate(tlss->push.im,
198                                           tlss->current_ev,
199                                           tstream_tls_push_trigger_write,
200                                           stream);
201         }
202
203         tlss->push.ofs += len;
204         return len;
205 }
206
207 static void tstream_tls_push_done(struct tevent_req *subreq);
208
209 static void tstream_tls_push_trigger_write(struct tevent_context *ev,
210                                            struct tevent_immediate *im,
211                                            void *private_data)
212 {
213         struct tstream_context *stream =
214                 talloc_get_type_abort(private_data,
215                 struct tstream_context);
216         struct tstream_tls *tlss =
217                 tstream_context_data(stream,
218                 struct tstream_tls);
219         struct tevent_req *subreq;
220
221         if (tlss->push.subreq) {
222                 /* nothing todo */
223                 return;
224         }
225
226         tlss->push.iov.iov_base = (char *)tlss->push.buffer;
227         tlss->push.iov.iov_len = tlss->push.ofs;
228
229         subreq = tstream_writev_send(tlss,
230                                      tlss->current_ev,
231                                      tlss->plain_stream,
232                                      &tlss->push.iov, 1);
233         if (subreq == NULL) {
234                 tlss->error = ENOMEM;
235                 tstream_tls_retry(stream, false);
236                 return;
237         }
238         tevent_req_set_callback(subreq, tstream_tls_push_done, stream);
239
240         tlss->push.subreq = subreq;
241 }
242
243 static void tstream_tls_push_done(struct tevent_req *subreq)
244 {
245         struct tstream_context *stream =
246                 tevent_req_callback_data(subreq,
247                 struct tstream_context);
248         struct tstream_tls *tlss =
249                 tstream_context_data(stream,
250                 struct tstream_tls);
251         int ret;
252         int sys_errno;
253
254         tlss->push.subreq = NULL;
255         ZERO_STRUCT(tlss->push.iov);
256         tlss->push.ofs = 0;
257
258         ret = tstream_writev_recv(subreq, &sys_errno);
259         TALLOC_FREE(subreq);
260         if (ret == -1) {
261                 tlss->error = sys_errno;
262                 tstream_tls_retry(stream, false);
263                 return;
264         }
265
266         tstream_tls_retry(stream, false);
267 }
268
269 static void tstream_tls_pull_done(struct tevent_req *subreq);
270
271 static ssize_t tstream_tls_pull_function(gnutls_transport_ptr ptr,
272                                          void *buf, size_t size)
273 {
274         struct tstream_context *stream =
275                 talloc_get_type_abort(ptr,
276                 struct tstream_context);
277         struct tstream_tls *tlss =
278                 tstream_context_data(stream,
279                 struct tstream_tls);
280         struct tevent_req *subreq;
281
282         if (tlss->error != 0) {
283                 errno = tlss->error;
284                 return -1;
285         }
286
287         if (tlss->pull.subreq) {
288                 errno = EAGAIN;
289                 return -1;
290         }
291
292         if (tlss->pull.iov.iov_base) {
293                 size_t n;
294
295                 n = MIN(tlss->pull.iov.iov_len, size);
296                 memcpy(buf, tlss->pull.iov.iov_base, n);
297
298                 tlss->pull.iov.iov_len -= n;
299                 if (tlss->pull.iov.iov_len == 0) {
300                         tlss->pull.iov.iov_base = NULL;
301                 }
302
303                 return n;
304         }
305
306         if (size == 0) {
307                 return 0;
308         }
309
310         tlss->pull.iov.iov_base = tlss->pull.buffer;
311         tlss->pull.iov.iov_len = MIN(size, sizeof(tlss->pull.buffer));
312
313         subreq = tstream_readv_send(tlss,
314                                     tlss->current_ev,
315                                     tlss->plain_stream,
316                                     &tlss->pull.iov, 1);
317         if (subreq == NULL) {
318                 errno = ENOMEM;
319                 return -1;
320         }
321         tevent_req_set_callback(subreq, tstream_tls_pull_done, stream);
322
323         tlss->pull.subreq = subreq;
324         errno = EAGAIN;
325         return -1;
326 }
327
328 static void tstream_tls_pull_done(struct tevent_req *subreq)
329 {
330         struct tstream_context *stream =
331                 tevent_req_callback_data(subreq,
332                 struct tstream_context);
333         struct tstream_tls *tlss =
334                 tstream_context_data(stream,
335                 struct tstream_tls);
336         int ret;
337         int sys_errno;
338
339         tlss->pull.subreq = NULL;
340
341         ret = tstream_readv_recv(subreq, &sys_errno);
342         TALLOC_FREE(subreq);
343         if (ret == -1) {
344                 tlss->error = sys_errno;
345                 tstream_tls_retry(stream, false);
346                 return;
347         }
348
349         tstream_tls_retry(stream, false);
350 }
351 #endif /* ENABLE_GNUTLS */
352
353 static int tstream_tls_destructor(struct tstream_tls *tlss)
354 {
355 #if ENABLE_GNUTLS
356         if (tlss->tls_session) {
357                 gnutls_deinit(tlss->tls_session);
358                 tlss->tls_session = NULL;
359         }
360 #endif /* ENABLE_GNUTLS */
361         return 0;
362 }
363
364 static ssize_t tstream_tls_pending_bytes(struct tstream_context *stream)
365 {
366         struct tstream_tls *tlss =
367                 tstream_context_data(stream,
368                 struct tstream_tls);
369         size_t ret;
370
371         if (tlss->error != 0) {
372                 errno = tlss->error;
373                 return -1;
374         }
375
376 #if ENABLE_GNUTLS
377         ret = gnutls_record_check_pending(tlss->tls_session);
378         ret += tlss->read.left;
379 #else /* ENABLE_GNUTLS */
380         errno = ENOSYS;
381         ret = -1;
382 #endif /* ENABLE_GNUTLS */
383         return ret;
384 }
385
386 struct tstream_tls_readv_state {
387         struct tstream_context *stream;
388
389         struct iovec *vector;
390         int count;
391
392         int ret;
393 };
394
395 static void tstream_tls_readv_crypt_next(struct tevent_req *req);
396
397 static struct tevent_req *tstream_tls_readv_send(TALLOC_CTX *mem_ctx,
398                                         struct tevent_context *ev,
399                                         struct tstream_context *stream,
400                                         struct iovec *vector,
401                                         size_t count)
402 {
403         struct tstream_tls *tlss =
404                 tstream_context_data(stream,
405                 struct tstream_tls);
406         struct tevent_req *req;
407         struct tstream_tls_readv_state *state;
408
409         tlss->read.req = NULL;
410         tlss->current_ev = ev;
411
412         req = tevent_req_create(mem_ctx, &state,
413                                 struct tstream_tls_readv_state);
414         if (req == NULL) {
415                 return NULL;
416         }
417
418         state->stream = stream;
419         state->ret = 0;
420
421         if (tlss->error != 0) {
422                 tevent_req_error(req, tlss->error);
423                 return tevent_req_post(req, ev);
424         }
425
426         /*
427          * we make a copy of the vector so we can change the structure
428          */
429         state->vector = talloc_array(state, struct iovec, count);
430         if (tevent_req_nomem(state->vector, req)) {
431                 return tevent_req_post(req, ev);
432         }
433         memcpy(state->vector, vector, sizeof(struct iovec) * count);
434         state->count = count;
435
436         tstream_tls_readv_crypt_next(req);
437         if (!tevent_req_is_in_progress(req)) {
438                 return tevent_req_post(req, ev);
439         }
440
441         return req;
442 }
443
444 static void tstream_tls_readv_crypt_next(struct tevent_req *req)
445 {
446         struct tstream_tls_readv_state *state =
447                 tevent_req_data(req,
448                 struct tstream_tls_readv_state);
449         struct tstream_tls *tlss =
450                 tstream_context_data(state->stream,
451                 struct tstream_tls);
452
453         /*
454          * copy the pending buffer first
455          */
456         while (tlss->read.left > 0 && state->count > 0) {
457                 uint8_t *base = (uint8_t *)state->vector[0].iov_base;
458                 size_t len = MIN(tlss->read.left, state->vector[0].iov_len);
459
460                 memcpy(base, tlss->read.buffer + tlss->read.ofs, len);
461
462                 base += len;
463                 state->vector[0].iov_base = base;
464                 state->vector[0].iov_len -= len;
465
466                 tlss->read.ofs += len;
467                 tlss->read.left -= len;
468
469                 if (state->vector[0].iov_len == 0) {
470                         state->vector += 1;
471                         state->count -= 1;
472                 }
473
474                 state->ret += len;
475         }
476
477         if (state->count == 0) {
478                 tevent_req_done(req);
479                 return;
480         }
481
482         tlss->read.req = req;
483         tstream_tls_retry_read(state->stream);
484 }
485
486 static void tstream_tls_retry_read(struct tstream_context *stream)
487 {
488         struct tstream_tls *tlss =
489                 tstream_context_data(stream,
490                 struct tstream_tls);
491         struct tevent_req *req = tlss->read.req;
492 #if ENABLE_GNUTLS
493         int ret;
494
495         if (tlss->error != 0) {
496                 tevent_req_error(req, tlss->error);
497                 return;
498         }
499
500         tlss->read.left = 0;
501         tlss->read.ofs = 0;
502
503         ret = gnutls_record_recv(tlss->tls_session,
504                                  tlss->read.buffer,
505                                  sizeof(tlss->read.buffer));
506         if (ret == GNUTLS_E_INTERRUPTED || ret == GNUTLS_E_AGAIN) {
507                 return;
508         }
509
510         tlss->read.req = NULL;
511
512         if (gnutls_error_is_fatal(ret) != 0) {
513                 DEBUG(1,("TLS %s - %s\n", __location__, gnutls_strerror(ret)));
514                 tlss->error = EIO;
515                 tevent_req_error(req, tlss->error);
516                 return;
517         }
518
519         if (ret == 0) {
520                 tlss->error = EPIPE;
521                 tevent_req_error(req, tlss->error);
522                 return;
523         }
524
525         tlss->read.left = ret;
526         tstream_tls_readv_crypt_next(req);
527 #else /* ENABLE_GNUTLS */
528         tevent_req_error(req, ENOSYS);
529 #endif /* ENABLE_GNUTLS */
530 }
531
532 static int tstream_tls_readv_recv(struct tevent_req *req,
533                                   int *perrno)
534 {
535         struct tstream_tls_readv_state *state =
536                 tevent_req_data(req,
537                 struct tstream_tls_readv_state);
538         struct tstream_tls *tlss =
539                 tstream_context_data(state->stream,
540                 struct tstream_tls);
541         int ret;
542
543         tlss->read.req = NULL;
544
545         ret = tsocket_simple_int_recv(req, perrno);
546         if (ret == 0) {
547                 ret = state->ret;
548         }
549
550         tevent_req_received(req);
551         return ret;
552 }
553
554 struct tstream_tls_writev_state {
555         struct tstream_context *stream;
556
557         struct iovec *vector;
558         int count;
559
560         int ret;
561 };
562
563 static void tstream_tls_writev_crypt_next(struct tevent_req *req);
564
565 static struct tevent_req *tstream_tls_writev_send(TALLOC_CTX *mem_ctx,
566                                         struct tevent_context *ev,
567                                         struct tstream_context *stream,
568                                         const struct iovec *vector,
569                                         size_t count)
570 {
571         struct tstream_tls *tlss =
572                 tstream_context_data(stream,
573                 struct tstream_tls);
574         struct tevent_req *req;
575         struct tstream_tls_writev_state *state;
576
577         tlss->write.req = NULL;
578         tlss->current_ev = ev;
579
580         req = tevent_req_create(mem_ctx, &state,
581                                 struct tstream_tls_writev_state);
582         if (req == NULL) {
583                 return NULL;
584         }
585
586         state->stream = stream;
587         state->ret = 0;
588
589         if (tlss->error != 0) {
590                 tevent_req_error(req, tlss->error);
591                 return tevent_req_post(req, ev);
592         }
593
594         /*
595          * we make a copy of the vector so we can change the structure
596          */
597         state->vector = talloc_array(state, struct iovec, count);
598         if (tevent_req_nomem(state->vector, req)) {
599                 return tevent_req_post(req, ev);
600         }
601         memcpy(state->vector, vector, sizeof(struct iovec) * count);
602         state->count = count;
603
604         tstream_tls_writev_crypt_next(req);
605         if (!tevent_req_is_in_progress(req)) {
606                 return tevent_req_post(req, ev);
607         }
608
609         return req;
610 }
611
612 static void tstream_tls_writev_crypt_next(struct tevent_req *req)
613 {
614         struct tstream_tls_writev_state *state =
615                 tevent_req_data(req,
616                 struct tstream_tls_writev_state);
617         struct tstream_tls *tlss =
618                 tstream_context_data(state->stream,
619                 struct tstream_tls);
620
621         tlss->write.left = sizeof(tlss->write.buffer);
622         tlss->write.ofs = 0;
623
624         /*
625          * first fill our buffer
626          */
627         while (tlss->write.left > 0 && state->count > 0) {
628                 uint8_t *base = (uint8_t *)state->vector[0].iov_base;
629                 size_t len = MIN(tlss->write.left, state->vector[0].iov_len);
630
631                 memcpy(tlss->write.buffer + tlss->write.ofs, base, len);
632
633                 base += len;
634                 state->vector[0].iov_base = base;
635                 state->vector[0].iov_len -= len;
636
637                 tlss->write.ofs += len;
638                 tlss->write.left -= len;
639
640                 if (state->vector[0].iov_len == 0) {
641                         state->vector += 1;
642                         state->count -= 1;
643                 }
644
645                 state->ret += len;
646         }
647
648         if (tlss->write.ofs == 0) {
649                 tevent_req_done(req);
650                 return;
651         }
652
653         tlss->write.left = tlss->write.ofs;
654         tlss->write.ofs = 0;
655
656         tlss->write.req = req;
657         tstream_tls_retry_write(state->stream);
658 }
659
660 static void tstream_tls_retry_write(struct tstream_context *stream)
661 {
662         struct tstream_tls *tlss =
663                 tstream_context_data(stream,
664                 struct tstream_tls);
665         struct tevent_req *req = tlss->write.req;
666 #if ENABLE_GNUTLS
667         int ret;
668
669         if (tlss->error != 0) {
670                 tevent_req_error(req, tlss->error);
671                 return;
672         }
673
674         ret = gnutls_record_send(tlss->tls_session,
675                                  tlss->write.buffer + tlss->write.ofs,
676                                  tlss->write.left);
677         if (ret == GNUTLS_E_INTERRUPTED || ret == GNUTLS_E_AGAIN) {
678                 return;
679         }
680
681         tlss->write.req = NULL;
682
683         if (gnutls_error_is_fatal(ret) != 0) {
684                 DEBUG(1,("TLS %s - %s\n", __location__, gnutls_strerror(ret)));
685                 tlss->error = EIO;
686                 tevent_req_error(req, tlss->error);
687                 return;
688         }
689
690         if (ret == 0) {
691                 tlss->error = EPIPE;
692                 tevent_req_error(req, tlss->error);
693                 return;
694         }
695
696         tlss->write.ofs += ret;
697         tlss->write.left -= ret;
698
699         if (tlss->write.left > 0) {
700                 tlss->write.req = req;
701                 tstream_tls_retry_write(stream);
702                 return;
703         }
704
705         tstream_tls_writev_crypt_next(req);
706 #else /* ENABLE_GNUTLS */
707         tevent_req_error(req, ENOSYS);
708 #endif /* ENABLE_GNUTLS */
709 }
710
711 static int tstream_tls_writev_recv(struct tevent_req *req,
712                                    int *perrno)
713 {
714         struct tstream_tls_writev_state *state =
715                 tevent_req_data(req,
716                 struct tstream_tls_writev_state);
717         struct tstream_tls *tlss =
718                 tstream_context_data(state->stream,
719                 struct tstream_tls);
720         int ret;
721
722         tlss->write.req = NULL;
723
724         ret = tsocket_simple_int_recv(req, perrno);
725         if (ret == 0) {
726                 ret = state->ret;
727         }
728
729         tevent_req_received(req);
730         return ret;
731 }
732
733 struct tstream_tls_disconnect_state {
734         uint8_t _dummy;
735 };
736
737 static struct tevent_req *tstream_tls_disconnect_send(TALLOC_CTX *mem_ctx,
738                                                 struct tevent_context *ev,
739                                                 struct tstream_context *stream)
740 {
741         struct tstream_tls *tlss =
742                 tstream_context_data(stream,
743                 struct tstream_tls);
744         struct tevent_req *req;
745         struct tstream_tls_disconnect_state *state;
746
747         tlss->disconnect.req = NULL;
748         tlss->current_ev = ev;
749
750         req = tevent_req_create(mem_ctx, &state,
751                                 struct tstream_tls_disconnect_state);
752         if (req == NULL) {
753                 return NULL;
754         }
755
756         if (tlss->error != 0) {
757                 tevent_req_error(req, tlss->error);
758                 return tevent_req_post(req, ev);
759         }
760
761         tlss->disconnect.req = req;
762         tstream_tls_retry_disconnect(stream);
763         if (!tevent_req_is_in_progress(req)) {
764                 return tevent_req_post(req, ev);
765         }
766
767         return req;
768 }
769
770 static void tstream_tls_retry_disconnect(struct tstream_context *stream)
771 {
772         struct tstream_tls *tlss =
773                 tstream_context_data(stream,
774                 struct tstream_tls);
775         struct tevent_req *req = tlss->disconnect.req;
776 #if ENABLE_GNUTLS
777         int ret;
778
779         if (tlss->error != 0) {
780                 tevent_req_error(req, tlss->error);
781                 return;
782         }
783
784         ret = gnutls_bye(tlss->tls_session, GNUTLS_SHUT_WR);
785         if (ret == GNUTLS_E_INTERRUPTED || ret == GNUTLS_E_AGAIN) {
786                 return;
787         }
788
789         tlss->disconnect.req = NULL;
790
791         if (gnutls_error_is_fatal(ret) != 0) {
792                 DEBUG(1,("TLS %s - %s\n", __location__, gnutls_strerror(ret)));
793                 tlss->error = EIO;
794                 tevent_req_error(req, tlss->error);
795                 return;
796         }
797
798         if (ret != GNUTLS_E_SUCCESS) {
799                 DEBUG(1,("TLS %s - %s\n", __location__, gnutls_strerror(ret)));
800                 tlss->error = EIO;
801                 tevent_req_error(req, tlss->error);
802                 return;
803         }
804
805         tevent_req_done(req);
806 #else /* ENABLE_GNUTLS */
807         tevent_req_error(req, ENOSYS);
808 #endif /* ENABLE_GNUTLS */
809 }
810
811 static int tstream_tls_disconnect_recv(struct tevent_req *req,
812                                        int *perrno)
813 {
814         int ret;
815
816         ret = tsocket_simple_int_recv(req, perrno);
817
818         tevent_req_received(req);
819         return ret;
820 }
821
822 static const struct tstream_context_ops tstream_tls_ops = {
823         .name                   = "tls",
824
825         .pending_bytes          = tstream_tls_pending_bytes,
826
827         .readv_send             = tstream_tls_readv_send,
828         .readv_recv             = tstream_tls_readv_recv,
829
830         .writev_send            = tstream_tls_writev_send,
831         .writev_recv            = tstream_tls_writev_recv,
832
833         .disconnect_send        = tstream_tls_disconnect_send,
834         .disconnect_recv        = tstream_tls_disconnect_recv,
835 };
836
837 struct tstream_tls_params {
838 #if ENABLE_GNUTLS
839         gnutls_certificate_credentials x509_cred;
840         gnutls_dh_params dh_params;
841 #endif /* ENABLE_GNUTLS */
842         bool tls_enabled;
843 };
844
845 static int tstream_tls_params_destructor(struct tstream_tls_params *tlsp)
846 {
847 #if ENABLE_GNUTLS
848         if (tlsp->x509_cred) {
849                 gnutls_certificate_free_credentials(tlsp->x509_cred);
850                 tlsp->x509_cred = NULL;
851         }
852         if (tlsp->dh_params) {
853                 gnutls_dh_params_deinit(tlsp->dh_params);
854                 tlsp->dh_params = NULL;
855         }
856 #endif /* ENABLE_GNUTLS */
857         return 0;
858 }
859
860 bool tstream_tls_params_enabled(struct tstream_tls_params *tlsp)
861 {
862         return tlsp->tls_enabled;
863 }
864
865 NTSTATUS tstream_tls_params_client(TALLOC_CTX *mem_ctx,
866                                    const char *ca_file,
867                                    const char *crl_file,
868                                    struct tstream_tls_params **_tlsp)
869 {
870 #if ENABLE_GNUTLS
871         struct tstream_tls_params *tlsp;
872         int ret;
873
874         ret = gnutls_global_init();
875         if (ret != GNUTLS_E_SUCCESS) {
876                 DEBUG(0,("TLS %s - %s\n", __location__, gnutls_strerror(ret)));
877                 return NT_STATUS_NOT_SUPPORTED;
878         }
879
880         tlsp = talloc_zero(mem_ctx, struct tstream_tls_params);
881         NT_STATUS_HAVE_NO_MEMORY(tlsp);
882
883         talloc_set_destructor(tlsp, tstream_tls_params_destructor);
884
885         ret = gnutls_certificate_allocate_credentials(&tlsp->x509_cred);
886         if (ret != GNUTLS_E_SUCCESS) {
887                 DEBUG(0,("TLS %s - %s\n", __location__, gnutls_strerror(ret)));
888                 talloc_free(tlsp);
889                 return NT_STATUS_NO_MEMORY;
890         }
891
892         if (ca_file && *ca_file) {
893                 ret = gnutls_certificate_set_x509_trust_file(tlsp->x509_cred,
894                                                              ca_file,
895                                                              GNUTLS_X509_FMT_PEM);
896                 if (ret < 0) {
897                         DEBUG(0,("TLS failed to initialise cafile %s - %s\n",
898                                  ca_file, gnutls_strerror(ret)));
899                         talloc_free(tlsp);
900                         return NT_STATUS_CANT_ACCESS_DOMAIN_INFO;
901                 }
902         }
903
904         if (crl_file && *crl_file) {
905                 ret = gnutls_certificate_set_x509_crl_file(tlsp->x509_cred,
906                                                            crl_file, 
907                                                            GNUTLS_X509_FMT_PEM);
908                 if (ret < 0) {
909                         DEBUG(0,("TLS failed to initialise crlfile %s - %s\n",
910                                  crl_file, gnutls_strerror(ret)));
911                         talloc_free(tlsp);
912                         return NT_STATUS_CANT_ACCESS_DOMAIN_INFO;
913                 }
914         }
915
916         tlsp->tls_enabled = true;
917
918         *_tlsp = tlsp;
919         return NT_STATUS_OK;
920 #else /* ENABLE_GNUTLS */
921         return NT_STATUS_NOT_IMPLEMENTED;
922 #endif /* ENABLE_GNUTLS */
923 }
924
925 struct tstream_tls_connect_state {
926         struct tstream_context *tls_stream;
927 };
928
929 struct tevent_req *_tstream_tls_connect_send(TALLOC_CTX *mem_ctx,
930                                              struct tevent_context *ev,
931                                              struct tstream_context *plain_stream,
932                                              struct tstream_tls_params *tls_params,
933                                              const char *location)
934 {
935         struct tevent_req *req;
936         struct tstream_tls_connect_state *state;
937 #if ENABLE_GNUTLS
938         struct tstream_tls *tlss;
939         int ret;
940         static const int cert_type_priority[] = {
941                 GNUTLS_CRT_X509,
942                 GNUTLS_CRT_OPENPGP,
943                 0
944         };
945 #endif /* ENABLE_GNUTLS */
946
947         req = tevent_req_create(mem_ctx, &state,
948                                 struct tstream_tls_connect_state);
949         if (req == NULL) {
950                 return NULL;
951         }
952
953 #if ENABLE_GNUTLS
954         state->tls_stream = tstream_context_create(state,
955                                                    &tstream_tls_ops,
956                                                    &tlss,
957                                                    struct tstream_tls,
958                                                    location);
959         if (tevent_req_nomem(state->tls_stream, req)) {
960                 return tevent_req_post(req, ev);
961         }
962         ZERO_STRUCTP(tlss);
963         talloc_set_destructor(tlss, tstream_tls_destructor);
964
965         tlss->plain_stream = plain_stream;
966
967         tlss->current_ev = ev;
968         tlss->retry_im = tevent_create_immediate(tlss);
969         if (tevent_req_nomem(tlss->retry_im, req)) {
970                 return tevent_req_post(req, ev);
971         }
972
973         ret = gnutls_init(&tlss->tls_session, GNUTLS_CLIENT);
974         if (ret != GNUTLS_E_SUCCESS) {
975                 DEBUG(0,("TLS %s - %s\n", __location__, gnutls_strerror(ret)));
976                 tevent_req_error(req, EINVAL);
977                 return tevent_req_post(req, ev);
978         }
979
980         ret = gnutls_set_default_priority(tlss->tls_session);
981         if (ret != GNUTLS_E_SUCCESS) {
982                 DEBUG(0,("TLS %s - %s\n", __location__, gnutls_strerror(ret)));
983                 tevent_req_error(req, EINVAL);
984                 return tevent_req_post(req, ev);
985         }
986
987         gnutls_certificate_type_set_priority(tlss->tls_session, cert_type_priority);
988
989         ret = gnutls_credentials_set(tlss->tls_session,
990                                      GNUTLS_CRD_CERTIFICATE,
991                                      tls_params->x509_cred);
992         if (ret != GNUTLS_E_SUCCESS) {
993                 DEBUG(0,("TLS %s - %s\n", __location__, gnutls_strerror(ret)));
994                 tevent_req_error(req, EINVAL);
995                 return tevent_req_post(req, ev);
996         }
997
998         gnutls_transport_set_ptr(tlss->tls_session, (gnutls_transport_ptr)state->tls_stream);
999         gnutls_transport_set_pull_function(tlss->tls_session,
1000                                            (gnutls_pull_func)tstream_tls_pull_function);
1001         gnutls_transport_set_push_function(tlss->tls_session,
1002                                            (gnutls_push_func)tstream_tls_push_function);
1003         gnutls_transport_set_lowat(tlss->tls_session, 0);
1004
1005         tlss->handshake.req = req;
1006         tstream_tls_retry_handshake(state->tls_stream);
1007         if (!tevent_req_is_in_progress(req)) {
1008                 return tevent_req_post(req, ev);
1009         }
1010
1011         return req;
1012 #else /* ENABLE_GNUTLS */
1013         tevent_req_error(req, ENOSYS);
1014         return tevent_req_post(req, ev);
1015 #endif /* ENABLE_GNUTLS */
1016 }
1017
1018 int tstream_tls_connect_recv(struct tevent_req *req,
1019                              int *perrno,
1020                              TALLOC_CTX *mem_ctx,
1021                              struct tstream_context **tls_stream)
1022 {
1023         struct tstream_tls_connect_state *state =
1024                 tevent_req_data(req,
1025                 struct tstream_tls_connect_state);
1026
1027         if (tevent_req_is_unix_error(req, perrno)) {
1028                 tevent_req_received(req);
1029                 return -1;
1030         }
1031
1032         *tls_stream = talloc_move(mem_ctx, &state->tls_stream);
1033         tevent_req_received(req);
1034         return 0;
1035 }
1036
1037 extern void tls_cert_generate(TALLOC_CTX *, const char *, const char *, const char *, const char *);
1038
1039 /*
1040   initialise global tls state
1041 */
1042 NTSTATUS tstream_tls_params_server(TALLOC_CTX *mem_ctx,
1043                                    const char *dns_host_name,
1044                                    bool enabled,
1045                                    const char *key_file,
1046                                    const char *cert_file,
1047                                    const char *ca_file,
1048                                    const char *crl_file,
1049                                    const char *dhp_file,
1050                                    struct tstream_tls_params **_tlsp)
1051 {
1052         struct tstream_tls_params *tlsp;
1053 #if ENABLE_GNUTLS
1054         int ret;
1055
1056         if (!enabled || key_file == NULL || *key_file == 0) {
1057                 tlsp = talloc_zero(mem_ctx, struct tstream_tls_params);
1058                 NT_STATUS_HAVE_NO_MEMORY(tlsp);
1059                 talloc_set_destructor(tlsp, tstream_tls_params_destructor);
1060                 tlsp->tls_enabled = false;
1061
1062                 *_tlsp = tlsp;
1063                 return NT_STATUS_OK;
1064         }
1065
1066         ret = gnutls_global_init();
1067         if (ret != GNUTLS_E_SUCCESS) {
1068                 DEBUG(0,("TLS %s - %s\n", __location__, gnutls_strerror(ret)));
1069                 return NT_STATUS_NOT_SUPPORTED;
1070         }
1071
1072         tlsp = talloc_zero(mem_ctx, struct tstream_tls_params);
1073         NT_STATUS_HAVE_NO_MEMORY(tlsp);
1074
1075         talloc_set_destructor(tlsp, tstream_tls_params_destructor);
1076
1077         if (!file_exist(ca_file)) {
1078                 tls_cert_generate(tlsp, dns_host_name,
1079                                   key_file, cert_file, ca_file);
1080         }
1081
1082         ret = gnutls_certificate_allocate_credentials(&tlsp->x509_cred);
1083         if (ret != GNUTLS_E_SUCCESS) {
1084                 DEBUG(0,("TLS %s - %s\n", __location__, gnutls_strerror(ret)));
1085                 talloc_free(tlsp);
1086                 return NT_STATUS_NO_MEMORY;
1087         }
1088
1089         if (ca_file && *ca_file) {
1090                 ret = gnutls_certificate_set_x509_trust_file(tlsp->x509_cred,
1091                                                              ca_file,
1092                                                              GNUTLS_X509_FMT_PEM);
1093                 if (ret < 0) {
1094                         DEBUG(0,("TLS failed to initialise cafile %s - %s\n",
1095                                  ca_file, gnutls_strerror(ret)));
1096                         talloc_free(tlsp);
1097                         return NT_STATUS_CANT_ACCESS_DOMAIN_INFO;
1098                 }
1099         }
1100
1101         if (crl_file && *crl_file) {
1102                 ret = gnutls_certificate_set_x509_crl_file(tlsp->x509_cred,
1103                                                            crl_file, 
1104                                                            GNUTLS_X509_FMT_PEM);
1105                 if (ret < 0) {
1106                         DEBUG(0,("TLS failed to initialise crlfile %s - %s\n",
1107                                  crl_file, gnutls_strerror(ret)));
1108                         talloc_free(tlsp);
1109                         return NT_STATUS_CANT_ACCESS_DOMAIN_INFO;
1110                 }
1111         }
1112
1113         ret = gnutls_certificate_set_x509_key_file(tlsp->x509_cred,
1114                                                    cert_file, key_file,
1115                                                    GNUTLS_X509_FMT_PEM);
1116         if (ret != GNUTLS_E_SUCCESS) {
1117                 DEBUG(0,("TLS failed to initialise certfile %s and keyfile %s - %s\n",
1118                          cert_file, key_file, gnutls_strerror(ret)));
1119                 talloc_free(tlsp);
1120                 return NT_STATUS_CANT_ACCESS_DOMAIN_INFO;
1121         }
1122
1123         ret = gnutls_dh_params_init(&tlsp->dh_params);
1124         if (ret != GNUTLS_E_SUCCESS) {
1125                 DEBUG(0,("TLS %s - %s\n", __location__, gnutls_strerror(ret)));
1126                 talloc_free(tlsp);
1127                 return NT_STATUS_NO_MEMORY;
1128         }
1129
1130         if (dhp_file && *dhp_file) {
1131                 gnutls_datum_t dhparms;
1132                 size_t size;
1133
1134                 dhparms.data = (uint8_t *)file_load(dhp_file, &size, 0, tlsp);
1135
1136                 if (!dhparms.data) {
1137                         DEBUG(0,("TLS failed to read DH Parms from %s - %d:%s\n",
1138                                  dhp_file, errno, strerror(errno)));
1139                         talloc_free(tlsp);
1140                         return NT_STATUS_CANT_ACCESS_DOMAIN_INFO;
1141                 }
1142                 dhparms.size = size;
1143
1144                 ret = gnutls_dh_params_import_pkcs3(tlsp->dh_params,
1145                                                     &dhparms,
1146                                                     GNUTLS_X509_FMT_PEM);
1147                 if (ret != GNUTLS_E_SUCCESS) {
1148                         DEBUG(0,("TLS failed to import pkcs3 %s - %s\n",
1149                                  dhp_file, gnutls_strerror(ret)));
1150                         talloc_free(tlsp);
1151                         return NT_STATUS_CANT_ACCESS_DOMAIN_INFO;
1152                 }
1153         } else {
1154                 ret = gnutls_dh_params_generate2(tlsp->dh_params, DH_BITS);
1155                 if (ret != GNUTLS_E_SUCCESS) {
1156                         DEBUG(0,("TLS failed to generate dh_params - %s\n",
1157                                  gnutls_strerror(ret)));
1158                         talloc_free(tlsp);
1159                         return NT_STATUS_INTERNAL_ERROR;
1160                 }
1161         }
1162
1163         gnutls_certificate_set_dh_params(tlsp->x509_cred, tlsp->dh_params);
1164
1165         tlsp->tls_enabled = true;
1166
1167 #else /* ENABLE_GNUTLS */
1168         tlsp = talloc_zero(mem_ctx, struct tstream_tls_params);
1169         NT_STATUS_HAVE_NO_MEMORY(tlsp);
1170         talloc_set_destructor(tlsp, tstream_tls_params_destructor);
1171         tlsp->tls_enabled = false;
1172 #endif /* ENABLE_GNUTLS */
1173
1174         *_tlsp = tlsp;
1175         return NT_STATUS_OK;
1176 }
1177
1178 struct tstream_tls_accept_state {
1179         struct tstream_context *tls_stream;
1180 };
1181
1182 struct tevent_req *_tstream_tls_accept_send(TALLOC_CTX *mem_ctx,
1183                                             struct tevent_context *ev,
1184                                             struct tstream_context *plain_stream,
1185                                             struct tstream_tls_params *tlsp,
1186                                             const char *location)
1187 {
1188         struct tevent_req *req;
1189         struct tstream_tls_accept_state *state;
1190         struct tstream_tls *tlss;
1191 #if ENABLE_GNUTLS
1192         int ret;
1193 #endif /* ENABLE_GNUTLS */
1194
1195         req = tevent_req_create(mem_ctx, &state,
1196                                 struct tstream_tls_accept_state);
1197         if (req == NULL) {
1198                 return NULL;
1199         }
1200
1201         state->tls_stream = tstream_context_create(state,
1202                                                    &tstream_tls_ops,
1203                                                    &tlss,
1204                                                    struct tstream_tls,
1205                                                    location);
1206         if (tevent_req_nomem(state->tls_stream, req)) {
1207                 return tevent_req_post(req, ev);
1208         }
1209         ZERO_STRUCTP(tlss);
1210         talloc_set_destructor(tlss, tstream_tls_destructor);
1211
1212 #if ENABLE_GNUTLS
1213         tlss->plain_stream = plain_stream;
1214
1215         tlss->current_ev = ev;
1216         tlss->retry_im = tevent_create_immediate(tlss);
1217         if (tevent_req_nomem(tlss->retry_im, req)) {
1218                 return tevent_req_post(req, ev);
1219         }
1220
1221         ret = gnutls_init(&tlss->tls_session, GNUTLS_SERVER);
1222         if (ret != GNUTLS_E_SUCCESS) {
1223                 DEBUG(0,("TLS %s - %s\n", __location__, gnutls_strerror(ret)));
1224                 tevent_req_error(req, EINVAL);
1225                 return tevent_req_post(req, ev);
1226         }
1227
1228         ret = gnutls_set_default_priority(tlss->tls_session);
1229         if (ret != GNUTLS_E_SUCCESS) {
1230                 DEBUG(0,("TLS %s - %s\n", __location__, gnutls_strerror(ret)));
1231                 tevent_req_error(req, EINVAL);
1232                 return tevent_req_post(req, ev);
1233         }
1234
1235         ret = gnutls_credentials_set(tlss->tls_session, GNUTLS_CRD_CERTIFICATE,
1236                                      tlsp->x509_cred);
1237         if (ret != GNUTLS_E_SUCCESS) {
1238                 DEBUG(0,("TLS %s - %s\n", __location__, gnutls_strerror(ret)));
1239                 tevent_req_error(req, EINVAL);
1240                 return tevent_req_post(req, ev);
1241         }
1242
1243         gnutls_certificate_server_set_request(tlss->tls_session,
1244                                               GNUTLS_CERT_REQUEST);
1245         gnutls_dh_set_prime_bits(tlss->tls_session, DH_BITS);
1246
1247         gnutls_transport_set_ptr(tlss->tls_session, (gnutls_transport_ptr)state->tls_stream);
1248         gnutls_transport_set_pull_function(tlss->tls_session,
1249                                            (gnutls_pull_func)tstream_tls_pull_function);
1250         gnutls_transport_set_push_function(tlss->tls_session,
1251                                            (gnutls_push_func)tstream_tls_push_function);
1252         gnutls_transport_set_lowat(tlss->tls_session, 0);
1253
1254         tlss->handshake.req = req;
1255         tstream_tls_retry_handshake(state->tls_stream);
1256         if (!tevent_req_is_in_progress(req)) {
1257                 return tevent_req_post(req, ev);
1258         }
1259
1260         return req;
1261 #else /* ENABLE_GNUTLS */
1262         tevent_req_error(req, ENOSYS);
1263         return tevent_req_post(req, ev);
1264 #endif /* ENABLE_GNUTLS */
1265 }
1266
1267 static void tstream_tls_retry_handshake(struct tstream_context *stream)
1268 {
1269         struct tstream_tls *tlss =
1270                 tstream_context_data(stream,
1271                 struct tstream_tls);
1272         struct tevent_req *req = tlss->handshake.req;
1273 #if ENABLE_GNUTLS
1274         int ret;
1275
1276         if (tlss->error != 0) {
1277                 tevent_req_error(req, tlss->error);
1278                 return;
1279         }
1280
1281         ret = gnutls_handshake(tlss->tls_session);
1282         if (ret == GNUTLS_E_INTERRUPTED || ret == GNUTLS_E_AGAIN) {
1283                 return;
1284         }
1285
1286         tlss->handshake.req = NULL;
1287
1288         if (gnutls_error_is_fatal(ret) != 0) {
1289                 DEBUG(1,("TLS %s - %s\n", __location__, gnutls_strerror(ret)));
1290                 tlss->error = EIO;
1291                 tevent_req_error(req, tlss->error);
1292                 return;
1293         }
1294
1295         if (ret != GNUTLS_E_SUCCESS) {
1296                 DEBUG(1,("TLS %s - %s\n", __location__, gnutls_strerror(ret)));
1297                 tlss->error = EIO;
1298                 tevent_req_error(req, tlss->error);
1299                 return;
1300         }
1301
1302         tevent_req_done(req);
1303 #else /* ENABLE_GNUTLS */
1304         tevent_req_error(req, ENOSYS);
1305 #endif /* ENABLE_GNUTLS */
1306 }
1307
1308 int tstream_tls_accept_recv(struct tevent_req *req,
1309                             int *perrno,
1310                             TALLOC_CTX *mem_ctx,
1311                             struct tstream_context **tls_stream)
1312 {
1313         struct tstream_tls_accept_state *state =
1314                 tevent_req_data(req,
1315                 struct tstream_tls_accept_state);
1316
1317         if (tevent_req_is_unix_error(req, perrno)) {
1318                 tevent_req_received(req);
1319                 return -1;
1320         }
1321
1322         *tls_stream = talloc_move(mem_ctx, &state->tls_stream);
1323         tevent_req_received(req);
1324         return 0;
1325 }