e11795879f27ac4c10cee09bc9103c82b03074e0
[metze/samba/wip.git] / python / samba / emulate / traffic.py
1 # -*- encoding: utf-8 -*-
2 # Samba traffic replay and learning
3 #
4 # Copyright (C) Catalyst IT Ltd. 2017
5 #
6 # This program is free software; you can redistribute it and/or modify
7 # it under the terms of the GNU General Public License as published by
8 # the Free Software Foundation; either version 3 of the License, or
9 # (at your option) any later version.
10 #
11 # This program is distributed in the hope that it will be useful,
12 # but WITHOUT ANY WARRANTY; without even the implied warranty of
13 # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
14 # GNU General Public License for more details.
15 #
16 # You should have received a copy of the GNU General Public License
17 # along with this program.  If not, see <http://www.gnu.org/licenses/>.
18 #
19 from __future__ import print_function, division
20
21 import time
22 import os
23 import random
24 import json
25 import math
26 import sys
27 import signal
28 from errno import ECHILD, ESRCH
29
30 from collections import OrderedDict, Counter, defaultdict, namedtuple
31 from samba.emulate import traffic_packets
32 from samba.samdb import SamDB
33 import ldb
34 from ldb import LdbError
35 from samba.dcerpc import ClientConnection
36 from samba.dcerpc import security, drsuapi, lsa
37 from samba.dcerpc import netlogon
38 from samba.dcerpc.netlogon import netr_Authenticator
39 from samba.dcerpc import srvsvc
40 from samba.dcerpc import samr
41 from samba.drs_utils import drs_DsBind
42 import traceback
43 from samba.credentials import Credentials, DONT_USE_KERBEROS, MUST_USE_KERBEROS
44 from samba.auth import system_session
45 from samba.dsdb import (
46     UF_NORMAL_ACCOUNT,
47     UF_SERVER_TRUST_ACCOUNT,
48     UF_TRUSTED_FOR_DELEGATION,
49     UF_WORKSTATION_TRUST_ACCOUNT
50 )
51 from samba.dcerpc.misc import SEC_CHAN_BDC
52 from samba import gensec
53 from samba import sd_utils
54 from samba.compat import get_string
55 from samba.logger import get_samba_logger
56 import bisect
57
58 CURRENT_MODEL_VERSION = 2   # save as this
59 REQUIRED_MODEL_VERSION = 2  # load accepts this or greater
60 SLEEP_OVERHEAD = 3e-4
61
62 # we don't use None, because it complicates [de]serialisation
63 NON_PACKET = '-'
64
65 CLIENT_CLUES = {
66     ('dns', '0'): 1.0,      # query
67     ('smb', '0x72'): 1.0,   # Negotiate protocol
68     ('ldap', '0'): 1.0,     # bind
69     ('ldap', '3'): 1.0,     # searchRequest
70     ('ldap', '2'): 1.0,     # unbindRequest
71     ('cldap', '3'): 1.0,
72     ('dcerpc', '11'): 1.0,  # bind
73     ('dcerpc', '14'): 1.0,  # Alter_context
74     ('nbns', '0'): 1.0,     # query
75 }
76
77 SERVER_CLUES = {
78     ('dns', '1'): 1.0,      # response
79     ('ldap', '1'): 1.0,     # bind response
80     ('ldap', '4'): 1.0,     # search result
81     ('ldap', '5'): 1.0,     # search done
82     ('cldap', '5'): 1.0,
83     ('dcerpc', '12'): 1.0,  # bind_ack
84     ('dcerpc', '13'): 1.0,  # bind_nak
85     ('dcerpc', '15'): 1.0,  # Alter_context response
86 }
87
88 SKIPPED_PROTOCOLS = {"smb", "smb2", "browser", "smb_netlogon"}
89
90 WAIT_SCALE = 10.0
91 WAIT_THRESHOLD = (1.0 / WAIT_SCALE)
92 NO_WAIT_LOG_TIME_RANGE = (-10, -3)
93
94 # DEBUG_LEVEL can be changed by scripts with -d
95 DEBUG_LEVEL = 0
96
97 LOGGER = get_samba_logger(name=__name__)
98
99
100 def debug(level, msg, *args):
101     """Print a formatted debug message to standard error.
102
103
104     :param level: The debug level, message will be printed if it is <= the
105                   currently set debug level. The debug level can be set with
106                   the -d option.
107     :param msg:   The message to be logged, can contain C-Style format
108                   specifiers
109     :param args:  The parameters required by the format specifiers
110     """
111     if level <= DEBUG_LEVEL:
112         if not args:
113             print(msg, file=sys.stderr)
114         else:
115             print(msg % tuple(args), file=sys.stderr)
116
117
118 def debug_lineno(*args):
119     """ Print an unformatted log message to stderr, contaning the line number
120     """
121     tb = traceback.extract_stack(limit=2)
122     print((" %s:" "\033[01;33m"
123            "%s " "\033[00m" % (tb[0][2], tb[0][1])), end=' ',
124           file=sys.stderr)
125     for a in args:
126         print(a, file=sys.stderr)
127     print(file=sys.stderr)
128     sys.stderr.flush()
129
130
131 def random_colour_print(seeds):
132     """Return a function that prints a coloured line to stderr. The colour
133     of the line depends on a sort of hash of the integer arguments."""
134     if seeds:
135         s = 214
136         for x in seeds:
137             s += 17
138             s *= x
139             s %= 214
140         prefix = "\033[38;5;%dm" % (18 + s)
141
142         def p(*args):
143             if DEBUG_LEVEL > 0:
144                 for a in args:
145                     print("%s%s\033[00m" % (prefix, a), file=sys.stderr)
146     else:
147         def p(*args):
148             if DEBUG_LEVEL > 0:
149                 for a in args:
150                     print(a, file=sys.stderr)
151
152     return p
153
154
155 class FakePacketError(Exception):
156     pass
157
158
159 class Packet(object):
160     """Details of a network packet"""
161     __slots__ = ('timestamp',
162                  'ip_protocol',
163                  'stream_number',
164                  'src',
165                  'dest',
166                  'protocol',
167                  'opcode',
168                  'desc',
169                  'extra',
170                  'endpoints')
171     def __init__(self, timestamp, ip_protocol, stream_number, src, dest,
172                  protocol, opcode, desc, extra):
173         self.timestamp = timestamp
174         self.ip_protocol = ip_protocol
175         self.stream_number = stream_number
176         self.src = src
177         self.dest = dest
178         self.protocol = protocol
179         self.opcode = opcode
180         self.desc = desc
181         self.extra = extra
182         if self.src < self.dest:
183             self.endpoints = (self.src, self.dest)
184         else:
185             self.endpoints = (self.dest, self.src)
186
187     @classmethod
188     def from_line(cls, line):
189         fields = line.rstrip('\n').split('\t')
190         (timestamp,
191          ip_protocol,
192          stream_number,
193          src,
194          dest,
195          protocol,
196          opcode,
197          desc) = fields[:8]
198         extra = fields[8:]
199
200         timestamp = float(timestamp)
201         src = int(src)
202         dest = int(dest)
203
204         return cls(timestamp, ip_protocol, stream_number, src, dest,
205                    protocol, opcode, desc, extra)
206
207     def as_summary(self, time_offset=0.0):
208         """Format the packet as a traffic_summary line.
209         """
210         extra = '\t'.join(self.extra)
211         t = self.timestamp + time_offset
212         return (t, '%f\t%s\t%s\t%d\t%d\t%s\t%s\t%s\t%s' %
213                 (t,
214                  self.ip_protocol,
215                  self.stream_number or '',
216                  self.src,
217                  self.dest,
218                  self.protocol,
219                  self.opcode,
220                  self.desc,
221                  extra))
222
223     def __str__(self):
224         return ("%.3f: %d -> %d; ip %s; strm %s; prot %s; op %s; desc %s %s" %
225                 (self.timestamp, self.src, self.dest, self.ip_protocol or '-',
226                  self.stream_number, self.protocol, self.opcode, self.desc,
227                  ('«' + ' '.join(self.extra) + '»' if self.extra else '')))
228
229     def __repr__(self):
230         return "<Packet @%s>" % self
231
232     def copy(self):
233         return self.__class__(self.timestamp,
234                               self.ip_protocol,
235                               self.stream_number,
236                               self.src,
237                               self.dest,
238                               self.protocol,
239                               self.opcode,
240                               self.desc,
241                               self.extra)
242
243     def as_packet_type(self):
244         t = '%s:%s' % (self.protocol, self.opcode)
245         return t
246
247     def client_score(self):
248         """A positive number means we think it is a client; a negative number
249         means we think it is a server. Zero means no idea. range: -1 to 1.
250         """
251         key = (self.protocol, self.opcode)
252         if key in CLIENT_CLUES:
253             return CLIENT_CLUES[key]
254         if key in SERVER_CLUES:
255             return -SERVER_CLUES[key]
256         return 0.0
257
258     def play(self, conversation, context):
259         """Send the packet over the network, if required.
260
261         Some packets are ignored, i.e. for  protocols not handled,
262         server response messages, or messages that are generated by the
263         protocol layer associated with other packets.
264         """
265         fn_name = 'packet_%s_%s' % (self.protocol, self.opcode)
266         try:
267             fn = getattr(traffic_packets, fn_name)
268
269         except AttributeError as e:
270             print("Conversation(%s) Missing handler %s" %
271                   (conversation.conversation_id, fn_name),
272                   file=sys.stderr)
273             return
274
275         # Don't display a message for kerberos packets, they're not directly
276         # generated they're used to indicate kerberos should be used
277         if self.protocol != "kerberos":
278             debug(2, "Conversation(%s) Calling handler %s" %
279                      (conversation.conversation_id, fn_name))
280
281         start = time.time()
282         try:
283             if fn(self, conversation, context):
284                 # Only collect timing data for functions that generate
285                 # network traffic, or fail
286                 end = time.time()
287                 duration = end - start
288                 print("%f\t%s\t%s\t%s\t%f\tTrue\t" %
289                       (end, conversation.conversation_id, self.protocol,
290                        self.opcode, duration))
291         except Exception as e:
292             end = time.time()
293             duration = end - start
294             print("%f\t%s\t%s\t%s\t%f\tFalse\t%s" %
295                   (end, conversation.conversation_id, self.protocol,
296                    self.opcode, duration, e))
297
298     def __cmp__(self, other):
299         return self.timestamp - other.timestamp
300
301     def is_really_a_packet(self, missing_packet_stats=None):
302         return is_a_real_packet(self.protocol, self.opcode)
303
304
305 def is_a_real_packet(protocol, opcode):
306     """Is the packet one that can be ignored?
307
308     If so removing it will have no effect on the replay
309     """
310     if protocol in SKIPPED_PROTOCOLS:
311         # Ignore any packets for the protocols we're not interested in.
312         return False
313     if protocol == "ldap" and opcode == '':
314         # skip ldap continuation packets
315         return False
316
317     fn_name = 'packet_%s_%s' % (protocol, opcode)
318     fn = getattr(traffic_packets, fn_name, None)
319     if fn is None:
320         LOGGER.debug("missing packet %s" % fn_name, file=sys.stderr)
321         return False
322     if fn is traffic_packets.null_packet:
323         return False
324     return True
325
326
327 def is_a_traffic_generating_packet(protocol, opcode):
328     """Return true if a packet generates traffic in its own right. Some of
329     these will generate traffic in certain contexts (e.g. ldap unbind
330     after a bind) but not if the conversation consists only of these packets.
331     """
332     if protocol == 'wait':
333         return False
334
335     if (protocol, opcode) in (
336             ('kerberos', ''),
337             ('ldap', '2'),
338             ('dcerpc', '15'),
339             ('dcerpc', '16')):
340         return False
341
342     return is_a_real_packet(protocol, opcode)
343
344
345 class ReplayContext(object):
346     """State/Context for a conversation between an simulated client and a
347        server. Some of the context is shared amongst all conversations
348        and should be generated before the fork, while other context is
349        specific to a particular conversation and should be generated
350        *after* the fork, in generate_process_local_config().
351     """
352     def __init__(self,
353                  server=None,
354                  lp=None,
355                  creds=None,
356                  badpassword_frequency=None,
357                  prefer_kerberos=None,
358                  tempdir=None,
359                  statsdir=None,
360                  ou=None,
361                  base_dn=None,
362                  domain=os.environ.get("DOMAIN"),
363                  domain_sid=None):
364         self.server                   = server
365         self.netlogon_connection      = None
366         self.creds                    = creds
367         self.lp                       = lp
368         self.prefer_kerberos          = prefer_kerberos
369         self.ou                       = ou
370         self.base_dn                  = base_dn
371         self.domain                   = domain
372         self.statsdir                 = statsdir
373         self.global_tempdir           = tempdir
374         self.domain_sid               = domain_sid
375         self.realm                    = lp.get('realm')
376
377         # Bad password attempt controls
378         self.badpassword_frequency    = badpassword_frequency
379         self.last_lsarpc_bad          = False
380         self.last_lsarpc_named_bad    = False
381         self.last_simple_bind_bad     = False
382         self.last_bind_bad            = False
383         self.last_srvsvc_bad          = False
384         self.last_drsuapi_bad         = False
385         self.last_netlogon_bad        = False
386         self.last_samlogon_bad        = False
387         self.generate_ldap_search_tables()
388
389     def generate_ldap_search_tables(self):
390         session = system_session()
391
392         db = SamDB(url="ldap://%s" % self.server,
393                    session_info=session,
394                    credentials=self.creds,
395                    lp=self.lp)
396
397         res = db.search(db.domain_dn(),
398                         scope=ldb.SCOPE_SUBTREE,
399                         controls=["paged_results:1:1000"],
400                         attrs=['dn'])
401
402         # find a list of dns for each pattern
403         # e.g. CN,CN,CN,DC,DC
404         dn_map = {}
405         attribute_clue_map = {
406             'invocationId': []
407         }
408
409         for r in res:
410             dn = str(r.dn)
411             pattern = ','.join(x.lstrip()[:2] for x in dn.split(',')).upper()
412             dns = dn_map.setdefault(pattern, [])
413             dns.append(dn)
414             if dn.startswith('CN=NTDS Settings,'):
415                 attribute_clue_map['invocationId'].append(dn)
416
417         # extend the map in case we are working with a different
418         # number of DC components.
419         # for k, v in self.dn_map.items():
420         #     print >>sys.stderr, k, len(v)
421
422         for k in list(dn_map.keys()):
423             if k[-3:] != ',DC':
424                 continue
425             p = k[:-3]
426             while p[-3:] == ',DC':
427                 p = p[:-3]
428             for i in range(5):
429                 p += ',DC'
430                 if p != k and p in dn_map:
431                     print('dn_map collison %s %s' % (k, p),
432                           file=sys.stderr)
433                     continue
434                 dn_map[p] = dn_map[k]
435
436         self.dn_map = dn_map
437         self.attribute_clue_map = attribute_clue_map
438
439     def generate_process_local_config(self, account, conversation):
440         self.ldap_connections         = []
441         self.dcerpc_connections       = []
442         self.lsarpc_connections       = []
443         self.lsarpc_connections_named = []
444         self.drsuapi_connections      = []
445         self.srvsvc_connections       = []
446         self.samr_contexts            = []
447         self.netbios_name             = account.netbios_name
448         self.machinepass              = account.machinepass
449         self.username                 = account.username
450         self.userpass                 = account.userpass
451
452         self.tempdir = mk_masked_dir(self.global_tempdir,
453                                      'conversation-%d' %
454                                      conversation.conversation_id)
455
456         self.lp.set("private dir", self.tempdir)
457         self.lp.set("lock dir", self.tempdir)
458         self.lp.set("state directory", self.tempdir)
459         self.lp.set("tls verify peer", "no_check")
460
461         self.remoteAddress = "/root/ncalrpc_as_system"
462         self.samlogon_dn   = ("cn=%s,%s" %
463                               (self.netbios_name, self.ou))
464         self.user_dn       = ("cn=%s,%s" %
465                               (self.username, self.ou))
466
467         self.generate_machine_creds()
468         self.generate_user_creds()
469
470     def with_random_bad_credentials(self, f, good, bad, failed_last_time):
471         """Execute the supplied logon function, randomly choosing the
472            bad credentials.
473
474            Based on the frequency in badpassword_frequency randomly perform the
475            function with the supplied bad credentials.
476            If run with bad credentials, the function is re-run with the good
477            credentials.
478            failed_last_time is used to prevent consecutive bad credential
479            attempts. So the over all bad credential frequency will be lower
480            than that requested, but not significantly.
481         """
482         if not failed_last_time:
483             if (self.badpassword_frequency and
484                 random.random() < self.badpassword_frequency):
485                 try:
486                     f(bad)
487                 except:
488                     # Ignore any exceptions as the operation may fail
489                     # as it's being performed with bad credentials
490                     pass
491                 failed_last_time = True
492             else:
493                 failed_last_time = False
494
495         result = f(good)
496         return (result, failed_last_time)
497
498     def generate_user_creds(self):
499         """Generate the conversation specific user Credentials.
500
501         Each Conversation has an associated user account used to simulate
502         any non Administrative user traffic.
503
504         Generates user credentials with good and bad passwords and ldap
505         simple bind credentials with good and bad passwords.
506         """
507         self.user_creds = Credentials()
508         self.user_creds.guess(self.lp)
509         self.user_creds.set_workstation(self.netbios_name)
510         self.user_creds.set_password(self.userpass)
511         self.user_creds.set_username(self.username)
512         self.user_creds.set_domain(self.domain)
513         if self.prefer_kerberos:
514             self.user_creds.set_kerberos_state(MUST_USE_KERBEROS)
515         else:
516             self.user_creds.set_kerberos_state(DONT_USE_KERBEROS)
517
518         self.user_creds_bad = Credentials()
519         self.user_creds_bad.guess(self.lp)
520         self.user_creds_bad.set_workstation(self.netbios_name)
521         self.user_creds_bad.set_password(self.userpass[:-4])
522         self.user_creds_bad.set_username(self.username)
523         if self.prefer_kerberos:
524             self.user_creds_bad.set_kerberos_state(MUST_USE_KERBEROS)
525         else:
526             self.user_creds_bad.set_kerberos_state(DONT_USE_KERBEROS)
527
528         # Credentials for ldap simple bind.
529         self.simple_bind_creds = Credentials()
530         self.simple_bind_creds.guess(self.lp)
531         self.simple_bind_creds.set_workstation(self.netbios_name)
532         self.simple_bind_creds.set_password(self.userpass)
533         self.simple_bind_creds.set_username(self.username)
534         self.simple_bind_creds.set_gensec_features(
535             self.simple_bind_creds.get_gensec_features() | gensec.FEATURE_SEAL)
536         if self.prefer_kerberos:
537             self.simple_bind_creds.set_kerberos_state(MUST_USE_KERBEROS)
538         else:
539             self.simple_bind_creds.set_kerberos_state(DONT_USE_KERBEROS)
540         self.simple_bind_creds.set_bind_dn(self.user_dn)
541
542         self.simple_bind_creds_bad = Credentials()
543         self.simple_bind_creds_bad.guess(self.lp)
544         self.simple_bind_creds_bad.set_workstation(self.netbios_name)
545         self.simple_bind_creds_bad.set_password(self.userpass[:-4])
546         self.simple_bind_creds_bad.set_username(self.username)
547         self.simple_bind_creds_bad.set_gensec_features(
548             self.simple_bind_creds_bad.get_gensec_features() |
549             gensec.FEATURE_SEAL)
550         if self.prefer_kerberos:
551             self.simple_bind_creds_bad.set_kerberos_state(MUST_USE_KERBEROS)
552         else:
553             self.simple_bind_creds_bad.set_kerberos_state(DONT_USE_KERBEROS)
554         self.simple_bind_creds_bad.set_bind_dn(self.user_dn)
555
556     def generate_machine_creds(self):
557         """Generate the conversation specific machine Credentials.
558
559         Each Conversation has an associated machine account.
560
561         Generates machine credentials with good and bad passwords.
562         """
563
564         self.machine_creds = Credentials()
565         self.machine_creds.guess(self.lp)
566         self.machine_creds.set_workstation(self.netbios_name)
567         self.machine_creds.set_secure_channel_type(SEC_CHAN_BDC)
568         self.machine_creds.set_password(self.machinepass)
569         self.machine_creds.set_username(self.netbios_name + "$")
570         self.machine_creds.set_domain(self.domain)
571         if self.prefer_kerberos:
572             self.machine_creds.set_kerberos_state(MUST_USE_KERBEROS)
573         else:
574             self.machine_creds.set_kerberos_state(DONT_USE_KERBEROS)
575
576         self.machine_creds_bad = Credentials()
577         self.machine_creds_bad.guess(self.lp)
578         self.machine_creds_bad.set_workstation(self.netbios_name)
579         self.machine_creds_bad.set_secure_channel_type(SEC_CHAN_BDC)
580         self.machine_creds_bad.set_password(self.machinepass[:-4])
581         self.machine_creds_bad.set_username(self.netbios_name + "$")
582         if self.prefer_kerberos:
583             self.machine_creds_bad.set_kerberos_state(MUST_USE_KERBEROS)
584         else:
585             self.machine_creds_bad.set_kerberos_state(DONT_USE_KERBEROS)
586
587     def get_matching_dn(self, pattern, attributes=None):
588         # If the pattern is an empty string, we assume ROOTDSE,
589         # Otherwise we try adding or removing DC suffixes, then
590         # shorter leading patterns until we hit one.
591         # e.g if there is no CN,CN,CN,CN,DC,DC
592         # we first try       CN,CN,CN,CN,DC
593         # and                CN,CN,CN,CN,DC,DC,DC
594         # then change to        CN,CN,CN,DC,DC
595         # and as last resort we use the base_dn
596         attr_clue = self.attribute_clue_map.get(attributes)
597         if attr_clue:
598             return random.choice(attr_clue)
599
600         pattern = pattern.upper()
601         while pattern:
602             if pattern in self.dn_map:
603                 return random.choice(self.dn_map[pattern])
604             # chop one off the front and try it all again.
605             pattern = pattern[3:]
606
607         return self.base_dn
608
609     def get_dcerpc_connection(self, new=False):
610         guid = '12345678-1234-abcd-ef00-01234567cffb'  # RPC_NETLOGON UUID
611         if self.dcerpc_connections and not new:
612             return self.dcerpc_connections[-1]
613         c = ClientConnection("ncacn_ip_tcp:%s" % self.server,
614                              (guid, 1), self.lp)
615         self.dcerpc_connections.append(c)
616         return c
617
618     def get_srvsvc_connection(self, new=False):
619         if self.srvsvc_connections and not new:
620             return self.srvsvc_connections[-1]
621
622         def connect(creds):
623             return srvsvc.srvsvc("ncacn_np:%s" % (self.server),
624                                  self.lp,
625                                  creds)
626
627         (c, self.last_srvsvc_bad) = \
628             self.with_random_bad_credentials(connect,
629                                              self.user_creds,
630                                              self.user_creds_bad,
631                                              self.last_srvsvc_bad)
632
633         self.srvsvc_connections.append(c)
634         return c
635
636     def get_lsarpc_connection(self, new=False):
637         if self.lsarpc_connections and not new:
638             return self.lsarpc_connections[-1]
639
640         def connect(creds):
641             binding_options = 'schannel,seal,sign'
642             return lsa.lsarpc("ncacn_ip_tcp:%s[%s]" %
643                               (self.server, binding_options),
644                               self.lp,
645                               creds)
646
647         (c, self.last_lsarpc_bad) = \
648             self.with_random_bad_credentials(connect,
649                                              self.machine_creds,
650                                              self.machine_creds_bad,
651                                              self.last_lsarpc_bad)
652
653         self.lsarpc_connections.append(c)
654         return c
655
656     def get_lsarpc_named_pipe_connection(self, new=False):
657         if self.lsarpc_connections_named and not new:
658             return self.lsarpc_connections_named[-1]
659
660         def connect(creds):
661             return lsa.lsarpc("ncacn_np:%s" % (self.server),
662                               self.lp,
663                               creds)
664
665         (c, self.last_lsarpc_named_bad) = \
666             self.with_random_bad_credentials(connect,
667                                              self.machine_creds,
668                                              self.machine_creds_bad,
669                                              self.last_lsarpc_named_bad)
670
671         self.lsarpc_connections_named.append(c)
672         return c
673
674     def get_drsuapi_connection_pair(self, new=False, unbind=False):
675         """get a (drs, drs_handle) tuple"""
676         if self.drsuapi_connections and not new:
677             c = self.drsuapi_connections[-1]
678             return c
679
680         def connect(creds):
681             binding_options = 'seal'
682             binding_string = "ncacn_ip_tcp:%s[%s]" %\
683                              (self.server, binding_options)
684             return drsuapi.drsuapi(binding_string, self.lp, creds)
685
686         (drs, self.last_drsuapi_bad) = \
687             self.with_random_bad_credentials(connect,
688                                              self.user_creds,
689                                              self.user_creds_bad,
690                                              self.last_drsuapi_bad)
691
692         (drs_handle, supported_extensions) = drs_DsBind(drs)
693         c = (drs, drs_handle)
694         self.drsuapi_connections.append(c)
695         return c
696
697     def get_ldap_connection(self, new=False, simple=False):
698         if self.ldap_connections and not new:
699             return self.ldap_connections[-1]
700
701         def simple_bind(creds):
702             """
703             To run simple bind against Windows, we need to run
704             following commands in PowerShell:
705
706                 Install-windowsfeature ADCS-Cert-Authority
707                 Install-AdcsCertificationAuthority -CAType EnterpriseRootCA
708                 Restart-Computer
709
710             """
711             return SamDB('ldaps://%s' % self.server,
712                          credentials=creds,
713                          lp=self.lp)
714
715         def sasl_bind(creds):
716             return SamDB('ldap://%s' % self.server,
717                          credentials=creds,
718                          lp=self.lp)
719         if simple:
720             (samdb, self.last_simple_bind_bad) = \
721                 self.with_random_bad_credentials(simple_bind,
722                                                  self.simple_bind_creds,
723                                                  self.simple_bind_creds_bad,
724                                                  self.last_simple_bind_bad)
725         else:
726             (samdb, self.last_bind_bad) = \
727                 self.with_random_bad_credentials(sasl_bind,
728                                                  self.user_creds,
729                                                  self.user_creds_bad,
730                                                  self.last_bind_bad)
731
732         self.ldap_connections.append(samdb)
733         return samdb
734
735     def get_samr_context(self, new=False):
736         if not self.samr_contexts or new:
737             self.samr_contexts.append(
738                 SamrContext(self.server, lp=self.lp, creds=self.creds))
739         return self.samr_contexts[-1]
740
741     def get_netlogon_connection(self):
742
743         if self.netlogon_connection:
744             return self.netlogon_connection
745
746         def connect(creds):
747             return netlogon.netlogon("ncacn_ip_tcp:%s[schannel,seal]" %
748                                      (self.server),
749                                      self.lp,
750                                      creds)
751         (c, self.last_netlogon_bad) = \
752             self.with_random_bad_credentials(connect,
753                                              self.machine_creds,
754                                              self.machine_creds_bad,
755                                              self.last_netlogon_bad)
756         self.netlogon_connection = c
757         return c
758
759     def guess_a_dns_lookup(self):
760         return (self.realm, 'A')
761
762     def get_authenticator(self):
763         auth = self.machine_creds.new_client_authenticator()
764         current  = netr_Authenticator()
765         current.cred.data = [x if isinstance(x, int) else ord(x)
766                              for x in auth["credential"]]
767         current.timestamp = auth["timestamp"]
768
769         subsequent = netr_Authenticator()
770         return (current, subsequent)
771
772
773 class SamrContext(object):
774     """State/Context associated with a samr connection.
775     """
776     def __init__(self, server, lp=None, creds=None):
777         self.connection    = None
778         self.handle        = None
779         self.domain_handle = None
780         self.domain_sid    = None
781         self.group_handle  = None
782         self.user_handle   = None
783         self.rids          = None
784         self.server        = server
785         self.lp            = lp
786         self.creds         = creds
787
788     def get_connection(self):
789         if not self.connection:
790             self.connection = samr.samr(
791                 "ncacn_ip_tcp:%s[seal]" % (self.server),
792                 lp_ctx=self.lp,
793                 credentials=self.creds)
794
795         return self.connection
796
797     def get_handle(self):
798         if not self.handle:
799             c = self.get_connection()
800             self.handle = c.Connect2(None, security.SEC_FLAG_MAXIMUM_ALLOWED)
801         return self.handle
802
803
804 class Conversation(object):
805     """Details of a converation between a simulated client and a server."""
806     def __init__(self, start_time=None, endpoints=None, seq=(),
807                  conversation_id=None):
808         self.start_time = start_time
809         self.endpoints = endpoints
810         self.packets = []
811         self.msg = random_colour_print(endpoints)
812         self.client_balance = 0.0
813         self.conversation_id = conversation_id
814         for p in seq:
815             self.add_short_packet(*p)
816
817     def __cmp__(self, other):
818         if self.start_time is None:
819             if other.start_time is None:
820                 return 0
821             return -1
822         if other.start_time is None:
823             return 1
824         return self.start_time - other.start_time
825
826     def add_packet(self, packet):
827         """Add a packet object to this conversation, making a local copy with
828         a conversation-relative timestamp."""
829         p = packet.copy()
830
831         if self.start_time is None:
832             self.start_time = p.timestamp
833
834         if self.endpoints is None:
835             self.endpoints = p.endpoints
836
837         if p.endpoints != self.endpoints:
838             raise FakePacketError("Conversation endpoints %s don't match"
839                                   "packet endpoints %s" %
840                                   (self.endpoints, p.endpoints))
841
842         p.timestamp -= self.start_time
843
844         if p.src == p.endpoints[0]:
845             self.client_balance -= p.client_score()
846         else:
847             self.client_balance += p.client_score()
848
849         if p.is_really_a_packet():
850             self.packets.append(p)
851
852     def add_short_packet(self, timestamp, protocol, opcode, extra,
853                          client=True):
854         """Create a packet from a timestamp, and 'protocol:opcode' pair, and a
855         (possibly empty) list of extra data. If client is True, assume
856         this packet is from the client to the server.
857         """
858         src, dest = self.guess_client_server()
859         if not client:
860             src, dest = dest, src
861         key = (protocol, opcode)
862         desc = OP_DESCRIPTIONS[key] if key in OP_DESCRIPTIONS else ''
863         if protocol in IP_PROTOCOLS:
864             ip_protocol = IP_PROTOCOLS[protocol]
865         else:
866             ip_protocol = '06'
867         packet = Packet(timestamp - self.start_time, ip_protocol,
868                         '', src, dest,
869                         protocol, opcode, desc, extra)
870         # XXX we're assuming the timestamp is already adjusted for
871         # this conversation?
872         # XXX should we adjust client balance for guessed packets?
873         if packet.src == packet.endpoints[0]:
874             self.client_balance -= packet.client_score()
875         else:
876             self.client_balance += packet.client_score()
877         if packet.is_really_a_packet():
878             self.packets.append(packet)
879
880     def __str__(self):
881         return ("<Conversation %s %s starting %.3f %d packets>" %
882                 (self.conversation_id, self.endpoints, self.start_time,
883                  len(self.packets)))
884
885     __repr__ = __str__
886
887     def __iter__(self):
888         return iter(self.packets)
889
890     def __len__(self):
891         return len(self.packets)
892
893     def get_duration(self):
894         if len(self.packets) < 2:
895             return 0
896         return self.packets[-1].timestamp - self.packets[0].timestamp
897
898     def replay_as_summary_lines(self):
899         lines = []
900         for p in self.packets:
901             lines.append(p.as_summary(self.start_time))
902         return lines
903
904     def replay_with_delay(self, start, context=None, account=None):
905         """Replay the conversation at the right time.
906         (We're already in a fork)."""
907         # first we sleep until the first packet
908         t = self.start_time
909         now = time.time() - start
910         gap = t - now
911         sleep_time = gap - SLEEP_OVERHEAD
912         if sleep_time > 0:
913             time.sleep(sleep_time)
914
915         miss = (time.time() - start) - t
916         self.msg("starting %s [miss %.3f]" % (self, miss))
917
918         max_gap = 0.0
919         max_sleep_miss = 0.0
920         # packet times are relative to conversation start
921         p_start = time.time()
922         for p in self.packets:
923             now = time.time() - p_start
924             gap = now - p.timestamp
925             if gap > max_gap:
926                 max_gap = gap
927             if gap < 0:
928                 sleep_time = -gap - SLEEP_OVERHEAD
929                 if sleep_time > 0:
930                     time.sleep(sleep_time)
931                     t = time.time() - p_start
932                     if t - p.timestamp > max_sleep_miss:
933                         max_sleep_miss = t - p.timestamp
934
935             p.play(self, context)
936
937         return max_gap, miss, max_sleep_miss
938
939     def guess_client_server(self, server_clue=None):
940         """Have a go at deciding who is the server and who is the client.
941         returns (client, server)
942         """
943         a, b = self.endpoints
944
945         if self.client_balance < 0:
946             return (a, b)
947
948         # in the absense of a clue, we will fall through to assuming
949         # the lowest number is the server (which is usually true).
950
951         if self.client_balance == 0 and server_clue == b:
952             return (a, b)
953
954         return (b, a)
955
956     def forget_packets_outside_window(self, s, e):
957         """Prune any packets outside the timne window we're interested in
958
959         :param s: start of the window
960         :param e: end of the window
961         """
962         self.packets = [p for p in self.packets if s <= p.timestamp <= e]
963         self.start_time = self.packets[0].timestamp if self.packets else None
964
965     def renormalise_times(self, start_time):
966         """Adjust the packet start times relative to the new start time."""
967         for p in self.packets:
968             p.timestamp -= start_time
969
970         if self.start_time is not None:
971             self.start_time -= start_time
972
973
974 class DnsHammer(Conversation):
975     """A lightweight conversation that generates a lot of dns:0 packets on
976     the fly"""
977
978     def __init__(self, dns_rate, duration):
979         n = int(dns_rate * duration)
980         self.times = [random.uniform(0, duration) for i in range(n)]
981         self.times.sort()
982         self.rate = dns_rate
983         self.duration = duration
984         self.start_time = 0
985         self.msg = random_colour_print()
986
987     def __str__(self):
988         return ("<DnsHammer %d packets over %.1fs (rate %.2f)>" %
989                 (len(self.times), self.duration, self.rate))
990
991     def replay(self, context=None):
992         start = time.time()
993         fn = traffic_packets.packet_dns_0
994         for t in self.times:
995             now = time.time() - start
996             gap = t - now
997             sleep_time = gap - SLEEP_OVERHEAD
998             if sleep_time > 0:
999                 time.sleep(sleep_time)
1000
1001             packet_start = time.time()
1002             try:
1003                 fn(None, None, context)
1004                 end = time.time()
1005                 duration = end - packet_start
1006                 print("%f\tDNS\tdns\t0\t%f\tTrue\t" % (end, duration))
1007             except Exception as e:
1008                 end = time.time()
1009                 duration = end - packet_start
1010                 print("%f\tDNS\tdns\t0\t%f\tFalse\t%s" % (end, duration, e))
1011
1012
1013 def ingest_summaries(files, dns_mode='count'):
1014     """Load a summary traffic summary file and generated Converations from it.
1015     """
1016
1017     dns_counts = defaultdict(int)
1018     packets = []
1019     for f in files:
1020         if isinstance(f, str):
1021             f = open(f)
1022         print("Ingesting %s" % (f.name,), file=sys.stderr)
1023         for line in f:
1024             p = Packet.from_line(line)
1025             if p.protocol == 'dns' and dns_mode != 'include':
1026                 dns_counts[p.opcode] += 1
1027             else:
1028                 packets.append(p)
1029
1030         f.close()
1031
1032     if not packets:
1033         return [], 0
1034
1035     start_time = min(p.timestamp for p in packets)
1036     last_packet = max(p.timestamp for p in packets)
1037
1038     print("gathering packets into conversations", file=sys.stderr)
1039     conversations = OrderedDict()
1040     for i, p in enumerate(packets):
1041         p.timestamp -= start_time
1042         c = conversations.get(p.endpoints)
1043         if c is None:
1044             c = Conversation(conversation_id=(i + 2))
1045             conversations[p.endpoints] = c
1046         c.add_packet(p)
1047
1048     # We only care about conversations with actual traffic, so we
1049     # filter out conversations with nothing to say. We do that here,
1050     # rather than earlier, because those empty packets contain useful
1051     # hints as to which end of the conversation was the client.
1052     conversation_list = []
1053     for c in conversations.values():
1054         if len(c) != 0:
1055             conversation_list.append(c)
1056
1057     # This is obviously not correct, as many conversations will appear
1058     # to start roughly simultaneously at the beginning of the snapshot.
1059     # To which we say: oh well, so be it.
1060     duration = float(last_packet - start_time)
1061     mean_interval = len(conversations) / duration
1062
1063     return conversation_list, mean_interval, duration, dns_counts
1064
1065
1066 def guess_server_address(conversations):
1067     # we guess the most common address.
1068     addresses = Counter()
1069     for c in conversations:
1070         addresses.update(c.endpoints)
1071     if addresses:
1072         return addresses.most_common(1)[0]
1073
1074
1075 def stringify_keys(x):
1076     y = {}
1077     for k, v in x.items():
1078         k2 = '\t'.join(k)
1079         y[k2] = v
1080     return y
1081
1082
1083 def unstringify_keys(x):
1084     y = {}
1085     for k, v in x.items():
1086         t = tuple(str(k).split('\t'))
1087         y[t] = v
1088     return y
1089
1090
1091 class TrafficModel(object):
1092     def __init__(self, n=3):
1093         self.ngrams = {}
1094         self.query_details = {}
1095         self.n = n
1096         self.dns_opcounts = defaultdict(int)
1097         self.cumulative_duration = 0.0
1098         self.packet_rate = [0, 1]
1099
1100     def learn(self, conversations, dns_opcounts={}):
1101         prev = 0.0
1102         cum_duration = 0.0
1103         key = (NON_PACKET,) * (self.n - 1)
1104
1105         server = guess_server_address(conversations)
1106
1107         for k, v in dns_opcounts.items():
1108             self.dns_opcounts[k] += v
1109
1110         if len(conversations) > 1:
1111             first = conversations[0].start_time
1112             total = 0
1113             last = first + 0.1
1114             for c in conversations:
1115                 total += len(c)
1116                 last = max(last, c.packets[-1].timestamp)
1117
1118             self.packet_rate[0] = total
1119             self.packet_rate[1] = last - first
1120
1121         for c in conversations:
1122             client, server = c.guess_client_server(server)
1123             cum_duration += c.get_duration()
1124             key = (NON_PACKET,) * (self.n - 1)
1125             for p in c:
1126                 if p.src != client:
1127                     continue
1128
1129                 elapsed = p.timestamp - prev
1130                 prev = p.timestamp
1131                 if elapsed > WAIT_THRESHOLD:
1132                     # add the wait as an extra state
1133                     wait = 'wait:%d' % (math.log(max(1.0,
1134                                                      elapsed * WAIT_SCALE)))
1135                     self.ngrams.setdefault(key, []).append(wait)
1136                     key = key[1:] + (wait,)
1137
1138                 short_p = p.as_packet_type()
1139                 self.query_details.setdefault(short_p,
1140                                               []).append(tuple(p.extra))
1141                 self.ngrams.setdefault(key, []).append(short_p)
1142                 key = key[1:] + (short_p,)
1143
1144         self.cumulative_duration += cum_duration
1145         # add in the end
1146         self.ngrams.setdefault(key, []).append(NON_PACKET)
1147
1148     def save(self, f):
1149         ngrams = {}
1150         for k, v in self.ngrams.items():
1151             k = '\t'.join(k)
1152             ngrams[k] = dict(Counter(v))
1153
1154         query_details = {}
1155         for k, v in self.query_details.items():
1156             query_details[k] = dict(Counter('\t'.join(x) if x else '-'
1157                                             for x in v))
1158
1159         d = {
1160             'ngrams': ngrams,
1161             'query_details': query_details,
1162             'cumulative_duration': self.cumulative_duration,
1163             'packet_rate': self.packet_rate,
1164             'version': CURRENT_MODEL_VERSION
1165         }
1166         d['dns'] = self.dns_opcounts
1167
1168         if isinstance(f, str):
1169             f = open(f, 'w')
1170
1171         json.dump(d, f, indent=2)
1172
1173     def load(self, f):
1174         if isinstance(f, str):
1175             f = open(f)
1176
1177         d = json.load(f)
1178
1179         try:
1180             version = d["version"]
1181             if version < REQUIRED_MODEL_VERSION:
1182                 raise ValueError("the model file is version %d; "
1183                                  "version %d is required" %
1184                                  (version, REQUIRED_MODEL_VERSION))
1185         except KeyError:
1186                 raise ValueError("the model file lacks a version number; "
1187                                  "version %d is required" %
1188                                  (REQUIRED_MODEL_VERSION))
1189
1190         for k, v in d['ngrams'].items():
1191             k = tuple(str(k).split('\t'))
1192             values = self.ngrams.setdefault(k, [])
1193             for p, count in v.items():
1194                 values.extend([str(p)] * count)
1195             values.sort()
1196
1197         for k, v in d['query_details'].items():
1198             values = self.query_details.setdefault(str(k), [])
1199             for p, count in v.items():
1200                 if p == '-':
1201                     values.extend([()] * count)
1202                 else:
1203                     values.extend([tuple(str(p).split('\t'))] * count)
1204             values.sort()
1205
1206         if 'dns' in d:
1207             for k, v in d['dns'].items():
1208                 self.dns_opcounts[k] += v
1209
1210         self.cumulative_duration = d['cumulative_duration']
1211         self.packet_rate = d['packet_rate']
1212
1213     def construct_conversation_sequence(self, timestamp=0.0,
1214                                         hard_stop=None,
1215                                         replay_speed=1,
1216                                         ignore_before=0):
1217         """Construct an individual conversation packet sequence from the
1218         model.
1219         """
1220         c = []
1221         key = (NON_PACKET,) * (self.n - 1)
1222         if ignore_before is None:
1223             ignore_before = timestamp - 1
1224
1225         while True:
1226             p = random.choice(self.ngrams.get(key, (NON_PACKET,)))
1227             if p == NON_PACKET:
1228                 break
1229
1230             if p in self.query_details:
1231                 extra = random.choice(self.query_details[p])
1232             else:
1233                 extra = []
1234
1235             protocol, opcode = p.split(':', 1)
1236             if protocol == 'wait':
1237                 log_wait_time = int(opcode) + random.random()
1238                 wait = math.exp(log_wait_time) / (WAIT_SCALE * replay_speed)
1239                 timestamp += wait
1240             else:
1241                 log_wait = random.uniform(*NO_WAIT_LOG_TIME_RANGE)
1242                 wait = math.exp(log_wait) / replay_speed
1243                 timestamp += wait
1244                 if hard_stop is not None and timestamp > hard_stop:
1245                     break
1246                 if timestamp >= ignore_before:
1247                     c.append((timestamp, protocol, opcode, extra))
1248
1249             key = key[1:] + (p,)
1250
1251         return c
1252
1253     def generate_conversation_sequences(self, scale, duration, replay_speed=1):
1254         """Generate a list of conversation descriptions from the model."""
1255
1256         # We run the simulation for ten times as long as our desired
1257         # duration, and take the section at the end.
1258         lead_in = 9 * duration
1259         rate_n, rate_t  = self.packet_rate
1260         target_packets = int(duration * scale * rate_n / rate_t)
1261
1262         conversations = []
1263         n_packets = 0
1264
1265         while n_packets < target_packets:
1266             start = random.uniform(-lead_in, duration)
1267             c = self.construct_conversation_sequence(start,
1268                                                      hard_stop=duration,
1269                                                      replay_speed=replay_speed,
1270                                                      ignore_before=0)
1271             # will these "packets" generate actual traffic?
1272             # some (e.g. ldap unbind) will not generate anything
1273             # if the previous packets are not there, and if the
1274             # conversation only has those it wastes a process doing nothing.
1275             for timestamp, protocol, opcode, extra in c:
1276                 if is_a_traffic_generating_packet(protocol, opcode):
1277                     break
1278             else:
1279                 continue
1280
1281             conversations.append(c)
1282             n_packets += len(c)
1283
1284         print(("we have %d packets (target %d) in %d conversations at scale %f"
1285                % (n_packets, target_packets, len(conversations), scale)),
1286               file=sys.stderr)
1287         conversations.sort()  # sorts by first element == start time
1288         return conversations
1289
1290
1291 def seq_to_conversations(seq, server=1, client=2):
1292     conversations = []
1293     for s in seq:
1294         if s:
1295             c = Conversation(s[0][0], (server, client), s)
1296             client += 1
1297             conversations.append(c)
1298     return conversations
1299
1300
1301 IP_PROTOCOLS = {
1302     'dns': '11',
1303     'rpc_netlogon': '06',
1304     'kerberos': '06',      # ratio 16248:258
1305     'smb': '06',
1306     'smb2': '06',
1307     'ldap': '06',
1308     'cldap': '11',
1309     'lsarpc': '06',
1310     'samr': '06',
1311     'dcerpc': '06',
1312     'epm': '06',
1313     'drsuapi': '06',
1314     'browser': '11',
1315     'smb_netlogon': '11',
1316     'srvsvc': '06',
1317     'nbns': '11',
1318 }
1319
1320 OP_DESCRIPTIONS = {
1321     ('browser', '0x01'): 'Host Announcement (0x01)',
1322     ('browser', '0x02'): 'Request Announcement (0x02)',
1323     ('browser', '0x08'): 'Browser Election Request (0x08)',
1324     ('browser', '0x09'): 'Get Backup List Request (0x09)',
1325     ('browser', '0x0c'): 'Domain/Workgroup Announcement (0x0c)',
1326     ('browser', '0x0f'): 'Local Master Announcement (0x0f)',
1327     ('cldap', '3'): 'searchRequest',
1328     ('cldap', '5'): 'searchResDone',
1329     ('dcerpc', '0'): 'Request',
1330     ('dcerpc', '11'): 'Bind',
1331     ('dcerpc', '12'): 'Bind_ack',
1332     ('dcerpc', '13'): 'Bind_nak',
1333     ('dcerpc', '14'): 'Alter_context',
1334     ('dcerpc', '15'): 'Alter_context_resp',
1335     ('dcerpc', '16'): 'AUTH3',
1336     ('dcerpc', '2'): 'Response',
1337     ('dns', '0'): 'query',
1338     ('dns', '1'): 'response',
1339     ('drsuapi', '0'): 'DsBind',
1340     ('drsuapi', '12'): 'DsCrackNames',
1341     ('drsuapi', '13'): 'DsWriteAccountSpn',
1342     ('drsuapi', '1'): 'DsUnbind',
1343     ('drsuapi', '2'): 'DsReplicaSync',
1344     ('drsuapi', '3'): 'DsGetNCChanges',
1345     ('drsuapi', '4'): 'DsReplicaUpdateRefs',
1346     ('epm', '3'): 'Map',
1347     ('kerberos', ''): '',
1348     ('ldap', '0'): 'bindRequest',
1349     ('ldap', '1'): 'bindResponse',
1350     ('ldap', '2'): 'unbindRequest',
1351     ('ldap', '3'): 'searchRequest',
1352     ('ldap', '4'): 'searchResEntry',
1353     ('ldap', '5'): 'searchResDone',
1354     ('ldap', ''): '*** Unknown ***',
1355     ('lsarpc', '14'): 'lsa_LookupNames',
1356     ('lsarpc', '15'): 'lsa_LookupSids',
1357     ('lsarpc', '39'): 'lsa_QueryTrustedDomainInfoBySid',
1358     ('lsarpc', '40'): 'lsa_SetTrustedDomainInfo',
1359     ('lsarpc', '6'): 'lsa_OpenPolicy',
1360     ('lsarpc', '76'): 'lsa_LookupSids3',
1361     ('lsarpc', '77'): 'lsa_LookupNames4',
1362     ('nbns', '0'): 'query',
1363     ('nbns', '1'): 'response',
1364     ('rpc_netlogon', '21'): 'NetrLogonDummyRoutine1',
1365     ('rpc_netlogon', '26'): 'NetrServerAuthenticate3',
1366     ('rpc_netlogon', '29'): 'NetrLogonGetDomainInfo',
1367     ('rpc_netlogon', '30'): 'NetrServerPasswordSet2',
1368     ('rpc_netlogon', '39'): 'NetrLogonSamLogonEx',
1369     ('rpc_netlogon', '40'): 'DsrEnumerateDomainTrusts',
1370     ('rpc_netlogon', '45'): 'NetrLogonSamLogonWithFlags',
1371     ('rpc_netlogon', '4'): 'NetrServerReqChallenge',
1372     ('samr', '0',): 'Connect',
1373     ('samr', '16'): 'GetAliasMembership',
1374     ('samr', '17'): 'LookupNames',
1375     ('samr', '18'): 'LookupRids',
1376     ('samr', '19'): 'OpenGroup',
1377     ('samr', '1'): 'Close',
1378     ('samr', '25'): 'QueryGroupMember',
1379     ('samr', '34'): 'OpenUser',
1380     ('samr', '36'): 'QueryUserInfo',
1381     ('samr', '39'): 'GetGroupsForUser',
1382     ('samr', '3'): 'QuerySecurity',
1383     ('samr', '5'): 'LookupDomain',
1384     ('samr', '64'): 'Connect5',
1385     ('samr', '6'): 'EnumDomains',
1386     ('samr', '7'): 'OpenDomain',
1387     ('samr', '8'): 'QueryDomainInfo',
1388     ('smb', '0x04'): 'Close (0x04)',
1389     ('smb', '0x24'): 'Locking AndX (0x24)',
1390     ('smb', '0x2e'): 'Read AndX (0x2e)',
1391     ('smb', '0x32'): 'Trans2 (0x32)',
1392     ('smb', '0x71'): 'Tree Disconnect (0x71)',
1393     ('smb', '0x72'): 'Negotiate Protocol (0x72)',
1394     ('smb', '0x73'): 'Session Setup AndX (0x73)',
1395     ('smb', '0x74'): 'Logoff AndX (0x74)',
1396     ('smb', '0x75'): 'Tree Connect AndX (0x75)',
1397     ('smb', '0xa2'): 'NT Create AndX (0xa2)',
1398     ('smb2', '0'): 'NegotiateProtocol',
1399     ('smb2', '11'): 'Ioctl',
1400     ('smb2', '14'): 'Find',
1401     ('smb2', '16'): 'GetInfo',
1402     ('smb2', '18'): 'Break',
1403     ('smb2', '1'): 'SessionSetup',
1404     ('smb2', '2'): 'SessionLogoff',
1405     ('smb2', '3'): 'TreeConnect',
1406     ('smb2', '4'): 'TreeDisconnect',
1407     ('smb2', '5'): 'Create',
1408     ('smb2', '6'): 'Close',
1409     ('smb2', '8'): 'Read',
1410     ('smb_netlogon', '0x12'): 'SAM LOGON request from client (0x12)',
1411     ('smb_netlogon', '0x17'): ('SAM Active Directory Response - '
1412                                'user unknown (0x17)'),
1413     ('srvsvc', '16'): 'NetShareGetInfo',
1414     ('srvsvc', '21'): 'NetSrvGetInfo',
1415 }
1416
1417
1418 def expand_short_packet(p, timestamp, src, dest, extra):
1419     protocol, opcode = p.split(':', 1)
1420     desc = OP_DESCRIPTIONS.get((protocol, opcode), '')
1421     ip_protocol = IP_PROTOCOLS.get(protocol, '06')
1422
1423     line = [timestamp, ip_protocol, '', src, dest, protocol, opcode, desc]
1424     line.extend(extra)
1425     return '\t'.join(line)
1426
1427
1428 def flushing_signal_handler(signal, frame):
1429     """Signal handler closes standard out and error.
1430
1431     Triggered by a sigterm, ensures that the log messages are flushed
1432     to disk and not lost.
1433     """
1434     sys.stderr.close()
1435     sys.stdout.close()
1436     os._exit(0)
1437
1438
1439 def replay_seq_in_fork(cs, start, context, account, client_id, server_id=1):
1440     """Fork a new process and replay the conversation sequence."""
1441     # We will need to reseed the random number generator or all the
1442     # clients will end up using the same sequence of random
1443     # numbers. random.randint() is mixed in so the initial seed will
1444     # have an effect here.
1445     seed = client_id * 1000 + random.randint(0, 999)
1446
1447     # flush our buffers so messages won't be written by both sides
1448     sys.stdout.flush()
1449     sys.stderr.flush()
1450     pid = os.fork()
1451     if pid != 0:
1452         return pid
1453
1454     # we must never return, or we'll end up running parts of the
1455     # parent's clean-up code. So we work in a try...finally, and
1456     # try to print any exceptions.
1457     try:
1458         random.seed(seed)
1459         endpoints = (server_id, client_id)
1460         status = 0
1461         t = cs[0][0]
1462         c = Conversation(t, endpoints, seq=cs, conversation_id=client_id)
1463         signal.signal(signal.SIGTERM, flushing_signal_handler)
1464
1465         context.generate_process_local_config(account, c)
1466         sys.stdin.close()
1467         os.close(0)
1468         filename = os.path.join(context.statsdir, 'stats-conversation-%d' %
1469                                 c.conversation_id)
1470         f = open(filename, 'w')
1471         try:
1472             sys.stdout.close()
1473             os.close(1)
1474         except IOError as e:
1475             LOGGER.info("stdout closing failed with %s" % e)
1476             pass
1477
1478         sys.stdout = f
1479         now = time.time() - start
1480         gap = t - now
1481         sleep_time = gap - SLEEP_OVERHEAD
1482         if sleep_time > 0:
1483             time.sleep(sleep_time)
1484
1485         max_lag, start_lag, max_sleep_miss = c.replay_with_delay(start=start,
1486                                                                  context=context)
1487         print("Maximum lag: %f" % max_lag)
1488         print("Start lag: %f" % start_lag)
1489         print("Max sleep miss: %f" % max_sleep_miss)
1490
1491     except Exception:
1492         status = 1
1493         print(("EXCEPTION in child PID %d, conversation %s" % (os.getpid(), c)),
1494               file=sys.stderr)
1495         traceback.print_exc(sys.stderr)
1496         sys.stderr.flush()
1497     finally:
1498         sys.stderr.close()
1499         sys.stdout.close()
1500         os._exit(status)
1501
1502
1503 def dnshammer_in_fork(dns_rate, duration):
1504     sys.stdout.flush()
1505     sys.stderr.flush()
1506     pid = os.fork()
1507     if pid != 0:
1508         return pid
1509     try:
1510         status = 0
1511         signal.signal(signal.SIGTERM, flushing_signal_handler)
1512         hammer = DnsHammer(dns_rate, duration)
1513         hammer.replay()
1514     except Exception:
1515         status = 1
1516         print(("EXCEPTION in child PID %d, the DNS hammer" % (os.getpid())),
1517               file=sys.stderr)
1518         traceback.print_exc(sys.stderr)
1519     finally:
1520         sys.stderr.close()
1521         sys.stdout.close()
1522         os._exit(status)
1523
1524
1525 def replay(conversation_seq,
1526            host=None,
1527            creds=None,
1528            lp=None,
1529            accounts=None,
1530            dns_rate=0,
1531            duration=None,
1532            latency_timeout=1.0,
1533            stop_on_any_error=False,
1534            **kwargs):
1535
1536     context = ReplayContext(server=host,
1537                             creds=creds,
1538                             lp=lp,
1539                             **kwargs)
1540
1541     if len(accounts) < len(conversation_seq):
1542         raise ValueError(("we have %d accounts but %d conversations" %
1543                           (len(accounts), len(conversation_seq))))
1544
1545     # Set the process group so that the calling scripts are not killed
1546     # when the forked child processes are killed.
1547     os.setpgrp()
1548
1549     # we delay the start by a bit to allow all the forks to get up and
1550     # running.
1551     delay = len(conversation_seq) * 0.02
1552     start = time.time() + delay
1553
1554     if duration is None:
1555         # end slightly after the last packet of the last conversation
1556         # to start. Conversations other than the last could still be
1557         # going, but we don't care.
1558         duration = conversation_seq[-1][-1][0] + latency_timeout
1559
1560     print("We will start in %.1f seconds" % delay,
1561           file=sys.stderr)
1562     print("We will stop after %.1f seconds" % (duration + delay),
1563           file=sys.stderr)
1564     print("runtime %.1f seconds" % duration,
1565           file=sys.stderr)
1566
1567     # give one second grace for packets to finish before killing begins
1568     end = start + duration + 1.0
1569
1570     LOGGER.info("Replaying traffic for %u conversations over %d seconds"
1571           % (len(conversation_seq), duration))
1572
1573
1574     children = {}
1575     try:
1576         if dns_rate:
1577             pid = dnshammer_in_fork(dns_rate, duration)
1578             children[pid] = 1
1579
1580         for i, cs in enumerate(conversation_seq):
1581             account = accounts[i]
1582             client_id = i + 2
1583             pid = replay_seq_in_fork(cs, start, context, account, client_id)
1584             children[pid] = client_id
1585
1586         # HERE, we are past all the forks
1587         t = time.time()
1588         print("all forks done in %.1f seconds, waiting %.1f" %
1589               (t - start + delay, t - start),
1590               file=sys.stderr)
1591
1592         while time.time() < end and children:
1593             time.sleep(0.003)
1594             try:
1595                 pid, status = os.waitpid(-1, os.WNOHANG)
1596             except OSError as e:
1597                 if e.errno != ECHILD:  # no child processes
1598                     raise
1599                 break
1600             if pid:
1601                 c = children.pop(pid, None)
1602                 if DEBUG_LEVEL > 0:
1603                     print(("process %d finished conversation %d;"
1604                            " %d to go" %
1605                            (pid, c, len(children))), file=sys.stderr)
1606                 if stop_on_any_error and status != 0:
1607                     break
1608
1609     except Exception:
1610         print("EXCEPTION in parent", file=sys.stderr)
1611         traceback.print_exc()
1612     finally:
1613         for s in (15, 15, 9):
1614             print(("killing %d children with -%d" %
1615                    (len(children), s)), file=sys.stderr)
1616             for pid in children:
1617                 try:
1618                     os.kill(pid, s)
1619                 except OSError as e:
1620                     if e.errno != ESRCH:  # don't fail if it has already died
1621                         raise
1622             time.sleep(0.5)
1623             end = time.time() + 1
1624             while children:
1625                 try:
1626                     pid, status = os.waitpid(-1, os.WNOHANG)
1627                 except OSError as e:
1628                     if e.errno != ECHILD:
1629                         raise
1630                 if pid != 0:
1631                     c = children.pop(pid, None)
1632                     if c is None:
1633                         print("children is %s, no pid found" % children)
1634                         sys.stderr.flush()
1635                         sys.stdout.flush()
1636                         os._exit(1)
1637                     print(("kill -%d %d KILLED conversation; "
1638                            "%d to go" %
1639                            (s, pid, len(children))),
1640                           file=sys.stderr)
1641                 if time.time() >= end:
1642                     break
1643
1644             if not children:
1645                 break
1646             time.sleep(1)
1647
1648         if children:
1649             print("%d children are missing" % len(children),
1650                   file=sys.stderr)
1651
1652         # there may be stragglers that were forked just as ^C was hit
1653         # and don't appear in the list of children. We can get them
1654         # with killpg, but that will also kill us, so this is^H^H would be
1655         # goodbye, except we cheat and pretend to use ^C (SIG_INTERRUPT),
1656         # so as not to have to fuss around writing signal handlers.
1657         try:
1658             os.killpg(0, 2)
1659         except KeyboardInterrupt:
1660             print("ignoring fake ^C", file=sys.stderr)
1661
1662
1663 def openLdb(host, creds, lp):
1664     session = system_session()
1665     ldb = SamDB(url="ldap://%s" % host,
1666                 session_info=session,
1667                 options=['modules:paged_searches'],
1668                 credentials=creds,
1669                 lp=lp)
1670     return ldb
1671
1672
1673 def ou_name(ldb, instance_id):
1674     """Generate an ou name from the instance id"""
1675     return "ou=instance-%d,ou=traffic_replay,%s" % (instance_id,
1676                                                     ldb.domain_dn())
1677
1678
1679 def create_ou(ldb, instance_id):
1680     """Create an ou, all created user and machine accounts will belong to it.
1681
1682     This allows all the created resources to be cleaned up easily.
1683     """
1684     ou = ou_name(ldb, instance_id)
1685     try:
1686         ldb.add({"dn": ou.split(',', 1)[1],
1687                  "objectclass": "organizationalunit"})
1688     except LdbError as e:
1689         (status, _) = e.args
1690         # ignore already exists
1691         if status != 68:
1692             raise
1693     try:
1694         ldb.add({"dn": ou,
1695                  "objectclass": "organizationalunit"})
1696     except LdbError as e:
1697         (status, _) = e.args
1698         # ignore already exists
1699         if status != 68:
1700             raise
1701     return ou
1702
1703
1704 # ConversationAccounts holds details of the machine and user accounts
1705 # associated with a conversation.
1706 #
1707 # We use a named tuple to reduce shared memory usage.
1708 ConversationAccounts = namedtuple('ConversationAccounts',
1709                                   ('netbios_name',
1710                                    'machinepass',
1711                                    'username',
1712                                    'userpass'))
1713
1714
1715 def generate_replay_accounts(ldb, instance_id, number, password):
1716     """Generate a series of unique machine and user account names."""
1717
1718     accounts = []
1719     for i in range(1, number + 1):
1720         netbios_name = machine_name(instance_id, i)
1721         username = user_name(instance_id, i)
1722
1723         account = ConversationAccounts(netbios_name, password, username,
1724                                        password)
1725         accounts.append(account)
1726     return accounts
1727
1728
1729 def create_machine_account(ldb, instance_id, netbios_name, machinepass,
1730                            traffic_account=True):
1731     """Create a machine account via ldap."""
1732
1733     ou = ou_name(ldb, instance_id)
1734     dn = "cn=%s,%s" % (netbios_name, ou)
1735     utf16pw = ('"%s"' % get_string(machinepass)).encode('utf-16-le')
1736
1737     if traffic_account:
1738         # we set these bits for the machine account otherwise the replayed
1739         # traffic throws up NT_STATUS_NO_TRUST_SAM_ACCOUNT errors
1740         account_controls = str(UF_TRUSTED_FOR_DELEGATION |
1741                                UF_SERVER_TRUST_ACCOUNT)
1742
1743     else:
1744         account_controls = str(UF_WORKSTATION_TRUST_ACCOUNT)
1745
1746     ldb.add({
1747         "dn": dn,
1748         "objectclass": "computer",
1749         "sAMAccountName": "%s$" % netbios_name,
1750         "userAccountControl": account_controls,
1751         "unicodePwd": utf16pw})
1752
1753
1754 def create_user_account(ldb, instance_id, username, userpass):
1755     """Create a user account via ldap."""
1756     ou = ou_name(ldb, instance_id)
1757     user_dn = "cn=%s,%s" % (username, ou)
1758     utf16pw = ('"%s"' % get_string(userpass)).encode('utf-16-le')
1759     ldb.add({
1760         "dn": user_dn,
1761         "objectclass": "user",
1762         "sAMAccountName": username,
1763         "userAccountControl": str(UF_NORMAL_ACCOUNT),
1764         "unicodePwd": utf16pw
1765     })
1766
1767     # grant user write permission to do things like write account SPN
1768     sdutils = sd_utils.SDUtils(ldb)
1769     sdutils.dacl_add_ace(user_dn, "(A;;WP;;;PS)")
1770
1771
1772 def create_group(ldb, instance_id, name):
1773     """Create a group via ldap."""
1774
1775     ou = ou_name(ldb, instance_id)
1776     dn = "cn=%s,%s" % (name, ou)
1777     ldb.add({
1778         "dn": dn,
1779         "objectclass": "group",
1780         "sAMAccountName": name,
1781     })
1782
1783
1784 def user_name(instance_id, i):
1785     """Generate a user name based in the instance id"""
1786     return "STGU-%d-%d" % (instance_id, i)
1787
1788
1789 def search_objectclass(ldb, objectclass='user', attr='sAMAccountName'):
1790     """Seach objectclass, return attr in a set"""
1791     objs = ldb.search(
1792         expression="(objectClass={})".format(objectclass),
1793         attrs=[attr]
1794     )
1795     return {str(obj[attr]) for obj in objs}
1796
1797
1798 def generate_users(ldb, instance_id, number, password):
1799     """Add users to the server"""
1800     existing_objects = search_objectclass(ldb, objectclass='user')
1801     users = 0
1802     for i in range(number, 0, -1):
1803         name = user_name(instance_id, i)
1804         if name not in existing_objects:
1805             create_user_account(ldb, instance_id, name, password)
1806             users += 1
1807             if users % 50 == 0:
1808                 LOGGER.info("Created %u/%u users" % (users, number))
1809
1810     return users
1811
1812
1813 def machine_name(instance_id, i, traffic_account=True):
1814     """Generate a machine account name from instance id."""
1815     if traffic_account:
1816         # traffic accounts correspond to a given user, and use different
1817         # userAccountControl flags to ensure packets get processed correctly
1818         # by the DC
1819         return "STGM-%d-%d" % (instance_id, i)
1820     else:
1821         # Otherwise we're just generating computer accounts to simulate a
1822         # semi-realistic network. These use the default computer
1823         # userAccountControl flags, so we use a different account name so that
1824         # we don't try to use them when generating packets
1825         return "PC-%d-%d" % (instance_id, i)
1826
1827
1828 def generate_machine_accounts(ldb, instance_id, number, password,
1829                               traffic_account=True):
1830     """Add machine accounts to the server"""
1831     existing_objects = search_objectclass(ldb, objectclass='computer')
1832     added = 0
1833     for i in range(number, 0, -1):
1834         name = machine_name(instance_id, i, traffic_account)
1835         if name + "$" not in existing_objects:
1836             create_machine_account(ldb, instance_id, name, password,
1837                                    traffic_account)
1838             added += 1
1839             if added % 50 == 0:
1840                 LOGGER.info("Created %u/%u machine accounts" % (added, number))
1841
1842     return added
1843
1844
1845 def group_name(instance_id, i):
1846     """Generate a group name from instance id."""
1847     return "STGG-%d-%d" % (instance_id, i)
1848
1849
1850 def generate_groups(ldb, instance_id, number):
1851     """Create the required number of groups on the server."""
1852     existing_objects = search_objectclass(ldb, objectclass='group')
1853     groups = 0
1854     for i in range(number, 0, -1):
1855         name = group_name(instance_id, i)
1856         if name not in existing_objects:
1857             create_group(ldb, instance_id, name)
1858             groups += 1
1859             if groups % 1000 == 0:
1860                 LOGGER.info("Created %u/%u groups" % (groups, number))
1861
1862     return groups
1863
1864
1865 def clean_up_accounts(ldb, instance_id):
1866     """Remove the created accounts and groups from the server."""
1867     ou = ou_name(ldb, instance_id)
1868     try:
1869         ldb.delete(ou, ["tree_delete:1"])
1870     except LdbError as e:
1871         (status, _) = e.args
1872         # ignore does not exist
1873         if status != 32:
1874             raise
1875
1876
1877 def generate_users_and_groups(ldb, instance_id, password,
1878                               number_of_users, number_of_groups,
1879                               group_memberships, max_members,
1880                               machine_accounts, traffic_accounts=True):
1881     """Generate the required users and groups, allocating the users to
1882        those groups."""
1883     memberships_added = 0
1884     groups_added = 0
1885     computers_added = 0
1886
1887     create_ou(ldb, instance_id)
1888
1889     LOGGER.info("Generating dummy user accounts")
1890     users_added = generate_users(ldb, instance_id, number_of_users, password)
1891
1892     LOGGER.info("Generating dummy machine accounts")
1893     computers_added = generate_machine_accounts(ldb, instance_id,
1894                                                 machine_accounts, password,
1895                                                 traffic_accounts)
1896
1897     if number_of_groups > 0:
1898         LOGGER.info("Generating dummy groups")
1899         groups_added = generate_groups(ldb, instance_id, number_of_groups)
1900
1901     if group_memberships > 0:
1902         LOGGER.info("Assigning users to groups")
1903         assignments = GroupAssignments(number_of_groups,
1904                                        groups_added,
1905                                        number_of_users,
1906                                        users_added,
1907                                        group_memberships,
1908                                        max_members)
1909         LOGGER.info("Adding users to groups")
1910         add_users_to_groups(ldb, instance_id, assignments)
1911         memberships_added = assignments.total()
1912
1913     if (groups_added > 0 and users_added == 0 and
1914        number_of_groups != groups_added):
1915         LOGGER.warning("The added groups will contain no members")
1916
1917     LOGGER.info("Added %d users (%d machines), %d groups and %d memberships" %
1918                 (users_added, computers_added, groups_added,
1919                  memberships_added))
1920
1921
1922 class GroupAssignments(object):
1923     def __init__(self, number_of_groups, groups_added, number_of_users,
1924                  users_added, group_memberships, max_members):
1925
1926         self.count = 0
1927         self.generate_group_distribution(number_of_groups)
1928         self.generate_user_distribution(number_of_users, group_memberships)
1929         self.max_members = max_members
1930         self.assignments = defaultdict(list)
1931         self.assign_groups(number_of_groups, groups_added, number_of_users,
1932                            users_added, group_memberships)
1933
1934     def cumulative_distribution(self, weights):
1935         # make sure the probabilities conform to a cumulative distribution
1936         # spread between 0.0 and 1.0. Dividing by the weighted total gives each
1937         # probability a proportional share of 1.0. Higher probabilities get a
1938         # bigger share, so are more likely to be picked. We use the cumulative
1939         # value, so we can use random.random() as a simple index into the list
1940         dist = []
1941         total = sum(weights)
1942         if total == 0:
1943             return None
1944
1945         cumulative = 0.0
1946         for probability in weights:
1947             cumulative += probability
1948             dist.append(cumulative / total)
1949         return dist
1950
1951     def generate_user_distribution(self, num_users, num_memberships):
1952         """Probability distribution of a user belonging to a group.
1953         """
1954         # Assign a weighted probability to each user. Use the Pareto
1955         # Distribution so that some users are in a lot of groups, and the
1956         # bulk of users are in only a few groups. If we're assigning a large
1957         # number of group memberships, use a higher shape. This means slightly
1958         # fewer outlying users that are in large numbers of groups. The aim is
1959         # to have no users belonging to more than ~500 groups.
1960         if num_memberships > 5000000:
1961             shape = 3.0
1962         elif num_memberships > 2000000:
1963             shape = 2.5
1964         elif num_memberships > 300000:
1965             shape = 2.25
1966         else:
1967             shape = 1.75
1968
1969         weights = []
1970         for x in range(1, num_users + 1):
1971             p = random.paretovariate(shape)
1972             weights.append(p)
1973
1974         # convert the weights to a cumulative distribution between 0.0 and 1.0
1975         self.user_dist = self.cumulative_distribution(weights)
1976
1977     def generate_group_distribution(self, n):
1978         """Probability distribution of a group containing a user."""
1979
1980         # Assign a weighted probability to each user. Probability decreases
1981         # as the group-ID increases
1982         weights = []
1983         for x in range(1, n + 1):
1984             p = 1 / (x**1.3)
1985             weights.append(p)
1986
1987         # convert the weights to a cumulative distribution between 0.0 and 1.0
1988         self.group_weights = weights
1989         self.group_dist = self.cumulative_distribution(weights)
1990
1991     def generate_random_membership(self):
1992         """Returns a randomly generated user-group membership"""
1993
1994         # the list items are cumulative distribution values between 0.0 and
1995         # 1.0, which makes random() a handy way to index the list to get a
1996         # weighted random user/group. (Here the user/group returned are
1997         # zero-based array indexes)
1998         user = bisect.bisect(self.user_dist, random.random())
1999         group = bisect.bisect(self.group_dist, random.random())
2000
2001         return user, group
2002
2003     def users_in_group(self, group):
2004         return self.assignments[group]
2005
2006     def get_groups(self):
2007         return self.assignments.keys()
2008
2009     def cap_group_membership(self, group, max_members):
2010         """Prevent the group's membership from exceeding the max specified"""
2011         num_members = len(self.assignments[group])
2012         if num_members >= max_members:
2013             LOGGER.info("Group {0} has {1} members".format(group, num_members))
2014
2015             # remove this group and then recalculate the cumulative
2016             # distribution, so this group is no longer selected
2017             self.group_weights[group - 1] = 0
2018             new_dist = self.cumulative_distribution(self.group_weights)
2019             self.group_dist = new_dist
2020
2021     def add_assignment(self, user, group):
2022         # the assignments are stored in a dictionary where key=group,
2023         # value=list-of-users-in-group (indexing by group-ID allows us to
2024         # optimize for DB membership writes)
2025         if user not in self.assignments[group]:
2026             self.assignments[group].append(user)
2027             self.count += 1
2028
2029         # check if there'a cap on how big the groups can grow
2030         if self.max_members:
2031             self.cap_group_membership(group, self.max_members)
2032
2033     def assign_groups(self, number_of_groups, groups_added,
2034                       number_of_users, users_added, group_memberships):
2035         """Allocate users to groups.
2036
2037         The intention is to have a few users that belong to most groups, while
2038         the majority of users belong to a few groups.
2039
2040         A few groups will contain most users, with the remaining only having a
2041         few users.
2042         """
2043
2044         if group_memberships <= 0:
2045             return
2046
2047         # Calculate the number of group menberships required
2048         group_memberships = math.ceil(
2049             float(group_memberships) *
2050             (float(users_added) / float(number_of_users)))
2051
2052         if self.max_members:
2053             group_memberships = min(group_memberships,
2054                                     self.max_members * number_of_groups)
2055
2056         existing_users  = number_of_users  - users_added  - 1
2057         existing_groups = number_of_groups - groups_added - 1
2058         while self.total() < group_memberships:
2059             user, group = self.generate_random_membership()
2060
2061             if group > existing_groups or user > existing_users:
2062                 # the + 1 converts the array index to the corresponding
2063                 # group or user number
2064                 self.add_assignment(user + 1, group + 1)
2065
2066     def total(self):
2067         return self.count
2068
2069
2070 def add_users_to_groups(db, instance_id, assignments):
2071     """Takes the assignments of users to groups and applies them to the DB."""
2072
2073     total = assignments.total()
2074     count = 0
2075     added = 0
2076
2077     for group in assignments.get_groups():
2078         users_in_group = assignments.users_in_group(group)
2079         if len(users_in_group) == 0:
2080             continue
2081
2082         # Split up the users into chunks, so we write no more than 1K at a
2083         # time. (Minimizing the DB modifies is more efficient, but writing
2084         # 10K+ users to a single group becomes inefficient memory-wise)
2085         for chunk in range(0, len(users_in_group), 1000):
2086             chunk_of_users = users_in_group[chunk:chunk + 1000]
2087             add_group_members(db, instance_id, group, chunk_of_users)
2088
2089             added += len(chunk_of_users)
2090             count += 1
2091             if count % 50 == 0:
2092                 LOGGER.info("Added %u/%u memberships" % (added, total))
2093
2094 def add_group_members(db, instance_id, group, users_in_group):
2095     """Adds the given users to group specified."""
2096
2097     ou = ou_name(db, instance_id)
2098
2099     def build_dn(name):
2100         return("cn=%s,%s" % (name, ou))
2101
2102     group_dn = build_dn(group_name(instance_id, group))
2103     m = ldb.Message()
2104     m.dn = ldb.Dn(db, group_dn)
2105
2106     for user in users_in_group:
2107         user_dn = build_dn(user_name(instance_id, user))
2108         idx = "member-" + str(user)
2109         m[idx] = ldb.MessageElement(user_dn, ldb.FLAG_MOD_ADD, "member")
2110
2111     db.modify(m)
2112
2113
2114 def generate_stats(statsdir, timing_file):
2115     """Generate and print the summary stats for a run."""
2116     first      = sys.float_info.max
2117     last       = 0
2118     successful = 0
2119     failed     = 0
2120     latencies  = {}
2121     failures   = {}
2122     unique_converations = set()
2123     conversations = 0
2124
2125     if timing_file is not None:
2126         tw = timing_file.write
2127     else:
2128         def tw(x):
2129             pass
2130
2131     tw("time\tconv\tprotocol\ttype\tduration\tsuccessful\terror\n")
2132
2133     for filename in os.listdir(statsdir):
2134         path = os.path.join(statsdir, filename)
2135         with open(path, 'r') as f:
2136             for line in f:
2137                 try:
2138                     fields       = line.rstrip('\n').split('\t')
2139                     conversation = fields[1]
2140                     protocol     = fields[2]
2141                     packet_type  = fields[3]
2142                     latency      = float(fields[4])
2143                     first        = min(float(fields[0]) - latency, first)
2144                     last         = max(float(fields[0]), last)
2145
2146                     if protocol not in latencies:
2147                         latencies[protocol] = {}
2148                     if packet_type not in latencies[protocol]:
2149                         latencies[protocol][packet_type] = []
2150
2151                     latencies[protocol][packet_type].append(latency)
2152
2153                     if protocol not in failures:
2154                         failures[protocol] = {}
2155                     if packet_type not in failures[protocol]:
2156                         failures[protocol][packet_type] = 0
2157
2158                     if fields[5] == 'True':
2159                         successful += 1
2160                     else:
2161                         failed += 1
2162                         failures[protocol][packet_type] += 1
2163
2164                     if conversation not in unique_converations:
2165                         unique_converations.add(conversation)
2166                         conversations += 1
2167
2168                     tw(line)
2169                 except (ValueError, IndexError):
2170                     # not a valid line print and ignore
2171                     print(line, file=sys.stderr)
2172                     pass
2173     duration = last - first
2174     if successful == 0:
2175         success_rate = 0
2176     else:
2177         success_rate = successful / duration
2178     if failed == 0:
2179         failure_rate = 0
2180     else:
2181         failure_rate = failed / duration
2182
2183     print("Total conversations:   %10d" % conversations)
2184     print("Successful operations: %10d (%.3f per second)"
2185           % (successful, success_rate))
2186     print("Failed operations:     %10d (%.3f per second)"
2187           % (failed, failure_rate))
2188
2189     print("Protocol    Op Code  Description                               "
2190           " Count       Failed         Mean       Median          "
2191           "95%        Range          Max")
2192
2193     protocols = sorted(latencies.keys())
2194     for protocol in protocols:
2195         packet_types = sorted(latencies[protocol], key=opcode_key)
2196         for packet_type in packet_types:
2197             values     = latencies[protocol][packet_type]
2198             values     = sorted(values)
2199             count      = len(values)
2200             failed     = failures[protocol][packet_type]
2201             mean       = sum(values) / count
2202             median     = calc_percentile(values, 0.50)
2203             percentile = calc_percentile(values, 0.95)
2204             rng        = values[-1] - values[0]
2205             maxv       = values[-1]
2206             desc       = OP_DESCRIPTIONS.get((protocol, packet_type), '')
2207             if sys.stdout.isatty:
2208                 print("%-12s   %4s  %-35s %12d %12d %12.6f "
2209                       "%12.6f %12.6f %12.6f %12.6f"
2210                       % (protocol,
2211                          packet_type,
2212                          desc,
2213                          count,
2214                          failed,
2215                          mean,
2216                          median,
2217                          percentile,
2218                          rng,
2219                          maxv))
2220             else:
2221                 print("%s\t%s\t%s\t%d\t%d\t%f\t%f\t%f\t%f\t%f"
2222                       % (protocol,
2223                          packet_type,
2224                          desc,
2225                          count,
2226                          failed,
2227                          mean,
2228                          median,
2229                          percentile,
2230                          rng,
2231                          maxv))
2232
2233
2234 def opcode_key(v):
2235     """Sort key for the operation code to ensure that it sorts numerically"""
2236     try:
2237         return "%03d" % int(v)
2238     except:
2239         return v
2240
2241
2242 def calc_percentile(values, percentile):
2243     """Calculate the specified percentile from the list of values.
2244
2245     Assumes the list is sorted in ascending order.
2246     """
2247
2248     if not values:
2249         return 0
2250     k = (len(values) - 1) * percentile
2251     f = math.floor(k)
2252     c = math.ceil(k)
2253     if f == c:
2254         return values[int(k)]
2255     d0 = values[int(f)] * (c - k)
2256     d1 = values[int(c)] * (k - f)
2257     return d0 + d1
2258
2259
2260 def mk_masked_dir(*path):
2261     """In a testenv we end up with 0777 directories that look an alarming
2262     green colour with ls. Use umask to avoid that."""
2263     # py3 os.mkdir can do this
2264     d = os.path.join(*path)
2265     mask = os.umask(0o077)
2266     os.mkdir(d)
2267     os.umask(mask)
2268     return d