6a28c0de4c0684bce63736b9880edd50a293c4ce
[samba.git] / python / samba / tests / __init__.py
1 # Unix SMB/CIFS implementation.
2 # Copyright (C) Jelmer Vernooij <jelmer@samba.org> 2007-2010
3 # Copyright (C) Stefan Metzmacher 2014,2015
4 #
5 # This program is free software; you can redistribute it and/or modify
6 # it under the terms of the GNU General Public License as published by
7 # the Free Software Foundation; either version 3 of the License, or
8 # (at your option) any later version.
9 #
10 # This program is distributed in the hope that it will be useful,
11 # but WITHOUT ANY WARRANTY; without even the implied warranty of
12 # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
13 # GNU General Public License for more details.
14 #
15 # You should have received a copy of the GNU General Public License
16 # along with this program.  If not, see <http://www.gnu.org/licenses/>.
17 #
18
19 """Samba Python tests."""
20 import os
21 import tempfile
22 import traceback
23 import collections
24 import ldb
25 import samba
26 from samba import param
27 from samba import credentials
28 from samba.credentials import Credentials
29 import subprocess
30 import sys
31 import unittest
32 import re
33 from enum import IntEnum, unique
34 import samba.auth
35 import samba.gensec
36 import samba.dcerpc.base
37 from random import randint
38 from random import SystemRandom
39 from contextlib import contextmanager
40 import shutil
41 import string
42 try:
43     from samba.samdb import SamDB
44 except ImportError:
45     # We are built without samdb support,
46     # imitate it so that connect_samdb() can recover
47     def SamDB(*args, **kwargs):
48         return None
49
50 import samba.ndr
51 import samba.dcerpc.dcerpc
52 import samba.dcerpc.epmapper
53
54 from unittest import SkipTest
55
56
57 BINDIR = os.path.abspath(os.path.join(os.path.dirname(__file__),
58                                       "../../../../bin"))
59
60 HEXDUMP_FILTER = bytearray([x if ((len(repr(chr(x))) == 3) and (x < 127)) else ord('.') for x in range(256)])
61
62 LDB_ERR_LUT = {v: k for k, v in vars(ldb).items() if k.startswith('ERR_')}
63
64 RE_CAMELCASE = re.compile(r"([_\-])+")
65
66
67 def ldb_err(v):
68     if isinstance(v, ldb.LdbError):
69         v = v.args[0]
70
71     if v in LDB_ERR_LUT:
72         return LDB_ERR_LUT[v]
73
74     try:
75         return f"[{', '.join(LDB_ERR_LUT.get(x, x) for x in v)}]"
76     except TypeError as e:
77         print(e)
78     return v
79
80
81 def DynamicTestCase(cls):
82     cls.setUpDynamicTestCases()
83     return cls
84
85
86 class TestCase(unittest.TestCase):
87     """A Samba test case."""
88
89     # Re-implement addClassCleanup to support Python versions older than 3.8.
90     # Can be removed once these older Python versions are no longer needed.
91     if sys.version_info.major == 3 and sys.version_info.minor < 8:
92         _class_cleanups = []
93
94         @classmethod
95         def addClassCleanup(cls, function, *args, **kwargs):
96             cls._class_cleanups.append((function, args, kwargs))
97
98         @classmethod
99         def tearDownClass(cls):
100             teardown_exceptions = []
101
102             while cls._class_cleanups:
103                 function, args, kwargs = cls._class_cleanups.pop()
104                 try:
105                     function(*args, **kwargs)
106                 except Exception:
107                     teardown_exceptions.append(traceback.format_exc())
108
109             # ExceptionGroup would be better but requires Python 3.11
110             if teardown_exceptions:
111                 raise ValueError("tearDownClass failed:\n\n" +
112                                  "\n".join(teardown_exceptions))
113
114         @classmethod
115         def setUpClass(cls):
116             """
117             Call setUpTestData, ensure tearDownClass is called on exceptions.
118
119             This is only required on Python versions older than 3.8.
120             """
121             try:
122                 cls.setUpTestData()
123             except Exception:
124                 cls.tearDownClass()
125                 raise
126     else:
127         @classmethod
128         def setUpClass(cls):
129             """
130             setUpClass only needs to call setUpTestData.
131
132             On Python 3.8 and above unittest will always call tearDownClass,
133             even if an exception was raised in setUpClass.
134             """
135             cls.setUpTestData()
136
137     @classmethod
138     def setUpTestData(cls):
139         """Create class level test fixtures here."""
140         pass
141
142     @classmethod
143     def generate_dynamic_test(cls, fnname, suffix, *args, doc=None):
144         """
145         fnname is something like "test_dynamic_sum"
146         suffix is something like "1plus2"
147         argstr could be (1, 2)
148
149         This would generate a test case called
150         "test_dynamic_sum_1plus2(self)" that
151         calls
152         self._test_dynamic_sum_with_args(1, 2)
153         """
154         def fn(self):
155             getattr(self, "_%s_with_args" % fnname)(*args)
156         fn.__doc__ = doc
157         attr = "%s_%s" % (fnname, suffix)
158         if hasattr(cls, attr):
159             raise RuntimeError(f"Dynamic test {attr} already exists!")
160         setattr(cls, attr, fn)
161
162     @classmethod
163     def setUpDynamicTestCases(cls):
164         """This can be implemented in order to call cls.generate_dynamic_test()
165         In order to implement autogenerated testcase permutations.
166         """
167         msg = "%s needs setUpDynamicTestCases() if @DynamicTestCase is used!" % (cls)
168         raise NotImplementedError(msg)
169
170     def unique_name(self):
171         """Generate a unique name from within a test for creating objects.
172
173         Used to ensure each test generates uniquely named objects that don't
174         interfere with other tests.
175         """
176         # name of calling function
177         name = self.id().rsplit(".", 1)[1]
178
179         # remove test_ prefix
180         if name.startswith("test_"):
181             name = name[5:]
182
183         # finally, convert to camelcase
184         name = RE_CAMELCASE.sub(" ", name).title().replace(" ", "")
185         return "".join([name[0].lower(), name[1:]])
186
187     def setUp(self):
188         super().setUp()
189         test_debug_level = os.getenv("TEST_DEBUG_LEVEL")
190         if test_debug_level is not None:
191             test_debug_level = int(test_debug_level)
192             self._old_debug_level = samba.get_debug_level()
193             samba.set_debug_level(test_debug_level)
194             self.addCleanup(samba.set_debug_level, test_debug_level)
195
196     @classmethod
197     def get_loadparm(cls):
198         return env_loadparm()
199
200     def get_credentials(self):
201         return cmdline_credentials
202
203     @classmethod
204     def get_env_credentials(cls, *, lp, env_username, env_password,
205                             env_realm=None, env_domain=None):
206         creds = credentials.Credentials()
207
208         # guess Credentials parameters here. Otherwise, workstation
209         # and domain fields are NULL and gencache code segfaults
210         creds.guess(lp)
211         creds.set_username(env_get_var_value(env_username))
212         creds.set_password(env_get_var_value(env_password))
213
214         if env_realm is not None:
215             creds.set_realm(env_get_var_value(env_realm))
216
217         if env_domain is not None:
218             creds.set_domain(env_get_var_value(env_domain))
219
220         return creds
221
222     def get_creds_ccache_name(self):
223         creds = self.get_credentials()
224         ccache = creds.get_named_ccache(self.get_loadparm())
225         ccache_name = ccache.get_name()
226
227         return ccache_name
228
229     def hexdump(self, src):
230         N = 0
231         result = ''
232         is_string = isinstance(src, str)
233         while src:
234             ll = src[:8]
235             lr = src[8:16]
236             src = src[16:]
237             if is_string:
238                 hl = ' '.join(["%02X" % ord(x) for x in ll])
239                 hr = ' '.join(["%02X" % ord(x) for x in lr])
240                 ll = ll.translate(HEXDUMP_FILTER)
241                 lr = lr.translate(HEXDUMP_FILTER)
242             else:
243                 hl = ' '.join(["%02X" % x for x in ll])
244                 hr = ' '.join(["%02X" % x for x in lr])
245                 ll = ll.translate(HEXDUMP_FILTER).decode('utf8')
246                 lr = lr.translate(HEXDUMP_FILTER).decode('utf8')
247             result += "[%04X] %-*s  %-*s  %s %s\n" % (N, 8 * 3, hl, 8 * 3, hr, ll, lr)
248             N += 16
249         return result
250
251     def insta_creds(self, template=None, username=None, userpass=None, kerberos_state=None):
252
253         if template is None:
254             raise ValueError("you need to supply a Credentials template")
255
256         if username is not None and userpass is None:
257             raise ValueError(
258                 "you cannot set creds username without setting a password")
259
260         if username is None:
261             assert userpass is None
262
263             username = template.get_username()
264             userpass = template.get_password()
265
266         simple_bind_dn = template.get_bind_dn()
267
268         if kerberos_state is None:
269             kerberos_state = template.get_kerberos_state()
270
271         # get a copy of the global creds or the passed in creds
272         c = Credentials()
273         c.set_username(username)
274         c.set_password(userpass)
275         c.set_domain(template.get_domain())
276         c.set_realm(template.get_realm())
277         c.set_workstation(template.get_workstation())
278         c.set_gensec_features(c.get_gensec_features()
279                               | samba.gensec.FEATURE_SEAL)
280         c.set_kerberos_state(kerberos_state)
281         if simple_bind_dn:
282             c.set_bind_dn(simple_bind_dn)
283         return c
284
285     def assertStringsEqual(self, a, b, msg=None, strip=False):
286         """Assert equality between two strings and highlight any differences.
287         If strip is true, leading and trailing whitespace is ignored."""
288         if strip:
289             a = a.strip()
290             b = b.strip()
291
292         if a != b:
293             sys.stderr.write("The strings differ %s(lengths %d vs %d); "
294                              "a diff follows\n"
295                              % ('when stripped ' if strip else '',
296                                 len(a), len(b),
297                                 ))
298
299             from difflib import unified_diff
300             diff = unified_diff(a.splitlines(True),
301                                 b.splitlines(True),
302                                 'a', 'b')
303             for line in diff:
304                 sys.stderr.write(line)
305
306             self.fail(msg)
307
308     def assertRaisesLdbError(self, errcode, message, f, *args, **kwargs):
309         """Assert a function raises a particular LdbError."""
310         if message is None:
311             message = f"{f.__name__}(*{args}, **{kwargs})"
312         try:
313             f(*args, **kwargs)
314         except ldb.LdbError as e:
315             (num, msg) = e.args
316             if isinstance(errcode, collections.abc.Container):
317                 found = num in errcode
318             else:
319                 found = num == errcode
320             if not found:
321                 lut = {v: k for k, v in vars(ldb).items()
322                        if k.startswith('ERR_') and isinstance(v, int)}
323                 if isinstance(errcode, collections.abc.Container):
324                     errcode_name = ' '.join(lut.get(x) for x in errcode)
325                 else:
326                     errcode_name = lut.get(errcode)
327                 self.fail(f"{message}, expected "
328                           f"LdbError {errcode_name}, {errcode} "
329                           f"got {lut.get(num)} ({num}) "
330                           f"{msg}")
331         else:
332             lut = {v: k for k, v in vars(ldb).items()
333                    if k.startswith('ERR_') and isinstance(v, int)}
334             if isinstance(errcode, collections.abc.Container):
335                 errcode_name = ' '.join(lut.get(x) for x in errcode)
336             else:
337                 errcode_name = lut.get(errcode)
338             self.fail("%s, expected "
339                       "LdbError %s, (%s) "
340                       "but we got success" % (message,
341                                               errcode_name,
342                                               errcode))
343
344
345 class LdbTestCase(TestCase):
346     """Trivial test case for running tests against a LDB."""
347
348     def setUp(self):
349         super().setUp()
350         self.tempfile = tempfile.NamedTemporaryFile(delete=False)
351         self.filename = self.tempfile.name
352         self.ldb = samba.Ldb(self.filename)
353
354     def set_modules(self, modules=None):
355         """Change the modules for this Ldb."""
356         if modules is None:
357             modules = []
358         m = ldb.Message()
359         m.dn = ldb.Dn(self.ldb, "@MODULES")
360         m["@LIST"] = ",".join(modules)
361         self.ldb.add(m)
362         self.ldb = samba.Ldb(self.filename)
363
364
365 class TestCaseInTempDir(TestCase):
366
367     def setUp(self):
368         super().setUp()
369         self.tempdir = tempfile.mkdtemp()
370         self.addCleanup(self._remove_tempdir)
371
372     def _remove_tempdir(self):
373         # Note asserting here is treated as an error rather than a test failure
374         self.assertEqual([], os.listdir(self.tempdir))
375         os.rmdir(self.tempdir)
376         self.tempdir = None
377
378     @contextmanager
379     def mktemp(self):
380         """Yield a temporary filename in the tempdir."""
381         try:
382             fd, fn = tempfile.mkstemp(dir=self.tempdir)
383             yield fn
384         finally:
385             try:
386                 os.close(fd)
387                 os.unlink(fn)
388             except (OSError, IOError) as e:
389                 print("could not remove temporary file: %s" % e,
390                       file=sys.stderr)
391
392     def rm_files(self, *files, allow_missing=False, _rm=os.remove):
393         """Remove listed files from the temp directory.
394
395         The files must be true files in the directory itself, not in
396         sub-directories.
397
398         By default a non-existent file will cause a test failure (or
399         error if used outside a test in e.g. tearDown), but if
400         allow_missing is true, the absence will be ignored.
401         """
402         for f in files:
403             path = os.path.join(self.tempdir, f)
404
405             # os.path.join will happily step out of the tempdir,
406             # so let's just check.
407             if os.path.dirname(path) != self.tempdir:
408                 raise ValueError(f"{path} might be outside {self.tempdir}")
409
410             try:
411                 _rm(path)
412             except FileNotFoundError as e:
413                 if not allow_missing:
414                     raise AssertionError(f"{f} not in {self.tempdir}: {e}")
415
416                 print(f"{f} not in {self.tempdir}")
417
418     def rm_dirs(self, *dirs, allow_missing=False):
419         """Remove listed directories from temp directory.
420
421         This works like rm_files, but only removes directories,
422         including their contents.
423         """
424         self.rm_files(*dirs, allow_missing=allow_missing, _rm=shutil.rmtree)
425
426
427 def env_loadparm():
428     lp = param.LoadParm()
429     try:
430         lp.load(os.environ["SMB_CONF_PATH"])
431     except KeyError:
432         raise KeyError("SMB_CONF_PATH not set")
433     return lp
434
435
436 def env_get_var_value(var_name, allow_missing=False):
437     """Returns value for variable in os.environ
438
439     Function throws AssertionError if variable is undefined.
440     Unit-test based python tests require certain input params
441     to be set in environment, otherwise they can't be run
442     """
443     if allow_missing:
444         if var_name not in os.environ.keys():
445             return None
446     assert var_name in os.environ.keys(), "Please supply %s in environment" % var_name
447     return os.environ[var_name]
448
449
450 cmdline_credentials = None
451
452
453 class RpcInterfaceTestCase(TestCase):
454     """DCE/RPC Test case."""
455
456
457 class BlackboxProcessError(Exception):
458     """This is raised when check_output() process returns a non-zero exit status
459
460     Exception instance should contain the exact exit code (S.returncode),
461     command line (S.cmd), process output (S.stdout) and process error stream
462     (S.stderr)
463     """
464
465     def __init__(self, returncode, cmd, stdout, stderr, msg=None):
466         self.returncode = returncode
467         if isinstance(cmd, list):
468             self.cmd = ' '.join(cmd)
469             self.shell = False
470         else:
471             self.cmd = cmd
472             self.shell = True
473         self.stdout = stdout
474         self.stderr = stderr
475         self.msg = msg
476
477     def __str__(self):
478         s = ("Command '%s'; shell %s; exit status %d; "
479              "stdout: '%s'; stderr: '%s'" %
480              (self.cmd, self.shell, self.returncode, self.stdout, self.stderr))
481         if self.msg is not None:
482             s = "%s; message: %s" % (s, self.msg)
483
484         return s
485
486
487 class BlackboxTestCase(TestCaseInTempDir):
488     """Base test case for blackbox tests."""
489
490     @staticmethod
491     def _make_cmdline(line):
492         """Expand the called script into a fully resolved path in the bin
493         directory."""
494         if isinstance(line, list):
495             parts = line
496         else:
497             parts = line.split(" ", 1)
498         cmd = parts[0]
499         exe = os.path.join(BINDIR, cmd)
500
501         python_cmds = ["samba-tool",
502                        "samba_dnsupdate",
503                        "samba_upgradedns",
504                        "script/traffic_replay",
505                        "script/traffic_learner"]
506
507         if os.path.exists(exe):
508             parts[0] = exe
509         if cmd in python_cmds and os.getenv("PYTHON", False):
510             parts.insert(0, os.environ["PYTHON"])
511
512         if not isinstance(line, list):
513             line = " ".join(parts)
514
515         return line
516
517     @classmethod
518     def check_run(cls, line, msg=None):
519         cls.check_exit_code(line, 0, msg=msg)
520
521     @classmethod
522     def check_exit_code(cls, line, expected, msg=None):
523         line = cls._make_cmdline(line)
524         use_shell = not isinstance(line, list)
525         p = subprocess.Popen(line,
526                              stdout=subprocess.PIPE,
527                              stderr=subprocess.PIPE,
528                              shell=use_shell)
529         stdoutdata, stderrdata = p.communicate()
530         retcode = p.returncode
531         if retcode != expected:
532             if msg is None:
533                 msg = "expected return code %s; got %s" % (expected, retcode)
534             raise BlackboxProcessError(retcode,
535                                        line,
536                                        stdoutdata,
537                                        stderrdata,
538                                        msg)
539         return stdoutdata
540
541     @classmethod
542     def check_output(cls, line):
543         use_shell = not isinstance(line, list)
544         line = cls._make_cmdline(line)
545         p = subprocess.Popen(line, stdout=subprocess.PIPE, stderr=subprocess.PIPE,
546                              shell=use_shell, close_fds=True)
547         stdoutdata, stderrdata = p.communicate()
548         retcode = p.returncode
549         if retcode:
550             raise BlackboxProcessError(retcode, line, stdoutdata, stderrdata)
551         return stdoutdata
552
553     #
554     # Run a command without checking the return code, returns the tuple
555     # (ret, stdout, stderr)
556     # where ret is the return code
557     #       stdout is a string containing the commands stdout
558     #       stderr is a string containing the commands stderr
559     @classmethod
560     def run_command(cls, line):
561         line = cls._make_cmdline(line)
562         use_shell = not isinstance(line, list)
563         p = subprocess.Popen(line,
564                              stdout=subprocess.PIPE,
565                              stderr=subprocess.PIPE,
566                              shell=use_shell)
567         stdoutdata, stderrdata = p.communicate()
568         retcode = p.returncode
569         return (retcode, stdoutdata.decode('UTF8'), stderrdata.decode('UTF8'))
570
571     # Generate a random password that can be safely  passed on the command line
572     # i.e. it does not contain any shell meta characters.
573     def random_password(self, count=32):
574         password = SystemRandom().choice(string.ascii_uppercase)
575         password += SystemRandom().choice(string.digits)
576         password += SystemRandom().choice(string.ascii_lowercase)
577         password += ''.join(SystemRandom().choice(string.ascii_uppercase +
578                             string.ascii_lowercase +
579                             string.digits) for x in range(count - 3))
580         return password
581
582
583 def connect_samdb(samdb_url, *, lp=None, session_info=None, credentials=None,
584                   flags=0, ldb_options=None, ldap_only=False, global_schema=True):
585     """Create SamDB instance and connects to samdb_url database.
586
587     :param samdb_url: Url for database to connect to.
588     :param lp: Optional loadparm object
589     :param session_info: Optional session information
590     :param credentials: Optional credentials, defaults to anonymous.
591     :param flags: Optional LDB flags
592     :param ldap_only: If set, only remote LDAP connection will be created.
593     :param global_schema: Whether to use global schema.
594
595     Added value for tests is that we have a shorthand function
596     to make proper URL for ldb.connect() while using default
597     parameters for connection based on test environment
598     """
599     if "://" not in samdb_url:
600         if not ldap_only and os.path.isfile(samdb_url):
601             samdb_url = "tdb://%s" % samdb_url
602         else:
603             samdb_url = "ldap://%s" % samdb_url
604     # use 'paged_search' module when connecting remotely
605     if samdb_url.startswith("ldap://"):
606         ldb_options = ["modules:paged_searches"]
607     elif ldap_only:
608         raise AssertionError("Trying to connect to %s while remote "
609                              "connection is required" % samdb_url)
610
611     # set defaults for test environment
612     if lp is None:
613         lp = env_loadparm()
614     if session_info is None:
615         session_info = samba.auth.system_session(lp)
616     if credentials is None:
617         credentials = cmdline_credentials
618
619     return SamDB(url=samdb_url,
620                  lp=lp,
621                  session_info=session_info,
622                  credentials=credentials,
623                  flags=flags,
624                  options=ldb_options,
625                  global_schema=global_schema)
626
627
628 def connect_samdb_ex(samdb_url, *, lp=None, session_info=None, credentials=None,
629                      flags=0, ldb_options=None, ldap_only=False):
630     """Connects to samdb_url database
631
632     :param samdb_url: Url for database to connect to.
633     :param lp: Optional loadparm object
634     :param session_info: Optional session information
635     :param credentials: Optional credentials, defaults to anonymous.
636     :param flags: Optional LDB flags
637     :param ldap_only: If set, only remote LDAP connection will be created.
638     :return: (sam_db_connection, rootDse_record) tuple
639     """
640     sam_db = connect_samdb(samdb_url, lp=lp, session_info=session_info,
641                            credentials=credentials, flags=flags,
642                            ldb_options=ldb_options, ldap_only=ldap_only)
643     # fetch RootDse
644     res = sam_db.search(base="", expression="", scope=ldb.SCOPE_BASE,
645                         attrs=["*"])
646     return (sam_db, res[0])
647
648
649 def connect_samdb_env(env_url, env_username, env_password, lp=None):
650     """Connect to SamDB by getting URL and Credentials from environment
651
652     :param env_url: Environment variable name to get lsb url from
653     :param env_username: Username environment variable
654     :param env_password: Password environment variable
655     :return: sam_db_connection
656     """
657     samdb_url = env_get_var_value(env_url)
658     creds = credentials.Credentials()
659     if lp is None:
660         # guess Credentials parameters here. Otherwise workstation
661         # and domain fields are NULL and gencache code segfaults
662         lp = param.LoadParm()
663         creds.guess(lp)
664     creds.set_username(env_get_var_value(env_username))
665     creds.set_password(env_get_var_value(env_password))
666     return connect_samdb(samdb_url, credentials=creds, lp=lp)
667
668
669 def delete_force(samdb, dn, **kwargs):
670     try:
671         samdb.delete(dn, **kwargs)
672     except ldb.LdbError as error:
673         (num, errstr) = error.args
674         assert num == ldb.ERR_NO_SUCH_OBJECT, "ldb.delete() failed: %s" % errstr
675
676
677 def create_test_ou(samdb, name):
678     """Creates a unique OU for the test"""
679
680     # Add some randomness to the test OU. Replication between the testenvs is
681     # constantly happening in the background. Deletion of the last test's
682     # objects can be slow to replicate out. So the OU created by a previous
683     # testenv may still exist at the point that tests start on another testenv.
684     rand = randint(1, 10000000)
685     dn = ldb.Dn(samdb, "OU=%s%d,%s" % (name, rand, samdb.get_default_basedn()))
686     samdb.add({"dn": dn, "objectclass": "organizationalUnit"})
687     return dn
688
689
690 @unique
691 class OptState(IntEnum):
692     NOOPT = 0
693     HYPHEN1 = 1
694     HYPHEN2 = 2
695     NAME = 3
696
697
698 def parse_help_consistency(out,
699                            options_start=None,
700                            options_end=None,
701                            optmap=None,
702                            max_leading_spaces=10):
703     if options_start is None:
704         opt_lines = []
705     else:
706         opt_lines = None
707
708     for raw_line in out.split('\n'):
709         line = raw_line.lstrip()
710         if line == '':
711             continue
712         if opt_lines is None:
713             if line == options_start:
714                 opt_lines = []
715             else:
716                 continue
717         if len(line) < len(raw_line) - max_leading_spaces:
718             # for the case where we have:
719             #
720             #  --foo        frobnicate or barlify depending on
721             #               --bar option.
722             #
723             # where we want to ignore the --bar.
724             continue
725         if line[0] == '-':
726             opt_lines.append(line)
727         if line == options_end:
728             break
729
730     if opt_lines is None:
731         # No --help options is not an error in *this* test.
732         return
733
734     is_longname_char = re.compile(r'^[\w-]$').match
735     for line in opt_lines:
736         state = OptState.NOOPT
737         name = None
738         prev = ' '
739         for c in line:
740             if state == OptState.NOOPT:
741                 if c == '-' and prev.isspace():
742                     state = OptState.HYPHEN1
743                 prev = c
744                 continue
745             if state == OptState.HYPHEN1:
746                 if c.isalnum():
747                     name = '-' + c
748                     state = OptState.NAME
749                 elif c == '-':
750                     state = OptState.HYPHEN2
751                 continue
752             if state == OptState.HYPHEN2:
753                 if c.isalnum():
754                     name = '--' + c
755                     state = OptState.NAME
756                 else:  # WTF, perhaps '--' ending option list.
757                     state = OptState.NOOPT
758                     prev = c
759                 continue
760             if state == OptState.NAME:
761                 if is_longname_char(c):
762                     name += c
763                 else:
764                     optmap.setdefault(name, []).append(line)
765                     state = OptState.NOOPT
766                     prev = c
767
768         if state == OptState.NAME:
769             optmap.setdefault(name, []).append(line)
770
771
772 def check_help_consistency(out,
773                            options_start=None,
774                            options_end=None):
775     """Ensure that options are not repeated and redefined in --help
776     output.
777
778     Returns None if everything is OK, otherwise a string indicating
779     the problems.
780
781     If options_start and/or options_end are provided, only the bit in
782     the output between these two lines is considered. For example,
783     with samba-tool,
784
785     options_start='Options:', options_end='Available subcommands:'
786
787     will prevent the test looking at the preamble which may contain
788     examples using options.
789     """
790     # Silly test, you might think, but this happens
791     optmap = {}
792     parse_help_consistency(out,
793                            options_start,
794                            options_end,
795                            optmap)
796
797     errors = []
798     for k, values in sorted(optmap.items()):
799         if len(values) > 1:
800             for v in values:
801                 errors.append("%s: %s" % (k, v))
802
803     if errors:
804         return "\n".join(errors)
805
806
807 def get_env_dir(key):
808     """A helper to pull a directory name from the environment, used in
809     some tests that optionally write e.g. fuzz seeds into a directory
810     named in an environment variable.
811     """
812     dir = os.environ.get(key)
813     if dir is None:
814         return None
815
816     if not os.path.isdir(dir):
817         raise ValueError(
818             f"{key} should name an existing directory (got '{dir}')")
819
820     return dir