traffic: Rework how assignments are generated slightly
[samba.git] / python / samba / emulate / traffic.py
index af05163f145d23cf592c99f005ae8923d60e53a8..fd886e3865e6a5b392db4e3ad5e1e37aeea07a9b 100644 (file)
@@ -1603,11 +1603,10 @@ class ConversationAccounts(object):
 def generate_replay_accounts(ldb, instance_id, number, password):
     """Generate a series of unique machine and user account names."""
 
-    generate_traffic_accounts(ldb, instance_id, number, password)
     accounts = []
     for i in range(1, number + 1):
-        netbios_name = "STGM-%d-%d" % (instance_id, i)
-        username     = "STGU-%d-%d" % (instance_id, i)
+        netbios_name = machine_name(instance_id, i)
+        username = user_name(instance_id, i)
 
         account = ConversationAccounts(netbios_name, password, username,
                                        password)
@@ -1615,54 +1614,6 @@ def generate_replay_accounts(ldb, instance_id, number, password):
     return accounts
 
 
-def generate_traffic_accounts(ldb, instance_id, number, password):
-    """Create the specified number of user and machine accounts.
-
-    As accounts are not explicitly deleted between runs. This function starts
-    with the last account and iterates backwards stopping either when it
-    finds an already existing account or it has generated all the required
-    accounts.
-    """
-    print(("Generating machine and conversation accounts, "
-           "as required for %d conversations" % number),
-          file=sys.stderr)
-    added = 0
-    for i in range(number, 0, -1):
-        try:
-            netbios_name = "STGM-%d-%d" % (instance_id, i)
-            create_machine_account(ldb, instance_id, netbios_name, password)
-            added += 1
-            if added % 50 == 0:
-                LOGGER.info("Created %u/%u machine accounts" % (added, number))
-        except LdbError as e:
-            (status, _) = e.args
-            if status == 68:
-                break
-            else:
-                raise
-    if added > 0:
-        LOGGER.info("Added %d new machine accounts" % added)
-
-    added = 0
-    for i in range(number, 0, -1):
-        try:
-            username = "STGU-%d-%d" % (instance_id, i)
-            create_user_account(ldb, instance_id, username, password)
-            added += 1
-            if added % 50 == 0:
-                LOGGER.info("Created %u/%u users" % (added, number))
-
-        except LdbError as e:
-            (status, _) = e.args
-            if status == 68:
-                break
-            else:
-                raise
-
-    if added > 0:
-        LOGGER.info("Added %d new user accounts" % added)
-
-
 def create_machine_account(ldb, instance_id, netbios_name, machinepass,
                            traffic_account=True):
     """Create a machine account via ldap."""
@@ -1747,16 +1698,31 @@ def generate_users(ldb, instance_id, number, password):
     return users
 
 
-def generate_machine_accounts(ldb, instance_id, number, password):
+def machine_name(instance_id, i, traffic_account=True):
+    """Generate a machine account name from instance id."""
+    if traffic_account:
+        # traffic accounts correspond to a given user, and use different
+        # userAccountControl flags to ensure packets get processed correctly
+        # by the DC
+        return "STGM-%d-%d" % (instance_id, i)
+    else:
+        # Otherwise we're just generating computer accounts to simulate a
+        # semi-realistic network. These use the default computer
+        # userAccountControl flags, so we use a different account name so that
+        # we don't try to use them when generating packets
+        return "PC-%d-%d" % (instance_id, i)
+
+
+def generate_machine_accounts(ldb, instance_id, number, password,
+                              traffic_account=True):
     """Add machine accounts to the server"""
     existing_objects = search_objectclass(ldb, objectclass='computer')
     added = 0
     for i in range(number, 0, -1):
-        name = "STGM-%d-%d$" % (instance_id, i)
-        if name not in existing_objects:
-            name = "STGM-%d-%d" % (instance_id, i)
+        name = machine_name(instance_id, i, traffic_account)
+        if name + "$" not in existing_objects:
             create_machine_account(ldb, instance_id, name, password,
-                                   traffic_account=False)
+                                   traffic_account)
             added += 1
             if added % 50 == 0:
                 LOGGER.info("Created %u/%u machine accounts" % (added, number))
@@ -1798,22 +1764,23 @@ def clean_up_accounts(ldb, instance_id):
 
 def generate_users_and_groups(ldb, instance_id, password,
                               number_of_users, number_of_groups,
-                              group_memberships):
+                              group_memberships, machine_accounts,
+                              traffic_accounts=True):
     """Generate the required users and groups, allocating the users to
        those groups."""
     memberships_added = 0
-    groups_added  = 0
+    groups_added = 0
+    computers_added = 0
 
     create_ou(ldb, instance_id)
 
     LOGGER.info("Generating dummy user accounts")
     users_added = generate_users(ldb, instance_id, number_of_users, password)
 
-    # assume there will be some overhang with more computer accounts than users
-    computer_accounts = int(1.25 * number_of_users)
     LOGGER.info("Generating dummy machine accounts")
     computers_added = generate_machine_accounts(ldb, instance_id,
-                                                computer_accounts, password)
+                                                machine_accounts, password,
+                                                traffic_accounts)
 
     if number_of_groups > 0:
         LOGGER.info("Generating dummy groups")
@@ -1846,11 +1813,9 @@ class GroupAssignments(object):
         self.count = 0
         self.generate_group_distribution(number_of_groups)
         self.generate_user_distribution(number_of_users, group_memberships)
-        self.assignments = self.assign_groups(number_of_groups,
-                                              groups_added,
-                                              number_of_users,
-                                              users_added,
-                                              group_memberships)
+        self.assignments = defaultdict(list)
+        self.assign_groups(number_of_groups, groups_added, number_of_users,
+                           users_added, group_memberships)
 
     def cumulative_distribution(self, weights):
         # make sure the probabilities conform to a cumulative distribution
@@ -1923,6 +1888,14 @@ class GroupAssignments(object):
     def get_groups(self):
         return self.assignments.keys()
 
+    def add_assignment(self, user, group):
+        # the assignments are stored in a dictionary where key=group,
+        # value=list-of-users-in-group (indexing by group-ID allows us to
+        # optimize for DB membership writes)
+        if user not in self.assignments[group]:
+            self.assignments[group].append(user)
+            self.count += 1
+
     def assign_groups(self, number_of_groups, groups_added,
                       number_of_users, users_added, group_memberships):
         """Allocate users to groups.
@@ -1934,9 +1907,8 @@ class GroupAssignments(object):
         few users.
         """
 
-        assignments = set()
         if group_memberships <= 0:
-            return {}
+            return
 
         # Calculate the number of group menberships required
         group_memberships = math.ceil(
@@ -1945,23 +1917,13 @@ class GroupAssignments(object):
 
         existing_users  = number_of_users  - users_added  - 1
         existing_groups = number_of_groups - groups_added - 1
-        while len(assignments) < group_memberships:
+        while self.total() < group_memberships:
             user, group = self.generate_random_membership()
 
             if group > existing_groups or user > existing_users:
                 # the + 1 converts the array index to the corresponding
                 # group or user number
-                assignments.add(((user + 1), (group + 1)))
-
-        # convert the set into a dictionary, where key=group, value=list-of-
-        # users-in-group (indexing by group-ID allows us to optimize for
-        # DB membership writes)
-        assignment_dict = defaultdict(list)
-        for (user, group) in assignments:
-            assignment_dict[group].append(user)
-            self.count += 1
-
-        return assignment_dict
+                self.add_assignment(user + 1, group + 1)
 
     def total(self):
         return self.count