s4:lib/tls - include GNUTLS headers consistently using <...>
[samba.git] / source4 / lib / tls / tls_tstream.c
index 7f37643a659d29468ea08df8581439427c1c4f94..6bb68fb34c0a9e606fec5130bf8cdbb8b9aac948 100644 (file)
@@ -25,7 +25,7 @@
 #include "lib/tls/tls.h"
 
 #if ENABLE_GNUTLS
-#include "gnutls/gnutls.h"
+#include <gnutls/gnutls.h>
 
 #define DH_BITS 1024
 
@@ -50,7 +50,7 @@ struct tstream_tls {
        struct tevent_immediate *retry_im;
 
        struct {
-               uint8_t buffer[1024];
+               uint8_t *buf;
                off_t ofs;
                struct iovec iov;
                struct tevent_req *subreq;
@@ -58,7 +58,7 @@ struct tstream_tls {
        } push;
 
        struct {
-               uint8_t buffer[1024];
+               uint8_t *buf;
                struct iovec iov;
                struct tevent_req *subreq;
        } pull;
@@ -153,6 +153,7 @@ static ssize_t tstream_tls_push_function(gnutls_transport_ptr ptr,
        struct tstream_tls *tlss =
                tstream_context_data(stream,
                struct tstream_tls);
+       uint8_t *nbuf;
        size_t len;
 
        if (tlss->error != 0) {
@@ -165,13 +166,26 @@ static ssize_t tstream_tls_push_function(gnutls_transport_ptr ptr,
                return -1;
        }
 
-       if (tlss->push.ofs == sizeof(tlss->push.buffer)) {
+       len = MIN(size, UINT16_MAX - tlss->push.ofs);
+
+       if (len == 0) {
                errno = EAGAIN;
                return -1;
        }
 
-       len = MIN(size, sizeof(tlss->push.buffer) - tlss->push.ofs);
-       memcpy(tlss->push.buffer + tlss->push.ofs, buf, len);
+       nbuf = talloc_realloc(tlss, tlss->push.buf,
+                             uint8_t, tlss->push.ofs + len);
+       if (nbuf == NULL) {
+               if (tlss->push.buf) {
+                       errno = EAGAIN;
+                       return -1;
+               }
+
+               return -1;
+       }
+       tlss->push.buf = nbuf;
+
+       memcpy(tlss->push.buf + tlss->push.ofs, buf, len);
 
        if (tlss->push.im == NULL) {
                tlss->push.im = tevent_create_immediate(tlss);
@@ -187,7 +201,7 @@ static ssize_t tstream_tls_push_function(gnutls_transport_ptr ptr,
                 * in the next event cycle.
                 *
                 * This way we can batch all push requests,
-                * if they fit into the buffer.
+                * if they fit into a UINT16_MAX buffer.
                 *
                 * This is important as gnutls_handshake()
                 * had a bug in some versions e.g. 2.4.1
@@ -223,7 +237,7 @@ static void tstream_tls_push_trigger_write(struct tevent_context *ev,
                return;
        }
 
-       tlss->push.iov.iov_base = (char *)tlss->push.buffer;
+       tlss->push.iov.iov_base = (char *)tlss->push.buf;
        tlss->push.iov.iov_len = tlss->push.ofs;
 
        subreq = tstream_writev_send(tlss,
@@ -253,6 +267,7 @@ static void tstream_tls_push_done(struct tevent_req *subreq)
 
        tlss->push.subreq = NULL;
        ZERO_STRUCT(tlss->push.iov);
+       TALLOC_FREE(tlss->push.buf);
        tlss->push.ofs = 0;
 
        ret = tstream_writev_recv(subreq, &sys_errno);
@@ -278,6 +293,7 @@ static ssize_t tstream_tls_pull_function(gnutls_transport_ptr ptr,
                tstream_context_data(stream,
                struct tstream_tls);
        struct tevent_req *subreq;
+       size_t len;
 
        if (tlss->error != 0) {
                errno = tlss->error;
@@ -290,14 +306,20 @@ static ssize_t tstream_tls_pull_function(gnutls_transport_ptr ptr,
        }
 
        if (tlss->pull.iov.iov_base) {
+               uint8_t *b;
                size_t n;
 
+               b = (uint8_t *)tlss->pull.iov.iov_base;
+
                n = MIN(tlss->pull.iov.iov_len, size);
-               memcpy(buf, tlss->pull.iov.iov_base, n);
+               memcpy(buf, b, n);
 
                tlss->pull.iov.iov_len -= n;
+               b += n;
+               tlss->pull.iov.iov_base = (char *)b;
                if (tlss->pull.iov.iov_len == 0) {
                        tlss->pull.iov.iov_base = NULL;
+                       TALLOC_FREE(tlss->pull.buf);
                }
 
                return n;
@@ -307,8 +329,15 @@ static ssize_t tstream_tls_pull_function(gnutls_transport_ptr ptr,
                return 0;
        }
 
-       tlss->pull.iov.iov_base = tlss->pull.buffer;
-       tlss->pull.iov.iov_len = MIN(size, sizeof(tlss->pull.buffer));
+       len = MIN(size, UINT16_MAX);
+
+       tlss->pull.buf = talloc_array(tlss, uint8_t, len);
+       if (tlss->pull.buf == NULL) {
+               return -1;
+       }
+
+       tlss->pull.iov.iov_base = (char *)tlss->pull.buf;
+       tlss->pull.iov.iov_len = len;
 
        subreq = tstream_readv_send(tlss,
                                    tlss->current_ev,
@@ -460,7 +489,7 @@ static void tstream_tls_readv_crypt_next(struct tevent_req *req)
                memcpy(base, tlss->read.buffer + tlss->read.ofs, len);
 
                base += len;
-               state->vector[0].iov_base = base;
+               state->vector[0].iov_base = (char *) base;
                state->vector[0].iov_len -= len;
 
                tlss->read.ofs += len;
@@ -631,7 +660,7 @@ static void tstream_tls_writev_crypt_next(struct tevent_req *req)
                memcpy(tlss->write.buffer + tlss->write.ofs, base, len);
 
                base += len;
-               state->vector[0].iov_base = base;
+               state->vector[0].iov_base = (char *) base;
                state->vector[0].iov_len -= len;
 
                tlss->write.ofs += len;
@@ -1000,7 +1029,9 @@ struct tevent_req *_tstream_tls_connect_send(TALLOC_CTX *mem_ctx,
                                           (gnutls_pull_func)tstream_tls_pull_function);
        gnutls_transport_set_push_function(tlss->tls_session,
                                           (gnutls_push_func)tstream_tls_push_function);
+#if GNUTLS_VERSION_MAJOR < 3
        gnutls_transport_set_lowat(tlss->tls_session, 0);
+#endif
 
        tlss->handshake.req = req;
        tstream_tls_retry_handshake(state->tls_stream);
@@ -1249,7 +1280,9 @@ struct tevent_req *_tstream_tls_accept_send(TALLOC_CTX *mem_ctx,
                                           (gnutls_pull_func)tstream_tls_pull_function);
        gnutls_transport_set_push_function(tlss->tls_session,
                                           (gnutls_push_func)tstream_tls_push_function);
+#if GNUTLS_VERSION_MAJOR < 3
        gnutls_transport_set_lowat(tlss->tls_session, 0);
+#endif
 
        tlss->handshake.req = req;
        tstream_tls_retry_handshake(state->tls_stream);