tls_tstream: use a dynamic buffer for the push case
[metze/samba/wip.git] / source4 / lib / tls / tls_tstream.c
1 /*
2    Unix SMB/CIFS implementation.
3
4    Copyright (C) Stefan Metzmacher 2010
5
6    This program is free software; you can redistribute it and/or modify
7    it under the terms of the GNU General Public License as published by
8    the Free Software Foundation; either version 3 of the License, or
9    (at your option) any later version.
10
11    This program is distributed in the hope that it will be useful,
12    but WITHOUT ANY WARRANTY; without even the implied warranty of
13    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
14    GNU General Public License for more details.
15
16    You should have received a copy of the GNU General Public License
17    along with this program.  If not, see <http://www.gnu.org/licenses/>.
18 */
19
20 #include "includes.h"
21 #include "system/network.h"
22 #include "../util/tevent_unix.h"
23 #include "../lib/tsocket/tsocket.h"
24 #include "../lib/tsocket/tsocket_internal.h"
25 #include "lib/tls/tls.h"
26
27 #if ENABLE_GNUTLS
28 #include "gnutls/gnutls.h"
29
30 #define DH_BITS 1024
31
32 #if defined(HAVE_GNUTLS_DATUM) && !defined(HAVE_GNUTLS_DATUM_T)
33 typedef gnutls_datum gnutls_datum_t;
34 #endif
35
36 #endif /* ENABLE_GNUTLS */
37
38 static const struct tstream_context_ops tstream_tls_ops;
39
40 struct tstream_tls {
41         struct tstream_context *plain_stream;
42         int error;
43
44 #if ENABLE_GNUTLS
45         gnutls_session tls_session;
46 #endif /* ENABLE_GNUTLS */
47
48         struct tevent_context *current_ev;
49
50         struct tevent_immediate *retry_im;
51
52         struct {
53                 uint8_t *buf;
54                 off_t ofs;
55                 struct iovec iov;
56                 struct tevent_req *subreq;
57                 struct tevent_immediate *im;
58         } push;
59
60         struct {
61                 uint8_t buffer[1024];
62                 struct iovec iov;
63                 struct tevent_req *subreq;
64         } pull;
65
66         struct {
67                 struct tevent_req *req;
68         } handshake;
69
70         struct {
71                 off_t ofs;
72                 size_t left;
73                 uint8_t buffer[1024];
74                 struct tevent_req *req;
75         } write;
76
77         struct {
78                 off_t ofs;
79                 size_t left;
80                 uint8_t buffer[1024];
81                 struct tevent_req *req;
82         } read;
83
84         struct {
85                 struct tevent_req *req;
86         } disconnect;
87 };
88
89 static void tstream_tls_retry_handshake(struct tstream_context *stream);
90 static void tstream_tls_retry_read(struct tstream_context *stream);
91 static void tstream_tls_retry_write(struct tstream_context *stream);
92 static void tstream_tls_retry_disconnect(struct tstream_context *stream);
93 static void tstream_tls_retry_trigger(struct tevent_context *ctx,
94                                       struct tevent_immediate *im,
95                                       void *private_data);
96
97 static void tstream_tls_retry(struct tstream_context *stream, bool deferred)
98 {
99
100         struct tstream_tls *tlss =
101                 tstream_context_data(stream,
102                 struct tstream_tls);
103
104         if (tlss->disconnect.req) {
105                 tstream_tls_retry_disconnect(stream);
106                 return;
107         }
108
109         if (tlss->handshake.req) {
110                 tstream_tls_retry_handshake(stream);
111                 return;
112         }
113
114         if (tlss->write.req && tlss->read.req && !deferred) {
115                 tevent_schedule_immediate(tlss->retry_im, tlss->current_ev,
116                                           tstream_tls_retry_trigger,
117                                           stream);
118         }
119
120         if (tlss->write.req) {
121                 tstream_tls_retry_write(stream);
122                 return;
123         }
124
125         if (tlss->read.req) {
126                 tstream_tls_retry_read(stream);
127                 return;
128         }
129 }
130
131 static void tstream_tls_retry_trigger(struct tevent_context *ctx,
132                                       struct tevent_immediate *im,
133                                       void *private_data)
134 {
135         struct tstream_context *stream =
136                 talloc_get_type_abort(private_data,
137                 struct tstream_context);
138
139         tstream_tls_retry(stream, true);
140 }
141
142 #if ENABLE_GNUTLS
143 static void tstream_tls_push_trigger_write(struct tevent_context *ev,
144                                            struct tevent_immediate *im,
145                                            void *private_data);
146
147 static ssize_t tstream_tls_push_function(gnutls_transport_ptr ptr,
148                                          const void *buf, size_t size)
149 {
150         struct tstream_context *stream =
151                 talloc_get_type_abort(ptr,
152                 struct tstream_context);
153         struct tstream_tls *tlss =
154                 tstream_context_data(stream,
155                 struct tstream_tls);
156         uint8_t *nbuf;
157         size_t len;
158
159         if (tlss->error != 0) {
160                 errno = tlss->error;
161                 return -1;
162         }
163
164         if (tlss->push.subreq) {
165                 errno = EAGAIN;
166                 return -1;
167         }
168
169         len = MIN(size, UINT16_MAX - tlss->push.ofs);
170
171         if (len == 0) {
172                 errno = EAGAIN;
173                 return -1;
174         }
175
176         nbuf = talloc_realloc(tlss, tlss->push.buf,
177                               uint8_t, tlss->push.ofs + len);
178         if (nbuf == NULL) {
179                 if (tlss->push.buf) {
180                         errno = EAGAIN;
181                         return -1;
182                 }
183
184                 return -1;
185         }
186         tlss->push.buf = nbuf;
187
188         memcpy(tlss->push.buf + tlss->push.ofs, buf, len);
189
190         if (tlss->push.im == NULL) {
191                 tlss->push.im = tevent_create_immediate(tlss);
192                 if (tlss->push.im == NULL) {
193                         errno = ENOMEM;
194                         return -1;
195                 }
196         }
197
198         if (tlss->push.ofs == 0) {
199                 /*
200                  * We'll do start the tstream_writev
201                  * in the next event cycle.
202                  *
203                  * This way we can batch all push requests,
204                  * if they fit into a UINT16_MAX buffer.
205                  *
206                  * This is important as gnutls_handshake()
207                  * had a bug in some versions e.g. 2.4.1
208                  * and others (See bug #7218) and it doesn't
209                  * handle EAGAIN.
210                  */
211                 tevent_schedule_immediate(tlss->push.im,
212                                           tlss->current_ev,
213                                           tstream_tls_push_trigger_write,
214                                           stream);
215         }
216
217         tlss->push.ofs += len;
218         return len;
219 }
220
221 static void tstream_tls_push_done(struct tevent_req *subreq);
222
223 static void tstream_tls_push_trigger_write(struct tevent_context *ev,
224                                            struct tevent_immediate *im,
225                                            void *private_data)
226 {
227         struct tstream_context *stream =
228                 talloc_get_type_abort(private_data,
229                 struct tstream_context);
230         struct tstream_tls *tlss =
231                 tstream_context_data(stream,
232                 struct tstream_tls);
233         struct tevent_req *subreq;
234
235         if (tlss->push.subreq) {
236                 /* nothing todo */
237                 return;
238         }
239
240         tlss->push.iov.iov_base = (char *)tlss->push.buf;
241         tlss->push.iov.iov_len = tlss->push.ofs;
242
243         subreq = tstream_writev_send(tlss,
244                                      tlss->current_ev,
245                                      tlss->plain_stream,
246                                      &tlss->push.iov, 1);
247         if (subreq == NULL) {
248                 tlss->error = ENOMEM;
249                 tstream_tls_retry(stream, false);
250                 return;
251         }
252         tevent_req_set_callback(subreq, tstream_tls_push_done, stream);
253
254         tlss->push.subreq = subreq;
255 }
256
257 static void tstream_tls_push_done(struct tevent_req *subreq)
258 {
259         struct tstream_context *stream =
260                 tevent_req_callback_data(subreq,
261                 struct tstream_context);
262         struct tstream_tls *tlss =
263                 tstream_context_data(stream,
264                 struct tstream_tls);
265         int ret;
266         int sys_errno;
267
268         tlss->push.subreq = NULL;
269         ZERO_STRUCT(tlss->push.iov);
270         TALLOC_FREE(tlss->push.buf);
271         tlss->push.ofs = 0;
272
273         ret = tstream_writev_recv(subreq, &sys_errno);
274         TALLOC_FREE(subreq);
275         if (ret == -1) {
276                 tlss->error = sys_errno;
277                 tstream_tls_retry(stream, false);
278                 return;
279         }
280
281         tstream_tls_retry(stream, false);
282 }
283
284 static void tstream_tls_pull_done(struct tevent_req *subreq);
285
286 static ssize_t tstream_tls_pull_function(gnutls_transport_ptr ptr,
287                                          void *buf, size_t size)
288 {
289         struct tstream_context *stream =
290                 talloc_get_type_abort(ptr,
291                 struct tstream_context);
292         struct tstream_tls *tlss =
293                 tstream_context_data(stream,
294                 struct tstream_tls);
295         struct tevent_req *subreq;
296
297         if (tlss->error != 0) {
298                 errno = tlss->error;
299                 return -1;
300         }
301
302         if (tlss->pull.subreq) {
303                 errno = EAGAIN;
304                 return -1;
305         }
306
307         if (tlss->pull.iov.iov_base) {
308                 size_t n;
309
310                 n = MIN(tlss->pull.iov.iov_len, size);
311                 memcpy(buf, tlss->pull.iov.iov_base, n);
312
313                 tlss->pull.iov.iov_len -= n;
314                 if (tlss->pull.iov.iov_len == 0) {
315                         tlss->pull.iov.iov_base = NULL;
316                 }
317
318                 return n;
319         }
320
321         if (size == 0) {
322                 return 0;
323         }
324
325         tlss->pull.iov.iov_base = tlss->pull.buffer;
326         tlss->pull.iov.iov_len = MIN(size, sizeof(tlss->pull.buffer));
327
328         subreq = tstream_readv_send(tlss,
329                                     tlss->current_ev,
330                                     tlss->plain_stream,
331                                     &tlss->pull.iov, 1);
332         if (subreq == NULL) {
333                 errno = ENOMEM;
334                 return -1;
335         }
336         tevent_req_set_callback(subreq, tstream_tls_pull_done, stream);
337
338         tlss->pull.subreq = subreq;
339         errno = EAGAIN;
340         return -1;
341 }
342
343 static void tstream_tls_pull_done(struct tevent_req *subreq)
344 {
345         struct tstream_context *stream =
346                 tevent_req_callback_data(subreq,
347                 struct tstream_context);
348         struct tstream_tls *tlss =
349                 tstream_context_data(stream,
350                 struct tstream_tls);
351         int ret;
352         int sys_errno;
353
354         tlss->pull.subreq = NULL;
355
356         ret = tstream_readv_recv(subreq, &sys_errno);
357         TALLOC_FREE(subreq);
358         if (ret == -1) {
359                 tlss->error = sys_errno;
360                 tstream_tls_retry(stream, false);
361                 return;
362         }
363
364         tstream_tls_retry(stream, false);
365 }
366 #endif /* ENABLE_GNUTLS */
367
368 static int tstream_tls_destructor(struct tstream_tls *tlss)
369 {
370 #if ENABLE_GNUTLS
371         if (tlss->tls_session) {
372                 gnutls_deinit(tlss->tls_session);
373                 tlss->tls_session = NULL;
374         }
375 #endif /* ENABLE_GNUTLS */
376         return 0;
377 }
378
379 static ssize_t tstream_tls_pending_bytes(struct tstream_context *stream)
380 {
381         struct tstream_tls *tlss =
382                 tstream_context_data(stream,
383                 struct tstream_tls);
384         size_t ret;
385
386         if (tlss->error != 0) {
387                 errno = tlss->error;
388                 return -1;
389         }
390
391 #if ENABLE_GNUTLS
392         ret = gnutls_record_check_pending(tlss->tls_session);
393         ret += tlss->read.left;
394 #else /* ENABLE_GNUTLS */
395         errno = ENOSYS;
396         ret = -1;
397 #endif /* ENABLE_GNUTLS */
398         return ret;
399 }
400
401 struct tstream_tls_readv_state {
402         struct tstream_context *stream;
403
404         struct iovec *vector;
405         int count;
406
407         int ret;
408 };
409
410 static void tstream_tls_readv_crypt_next(struct tevent_req *req);
411
412 static struct tevent_req *tstream_tls_readv_send(TALLOC_CTX *mem_ctx,
413                                         struct tevent_context *ev,
414                                         struct tstream_context *stream,
415                                         struct iovec *vector,
416                                         size_t count)
417 {
418         struct tstream_tls *tlss =
419                 tstream_context_data(stream,
420                 struct tstream_tls);
421         struct tevent_req *req;
422         struct tstream_tls_readv_state *state;
423
424         tlss->read.req = NULL;
425         tlss->current_ev = ev;
426
427         req = tevent_req_create(mem_ctx, &state,
428                                 struct tstream_tls_readv_state);
429         if (req == NULL) {
430                 return NULL;
431         }
432
433         state->stream = stream;
434         state->ret = 0;
435
436         if (tlss->error != 0) {
437                 tevent_req_error(req, tlss->error);
438                 return tevent_req_post(req, ev);
439         }
440
441         /*
442          * we make a copy of the vector so we can change the structure
443          */
444         state->vector = talloc_array(state, struct iovec, count);
445         if (tevent_req_nomem(state->vector, req)) {
446                 return tevent_req_post(req, ev);
447         }
448         memcpy(state->vector, vector, sizeof(struct iovec) * count);
449         state->count = count;
450
451         tstream_tls_readv_crypt_next(req);
452         if (!tevent_req_is_in_progress(req)) {
453                 return tevent_req_post(req, ev);
454         }
455
456         return req;
457 }
458
459 static void tstream_tls_readv_crypt_next(struct tevent_req *req)
460 {
461         struct tstream_tls_readv_state *state =
462                 tevent_req_data(req,
463                 struct tstream_tls_readv_state);
464         struct tstream_tls *tlss =
465                 tstream_context_data(state->stream,
466                 struct tstream_tls);
467
468         /*
469          * copy the pending buffer first
470          */
471         while (tlss->read.left > 0 && state->count > 0) {
472                 uint8_t *base = (uint8_t *)state->vector[0].iov_base;
473                 size_t len = MIN(tlss->read.left, state->vector[0].iov_len);
474
475                 memcpy(base, tlss->read.buffer + tlss->read.ofs, len);
476
477                 base += len;
478                 state->vector[0].iov_base = base;
479                 state->vector[0].iov_len -= len;
480
481                 tlss->read.ofs += len;
482                 tlss->read.left -= len;
483
484                 if (state->vector[0].iov_len == 0) {
485                         state->vector += 1;
486                         state->count -= 1;
487                 }
488
489                 state->ret += len;
490         }
491
492         if (state->count == 0) {
493                 tevent_req_done(req);
494                 return;
495         }
496
497         tlss->read.req = req;
498         tstream_tls_retry_read(state->stream);
499 }
500
501 static void tstream_tls_retry_read(struct tstream_context *stream)
502 {
503         struct tstream_tls *tlss =
504                 tstream_context_data(stream,
505                 struct tstream_tls);
506         struct tevent_req *req = tlss->read.req;
507 #if ENABLE_GNUTLS
508         int ret;
509
510         if (tlss->error != 0) {
511                 tevent_req_error(req, tlss->error);
512                 return;
513         }
514
515         tlss->read.left = 0;
516         tlss->read.ofs = 0;
517
518         ret = gnutls_record_recv(tlss->tls_session,
519                                  tlss->read.buffer,
520                                  sizeof(tlss->read.buffer));
521         if (ret == GNUTLS_E_INTERRUPTED || ret == GNUTLS_E_AGAIN) {
522                 return;
523         }
524
525         tlss->read.req = NULL;
526
527         if (gnutls_error_is_fatal(ret) != 0) {
528                 DEBUG(1,("TLS %s - %s\n", __location__, gnutls_strerror(ret)));
529                 tlss->error = EIO;
530                 tevent_req_error(req, tlss->error);
531                 return;
532         }
533
534         if (ret == 0) {
535                 tlss->error = EPIPE;
536                 tevent_req_error(req, tlss->error);
537                 return;
538         }
539
540         tlss->read.left = ret;
541         tstream_tls_readv_crypt_next(req);
542 #else /* ENABLE_GNUTLS */
543         tevent_req_error(req, ENOSYS);
544 #endif /* ENABLE_GNUTLS */
545 }
546
547 static int tstream_tls_readv_recv(struct tevent_req *req,
548                                   int *perrno)
549 {
550         struct tstream_tls_readv_state *state =
551                 tevent_req_data(req,
552                 struct tstream_tls_readv_state);
553         struct tstream_tls *tlss =
554                 tstream_context_data(state->stream,
555                 struct tstream_tls);
556         int ret;
557
558         tlss->read.req = NULL;
559
560         ret = tsocket_simple_int_recv(req, perrno);
561         if (ret == 0) {
562                 ret = state->ret;
563         }
564
565         tevent_req_received(req);
566         return ret;
567 }
568
569 struct tstream_tls_writev_state {
570         struct tstream_context *stream;
571
572         struct iovec *vector;
573         int count;
574
575         int ret;
576 };
577
578 static void tstream_tls_writev_crypt_next(struct tevent_req *req);
579
580 static struct tevent_req *tstream_tls_writev_send(TALLOC_CTX *mem_ctx,
581                                         struct tevent_context *ev,
582                                         struct tstream_context *stream,
583                                         const struct iovec *vector,
584                                         size_t count)
585 {
586         struct tstream_tls *tlss =
587                 tstream_context_data(stream,
588                 struct tstream_tls);
589         struct tevent_req *req;
590         struct tstream_tls_writev_state *state;
591
592         tlss->write.req = NULL;
593         tlss->current_ev = ev;
594
595         req = tevent_req_create(mem_ctx, &state,
596                                 struct tstream_tls_writev_state);
597         if (req == NULL) {
598                 return NULL;
599         }
600
601         state->stream = stream;
602         state->ret = 0;
603
604         if (tlss->error != 0) {
605                 tevent_req_error(req, tlss->error);
606                 return tevent_req_post(req, ev);
607         }
608
609         /*
610          * we make a copy of the vector so we can change the structure
611          */
612         state->vector = talloc_array(state, struct iovec, count);
613         if (tevent_req_nomem(state->vector, req)) {
614                 return tevent_req_post(req, ev);
615         }
616         memcpy(state->vector, vector, sizeof(struct iovec) * count);
617         state->count = count;
618
619         tstream_tls_writev_crypt_next(req);
620         if (!tevent_req_is_in_progress(req)) {
621                 return tevent_req_post(req, ev);
622         }
623
624         return req;
625 }
626
627 static void tstream_tls_writev_crypt_next(struct tevent_req *req)
628 {
629         struct tstream_tls_writev_state *state =
630                 tevent_req_data(req,
631                 struct tstream_tls_writev_state);
632         struct tstream_tls *tlss =
633                 tstream_context_data(state->stream,
634                 struct tstream_tls);
635
636         tlss->write.left = sizeof(tlss->write.buffer);
637         tlss->write.ofs = 0;
638
639         /*
640          * first fill our buffer
641          */
642         while (tlss->write.left > 0 && state->count > 0) {
643                 uint8_t *base = (uint8_t *)state->vector[0].iov_base;
644                 size_t len = MIN(tlss->write.left, state->vector[0].iov_len);
645
646                 memcpy(tlss->write.buffer + tlss->write.ofs, base, len);
647
648                 base += len;
649                 state->vector[0].iov_base = base;
650                 state->vector[0].iov_len -= len;
651
652                 tlss->write.ofs += len;
653                 tlss->write.left -= len;
654
655                 if (state->vector[0].iov_len == 0) {
656                         state->vector += 1;
657                         state->count -= 1;
658                 }
659
660                 state->ret += len;
661         }
662
663         if (tlss->write.ofs == 0) {
664                 tevent_req_done(req);
665                 return;
666         }
667
668         tlss->write.left = tlss->write.ofs;
669         tlss->write.ofs = 0;
670
671         tlss->write.req = req;
672         tstream_tls_retry_write(state->stream);
673 }
674
675 static void tstream_tls_retry_write(struct tstream_context *stream)
676 {
677         struct tstream_tls *tlss =
678                 tstream_context_data(stream,
679                 struct tstream_tls);
680         struct tevent_req *req = tlss->write.req;
681 #if ENABLE_GNUTLS
682         int ret;
683
684         if (tlss->error != 0) {
685                 tevent_req_error(req, tlss->error);
686                 return;
687         }
688
689         ret = gnutls_record_send(tlss->tls_session,
690                                  tlss->write.buffer + tlss->write.ofs,
691                                  tlss->write.left);
692         if (ret == GNUTLS_E_INTERRUPTED || ret == GNUTLS_E_AGAIN) {
693                 return;
694         }
695
696         tlss->write.req = NULL;
697
698         if (gnutls_error_is_fatal(ret) != 0) {
699                 DEBUG(1,("TLS %s - %s\n", __location__, gnutls_strerror(ret)));
700                 tlss->error = EIO;
701                 tevent_req_error(req, tlss->error);
702                 return;
703         }
704
705         if (ret == 0) {
706                 tlss->error = EPIPE;
707                 tevent_req_error(req, tlss->error);
708                 return;
709         }
710
711         tlss->write.ofs += ret;
712         tlss->write.left -= ret;
713
714         if (tlss->write.left > 0) {
715                 tlss->write.req = req;
716                 tstream_tls_retry_write(stream);
717                 return;
718         }
719
720         tstream_tls_writev_crypt_next(req);
721 #else /* ENABLE_GNUTLS */
722         tevent_req_error(req, ENOSYS);
723 #endif /* ENABLE_GNUTLS */
724 }
725
726 static int tstream_tls_writev_recv(struct tevent_req *req,
727                                    int *perrno)
728 {
729         struct tstream_tls_writev_state *state =
730                 tevent_req_data(req,
731                 struct tstream_tls_writev_state);
732         struct tstream_tls *tlss =
733                 tstream_context_data(state->stream,
734                 struct tstream_tls);
735         int ret;
736
737         tlss->write.req = NULL;
738
739         ret = tsocket_simple_int_recv(req, perrno);
740         if (ret == 0) {
741                 ret = state->ret;
742         }
743
744         tevent_req_received(req);
745         return ret;
746 }
747
748 struct tstream_tls_disconnect_state {
749         uint8_t _dummy;
750 };
751
752 static struct tevent_req *tstream_tls_disconnect_send(TALLOC_CTX *mem_ctx,
753                                                 struct tevent_context *ev,
754                                                 struct tstream_context *stream)
755 {
756         struct tstream_tls *tlss =
757                 tstream_context_data(stream,
758                 struct tstream_tls);
759         struct tevent_req *req;
760         struct tstream_tls_disconnect_state *state;
761
762         tlss->disconnect.req = NULL;
763         tlss->current_ev = ev;
764
765         req = tevent_req_create(mem_ctx, &state,
766                                 struct tstream_tls_disconnect_state);
767         if (req == NULL) {
768                 return NULL;
769         }
770
771         if (tlss->error != 0) {
772                 tevent_req_error(req, tlss->error);
773                 return tevent_req_post(req, ev);
774         }
775
776         tlss->disconnect.req = req;
777         tstream_tls_retry_disconnect(stream);
778         if (!tevent_req_is_in_progress(req)) {
779                 return tevent_req_post(req, ev);
780         }
781
782         return req;
783 }
784
785 static void tstream_tls_retry_disconnect(struct tstream_context *stream)
786 {
787         struct tstream_tls *tlss =
788                 tstream_context_data(stream,
789                 struct tstream_tls);
790         struct tevent_req *req = tlss->disconnect.req;
791 #if ENABLE_GNUTLS
792         int ret;
793
794         if (tlss->error != 0) {
795                 tevent_req_error(req, tlss->error);
796                 return;
797         }
798
799         ret = gnutls_bye(tlss->tls_session, GNUTLS_SHUT_WR);
800         if (ret == GNUTLS_E_INTERRUPTED || ret == GNUTLS_E_AGAIN) {
801                 return;
802         }
803
804         tlss->disconnect.req = NULL;
805
806         if (gnutls_error_is_fatal(ret) != 0) {
807                 DEBUG(1,("TLS %s - %s\n", __location__, gnutls_strerror(ret)));
808                 tlss->error = EIO;
809                 tevent_req_error(req, tlss->error);
810                 return;
811         }
812
813         if (ret != GNUTLS_E_SUCCESS) {
814                 DEBUG(1,("TLS %s - %s\n", __location__, gnutls_strerror(ret)));
815                 tlss->error = EIO;
816                 tevent_req_error(req, tlss->error);
817                 return;
818         }
819
820         tevent_req_done(req);
821 #else /* ENABLE_GNUTLS */
822         tevent_req_error(req, ENOSYS);
823 #endif /* ENABLE_GNUTLS */
824 }
825
826 static int tstream_tls_disconnect_recv(struct tevent_req *req,
827                                        int *perrno)
828 {
829         int ret;
830
831         ret = tsocket_simple_int_recv(req, perrno);
832
833         tevent_req_received(req);
834         return ret;
835 }
836
837 static const struct tstream_context_ops tstream_tls_ops = {
838         .name                   = "tls",
839
840         .pending_bytes          = tstream_tls_pending_bytes,
841
842         .readv_send             = tstream_tls_readv_send,
843         .readv_recv             = tstream_tls_readv_recv,
844
845         .writev_send            = tstream_tls_writev_send,
846         .writev_recv            = tstream_tls_writev_recv,
847
848         .disconnect_send        = tstream_tls_disconnect_send,
849         .disconnect_recv        = tstream_tls_disconnect_recv,
850 };
851
852 struct tstream_tls_params {
853 #if ENABLE_GNUTLS
854         gnutls_certificate_credentials x509_cred;
855         gnutls_dh_params dh_params;
856 #endif /* ENABLE_GNUTLS */
857         bool tls_enabled;
858 };
859
860 static int tstream_tls_params_destructor(struct tstream_tls_params *tlsp)
861 {
862 #if ENABLE_GNUTLS
863         if (tlsp->x509_cred) {
864                 gnutls_certificate_free_credentials(tlsp->x509_cred);
865                 tlsp->x509_cred = NULL;
866         }
867         if (tlsp->dh_params) {
868                 gnutls_dh_params_deinit(tlsp->dh_params);
869                 tlsp->dh_params = NULL;
870         }
871 #endif /* ENABLE_GNUTLS */
872         return 0;
873 }
874
875 bool tstream_tls_params_enabled(struct tstream_tls_params *tlsp)
876 {
877         return tlsp->tls_enabled;
878 }
879
880 NTSTATUS tstream_tls_params_client(TALLOC_CTX *mem_ctx,
881                                    const char *ca_file,
882                                    const char *crl_file,
883                                    struct tstream_tls_params **_tlsp)
884 {
885 #if ENABLE_GNUTLS
886         struct tstream_tls_params *tlsp;
887         int ret;
888
889         ret = gnutls_global_init();
890         if (ret != GNUTLS_E_SUCCESS) {
891                 DEBUG(0,("TLS %s - %s\n", __location__, gnutls_strerror(ret)));
892                 return NT_STATUS_NOT_SUPPORTED;
893         }
894
895         tlsp = talloc_zero(mem_ctx, struct tstream_tls_params);
896         NT_STATUS_HAVE_NO_MEMORY(tlsp);
897
898         talloc_set_destructor(tlsp, tstream_tls_params_destructor);
899
900         ret = gnutls_certificate_allocate_credentials(&tlsp->x509_cred);
901         if (ret != GNUTLS_E_SUCCESS) {
902                 DEBUG(0,("TLS %s - %s\n", __location__, gnutls_strerror(ret)));
903                 talloc_free(tlsp);
904                 return NT_STATUS_NO_MEMORY;
905         }
906
907         if (ca_file && *ca_file) {
908                 ret = gnutls_certificate_set_x509_trust_file(tlsp->x509_cred,
909                                                              ca_file,
910                                                              GNUTLS_X509_FMT_PEM);
911                 if (ret < 0) {
912                         DEBUG(0,("TLS failed to initialise cafile %s - %s\n",
913                                  ca_file, gnutls_strerror(ret)));
914                         talloc_free(tlsp);
915                         return NT_STATUS_CANT_ACCESS_DOMAIN_INFO;
916                 }
917         }
918
919         if (crl_file && *crl_file) {
920                 ret = gnutls_certificate_set_x509_crl_file(tlsp->x509_cred,
921                                                            crl_file, 
922                                                            GNUTLS_X509_FMT_PEM);
923                 if (ret < 0) {
924                         DEBUG(0,("TLS failed to initialise crlfile %s - %s\n",
925                                  crl_file, gnutls_strerror(ret)));
926                         talloc_free(tlsp);
927                         return NT_STATUS_CANT_ACCESS_DOMAIN_INFO;
928                 }
929         }
930
931         tlsp->tls_enabled = true;
932
933         *_tlsp = tlsp;
934         return NT_STATUS_OK;
935 #else /* ENABLE_GNUTLS */
936         return NT_STATUS_NOT_IMPLEMENTED;
937 #endif /* ENABLE_GNUTLS */
938 }
939
940 struct tstream_tls_connect_state {
941         struct tstream_context *tls_stream;
942 };
943
944 struct tevent_req *_tstream_tls_connect_send(TALLOC_CTX *mem_ctx,
945                                              struct tevent_context *ev,
946                                              struct tstream_context *plain_stream,
947                                              struct tstream_tls_params *tls_params,
948                                              const char *location)
949 {
950         struct tevent_req *req;
951         struct tstream_tls_connect_state *state;
952 #if ENABLE_GNUTLS
953         struct tstream_tls *tlss;
954         int ret;
955         static const int cert_type_priority[] = {
956                 GNUTLS_CRT_X509,
957                 GNUTLS_CRT_OPENPGP,
958                 0
959         };
960 #endif /* ENABLE_GNUTLS */
961
962         req = tevent_req_create(mem_ctx, &state,
963                                 struct tstream_tls_connect_state);
964         if (req == NULL) {
965                 return NULL;
966         }
967
968 #if ENABLE_GNUTLS
969         state->tls_stream = tstream_context_create(state,
970                                                    &tstream_tls_ops,
971                                                    &tlss,
972                                                    struct tstream_tls,
973                                                    location);
974         if (tevent_req_nomem(state->tls_stream, req)) {
975                 return tevent_req_post(req, ev);
976         }
977         ZERO_STRUCTP(tlss);
978         talloc_set_destructor(tlss, tstream_tls_destructor);
979
980         tlss->plain_stream = plain_stream;
981
982         tlss->current_ev = ev;
983         tlss->retry_im = tevent_create_immediate(tlss);
984         if (tevent_req_nomem(tlss->retry_im, req)) {
985                 return tevent_req_post(req, ev);
986         }
987
988         ret = gnutls_init(&tlss->tls_session, GNUTLS_CLIENT);
989         if (ret != GNUTLS_E_SUCCESS) {
990                 DEBUG(0,("TLS %s - %s\n", __location__, gnutls_strerror(ret)));
991                 tevent_req_error(req, EINVAL);
992                 return tevent_req_post(req, ev);
993         }
994
995         ret = gnutls_set_default_priority(tlss->tls_session);
996         if (ret != GNUTLS_E_SUCCESS) {
997                 DEBUG(0,("TLS %s - %s\n", __location__, gnutls_strerror(ret)));
998                 tevent_req_error(req, EINVAL);
999                 return tevent_req_post(req, ev);
1000         }
1001
1002         gnutls_certificate_type_set_priority(tlss->tls_session, cert_type_priority);
1003
1004         ret = gnutls_credentials_set(tlss->tls_session,
1005                                      GNUTLS_CRD_CERTIFICATE,
1006                                      tls_params->x509_cred);
1007         if (ret != GNUTLS_E_SUCCESS) {
1008                 DEBUG(0,("TLS %s - %s\n", __location__, gnutls_strerror(ret)));
1009                 tevent_req_error(req, EINVAL);
1010                 return tevent_req_post(req, ev);
1011         }
1012
1013         gnutls_transport_set_ptr(tlss->tls_session, (gnutls_transport_ptr)state->tls_stream);
1014         gnutls_transport_set_pull_function(tlss->tls_session,
1015                                            (gnutls_pull_func)tstream_tls_pull_function);
1016         gnutls_transport_set_push_function(tlss->tls_session,
1017                                            (gnutls_push_func)tstream_tls_push_function);
1018         gnutls_transport_set_lowat(tlss->tls_session, 0);
1019
1020         tlss->handshake.req = req;
1021         tstream_tls_retry_handshake(state->tls_stream);
1022         if (!tevent_req_is_in_progress(req)) {
1023                 return tevent_req_post(req, ev);
1024         }
1025
1026         return req;
1027 #else /* ENABLE_GNUTLS */
1028         tevent_req_error(req, ENOSYS);
1029         return tevent_req_post(req, ev);
1030 #endif /* ENABLE_GNUTLS */
1031 }
1032
1033 int tstream_tls_connect_recv(struct tevent_req *req,
1034                              int *perrno,
1035                              TALLOC_CTX *mem_ctx,
1036                              struct tstream_context **tls_stream)
1037 {
1038         struct tstream_tls_connect_state *state =
1039                 tevent_req_data(req,
1040                 struct tstream_tls_connect_state);
1041
1042         if (tevent_req_is_unix_error(req, perrno)) {
1043                 tevent_req_received(req);
1044                 return -1;
1045         }
1046
1047         *tls_stream = talloc_move(mem_ctx, &state->tls_stream);
1048         tevent_req_received(req);
1049         return 0;
1050 }
1051
1052 extern void tls_cert_generate(TALLOC_CTX *, const char *, const char *, const char *, const char *);
1053
1054 /*
1055   initialise global tls state
1056 */
1057 NTSTATUS tstream_tls_params_server(TALLOC_CTX *mem_ctx,
1058                                    const char *dns_host_name,
1059                                    bool enabled,
1060                                    const char *key_file,
1061                                    const char *cert_file,
1062                                    const char *ca_file,
1063                                    const char *crl_file,
1064                                    const char *dhp_file,
1065                                    struct tstream_tls_params **_tlsp)
1066 {
1067         struct tstream_tls_params *tlsp;
1068 #if ENABLE_GNUTLS
1069         int ret;
1070
1071         if (!enabled || key_file == NULL || *key_file == 0) {
1072                 tlsp = talloc_zero(mem_ctx, struct tstream_tls_params);
1073                 NT_STATUS_HAVE_NO_MEMORY(tlsp);
1074                 talloc_set_destructor(tlsp, tstream_tls_params_destructor);
1075                 tlsp->tls_enabled = false;
1076
1077                 *_tlsp = tlsp;
1078                 return NT_STATUS_OK;
1079         }
1080
1081         ret = gnutls_global_init();
1082         if (ret != GNUTLS_E_SUCCESS) {
1083                 DEBUG(0,("TLS %s - %s\n", __location__, gnutls_strerror(ret)));
1084                 return NT_STATUS_NOT_SUPPORTED;
1085         }
1086
1087         tlsp = talloc_zero(mem_ctx, struct tstream_tls_params);
1088         NT_STATUS_HAVE_NO_MEMORY(tlsp);
1089
1090         talloc_set_destructor(tlsp, tstream_tls_params_destructor);
1091
1092         if (!file_exist(ca_file)) {
1093                 tls_cert_generate(tlsp, dns_host_name,
1094                                   key_file, cert_file, ca_file);
1095         }
1096
1097         ret = gnutls_certificate_allocate_credentials(&tlsp->x509_cred);
1098         if (ret != GNUTLS_E_SUCCESS) {
1099                 DEBUG(0,("TLS %s - %s\n", __location__, gnutls_strerror(ret)));
1100                 talloc_free(tlsp);
1101                 return NT_STATUS_NO_MEMORY;
1102         }
1103
1104         if (ca_file && *ca_file) {
1105                 ret = gnutls_certificate_set_x509_trust_file(tlsp->x509_cred,
1106                                                              ca_file,
1107                                                              GNUTLS_X509_FMT_PEM);
1108                 if (ret < 0) {
1109                         DEBUG(0,("TLS failed to initialise cafile %s - %s\n",
1110                                  ca_file, gnutls_strerror(ret)));
1111                         talloc_free(tlsp);
1112                         return NT_STATUS_CANT_ACCESS_DOMAIN_INFO;
1113                 }
1114         }
1115
1116         if (crl_file && *crl_file) {
1117                 ret = gnutls_certificate_set_x509_crl_file(tlsp->x509_cred,
1118                                                            crl_file, 
1119                                                            GNUTLS_X509_FMT_PEM);
1120                 if (ret < 0) {
1121                         DEBUG(0,("TLS failed to initialise crlfile %s - %s\n",
1122                                  crl_file, gnutls_strerror(ret)));
1123                         talloc_free(tlsp);
1124                         return NT_STATUS_CANT_ACCESS_DOMAIN_INFO;
1125                 }
1126         }
1127
1128         ret = gnutls_certificate_set_x509_key_file(tlsp->x509_cred,
1129                                                    cert_file, key_file,
1130                                                    GNUTLS_X509_FMT_PEM);
1131         if (ret != GNUTLS_E_SUCCESS) {
1132                 DEBUG(0,("TLS failed to initialise certfile %s and keyfile %s - %s\n",
1133                          cert_file, key_file, gnutls_strerror(ret)));
1134                 talloc_free(tlsp);
1135                 return NT_STATUS_CANT_ACCESS_DOMAIN_INFO;
1136         }
1137
1138         ret = gnutls_dh_params_init(&tlsp->dh_params);
1139         if (ret != GNUTLS_E_SUCCESS) {
1140                 DEBUG(0,("TLS %s - %s\n", __location__, gnutls_strerror(ret)));
1141                 talloc_free(tlsp);
1142                 return NT_STATUS_NO_MEMORY;
1143         }
1144
1145         if (dhp_file && *dhp_file) {
1146                 gnutls_datum_t dhparms;
1147                 size_t size;
1148
1149                 dhparms.data = (uint8_t *)file_load(dhp_file, &size, 0, tlsp);
1150
1151                 if (!dhparms.data) {
1152                         DEBUG(0,("TLS failed to read DH Parms from %s - %d:%s\n",
1153                                  dhp_file, errno, strerror(errno)));
1154                         talloc_free(tlsp);
1155                         return NT_STATUS_CANT_ACCESS_DOMAIN_INFO;
1156                 }
1157                 dhparms.size = size;
1158
1159                 ret = gnutls_dh_params_import_pkcs3(tlsp->dh_params,
1160                                                     &dhparms,
1161                                                     GNUTLS_X509_FMT_PEM);
1162                 if (ret != GNUTLS_E_SUCCESS) {
1163                         DEBUG(0,("TLS failed to import pkcs3 %s - %s\n",
1164                                  dhp_file, gnutls_strerror(ret)));
1165                         talloc_free(tlsp);
1166                         return NT_STATUS_CANT_ACCESS_DOMAIN_INFO;
1167                 }
1168         } else {
1169                 ret = gnutls_dh_params_generate2(tlsp->dh_params, DH_BITS);
1170                 if (ret != GNUTLS_E_SUCCESS) {
1171                         DEBUG(0,("TLS failed to generate dh_params - %s\n",
1172                                  gnutls_strerror(ret)));
1173                         talloc_free(tlsp);
1174                         return NT_STATUS_INTERNAL_ERROR;
1175                 }
1176         }
1177
1178         gnutls_certificate_set_dh_params(tlsp->x509_cred, tlsp->dh_params);
1179
1180         tlsp->tls_enabled = true;
1181
1182 #else /* ENABLE_GNUTLS */
1183         tlsp = talloc_zero(mem_ctx, struct tstream_tls_params);
1184         NT_STATUS_HAVE_NO_MEMORY(tlsp);
1185         talloc_set_destructor(tlsp, tstream_tls_params_destructor);
1186         tlsp->tls_enabled = false;
1187 #endif /* ENABLE_GNUTLS */
1188
1189         *_tlsp = tlsp;
1190         return NT_STATUS_OK;
1191 }
1192
1193 struct tstream_tls_accept_state {
1194         struct tstream_context *tls_stream;
1195 };
1196
1197 struct tevent_req *_tstream_tls_accept_send(TALLOC_CTX *mem_ctx,
1198                                             struct tevent_context *ev,
1199                                             struct tstream_context *plain_stream,
1200                                             struct tstream_tls_params *tlsp,
1201                                             const char *location)
1202 {
1203         struct tevent_req *req;
1204         struct tstream_tls_accept_state *state;
1205         struct tstream_tls *tlss;
1206 #if ENABLE_GNUTLS
1207         int ret;
1208 #endif /* ENABLE_GNUTLS */
1209
1210         req = tevent_req_create(mem_ctx, &state,
1211                                 struct tstream_tls_accept_state);
1212         if (req == NULL) {
1213                 return NULL;
1214         }
1215
1216         state->tls_stream = tstream_context_create(state,
1217                                                    &tstream_tls_ops,
1218                                                    &tlss,
1219                                                    struct tstream_tls,
1220                                                    location);
1221         if (tevent_req_nomem(state->tls_stream, req)) {
1222                 return tevent_req_post(req, ev);
1223         }
1224         ZERO_STRUCTP(tlss);
1225         talloc_set_destructor(tlss, tstream_tls_destructor);
1226
1227 #if ENABLE_GNUTLS
1228         tlss->plain_stream = plain_stream;
1229
1230         tlss->current_ev = ev;
1231         tlss->retry_im = tevent_create_immediate(tlss);
1232         if (tevent_req_nomem(tlss->retry_im, req)) {
1233                 return tevent_req_post(req, ev);
1234         }
1235
1236         ret = gnutls_init(&tlss->tls_session, GNUTLS_SERVER);
1237         if (ret != GNUTLS_E_SUCCESS) {
1238                 DEBUG(0,("TLS %s - %s\n", __location__, gnutls_strerror(ret)));
1239                 tevent_req_error(req, EINVAL);
1240                 return tevent_req_post(req, ev);
1241         }
1242
1243         ret = gnutls_set_default_priority(tlss->tls_session);
1244         if (ret != GNUTLS_E_SUCCESS) {
1245                 DEBUG(0,("TLS %s - %s\n", __location__, gnutls_strerror(ret)));
1246                 tevent_req_error(req, EINVAL);
1247                 return tevent_req_post(req, ev);
1248         }
1249
1250         ret = gnutls_credentials_set(tlss->tls_session, GNUTLS_CRD_CERTIFICATE,
1251                                      tlsp->x509_cred);
1252         if (ret != GNUTLS_E_SUCCESS) {
1253                 DEBUG(0,("TLS %s - %s\n", __location__, gnutls_strerror(ret)));
1254                 tevent_req_error(req, EINVAL);
1255                 return tevent_req_post(req, ev);
1256         }
1257
1258         gnutls_certificate_server_set_request(tlss->tls_session,
1259                                               GNUTLS_CERT_REQUEST);
1260         gnutls_dh_set_prime_bits(tlss->tls_session, DH_BITS);
1261
1262         gnutls_transport_set_ptr(tlss->tls_session, (gnutls_transport_ptr)state->tls_stream);
1263         gnutls_transport_set_pull_function(tlss->tls_session,
1264                                            (gnutls_pull_func)tstream_tls_pull_function);
1265         gnutls_transport_set_push_function(tlss->tls_session,
1266                                            (gnutls_push_func)tstream_tls_push_function);
1267         gnutls_transport_set_lowat(tlss->tls_session, 0);
1268
1269         tlss->handshake.req = req;
1270         tstream_tls_retry_handshake(state->tls_stream);
1271         if (!tevent_req_is_in_progress(req)) {
1272                 return tevent_req_post(req, ev);
1273         }
1274
1275         return req;
1276 #else /* ENABLE_GNUTLS */
1277         tevent_req_error(req, ENOSYS);
1278         return tevent_req_post(req, ev);
1279 #endif /* ENABLE_GNUTLS */
1280 }
1281
1282 static void tstream_tls_retry_handshake(struct tstream_context *stream)
1283 {
1284         struct tstream_tls *tlss =
1285                 tstream_context_data(stream,
1286                 struct tstream_tls);
1287         struct tevent_req *req = tlss->handshake.req;
1288 #if ENABLE_GNUTLS
1289         int ret;
1290
1291         if (tlss->error != 0) {
1292                 tevent_req_error(req, tlss->error);
1293                 return;
1294         }
1295
1296         ret = gnutls_handshake(tlss->tls_session);
1297         if (ret == GNUTLS_E_INTERRUPTED || ret == GNUTLS_E_AGAIN) {
1298                 return;
1299         }
1300
1301         tlss->handshake.req = NULL;
1302
1303         if (gnutls_error_is_fatal(ret) != 0) {
1304                 DEBUG(1,("TLS %s - %s\n", __location__, gnutls_strerror(ret)));
1305                 tlss->error = EIO;
1306                 tevent_req_error(req, tlss->error);
1307                 return;
1308         }
1309
1310         if (ret != GNUTLS_E_SUCCESS) {
1311                 DEBUG(1,("TLS %s - %s\n", __location__, gnutls_strerror(ret)));
1312                 tlss->error = EIO;
1313                 tevent_req_error(req, tlss->error);
1314                 return;
1315         }
1316
1317         tevent_req_done(req);
1318 #else /* ENABLE_GNUTLS */
1319         tevent_req_error(req, ENOSYS);
1320 #endif /* ENABLE_GNUTLS */
1321 }
1322
1323 int tstream_tls_accept_recv(struct tevent_req *req,
1324                             int *perrno,
1325                             TALLOC_CTX *mem_ctx,
1326                             struct tstream_context **tls_stream)
1327 {
1328         struct tstream_tls_accept_state *state =
1329                 tevent_req_data(req,
1330                 struct tstream_tls_accept_state);
1331
1332         if (tevent_req_is_unix_error(req, perrno)) {
1333                 tevent_req_received(req);
1334                 return -1;
1335         }
1336
1337         *tls_stream = talloc_move(mem_ctx, &state->tls_stream);
1338         tevent_req_received(req);
1339         return 0;
1340 }