cf9f537e461d82c1336a6cf857bd1d22e994bdb2
[metze/samba/wip.git] / python / samba / tests / dns_forwarder.py
1 # Unix SMB/CIFS implementation.
2 # Copyright (C) Kai Blin  <kai@samba.org> 2011
3 #
4 # This program is free software; you can redistribute it and/or modify
5 # it under the terms of the GNU General Public License as published by
6 # the Free Software Foundation; either version 3 of the License, or
7 # (at your option) any later version.
8 #
9 # This program is distributed in the hope that it will be useful,
10 # but WITHOUT ANY WARRANTY; without even the implied warranty of
11 # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
12 # GNU General Public License for more details.
13 #
14 # You should have received a copy of the GNU General Public License
15 # along with this program.  If not, see <http://www.gnu.org/licenses/>.
16 #
17
18 import os
19 import sys
20 import struct
21 import random
22 import socket
23 import samba
24 import time
25 import errno
26 import samba.ndr as ndr
27 from samba import credentials, param
28 from samba.tests import TestCase
29 from samba.dcerpc import dns, dnsp, dnsserver
30 from samba.netcmd.dns import TXTRecord, dns_record_match, data_to_dns_record
31 from samba.tests.subunitrun import SubunitOptions, TestProgram
32 import samba.getopt as options
33 import optparse
34 import subprocess
35
36 parser = optparse.OptionParser("dns_forwarder.py <server name> <server ip> (dns forwarder)+ [options]")
37 sambaopts = options.SambaOptions(parser)
38 parser.add_option_group(sambaopts)
39
40 # This timeout only has relevance when testing against Windows
41 # Format errors tend to return patchy responses, so a timeout is needed.
42 parser.add_option("--timeout", type="int", dest="timeout",
43                   help="Specify timeout for DNS requests")
44
45 # use command line creds if available
46 credopts = options.CredentialsOptions(parser)
47 parser.add_option_group(credopts)
48 subunitopts = SubunitOptions(parser)
49 parser.add_option_group(subunitopts)
50
51 opts, args = parser.parse_args()
52
53 lp = sambaopts.get_loadparm()
54 creds = credopts.get_credentials(lp)
55
56 timeout = opts.timeout
57
58 if len(args) < 3:
59     parser.print_usage()
60     sys.exit(1)
61
62 server_name = args[0]
63 server_ip = args[1]
64 dns_servers = args[2:]
65
66 creds.set_krb_forwardable(credentials.NO_KRB_FORWARDABLE)
67
68 def make_txt_record(records):
69     rdata_txt = dns.txt_record()
70     s_list = dnsp.string_list()
71     s_list.count = len(records)
72     s_list.str = records
73     rdata_txt.txt = s_list
74     return rdata_txt
75
76
77 class DNSTest(TestCase):
78
79     errcodes = dict((v, k) for k, v in vars(dns).items() if k.startswith('DNS_RCODE_'))
80
81     def assert_dns_rcode_equals(self, packet, rcode):
82         "Helper function to check return code"
83         p_errcode = packet.operation & 0x000F
84         self.assertEquals(p_errcode, rcode, "Expected RCODE %s, got %s" %
85                           (self.errcodes[rcode], self.errcodes[p_errcode]))
86
87     def assert_dns_opcode_equals(self, packet, opcode):
88         "Helper function to check opcode"
89         p_opcode = packet.operation & 0x7800
90         self.assertEquals(p_opcode, opcode, "Expected OPCODE %s, got %s" %
91                             (opcode, p_opcode))
92
93     def make_name_packet(self, opcode, qid=None):
94         "Helper creating a dns.name_packet"
95         p = dns.name_packet()
96         if qid is None:
97             p.id = random.randint(0x0, 0xffff)
98         p.operation = opcode
99         p.questions = []
100         return p
101
102     def finish_name_packet(self, packet, questions):
103         "Helper to finalize a dns.name_packet"
104         packet.qdcount = len(questions)
105         packet.questions = questions
106
107     def make_name_question(self, name, qtype, qclass):
108         "Helper creating a dns.name_question"
109         q = dns.name_question()
110         q.name = name
111         q.question_type = qtype
112         q.question_class = qclass
113         return q
114
115     def get_dns_domain(self):
116         "Helper to get dns domain"
117         return self.creds.get_realm().lower()
118
119     def dns_transaction_udp(self, packet, host=server_ip,
120                             dump=False, timeout=timeout):
121         "send a DNS query and read the reply"
122         s = None
123         try:
124             send_packet = ndr.ndr_pack(packet)
125             if dump:
126                 print self.hexdump(send_packet)
127             s = socket.socket(socket.AF_INET, socket.SOCK_DGRAM, 0)
128             s.settimeout(timeout)
129             s.connect((host, 53))
130             s.send(send_packet, 0)
131             recv_packet = s.recv(2048, 0)
132             if dump:
133                 print self.hexdump(recv_packet)
134             return ndr.ndr_unpack(dns.name_packet, recv_packet)
135         finally:
136             if s is not None:
137                 s.close()
138
139     def make_cname_update(self, key, value):
140         p = self.make_name_packet(dns.DNS_OPCODE_UPDATE)
141
142         name = self.get_dns_domain()
143         u = self.make_name_question(name, dns.DNS_QTYPE_SOA, dns.DNS_QCLASS_IN)
144         self.finish_name_packet(p, [u])
145
146         r = dns.res_rec()
147         r.name = key
148         r.rr_type = dns.DNS_QTYPE_CNAME
149         r.rr_class = dns.DNS_QCLASS_IN
150         r.ttl = 900
151         r.length = 0xffff
152         rdata = value
153         r.rdata = rdata
154         p.nscount = 1
155         p.nsrecs = [r]
156         response = self.dns_transaction_udp(p)
157         self.assert_dns_rcode_equals(response, dns.DNS_RCODE_OK)
158
159
160
161 def contact_real_server(host, port):
162     s = socket.socket(socket.AF_INET, socket.SOCK_DGRAM, 0)
163     s.connect((host, port))
164     return s
165
166
167 class TestDnsForwarding(DNSTest):
168     def __init__(self, *args, **kwargs):
169         super(TestDnsForwarding, self).__init__(*args, **kwargs)
170         self.subprocesses = []
171
172     def setUp(self):
173         super(TestDnsForwarding, self).setUp()
174         self.server = server_name
175         self.server_ip = server_ip
176         self.lp = lp
177         self.creds = creds
178
179     def start_toy_server(self, host, port, id):
180         python = sys.executable
181         p = subprocess.Popen([python,
182                               os.path.join(samba.source_tree_topdir(),
183                                            'python/samba/tests/'
184                                            'dns_forwarder_helpers/server.py'),
185                              host, str(port), id])
186         self.subprocesses.append(p)
187         s = socket.socket(socket.AF_INET, socket.SOCK_DGRAM, 0)
188         for i in xrange(300):
189             time.sleep(0.05)
190             s.connect((host, port))
191             try:
192                 s.send('timeout 0', 0)
193             except socket.error as e:
194                 if e.errno in (errno.ECONNREFUSED, errno.EHOSTUNREACH):
195                     continue
196
197             if p.returncode is not None:
198                 self.fail("Toy server has managed to die already!")
199
200             return s
201
202     def tearDown(self):
203         super(TestDnsForwarding, self).tearDown()
204         for p in self.subprocesses:
205             p.kill()
206
207     def test_comatose_forwarder(self):
208         s = self.start_toy_server(dns_servers[0], 53, 'forwarder1')
209         s.send("timeout 1000000", 0)
210
211         # make DNS query
212         name = "an-address-that-will-not-resolve"
213         p = self.make_name_packet(dns.DNS_OPCODE_QUERY)
214         questions = []
215
216         q = self.make_name_question(name, dns.DNS_QTYPE_TXT, dns.DNS_QCLASS_IN)
217         questions.append(q)
218
219         self.finish_name_packet(p, questions)
220         send_packet = ndr.ndr_pack(p)
221
222         s.send(send_packet, 0)
223         s.settimeout(1)
224         try:
225             s.recv(0xffff + 2, 0)
226             self.fail("DNS forwarder should have been inactive")
227         except socket.timeout:
228             # Expected forwarder to be dead
229             pass
230
231     def test_no_active_forwarder(self):
232         ad = contact_real_server(server_ip, 53)
233
234         name = "dsfsfds.dsfsdfs"
235         p = self.make_name_packet(dns.DNS_OPCODE_QUERY)
236         questions = []
237
238         q = self.make_name_question(name, dns.DNS_QTYPE_TXT, dns.DNS_QCLASS_IN)
239         questions.append(q)
240
241         self.finish_name_packet(p, questions)
242         send_packet = ndr.ndr_pack(p)
243
244         self.finish_name_packet(p, questions)
245         p.operation |= dns.DNS_FLAG_RECURSION_DESIRED
246         send_packet = ndr.ndr_pack(p)
247
248         ad.send(send_packet, 0)
249         ad.settimeout(timeout)
250         try:
251             data = ad.recv(0xffff + 2, 0)
252             data = ndr.ndr_unpack(dns.name_packet, data)
253             self.assert_dns_rcode_equals(data, dns.DNS_RCODE_SERVFAIL)
254             self.assertEqual(data.ancount, 0)
255         except socket.timeout:
256             self.fail("DNS server is too slow (timeout %s)" % timeout)
257
258     def test_no_flag_recursive_forwarder(self):
259         ad = contact_real_server(server_ip, 53)
260
261         name = "dsfsfds.dsfsdfs"
262         p = self.make_name_packet(dns.DNS_OPCODE_QUERY)
263         questions = []
264
265         q = self.make_name_question(name, dns.DNS_QTYPE_TXT, dns.DNS_QCLASS_IN)
266         questions.append(q)
267
268         self.finish_name_packet(p, questions)
269         send_packet = ndr.ndr_pack(p)
270
271         self.finish_name_packet(p, questions)
272         # Leave off the recursive flag
273         send_packet = ndr.ndr_pack(p)
274
275         ad.send(send_packet, 0)
276         ad.settimeout(timeout)
277         try:
278             data = ad.recv(0xffff + 2, 0)
279             data = ndr.ndr_unpack(dns.name_packet, data)
280             self.assert_dns_rcode_equals(data, dns.DNS_RCODE_NXDOMAIN)
281             self.assertEqual(data.ancount, 0)
282         except socket.timeout:
283             self.fail("DNS server is too slow (timeout %s)" % timeout)
284
285     def test_single_forwarder(self):
286         s = self.start_toy_server(dns_servers[0], 53, 'forwarder1')
287         ad = contact_real_server(server_ip, 53)
288         name = "dsfsfds.dsfsdfs"
289         p = self.make_name_packet(dns.DNS_OPCODE_QUERY)
290         questions = []
291
292         q = self.make_name_question(name, dns.DNS_QTYPE_CNAME,
293                                     dns.DNS_QCLASS_IN)
294         questions.append(q)
295
296         self.finish_name_packet(p, questions)
297         p.operation |= dns.DNS_FLAG_RECURSION_DESIRED
298         send_packet = ndr.ndr_pack(p)
299
300         ad.send(send_packet, 0)
301         ad.settimeout(timeout)
302         try:
303             data = ad.recv(0xffff + 2, 0)
304             data = ndr.ndr_unpack(dns.name_packet, data)
305             self.assert_dns_rcode_equals(data, dns.DNS_RCODE_OK)
306             self.assertEqual('forwarder1', data.answers[0].rdata)
307         except socket.timeout:
308             self.fail("DNS server is too slow (timeout %s)" % timeout)
309
310     def test_single_forwarder_not_actually_there(self):
311         ad = contact_real_server(server_ip, 53)
312         name = "dsfsfds.dsfsdfs"
313         p = self.make_name_packet(dns.DNS_OPCODE_QUERY)
314         questions = []
315
316         q = self.make_name_question(name, dns.DNS_QTYPE_CNAME,
317                                     dns.DNS_QCLASS_IN)
318         questions.append(q)
319
320         self.finish_name_packet(p, questions)
321         p.operation |= dns.DNS_FLAG_RECURSION_DESIRED
322         send_packet = ndr.ndr_pack(p)
323
324         ad.send(send_packet, 0)
325         ad.settimeout(timeout)
326         try:
327             data = ad.recv(0xffff + 2, 0)
328             data = ndr.ndr_unpack(dns.name_packet, data)
329             self.assert_dns_rcode_equals(data, dns.DNS_RCODE_SERVFAIL)
330         except socket.timeout:
331             self.fail("DNS server is too slow (timeout %s)" % timeout)
332
333
334     def test_single_forwarder_waiting_forever(self):
335         s = self.start_toy_server(dns_servers[0], 53, 'forwarder1')
336         s.send('timeout 10000', 0)
337         ad = contact_real_server(server_ip, 53)
338         name = "dsfsfds.dsfsdfs"
339         p = self.make_name_packet(dns.DNS_OPCODE_QUERY)
340         questions = []
341
342         q = self.make_name_question(name, dns.DNS_QTYPE_CNAME,
343                                     dns.DNS_QCLASS_IN)
344         questions.append(q)
345
346         self.finish_name_packet(p, questions)
347         p.operation |= dns.DNS_FLAG_RECURSION_DESIRED
348         send_packet = ndr.ndr_pack(p)
349
350         ad.send(send_packet, 0)
351         ad.settimeout(timeout)
352         try:
353             data = ad.recv(0xffff + 2, 0)
354             data = ndr.ndr_unpack(dns.name_packet, data)
355             self.assert_dns_rcode_equals(data, dns.DNS_RCODE_SERVFAIL)
356         except socket.timeout:
357             self.fail("DNS server is too slow (timeout %s)" % timeout)
358
359     def test_double_forwarder_first_frozen(self):
360         if len(dns_servers) < 2:
361             print "Ignoring test_double_forwarder_first_frozen"
362             return
363         s1 = self.start_toy_server(dns_servers[0], 53, 'forwarder1')
364         s2 = self.start_toy_server(dns_servers[1], 53, 'forwarder2')
365         s1.send('timeout 1000', 0)
366         ad = contact_real_server(server_ip, 53)
367         name = "dsfsfds.dsfsdfs"
368         p = self.make_name_packet(dns.DNS_OPCODE_QUERY)
369         questions = []
370
371         q = self.make_name_question(name, dns.DNS_QTYPE_CNAME,
372                                     dns.DNS_QCLASS_IN)
373         questions.append(q)
374
375         self.finish_name_packet(p, questions)
376         p.operation |= dns.DNS_FLAG_RECURSION_DESIRED
377         send_packet = ndr.ndr_pack(p)
378
379         ad.send(send_packet, 0)
380         ad.settimeout(timeout)
381         try:
382             data = ad.recv(0xffff + 2, 0)
383             data = ndr.ndr_unpack(dns.name_packet, data)
384             self.assert_dns_rcode_equals(data, dns.DNS_RCODE_OK)
385             self.assertEqual('forwarder2', data.answers[0].rdata)
386         except socket.timeout:
387             self.fail("DNS server is too slow (timeout %s)" % timeout)
388
389     def test_double_forwarder_first_down(self):
390         if len(dns_servers) < 2:
391             print "Ignoring test_double_forwarder_first_down"
392             return
393         s2 = self.start_toy_server(dns_servers[1], 53, 'forwarder2')
394         ad = contact_real_server(server_ip, 53)
395         name = "dsfsfds.dsfsdfs"
396         p = self.make_name_packet(dns.DNS_OPCODE_QUERY)
397         questions = []
398
399         q = self.make_name_question(name, dns.DNS_QTYPE_CNAME,
400                                     dns.DNS_QCLASS_IN)
401         questions.append(q)
402
403         self.finish_name_packet(p, questions)
404         p.operation |= dns.DNS_FLAG_RECURSION_DESIRED
405         send_packet = ndr.ndr_pack(p)
406
407         ad.send(send_packet, 0)
408         ad.settimeout(timeout)
409         try:
410             data = ad.recv(0xffff + 2, 0)
411             data = ndr.ndr_unpack(dns.name_packet, data)
412             self.assert_dns_rcode_equals(data, dns.DNS_RCODE_OK)
413             self.assertEqual('forwarder2', data.answers[0].rdata)
414         except socket.timeout:
415             self.fail("DNS server is too slow (timeout %s)" % timeout)
416
417     def test_double_forwarder_both_slow(self):
418         if len(dns_servers) < 2:
419             print "Ignoring test_double_forwarder_both_slow"
420             return
421         s1 = self.start_toy_server(dns_servers[0], 53, 'forwarder1')
422         s2 = self.start_toy_server(dns_servers[1], 53, 'forwarder2')
423         s1.send('timeout 1.5', 0)
424         s2.send('timeout 1.5', 0)
425         ad = contact_real_server(server_ip, 53)
426         name = "dsfsfds.dsfsdfs"
427         p = self.make_name_packet(dns.DNS_OPCODE_QUERY)
428         questions = []
429
430         q = self.make_name_question(name, dns.DNS_QTYPE_CNAME,
431                                     dns.DNS_QCLASS_IN)
432         questions.append(q)
433
434         self.finish_name_packet(p, questions)
435         p.operation |= dns.DNS_FLAG_RECURSION_DESIRED
436         send_packet = ndr.ndr_pack(p)
437
438         ad.send(send_packet, 0)
439         ad.settimeout(timeout)
440         try:
441             data = ad.recv(0xffff + 2, 0)
442             data = ndr.ndr_unpack(dns.name_packet, data)
443             self.assert_dns_rcode_equals(data, dns.DNS_RCODE_OK)
444             self.assertEqual('forwarder1', data.answers[0].rdata)
445         except socket.timeout:
446             self.fail("DNS server is too slow (timeout %s)" % timeout)
447
448     def test_cname(self):
449         s1 = self.start_toy_server(dns_servers[0], 53, 'forwarder1')
450
451         ad = contact_real_server(server_ip, 53)
452         name = "resolve.cname"
453         p = self.make_name_packet(dns.DNS_OPCODE_QUERY)
454         questions = []
455
456         q = self.make_name_question(name, dns.DNS_QTYPE_CNAME,
457                                     dns.DNS_QCLASS_IN)
458         questions.append(q)
459
460         self.finish_name_packet(p, questions)
461         p.operation |= dns.DNS_FLAG_RECURSION_DESIRED
462         send_packet = ndr.ndr_pack(p)
463
464         ad.send(send_packet, 0)
465         ad.settimeout(timeout)
466         try:
467             data = ad.recv(0xffff + 2, 0)
468             data = ndr.ndr_unpack(dns.name_packet, data)
469             self.assert_dns_rcode_equals(data, dns.DNS_RCODE_OK)
470             self.assertEqual(len(data.answers), 1)
471             self.assertEqual('forwarder1', data.answers[0].rdata)
472         except socket.timeout:
473             self.fail("DNS server is too slow (timeout %s)" % timeout)
474
475     def test_double_cname(self):
476         s1 = self.start_toy_server(dns_servers[0], 53, 'forwarder1')
477
478         name = 'resolve.cname.%s' % self.get_dns_domain()
479         self.make_cname_update(name, "dsfsfds.dsfsdfs")
480
481         ad = contact_real_server(server_ip, 53)
482
483         p = self.make_name_packet(dns.DNS_OPCODE_QUERY)
484         questions = []
485         q = self.make_name_question(name, dns.DNS_QTYPE_A,
486                                     dns.DNS_QCLASS_IN)
487         questions.append(q)
488
489         self.finish_name_packet(p, questions)
490         p.operation |= dns.DNS_FLAG_RECURSION_DESIRED
491         send_packet = ndr.ndr_pack(p)
492
493         ad.send(send_packet, 0)
494         ad.settimeout(timeout)
495         try:
496             data = ad.recv(0xffff + 2, 0)
497             data = ndr.ndr_unpack(dns.name_packet, data)
498             self.assert_dns_rcode_equals(data, dns.DNS_RCODE_OK)
499             self.assertEqual('forwarder1', data.answers[1].rdata)
500         except socket.timeout:
501             self.fail("DNS server is too slow (timeout %s)" % timeout)
502
503     def test_cname_forwarding_with_slow_server(self):
504         if len(dns_servers) < 2:
505             print "Ignoring test_cname_forwarding_with_slow_server"
506             return
507         s1 = self.start_toy_server(dns_servers[0], 53, 'forwarder1')
508         s2 = self.start_toy_server(dns_servers[1], 53, 'forwarder2')
509         s1.send('timeout 10000', 0)
510
511         name = 'resolve.cname.%s' % self.get_dns_domain()
512         self.make_cname_update(name, "dsfsfds.dsfsdfs")
513
514         ad = contact_real_server(server_ip, 53)
515
516         p = self.make_name_packet(dns.DNS_OPCODE_QUERY)
517         questions = []
518         q = self.make_name_question(name, dns.DNS_QTYPE_A,
519                                     dns.DNS_QCLASS_IN)
520         questions.append(q)
521
522         self.finish_name_packet(p, questions)
523         p.operation |= dns.DNS_FLAG_RECURSION_DESIRED
524         send_packet = ndr.ndr_pack(p)
525
526         ad.send(send_packet, 0)
527         ad.settimeout(timeout)
528         try:
529             data = ad.recv(0xffff + 2, 0)
530             data = ndr.ndr_unpack(dns.name_packet, data)
531             self.assert_dns_rcode_equals(data, dns.DNS_RCODE_OK)
532             self.assertEqual('forwarder2', data.answers[-1].rdata)
533         except socket.timeout:
534             self.fail("DNS server is too slow (timeout %s)" % timeout)
535
536     def test_cname_forwarding_with_server_down(self):
537         if len(dns_servers) < 2:
538             print "Ignoring test_cname_forwarding_with_server_down"
539             return
540         s2 = self.start_toy_server(dns_servers[1], 53, 'forwarder2')
541
542         name1 = 'resolve1.cname.%s' % self.get_dns_domain()
543         name2 = 'resolve2.cname.%s' % self.get_dns_domain()
544         self.make_cname_update(name1, name2)
545         self.make_cname_update(name2, "dsfsfds.dsfsdfs")
546
547         ad = contact_real_server(server_ip, 53)
548
549         p = self.make_name_packet(dns.DNS_OPCODE_QUERY)
550         questions = []
551         q = self.make_name_question(name1, dns.DNS_QTYPE_A,
552                                     dns.DNS_QCLASS_IN)
553         questions.append(q)
554
555         self.finish_name_packet(p, questions)
556         p.operation |= dns.DNS_FLAG_RECURSION_DESIRED
557         send_packet = ndr.ndr_pack(p)
558
559         ad.send(send_packet, 0)
560         ad.settimeout(timeout)
561         try:
562             data = ad.recv(0xffff + 2, 0)
563             data = ndr.ndr_unpack(dns.name_packet, data)
564             self.assert_dns_rcode_equals(data, dns.DNS_RCODE_OK)
565             self.assertEqual('forwarder2', data.answers[-1].rdata)
566         except socket.timeout:
567             self.fail("DNS server is too slow (timeout %s)" % timeout)
568
569     def test_cname_forwarding_with_lots_of_cnames(self):
570         name3 = 'resolve3.cname.%s' % self.get_dns_domain()
571         s1 = self.start_toy_server(dns_servers[0], 53, name3)
572
573         name1 = 'resolve1.cname.%s' % self.get_dns_domain()
574         name2 = 'resolve2.cname.%s' % self.get_dns_domain()
575         self.make_cname_update(name1, name2)
576         self.make_cname_update(name3, name1)
577         self.make_cname_update(name2, "dsfsfds.dsfsdfs")
578
579         ad = contact_real_server(server_ip, 53)
580
581         p = self.make_name_packet(dns.DNS_OPCODE_QUERY)
582         questions = []
583         q = self.make_name_question(name1, dns.DNS_QTYPE_A,
584                                     dns.DNS_QCLASS_IN)
585         questions.append(q)
586
587         self.finish_name_packet(p, questions)
588         p.operation |= dns.DNS_FLAG_RECURSION_DESIRED
589         send_packet = ndr.ndr_pack(p)
590
591         ad.send(send_packet, 0)
592         ad.settimeout(timeout)
593         try:
594             data = ad.recv(0xffff + 2, 0)
595             data = ndr.ndr_unpack(dns.name_packet, data)
596             # This should cause a loop in Windows
597             # (which is restricted by a 20 CNAME limit)
598             #
599             # The reason it doesn't here is because forwarded CNAME have no
600             # additional processing in the internal DNS server.
601             self.assert_dns_rcode_equals(data, dns.DNS_RCODE_OK)
602             self.assertEqual(name3, data.answers[-1].rdata)
603         except socket.timeout:
604             self.fail("DNS server is too slow (timeout %s)" % timeout)
605
606 TestProgram(module=__name__, opts=subunitopts)