s4:lib/tls - call "gnutls_transport_set_lowat" only on GNUTLS < 3.0
[mdw/samba.git] / source4 / lib / tls / tls_tstream.c
index f7d27ebbf497dc6a7886bdf04d49256668668dc7..eb4a6d90dad31510eb2764cd109b2d6c3748fd9b 100644 (file)
@@ -50,10 +50,18 @@ struct tstream_tls {
        struct tevent_immediate *retry_im;
 
        struct {
-               uint8_t buffer[1024];
+               uint8_t *buf;
+               off_t ofs;
                struct iovec iov;
                struct tevent_req *subreq;
-       } push, pull;
+               struct tevent_immediate *im;
+       } push;
+
+       struct {
+               uint8_t *buf;
+               struct iovec iov;
+               struct tevent_req *subreq;
+       } pull;
 
        struct {
                struct tevent_req *req;
@@ -132,7 +140,9 @@ static void tstream_tls_retry_trigger(struct tevent_context *ctx,
 }
 
 #if ENABLE_GNUTLS
-static void tstream_tls_push_done(struct tevent_req *subreq);
+static void tstream_tls_push_trigger_write(struct tevent_context *ev,
+                                          struct tevent_immediate *im,
+                                          void *private_data);
 
 static ssize_t tstream_tls_push_function(gnutls_transport_ptr ptr,
                                         const void *buf, size_t size)
@@ -143,7 +153,8 @@ static ssize_t tstream_tls_push_function(gnutls_transport_ptr ptr,
        struct tstream_tls *tlss =
                tstream_context_data(stream,
                struct tstream_tls);
-       struct tevent_req *subreq;
+       uint8_t *nbuf;
+       size_t len;
 
        if (tlss->error != 0) {
                errno = tlss->error;
@@ -155,24 +166,92 @@ static ssize_t tstream_tls_push_function(gnutls_transport_ptr ptr,
                return -1;
        }
 
-       tlss->push.iov.iov_base = tlss->push.buffer;
-       tlss->push.iov.iov_len = MIN(size, sizeof(tlss->push.buffer));
+       len = MIN(size, UINT16_MAX - tlss->push.ofs);
+
+       if (len == 0) {
+               errno = EAGAIN;
+               return -1;
+       }
+
+       nbuf = talloc_realloc(tlss, tlss->push.buf,
+                             uint8_t, tlss->push.ofs + len);
+       if (nbuf == NULL) {
+               if (tlss->push.buf) {
+                       errno = EAGAIN;
+                       return -1;
+               }
 
-       memcpy(tlss->push.buffer, buf, tlss->push.iov.iov_len);
+               return -1;
+       }
+       tlss->push.buf = nbuf;
+
+       memcpy(tlss->push.buf + tlss->push.ofs, buf, len);
+
+       if (tlss->push.im == NULL) {
+               tlss->push.im = tevent_create_immediate(tlss);
+               if (tlss->push.im == NULL) {
+                       errno = ENOMEM;
+                       return -1;
+               }
+       }
+
+       if (tlss->push.ofs == 0) {
+               /*
+                * We'll do start the tstream_writev
+                * in the next event cycle.
+                *
+                * This way we can batch all push requests,
+                * if they fit into a UINT16_MAX buffer.
+                *
+                * This is important as gnutls_handshake()
+                * had a bug in some versions e.g. 2.4.1
+                * and others (See bug #7218) and it doesn't
+                * handle EAGAIN.
+                */
+               tevent_schedule_immediate(tlss->push.im,
+                                         tlss->current_ev,
+                                         tstream_tls_push_trigger_write,
+                                         stream);
+       }
+
+       tlss->push.ofs += len;
+       return len;
+}
+
+static void tstream_tls_push_done(struct tevent_req *subreq);
+
+static void tstream_tls_push_trigger_write(struct tevent_context *ev,
+                                          struct tevent_immediate *im,
+                                          void *private_data)
+{
+       struct tstream_context *stream =
+               talloc_get_type_abort(private_data,
+               struct tstream_context);
+       struct tstream_tls *tlss =
+               tstream_context_data(stream,
+               struct tstream_tls);
+       struct tevent_req *subreq;
+
+       if (tlss->push.subreq) {
+               /* nothing todo */
+               return;
+       }
+
+       tlss->push.iov.iov_base = (char *)tlss->push.buf;
+       tlss->push.iov.iov_len = tlss->push.ofs;
 
        subreq = tstream_writev_send(tlss,
                                     tlss->current_ev,
                                     tlss->plain_stream,
                                     &tlss->push.iov, 1);
        if (subreq == NULL) {
-               errno = ENOMEM;
-               return -1;
+               tlss->error = ENOMEM;
+               tstream_tls_retry(stream, false);
+               return;
        }
        tevent_req_set_callback(subreq, tstream_tls_push_done, stream);
 
        tlss->push.subreq = subreq;
-
-       return tlss->push.iov.iov_len;
 }
 
 static void tstream_tls_push_done(struct tevent_req *subreq)
@@ -188,6 +267,8 @@ static void tstream_tls_push_done(struct tevent_req *subreq)
 
        tlss->push.subreq = NULL;
        ZERO_STRUCT(tlss->push.iov);
+       TALLOC_FREE(tlss->push.buf);
+       tlss->push.ofs = 0;
 
        ret = tstream_writev_recv(subreq, &sys_errno);
        TALLOC_FREE(subreq);
@@ -212,6 +293,7 @@ static ssize_t tstream_tls_pull_function(gnutls_transport_ptr ptr,
                tstream_context_data(stream,
                struct tstream_tls);
        struct tevent_req *subreq;
+       size_t len;
 
        if (tlss->error != 0) {
                errno = tlss->error;
@@ -224,14 +306,20 @@ static ssize_t tstream_tls_pull_function(gnutls_transport_ptr ptr,
        }
 
        if (tlss->pull.iov.iov_base) {
+               uint8_t *b;
                size_t n;
 
+               b = (uint8_t *)tlss->pull.iov.iov_base;
+
                n = MIN(tlss->pull.iov.iov_len, size);
-               memcpy(buf, tlss->pull.iov.iov_base, n);
+               memcpy(buf, b, n);
 
                tlss->pull.iov.iov_len -= n;
+               b += n;
+               tlss->pull.iov.iov_base = (char *)b;
                if (tlss->pull.iov.iov_len == 0) {
                        tlss->pull.iov.iov_base = NULL;
+                       TALLOC_FREE(tlss->pull.buf);
                }
 
                return n;
@@ -241,8 +329,15 @@ static ssize_t tstream_tls_pull_function(gnutls_transport_ptr ptr,
                return 0;
        }
 
-       tlss->pull.iov.iov_base = tlss->pull.buffer;
-       tlss->pull.iov.iov_len = MIN(size, sizeof(tlss->pull.buffer));
+       len = MIN(size, UINT16_MAX);
+
+       tlss->pull.buf = talloc_array(tlss, uint8_t, len);
+       if (tlss->pull.buf == NULL) {
+               return -1;
+       }
+
+       tlss->pull.iov.iov_base = (char *)tlss->pull.buf;
+       tlss->pull.iov.iov_len = len;
 
        subreq = tstream_readv_send(tlss,
                                    tlss->current_ev,
@@ -394,7 +489,7 @@ static void tstream_tls_readv_crypt_next(struct tevent_req *req)
                memcpy(base, tlss->read.buffer + tlss->read.ofs, len);
 
                base += len;
-               state->vector[0].iov_base = base;
+               state->vector[0].iov_base = (char *) base;
                state->vector[0].iov_len -= len;
 
                tlss->read.ofs += len;
@@ -565,7 +660,7 @@ static void tstream_tls_writev_crypt_next(struct tevent_req *req)
                memcpy(tlss->write.buffer + tlss->write.ofs, base, len);
 
                base += len;
-               state->vector[0].iov_base = base;
+               state->vector[0].iov_base = (char *) base;
                state->vector[0].iov_len -= len;
 
                tlss->write.ofs += len;
@@ -934,7 +1029,9 @@ struct tevent_req *_tstream_tls_connect_send(TALLOC_CTX *mem_ctx,
                                           (gnutls_pull_func)tstream_tls_pull_function);
        gnutls_transport_set_push_function(tlss->tls_session,
                                           (gnutls_push_func)tstream_tls_push_function);
+#if GNUTLS_VERSION_MAJOR < 3
        gnutls_transport_set_lowat(tlss->tls_session, 0);
+#endif
 
        tlss->handshake.req = req;
        tstream_tls_retry_handshake(state->tls_stream);
@@ -1183,7 +1280,9 @@ struct tevent_req *_tstream_tls_accept_send(TALLOC_CTX *mem_ctx,
                                           (gnutls_pull_func)tstream_tls_pull_function);
        gnutls_transport_set_push_function(tlss->tls_session,
                                           (gnutls_push_func)tstream_tls_push_function);
+#if GNUTLS_VERSION_MAJOR < 3
        gnutls_transport_set_lowat(tlss->tls_session, 0);
+#endif
 
        tlss->handshake.req = req;
        tstream_tls_retry_handshake(state->tls_stream);