3ea47d80e7fc9f4999096363ed9c451a170d23b5
[metze/samba/wip.git] / lib / testtools / testtools / matchers.py
1 # Copyright (c) 2009-2011 testtools developers. See LICENSE for details.
2
3 """Matchers, a way to express complex assertions outside the testcase.
4
5 Inspired by 'hamcrest'.
6
7 Matcher provides the abstract API that all matchers need to implement.
8
9 Bundled matchers are listed in __all__: a list can be obtained by running
10 $ python -c 'import testtools.matchers; print testtools.matchers.__all__'
11 """
12
13 __metaclass__ = type
14 __all__ = [
15     'AfterPreprocessing',
16     'AllMatch',
17     'Annotate',
18     'Contains',
19     'DirExists',
20     'DocTestMatches',
21     'EndsWith',
22     'Equals',
23     'FileContains',
24     'FileExists',
25     'GreaterThan',
26     'HasPermissions',
27     'Is',
28     'IsInstance',
29     'KeysEqual',
30     'LessThan',
31     'MatchesAll',
32     'MatchesAny',
33     'MatchesException',
34     'MatchesListwise',
35     'MatchesPredicate',
36     'MatchesRegex',
37     'MatchesSetwise',
38     'MatchesStructure',
39     'NotEquals',
40     'Not',
41     'PathExists',
42     'Raises',
43     'raises',
44     'SamePath',
45     'StartsWith',
46     'TarballContains',
47     ]
48
49 import doctest
50 import operator
51 from pprint import pformat
52 import re
53 import os
54 import sys
55 import tarfile
56 import types
57
58 from testtools.compat import (
59     classtypes,
60     _error_repr,
61     isbaseexception,
62     _isbytes,
63     istext,
64     str_is_unicode,
65     text_repr
66     )
67
68
69 class Matcher(object):
70     """A pattern matcher.
71
72     A Matcher must implement match and __str__ to be used by
73     testtools.TestCase.assertThat. Matcher.match(thing) returns None when
74     thing is completely matched, and a Mismatch object otherwise.
75
76     Matchers can be useful outside of test cases, as they are simply a
77     pattern matching language expressed as objects.
78
79     testtools.matchers is inspired by hamcrest, but is pythonic rather than
80     a Java transcription.
81     """
82
83     def match(self, something):
84         """Return None if this matcher matches something, a Mismatch otherwise.
85         """
86         raise NotImplementedError(self.match)
87
88     def __str__(self):
89         """Get a sensible human representation of the matcher.
90
91         This should include the parameters given to the matcher and any
92         state that would affect the matches operation.
93         """
94         raise NotImplementedError(self.__str__)
95
96
97 class Mismatch(object):
98     """An object describing a mismatch detected by a Matcher."""
99
100     def __init__(self, description=None, details=None):
101         """Construct a `Mismatch`.
102
103         :param description: A description to use.  If not provided,
104             `Mismatch.describe` must be implemented.
105         :param details: Extra details about the mismatch.  Defaults
106             to the empty dict.
107         """
108         if description:
109             self._description = description
110         if details is None:
111             details = {}
112         self._details = details
113
114     def describe(self):
115         """Describe the mismatch.
116
117         This should be either a human-readable string or castable to a string.
118         In particular, is should either be plain ascii or unicode on Python 2,
119         and care should be taken to escape control characters.
120         """
121         try:
122             return self._description
123         except AttributeError:
124             raise NotImplementedError(self.describe)
125
126     def get_details(self):
127         """Get extra details about the mismatch.
128
129         This allows the mismatch to provide extra information beyond the basic
130         description, including large text or binary files, or debugging internals
131         without having to force it to fit in the output of 'describe'.
132
133         The testtools assertion assertThat will query get_details and attach
134         all its values to the test, permitting them to be reported in whatever
135         manner the test environment chooses.
136
137         :return: a dict mapping names to Content objects. name is a string to
138             name the detail, and the Content object is the detail to add
139             to the result. For more information see the API to which items from
140             this dict are passed testtools.TestCase.addDetail.
141         """
142         return getattr(self, '_details', {})
143
144     def __repr__(self):
145         return  "<testtools.matchers.Mismatch object at %x attributes=%r>" % (
146             id(self), self.__dict__)
147
148
149 class MismatchError(AssertionError):
150     """Raised when a mismatch occurs."""
151
152     # This class exists to work around
153     # <https://bugs.launchpad.net/testtools/+bug/804127>.  It provides a
154     # guaranteed way of getting a readable exception, no matter what crazy
155     # characters are in the matchee, matcher or mismatch.
156
157     def __init__(self, matchee, matcher, mismatch, verbose=False):
158         # Have to use old-style upcalling for Python 2.4 and 2.5
159         # compatibility.
160         AssertionError.__init__(self)
161         self.matchee = matchee
162         self.matcher = matcher
163         self.mismatch = mismatch
164         self.verbose = verbose
165
166     def __str__(self):
167         difference = self.mismatch.describe()
168         if self.verbose:
169             # GZ 2011-08-24: Smelly API? Better to take any object and special
170             #                case text inside?
171             if istext(self.matchee) or _isbytes(self.matchee):
172                 matchee = text_repr(self.matchee, multiline=False)
173             else:
174                 matchee = repr(self.matchee)
175             return (
176                 'Match failed. Matchee: %s\nMatcher: %s\nDifference: %s\n'
177                 % (matchee, self.matcher, difference))
178         else:
179             return difference
180
181     if not str_is_unicode:
182
183         __unicode__ = __str__
184
185         def __str__(self):
186             return self.__unicode__().encode("ascii", "backslashreplace")
187
188
189 class MismatchDecorator(object):
190     """Decorate a ``Mismatch``.
191
192     Forwards all messages to the original mismatch object.  Probably the best
193     way to use this is inherit from this class and then provide your own
194     custom decoration logic.
195     """
196
197     def __init__(self, original):
198         """Construct a `MismatchDecorator`.
199
200         :param original: A `Mismatch` object to decorate.
201         """
202         self.original = original
203
204     def __repr__(self):
205         return '<testtools.matchers.MismatchDecorator(%r)>' % (self.original,)
206
207     def describe(self):
208         return self.original.describe()
209
210     def get_details(self):
211         return self.original.get_details()
212
213
214 class _NonManglingOutputChecker(doctest.OutputChecker):
215     """Doctest checker that works with unicode rather than mangling strings
216
217     This is needed because current Python versions have tried to fix string
218     encoding related problems, but regressed the default behaviour with
219     unicode inputs in the process.
220
221     In Python 2.6 and 2.7 ``OutputChecker.output_difference`` is was changed
222     to return a bytestring encoded as per ``sys.stdout.encoding``, or utf-8 if
223     that can't be determined. Worse, that encoding process happens in the
224     innocent looking `_indent` global function. Because the
225     `DocTestMismatch.describe` result may well not be destined for printing to
226     stdout, this is no good for us. To get a unicode return as before, the
227     method is monkey patched if ``doctest._encoding`` exists.
228
229     Python 3 has a different problem. For some reason both inputs are encoded
230     to ascii with 'backslashreplace', making an escaped string matches its
231     unescaped form. Overriding the offending ``OutputChecker._toAscii`` method
232     is sufficient to revert this.
233     """
234
235     def _toAscii(self, s):
236         """Return ``s`` unchanged rather than mangling it to ascii"""
237         return s
238
239     # Only do this overriding hackery if doctest has a broken _input function
240     if getattr(doctest, "_encoding", None) is not None:
241         from types import FunctionType as __F
242         __f = doctest.OutputChecker.output_difference.im_func
243         __g = dict(__f.func_globals)
244         def _indent(s, indent=4, _pattern=re.compile("^(?!$)", re.MULTILINE)):
245             """Prepend non-empty lines in ``s`` with ``indent`` number of spaces"""
246             return _pattern.sub(indent*" ", s)
247         __g["_indent"] = _indent
248         output_difference = __F(__f.func_code, __g, "output_difference")
249         del __F, __f, __g, _indent
250
251
252 class DocTestMatches(object):
253     """See if a string matches a doctest example."""
254
255     def __init__(self, example, flags=0):
256         """Create a DocTestMatches to match example.
257
258         :param example: The example to match e.g. 'foo bar baz'
259         :param flags: doctest comparison flags to match on. e.g.
260             doctest.ELLIPSIS.
261         """
262         if not example.endswith('\n'):
263             example += '\n'
264         self.want = example # required variable name by doctest.
265         self.flags = flags
266         self._checker = _NonManglingOutputChecker()
267
268     def __str__(self):
269         if self.flags:
270             flagstr = ", flags=%d" % self.flags
271         else:
272             flagstr = ""
273         return 'DocTestMatches(%r%s)' % (self.want, flagstr)
274
275     def _with_nl(self, actual):
276         result = self.want.__class__(actual)
277         if not result.endswith('\n'):
278             result += '\n'
279         return result
280
281     def match(self, actual):
282         with_nl = self._with_nl(actual)
283         if self._checker.check_output(self.want, with_nl, self.flags):
284             return None
285         return DocTestMismatch(self, with_nl)
286
287     def _describe_difference(self, with_nl):
288         return self._checker.output_difference(self, with_nl, self.flags)
289
290
291 class DocTestMismatch(Mismatch):
292     """Mismatch object for DocTestMatches."""
293
294     def __init__(self, matcher, with_nl):
295         self.matcher = matcher
296         self.with_nl = with_nl
297
298     def describe(self):
299         s = self.matcher._describe_difference(self.with_nl)
300         if str_is_unicode or isinstance(s, unicode):
301             return s
302         # GZ 2011-08-24: This is actually pretty bogus, most C0 codes should
303         #                be escaped, in addition to non-ascii bytes.
304         return s.decode("latin1").encode("ascii", "backslashreplace")
305
306
307 class DoesNotContain(Mismatch):
308
309     def __init__(self, matchee, needle):
310         """Create a DoesNotContain Mismatch.
311
312         :param matchee: the object that did not contain needle.
313         :param needle: the needle that 'matchee' was expected to contain.
314         """
315         self.matchee = matchee
316         self.needle = needle
317
318     def describe(self):
319         return "%r not in %r" % (self.needle, self.matchee)
320
321
322 class DoesNotStartWith(Mismatch):
323
324     def __init__(self, matchee, expected):
325         """Create a DoesNotStartWith Mismatch.
326
327         :param matchee: the string that did not match.
328         :param expected: the string that 'matchee' was expected to start with.
329         """
330         self.matchee = matchee
331         self.expected = expected
332
333     def describe(self):
334         return "%s does not start with %s." % (
335             text_repr(self.matchee), text_repr(self.expected))
336
337
338 class DoesNotEndWith(Mismatch):
339
340     def __init__(self, matchee, expected):
341         """Create a DoesNotEndWith Mismatch.
342
343         :param matchee: the string that did not match.
344         :param expected: the string that 'matchee' was expected to end with.
345         """
346         self.matchee = matchee
347         self.expected = expected
348
349     def describe(self):
350         return "%s does not end with %s." % (
351             text_repr(self.matchee), text_repr(self.expected))
352
353
354 class _BinaryComparison(object):
355     """Matcher that compares an object to another object."""
356
357     def __init__(self, expected):
358         self.expected = expected
359
360     def __str__(self):
361         return "%s(%r)" % (self.__class__.__name__, self.expected)
362
363     def match(self, other):
364         if self.comparator(other, self.expected):
365             return None
366         return _BinaryMismatch(self.expected, self.mismatch_string, other)
367
368     def comparator(self, expected, other):
369         raise NotImplementedError(self.comparator)
370
371
372 class _BinaryMismatch(Mismatch):
373     """Two things did not match."""
374
375     def __init__(self, expected, mismatch_string, other):
376         self.expected = expected
377         self._mismatch_string = mismatch_string
378         self.other = other
379
380     def _format(self, thing):
381         # Blocks of text with newlines are formatted as triple-quote
382         # strings. Everything else is pretty-printed.
383         if istext(thing) or _isbytes(thing):
384             return text_repr(thing)
385         return pformat(thing)
386
387     def describe(self):
388         left = repr(self.expected)
389         right = repr(self.other)
390         if len(left) + len(right) > 70:
391             return "%s:\nreference = %s\nactual    = %s\n" % (
392                 self._mismatch_string, self._format(self.expected),
393                 self._format(self.other))
394         else:
395             return "%s %s %s" % (left, self._mismatch_string, right)
396
397
398 class MatchesPredicate(Matcher):
399     """Match if a given function returns True.
400
401     It is reasonably common to want to make a very simple matcher based on a
402     function that you already have that returns True or False given a single
403     argument (i.e. a predicate function).  This matcher makes it very easy to
404     do so. e.g.::
405
406       IsEven = MatchesPredicate(lambda x: x % 2 == 0, '%s is not even')
407       self.assertThat(4, IsEven)
408     """
409
410     def __init__(self, predicate, message):
411         """Create a ``MatchesPredicate`` matcher.
412
413         :param predicate: A function that takes a single argument and returns
414             a value that will be interpreted as a boolean.
415         :param message: A message to describe a mismatch.  It will be formatted
416             with '%' and be given whatever was passed to ``match()``. Thus, it
417             needs to contain exactly one thing like '%s', '%d' or '%f'.
418         """
419         self.predicate = predicate
420         self.message = message
421
422     def __str__(self):
423         return '%s(%r, %r)' % (
424             self.__class__.__name__, self.predicate, self.message)
425
426     def match(self, x):
427         if not self.predicate(x):
428             return Mismatch(self.message % x)
429
430
431 class Equals(_BinaryComparison):
432     """Matches if the items are equal."""
433
434     comparator = operator.eq
435     mismatch_string = '!='
436
437
438 class NotEquals(_BinaryComparison):
439     """Matches if the items are not equal.
440
441     In most cases, this is equivalent to ``Not(Equals(foo))``. The difference
442     only matters when testing ``__ne__`` implementations.
443     """
444
445     comparator = operator.ne
446     mismatch_string = '=='
447
448
449 class Is(_BinaryComparison):
450     """Matches if the items are identical."""
451
452     comparator = operator.is_
453     mismatch_string = 'is not'
454
455
456 class IsInstance(object):
457     """Matcher that wraps isinstance."""
458
459     def __init__(self, *types):
460         self.types = tuple(types)
461
462     def __str__(self):
463         return "%s(%s)" % (self.__class__.__name__,
464                 ', '.join(type.__name__ for type in self.types))
465
466     def match(self, other):
467         if isinstance(other, self.types):
468             return None
469         return NotAnInstance(other, self.types)
470
471
472 class NotAnInstance(Mismatch):
473
474     def __init__(self, matchee, types):
475         """Create a NotAnInstance Mismatch.
476
477         :param matchee: the thing which is not an instance of any of types.
478         :param types: A tuple of the types which were expected.
479         """
480         self.matchee = matchee
481         self.types = types
482
483     def describe(self):
484         if len(self.types) == 1:
485             typestr = self.types[0].__name__
486         else:
487             typestr = 'any of (%s)' % ', '.join(type.__name__ for type in
488                     self.types)
489         return "'%s' is not an instance of %s" % (self.matchee, typestr)
490
491
492 class LessThan(_BinaryComparison):
493     """Matches if the item is less than the matchers reference object."""
494
495     comparator = operator.__lt__
496     mismatch_string = 'is not >'
497
498
499 class GreaterThan(_BinaryComparison):
500     """Matches if the item is greater than the matchers reference object."""
501
502     comparator = operator.__gt__
503     mismatch_string = 'is not <'
504
505
506 class MatchesAny(object):
507     """Matches if any of the matchers it is created with match."""
508
509     def __init__(self, *matchers):
510         self.matchers = matchers
511
512     def match(self, matchee):
513         results = []
514         for matcher in self.matchers:
515             mismatch = matcher.match(matchee)
516             if mismatch is None:
517                 return None
518             results.append(mismatch)
519         return MismatchesAll(results)
520
521     def __str__(self):
522         return "MatchesAny(%s)" % ', '.join([
523             str(matcher) for matcher in self.matchers])
524
525
526 class MatchesAll(object):
527     """Matches if all of the matchers it is created with match."""
528
529     def __init__(self, *matchers, **options):
530         """Construct a MatchesAll matcher.
531
532         Just list the component matchers as arguments in the ``*args``
533         style. If you want only the first mismatch to be reported, past in
534         first_only=True as a keyword argument. By default, all mismatches are
535         reported.
536         """
537         self.matchers = matchers
538         self.first_only = options.get('first_only', False)
539
540     def __str__(self):
541         return 'MatchesAll(%s)' % ', '.join(map(str, self.matchers))
542
543     def match(self, matchee):
544         results = []
545         for matcher in self.matchers:
546             mismatch = matcher.match(matchee)
547             if mismatch is not None:
548                 if self.first_only:
549                     return mismatch
550                 results.append(mismatch)
551         if results:
552             return MismatchesAll(results)
553         else:
554             return None
555
556
557 class MismatchesAll(Mismatch):
558     """A mismatch with many child mismatches."""
559
560     def __init__(self, mismatches):
561         self.mismatches = mismatches
562
563     def describe(self):
564         descriptions = ["Differences: ["]
565         for mismatch in self.mismatches:
566             descriptions.append(mismatch.describe())
567         descriptions.append("]")
568         return '\n'.join(descriptions)
569
570
571 class Not(object):
572     """Inverts a matcher."""
573
574     def __init__(self, matcher):
575         self.matcher = matcher
576
577     def __str__(self):
578         return 'Not(%s)' % (self.matcher,)
579
580     def match(self, other):
581         mismatch = self.matcher.match(other)
582         if mismatch is None:
583             return MatchedUnexpectedly(self.matcher, other)
584         else:
585             return None
586
587
588 class MatchedUnexpectedly(Mismatch):
589     """A thing matched when it wasn't supposed to."""
590
591     def __init__(self, matcher, other):
592         self.matcher = matcher
593         self.other = other
594
595     def describe(self):
596         return "%r matches %s" % (self.other, self.matcher)
597
598
599 class MatchesException(Matcher):
600     """Match an exc_info tuple against an exception instance or type."""
601
602     def __init__(self, exception, value_re=None):
603         """Create a MatchesException that will match exc_info's for exception.
604
605         :param exception: Either an exception instance or type.
606             If an instance is given, the type and arguments of the exception
607             are checked. If a type is given only the type of the exception is
608             checked. If a tuple is given, then as with isinstance, any of the
609             types in the tuple matching is sufficient to match.
610         :param value_re: If 'exception' is a type, and the matchee exception
611             is of the right type, then match against this.  If value_re is a
612             string, then assume value_re is a regular expression and match
613             the str() of the exception against it.  Otherwise, assume value_re
614             is a matcher, and match the exception against it.
615         """
616         Matcher.__init__(self)
617         self.expected = exception
618         if istext(value_re):
619             value_re = AfterPreproccessing(str, MatchesRegex(value_re), False)
620         self.value_re = value_re
621         self._is_instance = type(self.expected) not in classtypes() + (tuple,)
622
623     def match(self, other):
624         if type(other) != tuple:
625             return Mismatch('%r is not an exc_info tuple' % other)
626         expected_class = self.expected
627         if self._is_instance:
628             expected_class = expected_class.__class__
629         if not issubclass(other[0], expected_class):
630             return Mismatch('%r is not a %r' % (other[0], expected_class))
631         if self._is_instance:
632             if other[1].args != self.expected.args:
633                 return Mismatch('%s has different arguments to %s.' % (
634                         _error_repr(other[1]), _error_repr(self.expected)))
635         elif self.value_re is not None:
636             return self.value_re.match(other[1])
637
638     def __str__(self):
639         if self._is_instance:
640             return "MatchesException(%s)" % _error_repr(self.expected)
641         return "MatchesException(%s)" % repr(self.expected)
642
643
644 class Contains(Matcher):
645     """Checks whether something is contained in another thing."""
646
647     def __init__(self, needle):
648         """Create a Contains Matcher.
649
650         :param needle: the thing that needs to be contained by matchees.
651         """
652         self.needle = needle
653
654     def __str__(self):
655         return "Contains(%r)" % (self.needle,)
656
657     def match(self, matchee):
658         try:
659             if self.needle not in matchee:
660                 return DoesNotContain(matchee, self.needle)
661         except TypeError:
662             # e.g. 1 in 2 will raise TypeError
663             return DoesNotContain(matchee, self.needle)
664         return None
665
666
667 class StartsWith(Matcher):
668     """Checks whether one string starts with another."""
669
670     def __init__(self, expected):
671         """Create a StartsWith Matcher.
672
673         :param expected: the string that matchees should start with.
674         """
675         self.expected = expected
676
677     def __str__(self):
678         return "StartsWith(%r)" % (self.expected,)
679
680     def match(self, matchee):
681         if not matchee.startswith(self.expected):
682             return DoesNotStartWith(matchee, self.expected)
683         return None
684
685
686 class EndsWith(Matcher):
687     """Checks whether one string starts with another."""
688
689     def __init__(self, expected):
690         """Create a EndsWith Matcher.
691
692         :param expected: the string that matchees should end with.
693         """
694         self.expected = expected
695
696     def __str__(self):
697         return "EndsWith(%r)" % (self.expected,)
698
699     def match(self, matchee):
700         if not matchee.endswith(self.expected):
701             return DoesNotEndWith(matchee, self.expected)
702         return None
703
704
705 class KeysEqual(Matcher):
706     """Checks whether a dict has particular keys."""
707
708     def __init__(self, *expected):
709         """Create a `KeysEqual` Matcher.
710
711         :param expected: The keys the dict is expected to have.  If a dict,
712             then we use the keys of that dict, if a collection, we assume it
713             is a collection of expected keys.
714         """
715         try:
716             self.expected = expected.keys()
717         except AttributeError:
718             self.expected = list(expected)
719
720     def __str__(self):
721         return "KeysEqual(%s)" % ', '.join(map(repr, self.expected))
722
723     def match(self, matchee):
724         expected = sorted(self.expected)
725         matched = Equals(expected).match(sorted(matchee.keys()))
726         if matched:
727             return AnnotatedMismatch(
728                 'Keys not equal',
729                 _BinaryMismatch(expected, 'does not match', matchee))
730         return None
731
732
733 class Annotate(object):
734     """Annotates a matcher with a descriptive string.
735
736     Mismatches are then described as '<mismatch>: <annotation>'.
737     """
738
739     def __init__(self, annotation, matcher):
740         self.annotation = annotation
741         self.matcher = matcher
742
743     @classmethod
744     def if_message(cls, annotation, matcher):
745         """Annotate ``matcher`` only if ``annotation`` is non-empty."""
746         if not annotation:
747             return matcher
748         return cls(annotation, matcher)
749
750     def __str__(self):
751         return 'Annotate(%r, %s)' % (self.annotation, self.matcher)
752
753     def match(self, other):
754         mismatch = self.matcher.match(other)
755         if mismatch is not None:
756             return AnnotatedMismatch(self.annotation, mismatch)
757
758
759 class AnnotatedMismatch(MismatchDecorator):
760     """A mismatch annotated with a descriptive string."""
761
762     def __init__(self, annotation, mismatch):
763         super(AnnotatedMismatch, self).__init__(mismatch)
764         self.annotation = annotation
765         self.mismatch = mismatch
766
767     def describe(self):
768         return '%s: %s' % (self.original.describe(), self.annotation)
769
770
771 class Raises(Matcher):
772     """Match if the matchee raises an exception when called.
773
774     Exceptions which are not subclasses of Exception propogate out of the
775     Raises.match call unless they are explicitly matched.
776     """
777
778     def __init__(self, exception_matcher=None):
779         """Create a Raises matcher.
780
781         :param exception_matcher: Optional validator for the exception raised
782             by matchee. If supplied the exc_info tuple for the exception raised
783             is passed into that matcher. If no exception_matcher is supplied
784             then the simple fact of raising an exception is considered enough
785             to match on.
786         """
787         self.exception_matcher = exception_matcher
788
789     def match(self, matchee):
790         try:
791             result = matchee()
792             return Mismatch('%r returned %r' % (matchee, result))
793         # Catch all exceptions: Raises() should be able to match a
794         # KeyboardInterrupt or SystemExit.
795         except:
796             exc_info = sys.exc_info()
797             if self.exception_matcher:
798                 mismatch = self.exception_matcher.match(exc_info)
799                 if not mismatch:
800                     del exc_info
801                     return
802             else:
803                 mismatch = None
804             # The exception did not match, or no explicit matching logic was
805             # performed. If the exception is a non-user exception (that is, not
806             # a subclass of Exception on Python 2.5+) then propogate it.
807             if isbaseexception(exc_info[1]):
808                 del exc_info
809                 raise
810             return mismatch
811
812     def __str__(self):
813         return 'Raises()'
814
815
816 def raises(exception):
817     """Make a matcher that checks that a callable raises an exception.
818
819     This is a convenience function, exactly equivalent to::
820
821         return Raises(MatchesException(exception))
822
823     See `Raises` and `MatchesException` for more information.
824     """
825     return Raises(MatchesException(exception))
826
827
828 class MatchesListwise(object):
829     """Matches if each matcher matches the corresponding value.
830
831     More easily explained by example than in words:
832
833     >>> MatchesListwise([Equals(1)]).match([1])
834     >>> MatchesListwise([Equals(1), Equals(2)]).match([1, 2])
835     >>> print (MatchesListwise([Equals(1), Equals(2)]).match([2, 1]).describe())
836     Differences: [
837     1 != 2
838     2 != 1
839     ]
840     >>> matcher = MatchesListwise([Equals(1), Equals(2)], first_only=True)
841     >>> print (matcher.match([3, 4]).describe())
842     1 != 3
843     """
844
845     def __init__(self, matchers, first_only=False):
846         """Construct a MatchesListwise matcher.
847
848         :param matchers: A list of matcher that the matched values must match.
849         :param first_only: If True, then only report the first mismatch,
850             otherwise report all of them. Defaults to False.
851         """
852         self.matchers = matchers
853         self.first_only = first_only
854
855     def match(self, values):
856         mismatches = []
857         length_mismatch = Annotate(
858             "Length mismatch", Equals(len(self.matchers))).match(len(values))
859         if length_mismatch:
860             mismatches.append(length_mismatch)
861         for matcher, value in zip(self.matchers, values):
862             mismatch = matcher.match(value)
863             if mismatch:
864                 if self.first_only:
865                     return mismatch
866                 mismatches.append(mismatch)
867         if mismatches:
868             return MismatchesAll(mismatches)
869
870
871 class MatchesStructure(object):
872     """Matcher that matches an object structurally.
873
874     'Structurally' here means that attributes of the object being matched are
875     compared against given matchers.
876
877     `fromExample` allows the creation of a matcher from a prototype object and
878     then modified versions can be created with `update`.
879
880     `byEquality` creates a matcher in much the same way as the constructor,
881     except that the matcher for each of the attributes is assumed to be
882     `Equals`.
883
884     `byMatcher` creates a similar matcher to `byEquality`, but you get to pick
885     the matcher, rather than just using `Equals`.
886     """
887
888     def __init__(self, **kwargs):
889         """Construct a `MatchesStructure`.
890
891         :param kwargs: A mapping of attributes to matchers.
892         """
893         self.kws = kwargs
894
895     @classmethod
896     def byEquality(cls, **kwargs):
897         """Matches an object where the attributes equal the keyword values.
898
899         Similar to the constructor, except that the matcher is assumed to be
900         Equals.
901         """
902         return cls.byMatcher(Equals, **kwargs)
903
904     @classmethod
905     def byMatcher(cls, matcher, **kwargs):
906         """Matches an object where the attributes match the keyword values.
907
908         Similar to the constructor, except that the provided matcher is used
909         to match all of the values.
910         """
911         return cls(
912             **dict((name, matcher(value)) for name, value in kwargs.items()))
913
914     @classmethod
915     def fromExample(cls, example, *attributes):
916         kwargs = {}
917         for attr in attributes:
918             kwargs[attr] = Equals(getattr(example, attr))
919         return cls(**kwargs)
920
921     def update(self, **kws):
922         new_kws = self.kws.copy()
923         for attr, matcher in kws.items():
924             if matcher is None:
925                 new_kws.pop(attr, None)
926             else:
927                 new_kws[attr] = matcher
928         return type(self)(**new_kws)
929
930     def __str__(self):
931         kws = []
932         for attr, matcher in sorted(self.kws.items()):
933             kws.append("%s=%s" % (attr, matcher))
934         return "%s(%s)" % (self.__class__.__name__, ', '.join(kws))
935
936     def match(self, value):
937         matchers = []
938         values = []
939         for attr, matcher in sorted(self.kws.items()):
940             matchers.append(Annotate(attr, matcher))
941             values.append(getattr(value, attr))
942         return MatchesListwise(matchers).match(values)
943
944
945 class MatchesRegex(object):
946     """Matches if the matchee is matched by a regular expression."""
947
948     def __init__(self, pattern, flags=0):
949         self.pattern = pattern
950         self.flags = flags
951
952     def __str__(self):
953         args = ['%r' % self.pattern]
954         flag_arg = []
955         # dir() sorts the attributes for us, so we don't need to do it again.
956         for flag in dir(re):
957             if len(flag) == 1:
958                 if self.flags & getattr(re, flag):
959                     flag_arg.append('re.%s' % flag)
960         if flag_arg:
961             args.append('|'.join(flag_arg))
962         return '%s(%s)' % (self.__class__.__name__, ', '.join(args))
963
964     def match(self, value):
965         if not re.match(self.pattern, value, self.flags):
966             pattern = self.pattern
967             if not isinstance(pattern, str_is_unicode and str or unicode):
968                 pattern = pattern.decode("latin1")
969             pattern = pattern.encode("unicode_escape").decode("ascii")
970             return Mismatch("%r does not match /%s/" % (
971                     value, pattern.replace("\\\\", "\\")))
972
973
974 class MatchesSetwise(object):
975     """Matches if all the matchers match elements of the value being matched.
976
977     That is, each element in the 'observed' set must match exactly one matcher
978     from the set of matchers, with no matchers left over.
979
980     The difference compared to `MatchesListwise` is that the order of the
981     matchings does not matter.
982     """
983
984     def __init__(self, *matchers):
985         self.matchers = matchers
986
987     def match(self, observed):
988         remaining_matchers = set(self.matchers)
989         not_matched = []
990         for value in observed:
991             for matcher in remaining_matchers:
992                 if matcher.match(value) is None:
993                     remaining_matchers.remove(matcher)
994                     break
995             else:
996                 not_matched.append(value)
997         if not_matched or remaining_matchers:
998             remaining_matchers = list(remaining_matchers)
999             # There are various cases that all should be reported somewhat
1000             # differently.
1001
1002             # There are two trivial cases:
1003             # 1) There are just some matchers left over.
1004             # 2) There are just some values left over.
1005
1006             # Then there are three more interesting cases:
1007             # 3) There are the same number of matchers and values left over.
1008             # 4) There are more matchers left over than values.
1009             # 5) There are more values left over than matchers.
1010
1011             if len(not_matched) == 0:
1012                 if len(remaining_matchers) > 1:
1013                     msg = "There were %s matchers left over: " % (
1014                         len(remaining_matchers),)
1015                 else:
1016                     msg = "There was 1 matcher left over: "
1017                 msg += ', '.join(map(str, remaining_matchers))
1018                 return Mismatch(msg)
1019             elif len(remaining_matchers) == 0:
1020                 if len(not_matched) > 1:
1021                     return Mismatch(
1022                         "There were %s values left over: %s" % (
1023                             len(not_matched), not_matched))
1024                 else:
1025                     return Mismatch(
1026                         "There was 1 value left over: %s" % (
1027                             not_matched, ))
1028             else:
1029                 common_length = min(len(remaining_matchers), len(not_matched))
1030                 if common_length == 0:
1031                     raise AssertionError("common_length can't be 0 here")
1032                 if common_length > 1:
1033                     msg = "There were %s mismatches" % (common_length,)
1034                 else:
1035                     msg = "There was 1 mismatch"
1036                 if len(remaining_matchers) > len(not_matched):
1037                     extra_matchers = remaining_matchers[common_length:]
1038                     msg += " and %s extra matcher" % (len(extra_matchers), )
1039                     if len(extra_matchers) > 1:
1040                         msg += "s"
1041                     msg += ': ' + ', '.join(map(str, extra_matchers))
1042                 elif len(not_matched) > len(remaining_matchers):
1043                     extra_values = not_matched[common_length:]
1044                     msg += " and %s extra value" % (len(extra_values), )
1045                     if len(extra_values) > 1:
1046                         msg += "s"
1047                     msg += ': ' + str(extra_values)
1048                 return Annotate(
1049                     msg, MatchesListwise(remaining_matchers[:common_length])
1050                     ).match(not_matched[:common_length])
1051
1052
1053 class AfterPreprocessing(object):
1054     """Matches if the value matches after passing through a function.
1055
1056     This can be used to aid in creating trivial matchers as functions, for
1057     example::
1058
1059       def PathHasFileContent(content):
1060           def _read(path):
1061               return open(path).read()
1062           return AfterPreprocessing(_read, Equals(content))
1063     """
1064
1065     def __init__(self, preprocessor, matcher, annotate=True):
1066         """Create an AfterPreprocessing matcher.
1067
1068         :param preprocessor: A function called with the matchee before
1069             matching.
1070         :param matcher: What to match the preprocessed matchee against.
1071         :param annotate: Whether or not to annotate the matcher with
1072             something explaining how we transformed the matchee. Defaults
1073             to True.
1074         """
1075         self.preprocessor = preprocessor
1076         self.matcher = matcher
1077         self.annotate = annotate
1078
1079     def _str_preprocessor(self):
1080         if isinstance(self.preprocessor, types.FunctionType):
1081             return '<function %s>' % self.preprocessor.__name__
1082         return str(self.preprocessor)
1083
1084     def __str__(self):
1085         return "AfterPreprocessing(%s, %s)" % (
1086             self._str_preprocessor(), self.matcher)
1087
1088     def match(self, value):
1089         after = self.preprocessor(value)
1090         if self.annotate:
1091             matcher = Annotate(
1092                 "after %s on %r" % (self._str_preprocessor(), value),
1093                 self.matcher)
1094         else:
1095             matcher = self.matcher
1096         return matcher.match(after)
1097
1098 # This is the old, deprecated. spelling of the name, kept for backwards
1099 # compatibility.
1100 AfterPreproccessing = AfterPreprocessing
1101
1102
1103 class AllMatch(object):
1104     """Matches if all provided values match the given matcher."""
1105
1106     def __init__(self, matcher):
1107         self.matcher = matcher
1108
1109     def __str__(self):
1110         return 'AllMatch(%s)' % (self.matcher,)
1111
1112     def match(self, values):
1113         mismatches = []
1114         for value in values:
1115             mismatch = self.matcher.match(value)
1116             if mismatch:
1117                 mismatches.append(mismatch)
1118         if mismatches:
1119             return MismatchesAll(mismatches)
1120
1121
1122 def PathExists():
1123     """Matches if the given path exists.
1124
1125     Use like this::
1126
1127       assertThat('/some/path', PathExists())
1128     """
1129     return MatchesPredicate(os.path.exists, "%s does not exist.")
1130
1131
1132 def DirExists():
1133     """Matches if the path exists and is a directory."""
1134     return MatchesAll(
1135         PathExists(),
1136         MatchesPredicate(os.path.isdir, "%s is not a directory."),
1137         first_only=True)
1138
1139
1140 def FileExists():
1141     """Matches if the given path exists and is a file."""
1142     return MatchesAll(
1143         PathExists(),
1144         MatchesPredicate(os.path.isfile, "%s is not a file."),
1145         first_only=True)
1146
1147
1148 class DirContains(Matcher):
1149     """Matches if the given directory contains files with the given names.
1150
1151     That is, is the directory listing exactly equal to the given files?
1152     """
1153
1154     def __init__(self, filenames=None, matcher=None):
1155         """Construct a ``DirContains`` matcher.
1156
1157         Can be used in a basic mode where the whole directory listing is
1158         matched against an expected directory listing (by passing
1159         ``filenames``).  Can also be used in a more advanced way where the
1160         whole directory listing is matched against an arbitrary matcher (by
1161         passing ``matcher`` instead).
1162
1163         :param filenames: If specified, match the sorted directory listing
1164             against this list of filenames, sorted.
1165         :param matcher: If specified, match the sorted directory listing
1166             against this matcher.
1167         """
1168         if filenames == matcher == None:
1169             raise AssertionError(
1170                 "Must provide one of `filenames` or `matcher`.")
1171         if None not in (filenames, matcher):
1172             raise AssertionError(
1173                 "Must provide either `filenames` or `matcher`, not both.")
1174         if filenames is None:
1175             self.matcher = matcher
1176         else:
1177             self.matcher = Equals(sorted(filenames))
1178
1179     def match(self, path):
1180         mismatch = DirExists().match(path)
1181         if mismatch is not None:
1182             return mismatch
1183         return self.matcher.match(sorted(os.listdir(path)))
1184
1185
1186 class FileContains(Matcher):
1187     """Matches if the given file has the specified contents."""
1188
1189     def __init__(self, contents=None, matcher=None):
1190         """Construct a ``FileContains`` matcher.
1191
1192         Can be used in a basic mode where the file contents are compared for
1193         equality against the expected file contents (by passing ``contents``).
1194         Can also be used in a more advanced way where the file contents are
1195         matched against an arbitrary matcher (by passing ``matcher`` instead).
1196
1197         :param contents: If specified, match the contents of the file with
1198             these contents.
1199         :param matcher: If specified, match the contents of the file against
1200             this matcher.
1201         """
1202         if contents == matcher == None:
1203             raise AssertionError(
1204                 "Must provide one of `contents` or `matcher`.")
1205         if None not in (contents, matcher):
1206             raise AssertionError(
1207                 "Must provide either `contents` or `matcher`, not both.")
1208         if matcher is None:
1209             self.matcher = Equals(contents)
1210         else:
1211             self.matcher = matcher
1212
1213     def match(self, path):
1214         mismatch = PathExists().match(path)
1215         if mismatch is not None:
1216             return mismatch
1217         f = open(path)
1218         try:
1219             actual_contents = f.read()
1220             return self.matcher.match(actual_contents)
1221         finally:
1222             f.close()
1223
1224     def __str__(self):
1225         return "File at path exists and contains %s" % self.contents
1226
1227
1228 class TarballContains(Matcher):
1229     """Matches if the given tarball contains the given paths.
1230
1231     Uses TarFile.getnames() to get the paths out of the tarball.
1232     """
1233
1234     def __init__(self, paths):
1235         super(TarballContains, self).__init__()
1236         self.paths = paths
1237
1238     def match(self, tarball_path):
1239         tarball = tarfile.open(tarball_path)
1240         try:
1241             return Equals(sorted(self.paths)).match(sorted(tarball.getnames()))
1242         finally:
1243             tarball.close()
1244
1245
1246 class SamePath(Matcher):
1247     """Matches if two paths are the same.
1248
1249     That is, the paths are equal, or they point to the same file but in
1250     different ways.  The paths do not have to exist.
1251     """
1252
1253     def __init__(self, path):
1254         super(SamePath, self).__init__()
1255         self.path = path
1256
1257     def match(self, other_path):
1258         f = lambda x: os.path.abspath(os.path.realpath(x))
1259         return Equals(f(self.path)).match(f(other_path))
1260
1261
1262 class HasPermissions(Matcher):
1263     """Matches if a file has the given permissions.
1264
1265     Permissions are specified and matched as a four-digit octal string.
1266     """
1267
1268     def __init__(self, octal_permissions):
1269         """Construct a HasPermissions matcher.
1270
1271         :param octal_permissions: A four digit octal string, representing the
1272             intended access permissions. e.g. '0775' for rwxrwxr-x.
1273         """
1274         super(HasPermissions, self).__init__()
1275         self.octal_permissions = octal_permissions
1276
1277     def match(self, filename):
1278         permissions = oct(os.stat(filename).st_mode)[-4:]
1279         return Equals(self.octal_permissions).match(permissions)
1280
1281
1282 # Signal that this is part of the testing framework, and that code from this
1283 # should not normally appear in tracebacks.
1284 __unittest = True