pyldb: Don't segfault when invalid type is specified to Dn.get().
authorJelmer Vernooij <jelmer@samba.org>
Mon, 14 Sep 2009 15:03:30 +0000 (17:03 +0200)
committerJelmer Vernooij <jelmer@samba.org>
Mon, 14 Sep 2009 15:03:30 +0000 (17:03 +0200)
(#6722)

source4/lib/ldb/pyldb.c
source4/lib/ldb/tests/python/api.py

index 3f7fa2f395825671c882ea2f1d98954b1af2c932..b4f03dc538640e1ba909df65d30e645ae6b35a27 100644 (file)
@@ -1758,8 +1758,13 @@ static PyObject *py_ldb_msg_keys(PyLdbMessageObject *self)
 static PyObject *py_ldb_msg_getitem_helper(PyLdbMessageObject *self, PyObject *py_name)
 {
        struct ldb_message_element *el;
-       char *name = PyString_AsString(py_name);
+       char *name;
        struct ldb_message *msg = PyLdbMessage_AsMessage(self);
+       if (!PyString_Check(py_name)) {
+               PyErr_SetNone(PyExc_TypeError);
+               return NULL;
+       }
+       name = PyString_AsString(py_name);
        if (!strcmp(name, "dn"))
                return PyLdbDn_FromDn(msg->dn);
        el = ldb_msg_find_element(msg, name);
@@ -1786,8 +1791,11 @@ static PyObject *py_ldb_msg_get(PyLdbMessageObject *self, PyObject *args)
                return NULL;
 
        ret = py_ldb_msg_getitem_helper(self, name);
-       if (ret == NULL)
+       if (ret == NULL) {
+               if (PyErr_Occurred())
+                       return NULL;
                Py_RETURN_NONE;
+       }
        return ret;
 }
 
index 88983ac738bc53099130085c7b0312c551d600cb..133bd180c17b1275315f182f62e8ccb2380da6b2 100755 (executable)
@@ -480,6 +480,10 @@ class LdbMsgTests(unittest.TestCase):
         self.msg.dn = ldb.Dn(ldb.Ldb("foo.tdb"), "@BASEINFO")
         self.assertEquals("@BASEINFO", self.msg.get("dn").__str__())
 
+    def test_get_invalid(self):
+        self.msg.dn = ldb.Dn(ldb.Ldb("foo.tdb"), "@BASEINFO")
+        self.assertRaises(TypeError, self.msg.get, 42)
+
     def test_get_other(self):
         self.msg["foo"] = ["bar"]
         self.assertEquals("bar", self.msg.get("foo")[0])