tests/krb5: Add tests for gMSAs
authorJo Sutton <josutton@catalyst.net.nz>
Fri, 5 Apr 2024 00:44:08 +0000 (13:44 +1300)
committerAndrew Bartlett <abartlet@samba.org>
Tue, 16 Apr 2024 03:58:31 +0000 (03:58 +0000)
Signed-off-by: Jo Sutton <josutton@catalyst.net.nz>
Reviewed-by: Andrew Bartlett <abartlet@samba.org>
python/samba/tests/krb5/gmsa_tests.py [new file with mode: 0755]
selftest/knownfail.d/gmsa
selftest/knownfail_mit_kdc_1_20
source4/selftest/tests.py

diff --git a/python/samba/tests/krb5/gmsa_tests.py b/python/samba/tests/krb5/gmsa_tests.py
new file mode 100755 (executable)
index 0000000..1d3787a
--- /dev/null
@@ -0,0 +1,905 @@
+#!/usr/bin/env python3
+# Unix SMB/CIFS implementation.
+# Copyright (C) Stefan Metzmacher 2020
+# Copyright (C) Catalyst.Net Ltd 2024
+#
+# This program is free software: you can redistribute it and/or modify
+# it under the terms of the GNU General Public License as published by
+# the Free Software Foundation, either version 3 of the License, or
+# (at your option) any later version.
+#
+# This program is distributed in the hope that it will be useful,
+# but WITHOUT ANY WARRANTY; without even the implied warranty of
+# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
+# GNU General Public License for more details.
+#
+# You should have received a copy of the GNU General Public License
+# along with this program.  If not, see <https://www.gnu.org/licenses/>.
+#
+
+import sys
+import os
+
+sys.path.insert(0, "bin/python")
+os.environ["PYTHONUNBUFFERED"] = "1"
+
+from typing import Iterable, NewType, Optional, Tuple, TypeVar
+
+import datetime
+from itertools import chain
+
+import ldb
+
+from samba import auth, dsdb, gensec
+from samba.dcerpc import gkdi, gmsa, misc, netlogon, security
+from samba.ndr import ndr_pack, ndr_unpack
+from samba.nt_time import (
+    nt_time_delta_from_timedelta,
+    nt_time_from_datetime,
+    NtTime,
+    NtTimeDelta,
+    timedelta_from_nt_time_delta,
+)
+from samba.samdb import SamDB
+from samba.credentials import Credentials, DONT_USE_KERBEROS
+from samba.gkdi import (
+    Gkid,
+    GroupKey,
+    KEY_CYCLE_DURATION,
+)
+
+from samba.tests import connect_samdb
+from samba.tests.krb5 import kcrypto
+from samba.tests.gkdi import GkdiBaseTest, MAX_CLOCK_SKEW
+from samba.tests.krb5.kdc_base_test import KDCBaseTest
+from samba.tests.krb5.raw_testcase import KerberosCredentials
+from samba.tests.krb5.rfc4120_constants import (
+    KU_PA_ENC_TIMESTAMP,
+    NT_PRINCIPAL,
+    PADATA_ENC_TIMESTAMP,
+)
+import samba.tests.krb5.rfc4120_pyasn1 as krb5_asn1
+
+GMSA_DEFAULT_MANAGED_PASSWORD_INTERVAL = 30
+
+Gmsa = NewType("Gmsa", ldb.Message)
+
+
+def gkdi_rollover_interval(managed_password_interval: int) -> NtTimeDelta:
+    rollover_interval = NtTimeDelta(
+        managed_password_interval * 24 // 10 * KEY_CYCLE_DURATION
+    )
+    if rollover_interval == 0:
+        raise ValueError("rollover interval must not be zero")
+    return rollover_interval
+
+
+class GmsaSeries:
+    start_time: NtTime
+    rollover_interval: NtTimeDelta
+
+    def __init__(self, start_gkid: Gkid, rollover_interval: NtTimeDelta) -> None:
+        self.start_time = start_gkid.start_nt_time()
+        self.rollover_interval = rollover_interval
+
+    def interval_gkid(self, n: int) -> Gkid:
+        return Gkid.from_nt_time(self.start_of_interval(n))
+
+    def start_of_interval(self, n: int) -> NtTime:
+        if not isinstance(n, int):
+            raise ValueError(f"{n} must be an integer")
+        return NtTime(int(self.start_time + n * self.rollover_interval))
+
+    def during_interval(self, n: int) -> NtTime:
+        return NtTime(int(self.start_of_interval(n) + self.rollover_interval // 2))
+
+    def during_skew_window(self, n: int) -> NtTime:
+        two_minutes = nt_time_delta_from_timedelta(datetime.timedelta(minutes=2))
+        return NtTime(
+            int(self.start_of_interval(n) + self.rollover_interval - two_minutes)
+        )
+
+
+class GmsaTests(GkdiBaseTest, KDCBaseTest):
+    def _as_req(
+        self,
+        creds: KerberosCredentials,
+        target_creds: KerberosCredentials,
+        enctype: kcrypto.Enctype,
+    ) -> dict:
+        preauth_key = self.PasswordKey_from_creds(creds, enctype)
+
+        def generate_padata_fn(
+            _kdc_exchange_dict: dict, _callback_dict: Optional[dict], req_body: dict
+        ) -> Tuple[list, dict]:
+            padata = []
+
+            patime, pausec = self.get_KerberosTimeWithUsec()
+            enc_ts = self.PA_ENC_TS_ENC_create(patime, pausec)
+            enc_ts = self.der_encode(enc_ts, asn1Spec=krb5_asn1.PA_ENC_TS_ENC())
+
+            enc_ts = self.EncryptedData_create(preauth_key, KU_PA_ENC_TIMESTAMP, enc_ts)
+            enc_ts = self.der_encode(enc_ts, asn1Spec=krb5_asn1.EncryptedData())
+
+            enc_ts = self.PA_DATA_create(PADATA_ENC_TIMESTAMP, enc_ts)
+
+            padata.append(enc_ts)
+
+            return padata, req_body
+
+        user_name = creds.get_username()
+        cname = self.PrincipalName_create(
+            name_type=NT_PRINCIPAL, names=user_name.split("/")
+        )
+
+        target_name = target_creds.get_username()
+        target_realm = target_creds.get_realm()
+
+        sname = self.PrincipalName_create(
+            name_type=NT_PRINCIPAL, names=["host", target_name[:-1]]
+        )
+
+        check_error_fn = None
+        check_rep_fn = self.generic_check_kdc_rep
+
+        expected_sname = self.PrincipalName_create(
+            name_type=NT_PRINCIPAL, names=[target_name]
+        )
+
+        kdc_options = "forwardable,renewable,canonicalize,renewable-ok"
+        kdc_options = krb5_asn1.KDCOptions(kdc_options)
+
+        ticket_decryption_key = self.TicketDecryptionKey_from_creds(target_creds)
+
+        kdc_exchange_dict = self.as_exchange_dict(
+            creds=creds,
+            expected_crealm=creds.get_realm(),
+            expected_cname=cname,
+            expected_srealm=target_realm,
+            expected_sname=expected_sname,
+            expected_supported_etypes=target_creds.tgs_supported_enctypes,
+            ticket_decryption_key=ticket_decryption_key,
+            generate_padata_fn=generate_padata_fn,
+            check_error_fn=check_error_fn,
+            check_rep_fn=check_rep_fn,
+            check_kdc_private_fn=self.generic_check_kdc_private,
+            expected_error_mode=0,
+            expected_salt=creds.get_salt(),
+            preauth_key=preauth_key,
+            kdc_options=str(kdc_options),
+        )
+
+        till = self.get_KerberosTime(offset=36000)
+
+        etypes = kcrypto.Enctype.AES256, kcrypto.Enctype.RC4
+
+        rep = self._generic_kdc_exchange(
+            kdc_exchange_dict,
+            cname=cname,
+            realm=target_realm,
+            sname=sname,
+            till_time=till,
+            etypes=etypes,
+        )
+        self.check_as_reply(rep)
+
+        return kdc_exchange_dict
+
+    # Note: unused
+    def gkdi_get_key_start_time(self, key_id: gkdi.KeyEnvelope) -> NtTime:
+        return Gkid.from_key_envelope(key_id).start_nt_time()
+
+    def get_password(
+        self,
+        samdb: SamDB,
+        target_sd: bytes,
+        root_key_id: Optional[misc.GUID],
+        gkid: Gkid,
+        sid: security.dom_sid,
+    ) -> bytes:
+        group_key = self.get_key_exact(samdb, target_sd, root_key_id, gkid)
+
+        password = self.generate_gmsa_password(group_key, sid)
+        return self.post_process_password_buffer(password)
+
+    def get_password_based_on_gkid(
+        self, samdb: SamDB, gkid: Gkid, sid: security.dom_sid
+    ) -> bytes:
+        return self.get_password(samdb, self.gmsa_sd, None, gkid, sid)
+
+    def get_password_based_on_timestamp(
+        self, samdb: SamDB, timestamp: NtTime, sid: security.dom_sid
+    ) -> bytes:
+        return self.get_password_based_on_gkid(samdb, Gkid.from_nt_time(timestamp), sid)
+
+    # Note: unused
+    def get_password_based_on_key_id(
+        self, samdb: SamDB, managed_password: gkdi.KeyEnvelope, sid: str
+    ) -> bytes:
+        return self.get_password(
+            samdb,
+            self.gmsa_sd,
+            managed_password.root_key_id,
+            Gkid.from_key_envelope(managed_password),
+            sid,
+        )
+
+    def generate_gmsa_password(self, key: GroupKey, sid: str) -> bytes:
+        context = ndr_pack(security.dom_sid(sid))
+        algorithm = key.hash_algorithm.algorithm()
+        gmsa_password_len = 256
+
+        return self.kdf(
+            algorithm,
+            key.key,
+            context,
+            label="GMSA PASSWORD",
+            len_in_bytes=gmsa_password_len,
+        )
+
+    def post_process_password_buffer(self, key: bytes) -> bytes:
+        self.assertEqual(0, len(key) & 1, f"length of key ({len(key)}) is not even")
+
+        def convert_null(t: Tuple[int, int]) -> Tuple[int, int]:
+            if t == (0, 0):
+                return 1, 0
+
+            return t
+
+        T = TypeVar("T")
+
+        def take_pairs(iterable: Iterable[T]) -> Iterable[Tuple[T, T]]:
+            it = iter(iterable)
+            while True:
+                try:
+                    first = next(it)
+                except StopIteration:
+                    break
+
+                yield first, next(it)
+
+        return bytes(chain.from_iterable(map(convert_null, take_pairs(key))))
+
+    def get_gmsa_object(self, samdb: SamDB, dn: ldb.Dn) -> Gmsa:
+        res = samdb.search(
+            dn,
+            scope=ldb.SCOPE_BASE,
+            attrs=[
+                "msDS-ManagedPasswordInterval",
+                "msDS-ManagedPasswordId",
+                "msDS-ManagedPasswordPreviousId",
+                "whenCreated",
+            ],
+        )
+        return res[0]
+
+    def gmsa_rollover_interval(self, gmsa_object: Gmsa) -> NtTimeDelta:
+        managed_password_interval = gmsa_object.get(
+            "msDS-ManagedPasswordInterval", idx=0
+        )
+        if managed_password_interval is None:
+            managed_password_interval = GMSA_DEFAULT_MANAGED_PASSWORD_INTERVAL
+        else:
+            managed_password_interval = int(managed_password_interval)
+
+        return gkdi_rollover_interval(managed_password_interval)
+
+    def gmsa_creation_nt_time(self, gmsa_object: Gmsa) -> NtTime:
+        creation_time: Optional[bytes] = gmsa_object.get("whenCreated", idx=0)
+        self.assertIsNotNone(creation_time)
+        assert creation_time is not None  # to help the type checker
+
+        create_time = datetime.datetime.fromtimestamp(
+            ldb.string_to_time(creation_time.decode()), tz=datetime.timezone.utc
+        )
+        return nt_time_from_datetime(create_time)
+
+    def gmsa_series(self, managed_password_interval: int) -> GmsaSeries:
+        return GmsaSeries(
+            self.future_gkid(), gkdi_rollover_interval(managed_password_interval)
+        )
+
+    def expected_gmsa_password_blob(
+        self,
+        samdb: SamDB,
+        creds: KerberosCredentials,
+        gkid: Gkid,
+        *,
+        query_expiration_gkid: Gkid,
+        previous_gkid: Optional[Gkid] = None,
+        return_future_key: bool = False,
+    ) -> gmsa.MANAGEDPASSWORD_BLOB:
+        new_password = self.get_password_based_on_gkid(samdb, gkid, creds.get_sid())
+        old_password = None
+        if previous_gkid is not None:
+            old_password = self.get_password_based_on_gkid(
+                samdb, previous_gkid, creds.get_sid()
+            )
+
+        current_time = self.current_nt_time(samdb)
+
+        gmsa_object = self.get_gmsa_object(samdb, creds.get_dn())
+        gkdi_rollover_interval = self.gmsa_rollover_interval(gmsa_object)
+
+        query_expiration_time = query_expiration_gkid.start_nt_time()
+        query_password_interval = NtTimeDelta(query_expiration_time - current_time)
+        unchanged_password_interval = NtTimeDelta(
+            max(
+                0,
+                query_expiration_time
+                + (gkdi_rollover_interval if return_future_key else 0)
+                - current_time
+                - MAX_CLOCK_SKEW,
+            )
+        )
+
+        return self.marshal_password(
+            new_password,
+            old_password,
+            query_password_interval,
+            unchanged_password_interval,
+        )
+
+    def expected_current_gmsa_password_blob(
+        self,
+        samdb: SamDB,
+        creds: KerberosCredentials,
+        *,
+        future_key_is_acceptable: bool,
+    ) -> gmsa.MANAGEDPASSWORD_BLOB:
+        gmsa_object = self.get_gmsa_object(samdb, creds.get_dn())
+
+        gkdi_rollover_interval = self.gmsa_rollover_interval(gmsa_object)
+
+        pwd_id_blob = gmsa_object.get("msDS-ManagedPasswordId", idx=0)
+        self.assertIsNotNone(pwd_id_blob, "SAM should have initialized password ID")
+
+        pwd_id = ndr_unpack(gkdi.KeyEnvelope, pwd_id_blob)
+        key_start_time = Gkid.from_key_envelope(pwd_id).start_nt_time()
+
+        current_time = self.current_nt_time(samdb)
+
+        time_since_key_start = NtTimeDelta(current_time - key_start_time)
+        quantized_time_since_key_start = NtTimeDelta(
+            time_since_key_start // gkdi_rollover_interval * gkdi_rollover_interval
+        )
+        new_key_start_time = NtTime(key_start_time + quantized_time_since_key_start)
+        new_key_expiration_time = NtTime(new_key_start_time + gkdi_rollover_interval)
+
+        account_sid = creds.get_sid()
+
+        within_clock_skew_window = (
+            new_key_expiration_time - current_time <= MAX_CLOCK_SKEW
+        )
+        return_future_key = future_key_is_acceptable and within_clock_skew_window
+        if return_future_key:
+            new_password = self.get_password_based_on_timestamp(
+                samdb, new_key_expiration_time, account_sid
+            )
+            old_password = self.get_password_based_on_timestamp(
+                samdb, new_key_start_time, account_sid
+            )
+        else:
+            new_password = self.get_password_based_on_timestamp(
+                samdb, new_key_start_time, account_sid
+            )
+
+            account_age = NtTimeDelta(
+                current_time - self.gmsa_creation_nt_time(gmsa_object)
+            )
+            if account_age >= gkdi_rollover_interval:
+                old_password = self.get_password_based_on_timestamp(
+                    samdb,
+                    NtTime(new_key_start_time - gkdi_rollover_interval),
+                    account_sid,
+                )
+            else:
+                # The account is not old enough to have a previous password.
+                old_password = None
+
+        key_expiration_time = NtTime(key_start_time + gkdi_rollover_interval)
+        key_is_expired = key_expiration_time <= current_time
+
+        query_expiration_time = NtTime(
+            new_key_expiration_time if key_is_expired else key_expiration_time
+        )
+        query_password_interval = NtTimeDelta(query_expiration_time - current_time)
+        unchanged_password_interval = NtTimeDelta(
+            max(
+                0,
+                query_expiration_time
+                + (gkdi_rollover_interval if return_future_key else 0)
+                - current_time
+                - MAX_CLOCK_SKEW,
+            )
+        )
+
+        return self.marshal_password(
+            new_password,
+            old_password,
+            query_password_interval,
+            unchanged_password_interval,
+        )
+
+    def marshal_password(
+        self,
+        current_password: bytes,
+        previous_password: Optional[bytes],
+        query_password_interval: NtTimeDelta,
+        unchanged_password_interval: NtTimeDelta,
+    ) -> gmsa.MANAGEDPASSWORD_BLOB:
+        managed_password = gmsa.MANAGEDPASSWORD_BLOB()
+
+        managed_password.passwords.current = current_password
+        managed_password.passwords.previous = previous_password
+        managed_password.passwords.query_interval = query_password_interval
+        managed_password.passwords.unchanged_interval = unchanged_password_interval
+
+        return managed_password
+
+    def gmsa_account(
+        self,
+        *,
+        samdb: Optional[SamDB] = None,
+        interval: int = 1,
+        msa_membership: Optional[str] = None,
+        **kwargs,
+    ) -> KerberosCredentials:
+        if msa_membership is None:
+            allow_world_sddl = "O:SYD:(A;;RP;;;WD)"
+            msa_membership = allow_world_sddl
+
+        msa_membership_sd = ndr_pack(
+            security.descriptor.from_sddl(msa_membership, security.dom_sid())
+        )
+
+        try:
+            creds = self.get_cached_creds(
+                samdb=samdb,
+                account_type=self.AccountType.GROUP_MANAGED_SERVICE,
+                opts={
+                    "additional_details": self.freeze(
+                        {
+                            "msDS-GroupMSAMembership": msa_membership_sd,
+                            "msDS-ManagedPasswordInterval": str(interval),
+                        }
+                    ),
+                    **kwargs,
+                },
+                # Ensure the gMSA is a brand‐new account.
+                use_cache=False,
+            )
+        except ldb.LdbError as err:
+            if err.args[0] == ldb.ERR_UNWILLING_TO_PERFORM:
+                self.fail(
+                    "If you’re running these tests against Windows, try “warming up”"
+                    " the GKDI service by running `samba.tests.krb5.gkdi_tests` first."
+                )
+
+            raise
+
+        # Derive the account’s current password. The account is too new to have a previous password yet.
+        managed_pwd = self.expected_current_gmsa_password_blob(
+            self.get_samdb() if samdb is None else samdb,
+            creds,
+            future_key_is_acceptable=False,
+        )
+
+        # Set the password.
+        self.assertIsNotNone(
+            managed_pwd.passwords.current, "current password must be present"
+        )
+        creds.set_utf16_password(managed_pwd.passwords.current)
+
+        return creds
+
+    def get_local_samdb(self) -> SamDB:
+        """Return a connection to the local database."""
+
+        lp = self.get_lp()
+        samdb = connect_samdb(
+            samdb_url=lp.samdb_url(), lp=lp, credentials=self.get_admin_creds()
+        )
+        self.assertLocalSamDB(samdb)
+
+        return samdb
+
+    # Perform a gensec logon using NTLMSSP. As samdb is passed in as a
+    # parameter, it can have a time set on it with set_db_time().
+    def gensec_ntlmssp_logon(
+        self, client_creds: Credentials, samdb: SamDB
+    ) -> "auth.session_info":
+        lp = self.get_lp()
+        lp.set("server role", "active directory domain controller")
+
+        settings = {"lp_ctx": lp, "target_hostname": lp.get("netbios name")}
+
+        gensec_client = gensec.Security.start_client(settings)
+        # Ensure that we don’t use Kerberos.
+        self.assertEqual(DONT_USE_KERBEROS, client_creds.get_kerberos_state())
+        gensec_client.set_credentials(client_creds)
+        gensec_client.want_feature(gensec.FEATURE_SEAL)
+        gensec_client.start_mech_by_name("ntlmssp")
+
+        auth_context = auth.AuthContext(lp_ctx=lp, ldb=samdb)
+
+        gensec_server = gensec.Security.start_server(settings, auth_context)
+        machine_creds = Credentials()
+        machine_creds.guess(lp)
+        machine_creds.set_machine_account(lp)
+        gensec_server.set_credentials(machine_creds)
+
+        gensec_server.start_mech_by_name("ntlmssp")
+
+        client_finished = False
+        server_finished = False
+        client_to_server = b""
+        server_to_client = b""
+
+        # Operate as both the client and the server to verify the user’s credentials.
+        while not client_finished or not server_finished:
+            if not client_finished:
+                client_finished, client_to_server = gensec_client.update(
+                    server_to_client
+                )
+            if not server_finished:
+                server_finished, server_to_client = gensec_server.update(
+                    client_to_server
+                )
+
+        # Retrieve the SIDs from the security token.
+        return gensec_server.session_info()
+
+    def check_nt_interval(
+        self,
+        expected_nt_interval: NtTimeDelta,
+        nt_interval: NtTimeDelta,
+        interval_name: str,
+    ) -> None:
+        """Check that the intervals match to within thirty seconds or so."""
+
+        threshold = datetime.timedelta(seconds=30)
+
+        interval = timedelta_from_nt_time_delta(nt_interval)
+        expected_interval = timedelta_from_nt_time_delta(expected_nt_interval)
+        interval_difference = abs(interval - expected_interval)
+        self.assertLess(
+            interval_difference,
+            threshold,
+            f"{interval_name} ({interval}) is out by {interval_difference} from"
+            f" expected ({expected_interval})",
+        )
+
+    def check_managed_pwd_intervals(
+        self,
+        expected_managed_pwd: gmsa.MANAGEDPASSWORD_BLOB,
+        managed_pwd: gmsa.MANAGEDPASSWORD_BLOB,
+    ) -> None:
+        expected_passwords = expected_managed_pwd.passwords
+        passwords = managed_pwd.passwords
+
+        self.check_nt_interval(
+            expected_passwords.query_interval,
+            passwords.query_interval,
+            "query interval",
+        )
+        self.check_nt_interval(
+            expected_passwords.unchanged_interval,
+            passwords.unchanged_interval,
+            "unchanged interval",
+        )
+
+    def check_managed_pwd(
+        self,
+        samdb: SamDB,
+        creds: KerberosCredentials,
+        *,
+        expected_managed_pwd: gmsa.MANAGEDPASSWORD_BLOB,
+    ) -> None:
+        res = samdb.search(
+            creds.get_dn(), scope=ldb.SCOPE_BASE, attrs=["msDS-ManagedPassword"]
+        )
+        self.assertEqual(1, len(res), "gMSA not found")
+        managed_password = res[0].get("msDS-ManagedPassword", idx=0)
+
+        self.assertIsNotNone(managed_password)
+        managed_pwd = ndr_unpack(gmsa.MANAGEDPASSWORD_BLOB, managed_password)
+
+        self.assertEqual(1, managed_pwd.version)
+        self.assertEqual(0, managed_pwd.reserved)
+        self.assertEqual(len(managed_password), managed_pwd.length)
+
+        self.assertIsNotNone(expected_managed_pwd.passwords.current)
+
+        self.assertEqual(
+            managed_pwd.passwords.current, expected_managed_pwd.passwords.current
+        )
+        self.assertEqual(
+            managed_pwd.passwords.previous, expected_managed_pwd.passwords.previous
+        )
+
+        self.check_managed_pwd_intervals(expected_managed_pwd, managed_pwd)
+
+    # When creating a gMSA, Windows seems to pick the root key with the
+    # greatest msKds-CreateTime having msKds-UseStartTime ≤ ten hours ago.
+    # Bear in mind that it seems also to cache the key, so it won’t always
+    # use the latest one.
+
+    def get_managed_service_accounts_dn(self) -> ldb.Dn:
+        samdb = self.get_samdb()
+        return samdb.get_wellknown_dn(
+            samdb.get_default_basedn(), dsdb.DS_GUID_MANAGED_SERVICE_ACCOUNTS_CONTAINER
+        )
+
+    def check_managed_password_access(
+        self, creds: Credentials, *, expect_access
+    ) -> None:
+        samdb = self.get_samdb()
+        managed_service_accounts_dn = self.get_managed_service_accounts_dn()
+        username = creds.get_username()
+
+        # Try base, subtree, and one‐level searches.
+        searches = (
+            (creds.get_dn(), ldb.SCOPE_BASE),
+            (managed_service_accounts_dn, ldb.SCOPE_SUBTREE),
+            (managed_service_accounts_dn, ldb.SCOPE_ONELEVEL),
+        )
+
+        for dn, scope in searches:
+            # Perform a search and see whether we’re allowed to view the managed password.
+
+            res = samdb.search(
+                dn,
+                scope=scope,
+                expression=f"sAMAccountName={username}",
+                attrs=["msDS-ManagedPassword"],
+            )
+            self.assertEqual(1, len(res), "should always find the gMSA")
+
+            managed_password = res[0].get("msDS-ManagedPassword", idx=0)
+            if expect_access:
+                self.assertIsNotNone(
+                    managed_password, "should be allowed to view the password"
+                )
+            else:
+                self.assertIsNone(
+                    managed_password, "should not be allowed to view the password"
+                )
+
+    def test_retrieved_password_allowed(self):
+        """Test being allowed to view the managed password."""
+        self.check_managed_password_access(self.gmsa_account(), expect_access=True)
+
+    def test_retrieved_password_denied(self):
+        """Test not being allowed to view the managed password."""
+        deny_world_sddl = "O:SYD:(D;;RP;;;WD)"
+        self.check_managed_password_access(
+            self.gmsa_account(msa_membership=deny_world_sddl), expect_access=False
+        )
+
+    def future_gkid(self) -> Gkid:
+        """Return (6333, 26, 5)—an arbitrary GKID far enough in the future that
+        it’s situated beyond any reasonable rollover period. But not so far in
+        the future that Python’s datetime library will throw OverflowErrors."""
+        future_date = datetime.datetime(9000, 1, 1, 0, 0, tzinfo=datetime.timezone.utc)
+        return Gkid.from_nt_time(nt_time_from_datetime(future_date))
+
+    def future_time(self) -> NtTime:
+        """Return an arbitrary time far enough in the future that it’s situated
+        beyond any reasonable rollover period. But not so far in the future that
+        Python’s datetime library will throw OverflowErrors."""
+        return self.future_gkid().start_nt_time()
+
+    def test_retrieved_password(self):
+        """Test that we can retrieve the correct password for a gMSA."""
+
+        samdb = self.get_samdb()
+        creds = self.gmsa_account()
+
+        expected = self.expected_current_gmsa_password_blob(
+            samdb,
+            creds,
+            future_key_is_acceptable=True,
+        )
+        self.check_managed_pwd(samdb, creds, expected_managed_pwd=expected)
+
+    def test_retrieved_password_when_current_key_is_valid(self):
+        """Test that we can retrieve the correct password for a gMSA at a time
+        when we are sure it is valid."""
+        password_interval = 37
+
+        samdb = self.get_local_samdb()
+        series = self.gmsa_series(password_interval)
+        self.set_db_time(samdb, series.start_of_interval(0))
+
+        creds = self.gmsa_account(samdb=samdb, interval=password_interval)
+
+        # Check the managed password of the account the moment it has been
+        # created.
+        expected = self.expected_gmsa_password_blob(
+            samdb,
+            creds,
+            series.interval_gkid(0),
+            previous_gkid=series.interval_gkid(-1),
+            query_expiration_gkid=series.interval_gkid(1),
+        )
+        self.check_managed_pwd(samdb, creds, expected_managed_pwd=expected)
+
+    def test_retrieved_password_when_current_key_is_expired(self):
+        """Test that we can retrieve the correct password for a gMSA when the
+        original password has expired."""
+        password_interval = 14
+
+        samdb = self.get_local_samdb()
+        series = self.gmsa_series(password_interval)
+        self.set_db_time(samdb, series.start_of_interval(0))
+
+        creds = self.gmsa_account(samdb=samdb, interval=password_interval)
+
+        # Set the time to the moment the original password has expired, and
+        # check that the managed password is correct.
+        expired_time = series.start_of_interval(1)
+        self.set_db_time(samdb, expired_time)
+        expected = self.expected_gmsa_password_blob(
+            samdb,
+            creds,
+            series.interval_gkid(1),
+            previous_gkid=series.interval_gkid(0),
+            query_expiration_gkid=series.interval_gkid(2),
+        )
+        self.check_managed_pwd(samdb, creds, expected_managed_pwd=expected)
+
+    def test_retrieved_password_when_next_key_is_expired(self):
+        password_interval = 1
+
+        samdb = self.get_local_samdb()
+        series = self.gmsa_series(password_interval)
+        self.set_db_time(samdb, series.start_of_interval(0))
+
+        creds = self.gmsa_account(samdb=samdb, interval=password_interval)
+
+        expired_time = series.start_of_interval(2)
+        self.set_db_time(samdb, expired_time)
+
+        expected = self.expected_gmsa_password_blob(
+            samdb,
+            creds,
+            series.interval_gkid(2),
+            previous_gkid=series.interval_gkid(1),
+            query_expiration_gkid=series.interval_gkid(3),
+        )
+        self.check_managed_pwd(samdb, creds, expected_managed_pwd=expected)
+
+    def test_retrieved_password_during_clock_skew_window_when_current_key_is_valid(
+        self,
+    ):
+        password_interval = 60
+
+        samdb = self.get_local_samdb()
+        series = self.gmsa_series(password_interval)
+        self.set_db_time(samdb, series.start_of_interval(0))
+
+        creds = self.gmsa_account(samdb=samdb, interval=password_interval)
+
+        self.set_db_time(samdb, series.during_skew_window(0))
+
+        expected = self.expected_gmsa_password_blob(
+            samdb,
+            creds,
+            series.interval_gkid(1),
+            previous_gkid=series.interval_gkid(0),
+            query_expiration_gkid=series.interval_gkid(1),
+            return_future_key=True,
+        )
+        self.check_managed_pwd(samdb, creds, expected_managed_pwd=expected)
+
+    def test_retrieved_password_during_clock_skew_window_when_current_key_is_expired(
+        self,
+    ):
+        password_interval = 100
+
+        samdb = self.get_local_samdb()
+        series = self.gmsa_series(password_interval)
+        self.set_db_time(samdb, series.start_of_interval(0))
+
+        creds = self.gmsa_account(samdb=samdb, interval=password_interval)
+
+        self.set_db_time(samdb, series.during_skew_window(1))
+
+        expected = self.expected_gmsa_password_blob(
+            samdb,
+            creds,
+            series.interval_gkid(2),
+            previous_gkid=series.interval_gkid(1),
+            query_expiration_gkid=series.interval_gkid(2),
+            return_future_key=True,
+        )
+        self.check_managed_pwd(samdb, creds, expected_managed_pwd=expected)
+
+    def test_retrieved_password_during_clock_skew_window_when_next_key_is_expired(
+        self,
+    ):
+        password_interval = 16
+
+        samdb = self.get_local_samdb()
+        series = self.gmsa_series(password_interval)
+        self.set_db_time(samdb, series.start_of_interval(0))
+
+        creds = self.gmsa_account(samdb=samdb, interval=password_interval)
+
+        self.set_db_time(samdb, series.during_skew_window(2))
+
+        expected = self.expected_gmsa_password_blob(
+            samdb,
+            creds,
+            series.interval_gkid(3),
+            previous_gkid=series.interval_gkid(2),
+            query_expiration_gkid=series.interval_gkid(3),
+            return_future_key=True,
+        )
+        self.check_managed_pwd(samdb, creds, expected_managed_pwd=expected)
+
+    def test_gmsa_can_perform_gensec_ntlmssp_logon(self):
+        creds = self.gmsa_account(kerberos_enabled=False)
+
+        # Perform a gensec logon.
+        session = self.gensec_ntlmssp_logon(creds, self.get_local_samdb())
+
+        # Ensure that the first SID contained within the security token is the gMSA’s SID.
+        token = session.security_token
+        token_sids = token.sids
+        self.assertGreater(len(token_sids), 0)
+
+        # Ensure that they match.
+        self.assertEqual(security.dom_sid(creds.get_sid()), token_sids[0])
+
+    def test_gmsa_can_perform_netlogon(self):
+        creds = self.gmsa_account(kerberos_enabled=False)
+        self._test_samlogon(
+            creds,
+            netlogon.NetlogonNetworkInformation,
+            validation_level=netlogon.NetlogonValidationSamInfo4,
+            domain_joined_mach_creds=creds,
+        )
+
+    def _gmsa_can_perform_as_req(self, *, enctype: kcrypto.Enctype) -> None:
+        self._as_req(self.gmsa_account(), self.get_service_creds(), enctype)
+
+    def test_gmsa_can_perform_as_req_with_aes256(self):
+        self._gmsa_can_perform_as_req(enctype=kcrypto.Enctype.AES256)
+
+    def test_gmsa_can_perform_as_req_with_rc4(self):
+        self._gmsa_can_perform_as_req(enctype=kcrypto.Enctype.RC4)
+
+    def _gmsa_can_authenticate_to_ldap(self, *, with_kerberos: bool) -> None:
+        creds = self.gmsa_account(kerberos_enabled=with_kerberos)
+
+        protocol = "ldap"
+
+        # Authenticate to LDAP.
+        samdb_user = SamDB(
+            url=f"{protocol}://{self.dc_host}", credentials=creds, lp=self.get_lp()
+        )
+
+        # Search for the user’s token groups.
+        res = samdb_user.search("", scope=ldb.SCOPE_BASE, attrs=["tokenGroups"])
+        self.assertEqual(1, len(res))
+
+        token_groups = res[0].get("tokenGroups", idx=0)
+        self.assertIsNotNone(token_groups)
+
+        # Ensure that the token SID matches.
+        token_sid = ndr_unpack(security.dom_sid, token_groups)
+        self.assertEqual(security.dom_sid(creds.get_sid()), token_sid)
+
+    def test_gmsa_can_authenticate_to_ldap_with_kerberos(self):
+        self._gmsa_can_authenticate_to_ldap(with_kerberos=True)
+
+    def test_gmsa_can_authenticate_to_ldap_without_kerberos(self):
+        self._gmsa_can_authenticate_to_ldap(with_kerberos=False)
+
+
+if __name__ == "__main__":
+    import unittest
+
+    unittest.main()
index 1d819e55b637dd4b98ecd7254bf22179df8f9f41..0eabead5a16c907ad7aa4226194a6b249dc8a782 100644 (file)
@@ -1,2 +1,17 @@
 ^samba.tests.dckeytab.samba.tests.dckeytab.DCKeytabTests.test_export_keytab_gmsa
 ^samba.tests.blackbox.gmsa.samba.tests.blackbox.gmsa.GMSABlackboxTest.test_gmsa_password_access
+^samba.tests.krb5.gmsa_tests.samba.tests.krb5.gmsa_tests.GmsaTests.test_gmsa_can_authenticate_to_ldap_with_kerberos\(ad_dc:local\)$
+^samba.tests.krb5.gmsa_tests.samba.tests.krb5.gmsa_tests.GmsaTests.test_gmsa_can_authenticate_to_ldap_without_kerberos\(ad_dc:local\)$
+^samba.tests.krb5.gmsa_tests.samba.tests.krb5.gmsa_tests.GmsaTests.test_gmsa_can_perform_as_req_with_aes256\(ad_dc:local\)$
+^samba.tests.krb5.gmsa_tests.samba.tests.krb5.gmsa_tests.GmsaTests.test_gmsa_can_perform_as_req_with_rc4\(ad_dc:local\)$
+^samba.tests.krb5.gmsa_tests.samba.tests.krb5.gmsa_tests.GmsaTests.test_gmsa_can_perform_gensec_logon\(ad_dc:local\)$
+^samba.tests.krb5.gmsa_tests.samba.tests.krb5.gmsa_tests.GmsaTests.test_gmsa_can_perform_netlogon\(ad_dc:local\)$
+^samba.tests.krb5.gmsa_tests.samba.tests.krb5.gmsa_tests.GmsaTests.test_retrieved_password\(ad_dc:local\)$
+^samba.tests.krb5.gmsa_tests.samba.tests.krb5.gmsa_tests.GmsaTests.test_retrieved_password_allowed\(ad_dc:local\)$
+^samba.tests.krb5.gmsa_tests.samba.tests.krb5.gmsa_tests.GmsaTests.test_retrieved_password_denied\(ad_dc:local\)$
+^samba.tests.krb5.gmsa_tests.samba.tests.krb5.gmsa_tests.GmsaTests.test_retrieved_password_during_clock_skew_window_when_current_key_is_expired\(ad_dc:local\)$
+^samba.tests.krb5.gmsa_tests.samba.tests.krb5.gmsa_tests.GmsaTests.test_retrieved_password_during_clock_skew_window_when_current_key_is_valid\(ad_dc:local\)$
+^samba.tests.krb5.gmsa_tests.samba.tests.krb5.gmsa_tests.GmsaTests.test_retrieved_password_during_clock_skew_window_when_next_key_is_expired\(ad_dc:local\)$
+^samba.tests.krb5.gmsa_tests.samba.tests.krb5.gmsa_tests.GmsaTests.test_retrieved_password_when_current_key_is_expired\(ad_dc:local\)$
+^samba.tests.krb5.gmsa_tests.samba.tests.krb5.gmsa_tests.GmsaTests.test_retrieved_password_when_current_key_is_valid\(ad_dc:local\)$
+^samba.tests.krb5.gmsa_tests.samba.tests.krb5.gmsa_tests.GmsaTests.test_retrieved_password_when_next_key_is_expired\(ad_dc:local\)$
index b3a2ab3f94349d10b628b2de5c8cab4e66f0cb25..e30836f1d124b87c7ae4c0151cb3d56d0bcda449 100644 (file)
 ^samba\.tests\.krb5\.kdc_tgs_tests\.samba\.tests\.krb5\.kdc_tgs_tests\.KdcTgsTests\.test_single_component_krbtgt_no_pac_as_req\(ad_dc\)$
 ^samba\.tests\.krb5\.kdc_tgs_tests\.samba\.tests\.krb5\.kdc_tgs_tests\.KdcTgsTests\.test_single_component_krbtgt_no_pac_tgs_req\(ad_dc\)$
 ^samba\.tests\.krb5\.kdc_tgs_tests\.samba\.tests\.krb5\.kdc_tgs_tests\.KdcTgsTests\.test_single_component_krbtgt_service_ticket\(ad_dc\)$
+#
+# gMSA tests
+#
+^samba.tests.krb5.gmsa_tests.samba.tests.krb5.gmsa_tests.GmsaTests.test_gmsa_can_perform_as_req_with_aes256\(ad_dc:local\)$
+^samba.tests.krb5.gmsa_tests.samba.tests.krb5.gmsa_tests.GmsaTests.test_gmsa_can_perform_as_req_with_rc4\(ad_dc:local\)$
index e3eccfbcda6ed089ff1ca56cc0226d80a49dfd4b..7d971090199ded4b85cb1bdb2bd8d45a22bc8cc8 100755 (executable)
@@ -2060,6 +2060,10 @@ planoldpythontestsuite(
     'ad_dc',
     'samba.tests.krb5.gkdi_tests',
     environ=krb5_environ)
+planoldpythontestsuite(
+    'ad_dc:local',
+    'samba.tests.krb5.gmsa_tests',
+    environ=krb5_environ)
 
 for env in [
         'vampire_dc',