traffic_replay: Add a max-members option to cap group size
[samba.git] / python / samba / emulate / traffic.py
index afb57da36f08ac9baa4cbc02c09c500ae670619d..291162f279ac066757f1d4459df360759604ef9c 100644 (file)
@@ -42,10 +42,18 @@ from samba.drs_utils import drs_DsBind
 import traceback
 from samba.credentials import Credentials, DONT_USE_KERBEROS, MUST_USE_KERBEROS
 from samba.auth import system_session
-from samba.dsdb import UF_WORKSTATION_TRUST_ACCOUNT, UF_PASSWD_NOTREQD
-from samba.dsdb import UF_NORMAL_ACCOUNT
-from samba.dcerpc.misc import SEC_CHAN_WKSTA
+from samba.dsdb import (
+    UF_NORMAL_ACCOUNT,
+    UF_SERVER_TRUST_ACCOUNT,
+    UF_TRUSTED_FOR_DELEGATION,
+    UF_WORKSTATION_TRUST_ACCOUNT
+)
+from samba.dcerpc.misc import SEC_CHAN_BDC
 from samba import gensec
+from samba import sd_utils
+from samba.compat import get_string
+from samba.logger import get_samba_logger
+import bisect
 
 SLEEP_OVERHEAD = 3e-4
 
@@ -84,6 +92,8 @@ NO_WAIT_LOG_TIME_RANGE = (-10, -3)
 # DEBUG_LEVEL can be changed by scripts with -d
 DEBUG_LEVEL = 0
 
+LOGGER = get_samba_logger(name=__name__)
+
 
 def debug(level, msg, *args):
     """Print a formatted debug message to standard error.
@@ -134,10 +144,26 @@ class FakePacketError(Exception):
 
 class Packet(object):
     """Details of a network packet"""
-    def __init__(self, fields):
-        if isinstance(fields, str):
-            fields = fields.rstrip('\n').split('\t')
+    def __init__(self, timestamp, ip_protocol, stream_number, src, dest,
+                 protocol, opcode, desc, extra):
 
+        self.timestamp = timestamp
+        self.ip_protocol = ip_protocol
+        self.stream_number = stream_number
+        self.src = src
+        self.dest = dest
+        self.protocol = protocol
+        self.opcode = opcode
+        self.desc = desc
+        self.extra = extra
+        if self.src < self.dest:
+            self.endpoints = (self.src, self.dest)
+        else:
+            self.endpoints = (self.dest, self.src)
+
+    @classmethod
+    def from_line(self, line):
+        fields = line.rstrip('\n').split('\t')
         (timestamp,
          ip_protocol,
          stream_number,
@@ -148,23 +174,12 @@ class Packet(object):
          desc) = fields[:8]
         extra = fields[8:]
 
-        self.timestamp = float(timestamp)
-        self.ip_protocol = ip_protocol
-        try:
-            self.stream_number = int(stream_number)
-        except (ValueError, TypeError):
-            self.stream_number = None
-        self.src = int(src)
-        self.dest = int(dest)
-        self.protocol = protocol
-        self.opcode = opcode
-        self.desc = desc
-        self.extra = extra
+        timestamp = float(timestamp)
+        src = int(src)
+        dest = int(dest)
 
-        if self.src < self.dest:
-            self.endpoints = (self.src, self.dest)
-        else:
-            self.endpoints = (self.dest, self.src)
+        return Packet(timestamp, ip_protocol, stream_number, src, dest,
+                      protocol, opcode, desc, extra)
 
     def as_summary(self, time_offset=0.0):
         """Format the packet as a traffic_summary line.
@@ -192,14 +207,15 @@ class Packet(object):
         return "<Packet @%s>" % self
 
     def copy(self):
-        return self.__class__([self.timestamp,
-                               self.ip_protocol,
-                               self.stream_number,
-                               self.src,
-                               self.dest,
-                               self.protocol,
-                               self.opcode,
-                               self.desc] + self.extra)
+        return self.__class__(self.timestamp,
+                              self.ip_protocol,
+                              self.stream_number,
+                              self.src,
+                              self.dest,
+                              self.protocol,
+                              self.opcode,
+                              self.desc,
+                              self.extra)
 
     def as_packet_type(self):
         t = '%s:%s' % (self.protocol, self.opcode)
@@ -228,7 +244,7 @@ class Packet(object):
             fn = getattr(traffic_packets, fn_name)
 
         except AttributeError as e:
-            print("Conversation(%s) Missing handler %s" % \
+            print("Conversation(%s) Missing handler %s" %
                   (conversation.conversation_id, fn_name),
                   file=sys.stderr)
             return
@@ -272,13 +288,12 @@ class Packet(object):
             return False
 
         fn_name = 'packet_%s_%s' % (self.protocol, self.opcode)
-        try:
-            fn = getattr(traffic_packets, fn_name)
-            if fn is traffic_packets.null_packet:
-                return False
-        except AttributeError:
+        fn = getattr(traffic_packets, fn_name, None)
+        if not fn:
             print("missing packet %s" % fn_name, file=sys.stderr)
             return False
+        if fn is traffic_packets.null_packet:
+            return False
         return True
 
 
@@ -331,7 +346,7 @@ class ReplayContext(object):
         self.last_netlogon_bad        = False
         self.last_samlogon_bad        = False
         self.generate_ldap_search_tables()
-        self.next_conversation_id = itertools.count().next
+        self.next_conversation_id = itertools.count()
 
     def generate_ldap_search_tables(self):
         session = system_session()
@@ -343,6 +358,7 @@ class ReplayContext(object):
 
         res = db.search(db.domain_dn(),
                         scope=ldb.SCOPE_SUBTREE,
+                        controls=["paged_results:1:1000"],
                         attrs=['dn'])
 
         # find a list of dns for each pattern
@@ -365,7 +381,7 @@ class ReplayContext(object):
         # for k, v in self.dn_map.items():
         #     print >>sys.stderr, k, len(v)
 
-        for k, v in dn_map.items():
+        for k in list(dn_map.keys()):
             if k[-3:] != ',DC':
                 continue
             p = k[:-3]
@@ -394,8 +410,8 @@ class ReplayContext(object):
                                      'conversation-%d' %
                                      conversation.conversation_id)
 
-        self.lp.set("private dir",     self.tempdir)
-        self.lp.set("lock dir",        self.tempdir)
+        self.lp.set("private dir", self.tempdir)
+        self.lp.set("lock dir", self.tempdir)
         self.lp.set("state directory", self.tempdir)
         self.lp.set("tls verify peer", "no_check")
 
@@ -426,8 +442,8 @@ class ReplayContext(object):
            than that requested, but not significantly.
         """
         if not failed_last_time:
-            if (self.badpassword_frequency > 0 and
-               random.random() < self.badpassword_frequency):
+            if (self.badpassword_frequency and self.badpassword_frequency > 0
+                and random.random() < self.badpassword_frequency):
                 try:
                     f(bad)
                 except:
@@ -455,6 +471,7 @@ class ReplayContext(object):
         self.user_creds.set_workstation(self.netbios_name)
         self.user_creds.set_password(self.userpass)
         self.user_creds.set_username(self.username)
+        self.user_creds.set_domain(self.domain)
         if self.prefer_kerberos:
             self.user_creds.set_kerberos_state(MUST_USE_KERBEROS)
         else:
@@ -509,9 +526,10 @@ class ReplayContext(object):
         self.machine_creds = Credentials()
         self.machine_creds.guess(self.lp)
         self.machine_creds.set_workstation(self.netbios_name)
-        self.machine_creds.set_secure_channel_type(SEC_CHAN_WKSTA)
+        self.machine_creds.set_secure_channel_type(SEC_CHAN_BDC)
         self.machine_creds.set_password(self.machinepass)
         self.machine_creds.set_username(self.netbios_name + "$")
+        self.machine_creds.set_domain(self.domain)
         if self.prefer_kerberos:
             self.machine_creds.set_kerberos_state(MUST_USE_KERBEROS)
         else:
@@ -520,7 +538,7 @@ class ReplayContext(object):
         self.machine_creds_bad = Credentials()
         self.machine_creds_bad.guess(self.lp)
         self.machine_creds_bad.set_workstation(self.netbios_name)
-        self.machine_creds_bad.set_secure_channel_type(SEC_CHAN_WKSTA)
+        self.machine_creds_bad.set_secure_channel_type(SEC_CHAN_BDC)
         self.machine_creds_bad.set_password(self.machinepass[:-4])
         self.machine_creds_bad.set_username(self.netbios_name + "$")
         if self.prefer_kerberos:
@@ -643,6 +661,15 @@ class ReplayContext(object):
             return self.ldap_connections[-1]
 
         def simple_bind(creds):
+            """
+            To run simple bind against Windows, we need to run
+            following commands in PowerShell:
+
+                Install-windowsfeature ADCS-Cert-Authority
+                Install-AdcsCertificationAuthority -CAType EnterpriseRootCA
+                Restart-Computer
+
+            """
             return SamDB('ldaps://%s' % self.server,
                          credentials=creds,
                          lp=self.lp)
@@ -669,7 +696,8 @@ class ReplayContext(object):
 
     def get_samr_context(self, new=False):
         if not self.samr_contexts or new:
-            self.samr_contexts.append(SamrContext(self.server))
+            self.samr_contexts.append(
+                SamrContext(self.server, lp=self.lp, creds=self.creds))
         return self.samr_contexts[-1]
 
     def get_netlogon_connection(self):
@@ -696,7 +724,7 @@ class ReplayContext(object):
     def get_authenticator(self):
         auth = self.machine_creds.new_client_authenticator()
         current  = netr_Authenticator()
-        current.cred.data = [ord(x) for x in auth["credential"]]
+        current.cred.data = [x if isinstance(x, int) else ord(x) for x in auth["credential"]]
         current.timestamp = auth["timestamp"]
 
         subsequent = netr_Authenticator()
@@ -706,7 +734,7 @@ class ReplayContext(object):
 class SamrContext(object):
     """State/Context associated with a samr connection.
     """
-    def __init__(self, server):
+    def __init__(self, server, lp=None, creds=None):
         self.connection    = None
         self.handle        = None
         self.domain_handle = None
@@ -715,10 +743,16 @@ class SamrContext(object):
         self.user_handle   = None
         self.rids          = None
         self.server        = server
+        self.lp            = lp
+        self.creds         = creds
 
     def get_connection(self):
         if not self.connection:
-            self.connection = samr.samr("ncacn_ip_tcp:%s" % (self.server))
+            self.connection = samr.samr(
+                "ncacn_ip_tcp:%s[seal]" % (self.server),
+                lp_ctx=self.lp,
+                credentials=self.creds)
+
         return self.connection
 
     def get_handle(self):
@@ -774,23 +808,24 @@ class Conversation(object):
         if p.is_really_a_packet():
             self.packets.append(p)
 
-    def add_short_packet(self, timestamp, p, extra, client=True):
+    def add_short_packet(self, timestamp, protocol, opcode, extra,
+                         client=True):
         """Create a packet from a timestamp, and 'protocol:opcode' pair, and a
         (possibly empty) list of extra data. If client is True, assume
         this packet is from the client to the server.
         """
-        protocol, opcode = p.split(':', 1)
         src, dest = self.guess_client_server()
         if not client:
             src, dest = dest, src
-
-        desc = OP_DESCRIPTIONS.get((protocol, opcode), '')
-        ip_protocol = IP_PROTOCOLS.get(protocol, '06')
-        fields = [timestamp - self.start_time, ip_protocol,
-                  '', src, dest,
-                  protocol, opcode, desc]
-        fields.extend(extra)
-        packet = Packet(fields)
+        key = (protocol, opcode)
+        desc = OP_DESCRIPTIONS[key] if key in OP_DESCRIPTIONS else ''
+        if protocol in IP_PROTOCOLS:
+            ip_protocol = IP_PROTOCOLS[protocol]
+        else:
+            ip_protocol = '06'
+        packet = Packet(timestamp - self.start_time, ip_protocol,
+                        '', src, dest,
+                        protocol, opcode, desc, extra)
         # XXX we're assuming the timestamp is already adjusted for
         # this conversation?
         # XXX should we adjust client balance for guessed packets?
@@ -853,7 +888,7 @@ class Conversation(object):
             gap = t - now
             print("gap is now %f" % gap, file=sys.stderr)
 
-        self.conversation_id = context.next_conversation_id()
+        self.conversation_id = next(context.next_conversation_id)
         pid = os.fork()
         if pid != 0:
             return pid
@@ -928,18 +963,8 @@ class Conversation(object):
         :param s: start of the window
         :param e: end of the window
         """
-
-        new_packets = []
-        for p in self.packets:
-            if p.timestamp < s or p.timestamp > e:
-                continue
-            new_packets.append(p)
-
-        self.packets = new_packets
-        if new_packets:
-            self.start_time = new_packets[0].timestamp
-        else:
-            self.start_time = None
+        self.packets = [p for p in self.packets if s <= p.timestamp <= e]
+        self.start_time = self.packets[0].timestamp if self.packets else None
 
     def renormalise_times(self, start_time):
         """Adjust the packet start times relative to the new start time."""
@@ -1012,7 +1037,7 @@ def ingest_summaries(files, dns_mode='count'):
             f = open(f)
         print("Ingesting %s" % (f.name,), file=sys.stderr)
         for line in f:
-            p = Packet(line)
+            p = Packet.from_line(line)
             if p.protocol == 'dns' and dns_mode != 'include':
                 dns_counts[p.opcode] += 1
             else:
@@ -1210,7 +1235,7 @@ class TrafficModel(object):
                 timestamp += wait
                 if hard_stop is not None and timestamp > hard_stop:
                     break
-                c.add_short_packet(timestamp, p, extra)
+                c.add_short_packet(timestamp, protocol, opcode, extra)
 
             key = key[1:] + (p,)
 
@@ -1248,7 +1273,7 @@ class TrafficModel(object):
             client += 1
 
         print(("we have %d conversations at rate %f" %
-                              (len(conversations), rate)), file=sys.stderr)
+               (len(conversations), rate)), file=sys.stderr)
         conversations.sort()
         return conversations
 
@@ -1418,7 +1443,7 @@ def replay(conversations,
 
     end = start + duration
 
-    print("Replaying traffic for %u conversations over %d seconds"
+    LOGGER.info("Replaying traffic for %u conversations over %d seconds"
           % (len(conversations), duration))
 
     children = {}
@@ -1481,7 +1506,7 @@ def replay(conversations,
     finally:
         for s in (15, 15, 9):
             print(("killing %d children with -%d" %
-                                 (len(children), s)), file=sys.stderr)
+                   (len(children), s)), file=sys.stderr)
             for pid in children:
                 try:
                     os.kill(pid, s)
@@ -1528,6 +1553,7 @@ def openLdb(host, creds, lp):
     session = system_session()
     ldb = SamDB(url="ldap://%s" % host,
                 session_info=session,
+                options=['modules:paged_searches'],
                 credentials=creds,
                 lp=lp)
     return ldb
@@ -1546,18 +1572,18 @@ def create_ou(ldb, instance_id):
     """
     ou = ou_name(ldb, instance_id)
     try:
-        ldb.add({"dn":          ou.split(',', 1)[1],
+        ldb.add({"dn": ou.split(',', 1)[1],
                  "objectclass": "organizationalunit"})
     except LdbError as e:
-        (status, _) = e
+        (status, _) = e.args
         # ignore already exists
         if status != 68:
             raise
     try:
-        ldb.add({"dn":          ou,
+        ldb.add({"dn": ou,
                  "objectclass": "organizationalunit"})
     except LdbError as e:
-        (status, _) = e
+        (status, _) = e.args
         # ignore already exists
         if status != 68:
             raise
@@ -1577,11 +1603,10 @@ class ConversationAccounts(object):
 def generate_replay_accounts(ldb, instance_id, number, password):
     """Generate a series of unique machine and user account names."""
 
-    generate_traffic_accounts(ldb, instance_id, number, password)
     accounts = []
     for i in range(1, number + 1):
-        netbios_name = "STGM-%d-%d" % (instance_id, i)
-        username     = "STGU-%d-%d" % (instance_id, i)
+        netbios_name = machine_name(instance_id, i)
+        username = user_name(instance_id, i)
 
         account = ConversationAccounts(netbios_name, password, username,
                                        password)
@@ -1589,80 +1614,36 @@ def generate_replay_accounts(ldb, instance_id, number, password):
     return accounts
 
 
-def generate_traffic_accounts(ldb, instance_id, number, password):
-    """Create the specified number of user and machine accounts.
-
-    As accounts are not explicitly deleted between runs. This function starts
-    with the last account and iterates backwards stopping either when it
-    finds an already existing account or it has generated all the required
-    accounts.
-    """
-    print(("Generating machine and conversation accounts, "
-           "as required for %d conversations" % number),
-          file=sys.stderr)
-    added = 0
-    for i in range(number, 0, -1):
-        try:
-            netbios_name = "STGM-%d-%d" % (instance_id, i)
-            create_machine_account(ldb, instance_id, netbios_name, password)
-            added += 1
-        except LdbError as e:
-            (status, _) = e
-            if status == 68:
-                break
-            else:
-                raise
-    if added > 0:
-        print("Added %d new machine accounts" % added,
-              file=sys.stderr)
-
-    added = 0
-    for i in range(number, 0, -1):
-        try:
-            username = "STGU-%d-%d" % (instance_id, i)
-            create_user_account(ldb, instance_id, username, password)
-            added += 1
-        except LdbError as e:
-            (status, _) = e
-            if status == 68:
-                break
-            else:
-                raise
-
-    if added > 0:
-        print("Added %d new user accounts" % added,
-              file=sys.stderr)
-
-
-def create_machine_account(ldb, instance_id, netbios_name, machinepass):
+def create_machine_account(ldb, instance_id, netbios_name, machinepass,
+                           traffic_account=True):
     """Create a machine account via ldap."""
 
     ou = ou_name(ldb, instance_id)
     dn = "cn=%s,%s" % (netbios_name, ou)
-    utf16pw = unicode(
-        '"' + machinepass.encode('utf-8') + '"', 'utf-8'
-    ).encode('utf-16-le')
-    start = time.time()
+    utf16pw = ('"%s"' % get_string(machinepass)).encode('utf-16-le')
+
+    if traffic_account:
+        # we set these bits for the machine account otherwise the replayed
+        # traffic throws up NT_STATUS_NO_TRUST_SAM_ACCOUNT errors
+        account_controls = str(UF_TRUSTED_FOR_DELEGATION |
+                               UF_SERVER_TRUST_ACCOUNT)
+
+    else:
+        account_controls = str(UF_WORKSTATION_TRUST_ACCOUNT)
+
     ldb.add({
         "dn": dn,
         "objectclass": "computer",
         "sAMAccountName": "%s$" % netbios_name,
-        "userAccountControl":
-        str(UF_WORKSTATION_TRUST_ACCOUNT | UF_PASSWD_NOTREQD),
+        "userAccountControl": account_controls,
         "unicodePwd": utf16pw})
-    end = time.time()
-    duration = end - start
-    print("%f\t0\tcreate\tmachine\t%f\tTrue\t" % (end, duration))
 
 
 def create_user_account(ldb, instance_id, username, userpass):
     """Create a user account via ldap."""
     ou = ou_name(ldb, instance_id)
     user_dn = "cn=%s,%s" % (username, ou)
-    utf16pw = unicode(
-        '"' + userpass.encode('utf-8') + '"', 'utf-8'
-    ).encode('utf-16-le')
-    start = time.time()
+    utf16pw = ('"%s"' % get_string(userpass)).encode('utf-16-le')
     ldb.add({
         "dn": user_dn,
         "objectclass": "user",
@@ -1670,9 +1651,10 @@ def create_user_account(ldb, instance_id, username, userpass):
         "userAccountControl": str(UF_NORMAL_ACCOUNT),
         "unicodePwd": utf16pw
     })
-    end = time.time()
-    duration = end - start
-    print("%f\t0\tcreate\tuser\t%f\tTrue\t" % (end, duration))
+
+    # grant user write permission to do things like write account SPN
+    sdutils = sd_utils.SDUtils(ldb)
+    sdutils.dacl_add_ace(user_dn, "(A;;WP;;;PS)")
 
 
 def create_group(ldb, instance_id, name):
@@ -1680,14 +1662,11 @@ def create_group(ldb, instance_id, name):
 
     ou = ou_name(ldb, instance_id)
     dn = "cn=%s,%s" % (name, ou)
-    start = time.time()
     ldb.add({
         "dn": dn,
         "objectclass": "group",
+        "sAMAccountName": name,
     })
-    end = time.time()
-    duration = end - start
-    print("%f\t0\tcreate\tgroup\t%f\tTrue\t" % (end, duration))
 
 
 def user_name(instance_id, i):
@@ -1695,25 +1674,62 @@ def user_name(instance_id, i):
     return "STGU-%d-%d" % (instance_id, i)
 
 
+def search_objectclass(ldb, objectclass='user', attr='sAMAccountName'):
+    """Seach objectclass, return attr in a set"""
+    objs = ldb.search(
+        expression="(objectClass={})".format(objectclass),
+        attrs=[attr]
+    )
+    return {str(obj[attr]) for obj in objs}
+
+
 def generate_users(ldb, instance_id, number, password):
     """Add users to the server"""
+    existing_objects = search_objectclass(ldb, objectclass='user')
     users = 0
     for i in range(number, 0, -1):
-        try:
-            username = user_name(instance_id, i)
-            create_user_account(ldb, instance_id, username, password)
+        name = user_name(instance_id, i)
+        if name not in existing_objects:
+            create_user_account(ldb, instance_id, name, password)
             users += 1
-        except LdbError as e:
-            (status, _) = e
-            # Stop if entry exists
-            if status == 68:
-                break
-            else:
-                raise
+            if users % 50 == 0:
+                LOGGER.info("Created %u/%u users" % (users, number))
 
     return users
 
 
+def machine_name(instance_id, i, traffic_account=True):
+    """Generate a machine account name from instance id."""
+    if traffic_account:
+        # traffic accounts correspond to a given user, and use different
+        # userAccountControl flags to ensure packets get processed correctly
+        # by the DC
+        return "STGM-%d-%d" % (instance_id, i)
+    else:
+        # Otherwise we're just generating computer accounts to simulate a
+        # semi-realistic network. These use the default computer
+        # userAccountControl flags, so we use a different account name so that
+        # we don't try to use them when generating packets
+        return "PC-%d-%d" % (instance_id, i)
+
+
+def generate_machine_accounts(ldb, instance_id, number, password,
+                              traffic_account=True):
+    """Add machine accounts to the server"""
+    existing_objects = search_objectclass(ldb, objectclass='computer')
+    added = 0
+    for i in range(number, 0, -1):
+        name = machine_name(instance_id, i, traffic_account)
+        if name + "$" not in existing_objects:
+            create_machine_account(ldb, instance_id, name, password,
+                                   traffic_account)
+            added += 1
+            if added % 50 == 0:
+                LOGGER.info("Created %u/%u machine accounts" % (added, number))
+
+    return added
+
+
 def group_name(instance_id, i):
     """Generate a group name from instance id."""
     return "STGG-%d-%d" % (instance_id, i)
@@ -1721,19 +1737,16 @@ def group_name(instance_id, i):
 
 def generate_groups(ldb, instance_id, number):
     """Create the required number of groups on the server."""
+    existing_objects = search_objectclass(ldb, objectclass='group')
     groups = 0
     for i in range(number, 0, -1):
-        try:
-            name = group_name(instance_id, i)
+        name = group_name(instance_id, i)
+        if name not in existing_objects:
             create_group(ldb, instance_id, name)
             groups += 1
-        except LdbError as e:
-            (status, _) = e
-            # Stop if entry exists
-            if status == 68:
-                break
-            else:
-                raise
+            if groups % 1000 == 0:
+                LOGGER.info("Created %u/%u groups" % (groups, number))
+
     return groups
 
 
@@ -1743,7 +1756,7 @@ def clean_up_accounts(ldb, instance_id):
     try:
         ldb.delete(ou, ["tree_delete:1"])
     except LdbError as e:
-        (status, _) = e
+        (status, _) = e.args
         # ignore does not exist
         if status != 32:
             raise
@@ -1751,123 +1764,239 @@ def clean_up_accounts(ldb, instance_id):
 
 def generate_users_and_groups(ldb, instance_id, password,
                               number_of_users, number_of_groups,
-                              group_memberships):
+                              group_memberships, max_members,
+                              machine_accounts, traffic_accounts=True):
     """Generate the required users and groups, allocating the users to
        those groups."""
-    assignments = []
-    groups_added  = 0
+    memberships_added = 0
+    groups_added = 0
+    computers_added = 0
 
     create_ou(ldb, instance_id)
 
-    print("Generating dummy user accounts", file=sys.stderr)
+    LOGGER.info("Generating dummy user accounts")
     users_added = generate_users(ldb, instance_id, number_of_users, password)
 
+    LOGGER.info("Generating dummy machine accounts")
+    computers_added = generate_machine_accounts(ldb, instance_id,
+                                                machine_accounts, password,
+                                                traffic_accounts)
+
     if number_of_groups > 0:
-        print("Generating dummy groups", file=sys.stderr)
+        LOGGER.info("Generating dummy groups")
         groups_added = generate_groups(ldb, instance_id, number_of_groups)
 
     if group_memberships > 0:
-        print("Assigning users to groups", file=sys.stderr)
-        assignments = assign_groups(number_of_groups,
-                                    groups_added,
-                                    number_of_users,
-                                    users_added,
-                                    group_memberships)
-        print("Adding users to groups", file=sys.stderr)
+        LOGGER.info("Assigning users to groups")
+        assignments = GroupAssignments(number_of_groups,
+                                       groups_added,
+                                       number_of_users,
+                                       users_added,
+                                       group_memberships,
+                                       max_members)
+        LOGGER.info("Adding users to groups")
         add_users_to_groups(ldb, instance_id, assignments)
+        memberships_added = assignments.total()
 
     if (groups_added > 0 and users_added == 0 and
        number_of_groups != groups_added):
-        print("Warning: the added groups will contain no members",
-              file=sys.stderr)
-
-    print(("Added %d users, %d groups and %d group memberships" %
-           (users_added, groups_added, len(assignments))),
-          file=sys.stderr)
-
-
-def assign_groups(number_of_groups,
-                  groups_added,
-                  number_of_users,
-                  users_added,
-                  group_memberships):
-    """Allocate users to groups.
-
-    The intention is to have a few users that belong to most groups, while
-    the majority of users belong to a few groups.
-
-    A few groups will contain most users, with the remaining only having a
-    few users.
-    """
+        LOGGER.warning("The added groups will contain no members")
+
+    LOGGER.info("Added %d users (%d machines), %d groups and %d memberships" %
+                (users_added, computers_added, groups_added,
+                 memberships_added))
+
+
+class GroupAssignments(object):
+    def __init__(self, number_of_groups, groups_added, number_of_users,
+                 users_added, group_memberships, max_members):
+
+        self.count = 0
+        self.generate_group_distribution(number_of_groups)
+        self.generate_user_distribution(number_of_users, group_memberships)
+        self.max_members = max_members
+        self.assignments = defaultdict(list)
+        self.assign_groups(number_of_groups, groups_added, number_of_users,
+                           users_added, group_memberships)
+
+    def cumulative_distribution(self, weights):
+        # make sure the probabilities conform to a cumulative distribution
+        # spread between 0.0 and 1.0. Dividing by the weighted total gives each
+        # probability a proportional share of 1.0. Higher probabilities get a
+        # bigger share, so are more likely to be picked. We use the cumulative
+        # value, so we can use random.random() as a simple index into the list
+        dist = []
+        total = sum(weights)
+        if total == 0:
+            return None
+
+        cumulative = 0.0
+        for probability in weights:
+            cumulative += probability
+            dist.append(cumulative / total)
+        return dist
 
-    def generate_user_distribution(n):
+    def generate_user_distribution(self, num_users, num_memberships):
         """Probability distribution of a user belonging to a group.
         """
-        dist = []
-        for x in range(1, n + 1):
-            p = 1 / (x + 0.001)
-            dist.append(p)
-        return dist
+        # Assign a weighted probability to each user. Use the Pareto
+        # Distribution so that some users are in a lot of groups, and the
+        # bulk of users are in only a few groups. If we're assigning a large
+        # number of group memberships, use a higher shape. This means slightly
+        # fewer outlying users that are in large numbers of groups. The aim is
+        # to have no users belonging to more than ~500 groups.
+        if num_memberships > 5000000:
+            shape = 3.0
+        elif num_memberships > 2000000:
+            shape = 2.5
+        elif num_memberships > 300000:
+            shape = 2.25
+        else:
+            shape = 1.75
 
-    def generate_group_distribution(n):
+        weights = []
+        for x in range(1, num_users + 1):
+            p = random.paretovariate(shape)
+            weights.append(p)
+
+        # convert the weights to a cumulative distribution between 0.0 and 1.0
+        self.user_dist = self.cumulative_distribution(weights)
+
+    def generate_group_distribution(self, n):
         """Probability distribution of a group containing a user."""
-        dist = []
+
+        # Assign a weighted probability to each user. Probability decreases
+        # as the group-ID increases
+        weights = []
         for x in range(1, n + 1):
             p = 1 / (x**1.3)
-            dist.append(p)
-        return dist
+            weights.append(p)
+
+        # convert the weights to a cumulative distribution between 0.0 and 1.0
+        self.group_weights = weights
+        self.group_dist = self.cumulative_distribution(weights)
+
+    def generate_random_membership(self):
+        """Returns a randomly generated user-group membership"""
+
+        # the list items are cumulative distribution values between 0.0 and
+        # 1.0, which makes random() a handy way to index the list to get a
+        # weighted random user/group. (Here the user/group returned are
+        # zero-based array indexes)
+        user = bisect.bisect(self.user_dist, random.random())
+        group = bisect.bisect(self.group_dist, random.random())
+
+        return user, group
+
+    def users_in_group(self, group):
+        return self.assignments[group]
+
+    def get_groups(self):
+        return self.assignments.keys()
+
+    def cap_group_membership(self, group, max_members):
+        """Prevent the group's membership from exceeding the max specified"""
+        num_members = len(self.assignments[group])
+        if num_members >= max_members:
+            LOGGER.info("Group {0} has {1} members".format(group, num_members))
+
+            # remove this group and then recalculate the cumulative
+            # distribution, so this group is no longer selected
+            self.group_weights[group - 1] = 0
+            new_dist = self.cumulative_distribution(self.group_weights)
+            self.group_dist = new_dist
+
+    def add_assignment(self, user, group):
+        # the assignments are stored in a dictionary where key=group,
+        # value=list-of-users-in-group (indexing by group-ID allows us to
+        # optimize for DB membership writes)
+        if user not in self.assignments[group]:
+            self.assignments[group].append(user)
+            self.count += 1
+
+        # check if there'a cap on how big the groups can grow
+        if self.max_members:
+            self.cap_group_membership(group, self.max_members)
+
+    def assign_groups(self, number_of_groups, groups_added,
+                      number_of_users, users_added, group_memberships):
+        """Allocate users to groups.
+
+        The intention is to have a few users that belong to most groups, while
+        the majority of users belong to a few groups.
+
+        A few groups will contain most users, with the remaining only having a
+        few users.
+        """
 
-    assignments = set()
-    if group_memberships <= 0:
-        return assignments
+        if group_memberships <= 0:
+            return
 
-    group_dist = generate_group_distribution(number_of_groups)
-    user_dist  = generate_user_distribution(number_of_users)
+        # Calculate the number of group menberships required
+        group_memberships = math.ceil(
+            float(group_memberships) *
+            (float(users_added) / float(number_of_users)))
 
-    # Calculate the number of group menberships required
-    group_memberships = math.ceil(
-        float(group_memberships) *
-        (float(users_added) / float(number_of_users)))
+        if self.max_members:
+            group_memberships = min(group_memberships,
+                                    self.max_members * number_of_groups)
 
-    existing_users  = number_of_users  - users_added  - 1
-    existing_groups = number_of_groups - groups_added - 1
-    while len(assignments) < group_memberships:
-        user        = random.randint(0, number_of_users - 1)
-        group       = random.randint(0, number_of_groups - 1)
-        probability = group_dist[group] * user_dist[user]
+        existing_users  = number_of_users  - users_added  - 1
+        existing_groups = number_of_groups - groups_added - 1
+        while self.total() < group_memberships:
+            user, group = self.generate_random_membership()
 
-        if ((random.random() < probability * 10000) and
-           (group > existing_groups or user > existing_users)):
-            # the + 1 converts the array index to the corresponding
-            # group or user number
-            assignments.add(((user + 1), (group + 1)))
+            if group > existing_groups or user > existing_users:
+                # the + 1 converts the array index to the corresponding
+                # group or user number
+                self.add_assignment(user + 1, group + 1)
 
-    return assignments
+    def total(self):
+        return self.count
 
 
 def add_users_to_groups(db, instance_id, assignments):
-    """Add users to their assigned groups.
+    """Takes the assignments of users to groups and applies them to the DB."""
+
+    total = assignments.total()
+    count = 0
+    added = 0
+
+    for group in assignments.get_groups():
+        users_in_group = assignments.users_in_group(group)
+        if len(users_in_group) == 0:
+            continue
 
-    Takes the list of (group,user) tuples generated by assign_groups and
-    assign the users to their specified groups."""
+        # Split up the users into chunks, so we write no more than 1K at a
+        # time. (Minimizing the DB modifies is more efficient, but writing
+        # 10K+ users to a single group becomes inefficient memory-wise)
+        for chunk in range(0, len(users_in_group), 1000):
+            chunk_of_users = users_in_group[chunk:chunk + 1000]
+            add_group_members(db, instance_id, group, chunk_of_users)
+
+            added += len(chunk_of_users)
+            count += 1
+            if count % 50 == 0:
+                LOGGER.info("Added %u/%u memberships" % (added, total))
+
+def add_group_members(db, instance_id, group, users_in_group):
+    """Adds the given users to group specified."""
 
     ou = ou_name(db, instance_id)
 
     def build_dn(name):
         return("cn=%s,%s" % (name, ou))
 
-    for (user, group) in assignments:
-        user_dn  = build_dn(user_name(instance_id, user))
-        group_dn = build_dn(group_name(instance_id, group))
+    group_dn = build_dn(group_name(instance_id, group))
+    m = ldb.Message()
+    m.dn = ldb.Dn(db, group_dn)
 
-        m = ldb.Message()
-        m.dn = ldb.Dn(db, group_dn)
-        m["member"] = ldb.MessageElement(user_dn, ldb.FLAG_MOD_ADD, "member")
-        start = time.time()
-        db.modify(m)
-        end = time.time()
-        duration = end - start
-        print("%f\t0\tadd\tuser\t%f\tTrue\t" % (end, duration))
+    for user in users_in_group:
+        user_dn = build_dn(user_name(instance_id, user))
+        idx = "member-" + str(user)
+        m[idx] = ldb.MessageElement(user_dn, ldb.FLAG_MOD_ADD, "member")
+
+    db.modify(m)
 
 
 def generate_stats(statsdir, timing_file):
@@ -1939,25 +2068,16 @@ def generate_stats(statsdir, timing_file):
     else:
         failure_rate = failed / duration
 
-    # print the stats in more human-readable format when stdout is going to the
-    # console (as opposed to being redirected to a file)
-    if sys.stdout.isatty():
-        print("Total conversations:   %10d" % conversations)
-        print("Successful operations: %10d (%.3f per second)"
-              % (successful, success_rate))
-        print("Failed operations:     %10d (%.3f per second)"
-              % (failed, failure_rate))
-    else:
-        print("(%d, %d, %d, %.3f, %.3f)" %
-              (conversations, successful, failed, success_rate, failure_rate))
+    print("Total conversations:   %10d" % conversations)
+    print("Successful operations: %10d (%.3f per second)"
+          % (successful, success_rate))
+    print("Failed operations:     %10d (%.3f per second)"
+          % (failed, failure_rate))
+
+    print("Protocol    Op Code  Description                               "
+          " Count       Failed         Mean       Median          "
+          "95%        Range          Max")
 
-    if sys.stdout.isatty():
-        print("Protocol    Op Code  Description                               "
-              " Count       Failed         Mean       Median          "
-              "95%        Range          Max")
-    else:
-        print("proto\top_code\tdesc\tcount\tfailed\tmean\tmedian\t95%\trange"
-              "\tmax")
     protocols = sorted(latencies.keys())
     for protocol in protocols:
         packet_types = sorted(latencies[protocol], key=opcode_key)