emulate/traffic: apply new logger to replace print
[samba.git] / python / samba / emulate / traffic.py
1 # -*- encoding: utf-8 -*-
2 # Samba traffic replay and learning
3 #
4 # Copyright (C) Catalyst IT Ltd. 2017
5 #
6 # This program is free software; you can redistribute it and/or modify
7 # it under the terms of the GNU General Public License as published by
8 # the Free Software Foundation; either version 3 of the License, or
9 # (at your option) any later version.
10 #
11 # This program is distributed in the hope that it will be useful,
12 # but WITHOUT ANY WARRANTY; without even the implied warranty of
13 # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
14 # GNU General Public License for more details.
15 #
16 # You should have received a copy of the GNU General Public License
17 # along with this program.  If not, see <http://www.gnu.org/licenses/>.
18 #
19 from __future__ import print_function, division
20
21 import time
22 import os
23 import random
24 import json
25 import math
26 import sys
27 import signal
28 import itertools
29
30 from collections import OrderedDict, Counter, defaultdict
31 from samba.emulate import traffic_packets
32 from samba.samdb import SamDB
33 import ldb
34 from ldb import LdbError
35 from samba.dcerpc import ClientConnection
36 from samba.dcerpc import security, drsuapi, lsa
37 from samba.dcerpc import netlogon
38 from samba.dcerpc.netlogon import netr_Authenticator
39 from samba.dcerpc import srvsvc
40 from samba.dcerpc import samr
41 from samba.drs_utils import drs_DsBind
42 import traceback
43 from samba.credentials import Credentials, DONT_USE_KERBEROS, MUST_USE_KERBEROS
44 from samba.auth import system_session
45 from samba.dsdb import (
46     UF_NORMAL_ACCOUNT,
47     UF_SERVER_TRUST_ACCOUNT,
48     UF_TRUSTED_FOR_DELEGATION
49 )
50 from samba.dcerpc.misc import SEC_CHAN_BDC
51 from samba import gensec
52 from samba import sd_utils
53 from samba.compat import get_string
54 from samba.logger import get_samba_logger
55
56 SLEEP_OVERHEAD = 3e-4
57
58 # we don't use None, because it complicates [de]serialisation
59 NON_PACKET = '-'
60
61 CLIENT_CLUES = {
62     ('dns', '0'): 1.0,      # query
63     ('smb', '0x72'): 1.0,   # Negotiate protocol
64     ('ldap', '0'): 1.0,     # bind
65     ('ldap', '3'): 1.0,     # searchRequest
66     ('ldap', '2'): 1.0,     # unbindRequest
67     ('cldap', '3'): 1.0,
68     ('dcerpc', '11'): 1.0,  # bind
69     ('dcerpc', '14'): 1.0,  # Alter_context
70     ('nbns', '0'): 1.0,     # query
71 }
72
73 SERVER_CLUES = {
74     ('dns', '1'): 1.0,      # response
75     ('ldap', '1'): 1.0,     # bind response
76     ('ldap', '4'): 1.0,     # search result
77     ('ldap', '5'): 1.0,     # search done
78     ('cldap', '5'): 1.0,
79     ('dcerpc', '12'): 1.0,  # bind_ack
80     ('dcerpc', '13'): 1.0,  # bind_nak
81     ('dcerpc', '15'): 1.0,  # Alter_context response
82 }
83
84 SKIPPED_PROTOCOLS = {"smb", "smb2", "browser", "smb_netlogon"}
85
86 WAIT_SCALE = 10.0
87 WAIT_THRESHOLD = (1.0 / WAIT_SCALE)
88 NO_WAIT_LOG_TIME_RANGE = (-10, -3)
89
90 # DEBUG_LEVEL can be changed by scripts with -d
91 DEBUG_LEVEL = 0
92
93 LOGGER = get_samba_logger(name=__name__)
94
95
96 def debug(level, msg, *args):
97     """Print a formatted debug message to standard error.
98
99
100     :param level: The debug level, message will be printed if it is <= the
101                   currently set debug level. The debug level can be set with
102                   the -d option.
103     :param msg:   The message to be logged, can contain C-Style format
104                   specifiers
105     :param args:  The parameters required by the format specifiers
106     """
107     if level <= DEBUG_LEVEL:
108         if not args:
109             print(msg, file=sys.stderr)
110         else:
111             print(msg % tuple(args), file=sys.stderr)
112
113
114 def debug_lineno(*args):
115     """ Print an unformatted log message to stderr, contaning the line number
116     """
117     tb = traceback.extract_stack(limit=2)
118     print((" %s:" "\033[01;33m"
119            "%s " "\033[00m" % (tb[0][2], tb[0][1])), end=' ',
120           file=sys.stderr)
121     for a in args:
122         print(a, file=sys.stderr)
123     print(file=sys.stderr)
124     sys.stderr.flush()
125
126
127 def random_colour_print():
128     """Return a function that prints a randomly coloured line to stderr"""
129     n = 18 + random.randrange(214)
130     prefix = "\033[38;5;%dm" % n
131
132     def p(*args):
133         for a in args:
134             print("%s%s\033[00m" % (prefix, a), file=sys.stderr)
135
136     return p
137
138
139 class FakePacketError(Exception):
140     pass
141
142
143 class Packet(object):
144     """Details of a network packet"""
145     def __init__(self, timestamp, ip_protocol, stream_number, src, dest,
146                  protocol, opcode, desc, extra):
147
148         self.timestamp = timestamp
149         self.ip_protocol = ip_protocol
150         self.stream_number = stream_number
151         self.src = src
152         self.dest = dest
153         self.protocol = protocol
154         self.opcode = opcode
155         self.desc = desc
156         self.extra = extra
157         if self.src < self.dest:
158             self.endpoints = (self.src, self.dest)
159         else:
160             self.endpoints = (self.dest, self.src)
161
162     @classmethod
163     def from_line(self, line):
164         fields = line.rstrip('\n').split('\t')
165         (timestamp,
166          ip_protocol,
167          stream_number,
168          src,
169          dest,
170          protocol,
171          opcode,
172          desc) = fields[:8]
173         extra = fields[8:]
174
175         timestamp = float(timestamp)
176         src = int(src)
177         dest = int(dest)
178
179         return Packet(timestamp, ip_protocol, stream_number, src, dest,
180                       protocol, opcode, desc, extra)
181
182     def as_summary(self, time_offset=0.0):
183         """Format the packet as a traffic_summary line.
184         """
185         extra = '\t'.join(self.extra)
186         t = self.timestamp + time_offset
187         return (t, '%f\t%s\t%s\t%d\t%d\t%s\t%s\t%s\t%s' %
188                 (t,
189                  self.ip_protocol,
190                  self.stream_number or '',
191                  self.src,
192                  self.dest,
193                  self.protocol,
194                  self.opcode,
195                  self.desc,
196                  extra))
197
198     def __str__(self):
199         return ("%.3f: %d -> %d; ip %s; strm %s; prot %s; op %s; desc %s %s" %
200                 (self.timestamp, self.src, self.dest, self.ip_protocol or '-',
201                  self.stream_number, self.protocol, self.opcode, self.desc,
202                  ('«' + ' '.join(self.extra) + '»' if self.extra else '')))
203
204     def __repr__(self):
205         return "<Packet @%s>" % self
206
207     def copy(self):
208         return self.__class__(self.timestamp,
209                               self.ip_protocol,
210                               self.stream_number,
211                               self.src,
212                               self.dest,
213                               self.protocol,
214                               self.opcode,
215                               self.desc,
216                               self.extra)
217
218     def as_packet_type(self):
219         t = '%s:%s' % (self.protocol, self.opcode)
220         return t
221
222     def client_score(self):
223         """A positive number means we think it is a client; a negative number
224         means we think it is a server. Zero means no idea. range: -1 to 1.
225         """
226         key = (self.protocol, self.opcode)
227         if key in CLIENT_CLUES:
228             return CLIENT_CLUES[key]
229         if key in SERVER_CLUES:
230             return -SERVER_CLUES[key]
231         return 0.0
232
233     def play(self, conversation, context):
234         """Send the packet over the network, if required.
235
236         Some packets are ignored, i.e. for  protocols not handled,
237         server response messages, or messages that are generated by the
238         protocol layer associated with other packets.
239         """
240         fn_name = 'packet_%s_%s' % (self.protocol, self.opcode)
241         try:
242             fn = getattr(traffic_packets, fn_name)
243
244         except AttributeError as e:
245             print("Conversation(%s) Missing handler %s" %
246                   (conversation.conversation_id, fn_name),
247                   file=sys.stderr)
248             return
249
250         # Don't display a message for kerberos packets, they're not directly
251         # generated they're used to indicate kerberos should be used
252         if self.protocol != "kerberos":
253             debug(2, "Conversation(%s) Calling handler %s" %
254                      (conversation.conversation_id, fn_name))
255
256         start = time.time()
257         try:
258             if fn(self, conversation, context):
259                 # Only collect timing data for functions that generate
260                 # network traffic, or fail
261                 end = time.time()
262                 duration = end - start
263                 print("%f\t%s\t%s\t%s\t%f\tTrue\t" %
264                       (end, conversation.conversation_id, self.protocol,
265                        self.opcode, duration))
266         except Exception as e:
267             end = time.time()
268             duration = end - start
269             print("%f\t%s\t%s\t%s\t%f\tFalse\t%s" %
270                   (end, conversation.conversation_id, self.protocol,
271                    self.opcode, duration, e))
272
273     def __cmp__(self, other):
274         return self.timestamp - other.timestamp
275
276     def is_really_a_packet(self, missing_packet_stats=None):
277         """Is the packet one that can be ignored?
278
279         If so removing it will have no effect on the replay
280         """
281         if self.protocol in SKIPPED_PROTOCOLS:
282             # Ignore any packets for the protocols we're not interested in.
283             return False
284         if self.protocol == "ldap" and self.opcode == '':
285             # skip ldap continuation packets
286             return False
287
288         fn_name = 'packet_%s_%s' % (self.protocol, self.opcode)
289         fn = getattr(traffic_packets, fn_name, None)
290         if not fn:
291             print("missing packet %s" % fn_name, file=sys.stderr)
292             return False
293         if fn is traffic_packets.null_packet:
294             return False
295         return True
296
297
298 class ReplayContext(object):
299     """State/Context for an individual conversation between an simulated client
300        and a server.
301     """
302
303     def __init__(self,
304                  server=None,
305                  lp=None,
306                  creds=None,
307                  badpassword_frequency=None,
308                  prefer_kerberos=None,
309                  tempdir=None,
310                  statsdir=None,
311                  ou=None,
312                  base_dn=None,
313                  domain=None,
314                  domain_sid=None):
315
316         self.server                   = server
317         self.ldap_connections         = []
318         self.dcerpc_connections       = []
319         self.lsarpc_connections       = []
320         self.lsarpc_connections_named = []
321         self.drsuapi_connections      = []
322         self.srvsvc_connections       = []
323         self.samr_contexts            = []
324         self.netlogon_connection      = None
325         self.creds                    = creds
326         self.lp                       = lp
327         self.prefer_kerberos          = prefer_kerberos
328         self.ou                       = ou
329         self.base_dn                  = base_dn
330         self.domain                   = domain
331         self.statsdir                 = statsdir
332         self.global_tempdir           = tempdir
333         self.domain_sid               = domain_sid
334         self.realm                    = lp.get('realm')
335
336         # Bad password attempt controls
337         self.badpassword_frequency    = badpassword_frequency
338         self.last_lsarpc_bad          = False
339         self.last_lsarpc_named_bad    = False
340         self.last_simple_bind_bad     = False
341         self.last_bind_bad            = False
342         self.last_srvsvc_bad          = False
343         self.last_drsuapi_bad         = False
344         self.last_netlogon_bad        = False
345         self.last_samlogon_bad        = False
346         self.generate_ldap_search_tables()
347         self.next_conversation_id = itertools.count()
348
349     def generate_ldap_search_tables(self):
350         session = system_session()
351
352         db = SamDB(url="ldap://%s" % self.server,
353                    session_info=session,
354                    credentials=self.creds,
355                    lp=self.lp)
356
357         res = db.search(db.domain_dn(),
358                         scope=ldb.SCOPE_SUBTREE,
359                         controls=["paged_results:1:1000"],
360                         attrs=['dn'])
361
362         # find a list of dns for each pattern
363         # e.g. CN,CN,CN,DC,DC
364         dn_map = {}
365         attribute_clue_map = {
366             'invocationId': []
367         }
368
369         for r in res:
370             dn = str(r.dn)
371             pattern = ','.join(x.lstrip()[:2] for x in dn.split(',')).upper()
372             dns = dn_map.setdefault(pattern, [])
373             dns.append(dn)
374             if dn.startswith('CN=NTDS Settings,'):
375                 attribute_clue_map['invocationId'].append(dn)
376
377         # extend the map in case we are working with a different
378         # number of DC components.
379         # for k, v in self.dn_map.items():
380         #     print >>sys.stderr, k, len(v)
381
382         for k in list(dn_map.keys()):
383             if k[-3:] != ',DC':
384                 continue
385             p = k[:-3]
386             while p[-3:] == ',DC':
387                 p = p[:-3]
388             for i in range(5):
389                 p += ',DC'
390                 if p != k and p in dn_map:
391                     print('dn_map collison %s %s' % (k, p),
392                           file=sys.stderr)
393                     continue
394                 dn_map[p] = dn_map[k]
395
396         self.dn_map = dn_map
397         self.attribute_clue_map = attribute_clue_map
398
399     def generate_process_local_config(self, account, conversation):
400         if account is None:
401             return
402         self.netbios_name             = account.netbios_name
403         self.machinepass              = account.machinepass
404         self.username                 = account.username
405         self.userpass                 = account.userpass
406
407         self.tempdir = mk_masked_dir(self.global_tempdir,
408                                      'conversation-%d' %
409                                      conversation.conversation_id)
410
411         self.lp.set("private dir", self.tempdir)
412         self.lp.set("lock dir", self.tempdir)
413         self.lp.set("state directory", self.tempdir)
414         self.lp.set("tls verify peer", "no_check")
415
416         # If the domain was not specified, check for the environment
417         # variable.
418         if self.domain is None:
419             self.domain = os.environ["DOMAIN"]
420
421         self.remoteAddress = "/root/ncalrpc_as_system"
422         self.samlogon_dn   = ("cn=%s,%s" %
423                               (self.netbios_name, self.ou))
424         self.user_dn       = ("cn=%s,%s" %
425                               (self.username, self.ou))
426
427         self.generate_machine_creds()
428         self.generate_user_creds()
429
430     def with_random_bad_credentials(self, f, good, bad, failed_last_time):
431         """Execute the supplied logon function, randomly choosing the
432            bad credentials.
433
434            Based on the frequency in badpassword_frequency randomly perform the
435            function with the supplied bad credentials.
436            If run with bad credentials, the function is re-run with the good
437            credentials.
438            failed_last_time is used to prevent consecutive bad credential
439            attempts. So the over all bad credential frequency will be lower
440            than that requested, but not significantly.
441         """
442         if not failed_last_time:
443             if (self.badpassword_frequency and self.badpassword_frequency > 0
444                 and random.random() < self.badpassword_frequency):
445                 try:
446                     f(bad)
447                 except:
448                     # Ignore any exceptions as the operation may fail
449                     # as it's being performed with bad credentials
450                     pass
451                 failed_last_time = True
452             else:
453                 failed_last_time = False
454
455         result = f(good)
456         return (result, failed_last_time)
457
458     def generate_user_creds(self):
459         """Generate the conversation specific user Credentials.
460
461         Each Conversation has an associated user account used to simulate
462         any non Administrative user traffic.
463
464         Generates user credentials with good and bad passwords and ldap
465         simple bind credentials with good and bad passwords.
466         """
467         self.user_creds = Credentials()
468         self.user_creds.guess(self.lp)
469         self.user_creds.set_workstation(self.netbios_name)
470         self.user_creds.set_password(self.userpass)
471         self.user_creds.set_username(self.username)
472         self.user_creds.set_domain(self.domain)
473         if self.prefer_kerberos:
474             self.user_creds.set_kerberos_state(MUST_USE_KERBEROS)
475         else:
476             self.user_creds.set_kerberos_state(DONT_USE_KERBEROS)
477
478         self.user_creds_bad = Credentials()
479         self.user_creds_bad.guess(self.lp)
480         self.user_creds_bad.set_workstation(self.netbios_name)
481         self.user_creds_bad.set_password(self.userpass[:-4])
482         self.user_creds_bad.set_username(self.username)
483         if self.prefer_kerberos:
484             self.user_creds_bad.set_kerberos_state(MUST_USE_KERBEROS)
485         else:
486             self.user_creds_bad.set_kerberos_state(DONT_USE_KERBEROS)
487
488         # Credentials for ldap simple bind.
489         self.simple_bind_creds = Credentials()
490         self.simple_bind_creds.guess(self.lp)
491         self.simple_bind_creds.set_workstation(self.netbios_name)
492         self.simple_bind_creds.set_password(self.userpass)
493         self.simple_bind_creds.set_username(self.username)
494         self.simple_bind_creds.set_gensec_features(
495             self.simple_bind_creds.get_gensec_features() | gensec.FEATURE_SEAL)
496         if self.prefer_kerberos:
497             self.simple_bind_creds.set_kerberos_state(MUST_USE_KERBEROS)
498         else:
499             self.simple_bind_creds.set_kerberos_state(DONT_USE_KERBEROS)
500         self.simple_bind_creds.set_bind_dn(self.user_dn)
501
502         self.simple_bind_creds_bad = Credentials()
503         self.simple_bind_creds_bad.guess(self.lp)
504         self.simple_bind_creds_bad.set_workstation(self.netbios_name)
505         self.simple_bind_creds_bad.set_password(self.userpass[:-4])
506         self.simple_bind_creds_bad.set_username(self.username)
507         self.simple_bind_creds_bad.set_gensec_features(
508             self.simple_bind_creds_bad.get_gensec_features() |
509             gensec.FEATURE_SEAL)
510         if self.prefer_kerberos:
511             self.simple_bind_creds_bad.set_kerberos_state(MUST_USE_KERBEROS)
512         else:
513             self.simple_bind_creds_bad.set_kerberos_state(DONT_USE_KERBEROS)
514         self.simple_bind_creds_bad.set_bind_dn(self.user_dn)
515
516     def generate_machine_creds(self):
517         """Generate the conversation specific machine Credentials.
518
519         Each Conversation has an associated machine account.
520
521         Generates machine credentials with good and bad passwords.
522         """
523
524         self.machine_creds = Credentials()
525         self.machine_creds.guess(self.lp)
526         self.machine_creds.set_workstation(self.netbios_name)
527         self.machine_creds.set_secure_channel_type(SEC_CHAN_BDC)
528         self.machine_creds.set_password(self.machinepass)
529         self.machine_creds.set_username(self.netbios_name + "$")
530         self.machine_creds.set_domain(self.domain)
531         if self.prefer_kerberos:
532             self.machine_creds.set_kerberos_state(MUST_USE_KERBEROS)
533         else:
534             self.machine_creds.set_kerberos_state(DONT_USE_KERBEROS)
535
536         self.machine_creds_bad = Credentials()
537         self.machine_creds_bad.guess(self.lp)
538         self.machine_creds_bad.set_workstation(self.netbios_name)
539         self.machine_creds_bad.set_secure_channel_type(SEC_CHAN_BDC)
540         self.machine_creds_bad.set_password(self.machinepass[:-4])
541         self.machine_creds_bad.set_username(self.netbios_name + "$")
542         if self.prefer_kerberos:
543             self.machine_creds_bad.set_kerberos_state(MUST_USE_KERBEROS)
544         else:
545             self.machine_creds_bad.set_kerberos_state(DONT_USE_KERBEROS)
546
547     def get_matching_dn(self, pattern, attributes=None):
548         # If the pattern is an empty string, we assume ROOTDSE,
549         # Otherwise we try adding or removing DC suffixes, then
550         # shorter leading patterns until we hit one.
551         # e.g if there is no CN,CN,CN,CN,DC,DC
552         # we first try       CN,CN,CN,CN,DC
553         # and                CN,CN,CN,CN,DC,DC,DC
554         # then change to        CN,CN,CN,DC,DC
555         # and as last resort we use the base_dn
556         attr_clue = self.attribute_clue_map.get(attributes)
557         if attr_clue:
558             return random.choice(attr_clue)
559
560         pattern = pattern.upper()
561         while pattern:
562             if pattern in self.dn_map:
563                 return random.choice(self.dn_map[pattern])
564             # chop one off the front and try it all again.
565             pattern = pattern[3:]
566
567         return self.base_dn
568
569     def get_dcerpc_connection(self, new=False):
570         guid = '12345678-1234-abcd-ef00-01234567cffb'  # RPC_NETLOGON UUID
571         if self.dcerpc_connections and not new:
572             return self.dcerpc_connections[-1]
573         c = ClientConnection("ncacn_ip_tcp:%s" % self.server,
574                              (guid, 1), self.lp)
575         self.dcerpc_connections.append(c)
576         return c
577
578     def get_srvsvc_connection(self, new=False):
579         if self.srvsvc_connections and not new:
580             return self.srvsvc_connections[-1]
581
582         def connect(creds):
583             return srvsvc.srvsvc("ncacn_np:%s" % (self.server),
584                                  self.lp,
585                                  creds)
586
587         (c, self.last_srvsvc_bad) = \
588             self.with_random_bad_credentials(connect,
589                                              self.user_creds,
590                                              self.user_creds_bad,
591                                              self.last_srvsvc_bad)
592
593         self.srvsvc_connections.append(c)
594         return c
595
596     def get_lsarpc_connection(self, new=False):
597         if self.lsarpc_connections and not new:
598             return self.lsarpc_connections[-1]
599
600         def connect(creds):
601             binding_options = 'schannel,seal,sign'
602             return lsa.lsarpc("ncacn_ip_tcp:%s[%s]" %
603                               (self.server, binding_options),
604                               self.lp,
605                               creds)
606
607         (c, self.last_lsarpc_bad) = \
608             self.with_random_bad_credentials(connect,
609                                              self.machine_creds,
610                                              self.machine_creds_bad,
611                                              self.last_lsarpc_bad)
612
613         self.lsarpc_connections.append(c)
614         return c
615
616     def get_lsarpc_named_pipe_connection(self, new=False):
617         if self.lsarpc_connections_named and not new:
618             return self.lsarpc_connections_named[-1]
619
620         def connect(creds):
621             return lsa.lsarpc("ncacn_np:%s" % (self.server),
622                               self.lp,
623                               creds)
624
625         (c, self.last_lsarpc_named_bad) = \
626             self.with_random_bad_credentials(connect,
627                                              self.machine_creds,
628                                              self.machine_creds_bad,
629                                              self.last_lsarpc_named_bad)
630
631         self.lsarpc_connections_named.append(c)
632         return c
633
634     def get_drsuapi_connection_pair(self, new=False, unbind=False):
635         """get a (drs, drs_handle) tuple"""
636         if self.drsuapi_connections and not new:
637             c = self.drsuapi_connections[-1]
638             return c
639
640         def connect(creds):
641             binding_options = 'seal'
642             binding_string = "ncacn_ip_tcp:%s[%s]" %\
643                              (self.server, binding_options)
644             return drsuapi.drsuapi(binding_string, self.lp, creds)
645
646         (drs, self.last_drsuapi_bad) = \
647             self.with_random_bad_credentials(connect,
648                                              self.user_creds,
649                                              self.user_creds_bad,
650                                              self.last_drsuapi_bad)
651
652         (drs_handle, supported_extensions) = drs_DsBind(drs)
653         c = (drs, drs_handle)
654         self.drsuapi_connections.append(c)
655         return c
656
657     def get_ldap_connection(self, new=False, simple=False):
658         if self.ldap_connections and not new:
659             return self.ldap_connections[-1]
660
661         def simple_bind(creds):
662             """
663             To run simple bind against Windows, we need to run
664             following commands in PowerShell:
665
666                 Install-windowsfeature ADCS-Cert-Authority
667                 Install-AdcsCertificationAuthority -CAType EnterpriseRootCA
668                 Restart-Computer
669
670             """
671             return SamDB('ldaps://%s' % self.server,
672                          credentials=creds,
673                          lp=self.lp)
674
675         def sasl_bind(creds):
676             return SamDB('ldap://%s' % self.server,
677                          credentials=creds,
678                          lp=self.lp)
679         if simple:
680             (samdb, self.last_simple_bind_bad) = \
681                 self.with_random_bad_credentials(simple_bind,
682                                                  self.simple_bind_creds,
683                                                  self.simple_bind_creds_bad,
684                                                  self.last_simple_bind_bad)
685         else:
686             (samdb, self.last_bind_bad) = \
687                 self.with_random_bad_credentials(sasl_bind,
688                                                  self.user_creds,
689                                                  self.user_creds_bad,
690                                                  self.last_bind_bad)
691
692         self.ldap_connections.append(samdb)
693         return samdb
694
695     def get_samr_context(self, new=False):
696         if not self.samr_contexts or new:
697             self.samr_contexts.append(
698                 SamrContext(self.server, lp=self.lp, creds=self.creds))
699         return self.samr_contexts[-1]
700
701     def get_netlogon_connection(self):
702
703         if self.netlogon_connection:
704             return self.netlogon_connection
705
706         def connect(creds):
707             return netlogon.netlogon("ncacn_ip_tcp:%s[schannel,seal]" %
708                                      (self.server),
709                                      self.lp,
710                                      creds)
711         (c, self.last_netlogon_bad) = \
712             self.with_random_bad_credentials(connect,
713                                              self.machine_creds,
714                                              self.machine_creds_bad,
715                                              self.last_netlogon_bad)
716         self.netlogon_connection = c
717         return c
718
719     def guess_a_dns_lookup(self):
720         return (self.realm, 'A')
721
722     def get_authenticator(self):
723         auth = self.machine_creds.new_client_authenticator()
724         current  = netr_Authenticator()
725         current.cred.data = [x if isinstance(x, int) else ord(x) for x in auth["credential"]]
726         current.timestamp = auth["timestamp"]
727
728         subsequent = netr_Authenticator()
729         return (current, subsequent)
730
731
732 class SamrContext(object):
733     """State/Context associated with a samr connection.
734     """
735     def __init__(self, server, lp=None, creds=None):
736         self.connection    = None
737         self.handle        = None
738         self.domain_handle = None
739         self.domain_sid    = None
740         self.group_handle  = None
741         self.user_handle   = None
742         self.rids          = None
743         self.server        = server
744         self.lp            = lp
745         self.creds         = creds
746
747     def get_connection(self):
748         if not self.connection:
749             self.connection = samr.samr(
750                 "ncacn_ip_tcp:%s[seal]" % (self.server),
751                 lp_ctx=self.lp,
752                 credentials=self.creds)
753
754         return self.connection
755
756     def get_handle(self):
757         if not self.handle:
758             c = self.get_connection()
759             self.handle = c.Connect2(None, security.SEC_FLAG_MAXIMUM_ALLOWED)
760         return self.handle
761
762
763 class Conversation(object):
764     """Details of a converation between a simulated client and a server."""
765     conversation_id = None
766
767     def __init__(self, start_time=None, endpoints=None):
768         self.start_time = start_time
769         self.endpoints = endpoints
770         self.packets = []
771         self.msg = random_colour_print()
772         self.client_balance = 0.0
773
774     def __cmp__(self, other):
775         if self.start_time is None:
776             if other.start_time is None:
777                 return 0
778             return -1
779         if other.start_time is None:
780             return 1
781         return self.start_time - other.start_time
782
783     def add_packet(self, packet):
784         """Add a packet object to this conversation, making a local copy with
785         a conversation-relative timestamp."""
786         p = packet.copy()
787
788         if self.start_time is None:
789             self.start_time = p.timestamp
790
791         if self.endpoints is None:
792             self.endpoints = p.endpoints
793
794         if p.endpoints != self.endpoints:
795             raise FakePacketError("Conversation endpoints %s don't match"
796                                   "packet endpoints %s" %
797                                   (self.endpoints, p.endpoints))
798
799         p.timestamp -= self.start_time
800
801         if p.src == p.endpoints[0]:
802             self.client_balance -= p.client_score()
803         else:
804             self.client_balance += p.client_score()
805
806         if p.is_really_a_packet():
807             self.packets.append(p)
808
809     def add_short_packet(self, timestamp, protocol, opcode, extra,
810                          client=True):
811         """Create a packet from a timestamp, and 'protocol:opcode' pair, and a
812         (possibly empty) list of extra data. If client is True, assume
813         this packet is from the client to the server.
814         """
815         src, dest = self.guess_client_server()
816         if not client:
817             src, dest = dest, src
818         key = (protocol, opcode)
819         desc = OP_DESCRIPTIONS[key] if key in OP_DESCRIPTIONS else ''
820         if protocol in IP_PROTOCOLS:
821             ip_protocol = IP_PROTOCOLS[protocol]
822         else:
823             ip_protocol = '06'
824         packet = Packet(timestamp - self.start_time, ip_protocol,
825                         '', src, dest,
826                         protocol, opcode, desc, extra)
827         # XXX we're assuming the timestamp is already adjusted for
828         # this conversation?
829         # XXX should we adjust client balance for guessed packets?
830         if packet.src == packet.endpoints[0]:
831             self.client_balance -= packet.client_score()
832         else:
833             self.client_balance += packet.client_score()
834         if packet.is_really_a_packet():
835             self.packets.append(packet)
836
837     def __str__(self):
838         return ("<Conversation %s %s starting %.3f %d packets>" %
839                 (self.conversation_id, self.endpoints, self.start_time,
840                  len(self.packets)))
841
842     __repr__ = __str__
843
844     def __iter__(self):
845         return iter(self.packets)
846
847     def __len__(self):
848         return len(self.packets)
849
850     def get_duration(self):
851         if len(self.packets) < 2:
852             return 0
853         return self.packets[-1].timestamp - self.packets[0].timestamp
854
855     def replay_as_summary_lines(self):
856         lines = []
857         for p in self.packets:
858             lines.append(p.as_summary(self.start_time))
859         return lines
860
861     def replay_in_fork_with_delay(self, start, context=None, account=None):
862         """Fork a new process and replay the conversation.
863         """
864         def signal_handler(signal, frame):
865             """Signal handler closes standard out and error.
866
867             Triggered by a sigterm, ensures that the log messages are flushed
868             to disk and not lost.
869             """
870             sys.stderr.close()
871             sys.stdout.close()
872             os._exit(0)
873
874         t = self.start_time
875         now = time.time() - start
876         gap = t - now
877         # we are replaying strictly in order, so it is safe to sleep
878         # in the main process if the gap is big enough. This reduces
879         # the number of concurrent threads, which allows us to make
880         # larger loads.
881         if gap > 0.15 and False:
882             print("sleeping for %f in main process" % (gap - 0.1),
883                   file=sys.stderr)
884             time.sleep(gap - 0.1)
885             now = time.time() - start
886             gap = t - now
887             print("gap is now %f" % gap, file=sys.stderr)
888
889         self.conversation_id = next(context.next_conversation_id)
890         pid = os.fork()
891         if pid != 0:
892             return pid
893         pid = os.getpid()
894         signal.signal(signal.SIGTERM, signal_handler)
895         # we must never return, or we'll end up running parts of the
896         # parent's clean-up code. So we work in a try...finally, and
897         # try to print any exceptions.
898
899         try:
900             context.generate_process_local_config(account, self)
901             sys.stdin.close()
902             os.close(0)
903             filename = os.path.join(context.statsdir, 'stats-conversation-%d' %
904                                     self.conversation_id)
905             sys.stdout.close()
906             sys.stdout = open(filename, 'w')
907
908             sleep_time = gap - SLEEP_OVERHEAD
909             if sleep_time > 0:
910                 time.sleep(sleep_time)
911
912             miss = t - (time.time() - start)
913             self.msg("starting %s [miss %.3f pid %d]" % (self, miss, pid))
914             self.replay(context)
915         except Exception:
916             print(("EXCEPTION in child PID %d, conversation %s" % (pid, self)),
917                   file=sys.stderr)
918             traceback.print_exc(sys.stderr)
919         finally:
920             sys.stderr.close()
921             sys.stdout.close()
922             os._exit(0)
923
924     def replay(self, context=None):
925         start = time.time()
926
927         for p in self.packets:
928             now = time.time() - start
929             gap = p.timestamp - now
930             sleep_time = gap - SLEEP_OVERHEAD
931             if sleep_time > 0:
932                 time.sleep(sleep_time)
933
934             miss = p.timestamp - (time.time() - start)
935             if context is None:
936                 self.msg("packet %s [miss %.3f pid %d]" % (p, miss,
937                                                            os.getpid()))
938                 continue
939             p.play(self, context)
940
941     def guess_client_server(self, server_clue=None):
942         """Have a go at deciding who is the server and who is the client.
943         returns (client, server)
944         """
945         a, b = self.endpoints
946
947         if self.client_balance < 0:
948             return (a, b)
949
950         # in the absense of a clue, we will fall through to assuming
951         # the lowest number is the server (which is usually true).
952
953         if self.client_balance == 0 and server_clue == b:
954             return (a, b)
955
956         return (b, a)
957
958     def forget_packets_outside_window(self, s, e):
959         """Prune any packets outside the timne window we're interested in
960
961         :param s: start of the window
962         :param e: end of the window
963         """
964         self.packets = [p for p in self.packets if s <= p.timestamp <= e]
965         self.start_time = self.packets[0].timestamp if self.packets else None
966
967     def renormalise_times(self, start_time):
968         """Adjust the packet start times relative to the new start time."""
969         for p in self.packets:
970             p.timestamp -= start_time
971
972         if self.start_time is not None:
973             self.start_time -= start_time
974
975
976 class DnsHammer(Conversation):
977     """A lightweight conversation that generates a lot of dns:0 packets on
978     the fly"""
979
980     def __init__(self, dns_rate, duration):
981         n = int(dns_rate * duration)
982         self.times = [random.uniform(0, duration) for i in range(n)]
983         self.times.sort()
984         self.rate = dns_rate
985         self.duration = duration
986         self.start_time = 0
987         self.msg = random_colour_print()
988
989     def __str__(self):
990         return ("<DnsHammer %d packets over %.1fs (rate %.2f)>" %
991                 (len(self.times), self.duration, self.rate))
992
993     def replay_in_fork_with_delay(self, start, context=None, account=None):
994         return Conversation.replay_in_fork_with_delay(self,
995                                                       start,
996                                                       context,
997                                                       account)
998
999     def replay(self, context=None):
1000         start = time.time()
1001         fn = traffic_packets.packet_dns_0
1002         for t in self.times:
1003             now = time.time() - start
1004             gap = t - now
1005             sleep_time = gap - SLEEP_OVERHEAD
1006             if sleep_time > 0:
1007                 time.sleep(sleep_time)
1008
1009             if context is None:
1010                 miss = t - (time.time() - start)
1011                 self.msg("packet %s [miss %.3f pid %d]" % (t, miss,
1012                                                            os.getpid()))
1013                 continue
1014
1015             packet_start = time.time()
1016             try:
1017                 fn(self, self, context)
1018                 end = time.time()
1019                 duration = end - packet_start
1020                 print("%f\tDNS\tdns\t0\t%f\tTrue\t" % (end, duration))
1021             except Exception as e:
1022                 end = time.time()
1023                 duration = end - packet_start
1024                 print("%f\tDNS\tdns\t0\t%f\tFalse\t%s" % (end, duration, e))
1025
1026
1027 def ingest_summaries(files, dns_mode='count'):
1028     """Load a summary traffic summary file and generated Converations from it.
1029     """
1030
1031     dns_counts = defaultdict(int)
1032     packets = []
1033     for f in files:
1034         if isinstance(f, str):
1035             f = open(f)
1036         print("Ingesting %s" % (f.name,), file=sys.stderr)
1037         for line in f:
1038             p = Packet.from_line(line)
1039             if p.protocol == 'dns' and dns_mode != 'include':
1040                 dns_counts[p.opcode] += 1
1041             else:
1042                 packets.append(p)
1043
1044         f.close()
1045
1046     if not packets:
1047         return [], 0
1048
1049     start_time = min(p.timestamp for p in packets)
1050     last_packet = max(p.timestamp for p in packets)
1051
1052     print("gathering packets into conversations", file=sys.stderr)
1053     conversations = OrderedDict()
1054     for p in packets:
1055         p.timestamp -= start_time
1056         c = conversations.get(p.endpoints)
1057         if c is None:
1058             c = Conversation()
1059             conversations[p.endpoints] = c
1060         c.add_packet(p)
1061
1062     # We only care about conversations with actual traffic, so we
1063     # filter out conversations with nothing to say. We do that here,
1064     # rather than earlier, because those empty packets contain useful
1065     # hints as to which end of the conversation was the client.
1066     conversation_list = []
1067     for c in conversations.values():
1068         if len(c) != 0:
1069             conversation_list.append(c)
1070
1071     # This is obviously not correct, as many conversations will appear
1072     # to start roughly simultaneously at the beginning of the snapshot.
1073     # To which we say: oh well, so be it.
1074     duration = float(last_packet - start_time)
1075     mean_interval = len(conversations) / duration
1076
1077     return conversation_list, mean_interval, duration, dns_counts
1078
1079
1080 def guess_server_address(conversations):
1081     # we guess the most common address.
1082     addresses = Counter()
1083     for c in conversations:
1084         addresses.update(c.endpoints)
1085     if addresses:
1086         return addresses.most_common(1)[0]
1087
1088
1089 def stringify_keys(x):
1090     y = {}
1091     for k, v in x.items():
1092         k2 = '\t'.join(k)
1093         y[k2] = v
1094     return y
1095
1096
1097 def unstringify_keys(x):
1098     y = {}
1099     for k, v in x.items():
1100         t = tuple(str(k).split('\t'))
1101         y[t] = v
1102     return y
1103
1104
1105 class TrafficModel(object):
1106     def __init__(self, n=3):
1107         self.ngrams = {}
1108         self.query_details = {}
1109         self.n = n
1110         self.dns_opcounts = defaultdict(int)
1111         self.cumulative_duration = 0.0
1112         self.conversation_rate = [0, 1]
1113
1114     def learn(self, conversations, dns_opcounts={}):
1115         prev = 0.0
1116         cum_duration = 0.0
1117         key = (NON_PACKET,) * (self.n - 1)
1118
1119         server = guess_server_address(conversations)
1120
1121         for k, v in dns_opcounts.items():
1122             self.dns_opcounts[k] += v
1123
1124         if len(conversations) > 1:
1125             elapsed =\
1126                 conversations[-1].start_time - conversations[0].start_time
1127             self.conversation_rate[0] = len(conversations)
1128             self.conversation_rate[1] = elapsed
1129
1130         for c in conversations:
1131             client, server = c.guess_client_server(server)
1132             cum_duration += c.get_duration()
1133             key = (NON_PACKET,) * (self.n - 1)
1134             for p in c:
1135                 if p.src != client:
1136                     continue
1137
1138                 elapsed = p.timestamp - prev
1139                 prev = p.timestamp
1140                 if elapsed > WAIT_THRESHOLD:
1141                     # add the wait as an extra state
1142                     wait = 'wait:%d' % (math.log(max(1.0,
1143                                                      elapsed * WAIT_SCALE)))
1144                     self.ngrams.setdefault(key, []).append(wait)
1145                     key = key[1:] + (wait,)
1146
1147                 short_p = p.as_packet_type()
1148                 self.query_details.setdefault(short_p,
1149                                               []).append(tuple(p.extra))
1150                 self.ngrams.setdefault(key, []).append(short_p)
1151                 key = key[1:] + (short_p,)
1152
1153         self.cumulative_duration += cum_duration
1154         # add in the end
1155         self.ngrams.setdefault(key, []).append(NON_PACKET)
1156
1157     def save(self, f):
1158         ngrams = {}
1159         for k, v in self.ngrams.items():
1160             k = '\t'.join(k)
1161             ngrams[k] = dict(Counter(v))
1162
1163         query_details = {}
1164         for k, v in self.query_details.items():
1165             query_details[k] = dict(Counter('\t'.join(x) if x else '-'
1166                                             for x in v))
1167
1168         d = {
1169             'ngrams': ngrams,
1170             'query_details': query_details,
1171             'cumulative_duration': self.cumulative_duration,
1172             'conversation_rate': self.conversation_rate,
1173         }
1174         d['dns'] = self.dns_opcounts
1175
1176         if isinstance(f, str):
1177             f = open(f, 'w')
1178
1179         json.dump(d, f, indent=2)
1180
1181     def load(self, f):
1182         if isinstance(f, str):
1183             f = open(f)
1184
1185         d = json.load(f)
1186
1187         for k, v in d['ngrams'].items():
1188             k = tuple(str(k).split('\t'))
1189             values = self.ngrams.setdefault(k, [])
1190             for p, count in v.items():
1191                 values.extend([str(p)] * count)
1192
1193         for k, v in d['query_details'].items():
1194             values = self.query_details.setdefault(str(k), [])
1195             for p, count in v.items():
1196                 if p == '-':
1197                     values.extend([()] * count)
1198                 else:
1199                     values.extend([tuple(str(p).split('\t'))] * count)
1200
1201         if 'dns' in d:
1202             for k, v in d['dns'].items():
1203                 self.dns_opcounts[k] += v
1204
1205         self.cumulative_duration = d['cumulative_duration']
1206         self.conversation_rate = d['conversation_rate']
1207
1208     def construct_conversation(self, timestamp=0.0, client=2, server=1,
1209                                hard_stop=None, packet_rate=1):
1210         """Construct a individual converation from the model."""
1211
1212         c = Conversation(timestamp, (server, client))
1213
1214         key = (NON_PACKET,) * (self.n - 1)
1215
1216         while key in self.ngrams:
1217             p = random.choice(self.ngrams.get(key, NON_PACKET))
1218             if p == NON_PACKET:
1219                 break
1220             if p in self.query_details:
1221                 extra = random.choice(self.query_details[p])
1222             else:
1223                 extra = []
1224
1225             protocol, opcode = p.split(':', 1)
1226             if protocol == 'wait':
1227                 log_wait_time = int(opcode) + random.random()
1228                 wait = math.exp(log_wait_time) / (WAIT_SCALE * packet_rate)
1229                 timestamp += wait
1230             else:
1231                 log_wait = random.uniform(*NO_WAIT_LOG_TIME_RANGE)
1232                 wait = math.exp(log_wait) / packet_rate
1233                 timestamp += wait
1234                 if hard_stop is not None and timestamp > hard_stop:
1235                     break
1236                 c.add_short_packet(timestamp, protocol, opcode, extra)
1237
1238             key = key[1:] + (p,)
1239
1240         return c
1241
1242     def generate_conversations(self, rate, duration, packet_rate=1):
1243         """Generate a list of conversations from the model."""
1244
1245         # We run the simulation for at least ten times as long as our
1246         # desired duration, and take a section near the start.
1247         rate_n, rate_t  = self.conversation_rate
1248
1249         duration2 = max(rate_t, duration * 2)
1250         n = rate * duration2 * rate_n / rate_t
1251
1252         server = 1
1253         client = 2
1254
1255         conversations = []
1256         end = duration2
1257         start = end - duration
1258
1259         while client < n + 2:
1260             start = random.uniform(0, duration2)
1261             c = self.construct_conversation(start,
1262                                             client,
1263                                             server,
1264                                             hard_stop=(duration2 * 5),
1265                                             packet_rate=packet_rate)
1266
1267             c.forget_packets_outside_window(start, end)
1268             c.renormalise_times(start)
1269             if len(c) != 0:
1270                 conversations.append(c)
1271             client += 1
1272
1273         print(("we have %d conversations at rate %f" %
1274                (len(conversations), rate)), file=sys.stderr)
1275         conversations.sort()
1276         return conversations
1277
1278
1279 IP_PROTOCOLS = {
1280     'dns': '11',
1281     'rpc_netlogon': '06',
1282     'kerberos': '06',      # ratio 16248:258
1283     'smb': '06',
1284     'smb2': '06',
1285     'ldap': '06',
1286     'cldap': '11',
1287     'lsarpc': '06',
1288     'samr': '06',
1289     'dcerpc': '06',
1290     'epm': '06',
1291     'drsuapi': '06',
1292     'browser': '11',
1293     'smb_netlogon': '11',
1294     'srvsvc': '06',
1295     'nbns': '11',
1296 }
1297
1298 OP_DESCRIPTIONS = {
1299     ('browser', '0x01'): 'Host Announcement (0x01)',
1300     ('browser', '0x02'): 'Request Announcement (0x02)',
1301     ('browser', '0x08'): 'Browser Election Request (0x08)',
1302     ('browser', '0x09'): 'Get Backup List Request (0x09)',
1303     ('browser', '0x0c'): 'Domain/Workgroup Announcement (0x0c)',
1304     ('browser', '0x0f'): 'Local Master Announcement (0x0f)',
1305     ('cldap', '3'): 'searchRequest',
1306     ('cldap', '5'): 'searchResDone',
1307     ('dcerpc', '0'): 'Request',
1308     ('dcerpc', '11'): 'Bind',
1309     ('dcerpc', '12'): 'Bind_ack',
1310     ('dcerpc', '13'): 'Bind_nak',
1311     ('dcerpc', '14'): 'Alter_context',
1312     ('dcerpc', '15'): 'Alter_context_resp',
1313     ('dcerpc', '16'): 'AUTH3',
1314     ('dcerpc', '2'): 'Response',
1315     ('dns', '0'): 'query',
1316     ('dns', '1'): 'response',
1317     ('drsuapi', '0'): 'DsBind',
1318     ('drsuapi', '12'): 'DsCrackNames',
1319     ('drsuapi', '13'): 'DsWriteAccountSpn',
1320     ('drsuapi', '1'): 'DsUnbind',
1321     ('drsuapi', '2'): 'DsReplicaSync',
1322     ('drsuapi', '3'): 'DsGetNCChanges',
1323     ('drsuapi', '4'): 'DsReplicaUpdateRefs',
1324     ('epm', '3'): 'Map',
1325     ('kerberos', ''): '',
1326     ('ldap', '0'): 'bindRequest',
1327     ('ldap', '1'): 'bindResponse',
1328     ('ldap', '2'): 'unbindRequest',
1329     ('ldap', '3'): 'searchRequest',
1330     ('ldap', '4'): 'searchResEntry',
1331     ('ldap', '5'): 'searchResDone',
1332     ('ldap', ''): '*** Unknown ***',
1333     ('lsarpc', '14'): 'lsa_LookupNames',
1334     ('lsarpc', '15'): 'lsa_LookupSids',
1335     ('lsarpc', '39'): 'lsa_QueryTrustedDomainInfoBySid',
1336     ('lsarpc', '40'): 'lsa_SetTrustedDomainInfo',
1337     ('lsarpc', '6'): 'lsa_OpenPolicy',
1338     ('lsarpc', '76'): 'lsa_LookupSids3',
1339     ('lsarpc', '77'): 'lsa_LookupNames4',
1340     ('nbns', '0'): 'query',
1341     ('nbns', '1'): 'response',
1342     ('rpc_netlogon', '21'): 'NetrLogonDummyRoutine1',
1343     ('rpc_netlogon', '26'): 'NetrServerAuthenticate3',
1344     ('rpc_netlogon', '29'): 'NetrLogonGetDomainInfo',
1345     ('rpc_netlogon', '30'): 'NetrServerPasswordSet2',
1346     ('rpc_netlogon', '39'): 'NetrLogonSamLogonEx',
1347     ('rpc_netlogon', '40'): 'DsrEnumerateDomainTrusts',
1348     ('rpc_netlogon', '45'): 'NetrLogonSamLogonWithFlags',
1349     ('rpc_netlogon', '4'): 'NetrServerReqChallenge',
1350     ('samr', '0',): 'Connect',
1351     ('samr', '16'): 'GetAliasMembership',
1352     ('samr', '17'): 'LookupNames',
1353     ('samr', '18'): 'LookupRids',
1354     ('samr', '19'): 'OpenGroup',
1355     ('samr', '1'): 'Close',
1356     ('samr', '25'): 'QueryGroupMember',
1357     ('samr', '34'): 'OpenUser',
1358     ('samr', '36'): 'QueryUserInfo',
1359     ('samr', '39'): 'GetGroupsForUser',
1360     ('samr', '3'): 'QuerySecurity',
1361     ('samr', '5'): 'LookupDomain',
1362     ('samr', '64'): 'Connect5',
1363     ('samr', '6'): 'EnumDomains',
1364     ('samr', '7'): 'OpenDomain',
1365     ('samr', '8'): 'QueryDomainInfo',
1366     ('smb', '0x04'): 'Close (0x04)',
1367     ('smb', '0x24'): 'Locking AndX (0x24)',
1368     ('smb', '0x2e'): 'Read AndX (0x2e)',
1369     ('smb', '0x32'): 'Trans2 (0x32)',
1370     ('smb', '0x71'): 'Tree Disconnect (0x71)',
1371     ('smb', '0x72'): 'Negotiate Protocol (0x72)',
1372     ('smb', '0x73'): 'Session Setup AndX (0x73)',
1373     ('smb', '0x74'): 'Logoff AndX (0x74)',
1374     ('smb', '0x75'): 'Tree Connect AndX (0x75)',
1375     ('smb', '0xa2'): 'NT Create AndX (0xa2)',
1376     ('smb2', '0'): 'NegotiateProtocol',
1377     ('smb2', '11'): 'Ioctl',
1378     ('smb2', '14'): 'Find',
1379     ('smb2', '16'): 'GetInfo',
1380     ('smb2', '18'): 'Break',
1381     ('smb2', '1'): 'SessionSetup',
1382     ('smb2', '2'): 'SessionLogoff',
1383     ('smb2', '3'): 'TreeConnect',
1384     ('smb2', '4'): 'TreeDisconnect',
1385     ('smb2', '5'): 'Create',
1386     ('smb2', '6'): 'Close',
1387     ('smb2', '8'): 'Read',
1388     ('smb_netlogon', '0x12'): 'SAM LOGON request from client (0x12)',
1389     ('smb_netlogon', '0x17'): ('SAM Active Directory Response - '
1390                                'user unknown (0x17)'),
1391     ('srvsvc', '16'): 'NetShareGetInfo',
1392     ('srvsvc', '21'): 'NetSrvGetInfo',
1393 }
1394
1395
1396 def expand_short_packet(p, timestamp, src, dest, extra):
1397     protocol, opcode = p.split(':', 1)
1398     desc = OP_DESCRIPTIONS.get((protocol, opcode), '')
1399     ip_protocol = IP_PROTOCOLS.get(protocol, '06')
1400
1401     line = [timestamp, ip_protocol, '', src, dest, protocol, opcode, desc]
1402     line.extend(extra)
1403     return '\t'.join(line)
1404
1405
1406 def replay(conversations,
1407            host=None,
1408            creds=None,
1409            lp=None,
1410            accounts=None,
1411            dns_rate=0,
1412            duration=None,
1413            **kwargs):
1414
1415     context = ReplayContext(server=host,
1416                             creds=creds,
1417                             lp=lp,
1418                             **kwargs)
1419
1420     if len(accounts) < len(conversations):
1421         print(("we have %d accounts but %d conversations" %
1422                (accounts, conversations)), file=sys.stderr)
1423
1424     cstack = list(zip(
1425         sorted(conversations, key=lambda x: x.start_time, reverse=True),
1426         accounts))
1427
1428     # Set the process group so that the calling scripts are not killed
1429     # when the forked child processes are killed.
1430     os.setpgrp()
1431
1432     start = time.time()
1433
1434     if duration is None:
1435         # end 1 second after the last packet of the last conversation
1436         # to start. Conversations other than the last could still be
1437         # going, but we don't care.
1438         duration = cstack[0][0].packets[-1].timestamp + 1.0
1439         print("We will stop after %.1f seconds" % duration,
1440               file=sys.stderr)
1441
1442     end = start + duration
1443
1444     LOGGER.info("Replaying traffic for %u conversations over %d seconds"
1445           % (len(conversations), duration))
1446
1447     children = {}
1448     if dns_rate:
1449         dns_hammer = DnsHammer(dns_rate, duration)
1450         cstack.append((dns_hammer, None))
1451
1452     try:
1453         while True:
1454             # we spawn a batch, wait for finishers, then spawn another
1455             now = time.time()
1456             batch_end = min(now + 2.0, end)
1457             fork_time = 0.0
1458             fork_n = 0
1459             while cstack:
1460                 c, account = cstack.pop()
1461                 if c.start_time + start > batch_end:
1462                     cstack.append((c, account))
1463                     break
1464
1465                 st = time.time()
1466                 pid = c.replay_in_fork_with_delay(start, context, account)
1467                 children[pid] = c
1468                 t = time.time()
1469                 elapsed = t - st
1470                 fork_time += elapsed
1471                 fork_n += 1
1472                 print("forked %s in pid %s (in %fs)" % (c, pid,
1473                                                         elapsed),
1474                       file=sys.stderr)
1475
1476             if fork_n:
1477                 print(("forked %d times in %f seconds (avg %f)" %
1478                        (fork_n, fork_time, fork_time / fork_n)),
1479                       file=sys.stderr)
1480             elif cstack:
1481                 debug(2, "no forks in batch ending %f" % batch_end)
1482
1483             while time.time() < batch_end - 1.0:
1484                 time.sleep(0.01)
1485                 try:
1486                     pid, status = os.waitpid(-1, os.WNOHANG)
1487                 except OSError as e:
1488                     if e.errno != 10:  # no child processes
1489                         raise
1490                     break
1491                 if pid:
1492                     c = children.pop(pid, None)
1493                     print(("process %d finished conversation %s;"
1494                            " %d to go" %
1495                            (pid, c, len(children))), file=sys.stderr)
1496
1497             if time.time() >= end:
1498                 print("time to stop", file=sys.stderr)
1499                 break
1500
1501     except Exception:
1502         print("EXCEPTION in parent", file=sys.stderr)
1503         traceback.print_exc()
1504     finally:
1505         for s in (15, 15, 9):
1506             print(("killing %d children with -%d" %
1507                    (len(children), s)), file=sys.stderr)
1508             for pid in children:
1509                 try:
1510                     os.kill(pid, s)
1511                 except OSError as e:
1512                     if e.errno != 3:  # don't fail if it has already died
1513                         raise
1514             time.sleep(0.5)
1515             end = time.time() + 1
1516             while children:
1517                 try:
1518                     pid, status = os.waitpid(-1, os.WNOHANG)
1519                 except OSError as e:
1520                     if e.errno != 10:
1521                         raise
1522                 if pid != 0:
1523                     c = children.pop(pid, None)
1524                     print(("kill -%d %d KILLED conversation %s; "
1525                            "%d to go" %
1526                            (s, pid, c, len(children))),
1527                           file=sys.stderr)
1528                 if time.time() >= end:
1529                     break
1530
1531             if not children:
1532                 break
1533             time.sleep(1)
1534
1535         if children:
1536             print("%d children are missing" % len(children),
1537                   file=sys.stderr)
1538
1539         # there may be stragglers that were forked just as ^C was hit
1540         # and don't appear in the list of children. We can get them
1541         # with killpg, but that will also kill us, so this is^H^H would be
1542         # goodbye, except we cheat and pretend to use ^C (SIG_INTERRUPT),
1543         # so as not to have to fuss around writing signal handlers.
1544         try:
1545             os.killpg(0, 2)
1546         except KeyboardInterrupt:
1547             print("ignoring fake ^C", file=sys.stderr)
1548
1549
1550 def openLdb(host, creds, lp):
1551     session = system_session()
1552     ldb = SamDB(url="ldap://%s" % host,
1553                 session_info=session,
1554                 options=['modules:paged_searches'],
1555                 credentials=creds,
1556                 lp=lp)
1557     return ldb
1558
1559
1560 def ou_name(ldb, instance_id):
1561     """Generate an ou name from the instance id"""
1562     return "ou=instance-%d,ou=traffic_replay,%s" % (instance_id,
1563                                                     ldb.domain_dn())
1564
1565
1566 def create_ou(ldb, instance_id):
1567     """Create an ou, all created user and machine accounts will belong to it.
1568
1569     This allows all the created resources to be cleaned up easily.
1570     """
1571     ou = ou_name(ldb, instance_id)
1572     try:
1573         ldb.add({"dn": ou.split(',', 1)[1],
1574                  "objectclass": "organizationalunit"})
1575     except LdbError as e:
1576         (status, _) = e.args
1577         # ignore already exists
1578         if status != 68:
1579             raise
1580     try:
1581         ldb.add({"dn": ou,
1582                  "objectclass": "organizationalunit"})
1583     except LdbError as e:
1584         (status, _) = e.args
1585         # ignore already exists
1586         if status != 68:
1587             raise
1588     return ou
1589
1590
1591 class ConversationAccounts(object):
1592     """Details of the machine and user accounts associated with a conversation.
1593     """
1594     def __init__(self, netbios_name, machinepass, username, userpass):
1595         self.netbios_name = netbios_name
1596         self.machinepass  = machinepass
1597         self.username     = username
1598         self.userpass     = userpass
1599
1600
1601 def generate_replay_accounts(ldb, instance_id, number, password):
1602     """Generate a series of unique machine and user account names."""
1603
1604     generate_traffic_accounts(ldb, instance_id, number, password)
1605     accounts = []
1606     for i in range(1, number + 1):
1607         netbios_name = "STGM-%d-%d" % (instance_id, i)
1608         username     = "STGU-%d-%d" % (instance_id, i)
1609
1610         account = ConversationAccounts(netbios_name, password, username,
1611                                        password)
1612         accounts.append(account)
1613     return accounts
1614
1615
1616 def generate_traffic_accounts(ldb, instance_id, number, password):
1617     """Create the specified number of user and machine accounts.
1618
1619     As accounts are not explicitly deleted between runs. This function starts
1620     with the last account and iterates backwards stopping either when it
1621     finds an already existing account or it has generated all the required
1622     accounts.
1623     """
1624     print(("Generating machine and conversation accounts, "
1625            "as required for %d conversations" % number),
1626           file=sys.stderr)
1627     added = 0
1628     for i in range(number, 0, -1):
1629         try:
1630             netbios_name = "STGM-%d-%d" % (instance_id, i)
1631             create_machine_account(ldb, instance_id, netbios_name, password)
1632             added += 1
1633         except LdbError as e:
1634             (status, _) = e.args
1635             if status == 68:
1636                 break
1637             else:
1638                 raise
1639     if added > 0:
1640         print("Added %d new machine accounts" % added,
1641               file=sys.stderr)
1642
1643     added = 0
1644     for i in range(number, 0, -1):
1645         try:
1646             username = "STGU-%d-%d" % (instance_id, i)
1647             create_user_account(ldb, instance_id, username, password)
1648             added += 1
1649         except LdbError as e:
1650             (status, _) = e.args
1651             if status == 68:
1652                 break
1653             else:
1654                 raise
1655
1656     if added > 0:
1657         print("Added %d new user accounts" % added,
1658               file=sys.stderr)
1659
1660
1661 def create_machine_account(ldb, instance_id, netbios_name, machinepass):
1662     """Create a machine account via ldap."""
1663
1664     ou = ou_name(ldb, instance_id)
1665     dn = "cn=%s,%s" % (netbios_name, ou)
1666     utf16pw = ('"%s"' % get_string(machinepass)).encode('utf-16-le')
1667
1668     start = time.time()
1669     ldb.add({
1670         "dn": dn,
1671         "objectclass": "computer",
1672         "sAMAccountName": "%s$" % netbios_name,
1673         "userAccountControl":
1674             str(UF_TRUSTED_FOR_DELEGATION | UF_SERVER_TRUST_ACCOUNT),
1675         "unicodePwd": utf16pw})
1676     end = time.time()
1677     duration = end - start
1678     LOGGER.info("%f\t0\tcreate\tmachine\t%f\tTrue\t" % (end, duration))
1679
1680
1681 def create_user_account(ldb, instance_id, username, userpass):
1682     """Create a user account via ldap."""
1683     ou = ou_name(ldb, instance_id)
1684     user_dn = "cn=%s,%s" % (username, ou)
1685     utf16pw = ('"%s"' % get_string(userpass)).encode('utf-16-le')
1686     start = time.time()
1687     ldb.add({
1688         "dn": user_dn,
1689         "objectclass": "user",
1690         "sAMAccountName": username,
1691         "userAccountControl": str(UF_NORMAL_ACCOUNT),
1692         "unicodePwd": utf16pw
1693     })
1694
1695     # grant user write permission to do things like write account SPN
1696     sdutils = sd_utils.SDUtils(ldb)
1697     sdutils.dacl_add_ace(user_dn, "(A;;WP;;;PS)")
1698
1699     end = time.time()
1700     duration = end - start
1701     LOGGER.info("%f\t0\tcreate\tuser\t%f\tTrue\t" % (end, duration))
1702
1703
1704 def create_group(ldb, instance_id, name):
1705     """Create a group via ldap."""
1706
1707     ou = ou_name(ldb, instance_id)
1708     dn = "cn=%s,%s" % (name, ou)
1709     start = time.time()
1710     ldb.add({
1711         "dn": dn,
1712         "objectclass": "group",
1713         "sAMAccountName": name,
1714     })
1715     end = time.time()
1716     duration = end - start
1717     LOGGER.info("%f\t0\tcreate\tgroup\t%f\tTrue\t" % (end, duration))
1718
1719
1720 def user_name(instance_id, i):
1721     """Generate a user name based in the instance id"""
1722     return "STGU-%d-%d" % (instance_id, i)
1723
1724
1725 def search_objectclass(ldb, objectclass='user', attr='sAMAccountName'):
1726     """Seach objectclass, return attr in a set"""
1727     objs = ldb.search(
1728         expression="(objectClass={})".format(objectclass),
1729         attrs=[attr]
1730     )
1731     return {str(obj[attr]) for obj in objs}
1732
1733
1734 def generate_users(ldb, instance_id, number, password):
1735     """Add users to the server"""
1736     existing_objects = search_objectclass(ldb, objectclass='user')
1737     users = 0
1738     for i in range(number, 0, -1):
1739         name = user_name(instance_id, i)
1740         if name not in existing_objects:
1741             create_user_account(ldb, instance_id, name, password)
1742             users += 1
1743
1744     return users
1745
1746
1747 def group_name(instance_id, i):
1748     """Generate a group name from instance id."""
1749     return "STGG-%d-%d" % (instance_id, i)
1750
1751
1752 def generate_groups(ldb, instance_id, number):
1753     """Create the required number of groups on the server."""
1754     existing_objects = search_objectclass(ldb, objectclass='group')
1755     groups = 0
1756     for i in range(number, 0, -1):
1757         name = group_name(instance_id, i)
1758         if name not in existing_objects:
1759             create_group(ldb, instance_id, name)
1760             groups += 1
1761
1762     return groups
1763
1764
1765 def clean_up_accounts(ldb, instance_id):
1766     """Remove the created accounts and groups from the server."""
1767     ou = ou_name(ldb, instance_id)
1768     try:
1769         ldb.delete(ou, ["tree_delete:1"])
1770     except LdbError as e:
1771         (status, _) = e.args
1772         # ignore does not exist
1773         if status != 32:
1774             raise
1775
1776
1777 def generate_users_and_groups(ldb, instance_id, password,
1778                               number_of_users, number_of_groups,
1779                               group_memberships):
1780     """Generate the required users and groups, allocating the users to
1781        those groups."""
1782     assignments = []
1783     groups_added  = 0
1784
1785     create_ou(ldb, instance_id)
1786
1787     print("Generating dummy user accounts", file=sys.stderr)
1788     users_added = generate_users(ldb, instance_id, number_of_users, password)
1789
1790     if number_of_groups > 0:
1791         print("Generating dummy groups", file=sys.stderr)
1792         groups_added = generate_groups(ldb, instance_id, number_of_groups)
1793
1794     if group_memberships > 0:
1795         print("Assigning users to groups", file=sys.stderr)
1796         assignments = assign_groups(number_of_groups,
1797                                     groups_added,
1798                                     number_of_users,
1799                                     users_added,
1800                                     group_memberships)
1801         print("Adding users to groups", file=sys.stderr)
1802         add_users_to_groups(ldb, instance_id, assignments)
1803
1804     if (groups_added > 0 and users_added == 0 and
1805        number_of_groups != groups_added):
1806         print("Warning: the added groups will contain no members",
1807               file=sys.stderr)
1808
1809     print(("Added %d users, %d groups and %d group memberships" %
1810            (users_added, groups_added, len(assignments))),
1811           file=sys.stderr)
1812
1813
1814 def assign_groups(number_of_groups,
1815                   groups_added,
1816                   number_of_users,
1817                   users_added,
1818                   group_memberships):
1819     """Allocate users to groups.
1820
1821     The intention is to have a few users that belong to most groups, while
1822     the majority of users belong to a few groups.
1823
1824     A few groups will contain most users, with the remaining only having a
1825     few users.
1826     """
1827
1828     def generate_user_distribution(n):
1829         """Probability distribution of a user belonging to a group.
1830         """
1831         dist = []
1832         for x in range(1, n + 1):
1833             p = 1 / (x + 0.001)
1834             dist.append(p)
1835         return dist
1836
1837     def generate_group_distribution(n):
1838         """Probability distribution of a group containing a user."""
1839         dist = []
1840         for x in range(1, n + 1):
1841             p = 1 / (x**1.3)
1842             dist.append(p)
1843         return dist
1844
1845     assignments = set()
1846     if group_memberships <= 0:
1847         return assignments
1848
1849     group_dist = generate_group_distribution(number_of_groups)
1850     user_dist  = generate_user_distribution(number_of_users)
1851
1852     # Calculate the number of group menberships required
1853     group_memberships = math.ceil(
1854         float(group_memberships) *
1855         (float(users_added) / float(number_of_users)))
1856
1857     existing_users  = number_of_users  - users_added  - 1
1858     existing_groups = number_of_groups - groups_added - 1
1859     while len(assignments) < group_memberships:
1860         user        = random.randint(0, number_of_users - 1)
1861         group       = random.randint(0, number_of_groups - 1)
1862         probability = group_dist[group] * user_dist[user]
1863
1864         if ((random.random() < probability * 10000) and
1865            (group > existing_groups or user > existing_users)):
1866             # the + 1 converts the array index to the corresponding
1867             # group or user number
1868             assignments.add(((user + 1), (group + 1)))
1869
1870     return assignments
1871
1872
1873 def add_users_to_groups(db, instance_id, assignments):
1874     """Add users to their assigned groups.
1875
1876     Takes the list of (group,user) tuples generated by assign_groups and
1877     assign the users to their specified groups."""
1878
1879     ou = ou_name(db, instance_id)
1880
1881     def build_dn(name):
1882         return("cn=%s,%s" % (name, ou))
1883
1884     for (user, group) in assignments:
1885         user_dn  = build_dn(user_name(instance_id, user))
1886         group_dn = build_dn(group_name(instance_id, group))
1887
1888         m = ldb.Message()
1889         m.dn = ldb.Dn(db, group_dn)
1890         m["member"] = ldb.MessageElement(user_dn, ldb.FLAG_MOD_ADD, "member")
1891         start = time.time()
1892         db.modify(m)
1893         end = time.time()
1894         duration = end - start
1895         print("%f\t0\tadd\tuser\t%f\tTrue\t" % (end, duration))
1896
1897
1898 def generate_stats(statsdir, timing_file):
1899     """Generate and print the summary stats for a run."""
1900     first      = sys.float_info.max
1901     last       = 0
1902     successful = 0
1903     failed     = 0
1904     latencies  = {}
1905     failures   = {}
1906     unique_converations = set()
1907     conversations = 0
1908
1909     if timing_file is not None:
1910         tw = timing_file.write
1911     else:
1912         def tw(x):
1913             pass
1914
1915     tw("time\tconv\tprotocol\ttype\tduration\tsuccessful\terror\n")
1916
1917     for filename in os.listdir(statsdir):
1918         path = os.path.join(statsdir, filename)
1919         with open(path, 'r') as f:
1920             for line in f:
1921                 try:
1922                     fields       = line.rstrip('\n').split('\t')
1923                     conversation = fields[1]
1924                     protocol     = fields[2]
1925                     packet_type  = fields[3]
1926                     latency      = float(fields[4])
1927                     first        = min(float(fields[0]) - latency, first)
1928                     last         = max(float(fields[0]), last)
1929
1930                     if protocol not in latencies:
1931                         latencies[protocol] = {}
1932                     if packet_type not in latencies[protocol]:
1933                         latencies[protocol][packet_type] = []
1934
1935                     latencies[protocol][packet_type].append(latency)
1936
1937                     if protocol not in failures:
1938                         failures[protocol] = {}
1939                     if packet_type not in failures[protocol]:
1940                         failures[protocol][packet_type] = 0
1941
1942                     if fields[5] == 'True':
1943                         successful += 1
1944                     else:
1945                         failed += 1
1946                         failures[protocol][packet_type] += 1
1947
1948                     if conversation not in unique_converations:
1949                         unique_converations.add(conversation)
1950                         conversations += 1
1951
1952                     tw(line)
1953                 except (ValueError, IndexError):
1954                     # not a valid line print and ignore
1955                     print(line, file=sys.stderr)
1956                     pass
1957     duration = last - first
1958     if successful == 0:
1959         success_rate = 0
1960     else:
1961         success_rate = successful / duration
1962     if failed == 0:
1963         failure_rate = 0
1964     else:
1965         failure_rate = failed / duration
1966
1967     print("Total conversations:   %10d" % conversations)
1968     print("Successful operations: %10d (%.3f per second)"
1969           % (successful, success_rate))
1970     print("Failed operations:     %10d (%.3f per second)"
1971           % (failed, failure_rate))
1972
1973     print("Protocol    Op Code  Description                               "
1974           " Count       Failed         Mean       Median          "
1975           "95%        Range          Max")
1976
1977     protocols = sorted(latencies.keys())
1978     for protocol in protocols:
1979         packet_types = sorted(latencies[protocol], key=opcode_key)
1980         for packet_type in packet_types:
1981             values     = latencies[protocol][packet_type]
1982             values     = sorted(values)
1983             count      = len(values)
1984             failed     = failures[protocol][packet_type]
1985             mean       = sum(values) / count
1986             median     = calc_percentile(values, 0.50)
1987             percentile = calc_percentile(values, 0.95)
1988             rng        = values[-1] - values[0]
1989             maxv       = values[-1]
1990             desc       = OP_DESCRIPTIONS.get((protocol, packet_type), '')
1991             if sys.stdout.isatty:
1992                 print("%-12s   %4s  %-35s %12d %12d %12.6f "
1993                       "%12.6f %12.6f %12.6f %12.6f"
1994                       % (protocol,
1995                          packet_type,
1996                          desc,
1997                          count,
1998                          failed,
1999                          mean,
2000                          median,
2001                          percentile,
2002                          rng,
2003                          maxv))
2004             else:
2005                 print("%s\t%s\t%s\t%d\t%d\t%f\t%f\t%f\t%f\t%f"
2006                       % (protocol,
2007                          packet_type,
2008                          desc,
2009                          count,
2010                          failed,
2011                          mean,
2012                          median,
2013                          percentile,
2014                          rng,
2015                          maxv))
2016
2017
2018 def opcode_key(v):
2019     """Sort key for the operation code to ensure that it sorts numerically"""
2020     try:
2021         return "%03d" % int(v)
2022     except:
2023         return v
2024
2025
2026 def calc_percentile(values, percentile):
2027     """Calculate the specified percentile from the list of values.
2028
2029     Assumes the list is sorted in ascending order.
2030     """
2031
2032     if not values:
2033         return 0
2034     k = (len(values) - 1) * percentile
2035     f = math.floor(k)
2036     c = math.ceil(k)
2037     if f == c:
2038         return values[int(k)]
2039     d0 = values[int(f)] * (c - k)
2040     d1 = values[int(c)] * (k - f)
2041     return d0 + d1
2042
2043
2044 def mk_masked_dir(*path):
2045     """In a testenv we end up with 0777 diectories that look an alarming
2046     green colour with ls. Use umask to avoid that."""
2047     d = os.path.join(*path)
2048     mask = os.umask(0o077)
2049     os.mkdir(d)
2050     os.umask(mask)
2051     return d