• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright (C) 2003 Python Software Foundation
2
3import unittest
4import shutil
5import tempfile
6import sys
7import stat
8import os
9import os.path
10import errno
11import subprocess
12from distutils.spawn import find_executable
13from shutil import (make_archive,
14                    register_archive_format, unregister_archive_format,
15                    get_archive_formats)
16import tarfile
17import warnings
18
19from test import test_support as support
20from test.test_support import TESTFN, check_warnings, captured_stdout
21
22TESTFN2 = TESTFN + "2"
23
24try:
25    import grp
26    import pwd
27    UID_GID_SUPPORT = True
28except ImportError:
29    UID_GID_SUPPORT = False
30
31try:
32    import zlib
33except ImportError:
34    zlib = None
35
36try:
37    import zipfile
38    ZIP_SUPPORT = True
39except ImportError:
40    ZIP_SUPPORT = find_executable('zip')
41
42class TestShutil(unittest.TestCase):
43
44    def setUp(self):
45        super(TestShutil, self).setUp()
46        self.tempdirs = []
47
48    def tearDown(self):
49        super(TestShutil, self).tearDown()
50        while self.tempdirs:
51            d = self.tempdirs.pop()
52            shutil.rmtree(d, os.name in ('nt', 'cygwin'))
53
54    def write_file(self, path, content='xxx'):
55        """Writes a file in the given path.
56
57
58        path can be a string or a sequence.
59        """
60        if isinstance(path, (list, tuple)):
61            path = os.path.join(*path)
62        f = open(path, 'w')
63        try:
64            f.write(content)
65        finally:
66            f.close()
67
68    def mkdtemp(self):
69        """Create a temporary directory that will be cleaned up.
70
71        Returns the path of the directory.
72        """
73        d = tempfile.mkdtemp()
74        self.tempdirs.append(d)
75        return d
76    def test_rmtree_errors(self):
77        # filename is guaranteed not to exist
78        filename = tempfile.mktemp()
79        self.assertRaises(OSError, shutil.rmtree, filename)
80
81    @unittest.skipUnless(hasattr(os, 'chmod'), 'requires os.chmod()')
82    @unittest.skipIf(sys.platform[:6] == 'cygwin',
83                     "This test can't be run on Cygwin (issue #1071513).")
84    @unittest.skipIf(hasattr(os, 'geteuid') and os.geteuid() == 0,
85                     "This test can't be run reliably as root (issue #1076467).")
86    def test_on_error(self):
87        self.errorState = 0
88        os.mkdir(TESTFN)
89        self.childpath = os.path.join(TESTFN, 'a')
90        f = open(self.childpath, 'w')
91        f.close()
92        old_dir_mode = os.stat(TESTFN).st_mode
93        old_child_mode = os.stat(self.childpath).st_mode
94        # Make unwritable.
95        os.chmod(self.childpath, stat.S_IREAD)
96        os.chmod(TESTFN, stat.S_IREAD)
97
98        shutil.rmtree(TESTFN, onerror=self.check_args_to_onerror)
99        # Test whether onerror has actually been called.
100        self.assertEqual(self.errorState, 2,
101                            "Expected call to onerror function did not happen.")
102
103        # Make writable again.
104        os.chmod(TESTFN, old_dir_mode)
105        os.chmod(self.childpath, old_child_mode)
106
107        # Clean up.
108        shutil.rmtree(TESTFN)
109
110    def check_args_to_onerror(self, func, arg, exc):
111        # test_rmtree_errors deliberately runs rmtree
112        # on a directory that is chmod 400, which will fail.
113        # This function is run when shutil.rmtree fails.
114        # 99.9% of the time it initially fails to remove
115        # a file in the directory, so the first time through
116        # func is os.remove.
117        # However, some Linux machines running ZFS on
118        # FUSE experienced a failure earlier in the process
119        # at os.listdir.  The first failure may legally
120        # be either.
121        if self.errorState == 0:
122            if func is os.remove:
123                self.assertEqual(arg, self.childpath)
124            else:
125                self.assertIs(func, os.listdir,
126                              "func must be either os.remove or os.listdir")
127                self.assertEqual(arg, TESTFN)
128            self.assertTrue(issubclass(exc[0], OSError))
129            self.errorState = 1
130        else:
131            self.assertEqual(func, os.rmdir)
132            self.assertEqual(arg, TESTFN)
133            self.assertTrue(issubclass(exc[0], OSError))
134            self.errorState = 2
135
136    def test_rmtree_dont_delete_file(self):
137        # When called on a file instead of a directory, don't delete it.
138        handle, path = tempfile.mkstemp()
139        os.fdopen(handle).close()
140        self.assertRaises(OSError, shutil.rmtree, path)
141        os.remove(path)
142
143    def test_copytree_simple(self):
144        def write_data(path, data):
145            f = open(path, "w")
146            f.write(data)
147            f.close()
148
149        def read_data(path):
150            f = open(path)
151            data = f.read()
152            f.close()
153            return data
154
155        src_dir = tempfile.mkdtemp()
156        dst_dir = os.path.join(tempfile.mkdtemp(), 'destination')
157
158        write_data(os.path.join(src_dir, 'test.txt'), '123')
159
160        os.mkdir(os.path.join(src_dir, 'test_dir'))
161        write_data(os.path.join(src_dir, 'test_dir', 'test.txt'), '456')
162
163        try:
164            shutil.copytree(src_dir, dst_dir)
165            self.assertTrue(os.path.isfile(os.path.join(dst_dir, 'test.txt')))
166            self.assertTrue(os.path.isdir(os.path.join(dst_dir, 'test_dir')))
167            self.assertTrue(os.path.isfile(os.path.join(dst_dir, 'test_dir',
168                                                        'test.txt')))
169            actual = read_data(os.path.join(dst_dir, 'test.txt'))
170            self.assertEqual(actual, '123')
171            actual = read_data(os.path.join(dst_dir, 'test_dir', 'test.txt'))
172            self.assertEqual(actual, '456')
173        finally:
174            for path in (
175                    os.path.join(src_dir, 'test.txt'),
176                    os.path.join(dst_dir, 'test.txt'),
177                    os.path.join(src_dir, 'test_dir', 'test.txt'),
178                    os.path.join(dst_dir, 'test_dir', 'test.txt'),
179                ):
180                if os.path.exists(path):
181                    os.remove(path)
182            for path in (src_dir,
183                    os.path.dirname(dst_dir)
184                ):
185                if os.path.exists(path):
186                    shutil.rmtree(path)
187
188    def test_copytree_with_exclude(self):
189
190        def write_data(path, data):
191            f = open(path, "w")
192            f.write(data)
193            f.close()
194
195        def read_data(path):
196            f = open(path)
197            data = f.read()
198            f.close()
199            return data
200
201        # creating data
202        join = os.path.join
203        exists = os.path.exists
204        src_dir = tempfile.mkdtemp()
205        try:
206            dst_dir = join(tempfile.mkdtemp(), 'destination')
207            write_data(join(src_dir, 'test.txt'), '123')
208            write_data(join(src_dir, 'test.tmp'), '123')
209            os.mkdir(join(src_dir, 'test_dir'))
210            write_data(join(src_dir, 'test_dir', 'test.txt'), '456')
211            os.mkdir(join(src_dir, 'test_dir2'))
212            write_data(join(src_dir, 'test_dir2', 'test.txt'), '456')
213            os.mkdir(join(src_dir, 'test_dir2', 'subdir'))
214            os.mkdir(join(src_dir, 'test_dir2', 'subdir2'))
215            write_data(join(src_dir, 'test_dir2', 'subdir', 'test.txt'), '456')
216            write_data(join(src_dir, 'test_dir2', 'subdir2', 'test.py'), '456')
217
218
219            # testing glob-like patterns
220            try:
221                patterns = shutil.ignore_patterns('*.tmp', 'test_dir2')
222                shutil.copytree(src_dir, dst_dir, ignore=patterns)
223                # checking the result: some elements should not be copied
224                self.assertTrue(exists(join(dst_dir, 'test.txt')))
225                self.assertTrue(not exists(join(dst_dir, 'test.tmp')))
226                self.assertTrue(not exists(join(dst_dir, 'test_dir2')))
227            finally:
228                if os.path.exists(dst_dir):
229                    shutil.rmtree(dst_dir)
230            try:
231                patterns = shutil.ignore_patterns('*.tmp', 'subdir*')
232                shutil.copytree(src_dir, dst_dir, ignore=patterns)
233                # checking the result: some elements should not be copied
234                self.assertTrue(not exists(join(dst_dir, 'test.tmp')))
235                self.assertTrue(not exists(join(dst_dir, 'test_dir2', 'subdir2')))
236                self.assertTrue(not exists(join(dst_dir, 'test_dir2', 'subdir')))
237            finally:
238                if os.path.exists(dst_dir):
239                    shutil.rmtree(dst_dir)
240
241            # testing callable-style
242            try:
243                def _filter(src, names):
244                    res = []
245                    for name in names:
246                        path = os.path.join(src, name)
247
248                        if (os.path.isdir(path) and
249                            path.split()[-1] == 'subdir'):
250                            res.append(name)
251                        elif os.path.splitext(path)[-1] in ('.py'):
252                            res.append(name)
253                    return res
254
255                shutil.copytree(src_dir, dst_dir, ignore=_filter)
256
257                # checking the result: some elements should not be copied
258                self.assertTrue(not exists(join(dst_dir, 'test_dir2', 'subdir2',
259                                        'test.py')))
260                self.assertTrue(not exists(join(dst_dir, 'test_dir2', 'subdir')))
261
262            finally:
263                if os.path.exists(dst_dir):
264                    shutil.rmtree(dst_dir)
265        finally:
266            shutil.rmtree(src_dir)
267            shutil.rmtree(os.path.dirname(dst_dir))
268
269    if hasattr(os, "symlink"):
270        def test_dont_copy_file_onto_link_to_itself(self):
271            # bug 851123.
272            os.mkdir(TESTFN)
273            src = os.path.join(TESTFN, 'cheese')
274            dst = os.path.join(TESTFN, 'shop')
275            try:
276                f = open(src, 'w')
277                f.write('cheddar')
278                f.close()
279
280                os.link(src, dst)
281                self.assertRaises(shutil.Error, shutil.copyfile, src, dst)
282                with open(src, 'r') as f:
283                    self.assertEqual(f.read(), 'cheddar')
284                os.remove(dst)
285
286                # Using `src` here would mean we end up with a symlink pointing
287                # to TESTFN/TESTFN/cheese, while it should point at
288                # TESTFN/cheese.
289                os.symlink('cheese', dst)
290                self.assertRaises(shutil.Error, shutil.copyfile, src, dst)
291                with open(src, 'r') as f:
292                    self.assertEqual(f.read(), 'cheddar')
293                os.remove(dst)
294            finally:
295                try:
296                    shutil.rmtree(TESTFN)
297                except OSError:
298                    pass
299
300        def test_rmtree_on_symlink(self):
301            # bug 1669.
302            os.mkdir(TESTFN)
303            try:
304                src = os.path.join(TESTFN, 'cheese')
305                dst = os.path.join(TESTFN, 'shop')
306                os.mkdir(src)
307                os.symlink(src, dst)
308                self.assertRaises(OSError, shutil.rmtree, dst)
309            finally:
310                shutil.rmtree(TESTFN, ignore_errors=True)
311
312    # Issue #3002: copyfile and copytree block indefinitely on named pipes
313    @unittest.skipUnless(hasattr(os, "mkfifo"), 'requires os.mkfifo()')
314    def test_copyfile_named_pipe(self):
315        os.mkfifo(TESTFN)
316        try:
317            self.assertRaises(shutil.SpecialFileError,
318                              shutil.copyfile, TESTFN, TESTFN2)
319            self.assertRaises(shutil.SpecialFileError,
320                              shutil.copyfile, __file__, TESTFN)
321        finally:
322            os.remove(TESTFN)
323
324    @unittest.skipUnless(hasattr(os, "mkfifo"), 'requires os.mkfifo()')
325    def test_copytree_named_pipe(self):
326        os.mkdir(TESTFN)
327        try:
328            subdir = os.path.join(TESTFN, "subdir")
329            os.mkdir(subdir)
330            pipe = os.path.join(subdir, "mypipe")
331            os.mkfifo(pipe)
332            try:
333                shutil.copytree(TESTFN, TESTFN2)
334            except shutil.Error as e:
335                errors = e.args[0]
336                self.assertEqual(len(errors), 1)
337                src, dst, error_msg = errors[0]
338                self.assertEqual("`%s` is a named pipe" % pipe, error_msg)
339            else:
340                self.fail("shutil.Error should have been raised")
341        finally:
342            shutil.rmtree(TESTFN, ignore_errors=True)
343            shutil.rmtree(TESTFN2, ignore_errors=True)
344
345    @unittest.skipUnless(hasattr(os, 'chflags') and
346                         hasattr(errno, 'EOPNOTSUPP') and
347                         hasattr(errno, 'ENOTSUP'),
348                         "requires os.chflags, EOPNOTSUPP & ENOTSUP")
349    def test_copystat_handles_harmless_chflags_errors(self):
350        tmpdir = self.mkdtemp()
351        file1 = os.path.join(tmpdir, 'file1')
352        file2 = os.path.join(tmpdir, 'file2')
353        self.write_file(file1, 'xxx')
354        self.write_file(file2, 'xxx')
355
356        def make_chflags_raiser(err):
357            ex = OSError()
358
359            def _chflags_raiser(path, flags):
360                ex.errno = err
361                raise ex
362            return _chflags_raiser
363        old_chflags = os.chflags
364        try:
365            for err in errno.EOPNOTSUPP, errno.ENOTSUP:
366                os.chflags = make_chflags_raiser(err)
367                shutil.copystat(file1, file2)
368            # assert others errors break it
369            os.chflags = make_chflags_raiser(errno.EOPNOTSUPP + errno.ENOTSUP)
370            self.assertRaises(OSError, shutil.copystat, file1, file2)
371        finally:
372            os.chflags = old_chflags
373
374    @unittest.skipUnless(zlib, "requires zlib")
375    def test_make_tarball(self):
376        # creating something to tar
377        root_dir, base_dir = self._create_files('')
378
379        tmpdir2 = self.mkdtemp()
380        # force shutil to create the directory
381        os.rmdir(tmpdir2)
382        # working with relative paths
383        work_dir = os.path.dirname(tmpdir2)
384        rel_base_name = os.path.join(os.path.basename(tmpdir2), 'archive')
385
386        with support.change_cwd(work_dir):
387            base_name = os.path.abspath(rel_base_name)
388            tarball = make_archive(rel_base_name, 'gztar', root_dir, '.')
389
390        # check if the compressed tarball was created
391        self.assertEqual(tarball, base_name + '.tar.gz')
392        self.assertTrue(os.path.isfile(tarball))
393        self.assertTrue(tarfile.is_tarfile(tarball))
394        with tarfile.open(tarball, 'r:gz') as tf:
395            self.assertEqual(sorted(tf.getnames()),
396                             ['.', './file1', './file2',
397                              './sub', './sub/file3', './sub2'])
398
399        # trying an uncompressed one
400        with support.change_cwd(work_dir):
401            tarball = make_archive(rel_base_name, 'tar', root_dir, '.')
402        self.assertEqual(tarball, base_name + '.tar')
403        self.assertTrue(os.path.isfile(tarball))
404        self.assertTrue(tarfile.is_tarfile(tarball))
405        with tarfile.open(tarball, 'r') as tf:
406            self.assertEqual(sorted(tf.getnames()),
407                             ['.', './file1', './file2',
408                              './sub', './sub/file3', './sub2'])
409
410    def _tarinfo(self, path):
411        with tarfile.open(path) as tar:
412            names = tar.getnames()
413            names.sort()
414            return tuple(names)
415
416    def _create_files(self, base_dir='dist'):
417        # creating something to tar
418        root_dir = self.mkdtemp()
419        dist = os.path.join(root_dir, base_dir)
420        if not os.path.isdir(dist):
421            os.makedirs(dist)
422        self.write_file((dist, 'file1'), 'xxx')
423        self.write_file((dist, 'file2'), 'xxx')
424        os.mkdir(os.path.join(dist, 'sub'))
425        self.write_file((dist, 'sub', 'file3'), 'xxx')
426        os.mkdir(os.path.join(dist, 'sub2'))
427        if base_dir:
428            self.write_file((root_dir, 'outer'), 'xxx')
429        return root_dir, base_dir
430
431    @unittest.skipUnless(zlib, "Requires zlib")
432    @unittest.skipUnless(find_executable('tar'),
433                         'Need the tar command to run')
434    def test_tarfile_vs_tar(self):
435        root_dir, base_dir = self._create_files()
436        base_name = os.path.join(self.mkdtemp(), 'archive')
437        tarball = make_archive(base_name, 'gztar', root_dir, base_dir)
438
439        # check if the compressed tarball was created
440        self.assertEqual(tarball, base_name + '.tar.gz')
441        self.assertTrue(os.path.isfile(tarball))
442
443        # now create another tarball using `tar`
444        tarball2 = os.path.join(root_dir, 'archive2.tar')
445        tar_cmd = ['tar', '-cf', 'archive2.tar', base_dir]
446        subprocess.check_call(tar_cmd, cwd=root_dir)
447
448        self.assertTrue(os.path.isfile(tarball2))
449        # let's compare both tarballs
450        self.assertEqual(self._tarinfo(tarball), self._tarinfo(tarball2))
451
452        # trying an uncompressed one
453        tarball = make_archive(base_name, 'tar', root_dir, base_dir)
454        self.assertEqual(tarball, base_name + '.tar')
455        self.assertTrue(os.path.isfile(tarball))
456
457        # now for a dry_run
458        tarball = make_archive(base_name, 'tar', root_dir, base_dir,
459                               dry_run=True)
460        self.assertEqual(tarball, base_name + '.tar')
461        self.assertTrue(os.path.isfile(tarball))
462
463    @unittest.skipUnless(zlib, "Requires zlib")
464    @unittest.skipUnless(ZIP_SUPPORT, 'Need zip support to run')
465    def test_make_zipfile(self):
466        # creating something to zip
467        root_dir, base_dir = self._create_files()
468
469        tmpdir2 = self.mkdtemp()
470        # force shutil to create the directory
471        os.rmdir(tmpdir2)
472        # working with relative paths
473        work_dir = os.path.dirname(tmpdir2)
474        rel_base_name = os.path.join(os.path.basename(tmpdir2), 'archive')
475
476        with support.change_cwd(work_dir):
477            base_name = os.path.abspath(rel_base_name)
478            res = make_archive(rel_base_name, 'zip', root_dir)
479
480        self.assertEqual(res, base_name + '.zip')
481        self.assertTrue(os.path.isfile(res))
482        self.assertTrue(zipfile.is_zipfile(res))
483        with zipfile.ZipFile(res) as zf:
484            self.assertEqual(sorted(zf.namelist()),
485                    ['dist/', 'dist/file1', 'dist/file2',
486                     'dist/sub/', 'dist/sub/file3', 'dist/sub2/',
487                     'outer'])
488
489        with support.change_cwd(work_dir):
490            base_name = os.path.abspath(rel_base_name)
491            res = make_archive(rel_base_name, 'zip', root_dir, base_dir)
492
493        self.assertEqual(res, base_name + '.zip')
494        self.assertTrue(os.path.isfile(res))
495        self.assertTrue(zipfile.is_zipfile(res))
496        with zipfile.ZipFile(res) as zf:
497            self.assertEqual(sorted(zf.namelist()),
498                    ['dist/', 'dist/file1', 'dist/file2',
499                     'dist/sub/', 'dist/sub/file3', 'dist/sub2/'])
500
501    @unittest.skipUnless(zlib, "Requires zlib")
502    @unittest.skipUnless(ZIP_SUPPORT, 'Need zip support to run')
503    @unittest.skipUnless(find_executable('zip'),
504                         'Need the zip command to run')
505    def test_zipfile_vs_zip(self):
506        root_dir, base_dir = self._create_files()
507        base_name = os.path.join(self.mkdtemp(), 'archive')
508        archive = make_archive(base_name, 'zip', root_dir, base_dir)
509
510        # check if ZIP file  was created
511        self.assertEqual(archive, base_name + '.zip')
512        self.assertTrue(os.path.isfile(archive))
513
514        # now create another ZIP file using `zip`
515        archive2 = os.path.join(root_dir, 'archive2.zip')
516        zip_cmd = ['zip', '-q', '-r', 'archive2.zip', base_dir]
517        subprocess.check_call(zip_cmd, cwd=root_dir)
518
519        self.assertTrue(os.path.isfile(archive2))
520        # let's compare both ZIP files
521        with zipfile.ZipFile(archive) as zf:
522            names = zf.namelist()
523        with zipfile.ZipFile(archive2) as zf:
524            names2 = zf.namelist()
525        self.assertEqual(sorted(names), sorted(names2))
526
527    @unittest.skipUnless(zlib, "Requires zlib")
528    @unittest.skipUnless(ZIP_SUPPORT, 'Need zip support to run')
529    @unittest.skipUnless(find_executable('unzip'),
530                         'Need the unzip command to run')
531    def test_unzip_zipfile(self):
532        root_dir, base_dir = self._create_files()
533        base_name = os.path.join(self.mkdtemp(), 'archive')
534        archive = make_archive(base_name, 'zip', root_dir, base_dir)
535
536        # check if ZIP file  was created
537        self.assertEqual(archive, base_name + '.zip')
538        self.assertTrue(os.path.isfile(archive))
539
540        # now check the ZIP file using `unzip -t`
541        zip_cmd = ['unzip', '-t', archive]
542        with support.change_cwd(root_dir):
543            try:
544                subprocess.check_output(zip_cmd, stderr=subprocess.STDOUT)
545            except subprocess.CalledProcessError as exc:
546                details = exc.output
547                msg = "{}\n\n**Unzip Output**\n{}"
548                self.fail(msg.format(exc, details))
549
550    def test_make_archive(self):
551        tmpdir = self.mkdtemp()
552        base_name = os.path.join(tmpdir, 'archive')
553        self.assertRaises(ValueError, make_archive, base_name, 'xxx')
554
555    @unittest.skipUnless(zlib, "Requires zlib")
556    def test_make_archive_owner_group(self):
557        # testing make_archive with owner and group, with various combinations
558        # this works even if there's not gid/uid support
559        if UID_GID_SUPPORT:
560            group = grp.getgrgid(0)[0]
561            owner = pwd.getpwuid(0)[0]
562        else:
563            group = owner = 'root'
564
565        root_dir, base_dir = self._create_files()
566        base_name = os.path.join(self.mkdtemp(), 'archive')
567        res = make_archive(base_name, 'zip', root_dir, base_dir, owner=owner,
568                           group=group)
569        self.assertTrue(os.path.isfile(res))
570
571        res = make_archive(base_name, 'zip', root_dir, base_dir)
572        self.assertTrue(os.path.isfile(res))
573
574        res = make_archive(base_name, 'tar', root_dir, base_dir,
575                           owner=owner, group=group)
576        self.assertTrue(os.path.isfile(res))
577
578        res = make_archive(base_name, 'tar', root_dir, base_dir,
579                           owner='kjhkjhkjg', group='oihohoh')
580        self.assertTrue(os.path.isfile(res))
581
582    @unittest.skipUnless(zlib, "Requires zlib")
583    @unittest.skipUnless(UID_GID_SUPPORT, "Requires grp and pwd support")
584    def test_tarfile_root_owner(self):
585        root_dir, base_dir = self._create_files()
586        base_name = os.path.join(self.mkdtemp(), 'archive')
587        group = grp.getgrgid(0)[0]
588        owner = pwd.getpwuid(0)[0]
589        with support.change_cwd(root_dir):
590            archive_name = make_archive(base_name, 'gztar', root_dir, 'dist',
591                                        owner=owner, group=group)
592
593        # check if the compressed tarball was created
594        self.assertTrue(os.path.isfile(archive_name))
595
596        # now checks the rights
597        archive = tarfile.open(archive_name)
598        try:
599            for member in archive.getmembers():
600                self.assertEqual(member.uid, 0)
601                self.assertEqual(member.gid, 0)
602        finally:
603            archive.close()
604
605    def test_make_archive_cwd(self):
606        current_dir = os.getcwd()
607        def _breaks(*args, **kw):
608            raise RuntimeError()
609
610        register_archive_format('xxx', _breaks, [], 'xxx file')
611        try:
612            try:
613                make_archive('xxx', 'xxx', root_dir=self.mkdtemp())
614            except Exception:
615                pass
616            self.assertEqual(os.getcwd(), current_dir)
617        finally:
618            unregister_archive_format('xxx')
619
620    def test_make_tarfile_in_curdir(self):
621        # Issue #21280
622        root_dir = self.mkdtemp()
623        saved_dir = os.getcwd()
624        try:
625            os.chdir(root_dir)
626            self.assertEqual(make_archive('test', 'tar'), 'test.tar')
627            self.assertTrue(os.path.isfile('test.tar'))
628        finally:
629            os.chdir(saved_dir)
630
631    @unittest.skipUnless(zlib, "Requires zlib")
632    def test_make_zipfile_in_curdir(self):
633        # Issue #21280
634        root_dir = self.mkdtemp()
635        saved_dir = os.getcwd()
636        try:
637            os.chdir(root_dir)
638            self.assertEqual(make_archive('test', 'zip'), 'test.zip')
639            self.assertTrue(os.path.isfile('test.zip'))
640        finally:
641            os.chdir(saved_dir)
642
643    def test_register_archive_format(self):
644
645        self.assertRaises(TypeError, register_archive_format, 'xxx', 1)
646        self.assertRaises(TypeError, register_archive_format, 'xxx', lambda: x,
647                          1)
648        self.assertRaises(TypeError, register_archive_format, 'xxx', lambda: x,
649                          [(1, 2), (1, 2, 3)])
650
651        register_archive_format('xxx', lambda: x, [(1, 2)], 'xxx file')
652        formats = [name for name, params in get_archive_formats()]
653        self.assertIn('xxx', formats)
654
655        unregister_archive_format('xxx')
656        formats = [name for name, params in get_archive_formats()]
657        self.assertNotIn('xxx', formats)
658
659
660class TestMove(unittest.TestCase):
661
662    def setUp(self):
663        filename = "foo"
664        self.src_dir = tempfile.mkdtemp()
665        self.dst_dir = tempfile.mkdtemp()
666        self.src_file = os.path.join(self.src_dir, filename)
667        self.dst_file = os.path.join(self.dst_dir, filename)
668        # Try to create a dir in the current directory, hoping that it is
669        # not located on the same filesystem as the system tmp dir.
670        try:
671            self.dir_other_fs = tempfile.mkdtemp(
672                dir=os.path.dirname(__file__))
673            self.file_other_fs = os.path.join(self.dir_other_fs,
674                filename)
675        except OSError:
676            self.dir_other_fs = None
677        with open(self.src_file, "wb") as f:
678            f.write("spam")
679
680    def tearDown(self):
681        for d in (self.src_dir, self.dst_dir, self.dir_other_fs):
682            try:
683                if d:
684                    shutil.rmtree(d)
685            except:
686                pass
687
688    def _check_move_file(self, src, dst, real_dst):
689        with open(src, "rb") as f:
690            contents = f.read()
691        shutil.move(src, dst)
692        with open(real_dst, "rb") as f:
693            self.assertEqual(contents, f.read())
694        self.assertFalse(os.path.exists(src))
695
696    def _check_move_dir(self, src, dst, real_dst):
697        contents = sorted(os.listdir(src))
698        shutil.move(src, dst)
699        self.assertEqual(contents, sorted(os.listdir(real_dst)))
700        self.assertFalse(os.path.exists(src))
701
702    def test_move_file(self):
703        # Move a file to another location on the same filesystem.
704        self._check_move_file(self.src_file, self.dst_file, self.dst_file)
705
706    def test_move_file_to_dir(self):
707        # Move a file inside an existing dir on the same filesystem.
708        self._check_move_file(self.src_file, self.dst_dir, self.dst_file)
709
710    def test_move_file_other_fs(self):
711        # Move a file to an existing dir on another filesystem.
712        if not self.dir_other_fs:
713            self.skipTest('dir on other filesystem not available')
714        self._check_move_file(self.src_file, self.file_other_fs,
715            self.file_other_fs)
716
717    def test_move_file_to_dir_other_fs(self):
718        # Move a file to another location on another filesystem.
719        if not self.dir_other_fs:
720            self.skipTest('dir on other filesystem not available')
721        self._check_move_file(self.src_file, self.dir_other_fs,
722            self.file_other_fs)
723
724    def test_move_dir(self):
725        # Move a dir to another location on the same filesystem.
726        dst_dir = tempfile.mktemp()
727        try:
728            self._check_move_dir(self.src_dir, dst_dir, dst_dir)
729        finally:
730            try:
731                shutil.rmtree(dst_dir)
732            except:
733                pass
734
735    def test_move_dir_other_fs(self):
736        # Move a dir to another location on another filesystem.
737        if not self.dir_other_fs:
738            self.skipTest('dir on other filesystem not available')
739        dst_dir = tempfile.mktemp(dir=self.dir_other_fs)
740        try:
741            self._check_move_dir(self.src_dir, dst_dir, dst_dir)
742        finally:
743            try:
744                shutil.rmtree(dst_dir)
745            except:
746                pass
747
748    def test_move_dir_to_dir(self):
749        # Move a dir inside an existing dir on the same filesystem.
750        self._check_move_dir(self.src_dir, self.dst_dir,
751            os.path.join(self.dst_dir, os.path.basename(self.src_dir)))
752
753    def test_move_dir_to_dir_other_fs(self):
754        # Move a dir inside an existing dir on another filesystem.
755        if not self.dir_other_fs:
756            self.skipTest('dir on other filesystem not available')
757        self._check_move_dir(self.src_dir, self.dir_other_fs,
758            os.path.join(self.dir_other_fs, os.path.basename(self.src_dir)))
759
760    def test_move_dir_sep_to_dir(self):
761        self._check_move_dir(self.src_dir + os.path.sep, self.dst_dir,
762            os.path.join(self.dst_dir, os.path.basename(self.src_dir)))
763
764    @unittest.skipUnless(os.path.altsep, 'requires os.path.altsep')
765    def test_move_dir_altsep_to_dir(self):
766        self._check_move_dir(self.src_dir + os.path.altsep, self.dst_dir,
767            os.path.join(self.dst_dir, os.path.basename(self.src_dir)))
768
769    def test_existing_file_inside_dest_dir(self):
770        # A file with the same name inside the destination dir already exists.
771        with open(self.dst_file, "wb"):
772            pass
773        self.assertRaises(shutil.Error, shutil.move, self.src_file, self.dst_dir)
774
775    def test_dont_move_dir_in_itself(self):
776        # Moving a dir inside itself raises an Error.
777        dst = os.path.join(self.src_dir, "bar")
778        self.assertRaises(shutil.Error, shutil.move, self.src_dir, dst)
779
780    def test_destinsrc_false_negative(self):
781        os.mkdir(TESTFN)
782        try:
783            for src, dst in [('srcdir', 'srcdir/dest')]:
784                src = os.path.join(TESTFN, src)
785                dst = os.path.join(TESTFN, dst)
786                self.assertTrue(shutil._destinsrc(src, dst),
787                             msg='_destinsrc() wrongly concluded that '
788                             'dst (%s) is not in src (%s)' % (dst, src))
789        finally:
790            shutil.rmtree(TESTFN, ignore_errors=True)
791
792    def test_destinsrc_false_positive(self):
793        os.mkdir(TESTFN)
794        try:
795            for src, dst in [('srcdir', 'src/dest'), ('srcdir', 'srcdir.new')]:
796                src = os.path.join(TESTFN, src)
797                dst = os.path.join(TESTFN, dst)
798                self.assertFalse(shutil._destinsrc(src, dst),
799                            msg='_destinsrc() wrongly concluded that '
800                            'dst (%s) is in src (%s)' % (dst, src))
801        finally:
802            shutil.rmtree(TESTFN, ignore_errors=True)
803
804
805class TestCopyFile(unittest.TestCase):
806
807    _delete = False
808
809    class Faux(object):
810        _entered = False
811        _exited_with = None
812        _raised = False
813        def __init__(self, raise_in_exit=False, suppress_at_exit=True):
814            self._raise_in_exit = raise_in_exit
815            self._suppress_at_exit = suppress_at_exit
816        def read(self, *args):
817            return ''
818        def __enter__(self):
819            self._entered = True
820        def __exit__(self, exc_type, exc_val, exc_tb):
821            self._exited_with = exc_type, exc_val, exc_tb
822            if self._raise_in_exit:
823                self._raised = True
824                raise IOError("Cannot close")
825            return self._suppress_at_exit
826
827    def tearDown(self):
828        if self._delete:
829            del shutil.open
830
831    def _set_shutil_open(self, func):
832        shutil.open = func
833        self._delete = True
834
835    def test_w_source_open_fails(self):
836        def _open(filename, mode='r'):
837            if filename == 'srcfile':
838                raise IOError('Cannot open "srcfile"')
839            assert 0  # shouldn't reach here.
840
841        self._set_shutil_open(_open)
842
843        self.assertRaises(IOError, shutil.copyfile, 'srcfile', 'destfile')
844
845    def test_w_dest_open_fails(self):
846
847        srcfile = self.Faux()
848
849        def _open(filename, mode='r'):
850            if filename == 'srcfile':
851                return srcfile
852            if filename == 'destfile':
853                raise IOError('Cannot open "destfile"')
854            assert 0  # shouldn't reach here.
855
856        self._set_shutil_open(_open)
857
858        shutil.copyfile('srcfile', 'destfile')
859        self.assertTrue(srcfile._entered)
860        self.assertTrue(srcfile._exited_with[0] is IOError)
861        self.assertEqual(srcfile._exited_with[1].args,
862                         ('Cannot open "destfile"',))
863
864    def test_w_dest_close_fails(self):
865
866        srcfile = self.Faux()
867        destfile = self.Faux(True)
868
869        def _open(filename, mode='r'):
870            if filename == 'srcfile':
871                return srcfile
872            if filename == 'destfile':
873                return destfile
874            assert 0  # shouldn't reach here.
875
876        self._set_shutil_open(_open)
877
878        shutil.copyfile('srcfile', 'destfile')
879        self.assertTrue(srcfile._entered)
880        self.assertTrue(destfile._entered)
881        self.assertTrue(destfile._raised)
882        self.assertTrue(srcfile._exited_with[0] is IOError)
883        self.assertEqual(srcfile._exited_with[1].args,
884                         ('Cannot close',))
885
886    def test_w_source_close_fails(self):
887
888        srcfile = self.Faux(True)
889        destfile = self.Faux()
890
891        def _open(filename, mode='r'):
892            if filename == 'srcfile':
893                return srcfile
894            if filename == 'destfile':
895                return destfile
896            assert 0  # shouldn't reach here.
897
898        self._set_shutil_open(_open)
899
900        self.assertRaises(IOError,
901                          shutil.copyfile, 'srcfile', 'destfile')
902        self.assertTrue(srcfile._entered)
903        self.assertTrue(destfile._entered)
904        self.assertFalse(destfile._raised)
905        self.assertTrue(srcfile._exited_with[0] is None)
906        self.assertTrue(srcfile._raised)
907
908    def test_move_dir_caseinsensitive(self):
909        # Renames a folder to the same name
910        # but a different case.
911
912        self.src_dir = tempfile.mkdtemp()
913        dst_dir = os.path.join(
914                os.path.dirname(self.src_dir),
915                os.path.basename(self.src_dir).upper())
916        self.assertNotEqual(self.src_dir, dst_dir)
917
918        try:
919            shutil.move(self.src_dir, dst_dir)
920            self.assertTrue(os.path.isdir(dst_dir))
921        finally:
922            if os.path.exists(dst_dir):
923                os.rmdir(dst_dir)
924
925
926
927def test_main():
928    support.run_unittest(TestShutil, TestMove, TestCopyFile)
929
930if __name__ == '__main__':
931    test_main()
932