traffic: optimize packet init for better performance
[metze/samba/wip.git] / python / samba / emulate / traffic.py
index 84a9a6ab067045a773b8ddea4e66b6f281923d5f..227477a5425c0fb6f89631cc56b59d42d7dbbf97 100644 (file)
@@ -138,10 +138,26 @@ class FakePacketError(Exception):
 
 class Packet(object):
     """Details of a network packet"""
-    def __init__(self, fields):
-        if isinstance(fields, str):
-            fields = fields.rstrip('\n').split('\t')
+    def __init__(self, timestamp, ip_protocol, stream_number, src, dest,
+                 protocol, opcode, desc, extra):
 
+        self.timestamp = timestamp
+        self.ip_protocol = ip_protocol
+        self.stream_number = stream_number
+        self.src = src
+        self.dest = dest
+        self.protocol = protocol
+        self.opcode = opcode
+        self.desc = desc
+        self.extra = extra
+        if self.src < self.dest:
+            self.endpoints = (self.src, self.dest)
+        else:
+            self.endpoints = (self.dest, self.src)
+
+    @classmethod
+    def from_line(self, line):
+        fields = line.rstrip('\n').split('\t')
         (timestamp,
          ip_protocol,
          stream_number,
@@ -152,23 +168,12 @@ class Packet(object):
          desc) = fields[:8]
         extra = fields[8:]
 
-        self.timestamp = float(timestamp)
-        self.ip_protocol = ip_protocol
-        try:
-            self.stream_number = int(stream_number)
-        except (ValueError, TypeError):
-            self.stream_number = None
-        self.src = int(src)
-        self.dest = int(dest)
-        self.protocol = protocol
-        self.opcode = opcode
-        self.desc = desc
-        self.extra = extra
+        timestamp = float(timestamp)
+        src = int(src)
+        dest = int(dest)
 
-        if self.src < self.dest:
-            self.endpoints = (self.src, self.dest)
-        else:
-            self.endpoints = (self.dest, self.src)
+        return Packet(timestamp, ip_protocol, stream_number, src, dest,
+                      protocol, opcode, desc, extra)
 
     def as_summary(self, time_offset=0.0):
         """Format the packet as a traffic_summary line.
@@ -196,14 +201,15 @@ class Packet(object):
         return "<Packet @%s>" % self
 
     def copy(self):
-        return self.__class__([self.timestamp,
-                               self.ip_protocol,
-                               self.stream_number,
-                               self.src,
-                               self.dest,
-                               self.protocol,
-                               self.opcode,
-                               self.desc] + self.extra)
+        return self.__class__(self.timestamp,
+                              self.ip_protocol,
+                              self.stream_number,
+                              self.src,
+                              self.dest,
+                              self.protocol,
+                              self.opcode,
+                              self.desc,
+                              self.extra)
 
     def as_packet_type(self):
         t = '%s:%s' % (self.protocol, self.opcode)
@@ -808,11 +814,9 @@ class Conversation(object):
 
         desc = OP_DESCRIPTIONS.get((protocol, opcode), '')
         ip_protocol = IP_PROTOCOLS.get(protocol, '06')
-        fields = [timestamp - self.start_time, ip_protocol,
-                  '', src, dest,
-                  protocol, opcode, desc]
-        fields.extend(extra)
-        packet = Packet(fields)
+        packet = Packet(timestamp - self.start_time, ip_protocol,
+                        '', src, dest,
+                        protocol, opcode, desc, extra)
         # XXX we're assuming the timestamp is already adjusted for
         # this conversation?
         # XXX should we adjust client balance for guessed packets?
@@ -1024,7 +1028,7 @@ def ingest_summaries(files, dns_mode='count'):
             f = open(f)
         print("Ingesting %s" % (f.name,), file=sys.stderr)
         for line in f:
-            p = Packet(line)
+            p = Packet.from_line(line)
             if p.protocol == 'dns' and dns_mode != 'include':
                 dns_counts[p.opcode] += 1
             else: