traffic: grant user write permission
[samba.git] / python / samba / emulate / traffic.py
1 # -*- encoding: utf-8 -*-
2 # Samba traffic replay and learning
3 #
4 # Copyright (C) Catalyst IT Ltd. 2017
5 #
6 # This program is free software; you can redistribute it and/or modify
7 # it under the terms of the GNU General Public License as published by
8 # the Free Software Foundation; either version 3 of the License, or
9 # (at your option) any later version.
10 #
11 # This program is distributed in the hope that it will be useful,
12 # but WITHOUT ANY WARRANTY; without even the implied warranty of
13 # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
14 # GNU General Public License for more details.
15 #
16 # You should have received a copy of the GNU General Public License
17 # along with this program.  If not, see <http://www.gnu.org/licenses/>.
18 #
19 from __future__ import print_function, division
20
21 import time
22 import os
23 import random
24 import json
25 import math
26 import sys
27 import signal
28 import itertools
29
30 from collections import OrderedDict, Counter, defaultdict
31 from samba.emulate import traffic_packets
32 from samba.samdb import SamDB
33 import ldb
34 from ldb import LdbError
35 from samba.dcerpc import ClientConnection
36 from samba.dcerpc import security, drsuapi, lsa
37 from samba.dcerpc import netlogon
38 from samba.dcerpc.netlogon import netr_Authenticator
39 from samba.dcerpc import srvsvc
40 from samba.dcerpc import samr
41 from samba.drs_utils import drs_DsBind
42 import traceback
43 from samba.credentials import Credentials, DONT_USE_KERBEROS, MUST_USE_KERBEROS
44 from samba.auth import system_session
45 from samba.dsdb import UF_WORKSTATION_TRUST_ACCOUNT, UF_PASSWD_NOTREQD
46 from samba.dsdb import UF_NORMAL_ACCOUNT
47 from samba.dcerpc.misc import SEC_CHAN_WKSTA
48 from samba import gensec
49 from samba import sd_utils
50
51 SLEEP_OVERHEAD = 3e-4
52
53 # we don't use None, because it complicates [de]serialisation
54 NON_PACKET = '-'
55
56 CLIENT_CLUES = {
57     ('dns', '0'): 1.0,      # query
58     ('smb', '0x72'): 1.0,   # Negotiate protocol
59     ('ldap', '0'): 1.0,     # bind
60     ('ldap', '3'): 1.0,     # searchRequest
61     ('ldap', '2'): 1.0,     # unbindRequest
62     ('cldap', '3'): 1.0,
63     ('dcerpc', '11'): 1.0,  # bind
64     ('dcerpc', '14'): 1.0,  # Alter_context
65     ('nbns', '0'): 1.0,     # query
66 }
67
68 SERVER_CLUES = {
69     ('dns', '1'): 1.0,      # response
70     ('ldap', '1'): 1.0,     # bind response
71     ('ldap', '4'): 1.0,     # search result
72     ('ldap', '5'): 1.0,     # search done
73     ('cldap', '5'): 1.0,
74     ('dcerpc', '12'): 1.0,  # bind_ack
75     ('dcerpc', '13'): 1.0,  # bind_nak
76     ('dcerpc', '15'): 1.0,  # Alter_context response
77 }
78
79 SKIPPED_PROTOCOLS = {"smb", "smb2", "browser", "smb_netlogon"}
80
81 WAIT_SCALE = 10.0
82 WAIT_THRESHOLD = (1.0 / WAIT_SCALE)
83 NO_WAIT_LOG_TIME_RANGE = (-10, -3)
84
85 # DEBUG_LEVEL can be changed by scripts with -d
86 DEBUG_LEVEL = 0
87
88
89 def debug(level, msg, *args):
90     """Print a formatted debug message to standard error.
91
92
93     :param level: The debug level, message will be printed if it is <= the
94                   currently set debug level. The debug level can be set with
95                   the -d option.
96     :param msg:   The message to be logged, can contain C-Style format
97                   specifiers
98     :param args:  The parameters required by the format specifiers
99     """
100     if level <= DEBUG_LEVEL:
101         if not args:
102             print(msg, file=sys.stderr)
103         else:
104             print(msg % tuple(args), file=sys.stderr)
105
106
107 def debug_lineno(*args):
108     """ Print an unformatted log message to stderr, contaning the line number
109     """
110     tb = traceback.extract_stack(limit=2)
111     print((" %s:" "\033[01;33m"
112            "%s " "\033[00m" % (tb[0][2], tb[0][1])), end=' ',
113           file=sys.stderr)
114     for a in args:
115         print(a, file=sys.stderr)
116     print(file=sys.stderr)
117     sys.stderr.flush()
118
119
120 def random_colour_print():
121     """Return a function that prints a randomly coloured line to stderr"""
122     n = 18 + random.randrange(214)
123     prefix = "\033[38;5;%dm" % n
124
125     def p(*args):
126         for a in args:
127             print("%s%s\033[00m" % (prefix, a), file=sys.stderr)
128
129     return p
130
131
132 class FakePacketError(Exception):
133     pass
134
135
136 class Packet(object):
137     """Details of a network packet"""
138     def __init__(self, fields):
139         if isinstance(fields, str):
140             fields = fields.rstrip('\n').split('\t')
141
142         (timestamp,
143          ip_protocol,
144          stream_number,
145          src,
146          dest,
147          protocol,
148          opcode,
149          desc) = fields[:8]
150         extra = fields[8:]
151
152         self.timestamp = float(timestamp)
153         self.ip_protocol = ip_protocol
154         try:
155             self.stream_number = int(stream_number)
156         except (ValueError, TypeError):
157             self.stream_number = None
158         self.src = int(src)
159         self.dest = int(dest)
160         self.protocol = protocol
161         self.opcode = opcode
162         self.desc = desc
163         self.extra = extra
164
165         if self.src < self.dest:
166             self.endpoints = (self.src, self.dest)
167         else:
168             self.endpoints = (self.dest, self.src)
169
170     def as_summary(self, time_offset=0.0):
171         """Format the packet as a traffic_summary line.
172         """
173         extra = '\t'.join(self.extra)
174         t = self.timestamp + time_offset
175         return (t, '%f\t%s\t%s\t%d\t%d\t%s\t%s\t%s\t%s' %
176                 (t,
177                  self.ip_protocol,
178                  self.stream_number or '',
179                  self.src,
180                  self.dest,
181                  self.protocol,
182                  self.opcode,
183                  self.desc,
184                  extra))
185
186     def __str__(self):
187         return ("%.3f: %d -> %d; ip %s; strm %s; prot %s; op %s; desc %s %s" %
188                 (self.timestamp, self.src, self.dest, self.ip_protocol or '-',
189                  self.stream_number, self.protocol, self.opcode, self.desc,
190                  ('«' + ' '.join(self.extra) + '»' if self.extra else '')))
191
192     def __repr__(self):
193         return "<Packet @%s>" % self
194
195     def copy(self):
196         return self.__class__([self.timestamp,
197                                self.ip_protocol,
198                                self.stream_number,
199                                self.src,
200                                self.dest,
201                                self.protocol,
202                                self.opcode,
203                                self.desc] + self.extra)
204
205     def as_packet_type(self):
206         t = '%s:%s' % (self.protocol, self.opcode)
207         return t
208
209     def client_score(self):
210         """A positive number means we think it is a client; a negative number
211         means we think it is a server. Zero means no idea. range: -1 to 1.
212         """
213         key = (self.protocol, self.opcode)
214         if key in CLIENT_CLUES:
215             return CLIENT_CLUES[key]
216         if key in SERVER_CLUES:
217             return -SERVER_CLUES[key]
218         return 0.0
219
220     def play(self, conversation, context):
221         """Send the packet over the network, if required.
222
223         Some packets are ignored, i.e. for  protocols not handled,
224         server response messages, or messages that are generated by the
225         protocol layer associated with other packets.
226         """
227         fn_name = 'packet_%s_%s' % (self.protocol, self.opcode)
228         try:
229             fn = getattr(traffic_packets, fn_name)
230
231         except AttributeError as e:
232             print("Conversation(%s) Missing handler %s" % \
233                   (conversation.conversation_id, fn_name),
234                   file=sys.stderr)
235             return
236
237         # Don't display a message for kerberos packets, they're not directly
238         # generated they're used to indicate kerberos should be used
239         if self.protocol != "kerberos":
240             debug(2, "Conversation(%s) Calling handler %s" %
241                      (conversation.conversation_id, fn_name))
242
243         start = time.time()
244         try:
245             if fn(self, conversation, context):
246                 # Only collect timing data for functions that generate
247                 # network traffic, or fail
248                 end = time.time()
249                 duration = end - start
250                 print("%f\t%s\t%s\t%s\t%f\tTrue\t" %
251                       (end, conversation.conversation_id, self.protocol,
252                        self.opcode, duration))
253         except Exception as e:
254             end = time.time()
255             duration = end - start
256             print("%f\t%s\t%s\t%s\t%f\tFalse\t%s" %
257                   (end, conversation.conversation_id, self.protocol,
258                    self.opcode, duration, e))
259
260     def __cmp__(self, other):
261         return self.timestamp - other.timestamp
262
263     def is_really_a_packet(self, missing_packet_stats=None):
264         """Is the packet one that can be ignored?
265
266         If so removing it will have no effect on the replay
267         """
268         if self.protocol in SKIPPED_PROTOCOLS:
269             # Ignore any packets for the protocols we're not interested in.
270             return False
271         if self.protocol == "ldap" and self.opcode == '':
272             # skip ldap continuation packets
273             return False
274
275         fn_name = 'packet_%s_%s' % (self.protocol, self.opcode)
276         try:
277             fn = getattr(traffic_packets, fn_name)
278             if fn is traffic_packets.null_packet:
279                 return False
280         except AttributeError:
281             print("missing packet %s" % fn_name, file=sys.stderr)
282             return False
283         return True
284
285
286 class ReplayContext(object):
287     """State/Context for an individual conversation between an simulated client
288        and a server.
289     """
290
291     def __init__(self,
292                  server=None,
293                  lp=None,
294                  creds=None,
295                  badpassword_frequency=None,
296                  prefer_kerberos=None,
297                  tempdir=None,
298                  statsdir=None,
299                  ou=None,
300                  base_dn=None,
301                  domain=None,
302                  domain_sid=None):
303
304         self.server                   = server
305         self.ldap_connections         = []
306         self.dcerpc_connections       = []
307         self.lsarpc_connections       = []
308         self.lsarpc_connections_named = []
309         self.drsuapi_connections      = []
310         self.srvsvc_connections       = []
311         self.samr_contexts            = []
312         self.netlogon_connection      = None
313         self.creds                    = creds
314         self.lp                       = lp
315         self.prefer_kerberos          = prefer_kerberos
316         self.ou                       = ou
317         self.base_dn                  = base_dn
318         self.domain                   = domain
319         self.statsdir                 = statsdir
320         self.global_tempdir           = tempdir
321         self.domain_sid               = domain_sid
322         self.realm                    = lp.get('realm')
323
324         # Bad password attempt controls
325         self.badpassword_frequency    = badpassword_frequency
326         self.last_lsarpc_bad          = False
327         self.last_lsarpc_named_bad    = False
328         self.last_simple_bind_bad     = False
329         self.last_bind_bad            = False
330         self.last_srvsvc_bad          = False
331         self.last_drsuapi_bad         = False
332         self.last_netlogon_bad        = False
333         self.last_samlogon_bad        = False
334         self.generate_ldap_search_tables()
335         self.next_conversation_id = itertools.count().next
336
337     def generate_ldap_search_tables(self):
338         session = system_session()
339
340         db = SamDB(url="ldap://%s" % self.server,
341                    session_info=session,
342                    credentials=self.creds,
343                    lp=self.lp)
344
345         res = db.search(db.domain_dn(),
346                         scope=ldb.SCOPE_SUBTREE,
347                         controls=["paged_results:1:1000"],
348                         attrs=['dn'])
349
350         # find a list of dns for each pattern
351         # e.g. CN,CN,CN,DC,DC
352         dn_map = {}
353         attribute_clue_map = {
354             'invocationId': []
355         }
356
357         for r in res:
358             dn = str(r.dn)
359             pattern = ','.join(x.lstrip()[:2] for x in dn.split(',')).upper()
360             dns = dn_map.setdefault(pattern, [])
361             dns.append(dn)
362             if dn.startswith('CN=NTDS Settings,'):
363                 attribute_clue_map['invocationId'].append(dn)
364
365         # extend the map in case we are working with a different
366         # number of DC components.
367         # for k, v in self.dn_map.items():
368         #     print >>sys.stderr, k, len(v)
369
370         for k, v in dn_map.items():
371             if k[-3:] != ',DC':
372                 continue
373             p = k[:-3]
374             while p[-3:] == ',DC':
375                 p = p[:-3]
376             for i in range(5):
377                 p += ',DC'
378                 if p != k and p in dn_map:
379                     print('dn_map collison %s %s' % (k, p),
380                           file=sys.stderr)
381                     continue
382                 dn_map[p] = dn_map[k]
383
384         self.dn_map = dn_map
385         self.attribute_clue_map = attribute_clue_map
386
387     def generate_process_local_config(self, account, conversation):
388         if account is None:
389             return
390         self.netbios_name             = account.netbios_name
391         self.machinepass              = account.machinepass
392         self.username                 = account.username
393         self.userpass                 = account.userpass
394
395         self.tempdir = mk_masked_dir(self.global_tempdir,
396                                      'conversation-%d' %
397                                      conversation.conversation_id)
398
399         self.lp.set("private dir",     self.tempdir)
400         self.lp.set("lock dir",        self.tempdir)
401         self.lp.set("state directory", self.tempdir)
402         self.lp.set("tls verify peer", "no_check")
403
404         # If the domain was not specified, check for the environment
405         # variable.
406         if self.domain is None:
407             self.domain = os.environ["DOMAIN"]
408
409         self.remoteAddress = "/root/ncalrpc_as_system"
410         self.samlogon_dn   = ("cn=%s,%s" %
411                               (self.netbios_name, self.ou))
412         self.user_dn       = ("cn=%s,%s" %
413                               (self.username, self.ou))
414
415         self.generate_machine_creds()
416         self.generate_user_creds()
417
418     def with_random_bad_credentials(self, f, good, bad, failed_last_time):
419         """Execute the supplied logon function, randomly choosing the
420            bad credentials.
421
422            Based on the frequency in badpassword_frequency randomly perform the
423            function with the supplied bad credentials.
424            If run with bad credentials, the function is re-run with the good
425            credentials.
426            failed_last_time is used to prevent consecutive bad credential
427            attempts. So the over all bad credential frequency will be lower
428            than that requested, but not significantly.
429         """
430         if not failed_last_time:
431             if (self.badpassword_frequency > 0 and
432                random.random() < self.badpassword_frequency):
433                 try:
434                     f(bad)
435                 except:
436                     # Ignore any exceptions as the operation may fail
437                     # as it's being performed with bad credentials
438                     pass
439                 failed_last_time = True
440             else:
441                 failed_last_time = False
442
443         result = f(good)
444         return (result, failed_last_time)
445
446     def generate_user_creds(self):
447         """Generate the conversation specific user Credentials.
448
449         Each Conversation has an associated user account used to simulate
450         any non Administrative user traffic.
451
452         Generates user credentials with good and bad passwords and ldap
453         simple bind credentials with good and bad passwords.
454         """
455         self.user_creds = Credentials()
456         self.user_creds.guess(self.lp)
457         self.user_creds.set_workstation(self.netbios_name)
458         self.user_creds.set_password(self.userpass)
459         self.user_creds.set_username(self.username)
460         self.user_creds.set_domain(self.domain)
461         if self.prefer_kerberos:
462             self.user_creds.set_kerberos_state(MUST_USE_KERBEROS)
463         else:
464             self.user_creds.set_kerberos_state(DONT_USE_KERBEROS)
465
466         self.user_creds_bad = Credentials()
467         self.user_creds_bad.guess(self.lp)
468         self.user_creds_bad.set_workstation(self.netbios_name)
469         self.user_creds_bad.set_password(self.userpass[:-4])
470         self.user_creds_bad.set_username(self.username)
471         if self.prefer_kerberos:
472             self.user_creds_bad.set_kerberos_state(MUST_USE_KERBEROS)
473         else:
474             self.user_creds_bad.set_kerberos_state(DONT_USE_KERBEROS)
475
476         # Credentials for ldap simple bind.
477         self.simple_bind_creds = Credentials()
478         self.simple_bind_creds.guess(self.lp)
479         self.simple_bind_creds.set_workstation(self.netbios_name)
480         self.simple_bind_creds.set_password(self.userpass)
481         self.simple_bind_creds.set_username(self.username)
482         self.simple_bind_creds.set_gensec_features(
483             self.simple_bind_creds.get_gensec_features() | gensec.FEATURE_SEAL)
484         if self.prefer_kerberos:
485             self.simple_bind_creds.set_kerberos_state(MUST_USE_KERBEROS)
486         else:
487             self.simple_bind_creds.set_kerberos_state(DONT_USE_KERBEROS)
488         self.simple_bind_creds.set_bind_dn(self.user_dn)
489
490         self.simple_bind_creds_bad = Credentials()
491         self.simple_bind_creds_bad.guess(self.lp)
492         self.simple_bind_creds_bad.set_workstation(self.netbios_name)
493         self.simple_bind_creds_bad.set_password(self.userpass[:-4])
494         self.simple_bind_creds_bad.set_username(self.username)
495         self.simple_bind_creds_bad.set_gensec_features(
496             self.simple_bind_creds_bad.get_gensec_features() |
497             gensec.FEATURE_SEAL)
498         if self.prefer_kerberos:
499             self.simple_bind_creds_bad.set_kerberos_state(MUST_USE_KERBEROS)
500         else:
501             self.simple_bind_creds_bad.set_kerberos_state(DONT_USE_KERBEROS)
502         self.simple_bind_creds_bad.set_bind_dn(self.user_dn)
503
504     def generate_machine_creds(self):
505         """Generate the conversation specific machine Credentials.
506
507         Each Conversation has an associated machine account.
508
509         Generates machine credentials with good and bad passwords.
510         """
511
512         self.machine_creds = Credentials()
513         self.machine_creds.guess(self.lp)
514         self.machine_creds.set_workstation(self.netbios_name)
515         self.machine_creds.set_secure_channel_type(SEC_CHAN_WKSTA)
516         self.machine_creds.set_password(self.machinepass)
517         self.machine_creds.set_username(self.netbios_name + "$")
518         self.machine_creds.set_domain(self.domain)
519         if self.prefer_kerberos:
520             self.machine_creds.set_kerberos_state(MUST_USE_KERBEROS)
521         else:
522             self.machine_creds.set_kerberos_state(DONT_USE_KERBEROS)
523
524         self.machine_creds_bad = Credentials()
525         self.machine_creds_bad.guess(self.lp)
526         self.machine_creds_bad.set_workstation(self.netbios_name)
527         self.machine_creds_bad.set_secure_channel_type(SEC_CHAN_WKSTA)
528         self.machine_creds_bad.set_password(self.machinepass[:-4])
529         self.machine_creds_bad.set_username(self.netbios_name + "$")
530         if self.prefer_kerberos:
531             self.machine_creds_bad.set_kerberos_state(MUST_USE_KERBEROS)
532         else:
533             self.machine_creds_bad.set_kerberos_state(DONT_USE_KERBEROS)
534
535     def get_matching_dn(self, pattern, attributes=None):
536         # If the pattern is an empty string, we assume ROOTDSE,
537         # Otherwise we try adding or removing DC suffixes, then
538         # shorter leading patterns until we hit one.
539         # e.g if there is no CN,CN,CN,CN,DC,DC
540         # we first try       CN,CN,CN,CN,DC
541         # and                CN,CN,CN,CN,DC,DC,DC
542         # then change to        CN,CN,CN,DC,DC
543         # and as last resort we use the base_dn
544         attr_clue = self.attribute_clue_map.get(attributes)
545         if attr_clue:
546             return random.choice(attr_clue)
547
548         pattern = pattern.upper()
549         while pattern:
550             if pattern in self.dn_map:
551                 return random.choice(self.dn_map[pattern])
552             # chop one off the front and try it all again.
553             pattern = pattern[3:]
554
555         return self.base_dn
556
557     def get_dcerpc_connection(self, new=False):
558         guid = '12345678-1234-abcd-ef00-01234567cffb'  # RPC_NETLOGON UUID
559         if self.dcerpc_connections and not new:
560             return self.dcerpc_connections[-1]
561         c = ClientConnection("ncacn_ip_tcp:%s" % self.server,
562                              (guid, 1), self.lp)
563         self.dcerpc_connections.append(c)
564         return c
565
566     def get_srvsvc_connection(self, new=False):
567         if self.srvsvc_connections and not new:
568             return self.srvsvc_connections[-1]
569
570         def connect(creds):
571             return srvsvc.srvsvc("ncacn_np:%s" % (self.server),
572                                  self.lp,
573                                  creds)
574
575         (c, self.last_srvsvc_bad) = \
576             self.with_random_bad_credentials(connect,
577                                              self.user_creds,
578                                              self.user_creds_bad,
579                                              self.last_srvsvc_bad)
580
581         self.srvsvc_connections.append(c)
582         return c
583
584     def get_lsarpc_connection(self, new=False):
585         if self.lsarpc_connections and not new:
586             return self.lsarpc_connections[-1]
587
588         def connect(creds):
589             binding_options = 'schannel,seal,sign'
590             return lsa.lsarpc("ncacn_ip_tcp:%s[%s]" %
591                               (self.server, binding_options),
592                               self.lp,
593                               creds)
594
595         (c, self.last_lsarpc_bad) = \
596             self.with_random_bad_credentials(connect,
597                                              self.machine_creds,
598                                              self.machine_creds_bad,
599                                              self.last_lsarpc_bad)
600
601         self.lsarpc_connections.append(c)
602         return c
603
604     def get_lsarpc_named_pipe_connection(self, new=False):
605         if self.lsarpc_connections_named and not new:
606             return self.lsarpc_connections_named[-1]
607
608         def connect(creds):
609             return lsa.lsarpc("ncacn_np:%s" % (self.server),
610                               self.lp,
611                               creds)
612
613         (c, self.last_lsarpc_named_bad) = \
614             self.with_random_bad_credentials(connect,
615                                              self.machine_creds,
616                                              self.machine_creds_bad,
617                                              self.last_lsarpc_named_bad)
618
619         self.lsarpc_connections_named.append(c)
620         return c
621
622     def get_drsuapi_connection_pair(self, new=False, unbind=False):
623         """get a (drs, drs_handle) tuple"""
624         if self.drsuapi_connections and not new:
625             c = self.drsuapi_connections[-1]
626             return c
627
628         def connect(creds):
629             binding_options = 'seal'
630             binding_string = "ncacn_ip_tcp:%s[%s]" %\
631                              (self.server, binding_options)
632             return drsuapi.drsuapi(binding_string, self.lp, creds)
633
634         (drs, self.last_drsuapi_bad) = \
635             self.with_random_bad_credentials(connect,
636                                              self.user_creds,
637                                              self.user_creds_bad,
638                                              self.last_drsuapi_bad)
639
640         (drs_handle, supported_extensions) = drs_DsBind(drs)
641         c = (drs, drs_handle)
642         self.drsuapi_connections.append(c)
643         return c
644
645     def get_ldap_connection(self, new=False, simple=False):
646         if self.ldap_connections and not new:
647             return self.ldap_connections[-1]
648
649         def simple_bind(creds):
650             """
651             To run simple bind against Windows, we need to run
652             following commands in PowerShell:
653
654                 Install-windowsfeature ADCS-Cert-Authority
655                 Install-AdcsCertificationAuthority -CAType EnterpriseRootCA
656                 Restart-Computer
657
658             """
659             return SamDB('ldaps://%s' % self.server,
660                          credentials=creds,
661                          lp=self.lp)
662
663         def sasl_bind(creds):
664             return SamDB('ldap://%s' % self.server,
665                          credentials=creds,
666                          lp=self.lp)
667         if simple:
668             (samdb, self.last_simple_bind_bad) = \
669                 self.with_random_bad_credentials(simple_bind,
670                                                  self.simple_bind_creds,
671                                                  self.simple_bind_creds_bad,
672                                                  self.last_simple_bind_bad)
673         else:
674             (samdb, self.last_bind_bad) = \
675                 self.with_random_bad_credentials(sasl_bind,
676                                                  self.user_creds,
677                                                  self.user_creds_bad,
678                                                  self.last_bind_bad)
679
680         self.ldap_connections.append(samdb)
681         return samdb
682
683     def get_samr_context(self, new=False):
684         if not self.samr_contexts or new:
685             self.samr_contexts.append(
686                 SamrContext(self.server, lp=self.lp, creds=self.creds))
687         return self.samr_contexts[-1]
688
689     def get_netlogon_connection(self):
690
691         if self.netlogon_connection:
692             return self.netlogon_connection
693
694         def connect(creds):
695             return netlogon.netlogon("ncacn_ip_tcp:%s[schannel,seal]" %
696                                      (self.server),
697                                      self.lp,
698                                      creds)
699         (c, self.last_netlogon_bad) = \
700             self.with_random_bad_credentials(connect,
701                                              self.machine_creds,
702                                              self.machine_creds_bad,
703                                              self.last_netlogon_bad)
704         self.netlogon_connection = c
705         return c
706
707     def guess_a_dns_lookup(self):
708         return (self.realm, 'A')
709
710     def get_authenticator(self):
711         auth = self.machine_creds.new_client_authenticator()
712         current  = netr_Authenticator()
713         current.cred.data = [ord(x) for x in auth["credential"]]
714         current.timestamp = auth["timestamp"]
715
716         subsequent = netr_Authenticator()
717         return (current, subsequent)
718
719
720 class SamrContext(object):
721     """State/Context associated with a samr connection.
722     """
723     def __init__(self, server, lp=None, creds=None):
724         self.connection    = None
725         self.handle        = None
726         self.domain_handle = None
727         self.domain_sid    = None
728         self.group_handle  = None
729         self.user_handle   = None
730         self.rids          = None
731         self.server        = server
732         self.lp            = lp
733         self.creds         = creds
734
735     def get_connection(self):
736         if not self.connection:
737             self.connection = samr.samr(
738                 "ncacn_ip_tcp:%s[seal]" % (self.server),
739                 lp_ctx=self.lp,
740                 credentials=self.creds)
741
742         return self.connection
743
744     def get_handle(self):
745         if not self.handle:
746             c = self.get_connection()
747             self.handle = c.Connect2(None, security.SEC_FLAG_MAXIMUM_ALLOWED)
748         return self.handle
749
750
751 class Conversation(object):
752     """Details of a converation between a simulated client and a server."""
753     conversation_id = None
754
755     def __init__(self, start_time=None, endpoints=None):
756         self.start_time = start_time
757         self.endpoints = endpoints
758         self.packets = []
759         self.msg = random_colour_print()
760         self.client_balance = 0.0
761
762     def __cmp__(self, other):
763         if self.start_time is None:
764             if other.start_time is None:
765                 return 0
766             return -1
767         if other.start_time is None:
768             return 1
769         return self.start_time - other.start_time
770
771     def add_packet(self, packet):
772         """Add a packet object to this conversation, making a local copy with
773         a conversation-relative timestamp."""
774         p = packet.copy()
775
776         if self.start_time is None:
777             self.start_time = p.timestamp
778
779         if self.endpoints is None:
780             self.endpoints = p.endpoints
781
782         if p.endpoints != self.endpoints:
783             raise FakePacketError("Conversation endpoints %s don't match"
784                                   "packet endpoints %s" %
785                                   (self.endpoints, p.endpoints))
786
787         p.timestamp -= self.start_time
788
789         if p.src == p.endpoints[0]:
790             self.client_balance -= p.client_score()
791         else:
792             self.client_balance += p.client_score()
793
794         if p.is_really_a_packet():
795             self.packets.append(p)
796
797     def add_short_packet(self, timestamp, p, extra, client=True):
798         """Create a packet from a timestamp, and 'protocol:opcode' pair, and a
799         (possibly empty) list of extra data. If client is True, assume
800         this packet is from the client to the server.
801         """
802         protocol, opcode = p.split(':', 1)
803         src, dest = self.guess_client_server()
804         if not client:
805             src, dest = dest, src
806
807         desc = OP_DESCRIPTIONS.get((protocol, opcode), '')
808         ip_protocol = IP_PROTOCOLS.get(protocol, '06')
809         fields = [timestamp - self.start_time, ip_protocol,
810                   '', src, dest,
811                   protocol, opcode, desc]
812         fields.extend(extra)
813         packet = Packet(fields)
814         # XXX we're assuming the timestamp is already adjusted for
815         # this conversation?
816         # XXX should we adjust client balance for guessed packets?
817         if packet.src == packet.endpoints[0]:
818             self.client_balance -= packet.client_score()
819         else:
820             self.client_balance += packet.client_score()
821         if packet.is_really_a_packet():
822             self.packets.append(packet)
823
824     def __str__(self):
825         return ("<Conversation %s %s starting %.3f %d packets>" %
826                 (self.conversation_id, self.endpoints, self.start_time,
827                  len(self.packets)))
828
829     __repr__ = __str__
830
831     def __iter__(self):
832         return iter(self.packets)
833
834     def __len__(self):
835         return len(self.packets)
836
837     def get_duration(self):
838         if len(self.packets) < 2:
839             return 0
840         return self.packets[-1].timestamp - self.packets[0].timestamp
841
842     def replay_as_summary_lines(self):
843         lines = []
844         for p in self.packets:
845             lines.append(p.as_summary(self.start_time))
846         return lines
847
848     def replay_in_fork_with_delay(self, start, context=None, account=None):
849         """Fork a new process and replay the conversation.
850         """
851         def signal_handler(signal, frame):
852             """Signal handler closes standard out and error.
853
854             Triggered by a sigterm, ensures that the log messages are flushed
855             to disk and not lost.
856             """
857             sys.stderr.close()
858             sys.stdout.close()
859             os._exit(0)
860
861         t = self.start_time
862         now = time.time() - start
863         gap = t - now
864         # we are replaying strictly in order, so it is safe to sleep
865         # in the main process if the gap is big enough. This reduces
866         # the number of concurrent threads, which allows us to make
867         # larger loads.
868         if gap > 0.15 and False:
869             print("sleeping for %f in main process" % (gap - 0.1),
870                   file=sys.stderr)
871             time.sleep(gap - 0.1)
872             now = time.time() - start
873             gap = t - now
874             print("gap is now %f" % gap, file=sys.stderr)
875
876         self.conversation_id = context.next_conversation_id()
877         pid = os.fork()
878         if pid != 0:
879             return pid
880         pid = os.getpid()
881         signal.signal(signal.SIGTERM, signal_handler)
882         # we must never return, or we'll end up running parts of the
883         # parent's clean-up code. So we work in a try...finally, and
884         # try to print any exceptions.
885
886         try:
887             context.generate_process_local_config(account, self)
888             sys.stdin.close()
889             os.close(0)
890             filename = os.path.join(context.statsdir, 'stats-conversation-%d' %
891                                     self.conversation_id)
892             sys.stdout.close()
893             sys.stdout = open(filename, 'w')
894
895             sleep_time = gap - SLEEP_OVERHEAD
896             if sleep_time > 0:
897                 time.sleep(sleep_time)
898
899             miss = t - (time.time() - start)
900             self.msg("starting %s [miss %.3f pid %d]" % (self, miss, pid))
901             self.replay(context)
902         except Exception:
903             print(("EXCEPTION in child PID %d, conversation %s" % (pid, self)),
904                   file=sys.stderr)
905             traceback.print_exc(sys.stderr)
906         finally:
907             sys.stderr.close()
908             sys.stdout.close()
909             os._exit(0)
910
911     def replay(self, context=None):
912         start = time.time()
913
914         for p in self.packets:
915             now = time.time() - start
916             gap = p.timestamp - now
917             sleep_time = gap - SLEEP_OVERHEAD
918             if sleep_time > 0:
919                 time.sleep(sleep_time)
920
921             miss = p.timestamp - (time.time() - start)
922             if context is None:
923                 self.msg("packet %s [miss %.3f pid %d]" % (p, miss,
924                                                            os.getpid()))
925                 continue
926             p.play(self, context)
927
928     def guess_client_server(self, server_clue=None):
929         """Have a go at deciding who is the server and who is the client.
930         returns (client, server)
931         """
932         a, b = self.endpoints
933
934         if self.client_balance < 0:
935             return (a, b)
936
937         # in the absense of a clue, we will fall through to assuming
938         # the lowest number is the server (which is usually true).
939
940         if self.client_balance == 0 and server_clue == b:
941             return (a, b)
942
943         return (b, a)
944
945     def forget_packets_outside_window(self, s, e):
946         """Prune any packets outside the timne window we're interested in
947
948         :param s: start of the window
949         :param e: end of the window
950         """
951
952         new_packets = []
953         for p in self.packets:
954             if p.timestamp < s or p.timestamp > e:
955                 continue
956             new_packets.append(p)
957
958         self.packets = new_packets
959         if new_packets:
960             self.start_time = new_packets[0].timestamp
961         else:
962             self.start_time = None
963
964     def renormalise_times(self, start_time):
965         """Adjust the packet start times relative to the new start time."""
966         for p in self.packets:
967             p.timestamp -= start_time
968
969         if self.start_time is not None:
970             self.start_time -= start_time
971
972
973 class DnsHammer(Conversation):
974     """A lightweight conversation that generates a lot of dns:0 packets on
975     the fly"""
976
977     def __init__(self, dns_rate, duration):
978         n = int(dns_rate * duration)
979         self.times = [random.uniform(0, duration) for i in range(n)]
980         self.times.sort()
981         self.rate = dns_rate
982         self.duration = duration
983         self.start_time = 0
984         self.msg = random_colour_print()
985
986     def __str__(self):
987         return ("<DnsHammer %d packets over %.1fs (rate %.2f)>" %
988                 (len(self.times), self.duration, self.rate))
989
990     def replay_in_fork_with_delay(self, start, context=None, account=None):
991         return Conversation.replay_in_fork_with_delay(self,
992                                                       start,
993                                                       context,
994                                                       account)
995
996     def replay(self, context=None):
997         start = time.time()
998         fn = traffic_packets.packet_dns_0
999         for t in self.times:
1000             now = time.time() - start
1001             gap = t - now
1002             sleep_time = gap - SLEEP_OVERHEAD
1003             if sleep_time > 0:
1004                 time.sleep(sleep_time)
1005
1006             if context is None:
1007                 miss = t - (time.time() - start)
1008                 self.msg("packet %s [miss %.3f pid %d]" % (t, miss,
1009                                                            os.getpid()))
1010                 continue
1011
1012             packet_start = time.time()
1013             try:
1014                 fn(self, self, context)
1015                 end = time.time()
1016                 duration = end - packet_start
1017                 print("%f\tDNS\tdns\t0\t%f\tTrue\t" % (end, duration))
1018             except Exception as e:
1019                 end = time.time()
1020                 duration = end - packet_start
1021                 print("%f\tDNS\tdns\t0\t%f\tFalse\t%s" % (end, duration, e))
1022
1023
1024 def ingest_summaries(files, dns_mode='count'):
1025     """Load a summary traffic summary file and generated Converations from it.
1026     """
1027
1028     dns_counts = defaultdict(int)
1029     packets = []
1030     for f in files:
1031         if isinstance(f, str):
1032             f = open(f)
1033         print("Ingesting %s" % (f.name,), file=sys.stderr)
1034         for line in f:
1035             p = Packet(line)
1036             if p.protocol == 'dns' and dns_mode != 'include':
1037                 dns_counts[p.opcode] += 1
1038             else:
1039                 packets.append(p)
1040
1041         f.close()
1042
1043     if not packets:
1044         return [], 0
1045
1046     start_time = min(p.timestamp for p in packets)
1047     last_packet = max(p.timestamp for p in packets)
1048
1049     print("gathering packets into conversations", file=sys.stderr)
1050     conversations = OrderedDict()
1051     for p in packets:
1052         p.timestamp -= start_time
1053         c = conversations.get(p.endpoints)
1054         if c is None:
1055             c = Conversation()
1056             conversations[p.endpoints] = c
1057         c.add_packet(p)
1058
1059     # We only care about conversations with actual traffic, so we
1060     # filter out conversations with nothing to say. We do that here,
1061     # rather than earlier, because those empty packets contain useful
1062     # hints as to which end of the conversation was the client.
1063     conversation_list = []
1064     for c in conversations.values():
1065         if len(c) != 0:
1066             conversation_list.append(c)
1067
1068     # This is obviously not correct, as many conversations will appear
1069     # to start roughly simultaneously at the beginning of the snapshot.
1070     # To which we say: oh well, so be it.
1071     duration = float(last_packet - start_time)
1072     mean_interval = len(conversations) / duration
1073
1074     return conversation_list, mean_interval, duration, dns_counts
1075
1076
1077 def guess_server_address(conversations):
1078     # we guess the most common address.
1079     addresses = Counter()
1080     for c in conversations:
1081         addresses.update(c.endpoints)
1082     if addresses:
1083         return addresses.most_common(1)[0]
1084
1085
1086 def stringify_keys(x):
1087     y = {}
1088     for k, v in x.items():
1089         k2 = '\t'.join(k)
1090         y[k2] = v
1091     return y
1092
1093
1094 def unstringify_keys(x):
1095     y = {}
1096     for k, v in x.items():
1097         t = tuple(str(k).split('\t'))
1098         y[t] = v
1099     return y
1100
1101
1102 class TrafficModel(object):
1103     def __init__(self, n=3):
1104         self.ngrams = {}
1105         self.query_details = {}
1106         self.n = n
1107         self.dns_opcounts = defaultdict(int)
1108         self.cumulative_duration = 0.0
1109         self.conversation_rate = [0, 1]
1110
1111     def learn(self, conversations, dns_opcounts={}):
1112         prev = 0.0
1113         cum_duration = 0.0
1114         key = (NON_PACKET,) * (self.n - 1)
1115
1116         server = guess_server_address(conversations)
1117
1118         for k, v in dns_opcounts.items():
1119             self.dns_opcounts[k] += v
1120
1121         if len(conversations) > 1:
1122             elapsed =\
1123                 conversations[-1].start_time - conversations[0].start_time
1124             self.conversation_rate[0] = len(conversations)
1125             self.conversation_rate[1] = elapsed
1126
1127         for c in conversations:
1128             client, server = c.guess_client_server(server)
1129             cum_duration += c.get_duration()
1130             key = (NON_PACKET,) * (self.n - 1)
1131             for p in c:
1132                 if p.src != client:
1133                     continue
1134
1135                 elapsed = p.timestamp - prev
1136                 prev = p.timestamp
1137                 if elapsed > WAIT_THRESHOLD:
1138                     # add the wait as an extra state
1139                     wait = 'wait:%d' % (math.log(max(1.0,
1140                                                      elapsed * WAIT_SCALE)))
1141                     self.ngrams.setdefault(key, []).append(wait)
1142                     key = key[1:] + (wait,)
1143
1144                 short_p = p.as_packet_type()
1145                 self.query_details.setdefault(short_p,
1146                                               []).append(tuple(p.extra))
1147                 self.ngrams.setdefault(key, []).append(short_p)
1148                 key = key[1:] + (short_p,)
1149
1150         self.cumulative_duration += cum_duration
1151         # add in the end
1152         self.ngrams.setdefault(key, []).append(NON_PACKET)
1153
1154     def save(self, f):
1155         ngrams = {}
1156         for k, v in self.ngrams.items():
1157             k = '\t'.join(k)
1158             ngrams[k] = dict(Counter(v))
1159
1160         query_details = {}
1161         for k, v in self.query_details.items():
1162             query_details[k] = dict(Counter('\t'.join(x) if x else '-'
1163                                             for x in v))
1164
1165         d = {
1166             'ngrams': ngrams,
1167             'query_details': query_details,
1168             'cumulative_duration': self.cumulative_duration,
1169             'conversation_rate': self.conversation_rate,
1170         }
1171         d['dns'] = self.dns_opcounts
1172
1173         if isinstance(f, str):
1174             f = open(f, 'w')
1175
1176         json.dump(d, f, indent=2)
1177
1178     def load(self, f):
1179         if isinstance(f, str):
1180             f = open(f)
1181
1182         d = json.load(f)
1183
1184         for k, v in d['ngrams'].items():
1185             k = tuple(str(k).split('\t'))
1186             values = self.ngrams.setdefault(k, [])
1187             for p, count in v.items():
1188                 values.extend([str(p)] * count)
1189
1190         for k, v in d['query_details'].items():
1191             values = self.query_details.setdefault(str(k), [])
1192             for p, count in v.items():
1193                 if p == '-':
1194                     values.extend([()] * count)
1195                 else:
1196                     values.extend([tuple(str(p).split('\t'))] * count)
1197
1198         if 'dns' in d:
1199             for k, v in d['dns'].items():
1200                 self.dns_opcounts[k] += v
1201
1202         self.cumulative_duration = d['cumulative_duration']
1203         self.conversation_rate = d['conversation_rate']
1204
1205     def construct_conversation(self, timestamp=0.0, client=2, server=1,
1206                                hard_stop=None, packet_rate=1):
1207         """Construct a individual converation from the model."""
1208
1209         c = Conversation(timestamp, (server, client))
1210
1211         key = (NON_PACKET,) * (self.n - 1)
1212
1213         while key in self.ngrams:
1214             p = random.choice(self.ngrams.get(key, NON_PACKET))
1215             if p == NON_PACKET:
1216                 break
1217             if p in self.query_details:
1218                 extra = random.choice(self.query_details[p])
1219             else:
1220                 extra = []
1221
1222             protocol, opcode = p.split(':', 1)
1223             if protocol == 'wait':
1224                 log_wait_time = int(opcode) + random.random()
1225                 wait = math.exp(log_wait_time) / (WAIT_SCALE * packet_rate)
1226                 timestamp += wait
1227             else:
1228                 log_wait = random.uniform(*NO_WAIT_LOG_TIME_RANGE)
1229                 wait = math.exp(log_wait) / packet_rate
1230                 timestamp += wait
1231                 if hard_stop is not None and timestamp > hard_stop:
1232                     break
1233                 c.add_short_packet(timestamp, p, extra)
1234
1235             key = key[1:] + (p,)
1236
1237         return c
1238
1239     def generate_conversations(self, rate, duration, packet_rate=1):
1240         """Generate a list of conversations from the model."""
1241
1242         # We run the simulation for at least ten times as long as our
1243         # desired duration, and take a section near the start.
1244         rate_n, rate_t  = self.conversation_rate
1245
1246         duration2 = max(rate_t, duration * 2)
1247         n = rate * duration2 * rate_n / rate_t
1248
1249         server = 1
1250         client = 2
1251
1252         conversations = []
1253         end = duration2
1254         start = end - duration
1255
1256         while client < n + 2:
1257             start = random.uniform(0, duration2)
1258             c = self.construct_conversation(start,
1259                                             client,
1260                                             server,
1261                                             hard_stop=(duration2 * 5),
1262                                             packet_rate=packet_rate)
1263
1264             c.forget_packets_outside_window(start, end)
1265             c.renormalise_times(start)
1266             if len(c) != 0:
1267                 conversations.append(c)
1268             client += 1
1269
1270         print(("we have %d conversations at rate %f" %
1271                               (len(conversations), rate)), file=sys.stderr)
1272         conversations.sort()
1273         return conversations
1274
1275
1276 IP_PROTOCOLS = {
1277     'dns': '11',
1278     'rpc_netlogon': '06',
1279     'kerberos': '06',      # ratio 16248:258
1280     'smb': '06',
1281     'smb2': '06',
1282     'ldap': '06',
1283     'cldap': '11',
1284     'lsarpc': '06',
1285     'samr': '06',
1286     'dcerpc': '06',
1287     'epm': '06',
1288     'drsuapi': '06',
1289     'browser': '11',
1290     'smb_netlogon': '11',
1291     'srvsvc': '06',
1292     'nbns': '11',
1293 }
1294
1295 OP_DESCRIPTIONS = {
1296     ('browser', '0x01'): 'Host Announcement (0x01)',
1297     ('browser', '0x02'): 'Request Announcement (0x02)',
1298     ('browser', '0x08'): 'Browser Election Request (0x08)',
1299     ('browser', '0x09'): 'Get Backup List Request (0x09)',
1300     ('browser', '0x0c'): 'Domain/Workgroup Announcement (0x0c)',
1301     ('browser', '0x0f'): 'Local Master Announcement (0x0f)',
1302     ('cldap', '3'): 'searchRequest',
1303     ('cldap', '5'): 'searchResDone',
1304     ('dcerpc', '0'): 'Request',
1305     ('dcerpc', '11'): 'Bind',
1306     ('dcerpc', '12'): 'Bind_ack',
1307     ('dcerpc', '13'): 'Bind_nak',
1308     ('dcerpc', '14'): 'Alter_context',
1309     ('dcerpc', '15'): 'Alter_context_resp',
1310     ('dcerpc', '16'): 'AUTH3',
1311     ('dcerpc', '2'): 'Response',
1312     ('dns', '0'): 'query',
1313     ('dns', '1'): 'response',
1314     ('drsuapi', '0'): 'DsBind',
1315     ('drsuapi', '12'): 'DsCrackNames',
1316     ('drsuapi', '13'): 'DsWriteAccountSpn',
1317     ('drsuapi', '1'): 'DsUnbind',
1318     ('drsuapi', '2'): 'DsReplicaSync',
1319     ('drsuapi', '3'): 'DsGetNCChanges',
1320     ('drsuapi', '4'): 'DsReplicaUpdateRefs',
1321     ('epm', '3'): 'Map',
1322     ('kerberos', ''): '',
1323     ('ldap', '0'): 'bindRequest',
1324     ('ldap', '1'): 'bindResponse',
1325     ('ldap', '2'): 'unbindRequest',
1326     ('ldap', '3'): 'searchRequest',
1327     ('ldap', '4'): 'searchResEntry',
1328     ('ldap', '5'): 'searchResDone',
1329     ('ldap', ''): '*** Unknown ***',
1330     ('lsarpc', '14'): 'lsa_LookupNames',
1331     ('lsarpc', '15'): 'lsa_LookupSids',
1332     ('lsarpc', '39'): 'lsa_QueryTrustedDomainInfoBySid',
1333     ('lsarpc', '40'): 'lsa_SetTrustedDomainInfo',
1334     ('lsarpc', '6'): 'lsa_OpenPolicy',
1335     ('lsarpc', '76'): 'lsa_LookupSids3',
1336     ('lsarpc', '77'): 'lsa_LookupNames4',
1337     ('nbns', '0'): 'query',
1338     ('nbns', '1'): 'response',
1339     ('rpc_netlogon', '21'): 'NetrLogonDummyRoutine1',
1340     ('rpc_netlogon', '26'): 'NetrServerAuthenticate3',
1341     ('rpc_netlogon', '29'): 'NetrLogonGetDomainInfo',
1342     ('rpc_netlogon', '30'): 'NetrServerPasswordSet2',
1343     ('rpc_netlogon', '39'): 'NetrLogonSamLogonEx',
1344     ('rpc_netlogon', '40'): 'DsrEnumerateDomainTrusts',
1345     ('rpc_netlogon', '45'): 'NetrLogonSamLogonWithFlags',
1346     ('rpc_netlogon', '4'): 'NetrServerReqChallenge',
1347     ('samr', '0',): 'Connect',
1348     ('samr', '16'): 'GetAliasMembership',
1349     ('samr', '17'): 'LookupNames',
1350     ('samr', '18'): 'LookupRids',
1351     ('samr', '19'): 'OpenGroup',
1352     ('samr', '1'): 'Close',
1353     ('samr', '25'): 'QueryGroupMember',
1354     ('samr', '34'): 'OpenUser',
1355     ('samr', '36'): 'QueryUserInfo',
1356     ('samr', '39'): 'GetGroupsForUser',
1357     ('samr', '3'): 'QuerySecurity',
1358     ('samr', '5'): 'LookupDomain',
1359     ('samr', '64'): 'Connect5',
1360     ('samr', '6'): 'EnumDomains',
1361     ('samr', '7'): 'OpenDomain',
1362     ('samr', '8'): 'QueryDomainInfo',
1363     ('smb', '0x04'): 'Close (0x04)',
1364     ('smb', '0x24'): 'Locking AndX (0x24)',
1365     ('smb', '0x2e'): 'Read AndX (0x2e)',
1366     ('smb', '0x32'): 'Trans2 (0x32)',
1367     ('smb', '0x71'): 'Tree Disconnect (0x71)',
1368     ('smb', '0x72'): 'Negotiate Protocol (0x72)',
1369     ('smb', '0x73'): 'Session Setup AndX (0x73)',
1370     ('smb', '0x74'): 'Logoff AndX (0x74)',
1371     ('smb', '0x75'): 'Tree Connect AndX (0x75)',
1372     ('smb', '0xa2'): 'NT Create AndX (0xa2)',
1373     ('smb2', '0'): 'NegotiateProtocol',
1374     ('smb2', '11'): 'Ioctl',
1375     ('smb2', '14'): 'Find',
1376     ('smb2', '16'): 'GetInfo',
1377     ('smb2', '18'): 'Break',
1378     ('smb2', '1'): 'SessionSetup',
1379     ('smb2', '2'): 'SessionLogoff',
1380     ('smb2', '3'): 'TreeConnect',
1381     ('smb2', '4'): 'TreeDisconnect',
1382     ('smb2', '5'): 'Create',
1383     ('smb2', '6'): 'Close',
1384     ('smb2', '8'): 'Read',
1385     ('smb_netlogon', '0x12'): 'SAM LOGON request from client (0x12)',
1386     ('smb_netlogon', '0x17'): ('SAM Active Directory Response - '
1387                                'user unknown (0x17)'),
1388     ('srvsvc', '16'): 'NetShareGetInfo',
1389     ('srvsvc', '21'): 'NetSrvGetInfo',
1390 }
1391
1392
1393 def expand_short_packet(p, timestamp, src, dest, extra):
1394     protocol, opcode = p.split(':', 1)
1395     desc = OP_DESCRIPTIONS.get((protocol, opcode), '')
1396     ip_protocol = IP_PROTOCOLS.get(protocol, '06')
1397
1398     line = [timestamp, ip_protocol, '', src, dest, protocol, opcode, desc]
1399     line.extend(extra)
1400     return '\t'.join(line)
1401
1402
1403 def replay(conversations,
1404            host=None,
1405            creds=None,
1406            lp=None,
1407            accounts=None,
1408            dns_rate=0,
1409            duration=None,
1410            **kwargs):
1411
1412     context = ReplayContext(server=host,
1413                             creds=creds,
1414                             lp=lp,
1415                             **kwargs)
1416
1417     if len(accounts) < len(conversations):
1418         print(("we have %d accounts but %d conversations" %
1419                (accounts, conversations)), file=sys.stderr)
1420
1421     cstack = list(zip(
1422         sorted(conversations, key=lambda x: x.start_time, reverse=True),
1423         accounts))
1424
1425     # Set the process group so that the calling scripts are not killed
1426     # when the forked child processes are killed.
1427     os.setpgrp()
1428
1429     start = time.time()
1430
1431     if duration is None:
1432         # end 1 second after the last packet of the last conversation
1433         # to start. Conversations other than the last could still be
1434         # going, but we don't care.
1435         duration = cstack[0][0].packets[-1].timestamp + 1.0
1436         print("We will stop after %.1f seconds" % duration,
1437               file=sys.stderr)
1438
1439     end = start + duration
1440
1441     print("Replaying traffic for %u conversations over %d seconds"
1442           % (len(conversations), duration))
1443
1444     children = {}
1445     if dns_rate:
1446         dns_hammer = DnsHammer(dns_rate, duration)
1447         cstack.append((dns_hammer, None))
1448
1449     try:
1450         while True:
1451             # we spawn a batch, wait for finishers, then spawn another
1452             now = time.time()
1453             batch_end = min(now + 2.0, end)
1454             fork_time = 0.0
1455             fork_n = 0
1456             while cstack:
1457                 c, account = cstack.pop()
1458                 if c.start_time + start > batch_end:
1459                     cstack.append((c, account))
1460                     break
1461
1462                 st = time.time()
1463                 pid = c.replay_in_fork_with_delay(start, context, account)
1464                 children[pid] = c
1465                 t = time.time()
1466                 elapsed = t - st
1467                 fork_time += elapsed
1468                 fork_n += 1
1469                 print("forked %s in pid %s (in %fs)" % (c, pid,
1470                                                         elapsed),
1471                       file=sys.stderr)
1472
1473             if fork_n:
1474                 print(("forked %d times in %f seconds (avg %f)" %
1475                        (fork_n, fork_time, fork_time / fork_n)),
1476                       file=sys.stderr)
1477             elif cstack:
1478                 debug(2, "no forks in batch ending %f" % batch_end)
1479
1480             while time.time() < batch_end - 1.0:
1481                 time.sleep(0.01)
1482                 try:
1483                     pid, status = os.waitpid(-1, os.WNOHANG)
1484                 except OSError as e:
1485                     if e.errno != 10:  # no child processes
1486                         raise
1487                     break
1488                 if pid:
1489                     c = children.pop(pid, None)
1490                     print(("process %d finished conversation %s;"
1491                            " %d to go" %
1492                            (pid, c, len(children))), file=sys.stderr)
1493
1494             if time.time() >= end:
1495                 print("time to stop", file=sys.stderr)
1496                 break
1497
1498     except Exception:
1499         print("EXCEPTION in parent", file=sys.stderr)
1500         traceback.print_exc()
1501     finally:
1502         for s in (15, 15, 9):
1503             print(("killing %d children with -%d" %
1504                                  (len(children), s)), file=sys.stderr)
1505             for pid in children:
1506                 try:
1507                     os.kill(pid, s)
1508                 except OSError as e:
1509                     if e.errno != 3:  # don't fail if it has already died
1510                         raise
1511             time.sleep(0.5)
1512             end = time.time() + 1
1513             while children:
1514                 try:
1515                     pid, status = os.waitpid(-1, os.WNOHANG)
1516                 except OSError as e:
1517                     if e.errno != 10:
1518                         raise
1519                 if pid != 0:
1520                     c = children.pop(pid, None)
1521                     print(("kill -%d %d KILLED conversation %s; "
1522                            "%d to go" %
1523                            (s, pid, c, len(children))),
1524                           file=sys.stderr)
1525                 if time.time() >= end:
1526                     break
1527
1528             if not children:
1529                 break
1530             time.sleep(1)
1531
1532         if children:
1533             print("%d children are missing" % len(children),
1534                   file=sys.stderr)
1535
1536         # there may be stragglers that were forked just as ^C was hit
1537         # and don't appear in the list of children. We can get them
1538         # with killpg, but that will also kill us, so this is^H^H would be
1539         # goodbye, except we cheat and pretend to use ^C (SIG_INTERRUPT),
1540         # so as not to have to fuss around writing signal handlers.
1541         try:
1542             os.killpg(0, 2)
1543         except KeyboardInterrupt:
1544             print("ignoring fake ^C", file=sys.stderr)
1545
1546
1547 def openLdb(host, creds, lp):
1548     session = system_session()
1549     ldb = SamDB(url="ldap://%s" % host,
1550                 session_info=session,
1551                 credentials=creds,
1552                 lp=lp)
1553     return ldb
1554
1555
1556 def ou_name(ldb, instance_id):
1557     """Generate an ou name from the instance id"""
1558     return "ou=instance-%d,ou=traffic_replay,%s" % (instance_id,
1559                                                     ldb.domain_dn())
1560
1561
1562 def create_ou(ldb, instance_id):
1563     """Create an ou, all created user and machine accounts will belong to it.
1564
1565     This allows all the created resources to be cleaned up easily.
1566     """
1567     ou = ou_name(ldb, instance_id)
1568     try:
1569         ldb.add({"dn":          ou.split(',', 1)[1],
1570                  "objectclass": "organizationalunit"})
1571     except LdbError as e:
1572         (status, _) = e
1573         # ignore already exists
1574         if status != 68:
1575             raise
1576     try:
1577         ldb.add({"dn":          ou,
1578                  "objectclass": "organizationalunit"})
1579     except LdbError as e:
1580         (status, _) = e
1581         # ignore already exists
1582         if status != 68:
1583             raise
1584     return ou
1585
1586
1587 class ConversationAccounts(object):
1588     """Details of the machine and user accounts associated with a conversation.
1589     """
1590     def __init__(self, netbios_name, machinepass, username, userpass):
1591         self.netbios_name = netbios_name
1592         self.machinepass  = machinepass
1593         self.username     = username
1594         self.userpass     = userpass
1595
1596
1597 def generate_replay_accounts(ldb, instance_id, number, password):
1598     """Generate a series of unique machine and user account names."""
1599
1600     generate_traffic_accounts(ldb, instance_id, number, password)
1601     accounts = []
1602     for i in range(1, number + 1):
1603         netbios_name = "STGM-%d-%d" % (instance_id, i)
1604         username     = "STGU-%d-%d" % (instance_id, i)
1605
1606         account = ConversationAccounts(netbios_name, password, username,
1607                                        password)
1608         accounts.append(account)
1609     return accounts
1610
1611
1612 def generate_traffic_accounts(ldb, instance_id, number, password):
1613     """Create the specified number of user and machine accounts.
1614
1615     As accounts are not explicitly deleted between runs. This function starts
1616     with the last account and iterates backwards stopping either when it
1617     finds an already existing account or it has generated all the required
1618     accounts.
1619     """
1620     print(("Generating machine and conversation accounts, "
1621            "as required for %d conversations" % number),
1622           file=sys.stderr)
1623     added = 0
1624     for i in range(number, 0, -1):
1625         try:
1626             netbios_name = "STGM-%d-%d" % (instance_id, i)
1627             create_machine_account(ldb, instance_id, netbios_name, password)
1628             added += 1
1629         except LdbError as e:
1630             (status, _) = e
1631             if status == 68:
1632                 break
1633             else:
1634                 raise
1635     if added > 0:
1636         print("Added %d new machine accounts" % added,
1637               file=sys.stderr)
1638
1639     added = 0
1640     for i in range(number, 0, -1):
1641         try:
1642             username = "STGU-%d-%d" % (instance_id, i)
1643             create_user_account(ldb, instance_id, username, password)
1644             added += 1
1645         except LdbError as e:
1646             (status, _) = e
1647             if status == 68:
1648                 break
1649             else:
1650                 raise
1651
1652     if added > 0:
1653         print("Added %d new user accounts" % added,
1654               file=sys.stderr)
1655
1656
1657 def create_machine_account(ldb, instance_id, netbios_name, machinepass):
1658     """Create a machine account via ldap."""
1659
1660     ou = ou_name(ldb, instance_id)
1661     dn = "cn=%s,%s" % (netbios_name, ou)
1662     utf16pw = unicode(
1663         '"' + machinepass.encode('utf-8') + '"', 'utf-8'
1664     ).encode('utf-16-le')
1665     start = time.time()
1666     ldb.add({
1667         "dn": dn,
1668         "objectclass": "computer",
1669         "sAMAccountName": "%s$" % netbios_name,
1670         "userAccountControl":
1671         str(UF_WORKSTATION_TRUST_ACCOUNT | UF_PASSWD_NOTREQD),
1672         "unicodePwd": utf16pw})
1673     end = time.time()
1674     duration = end - start
1675     print("%f\t0\tcreate\tmachine\t%f\tTrue\t" % (end, duration))
1676
1677
1678 def create_user_account(ldb, instance_id, username, userpass):
1679     """Create a user account via ldap."""
1680     ou = ou_name(ldb, instance_id)
1681     user_dn = "cn=%s,%s" % (username, ou)
1682     utf16pw = unicode(
1683         '"' + userpass.encode('utf-8') + '"', 'utf-8'
1684     ).encode('utf-16-le')
1685     start = time.time()
1686     ldb.add({
1687         "dn": user_dn,
1688         "objectclass": "user",
1689         "sAMAccountName": username,
1690         "userAccountControl": str(UF_NORMAL_ACCOUNT),
1691         "unicodePwd": utf16pw
1692     })
1693
1694     # grant user write permission to do things like write account SPN
1695     sdutils = sd_utils.SDUtils(ldb)
1696     sdutils.dacl_add_ace(user_dn,  "(A;;WP;;;PS)")
1697
1698     end = time.time()
1699     duration = end - start
1700     print("%f\t0\tcreate\tuser\t%f\tTrue\t" % (end, duration))
1701
1702
1703 def create_group(ldb, instance_id, name):
1704     """Create a group via ldap."""
1705
1706     ou = ou_name(ldb, instance_id)
1707     dn = "cn=%s,%s" % (name, ou)
1708     start = time.time()
1709     ldb.add({
1710         "dn": dn,
1711         "objectclass": "group",
1712     })
1713     end = time.time()
1714     duration = end - start
1715     print("%f\t0\tcreate\tgroup\t%f\tTrue\t" % (end, duration))
1716
1717
1718 def user_name(instance_id, i):
1719     """Generate a user name based in the instance id"""
1720     return "STGU-%d-%d" % (instance_id, i)
1721
1722
1723 def generate_users(ldb, instance_id, number, password):
1724     """Add users to the server"""
1725     users = 0
1726     for i in range(number, 0, -1):
1727         try:
1728             username = user_name(instance_id, i)
1729             create_user_account(ldb, instance_id, username, password)
1730             users += 1
1731         except LdbError as e:
1732             (status, _) = e
1733             # Stop if entry exists
1734             if status == 68:
1735                 break
1736             else:
1737                 raise
1738
1739     return users
1740
1741
1742 def group_name(instance_id, i):
1743     """Generate a group name from instance id."""
1744     return "STGG-%d-%d" % (instance_id, i)
1745
1746
1747 def generate_groups(ldb, instance_id, number):
1748     """Create the required number of groups on the server."""
1749     groups = 0
1750     for i in range(number, 0, -1):
1751         try:
1752             name = group_name(instance_id, i)
1753             create_group(ldb, instance_id, name)
1754             groups += 1
1755         except LdbError as e:
1756             (status, _) = e
1757             # Stop if entry exists
1758             if status == 68:
1759                 break
1760             else:
1761                 raise
1762     return groups
1763
1764
1765 def clean_up_accounts(ldb, instance_id):
1766     """Remove the created accounts and groups from the server."""
1767     ou = ou_name(ldb, instance_id)
1768     try:
1769         ldb.delete(ou, ["tree_delete:1"])
1770     except LdbError as e:
1771         (status, _) = e
1772         # ignore does not exist
1773         if status != 32:
1774             raise
1775
1776
1777 def generate_users_and_groups(ldb, instance_id, password,
1778                               number_of_users, number_of_groups,
1779                               group_memberships):
1780     """Generate the required users and groups, allocating the users to
1781        those groups."""
1782     assignments = []
1783     groups_added  = 0
1784
1785     create_ou(ldb, instance_id)
1786
1787     print("Generating dummy user accounts", file=sys.stderr)
1788     users_added = generate_users(ldb, instance_id, number_of_users, password)
1789
1790     if number_of_groups > 0:
1791         print("Generating dummy groups", file=sys.stderr)
1792         groups_added = generate_groups(ldb, instance_id, number_of_groups)
1793
1794     if group_memberships > 0:
1795         print("Assigning users to groups", file=sys.stderr)
1796         assignments = assign_groups(number_of_groups,
1797                                     groups_added,
1798                                     number_of_users,
1799                                     users_added,
1800                                     group_memberships)
1801         print("Adding users to groups", file=sys.stderr)
1802         add_users_to_groups(ldb, instance_id, assignments)
1803
1804     if (groups_added > 0 and users_added == 0 and
1805        number_of_groups != groups_added):
1806         print("Warning: the added groups will contain no members",
1807               file=sys.stderr)
1808
1809     print(("Added %d users, %d groups and %d group memberships" %
1810            (users_added, groups_added, len(assignments))),
1811           file=sys.stderr)
1812
1813
1814 def assign_groups(number_of_groups,
1815                   groups_added,
1816                   number_of_users,
1817                   users_added,
1818                   group_memberships):
1819     """Allocate users to groups.
1820
1821     The intention is to have a few users that belong to most groups, while
1822     the majority of users belong to a few groups.
1823
1824     A few groups will contain most users, with the remaining only having a
1825     few users.
1826     """
1827
1828     def generate_user_distribution(n):
1829         """Probability distribution of a user belonging to a group.
1830         """
1831         dist = []
1832         for x in range(1, n + 1):
1833             p = 1 / (x + 0.001)
1834             dist.append(p)
1835         return dist
1836
1837     def generate_group_distribution(n):
1838         """Probability distribution of a group containing a user."""
1839         dist = []
1840         for x in range(1, n + 1):
1841             p = 1 / (x**1.3)
1842             dist.append(p)
1843         return dist
1844
1845     assignments = set()
1846     if group_memberships <= 0:
1847         return assignments
1848
1849     group_dist = generate_group_distribution(number_of_groups)
1850     user_dist  = generate_user_distribution(number_of_users)
1851
1852     # Calculate the number of group menberships required
1853     group_memberships = math.ceil(
1854         float(group_memberships) *
1855         (float(users_added) / float(number_of_users)))
1856
1857     existing_users  = number_of_users  - users_added  - 1
1858     existing_groups = number_of_groups - groups_added - 1
1859     while len(assignments) < group_memberships:
1860         user        = random.randint(0, number_of_users - 1)
1861         group       = random.randint(0, number_of_groups - 1)
1862         probability = group_dist[group] * user_dist[user]
1863
1864         if ((random.random() < probability * 10000) and
1865            (group > existing_groups or user > existing_users)):
1866             # the + 1 converts the array index to the corresponding
1867             # group or user number
1868             assignments.add(((user + 1), (group + 1)))
1869
1870     return assignments
1871
1872
1873 def add_users_to_groups(db, instance_id, assignments):
1874     """Add users to their assigned groups.
1875
1876     Takes the list of (group,user) tuples generated by assign_groups and
1877     assign the users to their specified groups."""
1878
1879     ou = ou_name(db, instance_id)
1880
1881     def build_dn(name):
1882         return("cn=%s,%s" % (name, ou))
1883
1884     for (user, group) in assignments:
1885         user_dn  = build_dn(user_name(instance_id, user))
1886         group_dn = build_dn(group_name(instance_id, group))
1887
1888         m = ldb.Message()
1889         m.dn = ldb.Dn(db, group_dn)
1890         m["member"] = ldb.MessageElement(user_dn, ldb.FLAG_MOD_ADD, "member")
1891         start = time.time()
1892         db.modify(m)
1893         end = time.time()
1894         duration = end - start
1895         print("%f\t0\tadd\tuser\t%f\tTrue\t" % (end, duration))
1896
1897
1898 def generate_stats(statsdir, timing_file):
1899     """Generate and print the summary stats for a run."""
1900     first      = sys.float_info.max
1901     last       = 0
1902     successful = 0
1903     failed     = 0
1904     latencies  = {}
1905     failures   = {}
1906     unique_converations = set()
1907     conversations = 0
1908
1909     if timing_file is not None:
1910         tw = timing_file.write
1911     else:
1912         def tw(x):
1913             pass
1914
1915     tw("time\tconv\tprotocol\ttype\tduration\tsuccessful\terror\n")
1916
1917     for filename in os.listdir(statsdir):
1918         path = os.path.join(statsdir, filename)
1919         with open(path, 'r') as f:
1920             for line in f:
1921                 try:
1922                     fields       = line.rstrip('\n').split('\t')
1923                     conversation = fields[1]
1924                     protocol     = fields[2]
1925                     packet_type  = fields[3]
1926                     latency      = float(fields[4])
1927                     first        = min(float(fields[0]) - latency, first)
1928                     last         = max(float(fields[0]), last)
1929
1930                     if protocol not in latencies:
1931                         latencies[protocol] = {}
1932                     if packet_type not in latencies[protocol]:
1933                         latencies[protocol][packet_type] = []
1934
1935                     latencies[protocol][packet_type].append(latency)
1936
1937                     if protocol not in failures:
1938                         failures[protocol] = {}
1939                     if packet_type not in failures[protocol]:
1940                         failures[protocol][packet_type] = 0
1941
1942                     if fields[5] == 'True':
1943                         successful += 1
1944                     else:
1945                         failed += 1
1946                         failures[protocol][packet_type] += 1
1947
1948                     if conversation not in unique_converations:
1949                         unique_converations.add(conversation)
1950                         conversations += 1
1951
1952                     tw(line)
1953                 except (ValueError, IndexError):
1954                     # not a valid line print and ignore
1955                     print(line, file=sys.stderr)
1956                     pass
1957     duration = last - first
1958     if successful == 0:
1959         success_rate = 0
1960     else:
1961         success_rate = successful / duration
1962     if failed == 0:
1963         failure_rate = 0
1964     else:
1965         failure_rate = failed / duration
1966
1967     # print the stats in more human-readable format when stdout is going to the
1968     # console (as opposed to being redirected to a file)
1969     if sys.stdout.isatty():
1970         print("Total conversations:   %10d" % conversations)
1971         print("Successful operations: %10d (%.3f per second)"
1972               % (successful, success_rate))
1973         print("Failed operations:     %10d (%.3f per second)"
1974               % (failed, failure_rate))
1975     else:
1976         print("(%d, %d, %d, %.3f, %.3f)" %
1977               (conversations, successful, failed, success_rate, failure_rate))
1978
1979     if sys.stdout.isatty():
1980         print("Protocol    Op Code  Description                               "
1981               " Count       Failed         Mean       Median          "
1982               "95%        Range          Max")
1983     else:
1984         print("proto\top_code\tdesc\tcount\tfailed\tmean\tmedian\t95%\trange"
1985               "\tmax")
1986     protocols = sorted(latencies.keys())
1987     for protocol in protocols:
1988         packet_types = sorted(latencies[protocol], key=opcode_key)
1989         for packet_type in packet_types:
1990             values     = latencies[protocol][packet_type]
1991             values     = sorted(values)
1992             count      = len(values)
1993             failed     = failures[protocol][packet_type]
1994             mean       = sum(values) / count
1995             median     = calc_percentile(values, 0.50)
1996             percentile = calc_percentile(values, 0.95)
1997             rng        = values[-1] - values[0]
1998             maxv       = values[-1]
1999             desc       = OP_DESCRIPTIONS.get((protocol, packet_type), '')
2000             if sys.stdout.isatty:
2001                 print("%-12s   %4s  %-35s %12d %12d %12.6f "
2002                       "%12.6f %12.6f %12.6f %12.6f"
2003                       % (protocol,
2004                          packet_type,
2005                          desc,
2006                          count,
2007                          failed,
2008                          mean,
2009                          median,
2010                          percentile,
2011                          rng,
2012                          maxv))
2013             else:
2014                 print("%s\t%s\t%s\t%d\t%d\t%f\t%f\t%f\t%f\t%f"
2015                       % (protocol,
2016                          packet_type,
2017                          desc,
2018                          count,
2019                          failed,
2020                          mean,
2021                          median,
2022                          percentile,
2023                          rng,
2024                          maxv))
2025
2026
2027 def opcode_key(v):
2028     """Sort key for the operation code to ensure that it sorts numerically"""
2029     try:
2030         return "%03d" % int(v)
2031     except:
2032         return v
2033
2034
2035 def calc_percentile(values, percentile):
2036     """Calculate the specified percentile from the list of values.
2037
2038     Assumes the list is sorted in ascending order.
2039     """
2040
2041     if not values:
2042         return 0
2043     k = (len(values) - 1) * percentile
2044     f = math.floor(k)
2045     c = math.ceil(k)
2046     if f == c:
2047         return values[int(k)]
2048     d0 = values[int(f)] * (c - k)
2049     d1 = values[int(c)] * (k - f)
2050     return d0 + d1
2051
2052
2053 def mk_masked_dir(*path):
2054     """In a testenv we end up with 0777 diectories that look an alarming
2055     green colour with ls. Use umask to avoid that."""
2056     d = os.path.join(*path)
2057     mask = os.umask(0o077)
2058     os.mkdir(d)
2059     os.umask(mask)
2060     return d