• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1"""Tests for streams.py."""
2
3import gc
4import os
5import queue
6import pickle
7import socket
8import sys
9import threading
10import unittest
11from unittest import mock
12from test.support import socket_helper
13try:
14    import ssl
15except ImportError:
16    ssl = None
17
18import asyncio
19from test.test_asyncio import utils as test_utils
20
21
22def tearDownModule():
23    asyncio.set_event_loop_policy(None)
24
25
26class StreamTests(test_utils.TestCase):
27
28    DATA = b'line1\nline2\nline3\n'
29
30    def setUp(self):
31        super().setUp()
32        self.loop = asyncio.new_event_loop()
33        self.set_event_loop(self.loop)
34
35    def tearDown(self):
36        # just in case if we have transport close callbacks
37        test_utils.run_briefly(self.loop)
38
39        self.loop.close()
40        gc.collect()
41        super().tearDown()
42
43    def _basetest_open_connection(self, open_connection_fut):
44        messages = []
45        self.loop.set_exception_handler(lambda loop, ctx: messages.append(ctx))
46        reader, writer = self.loop.run_until_complete(open_connection_fut)
47        writer.write(b'GET / HTTP/1.0\r\n\r\n')
48        f = reader.readline()
49        data = self.loop.run_until_complete(f)
50        self.assertEqual(data, b'HTTP/1.0 200 OK\r\n')
51        f = reader.read()
52        data = self.loop.run_until_complete(f)
53        self.assertTrue(data.endswith(b'\r\n\r\nTest message'))
54        writer.close()
55        self.assertEqual(messages, [])
56
57    def test_open_connection(self):
58        with test_utils.run_test_server() as httpd:
59            conn_fut = asyncio.open_connection(*httpd.address)
60            self._basetest_open_connection(conn_fut)
61
62    @socket_helper.skip_unless_bind_unix_socket
63    def test_open_unix_connection(self):
64        with test_utils.run_test_unix_server() as httpd:
65            conn_fut = asyncio.open_unix_connection(httpd.address)
66            self._basetest_open_connection(conn_fut)
67
68    def _basetest_open_connection_no_loop_ssl(self, open_connection_fut):
69        messages = []
70        self.loop.set_exception_handler(lambda loop, ctx: messages.append(ctx))
71        try:
72            reader, writer = self.loop.run_until_complete(open_connection_fut)
73        finally:
74            asyncio.set_event_loop(None)
75        writer.write(b'GET / HTTP/1.0\r\n\r\n')
76        f = reader.read()
77        data = self.loop.run_until_complete(f)
78        self.assertTrue(data.endswith(b'\r\n\r\nTest message'))
79
80        writer.close()
81        self.assertEqual(messages, [])
82
83    @unittest.skipIf(ssl is None, 'No ssl module')
84    def test_open_connection_no_loop_ssl(self):
85        with test_utils.run_test_server(use_ssl=True) as httpd:
86            conn_fut = asyncio.open_connection(
87                *httpd.address,
88                ssl=test_utils.dummy_ssl_context())
89
90            self._basetest_open_connection_no_loop_ssl(conn_fut)
91
92    @socket_helper.skip_unless_bind_unix_socket
93    @unittest.skipIf(ssl is None, 'No ssl module')
94    def test_open_unix_connection_no_loop_ssl(self):
95        with test_utils.run_test_unix_server(use_ssl=True) as httpd:
96            conn_fut = asyncio.open_unix_connection(
97                httpd.address,
98                ssl=test_utils.dummy_ssl_context(),
99                server_hostname='',
100            )
101
102            self._basetest_open_connection_no_loop_ssl(conn_fut)
103
104    def _basetest_open_connection_error(self, open_connection_fut):
105        messages = []
106        self.loop.set_exception_handler(lambda loop, ctx: messages.append(ctx))
107        reader, writer = self.loop.run_until_complete(open_connection_fut)
108        writer._protocol.connection_lost(ZeroDivisionError())
109        f = reader.read()
110        with self.assertRaises(ZeroDivisionError):
111            self.loop.run_until_complete(f)
112        writer.close()
113        test_utils.run_briefly(self.loop)
114        self.assertEqual(messages, [])
115
116    def test_open_connection_error(self):
117        with test_utils.run_test_server() as httpd:
118            conn_fut = asyncio.open_connection(*httpd.address)
119            self._basetest_open_connection_error(conn_fut)
120
121    @socket_helper.skip_unless_bind_unix_socket
122    def test_open_unix_connection_error(self):
123        with test_utils.run_test_unix_server() as httpd:
124            conn_fut = asyncio.open_unix_connection(httpd.address)
125            self._basetest_open_connection_error(conn_fut)
126
127    def test_feed_empty_data(self):
128        stream = asyncio.StreamReader(loop=self.loop)
129
130        stream.feed_data(b'')
131        self.assertEqual(b'', stream._buffer)
132
133    def test_feed_nonempty_data(self):
134        stream = asyncio.StreamReader(loop=self.loop)
135
136        stream.feed_data(self.DATA)
137        self.assertEqual(self.DATA, stream._buffer)
138
139    def test_read_zero(self):
140        # Read zero bytes.
141        stream = asyncio.StreamReader(loop=self.loop)
142        stream.feed_data(self.DATA)
143
144        data = self.loop.run_until_complete(stream.read(0))
145        self.assertEqual(b'', data)
146        self.assertEqual(self.DATA, stream._buffer)
147
148    def test_read(self):
149        # Read bytes.
150        stream = asyncio.StreamReader(loop=self.loop)
151        read_task = self.loop.create_task(stream.read(30))
152
153        def cb():
154            stream.feed_data(self.DATA)
155        self.loop.call_soon(cb)
156
157        data = self.loop.run_until_complete(read_task)
158        self.assertEqual(self.DATA, data)
159        self.assertEqual(b'', stream._buffer)
160
161    def test_read_line_breaks(self):
162        # Read bytes without line breaks.
163        stream = asyncio.StreamReader(loop=self.loop)
164        stream.feed_data(b'line1')
165        stream.feed_data(b'line2')
166
167        data = self.loop.run_until_complete(stream.read(5))
168
169        self.assertEqual(b'line1', data)
170        self.assertEqual(b'line2', stream._buffer)
171
172    def test_read_eof(self):
173        # Read bytes, stop at eof.
174        stream = asyncio.StreamReader(loop=self.loop)
175        read_task = self.loop.create_task(stream.read(1024))
176
177        def cb():
178            stream.feed_eof()
179        self.loop.call_soon(cb)
180
181        data = self.loop.run_until_complete(read_task)
182        self.assertEqual(b'', data)
183        self.assertEqual(b'', stream._buffer)
184
185    def test_read_until_eof(self):
186        # Read all bytes until eof.
187        stream = asyncio.StreamReader(loop=self.loop)
188        read_task = self.loop.create_task(stream.read(-1))
189
190        def cb():
191            stream.feed_data(b'chunk1\n')
192            stream.feed_data(b'chunk2')
193            stream.feed_eof()
194        self.loop.call_soon(cb)
195
196        data = self.loop.run_until_complete(read_task)
197
198        self.assertEqual(b'chunk1\nchunk2', data)
199        self.assertEqual(b'', stream._buffer)
200
201    def test_read_exception(self):
202        stream = asyncio.StreamReader(loop=self.loop)
203        stream.feed_data(b'line\n')
204
205        data = self.loop.run_until_complete(stream.read(2))
206        self.assertEqual(b'li', data)
207
208        stream.set_exception(ValueError())
209        self.assertRaises(
210            ValueError, self.loop.run_until_complete, stream.read(2))
211
212    def test_invalid_limit(self):
213        with self.assertRaisesRegex(ValueError, 'imit'):
214            asyncio.StreamReader(limit=0, loop=self.loop)
215
216        with self.assertRaisesRegex(ValueError, 'imit'):
217            asyncio.StreamReader(limit=-1, loop=self.loop)
218
219    def test_read_limit(self):
220        stream = asyncio.StreamReader(limit=3, loop=self.loop)
221        stream.feed_data(b'chunk')
222        data = self.loop.run_until_complete(stream.read(5))
223        self.assertEqual(b'chunk', data)
224        self.assertEqual(b'', stream._buffer)
225
226    def test_readline(self):
227        # Read one line. 'readline' will need to wait for the data
228        # to come from 'cb'
229        stream = asyncio.StreamReader(loop=self.loop)
230        stream.feed_data(b'chunk1 ')
231        read_task = self.loop.create_task(stream.readline())
232
233        def cb():
234            stream.feed_data(b'chunk2 ')
235            stream.feed_data(b'chunk3 ')
236            stream.feed_data(b'\n chunk4')
237        self.loop.call_soon(cb)
238
239        line = self.loop.run_until_complete(read_task)
240        self.assertEqual(b'chunk1 chunk2 chunk3 \n', line)
241        self.assertEqual(b' chunk4', stream._buffer)
242
243    def test_readline_limit_with_existing_data(self):
244        # Read one line. The data is in StreamReader's buffer
245        # before the event loop is run.
246
247        stream = asyncio.StreamReader(limit=3, loop=self.loop)
248        stream.feed_data(b'li')
249        stream.feed_data(b'ne1\nline2\n')
250
251        self.assertRaises(
252            ValueError, self.loop.run_until_complete, stream.readline())
253        # The buffer should contain the remaining data after exception
254        self.assertEqual(b'line2\n', stream._buffer)
255
256        stream = asyncio.StreamReader(limit=3, loop=self.loop)
257        stream.feed_data(b'li')
258        stream.feed_data(b'ne1')
259        stream.feed_data(b'li')
260
261        self.assertRaises(
262            ValueError, self.loop.run_until_complete, stream.readline())
263        # No b'\n' at the end. The 'limit' is set to 3. So before
264        # waiting for the new data in buffer, 'readline' will consume
265        # the entire buffer, and since the length of the consumed data
266        # is more than 3, it will raise a ValueError. The buffer is
267        # expected to be empty now.
268        self.assertEqual(b'', stream._buffer)
269
270    def test_at_eof(self):
271        stream = asyncio.StreamReader(loop=self.loop)
272        self.assertFalse(stream.at_eof())
273
274        stream.feed_data(b'some data\n')
275        self.assertFalse(stream.at_eof())
276
277        self.loop.run_until_complete(stream.readline())
278        self.assertFalse(stream.at_eof())
279
280        stream.feed_data(b'some data\n')
281        stream.feed_eof()
282        self.loop.run_until_complete(stream.readline())
283        self.assertTrue(stream.at_eof())
284
285    def test_readline_limit(self):
286        # Read one line. StreamReaders are fed with data after
287        # their 'readline' methods are called.
288
289        stream = asyncio.StreamReader(limit=7, loop=self.loop)
290        def cb():
291            stream.feed_data(b'chunk1')
292            stream.feed_data(b'chunk2')
293            stream.feed_data(b'chunk3\n')
294            stream.feed_eof()
295        self.loop.call_soon(cb)
296
297        self.assertRaises(
298            ValueError, self.loop.run_until_complete, stream.readline())
299        # The buffer had just one line of data, and after raising
300        # a ValueError it should be empty.
301        self.assertEqual(b'', stream._buffer)
302
303        stream = asyncio.StreamReader(limit=7, loop=self.loop)
304        def cb():
305            stream.feed_data(b'chunk1')
306            stream.feed_data(b'chunk2\n')
307            stream.feed_data(b'chunk3\n')
308            stream.feed_eof()
309        self.loop.call_soon(cb)
310
311        self.assertRaises(
312            ValueError, self.loop.run_until_complete, stream.readline())
313        self.assertEqual(b'chunk3\n', stream._buffer)
314
315        # check strictness of the limit
316        stream = asyncio.StreamReader(limit=7, loop=self.loop)
317        stream.feed_data(b'1234567\n')
318        line = self.loop.run_until_complete(stream.readline())
319        self.assertEqual(b'1234567\n', line)
320        self.assertEqual(b'', stream._buffer)
321
322        stream.feed_data(b'12345678\n')
323        with self.assertRaises(ValueError) as cm:
324            self.loop.run_until_complete(stream.readline())
325        self.assertEqual(b'', stream._buffer)
326
327        stream.feed_data(b'12345678')
328        with self.assertRaises(ValueError) as cm:
329            self.loop.run_until_complete(stream.readline())
330        self.assertEqual(b'', stream._buffer)
331
332    def test_readline_nolimit_nowait(self):
333        # All needed data for the first 'readline' call will be
334        # in the buffer.
335        stream = asyncio.StreamReader(loop=self.loop)
336        stream.feed_data(self.DATA[:6])
337        stream.feed_data(self.DATA[6:])
338
339        line = self.loop.run_until_complete(stream.readline())
340
341        self.assertEqual(b'line1\n', line)
342        self.assertEqual(b'line2\nline3\n', stream._buffer)
343
344    def test_readline_eof(self):
345        stream = asyncio.StreamReader(loop=self.loop)
346        stream.feed_data(b'some data')
347        stream.feed_eof()
348
349        line = self.loop.run_until_complete(stream.readline())
350        self.assertEqual(b'some data', line)
351
352    def test_readline_empty_eof(self):
353        stream = asyncio.StreamReader(loop=self.loop)
354        stream.feed_eof()
355
356        line = self.loop.run_until_complete(stream.readline())
357        self.assertEqual(b'', line)
358
359    def test_readline_read_byte_count(self):
360        stream = asyncio.StreamReader(loop=self.loop)
361        stream.feed_data(self.DATA)
362
363        self.loop.run_until_complete(stream.readline())
364
365        data = self.loop.run_until_complete(stream.read(7))
366
367        self.assertEqual(b'line2\nl', data)
368        self.assertEqual(b'ine3\n', stream._buffer)
369
370    def test_readline_exception(self):
371        stream = asyncio.StreamReader(loop=self.loop)
372        stream.feed_data(b'line\n')
373
374        data = self.loop.run_until_complete(stream.readline())
375        self.assertEqual(b'line\n', data)
376
377        stream.set_exception(ValueError())
378        self.assertRaises(
379            ValueError, self.loop.run_until_complete, stream.readline())
380        self.assertEqual(b'', stream._buffer)
381
382    def test_readuntil_separator(self):
383        stream = asyncio.StreamReader(loop=self.loop)
384        with self.assertRaisesRegex(ValueError, 'Separator should be'):
385            self.loop.run_until_complete(stream.readuntil(separator=b''))
386
387    def test_readuntil_multi_chunks(self):
388        stream = asyncio.StreamReader(loop=self.loop)
389
390        stream.feed_data(b'lineAAA')
391        data = self.loop.run_until_complete(stream.readuntil(separator=b'AAA'))
392        self.assertEqual(b'lineAAA', data)
393        self.assertEqual(b'', stream._buffer)
394
395        stream.feed_data(b'lineAAA')
396        data = self.loop.run_until_complete(stream.readuntil(b'AAA'))
397        self.assertEqual(b'lineAAA', data)
398        self.assertEqual(b'', stream._buffer)
399
400        stream.feed_data(b'lineAAAxxx')
401        data = self.loop.run_until_complete(stream.readuntil(b'AAA'))
402        self.assertEqual(b'lineAAA', data)
403        self.assertEqual(b'xxx', stream._buffer)
404
405    def test_readuntil_multi_chunks_1(self):
406        stream = asyncio.StreamReader(loop=self.loop)
407
408        stream.feed_data(b'QWEaa')
409        stream.feed_data(b'XYaa')
410        stream.feed_data(b'a')
411        data = self.loop.run_until_complete(stream.readuntil(b'aaa'))
412        self.assertEqual(b'QWEaaXYaaa', data)
413        self.assertEqual(b'', stream._buffer)
414
415        stream.feed_data(b'QWEaa')
416        stream.feed_data(b'XYa')
417        stream.feed_data(b'aa')
418        data = self.loop.run_until_complete(stream.readuntil(b'aaa'))
419        self.assertEqual(b'QWEaaXYaaa', data)
420        self.assertEqual(b'', stream._buffer)
421
422        stream.feed_data(b'aaa')
423        data = self.loop.run_until_complete(stream.readuntil(b'aaa'))
424        self.assertEqual(b'aaa', data)
425        self.assertEqual(b'', stream._buffer)
426
427        stream.feed_data(b'Xaaa')
428        data = self.loop.run_until_complete(stream.readuntil(b'aaa'))
429        self.assertEqual(b'Xaaa', data)
430        self.assertEqual(b'', stream._buffer)
431
432        stream.feed_data(b'XXX')
433        stream.feed_data(b'a')
434        stream.feed_data(b'a')
435        stream.feed_data(b'a')
436        data = self.loop.run_until_complete(stream.readuntil(b'aaa'))
437        self.assertEqual(b'XXXaaa', data)
438        self.assertEqual(b'', stream._buffer)
439
440    def test_readuntil_eof(self):
441        stream = asyncio.StreamReader(loop=self.loop)
442        data = b'some dataAA'
443        stream.feed_data(data)
444        stream.feed_eof()
445
446        with self.assertRaisesRegex(asyncio.IncompleteReadError,
447                                    'undefined expected bytes') as cm:
448            self.loop.run_until_complete(stream.readuntil(b'AAA'))
449        self.assertEqual(cm.exception.partial, data)
450        self.assertIsNone(cm.exception.expected)
451        self.assertEqual(b'', stream._buffer)
452
453    def test_readuntil_limit_found_sep(self):
454        stream = asyncio.StreamReader(loop=self.loop, limit=3)
455        stream.feed_data(b'some dataAA')
456        with self.assertRaisesRegex(asyncio.LimitOverrunError,
457                                    'not found') as cm:
458            self.loop.run_until_complete(stream.readuntil(b'AAA'))
459
460        self.assertEqual(b'some dataAA', stream._buffer)
461
462        stream.feed_data(b'A')
463        with self.assertRaisesRegex(asyncio.LimitOverrunError,
464                                    'is found') as cm:
465            self.loop.run_until_complete(stream.readuntil(b'AAA'))
466
467        self.assertEqual(b'some dataAAA', stream._buffer)
468
469    def test_readexactly_zero_or_less(self):
470        # Read exact number of bytes (zero or less).
471        stream = asyncio.StreamReader(loop=self.loop)
472        stream.feed_data(self.DATA)
473
474        data = self.loop.run_until_complete(stream.readexactly(0))
475        self.assertEqual(b'', data)
476        self.assertEqual(self.DATA, stream._buffer)
477
478        with self.assertRaisesRegex(ValueError, 'less than zero'):
479            self.loop.run_until_complete(stream.readexactly(-1))
480        self.assertEqual(self.DATA, stream._buffer)
481
482    def test_readexactly(self):
483        # Read exact number of bytes.
484        stream = asyncio.StreamReader(loop=self.loop)
485
486        n = 2 * len(self.DATA)
487        read_task = self.loop.create_task(stream.readexactly(n))
488
489        def cb():
490            stream.feed_data(self.DATA)
491            stream.feed_data(self.DATA)
492            stream.feed_data(self.DATA)
493        self.loop.call_soon(cb)
494
495        data = self.loop.run_until_complete(read_task)
496        self.assertEqual(self.DATA + self.DATA, data)
497        self.assertEqual(self.DATA, stream._buffer)
498
499    def test_readexactly_limit(self):
500        stream = asyncio.StreamReader(limit=3, loop=self.loop)
501        stream.feed_data(b'chunk')
502        data = self.loop.run_until_complete(stream.readexactly(5))
503        self.assertEqual(b'chunk', data)
504        self.assertEqual(b'', stream._buffer)
505
506    def test_readexactly_eof(self):
507        # Read exact number of bytes (eof).
508        stream = asyncio.StreamReader(loop=self.loop)
509        n = 2 * len(self.DATA)
510        read_task = self.loop.create_task(stream.readexactly(n))
511
512        def cb():
513            stream.feed_data(self.DATA)
514            stream.feed_eof()
515        self.loop.call_soon(cb)
516
517        with self.assertRaises(asyncio.IncompleteReadError) as cm:
518            self.loop.run_until_complete(read_task)
519        self.assertEqual(cm.exception.partial, self.DATA)
520        self.assertEqual(cm.exception.expected, n)
521        self.assertEqual(str(cm.exception),
522                         '18 bytes read on a total of 36 expected bytes')
523        self.assertEqual(b'', stream._buffer)
524
525    def test_readexactly_exception(self):
526        stream = asyncio.StreamReader(loop=self.loop)
527        stream.feed_data(b'line\n')
528
529        data = self.loop.run_until_complete(stream.readexactly(2))
530        self.assertEqual(b'li', data)
531
532        stream.set_exception(ValueError())
533        self.assertRaises(
534            ValueError, self.loop.run_until_complete, stream.readexactly(2))
535
536    def test_exception(self):
537        stream = asyncio.StreamReader(loop=self.loop)
538        self.assertIsNone(stream.exception())
539
540        exc = ValueError()
541        stream.set_exception(exc)
542        self.assertIs(stream.exception(), exc)
543
544    def test_exception_waiter(self):
545        stream = asyncio.StreamReader(loop=self.loop)
546
547        async def set_err():
548            stream.set_exception(ValueError())
549
550        t1 = self.loop.create_task(stream.readline())
551        t2 = self.loop.create_task(set_err())
552
553        self.loop.run_until_complete(asyncio.wait([t1, t2]))
554
555        self.assertRaises(ValueError, t1.result)
556
557    def test_exception_cancel(self):
558        stream = asyncio.StreamReader(loop=self.loop)
559
560        t = self.loop.create_task(stream.readline())
561        test_utils.run_briefly(self.loop)
562        t.cancel()
563        test_utils.run_briefly(self.loop)
564        # The following line fails if set_exception() isn't careful.
565        stream.set_exception(RuntimeError('message'))
566        test_utils.run_briefly(self.loop)
567        self.assertIs(stream._waiter, None)
568
569    def test_start_server(self):
570
571        class MyServer:
572
573            def __init__(self, loop):
574                self.server = None
575                self.loop = loop
576
577            async def handle_client(self, client_reader, client_writer):
578                data = await client_reader.readline()
579                client_writer.write(data)
580                await client_writer.drain()
581                client_writer.close()
582                await client_writer.wait_closed()
583
584            def start(self):
585                sock = socket.create_server(('127.0.0.1', 0))
586                self.server = self.loop.run_until_complete(
587                    asyncio.start_server(self.handle_client,
588                                         sock=sock))
589                return sock.getsockname()
590
591            def handle_client_callback(self, client_reader, client_writer):
592                self.loop.create_task(self.handle_client(client_reader,
593                                                         client_writer))
594
595            def start_callback(self):
596                sock = socket.create_server(('127.0.0.1', 0))
597                addr = sock.getsockname()
598                sock.close()
599                self.server = self.loop.run_until_complete(
600                    asyncio.start_server(self.handle_client_callback,
601                                         host=addr[0], port=addr[1]))
602                return addr
603
604            def stop(self):
605                if self.server is not None:
606                    self.server.close()
607                    self.loop.run_until_complete(self.server.wait_closed())
608                    self.server = None
609
610        async def client(addr):
611            reader, writer = await asyncio.open_connection(*addr)
612            # send a line
613            writer.write(b"hello world!\n")
614            # read it back
615            msgback = await reader.readline()
616            writer.close()
617            await writer.wait_closed()
618            return msgback
619
620        messages = []
621        self.loop.set_exception_handler(lambda loop, ctx: messages.append(ctx))
622
623        # test the server variant with a coroutine as client handler
624        server = MyServer(self.loop)
625        addr = server.start()
626        msg = self.loop.run_until_complete(self.loop.create_task(client(addr)))
627        server.stop()
628        self.assertEqual(msg, b"hello world!\n")
629
630        # test the server variant with a callback as client handler
631        server = MyServer(self.loop)
632        addr = server.start_callback()
633        msg = self.loop.run_until_complete(self.loop.create_task(client(addr)))
634        server.stop()
635        self.assertEqual(msg, b"hello world!\n")
636
637        self.assertEqual(messages, [])
638
639    @socket_helper.skip_unless_bind_unix_socket
640    def test_start_unix_server(self):
641
642        class MyServer:
643
644            def __init__(self, loop, path):
645                self.server = None
646                self.loop = loop
647                self.path = path
648
649            async def handle_client(self, client_reader, client_writer):
650                data = await client_reader.readline()
651                client_writer.write(data)
652                await client_writer.drain()
653                client_writer.close()
654                await client_writer.wait_closed()
655
656            def start(self):
657                self.server = self.loop.run_until_complete(
658                    asyncio.start_unix_server(self.handle_client,
659                                              path=self.path))
660
661            def handle_client_callback(self, client_reader, client_writer):
662                self.loop.create_task(self.handle_client(client_reader,
663                                                         client_writer))
664
665            def start_callback(self):
666                start = asyncio.start_unix_server(self.handle_client_callback,
667                                                  path=self.path)
668                self.server = self.loop.run_until_complete(start)
669
670            def stop(self):
671                if self.server is not None:
672                    self.server.close()
673                    self.loop.run_until_complete(self.server.wait_closed())
674                    self.server = None
675
676        async def client(path):
677            reader, writer = await asyncio.open_unix_connection(path)
678            # send a line
679            writer.write(b"hello world!\n")
680            # read it back
681            msgback = await reader.readline()
682            writer.close()
683            await writer.wait_closed()
684            return msgback
685
686        messages = []
687        self.loop.set_exception_handler(lambda loop, ctx: messages.append(ctx))
688
689        # test the server variant with a coroutine as client handler
690        with test_utils.unix_socket_path() as path:
691            server = MyServer(self.loop, path)
692            server.start()
693            msg = self.loop.run_until_complete(
694                self.loop.create_task(client(path)))
695            server.stop()
696            self.assertEqual(msg, b"hello world!\n")
697
698        # test the server variant with a callback as client handler
699        with test_utils.unix_socket_path() as path:
700            server = MyServer(self.loop, path)
701            server.start_callback()
702            msg = self.loop.run_until_complete(
703                self.loop.create_task(client(path)))
704            server.stop()
705            self.assertEqual(msg, b"hello world!\n")
706
707        self.assertEqual(messages, [])
708
709    @unittest.skipIf(sys.platform == 'win32', "Don't have pipes")
710    def test_read_all_from_pipe_reader(self):
711        # See asyncio issue 168.  This test is derived from the example
712        # subprocess_attach_read_pipe.py, but we configure the
713        # StreamReader's limit so that twice it is less than the size
714        # of the data writer.  Also we must explicitly attach a child
715        # watcher to the event loop.
716
717        code = """\
718import os, sys
719fd = int(sys.argv[1])
720os.write(fd, b'data')
721os.close(fd)
722"""
723        rfd, wfd = os.pipe()
724        args = [sys.executable, '-c', code, str(wfd)]
725
726        pipe = open(rfd, 'rb', 0)
727        reader = asyncio.StreamReader(loop=self.loop, limit=1)
728        protocol = asyncio.StreamReaderProtocol(reader, loop=self.loop)
729        transport, _ = self.loop.run_until_complete(
730            self.loop.connect_read_pipe(lambda: protocol, pipe))
731
732        watcher = asyncio.SafeChildWatcher()
733        watcher.attach_loop(self.loop)
734        try:
735            asyncio.set_child_watcher(watcher)
736            create = asyncio.create_subprocess_exec(
737                *args,
738                pass_fds={wfd},
739            )
740            proc = self.loop.run_until_complete(create)
741            self.loop.run_until_complete(proc.wait())
742        finally:
743            asyncio.set_child_watcher(None)
744
745        os.close(wfd)
746        data = self.loop.run_until_complete(reader.read(-1))
747        self.assertEqual(data, b'data')
748
749    def test_streamreader_constructor_without_loop(self):
750        with self.assertWarns(DeprecationWarning) as cm:
751            with self.assertRaisesRegex(RuntimeError, 'There is no current event loop'):
752                asyncio.StreamReader()
753        self.assertEqual(cm.warnings[0].filename, __file__)
754
755    def test_streamreader_constructor_use_running_loop(self):
756        # asyncio issue #184: Ensure that StreamReaderProtocol constructor
757        # retrieves the current loop if the loop parameter is not set
758        async def test():
759            return asyncio.StreamReader()
760
761        reader = self.loop.run_until_complete(test())
762        self.assertIs(reader._loop, self.loop)
763
764    def test_streamreader_constructor_use_global_loop(self):
765        # asyncio issue #184: Ensure that StreamReaderProtocol constructor
766        # retrieves the current loop if the loop parameter is not set
767        # Deprecated in 3.10
768        self.addCleanup(asyncio.set_event_loop, None)
769        asyncio.set_event_loop(self.loop)
770        with self.assertWarns(DeprecationWarning) as cm:
771            reader = asyncio.StreamReader()
772        self.assertEqual(cm.warnings[0].filename, __file__)
773        self.assertIs(reader._loop, self.loop)
774
775
776    def test_streamreaderprotocol_constructor_without_loop(self):
777        reader = mock.Mock()
778        with self.assertWarns(DeprecationWarning) as cm:
779            with self.assertRaisesRegex(RuntimeError, 'There is no current event loop'):
780                asyncio.StreamReaderProtocol(reader)
781        self.assertEqual(cm.warnings[0].filename, __file__)
782
783    def test_streamreaderprotocol_constructor_use_running_loop(self):
784        # asyncio issue #184: Ensure that StreamReaderProtocol constructor
785        # retrieves the current loop if the loop parameter is not set
786        reader = mock.Mock()
787        async def test():
788            return asyncio.StreamReaderProtocol(reader)
789        protocol = self.loop.run_until_complete(test())
790        self.assertIs(protocol._loop, self.loop)
791
792    def test_streamreaderprotocol_constructor_use_global_loop(self):
793        # asyncio issue #184: Ensure that StreamReaderProtocol constructor
794        # retrieves the current loop if the loop parameter is not set
795        # Deprecated in 3.10
796        self.addCleanup(asyncio.set_event_loop, None)
797        asyncio.set_event_loop(self.loop)
798        reader = mock.Mock()
799        with self.assertWarns(DeprecationWarning) as cm:
800            protocol = asyncio.StreamReaderProtocol(reader)
801        self.assertEqual(cm.warnings[0].filename, __file__)
802        self.assertIs(protocol._loop, self.loop)
803
804    def test_drain_raises(self):
805        # See http://bugs.python.org/issue25441
806
807        # This test should not use asyncio for the mock server; the
808        # whole point of the test is to test for a bug in drain()
809        # where it never gives up the event loop but the socket is
810        # closed on the  server side.
811
812        messages = []
813        self.loop.set_exception_handler(lambda loop, ctx: messages.append(ctx))
814        q = queue.Queue()
815
816        def server():
817            # Runs in a separate thread.
818            with socket.create_server(('localhost', 0)) as sock:
819                addr = sock.getsockname()
820                q.put(addr)
821                clt, _ = sock.accept()
822                clt.close()
823
824        async def client(host, port):
825            reader, writer = await asyncio.open_connection(host, port)
826
827            while True:
828                writer.write(b"foo\n")
829                await writer.drain()
830
831        # Start the server thread and wait for it to be listening.
832        thread = threading.Thread(target=server)
833        thread.daemon = True
834        thread.start()
835        addr = q.get()
836
837        # Should not be stuck in an infinite loop.
838        with self.assertRaises((ConnectionResetError, ConnectionAbortedError,
839                                BrokenPipeError)):
840            self.loop.run_until_complete(client(*addr))
841
842        # Clean up the thread.  (Only on success; on failure, it may
843        # be stuck in accept().)
844        thread.join()
845        self.assertEqual([], messages)
846
847    def test___repr__(self):
848        stream = asyncio.StreamReader(loop=self.loop)
849        self.assertEqual("<StreamReader>", repr(stream))
850
851    def test___repr__nondefault_limit(self):
852        stream = asyncio.StreamReader(loop=self.loop, limit=123)
853        self.assertEqual("<StreamReader limit=123>", repr(stream))
854
855    def test___repr__eof(self):
856        stream = asyncio.StreamReader(loop=self.loop)
857        stream.feed_eof()
858        self.assertEqual("<StreamReader eof>", repr(stream))
859
860    def test___repr__data(self):
861        stream = asyncio.StreamReader(loop=self.loop)
862        stream.feed_data(b'data')
863        self.assertEqual("<StreamReader 4 bytes>", repr(stream))
864
865    def test___repr__exception(self):
866        stream = asyncio.StreamReader(loop=self.loop)
867        exc = RuntimeError()
868        stream.set_exception(exc)
869        self.assertEqual("<StreamReader exception=RuntimeError()>",
870                         repr(stream))
871
872    def test___repr__waiter(self):
873        stream = asyncio.StreamReader(loop=self.loop)
874        stream._waiter = asyncio.Future(loop=self.loop)
875        self.assertRegex(
876            repr(stream),
877            r"<StreamReader waiter=<Future pending[\S ]*>>")
878        stream._waiter.set_result(None)
879        self.loop.run_until_complete(stream._waiter)
880        stream._waiter = None
881        self.assertEqual("<StreamReader>", repr(stream))
882
883    def test___repr__transport(self):
884        stream = asyncio.StreamReader(loop=self.loop)
885        stream._transport = mock.Mock()
886        stream._transport.__repr__ = mock.Mock()
887        stream._transport.__repr__.return_value = "<Transport>"
888        self.assertEqual("<StreamReader transport=<Transport>>", repr(stream))
889
890    def test_IncompleteReadError_pickleable(self):
891        e = asyncio.IncompleteReadError(b'abc', 10)
892        for proto in range(pickle.HIGHEST_PROTOCOL + 1):
893            with self.subTest(pickle_protocol=proto):
894                e2 = pickle.loads(pickle.dumps(e, protocol=proto))
895                self.assertEqual(str(e), str(e2))
896                self.assertEqual(e.partial, e2.partial)
897                self.assertEqual(e.expected, e2.expected)
898
899    def test_LimitOverrunError_pickleable(self):
900        e = asyncio.LimitOverrunError('message', 10)
901        for proto in range(pickle.HIGHEST_PROTOCOL + 1):
902            with self.subTest(pickle_protocol=proto):
903                e2 = pickle.loads(pickle.dumps(e, protocol=proto))
904                self.assertEqual(str(e), str(e2))
905                self.assertEqual(e.consumed, e2.consumed)
906
907    def test_wait_closed_on_close(self):
908        with test_utils.run_test_server() as httpd:
909            rd, wr = self.loop.run_until_complete(
910                asyncio.open_connection(*httpd.address))
911
912            wr.write(b'GET / HTTP/1.0\r\n\r\n')
913            f = rd.readline()
914            data = self.loop.run_until_complete(f)
915            self.assertEqual(data, b'HTTP/1.0 200 OK\r\n')
916            f = rd.read()
917            data = self.loop.run_until_complete(f)
918            self.assertTrue(data.endswith(b'\r\n\r\nTest message'))
919            self.assertFalse(wr.is_closing())
920            wr.close()
921            self.assertTrue(wr.is_closing())
922            self.loop.run_until_complete(wr.wait_closed())
923
924    def test_wait_closed_on_close_with_unread_data(self):
925        with test_utils.run_test_server() as httpd:
926            rd, wr = self.loop.run_until_complete(
927                asyncio.open_connection(*httpd.address))
928
929            wr.write(b'GET / HTTP/1.0\r\n\r\n')
930            f = rd.readline()
931            data = self.loop.run_until_complete(f)
932            self.assertEqual(data, b'HTTP/1.0 200 OK\r\n')
933            wr.close()
934            self.loop.run_until_complete(wr.wait_closed())
935
936    def test_async_writer_api(self):
937        async def inner(httpd):
938            rd, wr = await asyncio.open_connection(*httpd.address)
939
940            wr.write(b'GET / HTTP/1.0\r\n\r\n')
941            data = await rd.readline()
942            self.assertEqual(data, b'HTTP/1.0 200 OK\r\n')
943            data = await rd.read()
944            self.assertTrue(data.endswith(b'\r\n\r\nTest message'))
945            wr.close()
946            await wr.wait_closed()
947
948        messages = []
949        self.loop.set_exception_handler(lambda loop, ctx: messages.append(ctx))
950
951        with test_utils.run_test_server() as httpd:
952            self.loop.run_until_complete(inner(httpd))
953
954        self.assertEqual(messages, [])
955
956    def test_async_writer_api_exception_after_close(self):
957        async def inner(httpd):
958            rd, wr = await asyncio.open_connection(*httpd.address)
959
960            wr.write(b'GET / HTTP/1.0\r\n\r\n')
961            data = await rd.readline()
962            self.assertEqual(data, b'HTTP/1.0 200 OK\r\n')
963            data = await rd.read()
964            self.assertTrue(data.endswith(b'\r\n\r\nTest message'))
965            wr.close()
966            with self.assertRaises(ConnectionResetError):
967                wr.write(b'data')
968                await wr.drain()
969
970        messages = []
971        self.loop.set_exception_handler(lambda loop, ctx: messages.append(ctx))
972
973        with test_utils.run_test_server() as httpd:
974            self.loop.run_until_complete(inner(httpd))
975
976        self.assertEqual(messages, [])
977
978    def test_eof_feed_when_closing_writer(self):
979        # See http://bugs.python.org/issue35065
980        messages = []
981        self.loop.set_exception_handler(lambda loop, ctx: messages.append(ctx))
982
983        with test_utils.run_test_server() as httpd:
984            rd, wr = self.loop.run_until_complete(
985                    asyncio.open_connection(*httpd.address))
986
987            wr.close()
988            f = wr.wait_closed()
989            self.loop.run_until_complete(f)
990            self.assertTrue(rd.at_eof())
991            f = rd.read()
992            data = self.loop.run_until_complete(f)
993            self.assertEqual(data, b'')
994
995        self.assertEqual(messages, [])
996
997
998if __name__ == '__main__':
999    unittest.main()
1000