VERSION: Bump version number up to 4.0.4.
[samba.git] / source4 / scripting / python / samba / kcc_utils.py
index 93096e96899b69fa86ba3602cbb0f06b8a089c57..57c31876a69aed74abcc97c3a697410eef8015bb 100644 (file)
@@ -1,5 +1,3 @@
-#!/usr/bin/env python
-#
 # KCC topology utilities
 #
 # Copyright (C) Dave Craft 2011
@@ -22,12 +20,14 @@ import ldb
 import uuid
 import time
 
-from samba        import (dsdb, unix2nttime)
-from samba.dcerpc import (drsblobs, \
-                          drsuapi,  \
-                          misc)
+from samba import dsdb, unix2nttime
+from samba.dcerpc import (
+    drsblobs,
+    drsuapi,
+    misc,
+    )
 from samba.common import dsdb_Dn
-from samba.ndr    import (ndr_unpack, ndr_pack)
+from samba.ndr import (ndr_unpack, ndr_pack)
 
 
 class NCType(object):
@@ -47,9 +47,9 @@ class NamingContext(object):
         :param nc_dnstr: NC dn string
         """
         self.nc_dnstr = nc_dnstr
-        self.nc_guid  = None
-        self.nc_sid   = None
-        self.nc_type  = NCType.unknown
+        self.nc_guid = None
+        self.nc_sid = None
+        self.nc_type = NCType.unknown
 
     def __str__(self):
         '''Debug dump string output of class'''
@@ -73,7 +73,7 @@ class NamingContext(object):
                                scope=ldb.SCOPE_BASE, attrs=attrs)
 
         except ldb.LdbError, (enum, estr):
-            raise Exception("Unable to find naming context (%s)" % \
+            raise Exception("Unable to find naming context (%s)" %
                             (self.nc_dnstr, estr))
         msg = res[0]
         if "objectGUID" in msg:
@@ -83,7 +83,6 @@ class NamingContext(object):
             self.nc_sid = msg["objectSid"][0]
 
         assert self.nc_guid is not None
-        return
 
     def is_schema(self):
         '''Return True if NC is schema'''
@@ -214,7 +213,7 @@ class NCReplica(NamingContext):
 
     def set_instantiated_flags(self, flags=None):
         '''Set or clear NC replica instantiated flags'''
-        if (flags == None):
+        if flags is None:
             self.rep_instantiated_flags = 0
         else:
             self.rep_instantiated_flags = flags
@@ -338,7 +337,7 @@ class NCReplica(NamingContext):
             # replacement list.  Build a list
             # of to be deleted reps which we will
             # remove from rep_repsFrom list below
-            if repsFrom.to_be_deleted == True:
+            if repsFrom.to_be_deleted:
                 delreps.append(repsFrom)
                 modify = True
                 continue
@@ -362,7 +361,7 @@ class NCReplica(NamingContext):
         # need to be deleted or input option has informed
         # us to be "readonly" (ro).  Leave database
         # record "as is"
-        if modify == False or ro == True:
+        if not modify or ro:
             return
 
         m = ldb.Message()
@@ -381,7 +380,7 @@ class NCReplica(NamingContext):
     def dumpstr_to_be_deleted(self):
         text=""
         for repsFrom in self.rep_repsFrom:
-            if repsFrom.to_be_deleted == True:
+            if repsFrom.to_be_deleted:
                 if text:
                     text = text + "\n%s" % repsFrom
                 else:
@@ -391,7 +390,7 @@ class NCReplica(NamingContext):
     def dumpstr_to_be_modified(self):
         text=""
         for repsFrom in self.rep_repsFrom:
-            if repsFrom.is_modified() == True:
+            if repsFrom.is_modified():
                 if text:
                     text = text + "\n%s" % repsFrom
                 else:
@@ -415,7 +414,6 @@ class NCReplica(NamingContext):
         # Possibly no fSMORoleOwner
         if "fSMORoleOwner" in msg:
             self.rep_fsmo_role_owner = msg["fSMORoleOwner"]
-        return
 
     def is_fsmo_role_owner(self, dsa_dnstr):
         if self.rep_fsmo_role_owner is not None and \
@@ -423,6 +421,7 @@ class NCReplica(NamingContext):
             return True
         return False
 
+
 class DirectoryServiceAgent(object):
 
     def __init__(self, dsa_dnstr):
@@ -626,7 +625,7 @@ class DirectoryServiceAgent(object):
                 for value in res[0][k]:
                     # Turn dn into a dsdb_Dn so we can use
                     # its methods to parse a binary DN
-                    dsdn  = dsdb_Dn(samdb, value)
+                    dsdn = dsdb_Dn(samdb, value)
                     flags = dsdn.get_binary_integer()
                     dnstr = str(dsdn.dn)
 
@@ -712,8 +711,6 @@ class DirectoryServiceAgent(object):
         for dnstr in delconn:
             del self.connect_table[dnstr]
 
-        return
-
     def add_connection(self, dnstr, connect):
         assert dnstr not in self.connect_table.keys()
         self.connect_table[dnstr] = connect
@@ -768,12 +765,12 @@ class DirectoryServiceAgent(object):
         """
         dnstr = "CN=%s," % str(uuid.uuid4()) + self.dsa_dnstr
 
-        connect             = NTDSConnection(dnstr)
+        connect = NTDSConnection(dnstr)
         connect.to_be_added = True
-        connect.enabled     = True
-        connect.from_dnstr  = from_dnstr
-        connect.options     = options
-        connect.flags       = flags
+        connect.enabled = True
+        connect.from_dnstr = from_dnstr
+        connect.options = options
+        connect.flags = flags
 
         if transport is not None:
             connect.transport_dnstr = transport.dnstr
@@ -814,11 +811,11 @@ class NTDSConnection(object):
     """
     def __init__(self, dnstr):
         self.dnstr = dnstr
-        self.guid  = None
+        self.guid = None
         self.enabled = False
         self.whenCreated = 0
-        self.to_be_added    = False # new connection needs to be added
-        self.to_be_deleted  = False # old connection needs to be deleted
+        self.to_be_added = False # new connection needs to be added
+        self.to_be_deleted = False # old connection needs to be deleted
         self.to_be_modified = False
         self.options = 0
         self.system_flags = 0
@@ -936,12 +933,11 @@ class NTDSConnection(object):
 
         if "objectGUID" in res[0]:
             self.transport_dnstr = tdnstr
-            self.transport_guid  = \
+            self.transport_guid = \
                 misc.GUID(samdb.schema_format_value("objectGUID",
                                                     msg["objectGUID"][0]))
         assert self.transport_dnstr is not None
         assert self.transport_guid is not None
-        return
 
     def commit_deleted(self, samdb, ro=False):
         """Local helper routine for commit_connections() which
@@ -952,17 +948,15 @@ class NTDSConnection(object):
         self.to_be_deleted = False
 
         # No database modification requested
-        if ro == True:
+        if ro:
             return
 
         try:
             samdb.delete(self.dnstr)
         except ldb.LdbError, (enum, estr):
-            raise Exception("Could not delete nTDSConnection for (%s) - (%s)" % \
+            raise Exception("Could not delete nTDSConnection for (%s) - (%s)" %
                             (self.dnstr, estr))
 
-        return
-
     def commit_added(self, samdb, ro=False):
         """Local helper routine for commit_connections() which
         handles committed connections that are to be added to the
@@ -972,7 +966,7 @@ class NTDSConnection(object):
         self.to_be_added = False
 
         # No database modification requested
-        if ro == True:
+        if ro:
             return
 
         # First verify we don't have this entry to ensure nothing
@@ -985,10 +979,10 @@ class NTDSConnection(object):
 
         except ldb.LdbError, (enum, estr):
             if enum != ldb.ERR_NO_SUCH_OBJECT:
-                raise Exception("Unable to search for (%s) - (%s)" % \
+                raise Exception("Unable to search for (%s) - (%s)" %
                                 (self.dnstr, estr))
         if found:
-            raise Exception("nTDSConnection for (%s) already exists!" % \
+            raise Exception("nTDSConnection for (%s) already exists!" %
                             self.dnstr)
 
         if self.enabled:
@@ -1001,10 +995,10 @@ class NTDSConnection(object):
         m.dn = ldb.Dn(samdb, self.dnstr)
 
         m["objectClass"] = \
-            ldb.MessageElement("nTDSConnection", ldb.FLAG_MOD_ADD, \
+            ldb.MessageElement("nTDSConnection", ldb.FLAG_MOD_ADD,
                                "objectClass")
         m["showInAdvancedViewOnly"] = \
-            ldb.MessageElement("TRUE", ldb.FLAG_MOD_ADD, \
+            ldb.MessageElement("TRUE", ldb.FLAG_MOD_ADD,
                                "showInAdvancedViewOnly")
         m["enabledConnection"] = \
             ldb.MessageElement(enablestr, ldb.FLAG_MOD_ADD, "enabledConnection")
@@ -1013,12 +1007,12 @@ class NTDSConnection(object):
         m["options"] = \
             ldb.MessageElement(str(self.options), ldb.FLAG_MOD_ADD, "options")
         m["systemFlags"] = \
-            ldb.MessageElement(str(self.system_flags), ldb.FLAG_MOD_ADD, \
+            ldb.MessageElement(str(self.system_flags), ldb.FLAG_MOD_ADD,
                                "systemFlags")
 
         if self.transport_dnstr is not None:
             m["transportType"] = \
-                ldb.MessageElement(str(self.transport_dnstr), ldb.FLAG_MOD_ADD, \
+                ldb.MessageElement(str(self.transport_dnstr), ldb.FLAG_MOD_ADD,
                                    "transportType")
 
         if self.schedule is not None:
@@ -1028,9 +1022,8 @@ class NTDSConnection(object):
         try:
             samdb.add(m)
         except ldb.LdbError, (enum, estr):
-            raise Exception("Could not add nTDSConnection for (%s) - (%s)" % \
+            raise Exception("Could not add nTDSConnection for (%s) - (%s)" %
                             (self.dnstr, estr))
-        return
 
     def commit_modified(self, samdb, ro=False):
         """Local helper routine for commit_connections() which
@@ -1041,7 +1034,7 @@ class NTDSConnection(object):
         self.to_be_modified = False
 
         # No database modification requested
-        if ro == True:
+        if ro:
             return
 
         # First verify we have this entry to ensure nothing
@@ -1054,10 +1047,10 @@ class NTDSConnection(object):
             if enum == ldb.ERR_NO_SUCH_OBJECT:
                 found = False
             else:
-                raise Exception("Unable to search for (%s) - (%s)" % \
+                raise Exception("Unable to search for (%s) - (%s)" %
                                 (self.dnstr, estr))
-        if found == False:
-            raise Exception("nTDSConnection for (%s) doesn't exist!" % \
+        if not found:
+            raise Exception("nTDSConnection for (%s) doesn't exist!" %
                             self.dnstr)
 
         if self.enabled:
@@ -1070,53 +1063,47 @@ class NTDSConnection(object):
         m.dn = ldb.Dn(samdb, self.dnstr)
 
         m["enabledConnection"] = \
-            ldb.MessageElement(enablestr, ldb.FLAG_MOD_REPLACE, \
+            ldb.MessageElement(enablestr, ldb.FLAG_MOD_REPLACE,
                                "enabledConnection")
         m["fromServer"] = \
-            ldb.MessageElement(self.from_dnstr, ldb.FLAG_MOD_REPLACE, \
+            ldb.MessageElement(self.from_dnstr, ldb.FLAG_MOD_REPLACE,
                                "fromServer")
         m["options"] = \
-            ldb.MessageElement(str(self.options), ldb.FLAG_MOD_REPLACE, \
+            ldb.MessageElement(str(self.options), ldb.FLAG_MOD_REPLACE,
                                "options")
         m["systemFlags"] = \
-            ldb.MessageElement(str(self.system_flags), ldb.FLAG_MOD_REPLACE, \
+            ldb.MessageElement(str(self.system_flags), ldb.FLAG_MOD_REPLACE,
                                "systemFlags")
 
         if self.transport_dnstr is not None:
             m["transportType"] = \
-                ldb.MessageElement(str(self.transport_dnstr), \
+                ldb.MessageElement(str(self.transport_dnstr),
                                    ldb.FLAG_MOD_REPLACE, "transportType")
         else:
             m["transportType"] = \
-                ldb.MessageElement([], \
-                                   ldb.FLAG_MOD_DELETE, "transportType")
+                ldb.MessageElement([], ldb.FLAG_MOD_DELETE, "transportType")
 
         if self.schedule is not None:
             m["schedule"] = \
-                ldb.MessageElement(ndr_pack(self.schedule), \
+                ldb.MessageElement(ndr_pack(self.schedule),
                                    ldb.FLAG_MOD_REPLACE, "schedule")
         else:
             m["schedule"] = \
-                ldb.MessageElement([], \
-                                   ldb.FLAG_MOD_DELETE, "schedule")
+                ldb.MessageElement([], ldb.FLAG_MOD_DELETE, "schedule")
         try:
             samdb.modify(m)
         except ldb.LdbError, (enum, estr):
-            raise Exception("Could not modify nTDSConnection for (%s) - (%s)" % \
+            raise Exception("Could not modify nTDSConnection for (%s) - (%s)" %
                             (self.dnstr, estr))
-        return
 
     def set_modified(self, truefalse):
         self.to_be_modified = truefalse
-        return
 
     def set_added(self, truefalse):
         self.to_be_added = truefalse
-        return
 
     def set_deleted(self, truefalse):
         self.to_be_deleted = truefalse
-        return
 
     def is_schedule_minimum_once_per_week(self):
         """Returns True if our schedule includes at least one
@@ -1142,9 +1129,9 @@ class NTDSConnection(object):
         elif sched is None:
             return True
 
-        if self.schedule.size              != sched.size or \
-           self.schedule.bandwidth         != sched.bandwidth or \
-           self.schedule.numberOfSchedules != sched.numberOfSchedules:
+        if (self.schedule.size != sched.size or
+            self.schedule.bandwidth != sched.bandwidth or
+            self.schedule.numberOfSchedules != sched.numberOfSchedules):
             return False
 
         for i, header in enumerate(self.schedule.headerArray):
@@ -1156,7 +1143,7 @@ class NTDSConnection(object):
                sched.headerArray[i].offset:
                 return False
 
-            for a, b in zip(self.schedule.dataArray[i].slots, \
+            for a, b in zip(self.schedule.dataArray[i].slots,
                             sched.dataArray[i].slots):
                 if a != b:
                     return False
@@ -1307,7 +1294,7 @@ class Partition(NamingContext):
                 continue
 
             for value in msg[k]:
-                dsdn  = dsdb_Dn(samdb, value)
+                dsdn = dsdb_Dn(samdb, value)
                 dnstr = str(dsdn.dn)
 
                 if k == "nCName":
@@ -1406,11 +1393,11 @@ class Site(object):
     naming context.  Contains all DSAs that exist within the site
     """
     def __init__(self, site_dnstr):
-        self.site_dnstr          = site_dnstr
-        self.site_options        = 0
+        self.site_dnstr = site_dnstr
+        self.site_options = 0
         self.site_topo_generator = None
-        self.site_topo_failover  = 0  # appears to be in minutes
-        self.dsa_table           = {}
+        self.site_topo_failover = 0  # appears to be in minutes
+        self.dsa_table = {}
 
     def load_site(self, samdb):
         """Loads the NTDS Site Settions options attribute for the site
@@ -1502,7 +1489,7 @@ class Site(object):
                 break
 
         if c_rep is None:
-            raise Exception("Unable to find config NC replica for (%s)" % \
+            raise Exception("Unable to find config NC replica for (%s)" %
                             mydsa.dsa_dnstr)
 
         # Load repsFrom if not already loaded so we can get the current
@@ -1525,10 +1512,10 @@ class Site(object):
         # in the site by guid in ascending order".   Place sorted list
         # in D_sort[]
         D_sort = []
-        d_dsa  = None
+        d_dsa = None
 
         unixnow = int(time.time())     # seconds since 1970
-        ntnow   = unix2nttime(unixnow) # double word number of 100 nanosecond
+        ntnow = unix2nttime(unixnow) # double word number of 100 nanosecond
                                        # intervals since 1600s
 
         for dsa in self.dsa_table.values():
@@ -1583,15 +1570,14 @@ class Site(object):
            # last_success appears to be a double word containing
            #     number of 100 nanosecond intervals since the 1600s
            if d_dsa.dsa_ivid != c_rep.source_dsa_invocation_id:
-               i_idx  = j_idx
+               i_idx = j_idx
                t_time = 0
 
            elif ntnow < (c_rep.last_success - f):
-               i_idx  = 0
+               i_idx = 0
                t_time = 0
-
            else:
-               i_idx  = j_idx
+               i_idx = j_idx
                t_time = c_rep.last_success
 
         # Otherwise (Nominate local DC as ISTG):
@@ -1599,7 +1585,7 @@ class Site(object):
         #         object for the local DC.
         #     Let t = the current time.
         else:
-            i_idx  = D_sort.index(mydsa)
+            i_idx = D_sort.index(mydsa)
             t_time = ntnow
 
         # Compute a function that maintains the current ISTG if
@@ -1631,7 +1617,7 @@ class Site(object):
 
         # If readonly database then do not perform a
         # persistent update
-        if ro == True:
+        if ro:
             return True
 
         # Perform update to the samdb
@@ -1641,14 +1627,15 @@ class Site(object):
         m.dn = ldb.Dn(samdb, ssdn)
 
         m["interSiteTopologyGenerator"] = \
-            ldb.MessageElement(mydsa.dsa_dnstr, ldb.FLAG_MOD_REPLACE, \
+            ldb.MessageElement(mydsa.dsa_dnstr, ldb.FLAG_MOD_REPLACE,
                                "interSiteTopologyGenerator")
         try:
             samdb.modify(m)
 
         except ldb.LdbError, estr:
-            raise Exception("Could not set interSiteTopologyGenerator for (%s) - (%s)" %
-                            (ssdn, estr))
+            raise Exception(
+                "Could not set interSiteTopologyGenerator for (%s) - (%s)" %
+                (ssdn, estr))
         return True
 
     def is_intrasite_topology_disabled(self):
@@ -1780,7 +1767,7 @@ class GraphNode(object):
             #    the DC on which ri "is present".
             #
             #    c.options does not contain NTDSCONN_OPT_RODC_TOPOLOGY
-            if connect and connect.is_rodc_topology() == False:
+            if connect and not connect.is_rodc_topology():
                 exists = True
             else:
                 exists = False
@@ -1791,13 +1778,11 @@ class GraphNode(object):
                 return
 
             # Generate a new dnstr for this nTDSConnection
-            opt   = dsdb.NTDSCONN_OPT_IS_GENERATED
+            opt = dsdb.NTDSCONN_OPT_IS_GENERATED
             flags = dsdb.SYSTEM_FLAG_CONFIG_ALLOW_RENAME + \
                      dsdb.SYSTEM_FLAG_CONFIG_ALLOW_MOVE
 
             dsa.create_connection(opt, flags, None, edge_dnstr, None)
-            return
-
 
     def has_sufficient_edges(self):
         '''Return True if we have met the maximum "from edges" criteria'''
@@ -1806,7 +1791,6 @@ class GraphNode(object):
         return False
 
 
-
 class Transport(object):
     """Class defines a Inter-site transport found under Sites
     """
@@ -1865,11 +1849,11 @@ class Transport(object):
 
         if "bridgeheadServerListBL" in msg:
             for value in msg["bridgeheadServerListBL"]:
-                dsdn  = dsdb_Dn(samdb, value)
+                dsdn = dsdb_Dn(samdb, value)
                 dnstr = str(dsdn.dn)
                 if dnstr not in self.bridgehead_list:
                     self.bridgehead_list.append(dnstr)
-        return
+
 
 class RepsFromTo(object):
     """Class encapsulation of the NDR repsFromToBlob.
@@ -2002,7 +1986,6 @@ class RepsFromTo(object):
             raise AttributeError, "Unknown attribute %s" % item
 
         self.__dict__['update_flags'] |= drsuapi.DRSUAPI_DRS_UPDATE_ADDRESS
-        return
 
     def __getattr__(self, item):
         """Overload of RepsFromTo attribute retrieval.
@@ -2047,18 +2030,19 @@ class RepsFromTo(object):
     def set_unmodified(self):
         self.__dict__['update_flags'] = 0x0
 
+
 class SiteLink(object):
     """Class defines a site link found under sites
     """
 
     def __init__(self, dnstr):
-        self.dnstr        = dnstr
-        self.options      = 0
+        self.dnstr = dnstr
+        self.options = 0
         self.system_flags = 0
-        self.cost         = 0
-        self.schedule     = None
-        self.interval     = None
-        self.site_list    = []
+        self.cost = 0
+        self.schedule = None
+        self.interval = None
+        self.site_list = []
 
     def __str__(self):
         '''Debug dump string output of Transport object'''
@@ -2124,33 +2108,32 @@ class SiteLink(object):
 
         if "siteList" in msg:
             for value in msg["siteList"]:
-                dsdn  = dsdb_Dn(samdb, value)
+                dsdn = dsdb_Dn(samdb, value)
                 dnstr = str(dsdn.dn)
                 if dnstr not in self.site_list:
                     self.site_list.append(dnstr)
-        return
 
     def is_sitelink(self, site1_dnstr, site2_dnstr):
         """Given a siteLink object, determine if it is a link
         between the two input site DNs
         """
-        if site1_dnstr in self.site_list and \
-           site2_dnstr in self.site_list:
+        if site1_dnstr in self.site_list and site2_dnstr in self.site_list:
             return True
         return False
 
-class VertexColor():
+
+class VertexColor(object):
     (unknown, white, black, red) = range(0, 4)
 
+
 class Vertex(object):
     """Class encapsulation of a Site Vertex in the
     intersite topology replication algorithm
     """
     def __init__(self, site, part):
-        self.site  = site
-        self.part  = part
+        self.site = site
+        self.part = part
         self.color = VertexColor.unknown
-        return
 
     def color_vertex(self):
         """Color each vertex to indicate which kind of NC
@@ -2174,12 +2157,11 @@ class Vertex(object):
 
             # We have a full replica which is the largest
             # value so exit
-            if rep.is_partial() == False:
+            if not rep.is_partial():
                 self.color = VertexColor.red
                 break
             else:
                 self.color = VertexColor.black
-        return
 
     def is_red(self):
         assert(self.color != VertexColor.unknown)