traffic: fix userAccountControl for machine account
[metze/samba/wip.git] / python / samba / emulate / traffic.py
index ea0529c4cb49c622c393cc25742aae697a7c1ebe..84a9a6ab067045a773b8ddea4e66b6f281923d5f 100644 (file)
@@ -42,10 +42,14 @@ from samba.drs_utils import drs_DsBind
 import traceback
 from samba.credentials import Credentials, DONT_USE_KERBEROS, MUST_USE_KERBEROS
 from samba.auth import system_session
-from samba.dsdb import UF_WORKSTATION_TRUST_ACCOUNT, UF_PASSWD_NOTREQD
-from samba.dsdb import UF_NORMAL_ACCOUNT
-from samba.dcerpc.misc import SEC_CHAN_WKSTA
+from samba.dsdb import (
+    UF_NORMAL_ACCOUNT,
+    UF_SERVER_TRUST_ACCOUNT,
+    UF_TRUSTED_FOR_DELEGATION
+)
+from samba.dcerpc.misc import SEC_CHAN_BDC
 from samba import gensec
+from samba import sd_utils
 
 SLEEP_OVERHEAD = 3e-4
 
@@ -272,13 +276,12 @@ class Packet(object):
             return False
 
         fn_name = 'packet_%s_%s' % (self.protocol, self.opcode)
-        try:
-            fn = getattr(traffic_packets, fn_name)
-            if fn is traffic_packets.null_packet:
-                return False
-        except AttributeError:
+        fn = getattr(traffic_packets, fn_name, None)
+        if not fn:
             print("missing packet %s" % fn_name, file=sys.stderr)
             return False
+        if fn is traffic_packets.null_packet:
+            return False
         return True
 
 
@@ -456,6 +459,7 @@ class ReplayContext(object):
         self.user_creds.set_workstation(self.netbios_name)
         self.user_creds.set_password(self.userpass)
         self.user_creds.set_username(self.username)
+        self.user_creds.set_domain(self.domain)
         if self.prefer_kerberos:
             self.user_creds.set_kerberos_state(MUST_USE_KERBEROS)
         else:
@@ -510,9 +514,10 @@ class ReplayContext(object):
         self.machine_creds = Credentials()
         self.machine_creds.guess(self.lp)
         self.machine_creds.set_workstation(self.netbios_name)
-        self.machine_creds.set_secure_channel_type(SEC_CHAN_WKSTA)
+        self.machine_creds.set_secure_channel_type(SEC_CHAN_BDC)
         self.machine_creds.set_password(self.machinepass)
         self.machine_creds.set_username(self.netbios_name + "$")
+        self.machine_creds.set_domain(self.domain)
         if self.prefer_kerberos:
             self.machine_creds.set_kerberos_state(MUST_USE_KERBEROS)
         else:
@@ -521,7 +526,7 @@ class ReplayContext(object):
         self.machine_creds_bad = Credentials()
         self.machine_creds_bad.guess(self.lp)
         self.machine_creds_bad.set_workstation(self.netbios_name)
-        self.machine_creds_bad.set_secure_channel_type(SEC_CHAN_WKSTA)
+        self.machine_creds_bad.set_secure_channel_type(SEC_CHAN_BDC)
         self.machine_creds_bad.set_password(self.machinepass[:-4])
         self.machine_creds_bad.set_username(self.netbios_name + "$")
         if self.prefer_kerberos:
@@ -644,6 +649,15 @@ class ReplayContext(object):
             return self.ldap_connections[-1]
 
         def simple_bind(creds):
+            """
+            To run simple bind against Windows, we need to run
+            following commands in PowerShell:
+
+                Install-windowsfeature ADCS-Cert-Authority
+                Install-AdcsCertificationAuthority -CAType EnterpriseRootCA
+                Restart-Computer
+
+            """
             return SamDB('ldaps://%s' % self.server,
                          credentials=creds,
                          lp=self.lp)
@@ -670,7 +684,8 @@ class ReplayContext(object):
 
     def get_samr_context(self, new=False):
         if not self.samr_contexts or new:
-            self.samr_contexts.append(SamrContext(self.server))
+            self.samr_contexts.append(
+                SamrContext(self.server, lp=self.lp, creds=self.creds))
         return self.samr_contexts[-1]
 
     def get_netlogon_connection(self):
@@ -707,7 +722,7 @@ class ReplayContext(object):
 class SamrContext(object):
     """State/Context associated with a samr connection.
     """
-    def __init__(self, server):
+    def __init__(self, server, lp=None, creds=None):
         self.connection    = None
         self.handle        = None
         self.domain_handle = None
@@ -716,10 +731,16 @@ class SamrContext(object):
         self.user_handle   = None
         self.rids          = None
         self.server        = server
+        self.lp            = lp
+        self.creds         = creds
 
     def get_connection(self):
         if not self.connection:
-            self.connection = samr.samr("ncacn_ip_tcp:%s" % (self.server))
+            self.connection = samr.samr(
+                "ncacn_ip_tcp:%s[seal]" % (self.server),
+                lp_ctx=self.lp,
+                credentials=self.creds)
+
         return self.connection
 
     def get_handle(self):
@@ -775,12 +796,12 @@ class Conversation(object):
         if p.is_really_a_packet():
             self.packets.append(p)
 
-    def add_short_packet(self, timestamp, p, extra, client=True):
+    def add_short_packet(self, timestamp, protocol, opcode, extra,
+                         client=True):
         """Create a packet from a timestamp, and 'protocol:opcode' pair, and a
         (possibly empty) list of extra data. If client is True, assume
         this packet is from the client to the server.
         """
-        protocol, opcode = p.split(':', 1)
         src, dest = self.guess_client_server()
         if not client:
             src, dest = dest, src
@@ -929,18 +950,8 @@ class Conversation(object):
         :param s: start of the window
         :param e: end of the window
         """
-
-        new_packets = []
-        for p in self.packets:
-            if p.timestamp < s or p.timestamp > e:
-                continue
-            new_packets.append(p)
-
-        self.packets = new_packets
-        if new_packets:
-            self.start_time = new_packets[0].timestamp
-        else:
-            self.start_time = None
+        self.packets = [p for p in self.packets if s <= p.timestamp <= e]
+        self.start_time = self.packets[0].timestamp if self.packets else None
 
     def renormalise_times(self, start_time):
         """Adjust the packet start times relative to the new start time."""
@@ -1211,7 +1222,7 @@ class TrafficModel(object):
                 timestamp += wait
                 if hard_stop is not None and timestamp > hard_stop:
                     break
-                c.add_short_packet(timestamp, p, extra)
+                c.add_short_packet(timestamp, protocol, opcode, extra)
 
             key = key[1:] + (p,)
 
@@ -1649,7 +1660,7 @@ def create_machine_account(ldb, instance_id, netbios_name, machinepass):
         "objectclass": "computer",
         "sAMAccountName": "%s$" % netbios_name,
         "userAccountControl":
-        str(UF_WORKSTATION_TRUST_ACCOUNT | UF_PASSWD_NOTREQD),
+            str(UF_TRUSTED_FOR_DELEGATION | UF_SERVER_TRUST_ACCOUNT),
         "unicodePwd": utf16pw})
     end = time.time()
     duration = end - start
@@ -1671,6 +1682,11 @@ def create_user_account(ldb, instance_id, username, userpass):
         "userAccountControl": str(UF_NORMAL_ACCOUNT),
         "unicodePwd": utf16pw
     })
+
+    # grant user write permission to do things like write account SPN
+    sdutils = sd_utils.SDUtils(ldb)
+    sdutils.dacl_add_ace(user_dn,  "(A;;WP;;;PS)")
+
     end = time.time()
     duration = end - start
     print("%f\t0\tcreate\tuser\t%f\tTrue\t" % (end, duration))