Merge remote-tracking branch 'gsoc2016/Subtitle-1'
[mediagoblin.git] / mediagoblin / db / migration_tools.py
1 # GNU MediaGoblin -- federated, autonomous media hosting
2 # Copyright (C) 2011, 2012 MediaGoblin contributors. See AUTHORS.
3 #
4 # This program is free software: you can redistribute it and/or modify
5 # it under the terms of the GNU Affero General Public License as published by
6 # the Free Software Foundation, either version 3 of the License, or
7 # (at your option) any later version.
8 #
9 # This program is distributed in the hope that it will be useful,
10 # but WITHOUT ANY WARRANTY; without even the implied warranty of
11 # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
12 # GNU Affero General Public License for more details.
13 #
14 # You should have received a copy of the GNU Affero General Public License
15 # along with this program. If not, see <http://www.gnu.org/licenses/>.
16
17 from __future__ import unicode_literals
18
19 import logging
20 import os
21 import pkg_resources
22
23 from alembic import command
24 from alembic.config import Config
25 from alembic.migration import MigrationContext
26
27 from mediagoblin.db.base import Base
28 from mediagoblin.tools.common import simple_printer
29 from sqlalchemy import Table
30 from sqlalchemy.sql import select
31
32 log = logging.getLogger(__name__)
33
34
35 class TableAlreadyExists(Exception):
36 pass
37
38
39 class MigrationManager(object):
40 """
41 Migration handling tool.
42
43 Takes information about a database, lets you update the database
44 to the latest migrations, etc.
45 """
46
47 def __init__(self, name, models, migration_registry, session,
48 printer=simple_printer):
49 """
50 Args:
51 - name: identifier of this section of the database
52 - session: session we're going to migrate
53 - migration_registry: where we should find all migrations to
54 run
55 """
56 self.name = name
57 self.models = models
58 self.session = session
59 self.migration_registry = migration_registry
60 self._sorted_migrations = None
61 self.printer = printer
62
63 # For convenience
64 from mediagoblin.db.models import MigrationData
65
66 self.migration_model = MigrationData
67 self.migration_table = MigrationData.__table__
68
69 @property
70 def sorted_migrations(self):
71 """
72 Sort migrations if necessary and store in self._sorted_migrations
73 """
74 if not self._sorted_migrations:
75 self._sorted_migrations = sorted(
76 self.migration_registry.items(),
77 # sort on the key... the migration number
78 key=lambda migration_tuple: migration_tuple[0])
79
80 return self._sorted_migrations
81
82 @property
83 def migration_data(self):
84 """
85 Get the migration row associated with this object, if any.
86 """
87 return self.session.query(
88 self.migration_model).filter_by(name=self.name).first()
89
90 @property
91 def latest_migration(self):
92 """
93 Return a migration number for the latest migration, or 0 if
94 there are no migrations.
95 """
96 if self.sorted_migrations:
97 return self.sorted_migrations[-1][0]
98 else:
99 # If no migrations have been set, we start at 0.
100 return 0
101
102 @property
103 def database_current_migration(self):
104 """
105 Return the current migration in the database.
106 """
107 # If the table doesn't even exist, return None.
108 if not self.migration_table.exists(self.session.bind):
109 return None
110
111 # Also return None if self.migration_data is None.
112 if self.migration_data is None:
113 return None
114
115 return self.migration_data.version
116
117 def set_current_migration(self, migration_number=None):
118 """
119 Set the migration in the database to migration_number
120 (or, the latest available)
121 """
122 self.migration_data.version = migration_number or self.latest_migration
123 self.session.commit()
124
125 def migrations_to_run(self):
126 """
127 Get a list of migrations to run still, if any.
128
129 Note that this will fail if there's no migration record for
130 this class!
131 """
132 assert self.database_current_migration is not None
133
134 db_current_migration = self.database_current_migration
135
136 return [
137 (migration_number, migration_func)
138 for migration_number, migration_func in self.sorted_migrations
139 if migration_number > db_current_migration]
140
141
142 def init_tables(self):
143 """
144 Create all tables relative to this package
145 """
146 # sanity check before we proceed, none of these should be created
147 for model in self.models:
148 # Maybe in the future just print out a "Yikes!" or something?
149 if model.__table__.exists(self.session.bind):
150 raise TableAlreadyExists(
151 u"Intended to create table '%s' but it already exists" %
152 model.__table__.name)
153
154 self.migration_model.metadata.create_all(
155 self.session.bind,
156 tables=[model.__table__ for model in self.models])
157
158 def create_new_migration_record(self):
159 """
160 Create a new migration record for this migration set
161 """
162 migration_record = self.migration_model(
163 name=self.name,
164 version=self.latest_migration)
165 self.session.add(migration_record)
166 self.session.commit()
167
168 def dry_run(self):
169 """
170 Print out a dry run of what we would have upgraded.
171 """
172 if self.database_current_migration is None:
173 self.printer(
174 u'~> Woulda initialized: %s\n' % self.name_for_printing())
175 return u'inited'
176
177 migrations_to_run = self.migrations_to_run()
178 if migrations_to_run:
179 self.printer(
180 u'~> Woulda updated %s:\n' % self.name_for_printing())
181
182 for migration_number, migration_func in migrations_to_run():
183 self.printer(
184 u' + Would update %s, "%s"\n' % (
185 migration_number, migration_func.func_name))
186
187 return u'migrated'
188
189 def name_for_printing(self):
190 if self.name == u'__main__':
191 return u"main mediagoblin tables"
192 else:
193 return u'plugin "%s"' % self.name
194
195 def init_or_migrate(self):
196 """
197 Initialize the database or migrate if appropriate.
198
199 Returns information about whether or not we initialized
200 ('inited'), migrated ('migrated'), or did nothing (None)
201 """
202 assure_migrations_table_setup(self.session)
203
204 # Find out what migration number, if any, this database data is at,
205 # and what the latest is.
206 migration_number = self.database_current_migration
207
208 # Is this our first time? Is there even a table entry for
209 # this identifier?
210 # If so:
211 # - create all tables
212 # - create record in migrations registry
213 # - print / inform the user
214 # - return 'inited'
215 if migration_number is None:
216 self.printer(u"-> Initializing %s... " % self.name_for_printing())
217
218 self.init_tables()
219 # auto-set at latest migration number
220 self.create_new_migration_record()
221 self.printer(u"done.\n")
222 self.set_current_migration()
223 return u'inited'
224
225 # Run migrations, if appropriate.
226 migrations_to_run = self.migrations_to_run()
227 if migrations_to_run:
228 self.printer(
229 u'-> Updating %s:\n' % self.name_for_printing())
230 for migration_number, migration_func in migrations_to_run:
231 self.printer(
232 u' + Running migration %s, "%s"... ' % (
233 migration_number, migration_func.__name__))
234 migration_func(self.session)
235 self.set_current_migration(migration_number)
236 self.printer('done.\n')
237
238 return u'migrated'
239
240 # Otherwise return None. Well it would do this anyway, but
241 # for clarity... ;)
242 return None
243
244
245 class RegisterMigration(object):
246 """
247 Tool for registering migrations
248
249 Call like:
250
251 @RegisterMigration(33)
252 def update_dwarves(database):
253 [...]
254
255 This will register your migration with the default migration
256 registry. Alternately, to specify a very specific
257 migration_registry, you can pass in that as the second argument.
258
259 Note, the number of your migration should NEVER be 0 or less than
260 0. 0 is the default "no migrations" state!
261 """
262 def __init__(self, migration_number, migration_registry):
263 assert migration_number > 0, "Migration number must be > 0!"
264 assert migration_number not in migration_registry, \
265 "Duplicate migration numbers detected! That's not allowed!"
266 assert migration_number <= 44, ('Alembic should be used for '
267 'new migrations')
268
269 self.migration_number = migration_number
270 self.migration_registry = migration_registry
271
272 def __call__(self, migration):
273 self.migration_registry[self.migration_number] = migration
274 return migration
275
276
277 def assure_migrations_table_setup(db):
278 """
279 Make sure the migrations table is set up in the database.
280 """
281 from mediagoblin.db.models import MigrationData
282
283 if not MigrationData.__table__.exists(db.bind):
284 MigrationData.metadata.create_all(
285 db.bind, tables=[MigrationData.__table__])
286
287
288 def inspect_table(metadata, table_name):
289 """Simple helper to get a ref to an already existing table"""
290 return Table(table_name, metadata, autoload=True,
291 autoload_with=metadata.bind)
292
293 def replace_table_hack(db, old_table, replacement_table):
294 """
295 A function to fully replace a current table with a new one for migrati-
296 -ons. This is necessary because some changes are made tricky in some situa-
297 -tion, for example, dropping a boolean column in sqlite is impossible w/o
298 this method
299
300 :param old_table A ref to the old table, gotten through
301 inspect_table
302
303 :param replacement_table A ref to the new table, gotten through
304 inspect_table
305
306 Users are encouraged to sqlalchemy-migrate replace table solutions, unless
307 that is not possible... in which case, this solution works,
308 at least for sqlite.
309 """
310 surviving_columns = replacement_table.columns.keys()
311 old_table_name = old_table.name
312 for row in db.execute(select(
313 [column for column in old_table.columns
314 if column.name in surviving_columns])):
315
316 db.execute(replacement_table.insert().values(**row))
317 db.commit()
318
319 old_table.drop()
320 db.commit()
321
322 replacement_table.rename(old_table_name)
323 db.commit()
324
325 def model_iteration_hack(db, query):
326 """
327 This will return either the query you gave if it's postgres or in the case
328 of sqlite it will return a list with all the results. This is because in
329 migrations it seems sqlite can't deal with concurrent quries so if you're
330 iterating over models and doing a commit inside the loop, you will run into
331 an exception which says you've closed the connection on your iteration
332 query. This fixes it.
333
334 NB: This loads all of the query reuslts into memeory, there isn't a good
335 way around this, we're assuming sqlite users have small databases.
336 """
337 # If it's SQLite just return all the objects
338 if db.bind.url.drivername == "sqlite":
339 return [obj for obj in db.execute(query)]
340
341 # Postgres return the query as it knows how to deal with it.
342 return db.execute(query)
343
344
345 def populate_table_foundations(session, foundations, name,
346 printer=simple_printer):
347 """
348 Create the table foundations (default rows) as layed out in FOUNDATIONS
349 in mediagoblin.db.models
350 """
351 printer(u'Laying foundations for %s:\n' % name)
352 for Model, rows in foundations.items():
353 printer(u' + Laying foundations for %s table\n' %
354 (Model.__name__))
355 for parameters in rows:
356 new_row = Model(**parameters)
357 session.add(new_row)
358
359 session.commit()
360
361
362 def build_alembic_config(global_config, cmd_options, session):
363 """
364 Build up a config that the alembic tooling can use based on our
365 configuration. Initialize the database session appropriately
366 as well.
367 """
368 alembic_dir = os.path.join(os.path.dirname(__file__), 'migrations')
369 alembic_cfg_path = os.path.join(alembic_dir, 'alembic.ini')
370 cfg = Config(alembic_cfg_path,
371 cmd_opts=cmd_options)
372 cfg.attributes["session"] = session
373
374 version_locations = [
375 pkg_resources.resource_filename(
376 "mediagoblin.db", os.path.join("migrations", "versions")),
377 ]
378
379 cfg.set_main_option("sqlalchemy.url", str(session.get_bind().url))
380
381 for plugin in global_config.get("plugins", []):
382 plugin_migrations = pkg_resources.resource_filename(
383 plugin, "migrations")
384 is_migrations_dir = (os.path.exists(plugin_migrations) and
385 os.path.isdir(plugin_migrations))
386 if is_migrations_dir:
387 version_locations.append(plugin_migrations)
388
389 cfg.set_main_option(
390 "version_locations",
391 " ".join(version_locations))
392
393 return cfg