a032f23a50f1cda414cef42d98a580b0c66cd07a
[metze/samba/wip.git] / source4 / scripting / python / samba / tests / dns.py
1 # Unix SMB/CIFS implementation.
2 # Copyright (C) Kai Blin  <kai@samba.org> 2011
3 #
4 # This program is free software; you can redistribute it and/or modify
5 # it under the terms of the GNU General Public License as published by
6 # the Free Software Foundation; either version 3 of the License, or
7 # (at your option) any later version.
8 #
9 # This program is distributed in the hope that it will be useful,
10 # but WITHOUT ANY WARRANTY; without even the implied warranty of
11 # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
12 # GNU General Public License for more details.
13 #
14 # You should have received a copy of the GNU General Public License
15 # along with this program.  If not, see <http://www.gnu.org/licenses/>.
16 #
17
18 import os
19 import struct
20 import random
21 from samba import socket
22 import samba.ndr as ndr
23 import samba.dcerpc.dns as dns
24 from samba.tests import TestCase
25
26 class DNSTest(TestCase):
27
28     def errstr(self, errcode):
29         "Return a readable error code"
30         string_codes = [
31             "OK",
32             "FORMERR",
33             "SERVFAIL",
34             "NXDOMAIN",
35             "NOTIMP",
36             "REFUSED",
37             "YXDOMAIN",
38             "YXRRSET",
39             "NXRRSET",
40             "NOTAUTH",
41             "NOTZONE",
42         ]
43
44         return string_codes[errcode]
45
46
47     def assert_dns_rcode_equals(self, packet, rcode):
48         "Helper function to check return code"
49         p_errcode = packet.operation & 0x000F
50         self.assertEquals(p_errcode, rcode, "Expected RCODE %s, got %s" % \
51                             (self.errstr(rcode), self.errstr(p_errcode)))
52
53     def assert_dns_opcode_equals(self, packet, opcode):
54         "Helper function to check opcode"
55         p_opcode = packet.operation & 0x7800
56         self.assertEquals(p_opcode, opcode, "Expected OPCODE %s, got %s" % \
57                             (opcode, p_opcode))
58
59     def make_name_packet(self, opcode, qid=None):
60         "Helper creating a dns.name_packet"
61         p = dns.name_packet()
62         if qid is None:
63             p.id = random.randint(0x0, 0xffff)
64         p.operation = opcode
65         p.questions = []
66         return p
67
68     def finish_name_packet(self, packet, questions):
69         "Helper to finalize a dns.name_packet"
70         packet.qdcount = len(questions)
71         packet.questions = questions
72
73     def make_name_question(self, name, qtype, qclass):
74         "Helper creating a dns.name_question"
75         q = dns.name_question()
76         q.name = name
77         q.question_type = qtype
78         q.question_class = qclass
79         return q
80
81     def get_dns_domain(self):
82         "Helper to get dns domain"
83         return os.getenv('REALM', 'example.com').lower()
84
85     def dns_transaction_udp(self, packet, host=os.getenv('DC_SERVER_IP')):
86         "send a DNS query and read the reply"
87         s = None
88         try:
89             send_packet = ndr.ndr_pack(packet)
90             s = socket.socket(socket.AF_INET, socket.SOCK_DGRAM, 0)
91             s.connect((host, 53))
92             s.send(send_packet, 0)
93             recv_packet = s.recv(2048, 0)
94             return ndr.ndr_unpack(dns.name_packet, recv_packet)
95         finally:
96             if s is not None:
97                 s.close()
98
99     def dns_transaction_tcp(self, packet, host=os.getenv('DC_SERVER_IP')):
100         "send a DNS query and read the reply"
101         s = None
102         try:
103             send_packet = ndr.ndr_pack(packet)
104             s = socket.socket(socket.AF_INET, socket.SOCK_STREAM, 0)
105             s.connect((host, 53))
106             tcp_packet = struct.pack('!H', len(send_packet))
107             tcp_packet += send_packet
108             s.send(tcp_packet, 0)
109             recv_packet = s.recv(0xffff + 2, 0)
110             return ndr.ndr_unpack(dns.name_packet, recv_packet[2:])
111         finally:
112                 if s is not None:
113                     s.close()
114
115 class TestSimpleQueries(DNSTest):
116     def test_one_a_query(self):
117         "create a query packet containing one query record"
118         p = self.make_name_packet(dns.DNS_OPCODE_QUERY)
119         questions = []
120
121         name = "%s.%s" % (os.getenv('DC_SERVER'), self.get_dns_domain())
122         q = self.make_name_question(name, dns.DNS_QTYPE_A, dns.DNS_QCLASS_IN)
123         print "asking for ", q.name
124         questions.append(q)
125
126         self.finish_name_packet(p, questions)
127         response = self.dns_transaction_udp(p)
128         self.assert_dns_rcode_equals(response, dns.DNS_RCODE_OK)
129         self.assert_dns_opcode_equals(response, dns.DNS_OPCODE_QUERY)
130         self.assertEquals(response.ancount, 1)
131         self.assertEquals(response.answers[0].rdata,
132                           os.getenv('DC_SERVER_IP'))
133
134     def test_one_a_query_tcp(self):
135         "create a query packet containing one query record via TCP"
136         p = self.make_name_packet(dns.DNS_OPCODE_QUERY)
137         questions = []
138
139         name = "%s.%s" % (os.getenv('DC_SERVER'), self.get_dns_domain())
140         q = self.make_name_question(name, dns.DNS_QTYPE_A, dns.DNS_QCLASS_IN)
141         print "asking for ", q.name
142         questions.append(q)
143
144         self.finish_name_packet(p, questions)
145         response = self.dns_transaction_tcp(p)
146         self.assert_dns_rcode_equals(response, dns.DNS_RCODE_OK)
147         self.assert_dns_opcode_equals(response, dns.DNS_OPCODE_QUERY)
148         self.assertEquals(response.ancount, 1)
149         self.assertEquals(response.answers[0].rdata,
150                           os.getenv('DC_SERVER_IP'))
151
152     def test_two_queries(self):
153         "create a query packet containing two query records"
154         p = self.make_name_packet(dns.DNS_OPCODE_QUERY)
155         questions = []
156
157         name = "%s.%s" % (os.getenv('DC_SERVER'), self.get_dns_domain())
158         q = self.make_name_question(name, dns.DNS_QTYPE_A, dns.DNS_QCLASS_IN)
159         questions.append(q)
160
161         name = "%s.%s" % ('bogusname', self.get_dns_domain())
162         q = self.make_name_question(name, dns.DNS_QTYPE_A, dns.DNS_QCLASS_IN)
163         questions.append(q)
164
165         self.finish_name_packet(p, questions)
166         response = self.dns_transaction_udp(p)
167         self.assert_dns_rcode_equals(response, dns.DNS_RCODE_FORMERR)
168
169     def test_qtype_all_query(self):
170         "create a QTYPE_ALL query"
171         p = self.make_name_packet(dns.DNS_OPCODE_QUERY)
172         questions = []
173
174         name = "%s.%s" % (os.getenv('DC_SERVER'), self.get_dns_domain())
175         q = self.make_name_question(name, dns.DNS_QTYPE_ALL, dns.DNS_QCLASS_IN)
176         print "asking for ", q.name
177         questions.append(q)
178
179         self.finish_name_packet(p, questions)
180         response = self.dns_transaction_udp(p)
181
182         num_answers = 1
183         dc_ipv6 = os.getenv('DC_SERVER_IPV6')
184         if dc_ipv6 is not None:
185             num_answers += 1
186
187         self.assert_dns_rcode_equals(response, dns.DNS_RCODE_OK)
188         self.assert_dns_opcode_equals(response, dns.DNS_OPCODE_QUERY)
189         self.assertEquals(response.ancount, num_answers)
190         self.assertEquals(response.answers[0].rdata,
191                           os.getenv('DC_SERVER_IP'))
192         if dc_ipv6 is not None:
193             self.assertEquals(response.answers[1].rdata, dc_ipv6)
194
195     def test_qclass_none_query(self):
196         "create a QCLASS_NONE query"
197         p = self.make_name_packet(dns.DNS_OPCODE_QUERY)
198         questions = []
199
200         name = "%s.%s" % (os.getenv('DC_SERVER'), self.get_dns_domain())
201         q = self.make_name_question(name, dns.DNS_QTYPE_ALL, dns.DNS_QCLASS_NONE)
202         questions.append(q)
203
204         self.finish_name_packet(p, questions)
205         response = self.dns_transaction_udp(p)
206         self.assert_dns_rcode_equals(response, dns.DNS_RCODE_NOTIMP)
207
208 # Only returns an authority section entry in BIND and Win DNS
209 # FIXME: Enable one Samba implements this feature
210 #    def test_soa_hostname_query(self):
211 #        "create a SOA query for a hostname"
212 #        p = self.make_name_packet(dns.DNS_OPCODE_QUERY)
213 #        questions = []
214 #
215 #        name = "%s.%s" % (os.getenv('DC_SERVER'), self.get_dns_domain())
216 #        q = self.make_name_question(name, dns.DNS_QTYPE_SOA, dns.DNS_QCLASS_IN)
217 #        questions.append(q)
218 #
219 #        self.finish_name_packet(p, questions)
220 #        response = self.dns_transaction_udp(p)
221 #        self.assert_dns_rcode_equals(response, dns.DNS_RCODE_OK)
222 #        self.assert_dns_opcode_equals(response, dns.DNS_OPCODE_QUERY)
223 #        # We don't get SOA records for single hosts
224 #        self.assertEquals(response.ancount, 0)
225
226     def test_soa_domain_query(self):
227         "create a SOA query for a domain"
228         p = self.make_name_packet(dns.DNS_OPCODE_QUERY)
229         questions = []
230
231         name = self.get_dns_domain()
232         q = self.make_name_question(name, dns.DNS_QTYPE_SOA, dns.DNS_QCLASS_IN)
233         questions.append(q)
234
235         self.finish_name_packet(p, questions)
236         response = self.dns_transaction_udp(p)
237         self.assert_dns_rcode_equals(response, dns.DNS_RCODE_OK)
238         self.assert_dns_opcode_equals(response, dns.DNS_OPCODE_QUERY)
239         self.assertEquals(response.ancount, 1)
240
241
242 class TestDNSUpdates(DNSTest):
243     def test_two_updates(self):
244         "create two update requests"
245         p = self.make_name_packet(dns.DNS_OPCODE_UPDATE)
246         updates = []
247
248         name = "%s.%s" % (os.getenv('DC_SERVER'), self.get_dns_domain())
249         u = self.make_name_question(name, dns.DNS_QTYPE_A, dns.DNS_QCLASS_IN)
250         updates.append(u)
251
252         name = self.get_dns_domain()
253         u = self.make_name_question(name, dns.DNS_QTYPE_A, dns.DNS_QCLASS_IN)
254         updates.append(u)
255
256         self.finish_name_packet(p, updates)
257         response = self.dns_transaction_udp(p)
258         self.assert_dns_rcode_equals(response, dns.DNS_RCODE_FORMERR)
259
260     def test_update_wrong_qclass(self):
261         "create update with DNS_QCLASS_NONE"
262         p = self.make_name_packet(dns.DNS_OPCODE_UPDATE)
263         updates = []
264
265         name = self.get_dns_domain()
266         u = self.make_name_question(name, dns.DNS_QTYPE_A, dns.DNS_QCLASS_NONE)
267         updates.append(u)
268
269         self.finish_name_packet(p, updates)
270         response = self.dns_transaction_udp(p)
271         self.assert_dns_rcode_equals(response, dns.DNS_RCODE_NOTIMP)
272
273     def test_update_prereq_with_non_null_ttl(self):
274         "test update with a non-null TTL"
275         p = self.make_name_packet(dns.DNS_OPCODE_UPDATE)
276         updates = []
277
278         name = self.get_dns_domain()
279
280         u = self.make_name_question(name, dns.DNS_QTYPE_SOA, dns.DNS_QCLASS_IN)
281         updates.append(u)
282         self.finish_name_packet(p, updates)
283
284         prereqs = []
285         r = dns.res_rec()
286         r.name = "%s.%s" % (os.getenv('DC_SERVER'), self.get_dns_domain())
287         r.rr_type = dns.DNS_QTYPE_TXT
288         r.rr_class = dns.DNS_QCLASS_NONE
289         r.ttl = 1
290         r.length = 0
291         prereqs.append(r)
292
293         p.ancount = len(prereqs)
294         p.answers = prereqs
295
296         response = self.dns_transaction_udp(p)
297         self.assert_dns_rcode_equals(response, dns.DNS_RCODE_FORMERR)
298
299 # I'd love to test this one, but it segfaults. :)
300 #    def test_update_prereq_with_non_null_length(self):
301 #        "test update with a non-null length"
302 #        p = self.make_name_packet(dns.DNS_OPCODE_UPDATE)
303 #        updates = []
304 #
305 #        name = self.get_dns_domain()
306 #
307 #        u = self.make_name_question(name, dns.DNS_QTYPE_SOA, dns.DNS_QCLASS_IN)
308 #        updates.append(u)
309 #        self.finish_name_packet(p, updates)
310 #
311 #        prereqs = []
312 #        r = dns.res_rec()
313 #        r.name = "%s.%s" % (os.getenv('DC_SERVER'), self.get_dns_domain())
314 #        r.rr_type = dns.DNS_QTYPE_TXT
315 #        r.rr_class = dns.DNS_QCLASS_ANY
316 #        r.ttl = 0
317 #        r.length = 1
318 #        prereqs.append(r)
319 #
320 #        p.ancount = len(prereqs)
321 #        p.answers = prereqs
322 #
323 #        response = self.dns_transaction_udp(p)
324 #        self.assert_dns_rcode_equals(response, dns.DNS_RCODE_FORMERR)
325
326     def test_update_prereq_nonexisting_name(self):
327         "test update with a nonexisting name"
328         p = self.make_name_packet(dns.DNS_OPCODE_UPDATE)
329         updates = []
330
331         name = self.get_dns_domain()
332
333         u = self.make_name_question(name, dns.DNS_QTYPE_SOA, dns.DNS_QCLASS_IN)
334         updates.append(u)
335         self.finish_name_packet(p, updates)
336
337         prereqs = []
338         r = dns.res_rec()
339         r.name = "idontexist.%s" % self.get_dns_domain()
340         r.rr_type = dns.DNS_QTYPE_TXT
341         r.rr_class = dns.DNS_QCLASS_ANY
342         r.ttl = 0
343         r.length = 0
344         prereqs.append(r)
345
346         p.ancount = len(prereqs)
347         p.answers = prereqs
348
349         response = self.dns_transaction_udp(p)
350         self.assert_dns_rcode_equals(response, dns.DNS_RCODE_NXRRSET)
351
352     def test_update_add_txt_record(self):
353         "test adding records works"
354         p = self.make_name_packet(dns.DNS_OPCODE_UPDATE)
355         updates = []
356
357         name = self.get_dns_domain()
358
359         u = self.make_name_question(name, dns.DNS_QTYPE_SOA, dns.DNS_QCLASS_IN)
360         updates.append(u)
361         self.finish_name_packet(p, updates)
362
363         updates = []
364         r = dns.res_rec()
365         r.name = "textrec.%s" % self.get_dns_domain()
366         r.rr_type = dns.DNS_QTYPE_TXT
367         r.rr_class = dns.DNS_QCLASS_IN
368         r.ttl = 900
369         r.length = 0xffff
370         r.rdata = dns.txt_record()
371         r.rdata.txt = '"This is a test"'
372         updates.append(r)
373         p.nscount = len(updates)
374         p.nsrecs = updates
375
376         response = self.dns_transaction_udp(p)
377         self.assert_dns_rcode_equals(response, dns.DNS_RCODE_OK)
378
379         p = self.make_name_packet(dns.DNS_OPCODE_QUERY)
380         questions = []
381
382         name = "textrec.%s" % self.get_dns_domain()
383         q = self.make_name_question(name, dns.DNS_QTYPE_TXT, dns.DNS_QCLASS_IN)
384         questions.append(q)
385
386         self.finish_name_packet(p, questions)
387         response = self.dns_transaction_udp(p)
388         self.assert_dns_rcode_equals(response, dns.DNS_RCODE_OK)
389         self.assertEquals(response.ancount, 1)
390         self.assertEquals(response.answers[0].rdata.txt, '"This is a test"')
391
392     def test_update_add_two_txt_records(self):
393         "test adding two txt records works"
394         p = self.make_name_packet(dns.DNS_OPCODE_UPDATE)
395         updates = []
396
397         name = self.get_dns_domain()
398
399         u = self.make_name_question(name, dns.DNS_QTYPE_SOA, dns.DNS_QCLASS_IN)
400         updates.append(u)
401         self.finish_name_packet(p, updates)
402
403         updates = []
404         r = dns.res_rec()
405         r.name = "textrec2.%s" % self.get_dns_domain()
406         r.rr_type = dns.DNS_QTYPE_TXT
407         r.rr_class = dns.DNS_QCLASS_IN
408         r.ttl = 900
409         r.length = 0xffff
410         r.rdata = dns.txt_record()
411         r.rdata.txt = '"This is a test" "and this is a test, too"'
412         updates.append(r)
413         p.nscount = len(updates)
414         p.nsrecs = updates
415
416         response = self.dns_transaction_udp(p)
417         self.assert_dns_rcode_equals(response, dns.DNS_RCODE_OK)
418
419         p = self.make_name_packet(dns.DNS_OPCODE_QUERY)
420         questions = []
421
422         name = "textrec2.%s" % self.get_dns_domain()
423         q = self.make_name_question(name, dns.DNS_QTYPE_TXT, dns.DNS_QCLASS_IN)
424         questions.append(q)
425
426         self.finish_name_packet(p, questions)
427         response = self.dns_transaction_udp(p)
428         self.assert_dns_rcode_equals(response, dns.DNS_RCODE_OK)
429         self.assertEquals(response.ancount, 1)
430         self.assertEquals(response.answers[0].rdata.txt, '"This is a test" "and this is a test, too"')
431
432
433     def test_delete_record(self):
434         "Test if deleting records works"
435         p = self.make_name_packet(dns.DNS_OPCODE_UPDATE)
436         updates = []
437
438         name = self.get_dns_domain()
439
440         u = self.make_name_question(name, dns.DNS_QTYPE_SOA, dns.DNS_QCLASS_IN)
441         updates.append(u)
442         self.finish_name_packet(p, updates)
443
444         updates = []
445         r = dns.res_rec()
446         r.name = "textrec.%s" % self.get_dns_domain()
447         r.rr_type = dns.DNS_QTYPE_TXT
448         r.rr_class = dns.DNS_QCLASS_NONE
449         r.ttl = 0
450         r.length = 0xffff
451         r.rdata = dns.txt_record()
452         r.rdata.txt = '"This is a test"'
453         updates.append(r)
454         p.nscount = len(updates)
455         p.nsrecs = updates
456
457         response = self.dns_transaction_udp(p)
458         self.assert_dns_rcode_equals(response, dns.DNS_RCODE_OK)
459
460         p = self.make_name_packet(dns.DNS_OPCODE_QUERY)
461         questions = []
462
463         name = "textrec.%s" % self.get_dns_domain()
464         q = self.make_name_question(name, dns.DNS_QTYPE_TXT, dns.DNS_QCLASS_IN)
465         questions.append(q)
466
467         self.finish_name_packet(p, questions)
468         response = self.dns_transaction_udp(p)
469         self.assert_dns_rcode_equals(response, dns.DNS_RCODE_NXDOMAIN)
470
471
472 if __name__ == "__main__":
473     import unittest
474     unittest.main()