CVE-2016-0771: tests/dns: prepare script for further testing
[samba.git] / python / samba / tests / dns.py
index 49d699edb78bd6c92cbbd5fce7b68030910c191d..fe0ac3b03ec071690f15da592c8de9c177b991c6 100644 (file)
@@ -20,11 +20,27 @@ import struct
 import random
 from samba import socket
 import samba.ndr as ndr
-import samba.dcerpc.dns as dns
+from samba import credentials, param
 from samba.tests import TestCase
+from samba.dcerpc import dns, dnsp, dnsserver
+
+FILTER=''.join([(len(repr(chr(x)))==3) and chr(x) or '.' for x in range(256)])
+
+def make_txt_record(records):
+    rdata_txt = dns.txt_record()
+    s_list = dnsp.string_list()
+    s_list.count = len(records)
+    s_list.str = records
+    rdata_txt.txt = s_list
+    return rdata_txt
 
 class DNSTest(TestCase):
 
+    def get_loadparm(self):
+        lp = param.LoadParm()
+        lp.load(os.getenv("SMB_CONF_PATH"))
+        return lp
+
     def errstr(self, errcode):
         "Return a readable error code"
         string_codes = [
@@ -82,36 +98,53 @@ class DNSTest(TestCase):
         "Helper to get dns domain"
         return os.getenv('REALM', 'example.com').lower()
 
-    def dns_transaction_udp(self, packet, host=os.getenv('SERVER_IP')):
+    def dns_transaction_udp(self, packet, host=os.getenv('SERVER_IP'), dump=False):
         "send a DNS query and read the reply"
         s = None
         try:
             send_packet = ndr.ndr_pack(packet)
+            if dump:
+                print self.hexdump(send_packet)
             s = socket.socket(socket.AF_INET, socket.SOCK_DGRAM, 0)
             s.connect((host, 53))
             s.send(send_packet, 0)
             recv_packet = s.recv(2048, 0)
+            if dump:
+                print self.hexdump(recv_packet)
             return ndr.ndr_unpack(dns.name_packet, recv_packet)
         finally:
             if s is not None:
                 s.close()
 
-    def dns_transaction_tcp(self, packet, host=os.getenv('SERVER_IP')):
+    def dns_transaction_tcp(self, packet, host=os.getenv('SERVER_IP'), dump=False):
         "send a DNS query and read the reply"
         s = None
         try:
             send_packet = ndr.ndr_pack(packet)
+            if dump:
+                print self.hexdump(send_packet)
             s = socket.socket(socket.AF_INET, socket.SOCK_STREAM, 0)
             s.connect((host, 53))
             tcp_packet = struct.pack('!H', len(send_packet))
             tcp_packet += send_packet
             s.send(tcp_packet, 0)
             recv_packet = s.recv(0xffff + 2, 0)
+            if dump:
+                print self.hexdump(recv_packet)
             return ndr.ndr_unpack(dns.name_packet, recv_packet[2:])
         finally:
                 if s is not None:
                     s.close()
 
+    def hexdump(self, src, length=8):
+        N=0; result=''
+        while src:
+           s,src = src[:length],src[length:]
+           hexa = ' '.join(["%02X"%ord(x) for x in s])
+           s = s.translate(FILTER)
+           result += "%04X   %-*s   %s\n" % (N, length*3, hexa, s)
+           N+=length
+        return result
 
 class TestSimpleQueries(DNSTest):
 
@@ -151,6 +184,36 @@ class TestSimpleQueries(DNSTest):
         self.assertEquals(response.answers[0].rdata,
                           os.getenv('SERVER_IP'))
 
+    def test_one_mx_query(self):
+        "create a query packet causing an empty RCODE_OK answer"
+        p = self.make_name_packet(dns.DNS_OPCODE_QUERY)
+        questions = []
+
+        name = "%s.%s" % (os.getenv('SERVER'), self.get_dns_domain())
+        q = self.make_name_question(name, dns.DNS_QTYPE_MX, dns.DNS_QCLASS_IN)
+        print "asking for ", q.name
+        questions.append(q)
+
+        self.finish_name_packet(p, questions)
+        response = self.dns_transaction_udp(p)
+        self.assert_dns_rcode_equals(response, dns.DNS_RCODE_OK)
+        self.assert_dns_opcode_equals(response, dns.DNS_OPCODE_QUERY)
+        self.assertEquals(response.ancount, 0)
+
+        p = self.make_name_packet(dns.DNS_OPCODE_QUERY)
+        questions = []
+
+        name = "invalid-%s.%s" % (os.getenv('SERVER'), self.get_dns_domain())
+        q = self.make_name_question(name, dns.DNS_QTYPE_MX, dns.DNS_QCLASS_IN)
+        print "asking for ", q.name
+        questions.append(q)
+
+        self.finish_name_packet(p, questions)
+        response = self.dns_transaction_udp(p)
+        self.assert_dns_rcode_equals(response, dns.DNS_RCODE_NXDOMAIN)
+        self.assert_dns_opcode_equals(response, dns.DNS_OPCODE_QUERY)
+        self.assertEquals(response.ancount, 0)
+
     def test_two_queries(self):
         "create a query packet containing two query records"
         p = self.make_name_packet(dns.DNS_OPCODE_QUERY)
@@ -239,6 +302,7 @@ class TestSimpleQueries(DNSTest):
         self.assert_dns_rcode_equals(response, dns.DNS_RCODE_OK)
         self.assert_dns_opcode_equals(response, dns.DNS_OPCODE_QUERY)
         self.assertEquals(response.ancount, 1)
+        self.assertEquals(response.answers[0].rdata.minimum, 3600)
 
 
 class TestDNSUpdates(DNSTest):
@@ -370,8 +434,8 @@ class TestDNSUpdates(DNSTest):
         r.rr_class = dns.DNS_QCLASS_IN
         r.ttl = 900
         r.length = 0xffff
-        r.rdata = dns.txt_record()
-        r.rdata.txt = '"This is a test"'
+        rdata = make_txt_record(['"This is a test"'])
+        r.rdata = rdata
         updates.append(r)
         p.nscount = len(updates)
         p.nsrecs = updates
@@ -390,7 +454,7 @@ class TestDNSUpdates(DNSTest):
         response = self.dns_transaction_udp(p)
         self.assert_dns_rcode_equals(response, dns.DNS_RCODE_OK)
         self.assertEquals(response.ancount, 1)
-        self.assertEquals(response.answers[0].rdata.txt, '"This is a test"')
+        self.assertEquals(response.answers[0].rdata.txt.str[0], '"This is a test"')
 
     def test_update_add_two_txt_records(self):
         "test adding two txt records works"
@@ -410,8 +474,9 @@ class TestDNSUpdates(DNSTest):
         r.rr_class = dns.DNS_QCLASS_IN
         r.ttl = 900
         r.length = 0xffff
-        r.rdata = dns.txt_record()
-        r.rdata.txt = '"This is a test" "and this is a test, too"'
+        rdata = make_txt_record(['"This is a test"',
+                                 '"and this is a test, too"'])
+        r.rdata = rdata
         updates.append(r)
         p.nscount = len(updates)
         p.nsrecs = updates
@@ -430,7 +495,8 @@ class TestDNSUpdates(DNSTest):
         response = self.dns_transaction_udp(p)
         self.assert_dns_rcode_equals(response, dns.DNS_RCODE_OK)
         self.assertEquals(response.ancount, 1)
-        self.assertEquals(response.answers[0].rdata.txt, '"This is a test" "and this is a test, too"')
+        self.assertEquals(response.answers[0].rdata.txt.str[0], '"This is a test"')
+        self.assertEquals(response.answers[0].rdata.txt.str[1], '"and this is a test, too"')
 
     def test_delete_record(self):
         "Test if deleting records works"
@@ -454,8 +520,8 @@ class TestDNSUpdates(DNSTest):
         r.rr_class = dns.DNS_QCLASS_IN
         r.ttl = 900
         r.length = 0xffff
-        r.rdata = dns.txt_record()
-        r.rdata.txt = '"This is a test"'
+        rdata = make_txt_record(['"This is a test"'])
+        r.rdata = rdata
         updates.append(r)
         p.nscount = len(updates)
         p.nsrecs = updates
@@ -490,8 +556,8 @@ class TestDNSUpdates(DNSTest):
         r.rr_class = dns.DNS_QCLASS_NONE
         r.ttl = 0
         r.length = 0xffff
-        r.rdata = dns.txt_record()
-        r.rdata.txt = '"This is a test"'
+        rdata = make_txt_record(['"This is a test"'])
+        r.rdata = rdata
         updates.append(r)
         p.nscount = len(updates)
         p.nsrecs = updates
@@ -510,6 +576,165 @@ class TestDNSUpdates(DNSTest):
         response = self.dns_transaction_udp(p)
         self.assert_dns_rcode_equals(response, dns.DNS_RCODE_NXDOMAIN)
 
+    def test_readd_record(self):
+        "Test if adding, deleting and then readding a records works"
+
+        NAME = "readdrec.%s" % self.get_dns_domain()
+
+        # Create the record
+        p = self.make_name_packet(dns.DNS_OPCODE_UPDATE)
+        updates = []
+
+        name = self.get_dns_domain()
+
+        u = self.make_name_question(name, dns.DNS_QTYPE_SOA, dns.DNS_QCLASS_IN)
+        updates.append(u)
+        self.finish_name_packet(p, updates)
+
+        updates = []
+        r = dns.res_rec()
+        r.name = NAME
+        r.rr_type = dns.DNS_QTYPE_TXT
+        r.rr_class = dns.DNS_QCLASS_IN
+        r.ttl = 900
+        r.length = 0xffff
+        rdata = make_txt_record(['"This is a test"'])
+        r.rdata = rdata
+        updates.append(r)
+        p.nscount = len(updates)
+        p.nsrecs = updates
+
+        response = self.dns_transaction_udp(p)
+        self.assert_dns_rcode_equals(response, dns.DNS_RCODE_OK)
+
+        # Now check the record is around
+        p = self.make_name_packet(dns.DNS_OPCODE_QUERY)
+        questions = []
+        q = self.make_name_question(NAME, dns.DNS_QTYPE_TXT, dns.DNS_QCLASS_IN)
+        questions.append(q)
+
+        self.finish_name_packet(p, questions)
+        response = self.dns_transaction_udp(p)
+        self.assert_dns_rcode_equals(response, dns.DNS_RCODE_OK)
+
+        # Now delete the record
+        p = self.make_name_packet(dns.DNS_OPCODE_UPDATE)
+        updates = []
+
+        name = self.get_dns_domain()
+
+        u = self.make_name_question(name, dns.DNS_QTYPE_SOA, dns.DNS_QCLASS_IN)
+        updates.append(u)
+        self.finish_name_packet(p, updates)
+
+        updates = []
+        r = dns.res_rec()
+        r.name = NAME
+        r.rr_type = dns.DNS_QTYPE_TXT
+        r.rr_class = dns.DNS_QCLASS_NONE
+        r.ttl = 0
+        r.length = 0xffff
+        rdata = make_txt_record(['"This is a test"'])
+        r.rdata = rdata
+        updates.append(r)
+        p.nscount = len(updates)
+        p.nsrecs = updates
+
+        response = self.dns_transaction_udp(p)
+        self.assert_dns_rcode_equals(response, dns.DNS_RCODE_OK)
+
+        # check it's gone
+        p = self.make_name_packet(dns.DNS_OPCODE_QUERY)
+        questions = []
+
+        q = self.make_name_question(NAME, dns.DNS_QTYPE_TXT, dns.DNS_QCLASS_IN)
+        questions.append(q)
+
+        self.finish_name_packet(p, questions)
+        response = self.dns_transaction_udp(p)
+        self.assert_dns_rcode_equals(response, dns.DNS_RCODE_NXDOMAIN)
+
+        # recreate the record
+        p = self.make_name_packet(dns.DNS_OPCODE_UPDATE)
+        updates = []
+
+        name = self.get_dns_domain()
+
+        u = self.make_name_question(name, dns.DNS_QTYPE_SOA, dns.DNS_QCLASS_IN)
+        updates.append(u)
+        self.finish_name_packet(p, updates)
+
+        updates = []
+        r = dns.res_rec()
+        r.name = NAME
+        r.rr_type = dns.DNS_QTYPE_TXT
+        r.rr_class = dns.DNS_QCLASS_IN
+        r.ttl = 900
+        r.length = 0xffff
+        rdata = make_txt_record(['"This is a test"'])
+        r.rdata = rdata
+        updates.append(r)
+        p.nscount = len(updates)
+        p.nsrecs = updates
+
+        response = self.dns_transaction_udp(p)
+        self.assert_dns_rcode_equals(response, dns.DNS_RCODE_OK)
+
+        # Now check the record is around
+        p = self.make_name_packet(dns.DNS_OPCODE_QUERY)
+        questions = []
+        q = self.make_name_question(NAME, dns.DNS_QTYPE_TXT, dns.DNS_QCLASS_IN)
+        questions.append(q)
+
+        self.finish_name_packet(p, questions)
+        response = self.dns_transaction_udp(p)
+        self.assert_dns_rcode_equals(response, dns.DNS_RCODE_OK)
+
+    def test_update_add_mx_record(self):
+        "test adding MX records works"
+        p = self.make_name_packet(dns.DNS_OPCODE_UPDATE)
+        updates = []
+
+        name = self.get_dns_domain()
+
+        u = self.make_name_question(name, dns.DNS_QTYPE_SOA, dns.DNS_QCLASS_IN)
+        updates.append(u)
+        self.finish_name_packet(p, updates)
+
+        updates = []
+        r = dns.res_rec()
+        r.name = "%s" % self.get_dns_domain()
+        r.rr_type = dns.DNS_QTYPE_MX
+        r.rr_class = dns.DNS_QCLASS_IN
+        r.ttl = 900
+        r.length = 0xffff
+        rdata = dns.mx_record()
+        rdata.preference = 10
+        rdata.exchange = 'mail.%s' % self.get_dns_domain()
+        r.rdata = rdata
+        updates.append(r)
+        p.nscount = len(updates)
+        p.nsrecs = updates
+
+        response = self.dns_transaction_udp(p)
+        self.assert_dns_rcode_equals(response, dns.DNS_RCODE_OK)
+
+        p = self.make_name_packet(dns.DNS_OPCODE_QUERY)
+        questions = []
+
+        name = "%s" % self.get_dns_domain()
+        q = self.make_name_question(name, dns.DNS_QTYPE_MX, dns.DNS_QCLASS_IN)
+        questions.append(q)
+
+        self.finish_name_packet(p, questions)
+        response = self.dns_transaction_udp(p)
+        self.assert_dns_rcode_equals(response, dns.DNS_RCODE_OK)
+        self.assertEqual(response.ancount, 1)
+        ans = response.answers[0]
+        self.assertEqual(ans.rr_type, dns.DNS_QTYPE_MX)
+        self.assertEqual(ans.rdata.preference, 10)
+        self.assertEqual(ans.rdata.exchange, 'mail.%s' % self.get_dns_domain())
+
 
 class TestComplexQueries(DNSTest):
 
@@ -617,6 +842,35 @@ class TestInvalidQueries(DNSTest):
         self.assertEquals(response.answers[0].rdata,
                           os.getenv('SERVER_IP'))
 
+    def test_one_a_reply(self):
+        "send a reply instead of a query"
+
+        p = self.make_name_packet(dns.DNS_OPCODE_QUERY)
+        questions = []
+
+        name = "%s.%s" % ('fakefakefake', self.get_dns_domain())
+        q = self.make_name_question(name, dns.DNS_QTYPE_A, dns.DNS_QCLASS_IN)
+        print "asking for ", q.name
+        questions.append(q)
+
+        self.finish_name_packet(p, questions)
+        p.operation |= dns.DNS_FLAG_REPLY
+        s = None
+        try:
+            send_packet = ndr.ndr_pack(p)
+            s = socket.socket(socket.AF_INET, socket.SOCK_STREAM, 0)
+            host=os.getenv('SERVER_IP')
+            s.connect((host, 53))
+            tcp_packet = struct.pack('!H', len(send_packet))
+            tcp_packet += send_packet
+            s.send(tcp_packet, 0)
+            recv_packet = s.recv(0xffff + 2, 0)
+            self.assertEquals(0, len(recv_packet))
+        finally:
+            if s is not None:
+                s.close()
+
+
 if __name__ == "__main__":
     import unittest
     unittest.main()