dns: Delete dnsNode objects when they are empty
[obnox/samba/samba-obnox.git] / python / samba / tests / dns.py
1 # Unix SMB/CIFS implementation.
2 # Copyright (C) Kai Blin  <kai@samba.org> 2011
3 #
4 # This program is free software; you can redistribute it and/or modify
5 # it under the terms of the GNU General Public License as published by
6 # the Free Software Foundation; either version 3 of the License, or
7 # (at your option) any later version.
8 #
9 # This program is distributed in the hope that it will be useful,
10 # but WITHOUT ANY WARRANTY; without even the implied warranty of
11 # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
12 # GNU General Public License for more details.
13 #
14 # You should have received a copy of the GNU General Public License
15 # along with this program.  If not, see <http://www.gnu.org/licenses/>.
16 #
17
18 import os
19 import struct
20 import random
21 from samba import socket
22 import samba.ndr as ndr
23 import samba.dcerpc.dns as dns
24 from samba.tests import TestCase
25
26 FILTER=''.join([(len(repr(chr(x)))==3) and chr(x) or '.' for x in range(256)])
27
28
29 class DNSTest(TestCase):
30
31     def errstr(self, errcode):
32         "Return a readable error code"
33         string_codes = [
34             "OK",
35             "FORMERR",
36             "SERVFAIL",
37             "NXDOMAIN",
38             "NOTIMP",
39             "REFUSED",
40             "YXDOMAIN",
41             "YXRRSET",
42             "NXRRSET",
43             "NOTAUTH",
44             "NOTZONE",
45         ]
46
47         return string_codes[errcode]
48
49
50     def assert_dns_rcode_equals(self, packet, rcode):
51         "Helper function to check return code"
52         p_errcode = packet.operation & 0x000F
53         self.assertEquals(p_errcode, rcode, "Expected RCODE %s, got %s" %
54                             (self.errstr(rcode), self.errstr(p_errcode)))
55
56     def assert_dns_opcode_equals(self, packet, opcode):
57         "Helper function to check opcode"
58         p_opcode = packet.operation & 0x7800
59         self.assertEquals(p_opcode, opcode, "Expected OPCODE %s, got %s" %
60                             (opcode, p_opcode))
61
62     def make_name_packet(self, opcode, qid=None):
63         "Helper creating a dns.name_packet"
64         p = dns.name_packet()
65         if qid is None:
66             p.id = random.randint(0x0, 0xffff)
67         p.operation = opcode
68         p.questions = []
69         return p
70
71     def finish_name_packet(self, packet, questions):
72         "Helper to finalize a dns.name_packet"
73         packet.qdcount = len(questions)
74         packet.questions = questions
75
76     def make_name_question(self, name, qtype, qclass):
77         "Helper creating a dns.name_question"
78         q = dns.name_question()
79         q.name = name
80         q.question_type = qtype
81         q.question_class = qclass
82         return q
83
84     def get_dns_domain(self):
85         "Helper to get dns domain"
86         return os.getenv('REALM', 'example.com').lower()
87
88     def dns_transaction_udp(self, packet, host=os.getenv('SERVER_IP'), dump=False):
89         "send a DNS query and read the reply"
90         s = None
91         try:
92             send_packet = ndr.ndr_pack(packet)
93             if dump:
94                 print self.hexdump(send_packet)
95             s = socket.socket(socket.AF_INET, socket.SOCK_DGRAM, 0)
96             s.connect((host, 53))
97             s.send(send_packet, 0)
98             recv_packet = s.recv(2048, 0)
99             if dump:
100                 print self.hexdump(recv_packet)
101             return ndr.ndr_unpack(dns.name_packet, recv_packet)
102         finally:
103             if s is not None:
104                 s.close()
105
106     def dns_transaction_tcp(self, packet, host=os.getenv('SERVER_IP'), dump=False):
107         "send a DNS query and read the reply"
108         s = None
109         try:
110             send_packet = ndr.ndr_pack(packet)
111             if dump:
112                 print self.hexdump(send_packet)
113             s = socket.socket(socket.AF_INET, socket.SOCK_STREAM, 0)
114             s.connect((host, 53))
115             tcp_packet = struct.pack('!H', len(send_packet))
116             tcp_packet += send_packet
117             s.send(tcp_packet, 0)
118             recv_packet = s.recv(0xffff + 2, 0)
119             if dump:
120                 print self.hexdump(recv_packet)
121             return ndr.ndr_unpack(dns.name_packet, recv_packet[2:])
122         finally:
123                 if s is not None:
124                     s.close()
125
126     def hexdump(self, src, length=8):
127         N=0; result=''
128         while src:
129            s,src = src[:length],src[length:]
130            hexa = ' '.join(["%02X"%ord(x) for x in s])
131            s = s.translate(FILTER)
132            result += "%04X   %-*s   %s\n" % (N, length*3, hexa, s)
133            N+=length
134         return result
135
136 class TestSimpleQueries(DNSTest):
137
138     def test_one_a_query(self):
139         "create a query packet containing one query record"
140         p = self.make_name_packet(dns.DNS_OPCODE_QUERY)
141         questions = []
142
143         name = "%s.%s" % (os.getenv('SERVER'), self.get_dns_domain())
144         q = self.make_name_question(name, dns.DNS_QTYPE_A, dns.DNS_QCLASS_IN)
145         print "asking for ", q.name
146         questions.append(q)
147
148         self.finish_name_packet(p, questions)
149         response = self.dns_transaction_udp(p)
150         self.assert_dns_rcode_equals(response, dns.DNS_RCODE_OK)
151         self.assert_dns_opcode_equals(response, dns.DNS_OPCODE_QUERY)
152         self.assertEquals(response.ancount, 1)
153         self.assertEquals(response.answers[0].rdata,
154                           os.getenv('SERVER_IP'))
155
156     def test_one_a_query_tcp(self):
157         "create a query packet containing one query record via TCP"
158         p = self.make_name_packet(dns.DNS_OPCODE_QUERY)
159         questions = []
160
161         name = "%s.%s" % (os.getenv('SERVER'), self.get_dns_domain())
162         q = self.make_name_question(name, dns.DNS_QTYPE_A, dns.DNS_QCLASS_IN)
163         print "asking for ", q.name
164         questions.append(q)
165
166         self.finish_name_packet(p, questions)
167         response = self.dns_transaction_tcp(p)
168         self.assert_dns_rcode_equals(response, dns.DNS_RCODE_OK)
169         self.assert_dns_opcode_equals(response, dns.DNS_OPCODE_QUERY)
170         self.assertEquals(response.ancount, 1)
171         self.assertEquals(response.answers[0].rdata,
172                           os.getenv('SERVER_IP'))
173
174     def test_two_queries(self):
175         "create a query packet containing two query records"
176         p = self.make_name_packet(dns.DNS_OPCODE_QUERY)
177         questions = []
178
179         name = "%s.%s" % (os.getenv('SERVER'), self.get_dns_domain())
180         q = self.make_name_question(name, dns.DNS_QTYPE_A, dns.DNS_QCLASS_IN)
181         questions.append(q)
182
183         name = "%s.%s" % ('bogusname', self.get_dns_domain())
184         q = self.make_name_question(name, dns.DNS_QTYPE_A, dns.DNS_QCLASS_IN)
185         questions.append(q)
186
187         self.finish_name_packet(p, questions)
188         response = self.dns_transaction_udp(p)
189         self.assert_dns_rcode_equals(response, dns.DNS_RCODE_FORMERR)
190
191     def test_qtype_all_query(self):
192         "create a QTYPE_ALL query"
193         p = self.make_name_packet(dns.DNS_OPCODE_QUERY)
194         questions = []
195
196         name = "%s.%s" % (os.getenv('SERVER'), self.get_dns_domain())
197         q = self.make_name_question(name, dns.DNS_QTYPE_ALL, dns.DNS_QCLASS_IN)
198         print "asking for ", q.name
199         questions.append(q)
200
201         self.finish_name_packet(p, questions)
202         response = self.dns_transaction_udp(p)
203
204         num_answers = 1
205         dc_ipv6 = os.getenv('SERVER_IPV6')
206         if dc_ipv6 is not None:
207             num_answers += 1
208
209         self.assert_dns_rcode_equals(response, dns.DNS_RCODE_OK)
210         self.assert_dns_opcode_equals(response, dns.DNS_OPCODE_QUERY)
211         self.assertEquals(response.ancount, num_answers)
212         self.assertEquals(response.answers[0].rdata,
213                           os.getenv('SERVER_IP'))
214         if dc_ipv6 is not None:
215             self.assertEquals(response.answers[1].rdata, dc_ipv6)
216
217     def test_qclass_none_query(self):
218         "create a QCLASS_NONE query"
219         p = self.make_name_packet(dns.DNS_OPCODE_QUERY)
220         questions = []
221
222         name = "%s.%s" % (os.getenv('SERVER'), self.get_dns_domain())
223         q = self.make_name_question(name, dns.DNS_QTYPE_ALL, dns.DNS_QCLASS_NONE)
224         questions.append(q)
225
226         self.finish_name_packet(p, questions)
227         response = self.dns_transaction_udp(p)
228         self.assert_dns_rcode_equals(response, dns.DNS_RCODE_NOTIMP)
229
230 # Only returns an authority section entry in BIND and Win DNS
231 # FIXME: Enable one Samba implements this feature
232 #    def test_soa_hostname_query(self):
233 #        "create a SOA query for a hostname"
234 #        p = self.make_name_packet(dns.DNS_OPCODE_QUERY)
235 #        questions = []
236 #
237 #        name = "%s.%s" % (os.getenv('SERVER'), self.get_dns_domain())
238 #        q = self.make_name_question(name, dns.DNS_QTYPE_SOA, dns.DNS_QCLASS_IN)
239 #        questions.append(q)
240 #
241 #        self.finish_name_packet(p, questions)
242 #        response = self.dns_transaction_udp(p)
243 #        self.assert_dns_rcode_equals(response, dns.DNS_RCODE_OK)
244 #        self.assert_dns_opcode_equals(response, dns.DNS_OPCODE_QUERY)
245 #        # We don't get SOA records for single hosts
246 #        self.assertEquals(response.ancount, 0)
247
248     def test_soa_domain_query(self):
249         "create a SOA query for a domain"
250         p = self.make_name_packet(dns.DNS_OPCODE_QUERY)
251         questions = []
252
253         name = self.get_dns_domain()
254         q = self.make_name_question(name, dns.DNS_QTYPE_SOA, dns.DNS_QCLASS_IN)
255         questions.append(q)
256
257         self.finish_name_packet(p, questions)
258         response = self.dns_transaction_udp(p)
259         self.assert_dns_rcode_equals(response, dns.DNS_RCODE_OK)
260         self.assert_dns_opcode_equals(response, dns.DNS_OPCODE_QUERY)
261         self.assertEquals(response.ancount, 1)
262
263
264 class TestDNSUpdates(DNSTest):
265
266     def test_two_updates(self):
267         "create two update requests"
268         p = self.make_name_packet(dns.DNS_OPCODE_UPDATE)
269         updates = []
270
271         name = "%s.%s" % (os.getenv('SERVER'), self.get_dns_domain())
272         u = self.make_name_question(name, dns.DNS_QTYPE_A, dns.DNS_QCLASS_IN)
273         updates.append(u)
274
275         name = self.get_dns_domain()
276         u = self.make_name_question(name, dns.DNS_QTYPE_A, dns.DNS_QCLASS_IN)
277         updates.append(u)
278
279         self.finish_name_packet(p, updates)
280         response = self.dns_transaction_udp(p)
281         self.assert_dns_rcode_equals(response, dns.DNS_RCODE_FORMERR)
282
283     def test_update_wrong_qclass(self):
284         "create update with DNS_QCLASS_NONE"
285         p = self.make_name_packet(dns.DNS_OPCODE_UPDATE)
286         updates = []
287
288         name = self.get_dns_domain()
289         u = self.make_name_question(name, dns.DNS_QTYPE_A, dns.DNS_QCLASS_NONE)
290         updates.append(u)
291
292         self.finish_name_packet(p, updates)
293         response = self.dns_transaction_udp(p)
294         self.assert_dns_rcode_equals(response, dns.DNS_RCODE_NOTIMP)
295
296     def test_update_prereq_with_non_null_ttl(self):
297         "test update with a non-null TTL"
298         p = self.make_name_packet(dns.DNS_OPCODE_UPDATE)
299         updates = []
300
301         name = self.get_dns_domain()
302
303         u = self.make_name_question(name, dns.DNS_QTYPE_SOA, dns.DNS_QCLASS_IN)
304         updates.append(u)
305         self.finish_name_packet(p, updates)
306
307         prereqs = []
308         r = dns.res_rec()
309         r.name = "%s.%s" % (os.getenv('SERVER'), self.get_dns_domain())
310         r.rr_type = dns.DNS_QTYPE_TXT
311         r.rr_class = dns.DNS_QCLASS_NONE
312         r.ttl = 1
313         r.length = 0
314         prereqs.append(r)
315
316         p.ancount = len(prereqs)
317         p.answers = prereqs
318
319         response = self.dns_transaction_udp(p)
320         self.assert_dns_rcode_equals(response, dns.DNS_RCODE_FORMERR)
321
322 # I'd love to test this one, but it segfaults. :)
323 #    def test_update_prereq_with_non_null_length(self):
324 #        "test update with a non-null length"
325 #        p = self.make_name_packet(dns.DNS_OPCODE_UPDATE)
326 #        updates = []
327 #
328 #        name = self.get_dns_domain()
329 #
330 #        u = self.make_name_question(name, dns.DNS_QTYPE_SOA, dns.DNS_QCLASS_IN)
331 #        updates.append(u)
332 #        self.finish_name_packet(p, updates)
333 #
334 #        prereqs = []
335 #        r = dns.res_rec()
336 #        r.name = "%s.%s" % (os.getenv('SERVER'), self.get_dns_domain())
337 #        r.rr_type = dns.DNS_QTYPE_TXT
338 #        r.rr_class = dns.DNS_QCLASS_ANY
339 #        r.ttl = 0
340 #        r.length = 1
341 #        prereqs.append(r)
342 #
343 #        p.ancount = len(prereqs)
344 #        p.answers = prereqs
345 #
346 #        response = self.dns_transaction_udp(p)
347 #        self.assert_dns_rcode_equals(response, dns.DNS_RCODE_FORMERR)
348
349     def test_update_prereq_nonexisting_name(self):
350         "test update with a nonexisting name"
351         p = self.make_name_packet(dns.DNS_OPCODE_UPDATE)
352         updates = []
353
354         name = self.get_dns_domain()
355
356         u = self.make_name_question(name, dns.DNS_QTYPE_SOA, dns.DNS_QCLASS_IN)
357         updates.append(u)
358         self.finish_name_packet(p, updates)
359
360         prereqs = []
361         r = dns.res_rec()
362         r.name = "idontexist.%s" % self.get_dns_domain()
363         r.rr_type = dns.DNS_QTYPE_TXT
364         r.rr_class = dns.DNS_QCLASS_ANY
365         r.ttl = 0
366         r.length = 0
367         prereqs.append(r)
368
369         p.ancount = len(prereqs)
370         p.answers = prereqs
371
372         response = self.dns_transaction_udp(p)
373         self.assert_dns_rcode_equals(response, dns.DNS_RCODE_NXRRSET)
374
375     def test_update_add_txt_record(self):
376         "test adding records works"
377         p = self.make_name_packet(dns.DNS_OPCODE_UPDATE)
378         updates = []
379
380         name = self.get_dns_domain()
381
382         u = self.make_name_question(name, dns.DNS_QTYPE_SOA, dns.DNS_QCLASS_IN)
383         updates.append(u)
384         self.finish_name_packet(p, updates)
385
386         updates = []
387         r = dns.res_rec()
388         r.name = "textrec.%s" % self.get_dns_domain()
389         r.rr_type = dns.DNS_QTYPE_TXT
390         r.rr_class = dns.DNS_QCLASS_IN
391         r.ttl = 900
392         r.length = 0xffff
393         rdata = dns.txt_record()
394         rdata.txt = '"This is a test"'
395         r.rdata = rdata
396         updates.append(r)
397         p.nscount = len(updates)
398         p.nsrecs = updates
399
400         response = self.dns_transaction_udp(p)
401         self.assert_dns_rcode_equals(response, dns.DNS_RCODE_OK)
402
403         p = self.make_name_packet(dns.DNS_OPCODE_QUERY)
404         questions = []
405
406         name = "textrec.%s" % self.get_dns_domain()
407         q = self.make_name_question(name, dns.DNS_QTYPE_TXT, dns.DNS_QCLASS_IN)
408         questions.append(q)
409
410         self.finish_name_packet(p, questions)
411         response = self.dns_transaction_udp(p)
412         self.assert_dns_rcode_equals(response, dns.DNS_RCODE_OK)
413         self.assertEquals(response.ancount, 1)
414         self.assertEquals(response.answers[0].rdata.txt, '"This is a test"')
415
416     def test_update_add_two_txt_records(self):
417         "test adding two txt records works"
418         p = self.make_name_packet(dns.DNS_OPCODE_UPDATE)
419         updates = []
420
421         name = self.get_dns_domain()
422
423         u = self.make_name_question(name, dns.DNS_QTYPE_SOA, dns.DNS_QCLASS_IN)
424         updates.append(u)
425         self.finish_name_packet(p, updates)
426
427         updates = []
428         r = dns.res_rec()
429         r.name = "textrec2.%s" % self.get_dns_domain()
430         r.rr_type = dns.DNS_QTYPE_TXT
431         r.rr_class = dns.DNS_QCLASS_IN
432         r.ttl = 900
433         r.length = 0xffff
434         rdata = dns.txt_record()
435         rdata.txt = '"This is a test" "and this is a test, too"'
436         r.rdata = rdata
437         updates.append(r)
438         p.nscount = len(updates)
439         p.nsrecs = updates
440
441         response = self.dns_transaction_udp(p)
442         self.assert_dns_rcode_equals(response, dns.DNS_RCODE_OK)
443
444         p = self.make_name_packet(dns.DNS_OPCODE_QUERY)
445         questions = []
446
447         name = "textrec2.%s" % self.get_dns_domain()
448         q = self.make_name_question(name, dns.DNS_QTYPE_TXT, dns.DNS_QCLASS_IN)
449         questions.append(q)
450
451         self.finish_name_packet(p, questions)
452         response = self.dns_transaction_udp(p)
453         self.assert_dns_rcode_equals(response, dns.DNS_RCODE_OK)
454         self.assertEquals(response.ancount, 1)
455         self.assertEquals(response.answers[0].rdata.txt, '"This is a test" "and this is a test, too"')
456
457     def test_delete_record(self):
458         "Test if deleting records works"
459
460         NAME = "deleterec.%s" % self.get_dns_domain()
461
462         # First, create a record to make sure we have a record to delete.
463         p = self.make_name_packet(dns.DNS_OPCODE_UPDATE)
464         updates = []
465
466         name = self.get_dns_domain()
467
468         u = self.make_name_question(name, dns.DNS_QTYPE_SOA, dns.DNS_QCLASS_IN)
469         updates.append(u)
470         self.finish_name_packet(p, updates)
471
472         updates = []
473         r = dns.res_rec()
474         r.name = NAME
475         r.rr_type = dns.DNS_QTYPE_TXT
476         r.rr_class = dns.DNS_QCLASS_IN
477         r.ttl = 900
478         r.length = 0xffff
479         rdata = dns.txt_record()
480         rdata.txt = '"This is a test"'
481         r.rdata = rdata
482         updates.append(r)
483         p.nscount = len(updates)
484         p.nsrecs = updates
485
486         response = self.dns_transaction_udp(p)
487         self.assert_dns_rcode_equals(response, dns.DNS_RCODE_OK)
488
489         # Now check the record is around
490         p = self.make_name_packet(dns.DNS_OPCODE_QUERY)
491         questions = []
492         q = self.make_name_question(NAME, dns.DNS_QTYPE_TXT, dns.DNS_QCLASS_IN)
493         questions.append(q)
494
495         self.finish_name_packet(p, questions)
496         response = self.dns_transaction_udp(p)
497         self.assert_dns_rcode_equals(response, dns.DNS_RCODE_OK)
498
499         # Now delete the record
500         p = self.make_name_packet(dns.DNS_OPCODE_UPDATE)
501         updates = []
502
503         name = self.get_dns_domain()
504
505         u = self.make_name_question(name, dns.DNS_QTYPE_SOA, dns.DNS_QCLASS_IN)
506         updates.append(u)
507         self.finish_name_packet(p, updates)
508
509         updates = []
510         r = dns.res_rec()
511         r.name = NAME
512         r.rr_type = dns.DNS_QTYPE_TXT
513         r.rr_class = dns.DNS_QCLASS_NONE
514         r.ttl = 0
515         r.length = 0xffff
516         rdata = dns.txt_record()
517         rdata.txt = '"This is a test"'
518         r.rdata = rdata
519         updates.append(r)
520         p.nscount = len(updates)
521         p.nsrecs = updates
522
523         response = self.dns_transaction_udp(p)
524         self.assert_dns_rcode_equals(response, dns.DNS_RCODE_OK)
525
526         # And finally check it's gone
527         p = self.make_name_packet(dns.DNS_OPCODE_QUERY)
528         questions = []
529
530         q = self.make_name_question(NAME, dns.DNS_QTYPE_TXT, dns.DNS_QCLASS_IN)
531         questions.append(q)
532
533         self.finish_name_packet(p, questions)
534         response = self.dns_transaction_udp(p)
535         self.assert_dns_rcode_equals(response, dns.DNS_RCODE_NXDOMAIN)
536
537     def test_readd_record(self):
538         "Test if adding, deleting and then readding a records works"
539
540         NAME = "readdrec.%s" % self.get_dns_domain()
541
542         # Create the record
543         p = self.make_name_packet(dns.DNS_OPCODE_UPDATE)
544         updates = []
545
546         name = self.get_dns_domain()
547
548         u = self.make_name_question(name, dns.DNS_QTYPE_SOA, dns.DNS_QCLASS_IN)
549         updates.append(u)
550         self.finish_name_packet(p, updates)
551
552         updates = []
553         r = dns.res_rec()
554         r.name = NAME
555         r.rr_type = dns.DNS_QTYPE_TXT
556         r.rr_class = dns.DNS_QCLASS_IN
557         r.ttl = 900
558         r.length = 0xffff
559         rdata = dns.txt_record()
560         rdata.txt = '"This is a test"'
561         r.rdata = rdata
562         updates.append(r)
563         p.nscount = len(updates)
564         p.nsrecs = updates
565
566         response = self.dns_transaction_udp(p)
567         self.assert_dns_rcode_equals(response, dns.DNS_RCODE_OK)
568
569         # Now check the record is around
570         p = self.make_name_packet(dns.DNS_OPCODE_QUERY)
571         questions = []
572         q = self.make_name_question(NAME, dns.DNS_QTYPE_TXT, dns.DNS_QCLASS_IN)
573         questions.append(q)
574
575         self.finish_name_packet(p, questions)
576         response = self.dns_transaction_udp(p)
577         self.assert_dns_rcode_equals(response, dns.DNS_RCODE_OK)
578
579         # Now delete the record
580         p = self.make_name_packet(dns.DNS_OPCODE_UPDATE)
581         updates = []
582
583         name = self.get_dns_domain()
584
585         u = self.make_name_question(name, dns.DNS_QTYPE_SOA, dns.DNS_QCLASS_IN)
586         updates.append(u)
587         self.finish_name_packet(p, updates)
588
589         updates = []
590         r = dns.res_rec()
591         r.name = NAME
592         r.rr_type = dns.DNS_QTYPE_TXT
593         r.rr_class = dns.DNS_QCLASS_NONE
594         r.ttl = 0
595         r.length = 0xffff
596         rdata = dns.txt_record()
597         rdata.txt = '"This is a test"'
598         r.rdata = rdata
599         updates.append(r)
600         p.nscount = len(updates)
601         p.nsrecs = updates
602
603         response = self.dns_transaction_udp(p)
604         self.assert_dns_rcode_equals(response, dns.DNS_RCODE_OK)
605
606         # check it's gone
607         p = self.make_name_packet(dns.DNS_OPCODE_QUERY)
608         questions = []
609
610         q = self.make_name_question(NAME, dns.DNS_QTYPE_TXT, dns.DNS_QCLASS_IN)
611         questions.append(q)
612
613         self.finish_name_packet(p, questions)
614         response = self.dns_transaction_udp(p)
615         self.assert_dns_rcode_equals(response, dns.DNS_RCODE_NXDOMAIN)
616
617         # recreate the record
618         p = self.make_name_packet(dns.DNS_OPCODE_UPDATE)
619         updates = []
620
621         name = self.get_dns_domain()
622
623         u = self.make_name_question(name, dns.DNS_QTYPE_SOA, dns.DNS_QCLASS_IN)
624         updates.append(u)
625         self.finish_name_packet(p, updates)
626
627         updates = []
628         r = dns.res_rec()
629         r.name = NAME
630         r.rr_type = dns.DNS_QTYPE_TXT
631         r.rr_class = dns.DNS_QCLASS_IN
632         r.ttl = 900
633         r.length = 0xffff
634         rdata = dns.txt_record()
635         rdata.txt = '"This is a test"'
636         r.rdata = rdata
637         updates.append(r)
638         p.nscount = len(updates)
639         p.nsrecs = updates
640
641         response = self.dns_transaction_udp(p)
642         self.assert_dns_rcode_equals(response, dns.DNS_RCODE_OK)
643
644         # Now check the record is around
645         p = self.make_name_packet(dns.DNS_OPCODE_QUERY)
646         questions = []
647         q = self.make_name_question(NAME, dns.DNS_QTYPE_TXT, dns.DNS_QCLASS_IN)
648         questions.append(q)
649
650         self.finish_name_packet(p, questions)
651         response = self.dns_transaction_udp(p)
652         self.assert_dns_rcode_equals(response, dns.DNS_RCODE_OK)
653
654     def test_update_add_mx_record(self):
655         "test adding MX records works"
656         p = self.make_name_packet(dns.DNS_OPCODE_UPDATE)
657         updates = []
658
659         name = self.get_dns_domain()
660
661         u = self.make_name_question(name, dns.DNS_QTYPE_SOA, dns.DNS_QCLASS_IN)
662         updates.append(u)
663         self.finish_name_packet(p, updates)
664
665         updates = []
666         r = dns.res_rec()
667         r.name = "%s" % self.get_dns_domain()
668         r.rr_type = dns.DNS_QTYPE_MX
669         r.rr_class = dns.DNS_QCLASS_IN
670         r.ttl = 900
671         r.length = 0xffff
672         rdata = dns.mx_record()
673         rdata.preference = 10
674         rdata.exchange = 'mail.%s' % self.get_dns_domain()
675         r.rdata = rdata
676         updates.append(r)
677         p.nscount = len(updates)
678         p.nsrecs = updates
679
680         response = self.dns_transaction_udp(p)
681         self.assert_dns_rcode_equals(response, dns.DNS_RCODE_OK)
682
683         p = self.make_name_packet(dns.DNS_OPCODE_QUERY)
684         questions = []
685
686         name = "%s" % self.get_dns_domain()
687         q = self.make_name_question(name, dns.DNS_QTYPE_MX, dns.DNS_QCLASS_IN)
688         questions.append(q)
689
690         self.finish_name_packet(p, questions)
691         response = self.dns_transaction_udp(p)
692         self.assert_dns_rcode_equals(response, dns.DNS_RCODE_OK)
693         self.assertEqual(response.ancount, 1)
694         ans = response.answers[0]
695         self.assertEqual(ans.rr_type, dns.DNS_QTYPE_MX)
696         self.assertEqual(ans.rdata.preference, 10)
697         self.assertEqual(ans.rdata.exchange, 'mail.%s' % self.get_dns_domain())
698
699
700 class TestComplexQueries(DNSTest):
701
702     def setUp(self):
703         super(TestComplexQueries, self).setUp()
704         p = self.make_name_packet(dns.DNS_OPCODE_UPDATE)
705         updates = []
706
707         name = self.get_dns_domain()
708
709         u = self.make_name_question(name, dns.DNS_QTYPE_SOA, dns.DNS_QCLASS_IN)
710         updates.append(u)
711         self.finish_name_packet(p, updates)
712
713         updates = []
714         r = dns.res_rec()
715         r.name = "cname_test.%s" % self.get_dns_domain()
716         r.rr_type = dns.DNS_QTYPE_CNAME
717         r.rr_class = dns.DNS_QCLASS_IN
718         r.ttl = 900
719         r.length = 0xffff
720         r.rdata = "%s.%s" % (os.getenv('SERVER'), self.get_dns_domain())
721         updates.append(r)
722         p.nscount = len(updates)
723         p.nsrecs = updates
724
725         response = self.dns_transaction_udp(p)
726         self.assert_dns_rcode_equals(response, dns.DNS_RCODE_OK)
727
728     def tearDown(self):
729         super(TestComplexQueries, self).tearDown()
730         p = self.make_name_packet(dns.DNS_OPCODE_UPDATE)
731         updates = []
732
733         name = self.get_dns_domain()
734
735         u = self.make_name_question(name, dns.DNS_QTYPE_SOA, dns.DNS_QCLASS_IN)
736         updates.append(u)
737         self.finish_name_packet(p, updates)
738
739         updates = []
740         r = dns.res_rec()
741         r.name = "cname_test.%s" % self.get_dns_domain()
742         r.rr_type = dns.DNS_QTYPE_CNAME
743         r.rr_class = dns.DNS_QCLASS_NONE
744         r.ttl = 0
745         r.length = 0xffff
746         r.rdata = "%s.%s" % (os.getenv('SERVER'), self.get_dns_domain())
747         updates.append(r)
748         p.nscount = len(updates)
749         p.nsrecs = updates
750
751         response = self.dns_transaction_udp(p)
752         self.assert_dns_rcode_equals(response, dns.DNS_RCODE_OK)
753
754     def test_one_a_query(self):
755         "create a query packet containing one query record"
756         p = self.make_name_packet(dns.DNS_OPCODE_QUERY)
757         questions = []
758
759         name = "cname_test.%s" % self.get_dns_domain()
760         q = self.make_name_question(name, dns.DNS_QTYPE_A, dns.DNS_QCLASS_IN)
761         print "asking for ", q.name
762         questions.append(q)
763
764         self.finish_name_packet(p, questions)
765         response = self.dns_transaction_udp(p)
766         self.assert_dns_rcode_equals(response, dns.DNS_RCODE_OK)
767         self.assert_dns_opcode_equals(response, dns.DNS_OPCODE_QUERY)
768         self.assertEquals(response.ancount, 2)
769         self.assertEquals(response.answers[0].rr_type, dns.DNS_QTYPE_CNAME)
770         self.assertEquals(response.answers[0].rdata, "%s.%s" %
771                           (os.getenv('SERVER'), self.get_dns_domain()))
772         self.assertEquals(response.answers[1].rr_type, dns.DNS_QTYPE_A)
773         self.assertEquals(response.answers[1].rdata,
774                           os.getenv('SERVER_IP'))
775
776 class TestInvalidQueries(DNSTest):
777
778     def test_one_a_query(self):
779         "send 0 bytes follows by create a query packet containing one query record"
780
781         s = None
782         try:
783             s = socket.socket(socket.AF_INET, socket.SOCK_DGRAM, 0)
784             s.connect((os.getenv('SERVER_IP'), 53))
785             s.send("", 0)
786         finally:
787             if s is not None:
788                 s.close()
789
790         p = self.make_name_packet(dns.DNS_OPCODE_QUERY)
791         questions = []
792
793         name = "%s.%s" % (os.getenv('SERVER'), self.get_dns_domain())
794         q = self.make_name_question(name, dns.DNS_QTYPE_A, dns.DNS_QCLASS_IN)
795         print "asking for ", q.name
796         questions.append(q)
797
798         self.finish_name_packet(p, questions)
799         response = self.dns_transaction_udp(p)
800         self.assert_dns_rcode_equals(response, dns.DNS_RCODE_OK)
801         self.assert_dns_opcode_equals(response, dns.DNS_OPCODE_QUERY)
802         self.assertEquals(response.ancount, 1)
803         self.assertEquals(response.answers[0].rdata,
804                           os.getenv('SERVER_IP'))
805
806 if __name__ == "__main__":
807     import unittest
808     unittest.main()