1 # Unix SMB/CIFS implementation.
2 # Copyright (C) Jelmer Vernooij <jelmer@samba.org> 2007-2010
3 # Copyright (C) Stefan Metzmacher 2014,2015
5 # This program is free software; you can redistribute it and/or modify
6 # it under the terms of the GNU General Public License as published by
7 # the Free Software Foundation; either version 3 of the License, or
8 # (at your option) any later version.
10 # This program is distributed in the hope that it will be useful,
11 # but WITHOUT ANY WARRANTY; without even the implied warranty of
12 # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
13 # GNU General Public License for more details.
15 # You should have received a copy of the GNU General Public License
16 # along with this program. If not, see <http://www.gnu.org/licenses/>.
19 """Samba Python tests."""
25 from samba import param
26 from samba import credentials
27 from samba.credentials import Credentials
28 from samba import gensec
37 import samba.dcerpc.base
38 from samba.compat import PY3, text_type
39 from samba.compat import string_types
40 from random import randint
42 from samba.samdb import SamDB
44 # We are built without samdb support,
45 # imitate it so that connect_samdb() can recover
46 def SamDB(*args, **kwargs):
50 import samba.dcerpc.dcerpc
51 import samba.dcerpc.epmapper
54 from unittest import SkipTest
56 class SkipTest(Exception):
59 HEXDUMP_FILTER = bytearray([x if ((len(repr(chr(x))) == 3) and (x < 127)) else ord('.') for x in range(256)])
62 class TestCase(unittest.TestCase):
63 """A Samba test case."""
66 super(TestCase, self).setUp()
67 test_debug_level = os.getenv("TEST_DEBUG_LEVEL")
68 if test_debug_level is not None:
69 test_debug_level = int(test_debug_level)
70 self._old_debug_level = samba.get_debug_level()
71 samba.set_debug_level(test_debug_level)
72 self.addCleanup(samba.set_debug_level, test_debug_level)
74 def get_loadparm(self):
77 def get_credentials(self):
78 return cmdline_credentials
80 def get_creds_ccache_name(self):
81 creds = self.get_credentials()
82 ccache = creds.get_named_ccache(self.get_loadparm())
83 ccache_name = ccache.get_name()
87 def hexdump(self, src):
90 is_string = isinstance(src, string_types)
96 hl = ' '.join(["%02X" % ord(x) for x in ll])
97 hr = ' '.join(["%02X" % ord(x) for x in lr])
98 ll = ll.translate(HEXDUMP_FILTER)
99 lr = lr.translate(HEXDUMP_FILTER)
101 hl = ' '.join(["%02X" % x for x in ll])
102 hr = ' '.join(["%02X" % x for x in lr])
103 ll = ll.translate(HEXDUMP_FILTER).decode('utf8')
104 lr = lr.translate(HEXDUMP_FILTER).decode('utf8')
105 result += "[%04X] %-*s %-*s %s %s\n" % (N, 8 * 3, hl, 8 * 3, hr, ll, lr)
109 def insta_creds(self, template=None, username=None, userpass=None, kerberos_state=None):
112 assert template is not None
114 if username is not None:
115 assert userpass is not None
118 assert userpass is None
120 username = template.get_username()
121 userpass = template.get_password()
123 if kerberos_state is None:
124 kerberos_state = template.get_kerberos_state()
126 # get a copy of the global creds or a the passed in creds
128 c.set_username(username)
129 c.set_password(userpass)
130 c.set_domain(template.get_domain())
131 c.set_realm(template.get_realm())
132 c.set_workstation(template.get_workstation())
133 c.set_gensec_features(c.get_gensec_features()
134 | gensec.FEATURE_SEAL)
135 c.set_kerberos_state(kerberos_state)
140 # These functions didn't exist before Python2.7:
141 if sys.version_info < (2, 7):
144 def skipTest(self, reason):
145 raise SkipTest(reason)
147 def assertIn(self, member, container, msg=None):
148 self.assertTrue(member in container, msg)
150 def assertIs(self, a, b, msg=None):
151 self.assertTrue(a is b, msg)
153 def assertIsNot(self, a, b, msg=None):
154 self.assertTrue(a is not b, msg)
156 def assertIsNotNone(self, a, msg=None):
157 self.assertTrue(a is not None)
159 def assertIsInstance(self, a, b, msg=None):
160 self.assertTrue(isinstance(a, b), msg)
162 def assertIsNone(self, a, msg=None):
163 self.assertTrue(a is None, msg)
165 def assertGreater(self, a, b, msg=None):
166 self.assertTrue(a > b, msg)
168 def assertGreaterEqual(self, a, b, msg=None):
169 self.assertTrue(a >= b, msg)
171 def assertLess(self, a, b, msg=None):
172 self.assertTrue(a < b, msg)
174 def assertLessEqual(self, a, b, msg=None):
175 self.assertTrue(a <= b, msg)
177 def addCleanup(self, fn, *args, **kwargs):
178 self._cleanups = getattr(self, "_cleanups", []) + [
181 def assertRegexpMatches(self, text, regex, msg=None):
182 # PY3 note: Python 3 will never see this, but we use
183 # text_type for the benefit of linters.
184 if isinstance(regex, (str, text_type)):
185 regex = re.compile(regex)
186 if not regex.search(text):
189 def _addSkip(self, result, reason):
190 addSkip = getattr(result, 'addSkip', None)
191 if addSkip is not None:
192 addSkip(self, reason)
194 warnings.warn("TestResult has no addSkip method, skips not reported",
196 result.addSuccess(self)
198 def run(self, result=None):
199 if result is None: result = self.defaultTestResult()
200 result.startTest(self)
201 testMethod = getattr(self, self._testMethodName)
205 except SkipTest as e:
206 self._addSkip(result, str(e))
208 except KeyboardInterrupt:
211 result.addError(self, self._exc_info())
218 except SkipTest as e:
219 self._addSkip(result, str(e))
221 except self.failureException:
222 result.addFailure(self, self._exc_info())
223 except KeyboardInterrupt:
226 result.addError(self, self._exc_info())
230 except SkipTest as e:
231 self._addSkip(result, str(e))
232 except KeyboardInterrupt:
235 result.addError(self, self._exc_info())
238 for (fn, args, kwargs) in reversed(getattr(self, "_cleanups", [])):
240 if ok: result.addSuccess(self)
242 result.stopTest(self)
244 def assertStringsEqual(self, a, b, msg=None, strip=False):
245 """Assert equality between two strings and highlight any differences.
246 If strip is true, leading and trailing whitespace is ignored."""
252 sys.stderr.write("The strings differ %s(lengths %d vs %d); "
254 % ('when stripped ' if strip else '',
258 from difflib import unified_diff
259 diff = unified_diff(a.splitlines(True),
263 sys.stderr.write(line)
268 class LdbTestCase(TestCase):
269 """Trivial test case for running tests against a LDB."""
272 super(LdbTestCase, self).setUp()
273 self.tempfile = tempfile.NamedTemporaryFile(delete=False)
274 self.filename = self.tempfile.name
275 self.ldb = samba.Ldb(self.filename)
277 def set_modules(self, modules=[]):
278 """Change the modules for this Ldb."""
280 m.dn = ldb.Dn(self.ldb, "@MODULES")
281 m["@LIST"] = ",".join(modules)
283 self.ldb = samba.Ldb(self.filename)
286 class TestCaseInTempDir(TestCase):
289 super(TestCaseInTempDir, self).setUp()
290 self.tempdir = tempfile.mkdtemp()
291 self.addCleanup(self._remove_tempdir)
293 def _remove_tempdir(self):
294 self.assertEquals([], os.listdir(self.tempdir))
295 os.rmdir(self.tempdir)
300 lp = param.LoadParm()
302 lp.load(os.environ["SMB_CONF_PATH"])
304 raise KeyError("SMB_CONF_PATH not set")
308 def env_get_var_value(var_name, allow_missing=False):
309 """Returns value for variable in os.environ
311 Function throws AssertionError if variable is defined.
312 Unit-test based python tests require certain input params
313 to be set in environment, otherwise they can't be run
316 if var_name not in os.environ.keys():
318 assert var_name in os.environ.keys(), "Please supply %s in environment" % var_name
319 return os.environ[var_name]
322 cmdline_credentials = None
325 class RpcInterfaceTestCase(TestCase):
326 """DCE/RPC Test case."""
329 class ValidNetbiosNameTests(TestCase):
331 def test_valid(self):
332 self.assertTrue(samba.valid_netbios_name("FOO"))
334 def test_too_long(self):
335 self.assertFalse(samba.valid_netbios_name("FOO" * 10))
337 def test_invalid_characters(self):
338 self.assertFalse(samba.valid_netbios_name("*BLA"))
341 class BlackboxProcessError(Exception):
342 """This is raised when check_output() process returns a non-zero exit status
344 Exception instance should contain the exact exit code (S.returncode),
345 command line (S.cmd), process output (S.stdout) and process error stream
349 def __init__(self, returncode, cmd, stdout, stderr, msg=None):
350 self.returncode = returncode
357 s = ("Command '%s'; exit status %d; stdout: '%s'; stderr: '%s'" %
358 (self.cmd, self.returncode, self.stdout, self.stderr))
359 if self.msg is not None:
360 s = "%s; message: %s" % (s, self.msg)
365 class BlackboxTestCase(TestCaseInTempDir):
366 """Base test case for blackbox tests."""
368 def _make_cmdline(self, line):
369 bindir = os.path.abspath(os.path.join(os.path.dirname(__file__), "../../../../bin"))
370 parts = line.split(" ")
371 if os.path.exists(os.path.join(bindir, parts[0])):
372 parts[0] = os.path.join(bindir, parts[0])
373 line = " ".join(parts)
376 def check_run(self, line, msg=None):
377 self.check_exit_code(line, 0, msg=msg)
379 def check_exit_code(self, line, expected, msg=None):
380 line = self._make_cmdline(line)
381 p = subprocess.Popen(line,
382 stdout=subprocess.PIPE,
383 stderr=subprocess.PIPE,
385 stdoutdata, stderrdata = p.communicate()
386 retcode = p.returncode
387 if retcode != expected:
388 raise BlackboxProcessError(retcode,
394 def check_output(self, line):
395 line = self._make_cmdline(line)
396 p = subprocess.Popen(line, stdout=subprocess.PIPE, stderr=subprocess.PIPE, shell=True, close_fds=True)
397 stdoutdata, stderrdata = p.communicate()
398 retcode = p.returncode
400 raise BlackboxProcessError(retcode, line, stdoutdata, stderrdata)
404 def connect_samdb(samdb_url, lp=None, session_info=None, credentials=None,
405 flags=0, ldb_options=None, ldap_only=False, global_schema=True):
406 """Create SamDB instance and connects to samdb_url database.
408 :param samdb_url: Url for database to connect to.
409 :param lp: Optional loadparm object
410 :param session_info: Optional session information
411 :param credentials: Optional credentials, defaults to anonymous.
412 :param flags: Optional LDB flags
413 :param ldap_only: If set, only remote LDAP connection will be created.
414 :param global_schema: Whether to use global schema.
416 Added value for tests is that we have a shorthand function
417 to make proper URL for ldb.connect() while using default
418 parameters for connection based on test environment
420 if not "://" in samdb_url:
421 if not ldap_only and os.path.isfile(samdb_url):
422 samdb_url = "tdb://%s" % samdb_url
424 samdb_url = "ldap://%s" % samdb_url
425 # use 'paged_search' module when connecting remotely
426 if samdb_url.startswith("ldap://"):
427 ldb_options = ["modules:paged_searches"]
429 raise AssertionError("Trying to connect to %s while remote "
430 "connection is required" % samdb_url)
432 # set defaults for test environment
435 if session_info is None:
436 session_info = samba.auth.system_session(lp)
437 if credentials is None:
438 credentials = cmdline_credentials
440 return SamDB(url=samdb_url,
442 session_info=session_info,
443 credentials=credentials,
446 global_schema=global_schema)
449 def connect_samdb_ex(samdb_url, lp=None, session_info=None, credentials=None,
450 flags=0, ldb_options=None, ldap_only=False):
451 """Connects to samdb_url database
453 :param samdb_url: Url for database to connect to.
454 :param lp: Optional loadparm object
455 :param session_info: Optional session information
456 :param credentials: Optional credentials, defaults to anonymous.
457 :param flags: Optional LDB flags
458 :param ldap_only: If set, only remote LDAP connection will be created.
459 :return: (sam_db_connection, rootDse_record) tuple
461 sam_db = connect_samdb(samdb_url, lp, session_info, credentials,
462 flags, ldb_options, ldap_only)
464 res = sam_db.search(base="", expression="", scope=ldb.SCOPE_BASE,
466 return (sam_db, res[0])
469 def connect_samdb_env(env_url, env_username, env_password, lp=None):
470 """Connect to SamDB by getting URL and Credentials from environment
472 :param env_url: Environment variable name to get lsb url from
473 :param env_username: Username environment variable
474 :param env_password: Password environment variable
475 :return: sam_db_connection
477 samdb_url = env_get_var_value(env_url)
478 creds = credentials.Credentials()
480 # guess Credentials parameters here. Otherwise workstation
481 # and domain fields are NULL and gencache code segfalts
482 lp = param.LoadParm()
484 creds.set_username(env_get_var_value(env_username))
485 creds.set_password(env_get_var_value(env_password))
486 return connect_samdb(samdb_url, credentials=creds, lp=lp)
489 def delete_force(samdb, dn, **kwargs):
491 samdb.delete(dn, **kwargs)
492 except ldb.LdbError as error:
493 (num, errstr) = error.args
494 assert num == ldb.ERR_NO_SUCH_OBJECT, "ldb.delete() failed: %s" % errstr
497 def create_test_ou(samdb, name):
498 """Creates a unique OU for the test"""
500 # Add some randomness to the test OU. Replication between the testenvs is
501 # constantly happening in the background. Deletion of the last test's
502 # objects can be slow to replicate out. So the OU created by a previous
503 # testenv may still exist at the point that tests start on another testenv.
504 rand = randint(1, 10000000)
505 dn = ldb.Dn(samdb, "OU=%s%d,%s" % (name, rand, samdb.get_default_basedn()))
506 samdb.add({"dn": dn, "objectclass": "organizationalUnit"})