ldb: complex expression testing
[samba.git] / python / samba / tests / complex_expressions.py
1 # -*- coding: utf-8 -*-
2
3 # Copyright Andrew Bartlett 2018
4 #
5 # This program is free software; you can redistribute it and/or modify
6 # it under the terms of the GNU General Public License as published by
7 # the Free Software Foundation; either version 3 of the License, or
8 # (at your option) any later version.
9 #
10 # This program is distributed in the hope that it will be useful,
11 # but WITHOUT ANY WARRANTY; without even the implied warranty of
12 # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
13 # GNU General Public License for more details.
14 #
15 # You should have received a copy of the GNU General Public License
16 # along with this program.  If not, see <http://www.gnu.org/licenses/>.
17 #
18
19 from __future__ import print_function
20 import optparse
21 import samba
22 import samba.getopt as options
23 import sys
24 import os
25 import time
26 from samba.auth import system_session
27 from samba.tests import TestCase
28 import ldb
29
30 ERRCODE_ENTRY_EXISTS = 68
31 ERRCODE_OPERATIONS_ERROR = 1
32 ERRCODE_INVALID_VALUE = 21
33 ERRCODE_CLASS_VIOLATION = 65
34
35 parser = optparse.OptionParser("{0} <host>".format(sys.argv[0]))
36 sambaopts = options.SambaOptions(parser)
37
38 # use command line creds if available
39 credopts = options.CredentialsOptions(parser)
40 parser.add_option_group(credopts)
41 parser.add_option("-v", action="store_true", dest="verbose",
42                   help="print successful expression outputs")
43 opts, args = parser.parse_args()
44
45 if len(args) < 1:
46     parser.print_usage()
47     sys.exit(1)
48
49 lp = sambaopts.get_loadparm()
50 creds = credopts.get_credentials(lp)
51
52 # Set properly at end of file.
53 host = None
54
55 global ou_count
56 ou_count = 0
57
58
59 class ComplexExpressionTests(TestCase):
60     # Using setUpClass instead of setup because we're not modifying any
61     # records in the tests
62     @classmethod
63     def setUpClass(cls):
64         super(ComplexExpressionTests, cls).setUpClass()
65         cls.samdb = samba.samdb.SamDB(host, lp=lp,
66                                       session_info=system_session(),
67                                       credentials=creds)
68
69         ou_name = "ComplexExprTest"
70         cls.base_dn = "OU={0},{1}".format(ou_name, cls.samdb.domain_dn())
71
72         try:
73             cls.samdb.delete(cls.base_dn, ["tree_delete:1"])
74         except:
75             pass
76
77         try:
78             cls.samdb.create_ou(cls.base_dn)
79         except ldb.LdbError as e:
80             if e.args[0] == ERRCODE_ENTRY_EXISTS:
81                 print(('test ou {ou} already exists. Delete with '
82                        '"samba-tool group delete OU={ou} '
83                        '--force-subtree-delete"').format(ou=ou_name))
84             raise e
85
86         cls.name_template = "testuser{0}"
87         cls.default_n = 10
88
89         # These fields are carefully hand-picked from the schema. They have
90         # syntax and handling appropriate for our test structure.
91         cls.largeint_f = "accountExpires"
92         cls.str_f = "accountNameHistory"
93         cls.int_f = "flags"
94         cls.enum_f = "preferredDeliveryMethod"
95         cls.time_f = "msTSExpireDate"
96         cls.ranged_int_f = "countryCode"
97
98     @classmethod
99     def tearDownClass(cls):
100         cls.samdb.delete(cls.base_dn, ["tree_delete:1"])
101
102     # Make test OU containing users with field=val for each val
103     def make_test_objects(self, field, vals):
104         global ou_count
105         ou_count += 1
106         ou_dn = "OU=testou{0},{1}".format(ou_count, self.base_dn)
107         self.samdb.create_ou(ou_dn)
108
109         ldap_objects = [{"dn": "CN=testuser{0},{1}".format(n, ou_dn),
110                          "name": self.name_template.format(n),
111                          "objectClass": "user",
112                          field: n}
113                         for n in vals]
114
115         for ldap_object in ldap_objects:
116             # It's useful to keep appropriate python types in the ldap_object
117             # dict but smdb's 'add' function expects strings.
118             stringed_ldap_object = {k: str(v)
119                                     for (k, v) in ldap_object.items()}
120             try:
121                 self.samdb.add(stringed_ldap_object)
122             except ldb.LdbError as e:
123                 print("failed to add %s" % (stringed_ldap_object))
124                 raise e
125
126         return ou_dn, ldap_objects
127
128     # Run search expr and print out time.  This function should be used for
129     # almost all searching.
130     def time_ldap_search(self, expr, dn):
131         time_taken = 0
132         try:
133             start_time = time.time()
134             res = self.samdb.search(base=dn,
135                                     scope=ldb.SCOPE_SUBTREE,
136                                     expression=expr)
137             time_taken = time.time() - start_time
138         except Exception as e:
139             print("failed expr " + expr)
140             raise e
141         print("{0} took {1}s".format(expr, time_taken))
142         return res, time_taken
143
144     # Take an ldap expression and an equivalent python expression.
145     # Run and time the ldap expression and compare the result to the python
146     # expression run over the a list of ldap_object dicts.
147     def assertLDAPQuery(self, ldap_expr, ou_dn, py_expr, ldap_objects):
148
149         # run (and time) the LDAP search expression over the DB
150         res, time_taken = self.time_ldap_search(ldap_expr, ou_dn)
151         results = {str(row.get('name')[0]) for row in res}
152
153         # build the set of expected results by evaluating the python-equivalent
154         # of the search expression over the same set of objects
155         expected_results = set()
156         for ldap_object in ldap_objects:
157             try:
158                 final_expr = py_expr.format(**ldap_object)
159             except KeyError:
160                 # If the format on the py_expr hits a key error, then
161                 # ldap_object doesn't have the field, so it shouldn't match.
162                 continue
163
164             if eval(final_expr):
165                 expected_results.add(str(ldap_object['name']))
166
167         self.assertEqual(results, expected_results)
168
169         if opts.verbose:
170             ldap_object_names = {l['name'] for l in ldap_objects}
171             excluded = ldap_object_names - results
172             excluded = "\n  ".join(excluded) or "[NOTHING]"
173             returned = "\n  ".join(expected_results) or "[NOTHING]"
174
175             print("PASS: Expression {0} took {1}s and returned:"
176                   "\n  {2}\n"
177                   "Excluded:\n  {3}\n".format(ldap_expr,
178                                               time_taken,
179                                               returned,
180                                               excluded))
181
182     # Basic integer range test
183     def test_int_range(self, field=None):
184         n = self.default_n
185         field = field or self.int_f
186         ou_dn, ldap_objects = self.make_test_objects(field, range(n))
187
188         expr = "(&(%s>=%s)(%s<=%s))" % (field, n-1, field, n+1)
189         py_expr = "%d <= {%s} <= %d" % (n-1, field, n+1)
190         self.assertLDAPQuery(expr, ou_dn, py_expr, ldap_objects)
191
192     # Same test again for largeint and enum
193     def test_largeint_range(self):
194         self.test_int_range(self.largeint_f)
195
196     def test_enum_range(self):
197         self.test_int_range(self.enum_f)
198
199     # Special range test for integer field with upper and lower bounds defined.
200     # The bounds are checked on insertion, not search, so we should be able
201     # to compare to a constant that is outside bounds.
202     def test_ranged_int_range(self):
203         field = self.ranged_int_f
204         ubound = 2**16
205         width = 8
206
207         vals = list(range(ubound-width, ubound))
208         ou_dn, ldap_objects = self.make_test_objects(field, vals)
209
210         # Check <= value above overflow returns all vals
211         expr = "(%s<=%d)" % (field, ubound+5)
212         py_expr = "{%s} <= %d" % (field, ubound+5)
213         self.assertLDAPQuery(expr, ou_dn, py_expr, ldap_objects)
214
215     # Test range also works for time fields
216     def test_time_range(self):
217         n = self.default_n
218         field = self.time_f
219         n = self.default_n
220         width = int(n/2)
221
222         base_time = 20050116175514
223         time_range = [base_time + t for t in range(-width, width)]
224         time_range = [str(t) + ".0Z" for t in time_range]
225         ou_dn, ldap_objects = self.make_test_objects(field, time_range)
226
227         expr = "(%s<=%s)" % (field, str(base_time) + ".0Z")
228         py_expr = 'int("{%s}"[:-3]) <= %d' % (field, base_time)
229         self.assertLDAPQuery(expr, ou_dn, py_expr, ldap_objects)
230
231         expr = "(&(%s>=%s)(%s<=%s))" % (field, str(base_time-1) + ".0Z",
232                                         field, str(base_time+1) + ".0Z")
233         py_expr = '%d <= int("{%s}"[:-3]) <= %d' % (base_time-1,
234                                                     field,
235                                                     base_time+1)
236         self.assertLDAPQuery(expr, ou_dn, py_expr, ldap_objects)
237
238     # Run each comparison op on a simple test set.  Time taken will be printed.
239     def test_int_single_cmp_op_speeds(self, field=None):
240         n = self.default_n
241         field = field or self.int_f
242         ou_dn, ldap_objects = self.make_test_objects(field, range(n))
243
244         comp_ops = ['=', '<=', '>=']
245         py_comp_ops = ['==', '<=', '>=']
246         exprs = ["(%s%s%d)" % (field, c, n) for c in comp_ops]
247         py_exprs = ["{%s}%s%d" % (field, c, n) for c in py_comp_ops]
248
249         for expr, py_expr in zip(exprs, py_exprs):
250             self.assertLDAPQuery(expr, ou_dn, py_expr, ldap_objects)
251
252     def test_largeint_single_cmp_op_speeds(self):
253         self.test_int_single_cmp_op_speeds(self.largeint_f)
254
255     def test_enum_single_cmp_op_speeds(self):
256         self.test_int_single_cmp_op_speeds(self.enum_f)
257
258     # Check strings are ordered using a naive ordering.
259     def test_str_ordering(self):
260         field = self.str_f
261         a_ord = ord('A')
262         n = 10
263         str_range = ['abc{0}d'.format(chr(c)) for c in range(a_ord, a_ord+n)]
264         ou_dn, ldap_objects = self.make_test_objects(field, str_range)
265         half_n = int(a_ord + n/2)
266
267         # Basic <= and >= statements
268         expr = "(%s>=abc%s)" % (field, chr(half_n))
269         py_expr = "'{%s}' >= 'abc%s'" % (field, chr(half_n))
270         self.assertLDAPQuery(expr, ou_dn, py_expr, ldap_objects)
271
272         expr = "(%s<=abc%s)" % (field, chr(half_n))
273         py_expr = "'{%s}' <= 'abc%s'" % (field, chr(half_n))
274         self.assertLDAPQuery(expr, ou_dn, py_expr, ldap_objects)
275
276         # String range
277         expr = "(&(%s>=abc%s)(%s<=abc%s))" % (field, chr(half_n-2),
278                                               field, chr(half_n+2))
279         py_expr = "'abc%s' <= '{%s}' <= 'abc%s'" % (chr(half_n-2),
280                                                     field,
281                                                     chr(half_n+2))
282         self.assertLDAPQuery(expr, ou_dn, py_expr, ldap_objects)
283
284         # Integers treated as string
285         expr = "(%s>=1)" % (field)
286         py_expr = "'{%s}' >= '1'" % (field)
287         self.assertLDAPQuery(expr, ou_dn, py_expr, ldap_objects)
288
289     # Windows returns nothing for invalid expressions. Expected fail on samba.
290     def test_invalid_expressions(self, field=None):
291         field = field or self.int_f
292         n = self.default_n
293         ou_dn, ldap_objects = self.make_test_objects(field, list(range(n)))
294         int_expressions = ["(%s>=abc)",
295                            "(%s<=abc)",
296                            "(%s=abc)"]
297
298         for expr in int_expressions:
299             expr = expr % (field)
300             self.assertLDAPQuery(expr, ou_dn, "False", ldap_objects)
301
302     def test_largeint_invalid_expressions(self):
303         self.test_invalid_expressions(self.largeint_f)
304
305     def test_enum_invalid_expressions(self):
306         self.test_invalid_expressions(self.enum_f)
307
308     def test_case_insensitive(self):
309         str_range = ["äbc"+str(n) for n in range(10)]
310         ou_dn, ldap_objects = self.make_test_objects(self.str_f, str_range)
311
312         expr = "(%s=äbc1)" % (self.str_f)
313         pyexpr = '"{%s}"=="äbc1"' % (self.str_f)
314         self.assertLDAPQuery(expr, ou_dn, pyexpr, ldap_objects)
315
316         expr = "(%s=ÄbC1)" % (self.str_f)
317         self.assertLDAPQuery(expr, ou_dn, pyexpr, ldap_objects)
318
319     # Check negative numbers can be entered and compared
320     def test_negative_cmp(self, field=None):
321         field = field or self.int_f
322         width = 6
323         around_zero = list(range(-width, width))
324         ou_dn, ldap_objects = self.make_test_objects(field, around_zero)
325
326         expr = "(%s>=-3)" % (field)
327         py_expr = "{%s} >= -3" % (field)
328         self.assertLDAPQuery(expr, ou_dn, py_expr, ldap_objects)
329
330     def test_negative_cmp_largeint(self):
331         self.test_negative_cmp(self.largeint_f)
332
333     def test_negative_cmp_enum(self):
334         self.test_negative_cmp(self.enum_f)
335
336     # Check behaviour on insertion and comparison of zero-prefixed numbers.
337     # Samba errors on insertion, Windows strips the leading zeroes.
338     def test_zero_prefix(self, field=None):
339         field = field or self.int_f
340
341         # Test comparison with 0-prefixed constants.
342         n = self.default_n
343         ou_dn, ldap_objects = self.make_test_objects(field, list(range(n)))
344
345         expr = "(%s>=00%d)" % (field, n/2)
346         py_expr = "{%s} >= %d" % (field, n/2)
347         self.assertLDAPQuery(expr, ou_dn, py_expr, ldap_objects)
348
349         # Delete the test OU so we don't mix it up with the next one.
350         self.samdb.delete(ou_dn, ["tree_delete:1"])
351
352         # Try inserting 0-prefixed numbers, check it fails.
353         zero_pref_nums = ['00'+str(num) for num in range(n)]
354         try:
355             ou_dn, ldap_objects = self.make_test_objects(field, zero_pref_nums)
356         except ldb.LdbError as e:
357             if e.args[0] != ERRCODE_INVALID_VALUE:
358                 raise e
359             return
360
361         # Samba doesn't get this far - the exception is raised.  Windows allows
362         # the insertion and removes the leading 0s as tested below.
363         # Either behaviour is fine.
364         print("LDAP allowed insertion of 0-prefixed nums for field " + field)
365
366         res = self.samdb.search(base=ou_dn,
367                                 scope=ldb.SCOPE_SUBTREE,
368                                 expression="(objectClass=user)")
369         returned_nums = [str(r.get(field)[0]) for r in res]
370         expect = [str(n) for n in range(n)]
371         self.assertEqual(set(returned_nums), set(expect))
372
373         expr = "(%s>=%d)" % (field, n/2)
374         py_expr = "{%s} >= %d" % (field, n/2)
375         for ldap_object in ldap_objects:
376             ldap_object[field] = int(ldap_object[field])
377
378         self.assertLDAPQuery(expr, ou_dn, py_expr, ldap_objects)
379
380     def test_zero_prefix_largeint(self):
381         self.test_zero_prefix(self.largeint_f)
382
383     def test_zero_prefix_enum(self):
384         self.test_zero_prefix(self.enum_f)
385
386     # Check integer overflow is handled as best it can be.
387     def test_int_overflow(self, field=None, of=None):
388         field = field or self.int_f
389         of = of or 2**31-1
390         width = 8
391
392         vals = list(range(of-width, of+width))
393         ou_dn, ldap_objects = self.make_test_objects(field, vals)
394
395         # Check ">=overflow" doesn't return vals past overflow
396         expr = "(%s>=%d)" % (field, of-3)
397         py_expr = "%d <= {%s} <= %d" % (of-3, field, of)
398         self.assertLDAPQuery(expr, ou_dn, py_expr, ldap_objects)
399
400         # "<=overflow" returns everything
401         expr = "(%s<=%d)" % (field, of)
402         py_expr = "True"
403         self.assertLDAPQuery(expr, ou_dn, py_expr, ldap_objects)
404
405         # Values past overflow should be negative
406         expr = "(&(%s<=%d)(%s>=0))" % (field, of, field)
407         py_expr = "{%s} <= %d" % (field, of)
408         self.assertLDAPQuery(expr, ou_dn, py_expr, ldap_objects)
409         expr = "(%s<=0)" % (field)
410         py_expr = "{%s} >= %d" % (field, of+1)
411         self.assertLDAPQuery(expr, ou_dn, py_expr, ldap_objects)
412
413         # Get the values back out and check vals past overflow are negative.
414         res = self.samdb.search(base=ou_dn,
415                                 scope=ldb.SCOPE_SUBTREE,
416                                 expression="(objectClass=user)")
417         returned_nums = [str(r.get(field)[0]) for r in res]
418
419         # Note: range(a,b) == [a..b-1] (confusing)
420         up_to_overflow = list(range(of-width, of+1))
421         negatives = list(range(-of-1, -of+width-2))
422
423         expect = [str(n) for n in up_to_overflow + negatives]
424         self.assertEqual(set(returned_nums), set(expect))
425
426     def test_enum_overflow(self):
427         self.test_int_overflow(self.enum_f, 2**31-1)
428
429     # Check cmp works on uSNChanged. We can't insert uSNChanged vals, they get
430     # added automatically so we'll just insert some objects and go with what
431     # we get.
432     def test_usnchanged(self):
433         field = "uSNChanged"
434         n = 10
435         # Note we can't actually set uSNChanged via LDAP (LDB ignores it),
436         # so the input val range doesn't matter here
437         ou_dn, _ = self.make_test_objects(field, list(range(n)))
438
439         # Get the assigned uSNChanged values
440         res = self.samdb.search(base=ou_dn,
441                                 scope=ldb.SCOPE_SUBTREE,
442                                 expression="(objectClass=user)")
443
444         # Our vals got ignored so make ldap_objects from search result
445         ldap_objects = [{'name': str(r['name'][0]),
446                          field: int(r[field][0])}
447                         for r in res]
448
449         # Get the median val and use as the number in the test search expr.
450         nums = [l[field] for l in ldap_objects]
451         nums = list(sorted(nums))
452         search_num = nums[int(len(nums)/2)]
453
454         expr = "(&(%s<=%d)(objectClass=user))" % (field, search_num)
455         py_expr = "{%s} <= %d" % (field, search_num)
456         self.assertLDAPQuery(expr, ou_dn, py_expr, ldap_objects)
457
458 # If we're called independently then import subunit, get host from first
459 # arg and run.  Otherwise, subunit ran us so just set host from env.
460 # We always try to run over LDAP rather than direct file, so that
461 # search timings are not impacted by opening and closing the tdb file.
462 if __name__ == "__main__":
463     from samba.tests.subunitrun import TestProgram
464     host = args[0]
465
466     if "://" not in host:
467         if os.path.isfile(host):
468             host = "tdb://%s" % host
469         else:
470             host = "ldap://%s" % host
471     TestProgram(module=__name__)
472 else:
473     host = "ldap://" + os.getenv("SERVER")