token_group: Use samba.tests.subunitrun.
[samba.git] / source4 / dsdb / tests / python / token_group.py
1 #!/usr/bin/env python
2 # -*- coding: utf-8 -*-
3 # test tokengroups attribute against internal token calculation
4
5 import optparse
6 import sys
7 import os
8
9 sys.path.insert(0, "bin/python")
10 import samba
11
12 from samba.tests.subunitrun import SubunitOptions, TestProgram
13
14 import samba.getopt as options
15
16 from samba.auth import system_session
17 from samba import ldb, dsdb
18 from samba.samdb import SamDB
19 from samba.auth import AuthContext
20 from samba.ndr import ndr_unpack
21 from samba import gensec
22 from samba.credentials import Credentials, DONT_USE_KERBEROS
23 from samba.dsdb import GTYPE_SECURITY_GLOBAL_GROUP, GTYPE_SECURITY_UNIVERSAL_GROUP
24
25 import samba.tests
26 from samba.tests import delete_force
27
28 from samba.auth import AUTH_SESSION_INFO_DEFAULT_GROUPS, AUTH_SESSION_INFO_AUTHENTICATED, AUTH_SESSION_INFO_SIMPLE_PRIVILEGES
29
30
31 parser = optparse.OptionParser("ldap.py [options] <host>")
32 sambaopts = options.SambaOptions(parser)
33 parser.add_option_group(sambaopts)
34 parser.add_option_group(options.VersionOptions(parser))
35 # use command line creds if available
36 credopts = options.CredentialsOptions(parser)
37 parser.add_option_group(credopts)
38 subunitopts = SubunitOptions(parser)
39 parser.add_option_group(subunitopts)
40 opts, args = parser.parse_args()
41
42 if len(args) < 1:
43     parser.print_usage()
44     sys.exit(1)
45
46 url = args[0]
47
48 lp = sambaopts.get_loadparm()
49 creds = credopts.get_credentials(lp)
50 creds.set_gensec_features(creds.get_gensec_features() | gensec.FEATURE_SEAL)
51
52 def closure(vSet, wSet, aSet):
53     for edge in aSet:
54         start, end = edge
55         if start in wSet:
56             if end not in wSet and end in vSet:
57                 wSet.add(end)
58                 closure(vSet, wSet, aSet)
59
60 class StaticTokenTest(samba.tests.TestCase):
61
62     def setUp(self):
63         super(StaticTokenTest, self).setUp()
64         self.ldb = SamDB(url, credentials=creds, session_info=system_session(lp), lp=lp)
65         self.base_dn = self.ldb.domain_dn()
66
67         res = self.ldb.search("", scope=ldb.SCOPE_BASE, attrs=["tokenGroups"])
68         self.assertEquals(len(res), 1)
69
70         self.user_sid_dn = "<SID=%s>" % str(ndr_unpack(samba.dcerpc.security.dom_sid, res[0]["tokenGroups"][0]))
71
72         session_info_flags = ( AUTH_SESSION_INFO_DEFAULT_GROUPS |
73                                AUTH_SESSION_INFO_AUTHENTICATED |
74                                AUTH_SESSION_INFO_SIMPLE_PRIVILEGES)
75         session = samba.auth.user_session(self.ldb, lp_ctx=lp, dn=self.user_sid_dn,
76                                           session_info_flags=session_info_flags)
77
78         token = session.security_token
79         self.user_sids = []
80         for s in token.sids:
81             self.user_sids.append(str(s))
82
83     def test_rootDSE_tokenGroups(self):
84         """Testing rootDSE tokengroups against internal calculation"""
85         if not url.startswith("ldap"):
86             self.fail(msg="This test is only valid on ldap")
87
88         res = self.ldb.search("", scope=ldb.SCOPE_BASE, attrs=["tokenGroups"])
89         self.assertEquals(len(res), 1)
90
91         print("Getting tokenGroups from rootDSE")
92         tokengroups = []
93         for sid in res[0]['tokenGroups']:
94             tokengroups.append(str(ndr_unpack(samba.dcerpc.security.dom_sid, sid)))
95
96         sidset1 = set(tokengroups)
97         sidset2 = set(self.user_sids)
98         if len(sidset1.difference(sidset2)):
99             print("token sids don't match")
100             print("tokengroups: %s" % tokengroups)
101             print("calculated : %s" % self.user_sids)
102             print("difference : %s" % sidset1.difference(sidset2))
103             self.fail(msg="calculated groups don't match against rootDSE tokenGroups")
104
105     def test_dn_tokenGroups(self):
106         print("Getting tokenGroups from user DN")
107         res = self.ldb.search(self.user_sid_dn, scope=ldb.SCOPE_BASE, attrs=["tokenGroups"])
108         self.assertEquals(len(res), 1)
109
110         dn_tokengroups = []
111         for sid in res[0]['tokenGroups']:
112             dn_tokengroups.append(str(ndr_unpack(samba.dcerpc.security.dom_sid, sid)))
113
114         sidset1 = set(dn_tokengroups)
115         sidset2 = set(self.user_sids)
116         if len(sidset1.difference(sidset2)):
117             print("token sids don't match")
118             print("difference : %s" % sidset1.difference(sidset2))
119             self.fail(msg="calculated groups don't match against user DN tokenGroups")
120
121     def test_pac_groups(self):
122         settings = {}
123         settings["lp_ctx"] = lp
124         settings["target_hostname"] = lp.get("netbios name")
125
126         gensec_client = gensec.Security.start_client(settings)
127         gensec_client.set_credentials(creds)
128         gensec_client.want_feature(gensec.FEATURE_SEAL)
129         gensec_client.start_mech_by_sasl_name("GSSAPI")
130
131         auth_context = AuthContext(lp_ctx=lp, ldb=self.ldb, methods=[])
132
133         gensec_server = gensec.Security.start_server(settings, auth_context)
134         machine_creds = Credentials()
135         machine_creds.guess(lp)
136         machine_creds.set_machine_account(lp)
137         gensec_server.set_credentials(machine_creds)
138
139         gensec_server.want_feature(gensec.FEATURE_SEAL)
140         gensec_server.start_mech_by_sasl_name("GSSAPI")
141
142         client_finished = False
143         server_finished = False
144         server_to_client = ""
145
146         # Run the actual call loop.
147         while client_finished == False and server_finished == False:
148             if not client_finished:
149                 print "running client gensec_update"
150                 (client_finished, client_to_server) = gensec_client.update(server_to_client)
151             if not server_finished:
152                 print "running server gensec_update"
153                 (server_finished, server_to_client) = gensec_server.update(client_to_server)
154
155         session = gensec_server.session_info()
156
157         token = session.security_token
158         pac_sids = []
159         for s in token.sids:
160             pac_sids.append(str(s))
161
162         sidset1 = set(pac_sids)
163         sidset2 = set(self.user_sids)
164         if len(sidset1.difference(sidset2)):
165             print("token sids don't match")
166             print("difference : %s" % sidset1.difference(sidset2))
167             self.fail(msg="calculated groups don't match against user PAC tokenGroups")
168
169 class DynamicTokenTest(samba.tests.TestCase):
170
171     def get_creds(self, target_username, target_password):
172         creds_tmp = Credentials()
173         creds_tmp.set_username(target_username)
174         creds_tmp.set_password(target_password)
175         creds_tmp.set_domain(creds.get_domain())
176         creds_tmp.set_realm(creds.get_realm())
177         creds_tmp.set_workstation(creds.get_workstation())
178         creds_tmp.set_gensec_features(creds_tmp.get_gensec_features()
179                                       | gensec.FEATURE_SEAL)
180         return creds_tmp
181
182     def get_ldb_connection(self, target_username, target_password):
183         creds_tmp = self.get_creds(target_username, target_password)
184         ldb_target = SamDB(url=url, credentials=creds_tmp, lp=lp)
185         return ldb_target
186
187     def setUp(self):
188         super(DynamicTokenTest, self).setUp()
189         self.admin_ldb = SamDB(url, credentials=creds, session_info=system_session(lp), lp=lp)
190
191         self.base_dn = self.admin_ldb.domain_dn()
192
193         self.test_user = "tokengroups_user1"
194         self.test_user_pass = "samba123@"
195         self.admin_ldb.newuser(self.test_user, self.test_user_pass)
196         self.test_group0 = "tokengroups_group0"
197         self.admin_ldb.newgroup(self.test_group0, grouptype=dsdb.GTYPE_SECURITY_DOMAIN_LOCAL_GROUP)
198         res = self.admin_ldb.search(base="cn=%s,cn=users,%s" % (self.test_group0, self.base_dn),
199                                     attrs=["objectSid"], scope=ldb.SCOPE_BASE)
200         self.test_group0_sid = ndr_unpack(samba.dcerpc.security.dom_sid, res[0]["objectSid"][0])
201
202         self.admin_ldb.add_remove_group_members(self.test_group0, [self.test_user],
203                                        add_members_operation=True)
204
205         self.test_group1 = "tokengroups_group1"
206         self.admin_ldb.newgroup(self.test_group1, grouptype=dsdb.GTYPE_SECURITY_GLOBAL_GROUP)
207         res = self.admin_ldb.search(base="cn=%s,cn=users,%s" % (self.test_group1, self.base_dn),
208                                     attrs=["objectSid"], scope=ldb.SCOPE_BASE)
209         self.test_group1_sid = ndr_unpack(samba.dcerpc.security.dom_sid, res[0]["objectSid"][0])
210
211         self.admin_ldb.add_remove_group_members(self.test_group1, [self.test_user],
212                                        add_members_operation=True)
213
214         self.test_group2 = "tokengroups_group2"
215         self.admin_ldb.newgroup(self.test_group2, grouptype=dsdb.GTYPE_SECURITY_UNIVERSAL_GROUP)
216
217         res = self.admin_ldb.search(base="cn=%s,cn=users,%s" % (self.test_group2, self.base_dn),
218                                     attrs=["objectSid"], scope=ldb.SCOPE_BASE)
219         self.test_group2_sid = ndr_unpack(samba.dcerpc.security.dom_sid, res[0]["objectSid"][0])
220
221         self.admin_ldb.add_remove_group_members(self.test_group2, [self.test_user],
222                                        add_members_operation=True)
223
224         self.ldb = self.get_ldb_connection(self.test_user, self.test_user_pass)
225
226         res = self.ldb.search("", scope=ldb.SCOPE_BASE, attrs=["tokenGroups"])
227         self.assertEquals(len(res), 1)
228
229         self.user_sid_dn = "<SID=%s>" % str(ndr_unpack(samba.dcerpc.security.dom_sid, res[0]["tokenGroups"][0]))
230
231         res = self.ldb.search(self.user_sid_dn, scope=ldb.SCOPE_BASE, attrs=[])
232         self.assertEquals(len(res), 1)
233
234         self.test_user_dn = res[0].dn
235
236         session_info_flags = ( AUTH_SESSION_INFO_DEFAULT_GROUPS |
237                                AUTH_SESSION_INFO_AUTHENTICATED |
238                                AUTH_SESSION_INFO_SIMPLE_PRIVILEGES)
239         session = samba.auth.user_session(self.ldb, lp_ctx=lp, dn=self.user_sid_dn,
240                                           session_info_flags=session_info_flags)
241
242         token = session.security_token
243         self.user_sids = []
244         for s in token.sids:
245             self.user_sids.append(str(s))
246
247     def tearDown(self):
248         super(DynamicTokenTest, self).tearDown()
249         delete_force(self.admin_ldb, "CN=%s,%s,%s" %
250                           (self.test_user, "cn=users", self.base_dn))
251         delete_force(self.admin_ldb, "CN=%s,%s,%s" %
252                           (self.test_group0, "cn=users", self.base_dn))
253         delete_force(self.admin_ldb, "CN=%s,%s,%s" %
254                           (self.test_group1, "cn=users", self.base_dn))
255         delete_force(self.admin_ldb, "CN=%s,%s,%s" %
256                           (self.test_group2, "cn=users", self.base_dn))
257
258     def test_rootDSE_tokenGroups(self):
259         """Testing rootDSE tokengroups against internal calculation"""
260         if not url.startswith("ldap"):
261             self.fail(msg="This test is only valid on ldap")
262
263         res = self.ldb.search("", scope=ldb.SCOPE_BASE, attrs=["tokenGroups"])
264         self.assertEquals(len(res), 1)
265
266         print("Getting tokenGroups from rootDSE")
267         tokengroups = []
268         for sid in res[0]['tokenGroups']:
269             tokengroups.append(str(ndr_unpack(samba.dcerpc.security.dom_sid, sid)))
270
271         sidset1 = set(tokengroups)
272         sidset2 = set(self.user_sids)
273         if len(sidset1.difference(sidset2)):
274             print("token sids don't match")
275             print("tokengroups: %s" % tokengroups)
276             print("calculated : %s" % self.user_sids)
277             print("difference : %s" % sidset1.difference(sidset2))
278             self.fail(msg="calculated groups don't match against rootDSE tokenGroups")
279
280     def test_dn_tokenGroups(self):
281         print("Getting tokenGroups from user DN")
282         res = self.ldb.search(self.user_sid_dn, scope=ldb.SCOPE_BASE, attrs=["tokenGroups"])
283         self.assertEquals(len(res), 1)
284
285         dn_tokengroups = []
286         for sid in res[0]['tokenGroups']:
287             dn_tokengroups.append(str(ndr_unpack(samba.dcerpc.security.dom_sid, sid)))
288
289         sidset1 = set(dn_tokengroups)
290         sidset2 = set(self.user_sids)
291         if len(sidset1.difference(sidset2)):
292             print("token sids don't match")
293             print("difference : %s" % sidset1.difference(sidset2))
294             self.fail(msg="calculated groups don't match against user DN tokenGroups")
295
296     def test_pac_groups(self):
297         settings = {}
298         settings["lp_ctx"] = lp
299         settings["target_hostname"] = lp.get("netbios name")
300
301         gensec_client = gensec.Security.start_client(settings)
302         gensec_client.set_credentials(self.get_creds(self.test_user, self.test_user_pass))
303         gensec_client.want_feature(gensec.FEATURE_SEAL)
304         gensec_client.start_mech_by_sasl_name("GSSAPI")
305
306         auth_context = AuthContext(lp_ctx=lp, ldb=self.ldb, methods=[])
307
308         gensec_server = gensec.Security.start_server(settings, auth_context)
309         machine_creds = Credentials()
310         machine_creds.guess(lp)
311         machine_creds.set_machine_account(lp)
312         gensec_server.set_credentials(machine_creds)
313
314         gensec_server.want_feature(gensec.FEATURE_SEAL)
315         gensec_server.start_mech_by_sasl_name("GSSAPI")
316
317         client_finished = False
318         server_finished = False
319         server_to_client = ""
320
321         # Run the actual call loop.
322         while client_finished == False and server_finished == False:
323             if not client_finished:
324                 print "running client gensec_update"
325                 (client_finished, client_to_server) = gensec_client.update(server_to_client)
326             if not server_finished:
327                 print "running server gensec_update"
328                 (server_finished, server_to_client) = gensec_server.update(client_to_server)
329
330         session = gensec_server.session_info()
331
332         token = session.security_token
333         pac_sids = []
334         for s in token.sids:
335             pac_sids.append(str(s))
336
337         sidset1 = set(pac_sids)
338         sidset2 = set(self.user_sids)
339         if len(sidset1.difference(sidset2)):
340             print("token sids don't match")
341             print("difference : %s" % sidset1.difference(sidset2))
342             self.fail(msg="calculated groups don't match against user PAC tokenGroups")
343
344
345     def test_tokenGroups_manual(self):
346         # Manually run the tokenGroups algorithm from MS-ADTS 3.1.1.4.5.19 and MS-DRSR 4.1.8.3
347         # and compare the result
348         res = self.admin_ldb.search(base=self.base_dn, scope=ldb.SCOPE_SUBTREE,
349                                     expression="(|(objectclass=user)(objectclass=group))",
350                                     attrs=["memberOf"])
351         aSet = set()
352         aSetR = set()
353         vSet = set()
354         for obj in res:
355             if "memberOf" in obj:
356                 for dn in obj["memberOf"]:
357                     first = obj.dn.get_casefold()
358                     second = ldb.Dn(self.admin_ldb, dn).get_casefold()
359                     aSet.add((first, second))
360                     aSetR.add((second, first))
361                     vSet.add(first)
362                     vSet.add(second)
363
364         res = self.admin_ldb.search(base=self.base_dn, scope=ldb.SCOPE_SUBTREE,
365                                     expression="(objectclass=user)",
366                                     attrs=["primaryGroupID"])
367         for obj in res:
368             if "primaryGroupID" in obj:
369                 sid = "%s-%d" % (self.admin_ldb.get_domain_sid(), int(obj["primaryGroupID"][0]))
370                 res2 = self.admin_ldb.search(base="<SID=%s>" % sid, scope=ldb.SCOPE_BASE,
371                                              attrs=[])
372                 first = obj.dn.get_casefold()
373                 second = res2[0].dn.get_casefold()
374
375                 aSet.add((first, second))
376                 aSetR.add((second, first))
377                 vSet.add(first)
378                 vSet.add(second)
379
380         wSet = set()
381         wSet.add(self.test_user_dn.get_casefold())
382         closure(vSet, wSet, aSet)
383         wSet.remove(self.test_user_dn.get_casefold())
384
385         tokenGroupsSet = set()
386
387         res = self.ldb.search(self.user_sid_dn, scope=ldb.SCOPE_BASE, attrs=["tokenGroups"])
388         self.assertEquals(len(res), 1)
389
390         dn_tokengroups = []
391         for sid in res[0]['tokenGroups']:
392             sid = ndr_unpack(samba.dcerpc.security.dom_sid, sid)
393             res3 = self.admin_ldb.search(base="<SID=%s>" % sid, scope=ldb.SCOPE_BASE,
394                                          attrs=[])
395             tokenGroupsSet.add(res3[0].dn.get_casefold())
396
397         if len(wSet.difference(tokenGroupsSet)):
398             self.fail(msg="additional calculated: %s" % wSet.difference(tokenGroupsSet))
399
400         if len(tokenGroupsSet.difference(wSet)):
401             self.fail(msg="additional tokenGroups: %s" % tokenGroupsSet.difference(wSet))
402
403
404     def filtered_closure(self, wSet, filter_grouptype):
405         res = self.admin_ldb.search(base=self.base_dn, scope=ldb.SCOPE_SUBTREE,
406                                     expression="(|(objectclass=user)(objectclass=group))",
407                                     attrs=["memberOf"])
408         aSet = set()
409         aSetR = set()
410         vSet = set()
411         for obj in res:
412             vSet.add(obj.dn.get_casefold())
413             if "memberOf" in obj:
414                 for dn in obj["memberOf"]:
415                     first = obj.dn.get_casefold()
416                     second = ldb.Dn(self.admin_ldb, dn).get_casefold()
417                     aSet.add((first, second))
418                     aSetR.add((second, first))
419                     vSet.add(first)
420                     vSet.add(second)
421
422         res = self.admin_ldb.search(base=self.base_dn, scope=ldb.SCOPE_SUBTREE,
423                                     expression="(objectclass=user)",
424                                     attrs=["primaryGroupID"])
425         for obj in res:
426             if "primaryGroupID" in obj:
427                 sid = "%s-%d" % (self.admin_ldb.get_domain_sid(), int(obj["primaryGroupID"][0]))
428                 res2 = self.admin_ldb.search(base="<SID=%s>" % sid, scope=ldb.SCOPE_BASE,
429                                              attrs=[])
430                 first = obj.dn.get_casefold()
431                 second = res2[0].dn.get_casefold()
432
433                 aSet.add((first, second))
434                 aSetR.add((second, first))
435                 vSet.add(first)
436                 vSet.add(second)
437
438         uSet = set()
439         for v in vSet:
440             res_group = self.admin_ldb.search(base=v, scope=ldb.SCOPE_BASE,
441                                               attrs=["groupType"],
442                                               expression="objectClass=group")
443             if len(res_group) == 1:
444                 if hex(int(res_group[0]["groupType"][0]) & 0x00000000FFFFFFFF) == hex(filter_grouptype):
445                     uSet.add(v)
446             else:
447                 uSet.add(v)
448
449         closure(uSet, wSet, aSet)
450
451
452     def test_tokenGroupsGlobalAndUniversal_manual(self):
453         # Manually run the tokenGroups algorithm from MS-ADTS 3.1.1.4.5.19 and MS-DRSR 4.1.8.3
454         # and compare the result
455
456         # The variable names come from MS-ADTS May 15, 2014
457
458         S = set()
459         S.add(self.test_user_dn.get_casefold())
460
461         self.filtered_closure(S, GTYPE_SECURITY_GLOBAL_GROUP)
462
463         T = set()
464         # Not really a SID, we do this on DNs...
465         for sid in S:
466             X = set()
467             X.add(sid)
468             self.filtered_closure(X, GTYPE_SECURITY_UNIVERSAL_GROUP)
469
470             T = T.union(X)
471
472         T.remove(self.test_user_dn.get_casefold())
473
474         tokenGroupsSet = set()
475
476         res = self.ldb.search(self.user_sid_dn, scope=ldb.SCOPE_BASE, attrs=["tokenGroupsGlobalAndUniversal"])
477         self.assertEquals(len(res), 1)
478
479         dn_tokengroups = []
480         for sid in res[0]['tokenGroupsGlobalAndUniversal']:
481             sid = ndr_unpack(samba.dcerpc.security.dom_sid, sid)
482             res3 = self.admin_ldb.search(base="<SID=%s>" % sid, scope=ldb.SCOPE_BASE,
483                                          attrs=[])
484             tokenGroupsSet.add(res3[0].dn.get_casefold())
485
486         if len(T.difference(tokenGroupsSet)):
487             self.fail(msg="additional calculated: %s" % T.difference(tokenGroupsSet))
488
489         if len(tokenGroupsSet.difference(T)):
490             self.fail(msg="additional tokenGroupsGlobalAndUniversal: %s" % tokenGroupsSet.difference(T))
491
492 if not "://" in url:
493     if os.path.isfile(url):
494         url = "tdb://%s" % url
495     else:
496         url = "ldap://%s" % url
497
498 samdb = SamDB(url, credentials=creds, session_info=system_session(lp), lp=lp)
499
500 TestProgram(module=__name__, opts=subunitopts)