--- /dev/null
+#include <stdarg.h>
+#include <stddef.h>
+#include <setjmp.h>
+#include <cmocka.h>
+
+#include "config.h"
+#include "torture.h"
+
+#include <errno.h>
+#include <sys/types.h>
+#include <sys/socket.h>
+#include <netinet/in.h>
+#include <arpa/inet.h>
+#include <stdlib.h>
+#include <stdio.h>
+#include <unistd.h>
+#include <poll.h>
+
+static int setup_echo_srv_tcp_ipv4(void **state)
+{
+ torture_setup_echo_srv_tcp_ipv4(state);
+
+ return 0;
+}
+
+#ifdef HAVE_IPV6
+static int setup_echo_srv_tcp_ipv6(void **state)
+{
+ torture_setup_echo_srv_tcp_ipv6(state);
+
+ return 0;
+}
+#endif
+
+static int teardown(void **state)
+{
+ torture_teardown_echo_srv(state);
+
+ return 0;
+}
+
+static void handle_poll_loop(size_t size, int s)
+{
+ char send_buf[size];
+ char recv_buf[size];
+ int nfds, num_open_fds;
+ struct pollfd pfds[1];
+ size_t nread = 0, nwrote = 0;
+ ssize_t ret;
+ int i;
+
+ num_open_fds = nfds = 1;
+ pfds[0].fd = s;
+ pfds[0].events = POLLIN | POLLOUT;
+
+ i = 0;
+ memset(send_buf, 0, sizeof(send_buf));
+
+ while (num_open_fds > 0 && i < 10) {
+ int ready;
+
+ printf("About to poll()\n");
+ ready = poll(pfds, nfds, -1);
+ assert_int_not_equal(ready, -1);
+
+ printf("Ready: %d\n", ready);
+
+ /* Deal with array returned by poll(). */
+ for (int j = 0; j < nfds; j++) {
+ if (pfds[j].revents != 0) {
+ printf(" fd=%d; events: %s%s%s%s\n", pfds[j].fd,
+ (pfds[j].revents & POLLIN) ? "POLLIN " : "",
+ (pfds[j].revents & POLLOUT) ? "POLLOUT " : "",
+ (pfds[j].revents & POLLHUP) ? "POLLHUP " : "",
+ (pfds[j].revents & POLLERR) ? "POLLERR " : "");
+ }
+
+ if (pfds[j].revents & POLLIN) {
+ ret = read(s,
+ recv_buf + nread,
+ sizeof(recv_buf) - nread);
+ printf(" fd=%d: read=%zd\n", pfds[j].fd, ret);
+ assert_int_not_equal(ret, -1);
+ nread += ret;
+ /* try to delay */
+ sleep(5);
+ }
+ if (pfds[j].revents & POLLOUT) {
+ snprintf(send_buf, sizeof(send_buf),
+ "packet.%d", i);
+ ret = write(s,
+ send_buf + nwrote,
+ sizeof(send_buf) - nwrote);
+ printf(" fd=%d: wrote=%zd\n", pfds[j].fd, ret);
+ assert_int_not_equal(ret, -1);
+ nwrote += ret;
+ if (nwrote == sizeof(send_buf)) {
+ /* no more to write */
+ pfds[j].events &= ~POLLOUT;
+ }
+ }
+ if (pfds[j].revents & (POLLERR | POLLHUP)) {
+ printf(" closing fd %d\n", pfds[j].fd);
+ close(pfds[j].fd);
+ num_open_fds--;
+ }
+
+ /* verify the data */
+ if (nwrote == sizeof(send_buf) && nread == nwrote) {
+ assert_memory_equal(send_buf, recv_buf,
+ sizeof(send_buf));
+ i++;
+ nwrote = 0;
+ nread = 0;
+ /* new packet to write */
+ pfds[j].events |= POLLOUT;
+ printf("== Next packet %d\n", i);
+ }
+ }
+ }
+
+ printf("All file descriptors closed; bye\n");
+}
+
+static void test_write_read_ipv4_size(void **state, size_t size)
+{
+ struct torture_address addr = {
+ .sa_socklen = sizeof(struct sockaddr_in),
+ };
+ int rc;
+ int s;
+
+ (void) state; /* unused */
+
+ s = socket(AF_INET, SOCK_STREAM, IPPROTO_TCP);
+ assert_int_not_equal(s, -1);
+
+ addr.sa.in.sin_family = AF_INET;
+ addr.sa.in.sin_port = htons(torture_server_port());
+
+ rc = inet_pton(addr.sa.in.sin_family,
+ torture_server_address(AF_INET),
+ &addr.sa.in.sin_addr);
+ assert_int_equal(rc, 1);
+
+ rc = connect(s, &addr.sa.s, addr.sa_socklen);
+ assert_int_equal(rc, 0);
+
+ /* closes the socket too */
+ handle_poll_loop(size, s);
+}
+
+static void test_write_read_ipv4(void **state)
+{
+ test_write_read_ipv4_size(state, 64);
+}
+
+static void test_write_read_ipv4_large(void **state)
+{
+ test_write_read_ipv4_size(state, 2000);
+}
+
+#ifdef HAVE_IPV6
+static void test_write_read_ipv6_size(void **state, size_t size)
+{
+ struct torture_address addr = {
+ .sa_socklen = sizeof(struct sockaddr_in6),
+ };
+ int rc;
+ int s;
+
+ (void) state; /* unused */
+
+ s = socket(AF_INET6, SOCK_STREAM, IPPROTO_TCP);
+ assert_int_not_equal(s, -1);
+
+ addr.sa.in6.sin6_family = AF_INET6;
+ addr.sa.in6.sin6_port = htons(torture_server_port());
+
+ rc = inet_pton(AF_INET6,
+ torture_server_address(AF_INET6),
+ &addr.sa.in6.sin6_addr);
+ assert_int_equal(rc, 1);
+
+ rc = connect(s, &addr.sa.s, addr.sa_socklen);
+ assert_int_equal(rc, 0);
+
+ /* closes the socket too */
+ handle_poll_loop(size, s);
+}
+
+static void test_write_read_ipv6(void **state)
+{
+ test_write_read_ipv6_size(state, 64);
+}
+
+static void test_write_read_ipv6_large(void **state)
+{
+ test_write_read_ipv6_size(state, 2000);
+}
+#endif
+
+int main(void) {
+ int rc;
+
+ const struct CMUnitTest tcp_write_tests[] = {
+ cmocka_unit_test_setup_teardown(test_write_read_ipv4,
+ setup_echo_srv_tcp_ipv4,
+ teardown),
+ cmocka_unit_test_setup_teardown(test_write_read_ipv4_large,
+ setup_echo_srv_tcp_ipv4,
+ teardown),
+#ifdef HAVE_IPV6
+ cmocka_unit_test_setup_teardown(test_write_read_ipv6,
+ setup_echo_srv_tcp_ipv6,
+ teardown),
+ cmocka_unit_test_setup_teardown(test_write_read_ipv6_large,
+ setup_echo_srv_tcp_ipv6,
+ teardown),
+#endif
+ };
+
+ rc = cmocka_run_group_tests(tcp_write_tests, NULL, NULL);
+
+ return rc;
+}