• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1import unittest
2from test import test_support
3
4import errno
5import itertools
6import socket
7import select
8import time
9import traceback
10import Queue
11import sys
12import os
13import array
14import contextlib
15import signal
16import math
17import weakref
18try:
19    import _socket
20except ImportError:
21    _socket = None
22
23
24def try_address(host, port=0, family=socket.AF_INET):
25    """Try to bind a socket on the given host:port and return True
26    if that has been possible."""
27    try:
28        sock = socket.socket(family, socket.SOCK_STREAM)
29        sock.bind((host, port))
30    except (socket.error, socket.gaierror):
31        return False
32    else:
33        sock.close()
34        return True
35
36HOST = test_support.HOST
37MSG = b'Michael Gilfix was here\n'
38SUPPORTS_IPV6 = socket.has_ipv6 and try_address('::1', family=socket.AF_INET6)
39
40try:
41    import thread
42    import threading
43except ImportError:
44    thread = None
45    threading = None
46
47HOST = test_support.HOST
48MSG = 'Michael Gilfix was here\n'
49
50class SocketTCPTest(unittest.TestCase):
51
52    def setUp(self):
53        self.serv = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
54        self.port = test_support.bind_port(self.serv)
55        self.serv.listen(1)
56
57    def tearDown(self):
58        self.serv.close()
59        self.serv = None
60
61class SocketUDPTest(unittest.TestCase):
62
63    def setUp(self):
64        self.serv = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
65        self.port = test_support.bind_port(self.serv)
66
67    def tearDown(self):
68        self.serv.close()
69        self.serv = None
70
71class ThreadableTest:
72    """Threadable Test class
73
74    The ThreadableTest class makes it easy to create a threaded
75    client/server pair from an existing unit test. To create a
76    new threaded class from an existing unit test, use multiple
77    inheritance:
78
79        class NewClass (OldClass, ThreadableTest):
80            pass
81
82    This class defines two new fixture functions with obvious
83    purposes for overriding:
84
85        clientSetUp ()
86        clientTearDown ()
87
88    Any new test functions within the class must then define
89    tests in pairs, where the test name is preceded with a
90    '_' to indicate the client portion of the test. Ex:
91
92        def testFoo(self):
93            # Server portion
94
95        def _testFoo(self):
96            # Client portion
97
98    Any exceptions raised by the clients during their tests
99    are caught and transferred to the main thread to alert
100    the testing framework.
101
102    Note, the server setup function cannot call any blocking
103    functions that rely on the client thread during setup,
104    unless serverExplicitReady() is called just before
105    the blocking call (such as in setting up a client/server
106    connection and performing the accept() in setUp().
107    """
108
109    def __init__(self):
110        # Swap the true setup function
111        self.__setUp = self.setUp
112        self.__tearDown = self.tearDown
113        self.setUp = self._setUp
114        self.tearDown = self._tearDown
115
116    def serverExplicitReady(self):
117        """This method allows the server to explicitly indicate that
118        it wants the client thread to proceed. This is useful if the
119        server is about to execute a blocking routine that is
120        dependent upon the client thread during its setup routine."""
121        self.server_ready.set()
122
123    def _setUp(self):
124        self.server_ready = threading.Event()
125        self.client_ready = threading.Event()
126        self.done = threading.Event()
127        self.queue = Queue.Queue(1)
128
129        # Do some munging to start the client test.
130        methodname = self.id()
131        i = methodname.rfind('.')
132        methodname = methodname[i+1:]
133        test_method = getattr(self, '_' + methodname)
134        self.client_thread = thread.start_new_thread(
135            self.clientRun, (test_method,))
136
137        self.__setUp()
138        if not self.server_ready.is_set():
139            self.server_ready.set()
140        self.client_ready.wait()
141
142    def _tearDown(self):
143        self.__tearDown()
144        self.done.wait()
145
146        if not self.queue.empty():
147            msg = self.queue.get()
148            self.fail(msg)
149
150    def clientRun(self, test_func):
151        self.server_ready.wait()
152        self.clientSetUp()
153        self.client_ready.set()
154        if not callable(test_func):
155            raise TypeError("test_func must be a callable function.")
156        try:
157            test_func()
158        except Exception, strerror:
159            self.queue.put(strerror)
160        self.clientTearDown()
161
162    def clientSetUp(self):
163        raise NotImplementedError("clientSetUp must be implemented.")
164
165    def clientTearDown(self):
166        self.done.set()
167        thread.exit()
168
169class ThreadedTCPSocketTest(SocketTCPTest, ThreadableTest):
170
171    def __init__(self, methodName='runTest'):
172        SocketTCPTest.__init__(self, methodName=methodName)
173        ThreadableTest.__init__(self)
174
175    def clientSetUp(self):
176        self.cli = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
177
178    def clientTearDown(self):
179        self.cli.close()
180        self.cli = None
181        ThreadableTest.clientTearDown(self)
182
183class ThreadedUDPSocketTest(SocketUDPTest, ThreadableTest):
184
185    def __init__(self, methodName='runTest'):
186        SocketUDPTest.__init__(self, methodName=methodName)
187        ThreadableTest.__init__(self)
188
189    def clientSetUp(self):
190        self.cli = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
191
192    def clientTearDown(self):
193        self.cli.close()
194        self.cli = None
195        ThreadableTest.clientTearDown(self)
196
197class SocketConnectedTest(ThreadedTCPSocketTest):
198
199    def __init__(self, methodName='runTest'):
200        ThreadedTCPSocketTest.__init__(self, methodName=methodName)
201
202    def setUp(self):
203        ThreadedTCPSocketTest.setUp(self)
204        # Indicate explicitly we're ready for the client thread to
205        # proceed and then perform the blocking call to accept
206        self.serverExplicitReady()
207        conn, addr = self.serv.accept()
208        self.cli_conn = conn
209
210    def tearDown(self):
211        self.cli_conn.close()
212        self.cli_conn = None
213        ThreadedTCPSocketTest.tearDown(self)
214
215    def clientSetUp(self):
216        ThreadedTCPSocketTest.clientSetUp(self)
217        self.cli.connect((HOST, self.port))
218        self.serv_conn = self.cli
219
220    def clientTearDown(self):
221        self.serv_conn.close()
222        self.serv_conn = None
223        ThreadedTCPSocketTest.clientTearDown(self)
224
225class SocketPairTest(unittest.TestCase, ThreadableTest):
226
227    def __init__(self, methodName='runTest'):
228        unittest.TestCase.__init__(self, methodName=methodName)
229        ThreadableTest.__init__(self)
230
231    def setUp(self):
232        self.serv, self.cli = socket.socketpair()
233
234    def tearDown(self):
235        self.serv.close()
236        self.serv = None
237
238    def clientSetUp(self):
239        pass
240
241    def clientTearDown(self):
242        self.cli.close()
243        self.cli = None
244        ThreadableTest.clientTearDown(self)
245
246
247#######################################################################
248## Begin Tests
249
250class GeneralModuleTests(unittest.TestCase):
251
252    @unittest.skipUnless(_socket is not None, 'need _socket module')
253    def test_csocket_repr(self):
254        s = _socket.socket(_socket.AF_INET, _socket.SOCK_STREAM)
255        try:
256            expected = ('<socket object, fd=%s, family=%s, type=%s, protocol=%s>'
257                        % (s.fileno(), s.family, s.type, s.proto))
258            self.assertEqual(repr(s), expected)
259        finally:
260            s.close()
261        expected = ('<socket object, fd=-1, family=%s, type=%s, protocol=%s>'
262                    % (s.family, s.type, s.proto))
263        self.assertEqual(repr(s), expected)
264
265    def test_weakref(self):
266        s = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
267        p = weakref.proxy(s)
268        self.assertEqual(p.fileno(), s.fileno())
269        s.close()
270        s = None
271        try:
272            p.fileno()
273        except ReferenceError:
274            pass
275        else:
276            self.fail('Socket proxy still exists')
277
278    def test_weakref__sock(self):
279        s = socket.socket()._sock
280        w = weakref.ref(s)
281        self.assertIs(w(), s)
282        del s
283        test_support.gc_collect()
284        self.assertIsNone(w())
285
286    def testSocketError(self):
287        # Testing socket module exceptions
288        def raise_error(*args, **kwargs):
289            raise socket.error
290        def raise_herror(*args, **kwargs):
291            raise socket.herror
292        def raise_gaierror(*args, **kwargs):
293            raise socket.gaierror
294        self.assertRaises(socket.error, raise_error,
295                              "Error raising socket exception.")
296        self.assertRaises(socket.error, raise_herror,
297                              "Error raising socket exception.")
298        self.assertRaises(socket.error, raise_gaierror,
299                              "Error raising socket exception.")
300
301    def testSendtoErrors(self):
302        # Testing that sendto doesn't mask failures. See #10169.
303        s = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
304        self.addCleanup(s.close)
305        s.bind(('', 0))
306        sockname = s.getsockname()
307        # 2 args
308        with self.assertRaises(UnicodeEncodeError):
309            s.sendto(u'\u2620', sockname)
310        with self.assertRaises(TypeError) as cm:
311            s.sendto(5j, sockname)
312        self.assertIn('not complex', str(cm.exception))
313        with self.assertRaises(TypeError) as cm:
314            s.sendto('foo', None)
315        self.assertIn('not NoneType', str(cm.exception))
316        # 3 args
317        with self.assertRaises(UnicodeEncodeError):
318            s.sendto(u'\u2620', 0, sockname)
319        with self.assertRaises(TypeError) as cm:
320            s.sendto(5j, 0, sockname)
321        self.assertIn('not complex', str(cm.exception))
322        with self.assertRaises(TypeError) as cm:
323            s.sendto('foo', 0, None)
324        self.assertIn('not NoneType', str(cm.exception))
325        with self.assertRaises(TypeError) as cm:
326            s.sendto('foo', 'bar', sockname)
327        self.assertIn('an integer is required', str(cm.exception))
328        with self.assertRaises(TypeError) as cm:
329            s.sendto('foo', None, None)
330        self.assertIn('an integer is required', str(cm.exception))
331        # wrong number of args
332        with self.assertRaises(TypeError) as cm:
333            s.sendto('foo')
334        self.assertIn('(1 given)', str(cm.exception))
335        with self.assertRaises(TypeError) as cm:
336            s.sendto('foo', 0, sockname, 4)
337        self.assertIn('(4 given)', str(cm.exception))
338
339
340    def testCrucialConstants(self):
341        # Testing for mission critical constants
342        socket.AF_INET
343        socket.SOCK_STREAM
344        socket.SOCK_DGRAM
345        socket.SOCK_RAW
346        socket.SOCK_RDM
347        socket.SOCK_SEQPACKET
348        socket.SOL_SOCKET
349        socket.SO_REUSEADDR
350
351    def testHostnameRes(self):
352        # Testing hostname resolution mechanisms
353        hostname = socket.gethostname()
354        try:
355            ip = socket.gethostbyname(hostname)
356        except socket.error:
357            # Probably name lookup wasn't set up right; skip this test
358            self.skipTest('name lookup failure')
359        self.assertTrue(ip.find('.') >= 0, "Error resolving host to ip.")
360        try:
361            hname, aliases, ipaddrs = socket.gethostbyaddr(ip)
362        except socket.error:
363            # Probably a similar problem as above; skip this test
364            self.skipTest('address lookup failure')
365        all_host_names = [hostname, hname] + aliases
366        fqhn = socket.getfqdn(ip)
367        if not fqhn in all_host_names:
368            self.fail("Error testing host resolution mechanisms. (fqdn: %s, all: %s)" % (fqhn, repr(all_host_names)))
369
370    @unittest.skipUnless(hasattr(sys, 'getrefcount'),
371                         'test needs sys.getrefcount()')
372    def testRefCountGetNameInfo(self):
373        # Testing reference count for getnameinfo
374        try:
375            # On some versions, this loses a reference
376            orig = sys.getrefcount(__name__)
377            socket.getnameinfo(__name__,0)
378        except TypeError:
379            self.assertEqual(sys.getrefcount(__name__), orig,
380                             "socket.getnameinfo loses a reference")
381
382    def testInterpreterCrash(self):
383        # Making sure getnameinfo doesn't crash the interpreter
384        try:
385            # On some versions, this crashes the interpreter.
386            socket.getnameinfo(('x', 0, 0, 0), 0)
387        except socket.error:
388            pass
389
390    def testNtoH(self):
391        # This just checks that htons etc. are their own inverse,
392        # when looking at the lower 16 or 32 bits.
393        sizes = {socket.htonl: 32, socket.ntohl: 32,
394                 socket.htons: 16, socket.ntohs: 16}
395        for func, size in sizes.items():
396            mask = (1L<<size) - 1
397            for i in (0, 1, 0xffff, ~0xffff, 2, 0x01234567, 0x76543210):
398                self.assertEqual(i & mask, func(func(i&mask)) & mask)
399
400            swapped = func(mask)
401            self.assertEqual(swapped & mask, mask)
402            self.assertRaises(OverflowError, func, 1L<<34)
403
404    def testNtoHErrors(self):
405        good_values = [ 1, 2, 3, 1L, 2L, 3L ]
406        bad_values = [ -1, -2, -3, -1L, -2L, -3L ]
407        for k in good_values:
408            socket.ntohl(k)
409            socket.ntohs(k)
410            socket.htonl(k)
411            socket.htons(k)
412        for k in bad_values:
413            self.assertRaises(OverflowError, socket.ntohl, k)
414            self.assertRaises(OverflowError, socket.ntohs, k)
415            self.assertRaises(OverflowError, socket.htonl, k)
416            self.assertRaises(OverflowError, socket.htons, k)
417
418    def testGetServBy(self):
419        eq = self.assertEqual
420        # Find one service that exists, then check all the related interfaces.
421        # I've ordered this by protocols that have both a tcp and udp
422        # protocol, at least for modern Linuxes.
423        if (sys.platform.startswith('linux') or
424            sys.platform.startswith('freebsd') or
425            sys.platform.startswith('netbsd') or
426            sys.platform == 'darwin'):
427            # avoid the 'echo' service on this platform, as there is an
428            # assumption breaking non-standard port/protocol entry
429            services = ('daytime', 'qotd', 'domain')
430        else:
431            services = ('echo', 'daytime', 'domain')
432        for service in services:
433            try:
434                port = socket.getservbyname(service, 'tcp')
435                break
436            except socket.error:
437                pass
438        else:
439            raise socket.error
440        # Try same call with optional protocol omitted
441        port2 = socket.getservbyname(service)
442        eq(port, port2)
443        # Try udp, but don't barf if it doesn't exist
444        try:
445            udpport = socket.getservbyname(service, 'udp')
446        except socket.error:
447            udpport = None
448        else:
449            eq(udpport, port)
450        # Now make sure the lookup by port returns the same service name
451        eq(socket.getservbyport(port2), service)
452        eq(socket.getservbyport(port, 'tcp'), service)
453        if udpport is not None:
454            eq(socket.getservbyport(udpport, 'udp'), service)
455        # Make sure getservbyport does not accept out of range ports.
456        self.assertRaises(OverflowError, socket.getservbyport, -1)
457        self.assertRaises(OverflowError, socket.getservbyport, 65536)
458
459    def testDefaultTimeout(self):
460        # Testing default timeout
461        # The default timeout should initially be None
462        self.assertEqual(socket.getdefaulttimeout(), None)
463        s = socket.socket()
464        self.assertEqual(s.gettimeout(), None)
465        s.close()
466
467        # Set the default timeout to 10, and see if it propagates
468        socket.setdefaulttimeout(10)
469        self.assertEqual(socket.getdefaulttimeout(), 10)
470        s = socket.socket()
471        self.assertEqual(s.gettimeout(), 10)
472        s.close()
473
474        # Reset the default timeout to None, and see if it propagates
475        socket.setdefaulttimeout(None)
476        self.assertEqual(socket.getdefaulttimeout(), None)
477        s = socket.socket()
478        self.assertEqual(s.gettimeout(), None)
479        s.close()
480
481        # Check that setting it to an invalid value raises ValueError
482        self.assertRaises(ValueError, socket.setdefaulttimeout, -1)
483
484        # Check that setting it to an invalid type raises TypeError
485        self.assertRaises(TypeError, socket.setdefaulttimeout, "spam")
486
487    @unittest.skipUnless(hasattr(socket, 'inet_aton'),
488                         'test needs socket.inet_aton()')
489    def testIPv4_inet_aton_fourbytes(self):
490        # Test that issue1008086 and issue767150 are fixed.
491        # It must return 4 bytes.
492        self.assertEqual('\x00'*4, socket.inet_aton('0.0.0.0'))
493        self.assertEqual('\xff'*4, socket.inet_aton('255.255.255.255'))
494
495    @unittest.skipUnless(hasattr(socket, 'inet_pton'),
496                         'test needs socket.inet_pton()')
497    def testIPv4toString(self):
498        from socket import inet_aton as f, inet_pton, AF_INET
499        g = lambda a: inet_pton(AF_INET, a)
500
501        self.assertEqual('\x00\x00\x00\x00', f('0.0.0.0'))
502        self.assertEqual('\xff\x00\xff\x00', f('255.0.255.0'))
503        self.assertEqual('\xaa\xaa\xaa\xaa', f('170.170.170.170'))
504        self.assertEqual('\x01\x02\x03\x04', f('1.2.3.4'))
505        self.assertEqual('\xff\xff\xff\xff', f('255.255.255.255'))
506
507        self.assertEqual('\x00\x00\x00\x00', g('0.0.0.0'))
508        self.assertEqual('\xff\x00\xff\x00', g('255.0.255.0'))
509        self.assertEqual('\xaa\xaa\xaa\xaa', g('170.170.170.170'))
510        self.assertEqual('\xff\xff\xff\xff', g('255.255.255.255'))
511
512    @unittest.skipUnless(hasattr(socket, 'inet_pton'),
513                         'test needs socket.inet_pton()')
514    def testIPv6toString(self):
515        try:
516            from socket import inet_pton, AF_INET6, has_ipv6
517            if not has_ipv6:
518                self.skipTest('IPv6 not available')
519        except ImportError:
520            self.skipTest('could not import needed symbols from socket')
521        f = lambda a: inet_pton(AF_INET6, a)
522
523        self.assertEqual('\x00' * 16, f('::'))
524        self.assertEqual('\x00' * 16, f('0::0'))
525        self.assertEqual('\x00\x01' + '\x00' * 14, f('1::'))
526        self.assertEqual(
527            '\x45\xef\x76\xcb\x00\x1a\x56\xef\xaf\xeb\x0b\xac\x19\x24\xae\xae',
528            f('45ef:76cb:1a:56ef:afeb:bac:1924:aeae')
529        )
530
531    @unittest.skipUnless(hasattr(socket, 'inet_ntop'),
532                         'test needs socket.inet_ntop()')
533    def testStringToIPv4(self):
534        from socket import inet_ntoa as f, inet_ntop, AF_INET
535        g = lambda a: inet_ntop(AF_INET, a)
536
537        self.assertEqual('1.0.1.0', f('\x01\x00\x01\x00'))
538        self.assertEqual('170.85.170.85', f('\xaa\x55\xaa\x55'))
539        self.assertEqual('255.255.255.255', f('\xff\xff\xff\xff'))
540        self.assertEqual('1.2.3.4', f('\x01\x02\x03\x04'))
541
542        self.assertEqual('1.0.1.0', g('\x01\x00\x01\x00'))
543        self.assertEqual('170.85.170.85', g('\xaa\x55\xaa\x55'))
544        self.assertEqual('255.255.255.255', g('\xff\xff\xff\xff'))
545
546    @unittest.skipUnless(hasattr(socket, 'inet_ntop'),
547                         'test needs socket.inet_ntop()')
548    def testStringToIPv6(self):
549        try:
550            from socket import inet_ntop, AF_INET6, has_ipv6
551            if not has_ipv6:
552                self.skipTest('IPv6 not available')
553        except ImportError:
554            self.skipTest('could not import needed symbols from socket')
555        f = lambda a: inet_ntop(AF_INET6, a)
556
557        self.assertEqual('::', f('\x00' * 16))
558        self.assertEqual('::1', f('\x00' * 15 + '\x01'))
559        self.assertEqual(
560            'aef:b01:506:1001:ffff:9997:55:170',
561            f('\x0a\xef\x0b\x01\x05\x06\x10\x01\xff\xff\x99\x97\x00\x55\x01\x70')
562        )
563
564    # XXX The following don't test module-level functionality...
565
566    def _get_unused_port(self, bind_address='0.0.0.0'):
567        """Use a temporary socket to elicit an unused ephemeral port.
568
569        Args:
570            bind_address: Hostname or IP address to search for a port on.
571
572        Returns: A most likely to be unused port.
573        """
574        tempsock = socket.socket()
575        tempsock.bind((bind_address, 0))
576        host, port = tempsock.getsockname()
577        tempsock.close()
578        return port
579
580    def testSockName(self):
581        # Testing getsockname()
582        port = self._get_unused_port()
583        sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
584        self.addCleanup(sock.close)
585        sock.bind(("0.0.0.0", port))
586        name = sock.getsockname()
587        # XXX(nnorwitz): http://tinyurl.com/os5jz seems to indicate
588        # it reasonable to get the host's addr in addition to 0.0.0.0.
589        # At least for eCos.  This is required for the S/390 to pass.
590        try:
591            my_ip_addr = socket.gethostbyname(socket.gethostname())
592        except socket.error:
593            # Probably name lookup wasn't set up right; skip this test
594            self.skipTest('name lookup failure')
595        self.assertIn(name[0], ("0.0.0.0", my_ip_addr), '%s invalid' % name[0])
596        self.assertEqual(name[1], port)
597
598    def testGetSockOpt(self):
599        # Testing getsockopt()
600        # We know a socket should start without reuse==0
601        sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
602        self.addCleanup(sock.close)
603        reuse = sock.getsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR)
604        self.assertFalse(reuse != 0, "initial mode is reuse")
605
606    def testSetSockOpt(self):
607        # Testing setsockopt()
608        sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
609        self.addCleanup(sock.close)
610        sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
611        reuse = sock.getsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR)
612        self.assertFalse(reuse == 0, "failed to set reuse mode")
613
614    def testSendAfterClose(self):
615        # testing send() after close() with timeout
616        sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
617        sock.settimeout(1)
618        sock.close()
619        self.assertRaises(socket.error, sock.send, "spam")
620
621    def testNewAttributes(self):
622        # testing .family, .type and .protocol
623        sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
624        self.assertEqual(sock.family, socket.AF_INET)
625        self.assertEqual(sock.type, socket.SOCK_STREAM)
626        self.assertEqual(sock.proto, 0)
627        sock.close()
628
629    def test_getsockaddrarg(self):
630        sock = socket.socket()
631        self.addCleanup(sock.close)
632        port = test_support.find_unused_port()
633        big_port = port + 65536
634        neg_port = port - 65536
635        self.assertRaises(OverflowError, sock.bind, (HOST, big_port))
636        self.assertRaises(OverflowError, sock.bind, (HOST, neg_port))
637        # Since find_unused_port() is inherently subject to race conditions, we
638        # call it a couple times if necessary.
639        for i in itertools.count():
640            port = test_support.find_unused_port()
641            try:
642                sock.bind((HOST, port))
643            except OSError as e:
644                if e.errno != errno.EADDRINUSE or i == 5:
645                    raise
646            else:
647                break
648
649    @unittest.skipUnless(os.name == "nt", "Windows specific")
650    def test_sock_ioctl(self):
651        self.assertTrue(hasattr(socket.socket, 'ioctl'))
652        self.assertTrue(hasattr(socket, 'SIO_RCVALL'))
653        self.assertTrue(hasattr(socket, 'RCVALL_ON'))
654        self.assertTrue(hasattr(socket, 'RCVALL_OFF'))
655        self.assertTrue(hasattr(socket, 'SIO_KEEPALIVE_VALS'))
656        s = socket.socket()
657        self.addCleanup(s.close)
658        self.assertRaises(ValueError, s.ioctl, -1, None)
659        s.ioctl(socket.SIO_KEEPALIVE_VALS, (1, 100, 100))
660
661    def testGetaddrinfo(self):
662        try:
663            socket.getaddrinfo('localhost', 80)
664        except socket.gaierror as err:
665            if err.errno == socket.EAI_SERVICE:
666                # see http://bugs.python.org/issue1282647
667                self.skipTest("buggy libc version")
668            raise
669        # len of every sequence is supposed to be == 5
670        for info in socket.getaddrinfo(HOST, None):
671            self.assertEqual(len(info), 5)
672        # host can be a domain name, a string representation of an
673        # IPv4/v6 address or None
674        socket.getaddrinfo('localhost', 80)
675        socket.getaddrinfo('127.0.0.1', 80)
676        socket.getaddrinfo(None, 80)
677        if SUPPORTS_IPV6:
678            socket.getaddrinfo('::1', 80)
679        # port can be a string service name such as "http", a numeric
680        # port number (int or long), or None
681        socket.getaddrinfo(HOST, "http")
682        socket.getaddrinfo(HOST, 80)
683        socket.getaddrinfo(HOST, 80L)
684        socket.getaddrinfo(HOST, None)
685        # test family and socktype filters
686        infos = socket.getaddrinfo(HOST, None, socket.AF_INET)
687        for family, _, _, _, _ in infos:
688            self.assertEqual(family, socket.AF_INET)
689        infos = socket.getaddrinfo(HOST, None, 0, socket.SOCK_STREAM)
690        for _, socktype, _, _, _ in infos:
691            self.assertEqual(socktype, socket.SOCK_STREAM)
692        # test proto and flags arguments
693        socket.getaddrinfo(HOST, None, 0, 0, socket.SOL_TCP)
694        socket.getaddrinfo(HOST, None, 0, 0, 0, socket.AI_PASSIVE)
695        # a server willing to support both IPv4 and IPv6 will
696        # usually do this
697        socket.getaddrinfo(None, 0, socket.AF_UNSPEC, socket.SOCK_STREAM, 0,
698                           socket.AI_PASSIVE)
699
700        # Issue 17269: test workaround for OS X platform bug segfault
701        if hasattr(socket, 'AI_NUMERICSERV'):
702            try:
703                # The arguments here are undefined and the call may succeed
704                # or fail.  All we care here is that it doesn't segfault.
705                socket.getaddrinfo("localhost", None, 0, 0, 0,
706                                   socket.AI_NUMERICSERV)
707            except socket.gaierror:
708                pass
709
710    def check_sendall_interrupted(self, with_timeout):
711        # socketpair() is not strictly required, but it makes things easier.
712        if not hasattr(signal, 'alarm') or not hasattr(socket, 'socketpair'):
713            self.skipTest("signal.alarm and socket.socketpair required for this test")
714        # Our signal handlers clobber the C errno by calling a math function
715        # with an invalid domain value.
716        def ok_handler(*args):
717            self.assertRaises(ValueError, math.acosh, 0)
718        def raising_handler(*args):
719            self.assertRaises(ValueError, math.acosh, 0)
720            1 // 0
721        c, s = socket.socketpair()
722        old_alarm = signal.signal(signal.SIGALRM, raising_handler)
723        try:
724            if with_timeout:
725                # Just above the one second minimum for signal.alarm
726                c.settimeout(1.5)
727            with self.assertRaises(ZeroDivisionError):
728                signal.alarm(1)
729                c.sendall(b"x" * test_support.SOCK_MAX_SIZE)
730            if with_timeout:
731                signal.signal(signal.SIGALRM, ok_handler)
732                signal.alarm(1)
733                self.assertRaises(socket.timeout, c.sendall,
734                                  b"x" * test_support.SOCK_MAX_SIZE)
735        finally:
736            signal.signal(signal.SIGALRM, old_alarm)
737            c.close()
738            s.close()
739
740    def test_sendall_interrupted(self):
741        self.check_sendall_interrupted(False)
742
743    def test_sendall_interrupted_with_timeout(self):
744        self.check_sendall_interrupted(True)
745
746    def test_listen_backlog(self):
747        for backlog in 0, -1:
748            srv = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
749            srv.bind((HOST, 0))
750            srv.listen(backlog)
751            srv.close()
752
753    @test_support.cpython_only
754    def test_listen_backlog_overflow(self):
755        # Issue 15989
756        import _testcapi
757        srv = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
758        srv.bind((HOST, 0))
759        self.assertRaises(OverflowError, srv.listen, _testcapi.INT_MAX + 1)
760        srv.close()
761
762    @unittest.skipUnless(SUPPORTS_IPV6, 'IPv6 required for this test.')
763    def test_flowinfo(self):
764        self.assertRaises(OverflowError, socket.getnameinfo,
765                          ('::1',0, 0xffffffff), 0)
766        s = socket.socket(socket.AF_INET6, socket.SOCK_STREAM)
767        try:
768            self.assertRaises(OverflowError, s.bind, ('::1', 0, -10))
769        finally:
770            s.close()
771
772
773@unittest.skipUnless(thread, 'Threading required for this test.')
774class BasicTCPTest(SocketConnectedTest):
775
776    def __init__(self, methodName='runTest'):
777        SocketConnectedTest.__init__(self, methodName=methodName)
778
779    def testRecv(self):
780        # Testing large receive over TCP
781        msg = self.cli_conn.recv(1024)
782        self.assertEqual(msg, MSG)
783
784    def _testRecv(self):
785        self.serv_conn.send(MSG)
786
787    def testOverFlowRecv(self):
788        # Testing receive in chunks over TCP
789        seg1 = self.cli_conn.recv(len(MSG) - 3)
790        seg2 = self.cli_conn.recv(1024)
791        msg = seg1 + seg2
792        self.assertEqual(msg, MSG)
793
794    def _testOverFlowRecv(self):
795        self.serv_conn.send(MSG)
796
797    def testRecvFrom(self):
798        # Testing large recvfrom() over TCP
799        msg, addr = self.cli_conn.recvfrom(1024)
800        self.assertEqual(msg, MSG)
801
802    def _testRecvFrom(self):
803        self.serv_conn.send(MSG)
804
805    def testOverFlowRecvFrom(self):
806        # Testing recvfrom() in chunks over TCP
807        seg1, addr = self.cli_conn.recvfrom(len(MSG)-3)
808        seg2, addr = self.cli_conn.recvfrom(1024)
809        msg = seg1 + seg2
810        self.assertEqual(msg, MSG)
811
812    def _testOverFlowRecvFrom(self):
813        self.serv_conn.send(MSG)
814
815    def testSendAll(self):
816        # Testing sendall() with a 2048 byte string over TCP
817        msg = ''
818        while 1:
819            read = self.cli_conn.recv(1024)
820            if not read:
821                break
822            msg += read
823        self.assertEqual(msg, 'f' * 2048)
824
825    def _testSendAll(self):
826        big_chunk = 'f' * 2048
827        self.serv_conn.sendall(big_chunk)
828
829    @unittest.skipUnless(hasattr(socket, 'fromfd'),
830                         'socket.fromfd not available')
831    def testFromFd(self):
832        # Testing fromfd()
833        fd = self.cli_conn.fileno()
834        sock = socket.fromfd(fd, socket.AF_INET, socket.SOCK_STREAM)
835        self.addCleanup(sock.close)
836        msg = sock.recv(1024)
837        self.assertEqual(msg, MSG)
838
839    def _testFromFd(self):
840        self.serv_conn.send(MSG)
841
842    def testDup(self):
843        # Testing dup()
844        sock = self.cli_conn.dup()
845        self.addCleanup(sock.close)
846        msg = sock.recv(1024)
847        self.assertEqual(msg, MSG)
848
849    def _testDup(self):
850        self.serv_conn.send(MSG)
851
852    def testShutdown(self):
853        # Testing shutdown()
854        msg = self.cli_conn.recv(1024)
855        self.assertEqual(msg, MSG)
856        # wait for _testShutdown to finish: on OS X, when the server
857        # closes the connection the client also becomes disconnected,
858        # and the client's shutdown call will fail. (Issue #4397.)
859        self.done.wait()
860
861    def _testShutdown(self):
862        self.serv_conn.send(MSG)
863        self.serv_conn.shutdown(2)
864
865    testShutdown_overflow = test_support.cpython_only(testShutdown)
866
867    @test_support.cpython_only
868    def _testShutdown_overflow(self):
869        import _testcapi
870        self.serv_conn.send(MSG)
871        # Issue 15989
872        self.assertRaises(OverflowError, self.serv_conn.shutdown,
873                          _testcapi.INT_MAX + 1)
874        self.assertRaises(OverflowError, self.serv_conn.shutdown,
875                          2 + (_testcapi.UINT_MAX + 1))
876        self.serv_conn.shutdown(2)
877
878@unittest.skipUnless(thread, 'Threading required for this test.')
879class BasicUDPTest(ThreadedUDPSocketTest):
880
881    def __init__(self, methodName='runTest'):
882        ThreadedUDPSocketTest.__init__(self, methodName=methodName)
883
884    def testSendtoAndRecv(self):
885        # Testing sendto() and Recv() over UDP
886        msg = self.serv.recv(len(MSG))
887        self.assertEqual(msg, MSG)
888
889    def _testSendtoAndRecv(self):
890        self.cli.sendto(MSG, 0, (HOST, self.port))
891
892    def testRecvFrom(self):
893        # Testing recvfrom() over UDP
894        msg, addr = self.serv.recvfrom(len(MSG))
895        self.assertEqual(msg, MSG)
896
897    def _testRecvFrom(self):
898        self.cli.sendto(MSG, 0, (HOST, self.port))
899
900    def testRecvFromNegative(self):
901        # Negative lengths passed to recvfrom should give ValueError.
902        self.assertRaises(ValueError, self.serv.recvfrom, -1)
903
904    def _testRecvFromNegative(self):
905        self.cli.sendto(MSG, 0, (HOST, self.port))
906
907@unittest.skipUnless(thread, 'Threading required for this test.')
908class TCPCloserTest(ThreadedTCPSocketTest):
909
910    def testClose(self):
911        conn, addr = self.serv.accept()
912        conn.close()
913
914        sd = self.cli
915        read, write, err = select.select([sd], [], [], 1.0)
916        self.assertEqual(read, [sd])
917        self.assertEqual(sd.recv(1), '')
918
919    def _testClose(self):
920        self.cli.connect((HOST, self.port))
921        time.sleep(1.0)
922
923@unittest.skipUnless(hasattr(socket, 'socketpair'),
924                     'test needs socket.socketpair()')
925@unittest.skipUnless(thread, 'Threading required for this test.')
926class BasicSocketPairTest(SocketPairTest):
927
928    def __init__(self, methodName='runTest'):
929        SocketPairTest.__init__(self, methodName=methodName)
930
931    def testRecv(self):
932        msg = self.serv.recv(1024)
933        self.assertEqual(msg, MSG)
934
935    def _testRecv(self):
936        self.cli.send(MSG)
937
938    def testSend(self):
939        self.serv.send(MSG)
940
941    def _testSend(self):
942        msg = self.cli.recv(1024)
943        self.assertEqual(msg, MSG)
944
945@unittest.skipUnless(thread, 'Threading required for this test.')
946class NonBlockingTCPTests(ThreadedTCPSocketTest):
947
948    def __init__(self, methodName='runTest'):
949        ThreadedTCPSocketTest.__init__(self, methodName=methodName)
950
951    def testSetBlocking(self):
952        # Testing whether set blocking works
953        self.serv.setblocking(True)
954        self.assertIsNone(self.serv.gettimeout())
955        self.serv.setblocking(False)
956        self.assertEqual(self.serv.gettimeout(), 0.0)
957        start = time.time()
958        try:
959            self.serv.accept()
960        except socket.error:
961            pass
962        end = time.time()
963        self.assertTrue((end - start) < 1.0, "Error setting non-blocking mode.")
964
965    def _testSetBlocking(self):
966        pass
967
968    @test_support.cpython_only
969    def testSetBlocking_overflow(self):
970        # Issue 15989
971        import _testcapi
972        if _testcapi.UINT_MAX >= _testcapi.ULONG_MAX:
973            self.skipTest('needs UINT_MAX < ULONG_MAX')
974        self.serv.setblocking(False)
975        self.assertEqual(self.serv.gettimeout(), 0.0)
976        self.serv.setblocking(_testcapi.UINT_MAX + 1)
977        self.assertIsNone(self.serv.gettimeout())
978
979    _testSetBlocking_overflow = test_support.cpython_only(_testSetBlocking)
980
981    def testAccept(self):
982        # Testing non-blocking accept
983        self.serv.setblocking(0)
984        try:
985            conn, addr = self.serv.accept()
986        except socket.error:
987            pass
988        else:
989            self.fail("Error trying to do non-blocking accept.")
990        read, write, err = select.select([self.serv], [], [])
991        if self.serv in read:
992            conn, addr = self.serv.accept()
993            conn.close()
994        else:
995            self.fail("Error trying to do accept after select.")
996
997    def _testAccept(self):
998        time.sleep(0.1)
999        self.cli.connect((HOST, self.port))
1000
1001    def testConnect(self):
1002        # Testing non-blocking connect
1003        conn, addr = self.serv.accept()
1004        conn.close()
1005
1006    def _testConnect(self):
1007        self.cli.settimeout(10)
1008        self.cli.connect((HOST, self.port))
1009
1010    def testRecv(self):
1011        # Testing non-blocking recv
1012        conn, addr = self.serv.accept()
1013        conn.setblocking(0)
1014        try:
1015            msg = conn.recv(len(MSG))
1016        except socket.error:
1017            pass
1018        else:
1019            self.fail("Error trying to do non-blocking recv.")
1020        read, write, err = select.select([conn], [], [])
1021        if conn in read:
1022            msg = conn.recv(len(MSG))
1023            conn.close()
1024            self.assertEqual(msg, MSG)
1025        else:
1026            self.fail("Error during select call to non-blocking socket.")
1027
1028    def _testRecv(self):
1029        self.cli.connect((HOST, self.port))
1030        time.sleep(0.1)
1031        self.cli.send(MSG)
1032
1033@unittest.skipUnless(thread, 'Threading required for this test.')
1034class FileObjectClassTestCase(SocketConnectedTest):
1035
1036    bufsize = -1 # Use default buffer size
1037
1038    def __init__(self, methodName='runTest'):
1039        SocketConnectedTest.__init__(self, methodName=methodName)
1040
1041    def setUp(self):
1042        SocketConnectedTest.setUp(self)
1043        self.serv_file = self.cli_conn.makefile('rb', self.bufsize)
1044
1045    def tearDown(self):
1046        self.serv_file.close()
1047        self.assertTrue(self.serv_file.closed)
1048        SocketConnectedTest.tearDown(self)
1049        self.serv_file = None
1050
1051    def clientSetUp(self):
1052        SocketConnectedTest.clientSetUp(self)
1053        self.cli_file = self.serv_conn.makefile('wb')
1054
1055    def clientTearDown(self):
1056        self.cli_file.close()
1057        self.assertTrue(self.cli_file.closed)
1058        self.cli_file = None
1059        SocketConnectedTest.clientTearDown(self)
1060
1061    def testSmallRead(self):
1062        # Performing small file read test
1063        first_seg = self.serv_file.read(len(MSG)-3)
1064        second_seg = self.serv_file.read(3)
1065        msg = first_seg + second_seg
1066        self.assertEqual(msg, MSG)
1067
1068    def _testSmallRead(self):
1069        self.cli_file.write(MSG)
1070        self.cli_file.flush()
1071
1072    def testFullRead(self):
1073        # read until EOF
1074        msg = self.serv_file.read()
1075        self.assertEqual(msg, MSG)
1076
1077    def _testFullRead(self):
1078        self.cli_file.write(MSG)
1079        self.cli_file.close()
1080
1081    def testUnbufferedRead(self):
1082        # Performing unbuffered file read test
1083        buf = ''
1084        while 1:
1085            char = self.serv_file.read(1)
1086            if not char:
1087                break
1088            buf += char
1089        self.assertEqual(buf, MSG)
1090
1091    def _testUnbufferedRead(self):
1092        self.cli_file.write(MSG)
1093        self.cli_file.flush()
1094
1095    def testReadline(self):
1096        # Performing file readline test
1097        line = self.serv_file.readline()
1098        self.assertEqual(line, MSG)
1099
1100    def _testReadline(self):
1101        self.cli_file.write(MSG)
1102        self.cli_file.flush()
1103
1104    def testReadlineAfterRead(self):
1105        a_baloo_is = self.serv_file.read(len("A baloo is"))
1106        self.assertEqual("A baloo is", a_baloo_is)
1107        _a_bear = self.serv_file.read(len(" a bear"))
1108        self.assertEqual(" a bear", _a_bear)
1109        line = self.serv_file.readline()
1110        self.assertEqual("\n", line)
1111        line = self.serv_file.readline()
1112        self.assertEqual("A BALOO IS A BEAR.\n", line)
1113        line = self.serv_file.readline()
1114        self.assertEqual(MSG, line)
1115
1116    def _testReadlineAfterRead(self):
1117        self.cli_file.write("A baloo is a bear\n")
1118        self.cli_file.write("A BALOO IS A BEAR.\n")
1119        self.cli_file.write(MSG)
1120        self.cli_file.flush()
1121
1122    def testReadlineAfterReadNoNewline(self):
1123        end_of_ = self.serv_file.read(len("End Of "))
1124        self.assertEqual("End Of ", end_of_)
1125        line = self.serv_file.readline()
1126        self.assertEqual("Line", line)
1127
1128    def _testReadlineAfterReadNoNewline(self):
1129        self.cli_file.write("End Of Line")
1130
1131    def testClosedAttr(self):
1132        self.assertTrue(not self.serv_file.closed)
1133
1134    def _testClosedAttr(self):
1135        self.assertTrue(not self.cli_file.closed)
1136
1137
1138class FileObjectInterruptedTestCase(unittest.TestCase):
1139    """Test that the file object correctly handles EINTR internally."""
1140
1141    class MockSocket(object):
1142        def __init__(self, recv_funcs=()):
1143            # A generator that returns callables that we'll call for each
1144            # call to recv().
1145            self._recv_step = iter(recv_funcs)
1146
1147        def recv(self, size):
1148            return self._recv_step.next()()
1149
1150    @staticmethod
1151    def _raise_eintr():
1152        raise socket.error(errno.EINTR)
1153
1154    def _test_readline(self, size=-1, **kwargs):
1155        mock_sock = self.MockSocket(recv_funcs=[
1156                lambda : "This is the first line\nAnd the sec",
1157                self._raise_eintr,
1158                lambda : "ond line is here\n",
1159                lambda : "",
1160            ])
1161        fo = socket._fileobject(mock_sock, **kwargs)
1162        self.assertEqual(fo.readline(size), "This is the first line\n")
1163        self.assertEqual(fo.readline(size), "And the second line is here\n")
1164
1165    def _test_read(self, size=-1, **kwargs):
1166        mock_sock = self.MockSocket(recv_funcs=[
1167                lambda : "This is the first line\nAnd the sec",
1168                self._raise_eintr,
1169                lambda : "ond line is here\n",
1170                lambda : "",
1171            ])
1172        fo = socket._fileobject(mock_sock, **kwargs)
1173        self.assertEqual(fo.read(size), "This is the first line\n"
1174                          "And the second line is here\n")
1175
1176    def test_default(self):
1177        self._test_readline()
1178        self._test_readline(size=100)
1179        self._test_read()
1180        self._test_read(size=100)
1181
1182    def test_with_1k_buffer(self):
1183        self._test_readline(bufsize=1024)
1184        self._test_readline(size=100, bufsize=1024)
1185        self._test_read(bufsize=1024)
1186        self._test_read(size=100, bufsize=1024)
1187
1188    def _test_readline_no_buffer(self, size=-1):
1189        mock_sock = self.MockSocket(recv_funcs=[
1190                lambda : "aa",
1191                lambda : "\n",
1192                lambda : "BB",
1193                self._raise_eintr,
1194                lambda : "bb",
1195                lambda : "",
1196            ])
1197        fo = socket._fileobject(mock_sock, bufsize=0)
1198        self.assertEqual(fo.readline(size), "aa\n")
1199        self.assertEqual(fo.readline(size), "BBbb")
1200
1201    def test_no_buffer(self):
1202        self._test_readline_no_buffer()
1203        self._test_readline_no_buffer(size=4)
1204        self._test_read(bufsize=0)
1205        self._test_read(size=100, bufsize=0)
1206
1207
1208class UnbufferedFileObjectClassTestCase(FileObjectClassTestCase):
1209
1210    """Repeat the tests from FileObjectClassTestCase with bufsize==0.
1211
1212    In this case (and in this case only), it should be possible to
1213    create a file object, read a line from it, create another file
1214    object, read another line from it, without loss of data in the
1215    first file object's buffer.  Note that httplib relies on this
1216    when reading multiple requests from the same socket."""
1217
1218    bufsize = 0 # Use unbuffered mode
1219
1220    def testUnbufferedReadline(self):
1221        # Read a line, create a new file object, read another line with it
1222        line = self.serv_file.readline() # first line
1223        self.assertEqual(line, "A. " + MSG) # first line
1224        self.serv_file = self.cli_conn.makefile('rb', 0)
1225        line = self.serv_file.readline() # second line
1226        self.assertEqual(line, "B. " + MSG) # second line
1227
1228    def _testUnbufferedReadline(self):
1229        self.cli_file.write("A. " + MSG)
1230        self.cli_file.write("B. " + MSG)
1231        self.cli_file.flush()
1232
1233class LineBufferedFileObjectClassTestCase(FileObjectClassTestCase):
1234
1235    bufsize = 1 # Default-buffered for reading; line-buffered for writing
1236
1237    class SocketMemo(object):
1238        """A wrapper to keep track of sent data, needed to examine write behaviour"""
1239        def __init__(self, sock):
1240            self._sock = sock
1241            self.sent = []
1242
1243        def send(self, data, flags=0):
1244            n = self._sock.send(data, flags)
1245            self.sent.append(data[:n])
1246            return n
1247
1248        def sendall(self, data, flags=0):
1249            self._sock.sendall(data, flags)
1250            self.sent.append(data)
1251
1252        def __getattr__(self, attr):
1253            return getattr(self._sock, attr)
1254
1255        def getsent(self):
1256            return [e.tobytes() if isinstance(e, memoryview) else e for e in self.sent]
1257
1258    def setUp(self):
1259        FileObjectClassTestCase.setUp(self)
1260        self.serv_file._sock = self.SocketMemo(self.serv_file._sock)
1261
1262    def testLinebufferedWrite(self):
1263        # Write two lines, in small chunks
1264        msg = MSG.strip()
1265        print >> self.serv_file, msg,
1266        print >> self.serv_file, msg
1267
1268        # second line:
1269        print >> self.serv_file, msg,
1270        print >> self.serv_file, msg,
1271        print >> self.serv_file, msg
1272
1273        # third line
1274        print >> self.serv_file, ''
1275
1276        self.serv_file.flush()
1277
1278        msg1 = "%s %s\n"%(msg, msg)
1279        msg2 =  "%s %s %s\n"%(msg, msg, msg)
1280        msg3 =  "\n"
1281        self.assertEqual(self.serv_file._sock.getsent(), [msg1, msg2, msg3])
1282
1283    def _testLinebufferedWrite(self):
1284        msg = MSG.strip()
1285        msg1 = "%s %s\n"%(msg, msg)
1286        msg2 =  "%s %s %s\n"%(msg, msg, msg)
1287        msg3 =  "\n"
1288        l1 = self.cli_file.readline()
1289        self.assertEqual(l1, msg1)
1290        l2 = self.cli_file.readline()
1291        self.assertEqual(l2, msg2)
1292        l3 = self.cli_file.readline()
1293        self.assertEqual(l3, msg3)
1294
1295
1296class SmallBufferedFileObjectClassTestCase(FileObjectClassTestCase):
1297
1298    bufsize = 2 # Exercise the buffering code
1299
1300
1301class NetworkConnectionTest(object):
1302    """Prove network connection."""
1303    def clientSetUp(self):
1304        # We're inherited below by BasicTCPTest2, which also inherits
1305        # BasicTCPTest, which defines self.port referenced below.
1306        self.cli = socket.create_connection((HOST, self.port))
1307        self.serv_conn = self.cli
1308
1309class BasicTCPTest2(NetworkConnectionTest, BasicTCPTest):
1310    """Tests that NetworkConnection does not break existing TCP functionality.
1311    """
1312
1313class NetworkConnectionNoServer(unittest.TestCase):
1314    class MockSocket(socket.socket):
1315        def connect(self, *args):
1316            raise socket.timeout('timed out')
1317
1318    @contextlib.contextmanager
1319    def mocked_socket_module(self):
1320        """Return a socket which times out on connect"""
1321        old_socket = socket.socket
1322        socket.socket = self.MockSocket
1323        try:
1324            yield
1325        finally:
1326            socket.socket = old_socket
1327
1328    def test_connect(self):
1329        port = test_support.find_unused_port()
1330        cli = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
1331        self.addCleanup(cli.close)
1332        with self.assertRaises(socket.error) as cm:
1333            cli.connect((HOST, port))
1334        self.assertEqual(cm.exception.errno, errno.ECONNREFUSED)
1335
1336    def test_create_connection(self):
1337        # Issue #9792: errors raised by create_connection() should have
1338        # a proper errno attribute.
1339        port = test_support.find_unused_port()
1340        with self.assertRaises(socket.error) as cm:
1341            socket.create_connection((HOST, port))
1342
1343        # Issue #16257: create_connection() calls getaddrinfo() against
1344        # 'localhost'.  This may result in an IPV6 addr being returned
1345        # as well as an IPV4 one:
1346        #   >>> socket.getaddrinfo('localhost', port, 0, SOCK_STREAM)
1347        #   >>> [(2,  2, 0, '', ('127.0.0.1', 41230)),
1348        #        (26, 2, 0, '', ('::1', 41230, 0, 0))]
1349        #
1350        # create_connection() enumerates through all the addresses returned
1351        # and if it doesn't successfully bind to any of them, it propagates
1352        # the last exception it encountered.
1353        #
1354        # On Solaris, ENETUNREACH is returned in this circumstance instead
1355        # of ECONNREFUSED.  So, if that errno exists, add it to our list of
1356        # expected errnos.
1357        expected_errnos = [ errno.ECONNREFUSED, ]
1358        if hasattr(errno, 'ENETUNREACH'):
1359            expected_errnos.append(errno.ENETUNREACH)
1360
1361        self.assertIn(cm.exception.errno, expected_errnos)
1362
1363    def test_create_connection_timeout(self):
1364        # Issue #9792: create_connection() should not recast timeout errors
1365        # as generic socket errors.
1366        with self.mocked_socket_module():
1367            with self.assertRaises(socket.timeout):
1368                socket.create_connection((HOST, 1234))
1369
1370
1371@unittest.skipUnless(thread, 'Threading required for this test.')
1372class NetworkConnectionAttributesTest(SocketTCPTest, ThreadableTest):
1373
1374    def __init__(self, methodName='runTest'):
1375        SocketTCPTest.__init__(self, methodName=methodName)
1376        ThreadableTest.__init__(self)
1377
1378    def clientSetUp(self):
1379        self.source_port = test_support.find_unused_port()
1380
1381    def clientTearDown(self):
1382        self.cli.close()
1383        self.cli = None
1384        ThreadableTest.clientTearDown(self)
1385
1386    def _justAccept(self):
1387        conn, addr = self.serv.accept()
1388        conn.close()
1389
1390    testFamily = _justAccept
1391    def _testFamily(self):
1392        self.cli = socket.create_connection((HOST, self.port), timeout=30)
1393        self.addCleanup(self.cli.close)
1394        self.assertEqual(self.cli.family, 2)
1395
1396    testSourceAddress = _justAccept
1397    def _testSourceAddress(self):
1398        self.cli = socket.create_connection((HOST, self.port), timeout=30,
1399                source_address=('', self.source_port))
1400        self.addCleanup(self.cli.close)
1401        self.assertEqual(self.cli.getsockname()[1], self.source_port)
1402        # The port number being used is sufficient to show that the bind()
1403        # call happened.
1404
1405    testTimeoutDefault = _justAccept
1406    def _testTimeoutDefault(self):
1407        # passing no explicit timeout uses socket's global default
1408        self.assertTrue(socket.getdefaulttimeout() is None)
1409        socket.setdefaulttimeout(42)
1410        try:
1411            self.cli = socket.create_connection((HOST, self.port))
1412            self.addCleanup(self.cli.close)
1413        finally:
1414            socket.setdefaulttimeout(None)
1415        self.assertEqual(self.cli.gettimeout(), 42)
1416
1417    testTimeoutNone = _justAccept
1418    def _testTimeoutNone(self):
1419        # None timeout means the same as sock.settimeout(None)
1420        self.assertTrue(socket.getdefaulttimeout() is None)
1421        socket.setdefaulttimeout(30)
1422        try:
1423            self.cli = socket.create_connection((HOST, self.port), timeout=None)
1424            self.addCleanup(self.cli.close)
1425        finally:
1426            socket.setdefaulttimeout(None)
1427        self.assertEqual(self.cli.gettimeout(), None)
1428
1429    testTimeoutValueNamed = _justAccept
1430    def _testTimeoutValueNamed(self):
1431        self.cli = socket.create_connection((HOST, self.port), timeout=30)
1432        self.assertEqual(self.cli.gettimeout(), 30)
1433
1434    testTimeoutValueNonamed = _justAccept
1435    def _testTimeoutValueNonamed(self):
1436        self.cli = socket.create_connection((HOST, self.port), 30)
1437        self.addCleanup(self.cli.close)
1438        self.assertEqual(self.cli.gettimeout(), 30)
1439
1440@unittest.skipUnless(thread, 'Threading required for this test.')
1441class NetworkConnectionBehaviourTest(SocketTCPTest, ThreadableTest):
1442
1443    def __init__(self, methodName='runTest'):
1444        SocketTCPTest.__init__(self, methodName=methodName)
1445        ThreadableTest.__init__(self)
1446
1447    def clientSetUp(self):
1448        pass
1449
1450    def clientTearDown(self):
1451        self.cli.close()
1452        self.cli = None
1453        ThreadableTest.clientTearDown(self)
1454
1455    def testInsideTimeout(self):
1456        conn, addr = self.serv.accept()
1457        self.addCleanup(conn.close)
1458        time.sleep(3)
1459        conn.send("done!")
1460    testOutsideTimeout = testInsideTimeout
1461
1462    def _testInsideTimeout(self):
1463        self.cli = sock = socket.create_connection((HOST, self.port))
1464        data = sock.recv(5)
1465        self.assertEqual(data, "done!")
1466
1467    def _testOutsideTimeout(self):
1468        self.cli = sock = socket.create_connection((HOST, self.port), timeout=1)
1469        self.assertRaises(socket.timeout, lambda: sock.recv(5))
1470
1471
1472class Urllib2FileobjectTest(unittest.TestCase):
1473
1474    # urllib2.HTTPHandler has "borrowed" socket._fileobject, and requires that
1475    # it close the socket if the close c'tor argument is true
1476
1477    def testClose(self):
1478        class MockSocket:
1479            closed = False
1480            def flush(self): pass
1481            def close(self): self.closed = True
1482
1483        # must not close unless we request it: the original use of _fileobject
1484        # by module socket requires that the underlying socket not be closed until
1485        # the _socketobject that created the _fileobject is closed
1486        s = MockSocket()
1487        f = socket._fileobject(s)
1488        f.close()
1489        self.assertTrue(not s.closed)
1490
1491        s = MockSocket()
1492        f = socket._fileobject(s, close=True)
1493        f.close()
1494        self.assertTrue(s.closed)
1495
1496class TCPTimeoutTest(SocketTCPTest):
1497
1498    def testTCPTimeout(self):
1499        def raise_timeout(*args, **kwargs):
1500            self.serv.settimeout(1.0)
1501            self.serv.accept()
1502        self.assertRaises(socket.timeout, raise_timeout,
1503                              "Error generating a timeout exception (TCP)")
1504
1505    def testTimeoutZero(self):
1506        ok = False
1507        try:
1508            self.serv.settimeout(0.0)
1509            foo = self.serv.accept()
1510        except socket.timeout:
1511            self.fail("caught timeout instead of error (TCP)")
1512        except socket.error:
1513            ok = True
1514        except:
1515            self.fail("caught unexpected exception (TCP)")
1516        if not ok:
1517            self.fail("accept() returned success when we did not expect it")
1518
1519    @unittest.skipUnless(hasattr(signal, 'alarm'),
1520                         'test needs signal.alarm()')
1521    def testInterruptedTimeout(self):
1522        # XXX I don't know how to do this test on MSWindows or any other
1523        # plaform that doesn't support signal.alarm() or os.kill(), though
1524        # the bug should have existed on all platforms.
1525        self.serv.settimeout(5.0)   # must be longer than alarm
1526        class Alarm(Exception):
1527            pass
1528        def alarm_handler(signal, frame):
1529            raise Alarm
1530        old_alarm = signal.signal(signal.SIGALRM, alarm_handler)
1531        try:
1532            signal.alarm(2)    # POSIX allows alarm to be up to 1 second early
1533            try:
1534                foo = self.serv.accept()
1535            except socket.timeout:
1536                self.fail("caught timeout instead of Alarm")
1537            except Alarm:
1538                pass
1539            except:
1540                self.fail("caught other exception instead of Alarm:"
1541                          " %s(%s):\n%s" %
1542                          (sys.exc_info()[:2] + (traceback.format_exc(),)))
1543            else:
1544                self.fail("nothing caught")
1545            finally:
1546                signal.alarm(0)         # shut off alarm
1547        except Alarm:
1548            self.fail("got Alarm in wrong place")
1549        finally:
1550            # no alarm can be pending.  Safe to restore old handler.
1551            signal.signal(signal.SIGALRM, old_alarm)
1552
1553class UDPTimeoutTest(SocketUDPTest):
1554
1555    def testUDPTimeout(self):
1556        def raise_timeout(*args, **kwargs):
1557            self.serv.settimeout(1.0)
1558            self.serv.recv(1024)
1559        self.assertRaises(socket.timeout, raise_timeout,
1560                              "Error generating a timeout exception (UDP)")
1561
1562    def testTimeoutZero(self):
1563        ok = False
1564        try:
1565            self.serv.settimeout(0.0)
1566            foo = self.serv.recv(1024)
1567        except socket.timeout:
1568            self.fail("caught timeout instead of error (UDP)")
1569        except socket.error:
1570            ok = True
1571        except:
1572            self.fail("caught unexpected exception (UDP)")
1573        if not ok:
1574            self.fail("recv() returned success when we did not expect it")
1575
1576class TestExceptions(unittest.TestCase):
1577
1578    def testExceptionTree(self):
1579        self.assertTrue(issubclass(socket.error, Exception))
1580        self.assertTrue(issubclass(socket.herror, socket.error))
1581        self.assertTrue(issubclass(socket.gaierror, socket.error))
1582        self.assertTrue(issubclass(socket.timeout, socket.error))
1583
1584@unittest.skipUnless(sys.platform == 'linux', 'Linux specific test')
1585class TestLinuxAbstractNamespace(unittest.TestCase):
1586
1587    UNIX_PATH_MAX = 108
1588
1589    def testLinuxAbstractNamespace(self):
1590        address = "\x00python-test-hello\x00\xff"
1591        s1 = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM)
1592        s1.bind(address)
1593        s1.listen(1)
1594        s2 = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM)
1595        s2.connect(s1.getsockname())
1596        s1.accept()
1597        self.assertEqual(s1.getsockname(), address)
1598        self.assertEqual(s2.getpeername(), address)
1599
1600    def testMaxName(self):
1601        address = "\x00" + "h" * (self.UNIX_PATH_MAX - 1)
1602        s = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM)
1603        s.bind(address)
1604        self.assertEqual(s.getsockname(), address)
1605
1606    def testNameOverflow(self):
1607        address = "\x00" + "h" * self.UNIX_PATH_MAX
1608        s = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM)
1609        self.assertRaises(socket.error, s.bind, address)
1610
1611
1612@unittest.skipUnless(thread, 'Threading required for this test.')
1613class BufferIOTest(SocketConnectedTest):
1614    """
1615    Test the buffer versions of socket.recv() and socket.send().
1616    """
1617    def __init__(self, methodName='runTest'):
1618        SocketConnectedTest.__init__(self, methodName=methodName)
1619
1620    def testRecvIntoArray(self):
1621        buf = array.array('c', ' '*1024)
1622        nbytes = self.cli_conn.recv_into(buf)
1623        self.assertEqual(nbytes, len(MSG))
1624        msg = buf.tostring()[:len(MSG)]
1625        self.assertEqual(msg, MSG)
1626
1627    def _testRecvIntoArray(self):
1628        with test_support.check_py3k_warnings():
1629            buf = buffer(MSG)
1630        self.serv_conn.send(buf)
1631
1632    def testRecvIntoBytearray(self):
1633        buf = bytearray(1024)
1634        nbytes = self.cli_conn.recv_into(buf)
1635        self.assertEqual(nbytes, len(MSG))
1636        msg = buf[:len(MSG)]
1637        self.assertEqual(msg, MSG)
1638
1639    _testRecvIntoBytearray = _testRecvIntoArray
1640
1641    def testRecvIntoMemoryview(self):
1642        buf = bytearray(1024)
1643        nbytes = self.cli_conn.recv_into(memoryview(buf))
1644        self.assertEqual(nbytes, len(MSG))
1645        msg = buf[:len(MSG)]
1646        self.assertEqual(msg, MSG)
1647
1648    _testRecvIntoMemoryview = _testRecvIntoArray
1649
1650    def testRecvFromIntoArray(self):
1651        buf = array.array('c', ' '*1024)
1652        nbytes, addr = self.cli_conn.recvfrom_into(buf)
1653        self.assertEqual(nbytes, len(MSG))
1654        msg = buf.tostring()[:len(MSG)]
1655        self.assertEqual(msg, MSG)
1656
1657    def _testRecvFromIntoArray(self):
1658        with test_support.check_py3k_warnings():
1659            buf = buffer(MSG)
1660        self.serv_conn.send(buf)
1661
1662    def testRecvFromIntoBytearray(self):
1663        buf = bytearray(1024)
1664        nbytes, addr = self.cli_conn.recvfrom_into(buf)
1665        self.assertEqual(nbytes, len(MSG))
1666        msg = buf[:len(MSG)]
1667        self.assertEqual(msg, MSG)
1668
1669    _testRecvFromIntoBytearray = _testRecvFromIntoArray
1670
1671    def testRecvFromIntoMemoryview(self):
1672        buf = bytearray(1024)
1673        nbytes, addr = self.cli_conn.recvfrom_into(memoryview(buf))
1674        self.assertEqual(nbytes, len(MSG))
1675        msg = buf[:len(MSG)]
1676        self.assertEqual(msg, MSG)
1677
1678    _testRecvFromIntoMemoryview = _testRecvFromIntoArray
1679
1680    def testRecvFromIntoSmallBuffer(self):
1681        # See issue #20246.
1682        buf = bytearray(8)
1683        self.assertRaises(ValueError, self.cli_conn.recvfrom_into, buf, 1024)
1684
1685    def _testRecvFromIntoSmallBuffer(self):
1686        with test_support.check_py3k_warnings():
1687            buf = buffer(MSG)
1688        self.serv_conn.send(buf)
1689
1690    def testRecvFromIntoEmptyBuffer(self):
1691        buf = bytearray()
1692        self.cli_conn.recvfrom_into(buf)
1693        self.cli_conn.recvfrom_into(buf, 0)
1694
1695    _testRecvFromIntoEmptyBuffer = _testRecvFromIntoArray
1696
1697
1698TIPC_STYPE = 2000
1699TIPC_LOWER = 200
1700TIPC_UPPER = 210
1701
1702def isTipcAvailable():
1703    """Check if the TIPC module is loaded
1704
1705    The TIPC module is not loaded automatically on Ubuntu and probably
1706    other Linux distros.
1707    """
1708    if not hasattr(socket, "AF_TIPC"):
1709        return False
1710    if not os.path.isfile("/proc/modules"):
1711        return False
1712    with open("/proc/modules") as f:
1713        for line in f:
1714            if line.startswith("tipc "):
1715                return True
1716    return False
1717
1718@unittest.skipUnless(isTipcAvailable(),
1719                     "TIPC module is not loaded, please 'sudo modprobe tipc'")
1720class TIPCTest(unittest.TestCase):
1721    def testRDM(self):
1722        srv = socket.socket(socket.AF_TIPC, socket.SOCK_RDM)
1723        cli = socket.socket(socket.AF_TIPC, socket.SOCK_RDM)
1724
1725        srv.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
1726        srvaddr = (socket.TIPC_ADDR_NAMESEQ, TIPC_STYPE,
1727                TIPC_LOWER, TIPC_UPPER)
1728        srv.bind(srvaddr)
1729
1730        sendaddr = (socket.TIPC_ADDR_NAME, TIPC_STYPE,
1731                TIPC_LOWER + (TIPC_UPPER - TIPC_LOWER) / 2, 0)
1732        cli.sendto(MSG, sendaddr)
1733
1734        msg, recvaddr = srv.recvfrom(1024)
1735
1736        self.assertEqual(cli.getsockname(), recvaddr)
1737        self.assertEqual(msg, MSG)
1738
1739
1740@unittest.skipUnless(isTipcAvailable(),
1741                     "TIPC module is not loaded, please 'sudo modprobe tipc'")
1742class TIPCThreadableTest(unittest.TestCase, ThreadableTest):
1743    def __init__(self, methodName = 'runTest'):
1744        unittest.TestCase.__init__(self, methodName = methodName)
1745        ThreadableTest.__init__(self)
1746
1747    def setUp(self):
1748        self.srv = socket.socket(socket.AF_TIPC, socket.SOCK_STREAM)
1749        self.srv.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
1750        srvaddr = (socket.TIPC_ADDR_NAMESEQ, TIPC_STYPE,
1751                TIPC_LOWER, TIPC_UPPER)
1752        self.srv.bind(srvaddr)
1753        self.srv.listen(5)
1754        self.serverExplicitReady()
1755        self.conn, self.connaddr = self.srv.accept()
1756
1757    def clientSetUp(self):
1758        # There is a hittable race between serverExplicitReady() and the
1759        # accept() call; sleep a little while to avoid it, otherwise
1760        # we could get an exception
1761        time.sleep(0.1)
1762        self.cli = socket.socket(socket.AF_TIPC, socket.SOCK_STREAM)
1763        addr = (socket.TIPC_ADDR_NAME, TIPC_STYPE,
1764                TIPC_LOWER + (TIPC_UPPER - TIPC_LOWER) / 2, 0)
1765        self.cli.connect(addr)
1766        self.cliaddr = self.cli.getsockname()
1767
1768    def testStream(self):
1769        msg = self.conn.recv(1024)
1770        self.assertEqual(msg, MSG)
1771        self.assertEqual(self.cliaddr, self.connaddr)
1772
1773    def _testStream(self):
1774        self.cli.send(MSG)
1775        self.cli.close()
1776
1777
1778def test_main():
1779    tests = [GeneralModuleTests, BasicTCPTest, TCPCloserTest, TCPTimeoutTest,
1780             TestExceptions, BufferIOTest, BasicTCPTest2, BasicUDPTest,
1781             UDPTimeoutTest ]
1782
1783    tests.extend([
1784        NonBlockingTCPTests,
1785        FileObjectClassTestCase,
1786        FileObjectInterruptedTestCase,
1787        UnbufferedFileObjectClassTestCase,
1788        LineBufferedFileObjectClassTestCase,
1789        SmallBufferedFileObjectClassTestCase,
1790        Urllib2FileobjectTest,
1791        NetworkConnectionNoServer,
1792        NetworkConnectionAttributesTest,
1793        NetworkConnectionBehaviourTest,
1794    ])
1795    tests.append(BasicSocketPairTest)
1796    tests.append(TestLinuxAbstractNamespace)
1797    tests.extend([TIPCTest, TIPCThreadableTest])
1798
1799    thread_info = test_support.threading_setup()
1800    test_support.run_unittest(*tests)
1801    test_support.threading_cleanup(*thread_info)
1802
1803if __name__ == "__main__":
1804    test_main()
1805