traffic: add paged_results control for ldb search
[samba.git] / python / samba / emulate / traffic.py
1 # -*- encoding: utf-8 -*-
2 # Samba traffic replay and learning
3 #
4 # Copyright (C) Catalyst IT Ltd. 2017
5 #
6 # This program is free software; you can redistribute it and/or modify
7 # it under the terms of the GNU General Public License as published by
8 # the Free Software Foundation; either version 3 of the License, or
9 # (at your option) any later version.
10 #
11 # This program is distributed in the hope that it will be useful,
12 # but WITHOUT ANY WARRANTY; without even the implied warranty of
13 # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
14 # GNU General Public License for more details.
15 #
16 # You should have received a copy of the GNU General Public License
17 # along with this program.  If not, see <http://www.gnu.org/licenses/>.
18 #
19 from __future__ import print_function, division
20
21 import time
22 import os
23 import random
24 import json
25 import math
26 import sys
27 import signal
28 import itertools
29
30 from collections import OrderedDict, Counter, defaultdict
31 from samba.emulate import traffic_packets
32 from samba.samdb import SamDB
33 import ldb
34 from ldb import LdbError
35 from samba.dcerpc import ClientConnection
36 from samba.dcerpc import security, drsuapi, lsa
37 from samba.dcerpc import netlogon
38 from samba.dcerpc.netlogon import netr_Authenticator
39 from samba.dcerpc import srvsvc
40 from samba.dcerpc import samr
41 from samba.drs_utils import drs_DsBind
42 import traceback
43 from samba.credentials import Credentials, DONT_USE_KERBEROS, MUST_USE_KERBEROS
44 from samba.auth import system_session
45 from samba.dsdb import UF_WORKSTATION_TRUST_ACCOUNT, UF_PASSWD_NOTREQD
46 from samba.dsdb import UF_NORMAL_ACCOUNT
47 from samba.dcerpc.misc import SEC_CHAN_WKSTA
48 from samba import gensec
49
50 SLEEP_OVERHEAD = 3e-4
51
52 # we don't use None, because it complicates [de]serialisation
53 NON_PACKET = '-'
54
55 CLIENT_CLUES = {
56     ('dns', '0'): 1.0,      # query
57     ('smb', '0x72'): 1.0,   # Negotiate protocol
58     ('ldap', '0'): 1.0,     # bind
59     ('ldap', '3'): 1.0,     # searchRequest
60     ('ldap', '2'): 1.0,     # unbindRequest
61     ('cldap', '3'): 1.0,
62     ('dcerpc', '11'): 1.0,  # bind
63     ('dcerpc', '14'): 1.0,  # Alter_context
64     ('nbns', '0'): 1.0,     # query
65 }
66
67 SERVER_CLUES = {
68     ('dns', '1'): 1.0,      # response
69     ('ldap', '1'): 1.0,     # bind response
70     ('ldap', '4'): 1.0,     # search result
71     ('ldap', '5'): 1.0,     # search done
72     ('cldap', '5'): 1.0,
73     ('dcerpc', '12'): 1.0,  # bind_ack
74     ('dcerpc', '13'): 1.0,  # bind_nak
75     ('dcerpc', '15'): 1.0,  # Alter_context response
76 }
77
78 SKIPPED_PROTOCOLS = {"smb", "smb2", "browser", "smb_netlogon"}
79
80 WAIT_SCALE = 10.0
81 WAIT_THRESHOLD = (1.0 / WAIT_SCALE)
82 NO_WAIT_LOG_TIME_RANGE = (-10, -3)
83
84 # DEBUG_LEVEL can be changed by scripts with -d
85 DEBUG_LEVEL = 0
86
87
88 def debug(level, msg, *args):
89     """Print a formatted debug message to standard error.
90
91
92     :param level: The debug level, message will be printed if it is <= the
93                   currently set debug level. The debug level can be set with
94                   the -d option.
95     :param msg:   The message to be logged, can contain C-Style format
96                   specifiers
97     :param args:  The parameters required by the format specifiers
98     """
99     if level <= DEBUG_LEVEL:
100         if not args:
101             print(msg, file=sys.stderr)
102         else:
103             print(msg % tuple(args), file=sys.stderr)
104
105
106 def debug_lineno(*args):
107     """ Print an unformatted log message to stderr, contaning the line number
108     """
109     tb = traceback.extract_stack(limit=2)
110     print((" %s:" "\033[01;33m"
111            "%s " "\033[00m" % (tb[0][2], tb[0][1])), end=' ',
112           file=sys.stderr)
113     for a in args:
114         print(a, file=sys.stderr)
115     print(file=sys.stderr)
116     sys.stderr.flush()
117
118
119 def random_colour_print():
120     """Return a function that prints a randomly coloured line to stderr"""
121     n = 18 + random.randrange(214)
122     prefix = "\033[38;5;%dm" % n
123
124     def p(*args):
125         for a in args:
126             print("%s%s\033[00m" % (prefix, a), file=sys.stderr)
127
128     return p
129
130
131 class FakePacketError(Exception):
132     pass
133
134
135 class Packet(object):
136     """Details of a network packet"""
137     def __init__(self, fields):
138         if isinstance(fields, str):
139             fields = fields.rstrip('\n').split('\t')
140
141         (timestamp,
142          ip_protocol,
143          stream_number,
144          src,
145          dest,
146          protocol,
147          opcode,
148          desc) = fields[:8]
149         extra = fields[8:]
150
151         self.timestamp = float(timestamp)
152         self.ip_protocol = ip_protocol
153         try:
154             self.stream_number = int(stream_number)
155         except (ValueError, TypeError):
156             self.stream_number = None
157         self.src = int(src)
158         self.dest = int(dest)
159         self.protocol = protocol
160         self.opcode = opcode
161         self.desc = desc
162         self.extra = extra
163
164         if self.src < self.dest:
165             self.endpoints = (self.src, self.dest)
166         else:
167             self.endpoints = (self.dest, self.src)
168
169     def as_summary(self, time_offset=0.0):
170         """Format the packet as a traffic_summary line.
171         """
172         extra = '\t'.join(self.extra)
173         t = self.timestamp + time_offset
174         return (t, '%f\t%s\t%s\t%d\t%d\t%s\t%s\t%s\t%s' %
175                 (t,
176                  self.ip_protocol,
177                  self.stream_number or '',
178                  self.src,
179                  self.dest,
180                  self.protocol,
181                  self.opcode,
182                  self.desc,
183                  extra))
184
185     def __str__(self):
186         return ("%.3f: %d -> %d; ip %s; strm %s; prot %s; op %s; desc %s %s" %
187                 (self.timestamp, self.src, self.dest, self.ip_protocol or '-',
188                  self.stream_number, self.protocol, self.opcode, self.desc,
189                  ('«' + ' '.join(self.extra) + '»' if self.extra else '')))
190
191     def __repr__(self):
192         return "<Packet @%s>" % self
193
194     def copy(self):
195         return self.__class__([self.timestamp,
196                                self.ip_protocol,
197                                self.stream_number,
198                                self.src,
199                                self.dest,
200                                self.protocol,
201                                self.opcode,
202                                self.desc] + self.extra)
203
204     def as_packet_type(self):
205         t = '%s:%s' % (self.protocol, self.opcode)
206         return t
207
208     def client_score(self):
209         """A positive number means we think it is a client; a negative number
210         means we think it is a server. Zero means no idea. range: -1 to 1.
211         """
212         key = (self.protocol, self.opcode)
213         if key in CLIENT_CLUES:
214             return CLIENT_CLUES[key]
215         if key in SERVER_CLUES:
216             return -SERVER_CLUES[key]
217         return 0.0
218
219     def play(self, conversation, context):
220         """Send the packet over the network, if required.
221
222         Some packets are ignored, i.e. for  protocols not handled,
223         server response messages, or messages that are generated by the
224         protocol layer associated with other packets.
225         """
226         fn_name = 'packet_%s_%s' % (self.protocol, self.opcode)
227         try:
228             fn = getattr(traffic_packets, fn_name)
229
230         except AttributeError as e:
231             print("Conversation(%s) Missing handler %s" % \
232                   (conversation.conversation_id, fn_name),
233                   file=sys.stderr)
234             return
235
236         # Don't display a message for kerberos packets, they're not directly
237         # generated they're used to indicate kerberos should be used
238         if self.protocol != "kerberos":
239             debug(2, "Conversation(%s) Calling handler %s" %
240                      (conversation.conversation_id, fn_name))
241
242         start = time.time()
243         try:
244             if fn(self, conversation, context):
245                 # Only collect timing data for functions that generate
246                 # network traffic, or fail
247                 end = time.time()
248                 duration = end - start
249                 print("%f\t%s\t%s\t%s\t%f\tTrue\t" %
250                       (end, conversation.conversation_id, self.protocol,
251                        self.opcode, duration))
252         except Exception as e:
253             end = time.time()
254             duration = end - start
255             print("%f\t%s\t%s\t%s\t%f\tFalse\t%s" %
256                   (end, conversation.conversation_id, self.protocol,
257                    self.opcode, duration, e))
258
259     def __cmp__(self, other):
260         return self.timestamp - other.timestamp
261
262     def is_really_a_packet(self, missing_packet_stats=None):
263         """Is the packet one that can be ignored?
264
265         If so removing it will have no effect on the replay
266         """
267         if self.protocol in SKIPPED_PROTOCOLS:
268             # Ignore any packets for the protocols we're not interested in.
269             return False
270         if self.protocol == "ldap" and self.opcode == '':
271             # skip ldap continuation packets
272             return False
273
274         fn_name = 'packet_%s_%s' % (self.protocol, self.opcode)
275         try:
276             fn = getattr(traffic_packets, fn_name)
277             if fn is traffic_packets.null_packet:
278                 return False
279         except AttributeError:
280             print("missing packet %s" % fn_name, file=sys.stderr)
281             return False
282         return True
283
284
285 class ReplayContext(object):
286     """State/Context for an individual conversation between an simulated client
287        and a server.
288     """
289
290     def __init__(self,
291                  server=None,
292                  lp=None,
293                  creds=None,
294                  badpassword_frequency=None,
295                  prefer_kerberos=None,
296                  tempdir=None,
297                  statsdir=None,
298                  ou=None,
299                  base_dn=None,
300                  domain=None,
301                  domain_sid=None):
302
303         self.server                   = server
304         self.ldap_connections         = []
305         self.dcerpc_connections       = []
306         self.lsarpc_connections       = []
307         self.lsarpc_connections_named = []
308         self.drsuapi_connections      = []
309         self.srvsvc_connections       = []
310         self.samr_contexts            = []
311         self.netlogon_connection      = None
312         self.creds                    = creds
313         self.lp                       = lp
314         self.prefer_kerberos          = prefer_kerberos
315         self.ou                       = ou
316         self.base_dn                  = base_dn
317         self.domain                   = domain
318         self.statsdir                 = statsdir
319         self.global_tempdir           = tempdir
320         self.domain_sid               = domain_sid
321         self.realm                    = lp.get('realm')
322
323         # Bad password attempt controls
324         self.badpassword_frequency    = badpassword_frequency
325         self.last_lsarpc_bad          = False
326         self.last_lsarpc_named_bad    = False
327         self.last_simple_bind_bad     = False
328         self.last_bind_bad            = False
329         self.last_srvsvc_bad          = False
330         self.last_drsuapi_bad         = False
331         self.last_netlogon_bad        = False
332         self.last_samlogon_bad        = False
333         self.generate_ldap_search_tables()
334         self.next_conversation_id = itertools.count().next
335
336     def generate_ldap_search_tables(self):
337         session = system_session()
338
339         db = SamDB(url="ldap://%s" % self.server,
340                    session_info=session,
341                    credentials=self.creds,
342                    lp=self.lp)
343
344         res = db.search(db.domain_dn(),
345                         scope=ldb.SCOPE_SUBTREE,
346                         controls=["paged_results:1:1000"],
347                         attrs=['dn'])
348
349         # find a list of dns for each pattern
350         # e.g. CN,CN,CN,DC,DC
351         dn_map = {}
352         attribute_clue_map = {
353             'invocationId': []
354         }
355
356         for r in res:
357             dn = str(r.dn)
358             pattern = ','.join(x.lstrip()[:2] for x in dn.split(',')).upper()
359             dns = dn_map.setdefault(pattern, [])
360             dns.append(dn)
361             if dn.startswith('CN=NTDS Settings,'):
362                 attribute_clue_map['invocationId'].append(dn)
363
364         # extend the map in case we are working with a different
365         # number of DC components.
366         # for k, v in self.dn_map.items():
367         #     print >>sys.stderr, k, len(v)
368
369         for k, v in dn_map.items():
370             if k[-3:] != ',DC':
371                 continue
372             p = k[:-3]
373             while p[-3:] == ',DC':
374                 p = p[:-3]
375             for i in range(5):
376                 p += ',DC'
377                 if p != k and p in dn_map:
378                     print('dn_map collison %s %s' % (k, p),
379                           file=sys.stderr)
380                     continue
381                 dn_map[p] = dn_map[k]
382
383         self.dn_map = dn_map
384         self.attribute_clue_map = attribute_clue_map
385
386     def generate_process_local_config(self, account, conversation):
387         if account is None:
388             return
389         self.netbios_name             = account.netbios_name
390         self.machinepass              = account.machinepass
391         self.username                 = account.username
392         self.userpass                 = account.userpass
393
394         self.tempdir = mk_masked_dir(self.global_tempdir,
395                                      'conversation-%d' %
396                                      conversation.conversation_id)
397
398         self.lp.set("private dir",     self.tempdir)
399         self.lp.set("lock dir",        self.tempdir)
400         self.lp.set("state directory", self.tempdir)
401         self.lp.set("tls verify peer", "no_check")
402
403         # If the domain was not specified, check for the environment
404         # variable.
405         if self.domain is None:
406             self.domain = os.environ["DOMAIN"]
407
408         self.remoteAddress = "/root/ncalrpc_as_system"
409         self.samlogon_dn   = ("cn=%s,%s" %
410                               (self.netbios_name, self.ou))
411         self.user_dn       = ("cn=%s,%s" %
412                               (self.username, self.ou))
413
414         self.generate_machine_creds()
415         self.generate_user_creds()
416
417     def with_random_bad_credentials(self, f, good, bad, failed_last_time):
418         """Execute the supplied logon function, randomly choosing the
419            bad credentials.
420
421            Based on the frequency in badpassword_frequency randomly perform the
422            function with the supplied bad credentials.
423            If run with bad credentials, the function is re-run with the good
424            credentials.
425            failed_last_time is used to prevent consecutive bad credential
426            attempts. So the over all bad credential frequency will be lower
427            than that requested, but not significantly.
428         """
429         if not failed_last_time:
430             if (self.badpassword_frequency > 0 and
431                random.random() < self.badpassword_frequency):
432                 try:
433                     f(bad)
434                 except:
435                     # Ignore any exceptions as the operation may fail
436                     # as it's being performed with bad credentials
437                     pass
438                 failed_last_time = True
439             else:
440                 failed_last_time = False
441
442         result = f(good)
443         return (result, failed_last_time)
444
445     def generate_user_creds(self):
446         """Generate the conversation specific user Credentials.
447
448         Each Conversation has an associated user account used to simulate
449         any non Administrative user traffic.
450
451         Generates user credentials with good and bad passwords and ldap
452         simple bind credentials with good and bad passwords.
453         """
454         self.user_creds = Credentials()
455         self.user_creds.guess(self.lp)
456         self.user_creds.set_workstation(self.netbios_name)
457         self.user_creds.set_password(self.userpass)
458         self.user_creds.set_username(self.username)
459         if self.prefer_kerberos:
460             self.user_creds.set_kerberos_state(MUST_USE_KERBEROS)
461         else:
462             self.user_creds.set_kerberos_state(DONT_USE_KERBEROS)
463
464         self.user_creds_bad = Credentials()
465         self.user_creds_bad.guess(self.lp)
466         self.user_creds_bad.set_workstation(self.netbios_name)
467         self.user_creds_bad.set_password(self.userpass[:-4])
468         self.user_creds_bad.set_username(self.username)
469         if self.prefer_kerberos:
470             self.user_creds_bad.set_kerberos_state(MUST_USE_KERBEROS)
471         else:
472             self.user_creds_bad.set_kerberos_state(DONT_USE_KERBEROS)
473
474         # Credentials for ldap simple bind.
475         self.simple_bind_creds = Credentials()
476         self.simple_bind_creds.guess(self.lp)
477         self.simple_bind_creds.set_workstation(self.netbios_name)
478         self.simple_bind_creds.set_password(self.userpass)
479         self.simple_bind_creds.set_username(self.username)
480         self.simple_bind_creds.set_gensec_features(
481             self.simple_bind_creds.get_gensec_features() | gensec.FEATURE_SEAL)
482         if self.prefer_kerberos:
483             self.simple_bind_creds.set_kerberos_state(MUST_USE_KERBEROS)
484         else:
485             self.simple_bind_creds.set_kerberos_state(DONT_USE_KERBEROS)
486         self.simple_bind_creds.set_bind_dn(self.user_dn)
487
488         self.simple_bind_creds_bad = Credentials()
489         self.simple_bind_creds_bad.guess(self.lp)
490         self.simple_bind_creds_bad.set_workstation(self.netbios_name)
491         self.simple_bind_creds_bad.set_password(self.userpass[:-4])
492         self.simple_bind_creds_bad.set_username(self.username)
493         self.simple_bind_creds_bad.set_gensec_features(
494             self.simple_bind_creds_bad.get_gensec_features() |
495             gensec.FEATURE_SEAL)
496         if self.prefer_kerberos:
497             self.simple_bind_creds_bad.set_kerberos_state(MUST_USE_KERBEROS)
498         else:
499             self.simple_bind_creds_bad.set_kerberos_state(DONT_USE_KERBEROS)
500         self.simple_bind_creds_bad.set_bind_dn(self.user_dn)
501
502     def generate_machine_creds(self):
503         """Generate the conversation specific machine Credentials.
504
505         Each Conversation has an associated machine account.
506
507         Generates machine credentials with good and bad passwords.
508         """
509
510         self.machine_creds = Credentials()
511         self.machine_creds.guess(self.lp)
512         self.machine_creds.set_workstation(self.netbios_name)
513         self.machine_creds.set_secure_channel_type(SEC_CHAN_WKSTA)
514         self.machine_creds.set_password(self.machinepass)
515         self.machine_creds.set_username(self.netbios_name + "$")
516         if self.prefer_kerberos:
517             self.machine_creds.set_kerberos_state(MUST_USE_KERBEROS)
518         else:
519             self.machine_creds.set_kerberos_state(DONT_USE_KERBEROS)
520
521         self.machine_creds_bad = Credentials()
522         self.machine_creds_bad.guess(self.lp)
523         self.machine_creds_bad.set_workstation(self.netbios_name)
524         self.machine_creds_bad.set_secure_channel_type(SEC_CHAN_WKSTA)
525         self.machine_creds_bad.set_password(self.machinepass[:-4])
526         self.machine_creds_bad.set_username(self.netbios_name + "$")
527         if self.prefer_kerberos:
528             self.machine_creds_bad.set_kerberos_state(MUST_USE_KERBEROS)
529         else:
530             self.machine_creds_bad.set_kerberos_state(DONT_USE_KERBEROS)
531
532     def get_matching_dn(self, pattern, attributes=None):
533         # If the pattern is an empty string, we assume ROOTDSE,
534         # Otherwise we try adding or removing DC suffixes, then
535         # shorter leading patterns until we hit one.
536         # e.g if there is no CN,CN,CN,CN,DC,DC
537         # we first try       CN,CN,CN,CN,DC
538         # and                CN,CN,CN,CN,DC,DC,DC
539         # then change to        CN,CN,CN,DC,DC
540         # and as last resort we use the base_dn
541         attr_clue = self.attribute_clue_map.get(attributes)
542         if attr_clue:
543             return random.choice(attr_clue)
544
545         pattern = pattern.upper()
546         while pattern:
547             if pattern in self.dn_map:
548                 return random.choice(self.dn_map[pattern])
549             # chop one off the front and try it all again.
550             pattern = pattern[3:]
551
552         return self.base_dn
553
554     def get_dcerpc_connection(self, new=False):
555         guid = '12345678-1234-abcd-ef00-01234567cffb'  # RPC_NETLOGON UUID
556         if self.dcerpc_connections and not new:
557             return self.dcerpc_connections[-1]
558         c = ClientConnection("ncacn_ip_tcp:%s" % self.server,
559                              (guid, 1), self.lp)
560         self.dcerpc_connections.append(c)
561         return c
562
563     def get_srvsvc_connection(self, new=False):
564         if self.srvsvc_connections and not new:
565             return self.srvsvc_connections[-1]
566
567         def connect(creds):
568             return srvsvc.srvsvc("ncacn_np:%s" % (self.server),
569                                  self.lp,
570                                  creds)
571
572         (c, self.last_srvsvc_bad) = \
573             self.with_random_bad_credentials(connect,
574                                              self.user_creds,
575                                              self.user_creds_bad,
576                                              self.last_srvsvc_bad)
577
578         self.srvsvc_connections.append(c)
579         return c
580
581     def get_lsarpc_connection(self, new=False):
582         if self.lsarpc_connections and not new:
583             return self.lsarpc_connections[-1]
584
585         def connect(creds):
586             binding_options = 'schannel,seal,sign'
587             return lsa.lsarpc("ncacn_ip_tcp:%s[%s]" %
588                               (self.server, binding_options),
589                               self.lp,
590                               creds)
591
592         (c, self.last_lsarpc_bad) = \
593             self.with_random_bad_credentials(connect,
594                                              self.machine_creds,
595                                              self.machine_creds_bad,
596                                              self.last_lsarpc_bad)
597
598         self.lsarpc_connections.append(c)
599         return c
600
601     def get_lsarpc_named_pipe_connection(self, new=False):
602         if self.lsarpc_connections_named and not new:
603             return self.lsarpc_connections_named[-1]
604
605         def connect(creds):
606             return lsa.lsarpc("ncacn_np:%s" % (self.server),
607                               self.lp,
608                               creds)
609
610         (c, self.last_lsarpc_named_bad) = \
611             self.with_random_bad_credentials(connect,
612                                              self.machine_creds,
613                                              self.machine_creds_bad,
614                                              self.last_lsarpc_named_bad)
615
616         self.lsarpc_connections_named.append(c)
617         return c
618
619     def get_drsuapi_connection_pair(self, new=False, unbind=False):
620         """get a (drs, drs_handle) tuple"""
621         if self.drsuapi_connections and not new:
622             c = self.drsuapi_connections[-1]
623             return c
624
625         def connect(creds):
626             binding_options = 'seal'
627             binding_string = "ncacn_ip_tcp:%s[%s]" %\
628                              (self.server, binding_options)
629             return drsuapi.drsuapi(binding_string, self.lp, creds)
630
631         (drs, self.last_drsuapi_bad) = \
632             self.with_random_bad_credentials(connect,
633                                              self.user_creds,
634                                              self.user_creds_bad,
635                                              self.last_drsuapi_bad)
636
637         (drs_handle, supported_extensions) = drs_DsBind(drs)
638         c = (drs, drs_handle)
639         self.drsuapi_connections.append(c)
640         return c
641
642     def get_ldap_connection(self, new=False, simple=False):
643         if self.ldap_connections and not new:
644             return self.ldap_connections[-1]
645
646         def simple_bind(creds):
647             return SamDB('ldaps://%s' % self.server,
648                          credentials=creds,
649                          lp=self.lp)
650
651         def sasl_bind(creds):
652             return SamDB('ldap://%s' % self.server,
653                          credentials=creds,
654                          lp=self.lp)
655         if simple:
656             (samdb, self.last_simple_bind_bad) = \
657                 self.with_random_bad_credentials(simple_bind,
658                                                  self.simple_bind_creds,
659                                                  self.simple_bind_creds_bad,
660                                                  self.last_simple_bind_bad)
661         else:
662             (samdb, self.last_bind_bad) = \
663                 self.with_random_bad_credentials(sasl_bind,
664                                                  self.user_creds,
665                                                  self.user_creds_bad,
666                                                  self.last_bind_bad)
667
668         self.ldap_connections.append(samdb)
669         return samdb
670
671     def get_samr_context(self, new=False):
672         if not self.samr_contexts or new:
673             self.samr_contexts.append(SamrContext(self.server))
674         return self.samr_contexts[-1]
675
676     def get_netlogon_connection(self):
677
678         if self.netlogon_connection:
679             return self.netlogon_connection
680
681         def connect(creds):
682             return netlogon.netlogon("ncacn_ip_tcp:%s[schannel,seal]" %
683                                      (self.server),
684                                      self.lp,
685                                      creds)
686         (c, self.last_netlogon_bad) = \
687             self.with_random_bad_credentials(connect,
688                                              self.machine_creds,
689                                              self.machine_creds_bad,
690                                              self.last_netlogon_bad)
691         self.netlogon_connection = c
692         return c
693
694     def guess_a_dns_lookup(self):
695         return (self.realm, 'A')
696
697     def get_authenticator(self):
698         auth = self.machine_creds.new_client_authenticator()
699         current  = netr_Authenticator()
700         current.cred.data = [ord(x) for x in auth["credential"]]
701         current.timestamp = auth["timestamp"]
702
703         subsequent = netr_Authenticator()
704         return (current, subsequent)
705
706
707 class SamrContext(object):
708     """State/Context associated with a samr connection.
709     """
710     def __init__(self, server):
711         self.connection    = None
712         self.handle        = None
713         self.domain_handle = None
714         self.domain_sid    = None
715         self.group_handle  = None
716         self.user_handle   = None
717         self.rids          = None
718         self.server        = server
719
720     def get_connection(self):
721         if not self.connection:
722             self.connection = samr.samr("ncacn_ip_tcp:%s" % (self.server))
723         return self.connection
724
725     def get_handle(self):
726         if not self.handle:
727             c = self.get_connection()
728             self.handle = c.Connect2(None, security.SEC_FLAG_MAXIMUM_ALLOWED)
729         return self.handle
730
731
732 class Conversation(object):
733     """Details of a converation between a simulated client and a server."""
734     conversation_id = None
735
736     def __init__(self, start_time=None, endpoints=None):
737         self.start_time = start_time
738         self.endpoints = endpoints
739         self.packets = []
740         self.msg = random_colour_print()
741         self.client_balance = 0.0
742
743     def __cmp__(self, other):
744         if self.start_time is None:
745             if other.start_time is None:
746                 return 0
747             return -1
748         if other.start_time is None:
749             return 1
750         return self.start_time - other.start_time
751
752     def add_packet(self, packet):
753         """Add a packet object to this conversation, making a local copy with
754         a conversation-relative timestamp."""
755         p = packet.copy()
756
757         if self.start_time is None:
758             self.start_time = p.timestamp
759
760         if self.endpoints is None:
761             self.endpoints = p.endpoints
762
763         if p.endpoints != self.endpoints:
764             raise FakePacketError("Conversation endpoints %s don't match"
765                                   "packet endpoints %s" %
766                                   (self.endpoints, p.endpoints))
767
768         p.timestamp -= self.start_time
769
770         if p.src == p.endpoints[0]:
771             self.client_balance -= p.client_score()
772         else:
773             self.client_balance += p.client_score()
774
775         if p.is_really_a_packet():
776             self.packets.append(p)
777
778     def add_short_packet(self, timestamp, p, extra, client=True):
779         """Create a packet from a timestamp, and 'protocol:opcode' pair, and a
780         (possibly empty) list of extra data. If client is True, assume
781         this packet is from the client to the server.
782         """
783         protocol, opcode = p.split(':', 1)
784         src, dest = self.guess_client_server()
785         if not client:
786             src, dest = dest, src
787
788         desc = OP_DESCRIPTIONS.get((protocol, opcode), '')
789         ip_protocol = IP_PROTOCOLS.get(protocol, '06')
790         fields = [timestamp - self.start_time, ip_protocol,
791                   '', src, dest,
792                   protocol, opcode, desc]
793         fields.extend(extra)
794         packet = Packet(fields)
795         # XXX we're assuming the timestamp is already adjusted for
796         # this conversation?
797         # XXX should we adjust client balance for guessed packets?
798         if packet.src == packet.endpoints[0]:
799             self.client_balance -= packet.client_score()
800         else:
801             self.client_balance += packet.client_score()
802         if packet.is_really_a_packet():
803             self.packets.append(packet)
804
805     def __str__(self):
806         return ("<Conversation %s %s starting %.3f %d packets>" %
807                 (self.conversation_id, self.endpoints, self.start_time,
808                  len(self.packets)))
809
810     __repr__ = __str__
811
812     def __iter__(self):
813         return iter(self.packets)
814
815     def __len__(self):
816         return len(self.packets)
817
818     def get_duration(self):
819         if len(self.packets) < 2:
820             return 0
821         return self.packets[-1].timestamp - self.packets[0].timestamp
822
823     def replay_as_summary_lines(self):
824         lines = []
825         for p in self.packets:
826             lines.append(p.as_summary(self.start_time))
827         return lines
828
829     def replay_in_fork_with_delay(self, start, context=None, account=None):
830         """Fork a new process and replay the conversation.
831         """
832         def signal_handler(signal, frame):
833             """Signal handler closes standard out and error.
834
835             Triggered by a sigterm, ensures that the log messages are flushed
836             to disk and not lost.
837             """
838             sys.stderr.close()
839             sys.stdout.close()
840             os._exit(0)
841
842         t = self.start_time
843         now = time.time() - start
844         gap = t - now
845         # we are replaying strictly in order, so it is safe to sleep
846         # in the main process if the gap is big enough. This reduces
847         # the number of concurrent threads, which allows us to make
848         # larger loads.
849         if gap > 0.15 and False:
850             print("sleeping for %f in main process" % (gap - 0.1),
851                   file=sys.stderr)
852             time.sleep(gap - 0.1)
853             now = time.time() - start
854             gap = t - now
855             print("gap is now %f" % gap, file=sys.stderr)
856
857         self.conversation_id = context.next_conversation_id()
858         pid = os.fork()
859         if pid != 0:
860             return pid
861         pid = os.getpid()
862         signal.signal(signal.SIGTERM, signal_handler)
863         # we must never return, or we'll end up running parts of the
864         # parent's clean-up code. So we work in a try...finally, and
865         # try to print any exceptions.
866
867         try:
868             context.generate_process_local_config(account, self)
869             sys.stdin.close()
870             os.close(0)
871             filename = os.path.join(context.statsdir, 'stats-conversation-%d' %
872                                     self.conversation_id)
873             sys.stdout.close()
874             sys.stdout = open(filename, 'w')
875
876             sleep_time = gap - SLEEP_OVERHEAD
877             if sleep_time > 0:
878                 time.sleep(sleep_time)
879
880             miss = t - (time.time() - start)
881             self.msg("starting %s [miss %.3f pid %d]" % (self, miss, pid))
882             self.replay(context)
883         except Exception:
884             print(("EXCEPTION in child PID %d, conversation %s" % (pid, self)),
885                   file=sys.stderr)
886             traceback.print_exc(sys.stderr)
887         finally:
888             sys.stderr.close()
889             sys.stdout.close()
890             os._exit(0)
891
892     def replay(self, context=None):
893         start = time.time()
894
895         for p in self.packets:
896             now = time.time() - start
897             gap = p.timestamp - now
898             sleep_time = gap - SLEEP_OVERHEAD
899             if sleep_time > 0:
900                 time.sleep(sleep_time)
901
902             miss = p.timestamp - (time.time() - start)
903             if context is None:
904                 self.msg("packet %s [miss %.3f pid %d]" % (p, miss,
905                                                            os.getpid()))
906                 continue
907             p.play(self, context)
908
909     def guess_client_server(self, server_clue=None):
910         """Have a go at deciding who is the server and who is the client.
911         returns (client, server)
912         """
913         a, b = self.endpoints
914
915         if self.client_balance < 0:
916             return (a, b)
917
918         # in the absense of a clue, we will fall through to assuming
919         # the lowest number is the server (which is usually true).
920
921         if self.client_balance == 0 and server_clue == b:
922             return (a, b)
923
924         return (b, a)
925
926     def forget_packets_outside_window(self, s, e):
927         """Prune any packets outside the timne window we're interested in
928
929         :param s: start of the window
930         :param e: end of the window
931         """
932
933         new_packets = []
934         for p in self.packets:
935             if p.timestamp < s or p.timestamp > e:
936                 continue
937             new_packets.append(p)
938
939         self.packets = new_packets
940         if new_packets:
941             self.start_time = new_packets[0].timestamp
942         else:
943             self.start_time = None
944
945     def renormalise_times(self, start_time):
946         """Adjust the packet start times relative to the new start time."""
947         for p in self.packets:
948             p.timestamp -= start_time
949
950         if self.start_time is not None:
951             self.start_time -= start_time
952
953
954 class DnsHammer(Conversation):
955     """A lightweight conversation that generates a lot of dns:0 packets on
956     the fly"""
957
958     def __init__(self, dns_rate, duration):
959         n = int(dns_rate * duration)
960         self.times = [random.uniform(0, duration) for i in range(n)]
961         self.times.sort()
962         self.rate = dns_rate
963         self.duration = duration
964         self.start_time = 0
965         self.msg = random_colour_print()
966
967     def __str__(self):
968         return ("<DnsHammer %d packets over %.1fs (rate %.2f)>" %
969                 (len(self.times), self.duration, self.rate))
970
971     def replay_in_fork_with_delay(self, start, context=None, account=None):
972         return Conversation.replay_in_fork_with_delay(self,
973                                                       start,
974                                                       context,
975                                                       account)
976
977     def replay(self, context=None):
978         start = time.time()
979         fn = traffic_packets.packet_dns_0
980         for t in self.times:
981             now = time.time() - start
982             gap = t - now
983             sleep_time = gap - SLEEP_OVERHEAD
984             if sleep_time > 0:
985                 time.sleep(sleep_time)
986
987             if context is None:
988                 miss = t - (time.time() - start)
989                 self.msg("packet %s [miss %.3f pid %d]" % (t, miss,
990                                                            os.getpid()))
991                 continue
992
993             packet_start = time.time()
994             try:
995                 fn(self, self, context)
996                 end = time.time()
997                 duration = end - packet_start
998                 print("%f\tDNS\tdns\t0\t%f\tTrue\t" % (end, duration))
999             except Exception as e:
1000                 end = time.time()
1001                 duration = end - packet_start
1002                 print("%f\tDNS\tdns\t0\t%f\tFalse\t%s" % (end, duration, e))
1003
1004
1005 def ingest_summaries(files, dns_mode='count'):
1006     """Load a summary traffic summary file and generated Converations from it.
1007     """
1008
1009     dns_counts = defaultdict(int)
1010     packets = []
1011     for f in files:
1012         if isinstance(f, str):
1013             f = open(f)
1014         print("Ingesting %s" % (f.name,), file=sys.stderr)
1015         for line in f:
1016             p = Packet(line)
1017             if p.protocol == 'dns' and dns_mode != 'include':
1018                 dns_counts[p.opcode] += 1
1019             else:
1020                 packets.append(p)
1021
1022         f.close()
1023
1024     if not packets:
1025         return [], 0
1026
1027     start_time = min(p.timestamp for p in packets)
1028     last_packet = max(p.timestamp for p in packets)
1029
1030     print("gathering packets into conversations", file=sys.stderr)
1031     conversations = OrderedDict()
1032     for p in packets:
1033         p.timestamp -= start_time
1034         c = conversations.get(p.endpoints)
1035         if c is None:
1036             c = Conversation()
1037             conversations[p.endpoints] = c
1038         c.add_packet(p)
1039
1040     # We only care about conversations with actual traffic, so we
1041     # filter out conversations with nothing to say. We do that here,
1042     # rather than earlier, because those empty packets contain useful
1043     # hints as to which end of the conversation was the client.
1044     conversation_list = []
1045     for c in conversations.values():
1046         if len(c) != 0:
1047             conversation_list.append(c)
1048
1049     # This is obviously not correct, as many conversations will appear
1050     # to start roughly simultaneously at the beginning of the snapshot.
1051     # To which we say: oh well, so be it.
1052     duration = float(last_packet - start_time)
1053     mean_interval = len(conversations) / duration
1054
1055     return conversation_list, mean_interval, duration, dns_counts
1056
1057
1058 def guess_server_address(conversations):
1059     # we guess the most common address.
1060     addresses = Counter()
1061     for c in conversations:
1062         addresses.update(c.endpoints)
1063     if addresses:
1064         return addresses.most_common(1)[0]
1065
1066
1067 def stringify_keys(x):
1068     y = {}
1069     for k, v in x.items():
1070         k2 = '\t'.join(k)
1071         y[k2] = v
1072     return y
1073
1074
1075 def unstringify_keys(x):
1076     y = {}
1077     for k, v in x.items():
1078         t = tuple(str(k).split('\t'))
1079         y[t] = v
1080     return y
1081
1082
1083 class TrafficModel(object):
1084     def __init__(self, n=3):
1085         self.ngrams = {}
1086         self.query_details = {}
1087         self.n = n
1088         self.dns_opcounts = defaultdict(int)
1089         self.cumulative_duration = 0.0
1090         self.conversation_rate = [0, 1]
1091
1092     def learn(self, conversations, dns_opcounts={}):
1093         prev = 0.0
1094         cum_duration = 0.0
1095         key = (NON_PACKET,) * (self.n - 1)
1096
1097         server = guess_server_address(conversations)
1098
1099         for k, v in dns_opcounts.items():
1100             self.dns_opcounts[k] += v
1101
1102         if len(conversations) > 1:
1103             elapsed =\
1104                 conversations[-1].start_time - conversations[0].start_time
1105             self.conversation_rate[0] = len(conversations)
1106             self.conversation_rate[1] = elapsed
1107
1108         for c in conversations:
1109             client, server = c.guess_client_server(server)
1110             cum_duration += c.get_duration()
1111             key = (NON_PACKET,) * (self.n - 1)
1112             for p in c:
1113                 if p.src != client:
1114                     continue
1115
1116                 elapsed = p.timestamp - prev
1117                 prev = p.timestamp
1118                 if elapsed > WAIT_THRESHOLD:
1119                     # add the wait as an extra state
1120                     wait = 'wait:%d' % (math.log(max(1.0,
1121                                                      elapsed * WAIT_SCALE)))
1122                     self.ngrams.setdefault(key, []).append(wait)
1123                     key = key[1:] + (wait,)
1124
1125                 short_p = p.as_packet_type()
1126                 self.query_details.setdefault(short_p,
1127                                               []).append(tuple(p.extra))
1128                 self.ngrams.setdefault(key, []).append(short_p)
1129                 key = key[1:] + (short_p,)
1130
1131         self.cumulative_duration += cum_duration
1132         # add in the end
1133         self.ngrams.setdefault(key, []).append(NON_PACKET)
1134
1135     def save(self, f):
1136         ngrams = {}
1137         for k, v in self.ngrams.items():
1138             k = '\t'.join(k)
1139             ngrams[k] = dict(Counter(v))
1140
1141         query_details = {}
1142         for k, v in self.query_details.items():
1143             query_details[k] = dict(Counter('\t'.join(x) if x else '-'
1144                                             for x in v))
1145
1146         d = {
1147             'ngrams': ngrams,
1148             'query_details': query_details,
1149             'cumulative_duration': self.cumulative_duration,
1150             'conversation_rate': self.conversation_rate,
1151         }
1152         d['dns'] = self.dns_opcounts
1153
1154         if isinstance(f, str):
1155             f = open(f, 'w')
1156
1157         json.dump(d, f, indent=2)
1158
1159     def load(self, f):
1160         if isinstance(f, str):
1161             f = open(f)
1162
1163         d = json.load(f)
1164
1165         for k, v in d['ngrams'].items():
1166             k = tuple(str(k).split('\t'))
1167             values = self.ngrams.setdefault(k, [])
1168             for p, count in v.items():
1169                 values.extend([str(p)] * count)
1170
1171         for k, v in d['query_details'].items():
1172             values = self.query_details.setdefault(str(k), [])
1173             for p, count in v.items():
1174                 if p == '-':
1175                     values.extend([()] * count)
1176                 else:
1177                     values.extend([tuple(str(p).split('\t'))] * count)
1178
1179         if 'dns' in d:
1180             for k, v in d['dns'].items():
1181                 self.dns_opcounts[k] += v
1182
1183         self.cumulative_duration = d['cumulative_duration']
1184         self.conversation_rate = d['conversation_rate']
1185
1186     def construct_conversation(self, timestamp=0.0, client=2, server=1,
1187                                hard_stop=None, packet_rate=1):
1188         """Construct a individual converation from the model."""
1189
1190         c = Conversation(timestamp, (server, client))
1191
1192         key = (NON_PACKET,) * (self.n - 1)
1193
1194         while key in self.ngrams:
1195             p = random.choice(self.ngrams.get(key, NON_PACKET))
1196             if p == NON_PACKET:
1197                 break
1198             if p in self.query_details:
1199                 extra = random.choice(self.query_details[p])
1200             else:
1201                 extra = []
1202
1203             protocol, opcode = p.split(':', 1)
1204             if protocol == 'wait':
1205                 log_wait_time = int(opcode) + random.random()
1206                 wait = math.exp(log_wait_time) / (WAIT_SCALE * packet_rate)
1207                 timestamp += wait
1208             else:
1209                 log_wait = random.uniform(*NO_WAIT_LOG_TIME_RANGE)
1210                 wait = math.exp(log_wait) / packet_rate
1211                 timestamp += wait
1212                 if hard_stop is not None and timestamp > hard_stop:
1213                     break
1214                 c.add_short_packet(timestamp, p, extra)
1215
1216             key = key[1:] + (p,)
1217
1218         return c
1219
1220     def generate_conversations(self, rate, duration, packet_rate=1):
1221         """Generate a list of conversations from the model."""
1222
1223         # We run the simulation for at least ten times as long as our
1224         # desired duration, and take a section near the start.
1225         rate_n, rate_t  = self.conversation_rate
1226
1227         duration2 = max(rate_t, duration * 2)
1228         n = rate * duration2 * rate_n / rate_t
1229
1230         server = 1
1231         client = 2
1232
1233         conversations = []
1234         end = duration2
1235         start = end - duration
1236
1237         while client < n + 2:
1238             start = random.uniform(0, duration2)
1239             c = self.construct_conversation(start,
1240                                             client,
1241                                             server,
1242                                             hard_stop=(duration2 * 5),
1243                                             packet_rate=packet_rate)
1244
1245             c.forget_packets_outside_window(start, end)
1246             c.renormalise_times(start)
1247             if len(c) != 0:
1248                 conversations.append(c)
1249             client += 1
1250
1251         print(("we have %d conversations at rate %f" %
1252                               (len(conversations), rate)), file=sys.stderr)
1253         conversations.sort()
1254         return conversations
1255
1256
1257 IP_PROTOCOLS = {
1258     'dns': '11',
1259     'rpc_netlogon': '06',
1260     'kerberos': '06',      # ratio 16248:258
1261     'smb': '06',
1262     'smb2': '06',
1263     'ldap': '06',
1264     'cldap': '11',
1265     'lsarpc': '06',
1266     'samr': '06',
1267     'dcerpc': '06',
1268     'epm': '06',
1269     'drsuapi': '06',
1270     'browser': '11',
1271     'smb_netlogon': '11',
1272     'srvsvc': '06',
1273     'nbns': '11',
1274 }
1275
1276 OP_DESCRIPTIONS = {
1277     ('browser', '0x01'): 'Host Announcement (0x01)',
1278     ('browser', '0x02'): 'Request Announcement (0x02)',
1279     ('browser', '0x08'): 'Browser Election Request (0x08)',
1280     ('browser', '0x09'): 'Get Backup List Request (0x09)',
1281     ('browser', '0x0c'): 'Domain/Workgroup Announcement (0x0c)',
1282     ('browser', '0x0f'): 'Local Master Announcement (0x0f)',
1283     ('cldap', '3'): 'searchRequest',
1284     ('cldap', '5'): 'searchResDone',
1285     ('dcerpc', '0'): 'Request',
1286     ('dcerpc', '11'): 'Bind',
1287     ('dcerpc', '12'): 'Bind_ack',
1288     ('dcerpc', '13'): 'Bind_nak',
1289     ('dcerpc', '14'): 'Alter_context',
1290     ('dcerpc', '15'): 'Alter_context_resp',
1291     ('dcerpc', '16'): 'AUTH3',
1292     ('dcerpc', '2'): 'Response',
1293     ('dns', '0'): 'query',
1294     ('dns', '1'): 'response',
1295     ('drsuapi', '0'): 'DsBind',
1296     ('drsuapi', '12'): 'DsCrackNames',
1297     ('drsuapi', '13'): 'DsWriteAccountSpn',
1298     ('drsuapi', '1'): 'DsUnbind',
1299     ('drsuapi', '2'): 'DsReplicaSync',
1300     ('drsuapi', '3'): 'DsGetNCChanges',
1301     ('drsuapi', '4'): 'DsReplicaUpdateRefs',
1302     ('epm', '3'): 'Map',
1303     ('kerberos', ''): '',
1304     ('ldap', '0'): 'bindRequest',
1305     ('ldap', '1'): 'bindResponse',
1306     ('ldap', '2'): 'unbindRequest',
1307     ('ldap', '3'): 'searchRequest',
1308     ('ldap', '4'): 'searchResEntry',
1309     ('ldap', '5'): 'searchResDone',
1310     ('ldap', ''): '*** Unknown ***',
1311     ('lsarpc', '14'): 'lsa_LookupNames',
1312     ('lsarpc', '15'): 'lsa_LookupSids',
1313     ('lsarpc', '39'): 'lsa_QueryTrustedDomainInfoBySid',
1314     ('lsarpc', '40'): 'lsa_SetTrustedDomainInfo',
1315     ('lsarpc', '6'): 'lsa_OpenPolicy',
1316     ('lsarpc', '76'): 'lsa_LookupSids3',
1317     ('lsarpc', '77'): 'lsa_LookupNames4',
1318     ('nbns', '0'): 'query',
1319     ('nbns', '1'): 'response',
1320     ('rpc_netlogon', '21'): 'NetrLogonDummyRoutine1',
1321     ('rpc_netlogon', '26'): 'NetrServerAuthenticate3',
1322     ('rpc_netlogon', '29'): 'NetrLogonGetDomainInfo',
1323     ('rpc_netlogon', '30'): 'NetrServerPasswordSet2',
1324     ('rpc_netlogon', '39'): 'NetrLogonSamLogonEx',
1325     ('rpc_netlogon', '40'): 'DsrEnumerateDomainTrusts',
1326     ('rpc_netlogon', '45'): 'NetrLogonSamLogonWithFlags',
1327     ('rpc_netlogon', '4'): 'NetrServerReqChallenge',
1328     ('samr', '0',): 'Connect',
1329     ('samr', '16'): 'GetAliasMembership',
1330     ('samr', '17'): 'LookupNames',
1331     ('samr', '18'): 'LookupRids',
1332     ('samr', '19'): 'OpenGroup',
1333     ('samr', '1'): 'Close',
1334     ('samr', '25'): 'QueryGroupMember',
1335     ('samr', '34'): 'OpenUser',
1336     ('samr', '36'): 'QueryUserInfo',
1337     ('samr', '39'): 'GetGroupsForUser',
1338     ('samr', '3'): 'QuerySecurity',
1339     ('samr', '5'): 'LookupDomain',
1340     ('samr', '64'): 'Connect5',
1341     ('samr', '6'): 'EnumDomains',
1342     ('samr', '7'): 'OpenDomain',
1343     ('samr', '8'): 'QueryDomainInfo',
1344     ('smb', '0x04'): 'Close (0x04)',
1345     ('smb', '0x24'): 'Locking AndX (0x24)',
1346     ('smb', '0x2e'): 'Read AndX (0x2e)',
1347     ('smb', '0x32'): 'Trans2 (0x32)',
1348     ('smb', '0x71'): 'Tree Disconnect (0x71)',
1349     ('smb', '0x72'): 'Negotiate Protocol (0x72)',
1350     ('smb', '0x73'): 'Session Setup AndX (0x73)',
1351     ('smb', '0x74'): 'Logoff AndX (0x74)',
1352     ('smb', '0x75'): 'Tree Connect AndX (0x75)',
1353     ('smb', '0xa2'): 'NT Create AndX (0xa2)',
1354     ('smb2', '0'): 'NegotiateProtocol',
1355     ('smb2', '11'): 'Ioctl',
1356     ('smb2', '14'): 'Find',
1357     ('smb2', '16'): 'GetInfo',
1358     ('smb2', '18'): 'Break',
1359     ('smb2', '1'): 'SessionSetup',
1360     ('smb2', '2'): 'SessionLogoff',
1361     ('smb2', '3'): 'TreeConnect',
1362     ('smb2', '4'): 'TreeDisconnect',
1363     ('smb2', '5'): 'Create',
1364     ('smb2', '6'): 'Close',
1365     ('smb2', '8'): 'Read',
1366     ('smb_netlogon', '0x12'): 'SAM LOGON request from client (0x12)',
1367     ('smb_netlogon', '0x17'): ('SAM Active Directory Response - '
1368                                'user unknown (0x17)'),
1369     ('srvsvc', '16'): 'NetShareGetInfo',
1370     ('srvsvc', '21'): 'NetSrvGetInfo',
1371 }
1372
1373
1374 def expand_short_packet(p, timestamp, src, dest, extra):
1375     protocol, opcode = p.split(':', 1)
1376     desc = OP_DESCRIPTIONS.get((protocol, opcode), '')
1377     ip_protocol = IP_PROTOCOLS.get(protocol, '06')
1378
1379     line = [timestamp, ip_protocol, '', src, dest, protocol, opcode, desc]
1380     line.extend(extra)
1381     return '\t'.join(line)
1382
1383
1384 def replay(conversations,
1385            host=None,
1386            creds=None,
1387            lp=None,
1388            accounts=None,
1389            dns_rate=0,
1390            duration=None,
1391            **kwargs):
1392
1393     context = ReplayContext(server=host,
1394                             creds=creds,
1395                             lp=lp,
1396                             **kwargs)
1397
1398     if len(accounts) < len(conversations):
1399         print(("we have %d accounts but %d conversations" %
1400                (accounts, conversations)), file=sys.stderr)
1401
1402     cstack = list(zip(
1403         sorted(conversations, key=lambda x: x.start_time, reverse=True),
1404         accounts))
1405
1406     # Set the process group so that the calling scripts are not killed
1407     # when the forked child processes are killed.
1408     os.setpgrp()
1409
1410     start = time.time()
1411
1412     if duration is None:
1413         # end 1 second after the last packet of the last conversation
1414         # to start. Conversations other than the last could still be
1415         # going, but we don't care.
1416         duration = cstack[0][0].packets[-1].timestamp + 1.0
1417         print("We will stop after %.1f seconds" % duration,
1418               file=sys.stderr)
1419
1420     end = start + duration
1421
1422     print("Replaying traffic for %u conversations over %d seconds"
1423           % (len(conversations), duration))
1424
1425     children = {}
1426     if dns_rate:
1427         dns_hammer = DnsHammer(dns_rate, duration)
1428         cstack.append((dns_hammer, None))
1429
1430     try:
1431         while True:
1432             # we spawn a batch, wait for finishers, then spawn another
1433             now = time.time()
1434             batch_end = min(now + 2.0, end)
1435             fork_time = 0.0
1436             fork_n = 0
1437             while cstack:
1438                 c, account = cstack.pop()
1439                 if c.start_time + start > batch_end:
1440                     cstack.append((c, account))
1441                     break
1442
1443                 st = time.time()
1444                 pid = c.replay_in_fork_with_delay(start, context, account)
1445                 children[pid] = c
1446                 t = time.time()
1447                 elapsed = t - st
1448                 fork_time += elapsed
1449                 fork_n += 1
1450                 print("forked %s in pid %s (in %fs)" % (c, pid,
1451                                                         elapsed),
1452                       file=sys.stderr)
1453
1454             if fork_n:
1455                 print(("forked %d times in %f seconds (avg %f)" %
1456                        (fork_n, fork_time, fork_time / fork_n)),
1457                       file=sys.stderr)
1458             elif cstack:
1459                 debug(2, "no forks in batch ending %f" % batch_end)
1460
1461             while time.time() < batch_end - 1.0:
1462                 time.sleep(0.01)
1463                 try:
1464                     pid, status = os.waitpid(-1, os.WNOHANG)
1465                 except OSError as e:
1466                     if e.errno != 10:  # no child processes
1467                         raise
1468                     break
1469                 if pid:
1470                     c = children.pop(pid, None)
1471                     print(("process %d finished conversation %s;"
1472                            " %d to go" %
1473                            (pid, c, len(children))), file=sys.stderr)
1474
1475             if time.time() >= end:
1476                 print("time to stop", file=sys.stderr)
1477                 break
1478
1479     except Exception:
1480         print("EXCEPTION in parent", file=sys.stderr)
1481         traceback.print_exc()
1482     finally:
1483         for s in (15, 15, 9):
1484             print(("killing %d children with -%d" %
1485                                  (len(children), s)), file=sys.stderr)
1486             for pid in children:
1487                 try:
1488                     os.kill(pid, s)
1489                 except OSError as e:
1490                     if e.errno != 3:  # don't fail if it has already died
1491                         raise
1492             time.sleep(0.5)
1493             end = time.time() + 1
1494             while children:
1495                 try:
1496                     pid, status = os.waitpid(-1, os.WNOHANG)
1497                 except OSError as e:
1498                     if e.errno != 10:
1499                         raise
1500                 if pid != 0:
1501                     c = children.pop(pid, None)
1502                     print(("kill -%d %d KILLED conversation %s; "
1503                            "%d to go" %
1504                            (s, pid, c, len(children))),
1505                           file=sys.stderr)
1506                 if time.time() >= end:
1507                     break
1508
1509             if not children:
1510                 break
1511             time.sleep(1)
1512
1513         if children:
1514             print("%d children are missing" % len(children),
1515                   file=sys.stderr)
1516
1517         # there may be stragglers that were forked just as ^C was hit
1518         # and don't appear in the list of children. We can get them
1519         # with killpg, but that will also kill us, so this is^H^H would be
1520         # goodbye, except we cheat and pretend to use ^C (SIG_INTERRUPT),
1521         # so as not to have to fuss around writing signal handlers.
1522         try:
1523             os.killpg(0, 2)
1524         except KeyboardInterrupt:
1525             print("ignoring fake ^C", file=sys.stderr)
1526
1527
1528 def openLdb(host, creds, lp):
1529     session = system_session()
1530     ldb = SamDB(url="ldap://%s" % host,
1531                 session_info=session,
1532                 credentials=creds,
1533                 lp=lp)
1534     return ldb
1535
1536
1537 def ou_name(ldb, instance_id):
1538     """Generate an ou name from the instance id"""
1539     return "ou=instance-%d,ou=traffic_replay,%s" % (instance_id,
1540                                                     ldb.domain_dn())
1541
1542
1543 def create_ou(ldb, instance_id):
1544     """Create an ou, all created user and machine accounts will belong to it.
1545
1546     This allows all the created resources to be cleaned up easily.
1547     """
1548     ou = ou_name(ldb, instance_id)
1549     try:
1550         ldb.add({"dn":          ou.split(',', 1)[1],
1551                  "objectclass": "organizationalunit"})
1552     except LdbError as e:
1553         (status, _) = e
1554         # ignore already exists
1555         if status != 68:
1556             raise
1557     try:
1558         ldb.add({"dn":          ou,
1559                  "objectclass": "organizationalunit"})
1560     except LdbError as e:
1561         (status, _) = e
1562         # ignore already exists
1563         if status != 68:
1564             raise
1565     return ou
1566
1567
1568 class ConversationAccounts(object):
1569     """Details of the machine and user accounts associated with a conversation.
1570     """
1571     def __init__(self, netbios_name, machinepass, username, userpass):
1572         self.netbios_name = netbios_name
1573         self.machinepass  = machinepass
1574         self.username     = username
1575         self.userpass     = userpass
1576
1577
1578 def generate_replay_accounts(ldb, instance_id, number, password):
1579     """Generate a series of unique machine and user account names."""
1580
1581     generate_traffic_accounts(ldb, instance_id, number, password)
1582     accounts = []
1583     for i in range(1, number + 1):
1584         netbios_name = "STGM-%d-%d" % (instance_id, i)
1585         username     = "STGU-%d-%d" % (instance_id, i)
1586
1587         account = ConversationAccounts(netbios_name, password, username,
1588                                        password)
1589         accounts.append(account)
1590     return accounts
1591
1592
1593 def generate_traffic_accounts(ldb, instance_id, number, password):
1594     """Create the specified number of user and machine accounts.
1595
1596     As accounts are not explicitly deleted between runs. This function starts
1597     with the last account and iterates backwards stopping either when it
1598     finds an already existing account or it has generated all the required
1599     accounts.
1600     """
1601     print(("Generating machine and conversation accounts, "
1602            "as required for %d conversations" % number),
1603           file=sys.stderr)
1604     added = 0
1605     for i in range(number, 0, -1):
1606         try:
1607             netbios_name = "STGM-%d-%d" % (instance_id, i)
1608             create_machine_account(ldb, instance_id, netbios_name, password)
1609             added += 1
1610         except LdbError as e:
1611             (status, _) = e
1612             if status == 68:
1613                 break
1614             else:
1615                 raise
1616     if added > 0:
1617         print("Added %d new machine accounts" % added,
1618               file=sys.stderr)
1619
1620     added = 0
1621     for i in range(number, 0, -1):
1622         try:
1623             username = "STGU-%d-%d" % (instance_id, i)
1624             create_user_account(ldb, instance_id, username, password)
1625             added += 1
1626         except LdbError as e:
1627             (status, _) = e
1628             if status == 68:
1629                 break
1630             else:
1631                 raise
1632
1633     if added > 0:
1634         print("Added %d new user accounts" % added,
1635               file=sys.stderr)
1636
1637
1638 def create_machine_account(ldb, instance_id, netbios_name, machinepass):
1639     """Create a machine account via ldap."""
1640
1641     ou = ou_name(ldb, instance_id)
1642     dn = "cn=%s,%s" % (netbios_name, ou)
1643     utf16pw = unicode(
1644         '"' + machinepass.encode('utf-8') + '"', 'utf-8'
1645     ).encode('utf-16-le')
1646     start = time.time()
1647     ldb.add({
1648         "dn": dn,
1649         "objectclass": "computer",
1650         "sAMAccountName": "%s$" % netbios_name,
1651         "userAccountControl":
1652         str(UF_WORKSTATION_TRUST_ACCOUNT | UF_PASSWD_NOTREQD),
1653         "unicodePwd": utf16pw})
1654     end = time.time()
1655     duration = end - start
1656     print("%f\t0\tcreate\tmachine\t%f\tTrue\t" % (end, duration))
1657
1658
1659 def create_user_account(ldb, instance_id, username, userpass):
1660     """Create a user account via ldap."""
1661     ou = ou_name(ldb, instance_id)
1662     user_dn = "cn=%s,%s" % (username, ou)
1663     utf16pw = unicode(
1664         '"' + userpass.encode('utf-8') + '"', 'utf-8'
1665     ).encode('utf-16-le')
1666     start = time.time()
1667     ldb.add({
1668         "dn": user_dn,
1669         "objectclass": "user",
1670         "sAMAccountName": username,
1671         "userAccountControl": str(UF_NORMAL_ACCOUNT),
1672         "unicodePwd": utf16pw
1673     })
1674     end = time.time()
1675     duration = end - start
1676     print("%f\t0\tcreate\tuser\t%f\tTrue\t" % (end, duration))
1677
1678
1679 def create_group(ldb, instance_id, name):
1680     """Create a group via ldap."""
1681
1682     ou = ou_name(ldb, instance_id)
1683     dn = "cn=%s,%s" % (name, ou)
1684     start = time.time()
1685     ldb.add({
1686         "dn": dn,
1687         "objectclass": "group",
1688     })
1689     end = time.time()
1690     duration = end - start
1691     print("%f\t0\tcreate\tgroup\t%f\tTrue\t" % (end, duration))
1692
1693
1694 def user_name(instance_id, i):
1695     """Generate a user name based in the instance id"""
1696     return "STGU-%d-%d" % (instance_id, i)
1697
1698
1699 def generate_users(ldb, instance_id, number, password):
1700     """Add users to the server"""
1701     users = 0
1702     for i in range(number, 0, -1):
1703         try:
1704             username = user_name(instance_id, i)
1705             create_user_account(ldb, instance_id, username, password)
1706             users += 1
1707         except LdbError as e:
1708             (status, _) = e
1709             # Stop if entry exists
1710             if status == 68:
1711                 break
1712             else:
1713                 raise
1714
1715     return users
1716
1717
1718 def group_name(instance_id, i):
1719     """Generate a group name from instance id."""
1720     return "STGG-%d-%d" % (instance_id, i)
1721
1722
1723 def generate_groups(ldb, instance_id, number):
1724     """Create the required number of groups on the server."""
1725     groups = 0
1726     for i in range(number, 0, -1):
1727         try:
1728             name = group_name(instance_id, i)
1729             create_group(ldb, instance_id, name)
1730             groups += 1
1731         except LdbError as e:
1732             (status, _) = e
1733             # Stop if entry exists
1734             if status == 68:
1735                 break
1736             else:
1737                 raise
1738     return groups
1739
1740
1741 def clean_up_accounts(ldb, instance_id):
1742     """Remove the created accounts and groups from the server."""
1743     ou = ou_name(ldb, instance_id)
1744     try:
1745         ldb.delete(ou, ["tree_delete:1"])
1746     except LdbError as e:
1747         (status, _) = e
1748         # ignore does not exist
1749         if status != 32:
1750             raise
1751
1752
1753 def generate_users_and_groups(ldb, instance_id, password,
1754                               number_of_users, number_of_groups,
1755                               group_memberships):
1756     """Generate the required users and groups, allocating the users to
1757        those groups."""
1758     assignments = []
1759     groups_added  = 0
1760
1761     create_ou(ldb, instance_id)
1762
1763     print("Generating dummy user accounts", file=sys.stderr)
1764     users_added = generate_users(ldb, instance_id, number_of_users, password)
1765
1766     if number_of_groups > 0:
1767         print("Generating dummy groups", file=sys.stderr)
1768         groups_added = generate_groups(ldb, instance_id, number_of_groups)
1769
1770     if group_memberships > 0:
1771         print("Assigning users to groups", file=sys.stderr)
1772         assignments = assign_groups(number_of_groups,
1773                                     groups_added,
1774                                     number_of_users,
1775                                     users_added,
1776                                     group_memberships)
1777         print("Adding users to groups", file=sys.stderr)
1778         add_users_to_groups(ldb, instance_id, assignments)
1779
1780     if (groups_added > 0 and users_added == 0 and
1781        number_of_groups != groups_added):
1782         print("Warning: the added groups will contain no members",
1783               file=sys.stderr)
1784
1785     print(("Added %d users, %d groups and %d group memberships" %
1786            (users_added, groups_added, len(assignments))),
1787           file=sys.stderr)
1788
1789
1790 def assign_groups(number_of_groups,
1791                   groups_added,
1792                   number_of_users,
1793                   users_added,
1794                   group_memberships):
1795     """Allocate users to groups.
1796
1797     The intention is to have a few users that belong to most groups, while
1798     the majority of users belong to a few groups.
1799
1800     A few groups will contain most users, with the remaining only having a
1801     few users.
1802     """
1803
1804     def generate_user_distribution(n):
1805         """Probability distribution of a user belonging to a group.
1806         """
1807         dist = []
1808         for x in range(1, n + 1):
1809             p = 1 / (x + 0.001)
1810             dist.append(p)
1811         return dist
1812
1813     def generate_group_distribution(n):
1814         """Probability distribution of a group containing a user."""
1815         dist = []
1816         for x in range(1, n + 1):
1817             p = 1 / (x**1.3)
1818             dist.append(p)
1819         return dist
1820
1821     assignments = set()
1822     if group_memberships <= 0:
1823         return assignments
1824
1825     group_dist = generate_group_distribution(number_of_groups)
1826     user_dist  = generate_user_distribution(number_of_users)
1827
1828     # Calculate the number of group menberships required
1829     group_memberships = math.ceil(
1830         float(group_memberships) *
1831         (float(users_added) / float(number_of_users)))
1832
1833     existing_users  = number_of_users  - users_added  - 1
1834     existing_groups = number_of_groups - groups_added - 1
1835     while len(assignments) < group_memberships:
1836         user        = random.randint(0, number_of_users - 1)
1837         group       = random.randint(0, number_of_groups - 1)
1838         probability = group_dist[group] * user_dist[user]
1839
1840         if ((random.random() < probability * 10000) and
1841            (group > existing_groups or user > existing_users)):
1842             # the + 1 converts the array index to the corresponding
1843             # group or user number
1844             assignments.add(((user + 1), (group + 1)))
1845
1846     return assignments
1847
1848
1849 def add_users_to_groups(db, instance_id, assignments):
1850     """Add users to their assigned groups.
1851
1852     Takes the list of (group,user) tuples generated by assign_groups and
1853     assign the users to their specified groups."""
1854
1855     ou = ou_name(db, instance_id)
1856
1857     def build_dn(name):
1858         return("cn=%s,%s" % (name, ou))
1859
1860     for (user, group) in assignments:
1861         user_dn  = build_dn(user_name(instance_id, user))
1862         group_dn = build_dn(group_name(instance_id, group))
1863
1864         m = ldb.Message()
1865         m.dn = ldb.Dn(db, group_dn)
1866         m["member"] = ldb.MessageElement(user_dn, ldb.FLAG_MOD_ADD, "member")
1867         start = time.time()
1868         db.modify(m)
1869         end = time.time()
1870         duration = end - start
1871         print("%f\t0\tadd\tuser\t%f\tTrue\t" % (end, duration))
1872
1873
1874 def generate_stats(statsdir, timing_file):
1875     """Generate and print the summary stats for a run."""
1876     first      = sys.float_info.max
1877     last       = 0
1878     successful = 0
1879     failed     = 0
1880     latencies  = {}
1881     failures   = {}
1882     unique_converations = set()
1883     conversations = 0
1884
1885     if timing_file is not None:
1886         tw = timing_file.write
1887     else:
1888         def tw(x):
1889             pass
1890
1891     tw("time\tconv\tprotocol\ttype\tduration\tsuccessful\terror\n")
1892
1893     for filename in os.listdir(statsdir):
1894         path = os.path.join(statsdir, filename)
1895         with open(path, 'r') as f:
1896             for line in f:
1897                 try:
1898                     fields       = line.rstrip('\n').split('\t')
1899                     conversation = fields[1]
1900                     protocol     = fields[2]
1901                     packet_type  = fields[3]
1902                     latency      = float(fields[4])
1903                     first        = min(float(fields[0]) - latency, first)
1904                     last         = max(float(fields[0]), last)
1905
1906                     if protocol not in latencies:
1907                         latencies[protocol] = {}
1908                     if packet_type not in latencies[protocol]:
1909                         latencies[protocol][packet_type] = []
1910
1911                     latencies[protocol][packet_type].append(latency)
1912
1913                     if protocol not in failures:
1914                         failures[protocol] = {}
1915                     if packet_type not in failures[protocol]:
1916                         failures[protocol][packet_type] = 0
1917
1918                     if fields[5] == 'True':
1919                         successful += 1
1920                     else:
1921                         failed += 1
1922                         failures[protocol][packet_type] += 1
1923
1924                     if conversation not in unique_converations:
1925                         unique_converations.add(conversation)
1926                         conversations += 1
1927
1928                     tw(line)
1929                 except (ValueError, IndexError):
1930                     # not a valid line print and ignore
1931                     print(line, file=sys.stderr)
1932                     pass
1933     duration = last - first
1934     if successful == 0:
1935         success_rate = 0
1936     else:
1937         success_rate = successful / duration
1938     if failed == 0:
1939         failure_rate = 0
1940     else:
1941         failure_rate = failed / duration
1942
1943     # print the stats in more human-readable format when stdout is going to the
1944     # console (as opposed to being redirected to a file)
1945     if sys.stdout.isatty():
1946         print("Total conversations:   %10d" % conversations)
1947         print("Successful operations: %10d (%.3f per second)"
1948               % (successful, success_rate))
1949         print("Failed operations:     %10d (%.3f per second)"
1950               % (failed, failure_rate))
1951     else:
1952         print("(%d, %d, %d, %.3f, %.3f)" %
1953               (conversations, successful, failed, success_rate, failure_rate))
1954
1955     if sys.stdout.isatty():
1956         print("Protocol    Op Code  Description                               "
1957               " Count       Failed         Mean       Median          "
1958               "95%        Range          Max")
1959     else:
1960         print("proto\top_code\tdesc\tcount\tfailed\tmean\tmedian\t95%\trange"
1961               "\tmax")
1962     protocols = sorted(latencies.keys())
1963     for protocol in protocols:
1964         packet_types = sorted(latencies[protocol], key=opcode_key)
1965         for packet_type in packet_types:
1966             values     = latencies[protocol][packet_type]
1967             values     = sorted(values)
1968             count      = len(values)
1969             failed     = failures[protocol][packet_type]
1970             mean       = sum(values) / count
1971             median     = calc_percentile(values, 0.50)
1972             percentile = calc_percentile(values, 0.95)
1973             rng        = values[-1] - values[0]
1974             maxv       = values[-1]
1975             desc       = OP_DESCRIPTIONS.get((protocol, packet_type), '')
1976             if sys.stdout.isatty:
1977                 print("%-12s   %4s  %-35s %12d %12d %12.6f "
1978                       "%12.6f %12.6f %12.6f %12.6f"
1979                       % (protocol,
1980                          packet_type,
1981                          desc,
1982                          count,
1983                          failed,
1984                          mean,
1985                          median,
1986                          percentile,
1987                          rng,
1988                          maxv))
1989             else:
1990                 print("%s\t%s\t%s\t%d\t%d\t%f\t%f\t%f\t%f\t%f"
1991                       % (protocol,
1992                          packet_type,
1993                          desc,
1994                          count,
1995                          failed,
1996                          mean,
1997                          median,
1998                          percentile,
1999                          rng,
2000                          maxv))
2001
2002
2003 def opcode_key(v):
2004     """Sort key for the operation code to ensure that it sorts numerically"""
2005     try:
2006         return "%03d" % int(v)
2007     except:
2008         return v
2009
2010
2011 def calc_percentile(values, percentile):
2012     """Calculate the specified percentile from the list of values.
2013
2014     Assumes the list is sorted in ascending order.
2015     """
2016
2017     if not values:
2018         return 0
2019     k = (len(values) - 1) * percentile
2020     f = math.floor(k)
2021     c = math.ceil(k)
2022     if f == c:
2023         return values[int(k)]
2024     d0 = values[int(f)] * (c - k)
2025     d1 = values[int(c)] * (k - f)
2026     return d0 + d1
2027
2028
2029 def mk_masked_dir(*path):
2030     """In a testenv we end up with 0777 diectories that look an alarming
2031     green colour with ls. Use umask to avoid that."""
2032     d = os.path.join(*path)
2033     mask = os.umask(0o077)
2034     os.mkdir(d)
2035     os.umask(mask)
2036     return d