raw
9982-getdata            1 """
9982-getdata 2 Caribou is a simple SQLite database migrations library, built
9982-getdata 3 to manage the evoluton of client side databases over multiple releases
9982-getdata 4 of an application.
9982-getdata 5 """
9982-getdata 6
9982-getdata 7 from __future__ import with_statement
9982-getdata 8
9982-getdata 9 __author__ = 'clutchski@gmail.com'
9982-getdata 10
9982-getdata 11 import contextlib
9982-getdata 12 import datetime
9982-getdata 13 import glob
9982-getdata 14 import imp
9982-getdata 15 import os.path
9982-getdata 16 import sqlite3
9982-getdata 17 import traceback
9982-getdata 18
9982-getdata 19 # statics
9982-getdata 20
9982-getdata 21 VERSION_TABLE = 'migration_version'
9982-getdata 22 UTC_LENGTH = 14
9982-getdata 23
9982-getdata 24 # errors
9982-getdata 25
9982-getdata 26 class Error(Exception):
9982-getdata 27 """ Base class for all Caribou errors. """
9982-getdata 28 pass
9982-getdata 29
9982-getdata 30 class InvalidMigrationError(Error):
9982-getdata 31 """ Thrown when a client migration contains an error. """
9982-getdata 32 pass
9982-getdata 33
9982-getdata 34 class InvalidNameError(Error):
9982-getdata 35 """ Thrown when a client migration has an invalid filename. """
9982-getdata 36
9982-getdata 37 def __init__(self, filename):
9982-getdata 38 msg = 'Migration filenames must start with a UTC timestamp. ' \
9982-getdata 39 'The following file has an invalid name: %s' % filename
9982-getdata 40 super(InvalidNameError, self).__init__(msg)
9982-getdata 41
9982-getdata 42 # code
9982-getdata 43
9982-getdata 44 @contextlib.contextmanager
9982-getdata 45 def execute(conn, sql, params=None):
9982-getdata 46 params = [] if params is None else params
9982-getdata 47 cursor = conn.execute(sql, params)
9982-getdata 48 try:
9982-getdata 49 yield cursor
9982-getdata 50 finally:
9982-getdata 51 cursor.close()
9982-getdata 52
9982-getdata 53 @contextlib.contextmanager
9982-getdata 54 def transaction(conn):
9982-getdata 55 try:
9982-getdata 56 yield
9982-getdata 57 conn.commit()
9982-getdata 58 except:
9982-getdata 59 conn.rollback()
9982-getdata 60 msg = "Error in transaction: %s" % traceback.format_exc()
9982-getdata 61 raise Error(msg)
9982-getdata 62
9982-getdata 63 def has_method(an_object, method_name):
9982-getdata 64 return hasattr(an_object, method_name) and \
9982-getdata 65 callable(getattr(an_object, method_name))
9982-getdata 66
9982-getdata 67 def is_directory(path):
9982-getdata 68 return os.path.exists(path) and os.path.isdir(path)
9982-getdata 69
9982-getdata 70 class Migration(object):
9982-getdata 71 """ This class represents a migration version. """
9982-getdata 72
9982-getdata 73 def __init__(self, path):
9982-getdata 74 self.path = path
9982-getdata 75 self.filename = os.path.basename(path)
9982-getdata 76 self.module_name, _ = os.path.splitext(self.filename)
9982-getdata 77 self.get_version() # will assert the filename is valid
9982-getdata 78 self.name = self.module_name[UTC_LENGTH:]
9982-getdata 79 while self.name.startswith('_'):
9982-getdata 80 self.name = self.name[1:]
9982-getdata 81 try:
9982-getdata 82 self.module = imp.load_source(self.module_name, path)
9982-getdata 83 except:
9982-getdata 84 msg = "Invalid migration %s: %s" % (path, traceback.format_exc())
9982-getdata 85 raise InvalidMigrationError(msg)
9982-getdata 86 # assert the migration has the needed methods
9982-getdata 87 missing = [m for m in ['upgrade', 'downgrade']
9982-getdata 88 if not has_method(self.module, m)]
9982-getdata 89 if missing:
9982-getdata 90 msg = 'Migration %s is missing required methods: %s.' % (
9982-getdata 91 self.path, ', '.join(missing))
9982-getdata 92 raise InvalidMigrationError(msg)
9982-getdata 93
9982-getdata 94 def get_version(self):
9982-getdata 95 if len(self.filename) < UTC_LENGTH:
9982-getdata 96 raise InvalidNameError(self.filename)
9982-getdata 97 timestamp = self.filename[:UTC_LENGTH]
9982-getdata 98 #FIXME: is this test sufficient?
9982-getdata 99 if not timestamp.isdigit():
9982-getdata 100 raise InvalidNameError(self.filename)
9982-getdata 101 return timestamp
9982-getdata 102
9982-getdata 103 def upgrade(self, conn):
9982-getdata 104 self.module.upgrade(conn)
9982-getdata 105
9982-getdata 106 def downgrade(self, conn):
9982-getdata 107 self.module.downgrade(conn)
9982-getdata 108
9982-getdata 109 def __repr__(self):
9982-getdata 110 return 'Migration(%s)' % self.filename
9982-getdata 111
9982-getdata 112 class Database(object):
9982-getdata 113
9982-getdata 114 def __init__(self, db_url):
9982-getdata 115 self.db_url = db_url
9982-getdata 116 self.conn = sqlite3.connect(db_url)
9982-getdata 117
9982-getdata 118 def close(self):
9982-getdata 119 self.conn.close()
9982-getdata 120
9982-getdata 121 def is_version_controlled(self):
9982-getdata 122 sql = """select *
9982-getdata 123 from sqlite_master
9982-getdata 124 where type = 'table'
9982-getdata 125 and name = :1"""
9982-getdata 126 with execute(self.conn, sql, [VERSION_TABLE]) as cursor:
9982-getdata 127 return bool(cursor.fetchall())
9982-getdata 128
9982-getdata 129 def upgrade(self, migrations, target_version=None):
9982-getdata 130 if target_version:
9982-getdata 131 _assert_migration_exists(migrations, target_version)
9982-getdata 132
9982-getdata 133 migrations.sort(key=lambda x: x.get_version())
9982-getdata 134 database_version = self.get_version()
9982-getdata 135
9982-getdata 136 for migration in migrations:
9982-getdata 137 current_version = migration.get_version()
9982-getdata 138 if current_version <= database_version:
9982-getdata 139 continue
9982-getdata 140 if target_version and current_version > target_version:
9982-getdata 141 break
9982-getdata 142 migration.upgrade(self.conn)
9982-getdata 143 new_version = migration.get_version()
9982-getdata 144 self.update_version(new_version)
9982-getdata 145
9982-getdata 146 def downgrade(self, migrations, target_version):
9982-getdata 147 if target_version not in (0, '0'):
9982-getdata 148 _assert_migration_exists(migrations, target_version)
9982-getdata 149
9982-getdata 150 migrations.sort(key=lambda x: x.get_version(), reverse=True)
9982-getdata 151 database_version = self.get_version()
9982-getdata 152
9982-getdata 153 for i, migration in enumerate(migrations):
9982-getdata 154 current_version = migration.get_version()
9982-getdata 155 if current_version > database_version:
9982-getdata 156 continue
9982-getdata 157 if current_version <= target_version:
9982-getdata 158 break
9982-getdata 159 migration.downgrade(self.conn)
9982-getdata 160 next_version = 0
9982-getdata 161 # if an earlier migration exists, set the db version to
9982-getdata 162 # its version number
9982-getdata 163 if i < len(migrations) - 1:
9982-getdata 164 next_migration = migrations[i + 1]
9982-getdata 165 next_version = next_migration.get_version()
9982-getdata 166 self.update_version(next_version)
9982-getdata 167
9982-getdata 168 def get_version(self):
9982-getdata 169 """ Return the database's version, or None if it is not under version
9982-getdata 170 control.
9982-getdata 171 """
9982-getdata 172 if not self.is_version_controlled():
9982-getdata 173 return None
9982-getdata 174 sql = 'select version from %s' % VERSION_TABLE
9982-getdata 175 with execute(self.conn, sql) as cursor:
9982-getdata 176 result = cursor.fetchall()
9982-getdata 177 return result[0][0] if result else 0
9982-getdata 178
9982-getdata 179 def update_version(self, version):
9982-getdata 180 sql = 'update %s set version = :1' % VERSION_TABLE
9982-getdata 181 with transaction(self.conn):
9982-getdata 182 self.conn.execute(sql, [version])
9982-getdata 183
9982-getdata 184 def initialize_version_control(self):
9982-getdata 185 sql = """ create table if not exists %s
9982-getdata 186 ( version text ) """ % VERSION_TABLE
9982-getdata 187 with transaction(self.conn):
9982-getdata 188 self.conn.execute(sql)
9982-getdata 189 self.conn.execute('insert into %s values (0)' % VERSION_TABLE)
9982-getdata 190
9982-getdata 191 def __repr__(self):
9982-getdata 192 return 'Database("%s")' % self.db_url
9982-getdata 193
9982-getdata 194 def _assert_migration_exists(migrations, version):
9982-getdata 195 if version not in (m.get_version() for m in migrations):
9982-getdata 196 raise Error('No migration with version %s exists.' % version)
9982-getdata 197
9982-getdata 198 def load_migrations(directory):
9982-getdata 199 """ Return the migrations contained in the given directory. """
9982-getdata 200 if not is_directory(directory):
9982-getdata 201 msg = "%s is not a directory." % directory
9982-getdata 202 raise Error(msg)
9982-getdata 203 wildcard = os.path.join(directory, '*.py')
9982-getdata 204 migration_files = glob.glob(wildcard)
9982-getdata 205 return [Migration(f) for f in migration_files]
9982-getdata 206
9982-getdata 207 def upgrade(db_url, migration_dir, version=None):
9982-getdata 208 """ Upgrade the given database with the migrations contained in the
9982-getdata 209 migrations directory. If a version is not specified, upgrade
9982-getdata 210 to the most recent version.
9982-getdata 211 """
9982-getdata 212 with contextlib.closing(Database(db_url)) as db:
9982-getdata 213 db = Database(db_url)
9982-getdata 214 if not db.is_version_controlled():
9982-getdata 215 db.initialize_version_control()
9982-getdata 216 migrations = load_migrations(migration_dir)
9982-getdata 217 db.upgrade(migrations, version)
9982-getdata 218
9982-getdata 219 def downgrade(db_url, migration_dir, version):
9982-getdata 220 """ Downgrade the database to the given version with the migrations
9982-getdata 221 contained in the given migration directory.
9982-getdata 222 """
9982-getdata 223 with contextlib.closing(Database(db_url)) as db:
9982-getdata 224 if not db.is_version_controlled():
9982-getdata 225 msg = "The database %s is not version controlled." % (db_url)
9982-getdata 226 raise Error(msg)
9982-getdata 227 migrations = load_migrations(migration_dir)
9982-getdata 228 db.downgrade(migrations, version)
9982-getdata 229
9982-getdata 230 def get_version(db_url):
9982-getdata 231 """ Return the migration version of the given database. """
9982-getdata 232 with contextlib.closing(Database(db_url)) as db:
9982-getdata 233 return db.get_version()
9982-getdata 234
9982-getdata 235 def create_migration(name, directory=None):
9982-getdata 236 """ Create a migration with the given name. If no directory is specified,
9982-getdata 237 the current working directory will be used.
9982-getdata 238 """
9982-getdata 239 directory = directory if directory else '.'
9982-getdata 240 if not is_directory(directory):
9982-getdata 241 msg = '%s is not a directory.' % directory
9982-getdata 242 raise Error(msg)
9982-getdata 243
9982-getdata 244 now = datetime.datetime.now()
9982-getdata 245 version = now.strftime("%Y%m%d%H%M%S")
9982-getdata 246
9982-getdata 247 contents = MIGRATION_TEMPLATE % {'name':name, 'version':version}
9982-getdata 248
9982-getdata 249 name = name.replace(' ', '_')
9982-getdata 250 filename = "%s_%s.py" % (version, name)
9982-getdata 251 path = os.path.join(directory, filename)
9982-getdata 252 with open(path, 'w') as migration_file:
9982-getdata 253 migration_file.write(contents)
9982-getdata 254 return path
9982-getdata 255
9982-getdata 256 MIGRATION_TEMPLATE = """\
9982-getdata 257 \"\"\"
9982-getdata 258 This module contains a Caribou migration.
9982-getdata 259
9982-getdata 260 Migration Name: %(name)s
9982-getdata 261 Migration Version: %(version)s
9982-getdata 262 \"\"\"
9982-getdata 263
9982-getdata 264 def upgrade(connection):
9982-getdata 265 # add your upgrade step here
9982-getdata 266 pass
9982-getdata 267
9982-getdata 268 def downgrade(connection):
9982-getdata 269 # add your downgrade step here
9982-getdata 270 pass
9982-getdata 271 """