scripting/samba_upgradedns: Tighten up exception and attribute list handling
[metze/samba/wip.git] / source4 / scripting / bin / samba_upgradedns
1 #!/usr/bin/env python
2 #
3 # Unix SMB/CIFS implementation.
4 # Copyright (C) Amitay Isaacs <amitay@gmail.com> 2012
5 #
6 # Upgrade DNS provision from BIND9_FLATFILE to BIND9_DLZ or SAMBA_INTERNAL
7 #
8 # This program is free software; you can redistribute it and/or modify
9 # it under the terms of the GNU General Public License as published by
10 # the Free Software Foundation; either version 3 of the License, or
11 # (at your option) any later version.
12 #
13 # This program is distributed in the hope that it will be useful,
14 # but WITHOUT ANY WARRANTY; without even the implied warranty of
15 # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
16 # GNU General Public License for more details.
17 #
18 # You should have received a copy of the GNU General Public License
19 # along with this program.  If not, see <http://www.gnu.org/licenses/>.
20
21 import sys
22 import os
23 import optparse
24 import logging
25 import grp
26 from base64 import b64encode
27 import shlex
28
29 sys.path.insert(0, "bin/python")
30
31 import ldb
32 import samba
33 from samba import param
34 from samba.auth import system_session
35 from samba.ndr import (
36     ndr_pack,
37     ndr_unpack )
38 import samba.getopt as options
39 from samba.upgradehelpers import (
40     get_paths,
41     get_ldbs )
42 from samba.dsdb import DS_DOMAIN_FUNCTION_2003
43 from samba.provision import (
44     find_provision_key_parameters,
45     interface_ips_v4,
46     interface_ips_v6 )
47 from samba.provision.common import (
48     setup_path,
49     setup_add_ldif )
50 from samba.provision.sambadns import (
51     ARecord,
52     AAAARecord,
53     CNameRecord,
54     NSRecord,
55     SOARecord,
56     SRVRecord,
57     TXTRecord,
58     get_dnsadmins_sid,
59     add_dns_accounts,
60     create_dns_partitions,
61     fill_dns_data_partitions,
62     create_dns_dir,
63     secretsdb_setup_dns,
64     create_samdb_copy,
65     create_named_conf,
66     create_named_txt )
67 from samba.dcerpc import security
68
69 samba.ensure_external_module("dns", "dnspython")
70 import dns.zone, dns.rdatatype
71
72 __docformat__ = 'restructuredText'
73
74
75 def find_bind_gid():
76     """Find system group id for bind9
77     """
78     for name in ["bind", "named"]:
79         try:
80             return grp.getgrnam(name)[2]
81         except KeyError:
82             pass
83     return None
84
85
86 def fix_names(pnames):
87     """Convert elements to strings from MessageElement
88     """
89     names = pnames
90     names.rootdn = pnames.rootdn[0]
91     names.domaindn = pnames.domaindn[0]
92     names.configdn = pnames.configdn[0]
93     names.schemadn = pnames.schemadn[0]
94     names.root_gid = pnames.root_gid
95     names.serverdn = str(pnames.serverdn)
96     return names
97
98
99 def convert_dns_rdata(rdata, serial=1):
100     """Convert resource records in dnsRecord format
101     """
102     if rdata.rdtype == dns.rdatatype.A:
103         rec = ARecord(rdata.address, serial=serial)
104     elif rdata.rdtype == dns.rdatatype.AAAA:
105         rec = AAAARecord(rdata.address, serial=serial)
106     elif rdata.rdtype == dns.rdatatype.CNAME:
107         rec = CNameRecord(rdata.target.to_text(), serial=serial)
108     elif rdata.rdtype == dns.rdatatype.NS:
109         rec = NSRecord(rdata.target.to_text(), serial=serial)
110     elif rdata.rdtype == dns.rdatatype.SRV:
111         rec = SRVRecord(rdata.target.to_text(), int(rdata.port),
112                         priority=int(rdata.priority), weight=int(rdata.weight),
113                         serial=serial)
114     elif rdata.rdtype == dns.rdatatype.TXT:
115         slist = shlex.split(rdata.to_text())
116         rec = TXTRecord(slist, serial=serial)
117     elif rdata.rdtype == dns.rdatatype.SOA:
118         rec = SOARecord(rdata.mname.to_text(), rdata.rname.to_text(),
119                         serial=int(rdata.serial),
120                         refresh=int(rdata.refresh), retry=int(rdata.retry),
121                         expire=int(rdata.expire), minimum=int(rdata.minimum))
122     else:
123         rec = None
124     return rec
125
126
127 def import_zone_data(samdb, logger, zone, serial, domaindn, forestdn,
128                      dnsdomain, dnsforest):
129     """Insert zone data in DNS partitions
130     """
131     labels = dnsdomain.split('.')
132     labels.append('')
133     domain_root = dns.name.Name(labels)
134     domain_prefix = "DC=%s,CN=MicrosoftDNS,DC=DomainDnsZones,%s" % (dnsdomain,
135                                                                     domaindn)
136
137     tmp = "_msdcs.%s" % dnsforest
138     labels = tmp.split('.')
139     labels.append('')
140     forest_root = dns.name.Name(labels)
141     dnsmsdcs = "_msdcs.%s" % dnsforest
142     forest_prefix = "DC=%s,CN=MicrosoftDNS,DC=ForestDnsZones,%s" % (dnsmsdcs,
143                                                                     forestdn)
144
145     # Extract @ record
146     at_record = zone.get_node(domain_root)
147     zone.delete_node(domain_root)
148
149     # SOA record
150     rdset = at_record.get_rdataset(dns.rdataclass.IN, dns.rdatatype.SOA)
151     soa_rec = ndr_pack(convert_dns_rdata(rdset[0]))
152     at_record.delete_rdataset(dns.rdataclass.IN, dns.rdatatype.SOA)
153
154     # NS record
155     rdset = at_record.get_rdataset(dns.rdataclass.IN, dns.rdatatype.NS)
156     ns_rec = ndr_pack(convert_dns_rdata(rdset[0]))
157     at_record.delete_rdataset(dns.rdataclass.IN, dns.rdatatype.NS)
158
159     # A/AAAA records
160     ip_recs = []
161     for rdset in at_record:
162         for r in rdset:
163             rec = convert_dns_rdata(r)
164             ip_recs.append(ndr_pack(rec))
165
166     # Add @ record for domain
167     dns_rec = [soa_rec, ns_rec] + ip_recs
168     msg = ldb.Message(ldb.Dn(samdb, 'DC=@,%s' % domain_prefix))
169     msg["objectClass"] = ["top", "dnsNode"]
170     msg["dnsRecord"] = ldb.MessageElement(dns_rec, ldb.FLAG_MOD_ADD,
171                                           "dnsRecord")
172     try:
173         samdb.add(msg)
174     except Exception:
175         logger.error("Failed to add @ record for domain")
176         raise
177     logger.debug("Added @ record for domain")
178
179     # Add @ record for forest
180     dns_rec = [soa_rec, ns_rec]
181     msg = ldb.Message(ldb.Dn(samdb, 'DC=@,%s' % forest_prefix))
182     msg["objectClass"] = ["top", "dnsNode"]
183     msg["dnsRecord"] = ldb.MessageElement(dns_rec, ldb.FLAG_MOD_ADD,
184                                           "dnsRecord")
185     try:
186         samdb.add(msg)
187     except Exception:
188         logger.error("Failed to add @ record for forest")
189         raise
190     logger.debug("Added @ record for forest")
191
192     # Add remaining records in domain and forest
193     for node in zone.nodes:
194         name = node.relativize(forest_root).to_text()
195         if name == node.to_text():
196             name = node.relativize(domain_root).to_text()
197             dn = "DC=%s,%s" % (name, domain_prefix)
198             fqdn = "%s.%s" % (name, dnsdomain)
199         else:
200             dn = "DC=%s,%s" % (name, forest_prefix)
201             fqdn = "%s.%s" % (name, dnsmsdcs)
202
203         dns_rec = []
204         for rdataset in zone.nodes[node]:
205             for rdata in rdataset:
206                 rec = convert_dns_rdata(rdata, serial)
207                 if not rec:
208                     logger.warn("Unsupported record type (%s) for %s, ignoring" %
209                                 dns.rdatatype.to_text(rdata.rdatatype), name)
210                 else:
211                     dns_rec.append(ndr_pack(rec))
212
213         msg = ldb.Message(ldb.Dn(samdb, dn))
214         msg["objectClass"] = ["top", "dnsNode"]
215         msg["dnsRecord"] = ldb.MessageElement(dns_rec, ldb.FLAG_MOD_ADD,
216                                               "dnsRecord")
217         try:
218             samdb.add(msg)
219         except Exception:
220             logger.error("Failed to add DNS record %s" % (fqdn))
221             raise
222         logger.debug("Added DNS record %s" % (fqdn))
223
224
225 # dnsprovision creates application partitions for AD based DNS mainly if the existing
226 # provision was created using earlier snapshots of samba4 which did not have support
227 # for DNS partitions
228
229 if __name__ == '__main__':
230
231     # Setup command line parser
232     parser = optparse.OptionParser("upgradedns [options]")
233     sambaopts = options.SambaOptions(parser)
234     credopts = options.CredentialsOptions(parser)
235
236     parser.add_option_group(options.VersionOptions(parser))
237     parser.add_option_group(sambaopts)
238     parser.add_option_group(credopts)
239
240     parser.add_option("--dns-backend", type="choice", metavar="<BIND9_DLZ|SAMBA_INTERNAL>",
241                       choices=["SAMBA_INTERNAL", "BIND9_DLZ"], default="SAMBA_INTERNAL",
242                       help="The DNS server backend, default SAMBA_INTERNAL")
243     parser.add_option("--migrate", type="choice", metavar="<yes|no>",
244                       choices=["yes","no"], default="yes",
245                       help="Migrate existing zone data, default yes")
246     parser.add_option("--verbose", help="Be verbose", action="store_true")
247
248     opts = parser.parse_args()[0]
249
250     if opts.dns_backend is None:
251         opts.dns_backend = 'SAMBA_INTERNAL'
252
253     if opts.migrate:
254         autofill = False
255     else:
256         autofill = True
257
258     # Set up logger
259     logger = logging.getLogger("upgradedns")
260     logger.addHandler(logging.StreamHandler(sys.stdout))
261     logger.setLevel(logging.INFO)
262     if opts.verbose:
263         logger.setLevel(logging.DEBUG)
264
265     lp = sambaopts.get_loadparm()
266     lp.load(lp.configfile)
267     creds = credopts.get_credentials(lp)
268
269     logger.info("Reading domain information")
270     paths = get_paths(param, smbconf=lp.configfile)
271     paths.bind_gid = find_bind_gid()
272     ldbs = get_ldbs(paths, creds, system_session(), lp)
273     pnames = find_provision_key_parameters(ldbs.sam, ldbs.secrets, ldbs.idmap,
274                                            paths, lp.configfile, lp)
275     names = fix_names(pnames)
276
277     if names.domainlevel < DS_DOMAIN_FUNCTION_2003:
278         logger.error("Cannot create AD based DNS for OS level < 2003")
279         sys.exit(1)
280
281     domaindn = names.domaindn
282     forestdn = names.rootdn
283
284     dnsdomain = names.dnsdomain.lower()
285     dnsforest = dnsdomain
286
287     site = names.sitename
288     hostname = names.hostname
289     dnsname = '%s.%s' % (hostname, dnsdomain)
290
291     domainsid = names.domainsid
292     domainguid = names.domainguid
293     ntdsguid = names.ntdsguid
294
295     # Check for DNS accounts and create them if required
296     try:
297         msg = ldbs.sam.search(base=domaindn, scope=ldb.SCOPE_DEFAULT,
298                               expression='(sAMAccountName=DnsAdmins)',
299                               attrs=['objectSid'])
300         dnsadmins_sid = ndr_unpack(security.dom_sid, msg[0]['objectSid'][0])
301     except IndexError:
302         logger.info("Adding DNS accounts")
303         add_dns_accounts(ldbs.sam, domaindn)
304         dnsadmins_sid = get_dnsadmins_sid(ldbs.sam, domaindn)
305     else:
306         logger.info("DNS accounts already exist")
307
308     # Import dns records from zone file
309     if os.path.exists(paths.dns):
310         logger.info("Reading records from zone file %s" % paths.dns)
311         try:
312             zone = dns.zone.from_file(paths.dns, relativize=False)
313             rrset = zone.get_rdataset("%s." % dnsdomain, dns.rdatatype.SOA)
314             serial = int(rrset[0].serial)
315         except Exception, e:
316             logger.warn("Error parsing DNS data from '%s' (%s)" % (paths.dns, str(e)))
317             logger.warn("DNS records will be automatically created")
318             autofill = True
319     else:
320         logger.info("No zone file %s" % paths.dns)
321         logger.warn("DNS records will be automatically created")
322         autofill = True
323
324     # Create DNS partitions if missing and fill DNS information
325     try:
326         expression = '(|(dnsRoot=DomainDnsZones.%s)(dnsRoot=ForestDnsZones.%s))' % \
327                      (dnsdomain, dnsforest)
328         msg = ldbs.sam.search(base=names.configdn, scope=ldb.SCOPE_DEFAULT,
329                               expression=expression, attrs=['nCName'])
330         ncname = msg[0]['nCName'][0]
331     except IndexError:
332         logger.info("Creating DNS partitions")
333
334         logger.info("Looking up IPv4 addresses")
335         hostip = interface_ips_v4(lp)
336         try:
337             hostip.remove('127.0.0.1')
338         except ValueError:
339             pass
340         if not hostip:
341             logger.error("No IPv4 addresses found")
342             sys.exit(1)
343         else:
344             hostip = hostip[0]
345             logger.debug("IPv4 addresses: %s" % hostip)
346
347         logger.info("Looking up IPv6 addresses")
348         hostip6 = interface_ips_v6(lp, linklocal=False)
349         if not hostip6:
350             hostip6 = None
351         else:
352             hostip6 = hostip6[0]
353         logger.debug("IPv6 addresses: %s" % hostip6)
354
355         create_dns_partitions(ldbs.sam, domainsid, names, domaindn, forestdn,
356                           dnsadmins_sid)
357
358         logger.info("Populating DNS partitions")
359         fill_dns_data_partitions(ldbs.sam, domainsid, site, domaindn, forestdn,
360                              dnsdomain, dnsforest, hostname, hostip, hostip6,
361                              domainguid, ntdsguid, dnsadmins_sid,
362                              autofill=autofill)
363
364         if not autofill:
365             logger.info("Importing records from zone file")
366             import_zone_data(ldbs.sam, logger, zone, serial, domaindn, forestdn,
367                              dnsdomain, dnsforest)
368     else:
369         logger.info("DNS partitions already exist")
370
371     # Mark that we are hosting DNS partitions
372     try:
373         dns_nclist = [ 'DC=DomainDnsZones,%s' % domaindn,
374                        'DC=ForestDnsZones,%s' % forestdn ]
375
376         msgs = ldbs.sam.search(base=names.serverdn, scope=ldb.SCOPE_DEFAULT,
377                                expression='(objectclass=nTDSDSa)',
378                                attrs=['hasPartialReplicaNCs',
379                                       'msDS-hasMasterNCs'])
380         msg = msgs[0]
381
382         master_nclist = []
383         ncs = msg.get("msDS-hasMasterNCs")
384         if ncs:
385             for nc in ncs:
386                 master_nclist.append(nc)
387
388         partial_nclist = []
389         ncs = msg.get("hasPartialReplicaNCs")
390         if ncs:
391             for nc in ncs:
392                 partial_nclist.append(nc)
393
394         modified_master = False
395         modified_partial = False
396         for nc in dns_nclist:
397             if nc not in master_nclist:
398                 master_nclist.append(nc)
399                 modified_master = True
400             if nc in partial_nclist:
401                 partial_nclist.remove(nc)
402                 modified_partial = True
403
404         if modified_master or modified_partial:
405             logger.debug("Updating msDS-hasMasterNCs and hasPartialReplicaNCs attributes")
406             m = ldb.Message()
407             m.dn = msg.dn
408             if modified_master:
409                 m["msDS-hasMasterNCs"] = ldb.MessageElement(master_nclist,
410                                                             ldb.FLAG_MOD_REPLACE,
411                                                             "msDS-hasMasterNCs")
412             if modified_partial:
413                 if partial_nclist:
414                     m["hasPartialReplicaNCs"] = ldb.MessageElement(partial_nclist,
415                                                                    ldb.FLAG_MOD_REPLACE,
416                                                                    "hasPartialReplicaNCs")
417                 else:
418                     m["hasPartialReplicaNCs"] = ldb.MessageElement(ncs,
419                                                                    ldb.FLAG_MOD_DELETE,
420                                                                    "hasPartialReplicaNCs")
421             ldbs.sam.modify(m)
422     except Exception:
423         raise
424
425     # Special stuff for DLZ backend
426     if opts.dns_backend == "BIND9_DLZ":
427         # Check if dns-HOSTNAME account exists and create it if required
428         try:
429             dn = 'samAccountName=dns-%s,CN=Principals' % hostname
430             msg = ldbs.secrets.search(expression='(dn=%s)' % dn, attrs=['secret'])
431             dnssecret = msg[0]['secret'][0]
432         except IndexError:
433
434             logger.info("Adding dns-%s account" % hostname)
435
436             try:
437                 msg = ldbs.sam.search(base=domaindn, scope=ldb.SCOPE_DEFAULT,
438                                       expression='(sAMAccountName=dns-%s)' % (hostname),
439                                       attrs=[])
440                 dn = msg[0].dn
441                 ldbs.sam.delete(dn)
442             except IndexError:
443                 pass
444
445             dnspass = samba.generate_random_password(128, 255)
446             setup_add_ldif(ldbs.sam, setup_path("provision_dns_add_samba.ldif"), {
447                     "DNSDOMAIN": dnsdomain,
448                     "DOMAINDN": domaindn,
449                     "DNSPASS_B64": b64encode(dnspass.encode('utf-16-le')),
450                     "HOSTNAME" : hostname,
451                     "DNSNAME" : dnsname }
452                            )
453
454             res = ldbs.sam.search(base=domaindn, scope=ldb.SCOPE_DEFAULT,
455                                   expression='(sAMAccountName=dns-%s)' % (hostname),
456                                   attrs=["msDS-KeyVersionNumber"])
457             if "msDS-KeyVersionNumber" in res[0]:
458                 dns_key_version_number = int(res[0]["msDS-KeyVersionNumber"][0])
459             else:
460                 dns_key_version_number = None
461
462             secretsdb_setup_dns(ldbs.secrets, names,
463                                 paths.private_dir, realm=names.realm,
464                                 dnsdomain=names.dnsdomain,
465                                 dns_keytab_path=paths.dns_keytab, dnspass=dnspass,
466                                 key_version_number=dns_key_version_number)
467         else:
468             logger.info("dns-%s account already exists" % hostname)
469
470         # This forces a re-creation of dns directory and all the files within
471         # It's an overkill, but it's easier to re-create a samdb copy, rather
472         # than trying to fix a broken copy.
473         create_dns_dir(logger, paths)
474
475         # Setup a copy of SAM for BIND9
476         create_samdb_copy(ldbs.sam, logger, paths, names, domainsid,
477                           domainguid)
478
479         create_named_conf(paths, names.realm, dnsdomain, opts.dns_backend)
480
481         create_named_txt(paths.namedtxt, names.realm, dnsdomain, dnsname,
482                          paths.private_dir, paths.dns_keytab)
483         logger.info("See %s for an example configuration include file for BIND", paths.namedconf)
484         logger.info("and %s for further documentation required for secure DNS "
485                     "updates", paths.namedtxt)
486     elif opts.dns_backend == "SAMBA_INTERNAL":
487         # Check if dns-HOSTNAME account exists and delete it if required
488         try:
489             dn_str = 'samAccountName=dns-%s,CN=Principals' % hostname
490             msg = ldbs.secrets.search(expression='(dn=%s)' % dn_str, attrs=[])
491             dn = msg[0].dn
492         except IndexError:
493             dn = None
494
495         if dn is not None:
496             try:
497                 ldbs.secrets.delete(dn)
498             except Exception:
499                 logger.info("Failed to delete %s from secrets.ldb" % dn)
500
501         try:
502             msg = ldbs.sam.search(base=domaindn, scope=ldb.SCOPE_DEFAULT,
503                                   expression='(sAMAccountName=dns-%s)' % (hostname),
504                                   attrs=[])
505             dn = msg[0].dn
506         except IndexError:
507             dn = None
508
509         if dn is not None:
510             try:
511                 ldbs.sam.delete(dn)
512             except Exception:
513                 logger.info("Failed to delete %s from sam.ldb" % dn)
514
515     logger.info("Finished upgrading DNS")