1"""Unit tests for collections.defaultdict.""" 2 3import os 4import copy 5import pickle 6import tempfile 7import unittest 8 9from collections import defaultdict 10 11def foobar(): 12 return list 13 14class TestDefaultDict(unittest.TestCase): 15 16 def test_basic(self): 17 d1 = defaultdict() 18 self.assertEqual(d1.default_factory, None) 19 d1.default_factory = list 20 d1[12].append(42) 21 self.assertEqual(d1, {12: [42]}) 22 d1[12].append(24) 23 self.assertEqual(d1, {12: [42, 24]}) 24 d1[13] 25 d1[14] 26 self.assertEqual(d1, {12: [42, 24], 13: [], 14: []}) 27 self.assertTrue(d1[12] is not d1[13] is not d1[14]) 28 d2 = defaultdict(list, foo=1, bar=2) 29 self.assertEqual(d2.default_factory, list) 30 self.assertEqual(d2, {"foo": 1, "bar": 2}) 31 self.assertEqual(d2["foo"], 1) 32 self.assertEqual(d2["bar"], 2) 33 self.assertEqual(d2[42], []) 34 self.assertIn("foo", d2) 35 self.assertIn("foo", d2.keys()) 36 self.assertIn("bar", d2) 37 self.assertIn("bar", d2.keys()) 38 self.assertIn(42, d2) 39 self.assertIn(42, d2.keys()) 40 self.assertNotIn(12, d2) 41 self.assertNotIn(12, d2.keys()) 42 d2.default_factory = None 43 self.assertEqual(d2.default_factory, None) 44 try: 45 d2[15] 46 except KeyError as err: 47 self.assertEqual(err.args, (15,)) 48 else: 49 self.fail("d2[15] didn't raise KeyError") 50 self.assertRaises(TypeError, defaultdict, 1) 51 52 def test_missing(self): 53 d1 = defaultdict() 54 self.assertRaises(KeyError, d1.__missing__, 42) 55 d1.default_factory = list 56 self.assertEqual(d1.__missing__(42), []) 57 58 def test_repr(self): 59 d1 = defaultdict() 60 self.assertEqual(d1.default_factory, None) 61 self.assertEqual(repr(d1), "defaultdict(None, {})") 62 self.assertEqual(eval(repr(d1)), d1) 63 d1[11] = 41 64 self.assertEqual(repr(d1), "defaultdict(None, {11: 41})") 65 d2 = defaultdict(int) 66 self.assertEqual(d2.default_factory, int) 67 d2[12] = 42 68 self.assertEqual(repr(d2), "defaultdict(<class 'int'>, {12: 42})") 69 def foo(): return 43 70 d3 = defaultdict(foo) 71 self.assertTrue(d3.default_factory is foo) 72 d3[13] 73 self.assertEqual(repr(d3), "defaultdict(%s, {13: 43})" % repr(foo)) 74 75 def test_print(self): 76 d1 = defaultdict() 77 def foo(): return 42 78 d2 = defaultdict(foo, {1: 2}) 79 # NOTE: We can't use tempfile.[Named]TemporaryFile since this 80 # code must exercise the tp_print C code, which only gets 81 # invoked for *real* files. 82 tfn = tempfile.mktemp() 83 try: 84 f = open(tfn, "w+") 85 try: 86 print(d1, file=f) 87 print(d2, file=f) 88 f.seek(0) 89 self.assertEqual(f.readline(), repr(d1) + "\n") 90 self.assertEqual(f.readline(), repr(d2) + "\n") 91 finally: 92 f.close() 93 finally: 94 os.remove(tfn) 95 96 def test_copy(self): 97 d1 = defaultdict() 98 d2 = d1.copy() 99 self.assertEqual(type(d2), defaultdict) 100 self.assertEqual(d2.default_factory, None) 101 self.assertEqual(d2, {}) 102 d1.default_factory = list 103 d3 = d1.copy() 104 self.assertEqual(type(d3), defaultdict) 105 self.assertEqual(d3.default_factory, list) 106 self.assertEqual(d3, {}) 107 d1[42] 108 d4 = d1.copy() 109 self.assertEqual(type(d4), defaultdict) 110 self.assertEqual(d4.default_factory, list) 111 self.assertEqual(d4, {42: []}) 112 d4[12] 113 self.assertEqual(d4, {42: [], 12: []}) 114 115 # Issue 6637: Copy fails for empty default dict 116 d = defaultdict() 117 d['a'] = 42 118 e = d.copy() 119 self.assertEqual(e['a'], 42) 120 121 def test_shallow_copy(self): 122 d1 = defaultdict(foobar, {1: 1}) 123 d2 = copy.copy(d1) 124 self.assertEqual(d2.default_factory, foobar) 125 self.assertEqual(d2, d1) 126 d1.default_factory = list 127 d2 = copy.copy(d1) 128 self.assertEqual(d2.default_factory, list) 129 self.assertEqual(d2, d1) 130 131 def test_deep_copy(self): 132 d1 = defaultdict(foobar, {1: [1]}) 133 d2 = copy.deepcopy(d1) 134 self.assertEqual(d2.default_factory, foobar) 135 self.assertEqual(d2, d1) 136 self.assertTrue(d1[1] is not d2[1]) 137 d1.default_factory = list 138 d2 = copy.deepcopy(d1) 139 self.assertEqual(d2.default_factory, list) 140 self.assertEqual(d2, d1) 141 142 def test_keyerror_without_factory(self): 143 d1 = defaultdict() 144 try: 145 d1[(1,)] 146 except KeyError as err: 147 self.assertEqual(err.args[0], (1,)) 148 else: 149 self.fail("expected KeyError") 150 151 def test_recursive_repr(self): 152 # Issue2045: stack overflow when default_factory is a bound method 153 class sub(defaultdict): 154 def __init__(self): 155 self.default_factory = self._factory 156 def _factory(self): 157 return [] 158 d = sub() 159 self.assertRegex(repr(d), 160 r"sub\(<bound method .*sub\._factory " 161 r"of sub\(\.\.\., \{\}\)>, \{\}\)") 162 163 # NOTE: printing a subclass of a builtin type does not call its 164 # tp_print slot. So this part is essentially the same test as above. 165 tfn = tempfile.mktemp() 166 try: 167 f = open(tfn, "w+") 168 try: 169 print(d, file=f) 170 finally: 171 f.close() 172 finally: 173 os.remove(tfn) 174 175 def test_callable_arg(self): 176 self.assertRaises(TypeError, defaultdict, {}) 177 178 def test_pickling(self): 179 d = defaultdict(int) 180 d[1] 181 for proto in range(pickle.HIGHEST_PROTOCOL + 1): 182 s = pickle.dumps(d, proto) 183 o = pickle.loads(s) 184 self.assertEqual(d, o) 185 186 def test_union(self): 187 i = defaultdict(int, {1: 1, 2: 2}) 188 s = defaultdict(str, {0: "zero", 1: "one"}) 189 190 i_s = i | s 191 self.assertIs(i_s.default_factory, int) 192 self.assertDictEqual(i_s, {1: "one", 2: 2, 0: "zero"}) 193 self.assertEqual(list(i_s), [1, 2, 0]) 194 195 s_i = s | i 196 self.assertIs(s_i.default_factory, str) 197 self.assertDictEqual(s_i, {0: "zero", 1: 1, 2: 2}) 198 self.assertEqual(list(s_i), [0, 1, 2]) 199 200 i_ds = i | dict(s) 201 self.assertIs(i_ds.default_factory, int) 202 self.assertDictEqual(i_ds, {1: "one", 2: 2, 0: "zero"}) 203 self.assertEqual(list(i_ds), [1, 2, 0]) 204 205 ds_i = dict(s) | i 206 self.assertIs(ds_i.default_factory, int) 207 self.assertDictEqual(ds_i, {0: "zero", 1: 1, 2: 2}) 208 self.assertEqual(list(ds_i), [0, 1, 2]) 209 210 with self.assertRaises(TypeError): 211 i | list(s.items()) 212 with self.assertRaises(TypeError): 213 list(s.items()) | i 214 215 # We inherit a fine |= from dict, so just a few sanity checks here: 216 i |= list(s.items()) 217 self.assertIs(i.default_factory, int) 218 self.assertDictEqual(i, {1: "one", 2: 2, 0: "zero"}) 219 self.assertEqual(list(i), [1, 2, 0]) 220 221 with self.assertRaises(TypeError): 222 i |= None 223 224if __name__ == "__main__": 225 unittest.main() 226