join: Refactor clone_only case to simplify code
authorTim Beale <timbeale@catalyst.net.nz>
Mon, 11 Jun 2018 04:33:19 +0000 (16:33 +1200)
committerGary Lockyer <gary@samba.org>
Tue, 3 Jul 2018 03:24:14 +0000 (05:24 +0200)
Currently for DC clones, we create a regular DCJoinContext, se a
'clone_only' flag, and then make lots of special checks for this flag
throughout the code. Instead, we can use inheritance to create a
DCCloneContext sub-class, and put the specialization there.

This means we can remove all the 'clone_only' checks from the code. The
only 2 methods that really differ are do_join() and join_finalize(), and
these don't share much code at all. (To avoid duplication, I split the
first part of do_join() into a new build_nc_lists() function, but this
is a pretty trivial code move).

We still pass the clone_only flag into the __init__() as there's still
one case where we want to avoid doing work in the case of the clone.
For clarity, I'll refactor this in a subsequent patch.

Signed-off-by: Tim Beale <timbeale@catalyst.net.nz>
Reviewed-by: Andrew Bartlett <abartlet@samba.org>
Reviewed-by: Gary Lockyer <gary@catalyst.net.nz>
python/samba/join.py

index 9635a6b6888b853311d2b7909bd2148425d635f8..51a82f20bd4b4517a053262acbe05eecd2d0d47f 100644 (file)
@@ -61,8 +61,6 @@ class DCJoinContext(object):
         if site is None:
             site = "Default-First-Site-Name"
 
-        ctx.clone_only=clone_only
-
         ctx.logger = logger
         ctx.creds = creds
         ctx.lp = lp
@@ -119,18 +117,7 @@ class DCJoinContext(object):
             ctx.acct_pass = samba.generate_random_machine_password(128, 255)
 
         ctx.dnsdomain = ctx.samdb.domain_dns_name()
-        if clone_only:
-            # As we don't want to create or delete these DNs, we set them to None
-            ctx.server_dn = None
-            ctx.ntds_dn = None
-            ctx.acct_dn = None
-            ctx.myname = ctx.server.split('.')[0]
-            ctx.ntds_guid = None
-            ctx.rid_manager_dn = None
-
-            # Save this early
-            ctx.remote_dc_ntds_guid = ctx.samdb.get_ntds_GUID()
-        else:
+        if not clone_only:
             # work out the DNs of all the objects we will be adding
             ctx.myname = netbios_name
             ctx.samname = "%s$" % ctx.myname
@@ -1188,12 +1175,11 @@ class DCJoinContext(object):
         # DC we just replicated from then we don't need to send the updatereplicateref
         # as replication between sites is time based and on the initiative of the
         # requesting DC
-        if not ctx.clone_only:
-            ctx.logger.info("Sending DsReplicaUpdateRefs for all the replicated partitions")
-            for nc in ctx.nc_list:
-                ctx.send_DsReplicaUpdateRefs(nc)
+        ctx.logger.info("Sending DsReplicaUpdateRefs for all the replicated partitions")
+        for nc in ctx.nc_list:
+            ctx.send_DsReplicaUpdateRefs(nc)
 
-        if not ctx.clone_only and ctx.RODC:
+        if ctx.RODC:
             print("Setting RODC invocationId")
             ctx.local_samdb.set_invocation_id(str(ctx.invocation_id))
             ctx.local_samdb.set_opaque_integer("domainFunctionality",
@@ -1224,17 +1210,12 @@ class DCJoinContext(object):
         m.dn = ldb.Dn(ctx.local_samdb, '@ROOTDSE')
         m["isSynchronized"] = ldb.MessageElement("TRUE", ldb.FLAG_MOD_REPLACE, "isSynchronized")
 
-        # We want to appear to be the server we just cloned
-        if ctx.clone_only:
-            guid = ctx.remote_dc_ntds_guid
-        else:
-            guid = ctx.ntds_guid
-
+        guid = ctx.ntds_guid
         m["dsServiceName"] = ldb.MessageElement("<GUID=%s>" % str(guid),
                                                 ldb.FLAG_MOD_REPLACE, "dsServiceName")
         ctx.local_samdb.modify(m)
 
-        if ctx.clone_only or ctx.subdomain:
+        if ctx.subdomain:
             return
 
         secrets_ldb = Ldb(ctx.paths.secrets, session_info=system_session(), lp=ctx.lp)
@@ -1359,7 +1340,7 @@ class DCJoinContext(object):
         ctx.local_samdb.add(rec)
 
 
-    def do_join(ctx):
+    def build_nc_lists(ctx):
         # nc_list is the list of naming context (NC) for which we will
         # replicate in and send a updateRef command to the partner DC
 
@@ -1380,23 +1361,24 @@ class DCJoinContext(object):
                 ctx.full_nc_list += [ctx.domaindns_zone]
                 ctx.full_nc_list += [ctx.forestdns_zone]
 
-        if not ctx.clone_only:
-            if ctx.promote_existing:
-                ctx.promote_possible()
-            else:
-                ctx.cleanup_old_join()
+    def do_join(ctx):
+        ctx.build_nc_lists()
+
+        if ctx.promote_existing:
+            ctx.promote_possible()
+        else:
+            ctx.cleanup_old_join()
 
         try:
-            if not ctx.clone_only:
-                ctx.join_add_objects()
+            ctx.join_add_objects()
             ctx.join_provision()
             ctx.join_replicate()
-            if (not ctx.clone_only and ctx.subdomain):
+            if ctx.subdomain:
                 ctx.join_add_objects2()
                 ctx.join_provision_own_domain()
                 ctx.join_setup_trusts()
 
-            if not ctx.clone_only and ctx.dns_backend != "NONE":
+            if ctx.dns_backend != "NONE":
                 ctx.join_add_dns_records()
                 ctx.join_replicate_new_dns_records()
 
@@ -1406,8 +1388,7 @@ class DCJoinContext(object):
                 print("Join failed - cleaning up")
             except IOError:
                 pass
-            if not ctx.clone_only:
-                ctx.cleanup_old_join()
+            ctx.cleanup_old_join()
             raise
 
 
@@ -1499,11 +1480,10 @@ def join_DC(logger=None, server=None, creds=None, lp=None, site=None, netbios_na
 def join_clone(logger=None, server=None, creds=None, lp=None,
                targetdir=None, domain=None, include_secrets=False,
                dns_backend="NONE"):
-    """Join as a DC."""
-    ctx = DCJoinContext(logger, server, creds, lp, site=None, netbios_name=None,
-                        targetdir=targetdir, domain=domain, machinepass=None,
-                        use_ntvfs=False, dns_backend=dns_backend,
-                        promote_existing=False, clone_only=True)
+    """Creates a local clone of a remote DC."""
+    ctx = DCCloneContext(logger, server, creds, lp, targetdir=targetdir,
+                         domain=domain, dns_backend=dns_backend,
+                         include_secrets=include_secrets)
 
     lp.set("workgroup", ctx.domain_name)
     logger.info("workgroup is %s" % ctx.domain_name)
@@ -1511,12 +1491,6 @@ def join_clone(logger=None, server=None, creds=None, lp=None,
     lp.set("realm", ctx.realm)
     logger.info("realm is %s" % ctx.realm)
 
-    ctx.replica_flags |= (drsuapi.DRSUAPI_DRS_WRIT_REP |
-                          drsuapi.DRSUAPI_DRS_FULL_SYNC_IN_PROGRESS)
-    if not include_secrets:
-        ctx.replica_flags |= drsuapi.DRSUAPI_DRS_SPECIAL_SECRET_PROCESSING
-    ctx.domain_replica_flags = ctx.replica_flags
-
     ctx.do_join()
     logger.info("Cloned domain %s (SID %s)" % (ctx.domain_name, ctx.domsid))
 
@@ -1573,3 +1547,55 @@ def join_subdomain(logger=None, server=None, creds=None, lp=None, site=None,
 
     ctx.do_join()
     ctx.logger.info("Created domain %s (SID %s) as a DC" % (ctx.domain_name, ctx.domsid))
+
+
+class DCCloneContext(DCJoinContext):
+    """Clones a remote DC."""
+
+    def __init__(ctx, logger=None, server=None, creds=None, lp=None,
+                 targetdir=None, domain=None, dns_backend=None,
+                 include_secrets=False):
+        super(DCCloneContext, ctx).__init__(logger, server, creds, lp,
+                                            targetdir=targetdir, domain=domain,
+                                            dns_backend=dns_backend,
+                                            clone_only=True)
+
+        # As we don't want to create or delete these DNs, we set them to None
+        ctx.server_dn = None
+        ctx.ntds_dn = None
+        ctx.acct_dn = None
+        ctx.myname = ctx.server.split('.')[0]
+        ctx.ntds_guid = None
+        ctx.rid_manager_dn = None
+
+        # Save this early
+        ctx.remote_dc_ntds_guid = ctx.samdb.get_ntds_GUID()
+
+        ctx.replica_flags |= (drsuapi.DRSUAPI_DRS_WRIT_REP |
+                              drsuapi.DRSUAPI_DRS_FULL_SYNC_IN_PROGRESS)
+        if not include_secrets:
+            ctx.replica_flags |= drsuapi.DRSUAPI_DRS_SPECIAL_SECRET_PROCESSING
+        ctx.domain_replica_flags = ctx.replica_flags
+
+    def join_finalise(ctx):
+        ctx.logger.info("Setting isSynchronized and dsServiceName")
+        m = ldb.Message()
+        m.dn = ldb.Dn(ctx.local_samdb, '@ROOTDSE')
+        m["isSynchronized"] = ldb.MessageElement("TRUE", ldb.FLAG_MOD_REPLACE,
+                                                 "isSynchronized")
+
+        # We want to appear to be the server we just cloned
+        guid = ctx.remote_dc_ntds_guid
+        m["dsServiceName"] = ldb.MessageElement("<GUID=%s>" % str(guid),
+                                                ldb.FLAG_MOD_REPLACE,
+                                                "dsServiceName")
+        ctx.local_samdb.modify(m)
+
+    def do_join(ctx):
+        ctx.build_nc_lists()
+
+        # When cloning a DC, we just want to provision a DC locally, then
+        # grab the remote DC's entire DB via DRS replication
+        ctx.join_provision()
+        ctx.join_replicate()
+        ctx.join_finalise()