1from _compat_pickle import (IMPORT_MAPPING, REVERSE_IMPORT_MAPPING, 2 NAME_MAPPING, REVERSE_NAME_MAPPING) 3import builtins 4import pickle 5import io 6import collections 7import struct 8import sys 9import warnings 10import weakref 11 12import doctest 13import unittest 14from test import support 15from test.support import import_helper 16 17from test.pickletester import AbstractHookTests 18from test.pickletester import AbstractUnpickleTests 19from test.pickletester import AbstractPickleTests 20from test.pickletester import AbstractPickleModuleTests 21from test.pickletester import AbstractPersistentPicklerTests 22from test.pickletester import AbstractIdentityPersistentPicklerTests 23from test.pickletester import AbstractPicklerUnpicklerObjectTests 24from test.pickletester import AbstractDispatchTableTests 25from test.pickletester import AbstractCustomPicklerClass 26from test.pickletester import BigmemPickleTests 27 28try: 29 import _pickle 30 has_c_implementation = True 31except ImportError: 32 has_c_implementation = False 33 34 35class PyPickleTests(AbstractPickleModuleTests, unittest.TestCase): 36 dump = staticmethod(pickle._dump) 37 dumps = staticmethod(pickle._dumps) 38 load = staticmethod(pickle._load) 39 loads = staticmethod(pickle._loads) 40 Pickler = pickle._Pickler 41 Unpickler = pickle._Unpickler 42 43 44class PyUnpicklerTests(AbstractUnpickleTests, unittest.TestCase): 45 46 unpickler = pickle._Unpickler 47 bad_stack_errors = (IndexError,) 48 truncated_errors = (pickle.UnpicklingError, EOFError, 49 AttributeError, ValueError, 50 struct.error, IndexError, ImportError) 51 52 def loads(self, buf, **kwds): 53 f = io.BytesIO(buf) 54 u = self.unpickler(f, **kwds) 55 return u.load() 56 57 58class PyPicklerTests(AbstractPickleTests, unittest.TestCase): 59 60 pickler = pickle._Pickler 61 unpickler = pickle._Unpickler 62 63 def dumps(self, arg, proto=None, **kwargs): 64 f = io.BytesIO() 65 p = self.pickler(f, proto, **kwargs) 66 p.dump(arg) 67 f.seek(0) 68 return bytes(f.read()) 69 70 def loads(self, buf, **kwds): 71 f = io.BytesIO(buf) 72 u = self.unpickler(f, **kwds) 73 return u.load() 74 75 76class InMemoryPickleTests(AbstractPickleTests, AbstractUnpickleTests, 77 BigmemPickleTests, unittest.TestCase): 78 79 bad_stack_errors = (pickle.UnpicklingError, IndexError) 80 truncated_errors = (pickle.UnpicklingError, EOFError, 81 AttributeError, ValueError, 82 struct.error, IndexError, ImportError) 83 84 def dumps(self, arg, protocol=None, **kwargs): 85 return pickle.dumps(arg, protocol, **kwargs) 86 87 def loads(self, buf, **kwds): 88 return pickle.loads(buf, **kwds) 89 90 test_framed_write_sizes_with_delayed_writer = None 91 92 93class PersistentPicklerUnpicklerMixin(object): 94 95 def dumps(self, arg, proto=None): 96 class PersPickler(self.pickler): 97 def persistent_id(subself, obj): 98 return self.persistent_id(obj) 99 f = io.BytesIO() 100 p = PersPickler(f, proto) 101 p.dump(arg) 102 return f.getvalue() 103 104 def loads(self, buf, **kwds): 105 class PersUnpickler(self.unpickler): 106 def persistent_load(subself, obj): 107 return self.persistent_load(obj) 108 f = io.BytesIO(buf) 109 u = PersUnpickler(f, **kwds) 110 return u.load() 111 112 113class PyPersPicklerTests(AbstractPersistentPicklerTests, 114 PersistentPicklerUnpicklerMixin, unittest.TestCase): 115 116 pickler = pickle._Pickler 117 unpickler = pickle._Unpickler 118 119 120class PyIdPersPicklerTests(AbstractIdentityPersistentPicklerTests, 121 PersistentPicklerUnpicklerMixin, unittest.TestCase): 122 123 pickler = pickle._Pickler 124 unpickler = pickle._Unpickler 125 126 @support.cpython_only 127 def test_pickler_reference_cycle(self): 128 def check(Pickler): 129 for proto in range(pickle.HIGHEST_PROTOCOL + 1): 130 f = io.BytesIO() 131 pickler = Pickler(f, proto) 132 pickler.dump('abc') 133 self.assertEqual(self.loads(f.getvalue()), 'abc') 134 pickler = Pickler(io.BytesIO()) 135 self.assertEqual(pickler.persistent_id('def'), 'def') 136 r = weakref.ref(pickler) 137 del pickler 138 self.assertIsNone(r()) 139 140 class PersPickler(self.pickler): 141 def persistent_id(subself, obj): 142 return obj 143 check(PersPickler) 144 145 class PersPickler(self.pickler): 146 @classmethod 147 def persistent_id(cls, obj): 148 return obj 149 check(PersPickler) 150 151 class PersPickler(self.pickler): 152 @staticmethod 153 def persistent_id(obj): 154 return obj 155 check(PersPickler) 156 157 @support.cpython_only 158 def test_unpickler_reference_cycle(self): 159 def check(Unpickler): 160 for proto in range(pickle.HIGHEST_PROTOCOL + 1): 161 unpickler = Unpickler(io.BytesIO(self.dumps('abc', proto))) 162 self.assertEqual(unpickler.load(), 'abc') 163 unpickler = Unpickler(io.BytesIO()) 164 self.assertEqual(unpickler.persistent_load('def'), 'def') 165 r = weakref.ref(unpickler) 166 del unpickler 167 self.assertIsNone(r()) 168 169 class PersUnpickler(self.unpickler): 170 def persistent_load(subself, pid): 171 return pid 172 check(PersUnpickler) 173 174 class PersUnpickler(self.unpickler): 175 @classmethod 176 def persistent_load(cls, pid): 177 return pid 178 check(PersUnpickler) 179 180 class PersUnpickler(self.unpickler): 181 @staticmethod 182 def persistent_load(pid): 183 return pid 184 check(PersUnpickler) 185 186 187class PyPicklerUnpicklerObjectTests(AbstractPicklerUnpicklerObjectTests, unittest.TestCase): 188 189 pickler_class = pickle._Pickler 190 unpickler_class = pickle._Unpickler 191 192 193class PyDispatchTableTests(AbstractDispatchTableTests, unittest.TestCase): 194 195 pickler_class = pickle._Pickler 196 197 def get_dispatch_table(self): 198 return pickle.dispatch_table.copy() 199 200 201class PyChainDispatchTableTests(AbstractDispatchTableTests, unittest.TestCase): 202 203 pickler_class = pickle._Pickler 204 205 def get_dispatch_table(self): 206 return collections.ChainMap({}, pickle.dispatch_table) 207 208 209class PyPicklerHookTests(AbstractHookTests, unittest.TestCase): 210 class CustomPyPicklerClass(pickle._Pickler, 211 AbstractCustomPicklerClass): 212 pass 213 pickler_class = CustomPyPicklerClass 214 215 216if has_c_implementation: 217 class CPickleTests(AbstractPickleModuleTests, unittest.TestCase): 218 from _pickle import dump, dumps, load, loads, Pickler, Unpickler 219 220 class CUnpicklerTests(PyUnpicklerTests): 221 unpickler = _pickle.Unpickler 222 bad_stack_errors = (pickle.UnpicklingError,) 223 truncated_errors = (pickle.UnpicklingError,) 224 225 class CPicklerTests(PyPicklerTests): 226 pickler = _pickle.Pickler 227 unpickler = _pickle.Unpickler 228 229 class CPersPicklerTests(PyPersPicklerTests): 230 pickler = _pickle.Pickler 231 unpickler = _pickle.Unpickler 232 233 class CIdPersPicklerTests(PyIdPersPicklerTests): 234 pickler = _pickle.Pickler 235 unpickler = _pickle.Unpickler 236 237 class CDumpPickle_LoadPickle(PyPicklerTests): 238 pickler = _pickle.Pickler 239 unpickler = pickle._Unpickler 240 241 class DumpPickle_CLoadPickle(PyPicklerTests): 242 pickler = pickle._Pickler 243 unpickler = _pickle.Unpickler 244 245 class CPicklerUnpicklerObjectTests(AbstractPicklerUnpicklerObjectTests, unittest.TestCase): 246 pickler_class = _pickle.Pickler 247 unpickler_class = _pickle.Unpickler 248 249 def test_issue18339(self): 250 unpickler = self.unpickler_class(io.BytesIO()) 251 with self.assertRaises(TypeError): 252 unpickler.memo = object 253 # used to cause a segfault 254 with self.assertRaises(ValueError): 255 unpickler.memo = {-1: None} 256 unpickler.memo = {1: None} 257 258 class CDispatchTableTests(AbstractDispatchTableTests, unittest.TestCase): 259 pickler_class = pickle.Pickler 260 def get_dispatch_table(self): 261 return pickle.dispatch_table.copy() 262 263 class CChainDispatchTableTests(AbstractDispatchTableTests, unittest.TestCase): 264 pickler_class = pickle.Pickler 265 def get_dispatch_table(self): 266 return collections.ChainMap({}, pickle.dispatch_table) 267 268 class CPicklerHookTests(AbstractHookTests, unittest.TestCase): 269 class CustomCPicklerClass(_pickle.Pickler, AbstractCustomPicklerClass): 270 pass 271 pickler_class = CustomCPicklerClass 272 273 @support.cpython_only 274 class SizeofTests(unittest.TestCase): 275 check_sizeof = support.check_sizeof 276 277 def test_pickler(self): 278 basesize = support.calcobjsize('7P2n3i2n3i2P') 279 p = _pickle.Pickler(io.BytesIO()) 280 self.assertEqual(object.__sizeof__(p), basesize) 281 MT_size = struct.calcsize('3nP0n') 282 ME_size = struct.calcsize('Pn0P') 283 check = self.check_sizeof 284 check(p, basesize + 285 MT_size + 8 * ME_size + # Minimal memo table size. 286 sys.getsizeof(b'x'*4096)) # Minimal write buffer size. 287 for i in range(6): 288 p.dump(chr(i)) 289 check(p, basesize + 290 MT_size + 32 * ME_size + # Size of memo table required to 291 # save references to 6 objects. 292 0) # Write buffer is cleared after every dump(). 293 294 def test_unpickler(self): 295 basesize = support.calcobjsize('2P2n2P 2P2n2i5P 2P3n8P2n2i') 296 unpickler = _pickle.Unpickler 297 P = struct.calcsize('P') # Size of memo table entry. 298 n = struct.calcsize('n') # Size of mark table entry. 299 check = self.check_sizeof 300 for encoding in 'ASCII', 'UTF-16', 'latin-1': 301 for errors in 'strict', 'replace': 302 u = unpickler(io.BytesIO(), 303 encoding=encoding, errors=errors) 304 self.assertEqual(object.__sizeof__(u), basesize) 305 check(u, basesize + 306 32 * P + # Minimal memo table size. 307 len(encoding) + 1 + len(errors) + 1) 308 309 stdsize = basesize + len('ASCII') + 1 + len('strict') + 1 310 def check_unpickler(data, memo_size, marks_size): 311 dump = pickle.dumps(data) 312 u = unpickler(io.BytesIO(dump), 313 encoding='ASCII', errors='strict') 314 u.load() 315 check(u, stdsize + memo_size * P + marks_size * n) 316 317 check_unpickler(0, 32, 0) 318 # 20 is minimal non-empty mark stack size. 319 check_unpickler([0] * 100, 32, 20) 320 # 128 is memo table size required to save references to 100 objects. 321 check_unpickler([chr(i) for i in range(100)], 128, 20) 322 def recurse(deep): 323 data = 0 324 for i in range(deep): 325 data = [data, data] 326 return data 327 check_unpickler(recurse(0), 32, 0) 328 check_unpickler(recurse(1), 32, 20) 329 check_unpickler(recurse(20), 32, 20) 330 check_unpickler(recurse(50), 64, 60) 331 check_unpickler(recurse(100), 128, 140) 332 333 u = unpickler(io.BytesIO(pickle.dumps('a', 0)), 334 encoding='ASCII', errors='strict') 335 u.load() 336 check(u, stdsize + 32 * P + 2 + 1) 337 338 339ALT_IMPORT_MAPPING = { 340 ('_elementtree', 'xml.etree.ElementTree'), 341 ('cPickle', 'pickle'), 342 ('StringIO', 'io'), 343 ('cStringIO', 'io'), 344} 345 346ALT_NAME_MAPPING = { 347 ('__builtin__', 'basestring', 'builtins', 'str'), 348 ('exceptions', 'StandardError', 'builtins', 'Exception'), 349 ('UserDict', 'UserDict', 'collections', 'UserDict'), 350 ('socket', '_socketobject', 'socket', 'SocketType'), 351} 352 353def mapping(module, name): 354 if (module, name) in NAME_MAPPING: 355 module, name = NAME_MAPPING[(module, name)] 356 elif module in IMPORT_MAPPING: 357 module = IMPORT_MAPPING[module] 358 return module, name 359 360def reverse_mapping(module, name): 361 if (module, name) in REVERSE_NAME_MAPPING: 362 module, name = REVERSE_NAME_MAPPING[(module, name)] 363 elif module in REVERSE_IMPORT_MAPPING: 364 module = REVERSE_IMPORT_MAPPING[module] 365 return module, name 366 367def getmodule(module): 368 try: 369 return sys.modules[module] 370 except KeyError: 371 try: 372 with warnings.catch_warnings(): 373 action = 'always' if support.verbose else 'ignore' 374 warnings.simplefilter(action, DeprecationWarning) 375 __import__(module) 376 except AttributeError as exc: 377 if support.verbose: 378 print("Can't import module %r: %s" % (module, exc)) 379 raise ImportError 380 except ImportError as exc: 381 if support.verbose: 382 print(exc) 383 raise 384 return sys.modules[module] 385 386def getattribute(module, name): 387 obj = getmodule(module) 388 for n in name.split('.'): 389 obj = getattr(obj, n) 390 return obj 391 392def get_exceptions(mod): 393 for name in dir(mod): 394 attr = getattr(mod, name) 395 if isinstance(attr, type) and issubclass(attr, BaseException): 396 yield name, attr 397 398class CompatPickleTests(unittest.TestCase): 399 def test_import(self): 400 modules = set(IMPORT_MAPPING.values()) 401 modules |= set(REVERSE_IMPORT_MAPPING) 402 modules |= {module for module, name in REVERSE_NAME_MAPPING} 403 modules |= {module for module, name in NAME_MAPPING.values()} 404 for module in modules: 405 try: 406 getmodule(module) 407 except ImportError: 408 pass 409 410 def test_import_mapping(self): 411 for module3, module2 in REVERSE_IMPORT_MAPPING.items(): 412 with self.subTest((module3, module2)): 413 try: 414 getmodule(module3) 415 except ImportError: 416 pass 417 if module3[:1] != '_': 418 self.assertIn(module2, IMPORT_MAPPING) 419 self.assertEqual(IMPORT_MAPPING[module2], module3) 420 421 def test_name_mapping(self): 422 for (module3, name3), (module2, name2) in REVERSE_NAME_MAPPING.items(): 423 with self.subTest(((module3, name3), (module2, name2))): 424 if (module2, name2) == ('exceptions', 'OSError'): 425 attr = getattribute(module3, name3) 426 self.assertTrue(issubclass(attr, OSError)) 427 elif (module2, name2) == ('exceptions', 'ImportError'): 428 attr = getattribute(module3, name3) 429 self.assertTrue(issubclass(attr, ImportError)) 430 else: 431 module, name = mapping(module2, name2) 432 if module3[:1] != '_': 433 self.assertEqual((module, name), (module3, name3)) 434 try: 435 attr = getattribute(module3, name3) 436 except ImportError: 437 pass 438 else: 439 self.assertEqual(getattribute(module, name), attr) 440 441 def test_reverse_import_mapping(self): 442 for module2, module3 in IMPORT_MAPPING.items(): 443 with self.subTest((module2, module3)): 444 try: 445 getmodule(module3) 446 except ImportError as exc: 447 if support.verbose: 448 print(exc) 449 if ((module2, module3) not in ALT_IMPORT_MAPPING and 450 REVERSE_IMPORT_MAPPING.get(module3, None) != module2): 451 for (m3, n3), (m2, n2) in REVERSE_NAME_MAPPING.items(): 452 if (module3, module2) == (m3, m2): 453 break 454 else: 455 self.fail('No reverse mapping from %r to %r' % 456 (module3, module2)) 457 module = REVERSE_IMPORT_MAPPING.get(module3, module3) 458 module = IMPORT_MAPPING.get(module, module) 459 self.assertEqual(module, module3) 460 461 def test_reverse_name_mapping(self): 462 for (module2, name2), (module3, name3) in NAME_MAPPING.items(): 463 with self.subTest(((module2, name2), (module3, name3))): 464 try: 465 attr = getattribute(module3, name3) 466 except ImportError: 467 pass 468 module, name = reverse_mapping(module3, name3) 469 if (module2, name2, module3, name3) not in ALT_NAME_MAPPING: 470 self.assertEqual((module, name), (module2, name2)) 471 module, name = mapping(module, name) 472 self.assertEqual((module, name), (module3, name3)) 473 474 def test_exceptions(self): 475 self.assertEqual(mapping('exceptions', 'StandardError'), 476 ('builtins', 'Exception')) 477 self.assertEqual(mapping('exceptions', 'Exception'), 478 ('builtins', 'Exception')) 479 self.assertEqual(reverse_mapping('builtins', 'Exception'), 480 ('exceptions', 'Exception')) 481 self.assertEqual(mapping('exceptions', 'OSError'), 482 ('builtins', 'OSError')) 483 self.assertEqual(reverse_mapping('builtins', 'OSError'), 484 ('exceptions', 'OSError')) 485 486 for name, exc in get_exceptions(builtins): 487 with self.subTest(name): 488 if exc in (BlockingIOError, 489 ResourceWarning, 490 StopAsyncIteration, 491 RecursionError, 492 EncodingWarning): 493 continue 494 if exc is not OSError and issubclass(exc, OSError): 495 self.assertEqual(reverse_mapping('builtins', name), 496 ('exceptions', 'OSError')) 497 elif exc is not ImportError and issubclass(exc, ImportError): 498 self.assertEqual(reverse_mapping('builtins', name), 499 ('exceptions', 'ImportError')) 500 self.assertEqual(mapping('exceptions', name), 501 ('exceptions', name)) 502 else: 503 self.assertEqual(reverse_mapping('builtins', name), 504 ('exceptions', name)) 505 self.assertEqual(mapping('exceptions', name), 506 ('builtins', name)) 507 508 def test_multiprocessing_exceptions(self): 509 module = import_helper.import_module('multiprocessing.context') 510 for name, exc in get_exceptions(module): 511 with self.subTest(name): 512 self.assertEqual(reverse_mapping('multiprocessing.context', name), 513 ('multiprocessing', name)) 514 self.assertEqual(mapping('multiprocessing', name), 515 ('multiprocessing.context', name)) 516 517 518def load_tests(loader, tests, pattern): 519 tests.addTest(doctest.DocTestSuite()) 520 return tests 521 522 523if __name__ == "__main__": 524 unittest.main() 525