0cb23920f1cc42be6d369088f1a6255c36e3cf76
[samba.git] / source4 / lib / util_sock.c
1 /* 
2    Unix SMB/CIFS implementation.
3    Samba utility functions
4    Copyright (C) Andrew Tridgell 1992-1998
5    Copyright (C) Tim Potter      2000-2001
6    
7    This program is free software; you can redistribute it and/or modify
8    it under the terms of the GNU General Public License as published by
9    the Free Software Foundation; either version 2 of the License, or
10    (at your option) any later version.
11    
12    This program is distributed in the hope that it will be useful,
13    but WITHOUT ANY WARRANTY; without even the implied warranty of
14    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
15    GNU General Public License for more details.
16    
17    You should have received a copy of the GNU General Public License
18    along with this program; if not, write to the Free Software
19    Foundation, Inc., 675 Mass Ave, Cambridge, MA 02139, USA.
20 */
21
22 #include "includes.h"
23
24
25 /****************************************************************************
26  Determine if a file descriptor is in fact a socket.
27 ****************************************************************************/
28 BOOL is_a_socket(int fd)
29 {
30         int v,l;
31         l = sizeof(int);
32         return getsockopt(fd, SOL_SOCKET, SO_TYPE, (char *)&v, &l) == 0;
33 }
34
35 enum SOCK_OPT_TYPES {OPT_BOOL,OPT_INT,OPT_ON};
36
37 typedef struct smb_socket_option {
38         const char *name;
39         int level;
40         int option;
41         int value;
42         int opttype;
43 } smb_socket_option;
44
45 static const smb_socket_option socket_options[] = {
46   {"SO_KEEPALIVE",      SOL_SOCKET,    SO_KEEPALIVE,    0,                 OPT_BOOL},
47   {"SO_REUSEADDR",      SOL_SOCKET,    SO_REUSEADDR,    0,                 OPT_BOOL},
48   {"SO_BROADCAST",      SOL_SOCKET,    SO_BROADCAST,    0,                 OPT_BOOL},
49 #ifdef TCP_NODELAY
50   {"TCP_NODELAY",       IPPROTO_TCP,   TCP_NODELAY,     0,                 OPT_BOOL},
51 #endif
52 #ifdef IPTOS_LOWDELAY
53   {"IPTOS_LOWDELAY",    IPPROTO_IP,    IP_TOS,          IPTOS_LOWDELAY,    OPT_ON},
54 #endif
55 #ifdef IPTOS_THROUGHPUT
56   {"IPTOS_THROUGHPUT",  IPPROTO_IP,    IP_TOS,          IPTOS_THROUGHPUT,  OPT_ON},
57 #endif
58 #ifdef SO_REUSEPORT
59   {"SO_REUSEPORT",      SOL_SOCKET,    SO_REUSEPORT,    0,                 OPT_BOOL},
60 #endif
61 #ifdef SO_SNDBUF
62   {"SO_SNDBUF",         SOL_SOCKET,    SO_SNDBUF,       0,                 OPT_INT},
63 #endif
64 #ifdef SO_RCVBUF
65   {"SO_RCVBUF",         SOL_SOCKET,    SO_RCVBUF,       0,                 OPT_INT},
66 #endif
67 #ifdef SO_SNDLOWAT
68   {"SO_SNDLOWAT",       SOL_SOCKET,    SO_SNDLOWAT,     0,                 OPT_INT},
69 #endif
70 #ifdef SO_RCVLOWAT
71   {"SO_RCVLOWAT",       SOL_SOCKET,    SO_RCVLOWAT,     0,                 OPT_INT},
72 #endif
73 #ifdef SO_SNDTIMEO
74   {"SO_SNDTIMEO",       SOL_SOCKET,    SO_SNDTIMEO,     0,                 OPT_INT},
75 #endif
76 #ifdef SO_RCVTIMEO
77   {"SO_RCVTIMEO",       SOL_SOCKET,    SO_RCVTIMEO,     0,                 OPT_INT},
78 #endif
79   {NULL,0,0,0,0}};
80
81 /****************************************************************************
82  Print socket options.
83 ****************************************************************************/
84
85 static void print_socket_options(int s)
86 {
87         int value, vlen = 4;
88         const smb_socket_option *p = &socket_options[0];
89
90         for (; p->name != NULL; p++) {
91                 if (getsockopt(s, p->level, p->option, (void *)&value, &vlen) == -1) {
92                         DEBUG(5,("Could not test socket option %s.\n", p->name));
93                 } else {
94                         DEBUG(5,("socket option %s = %d\n",p->name,value));
95                 }
96         }
97  }
98
99 /****************************************************************************
100  Set user socket options.
101 ****************************************************************************/
102
103 void set_socket_options(int fd, const char *options)
104 {
105         fstring tok;
106
107         while (next_token(&options,tok," \t,", sizeof(tok))) {
108                 int ret=0,i;
109                 int value = 1;
110                 char *p;
111                 BOOL got_value = False;
112
113                 if ((p = strchr_m(tok,'='))) {
114                         *p = 0;
115                         value = atoi(p+1);
116                         got_value = True;
117                 }
118
119                 for (i=0;socket_options[i].name;i++)
120                         if (strequal(socket_options[i].name,tok))
121                                 break;
122
123                 if (!socket_options[i].name) {
124                         DEBUG(0,("Unknown socket option %s\n",tok));
125                         continue;
126                 }
127
128                 switch (socket_options[i].opttype) {
129                 case OPT_BOOL:
130                 case OPT_INT:
131                         ret = setsockopt(fd,socket_options[i].level,
132                                                 socket_options[i].option,(char *)&value,sizeof(int));
133                         break;
134
135                 case OPT_ON:
136                         if (got_value)
137                                 DEBUG(0,("syntax error - %s does not take a value\n",tok));
138
139                         {
140                                 int on = socket_options[i].value;
141                                 ret = setsockopt(fd,socket_options[i].level,
142                                                         socket_options[i].option,(char *)&on,sizeof(int));
143                         }
144                         break;    
145                 }
146       
147                 if (ret != 0)
148                         DEBUG(0,("Failed to set socket option %s (Error %s)\n",tok, strerror(errno) ));
149         }
150
151         print_socket_options(fd);
152 }
153
154 /****************************************************************************
155  Read from a socket.
156 ****************************************************************************/
157
158 ssize_t read_udp_socket(int fd, char *buf, size_t len, 
159                         struct in_addr *from_addr, int *from_port)
160 {
161         ssize_t ret;
162         struct sockaddr_in sock;
163         socklen_t socklen = sizeof(sock);
164
165         ret = recvfrom(fd,buf,len, 0, (struct sockaddr *)&sock, &socklen);
166         if (ret <= 0) {
167                 DEBUG(2,("read socket failed. ERRNO=%s\n",strerror(errno)));
168                 return 0;
169         }
170
171         if (from_addr) {
172                 *from_addr = sock.sin_addr;
173         }
174         if (from_port) {
175                 *from_port = ntohs(sock.sin_port);
176         }
177
178         return ret;
179 }
180
181
182
183 /****************************************************************************
184  Check the timeout. 
185 ****************************************************************************/
186
187 static BOOL timeout_until(struct timeval *timeout,
188                           const struct timeval *endtime)
189 {
190         struct timeval now;
191
192         GetTimeOfDay(&now);
193
194         if ((now.tv_sec > endtime->tv_sec) ||
195             ((now.tv_sec == endtime->tv_sec) &&
196              (now.tv_usec > endtime->tv_usec)))
197                 return False;
198
199         timeout->tv_sec = endtime->tv_sec - now.tv_sec;
200         timeout->tv_usec = endtime->tv_usec - now.tv_usec;
201         return True;
202 }
203
204
205 /****************************************************************************
206  Read data from the client, reading exactly N bytes, with timeout. 
207 ****************************************************************************/
208
209 ssize_t read_data_until(int fd,char *buffer,size_t N,
210                         const struct timeval *endtime)
211 {
212         ssize_t ret;
213         size_t total=0;  
214  
215         while (total < N) {
216
217                 if (endtime != NULL) {
218                         fd_set r_fds;
219                         struct timeval timeout;
220                         int res;
221
222                         FD_ZERO(&r_fds);
223                         FD_SET(fd, &r_fds);
224
225                         if (!timeout_until(&timeout, endtime))
226                                 return -1;
227
228                         res = sys_select(fd+1, &r_fds, NULL, NULL, &timeout);
229                         if (res <= 0)
230                                 return -1;
231                 }
232
233                 ret = sys_read(fd,buffer + total,N - total);
234
235                 if (ret == 0) {
236                         DEBUG(10,("read_data: read of %d returned 0. Error = %s\n", (int)(N - total), strerror(errno) ));
237                         return 0;
238                 }
239
240                 if (ret == -1) {
241                         DEBUG(0,("read_data: read failure for %d. Error = %s\n", (int)(N - total), strerror(errno) ));
242                         return -1;
243                 }
244                 total += ret;
245         }
246         return (ssize_t)total;
247 }
248
249 /****************************************************************************
250  Write data to a fd with timeout.
251 ****************************************************************************/
252
253 ssize_t write_data_until(int fd,char *buffer,size_t N,
254                          const struct timeval *endtime)
255 {
256         size_t total=0;
257         ssize_t ret;
258
259         while (total < N) {
260
261                 if (endtime != NULL) {
262                         fd_set w_fds;
263                         struct timeval timeout;
264                         int res;
265
266                         FD_ZERO(&w_fds);
267                         FD_SET(fd, &w_fds);
268
269                         if (!timeout_until(&timeout, endtime))
270                                 return -1;
271
272                         res = sys_select(fd+1, NULL, &w_fds, NULL, &timeout);
273                         if (res <= 0)
274                                 return -1;
275                 }
276
277                 ret = sys_write(fd,buffer + total,N - total);
278
279                 if (ret == -1) {
280                         DEBUG(0,("write_data: write failure. Error = %s\n", strerror(errno) ));
281                         return -1;
282                 }
283                 if (ret == 0)
284                         return total;
285
286                 total += ret;
287         }
288         return (ssize_t)total;
289 }
290
291
292 /****************************************************************************
293  Open a socket of the specified type, port, and address for incoming data.
294 ****************************************************************************/
295 int open_socket_in( int type, int port, int dlevel, uint32_t socket_addr, BOOL rebind )
296 {
297         struct sockaddr_in sock;
298         int res;
299
300         memset( (char *)&sock, '\0', sizeof(sock) );
301
302 #ifdef HAVE_SOCK_SIN_LEN
303         sock.sin_len         = sizeof(sock);
304 #endif
305         sock.sin_port        = htons( port );
306         sock.sin_family      = AF_INET;
307         sock.sin_addr.s_addr = socket_addr;
308
309         res = socket( AF_INET, type, 0 );
310         if( res == -1 ) {
311                 DEBUG(0,("open_socket_in(): socket() call failed: %s\n", strerror(errno)));
312                 return -1;
313         }
314
315         /* This block sets/clears the SO_REUSEADDR and possibly SO_REUSEPORT. */
316         {
317                 int val = rebind ? 1 : 0;
318                 setsockopt(res,SOL_SOCKET,SO_REUSEADDR,(char *)&val,sizeof(val));
319 #ifdef SO_REUSEPORT
320                 setsockopt(res,SOL_SOCKET,SO_REUSEPORT,(char *)&val,sizeof(val));
321 #endif
322         }
323
324         /* now we've got a socket - we need to bind it */
325         if( bind( res, (struct sockaddr *)&sock, sizeof(sock) ) == -1 ) {
326                 DEBUG(0,("bind failed on port %d - %s\n", port, strerror(errno)));
327                 close( res ); 
328                 return( -1 ); 
329         }
330
331         DEBUG( 10, ( "bind succeeded on port %d\n", port ) );
332
333         return( res );
334  }
335
336
337 /****************************************************************************
338   create an outgoing socket. timeout is in milliseconds.
339   **************************************************************************/
340 int open_socket_out(int type, struct in_addr *addr, int port, int timeout)
341 {
342         struct sockaddr_in sock_out;
343         int res,ret;
344         int connect_loop = 250; /* 250 milliseconds */
345         int loops = (timeout) / connect_loop;
346
347         /* create a socket to write to */
348         res = socket(PF_INET, type, 0);
349         if (res == -1) 
350         { DEBUG(0,("socket error\n")); return -1; }
351         
352         if (type != SOCK_STREAM) return(res);
353         
354         memset((char *)&sock_out,'\0',sizeof(sock_out));
355         putip((char *)&sock_out.sin_addr,(char *)addr);
356         
357         sock_out.sin_port = htons( port );
358         sock_out.sin_family = PF_INET;
359         
360         /* set it non-blocking */
361         set_blocking(res,False);
362         
363         DEBUG(3,("Connecting to %s at port %d\n",inet_ntoa(*addr),port));
364         
365         /* and connect it to the destination */
366 connect_again:
367         ret = connect(res,(struct sockaddr *)&sock_out,sizeof(sock_out));
368         
369         /* Some systems return EAGAIN when they mean EINPROGRESS */
370         if (ret < 0 && (errno == EINPROGRESS || errno == EALREADY ||
371                         errno == EAGAIN) && loops--) {
372                 msleep(connect_loop);
373                 goto connect_again;
374         }
375         
376         if (ret < 0 && (errno == EINPROGRESS || errno == EALREADY ||
377                         errno == EAGAIN)) {
378                 DEBUG(1,("timeout connecting to %s:%d\n",inet_ntoa(*addr),port));
379                 close(res);
380                 return -1;
381         }
382         
383 #ifdef EISCONN
384         if (ret < 0 && errno == EISCONN) {
385                 errno = 0;
386                 ret = 0;
387         }
388 #endif
389         
390         if (ret < 0) {
391                 DEBUG(2,("error connecting to %s:%d (%s)\n",
392                          inet_ntoa(*addr),port,strerror(errno)));
393                 close(res);
394                 return -1;
395         }
396         
397         /* set it blocking again */
398         set_blocking(res,True);
399         
400         return res;
401 }
402
403 /*
404   open a connected UDP socket to host on port
405 */
406 int open_udp_socket(const char *host, int port)
407 {
408         int type = SOCK_DGRAM;
409         struct sockaddr_in sock_out;
410         int res;
411         struct in_addr addr;
412         TALLOC_CTX *mem_ctx;
413
414         mem_ctx = talloc_init("open_udp_socket");
415         if (!mem_ctx) {
416                 return -1;
417         }
418         addr = interpret_addr2(host);
419
420         res = socket(PF_INET, type, 0);
421         if (res == -1) {
422                 return -1;
423         }
424
425         memset((char *)&sock_out,'\0',sizeof(sock_out));
426         putip((char *)&sock_out.sin_addr,(char *)&addr);
427         sock_out.sin_port = htons(port);
428         sock_out.sin_family = PF_INET;
429         
430         talloc_destroy(mem_ctx);
431
432         if (connect(res,(struct sockaddr *)&sock_out,sizeof(sock_out))) {
433                 close(res);
434                 return -1;
435         }
436
437         return res;
438 }
439
440
441 /*******************************************************************
442  matchname - determine if host name matches IP address. Used to
443  confirm a hostname lookup to prevent spoof attacks
444  ******************************************************************/
445 static BOOL matchname(char *remotehost, struct in_addr addr)
446 {
447         struct hostent *hp;
448         int     i;
449         
450         if ((hp = sys_gethostbyname(remotehost)) == 0) {
451                 DEBUG(0,("sys_gethostbyname(%s): lookup failure.\n", remotehost));
452                 return False;
453         } 
454
455         /*
456          * Make sure that gethostbyname() returns the "correct" host name.
457          * Unfortunately, gethostbyname("localhost") sometimes yields
458          * "localhost.domain". Since the latter host name comes from the
459          * local DNS, we just have to trust it (all bets are off if the local
460          * DNS is perverted). We always check the address list, though.
461          */
462         
463         if (strcasecmp(remotehost, hp->h_name)
464             && strcasecmp(remotehost, "localhost")) {
465                 DEBUG(0,("host name/name mismatch: %s != %s\n",
466                          remotehost, hp->h_name));
467                 return False;
468         }
469         
470         /* Look up the host address in the address list we just got. */
471         for (i = 0; hp->h_addr_list[i]; i++) {
472                 if (memcmp(hp->h_addr_list[i], (char *) & addr, sizeof(addr)) == 0)
473                         return True;
474         }
475         
476         /*
477          * The host name does not map to the original host address. Perhaps
478          * someone has compromised a name server. More likely someone botched
479          * it, but that could be dangerous, too.
480          */
481         
482         DEBUG(0,("host name/address mismatch: %s != %s\n",
483                  inet_ntoa(addr), hp->h_name));
484         return False;
485 }
486
487  
488 /*******************************************************************
489  return the DNS name of the remote end of a socket
490  ******************************************************************/
491 char *get_socket_name(TALLOC_CTX *mem_ctx, int fd, BOOL force_lookup)
492 {
493         char *name_buf;
494         struct hostent *hp;
495         struct in_addr addr;
496         char *p;
497
498         /* reverse lookups can be *very* expensive, and in many
499            situations won't work because many networks don't link dhcp
500            with dns. To avoid the delay we avoid the lookup if
501            possible */
502         if (!lp_hostname_lookups() && (force_lookup == False)) {
503                 return get_socket_addr(mem_ctx, fd);
504         }
505         
506         p = get_socket_addr(mem_ctx, fd);
507
508         name_buf = talloc_strdup(mem_ctx, "UNKNOWN");
509         if (fd == -1) return name_buf;
510
511         addr = interpret_addr2(p);
512         
513         /* Look up the remote host name. */
514         if ((hp = gethostbyaddr((char *)&addr.s_addr, sizeof(addr.s_addr), AF_INET)) == 0) {
515                 DEBUG(1,("Gethostbyaddr failed for %s\n",p));
516                 name_buf = talloc_strdup(mem_ctx, p);
517         } else {
518                 name_buf = talloc_strdup(mem_ctx, (char *)hp->h_name);
519                 if (!matchname(name_buf, addr)) {
520                         DEBUG(0,("Matchname failed on %s %s\n",name_buf,p));
521                         name_buf = talloc_strdup(mem_ctx, "UNKNOWN");
522                 }
523         }
524
525         alpha_strcpy(name_buf, name_buf, "_-.", strlen(name_buf)+1);
526         if (strstr(name_buf,"..")) {
527                 name_buf = talloc_strdup(mem_ctx, "UNKNOWN");
528         }
529
530         return name_buf;
531 }
532
533 /*******************************************************************
534  return the IP addr of the remote end of a socket as a string 
535  ******************************************************************/
536 char *get_socket_addr(TALLOC_CTX *mem_ctx, int fd)
537 {
538         struct sockaddr sa;
539         struct sockaddr_in *sockin = (struct sockaddr_in *) (&sa);
540         int     length = sizeof(sa);
541
542         if (fd == -1 || getpeername(fd, &sa, &length) == -1) {
543                 return talloc_strdup(mem_ctx, "0.0.0.0");
544         }
545         
546         return talloc_strdup(mem_ctx, (char *)inet_ntoa(sockin->sin_addr));
547 }
548
549
550
551 /*******************************************************************
552 this is like socketpair but uses tcp. It is used by the Samba
553 regression test code
554 The function guarantees that nobody else can attach to the socket,
555 or if they do that this function fails and the socket gets closed
556 returns 0 on success, -1 on failure
557 the resulting file descriptors are symmetrical
558  ******************************************************************/
559 static int socketpair_tcp(int fd[2])
560 {
561         int listener;
562         struct sockaddr_in sock;
563         struct sockaddr_in sock2;
564         socklen_t socklen = sizeof(sock);
565         int connect_done = 0;
566         
567         fd[0] = fd[1] = listener = -1;
568
569         memset(&sock, 0, sizeof(sock));
570         
571         if ((listener = socket(PF_INET, SOCK_STREAM, 0)) == -1) goto failed;
572
573         memset(&sock2, 0, sizeof(sock2));
574 #ifdef HAVE_SOCK_SIN_LEN
575         sock2.sin_len = sizeof(sock2);
576 #endif
577         sock2.sin_family = PF_INET;
578
579         bind(listener, (struct sockaddr *)&sock2, sizeof(sock2));
580
581         if (listen(listener, 1) != 0) goto failed;
582
583         if (getsockname(listener, (struct sockaddr *)&sock, &socklen) != 0) goto failed;
584
585         if ((fd[1] = socket(PF_INET, SOCK_STREAM, 0)) == -1) goto failed;
586
587         set_blocking(fd[1], 0);
588
589         sock.sin_addr.s_addr = htonl(INADDR_LOOPBACK);
590
591         if (connect(fd[1],(struct sockaddr *)&sock,sizeof(sock)) == -1) {
592                 if (errno != EINPROGRESS) goto failed;
593         } else {
594                 connect_done = 1;
595         }
596
597         if ((fd[0] = accept(listener, (struct sockaddr *)&sock, &socklen)) == -1) goto failed;
598
599         close(listener);
600         if (connect_done == 0) {
601                 if (connect(fd[1],(struct sockaddr *)&sock,sizeof(sock)) != 0
602                     && errno != EISCONN) goto failed;
603         }
604
605         set_blocking(fd[1], 1);
606
607         /* all OK! */
608         return 0;
609
610  failed:
611         if (fd[0] != -1) close(fd[0]);
612         if (fd[1] != -1) close(fd[1]);
613         if (listener != -1) close(listener);
614         return -1;
615 }
616
617
618 /*******************************************************************
619 run a program on a local tcp socket, this is used to launch smbd
620 when regression testing
621 the return value is a socket which is attached to a subprocess
622 running "prog". stdin and stdout are attached. stderr is left
623 attached to the original stderr
624  ******************************************************************/
625 int sock_exec(const char *prog)
626 {
627         int fd[2];
628         if (socketpair_tcp(fd) != 0) {
629                 DEBUG(0,("socketpair_tcp failed (%s)\n", strerror(errno)));
630                 return -1;
631         }
632         if (fork() == 0) {
633                 close(fd[0]);
634                 close(0);
635                 close(1);
636                 dup(fd[1]);
637                 dup(fd[1]);
638                 exit(system(prog));
639         }
640         close(fd[1]);
641         return fd[0];
642 }
643
644
645 /*
646   determine if a packet is pending for receive on a socket
647 */
648 BOOL socket_pending(int fd)
649 {
650         fd_set fds;
651         int selrtn;
652         struct timeval timeout;
653
654         FD_ZERO(&fds);
655         FD_SET(fd,&fds);
656         
657         /* immediate timeout */
658         timeout.tv_sec = 0;
659         timeout.tv_usec = 0;
660
661         /* yes, this is supposed to be a normal select not a sys_select() */
662         selrtn = select(fd+1,&fds,NULL,NULL,&timeout);
663                 
664         if (selrtn == 1) {
665                 /* the fd is readable */
666                 return True;
667         }
668
669         return False;
670 }