s4-socket: detect NULL server in socket connection code
[mat/samba.git] / source4 / lib / socket / connect_multi.c
1 /* 
2    Unix SMB/CIFS implementation.
3
4    Fire connect requests to a host and a number of ports, with a timeout
5    between the connect request. Return if the first connect comes back
6    successfully or return the last error.
7
8    Copyright (C) Volker Lendecke 2005
9    
10    This program is free software; you can redistribute it and/or modify
11    it under the terms of the GNU General Public License as published by
12    the Free Software Foundation; either version 3 of the License, or
13    (at your option) any later version.
14    
15    This program is distributed in the hope that it will be useful,
16    but WITHOUT ANY WARRANTY; without even the implied warranty of
17    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
18    GNU General Public License for more details.
19    
20    You should have received a copy of the GNU General Public License
21    along with this program.  If not, see <http://www.gnu.org/licenses/>.
22 */
23
24 #include "includes.h"
25 #include "lib/socket/socket.h"
26 #include "lib/events/events.h"
27 #include "libcli/composite/composite.h"
28 #include "libcli/resolve/resolve.h"
29
30 #define MULTI_PORT_DELAY 2000 /* microseconds */
31
32 /*
33   overall state
34 */
35 struct connect_multi_state {
36         struct socket_address *server_address;
37         int num_ports;
38         uint16_t *ports;
39
40         struct socket_context *sock;
41         uint16_t result_port;
42
43         int num_connects_sent, num_connects_recv;
44 };
45
46 /*
47   state of an individual socket_connect_send() call
48 */
49 struct connect_one_state {
50         struct composite_context *result;
51         struct socket_context *sock;
52         struct socket_address *addr;
53 };
54
55 static void continue_resolve_name(struct composite_context *creq);
56 static void connect_multi_timer(struct tevent_context *ev,
57                                     struct tevent_timer *te,
58                                     struct timeval tv, void *p);
59 static void connect_multi_next_socket(struct composite_context *result);
60 static void continue_one(struct composite_context *creq);
61
62 /*
63   setup an async socket_connect, with multiple ports
64 */
65 _PUBLIC_ struct composite_context *socket_connect_multi_send(
66                                                     TALLOC_CTX *mem_ctx,
67                                                     const char *server_name,
68                                                     int num_server_ports,
69                                                     uint16_t *server_ports,
70                                                     struct resolve_context *resolve_ctx,
71                                                     struct tevent_context *event_ctx)
72 {
73         struct composite_context *result;
74         struct connect_multi_state *multi;
75         int i;
76
77         struct nbt_name name;
78         struct composite_context *creq;
79                 
80         result = talloc_zero(mem_ctx, struct composite_context);
81         if (result == NULL) return NULL;
82         result->state = COMPOSITE_STATE_IN_PROGRESS;
83         result->event_ctx = event_ctx;
84
85         multi = talloc_zero(result, struct connect_multi_state);
86         if (composite_nomem(multi, result)) goto failed;
87         result->private_data = multi;
88
89         multi->num_ports = num_server_ports;
90         multi->ports = talloc_array(multi, uint16_t, multi->num_ports);
91         if (composite_nomem(multi->ports, result)) goto failed;
92
93         for (i=0; i<multi->num_ports; i++) {
94                 multi->ports[i] = server_ports[i];
95         }
96
97         /*  
98             we don't want to do the name resolution separately
99                     for each port, so start it now, then only start on
100                     the real sockets once we have an IP
101         */
102         make_nbt_name_server(&name, server_name);
103
104         creq = resolve_name_all_send(resolve_ctx, multi, 0, multi->ports[0], &name, result->event_ctx);
105         if (composite_nomem(creq, result)) goto failed;
106
107         composite_continue(result, creq, continue_resolve_name, result);
108
109         return result;
110
111
112  failed:
113         composite_error(result, result->status);
114         return result;
115 }
116
117 /*
118   start connecting to the next socket/port in the list
119 */
120 static void connect_multi_next_socket(struct composite_context *result)
121 {
122         struct connect_multi_state *multi = talloc_get_type(result->private_data, 
123                                                             struct connect_multi_state);
124         struct connect_one_state *state;
125         struct composite_context *creq;
126         int next = multi->num_connects_sent;
127
128         if (next == multi->num_ports) {
129                 /* don't do anything, just wait for the existing ones to finish */
130                 return;
131         }
132
133         multi->num_connects_sent += 1;
134
135         if (multi->server_address == NULL) {
136                 composite_error(result, NT_STATUS_OBJECT_NAME_NOT_FOUND);
137                 return;
138         }
139
140         state = talloc(multi, struct connect_one_state);
141         if (composite_nomem(state, result)) return;
142
143         state->result = result;
144         result->status = socket_create(multi->server_address->family, SOCKET_TYPE_STREAM, &state->sock, 0);
145         if (!composite_is_ok(result)) return;
146
147         state->addr = socket_address_copy(state, multi->server_address);
148         if (composite_nomem(state->addr, result)) return;
149
150         socket_address_set_port(state->addr, multi->ports[next]);
151
152         talloc_steal(state, state->sock);
153
154         creq = socket_connect_send(state->sock, NULL, 
155                                    state->addr, 0,
156                                    result->event_ctx);
157         if (composite_nomem(creq, result)) return;
158         talloc_steal(state, creq);
159
160         composite_continue(result, creq, continue_one, state);
161
162         /* if there are more ports to go then setup a timer to fire when we have waited
163            for a couple of milli-seconds, when that goes off we try the next port regardless
164            of whether this port has completed */
165         if (multi->num_ports > multi->num_connects_sent) {
166                 /* note that this timer is a child of the single
167                    connect attempt state, so it will go away when this
168                    request completes */
169                 tevent_add_timer(result->event_ctx, state,
170                                 timeval_current_ofs_usec(MULTI_PORT_DELAY),
171                                 connect_multi_timer, result);
172         }
173 }
174
175 /*
176   a timer has gone off telling us that we should try the next port
177 */
178 static void connect_multi_timer(struct tevent_context *ev,
179                                 struct tevent_timer *te,
180                                 struct timeval tv, void *p)
181 {
182         struct composite_context *result = talloc_get_type(p, struct composite_context);
183         connect_multi_next_socket(result);
184 }
185
186
187 /*
188   recv name resolution reply then send the next connect
189 */
190 static void continue_resolve_name(struct composite_context *creq)
191 {
192         struct composite_context *result = talloc_get_type(creq->async.private_data, 
193                                                            struct composite_context);
194         struct connect_multi_state *multi = talloc_get_type(result->private_data, 
195                                                             struct connect_multi_state);
196         struct socket_address **addr;
197
198         result->status = resolve_name_all_recv(creq, multi, &addr, NULL);
199         if (!composite_is_ok(result)) return;
200
201         /* Let's just go for the first for now */
202         multi->server_address = addr[0];
203
204         connect_multi_next_socket(result);
205 }
206
207 /*
208   one of our socket_connect_send() calls hash finished. If it got a
209   connection or there are none left then we are done
210 */
211 static void continue_one(struct composite_context *creq)
212 {
213         struct connect_one_state *state = talloc_get_type(creq->async.private_data, 
214                                                           struct connect_one_state);
215         struct composite_context *result = state->result;
216         struct connect_multi_state *multi = talloc_get_type(result->private_data, 
217                                                             struct connect_multi_state);
218         NTSTATUS status;
219         multi->num_connects_recv++;
220
221         status = socket_connect_recv(creq);
222
223         if (NT_STATUS_IS_OK(status)) {
224                 multi->sock = talloc_steal(multi, state->sock);
225                 multi->result_port = state->addr->port;
226         }
227
228         talloc_free(state);
229
230         if (NT_STATUS_IS_OK(status) || 
231             multi->num_connects_recv == multi->num_ports) {
232                 result->status = status;
233                 composite_done(result);
234                 return;
235         }
236
237         /* try the next port */
238         connect_multi_next_socket(result);
239 }
240
241 /*
242   async recv routine for socket_connect_multi()
243  */
244 _PUBLIC_ NTSTATUS socket_connect_multi_recv(struct composite_context *ctx,
245                                    TALLOC_CTX *mem_ctx,
246                                    struct socket_context **sock,
247                                    uint16_t *port)
248 {
249         NTSTATUS status = composite_wait(ctx);
250         if (NT_STATUS_IS_OK(status)) {
251                 struct connect_multi_state *multi =
252                         talloc_get_type(ctx->private_data,
253                                         struct connect_multi_state);
254                 *sock = talloc_steal(mem_ctx, multi->sock);
255                 *port = multi->result_port;
256         }
257         talloc_free(ctx);
258         return status;
259 }
260
261 NTSTATUS socket_connect_multi(TALLOC_CTX *mem_ctx,
262                               const char *server_address,
263                               int num_server_ports, uint16_t *server_ports,
264                               struct resolve_context *resolve_ctx,
265                               struct tevent_context *event_ctx,
266                               struct socket_context **result,
267                               uint16_t *result_port)
268 {
269         struct composite_context *ctx =
270                 socket_connect_multi_send(mem_ctx, server_address,
271                                           num_server_ports, server_ports,
272                                           resolve_ctx,
273                                           event_ctx);
274         return socket_connect_multi_recv(ctx, mem_ctx, result, result_port);
275 }