traffic: load dns query from file and write stats to file
authorJoe Guo <joeg@catalyst.net.nz>
Tue, 26 Mar 2019 04:48:39 +0000 (17:48 +1300)
committerAndrew Bartlett <abartlet@samba.org>
Wed, 1 May 2019 01:10:42 +0000 (01:10 +0000)
Signed-off-by: Joe Guo <joeg@catalyst.net.nz>
Reviewed-by: Andrew Bartlett <abartlet@samba.org>
Reviewed-by: Andreas Schneider <asn@samba.org>
Autobuild-User(master): Andrew Bartlett <abartlet@samba.org>
Autobuild-Date(master): Wed May  1 01:10:42 UTC 2019 on sn-devel-184

python/samba/emulate/traffic.py
script/traffic_replay

index dc13241d5bafd0b1bf21ba5c11026d72c54dc2ef..7f720c98671bb4d48182bd8ea2d113db07ccf007 100644 (file)
@@ -28,6 +28,8 @@ import signal
 from errno import ECHILD, ESRCH
 
 from collections import OrderedDict, Counter, defaultdict, namedtuple
+from dns.resolver import query as dns_query
+
 from samba.emulate import traffic_packets
 from samba.samdb import SamDB
 import ldb
@@ -967,22 +969,56 @@ class DnsHammer(Conversation):
     """A lightweight conversation that generates a lot of dns:0 packets on
     the fly"""
 
-    def __init__(self, dns_rate, duration):
+    def __init__(self, dns_rate, duration, query_file=None):
         n = int(dns_rate * duration)
         self.times = [random.uniform(0, duration) for i in range(n)]
         self.times.sort()
         self.rate = dns_rate
         self.duration = duration
         self.start_time = 0
-        self.msg = random_colour_print()
+        self.query_choices = self._get_query_choices(query_file=query_file)
 
     def __str__(self):
         return ("<DnsHammer %d packets over %.1fs (rate %.2f)>" %
                 (len(self.times), self.duration, self.rate))
 
+    def _get_query_choices(self, query_file=None):
+        """
+        Read dns query choices from a file, or return default
+
+        rname may contain format string like `{realm}`
+        realm can be fetched from context.realm
+        """
+
+        if query_file:
+            with open(query_file, 'r') as f:
+                text = f.read()
+            choices = []
+            for line in text.splitlines():
+                line = line.strip()
+                if line and not line.startswith('#'):
+                    args = line.split(',')
+                    assert len(args) == 4
+                    choices.append(args)
+            return choices
+        else:
+            return [
+                (0, '{realm}', 'A', 'yes'),
+                (1, '{realm}', 'NS', 'yes'),
+                (2, '*.{realm}', 'A', 'no'),
+                (3, '*.{realm}', 'NS', 'no'),
+                (10, '_msdcs.{realm}', 'A', 'yes'),
+                (11, '_msdcs.{realm}', 'NS', 'yes'),
+                (20, 'nx.realm.com', 'A', 'no'),
+                (21, 'nx.realm.com', 'NS', 'no'),
+                (22, '*.nx.realm.com', 'A', 'no'),
+                (23, '*.nx.realm.com', 'NS', 'no'),
+            ]
+
     def replay(self, context=None):
+        assert context
+        assert context.realm
         start = time.time()
-        fn = traffic_packets.packet_dns_0
         for t in self.times:
             now = time.time() - start
             gap = t - now
@@ -990,16 +1026,21 @@ class DnsHammer(Conversation):
             if sleep_time > 0:
                 time.sleep(sleep_time)
 
+            opcode, rname, rtype, exist = random.choice(self.query_choices)
+            rname = rname.format(realm=context.realm)
+            success = True
             packet_start = time.time()
             try:
-                fn(None, None, context)
+                answers = dns_query(rname, rtype)
+                if exist == 'yes' and not len(answers):
+                    # expect answers but didn't get, fail
+                    success = False
+            except Exception:
+                success = False
+            finally:
                 end = time.time()
                 duration = end - packet_start
-                print("%f\tDNS\tdns\t0\t%f\tTrue\t" % (end, duration))
-            except Exception as e:
-                end = time.time()
-                duration = end - packet_start
-                print("%f\tDNS\tdns\t0\t%f\tFalse\t%s" % (end, duration, e))
+                print("%f\tDNS\tdns\t%s\t%f\t%s\t" % (end, opcode, duration, success))
 
 
 def ingest_summaries(files, dns_mode='count'):
@@ -1516,17 +1557,30 @@ def replay_seq_in_fork(cs, start, context, account, client_id, server_id=1):
         os._exit(status)
 
 
-def dnshammer_in_fork(dns_rate, duration):
+def dnshammer_in_fork(dns_rate, duration, context, query_file=None):
     sys.stdout.flush()
     sys.stderr.flush()
     pid = os.fork()
     if pid != 0:
         return pid
+
+    sys.stdin.close()
+    os.close(0)
+
+    try:
+        sys.stdout.close()
+        os.close(1)
+    except IOError as e:
+        LOGGER.warn("stdout closing failed with %s" % e)
+        pass
+    filename = os.path.join(context.statsdir, 'stats-dns')
+    sys.stdout = open(filename, 'w')
+
     try:
         status = 0
         signal.signal(signal.SIGTERM, flushing_signal_handler)
-        hammer = DnsHammer(dns_rate, duration)
-        hammer.replay()
+        hammer = DnsHammer(dns_rate, duration, query_file=query_file)
+        hammer.replay(context=context)
     except Exception:
         status = 1
         print(("EXCEPTION in child PID %d, the DNS hammer" % (os.getpid())),
@@ -1544,6 +1598,7 @@ def replay(conversation_seq,
            lp=None,
            accounts=None,
            dns_rate=0,
+           dns_query_file=None,
            duration=None,
            latency_timeout=1.0,
            stop_on_any_error=False,
@@ -1593,7 +1648,8 @@ def replay(conversation_seq,
     children = {}
     try:
         if dns_rate:
-            pid = dnshammer_in_fork(dns_rate, duration)
+            pid = dnshammer_in_fork(dns_rate, duration, context,
+                                    query_file=dns_query_file)
             children[pid] = 1
 
         for i, cs in enumerate(conversation_seq):
index 77eef7c0322bda95060985665c7d06dd1ce667c4..0d74c876d127b8823b8e10bb9483b87bf58b1653 100755 (executable)
@@ -50,6 +50,8 @@ def main():
 
     parser.add_option('--dns-rate', type='float', default=0,
                       help='fire extra DNS packets at this rate')
+    parser.add_option('--dns-query-file', dest="dns_query_file",
+                      help='A file contains DNS query list')
     parser.add_option('-B', '--badpassword-frequency',
                       type='float', default=0.0,
                       help='frequency of connections with bad passwords')
@@ -403,6 +405,7 @@ def main():
                    creds=creds,
                    accounts=accounts,
                    dns_rate=opts.dns_rate,
+                   dns_query_file=opts.dns_query_file,
                    duration=opts.duration,
                    latency_timeout=opts.latency_timeout,
                    badpassword_frequency=opts.badpassword_frequency,