dsdb: Add tokenGroupsGlobalAndUniversal, tokenGroups, tokenGroupsNoGCAcceptable
[samba.git] / source4 / dsdb / tests / python / token_group.py
1 #!/usr/bin/env python
2 # -*- coding: utf-8 -*-
3 # test tokengroups attribute against internal token calculation
4
5 import optparse
6 import sys
7 import os
8
9 sys.path.insert(0, "bin/python")
10 import samba
11
12 from samba.tests.subunitrun import SubunitOptions, TestProgram
13
14 import samba.getopt as options
15
16 from samba.auth import system_session
17 from samba import ldb, dsdb
18 from samba.samdb import SamDB
19 from samba.auth import AuthContext
20 from samba.ndr import ndr_unpack
21 from samba import gensec
22 from samba.credentials import Credentials, DONT_USE_KERBEROS
23 from samba.dsdb import GTYPE_SECURITY_GLOBAL_GROUP, GTYPE_SECURITY_UNIVERSAL_GROUP
24 import samba.tests
25 from samba.tests import delete_force
26
27 from samba.auth import AUTH_SESSION_INFO_DEFAULT_GROUPS, AUTH_SESSION_INFO_AUTHENTICATED, AUTH_SESSION_INFO_SIMPLE_PRIVILEGES
28
29
30 parser = optparse.OptionParser("ldap.py [options] <host>")
31 sambaopts = options.SambaOptions(parser)
32 parser.add_option_group(sambaopts)
33 parser.add_option_group(options.VersionOptions(parser))
34 # use command line creds if available
35 credopts = options.CredentialsOptions(parser)
36 parser.add_option_group(credopts)
37 subunitopts = SubunitOptions(parser)
38 parser.add_option_group(subunitopts)
39 opts, args = parser.parse_args()
40
41 if len(args) < 1:
42     parser.print_usage()
43     sys.exit(1)
44
45 url = args[0]
46
47 lp = sambaopts.get_loadparm()
48 creds = credopts.get_credentials(lp)
49 creds.set_gensec_features(creds.get_gensec_features() | gensec.FEATURE_SEAL)
50
51 def closure(vSet, wSet, aSet):
52     for edge in aSet:
53         start, end = edge
54         if start in wSet:
55             if end not in wSet and end in vSet:
56                 wSet.add(end)
57                 closure(vSet, wSet, aSet)
58
59 class StaticTokenTest(samba.tests.TestCase):
60
61     def setUp(self):
62         super(StaticTokenTest, self).setUp()
63         self.ldb = SamDB(url, credentials=creds, session_info=system_session(lp), lp=lp)
64         self.base_dn = self.ldb.domain_dn()
65
66         res = self.ldb.search("", scope=ldb.SCOPE_BASE, attrs=["tokenGroups"])
67         self.assertEquals(len(res), 1)
68
69         self.user_sid_dn = "<SID=%s>" % str(ndr_unpack(samba.dcerpc.security.dom_sid, res[0]["tokenGroups"][0]))
70
71         session_info_flags = ( AUTH_SESSION_INFO_DEFAULT_GROUPS |
72                                AUTH_SESSION_INFO_AUTHENTICATED |
73                                AUTH_SESSION_INFO_SIMPLE_PRIVILEGES)
74         session = samba.auth.user_session(self.ldb, lp_ctx=lp, dn=self.user_sid_dn,
75                                           session_info_flags=session_info_flags)
76
77         token = session.security_token
78         self.user_sids = []
79         for s in token.sids:
80             self.user_sids.append(str(s))
81
82     def test_rootDSE_tokenGroups(self):
83         """Testing rootDSE tokengroups against internal calculation"""
84         if not url.startswith("ldap"):
85             self.fail(msg="This test is only valid on ldap")
86
87         res = self.ldb.search("", scope=ldb.SCOPE_BASE, attrs=["tokenGroups"])
88         self.assertEquals(len(res), 1)
89
90         print("Getting tokenGroups from rootDSE")
91         tokengroups = []
92         for sid in res[0]['tokenGroups']:
93             tokengroups.append(str(ndr_unpack(samba.dcerpc.security.dom_sid, sid)))
94
95         sidset1 = set(tokengroups)
96         sidset2 = set(self.user_sids)
97         if len(sidset1.difference(sidset2)):
98             print("token sids don't match")
99             print("tokengroups: %s" % tokengroups)
100             print("calculated : %s" % self.user_sids)
101             print("difference : %s" % sidset1.difference(sidset2))
102             self.fail(msg="calculated groups don't match against rootDSE tokenGroups")
103
104     def test_dn_tokenGroups(self):
105         print("Getting tokenGroups from user DN")
106         res = self.ldb.search(self.user_sid_dn, scope=ldb.SCOPE_BASE, attrs=["tokenGroups"])
107         self.assertEquals(len(res), 1)
108
109         dn_tokengroups = []
110         for sid in res[0]['tokenGroups']:
111             dn_tokengroups.append(str(ndr_unpack(samba.dcerpc.security.dom_sid, sid)))
112
113         sidset1 = set(dn_tokengroups)
114         sidset2 = set(self.user_sids)
115         if len(sidset1.difference(sidset2)):
116             print("token sids don't match")
117             print("difference : %s" % sidset1.difference(sidset2))
118             self.fail(msg="calculated groups don't match against user DN tokenGroups")
119
120     def test_pac_groups(self):
121         settings = {}
122         settings["lp_ctx"] = lp
123         settings["target_hostname"] = lp.get("netbios name")
124
125         gensec_client = gensec.Security.start_client(settings)
126         gensec_client.set_credentials(creds)
127         gensec_client.want_feature(gensec.FEATURE_SEAL)
128         gensec_client.start_mech_by_sasl_name("GSSAPI")
129
130         auth_context = AuthContext(lp_ctx=lp, ldb=self.ldb, methods=[])
131
132         gensec_server = gensec.Security.start_server(settings, auth_context)
133         machine_creds = Credentials()
134         machine_creds.guess(lp)
135         machine_creds.set_machine_account(lp)
136         gensec_server.set_credentials(machine_creds)
137
138         gensec_server.want_feature(gensec.FEATURE_SEAL)
139         gensec_server.start_mech_by_sasl_name("GSSAPI")
140
141         client_finished = False
142         server_finished = False
143         server_to_client = ""
144
145         # Run the actual call loop.
146         while client_finished == False and server_finished == False:
147             if not client_finished:
148                 print "running client gensec_update"
149                 (client_finished, client_to_server) = gensec_client.update(server_to_client)
150             if not server_finished:
151                 print "running server gensec_update"
152                 (server_finished, server_to_client) = gensec_server.update(client_to_server)
153
154         session = gensec_server.session_info()
155
156         token = session.security_token
157         pac_sids = []
158         for s in token.sids:
159             pac_sids.append(str(s))
160
161         sidset1 = set(pac_sids)
162         sidset2 = set(self.user_sids)
163         if len(sidset1.difference(sidset2)):
164             print("token sids don't match")
165             print("difference : %s" % sidset1.difference(sidset2))
166             self.fail(msg="calculated groups don't match against user PAC tokenGroups")
167
168 class DynamicTokenTest(samba.tests.TestCase):
169
170     def get_creds(self, target_username, target_password):
171         creds_tmp = Credentials()
172         creds_tmp.set_username(target_username)
173         creds_tmp.set_password(target_password)
174         creds_tmp.set_domain(creds.get_domain())
175         creds_tmp.set_realm(creds.get_realm())
176         creds_tmp.set_workstation(creds.get_workstation())
177         creds_tmp.set_gensec_features(creds_tmp.get_gensec_features()
178                                       | gensec.FEATURE_SEAL)
179         return creds_tmp
180
181     def get_ldb_connection(self, target_username, target_password):
182         creds_tmp = self.get_creds(target_username, target_password)
183         ldb_target = SamDB(url=url, credentials=creds_tmp, lp=lp)
184         return ldb_target
185
186     def setUp(self):
187         super(DynamicTokenTest, self).setUp()
188         self.admin_ldb = SamDB(url, credentials=creds, session_info=system_session(lp), lp=lp)
189
190         self.base_dn = self.admin_ldb.domain_dn()
191
192         self.test_user = "tokengroups_user1"
193         self.test_user_pass = "samba123@"
194         self.admin_ldb.newuser(self.test_user, self.test_user_pass)
195         self.test_group0 = "tokengroups_group0"
196         self.admin_ldb.newgroup(self.test_group0, grouptype=dsdb.GTYPE_SECURITY_DOMAIN_LOCAL_GROUP)
197         res = self.admin_ldb.search(base="cn=%s,cn=users,%s" % (self.test_group0, self.base_dn),
198                                     attrs=["objectSid"], scope=ldb.SCOPE_BASE)
199         self.test_group0_sid = ndr_unpack(samba.dcerpc.security.dom_sid, res[0]["objectSid"][0])
200
201         self.admin_ldb.add_remove_group_members(self.test_group0, [self.test_user],
202                                        add_members_operation=True)
203
204         self.test_group1 = "tokengroups_group1"
205         self.admin_ldb.newgroup(self.test_group1, grouptype=dsdb.GTYPE_SECURITY_GLOBAL_GROUP)
206         res = self.admin_ldb.search(base="cn=%s,cn=users,%s" % (self.test_group1, self.base_dn),
207                                     attrs=["objectSid"], scope=ldb.SCOPE_BASE)
208         self.test_group1_sid = ndr_unpack(samba.dcerpc.security.dom_sid, res[0]["objectSid"][0])
209
210         self.admin_ldb.add_remove_group_members(self.test_group1, [self.test_user],
211                                        add_members_operation=True)
212
213         self.test_group2 = "tokengroups_group2"
214         self.admin_ldb.newgroup(self.test_group2, grouptype=dsdb.GTYPE_SECURITY_UNIVERSAL_GROUP)
215
216         res = self.admin_ldb.search(base="cn=%s,cn=users,%s" % (self.test_group2, self.base_dn),
217                                     attrs=["objectSid"], scope=ldb.SCOPE_BASE)
218         self.test_group2_sid = ndr_unpack(samba.dcerpc.security.dom_sid, res[0]["objectSid"][0])
219
220         self.admin_ldb.add_remove_group_members(self.test_group2, [self.test_user],
221                                        add_members_operation=True)
222
223         self.ldb = self.get_ldb_connection(self.test_user, self.test_user_pass)
224
225         res = self.ldb.search("", scope=ldb.SCOPE_BASE, attrs=["tokenGroups"])
226         self.assertEquals(len(res), 1)
227
228         self.user_sid_dn = "<SID=%s>" % str(ndr_unpack(samba.dcerpc.security.dom_sid, res[0]["tokenGroups"][0]))
229
230         res = self.ldb.search(self.user_sid_dn, scope=ldb.SCOPE_BASE, attrs=[])
231         self.assertEquals(len(res), 1)
232
233         self.test_user_dn = res[0].dn
234
235         session_info_flags = ( AUTH_SESSION_INFO_DEFAULT_GROUPS |
236                                AUTH_SESSION_INFO_AUTHENTICATED |
237                                AUTH_SESSION_INFO_SIMPLE_PRIVILEGES)
238         session = samba.auth.user_session(self.ldb, lp_ctx=lp, dn=self.user_sid_dn,
239                                           session_info_flags=session_info_flags)
240
241         token = session.security_token
242         self.user_sids = []
243         for s in token.sids:
244             self.user_sids.append(str(s))
245
246     def tearDown(self):
247         super(DynamicTokenTest, self).tearDown()
248         delete_force(self.admin_ldb, "CN=%s,%s,%s" %
249                           (self.test_user, "cn=users", self.base_dn))
250         delete_force(self.admin_ldb, "CN=%s,%s,%s" %
251                           (self.test_group0, "cn=users", self.base_dn))
252         delete_force(self.admin_ldb, "CN=%s,%s,%s" %
253                           (self.test_group1, "cn=users", self.base_dn))
254         delete_force(self.admin_ldb, "CN=%s,%s,%s" %
255                           (self.test_group2, "cn=users", self.base_dn))
256
257     def test_rootDSE_tokenGroups(self):
258         """Testing rootDSE tokengroups against internal calculation"""
259         if not url.startswith("ldap"):
260             self.fail(msg="This test is only valid on ldap")
261
262         res = self.ldb.search("", scope=ldb.SCOPE_BASE, attrs=["tokenGroups"])
263         self.assertEquals(len(res), 1)
264
265         print("Getting tokenGroups from rootDSE")
266         tokengroups = []
267         for sid in res[0]['tokenGroups']:
268             tokengroups.append(str(ndr_unpack(samba.dcerpc.security.dom_sid, sid)))
269
270         sidset1 = set(tokengroups)
271         sidset2 = set(self.user_sids)
272         if len(sidset1.difference(sidset2)):
273             print("token sids don't match")
274             print("tokengroups: %s" % tokengroups)
275             print("calculated : %s" % self.user_sids)
276             print("difference : %s" % sidset1.difference(sidset2))
277             self.fail(msg="calculated groups don't match against rootDSE tokenGroups")
278
279     def test_dn_tokenGroups(self):
280         print("Getting tokenGroups from user DN")
281         res = self.ldb.search(self.user_sid_dn, scope=ldb.SCOPE_BASE, attrs=["tokenGroups"])
282         self.assertEquals(len(res), 1)
283
284         dn_tokengroups = []
285         for sid in res[0]['tokenGroups']:
286             dn_tokengroups.append(str(ndr_unpack(samba.dcerpc.security.dom_sid, sid)))
287
288         sidset1 = set(dn_tokengroups)
289         sidset2 = set(self.user_sids)
290         if len(sidset1.difference(sidset2)):
291             print("token sids don't match")
292             print("difference : %s" % sidset1.difference(sidset2))
293             self.fail(msg="calculated groups don't match against user DN tokenGroups")
294
295     def test_pac_groups(self):
296         settings = {}
297         settings["lp_ctx"] = lp
298         settings["target_hostname"] = lp.get("netbios name")
299
300         gensec_client = gensec.Security.start_client(settings)
301         gensec_client.set_credentials(self.get_creds(self.test_user, self.test_user_pass))
302         gensec_client.want_feature(gensec.FEATURE_SEAL)
303         gensec_client.start_mech_by_sasl_name("GSSAPI")
304
305         auth_context = AuthContext(lp_ctx=lp, ldb=self.ldb, methods=[])
306
307         gensec_server = gensec.Security.start_server(settings, auth_context)
308         machine_creds = Credentials()
309         machine_creds.guess(lp)
310         machine_creds.set_machine_account(lp)
311         gensec_server.set_credentials(machine_creds)
312
313         gensec_server.want_feature(gensec.FEATURE_SEAL)
314         gensec_server.start_mech_by_sasl_name("GSSAPI")
315
316         client_finished = False
317         server_finished = False
318         server_to_client = ""
319
320         # Run the actual call loop.
321         while client_finished == False and server_finished == False:
322             if not client_finished:
323                 print "running client gensec_update"
324                 (client_finished, client_to_server) = gensec_client.update(server_to_client)
325             if not server_finished:
326                 print "running server gensec_update"
327                 (server_finished, server_to_client) = gensec_server.update(client_to_server)
328
329         session = gensec_server.session_info()
330
331         token = session.security_token
332         pac_sids = []
333         for s in token.sids:
334             pac_sids.append(str(s))
335
336         sidset1 = set(pac_sids)
337         sidset2 = set(self.user_sids)
338         if len(sidset1.difference(sidset2)):
339             print("token sids don't match")
340             print("difference : %s" % sidset1.difference(sidset2))
341             self.fail(msg="calculated groups don't match against user PAC tokenGroups")
342
343
344     def test_tokenGroups_manual(self):
345         # Manually run the tokenGroups algorithm from MS-ADTS 3.1.1.4.5.19 and MS-DRSR 4.1.8.3
346         # and compare the result
347         res = self.admin_ldb.search(base=self.base_dn, scope=ldb.SCOPE_SUBTREE,
348                                     expression="(|(objectclass=user)(objectclass=group))",
349                                     attrs=["memberOf"])
350         aSet = set()
351         aSetR = set()
352         vSet = set()
353         for obj in res:
354             if "memberOf" in obj:
355                 for dn in obj["memberOf"]:
356                     first = obj.dn.get_casefold()
357                     second = ldb.Dn(self.admin_ldb, dn).get_casefold()
358                     aSet.add((first, second))
359                     aSetR.add((second, first))
360                     vSet.add(first)
361                     vSet.add(second)
362
363         res = self.admin_ldb.search(base=self.base_dn, scope=ldb.SCOPE_SUBTREE,
364                                     expression="(objectclass=user)",
365                                     attrs=["primaryGroupID"])
366         for obj in res:
367             if "primaryGroupID" in obj:
368                 sid = "%s-%d" % (self.admin_ldb.get_domain_sid(), int(obj["primaryGroupID"][0]))
369                 res2 = self.admin_ldb.search(base="<SID=%s>" % sid, scope=ldb.SCOPE_BASE,
370                                              attrs=[])
371                 first = obj.dn.get_casefold()
372                 second = res2[0].dn.get_casefold()
373
374                 aSet.add((first, second))
375                 aSetR.add((second, first))
376                 vSet.add(first)
377                 vSet.add(second)
378
379         wSet = set()
380         wSet.add(self.test_user_dn.get_casefold())
381         closure(vSet, wSet, aSet)
382         wSet.remove(self.test_user_dn.get_casefold())
383
384         tokenGroupsSet = set()
385
386         res = self.ldb.search(self.user_sid_dn, scope=ldb.SCOPE_BASE, attrs=["tokenGroups"])
387         self.assertEquals(len(res), 1)
388
389         dn_tokengroups = []
390         for sid in res[0]['tokenGroups']:
391             sid = ndr_unpack(samba.dcerpc.security.dom_sid, sid)
392             res3 = self.admin_ldb.search(base="<SID=%s>" % sid, scope=ldb.SCOPE_BASE,
393                                          attrs=[])
394             tokenGroupsSet.add(res3[0].dn.get_casefold())
395
396         if len(wSet.difference(tokenGroupsSet)):
397             self.fail(msg="additional calculated: %s" % wSet.difference(tokenGroupsSet))
398
399         if len(tokenGroupsSet.difference(wSet)):
400             self.fail(msg="additional tokenGroups: %s" % tokenGroupsSet.difference(wSet))
401
402
403     def filtered_closure(self, wSet, filter_grouptype):
404         res = self.admin_ldb.search(base=self.base_dn, scope=ldb.SCOPE_SUBTREE,
405                                     expression="(|(objectclass=user)(objectclass=group))",
406                                     attrs=["memberOf"])
407         aSet = set()
408         aSetR = set()
409         vSet = set()
410         for obj in res:
411             vSet.add(obj.dn.get_casefold())
412             if "memberOf" in obj:
413                 for dn in obj["memberOf"]:
414                     first = obj.dn.get_casefold()
415                     second = ldb.Dn(self.admin_ldb, dn).get_casefold()
416                     aSet.add((first, second))
417                     aSetR.add((second, first))
418                     vSet.add(first)
419                     vSet.add(second)
420
421         res = self.admin_ldb.search(base=self.base_dn, scope=ldb.SCOPE_SUBTREE,
422                                     expression="(objectclass=user)",
423                                     attrs=["primaryGroupID"])
424         for obj in res:
425             if "primaryGroupID" in obj:
426                 sid = "%s-%d" % (self.admin_ldb.get_domain_sid(), int(obj["primaryGroupID"][0]))
427                 res2 = self.admin_ldb.search(base="<SID=%s>" % sid, scope=ldb.SCOPE_BASE,
428                                              attrs=[])
429                 first = obj.dn.get_casefold()
430                 second = res2[0].dn.get_casefold()
431
432                 aSet.add((first, second))
433                 aSetR.add((second, first))
434                 vSet.add(first)
435                 vSet.add(second)
436
437         uSet = set()
438         for v in vSet:
439             res_group = self.admin_ldb.search(base=v, scope=ldb.SCOPE_BASE,
440                                               attrs=["groupType"],
441                                               expression="objectClass=group")
442             if len(res_group) == 1:
443                 if hex(int(res_group[0]["groupType"][0]) & 0x00000000FFFFFFFF) == hex(filter_grouptype):
444                     uSet.add(v)
445             else:
446                 uSet.add(v)
447
448         closure(uSet, wSet, aSet)
449
450
451     def test_tokenGroupsGlobalAndUniversal_manual(self):
452         # Manually run the tokenGroups algorithm from MS-ADTS 3.1.1.4.5.19 and MS-DRSR 4.1.8.3
453         # and compare the result
454
455         # The variable names come from MS-ADTS May 15, 2014
456
457         S = set()
458         S.add(self.test_user_dn.get_casefold())
459
460         self.filtered_closure(S, GTYPE_SECURITY_GLOBAL_GROUP)
461
462         T = set()
463         # Not really a SID, we do this on DNs...
464         for sid in S:
465             X = set()
466             X.add(sid)
467             self.filtered_closure(X, GTYPE_SECURITY_UNIVERSAL_GROUP)
468
469             T = T.union(X)
470
471         T.remove(self.test_user_dn.get_casefold())
472
473         tokenGroupsSet = set()
474
475         res = self.ldb.search(self.user_sid_dn, scope=ldb.SCOPE_BASE, attrs=["tokenGroupsGlobalAndUniversal"])
476         self.assertEquals(len(res), 1)
477
478         dn_tokengroups = []
479         for sid in res[0]['tokenGroupsGlobalAndUniversal']:
480             sid = ndr_unpack(samba.dcerpc.security.dom_sid, sid)
481             res3 = self.admin_ldb.search(base="<SID=%s>" % sid, scope=ldb.SCOPE_BASE,
482                                          attrs=[])
483             tokenGroupsSet.add(res3[0].dn.get_casefold())
484
485         if len(T.difference(tokenGroupsSet)):
486             self.fail(msg="additional calculated: %s" % T.difference(tokenGroupsSet))
487
488         if len(tokenGroupsSet.difference(T)):
489             self.fail(msg="additional tokenGroupsGlobalAndUniversal: %s" % tokenGroupsSet.difference(T))
490
491 if not "://" in url:
492     if os.path.isfile(url):
493         url = "tdb://%s" % url
494     else:
495         url = "ldap://%s" % url
496
497 TestProgram(module=__name__, opts=subunitopts)