traffic_replay: Add a max-members option to cap group size
[samba.git] / python / samba / emulate / traffic.py
1 # -*- encoding: utf-8 -*-
2 # Samba traffic replay and learning
3 #
4 # Copyright (C) Catalyst IT Ltd. 2017
5 #
6 # This program is free software; you can redistribute it and/or modify
7 # it under the terms of the GNU General Public License as published by
8 # the Free Software Foundation; either version 3 of the License, or
9 # (at your option) any later version.
10 #
11 # This program is distributed in the hope that it will be useful,
12 # but WITHOUT ANY WARRANTY; without even the implied warranty of
13 # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
14 # GNU General Public License for more details.
15 #
16 # You should have received a copy of the GNU General Public License
17 # along with this program.  If not, see <http://www.gnu.org/licenses/>.
18 #
19 from __future__ import print_function, division
20
21 import time
22 import os
23 import random
24 import json
25 import math
26 import sys
27 import signal
28 import itertools
29
30 from collections import OrderedDict, Counter, defaultdict
31 from samba.emulate import traffic_packets
32 from samba.samdb import SamDB
33 import ldb
34 from ldb import LdbError
35 from samba.dcerpc import ClientConnection
36 from samba.dcerpc import security, drsuapi, lsa
37 from samba.dcerpc import netlogon
38 from samba.dcerpc.netlogon import netr_Authenticator
39 from samba.dcerpc import srvsvc
40 from samba.dcerpc import samr
41 from samba.drs_utils import drs_DsBind
42 import traceback
43 from samba.credentials import Credentials, DONT_USE_KERBEROS, MUST_USE_KERBEROS
44 from samba.auth import system_session
45 from samba.dsdb import (
46     UF_NORMAL_ACCOUNT,
47     UF_SERVER_TRUST_ACCOUNT,
48     UF_TRUSTED_FOR_DELEGATION,
49     UF_WORKSTATION_TRUST_ACCOUNT
50 )
51 from samba.dcerpc.misc import SEC_CHAN_BDC
52 from samba import gensec
53 from samba import sd_utils
54 from samba.compat import get_string
55 from samba.logger import get_samba_logger
56 import bisect
57
58 SLEEP_OVERHEAD = 3e-4
59
60 # we don't use None, because it complicates [de]serialisation
61 NON_PACKET = '-'
62
63 CLIENT_CLUES = {
64     ('dns', '0'): 1.0,      # query
65     ('smb', '0x72'): 1.0,   # Negotiate protocol
66     ('ldap', '0'): 1.0,     # bind
67     ('ldap', '3'): 1.0,     # searchRequest
68     ('ldap', '2'): 1.0,     # unbindRequest
69     ('cldap', '3'): 1.0,
70     ('dcerpc', '11'): 1.0,  # bind
71     ('dcerpc', '14'): 1.0,  # Alter_context
72     ('nbns', '0'): 1.0,     # query
73 }
74
75 SERVER_CLUES = {
76     ('dns', '1'): 1.0,      # response
77     ('ldap', '1'): 1.0,     # bind response
78     ('ldap', '4'): 1.0,     # search result
79     ('ldap', '5'): 1.0,     # search done
80     ('cldap', '5'): 1.0,
81     ('dcerpc', '12'): 1.0,  # bind_ack
82     ('dcerpc', '13'): 1.0,  # bind_nak
83     ('dcerpc', '15'): 1.0,  # Alter_context response
84 }
85
86 SKIPPED_PROTOCOLS = {"smb", "smb2", "browser", "smb_netlogon"}
87
88 WAIT_SCALE = 10.0
89 WAIT_THRESHOLD = (1.0 / WAIT_SCALE)
90 NO_WAIT_LOG_TIME_RANGE = (-10, -3)
91
92 # DEBUG_LEVEL can be changed by scripts with -d
93 DEBUG_LEVEL = 0
94
95 LOGGER = get_samba_logger(name=__name__)
96
97
98 def debug(level, msg, *args):
99     """Print a formatted debug message to standard error.
100
101
102     :param level: The debug level, message will be printed if it is <= the
103                   currently set debug level. The debug level can be set with
104                   the -d option.
105     :param msg:   The message to be logged, can contain C-Style format
106                   specifiers
107     :param args:  The parameters required by the format specifiers
108     """
109     if level <= DEBUG_LEVEL:
110         if not args:
111             print(msg, file=sys.stderr)
112         else:
113             print(msg % tuple(args), file=sys.stderr)
114
115
116 def debug_lineno(*args):
117     """ Print an unformatted log message to stderr, contaning the line number
118     """
119     tb = traceback.extract_stack(limit=2)
120     print((" %s:" "\033[01;33m"
121            "%s " "\033[00m" % (tb[0][2], tb[0][1])), end=' ',
122           file=sys.stderr)
123     for a in args:
124         print(a, file=sys.stderr)
125     print(file=sys.stderr)
126     sys.stderr.flush()
127
128
129 def random_colour_print():
130     """Return a function that prints a randomly coloured line to stderr"""
131     n = 18 + random.randrange(214)
132     prefix = "\033[38;5;%dm" % n
133
134     def p(*args):
135         for a in args:
136             print("%s%s\033[00m" % (prefix, a), file=sys.stderr)
137
138     return p
139
140
141 class FakePacketError(Exception):
142     pass
143
144
145 class Packet(object):
146     """Details of a network packet"""
147     def __init__(self, timestamp, ip_protocol, stream_number, src, dest,
148                  protocol, opcode, desc, extra):
149
150         self.timestamp = timestamp
151         self.ip_protocol = ip_protocol
152         self.stream_number = stream_number
153         self.src = src
154         self.dest = dest
155         self.protocol = protocol
156         self.opcode = opcode
157         self.desc = desc
158         self.extra = extra
159         if self.src < self.dest:
160             self.endpoints = (self.src, self.dest)
161         else:
162             self.endpoints = (self.dest, self.src)
163
164     @classmethod
165     def from_line(self, line):
166         fields = line.rstrip('\n').split('\t')
167         (timestamp,
168          ip_protocol,
169          stream_number,
170          src,
171          dest,
172          protocol,
173          opcode,
174          desc) = fields[:8]
175         extra = fields[8:]
176
177         timestamp = float(timestamp)
178         src = int(src)
179         dest = int(dest)
180
181         return Packet(timestamp, ip_protocol, stream_number, src, dest,
182                       protocol, opcode, desc, extra)
183
184     def as_summary(self, time_offset=0.0):
185         """Format the packet as a traffic_summary line.
186         """
187         extra = '\t'.join(self.extra)
188         t = self.timestamp + time_offset
189         return (t, '%f\t%s\t%s\t%d\t%d\t%s\t%s\t%s\t%s' %
190                 (t,
191                  self.ip_protocol,
192                  self.stream_number or '',
193                  self.src,
194                  self.dest,
195                  self.protocol,
196                  self.opcode,
197                  self.desc,
198                  extra))
199
200     def __str__(self):
201         return ("%.3f: %d -> %d; ip %s; strm %s; prot %s; op %s; desc %s %s" %
202                 (self.timestamp, self.src, self.dest, self.ip_protocol or '-',
203                  self.stream_number, self.protocol, self.opcode, self.desc,
204                  ('«' + ' '.join(self.extra) + '»' if self.extra else '')))
205
206     def __repr__(self):
207         return "<Packet @%s>" % self
208
209     def copy(self):
210         return self.__class__(self.timestamp,
211                               self.ip_protocol,
212                               self.stream_number,
213                               self.src,
214                               self.dest,
215                               self.protocol,
216                               self.opcode,
217                               self.desc,
218                               self.extra)
219
220     def as_packet_type(self):
221         t = '%s:%s' % (self.protocol, self.opcode)
222         return t
223
224     def client_score(self):
225         """A positive number means we think it is a client; a negative number
226         means we think it is a server. Zero means no idea. range: -1 to 1.
227         """
228         key = (self.protocol, self.opcode)
229         if key in CLIENT_CLUES:
230             return CLIENT_CLUES[key]
231         if key in SERVER_CLUES:
232             return -SERVER_CLUES[key]
233         return 0.0
234
235     def play(self, conversation, context):
236         """Send the packet over the network, if required.
237
238         Some packets are ignored, i.e. for  protocols not handled,
239         server response messages, or messages that are generated by the
240         protocol layer associated with other packets.
241         """
242         fn_name = 'packet_%s_%s' % (self.protocol, self.opcode)
243         try:
244             fn = getattr(traffic_packets, fn_name)
245
246         except AttributeError as e:
247             print("Conversation(%s) Missing handler %s" %
248                   (conversation.conversation_id, fn_name),
249                   file=sys.stderr)
250             return
251
252         # Don't display a message for kerberos packets, they're not directly
253         # generated they're used to indicate kerberos should be used
254         if self.protocol != "kerberos":
255             debug(2, "Conversation(%s) Calling handler %s" %
256                      (conversation.conversation_id, fn_name))
257
258         start = time.time()
259         try:
260             if fn(self, conversation, context):
261                 # Only collect timing data for functions that generate
262                 # network traffic, or fail
263                 end = time.time()
264                 duration = end - start
265                 print("%f\t%s\t%s\t%s\t%f\tTrue\t" %
266                       (end, conversation.conversation_id, self.protocol,
267                        self.opcode, duration))
268         except Exception as e:
269             end = time.time()
270             duration = end - start
271             print("%f\t%s\t%s\t%s\t%f\tFalse\t%s" %
272                   (end, conversation.conversation_id, self.protocol,
273                    self.opcode, duration, e))
274
275     def __cmp__(self, other):
276         return self.timestamp - other.timestamp
277
278     def is_really_a_packet(self, missing_packet_stats=None):
279         """Is the packet one that can be ignored?
280
281         If so removing it will have no effect on the replay
282         """
283         if self.protocol in SKIPPED_PROTOCOLS:
284             # Ignore any packets for the protocols we're not interested in.
285             return False
286         if self.protocol == "ldap" and self.opcode == '':
287             # skip ldap continuation packets
288             return False
289
290         fn_name = 'packet_%s_%s' % (self.protocol, self.opcode)
291         fn = getattr(traffic_packets, fn_name, None)
292         if not fn:
293             print("missing packet %s" % fn_name, file=sys.stderr)
294             return False
295         if fn is traffic_packets.null_packet:
296             return False
297         return True
298
299
300 class ReplayContext(object):
301     """State/Context for an individual conversation between an simulated client
302        and a server.
303     """
304
305     def __init__(self,
306                  server=None,
307                  lp=None,
308                  creds=None,
309                  badpassword_frequency=None,
310                  prefer_kerberos=None,
311                  tempdir=None,
312                  statsdir=None,
313                  ou=None,
314                  base_dn=None,
315                  domain=None,
316                  domain_sid=None):
317
318         self.server                   = server
319         self.ldap_connections         = []
320         self.dcerpc_connections       = []
321         self.lsarpc_connections       = []
322         self.lsarpc_connections_named = []
323         self.drsuapi_connections      = []
324         self.srvsvc_connections       = []
325         self.samr_contexts            = []
326         self.netlogon_connection      = None
327         self.creds                    = creds
328         self.lp                       = lp
329         self.prefer_kerberos          = prefer_kerberos
330         self.ou                       = ou
331         self.base_dn                  = base_dn
332         self.domain                   = domain
333         self.statsdir                 = statsdir
334         self.global_tempdir           = tempdir
335         self.domain_sid               = domain_sid
336         self.realm                    = lp.get('realm')
337
338         # Bad password attempt controls
339         self.badpassword_frequency    = badpassword_frequency
340         self.last_lsarpc_bad          = False
341         self.last_lsarpc_named_bad    = False
342         self.last_simple_bind_bad     = False
343         self.last_bind_bad            = False
344         self.last_srvsvc_bad          = False
345         self.last_drsuapi_bad         = False
346         self.last_netlogon_bad        = False
347         self.last_samlogon_bad        = False
348         self.generate_ldap_search_tables()
349         self.next_conversation_id = itertools.count()
350
351     def generate_ldap_search_tables(self):
352         session = system_session()
353
354         db = SamDB(url="ldap://%s" % self.server,
355                    session_info=session,
356                    credentials=self.creds,
357                    lp=self.lp)
358
359         res = db.search(db.domain_dn(),
360                         scope=ldb.SCOPE_SUBTREE,
361                         controls=["paged_results:1:1000"],
362                         attrs=['dn'])
363
364         # find a list of dns for each pattern
365         # e.g. CN,CN,CN,DC,DC
366         dn_map = {}
367         attribute_clue_map = {
368             'invocationId': []
369         }
370
371         for r in res:
372             dn = str(r.dn)
373             pattern = ','.join(x.lstrip()[:2] for x in dn.split(',')).upper()
374             dns = dn_map.setdefault(pattern, [])
375             dns.append(dn)
376             if dn.startswith('CN=NTDS Settings,'):
377                 attribute_clue_map['invocationId'].append(dn)
378
379         # extend the map in case we are working with a different
380         # number of DC components.
381         # for k, v in self.dn_map.items():
382         #     print >>sys.stderr, k, len(v)
383
384         for k in list(dn_map.keys()):
385             if k[-3:] != ',DC':
386                 continue
387             p = k[:-3]
388             while p[-3:] == ',DC':
389                 p = p[:-3]
390             for i in range(5):
391                 p += ',DC'
392                 if p != k and p in dn_map:
393                     print('dn_map collison %s %s' % (k, p),
394                           file=sys.stderr)
395                     continue
396                 dn_map[p] = dn_map[k]
397
398         self.dn_map = dn_map
399         self.attribute_clue_map = attribute_clue_map
400
401     def generate_process_local_config(self, account, conversation):
402         if account is None:
403             return
404         self.netbios_name             = account.netbios_name
405         self.machinepass              = account.machinepass
406         self.username                 = account.username
407         self.userpass                 = account.userpass
408
409         self.tempdir = mk_masked_dir(self.global_tempdir,
410                                      'conversation-%d' %
411                                      conversation.conversation_id)
412
413         self.lp.set("private dir", self.tempdir)
414         self.lp.set("lock dir", self.tempdir)
415         self.lp.set("state directory", self.tempdir)
416         self.lp.set("tls verify peer", "no_check")
417
418         # If the domain was not specified, check for the environment
419         # variable.
420         if self.domain is None:
421             self.domain = os.environ["DOMAIN"]
422
423         self.remoteAddress = "/root/ncalrpc_as_system"
424         self.samlogon_dn   = ("cn=%s,%s" %
425                               (self.netbios_name, self.ou))
426         self.user_dn       = ("cn=%s,%s" %
427                               (self.username, self.ou))
428
429         self.generate_machine_creds()
430         self.generate_user_creds()
431
432     def with_random_bad_credentials(self, f, good, bad, failed_last_time):
433         """Execute the supplied logon function, randomly choosing the
434            bad credentials.
435
436            Based on the frequency in badpassword_frequency randomly perform the
437            function with the supplied bad credentials.
438            If run with bad credentials, the function is re-run with the good
439            credentials.
440            failed_last_time is used to prevent consecutive bad credential
441            attempts. So the over all bad credential frequency will be lower
442            than that requested, but not significantly.
443         """
444         if not failed_last_time:
445             if (self.badpassword_frequency and self.badpassword_frequency > 0
446                 and random.random() < self.badpassword_frequency):
447                 try:
448                     f(bad)
449                 except:
450                     # Ignore any exceptions as the operation may fail
451                     # as it's being performed with bad credentials
452                     pass
453                 failed_last_time = True
454             else:
455                 failed_last_time = False
456
457         result = f(good)
458         return (result, failed_last_time)
459
460     def generate_user_creds(self):
461         """Generate the conversation specific user Credentials.
462
463         Each Conversation has an associated user account used to simulate
464         any non Administrative user traffic.
465
466         Generates user credentials with good and bad passwords and ldap
467         simple bind credentials with good and bad passwords.
468         """
469         self.user_creds = Credentials()
470         self.user_creds.guess(self.lp)
471         self.user_creds.set_workstation(self.netbios_name)
472         self.user_creds.set_password(self.userpass)
473         self.user_creds.set_username(self.username)
474         self.user_creds.set_domain(self.domain)
475         if self.prefer_kerberos:
476             self.user_creds.set_kerberos_state(MUST_USE_KERBEROS)
477         else:
478             self.user_creds.set_kerberos_state(DONT_USE_KERBEROS)
479
480         self.user_creds_bad = Credentials()
481         self.user_creds_bad.guess(self.lp)
482         self.user_creds_bad.set_workstation(self.netbios_name)
483         self.user_creds_bad.set_password(self.userpass[:-4])
484         self.user_creds_bad.set_username(self.username)
485         if self.prefer_kerberos:
486             self.user_creds_bad.set_kerberos_state(MUST_USE_KERBEROS)
487         else:
488             self.user_creds_bad.set_kerberos_state(DONT_USE_KERBEROS)
489
490         # Credentials for ldap simple bind.
491         self.simple_bind_creds = Credentials()
492         self.simple_bind_creds.guess(self.lp)
493         self.simple_bind_creds.set_workstation(self.netbios_name)
494         self.simple_bind_creds.set_password(self.userpass)
495         self.simple_bind_creds.set_username(self.username)
496         self.simple_bind_creds.set_gensec_features(
497             self.simple_bind_creds.get_gensec_features() | gensec.FEATURE_SEAL)
498         if self.prefer_kerberos:
499             self.simple_bind_creds.set_kerberos_state(MUST_USE_KERBEROS)
500         else:
501             self.simple_bind_creds.set_kerberos_state(DONT_USE_KERBEROS)
502         self.simple_bind_creds.set_bind_dn(self.user_dn)
503
504         self.simple_bind_creds_bad = Credentials()
505         self.simple_bind_creds_bad.guess(self.lp)
506         self.simple_bind_creds_bad.set_workstation(self.netbios_name)
507         self.simple_bind_creds_bad.set_password(self.userpass[:-4])
508         self.simple_bind_creds_bad.set_username(self.username)
509         self.simple_bind_creds_bad.set_gensec_features(
510             self.simple_bind_creds_bad.get_gensec_features() |
511             gensec.FEATURE_SEAL)
512         if self.prefer_kerberos:
513             self.simple_bind_creds_bad.set_kerberos_state(MUST_USE_KERBEROS)
514         else:
515             self.simple_bind_creds_bad.set_kerberos_state(DONT_USE_KERBEROS)
516         self.simple_bind_creds_bad.set_bind_dn(self.user_dn)
517
518     def generate_machine_creds(self):
519         """Generate the conversation specific machine Credentials.
520
521         Each Conversation has an associated machine account.
522
523         Generates machine credentials with good and bad passwords.
524         """
525
526         self.machine_creds = Credentials()
527         self.machine_creds.guess(self.lp)
528         self.machine_creds.set_workstation(self.netbios_name)
529         self.machine_creds.set_secure_channel_type(SEC_CHAN_BDC)
530         self.machine_creds.set_password(self.machinepass)
531         self.machine_creds.set_username(self.netbios_name + "$")
532         self.machine_creds.set_domain(self.domain)
533         if self.prefer_kerberos:
534             self.machine_creds.set_kerberos_state(MUST_USE_KERBEROS)
535         else:
536             self.machine_creds.set_kerberos_state(DONT_USE_KERBEROS)
537
538         self.machine_creds_bad = Credentials()
539         self.machine_creds_bad.guess(self.lp)
540         self.machine_creds_bad.set_workstation(self.netbios_name)
541         self.machine_creds_bad.set_secure_channel_type(SEC_CHAN_BDC)
542         self.machine_creds_bad.set_password(self.machinepass[:-4])
543         self.machine_creds_bad.set_username(self.netbios_name + "$")
544         if self.prefer_kerberos:
545             self.machine_creds_bad.set_kerberos_state(MUST_USE_KERBEROS)
546         else:
547             self.machine_creds_bad.set_kerberos_state(DONT_USE_KERBEROS)
548
549     def get_matching_dn(self, pattern, attributes=None):
550         # If the pattern is an empty string, we assume ROOTDSE,
551         # Otherwise we try adding or removing DC suffixes, then
552         # shorter leading patterns until we hit one.
553         # e.g if there is no CN,CN,CN,CN,DC,DC
554         # we first try       CN,CN,CN,CN,DC
555         # and                CN,CN,CN,CN,DC,DC,DC
556         # then change to        CN,CN,CN,DC,DC
557         # and as last resort we use the base_dn
558         attr_clue = self.attribute_clue_map.get(attributes)
559         if attr_clue:
560             return random.choice(attr_clue)
561
562         pattern = pattern.upper()
563         while pattern:
564             if pattern in self.dn_map:
565                 return random.choice(self.dn_map[pattern])
566             # chop one off the front and try it all again.
567             pattern = pattern[3:]
568
569         return self.base_dn
570
571     def get_dcerpc_connection(self, new=False):
572         guid = '12345678-1234-abcd-ef00-01234567cffb'  # RPC_NETLOGON UUID
573         if self.dcerpc_connections and not new:
574             return self.dcerpc_connections[-1]
575         c = ClientConnection("ncacn_ip_tcp:%s" % self.server,
576                              (guid, 1), self.lp)
577         self.dcerpc_connections.append(c)
578         return c
579
580     def get_srvsvc_connection(self, new=False):
581         if self.srvsvc_connections and not new:
582             return self.srvsvc_connections[-1]
583
584         def connect(creds):
585             return srvsvc.srvsvc("ncacn_np:%s" % (self.server),
586                                  self.lp,
587                                  creds)
588
589         (c, self.last_srvsvc_bad) = \
590             self.with_random_bad_credentials(connect,
591                                              self.user_creds,
592                                              self.user_creds_bad,
593                                              self.last_srvsvc_bad)
594
595         self.srvsvc_connections.append(c)
596         return c
597
598     def get_lsarpc_connection(self, new=False):
599         if self.lsarpc_connections and not new:
600             return self.lsarpc_connections[-1]
601
602         def connect(creds):
603             binding_options = 'schannel,seal,sign'
604             return lsa.lsarpc("ncacn_ip_tcp:%s[%s]" %
605                               (self.server, binding_options),
606                               self.lp,
607                               creds)
608
609         (c, self.last_lsarpc_bad) = \
610             self.with_random_bad_credentials(connect,
611                                              self.machine_creds,
612                                              self.machine_creds_bad,
613                                              self.last_lsarpc_bad)
614
615         self.lsarpc_connections.append(c)
616         return c
617
618     def get_lsarpc_named_pipe_connection(self, new=False):
619         if self.lsarpc_connections_named and not new:
620             return self.lsarpc_connections_named[-1]
621
622         def connect(creds):
623             return lsa.lsarpc("ncacn_np:%s" % (self.server),
624                               self.lp,
625                               creds)
626
627         (c, self.last_lsarpc_named_bad) = \
628             self.with_random_bad_credentials(connect,
629                                              self.machine_creds,
630                                              self.machine_creds_bad,
631                                              self.last_lsarpc_named_bad)
632
633         self.lsarpc_connections_named.append(c)
634         return c
635
636     def get_drsuapi_connection_pair(self, new=False, unbind=False):
637         """get a (drs, drs_handle) tuple"""
638         if self.drsuapi_connections and not new:
639             c = self.drsuapi_connections[-1]
640             return c
641
642         def connect(creds):
643             binding_options = 'seal'
644             binding_string = "ncacn_ip_tcp:%s[%s]" %\
645                              (self.server, binding_options)
646             return drsuapi.drsuapi(binding_string, self.lp, creds)
647
648         (drs, self.last_drsuapi_bad) = \
649             self.with_random_bad_credentials(connect,
650                                              self.user_creds,
651                                              self.user_creds_bad,
652                                              self.last_drsuapi_bad)
653
654         (drs_handle, supported_extensions) = drs_DsBind(drs)
655         c = (drs, drs_handle)
656         self.drsuapi_connections.append(c)
657         return c
658
659     def get_ldap_connection(self, new=False, simple=False):
660         if self.ldap_connections and not new:
661             return self.ldap_connections[-1]
662
663         def simple_bind(creds):
664             """
665             To run simple bind against Windows, we need to run
666             following commands in PowerShell:
667
668                 Install-windowsfeature ADCS-Cert-Authority
669                 Install-AdcsCertificationAuthority -CAType EnterpriseRootCA
670                 Restart-Computer
671
672             """
673             return SamDB('ldaps://%s' % self.server,
674                          credentials=creds,
675                          lp=self.lp)
676
677         def sasl_bind(creds):
678             return SamDB('ldap://%s' % self.server,
679                          credentials=creds,
680                          lp=self.lp)
681         if simple:
682             (samdb, self.last_simple_bind_bad) = \
683                 self.with_random_bad_credentials(simple_bind,
684                                                  self.simple_bind_creds,
685                                                  self.simple_bind_creds_bad,
686                                                  self.last_simple_bind_bad)
687         else:
688             (samdb, self.last_bind_bad) = \
689                 self.with_random_bad_credentials(sasl_bind,
690                                                  self.user_creds,
691                                                  self.user_creds_bad,
692                                                  self.last_bind_bad)
693
694         self.ldap_connections.append(samdb)
695         return samdb
696
697     def get_samr_context(self, new=False):
698         if not self.samr_contexts or new:
699             self.samr_contexts.append(
700                 SamrContext(self.server, lp=self.lp, creds=self.creds))
701         return self.samr_contexts[-1]
702
703     def get_netlogon_connection(self):
704
705         if self.netlogon_connection:
706             return self.netlogon_connection
707
708         def connect(creds):
709             return netlogon.netlogon("ncacn_ip_tcp:%s[schannel,seal]" %
710                                      (self.server),
711                                      self.lp,
712                                      creds)
713         (c, self.last_netlogon_bad) = \
714             self.with_random_bad_credentials(connect,
715                                              self.machine_creds,
716                                              self.machine_creds_bad,
717                                              self.last_netlogon_bad)
718         self.netlogon_connection = c
719         return c
720
721     def guess_a_dns_lookup(self):
722         return (self.realm, 'A')
723
724     def get_authenticator(self):
725         auth = self.machine_creds.new_client_authenticator()
726         current  = netr_Authenticator()
727         current.cred.data = [x if isinstance(x, int) else ord(x) for x in auth["credential"]]
728         current.timestamp = auth["timestamp"]
729
730         subsequent = netr_Authenticator()
731         return (current, subsequent)
732
733
734 class SamrContext(object):
735     """State/Context associated with a samr connection.
736     """
737     def __init__(self, server, lp=None, creds=None):
738         self.connection    = None
739         self.handle        = None
740         self.domain_handle = None
741         self.domain_sid    = None
742         self.group_handle  = None
743         self.user_handle   = None
744         self.rids          = None
745         self.server        = server
746         self.lp            = lp
747         self.creds         = creds
748
749     def get_connection(self):
750         if not self.connection:
751             self.connection = samr.samr(
752                 "ncacn_ip_tcp:%s[seal]" % (self.server),
753                 lp_ctx=self.lp,
754                 credentials=self.creds)
755
756         return self.connection
757
758     def get_handle(self):
759         if not self.handle:
760             c = self.get_connection()
761             self.handle = c.Connect2(None, security.SEC_FLAG_MAXIMUM_ALLOWED)
762         return self.handle
763
764
765 class Conversation(object):
766     """Details of a converation between a simulated client and a server."""
767     conversation_id = None
768
769     def __init__(self, start_time=None, endpoints=None):
770         self.start_time = start_time
771         self.endpoints = endpoints
772         self.packets = []
773         self.msg = random_colour_print()
774         self.client_balance = 0.0
775
776     def __cmp__(self, other):
777         if self.start_time is None:
778             if other.start_time is None:
779                 return 0
780             return -1
781         if other.start_time is None:
782             return 1
783         return self.start_time - other.start_time
784
785     def add_packet(self, packet):
786         """Add a packet object to this conversation, making a local copy with
787         a conversation-relative timestamp."""
788         p = packet.copy()
789
790         if self.start_time is None:
791             self.start_time = p.timestamp
792
793         if self.endpoints is None:
794             self.endpoints = p.endpoints
795
796         if p.endpoints != self.endpoints:
797             raise FakePacketError("Conversation endpoints %s don't match"
798                                   "packet endpoints %s" %
799                                   (self.endpoints, p.endpoints))
800
801         p.timestamp -= self.start_time
802
803         if p.src == p.endpoints[0]:
804             self.client_balance -= p.client_score()
805         else:
806             self.client_balance += p.client_score()
807
808         if p.is_really_a_packet():
809             self.packets.append(p)
810
811     def add_short_packet(self, timestamp, protocol, opcode, extra,
812                          client=True):
813         """Create a packet from a timestamp, and 'protocol:opcode' pair, and a
814         (possibly empty) list of extra data. If client is True, assume
815         this packet is from the client to the server.
816         """
817         src, dest = self.guess_client_server()
818         if not client:
819             src, dest = dest, src
820         key = (protocol, opcode)
821         desc = OP_DESCRIPTIONS[key] if key in OP_DESCRIPTIONS else ''
822         if protocol in IP_PROTOCOLS:
823             ip_protocol = IP_PROTOCOLS[protocol]
824         else:
825             ip_protocol = '06'
826         packet = Packet(timestamp - self.start_time, ip_protocol,
827                         '', src, dest,
828                         protocol, opcode, desc, extra)
829         # XXX we're assuming the timestamp is already adjusted for
830         # this conversation?
831         # XXX should we adjust client balance for guessed packets?
832         if packet.src == packet.endpoints[0]:
833             self.client_balance -= packet.client_score()
834         else:
835             self.client_balance += packet.client_score()
836         if packet.is_really_a_packet():
837             self.packets.append(packet)
838
839     def __str__(self):
840         return ("<Conversation %s %s starting %.3f %d packets>" %
841                 (self.conversation_id, self.endpoints, self.start_time,
842                  len(self.packets)))
843
844     __repr__ = __str__
845
846     def __iter__(self):
847         return iter(self.packets)
848
849     def __len__(self):
850         return len(self.packets)
851
852     def get_duration(self):
853         if len(self.packets) < 2:
854             return 0
855         return self.packets[-1].timestamp - self.packets[0].timestamp
856
857     def replay_as_summary_lines(self):
858         lines = []
859         for p in self.packets:
860             lines.append(p.as_summary(self.start_time))
861         return lines
862
863     def replay_in_fork_with_delay(self, start, context=None, account=None):
864         """Fork a new process and replay the conversation.
865         """
866         def signal_handler(signal, frame):
867             """Signal handler closes standard out and error.
868
869             Triggered by a sigterm, ensures that the log messages are flushed
870             to disk and not lost.
871             """
872             sys.stderr.close()
873             sys.stdout.close()
874             os._exit(0)
875
876         t = self.start_time
877         now = time.time() - start
878         gap = t - now
879         # we are replaying strictly in order, so it is safe to sleep
880         # in the main process if the gap is big enough. This reduces
881         # the number of concurrent threads, which allows us to make
882         # larger loads.
883         if gap > 0.15 and False:
884             print("sleeping for %f in main process" % (gap - 0.1),
885                   file=sys.stderr)
886             time.sleep(gap - 0.1)
887             now = time.time() - start
888             gap = t - now
889             print("gap is now %f" % gap, file=sys.stderr)
890
891         self.conversation_id = next(context.next_conversation_id)
892         pid = os.fork()
893         if pid != 0:
894             return pid
895         pid = os.getpid()
896         signal.signal(signal.SIGTERM, signal_handler)
897         # we must never return, or we'll end up running parts of the
898         # parent's clean-up code. So we work in a try...finally, and
899         # try to print any exceptions.
900
901         try:
902             context.generate_process_local_config(account, self)
903             sys.stdin.close()
904             os.close(0)
905             filename = os.path.join(context.statsdir, 'stats-conversation-%d' %
906                                     self.conversation_id)
907             sys.stdout.close()
908             sys.stdout = open(filename, 'w')
909
910             sleep_time = gap - SLEEP_OVERHEAD
911             if sleep_time > 0:
912                 time.sleep(sleep_time)
913
914             miss = t - (time.time() - start)
915             self.msg("starting %s [miss %.3f pid %d]" % (self, miss, pid))
916             self.replay(context)
917         except Exception:
918             print(("EXCEPTION in child PID %d, conversation %s" % (pid, self)),
919                   file=sys.stderr)
920             traceback.print_exc(sys.stderr)
921         finally:
922             sys.stderr.close()
923             sys.stdout.close()
924             os._exit(0)
925
926     def replay(self, context=None):
927         start = time.time()
928
929         for p in self.packets:
930             now = time.time() - start
931             gap = p.timestamp - now
932             sleep_time = gap - SLEEP_OVERHEAD
933             if sleep_time > 0:
934                 time.sleep(sleep_time)
935
936             miss = p.timestamp - (time.time() - start)
937             if context is None:
938                 self.msg("packet %s [miss %.3f pid %d]" % (p, miss,
939                                                            os.getpid()))
940                 continue
941             p.play(self, context)
942
943     def guess_client_server(self, server_clue=None):
944         """Have a go at deciding who is the server and who is the client.
945         returns (client, server)
946         """
947         a, b = self.endpoints
948
949         if self.client_balance < 0:
950             return (a, b)
951
952         # in the absense of a clue, we will fall through to assuming
953         # the lowest number is the server (which is usually true).
954
955         if self.client_balance == 0 and server_clue == b:
956             return (a, b)
957
958         return (b, a)
959
960     def forget_packets_outside_window(self, s, e):
961         """Prune any packets outside the timne window we're interested in
962
963         :param s: start of the window
964         :param e: end of the window
965         """
966         self.packets = [p for p in self.packets if s <= p.timestamp <= e]
967         self.start_time = self.packets[0].timestamp if self.packets else None
968
969     def renormalise_times(self, start_time):
970         """Adjust the packet start times relative to the new start time."""
971         for p in self.packets:
972             p.timestamp -= start_time
973
974         if self.start_time is not None:
975             self.start_time -= start_time
976
977
978 class DnsHammer(Conversation):
979     """A lightweight conversation that generates a lot of dns:0 packets on
980     the fly"""
981
982     def __init__(self, dns_rate, duration):
983         n = int(dns_rate * duration)
984         self.times = [random.uniform(0, duration) for i in range(n)]
985         self.times.sort()
986         self.rate = dns_rate
987         self.duration = duration
988         self.start_time = 0
989         self.msg = random_colour_print()
990
991     def __str__(self):
992         return ("<DnsHammer %d packets over %.1fs (rate %.2f)>" %
993                 (len(self.times), self.duration, self.rate))
994
995     def replay_in_fork_with_delay(self, start, context=None, account=None):
996         return Conversation.replay_in_fork_with_delay(self,
997                                                       start,
998                                                       context,
999                                                       account)
1000
1001     def replay(self, context=None):
1002         start = time.time()
1003         fn = traffic_packets.packet_dns_0
1004         for t in self.times:
1005             now = time.time() - start
1006             gap = t - now
1007             sleep_time = gap - SLEEP_OVERHEAD
1008             if sleep_time > 0:
1009                 time.sleep(sleep_time)
1010
1011             if context is None:
1012                 miss = t - (time.time() - start)
1013                 self.msg("packet %s [miss %.3f pid %d]" % (t, miss,
1014                                                            os.getpid()))
1015                 continue
1016
1017             packet_start = time.time()
1018             try:
1019                 fn(self, self, context)
1020                 end = time.time()
1021                 duration = end - packet_start
1022                 print("%f\tDNS\tdns\t0\t%f\tTrue\t" % (end, duration))
1023             except Exception as e:
1024                 end = time.time()
1025                 duration = end - packet_start
1026                 print("%f\tDNS\tdns\t0\t%f\tFalse\t%s" % (end, duration, e))
1027
1028
1029 def ingest_summaries(files, dns_mode='count'):
1030     """Load a summary traffic summary file and generated Converations from it.
1031     """
1032
1033     dns_counts = defaultdict(int)
1034     packets = []
1035     for f in files:
1036         if isinstance(f, str):
1037             f = open(f)
1038         print("Ingesting %s" % (f.name,), file=sys.stderr)
1039         for line in f:
1040             p = Packet.from_line(line)
1041             if p.protocol == 'dns' and dns_mode != 'include':
1042                 dns_counts[p.opcode] += 1
1043             else:
1044                 packets.append(p)
1045
1046         f.close()
1047
1048     if not packets:
1049         return [], 0
1050
1051     start_time = min(p.timestamp for p in packets)
1052     last_packet = max(p.timestamp for p in packets)
1053
1054     print("gathering packets into conversations", file=sys.stderr)
1055     conversations = OrderedDict()
1056     for p in packets:
1057         p.timestamp -= start_time
1058         c = conversations.get(p.endpoints)
1059         if c is None:
1060             c = Conversation()
1061             conversations[p.endpoints] = c
1062         c.add_packet(p)
1063
1064     # We only care about conversations with actual traffic, so we
1065     # filter out conversations with nothing to say. We do that here,
1066     # rather than earlier, because those empty packets contain useful
1067     # hints as to which end of the conversation was the client.
1068     conversation_list = []
1069     for c in conversations.values():
1070         if len(c) != 0:
1071             conversation_list.append(c)
1072
1073     # This is obviously not correct, as many conversations will appear
1074     # to start roughly simultaneously at the beginning of the snapshot.
1075     # To which we say: oh well, so be it.
1076     duration = float(last_packet - start_time)
1077     mean_interval = len(conversations) / duration
1078
1079     return conversation_list, mean_interval, duration, dns_counts
1080
1081
1082 def guess_server_address(conversations):
1083     # we guess the most common address.
1084     addresses = Counter()
1085     for c in conversations:
1086         addresses.update(c.endpoints)
1087     if addresses:
1088         return addresses.most_common(1)[0]
1089
1090
1091 def stringify_keys(x):
1092     y = {}
1093     for k, v in x.items():
1094         k2 = '\t'.join(k)
1095         y[k2] = v
1096     return y
1097
1098
1099 def unstringify_keys(x):
1100     y = {}
1101     for k, v in x.items():
1102         t = tuple(str(k).split('\t'))
1103         y[t] = v
1104     return y
1105
1106
1107 class TrafficModel(object):
1108     def __init__(self, n=3):
1109         self.ngrams = {}
1110         self.query_details = {}
1111         self.n = n
1112         self.dns_opcounts = defaultdict(int)
1113         self.cumulative_duration = 0.0
1114         self.conversation_rate = [0, 1]
1115
1116     def learn(self, conversations, dns_opcounts={}):
1117         prev = 0.0
1118         cum_duration = 0.0
1119         key = (NON_PACKET,) * (self.n - 1)
1120
1121         server = guess_server_address(conversations)
1122
1123         for k, v in dns_opcounts.items():
1124             self.dns_opcounts[k] += v
1125
1126         if len(conversations) > 1:
1127             elapsed =\
1128                 conversations[-1].start_time - conversations[0].start_time
1129             self.conversation_rate[0] = len(conversations)
1130             self.conversation_rate[1] = elapsed
1131
1132         for c in conversations:
1133             client, server = c.guess_client_server(server)
1134             cum_duration += c.get_duration()
1135             key = (NON_PACKET,) * (self.n - 1)
1136             for p in c:
1137                 if p.src != client:
1138                     continue
1139
1140                 elapsed = p.timestamp - prev
1141                 prev = p.timestamp
1142                 if elapsed > WAIT_THRESHOLD:
1143                     # add the wait as an extra state
1144                     wait = 'wait:%d' % (math.log(max(1.0,
1145                                                      elapsed * WAIT_SCALE)))
1146                     self.ngrams.setdefault(key, []).append(wait)
1147                     key = key[1:] + (wait,)
1148
1149                 short_p = p.as_packet_type()
1150                 self.query_details.setdefault(short_p,
1151                                               []).append(tuple(p.extra))
1152                 self.ngrams.setdefault(key, []).append(short_p)
1153                 key = key[1:] + (short_p,)
1154
1155         self.cumulative_duration += cum_duration
1156         # add in the end
1157         self.ngrams.setdefault(key, []).append(NON_PACKET)
1158
1159     def save(self, f):
1160         ngrams = {}
1161         for k, v in self.ngrams.items():
1162             k = '\t'.join(k)
1163             ngrams[k] = dict(Counter(v))
1164
1165         query_details = {}
1166         for k, v in self.query_details.items():
1167             query_details[k] = dict(Counter('\t'.join(x) if x else '-'
1168                                             for x in v))
1169
1170         d = {
1171             'ngrams': ngrams,
1172             'query_details': query_details,
1173             'cumulative_duration': self.cumulative_duration,
1174             'conversation_rate': self.conversation_rate,
1175         }
1176         d['dns'] = self.dns_opcounts
1177
1178         if isinstance(f, str):
1179             f = open(f, 'w')
1180
1181         json.dump(d, f, indent=2)
1182
1183     def load(self, f):
1184         if isinstance(f, str):
1185             f = open(f)
1186
1187         d = json.load(f)
1188
1189         for k, v in d['ngrams'].items():
1190             k = tuple(str(k).split('\t'))
1191             values = self.ngrams.setdefault(k, [])
1192             for p, count in v.items():
1193                 values.extend([str(p)] * count)
1194
1195         for k, v in d['query_details'].items():
1196             values = self.query_details.setdefault(str(k), [])
1197             for p, count in v.items():
1198                 if p == '-':
1199                     values.extend([()] * count)
1200                 else:
1201                     values.extend([tuple(str(p).split('\t'))] * count)
1202
1203         if 'dns' in d:
1204             for k, v in d['dns'].items():
1205                 self.dns_opcounts[k] += v
1206
1207         self.cumulative_duration = d['cumulative_duration']
1208         self.conversation_rate = d['conversation_rate']
1209
1210     def construct_conversation(self, timestamp=0.0, client=2, server=1,
1211                                hard_stop=None, packet_rate=1):
1212         """Construct a individual converation from the model."""
1213
1214         c = Conversation(timestamp, (server, client))
1215
1216         key = (NON_PACKET,) * (self.n - 1)
1217
1218         while key in self.ngrams:
1219             p = random.choice(self.ngrams.get(key, NON_PACKET))
1220             if p == NON_PACKET:
1221                 break
1222             if p in self.query_details:
1223                 extra = random.choice(self.query_details[p])
1224             else:
1225                 extra = []
1226
1227             protocol, opcode = p.split(':', 1)
1228             if protocol == 'wait':
1229                 log_wait_time = int(opcode) + random.random()
1230                 wait = math.exp(log_wait_time) / (WAIT_SCALE * packet_rate)
1231                 timestamp += wait
1232             else:
1233                 log_wait = random.uniform(*NO_WAIT_LOG_TIME_RANGE)
1234                 wait = math.exp(log_wait) / packet_rate
1235                 timestamp += wait
1236                 if hard_stop is not None and timestamp > hard_stop:
1237                     break
1238                 c.add_short_packet(timestamp, protocol, opcode, extra)
1239
1240             key = key[1:] + (p,)
1241
1242         return c
1243
1244     def generate_conversations(self, rate, duration, packet_rate=1):
1245         """Generate a list of conversations from the model."""
1246
1247         # We run the simulation for at least ten times as long as our
1248         # desired duration, and take a section near the start.
1249         rate_n, rate_t  = self.conversation_rate
1250
1251         duration2 = max(rate_t, duration * 2)
1252         n = rate * duration2 * rate_n / rate_t
1253
1254         server = 1
1255         client = 2
1256
1257         conversations = []
1258         end = duration2
1259         start = end - duration
1260
1261         while client < n + 2:
1262             start = random.uniform(0, duration2)
1263             c = self.construct_conversation(start,
1264                                             client,
1265                                             server,
1266                                             hard_stop=(duration2 * 5),
1267                                             packet_rate=packet_rate)
1268
1269             c.forget_packets_outside_window(start, end)
1270             c.renormalise_times(start)
1271             if len(c) != 0:
1272                 conversations.append(c)
1273             client += 1
1274
1275         print(("we have %d conversations at rate %f" %
1276                (len(conversations), rate)), file=sys.stderr)
1277         conversations.sort()
1278         return conversations
1279
1280
1281 IP_PROTOCOLS = {
1282     'dns': '11',
1283     'rpc_netlogon': '06',
1284     'kerberos': '06',      # ratio 16248:258
1285     'smb': '06',
1286     'smb2': '06',
1287     'ldap': '06',
1288     'cldap': '11',
1289     'lsarpc': '06',
1290     'samr': '06',
1291     'dcerpc': '06',
1292     'epm': '06',
1293     'drsuapi': '06',
1294     'browser': '11',
1295     'smb_netlogon': '11',
1296     'srvsvc': '06',
1297     'nbns': '11',
1298 }
1299
1300 OP_DESCRIPTIONS = {
1301     ('browser', '0x01'): 'Host Announcement (0x01)',
1302     ('browser', '0x02'): 'Request Announcement (0x02)',
1303     ('browser', '0x08'): 'Browser Election Request (0x08)',
1304     ('browser', '0x09'): 'Get Backup List Request (0x09)',
1305     ('browser', '0x0c'): 'Domain/Workgroup Announcement (0x0c)',
1306     ('browser', '0x0f'): 'Local Master Announcement (0x0f)',
1307     ('cldap', '3'): 'searchRequest',
1308     ('cldap', '5'): 'searchResDone',
1309     ('dcerpc', '0'): 'Request',
1310     ('dcerpc', '11'): 'Bind',
1311     ('dcerpc', '12'): 'Bind_ack',
1312     ('dcerpc', '13'): 'Bind_nak',
1313     ('dcerpc', '14'): 'Alter_context',
1314     ('dcerpc', '15'): 'Alter_context_resp',
1315     ('dcerpc', '16'): 'AUTH3',
1316     ('dcerpc', '2'): 'Response',
1317     ('dns', '0'): 'query',
1318     ('dns', '1'): 'response',
1319     ('drsuapi', '0'): 'DsBind',
1320     ('drsuapi', '12'): 'DsCrackNames',
1321     ('drsuapi', '13'): 'DsWriteAccountSpn',
1322     ('drsuapi', '1'): 'DsUnbind',
1323     ('drsuapi', '2'): 'DsReplicaSync',
1324     ('drsuapi', '3'): 'DsGetNCChanges',
1325     ('drsuapi', '4'): 'DsReplicaUpdateRefs',
1326     ('epm', '3'): 'Map',
1327     ('kerberos', ''): '',
1328     ('ldap', '0'): 'bindRequest',
1329     ('ldap', '1'): 'bindResponse',
1330     ('ldap', '2'): 'unbindRequest',
1331     ('ldap', '3'): 'searchRequest',
1332     ('ldap', '4'): 'searchResEntry',
1333     ('ldap', '5'): 'searchResDone',
1334     ('ldap', ''): '*** Unknown ***',
1335     ('lsarpc', '14'): 'lsa_LookupNames',
1336     ('lsarpc', '15'): 'lsa_LookupSids',
1337     ('lsarpc', '39'): 'lsa_QueryTrustedDomainInfoBySid',
1338     ('lsarpc', '40'): 'lsa_SetTrustedDomainInfo',
1339     ('lsarpc', '6'): 'lsa_OpenPolicy',
1340     ('lsarpc', '76'): 'lsa_LookupSids3',
1341     ('lsarpc', '77'): 'lsa_LookupNames4',
1342     ('nbns', '0'): 'query',
1343     ('nbns', '1'): 'response',
1344     ('rpc_netlogon', '21'): 'NetrLogonDummyRoutine1',
1345     ('rpc_netlogon', '26'): 'NetrServerAuthenticate3',
1346     ('rpc_netlogon', '29'): 'NetrLogonGetDomainInfo',
1347     ('rpc_netlogon', '30'): 'NetrServerPasswordSet2',
1348     ('rpc_netlogon', '39'): 'NetrLogonSamLogonEx',
1349     ('rpc_netlogon', '40'): 'DsrEnumerateDomainTrusts',
1350     ('rpc_netlogon', '45'): 'NetrLogonSamLogonWithFlags',
1351     ('rpc_netlogon', '4'): 'NetrServerReqChallenge',
1352     ('samr', '0',): 'Connect',
1353     ('samr', '16'): 'GetAliasMembership',
1354     ('samr', '17'): 'LookupNames',
1355     ('samr', '18'): 'LookupRids',
1356     ('samr', '19'): 'OpenGroup',
1357     ('samr', '1'): 'Close',
1358     ('samr', '25'): 'QueryGroupMember',
1359     ('samr', '34'): 'OpenUser',
1360     ('samr', '36'): 'QueryUserInfo',
1361     ('samr', '39'): 'GetGroupsForUser',
1362     ('samr', '3'): 'QuerySecurity',
1363     ('samr', '5'): 'LookupDomain',
1364     ('samr', '64'): 'Connect5',
1365     ('samr', '6'): 'EnumDomains',
1366     ('samr', '7'): 'OpenDomain',
1367     ('samr', '8'): 'QueryDomainInfo',
1368     ('smb', '0x04'): 'Close (0x04)',
1369     ('smb', '0x24'): 'Locking AndX (0x24)',
1370     ('smb', '0x2e'): 'Read AndX (0x2e)',
1371     ('smb', '0x32'): 'Trans2 (0x32)',
1372     ('smb', '0x71'): 'Tree Disconnect (0x71)',
1373     ('smb', '0x72'): 'Negotiate Protocol (0x72)',
1374     ('smb', '0x73'): 'Session Setup AndX (0x73)',
1375     ('smb', '0x74'): 'Logoff AndX (0x74)',
1376     ('smb', '0x75'): 'Tree Connect AndX (0x75)',
1377     ('smb', '0xa2'): 'NT Create AndX (0xa2)',
1378     ('smb2', '0'): 'NegotiateProtocol',
1379     ('smb2', '11'): 'Ioctl',
1380     ('smb2', '14'): 'Find',
1381     ('smb2', '16'): 'GetInfo',
1382     ('smb2', '18'): 'Break',
1383     ('smb2', '1'): 'SessionSetup',
1384     ('smb2', '2'): 'SessionLogoff',
1385     ('smb2', '3'): 'TreeConnect',
1386     ('smb2', '4'): 'TreeDisconnect',
1387     ('smb2', '5'): 'Create',
1388     ('smb2', '6'): 'Close',
1389     ('smb2', '8'): 'Read',
1390     ('smb_netlogon', '0x12'): 'SAM LOGON request from client (0x12)',
1391     ('smb_netlogon', '0x17'): ('SAM Active Directory Response - '
1392                                'user unknown (0x17)'),
1393     ('srvsvc', '16'): 'NetShareGetInfo',
1394     ('srvsvc', '21'): 'NetSrvGetInfo',
1395 }
1396
1397
1398 def expand_short_packet(p, timestamp, src, dest, extra):
1399     protocol, opcode = p.split(':', 1)
1400     desc = OP_DESCRIPTIONS.get((protocol, opcode), '')
1401     ip_protocol = IP_PROTOCOLS.get(protocol, '06')
1402
1403     line = [timestamp, ip_protocol, '', src, dest, protocol, opcode, desc]
1404     line.extend(extra)
1405     return '\t'.join(line)
1406
1407
1408 def replay(conversations,
1409            host=None,
1410            creds=None,
1411            lp=None,
1412            accounts=None,
1413            dns_rate=0,
1414            duration=None,
1415            **kwargs):
1416
1417     context = ReplayContext(server=host,
1418                             creds=creds,
1419                             lp=lp,
1420                             **kwargs)
1421
1422     if len(accounts) < len(conversations):
1423         print(("we have %d accounts but %d conversations" %
1424                (accounts, conversations)), file=sys.stderr)
1425
1426     cstack = list(zip(
1427         sorted(conversations, key=lambda x: x.start_time, reverse=True),
1428         accounts))
1429
1430     # Set the process group so that the calling scripts are not killed
1431     # when the forked child processes are killed.
1432     os.setpgrp()
1433
1434     start = time.time()
1435
1436     if duration is None:
1437         # end 1 second after the last packet of the last conversation
1438         # to start. Conversations other than the last could still be
1439         # going, but we don't care.
1440         duration = cstack[0][0].packets[-1].timestamp + 1.0
1441         print("We will stop after %.1f seconds" % duration,
1442               file=sys.stderr)
1443
1444     end = start + duration
1445
1446     LOGGER.info("Replaying traffic for %u conversations over %d seconds"
1447           % (len(conversations), duration))
1448
1449     children = {}
1450     if dns_rate:
1451         dns_hammer = DnsHammer(dns_rate, duration)
1452         cstack.append((dns_hammer, None))
1453
1454     try:
1455         while True:
1456             # we spawn a batch, wait for finishers, then spawn another
1457             now = time.time()
1458             batch_end = min(now + 2.0, end)
1459             fork_time = 0.0
1460             fork_n = 0
1461             while cstack:
1462                 c, account = cstack.pop()
1463                 if c.start_time + start > batch_end:
1464                     cstack.append((c, account))
1465                     break
1466
1467                 st = time.time()
1468                 pid = c.replay_in_fork_with_delay(start, context, account)
1469                 children[pid] = c
1470                 t = time.time()
1471                 elapsed = t - st
1472                 fork_time += elapsed
1473                 fork_n += 1
1474                 print("forked %s in pid %s (in %fs)" % (c, pid,
1475                                                         elapsed),
1476                       file=sys.stderr)
1477
1478             if fork_n:
1479                 print(("forked %d times in %f seconds (avg %f)" %
1480                        (fork_n, fork_time, fork_time / fork_n)),
1481                       file=sys.stderr)
1482             elif cstack:
1483                 debug(2, "no forks in batch ending %f" % batch_end)
1484
1485             while time.time() < batch_end - 1.0:
1486                 time.sleep(0.01)
1487                 try:
1488                     pid, status = os.waitpid(-1, os.WNOHANG)
1489                 except OSError as e:
1490                     if e.errno != 10:  # no child processes
1491                         raise
1492                     break
1493                 if pid:
1494                     c = children.pop(pid, None)
1495                     print(("process %d finished conversation %s;"
1496                            " %d to go" %
1497                            (pid, c, len(children))), file=sys.stderr)
1498
1499             if time.time() >= end:
1500                 print("time to stop", file=sys.stderr)
1501                 break
1502
1503     except Exception:
1504         print("EXCEPTION in parent", file=sys.stderr)
1505         traceback.print_exc()
1506     finally:
1507         for s in (15, 15, 9):
1508             print(("killing %d children with -%d" %
1509                    (len(children), s)), file=sys.stderr)
1510             for pid in children:
1511                 try:
1512                     os.kill(pid, s)
1513                 except OSError as e:
1514                     if e.errno != 3:  # don't fail if it has already died
1515                         raise
1516             time.sleep(0.5)
1517             end = time.time() + 1
1518             while children:
1519                 try:
1520                     pid, status = os.waitpid(-1, os.WNOHANG)
1521                 except OSError as e:
1522                     if e.errno != 10:
1523                         raise
1524                 if pid != 0:
1525                     c = children.pop(pid, None)
1526                     print(("kill -%d %d KILLED conversation %s; "
1527                            "%d to go" %
1528                            (s, pid, c, len(children))),
1529                           file=sys.stderr)
1530                 if time.time() >= end:
1531                     break
1532
1533             if not children:
1534                 break
1535             time.sleep(1)
1536
1537         if children:
1538             print("%d children are missing" % len(children),
1539                   file=sys.stderr)
1540
1541         # there may be stragglers that were forked just as ^C was hit
1542         # and don't appear in the list of children. We can get them
1543         # with killpg, but that will also kill us, so this is^H^H would be
1544         # goodbye, except we cheat and pretend to use ^C (SIG_INTERRUPT),
1545         # so as not to have to fuss around writing signal handlers.
1546         try:
1547             os.killpg(0, 2)
1548         except KeyboardInterrupt:
1549             print("ignoring fake ^C", file=sys.stderr)
1550
1551
1552 def openLdb(host, creds, lp):
1553     session = system_session()
1554     ldb = SamDB(url="ldap://%s" % host,
1555                 session_info=session,
1556                 options=['modules:paged_searches'],
1557                 credentials=creds,
1558                 lp=lp)
1559     return ldb
1560
1561
1562 def ou_name(ldb, instance_id):
1563     """Generate an ou name from the instance id"""
1564     return "ou=instance-%d,ou=traffic_replay,%s" % (instance_id,
1565                                                     ldb.domain_dn())
1566
1567
1568 def create_ou(ldb, instance_id):
1569     """Create an ou, all created user and machine accounts will belong to it.
1570
1571     This allows all the created resources to be cleaned up easily.
1572     """
1573     ou = ou_name(ldb, instance_id)
1574     try:
1575         ldb.add({"dn": ou.split(',', 1)[1],
1576                  "objectclass": "organizationalunit"})
1577     except LdbError as e:
1578         (status, _) = e.args
1579         # ignore already exists
1580         if status != 68:
1581             raise
1582     try:
1583         ldb.add({"dn": ou,
1584                  "objectclass": "organizationalunit"})
1585     except LdbError as e:
1586         (status, _) = e.args
1587         # ignore already exists
1588         if status != 68:
1589             raise
1590     return ou
1591
1592
1593 class ConversationAccounts(object):
1594     """Details of the machine and user accounts associated with a conversation.
1595     """
1596     def __init__(self, netbios_name, machinepass, username, userpass):
1597         self.netbios_name = netbios_name
1598         self.machinepass  = machinepass
1599         self.username     = username
1600         self.userpass     = userpass
1601
1602
1603 def generate_replay_accounts(ldb, instance_id, number, password):
1604     """Generate a series of unique machine and user account names."""
1605
1606     accounts = []
1607     for i in range(1, number + 1):
1608         netbios_name = machine_name(instance_id, i)
1609         username = user_name(instance_id, i)
1610
1611         account = ConversationAccounts(netbios_name, password, username,
1612                                        password)
1613         accounts.append(account)
1614     return accounts
1615
1616
1617 def create_machine_account(ldb, instance_id, netbios_name, machinepass,
1618                            traffic_account=True):
1619     """Create a machine account via ldap."""
1620
1621     ou = ou_name(ldb, instance_id)
1622     dn = "cn=%s,%s" % (netbios_name, ou)
1623     utf16pw = ('"%s"' % get_string(machinepass)).encode('utf-16-le')
1624
1625     if traffic_account:
1626         # we set these bits for the machine account otherwise the replayed
1627         # traffic throws up NT_STATUS_NO_TRUST_SAM_ACCOUNT errors
1628         account_controls = str(UF_TRUSTED_FOR_DELEGATION |
1629                                UF_SERVER_TRUST_ACCOUNT)
1630
1631     else:
1632         account_controls = str(UF_WORKSTATION_TRUST_ACCOUNT)
1633
1634     ldb.add({
1635         "dn": dn,
1636         "objectclass": "computer",
1637         "sAMAccountName": "%s$" % netbios_name,
1638         "userAccountControl": account_controls,
1639         "unicodePwd": utf16pw})
1640
1641
1642 def create_user_account(ldb, instance_id, username, userpass):
1643     """Create a user account via ldap."""
1644     ou = ou_name(ldb, instance_id)
1645     user_dn = "cn=%s,%s" % (username, ou)
1646     utf16pw = ('"%s"' % get_string(userpass)).encode('utf-16-le')
1647     ldb.add({
1648         "dn": user_dn,
1649         "objectclass": "user",
1650         "sAMAccountName": username,
1651         "userAccountControl": str(UF_NORMAL_ACCOUNT),
1652         "unicodePwd": utf16pw
1653     })
1654
1655     # grant user write permission to do things like write account SPN
1656     sdutils = sd_utils.SDUtils(ldb)
1657     sdutils.dacl_add_ace(user_dn, "(A;;WP;;;PS)")
1658
1659
1660 def create_group(ldb, instance_id, name):
1661     """Create a group via ldap."""
1662
1663     ou = ou_name(ldb, instance_id)
1664     dn = "cn=%s,%s" % (name, ou)
1665     ldb.add({
1666         "dn": dn,
1667         "objectclass": "group",
1668         "sAMAccountName": name,
1669     })
1670
1671
1672 def user_name(instance_id, i):
1673     """Generate a user name based in the instance id"""
1674     return "STGU-%d-%d" % (instance_id, i)
1675
1676
1677 def search_objectclass(ldb, objectclass='user', attr='sAMAccountName'):
1678     """Seach objectclass, return attr in a set"""
1679     objs = ldb.search(
1680         expression="(objectClass={})".format(objectclass),
1681         attrs=[attr]
1682     )
1683     return {str(obj[attr]) for obj in objs}
1684
1685
1686 def generate_users(ldb, instance_id, number, password):
1687     """Add users to the server"""
1688     existing_objects = search_objectclass(ldb, objectclass='user')
1689     users = 0
1690     for i in range(number, 0, -1):
1691         name = user_name(instance_id, i)
1692         if name not in existing_objects:
1693             create_user_account(ldb, instance_id, name, password)
1694             users += 1
1695             if users % 50 == 0:
1696                 LOGGER.info("Created %u/%u users" % (users, number))
1697
1698     return users
1699
1700
1701 def machine_name(instance_id, i, traffic_account=True):
1702     """Generate a machine account name from instance id."""
1703     if traffic_account:
1704         # traffic accounts correspond to a given user, and use different
1705         # userAccountControl flags to ensure packets get processed correctly
1706         # by the DC
1707         return "STGM-%d-%d" % (instance_id, i)
1708     else:
1709         # Otherwise we're just generating computer accounts to simulate a
1710         # semi-realistic network. These use the default computer
1711         # userAccountControl flags, so we use a different account name so that
1712         # we don't try to use them when generating packets
1713         return "PC-%d-%d" % (instance_id, i)
1714
1715
1716 def generate_machine_accounts(ldb, instance_id, number, password,
1717                               traffic_account=True):
1718     """Add machine accounts to the server"""
1719     existing_objects = search_objectclass(ldb, objectclass='computer')
1720     added = 0
1721     for i in range(number, 0, -1):
1722         name = machine_name(instance_id, i, traffic_account)
1723         if name + "$" not in existing_objects:
1724             create_machine_account(ldb, instance_id, name, password,
1725                                    traffic_account)
1726             added += 1
1727             if added % 50 == 0:
1728                 LOGGER.info("Created %u/%u machine accounts" % (added, number))
1729
1730     return added
1731
1732
1733 def group_name(instance_id, i):
1734     """Generate a group name from instance id."""
1735     return "STGG-%d-%d" % (instance_id, i)
1736
1737
1738 def generate_groups(ldb, instance_id, number):
1739     """Create the required number of groups on the server."""
1740     existing_objects = search_objectclass(ldb, objectclass='group')
1741     groups = 0
1742     for i in range(number, 0, -1):
1743         name = group_name(instance_id, i)
1744         if name not in existing_objects:
1745             create_group(ldb, instance_id, name)
1746             groups += 1
1747             if groups % 1000 == 0:
1748                 LOGGER.info("Created %u/%u groups" % (groups, number))
1749
1750     return groups
1751
1752
1753 def clean_up_accounts(ldb, instance_id):
1754     """Remove the created accounts and groups from the server."""
1755     ou = ou_name(ldb, instance_id)
1756     try:
1757         ldb.delete(ou, ["tree_delete:1"])
1758     except LdbError as e:
1759         (status, _) = e.args
1760         # ignore does not exist
1761         if status != 32:
1762             raise
1763
1764
1765 def generate_users_and_groups(ldb, instance_id, password,
1766                               number_of_users, number_of_groups,
1767                               group_memberships, max_members,
1768                               machine_accounts, traffic_accounts=True):
1769     """Generate the required users and groups, allocating the users to
1770        those groups."""
1771     memberships_added = 0
1772     groups_added = 0
1773     computers_added = 0
1774
1775     create_ou(ldb, instance_id)
1776
1777     LOGGER.info("Generating dummy user accounts")
1778     users_added = generate_users(ldb, instance_id, number_of_users, password)
1779
1780     LOGGER.info("Generating dummy machine accounts")
1781     computers_added = generate_machine_accounts(ldb, instance_id,
1782                                                 machine_accounts, password,
1783                                                 traffic_accounts)
1784
1785     if number_of_groups > 0:
1786         LOGGER.info("Generating dummy groups")
1787         groups_added = generate_groups(ldb, instance_id, number_of_groups)
1788
1789     if group_memberships > 0:
1790         LOGGER.info("Assigning users to groups")
1791         assignments = GroupAssignments(number_of_groups,
1792                                        groups_added,
1793                                        number_of_users,
1794                                        users_added,
1795                                        group_memberships,
1796                                        max_members)
1797         LOGGER.info("Adding users to groups")
1798         add_users_to_groups(ldb, instance_id, assignments)
1799         memberships_added = assignments.total()
1800
1801     if (groups_added > 0 and users_added == 0 and
1802        number_of_groups != groups_added):
1803         LOGGER.warning("The added groups will contain no members")
1804
1805     LOGGER.info("Added %d users (%d machines), %d groups and %d memberships" %
1806                 (users_added, computers_added, groups_added,
1807                  memberships_added))
1808
1809
1810 class GroupAssignments(object):
1811     def __init__(self, number_of_groups, groups_added, number_of_users,
1812                  users_added, group_memberships, max_members):
1813
1814         self.count = 0
1815         self.generate_group_distribution(number_of_groups)
1816         self.generate_user_distribution(number_of_users, group_memberships)
1817         self.max_members = max_members
1818         self.assignments = defaultdict(list)
1819         self.assign_groups(number_of_groups, groups_added, number_of_users,
1820                            users_added, group_memberships)
1821
1822     def cumulative_distribution(self, weights):
1823         # make sure the probabilities conform to a cumulative distribution
1824         # spread between 0.0 and 1.0. Dividing by the weighted total gives each
1825         # probability a proportional share of 1.0. Higher probabilities get a
1826         # bigger share, so are more likely to be picked. We use the cumulative
1827         # value, so we can use random.random() as a simple index into the list
1828         dist = []
1829         total = sum(weights)
1830         if total == 0:
1831             return None
1832
1833         cumulative = 0.0
1834         for probability in weights:
1835             cumulative += probability
1836             dist.append(cumulative / total)
1837         return dist
1838
1839     def generate_user_distribution(self, num_users, num_memberships):
1840         """Probability distribution of a user belonging to a group.
1841         """
1842         # Assign a weighted probability to each user. Use the Pareto
1843         # Distribution so that some users are in a lot of groups, and the
1844         # bulk of users are in only a few groups. If we're assigning a large
1845         # number of group memberships, use a higher shape. This means slightly
1846         # fewer outlying users that are in large numbers of groups. The aim is
1847         # to have no users belonging to more than ~500 groups.
1848         if num_memberships > 5000000:
1849             shape = 3.0
1850         elif num_memberships > 2000000:
1851             shape = 2.5
1852         elif num_memberships > 300000:
1853             shape = 2.25
1854         else:
1855             shape = 1.75
1856
1857         weights = []
1858         for x in range(1, num_users + 1):
1859             p = random.paretovariate(shape)
1860             weights.append(p)
1861
1862         # convert the weights to a cumulative distribution between 0.0 and 1.0
1863         self.user_dist = self.cumulative_distribution(weights)
1864
1865     def generate_group_distribution(self, n):
1866         """Probability distribution of a group containing a user."""
1867
1868         # Assign a weighted probability to each user. Probability decreases
1869         # as the group-ID increases
1870         weights = []
1871         for x in range(1, n + 1):
1872             p = 1 / (x**1.3)
1873             weights.append(p)
1874
1875         # convert the weights to a cumulative distribution between 0.0 and 1.0
1876         self.group_weights = weights
1877         self.group_dist = self.cumulative_distribution(weights)
1878
1879     def generate_random_membership(self):
1880         """Returns a randomly generated user-group membership"""
1881
1882         # the list items are cumulative distribution values between 0.0 and
1883         # 1.0, which makes random() a handy way to index the list to get a
1884         # weighted random user/group. (Here the user/group returned are
1885         # zero-based array indexes)
1886         user = bisect.bisect(self.user_dist, random.random())
1887         group = bisect.bisect(self.group_dist, random.random())
1888
1889         return user, group
1890
1891     def users_in_group(self, group):
1892         return self.assignments[group]
1893
1894     def get_groups(self):
1895         return self.assignments.keys()
1896
1897     def cap_group_membership(self, group, max_members):
1898         """Prevent the group's membership from exceeding the max specified"""
1899         num_members = len(self.assignments[group])
1900         if num_members >= max_members:
1901             LOGGER.info("Group {0} has {1} members".format(group, num_members))
1902
1903             # remove this group and then recalculate the cumulative
1904             # distribution, so this group is no longer selected
1905             self.group_weights[group - 1] = 0
1906             new_dist = self.cumulative_distribution(self.group_weights)
1907             self.group_dist = new_dist
1908
1909     def add_assignment(self, user, group):
1910         # the assignments are stored in a dictionary where key=group,
1911         # value=list-of-users-in-group (indexing by group-ID allows us to
1912         # optimize for DB membership writes)
1913         if user not in self.assignments[group]:
1914             self.assignments[group].append(user)
1915             self.count += 1
1916
1917         # check if there'a cap on how big the groups can grow
1918         if self.max_members:
1919             self.cap_group_membership(group, self.max_members)
1920
1921     def assign_groups(self, number_of_groups, groups_added,
1922                       number_of_users, users_added, group_memberships):
1923         """Allocate users to groups.
1924
1925         The intention is to have a few users that belong to most groups, while
1926         the majority of users belong to a few groups.
1927
1928         A few groups will contain most users, with the remaining only having a
1929         few users.
1930         """
1931
1932         if group_memberships <= 0:
1933             return
1934
1935         # Calculate the number of group menberships required
1936         group_memberships = math.ceil(
1937             float(group_memberships) *
1938             (float(users_added) / float(number_of_users)))
1939
1940         if self.max_members:
1941             group_memberships = min(group_memberships,
1942                                     self.max_members * number_of_groups)
1943
1944         existing_users  = number_of_users  - users_added  - 1
1945         existing_groups = number_of_groups - groups_added - 1
1946         while self.total() < group_memberships:
1947             user, group = self.generate_random_membership()
1948
1949             if group > existing_groups or user > existing_users:
1950                 # the + 1 converts the array index to the corresponding
1951                 # group or user number
1952                 self.add_assignment(user + 1, group + 1)
1953
1954     def total(self):
1955         return self.count
1956
1957
1958 def add_users_to_groups(db, instance_id, assignments):
1959     """Takes the assignments of users to groups and applies them to the DB."""
1960
1961     total = assignments.total()
1962     count = 0
1963     added = 0
1964
1965     for group in assignments.get_groups():
1966         users_in_group = assignments.users_in_group(group)
1967         if len(users_in_group) == 0:
1968             continue
1969
1970         # Split up the users into chunks, so we write no more than 1K at a
1971         # time. (Minimizing the DB modifies is more efficient, but writing
1972         # 10K+ users to a single group becomes inefficient memory-wise)
1973         for chunk in range(0, len(users_in_group), 1000):
1974             chunk_of_users = users_in_group[chunk:chunk + 1000]
1975             add_group_members(db, instance_id, group, chunk_of_users)
1976
1977             added += len(chunk_of_users)
1978             count += 1
1979             if count % 50 == 0:
1980                 LOGGER.info("Added %u/%u memberships" % (added, total))
1981
1982 def add_group_members(db, instance_id, group, users_in_group):
1983     """Adds the given users to group specified."""
1984
1985     ou = ou_name(db, instance_id)
1986
1987     def build_dn(name):
1988         return("cn=%s,%s" % (name, ou))
1989
1990     group_dn = build_dn(group_name(instance_id, group))
1991     m = ldb.Message()
1992     m.dn = ldb.Dn(db, group_dn)
1993
1994     for user in users_in_group:
1995         user_dn = build_dn(user_name(instance_id, user))
1996         idx = "member-" + str(user)
1997         m[idx] = ldb.MessageElement(user_dn, ldb.FLAG_MOD_ADD, "member")
1998
1999     db.modify(m)
2000
2001
2002 def generate_stats(statsdir, timing_file):
2003     """Generate and print the summary stats for a run."""
2004     first      = sys.float_info.max
2005     last       = 0
2006     successful = 0
2007     failed     = 0
2008     latencies  = {}
2009     failures   = {}
2010     unique_converations = set()
2011     conversations = 0
2012
2013     if timing_file is not None:
2014         tw = timing_file.write
2015     else:
2016         def tw(x):
2017             pass
2018
2019     tw("time\tconv\tprotocol\ttype\tduration\tsuccessful\terror\n")
2020
2021     for filename in os.listdir(statsdir):
2022         path = os.path.join(statsdir, filename)
2023         with open(path, 'r') as f:
2024             for line in f:
2025                 try:
2026                     fields       = line.rstrip('\n').split('\t')
2027                     conversation = fields[1]
2028                     protocol     = fields[2]
2029                     packet_type  = fields[3]
2030                     latency      = float(fields[4])
2031                     first        = min(float(fields[0]) - latency, first)
2032                     last         = max(float(fields[0]), last)
2033
2034                     if protocol not in latencies:
2035                         latencies[protocol] = {}
2036                     if packet_type not in latencies[protocol]:
2037                         latencies[protocol][packet_type] = []
2038
2039                     latencies[protocol][packet_type].append(latency)
2040
2041                     if protocol not in failures:
2042                         failures[protocol] = {}
2043                     if packet_type not in failures[protocol]:
2044                         failures[protocol][packet_type] = 0
2045
2046                     if fields[5] == 'True':
2047                         successful += 1
2048                     else:
2049                         failed += 1
2050                         failures[protocol][packet_type] += 1
2051
2052                     if conversation not in unique_converations:
2053                         unique_converations.add(conversation)
2054                         conversations += 1
2055
2056                     tw(line)
2057                 except (ValueError, IndexError):
2058                     # not a valid line print and ignore
2059                     print(line, file=sys.stderr)
2060                     pass
2061     duration = last - first
2062     if successful == 0:
2063         success_rate = 0
2064     else:
2065         success_rate = successful / duration
2066     if failed == 0:
2067         failure_rate = 0
2068     else:
2069         failure_rate = failed / duration
2070
2071     print("Total conversations:   %10d" % conversations)
2072     print("Successful operations: %10d (%.3f per second)"
2073           % (successful, success_rate))
2074     print("Failed operations:     %10d (%.3f per second)"
2075           % (failed, failure_rate))
2076
2077     print("Protocol    Op Code  Description                               "
2078           " Count       Failed         Mean       Median          "
2079           "95%        Range          Max")
2080
2081     protocols = sorted(latencies.keys())
2082     for protocol in protocols:
2083         packet_types = sorted(latencies[protocol], key=opcode_key)
2084         for packet_type in packet_types:
2085             values     = latencies[protocol][packet_type]
2086             values     = sorted(values)
2087             count      = len(values)
2088             failed     = failures[protocol][packet_type]
2089             mean       = sum(values) / count
2090             median     = calc_percentile(values, 0.50)
2091             percentile = calc_percentile(values, 0.95)
2092             rng        = values[-1] - values[0]
2093             maxv       = values[-1]
2094             desc       = OP_DESCRIPTIONS.get((protocol, packet_type), '')
2095             if sys.stdout.isatty:
2096                 print("%-12s   %4s  %-35s %12d %12d %12.6f "
2097                       "%12.6f %12.6f %12.6f %12.6f"
2098                       % (protocol,
2099                          packet_type,
2100                          desc,
2101                          count,
2102                          failed,
2103                          mean,
2104                          median,
2105                          percentile,
2106                          rng,
2107                          maxv))
2108             else:
2109                 print("%s\t%s\t%s\t%d\t%d\t%f\t%f\t%f\t%f\t%f"
2110                       % (protocol,
2111                          packet_type,
2112                          desc,
2113                          count,
2114                          failed,
2115                          mean,
2116                          median,
2117                          percentile,
2118                          rng,
2119                          maxv))
2120
2121
2122 def opcode_key(v):
2123     """Sort key for the operation code to ensure that it sorts numerically"""
2124     try:
2125         return "%03d" % int(v)
2126     except:
2127         return v
2128
2129
2130 def calc_percentile(values, percentile):
2131     """Calculate the specified percentile from the list of values.
2132
2133     Assumes the list is sorted in ascending order.
2134     """
2135
2136     if not values:
2137         return 0
2138     k = (len(values) - 1) * percentile
2139     f = math.floor(k)
2140     c = math.ceil(k)
2141     if f == c:
2142         return values[int(k)]
2143     d0 = values[int(f)] * (c - k)
2144     d1 = values[int(c)] * (k - f)
2145     return d0 + d1
2146
2147
2148 def mk_masked_dir(*path):
2149     """In a testenv we end up with 0777 diectories that look an alarming
2150     green colour with ls. Use umask to avoid that."""
2151     d = os.path.join(*path)
2152     mask = os.umask(0o077)
2153     os.mkdir(d)
2154     os.umask(mask)
2155     return d