lock caches in case they are shared
authorBob Halley <halley@dnspython.org>
Sun, 31 Mar 2013 11:50:04 +0000 (12:50 +0100)
committerBob Halley <halley@dnspython.org>
Sun, 31 Mar 2013 11:50:04 +0000 (12:50 +0100)
dns/resolver.py

index 08e86a2819063bb6faca284ec86ef4d27ce034e9..54d053dcba74693eae0b2fc3b8fef671fb980f10 100644 (file)
@@ -22,6 +22,11 @@ import socket
 import sys
 import time
 
+try:
+    import threading as _threading
+except ImportError:
+    import dummy_threading as _threading
+
 import dns.exception
 import dns.flags
 import dns.ipv4
@@ -216,8 +221,9 @@ class Cache(object):
         self.data = {}
         self.cleaning_interval = cleaning_interval
         self.next_cleaning = time.time() + self.cleaning_interval
+        self.lock = _threading.Lock()
 
-    def maybe_clean(self):
+    def _maybe_clean(self):
         """Clean the cache if it's time to do so."""
 
         now = time.time()
@@ -240,11 +246,15 @@ class Cache(object):
         @rtype: dns.resolver.Answer object or None
         """
 
-        self.maybe_clean()
-        v = self.data.get(key)
-        if v is None or v.expiration <= time.time():
-            return None
-        return v
+        try:
+            self.lock.acquire()
+            self._maybe_clean()
+            v = self.data.get(key)
+            if v is None or v.expiration <= time.time():
+                return None
+            return v
+        finally:
+            self.lock.release()
 
     def put(self, key, value):
         """Associate key and value in the cache.
@@ -255,8 +265,12 @@ class Cache(object):
         @type value: dns.resolver.Answer object
         """
 
-        self.maybe_clean()
-        self.data[key] = value
+        try:
+            self.lock.acquire()
+            self._maybe_clean()
+            self.data[key] = value
+        finally:
+            self.lock.release()
 
     def flush(self, key=None):
         """Flush the cache.
@@ -268,12 +282,16 @@ class Cache(object):
         @type key: (dns.name.Name, int, int) tuple or None
         """
 
-        if not key is None:
-            if self.data.has_key(key):
-                del self.data[key]
-        else:
-            self.data = {}
-            self.next_cleaning = time.time() + self.cleaning_interval
+        try:
+            self.lock.acquire()
+            if not key is None:
+                if self.data.has_key(key):
+                    del self.data[key]
+            else:
+                self.data = {}
+                self.next_cleaning = time.time() + self.cleaning_interval
+        finally:
+            self.lock.release()
 
 class LRUCacheNode(object):
     """LRUCache node.
@@ -326,6 +344,7 @@ class LRUCache(object):
         self.data = {}
         self.set_max_size(max_size)
         self.sentinel = LRUCacheNode(None, None)
+        self.lock = _threading.Lock()
 
     def set_max_size(self, max_size):
         if max_size < 1:
@@ -340,17 +359,21 @@ class LRUCache(object):
         query name, rdtype, and rdclass.
         @rtype: dns.resolver.Answer object or None
         """
-        node = self.data.get(key)
-        if node is None:
-            return None
-        # Unlink because we're either going to move the node to the front
-        # of the LRU list or we're going to free it.
-        node.unlink()
-        if node.value.expiration <= time.time():
-            del self.data[node.key]
-            return None
-        node.link_after(self.sentinel)
-        return node.value
+        try:
+            self.lock.acquire()
+            node = self.data.get(key)
+            if node is None:
+                return None
+            # Unlink because we're either going to move the node to the front
+            # of the LRU list or we're going to free it.
+            node.unlink()
+            if node.value.expiration <= time.time():
+                del self.data[node.key]
+                return None
+            node.link_after(self.sentinel)
+            return node.value
+        finally:
+            self.lock.release()
 
     def put(self, key, value):
         """Associate key and value in the cache.
@@ -360,17 +383,21 @@ class LRUCache(object):
         @param value: The answer being cached
         @type value: dns.resolver.Answer object
         """
-        node = self.data.get(key)
-        if not node is None:
-            node.unlink()
-            del self.data[node.key]
-        while len(self.data) >= self.max_size:
-            node = self.sentinel.prev
-            node.unlink()
-            del self.data[node.key]
-        node = LRUCacheNode(key, value)
-        node.link_after(self.sentinel)
-        self.data[key] = node
+        try:
+            self.lock.acquire()
+            node = self.data.get(key)
+            if not node is None:
+                node.unlink()
+                del self.data[node.key]
+            while len(self.data) >= self.max_size:
+                node = self.sentinel.prev
+                node.unlink()
+                del self.data[node.key]
+            node = LRUCacheNode(key, value)
+            node.link_after(self.sentinel)
+            self.data[key] = node
+        finally:
+            self.lock.release()
 
     def flush(self, key=None):
         """Flush the cache.
@@ -381,19 +408,23 @@ class LRUCache(object):
         @param key: the key to flush
         @type key: (dns.name.Name, int, int) tuple or None
         """
-        if not key is None:
-            node = self.data.get(key)
-            if not node is None:
-                node.unlink()
-                del self.data[node.key]
-        else:
-            node = self.sentinel.next
-            while node != self.sentinel:
-                next = node.next
-                node.prev = None
-                node.next = None
-                node = next
-            self.data = {}
+        try:
+            self.lock.acquire()
+            if not key is None:
+                node = self.data.get(key)
+                if not node is None:
+                    node.unlink()
+                    del self.data[node.key]
+            else:
+                node = self.sentinel.next
+                while node != self.sentinel:
+                    next = node.next
+                    node.prev = None
+                    node.next = None
+                    node = next
+                self.data = {}
+        finally:
+            self.lock.release()
 
 class Resolver(object):
     """DNS stub resolver