#!/usr/bin/python -u import os, sys, re, tempfile from optparse import OptionParser import common from autotest_lib.client.common_lib import utils from autotest_lib.database import database_connection MIGRATE_TABLE = 'migrate_info' _AUTODIR = os.path.join(os.path.dirname(__file__), '..') _MIGRATIONS_DIRS = { 'AUTOTEST_WEB': os.path.join(_AUTODIR, 'frontend', 'migrations'), 'TKO': os.path.join(_AUTODIR, 'tko', 'migrations'), 'AUTOTEST_SERVER_DB': os.path.join(_AUTODIR, 'database', 'server_db_migrations'), } _DEFAULT_MIGRATIONS_DIR = 'migrations' # use CWD class Migration(object): """Represents a database migration.""" _UP_ATTRIBUTES = ('migrate_up', 'UP_SQL') _DOWN_ATTRIBUTES = ('migrate_down', 'DOWN_SQL') def __init__(self, name, version, module): self.name = name self.version = version self.module = module self._check_attributes(self._UP_ATTRIBUTES) self._check_attributes(self._DOWN_ATTRIBUTES) @classmethod def from_file(cls, filename): """Instantiates a Migration from a file. @param filename: Name of a migration file. @return An instantiated Migration object. """ version = int(filename[:3]) name = filename[:-3] module = __import__(name, globals(), locals(), []) return cls(name, version, module) def _check_attributes(self, attributes): method_name, sql_name = attributes assert (hasattr(self.module, method_name) or hasattr(self.module, sql_name)) def _execute_migration(self, attributes, manager): method_name, sql_name = attributes method = getattr(self.module, method_name, None) if method: assert callable(method) method(manager) else: sql = getattr(self.module, sql_name) assert isinstance(sql, basestring) manager.execute_script(sql) def migrate_up(self, manager): """Performs an up migration (to a newer version). @param manager: A MigrationManager object. """ self._execute_migration(self._UP_ATTRIBUTES, manager) def migrate_down(self, manager): """Performs a down migration (to an older version). @param manager: A MigrationManager object. """ self._execute_migration(self._DOWN_ATTRIBUTES, manager) class MigrationManager(object): """Managest database migrations.""" connection = None cursor = None migrations_dir = None def __init__(self, database_connection, migrations_dir=None, force=False): self._database = database_connection self.force = force # A boolean, this will only be set to True if this migration should be # simulated rather than actually taken. For use with migrations that # may make destructive queries self.simulate = False self._set_migrations_dir(migrations_dir) def _set_migrations_dir(self, migrations_dir=None): config_section = self._config_section() if migrations_dir is None: migrations_dir = os.path.abspath( _MIGRATIONS_DIRS.get(config_section, _DEFAULT_MIGRATIONS_DIR)) self.migrations_dir = migrations_dir sys.path.append(migrations_dir) assert os.path.exists(migrations_dir), migrations_dir + " doesn't exist" def _config_section(self): return self._database.global_config_section def get_db_name(self): """Gets the database name.""" return self._database.get_database_info()['db_name'] def execute(self, query, *parameters): """Executes a database query. @param query: The query to execute. @param parameters: Associated parameters for the query. @return The result of the query. """ return self._database.execute(query, parameters) def execute_script(self, script): """Executes a set of database queries. @param script: A string of semicolon-separated queries. """ sql_statements = [statement.strip() for statement in script.split(';') if statement.strip()] for statement in sql_statements: self.execute(statement) def check_migrate_table_exists(self): """Checks whether the migration table exists.""" try: self.execute("SELECT * FROM %s" % MIGRATE_TABLE) return True except self._database.DatabaseError, exc: # we can't check for more specifics due to differences between DB # backends (we can't even check for a subclass of DatabaseError) return False def create_migrate_table(self): """Creates the migration table.""" if not self.check_migrate_table_exists(): self.execute("CREATE TABLE %s (`version` integer)" % MIGRATE_TABLE) else: self.execute("DELETE FROM %s" % MIGRATE_TABLE) self.execute("INSERT INTO %s VALUES (0)" % MIGRATE_TABLE) assert self._database.rowcount == 1 def set_db_version(self, version): """Sets the database version. @param version: The version to which to set the database. """ assert isinstance(version, int) self.execute("UPDATE %s SET version=%%s" % MIGRATE_TABLE, version) assert self._database.rowcount == 1 def get_db_version(self): """Gets the database version. @return The database version. """ if not self.check_migrate_table_exists(): return 0 rows = self.execute("SELECT * FROM %s" % MIGRATE_TABLE) if len(rows) == 0: return 0 assert len(rows) == 1 and len(rows[0]) == 1 return rows[0][0] def get_migrations(self, minimum_version=None, maximum_version=None): """Gets the list of migrations to perform. @param minimum_version: The minimum database version. @param maximum_version: The maximum database version. @return A list of Migration objects. """ migrate_files = [filename for filename in os.listdir(self.migrations_dir) if re.match(r'^\d\d\d_.*\.py$', filename)] migrate_files.sort() migrations = [Migration.from_file(filename) for filename in migrate_files] if minimum_version is not None: migrations = [migration for migration in migrations if migration.version >= minimum_version] if maximum_version is not None: migrations = [migration for migration in migrations if migration.version <= maximum_version] return migrations def do_migration(self, migration, migrate_up=True): """Performs a migration. @param migration: The Migration to perform. @param migrate_up: Whether to migrate up (if not, then migrates down). """ print 'Applying migration %s' % migration.name, # no newline if migrate_up: print 'up' assert self.get_db_version() == migration.version - 1 migration.migrate_up(self) new_version = migration.version else: print 'down' assert self.get_db_version() == migration.version migration.migrate_down(self) new_version = migration.version - 1 self.set_db_version(new_version) def migrate_to_version(self, version): """Performs a migration to a specified version. @param version: The version to which to migrate the database. """ current_version = self.get_db_version() if current_version == 0 and self._config_section() == 'AUTOTEST_WEB': self._migrate_from_base() current_version = self.get_db_version() if current_version < version: lower, upper = current_version, version migrate_up = True else: lower, upper = version, current_version migrate_up = False migrations = self.get_migrations(lower + 1, upper) if not migrate_up: migrations.reverse() for migration in migrations: self.do_migration(migration, migrate_up) assert self.get_db_version() == version print 'At version', version def _migrate_from_base(self): """Initialize the AFE database. """ self.confirm_initialization() migration_script = utils.read_file( os.path.join(os.path.dirname(__file__), 'schema_051.sql')) migration_script = migration_script % ( dict(username=self._database.get_database_info()['username'])) self.execute_script(migration_script) self.create_migrate_table() self.set_db_version(51) def confirm_initialization(self): """Confirms with the user that we should initialize the database. @raises Exception, if the user chooses to abort the migration. """ if not self.force: response = raw_input( 'Your %s database does not appear to be initialized. Do you ' 'want to recreate it (this will result in loss of any existing ' 'data) (yes/No)? ' % self.get_db_name()) if response != 'yes': raise Exception('User has chosen to abort migration') def get_latest_version(self): """Gets the latest database version.""" migrations = self.get_migrations() return migrations[-1].version def migrate_to_latest(self): """Migrates the database to the latest version.""" latest_version = self.get_latest_version() self.migrate_to_version(latest_version) def initialize_test_db(self): """Initializes a test database.""" db_name = self.get_db_name() test_db_name = 'test_' + db_name # first, connect to no DB so we can create a test DB self._database.connect(db_name='') print 'Creating test DB', test_db_name self.execute('CREATE DATABASE ' + test_db_name) self._database.disconnect() # now connect to the test DB self._database.connect(db_name=test_db_name) def remove_test_db(self): """Removes a test database.""" print 'Removing test DB' self.execute('DROP DATABASE ' + self.get_db_name()) # reset connection back to real DB self._database.disconnect() self._database.connect() def get_mysql_args(self): """Returns the mysql arguments as a string.""" return ('-u %(username)s -p%(password)s -h %(host)s %(db_name)s' % self._database.get_database_info()) def migrate_to_version_or_latest(self, version): """Migrates to either a specified version, or the latest version. @param version: The version to which to migrate the database, or None in order to migrate to the latest version. """ if version is None: self.migrate_to_latest() else: self.migrate_to_version(version) def do_sync_db(self, version=None): """Migrates the database. @param version: The version to which to migrate the database. """ print 'Migration starting for database', self.get_db_name() self.migrate_to_version_or_latest(version) print 'Migration complete' def test_sync_db(self, version=None): """Create a fresh database and run all migrations on it. @param version: The version to which to migrate the database. """ self.initialize_test_db() try: print 'Starting migration test on DB', self.get_db_name() self.migrate_to_version_or_latest(version) # show schema to the user os.system('mysqldump %s --no-data=true ' '--add-drop-table=false' % self.get_mysql_args()) finally: self.remove_test_db() print 'Test finished successfully' def simulate_sync_db(self, version=None): """Creates a fresh DB, copies existing DB to it, then synchronizes it. @param version: The version to which to migrate the database. """ db_version = self.get_db_version() # don't do anything if we're already at the latest version if db_version == self.get_latest_version(): print 'Skipping simulation, already at latest version' return # get existing data self.initialize_and_fill_test_db() try: print 'Starting migration test on DB', self.get_db_name() self.migrate_to_version_or_latest(version) finally: self.remove_test_db() print 'Test finished successfully' def initialize_and_fill_test_db(self): """Initializes and fills up a test database.""" print 'Dumping existing data' dump_fd, dump_file = tempfile.mkstemp('.migrate_dump') os.system('mysqldump %s >%s' % (self.get_mysql_args(), dump_file)) # fill in test DB self.initialize_test_db() print 'Filling in test DB' os.system('mysql %s <%s' % (self.get_mysql_args(), dump_file)) os.close(dump_fd) os.remove(dump_file) USAGE = """\ %s [options] sync|test|simulate|safesync [version] Options: -d --database Which database to act on -f --force Don't ask for confirmation --debug Print all DB queries"""\ % sys.argv[0] def main(): """Main function for the migration script.""" parser = OptionParser() parser.add_option("-d", "--database", help="which database to act on", dest="database", default="AUTOTEST_WEB") parser.add_option("-f", "--force", help="don't ask for confirmation", action="store_true") parser.add_option('--debug', help='print all DB queries', action='store_true') (options, args) = parser.parse_args() manager = get_migration_manager(db_name=options.database, debug=options.debug, force=options.force) if len(args) > 0: if len(args) > 1: version = int(args[1]) else: version = None if args[0] == 'sync': manager.do_sync_db(version) elif args[0] == 'test': manager.simulate=True manager.test_sync_db(version) elif args[0] == 'simulate': manager.simulate=True manager.simulate_sync_db(version) elif args[0] == 'safesync': print 'Simluating migration' manager.simulate=True manager.simulate_sync_db(version) print 'Performing real migration' manager.simulate=False manager.do_sync_db(version) else: print USAGE return print USAGE def get_migration_manager(db_name, debug, force): """Creates a MigrationManager object. @param db_name: The database name. @param debug: Whether to print debug messages. @param force: Whether to force migration without asking for confirmation. @return A created MigrationManager object. """ database = database_connection.DatabaseConnection(db_name) database.debug = debug database.reconnect_enabled = False database.connect() return MigrationManager(database, force=force) if __name__ == '__main__': main()