traffic_replay: Add a max-members option to cap group size
[samba.git] / python / samba / emulate / traffic.py
index af99e66823af4aedc669dd91f7534840a397e30c..291162f279ac066757f1d4459df360759604ef9c 100644 (file)
@@ -1603,11 +1603,10 @@ class ConversationAccounts(object):
 def generate_replay_accounts(ldb, instance_id, number, password):
     """Generate a series of unique machine and user account names."""
 
-    generate_traffic_accounts(ldb, instance_id, number, password)
     accounts = []
     for i in range(1, number + 1):
-        netbios_name = "STGM-%d-%d" % (instance_id, i)
-        username     = "STGU-%d-%d" % (instance_id, i)
+        netbios_name = machine_name(instance_id, i)
+        username = user_name(instance_id, i)
 
         account = ConversationAccounts(netbios_name, password, username,
                                        password)
@@ -1615,54 +1614,6 @@ def generate_replay_accounts(ldb, instance_id, number, password):
     return accounts
 
 
-def generate_traffic_accounts(ldb, instance_id, number, password):
-    """Create the specified number of user and machine accounts.
-
-    As accounts are not explicitly deleted between runs. This function starts
-    with the last account and iterates backwards stopping either when it
-    finds an already existing account or it has generated all the required
-    accounts.
-    """
-    print(("Generating machine and conversation accounts, "
-           "as required for %d conversations" % number),
-          file=sys.stderr)
-    added = 0
-    for i in range(number, 0, -1):
-        try:
-            netbios_name = "STGM-%d-%d" % (instance_id, i)
-            create_machine_account(ldb, instance_id, netbios_name, password)
-            added += 1
-            if added % 50 == 0:
-                LOGGER.info("Created %u/%u machine accounts" % (added, number))
-        except LdbError as e:
-            (status, _) = e.args
-            if status == 68:
-                break
-            else:
-                raise
-    if added > 0:
-        LOGGER.info("Added %d new machine accounts" % added)
-
-    added = 0
-    for i in range(number, 0, -1):
-        try:
-            username = "STGU-%d-%d" % (instance_id, i)
-            create_user_account(ldb, instance_id, username, password)
-            added += 1
-            if added % 50 == 0:
-                LOGGER.info("Created %u/%u users" % (added, number))
-
-        except LdbError as e:
-            (status, _) = e.args
-            if status == 68:
-                break
-            else:
-                raise
-
-    if added > 0:
-        LOGGER.info("Added %d new user accounts" % added)
-
-
 def create_machine_account(ldb, instance_id, netbios_name, machinepass,
                            traffic_account=True):
     """Create a machine account via ldap."""
@@ -1747,15 +1698,29 @@ def generate_users(ldb, instance_id, number, password):
     return users
 
 
+def machine_name(instance_id, i, traffic_account=True):
+    """Generate a machine account name from instance id."""
+    if traffic_account:
+        # traffic accounts correspond to a given user, and use different
+        # userAccountControl flags to ensure packets get processed correctly
+        # by the DC
+        return "STGM-%d-%d" % (instance_id, i)
+    else:
+        # Otherwise we're just generating computer accounts to simulate a
+        # semi-realistic network. These use the default computer
+        # userAccountControl flags, so we use a different account name so that
+        # we don't try to use them when generating packets
+        return "PC-%d-%d" % (instance_id, i)
+
+
 def generate_machine_accounts(ldb, instance_id, number, password,
                               traffic_account=True):
     """Add machine accounts to the server"""
     existing_objects = search_objectclass(ldb, objectclass='computer')
     added = 0
     for i in range(number, 0, -1):
-        name = "STGM-%d-%d$" % (instance_id, i)
-        if name not in existing_objects:
-            name = "STGM-%d-%d" % (instance_id, i)
+        name = machine_name(instance_id, i, traffic_account)
+        if name + "$" not in existing_objects:
             create_machine_account(ldb, instance_id, name, password,
                                    traffic_account)
             added += 1
@@ -1799,8 +1764,8 @@ def clean_up_accounts(ldb, instance_id):
 
 def generate_users_and_groups(ldb, instance_id, password,
                               number_of_users, number_of_groups,
-                              group_memberships, machine_accounts=0,
-                              traffic_accounts=True):
+                              group_memberships, max_members,
+                              machine_accounts, traffic_accounts=True):
     """Generate the required users and groups, allocating the users to
        those groups."""
     memberships_added = 0
@@ -1812,11 +1777,10 @@ def generate_users_and_groups(ldb, instance_id, password,
     LOGGER.info("Generating dummy user accounts")
     users_added = generate_users(ldb, instance_id, number_of_users, password)
 
-    if machine_accounts > 0:
-        LOGGER.info("Generating dummy machine accounts")
-        computers_added = generate_machine_accounts(ldb, instance_id,
-                                                    machine_accounts, password,
-                                                    traffic_accounts)
+    LOGGER.info("Generating dummy machine accounts")
+    computers_added = generate_machine_accounts(ldb, instance_id,
+                                                machine_accounts, password,
+                                                traffic_accounts)
 
     if number_of_groups > 0:
         LOGGER.info("Generating dummy groups")
@@ -1828,7 +1792,8 @@ def generate_users_and_groups(ldb, instance_id, password,
                                        groups_added,
                                        number_of_users,
                                        users_added,
-                                       group_memberships)
+                                       group_memberships,
+                                       max_members)
         LOGGER.info("Adding users to groups")
         add_users_to_groups(ldb, instance_id, assignments)
         memberships_added = assignments.total()
@@ -1844,16 +1809,15 @@ def generate_users_and_groups(ldb, instance_id, password,
 
 class GroupAssignments(object):
     def __init__(self, number_of_groups, groups_added, number_of_users,
-                 users_added, group_memberships):
+                 users_added, group_memberships, max_members):
 
         self.count = 0
         self.generate_group_distribution(number_of_groups)
         self.generate_user_distribution(number_of_users, group_memberships)
-        self.assignments = self.assign_groups(number_of_groups,
-                                              groups_added,
-                                              number_of_users,
-                                              users_added,
-                                              group_memberships)
+        self.max_members = max_members
+        self.assignments = defaultdict(list)
+        self.assign_groups(number_of_groups, groups_added, number_of_users,
+                           users_added, group_memberships)
 
     def cumulative_distribution(self, weights):
         # make sure the probabilities conform to a cumulative distribution
@@ -1863,6 +1827,9 @@ class GroupAssignments(object):
         # value, so we can use random.random() as a simple index into the list
         dist = []
         total = sum(weights)
+        if total == 0:
+            return None
+
         cumulative = 0.0
         for probability in weights:
             cumulative += probability
@@ -1906,6 +1873,7 @@ class GroupAssignments(object):
             weights.append(p)
 
         # convert the weights to a cumulative distribution between 0.0 and 1.0
+        self.group_weights = weights
         self.group_dist = self.cumulative_distribution(weights)
 
     def generate_random_membership(self):
@@ -1926,6 +1894,30 @@ class GroupAssignments(object):
     def get_groups(self):
         return self.assignments.keys()
 
+    def cap_group_membership(self, group, max_members):
+        """Prevent the group's membership from exceeding the max specified"""
+        num_members = len(self.assignments[group])
+        if num_members >= max_members:
+            LOGGER.info("Group {0} has {1} members".format(group, num_members))
+
+            # remove this group and then recalculate the cumulative
+            # distribution, so this group is no longer selected
+            self.group_weights[group - 1] = 0
+            new_dist = self.cumulative_distribution(self.group_weights)
+            self.group_dist = new_dist
+
+    def add_assignment(self, user, group):
+        # the assignments are stored in a dictionary where key=group,
+        # value=list-of-users-in-group (indexing by group-ID allows us to
+        # optimize for DB membership writes)
+        if user not in self.assignments[group]:
+            self.assignments[group].append(user)
+            self.count += 1
+
+        # check if there'a cap on how big the groups can grow
+        if self.max_members:
+            self.cap_group_membership(group, self.max_members)
+
     def assign_groups(self, number_of_groups, groups_added,
                       number_of_users, users_added, group_memberships):
         """Allocate users to groups.
@@ -1937,34 +1929,27 @@ class GroupAssignments(object):
         few users.
         """
 
-        assignments = set()
         if group_memberships <= 0:
-            return {}
+            return
 
         # Calculate the number of group menberships required
         group_memberships = math.ceil(
             float(group_memberships) *
             (float(users_added) / float(number_of_users)))
 
+        if self.max_members:
+            group_memberships = min(group_memberships,
+                                    self.max_members * number_of_groups)
+
         existing_users  = number_of_users  - users_added  - 1
         existing_groups = number_of_groups - groups_added - 1
-        while len(assignments) < group_memberships:
+        while self.total() < group_memberships:
             user, group = self.generate_random_membership()
 
             if group > existing_groups or user > existing_users:
                 # the + 1 converts the array index to the corresponding
                 # group or user number
-                assignments.add(((user + 1), (group + 1)))
-
-        # convert the set into a dictionary, where key=group, value=list-of-
-        # users-in-group (indexing by group-ID allows us to optimize for
-        # DB membership writes)
-        assignment_dict = defaultdict(list)
-        for (user, group) in assignments:
-            assignment_dict[group].append(user)
-            self.count += 1
-
-        return assignment_dict
+                self.add_assignment(user + 1, group + 1)
 
     def total(self):
         return self.count