1#!/usr/bin/python -u 2 3import os, sys, re, tempfile 4from optparse import OptionParser 5import common 6from autotest_lib.client.common_lib import utils 7from autotest_lib.database import database_connection 8 9MIGRATE_TABLE = 'migrate_info' 10 11_AUTODIR = os.path.join(os.path.dirname(__file__), '..') 12_MIGRATIONS_DIRS = { 13 'AUTOTEST_WEB': os.path.join(_AUTODIR, 'frontend', 'migrations'), 14 'TKO': os.path.join(_AUTODIR, 'tko', 'migrations'), 15 'AUTOTEST_SERVER_DB': os.path.join(_AUTODIR, 'database', 16 'server_db_migrations'), 17} 18_DEFAULT_MIGRATIONS_DIR = 'migrations' # use CWD 19 20class Migration(object): 21 """Represents a database migration.""" 22 _UP_ATTRIBUTES = ('migrate_up', 'UP_SQL') 23 _DOWN_ATTRIBUTES = ('migrate_down', 'DOWN_SQL') 24 25 def __init__(self, name, version, module): 26 self.name = name 27 self.version = version 28 self.module = module 29 self._check_attributes(self._UP_ATTRIBUTES) 30 self._check_attributes(self._DOWN_ATTRIBUTES) 31 32 33 @classmethod 34 def from_file(cls, filename): 35 """Instantiates a Migration from a file. 36 37 @param filename: Name of a migration file. 38 39 @return An instantiated Migration object. 40 41 """ 42 version = int(filename[:3]) 43 name = filename[:-3] 44 module = __import__(name, globals(), locals(), []) 45 return cls(name, version, module) 46 47 48 def _check_attributes(self, attributes): 49 method_name, sql_name = attributes 50 assert (hasattr(self.module, method_name) or 51 hasattr(self.module, sql_name)) 52 53 54 def _execute_migration(self, attributes, manager): 55 method_name, sql_name = attributes 56 method = getattr(self.module, method_name, None) 57 if method: 58 assert callable(method) 59 method(manager) 60 else: 61 sql = getattr(self.module, sql_name) 62 assert isinstance(sql, basestring) 63 manager.execute_script(sql) 64 65 66 def migrate_up(self, manager): 67 """Performs an up migration (to a newer version). 68 69 @param manager: A MigrationManager object. 70 71 """ 72 self._execute_migration(self._UP_ATTRIBUTES, manager) 73 74 75 def migrate_down(self, manager): 76 """Performs a down migration (to an older version). 77 78 @param manager: A MigrationManager object. 79 80 """ 81 self._execute_migration(self._DOWN_ATTRIBUTES, manager) 82 83 84class MigrationManager(object): 85 """Managest database migrations.""" 86 connection = None 87 cursor = None 88 migrations_dir = None 89 90 def __init__(self, database_connection, migrations_dir=None, force=False): 91 self._database = database_connection 92 self.force = force 93 # A boolean, this will only be set to True if this migration should be 94 # simulated rather than actually taken. For use with migrations that 95 # may make destructive queries 96 self.simulate = False 97 self._set_migrations_dir(migrations_dir) 98 99 100 def _set_migrations_dir(self, migrations_dir=None): 101 config_section = self._config_section() 102 if migrations_dir is None: 103 migrations_dir = os.path.abspath( 104 _MIGRATIONS_DIRS.get(config_section, _DEFAULT_MIGRATIONS_DIR)) 105 self.migrations_dir = migrations_dir 106 sys.path.append(migrations_dir) 107 assert os.path.exists(migrations_dir), migrations_dir + " doesn't exist" 108 109 110 def _config_section(self): 111 return self._database.global_config_section 112 113 114 def get_db_name(self): 115 """Gets the database name.""" 116 return self._database.get_database_info()['db_name'] 117 118 119 def execute(self, query, *parameters): 120 """Executes a database query. 121 122 @param query: The query to execute. 123 @param parameters: Associated parameters for the query. 124 125 @return The result of the query. 126 127 """ 128 return self._database.execute(query, parameters) 129 130 131 def execute_script(self, script): 132 """Executes a set of database queries. 133 134 @param script: A string of semicolon-separated queries. 135 136 """ 137 sql_statements = [statement.strip() 138 for statement in script.split(';') 139 if statement.strip()] 140 for statement in sql_statements: 141 self.execute(statement) 142 143 144 def check_migrate_table_exists(self): 145 """Checks whether the migration table exists.""" 146 try: 147 self.execute("SELECT * FROM %s" % MIGRATE_TABLE) 148 return True 149 except self._database.DatabaseError, exc: 150 # we can't check for more specifics due to differences between DB 151 # backends (we can't even check for a subclass of DatabaseError) 152 return False 153 154 155 def create_migrate_table(self): 156 """Creates the migration table.""" 157 if not self.check_migrate_table_exists(): 158 self.execute("CREATE TABLE %s (`version` integer)" % 159 MIGRATE_TABLE) 160 else: 161 self.execute("DELETE FROM %s" % MIGRATE_TABLE) 162 self.execute("INSERT INTO %s VALUES (0)" % MIGRATE_TABLE) 163 assert self._database.rowcount == 1 164 165 166 def set_db_version(self, version): 167 """Sets the database version. 168 169 @param version: The version to which to set the database. 170 171 """ 172 assert isinstance(version, int) 173 self.execute("UPDATE %s SET version=%%s" % MIGRATE_TABLE, 174 version) 175 assert self._database.rowcount == 1 176 177 178 def get_db_version(self): 179 """Gets the database version. 180 181 @return The database version. 182 183 """ 184 if not self.check_migrate_table_exists(): 185 return 0 186 rows = self.execute("SELECT * FROM %s" % MIGRATE_TABLE) 187 if len(rows) == 0: 188 return 0 189 assert len(rows) == 1 and len(rows[0]) == 1 190 return rows[0][0] 191 192 193 def get_migrations(self, minimum_version=None, maximum_version=None): 194 """Gets the list of migrations to perform. 195 196 @param minimum_version: The minimum database version. 197 @param maximum_version: The maximum database version. 198 199 @return A list of Migration objects. 200 201 """ 202 migrate_files = [filename for filename 203 in os.listdir(self.migrations_dir) 204 if re.match(r'^\d\d\d_.*\.py$', filename)] 205 migrate_files.sort() 206 migrations = [Migration.from_file(filename) 207 for filename in migrate_files] 208 if minimum_version is not None: 209 migrations = [migration for migration in migrations 210 if migration.version >= minimum_version] 211 if maximum_version is not None: 212 migrations = [migration for migration in migrations 213 if migration.version <= maximum_version] 214 return migrations 215 216 217 def do_migration(self, migration, migrate_up=True): 218 """Performs a migration. 219 220 @param migration: The Migration to perform. 221 @param migrate_up: Whether to migrate up (if not, then migrates down). 222 223 """ 224 print 'Applying migration %s' % migration.name, # no newline 225 if migrate_up: 226 print 'up' 227 assert self.get_db_version() == migration.version - 1 228 migration.migrate_up(self) 229 new_version = migration.version 230 else: 231 print 'down' 232 assert self.get_db_version() == migration.version 233 migration.migrate_down(self) 234 new_version = migration.version - 1 235 self.set_db_version(new_version) 236 237 238 def migrate_to_version(self, version): 239 """Performs a migration to a specified version. 240 241 @param version: The version to which to migrate the database. 242 243 """ 244 current_version = self.get_db_version() 245 if current_version == 0 and self._config_section() == 'AUTOTEST_WEB': 246 self._migrate_from_base() 247 current_version = self.get_db_version() 248 249 if current_version < version: 250 lower, upper = current_version, version 251 migrate_up = True 252 else: 253 lower, upper = version, current_version 254 migrate_up = False 255 256 migrations = self.get_migrations(lower + 1, upper) 257 if not migrate_up: 258 migrations.reverse() 259 for migration in migrations: 260 self.do_migration(migration, migrate_up) 261 262 assert self.get_db_version() == version 263 print 'At version', version 264 265 266 def _migrate_from_base(self): 267 """Initialize the AFE database. 268 """ 269 self.confirm_initialization() 270 271 migration_script = utils.read_file( 272 os.path.join(os.path.dirname(__file__), 'schema_051.sql')) 273 migration_script = migration_script % ( 274 dict(username=self._database.get_database_info()['username'])) 275 self.execute_script(migration_script) 276 277 self.create_migrate_table() 278 self.set_db_version(51) 279 280 281 def confirm_initialization(self): 282 """Confirms with the user that we should initialize the database. 283 284 @raises Exception, if the user chooses to abort the migration. 285 286 """ 287 if not self.force: 288 response = raw_input( 289 'Your %s database does not appear to be initialized. Do you ' 290 'want to recreate it (this will result in loss of any existing ' 291 'data) (yes/No)? ' % self.get_db_name()) 292 if response != 'yes': 293 raise Exception('User has chosen to abort migration') 294 295 296 def get_latest_version(self): 297 """Gets the latest database version.""" 298 migrations = self.get_migrations() 299 return migrations[-1].version 300 301 302 def migrate_to_latest(self): 303 """Migrates the database to the latest version.""" 304 latest_version = self.get_latest_version() 305 self.migrate_to_version(latest_version) 306 307 308 def initialize_test_db(self): 309 """Initializes a test database.""" 310 db_name = self.get_db_name() 311 test_db_name = 'test_' + db_name 312 # first, connect to no DB so we can create a test DB 313 self._database.connect(db_name='') 314 print 'Creating test DB', test_db_name 315 self.execute('CREATE DATABASE ' + test_db_name) 316 self._database.disconnect() 317 # now connect to the test DB 318 self._database.connect(db_name=test_db_name) 319 320 321 def remove_test_db(self): 322 """Removes a test database.""" 323 print 'Removing test DB' 324 self.execute('DROP DATABASE ' + self.get_db_name()) 325 # reset connection back to real DB 326 self._database.disconnect() 327 self._database.connect() 328 329 330 def get_mysql_args(self): 331 """Returns the mysql arguments as a string.""" 332 return ('-u %(username)s -p%(password)s -h %(host)s %(db_name)s' % 333 self._database.get_database_info()) 334 335 336 def migrate_to_version_or_latest(self, version): 337 """Migrates to either a specified version, or the latest version. 338 339 @param version: The version to which to migrate the database, 340 or None in order to migrate to the latest version. 341 342 """ 343 if version is None: 344 self.migrate_to_latest() 345 else: 346 self.migrate_to_version(version) 347 348 349 def do_sync_db(self, version=None): 350 """Migrates the database. 351 352 @param version: The version to which to migrate the database. 353 354 """ 355 print 'Migration starting for database', self.get_db_name() 356 self.migrate_to_version_or_latest(version) 357 print 'Migration complete' 358 359 360 def test_sync_db(self, version=None): 361 """Create a fresh database and run all migrations on it. 362 363 @param version: The version to which to migrate the database. 364 365 """ 366 self.initialize_test_db() 367 try: 368 print 'Starting migration test on DB', self.get_db_name() 369 self.migrate_to_version_or_latest(version) 370 # show schema to the user 371 os.system('mysqldump %s --no-data=true ' 372 '--add-drop-table=false' % 373 self.get_mysql_args()) 374 finally: 375 self.remove_test_db() 376 print 'Test finished successfully' 377 378 379 def simulate_sync_db(self, version=None): 380 """Creates a fresh DB, copies existing DB to it, then synchronizes it. 381 382 @param version: The version to which to migrate the database. 383 384 """ 385 db_version = self.get_db_version() 386 # don't do anything if we're already at the latest version 387 if db_version == self.get_latest_version(): 388 print 'Skipping simulation, already at latest version' 389 return 390 # get existing data 391 self.initialize_and_fill_test_db() 392 try: 393 print 'Starting migration test on DB', self.get_db_name() 394 self.migrate_to_version_or_latest(version) 395 finally: 396 self.remove_test_db() 397 print 'Test finished successfully' 398 399 400 def initialize_and_fill_test_db(self): 401 """Initializes and fills up a test database.""" 402 print 'Dumping existing data' 403 dump_fd, dump_file = tempfile.mkstemp('.migrate_dump') 404 os.system('mysqldump %s >%s' % 405 (self.get_mysql_args(), dump_file)) 406 # fill in test DB 407 self.initialize_test_db() 408 print 'Filling in test DB' 409 os.system('mysql %s <%s' % (self.get_mysql_args(), dump_file)) 410 os.close(dump_fd) 411 os.remove(dump_file) 412 413 414USAGE = """\ 415%s [options] sync|test|simulate|safesync [version] 416Options: 417 -d --database Which database to act on 418 -f --force Don't ask for confirmation 419 --debug Print all DB queries"""\ 420 % sys.argv[0] 421 422 423def main(): 424 """Main function for the migration script.""" 425 parser = OptionParser() 426 parser.add_option("-d", "--database", 427 help="which database to act on", 428 dest="database", 429 default="AUTOTEST_WEB") 430 parser.add_option("-f", "--force", help="don't ask for confirmation", 431 action="store_true") 432 parser.add_option('--debug', help='print all DB queries', 433 action='store_true') 434 (options, args) = parser.parse_args() 435 manager = get_migration_manager(db_name=options.database, 436 debug=options.debug, force=options.force) 437 438 if len(args) > 0: 439 if len(args) > 1: 440 version = int(args[1]) 441 else: 442 version = None 443 if args[0] == 'sync': 444 manager.do_sync_db(version) 445 elif args[0] == 'test': 446 manager.simulate=True 447 manager.test_sync_db(version) 448 elif args[0] == 'simulate': 449 manager.simulate=True 450 manager.simulate_sync_db(version) 451 elif args[0] == 'safesync': 452 print 'Simluating migration' 453 manager.simulate=True 454 manager.simulate_sync_db(version) 455 print 'Performing real migration' 456 manager.simulate=False 457 manager.do_sync_db(version) 458 else: 459 print USAGE 460 return 461 462 print USAGE 463 464 465def get_migration_manager(db_name, debug, force): 466 """Creates a MigrationManager object. 467 468 @param db_name: The database name. 469 @param debug: Whether to print debug messages. 470 @param force: Whether to force migration without asking for confirmation. 471 472 @return A created MigrationManager object. 473 474 """ 475 database = database_connection.DatabaseConnection(db_name) 476 database.debug = debug 477 database.reconnect_enabled = False 478 database.connect() 479 return MigrationManager(database, force=force) 480 481 482if __name__ == '__main__': 483 main() 484