traffic: rework conversation generation to better use memory
authorDouglas Bagnall <douglas.bagnall@catalyst.net.nz>
Fri, 19 Oct 2018 04:11:52 +0000 (17:11 +1300)
committerDouglas Bagnall <dbagnall@samba.org>
Tue, 8 Jan 2019 22:55:33 +0000 (23:55 +0100)
Use less memory altogether and don't allocated shared mutable before
the fork.

Signed-off-by: Douglas Bagnall <douglas.bagnall@catalyst.net.nz>
Reviewed-by: Andrew Bartlett <abartlet@samba.org>
python/samba/emulate/traffic.py
script/traffic_replay

index 3310ce768bd1ce9a87ef115c1398c0479721db69..b6501c49282477729bfb4821114b93771b092f77 100644 (file)
@@ -887,86 +887,41 @@ class Conversation(object):
             lines.append(p.as_summary(self.start_time))
         return lines
 
-    def replay_in_fork_with_delay(self, start, context=None, account=None):
-        """Fork a new process and replay the conversation.
-        """
-        def signal_handler(signal, frame):
-            """Signal handler closes standard out and error.
-
-            Triggered by a sigterm, ensures that the log messages are flushed
-            to disk and not lost.
-            """
-            sys.stderr.close()
-            sys.stdout.close()
-            os._exit(0)
-
+    def replay_with_delay(self, start, context=None, account=None):
+        """Replay the conversation at the right time.
+        (We're already in a fork)."""
+        # first we sleep until the first packet
         t = self.start_time
         now = time.time() - start
         gap = t - now
-        # we are replaying strictly in order, so it is safe to sleep
-        # in the main process if the gap is big enough. This reduces
-        # the number of concurrent threads, which allows us to make
-        # larger loads.
-        if gap > 0.15 and False:
-            print("sleeping for %f in main process" % (gap - 0.1),
-                  file=sys.stderr)
-            time.sleep(gap - 0.1)
-            now = time.time() - start
-            gap = t - now
-            print("gap is now %f" % gap, file=sys.stderr)
-
-        self.conversation_id = next(context.next_conversation_id)
-        pid = os.fork()
-        if pid != 0:
-            return pid
-        pid = os.getpid()
-        signal.signal(signal.SIGTERM, signal_handler)
-        # we must never return, or we'll end up running parts of the
-        # parent's clean-up code. So we work in a try...finally, and
-        # try to print any exceptions.
-
-        try:
-            context.generate_process_local_config(account, self)
-            sys.stdin.close()
-            os.close(0)
-            filename = os.path.join(context.statsdir, 'stats-conversation-%d' %
-                                    self.conversation_id)
-            sys.stdout.close()
-            sys.stdout = open(filename, 'w')
-
-            sleep_time = gap - SLEEP_OVERHEAD
-            if sleep_time > 0:
-                time.sleep(sleep_time)
+        sleep_time = gap - SLEEP_OVERHEAD
+        if sleep_time > 0:
+            time.sleep(sleep_time)
 
-            miss = t - (time.time() - start)
-            self.msg("starting %s [miss %.3f pid %d]" % (self, miss, pid))
-            self.replay(context)
-        except Exception:
-            print(("EXCEPTION in child PID %d, conversation %s" % (pid, self)),
-                  file=sys.stderr)
-            traceback.print_exc(sys.stderr)
-        finally:
-            sys.stderr.close()
-            sys.stdout.close()
-            os._exit(0)
-
-    def replay(self, context=None):
-        start = time.time()
+        miss = (time.time() - start) - t
+        self.msg("starting %s [miss %.3f]" % (self, miss))
 
+        max_gap = 0.0
+        max_sleep_miss = 0.0
+        # packet times are relative to conversation start
+        p_start = time.time()
         for p in self.packets:
-            now = time.time() - start
-            gap = p.timestamp - now
-            sleep_time = gap - SLEEP_OVERHEAD
-            if sleep_time > 0:
-                time.sleep(sleep_time)
+            now = time.time() - p_start
+            gap = now - p.timestamp
+            if gap > max_gap:
+                max_gap = gap
+            if gap < 0:
+                sleep_time = -gap - SLEEP_OVERHEAD
+                if sleep_time > 0:
+                    time.sleep(sleep_time)
+                    t = time.time() - p_start
+                    if t - p.timestamp > max_sleep_miss:
+                        max_sleep_miss = t - p.timestamp
 
-            miss = p.timestamp - (time.time() - start)
-            if context is None:
-                self.msg("packet %s [miss %.3f pid %d]" % (p, miss,
-                                                           os.getpid()))
-                continue
             p.play(self, context)
 
+        return max_gap, miss, max_sleep_miss
+
     def guess_client_server(self, server_clue=None):
         """Have a go at deciding who is the server and who is the client.
         returns (client, server)
@@ -1019,12 +974,6 @@ class DnsHammer(Conversation):
         return ("<DnsHammer %d packets over %.1fs (rate %.2f)>" %
                 (len(self.times), self.duration, self.rate))
 
-    def replay_in_fork_with_delay(self, start, context=None, account=None):
-        return Conversation.replay_in_fork_with_delay(self,
-                                                      start,
-                                                      context,
-                                                      account)
-
     def replay(self, context=None):
         start = time.time()
         fn = traffic_packets.packet_dns_0
@@ -1035,15 +984,9 @@ class DnsHammer(Conversation):
             if sleep_time > 0:
                 time.sleep(sleep_time)
 
-            if context is None:
-                miss = t - (time.time() - start)
-                self.msg("packet %s [miss %.3f pid %d]" % (t, miss,
-                                                           os.getpid()))
-                continue
-
             packet_start = time.time()
             try:
-                fn(self, self, context)
+                fn(None, None, context)
                 end = time.time()
                 duration = end - packet_start
                 print("%f\tDNS\tdns\t0\t%f\tTrue\t" % (end, duration))
@@ -1458,7 +1401,98 @@ def expand_short_packet(p, timestamp, src, dest, extra):
     return '\t'.join(line)
 
 
-def replay(conversations,
+def flushing_signal_handler(signal, frame):
+    """Signal handler closes standard out and error.
+
+    Triggered by a sigterm, ensures that the log messages are flushed
+    to disk and not lost.
+    """
+    sys.stderr.close()
+    sys.stdout.close()
+    os._exit(0)
+
+
+def replay_seq_in_fork(cs, start, context, account, client_id, server_id=1):
+    """Fork a new process and replay the conversation sequence."""
+    endpoints = (server_id, client_id)
+    # flush our buffers so messages won't be written by both sides
+    sys.stdout.flush()
+    sys.stderr.flush()
+    pid = os.fork()
+    if pid != 0:
+        return pid
+
+    # we must never return, or we'll end up running parts of the
+    # parent's clean-up code. So we work in a try...finally, and
+    # try to print any exceptions.
+
+    try:
+        status = 0
+        t = cs[0][0]
+        c = Conversation(t, endpoints, seq=cs, conversation_id=client_id)
+        signal.signal(signal.SIGTERM, flushing_signal_handler)
+
+        context.generate_process_local_config(account, c)
+        sys.stdin.close()
+        os.close(0)
+        filename = os.path.join(context.statsdir, 'stats-conversation-%d' %
+                                c.conversation_id)
+        f = open(filename, 'w')
+        try:
+            sys.stdout.close()
+            os.close(1)
+        except IOError as e:
+            LOGGER.info("stdout closing failed with %s" % e)
+            pass
+
+        sys.stdout = f
+        now = time.time() - start
+        gap = t - now
+        sleep_time = gap - SLEEP_OVERHEAD
+        if sleep_time > 0:
+            time.sleep(sleep_time)
+
+        max_lag, start_lag, max_sleep_miss = c.replay_with_delay(start=start,
+                                                                 context=context)
+        print("Maximum lag: %f" % max_lag)
+        print("Start lag: %f" % start_lag)
+        print("Max sleep miss: %f" % max_sleep_miss)
+
+    except Exception:
+        status = 1
+        print(("EXCEPTION in child PID %d, conversation %s" % (os.getpid(), c)),
+              file=sys.stderr)
+        traceback.print_exc(sys.stderr)
+        sys.stderr.flush()
+    finally:
+        sys.stderr.close()
+        sys.stdout.close()
+        os._exit(status)
+
+
+def dnshammer_in_fork(dns_rate, duration):
+    sys.stdout.flush()
+    sys.stderr.flush()
+    pid = os.fork()
+    if pid != 0:
+        return pid
+    try:
+        status = 0
+        signal.signal(signal.SIGTERM, flushing_signal_handler)
+        hammer = DnsHammer(dns_rate, duration)
+        hammer.replay()
+    except Exception:
+        status = 1
+        print(("EXCEPTION in child PID %d, the DNS hammer" % (os.getpid())),
+              file=sys.stderr)
+        traceback.print_exc(sys.stderr)
+    finally:
+        sys.stderr.close()
+        sys.stdout.close()
+        os._exit(status)
+
+
+def replay(conversation_seq,
            host=None,
            creds=None,
            lp=None,
@@ -1472,87 +1506,72 @@ def replay(conversations,
                             lp=lp,
                             **kwargs)
 
-    if len(accounts) < len(conversations):
-        print(("we have %d accounts but %d conversations" %
-               (accounts, conversations)), file=sys.stderr)
-
-    cstack = list(zip(
-        sorted(conversations, key=lambda x: x.start_time, reverse=True),
-        accounts))
+    if len(accounts) < len(conversation_seq):
+        raise ValueError(("we have %d accounts but %d conversations" %
+                          (len(accounts), len(conversation_seq))))
 
     # Set the process group so that the calling scripts are not killed
     # when the forked child processes are killed.
     os.setpgrp()
 
-    start = time.time()
+    # we delay the start by a bit to allow all the forks to get up and
+    # running.
+    delay = len(conversation_seq) * 0.02
+    start = time.time() + delay
 
     if duration is None:
-        # end 1 second after the last packet of the last conversation
+        # end slightly after the last packet of the last conversation
         # to start. Conversations other than the last could still be
         # going, but we don't care.
-        duration = cstack[0][0].packets[-1].timestamp + 1.0
-        print("We will stop after %.1f seconds" % duration,
-              file=sys.stderr)
+        duration = conversation_seq[-1][-1][0] + 1.0
 
-    end = start + duration
+    print("We will start in %.1f seconds" % delay,
+          file=sys.stderr)
+    print("We will stop after %.1f seconds" % (duration + delay),
+          file=sys.stderr)
+    print("runtime %.1f seconds" % duration,
+          file=sys.stderr)
+
+    # give one second grace for packets to finish before killing begins
+    end = start + duration + 1.0
 
     LOGGER.info("Replaying traffic for %u conversations over %d seconds"
-          % (len(conversations), duration))
+          % (len(conversation_seq), duration))
 
-    children = {}
-    if dns_rate:
-        dns_hammer = DnsHammer(dns_rate, duration)
-        cstack.append((dns_hammer, None))
 
+    children = {}
     try:
-        while True:
-            # we spawn a batch, wait for finishers, then spawn another
-            now = time.time()
-            batch_end = min(now + 2.0, end)
-            fork_time = 0.0
-            fork_n = 0
-            while cstack:
-                c, account = cstack.pop()
-                if c.start_time + start > batch_end:
-                    cstack.append((c, account))
-                    break
+        if dns_rate:
+            pid = dnshammer_in_fork(dns_rate, duration)
+            children[pid] = 1
+
+        for i, cs in enumerate(conversation_seq):
+            account = accounts[i]
+            client_id = i + 2
+            pid = replay_seq_in_fork(cs, start, context, account, client_id)
+            children[pid] = client_id
+
+        # HERE, we are past all the forks
+        t = time.time()
+        print("all forks done in %.1f seconds, waiting %.1f" %
+              (t - start + delay, t - start),
+              file=sys.stderr)
 
-                st = time.time()
-                pid = c.replay_in_fork_with_delay(start, context, account)
-                children[pid] = c
-                t = time.time()
-                elapsed = t - st
-                fork_time += elapsed
-                fork_n += 1
-                print("forked %s in pid %s (in %fs)" % (c, pid,
-                                                        elapsed),
-                      file=sys.stderr)
-
-            if fork_n:
-                print(("forked %d times in %f seconds (avg %f)" %
-                       (fork_n, fork_time, fork_time / fork_n)),
-                      file=sys.stderr)
-            elif cstack:
-                debug(2, "no forks in batch ending %f" % batch_end)
-
-            while time.time() < batch_end - 1.0:
-                time.sleep(0.01)
-                try:
-                    pid, status = os.waitpid(-1, os.WNOHANG)
-                except OSError as e:
-                    if e.errno != 10:  # no child processes
-                        raise
-                    break
-                if pid:
-                    c = children.pop(pid, None)
-                    print(("process %d finished conversation %s;"
+        while time.time() < end and children:
+            time.sleep(0.003)
+            try:
+                pid, status = os.waitpid(-1, os.WNOHANG)
+            except OSError as e:
+                if e.errno != ECHILD:  # no child processes
+                    raise
+                break
+            if pid:
+                c = children.pop(pid, None)
+                if DEBUG_LEVEL > 0:
+                    print(("process %d finished conversation %d;"
                            " %d to go" %
                            (pid, c, len(children))), file=sys.stderr)
 
-            if time.time() >= end:
-                print("time to stop", file=sys.stderr)
-                break
-
     except Exception:
         print("EXCEPTION in parent", file=sys.stderr)
         traceback.print_exc()
@@ -1576,9 +1595,14 @@ def replay(conversations,
                         raise
                 if pid != 0:
                     c = children.pop(pid, None)
-                    print(("kill -%d %d KILLED conversation %s; "
+                    if c is None:
+                        print("children is %s, no pid found" % children)
+                        sys.stderr.flush()
+                        sys.stdout.flush()
+                        os._exit(1)
+                    print(("kill -%d %d KILLED conversation; "
                            "%d to go" %
-                           (s, pid, c, len(children))),
+                           (s, pid, len(children))),
                           file=sys.stderr)
                 if time.time() >= end:
                     break
index 83b7041f63548074592aeb00970837bfe032ea88..f9ceef5878f9cbb48b673ba2680ccb79979e0811 100755 (executable)
@@ -359,7 +359,7 @@ def main():
 
         exit(0)
 
-    traffic.replay(traffic.seq_to_conversations(conversations),
+    traffic.replay(conversations,
                    host,
                    lp=lp,
                    creds=creds,