traffic_replay: Re-organize assignments to be group-based
authorTim Beale <timbeale@catalyst.net.nz>
Wed, 31 Oct 2018 03:50:27 +0000 (16:50 +1300)
committerTim Beale <timbeale@samba.org>
Sun, 4 Nov 2018 22:55:16 +0000 (23:55 +0100)
We can speed up writing the group memberships by adding multiple users
to a group in a single DB modify operation.

To do this, we first need to reorganize the assignments so instead
of being a set of tuples, it's a dictionary where key=group and
value=list-of-users-in-group.

add_users_to_groups() now iterates through the users/groups slightly
differently, but mostly it's just indentation changes. We haven't
changed the number of DB operations yet - we'll do that in the next
patch.

Signed-off-by: Tim Beale <timbeale@catalyst.net.nz>
Reviewed-by: Douglas Bagnall <douglas.bagnall@catalyst.net.nz>
python/samba/emulate/traffic.py

index aabf6ed0a4225bd4e2581b9f8c404a444700ff62..0087b03a3798df9a8eda589515a8d8ace58aae69 100644 (file)
@@ -1800,7 +1800,7 @@ def generate_users_and_groups(ldb, instance_id, password,
                                        users_added,
                                        group_memberships)
         print("Adding users to groups", file=sys.stderr)
-        add_users_to_groups(ldb, instance_id, assignments.assignments)
+        add_users_to_groups(ldb, instance_id, assignments)
         memberships_added = assignments.total()
 
     if (groups_added > 0 and users_added == 0 and
@@ -1817,6 +1817,7 @@ class GroupAssignments(object):
     def __init__(self, number_of_groups, groups_added, number_of_users,
                  users_added, group_memberships):
 
+        self.count = 0
         self.generate_group_distribution(number_of_groups)
         self.generate_user_distribution(number_of_users, group_memberships)
         self.assignments = self.assign_groups(number_of_groups,
@@ -1890,6 +1891,12 @@ class GroupAssignments(object):
 
         return user, group
 
+    def users_in_group(self, group):
+        return self.assignments[group]
+
+    def get_groups(self):
+        return self.assignments.keys()
+
     def assign_groups(self, number_of_groups, groups_added,
                       number_of_users, users_added, group_memberships):
         """Allocate users to groups.
@@ -1903,7 +1910,7 @@ class GroupAssignments(object):
 
         assignments = set()
         if group_memberships <= 0:
-            return assignments
+            return {}
 
         # Calculate the number of group menberships required
         group_memberships = math.ceil(
@@ -1920,35 +1927,41 @@ class GroupAssignments(object):
                 # group or user number
                 assignments.add(((user + 1), (group + 1)))
 
-        return assignments
+        # convert the set into a dictionary, where key=group, value=list-of-
+        # users-in-group (indexing by group-ID allows us to optimize for
+        # DB membership writes)
+        assignment_dict = defaultdict(list)
+        for (user, group) in assignments:
+            assignment_dict[group].append(user)
+            self.count += 1
+
+        return assignment_dict
 
     def total(self):
-        return len(self.assignments)
+        return self.count
 
 
 def add_users_to_groups(db, instance_id, assignments):
-    """Add users to their assigned groups.
-
-    Takes the list of (group,user) tuples generated by assign_groups and
-    assign the users to their specified groups."""
+    """Takes the assignments of users to groups and applies them to the DB."""
 
     ou = ou_name(db, instance_id)
 
     def build_dn(name):
         return("cn=%s,%s" % (name, ou))
 
-    for (user, group) in assignments:
-        user_dn  = build_dn(user_name(instance_id, user))
-        group_dn = build_dn(group_name(instance_id, group))
+    for group in assignments.get_groups():
+        for user in assignments.users_in_group(group):
+            user_dn  = build_dn(user_name(instance_id, user))
+            group_dn = build_dn(group_name(instance_id, group))
 
-        m = ldb.Message()
-        m.dn = ldb.Dn(db, group_dn)
-        m["member"] = ldb.MessageElement(user_dn, ldb.FLAG_MOD_ADD, "member")
-        start = time.time()
-        db.modify(m)
-        end = time.time()
-        duration = end - start
-        LOGGER.info("%f\t0\tadd\tuser\t%f\tTrue\t" % (end, duration))
+            m = ldb.Message()
+            m.dn = ldb.Dn(db, group_dn)
+            m["member"] = ldb.MessageElement(user_dn, ldb.FLAG_MOD_ADD, "member")
+            start = time.time()
+            db.modify(m)
+            end = time.time()
+            duration = end - start
+            LOGGER.info("%f\t0\tadd\tuser\t%f\tTrue\t" % (end, duration))
 
 
 def generate_stats(statsdir, timing_file):