Provide a better manager API for Alembic.
authorBerker Peksag <berker.peksag@gmail.com>
Fri, 15 Aug 2014 12:39:45 +0000 (15:39 +0300)
committerBerker Peksag <berker.peksag@gmail.com>
Fri, 15 Aug 2014 12:39:45 +0000 (15:39 +0300)
mediagoblin/db/migration_tools.py

index 2d7b999a48048c9abfdbe31436579093b276a6ac..ab4487d29a708c24611eec375d891e060fbbf6d0 100644 (file)
 
 from __future__ import unicode_literals
 
+import logging
 import os
 
 from alembic import command
 from alembic.config import Config
+from alembic.migration import MigrationContext
 
 from mediagoblin.db.base import Base
 from mediagoblin.tools.common import simple_printer
 from sqlalchemy import Table
 from sqlalchemy.sql import select
 
+log = logging.getLogger(__name__)
+
+
 class TableAlreadyExists(Exception):
     pass
 
@@ -39,18 +44,34 @@ class AlembicMigrationManager(object):
         self.alembic_cfg = Config(alembic_cfg_path)
         self.session = session
 
+    def get_current_revision(self):
+        context = MigrationContext.configure(self.session.bind)
+        return context.get_current_revision()
+
+    def upgrade(self, version):
+        return command.upgrade(self.alembic_cfg, version or 'head')
+
+    def downgrade(self, version):
+        if isinstance(version, int) or version is None or version.isdigit():
+            version = 'base'
+        return command.downgrade(self.alembic_cfg, version)
+
+    def stamp(self, revision):
+        return command.stamp(self.alembic_cfg, revision=revision)
+
     def init_tables(self):
         Base.metadata.create_all(self.session.bind)
         # load the Alembic configuration and generate the
         # version table, "stamping" it with the most recent rev:
         command.stamp(self.alembic_cfg, 'head')
 
-    def init_or_migrate(self, version='head'):
-        # TODO(berker): Check this
-        # http://alembic.readthedocs.org/en/latest/api.html#alembic.migration.MigrationContext
-        # current_rev = context.get_current_revision()
-        # Call self.init_tables() first if current_rev is None?
-        command.upgrade(self.alembic_cfg, version)
+    def init_or_migrate(self, version=None):
+        if self.get_current_revision() is None:
+            log.info('Initializing tables and stamping it with '
+                     'the most recent migration...')
+            self.init_tables()
+        else:
+            self.upgrade(version)
 
 
 class MigrationManager(object):