s3:serverid: don't ignore the result of dbwrap_parse_record()
[ddiss/samba.git] / source3 / lib / serverid.c
index ded72981ec4be7a203fdb9946b6db90995604974..4e3175662030b314b5e10eacc99884bfd5b4c661 100644 (file)
 */
 
 #include "includes.h"
+#include "system/filesys.h"
 #include "serverid.h"
-#include "dbwrap.h"
+#include "util_tdb.h"
+#include "dbwrap/dbwrap.h"
+#include "dbwrap/dbwrap_open.h"
+#include "lib/tdb_wrap/tdb_wrap.h"
+#include "lib/param/param.h"
+#include "ctdbd_conn.h"
+#include "messages.h"
 
 struct serverid_key {
        pid_t pid;
-#ifdef CLUSTER_SUPPORT
+       uint32_t task_id;
        uint32_t vnn;
-#endif
 };
 
 struct serverid_data {
@@ -36,6 +42,13 @@ struct serverid_data {
 bool serverid_parent_init(TALLOC_CTX *mem_ctx)
 {
        struct tdb_wrap *db;
+       struct loadparm_context *lp_ctx;
+
+       lp_ctx = loadparm_init_s3(mem_ctx, loadparm_s3_context());
+       if (lp_ctx == NULL) {
+               DEBUG(0, ("loadparm_init_s3 failed\n"));
+               return false;
+       }
 
        /*
         * Open the tdb in the parent process (smbd) so that our
@@ -45,7 +58,8 @@ bool serverid_parent_init(TALLOC_CTX *mem_ctx)
 
        db = tdb_wrap_open(mem_ctx, lock_path("serverid.tdb"),
                           0, TDB_DEFAULT|TDB_CLEAR_IF_FIRST|TDB_INCOMPATIBLE_HASH, O_RDWR|O_CREAT,
-                          0644);
+                          0644, lp_ctx);
+       talloc_unlink(mem_ctx, lp_ctx);
        if (db == NULL) {
                DEBUG(1, ("could not open serverid.tdb: %s\n",
                          strerror(errno)));
@@ -62,7 +76,8 @@ static struct db_context *serverid_db(void)
                return db;
        }
        db = db_open(NULL, lock_path("serverid.tdb"), 0,
-                    TDB_DEFAULT|TDB_CLEAR_IF_FIRST|TDB_INCOMPATIBLE_HASH, O_RDWR|O_CREAT, 0644);
+                    TDB_DEFAULT|TDB_CLEAR_IF_FIRST|TDB_INCOMPATIBLE_HASH,
+                    O_RDWR|O_CREAT, 0644, DBWRAP_LOCK_ORDER_2);
        return db;
 }
 
@@ -71,9 +86,8 @@ static void serverid_fill_key(const struct server_id *id,
 {
        ZERO_STRUCTP(key);
        key->pid = id->pid;
-#ifdef CLUSTER_SUPPORT
+       key->task_id = id->task_id;
        key->vnn = id->vnn;
-#endif
 }
 
 bool serverid_register(const struct server_id id, uint32_t msg_flags)
@@ -94,7 +108,7 @@ bool serverid_register(const struct server_id id, uint32_t msg_flags)
        serverid_fill_key(&id, &key);
        tdbkey = make_tdb_data((uint8_t *)&key, sizeof(key));
 
-       rec = db->fetch_locked(db, talloc_tos(), tdbkey);
+       rec = dbwrap_fetch_locked(db, talloc_tos(), tdbkey);
        if (rec == NULL) {
                DEBUG(1, ("Could not fetch_lock serverid.tdb record\n"));
                return false;
@@ -105,12 +119,17 @@ bool serverid_register(const struct server_id id, uint32_t msg_flags)
        data.msg_flags = msg_flags;
 
        tdbdata = make_tdb_data((uint8_t *)&data, sizeof(data));
-       status = rec->store(rec, tdbdata, 0);
+       status = dbwrap_record_store(rec, tdbdata, 0);
        if (!NT_STATUS_IS_OK(status)) {
                DEBUG(1, ("Storing serverid.tdb record failed: %s\n",
                          nt_errstr(status)));
                goto done;
        }
+#ifdef HAVE_CTDB_CONTROL_CHECK_SRVIDS_DECL
+       if (lp_clustering()) {
+               register_with_ctdbd(messaging_ctdbd_connection(), id.unique_id);
+       }
+#endif
        ret = true;
 done:
        TALLOC_FREE(rec);
@@ -125,6 +144,7 @@ bool serverid_register_msg_flags(const struct server_id id, bool do_reg,
        struct serverid_data *data;
        struct db_record *rec;
        TDB_DATA tdbkey;
+       TDB_DATA value;
        NTSTATUS status;
        bool ret = false;
 
@@ -136,20 +156,22 @@ bool serverid_register_msg_flags(const struct server_id id, bool do_reg,
        serverid_fill_key(&id, &key);
        tdbkey = make_tdb_data((uint8_t *)&key, sizeof(key));
 
-       rec = db->fetch_locked(db, talloc_tos(), tdbkey);
+       rec = dbwrap_fetch_locked(db, talloc_tos(), tdbkey);
        if (rec == NULL) {
                DEBUG(1, ("Could not fetch_lock serverid.tdb record\n"));
                return false;
        }
 
-       if (rec->value.dsize != sizeof(struct serverid_data)) {
+       value = dbwrap_record_get_value(rec);
+
+       if (value.dsize != sizeof(struct serverid_data)) {
                DEBUG(1, ("serverid record has unexpected size %d "
-                         "(wanted %d)\n", (int)rec->value.dsize,
+                         "(wanted %d)\n", (int)value.dsize,
                          (int)sizeof(struct serverid_data)));
                goto done;
        }
 
-       data = (struct serverid_data *)rec->value.dptr;
+       data = (struct serverid_data *)value.dptr;
 
        if (do_reg) {
                data->msg_flags |= msg_flags;
@@ -157,7 +179,7 @@ bool serverid_register_msg_flags(const struct server_id id, bool do_reg,
                data->msg_flags &= ~msg_flags;
        }
 
-       status = rec->store(rec, rec->value, 0);
+       status = dbwrap_record_store(rec, value, 0);
        if (!NT_STATUS_IS_OK(status)) {
                DEBUG(1, ("Storing serverid.tdb record failed: %s\n",
                          nt_errstr(status)));
@@ -186,13 +208,13 @@ bool serverid_deregister(struct server_id id)
        serverid_fill_key(&id, &key);
        tdbkey = make_tdb_data((uint8_t *)&key, sizeof(key));
 
-       rec = db->fetch_locked(db, talloc_tos(), tdbkey);
+       rec = dbwrap_fetch_locked(db, talloc_tos(), tdbkey);
        if (rec == NULL) {
                DEBUG(1, ("Could not fetch_lock serverid.tdb record\n"));
                return false;
        }
 
-       status = rec->delete_rec(rec);
+       status = dbwrap_record_delete(rec);
        if (!NT_STATUS_IS_OK(status)) {
                DEBUG(1, ("Deleting serverid.tdb record failed: %s\n",
                          nt_errstr(status)));
@@ -209,13 +231,14 @@ struct serverid_exists_state {
        bool exists;
 };
 
-static int server_exists_parse(TDB_DATA key, TDB_DATA data, void *priv)
+static void server_exists_parse(TDB_DATA key, TDB_DATA data, void *priv)
 {
        struct serverid_exists_state *state =
                (struct serverid_exists_state *)priv;
 
        if (data.dsize != sizeof(struct serverid_data)) {
-               return -1;
+               state->exists = false;
+               return;
        }
 
        /*
@@ -224,7 +247,6 @@ static int server_exists_parse(TDB_DATA key, TDB_DATA data, void *priv)
         */
        state->exists = (memcmp(&state->id->unique_id, data.dptr,
                                sizeof(state->id->unique_id)) == 0);
-       return 0;
 }
 
 bool serverid_exists(const struct server_id *id)
@@ -233,11 +255,20 @@ bool serverid_exists(const struct server_id *id)
        struct serverid_exists_state state;
        struct serverid_key key;
        TDB_DATA tdbkey;
+       NTSTATUS status;
+
+       if (procid_is_me(id)) {
+               return true;
+       }
 
-       if (lp_clustering() && !process_exists(*id)) {
+       if (!process_exists(*id)) {
                return false;
        }
 
+       if (id->unique_id == SERVERID_UNIQUE_ID_NOT_TO_VERIFY) {
+               return true;
+       }
+
        db = serverid_db();
        if (db == NULL) {
                return false;
@@ -249,36 +280,90 @@ bool serverid_exists(const struct server_id *id)
        state.id = id;
        state.exists = false;
 
-       if (db->parse_record(db, tdbkey, server_exists_parse, &state) == -1) {
+       status = dbwrap_parse_record(db, tdbkey, server_exists_parse, &state);
+       if (!NT_STATUS_IS_OK(status)) {
                return false;
        }
        return state.exists;
 }
 
+bool serverids_exist(const struct server_id *ids, int num_ids, bool *results)
+{
+       struct db_context *db;
+       int i;
+
+#ifdef HAVE_CTDB_CONTROL_CHECK_SRVIDS_DECL
+       if (lp_clustering()) {
+               return ctdb_serverids_exist(messaging_ctdbd_connection(),
+                                           ids, num_ids, results);
+       }
+#endif
+       if (!processes_exist(ids, num_ids, results)) {
+               return false;
+       }
+
+       db = serverid_db();
+       if (db == NULL) {
+               return false;
+       }
+
+       for (i=0; i<num_ids; i++) {
+               struct serverid_exists_state state;
+               struct serverid_key key;
+               TDB_DATA tdbkey;
+               NTSTATUS status;
+
+               if (ids[i].unique_id == SERVERID_UNIQUE_ID_NOT_TO_VERIFY) {
+                       results[i] = true;
+                       continue;
+               }
+               if (!results[i]) {
+                       continue;
+               }
+
+               serverid_fill_key(&ids[i], &key);
+               tdbkey = make_tdb_data((uint8_t *)&key, sizeof(key));
+
+               state.id = &ids[i];
+               state.exists = false;
+               status = dbwrap_parse_record(db, tdbkey, server_exists_parse, &state);
+               if (!NT_STATUS_IS_OK(status)) {
+                       results[i] = false;
+                       continue;
+               }
+               results[i] = state.exists;
+       }
+       return true;
+}
+
 static bool serverid_rec_parse(const struct db_record *rec,
                               struct server_id *id, uint32_t *msg_flags)
 {
        struct serverid_key key;
        struct serverid_data data;
+       TDB_DATA tdbkey;
+       TDB_DATA tdbdata;
 
-       if (rec->key.dsize != sizeof(key)) {
+       tdbkey = dbwrap_record_get_key(rec);
+       tdbdata = dbwrap_record_get_value(rec);
+
+       if (tdbkey.dsize != sizeof(key)) {
                DEBUG(1, ("Found invalid key length %d in serverid.tdb\n",
-                         (int)rec->key.dsize));
+                         (int)tdbkey.dsize));
                return false;
        }
-       if (rec->value.dsize != sizeof(data)) {
+       if (tdbdata.dsize != sizeof(data)) {
                DEBUG(1, ("Found invalid value length %d in serverid.tdb\n",
-                         (int)rec->value.dsize));
+                         (int)tdbdata.dsize));
                return false;
        }
 
-       memcpy(&key, rec->key.dptr, sizeof(key));
-       memcpy(&data, rec->value.dptr, sizeof(data));
+       memcpy(&key, tdbkey.dptr, sizeof(key));
+       memcpy(&data, tdbdata.dptr, sizeof(data));
 
        id->pid = key.pid;
-#ifdef CLUSTER_SUPPORT
+       id->task_id = key.task_id;
        id->vnn = key.vnn;
-#endif
        id->unique_id = data.unique_id;
        *msg_flags = data.msg_flags;
        return true;
@@ -309,6 +394,7 @@ bool serverid_traverse_read(int (*fn)(const struct server_id *id,
 {
        struct db_context *db;
        struct serverid_traverse_read_state state;
+       NTSTATUS status;
 
        db = serverid_db();
        if (db == NULL) {
@@ -316,7 +402,10 @@ bool serverid_traverse_read(int (*fn)(const struct server_id *id,
        }
        state.fn = fn;
        state.private_data = private_data;
-       return db->traverse_read(db, serverid_traverse_read_fn, &state);
+
+       status = dbwrap_traverse_read(db, serverid_traverse_read_fn, &state,
+                                     NULL);
+       return NT_STATUS_IS_OK(status);
 }
 
 struct serverid_traverse_state {
@@ -345,6 +434,7 @@ bool serverid_traverse(int (*fn)(struct db_record *rec,
 {
        struct db_context *db;
        struct serverid_traverse_state state;
+       NTSTATUS status;
 
        db = serverid_db();
        if (db == NULL) {
@@ -352,5 +442,40 @@ bool serverid_traverse(int (*fn)(struct db_record *rec,
        }
        state.fn = fn;
        state.private_data = private_data;
-       return db->traverse(db, serverid_traverse_fn, &state);
+
+       status = dbwrap_traverse(db, serverid_traverse_fn, &state, NULL);
+       return NT_STATUS_IS_OK(status);
+}
+
+uint64_t serverid_get_random_unique_id(void)
+{
+       uint64_t unique_id = SERVERID_UNIQUE_ID_NOT_TO_VERIFY;
+
+       while (unique_id == SERVERID_UNIQUE_ID_NOT_TO_VERIFY) {
+               generate_random_buffer((uint8_t *)&unique_id,
+                                      sizeof(unique_id));
+       }
+
+       return unique_id;
+}
+
+bool serverid_equal(const struct server_id *p1, const struct server_id *p2)
+{
+       if (p1->pid != p2->pid) {
+               return false;
+       }
+
+       if (p1->task_id != p2->task_id) {
+               return false;
+       }
+
+       if (p1->vnn != p2->vnn) {
+               return false;
+       }
+
+       if (p1->unique_id != p2->unique_id) {
+               return false;
+       }
+
+       return true;
 }