s3: Slight reshaping of server_exists_parse
[kamenim/samba.git] / source3 / lib / serverid.c
1 /*
2    Unix SMB/CIFS implementation.
3    Implementation of a reliable server_exists()
4    Copyright (C) Volker Lendecke 2010
5
6    This program is free software; you can redistribute it and/or modify
7    it under the terms of the GNU General Public License as published by
8    the Free Software Foundation; either version 3 of the License, or
9    (at your option) any later version.
10
11    This program is distributed in the hope that it will be useful,
12    but WITHOUT ANY WARRANTY; without even the implied warranty of
13    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
14    GNU General Public License for more details.
15
16    You should have received a copy of the GNU General Public License
17    along with this program.  If not, see <http://www.gnu.org/licenses/>.
18 */
19
20 #include "includes.h"
21 #include "serverid.h"
22
23 struct serverid_key {
24         pid_t pid;
25 #ifdef CLUSTER_SUPPORT
26         uint32_t vnn;
27 #endif
28 };
29
30 bool serverid_parent_init(void)
31 {
32         struct tdb_wrap *db;
33
34         /*
35          * Open the tdb in the parent process (smbd) so that our
36          * CLEAR_IF_FIRST optimization in tdb_reopen_all can properly
37          * work.
38          */
39
40         db = tdb_wrap_open(talloc_autofree_context(),
41                            lock_path("serverid.tdb"),
42                            0, TDB_DEFAULT|TDB_CLEAR_IF_FIRST, O_RDWR|O_CREAT,
43                            0644);
44         if (db == NULL) {
45                 DEBUG(1, ("could not open serverid.tdb: %s\n",
46                           strerror(errno)));
47                 return false;
48         }
49         return true;
50 }
51
52 struct serverid_data {
53         uint64_t unique_id;
54         uint32_t msg_flags;
55 };
56
57 static struct db_context *serverid_db(void)
58 {
59         static struct db_context *db;
60
61         if (db != NULL) {
62                 return db;
63         }
64         db = db_open(talloc_autofree_context(), lock_path("serverid.tdb"),
65                      0, TDB_DEFAULT|TDB_CLEAR_IF_FIRST, O_RDWR|O_CREAT, 0644);
66         return db;
67 }
68
69 static void serverid_fill_key(const struct server_id *id,
70                               struct serverid_key *key)
71 {
72         ZERO_STRUCTP(key);
73         key->pid = id->pid;
74 #ifdef CLUSTER_SUPPORT
75         key->vnn = id->vnn;
76 #endif
77 }
78
79 bool serverid_register(const struct server_id *id, uint32_t msg_flags)
80 {
81         struct db_context *db;
82         struct serverid_key key;
83         struct serverid_data data;
84         struct db_record *rec;
85         TDB_DATA tdbkey, tdbdata;
86         NTSTATUS status;
87         bool ret = false;
88
89         db = serverid_db();
90         if (db == NULL) {
91                 return false;
92         }
93
94         serverid_fill_key(id, &key);
95         tdbkey = make_tdb_data((uint8_t *)&key, sizeof(key));
96
97         rec = db->fetch_locked(db, talloc_tos(), tdbkey);
98         if (rec == NULL) {
99                 DEBUG(1, ("Could not fetch_lock serverid.tdb record\n"));
100                 return false;
101         }
102
103         ZERO_STRUCT(data);
104         data.unique_id = id->unique_id;
105         data.msg_flags = msg_flags;
106
107         tdbdata = make_tdb_data((uint8_t *)&data, sizeof(data));
108         status = rec->store(rec, tdbdata, 0);
109         if (!NT_STATUS_IS_OK(status)) {
110                 DEBUG(1, ("Storing serverid.tdb record failed: %s\n",
111                           nt_errstr(status)));
112                 goto done;
113         }
114         ret = true;
115 done:
116         TALLOC_FREE(rec);
117         return ret;
118 }
119
120 bool serverid_register_self(uint32_t msg_flags)
121 {
122         struct server_id pid;
123
124         pid = procid_self();
125         return serverid_register(&pid, msg_flags);
126 }
127
128 bool serverid_deregister(const struct server_id *id)
129 {
130         struct db_context *db;
131         struct serverid_key key;
132         struct db_record *rec;
133         TDB_DATA tdbkey;
134         NTSTATUS status;
135         bool ret = false;
136
137         db = serverid_db();
138         if (db == NULL) {
139                 return false;
140         }
141
142         serverid_fill_key(id, &key);
143         tdbkey = make_tdb_data((uint8_t *)&key, sizeof(key));
144
145         rec = db->fetch_locked(db, talloc_tos(), tdbkey);
146         if (rec == NULL) {
147                 DEBUG(1, ("Could not fetch_lock serverid.tdb record\n"));
148                 return false;
149         }
150
151         status = rec->delete_rec(rec);
152         if (!NT_STATUS_IS_OK(status)) {
153                 DEBUG(1, ("Deleting serverid.tdb record failed: %s\n",
154                           nt_errstr(status)));
155                 goto done;
156         }
157         ret = true;
158 done:
159         TALLOC_FREE(rec);
160         return ret;
161 }
162
163 bool serverid_deregister_self(void)
164 {
165         struct server_id pid;
166
167         pid = procid_self();
168         return serverid_deregister(&pid);
169 }
170
171 struct serverid_exists_state {
172         const struct server_id *id;
173         bool exists;
174 };
175
176 static int server_exists_parse(TDB_DATA key, TDB_DATA data, void *priv)
177 {
178         struct serverid_exists_state *state =
179                 (struct serverid_exists_state *)priv;
180         uint64_t unique_id;
181
182         if (data.dsize != sizeof(struct serverid_data)) {
183                 return -1;
184         }
185
186         /*
187          * Use memcmp, not direct compare. data.dptr might not be
188          * aligned.
189          */
190         state->exists =
191                 (memcmp(&unique_id, data.dptr, sizeof(unique_id)) == 0);
192         return 0;
193 }
194
195 bool serverid_exists(const struct server_id *id)
196 {
197         struct db_context *db;
198         struct serverid_exists_state state;
199         struct serverid_key key;
200         TDB_DATA tdbkey;
201
202         db = serverid_db();
203         if (db == NULL) {
204                 return false;
205         }
206
207         serverid_fill_key(id, &key);
208         tdbkey = make_tdb_data((uint8_t *)&key, sizeof(key));
209
210         state.id = id;
211         state.exists = false;
212
213         if (db->parse_record(db, tdbkey, server_exists_parse, &state) == -1) {
214                 return false;
215         }
216         return state.exists;
217 }
218
219 static bool serverid_rec_parse(const struct db_record *rec,
220                                struct server_id *id, uint32_t *msg_flags)
221 {
222         struct serverid_key key;
223         struct serverid_data data;
224
225         if (rec->key.dsize != sizeof(key)) {
226                 DEBUG(1, ("Found invalid key length %d in serverid.tdb\n",
227                           (int)rec->key.dsize));
228                 return false;
229         }
230         if (rec->value.dsize != sizeof(data)) {
231                 DEBUG(1, ("Found invalid value length %d in serverid.tdb\n",
232                           (int)rec->value.dsize));
233                 return false;
234         }
235
236         memcpy(&key, rec->key.dptr, sizeof(key));
237         memcpy(&data, rec->value.dptr, sizeof(data));
238
239         id->pid = key.pid;
240 #ifdef CLUSTER_SUPPORT
241         id->vnn = key.vnn;
242 #endif
243         id->unique_id = data.unique_id;
244         *msg_flags = data.msg_flags;
245         return true;
246 }
247
248 struct serverid_traverse_read_state {
249         int (*fn)(const struct server_id *id, uint32_t msg_flags,
250                   void *private_data);
251         void *private_data;
252 };
253
254 static int serverid_traverse_read_fn(struct db_record *rec, void *private_data)
255 {
256         struct serverid_traverse_read_state *state =
257                 (struct serverid_traverse_read_state *)private_data;
258         struct server_id id;
259         uint32_t msg_flags;
260
261         if (!serverid_rec_parse(rec, &id, &msg_flags)) {
262                 return 0;
263         }
264         return state->fn(&id, msg_flags,state->private_data);
265 }
266
267 bool serverid_traverse_read(int (*fn)(const struct server_id *id,
268                                       uint32_t msg_flags, void *private_data),
269                             void *private_data)
270 {
271         struct db_context *db;
272         struct serverid_traverse_read_state state;
273
274         db = serverid_db();
275         if (db == NULL) {
276                 return false;
277         }
278         state.fn = fn;
279         state.private_data = private_data;
280         return db->traverse_read(db, serverid_traverse_read_fn, &state);
281 }
282
283 struct serverid_traverse_state {
284         int (*fn)(struct db_record *rec, const struct server_id *id,
285                   uint32_t msg_flags, void *private_data);
286         void *private_data;
287 };
288
289 static int serverid_traverse_fn(struct db_record *rec, void *private_data)
290 {
291         struct serverid_traverse_state *state =
292                 (struct serverid_traverse_state *)private_data;
293         struct server_id id;
294         uint32_t msg_flags;
295
296         if (!serverid_rec_parse(rec, &id, &msg_flags)) {
297                 return 0;
298         }
299         return state->fn(rec, &id, msg_flags, state->private_data);
300 }
301
302 bool serverid_traverse(int (*fn)(struct db_record *rec,
303                                  const struct server_id *id,
304                                  uint32_t msg_flags, void *private_data),
305                             void *private_data)
306 {
307         struct db_context *db;
308         struct serverid_traverse_state state;
309
310         db = serverid_db();
311         if (db == NULL) {
312                 return false;
313         }
314         state.fn = fn;
315         state.private_data = private_data;
316         return db->traverse(db, serverid_traverse_fn, &state);
317 }