b1f6a5711ef65abac9d7939f600c9cca143066bf
[ddiss/samba.git] / source3 / lib / serverid.c
1 /*
2    Unix SMB/CIFS implementation.
3    Implementation of a reliable server_exists()
4    Copyright (C) Volker Lendecke 2010
5
6    This program is free software; you can redistribute it and/or modify
7    it under the terms of the GNU General Public License as published by
8    the Free Software Foundation; either version 3 of the License, or
9    (at your option) any later version.
10
11    This program is distributed in the hope that it will be useful,
12    but WITHOUT ANY WARRANTY; without even the implied warranty of
13    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
14    GNU General Public License for more details.
15
16    You should have received a copy of the GNU General Public License
17    along with this program.  If not, see <http://www.gnu.org/licenses/>.
18 */
19
20 #include "includes.h"
21 #include "system/filesys.h"
22 #include "serverid.h"
23 #include "util_tdb.h"
24 #include "dbwrap.h"
25 #include "lib/util/tdb_wrap.h"
26
27 struct serverid_key {
28         pid_t pid;
29         uint32_t vnn;
30 };
31
32 struct serverid_data {
33         uint64_t unique_id;
34         uint32_t msg_flags;
35 };
36
37 bool serverid_parent_init(TALLOC_CTX *mem_ctx)
38 {
39         struct tdb_wrap *db;
40
41         /*
42          * Open the tdb in the parent process (smbd) so that our
43          * CLEAR_IF_FIRST optimization in tdb_reopen_all can properly
44          * work.
45          */
46
47         db = tdb_wrap_open(mem_ctx, lock_path("serverid.tdb"),
48                            0, TDB_DEFAULT|TDB_CLEAR_IF_FIRST|TDB_INCOMPATIBLE_HASH, O_RDWR|O_CREAT,
49                            0644);
50         if (db == NULL) {
51                 DEBUG(1, ("could not open serverid.tdb: %s\n",
52                           strerror(errno)));
53                 return false;
54         }
55         return true;
56 }
57
58 static struct db_context *serverid_db(void)
59 {
60         static struct db_context *db;
61
62         if (db != NULL) {
63                 return db;
64         }
65         db = db_open(NULL, lock_path("serverid.tdb"), 0,
66                      TDB_DEFAULT|TDB_CLEAR_IF_FIRST|TDB_INCOMPATIBLE_HASH, O_RDWR|O_CREAT, 0644);
67         return db;
68 }
69
70 static void serverid_fill_key(const struct server_id *id,
71                               struct serverid_key *key)
72 {
73         ZERO_STRUCTP(key);
74         key->pid = id->pid;
75         key->vnn = id->vnn;
76 }
77
78 bool serverid_register(const struct server_id id, uint32_t msg_flags)
79 {
80         struct db_context *db;
81         struct serverid_key key;
82         struct serverid_data data;
83         struct db_record *rec;
84         TDB_DATA tdbkey, tdbdata;
85         NTSTATUS status;
86         bool ret = false;
87
88         db = serverid_db();
89         if (db == NULL) {
90                 return false;
91         }
92
93         serverid_fill_key(&id, &key);
94         tdbkey = make_tdb_data((uint8_t *)&key, sizeof(key));
95
96         rec = db->fetch_locked(db, talloc_tos(), tdbkey);
97         if (rec == NULL) {
98                 DEBUG(1, ("Could not fetch_lock serverid.tdb record\n"));
99                 return false;
100         }
101
102         ZERO_STRUCT(data);
103         data.unique_id = id.unique_id;
104         data.msg_flags = msg_flags;
105
106         tdbdata = make_tdb_data((uint8_t *)&data, sizeof(data));
107         status = rec->store(rec, tdbdata, 0);
108         if (!NT_STATUS_IS_OK(status)) {
109                 DEBUG(1, ("Storing serverid.tdb record failed: %s\n",
110                           nt_errstr(status)));
111                 goto done;
112         }
113         ret = true;
114 done:
115         TALLOC_FREE(rec);
116         return ret;
117 }
118
119 bool serverid_register_msg_flags(const struct server_id id, bool do_reg,
120                                  uint32_t msg_flags)
121 {
122         struct db_context *db;
123         struct serverid_key key;
124         struct serverid_data *data;
125         struct db_record *rec;
126         TDB_DATA tdbkey;
127         NTSTATUS status;
128         bool ret = false;
129
130         db = serverid_db();
131         if (db == NULL) {
132                 return false;
133         }
134
135         serverid_fill_key(&id, &key);
136         tdbkey = make_tdb_data((uint8_t *)&key, sizeof(key));
137
138         rec = db->fetch_locked(db, talloc_tos(), tdbkey);
139         if (rec == NULL) {
140                 DEBUG(1, ("Could not fetch_lock serverid.tdb record\n"));
141                 return false;
142         }
143
144         if (rec->value.dsize != sizeof(struct serverid_data)) {
145                 DEBUG(1, ("serverid record has unexpected size %d "
146                           "(wanted %d)\n", (int)rec->value.dsize,
147                           (int)sizeof(struct serverid_data)));
148                 goto done;
149         }
150
151         data = (struct serverid_data *)rec->value.dptr;
152
153         if (do_reg) {
154                 data->msg_flags |= msg_flags;
155         } else {
156                 data->msg_flags &= ~msg_flags;
157         }
158
159         status = rec->store(rec, rec->value, 0);
160         if (!NT_STATUS_IS_OK(status)) {
161                 DEBUG(1, ("Storing serverid.tdb record failed: %s\n",
162                           nt_errstr(status)));
163                 goto done;
164         }
165         ret = true;
166 done:
167         TALLOC_FREE(rec);
168         return ret;
169 }
170
171 bool serverid_deregister(struct server_id id)
172 {
173         struct db_context *db;
174         struct serverid_key key;
175         struct db_record *rec;
176         TDB_DATA tdbkey;
177         NTSTATUS status;
178         bool ret = false;
179
180         db = serverid_db();
181         if (db == NULL) {
182                 return false;
183         }
184
185         serverid_fill_key(&id, &key);
186         tdbkey = make_tdb_data((uint8_t *)&key, sizeof(key));
187
188         rec = db->fetch_locked(db, talloc_tos(), tdbkey);
189         if (rec == NULL) {
190                 DEBUG(1, ("Could not fetch_lock serverid.tdb record\n"));
191                 return false;
192         }
193
194         status = rec->delete_rec(rec);
195         if (!NT_STATUS_IS_OK(status)) {
196                 DEBUG(1, ("Deleting serverid.tdb record failed: %s\n",
197                           nt_errstr(status)));
198                 goto done;
199         }
200         ret = true;
201 done:
202         TALLOC_FREE(rec);
203         return ret;
204 }
205
206 struct serverid_exists_state {
207         const struct server_id *id;
208         bool exists;
209 };
210
211 static int server_exists_parse(TDB_DATA key, TDB_DATA data, void *priv)
212 {
213         struct serverid_exists_state *state =
214                 (struct serverid_exists_state *)priv;
215
216         if (data.dsize != sizeof(struct serverid_data)) {
217                 return -1;
218         }
219
220         /*
221          * Use memcmp, not direct compare. data.dptr might not be
222          * aligned.
223          */
224         state->exists = (memcmp(&state->id->unique_id, data.dptr,
225                                 sizeof(state->id->unique_id)) == 0);
226         return 0;
227 }
228
229 bool serverid_exists(const struct server_id *id)
230 {
231         struct db_context *db;
232         struct serverid_exists_state state;
233         struct serverid_key key;
234         TDB_DATA tdbkey;
235
236         if (lp_clustering() && !process_exists(*id)) {
237                 return false;
238         }
239
240         db = serverid_db();
241         if (db == NULL) {
242                 return false;
243         }
244
245         serverid_fill_key(id, &key);
246         tdbkey = make_tdb_data((uint8_t *)&key, sizeof(key));
247
248         state.id = id;
249         state.exists = false;
250
251         if (db->parse_record(db, tdbkey, server_exists_parse, &state) == -1) {
252                 return false;
253         }
254         return state.exists;
255 }
256
257 static bool serverid_rec_parse(const struct db_record *rec,
258                                struct server_id *id, uint32_t *msg_flags)
259 {
260         struct serverid_key key;
261         struct serverid_data data;
262
263         if (rec->key.dsize != sizeof(key)) {
264                 DEBUG(1, ("Found invalid key length %d in serverid.tdb\n",
265                           (int)rec->key.dsize));
266                 return false;
267         }
268         if (rec->value.dsize != sizeof(data)) {
269                 DEBUG(1, ("Found invalid value length %d in serverid.tdb\n",
270                           (int)rec->value.dsize));
271                 return false;
272         }
273
274         memcpy(&key, rec->key.dptr, sizeof(key));
275         memcpy(&data, rec->value.dptr, sizeof(data));
276
277         id->pid = key.pid;
278         id->vnn = key.vnn;
279         id->unique_id = data.unique_id;
280         *msg_flags = data.msg_flags;
281         return true;
282 }
283
284 struct serverid_traverse_read_state {
285         int (*fn)(const struct server_id *id, uint32_t msg_flags,
286                   void *private_data);
287         void *private_data;
288 };
289
290 static int serverid_traverse_read_fn(struct db_record *rec, void *private_data)
291 {
292         struct serverid_traverse_read_state *state =
293                 (struct serverid_traverse_read_state *)private_data;
294         struct server_id id;
295         uint32_t msg_flags;
296
297         if (!serverid_rec_parse(rec, &id, &msg_flags)) {
298                 return 0;
299         }
300         return state->fn(&id, msg_flags,state->private_data);
301 }
302
303 bool serverid_traverse_read(int (*fn)(const struct server_id *id,
304                                       uint32_t msg_flags, void *private_data),
305                             void *private_data)
306 {
307         struct db_context *db;
308         struct serverid_traverse_read_state state;
309
310         db = serverid_db();
311         if (db == NULL) {
312                 return false;
313         }
314         state.fn = fn;
315         state.private_data = private_data;
316         return db->traverse_read(db, serverid_traverse_read_fn, &state);
317 }
318
319 struct serverid_traverse_state {
320         int (*fn)(struct db_record *rec, const struct server_id *id,
321                   uint32_t msg_flags, void *private_data);
322         void *private_data;
323 };
324
325 static int serverid_traverse_fn(struct db_record *rec, void *private_data)
326 {
327         struct serverid_traverse_state *state =
328                 (struct serverid_traverse_state *)private_data;
329         struct server_id id;
330         uint32_t msg_flags;
331
332         if (!serverid_rec_parse(rec, &id, &msg_flags)) {
333                 return 0;
334         }
335         return state->fn(rec, &id, msg_flags, state->private_data);
336 }
337
338 bool serverid_traverse(int (*fn)(struct db_record *rec,
339                                  const struct server_id *id,
340                                  uint32_t msg_flags, void *private_data),
341                             void *private_data)
342 {
343         struct db_context *db;
344         struct serverid_traverse_state state;
345
346         db = serverid_db();
347         if (db == NULL) {
348                 return false;
349         }
350         state.fn = fn;
351         state.private_data = private_data;
352         return db->traverse(db, serverid_traverse_fn, &state);
353 }