tests/gkdi: Allow current time to be overridden
authorJo Sutton <josutton@catalyst.net.nz>
Tue, 26 Mar 2024 03:25:31 +0000 (16:25 +1300)
committerAndrew Bartlett <abartlet@samba.org>
Tue, 16 Apr 2024 03:58:31 +0000 (03:58 +0000)
Signed-off-by: Jo Sutton <josutton@catalyst.net.nz>
Reviewed-by: Andrew Bartlett <abartlet@samba.org>
python/samba/tests/gkdi.py
python/samba/tests/krb5/gkdi_tests.py

index ac62eb51b70a60b26066e113a4b6f00fc29282db..d16ae9a1066cf9651dd15574a2aa6a853da05069 100644 (file)
@@ -57,6 +57,7 @@ from samba.hresult import (
 )
 from samba.ndr import ndr_pack, ndr_unpack
 from samba.nt_time import (
+    datetime_from_nt_time,
     nt_time_from_datetime,
     NtTime,
     NtTimeDelta,
@@ -74,6 +75,8 @@ RootKey = NewType("RootKey", ldb.Message)
 
 ROOT_KEY_START_TIME = NtTime(KEY_CYCLE_DURATION + MAX_CLOCK_SKEW)
 
+DSDB_GMSA_TIME_OPAQUE = "dsdb_gmsa_time_opaque"
+
 
 class GetKeyError(Exception):
     def __init__(self, status: HResult, message: str):
@@ -89,24 +92,39 @@ class GkdiBaseTest(TestCase):
         b"\x01\x01\x00\x00\x00\x00\x00\x05\x12\x00\x00\x00"
     )
 
-    @staticmethod
-    def current_time(offset: Optional[datetime.timedelta] = None) -> datetime.datetime:
-        current_time = datetime.datetime.now(tz=datetime.timezone.utc)
+    def set_db_time(self, samdb: SamDB, time: Optional[NtTime]) -> None:
+        samdb.set_opaque(DSDB_GMSA_TIME_OPAQUE, time)
+
+    def get_db_time(self, samdb: SamDB) -> Optional[NtTime]:
+        return samdb.get_opaque(DSDB_GMSA_TIME_OPAQUE)
+
+    def current_time(
+        self, samdb: SamDB, *, offset: Optional[datetime.timedelta] = None
+    ) -> datetime.datetime:
+        now = self.get_db_time(samdb)
+        if now is None:
+            current_time = datetime.datetime.now(tz=datetime.timezone.utc)
+        else:
+            current_time = datetime_from_nt_time(now)
 
         if offset is not None:
             current_time += offset
 
         return current_time
 
-    def current_nt_time(self, offset: Optional[datetime.timedelta] = None) -> NtTime:
-        return nt_time_from_datetime(self.current_time(offset))
+    def current_nt_time(
+        self, samdb: SamDB, *, offset: Optional[datetime.timedelta] = None
+    ) -> NtTime:
+        return nt_time_from_datetime(self.current_time(samdb, offset=offset))
 
-    def current_gkid(self, offset: Optional[datetime.timedelta] = None) -> Gkid:
+    def current_gkid(
+        self, samdb: SamDB, *, offset: Optional[datetime.timedelta] = None
+    ) -> Gkid:
         if offset is None:
             # Allow for clock skew.
             offset = timedelta_from_nt_time_delta(MAX_CLOCK_SKEW)
 
-        return Gkid.from_nt_time(self.current_nt_time(offset))
+        return Gkid.from_nt_time(self.current_nt_time(samdb, offset=offset))
 
     def gkdi_connect(
         self, host: str, lp: LoadParm, server_creds: Credentials
@@ -287,7 +305,7 @@ class GkdiBaseTest(TestCase):
         particular root key to use."""
 
         if current_gkid is None:
-            current_gkid = self.current_gkid()
+            current_gkid = self.current_gkid(samdb)
 
         root_key_specified = root_key_id is not None
         if root_key_specified:
@@ -367,7 +385,7 @@ class GkdiBaseTest(TestCase):
         current_gkid: Optional[Gkid] = None,
     ) -> GroupKey:
         if current_gkid is None:
-            current_gkid = self.current_gkid()
+            current_gkid = self.current_gkid(samdb)
 
         root_key_specified = root_key_id is not None
         self.validate_get_key_request(gkid, current_gkid, root_key_specified)
@@ -523,8 +541,9 @@ class GkdiBaseTest(TestCase):
             samdb,
             domain_dn,
             current_nt_time=self.current_nt_time(
+                samdb,
                 # Allow for clock skew.
-                timedelta_from_nt_time_delta(MAX_CLOCK_SKEW)
+                offset=timedelta_from_nt_time_delta(MAX_CLOCK_SKEW),
             ),
             use_start_time=use_start_time,
             hash_algorithm=hash_algorithm,
index 58a65c4c764a26ad4ffcf6a397ab41244bcf2e86..accaca0bc1cdd9cc1adce62a562d83962163eed3 100755 (executable)
@@ -92,7 +92,7 @@ class GkdiExplicitRootKeyTests(GkdiKdcBaseTest):
         # It actually doesn’t matter what we specify for the L1 and L2 indices.
         # We’ll get the same result regardless — they just cannot specify a key
         # from the future.
-        current_gkid = self.current_gkid()
+        current_gkid = self.current_gkid(self.get_samdb())
         key = self.check_rpc_get_key(root_key_id, current_gkid)
 
         self.assertEqual(current_gkid, key.gkid)
@@ -104,7 +104,7 @@ class GkdiExplicitRootKeyTests(GkdiKdcBaseTest):
 
         # It actually doesn’t matter what we specify for the L1 and L2 indices.
         # We’ll get the same result regardless.
-        previous_l0_idx = self.current_gkid().l0_idx - 1
+        previous_l0_idx = self.current_gkid(self.get_samdb()).l0_idx - 1
         key = self.check_rpc_get_key(root_key_id, Gkid(previous_l0_idx, 0, 0))
 
         # Expect to get an L1 seed key.
@@ -117,7 +117,7 @@ class GkdiExplicitRootKeyTests(GkdiKdcBaseTest):
         """Test with the SHA1 algorithm."""
         key = self.check_rpc_get_key(
             self.new_root_key(hash_algorithm=Algorithm.SHA1),
-            self.current_gkid(),
+            self.current_gkid(self.get_samdb()),
         )
         self.assertIs(Algorithm.SHA1, key.hash_algorithm)
 
@@ -125,7 +125,7 @@ class GkdiExplicitRootKeyTests(GkdiKdcBaseTest):
         """Test with the SHA256 algorithm."""
         key = self.check_rpc_get_key(
             self.new_root_key(hash_algorithm=Algorithm.SHA256),
-            self.current_gkid(),
+            self.current_gkid(self.get_samdb()),
         )
         self.assertIs(Algorithm.SHA256, key.hash_algorithm)
 
@@ -133,7 +133,7 @@ class GkdiExplicitRootKeyTests(GkdiKdcBaseTest):
         """Test with the SHA384 algorithm."""
         key = self.check_rpc_get_key(
             self.new_root_key(hash_algorithm=Algorithm.SHA384),
-            self.current_gkid(),
+            self.current_gkid(self.get_samdb()),
         )
         self.assertIs(Algorithm.SHA384, key.hash_algorithm)
 
@@ -141,7 +141,7 @@ class GkdiExplicitRootKeyTests(GkdiKdcBaseTest):
         """Test with the SHA512 algorithm."""
         key = self.check_rpc_get_key(
             self.new_root_key(hash_algorithm=Algorithm.SHA512),
-            self.current_gkid(),
+            self.current_gkid(self.get_samdb()),
         )
         self.assertIs(Algorithm.SHA512, key.hash_algorithm)
 
@@ -149,7 +149,7 @@ class GkdiExplicitRootKeyTests(GkdiKdcBaseTest):
         """Test without a specified algorithm."""
         key = self.check_rpc_get_key(
             self.new_root_key(hash_algorithm=None),
-            self.current_gkid(),
+            self.current_gkid(self.get_samdb()),
         )
         self.assertIs(Algorithm.SHA256, key.hash_algorithm)
 
@@ -158,6 +158,7 @@ class GkdiExplicitRootKeyTests(GkdiKdcBaseTest):
         root_key_id = self.new_root_key(use_start_time=ROOT_KEY_START_TIME)
 
         future_gkid = self.current_gkid(
+            self.get_samdb(),
             offset=timedelta_from_nt_time_delta(
                 NtTimeDelta(KEY_CYCLE_DURATION + MAX_CLOCK_SKEW)
             )
@@ -185,7 +186,7 @@ class GkdiExplicitRootKeyTests(GkdiKdcBaseTest):
         """Attempt to use a root key with an effective time of zero."""
         root_key_id = self.new_root_key(use_start_time=NtTime(0))
 
-        gkid = self.current_gkid()
+        gkid = self.current_gkid(self.get_samdb())
 
         with self.assertRaises(GetKeyError) as err:
             self.get_key(self.get_samdb(), self.gmsa_sd, root_key_id, gkid)
@@ -209,7 +210,7 @@ class GkdiExplicitRootKeyTests(GkdiKdcBaseTest):
         """Attempt to use a root key with an effective time set too low."""
         root_key_id = self.new_root_key(use_start_time=NtTime(ROOT_KEY_START_TIME - 1))
 
-        gkid = self.current_gkid()
+        gkid = self.current_gkid(self.get_samdb())
 
         with self.assertRaises(GetKeyError) as err:
             self.get_key(self.get_samdb(), self.gmsa_sd, root_key_id, gkid)
@@ -233,7 +234,7 @@ class GkdiExplicitRootKeyTests(GkdiKdcBaseTest):
 
     def test_before_valid(self):
         """Attempt to use a key before it is valid."""
-        gkid = self.current_gkid()
+        gkid = self.current_gkid(self.get_samdb())
         valid_start_time = NtTime(
             gkid.start_nt_time() + KEY_CYCLE_DURATION + MAX_CLOCK_SKEW
         )
@@ -268,7 +269,7 @@ class GkdiExplicitRootKeyTests(GkdiKdcBaseTest):
         """Attempt to use a non‐existent root key."""
         root_key_id = misc.GUID(secrets.token_bytes(16))
 
-        gkid = self.current_gkid()
+        gkid = self.current_gkid(self.get_samdb())
 
         with self.assertRaises(GetKeyError) as err:
             self.get_key(self.get_samdb(), self.gmsa_sd, root_key_id, gkid)
@@ -292,7 +293,7 @@ class GkdiExplicitRootKeyTests(GkdiKdcBaseTest):
         """Attempt to use a root key that is the wrong length."""
         root_key_id = self.new_root_key(data=bytes(KEY_LEN_BYTES // 2))
 
-        gkid = self.current_gkid()
+        gkid = self.current_gkid(self.get_samdb())
 
         with self.assertRaises(GetKeyError) as err:
             self.get_key(self.get_samdb(), self.gmsa_sd, root_key_id, gkid)
@@ -724,7 +725,7 @@ class GkdiSelfTests(GkdiKdcBaseTest):
             self.gmsa_sd,
             root_key_id,
             gkid,
-            current_gkid=self.current_gkid(),
+            current_gkid=self.current_gkid(self.get_samdb()),
         )
 
         self.assertEqual(gkid, key.gkid)