1"""Unit tests for collections.defaultdict.""" 2 3import copy 4import pickle 5import unittest 6 7from collections import defaultdict 8 9def foobar(): 10 return list 11 12class TestDefaultDict(unittest.TestCase): 13 14 def test_basic(self): 15 d1 = defaultdict() 16 self.assertEqual(d1.default_factory, None) 17 d1.default_factory = list 18 d1[12].append(42) 19 self.assertEqual(d1, {12: [42]}) 20 d1[12].append(24) 21 self.assertEqual(d1, {12: [42, 24]}) 22 d1[13] 23 d1[14] 24 self.assertEqual(d1, {12: [42, 24], 13: [], 14: []}) 25 self.assertTrue(d1[12] is not d1[13] is not d1[14]) 26 d2 = defaultdict(list, foo=1, bar=2) 27 self.assertEqual(d2.default_factory, list) 28 self.assertEqual(d2, {"foo": 1, "bar": 2}) 29 self.assertEqual(d2["foo"], 1) 30 self.assertEqual(d2["bar"], 2) 31 self.assertEqual(d2[42], []) 32 self.assertIn("foo", d2) 33 self.assertIn("foo", d2.keys()) 34 self.assertIn("bar", d2) 35 self.assertIn("bar", d2.keys()) 36 self.assertIn(42, d2) 37 self.assertIn(42, d2.keys()) 38 self.assertNotIn(12, d2) 39 self.assertNotIn(12, d2.keys()) 40 d2.default_factory = None 41 self.assertEqual(d2.default_factory, None) 42 try: 43 d2[15] 44 except KeyError as err: 45 self.assertEqual(err.args, (15,)) 46 else: 47 self.fail("d2[15] didn't raise KeyError") 48 self.assertRaises(TypeError, defaultdict, 1) 49 50 def test_missing(self): 51 d1 = defaultdict() 52 self.assertRaises(KeyError, d1.__missing__, 42) 53 d1.default_factory = list 54 self.assertEqual(d1.__missing__(42), []) 55 56 def test_repr(self): 57 d1 = defaultdict() 58 self.assertEqual(d1.default_factory, None) 59 self.assertEqual(repr(d1), "defaultdict(None, {})") 60 self.assertEqual(eval(repr(d1)), d1) 61 d1[11] = 41 62 self.assertEqual(repr(d1), "defaultdict(None, {11: 41})") 63 d2 = defaultdict(int) 64 self.assertEqual(d2.default_factory, int) 65 d2[12] = 42 66 self.assertEqual(repr(d2), "defaultdict(<class 'int'>, {12: 42})") 67 def foo(): return 43 68 d3 = defaultdict(foo) 69 self.assertTrue(d3.default_factory is foo) 70 d3[13] 71 self.assertEqual(repr(d3), "defaultdict(%s, {13: 43})" % repr(foo)) 72 73 def test_copy(self): 74 d1 = defaultdict() 75 d2 = d1.copy() 76 self.assertEqual(type(d2), defaultdict) 77 self.assertEqual(d2.default_factory, None) 78 self.assertEqual(d2, {}) 79 d1.default_factory = list 80 d3 = d1.copy() 81 self.assertEqual(type(d3), defaultdict) 82 self.assertEqual(d3.default_factory, list) 83 self.assertEqual(d3, {}) 84 d1[42] 85 d4 = d1.copy() 86 self.assertEqual(type(d4), defaultdict) 87 self.assertEqual(d4.default_factory, list) 88 self.assertEqual(d4, {42: []}) 89 d4[12] 90 self.assertEqual(d4, {42: [], 12: []}) 91 92 # Issue 6637: Copy fails for empty default dict 93 d = defaultdict() 94 d['a'] = 42 95 e = d.copy() 96 self.assertEqual(e['a'], 42) 97 98 def test_shallow_copy(self): 99 d1 = defaultdict(foobar, {1: 1}) 100 d2 = copy.copy(d1) 101 self.assertEqual(d2.default_factory, foobar) 102 self.assertEqual(d2, d1) 103 d1.default_factory = list 104 d2 = copy.copy(d1) 105 self.assertEqual(d2.default_factory, list) 106 self.assertEqual(d2, d1) 107 108 def test_deep_copy(self): 109 d1 = defaultdict(foobar, {1: [1]}) 110 d2 = copy.deepcopy(d1) 111 self.assertEqual(d2.default_factory, foobar) 112 self.assertEqual(d2, d1) 113 self.assertTrue(d1[1] is not d2[1]) 114 d1.default_factory = list 115 d2 = copy.deepcopy(d1) 116 self.assertEqual(d2.default_factory, list) 117 self.assertEqual(d2, d1) 118 119 def test_keyerror_without_factory(self): 120 d1 = defaultdict() 121 try: 122 d1[(1,)] 123 except KeyError as err: 124 self.assertEqual(err.args[0], (1,)) 125 else: 126 self.fail("expected KeyError") 127 128 def test_recursive_repr(self): 129 # Issue2045: stack overflow when default_factory is a bound method 130 class sub(defaultdict): 131 def __init__(self): 132 self.default_factory = self._factory 133 def _factory(self): 134 return [] 135 d = sub() 136 self.assertRegex(repr(d), 137 r"sub\(<bound method .*sub\._factory " 138 r"of sub\(\.\.\., \{\}\)>, \{\}\)") 139 140 def test_callable_arg(self): 141 self.assertRaises(TypeError, defaultdict, {}) 142 143 def test_pickling(self): 144 d = defaultdict(int) 145 d[1] 146 for proto in range(pickle.HIGHEST_PROTOCOL + 1): 147 s = pickle.dumps(d, proto) 148 o = pickle.loads(s) 149 self.assertEqual(d, o) 150 151 def test_union(self): 152 i = defaultdict(int, {1: 1, 2: 2}) 153 s = defaultdict(str, {0: "zero", 1: "one"}) 154 155 i_s = i | s 156 self.assertIs(i_s.default_factory, int) 157 self.assertDictEqual(i_s, {1: "one", 2: 2, 0: "zero"}) 158 self.assertEqual(list(i_s), [1, 2, 0]) 159 160 s_i = s | i 161 self.assertIs(s_i.default_factory, str) 162 self.assertDictEqual(s_i, {0: "zero", 1: 1, 2: 2}) 163 self.assertEqual(list(s_i), [0, 1, 2]) 164 165 i_ds = i | dict(s) 166 self.assertIs(i_ds.default_factory, int) 167 self.assertDictEqual(i_ds, {1: "one", 2: 2, 0: "zero"}) 168 self.assertEqual(list(i_ds), [1, 2, 0]) 169 170 ds_i = dict(s) | i 171 self.assertIs(ds_i.default_factory, int) 172 self.assertDictEqual(ds_i, {0: "zero", 1: 1, 2: 2}) 173 self.assertEqual(list(ds_i), [0, 1, 2]) 174 175 with self.assertRaises(TypeError): 176 i | list(s.items()) 177 with self.assertRaises(TypeError): 178 list(s.items()) | i 179 180 # We inherit a fine |= from dict, so just a few sanity checks here: 181 i |= list(s.items()) 182 self.assertIs(i.default_factory, int) 183 self.assertDictEqual(i, {1: "one", 2: 2, 0: "zero"}) 184 self.assertEqual(list(i), [1, 2, 0]) 185 186 with self.assertRaises(TypeError): 187 i |= None 188 189if __name__ == "__main__": 190 unittest.main() 191