traffic_replay: Add a max-members option to cap group size
[samba.git] / python / samba / emulate / traffic.py
index 2688f3487887afa3ed2e7dcb237b21d9755a97b7..291162f279ac066757f1d4459df360759604ef9c 100644 (file)
@@ -45,13 +45,15 @@ from samba.auth import system_session
 from samba.dsdb import (
     UF_NORMAL_ACCOUNT,
     UF_SERVER_TRUST_ACCOUNT,
-    UF_TRUSTED_FOR_DELEGATION
+    UF_TRUSTED_FOR_DELEGATION,
+    UF_WORKSTATION_TRUST_ACCOUNT
 )
 from samba.dcerpc.misc import SEC_CHAN_BDC
 from samba import gensec
 from samba import sd_utils
 from samba.compat import get_string
 from samba.logger import get_samba_logger
+import bisect
 
 SLEEP_OVERHEAD = 3e-4
 
@@ -1601,11 +1603,10 @@ class ConversationAccounts(object):
 def generate_replay_accounts(ldb, instance_id, number, password):
     """Generate a series of unique machine and user account names."""
 
-    generate_traffic_accounts(ldb, instance_id, number, password)
     accounts = []
     for i in range(1, number + 1):
-        netbios_name = "STGM-%d-%d" % (instance_id, i)
-        username     = "STGU-%d-%d" % (instance_id, i)
+        netbios_name = machine_name(instance_id, i)
+        username = user_name(instance_id, i)
 
         account = ConversationAccounts(netbios_name, password, username,
                                        password)
@@ -1613,69 +1614,29 @@ def generate_replay_accounts(ldb, instance_id, number, password):
     return accounts
 
 
-def generate_traffic_accounts(ldb, instance_id, number, password):
-    """Create the specified number of user and machine accounts.
-
-    As accounts are not explicitly deleted between runs. This function starts
-    with the last account and iterates backwards stopping either when it
-    finds an already existing account or it has generated all the required
-    accounts.
-    """
-    print(("Generating machine and conversation accounts, "
-           "as required for %d conversations" % number),
-          file=sys.stderr)
-    added = 0
-    for i in range(number, 0, -1):
-        try:
-            netbios_name = "STGM-%d-%d" % (instance_id, i)
-            create_machine_account(ldb, instance_id, netbios_name, password)
-            added += 1
-        except LdbError as e:
-            (status, _) = e.args
-            if status == 68:
-                break
-            else:
-                raise
-    if added > 0:
-        print("Added %d new machine accounts" % added,
-              file=sys.stderr)
-
-    added = 0
-    for i in range(number, 0, -1):
-        try:
-            username = "STGU-%d-%d" % (instance_id, i)
-            create_user_account(ldb, instance_id, username, password)
-            added += 1
-        except LdbError as e:
-            (status, _) = e.args
-            if status == 68:
-                break
-            else:
-                raise
-
-    if added > 0:
-        print("Added %d new user accounts" % added,
-              file=sys.stderr)
-
-
-def create_machine_account(ldb, instance_id, netbios_name, machinepass):
+def create_machine_account(ldb, instance_id, netbios_name, machinepass,
+                           traffic_account=True):
     """Create a machine account via ldap."""
 
     ou = ou_name(ldb, instance_id)
     dn = "cn=%s,%s" % (netbios_name, ou)
     utf16pw = ('"%s"' % get_string(machinepass)).encode('utf-16-le')
 
-    start = time.time()
+    if traffic_account:
+        # we set these bits for the machine account otherwise the replayed
+        # traffic throws up NT_STATUS_NO_TRUST_SAM_ACCOUNT errors
+        account_controls = str(UF_TRUSTED_FOR_DELEGATION |
+                               UF_SERVER_TRUST_ACCOUNT)
+
+    else:
+        account_controls = str(UF_WORKSTATION_TRUST_ACCOUNT)
+
     ldb.add({
         "dn": dn,
         "objectclass": "computer",
         "sAMAccountName": "%s$" % netbios_name,
-        "userAccountControl":
-            str(UF_TRUSTED_FOR_DELEGATION | UF_SERVER_TRUST_ACCOUNT),
+        "userAccountControl": account_controls,
         "unicodePwd": utf16pw})
-    end = time.time()
-    duration = end - start
-    LOGGER.info("%f\t0\tcreate\tmachine\t%f\tTrue\t" % (end, duration))
 
 
 def create_user_account(ldb, instance_id, username, userpass):
@@ -1683,7 +1644,6 @@ def create_user_account(ldb, instance_id, username, userpass):
     ou = ou_name(ldb, instance_id)
     user_dn = "cn=%s,%s" % (username, ou)
     utf16pw = ('"%s"' % get_string(userpass)).encode('utf-16-le')
-    start = time.time()
     ldb.add({
         "dn": user_dn,
         "objectclass": "user",
@@ -1696,25 +1656,17 @@ def create_user_account(ldb, instance_id, username, userpass):
     sdutils = sd_utils.SDUtils(ldb)
     sdutils.dacl_add_ace(user_dn, "(A;;WP;;;PS)")
 
-    end = time.time()
-    duration = end - start
-    LOGGER.info("%f\t0\tcreate\tuser\t%f\tTrue\t" % (end, duration))
-
 
 def create_group(ldb, instance_id, name):
     """Create a group via ldap."""
 
     ou = ou_name(ldb, instance_id)
     dn = "cn=%s,%s" % (name, ou)
-    start = time.time()
     ldb.add({
         "dn": dn,
         "objectclass": "group",
         "sAMAccountName": name,
     })
-    end = time.time()
-    duration = end - start
-    LOGGER.info("%f\t0\tcreate\tgroup\t%f\tTrue\t" % (end, duration))
 
 
 def user_name(instance_id, i):
@@ -1740,10 +1692,44 @@ def generate_users(ldb, instance_id, number, password):
         if name not in existing_objects:
             create_user_account(ldb, instance_id, name, password)
             users += 1
+            if users % 50 == 0:
+                LOGGER.info("Created %u/%u users" % (users, number))
 
     return users
 
 
+def machine_name(instance_id, i, traffic_account=True):
+    """Generate a machine account name from instance id."""
+    if traffic_account:
+        # traffic accounts correspond to a given user, and use different
+        # userAccountControl flags to ensure packets get processed correctly
+        # by the DC
+        return "STGM-%d-%d" % (instance_id, i)
+    else:
+        # Otherwise we're just generating computer accounts to simulate a
+        # semi-realistic network. These use the default computer
+        # userAccountControl flags, so we use a different account name so that
+        # we don't try to use them when generating packets
+        return "PC-%d-%d" % (instance_id, i)
+
+
+def generate_machine_accounts(ldb, instance_id, number, password,
+                              traffic_account=True):
+    """Add machine accounts to the server"""
+    existing_objects = search_objectclass(ldb, objectclass='computer')
+    added = 0
+    for i in range(number, 0, -1):
+        name = machine_name(instance_id, i, traffic_account)
+        if name + "$" not in existing_objects:
+            create_machine_account(ldb, instance_id, name, password,
+                                   traffic_account)
+            added += 1
+            if added % 50 == 0:
+                LOGGER.info("Created %u/%u machine accounts" % (added, number))
+
+    return added
+
+
 def group_name(instance_id, i):
     """Generate a group name from instance id."""
     return "STGG-%d-%d" % (instance_id, i)
@@ -1758,6 +1744,8 @@ def generate_groups(ldb, instance_id, number):
         if name not in existing_objects:
             create_group(ldb, instance_id, name)
             groups += 1
+            if groups % 1000 == 0:
+                LOGGER.info("Created %u/%u groups" % (groups, number))
 
     return groups
 
@@ -1776,53 +1764,162 @@ def clean_up_accounts(ldb, instance_id):
 
 def generate_users_and_groups(ldb, instance_id, password,
                               number_of_users, number_of_groups,
-                              group_memberships):
+                              group_memberships, max_members,
+                              machine_accounts, traffic_accounts=True):
     """Generate the required users and groups, allocating the users to
        those groups."""
     memberships_added = 0
-    groups_added  = 0
+    groups_added = 0
+    computers_added = 0
 
     create_ou(ldb, instance_id)
 
-    print("Generating dummy user accounts", file=sys.stderr)
+    LOGGER.info("Generating dummy user accounts")
     users_added = generate_users(ldb, instance_id, number_of_users, password)
 
+    LOGGER.info("Generating dummy machine accounts")
+    computers_added = generate_machine_accounts(ldb, instance_id,
+                                                machine_accounts, password,
+                                                traffic_accounts)
+
     if number_of_groups > 0:
-        print("Generating dummy groups", file=sys.stderr)
+        LOGGER.info("Generating dummy groups")
         groups_added = generate_groups(ldb, instance_id, number_of_groups)
 
     if group_memberships > 0:
-        print("Assigning users to groups", file=sys.stderr)
+        LOGGER.info("Assigning users to groups")
         assignments = GroupAssignments(number_of_groups,
                                        groups_added,
                                        number_of_users,
                                        users_added,
-                                       group_memberships)
-        print("Adding users to groups", file=sys.stderr)
-        add_users_to_groups(ldb, instance_id, assignments.assignments)
+                                       group_memberships,
+                                       max_members)
+        LOGGER.info("Adding users to groups")
+        add_users_to_groups(ldb, instance_id, assignments)
         memberships_added = assignments.total()
 
     if (groups_added > 0 and users_added == 0 and
        number_of_groups != groups_added):
-        print("Warning: the added groups will contain no members",
-              file=sys.stderr)
+        LOGGER.warning("The added groups will contain no members")
 
-    print(("Added %d users, %d groups and %d group memberships" %
-           (users_added, groups_added, memberships_added)),
-          file=sys.stderr)
+    LOGGER.info("Added %d users (%d machines), %d groups and %d memberships" %
+                (users_added, computers_added, groups_added,
+                 memberships_added))
 
 
 class GroupAssignments(object):
     def __init__(self, number_of_groups, groups_added, number_of_users,
-                 users_added, group_memberships):
-        self.assignments = self.assign_groups(number_of_groups,
-                                              groups_added,
-                                              number_of_users,
-                                              users_added,
-                                              group_memberships)
-
-    def assign_groups(self, number_of_groups, groups_added, number_of_users,
-                      users_added, group_memberships):
+                 users_added, group_memberships, max_members):
+
+        self.count = 0
+        self.generate_group_distribution(number_of_groups)
+        self.generate_user_distribution(number_of_users, group_memberships)
+        self.max_members = max_members
+        self.assignments = defaultdict(list)
+        self.assign_groups(number_of_groups, groups_added, number_of_users,
+                           users_added, group_memberships)
+
+    def cumulative_distribution(self, weights):
+        # make sure the probabilities conform to a cumulative distribution
+        # spread between 0.0 and 1.0. Dividing by the weighted total gives each
+        # probability a proportional share of 1.0. Higher probabilities get a
+        # bigger share, so are more likely to be picked. We use the cumulative
+        # value, so we can use random.random() as a simple index into the list
+        dist = []
+        total = sum(weights)
+        if total == 0:
+            return None
+
+        cumulative = 0.0
+        for probability in weights:
+            cumulative += probability
+            dist.append(cumulative / total)
+        return dist
+
+    def generate_user_distribution(self, num_users, num_memberships):
+        """Probability distribution of a user belonging to a group.
+        """
+        # Assign a weighted probability to each user. Use the Pareto
+        # Distribution so that some users are in a lot of groups, and the
+        # bulk of users are in only a few groups. If we're assigning a large
+        # number of group memberships, use a higher shape. This means slightly
+        # fewer outlying users that are in large numbers of groups. The aim is
+        # to have no users belonging to more than ~500 groups.
+        if num_memberships > 5000000:
+            shape = 3.0
+        elif num_memberships > 2000000:
+            shape = 2.5
+        elif num_memberships > 300000:
+            shape = 2.25
+        else:
+            shape = 1.75
+
+        weights = []
+        for x in range(1, num_users + 1):
+            p = random.paretovariate(shape)
+            weights.append(p)
+
+        # convert the weights to a cumulative distribution between 0.0 and 1.0
+        self.user_dist = self.cumulative_distribution(weights)
+
+    def generate_group_distribution(self, n):
+        """Probability distribution of a group containing a user."""
+
+        # Assign a weighted probability to each user. Probability decreases
+        # as the group-ID increases
+        weights = []
+        for x in range(1, n + 1):
+            p = 1 / (x**1.3)
+            weights.append(p)
+
+        # convert the weights to a cumulative distribution between 0.0 and 1.0
+        self.group_weights = weights
+        self.group_dist = self.cumulative_distribution(weights)
+
+    def generate_random_membership(self):
+        """Returns a randomly generated user-group membership"""
+
+        # the list items are cumulative distribution values between 0.0 and
+        # 1.0, which makes random() a handy way to index the list to get a
+        # weighted random user/group. (Here the user/group returned are
+        # zero-based array indexes)
+        user = bisect.bisect(self.user_dist, random.random())
+        group = bisect.bisect(self.group_dist, random.random())
+
+        return user, group
+
+    def users_in_group(self, group):
+        return self.assignments[group]
+
+    def get_groups(self):
+        return self.assignments.keys()
+
+    def cap_group_membership(self, group, max_members):
+        """Prevent the group's membership from exceeding the max specified"""
+        num_members = len(self.assignments[group])
+        if num_members >= max_members:
+            LOGGER.info("Group {0} has {1} members".format(group, num_members))
+
+            # remove this group and then recalculate the cumulative
+            # distribution, so this group is no longer selected
+            self.group_weights[group - 1] = 0
+            new_dist = self.cumulative_distribution(self.group_weights)
+            self.group_dist = new_dist
+
+    def add_assignment(self, user, group):
+        # the assignments are stored in a dictionary where key=group,
+        # value=list-of-users-in-group (indexing by group-ID allows us to
+        # optimize for DB membership writes)
+        if user not in self.assignments[group]:
+            self.assignments[group].append(user)
+            self.count += 1
+
+        # check if there'a cap on how big the groups can grow
+        if self.max_members:
+            self.cap_group_membership(group, self.max_members)
+
+    def assign_groups(self, number_of_groups, groups_added,
+                      number_of_users, users_added, group_memberships):
         """Allocate users to groups.
 
         The intention is to have a few users that belong to most groups, while
@@ -1832,77 +1929,74 @@ class GroupAssignments(object):
         few users.
         """
 
-        def generate_user_distribution(n):
-            """Probability distribution of a user belonging to a group.
-            """
-            dist = []
-            for x in range(1, n + 1):
-                p = 1 / (x + 0.001)
-                dist.append(p)
-            return dist
-
-        def generate_group_distribution(n):
-            """Probability distribution of a group containing a user."""
-            dist = []
-            for x in range(1, n + 1):
-                p = 1 / (x**1.3)
-                dist.append(p)
-            return dist
-
-        assignments = set()
         if group_memberships <= 0:
-            return assignments
-
-        group_dist = generate_group_distribution(number_of_groups)
-        user_dist  = generate_user_distribution(number_of_users)
+            return
 
         # Calculate the number of group menberships required
         group_memberships = math.ceil(
             float(group_memberships) *
             (float(users_added) / float(number_of_users)))
 
+        if self.max_members:
+            group_memberships = min(group_memberships,
+                                    self.max_members * number_of_groups)
+
         existing_users  = number_of_users  - users_added  - 1
         existing_groups = number_of_groups - groups_added - 1
-        while len(assignments) < group_memberships:
-            user        = random.randint(0, number_of_users - 1)
-            group       = random.randint(0, number_of_groups - 1)
-            probability = group_dist[group] * user_dist[user]
+        while self.total() < group_memberships:
+            user, group = self.generate_random_membership()
 
-            if ((random.random() < probability * 10000) and
-               (group > existing_groups or user > existing_users)):
+            if group > existing_groups or user > existing_users:
                 # the + 1 converts the array index to the corresponding
                 # group or user number
-                assignments.add(((user + 1), (group + 1)))
-
-        return assignments
+                self.add_assignment(user + 1, group + 1)
 
     def total(self):
-        return len(self.assignments)
+        return self.count
 
 
 def add_users_to_groups(db, instance_id, assignments):
-    """Add users to their assigned groups.
+    """Takes the assignments of users to groups and applies them to the DB."""
 
-    Takes the list of (group,user) tuples generated by assign_groups and
-    assign the users to their specified groups."""
+    total = assignments.total()
+    count = 0
+    added = 0
+
+    for group in assignments.get_groups():
+        users_in_group = assignments.users_in_group(group)
+        if len(users_in_group) == 0:
+            continue
+
+        # Split up the users into chunks, so we write no more than 1K at a
+        # time. (Minimizing the DB modifies is more efficient, but writing
+        # 10K+ users to a single group becomes inefficient memory-wise)
+        for chunk in range(0, len(users_in_group), 1000):
+            chunk_of_users = users_in_group[chunk:chunk + 1000]
+            add_group_members(db, instance_id, group, chunk_of_users)
+
+            added += len(chunk_of_users)
+            count += 1
+            if count % 50 == 0:
+                LOGGER.info("Added %u/%u memberships" % (added, total))
+
+def add_group_members(db, instance_id, group, users_in_group):
+    """Adds the given users to group specified."""
 
     ou = ou_name(db, instance_id)
 
     def build_dn(name):
         return("cn=%s,%s" % (name, ou))
 
-    for (user, group) in assignments:
-        user_dn  = build_dn(user_name(instance_id, user))
-        group_dn = build_dn(group_name(instance_id, group))
+    group_dn = build_dn(group_name(instance_id, group))
+    m = ldb.Message()
+    m.dn = ldb.Dn(db, group_dn)
 
-        m = ldb.Message()
-        m.dn = ldb.Dn(db, group_dn)
-        m["member"] = ldb.MessageElement(user_dn, ldb.FLAG_MOD_ADD, "member")
-        start = time.time()
-        db.modify(m)
-        end = time.time()
-        duration = end - start
-        LOGGER.info("%f\t0\tadd\tuser\t%f\tTrue\t" % (end, duration))
+    for user in users_in_group:
+        user_dn = build_dn(user_name(instance_id, user))
+        idx = "member-" + str(user)
+        m[idx] = ldb.MessageElement(user_dn, ldb.FLAG_MOD_ADD, "member")
+
+    db.modify(m)
 
 
 def generate_stats(statsdir, timing_file):