pyldb: Support getting the parent of special DNs without segfaulting.
authorJelmer Vernooij <jelmer@samba.org>
Wed, 17 Jun 2009 16:25:21 +0000 (18:25 +0200)
committerJelmer Vernooij <jelmer@samba.org>
Wed, 17 Jun 2009 18:45:37 +0000 (20:45 +0200)
Found by: Андрей Григорьев <andrew@ei-grad.ru>

source4/lib/ldb/pyldb.c
source4/lib/ldb/tests/python/api.py

index 52d85304396ef57007057e45eebf9fcc16dcf09d..ab2a1215b849de46f2d19c820554a7bf6bdb969e 100644 (file)
@@ -206,7 +206,15 @@ static int py_ldb_dn_compare(PyLdbDnObject *dn1, PyLdbDnObject *dn2)
 static PyObject *py_ldb_dn_get_parent(PyLdbDnObject *self)
 {
        struct ldb_dn *dn = PyLdbDn_AsDn((PyObject *)self);
-       return PyLdbDn_FromDn(ldb_dn_get_parent(NULL, dn));
+       struct ldb_dn *parent;
+
+       parent = ldb_dn_get_parent(NULL, dn);
+
+       if (parent == NULL) {
+               Py_RETURN_NONE;
+       } else {
+               return PyLdbDn_FromDn(parent);
+       }
 }
 
 #define dn_ldb_ctx(dn) ((struct ldb_context *)dn)
index 07500e23728d9c86bc9640fdcacfc2c195a17ebc..177e2e986459dd53544d3ff89d5dfcd1e0b16d7a 100755 (executable)
@@ -14,6 +14,7 @@ def filename():
     return os.tempnam()
 
 class NoContextTests(unittest.TestCase):
+
     def test_valid_attr_name(self):
         self.assertTrue(ldb.valid_attr_name("foo"))
         self.assertFalse(ldb.valid_attr_name("24foo"))
@@ -28,6 +29,7 @@ class NoContextTests(unittest.TestCase):
 
 
 class SimpleLdb(unittest.TestCase):
+
     def test_connect(self):
         ldb.Ldb(filename())
 
@@ -273,6 +275,7 @@ class SimpleLdb(unittest.TestCase):
 
 
 class DnTests(unittest.TestCase):
+
     def setUp(self):
         self.ldb = ldb.Ldb(filename())
 
@@ -301,6 +304,10 @@ class DnTests(unittest.TestCase):
         x = ldb.Dn(self.ldb, "dc=foo,bar=bloe")
         self.assertEquals("bar=bloe", x.parent().__str__())
 
+    def test_parent_nonexistant(self):
+        x = ldb.Dn(self.ldb, "@BLA")
+        self.assertEquals(None, x.parent())
+
     def test_compare(self):
         x = ldb.Dn(self.ldb, "dc=foo,bar=bloe")
         y = ldb.Dn(self.ldb, "dc=foo,bar=bloe")
@@ -373,6 +380,7 @@ class DnTests(unittest.TestCase):
 
 
 class LdbMsgTests(unittest.TestCase):
+
     def setUp(self):
         self.msg = ldb.Message()
 
@@ -439,6 +447,7 @@ class LdbMsgTests(unittest.TestCase):
 
 
 class MessageElementTests(unittest.TestCase):
+
     def test_cmp_element(self):
         x = ldb.MessageElement(["foo"])
         y = ldb.MessageElement(["foo"])
@@ -479,6 +488,7 @@ class MessageElementTests(unittest.TestCase):
 
 
 class ModuleTests(unittest.TestCase):
+
     def test_register_module(self):
         class ExampleModule:
             name = "example"
@@ -505,6 +515,7 @@ class ModuleTests(unittest.TestCase):
         l = ldb.Ldb("usemodule.ldb")
         self.assertEquals(["init"], ops)
 
+
 if __name__ == '__main__':
     import unittest
     unittest.TestProgram()