Fix migrations on Python 2.
[mediagoblin.git] / mediagoblin / db / migration_tools.py
CommitLineData
a050e776
E
1# GNU MediaGoblin -- federated, autonomous media hosting
2# Copyright (C) 2011, 2012 MediaGoblin contributors. See AUTHORS.
3#
4# This program is free software: you can redistribute it and/or modify
5# it under the terms of the GNU Affero General Public License as published by
6# the Free Software Foundation, either version 3 of the License, or
7# (at your option) any later version.
8#
9# This program is distributed in the hope that it will be useful,
10# but WITHOUT ANY WARRANTY; without even the implied warranty of
11# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
12# GNU Affero General Public License for more details.
13#
14# You should have received a copy of the GNU Affero General Public License
15# along with this program. If not, see <http://www.gnu.org/licenses/>.
16
7f342c72
BP
17from __future__ import unicode_literals
18
de51eca5 19import logging
65f20ca4
BP
20import os
21
22from alembic import command
23from alembic.config import Config
de51eca5 24from alembic.migration import MigrationContext
65f20ca4
BP
25
26from mediagoblin.db.base import Base
a050e776 27from mediagoblin.tools.common import simple_printer
c4466cb4 28from sqlalchemy import Table
e5196ff0 29from sqlalchemy.sql import select
a050e776 30
de51eca5
BP
31log = logging.getLogger(__name__)
32
33
7e4a87dc
CAW
34class TableAlreadyExists(Exception):
35 pass
36
a050e776 37
65f20ca4
BP
38class AlembicMigrationManager(object):
39
40 def __init__(self, session):
41 root_dir = os.path.abspath(os.path.dirname(os.path.dirname(
42 os.path.dirname(__file__))))
43 alembic_cfg_path = os.path.join(root_dir, 'alembic.ini')
44 self.alembic_cfg = Config(alembic_cfg_path)
45 self.session = session
46
de51eca5
BP
47 def get_current_revision(self):
48 context = MigrationContext.configure(self.session.bind)
49 return context.get_current_revision()
50
51 def upgrade(self, version):
52 return command.upgrade(self.alembic_cfg, version or 'head')
53
54 def downgrade(self, version):
55 if isinstance(version, int) or version is None or version.isdigit():
56 version = 'base'
57 return command.downgrade(self.alembic_cfg, version)
58
59 def stamp(self, revision):
60 return command.stamp(self.alembic_cfg, revision=revision)
61
65f20ca4
BP
62 def init_tables(self):
63 Base.metadata.create_all(self.session.bind)
64 # load the Alembic configuration and generate the
65 # version table, "stamping" it with the most recent rev:
66 command.stamp(self.alembic_cfg, 'head')
67
de51eca5
BP
68 def init_or_migrate(self, version=None):
69 if self.get_current_revision() is None:
70 log.info('Initializing tables and stamping it with '
71 'the most recent migration...')
72 self.init_tables()
73 else:
74 self.upgrade(version)
65f20ca4
BP
75
76
a050e776
E
77class MigrationManager(object):
78 """
79 Migration handling tool.
80
81 Takes information about a database, lets you update the database
82 to the latest migrations, etc.
83 """
84
08cd10d8 85 def __init__(self, name, models, foundations, migration_registry, session,
a050e776
E
86 printer=simple_printer):
87 """
88 Args:
89 - name: identifier of this section of the database
90 - session: session we're going to migrate
91 - migration_registry: where we should find all migrations to
92 run
93 """
7f342c72 94 self.name = name
a050e776 95 self.models = models
08cd10d8 96 self.foundations = foundations
a050e776
E
97 self.session = session
98 self.migration_registry = migration_registry
99 self._sorted_migrations = None
100 self.printer = printer
101
102 # For convenience
103 from mediagoblin.db.models import MigrationData
104
105 self.migration_model = MigrationData
106 self.migration_table = MigrationData.__table__
107
108 @property
109 def sorted_migrations(self):
110 """
111 Sort migrations if necessary and store in self._sorted_migrations
112 """
113 if not self._sorted_migrations:
114 self._sorted_migrations = sorted(
115 self.migration_registry.items(),
116 # sort on the key... the migration number
117 key=lambda migration_tuple: migration_tuple[0])
118
119 return self._sorted_migrations
120
121 @property
122 def migration_data(self):
123 """
124 Get the migration row associated with this object, if any.
125 """
126 return self.session.query(
127 self.migration_model).filter_by(name=self.name).first()
128
129 @property
130 def latest_migration(self):
131 """
132 Return a migration number for the latest migration, or 0 if
133 there are no migrations.
134 """
135 if self.sorted_migrations:
136 return self.sorted_migrations[-1][0]
137 else:
138 # If no migrations have been set, we start at 0.
139 return 0
140
141 @property
142 def database_current_migration(self):
143 """
144 Return the current migration in the database.
145 """
146 # If the table doesn't even exist, return None.
147 if not self.migration_table.exists(self.session.bind):
148 return None
149
150 # Also return None if self.migration_data is None.
151 if self.migration_data is None:
152 return None
153
154 return self.migration_data.version
155
156 def set_current_migration(self, migration_number=None):
157 """
158 Set the migration in the database to migration_number
159 (or, the latest available)
160 """
161 self.migration_data.version = migration_number or self.latest_migration
162 self.session.commit()
163
164 def migrations_to_run(self):
165 """
166 Get a list of migrations to run still, if any.
167
168 Note that this will fail if there's no migration record for
169 this class!
170 """
171 assert self.database_current_migration is not None
172
173 db_current_migration = self.database_current_migration
174
175 return [
176 (migration_number, migration_func)
177 for migration_number, migration_func in self.sorted_migrations
178 if migration_number > db_current_migration]
179
180
181 def init_tables(self):
182 """
183 Create all tables relative to this package
184 """
185 # sanity check before we proceed, none of these should be created
186 for model in self.models:
187 # Maybe in the future just print out a "Yikes!" or something?
7e4a87dc
CAW
188 if model.__table__.exists(self.session.bind):
189 raise TableAlreadyExists(
190 u"Intended to create table '%s' but it already exists" %
191 model.__table__.name)
a050e776
E
192
193 self.migration_model.metadata.create_all(
194 self.session.bind,
195 tables=[model.__table__ for model in self.models])
196
f2b2008d 197 def populate_table_foundations(self):
198 """
199 Create the table foundations (default rows) as layed out in FOUNDATIONS
200 in mediagoblin.db.models
201 """
08cd10d8 202 for Model, rows in self.foundations.items():
63c3ca28 203 self.printer(u' + Laying foundations for %s table\n' %
204 (Model.__name__))
f2b2008d 205 for parameters in rows:
08cd10d8 206 new_row = Model(**parameters)
63c3ca28 207 self.session.add(new_row)
f2b2008d 208
a050e776
E
209 def create_new_migration_record(self):
210 """
211 Create a new migration record for this migration set
212 """
213 migration_record = self.migration_model(
214 name=self.name,
215 version=self.latest_migration)
216 self.session.add(migration_record)
217 self.session.commit()
218
219 def dry_run(self):
220 """
221 Print out a dry run of what we would have upgraded.
222 """
223 if self.database_current_migration is None:
224 self.printer(
225 u'~> Woulda initialized: %s\n' % self.name_for_printing())
226 return u'inited'
227
228 migrations_to_run = self.migrations_to_run()
229 if migrations_to_run:
230 self.printer(
231 u'~> Woulda updated %s:\n' % self.name_for_printing())
232
233 for migration_number, migration_func in migrations_to_run():
234 self.printer(
235 u' + Would update %s, "%s"\n' % (
236 migration_number, migration_func.func_name))
237
238 return u'migrated'
239
240 def name_for_printing(self):
241 if self.name == u'__main__':
242 return u"main mediagoblin tables"
243 else:
003ea474 244 return u'plugin "%s"' % self.name
a050e776
E
245
246 def init_or_migrate(self):
247 """
248 Initialize the database or migrate if appropriate.
249
250 Returns information about whether or not we initialized
251 ('inited'), migrated ('migrated'), or did nothing (None)
252 """
253 assure_migrations_table_setup(self.session)
254
255 # Find out what migration number, if any, this database data is at,
256 # and what the latest is.
257 migration_number = self.database_current_migration
258
259 # Is this our first time? Is there even a table entry for
260 # this identifier?
261 # If so:
262 # - create all tables
263 # - create record in migrations registry
264 # - print / inform the user
265 # - return 'inited'
266 if migration_number is None:
267 self.printer(u"-> Initializing %s... " % self.name_for_printing())
268
269 self.init_tables()
270 # auto-set at latest migration number
f2b2008d 271 self.create_new_migration_record()
a050e776 272 self.printer(u"done.\n")
63c3ca28 273 self.populate_table_foundations()
a050e776
E
274 self.set_current_migration()
275 return u'inited'
276
277 # Run migrations, if appropriate.
278 migrations_to_run = self.migrations_to_run()
279 if migrations_to_run:
280 self.printer(
281 u'-> Updating %s:\n' % self.name_for_printing())
282 for migration_number, migration_func in migrations_to_run:
283 self.printer(
284 u' + Running migration %s, "%s"... ' % (
98d8b365 285 migration_number, migration_func.__name__))
a050e776
E
286 migration_func(self.session)
287 self.set_current_migration(migration_number)
288 self.printer('done.\n')
289
290 return u'migrated'
291
292 # Otherwise return None. Well it would do this anyway, but
293 # for clarity... ;)
294 return None
295
296
297class RegisterMigration(object):
298 """
299 Tool for registering migrations
300
301 Call like:
302
303 @RegisterMigration(33)
304 def update_dwarves(database):
305 [...]
306
307 This will register your migration with the default migration
308 registry. Alternately, to specify a very specific
309 migration_registry, you can pass in that as the second argument.
310
311 Note, the number of your migration should NEVER be 0 or less than
312 0. 0 is the default "no migrations" state!
313 """
314 def __init__(self, migration_number, migration_registry):
315 assert migration_number > 0, "Migration number must be > 0!"
316 assert migration_number not in migration_registry, \
317 "Duplicate migration numbers detected! That's not allowed!"
318
319 self.migration_number = migration_number
320 self.migration_registry = migration_registry
321
322 def __call__(self, migration):
323 self.migration_registry[self.migration_number] = migration
324 return migration
325
326
327def assure_migrations_table_setup(db):
328 """
329 Make sure the migrations table is set up in the database.
330 """
331 from mediagoblin.db.models import MigrationData
332
333 if not MigrationData.__table__.exists(db.bind):
334 MigrationData.metadata.create_all(
335 db.bind, tables=[MigrationData.__table__])
c4466cb4
E
336
337
338def inspect_table(metadata, table_name):
339 """Simple helper to get a ref to an already existing table"""
340 return Table(table_name, metadata, autoload=True,
341 autoload_with=metadata.bind)
e5196ff0 342
0c875e1e
CAW
343def replace_table_hack(db, old_table, replacement_table):
344 """
345 A function to fully replace a current table with a new one for migrati-
e5196ff0 346 -ons. This is necessary because some changes are made tricky in some situa-
347 -tion, for example, dropping a boolean column in sqlite is impossible w/o
348 this method
349
350 :param old_table A ref to the old table, gotten through
351 inspect_table
352
353 :param replacement_table A ref to the new table, gotten through
0c875e1e
CAW
354 inspect_table
355
356 Users are encouraged to sqlalchemy-migrate replace table solutions, unless
357 that is not possible... in which case, this solution works,
358 at least for sqlite.
359 """
e5196ff0 360 surviving_columns = replacement_table.columns.keys()
361 old_table_name = old_table.name
362 for row in db.execute(select(
363 [column for column in old_table.columns
364 if column.name in surviving_columns])):
365
366 db.execute(replacement_table.insert().values(**row))
367 db.commit()
368
369 old_table.drop()
370 db.commit()
371
454a2c16 372 replacement_table.rename(old_table_name)
e5196ff0 373 db.commit()