traffic_replay: Add a max-members option to cap group size
[samba.git] / script / traffic_replay
index 6f42f2d68cd5b0e535c46e23ef58b7a36c5c2ec5..0ee0f9b65752291a38ecb256acb28e2217fda71a 100755 (executable)
@@ -22,12 +22,16 @@ import os
 import optparse
 import tempfile
 import shutil
+import random
 
 sys.path.insert(0, "bin/python")
 
-from samba import gensec
+from samba import gensec, get_debug_level
 from samba.emulate import traffic
 import samba.getopt as options
+from samba.logger import get_samba_logger
+from samba.samdb import SamDB
+from samba.auth import system_session
 
 
 def print_err(*args, **kwargs):
@@ -67,6 +71,8 @@ def main():
     parser.add_option('-c', '--clean-up',
                       action="store_true",
                       help='Clean up the generated groups and user accounts')
+    parser.add_option('--random-seed', type='int', default=0,
+                      help='Use to keep randomness consistent across multiple runs')
 
     model_group = optparse.OptionGroup(parser, 'Traffic Model Options',
                                        'These options alter the traffic '
@@ -106,6 +112,8 @@ def main():
     user_gen_group.add_option('--group-memberships', type='int', default=0,
                               help='Total memberships to assign across all '
                               'test users and all groups')
+    user_gen_group.add_option('--max-members', type='int', default=None,
+                              help='Max users to add to any one group')
     parser.add_option_group(user_gen_group)
 
     sambaopts = options.SambaOptions(parser)
@@ -131,8 +139,18 @@ def main():
         parser.print_usage()
         return
 
+    lp = sambaopts.get_loadparm()
+    debuglevel = get_debug_level()
+    logger = get_samba_logger(name=__name__,
+                              verbose=debuglevel > 3,
+                              quiet=debuglevel < 1)
+
+    traffic.DEBUG_LEVEL = debuglevel
+    # pass log level down to traffic module to make sure level is controlled
+    traffic.LOGGER.setLevel(logger.getEffectiveLevel())
+
     if opts.clean_up:
-        print_err("Removing user and machine accounts")
+        logger.info("Removing user and machine accounts")
         lp    = sambaopts.get_loadparm()
         creds = credopts.get_credentials(lp)
         creds.set_gensec_features(creds.get_gensec_features() | gensec.FEATURE_SEAL)
@@ -142,20 +160,22 @@ def main():
 
     if summary:
         if not os.path.exists(summary):
-            print_err("Summary file %s doesn't exist" % summary)
+            logger.error("Summary file %s doesn't exist" % summary)
             sys.exit(1)
     # the summary-file can be ommitted for --generate-users-only and
     # --cleanup-up, but it should be specified in all other cases
     elif not opts.generate_users_only:
-        print_err("No summary-file specified to replay traffic from")
+        logger.error("No summary-file specified to replay traffic from")
         sys.exit(1)
 
     if not opts.fixed_password:
-        print_err(("Please use --fixed-password to specify a password"
-                             " for the users created as part of this test"))
+        logger.error(("Please use --fixed-password to specify a password"
+                      " for the users created as part of this test"))
         sys.exit(1)
 
-    lp = sambaopts.get_loadparm()
+    if opts.random_seed:
+        random.seed(opts.random_seed)
+
     creds = credopts.get_credentials(lp)
     creds.set_gensec_features(creds.get_gensec_features() | gensec.FEATURE_SEAL)
 
@@ -165,54 +185,58 @@ def main():
     else:
         domain = lp.get("workgroup")
         if domain == "WORKGROUP":
-            print_err(("NETBIOS domain does not appear to be "
-                       "specified, use the --workgroup option"))
+            logger.error(("NETBIOS domain does not appear to be "
+                          "specified, use the --workgroup option"))
             sys.exit(1)
 
     if not opts.realm and not lp.get('realm'):
-        print_err("Realm not specified, use the --realm option")
+        logger.error("Realm not specified, use the --realm option")
         sys.exit(1)
 
     if opts.generate_users_only and not (opts.number_of_users or
                                          opts.number_of_groups):
-        print_err(("Please specify the number of users and/or groups "
-                   "to generate."))
+        logger.error(("Please specify the number of users and/or groups "
+                      "to generate."))
         sys.exit(1)
 
     if opts.group_memberships and opts.average_groups_per_user:
-        print_err(("--group-memberships and --average-groups-per-user"
-                   " are incompatible options - use one or the other"))
+        logger.error(("--group-memberships and --average-groups-per-user"
+                      " are incompatible options - use one or the other"))
         sys.exit(1)
 
     if not opts.number_of_groups and opts.average_groups_per_user:
-        print_err(("--average-groups-per-user requires "
-                   "--number-of-groups"))
+        logger.error(("--average-groups-per-user requires "
+                      "--number-of-groups"))
         sys.exit(1)
 
+    if opts.number_of_groups and opts.average_groups_per_user:
+        if opts.number_of_groups < opts.average_groups_per_user:
+            logger.error(("--average-groups-per-user can not be more than "
+                          "--number-of-groups"))
+            sys.exit(1)
+
     if not opts.number_of_groups and opts.group_memberships:
-        print_err("--group-memberships requires --number-of-groups")
+        logger.error("--group-memberships requires --number-of-groups")
         sys.exit(1)
 
     if opts.timing_data not in ('-', None):
         try:
             open(opts.timing_data, 'w').close()
-        except IOError as e:
-            print_err(("the supplied timing data destination "
-                       "(%s) is not writable" % opts.timing_data))
-            print_err(e)
+        except IOError:
+            # exception info will be added to log automatically
+            logger.exception(("the supplied timing data destination "
+                              "(%s) is not writable" % opts.timing_data))
             sys.exit()
 
     if opts.traffic_summary not in ('-', None):
         try:
             open(opts.traffic_summary, 'w').close()
-        except IOError as e:
-            print_err(("the supplied traffic summary destination "
-                       "(%s) is not writable" % opts.traffic_summary))
-            print_err(e)
+        except IOError:
+            # exception info will be added to log automatically
+            logger.exception(("the supplied traffic summary destination "
+                              "(%s) is not writable" % opts.traffic_summary))
             sys.exit()
 
-    traffic.DEBUG_LEVEL = opts.debuglevel
-
     duration = opts.duration
     if duration is None:
         duration = 60.0
@@ -223,8 +247,8 @@ def main():
             conversations, interval, duration, dns_counts = \
                                             traffic.ingest_summaries([summary])
 
-            print_err(("Using conversations from the traffic summary "
-                       "file specified"))
+            logger.info(("Using conversations from the traffic summary "
+                         "file specified"))
 
             # honour the specified duration if it's different to the
             # capture duration
@@ -232,7 +256,7 @@ def main():
                 duration = opts.duration
 
         except ValueError as e:
-            if not e.message.startswith('need more than'):
+            if not str(e).startswith('need more than'):
                 raise
 
             model = traffic.TrafficModel()
@@ -240,15 +264,14 @@ def main():
             try:
                 model.load(summary)
             except ValueError:
-                print_err(("Could not parse %s. The summary file "
-                           "should be the output from either the "
-                           "traffic_summary.pl or "
-                           "traffic_learner scripts."
-                           % summary))
+                logger.error(("Could not parse %s. The summary file "
+                              "should be the output from either the "
+                              "traffic_summary.pl or "
+                              "traffic_learner scripts.") % summary)
                 sys.exit()
 
-            print_err(("Using the specified model file to "
-                       "generate conversations"))
+            logger.info(("Using the specified model file to "
+                         "generate conversations"))
 
             conversations = model.generate_conversations(opts.scale_traffic,
                                                          duration,
@@ -257,17 +280,17 @@ def main():
     else:
         conversations = []
 
-    if opts.debuglevel > 5:
+    if debuglevel > 5:
         for c in conversations:
             for p in c.packets:
-                print("    ", p)
+                print("    ", p, file=sys.stderr)
 
-        print('=' * 72)
+        print('=' * 72, file=sys.stderr)
 
     if opts.number_of_users and opts.number_of_users < len(conversations):
-        print_err(("--number-of-users (%d) is less than the "
-                   "number of conversations to replay (%d)"
-                   % (opts.number_of_users, len(conversations))))
+        logger.error(("--number-of-users (%d) is less than the "
+                      "number of conversations to replay (%d)"
+                     % (opts.number_of_users, len(conversations))))
         sys.exit(1)
 
     number_of_users = max(opts.number_of_users, len(conversations))
@@ -275,43 +298,60 @@ def main():
 
     if not opts.group_memberships and opts.average_groups_per_user:
         opts.group_memberships = opts.average_groups_per_user * number_of_users
-        print_err(("Using %d group-memberships based on %u average "
-                   "memberships for %d users"
-                   % (opts.group_memberships,
-                      opts.average_groups_per_user, number_of_users)))
+        logger.info(("Using %d group-memberships based on %u average "
+                     "memberships for %d users"
+                     % (opts.group_memberships,
+                        opts.average_groups_per_user, number_of_users)))
 
     if opts.group_memberships > max_memberships:
-        print_err(("The group memberships specified (%d) exceeds "
-                   "the total users (%d) * total groups (%d)"
-                   % (opts.group_memberships, number_of_users,
-                      opts.number_of_groups)))
+        logger.error(("The group memberships specified (%d) exceeds "
+                      "the total users (%d) * total groups (%d)"
+                      % (opts.group_memberships, number_of_users,
+                         opts.number_of_groups)))
         sys.exit(1)
 
+    # Get an LDB connection.
     try:
-        ldb = traffic.openLdb(host, creds, lp)
+        # if we're only adding users, then it's OK to pass a sam.ldb filepath
+        # as the host, which creates the users much faster. In all other cases
+        # we should be connecting to a remote DC
+        if opts.generate_users_only and os.path.isfile(host):
+            ldb = SamDB(url="ldb://{0}".format(host),
+                        session_info=system_session(), lp=lp)
+        else:
+            ldb = traffic.openLdb(host, creds, lp)
     except:
-        print_err(("\nInitial LDAP connection failed! Did you supply "
-                   "a DNS host name and the correct credentials?"))
+        logger.error(("\nInitial LDAP connection failed! Did you supply "
+                      "a DNS host name and the correct credentials?"))
         sys.exit(1)
 
     if opts.generate_users_only:
+        # generate computer accounts for added realism. Assume there will be
+        # some overhang with more computer accounts than users
+        computer_accounts = int(1.25 * number_of_users)
         traffic.generate_users_and_groups(ldb,
                                           opts.instance_id,
                                           opts.fixed_password,
                                           opts.number_of_users,
                                           opts.number_of_groups,
-                                          opts.group_memberships)
+                                          opts.group_memberships,
+                                          opts.max_members,
+                                          machine_accounts=computer_accounts,
+                                          traffic_accounts=False)
         sys.exit()
 
     tempdir = tempfile.mkdtemp(prefix="samba_tg_")
-    print_err("Using temp dir %s" % tempdir)
+    logger.info("Using temp dir %s" % tempdir)
 
     traffic.generate_users_and_groups(ldb,
                                       opts.instance_id,
                                       opts.fixed_password,
                                       number_of_users,
                                       opts.number_of_groups,
-                                      opts.group_memberships)
+                                      opts.group_memberships,
+                                      opts.max_members,
+                                      machine_accounts=len(conversations),
+                                      traffic_accounts=True)
 
     accounts = traffic.generate_replay_accounts(ldb,
                                                 opts.instance_id,
@@ -326,7 +366,7 @@ def main():
         else:
             summary_dest = open(opts.traffic_summary, 'w')
 
-        print_err("Writing traffic summary")
+        logger.info("Writing traffic summary")
         summaries = []
         for c in conversations:
             summaries += c.replay_as_summary_lines()
@@ -359,11 +399,11 @@ def main():
     else:
         timing_dest = open(opts.timing_data, 'w')
 
-    print_err("Generating statistics")
+    logger.info("Generating statistics")
     traffic.generate_stats(statsdir, timing_dest)
 
     if not opts.preserve_tempdir:
-        print_err("Removing temporary directory")
+        logger.info("Removing temporary directory")
         shutil.rmtree(tempdir)