traffic_replay: Improve user generation debug
[samba.git] / python / samba / emulate / traffic.py
1 # -*- encoding: utf-8 -*-
2 # Samba traffic replay and learning
3 #
4 # Copyright (C) Catalyst IT Ltd. 2017
5 #
6 # This program is free software; you can redistribute it and/or modify
7 # it under the terms of the GNU General Public License as published by
8 # the Free Software Foundation; either version 3 of the License, or
9 # (at your option) any later version.
10 #
11 # This program is distributed in the hope that it will be useful,
12 # but WITHOUT ANY WARRANTY; without even the implied warranty of
13 # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
14 # GNU General Public License for more details.
15 #
16 # You should have received a copy of the GNU General Public License
17 # along with this program.  If not, see <http://www.gnu.org/licenses/>.
18 #
19 from __future__ import print_function, division
20
21 import time
22 import os
23 import random
24 import json
25 import math
26 import sys
27 import signal
28 import itertools
29
30 from collections import OrderedDict, Counter, defaultdict
31 from samba.emulate import traffic_packets
32 from samba.samdb import SamDB
33 import ldb
34 from ldb import LdbError
35 from samba.dcerpc import ClientConnection
36 from samba.dcerpc import security, drsuapi, lsa
37 from samba.dcerpc import netlogon
38 from samba.dcerpc.netlogon import netr_Authenticator
39 from samba.dcerpc import srvsvc
40 from samba.dcerpc import samr
41 from samba.drs_utils import drs_DsBind
42 import traceback
43 from samba.credentials import Credentials, DONT_USE_KERBEROS, MUST_USE_KERBEROS
44 from samba.auth import system_session
45 from samba.dsdb import (
46     UF_NORMAL_ACCOUNT,
47     UF_SERVER_TRUST_ACCOUNT,
48     UF_TRUSTED_FOR_DELEGATION
49 )
50 from samba.dcerpc.misc import SEC_CHAN_BDC
51 from samba import gensec
52 from samba import sd_utils
53 from samba.compat import get_string
54 from samba.logger import get_samba_logger
55 import bisect
56
57 SLEEP_OVERHEAD = 3e-4
58
59 # we don't use None, because it complicates [de]serialisation
60 NON_PACKET = '-'
61
62 CLIENT_CLUES = {
63     ('dns', '0'): 1.0,      # query
64     ('smb', '0x72'): 1.0,   # Negotiate protocol
65     ('ldap', '0'): 1.0,     # bind
66     ('ldap', '3'): 1.0,     # searchRequest
67     ('ldap', '2'): 1.0,     # unbindRequest
68     ('cldap', '3'): 1.0,
69     ('dcerpc', '11'): 1.0,  # bind
70     ('dcerpc', '14'): 1.0,  # Alter_context
71     ('nbns', '0'): 1.0,     # query
72 }
73
74 SERVER_CLUES = {
75     ('dns', '1'): 1.0,      # response
76     ('ldap', '1'): 1.0,     # bind response
77     ('ldap', '4'): 1.0,     # search result
78     ('ldap', '5'): 1.0,     # search done
79     ('cldap', '5'): 1.0,
80     ('dcerpc', '12'): 1.0,  # bind_ack
81     ('dcerpc', '13'): 1.0,  # bind_nak
82     ('dcerpc', '15'): 1.0,  # Alter_context response
83 }
84
85 SKIPPED_PROTOCOLS = {"smb", "smb2", "browser", "smb_netlogon"}
86
87 WAIT_SCALE = 10.0
88 WAIT_THRESHOLD = (1.0 / WAIT_SCALE)
89 NO_WAIT_LOG_TIME_RANGE = (-10, -3)
90
91 # DEBUG_LEVEL can be changed by scripts with -d
92 DEBUG_LEVEL = 0
93
94 LOGGER = get_samba_logger(name=__name__)
95
96
97 def debug(level, msg, *args):
98     """Print a formatted debug message to standard error.
99
100
101     :param level: The debug level, message will be printed if it is <= the
102                   currently set debug level. The debug level can be set with
103                   the -d option.
104     :param msg:   The message to be logged, can contain C-Style format
105                   specifiers
106     :param args:  The parameters required by the format specifiers
107     """
108     if level <= DEBUG_LEVEL:
109         if not args:
110             print(msg, file=sys.stderr)
111         else:
112             print(msg % tuple(args), file=sys.stderr)
113
114
115 def debug_lineno(*args):
116     """ Print an unformatted log message to stderr, contaning the line number
117     """
118     tb = traceback.extract_stack(limit=2)
119     print((" %s:" "\033[01;33m"
120            "%s " "\033[00m" % (tb[0][2], tb[0][1])), end=' ',
121           file=sys.stderr)
122     for a in args:
123         print(a, file=sys.stderr)
124     print(file=sys.stderr)
125     sys.stderr.flush()
126
127
128 def random_colour_print():
129     """Return a function that prints a randomly coloured line to stderr"""
130     n = 18 + random.randrange(214)
131     prefix = "\033[38;5;%dm" % n
132
133     def p(*args):
134         for a in args:
135             print("%s%s\033[00m" % (prefix, a), file=sys.stderr)
136
137     return p
138
139
140 class FakePacketError(Exception):
141     pass
142
143
144 class Packet(object):
145     """Details of a network packet"""
146     def __init__(self, timestamp, ip_protocol, stream_number, src, dest,
147                  protocol, opcode, desc, extra):
148
149         self.timestamp = timestamp
150         self.ip_protocol = ip_protocol
151         self.stream_number = stream_number
152         self.src = src
153         self.dest = dest
154         self.protocol = protocol
155         self.opcode = opcode
156         self.desc = desc
157         self.extra = extra
158         if self.src < self.dest:
159             self.endpoints = (self.src, self.dest)
160         else:
161             self.endpoints = (self.dest, self.src)
162
163     @classmethod
164     def from_line(self, line):
165         fields = line.rstrip('\n').split('\t')
166         (timestamp,
167          ip_protocol,
168          stream_number,
169          src,
170          dest,
171          protocol,
172          opcode,
173          desc) = fields[:8]
174         extra = fields[8:]
175
176         timestamp = float(timestamp)
177         src = int(src)
178         dest = int(dest)
179
180         return Packet(timestamp, ip_protocol, stream_number, src, dest,
181                       protocol, opcode, desc, extra)
182
183     def as_summary(self, time_offset=0.0):
184         """Format the packet as a traffic_summary line.
185         """
186         extra = '\t'.join(self.extra)
187         t = self.timestamp + time_offset
188         return (t, '%f\t%s\t%s\t%d\t%d\t%s\t%s\t%s\t%s' %
189                 (t,
190                  self.ip_protocol,
191                  self.stream_number or '',
192                  self.src,
193                  self.dest,
194                  self.protocol,
195                  self.opcode,
196                  self.desc,
197                  extra))
198
199     def __str__(self):
200         return ("%.3f: %d -> %d; ip %s; strm %s; prot %s; op %s; desc %s %s" %
201                 (self.timestamp, self.src, self.dest, self.ip_protocol or '-',
202                  self.stream_number, self.protocol, self.opcode, self.desc,
203                  ('«' + ' '.join(self.extra) + '»' if self.extra else '')))
204
205     def __repr__(self):
206         return "<Packet @%s>" % self
207
208     def copy(self):
209         return self.__class__(self.timestamp,
210                               self.ip_protocol,
211                               self.stream_number,
212                               self.src,
213                               self.dest,
214                               self.protocol,
215                               self.opcode,
216                               self.desc,
217                               self.extra)
218
219     def as_packet_type(self):
220         t = '%s:%s' % (self.protocol, self.opcode)
221         return t
222
223     def client_score(self):
224         """A positive number means we think it is a client; a negative number
225         means we think it is a server. Zero means no idea. range: -1 to 1.
226         """
227         key = (self.protocol, self.opcode)
228         if key in CLIENT_CLUES:
229             return CLIENT_CLUES[key]
230         if key in SERVER_CLUES:
231             return -SERVER_CLUES[key]
232         return 0.0
233
234     def play(self, conversation, context):
235         """Send the packet over the network, if required.
236
237         Some packets are ignored, i.e. for  protocols not handled,
238         server response messages, or messages that are generated by the
239         protocol layer associated with other packets.
240         """
241         fn_name = 'packet_%s_%s' % (self.protocol, self.opcode)
242         try:
243             fn = getattr(traffic_packets, fn_name)
244
245         except AttributeError as e:
246             print("Conversation(%s) Missing handler %s" %
247                   (conversation.conversation_id, fn_name),
248                   file=sys.stderr)
249             return
250
251         # Don't display a message for kerberos packets, they're not directly
252         # generated they're used to indicate kerberos should be used
253         if self.protocol != "kerberos":
254             debug(2, "Conversation(%s) Calling handler %s" %
255                      (conversation.conversation_id, fn_name))
256
257         start = time.time()
258         try:
259             if fn(self, conversation, context):
260                 # Only collect timing data for functions that generate
261                 # network traffic, or fail
262                 end = time.time()
263                 duration = end - start
264                 print("%f\t%s\t%s\t%s\t%f\tTrue\t" %
265                       (end, conversation.conversation_id, self.protocol,
266                        self.opcode, duration))
267         except Exception as e:
268             end = time.time()
269             duration = end - start
270             print("%f\t%s\t%s\t%s\t%f\tFalse\t%s" %
271                   (end, conversation.conversation_id, self.protocol,
272                    self.opcode, duration, e))
273
274     def __cmp__(self, other):
275         return self.timestamp - other.timestamp
276
277     def is_really_a_packet(self, missing_packet_stats=None):
278         """Is the packet one that can be ignored?
279
280         If so removing it will have no effect on the replay
281         """
282         if self.protocol in SKIPPED_PROTOCOLS:
283             # Ignore any packets for the protocols we're not interested in.
284             return False
285         if self.protocol == "ldap" and self.opcode == '':
286             # skip ldap continuation packets
287             return False
288
289         fn_name = 'packet_%s_%s' % (self.protocol, self.opcode)
290         fn = getattr(traffic_packets, fn_name, None)
291         if not fn:
292             print("missing packet %s" % fn_name, file=sys.stderr)
293             return False
294         if fn is traffic_packets.null_packet:
295             return False
296         return True
297
298
299 class ReplayContext(object):
300     """State/Context for an individual conversation between an simulated client
301        and a server.
302     """
303
304     def __init__(self,
305                  server=None,
306                  lp=None,
307                  creds=None,
308                  badpassword_frequency=None,
309                  prefer_kerberos=None,
310                  tempdir=None,
311                  statsdir=None,
312                  ou=None,
313                  base_dn=None,
314                  domain=None,
315                  domain_sid=None):
316
317         self.server                   = server
318         self.ldap_connections         = []
319         self.dcerpc_connections       = []
320         self.lsarpc_connections       = []
321         self.lsarpc_connections_named = []
322         self.drsuapi_connections      = []
323         self.srvsvc_connections       = []
324         self.samr_contexts            = []
325         self.netlogon_connection      = None
326         self.creds                    = creds
327         self.lp                       = lp
328         self.prefer_kerberos          = prefer_kerberos
329         self.ou                       = ou
330         self.base_dn                  = base_dn
331         self.domain                   = domain
332         self.statsdir                 = statsdir
333         self.global_tempdir           = tempdir
334         self.domain_sid               = domain_sid
335         self.realm                    = lp.get('realm')
336
337         # Bad password attempt controls
338         self.badpassword_frequency    = badpassword_frequency
339         self.last_lsarpc_bad          = False
340         self.last_lsarpc_named_bad    = False
341         self.last_simple_bind_bad     = False
342         self.last_bind_bad            = False
343         self.last_srvsvc_bad          = False
344         self.last_drsuapi_bad         = False
345         self.last_netlogon_bad        = False
346         self.last_samlogon_bad        = False
347         self.generate_ldap_search_tables()
348         self.next_conversation_id = itertools.count()
349
350     def generate_ldap_search_tables(self):
351         session = system_session()
352
353         db = SamDB(url="ldap://%s" % self.server,
354                    session_info=session,
355                    credentials=self.creds,
356                    lp=self.lp)
357
358         res = db.search(db.domain_dn(),
359                         scope=ldb.SCOPE_SUBTREE,
360                         controls=["paged_results:1:1000"],
361                         attrs=['dn'])
362
363         # find a list of dns for each pattern
364         # e.g. CN,CN,CN,DC,DC
365         dn_map = {}
366         attribute_clue_map = {
367             'invocationId': []
368         }
369
370         for r in res:
371             dn = str(r.dn)
372             pattern = ','.join(x.lstrip()[:2] for x in dn.split(',')).upper()
373             dns = dn_map.setdefault(pattern, [])
374             dns.append(dn)
375             if dn.startswith('CN=NTDS Settings,'):
376                 attribute_clue_map['invocationId'].append(dn)
377
378         # extend the map in case we are working with a different
379         # number of DC components.
380         # for k, v in self.dn_map.items():
381         #     print >>sys.stderr, k, len(v)
382
383         for k in list(dn_map.keys()):
384             if k[-3:] != ',DC':
385                 continue
386             p = k[:-3]
387             while p[-3:] == ',DC':
388                 p = p[:-3]
389             for i in range(5):
390                 p += ',DC'
391                 if p != k and p in dn_map:
392                     print('dn_map collison %s %s' % (k, p),
393                           file=sys.stderr)
394                     continue
395                 dn_map[p] = dn_map[k]
396
397         self.dn_map = dn_map
398         self.attribute_clue_map = attribute_clue_map
399
400     def generate_process_local_config(self, account, conversation):
401         if account is None:
402             return
403         self.netbios_name             = account.netbios_name
404         self.machinepass              = account.machinepass
405         self.username                 = account.username
406         self.userpass                 = account.userpass
407
408         self.tempdir = mk_masked_dir(self.global_tempdir,
409                                      'conversation-%d' %
410                                      conversation.conversation_id)
411
412         self.lp.set("private dir", self.tempdir)
413         self.lp.set("lock dir", self.tempdir)
414         self.lp.set("state directory", self.tempdir)
415         self.lp.set("tls verify peer", "no_check")
416
417         # If the domain was not specified, check for the environment
418         # variable.
419         if self.domain is None:
420             self.domain = os.environ["DOMAIN"]
421
422         self.remoteAddress = "/root/ncalrpc_as_system"
423         self.samlogon_dn   = ("cn=%s,%s" %
424                               (self.netbios_name, self.ou))
425         self.user_dn       = ("cn=%s,%s" %
426                               (self.username, self.ou))
427
428         self.generate_machine_creds()
429         self.generate_user_creds()
430
431     def with_random_bad_credentials(self, f, good, bad, failed_last_time):
432         """Execute the supplied logon function, randomly choosing the
433            bad credentials.
434
435            Based on the frequency in badpassword_frequency randomly perform the
436            function with the supplied bad credentials.
437            If run with bad credentials, the function is re-run with the good
438            credentials.
439            failed_last_time is used to prevent consecutive bad credential
440            attempts. So the over all bad credential frequency will be lower
441            than that requested, but not significantly.
442         """
443         if not failed_last_time:
444             if (self.badpassword_frequency and self.badpassword_frequency > 0
445                 and random.random() < self.badpassword_frequency):
446                 try:
447                     f(bad)
448                 except:
449                     # Ignore any exceptions as the operation may fail
450                     # as it's being performed with bad credentials
451                     pass
452                 failed_last_time = True
453             else:
454                 failed_last_time = False
455
456         result = f(good)
457         return (result, failed_last_time)
458
459     def generate_user_creds(self):
460         """Generate the conversation specific user Credentials.
461
462         Each Conversation has an associated user account used to simulate
463         any non Administrative user traffic.
464
465         Generates user credentials with good and bad passwords and ldap
466         simple bind credentials with good and bad passwords.
467         """
468         self.user_creds = Credentials()
469         self.user_creds.guess(self.lp)
470         self.user_creds.set_workstation(self.netbios_name)
471         self.user_creds.set_password(self.userpass)
472         self.user_creds.set_username(self.username)
473         self.user_creds.set_domain(self.domain)
474         if self.prefer_kerberos:
475             self.user_creds.set_kerberos_state(MUST_USE_KERBEROS)
476         else:
477             self.user_creds.set_kerberos_state(DONT_USE_KERBEROS)
478
479         self.user_creds_bad = Credentials()
480         self.user_creds_bad.guess(self.lp)
481         self.user_creds_bad.set_workstation(self.netbios_name)
482         self.user_creds_bad.set_password(self.userpass[:-4])
483         self.user_creds_bad.set_username(self.username)
484         if self.prefer_kerberos:
485             self.user_creds_bad.set_kerberos_state(MUST_USE_KERBEROS)
486         else:
487             self.user_creds_bad.set_kerberos_state(DONT_USE_KERBEROS)
488
489         # Credentials for ldap simple bind.
490         self.simple_bind_creds = Credentials()
491         self.simple_bind_creds.guess(self.lp)
492         self.simple_bind_creds.set_workstation(self.netbios_name)
493         self.simple_bind_creds.set_password(self.userpass)
494         self.simple_bind_creds.set_username(self.username)
495         self.simple_bind_creds.set_gensec_features(
496             self.simple_bind_creds.get_gensec_features() | gensec.FEATURE_SEAL)
497         if self.prefer_kerberos:
498             self.simple_bind_creds.set_kerberos_state(MUST_USE_KERBEROS)
499         else:
500             self.simple_bind_creds.set_kerberos_state(DONT_USE_KERBEROS)
501         self.simple_bind_creds.set_bind_dn(self.user_dn)
502
503         self.simple_bind_creds_bad = Credentials()
504         self.simple_bind_creds_bad.guess(self.lp)
505         self.simple_bind_creds_bad.set_workstation(self.netbios_name)
506         self.simple_bind_creds_bad.set_password(self.userpass[:-4])
507         self.simple_bind_creds_bad.set_username(self.username)
508         self.simple_bind_creds_bad.set_gensec_features(
509             self.simple_bind_creds_bad.get_gensec_features() |
510             gensec.FEATURE_SEAL)
511         if self.prefer_kerberos:
512             self.simple_bind_creds_bad.set_kerberos_state(MUST_USE_KERBEROS)
513         else:
514             self.simple_bind_creds_bad.set_kerberos_state(DONT_USE_KERBEROS)
515         self.simple_bind_creds_bad.set_bind_dn(self.user_dn)
516
517     def generate_machine_creds(self):
518         """Generate the conversation specific machine Credentials.
519
520         Each Conversation has an associated machine account.
521
522         Generates machine credentials with good and bad passwords.
523         """
524
525         self.machine_creds = Credentials()
526         self.machine_creds.guess(self.lp)
527         self.machine_creds.set_workstation(self.netbios_name)
528         self.machine_creds.set_secure_channel_type(SEC_CHAN_BDC)
529         self.machine_creds.set_password(self.machinepass)
530         self.machine_creds.set_username(self.netbios_name + "$")
531         self.machine_creds.set_domain(self.domain)
532         if self.prefer_kerberos:
533             self.machine_creds.set_kerberos_state(MUST_USE_KERBEROS)
534         else:
535             self.machine_creds.set_kerberos_state(DONT_USE_KERBEROS)
536
537         self.machine_creds_bad = Credentials()
538         self.machine_creds_bad.guess(self.lp)
539         self.machine_creds_bad.set_workstation(self.netbios_name)
540         self.machine_creds_bad.set_secure_channel_type(SEC_CHAN_BDC)
541         self.machine_creds_bad.set_password(self.machinepass[:-4])
542         self.machine_creds_bad.set_username(self.netbios_name + "$")
543         if self.prefer_kerberos:
544             self.machine_creds_bad.set_kerberos_state(MUST_USE_KERBEROS)
545         else:
546             self.machine_creds_bad.set_kerberos_state(DONT_USE_KERBEROS)
547
548     def get_matching_dn(self, pattern, attributes=None):
549         # If the pattern is an empty string, we assume ROOTDSE,
550         # Otherwise we try adding or removing DC suffixes, then
551         # shorter leading patterns until we hit one.
552         # e.g if there is no CN,CN,CN,CN,DC,DC
553         # we first try       CN,CN,CN,CN,DC
554         # and                CN,CN,CN,CN,DC,DC,DC
555         # then change to        CN,CN,CN,DC,DC
556         # and as last resort we use the base_dn
557         attr_clue = self.attribute_clue_map.get(attributes)
558         if attr_clue:
559             return random.choice(attr_clue)
560
561         pattern = pattern.upper()
562         while pattern:
563             if pattern in self.dn_map:
564                 return random.choice(self.dn_map[pattern])
565             # chop one off the front and try it all again.
566             pattern = pattern[3:]
567
568         return self.base_dn
569
570     def get_dcerpc_connection(self, new=False):
571         guid = '12345678-1234-abcd-ef00-01234567cffb'  # RPC_NETLOGON UUID
572         if self.dcerpc_connections and not new:
573             return self.dcerpc_connections[-1]
574         c = ClientConnection("ncacn_ip_tcp:%s" % self.server,
575                              (guid, 1), self.lp)
576         self.dcerpc_connections.append(c)
577         return c
578
579     def get_srvsvc_connection(self, new=False):
580         if self.srvsvc_connections and not new:
581             return self.srvsvc_connections[-1]
582
583         def connect(creds):
584             return srvsvc.srvsvc("ncacn_np:%s" % (self.server),
585                                  self.lp,
586                                  creds)
587
588         (c, self.last_srvsvc_bad) = \
589             self.with_random_bad_credentials(connect,
590                                              self.user_creds,
591                                              self.user_creds_bad,
592                                              self.last_srvsvc_bad)
593
594         self.srvsvc_connections.append(c)
595         return c
596
597     def get_lsarpc_connection(self, new=False):
598         if self.lsarpc_connections and not new:
599             return self.lsarpc_connections[-1]
600
601         def connect(creds):
602             binding_options = 'schannel,seal,sign'
603             return lsa.lsarpc("ncacn_ip_tcp:%s[%s]" %
604                               (self.server, binding_options),
605                               self.lp,
606                               creds)
607
608         (c, self.last_lsarpc_bad) = \
609             self.with_random_bad_credentials(connect,
610                                              self.machine_creds,
611                                              self.machine_creds_bad,
612                                              self.last_lsarpc_bad)
613
614         self.lsarpc_connections.append(c)
615         return c
616
617     def get_lsarpc_named_pipe_connection(self, new=False):
618         if self.lsarpc_connections_named and not new:
619             return self.lsarpc_connections_named[-1]
620
621         def connect(creds):
622             return lsa.lsarpc("ncacn_np:%s" % (self.server),
623                               self.lp,
624                               creds)
625
626         (c, self.last_lsarpc_named_bad) = \
627             self.with_random_bad_credentials(connect,
628                                              self.machine_creds,
629                                              self.machine_creds_bad,
630                                              self.last_lsarpc_named_bad)
631
632         self.lsarpc_connections_named.append(c)
633         return c
634
635     def get_drsuapi_connection_pair(self, new=False, unbind=False):
636         """get a (drs, drs_handle) tuple"""
637         if self.drsuapi_connections and not new:
638             c = self.drsuapi_connections[-1]
639             return c
640
641         def connect(creds):
642             binding_options = 'seal'
643             binding_string = "ncacn_ip_tcp:%s[%s]" %\
644                              (self.server, binding_options)
645             return drsuapi.drsuapi(binding_string, self.lp, creds)
646
647         (drs, self.last_drsuapi_bad) = \
648             self.with_random_bad_credentials(connect,
649                                              self.user_creds,
650                                              self.user_creds_bad,
651                                              self.last_drsuapi_bad)
652
653         (drs_handle, supported_extensions) = drs_DsBind(drs)
654         c = (drs, drs_handle)
655         self.drsuapi_connections.append(c)
656         return c
657
658     def get_ldap_connection(self, new=False, simple=False):
659         if self.ldap_connections and not new:
660             return self.ldap_connections[-1]
661
662         def simple_bind(creds):
663             """
664             To run simple bind against Windows, we need to run
665             following commands in PowerShell:
666
667                 Install-windowsfeature ADCS-Cert-Authority
668                 Install-AdcsCertificationAuthority -CAType EnterpriseRootCA
669                 Restart-Computer
670
671             """
672             return SamDB('ldaps://%s' % self.server,
673                          credentials=creds,
674                          lp=self.lp)
675
676         def sasl_bind(creds):
677             return SamDB('ldap://%s' % self.server,
678                          credentials=creds,
679                          lp=self.lp)
680         if simple:
681             (samdb, self.last_simple_bind_bad) = \
682                 self.with_random_bad_credentials(simple_bind,
683                                                  self.simple_bind_creds,
684                                                  self.simple_bind_creds_bad,
685                                                  self.last_simple_bind_bad)
686         else:
687             (samdb, self.last_bind_bad) = \
688                 self.with_random_bad_credentials(sasl_bind,
689                                                  self.user_creds,
690                                                  self.user_creds_bad,
691                                                  self.last_bind_bad)
692
693         self.ldap_connections.append(samdb)
694         return samdb
695
696     def get_samr_context(self, new=False):
697         if not self.samr_contexts or new:
698             self.samr_contexts.append(
699                 SamrContext(self.server, lp=self.lp, creds=self.creds))
700         return self.samr_contexts[-1]
701
702     def get_netlogon_connection(self):
703
704         if self.netlogon_connection:
705             return self.netlogon_connection
706
707         def connect(creds):
708             return netlogon.netlogon("ncacn_ip_tcp:%s[schannel,seal]" %
709                                      (self.server),
710                                      self.lp,
711                                      creds)
712         (c, self.last_netlogon_bad) = \
713             self.with_random_bad_credentials(connect,
714                                              self.machine_creds,
715                                              self.machine_creds_bad,
716                                              self.last_netlogon_bad)
717         self.netlogon_connection = c
718         return c
719
720     def guess_a_dns_lookup(self):
721         return (self.realm, 'A')
722
723     def get_authenticator(self):
724         auth = self.machine_creds.new_client_authenticator()
725         current  = netr_Authenticator()
726         current.cred.data = [x if isinstance(x, int) else ord(x) for x in auth["credential"]]
727         current.timestamp = auth["timestamp"]
728
729         subsequent = netr_Authenticator()
730         return (current, subsequent)
731
732
733 class SamrContext(object):
734     """State/Context associated with a samr connection.
735     """
736     def __init__(self, server, lp=None, creds=None):
737         self.connection    = None
738         self.handle        = None
739         self.domain_handle = None
740         self.domain_sid    = None
741         self.group_handle  = None
742         self.user_handle   = None
743         self.rids          = None
744         self.server        = server
745         self.lp            = lp
746         self.creds         = creds
747
748     def get_connection(self):
749         if not self.connection:
750             self.connection = samr.samr(
751                 "ncacn_ip_tcp:%s[seal]" % (self.server),
752                 lp_ctx=self.lp,
753                 credentials=self.creds)
754
755         return self.connection
756
757     def get_handle(self):
758         if not self.handle:
759             c = self.get_connection()
760             self.handle = c.Connect2(None, security.SEC_FLAG_MAXIMUM_ALLOWED)
761         return self.handle
762
763
764 class Conversation(object):
765     """Details of a converation between a simulated client and a server."""
766     conversation_id = None
767
768     def __init__(self, start_time=None, endpoints=None):
769         self.start_time = start_time
770         self.endpoints = endpoints
771         self.packets = []
772         self.msg = random_colour_print()
773         self.client_balance = 0.0
774
775     def __cmp__(self, other):
776         if self.start_time is None:
777             if other.start_time is None:
778                 return 0
779             return -1
780         if other.start_time is None:
781             return 1
782         return self.start_time - other.start_time
783
784     def add_packet(self, packet):
785         """Add a packet object to this conversation, making a local copy with
786         a conversation-relative timestamp."""
787         p = packet.copy()
788
789         if self.start_time is None:
790             self.start_time = p.timestamp
791
792         if self.endpoints is None:
793             self.endpoints = p.endpoints
794
795         if p.endpoints != self.endpoints:
796             raise FakePacketError("Conversation endpoints %s don't match"
797                                   "packet endpoints %s" %
798                                   (self.endpoints, p.endpoints))
799
800         p.timestamp -= self.start_time
801
802         if p.src == p.endpoints[0]:
803             self.client_balance -= p.client_score()
804         else:
805             self.client_balance += p.client_score()
806
807         if p.is_really_a_packet():
808             self.packets.append(p)
809
810     def add_short_packet(self, timestamp, protocol, opcode, extra,
811                          client=True):
812         """Create a packet from a timestamp, and 'protocol:opcode' pair, and a
813         (possibly empty) list of extra data. If client is True, assume
814         this packet is from the client to the server.
815         """
816         src, dest = self.guess_client_server()
817         if not client:
818             src, dest = dest, src
819         key = (protocol, opcode)
820         desc = OP_DESCRIPTIONS[key] if key in OP_DESCRIPTIONS else ''
821         if protocol in IP_PROTOCOLS:
822             ip_protocol = IP_PROTOCOLS[protocol]
823         else:
824             ip_protocol = '06'
825         packet = Packet(timestamp - self.start_time, ip_protocol,
826                         '', src, dest,
827                         protocol, opcode, desc, extra)
828         # XXX we're assuming the timestamp is already adjusted for
829         # this conversation?
830         # XXX should we adjust client balance for guessed packets?
831         if packet.src == packet.endpoints[0]:
832             self.client_balance -= packet.client_score()
833         else:
834             self.client_balance += packet.client_score()
835         if packet.is_really_a_packet():
836             self.packets.append(packet)
837
838     def __str__(self):
839         return ("<Conversation %s %s starting %.3f %d packets>" %
840                 (self.conversation_id, self.endpoints, self.start_time,
841                  len(self.packets)))
842
843     __repr__ = __str__
844
845     def __iter__(self):
846         return iter(self.packets)
847
848     def __len__(self):
849         return len(self.packets)
850
851     def get_duration(self):
852         if len(self.packets) < 2:
853             return 0
854         return self.packets[-1].timestamp - self.packets[0].timestamp
855
856     def replay_as_summary_lines(self):
857         lines = []
858         for p in self.packets:
859             lines.append(p.as_summary(self.start_time))
860         return lines
861
862     def replay_in_fork_with_delay(self, start, context=None, account=None):
863         """Fork a new process and replay the conversation.
864         """
865         def signal_handler(signal, frame):
866             """Signal handler closes standard out and error.
867
868             Triggered by a sigterm, ensures that the log messages are flushed
869             to disk and not lost.
870             """
871             sys.stderr.close()
872             sys.stdout.close()
873             os._exit(0)
874
875         t = self.start_time
876         now = time.time() - start
877         gap = t - now
878         # we are replaying strictly in order, so it is safe to sleep
879         # in the main process if the gap is big enough. This reduces
880         # the number of concurrent threads, which allows us to make
881         # larger loads.
882         if gap > 0.15 and False:
883             print("sleeping for %f in main process" % (gap - 0.1),
884                   file=sys.stderr)
885             time.sleep(gap - 0.1)
886             now = time.time() - start
887             gap = t - now
888             print("gap is now %f" % gap, file=sys.stderr)
889
890         self.conversation_id = next(context.next_conversation_id)
891         pid = os.fork()
892         if pid != 0:
893             return pid
894         pid = os.getpid()
895         signal.signal(signal.SIGTERM, signal_handler)
896         # we must never return, or we'll end up running parts of the
897         # parent's clean-up code. So we work in a try...finally, and
898         # try to print any exceptions.
899
900         try:
901             context.generate_process_local_config(account, self)
902             sys.stdin.close()
903             os.close(0)
904             filename = os.path.join(context.statsdir, 'stats-conversation-%d' %
905                                     self.conversation_id)
906             sys.stdout.close()
907             sys.stdout = open(filename, 'w')
908
909             sleep_time = gap - SLEEP_OVERHEAD
910             if sleep_time > 0:
911                 time.sleep(sleep_time)
912
913             miss = t - (time.time() - start)
914             self.msg("starting %s [miss %.3f pid %d]" % (self, miss, pid))
915             self.replay(context)
916         except Exception:
917             print(("EXCEPTION in child PID %d, conversation %s" % (pid, self)),
918                   file=sys.stderr)
919             traceback.print_exc(sys.stderr)
920         finally:
921             sys.stderr.close()
922             sys.stdout.close()
923             os._exit(0)
924
925     def replay(self, context=None):
926         start = time.time()
927
928         for p in self.packets:
929             now = time.time() - start
930             gap = p.timestamp - now
931             sleep_time = gap - SLEEP_OVERHEAD
932             if sleep_time > 0:
933                 time.sleep(sleep_time)
934
935             miss = p.timestamp - (time.time() - start)
936             if context is None:
937                 self.msg("packet %s [miss %.3f pid %d]" % (p, miss,
938                                                            os.getpid()))
939                 continue
940             p.play(self, context)
941
942     def guess_client_server(self, server_clue=None):
943         """Have a go at deciding who is the server and who is the client.
944         returns (client, server)
945         """
946         a, b = self.endpoints
947
948         if self.client_balance < 0:
949             return (a, b)
950
951         # in the absense of a clue, we will fall through to assuming
952         # the lowest number is the server (which is usually true).
953
954         if self.client_balance == 0 and server_clue == b:
955             return (a, b)
956
957         return (b, a)
958
959     def forget_packets_outside_window(self, s, e):
960         """Prune any packets outside the timne window we're interested in
961
962         :param s: start of the window
963         :param e: end of the window
964         """
965         self.packets = [p for p in self.packets if s <= p.timestamp <= e]
966         self.start_time = self.packets[0].timestamp if self.packets else None
967
968     def renormalise_times(self, start_time):
969         """Adjust the packet start times relative to the new start time."""
970         for p in self.packets:
971             p.timestamp -= start_time
972
973         if self.start_time is not None:
974             self.start_time -= start_time
975
976
977 class DnsHammer(Conversation):
978     """A lightweight conversation that generates a lot of dns:0 packets on
979     the fly"""
980
981     def __init__(self, dns_rate, duration):
982         n = int(dns_rate * duration)
983         self.times = [random.uniform(0, duration) for i in range(n)]
984         self.times.sort()
985         self.rate = dns_rate
986         self.duration = duration
987         self.start_time = 0
988         self.msg = random_colour_print()
989
990     def __str__(self):
991         return ("<DnsHammer %d packets over %.1fs (rate %.2f)>" %
992                 (len(self.times), self.duration, self.rate))
993
994     def replay_in_fork_with_delay(self, start, context=None, account=None):
995         return Conversation.replay_in_fork_with_delay(self,
996                                                       start,
997                                                       context,
998                                                       account)
999
1000     def replay(self, context=None):
1001         start = time.time()
1002         fn = traffic_packets.packet_dns_0
1003         for t in self.times:
1004             now = time.time() - start
1005             gap = t - now
1006             sleep_time = gap - SLEEP_OVERHEAD
1007             if sleep_time > 0:
1008                 time.sleep(sleep_time)
1009
1010             if context is None:
1011                 miss = t - (time.time() - start)
1012                 self.msg("packet %s [miss %.3f pid %d]" % (t, miss,
1013                                                            os.getpid()))
1014                 continue
1015
1016             packet_start = time.time()
1017             try:
1018                 fn(self, self, context)
1019                 end = time.time()
1020                 duration = end - packet_start
1021                 print("%f\tDNS\tdns\t0\t%f\tTrue\t" % (end, duration))
1022             except Exception as e:
1023                 end = time.time()
1024                 duration = end - packet_start
1025                 print("%f\tDNS\tdns\t0\t%f\tFalse\t%s" % (end, duration, e))
1026
1027
1028 def ingest_summaries(files, dns_mode='count'):
1029     """Load a summary traffic summary file and generated Converations from it.
1030     """
1031
1032     dns_counts = defaultdict(int)
1033     packets = []
1034     for f in files:
1035         if isinstance(f, str):
1036             f = open(f)
1037         print("Ingesting %s" % (f.name,), file=sys.stderr)
1038         for line in f:
1039             p = Packet.from_line(line)
1040             if p.protocol == 'dns' and dns_mode != 'include':
1041                 dns_counts[p.opcode] += 1
1042             else:
1043                 packets.append(p)
1044
1045         f.close()
1046
1047     if not packets:
1048         return [], 0
1049
1050     start_time = min(p.timestamp for p in packets)
1051     last_packet = max(p.timestamp for p in packets)
1052
1053     print("gathering packets into conversations", file=sys.stderr)
1054     conversations = OrderedDict()
1055     for p in packets:
1056         p.timestamp -= start_time
1057         c = conversations.get(p.endpoints)
1058         if c is None:
1059             c = Conversation()
1060             conversations[p.endpoints] = c
1061         c.add_packet(p)
1062
1063     # We only care about conversations with actual traffic, so we
1064     # filter out conversations with nothing to say. We do that here,
1065     # rather than earlier, because those empty packets contain useful
1066     # hints as to which end of the conversation was the client.
1067     conversation_list = []
1068     for c in conversations.values():
1069         if len(c) != 0:
1070             conversation_list.append(c)
1071
1072     # This is obviously not correct, as many conversations will appear
1073     # to start roughly simultaneously at the beginning of the snapshot.
1074     # To which we say: oh well, so be it.
1075     duration = float(last_packet - start_time)
1076     mean_interval = len(conversations) / duration
1077
1078     return conversation_list, mean_interval, duration, dns_counts
1079
1080
1081 def guess_server_address(conversations):
1082     # we guess the most common address.
1083     addresses = Counter()
1084     for c in conversations:
1085         addresses.update(c.endpoints)
1086     if addresses:
1087         return addresses.most_common(1)[0]
1088
1089
1090 def stringify_keys(x):
1091     y = {}
1092     for k, v in x.items():
1093         k2 = '\t'.join(k)
1094         y[k2] = v
1095     return y
1096
1097
1098 def unstringify_keys(x):
1099     y = {}
1100     for k, v in x.items():
1101         t = tuple(str(k).split('\t'))
1102         y[t] = v
1103     return y
1104
1105
1106 class TrafficModel(object):
1107     def __init__(self, n=3):
1108         self.ngrams = {}
1109         self.query_details = {}
1110         self.n = n
1111         self.dns_opcounts = defaultdict(int)
1112         self.cumulative_duration = 0.0
1113         self.conversation_rate = [0, 1]
1114
1115     def learn(self, conversations, dns_opcounts={}):
1116         prev = 0.0
1117         cum_duration = 0.0
1118         key = (NON_PACKET,) * (self.n - 1)
1119
1120         server = guess_server_address(conversations)
1121
1122         for k, v in dns_opcounts.items():
1123             self.dns_opcounts[k] += v
1124
1125         if len(conversations) > 1:
1126             elapsed =\
1127                 conversations[-1].start_time - conversations[0].start_time
1128             self.conversation_rate[0] = len(conversations)
1129             self.conversation_rate[1] = elapsed
1130
1131         for c in conversations:
1132             client, server = c.guess_client_server(server)
1133             cum_duration += c.get_duration()
1134             key = (NON_PACKET,) * (self.n - 1)
1135             for p in c:
1136                 if p.src != client:
1137                     continue
1138
1139                 elapsed = p.timestamp - prev
1140                 prev = p.timestamp
1141                 if elapsed > WAIT_THRESHOLD:
1142                     # add the wait as an extra state
1143                     wait = 'wait:%d' % (math.log(max(1.0,
1144                                                      elapsed * WAIT_SCALE)))
1145                     self.ngrams.setdefault(key, []).append(wait)
1146                     key = key[1:] + (wait,)
1147
1148                 short_p = p.as_packet_type()
1149                 self.query_details.setdefault(short_p,
1150                                               []).append(tuple(p.extra))
1151                 self.ngrams.setdefault(key, []).append(short_p)
1152                 key = key[1:] + (short_p,)
1153
1154         self.cumulative_duration += cum_duration
1155         # add in the end
1156         self.ngrams.setdefault(key, []).append(NON_PACKET)
1157
1158     def save(self, f):
1159         ngrams = {}
1160         for k, v in self.ngrams.items():
1161             k = '\t'.join(k)
1162             ngrams[k] = dict(Counter(v))
1163
1164         query_details = {}
1165         for k, v in self.query_details.items():
1166             query_details[k] = dict(Counter('\t'.join(x) if x else '-'
1167                                             for x in v))
1168
1169         d = {
1170             'ngrams': ngrams,
1171             'query_details': query_details,
1172             'cumulative_duration': self.cumulative_duration,
1173             'conversation_rate': self.conversation_rate,
1174         }
1175         d['dns'] = self.dns_opcounts
1176
1177         if isinstance(f, str):
1178             f = open(f, 'w')
1179
1180         json.dump(d, f, indent=2)
1181
1182     def load(self, f):
1183         if isinstance(f, str):
1184             f = open(f)
1185
1186         d = json.load(f)
1187
1188         for k, v in d['ngrams'].items():
1189             k = tuple(str(k).split('\t'))
1190             values = self.ngrams.setdefault(k, [])
1191             for p, count in v.items():
1192                 values.extend([str(p)] * count)
1193
1194         for k, v in d['query_details'].items():
1195             values = self.query_details.setdefault(str(k), [])
1196             for p, count in v.items():
1197                 if p == '-':
1198                     values.extend([()] * count)
1199                 else:
1200                     values.extend([tuple(str(p).split('\t'))] * count)
1201
1202         if 'dns' in d:
1203             for k, v in d['dns'].items():
1204                 self.dns_opcounts[k] += v
1205
1206         self.cumulative_duration = d['cumulative_duration']
1207         self.conversation_rate = d['conversation_rate']
1208
1209     def construct_conversation(self, timestamp=0.0, client=2, server=1,
1210                                hard_stop=None, packet_rate=1):
1211         """Construct a individual converation from the model."""
1212
1213         c = Conversation(timestamp, (server, client))
1214
1215         key = (NON_PACKET,) * (self.n - 1)
1216
1217         while key in self.ngrams:
1218             p = random.choice(self.ngrams.get(key, NON_PACKET))
1219             if p == NON_PACKET:
1220                 break
1221             if p in self.query_details:
1222                 extra = random.choice(self.query_details[p])
1223             else:
1224                 extra = []
1225
1226             protocol, opcode = p.split(':', 1)
1227             if protocol == 'wait':
1228                 log_wait_time = int(opcode) + random.random()
1229                 wait = math.exp(log_wait_time) / (WAIT_SCALE * packet_rate)
1230                 timestamp += wait
1231             else:
1232                 log_wait = random.uniform(*NO_WAIT_LOG_TIME_RANGE)
1233                 wait = math.exp(log_wait) / packet_rate
1234                 timestamp += wait
1235                 if hard_stop is not None and timestamp > hard_stop:
1236                     break
1237                 c.add_short_packet(timestamp, protocol, opcode, extra)
1238
1239             key = key[1:] + (p,)
1240
1241         return c
1242
1243     def generate_conversations(self, rate, duration, packet_rate=1):
1244         """Generate a list of conversations from the model."""
1245
1246         # We run the simulation for at least ten times as long as our
1247         # desired duration, and take a section near the start.
1248         rate_n, rate_t  = self.conversation_rate
1249
1250         duration2 = max(rate_t, duration * 2)
1251         n = rate * duration2 * rate_n / rate_t
1252
1253         server = 1
1254         client = 2
1255
1256         conversations = []
1257         end = duration2
1258         start = end - duration
1259
1260         while client < n + 2:
1261             start = random.uniform(0, duration2)
1262             c = self.construct_conversation(start,
1263                                             client,
1264                                             server,
1265                                             hard_stop=(duration2 * 5),
1266                                             packet_rate=packet_rate)
1267
1268             c.forget_packets_outside_window(start, end)
1269             c.renormalise_times(start)
1270             if len(c) != 0:
1271                 conversations.append(c)
1272             client += 1
1273
1274         print(("we have %d conversations at rate %f" %
1275                (len(conversations), rate)), file=sys.stderr)
1276         conversations.sort()
1277         return conversations
1278
1279
1280 IP_PROTOCOLS = {
1281     'dns': '11',
1282     'rpc_netlogon': '06',
1283     'kerberos': '06',      # ratio 16248:258
1284     'smb': '06',
1285     'smb2': '06',
1286     'ldap': '06',
1287     'cldap': '11',
1288     'lsarpc': '06',
1289     'samr': '06',
1290     'dcerpc': '06',
1291     'epm': '06',
1292     'drsuapi': '06',
1293     'browser': '11',
1294     'smb_netlogon': '11',
1295     'srvsvc': '06',
1296     'nbns': '11',
1297 }
1298
1299 OP_DESCRIPTIONS = {
1300     ('browser', '0x01'): 'Host Announcement (0x01)',
1301     ('browser', '0x02'): 'Request Announcement (0x02)',
1302     ('browser', '0x08'): 'Browser Election Request (0x08)',
1303     ('browser', '0x09'): 'Get Backup List Request (0x09)',
1304     ('browser', '0x0c'): 'Domain/Workgroup Announcement (0x0c)',
1305     ('browser', '0x0f'): 'Local Master Announcement (0x0f)',
1306     ('cldap', '3'): 'searchRequest',
1307     ('cldap', '5'): 'searchResDone',
1308     ('dcerpc', '0'): 'Request',
1309     ('dcerpc', '11'): 'Bind',
1310     ('dcerpc', '12'): 'Bind_ack',
1311     ('dcerpc', '13'): 'Bind_nak',
1312     ('dcerpc', '14'): 'Alter_context',
1313     ('dcerpc', '15'): 'Alter_context_resp',
1314     ('dcerpc', '16'): 'AUTH3',
1315     ('dcerpc', '2'): 'Response',
1316     ('dns', '0'): 'query',
1317     ('dns', '1'): 'response',
1318     ('drsuapi', '0'): 'DsBind',
1319     ('drsuapi', '12'): 'DsCrackNames',
1320     ('drsuapi', '13'): 'DsWriteAccountSpn',
1321     ('drsuapi', '1'): 'DsUnbind',
1322     ('drsuapi', '2'): 'DsReplicaSync',
1323     ('drsuapi', '3'): 'DsGetNCChanges',
1324     ('drsuapi', '4'): 'DsReplicaUpdateRefs',
1325     ('epm', '3'): 'Map',
1326     ('kerberos', ''): '',
1327     ('ldap', '0'): 'bindRequest',
1328     ('ldap', '1'): 'bindResponse',
1329     ('ldap', '2'): 'unbindRequest',
1330     ('ldap', '3'): 'searchRequest',
1331     ('ldap', '4'): 'searchResEntry',
1332     ('ldap', '5'): 'searchResDone',
1333     ('ldap', ''): '*** Unknown ***',
1334     ('lsarpc', '14'): 'lsa_LookupNames',
1335     ('lsarpc', '15'): 'lsa_LookupSids',
1336     ('lsarpc', '39'): 'lsa_QueryTrustedDomainInfoBySid',
1337     ('lsarpc', '40'): 'lsa_SetTrustedDomainInfo',
1338     ('lsarpc', '6'): 'lsa_OpenPolicy',
1339     ('lsarpc', '76'): 'lsa_LookupSids3',
1340     ('lsarpc', '77'): 'lsa_LookupNames4',
1341     ('nbns', '0'): 'query',
1342     ('nbns', '1'): 'response',
1343     ('rpc_netlogon', '21'): 'NetrLogonDummyRoutine1',
1344     ('rpc_netlogon', '26'): 'NetrServerAuthenticate3',
1345     ('rpc_netlogon', '29'): 'NetrLogonGetDomainInfo',
1346     ('rpc_netlogon', '30'): 'NetrServerPasswordSet2',
1347     ('rpc_netlogon', '39'): 'NetrLogonSamLogonEx',
1348     ('rpc_netlogon', '40'): 'DsrEnumerateDomainTrusts',
1349     ('rpc_netlogon', '45'): 'NetrLogonSamLogonWithFlags',
1350     ('rpc_netlogon', '4'): 'NetrServerReqChallenge',
1351     ('samr', '0',): 'Connect',
1352     ('samr', '16'): 'GetAliasMembership',
1353     ('samr', '17'): 'LookupNames',
1354     ('samr', '18'): 'LookupRids',
1355     ('samr', '19'): 'OpenGroup',
1356     ('samr', '1'): 'Close',
1357     ('samr', '25'): 'QueryGroupMember',
1358     ('samr', '34'): 'OpenUser',
1359     ('samr', '36'): 'QueryUserInfo',
1360     ('samr', '39'): 'GetGroupsForUser',
1361     ('samr', '3'): 'QuerySecurity',
1362     ('samr', '5'): 'LookupDomain',
1363     ('samr', '64'): 'Connect5',
1364     ('samr', '6'): 'EnumDomains',
1365     ('samr', '7'): 'OpenDomain',
1366     ('samr', '8'): 'QueryDomainInfo',
1367     ('smb', '0x04'): 'Close (0x04)',
1368     ('smb', '0x24'): 'Locking AndX (0x24)',
1369     ('smb', '0x2e'): 'Read AndX (0x2e)',
1370     ('smb', '0x32'): 'Trans2 (0x32)',
1371     ('smb', '0x71'): 'Tree Disconnect (0x71)',
1372     ('smb', '0x72'): 'Negotiate Protocol (0x72)',
1373     ('smb', '0x73'): 'Session Setup AndX (0x73)',
1374     ('smb', '0x74'): 'Logoff AndX (0x74)',
1375     ('smb', '0x75'): 'Tree Connect AndX (0x75)',
1376     ('smb', '0xa2'): 'NT Create AndX (0xa2)',
1377     ('smb2', '0'): 'NegotiateProtocol',
1378     ('smb2', '11'): 'Ioctl',
1379     ('smb2', '14'): 'Find',
1380     ('smb2', '16'): 'GetInfo',
1381     ('smb2', '18'): 'Break',
1382     ('smb2', '1'): 'SessionSetup',
1383     ('smb2', '2'): 'SessionLogoff',
1384     ('smb2', '3'): 'TreeConnect',
1385     ('smb2', '4'): 'TreeDisconnect',
1386     ('smb2', '5'): 'Create',
1387     ('smb2', '6'): 'Close',
1388     ('smb2', '8'): 'Read',
1389     ('smb_netlogon', '0x12'): 'SAM LOGON request from client (0x12)',
1390     ('smb_netlogon', '0x17'): ('SAM Active Directory Response - '
1391                                'user unknown (0x17)'),
1392     ('srvsvc', '16'): 'NetShareGetInfo',
1393     ('srvsvc', '21'): 'NetSrvGetInfo',
1394 }
1395
1396
1397 def expand_short_packet(p, timestamp, src, dest, extra):
1398     protocol, opcode = p.split(':', 1)
1399     desc = OP_DESCRIPTIONS.get((protocol, opcode), '')
1400     ip_protocol = IP_PROTOCOLS.get(protocol, '06')
1401
1402     line = [timestamp, ip_protocol, '', src, dest, protocol, opcode, desc]
1403     line.extend(extra)
1404     return '\t'.join(line)
1405
1406
1407 def replay(conversations,
1408            host=None,
1409            creds=None,
1410            lp=None,
1411            accounts=None,
1412            dns_rate=0,
1413            duration=None,
1414            **kwargs):
1415
1416     context = ReplayContext(server=host,
1417                             creds=creds,
1418                             lp=lp,
1419                             **kwargs)
1420
1421     if len(accounts) < len(conversations):
1422         print(("we have %d accounts but %d conversations" %
1423                (accounts, conversations)), file=sys.stderr)
1424
1425     cstack = list(zip(
1426         sorted(conversations, key=lambda x: x.start_time, reverse=True),
1427         accounts))
1428
1429     # Set the process group so that the calling scripts are not killed
1430     # when the forked child processes are killed.
1431     os.setpgrp()
1432
1433     start = time.time()
1434
1435     if duration is None:
1436         # end 1 second after the last packet of the last conversation
1437         # to start. Conversations other than the last could still be
1438         # going, but we don't care.
1439         duration = cstack[0][0].packets[-1].timestamp + 1.0
1440         print("We will stop after %.1f seconds" % duration,
1441               file=sys.stderr)
1442
1443     end = start + duration
1444
1445     LOGGER.info("Replaying traffic for %u conversations over %d seconds"
1446           % (len(conversations), duration))
1447
1448     children = {}
1449     if dns_rate:
1450         dns_hammer = DnsHammer(dns_rate, duration)
1451         cstack.append((dns_hammer, None))
1452
1453     try:
1454         while True:
1455             # we spawn a batch, wait for finishers, then spawn another
1456             now = time.time()
1457             batch_end = min(now + 2.0, end)
1458             fork_time = 0.0
1459             fork_n = 0
1460             while cstack:
1461                 c, account = cstack.pop()
1462                 if c.start_time + start > batch_end:
1463                     cstack.append((c, account))
1464                     break
1465
1466                 st = time.time()
1467                 pid = c.replay_in_fork_with_delay(start, context, account)
1468                 children[pid] = c
1469                 t = time.time()
1470                 elapsed = t - st
1471                 fork_time += elapsed
1472                 fork_n += 1
1473                 print("forked %s in pid %s (in %fs)" % (c, pid,
1474                                                         elapsed),
1475                       file=sys.stderr)
1476
1477             if fork_n:
1478                 print(("forked %d times in %f seconds (avg %f)" %
1479                        (fork_n, fork_time, fork_time / fork_n)),
1480                       file=sys.stderr)
1481             elif cstack:
1482                 debug(2, "no forks in batch ending %f" % batch_end)
1483
1484             while time.time() < batch_end - 1.0:
1485                 time.sleep(0.01)
1486                 try:
1487                     pid, status = os.waitpid(-1, os.WNOHANG)
1488                 except OSError as e:
1489                     if e.errno != 10:  # no child processes
1490                         raise
1491                     break
1492                 if pid:
1493                     c = children.pop(pid, None)
1494                     print(("process %d finished conversation %s;"
1495                            " %d to go" %
1496                            (pid, c, len(children))), file=sys.stderr)
1497
1498             if time.time() >= end:
1499                 print("time to stop", file=sys.stderr)
1500                 break
1501
1502     except Exception:
1503         print("EXCEPTION in parent", file=sys.stderr)
1504         traceback.print_exc()
1505     finally:
1506         for s in (15, 15, 9):
1507             print(("killing %d children with -%d" %
1508                    (len(children), s)), file=sys.stderr)
1509             for pid in children:
1510                 try:
1511                     os.kill(pid, s)
1512                 except OSError as e:
1513                     if e.errno != 3:  # don't fail if it has already died
1514                         raise
1515             time.sleep(0.5)
1516             end = time.time() + 1
1517             while children:
1518                 try:
1519                     pid, status = os.waitpid(-1, os.WNOHANG)
1520                 except OSError as e:
1521                     if e.errno != 10:
1522                         raise
1523                 if pid != 0:
1524                     c = children.pop(pid, None)
1525                     print(("kill -%d %d KILLED conversation %s; "
1526                            "%d to go" %
1527                            (s, pid, c, len(children))),
1528                           file=sys.stderr)
1529                 if time.time() >= end:
1530                     break
1531
1532             if not children:
1533                 break
1534             time.sleep(1)
1535
1536         if children:
1537             print("%d children are missing" % len(children),
1538                   file=sys.stderr)
1539
1540         # there may be stragglers that were forked just as ^C was hit
1541         # and don't appear in the list of children. We can get them
1542         # with killpg, but that will also kill us, so this is^H^H would be
1543         # goodbye, except we cheat and pretend to use ^C (SIG_INTERRUPT),
1544         # so as not to have to fuss around writing signal handlers.
1545         try:
1546             os.killpg(0, 2)
1547         except KeyboardInterrupt:
1548             print("ignoring fake ^C", file=sys.stderr)
1549
1550
1551 def openLdb(host, creds, lp):
1552     session = system_session()
1553     ldb = SamDB(url="ldap://%s" % host,
1554                 session_info=session,
1555                 options=['modules:paged_searches'],
1556                 credentials=creds,
1557                 lp=lp)
1558     return ldb
1559
1560
1561 def ou_name(ldb, instance_id):
1562     """Generate an ou name from the instance id"""
1563     return "ou=instance-%d,ou=traffic_replay,%s" % (instance_id,
1564                                                     ldb.domain_dn())
1565
1566
1567 def create_ou(ldb, instance_id):
1568     """Create an ou, all created user and machine accounts will belong to it.
1569
1570     This allows all the created resources to be cleaned up easily.
1571     """
1572     ou = ou_name(ldb, instance_id)
1573     try:
1574         ldb.add({"dn": ou.split(',', 1)[1],
1575                  "objectclass": "organizationalunit"})
1576     except LdbError as e:
1577         (status, _) = e.args
1578         # ignore already exists
1579         if status != 68:
1580             raise
1581     try:
1582         ldb.add({"dn": ou,
1583                  "objectclass": "organizationalunit"})
1584     except LdbError as e:
1585         (status, _) = e.args
1586         # ignore already exists
1587         if status != 68:
1588             raise
1589     return ou
1590
1591
1592 class ConversationAccounts(object):
1593     """Details of the machine and user accounts associated with a conversation.
1594     """
1595     def __init__(self, netbios_name, machinepass, username, userpass):
1596         self.netbios_name = netbios_name
1597         self.machinepass  = machinepass
1598         self.username     = username
1599         self.userpass     = userpass
1600
1601
1602 def generate_replay_accounts(ldb, instance_id, number, password):
1603     """Generate a series of unique machine and user account names."""
1604
1605     generate_traffic_accounts(ldb, instance_id, number, password)
1606     accounts = []
1607     for i in range(1, number + 1):
1608         netbios_name = "STGM-%d-%d" % (instance_id, i)
1609         username     = "STGU-%d-%d" % (instance_id, i)
1610
1611         account = ConversationAccounts(netbios_name, password, username,
1612                                        password)
1613         accounts.append(account)
1614     return accounts
1615
1616
1617 def generate_traffic_accounts(ldb, instance_id, number, password):
1618     """Create the specified number of user and machine accounts.
1619
1620     As accounts are not explicitly deleted between runs. This function starts
1621     with the last account and iterates backwards stopping either when it
1622     finds an already existing account or it has generated all the required
1623     accounts.
1624     """
1625     print(("Generating machine and conversation accounts, "
1626            "as required for %d conversations" % number),
1627           file=sys.stderr)
1628     added = 0
1629     for i in range(number, 0, -1):
1630         try:
1631             netbios_name = "STGM-%d-%d" % (instance_id, i)
1632             create_machine_account(ldb, instance_id, netbios_name, password)
1633             added += 1
1634             if added % 50 == 0:
1635                 LOGGER.info("Created %u/%u machine accounts" % (added, number))
1636         except LdbError as e:
1637             (status, _) = e.args
1638             if status == 68:
1639                 break
1640             else:
1641                 raise
1642     if added > 0:
1643         LOGGER.info("Added %d new machine accounts" % added)
1644
1645     added = 0
1646     for i in range(number, 0, -1):
1647         try:
1648             username = "STGU-%d-%d" % (instance_id, i)
1649             create_user_account(ldb, instance_id, username, password)
1650             added += 1
1651             if added % 50 == 0:
1652                 LOGGER.info("Created %u/%u users" % (added, number))
1653
1654         except LdbError as e:
1655             (status, _) = e.args
1656             if status == 68:
1657                 break
1658             else:
1659                 raise
1660
1661     if added > 0:
1662         LOGGER.info("Added %d new user accounts" % added)
1663
1664
1665 def create_machine_account(ldb, instance_id, netbios_name, machinepass):
1666     """Create a machine account via ldap."""
1667
1668     ou = ou_name(ldb, instance_id)
1669     dn = "cn=%s,%s" % (netbios_name, ou)
1670     utf16pw = ('"%s"' % get_string(machinepass)).encode('utf-16-le')
1671
1672     ldb.add({
1673         "dn": dn,
1674         "objectclass": "computer",
1675         "sAMAccountName": "%s$" % netbios_name,
1676         "userAccountControl":
1677             str(UF_TRUSTED_FOR_DELEGATION | UF_SERVER_TRUST_ACCOUNT),
1678         "unicodePwd": utf16pw})
1679
1680
1681 def create_user_account(ldb, instance_id, username, userpass):
1682     """Create a user account via ldap."""
1683     ou = ou_name(ldb, instance_id)
1684     user_dn = "cn=%s,%s" % (username, ou)
1685     utf16pw = ('"%s"' % get_string(userpass)).encode('utf-16-le')
1686     ldb.add({
1687         "dn": user_dn,
1688         "objectclass": "user",
1689         "sAMAccountName": username,
1690         "userAccountControl": str(UF_NORMAL_ACCOUNT),
1691         "unicodePwd": utf16pw
1692     })
1693
1694     # grant user write permission to do things like write account SPN
1695     sdutils = sd_utils.SDUtils(ldb)
1696     sdutils.dacl_add_ace(user_dn, "(A;;WP;;;PS)")
1697
1698
1699 def create_group(ldb, instance_id, name):
1700     """Create a group via ldap."""
1701
1702     ou = ou_name(ldb, instance_id)
1703     dn = "cn=%s,%s" % (name, ou)
1704     ldb.add({
1705         "dn": dn,
1706         "objectclass": "group",
1707         "sAMAccountName": name,
1708     })
1709
1710
1711 def user_name(instance_id, i):
1712     """Generate a user name based in the instance id"""
1713     return "STGU-%d-%d" % (instance_id, i)
1714
1715
1716 def search_objectclass(ldb, objectclass='user', attr='sAMAccountName'):
1717     """Seach objectclass, return attr in a set"""
1718     objs = ldb.search(
1719         expression="(objectClass={})".format(objectclass),
1720         attrs=[attr]
1721     )
1722     return {str(obj[attr]) for obj in objs}
1723
1724
1725 def generate_users(ldb, instance_id, number, password):
1726     """Add users to the server"""
1727     existing_objects = search_objectclass(ldb, objectclass='user')
1728     users = 0
1729     for i in range(number, 0, -1):
1730         name = user_name(instance_id, i)
1731         if name not in existing_objects:
1732             create_user_account(ldb, instance_id, name, password)
1733             users += 1
1734             if users % 50 == 0:
1735                 LOGGER.info("Created %u/%u users" % (users, number))
1736
1737     return users
1738
1739
1740 def group_name(instance_id, i):
1741     """Generate a group name from instance id."""
1742     return "STGG-%d-%d" % (instance_id, i)
1743
1744
1745 def generate_groups(ldb, instance_id, number):
1746     """Create the required number of groups on the server."""
1747     existing_objects = search_objectclass(ldb, objectclass='group')
1748     groups = 0
1749     for i in range(number, 0, -1):
1750         name = group_name(instance_id, i)
1751         if name not in existing_objects:
1752             create_group(ldb, instance_id, name)
1753             groups += 1
1754             if groups % 1000 == 0:
1755                 LOGGER.info("Created %u/%u groups" % (groups, number))
1756
1757     return groups
1758
1759
1760 def clean_up_accounts(ldb, instance_id):
1761     """Remove the created accounts and groups from the server."""
1762     ou = ou_name(ldb, instance_id)
1763     try:
1764         ldb.delete(ou, ["tree_delete:1"])
1765     except LdbError as e:
1766         (status, _) = e.args
1767         # ignore does not exist
1768         if status != 32:
1769             raise
1770
1771
1772 def generate_users_and_groups(ldb, instance_id, password,
1773                               number_of_users, number_of_groups,
1774                               group_memberships):
1775     """Generate the required users and groups, allocating the users to
1776        those groups."""
1777     memberships_added = 0
1778     groups_added  = 0
1779
1780     create_ou(ldb, instance_id)
1781
1782     LOGGER.info("Generating dummy user accounts")
1783     users_added = generate_users(ldb, instance_id, number_of_users, password)
1784
1785     if number_of_groups > 0:
1786         LOGGER.info("Generating dummy groups")
1787         groups_added = generate_groups(ldb, instance_id, number_of_groups)
1788
1789     if group_memberships > 0:
1790         LOGGER.info("Assigning users to groups")
1791         assignments = GroupAssignments(number_of_groups,
1792                                        groups_added,
1793                                        number_of_users,
1794                                        users_added,
1795                                        group_memberships)
1796         LOGGER.info("Adding users to groups")
1797         add_users_to_groups(ldb, instance_id, assignments)
1798         memberships_added = assignments.total()
1799
1800     if (groups_added > 0 and users_added == 0 and
1801        number_of_groups != groups_added):
1802         LOGGER.warning("The added groups will contain no members")
1803
1804     LOGGER.info("Added %d users, %d groups and %d group memberships" %
1805                 (users_added, groups_added, memberships_added))
1806
1807
1808 class GroupAssignments(object):
1809     def __init__(self, number_of_groups, groups_added, number_of_users,
1810                  users_added, group_memberships):
1811
1812         self.count = 0
1813         self.generate_group_distribution(number_of_groups)
1814         self.generate_user_distribution(number_of_users, group_memberships)
1815         self.assignments = self.assign_groups(number_of_groups,
1816                                               groups_added,
1817                                               number_of_users,
1818                                               users_added,
1819                                               group_memberships)
1820
1821     def cumulative_distribution(self, weights):
1822         # make sure the probabilities conform to a cumulative distribution
1823         # spread between 0.0 and 1.0. Dividing by the weighted total gives each
1824         # probability a proportional share of 1.0. Higher probabilities get a
1825         # bigger share, so are more likely to be picked. We use the cumulative
1826         # value, so we can use random.random() as a simple index into the list
1827         dist = []
1828         total = sum(weights)
1829         cumulative = 0.0
1830         for probability in weights:
1831             cumulative += probability
1832             dist.append(cumulative / total)
1833         return dist
1834
1835     def generate_user_distribution(self, num_users, num_memberships):
1836         """Probability distribution of a user belonging to a group.
1837         """
1838         # Assign a weighted probability to each user. Use the Pareto
1839         # Distribution so that some users are in a lot of groups, and the
1840         # bulk of users are in only a few groups. If we're assigning a large
1841         # number of group memberships, use a higher shape. This means slightly
1842         # fewer outlying users that are in large numbers of groups. The aim is
1843         # to have no users belonging to more than ~500 groups.
1844         if num_memberships > 5000000:
1845             shape = 3.0
1846         elif num_memberships > 2000000:
1847             shape = 2.5
1848         elif num_memberships > 300000:
1849             shape = 2.25
1850         else:
1851             shape = 1.75
1852
1853         weights = []
1854         for x in range(1, num_users + 1):
1855             p = random.paretovariate(shape)
1856             weights.append(p)
1857
1858         # convert the weights to a cumulative distribution between 0.0 and 1.0
1859         self.user_dist = self.cumulative_distribution(weights)
1860
1861     def generate_group_distribution(self, n):
1862         """Probability distribution of a group containing a user."""
1863
1864         # Assign a weighted probability to each user. Probability decreases
1865         # as the group-ID increases
1866         weights = []
1867         for x in range(1, n + 1):
1868             p = 1 / (x**1.3)
1869             weights.append(p)
1870
1871         # convert the weights to a cumulative distribution between 0.0 and 1.0
1872         self.group_dist = self.cumulative_distribution(weights)
1873
1874     def generate_random_membership(self):
1875         """Returns a randomly generated user-group membership"""
1876
1877         # the list items are cumulative distribution values between 0.0 and
1878         # 1.0, which makes random() a handy way to index the list to get a
1879         # weighted random user/group. (Here the user/group returned are
1880         # zero-based array indexes)
1881         user = bisect.bisect(self.user_dist, random.random())
1882         group = bisect.bisect(self.group_dist, random.random())
1883
1884         return user, group
1885
1886     def users_in_group(self, group):
1887         return self.assignments[group]
1888
1889     def get_groups(self):
1890         return self.assignments.keys()
1891
1892     def assign_groups(self, number_of_groups, groups_added,
1893                       number_of_users, users_added, group_memberships):
1894         """Allocate users to groups.
1895
1896         The intention is to have a few users that belong to most groups, while
1897         the majority of users belong to a few groups.
1898
1899         A few groups will contain most users, with the remaining only having a
1900         few users.
1901         """
1902
1903         assignments = set()
1904         if group_memberships <= 0:
1905             return {}
1906
1907         # Calculate the number of group menberships required
1908         group_memberships = math.ceil(
1909             float(group_memberships) *
1910             (float(users_added) / float(number_of_users)))
1911
1912         existing_users  = number_of_users  - users_added  - 1
1913         existing_groups = number_of_groups - groups_added - 1
1914         while len(assignments) < group_memberships:
1915             user, group = self.generate_random_membership()
1916
1917             if group > existing_groups or user > existing_users:
1918                 # the + 1 converts the array index to the corresponding
1919                 # group or user number
1920                 assignments.add(((user + 1), (group + 1)))
1921
1922         # convert the set into a dictionary, where key=group, value=list-of-
1923         # users-in-group (indexing by group-ID allows us to optimize for
1924         # DB membership writes)
1925         assignment_dict = defaultdict(list)
1926         for (user, group) in assignments:
1927             assignment_dict[group].append(user)
1928             self.count += 1
1929
1930         return assignment_dict
1931
1932     def total(self):
1933         return self.count
1934
1935
1936 def add_users_to_groups(db, instance_id, assignments):
1937     """Takes the assignments of users to groups and applies them to the DB."""
1938
1939     total = assignments.total()
1940     count = 0
1941     added = 0
1942
1943     for group in assignments.get_groups():
1944         users_in_group = assignments.users_in_group(group)
1945         if len(users_in_group) == 0:
1946             continue
1947
1948         # Split up the users into chunks, so we write no more than 1K at a
1949         # time. (Minimizing the DB modifies is more efficient, but writing
1950         # 10K+ users to a single group becomes inefficient memory-wise)
1951         for chunk in range(0, len(users_in_group), 1000):
1952             chunk_of_users = users_in_group[chunk:chunk + 1000]
1953             add_group_members(db, instance_id, group, chunk_of_users)
1954
1955             added += len(chunk_of_users)
1956             count += 1
1957             if count % 50 == 0:
1958                 LOGGER.info("Added %u/%u memberships" % (added, total))
1959
1960 def add_group_members(db, instance_id, group, users_in_group):
1961     """Adds the given users to group specified."""
1962
1963     ou = ou_name(db, instance_id)
1964
1965     def build_dn(name):
1966         return("cn=%s,%s" % (name, ou))
1967
1968     group_dn = build_dn(group_name(instance_id, group))
1969     m = ldb.Message()
1970     m.dn = ldb.Dn(db, group_dn)
1971
1972     for user in users_in_group:
1973         user_dn = build_dn(user_name(instance_id, user))
1974         idx = "member-" + str(user)
1975         m[idx] = ldb.MessageElement(user_dn, ldb.FLAG_MOD_ADD, "member")
1976
1977     db.modify(m)
1978
1979
1980 def generate_stats(statsdir, timing_file):
1981     """Generate and print the summary stats for a run."""
1982     first      = sys.float_info.max
1983     last       = 0
1984     successful = 0
1985     failed     = 0
1986     latencies  = {}
1987     failures   = {}
1988     unique_converations = set()
1989     conversations = 0
1990
1991     if timing_file is not None:
1992         tw = timing_file.write
1993     else:
1994         def tw(x):
1995             pass
1996
1997     tw("time\tconv\tprotocol\ttype\tduration\tsuccessful\terror\n")
1998
1999     for filename in os.listdir(statsdir):
2000         path = os.path.join(statsdir, filename)
2001         with open(path, 'r') as f:
2002             for line in f:
2003                 try:
2004                     fields       = line.rstrip('\n').split('\t')
2005                     conversation = fields[1]
2006                     protocol     = fields[2]
2007                     packet_type  = fields[3]
2008                     latency      = float(fields[4])
2009                     first        = min(float(fields[0]) - latency, first)
2010                     last         = max(float(fields[0]), last)
2011
2012                     if protocol not in latencies:
2013                         latencies[protocol] = {}
2014                     if packet_type not in latencies[protocol]:
2015                         latencies[protocol][packet_type] = []
2016
2017                     latencies[protocol][packet_type].append(latency)
2018
2019                     if protocol not in failures:
2020                         failures[protocol] = {}
2021                     if packet_type not in failures[protocol]:
2022                         failures[protocol][packet_type] = 0
2023
2024                     if fields[5] == 'True':
2025                         successful += 1
2026                     else:
2027                         failed += 1
2028                         failures[protocol][packet_type] += 1
2029
2030                     if conversation not in unique_converations:
2031                         unique_converations.add(conversation)
2032                         conversations += 1
2033
2034                     tw(line)
2035                 except (ValueError, IndexError):
2036                     # not a valid line print and ignore
2037                     print(line, file=sys.stderr)
2038                     pass
2039     duration = last - first
2040     if successful == 0:
2041         success_rate = 0
2042     else:
2043         success_rate = successful / duration
2044     if failed == 0:
2045         failure_rate = 0
2046     else:
2047         failure_rate = failed / duration
2048
2049     print("Total conversations:   %10d" % conversations)
2050     print("Successful operations: %10d (%.3f per second)"
2051           % (successful, success_rate))
2052     print("Failed operations:     %10d (%.3f per second)"
2053           % (failed, failure_rate))
2054
2055     print("Protocol    Op Code  Description                               "
2056           " Count       Failed         Mean       Median          "
2057           "95%        Range          Max")
2058
2059     protocols = sorted(latencies.keys())
2060     for protocol in protocols:
2061         packet_types = sorted(latencies[protocol], key=opcode_key)
2062         for packet_type in packet_types:
2063             values     = latencies[protocol][packet_type]
2064             values     = sorted(values)
2065             count      = len(values)
2066             failed     = failures[protocol][packet_type]
2067             mean       = sum(values) / count
2068             median     = calc_percentile(values, 0.50)
2069             percentile = calc_percentile(values, 0.95)
2070             rng        = values[-1] - values[0]
2071             maxv       = values[-1]
2072             desc       = OP_DESCRIPTIONS.get((protocol, packet_type), '')
2073             if sys.stdout.isatty:
2074                 print("%-12s   %4s  %-35s %12d %12d %12.6f "
2075                       "%12.6f %12.6f %12.6f %12.6f"
2076                       % (protocol,
2077                          packet_type,
2078                          desc,
2079                          count,
2080                          failed,
2081                          mean,
2082                          median,
2083                          percentile,
2084                          rng,
2085                          maxv))
2086             else:
2087                 print("%s\t%s\t%s\t%d\t%d\t%f\t%f\t%f\t%f\t%f"
2088                       % (protocol,
2089                          packet_type,
2090                          desc,
2091                          count,
2092                          failed,
2093                          mean,
2094                          median,
2095                          percentile,
2096                          rng,
2097                          maxv))
2098
2099
2100 def opcode_key(v):
2101     """Sort key for the operation code to ensure that it sorts numerically"""
2102     try:
2103         return "%03d" % int(v)
2104     except:
2105         return v
2106
2107
2108 def calc_percentile(values, percentile):
2109     """Calculate the specified percentile from the list of values.
2110
2111     Assumes the list is sorted in ascending order.
2112     """
2113
2114     if not values:
2115         return 0
2116     k = (len(values) - 1) * percentile
2117     f = math.floor(k)
2118     c = math.ceil(k)
2119     if f == c:
2120         return values[int(k)]
2121     d0 = values[int(f)] * (c - k)
2122     d1 = values[int(c)] * (k - f)
2123     return d0 + d1
2124
2125
2126 def mk_masked_dir(*path):
2127     """In a testenv we end up with 0777 diectories that look an alarming
2128     green colour with ls. Use umask to avoid that."""
2129     d = os.path.join(*path)
2130     mask = os.umask(0o077)
2131     os.mkdir(d)
2132     os.umask(mask)
2133     return d