Close repositories after use.
[jelmer/dulwich.git] / dulwich / tests / test_repository.py
1 # -*- coding: utf-8 -*-
2 # test_repository.py -- tests for repository.py
3 # Copyright (C) 2007 James Westby <jw+debian@jameswestby.net>
4 #
5 # Dulwich is dual-licensed under the Apache License, Version 2.0 and the GNU
6 # General Public License as public by the Free Software Foundation; version 2.0
7 # or (at your option) any later version. You can redistribute it and/or
8 # modify it under the terms of either of these two licenses.
9 #
10 # Unless required by applicable law or agreed to in writing, software
11 # distributed under the License is distributed on an "AS IS" BASIS,
12 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 # See the License for the specific language governing permissions and
14 # limitations under the License.
15 #
16 # You should have received a copy of the licenses; if not, see
17 # <http://www.gnu.org/licenses/> for a copy of the GNU General Public License
18 # and <http://www.apache.org/licenses/LICENSE-2.0> for a copy of the Apache
19 # License, Version 2.0.
20 #
21
22 """Tests for the repository."""
23
24 import locale
25 import os
26 import stat
27 import shutil
28 import sys
29 import tempfile
30 import warnings
31
32 from dulwich import errors
33 from dulwich.object_store import (
34     tree_lookup_path,
35     )
36 from dulwich import objects
37 from dulwich.config import Config
38 from dulwich.errors import NotGitRepository
39 from dulwich.repo import (
40     InvalidUserIdentity,
41     Repo,
42     MemoryRepo,
43     check_user_identity,
44     )
45 from dulwich.tests import (
46     TestCase,
47     skipIf,
48     )
49 from dulwich.tests.utils import (
50     open_repo,
51     tear_down_repo,
52     setup_warning_catcher,
53     )
54
55 missing_sha = b'b91fa4d900e17e99b433218e988c4eb4a3e9a097'
56
57
58 class CreateRepositoryTests(TestCase):
59
60     def assertFileContentsEqual(self, expected, repo, path):
61         f = repo.get_named_file(path)
62         if not f:
63             self.assertEqual(expected, None)
64         else:
65             with f:
66                 self.assertEqual(expected, f.read())
67
68     def _check_repo_contents(self, repo, expect_bare):
69         self.assertEqual(expect_bare, repo.bare)
70         self.assertFileContentsEqual(
71             b'Unnamed repository', repo, 'description')
72         self.assertFileContentsEqual(
73             b'', repo, os.path.join('info', 'exclude'))
74         self.assertFileContentsEqual(None, repo, 'nonexistent file')
75         barestr = b'bare = ' + str(expect_bare).lower().encode('ascii')
76         with repo.get_named_file('config') as f:
77             config_text = f.read()
78             self.assertTrue(barestr in config_text, "%r" % config_text)
79         expect_filemode = sys.platform != 'win32'
80         barestr = b'filemode = ' + str(expect_filemode).lower().encode('ascii')
81         with repo.get_named_file('config') as f:
82             config_text = f.read()
83             self.assertTrue(barestr in config_text, "%r" % config_text)
84
85     def test_create_memory(self):
86         repo = MemoryRepo.init_bare([], {})
87         self._check_repo_contents(repo, True)
88
89     def test_create_disk_bare(self):
90         tmp_dir = tempfile.mkdtemp()
91         self.addCleanup(shutil.rmtree, tmp_dir)
92         repo = Repo.init_bare(tmp_dir)
93         self.assertEqual(tmp_dir, repo._controldir)
94         self._check_repo_contents(repo, True)
95
96     def test_create_disk_non_bare(self):
97         tmp_dir = tempfile.mkdtemp()
98         self.addCleanup(shutil.rmtree, tmp_dir)
99         repo = Repo.init(tmp_dir)
100         self.assertEqual(os.path.join(tmp_dir, '.git'), repo._controldir)
101         self._check_repo_contents(repo, False)
102
103     def test_create_disk_non_bare_mkdir(self):
104         tmp_dir = tempfile.mkdtemp()
105         target_dir = os.path.join(tmp_dir, "target")
106         self.addCleanup(shutil.rmtree, tmp_dir)
107         repo = Repo.init(target_dir, mkdir=True)
108         self.assertEqual(os.path.join(target_dir, '.git'), repo._controldir)
109         self._check_repo_contents(repo, False)
110
111     def test_create_disk_bare_mkdir(self):
112         tmp_dir = tempfile.mkdtemp()
113         target_dir = os.path.join(tmp_dir, "target")
114         self.addCleanup(shutil.rmtree, tmp_dir)
115         repo = Repo.init_bare(target_dir, mkdir=True)
116         self.assertEqual(target_dir, repo._controldir)
117         self._check_repo_contents(repo, True)
118
119
120 class MemoryRepoTests(TestCase):
121
122     def test_set_description(self):
123         r = MemoryRepo.init_bare([], {})
124         description = b"Some description"
125         r.set_description(description)
126         self.assertEqual(description, r.get_description())
127
128
129 class RepositoryRootTests(TestCase):
130
131     def mkdtemp(self):
132         return tempfile.mkdtemp()
133
134     def open_repo(self, name):
135         temp_dir = self.mkdtemp()
136         repo = open_repo(name, temp_dir)
137         self.addCleanup(tear_down_repo, repo)
138         return repo
139
140     def test_simple_props(self):
141         r = self.open_repo('a.git')
142         self.assertEqual(r.controldir(), r.path)
143
144     def test_setitem(self):
145         r = self.open_repo('a.git')
146         r[b"refs/tags/foo"] = b'a90fa2d900a17e99b433217e988c4eb4a2e9a097'
147         self.assertEqual(b'a90fa2d900a17e99b433217e988c4eb4a2e9a097',
148                          r[b"refs/tags/foo"].id)
149
150     def test_getitem_unicode(self):
151         r = self.open_repo('a.git')
152
153         test_keys = [
154             (b'refs/heads/master', True),
155             (b'a90fa2d900a17e99b433217e988c4eb4a2e9a097', True),
156             (b'11' * 19 + b'--', False),
157         ]
158
159         for k, contained in test_keys:
160             self.assertEqual(k in r, contained)
161
162         # Avoid deprecation warning under Py3.2+
163         if getattr(self, 'assertRaisesRegex', None):
164             assertRaisesRegexp = self.assertRaisesRegex
165         else:
166             assertRaisesRegexp = self.assertRaisesRegexp
167         for k, _ in test_keys:
168             assertRaisesRegexp(
169                 TypeError, "'name' must be bytestring, not int",
170                 r.__getitem__, 12
171             )
172
173     def test_delitem(self):
174         r = self.open_repo('a.git')
175
176         del r[b'refs/heads/master']
177         self.assertRaises(KeyError, lambda: r[b'refs/heads/master'])
178
179         del r[b'HEAD']
180         self.assertRaises(KeyError, lambda: r[b'HEAD'])
181
182         self.assertRaises(ValueError, r.__delitem__, b'notrefs/foo')
183
184     def test_get_refs(self):
185         r = self.open_repo('a.git')
186         self.assertEqual({
187             b'HEAD': b'a90fa2d900a17e99b433217e988c4eb4a2e9a097',
188             b'refs/heads/master': b'a90fa2d900a17e99b433217e988c4eb4a2e9a097',
189             b'refs/tags/mytag': b'28237f4dc30d0d462658d6b937b08a0f0b6ef55a',
190             b'refs/tags/mytag-packed':
191                 b'b0931cadc54336e78a1d980420e3268903b57a50',
192             }, r.get_refs())
193
194     def test_head(self):
195         r = self.open_repo('a.git')
196         self.assertEqual(r.head(), b'a90fa2d900a17e99b433217e988c4eb4a2e9a097')
197
198     def test_get_object(self):
199         r = self.open_repo('a.git')
200         obj = r.get_object(r.head())
201         self.assertEqual(obj.type_name, b'commit')
202
203     def test_get_object_non_existant(self):
204         r = self.open_repo('a.git')
205         self.assertRaises(KeyError, r.get_object, missing_sha)
206
207     def test_contains_object(self):
208         r = self.open_repo('a.git')
209         self.assertTrue(r.head() in r)
210
211     def test_contains_ref(self):
212         r = self.open_repo('a.git')
213         self.assertTrue(b"HEAD" in r)
214
215     def test_get_no_description(self):
216         r = self.open_repo('a.git')
217         self.assertIs(None, r.get_description())
218
219     def test_get_description(self):
220         r = self.open_repo('a.git')
221         with open(os.path.join(r.path, 'description'), 'wb') as f:
222             f.write(b"Some description")
223         self.assertEqual(b"Some description", r.get_description())
224
225     def test_set_description(self):
226         r = self.open_repo('a.git')
227         description = b"Some description"
228         r.set_description(description)
229         self.assertEqual(description, r.get_description())
230
231     def test_contains_missing(self):
232         r = self.open_repo('a.git')
233         self.assertFalse(b"bar" in r)
234
235     def test_get_peeled(self):
236         # unpacked ref
237         r = self.open_repo('a.git')
238         tag_sha = b'28237f4dc30d0d462658d6b937b08a0f0b6ef55a'
239         self.assertNotEqual(r[tag_sha].sha().hexdigest(), r.head())
240         self.assertEqual(r.get_peeled(b'refs/tags/mytag'), r.head())
241
242         # packed ref with cached peeled value
243         packed_tag_sha = b'b0931cadc54336e78a1d980420e3268903b57a50'
244         parent_sha = r[r.head()].parents[0]
245         self.assertNotEqual(r[packed_tag_sha].sha().hexdigest(), parent_sha)
246         self.assertEqual(r.get_peeled(b'refs/tags/mytag-packed'), parent_sha)
247
248         # TODO: add more corner cases to test repo
249
250     def test_get_peeled_not_tag(self):
251         r = self.open_repo('a.git')
252         self.assertEqual(r.get_peeled(b'HEAD'), r.head())
253
254     def test_get_walker(self):
255         r = self.open_repo('a.git')
256         # include defaults to [r.head()]
257         self.assertEqual(
258             [e.commit.id for e in r.get_walker()],
259             [r.head(), b'2a72d929692c41d8554c07f6301757ba18a65d91'])
260         self.assertEqual(
261             [e.commit.id for e in
262                 r.get_walker([b'2a72d929692c41d8554c07f6301757ba18a65d91'])],
263             [b'2a72d929692c41d8554c07f6301757ba18a65d91'])
264         self.assertEqual(
265             [e.commit.id for e in
266                 r.get_walker(b'2a72d929692c41d8554c07f6301757ba18a65d91')],
267             [b'2a72d929692c41d8554c07f6301757ba18a65d91'])
268
269     def test_fetch(self):
270         r = self.open_repo('a.git')
271         tmp_dir = self.mkdtemp()
272         self.addCleanup(shutil.rmtree, tmp_dir)
273         t = Repo.init(tmp_dir)
274         self.addCleanup(t.close)
275         r.fetch(t)
276         self.assertIn(b'a90fa2d900a17e99b433217e988c4eb4a2e9a097', t)
277         self.assertIn(b'a90fa2d900a17e99b433217e988c4eb4a2e9a097', t)
278         self.assertIn(b'a90fa2d900a17e99b433217e988c4eb4a2e9a097', t)
279         self.assertIn(b'28237f4dc30d0d462658d6b937b08a0f0b6ef55a', t)
280         self.assertIn(b'b0931cadc54336e78a1d980420e3268903b57a50', t)
281
282     def test_fetch_ignores_missing_refs(self):
283         r = self.open_repo('a.git')
284         missing = b'1234566789123456789123567891234657373833'
285         r.refs[b'refs/heads/blah'] = missing
286         tmp_dir = self.mkdtemp()
287         self.addCleanup(shutil.rmtree, tmp_dir)
288         t = Repo.init(tmp_dir)
289         self.addCleanup(t.close)
290         r.fetch(t)
291         self.assertIn(b'a90fa2d900a17e99b433217e988c4eb4a2e9a097', t)
292         self.assertIn(b'a90fa2d900a17e99b433217e988c4eb4a2e9a097', t)
293         self.assertIn(b'a90fa2d900a17e99b433217e988c4eb4a2e9a097', t)
294         self.assertIn(b'28237f4dc30d0d462658d6b937b08a0f0b6ef55a', t)
295         self.assertIn(b'b0931cadc54336e78a1d980420e3268903b57a50', t)
296         self.assertNotIn(missing, t)
297
298     def test_clone(self):
299         r = self.open_repo('a.git')
300         tmp_dir = self.mkdtemp()
301         self.addCleanup(shutil.rmtree, tmp_dir)
302         with r.clone(tmp_dir, mkdir=False) as t:
303             self.assertEqual({
304                 b'HEAD': b'a90fa2d900a17e99b433217e988c4eb4a2e9a097',
305                 b'refs/remotes/origin/master':
306                     b'a90fa2d900a17e99b433217e988c4eb4a2e9a097',
307                 b'refs/heads/master':
308                     b'a90fa2d900a17e99b433217e988c4eb4a2e9a097',
309                 b'refs/tags/mytag':
310                     b'28237f4dc30d0d462658d6b937b08a0f0b6ef55a',
311                 b'refs/tags/mytag-packed':
312                     b'b0931cadc54336e78a1d980420e3268903b57a50',
313                 }, t.refs.as_dict())
314             shas = [e.commit.id for e in r.get_walker()]
315             self.assertEqual(shas, [t.head(),
316                              b'2a72d929692c41d8554c07f6301757ba18a65d91'])
317             c = t.get_config()
318             encoded_path = r.path
319             if not isinstance(encoded_path, bytes):
320                 encoded_path = encoded_path.encode(sys.getfilesystemencoding())
321             self.assertEqual(encoded_path,
322                              c.get((b'remote', b'origin'), b'url'))
323             self.assertEqual(
324                 b'+refs/heads/*:refs/remotes/origin/*',
325                 c.get((b'remote', b'origin'), b'fetch'))
326
327     def test_clone_no_head(self):
328         temp_dir = self.mkdtemp()
329         self.addCleanup(shutil.rmtree, temp_dir)
330         repo_dir = os.path.join(os.path.dirname(__file__), 'data', 'repos')
331         dest_dir = os.path.join(temp_dir, 'a.git')
332         shutil.copytree(os.path.join(repo_dir, 'a.git'),
333                         dest_dir, symlinks=True)
334         r = Repo(dest_dir)
335         del r.refs[b"refs/heads/master"]
336         del r.refs[b"HEAD"]
337         t = r.clone(os.path.join(temp_dir, 'b.git'), mkdir=True)
338         self.assertEqual({
339             b'refs/tags/mytag': b'28237f4dc30d0d462658d6b937b08a0f0b6ef55a',
340             b'refs/tags/mytag-packed':
341                 b'b0931cadc54336e78a1d980420e3268903b57a50',
342             }, t.refs.as_dict())
343
344     def test_clone_empty(self):
345         """Test clone() doesn't crash if HEAD points to a non-existing ref.
346
347         This simulates cloning server-side bare repository either when it is
348         still empty or if user renames master branch and pushes private repo
349         to the server.
350         Non-bare repo HEAD always points to an existing ref.
351         """
352         r = self.open_repo('empty.git')
353         tmp_dir = self.mkdtemp()
354         self.addCleanup(shutil.rmtree, tmp_dir)
355         r.clone(tmp_dir, mkdir=False, bare=True)
356
357     def test_clone_bare(self):
358         r = self.open_repo('a.git')
359         tmp_dir = self.mkdtemp()
360         self.addCleanup(shutil.rmtree, tmp_dir)
361         t = r.clone(tmp_dir, mkdir=False)
362         t.close()
363
364     def test_clone_checkout_and_bare(self):
365         r = self.open_repo('a.git')
366         tmp_dir = self.mkdtemp()
367         self.addCleanup(shutil.rmtree, tmp_dir)
368         self.assertRaises(ValueError, r.clone, tmp_dir, mkdir=False,
369                           checkout=True, bare=True)
370
371     def test_merge_history(self):
372         r = self.open_repo('simple_merge.git')
373         shas = [e.commit.id for e in r.get_walker()]
374         self.assertEqual(shas, [b'5dac377bdded4c9aeb8dff595f0faeebcc8498cc',
375                                 b'ab64bbdcc51b170d21588e5c5d391ee5c0c96dfd',
376                                 b'4cffe90e0a41ad3f5190079d7c8f036bde29cbe6',
377                                 b'60dacdc733de308bb77bb76ce0fb0f9b44c9769e',
378                                 b'0d89f20333fbb1d2f3a94da77f4981373d8f4310'])
379
380     def test_out_of_order_merge(self):
381         """Test that revision history is ordered by date, not parent order."""
382         r = self.open_repo('ooo_merge.git')
383         shas = [e.commit.id for e in r.get_walker()]
384         self.assertEqual(shas, [b'7601d7f6231db6a57f7bbb79ee52e4d462fd44d1',
385                                 b'f507291b64138b875c28e03469025b1ea20bc614',
386                                 b'fb5b0425c7ce46959bec94d54b9a157645e114f5',
387                                 b'f9e39b120c68182a4ba35349f832d0e4e61f485c'])
388
389     def test_get_tags_empty(self):
390         r = self.open_repo('ooo_merge.git')
391         self.assertEqual({}, r.refs.as_dict(b'refs/tags'))
392
393     def test_get_config(self):
394         r = self.open_repo('ooo_merge.git')
395         self.assertIsInstance(r.get_config(), Config)
396
397     def test_get_config_stack(self):
398         r = self.open_repo('ooo_merge.git')
399         self.assertIsInstance(r.get_config_stack(), Config)
400
401     @skipIf(not getattr(os, 'symlink', None), 'Requires symlink support')
402     def test_submodule(self):
403         temp_dir = self.mkdtemp()
404         self.addCleanup(shutil.rmtree, temp_dir)
405         repo_dir = os.path.join(os.path.dirname(__file__), 'data', 'repos')
406         shutil.copytree(os.path.join(repo_dir, 'a.git'),
407                         os.path.join(temp_dir, 'a.git'), symlinks=True)
408         rel = os.path.relpath(os.path.join(repo_dir, 'submodule'), temp_dir)
409         os.symlink(os.path.join(rel, 'dotgit'), os.path.join(temp_dir, '.git'))
410         with Repo(temp_dir) as r:
411             self.assertEqual(r.head(),
412                              b'a90fa2d900a17e99b433217e988c4eb4a2e9a097')
413
414     def test_common_revisions(self):
415         """
416         This test demonstrates that ``find_common_revisions()`` actually
417         returns common heads, not revisions; dulwich already uses
418         ``find_common_revisions()`` in such a manner (see
419         ``Repo.fetch_objects()``).
420         """
421
422         expected_shas = set([b'60dacdc733de308bb77bb76ce0fb0f9b44c9769e'])
423
424         # Source for objects.
425         r_base = self.open_repo('simple_merge.git')
426
427         # Re-create each-side of the merge in simple_merge.git.
428         #
429         # Since the trees and blobs are missing, the repository created is
430         # corrupted, but we're only checking for commits for the purpose of
431         # this test, so it's immaterial.
432         r1_dir = self.mkdtemp()
433         self.addCleanup(shutil.rmtree, r1_dir)
434         r1_commits = [b'ab64bbdcc51b170d21588e5c5d391ee5c0c96dfd',  # HEAD
435                       b'60dacdc733de308bb77bb76ce0fb0f9b44c9769e',
436                       b'0d89f20333fbb1d2f3a94da77f4981373d8f4310']
437
438         r2_dir = self.mkdtemp()
439         self.addCleanup(shutil.rmtree, r2_dir)
440         r2_commits = [b'4cffe90e0a41ad3f5190079d7c8f036bde29cbe6',  # HEAD
441                       b'60dacdc733de308bb77bb76ce0fb0f9b44c9769e',
442                       b'0d89f20333fbb1d2f3a94da77f4981373d8f4310']
443
444         r1 = Repo.init_bare(r1_dir)
445         for c in r1_commits:
446             r1.object_store.add_object(r_base.get_object(c))
447         r1.refs[b'HEAD'] = r1_commits[0]
448
449         r2 = Repo.init_bare(r2_dir)
450         for c in r2_commits:
451             r2.object_store.add_object(r_base.get_object(c))
452         r2.refs[b'HEAD'] = r2_commits[0]
453
454         # Finally, the 'real' testing!
455         shas = r2.object_store.find_common_revisions(r1.get_graph_walker())
456         self.assertEqual(set(shas), expected_shas)
457
458         shas = r1.object_store.find_common_revisions(r2.get_graph_walker())
459         self.assertEqual(set(shas), expected_shas)
460
461     def test_shell_hook_pre_commit(self):
462         if os.name != 'posix':
463             self.skipTest('shell hook tests requires POSIX shell')
464
465         pre_commit_fail = """#!/bin/sh
466 exit 1
467 """
468
469         pre_commit_success = """#!/bin/sh
470 exit 0
471 """
472
473         repo_dir = os.path.join(self.mkdtemp())
474         self.addCleanup(shutil.rmtree, repo_dir)
475         r = Repo.init(repo_dir)
476         self.addCleanup(r.close)
477
478         pre_commit = os.path.join(r.controldir(), 'hooks', 'pre-commit')
479
480         with open(pre_commit, 'w') as f:
481             f.write(pre_commit_fail)
482         os.chmod(pre_commit, stat.S_IREAD | stat.S_IWRITE | stat.S_IEXEC)
483
484         self.assertRaises(errors.CommitError, r.do_commit, 'failed commit',
485                           committer='Test Committer <test@nodomain.com>',
486                           author='Test Author <test@nodomain.com>',
487                           commit_timestamp=12345, commit_timezone=0,
488                           author_timestamp=12345, author_timezone=0)
489
490         with open(pre_commit, 'w') as f:
491             f.write(pre_commit_success)
492         os.chmod(pre_commit, stat.S_IREAD | stat.S_IWRITE | stat.S_IEXEC)
493
494         commit_sha = r.do_commit(
495             b'empty commit',
496             committer=b'Test Committer <test@nodomain.com>',
497             author=b'Test Author <test@nodomain.com>',
498             commit_timestamp=12395, commit_timezone=0,
499             author_timestamp=12395, author_timezone=0)
500         self.assertEqual([], r[commit_sha].parents)
501
502     def test_shell_hook_commit_msg(self):
503         if os.name != 'posix':
504             self.skipTest('shell hook tests requires POSIX shell')
505
506         commit_msg_fail = """#!/bin/sh
507 exit 1
508 """
509
510         commit_msg_success = """#!/bin/sh
511 exit 0
512 """
513
514         repo_dir = self.mkdtemp()
515         self.addCleanup(shutil.rmtree, repo_dir)
516         r = Repo.init(repo_dir)
517         self.addCleanup(r.close)
518
519         commit_msg = os.path.join(r.controldir(), 'hooks', 'commit-msg')
520
521         with open(commit_msg, 'w') as f:
522             f.write(commit_msg_fail)
523         os.chmod(commit_msg, stat.S_IREAD | stat.S_IWRITE | stat.S_IEXEC)
524
525         self.assertRaises(errors.CommitError, r.do_commit, b'failed commit',
526                           committer=b'Test Committer <test@nodomain.com>',
527                           author=b'Test Author <test@nodomain.com>',
528                           commit_timestamp=12345, commit_timezone=0,
529                           author_timestamp=12345, author_timezone=0)
530
531         with open(commit_msg, 'w') as f:
532             f.write(commit_msg_success)
533         os.chmod(commit_msg, stat.S_IREAD | stat.S_IWRITE | stat.S_IEXEC)
534
535         commit_sha = r.do_commit(
536             b'empty commit',
537             committer=b'Test Committer <test@nodomain.com>',
538             author=b'Test Author <test@nodomain.com>',
539             commit_timestamp=12395, commit_timezone=0,
540             author_timestamp=12395, author_timezone=0)
541         self.assertEqual([], r[commit_sha].parents)
542
543     def test_shell_hook_post_commit(self):
544         if os.name != 'posix':
545             self.skipTest('shell hook tests requires POSIX shell')
546
547         repo_dir = self.mkdtemp()
548         self.addCleanup(shutil.rmtree, repo_dir)
549
550         r = Repo.init(repo_dir)
551         self.addCleanup(r.close)
552
553         (fd, path) = tempfile.mkstemp(dir=repo_dir)
554         os.close(fd)
555         post_commit_msg = """#!/bin/sh
556 rm """ + path + """
557 """
558
559         root_sha = r.do_commit(
560             b'empty commit',
561             committer=b'Test Committer <test@nodomain.com>',
562             author=b'Test Author <test@nodomain.com>',
563             commit_timestamp=12345, commit_timezone=0,
564             author_timestamp=12345, author_timezone=0)
565         self.assertEqual([], r[root_sha].parents)
566
567         post_commit = os.path.join(r.controldir(), 'hooks', 'post-commit')
568
569         with open(post_commit, 'wb') as f:
570             f.write(post_commit_msg.encode(locale.getpreferredencoding()))
571         os.chmod(post_commit, stat.S_IREAD | stat.S_IWRITE | stat.S_IEXEC)
572
573         commit_sha = r.do_commit(
574             b'empty commit',
575             committer=b'Test Committer <test@nodomain.com>',
576             author=b'Test Author <test@nodomain.com>',
577             commit_timestamp=12345, commit_timezone=0,
578             author_timestamp=12345, author_timezone=0)
579         self.assertEqual([root_sha], r[commit_sha].parents)
580
581         self.assertFalse(os.path.exists(path))
582
583         post_commit_msg_fail = """#!/bin/sh
584 exit 1
585 """
586         with open(post_commit, 'w') as f:
587             f.write(post_commit_msg_fail)
588         os.chmod(post_commit, stat.S_IREAD | stat.S_IWRITE | stat.S_IEXEC)
589
590         warnings.simplefilter("always", UserWarning)
591         self.addCleanup(warnings.resetwarnings)
592         warnings_list, restore_warnings = setup_warning_catcher()
593         self.addCleanup(restore_warnings)
594
595         commit_sha2 = r.do_commit(
596             b'empty commit',
597             committer=b'Test Committer <test@nodomain.com>',
598             author=b'Test Author <test@nodomain.com>',
599             commit_timestamp=12345, commit_timezone=0,
600             author_timestamp=12345, author_timezone=0)
601         expected_warning = UserWarning(
602             'post-commit hook failed: Hook post-commit exited with '
603             'non-zero status',)
604         for w in warnings_list:
605             if (type(w) == type(expected_warning) and
606                     w.args == expected_warning.args):
607                 break
608         else:
609             raise AssertionError(
610                 'Expected warning %r not in %r' %
611                 (expected_warning, warnings_list))
612         self.assertEqual([commit_sha], r[commit_sha2].parents)
613
614     def test_as_dict(self):
615         def check(repo):
616             self.assertEqual(
617                 repo.refs.subkeys(b'refs/tags'),
618                 repo.refs.subkeys(b'refs/tags/'))
619             self.assertEqual(
620                 repo.refs.as_dict(b'refs/tags'),
621                 repo.refs.as_dict(b'refs/tags/'))
622             self.assertEqual(
623                 repo.refs.as_dict(b'refs/heads'),
624                 repo.refs.as_dict(b'refs/heads/'))
625
626         bare = self.open_repo('a.git')
627         tmp_dir = self.mkdtemp()
628         self.addCleanup(shutil.rmtree, tmp_dir)
629         with bare.clone(tmp_dir, mkdir=False) as nonbare:
630             check(nonbare)
631             check(bare)
632
633     def test_working_tree(self):
634         temp_dir = tempfile.mkdtemp()
635         self.addCleanup(shutil.rmtree, temp_dir)
636         worktree_temp_dir = tempfile.mkdtemp()
637         self.addCleanup(shutil.rmtree, worktree_temp_dir)
638         r = Repo.init(temp_dir)
639         self.addCleanup(r.close)
640         root_sha = r.do_commit(
641                 b'empty commit',
642                 committer=b'Test Committer <test@nodomain.com>',
643                 author=b'Test Author <test@nodomain.com>',
644                 commit_timestamp=12345, commit_timezone=0,
645                 author_timestamp=12345, author_timezone=0)
646         r.refs[b'refs/heads/master'] = root_sha
647         w = Repo._init_new_working_directory(worktree_temp_dir, r)
648         self.addCleanup(w.close)
649         new_sha = w.do_commit(
650                 b'new commit',
651                 committer=b'Test Committer <test@nodomain.com>',
652                 author=b'Test Author <test@nodomain.com>',
653                 commit_timestamp=12345, commit_timezone=0,
654                 author_timestamp=12345, author_timezone=0)
655         w.refs[b'HEAD'] = new_sha
656         self.assertEqual(os.path.abspath(r.controldir()),
657                          os.path.abspath(w.commondir()))
658         self.assertEqual(r.refs.keys(), w.refs.keys())
659         self.assertNotEqual(r.head(), w.head())
660
661
662 class BuildRepoRootTests(TestCase):
663     """Tests that build on-disk repos from scratch.
664
665     Repos live in a temp dir and are torn down after each test. They start with
666     a single commit in master having single file named 'a'.
667     """
668
669     def get_repo_dir(self):
670         return os.path.join(tempfile.mkdtemp(), 'test')
671
672     def setUp(self):
673         super(BuildRepoRootTests, self).setUp()
674         self._repo_dir = self.get_repo_dir()
675         os.makedirs(self._repo_dir)
676         r = self._repo = Repo.init(self._repo_dir)
677         self.addCleanup(tear_down_repo, r)
678         self.assertFalse(r.bare)
679         self.assertEqual(b'ref: refs/heads/master', r.refs.read_ref(b'HEAD'))
680         self.assertRaises(KeyError, lambda: r.refs[b'refs/heads/master'])
681
682         with open(os.path.join(r.path, 'a'), 'wb') as f:
683             f.write(b'file contents')
684         r.stage(['a'])
685         commit_sha = r.do_commit(
686                 b'msg',
687                 committer=b'Test Committer <test@nodomain.com>',
688                 author=b'Test Author <test@nodomain.com>',
689                 commit_timestamp=12345, commit_timezone=0,
690                 author_timestamp=12345, author_timezone=0)
691         self.assertEqual([], r[commit_sha].parents)
692         self._root_commit = commit_sha
693
694     def test_get_shallow(self):
695         self.assertEqual(set(), self._repo.get_shallow())
696         with open(os.path.join(self._repo.path, '.git', 'shallow'), 'wb') as f:
697             f.write(b'a90fa2d900a17e99b433217e988c4eb4a2e9a097\n')
698         self.assertEqual({b'a90fa2d900a17e99b433217e988c4eb4a2e9a097'},
699                          self._repo.get_shallow())
700
701     def test_update_shallow(self):
702         self._repo.update_shallow(None, None)  # no op
703         self.assertEquals(set(), self._repo.get_shallow())
704         self._repo.update_shallow(
705                 [b'a90fa2d900a17e99b433217e988c4eb4a2e9a097'],
706                 None)
707         self.assertEqual(
708                 {b'a90fa2d900a17e99b433217e988c4eb4a2e9a097'},
709                 self._repo.get_shallow())
710         self._repo.update_shallow(
711                 [b'a90fa2d900a17e99b433217e988c4eb4a2e9a097'],
712                 [b'f9e39b120c68182a4ba35349f832d0e4e61f485c'])
713         self.assertEqual({b'a90fa2d900a17e99b433217e988c4eb4a2e9a097'},
714                          self._repo.get_shallow())
715
716     def test_build_repo(self):
717         r = self._repo
718         self.assertEqual(b'ref: refs/heads/master', r.refs.read_ref(b'HEAD'))
719         self.assertEqual(self._root_commit, r.refs[b'refs/heads/master'])
720         expected_blob = objects.Blob.from_string(b'file contents')
721         self.assertEqual(expected_blob.data, r[expected_blob.id].data)
722         actual_commit = r[self._root_commit]
723         self.assertEqual(b'msg', actual_commit.message)
724
725     def test_commit_modified(self):
726         r = self._repo
727         with open(os.path.join(r.path, 'a'), 'wb') as f:
728             f.write(b'new contents')
729         r.stage(['a'])
730         commit_sha = r.do_commit(
731             b'modified a',
732             committer=b'Test Committer <test@nodomain.com>',
733             author=b'Test Author <test@nodomain.com>',
734             commit_timestamp=12395, commit_timezone=0,
735             author_timestamp=12395, author_timezone=0)
736         self.assertEqual([self._root_commit], r[commit_sha].parents)
737         a_mode, a_id = tree_lookup_path(r.get_object, r[commit_sha].tree, b'a')
738         self.assertEqual(stat.S_IFREG | 0o644, a_mode)
739         self.assertEqual(b'new contents', r[a_id].data)
740
741     @skipIf(not getattr(os, 'symlink', None), 'Requires symlink support')
742     def test_commit_symlink(self):
743         r = self._repo
744         os.symlink('a', os.path.join(r.path, 'b'))
745         r.stage(['a', 'b'])
746         commit_sha = r.do_commit(
747             b'Symlink b',
748             committer=b'Test Committer <test@nodomain.com>',
749             author=b'Test Author <test@nodomain.com>',
750             commit_timestamp=12395, commit_timezone=0,
751             author_timestamp=12395, author_timezone=0)
752         self.assertEqual([self._root_commit], r[commit_sha].parents)
753         b_mode, b_id = tree_lookup_path(r.get_object, r[commit_sha].tree, b'b')
754         self.assertTrue(stat.S_ISLNK(b_mode))
755         self.assertEqual(b'a', r[b_id].data)
756
757     def test_commit_deleted(self):
758         r = self._repo
759         os.remove(os.path.join(r.path, 'a'))
760         r.stage(['a'])
761         commit_sha = r.do_commit(
762             b'deleted a',
763             committer=b'Test Committer <test@nodomain.com>',
764             author=b'Test Author <test@nodomain.com>',
765             commit_timestamp=12395, commit_timezone=0,
766             author_timestamp=12395, author_timezone=0)
767         self.assertEqual([self._root_commit], r[commit_sha].parents)
768         self.assertEqual([], list(r.open_index()))
769         tree = r[r[commit_sha].tree]
770         self.assertEqual([], list(tree.iteritems()))
771
772     def test_commit_follows(self):
773         r = self._repo
774         r.refs.set_symbolic_ref(b'HEAD', b'refs/heads/bla')
775         commit_sha = r.do_commit(
776             b'commit with strange character',
777             committer=b'Test Committer <test@nodomain.com>',
778             author=b'Test Author <test@nodomain.com>',
779             commit_timestamp=12395, commit_timezone=0,
780             author_timestamp=12395, author_timezone=0,
781             ref=b'HEAD')
782         self.assertEqual(commit_sha, r[b'refs/heads/bla'].id)
783
784     def test_commit_encoding(self):
785         r = self._repo
786         commit_sha = r.do_commit(
787             b'commit with strange character \xee',
788             committer=b'Test Committer <test@nodomain.com>',
789             author=b'Test Author <test@nodomain.com>',
790             commit_timestamp=12395, commit_timezone=0,
791             author_timestamp=12395, author_timezone=0,
792             encoding=b"iso8859-1")
793         self.assertEqual(b"iso8859-1", r[commit_sha].encoding)
794
795     def test_commit_encoding_from_config(self):
796         r = self._repo
797         c = r.get_config()
798         c.set(('i18n',), 'commitEncoding', 'iso8859-1')
799         c.write_to_path()
800         commit_sha = r.do_commit(
801             b'commit with strange character \xee',
802             committer=b'Test Committer <test@nodomain.com>',
803             author=b'Test Author <test@nodomain.com>',
804             commit_timestamp=12395, commit_timezone=0,
805             author_timestamp=12395, author_timezone=0)
806         self.assertEqual(b"iso8859-1", r[commit_sha].encoding)
807
808     def test_commit_config_identity(self):
809         # commit falls back to the users' identity if it wasn't specified
810         r = self._repo
811         c = r.get_config()
812         c.set((b"user", ), b"name", b"Jelmer")
813         c.set((b"user", ), b"email", b"jelmer@apache.org")
814         c.write_to_path()
815         commit_sha = r.do_commit(b'message')
816         self.assertEqual(
817             b"Jelmer <jelmer@apache.org>",
818             r[commit_sha].author)
819         self.assertEqual(
820             b"Jelmer <jelmer@apache.org>",
821             r[commit_sha].committer)
822
823     def test_commit_config_identity_in_memoryrepo(self):
824         # commit falls back to the users' identity if it wasn't specified
825         r = MemoryRepo.init_bare([], {})
826         c = r.get_config()
827         c.set((b"user", ), b"name", b"Jelmer")
828         c.set((b"user", ), b"email", b"jelmer@apache.org")
829
830         commit_sha = r.do_commit(b'message', tree=objects.Tree().id)
831         self.assertEqual(
832             b"Jelmer <jelmer@apache.org>",
833             r[commit_sha].author)
834         self.assertEqual(
835             b"Jelmer <jelmer@apache.org>",
836             r[commit_sha].committer)
837
838     def test_commit_fail_ref(self):
839         r = self._repo
840
841         def set_if_equals(name, old_ref, new_ref, **kwargs):
842             return False
843         r.refs.set_if_equals = set_if_equals
844
845         def add_if_new(name, new_ref, **kwargs):
846             self.fail('Unexpected call to add_if_new')
847         r.refs.add_if_new = add_if_new
848
849         old_shas = set(r.object_store)
850         self.assertRaises(errors.CommitError, r.do_commit, b'failed commit',
851                           committer=b'Test Committer <test@nodomain.com>',
852                           author=b'Test Author <test@nodomain.com>',
853                           commit_timestamp=12345, commit_timezone=0,
854                           author_timestamp=12345, author_timezone=0)
855         new_shas = set(r.object_store) - old_shas
856         self.assertEqual(1, len(new_shas))
857         # Check that the new commit (now garbage) was added.
858         new_commit = r[new_shas.pop()]
859         self.assertEqual(r[self._root_commit].tree, new_commit.tree)
860         self.assertEqual(b'failed commit', new_commit.message)
861
862     def test_commit_branch(self):
863         r = self._repo
864
865         commit_sha = r.do_commit(
866             b'commit to branch',
867             committer=b'Test Committer <test@nodomain.com>',
868             author=b'Test Author <test@nodomain.com>',
869             commit_timestamp=12395, commit_timezone=0,
870             author_timestamp=12395, author_timezone=0,
871             ref=b"refs/heads/new_branch")
872         self.assertEqual(self._root_commit, r[b"HEAD"].id)
873         self.assertEqual(commit_sha, r[b"refs/heads/new_branch"].id)
874         self.assertEqual([], r[commit_sha].parents)
875         self.assertTrue(b"refs/heads/new_branch" in r)
876
877         new_branch_head = commit_sha
878
879         commit_sha = r.do_commit(
880             b'commit to branch 2',
881             committer=b'Test Committer <test@nodomain.com>',
882             author=b'Test Author <test@nodomain.com>',
883             commit_timestamp=12395, commit_timezone=0,
884             author_timestamp=12395, author_timezone=0,
885             ref=b"refs/heads/new_branch")
886         self.assertEqual(self._root_commit, r[b"HEAD"].id)
887         self.assertEqual(commit_sha, r[b"refs/heads/new_branch"].id)
888         self.assertEqual([new_branch_head], r[commit_sha].parents)
889
890     def test_commit_merge_heads(self):
891         r = self._repo
892         merge_1 = r.do_commit(
893             b'commit to branch 2',
894             committer=b'Test Committer <test@nodomain.com>',
895             author=b'Test Author <test@nodomain.com>',
896             commit_timestamp=12395, commit_timezone=0,
897             author_timestamp=12395, author_timezone=0,
898             ref=b"refs/heads/new_branch")
899         commit_sha = r.do_commit(
900             b'commit with merge',
901             committer=b'Test Committer <test@nodomain.com>',
902             author=b'Test Author <test@nodomain.com>',
903             commit_timestamp=12395, commit_timezone=0,
904             author_timestamp=12395, author_timezone=0,
905             merge_heads=[merge_1])
906         self.assertEqual(
907             [self._root_commit, merge_1],
908             r[commit_sha].parents)
909
910     def test_commit_dangling_commit(self):
911         r = self._repo
912
913         old_shas = set(r.object_store)
914         old_refs = r.get_refs()
915         commit_sha = r.do_commit(
916             b'commit with no ref',
917             committer=b'Test Committer <test@nodomain.com>',
918             author=b'Test Author <test@nodomain.com>',
919             commit_timestamp=12395, commit_timezone=0,
920             author_timestamp=12395, author_timezone=0,
921             ref=None)
922         new_shas = set(r.object_store) - old_shas
923
924         # New sha is added, but no new refs
925         self.assertEqual(1, len(new_shas))
926         new_commit = r[new_shas.pop()]
927         self.assertEqual(r[self._root_commit].tree, new_commit.tree)
928         self.assertEqual([], r[commit_sha].parents)
929         self.assertEqual(old_refs, r.get_refs())
930
931     def test_commit_dangling_commit_with_parents(self):
932         r = self._repo
933
934         old_shas = set(r.object_store)
935         old_refs = r.get_refs()
936         commit_sha = r.do_commit(
937             b'commit with no ref',
938             committer=b'Test Committer <test@nodomain.com>',
939             author=b'Test Author <test@nodomain.com>',
940             commit_timestamp=12395, commit_timezone=0,
941             author_timestamp=12395, author_timezone=0,
942             ref=None, merge_heads=[self._root_commit])
943         new_shas = set(r.object_store) - old_shas
944
945         # New sha is added, but no new refs
946         self.assertEqual(1, len(new_shas))
947         new_commit = r[new_shas.pop()]
948         self.assertEqual(r[self._root_commit].tree, new_commit.tree)
949         self.assertEqual([self._root_commit], r[commit_sha].parents)
950         self.assertEqual(old_refs, r.get_refs())
951
952     def test_stage_absolute(self):
953         r = self._repo
954         os.remove(os.path.join(r.path, 'a'))
955         self.assertRaises(ValueError, r.stage, [os.path.join(r.path, 'a')])
956
957     def test_stage_deleted(self):
958         r = self._repo
959         os.remove(os.path.join(r.path, 'a'))
960         r.stage(['a'])
961         r.stage(['a'])  # double-stage a deleted path
962
963     def test_stage_directory(self):
964         r = self._repo
965         os.mkdir(os.path.join(r.path, 'c'))
966         r.stage(['c'])
967         self.assertEqual([b'a'], list(r.open_index()))
968
969     @skipIf(sys.platform == 'win32' and sys.version_info[:2] >= (3, 6),
970             'tries to implicitly decode as utf8')
971     def test_commit_no_encode_decode(self):
972         r = self._repo
973         repo_path_bytes = r.path.encode(sys.getfilesystemencoding())
974         encodings = ('utf8', 'latin1')
975         names = [u'À'.encode(encoding) for encoding in encodings]
976         for name, encoding in zip(names, encodings):
977             full_path = os.path.join(repo_path_bytes, name)
978             with open(full_path, 'wb') as f:
979                 f.write(encoding.encode('ascii'))
980             # These files are break tear_down_repo, so cleanup these files
981             # ourselves.
982             self.addCleanup(os.remove, full_path)
983
984         r.stage(names)
985         commit_sha = r.do_commit(
986             b'Files with different encodings',
987             committer=b'Test Committer <test@nodomain.com>',
988             author=b'Test Author <test@nodomain.com>',
989             commit_timestamp=12395, commit_timezone=0,
990             author_timestamp=12395, author_timezone=0,
991             ref=None, merge_heads=[self._root_commit])
992
993         for name, encoding in zip(names, encodings):
994             mode, id = tree_lookup_path(r.get_object, r[commit_sha].tree, name)
995             self.assertEqual(stat.S_IFREG | 0o644, mode)
996             self.assertEqual(encoding.encode('ascii'), r[id].data)
997
998     def test_discover_intended(self):
999         path = os.path.join(self._repo_dir, 'b/c')
1000         r = Repo.discover(path)
1001         self.assertEqual(r.head(), self._repo.head())
1002
1003     def test_discover_isrepo(self):
1004         r = Repo.discover(self._repo_dir)
1005         self.assertEqual(r.head(), self._repo.head())
1006
1007     def test_discover_notrepo(self):
1008         with self.assertRaises(NotGitRepository):
1009             Repo.discover('/')
1010
1011
1012 class CheckUserIdentityTests(TestCase):
1013
1014     def test_valid(self):
1015         check_user_identity(b'Me <me@example.com>')
1016
1017     def test_invalid(self):
1018         self.assertRaises(InvalidUserIdentity,
1019                           check_user_identity, b'No Email')
1020         self.assertRaises(InvalidUserIdentity,
1021                           check_user_identity, b'Fullname <missing')
1022         self.assertRaises(InvalidUserIdentity,
1023                           check_user_identity, b'Fullname missing>')
1024         self.assertRaises(InvalidUserIdentity,
1025                           check_user_identity, b'Fullname >order<>')