1#!/usr/bin/python 2 3import unittest, time 4import common 5from autotest_lib.client.common_lib import global_config 6from autotest_lib.client.common_lib.test_utils import mock 7from autotest_lib.database import database_connection 8 9_CONFIG_SECTION = 'AUTOTEST_WEB' 10_HOST = 'myhost' 11_USER = 'myuser' 12_PASS = 'mypass' 13_DB_NAME = 'mydb' 14_DB_TYPE = 'mydbtype' 15 16_CONNECT_KWARGS = dict(host=_HOST, username=_USER, password=_PASS, 17 db_name=_DB_NAME) 18_RECONNECT_DELAY = 10 19 20class FakeDatabaseError(Exception): 21 pass 22 23 24class DatabaseConnectionTest(unittest.TestCase): 25 def setUp(self): 26 self.god = mock.mock_god() 27 self.god.stub_function(time, 'sleep') 28 29 30 def tearDown(self): 31 global_config.global_config.reset_config_values() 32 self.god.unstub_all() 33 34 35 def _get_database_connection(self, config_section=_CONFIG_SECTION): 36 if config_section == _CONFIG_SECTION: 37 self._override_config() 38 db = database_connection.DatabaseConnection(config_section) 39 40 self._fake_backend = self.god.create_mock_class( 41 database_connection._GenericBackend, 'fake_backend') 42 for exception in database_connection._DB_EXCEPTIONS: 43 setattr(self._fake_backend, exception, FakeDatabaseError) 44 self._fake_backend.rowcount = 0 45 46 def get_fake_backend(db_type): 47 self._db_type = db_type 48 return self._fake_backend 49 self.god.stub_with(db, '_get_backend', get_fake_backend) 50 51 db.reconnect_delay_sec = _RECONNECT_DELAY 52 return db 53 54 55 def _override_config(self): 56 c = global_config.global_config 57 c.override_config_value(_CONFIG_SECTION, 'host', _HOST) 58 c.override_config_value(_CONFIG_SECTION, 'user', _USER) 59 c.override_config_value(_CONFIG_SECTION, 'password', _PASS) 60 c.override_config_value(_CONFIG_SECTION, 'database', _DB_NAME) 61 c.override_config_value(_CONFIG_SECTION, 'db_type', _DB_TYPE) 62 63 64 def test_connect(self): 65 db = self._get_database_connection(config_section=None) 66 self._fake_backend.connect.expect_call(**_CONNECT_KWARGS) 67 68 db.connect(db_type=_DB_TYPE, host=_HOST, username=_USER, 69 password=_PASS, db_name=_DB_NAME) 70 71 self.assertEquals(self._db_type, _DB_TYPE) 72 self.god.check_playback() 73 74 75 def test_global_config(self): 76 db = self._get_database_connection() 77 self._fake_backend.connect.expect_call(**_CONNECT_KWARGS) 78 79 db.connect() 80 81 self.assertEquals(self._db_type, _DB_TYPE) 82 self.god.check_playback() 83 84 85 def _expect_reconnect(self, fail=False): 86 self._fake_backend.disconnect.expect_call() 87 call = self._fake_backend.connect.expect_call(**_CONNECT_KWARGS) 88 if fail: 89 call.and_raises(FakeDatabaseError()) 90 91 92 def _expect_fail_and_reconnect(self, num_reconnects, fail_last=False): 93 self._fake_backend.connect.expect_call(**_CONNECT_KWARGS).and_raises( 94 FakeDatabaseError()) 95 for i in xrange(num_reconnects): 96 time.sleep.expect_call(_RECONNECT_DELAY) 97 if i < num_reconnects - 1: 98 self._expect_reconnect(fail=True) 99 else: 100 self._expect_reconnect(fail=fail_last) 101 102 103 def test_connect_retry(self): 104 db = self._get_database_connection() 105 self._expect_fail_and_reconnect(1) 106 107 db.connect() 108 self.god.check_playback() 109 110 self._fake_backend.disconnect.expect_call() 111 self._expect_fail_and_reconnect(0) 112 self.assertRaises(FakeDatabaseError, db.connect, 113 try_reconnecting=False) 114 self.god.check_playback() 115 116 db.reconnect_enabled = False 117 self._fake_backend.disconnect.expect_call() 118 self._expect_fail_and_reconnect(0) 119 self.assertRaises(FakeDatabaseError, db.connect) 120 self.god.check_playback() 121 122 123 def test_max_reconnect(self): 124 db = self._get_database_connection() 125 db.max_reconnect_attempts = 5 126 self._expect_fail_and_reconnect(5, fail_last=True) 127 128 self.assertRaises(FakeDatabaseError, db.connect) 129 self.god.check_playback() 130 131 132 def test_reconnect_forever(self): 133 db = self._get_database_connection() 134 db.max_reconnect_attempts = database_connection.RECONNECT_FOREVER 135 self._expect_fail_and_reconnect(30) 136 137 db.connect() 138 self.god.check_playback() 139 140 141 def _simple_connect(self, db): 142 self._fake_backend.connect.expect_call(**_CONNECT_KWARGS) 143 db.connect() 144 self.god.check_playback() 145 146 147 def test_disconnect(self): 148 db = self._get_database_connection() 149 self._simple_connect(db) 150 self._fake_backend.disconnect.expect_call() 151 152 db.disconnect() 153 self.god.check_playback() 154 155 156 def test_execute(self): 157 db = self._get_database_connection() 158 self._simple_connect(db) 159 params = object() 160 self._fake_backend.execute.expect_call('query', params) 161 162 db.execute('query', params) 163 self.god.check_playback() 164 165 166 def test_execute_retry(self): 167 db = self._get_database_connection() 168 self._simple_connect(db) 169 self._fake_backend.execute.expect_call('query', None).and_raises( 170 FakeDatabaseError()) 171 self._expect_reconnect() 172 self._fake_backend.execute.expect_call('query', None) 173 174 db.execute('query') 175 self.god.check_playback() 176 177 self._fake_backend.execute.expect_call('query', None).and_raises( 178 FakeDatabaseError()) 179 self.assertRaises(FakeDatabaseError, db.execute, 'query', 180 try_reconnecting=False) 181 182 183if __name__ == '__main__': 184 unittest.main() 185