1import copy_reg 2import unittest 3 4from test import test_support 5from test.pickletester import ExtensionSaver 6 7class C: 8 pass 9 10 11class WithoutSlots(object): 12 pass 13 14class WithWeakref(object): 15 __slots__ = ('__weakref__',) 16 17class WithPrivate(object): 18 __slots__ = ('__spam',) 19 20class WithSingleString(object): 21 __slots__ = 'spam' 22 23class WithInherited(WithSingleString): 24 __slots__ = ('eggs',) 25 26 27class CopyRegTestCase(unittest.TestCase): 28 29 def test_class(self): 30 self.assertRaises(TypeError, copy_reg.pickle, 31 C, None, None) 32 33 def test_noncallable_reduce(self): 34 self.assertRaises(TypeError, copy_reg.pickle, 35 type(1), "not a callable") 36 37 def test_noncallable_constructor(self): 38 self.assertRaises(TypeError, copy_reg.pickle, 39 type(1), int, "not a callable") 40 41 def test_bool(self): 42 import copy 43 self.assertEqual(True, copy.copy(True)) 44 45 def test_extension_registry(self): 46 mod, func, code = 'junk1 ', ' junk2', 0xabcd 47 e = ExtensionSaver(code) 48 try: 49 # Shouldn't be in registry now. 50 self.assertRaises(ValueError, copy_reg.remove_extension, 51 mod, func, code) 52 copy_reg.add_extension(mod, func, code) 53 # Should be in the registry. 54 self.assertTrue(copy_reg._extension_registry[mod, func] == code) 55 self.assertTrue(copy_reg._inverted_registry[code] == (mod, func)) 56 # Shouldn't be in the cache. 57 self.assertNotIn(code, copy_reg._extension_cache) 58 # Redundant registration should be OK. 59 copy_reg.add_extension(mod, func, code) # shouldn't blow up 60 # Conflicting code. 61 self.assertRaises(ValueError, copy_reg.add_extension, 62 mod, func, code + 1) 63 self.assertRaises(ValueError, copy_reg.remove_extension, 64 mod, func, code + 1) 65 # Conflicting module name. 66 self.assertRaises(ValueError, copy_reg.add_extension, 67 mod[1:], func, code ) 68 self.assertRaises(ValueError, copy_reg.remove_extension, 69 mod[1:], func, code ) 70 # Conflicting function name. 71 self.assertRaises(ValueError, copy_reg.add_extension, 72 mod, func[1:], code) 73 self.assertRaises(ValueError, copy_reg.remove_extension, 74 mod, func[1:], code) 75 # Can't remove one that isn't registered at all. 76 if code + 1 not in copy_reg._inverted_registry: 77 self.assertRaises(ValueError, copy_reg.remove_extension, 78 mod[1:], func[1:], code + 1) 79 80 finally: 81 e.restore() 82 83 # Shouldn't be there anymore. 84 self.assertNotIn((mod, func), copy_reg._extension_registry) 85 # The code *may* be in copy_reg._extension_registry, though, if 86 # we happened to pick on a registered code. So don't check for 87 # that. 88 89 # Check valid codes at the limits. 90 for code in 1, 0x7fffffff: 91 e = ExtensionSaver(code) 92 try: 93 copy_reg.add_extension(mod, func, code) 94 copy_reg.remove_extension(mod, func, code) 95 finally: 96 e.restore() 97 98 # Ensure invalid codes blow up. 99 for code in -1, 0, 0x80000000L: 100 self.assertRaises(ValueError, copy_reg.add_extension, 101 mod, func, code) 102 103 def test_slotnames(self): 104 self.assertEqual(copy_reg._slotnames(WithoutSlots), []) 105 self.assertEqual(copy_reg._slotnames(WithWeakref), []) 106 expected = ['_WithPrivate__spam'] 107 self.assertEqual(copy_reg._slotnames(WithPrivate), expected) 108 self.assertEqual(copy_reg._slotnames(WithSingleString), ['spam']) 109 expected = ['eggs', 'spam'] 110 expected.sort() 111 result = copy_reg._slotnames(WithInherited) 112 result.sort() 113 self.assertEqual(result, expected) 114 115 116def test_main(): 117 test_support.run_unittest(CopyRegTestCase) 118 119 120if __name__ == "__main__": 121 test_main() 122