-
+ 97EEBB27E6EC219B66D69A66FFE37962046B5F76A2A30A11AAD259DE4985789E61837290C9F327E22D2D04E7894DF0AEC767D67AAB4E86F7C81DF22EB162AD91
blatta/lib/caribou.py
(0 . 0)(1 . 271)
199 """
200 Caribou is a simple SQLite database migrations library, built
201 to manage the evoluton of client side databases over multiple releases
202 of an application.
203 """
204
205 from __future__ import with_statement
206
207 __author__ = 'clutchski@gmail.com'
208
209 import contextlib
210 import datetime
211 import glob
212 import imp
213 import os.path
214 import sqlite3
215 import traceback
216
217 # statics
218
219 VERSION_TABLE = 'migration_version'
220 UTC_LENGTH = 14
221
222 # errors
223
224 class Error(Exception):
225 """ Base class for all Caribou errors. """
226 pass
227
228 class InvalidMigrationError(Error):
229 """ Thrown when a client migration contains an error. """
230 pass
231
232 class InvalidNameError(Error):
233 """ Thrown when a client migration has an invalid filename. """
234
235 def __init__(self, filename):
236 msg = 'Migration filenames must start with a UTC timestamp. ' \
237 'The following file has an invalid name: %s' % filename
238 super(InvalidNameError, self).__init__(msg)
239
240 # code
241
242 @contextlib.contextmanager
243 def execute(conn, sql, params=None):
244 params = [] if params is None else params
245 cursor = conn.execute(sql, params)
246 try:
247 yield cursor
248 finally:
249 cursor.close()
250
251 @contextlib.contextmanager
252 def transaction(conn):
253 try:
254 yield
255 conn.commit()
256 except:
257 conn.rollback()
258 msg = "Error in transaction: %s" % traceback.format_exc()
259 raise Error(msg)
260
261 def has_method(an_object, method_name):
262 return hasattr(an_object, method_name) and \
263 callable(getattr(an_object, method_name))
264
265 def is_directory(path):
266 return os.path.exists(path) and os.path.isdir(path)
267
268 class Migration(object):
269 """ This class represents a migration version. """
270
271 def __init__(self, path):
272 self.path = path
273 self.filename = os.path.basename(path)
274 self.module_name, _ = os.path.splitext(self.filename)
275 self.get_version() # will assert the filename is valid
276 self.name = self.module_name[UTC_LENGTH:]
277 while self.name.startswith('_'):
278 self.name = self.name[1:]
279 try:
280 self.module = imp.load_source(self.module_name, path)
281 except:
282 msg = "Invalid migration %s: %s" % (path, traceback.format_exc())
283 raise InvalidMigrationError(msg)
284 # assert the migration has the needed methods
285 missing = [m for m in ['upgrade', 'downgrade']
286 if not has_method(self.module, m)]
287 if missing:
288 msg = 'Migration %s is missing required methods: %s.' % (
289 self.path, ', '.join(missing))
290 raise InvalidMigrationError(msg)
291
292 def get_version(self):
293 if len(self.filename) < UTC_LENGTH:
294 raise InvalidNameError(self.filename)
295 timestamp = self.filename[:UTC_LENGTH]
296 #FIXME: is this test sufficient?
297 if not timestamp.isdigit():
298 raise InvalidNameError(self.filename)
299 return timestamp
300
301 def upgrade(self, conn):
302 self.module.upgrade(conn)
303
304 def downgrade(self, conn):
305 self.module.downgrade(conn)
306
307 def __repr__(self):
308 return 'Migration(%s)' % self.filename
309
310 class Database(object):
311
312 def __init__(self, db_url):
313 self.db_url = db_url
314 self.conn = sqlite3.connect(db_url)
315
316 def close(self):
317 self.conn.close()
318
319 def is_version_controlled(self):
320 sql = """select *
321 from sqlite_master
322 where type = 'table'
323 and name = :1"""
324 with execute(self.conn, sql, [VERSION_TABLE]) as cursor:
325 return bool(cursor.fetchall())
326
327 def upgrade(self, migrations, target_version=None):
328 if target_version:
329 _assert_migration_exists(migrations, target_version)
330
331 migrations.sort(key=lambda x: x.get_version())
332 database_version = self.get_version()
333
334 for migration in migrations:
335 current_version = migration.get_version()
336 if current_version <= database_version:
337 continue
338 if target_version and current_version > target_version:
339 break
340 migration.upgrade(self.conn)
341 new_version = migration.get_version()
342 self.update_version(new_version)
343
344 def downgrade(self, migrations, target_version):
345 if target_version not in (0, '0'):
346 _assert_migration_exists(migrations, target_version)
347
348 migrations.sort(key=lambda x: x.get_version(), reverse=True)
349 database_version = self.get_version()
350
351 for i, migration in enumerate(migrations):
352 current_version = migration.get_version()
353 if current_version > database_version:
354 continue
355 if current_version <= target_version:
356 break
357 migration.downgrade(self.conn)
358 next_version = 0
359 # if an earlier migration exists, set the db version to
360 # its version number
361 if i < len(migrations) - 1:
362 next_migration = migrations[i + 1]
363 next_version = next_migration.get_version()
364 self.update_version(next_version)
365
366 def get_version(self):
367 """ Return the database's version, or None if it is not under version
368 control.
369 """
370 if not self.is_version_controlled():
371 return None
372 sql = 'select version from %s' % VERSION_TABLE
373 with execute(self.conn, sql) as cursor:
374 result = cursor.fetchall()
375 return result[0][0] if result else 0
376
377 def update_version(self, version):
378 sql = 'update %s set version = :1' % VERSION_TABLE
379 with transaction(self.conn):
380 self.conn.execute(sql, [version])
381
382 def initialize_version_control(self):
383 sql = """ create table if not exists %s
384 ( version text ) """ % VERSION_TABLE
385 with transaction(self.conn):
386 self.conn.execute(sql)
387 self.conn.execute('insert into %s values (0)' % VERSION_TABLE)
388
389 def __repr__(self):
390 return 'Database("%s")' % self.db_url
391
392 def _assert_migration_exists(migrations, version):
393 if version not in (m.get_version() for m in migrations):
394 raise Error('No migration with version %s exists.' % version)
395
396 def load_migrations(directory):
397 """ Return the migrations contained in the given directory. """
398 if not is_directory(directory):
399 msg = "%s is not a directory." % directory
400 raise Error(msg)
401 wildcard = os.path.join(directory, '*.py')
402 migration_files = glob.glob(wildcard)
403 return [Migration(f) for f in migration_files]
404
405 def upgrade(db_url, migration_dir, version=None):
406 """ Upgrade the given database with the migrations contained in the
407 migrations directory. If a version is not specified, upgrade
408 to the most recent version.
409 """
410 with contextlib.closing(Database(db_url)) as db:
411 db = Database(db_url)
412 if not db.is_version_controlled():
413 db.initialize_version_control()
414 migrations = load_migrations(migration_dir)
415 db.upgrade(migrations, version)
416
417 def downgrade(db_url, migration_dir, version):
418 """ Downgrade the database to the given version with the migrations
419 contained in the given migration directory.
420 """
421 with contextlib.closing(Database(db_url)) as db:
422 if not db.is_version_controlled():
423 msg = "The database %s is not version controlled." % (db_url)
424 raise Error(msg)
425 migrations = load_migrations(migration_dir)
426 db.downgrade(migrations, version)
427
428 def get_version(db_url):
429 """ Return the migration version of the given database. """
430 with contextlib.closing(Database(db_url)) as db:
431 return db.get_version()
432
433 def create_migration(name, directory=None):
434 """ Create a migration with the given name. If no directory is specified,
435 the current working directory will be used.
436 """
437 directory = directory if directory else '.'
438 if not is_directory(directory):
439 msg = '%s is not a directory.' % directory
440 raise Error(msg)
441
442 now = datetime.datetime.now()
443 version = now.strftime("%Y%m%d%H%M%S")
444
445 contents = MIGRATION_TEMPLATE % {'name':name, 'version':version}
446
447 name = name.replace(' ', '_')
448 filename = "%s_%s.py" % (version, name)
449 path = os.path.join(directory, filename)
450 with open(path, 'w') as migration_file:
451 migration_file.write(contents)
452 return path
453
454 MIGRATION_TEMPLATE = """\
455 \"\"\"
456 This module contains a Caribou migration.
457
458 Migration Name: %(name)s
459 Migration Version: %(version)s
460 \"\"\"
461
462 def upgrade(connection):
463 # add your upgrade step here
464 pass
465
466 def downgrade(connection):
467 # add your downgrade step here
468 pass
469 """