r4831: added udp support to our generic sockets library.
authorAndrew Tridgell <tridge@samba.org>
Wed, 19 Jan 2005 03:20:20 +0000 (03:20 +0000)
committerGerald (Jerry) Carter <jerry@samba.org>
Wed, 10 Oct 2007 18:08:59 +0000 (13:08 -0500)
I decided to incorporate the udp support into the socket_ipv4.c
backend (and later in socket_ipv6.c) rather than doing a separate
backend, as so much of the code is shareable. Basically this adds a
socket_sendto() and a socket_recvfrom() call and not much all.

For udp servers, I decided to keep the call as socket_listen(), even
though dgram servers don't actually call listen(). This keeps the API
consistent.

I also added a simple local sockets testsuite in smbtorture,
LOCAL-SOCKET
(This used to be commit 9f12a45a05c5c447fb4ec18c8dd28f70e90e32a5)

source4/lib/socket/socket.c
source4/lib/socket/socket.h
source4/lib/socket/socket_ipv4.c
source4/lib/socket/socket_ipv6.c
source4/lib/socket/socket_unix.c
source4/torture/config.mk
source4/torture/local/socket.c [new file with mode: 0644]
source4/torture/torture.c

index cc43348e79f9594ea6c522459124e982b6f2e19e..97176ea15045b4bc198a14ff4e22ed566ad2a8d3 100644 (file)
@@ -32,7 +32,9 @@ static int socket_destructor(void *ptr)
        return 0;
 }
 
-NTSTATUS socket_create_with_ops(TALLOC_CTX *mem_ctx, const struct socket_ops *ops, struct socket_context **new_sock, uint32_t flags)
+static NTSTATUS socket_create_with_ops(TALLOC_CTX *mem_ctx, const struct socket_ops *ops,
+                                      struct socket_context **new_sock, 
+                                      enum socket_type type, uint32_t flags)
 {
        NTSTATUS status;
 
@@ -41,7 +43,7 @@ NTSTATUS socket_create_with_ops(TALLOC_CTX *mem_ctx, const struct socket_ops *op
                return NT_STATUS_NO_MEMORY;
        }
 
-       (*new_sock)->type = ops->type;
+       (*new_sock)->type = type;
        (*new_sock)->state = SOCKET_STATE_UNDEFINED;
        (*new_sock)->flags = flags;
 
@@ -60,6 +62,7 @@ NTSTATUS socket_create_with_ops(TALLOC_CTX *mem_ctx, const struct socket_ops *op
           send calls on non-blocking sockets will randomly recv/send
           less data than requested */
        if (!(flags & SOCKET_FLAG_BLOCK) &&
+           type == SOCKET_TYPE_STREAM &&
            lp_parm_bool(-1, "socket", "testnonblock", False)) {
                (*new_sock)->flags |= SOCKET_FLAG_TESTNONBLOCK;
        }
@@ -69,7 +72,8 @@ NTSTATUS socket_create_with_ops(TALLOC_CTX *mem_ctx, const struct socket_ops *op
        return NT_STATUS_OK;
 }
 
-NTSTATUS socket_create(const char *name, enum socket_type type, struct socket_context **new_sock, uint32_t flags)
+NTSTATUS socket_create(const char *name, enum socket_type type, 
+                      struct socket_context **new_sock, uint32_t flags)
 {
        const struct socket_ops *ops;
 
@@ -78,7 +82,7 @@ NTSTATUS socket_create(const char *name, enum socket_type type, struct socket_co
                return NT_STATUS_INVALID_PARAMETER;
        }
 
-       return socket_create_with_ops(NULL, ops, new_sock, flags);
+       return socket_create_with_ops(NULL, ops, new_sock, type, flags);
 }
 
 void socket_destroy(struct socket_context *sock)
@@ -92,10 +96,6 @@ NTSTATUS socket_connect(struct socket_context *sock,
                        const char *server_address, int server_port,
                        uint32_t flags)
 {
-       if (sock->type != SOCKET_TYPE_STREAM) {
-               return NT_STATUS_INVALID_PARAMETER;
-       }
-
        if (sock->state != SOCKET_STATE_UNDEFINED) {
                return NT_STATUS_INVALID_PARAMETER;
        }
@@ -117,10 +117,6 @@ NTSTATUS socket_connect_complete(struct socket_context *sock, uint32_t flags)
 
 NTSTATUS socket_listen(struct socket_context *sock, const char *my_address, int port, int queue_size, uint32_t flags)
 {
-       if (sock->type != SOCKET_TYPE_STREAM) {
-               return NT_STATUS_INVALID_PARAMETER;
-       }
-
        if (sock->state != SOCKET_STATE_UNDEFINED) {
                return NT_STATUS_INVALID_PARAMETER;
        }
@@ -160,10 +156,6 @@ NTSTATUS socket_accept(struct socket_context *sock, struct socket_context **new_
 NTSTATUS socket_recv(struct socket_context *sock, void *buf, 
                     size_t wantlen, size_t *nread, uint32_t flags)
 {
-       if (sock->type != SOCKET_TYPE_STREAM) {
-               return NT_STATUS_INVALID_PARAMETER;
-       }
-
        if (sock->state != SOCKET_STATE_CLIENT_CONNECTED &&
            sock->state != SOCKET_STATE_SERVER_CONNECTED) {
                return NT_STATUS_INVALID_PARAMETER;
@@ -184,13 +176,25 @@ NTSTATUS socket_recv(struct socket_context *sock, void *buf,
        return sock->ops->fn_recv(sock, buf, wantlen, nread, flags);
 }
 
-NTSTATUS socket_send(struct socket_context *sock, 
-                    const DATA_BLOB *blob, size_t *sendlen, uint32_t flags)
+NTSTATUS socket_recvfrom(struct socket_context *sock, void *buf, 
+                        size_t wantlen, size_t *nread, uint32_t flags,
+                        const char **src_addr, int *src_port)
 {
-       if (sock->type != SOCKET_TYPE_STREAM) {
+       if (sock->type != SOCKET_TYPE_DGRAM) {
                return NT_STATUS_INVALID_PARAMETER;
        }
 
+       if (!sock->ops->fn_recvfrom) {
+               return NT_STATUS_NOT_IMPLEMENTED;
+       }
+
+       return sock->ops->fn_recvfrom(sock, buf, wantlen, nread, flags, 
+                                     src_addr, src_port);
+}
+
+NTSTATUS socket_send(struct socket_context *sock, 
+                    const DATA_BLOB *blob, size_t *sendlen, uint32_t flags)
+{
        if (sock->state != SOCKET_STATE_CLIENT_CONNECTED &&
            sock->state != SOCKET_STATE_SERVER_CONNECTED) {
                return NT_STATUS_INVALID_PARAMETER;
@@ -213,6 +217,27 @@ NTSTATUS socket_send(struct socket_context *sock,
        return sock->ops->fn_send(sock, blob, sendlen, flags);
 }
 
+
+NTSTATUS socket_sendto(struct socket_context *sock, 
+                      const DATA_BLOB *blob, size_t *sendlen, uint32_t flags,
+                      const char *dest_addr, int dest_port)
+{
+       if (sock->type != SOCKET_TYPE_DGRAM) {
+               return NT_STATUS_INVALID_PARAMETER;
+       }
+
+       if (sock->state == SOCKET_STATE_CLIENT_CONNECTED ||
+           sock->state == SOCKET_STATE_SERVER_CONNECTED) {
+               return NT_STATUS_INVALID_PARAMETER;
+       }
+
+       if (!sock->ops->fn_sendto) {
+               return NT_STATUS_NOT_IMPLEMENTED;
+       }
+
+       return sock->ops->fn_sendto(sock, blob, sendlen, flags, dest_addr, dest_port);
+}
+
 NTSTATUS socket_set_option(struct socket_context *sock, const char *option, const char *val)
 {
        if (!sock->ops->fn_set_option) {
@@ -302,7 +327,7 @@ const struct socket_ops *socket_getops_byname(const char *name, enum socket_type
 {
        if (strcmp("ip", name) == 0 || 
            strcmp("ipv4", name) == 0) {
-               return socket_ipv4_ops();
+               return socket_ipv4_ops(type);
        }
 
 #if HAVE_SOCKET_IPV6
@@ -311,12 +336,12 @@ const struct socket_ops *socket_getops_byname(const char *name, enum socket_type
                        DEBUG(3, ("IPv6 support was disabled in smb.conf"));
                        return NULL;
                }
-               return socket_ipv6_ops();
+               return socket_ipv6_ops(type);
        }
 #endif
 
        if (strcmp("unix", name) == 0) {
-               return socket_unixdom_ops();
+               return socket_unixdom_ops(type);
        }
 
        return NULL;
index 7dd8c0ae179294da7bbd41902e0b442aa5cdf454..162a05cb402e3f01fcca1277b2c9e50540efd1dc 100644 (file)
 struct socket_context;
 
 enum socket_type {
-       SOCKET_TYPE_STREAM
+       SOCKET_TYPE_STREAM,
+       SOCKET_TYPE_DGRAM
 };
 
 struct socket_ops {
        const char *name;
-       enum socket_type type;
 
        NTSTATUS (*fn_init)(struct socket_context *sock);
 
@@ -50,9 +50,16 @@ struct socket_ops {
 
        /* general ops */
        NTSTATUS (*fn_recv)(struct socket_context *sock, void *buf,
-                        size_t wantlen, size_t *nread, uint32_t flags);
+                           size_t wantlen, size_t *nread, uint32_t flags);
        NTSTATUS (*fn_send)(struct socket_context *sock, 
-                        const DATA_BLOB *blob, size_t *sendlen, uint32_t flags);
+                           const DATA_BLOB *blob, size_t *sendlen, uint32_t flags);
+
+       NTSTATUS (*fn_sendto)(struct socket_context *sock, 
+                             const DATA_BLOB *blob, size_t *sendlen, uint32_t flags,
+                             const char *dest_addr, int dest_port);
+       NTSTATUS (*fn_recvfrom)(struct socket_context *sock, 
+                               void *buf, size_t wantlen, size_t *nread, uint32_t flags,
+                               const char **src_addr, int *src_port);
 
        void (*fn_close)(struct socket_context *sock);
 
index 7cf2b73e4e70ef13b4f90f67a0da19ca802d18b9..d6b6bf7be4ec90a3dd0378ea273f43e43b047683 100644 (file)
 #include "includes.h"
 #include "system/network.h"
 
-static NTSTATUS ipv4_tcp_init(struct socket_context *sock)
+static NTSTATUS ipv4_init(struct socket_context *sock)
 {
-       sock->fd = socket(PF_INET, SOCK_STREAM, 0);
+       int type;
+
+       switch (sock->type) {
+       case SOCKET_TYPE_STREAM:
+               type = SOCK_STREAM;
+               break;
+       case SOCKET_TYPE_DGRAM:
+               type = SOCK_DGRAM;
+               break;
+       default:
+               return NT_STATUS_INVALID_PARAMETER;
+       }
+
+       sock->fd = socket(PF_INET, type, 0);
        if (sock->fd == -1) {
                return map_nt_error_from_unix(errno);
        }
@@ -34,12 +47,12 @@ static NTSTATUS ipv4_tcp_init(struct socket_context *sock)
        return NT_STATUS_OK;
 }
 
-static void ipv4_tcp_close(struct socket_context *sock)
+static void ipv4_close(struct socket_context *sock)
 {
        close(sock->fd);
 }
 
-static NTSTATUS ipv4_tcp_connect_complete(struct socket_context *sock, uint32_t flags)
+static NTSTATUS ipv4_connect_complete(struct socket_context *sock, uint32_t flags)
 {
        int error=0, ret;
        socklen_t len = sizeof(error);
@@ -67,7 +80,7 @@ static NTSTATUS ipv4_tcp_connect_complete(struct socket_context *sock, uint32_t
 }
 
 
-static NTSTATUS ipv4_tcp_connect(struct socket_context *sock,
+static NTSTATUS ipv4_connect(struct socket_context *sock,
                                 const char *my_address, int my_port,
                                 const char *srv_address, int srv_port,
                                 uint32_t flags)
@@ -110,13 +123,17 @@ static NTSTATUS ipv4_tcp_connect(struct socket_context *sock,
                return map_nt_error_from_unix(errno);
        }
 
-       return ipv4_tcp_connect_complete(sock, flags);
+       return ipv4_connect_complete(sock, flags);
 }
 
 
-static NTSTATUS ipv4_tcp_listen(struct socket_context *sock,
-                                       const char *my_address, int port,
-                                       int queue_size, uint32_t flags)
+/*
+  note that for simplicity of the API, socket_listen() is also
+  use for DGRAM sockets, but in reality only a bind() is done
+*/
+static NTSTATUS ipv4_listen(struct socket_context *sock,
+                           const char *my_address, int port,
+                           int queue_size, uint32_t flags)
 {
        struct sockaddr_in my_addr;
        struct ipv4_addr ip_addr;
@@ -137,9 +154,11 @@ static NTSTATUS ipv4_tcp_listen(struct socket_context *sock,
                return map_nt_error_from_unix(errno);
        }
 
-       ret = listen(sock->fd, queue_size);
-       if (ret == -1) {
-               return map_nt_error_from_unix(errno);
+       if (sock->type == SOCKET_TYPE_STREAM) {
+               ret = listen(sock->fd, queue_size);
+               if (ret == -1) {
+                       return map_nt_error_from_unix(errno);
+               }
        }
 
        if (!(flags & SOCKET_FLAG_BLOCK)) {
@@ -154,12 +173,16 @@ static NTSTATUS ipv4_tcp_listen(struct socket_context *sock,
        return NT_STATUS_OK;
 }
 
-static NTSTATUS ipv4_tcp_accept(struct socket_context *sock, struct socket_context **new_sock)
+static NTSTATUS ipv4_accept(struct socket_context *sock, struct socket_context **new_sock)
 {
        struct sockaddr_in cli_addr;
        socklen_t cli_addr_len = sizeof(cli_addr);
        int new_fd;
 
+       if (sock->type != SOCKET_TYPE_STREAM) {
+               return NT_STATUS_INVALID_PARAMETER;
+       }
+
        new_fd = accept(sock->fd, (struct sockaddr *)&cli_addr, &cli_addr_len);
        if (new_fd == -1) {
                return map_nt_error_from_unix(errno);
@@ -198,7 +221,7 @@ static NTSTATUS ipv4_tcp_accept(struct socket_context *sock, struct socket_conte
        return NT_STATUS_OK;
 }
 
-static NTSTATUS ipv4_tcp_recv(struct socket_context *sock, void *buf, 
+static NTSTATUS ipv4_recv(struct socket_context *sock, void *buf, 
                              size_t wantlen, size_t *nread, uint32_t flags)
 {
        ssize_t gotlen;
@@ -227,7 +250,49 @@ static NTSTATUS ipv4_tcp_recv(struct socket_context *sock, void *buf,
        return NT_STATUS_OK;
 }
 
-static NTSTATUS ipv4_tcp_send(struct socket_context *sock, 
+
+static NTSTATUS ipv4_recvfrom(struct socket_context *sock, void *buf, 
+                             size_t wantlen, size_t *nread, uint32_t flags,
+                             const char **src_addr, int *src_port)
+{
+       ssize_t gotlen;
+       int flgs = 0;
+       struct sockaddr_in from_addr;
+       socklen_t from_len = sizeof(from_addr);
+       const char *addr;
+
+       if (flags & SOCKET_FLAG_PEEK) {
+               flgs |= MSG_PEEK;
+       }
+
+       if (flags & SOCKET_FLAG_BLOCK) {
+               flgs |= MSG_WAITALL;
+       }
+
+       *nread = 0;
+
+       gotlen = recvfrom(sock->fd, buf, wantlen, flgs, 
+                         (struct sockaddr *)&from_addr, &from_len);
+       if (gotlen == 0) {
+               return NT_STATUS_END_OF_FILE;
+       } else if (gotlen == -1) {
+               return map_nt_error_from_unix(errno);
+       }
+
+       addr = inet_ntoa(from_addr.sin_addr);
+       if (addr == NULL) {
+               return NT_STATUS_INTERNAL_ERROR;
+       }
+       *src_addr = talloc_strdup(sock, addr);
+       NT_STATUS_HAVE_NO_MEMORY(*src_addr);
+       *src_port = ntohs(from_addr.sin_port);
+
+       *nread = gotlen;
+
+       return NT_STATUS_OK;
+}
+
+static NTSTATUS ipv4_send(struct socket_context *sock, 
                              const DATA_BLOB *blob, size_t *sendlen, uint32_t flags)
 {
        ssize_t len;
@@ -245,13 +310,44 @@ static NTSTATUS ipv4_tcp_send(struct socket_context *sock,
        return NT_STATUS_OK;
 }
 
-static NTSTATUS ipv4_tcp_set_option(struct socket_context *sock, const char *option, const char *val)
+static NTSTATUS ipv4_sendto(struct socket_context *sock, 
+                           const DATA_BLOB *blob, size_t *sendlen, uint32_t flags,
+                           const char *dest_addr, int dest_port)
+{
+       ssize_t len;
+       int flgs = 0;
+       struct sockaddr_in srv_addr;
+       struct ipv4_addr addr;
+
+       ZERO_STRUCT(srv_addr);
+#ifdef HAVE_SOCK_SIN_LEN
+       srv_addr.sin_len         = sizeof(srv_addr);
+#endif
+       addr                     = interpret_addr2(dest_addr);
+       srv_addr.sin_addr.s_addr = addr.addr;
+       srv_addr.sin_port        = htons(dest_port);
+       srv_addr.sin_family      = PF_INET;
+
+       *sendlen = 0;
+
+       len = sendto(sock->fd, blob->data, blob->length, flgs, 
+                  (struct sockaddr *)&srv_addr, sizeof(srv_addr));
+       if (len == -1) {
+               return map_nt_error_from_unix(errno);
+       }       
+
+       *sendlen = len;
+
+       return NT_STATUS_OK;
+}
+
+static NTSTATUS ipv4_set_option(struct socket_context *sock, const char *option, const char *val)
 {
        set_socket_options(sock->fd, option);
        return NT_STATUS_OK;
 }
 
-static char *ipv4_tcp_get_peer_name(struct socket_context *sock, TALLOC_CTX *mem_ctx)
+static char *ipv4_get_peer_name(struct socket_context *sock, TALLOC_CTX *mem_ctx)
 {
        struct sockaddr_in peer_addr;
        socklen_t len = sizeof(peer_addr);
@@ -271,7 +367,7 @@ static char *ipv4_tcp_get_peer_name(struct socket_context *sock, TALLOC_CTX *mem
        return talloc_strdup(mem_ctx, he->h_name);
 }
 
-static char *ipv4_tcp_get_peer_addr(struct socket_context *sock, TALLOC_CTX *mem_ctx)
+static char *ipv4_get_peer_addr(struct socket_context *sock, TALLOC_CTX *mem_ctx)
 {
        struct sockaddr_in peer_addr;
        socklen_t len = sizeof(peer_addr);
@@ -285,7 +381,7 @@ static char *ipv4_tcp_get_peer_addr(struct socket_context *sock, TALLOC_CTX *mem
        return talloc_strdup(mem_ctx, inet_ntoa(peer_addr.sin_addr));
 }
 
-static int ipv4_tcp_get_peer_port(struct socket_context *sock)
+static int ipv4_get_peer_port(struct socket_context *sock)
 {
        struct sockaddr_in peer_addr;
        socklen_t len = sizeof(peer_addr);
@@ -299,7 +395,7 @@ static int ipv4_tcp_get_peer_port(struct socket_context *sock)
        return ntohs(peer_addr.sin_port);
 }
 
-static char *ipv4_tcp_get_my_addr(struct socket_context *sock, TALLOC_CTX *mem_ctx)
+static char *ipv4_get_my_addr(struct socket_context *sock, TALLOC_CTX *mem_ctx)
 {
        struct sockaddr_in my_addr;
        socklen_t len = sizeof(my_addr);
@@ -313,7 +409,7 @@ static char *ipv4_tcp_get_my_addr(struct socket_context *sock, TALLOC_CTX *mem_c
        return talloc_strdup(mem_ctx, inet_ntoa(my_addr.sin_addr));
 }
 
-static int ipv4_tcp_get_my_port(struct socket_context *sock)
+static int ipv4_get_my_port(struct socket_context *sock)
 {
        struct sockaddr_in my_addr;
        socklen_t len = sizeof(my_addr);
@@ -327,36 +423,36 @@ static int ipv4_tcp_get_my_port(struct socket_context *sock)
        return ntohs(my_addr.sin_port);
 }
 
-static int ipv4_tcp_get_fd(struct socket_context *sock)
+static int ipv4_get_fd(struct socket_context *sock)
 {
        return sock->fd;
 }
 
-static const struct socket_ops ipv4_tcp_ops = {
+static const struct socket_ops ipv4_ops = {
        .name                   = "ipv4",
-       .type                   = SOCKET_TYPE_STREAM,
-
-       .fn_init                = ipv4_tcp_init,
-       .fn_connect             = ipv4_tcp_connect,
-       .fn_connect_complete    = ipv4_tcp_connect_complete,
-       .fn_listen              = ipv4_tcp_listen,
-       .fn_accept              = ipv4_tcp_accept,
-       .fn_recv                = ipv4_tcp_recv,
-       .fn_send                = ipv4_tcp_send,
-       .fn_close               = ipv4_tcp_close,
-
-       .fn_set_option          = ipv4_tcp_set_option,
-
-       .fn_get_peer_name       = ipv4_tcp_get_peer_name,
-       .fn_get_peer_addr       = ipv4_tcp_get_peer_addr,
-       .fn_get_peer_port       = ipv4_tcp_get_peer_port,
-       .fn_get_my_addr         = ipv4_tcp_get_my_addr,
-       .fn_get_my_port         = ipv4_tcp_get_my_port,
-
-       .fn_get_fd              = ipv4_tcp_get_fd
+       .fn_init                = ipv4_init,
+       .fn_connect             = ipv4_connect,
+       .fn_connect_complete    = ipv4_connect_complete,
+       .fn_listen              = ipv4_listen,
+       .fn_accept              = ipv4_accept,
+       .fn_recv                = ipv4_recv,
+       .fn_recvfrom            = ipv4_recvfrom,
+       .fn_send                = ipv4_send,
+       .fn_sendto              = ipv4_sendto,
+       .fn_close               = ipv4_close,
+
+       .fn_set_option          = ipv4_set_option,
+
+       .fn_get_peer_name       = ipv4_get_peer_name,
+       .fn_get_peer_addr       = ipv4_get_peer_addr,
+       .fn_get_peer_port       = ipv4_get_peer_port,
+       .fn_get_my_addr         = ipv4_get_my_addr,
+       .fn_get_my_port         = ipv4_get_my_port,
+
+       .fn_get_fd              = ipv4_get_fd
 };
 
-const struct socket_ops *socket_ipv4_ops(void)
+const struct socket_ops *socket_ipv4_ops(enum socket_type type)
 {
-       return &ipv4_tcp_ops;
+       return &ipv4_ops;
 }
index 35b4037ff48860c6d471517a58dc8a3da394877c..27e452b14eb5810009d5e114612d3289bd65cdd5 100644 (file)
@@ -347,8 +347,6 @@ static int ipv6_tcp_get_fd(struct socket_context *sock)
 
 static const struct socket_ops ipv6_tcp_ops = {
        .name                   = "ipv6",
-       .type                   = SOCKET_TYPE_STREAM,
-
        .fn_init                = ipv6_tcp_init,
        .fn_connect             = ipv6_tcp_connect,
        .fn_connect_complete    = ipv6_tcp_connect_complete,
@@ -369,7 +367,10 @@ static const struct socket_ops ipv6_tcp_ops = {
        .fn_get_fd              = ipv6_tcp_get_fd
 };
 
-const struct socket_ops *socket_ipv6_ops(void)
+const struct socket_ops *socket_ipv6_ops(enum socket_type type)
 {
+       if (type != SOCKET_TYPE_STREAM) {
+               return NULL;
+       }
        return &ipv6_tcp_ops;
 }
index 60a4b9ec4816d9c0790741dd92f5f4ea9c1170fb..bdd68f9d9d99f7c8a4734a7dca7d15d31bd28650 100644 (file)
@@ -4,7 +4,7 @@
    unix domain socket functions
 
    Copyright (C) Stefan Metzmacher 2004
-   Copyright (C) Andrew Tridgell 2004
+   Copyright (C) Andrew Tridgell 2004-2005
    
    This program is free software; you can redistribute it and/or modify
    it under the terms of the GNU General Public License as published by
@@ -266,8 +266,6 @@ static int unixdom_get_fd(struct socket_context *sock)
 
 static const struct socket_ops unixdom_ops = {
        .name                   = "unix",
-       .type                   = SOCKET_TYPE_STREAM,
-
        .fn_init                = unixdom_init,
        .fn_connect             = unixdom_connect,
        .fn_connect_complete    = unixdom_connect_complete,
@@ -288,7 +286,10 @@ static const struct socket_ops unixdom_ops = {
        .fn_get_fd              = unixdom_get_fd
 };
 
-const struct socket_ops *socket_unixdom_ops(void)
+const struct socket_ops *socket_unixdom_ops(enum socket_type type)
 {
+       if (type != SOCKET_TYPE_STREAM) {
+               return NULL;
+       }
        return &unixdom_ops;
 }
index 0e74d7c040ec1f34eb39d76d4a3d87d1d948f740..a31c2c05983e8d0aef7d12b9a1ad6ca9bafb06ac 100644 (file)
@@ -138,7 +138,8 @@ ADD_OBJ_FILES = \
                lib/talloc/testsuite.o \
                torture/local/messaging.o \
                torture/local/binding_string.o \
-               torture/local/idtree.o
+               torture/local/idtree.o \
+               torture/local/socket.o
 REQUIRED_SUBSYSTEMS = \
                LIBSMB \
                MESSAGING
diff --git a/source4/torture/local/socket.c b/source4/torture/local/socket.c
new file mode 100644 (file)
index 0000000..cdd379e
--- /dev/null
@@ -0,0 +1,134 @@
+/* 
+   Unix SMB/CIFS implementation.
+
+   local testing of socket routines.
+
+   Copyright (C) Andrew Tridgell 2005
+   
+   This program is free software; you can redistribute it and/or modify
+   it under the terms of the GNU General Public License as published by
+   the Free Software Foundation; either version 2 of the License, or
+   (at your option) any later version.
+   
+   This program is distributed in the hope that it will be useful,
+   but WITHOUT ANY WARRANTY; without even the implied warranty of
+   MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
+   GNU General Public License for more details.
+   
+   You should have received a copy of the GNU General Public License
+   along with this program; if not, write to the Free Software
+   Foundation, Inc., 675 Mass Ave, Cambridge, MA 02139, USA.
+*/
+
+#include "includes.h"
+
+#define CHECK_STATUS(status, correct) do { \
+       if (!NT_STATUS_EQUAL(status, correct)) { \
+               printf("(%s) Incorrect status %s - should be %s\n", \
+                      __location__, nt_errstr(status), nt_errstr(correct)); \
+               ret = False; \
+               goto done; \
+       }} while (0)
+
+
+/*
+  basic testing of udp routines
+*/
+static BOOL test_udp(TALLOC_CTX *mem_ctx)
+{
+       struct socket_context *sock1, *sock2;
+       NTSTATUS status;
+       int srv_port, from_port;
+       const char *srv_addr, *from_addr;
+       size_t size = 100 + (random() % 100);
+       DATA_BLOB blob, blob2;
+       size_t sent, nread;
+       BOOL ret = True;
+
+       status = socket_create("ip", SOCKET_TYPE_DGRAM, &sock1, 0);
+       CHECK_STATUS(status, NT_STATUS_OK);
+       talloc_steal(mem_ctx, sock1);
+
+       status = socket_create("ip", SOCKET_TYPE_DGRAM, &sock2, 0);
+       CHECK_STATUS(status, NT_STATUS_OK);
+       talloc_steal(mem_ctx, sock2);
+
+       status = socket_listen(sock1, "127.0.0.1", 0, 0, 0);
+       CHECK_STATUS(status, NT_STATUS_OK);
+
+       srv_addr = socket_get_my_addr(sock1, mem_ctx);
+       if (srv_addr == NULL || strcmp(srv_addr, "127.0.0.1") != 0) {
+               printf("Expected server address of 127.0.0.1 but got %s\n", srv_addr);
+               return False;
+       }
+
+       srv_port = socket_get_my_port(sock1);
+       printf("server port is %d\n", srv_port);
+
+       blob  = data_blob_talloc(mem_ctx, NULL, size);
+       blob2 = data_blob_talloc(mem_ctx, NULL, size);
+       generate_random_buffer(blob.data, blob.length);
+
+       sent = size;
+       status = socket_sendto(sock2, &blob, &sent, 0, srv_addr, srv_port);
+       CHECK_STATUS(status, NT_STATUS_OK);
+
+       status = socket_recvfrom(sock1, blob2.data, size, &nread, 0, 
+                                &from_addr, &from_port);
+       CHECK_STATUS(status, NT_STATUS_OK);
+
+       if (strcmp(from_addr, srv_addr) != 0) {
+               printf("Unexpected recvfrom addr %s\n", from_addr);
+               ret = False;
+       }
+       if (nread != size) {
+               printf("Unexpected recvfrom size %d should be %d\n", nread, size);
+               ret = False;
+       }
+
+       if (memcmp(blob2.data, blob.data, size) != 0) {
+               printf("Bad data in recvfrom\n");
+               ret = False;
+       }
+
+       generate_random_buffer(blob.data, blob.length);
+       status = socket_sendto(sock1, &blob, &sent, 0, from_addr, from_port);
+       CHECK_STATUS(status, NT_STATUS_OK);
+
+       status = socket_recvfrom(sock2, blob2.data, size, &nread, 0, 
+                                &from_addr, &from_port);
+       CHECK_STATUS(status, NT_STATUS_OK);
+       if (strcmp(from_addr, srv_addr) != 0) {
+               printf("Unexpected recvfrom addr %s\n", from_addr);
+               ret = False;
+       }
+       if (nread != size) {
+               printf("Unexpected recvfrom size %d should be %d\n", nread, size);
+               ret = False;
+       }
+       if (from_port != srv_port) {
+               printf("Unexpected recvfrom port %d should be %d\n", 
+                      from_port, srv_port);
+               ret = False;
+       }
+       if (memcmp(blob2.data, blob.data, size) != 0) {
+               printf("Bad data in recvfrom\n");
+               ret = False;
+       }
+
+done:
+       talloc_free(sock1);
+       talloc_free(sock2);
+
+       return ret;
+}
+
+BOOL torture_local_socket(void) 
+{
+       BOOL ret = True;
+       TALLOC_CTX *mem_ctx = talloc_new(NULL);
+
+       ret &= test_udp(mem_ctx);
+
+       return ret;
+}
index f9faf21957791e139bdbccd19dd858455bb2cfd7..b2219c617c42aced3d3ae9b0de898d2a66ad923d 100644 (file)
@@ -2432,6 +2432,7 @@ static struct {
        {"LOCAL-MESSAGING", torture_local_messaging, 0},
        {"LOCAL-BINDING", torture_local_binding_string, 0},
        {"LOCAL-IDTREE", torture_local_idtree, 0},
+       {"LOCAL-SOCKET", torture_local_socket, 0},
 
        /* ldap testers */
        {"LDAP-BASIC", torture_ldap_basic, 0},