• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1"""Tests for transports.py."""
2
3import unittest
4from unittest import mock
5
6import asyncio
7from asyncio import transports
8
9
10def tearDownModule():
11    # not needed for the test file but added for uniformness with all other
12    # asyncio test files for the sake of unified cleanup
13    asyncio.set_event_loop_policy(None)
14
15
16class TransportTests(unittest.TestCase):
17
18    def test_ctor_extra_is_none(self):
19        transport = asyncio.Transport()
20        self.assertEqual(transport._extra, {})
21
22    def test_get_extra_info(self):
23        transport = asyncio.Transport({'extra': 'info'})
24        self.assertEqual('info', transport.get_extra_info('extra'))
25        self.assertIsNone(transport.get_extra_info('unknown'))
26
27        default = object()
28        self.assertIs(default, transport.get_extra_info('unknown', default))
29
30    def test_writelines(self):
31        writer = mock.Mock()
32
33        class MyTransport(asyncio.Transport):
34            def write(self, data):
35                writer(data)
36
37        transport = MyTransport()
38
39        transport.writelines([b'line1',
40                              bytearray(b'line2'),
41                              memoryview(b'line3')])
42        self.assertEqual(1, writer.call_count)
43        writer.assert_called_with(b'line1line2line3')
44
45    def test_not_implemented(self):
46        transport = asyncio.Transport()
47
48        self.assertRaises(NotImplementedError,
49                          transport.set_write_buffer_limits)
50        self.assertRaises(NotImplementedError, transport.get_write_buffer_size)
51        self.assertRaises(NotImplementedError, transport.write, 'data')
52        self.assertRaises(NotImplementedError, transport.write_eof)
53        self.assertRaises(NotImplementedError, transport.can_write_eof)
54        self.assertRaises(NotImplementedError, transport.pause_reading)
55        self.assertRaises(NotImplementedError, transport.resume_reading)
56        self.assertRaises(NotImplementedError, transport.is_reading)
57        self.assertRaises(NotImplementedError, transport.close)
58        self.assertRaises(NotImplementedError, transport.abort)
59
60    def test_dgram_not_implemented(self):
61        transport = asyncio.DatagramTransport()
62
63        self.assertRaises(NotImplementedError, transport.sendto, 'data')
64        self.assertRaises(NotImplementedError, transport.abort)
65
66    def test_subprocess_transport_not_implemented(self):
67        transport = asyncio.SubprocessTransport()
68
69        self.assertRaises(NotImplementedError, transport.get_pid)
70        self.assertRaises(NotImplementedError, transport.get_returncode)
71        self.assertRaises(NotImplementedError, transport.get_pipe_transport, 1)
72        self.assertRaises(NotImplementedError, transport.send_signal, 1)
73        self.assertRaises(NotImplementedError, transport.terminate)
74        self.assertRaises(NotImplementedError, transport.kill)
75
76    def test_flowcontrol_mixin_set_write_limits(self):
77
78        class MyTransport(transports._FlowControlMixin,
79                          transports.Transport):
80
81            def get_write_buffer_size(self):
82                return 512
83
84        loop = mock.Mock()
85        transport = MyTransport(loop=loop)
86        transport._protocol = mock.Mock()
87
88        self.assertFalse(transport._protocol_paused)
89
90        with self.assertRaisesRegex(ValueError, 'high.*must be >= low'):
91            transport.set_write_buffer_limits(high=0, low=1)
92
93        transport.set_write_buffer_limits(high=1024, low=128)
94        self.assertFalse(transport._protocol_paused)
95        self.assertEqual(transport.get_write_buffer_limits(), (128, 1024))
96
97        transport.set_write_buffer_limits(high=256, low=128)
98        self.assertTrue(transport._protocol_paused)
99        self.assertEqual(transport.get_write_buffer_limits(), (128, 256))
100
101
102if __name__ == '__main__':
103    unittest.main()
104