1 # -*- encoding: utf-8 -*-
2 # Samba traffic replay and learning
4 # Copyright (C) Catalyst IT Ltd. 2017
6 # This program is free software; you can redistribute it and/or modify
7 # it under the terms of the GNU General Public License as published by
8 # the Free Software Foundation; either version 3 of the License, or
9 # (at your option) any later version.
11 # This program is distributed in the hope that it will be useful,
12 # but WITHOUT ANY WARRANTY; without even the implied warranty of
13 # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
14 # GNU General Public License for more details.
16 # You should have received a copy of the GNU General Public License
17 # along with this program. If not, see <http://www.gnu.org/licenses/>.
19 from __future__ import print_function, division
28 from errno import ECHILD, ESRCH
30 from collections import OrderedDict, Counter, defaultdict, namedtuple
31 from samba.emulate import traffic_packets
32 from samba.samdb import SamDB
34 from ldb import LdbError
35 from samba.dcerpc import ClientConnection
36 from samba.dcerpc import security, drsuapi, lsa
37 from samba.dcerpc import netlogon
38 from samba.dcerpc.netlogon import netr_Authenticator
39 from samba.dcerpc import srvsvc
40 from samba.dcerpc import samr
41 from samba.drs_utils import drs_DsBind
43 from samba.credentials import Credentials, DONT_USE_KERBEROS, MUST_USE_KERBEROS
44 from samba.auth import system_session
45 from samba.dsdb import (
47 UF_SERVER_TRUST_ACCOUNT,
48 UF_TRUSTED_FOR_DELEGATION,
49 UF_WORKSTATION_TRUST_ACCOUNT
51 from samba.dcerpc.misc import SEC_CHAN_BDC
52 from samba import gensec
53 from samba import sd_utils
54 from samba.compat import get_string
55 from samba.logger import get_samba_logger
58 CURRENT_MODEL_VERSION = 2 # save as this
59 REQUIRED_MODEL_VERSION = 2 # load accepts this or greater
62 # we don't use None, because it complicates [de]serialisation
66 ('dns', '0'): 1.0, # query
67 ('smb', '0x72'): 1.0, # Negotiate protocol
68 ('ldap', '0'): 1.0, # bind
69 ('ldap', '3'): 1.0, # searchRequest
70 ('ldap', '2'): 1.0, # unbindRequest
72 ('dcerpc', '11'): 1.0, # bind
73 ('dcerpc', '14'): 1.0, # Alter_context
74 ('nbns', '0'): 1.0, # query
78 ('dns', '1'): 1.0, # response
79 ('ldap', '1'): 1.0, # bind response
80 ('ldap', '4'): 1.0, # search result
81 ('ldap', '5'): 1.0, # search done
83 ('dcerpc', '12'): 1.0, # bind_ack
84 ('dcerpc', '13'): 1.0, # bind_nak
85 ('dcerpc', '15'): 1.0, # Alter_context response
88 SKIPPED_PROTOCOLS = {"smb", "smb2", "browser", "smb_netlogon"}
91 WAIT_THRESHOLD = (1.0 / WAIT_SCALE)
92 NO_WAIT_LOG_TIME_RANGE = (-10, -3)
94 # DEBUG_LEVEL can be changed by scripts with -d
97 LOGGER = get_samba_logger(name=__name__)
100 def debug(level, msg, *args):
101 """Print a formatted debug message to standard error.
104 :param level: The debug level, message will be printed if it is <= the
105 currently set debug level. The debug level can be set with
107 :param msg: The message to be logged, can contain C-Style format
109 :param args: The parameters required by the format specifiers
111 if level <= DEBUG_LEVEL:
113 print(msg, file=sys.stderr)
115 print(msg % tuple(args), file=sys.stderr)
118 def debug_lineno(*args):
119 """ Print an unformatted log message to stderr, contaning the line number
121 tb = traceback.extract_stack(limit=2)
122 print((" %s:" "\033[01;33m"
123 "%s " "\033[00m" % (tb[0][2], tb[0][1])), end=' ',
126 print(a, file=sys.stderr)
127 print(file=sys.stderr)
131 def random_colour_print(seeds):
132 """Return a function that prints a coloured line to stderr. The colour
133 of the line depends on a sort of hash of the integer arguments."""
140 prefix = "\033[38;5;%dm" % (18 + s)
145 print("%s%s\033[00m" % (prefix, a), file=sys.stderr)
150 print(a, file=sys.stderr)
155 class FakePacketError(Exception):
159 class Packet(object):
160 """Details of a network packet"""
161 __slots__ = ('timestamp',
171 def __init__(self, timestamp, ip_protocol, stream_number, src, dest,
172 protocol, opcode, desc, extra):
173 self.timestamp = timestamp
174 self.ip_protocol = ip_protocol
175 self.stream_number = stream_number
178 self.protocol = protocol
182 if self.src < self.dest:
183 self.endpoints = (self.src, self.dest)
185 self.endpoints = (self.dest, self.src)
188 def from_line(cls, line):
189 fields = line.rstrip('\n').split('\t')
200 timestamp = float(timestamp)
204 return cls(timestamp, ip_protocol, stream_number, src, dest,
205 protocol, opcode, desc, extra)
207 def as_summary(self, time_offset=0.0):
208 """Format the packet as a traffic_summary line.
210 extra = '\t'.join(self.extra)
211 t = self.timestamp + time_offset
212 return (t, '%f\t%s\t%s\t%d\t%d\t%s\t%s\t%s\t%s' %
215 self.stream_number or '',
224 return ("%.3f: %d -> %d; ip %s; strm %s; prot %s; op %s; desc %s %s" %
225 (self.timestamp, self.src, self.dest, self.ip_protocol or '-',
226 self.stream_number, self.protocol, self.opcode, self.desc,
227 ('«' + ' '.join(self.extra) + '»' if self.extra else '')))
230 return "<Packet @%s>" % self
233 return self.__class__(self.timestamp,
243 def as_packet_type(self):
244 t = '%s:%s' % (self.protocol, self.opcode)
247 def client_score(self):
248 """A positive number means we think it is a client; a negative number
249 means we think it is a server. Zero means no idea. range: -1 to 1.
251 key = (self.protocol, self.opcode)
252 if key in CLIENT_CLUES:
253 return CLIENT_CLUES[key]
254 if key in SERVER_CLUES:
255 return -SERVER_CLUES[key]
258 def play(self, conversation, context):
259 """Send the packet over the network, if required.
261 Some packets are ignored, i.e. for protocols not handled,
262 server response messages, or messages that are generated by the
263 protocol layer associated with other packets.
265 fn_name = 'packet_%s_%s' % (self.protocol, self.opcode)
267 fn = getattr(traffic_packets, fn_name)
269 except AttributeError as e:
270 print("Conversation(%s) Missing handler %s" %
271 (conversation.conversation_id, fn_name),
275 # Don't display a message for kerberos packets, they're not directly
276 # generated they're used to indicate kerberos should be used
277 if self.protocol != "kerberos":
278 debug(2, "Conversation(%s) Calling handler %s" %
279 (conversation.conversation_id, fn_name))
283 if fn(self, conversation, context):
284 # Only collect timing data for functions that generate
285 # network traffic, or fail
287 duration = end - start
288 print("%f\t%s\t%s\t%s\t%f\tTrue\t" %
289 (end, conversation.conversation_id, self.protocol,
290 self.opcode, duration))
291 except Exception as e:
293 duration = end - start
294 print("%f\t%s\t%s\t%s\t%f\tFalse\t%s" %
295 (end, conversation.conversation_id, self.protocol,
296 self.opcode, duration, e))
298 def __cmp__(self, other):
299 return self.timestamp - other.timestamp
301 def is_really_a_packet(self, missing_packet_stats=None):
302 return is_a_real_packet(self.protocol, self.opcode)
305 def is_a_real_packet(protocol, opcode):
306 """Is the packet one that can be ignored?
308 If so removing it will have no effect on the replay
310 if protocol in SKIPPED_PROTOCOLS:
311 # Ignore any packets for the protocols we're not interested in.
313 if protocol == "ldap" and opcode == '':
314 # skip ldap continuation packets
317 fn_name = 'packet_%s_%s' % (protocol, opcode)
318 fn = getattr(traffic_packets, fn_name, None)
320 LOGGER.debug("missing packet %s" % fn_name, file=sys.stderr)
322 if fn is traffic_packets.null_packet:
327 def is_a_traffic_generating_packet(protocol, opcode):
328 """Return true if a packet generates traffic in its own right. Some of
329 these will generate traffic in certain contexts (e.g. ldap unbind
330 after a bind) but not if the conversation consists only of these packets.
332 if protocol == 'wait':
335 if (protocol, opcode) in (
342 return is_a_real_packet(protocol, opcode)
345 class ReplayContext(object):
346 """State/Context for a conversation between an simulated client and a
347 server. Some of the context is shared amongst all conversations
348 and should be generated before the fork, while other context is
349 specific to a particular conversation and should be generated
350 *after* the fork, in generate_process_local_config().
356 badpassword_frequency=None,
357 prefer_kerberos=None,
362 domain=os.environ.get("DOMAIN"),
365 self.netlogon_connection = None
368 self.prefer_kerberos = prefer_kerberos
370 self.base_dn = base_dn
372 self.statsdir = statsdir
373 self.global_tempdir = tempdir
374 self.domain_sid = domain_sid
375 self.realm = lp.get('realm')
377 # Bad password attempt controls
378 self.badpassword_frequency = badpassword_frequency
379 self.last_lsarpc_bad = False
380 self.last_lsarpc_named_bad = False
381 self.last_simple_bind_bad = False
382 self.last_bind_bad = False
383 self.last_srvsvc_bad = False
384 self.last_drsuapi_bad = False
385 self.last_netlogon_bad = False
386 self.last_samlogon_bad = False
387 self.generate_ldap_search_tables()
389 def generate_ldap_search_tables(self):
390 session = system_session()
392 db = SamDB(url="ldap://%s" % self.server,
393 session_info=session,
394 credentials=self.creds,
397 res = db.search(db.domain_dn(),
398 scope=ldb.SCOPE_SUBTREE,
399 controls=["paged_results:1:1000"],
402 # find a list of dns for each pattern
403 # e.g. CN,CN,CN,DC,DC
405 attribute_clue_map = {
411 pattern = ','.join(x.lstrip()[:2] for x in dn.split(',')).upper()
412 dns = dn_map.setdefault(pattern, [])
414 if dn.startswith('CN=NTDS Settings,'):
415 attribute_clue_map['invocationId'].append(dn)
417 # extend the map in case we are working with a different
418 # number of DC components.
419 # for k, v in self.dn_map.items():
420 # print >>sys.stderr, k, len(v)
422 for k in list(dn_map.keys()):
426 while p[-3:] == ',DC':
430 if p != k and p in dn_map:
431 print('dn_map collison %s %s' % (k, p),
434 dn_map[p] = dn_map[k]
437 self.attribute_clue_map = attribute_clue_map
439 def generate_process_local_config(self, account, conversation):
440 self.ldap_connections = []
441 self.dcerpc_connections = []
442 self.lsarpc_connections = []
443 self.lsarpc_connections_named = []
444 self.drsuapi_connections = []
445 self.srvsvc_connections = []
446 self.samr_contexts = []
447 self.netbios_name = account.netbios_name
448 self.machinepass = account.machinepass
449 self.username = account.username
450 self.userpass = account.userpass
452 self.tempdir = mk_masked_dir(self.global_tempdir,
454 conversation.conversation_id)
456 self.lp.set("private dir", self.tempdir)
457 self.lp.set("lock dir", self.tempdir)
458 self.lp.set("state directory", self.tempdir)
459 self.lp.set("tls verify peer", "no_check")
461 self.remoteAddress = "/root/ncalrpc_as_system"
462 self.samlogon_dn = ("cn=%s,%s" %
463 (self.netbios_name, self.ou))
464 self.user_dn = ("cn=%s,%s" %
465 (self.username, self.ou))
467 self.generate_machine_creds()
468 self.generate_user_creds()
470 def with_random_bad_credentials(self, f, good, bad, failed_last_time):
471 """Execute the supplied logon function, randomly choosing the
474 Based on the frequency in badpassword_frequency randomly perform the
475 function with the supplied bad credentials.
476 If run with bad credentials, the function is re-run with the good
478 failed_last_time is used to prevent consecutive bad credential
479 attempts. So the over all bad credential frequency will be lower
480 than that requested, but not significantly.
482 if not failed_last_time:
483 if (self.badpassword_frequency and
484 random.random() < self.badpassword_frequency):
488 # Ignore any exceptions as the operation may fail
489 # as it's being performed with bad credentials
491 failed_last_time = True
493 failed_last_time = False
496 return (result, failed_last_time)
498 def generate_user_creds(self):
499 """Generate the conversation specific user Credentials.
501 Each Conversation has an associated user account used to simulate
502 any non Administrative user traffic.
504 Generates user credentials with good and bad passwords and ldap
505 simple bind credentials with good and bad passwords.
507 self.user_creds = Credentials()
508 self.user_creds.guess(self.lp)
509 self.user_creds.set_workstation(self.netbios_name)
510 self.user_creds.set_password(self.userpass)
511 self.user_creds.set_username(self.username)
512 self.user_creds.set_domain(self.domain)
513 if self.prefer_kerberos:
514 self.user_creds.set_kerberos_state(MUST_USE_KERBEROS)
516 self.user_creds.set_kerberos_state(DONT_USE_KERBEROS)
518 self.user_creds_bad = Credentials()
519 self.user_creds_bad.guess(self.lp)
520 self.user_creds_bad.set_workstation(self.netbios_name)
521 self.user_creds_bad.set_password(self.userpass[:-4])
522 self.user_creds_bad.set_username(self.username)
523 if self.prefer_kerberos:
524 self.user_creds_bad.set_kerberos_state(MUST_USE_KERBEROS)
526 self.user_creds_bad.set_kerberos_state(DONT_USE_KERBEROS)
528 # Credentials for ldap simple bind.
529 self.simple_bind_creds = Credentials()
530 self.simple_bind_creds.guess(self.lp)
531 self.simple_bind_creds.set_workstation(self.netbios_name)
532 self.simple_bind_creds.set_password(self.userpass)
533 self.simple_bind_creds.set_username(self.username)
534 self.simple_bind_creds.set_gensec_features(
535 self.simple_bind_creds.get_gensec_features() | gensec.FEATURE_SEAL)
536 if self.prefer_kerberos:
537 self.simple_bind_creds.set_kerberos_state(MUST_USE_KERBEROS)
539 self.simple_bind_creds.set_kerberos_state(DONT_USE_KERBEROS)
540 self.simple_bind_creds.set_bind_dn(self.user_dn)
542 self.simple_bind_creds_bad = Credentials()
543 self.simple_bind_creds_bad.guess(self.lp)
544 self.simple_bind_creds_bad.set_workstation(self.netbios_name)
545 self.simple_bind_creds_bad.set_password(self.userpass[:-4])
546 self.simple_bind_creds_bad.set_username(self.username)
547 self.simple_bind_creds_bad.set_gensec_features(
548 self.simple_bind_creds_bad.get_gensec_features() |
550 if self.prefer_kerberos:
551 self.simple_bind_creds_bad.set_kerberos_state(MUST_USE_KERBEROS)
553 self.simple_bind_creds_bad.set_kerberos_state(DONT_USE_KERBEROS)
554 self.simple_bind_creds_bad.set_bind_dn(self.user_dn)
556 def generate_machine_creds(self):
557 """Generate the conversation specific machine Credentials.
559 Each Conversation has an associated machine account.
561 Generates machine credentials with good and bad passwords.
564 self.machine_creds = Credentials()
565 self.machine_creds.guess(self.lp)
566 self.machine_creds.set_workstation(self.netbios_name)
567 self.machine_creds.set_secure_channel_type(SEC_CHAN_BDC)
568 self.machine_creds.set_password(self.machinepass)
569 self.machine_creds.set_username(self.netbios_name + "$")
570 self.machine_creds.set_domain(self.domain)
571 if self.prefer_kerberos:
572 self.machine_creds.set_kerberos_state(MUST_USE_KERBEROS)
574 self.machine_creds.set_kerberos_state(DONT_USE_KERBEROS)
576 self.machine_creds_bad = Credentials()
577 self.machine_creds_bad.guess(self.lp)
578 self.machine_creds_bad.set_workstation(self.netbios_name)
579 self.machine_creds_bad.set_secure_channel_type(SEC_CHAN_BDC)
580 self.machine_creds_bad.set_password(self.machinepass[:-4])
581 self.machine_creds_bad.set_username(self.netbios_name + "$")
582 if self.prefer_kerberos:
583 self.machine_creds_bad.set_kerberos_state(MUST_USE_KERBEROS)
585 self.machine_creds_bad.set_kerberos_state(DONT_USE_KERBEROS)
587 def get_matching_dn(self, pattern, attributes=None):
588 # If the pattern is an empty string, we assume ROOTDSE,
589 # Otherwise we try adding or removing DC suffixes, then
590 # shorter leading patterns until we hit one.
591 # e.g if there is no CN,CN,CN,CN,DC,DC
592 # we first try CN,CN,CN,CN,DC
593 # and CN,CN,CN,CN,DC,DC,DC
594 # then change to CN,CN,CN,DC,DC
595 # and as last resort we use the base_dn
596 attr_clue = self.attribute_clue_map.get(attributes)
598 return random.choice(attr_clue)
600 pattern = pattern.upper()
602 if pattern in self.dn_map:
603 return random.choice(self.dn_map[pattern])
604 # chop one off the front and try it all again.
605 pattern = pattern[3:]
609 def get_dcerpc_connection(self, new=False):
610 guid = '12345678-1234-abcd-ef00-01234567cffb' # RPC_NETLOGON UUID
611 if self.dcerpc_connections and not new:
612 return self.dcerpc_connections[-1]
613 c = ClientConnection("ncacn_ip_tcp:%s" % self.server,
615 self.dcerpc_connections.append(c)
618 def get_srvsvc_connection(self, new=False):
619 if self.srvsvc_connections and not new:
620 return self.srvsvc_connections[-1]
623 return srvsvc.srvsvc("ncacn_np:%s" % (self.server),
627 (c, self.last_srvsvc_bad) = \
628 self.with_random_bad_credentials(connect,
631 self.last_srvsvc_bad)
633 self.srvsvc_connections.append(c)
636 def get_lsarpc_connection(self, new=False):
637 if self.lsarpc_connections and not new:
638 return self.lsarpc_connections[-1]
641 binding_options = 'schannel,seal,sign'
642 return lsa.lsarpc("ncacn_ip_tcp:%s[%s]" %
643 (self.server, binding_options),
647 (c, self.last_lsarpc_bad) = \
648 self.with_random_bad_credentials(connect,
650 self.machine_creds_bad,
651 self.last_lsarpc_bad)
653 self.lsarpc_connections.append(c)
656 def get_lsarpc_named_pipe_connection(self, new=False):
657 if self.lsarpc_connections_named and not new:
658 return self.lsarpc_connections_named[-1]
661 return lsa.lsarpc("ncacn_np:%s" % (self.server),
665 (c, self.last_lsarpc_named_bad) = \
666 self.with_random_bad_credentials(connect,
668 self.machine_creds_bad,
669 self.last_lsarpc_named_bad)
671 self.lsarpc_connections_named.append(c)
674 def get_drsuapi_connection_pair(self, new=False, unbind=False):
675 """get a (drs, drs_handle) tuple"""
676 if self.drsuapi_connections and not new:
677 c = self.drsuapi_connections[-1]
681 binding_options = 'seal'
682 binding_string = "ncacn_ip_tcp:%s[%s]" %\
683 (self.server, binding_options)
684 return drsuapi.drsuapi(binding_string, self.lp, creds)
686 (drs, self.last_drsuapi_bad) = \
687 self.with_random_bad_credentials(connect,
690 self.last_drsuapi_bad)
692 (drs_handle, supported_extensions) = drs_DsBind(drs)
693 c = (drs, drs_handle)
694 self.drsuapi_connections.append(c)
697 def get_ldap_connection(self, new=False, simple=False):
698 if self.ldap_connections and not new:
699 return self.ldap_connections[-1]
701 def simple_bind(creds):
703 To run simple bind against Windows, we need to run
704 following commands in PowerShell:
706 Install-windowsfeature ADCS-Cert-Authority
707 Install-AdcsCertificationAuthority -CAType EnterpriseRootCA
711 return SamDB('ldaps://%s' % self.server,
715 def sasl_bind(creds):
716 return SamDB('ldap://%s' % self.server,
720 (samdb, self.last_simple_bind_bad) = \
721 self.with_random_bad_credentials(simple_bind,
722 self.simple_bind_creds,
723 self.simple_bind_creds_bad,
724 self.last_simple_bind_bad)
726 (samdb, self.last_bind_bad) = \
727 self.with_random_bad_credentials(sasl_bind,
732 self.ldap_connections.append(samdb)
735 def get_samr_context(self, new=False):
736 if not self.samr_contexts or new:
737 self.samr_contexts.append(
738 SamrContext(self.server, lp=self.lp, creds=self.creds))
739 return self.samr_contexts[-1]
741 def get_netlogon_connection(self):
743 if self.netlogon_connection:
744 return self.netlogon_connection
747 return netlogon.netlogon("ncacn_ip_tcp:%s[schannel,seal]" %
751 (c, self.last_netlogon_bad) = \
752 self.with_random_bad_credentials(connect,
754 self.machine_creds_bad,
755 self.last_netlogon_bad)
756 self.netlogon_connection = c
759 def guess_a_dns_lookup(self):
760 return (self.realm, 'A')
762 def get_authenticator(self):
763 auth = self.machine_creds.new_client_authenticator()
764 current = netr_Authenticator()
765 current.cred.data = [x if isinstance(x, int) else ord(x)
766 for x in auth["credential"]]
767 current.timestamp = auth["timestamp"]
769 subsequent = netr_Authenticator()
770 return (current, subsequent)
773 class SamrContext(object):
774 """State/Context associated with a samr connection.
776 def __init__(self, server, lp=None, creds=None):
777 self.connection = None
779 self.domain_handle = None
780 self.domain_sid = None
781 self.group_handle = None
782 self.user_handle = None
788 def get_connection(self):
789 if not self.connection:
790 self.connection = samr.samr(
791 "ncacn_ip_tcp:%s[seal]" % (self.server),
793 credentials=self.creds)
795 return self.connection
797 def get_handle(self):
799 c = self.get_connection()
800 self.handle = c.Connect2(None, security.SEC_FLAG_MAXIMUM_ALLOWED)
804 class Conversation(object):
805 """Details of a converation between a simulated client and a server."""
806 def __init__(self, start_time=None, endpoints=None, seq=(),
807 conversation_id=None):
808 self.start_time = start_time
809 self.endpoints = endpoints
811 self.msg = random_colour_print(endpoints)
812 self.client_balance = 0.0
813 self.conversation_id = conversation_id
815 self.add_short_packet(*p)
817 def __cmp__(self, other):
818 if self.start_time is None:
819 if other.start_time is None:
822 if other.start_time is None:
824 return self.start_time - other.start_time
826 def add_packet(self, packet):
827 """Add a packet object to this conversation, making a local copy with
828 a conversation-relative timestamp."""
831 if self.start_time is None:
832 self.start_time = p.timestamp
834 if self.endpoints is None:
835 self.endpoints = p.endpoints
837 if p.endpoints != self.endpoints:
838 raise FakePacketError("Conversation endpoints %s don't match"
839 "packet endpoints %s" %
840 (self.endpoints, p.endpoints))
842 p.timestamp -= self.start_time
844 if p.src == p.endpoints[0]:
845 self.client_balance -= p.client_score()
847 self.client_balance += p.client_score()
849 if p.is_really_a_packet():
850 self.packets.append(p)
852 def add_short_packet(self, timestamp, protocol, opcode, extra,
854 """Create a packet from a timestamp, and 'protocol:opcode' pair, and a
855 (possibly empty) list of extra data. If client is True, assume
856 this packet is from the client to the server.
858 src, dest = self.guess_client_server()
860 src, dest = dest, src
861 key = (protocol, opcode)
862 desc = OP_DESCRIPTIONS[key] if key in OP_DESCRIPTIONS else ''
863 if protocol in IP_PROTOCOLS:
864 ip_protocol = IP_PROTOCOLS[protocol]
867 packet = Packet(timestamp - self.start_time, ip_protocol,
869 protocol, opcode, desc, extra)
870 # XXX we're assuming the timestamp is already adjusted for
872 # XXX should we adjust client balance for guessed packets?
873 if packet.src == packet.endpoints[0]:
874 self.client_balance -= packet.client_score()
876 self.client_balance += packet.client_score()
877 if packet.is_really_a_packet():
878 self.packets.append(packet)
881 return ("<Conversation %s %s starting %.3f %d packets>" %
882 (self.conversation_id, self.endpoints, self.start_time,
888 return iter(self.packets)
891 return len(self.packets)
893 def get_duration(self):
894 if len(self.packets) < 2:
896 return self.packets[-1].timestamp - self.packets[0].timestamp
898 def replay_as_summary_lines(self):
900 for p in self.packets:
901 lines.append(p.as_summary(self.start_time))
904 def replay_with_delay(self, start, context=None, account=None):
905 """Replay the conversation at the right time.
906 (We're already in a fork)."""
907 # first we sleep until the first packet
909 now = time.time() - start
911 sleep_time = gap - SLEEP_OVERHEAD
913 time.sleep(sleep_time)
915 miss = (time.time() - start) - t
916 self.msg("starting %s [miss %.3f]" % (self, miss))
920 # packet times are relative to conversation start
921 p_start = time.time()
922 for p in self.packets:
923 now = time.time() - p_start
924 gap = now - p.timestamp
928 sleep_time = -gap - SLEEP_OVERHEAD
930 time.sleep(sleep_time)
931 t = time.time() - p_start
932 if t - p.timestamp > max_sleep_miss:
933 max_sleep_miss = t - p.timestamp
935 p.play(self, context)
937 return max_gap, miss, max_sleep_miss
939 def guess_client_server(self, server_clue=None):
940 """Have a go at deciding who is the server and who is the client.
941 returns (client, server)
943 a, b = self.endpoints
945 if self.client_balance < 0:
948 # in the absense of a clue, we will fall through to assuming
949 # the lowest number is the server (which is usually true).
951 if self.client_balance == 0 and server_clue == b:
956 def forget_packets_outside_window(self, s, e):
957 """Prune any packets outside the timne window we're interested in
959 :param s: start of the window
960 :param e: end of the window
962 self.packets = [p for p in self.packets if s <= p.timestamp <= e]
963 self.start_time = self.packets[0].timestamp if self.packets else None
965 def renormalise_times(self, start_time):
966 """Adjust the packet start times relative to the new start time."""
967 for p in self.packets:
968 p.timestamp -= start_time
970 if self.start_time is not None:
971 self.start_time -= start_time
974 class DnsHammer(Conversation):
975 """A lightweight conversation that generates a lot of dns:0 packets on
978 def __init__(self, dns_rate, duration):
979 n = int(dns_rate * duration)
980 self.times = [random.uniform(0, duration) for i in range(n)]
983 self.duration = duration
985 self.msg = random_colour_print()
988 return ("<DnsHammer %d packets over %.1fs (rate %.2f)>" %
989 (len(self.times), self.duration, self.rate))
991 def replay(self, context=None):
993 fn = traffic_packets.packet_dns_0
995 now = time.time() - start
997 sleep_time = gap - SLEEP_OVERHEAD
999 time.sleep(sleep_time)
1001 packet_start = time.time()
1003 fn(None, None, context)
1005 duration = end - packet_start
1006 print("%f\tDNS\tdns\t0\t%f\tTrue\t" % (end, duration))
1007 except Exception as e:
1009 duration = end - packet_start
1010 print("%f\tDNS\tdns\t0\t%f\tFalse\t%s" % (end, duration, e))
1013 def ingest_summaries(files, dns_mode='count'):
1014 """Load a summary traffic summary file and generated Converations from it.
1017 dns_counts = defaultdict(int)
1020 if isinstance(f, str):
1022 print("Ingesting %s" % (f.name,), file=sys.stderr)
1024 p = Packet.from_line(line)
1025 if p.protocol == 'dns' and dns_mode != 'include':
1026 dns_counts[p.opcode] += 1
1035 start_time = min(p.timestamp for p in packets)
1036 last_packet = max(p.timestamp for p in packets)
1038 print("gathering packets into conversations", file=sys.stderr)
1039 conversations = OrderedDict()
1040 for i, p in enumerate(packets):
1041 p.timestamp -= start_time
1042 c = conversations.get(p.endpoints)
1044 c = Conversation(conversation_id=(i + 2))
1045 conversations[p.endpoints] = c
1048 # We only care about conversations with actual traffic, so we
1049 # filter out conversations with nothing to say. We do that here,
1050 # rather than earlier, because those empty packets contain useful
1051 # hints as to which end of the conversation was the client.
1052 conversation_list = []
1053 for c in conversations.values():
1055 conversation_list.append(c)
1057 # This is obviously not correct, as many conversations will appear
1058 # to start roughly simultaneously at the beginning of the snapshot.
1059 # To which we say: oh well, so be it.
1060 duration = float(last_packet - start_time)
1061 mean_interval = len(conversations) / duration
1063 return conversation_list, mean_interval, duration, dns_counts
1066 def guess_server_address(conversations):
1067 # we guess the most common address.
1068 addresses = Counter()
1069 for c in conversations:
1070 addresses.update(c.endpoints)
1072 return addresses.most_common(1)[0]
1075 def stringify_keys(x):
1077 for k, v in x.items():
1083 def unstringify_keys(x):
1085 for k, v in x.items():
1086 t = tuple(str(k).split('\t'))
1091 class TrafficModel(object):
1092 def __init__(self, n=3):
1094 self.query_details = {}
1096 self.dns_opcounts = defaultdict(int)
1097 self.cumulative_duration = 0.0
1098 self.packet_rate = [0, 1]
1100 def learn(self, conversations, dns_opcounts={}):
1103 key = (NON_PACKET,) * (self.n - 1)
1105 server = guess_server_address(conversations)
1107 for k, v in dns_opcounts.items():
1108 self.dns_opcounts[k] += v
1110 if len(conversations) > 1:
1111 first = conversations[0].start_time
1114 for c in conversations:
1116 last = max(last, c.packets[-1].timestamp)
1118 self.packet_rate[0] = total
1119 self.packet_rate[1] = last - first
1121 for c in conversations:
1122 client, server = c.guess_client_server(server)
1123 cum_duration += c.get_duration()
1124 key = (NON_PACKET,) * (self.n - 1)
1129 elapsed = p.timestamp - prev
1131 if elapsed > WAIT_THRESHOLD:
1132 # add the wait as an extra state
1133 wait = 'wait:%d' % (math.log(max(1.0,
1134 elapsed * WAIT_SCALE)))
1135 self.ngrams.setdefault(key, []).append(wait)
1136 key = key[1:] + (wait,)
1138 short_p = p.as_packet_type()
1139 self.query_details.setdefault(short_p,
1140 []).append(tuple(p.extra))
1141 self.ngrams.setdefault(key, []).append(short_p)
1142 key = key[1:] + (short_p,)
1144 self.cumulative_duration += cum_duration
1146 self.ngrams.setdefault(key, []).append(NON_PACKET)
1150 for k, v in self.ngrams.items():
1152 ngrams[k] = dict(Counter(v))
1155 for k, v in self.query_details.items():
1156 query_details[k] = dict(Counter('\t'.join(x) if x else '-'
1161 'query_details': query_details,
1162 'cumulative_duration': self.cumulative_duration,
1163 'packet_rate': self.packet_rate,
1164 'version': CURRENT_MODEL_VERSION
1166 d['dns'] = self.dns_opcounts
1168 if isinstance(f, str):
1171 json.dump(d, f, indent=2)
1174 if isinstance(f, str):
1180 version = d["version"]
1181 if version < REQUIRED_MODEL_VERSION:
1182 raise ValueError("the model file is version %d; "
1183 "version %d is required" %
1184 (version, REQUIRED_MODEL_VERSION))
1186 raise ValueError("the model file lacks a version number; "
1187 "version %d is required" %
1188 (REQUIRED_MODEL_VERSION))
1190 for k, v in d['ngrams'].items():
1191 k = tuple(str(k).split('\t'))
1192 values = self.ngrams.setdefault(k, [])
1193 for p, count in v.items():
1194 values.extend([str(p)] * count)
1197 for k, v in d['query_details'].items():
1198 values = self.query_details.setdefault(str(k), [])
1199 for p, count in v.items():
1201 values.extend([()] * count)
1203 values.extend([tuple(str(p).split('\t'))] * count)
1207 for k, v in d['dns'].items():
1208 self.dns_opcounts[k] += v
1210 self.cumulative_duration = d['cumulative_duration']
1211 self.packet_rate = d['packet_rate']
1213 def construct_conversation_sequence(self, timestamp=0.0,
1217 """Construct an individual conversation packet sequence from the
1221 key = (NON_PACKET,) * (self.n - 1)
1222 if ignore_before is None:
1223 ignore_before = timestamp - 1
1226 p = random.choice(self.ngrams.get(key, (NON_PACKET,)))
1230 if p in self.query_details:
1231 extra = random.choice(self.query_details[p])
1235 protocol, opcode = p.split(':', 1)
1236 if protocol == 'wait':
1237 log_wait_time = int(opcode) + random.random()
1238 wait = math.exp(log_wait_time) / (WAIT_SCALE * replay_speed)
1241 log_wait = random.uniform(*NO_WAIT_LOG_TIME_RANGE)
1242 wait = math.exp(log_wait) / replay_speed
1244 if hard_stop is not None and timestamp > hard_stop:
1246 if timestamp >= ignore_before:
1247 c.append((timestamp, protocol, opcode, extra))
1249 key = key[1:] + (p,)
1253 def generate_conversation_sequences(self, scale, duration, replay_speed=1):
1254 """Generate a list of conversation descriptions from the model."""
1256 # We run the simulation for ten times as long as our desired
1257 # duration, and take the section at the end.
1258 lead_in = 9 * duration
1259 rate_n, rate_t = self.packet_rate
1260 target_packets = int(duration * scale * rate_n / rate_t)
1265 while n_packets < target_packets:
1266 start = random.uniform(-lead_in, duration)
1267 c = self.construct_conversation_sequence(start,
1269 replay_speed=replay_speed,
1271 # will these "packets" generate actual traffic?
1272 # some (e.g. ldap unbind) will not generate anything
1273 # if the previous packets are not there, and if the
1274 # conversation only has those it wastes a process doing nothing.
1275 for timestamp, protocol, opcode, extra in c:
1276 if is_a_traffic_generating_packet(protocol, opcode):
1281 conversations.append(c)
1284 print(("we have %d packets (target %d) in %d conversations at scale %f"
1285 % (n_packets, target_packets, len(conversations), scale)),
1287 conversations.sort() # sorts by first element == start time
1288 return conversations
1291 def seq_to_conversations(seq, server=1, client=2):
1295 c = Conversation(s[0][0], (server, client), s)
1297 conversations.append(c)
1298 return conversations
1303 'rpc_netlogon': '06',
1304 'kerberos': '06', # ratio 16248:258
1315 'smb_netlogon': '11',
1321 ('browser', '0x01'): 'Host Announcement (0x01)',
1322 ('browser', '0x02'): 'Request Announcement (0x02)',
1323 ('browser', '0x08'): 'Browser Election Request (0x08)',
1324 ('browser', '0x09'): 'Get Backup List Request (0x09)',
1325 ('browser', '0x0c'): 'Domain/Workgroup Announcement (0x0c)',
1326 ('browser', '0x0f'): 'Local Master Announcement (0x0f)',
1327 ('cldap', '3'): 'searchRequest',
1328 ('cldap', '5'): 'searchResDone',
1329 ('dcerpc', '0'): 'Request',
1330 ('dcerpc', '11'): 'Bind',
1331 ('dcerpc', '12'): 'Bind_ack',
1332 ('dcerpc', '13'): 'Bind_nak',
1333 ('dcerpc', '14'): 'Alter_context',
1334 ('dcerpc', '15'): 'Alter_context_resp',
1335 ('dcerpc', '16'): 'AUTH3',
1336 ('dcerpc', '2'): 'Response',
1337 ('dns', '0'): 'query',
1338 ('dns', '1'): 'response',
1339 ('drsuapi', '0'): 'DsBind',
1340 ('drsuapi', '12'): 'DsCrackNames',
1341 ('drsuapi', '13'): 'DsWriteAccountSpn',
1342 ('drsuapi', '1'): 'DsUnbind',
1343 ('drsuapi', '2'): 'DsReplicaSync',
1344 ('drsuapi', '3'): 'DsGetNCChanges',
1345 ('drsuapi', '4'): 'DsReplicaUpdateRefs',
1346 ('epm', '3'): 'Map',
1347 ('kerberos', ''): '',
1348 ('ldap', '0'): 'bindRequest',
1349 ('ldap', '1'): 'bindResponse',
1350 ('ldap', '2'): 'unbindRequest',
1351 ('ldap', '3'): 'searchRequest',
1352 ('ldap', '4'): 'searchResEntry',
1353 ('ldap', '5'): 'searchResDone',
1354 ('ldap', ''): '*** Unknown ***',
1355 ('lsarpc', '14'): 'lsa_LookupNames',
1356 ('lsarpc', '15'): 'lsa_LookupSids',
1357 ('lsarpc', '39'): 'lsa_QueryTrustedDomainInfoBySid',
1358 ('lsarpc', '40'): 'lsa_SetTrustedDomainInfo',
1359 ('lsarpc', '6'): 'lsa_OpenPolicy',
1360 ('lsarpc', '76'): 'lsa_LookupSids3',
1361 ('lsarpc', '77'): 'lsa_LookupNames4',
1362 ('nbns', '0'): 'query',
1363 ('nbns', '1'): 'response',
1364 ('rpc_netlogon', '21'): 'NetrLogonDummyRoutine1',
1365 ('rpc_netlogon', '26'): 'NetrServerAuthenticate3',
1366 ('rpc_netlogon', '29'): 'NetrLogonGetDomainInfo',
1367 ('rpc_netlogon', '30'): 'NetrServerPasswordSet2',
1368 ('rpc_netlogon', '39'): 'NetrLogonSamLogonEx',
1369 ('rpc_netlogon', '40'): 'DsrEnumerateDomainTrusts',
1370 ('rpc_netlogon', '45'): 'NetrLogonSamLogonWithFlags',
1371 ('rpc_netlogon', '4'): 'NetrServerReqChallenge',
1372 ('samr', '0',): 'Connect',
1373 ('samr', '16'): 'GetAliasMembership',
1374 ('samr', '17'): 'LookupNames',
1375 ('samr', '18'): 'LookupRids',
1376 ('samr', '19'): 'OpenGroup',
1377 ('samr', '1'): 'Close',
1378 ('samr', '25'): 'QueryGroupMember',
1379 ('samr', '34'): 'OpenUser',
1380 ('samr', '36'): 'QueryUserInfo',
1381 ('samr', '39'): 'GetGroupsForUser',
1382 ('samr', '3'): 'QuerySecurity',
1383 ('samr', '5'): 'LookupDomain',
1384 ('samr', '64'): 'Connect5',
1385 ('samr', '6'): 'EnumDomains',
1386 ('samr', '7'): 'OpenDomain',
1387 ('samr', '8'): 'QueryDomainInfo',
1388 ('smb', '0x04'): 'Close (0x04)',
1389 ('smb', '0x24'): 'Locking AndX (0x24)',
1390 ('smb', '0x2e'): 'Read AndX (0x2e)',
1391 ('smb', '0x32'): 'Trans2 (0x32)',
1392 ('smb', '0x71'): 'Tree Disconnect (0x71)',
1393 ('smb', '0x72'): 'Negotiate Protocol (0x72)',
1394 ('smb', '0x73'): 'Session Setup AndX (0x73)',
1395 ('smb', '0x74'): 'Logoff AndX (0x74)',
1396 ('smb', '0x75'): 'Tree Connect AndX (0x75)',
1397 ('smb', '0xa2'): 'NT Create AndX (0xa2)',
1398 ('smb2', '0'): 'NegotiateProtocol',
1399 ('smb2', '11'): 'Ioctl',
1400 ('smb2', '14'): 'Find',
1401 ('smb2', '16'): 'GetInfo',
1402 ('smb2', '18'): 'Break',
1403 ('smb2', '1'): 'SessionSetup',
1404 ('smb2', '2'): 'SessionLogoff',
1405 ('smb2', '3'): 'TreeConnect',
1406 ('smb2', '4'): 'TreeDisconnect',
1407 ('smb2', '5'): 'Create',
1408 ('smb2', '6'): 'Close',
1409 ('smb2', '8'): 'Read',
1410 ('smb_netlogon', '0x12'): 'SAM LOGON request from client (0x12)',
1411 ('smb_netlogon', '0x17'): ('SAM Active Directory Response - '
1412 'user unknown (0x17)'),
1413 ('srvsvc', '16'): 'NetShareGetInfo',
1414 ('srvsvc', '21'): 'NetSrvGetInfo',
1418 def expand_short_packet(p, timestamp, src, dest, extra):
1419 protocol, opcode = p.split(':', 1)
1420 desc = OP_DESCRIPTIONS.get((protocol, opcode), '')
1421 ip_protocol = IP_PROTOCOLS.get(protocol, '06')
1423 line = [timestamp, ip_protocol, '', src, dest, protocol, opcode, desc]
1425 return '\t'.join(line)
1428 def flushing_signal_handler(signal, frame):
1429 """Signal handler closes standard out and error.
1431 Triggered by a sigterm, ensures that the log messages are flushed
1432 to disk and not lost.
1439 def replay_seq_in_fork(cs, start, context, account, client_id, server_id=1):
1440 """Fork a new process and replay the conversation sequence."""
1441 # We will need to reseed the random number generator or all the
1442 # clients will end up using the same sequence of random
1443 # numbers. random.randint() is mixed in so the initial seed will
1444 # have an effect here.
1445 seed = client_id * 1000 + random.randint(0, 999)
1447 # flush our buffers so messages won't be written by both sides
1454 # we must never return, or we'll end up running parts of the
1455 # parent's clean-up code. So we work in a try...finally, and
1456 # try to print any exceptions.
1459 endpoints = (server_id, client_id)
1462 c = Conversation(t, endpoints, seq=cs, conversation_id=client_id)
1463 signal.signal(signal.SIGTERM, flushing_signal_handler)
1465 context.generate_process_local_config(account, c)
1468 filename = os.path.join(context.statsdir, 'stats-conversation-%d' %
1470 f = open(filename, 'w')
1474 except IOError as e:
1475 LOGGER.info("stdout closing failed with %s" % e)
1479 now = time.time() - start
1481 sleep_time = gap - SLEEP_OVERHEAD
1483 time.sleep(sleep_time)
1485 max_lag, start_lag, max_sleep_miss = c.replay_with_delay(start=start,
1487 print("Maximum lag: %f" % max_lag)
1488 print("Start lag: %f" % start_lag)
1489 print("Max sleep miss: %f" % max_sleep_miss)
1493 print(("EXCEPTION in child PID %d, conversation %s" % (os.getpid(), c)),
1495 traceback.print_exc(sys.stderr)
1503 def dnshammer_in_fork(dns_rate, duration):
1511 signal.signal(signal.SIGTERM, flushing_signal_handler)
1512 hammer = DnsHammer(dns_rate, duration)
1516 print(("EXCEPTION in child PID %d, the DNS hammer" % (os.getpid())),
1518 traceback.print_exc(sys.stderr)
1525 def replay(conversation_seq,
1532 latency_timeout=1.0,
1533 stop_on_any_error=False,
1536 context = ReplayContext(server=host,
1541 if len(accounts) < len(conversation_seq):
1542 raise ValueError(("we have %d accounts but %d conversations" %
1543 (len(accounts), len(conversation_seq))))
1545 # Set the process group so that the calling scripts are not killed
1546 # when the forked child processes are killed.
1549 # we delay the start by a bit to allow all the forks to get up and
1551 delay = len(conversation_seq) * 0.02
1552 start = time.time() + delay
1554 if duration is None:
1555 # end slightly after the last packet of the last conversation
1556 # to start. Conversations other than the last could still be
1557 # going, but we don't care.
1558 duration = conversation_seq[-1][-1][0] + latency_timeout
1560 print("We will start in %.1f seconds" % delay,
1562 print("We will stop after %.1f seconds" % (duration + delay),
1564 print("runtime %.1f seconds" % duration,
1567 # give one second grace for packets to finish before killing begins
1568 end = start + duration + 1.0
1570 LOGGER.info("Replaying traffic for %u conversations over %d seconds"
1571 % (len(conversation_seq), duration))
1577 pid = dnshammer_in_fork(dns_rate, duration)
1580 for i, cs in enumerate(conversation_seq):
1581 account = accounts[i]
1583 pid = replay_seq_in_fork(cs, start, context, account, client_id)
1584 children[pid] = client_id
1586 # HERE, we are past all the forks
1588 print("all forks done in %.1f seconds, waiting %.1f" %
1589 (t - start + delay, t - start),
1592 while time.time() < end and children:
1595 pid, status = os.waitpid(-1, os.WNOHANG)
1596 except OSError as e:
1597 if e.errno != ECHILD: # no child processes
1601 c = children.pop(pid, None)
1603 print(("process %d finished conversation %d;"
1605 (pid, c, len(children))), file=sys.stderr)
1606 if stop_on_any_error and status != 0:
1610 print("EXCEPTION in parent", file=sys.stderr)
1611 traceback.print_exc()
1613 for s in (15, 15, 9):
1614 print(("killing %d children with -%d" %
1615 (len(children), s)), file=sys.stderr)
1616 for pid in children:
1619 except OSError as e:
1620 if e.errno != ESRCH: # don't fail if it has already died
1623 end = time.time() + 1
1626 pid, status = os.waitpid(-1, os.WNOHANG)
1627 except OSError as e:
1628 if e.errno != ECHILD:
1631 c = children.pop(pid, None)
1633 print("children is %s, no pid found" % children)
1637 print(("kill -%d %d KILLED conversation; "
1639 (s, pid, len(children))),
1641 if time.time() >= end:
1649 print("%d children are missing" % len(children),
1652 # there may be stragglers that were forked just as ^C was hit
1653 # and don't appear in the list of children. We can get them
1654 # with killpg, but that will also kill us, so this is^H^H would be
1655 # goodbye, except we cheat and pretend to use ^C (SIG_INTERRUPT),
1656 # so as not to have to fuss around writing signal handlers.
1659 except KeyboardInterrupt:
1660 print("ignoring fake ^C", file=sys.stderr)
1663 def openLdb(host, creds, lp):
1664 session = system_session()
1665 ldb = SamDB(url="ldap://%s" % host,
1666 session_info=session,
1667 options=['modules:paged_searches'],
1673 def ou_name(ldb, instance_id):
1674 """Generate an ou name from the instance id"""
1675 return "ou=instance-%d,ou=traffic_replay,%s" % (instance_id,
1679 def create_ou(ldb, instance_id):
1680 """Create an ou, all created user and machine accounts will belong to it.
1682 This allows all the created resources to be cleaned up easily.
1684 ou = ou_name(ldb, instance_id)
1686 ldb.add({"dn": ou.split(',', 1)[1],
1687 "objectclass": "organizationalunit"})
1688 except LdbError as e:
1689 (status, _) = e.args
1690 # ignore already exists
1695 "objectclass": "organizationalunit"})
1696 except LdbError as e:
1697 (status, _) = e.args
1698 # ignore already exists
1704 # ConversationAccounts holds details of the machine and user accounts
1705 # associated with a conversation.
1707 # We use a named tuple to reduce shared memory usage.
1708 ConversationAccounts = namedtuple('ConversationAccounts',
1715 def generate_replay_accounts(ldb, instance_id, number, password):
1716 """Generate a series of unique machine and user account names."""
1719 for i in range(1, number + 1):
1720 netbios_name = machine_name(instance_id, i)
1721 username = user_name(instance_id, i)
1723 account = ConversationAccounts(netbios_name, password, username,
1725 accounts.append(account)
1729 def create_machine_account(ldb, instance_id, netbios_name, machinepass,
1730 traffic_account=True):
1731 """Create a machine account via ldap."""
1733 ou = ou_name(ldb, instance_id)
1734 dn = "cn=%s,%s" % (netbios_name, ou)
1735 utf16pw = ('"%s"' % get_string(machinepass)).encode('utf-16-le')
1738 # we set these bits for the machine account otherwise the replayed
1739 # traffic throws up NT_STATUS_NO_TRUST_SAM_ACCOUNT errors
1740 account_controls = str(UF_TRUSTED_FOR_DELEGATION |
1741 UF_SERVER_TRUST_ACCOUNT)
1744 account_controls = str(UF_WORKSTATION_TRUST_ACCOUNT)
1748 "objectclass": "computer",
1749 "sAMAccountName": "%s$" % netbios_name,
1750 "userAccountControl": account_controls,
1751 "unicodePwd": utf16pw})
1754 def create_user_account(ldb, instance_id, username, userpass):
1755 """Create a user account via ldap."""
1756 ou = ou_name(ldb, instance_id)
1757 user_dn = "cn=%s,%s" % (username, ou)
1758 utf16pw = ('"%s"' % get_string(userpass)).encode('utf-16-le')
1761 "objectclass": "user",
1762 "sAMAccountName": username,
1763 "userAccountControl": str(UF_NORMAL_ACCOUNT),
1764 "unicodePwd": utf16pw
1767 # grant user write permission to do things like write account SPN
1768 sdutils = sd_utils.SDUtils(ldb)
1769 sdutils.dacl_add_ace(user_dn, "(A;;WP;;;PS)")
1772 def create_group(ldb, instance_id, name):
1773 """Create a group via ldap."""
1775 ou = ou_name(ldb, instance_id)
1776 dn = "cn=%s,%s" % (name, ou)
1779 "objectclass": "group",
1780 "sAMAccountName": name,
1784 def user_name(instance_id, i):
1785 """Generate a user name based in the instance id"""
1786 return "STGU-%d-%d" % (instance_id, i)
1789 def search_objectclass(ldb, objectclass='user', attr='sAMAccountName'):
1790 """Seach objectclass, return attr in a set"""
1792 expression="(objectClass={})".format(objectclass),
1795 return {str(obj[attr]) for obj in objs}
1798 def generate_users(ldb, instance_id, number, password):
1799 """Add users to the server"""
1800 existing_objects = search_objectclass(ldb, objectclass='user')
1802 for i in range(number, 0, -1):
1803 name = user_name(instance_id, i)
1804 if name not in existing_objects:
1805 create_user_account(ldb, instance_id, name, password)
1808 LOGGER.info("Created %u/%u users" % (users, number))
1813 def machine_name(instance_id, i, traffic_account=True):
1814 """Generate a machine account name from instance id."""
1816 # traffic accounts correspond to a given user, and use different
1817 # userAccountControl flags to ensure packets get processed correctly
1819 return "STGM-%d-%d" % (instance_id, i)
1821 # Otherwise we're just generating computer accounts to simulate a
1822 # semi-realistic network. These use the default computer
1823 # userAccountControl flags, so we use a different account name so that
1824 # we don't try to use them when generating packets
1825 return "PC-%d-%d" % (instance_id, i)
1828 def generate_machine_accounts(ldb, instance_id, number, password,
1829 traffic_account=True):
1830 """Add machine accounts to the server"""
1831 existing_objects = search_objectclass(ldb, objectclass='computer')
1833 for i in range(number, 0, -1):
1834 name = machine_name(instance_id, i, traffic_account)
1835 if name + "$" not in existing_objects:
1836 create_machine_account(ldb, instance_id, name, password,
1840 LOGGER.info("Created %u/%u machine accounts" % (added, number))
1845 def group_name(instance_id, i):
1846 """Generate a group name from instance id."""
1847 return "STGG-%d-%d" % (instance_id, i)
1850 def generate_groups(ldb, instance_id, number):
1851 """Create the required number of groups on the server."""
1852 existing_objects = search_objectclass(ldb, objectclass='group')
1854 for i in range(number, 0, -1):
1855 name = group_name(instance_id, i)
1856 if name not in existing_objects:
1857 create_group(ldb, instance_id, name)
1859 if groups % 1000 == 0:
1860 LOGGER.info("Created %u/%u groups" % (groups, number))
1865 def clean_up_accounts(ldb, instance_id):
1866 """Remove the created accounts and groups from the server."""
1867 ou = ou_name(ldb, instance_id)
1869 ldb.delete(ou, ["tree_delete:1"])
1870 except LdbError as e:
1871 (status, _) = e.args
1872 # ignore does not exist
1877 def generate_users_and_groups(ldb, instance_id, password,
1878 number_of_users, number_of_groups,
1879 group_memberships, max_members,
1880 machine_accounts, traffic_accounts=True):
1881 """Generate the required users and groups, allocating the users to
1883 memberships_added = 0
1887 create_ou(ldb, instance_id)
1889 LOGGER.info("Generating dummy user accounts")
1890 users_added = generate_users(ldb, instance_id, number_of_users, password)
1892 LOGGER.info("Generating dummy machine accounts")
1893 computers_added = generate_machine_accounts(ldb, instance_id,
1894 machine_accounts, password,
1897 if number_of_groups > 0:
1898 LOGGER.info("Generating dummy groups")
1899 groups_added = generate_groups(ldb, instance_id, number_of_groups)
1901 if group_memberships > 0:
1902 LOGGER.info("Assigning users to groups")
1903 assignments = GroupAssignments(number_of_groups,
1909 LOGGER.info("Adding users to groups")
1910 add_users_to_groups(ldb, instance_id, assignments)
1911 memberships_added = assignments.total()
1913 if (groups_added > 0 and users_added == 0 and
1914 number_of_groups != groups_added):
1915 LOGGER.warning("The added groups will contain no members")
1917 LOGGER.info("Added %d users (%d machines), %d groups and %d memberships" %
1918 (users_added, computers_added, groups_added,
1922 class GroupAssignments(object):
1923 def __init__(self, number_of_groups, groups_added, number_of_users,
1924 users_added, group_memberships, max_members):
1927 self.generate_group_distribution(number_of_groups)
1928 self.generate_user_distribution(number_of_users, group_memberships)
1929 self.max_members = max_members
1930 self.assignments = defaultdict(list)
1931 self.assign_groups(number_of_groups, groups_added, number_of_users,
1932 users_added, group_memberships)
1934 def cumulative_distribution(self, weights):
1935 # make sure the probabilities conform to a cumulative distribution
1936 # spread between 0.0 and 1.0. Dividing by the weighted total gives each
1937 # probability a proportional share of 1.0. Higher probabilities get a
1938 # bigger share, so are more likely to be picked. We use the cumulative
1939 # value, so we can use random.random() as a simple index into the list
1941 total = sum(weights)
1946 for probability in weights:
1947 cumulative += probability
1948 dist.append(cumulative / total)
1951 def generate_user_distribution(self, num_users, num_memberships):
1952 """Probability distribution of a user belonging to a group.
1954 # Assign a weighted probability to each user. Use the Pareto
1955 # Distribution so that some users are in a lot of groups, and the
1956 # bulk of users are in only a few groups. If we're assigning a large
1957 # number of group memberships, use a higher shape. This means slightly
1958 # fewer outlying users that are in large numbers of groups. The aim is
1959 # to have no users belonging to more than ~500 groups.
1960 if num_memberships > 5000000:
1962 elif num_memberships > 2000000:
1964 elif num_memberships > 300000:
1970 for x in range(1, num_users + 1):
1971 p = random.paretovariate(shape)
1974 # convert the weights to a cumulative distribution between 0.0 and 1.0
1975 self.user_dist = self.cumulative_distribution(weights)
1977 def generate_group_distribution(self, n):
1978 """Probability distribution of a group containing a user."""
1980 # Assign a weighted probability to each user. Probability decreases
1981 # as the group-ID increases
1983 for x in range(1, n + 1):
1987 # convert the weights to a cumulative distribution between 0.0 and 1.0
1988 self.group_weights = weights
1989 self.group_dist = self.cumulative_distribution(weights)
1991 def generate_random_membership(self):
1992 """Returns a randomly generated user-group membership"""
1994 # the list items are cumulative distribution values between 0.0 and
1995 # 1.0, which makes random() a handy way to index the list to get a
1996 # weighted random user/group. (Here the user/group returned are
1997 # zero-based array indexes)
1998 user = bisect.bisect(self.user_dist, random.random())
1999 group = bisect.bisect(self.group_dist, random.random())
2003 def users_in_group(self, group):
2004 return self.assignments[group]
2006 def get_groups(self):
2007 return self.assignments.keys()
2009 def cap_group_membership(self, group, max_members):
2010 """Prevent the group's membership from exceeding the max specified"""
2011 num_members = len(self.assignments[group])
2012 if num_members >= max_members:
2013 LOGGER.info("Group {0} has {1} members".format(group, num_members))
2015 # remove this group and then recalculate the cumulative
2016 # distribution, so this group is no longer selected
2017 self.group_weights[group - 1] = 0
2018 new_dist = self.cumulative_distribution(self.group_weights)
2019 self.group_dist = new_dist
2021 def add_assignment(self, user, group):
2022 # the assignments are stored in a dictionary where key=group,
2023 # value=list-of-users-in-group (indexing by group-ID allows us to
2024 # optimize for DB membership writes)
2025 if user not in self.assignments[group]:
2026 self.assignments[group].append(user)
2029 # check if there'a cap on how big the groups can grow
2030 if self.max_members:
2031 self.cap_group_membership(group, self.max_members)
2033 def assign_groups(self, number_of_groups, groups_added,
2034 number_of_users, users_added, group_memberships):
2035 """Allocate users to groups.
2037 The intention is to have a few users that belong to most groups, while
2038 the majority of users belong to a few groups.
2040 A few groups will contain most users, with the remaining only having a
2044 if group_memberships <= 0:
2047 # Calculate the number of group menberships required
2048 group_memberships = math.ceil(
2049 float(group_memberships) *
2050 (float(users_added) / float(number_of_users)))
2052 if self.max_members:
2053 group_memberships = min(group_memberships,
2054 self.max_members * number_of_groups)
2056 existing_users = number_of_users - users_added - 1
2057 existing_groups = number_of_groups - groups_added - 1
2058 while self.total() < group_memberships:
2059 user, group = self.generate_random_membership()
2061 if group > existing_groups or user > existing_users:
2062 # the + 1 converts the array index to the corresponding
2063 # group or user number
2064 self.add_assignment(user + 1, group + 1)
2070 def add_users_to_groups(db, instance_id, assignments):
2071 """Takes the assignments of users to groups and applies them to the DB."""
2073 total = assignments.total()
2077 for group in assignments.get_groups():
2078 users_in_group = assignments.users_in_group(group)
2079 if len(users_in_group) == 0:
2082 # Split up the users into chunks, so we write no more than 1K at a
2083 # time. (Minimizing the DB modifies is more efficient, but writing
2084 # 10K+ users to a single group becomes inefficient memory-wise)
2085 for chunk in range(0, len(users_in_group), 1000):
2086 chunk_of_users = users_in_group[chunk:chunk + 1000]
2087 add_group_members(db, instance_id, group, chunk_of_users)
2089 added += len(chunk_of_users)
2092 LOGGER.info("Added %u/%u memberships" % (added, total))
2094 def add_group_members(db, instance_id, group, users_in_group):
2095 """Adds the given users to group specified."""
2097 ou = ou_name(db, instance_id)
2100 return("cn=%s,%s" % (name, ou))
2102 group_dn = build_dn(group_name(instance_id, group))
2104 m.dn = ldb.Dn(db, group_dn)
2106 for user in users_in_group:
2107 user_dn = build_dn(user_name(instance_id, user))
2108 idx = "member-" + str(user)
2109 m[idx] = ldb.MessageElement(user_dn, ldb.FLAG_MOD_ADD, "member")
2114 def generate_stats(statsdir, timing_file):
2115 """Generate and print the summary stats for a run."""
2116 first = sys.float_info.max
2122 unique_converations = set()
2125 if timing_file is not None:
2126 tw = timing_file.write
2131 tw("time\tconv\tprotocol\ttype\tduration\tsuccessful\terror\n")
2133 for filename in os.listdir(statsdir):
2134 path = os.path.join(statsdir, filename)
2135 with open(path, 'r') as f:
2138 fields = line.rstrip('\n').split('\t')
2139 conversation = fields[1]
2140 protocol = fields[2]
2141 packet_type = fields[3]
2142 latency = float(fields[4])
2143 first = min(float(fields[0]) - latency, first)
2144 last = max(float(fields[0]), last)
2146 if protocol not in latencies:
2147 latencies[protocol] = {}
2148 if packet_type not in latencies[protocol]:
2149 latencies[protocol][packet_type] = []
2151 latencies[protocol][packet_type].append(latency)
2153 if protocol not in failures:
2154 failures[protocol] = {}
2155 if packet_type not in failures[protocol]:
2156 failures[protocol][packet_type] = 0
2158 if fields[5] == 'True':
2162 failures[protocol][packet_type] += 1
2164 if conversation not in unique_converations:
2165 unique_converations.add(conversation)
2169 except (ValueError, IndexError):
2170 # not a valid line print and ignore
2171 print(line, file=sys.stderr)
2173 duration = last - first
2177 success_rate = successful / duration
2181 failure_rate = failed / duration
2183 print("Total conversations: %10d" % conversations)
2184 print("Successful operations: %10d (%.3f per second)"
2185 % (successful, success_rate))
2186 print("Failed operations: %10d (%.3f per second)"
2187 % (failed, failure_rate))
2189 print("Protocol Op Code Description "
2190 " Count Failed Mean Median "
2193 protocols = sorted(latencies.keys())
2194 for protocol in protocols:
2195 packet_types = sorted(latencies[protocol], key=opcode_key)
2196 for packet_type in packet_types:
2197 values = latencies[protocol][packet_type]
2198 values = sorted(values)
2200 failed = failures[protocol][packet_type]
2201 mean = sum(values) / count
2202 median = calc_percentile(values, 0.50)
2203 percentile = calc_percentile(values, 0.95)
2204 rng = values[-1] - values[0]
2206 desc = OP_DESCRIPTIONS.get((protocol, packet_type), '')
2207 if sys.stdout.isatty:
2208 print("%-12s %4s %-35s %12d %12d %12.6f "
2209 "%12.6f %12.6f %12.6f %12.6f"
2221 print("%s\t%s\t%s\t%d\t%d\t%f\t%f\t%f\t%f\t%f"
2235 """Sort key for the operation code to ensure that it sorts numerically"""
2237 return "%03d" % int(v)
2242 def calc_percentile(values, percentile):
2243 """Calculate the specified percentile from the list of values.
2245 Assumes the list is sorted in ascending order.
2250 k = (len(values) - 1) * percentile
2254 return values[int(k)]
2255 d0 = values[int(f)] * (c - k)
2256 d1 = values[int(c)] * (k - f)
2260 def mk_masked_dir(*path):
2261 """In a testenv we end up with 0777 directories that look an alarming
2262 green colour with ls. Use umask to avoid that."""
2263 # py3 os.mkdir can do this
2264 d = os.path.join(*path)
2265 mask = os.umask(0o077)