s4-join: Setup correct DNS configuration
[ddiss/samba.git] / source4 / scripting / python / samba / join.py
index 195dfc23120f2c81314a4d3f49add3c960d9f44f..9ef7d3dd1737658719ae69fa7efa2bedd35efac3 100644 (file)
@@ -1,5 +1,3 @@
-#!/usr/bin/env python
-#
 # python join code
 # Copyright Andrew Tridgell 2010
 # Copyright Andrew Bartlett 2010
 from samba.auth import system_session
 from samba.samdb import SamDB
 from samba import gensec, Ldb, drs_utils
-import ldb, samba, sys, os, uuid
+import ldb, samba, sys, uuid
 from samba.ndr import ndr_pack
 from samba.dcerpc import security, drsuapi, misc, nbt, lsa, drsblobs
 from samba.credentials import Credentials, DONT_USE_KERBEROS
 from samba.provision import secretsdb_self_join, provision, provision_fill, FILL_DRS, FILL_SUBDOMAIN
 from samba.schema import Schema
 from samba.net import Net
-from samba.dcerpc import security
+from samba.provision.sambadns import setup_bind9_dns
 import logging
 import talloc
 import random
@@ -49,12 +47,21 @@ class dc_join(object):
     '''perform a DC join'''
 
     def __init__(ctx, server=None, creds=None, lp=None, site=None,
-            netbios_name=None, targetdir=None, domain=None):
+            netbios_name=None, targetdir=None, domain=None,
+            machinepass=None, use_ntvfs=False, dns_backend=None):
         ctx.creds = creds
         ctx.lp = lp
         ctx.site = site
         ctx.netbios_name = netbios_name
         ctx.targetdir = targetdir
+        ctx.use_ntvfs = use_ntvfs
+        if dns_backend is None:
+            ctx.dns_backend = "NONE"
+        else:
+            ctx.dns_backend = dns_backend
+
+        ctx.nc_list = []
+        ctx.full_nc_list = []
 
         ctx.creds.set_gensec_features(creds.get_gensec_features() | gensec.FEATURE_SEAL)
         ctx.net = Net(creds=ctx.creds, lp=ctx.lp)
@@ -84,13 +91,17 @@ class dc_join(object):
         ctx.config_dn = str(ctx.samdb.get_config_basedn())
         ctx.domsid = ctx.samdb.get_domain_sid()
         ctx.domain_name = ctx.get_domain_name()
+        ctx.forest_domain_name = ctx.get_forest_domain_name()
         ctx.invocation_id = misc.GUID(str(uuid.uuid4()))
 
-        ctx.dc_ntds_dn = ctx.get_dsServiceName()
+        ctx.dc_ntds_dn = ctx.samdb.get_dsServiceName()
         ctx.dc_dnsHostName = ctx.get_dnsHostName()
         ctx.behavior_version = ctx.get_behavior_version()
 
-        ctx.acct_pass = samba.generate_random_password(32, 40)
+        if machinepass is not None:
+            ctx.acct_pass = machinepass
+        else:
+            ctx.acct_pass = samba.generate_random_password(32, 40)
 
         # work out the DNs of all the objects we will be adding
         ctx.server_dn = "CN=%s,CN=Servers,CN=%s,CN=Sites,%s" % (ctx.myname, ctx.site, ctx.config_dn)
@@ -125,7 +136,6 @@ class dc_join(object):
         ctx.managedby = None
         ctx.subdomain = False
 
-
     def del_noerror(ctx, dn, recursive=False):
         if recursive:
             try:
@@ -144,12 +154,12 @@ class dc_join(object):
         '''remove any DNs from a previous join'''
         try:
             # find the krbtgt link
-            print("checking samaccountname")
+            print("checking sAMAccountName")
             if ctx.subdomain:
                 res = None
             else:
                 res = ctx.samdb.search(base=ctx.samdb.get_default_basedn(),
-                                       expression='samAccountName=%s' % ldb.binary_encode(ctx.samname),
+                                       expression='sAMAccountName=%s' % ldb.binary_encode(ctx.samname),
                                        attrs=["msDS-krbTgtLink"])
                 if res:
                     ctx.del_noerror(res[0].dn, recursive=True)
@@ -185,7 +195,7 @@ class dc_join(object):
                 lsaconn.DeleteTrustedDomain(pol_handle, info.info_ex.sid)
 
                 name = lsa.String()
-                name.string = ctx.domain_name
+                name.string = ctx.forest_domain_name
                 info = lsaconn.QueryTrustedDomainInfoByName(pol_handle, name, lsa.LSA_TRUSTED_DOMAIN_INFO_FULL_INFO)
 
                 lsaconn.DeleteTrustedDomain(pol_handle, info.info_ex.sid)
@@ -196,7 +206,7 @@ class dc_join(object):
     def find_dc(ctx, domain):
         '''find a writeable DC for the given domain'''
         try:
-            ctx.cldap_ret = ctx.net.finddc(domainnbt.NBT_SERVER_LDAP | nbt.NBT_SERVER_DS | nbt.NBT_SERVER_WRITABLE)
+            ctx.cldap_ret = ctx.net.finddc(domain=domain, flags=nbt.NBT_SERVER_LDAP | nbt.NBT_SERVER_DS | nbt.NBT_SERVER_WRITABLE)
         except Exception:
             raise Exception("Failed to find a writeable DC for domain '%s'" % domain)
         if ctx.cldap_ret.client_site is not None and ctx.cldap_ret.client_site != "":
@@ -204,10 +214,6 @@ class dc_join(object):
         return ctx.cldap_ret.pdc_dns_name
 
 
-    def get_dsServiceName(ctx):
-        res = ctx.samdb.search(base="", scope=ldb.SCOPE_BASE, attrs=["dsServiceName"])
-        return res[0]["dsServiceName"][0]
-
     def get_behavior_version(ctx):
         res = ctx.samdb.search(base=ctx.base_dn, scope=ldb.SCOPE_BASE, attrs=["msDS-Behavior-Version"])
         if "msDS-Behavior-Version" in res[0]:
@@ -226,6 +232,13 @@ class dc_join(object):
                                expression='ncName=%s' % ctx.samdb.get_default_basedn())
         return res[0]["nETBIOSName"][0]
 
+    def get_forest_domain_name(ctx):
+        '''get netbios name of the domain from the partitions record'''
+        partitions_dn = ctx.samdb.get_partitions_dn()
+        res = ctx.samdb.search(base=partitions_dn, scope=ldb.SCOPE_ONELEVEL, attrs=["nETBIOSName"],
+                               expression='ncName=%s' % ctx.samdb.get_root_basedn())
+        return res[0]["nETBIOSName"][0]
+
     def get_parent_partition_dn(ctx):
         '''get the parent domain partition DN from parent DNS name'''
         res = ctx.samdb.search(base=ctx.config_dn, attrs=[],
@@ -233,6 +246,16 @@ class dc_join(object):
                                (ctx.parent_dnsdomain, ldb.OID_COMPARATOR_AND, samba.dsdb.SYSTEM_FLAG_CR_NTDS_DOMAIN))
         return str(res[0].dn)
 
+    def get_naming_master(ctx):
+        '''get the parent domain partition DN from parent DNS name'''
+        res = ctx.samdb.search(base='CN=Partitions,%s' % ctx.config_dn, attrs=['fSMORoleOwner'],
+                               scope=ldb.SCOPE_BASE, controls=["extended_dn:1:1"])
+        if not 'fSMORoleOwner' in res[0]:
+            raise DCJoinException("Can't find naming master on partition DN %s" % ctx.partition_dn)
+        master_guid = str(misc.GUID(ldb.Dn(ctx.samdb, res[0]['fSMORoleOwner'][0]).get_extended_component('GUID')))
+        master_host = '%s._msdcs.%s' % (master_guid, ctx.dnsforest)
+        return master_host
+
     def get_mysid(ctx):
         '''get the SID of the connected user. Only works with w2k8 and later,
            so only used for RODC join'''
@@ -280,9 +303,9 @@ class dc_join(object):
         ctx.samdb.rename(ctx.krbtgt_dn, ctx.new_krbtgt_dn)
 
     def drsuapi_connect(ctx):
-        '''make a DRSUAPI connection to the server'''
+        '''make a DRSUAPI connection to the naming master'''
         binding_options = "seal"
-        if int(ctx.lp.get("log level")) >= 5:
+        if int(ctx.lp.get("log level")) >= 4:
             binding_options += ",print"
         binding_string = "ncacn_ip_tcp:%s[%s]" % (ctx.server, binding_options)
         ctx.drsuapi = drsuapi.drsuapi(binding_string, ctx.lp, ctx.creds)
@@ -347,17 +370,25 @@ class dc_join(object):
             prev = o
 
         (level, ctr) = ctx.drsuapi.DsAddEntry(ctx.drsuapi_handle, 2, req2)
-        if ctr.err_ver != 1:
-            raise RuntimeError("expected err_ver 1, got %u" % ctr.err_ver)
-        if ctr.err_data.status != (0, 'WERR_OK'):
-            print("DsAddEntry failed with status %s info %s" % (ctr.err_data.status,
-                                                                ctr.err_data.info.extended_err))
-            raise RuntimeError("DsAddEntry failed")
-        if ctr.err_data.dir_err != drsuapi.DRSUAPI_DIRERR_OK:
-            print("DsAddEntry failed with dir_err %u" % ctr.err_data.dir_err)
-            raise RuntimeError("DsAddEntry failed")
-        return ctr.objects
+        if level == 2:
+            if ctr.dir_err != drsuapi.DRSUAPI_DIRERR_OK:
+                print("DsAddEntry failed with dir_err %u" % ctr.dir_err)
+                raise RuntimeError("DsAddEntry failed")
+            if ctr.extended_err != (0, 'WERR_OK'):
+                print("DsAddEntry failed with status %s info %s" % (ctr.extended_err))
+                raise RuntimeError("DsAddEntry failed")
+        if level == 3:
+            if ctr.err_ver != 1:
+                raise RuntimeError("expected err_ver 1, got %u" % ctr.err_ver)
+            if ctr.err_data.status != (0, 'WERR_OK'):
+                print("DsAddEntry failed with status %s info %s" % (ctr.err_data.status,
+                                                                    ctr.err_data.info.extended_err))
+                raise RuntimeError("DsAddEntry failed")
+            if ctr.err_data.dir_err != drsuapi.DRSUAPI_DIRERR_OK:
+                print("DsAddEntry failed with dir_err %u" % ctr.err_data.dir_err)
+                raise RuntimeError("DsAddEntry failed")
 
+        return ctr.objects
 
     def join_add_ntdsdsa(ctx):
         '''add the ntdsdsa object'''
@@ -372,21 +403,21 @@ class dc_join(object):
         nc_list = [ ctx.base_dn, ctx.config_dn, ctx.schema_dn ]
 
         if ctx.behavior_version >= samba.dsdb.DS_DOMAIN_FUNCTION_2003:
-            rec["msDS-Behavior-Version"] = str(ctx.behavior_version)
+            rec["msDS-Behavior-Version"] = str(samba.dsdb.DS_DOMAIN_FUNCTION_2008_R2)
 
         if ctx.behavior_version >= samba.dsdb.DS_DOMAIN_FUNCTION_2003:
             rec["msDS-HasDomainNCs"] = ctx.base_dn
 
         if ctx.RODC:
             rec["objectCategory"] = "CN=NTDS-DSA-RO,%s" % ctx.schema_dn
-            rec["msDS-HasFullReplicaNCs"] = nc_list
+            rec["msDS-HasFullReplicaNCs"] = ctx.nc_list
             rec["options"] = "37"
             ctx.samdb.add(rec, ["rodc_join:1:1"])
         else:
             rec["objectCategory"] = "CN=NTDS-DSA,%s" % ctx.schema_dn
             rec["HasMasterNCs"]      = nc_list
             if ctx.behavior_version >= samba.dsdb.DS_DOMAIN_FUNCTION_2003:
-                rec["msDS-HasMasterNCs"] = nc_list
+                rec["msDS-HasMasterNCs"] = ctx.nc_list
             rec["options"] = "1"
             rec["invocationId"] = ndr_pack(ctx.invocation_id)
             ctx.DsAddEntry([rec])
@@ -395,7 +426,6 @@ class dc_join(object):
         res = ctx.samdb.search(base=ctx.ntds_dn, scope=ldb.SCOPE_BASE, attrs=["objectGUID"])
         ctx.ntds_guid = misc.GUID(ctx.samdb.schema_format_value("objectGUID", res[0]["objectGUID"][0]))
 
-
     def join_add_objects(ctx):
         '''add the various objects needed for the join'''
         if ctx.acct_dn:
@@ -453,15 +483,6 @@ class dc_join(object):
                 "fromServer" : ctx.dc_ntds_dn}
             ctx.samdb.add(rec)
 
-        if ctx.topology_dn and ctx.acct_dn:
-            print "Adding %s" % ctx.topology_dn
-            rec = {
-                "dn" : ctx.topology_dn,
-                "objectclass" : "msDFSR-Member",
-                "msDFSR-ComputerReference" : ctx.acct_dn,
-                "serverReference" : ctx.ntds_dn}
-            ctx.samdb.add(rec)
-
         if ctx.acct_dn:
             print "Adding SPNs to %s" % ctx.acct_dn
             m = ldb.Message()
@@ -473,13 +494,31 @@ class dc_join(object):
                                                            "servicePrincipalName")
             ctx.samdb.modify(m)
 
+            # The account password set operation should normally be done over
+            # LDAP. Windows 2000 DCs however allow this only with SSL
+            # connections which are hard to set up and otherwise refuse with
+            # ERR_UNWILLING_TO_PERFORM. In this case we fall back to libnet
+            # over SAMR.
             print "Setting account password for %s" % ctx.samname
-            ctx.samdb.setpassword("(&(objectClass=user)(sAMAccountName=%s))" % ldb.binary_encode(ctx.samname),
-                                  ctx.acct_pass,
-                                  force_change_at_next_login=False,
-                                  username=ctx.samname)
-            res = ctx.samdb.search(base=ctx.acct_dn, scope=ldb.SCOPE_BASE, attrs=["msDS-keyVersionNumber"])
-            ctx.key_version_number = int(res[0]["msDS-keyVersionNumber"][0])
+            try:
+                ctx.samdb.setpassword("(&(objectClass=user)(sAMAccountName=%s))"
+                                      % ldb.binary_encode(ctx.samname),
+                                      ctx.acct_pass,
+                                      force_change_at_next_login=False,
+                                      username=ctx.samname)
+            except ldb.LdbError, (num, _):
+                if num != ldb.ERR_UNWILLING_TO_PERFORM:
+                    pass
+                ctx.net.set_password(account_name=ctx.samname,
+                                     domain_name=ctx.domain_name,
+                                     newpassword=ctx.acct_pass)
+
+            res = ctx.samdb.search(base=ctx.acct_dn, scope=ldb.SCOPE_BASE,
+                                   attrs=["msDS-KeyVersionNumber"])
+            if "msDS-KeyVersionNumber" in res[0]:
+                ctx.key_version_number = int(res[0]["msDS-KeyVersionNumber"][0])
+            else:
+                ctx.key_version_number = None
 
             print("Enabling account")
             m = ldb.Message()
@@ -489,7 +528,6 @@ class dc_join(object):
                                                          "userAccountControl")
             ctx.samdb.modify(m)
 
-
     def join_add_objects2(ctx):
         '''add the various objects needed for the join, for subdomains post replication'''
 
@@ -525,7 +563,7 @@ class dc_join(object):
         rec2["objectCategory"] = "CN=NTDS-DSA,%s" % ctx.schema_dn
         rec2["HasMasterNCs"]      = nc_list
         if ctx.behavior_version >= samba.dsdb.DS_DOMAIN_FUNCTION_2003:
-            rec2["msDS-HasMasterNCs"] = nc_list
+            rec2["msDS-HasMasterNCs"] = ctx.nc_list
         rec2["options"] = "1"
         rec2["invocationId"] = ndr_pack(ctx.invocation_id)
 
@@ -558,15 +596,15 @@ class dc_join(object):
         logger.addHandler(logging.StreamHandler(sys.stdout))
         smbconf = ctx.lp.configfile
 
-        presult = provision(logger, system_session(), None,
-                            smbconf=smbconf, targetdir=ctx.targetdir, samdb_fill=FILL_DRS,
-                            realm=ctx.realm, rootdn=ctx.root_dn, domaindn=ctx.base_dn,
-                            schemadn=ctx.schema_dn,
-                            configdn=ctx.config_dn,
-                            serverdn=ctx.server_dn, domain=ctx.domain_name,
-                            hostname=ctx.myname, domainsid=ctx.domsid,
-                            machinepass=ctx.acct_pass, serverrole="domain controller",
-                            sitename=ctx.site, lp=ctx.lp, ntdsguid=ctx.ntds_guid)
+        presult = provision(logger, system_session(), None, smbconf=smbconf,
+                targetdir=ctx.targetdir, samdb_fill=FILL_DRS, realm=ctx.realm,
+                rootdn=ctx.root_dn, domaindn=ctx.base_dn,
+                schemadn=ctx.schema_dn, configdn=ctx.config_dn,
+                serverdn=ctx.server_dn, domain=ctx.domain_name,
+                hostname=ctx.myname, domainsid=ctx.domsid,
+                machinepass=ctx.acct_pass, serverrole="domain controller",
+                sitename=ctx.site, lp=ctx.lp, ntdsguid=ctx.ntds_guid,
+                use_ntvfs=ctx.use_ntvfs, dns_backend=ctx.dns_backend)
         print "Provision OK for domain DN %s" % presult.domaindn
         ctx.local_samdb = presult.samdb
         ctx.lp          = presult.lp
@@ -604,10 +642,10 @@ class dc_join(object):
                                  domainguid=domguid,
                                  targetdir=ctx.targetdir, samdb_fill=FILL_SUBDOMAIN,
                                  machinepass=ctx.acct_pass, serverrole="domain controller",
-                                 lp=ctx.lp, hostip=ctx.names.hostip, hostip6=ctx.names.hostip6)
+                                 lp=ctx.lp, hostip=ctx.names.hostip, hostip6=ctx.names.hostip6,
+                                 dns_backend=ctx.dns_backend)
         print("Provision OK for domain %s" % ctx.names.dnsdomain)
 
-
     def join_replicate(ctx):
         '''replicate the SAM'''
 
@@ -644,9 +682,30 @@ class dc_join(object):
                     destination_dsa_guid, rodc=ctx.RODC,
                     replica_flags=ctx.replica_flags)
             if not ctx.subdomain:
+                # Replicate first the critical object for the basedn
+                if not ctx.domain_replica_flags & drsuapi.DRSUAPI_DRS_CRITICAL_ONLY:
+                    print "Replicating critical objects from the base DN of the domain"
+                    ctx.domain_replica_flags |= drsuapi.DRSUAPI_DRS_CRITICAL_ONLY | drsuapi.DRSUAPI_DRS_GET_ANC
+                    repl.replicate(ctx.base_dn, source_dsa_invocation_id,
+                                destination_dsa_guid, rodc=ctx.RODC,
+                                replica_flags=ctx.domain_replica_flags)
+                    ctx.domain_replica_flags ^= drsuapi.DRSUAPI_DRS_CRITICAL_ONLY | drsuapi.DRSUAPI_DRS_GET_ANC
+                else:
+                    ctx.domain_replica_flags |= drsuapi.DRSUAPI_DRS_GET_ANC
                 repl.replicate(ctx.base_dn, source_dsa_invocation_id,
                                destination_dsa_guid, rodc=ctx.RODC,
                                replica_flags=ctx.domain_replica_flags)
+
+            if 'DC=DomainDnsZones,%s' % ctx.base_dn in ctx.nc_list:
+                repl.replicate('DC=DomainDnsZones,%s' % ctx.base_dn, source_dsa_invocation_id,
+                               destination_dsa_guid, rodc=ctx.RODC,
+                               replica_flags=ctx.replica_flags)
+
+            if 'DC=ForestDnsZones,%s' % ctx.root_dn in ctx.nc_list:
+                repl.replicate('DC=ForestDnsZones,%s' % ctx.root_dn, source_dsa_invocation_id,
+                               destination_dsa_guid, rodc=ctx.RODC,
+                               replica_flags=ctx.replica_flags)
+
             if ctx.RODC:
                 repl.replicate(ctx.acct_dn, source_dsa_invocation_id,
                         destination_dsa_guid,
@@ -665,10 +724,31 @@ class dc_join(object):
         else:
             ctx.local_samdb.transaction_commit()
 
+    def send_DsReplicaUpdateRefs(ctx, dn):
+        r = drsuapi.DsReplicaUpdateRefsRequest1()
+        r.naming_context = drsuapi.DsReplicaObjectIdentifier()
+        r.naming_context.dn = str(dn)
+        r.naming_context.guid = misc.GUID("00000000-0000-0000-0000-000000000000")
+        r.naming_context.sid = security.dom_sid("S-0-0")
+        r.dest_dsa_guid = ctx.ntds_guid
+        r.dest_dsa_dns_name = "%s._msdcs.%s" % (str(ctx.ntds_guid), ctx.dnsforest)
+        r.options = drsuapi.DRSUAPI_DRS_ADD_REF | drsuapi.DRSUAPI_DRS_DEL_REF
+        if not ctx.RODC:
+            r.options |= drsuapi.DRSUAPI_DRS_WRIT_REP
+
+        if ctx.drsuapi:
+            ctx.drsuapi.DsReplicaUpdateRefs(ctx.drsuapi_handle, 1, r)
 
     def join_finalise(ctx):
         '''finalise the join, mark us synchronised and setup secrets db'''
 
+        logger = logging.getLogger("provision")
+        logger.addHandler(logging.StreamHandler(sys.stdout))
+
+        print "Sending DsReplicateUpdateRefs for all the partitions"
+        for nc in ctx.full_nc_list:
+            ctx.send_DsReplicaUpdateRefs(nc)
+
         print "Setting isSynchronized and dsServiceName"
         m = ldb.Message()
         m.dn = ldb.Dn(ctx.local_samdb, '@ROOTDSE')
@@ -692,6 +772,15 @@ class dc_join(object):
                             secure_channel_type=ctx.secure_channel_type,
                             key_version_number=ctx.key_version_number)
 
+        if ctx.dns_backend.startswith("BIND9_"):
+            dnspass = samba.generate_random_password(128, 255)
+
+            setup_bind9_dns(ctx.local_samdb, secrets_ldb, security.dom_sid(ctx.domsid),
+                            ctx.names, ctx.paths, ctx.lp, logger,
+                            dns_backend=ctx.dns_backend,
+                            dnspass=dnspass, os_level=ctx.behavior_version,
+                            targetdir=ctx.targetdir)
+
     def join_setup_trusts(ctx):
         '''provision the local SAM'''
 
@@ -709,7 +798,7 @@ class dc_join(object):
             return blob
 
         print "Setup domain trusts with server %s" % ctx.server
-        binding_options = ""  # why doesn't signing work gere? w2k8r2 claims no session key
+        binding_options = ""  # why doesn't signing work here? w2k8r2 claims no session key
         lsaconn = lsa.lsarpc("ncacn_np:%s[%s]" % (ctx.server, binding_options),
                              ctx.lp, ctx.creds)
 
@@ -784,20 +873,20 @@ class dc_join(object):
                                                          security.SEC_STD_DELETE)
 
         rec = {
-            "dn" : "cn=%s,cn=system,%s" % (ctx.parent_dnsdomain, ctx.base_dn),
+            "dn" : "cn=%s,cn=system,%s" % (ctx.dnsforest, ctx.base_dn),
             "objectclass" : "trustedDomain",
             "trustType" : str(info.trust_type),
             "trustAttributes" : str(info.trust_attributes),
             "trustDirection" : str(info.trust_direction),
-            "flatname" : ctx.parent_domain_name,
-            "trustPartner" : ctx.parent_dnsdomain,
+            "flatname" : ctx.forest_domain_name,
+            "trustPartner" : ctx.dnsforest,
             "trustAuthIncoming" : ndr_pack(outgoing),
             "trustAuthOutgoing" : ndr_pack(outgoing)
             }
         ctx.local_samdb.add(rec)
 
         rec = {
-            "dn" : "cn=%s$,cn=users,%s" % (ctx.parent_domain_name, ctx.base_dn),
+            "dn" : "cn=%s$,cn=users,%s" % (ctx.forest_domain_name, ctx.base_dn),
             "objectclass" : "user",
             "userAccountControl" : str(samba.dsdb.UF_INTERDOMAIN_TRUST_ACCOUNT),
             "clearTextPassword" : ctx.trustdom_pass.encode('utf-16-le')
@@ -806,6 +895,20 @@ class dc_join(object):
 
 
     def do_join(ctx):
+        ctx.nc_list = [ ctx.config_dn, ctx.schema_dn ]
+        ctx.full_nc_list = [ctx.base_dn, ctx.config_dn, ctx.schema_dn ]
+
+        if not ctx.subdomain:
+            ctx.nc_list += [ctx.base_dn]
+            if ctx.dns_backend != "NONE":
+                ctx.nc_list += ['DC=DomainDnsZones,%s' % ctx.base_dn]
+
+        if ctx.dns_backend != "NONE":
+            ctx.full_nc_list += ['DC=DomainDnsZones,%s' % ctx.base_dn]
+            ctx.full_nc_list += ['DC=ForestDnsZones,%s' % ctx.root_dn]
+            ctx.nc_list += ['DC=ForestDnsZones,%s' % ctx.root_dn]
+
+
         ctx.cleanup_old_join()
         try:
             ctx.join_add_objects()
@@ -816,17 +919,19 @@ class dc_join(object):
                 ctx.join_provision_own_domain()
                 ctx.join_setup_trusts()
             ctx.join_finalise()
-        except Exception:
+        except:
             print "Join failed - cleaning up"
-            #ctx.cleanup_old_join()
+            ctx.cleanup_old_join()
             raise
 
 
 def join_RODC(server=None, creds=None, lp=None, site=None, netbios_name=None,
-              targetdir=None, domain=None, domain_critical_only=False):
+              targetdir=None, domain=None, domain_critical_only=False,
+              machinepass=None, use_ntvfs=False, dns_backend=None):
     """join as a RODC"""
 
-    ctx = dc_join(server, creds, lp, site, netbios_name, targetdir, domain)
+    ctx = dc_join(server, creds, lp, site, netbios_name, targetdir, domain,
+                  machinepass, use_ntvfs, dns_backend)
 
     lp.set("workgroup", ctx.domain_name)
     print("workgroup is %s" % ctx.domain_name)
@@ -875,9 +980,11 @@ def join_RODC(server=None, creds=None, lp=None, site=None, netbios_name=None,
 
 
 def join_DC(server=None, creds=None, lp=None, site=None, netbios_name=None,
-            targetdir=None, domain=None, domain_critical_only=False):
+            targetdir=None, domain=None, domain_critical_only=False,
+            machinepass=None, use_ntvfs=False, dns_backend=None):
     """join as a DC"""
-    ctx = dc_join(server, creds, lp, site, netbios_name, targetdir, domain)
+    ctx = dc_join(server, creds, lp, site, netbios_name, targetdir, domain,
+                  machinepass, use_ntvfs, dns_backend)
 
     lp.set("workgroup", ctx.domain_name)
     print("workgroup is %s" % ctx.domain_name)
@@ -903,9 +1010,11 @@ def join_DC(server=None, creds=None, lp=None, site=None, netbios_name=None,
     print "Joined domain %s (SID %s) as a DC" % (ctx.domain_name, ctx.domsid)
 
 def join_subdomain(server=None, creds=None, lp=None, site=None, netbios_name=None,
-                   targetdir=None, parent_domain=None, dnsdomain=None, netbios_domain=None):
+                   targetdir=None, parent_domain=None, dnsdomain=None, netbios_domain=None,
+                   machinepass=None, use_ntvfs=False, dns_backend=None):
     """join as a DC"""
-    ctx = dc_join(server, creds, lp, site, netbios_name, targetdir, parent_domain)
+    ctx = dc_join(server, creds, lp, site, netbios_name, targetdir, parent_domain,
+                  machinepass, use_ntvfs, dns_backend)
     ctx.subdomain = True
     ctx.parent_domain_name = ctx.domain_name
     ctx.domain_name = netbios_domain
@@ -914,6 +1023,14 @@ def join_subdomain(server=None, creds=None, lp=None, site=None, netbios_name=Non
     ctx.parent_partition_dn = ctx.get_parent_partition_dn()
     ctx.dnsdomain = dnsdomain
     ctx.partition_dn = "CN=%s,CN=Partitions,%s" % (ctx.domain_name, ctx.config_dn)
+    ctx.naming_master = ctx.get_naming_master()
+    if ctx.naming_master != ctx.server:
+        print("Reconnecting to naming master %s" % ctx.naming_master)
+        ctx.server = ctx.naming_master
+        ctx.samdb = SamDB(url="ldap://%s" % ctx.server,
+                          session_info=system_session(),
+                          credentials=ctx.creds, lp=ctx.lp)
+
     ctx.base_dn = samba.dn_from_dns_name(dnsdomain)
     ctx.domsid = str(security.random_sid())
     ctx.acct_dn = None