Fix sendto_kdc.c on OS X after cm refactoring
[asn/mit-krb5.git] / src / lib / krb5 / os / sendto_kdc.c
1 /* -*- mode: c; c-basic-offset: 4; indent-tabs-mode: nil -*- */
2 /* lib/krb5/os/sendto_kdc.c */
3 /*
4  * Copyright 1990,1991,2001,2002,2004,2005,2007,2008 by the Massachusetts Institute of Technology.
5  * All Rights Reserved.
6  *
7  * Export of this software from the United States of America may
8  *   require a specific license from the United States Government.
9  *   It is the responsibility of any person or organization contemplating
10  *   export to obtain such a license before exporting.
11  *
12  * WITHIN THAT CONSTRAINT, permission to use, copy, modify, and
13  * distribute this software and its documentation for any purpose and
14  * without fee is hereby granted, provided that the above copyright
15  * notice appear in all copies and that both that copyright notice and
16  * this permission notice appear in supporting documentation, and that
17  * the name of M.I.T. not be used in advertising or publicity pertaining
18  * to distribution of the software without specific, written prior
19  * permission.  Furthermore if you modify this software you must label
20  * your software as modified software and not distribute it in such a
21  * fashion that it might be confused with the original M.I.T. software.
22  * M.I.T. makes no representations about the suitability of
23  * this software for any purpose.  It is provided "as is" without express
24  * or implied warranty.
25  */
26
27 /* Send packet to KDC for realm; wait for response, retransmitting
28  * as necessary. */
29
30 #include "fake-addrinfo.h"
31 #include "k5-int.h"
32
33 #include "os-proto.h"
34
35 #if defined(HAVE_POLL_H)
36 #include <poll.h>
37 #define USE_POLL
38 #define MAX_POLLFDS 1024
39 #elif defined(HAVE_SYS_SELECT_H)
40 #include <sys/select.h>
41 #endif
42
43 #ifndef _WIN32
44 /* For FIONBIO.  */
45 #include <sys/ioctl.h>
46 #ifdef HAVE_SYS_FILIO_H
47 #include <sys/filio.h>
48 #endif
49 #endif
50
51 #define MAX_PASS                    3
52 #define DEFAULT_UDP_PREF_LIMIT   1465
53 #define HARD_UDP_LIMIT          32700 /* could probably do 64K-epsilon ? */
54
55 /* Select state flags.  */
56 #define SSF_READ 0x01
57 #define SSF_WRITE 0x02
58 #define SSF_EXCEPTION 0x04
59
60 typedef int64_t time_ms;
61
62 /* This can be pretty large, so should not be stack-allocated. */
63 struct select_state {
64 #ifdef USE_POLL
65     struct pollfd fds[MAX_POLLFDS];
66 #else
67     int max;
68     fd_set rfds, wfds, xfds;
69 #endif
70     int nfds;
71 };
72
73 static const char *const state_strings[] = {
74     "INITIALIZING", "CONNECTING", "WRITING", "READING", "FAILED"
75 };
76
77 /* connection states */
78 enum conn_states { INITIALIZING, CONNECTING, WRITING, READING, FAILED };
79 struct incoming_krb5_message {
80     size_t bufsizebytes_read;
81     size_t bufsize;
82     char *buf;
83     char *pos;
84     unsigned char bufsizebytes[4];
85     size_t n_left;
86 };
87
88 struct conn_state {
89     SOCKET fd;
90     enum conn_states state;
91     int (*service)(krb5_context context, struct conn_state *,
92                    struct select_state *, int);
93     struct remote_address addr;
94     struct {
95         struct {
96             sg_buf sgbuf[2];
97             sg_buf *sgp;
98             int sg_count;
99             unsigned char msg_len_buf[4];
100         } out;
101         struct incoming_krb5_message in;
102     } x;
103     krb5_data callback_buffer;
104     size_t server_index;
105     struct conn_state *next;
106     time_ms endtime;
107 };
108
109 /* Get current time in milliseconds. */
110 static krb5_error_code
111 get_curtime_ms(time_ms *time_out)
112 {
113     struct timeval tv;
114
115     if (gettimeofday(&tv, 0))
116         return errno;
117     *time_out = (time_ms)tv.tv_sec * 1000 + tv.tv_usec / 1000;
118     return 0;
119 }
120
121 #ifdef USE_POLL
122
123 /* Find a pollfd in selstate by fd, or abort if we can't find it. */
124 static inline struct pollfd *
125 find_pollfd(struct select_state *selstate, int fd)
126 {
127     int i;
128
129     for (i = 0; i < selstate->nfds; i++) {
130         if (selstate->fds[i].fd == fd)
131             return &selstate->fds[i];
132     }
133     abort();
134 }
135
136 static void
137 cm_init_selstate(struct select_state *selstate)
138 {
139     selstate->nfds = 0;
140 }
141
142 static krb5_boolean
143 cm_add_fd(struct select_state *selstate, int fd)
144 {
145     if (selstate->nfds >= MAX_POLLFDS)
146         return FALSE;
147     selstate->fds[selstate->nfds].fd = fd;
148     selstate->fds[selstate->nfds].events = 0;
149     selstate->nfds++;
150     return TRUE;
151 }
152
153 static void
154 cm_remove_fd(struct select_state *selstate, int fd)
155 {
156     struct pollfd *pfd = find_pollfd(selstate, fd);
157
158     *pfd = selstate->fds[selstate->nfds - 1];
159     selstate->nfds--;
160 }
161
162 /* Poll for reading (and not writing) on fd the next time we poll. */
163 static void
164 cm_read(struct select_state *selstate, int fd)
165 {
166     find_pollfd(selstate, fd)->events = POLLIN;
167 }
168
169 /* Poll for writing (and not reading) on fd the next time we poll. */
170 static void
171 cm_write(struct select_state *selstate, int fd)
172 {
173     find_pollfd(selstate, fd)->events = POLLOUT;
174 }
175
176 /* Get the output events for fd in the form of ssflags. */
177 static unsigned int
178 cm_get_ssflags(struct select_state *selstate, int fd)
179 {
180     struct pollfd *pfd = find_pollfd(selstate, fd);
181
182     /*
183      * OS X sets POLLHUP without POLLOUT on connection error.  Catch this as
184      * well as other error events such as POLLNVAL, but only if POLLIN and
185      * POLLOUT aren't set, as we can get POLLHUP along with POLLIN with TCP
186      * data still to be read.
187      */
188     if (pfd->revents != 0 && !(pfd->revents & (POLLIN | POLLOUT)))
189         return SSF_EXCEPTION;
190
191     return ((pfd->revents & POLLIN) ? SSF_READ : 0) |
192         ((pfd->revents & POLLOUT) ? SSF_WRITE : 0) |
193         ((pfd->revents & POLLERR) ? SSF_EXCEPTION : 0);
194 }
195
196 #else /* not USE_POLL */
197
198 static void
199 cm_init_selstate(struct select_state *selstate)
200 {
201     selstate->nfds = 0;
202     selstate->max = 0;
203     FD_ZERO(&selstate->rfds);
204     FD_ZERO(&selstate->wfds);
205     FD_ZERO(&selstate->xfds);
206 }
207
208 static krb5_boolean
209 cm_add_fd(struct select_state *selstate, int fd)
210 {
211 #ifndef _WIN32  /* On Windows FD_SETSIZE is a count, not a max value. */
212     if (fd >= FD_SETSIZE)
213         return FALSE;
214 #endif
215     FD_SET(fd, &selstate->xfds);
216     if (selstate->max <= fd)
217         selstate->max = fd + 1;
218     selstate->nfds++;
219     return TRUE;
220 }
221
222 static void
223 cm_remove_fd(struct select_state *selstate, int fd)
224 {
225     FD_CLR(fd, &selstate->rfds);
226     FD_CLR(fd, &selstate->wfds);
227     FD_CLR(fd, &selstate->xfds);
228     if (selstate->max == fd + 1) {
229         while (selstate->max > 0 &&
230                !FD_ISSET(selstate->max - 1, &selstate->rfds) &&
231                !FD_ISSET(selstate->max - 1, &selstate->wfds) &&
232                !FD_ISSET(selstate->max - 1, &selstate->xfds))
233             selstate->max--;
234     }
235     selstate->nfds--;
236 }
237
238 /* Select for reading (and not writing) on fd the next time we select. */
239 static void
240 cm_read(struct select_state *selstate, int fd)
241 {
242     FD_SET(fd, &selstate->rfds);
243     FD_CLR(fd, &selstate->wfds);
244 }
245
246 /* Select for writing (and not reading) on fd the next time we select. */
247 static void
248 cm_write(struct select_state *selstate, int fd)
249 {
250     FD_CLR(fd, &selstate->rfds);
251     FD_SET(fd, &selstate->wfds);
252 }
253
254 /* Get the events for fd from selstate after a select. */
255 static unsigned int
256 cm_get_ssflags(struct select_state *selstate, int fd)
257 {
258     return (FD_ISSET(fd, &selstate->rfds) ? SSF_READ : 0) |
259         (FD_ISSET(fd, &selstate->wfds) ? SSF_WRITE : 0) |
260         (FD_ISSET(fd, &selstate->xfds) ? SSF_EXCEPTION : 0);
261 }
262
263 #endif /* not USE_POLL */
264
265 static krb5_error_code
266 cm_select_or_poll(const struct select_state *in, time_ms endtime,
267                   struct select_state *out, int *sret)
268 {
269 #ifndef USE_POLL
270     struct timeval tv;
271 #endif
272     krb5_error_code retval;
273     time_ms curtime, interval;
274
275     retval = get_curtime_ms(&curtime);
276     if (retval != 0)
277         return retval;
278     interval = (curtime < endtime) ? endtime - curtime : 0;
279
280     /* We don't need a separate copy of the selstate for poll, but use one for
281      * consistency with how we use select. */
282     *out = *in;
283
284 #ifdef USE_POLL
285     *sret = poll(out->fds, out->nfds, interval);
286 #else
287     tv.tv_sec = interval / 1000;
288     tv.tv_usec = interval % 1000 * 1000;
289     *sret = select(out->max, &out->rfds, &out->wfds, &out->xfds, &tv);
290 #endif
291
292     return (*sret < 0) ? SOCKET_ERRNO : 0;
293 }
294
295 static int
296 in_addrlist(struct server_entry *entry, struct serverlist *list)
297 {
298     size_t i;
299     struct server_entry *le;
300
301     for (i = 0; i < list->nservers; i++) {
302         le = &list->servers[i];
303         if (entry->hostname != NULL && le->hostname != NULL &&
304             strcmp(entry->hostname, le->hostname) == 0)
305             return 1;
306         if (entry->hostname == NULL && le->hostname == NULL &&
307             entry->addrlen == le->addrlen &&
308             memcmp(&entry->addr, &le->addr, entry->addrlen) == 0)
309             return 1;
310     }
311     return 0;
312 }
313
314 static int
315 check_for_svc_unavailable (krb5_context context,
316                            const krb5_data *reply,
317                            void *msg_handler_data)
318 {
319     krb5_error_code *retval = (krb5_error_code *)msg_handler_data;
320
321     *retval = 0;
322
323     if (krb5_is_krb_error(reply)) {
324         krb5_error *err_reply;
325
326         if (decode_krb5_error(reply, &err_reply) == 0) {
327             *retval = err_reply->error;
328             krb5_free_error(context, err_reply);
329
330             /* Returning 0 means continue to next KDC */
331             return (*retval != KDC_ERR_SVC_UNAVAILABLE);
332         }
333     }
334
335     return 1;
336 }
337
338 /*
339  * send the formatted request 'message' to a KDC for realm 'realm' and
340  * return the response (if any) in 'reply'.
341  *
342  * If the message is sent and a response is received, 0 is returned,
343  * otherwise an error code is returned.
344  *
345  * The storage for 'reply' is allocated and should be freed by the caller
346  * when finished.
347  */
348
349 krb5_error_code
350 krb5_sendto_kdc(krb5_context context, const krb5_data *message,
351                 const krb5_data *realm, krb5_data *reply, int *use_master,
352                 int tcp_only)
353 {
354     krb5_error_code retval, err;
355     struct serverlist servers;
356     int socktype1 = 0, socktype2 = 0, server_used;
357
358     /*
359      * find KDC location(s) for realm
360      */
361
362     /*
363      * BUG: This code won't return "interesting" errors (e.g., out of mem,
364      * bad config file) from locate_kdc.  KRB5_REALM_CANT_RESOLVE can be
365      * ignored from one query of two, but if only one query is done, or
366      * both return that error, it should be returned to the caller.  Also,
367      * "interesting" errors (not KRB5_KDC_UNREACH) from sendto_{udp,tcp}
368      * should probably be returned as well.
369      */
370
371     TRACE_SENDTO_KDC(context, message->length, realm, *use_master, tcp_only);
372
373     if (!tcp_only && context->udp_pref_limit < 0) {
374         int tmp;
375         retval = profile_get_integer(context->profile,
376                                      KRB5_CONF_LIBDEFAULTS, KRB5_CONF_UDP_PREFERENCE_LIMIT, 0,
377                                      DEFAULT_UDP_PREF_LIMIT, &tmp);
378         if (retval)
379             return retval;
380         if (tmp < 0)
381             tmp = DEFAULT_UDP_PREF_LIMIT;
382         else if (tmp > HARD_UDP_LIMIT)
383             /* In the unlikely case that a *really* big value is
384                given, let 'em use as big as we think we can
385                support.  */
386             tmp = HARD_UDP_LIMIT;
387         context->udp_pref_limit = tmp;
388     }
389
390     if (tcp_only)
391         socktype1 = SOCK_STREAM, socktype2 = 0;
392     else if (message->length <= (unsigned int) context->udp_pref_limit)
393         socktype1 = SOCK_DGRAM, socktype2 = SOCK_STREAM;
394     else
395         socktype1 = SOCK_STREAM, socktype2 = SOCK_DGRAM;
396
397     retval = k5_locate_kdc(context, realm, &servers, *use_master,
398                            tcp_only ? SOCK_STREAM : 0);
399     if (retval)
400         return retval;
401
402     err = 0;
403     retval = k5_sendto(context, message, &servers, socktype1, socktype2,
404                        NULL, reply, NULL, NULL, &server_used,
405                        check_for_svc_unavailable, &err);
406     if (retval == KRB5_KDC_UNREACH) {
407         if (err == KDC_ERR_SVC_UNAVAILABLE) {
408             retval = KRB5KDC_ERR_SVC_UNAVAILABLE;
409         } else {
410             krb5_set_error_message(context, retval,
411                                    _("Cannot contact any KDC for realm "
412                                      "'%.*s'"), realm->length, realm->data);
413         }
414     }
415     if (retval)
416         goto cleanup;
417
418     /* Set use_master to 1 if we ended up talking to a master when we didn't
419      * explicitly request to. */
420     if (*use_master == 0) {
421         struct serverlist mservers;
422         struct server_entry *entry = &servers.servers[server_used];
423         retval = k5_locate_kdc(context, realm, &mservers, TRUE,
424                                entry->socktype);
425         if (retval == 0) {
426             if (in_addrlist(entry, &mservers))
427                 *use_master = 1;
428             k5_free_serverlist(&mservers);
429         }
430         TRACE_SENDTO_KDC_MASTER(context, *use_master);
431         retval = 0;
432     }
433
434 cleanup:
435     k5_free_serverlist(&servers);
436     return retval;
437 }
438
439 /*
440  * Notes:
441  *
442  * Getting "connection refused" on a connected UDP socket causes
443  * select to indicate write capability on UNIX, but only shows up
444  * as an exception on Windows.  (I don't think any UNIX system flags
445  * the error as an exception.)  So we check for both, or make it
446  * system-specific.
447  *
448  * Always watch for responses from *any* of the servers.  Eventually
449  * fix the UDP code to do the same.
450  *
451  * To do:
452  * - TCP NOPUSH/CORK socket options?
453  * - error codes that don't suck
454  * - getsockopt(SO_ERROR) to check connect status
455  * - handle error RESPONSE_TOO_BIG from UDP server and use TCP
456  *   connections already in progress
457  */
458
459 static int service_tcp_fd(krb5_context context, struct conn_state *conn,
460                           struct select_state *selstate, int ssflags);
461 static int service_udp_fd(krb5_context context, struct conn_state *conn,
462                           struct select_state *selstate, int ssflags);
463
464 static void
465 set_conn_state_msg_length (struct conn_state *state, const krb5_data *message)
466 {
467     if (!message || message->length == 0)
468         return;
469
470     if (state->addr.type == SOCK_STREAM) {
471         store_32_be(message->length, state->x.out.msg_len_buf);
472         SG_SET(&state->x.out.sgbuf[0], state->x.out.msg_len_buf, 4);
473         SG_SET(&state->x.out.sgbuf[1], message->data, message->length);
474         state->x.out.sg_count = 2;
475
476     } else {
477
478         SG_SET(&state->x.out.sgbuf[0], message->data, message->length);
479         SG_SET(&state->x.out.sgbuf[1], 0, 0);
480         state->x.out.sg_count = 1;
481
482     }
483 }
484
485 static krb5_error_code
486 add_connection(struct conn_state **conns, struct addrinfo *ai,
487                size_t server_index, const krb5_data *message, char **udpbufp)
488 {
489     struct conn_state *state, **tailptr;
490
491     state = calloc(1, sizeof(*state));
492     if (state == NULL)
493         return ENOMEM;
494     state->state = INITIALIZING;
495     state->x.out.sgp = state->x.out.sgbuf;
496     state->addr.type = ai->ai_socktype;
497     state->addr.family = ai->ai_family;
498     state->addr.len = ai->ai_addrlen;
499     memcpy(&state->addr.saddr, ai->ai_addr, ai->ai_addrlen);
500     state->fd = INVALID_SOCKET;
501     state->server_index = server_index;
502     SG_SET(&state->x.out.sgbuf[1], 0, 0);
503     if (ai->ai_socktype == SOCK_STREAM) {
504         state->service = service_tcp_fd;
505         set_conn_state_msg_length (state, message);
506     } else {
507         state->service = service_udp_fd;
508         set_conn_state_msg_length (state, message);
509
510         if (*udpbufp == NULL) {
511             *udpbufp = malloc(MAX_DGRAM_SIZE);
512             if (*udpbufp == 0)
513                 return ENOMEM;
514         }
515         state->x.in.buf = *udpbufp;
516         state->x.in.bufsize = MAX_DGRAM_SIZE;
517     }
518
519     /* Chain the new state onto the tail of the list. */
520     for (tailptr = conns; *tailptr != NULL; tailptr = &(*tailptr)->next);
521     *tailptr = state;
522
523     return 0;
524 }
525
526 static int
527 translate_ai_error (int err)
528 {
529     switch (err) {
530     case 0:
531         return 0;
532     case EAI_BADFLAGS:
533     case EAI_FAMILY:
534     case EAI_SOCKTYPE:
535     case EAI_SERVICE:
536         /* All of these indicate bad inputs to getaddrinfo.  */
537         return EINVAL;
538     case EAI_AGAIN:
539         /* Translate to standard errno code.  */
540         return EAGAIN;
541     case EAI_MEMORY:
542         /* Translate to standard errno code.  */
543         return ENOMEM;
544 #ifdef EAI_ADDRFAMILY
545     case EAI_ADDRFAMILY:
546 #endif
547 #if defined(EAI_NODATA) && EAI_NODATA != EAI_NONAME
548     case EAI_NODATA:
549 #endif
550     case EAI_NONAME:
551         /* Name not known or no address data, but no error.  Do
552            nothing more.  */
553         return 0;
554 #ifdef EAI_OVERFLOW
555     case EAI_OVERFLOW:
556         /* An argument buffer overflowed.  */
557         return EINVAL;          /* XXX */
558 #endif
559 #ifdef EAI_SYSTEM
560     case EAI_SYSTEM:
561         /* System error, obviously.  */
562         return errno;
563 #endif
564     default:
565         /* An error code we haven't handled?  */
566         return EINVAL;
567     }
568 }
569
570 /*
571  * Resolve the entry in servers with index ind, adding connections to the list
572  * *conns.  Connections are added for each of socktype1 and (if not zero)
573  * socktype2.  message and udpbufp are used to initialize the connections; see
574  * add_connection above.  If no addresses are available for an entry but no
575  * internal name resolution failure occurs, return 0 without adding any new
576  * connections.
577  */
578 static krb5_error_code
579 resolve_server(krb5_context context, const struct serverlist *servers,
580                size_t ind, int socktype1, int socktype2,
581                const krb5_data *message, char **udpbufp,
582                struct conn_state **conns)
583 {
584     krb5_error_code retval;
585     struct server_entry *entry = &servers->servers[ind];
586     struct addrinfo *addrs, *a, hint, ai;
587     int err, result;
588     char portbuf[64];
589
590     /* Skip any stray entries of socktypes we don't want. */
591     if (entry->socktype != 0 && entry->socktype != socktype1 &&
592         entry->socktype != socktype2)
593         return 0;
594
595     if (entry->hostname == NULL) {
596         ai.ai_socktype = entry->socktype;
597         ai.ai_family = entry->family;
598         ai.ai_addrlen = entry->addrlen;
599         ai.ai_addr = (struct sockaddr *)&entry->addr;
600         return add_connection(conns, &ai, ind, message, udpbufp);
601     }
602
603     memset(&hint, 0, sizeof(hint));
604     hint.ai_family = entry->family;
605     hint.ai_socktype = (entry->socktype != 0) ? entry->socktype : socktype1;
606     hint.ai_flags = AI_ADDRCONFIG;
607 #ifdef AI_NUMERICSERV
608     hint.ai_flags |= AI_NUMERICSERV;
609 #endif
610     result = snprintf(portbuf, sizeof(portbuf), "%d", ntohs(entry->port));
611     if (SNPRINTF_OVERFLOW(result, sizeof(portbuf)))
612         return EINVAL;
613     TRACE_SENDTO_KDC_RESOLVING(context, entry->hostname);
614     err = getaddrinfo(entry->hostname, portbuf, &hint, &addrs);
615     if (err)
616         return translate_ai_error(err);
617     /* Add each address with the preferred socktype. */
618     retval = 0;
619     for (a = addrs; a != 0 && retval == 0; a = a->ai_next)
620         retval = add_connection(conns, a, ind, message, udpbufp);
621     if (retval == 0 && entry->socktype == 0 && socktype2 != 0) {
622         /* Add each address again with the non-preferred socktype. */
623         for (a = addrs; a != 0 && retval == 0; a = a->ai_next) {
624             a->ai_socktype = socktype2;
625             retval = add_connection(conns, a, ind, message, udpbufp);
626         }
627     }
628     freeaddrinfo(addrs);
629     return retval;
630 }
631
632 static int
633 start_connection(krb5_context context, struct conn_state *state,
634                  struct select_state *selstate,
635                  struct sendto_callback_info *callback_info)
636 {
637     int fd, e;
638     static const int one = 1;
639     static const struct linger lopt = { 0, 0 };
640
641     fd = socket(state->addr.family, state->addr.type, 0);
642     if (fd == INVALID_SOCKET)
643         return -1;              /* try other hosts */
644     set_cloexec_fd(fd);
645     /* Make it non-blocking.  */
646     ioctlsocket(fd, FIONBIO, (const void *) &one);
647     if (state->addr.type == SOCK_STREAM) {
648         setsockopt(fd, SOL_SOCKET, SO_LINGER, &lopt, sizeof(lopt));
649         TRACE_SENDTO_KDC_TCP_CONNECT(context, &state->addr);
650     }
651
652     /* Start connecting to KDC.  */
653     e = connect(fd, (struct sockaddr *)&state->addr.saddr, state->addr.len);
654     if (e != 0) {
655         /*
656          * This is the path that should be followed for non-blocking
657          * connections.
658          */
659         if (SOCKET_ERRNO == EINPROGRESS || SOCKET_ERRNO == EWOULDBLOCK) {
660             state->state = CONNECTING;
661             state->fd = fd;
662         } else {
663             (void) closesocket(fd);
664             state->state = FAILED;
665             return -2;
666         }
667     } else {
668         /*
669          * Connect returned zero even though we made it non-blocking.  This
670          * happens normally for UDP sockets, and can perhaps also happen for
671          * TCP sockets connecting to localhost.
672          */
673         state->state = WRITING;
674         state->fd = fd;
675     }
676
677     /*
678      * Here's where KPASSWD callback gets the socket information it needs for
679      * a kpasswd request
680      */
681     if (callback_info) {
682
683         e = callback_info->pfn_callback(state->fd, callback_info->data,
684                                         &state->callback_buffer);
685         if (e != 0) {
686             (void) closesocket(fd);
687             state->fd = INVALID_SOCKET;
688             state->state = FAILED;
689             return -3;
690         }
691
692         set_conn_state_msg_length(state, &state->callback_buffer);
693     }
694
695     if (state->addr.type == SOCK_DGRAM) {
696         /* Send it now.  */
697         ssize_t ret;
698         sg_buf *sg = &state->x.out.sgbuf[0];
699
700         TRACE_SENDTO_KDC_UDP_SEND_INITIAL(context, &state->addr);
701         ret = send(state->fd, SG_BUF(sg), SG_LEN(sg), 0);
702         if (ret < 0 || (size_t) ret != SG_LEN(sg)) {
703             TRACE_SENDTO_KDC_UDP_ERROR_SEND_INITIAL(context, &state->addr,
704                                                     SOCKET_ERRNO);
705             (void) closesocket(state->fd);
706             state->fd = INVALID_SOCKET;
707             state->state = FAILED;
708             return -4;
709         } else {
710             state->state = READING;
711         }
712     }
713
714     if (!cm_add_fd(selstate, state->fd)) {
715         (void) closesocket(state->fd);
716         state->fd = INVALID_SOCKET;
717         state->state = FAILED;
718         return -1;
719     }
720     if (state->state == CONNECTING || state->state == WRITING)
721         cm_write(selstate, state->fd);
722     else
723         cm_read(selstate, state->fd);
724
725     return 0;
726 }
727
728 /* Return 0 if we sent something, non-0 otherwise.
729    If 0 is returned, the caller should delay waiting for a response.
730    Otherwise, the caller should immediately move on to process the
731    next connection.  */
732 static int
733 maybe_send(krb5_context context, struct conn_state *conn,
734            struct select_state *selstate,
735            struct sendto_callback_info *callback_info)
736 {
737     sg_buf *sg;
738     ssize_t ret;
739
740     if (conn->state == INITIALIZING)
741         return start_connection(context, conn, selstate, callback_info);
742
743     /* Did we already shut down this channel?  */
744     if (conn->state == FAILED) {
745         return -1;
746     }
747
748     if (conn->addr.type == SOCK_STREAM) {
749         /* The select callback will handle flushing any data we
750            haven't written yet, and we only write it once.  */
751         return -1;
752     }
753
754     /* UDP - retransmit after a previous attempt timed out. */
755     sg = &conn->x.out.sgbuf[0];
756     TRACE_SENDTO_KDC_UDP_SEND_RETRY(context, &conn->addr);
757     ret = send(conn->fd, SG_BUF(sg), SG_LEN(sg), 0);
758     if (ret < 0 || (size_t) ret != SG_LEN(sg)) {
759         TRACE_SENDTO_KDC_UDP_ERROR_SEND_RETRY(context, &conn->addr,
760                                               SOCKET_ERRNO);
761         /* Keep connection alive, we'll try again next pass.
762
763            Is this likely to catch any errors we didn't get from the
764            select callbacks?  */
765         return -1;
766     }
767     /* Yay, it worked.  */
768     return 0;
769 }
770
771 static void
772 kill_conn(struct conn_state *conn, struct select_state *selstate)
773 {
774     cm_remove_fd(selstate, conn->fd);
775     closesocket(conn->fd);
776     conn->fd = INVALID_SOCKET;
777     conn->state = FAILED;
778 }
779
780 /* Check socket for error.  */
781 static int
782 get_so_error(int fd)
783 {
784     int e, sockerr;
785     socklen_t sockerrlen;
786
787     sockerr = 0;
788     sockerrlen = sizeof(sockerr);
789     e = getsockopt(fd, SOL_SOCKET, SO_ERROR, &sockerr, &sockerrlen);
790     if (e != 0) {
791         /* What to do now?  */
792         e = SOCKET_ERRNO;
793         return e;
794     }
795     return sockerr;
796 }
797
798 /* Process events on a TCP socket.  Return 1 if we get a complete reply. */
799 static int
800 service_tcp_fd(krb5_context context, struct conn_state *conn,
801                struct select_state *selstate, int ssflags)
802 {
803     int e = 0;
804     ssize_t nwritten, nread;
805     SOCKET_WRITEV_TEMP tmp;
806
807     /* Check for a socket exception. */
808     if (ssflags & SSF_EXCEPTION)
809         goto kill_conn;
810
811     switch (conn->state) {
812     case CONNECTING:
813         /* Check whether the connection succeeded. */
814         e = get_so_error(conn->fd);
815         if (e) {
816             TRACE_SENDTO_KDC_TCP_ERROR_CONNECT(context, &conn->addr, e);
817             goto kill_conn;
818         }
819         conn->state = WRITING;
820
821         /* Record this connection's timeout for service_fds. */
822         if (get_curtime_ms(&conn->endtime) == 0)
823             conn->endtime += 10000;
824
825         /* Fall through. */
826     case WRITING:
827         TRACE_SENDTO_KDC_TCP_SEND(context, &conn->addr);
828         nwritten = SOCKET_WRITEV(conn->fd, conn->x.out.sgp,
829                                  conn->x.out.sg_count, tmp);
830         if (nwritten < 0) {
831             TRACE_SENDTO_KDC_TCP_ERROR_SEND(context, &conn->addr,
832                                             SOCKET_ERRNO);
833             goto kill_conn;
834         }
835         while (nwritten) {
836             sg_buf *sgp = conn->x.out.sgp;
837             if ((size_t) nwritten < SG_LEN(sgp)) {
838                 SG_ADVANCE(sgp, (size_t) nwritten);
839                 nwritten = 0;
840             } else {
841                 nwritten -= SG_LEN(sgp);
842                 conn->x.out.sgp++;
843                 conn->x.out.sg_count--;
844             }
845         }
846         if (conn->x.out.sg_count == 0) {
847             /* Done writing, switch to reading. */
848             cm_read(selstate, conn->fd);
849             conn->state = READING;
850             conn->x.in.bufsizebytes_read = 0;
851             conn->x.in.bufsize = 0;
852             conn->x.in.buf = 0;
853             conn->x.in.pos = 0;
854             conn->x.in.n_left = 0;
855         }
856         return 0;
857
858     case READING:
859         if (conn->x.in.bufsizebytes_read == 4) {
860             /* Reading data.  */
861             nread = SOCKET_READ(conn->fd, conn->x.in.pos, conn->x.in.n_left);
862             if (nread <= 0) {
863                 e = nread ? SOCKET_ERRNO : ECONNRESET;
864                 TRACE_SENDTO_KDC_TCP_ERROR_RECV(context, &conn->addr, e);
865                 goto kill_conn;
866             }
867             conn->x.in.n_left -= nread;
868             conn->x.in.pos += nread;
869             if (conn->x.in.n_left <= 0)
870                 return 1;
871         } else {
872             /* Reading length.  */
873             nread = SOCKET_READ(conn->fd,
874                                 conn->x.in.bufsizebytes + conn->x.in.bufsizebytes_read,
875                                 4 - conn->x.in.bufsizebytes_read);
876             if (nread <= 0) {
877                 e = nread ? SOCKET_ERRNO : ECONNRESET;
878                 TRACE_SENDTO_KDC_TCP_ERROR_RECV_LEN(context, &conn->addr, e);
879                 goto kill_conn;
880             }
881             conn->x.in.bufsizebytes_read += nread;
882             if (conn->x.in.bufsizebytes_read == 4) {
883                 unsigned long len = load_32_be (conn->x.in.bufsizebytes);
884                 /* Arbitrary 1M cap.  */
885                 if (len > 1 * 1024 * 1024)
886                     goto kill_conn;
887                 conn->x.in.bufsize = conn->x.in.n_left = len;
888                 conn->x.in.buf = conn->x.in.pos = malloc(len);
889                 if (conn->x.in.buf == 0)
890                     goto kill_conn;
891             }
892         }
893         break;
894
895     default:
896         abort();
897     }
898     return 0;
899
900 kill_conn:
901     TRACE_SENDTO_KDC_TCP_DISCONNECT(context, &conn->addr);
902     kill_conn(conn, selstate);
903     return 0;
904 }
905
906 /* Process events on a UDP socket.  Return 1 if we get a reply. */
907 static int
908 service_udp_fd(krb5_context context, struct conn_state *conn,
909                struct select_state *selstate, int ssflags)
910 {
911     int nread;
912
913     if (!(ssflags & (SSF_READ|SSF_EXCEPTION)))
914         abort();
915     if (conn->state != READING)
916         abort();
917
918     nread = recv(conn->fd, conn->x.in.buf, conn->x.in.bufsize, 0);
919     if (nread < 0) {
920         TRACE_SENDTO_KDC_UDP_ERROR_RECV(context, &conn->addr, SOCKET_ERRNO);
921         kill_conn(conn, selstate);
922         return 0;
923     }
924     conn->x.in.pos = conn->x.in.buf + nread;
925     return 1;
926 }
927
928 /* Return the maximum of endtime and the endtime fields of all currently active
929  * TCP connections. */
930 static time_ms
931 get_endtime(time_ms endtime, struct conn_state *conns)
932 {
933     struct conn_state *state;
934
935     for (state = conns; state != NULL; state = state->next) {
936         if (state->addr.type == SOCK_STREAM &&
937             (state->state == READING || state->state == WRITING) &&
938             state->endtime > endtime)
939             endtime = state->endtime;
940     }
941     return endtime;
942 }
943
944 static krb5_boolean
945 service_fds(krb5_context context, struct select_state *selstate,
946             time_ms interval, struct conn_state *conns,
947             struct select_state *seltemp,
948             int (*msg_handler)(krb5_context, const krb5_data *, void *),
949             void *msg_handler_data, struct conn_state **winner_out)
950 {
951     int e, selret = 0;
952     time_ms endtime;
953     struct conn_state *state;
954
955     *winner_out = NULL;
956
957     e = get_curtime_ms(&endtime);
958     if (e)
959         return 1;
960     endtime += interval;
961
962     e = 0;
963     while (selstate->nfds > 0) {
964         e = cm_select_or_poll(selstate, get_endtime(endtime, conns),
965                               seltemp, &selret);
966         if (e == EINTR)
967             continue;
968         if (e != 0)
969             break;
970
971         if (selret == 0)
972             /* Timeout, return to caller.  */
973             return 0;
974
975         /* Got something on a socket, process it.  */
976         for (state = conns; state != NULL; state = state->next) {
977             int ssflags;
978
979             if (state->fd == INVALID_SOCKET)
980                 continue;
981             ssflags = cm_get_ssflags(seltemp, state->fd);
982             if (!ssflags)
983                 continue;
984
985             if (state->service(context, state, selstate, ssflags)) {
986                 int stop = 1;
987
988                 if (msg_handler != NULL) {
989                     krb5_data reply;
990
991                     reply.data = state->x.in.buf;
992                     reply.length = state->x.in.pos - state->x.in.buf;
993
994                     stop = (msg_handler(context, &reply, msg_handler_data) != 0);
995                 }
996
997                 if (stop) {
998                     *winner_out = state;
999                     return 1;
1000                 }
1001             }
1002         }
1003     }
1004     if (e != 0)
1005         return 1;
1006     return 0;
1007 }
1008
1009 /*
1010  * Current worst-case timeout behavior:
1011  *
1012  * First pass, 1s per udp or tcp server, plus 2s at end.
1013  * Second pass, 1s per udp server, plus 4s.
1014  * Third pass, 1s per udp server, plus 8s.
1015  * Fourth => 16s, etc.
1016  *
1017  * Restated:
1018  * Per UDP server, 1s per pass.
1019  * Per TCP server, 1s.
1020  * Backoff delay, 2**(P+1) - 2, where P is total number of passes.
1021  *
1022  * Total = 2**(P+1) + U*P + T - 2.
1023  *
1024  * If P=3, Total = 3*U + T + 14.
1025  * If P=4, Total = 4*U + T + 30.
1026  *
1027  * Note that if you try to reach two ports (e.g., both 88 and 750) on
1028  * one server, it counts as two.
1029  *
1030  * There is one exception to the above rules.  Whenever a TCP connection is
1031  * established, we wait up to ten seconds for it to finish or fail before
1032  * moving on.  This reduces network traffic significantly in a TCP environment.
1033  */
1034
1035 krb5_error_code
1036 k5_sendto(krb5_context context, const krb5_data *message,
1037           const struct serverlist *servers, int socktype1, int socktype2,
1038           struct sendto_callback_info* callback_info, krb5_data *reply,
1039           struct sockaddr *remoteaddr, socklen_t *remoteaddrlen,
1040           int *server_used,
1041           /* return 0 -> keep going, 1 -> quit */
1042           int (*msg_handler)(krb5_context, const krb5_data *, void *),
1043           void *msg_handler_data)
1044 {
1045     int pass;
1046     time_ms delay;
1047     krb5_error_code retval;
1048     struct conn_state *conns = NULL, *state, **tailptr, *next, *winner;
1049     size_t s;
1050     struct select_state *sel_state = NULL, *seltemp;
1051     char *udpbuf = NULL;
1052     krb5_boolean done = FALSE;
1053
1054     reply->data = 0;
1055     reply->length = 0;
1056
1057     /* One for use here, listing all our fds in use, and one for
1058      * temporary use in service_fds, for the fds of interest.  */
1059     sel_state = malloc(2 * sizeof(*sel_state));
1060     if (sel_state == NULL) {
1061         retval = ENOMEM;
1062         goto cleanup;
1063     }
1064     seltemp = &sel_state[1];
1065     cm_init_selstate(sel_state);
1066
1067     /* First pass: resolve server hosts, communicate with resulting addresses
1068      * of the preferred socktype, and wait 1s for an answer from each. */
1069     for (s = 0; s < servers->nservers && !done; s++) {
1070         /* Find the current tail pointer. */
1071         for (tailptr = &conns; *tailptr != NULL; tailptr = &(*tailptr)->next);
1072         retval = resolve_server(context, servers, s, socktype1, socktype2,
1073                                 message, &udpbuf, &conns);
1074         if (retval)
1075             goto cleanup;
1076         for (state = *tailptr; state != NULL && !done; state = state->next) {
1077             /* Contact each new connection whose socktype matches socktype1. */
1078             if (state->addr.type != socktype1)
1079                 continue;
1080             if (maybe_send(context, state, sel_state, callback_info))
1081                 continue;
1082             done = service_fds(context, sel_state, 1000, conns, seltemp,
1083                                msg_handler, msg_handler_data, &winner);
1084         }
1085     }
1086
1087     /* Complete the first pass by contacting servers of the non-preferred
1088      * socktype (if given), waiting 1s for an answer from each. */
1089     for (state = conns; state != NULL && !done; state = state->next) {
1090         if (state->addr.type != socktype2)
1091             continue;
1092         if (maybe_send(context, state, sel_state, callback_info))
1093             continue;
1094         done = service_fds(context, sel_state, 1000, conns, seltemp,
1095                            msg_handler, msg_handler_data, &winner);
1096     }
1097
1098     /* Wait for two seconds at the end of the first pass. */
1099     if (!done) {
1100         done = service_fds(context, sel_state, 2000, conns, seltemp,
1101                            msg_handler, msg_handler_data, &winner);
1102     }
1103
1104     /* Make remaining passes over all of the connections. */
1105     delay = 4000;
1106     for (pass = 1; pass < MAX_PASS && !done; pass++) {
1107         for (state = conns; state != NULL && !done; state = state->next) {
1108             if (maybe_send(context, state, sel_state, callback_info))
1109                 continue;
1110             done = service_fds(context, sel_state, 1000, conns, seltemp,
1111                                msg_handler, msg_handler_data, &winner);
1112             if (sel_state->nfds == 0)
1113                 break;
1114         }
1115         /* Wait for the delay backoff at the end of this pass. */
1116         if (!done) {
1117             done = service_fds(context, sel_state, delay, conns, seltemp,
1118                                msg_handler, msg_handler_data, &winner);
1119         }
1120         if (sel_state->nfds == 0)
1121             break;
1122         delay *= 2;
1123     }
1124
1125     if (sel_state->nfds == 0 || !done || winner == NULL) {
1126         retval = KRB5_KDC_UNREACH;
1127         goto cleanup;
1128     }
1129     /* Success!  */
1130     reply->data = winner->x.in.buf;
1131     reply->length = winner->x.in.pos - winner->x.in.buf;
1132     retval = 0;
1133     winner->x.in.buf = NULL;
1134     if (server_used != NULL)
1135         *server_used = winner->server_index;
1136     if (remoteaddr != NULL && remoteaddrlen != 0 && *remoteaddrlen > 0)
1137         (void)getpeername(winner->fd, remoteaddr, remoteaddrlen);
1138     TRACE_SENDTO_KDC_RESPONSE(context, reply->length, &winner->addr);
1139
1140 cleanup:
1141     for (state = conns; state != NULL; state = next) {
1142         next = state->next;
1143         if (state->fd != INVALID_SOCKET)
1144             closesocket(state->fd);
1145         if (state->state == READING && state->x.in.buf != udpbuf)
1146             free(state->x.in.buf);
1147         if (callback_info) {
1148             callback_info->pfn_cleanup(callback_info->data,
1149                                        &state->callback_buffer);
1150         }
1151         free(state);
1152     }
1153
1154     if (reply->data != udpbuf)
1155         free(udpbuf);
1156     free(sel_state);
1157     return retval;
1158 }