PEP8: fix E302: expected 2 blank lines, found 1
[metze/samba/wip.git] / python / samba / tests / dns_forwarder.py
1 # Unix SMB/CIFS implementation.
2 # Copyright (C) Kai Blin  <kai@samba.org> 2011
3 #
4 # This program is free software; you can redistribute it and/or modify
5 # it under the terms of the GNU General Public License as published by
6 # the Free Software Foundation; either version 3 of the License, or
7 # (at your option) any later version.
8 #
9 # This program is distributed in the hope that it will be useful,
10 # but WITHOUT ANY WARRANTY; without even the implied warranty of
11 # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
12 # GNU General Public License for more details.
13 #
14 # You should have received a copy of the GNU General Public License
15 # along with this program.  If not, see <http://www.gnu.org/licenses/>.
16 #
17
18 from __future__ import print_function
19 import os
20 import sys
21 import struct
22 import random
23 import socket
24 import samba
25 import time
26 import errno
27 import samba.ndr as ndr
28 from samba import credentials, param
29 from samba.tests import TestCase
30 from samba.dcerpc import dns, dnsp, dnsserver
31 from samba.netcmd.dns import TXTRecord, dns_record_match, data_to_dns_record
32 from samba.tests.subunitrun import SubunitOptions, TestProgram
33 import samba.getopt as options
34 import optparse
35 import subprocess
36
37 parser = optparse.OptionParser("dns_forwarder.py <server name> <server ip> (dns forwarder)+ [options]")
38 sambaopts = options.SambaOptions(parser)
39 parser.add_option_group(sambaopts)
40
41 # This timeout only has relevance when testing against Windows
42 # Format errors tend to return patchy responses, so a timeout is needed.
43 parser.add_option("--timeout", type="int", dest="timeout",
44                   help="Specify timeout for DNS requests")
45
46 # use command line creds if available
47 credopts = options.CredentialsOptions(parser)
48 parser.add_option_group(credopts)
49 subunitopts = SubunitOptions(parser)
50 parser.add_option_group(subunitopts)
51
52 opts, args = parser.parse_args()
53
54 lp = sambaopts.get_loadparm()
55 creds = credopts.get_credentials(lp)
56
57 timeout = opts.timeout
58
59 if len(args) < 3:
60     parser.print_usage()
61     sys.exit(1)
62
63 server_name = args[0]
64 server_ip = args[1]
65 dns_servers = args[2:]
66
67 creds.set_krb_forwardable(credentials.NO_KRB_FORWARDABLE)
68
69
70 def make_txt_record(records):
71     rdata_txt = dns.txt_record()
72     s_list = dnsp.string_list()
73     s_list.count = len(records)
74     s_list.str = records
75     rdata_txt.txt = s_list
76     return rdata_txt
77
78
79 class DNSTest(TestCase):
80
81     errcodes = dict((v, k) for k, v in vars(dns).items() if k.startswith('DNS_RCODE_'))
82
83     def assert_dns_rcode_equals(self, packet, rcode):
84         "Helper function to check return code"
85         p_errcode = packet.operation & 0x000F
86         self.assertEquals(p_errcode, rcode, "Expected RCODE %s, got %s" %
87                           (self.errcodes[rcode], self.errcodes[p_errcode]))
88
89     def assert_dns_opcode_equals(self, packet, opcode):
90         "Helper function to check opcode"
91         p_opcode = packet.operation & 0x7800
92         self.assertEquals(p_opcode, opcode, "Expected OPCODE %s, got %s" %
93                           (opcode, p_opcode))
94
95     def make_name_packet(self, opcode, qid=None):
96         "Helper creating a dns.name_packet"
97         p = dns.name_packet()
98         if qid is None:
99             p.id = random.randint(0x0, 0xffff)
100         p.operation = opcode
101         p.questions = []
102         return p
103
104     def finish_name_packet(self, packet, questions):
105         "Helper to finalize a dns.name_packet"
106         packet.qdcount = len(questions)
107         packet.questions = questions
108
109     def make_name_question(self, name, qtype, qclass):
110         "Helper creating a dns.name_question"
111         q = dns.name_question()
112         q.name = name
113         q.question_type = qtype
114         q.question_class = qclass
115         return q
116
117     def get_dns_domain(self):
118         "Helper to get dns domain"
119         return self.creds.get_realm().lower()
120
121     def dns_transaction_udp(self, packet, host=server_ip,
122                             dump=False, timeout=timeout):
123         "send a DNS query and read the reply"
124         s = None
125         try:
126             send_packet = ndr.ndr_pack(packet)
127             if dump:
128                 print(self.hexdump(send_packet))
129             s = socket.socket(socket.AF_INET, socket.SOCK_DGRAM, 0)
130             s.settimeout(timeout)
131             s.connect((host, 53))
132             s.send(send_packet, 0)
133             recv_packet = s.recv(2048, 0)
134             if dump:
135                 print(self.hexdump(recv_packet))
136             return ndr.ndr_unpack(dns.name_packet, recv_packet)
137         finally:
138             if s is not None:
139                 s.close()
140
141     def make_cname_update(self, key, value):
142         p = self.make_name_packet(dns.DNS_OPCODE_UPDATE)
143
144         name = self.get_dns_domain()
145         u = self.make_name_question(name, dns.DNS_QTYPE_SOA, dns.DNS_QCLASS_IN)
146         self.finish_name_packet(p, [u])
147
148         r = dns.res_rec()
149         r.name = key
150         r.rr_type = dns.DNS_QTYPE_CNAME
151         r.rr_class = dns.DNS_QCLASS_IN
152         r.ttl = 900
153         r.length = 0xffff
154         rdata = value
155         r.rdata = rdata
156         p.nscount = 1
157         p.nsrecs = [r]
158         response = self.dns_transaction_udp(p)
159         self.assert_dns_rcode_equals(response, dns.DNS_RCODE_OK)
160
161
162
163 def contact_real_server(host, port):
164     s = socket.socket(socket.AF_INET, socket.SOCK_DGRAM, 0)
165     s.connect((host, port))
166     return s
167
168
169 class TestDnsForwarding(DNSTest):
170     def __init__(self, *args, **kwargs):
171         super(TestDnsForwarding, self).__init__(*args, **kwargs)
172         self.subprocesses = []
173
174     def setUp(self):
175         super(TestDnsForwarding, self).setUp()
176         self.server = server_name
177         self.server_ip = server_ip
178         self.lp = lp
179         self.creds = creds
180
181     def start_toy_server(self, host, port, id):
182         python = sys.executable
183         p = subprocess.Popen([python,
184                               os.path.join(samba.source_tree_topdir(),
185                                            'python/samba/tests/'
186                                            'dns_forwarder_helpers/server.py'),
187                              host, str(port), id])
188         self.subprocesses.append(p)
189         s = socket.socket(socket.AF_INET, socket.SOCK_DGRAM, 0)
190         for i in range(300):
191             time.sleep(0.05)
192             s.connect((host, port))
193             try:
194                 s.send('timeout 0', 0)
195             except socket.error as e:
196                 if e.errno in (errno.ECONNREFUSED, errno.EHOSTUNREACH):
197                     continue
198
199             if p.returncode is not None:
200                 self.fail("Toy server has managed to die already!")
201
202             return s
203
204     def tearDown(self):
205         super(TestDnsForwarding, self).tearDown()
206         for p in self.subprocesses:
207             p.kill()
208
209     def test_comatose_forwarder(self):
210         s = self.start_toy_server(dns_servers[0], 53, 'forwarder1')
211         s.send("timeout 1000000", 0)
212
213         # make DNS query
214         name = "an-address-that-will-not-resolve"
215         p = self.make_name_packet(dns.DNS_OPCODE_QUERY)
216         questions = []
217
218         q = self.make_name_question(name, dns.DNS_QTYPE_TXT, dns.DNS_QCLASS_IN)
219         questions.append(q)
220
221         self.finish_name_packet(p, questions)
222         send_packet = ndr.ndr_pack(p)
223
224         s.send(send_packet, 0)
225         s.settimeout(1)
226         try:
227             s.recv(0xffff + 2, 0)
228             self.fail("DNS forwarder should have been inactive")
229         except socket.timeout:
230             # Expected forwarder to be dead
231             pass
232
233     def test_no_active_forwarder(self):
234         ad = contact_real_server(server_ip, 53)
235
236         name = "dsfsfds.dsfsdfs"
237         p = self.make_name_packet(dns.DNS_OPCODE_QUERY)
238         questions = []
239
240         q = self.make_name_question(name, dns.DNS_QTYPE_TXT, dns.DNS_QCLASS_IN)
241         questions.append(q)
242
243         self.finish_name_packet(p, questions)
244         send_packet = ndr.ndr_pack(p)
245
246         self.finish_name_packet(p, questions)
247         p.operation |= dns.DNS_FLAG_RECURSION_DESIRED
248         send_packet = ndr.ndr_pack(p)
249
250         ad.send(send_packet, 0)
251         ad.settimeout(timeout)
252         try:
253             data = ad.recv(0xffff + 2, 0)
254             data = ndr.ndr_unpack(dns.name_packet, data)
255             self.assert_dns_rcode_equals(data, dns.DNS_RCODE_SERVFAIL)
256             self.assertEqual(data.ancount, 0)
257         except socket.timeout:
258             self.fail("DNS server is too slow (timeout %s)" % timeout)
259
260     def test_no_flag_recursive_forwarder(self):
261         ad = contact_real_server(server_ip, 53)
262
263         name = "dsfsfds.dsfsdfs"
264         p = self.make_name_packet(dns.DNS_OPCODE_QUERY)
265         questions = []
266
267         q = self.make_name_question(name, dns.DNS_QTYPE_TXT, dns.DNS_QCLASS_IN)
268         questions.append(q)
269
270         self.finish_name_packet(p, questions)
271         send_packet = ndr.ndr_pack(p)
272
273         self.finish_name_packet(p, questions)
274         # Leave off the recursive flag
275         send_packet = ndr.ndr_pack(p)
276
277         ad.send(send_packet, 0)
278         ad.settimeout(timeout)
279         try:
280             data = ad.recv(0xffff + 2, 0)
281             data = ndr.ndr_unpack(dns.name_packet, data)
282             self.assert_dns_rcode_equals(data, dns.DNS_RCODE_NXDOMAIN)
283             self.assertEqual(data.ancount, 0)
284         except socket.timeout:
285             self.fail("DNS server is too slow (timeout %s)" % timeout)
286
287     def test_single_forwarder(self):
288         s = self.start_toy_server(dns_servers[0], 53, 'forwarder1')
289         ad = contact_real_server(server_ip, 53)
290         name = "dsfsfds.dsfsdfs"
291         p = self.make_name_packet(dns.DNS_OPCODE_QUERY)
292         questions = []
293
294         q = self.make_name_question(name, dns.DNS_QTYPE_CNAME,
295                                     dns.DNS_QCLASS_IN)
296         questions.append(q)
297
298         self.finish_name_packet(p, questions)
299         p.operation |= dns.DNS_FLAG_RECURSION_DESIRED
300         send_packet = ndr.ndr_pack(p)
301
302         ad.send(send_packet, 0)
303         ad.settimeout(timeout)
304         try:
305             data = ad.recv(0xffff + 2, 0)
306             data = ndr.ndr_unpack(dns.name_packet, data)
307             self.assert_dns_rcode_equals(data, dns.DNS_RCODE_OK)
308             self.assertEqual('forwarder1', data.answers[0].rdata)
309         except socket.timeout:
310             self.fail("DNS server is too slow (timeout %s)" % timeout)
311
312     def test_single_forwarder_not_actually_there(self):
313         ad = contact_real_server(server_ip, 53)
314         name = "dsfsfds.dsfsdfs"
315         p = self.make_name_packet(dns.DNS_OPCODE_QUERY)
316         questions = []
317
318         q = self.make_name_question(name, dns.DNS_QTYPE_CNAME,
319                                     dns.DNS_QCLASS_IN)
320         questions.append(q)
321
322         self.finish_name_packet(p, questions)
323         p.operation |= dns.DNS_FLAG_RECURSION_DESIRED
324         send_packet = ndr.ndr_pack(p)
325
326         ad.send(send_packet, 0)
327         ad.settimeout(timeout)
328         try:
329             data = ad.recv(0xffff + 2, 0)
330             data = ndr.ndr_unpack(dns.name_packet, data)
331             self.assert_dns_rcode_equals(data, dns.DNS_RCODE_SERVFAIL)
332         except socket.timeout:
333             self.fail("DNS server is too slow (timeout %s)" % timeout)
334
335
336     def test_single_forwarder_waiting_forever(self):
337         s = self.start_toy_server(dns_servers[0], 53, 'forwarder1')
338         s.send('timeout 10000', 0)
339         ad = contact_real_server(server_ip, 53)
340         name = "dsfsfds.dsfsdfs"
341         p = self.make_name_packet(dns.DNS_OPCODE_QUERY)
342         questions = []
343
344         q = self.make_name_question(name, dns.DNS_QTYPE_CNAME,
345                                     dns.DNS_QCLASS_IN)
346         questions.append(q)
347
348         self.finish_name_packet(p, questions)
349         p.operation |= dns.DNS_FLAG_RECURSION_DESIRED
350         send_packet = ndr.ndr_pack(p)
351
352         ad.send(send_packet, 0)
353         ad.settimeout(timeout)
354         try:
355             data = ad.recv(0xffff + 2, 0)
356             data = ndr.ndr_unpack(dns.name_packet, data)
357             self.assert_dns_rcode_equals(data, dns.DNS_RCODE_SERVFAIL)
358         except socket.timeout:
359             self.fail("DNS server is too slow (timeout %s)" % timeout)
360
361     def test_double_forwarder_first_frozen(self):
362         if len(dns_servers) < 2:
363             print("Ignoring test_double_forwarder_first_frozen")
364             return
365         s1 = self.start_toy_server(dns_servers[0], 53, 'forwarder1')
366         s2 = self.start_toy_server(dns_servers[1], 53, 'forwarder2')
367         s1.send('timeout 1000', 0)
368         ad = contact_real_server(server_ip, 53)
369         name = "dsfsfds.dsfsdfs"
370         p = self.make_name_packet(dns.DNS_OPCODE_QUERY)
371         questions = []
372
373         q = self.make_name_question(name, dns.DNS_QTYPE_CNAME,
374                                     dns.DNS_QCLASS_IN)
375         questions.append(q)
376
377         self.finish_name_packet(p, questions)
378         p.operation |= dns.DNS_FLAG_RECURSION_DESIRED
379         send_packet = ndr.ndr_pack(p)
380
381         ad.send(send_packet, 0)
382         ad.settimeout(timeout)
383         try:
384             data = ad.recv(0xffff + 2, 0)
385             data = ndr.ndr_unpack(dns.name_packet, data)
386             self.assert_dns_rcode_equals(data, dns.DNS_RCODE_OK)
387             self.assertEqual('forwarder2', data.answers[0].rdata)
388         except socket.timeout:
389             self.fail("DNS server is too slow (timeout %s)" % timeout)
390
391     def test_double_forwarder_first_down(self):
392         if len(dns_servers) < 2:
393             print("Ignoring test_double_forwarder_first_down")
394             return
395         s2 = self.start_toy_server(dns_servers[1], 53, 'forwarder2')
396         ad = contact_real_server(server_ip, 53)
397         name = "dsfsfds.dsfsdfs"
398         p = self.make_name_packet(dns.DNS_OPCODE_QUERY)
399         questions = []
400
401         q = self.make_name_question(name, dns.DNS_QTYPE_CNAME,
402                                     dns.DNS_QCLASS_IN)
403         questions.append(q)
404
405         self.finish_name_packet(p, questions)
406         p.operation |= dns.DNS_FLAG_RECURSION_DESIRED
407         send_packet = ndr.ndr_pack(p)
408
409         ad.send(send_packet, 0)
410         ad.settimeout(timeout)
411         try:
412             data = ad.recv(0xffff + 2, 0)
413             data = ndr.ndr_unpack(dns.name_packet, data)
414             self.assert_dns_rcode_equals(data, dns.DNS_RCODE_OK)
415             self.assertEqual('forwarder2', data.answers[0].rdata)
416         except socket.timeout:
417             self.fail("DNS server is too slow (timeout %s)" % timeout)
418
419     def test_double_forwarder_both_slow(self):
420         if len(dns_servers) < 2:
421             print("Ignoring test_double_forwarder_both_slow")
422             return
423         s1 = self.start_toy_server(dns_servers[0], 53, 'forwarder1')
424         s2 = self.start_toy_server(dns_servers[1], 53, 'forwarder2')
425         s1.send('timeout 1.5', 0)
426         s2.send('timeout 1.5', 0)
427         ad = contact_real_server(server_ip, 53)
428         name = "dsfsfds.dsfsdfs"
429         p = self.make_name_packet(dns.DNS_OPCODE_QUERY)
430         questions = []
431
432         q = self.make_name_question(name, dns.DNS_QTYPE_CNAME,
433                                     dns.DNS_QCLASS_IN)
434         questions.append(q)
435
436         self.finish_name_packet(p, questions)
437         p.operation |= dns.DNS_FLAG_RECURSION_DESIRED
438         send_packet = ndr.ndr_pack(p)
439
440         ad.send(send_packet, 0)
441         ad.settimeout(timeout)
442         try:
443             data = ad.recv(0xffff + 2, 0)
444             data = ndr.ndr_unpack(dns.name_packet, data)
445             self.assert_dns_rcode_equals(data, dns.DNS_RCODE_OK)
446             self.assertEqual('forwarder1', data.answers[0].rdata)
447         except socket.timeout:
448             self.fail("DNS server is too slow (timeout %s)" % timeout)
449
450     def test_cname(self):
451         s1 = self.start_toy_server(dns_servers[0], 53, 'forwarder1')
452
453         ad = contact_real_server(server_ip, 53)
454         name = "resolve.cname"
455         p = self.make_name_packet(dns.DNS_OPCODE_QUERY)
456         questions = []
457
458         q = self.make_name_question(name, dns.DNS_QTYPE_CNAME,
459                                     dns.DNS_QCLASS_IN)
460         questions.append(q)
461
462         self.finish_name_packet(p, questions)
463         p.operation |= dns.DNS_FLAG_RECURSION_DESIRED
464         send_packet = ndr.ndr_pack(p)
465
466         ad.send(send_packet, 0)
467         ad.settimeout(timeout)
468         try:
469             data = ad.recv(0xffff + 2, 0)
470             data = ndr.ndr_unpack(dns.name_packet, data)
471             self.assert_dns_rcode_equals(data, dns.DNS_RCODE_OK)
472             self.assertEqual(len(data.answers), 1)
473             self.assertEqual('forwarder1', data.answers[0].rdata)
474         except socket.timeout:
475             self.fail("DNS server is too slow (timeout %s)" % timeout)
476
477     def test_double_cname(self):
478         s1 = self.start_toy_server(dns_servers[0], 53, 'forwarder1')
479
480         name = 'resolve.cname.%s' % self.get_dns_domain()
481         self.make_cname_update(name, "dsfsfds.dsfsdfs")
482
483         ad = contact_real_server(server_ip, 53)
484
485         p = self.make_name_packet(dns.DNS_OPCODE_QUERY)
486         questions = []
487         q = self.make_name_question(name, dns.DNS_QTYPE_A,
488                                     dns.DNS_QCLASS_IN)
489         questions.append(q)
490
491         self.finish_name_packet(p, questions)
492         p.operation |= dns.DNS_FLAG_RECURSION_DESIRED
493         send_packet = ndr.ndr_pack(p)
494
495         ad.send(send_packet, 0)
496         ad.settimeout(timeout)
497         try:
498             data = ad.recv(0xffff + 2, 0)
499             data = ndr.ndr_unpack(dns.name_packet, data)
500             self.assert_dns_rcode_equals(data, dns.DNS_RCODE_OK)
501             self.assertEqual('forwarder1', data.answers[1].rdata)
502         except socket.timeout:
503             self.fail("DNS server is too slow (timeout %s)" % timeout)
504
505     def test_cname_forwarding_with_slow_server(self):
506         if len(dns_servers) < 2:
507             print("Ignoring test_cname_forwarding_with_slow_server")
508             return
509         s1 = self.start_toy_server(dns_servers[0], 53, 'forwarder1')
510         s2 = self.start_toy_server(dns_servers[1], 53, 'forwarder2')
511         s1.send('timeout 10000', 0)
512
513         name = 'resolve.cname.%s' % self.get_dns_domain()
514         self.make_cname_update(name, "dsfsfds.dsfsdfs")
515
516         ad = contact_real_server(server_ip, 53)
517
518         p = self.make_name_packet(dns.DNS_OPCODE_QUERY)
519         questions = []
520         q = self.make_name_question(name, dns.DNS_QTYPE_A,
521                                     dns.DNS_QCLASS_IN)
522         questions.append(q)
523
524         self.finish_name_packet(p, questions)
525         p.operation |= dns.DNS_FLAG_RECURSION_DESIRED
526         send_packet = ndr.ndr_pack(p)
527
528         ad.send(send_packet, 0)
529         ad.settimeout(timeout)
530         try:
531             data = ad.recv(0xffff + 2, 0)
532             data = ndr.ndr_unpack(dns.name_packet, data)
533             self.assert_dns_rcode_equals(data, dns.DNS_RCODE_OK)
534             self.assertEqual('forwarder2', data.answers[-1].rdata)
535         except socket.timeout:
536             self.fail("DNS server is too slow (timeout %s)" % timeout)
537
538     def test_cname_forwarding_with_server_down(self):
539         if len(dns_servers) < 2:
540             print("Ignoring test_cname_forwarding_with_server_down")
541             return
542         s2 = self.start_toy_server(dns_servers[1], 53, 'forwarder2')
543
544         name1 = 'resolve1.cname.%s' % self.get_dns_domain()
545         name2 = 'resolve2.cname.%s' % self.get_dns_domain()
546         self.make_cname_update(name1, name2)
547         self.make_cname_update(name2, "dsfsfds.dsfsdfs")
548
549         ad = contact_real_server(server_ip, 53)
550
551         p = self.make_name_packet(dns.DNS_OPCODE_QUERY)
552         questions = []
553         q = self.make_name_question(name1, dns.DNS_QTYPE_A,
554                                     dns.DNS_QCLASS_IN)
555         questions.append(q)
556
557         self.finish_name_packet(p, questions)
558         p.operation |= dns.DNS_FLAG_RECURSION_DESIRED
559         send_packet = ndr.ndr_pack(p)
560
561         ad.send(send_packet, 0)
562         ad.settimeout(timeout)
563         try:
564             data = ad.recv(0xffff + 2, 0)
565             data = ndr.ndr_unpack(dns.name_packet, data)
566             self.assert_dns_rcode_equals(data, dns.DNS_RCODE_OK)
567             self.assertEqual('forwarder2', data.answers[-1].rdata)
568         except socket.timeout:
569             self.fail("DNS server is too slow (timeout %s)" % timeout)
570
571     def test_cname_forwarding_with_lots_of_cnames(self):
572         name3 = 'resolve3.cname.%s' % self.get_dns_domain()
573         s1 = self.start_toy_server(dns_servers[0], 53, name3)
574
575         name1 = 'resolve1.cname.%s' % self.get_dns_domain()
576         name2 = 'resolve2.cname.%s' % self.get_dns_domain()
577         self.make_cname_update(name1, name2)
578         self.make_cname_update(name3, name1)
579         self.make_cname_update(name2, "dsfsfds.dsfsdfs")
580
581         ad = contact_real_server(server_ip, 53)
582
583         p = self.make_name_packet(dns.DNS_OPCODE_QUERY)
584         questions = []
585         q = self.make_name_question(name1, dns.DNS_QTYPE_A,
586                                     dns.DNS_QCLASS_IN)
587         questions.append(q)
588
589         self.finish_name_packet(p, questions)
590         p.operation |= dns.DNS_FLAG_RECURSION_DESIRED
591         send_packet = ndr.ndr_pack(p)
592
593         ad.send(send_packet, 0)
594         ad.settimeout(timeout)
595         try:
596             data = ad.recv(0xffff + 2, 0)
597             data = ndr.ndr_unpack(dns.name_packet, data)
598             # This should cause a loop in Windows
599             # (which is restricted by a 20 CNAME limit)
600             #
601             # The reason it doesn't here is because forwarded CNAME have no
602             # additional processing in the internal DNS server.
603             self.assert_dns_rcode_equals(data, dns.DNS_RCODE_OK)
604             self.assertEqual(name3, data.answers[-1].rdata)
605         except socket.timeout:
606             self.fail("DNS server is too slow (timeout %s)" % timeout)
607
608 TestProgram(module=__name__, opts=subunitopts)