s4:provision Move helper functions back to provision
[metze/samba/wip.git] / source4 / scripting / python / samba / samdb.py
1 #!/usr/bin/python
2
3 # Unix SMB/CIFS implementation.
4 # Copyright (C) Jelmer Vernooij <jelmer@samba.org> 2007-2008
5 #
6 # Based on the original in EJS:
7 # Copyright (C) Andrew Tridgell <tridge@samba.org> 2005
8 #   
9 # This program is free software; you can redistribute it and/or modify
10 # it under the terms of the GNU General Public License as published by
11 # the Free Software Foundation; either version 3 of the License, or
12 # (at your option) any later version.
13 #   
14 # This program is distributed in the hope that it will be useful,
15 # but WITHOUT ANY WARRANTY; without even the implied warranty of
16 # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
17 # GNU General Public License for more details.
18 #   
19 # You should have received a copy of the GNU General Public License
20 # along with this program.  If not, see <http://www.gnu.org/licenses/>.
21 #
22
23 """Convenience functions for using the SAM."""
24
25 import samba
26 import glue
27 import ldb
28 from samba.idmap import IDmapDB
29 import pwd
30 import time
31 import base64
32
33 __docformat__ = "restructuredText"
34
35 class SamDB(samba.Ldb):
36     """The SAM database."""
37
38     def __init__(self, url=None, session_info=None, credentials=None, 
39                  modules_dir=None, lp=None, options=None):
40         """Open the Sam Database.
41
42         :param url: URL of the database.
43         """
44         self.lp = lp
45         super(SamDB, self).__init__(session_info=session_info, credentials=credentials,
46                                     modules_dir=modules_dir, lp=lp, options=options)
47         glue.dsdb_set_global_schema(self)
48         if url:
49             self.connect(url)
50         else:
51             self.connect(lp.get("sam database"))
52
53     def connect(self, url):
54         super(SamDB, self).connect(self.lp.private_path(url))
55
56     def enable_account(self, user_dn):
57         """Enable an account.
58         
59         :param user_dn: Dn of the account to enable.
60         """
61         res = self.search(user_dn, ldb.SCOPE_BASE, None, ["userAccountControl"])
62         assert len(res) == 1
63         userAccountControl = int(res[0]["userAccountControl"][0])
64         if (userAccountControl & 0x2):
65             userAccountControl = userAccountControl & ~0x2 # remove disabled bit
66         if (userAccountControl & 0x20):
67             userAccountControl = userAccountControl & ~0x20 # remove 'no password required' bit
68
69         mod = """
70 dn: %s
71 changetype: modify
72 replace: userAccountControl
73 userAccountControl: %u
74 """ % (user_dn, userAccountControl)
75         self.modify_ldif(mod)
76
77         
78     def force_password_change_at_next_login(self, user_dn):
79         """Force a password change at next login
80         
81         :param user_dn: Dn of the account to force password change on
82         """
83         mod = """
84 dn: %s
85 changetype: modify
86 replace: pwdLastSet
87 pwdLastSet: 0
88 """ % (user_dn)
89         self.modify_ldif(mod)
90
91     def domain_dn(self):
92         # find the DNs for the domain and the domain users group
93         res = self.search("", scope=ldb.SCOPE_BASE, 
94                           expression="(defaultNamingContext=*)", 
95                           attrs=["defaultNamingContext"])
96         assert(len(res) == 1 and res[0]["defaultNamingContext"] is not None)
97         return res[0]["defaultNamingContext"][0]
98
99     def newuser(self, username, unixname, password, force_password_change_at_next_login=False):
100         """add a new user record.
101         
102         :param username: Name of the new user.
103         :param unixname: Name of the unix user to map to.
104         :param password: Password for the new user
105         """
106         # connect to the sam 
107         self.transaction_start()
108         try:
109             domain_dn = self.domain_dn()
110             assert(domain_dn is not None)
111             user_dn = "CN=%s,CN=Users,%s" % (username, domain_dn)
112
113             #
114             #  the new user record. note the reliance on the samdb module to 
115             #  fill in a sid, guid etc
116             #
117             #  now the real work
118             self.add({"dn": user_dn, 
119                 "sAMAccountName": username,
120                 "userPassword": password,
121                 "objectClass": "user"})
122
123             res = self.search(user_dn, scope=ldb.SCOPE_BASE,
124                               expression="objectclass=*",
125                               attrs=["objectSid"])
126             assert len(res) == 1
127             user_sid = self.schema_format_value("objectSid", res[0]["objectSid"][0])
128             
129             try:
130                 idmap = IDmapDB(lp=self.lp)
131
132                 user = pwd.getpwnam(unixname)
133                 # setup ID mapping for this UID
134                 
135                 idmap.setup_name_mapping(user_sid, idmap.TYPE_UID, user[2])
136
137             except KeyError:
138                 pass
139
140             if force_password_change_at_next_login:
141                 self.force_password_change_at_next_login(user_dn)
142
143             #  modify the userAccountControl to remove the disabled bit
144             self.enable_account(user_dn)
145         except:
146             self.transaction_cancel()
147             raise
148         self.transaction_commit()
149
150     def setpassword(self, filter, password, force_password_change_at_next_login=False):
151         """Set a password on a user record
152         
153         :param filter: LDAP filter to find the user (eg samccountname=name)
154         :param password: Password for the user
155         """
156         # connect to the sam 
157         self.transaction_start()
158         try:
159             # find the DNs for the domain
160             res = self.search("", scope=ldb.SCOPE_BASE, 
161                               expression="(defaultNamingContext=*)", 
162                               attrs=["defaultNamingContext"])
163             assert(len(res) == 1 and res[0]["defaultNamingContext"] is not None)
164             domain_dn = res[0]["defaultNamingContext"][0]
165             assert(domain_dn is not None)
166
167             res = self.search(domain_dn, scope=ldb.SCOPE_SUBTREE, 
168                               expression=filter)
169             assert(len(res) == 1)
170             user_dn = res[0].dn
171
172             setpw = """
173 dn: %s
174 changetype: modify
175 replace: userPassword
176 userPassword:: %s
177 """ % (user_dn, base64.b64encode(password))
178
179             self.modify_ldif(setpw)
180
181             if force_password_change_at_next_login:
182                 self.force_password_change_at_next_login(user_dn)
183
184             #  modify the userAccountControl to remove the disabled bit
185             self.enable_account(user_dn)
186         except:
187             self.transaction_cancel()
188             raise
189         self.transaction_commit()
190
191     def setexpiry(self, user, expiry_seconds, noexpiry):
192         """Set the account expiry for a user
193         
194         :param expiry_seconds: expiry time from now in seconds
195         :param noexpiry: if set, then don't expire password
196         """
197         self.transaction_start()
198         try:
199             res = self.search(base=self.domain_dn(), scope=ldb.SCOPE_SUBTREE,
200                               expression=("(samAccountName=%s)" % user),
201                               attrs=["userAccountControl", "accountExpires"])
202             assert len(res) == 1
203             userAccountControl = int(res[0]["userAccountControl"][0])
204             accountExpires     = int(res[0]["accountExpires"][0])
205             if noexpiry:
206                 userAccountControl = userAccountControl | 0x10000
207                 accountExpires = 0
208             else:
209                 userAccountControl = userAccountControl & ~0x10000
210                 accountExpires = glue.unix2nttime(expiry_seconds + int(time.time()))
211
212             mod = """
213 dn: %s
214 changetype: modify
215 replace: userAccountControl
216 userAccountControl: %u
217 replace: accountExpires
218 accountExpires: %u
219 """ % (res[0].dn, userAccountControl, accountExpires)
220             # now change the database
221             self.modify_ldif(mod)
222         except:
223             self.transaction_cancel()
224             raise
225         self.transaction_commit();
226