1 # -*- encoding: utf-8 -*-
2 # Samba traffic replay and learning
4 # Copyright (C) Catalyst IT Ltd. 2017
6 # This program is free software; you can redistribute it and/or modify
7 # it under the terms of the GNU General Public License as published by
8 # the Free Software Foundation; either version 3 of the License, or
9 # (at your option) any later version.
11 # This program is distributed in the hope that it will be useful,
12 # but WITHOUT ANY WARRANTY; without even the implied warranty of
13 # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
14 # GNU General Public License for more details.
16 # You should have received a copy of the GNU General Public License
17 # along with this program. If not, see <http://www.gnu.org/licenses/>.
19 from __future__ import print_function, division
30 from collections import OrderedDict, Counter, defaultdict
31 from samba.emulate import traffic_packets
32 from samba.samdb import SamDB
34 from ldb import LdbError
35 from samba.dcerpc import ClientConnection
36 from samba.dcerpc import security, drsuapi, lsa
37 from samba.dcerpc import netlogon
38 from samba.dcerpc.netlogon import netr_Authenticator
39 from samba.dcerpc import srvsvc
40 from samba.dcerpc import samr
41 from samba.drs_utils import drs_DsBind
43 from samba.credentials import Credentials, DONT_USE_KERBEROS, MUST_USE_KERBEROS
44 from samba.auth import system_session
45 from samba.dsdb import UF_WORKSTATION_TRUST_ACCOUNT, UF_PASSWD_NOTREQD
46 from samba.dsdb import UF_NORMAL_ACCOUNT
47 from samba.dcerpc.misc import SEC_CHAN_WKSTA
48 from samba import gensec
49 from samba import sd_utils
53 # we don't use None, because it complicates [de]serialisation
57 ('dns', '0'): 1.0, # query
58 ('smb', '0x72'): 1.0, # Negotiate protocol
59 ('ldap', '0'): 1.0, # bind
60 ('ldap', '3'): 1.0, # searchRequest
61 ('ldap', '2'): 1.0, # unbindRequest
63 ('dcerpc', '11'): 1.0, # bind
64 ('dcerpc', '14'): 1.0, # Alter_context
65 ('nbns', '0'): 1.0, # query
69 ('dns', '1'): 1.0, # response
70 ('ldap', '1'): 1.0, # bind response
71 ('ldap', '4'): 1.0, # search result
72 ('ldap', '5'): 1.0, # search done
74 ('dcerpc', '12'): 1.0, # bind_ack
75 ('dcerpc', '13'): 1.0, # bind_nak
76 ('dcerpc', '15'): 1.0, # Alter_context response
79 SKIPPED_PROTOCOLS = {"smb", "smb2", "browser", "smb_netlogon"}
82 WAIT_THRESHOLD = (1.0 / WAIT_SCALE)
83 NO_WAIT_LOG_TIME_RANGE = (-10, -3)
85 # DEBUG_LEVEL can be changed by scripts with -d
89 def debug(level, msg, *args):
90 """Print a formatted debug message to standard error.
93 :param level: The debug level, message will be printed if it is <= the
94 currently set debug level. The debug level can be set with
96 :param msg: The message to be logged, can contain C-Style format
98 :param args: The parameters required by the format specifiers
100 if level <= DEBUG_LEVEL:
102 print(msg, file=sys.stderr)
104 print(msg % tuple(args), file=sys.stderr)
107 def debug_lineno(*args):
108 """ Print an unformatted log message to stderr, contaning the line number
110 tb = traceback.extract_stack(limit=2)
111 print((" %s:" "\033[01;33m"
112 "%s " "\033[00m" % (tb[0][2], tb[0][1])), end=' ',
115 print(a, file=sys.stderr)
116 print(file=sys.stderr)
120 def random_colour_print():
121 """Return a function that prints a randomly coloured line to stderr"""
122 n = 18 + random.randrange(214)
123 prefix = "\033[38;5;%dm" % n
127 print("%s%s\033[00m" % (prefix, a), file=sys.stderr)
132 class FakePacketError(Exception):
136 class Packet(object):
137 """Details of a network packet"""
138 def __init__(self, fields):
139 if isinstance(fields, str):
140 fields = fields.rstrip('\n').split('\t')
152 self.timestamp = float(timestamp)
153 self.ip_protocol = ip_protocol
155 self.stream_number = int(stream_number)
156 except (ValueError, TypeError):
157 self.stream_number = None
159 self.dest = int(dest)
160 self.protocol = protocol
165 if self.src < self.dest:
166 self.endpoints = (self.src, self.dest)
168 self.endpoints = (self.dest, self.src)
170 def as_summary(self, time_offset=0.0):
171 """Format the packet as a traffic_summary line.
173 extra = '\t'.join(self.extra)
174 t = self.timestamp + time_offset
175 return (t, '%f\t%s\t%s\t%d\t%d\t%s\t%s\t%s\t%s' %
178 self.stream_number or '',
187 return ("%.3f: %d -> %d; ip %s; strm %s; prot %s; op %s; desc %s %s" %
188 (self.timestamp, self.src, self.dest, self.ip_protocol or '-',
189 self.stream_number, self.protocol, self.opcode, self.desc,
190 ('«' + ' '.join(self.extra) + '»' if self.extra else '')))
193 return "<Packet @%s>" % self
196 return self.__class__([self.timestamp,
203 self.desc] + self.extra)
205 def as_packet_type(self):
206 t = '%s:%s' % (self.protocol, self.opcode)
209 def client_score(self):
210 """A positive number means we think it is a client; a negative number
211 means we think it is a server. Zero means no idea. range: -1 to 1.
213 key = (self.protocol, self.opcode)
214 if key in CLIENT_CLUES:
215 return CLIENT_CLUES[key]
216 if key in SERVER_CLUES:
217 return -SERVER_CLUES[key]
220 def play(self, conversation, context):
221 """Send the packet over the network, if required.
223 Some packets are ignored, i.e. for protocols not handled,
224 server response messages, or messages that are generated by the
225 protocol layer associated with other packets.
227 fn_name = 'packet_%s_%s' % (self.protocol, self.opcode)
229 fn = getattr(traffic_packets, fn_name)
231 except AttributeError as e:
232 print("Conversation(%s) Missing handler %s" % \
233 (conversation.conversation_id, fn_name),
237 # Don't display a message for kerberos packets, they're not directly
238 # generated they're used to indicate kerberos should be used
239 if self.protocol != "kerberos":
240 debug(2, "Conversation(%s) Calling handler %s" %
241 (conversation.conversation_id, fn_name))
245 if fn(self, conversation, context):
246 # Only collect timing data for functions that generate
247 # network traffic, or fail
249 duration = end - start
250 print("%f\t%s\t%s\t%s\t%f\tTrue\t" %
251 (end, conversation.conversation_id, self.protocol,
252 self.opcode, duration))
253 except Exception as e:
255 duration = end - start
256 print("%f\t%s\t%s\t%s\t%f\tFalse\t%s" %
257 (end, conversation.conversation_id, self.protocol,
258 self.opcode, duration, e))
260 def __cmp__(self, other):
261 return self.timestamp - other.timestamp
263 def is_really_a_packet(self, missing_packet_stats=None):
264 """Is the packet one that can be ignored?
266 If so removing it will have no effect on the replay
268 if self.protocol in SKIPPED_PROTOCOLS:
269 # Ignore any packets for the protocols we're not interested in.
271 if self.protocol == "ldap" and self.opcode == '':
272 # skip ldap continuation packets
275 fn_name = 'packet_%s_%s' % (self.protocol, self.opcode)
277 fn = getattr(traffic_packets, fn_name)
278 if fn is traffic_packets.null_packet:
280 except AttributeError:
281 print("missing packet %s" % fn_name, file=sys.stderr)
286 class ReplayContext(object):
287 """State/Context for an individual conversation between an simulated client
295 badpassword_frequency=None,
296 prefer_kerberos=None,
305 self.ldap_connections = []
306 self.dcerpc_connections = []
307 self.lsarpc_connections = []
308 self.lsarpc_connections_named = []
309 self.drsuapi_connections = []
310 self.srvsvc_connections = []
311 self.samr_contexts = []
312 self.netlogon_connection = None
315 self.prefer_kerberos = prefer_kerberos
317 self.base_dn = base_dn
319 self.statsdir = statsdir
320 self.global_tempdir = tempdir
321 self.domain_sid = domain_sid
322 self.realm = lp.get('realm')
324 # Bad password attempt controls
325 self.badpassword_frequency = badpassword_frequency
326 self.last_lsarpc_bad = False
327 self.last_lsarpc_named_bad = False
328 self.last_simple_bind_bad = False
329 self.last_bind_bad = False
330 self.last_srvsvc_bad = False
331 self.last_drsuapi_bad = False
332 self.last_netlogon_bad = False
333 self.last_samlogon_bad = False
334 self.generate_ldap_search_tables()
335 self.next_conversation_id = itertools.count().next
337 def generate_ldap_search_tables(self):
338 session = system_session()
340 db = SamDB(url="ldap://%s" % self.server,
341 session_info=session,
342 credentials=self.creds,
345 res = db.search(db.domain_dn(),
346 scope=ldb.SCOPE_SUBTREE,
347 controls=["paged_results:1:1000"],
350 # find a list of dns for each pattern
351 # e.g. CN,CN,CN,DC,DC
353 attribute_clue_map = {
359 pattern = ','.join(x.lstrip()[:2] for x in dn.split(',')).upper()
360 dns = dn_map.setdefault(pattern, [])
362 if dn.startswith('CN=NTDS Settings,'):
363 attribute_clue_map['invocationId'].append(dn)
365 # extend the map in case we are working with a different
366 # number of DC components.
367 # for k, v in self.dn_map.items():
368 # print >>sys.stderr, k, len(v)
370 for k, v in dn_map.items():
374 while p[-3:] == ',DC':
378 if p != k and p in dn_map:
379 print('dn_map collison %s %s' % (k, p),
382 dn_map[p] = dn_map[k]
385 self.attribute_clue_map = attribute_clue_map
387 def generate_process_local_config(self, account, conversation):
390 self.netbios_name = account.netbios_name
391 self.machinepass = account.machinepass
392 self.username = account.username
393 self.userpass = account.userpass
395 self.tempdir = mk_masked_dir(self.global_tempdir,
397 conversation.conversation_id)
399 self.lp.set("private dir", self.tempdir)
400 self.lp.set("lock dir", self.tempdir)
401 self.lp.set("state directory", self.tempdir)
402 self.lp.set("tls verify peer", "no_check")
404 # If the domain was not specified, check for the environment
406 if self.domain is None:
407 self.domain = os.environ["DOMAIN"]
409 self.remoteAddress = "/root/ncalrpc_as_system"
410 self.samlogon_dn = ("cn=%s,%s" %
411 (self.netbios_name, self.ou))
412 self.user_dn = ("cn=%s,%s" %
413 (self.username, self.ou))
415 self.generate_machine_creds()
416 self.generate_user_creds()
418 def with_random_bad_credentials(self, f, good, bad, failed_last_time):
419 """Execute the supplied logon function, randomly choosing the
422 Based on the frequency in badpassword_frequency randomly perform the
423 function with the supplied bad credentials.
424 If run with bad credentials, the function is re-run with the good
426 failed_last_time is used to prevent consecutive bad credential
427 attempts. So the over all bad credential frequency will be lower
428 than that requested, but not significantly.
430 if not failed_last_time:
431 if (self.badpassword_frequency > 0 and
432 random.random() < self.badpassword_frequency):
436 # Ignore any exceptions as the operation may fail
437 # as it's being performed with bad credentials
439 failed_last_time = True
441 failed_last_time = False
444 return (result, failed_last_time)
446 def generate_user_creds(self):
447 """Generate the conversation specific user Credentials.
449 Each Conversation has an associated user account used to simulate
450 any non Administrative user traffic.
452 Generates user credentials with good and bad passwords and ldap
453 simple bind credentials with good and bad passwords.
455 self.user_creds = Credentials()
456 self.user_creds.guess(self.lp)
457 self.user_creds.set_workstation(self.netbios_name)
458 self.user_creds.set_password(self.userpass)
459 self.user_creds.set_username(self.username)
460 self.user_creds.set_domain(self.domain)
461 if self.prefer_kerberos:
462 self.user_creds.set_kerberos_state(MUST_USE_KERBEROS)
464 self.user_creds.set_kerberos_state(DONT_USE_KERBEROS)
466 self.user_creds_bad = Credentials()
467 self.user_creds_bad.guess(self.lp)
468 self.user_creds_bad.set_workstation(self.netbios_name)
469 self.user_creds_bad.set_password(self.userpass[:-4])
470 self.user_creds_bad.set_username(self.username)
471 if self.prefer_kerberos:
472 self.user_creds_bad.set_kerberos_state(MUST_USE_KERBEROS)
474 self.user_creds_bad.set_kerberos_state(DONT_USE_KERBEROS)
476 # Credentials for ldap simple bind.
477 self.simple_bind_creds = Credentials()
478 self.simple_bind_creds.guess(self.lp)
479 self.simple_bind_creds.set_workstation(self.netbios_name)
480 self.simple_bind_creds.set_password(self.userpass)
481 self.simple_bind_creds.set_username(self.username)
482 self.simple_bind_creds.set_gensec_features(
483 self.simple_bind_creds.get_gensec_features() | gensec.FEATURE_SEAL)
484 if self.prefer_kerberos:
485 self.simple_bind_creds.set_kerberos_state(MUST_USE_KERBEROS)
487 self.simple_bind_creds.set_kerberos_state(DONT_USE_KERBEROS)
488 self.simple_bind_creds.set_bind_dn(self.user_dn)
490 self.simple_bind_creds_bad = Credentials()
491 self.simple_bind_creds_bad.guess(self.lp)
492 self.simple_bind_creds_bad.set_workstation(self.netbios_name)
493 self.simple_bind_creds_bad.set_password(self.userpass[:-4])
494 self.simple_bind_creds_bad.set_username(self.username)
495 self.simple_bind_creds_bad.set_gensec_features(
496 self.simple_bind_creds_bad.get_gensec_features() |
498 if self.prefer_kerberos:
499 self.simple_bind_creds_bad.set_kerberos_state(MUST_USE_KERBEROS)
501 self.simple_bind_creds_bad.set_kerberos_state(DONT_USE_KERBEROS)
502 self.simple_bind_creds_bad.set_bind_dn(self.user_dn)
504 def generate_machine_creds(self):
505 """Generate the conversation specific machine Credentials.
507 Each Conversation has an associated machine account.
509 Generates machine credentials with good and bad passwords.
512 self.machine_creds = Credentials()
513 self.machine_creds.guess(self.lp)
514 self.machine_creds.set_workstation(self.netbios_name)
515 self.machine_creds.set_secure_channel_type(SEC_CHAN_WKSTA)
516 self.machine_creds.set_password(self.machinepass)
517 self.machine_creds.set_username(self.netbios_name + "$")
518 self.machine_creds.set_domain(self.domain)
519 if self.prefer_kerberos:
520 self.machine_creds.set_kerberos_state(MUST_USE_KERBEROS)
522 self.machine_creds.set_kerberos_state(DONT_USE_KERBEROS)
524 self.machine_creds_bad = Credentials()
525 self.machine_creds_bad.guess(self.lp)
526 self.machine_creds_bad.set_workstation(self.netbios_name)
527 self.machine_creds_bad.set_secure_channel_type(SEC_CHAN_WKSTA)
528 self.machine_creds_bad.set_password(self.machinepass[:-4])
529 self.machine_creds_bad.set_username(self.netbios_name + "$")
530 if self.prefer_kerberos:
531 self.machine_creds_bad.set_kerberos_state(MUST_USE_KERBEROS)
533 self.machine_creds_bad.set_kerberos_state(DONT_USE_KERBEROS)
535 def get_matching_dn(self, pattern, attributes=None):
536 # If the pattern is an empty string, we assume ROOTDSE,
537 # Otherwise we try adding or removing DC suffixes, then
538 # shorter leading patterns until we hit one.
539 # e.g if there is no CN,CN,CN,CN,DC,DC
540 # we first try CN,CN,CN,CN,DC
541 # and CN,CN,CN,CN,DC,DC,DC
542 # then change to CN,CN,CN,DC,DC
543 # and as last resort we use the base_dn
544 attr_clue = self.attribute_clue_map.get(attributes)
546 return random.choice(attr_clue)
548 pattern = pattern.upper()
550 if pattern in self.dn_map:
551 return random.choice(self.dn_map[pattern])
552 # chop one off the front and try it all again.
553 pattern = pattern[3:]
557 def get_dcerpc_connection(self, new=False):
558 guid = '12345678-1234-abcd-ef00-01234567cffb' # RPC_NETLOGON UUID
559 if self.dcerpc_connections and not new:
560 return self.dcerpc_connections[-1]
561 c = ClientConnection("ncacn_ip_tcp:%s" % self.server,
563 self.dcerpc_connections.append(c)
566 def get_srvsvc_connection(self, new=False):
567 if self.srvsvc_connections and not new:
568 return self.srvsvc_connections[-1]
571 return srvsvc.srvsvc("ncacn_np:%s" % (self.server),
575 (c, self.last_srvsvc_bad) = \
576 self.with_random_bad_credentials(connect,
579 self.last_srvsvc_bad)
581 self.srvsvc_connections.append(c)
584 def get_lsarpc_connection(self, new=False):
585 if self.lsarpc_connections and not new:
586 return self.lsarpc_connections[-1]
589 binding_options = 'schannel,seal,sign'
590 return lsa.lsarpc("ncacn_ip_tcp:%s[%s]" %
591 (self.server, binding_options),
595 (c, self.last_lsarpc_bad) = \
596 self.with_random_bad_credentials(connect,
598 self.machine_creds_bad,
599 self.last_lsarpc_bad)
601 self.lsarpc_connections.append(c)
604 def get_lsarpc_named_pipe_connection(self, new=False):
605 if self.lsarpc_connections_named and not new:
606 return self.lsarpc_connections_named[-1]
609 return lsa.lsarpc("ncacn_np:%s" % (self.server),
613 (c, self.last_lsarpc_named_bad) = \
614 self.with_random_bad_credentials(connect,
616 self.machine_creds_bad,
617 self.last_lsarpc_named_bad)
619 self.lsarpc_connections_named.append(c)
622 def get_drsuapi_connection_pair(self, new=False, unbind=False):
623 """get a (drs, drs_handle) tuple"""
624 if self.drsuapi_connections and not new:
625 c = self.drsuapi_connections[-1]
629 binding_options = 'seal'
630 binding_string = "ncacn_ip_tcp:%s[%s]" %\
631 (self.server, binding_options)
632 return drsuapi.drsuapi(binding_string, self.lp, creds)
634 (drs, self.last_drsuapi_bad) = \
635 self.with_random_bad_credentials(connect,
638 self.last_drsuapi_bad)
640 (drs_handle, supported_extensions) = drs_DsBind(drs)
641 c = (drs, drs_handle)
642 self.drsuapi_connections.append(c)
645 def get_ldap_connection(self, new=False, simple=False):
646 if self.ldap_connections and not new:
647 return self.ldap_connections[-1]
649 def simple_bind(creds):
651 To run simple bind against Windows, we need to run
652 following commands in PowerShell:
654 Install-windowsfeature ADCS-Cert-Authority
655 Install-AdcsCertificationAuthority -CAType EnterpriseRootCA
659 return SamDB('ldaps://%s' % self.server,
663 def sasl_bind(creds):
664 return SamDB('ldap://%s' % self.server,
668 (samdb, self.last_simple_bind_bad) = \
669 self.with_random_bad_credentials(simple_bind,
670 self.simple_bind_creds,
671 self.simple_bind_creds_bad,
672 self.last_simple_bind_bad)
674 (samdb, self.last_bind_bad) = \
675 self.with_random_bad_credentials(sasl_bind,
680 self.ldap_connections.append(samdb)
683 def get_samr_context(self, new=False):
684 if not self.samr_contexts or new:
685 self.samr_contexts.append(
686 SamrContext(self.server, lp=self.lp, creds=self.creds))
687 return self.samr_contexts[-1]
689 def get_netlogon_connection(self):
691 if self.netlogon_connection:
692 return self.netlogon_connection
695 return netlogon.netlogon("ncacn_ip_tcp:%s[schannel,seal]" %
699 (c, self.last_netlogon_bad) = \
700 self.with_random_bad_credentials(connect,
702 self.machine_creds_bad,
703 self.last_netlogon_bad)
704 self.netlogon_connection = c
707 def guess_a_dns_lookup(self):
708 return (self.realm, 'A')
710 def get_authenticator(self):
711 auth = self.machine_creds.new_client_authenticator()
712 current = netr_Authenticator()
713 current.cred.data = [ord(x) for x in auth["credential"]]
714 current.timestamp = auth["timestamp"]
716 subsequent = netr_Authenticator()
717 return (current, subsequent)
720 class SamrContext(object):
721 """State/Context associated with a samr connection.
723 def __init__(self, server, lp=None, creds=None):
724 self.connection = None
726 self.domain_handle = None
727 self.domain_sid = None
728 self.group_handle = None
729 self.user_handle = None
735 def get_connection(self):
736 if not self.connection:
737 self.connection = samr.samr(
738 "ncacn_ip_tcp:%s[seal]" % (self.server),
740 credentials=self.creds)
742 return self.connection
744 def get_handle(self):
746 c = self.get_connection()
747 self.handle = c.Connect2(None, security.SEC_FLAG_MAXIMUM_ALLOWED)
751 class Conversation(object):
752 """Details of a converation between a simulated client and a server."""
753 conversation_id = None
755 def __init__(self, start_time=None, endpoints=None):
756 self.start_time = start_time
757 self.endpoints = endpoints
759 self.msg = random_colour_print()
760 self.client_balance = 0.0
762 def __cmp__(self, other):
763 if self.start_time is None:
764 if other.start_time is None:
767 if other.start_time is None:
769 return self.start_time - other.start_time
771 def add_packet(self, packet):
772 """Add a packet object to this conversation, making a local copy with
773 a conversation-relative timestamp."""
776 if self.start_time is None:
777 self.start_time = p.timestamp
779 if self.endpoints is None:
780 self.endpoints = p.endpoints
782 if p.endpoints != self.endpoints:
783 raise FakePacketError("Conversation endpoints %s don't match"
784 "packet endpoints %s" %
785 (self.endpoints, p.endpoints))
787 p.timestamp -= self.start_time
789 if p.src == p.endpoints[0]:
790 self.client_balance -= p.client_score()
792 self.client_balance += p.client_score()
794 if p.is_really_a_packet():
795 self.packets.append(p)
797 def add_short_packet(self, timestamp, p, extra, client=True):
798 """Create a packet from a timestamp, and 'protocol:opcode' pair, and a
799 (possibly empty) list of extra data. If client is True, assume
800 this packet is from the client to the server.
802 protocol, opcode = p.split(':', 1)
803 src, dest = self.guess_client_server()
805 src, dest = dest, src
807 desc = OP_DESCRIPTIONS.get((protocol, opcode), '')
808 ip_protocol = IP_PROTOCOLS.get(protocol, '06')
809 fields = [timestamp - self.start_time, ip_protocol,
811 protocol, opcode, desc]
813 packet = Packet(fields)
814 # XXX we're assuming the timestamp is already adjusted for
816 # XXX should we adjust client balance for guessed packets?
817 if packet.src == packet.endpoints[0]:
818 self.client_balance -= packet.client_score()
820 self.client_balance += packet.client_score()
821 if packet.is_really_a_packet():
822 self.packets.append(packet)
825 return ("<Conversation %s %s starting %.3f %d packets>" %
826 (self.conversation_id, self.endpoints, self.start_time,
832 return iter(self.packets)
835 return len(self.packets)
837 def get_duration(self):
838 if len(self.packets) < 2:
840 return self.packets[-1].timestamp - self.packets[0].timestamp
842 def replay_as_summary_lines(self):
844 for p in self.packets:
845 lines.append(p.as_summary(self.start_time))
848 def replay_in_fork_with_delay(self, start, context=None, account=None):
849 """Fork a new process and replay the conversation.
851 def signal_handler(signal, frame):
852 """Signal handler closes standard out and error.
854 Triggered by a sigterm, ensures that the log messages are flushed
855 to disk and not lost.
862 now = time.time() - start
864 # we are replaying strictly in order, so it is safe to sleep
865 # in the main process if the gap is big enough. This reduces
866 # the number of concurrent threads, which allows us to make
868 if gap > 0.15 and False:
869 print("sleeping for %f in main process" % (gap - 0.1),
871 time.sleep(gap - 0.1)
872 now = time.time() - start
874 print("gap is now %f" % gap, file=sys.stderr)
876 self.conversation_id = context.next_conversation_id()
881 signal.signal(signal.SIGTERM, signal_handler)
882 # we must never return, or we'll end up running parts of the
883 # parent's clean-up code. So we work in a try...finally, and
884 # try to print any exceptions.
887 context.generate_process_local_config(account, self)
890 filename = os.path.join(context.statsdir, 'stats-conversation-%d' %
891 self.conversation_id)
893 sys.stdout = open(filename, 'w')
895 sleep_time = gap - SLEEP_OVERHEAD
897 time.sleep(sleep_time)
899 miss = t - (time.time() - start)
900 self.msg("starting %s [miss %.3f pid %d]" % (self, miss, pid))
903 print(("EXCEPTION in child PID %d, conversation %s" % (pid, self)),
905 traceback.print_exc(sys.stderr)
911 def replay(self, context=None):
914 for p in self.packets:
915 now = time.time() - start
916 gap = p.timestamp - now
917 sleep_time = gap - SLEEP_OVERHEAD
919 time.sleep(sleep_time)
921 miss = p.timestamp - (time.time() - start)
923 self.msg("packet %s [miss %.3f pid %d]" % (p, miss,
926 p.play(self, context)
928 def guess_client_server(self, server_clue=None):
929 """Have a go at deciding who is the server and who is the client.
930 returns (client, server)
932 a, b = self.endpoints
934 if self.client_balance < 0:
937 # in the absense of a clue, we will fall through to assuming
938 # the lowest number is the server (which is usually true).
940 if self.client_balance == 0 and server_clue == b:
945 def forget_packets_outside_window(self, s, e):
946 """Prune any packets outside the timne window we're interested in
948 :param s: start of the window
949 :param e: end of the window
953 for p in self.packets:
954 if p.timestamp < s or p.timestamp > e:
956 new_packets.append(p)
958 self.packets = new_packets
960 self.start_time = new_packets[0].timestamp
962 self.start_time = None
964 def renormalise_times(self, start_time):
965 """Adjust the packet start times relative to the new start time."""
966 for p in self.packets:
967 p.timestamp -= start_time
969 if self.start_time is not None:
970 self.start_time -= start_time
973 class DnsHammer(Conversation):
974 """A lightweight conversation that generates a lot of dns:0 packets on
977 def __init__(self, dns_rate, duration):
978 n = int(dns_rate * duration)
979 self.times = [random.uniform(0, duration) for i in range(n)]
982 self.duration = duration
984 self.msg = random_colour_print()
987 return ("<DnsHammer %d packets over %.1fs (rate %.2f)>" %
988 (len(self.times), self.duration, self.rate))
990 def replay_in_fork_with_delay(self, start, context=None, account=None):
991 return Conversation.replay_in_fork_with_delay(self,
996 def replay(self, context=None):
998 fn = traffic_packets.packet_dns_0
1000 now = time.time() - start
1002 sleep_time = gap - SLEEP_OVERHEAD
1004 time.sleep(sleep_time)
1007 miss = t - (time.time() - start)
1008 self.msg("packet %s [miss %.3f pid %d]" % (t, miss,
1012 packet_start = time.time()
1014 fn(self, self, context)
1016 duration = end - packet_start
1017 print("%f\tDNS\tdns\t0\t%f\tTrue\t" % (end, duration))
1018 except Exception as e:
1020 duration = end - packet_start
1021 print("%f\tDNS\tdns\t0\t%f\tFalse\t%s" % (end, duration, e))
1024 def ingest_summaries(files, dns_mode='count'):
1025 """Load a summary traffic summary file and generated Converations from it.
1028 dns_counts = defaultdict(int)
1031 if isinstance(f, str):
1033 print("Ingesting %s" % (f.name,), file=sys.stderr)
1036 if p.protocol == 'dns' and dns_mode != 'include':
1037 dns_counts[p.opcode] += 1
1046 start_time = min(p.timestamp for p in packets)
1047 last_packet = max(p.timestamp for p in packets)
1049 print("gathering packets into conversations", file=sys.stderr)
1050 conversations = OrderedDict()
1052 p.timestamp -= start_time
1053 c = conversations.get(p.endpoints)
1056 conversations[p.endpoints] = c
1059 # We only care about conversations with actual traffic, so we
1060 # filter out conversations with nothing to say. We do that here,
1061 # rather than earlier, because those empty packets contain useful
1062 # hints as to which end of the conversation was the client.
1063 conversation_list = []
1064 for c in conversations.values():
1066 conversation_list.append(c)
1068 # This is obviously not correct, as many conversations will appear
1069 # to start roughly simultaneously at the beginning of the snapshot.
1070 # To which we say: oh well, so be it.
1071 duration = float(last_packet - start_time)
1072 mean_interval = len(conversations) / duration
1074 return conversation_list, mean_interval, duration, dns_counts
1077 def guess_server_address(conversations):
1078 # we guess the most common address.
1079 addresses = Counter()
1080 for c in conversations:
1081 addresses.update(c.endpoints)
1083 return addresses.most_common(1)[0]
1086 def stringify_keys(x):
1088 for k, v in x.items():
1094 def unstringify_keys(x):
1096 for k, v in x.items():
1097 t = tuple(str(k).split('\t'))
1102 class TrafficModel(object):
1103 def __init__(self, n=3):
1105 self.query_details = {}
1107 self.dns_opcounts = defaultdict(int)
1108 self.cumulative_duration = 0.0
1109 self.conversation_rate = [0, 1]
1111 def learn(self, conversations, dns_opcounts={}):
1114 key = (NON_PACKET,) * (self.n - 1)
1116 server = guess_server_address(conversations)
1118 for k, v in dns_opcounts.items():
1119 self.dns_opcounts[k] += v
1121 if len(conversations) > 1:
1123 conversations[-1].start_time - conversations[0].start_time
1124 self.conversation_rate[0] = len(conversations)
1125 self.conversation_rate[1] = elapsed
1127 for c in conversations:
1128 client, server = c.guess_client_server(server)
1129 cum_duration += c.get_duration()
1130 key = (NON_PACKET,) * (self.n - 1)
1135 elapsed = p.timestamp - prev
1137 if elapsed > WAIT_THRESHOLD:
1138 # add the wait as an extra state
1139 wait = 'wait:%d' % (math.log(max(1.0,
1140 elapsed * WAIT_SCALE)))
1141 self.ngrams.setdefault(key, []).append(wait)
1142 key = key[1:] + (wait,)
1144 short_p = p.as_packet_type()
1145 self.query_details.setdefault(short_p,
1146 []).append(tuple(p.extra))
1147 self.ngrams.setdefault(key, []).append(short_p)
1148 key = key[1:] + (short_p,)
1150 self.cumulative_duration += cum_duration
1152 self.ngrams.setdefault(key, []).append(NON_PACKET)
1156 for k, v in self.ngrams.items():
1158 ngrams[k] = dict(Counter(v))
1161 for k, v in self.query_details.items():
1162 query_details[k] = dict(Counter('\t'.join(x) if x else '-'
1167 'query_details': query_details,
1168 'cumulative_duration': self.cumulative_duration,
1169 'conversation_rate': self.conversation_rate,
1171 d['dns'] = self.dns_opcounts
1173 if isinstance(f, str):
1176 json.dump(d, f, indent=2)
1179 if isinstance(f, str):
1184 for k, v in d['ngrams'].items():
1185 k = tuple(str(k).split('\t'))
1186 values = self.ngrams.setdefault(k, [])
1187 for p, count in v.items():
1188 values.extend([str(p)] * count)
1190 for k, v in d['query_details'].items():
1191 values = self.query_details.setdefault(str(k), [])
1192 for p, count in v.items():
1194 values.extend([()] * count)
1196 values.extend([tuple(str(p).split('\t'))] * count)
1199 for k, v in d['dns'].items():
1200 self.dns_opcounts[k] += v
1202 self.cumulative_duration = d['cumulative_duration']
1203 self.conversation_rate = d['conversation_rate']
1205 def construct_conversation(self, timestamp=0.0, client=2, server=1,
1206 hard_stop=None, packet_rate=1):
1207 """Construct a individual converation from the model."""
1209 c = Conversation(timestamp, (server, client))
1211 key = (NON_PACKET,) * (self.n - 1)
1213 while key in self.ngrams:
1214 p = random.choice(self.ngrams.get(key, NON_PACKET))
1217 if p in self.query_details:
1218 extra = random.choice(self.query_details[p])
1222 protocol, opcode = p.split(':', 1)
1223 if protocol == 'wait':
1224 log_wait_time = int(opcode) + random.random()
1225 wait = math.exp(log_wait_time) / (WAIT_SCALE * packet_rate)
1228 log_wait = random.uniform(*NO_WAIT_LOG_TIME_RANGE)
1229 wait = math.exp(log_wait) / packet_rate
1231 if hard_stop is not None and timestamp > hard_stop:
1233 c.add_short_packet(timestamp, p, extra)
1235 key = key[1:] + (p,)
1239 def generate_conversations(self, rate, duration, packet_rate=1):
1240 """Generate a list of conversations from the model."""
1242 # We run the simulation for at least ten times as long as our
1243 # desired duration, and take a section near the start.
1244 rate_n, rate_t = self.conversation_rate
1246 duration2 = max(rate_t, duration * 2)
1247 n = rate * duration2 * rate_n / rate_t
1254 start = end - duration
1256 while client < n + 2:
1257 start = random.uniform(0, duration2)
1258 c = self.construct_conversation(start,
1261 hard_stop=(duration2 * 5),
1262 packet_rate=packet_rate)
1264 c.forget_packets_outside_window(start, end)
1265 c.renormalise_times(start)
1267 conversations.append(c)
1270 print(("we have %d conversations at rate %f" %
1271 (len(conversations), rate)), file=sys.stderr)
1272 conversations.sort()
1273 return conversations
1278 'rpc_netlogon': '06',
1279 'kerberos': '06', # ratio 16248:258
1290 'smb_netlogon': '11',
1296 ('browser', '0x01'): 'Host Announcement (0x01)',
1297 ('browser', '0x02'): 'Request Announcement (0x02)',
1298 ('browser', '0x08'): 'Browser Election Request (0x08)',
1299 ('browser', '0x09'): 'Get Backup List Request (0x09)',
1300 ('browser', '0x0c'): 'Domain/Workgroup Announcement (0x0c)',
1301 ('browser', '0x0f'): 'Local Master Announcement (0x0f)',
1302 ('cldap', '3'): 'searchRequest',
1303 ('cldap', '5'): 'searchResDone',
1304 ('dcerpc', '0'): 'Request',
1305 ('dcerpc', '11'): 'Bind',
1306 ('dcerpc', '12'): 'Bind_ack',
1307 ('dcerpc', '13'): 'Bind_nak',
1308 ('dcerpc', '14'): 'Alter_context',
1309 ('dcerpc', '15'): 'Alter_context_resp',
1310 ('dcerpc', '16'): 'AUTH3',
1311 ('dcerpc', '2'): 'Response',
1312 ('dns', '0'): 'query',
1313 ('dns', '1'): 'response',
1314 ('drsuapi', '0'): 'DsBind',
1315 ('drsuapi', '12'): 'DsCrackNames',
1316 ('drsuapi', '13'): 'DsWriteAccountSpn',
1317 ('drsuapi', '1'): 'DsUnbind',
1318 ('drsuapi', '2'): 'DsReplicaSync',
1319 ('drsuapi', '3'): 'DsGetNCChanges',
1320 ('drsuapi', '4'): 'DsReplicaUpdateRefs',
1321 ('epm', '3'): 'Map',
1322 ('kerberos', ''): '',
1323 ('ldap', '0'): 'bindRequest',
1324 ('ldap', '1'): 'bindResponse',
1325 ('ldap', '2'): 'unbindRequest',
1326 ('ldap', '3'): 'searchRequest',
1327 ('ldap', '4'): 'searchResEntry',
1328 ('ldap', '5'): 'searchResDone',
1329 ('ldap', ''): '*** Unknown ***',
1330 ('lsarpc', '14'): 'lsa_LookupNames',
1331 ('lsarpc', '15'): 'lsa_LookupSids',
1332 ('lsarpc', '39'): 'lsa_QueryTrustedDomainInfoBySid',
1333 ('lsarpc', '40'): 'lsa_SetTrustedDomainInfo',
1334 ('lsarpc', '6'): 'lsa_OpenPolicy',
1335 ('lsarpc', '76'): 'lsa_LookupSids3',
1336 ('lsarpc', '77'): 'lsa_LookupNames4',
1337 ('nbns', '0'): 'query',
1338 ('nbns', '1'): 'response',
1339 ('rpc_netlogon', '21'): 'NetrLogonDummyRoutine1',
1340 ('rpc_netlogon', '26'): 'NetrServerAuthenticate3',
1341 ('rpc_netlogon', '29'): 'NetrLogonGetDomainInfo',
1342 ('rpc_netlogon', '30'): 'NetrServerPasswordSet2',
1343 ('rpc_netlogon', '39'): 'NetrLogonSamLogonEx',
1344 ('rpc_netlogon', '40'): 'DsrEnumerateDomainTrusts',
1345 ('rpc_netlogon', '45'): 'NetrLogonSamLogonWithFlags',
1346 ('rpc_netlogon', '4'): 'NetrServerReqChallenge',
1347 ('samr', '0',): 'Connect',
1348 ('samr', '16'): 'GetAliasMembership',
1349 ('samr', '17'): 'LookupNames',
1350 ('samr', '18'): 'LookupRids',
1351 ('samr', '19'): 'OpenGroup',
1352 ('samr', '1'): 'Close',
1353 ('samr', '25'): 'QueryGroupMember',
1354 ('samr', '34'): 'OpenUser',
1355 ('samr', '36'): 'QueryUserInfo',
1356 ('samr', '39'): 'GetGroupsForUser',
1357 ('samr', '3'): 'QuerySecurity',
1358 ('samr', '5'): 'LookupDomain',
1359 ('samr', '64'): 'Connect5',
1360 ('samr', '6'): 'EnumDomains',
1361 ('samr', '7'): 'OpenDomain',
1362 ('samr', '8'): 'QueryDomainInfo',
1363 ('smb', '0x04'): 'Close (0x04)',
1364 ('smb', '0x24'): 'Locking AndX (0x24)',
1365 ('smb', '0x2e'): 'Read AndX (0x2e)',
1366 ('smb', '0x32'): 'Trans2 (0x32)',
1367 ('smb', '0x71'): 'Tree Disconnect (0x71)',
1368 ('smb', '0x72'): 'Negotiate Protocol (0x72)',
1369 ('smb', '0x73'): 'Session Setup AndX (0x73)',
1370 ('smb', '0x74'): 'Logoff AndX (0x74)',
1371 ('smb', '0x75'): 'Tree Connect AndX (0x75)',
1372 ('smb', '0xa2'): 'NT Create AndX (0xa2)',
1373 ('smb2', '0'): 'NegotiateProtocol',
1374 ('smb2', '11'): 'Ioctl',
1375 ('smb2', '14'): 'Find',
1376 ('smb2', '16'): 'GetInfo',
1377 ('smb2', '18'): 'Break',
1378 ('smb2', '1'): 'SessionSetup',
1379 ('smb2', '2'): 'SessionLogoff',
1380 ('smb2', '3'): 'TreeConnect',
1381 ('smb2', '4'): 'TreeDisconnect',
1382 ('smb2', '5'): 'Create',
1383 ('smb2', '6'): 'Close',
1384 ('smb2', '8'): 'Read',
1385 ('smb_netlogon', '0x12'): 'SAM LOGON request from client (0x12)',
1386 ('smb_netlogon', '0x17'): ('SAM Active Directory Response - '
1387 'user unknown (0x17)'),
1388 ('srvsvc', '16'): 'NetShareGetInfo',
1389 ('srvsvc', '21'): 'NetSrvGetInfo',
1393 def expand_short_packet(p, timestamp, src, dest, extra):
1394 protocol, opcode = p.split(':', 1)
1395 desc = OP_DESCRIPTIONS.get((protocol, opcode), '')
1396 ip_protocol = IP_PROTOCOLS.get(protocol, '06')
1398 line = [timestamp, ip_protocol, '', src, dest, protocol, opcode, desc]
1400 return '\t'.join(line)
1403 def replay(conversations,
1412 context = ReplayContext(server=host,
1417 if len(accounts) < len(conversations):
1418 print(("we have %d accounts but %d conversations" %
1419 (accounts, conversations)), file=sys.stderr)
1422 sorted(conversations, key=lambda x: x.start_time, reverse=True),
1425 # Set the process group so that the calling scripts are not killed
1426 # when the forked child processes are killed.
1431 if duration is None:
1432 # end 1 second after the last packet of the last conversation
1433 # to start. Conversations other than the last could still be
1434 # going, but we don't care.
1435 duration = cstack[0][0].packets[-1].timestamp + 1.0
1436 print("We will stop after %.1f seconds" % duration,
1439 end = start + duration
1441 print("Replaying traffic for %u conversations over %d seconds"
1442 % (len(conversations), duration))
1446 dns_hammer = DnsHammer(dns_rate, duration)
1447 cstack.append((dns_hammer, None))
1451 # we spawn a batch, wait for finishers, then spawn another
1453 batch_end = min(now + 2.0, end)
1457 c, account = cstack.pop()
1458 if c.start_time + start > batch_end:
1459 cstack.append((c, account))
1463 pid = c.replay_in_fork_with_delay(start, context, account)
1467 fork_time += elapsed
1469 print("forked %s in pid %s (in %fs)" % (c, pid,
1474 print(("forked %d times in %f seconds (avg %f)" %
1475 (fork_n, fork_time, fork_time / fork_n)),
1478 debug(2, "no forks in batch ending %f" % batch_end)
1480 while time.time() < batch_end - 1.0:
1483 pid, status = os.waitpid(-1, os.WNOHANG)
1484 except OSError as e:
1485 if e.errno != 10: # no child processes
1489 c = children.pop(pid, None)
1490 print(("process %d finished conversation %s;"
1492 (pid, c, len(children))), file=sys.stderr)
1494 if time.time() >= end:
1495 print("time to stop", file=sys.stderr)
1499 print("EXCEPTION in parent", file=sys.stderr)
1500 traceback.print_exc()
1502 for s in (15, 15, 9):
1503 print(("killing %d children with -%d" %
1504 (len(children), s)), file=sys.stderr)
1505 for pid in children:
1508 except OSError as e:
1509 if e.errno != 3: # don't fail if it has already died
1512 end = time.time() + 1
1515 pid, status = os.waitpid(-1, os.WNOHANG)
1516 except OSError as e:
1520 c = children.pop(pid, None)
1521 print(("kill -%d %d KILLED conversation %s; "
1523 (s, pid, c, len(children))),
1525 if time.time() >= end:
1533 print("%d children are missing" % len(children),
1536 # there may be stragglers that were forked just as ^C was hit
1537 # and don't appear in the list of children. We can get them
1538 # with killpg, but that will also kill us, so this is^H^H would be
1539 # goodbye, except we cheat and pretend to use ^C (SIG_INTERRUPT),
1540 # so as not to have to fuss around writing signal handlers.
1543 except KeyboardInterrupt:
1544 print("ignoring fake ^C", file=sys.stderr)
1547 def openLdb(host, creds, lp):
1548 session = system_session()
1549 ldb = SamDB(url="ldap://%s" % host,
1550 session_info=session,
1556 def ou_name(ldb, instance_id):
1557 """Generate an ou name from the instance id"""
1558 return "ou=instance-%d,ou=traffic_replay,%s" % (instance_id,
1562 def create_ou(ldb, instance_id):
1563 """Create an ou, all created user and machine accounts will belong to it.
1565 This allows all the created resources to be cleaned up easily.
1567 ou = ou_name(ldb, instance_id)
1569 ldb.add({"dn": ou.split(',', 1)[1],
1570 "objectclass": "organizationalunit"})
1571 except LdbError as e:
1573 # ignore already exists
1578 "objectclass": "organizationalunit"})
1579 except LdbError as e:
1581 # ignore already exists
1587 class ConversationAccounts(object):
1588 """Details of the machine and user accounts associated with a conversation.
1590 def __init__(self, netbios_name, machinepass, username, userpass):
1591 self.netbios_name = netbios_name
1592 self.machinepass = machinepass
1593 self.username = username
1594 self.userpass = userpass
1597 def generate_replay_accounts(ldb, instance_id, number, password):
1598 """Generate a series of unique machine and user account names."""
1600 generate_traffic_accounts(ldb, instance_id, number, password)
1602 for i in range(1, number + 1):
1603 netbios_name = "STGM-%d-%d" % (instance_id, i)
1604 username = "STGU-%d-%d" % (instance_id, i)
1606 account = ConversationAccounts(netbios_name, password, username,
1608 accounts.append(account)
1612 def generate_traffic_accounts(ldb, instance_id, number, password):
1613 """Create the specified number of user and machine accounts.
1615 As accounts are not explicitly deleted between runs. This function starts
1616 with the last account and iterates backwards stopping either when it
1617 finds an already existing account or it has generated all the required
1620 print(("Generating machine and conversation accounts, "
1621 "as required for %d conversations" % number),
1624 for i in range(number, 0, -1):
1626 netbios_name = "STGM-%d-%d" % (instance_id, i)
1627 create_machine_account(ldb, instance_id, netbios_name, password)
1629 except LdbError as e:
1636 print("Added %d new machine accounts" % added,
1640 for i in range(number, 0, -1):
1642 username = "STGU-%d-%d" % (instance_id, i)
1643 create_user_account(ldb, instance_id, username, password)
1645 except LdbError as e:
1653 print("Added %d new user accounts" % added,
1657 def create_machine_account(ldb, instance_id, netbios_name, machinepass):
1658 """Create a machine account via ldap."""
1660 ou = ou_name(ldb, instance_id)
1661 dn = "cn=%s,%s" % (netbios_name, ou)
1663 '"' + machinepass.encode('utf-8') + '"', 'utf-8'
1664 ).encode('utf-16-le')
1668 "objectclass": "computer",
1669 "sAMAccountName": "%s$" % netbios_name,
1670 "userAccountControl":
1671 str(UF_WORKSTATION_TRUST_ACCOUNT | UF_PASSWD_NOTREQD),
1672 "unicodePwd": utf16pw})
1674 duration = end - start
1675 print("%f\t0\tcreate\tmachine\t%f\tTrue\t" % (end, duration))
1678 def create_user_account(ldb, instance_id, username, userpass):
1679 """Create a user account via ldap."""
1680 ou = ou_name(ldb, instance_id)
1681 user_dn = "cn=%s,%s" % (username, ou)
1683 '"' + userpass.encode('utf-8') + '"', 'utf-8'
1684 ).encode('utf-16-le')
1688 "objectclass": "user",
1689 "sAMAccountName": username,
1690 "userAccountControl": str(UF_NORMAL_ACCOUNT),
1691 "unicodePwd": utf16pw
1694 # grant user write permission to do things like write account SPN
1695 sdutils = sd_utils.SDUtils(ldb)
1696 sdutils.dacl_add_ace(user_dn, "(A;;WP;;;PS)")
1699 duration = end - start
1700 print("%f\t0\tcreate\tuser\t%f\tTrue\t" % (end, duration))
1703 def create_group(ldb, instance_id, name):
1704 """Create a group via ldap."""
1706 ou = ou_name(ldb, instance_id)
1707 dn = "cn=%s,%s" % (name, ou)
1711 "objectclass": "group",
1714 duration = end - start
1715 print("%f\t0\tcreate\tgroup\t%f\tTrue\t" % (end, duration))
1718 def user_name(instance_id, i):
1719 """Generate a user name based in the instance id"""
1720 return "STGU-%d-%d" % (instance_id, i)
1723 def generate_users(ldb, instance_id, number, password):
1724 """Add users to the server"""
1726 for i in range(number, 0, -1):
1728 username = user_name(instance_id, i)
1729 create_user_account(ldb, instance_id, username, password)
1731 except LdbError as e:
1733 # Stop if entry exists
1742 def group_name(instance_id, i):
1743 """Generate a group name from instance id."""
1744 return "STGG-%d-%d" % (instance_id, i)
1747 def generate_groups(ldb, instance_id, number):
1748 """Create the required number of groups on the server."""
1750 for i in range(number, 0, -1):
1752 name = group_name(instance_id, i)
1753 create_group(ldb, instance_id, name)
1755 except LdbError as e:
1757 # Stop if entry exists
1765 def clean_up_accounts(ldb, instance_id):
1766 """Remove the created accounts and groups from the server."""
1767 ou = ou_name(ldb, instance_id)
1769 ldb.delete(ou, ["tree_delete:1"])
1770 except LdbError as e:
1772 # ignore does not exist
1777 def generate_users_and_groups(ldb, instance_id, password,
1778 number_of_users, number_of_groups,
1780 """Generate the required users and groups, allocating the users to
1785 create_ou(ldb, instance_id)
1787 print("Generating dummy user accounts", file=sys.stderr)
1788 users_added = generate_users(ldb, instance_id, number_of_users, password)
1790 if number_of_groups > 0:
1791 print("Generating dummy groups", file=sys.stderr)
1792 groups_added = generate_groups(ldb, instance_id, number_of_groups)
1794 if group_memberships > 0:
1795 print("Assigning users to groups", file=sys.stderr)
1796 assignments = assign_groups(number_of_groups,
1801 print("Adding users to groups", file=sys.stderr)
1802 add_users_to_groups(ldb, instance_id, assignments)
1804 if (groups_added > 0 and users_added == 0 and
1805 number_of_groups != groups_added):
1806 print("Warning: the added groups will contain no members",
1809 print(("Added %d users, %d groups and %d group memberships" %
1810 (users_added, groups_added, len(assignments))),
1814 def assign_groups(number_of_groups,
1819 """Allocate users to groups.
1821 The intention is to have a few users that belong to most groups, while
1822 the majority of users belong to a few groups.
1824 A few groups will contain most users, with the remaining only having a
1828 def generate_user_distribution(n):
1829 """Probability distribution of a user belonging to a group.
1832 for x in range(1, n + 1):
1837 def generate_group_distribution(n):
1838 """Probability distribution of a group containing a user."""
1840 for x in range(1, n + 1):
1846 if group_memberships <= 0:
1849 group_dist = generate_group_distribution(number_of_groups)
1850 user_dist = generate_user_distribution(number_of_users)
1852 # Calculate the number of group menberships required
1853 group_memberships = math.ceil(
1854 float(group_memberships) *
1855 (float(users_added) / float(number_of_users)))
1857 existing_users = number_of_users - users_added - 1
1858 existing_groups = number_of_groups - groups_added - 1
1859 while len(assignments) < group_memberships:
1860 user = random.randint(0, number_of_users - 1)
1861 group = random.randint(0, number_of_groups - 1)
1862 probability = group_dist[group] * user_dist[user]
1864 if ((random.random() < probability * 10000) and
1865 (group > existing_groups or user > existing_users)):
1866 # the + 1 converts the array index to the corresponding
1867 # group or user number
1868 assignments.add(((user + 1), (group + 1)))
1873 def add_users_to_groups(db, instance_id, assignments):
1874 """Add users to their assigned groups.
1876 Takes the list of (group,user) tuples generated by assign_groups and
1877 assign the users to their specified groups."""
1879 ou = ou_name(db, instance_id)
1882 return("cn=%s,%s" % (name, ou))
1884 for (user, group) in assignments:
1885 user_dn = build_dn(user_name(instance_id, user))
1886 group_dn = build_dn(group_name(instance_id, group))
1889 m.dn = ldb.Dn(db, group_dn)
1890 m["member"] = ldb.MessageElement(user_dn, ldb.FLAG_MOD_ADD, "member")
1894 duration = end - start
1895 print("%f\t0\tadd\tuser\t%f\tTrue\t" % (end, duration))
1898 def generate_stats(statsdir, timing_file):
1899 """Generate and print the summary stats for a run."""
1900 first = sys.float_info.max
1906 unique_converations = set()
1909 if timing_file is not None:
1910 tw = timing_file.write
1915 tw("time\tconv\tprotocol\ttype\tduration\tsuccessful\terror\n")
1917 for filename in os.listdir(statsdir):
1918 path = os.path.join(statsdir, filename)
1919 with open(path, 'r') as f:
1922 fields = line.rstrip('\n').split('\t')
1923 conversation = fields[1]
1924 protocol = fields[2]
1925 packet_type = fields[3]
1926 latency = float(fields[4])
1927 first = min(float(fields[0]) - latency, first)
1928 last = max(float(fields[0]), last)
1930 if protocol not in latencies:
1931 latencies[protocol] = {}
1932 if packet_type not in latencies[protocol]:
1933 latencies[protocol][packet_type] = []
1935 latencies[protocol][packet_type].append(latency)
1937 if protocol not in failures:
1938 failures[protocol] = {}
1939 if packet_type not in failures[protocol]:
1940 failures[protocol][packet_type] = 0
1942 if fields[5] == 'True':
1946 failures[protocol][packet_type] += 1
1948 if conversation not in unique_converations:
1949 unique_converations.add(conversation)
1953 except (ValueError, IndexError):
1954 # not a valid line print and ignore
1955 print(line, file=sys.stderr)
1957 duration = last - first
1961 success_rate = successful / duration
1965 failure_rate = failed / duration
1967 # print the stats in more human-readable format when stdout is going to the
1968 # console (as opposed to being redirected to a file)
1969 if sys.stdout.isatty():
1970 print("Total conversations: %10d" % conversations)
1971 print("Successful operations: %10d (%.3f per second)"
1972 % (successful, success_rate))
1973 print("Failed operations: %10d (%.3f per second)"
1974 % (failed, failure_rate))
1976 print("(%d, %d, %d, %.3f, %.3f)" %
1977 (conversations, successful, failed, success_rate, failure_rate))
1979 if sys.stdout.isatty():
1980 print("Protocol Op Code Description "
1981 " Count Failed Mean Median "
1984 print("proto\top_code\tdesc\tcount\tfailed\tmean\tmedian\t95%\trange"
1986 protocols = sorted(latencies.keys())
1987 for protocol in protocols:
1988 packet_types = sorted(latencies[protocol], key=opcode_key)
1989 for packet_type in packet_types:
1990 values = latencies[protocol][packet_type]
1991 values = sorted(values)
1993 failed = failures[protocol][packet_type]
1994 mean = sum(values) / count
1995 median = calc_percentile(values, 0.50)
1996 percentile = calc_percentile(values, 0.95)
1997 rng = values[-1] - values[0]
1999 desc = OP_DESCRIPTIONS.get((protocol, packet_type), '')
2000 if sys.stdout.isatty:
2001 print("%-12s %4s %-35s %12d %12d %12.6f "
2002 "%12.6f %12.6f %12.6f %12.6f"
2014 print("%s\t%s\t%s\t%d\t%d\t%f\t%f\t%f\t%f\t%f"
2028 """Sort key for the operation code to ensure that it sorts numerically"""
2030 return "%03d" % int(v)
2035 def calc_percentile(values, percentile):
2036 """Calculate the specified percentile from the list of values.
2038 Assumes the list is sorted in ascending order.
2043 k = (len(values) - 1) * percentile
2047 return values[int(k)]
2048 d0 = values[int(f)] * (c - k)
2049 d1 = values[int(c)] * (k - f)
2053 def mk_masked_dir(*path):
2054 """In a testenv we end up with 0777 diectories that look an alarming
2055 green colour with ls. Use umask to avoid that."""
2056 d = os.path.join(*path)
2057 mask = os.umask(0o077)