• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1import sys
2import unittest
3from doctest import DocTestSuite
4from test import support
5import weakref
6import gc
7
8# Modules under test
9import _thread
10import threading
11import _threading_local
12
13
14class Weak(object):
15    pass
16
17def target(local, weaklist):
18    weak = Weak()
19    local.weak = weak
20    weaklist.append(weakref.ref(weak))
21
22
23class BaseLocalTest:
24
25    def test_local_refs(self):
26        self._local_refs(20)
27        self._local_refs(50)
28        self._local_refs(100)
29
30    def _local_refs(self, n):
31        local = self._local()
32        weaklist = []
33        for i in range(n):
34            t = threading.Thread(target=target, args=(local, weaklist))
35            t.start()
36            t.join()
37        del t
38
39        gc.collect()
40        self.assertEqual(len(weaklist), n)
41
42        # XXX _threading_local keeps the local of the last stopped thread alive.
43        deadlist = [weak for weak in weaklist if weak() is None]
44        self.assertIn(len(deadlist), (n-1, n))
45
46        # Assignment to the same thread local frees it sometimes (!)
47        local.someothervar = None
48        gc.collect()
49        deadlist = [weak for weak in weaklist if weak() is None]
50        self.assertIn(len(deadlist), (n-1, n), (n, len(deadlist)))
51
52    def test_derived(self):
53        # Issue 3088: if there is a threads switch inside the __init__
54        # of a threading.local derived class, the per-thread dictionary
55        # is created but not correctly set on the object.
56        # The first member set may be bogus.
57        import time
58        class Local(self._local):
59            def __init__(self):
60                time.sleep(0.01)
61        local = Local()
62
63        def f(i):
64            local.x = i
65            # Simply check that the variable is correctly set
66            self.assertEqual(local.x, i)
67
68        with support.start_threads(threading.Thread(target=f, args=(i,))
69                                   for i in range(10)):
70            pass
71
72    def test_derived_cycle_dealloc(self):
73        # http://bugs.python.org/issue6990
74        class Local(self._local):
75            pass
76        locals = None
77        passed = False
78        e1 = threading.Event()
79        e2 = threading.Event()
80
81        def f():
82            nonlocal passed
83            # 1) Involve Local in a cycle
84            cycle = [Local()]
85            cycle.append(cycle)
86            cycle[0].foo = 'bar'
87
88            # 2) GC the cycle (triggers threadmodule.c::local_clear
89            # before local_dealloc)
90            del cycle
91            gc.collect()
92            e1.set()
93            e2.wait()
94
95            # 4) New Locals should be empty
96            passed = all(not hasattr(local, 'foo') for local in locals)
97
98        t = threading.Thread(target=f)
99        t.start()
100        e1.wait()
101
102        # 3) New Locals should recycle the original's address. Creating
103        # them in the thread overwrites the thread state and avoids the
104        # bug
105        locals = [Local() for i in range(10)]
106        e2.set()
107        t.join()
108
109        self.assertTrue(passed)
110
111    def test_arguments(self):
112        # Issue 1522237
113        class MyLocal(self._local):
114            def __init__(self, *args, **kwargs):
115                pass
116
117        MyLocal(a=1)
118        MyLocal(1)
119        self.assertRaises(TypeError, self._local, a=1)
120        self.assertRaises(TypeError, self._local, 1)
121
122    def _test_one_class(self, c):
123        self._failed = "No error message set or cleared."
124        obj = c()
125        e1 = threading.Event()
126        e2 = threading.Event()
127
128        def f1():
129            obj.x = 'foo'
130            obj.y = 'bar'
131            del obj.y
132            e1.set()
133            e2.wait()
134
135        def f2():
136            try:
137                foo = obj.x
138            except AttributeError:
139                # This is expected -- we haven't set obj.x in this thread yet!
140                self._failed = ""  # passed
141            else:
142                self._failed = ('Incorrectly got value %r from class %r\n' %
143                                (foo, c))
144                sys.stderr.write(self._failed)
145
146        t1 = threading.Thread(target=f1)
147        t1.start()
148        e1.wait()
149        t2 = threading.Thread(target=f2)
150        t2.start()
151        t2.join()
152        # The test is done; just let t1 know it can exit, and wait for it.
153        e2.set()
154        t1.join()
155
156        self.assertFalse(self._failed, self._failed)
157
158    def test_threading_local(self):
159        self._test_one_class(self._local)
160
161    def test_threading_local_subclass(self):
162        class LocalSubclass(self._local):
163            """To test that subclasses behave properly."""
164        self._test_one_class(LocalSubclass)
165
166    def _test_dict_attribute(self, cls):
167        obj = cls()
168        obj.x = 5
169        self.assertEqual(obj.__dict__, {'x': 5})
170        with self.assertRaises(AttributeError):
171            obj.__dict__ = {}
172        with self.assertRaises(AttributeError):
173            del obj.__dict__
174
175    def test_dict_attribute(self):
176        self._test_dict_attribute(self._local)
177
178    def test_dict_attribute_subclass(self):
179        class LocalSubclass(self._local):
180            """To test that subclasses behave properly."""
181        self._test_dict_attribute(LocalSubclass)
182
183    def test_cycle_collection(self):
184        class X:
185            pass
186
187        x = X()
188        x.local = self._local()
189        x.local.x = x
190        wr = weakref.ref(x)
191        del x
192        gc.collect()
193        self.assertIsNone(wr())
194
195
196class ThreadLocalTest(unittest.TestCase, BaseLocalTest):
197    _local = _thread._local
198
199class PyThreadingLocalTest(unittest.TestCase, BaseLocalTest):
200    _local = _threading_local.local
201
202
203def test_main():
204    suite = unittest.TestSuite()
205    suite.addTest(DocTestSuite('_threading_local'))
206    suite.addTest(unittest.makeSuite(ThreadLocalTest))
207    suite.addTest(unittest.makeSuite(PyThreadingLocalTest))
208
209    local_orig = _threading_local.local
210    def setUp(test):
211        _threading_local.local = _thread._local
212    def tearDown(test):
213        _threading_local.local = local_orig
214    suite.addTest(DocTestSuite('_threading_local',
215                               setUp=setUp, tearDown=tearDown)
216                  )
217
218    support.run_unittest(suite)
219
220if __name__ == '__main__':
221    test_main()
222