testtools: Update to latest version.
[metze/samba/wip.git] / lib / testtools / testtools / matchers / _dict.py
1 # Copyright (c) 2009-2012 testtools developers. See LICENSE for details.
2
3 __all__ = [
4     'KeysEqual',
5     ]
6
7 from ..helpers import (
8     dict_subtract,
9     filter_values,
10     map_values,
11     )
12 from ._higherorder import (
13     AnnotatedMismatch,
14     PrefixedMismatch,
15     MismatchesAll,
16     )
17 from ._impl import Matcher, Mismatch
18
19
20 def LabelledMismatches(mismatches, details=None):
21     """A collection of mismatches, each labelled."""
22     return MismatchesAll(
23         (PrefixedMismatch(k, v) for (k, v) in sorted(mismatches.items())),
24         wrap=False)
25
26
27 class MatchesAllDict(Matcher):
28     """Matches if all of the matchers it is created with match.
29
30     A lot like ``MatchesAll``, but takes a dict of Matchers and labels any
31     mismatches with the key of the dictionary.
32     """
33
34     def __init__(self, matchers):
35         super(MatchesAllDict, self).__init__()
36         self.matchers = matchers
37
38     def __str__(self):
39         return 'MatchesAllDict(%s)' % (_format_matcher_dict(self.matchers),)
40
41     def match(self, observed):
42         mismatches = {}
43         for label in self.matchers:
44             mismatches[label] = self.matchers[label].match(observed)
45         return _dict_to_mismatch(
46             mismatches, result_mismatch=LabelledMismatches)
47
48
49 class DictMismatches(Mismatch):
50     """A mismatch with a dict of child mismatches."""
51
52     def __init__(self, mismatches, details=None):
53         super(DictMismatches, self).__init__(None, details=details)
54         self.mismatches = mismatches
55
56     def describe(self):
57         lines = ['{']
58         lines.extend(
59             ['  %r: %s,' % (key, mismatch.describe())
60              for (key, mismatch) in sorted(self.mismatches.items())])
61         lines.append('}')
62         return '\n'.join(lines)
63
64
65 def _dict_to_mismatch(data, to_mismatch=None,
66                       result_mismatch=DictMismatches):
67     if to_mismatch:
68         data = map_values(to_mismatch, data)
69     mismatches = filter_values(bool, data)
70     if mismatches:
71         return result_mismatch(mismatches)
72
73
74 class _MatchCommonKeys(Matcher):
75     """Match on keys in a dictionary.
76
77     Given a dictionary where the values are matchers, this will look for
78     common keys in the matched dictionary and match if and only if all common
79     keys match the given matchers.
80
81     Thus::
82
83       >>> structure = {'a': Equals('x'), 'b': Equals('y')}
84       >>> _MatchCommonKeys(structure).match({'a': 'x', 'c': 'z'})
85       None
86     """
87
88     def __init__(self, dict_of_matchers):
89         super(_MatchCommonKeys, self).__init__()
90         self._matchers = dict_of_matchers
91
92     def _compare_dicts(self, expected, observed):
93         common_keys = set(expected.keys()) & set(observed.keys())
94         mismatches = {}
95         for key in common_keys:
96             mismatch = expected[key].match(observed[key])
97             if mismatch:
98                 mismatches[key] = mismatch
99         return mismatches
100
101     def match(self, observed):
102         mismatches = self._compare_dicts(self._matchers, observed)
103         if mismatches:
104             return DictMismatches(mismatches)
105
106
107 class _SubDictOf(Matcher):
108     """Matches if the matched dict only has keys that are in given dict."""
109
110     def __init__(self, super_dict, format_value=repr):
111         super(_SubDictOf, self).__init__()
112         self.super_dict = super_dict
113         self.format_value = format_value
114
115     def match(self, observed):
116         excess = dict_subtract(observed, self.super_dict)
117         return _dict_to_mismatch(
118             excess, lambda v: Mismatch(self.format_value(v)))
119
120
121 class _SuperDictOf(Matcher):
122     """Matches if all of the keys in the given dict are in the matched dict.
123     """
124
125     def __init__(self, sub_dict, format_value=repr):
126         super(_SuperDictOf, self).__init__()
127         self.sub_dict = sub_dict
128         self.format_value = format_value
129
130     def match(self, super_dict):
131         return _SubDictOf(super_dict, self.format_value).match(self.sub_dict)
132
133
134 def _format_matcher_dict(matchers):
135     return '{%s}' % (
136         ', '.join(sorted('%r: %s' % (k, v) for k, v in matchers.items())))
137
138
139 class _CombinedMatcher(Matcher):
140     """Many matchers labelled and combined into one uber-matcher.
141
142     Subclass this and then specify a dict of matcher factories that take a
143     single 'expected' value and return a matcher.  The subclass will match
144     only if all of the matchers made from factories match.
145
146     Not **entirely** dissimilar from ``MatchesAll``.
147     """
148
149     matcher_factories = {}
150
151     def __init__(self, expected):
152         super(_CombinedMatcher, self).__init__()
153         self._expected = expected
154
155     def format_expected(self, expected):
156         return repr(expected)
157
158     def __str__(self):
159         return '%s(%s)' % (
160             self.__class__.__name__, self.format_expected(self._expected))
161
162     def match(self, observed):
163         matchers = dict(
164             (k, v(self._expected)) for k, v in self.matcher_factories.items())
165         return MatchesAllDict(matchers).match(observed)
166
167
168 class MatchesDict(_CombinedMatcher):
169     """Match a dictionary exactly, by its keys.
170
171     Specify a dictionary mapping keys (often strings) to matchers.  This is
172     the 'expected' dict.  Any dictionary that matches this must have exactly
173     the same keys, and the values must match the corresponding matchers in the
174     expected dict.
175     """
176
177     matcher_factories = {
178         'Extra': _SubDictOf,
179         'Missing': lambda m: _SuperDictOf(m, format_value=str),
180         'Differences': _MatchCommonKeys,
181         }
182
183     format_expected = lambda self, expected: _format_matcher_dict(expected)
184
185
186 class ContainsDict(_CombinedMatcher):
187     """Match a dictionary for that contains a specified sub-dictionary.
188
189     Specify a dictionary mapping keys (often strings) to matchers.  This is
190     the 'expected' dict.  Any dictionary that matches this must have **at
191     least** these keys, and the values must match the corresponding matchers
192     in the expected dict.  Dictionaries that have more keys will also match.
193
194     In other words, any matching dictionary must contain the dictionary given
195     to the constructor.
196
197     Does not check for strict sub-dictionary.  That is, equal dictionaries
198     match.
199     """
200
201     matcher_factories = {
202         'Missing': lambda m: _SuperDictOf(m, format_value=str),
203         'Differences': _MatchCommonKeys,
204         }
205
206     format_expected = lambda self, expected: _format_matcher_dict(expected)
207
208
209 class ContainedByDict(_CombinedMatcher):
210     """Match a dictionary for which this is a super-dictionary.
211
212     Specify a dictionary mapping keys (often strings) to matchers.  This is
213     the 'expected' dict.  Any dictionary that matches this must have **only**
214     these keys, and the values must match the corresponding matchers in the
215     expected dict.  Dictionaries that have fewer keys can also match.
216
217     In other words, any matching dictionary must be contained by the
218     dictionary given to the constructor.
219
220     Does not check for strict super-dictionary.  That is, equal dictionaries
221     match.
222     """
223
224     matcher_factories = {
225         'Extra': _SubDictOf,
226         'Differences': _MatchCommonKeys,
227         }
228
229     format_expected = lambda self, expected: _format_matcher_dict(expected)
230
231
232 class KeysEqual(Matcher):
233     """Checks whether a dict has particular keys."""
234
235     def __init__(self, *expected):
236         """Create a `KeysEqual` Matcher.
237
238         :param expected: The keys the dict is expected to have.  If a dict,
239             then we use the keys of that dict, if a collection, we assume it
240             is a collection of expected keys.
241         """
242         super(KeysEqual, self).__init__()
243         try:
244             self.expected = expected.keys()
245         except AttributeError:
246             self.expected = list(expected)
247
248     def __str__(self):
249         return "KeysEqual(%s)" % ', '.join(map(repr, self.expected))
250
251     def match(self, matchee):
252         from ._basic import _BinaryMismatch, Equals
253         expected = sorted(self.expected)
254         matched = Equals(expected).match(sorted(matchee.keys()))
255         if matched:
256             return AnnotatedMismatch(
257                 'Keys not equal',
258                 _BinaryMismatch(expected, 'does not match', matchee))
259         return None