Add custom implementations of TestCase.assertIs and TestCase.assertIsNot, for Python2.6.
[obnox/samba/samba-obnox.git] / python / samba / tests / __init__.py
1 # Unix SMB/CIFS implementation.
2 # Copyright (C) Jelmer Vernooij <jelmer@samba.org> 2007-2010
3 #
4 # This program is free software; you can redistribute it and/or modify
5 # it under the terms of the GNU General Public License as published by
6 # the Free Software Foundation; either version 3 of the License, or
7 # (at your option) any later version.
8 #
9 # This program is distributed in the hope that it will be useful,
10 # but WITHOUT ANY WARRANTY; without even the implied warranty of
11 # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
12 # GNU General Public License for more details.
13 #
14 # You should have received a copy of the GNU General Public License
15 # along with this program.  If not, see <http://www.gnu.org/licenses/>.
16 #
17
18 """Samba Python tests."""
19
20 import os
21 import ldb
22 import samba
23 import samba.auth
24 from samba import param
25 from samba.samdb import SamDB
26 from samba import credentials
27 import subprocess
28 import tempfile
29 import unittest
30
31 try:
32     from unittest import SkipTest
33 except ImportError:
34     class SkipTest(Exception):
35         """Test skipped."""
36
37
38 class TestCase(unittest.TestCase):
39     """A Samba test case."""
40
41     def setUp(self):
42         super(TestCase, self).setUp()
43         test_debug_level = os.getenv("TEST_DEBUG_LEVEL")
44         if test_debug_level is not None:
45             test_debug_level = int(test_debug_level)
46             self._old_debug_level = samba.get_debug_level()
47             samba.set_debug_level(test_debug_level)
48             self.addCleanup(samba.set_debug_level, test_debug_level)
49
50     def get_loadparm(self):
51         return env_loadparm()
52
53     def get_credentials(self):
54         return cmdline_credentials
55
56     if not getattr(unittest.TestCase, "skipTest", None):
57         def skipTest(self, reason):
58             raise SkipTest(reason)
59
60     if not getattr(unittest.TestCase, "assertIs", None):
61         def assertIs(self, a, b):
62             self.assertTrue(a is b)
63
64     if not getattr(unittest.TestCase, "assertIsNot", None):
65         def assertIsNot(self, a, b):
66             self.assertTrue(a is not b)
67
68
69 class LdbTestCase(unittest.TestCase):
70     """Trivial test case for running tests against a LDB."""
71
72     def setUp(self):
73         super(LdbTestCase, self).setUp()
74         self.filename = os.tempnam()
75         self.ldb = samba.Ldb(self.filename)
76
77     def set_modules(self, modules=[]):
78         """Change the modules for this Ldb."""
79         m = ldb.Message()
80         m.dn = ldb.Dn(self.ldb, "@MODULES")
81         m["@LIST"] = ",".join(modules)
82         self.ldb.add(m)
83         self.ldb = samba.Ldb(self.filename)
84
85
86 class TestCaseInTempDir(TestCase):
87
88     def setUp(self):
89         super(TestCaseInTempDir, self).setUp()
90         self.tempdir = tempfile.mkdtemp()
91         self.addCleanup(self._remove_tempdir)
92
93     def _remove_tempdir(self):
94         self.assertEquals([], os.listdir(self.tempdir))
95         os.rmdir(self.tempdir)
96         self.tempdir = None
97
98
99 def env_loadparm():
100     lp = param.LoadParm()
101     try:
102         lp.load(os.environ["SMB_CONF_PATH"])
103     except KeyError:
104         raise KeyError("SMB_CONF_PATH not set")
105     return lp
106
107
108 def env_get_var_value(var_name):
109     """Returns value for variable in os.environ
110
111     Function throws AssertionError if variable is defined.
112     Unit-test based python tests require certain input params
113     to be set in environment, otherwise they can't be run
114     """
115     assert var_name in os.environ.keys(), "Please supply %s in environment" % var_name
116     return os.environ[var_name]
117
118
119 cmdline_credentials = None
120
121 class RpcInterfaceTestCase(TestCase):
122     """DCE/RPC Test case."""
123
124
125 class ValidNetbiosNameTests(TestCase):
126
127     def test_valid(self):
128         self.assertTrue(samba.valid_netbios_name("FOO"))
129
130     def test_too_long(self):
131         self.assertFalse(samba.valid_netbios_name("FOO"*10))
132
133     def test_invalid_characters(self):
134         self.assertFalse(samba.valid_netbios_name("*BLA"))
135
136
137 class BlackboxProcessError(Exception):
138     """This is raised when check_output() process returns a non-zero exit status
139
140     Exception instance should contain the exact exit code (S.returncode),
141     command line (S.cmd), process output (S.stdout) and process error stream
142     (S.stderr)
143     """
144
145     def __init__(self, returncode, cmd, stdout, stderr):
146         self.returncode = returncode
147         self.cmd = cmd
148         self.stdout = stdout
149         self.stderr = stderr
150
151     def __str__(self):
152         return "Command '%s'; exit status %d; stdout: '%s'; stderr: '%s'" % (self.cmd, self.returncode,
153                                                                              self.stdout, self.stderr)
154
155 class BlackboxTestCase(TestCase):
156     """Base test case for blackbox tests."""
157
158     def _make_cmdline(self, line):
159         bindir = os.path.abspath(os.path.join(os.path.dirname(__file__), "../../../../bin"))
160         parts = line.split(" ")
161         if os.path.exists(os.path.join(bindir, parts[0])):
162             parts[0] = os.path.join(bindir, parts[0])
163         line = " ".join(parts)
164         return line
165
166     def check_run(self, line):
167         line = self._make_cmdline(line)
168         p = subprocess.Popen(line, stdout=subprocess.PIPE, stderr=subprocess.PIPE, shell=True)
169         retcode = p.wait()
170         if retcode:
171             raise BlackboxProcessError(retcode, line, p.stdout.read(), p.stderr.read())
172
173     def check_output(self, line):
174         line = self._make_cmdline(line)
175         p = subprocess.Popen(line, stdout=subprocess.PIPE, stderr=subprocess.PIPE, shell=True, close_fds=True)
176         retcode = p.wait()
177         if retcode:
178             raise BlackboxProcessError(retcode, line, p.stdout.read(), p.stderr.read())
179         return p.stdout.read()
180
181
182 def connect_samdb(samdb_url, lp=None, session_info=None, credentials=None,
183                   flags=0, ldb_options=None, ldap_only=False, global_schema=True):
184     """Create SamDB instance and connects to samdb_url database.
185
186     :param samdb_url: Url for database to connect to.
187     :param lp: Optional loadparm object
188     :param session_info: Optional session information
189     :param credentials: Optional credentials, defaults to anonymous.
190     :param flags: Optional LDB flags
191     :param ldap_only: If set, only remote LDAP connection will be created.
192     :param global_schema: Whether to use global schema.
193
194     Added value for tests is that we have a shorthand function
195     to make proper URL for ldb.connect() while using default
196     parameters for connection based on test environment
197     """
198     if not "://" in samdb_url:
199         if not ldap_only and os.path.isfile(samdb_url):
200             samdb_url = "tdb://%s" % samdb_url
201         else:
202             samdb_url = "ldap://%s" % samdb_url
203     # use 'paged_search' module when connecting remotely
204     if samdb_url.startswith("ldap://"):
205         ldb_options = ["modules:paged_searches"]
206     elif ldap_only:
207         raise AssertionError("Trying to connect to %s while remote "
208                              "connection is required" % samdb_url)
209
210     # set defaults for test environment
211     if lp is None:
212         lp = env_loadparm()
213     if session_info is None:
214         session_info = samba.auth.system_session(lp)
215     if credentials is None:
216         credentials = cmdline_credentials
217
218     return SamDB(url=samdb_url,
219                  lp=lp,
220                  session_info=session_info,
221                  credentials=credentials,
222                  flags=flags,
223                  options=ldb_options,
224                  global_schema=global_schema)
225
226
227 def connect_samdb_ex(samdb_url, lp=None, session_info=None, credentials=None,
228                      flags=0, ldb_options=None, ldap_only=False):
229     """Connects to samdb_url database
230
231     :param samdb_url: Url for database to connect to.
232     :param lp: Optional loadparm object
233     :param session_info: Optional session information
234     :param credentials: Optional credentials, defaults to anonymous.
235     :param flags: Optional LDB flags
236     :param ldap_only: If set, only remote LDAP connection will be created.
237     :return: (sam_db_connection, rootDse_record) tuple
238     """
239     sam_db = connect_samdb(samdb_url, lp, session_info, credentials,
240                            flags, ldb_options, ldap_only)
241     # fetch RootDse
242     res = sam_db.search(base="", expression="", scope=ldb.SCOPE_BASE,
243                         attrs=["*"])
244     return (sam_db, res[0])
245
246
247 def connect_samdb_env(env_url, env_username, env_password, lp=None):
248     """Connect to SamDB by getting URL and Credentials from environment
249
250     :param env_url: Environment variable name to get lsb url from
251     :param env_username: Username environment variable
252     :param env_password: Password environment variable
253     :return: sam_db_connection
254     """
255     samdb_url = env_get_var_value(env_url)
256     creds = credentials.Credentials()
257     if lp is None:
258         # guess Credentials parameters here. Otherwise workstation
259         # and domain fields are NULL and gencache code segfalts
260         lp = param.LoadParm()
261         creds.guess(lp)
262     creds.set_username(env_get_var_value(env_username))
263     creds.set_password(env_get_var_value(env_password))
264     return connect_samdb(samdb_url, credentials=creds, lp=lp)
265
266
267 def delete_force(samdb, dn):
268     try:
269         samdb.delete(dn)
270     except ldb.LdbError, (num, errstr):
271         assert num == ldb.ERR_NO_SUCH_OBJECT, "ldb.delete() failed: %s" % errstr