HACK-TODO: tls_tstream...
[metze/samba/wip.git] / source4 / lib / tls / tls_tstream.c
1 /*
2    Unix SMB/CIFS implementation.
3
4    Copyright (C) Stefan Metzmacher 2010
5
6    This program is free software; you can redistribute it and/or modify
7    it under the terms of the GNU General Public License as published by
8    the Free Software Foundation; either version 3 of the License, or
9    (at your option) any later version.
10
11    This program is distributed in the hope that it will be useful,
12    but WITHOUT ANY WARRANTY; without even the implied warranty of
13    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
14    GNU General Public License for more details.
15
16    You should have received a copy of the GNU General Public License
17    along with this program.  If not, see <http://www.gnu.org/licenses/>.
18 */
19
20 #include "includes.h"
21 #include "system/network.h"
22 #include "../util/tevent_unix.h"
23 #include "../lib/tsocket/tsocket.h"
24 #include "../lib/tsocket/tsocket_internal.h"
25 #include "lib/tls/tls.h"
26 #include "gnutls/gnutls.h"
27
28 #define DH_BITS 1024
29
30 static const struct tstream_context_ops tstream_tls_ops;
31
32 struct tstream_tls {
33         struct tstream_context *plain_stream;
34         int plain_errno;
35
36         gnutls_session tls_session;
37         int tls_error;
38
39         gnutls_certificate_credentials xcred;
40
41         struct tevent_context *current_ev;
42
43         struct {
44                 uint8_t buffer[1024];
45                 struct iovec iov;
46         } push, pull;
47 };
48
49 static void tstream_tls_schedule_retry(struct tstream_context *tls,
50                                        struct tevent_context *ev)
51 {
52         /* TODO */
53 }
54 static void tstream_tls_retry(struct tstream_context *tls)
55 {
56         /* TODO */
57 }
58
59 static void tstream_tls_push_done(struct tevent_req *subreq);
60
61 static ssize_t tstream_tls_push_function(gnutls_transport_ptr ptr,
62                                          const void *buf, size_t size)
63 {
64         struct tstream_context *tls = talloc_get_type_abort(ptr,
65                                       struct tstream_context);
66         struct tstream_tls *tlss = tstream_context_data(tls, struct tstream_tls);
67         struct tevent_req *subreq;
68
69         if (tlss->push.iov.iov_base) {
70                 errno = EAGAIN;
71                 return -1;
72         }
73
74         tlss->push.iov.iov_base = tlss->push.buffer;
75         tlss->push.iov.iov_len = MIN(size, sizeof(tlss->push.buffer));
76
77         memcpy(tlss->push.buffer, buf, tlss->push.iov.iov_len);
78
79         subreq = tstream_writev_send(tlss,
80                                      tlss->current_ev,
81                                      tlss->plain_stream,
82                                      &tlss->push.iov, 1);
83         if (subreq == NULL) {
84                 errno = ENOMEM;
85                 return -1;
86         }
87         tevent_req_set_callback(subreq, tstream_tls_push_done, tls);
88
89         return tlss->push.iov.iov_len;
90 }
91
92 static void tstream_tls_push_done(struct tevent_req *subreq)
93 {
94         struct tstream_context *tls = tevent_req_callback_data(subreq,
95                                       struct tstream_context);
96         struct tstream_tls *tlss = tstream_context_data(tls, struct tstream_tls);
97         int ret;
98         int perrno;
99
100         ZERO_STRUCT(tlss->push.iov);
101
102         ret = tstream_writev_recv(subreq, &perrno);
103         TALLOC_FREE(subreq);
104         if (ret == -1) {
105                 tlss->plain_errno = perrno;
106                 tstream_tls_retry(tls);
107                 return;
108         }
109
110         tstream_tls_retry(tls);
111 }
112
113 static void tstream_tls_pull_done(struct tevent_req *subreq);
114
115 static ssize_t tstream_tls_pull_function(gnutls_transport_ptr ptr,
116                                          void *buf, size_t size)
117 {
118         struct tstream_context *tls = talloc_get_type_abort(ptr,
119                                       struct tstream_context);
120         struct tstream_tls *tlss = tstream_context_data(tls, struct tstream_tls);
121         struct tevent_req *subreq;
122
123         if (tlss->pull.iov.iov_base) {
124                 size_t n;
125
126                 n = MIN(tlss->pull.iov.iov_len, size);
127                 memcpy(buf, tlss->pull.iov.iov_base, n);
128
129                 tlss->pull.iov.iov_len -= n;
130                 if (tlss->pull.iov.iov_len == 0) {
131                         tlss->pull.iov.iov_base = NULL;
132                 }
133
134                 return n;
135         }
136
137         if (size == 0) {
138                 return 0;
139         }
140
141         tlss->pull.iov.iov_base = tlss->pull.buffer;
142         tlss->pull.iov.iov_len = MIN(size, sizeof(tlss->pull.buffer));
143
144         subreq = tstream_readv_send(tlss,
145                                     tlss->current_ev,
146                                     tlss->plain_stream,
147                                     &tlss->pull.iov, 1);
148         if (subreq == NULL) {
149                 errno = ENOMEM;
150                 return -1;
151         }
152         tevent_req_set_callback(subreq, tstream_tls_pull_done, tls);
153
154         errno = EAGAIN;
155         return -1;
156 }
157
158 static void tstream_tls_pull_done(struct tevent_req *subreq)
159 {
160         struct tstream_context *tls = tevent_req_callback_data(subreq,
161                                       struct tstream_context);
162         struct tstream_tls *tlss = tstream_context_data(tls, struct tstream_tls);
163         int ret;
164         int perrno;
165
166         ret = tstream_readv_recv(subreq, &perrno);
167         TALLOC_FREE(subreq);
168         if (ret == -1) {
169                 tlss->plain_errno = perrno;
170                 tstream_tls_retry(tls);
171                 return;
172         }
173
174         tstream_tls_retry(tls);
175 }
176
177 static int tstream_tls_destructor(struct tstream_tls *tlss)
178 {
179         if (tlss->xcred) {
180                 gnutls_certificate_free_credentials(tlss->xcred);
181                 tlss->xcred = NULL;
182         }
183         if (tlss->tls_session) {
184                 gnutls_deinit(tlss->tls_session);
185                 tlss->tls_session = NULL;
186         }
187         return 0;
188 }
189
190 static ssize_t tstream_tls_pending_bytes(struct tstream_context *stream)
191 {
192         struct tstream_tls *tlss = tstream_context_data(stream,
193                                    struct tstream_tls);
194         ssize_t ret;
195
196         if (!tlss->plain_stream) {
197                 errno = ENOTCONN;
198                 return -1;
199         }
200
201         if (!tlss->tls_session) {
202                 ret = tstream_pending_bytes(tlss->plain_stream);
203                 return ret;
204         }
205
206         ret = gnutls_record_check_pending(tlss->tls_session);
207         if (ret < 0) {
208                 /* TODO: better mapping */
209                 errno = EIO;
210                 return -1;
211         }
212
213         return ret;
214 }
215
216 struct tstream_tls_readv_state {
217         int ret;
218 };
219
220 static void tstream_tls_readv_plain_handler(struct tevent_req *subreq);
221
222 static struct tevent_req *tstream_tls_readv_send(TALLOC_CTX *mem_ctx,
223                                         struct tevent_context *ev,
224                                         struct tstream_context *stream,
225                                         struct iovec *vector,
226                                         size_t count)
227 {
228         struct tevent_req *req;
229         struct tstream_tls_readv_state *state;
230         struct tstream_tls *tlss = tstream_context_data(stream, struct tstream_tls);
231         struct tevent_req *subreq;
232
233         req = tevent_req_create(mem_ctx, &state,
234                                 struct tstream_tls_readv_state);
235         if (!req) {
236                 return NULL;
237         }
238
239         state->ret      = 0;
240
241         if (!tlss->plain_stream) {
242                 tevent_req_error(req, ENOTCONN);
243                 return tevent_req_post(req, ev);
244         }
245
246         if (!tlss->tls_session) {
247                 subreq = tstream_readv_send(state,
248                                             ev,
249                                             tlss->plain_stream,
250                                             vector,
251                                             count);
252                 if (tevent_req_nomem(subreq,req)) {
253                         return tevent_req_post(req, ev);
254                 }
255                 tevent_req_set_callback(subreq,
256                                         tstream_tls_readv_plain_handler,
257                                         req);
258
259                 return req;
260         }
261
262         /* TODO */
263         tevent_req_error(req, ENOSYS);
264         return tevent_req_post(req, ev);
265 }
266
267 static void tstream_tls_readv_plain_handler(struct tevent_req *subreq)
268 {
269         struct tevent_req *req = tevent_req_callback_data(subreq,
270                                  struct tevent_req);
271         struct tstream_tls_readv_state *state = tevent_req_data(req,
272                                         struct tstream_tls_readv_state);
273         int ret;
274         int sys_errno;
275
276         ret = tstream_readv_recv(subreq, &sys_errno);
277         TALLOC_FREE(subreq);
278         if (ret == -1) {
279                 tevent_req_error(req, sys_errno);
280                 return;
281         }
282
283         state->ret = ret;
284
285         tevent_req_done(req);
286 }
287
288 static int tstream_tls_readv_recv(struct tevent_req *req,
289                                   int *perrno)
290 {
291         struct tstream_tls_readv_state *state = tevent_req_data(req,
292                                         struct tstream_tls_readv_state);
293         int ret;
294
295         ret = tsocket_simple_int_recv(req, perrno);
296         if (ret == 0) {
297                 ret = state->ret;
298         }
299
300         tevent_req_received(req);
301         return ret;
302 }
303
304 struct tstream_tls_writev_state {
305         int ret;
306 };
307
308 static void tstream_tls_writev_plain_handler(struct tevent_req *subreq);
309
310 static struct tevent_req *tstream_tls_writev_send(TALLOC_CTX *mem_ctx,
311                                         struct tevent_context *ev,
312                                         struct tstream_context *stream,
313                                         const struct iovec *vector,
314                                         size_t count)
315 {
316         struct tevent_req *req;
317         struct tstream_tls_writev_state *state;
318         struct tstream_tls *tlss = tstream_context_data(stream, struct tstream_tls);
319         struct tevent_req *subreq;
320
321         req = tevent_req_create(mem_ctx, &state,
322                                 struct tstream_tls_writev_state);
323         if (!req) {
324                 return NULL;
325         }
326
327         state->ret      = 0;
328
329         if (!tlss->plain_stream) {
330                 tevent_req_error(req, ENOTCONN);
331                 return tevent_req_post(req, ev);
332         }
333
334         if (!tlss->tls_session) {
335                 subreq = tstream_writev_send(state,
336                                              ev,
337                                              tlss->plain_stream,
338                                              vector,
339                                              count);
340                 if (tevent_req_nomem(subreq, req)) {
341                         return tevent_req_post(req, ev);
342                 }
343                 tevent_req_set_callback(subreq, tstream_tls_writev_plain_handler, req);
344
345                 return req;
346         }
347
348         /* TODO */
349         tevent_req_error(req, ENOSYS);
350         return tevent_req_post(req, ev);
351 }
352
353 static void tstream_tls_writev_plain_handler(struct tevent_req *subreq)
354 {
355         struct tevent_req *req = tevent_req_callback_data(subreq,
356                                  struct tevent_req);
357         struct tstream_tls_writev_state *state = tevent_req_data(req,
358                                         struct tstream_tls_writev_state);
359         int ret;
360         int sys_errno;
361
362         ret = tstream_writev_recv(subreq, &sys_errno);
363         TALLOC_FREE(subreq);
364         if (ret == -1) {
365                 tevent_req_error(req, sys_errno);
366                 return;
367         }
368
369         state->ret = ret;
370
371         tevent_req_done(req);
372 }
373
374 static int tstream_tls_writev_recv(struct tevent_req *req,
375                                    int *perrno)
376 {
377         struct tstream_tls_writev_state *state = tevent_req_data(req,
378                                         struct tstream_tls_writev_state);
379         int ret;
380
381         ret = tsocket_simple_int_recv(req, perrno);
382         if (ret == 0) {
383                 ret = state->ret;
384         }
385
386         tevent_req_received(req);
387         return ret;
388 }
389
390 struct tstream_tls_disconnect_state {
391         uint8_t _dummy;
392 };
393
394 static struct tevent_req *tstream_tls_disconnect_send(TALLOC_CTX *mem_ctx,
395                                                 struct tevent_context *ev,
396                                                 struct tstream_context *stream)
397 {
398         struct tstream_tls *tlss = tstream_context_data(stream, struct tstream_tls);
399         struct tevent_req *req;
400         struct tstream_tls_disconnect_state *state;
401
402         req = tevent_req_create(mem_ctx, &state,
403                                 struct tstream_tls_disconnect_state);
404         if (req == NULL) {
405                 return NULL;
406         }
407
408         if (!tlss->plain_stream) {
409                 tevent_req_error(req, ENOTCONN);
410                 return tevent_req_post(req, ev);
411         }
412
413         if (!tlss->tls_session) {
414                 /*
415                  * The caller is responsible to do the real disconnect
416                  * on the plain stream!
417                  */
418                 tlss->plain_stream = NULL;
419                 tevent_req_done(req);
420                 return tevent_req_post(req, ev);
421         }
422
423         /* TODO */
424         tevent_req_error(req, ENOSYS);
425         return tevent_req_post(req, ev);
426 }
427
428 static int tstream_tls_disconnect_recv(struct tevent_req *req,
429                                        int *perrno)
430 {
431         int ret;
432
433         ret = tsocket_simple_int_recv(req, perrno);
434
435         tevent_req_received(req);
436         return ret;
437 }
438
439 static const struct tstream_context_ops tstream_tls_ops = {
440         .name                   = "tls",
441
442         .pending_bytes          = tstream_tls_pending_bytes,
443
444         .readv_send             = tstream_tls_readv_send,
445         .readv_recv             = tstream_tls_readv_recv,
446
447         .writev_send            = tstream_tls_writev_send,
448         .writev_recv            = tstream_tls_writev_recv,
449
450         .disconnect_send        = tstream_tls_disconnect_send,
451         .disconnect_recv        = tstream_tls_disconnect_recv,
452 };
453
454 struct tstream_tls_params {
455         const char *ca_path;
456         gnutls_certificate_credentials xcred;
457 };
458
459 struct tstream_tls_connect_state {
460         struct {
461                 struct tevent_context *ev;
462         } caller;
463         struct tstream_context *tls_stream;
464 };
465
466 struct tevent_req *_tstream_tls_connect_send(TALLOC_CTX *mem_ctx,
467                                              struct tevent_context *ev,
468                                              struct tstream_context *plain_stream,
469                                              struct tstream_tls_params *tls_params,
470                                              const char *location)
471 {
472         struct tevent_req *req;
473         struct tstream_tls_connect_state *state;
474         struct tstream_tls *tlss;
475         int ret;
476         static const int cert_type_priority[] = {
477                 GNUTLS_CRT_X509,
478                 GNUTLS_CRT_OPENPGP,
479                 0
480         };
481
482         req = tevent_req_create(mem_ctx, &state,
483                                 struct tstream_tls_connect_state);
484         if (!req) {
485                 return NULL;
486         }
487
488         state->caller.ev = ev;
489
490         state->tls_stream = tstream_context_create(state,
491                                                    &tstream_tls_ops,
492                                                    &tlss,
493                                                    struct tstream_tls,
494                                                    location);
495         if (tevent_req_nomem(state->tls_stream, req)) {
496                 return tevent_req_post(req, ev);
497         }
498         ZERO_STRUCTP(tlss);
499         talloc_set_destructor(tlss, tstream_tls_destructor);
500
501         tlss->plain_stream = plain_stream;
502
503         gnutls_global_init();
504
505         ret = gnutls_certificate_allocate_credentials(&tlss->xcred);
506         if (tevent_req_error(req, ret)) {
507                 DEBUG(0,("TLS %s - %s\n", __location__, gnutls_strerror(ret)));
508                 return tevent_req_post(req, ev);
509         }
510
511         gnutls_certificate_set_x509_trust_file(tlss->xcred,
512                                                tls_params->ca_path,
513                                                GNUTLS_X509_FMT_PEM);
514
515         ret = gnutls_init(&tlss->tls_session, GNUTLS_CLIENT);
516         if (tevent_req_error(req, ret)) {
517                 DEBUG(0,("TLS %s - %s\n", __location__, gnutls_strerror(ret)));
518                 return tevent_req_post(req, ev);
519         }
520
521         ret = gnutls_set_default_priority(tlss->tls_session);
522         if (tevent_req_error(req, ret)) {
523                 DEBUG(0,("TLS %s - %s\n", __location__, gnutls_strerror(ret)));
524                 return tevent_req_post(req, ev);
525         }
526
527         gnutls_certificate_type_set_priority(tlss->tls_session, cert_type_priority);
528
529         ret = gnutls_credentials_set(tlss->tls_session, GNUTLS_CRD_CERTIFICATE, tlss->xcred);
530         if (tevent_req_error(req, ret)) {
531                 DEBUG(0,("TLS %s - %s\n", __location__, gnutls_strerror(ret)));
532                 return tevent_req_post(req, ev);
533         }
534
535         gnutls_transport_set_ptr(tlss->tls_session, (gnutls_transport_ptr)state->tls_stream);
536         gnutls_transport_set_pull_function(tlss->tls_session,
537                                            (gnutls_pull_func)tstream_tls_pull_function);
538         gnutls_transport_set_push_function(tlss->tls_session,
539                                            (gnutls_push_func)tstream_tls_push_function);
540         gnutls_transport_set_lowat(tlss->tls_session, 0);
541
542         //tstream_tls_prepare_operation(tls, state->caller.ev);
543         ret = gnutls_handshake(tlss->tls_session);
544         if (ret == GNUTLS_E_INTERRUPTED || ret == GNUTLS_E_AGAIN) {
545                 tstream_tls_schedule_retry(state->tls_stream, state->caller.ev);
546                 return req;
547         }
548         if (tevent_req_error(req, ret)) {
549                 DEBUG(0,("TLS %s - %s\n", __location__, gnutls_strerror(ret)));
550                 return tevent_req_post(req, ev);
551         }
552
553         return tevent_req_post(req, ev);
554 }
555
556 int tstream_tls_connect_recv(struct tevent_req *req,
557                              int *perrno,
558                              TALLOC_CTX *mem_ctx,
559                              struct tstream_context **tls_stream)
560 {
561         struct tstream_tls_connect_state *state =
562                 tevent_req_data(req,
563                 struct tstream_tls_connect_state);
564
565         if (tevent_req_is_unix_error(req, perrno)) {
566                 tevent_req_received(req);
567                 return -1;
568         }
569
570         *tls_stream = talloc_move(mem_ctx, &state->tls_stream);
571         tevent_req_received(req);
572         return 0;
573 }
574
575 struct tstream_tls_accept_state {
576         struct {
577                 struct tevent_context *ev;
578         } caller;
579         struct tstream_context *tls_stream;
580 };
581
582 struct tevent_req *_tstream_tls_accept_send(TALLOC_CTX *mem_ctx,
583                                             struct tevent_context *ev,
584                                             struct tstream_context *plain_stream,
585                                             struct tstream_tls_params *tls_params,
586                                             const char *location)
587 {
588         struct tevent_req *req;
589         struct tstream_tls_accept_state *state;
590         struct tstream_tls *tlss;
591         int ret;
592
593         req = tevent_req_create(mem_ctx, &state,
594                                 struct tstream_tls_accept_state);
595         if (!req) {
596                 return NULL;
597         }
598
599         state->caller.ev = ev;
600
601         state->tls_stream = tstream_context_create(state,
602                                                    &tstream_tls_ops,
603                                                    &tlss,
604                                                    struct tstream_tls,
605                                                    location);
606         if (tevent_req_nomem(state->tls_stream, req)) {
607                 return tevent_req_post(req, ev);
608         }
609         ZERO_STRUCTP(tlss);
610         talloc_set_destructor(tlss, tstream_tls_destructor);
611
612         tlss->plain_stream = plain_stream;
613
614         gnutls_global_init();
615
616         ret = gnutls_init(&tlss->tls_session, GNUTLS_SERVER);
617         if (tevent_req_error(req, ret)) {
618                 DEBUG(0,("TLS %s - %s\n", __location__, gnutls_strerror(ret)));
619                 return tevent_req_post(req, ev);
620         }
621
622         ret = gnutls_set_default_priority(tlss->tls_session);
623         if (tevent_req_error(req, ret)) {
624                 DEBUG(0,("TLS %s - %s\n", __location__, gnutls_strerror(ret)));
625                 return tevent_req_post(req, ev);
626         }
627
628         ret = gnutls_credentials_set(tlss->tls_session, GNUTLS_CRD_CERTIFICATE,
629                                      tls_params->xcred);
630         if (tevent_req_error(req, ret)) {
631                 DEBUG(0,("TLS %s - %s\n", __location__, gnutls_strerror(ret)));
632                 return tevent_req_post(req, ev);
633         }
634
635         gnutls_certificate_server_set_request(tlss->tls_session,
636                                               GNUTLS_CERT_REQUEST);
637         gnutls_dh_set_prime_bits(tlss->tls_session, DH_BITS);
638
639         gnutls_transport_set_ptr(tlss->tls_session, (gnutls_transport_ptr)state->tls_stream);
640         gnutls_transport_set_pull_function(tlss->tls_session,
641                                            (gnutls_pull_func)tstream_tls_pull_function);
642         gnutls_transport_set_push_function(tlss->tls_session,
643                                            (gnutls_push_func)tstream_tls_push_function);
644         gnutls_transport_set_lowat(tlss->tls_session, 0);
645
646         //tstream_tls_prepare_operation(tls, state->caller.ev);
647         ret = gnutls_handshake(tlss->tls_session);
648         if (ret == GNUTLS_E_INTERRUPTED || ret == GNUTLS_E_AGAIN) {
649                 tstream_tls_schedule_retry(state->tls_stream, state->caller.ev);
650                 return req;
651         }
652         if (tevent_req_error(req, ret)) {
653                 DEBUG(0,("TLS %s - %s\n", __location__, gnutls_strerror(ret)));
654                 return tevent_req_post(req, ev);
655         }
656
657         return tevent_req_post(req, ev);
658 }
659
660 int tstream_tls_accept_recv(struct tevent_req *req,
661                             int *perrno,
662                             TALLOC_CTX *mem_ctx,
663                             struct tstream_context **tls_stream)
664 {
665         struct tstream_tls_accept_state *state =
666                 tevent_req_data(req,
667                 struct tstream_tls_accept_state);
668
669         if (tevent_req_is_unix_error(req, perrno)) {
670                 tevent_req_received(req);
671                 return -1;
672         }
673
674         *tls_stream = talloc_move(mem_ctx, &state->tls_stream);
675         tevent_req_received(req);
676         return 0;
677 }
678