1import imghdr 2import io 3import os 4import pathlib 5import unittest 6import warnings 7from test.support import findfile 8from test.support.os_helper import TESTFN, unlink 9 10 11TEST_FILES = ( 12 ('python.png', 'png'), 13 ('python.gif', 'gif'), 14 ('python.bmp', 'bmp'), 15 ('python.ppm', 'ppm'), 16 ('python.pgm', 'pgm'), 17 ('python.pbm', 'pbm'), 18 ('python.jpg', 'jpeg'), 19 ('python.ras', 'rast'), 20 ('python.sgi', 'rgb'), 21 ('python.tiff', 'tiff'), 22 ('python.xbm', 'xbm'), 23 ('python.webp', 'webp'), 24 ('python.exr', 'exr'), 25) 26 27class UnseekableIO(io.FileIO): 28 def tell(self): 29 raise io.UnsupportedOperation 30 31 def seek(self, *args, **kwargs): 32 raise io.UnsupportedOperation 33 34class TestImghdr(unittest.TestCase): 35 @classmethod 36 def setUpClass(cls): 37 cls.testfile = findfile('python.png', subdir='imghdrdata') 38 with open(cls.testfile, 'rb') as stream: 39 cls.testdata = stream.read() 40 41 def tearDown(self): 42 unlink(TESTFN) 43 44 def test_data(self): 45 for filename, expected in TEST_FILES: 46 filename = findfile(filename, subdir='imghdrdata') 47 self.assertEqual(imghdr.what(filename), expected) 48 with open(filename, 'rb') as stream: 49 self.assertEqual(imghdr.what(stream), expected) 50 with open(filename, 'rb') as stream: 51 data = stream.read() 52 self.assertEqual(imghdr.what(None, data), expected) 53 self.assertEqual(imghdr.what(None, bytearray(data)), expected) 54 55 def test_pathlike_filename(self): 56 for filename, expected in TEST_FILES: 57 with self.subTest(filename=filename): 58 filename = findfile(filename, subdir='imghdrdata') 59 self.assertEqual(imghdr.what(pathlib.Path(filename)), expected) 60 61 def test_register_test(self): 62 def test_jumbo(h, file): 63 if h.startswith(b'eggs'): 64 return 'ham' 65 imghdr.tests.append(test_jumbo) 66 self.addCleanup(imghdr.tests.pop) 67 self.assertEqual(imghdr.what(None, b'eggs'), 'ham') 68 69 def test_file_pos(self): 70 with open(TESTFN, 'wb') as stream: 71 stream.write(b'ababagalamaga') 72 pos = stream.tell() 73 stream.write(self.testdata) 74 with open(TESTFN, 'rb') as stream: 75 stream.seek(pos) 76 self.assertEqual(imghdr.what(stream), 'png') 77 self.assertEqual(stream.tell(), pos) 78 79 def test_bad_args(self): 80 with self.assertRaises(TypeError): 81 imghdr.what() 82 with self.assertRaises(AttributeError): 83 imghdr.what(None) 84 with self.assertRaises(TypeError): 85 imghdr.what(self.testfile, 1) 86 with self.assertRaises(AttributeError): 87 imghdr.what(os.fsencode(self.testfile)) 88 with open(self.testfile, 'rb') as f: 89 with self.assertRaises(AttributeError): 90 imghdr.what(f.fileno()) 91 92 def test_invalid_headers(self): 93 for header in (b'\211PN\r\n', 94 b'\001\331', 95 b'\x59\xA6', 96 b'cutecat', 97 b'000000JFI', 98 b'GIF80'): 99 self.assertIsNone(imghdr.what(None, header)) 100 101 def test_string_data(self): 102 with warnings.catch_warnings(): 103 warnings.simplefilter("ignore", BytesWarning) 104 for filename, _ in TEST_FILES: 105 filename = findfile(filename, subdir='imghdrdata') 106 with open(filename, 'rb') as stream: 107 data = stream.read().decode('latin1') 108 with self.assertRaises(TypeError): 109 imghdr.what(io.StringIO(data)) 110 with self.assertRaises(TypeError): 111 imghdr.what(None, data) 112 113 def test_missing_file(self): 114 with self.assertRaises(FileNotFoundError): 115 imghdr.what('missing') 116 117 def test_closed_file(self): 118 stream = open(self.testfile, 'rb') 119 stream.close() 120 with self.assertRaises(ValueError) as cm: 121 imghdr.what(stream) 122 stream = io.BytesIO(self.testdata) 123 stream.close() 124 with self.assertRaises(ValueError) as cm: 125 imghdr.what(stream) 126 127 def test_unseekable(self): 128 with open(TESTFN, 'wb') as stream: 129 stream.write(self.testdata) 130 with UnseekableIO(TESTFN, 'rb') as stream: 131 with self.assertRaises(io.UnsupportedOperation): 132 imghdr.what(stream) 133 134 def test_output_stream(self): 135 with open(TESTFN, 'wb') as stream: 136 stream.write(self.testdata) 137 stream.seek(0) 138 with self.assertRaises(OSError) as cm: 139 imghdr.what(stream) 140 141if __name__ == '__main__': 142 unittest.main() 143