s3-lib: Fixed a possible crash bug.
[abartlet/samba.git/.git] / source3 / lib / serverid.c
1 /*
2    Unix SMB/CIFS implementation.
3    Implementation of a reliable server_exists()
4    Copyright (C) Volker Lendecke 2010
5
6    This program is free software; you can redistribute it and/or modify
7    it under the terms of the GNU General Public License as published by
8    the Free Software Foundation; either version 3 of the License, or
9    (at your option) any later version.
10
11    This program is distributed in the hope that it will be useful,
12    but WITHOUT ANY WARRANTY; without even the implied warranty of
13    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
14    GNU General Public License for more details.
15
16    You should have received a copy of the GNU General Public License
17    along with this program.  If not, see <http://www.gnu.org/licenses/>.
18 */
19
20 #include "includes.h"
21 #include "serverid.h"
22
23 struct serverid_key {
24         pid_t pid;
25 #ifdef CLUSTER_SUPPORT
26         uint32_t vnn;
27 #endif
28 };
29
30 struct serverid_data {
31         uint64_t unique_id;
32         uint32_t msg_flags;
33 };
34
35 bool serverid_parent_init(void)
36 {
37         struct tdb_wrap *db;
38
39         /*
40          * Open the tdb in the parent process (smbd) so that our
41          * CLEAR_IF_FIRST optimization in tdb_reopen_all can properly
42          * work.
43          */
44
45         db = tdb_wrap_open(talloc_autofree_context(),
46                            lock_path("serverid.tdb"),
47                            0, TDB_DEFAULT|TDB_CLEAR_IF_FIRST, O_RDWR|O_CREAT,
48                            0644);
49         if (db == NULL) {
50                 DEBUG(1, ("could not open serverid.tdb: %s\n",
51                           strerror(errno)));
52                 return false;
53         }
54         return true;
55 }
56
57 static struct db_context *serverid_db(void)
58 {
59         static struct db_context *db;
60
61         if (db != NULL) {
62                 return db;
63         }
64         db = db_open(talloc_autofree_context(), lock_path("serverid.tdb"),
65                      0, TDB_DEFAULT|TDB_CLEAR_IF_FIRST, O_RDWR|O_CREAT, 0644);
66         return db;
67 }
68
69 static void serverid_fill_key(const struct server_id *id,
70                               struct serverid_key *key)
71 {
72         ZERO_STRUCTP(key);
73         key->pid = id->pid;
74 #ifdef CLUSTER_SUPPORT
75         key->vnn = id->vnn;
76 #endif
77 }
78
79 bool serverid_register(const struct server_id id, uint32_t msg_flags)
80 {
81         struct db_context *db;
82         struct serverid_key key;
83         struct serverid_data data;
84         struct db_record *rec;
85         TDB_DATA tdbkey, tdbdata;
86         NTSTATUS status;
87         bool ret = false;
88
89         db = serverid_db();
90         if (db == NULL) {
91                 return false;
92         }
93
94         serverid_fill_key(&id, &key);
95         tdbkey = make_tdb_data((uint8_t *)&key, sizeof(key));
96
97         rec = db->fetch_locked(db, talloc_tos(), tdbkey);
98         if (rec == NULL) {
99                 DEBUG(1, ("Could not fetch_lock serverid.tdb record\n"));
100                 return false;
101         }
102
103         ZERO_STRUCT(data);
104         data.unique_id = id.unique_id;
105         data.msg_flags = msg_flags;
106
107         tdbdata = make_tdb_data((uint8_t *)&data, sizeof(data));
108         status = rec->store(rec, tdbdata, 0);
109         if (!NT_STATUS_IS_OK(status)) {
110                 DEBUG(1, ("Storing serverid.tdb record failed: %s\n",
111                           nt_errstr(status)));
112                 goto done;
113         }
114         ret = true;
115 done:
116         TALLOC_FREE(rec);
117         return ret;
118 }
119
120 bool serverid_register_msg_flags(const struct server_id id, bool do_reg,
121                                  uint32_t msg_flags)
122 {
123         struct db_context *db;
124         struct serverid_key key;
125         struct serverid_data *data;
126         struct db_record *rec;
127         TDB_DATA tdbkey, tdbdata;
128         NTSTATUS status;
129         bool ret = false;
130
131         db = serverid_db();
132         if (db == NULL) {
133                 return false;
134         }
135
136         serverid_fill_key(&id, &key);
137         tdbkey = make_tdb_data((uint8_t *)&key, sizeof(key));
138
139         rec = db->fetch_locked(db, talloc_tos(), tdbkey);
140         if (rec == NULL) {
141                 DEBUG(1, ("Could not fetch_lock serverid.tdb record\n"));
142                 return false;
143         }
144
145         if (rec->value.dsize != sizeof(struct serverid_data)) {
146                 DEBUG(1, ("serverid record has unexpected size %d "
147                           "(wanted %d)\n", (int)rec->value.dsize,
148                           (int)sizeof(struct serverid_data)));
149                 goto done;
150         }
151
152         data = (struct serverid_data *)rec->value.dptr;
153
154         if (do_reg) {
155                 data->msg_flags |= msg_flags;
156         } else {
157                 data->msg_flags &= ~msg_flags;
158         }
159
160         ZERO_STRUCT(tdbdata);
161
162         status = rec->store(rec, tdbdata, 0);
163         if (!NT_STATUS_IS_OK(status)) {
164                 DEBUG(1, ("Storing serverid.tdb record failed: %s\n",
165                           nt_errstr(status)));
166                 goto done;
167         }
168         ret = true;
169 done:
170         TALLOC_FREE(rec);
171         return ret;
172 }
173
174 bool serverid_deregister(struct server_id id)
175 {
176         struct db_context *db;
177         struct serverid_key key;
178         struct db_record *rec;
179         TDB_DATA tdbkey;
180         NTSTATUS status;
181         bool ret = false;
182
183         db = serverid_db();
184         if (db == NULL) {
185                 return false;
186         }
187
188         serverid_fill_key(&id, &key);
189         tdbkey = make_tdb_data((uint8_t *)&key, sizeof(key));
190
191         rec = db->fetch_locked(db, talloc_tos(), tdbkey);
192         if (rec == NULL) {
193                 DEBUG(1, ("Could not fetch_lock serverid.tdb record\n"));
194                 return false;
195         }
196
197         status = rec->delete_rec(rec);
198         if (!NT_STATUS_IS_OK(status)) {
199                 DEBUG(1, ("Deleting serverid.tdb record failed: %s\n",
200                           nt_errstr(status)));
201                 goto done;
202         }
203         ret = true;
204 done:
205         TALLOC_FREE(rec);
206         return ret;
207 }
208
209 struct serverid_exists_state {
210         const struct server_id *id;
211         bool exists;
212 };
213
214 static int server_exists_parse(TDB_DATA key, TDB_DATA data, void *priv)
215 {
216         struct serverid_exists_state *state =
217                 (struct serverid_exists_state *)priv;
218
219         if (data.dsize != sizeof(struct serverid_data)) {
220                 return -1;
221         }
222
223         /*
224          * Use memcmp, not direct compare. data.dptr might not be
225          * aligned.
226          */
227         state->exists = (memcmp(&state->id->unique_id, data.dptr,
228                                 sizeof(state->id->unique_id)) == 0);
229         return 0;
230 }
231
232 bool serverid_exists(const struct server_id *id)
233 {
234         struct db_context *db;
235         struct serverid_exists_state state;
236         struct serverid_key key;
237         TDB_DATA tdbkey;
238
239         db = serverid_db();
240         if (db == NULL) {
241                 return false;
242         }
243
244         serverid_fill_key(id, &key);
245         tdbkey = make_tdb_data((uint8_t *)&key, sizeof(key));
246
247         state.id = id;
248         state.exists = false;
249
250         if (db->parse_record(db, tdbkey, server_exists_parse, &state) == -1) {
251                 return false;
252         }
253         return state.exists;
254 }
255
256 static bool serverid_rec_parse(const struct db_record *rec,
257                                struct server_id *id, uint32_t *msg_flags)
258 {
259         struct serverid_key key;
260         struct serverid_data data;
261
262         if (rec->key.dsize != sizeof(key)) {
263                 DEBUG(1, ("Found invalid key length %d in serverid.tdb\n",
264                           (int)rec->key.dsize));
265                 return false;
266         }
267         if (rec->value.dsize != sizeof(data)) {
268                 DEBUG(1, ("Found invalid value length %d in serverid.tdb\n",
269                           (int)rec->value.dsize));
270                 return false;
271         }
272
273         memcpy(&key, rec->key.dptr, sizeof(key));
274         memcpy(&data, rec->value.dptr, sizeof(data));
275
276         id->pid = key.pid;
277 #ifdef CLUSTER_SUPPORT
278         id->vnn = key.vnn;
279 #endif
280         id->unique_id = data.unique_id;
281         *msg_flags = data.msg_flags;
282         return true;
283 }
284
285 struct serverid_traverse_read_state {
286         int (*fn)(const struct server_id *id, uint32_t msg_flags,
287                   void *private_data);
288         void *private_data;
289 };
290
291 static int serverid_traverse_read_fn(struct db_record *rec, void *private_data)
292 {
293         struct serverid_traverse_read_state *state =
294                 (struct serverid_traverse_read_state *)private_data;
295         struct server_id id;
296         uint32_t msg_flags;
297
298         if (!serverid_rec_parse(rec, &id, &msg_flags)) {
299                 return 0;
300         }
301         return state->fn(&id, msg_flags,state->private_data);
302 }
303
304 bool serverid_traverse_read(int (*fn)(const struct server_id *id,
305                                       uint32_t msg_flags, void *private_data),
306                             void *private_data)
307 {
308         struct db_context *db;
309         struct serverid_traverse_read_state state;
310
311         db = serverid_db();
312         if (db == NULL) {
313                 return false;
314         }
315         state.fn = fn;
316         state.private_data = private_data;
317         return db->traverse_read(db, serverid_traverse_read_fn, &state);
318 }
319
320 struct serverid_traverse_state {
321         int (*fn)(struct db_record *rec, const struct server_id *id,
322                   uint32_t msg_flags, void *private_data);
323         void *private_data;
324 };
325
326 static int serverid_traverse_fn(struct db_record *rec, void *private_data)
327 {
328         struct serverid_traverse_state *state =
329                 (struct serverid_traverse_state *)private_data;
330         struct server_id id;
331         uint32_t msg_flags;
332
333         if (!serverid_rec_parse(rec, &id, &msg_flags)) {
334                 return 0;
335         }
336         return state->fn(rec, &id, msg_flags, state->private_data);
337 }
338
339 bool serverid_traverse(int (*fn)(struct db_record *rec,
340                                  const struct server_id *id,
341                                  uint32_t msg_flags, void *private_data),
342                             void *private_data)
343 {
344         struct db_context *db;
345         struct serverid_traverse_state state;
346
347         db = serverid_db();
348         if (db == NULL) {
349                 return false;
350         }
351         state.fn = fn;
352         state.private_data = private_data;
353         return db->traverse(db, serverid_traverse_fn, &state);
354 }