1#!/usr/bin/python 2# Copyright 2017 The Chromium OS Authors. All rights reserved. 3# Use of this source code is governed by a BSD-style license that can be 4# found in the LICENSE file. 5 6import os 7import shutil 8import socket 9import tempfile 10import unittest 11 12import common 13from autotest_lib.site_utils import lxc 14from autotest_lib.site_utils.lxc import unittest_setup 15from autotest_lib.site_utils.lxc.container_pool import async_listener 16from autotest_lib.site_utils.lxc.container_pool import client 17 18 19# Timeout for tests. 20TIMEOUT = 30 21 22 23class ClientTests(unittest.TestCase): 24 """Unit tests for the Client class.""" 25 26 @classmethod 27 def setUpClass(cls): 28 """Creates a directory for running the unit tests.""" 29 # Explicitly use /tmp as the tmpdir. Board specific TMPDIRs inside of 30 # the chroot are set to a path that causes the socket address to exceed 31 # the maximum allowable length. 32 cls.test_dir = tempfile.mkdtemp(prefix='client_unittest_', dir='/tmp') 33 34 35 @classmethod 36 def tearDownClass(cls): 37 """Deletes the test directory.""" 38 shutil.rmtree(cls.test_dir) 39 40 41 def setUp(self): 42 """Per-test setup.""" 43 # Put each test in its own test dir, so it's hermetic. 44 self.test_dir = tempfile.mkdtemp(dir=ClientTests.test_dir) 45 self.address = os.path.join(self.test_dir, 46 lxc.DEFAULT_CONTAINER_POOL_SOCKET) 47 self.listener = async_listener.AsyncListener(self.address) 48 self.listener.start() 49 50 51 def tearDown(self): 52 self.listener.close() 53 54 55 def testConnection(self): 56 """Tests a basic client connection.""" 57 # Verify that no connections are pending. 58 self.assertIsNone(self.listener.get_connection()) 59 60 # Connect a client, then verify that the host connection is established. 61 host = None 62 with client.Client.connect(self.address, TIMEOUT): 63 host = self.listener.get_connection(TIMEOUT) 64 self.assertIsNotNone(host) 65 66 # Client closed - check that the host connection also closed. 67 self.assertTrue(host.poll(TIMEOUT)) 68 with self.assertRaises(EOFError): 69 host.recv() 70 71 72 def testConnection_badAddress(self): 73 """Tests that connecting to a bad address fails.""" 74 # Make a bogus address, then assert that the client fails. 75 address = '%s.foobar' % self.address 76 with self.assertRaises(socket.error): 77 client.Client(address, 0) 78 79 80 def testConnection_timeout(self): 81 """Tests that connection attempts time out properly.""" 82 with tempfile.NamedTemporaryFile(dir=self.test_dir) as tmp: 83 with self.assertRaises(socket.timeout): 84 client.Client(tmp.name, 0) 85 86 87 def testConnection_deadLine(self): 88 """Tests that the connection times out if no action is ever taken.""" 89 id = 3 90 short_timeout = TIMEOUT/2 91 with client.Client.connect(self.address, TIMEOUT) as c: 92 self.assertIsNone(c.get_container(id, short_timeout)) 93 94if __name__ == '__main__': 95 unittest_setup.setup(require_sudo=False) 96 unittest.main() 97