• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1"""Unit tests for collections.defaultdict."""
2
3import copy
4import pickle
5import unittest
6
7from collections import defaultdict
8
9def foobar():
10    return list
11
12class TestDefaultDict(unittest.TestCase):
13
14    def test_basic(self):
15        d1 = defaultdict()
16        self.assertEqual(d1.default_factory, None)
17        d1.default_factory = list
18        d1[12].append(42)
19        self.assertEqual(d1, {12: [42]})
20        d1[12].append(24)
21        self.assertEqual(d1, {12: [42, 24]})
22        d1[13]
23        d1[14]
24        self.assertEqual(d1, {12: [42, 24], 13: [], 14: []})
25        self.assertTrue(d1[12] is not d1[13] is not d1[14])
26        d2 = defaultdict(list, foo=1, bar=2)
27        self.assertEqual(d2.default_factory, list)
28        self.assertEqual(d2, {"foo": 1, "bar": 2})
29        self.assertEqual(d2["foo"], 1)
30        self.assertEqual(d2["bar"], 2)
31        self.assertEqual(d2[42], [])
32        self.assertIn("foo", d2)
33        self.assertIn("foo", d2.keys())
34        self.assertIn("bar", d2)
35        self.assertIn("bar", d2.keys())
36        self.assertIn(42, d2)
37        self.assertIn(42, d2.keys())
38        self.assertNotIn(12, d2)
39        self.assertNotIn(12, d2.keys())
40        d2.default_factory = None
41        self.assertEqual(d2.default_factory, None)
42        try:
43            d2[15]
44        except KeyError as err:
45            self.assertEqual(err.args, (15,))
46        else:
47            self.fail("d2[15] didn't raise KeyError")
48        self.assertRaises(TypeError, defaultdict, 1)
49
50    def test_missing(self):
51        d1 = defaultdict()
52        self.assertRaises(KeyError, d1.__missing__, 42)
53        d1.default_factory = list
54        self.assertEqual(d1.__missing__(42), [])
55
56    def test_repr(self):
57        d1 = defaultdict()
58        self.assertEqual(d1.default_factory, None)
59        self.assertEqual(repr(d1), "defaultdict(None, {})")
60        self.assertEqual(eval(repr(d1)), d1)
61        d1[11] = 41
62        self.assertEqual(repr(d1), "defaultdict(None, {11: 41})")
63        d2 = defaultdict(int)
64        self.assertEqual(d2.default_factory, int)
65        d2[12] = 42
66        self.assertEqual(repr(d2), "defaultdict(<class 'int'>, {12: 42})")
67        def foo(): return 43
68        d3 = defaultdict(foo)
69        self.assertTrue(d3.default_factory is foo)
70        d3[13]
71        self.assertEqual(repr(d3), "defaultdict(%s, {13: 43})" % repr(foo))
72
73    def test_copy(self):
74        d1 = defaultdict()
75        d2 = d1.copy()
76        self.assertEqual(type(d2), defaultdict)
77        self.assertEqual(d2.default_factory, None)
78        self.assertEqual(d2, {})
79        d1.default_factory = list
80        d3 = d1.copy()
81        self.assertEqual(type(d3), defaultdict)
82        self.assertEqual(d3.default_factory, list)
83        self.assertEqual(d3, {})
84        d1[42]
85        d4 = d1.copy()
86        self.assertEqual(type(d4), defaultdict)
87        self.assertEqual(d4.default_factory, list)
88        self.assertEqual(d4, {42: []})
89        d4[12]
90        self.assertEqual(d4, {42: [], 12: []})
91
92        # Issue 6637: Copy fails for empty default dict
93        d = defaultdict()
94        d['a'] = 42
95        e = d.copy()
96        self.assertEqual(e['a'], 42)
97
98    def test_shallow_copy(self):
99        d1 = defaultdict(foobar, {1: 1})
100        d2 = copy.copy(d1)
101        self.assertEqual(d2.default_factory, foobar)
102        self.assertEqual(d2, d1)
103        d1.default_factory = list
104        d2 = copy.copy(d1)
105        self.assertEqual(d2.default_factory, list)
106        self.assertEqual(d2, d1)
107
108    def test_deep_copy(self):
109        d1 = defaultdict(foobar, {1: [1]})
110        d2 = copy.deepcopy(d1)
111        self.assertEqual(d2.default_factory, foobar)
112        self.assertEqual(d2, d1)
113        self.assertTrue(d1[1] is not d2[1])
114        d1.default_factory = list
115        d2 = copy.deepcopy(d1)
116        self.assertEqual(d2.default_factory, list)
117        self.assertEqual(d2, d1)
118
119    def test_keyerror_without_factory(self):
120        d1 = defaultdict()
121        try:
122            d1[(1,)]
123        except KeyError as err:
124            self.assertEqual(err.args[0], (1,))
125        else:
126            self.fail("expected KeyError")
127
128    def test_recursive_repr(self):
129        # Issue2045: stack overflow when default_factory is a bound method
130        class sub(defaultdict):
131            def __init__(self):
132                self.default_factory = self._factory
133            def _factory(self):
134                return []
135        d = sub()
136        self.assertRegex(repr(d),
137            r"sub\(<bound method .*sub\._factory "
138            r"of sub\(\.\.\., \{\}\)>, \{\}\)")
139
140    def test_callable_arg(self):
141        self.assertRaises(TypeError, defaultdict, {})
142
143    def test_pickling(self):
144        d = defaultdict(int)
145        d[1]
146        for proto in range(pickle.HIGHEST_PROTOCOL + 1):
147            s = pickle.dumps(d, proto)
148            o = pickle.loads(s)
149            self.assertEqual(d, o)
150
151    def test_union(self):
152        i = defaultdict(int, {1: 1, 2: 2})
153        s = defaultdict(str, {0: "zero", 1: "one"})
154
155        i_s = i | s
156        self.assertIs(i_s.default_factory, int)
157        self.assertDictEqual(i_s, {1: "one", 2: 2, 0: "zero"})
158        self.assertEqual(list(i_s), [1, 2, 0])
159
160        s_i = s | i
161        self.assertIs(s_i.default_factory, str)
162        self.assertDictEqual(s_i, {0: "zero", 1: 1, 2: 2})
163        self.assertEqual(list(s_i), [0, 1, 2])
164
165        i_ds = i | dict(s)
166        self.assertIs(i_ds.default_factory, int)
167        self.assertDictEqual(i_ds, {1: "one", 2: 2, 0: "zero"})
168        self.assertEqual(list(i_ds), [1, 2, 0])
169
170        ds_i = dict(s) | i
171        self.assertIs(ds_i.default_factory, int)
172        self.assertDictEqual(ds_i, {0: "zero", 1: 1, 2: 2})
173        self.assertEqual(list(ds_i), [0, 1, 2])
174
175        with self.assertRaises(TypeError):
176            i | list(s.items())
177        with self.assertRaises(TypeError):
178            list(s.items()) | i
179
180        # We inherit a fine |= from dict, so just a few sanity checks here:
181        i |= list(s.items())
182        self.assertIs(i.default_factory, int)
183        self.assertDictEqual(i, {1: "one", 2: 2, 0: "zero"})
184        self.assertEqual(list(i), [1, 2, 0])
185
186        with self.assertRaises(TypeError):
187            i |= None
188
189if __name__ == "__main__":
190    unittest.main()
191