• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1import imghdr
2import io
3import os
4import pathlib
5import unittest
6import warnings
7from test.support import findfile
8from test.support.os_helper import TESTFN, unlink
9
10
11TEST_FILES = (
12    ('python.png', 'png'),
13    ('python.gif', 'gif'),
14    ('python.bmp', 'bmp'),
15    ('python.ppm', 'ppm'),
16    ('python.pgm', 'pgm'),
17    ('python.pbm', 'pbm'),
18    ('python.jpg', 'jpeg'),
19    ('python.ras', 'rast'),
20    ('python.sgi', 'rgb'),
21    ('python.tiff', 'tiff'),
22    ('python.xbm', 'xbm'),
23    ('python.webp', 'webp'),
24    ('python.exr', 'exr'),
25)
26
27class UnseekableIO(io.FileIO):
28    def tell(self):
29        raise io.UnsupportedOperation
30
31    def seek(self, *args, **kwargs):
32        raise io.UnsupportedOperation
33
34class TestImghdr(unittest.TestCase):
35    @classmethod
36    def setUpClass(cls):
37        cls.testfile = findfile('python.png', subdir='imghdrdata')
38        with open(cls.testfile, 'rb') as stream:
39            cls.testdata = stream.read()
40
41    def tearDown(self):
42        unlink(TESTFN)
43
44    def test_data(self):
45        for filename, expected in TEST_FILES:
46            filename = findfile(filename, subdir='imghdrdata')
47            self.assertEqual(imghdr.what(filename), expected)
48            with open(filename, 'rb') as stream:
49                self.assertEqual(imghdr.what(stream), expected)
50            with open(filename, 'rb') as stream:
51                data = stream.read()
52            self.assertEqual(imghdr.what(None, data), expected)
53            self.assertEqual(imghdr.what(None, bytearray(data)), expected)
54
55    def test_pathlike_filename(self):
56        for filename, expected in TEST_FILES:
57            with self.subTest(filename=filename):
58                filename = findfile(filename, subdir='imghdrdata')
59                self.assertEqual(imghdr.what(pathlib.Path(filename)), expected)
60
61    def test_register_test(self):
62        def test_jumbo(h, file):
63            if h.startswith(b'eggs'):
64                return 'ham'
65        imghdr.tests.append(test_jumbo)
66        self.addCleanup(imghdr.tests.pop)
67        self.assertEqual(imghdr.what(None, b'eggs'), 'ham')
68
69    def test_file_pos(self):
70        with open(TESTFN, 'wb') as stream:
71            stream.write(b'ababagalamaga')
72            pos = stream.tell()
73            stream.write(self.testdata)
74        with open(TESTFN, 'rb') as stream:
75            stream.seek(pos)
76            self.assertEqual(imghdr.what(stream), 'png')
77            self.assertEqual(stream.tell(), pos)
78
79    def test_bad_args(self):
80        with self.assertRaises(TypeError):
81            imghdr.what()
82        with self.assertRaises(AttributeError):
83            imghdr.what(None)
84        with self.assertRaises(TypeError):
85            imghdr.what(self.testfile, 1)
86        with self.assertRaises(AttributeError):
87            imghdr.what(os.fsencode(self.testfile))
88        with open(self.testfile, 'rb') as f:
89            with self.assertRaises(AttributeError):
90                imghdr.what(f.fileno())
91
92    def test_invalid_headers(self):
93        for header in (b'\211PN\r\n',
94                       b'\001\331',
95                       b'\x59\xA6',
96                       b'cutecat',
97                       b'000000JFI',
98                       b'GIF80'):
99            self.assertIsNone(imghdr.what(None, header))
100
101    def test_string_data(self):
102        with warnings.catch_warnings():
103            warnings.simplefilter("ignore", BytesWarning)
104            for filename, _ in TEST_FILES:
105                filename = findfile(filename, subdir='imghdrdata')
106                with open(filename, 'rb') as stream:
107                    data = stream.read().decode('latin1')
108                with self.assertRaises(TypeError):
109                    imghdr.what(io.StringIO(data))
110                with self.assertRaises(TypeError):
111                    imghdr.what(None, data)
112
113    def test_missing_file(self):
114        with self.assertRaises(FileNotFoundError):
115            imghdr.what('missing')
116
117    def test_closed_file(self):
118        stream = open(self.testfile, 'rb')
119        stream.close()
120        with self.assertRaises(ValueError) as cm:
121            imghdr.what(stream)
122        stream = io.BytesIO(self.testdata)
123        stream.close()
124        with self.assertRaises(ValueError) as cm:
125            imghdr.what(stream)
126
127    def test_unseekable(self):
128        with open(TESTFN, 'wb') as stream:
129            stream.write(self.testdata)
130        with UnseekableIO(TESTFN, 'rb') as stream:
131            with self.assertRaises(io.UnsupportedOperation):
132                imghdr.what(stream)
133
134    def test_output_stream(self):
135        with open(TESTFN, 'wb') as stream:
136            stream.write(self.testdata)
137            stream.seek(0)
138            with self.assertRaises(OSError) as cm:
139                imghdr.what(stream)
140
141if __name__ == '__main__':
142    unittest.main()
143