traffic_replay: Make sure naming assumptions are in a single place
[samba.git] / python / samba / emulate / traffic.py
1 # -*- encoding: utf-8 -*-
2 # Samba traffic replay and learning
3 #
4 # Copyright (C) Catalyst IT Ltd. 2017
5 #
6 # This program is free software; you can redistribute it and/or modify
7 # it under the terms of the GNU General Public License as published by
8 # the Free Software Foundation; either version 3 of the License, or
9 # (at your option) any later version.
10 #
11 # This program is distributed in the hope that it will be useful,
12 # but WITHOUT ANY WARRANTY; without even the implied warranty of
13 # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
14 # GNU General Public License for more details.
15 #
16 # You should have received a copy of the GNU General Public License
17 # along with this program.  If not, see <http://www.gnu.org/licenses/>.
18 #
19 from __future__ import print_function, division
20
21 import time
22 import os
23 import random
24 import json
25 import math
26 import sys
27 import signal
28 import itertools
29
30 from collections import OrderedDict, Counter, defaultdict
31 from samba.emulate import traffic_packets
32 from samba.samdb import SamDB
33 import ldb
34 from ldb import LdbError
35 from samba.dcerpc import ClientConnection
36 from samba.dcerpc import security, drsuapi, lsa
37 from samba.dcerpc import netlogon
38 from samba.dcerpc.netlogon import netr_Authenticator
39 from samba.dcerpc import srvsvc
40 from samba.dcerpc import samr
41 from samba.drs_utils import drs_DsBind
42 import traceback
43 from samba.credentials import Credentials, DONT_USE_KERBEROS, MUST_USE_KERBEROS
44 from samba.auth import system_session
45 from samba.dsdb import (
46     UF_NORMAL_ACCOUNT,
47     UF_SERVER_TRUST_ACCOUNT,
48     UF_TRUSTED_FOR_DELEGATION,
49     UF_WORKSTATION_TRUST_ACCOUNT
50 )
51 from samba.dcerpc.misc import SEC_CHAN_BDC
52 from samba import gensec
53 from samba import sd_utils
54 from samba.compat import get_string
55 from samba.logger import get_samba_logger
56 import bisect
57
58 SLEEP_OVERHEAD = 3e-4
59
60 # we don't use None, because it complicates [de]serialisation
61 NON_PACKET = '-'
62
63 CLIENT_CLUES = {
64     ('dns', '0'): 1.0,      # query
65     ('smb', '0x72'): 1.0,   # Negotiate protocol
66     ('ldap', '0'): 1.0,     # bind
67     ('ldap', '3'): 1.0,     # searchRequest
68     ('ldap', '2'): 1.0,     # unbindRequest
69     ('cldap', '3'): 1.0,
70     ('dcerpc', '11'): 1.0,  # bind
71     ('dcerpc', '14'): 1.0,  # Alter_context
72     ('nbns', '0'): 1.0,     # query
73 }
74
75 SERVER_CLUES = {
76     ('dns', '1'): 1.0,      # response
77     ('ldap', '1'): 1.0,     # bind response
78     ('ldap', '4'): 1.0,     # search result
79     ('ldap', '5'): 1.0,     # search done
80     ('cldap', '5'): 1.0,
81     ('dcerpc', '12'): 1.0,  # bind_ack
82     ('dcerpc', '13'): 1.0,  # bind_nak
83     ('dcerpc', '15'): 1.0,  # Alter_context response
84 }
85
86 SKIPPED_PROTOCOLS = {"smb", "smb2", "browser", "smb_netlogon"}
87
88 WAIT_SCALE = 10.0
89 WAIT_THRESHOLD = (1.0 / WAIT_SCALE)
90 NO_WAIT_LOG_TIME_RANGE = (-10, -3)
91
92 # DEBUG_LEVEL can be changed by scripts with -d
93 DEBUG_LEVEL = 0
94
95 LOGGER = get_samba_logger(name=__name__)
96
97
98 def debug(level, msg, *args):
99     """Print a formatted debug message to standard error.
100
101
102     :param level: The debug level, message will be printed if it is <= the
103                   currently set debug level. The debug level can be set with
104                   the -d option.
105     :param msg:   The message to be logged, can contain C-Style format
106                   specifiers
107     :param args:  The parameters required by the format specifiers
108     """
109     if level <= DEBUG_LEVEL:
110         if not args:
111             print(msg, file=sys.stderr)
112         else:
113             print(msg % tuple(args), file=sys.stderr)
114
115
116 def debug_lineno(*args):
117     """ Print an unformatted log message to stderr, contaning the line number
118     """
119     tb = traceback.extract_stack(limit=2)
120     print((" %s:" "\033[01;33m"
121            "%s " "\033[00m" % (tb[0][2], tb[0][1])), end=' ',
122           file=sys.stderr)
123     for a in args:
124         print(a, file=sys.stderr)
125     print(file=sys.stderr)
126     sys.stderr.flush()
127
128
129 def random_colour_print():
130     """Return a function that prints a randomly coloured line to stderr"""
131     n = 18 + random.randrange(214)
132     prefix = "\033[38;5;%dm" % n
133
134     def p(*args):
135         for a in args:
136             print("%s%s\033[00m" % (prefix, a), file=sys.stderr)
137
138     return p
139
140
141 class FakePacketError(Exception):
142     pass
143
144
145 class Packet(object):
146     """Details of a network packet"""
147     def __init__(self, timestamp, ip_protocol, stream_number, src, dest,
148                  protocol, opcode, desc, extra):
149
150         self.timestamp = timestamp
151         self.ip_protocol = ip_protocol
152         self.stream_number = stream_number
153         self.src = src
154         self.dest = dest
155         self.protocol = protocol
156         self.opcode = opcode
157         self.desc = desc
158         self.extra = extra
159         if self.src < self.dest:
160             self.endpoints = (self.src, self.dest)
161         else:
162             self.endpoints = (self.dest, self.src)
163
164     @classmethod
165     def from_line(self, line):
166         fields = line.rstrip('\n').split('\t')
167         (timestamp,
168          ip_protocol,
169          stream_number,
170          src,
171          dest,
172          protocol,
173          opcode,
174          desc) = fields[:8]
175         extra = fields[8:]
176
177         timestamp = float(timestamp)
178         src = int(src)
179         dest = int(dest)
180
181         return Packet(timestamp, ip_protocol, stream_number, src, dest,
182                       protocol, opcode, desc, extra)
183
184     def as_summary(self, time_offset=0.0):
185         """Format the packet as a traffic_summary line.
186         """
187         extra = '\t'.join(self.extra)
188         t = self.timestamp + time_offset
189         return (t, '%f\t%s\t%s\t%d\t%d\t%s\t%s\t%s\t%s' %
190                 (t,
191                  self.ip_protocol,
192                  self.stream_number or '',
193                  self.src,
194                  self.dest,
195                  self.protocol,
196                  self.opcode,
197                  self.desc,
198                  extra))
199
200     def __str__(self):
201         return ("%.3f: %d -> %d; ip %s; strm %s; prot %s; op %s; desc %s %s" %
202                 (self.timestamp, self.src, self.dest, self.ip_protocol or '-',
203                  self.stream_number, self.protocol, self.opcode, self.desc,
204                  ('«' + ' '.join(self.extra) + '»' if self.extra else '')))
205
206     def __repr__(self):
207         return "<Packet @%s>" % self
208
209     def copy(self):
210         return self.__class__(self.timestamp,
211                               self.ip_protocol,
212                               self.stream_number,
213                               self.src,
214                               self.dest,
215                               self.protocol,
216                               self.opcode,
217                               self.desc,
218                               self.extra)
219
220     def as_packet_type(self):
221         t = '%s:%s' % (self.protocol, self.opcode)
222         return t
223
224     def client_score(self):
225         """A positive number means we think it is a client; a negative number
226         means we think it is a server. Zero means no idea. range: -1 to 1.
227         """
228         key = (self.protocol, self.opcode)
229         if key in CLIENT_CLUES:
230             return CLIENT_CLUES[key]
231         if key in SERVER_CLUES:
232             return -SERVER_CLUES[key]
233         return 0.0
234
235     def play(self, conversation, context):
236         """Send the packet over the network, if required.
237
238         Some packets are ignored, i.e. for  protocols not handled,
239         server response messages, or messages that are generated by the
240         protocol layer associated with other packets.
241         """
242         fn_name = 'packet_%s_%s' % (self.protocol, self.opcode)
243         try:
244             fn = getattr(traffic_packets, fn_name)
245
246         except AttributeError as e:
247             print("Conversation(%s) Missing handler %s" %
248                   (conversation.conversation_id, fn_name),
249                   file=sys.stderr)
250             return
251
252         # Don't display a message for kerberos packets, they're not directly
253         # generated they're used to indicate kerberos should be used
254         if self.protocol != "kerberos":
255             debug(2, "Conversation(%s) Calling handler %s" %
256                      (conversation.conversation_id, fn_name))
257
258         start = time.time()
259         try:
260             if fn(self, conversation, context):
261                 # Only collect timing data for functions that generate
262                 # network traffic, or fail
263                 end = time.time()
264                 duration = end - start
265                 print("%f\t%s\t%s\t%s\t%f\tTrue\t" %
266                       (end, conversation.conversation_id, self.protocol,
267                        self.opcode, duration))
268         except Exception as e:
269             end = time.time()
270             duration = end - start
271             print("%f\t%s\t%s\t%s\t%f\tFalse\t%s" %
272                   (end, conversation.conversation_id, self.protocol,
273                    self.opcode, duration, e))
274
275     def __cmp__(self, other):
276         return self.timestamp - other.timestamp
277
278     def is_really_a_packet(self, missing_packet_stats=None):
279         """Is the packet one that can be ignored?
280
281         If so removing it will have no effect on the replay
282         """
283         if self.protocol in SKIPPED_PROTOCOLS:
284             # Ignore any packets for the protocols we're not interested in.
285             return False
286         if self.protocol == "ldap" and self.opcode == '':
287             # skip ldap continuation packets
288             return False
289
290         fn_name = 'packet_%s_%s' % (self.protocol, self.opcode)
291         fn = getattr(traffic_packets, fn_name, None)
292         if not fn:
293             print("missing packet %s" % fn_name, file=sys.stderr)
294             return False
295         if fn is traffic_packets.null_packet:
296             return False
297         return True
298
299
300 class ReplayContext(object):
301     """State/Context for an individual conversation between an simulated client
302        and a server.
303     """
304
305     def __init__(self,
306                  server=None,
307                  lp=None,
308                  creds=None,
309                  badpassword_frequency=None,
310                  prefer_kerberos=None,
311                  tempdir=None,
312                  statsdir=None,
313                  ou=None,
314                  base_dn=None,
315                  domain=None,
316                  domain_sid=None):
317
318         self.server                   = server
319         self.ldap_connections         = []
320         self.dcerpc_connections       = []
321         self.lsarpc_connections       = []
322         self.lsarpc_connections_named = []
323         self.drsuapi_connections      = []
324         self.srvsvc_connections       = []
325         self.samr_contexts            = []
326         self.netlogon_connection      = None
327         self.creds                    = creds
328         self.lp                       = lp
329         self.prefer_kerberos          = prefer_kerberos
330         self.ou                       = ou
331         self.base_dn                  = base_dn
332         self.domain                   = domain
333         self.statsdir                 = statsdir
334         self.global_tempdir           = tempdir
335         self.domain_sid               = domain_sid
336         self.realm                    = lp.get('realm')
337
338         # Bad password attempt controls
339         self.badpassword_frequency    = badpassword_frequency
340         self.last_lsarpc_bad          = False
341         self.last_lsarpc_named_bad    = False
342         self.last_simple_bind_bad     = False
343         self.last_bind_bad            = False
344         self.last_srvsvc_bad          = False
345         self.last_drsuapi_bad         = False
346         self.last_netlogon_bad        = False
347         self.last_samlogon_bad        = False
348         self.generate_ldap_search_tables()
349         self.next_conversation_id = itertools.count()
350
351     def generate_ldap_search_tables(self):
352         session = system_session()
353
354         db = SamDB(url="ldap://%s" % self.server,
355                    session_info=session,
356                    credentials=self.creds,
357                    lp=self.lp)
358
359         res = db.search(db.domain_dn(),
360                         scope=ldb.SCOPE_SUBTREE,
361                         controls=["paged_results:1:1000"],
362                         attrs=['dn'])
363
364         # find a list of dns for each pattern
365         # e.g. CN,CN,CN,DC,DC
366         dn_map = {}
367         attribute_clue_map = {
368             'invocationId': []
369         }
370
371         for r in res:
372             dn = str(r.dn)
373             pattern = ','.join(x.lstrip()[:2] for x in dn.split(',')).upper()
374             dns = dn_map.setdefault(pattern, [])
375             dns.append(dn)
376             if dn.startswith('CN=NTDS Settings,'):
377                 attribute_clue_map['invocationId'].append(dn)
378
379         # extend the map in case we are working with a different
380         # number of DC components.
381         # for k, v in self.dn_map.items():
382         #     print >>sys.stderr, k, len(v)
383
384         for k in list(dn_map.keys()):
385             if k[-3:] != ',DC':
386                 continue
387             p = k[:-3]
388             while p[-3:] == ',DC':
389                 p = p[:-3]
390             for i in range(5):
391                 p += ',DC'
392                 if p != k and p in dn_map:
393                     print('dn_map collison %s %s' % (k, p),
394                           file=sys.stderr)
395                     continue
396                 dn_map[p] = dn_map[k]
397
398         self.dn_map = dn_map
399         self.attribute_clue_map = attribute_clue_map
400
401     def generate_process_local_config(self, account, conversation):
402         if account is None:
403             return
404         self.netbios_name             = account.netbios_name
405         self.machinepass              = account.machinepass
406         self.username                 = account.username
407         self.userpass                 = account.userpass
408
409         self.tempdir = mk_masked_dir(self.global_tempdir,
410                                      'conversation-%d' %
411                                      conversation.conversation_id)
412
413         self.lp.set("private dir", self.tempdir)
414         self.lp.set("lock dir", self.tempdir)
415         self.lp.set("state directory", self.tempdir)
416         self.lp.set("tls verify peer", "no_check")
417
418         # If the domain was not specified, check for the environment
419         # variable.
420         if self.domain is None:
421             self.domain = os.environ["DOMAIN"]
422
423         self.remoteAddress = "/root/ncalrpc_as_system"
424         self.samlogon_dn   = ("cn=%s,%s" %
425                               (self.netbios_name, self.ou))
426         self.user_dn       = ("cn=%s,%s" %
427                               (self.username, self.ou))
428
429         self.generate_machine_creds()
430         self.generate_user_creds()
431
432     def with_random_bad_credentials(self, f, good, bad, failed_last_time):
433         """Execute the supplied logon function, randomly choosing the
434            bad credentials.
435
436            Based on the frequency in badpassword_frequency randomly perform the
437            function with the supplied bad credentials.
438            If run with bad credentials, the function is re-run with the good
439            credentials.
440            failed_last_time is used to prevent consecutive bad credential
441            attempts. So the over all bad credential frequency will be lower
442            than that requested, but not significantly.
443         """
444         if not failed_last_time:
445             if (self.badpassword_frequency and self.badpassword_frequency > 0
446                 and random.random() < self.badpassword_frequency):
447                 try:
448                     f(bad)
449                 except:
450                     # Ignore any exceptions as the operation may fail
451                     # as it's being performed with bad credentials
452                     pass
453                 failed_last_time = True
454             else:
455                 failed_last_time = False
456
457         result = f(good)
458         return (result, failed_last_time)
459
460     def generate_user_creds(self):
461         """Generate the conversation specific user Credentials.
462
463         Each Conversation has an associated user account used to simulate
464         any non Administrative user traffic.
465
466         Generates user credentials with good and bad passwords and ldap
467         simple bind credentials with good and bad passwords.
468         """
469         self.user_creds = Credentials()
470         self.user_creds.guess(self.lp)
471         self.user_creds.set_workstation(self.netbios_name)
472         self.user_creds.set_password(self.userpass)
473         self.user_creds.set_username(self.username)
474         self.user_creds.set_domain(self.domain)
475         if self.prefer_kerberos:
476             self.user_creds.set_kerberos_state(MUST_USE_KERBEROS)
477         else:
478             self.user_creds.set_kerberos_state(DONT_USE_KERBEROS)
479
480         self.user_creds_bad = Credentials()
481         self.user_creds_bad.guess(self.lp)
482         self.user_creds_bad.set_workstation(self.netbios_name)
483         self.user_creds_bad.set_password(self.userpass[:-4])
484         self.user_creds_bad.set_username(self.username)
485         if self.prefer_kerberos:
486             self.user_creds_bad.set_kerberos_state(MUST_USE_KERBEROS)
487         else:
488             self.user_creds_bad.set_kerberos_state(DONT_USE_KERBEROS)
489
490         # Credentials for ldap simple bind.
491         self.simple_bind_creds = Credentials()
492         self.simple_bind_creds.guess(self.lp)
493         self.simple_bind_creds.set_workstation(self.netbios_name)
494         self.simple_bind_creds.set_password(self.userpass)
495         self.simple_bind_creds.set_username(self.username)
496         self.simple_bind_creds.set_gensec_features(
497             self.simple_bind_creds.get_gensec_features() | gensec.FEATURE_SEAL)
498         if self.prefer_kerberos:
499             self.simple_bind_creds.set_kerberos_state(MUST_USE_KERBEROS)
500         else:
501             self.simple_bind_creds.set_kerberos_state(DONT_USE_KERBEROS)
502         self.simple_bind_creds.set_bind_dn(self.user_dn)
503
504         self.simple_bind_creds_bad = Credentials()
505         self.simple_bind_creds_bad.guess(self.lp)
506         self.simple_bind_creds_bad.set_workstation(self.netbios_name)
507         self.simple_bind_creds_bad.set_password(self.userpass[:-4])
508         self.simple_bind_creds_bad.set_username(self.username)
509         self.simple_bind_creds_bad.set_gensec_features(
510             self.simple_bind_creds_bad.get_gensec_features() |
511             gensec.FEATURE_SEAL)
512         if self.prefer_kerberos:
513             self.simple_bind_creds_bad.set_kerberos_state(MUST_USE_KERBEROS)
514         else:
515             self.simple_bind_creds_bad.set_kerberos_state(DONT_USE_KERBEROS)
516         self.simple_bind_creds_bad.set_bind_dn(self.user_dn)
517
518     def generate_machine_creds(self):
519         """Generate the conversation specific machine Credentials.
520
521         Each Conversation has an associated machine account.
522
523         Generates machine credentials with good and bad passwords.
524         """
525
526         self.machine_creds = Credentials()
527         self.machine_creds.guess(self.lp)
528         self.machine_creds.set_workstation(self.netbios_name)
529         self.machine_creds.set_secure_channel_type(SEC_CHAN_BDC)
530         self.machine_creds.set_password(self.machinepass)
531         self.machine_creds.set_username(self.netbios_name + "$")
532         self.machine_creds.set_domain(self.domain)
533         if self.prefer_kerberos:
534             self.machine_creds.set_kerberos_state(MUST_USE_KERBEROS)
535         else:
536             self.machine_creds.set_kerberos_state(DONT_USE_KERBEROS)
537
538         self.machine_creds_bad = Credentials()
539         self.machine_creds_bad.guess(self.lp)
540         self.machine_creds_bad.set_workstation(self.netbios_name)
541         self.machine_creds_bad.set_secure_channel_type(SEC_CHAN_BDC)
542         self.machine_creds_bad.set_password(self.machinepass[:-4])
543         self.machine_creds_bad.set_username(self.netbios_name + "$")
544         if self.prefer_kerberos:
545             self.machine_creds_bad.set_kerberos_state(MUST_USE_KERBEROS)
546         else:
547             self.machine_creds_bad.set_kerberos_state(DONT_USE_KERBEROS)
548
549     def get_matching_dn(self, pattern, attributes=None):
550         # If the pattern is an empty string, we assume ROOTDSE,
551         # Otherwise we try adding or removing DC suffixes, then
552         # shorter leading patterns until we hit one.
553         # e.g if there is no CN,CN,CN,CN,DC,DC
554         # we first try       CN,CN,CN,CN,DC
555         # and                CN,CN,CN,CN,DC,DC,DC
556         # then change to        CN,CN,CN,DC,DC
557         # and as last resort we use the base_dn
558         attr_clue = self.attribute_clue_map.get(attributes)
559         if attr_clue:
560             return random.choice(attr_clue)
561
562         pattern = pattern.upper()
563         while pattern:
564             if pattern in self.dn_map:
565                 return random.choice(self.dn_map[pattern])
566             # chop one off the front and try it all again.
567             pattern = pattern[3:]
568
569         return self.base_dn
570
571     def get_dcerpc_connection(self, new=False):
572         guid = '12345678-1234-abcd-ef00-01234567cffb'  # RPC_NETLOGON UUID
573         if self.dcerpc_connections and not new:
574             return self.dcerpc_connections[-1]
575         c = ClientConnection("ncacn_ip_tcp:%s" % self.server,
576                              (guid, 1), self.lp)
577         self.dcerpc_connections.append(c)
578         return c
579
580     def get_srvsvc_connection(self, new=False):
581         if self.srvsvc_connections and not new:
582             return self.srvsvc_connections[-1]
583
584         def connect(creds):
585             return srvsvc.srvsvc("ncacn_np:%s" % (self.server),
586                                  self.lp,
587                                  creds)
588
589         (c, self.last_srvsvc_bad) = \
590             self.with_random_bad_credentials(connect,
591                                              self.user_creds,
592                                              self.user_creds_bad,
593                                              self.last_srvsvc_bad)
594
595         self.srvsvc_connections.append(c)
596         return c
597
598     def get_lsarpc_connection(self, new=False):
599         if self.lsarpc_connections and not new:
600             return self.lsarpc_connections[-1]
601
602         def connect(creds):
603             binding_options = 'schannel,seal,sign'
604             return lsa.lsarpc("ncacn_ip_tcp:%s[%s]" %
605                               (self.server, binding_options),
606                               self.lp,
607                               creds)
608
609         (c, self.last_lsarpc_bad) = \
610             self.with_random_bad_credentials(connect,
611                                              self.machine_creds,
612                                              self.machine_creds_bad,
613                                              self.last_lsarpc_bad)
614
615         self.lsarpc_connections.append(c)
616         return c
617
618     def get_lsarpc_named_pipe_connection(self, new=False):
619         if self.lsarpc_connections_named and not new:
620             return self.lsarpc_connections_named[-1]
621
622         def connect(creds):
623             return lsa.lsarpc("ncacn_np:%s" % (self.server),
624                               self.lp,
625                               creds)
626
627         (c, self.last_lsarpc_named_bad) = \
628             self.with_random_bad_credentials(connect,
629                                              self.machine_creds,
630                                              self.machine_creds_bad,
631                                              self.last_lsarpc_named_bad)
632
633         self.lsarpc_connections_named.append(c)
634         return c
635
636     def get_drsuapi_connection_pair(self, new=False, unbind=False):
637         """get a (drs, drs_handle) tuple"""
638         if self.drsuapi_connections and not new:
639             c = self.drsuapi_connections[-1]
640             return c
641
642         def connect(creds):
643             binding_options = 'seal'
644             binding_string = "ncacn_ip_tcp:%s[%s]" %\
645                              (self.server, binding_options)
646             return drsuapi.drsuapi(binding_string, self.lp, creds)
647
648         (drs, self.last_drsuapi_bad) = \
649             self.with_random_bad_credentials(connect,
650                                              self.user_creds,
651                                              self.user_creds_bad,
652                                              self.last_drsuapi_bad)
653
654         (drs_handle, supported_extensions) = drs_DsBind(drs)
655         c = (drs, drs_handle)
656         self.drsuapi_connections.append(c)
657         return c
658
659     def get_ldap_connection(self, new=False, simple=False):
660         if self.ldap_connections and not new:
661             return self.ldap_connections[-1]
662
663         def simple_bind(creds):
664             """
665             To run simple bind against Windows, we need to run
666             following commands in PowerShell:
667
668                 Install-windowsfeature ADCS-Cert-Authority
669                 Install-AdcsCertificationAuthority -CAType EnterpriseRootCA
670                 Restart-Computer
671
672             """
673             return SamDB('ldaps://%s' % self.server,
674                          credentials=creds,
675                          lp=self.lp)
676
677         def sasl_bind(creds):
678             return SamDB('ldap://%s' % self.server,
679                          credentials=creds,
680                          lp=self.lp)
681         if simple:
682             (samdb, self.last_simple_bind_bad) = \
683                 self.with_random_bad_credentials(simple_bind,
684                                                  self.simple_bind_creds,
685                                                  self.simple_bind_creds_bad,
686                                                  self.last_simple_bind_bad)
687         else:
688             (samdb, self.last_bind_bad) = \
689                 self.with_random_bad_credentials(sasl_bind,
690                                                  self.user_creds,
691                                                  self.user_creds_bad,
692                                                  self.last_bind_bad)
693
694         self.ldap_connections.append(samdb)
695         return samdb
696
697     def get_samr_context(self, new=False):
698         if not self.samr_contexts or new:
699             self.samr_contexts.append(
700                 SamrContext(self.server, lp=self.lp, creds=self.creds))
701         return self.samr_contexts[-1]
702
703     def get_netlogon_connection(self):
704
705         if self.netlogon_connection:
706             return self.netlogon_connection
707
708         def connect(creds):
709             return netlogon.netlogon("ncacn_ip_tcp:%s[schannel,seal]" %
710                                      (self.server),
711                                      self.lp,
712                                      creds)
713         (c, self.last_netlogon_bad) = \
714             self.with_random_bad_credentials(connect,
715                                              self.machine_creds,
716                                              self.machine_creds_bad,
717                                              self.last_netlogon_bad)
718         self.netlogon_connection = c
719         return c
720
721     def guess_a_dns_lookup(self):
722         return (self.realm, 'A')
723
724     def get_authenticator(self):
725         auth = self.machine_creds.new_client_authenticator()
726         current  = netr_Authenticator()
727         current.cred.data = [x if isinstance(x, int) else ord(x) for x in auth["credential"]]
728         current.timestamp = auth["timestamp"]
729
730         subsequent = netr_Authenticator()
731         return (current, subsequent)
732
733
734 class SamrContext(object):
735     """State/Context associated with a samr connection.
736     """
737     def __init__(self, server, lp=None, creds=None):
738         self.connection    = None
739         self.handle        = None
740         self.domain_handle = None
741         self.domain_sid    = None
742         self.group_handle  = None
743         self.user_handle   = None
744         self.rids          = None
745         self.server        = server
746         self.lp            = lp
747         self.creds         = creds
748
749     def get_connection(self):
750         if not self.connection:
751             self.connection = samr.samr(
752                 "ncacn_ip_tcp:%s[seal]" % (self.server),
753                 lp_ctx=self.lp,
754                 credentials=self.creds)
755
756         return self.connection
757
758     def get_handle(self):
759         if not self.handle:
760             c = self.get_connection()
761             self.handle = c.Connect2(None, security.SEC_FLAG_MAXIMUM_ALLOWED)
762         return self.handle
763
764
765 class Conversation(object):
766     """Details of a converation between a simulated client and a server."""
767     conversation_id = None
768
769     def __init__(self, start_time=None, endpoints=None):
770         self.start_time = start_time
771         self.endpoints = endpoints
772         self.packets = []
773         self.msg = random_colour_print()
774         self.client_balance = 0.0
775
776     def __cmp__(self, other):
777         if self.start_time is None:
778             if other.start_time is None:
779                 return 0
780             return -1
781         if other.start_time is None:
782             return 1
783         return self.start_time - other.start_time
784
785     def add_packet(self, packet):
786         """Add a packet object to this conversation, making a local copy with
787         a conversation-relative timestamp."""
788         p = packet.copy()
789
790         if self.start_time is None:
791             self.start_time = p.timestamp
792
793         if self.endpoints is None:
794             self.endpoints = p.endpoints
795
796         if p.endpoints != self.endpoints:
797             raise FakePacketError("Conversation endpoints %s don't match"
798                                   "packet endpoints %s" %
799                                   (self.endpoints, p.endpoints))
800
801         p.timestamp -= self.start_time
802
803         if p.src == p.endpoints[0]:
804             self.client_balance -= p.client_score()
805         else:
806             self.client_balance += p.client_score()
807
808         if p.is_really_a_packet():
809             self.packets.append(p)
810
811     def add_short_packet(self, timestamp, protocol, opcode, extra,
812                          client=True):
813         """Create a packet from a timestamp, and 'protocol:opcode' pair, and a
814         (possibly empty) list of extra data. If client is True, assume
815         this packet is from the client to the server.
816         """
817         src, dest = self.guess_client_server()
818         if not client:
819             src, dest = dest, src
820         key = (protocol, opcode)
821         desc = OP_DESCRIPTIONS[key] if key in OP_DESCRIPTIONS else ''
822         if protocol in IP_PROTOCOLS:
823             ip_protocol = IP_PROTOCOLS[protocol]
824         else:
825             ip_protocol = '06'
826         packet = Packet(timestamp - self.start_time, ip_protocol,
827                         '', src, dest,
828                         protocol, opcode, desc, extra)
829         # XXX we're assuming the timestamp is already adjusted for
830         # this conversation?
831         # XXX should we adjust client balance for guessed packets?
832         if packet.src == packet.endpoints[0]:
833             self.client_balance -= packet.client_score()
834         else:
835             self.client_balance += packet.client_score()
836         if packet.is_really_a_packet():
837             self.packets.append(packet)
838
839     def __str__(self):
840         return ("<Conversation %s %s starting %.3f %d packets>" %
841                 (self.conversation_id, self.endpoints, self.start_time,
842                  len(self.packets)))
843
844     __repr__ = __str__
845
846     def __iter__(self):
847         return iter(self.packets)
848
849     def __len__(self):
850         return len(self.packets)
851
852     def get_duration(self):
853         if len(self.packets) < 2:
854             return 0
855         return self.packets[-1].timestamp - self.packets[0].timestamp
856
857     def replay_as_summary_lines(self):
858         lines = []
859         for p in self.packets:
860             lines.append(p.as_summary(self.start_time))
861         return lines
862
863     def replay_in_fork_with_delay(self, start, context=None, account=None):
864         """Fork a new process and replay the conversation.
865         """
866         def signal_handler(signal, frame):
867             """Signal handler closes standard out and error.
868
869             Triggered by a sigterm, ensures that the log messages are flushed
870             to disk and not lost.
871             """
872             sys.stderr.close()
873             sys.stdout.close()
874             os._exit(0)
875
876         t = self.start_time
877         now = time.time() - start
878         gap = t - now
879         # we are replaying strictly in order, so it is safe to sleep
880         # in the main process if the gap is big enough. This reduces
881         # the number of concurrent threads, which allows us to make
882         # larger loads.
883         if gap > 0.15 and False:
884             print("sleeping for %f in main process" % (gap - 0.1),
885                   file=sys.stderr)
886             time.sleep(gap - 0.1)
887             now = time.time() - start
888             gap = t - now
889             print("gap is now %f" % gap, file=sys.stderr)
890
891         self.conversation_id = next(context.next_conversation_id)
892         pid = os.fork()
893         if pid != 0:
894             return pid
895         pid = os.getpid()
896         signal.signal(signal.SIGTERM, signal_handler)
897         # we must never return, or we'll end up running parts of the
898         # parent's clean-up code. So we work in a try...finally, and
899         # try to print any exceptions.
900
901         try:
902             context.generate_process_local_config(account, self)
903             sys.stdin.close()
904             os.close(0)
905             filename = os.path.join(context.statsdir, 'stats-conversation-%d' %
906                                     self.conversation_id)
907             sys.stdout.close()
908             sys.stdout = open(filename, 'w')
909
910             sleep_time = gap - SLEEP_OVERHEAD
911             if sleep_time > 0:
912                 time.sleep(sleep_time)
913
914             miss = t - (time.time() - start)
915             self.msg("starting %s [miss %.3f pid %d]" % (self, miss, pid))
916             self.replay(context)
917         except Exception:
918             print(("EXCEPTION in child PID %d, conversation %s" % (pid, self)),
919                   file=sys.stderr)
920             traceback.print_exc(sys.stderr)
921         finally:
922             sys.stderr.close()
923             sys.stdout.close()
924             os._exit(0)
925
926     def replay(self, context=None):
927         start = time.time()
928
929         for p in self.packets:
930             now = time.time() - start
931             gap = p.timestamp - now
932             sleep_time = gap - SLEEP_OVERHEAD
933             if sleep_time > 0:
934                 time.sleep(sleep_time)
935
936             miss = p.timestamp - (time.time() - start)
937             if context is None:
938                 self.msg("packet %s [miss %.3f pid %d]" % (p, miss,
939                                                            os.getpid()))
940                 continue
941             p.play(self, context)
942
943     def guess_client_server(self, server_clue=None):
944         """Have a go at deciding who is the server and who is the client.
945         returns (client, server)
946         """
947         a, b = self.endpoints
948
949         if self.client_balance < 0:
950             return (a, b)
951
952         # in the absense of a clue, we will fall through to assuming
953         # the lowest number is the server (which is usually true).
954
955         if self.client_balance == 0 and server_clue == b:
956             return (a, b)
957
958         return (b, a)
959
960     def forget_packets_outside_window(self, s, e):
961         """Prune any packets outside the timne window we're interested in
962
963         :param s: start of the window
964         :param e: end of the window
965         """
966         self.packets = [p for p in self.packets if s <= p.timestamp <= e]
967         self.start_time = self.packets[0].timestamp if self.packets else None
968
969     def renormalise_times(self, start_time):
970         """Adjust the packet start times relative to the new start time."""
971         for p in self.packets:
972             p.timestamp -= start_time
973
974         if self.start_time is not None:
975             self.start_time -= start_time
976
977
978 class DnsHammer(Conversation):
979     """A lightweight conversation that generates a lot of dns:0 packets on
980     the fly"""
981
982     def __init__(self, dns_rate, duration):
983         n = int(dns_rate * duration)
984         self.times = [random.uniform(0, duration) for i in range(n)]
985         self.times.sort()
986         self.rate = dns_rate
987         self.duration = duration
988         self.start_time = 0
989         self.msg = random_colour_print()
990
991     def __str__(self):
992         return ("<DnsHammer %d packets over %.1fs (rate %.2f)>" %
993                 (len(self.times), self.duration, self.rate))
994
995     def replay_in_fork_with_delay(self, start, context=None, account=None):
996         return Conversation.replay_in_fork_with_delay(self,
997                                                       start,
998                                                       context,
999                                                       account)
1000
1001     def replay(self, context=None):
1002         start = time.time()
1003         fn = traffic_packets.packet_dns_0
1004         for t in self.times:
1005             now = time.time() - start
1006             gap = t - now
1007             sleep_time = gap - SLEEP_OVERHEAD
1008             if sleep_time > 0:
1009                 time.sleep(sleep_time)
1010
1011             if context is None:
1012                 miss = t - (time.time() - start)
1013                 self.msg("packet %s [miss %.3f pid %d]" % (t, miss,
1014                                                            os.getpid()))
1015                 continue
1016
1017             packet_start = time.time()
1018             try:
1019                 fn(self, self, context)
1020                 end = time.time()
1021                 duration = end - packet_start
1022                 print("%f\tDNS\tdns\t0\t%f\tTrue\t" % (end, duration))
1023             except Exception as e:
1024                 end = time.time()
1025                 duration = end - packet_start
1026                 print("%f\tDNS\tdns\t0\t%f\tFalse\t%s" % (end, duration, e))
1027
1028
1029 def ingest_summaries(files, dns_mode='count'):
1030     """Load a summary traffic summary file and generated Converations from it.
1031     """
1032
1033     dns_counts = defaultdict(int)
1034     packets = []
1035     for f in files:
1036         if isinstance(f, str):
1037             f = open(f)
1038         print("Ingesting %s" % (f.name,), file=sys.stderr)
1039         for line in f:
1040             p = Packet.from_line(line)
1041             if p.protocol == 'dns' and dns_mode != 'include':
1042                 dns_counts[p.opcode] += 1
1043             else:
1044                 packets.append(p)
1045
1046         f.close()
1047
1048     if not packets:
1049         return [], 0
1050
1051     start_time = min(p.timestamp for p in packets)
1052     last_packet = max(p.timestamp for p in packets)
1053
1054     print("gathering packets into conversations", file=sys.stderr)
1055     conversations = OrderedDict()
1056     for p in packets:
1057         p.timestamp -= start_time
1058         c = conversations.get(p.endpoints)
1059         if c is None:
1060             c = Conversation()
1061             conversations[p.endpoints] = c
1062         c.add_packet(p)
1063
1064     # We only care about conversations with actual traffic, so we
1065     # filter out conversations with nothing to say. We do that here,
1066     # rather than earlier, because those empty packets contain useful
1067     # hints as to which end of the conversation was the client.
1068     conversation_list = []
1069     for c in conversations.values():
1070         if len(c) != 0:
1071             conversation_list.append(c)
1072
1073     # This is obviously not correct, as many conversations will appear
1074     # to start roughly simultaneously at the beginning of the snapshot.
1075     # To which we say: oh well, so be it.
1076     duration = float(last_packet - start_time)
1077     mean_interval = len(conversations) / duration
1078
1079     return conversation_list, mean_interval, duration, dns_counts
1080
1081
1082 def guess_server_address(conversations):
1083     # we guess the most common address.
1084     addresses = Counter()
1085     for c in conversations:
1086         addresses.update(c.endpoints)
1087     if addresses:
1088         return addresses.most_common(1)[0]
1089
1090
1091 def stringify_keys(x):
1092     y = {}
1093     for k, v in x.items():
1094         k2 = '\t'.join(k)
1095         y[k2] = v
1096     return y
1097
1098
1099 def unstringify_keys(x):
1100     y = {}
1101     for k, v in x.items():
1102         t = tuple(str(k).split('\t'))
1103         y[t] = v
1104     return y
1105
1106
1107 class TrafficModel(object):
1108     def __init__(self, n=3):
1109         self.ngrams = {}
1110         self.query_details = {}
1111         self.n = n
1112         self.dns_opcounts = defaultdict(int)
1113         self.cumulative_duration = 0.0
1114         self.conversation_rate = [0, 1]
1115
1116     def learn(self, conversations, dns_opcounts={}):
1117         prev = 0.0
1118         cum_duration = 0.0
1119         key = (NON_PACKET,) * (self.n - 1)
1120
1121         server = guess_server_address(conversations)
1122
1123         for k, v in dns_opcounts.items():
1124             self.dns_opcounts[k] += v
1125
1126         if len(conversations) > 1:
1127             elapsed =\
1128                 conversations[-1].start_time - conversations[0].start_time
1129             self.conversation_rate[0] = len(conversations)
1130             self.conversation_rate[1] = elapsed
1131
1132         for c in conversations:
1133             client, server = c.guess_client_server(server)
1134             cum_duration += c.get_duration()
1135             key = (NON_PACKET,) * (self.n - 1)
1136             for p in c:
1137                 if p.src != client:
1138                     continue
1139
1140                 elapsed = p.timestamp - prev
1141                 prev = p.timestamp
1142                 if elapsed > WAIT_THRESHOLD:
1143                     # add the wait as an extra state
1144                     wait = 'wait:%d' % (math.log(max(1.0,
1145                                                      elapsed * WAIT_SCALE)))
1146                     self.ngrams.setdefault(key, []).append(wait)
1147                     key = key[1:] + (wait,)
1148
1149                 short_p = p.as_packet_type()
1150                 self.query_details.setdefault(short_p,
1151                                               []).append(tuple(p.extra))
1152                 self.ngrams.setdefault(key, []).append(short_p)
1153                 key = key[1:] + (short_p,)
1154
1155         self.cumulative_duration += cum_duration
1156         # add in the end
1157         self.ngrams.setdefault(key, []).append(NON_PACKET)
1158
1159     def save(self, f):
1160         ngrams = {}
1161         for k, v in self.ngrams.items():
1162             k = '\t'.join(k)
1163             ngrams[k] = dict(Counter(v))
1164
1165         query_details = {}
1166         for k, v in self.query_details.items():
1167             query_details[k] = dict(Counter('\t'.join(x) if x else '-'
1168                                             for x in v))
1169
1170         d = {
1171             'ngrams': ngrams,
1172             'query_details': query_details,
1173             'cumulative_duration': self.cumulative_duration,
1174             'conversation_rate': self.conversation_rate,
1175         }
1176         d['dns'] = self.dns_opcounts
1177
1178         if isinstance(f, str):
1179             f = open(f, 'w')
1180
1181         json.dump(d, f, indent=2)
1182
1183     def load(self, f):
1184         if isinstance(f, str):
1185             f = open(f)
1186
1187         d = json.load(f)
1188
1189         for k, v in d['ngrams'].items():
1190             k = tuple(str(k).split('\t'))
1191             values = self.ngrams.setdefault(k, [])
1192             for p, count in v.items():
1193                 values.extend([str(p)] * count)
1194
1195         for k, v in d['query_details'].items():
1196             values = self.query_details.setdefault(str(k), [])
1197             for p, count in v.items():
1198                 if p == '-':
1199                     values.extend([()] * count)
1200                 else:
1201                     values.extend([tuple(str(p).split('\t'))] * count)
1202
1203         if 'dns' in d:
1204             for k, v in d['dns'].items():
1205                 self.dns_opcounts[k] += v
1206
1207         self.cumulative_duration = d['cumulative_duration']
1208         self.conversation_rate = d['conversation_rate']
1209
1210     def construct_conversation(self, timestamp=0.0, client=2, server=1,
1211                                hard_stop=None, packet_rate=1):
1212         """Construct a individual converation from the model."""
1213
1214         c = Conversation(timestamp, (server, client))
1215
1216         key = (NON_PACKET,) * (self.n - 1)
1217
1218         while key in self.ngrams:
1219             p = random.choice(self.ngrams.get(key, NON_PACKET))
1220             if p == NON_PACKET:
1221                 break
1222             if p in self.query_details:
1223                 extra = random.choice(self.query_details[p])
1224             else:
1225                 extra = []
1226
1227             protocol, opcode = p.split(':', 1)
1228             if protocol == 'wait':
1229                 log_wait_time = int(opcode) + random.random()
1230                 wait = math.exp(log_wait_time) / (WAIT_SCALE * packet_rate)
1231                 timestamp += wait
1232             else:
1233                 log_wait = random.uniform(*NO_WAIT_LOG_TIME_RANGE)
1234                 wait = math.exp(log_wait) / packet_rate
1235                 timestamp += wait
1236                 if hard_stop is not None and timestamp > hard_stop:
1237                     break
1238                 c.add_short_packet(timestamp, protocol, opcode, extra)
1239
1240             key = key[1:] + (p,)
1241
1242         return c
1243
1244     def generate_conversations(self, rate, duration, packet_rate=1):
1245         """Generate a list of conversations from the model."""
1246
1247         # We run the simulation for at least ten times as long as our
1248         # desired duration, and take a section near the start.
1249         rate_n, rate_t  = self.conversation_rate
1250
1251         duration2 = max(rate_t, duration * 2)
1252         n = rate * duration2 * rate_n / rate_t
1253
1254         server = 1
1255         client = 2
1256
1257         conversations = []
1258         end = duration2
1259         start = end - duration
1260
1261         while client < n + 2:
1262             start = random.uniform(0, duration2)
1263             c = self.construct_conversation(start,
1264                                             client,
1265                                             server,
1266                                             hard_stop=(duration2 * 5),
1267                                             packet_rate=packet_rate)
1268
1269             c.forget_packets_outside_window(start, end)
1270             c.renormalise_times(start)
1271             if len(c) != 0:
1272                 conversations.append(c)
1273             client += 1
1274
1275         print(("we have %d conversations at rate %f" %
1276                (len(conversations), rate)), file=sys.stderr)
1277         conversations.sort()
1278         return conversations
1279
1280
1281 IP_PROTOCOLS = {
1282     'dns': '11',
1283     'rpc_netlogon': '06',
1284     'kerberos': '06',      # ratio 16248:258
1285     'smb': '06',
1286     'smb2': '06',
1287     'ldap': '06',
1288     'cldap': '11',
1289     'lsarpc': '06',
1290     'samr': '06',
1291     'dcerpc': '06',
1292     'epm': '06',
1293     'drsuapi': '06',
1294     'browser': '11',
1295     'smb_netlogon': '11',
1296     'srvsvc': '06',
1297     'nbns': '11',
1298 }
1299
1300 OP_DESCRIPTIONS = {
1301     ('browser', '0x01'): 'Host Announcement (0x01)',
1302     ('browser', '0x02'): 'Request Announcement (0x02)',
1303     ('browser', '0x08'): 'Browser Election Request (0x08)',
1304     ('browser', '0x09'): 'Get Backup List Request (0x09)',
1305     ('browser', '0x0c'): 'Domain/Workgroup Announcement (0x0c)',
1306     ('browser', '0x0f'): 'Local Master Announcement (0x0f)',
1307     ('cldap', '3'): 'searchRequest',
1308     ('cldap', '5'): 'searchResDone',
1309     ('dcerpc', '0'): 'Request',
1310     ('dcerpc', '11'): 'Bind',
1311     ('dcerpc', '12'): 'Bind_ack',
1312     ('dcerpc', '13'): 'Bind_nak',
1313     ('dcerpc', '14'): 'Alter_context',
1314     ('dcerpc', '15'): 'Alter_context_resp',
1315     ('dcerpc', '16'): 'AUTH3',
1316     ('dcerpc', '2'): 'Response',
1317     ('dns', '0'): 'query',
1318     ('dns', '1'): 'response',
1319     ('drsuapi', '0'): 'DsBind',
1320     ('drsuapi', '12'): 'DsCrackNames',
1321     ('drsuapi', '13'): 'DsWriteAccountSpn',
1322     ('drsuapi', '1'): 'DsUnbind',
1323     ('drsuapi', '2'): 'DsReplicaSync',
1324     ('drsuapi', '3'): 'DsGetNCChanges',
1325     ('drsuapi', '4'): 'DsReplicaUpdateRefs',
1326     ('epm', '3'): 'Map',
1327     ('kerberos', ''): '',
1328     ('ldap', '0'): 'bindRequest',
1329     ('ldap', '1'): 'bindResponse',
1330     ('ldap', '2'): 'unbindRequest',
1331     ('ldap', '3'): 'searchRequest',
1332     ('ldap', '4'): 'searchResEntry',
1333     ('ldap', '5'): 'searchResDone',
1334     ('ldap', ''): '*** Unknown ***',
1335     ('lsarpc', '14'): 'lsa_LookupNames',
1336     ('lsarpc', '15'): 'lsa_LookupSids',
1337     ('lsarpc', '39'): 'lsa_QueryTrustedDomainInfoBySid',
1338     ('lsarpc', '40'): 'lsa_SetTrustedDomainInfo',
1339     ('lsarpc', '6'): 'lsa_OpenPolicy',
1340     ('lsarpc', '76'): 'lsa_LookupSids3',
1341     ('lsarpc', '77'): 'lsa_LookupNames4',
1342     ('nbns', '0'): 'query',
1343     ('nbns', '1'): 'response',
1344     ('rpc_netlogon', '21'): 'NetrLogonDummyRoutine1',
1345     ('rpc_netlogon', '26'): 'NetrServerAuthenticate3',
1346     ('rpc_netlogon', '29'): 'NetrLogonGetDomainInfo',
1347     ('rpc_netlogon', '30'): 'NetrServerPasswordSet2',
1348     ('rpc_netlogon', '39'): 'NetrLogonSamLogonEx',
1349     ('rpc_netlogon', '40'): 'DsrEnumerateDomainTrusts',
1350     ('rpc_netlogon', '45'): 'NetrLogonSamLogonWithFlags',
1351     ('rpc_netlogon', '4'): 'NetrServerReqChallenge',
1352     ('samr', '0',): 'Connect',
1353     ('samr', '16'): 'GetAliasMembership',
1354     ('samr', '17'): 'LookupNames',
1355     ('samr', '18'): 'LookupRids',
1356     ('samr', '19'): 'OpenGroup',
1357     ('samr', '1'): 'Close',
1358     ('samr', '25'): 'QueryGroupMember',
1359     ('samr', '34'): 'OpenUser',
1360     ('samr', '36'): 'QueryUserInfo',
1361     ('samr', '39'): 'GetGroupsForUser',
1362     ('samr', '3'): 'QuerySecurity',
1363     ('samr', '5'): 'LookupDomain',
1364     ('samr', '64'): 'Connect5',
1365     ('samr', '6'): 'EnumDomains',
1366     ('samr', '7'): 'OpenDomain',
1367     ('samr', '8'): 'QueryDomainInfo',
1368     ('smb', '0x04'): 'Close (0x04)',
1369     ('smb', '0x24'): 'Locking AndX (0x24)',
1370     ('smb', '0x2e'): 'Read AndX (0x2e)',
1371     ('smb', '0x32'): 'Trans2 (0x32)',
1372     ('smb', '0x71'): 'Tree Disconnect (0x71)',
1373     ('smb', '0x72'): 'Negotiate Protocol (0x72)',
1374     ('smb', '0x73'): 'Session Setup AndX (0x73)',
1375     ('smb', '0x74'): 'Logoff AndX (0x74)',
1376     ('smb', '0x75'): 'Tree Connect AndX (0x75)',
1377     ('smb', '0xa2'): 'NT Create AndX (0xa2)',
1378     ('smb2', '0'): 'NegotiateProtocol',
1379     ('smb2', '11'): 'Ioctl',
1380     ('smb2', '14'): 'Find',
1381     ('smb2', '16'): 'GetInfo',
1382     ('smb2', '18'): 'Break',
1383     ('smb2', '1'): 'SessionSetup',
1384     ('smb2', '2'): 'SessionLogoff',
1385     ('smb2', '3'): 'TreeConnect',
1386     ('smb2', '4'): 'TreeDisconnect',
1387     ('smb2', '5'): 'Create',
1388     ('smb2', '6'): 'Close',
1389     ('smb2', '8'): 'Read',
1390     ('smb_netlogon', '0x12'): 'SAM LOGON request from client (0x12)',
1391     ('smb_netlogon', '0x17'): ('SAM Active Directory Response - '
1392                                'user unknown (0x17)'),
1393     ('srvsvc', '16'): 'NetShareGetInfo',
1394     ('srvsvc', '21'): 'NetSrvGetInfo',
1395 }
1396
1397
1398 def expand_short_packet(p, timestamp, src, dest, extra):
1399     protocol, opcode = p.split(':', 1)
1400     desc = OP_DESCRIPTIONS.get((protocol, opcode), '')
1401     ip_protocol = IP_PROTOCOLS.get(protocol, '06')
1402
1403     line = [timestamp, ip_protocol, '', src, dest, protocol, opcode, desc]
1404     line.extend(extra)
1405     return '\t'.join(line)
1406
1407
1408 def replay(conversations,
1409            host=None,
1410            creds=None,
1411            lp=None,
1412            accounts=None,
1413            dns_rate=0,
1414            duration=None,
1415            **kwargs):
1416
1417     context = ReplayContext(server=host,
1418                             creds=creds,
1419                             lp=lp,
1420                             **kwargs)
1421
1422     if len(accounts) < len(conversations):
1423         print(("we have %d accounts but %d conversations" %
1424                (accounts, conversations)), file=sys.stderr)
1425
1426     cstack = list(zip(
1427         sorted(conversations, key=lambda x: x.start_time, reverse=True),
1428         accounts))
1429
1430     # Set the process group so that the calling scripts are not killed
1431     # when the forked child processes are killed.
1432     os.setpgrp()
1433
1434     start = time.time()
1435
1436     if duration is None:
1437         # end 1 second after the last packet of the last conversation
1438         # to start. Conversations other than the last could still be
1439         # going, but we don't care.
1440         duration = cstack[0][0].packets[-1].timestamp + 1.0
1441         print("We will stop after %.1f seconds" % duration,
1442               file=sys.stderr)
1443
1444     end = start + duration
1445
1446     LOGGER.info("Replaying traffic for %u conversations over %d seconds"
1447           % (len(conversations), duration))
1448
1449     children = {}
1450     if dns_rate:
1451         dns_hammer = DnsHammer(dns_rate, duration)
1452         cstack.append((dns_hammer, None))
1453
1454     try:
1455         while True:
1456             # we spawn a batch, wait for finishers, then spawn another
1457             now = time.time()
1458             batch_end = min(now + 2.0, end)
1459             fork_time = 0.0
1460             fork_n = 0
1461             while cstack:
1462                 c, account = cstack.pop()
1463                 if c.start_time + start > batch_end:
1464                     cstack.append((c, account))
1465                     break
1466
1467                 st = time.time()
1468                 pid = c.replay_in_fork_with_delay(start, context, account)
1469                 children[pid] = c
1470                 t = time.time()
1471                 elapsed = t - st
1472                 fork_time += elapsed
1473                 fork_n += 1
1474                 print("forked %s in pid %s (in %fs)" % (c, pid,
1475                                                         elapsed),
1476                       file=sys.stderr)
1477
1478             if fork_n:
1479                 print(("forked %d times in %f seconds (avg %f)" %
1480                        (fork_n, fork_time, fork_time / fork_n)),
1481                       file=sys.stderr)
1482             elif cstack:
1483                 debug(2, "no forks in batch ending %f" % batch_end)
1484
1485             while time.time() < batch_end - 1.0:
1486                 time.sleep(0.01)
1487                 try:
1488                     pid, status = os.waitpid(-1, os.WNOHANG)
1489                 except OSError as e:
1490                     if e.errno != 10:  # no child processes
1491                         raise
1492                     break
1493                 if pid:
1494                     c = children.pop(pid, None)
1495                     print(("process %d finished conversation %s;"
1496                            " %d to go" %
1497                            (pid, c, len(children))), file=sys.stderr)
1498
1499             if time.time() >= end:
1500                 print("time to stop", file=sys.stderr)
1501                 break
1502
1503     except Exception:
1504         print("EXCEPTION in parent", file=sys.stderr)
1505         traceback.print_exc()
1506     finally:
1507         for s in (15, 15, 9):
1508             print(("killing %d children with -%d" %
1509                    (len(children), s)), file=sys.stderr)
1510             for pid in children:
1511                 try:
1512                     os.kill(pid, s)
1513                 except OSError as e:
1514                     if e.errno != 3:  # don't fail if it has already died
1515                         raise
1516             time.sleep(0.5)
1517             end = time.time() + 1
1518             while children:
1519                 try:
1520                     pid, status = os.waitpid(-1, os.WNOHANG)
1521                 except OSError as e:
1522                     if e.errno != 10:
1523                         raise
1524                 if pid != 0:
1525                     c = children.pop(pid, None)
1526                     print(("kill -%d %d KILLED conversation %s; "
1527                            "%d to go" %
1528                            (s, pid, c, len(children))),
1529                           file=sys.stderr)
1530                 if time.time() >= end:
1531                     break
1532
1533             if not children:
1534                 break
1535             time.sleep(1)
1536
1537         if children:
1538             print("%d children are missing" % len(children),
1539                   file=sys.stderr)
1540
1541         # there may be stragglers that were forked just as ^C was hit
1542         # and don't appear in the list of children. We can get them
1543         # with killpg, but that will also kill us, so this is^H^H would be
1544         # goodbye, except we cheat and pretend to use ^C (SIG_INTERRUPT),
1545         # so as not to have to fuss around writing signal handlers.
1546         try:
1547             os.killpg(0, 2)
1548         except KeyboardInterrupt:
1549             print("ignoring fake ^C", file=sys.stderr)
1550
1551
1552 def openLdb(host, creds, lp):
1553     session = system_session()
1554     ldb = SamDB(url="ldap://%s" % host,
1555                 session_info=session,
1556                 options=['modules:paged_searches'],
1557                 credentials=creds,
1558                 lp=lp)
1559     return ldb
1560
1561
1562 def ou_name(ldb, instance_id):
1563     """Generate an ou name from the instance id"""
1564     return "ou=instance-%d,ou=traffic_replay,%s" % (instance_id,
1565                                                     ldb.domain_dn())
1566
1567
1568 def create_ou(ldb, instance_id):
1569     """Create an ou, all created user and machine accounts will belong to it.
1570
1571     This allows all the created resources to be cleaned up easily.
1572     """
1573     ou = ou_name(ldb, instance_id)
1574     try:
1575         ldb.add({"dn": ou.split(',', 1)[1],
1576                  "objectclass": "organizationalunit"})
1577     except LdbError as e:
1578         (status, _) = e.args
1579         # ignore already exists
1580         if status != 68:
1581             raise
1582     try:
1583         ldb.add({"dn": ou,
1584                  "objectclass": "organizationalunit"})
1585     except LdbError as e:
1586         (status, _) = e.args
1587         # ignore already exists
1588         if status != 68:
1589             raise
1590     return ou
1591
1592
1593 class ConversationAccounts(object):
1594     """Details of the machine and user accounts associated with a conversation.
1595     """
1596     def __init__(self, netbios_name, machinepass, username, userpass):
1597         self.netbios_name = netbios_name
1598         self.machinepass  = machinepass
1599         self.username     = username
1600         self.userpass     = userpass
1601
1602
1603 def generate_replay_accounts(ldb, instance_id, number, password):
1604     """Generate a series of unique machine and user account names."""
1605
1606     generate_traffic_accounts(ldb, instance_id, number, password)
1607     accounts = []
1608     for i in range(1, number + 1):
1609         netbios_name = machine_name(instance_id, i)
1610         username = user_name(instance_id, i)
1611
1612         account = ConversationAccounts(netbios_name, password, username,
1613                                        password)
1614         accounts.append(account)
1615     return accounts
1616
1617
1618 def generate_traffic_accounts(ldb, instance_id, number, password):
1619     """Create the specified number of user and machine accounts.
1620
1621     As accounts are not explicitly deleted between runs. This function starts
1622     with the last account and iterates backwards stopping either when it
1623     finds an already existing account or it has generated all the required
1624     accounts.
1625     """
1626     print(("Generating machine and conversation accounts, "
1627            "as required for %d conversations" % number),
1628           file=sys.stderr)
1629     added = 0
1630     for i in range(number, 0, -1):
1631         try:
1632             netbios_name = machine_name(instance_id, i)
1633             create_machine_account(ldb, instance_id, netbios_name, password)
1634             added += 1
1635             if added % 50 == 0:
1636                 LOGGER.info("Created %u/%u machine accounts" % (added, number))
1637         except LdbError as e:
1638             (status, _) = e.args
1639             if status == 68:
1640                 break
1641             else:
1642                 raise
1643     if added > 0:
1644         LOGGER.info("Added %d new machine accounts" % added)
1645
1646     added = 0
1647     for i in range(number, 0, -1):
1648         try:
1649             username = user_name(instance_id, i)
1650             create_user_account(ldb, instance_id, username, password)
1651             added += 1
1652             if added % 50 == 0:
1653                 LOGGER.info("Created %u/%u users" % (added, number))
1654
1655         except LdbError as e:
1656             (status, _) = e.args
1657             if status == 68:
1658                 break
1659             else:
1660                 raise
1661
1662     if added > 0:
1663         LOGGER.info("Added %d new user accounts" % added)
1664
1665
1666 def create_machine_account(ldb, instance_id, netbios_name, machinepass,
1667                            traffic_account=True):
1668     """Create a machine account via ldap."""
1669
1670     ou = ou_name(ldb, instance_id)
1671     dn = "cn=%s,%s" % (netbios_name, ou)
1672     utf16pw = ('"%s"' % get_string(machinepass)).encode('utf-16-le')
1673
1674     if traffic_account:
1675         # we set these bits for the machine account otherwise the replayed
1676         # traffic throws up NT_STATUS_NO_TRUST_SAM_ACCOUNT errors
1677         account_controls = str(UF_TRUSTED_FOR_DELEGATION |
1678                                UF_SERVER_TRUST_ACCOUNT)
1679
1680     else:
1681         account_controls = str(UF_WORKSTATION_TRUST_ACCOUNT)
1682
1683     ldb.add({
1684         "dn": dn,
1685         "objectclass": "computer",
1686         "sAMAccountName": "%s$" % netbios_name,
1687         "userAccountControl": account_controls,
1688         "unicodePwd": utf16pw})
1689
1690
1691 def create_user_account(ldb, instance_id, username, userpass):
1692     """Create a user account via ldap."""
1693     ou = ou_name(ldb, instance_id)
1694     user_dn = "cn=%s,%s" % (username, ou)
1695     utf16pw = ('"%s"' % get_string(userpass)).encode('utf-16-le')
1696     ldb.add({
1697         "dn": user_dn,
1698         "objectclass": "user",
1699         "sAMAccountName": username,
1700         "userAccountControl": str(UF_NORMAL_ACCOUNT),
1701         "unicodePwd": utf16pw
1702     })
1703
1704     # grant user write permission to do things like write account SPN
1705     sdutils = sd_utils.SDUtils(ldb)
1706     sdutils.dacl_add_ace(user_dn, "(A;;WP;;;PS)")
1707
1708
1709 def create_group(ldb, instance_id, name):
1710     """Create a group via ldap."""
1711
1712     ou = ou_name(ldb, instance_id)
1713     dn = "cn=%s,%s" % (name, ou)
1714     ldb.add({
1715         "dn": dn,
1716         "objectclass": "group",
1717         "sAMAccountName": name,
1718     })
1719
1720
1721 def user_name(instance_id, i):
1722     """Generate a user name based in the instance id"""
1723     return "STGU-%d-%d" % (instance_id, i)
1724
1725
1726 def search_objectclass(ldb, objectclass='user', attr='sAMAccountName'):
1727     """Seach objectclass, return attr in a set"""
1728     objs = ldb.search(
1729         expression="(objectClass={})".format(objectclass),
1730         attrs=[attr]
1731     )
1732     return {str(obj[attr]) for obj in objs}
1733
1734
1735 def generate_users(ldb, instance_id, number, password):
1736     """Add users to the server"""
1737     existing_objects = search_objectclass(ldb, objectclass='user')
1738     users = 0
1739     for i in range(number, 0, -1):
1740         name = user_name(instance_id, i)
1741         if name not in existing_objects:
1742             create_user_account(ldb, instance_id, name, password)
1743             users += 1
1744             if users % 50 == 0:
1745                 LOGGER.info("Created %u/%u users" % (users, number))
1746
1747     return users
1748
1749
1750 def machine_name(instance_id, i):
1751     """Generate a machine account name from instance id."""
1752     return "STGM-%d-%d" % (instance_id, i)
1753
1754
1755 def generate_machine_accounts(ldb, instance_id, number, password,
1756                               traffic_account=True):
1757     """Add machine accounts to the server"""
1758     existing_objects = search_objectclass(ldb, objectclass='computer')
1759     added = 0
1760     for i in range(number, 0, -1):
1761         name = machine_name(instance_id, i)
1762         if name + "$" not in existing_objects:
1763             create_machine_account(ldb, instance_id, name, password,
1764                                    traffic_account)
1765             added += 1
1766             if added % 50 == 0:
1767                 LOGGER.info("Created %u/%u machine accounts" % (added, number))
1768
1769     return added
1770
1771
1772 def group_name(instance_id, i):
1773     """Generate a group name from instance id."""
1774     return "STGG-%d-%d" % (instance_id, i)
1775
1776
1777 def generate_groups(ldb, instance_id, number):
1778     """Create the required number of groups on the server."""
1779     existing_objects = search_objectclass(ldb, objectclass='group')
1780     groups = 0
1781     for i in range(number, 0, -1):
1782         name = group_name(instance_id, i)
1783         if name not in existing_objects:
1784             create_group(ldb, instance_id, name)
1785             groups += 1
1786             if groups % 1000 == 0:
1787                 LOGGER.info("Created %u/%u groups" % (groups, number))
1788
1789     return groups
1790
1791
1792 def clean_up_accounts(ldb, instance_id):
1793     """Remove the created accounts and groups from the server."""
1794     ou = ou_name(ldb, instance_id)
1795     try:
1796         ldb.delete(ou, ["tree_delete:1"])
1797     except LdbError as e:
1798         (status, _) = e.args
1799         # ignore does not exist
1800         if status != 32:
1801             raise
1802
1803
1804 def generate_users_and_groups(ldb, instance_id, password,
1805                               number_of_users, number_of_groups,
1806                               group_memberships, machine_accounts=0,
1807                               traffic_accounts=True):
1808     """Generate the required users and groups, allocating the users to
1809        those groups."""
1810     memberships_added = 0
1811     groups_added = 0
1812     computers_added = 0
1813
1814     create_ou(ldb, instance_id)
1815
1816     LOGGER.info("Generating dummy user accounts")
1817     users_added = generate_users(ldb, instance_id, number_of_users, password)
1818
1819     if machine_accounts > 0:
1820         LOGGER.info("Generating dummy machine accounts")
1821         computers_added = generate_machine_accounts(ldb, instance_id,
1822                                                     machine_accounts, password,
1823                                                     traffic_accounts)
1824
1825     if number_of_groups > 0:
1826         LOGGER.info("Generating dummy groups")
1827         groups_added = generate_groups(ldb, instance_id, number_of_groups)
1828
1829     if group_memberships > 0:
1830         LOGGER.info("Assigning users to groups")
1831         assignments = GroupAssignments(number_of_groups,
1832                                        groups_added,
1833                                        number_of_users,
1834                                        users_added,
1835                                        group_memberships)
1836         LOGGER.info("Adding users to groups")
1837         add_users_to_groups(ldb, instance_id, assignments)
1838         memberships_added = assignments.total()
1839
1840     if (groups_added > 0 and users_added == 0 and
1841        number_of_groups != groups_added):
1842         LOGGER.warning("The added groups will contain no members")
1843
1844     LOGGER.info("Added %d users (%d machines), %d groups and %d memberships" %
1845                 (users_added, computers_added, groups_added,
1846                  memberships_added))
1847
1848
1849 class GroupAssignments(object):
1850     def __init__(self, number_of_groups, groups_added, number_of_users,
1851                  users_added, group_memberships):
1852
1853         self.count = 0
1854         self.generate_group_distribution(number_of_groups)
1855         self.generate_user_distribution(number_of_users, group_memberships)
1856         self.assignments = self.assign_groups(number_of_groups,
1857                                               groups_added,
1858                                               number_of_users,
1859                                               users_added,
1860                                               group_memberships)
1861
1862     def cumulative_distribution(self, weights):
1863         # make sure the probabilities conform to a cumulative distribution
1864         # spread between 0.0 and 1.0. Dividing by the weighted total gives each
1865         # probability a proportional share of 1.0. Higher probabilities get a
1866         # bigger share, so are more likely to be picked. We use the cumulative
1867         # value, so we can use random.random() as a simple index into the list
1868         dist = []
1869         total = sum(weights)
1870         cumulative = 0.0
1871         for probability in weights:
1872             cumulative += probability
1873             dist.append(cumulative / total)
1874         return dist
1875
1876     def generate_user_distribution(self, num_users, num_memberships):
1877         """Probability distribution of a user belonging to a group.
1878         """
1879         # Assign a weighted probability to each user. Use the Pareto
1880         # Distribution so that some users are in a lot of groups, and the
1881         # bulk of users are in only a few groups. If we're assigning a large
1882         # number of group memberships, use a higher shape. This means slightly
1883         # fewer outlying users that are in large numbers of groups. The aim is
1884         # to have no users belonging to more than ~500 groups.
1885         if num_memberships > 5000000:
1886             shape = 3.0
1887         elif num_memberships > 2000000:
1888             shape = 2.5
1889         elif num_memberships > 300000:
1890             shape = 2.25
1891         else:
1892             shape = 1.75
1893
1894         weights = []
1895         for x in range(1, num_users + 1):
1896             p = random.paretovariate(shape)
1897             weights.append(p)
1898
1899         # convert the weights to a cumulative distribution between 0.0 and 1.0
1900         self.user_dist = self.cumulative_distribution(weights)
1901
1902     def generate_group_distribution(self, n):
1903         """Probability distribution of a group containing a user."""
1904
1905         # Assign a weighted probability to each user. Probability decreases
1906         # as the group-ID increases
1907         weights = []
1908         for x in range(1, n + 1):
1909             p = 1 / (x**1.3)
1910             weights.append(p)
1911
1912         # convert the weights to a cumulative distribution between 0.0 and 1.0
1913         self.group_dist = self.cumulative_distribution(weights)
1914
1915     def generate_random_membership(self):
1916         """Returns a randomly generated user-group membership"""
1917
1918         # the list items are cumulative distribution values between 0.0 and
1919         # 1.0, which makes random() a handy way to index the list to get a
1920         # weighted random user/group. (Here the user/group returned are
1921         # zero-based array indexes)
1922         user = bisect.bisect(self.user_dist, random.random())
1923         group = bisect.bisect(self.group_dist, random.random())
1924
1925         return user, group
1926
1927     def users_in_group(self, group):
1928         return self.assignments[group]
1929
1930     def get_groups(self):
1931         return self.assignments.keys()
1932
1933     def assign_groups(self, number_of_groups, groups_added,
1934                       number_of_users, users_added, group_memberships):
1935         """Allocate users to groups.
1936
1937         The intention is to have a few users that belong to most groups, while
1938         the majority of users belong to a few groups.
1939
1940         A few groups will contain most users, with the remaining only having a
1941         few users.
1942         """
1943
1944         assignments = set()
1945         if group_memberships <= 0:
1946             return {}
1947
1948         # Calculate the number of group menberships required
1949         group_memberships = math.ceil(
1950             float(group_memberships) *
1951             (float(users_added) / float(number_of_users)))
1952
1953         existing_users  = number_of_users  - users_added  - 1
1954         existing_groups = number_of_groups - groups_added - 1
1955         while len(assignments) < group_memberships:
1956             user, group = self.generate_random_membership()
1957
1958             if group > existing_groups or user > existing_users:
1959                 # the + 1 converts the array index to the corresponding
1960                 # group or user number
1961                 assignments.add(((user + 1), (group + 1)))
1962
1963         # convert the set into a dictionary, where key=group, value=list-of-
1964         # users-in-group (indexing by group-ID allows us to optimize for
1965         # DB membership writes)
1966         assignment_dict = defaultdict(list)
1967         for (user, group) in assignments:
1968             assignment_dict[group].append(user)
1969             self.count += 1
1970
1971         return assignment_dict
1972
1973     def total(self):
1974         return self.count
1975
1976
1977 def add_users_to_groups(db, instance_id, assignments):
1978     """Takes the assignments of users to groups and applies them to the DB."""
1979
1980     total = assignments.total()
1981     count = 0
1982     added = 0
1983
1984     for group in assignments.get_groups():
1985         users_in_group = assignments.users_in_group(group)
1986         if len(users_in_group) == 0:
1987             continue
1988
1989         # Split up the users into chunks, so we write no more than 1K at a
1990         # time. (Minimizing the DB modifies is more efficient, but writing
1991         # 10K+ users to a single group becomes inefficient memory-wise)
1992         for chunk in range(0, len(users_in_group), 1000):
1993             chunk_of_users = users_in_group[chunk:chunk + 1000]
1994             add_group_members(db, instance_id, group, chunk_of_users)
1995
1996             added += len(chunk_of_users)
1997             count += 1
1998             if count % 50 == 0:
1999                 LOGGER.info("Added %u/%u memberships" % (added, total))
2000
2001 def add_group_members(db, instance_id, group, users_in_group):
2002     """Adds the given users to group specified."""
2003
2004     ou = ou_name(db, instance_id)
2005
2006     def build_dn(name):
2007         return("cn=%s,%s" % (name, ou))
2008
2009     group_dn = build_dn(group_name(instance_id, group))
2010     m = ldb.Message()
2011     m.dn = ldb.Dn(db, group_dn)
2012
2013     for user in users_in_group:
2014         user_dn = build_dn(user_name(instance_id, user))
2015         idx = "member-" + str(user)
2016         m[idx] = ldb.MessageElement(user_dn, ldb.FLAG_MOD_ADD, "member")
2017
2018     db.modify(m)
2019
2020
2021 def generate_stats(statsdir, timing_file):
2022     """Generate and print the summary stats for a run."""
2023     first      = sys.float_info.max
2024     last       = 0
2025     successful = 0
2026     failed     = 0
2027     latencies  = {}
2028     failures   = {}
2029     unique_converations = set()
2030     conversations = 0
2031
2032     if timing_file is not None:
2033         tw = timing_file.write
2034     else:
2035         def tw(x):
2036             pass
2037
2038     tw("time\tconv\tprotocol\ttype\tduration\tsuccessful\terror\n")
2039
2040     for filename in os.listdir(statsdir):
2041         path = os.path.join(statsdir, filename)
2042         with open(path, 'r') as f:
2043             for line in f:
2044                 try:
2045                     fields       = line.rstrip('\n').split('\t')
2046                     conversation = fields[1]
2047                     protocol     = fields[2]
2048                     packet_type  = fields[3]
2049                     latency      = float(fields[4])
2050                     first        = min(float(fields[0]) - latency, first)
2051                     last         = max(float(fields[0]), last)
2052
2053                     if protocol not in latencies:
2054                         latencies[protocol] = {}
2055                     if packet_type not in latencies[protocol]:
2056                         latencies[protocol][packet_type] = []
2057
2058                     latencies[protocol][packet_type].append(latency)
2059
2060                     if protocol not in failures:
2061                         failures[protocol] = {}
2062                     if packet_type not in failures[protocol]:
2063                         failures[protocol][packet_type] = 0
2064
2065                     if fields[5] == 'True':
2066                         successful += 1
2067                     else:
2068                         failed += 1
2069                         failures[protocol][packet_type] += 1
2070
2071                     if conversation not in unique_converations:
2072                         unique_converations.add(conversation)
2073                         conversations += 1
2074
2075                     tw(line)
2076                 except (ValueError, IndexError):
2077                     # not a valid line print and ignore
2078                     print(line, file=sys.stderr)
2079                     pass
2080     duration = last - first
2081     if successful == 0:
2082         success_rate = 0
2083     else:
2084         success_rate = successful / duration
2085     if failed == 0:
2086         failure_rate = 0
2087     else:
2088         failure_rate = failed / duration
2089
2090     print("Total conversations:   %10d" % conversations)
2091     print("Successful operations: %10d (%.3f per second)"
2092           % (successful, success_rate))
2093     print("Failed operations:     %10d (%.3f per second)"
2094           % (failed, failure_rate))
2095
2096     print("Protocol    Op Code  Description                               "
2097           " Count       Failed         Mean       Median          "
2098           "95%        Range          Max")
2099
2100     protocols = sorted(latencies.keys())
2101     for protocol in protocols:
2102         packet_types = sorted(latencies[protocol], key=opcode_key)
2103         for packet_type in packet_types:
2104             values     = latencies[protocol][packet_type]
2105             values     = sorted(values)
2106             count      = len(values)
2107             failed     = failures[protocol][packet_type]
2108             mean       = sum(values) / count
2109             median     = calc_percentile(values, 0.50)
2110             percentile = calc_percentile(values, 0.95)
2111             rng        = values[-1] - values[0]
2112             maxv       = values[-1]
2113             desc       = OP_DESCRIPTIONS.get((protocol, packet_type), '')
2114             if sys.stdout.isatty:
2115                 print("%-12s   %4s  %-35s %12d %12d %12.6f "
2116                       "%12.6f %12.6f %12.6f %12.6f"
2117                       % (protocol,
2118                          packet_type,
2119                          desc,
2120                          count,
2121                          failed,
2122                          mean,
2123                          median,
2124                          percentile,
2125                          rng,
2126                          maxv))
2127             else:
2128                 print("%s\t%s\t%s\t%d\t%d\t%f\t%f\t%f\t%f\t%f"
2129                       % (protocol,
2130                          packet_type,
2131                          desc,
2132                          count,
2133                          failed,
2134                          mean,
2135                          median,
2136                          percentile,
2137                          rng,
2138                          maxv))
2139
2140
2141 def opcode_key(v):
2142     """Sort key for the operation code to ensure that it sorts numerically"""
2143     try:
2144         return "%03d" % int(v)
2145     except:
2146         return v
2147
2148
2149 def calc_percentile(values, percentile):
2150     """Calculate the specified percentile from the list of values.
2151
2152     Assumes the list is sorted in ascending order.
2153     """
2154
2155     if not values:
2156         return 0
2157     k = (len(values) - 1) * percentile
2158     f = math.floor(k)
2159     c = math.ceil(k)
2160     if f == c:
2161         return values[int(k)]
2162     d0 = values[int(f)] * (c - k)
2163     d1 = values[int(c)] * (k - f)
2164     return d0 + d1
2165
2166
2167 def mk_masked_dir(*path):
2168     """In a testenv we end up with 0777 diectories that look an alarming
2169     green colour with ls. Use umask to avoid that."""
2170     d = os.path.join(*path)
2171     mask = os.umask(0o077)
2172     os.mkdir(d)
2173     os.umask(mask)
2174     return d