traffic_packets: add windows instructions for ldap 0 simple bind
[samba.git] / python / samba / emulate / traffic.py
1 # -*- encoding: utf-8 -*-
2 # Samba traffic replay and learning
3 #
4 # Copyright (C) Catalyst IT Ltd. 2017
5 #
6 # This program is free software; you can redistribute it and/or modify
7 # it under the terms of the GNU General Public License as published by
8 # the Free Software Foundation; either version 3 of the License, or
9 # (at your option) any later version.
10 #
11 # This program is distributed in the hope that it will be useful,
12 # but WITHOUT ANY WARRANTY; without even the implied warranty of
13 # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
14 # GNU General Public License for more details.
15 #
16 # You should have received a copy of the GNU General Public License
17 # along with this program.  If not, see <http://www.gnu.org/licenses/>.
18 #
19 from __future__ import print_function, division
20
21 import time
22 import os
23 import random
24 import json
25 import math
26 import sys
27 import signal
28 import itertools
29
30 from collections import OrderedDict, Counter, defaultdict
31 from samba.emulate import traffic_packets
32 from samba.samdb import SamDB
33 import ldb
34 from ldb import LdbError
35 from samba.dcerpc import ClientConnection
36 from samba.dcerpc import security, drsuapi, lsa
37 from samba.dcerpc import netlogon
38 from samba.dcerpc.netlogon import netr_Authenticator
39 from samba.dcerpc import srvsvc
40 from samba.dcerpc import samr
41 from samba.drs_utils import drs_DsBind
42 import traceback
43 from samba.credentials import Credentials, DONT_USE_KERBEROS, MUST_USE_KERBEROS
44 from samba.auth import system_session
45 from samba.dsdb import UF_WORKSTATION_TRUST_ACCOUNT, UF_PASSWD_NOTREQD
46 from samba.dsdb import UF_NORMAL_ACCOUNT
47 from samba.dcerpc.misc import SEC_CHAN_WKSTA
48 from samba import gensec
49
50 SLEEP_OVERHEAD = 3e-4
51
52 # we don't use None, because it complicates [de]serialisation
53 NON_PACKET = '-'
54
55 CLIENT_CLUES = {
56     ('dns', '0'): 1.0,      # query
57     ('smb', '0x72'): 1.0,   # Negotiate protocol
58     ('ldap', '0'): 1.0,     # bind
59     ('ldap', '3'): 1.0,     # searchRequest
60     ('ldap', '2'): 1.0,     # unbindRequest
61     ('cldap', '3'): 1.0,
62     ('dcerpc', '11'): 1.0,  # bind
63     ('dcerpc', '14'): 1.0,  # Alter_context
64     ('nbns', '0'): 1.0,     # query
65 }
66
67 SERVER_CLUES = {
68     ('dns', '1'): 1.0,      # response
69     ('ldap', '1'): 1.0,     # bind response
70     ('ldap', '4'): 1.0,     # search result
71     ('ldap', '5'): 1.0,     # search done
72     ('cldap', '5'): 1.0,
73     ('dcerpc', '12'): 1.0,  # bind_ack
74     ('dcerpc', '13'): 1.0,  # bind_nak
75     ('dcerpc', '15'): 1.0,  # Alter_context response
76 }
77
78 SKIPPED_PROTOCOLS = {"smb", "smb2", "browser", "smb_netlogon"}
79
80 WAIT_SCALE = 10.0
81 WAIT_THRESHOLD = (1.0 / WAIT_SCALE)
82 NO_WAIT_LOG_TIME_RANGE = (-10, -3)
83
84 # DEBUG_LEVEL can be changed by scripts with -d
85 DEBUG_LEVEL = 0
86
87
88 def debug(level, msg, *args):
89     """Print a formatted debug message to standard error.
90
91
92     :param level: The debug level, message will be printed if it is <= the
93                   currently set debug level. The debug level can be set with
94                   the -d option.
95     :param msg:   The message to be logged, can contain C-Style format
96                   specifiers
97     :param args:  The parameters required by the format specifiers
98     """
99     if level <= DEBUG_LEVEL:
100         if not args:
101             print(msg, file=sys.stderr)
102         else:
103             print(msg % tuple(args), file=sys.stderr)
104
105
106 def debug_lineno(*args):
107     """ Print an unformatted log message to stderr, contaning the line number
108     """
109     tb = traceback.extract_stack(limit=2)
110     print((" %s:" "\033[01;33m"
111            "%s " "\033[00m" % (tb[0][2], tb[0][1])), end=' ',
112           file=sys.stderr)
113     for a in args:
114         print(a, file=sys.stderr)
115     print(file=sys.stderr)
116     sys.stderr.flush()
117
118
119 def random_colour_print():
120     """Return a function that prints a randomly coloured line to stderr"""
121     n = 18 + random.randrange(214)
122     prefix = "\033[38;5;%dm" % n
123
124     def p(*args):
125         for a in args:
126             print("%s%s\033[00m" % (prefix, a), file=sys.stderr)
127
128     return p
129
130
131 class FakePacketError(Exception):
132     pass
133
134
135 class Packet(object):
136     """Details of a network packet"""
137     def __init__(self, fields):
138         if isinstance(fields, str):
139             fields = fields.rstrip('\n').split('\t')
140
141         (timestamp,
142          ip_protocol,
143          stream_number,
144          src,
145          dest,
146          protocol,
147          opcode,
148          desc) = fields[:8]
149         extra = fields[8:]
150
151         self.timestamp = float(timestamp)
152         self.ip_protocol = ip_protocol
153         try:
154             self.stream_number = int(stream_number)
155         except (ValueError, TypeError):
156             self.stream_number = None
157         self.src = int(src)
158         self.dest = int(dest)
159         self.protocol = protocol
160         self.opcode = opcode
161         self.desc = desc
162         self.extra = extra
163
164         if self.src < self.dest:
165             self.endpoints = (self.src, self.dest)
166         else:
167             self.endpoints = (self.dest, self.src)
168
169     def as_summary(self, time_offset=0.0):
170         """Format the packet as a traffic_summary line.
171         """
172         extra = '\t'.join(self.extra)
173         t = self.timestamp + time_offset
174         return (t, '%f\t%s\t%s\t%d\t%d\t%s\t%s\t%s\t%s' %
175                 (t,
176                  self.ip_protocol,
177                  self.stream_number or '',
178                  self.src,
179                  self.dest,
180                  self.protocol,
181                  self.opcode,
182                  self.desc,
183                  extra))
184
185     def __str__(self):
186         return ("%.3f: %d -> %d; ip %s; strm %s; prot %s; op %s; desc %s %s" %
187                 (self.timestamp, self.src, self.dest, self.ip_protocol or '-',
188                  self.stream_number, self.protocol, self.opcode, self.desc,
189                  ('«' + ' '.join(self.extra) + '»' if self.extra else '')))
190
191     def __repr__(self):
192         return "<Packet @%s>" % self
193
194     def copy(self):
195         return self.__class__([self.timestamp,
196                                self.ip_protocol,
197                                self.stream_number,
198                                self.src,
199                                self.dest,
200                                self.protocol,
201                                self.opcode,
202                                self.desc] + self.extra)
203
204     def as_packet_type(self):
205         t = '%s:%s' % (self.protocol, self.opcode)
206         return t
207
208     def client_score(self):
209         """A positive number means we think it is a client; a negative number
210         means we think it is a server. Zero means no idea. range: -1 to 1.
211         """
212         key = (self.protocol, self.opcode)
213         if key in CLIENT_CLUES:
214             return CLIENT_CLUES[key]
215         if key in SERVER_CLUES:
216             return -SERVER_CLUES[key]
217         return 0.0
218
219     def play(self, conversation, context):
220         """Send the packet over the network, if required.
221
222         Some packets are ignored, i.e. for  protocols not handled,
223         server response messages, or messages that are generated by the
224         protocol layer associated with other packets.
225         """
226         fn_name = 'packet_%s_%s' % (self.protocol, self.opcode)
227         try:
228             fn = getattr(traffic_packets, fn_name)
229
230         except AttributeError as e:
231             print("Conversation(%s) Missing handler %s" % \
232                   (conversation.conversation_id, fn_name),
233                   file=sys.stderr)
234             return
235
236         # Don't display a message for kerberos packets, they're not directly
237         # generated they're used to indicate kerberos should be used
238         if self.protocol != "kerberos":
239             debug(2, "Conversation(%s) Calling handler %s" %
240                      (conversation.conversation_id, fn_name))
241
242         start = time.time()
243         try:
244             if fn(self, conversation, context):
245                 # Only collect timing data for functions that generate
246                 # network traffic, or fail
247                 end = time.time()
248                 duration = end - start
249                 print("%f\t%s\t%s\t%s\t%f\tTrue\t" %
250                       (end, conversation.conversation_id, self.protocol,
251                        self.opcode, duration))
252         except Exception as e:
253             end = time.time()
254             duration = end - start
255             print("%f\t%s\t%s\t%s\t%f\tFalse\t%s" %
256                   (end, conversation.conversation_id, self.protocol,
257                    self.opcode, duration, e))
258
259     def __cmp__(self, other):
260         return self.timestamp - other.timestamp
261
262     def is_really_a_packet(self, missing_packet_stats=None):
263         """Is the packet one that can be ignored?
264
265         If so removing it will have no effect on the replay
266         """
267         if self.protocol in SKIPPED_PROTOCOLS:
268             # Ignore any packets for the protocols we're not interested in.
269             return False
270         if self.protocol == "ldap" and self.opcode == '':
271             # skip ldap continuation packets
272             return False
273
274         fn_name = 'packet_%s_%s' % (self.protocol, self.opcode)
275         try:
276             fn = getattr(traffic_packets, fn_name)
277             if fn is traffic_packets.null_packet:
278                 return False
279         except AttributeError:
280             print("missing packet %s" % fn_name, file=sys.stderr)
281             return False
282         return True
283
284
285 class ReplayContext(object):
286     """State/Context for an individual conversation between an simulated client
287        and a server.
288     """
289
290     def __init__(self,
291                  server=None,
292                  lp=None,
293                  creds=None,
294                  badpassword_frequency=None,
295                  prefer_kerberos=None,
296                  tempdir=None,
297                  statsdir=None,
298                  ou=None,
299                  base_dn=None,
300                  domain=None,
301                  domain_sid=None):
302
303         self.server                   = server
304         self.ldap_connections         = []
305         self.dcerpc_connections       = []
306         self.lsarpc_connections       = []
307         self.lsarpc_connections_named = []
308         self.drsuapi_connections      = []
309         self.srvsvc_connections       = []
310         self.samr_contexts            = []
311         self.netlogon_connection      = None
312         self.creds                    = creds
313         self.lp                       = lp
314         self.prefer_kerberos          = prefer_kerberos
315         self.ou                       = ou
316         self.base_dn                  = base_dn
317         self.domain                   = domain
318         self.statsdir                 = statsdir
319         self.global_tempdir           = tempdir
320         self.domain_sid               = domain_sid
321         self.realm                    = lp.get('realm')
322
323         # Bad password attempt controls
324         self.badpassword_frequency    = badpassword_frequency
325         self.last_lsarpc_bad          = False
326         self.last_lsarpc_named_bad    = False
327         self.last_simple_bind_bad     = False
328         self.last_bind_bad            = False
329         self.last_srvsvc_bad          = False
330         self.last_drsuapi_bad         = False
331         self.last_netlogon_bad        = False
332         self.last_samlogon_bad        = False
333         self.generate_ldap_search_tables()
334         self.next_conversation_id = itertools.count().next
335
336     def generate_ldap_search_tables(self):
337         session = system_session()
338
339         db = SamDB(url="ldap://%s" % self.server,
340                    session_info=session,
341                    credentials=self.creds,
342                    lp=self.lp)
343
344         res = db.search(db.domain_dn(),
345                         scope=ldb.SCOPE_SUBTREE,
346                         controls=["paged_results:1:1000"],
347                         attrs=['dn'])
348
349         # find a list of dns for each pattern
350         # e.g. CN,CN,CN,DC,DC
351         dn_map = {}
352         attribute_clue_map = {
353             'invocationId': []
354         }
355
356         for r in res:
357             dn = str(r.dn)
358             pattern = ','.join(x.lstrip()[:2] for x in dn.split(',')).upper()
359             dns = dn_map.setdefault(pattern, [])
360             dns.append(dn)
361             if dn.startswith('CN=NTDS Settings,'):
362                 attribute_clue_map['invocationId'].append(dn)
363
364         # extend the map in case we are working with a different
365         # number of DC components.
366         # for k, v in self.dn_map.items():
367         #     print >>sys.stderr, k, len(v)
368
369         for k, v in dn_map.items():
370             if k[-3:] != ',DC':
371                 continue
372             p = k[:-3]
373             while p[-3:] == ',DC':
374                 p = p[:-3]
375             for i in range(5):
376                 p += ',DC'
377                 if p != k and p in dn_map:
378                     print('dn_map collison %s %s' % (k, p),
379                           file=sys.stderr)
380                     continue
381                 dn_map[p] = dn_map[k]
382
383         self.dn_map = dn_map
384         self.attribute_clue_map = attribute_clue_map
385
386     def generate_process_local_config(self, account, conversation):
387         if account is None:
388             return
389         self.netbios_name             = account.netbios_name
390         self.machinepass              = account.machinepass
391         self.username                 = account.username
392         self.userpass                 = account.userpass
393
394         self.tempdir = mk_masked_dir(self.global_tempdir,
395                                      'conversation-%d' %
396                                      conversation.conversation_id)
397
398         self.lp.set("private dir",     self.tempdir)
399         self.lp.set("lock dir",        self.tempdir)
400         self.lp.set("state directory", self.tempdir)
401         self.lp.set("tls verify peer", "no_check")
402
403         # If the domain was not specified, check for the environment
404         # variable.
405         if self.domain is None:
406             self.domain = os.environ["DOMAIN"]
407
408         self.remoteAddress = "/root/ncalrpc_as_system"
409         self.samlogon_dn   = ("cn=%s,%s" %
410                               (self.netbios_name, self.ou))
411         self.user_dn       = ("cn=%s,%s" %
412                               (self.username, self.ou))
413
414         self.generate_machine_creds()
415         self.generate_user_creds()
416
417     def with_random_bad_credentials(self, f, good, bad, failed_last_time):
418         """Execute the supplied logon function, randomly choosing the
419            bad credentials.
420
421            Based on the frequency in badpassword_frequency randomly perform the
422            function with the supplied bad credentials.
423            If run with bad credentials, the function is re-run with the good
424            credentials.
425            failed_last_time is used to prevent consecutive bad credential
426            attempts. So the over all bad credential frequency will be lower
427            than that requested, but not significantly.
428         """
429         if not failed_last_time:
430             if (self.badpassword_frequency > 0 and
431                random.random() < self.badpassword_frequency):
432                 try:
433                     f(bad)
434                 except:
435                     # Ignore any exceptions as the operation may fail
436                     # as it's being performed with bad credentials
437                     pass
438                 failed_last_time = True
439             else:
440                 failed_last_time = False
441
442         result = f(good)
443         return (result, failed_last_time)
444
445     def generate_user_creds(self):
446         """Generate the conversation specific user Credentials.
447
448         Each Conversation has an associated user account used to simulate
449         any non Administrative user traffic.
450
451         Generates user credentials with good and bad passwords and ldap
452         simple bind credentials with good and bad passwords.
453         """
454         self.user_creds = Credentials()
455         self.user_creds.guess(self.lp)
456         self.user_creds.set_workstation(self.netbios_name)
457         self.user_creds.set_password(self.userpass)
458         self.user_creds.set_username(self.username)
459         if self.prefer_kerberos:
460             self.user_creds.set_kerberos_state(MUST_USE_KERBEROS)
461         else:
462             self.user_creds.set_kerberos_state(DONT_USE_KERBEROS)
463
464         self.user_creds_bad = Credentials()
465         self.user_creds_bad.guess(self.lp)
466         self.user_creds_bad.set_workstation(self.netbios_name)
467         self.user_creds_bad.set_password(self.userpass[:-4])
468         self.user_creds_bad.set_username(self.username)
469         if self.prefer_kerberos:
470             self.user_creds_bad.set_kerberos_state(MUST_USE_KERBEROS)
471         else:
472             self.user_creds_bad.set_kerberos_state(DONT_USE_KERBEROS)
473
474         # Credentials for ldap simple bind.
475         self.simple_bind_creds = Credentials()
476         self.simple_bind_creds.guess(self.lp)
477         self.simple_bind_creds.set_workstation(self.netbios_name)
478         self.simple_bind_creds.set_password(self.userpass)
479         self.simple_bind_creds.set_username(self.username)
480         self.simple_bind_creds.set_gensec_features(
481             self.simple_bind_creds.get_gensec_features() | gensec.FEATURE_SEAL)
482         if self.prefer_kerberos:
483             self.simple_bind_creds.set_kerberos_state(MUST_USE_KERBEROS)
484         else:
485             self.simple_bind_creds.set_kerberos_state(DONT_USE_KERBEROS)
486         self.simple_bind_creds.set_bind_dn(self.user_dn)
487
488         self.simple_bind_creds_bad = Credentials()
489         self.simple_bind_creds_bad.guess(self.lp)
490         self.simple_bind_creds_bad.set_workstation(self.netbios_name)
491         self.simple_bind_creds_bad.set_password(self.userpass[:-4])
492         self.simple_bind_creds_bad.set_username(self.username)
493         self.simple_bind_creds_bad.set_gensec_features(
494             self.simple_bind_creds_bad.get_gensec_features() |
495             gensec.FEATURE_SEAL)
496         if self.prefer_kerberos:
497             self.simple_bind_creds_bad.set_kerberos_state(MUST_USE_KERBEROS)
498         else:
499             self.simple_bind_creds_bad.set_kerberos_state(DONT_USE_KERBEROS)
500         self.simple_bind_creds_bad.set_bind_dn(self.user_dn)
501
502     def generate_machine_creds(self):
503         """Generate the conversation specific machine Credentials.
504
505         Each Conversation has an associated machine account.
506
507         Generates machine credentials with good and bad passwords.
508         """
509
510         self.machine_creds = Credentials()
511         self.machine_creds.guess(self.lp)
512         self.machine_creds.set_workstation(self.netbios_name)
513         self.machine_creds.set_secure_channel_type(SEC_CHAN_WKSTA)
514         self.machine_creds.set_password(self.machinepass)
515         self.machine_creds.set_username(self.netbios_name + "$")
516         if self.prefer_kerberos:
517             self.machine_creds.set_kerberos_state(MUST_USE_KERBEROS)
518         else:
519             self.machine_creds.set_kerberos_state(DONT_USE_KERBEROS)
520
521         self.machine_creds_bad = Credentials()
522         self.machine_creds_bad.guess(self.lp)
523         self.machine_creds_bad.set_workstation(self.netbios_name)
524         self.machine_creds_bad.set_secure_channel_type(SEC_CHAN_WKSTA)
525         self.machine_creds_bad.set_password(self.machinepass[:-4])
526         self.machine_creds_bad.set_username(self.netbios_name + "$")
527         if self.prefer_kerberos:
528             self.machine_creds_bad.set_kerberos_state(MUST_USE_KERBEROS)
529         else:
530             self.machine_creds_bad.set_kerberos_state(DONT_USE_KERBEROS)
531
532     def get_matching_dn(self, pattern, attributes=None):
533         # If the pattern is an empty string, we assume ROOTDSE,
534         # Otherwise we try adding or removing DC suffixes, then
535         # shorter leading patterns until we hit one.
536         # e.g if there is no CN,CN,CN,CN,DC,DC
537         # we first try       CN,CN,CN,CN,DC
538         # and                CN,CN,CN,CN,DC,DC,DC
539         # then change to        CN,CN,CN,DC,DC
540         # and as last resort we use the base_dn
541         attr_clue = self.attribute_clue_map.get(attributes)
542         if attr_clue:
543             return random.choice(attr_clue)
544
545         pattern = pattern.upper()
546         while pattern:
547             if pattern in self.dn_map:
548                 return random.choice(self.dn_map[pattern])
549             # chop one off the front and try it all again.
550             pattern = pattern[3:]
551
552         return self.base_dn
553
554     def get_dcerpc_connection(self, new=False):
555         guid = '12345678-1234-abcd-ef00-01234567cffb'  # RPC_NETLOGON UUID
556         if self.dcerpc_connections and not new:
557             return self.dcerpc_connections[-1]
558         c = ClientConnection("ncacn_ip_tcp:%s" % self.server,
559                              (guid, 1), self.lp)
560         self.dcerpc_connections.append(c)
561         return c
562
563     def get_srvsvc_connection(self, new=False):
564         if self.srvsvc_connections and not new:
565             return self.srvsvc_connections[-1]
566
567         def connect(creds):
568             return srvsvc.srvsvc("ncacn_np:%s" % (self.server),
569                                  self.lp,
570                                  creds)
571
572         (c, self.last_srvsvc_bad) = \
573             self.with_random_bad_credentials(connect,
574                                              self.user_creds,
575                                              self.user_creds_bad,
576                                              self.last_srvsvc_bad)
577
578         self.srvsvc_connections.append(c)
579         return c
580
581     def get_lsarpc_connection(self, new=False):
582         if self.lsarpc_connections and not new:
583             return self.lsarpc_connections[-1]
584
585         def connect(creds):
586             binding_options = 'schannel,seal,sign'
587             return lsa.lsarpc("ncacn_ip_tcp:%s[%s]" %
588                               (self.server, binding_options),
589                               self.lp,
590                               creds)
591
592         (c, self.last_lsarpc_bad) = \
593             self.with_random_bad_credentials(connect,
594                                              self.machine_creds,
595                                              self.machine_creds_bad,
596                                              self.last_lsarpc_bad)
597
598         self.lsarpc_connections.append(c)
599         return c
600
601     def get_lsarpc_named_pipe_connection(self, new=False):
602         if self.lsarpc_connections_named and not new:
603             return self.lsarpc_connections_named[-1]
604
605         def connect(creds):
606             return lsa.lsarpc("ncacn_np:%s" % (self.server),
607                               self.lp,
608                               creds)
609
610         (c, self.last_lsarpc_named_bad) = \
611             self.with_random_bad_credentials(connect,
612                                              self.machine_creds,
613                                              self.machine_creds_bad,
614                                              self.last_lsarpc_named_bad)
615
616         self.lsarpc_connections_named.append(c)
617         return c
618
619     def get_drsuapi_connection_pair(self, new=False, unbind=False):
620         """get a (drs, drs_handle) tuple"""
621         if self.drsuapi_connections and not new:
622             c = self.drsuapi_connections[-1]
623             return c
624
625         def connect(creds):
626             binding_options = 'seal'
627             binding_string = "ncacn_ip_tcp:%s[%s]" %\
628                              (self.server, binding_options)
629             return drsuapi.drsuapi(binding_string, self.lp, creds)
630
631         (drs, self.last_drsuapi_bad) = \
632             self.with_random_bad_credentials(connect,
633                                              self.user_creds,
634                                              self.user_creds_bad,
635                                              self.last_drsuapi_bad)
636
637         (drs_handle, supported_extensions) = drs_DsBind(drs)
638         c = (drs, drs_handle)
639         self.drsuapi_connections.append(c)
640         return c
641
642     def get_ldap_connection(self, new=False, simple=False):
643         if self.ldap_connections and not new:
644             return self.ldap_connections[-1]
645
646         def simple_bind(creds):
647             """
648             To run simple bind against Windows, we need to run
649             following commands in PowerShell:
650
651                 Install-windowsfeature ADCS-Cert-Authority
652                 Install-AdcsCertificationAuthority -CAType EnterpriseRootCA
653                 Restart-Computer
654
655             """
656             return SamDB('ldaps://%s' % self.server,
657                          credentials=creds,
658                          lp=self.lp)
659
660         def sasl_bind(creds):
661             return SamDB('ldap://%s' % self.server,
662                          credentials=creds,
663                          lp=self.lp)
664         if simple:
665             (samdb, self.last_simple_bind_bad) = \
666                 self.with_random_bad_credentials(simple_bind,
667                                                  self.simple_bind_creds,
668                                                  self.simple_bind_creds_bad,
669                                                  self.last_simple_bind_bad)
670         else:
671             (samdb, self.last_bind_bad) = \
672                 self.with_random_bad_credentials(sasl_bind,
673                                                  self.user_creds,
674                                                  self.user_creds_bad,
675                                                  self.last_bind_bad)
676
677         self.ldap_connections.append(samdb)
678         return samdb
679
680     def get_samr_context(self, new=False):
681         if not self.samr_contexts or new:
682             self.samr_contexts.append(
683                 SamrContext(self.server, lp=self.lp, creds=self.creds))
684         return self.samr_contexts[-1]
685
686     def get_netlogon_connection(self):
687
688         if self.netlogon_connection:
689             return self.netlogon_connection
690
691         def connect(creds):
692             return netlogon.netlogon("ncacn_ip_tcp:%s[schannel,seal]" %
693                                      (self.server),
694                                      self.lp,
695                                      creds)
696         (c, self.last_netlogon_bad) = \
697             self.with_random_bad_credentials(connect,
698                                              self.machine_creds,
699                                              self.machine_creds_bad,
700                                              self.last_netlogon_bad)
701         self.netlogon_connection = c
702         return c
703
704     def guess_a_dns_lookup(self):
705         return (self.realm, 'A')
706
707     def get_authenticator(self):
708         auth = self.machine_creds.new_client_authenticator()
709         current  = netr_Authenticator()
710         current.cred.data = [ord(x) for x in auth["credential"]]
711         current.timestamp = auth["timestamp"]
712
713         subsequent = netr_Authenticator()
714         return (current, subsequent)
715
716
717 class SamrContext(object):
718     """State/Context associated with a samr connection.
719     """
720     def __init__(self, server, lp=None, creds=None):
721         self.connection    = None
722         self.handle        = None
723         self.domain_handle = None
724         self.domain_sid    = None
725         self.group_handle  = None
726         self.user_handle   = None
727         self.rids          = None
728         self.server        = server
729         self.lp            = lp
730         self.creds         = creds
731
732     def get_connection(self):
733         if not self.connection:
734             self.connection = samr.samr(
735                 "ncacn_ip_tcp:%s[seal]" % (self.server),
736                 lp_ctx=self.lp,
737                 credentials=self.creds)
738
739         return self.connection
740
741     def get_handle(self):
742         if not self.handle:
743             c = self.get_connection()
744             self.handle = c.Connect2(None, security.SEC_FLAG_MAXIMUM_ALLOWED)
745         return self.handle
746
747
748 class Conversation(object):
749     """Details of a converation between a simulated client and a server."""
750     conversation_id = None
751
752     def __init__(self, start_time=None, endpoints=None):
753         self.start_time = start_time
754         self.endpoints = endpoints
755         self.packets = []
756         self.msg = random_colour_print()
757         self.client_balance = 0.0
758
759     def __cmp__(self, other):
760         if self.start_time is None:
761             if other.start_time is None:
762                 return 0
763             return -1
764         if other.start_time is None:
765             return 1
766         return self.start_time - other.start_time
767
768     def add_packet(self, packet):
769         """Add a packet object to this conversation, making a local copy with
770         a conversation-relative timestamp."""
771         p = packet.copy()
772
773         if self.start_time is None:
774             self.start_time = p.timestamp
775
776         if self.endpoints is None:
777             self.endpoints = p.endpoints
778
779         if p.endpoints != self.endpoints:
780             raise FakePacketError("Conversation endpoints %s don't match"
781                                   "packet endpoints %s" %
782                                   (self.endpoints, p.endpoints))
783
784         p.timestamp -= self.start_time
785
786         if p.src == p.endpoints[0]:
787             self.client_balance -= p.client_score()
788         else:
789             self.client_balance += p.client_score()
790
791         if p.is_really_a_packet():
792             self.packets.append(p)
793
794     def add_short_packet(self, timestamp, p, extra, client=True):
795         """Create a packet from a timestamp, and 'protocol:opcode' pair, and a
796         (possibly empty) list of extra data. If client is True, assume
797         this packet is from the client to the server.
798         """
799         protocol, opcode = p.split(':', 1)
800         src, dest = self.guess_client_server()
801         if not client:
802             src, dest = dest, src
803
804         desc = OP_DESCRIPTIONS.get((protocol, opcode), '')
805         ip_protocol = IP_PROTOCOLS.get(protocol, '06')
806         fields = [timestamp - self.start_time, ip_protocol,
807                   '', src, dest,
808                   protocol, opcode, desc]
809         fields.extend(extra)
810         packet = Packet(fields)
811         # XXX we're assuming the timestamp is already adjusted for
812         # this conversation?
813         # XXX should we adjust client balance for guessed packets?
814         if packet.src == packet.endpoints[0]:
815             self.client_balance -= packet.client_score()
816         else:
817             self.client_balance += packet.client_score()
818         if packet.is_really_a_packet():
819             self.packets.append(packet)
820
821     def __str__(self):
822         return ("<Conversation %s %s starting %.3f %d packets>" %
823                 (self.conversation_id, self.endpoints, self.start_time,
824                  len(self.packets)))
825
826     __repr__ = __str__
827
828     def __iter__(self):
829         return iter(self.packets)
830
831     def __len__(self):
832         return len(self.packets)
833
834     def get_duration(self):
835         if len(self.packets) < 2:
836             return 0
837         return self.packets[-1].timestamp - self.packets[0].timestamp
838
839     def replay_as_summary_lines(self):
840         lines = []
841         for p in self.packets:
842             lines.append(p.as_summary(self.start_time))
843         return lines
844
845     def replay_in_fork_with_delay(self, start, context=None, account=None):
846         """Fork a new process and replay the conversation.
847         """
848         def signal_handler(signal, frame):
849             """Signal handler closes standard out and error.
850
851             Triggered by a sigterm, ensures that the log messages are flushed
852             to disk and not lost.
853             """
854             sys.stderr.close()
855             sys.stdout.close()
856             os._exit(0)
857
858         t = self.start_time
859         now = time.time() - start
860         gap = t - now
861         # we are replaying strictly in order, so it is safe to sleep
862         # in the main process if the gap is big enough. This reduces
863         # the number of concurrent threads, which allows us to make
864         # larger loads.
865         if gap > 0.15 and False:
866             print("sleeping for %f in main process" % (gap - 0.1),
867                   file=sys.stderr)
868             time.sleep(gap - 0.1)
869             now = time.time() - start
870             gap = t - now
871             print("gap is now %f" % gap, file=sys.stderr)
872
873         self.conversation_id = context.next_conversation_id()
874         pid = os.fork()
875         if pid != 0:
876             return pid
877         pid = os.getpid()
878         signal.signal(signal.SIGTERM, signal_handler)
879         # we must never return, or we'll end up running parts of the
880         # parent's clean-up code. So we work in a try...finally, and
881         # try to print any exceptions.
882
883         try:
884             context.generate_process_local_config(account, self)
885             sys.stdin.close()
886             os.close(0)
887             filename = os.path.join(context.statsdir, 'stats-conversation-%d' %
888                                     self.conversation_id)
889             sys.stdout.close()
890             sys.stdout = open(filename, 'w')
891
892             sleep_time = gap - SLEEP_OVERHEAD
893             if sleep_time > 0:
894                 time.sleep(sleep_time)
895
896             miss = t - (time.time() - start)
897             self.msg("starting %s [miss %.3f pid %d]" % (self, miss, pid))
898             self.replay(context)
899         except Exception:
900             print(("EXCEPTION in child PID %d, conversation %s" % (pid, self)),
901                   file=sys.stderr)
902             traceback.print_exc(sys.stderr)
903         finally:
904             sys.stderr.close()
905             sys.stdout.close()
906             os._exit(0)
907
908     def replay(self, context=None):
909         start = time.time()
910
911         for p in self.packets:
912             now = time.time() - start
913             gap = p.timestamp - now
914             sleep_time = gap - SLEEP_OVERHEAD
915             if sleep_time > 0:
916                 time.sleep(sleep_time)
917
918             miss = p.timestamp - (time.time() - start)
919             if context is None:
920                 self.msg("packet %s [miss %.3f pid %d]" % (p, miss,
921                                                            os.getpid()))
922                 continue
923             p.play(self, context)
924
925     def guess_client_server(self, server_clue=None):
926         """Have a go at deciding who is the server and who is the client.
927         returns (client, server)
928         """
929         a, b = self.endpoints
930
931         if self.client_balance < 0:
932             return (a, b)
933
934         # in the absense of a clue, we will fall through to assuming
935         # the lowest number is the server (which is usually true).
936
937         if self.client_balance == 0 and server_clue == b:
938             return (a, b)
939
940         return (b, a)
941
942     def forget_packets_outside_window(self, s, e):
943         """Prune any packets outside the timne window we're interested in
944
945         :param s: start of the window
946         :param e: end of the window
947         """
948
949         new_packets = []
950         for p in self.packets:
951             if p.timestamp < s or p.timestamp > e:
952                 continue
953             new_packets.append(p)
954
955         self.packets = new_packets
956         if new_packets:
957             self.start_time = new_packets[0].timestamp
958         else:
959             self.start_time = None
960
961     def renormalise_times(self, start_time):
962         """Adjust the packet start times relative to the new start time."""
963         for p in self.packets:
964             p.timestamp -= start_time
965
966         if self.start_time is not None:
967             self.start_time -= start_time
968
969
970 class DnsHammer(Conversation):
971     """A lightweight conversation that generates a lot of dns:0 packets on
972     the fly"""
973
974     def __init__(self, dns_rate, duration):
975         n = int(dns_rate * duration)
976         self.times = [random.uniform(0, duration) for i in range(n)]
977         self.times.sort()
978         self.rate = dns_rate
979         self.duration = duration
980         self.start_time = 0
981         self.msg = random_colour_print()
982
983     def __str__(self):
984         return ("<DnsHammer %d packets over %.1fs (rate %.2f)>" %
985                 (len(self.times), self.duration, self.rate))
986
987     def replay_in_fork_with_delay(self, start, context=None, account=None):
988         return Conversation.replay_in_fork_with_delay(self,
989                                                       start,
990                                                       context,
991                                                       account)
992
993     def replay(self, context=None):
994         start = time.time()
995         fn = traffic_packets.packet_dns_0
996         for t in self.times:
997             now = time.time() - start
998             gap = t - now
999             sleep_time = gap - SLEEP_OVERHEAD
1000             if sleep_time > 0:
1001                 time.sleep(sleep_time)
1002
1003             if context is None:
1004                 miss = t - (time.time() - start)
1005                 self.msg("packet %s [miss %.3f pid %d]" % (t, miss,
1006                                                            os.getpid()))
1007                 continue
1008
1009             packet_start = time.time()
1010             try:
1011                 fn(self, self, context)
1012                 end = time.time()
1013                 duration = end - packet_start
1014                 print("%f\tDNS\tdns\t0\t%f\tTrue\t" % (end, duration))
1015             except Exception as e:
1016                 end = time.time()
1017                 duration = end - packet_start
1018                 print("%f\tDNS\tdns\t0\t%f\tFalse\t%s" % (end, duration, e))
1019
1020
1021 def ingest_summaries(files, dns_mode='count'):
1022     """Load a summary traffic summary file and generated Converations from it.
1023     """
1024
1025     dns_counts = defaultdict(int)
1026     packets = []
1027     for f in files:
1028         if isinstance(f, str):
1029             f = open(f)
1030         print("Ingesting %s" % (f.name,), file=sys.stderr)
1031         for line in f:
1032             p = Packet(line)
1033             if p.protocol == 'dns' and dns_mode != 'include':
1034                 dns_counts[p.opcode] += 1
1035             else:
1036                 packets.append(p)
1037
1038         f.close()
1039
1040     if not packets:
1041         return [], 0
1042
1043     start_time = min(p.timestamp for p in packets)
1044     last_packet = max(p.timestamp for p in packets)
1045
1046     print("gathering packets into conversations", file=sys.stderr)
1047     conversations = OrderedDict()
1048     for p in packets:
1049         p.timestamp -= start_time
1050         c = conversations.get(p.endpoints)
1051         if c is None:
1052             c = Conversation()
1053             conversations[p.endpoints] = c
1054         c.add_packet(p)
1055
1056     # We only care about conversations with actual traffic, so we
1057     # filter out conversations with nothing to say. We do that here,
1058     # rather than earlier, because those empty packets contain useful
1059     # hints as to which end of the conversation was the client.
1060     conversation_list = []
1061     for c in conversations.values():
1062         if len(c) != 0:
1063             conversation_list.append(c)
1064
1065     # This is obviously not correct, as many conversations will appear
1066     # to start roughly simultaneously at the beginning of the snapshot.
1067     # To which we say: oh well, so be it.
1068     duration = float(last_packet - start_time)
1069     mean_interval = len(conversations) / duration
1070
1071     return conversation_list, mean_interval, duration, dns_counts
1072
1073
1074 def guess_server_address(conversations):
1075     # we guess the most common address.
1076     addresses = Counter()
1077     for c in conversations:
1078         addresses.update(c.endpoints)
1079     if addresses:
1080         return addresses.most_common(1)[0]
1081
1082
1083 def stringify_keys(x):
1084     y = {}
1085     for k, v in x.items():
1086         k2 = '\t'.join(k)
1087         y[k2] = v
1088     return y
1089
1090
1091 def unstringify_keys(x):
1092     y = {}
1093     for k, v in x.items():
1094         t = tuple(str(k).split('\t'))
1095         y[t] = v
1096     return y
1097
1098
1099 class TrafficModel(object):
1100     def __init__(self, n=3):
1101         self.ngrams = {}
1102         self.query_details = {}
1103         self.n = n
1104         self.dns_opcounts = defaultdict(int)
1105         self.cumulative_duration = 0.0
1106         self.conversation_rate = [0, 1]
1107
1108     def learn(self, conversations, dns_opcounts={}):
1109         prev = 0.0
1110         cum_duration = 0.0
1111         key = (NON_PACKET,) * (self.n - 1)
1112
1113         server = guess_server_address(conversations)
1114
1115         for k, v in dns_opcounts.items():
1116             self.dns_opcounts[k] += v
1117
1118         if len(conversations) > 1:
1119             elapsed =\
1120                 conversations[-1].start_time - conversations[0].start_time
1121             self.conversation_rate[0] = len(conversations)
1122             self.conversation_rate[1] = elapsed
1123
1124         for c in conversations:
1125             client, server = c.guess_client_server(server)
1126             cum_duration += c.get_duration()
1127             key = (NON_PACKET,) * (self.n - 1)
1128             for p in c:
1129                 if p.src != client:
1130                     continue
1131
1132                 elapsed = p.timestamp - prev
1133                 prev = p.timestamp
1134                 if elapsed > WAIT_THRESHOLD:
1135                     # add the wait as an extra state
1136                     wait = 'wait:%d' % (math.log(max(1.0,
1137                                                      elapsed * WAIT_SCALE)))
1138                     self.ngrams.setdefault(key, []).append(wait)
1139                     key = key[1:] + (wait,)
1140
1141                 short_p = p.as_packet_type()
1142                 self.query_details.setdefault(short_p,
1143                                               []).append(tuple(p.extra))
1144                 self.ngrams.setdefault(key, []).append(short_p)
1145                 key = key[1:] + (short_p,)
1146
1147         self.cumulative_duration += cum_duration
1148         # add in the end
1149         self.ngrams.setdefault(key, []).append(NON_PACKET)
1150
1151     def save(self, f):
1152         ngrams = {}
1153         for k, v in self.ngrams.items():
1154             k = '\t'.join(k)
1155             ngrams[k] = dict(Counter(v))
1156
1157         query_details = {}
1158         for k, v in self.query_details.items():
1159             query_details[k] = dict(Counter('\t'.join(x) if x else '-'
1160                                             for x in v))
1161
1162         d = {
1163             'ngrams': ngrams,
1164             'query_details': query_details,
1165             'cumulative_duration': self.cumulative_duration,
1166             'conversation_rate': self.conversation_rate,
1167         }
1168         d['dns'] = self.dns_opcounts
1169
1170         if isinstance(f, str):
1171             f = open(f, 'w')
1172
1173         json.dump(d, f, indent=2)
1174
1175     def load(self, f):
1176         if isinstance(f, str):
1177             f = open(f)
1178
1179         d = json.load(f)
1180
1181         for k, v in d['ngrams'].items():
1182             k = tuple(str(k).split('\t'))
1183             values = self.ngrams.setdefault(k, [])
1184             for p, count in v.items():
1185                 values.extend([str(p)] * count)
1186
1187         for k, v in d['query_details'].items():
1188             values = self.query_details.setdefault(str(k), [])
1189             for p, count in v.items():
1190                 if p == '-':
1191                     values.extend([()] * count)
1192                 else:
1193                     values.extend([tuple(str(p).split('\t'))] * count)
1194
1195         if 'dns' in d:
1196             for k, v in d['dns'].items():
1197                 self.dns_opcounts[k] += v
1198
1199         self.cumulative_duration = d['cumulative_duration']
1200         self.conversation_rate = d['conversation_rate']
1201
1202     def construct_conversation(self, timestamp=0.0, client=2, server=1,
1203                                hard_stop=None, packet_rate=1):
1204         """Construct a individual converation from the model."""
1205
1206         c = Conversation(timestamp, (server, client))
1207
1208         key = (NON_PACKET,) * (self.n - 1)
1209
1210         while key in self.ngrams:
1211             p = random.choice(self.ngrams.get(key, NON_PACKET))
1212             if p == NON_PACKET:
1213                 break
1214             if p in self.query_details:
1215                 extra = random.choice(self.query_details[p])
1216             else:
1217                 extra = []
1218
1219             protocol, opcode = p.split(':', 1)
1220             if protocol == 'wait':
1221                 log_wait_time = int(opcode) + random.random()
1222                 wait = math.exp(log_wait_time) / (WAIT_SCALE * packet_rate)
1223                 timestamp += wait
1224             else:
1225                 log_wait = random.uniform(*NO_WAIT_LOG_TIME_RANGE)
1226                 wait = math.exp(log_wait) / packet_rate
1227                 timestamp += wait
1228                 if hard_stop is not None and timestamp > hard_stop:
1229                     break
1230                 c.add_short_packet(timestamp, p, extra)
1231
1232             key = key[1:] + (p,)
1233
1234         return c
1235
1236     def generate_conversations(self, rate, duration, packet_rate=1):
1237         """Generate a list of conversations from the model."""
1238
1239         # We run the simulation for at least ten times as long as our
1240         # desired duration, and take a section near the start.
1241         rate_n, rate_t  = self.conversation_rate
1242
1243         duration2 = max(rate_t, duration * 2)
1244         n = rate * duration2 * rate_n / rate_t
1245
1246         server = 1
1247         client = 2
1248
1249         conversations = []
1250         end = duration2
1251         start = end - duration
1252
1253         while client < n + 2:
1254             start = random.uniform(0, duration2)
1255             c = self.construct_conversation(start,
1256                                             client,
1257                                             server,
1258                                             hard_stop=(duration2 * 5),
1259                                             packet_rate=packet_rate)
1260
1261             c.forget_packets_outside_window(start, end)
1262             c.renormalise_times(start)
1263             if len(c) != 0:
1264                 conversations.append(c)
1265             client += 1
1266
1267         print(("we have %d conversations at rate %f" %
1268                               (len(conversations), rate)), file=sys.stderr)
1269         conversations.sort()
1270         return conversations
1271
1272
1273 IP_PROTOCOLS = {
1274     'dns': '11',
1275     'rpc_netlogon': '06',
1276     'kerberos': '06',      # ratio 16248:258
1277     'smb': '06',
1278     'smb2': '06',
1279     'ldap': '06',
1280     'cldap': '11',
1281     'lsarpc': '06',
1282     'samr': '06',
1283     'dcerpc': '06',
1284     'epm': '06',
1285     'drsuapi': '06',
1286     'browser': '11',
1287     'smb_netlogon': '11',
1288     'srvsvc': '06',
1289     'nbns': '11',
1290 }
1291
1292 OP_DESCRIPTIONS = {
1293     ('browser', '0x01'): 'Host Announcement (0x01)',
1294     ('browser', '0x02'): 'Request Announcement (0x02)',
1295     ('browser', '0x08'): 'Browser Election Request (0x08)',
1296     ('browser', '0x09'): 'Get Backup List Request (0x09)',
1297     ('browser', '0x0c'): 'Domain/Workgroup Announcement (0x0c)',
1298     ('browser', '0x0f'): 'Local Master Announcement (0x0f)',
1299     ('cldap', '3'): 'searchRequest',
1300     ('cldap', '5'): 'searchResDone',
1301     ('dcerpc', '0'): 'Request',
1302     ('dcerpc', '11'): 'Bind',
1303     ('dcerpc', '12'): 'Bind_ack',
1304     ('dcerpc', '13'): 'Bind_nak',
1305     ('dcerpc', '14'): 'Alter_context',
1306     ('dcerpc', '15'): 'Alter_context_resp',
1307     ('dcerpc', '16'): 'AUTH3',
1308     ('dcerpc', '2'): 'Response',
1309     ('dns', '0'): 'query',
1310     ('dns', '1'): 'response',
1311     ('drsuapi', '0'): 'DsBind',
1312     ('drsuapi', '12'): 'DsCrackNames',
1313     ('drsuapi', '13'): 'DsWriteAccountSpn',
1314     ('drsuapi', '1'): 'DsUnbind',
1315     ('drsuapi', '2'): 'DsReplicaSync',
1316     ('drsuapi', '3'): 'DsGetNCChanges',
1317     ('drsuapi', '4'): 'DsReplicaUpdateRefs',
1318     ('epm', '3'): 'Map',
1319     ('kerberos', ''): '',
1320     ('ldap', '0'): 'bindRequest',
1321     ('ldap', '1'): 'bindResponse',
1322     ('ldap', '2'): 'unbindRequest',
1323     ('ldap', '3'): 'searchRequest',
1324     ('ldap', '4'): 'searchResEntry',
1325     ('ldap', '5'): 'searchResDone',
1326     ('ldap', ''): '*** Unknown ***',
1327     ('lsarpc', '14'): 'lsa_LookupNames',
1328     ('lsarpc', '15'): 'lsa_LookupSids',
1329     ('lsarpc', '39'): 'lsa_QueryTrustedDomainInfoBySid',
1330     ('lsarpc', '40'): 'lsa_SetTrustedDomainInfo',
1331     ('lsarpc', '6'): 'lsa_OpenPolicy',
1332     ('lsarpc', '76'): 'lsa_LookupSids3',
1333     ('lsarpc', '77'): 'lsa_LookupNames4',
1334     ('nbns', '0'): 'query',
1335     ('nbns', '1'): 'response',
1336     ('rpc_netlogon', '21'): 'NetrLogonDummyRoutine1',
1337     ('rpc_netlogon', '26'): 'NetrServerAuthenticate3',
1338     ('rpc_netlogon', '29'): 'NetrLogonGetDomainInfo',
1339     ('rpc_netlogon', '30'): 'NetrServerPasswordSet2',
1340     ('rpc_netlogon', '39'): 'NetrLogonSamLogonEx',
1341     ('rpc_netlogon', '40'): 'DsrEnumerateDomainTrusts',
1342     ('rpc_netlogon', '45'): 'NetrLogonSamLogonWithFlags',
1343     ('rpc_netlogon', '4'): 'NetrServerReqChallenge',
1344     ('samr', '0',): 'Connect',
1345     ('samr', '16'): 'GetAliasMembership',
1346     ('samr', '17'): 'LookupNames',
1347     ('samr', '18'): 'LookupRids',
1348     ('samr', '19'): 'OpenGroup',
1349     ('samr', '1'): 'Close',
1350     ('samr', '25'): 'QueryGroupMember',
1351     ('samr', '34'): 'OpenUser',
1352     ('samr', '36'): 'QueryUserInfo',
1353     ('samr', '39'): 'GetGroupsForUser',
1354     ('samr', '3'): 'QuerySecurity',
1355     ('samr', '5'): 'LookupDomain',
1356     ('samr', '64'): 'Connect5',
1357     ('samr', '6'): 'EnumDomains',
1358     ('samr', '7'): 'OpenDomain',
1359     ('samr', '8'): 'QueryDomainInfo',
1360     ('smb', '0x04'): 'Close (0x04)',
1361     ('smb', '0x24'): 'Locking AndX (0x24)',
1362     ('smb', '0x2e'): 'Read AndX (0x2e)',
1363     ('smb', '0x32'): 'Trans2 (0x32)',
1364     ('smb', '0x71'): 'Tree Disconnect (0x71)',
1365     ('smb', '0x72'): 'Negotiate Protocol (0x72)',
1366     ('smb', '0x73'): 'Session Setup AndX (0x73)',
1367     ('smb', '0x74'): 'Logoff AndX (0x74)',
1368     ('smb', '0x75'): 'Tree Connect AndX (0x75)',
1369     ('smb', '0xa2'): 'NT Create AndX (0xa2)',
1370     ('smb2', '0'): 'NegotiateProtocol',
1371     ('smb2', '11'): 'Ioctl',
1372     ('smb2', '14'): 'Find',
1373     ('smb2', '16'): 'GetInfo',
1374     ('smb2', '18'): 'Break',
1375     ('smb2', '1'): 'SessionSetup',
1376     ('smb2', '2'): 'SessionLogoff',
1377     ('smb2', '3'): 'TreeConnect',
1378     ('smb2', '4'): 'TreeDisconnect',
1379     ('smb2', '5'): 'Create',
1380     ('smb2', '6'): 'Close',
1381     ('smb2', '8'): 'Read',
1382     ('smb_netlogon', '0x12'): 'SAM LOGON request from client (0x12)',
1383     ('smb_netlogon', '0x17'): ('SAM Active Directory Response - '
1384                                'user unknown (0x17)'),
1385     ('srvsvc', '16'): 'NetShareGetInfo',
1386     ('srvsvc', '21'): 'NetSrvGetInfo',
1387 }
1388
1389
1390 def expand_short_packet(p, timestamp, src, dest, extra):
1391     protocol, opcode = p.split(':', 1)
1392     desc = OP_DESCRIPTIONS.get((protocol, opcode), '')
1393     ip_protocol = IP_PROTOCOLS.get(protocol, '06')
1394
1395     line = [timestamp, ip_protocol, '', src, dest, protocol, opcode, desc]
1396     line.extend(extra)
1397     return '\t'.join(line)
1398
1399
1400 def replay(conversations,
1401            host=None,
1402            creds=None,
1403            lp=None,
1404            accounts=None,
1405            dns_rate=0,
1406            duration=None,
1407            **kwargs):
1408
1409     context = ReplayContext(server=host,
1410                             creds=creds,
1411                             lp=lp,
1412                             **kwargs)
1413
1414     if len(accounts) < len(conversations):
1415         print(("we have %d accounts but %d conversations" %
1416                (accounts, conversations)), file=sys.stderr)
1417
1418     cstack = list(zip(
1419         sorted(conversations, key=lambda x: x.start_time, reverse=True),
1420         accounts))
1421
1422     # Set the process group so that the calling scripts are not killed
1423     # when the forked child processes are killed.
1424     os.setpgrp()
1425
1426     start = time.time()
1427
1428     if duration is None:
1429         # end 1 second after the last packet of the last conversation
1430         # to start. Conversations other than the last could still be
1431         # going, but we don't care.
1432         duration = cstack[0][0].packets[-1].timestamp + 1.0
1433         print("We will stop after %.1f seconds" % duration,
1434               file=sys.stderr)
1435
1436     end = start + duration
1437
1438     print("Replaying traffic for %u conversations over %d seconds"
1439           % (len(conversations), duration))
1440
1441     children = {}
1442     if dns_rate:
1443         dns_hammer = DnsHammer(dns_rate, duration)
1444         cstack.append((dns_hammer, None))
1445
1446     try:
1447         while True:
1448             # we spawn a batch, wait for finishers, then spawn another
1449             now = time.time()
1450             batch_end = min(now + 2.0, end)
1451             fork_time = 0.0
1452             fork_n = 0
1453             while cstack:
1454                 c, account = cstack.pop()
1455                 if c.start_time + start > batch_end:
1456                     cstack.append((c, account))
1457                     break
1458
1459                 st = time.time()
1460                 pid = c.replay_in_fork_with_delay(start, context, account)
1461                 children[pid] = c
1462                 t = time.time()
1463                 elapsed = t - st
1464                 fork_time += elapsed
1465                 fork_n += 1
1466                 print("forked %s in pid %s (in %fs)" % (c, pid,
1467                                                         elapsed),
1468                       file=sys.stderr)
1469
1470             if fork_n:
1471                 print(("forked %d times in %f seconds (avg %f)" %
1472                        (fork_n, fork_time, fork_time / fork_n)),
1473                       file=sys.stderr)
1474             elif cstack:
1475                 debug(2, "no forks in batch ending %f" % batch_end)
1476
1477             while time.time() < batch_end - 1.0:
1478                 time.sleep(0.01)
1479                 try:
1480                     pid, status = os.waitpid(-1, os.WNOHANG)
1481                 except OSError as e:
1482                     if e.errno != 10:  # no child processes
1483                         raise
1484                     break
1485                 if pid:
1486                     c = children.pop(pid, None)
1487                     print(("process %d finished conversation %s;"
1488                            " %d to go" %
1489                            (pid, c, len(children))), file=sys.stderr)
1490
1491             if time.time() >= end:
1492                 print("time to stop", file=sys.stderr)
1493                 break
1494
1495     except Exception:
1496         print("EXCEPTION in parent", file=sys.stderr)
1497         traceback.print_exc()
1498     finally:
1499         for s in (15, 15, 9):
1500             print(("killing %d children with -%d" %
1501                                  (len(children), s)), file=sys.stderr)
1502             for pid in children:
1503                 try:
1504                     os.kill(pid, s)
1505                 except OSError as e:
1506                     if e.errno != 3:  # don't fail if it has already died
1507                         raise
1508             time.sleep(0.5)
1509             end = time.time() + 1
1510             while children:
1511                 try:
1512                     pid, status = os.waitpid(-1, os.WNOHANG)
1513                 except OSError as e:
1514                     if e.errno != 10:
1515                         raise
1516                 if pid != 0:
1517                     c = children.pop(pid, None)
1518                     print(("kill -%d %d KILLED conversation %s; "
1519                            "%d to go" %
1520                            (s, pid, c, len(children))),
1521                           file=sys.stderr)
1522                 if time.time() >= end:
1523                     break
1524
1525             if not children:
1526                 break
1527             time.sleep(1)
1528
1529         if children:
1530             print("%d children are missing" % len(children),
1531                   file=sys.stderr)
1532
1533         # there may be stragglers that were forked just as ^C was hit
1534         # and don't appear in the list of children. We can get them
1535         # with killpg, but that will also kill us, so this is^H^H would be
1536         # goodbye, except we cheat and pretend to use ^C (SIG_INTERRUPT),
1537         # so as not to have to fuss around writing signal handlers.
1538         try:
1539             os.killpg(0, 2)
1540         except KeyboardInterrupt:
1541             print("ignoring fake ^C", file=sys.stderr)
1542
1543
1544 def openLdb(host, creds, lp):
1545     session = system_session()
1546     ldb = SamDB(url="ldap://%s" % host,
1547                 session_info=session,
1548                 credentials=creds,
1549                 lp=lp)
1550     return ldb
1551
1552
1553 def ou_name(ldb, instance_id):
1554     """Generate an ou name from the instance id"""
1555     return "ou=instance-%d,ou=traffic_replay,%s" % (instance_id,
1556                                                     ldb.domain_dn())
1557
1558
1559 def create_ou(ldb, instance_id):
1560     """Create an ou, all created user and machine accounts will belong to it.
1561
1562     This allows all the created resources to be cleaned up easily.
1563     """
1564     ou = ou_name(ldb, instance_id)
1565     try:
1566         ldb.add({"dn":          ou.split(',', 1)[1],
1567                  "objectclass": "organizationalunit"})
1568     except LdbError as e:
1569         (status, _) = e
1570         # ignore already exists
1571         if status != 68:
1572             raise
1573     try:
1574         ldb.add({"dn":          ou,
1575                  "objectclass": "organizationalunit"})
1576     except LdbError as e:
1577         (status, _) = e
1578         # ignore already exists
1579         if status != 68:
1580             raise
1581     return ou
1582
1583
1584 class ConversationAccounts(object):
1585     """Details of the machine and user accounts associated with a conversation.
1586     """
1587     def __init__(self, netbios_name, machinepass, username, userpass):
1588         self.netbios_name = netbios_name
1589         self.machinepass  = machinepass
1590         self.username     = username
1591         self.userpass     = userpass
1592
1593
1594 def generate_replay_accounts(ldb, instance_id, number, password):
1595     """Generate a series of unique machine and user account names."""
1596
1597     generate_traffic_accounts(ldb, instance_id, number, password)
1598     accounts = []
1599     for i in range(1, number + 1):
1600         netbios_name = "STGM-%d-%d" % (instance_id, i)
1601         username     = "STGU-%d-%d" % (instance_id, i)
1602
1603         account = ConversationAccounts(netbios_name, password, username,
1604                                        password)
1605         accounts.append(account)
1606     return accounts
1607
1608
1609 def generate_traffic_accounts(ldb, instance_id, number, password):
1610     """Create the specified number of user and machine accounts.
1611
1612     As accounts are not explicitly deleted between runs. This function starts
1613     with the last account and iterates backwards stopping either when it
1614     finds an already existing account or it has generated all the required
1615     accounts.
1616     """
1617     print(("Generating machine and conversation accounts, "
1618            "as required for %d conversations" % number),
1619           file=sys.stderr)
1620     added = 0
1621     for i in range(number, 0, -1):
1622         try:
1623             netbios_name = "STGM-%d-%d" % (instance_id, i)
1624             create_machine_account(ldb, instance_id, netbios_name, password)
1625             added += 1
1626         except LdbError as e:
1627             (status, _) = e
1628             if status == 68:
1629                 break
1630             else:
1631                 raise
1632     if added > 0:
1633         print("Added %d new machine accounts" % added,
1634               file=sys.stderr)
1635
1636     added = 0
1637     for i in range(number, 0, -1):
1638         try:
1639             username = "STGU-%d-%d" % (instance_id, i)
1640             create_user_account(ldb, instance_id, username, password)
1641             added += 1
1642         except LdbError as e:
1643             (status, _) = e
1644             if status == 68:
1645                 break
1646             else:
1647                 raise
1648
1649     if added > 0:
1650         print("Added %d new user accounts" % added,
1651               file=sys.stderr)
1652
1653
1654 def create_machine_account(ldb, instance_id, netbios_name, machinepass):
1655     """Create a machine account via ldap."""
1656
1657     ou = ou_name(ldb, instance_id)
1658     dn = "cn=%s,%s" % (netbios_name, ou)
1659     utf16pw = unicode(
1660         '"' + machinepass.encode('utf-8') + '"', 'utf-8'
1661     ).encode('utf-16-le')
1662     start = time.time()
1663     ldb.add({
1664         "dn": dn,
1665         "objectclass": "computer",
1666         "sAMAccountName": "%s$" % netbios_name,
1667         "userAccountControl":
1668         str(UF_WORKSTATION_TRUST_ACCOUNT | UF_PASSWD_NOTREQD),
1669         "unicodePwd": utf16pw})
1670     end = time.time()
1671     duration = end - start
1672     print("%f\t0\tcreate\tmachine\t%f\tTrue\t" % (end, duration))
1673
1674
1675 def create_user_account(ldb, instance_id, username, userpass):
1676     """Create a user account via ldap."""
1677     ou = ou_name(ldb, instance_id)
1678     user_dn = "cn=%s,%s" % (username, ou)
1679     utf16pw = unicode(
1680         '"' + userpass.encode('utf-8') + '"', 'utf-8'
1681     ).encode('utf-16-le')
1682     start = time.time()
1683     ldb.add({
1684         "dn": user_dn,
1685         "objectclass": "user",
1686         "sAMAccountName": username,
1687         "userAccountControl": str(UF_NORMAL_ACCOUNT),
1688         "unicodePwd": utf16pw
1689     })
1690     end = time.time()
1691     duration = end - start
1692     print("%f\t0\tcreate\tuser\t%f\tTrue\t" % (end, duration))
1693
1694
1695 def create_group(ldb, instance_id, name):
1696     """Create a group via ldap."""
1697
1698     ou = ou_name(ldb, instance_id)
1699     dn = "cn=%s,%s" % (name, ou)
1700     start = time.time()
1701     ldb.add({
1702         "dn": dn,
1703         "objectclass": "group",
1704     })
1705     end = time.time()
1706     duration = end - start
1707     print("%f\t0\tcreate\tgroup\t%f\tTrue\t" % (end, duration))
1708
1709
1710 def user_name(instance_id, i):
1711     """Generate a user name based in the instance id"""
1712     return "STGU-%d-%d" % (instance_id, i)
1713
1714
1715 def generate_users(ldb, instance_id, number, password):
1716     """Add users to the server"""
1717     users = 0
1718     for i in range(number, 0, -1):
1719         try:
1720             username = user_name(instance_id, i)
1721             create_user_account(ldb, instance_id, username, password)
1722             users += 1
1723         except LdbError as e:
1724             (status, _) = e
1725             # Stop if entry exists
1726             if status == 68:
1727                 break
1728             else:
1729                 raise
1730
1731     return users
1732
1733
1734 def group_name(instance_id, i):
1735     """Generate a group name from instance id."""
1736     return "STGG-%d-%d" % (instance_id, i)
1737
1738
1739 def generate_groups(ldb, instance_id, number):
1740     """Create the required number of groups on the server."""
1741     groups = 0
1742     for i in range(number, 0, -1):
1743         try:
1744             name = group_name(instance_id, i)
1745             create_group(ldb, instance_id, name)
1746             groups += 1
1747         except LdbError as e:
1748             (status, _) = e
1749             # Stop if entry exists
1750             if status == 68:
1751                 break
1752             else:
1753                 raise
1754     return groups
1755
1756
1757 def clean_up_accounts(ldb, instance_id):
1758     """Remove the created accounts and groups from the server."""
1759     ou = ou_name(ldb, instance_id)
1760     try:
1761         ldb.delete(ou, ["tree_delete:1"])
1762     except LdbError as e:
1763         (status, _) = e
1764         # ignore does not exist
1765         if status != 32:
1766             raise
1767
1768
1769 def generate_users_and_groups(ldb, instance_id, password,
1770                               number_of_users, number_of_groups,
1771                               group_memberships):
1772     """Generate the required users and groups, allocating the users to
1773        those groups."""
1774     assignments = []
1775     groups_added  = 0
1776
1777     create_ou(ldb, instance_id)
1778
1779     print("Generating dummy user accounts", file=sys.stderr)
1780     users_added = generate_users(ldb, instance_id, number_of_users, password)
1781
1782     if number_of_groups > 0:
1783         print("Generating dummy groups", file=sys.stderr)
1784         groups_added = generate_groups(ldb, instance_id, number_of_groups)
1785
1786     if group_memberships > 0:
1787         print("Assigning users to groups", file=sys.stderr)
1788         assignments = assign_groups(number_of_groups,
1789                                     groups_added,
1790                                     number_of_users,
1791                                     users_added,
1792                                     group_memberships)
1793         print("Adding users to groups", file=sys.stderr)
1794         add_users_to_groups(ldb, instance_id, assignments)
1795
1796     if (groups_added > 0 and users_added == 0 and
1797        number_of_groups != groups_added):
1798         print("Warning: the added groups will contain no members",
1799               file=sys.stderr)
1800
1801     print(("Added %d users, %d groups and %d group memberships" %
1802            (users_added, groups_added, len(assignments))),
1803           file=sys.stderr)
1804
1805
1806 def assign_groups(number_of_groups,
1807                   groups_added,
1808                   number_of_users,
1809                   users_added,
1810                   group_memberships):
1811     """Allocate users to groups.
1812
1813     The intention is to have a few users that belong to most groups, while
1814     the majority of users belong to a few groups.
1815
1816     A few groups will contain most users, with the remaining only having a
1817     few users.
1818     """
1819
1820     def generate_user_distribution(n):
1821         """Probability distribution of a user belonging to a group.
1822         """
1823         dist = []
1824         for x in range(1, n + 1):
1825             p = 1 / (x + 0.001)
1826             dist.append(p)
1827         return dist
1828
1829     def generate_group_distribution(n):
1830         """Probability distribution of a group containing a user."""
1831         dist = []
1832         for x in range(1, n + 1):
1833             p = 1 / (x**1.3)
1834             dist.append(p)
1835         return dist
1836
1837     assignments = set()
1838     if group_memberships <= 0:
1839         return assignments
1840
1841     group_dist = generate_group_distribution(number_of_groups)
1842     user_dist  = generate_user_distribution(number_of_users)
1843
1844     # Calculate the number of group menberships required
1845     group_memberships = math.ceil(
1846         float(group_memberships) *
1847         (float(users_added) / float(number_of_users)))
1848
1849     existing_users  = number_of_users  - users_added  - 1
1850     existing_groups = number_of_groups - groups_added - 1
1851     while len(assignments) < group_memberships:
1852         user        = random.randint(0, number_of_users - 1)
1853         group       = random.randint(0, number_of_groups - 1)
1854         probability = group_dist[group] * user_dist[user]
1855
1856         if ((random.random() < probability * 10000) and
1857            (group > existing_groups or user > existing_users)):
1858             # the + 1 converts the array index to the corresponding
1859             # group or user number
1860             assignments.add(((user + 1), (group + 1)))
1861
1862     return assignments
1863
1864
1865 def add_users_to_groups(db, instance_id, assignments):
1866     """Add users to their assigned groups.
1867
1868     Takes the list of (group,user) tuples generated by assign_groups and
1869     assign the users to their specified groups."""
1870
1871     ou = ou_name(db, instance_id)
1872
1873     def build_dn(name):
1874         return("cn=%s,%s" % (name, ou))
1875
1876     for (user, group) in assignments:
1877         user_dn  = build_dn(user_name(instance_id, user))
1878         group_dn = build_dn(group_name(instance_id, group))
1879
1880         m = ldb.Message()
1881         m.dn = ldb.Dn(db, group_dn)
1882         m["member"] = ldb.MessageElement(user_dn, ldb.FLAG_MOD_ADD, "member")
1883         start = time.time()
1884         db.modify(m)
1885         end = time.time()
1886         duration = end - start
1887         print("%f\t0\tadd\tuser\t%f\tTrue\t" % (end, duration))
1888
1889
1890 def generate_stats(statsdir, timing_file):
1891     """Generate and print the summary stats for a run."""
1892     first      = sys.float_info.max
1893     last       = 0
1894     successful = 0
1895     failed     = 0
1896     latencies  = {}
1897     failures   = {}
1898     unique_converations = set()
1899     conversations = 0
1900
1901     if timing_file is not None:
1902         tw = timing_file.write
1903     else:
1904         def tw(x):
1905             pass
1906
1907     tw("time\tconv\tprotocol\ttype\tduration\tsuccessful\terror\n")
1908
1909     for filename in os.listdir(statsdir):
1910         path = os.path.join(statsdir, filename)
1911         with open(path, 'r') as f:
1912             for line in f:
1913                 try:
1914                     fields       = line.rstrip('\n').split('\t')
1915                     conversation = fields[1]
1916                     protocol     = fields[2]
1917                     packet_type  = fields[3]
1918                     latency      = float(fields[4])
1919                     first        = min(float(fields[0]) - latency, first)
1920                     last         = max(float(fields[0]), last)
1921
1922                     if protocol not in latencies:
1923                         latencies[protocol] = {}
1924                     if packet_type not in latencies[protocol]:
1925                         latencies[protocol][packet_type] = []
1926
1927                     latencies[protocol][packet_type].append(latency)
1928
1929                     if protocol not in failures:
1930                         failures[protocol] = {}
1931                     if packet_type not in failures[protocol]:
1932                         failures[protocol][packet_type] = 0
1933
1934                     if fields[5] == 'True':
1935                         successful += 1
1936                     else:
1937                         failed += 1
1938                         failures[protocol][packet_type] += 1
1939
1940                     if conversation not in unique_converations:
1941                         unique_converations.add(conversation)
1942                         conversations += 1
1943
1944                     tw(line)
1945                 except (ValueError, IndexError):
1946                     # not a valid line print and ignore
1947                     print(line, file=sys.stderr)
1948                     pass
1949     duration = last - first
1950     if successful == 0:
1951         success_rate = 0
1952     else:
1953         success_rate = successful / duration
1954     if failed == 0:
1955         failure_rate = 0
1956     else:
1957         failure_rate = failed / duration
1958
1959     # print the stats in more human-readable format when stdout is going to the
1960     # console (as opposed to being redirected to a file)
1961     if sys.stdout.isatty():
1962         print("Total conversations:   %10d" % conversations)
1963         print("Successful operations: %10d (%.3f per second)"
1964               % (successful, success_rate))
1965         print("Failed operations:     %10d (%.3f per second)"
1966               % (failed, failure_rate))
1967     else:
1968         print("(%d, %d, %d, %.3f, %.3f)" %
1969               (conversations, successful, failed, success_rate, failure_rate))
1970
1971     if sys.stdout.isatty():
1972         print("Protocol    Op Code  Description                               "
1973               " Count       Failed         Mean       Median          "
1974               "95%        Range          Max")
1975     else:
1976         print("proto\top_code\tdesc\tcount\tfailed\tmean\tmedian\t95%\trange"
1977               "\tmax")
1978     protocols = sorted(latencies.keys())
1979     for protocol in protocols:
1980         packet_types = sorted(latencies[protocol], key=opcode_key)
1981         for packet_type in packet_types:
1982             values     = latencies[protocol][packet_type]
1983             values     = sorted(values)
1984             count      = len(values)
1985             failed     = failures[protocol][packet_type]
1986             mean       = sum(values) / count
1987             median     = calc_percentile(values, 0.50)
1988             percentile = calc_percentile(values, 0.95)
1989             rng        = values[-1] - values[0]
1990             maxv       = values[-1]
1991             desc       = OP_DESCRIPTIONS.get((protocol, packet_type), '')
1992             if sys.stdout.isatty:
1993                 print("%-12s   %4s  %-35s %12d %12d %12.6f "
1994                       "%12.6f %12.6f %12.6f %12.6f"
1995                       % (protocol,
1996                          packet_type,
1997                          desc,
1998                          count,
1999                          failed,
2000                          mean,
2001                          median,
2002                          percentile,
2003                          rng,
2004                          maxv))
2005             else:
2006                 print("%s\t%s\t%s\t%d\t%d\t%f\t%f\t%f\t%f\t%f"
2007                       % (protocol,
2008                          packet_type,
2009                          desc,
2010                          count,
2011                          failed,
2012                          mean,
2013                          median,
2014                          percentile,
2015                          rng,
2016                          maxv))
2017
2018
2019 def opcode_key(v):
2020     """Sort key for the operation code to ensure that it sorts numerically"""
2021     try:
2022         return "%03d" % int(v)
2023     except:
2024         return v
2025
2026
2027 def calc_percentile(values, percentile):
2028     """Calculate the specified percentile from the list of values.
2029
2030     Assumes the list is sorted in ascending order.
2031     """
2032
2033     if not values:
2034         return 0
2035     k = (len(values) - 1) * percentile
2036     f = math.floor(k)
2037     c = math.ceil(k)
2038     if f == c:
2039         return values[int(k)]
2040     d0 = values[int(f)] * (c - k)
2041     d1 = values[int(c)] * (k - f)
2042     return d0 + d1
2043
2044
2045 def mk_masked_dir(*path):
2046     """In a testenv we end up with 0777 diectories that look an alarming
2047     green colour with ls. Use umask to avoid that."""
2048     d = os.path.join(*path)
2049     mask = os.umask(0o077)
2050     os.mkdir(d)
2051     os.umask(mask)
2052     return d