python/tests/dsdb: Add tests for RID allocation functions
authorJoseph Sutton <josephsutton@catalyst.net.nz>
Mon, 24 May 2021 04:46:28 +0000 (16:46 +1200)
committerKarolin Seeger <kseeger@samba.org>
Tue, 13 Jul 2021 12:31:15 +0000 (12:31 +0000)
BUG: https://bugzilla.samba.org/show_bug.cgi?id=14669

Signed-off-by: Joseph Sutton <josephsutton@catalyst.net.nz>
Reviewed-by: Andrew Bartlett <abartlet@samba.org>
Reviewed-by: Douglas Bagnall <douglas.bagnall@catalyst.net.nz>
(cherry picked from commit 7c7cad81844950c3efe9a540a47b9d4e1ce1b2a1)

python/samba/tests/dsdb.py

index 33cfcc1227151d8b9226647f1d96668e75796145..f1d0557743ebe37348b14a8b6dab93c01630c010 100644 (file)
@@ -50,13 +50,316 @@ class DsdbTests(TestCase):
 
         base_dn = self.samdb.domain_dn()
 
-        self.account_dn = "cn=" + user_name + ",cn=Users," + base_dn
+        self.account_dn = "CN=" + user_name + ",CN=Users," + base_dn
         self.samdb.newuser(username=user_name,
                            password=user_pass,
                            description=user_description)
         # Cleanup (teardown)
         self.addCleanup(delete_force, self.samdb, self.account_dn)
 
+        # Get server reference DN
+        res = self.samdb.search(base=ldb.Dn(self.samdb,
+                                            self.samdb.get_serverName()),
+                                scope=ldb.SCOPE_BASE,
+                                attrs=["serverReference"])
+        # Get server reference
+        self.server_ref_dn = ldb.Dn(
+            self.samdb, res[0]["serverReference"][0].decode("utf-8"))
+
+        # Get RID Set DN
+        res = self.samdb.search(base=self.server_ref_dn,
+                                scope=ldb.SCOPE_BASE,
+                                attrs=["rIDSetReferences"])
+        rid_set_refs = res[0]
+        self.assertIn("rIDSetReferences", rid_set_refs)
+        rid_set_str = rid_set_refs["rIDSetReferences"][0].decode("utf-8")
+        self.rid_set_dn = ldb.Dn(self.samdb, rid_set_str)
+
+    def get_rid_set(self, rid_set_dn):
+        res = self.samdb.search(base=rid_set_dn,
+                                scope=ldb.SCOPE_BASE,
+                                attrs=["rIDAllocationPool",
+                                       "rIDPreviousAllocationPool",
+                                       "rIDUsedPool",
+                                       "rIDNextRID"])
+        return res[0]
+
+    def test_ridalloc_next_free_rid(self):
+        # Test RID allocation. We assume that RID
+        # pools allocated to us are continguous.
+        self.samdb.transaction_start()
+        try:
+            orig_rid_set = self.get_rid_set(self.rid_set_dn)
+            self.assertIn("rIDAllocationPool", orig_rid_set)
+            self.assertIn("rIDPreviousAllocationPool", orig_rid_set)
+            self.assertIn("rIDUsedPool", orig_rid_set)
+            self.assertIn("rIDNextRID", orig_rid_set)
+
+            # Get rIDNextRID value from RID set.
+            next_rid = int(orig_rid_set["rIDNextRID"][0])
+
+            # Check the result of next_free_rid().
+            next_free_rid = self.samdb.next_free_rid()
+            self.assertEqual(next_rid + 1, next_free_rid)
+
+            # Check calling it twice in succession gives the same result.
+            next_free_rid2 = self.samdb.next_free_rid()
+            self.assertEqual(next_free_rid, next_free_rid2)
+
+            # Ensure that the RID set attributes have not changed.
+            rid_set2 = self.get_rid_set(self.rid_set_dn)
+            self.assertEqual(orig_rid_set, rid_set2)
+        finally:
+            self.samdb.transaction_cancel()
+
+    def test_ridalloc_no_ridnextrid(self):
+        self.samdb.transaction_start()
+        try:
+            # Delete the rIDNextRID attribute of the RID set,
+            # and set up previous and next pools.
+            prev_lo = 1000
+            prev_hi = 1999
+            next_lo = 3000
+            next_hi = 3999
+            msg = ldb.Message()
+            msg.dn = self.rid_set_dn
+            msg["rIDNextRID"] = ldb.MessageElement([],
+                                                   ldb.FLAG_MOD_DELETE,
+                                                   "rIDNextRID")
+            msg["rIDPreviousAllocationPool"] = (
+                ldb.MessageElement(str((prev_hi << 32) | prev_lo),
+                                   ldb.FLAG_MOD_REPLACE,
+                                   "rIDPreviousAllocationPool"))
+            msg["rIDAllocationPool"] = (
+                ldb.MessageElement(str((next_hi << 32) | next_lo),
+                                   ldb.FLAG_MOD_REPLACE,
+                                   "rIDAllocationPool"))
+            self.samdb.modify(msg)
+
+            # Ensure that next_free_rid() returns the start of the next pool
+            # plus one.
+            next_free_rid3 = self.samdb.next_free_rid()
+            self.assertEqual(next_lo + 1, next_free_rid3)
+
+            # Check the result of allocate_rid() matches.
+            rid = self.samdb.allocate_rid()
+            self.assertEqual(next_free_rid3, rid)
+
+            # Check that the result of next_free_rid() has now changed.
+            next_free_rid4 = self.samdb.next_free_rid()
+            self.assertEqual(rid + 1, next_free_rid4)
+
+            # Check the range of available RIDs.
+            free_lo, free_hi = self.samdb.free_rid_bounds()
+            self.assertEqual(rid + 1, free_lo)
+            self.assertEqual(next_hi, free_hi)
+        finally:
+            self.samdb.transaction_cancel()
+
+    def test_ridalloc_no_free_rids(self):
+        self.samdb.transaction_start()
+        try:
+            # Exhaust our current pool of RIDs.
+            pool_lo = 2000
+            pool_hi = 2999
+            msg = ldb.Message()
+            msg.dn = self.rid_set_dn
+            msg["rIDPreviousAllocationPool"] = (
+                ldb.MessageElement(str((pool_hi << 32) | pool_lo),
+                                   ldb.FLAG_MOD_REPLACE,
+                                   "rIDPreviousAllocationPool"))
+            msg["rIDAllocationPool"] = (
+                ldb.MessageElement(str((pool_hi << 32) | pool_lo),
+                                   ldb.FLAG_MOD_REPLACE,
+                                   "rIDAllocationPool"))
+            msg["rIDNextRID"] = (
+            ldb.MessageElement(str(pool_hi),
+                                   ldb.FLAG_MOD_REPLACE,
+                                   "rIDNextRID"))
+            self.samdb.modify(msg)
+
+            # Ensure that calculating the next free RID fails.
+            with self.assertRaises(ldb.LdbError) as err:
+                self.samdb.next_free_rid()
+
+            self.assertEqual("RID pools out of RIDs", err.exception.args[1])
+
+            # Ensure we can still allocate a new RID.
+            self.samdb.allocate_rid()
+        finally:
+            self.samdb.transaction_cancel()
+
+    def test_ridalloc_new_ridset(self):
+        self.samdb.transaction_start()
+        try:
+            # Test what happens with RID Set values set to zero (similar to
+            # when a RID Set is first created, except we also set
+            # rIDAllocationPool to zero).
+            msg = ldb.Message()
+            msg.dn = self.rid_set_dn
+            msg["rIDPreviousAllocationPool"] = (
+                ldb.MessageElement("0",
+                                   ldb.FLAG_MOD_REPLACE,
+                                   "rIDPreviousAllocationPool"))
+            msg["rIDAllocationPool"] = (
+                ldb.MessageElement("0",
+                                   ldb.FLAG_MOD_REPLACE,
+                                   "rIDAllocationPool"))
+            msg["rIDNextRID"] = (
+                ldb.MessageElement("0",
+                                   ldb.FLAG_MOD_REPLACE,
+                                   "rIDNextRID"))
+            self.samdb.modify(msg)
+
+            # Ensure that calculating the next free RID fails.
+            with self.assertRaises(ldb.LdbError) as err:
+                self.samdb.next_free_rid()
+
+            self.assertEqual("RID pools out of RIDs", err.exception.args[1])
+
+            # Set values for the next pool.
+            pool_lo = 2000
+            pool_hi = 2999
+            msg = ldb.Message()
+            msg.dn = self.rid_set_dn
+            msg["rIDAllocationPool"] = (
+                ldb.MessageElement(str((pool_hi << 32) | pool_lo),
+                                   ldb.FLAG_MOD_REPLACE,
+                                   "rIDAllocationPool"))
+            self.samdb.modify(msg)
+
+            # Ensure the next free RID value is equal to the next pool's lower
+            # bound.
+            next_free_rid5 = self.samdb.next_free_rid()
+            self.assertEqual(pool_lo, next_free_rid5)
+
+            # Check the range of available RIDs.
+            free_lo, free_hi = self.samdb.free_rid_bounds()
+            self.assertEqual(pool_lo, free_lo)
+            self.assertEqual(pool_hi, free_hi)
+        finally:
+            self.samdb.transaction_cancel()
+
+    def test_ridalloc_move_to_new_pool(self):
+        self.samdb.transaction_start()
+        try:
+            # Test moving to a new pool from the previous pool.
+            pool_lo = 2000
+            pool_hi = 2999
+            new_pool_lo = 4500
+            new_pool_hi = 4599
+            msg = ldb.Message()
+            msg.dn = self.rid_set_dn
+            msg["rIDPreviousAllocationPool"] = (
+                ldb.MessageElement(str((pool_hi << 32) | pool_lo),
+                                   ldb.FLAG_MOD_REPLACE,
+                                   "rIDPreviousAllocationPool"))
+            msg["rIDAllocationPool"] = (
+                ldb.MessageElement(str((new_pool_hi << 32) | new_pool_lo),
+                                   ldb.FLAG_MOD_REPLACE,
+                                   "rIDAllocationPool"))
+            msg["rIDNextRID"] = (
+                ldb.MessageElement(str(pool_hi - 1),
+                                   ldb.FLAG_MOD_REPLACE,
+                                   "rIDNextRID"))
+            self.samdb.modify(msg)
+
+            # We should have remained in the previous pool.
+            next_free_rid6 = self.samdb.next_free_rid()
+            self.assertEqual(pool_hi, next_free_rid6)
+
+            # Check the range of available RIDs.
+            free_lo, free_hi = self.samdb.free_rid_bounds()
+            self.assertEqual(pool_hi, free_lo)
+            self.assertEqual(pool_hi, free_hi)
+
+            # Allocate a new RID.
+            rid2 = self.samdb.allocate_rid()
+            self.assertEqual(next_free_rid6, rid2)
+
+            # We should now move to the next pool.
+            next_free_rid7 = self.samdb.next_free_rid()
+            self.assertEqual(new_pool_lo, next_free_rid7)
+
+            # Check the new range of available RIDs.
+            free_lo2, free_hi2 = self.samdb.free_rid_bounds()
+            self.assertEqual(new_pool_lo, free_lo2)
+            self.assertEqual(new_pool_hi, free_hi2)
+
+            # Ensure that allocate_rid() matches.
+            rid3 = self.samdb.allocate_rid()
+            self.assertEqual(next_free_rid7, rid3)
+        finally:
+            self.samdb.transaction_cancel()
+
+    def test_ridalloc_no_ridsetreferences(self):
+        self.samdb.transaction_start()
+        try:
+            # Delete the rIDSetReferences attribute.
+            msg = ldb.Message()
+            msg.dn = self.server_ref_dn
+            msg["rIDSetReferences"] = (
+                ldb.MessageElement([],
+                                   ldb.FLAG_MOD_DELETE,
+                                   "rIDSetReferences"))
+            self.samdb.modify(msg)
+
+            # Ensure calculating the next free RID fails.
+            with self.assertRaises(ldb.LdbError) as err:
+                self.samdb.next_free_rid()
+
+            enum, estr = err.exception.args
+            self.assertEqual(ldb.ERR_NO_SUCH_ATTRIBUTE, enum)
+            self.assertIn("No RID Set DN - "
+                          "Cannot find attribute rIDSetReferences of %s "
+                          "to calculate reference dn" % self.server_ref_dn,
+                          estr)
+
+            # Ensure allocating a new RID fails.
+            with self.assertRaises(ldb.LdbError) as err:
+                self.samdb.allocate_rid()
+
+            enum, estr = err.exception.args
+            self.assertEqual(ldb.ERR_ENTRY_ALREADY_EXISTS, enum)
+            self.assertIn("No RID Set DN - "
+                          "Failed to add RID Set %s - "
+                          "Entry %s already exists" %
+                          (self.rid_set_dn, self.rid_set_dn),
+                          estr)
+        finally:
+            self.samdb.transaction_cancel()
+
+    def test_ridalloc_no_rid_set(self):
+        self.samdb.transaction_start()
+        try:
+            # Set the rIDSetReferences attribute to not point to a RID Set.
+            fake_rid_set_str = self.account_dn
+            msg = ldb.Message()
+            msg.dn = self.server_ref_dn
+            msg["rIDSetReferences"] = (
+                ldb.MessageElement(fake_rid_set_str,
+                                   ldb.FLAG_MOD_REPLACE,
+                                   "rIDSetReferences"))
+            self.samdb.modify(msg)
+
+            # Ensure calculating the next free RID fails.
+            with self.assertRaises(ldb.LdbError) as err:
+                self.samdb.next_free_rid()
+
+            enum, estr = err.exception.args
+            self.assertEqual(ldb.ERR_OPERATIONS_ERROR, enum)
+            self.assertIn("Bad RID Set " + fake_rid_set_str, estr)
+
+            # Ensure allocating a new RID fails.
+            with self.assertRaises(ldb.LdbError) as err:
+                self.samdb.allocate_rid()
+
+            enum, estr = err.exception.args
+            self.assertEqual(ldb.ERR_OPERATIONS_ERROR, enum)
+            self.assertIn("Bad RID Set " + fake_rid_set_str,  estr)
+        finally:
+            self.samdb.transaction_cancel()
+
     def test_get_oid_from_attrid(self):
         oid = self.samdb.get_oid_from_attid(591614)
         self.assertEqual(oid, "1.2.840.113556.1.4.1790")