Commit | Line | Data |
---|---|---|
a050e776 E |
1 | # GNU MediaGoblin -- federated, autonomous media hosting |
2 | # Copyright (C) 2011, 2012 MediaGoblin contributors. See AUTHORS. | |
3 | # | |
4 | # This program is free software: you can redistribute it and/or modify | |
5 | # it under the terms of the GNU Affero General Public License as published by | |
6 | # the Free Software Foundation, either version 3 of the License, or | |
7 | # (at your option) any later version. | |
8 | # | |
9 | # This program is distributed in the hope that it will be useful, | |
10 | # but WITHOUT ANY WARRANTY; without even the implied warranty of | |
11 | # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the | |
12 | # GNU Affero General Public License for more details. | |
13 | # | |
14 | # You should have received a copy of the GNU Affero General Public License | |
15 | # along with this program. If not, see <http://www.gnu.org/licenses/>. | |
16 | ||
7f342c72 BP |
17 | from __future__ import unicode_literals |
18 | ||
de51eca5 | 19 | import logging |
65f20ca4 BP |
20 | import os |
21 | ||
22 | from alembic import command | |
23 | from alembic.config import Config | |
de51eca5 | 24 | from alembic.migration import MigrationContext |
65f20ca4 BP |
25 | |
26 | from mediagoblin.db.base import Base | |
a050e776 | 27 | from mediagoblin.tools.common import simple_printer |
c4466cb4 | 28 | from sqlalchemy import Table |
e5196ff0 | 29 | from sqlalchemy.sql import select |
a050e776 | 30 | |
de51eca5 BP |
31 | log = logging.getLogger(__name__) |
32 | ||
33 | ||
7e4a87dc CAW |
34 | class TableAlreadyExists(Exception): |
35 | pass | |
36 | ||
a050e776 | 37 | |
65f20ca4 BP |
38 | class AlembicMigrationManager(object): |
39 | ||
40 | def __init__(self, session): | |
41 | root_dir = os.path.abspath(os.path.dirname(os.path.dirname( | |
42 | os.path.dirname(__file__)))) | |
43 | alembic_cfg_path = os.path.join(root_dir, 'alembic.ini') | |
44 | self.alembic_cfg = Config(alembic_cfg_path) | |
45 | self.session = session | |
46 | ||
de51eca5 BP |
47 | def get_current_revision(self): |
48 | context = MigrationContext.configure(self.session.bind) | |
49 | return context.get_current_revision() | |
50 | ||
51 | def upgrade(self, version): | |
52 | return command.upgrade(self.alembic_cfg, version or 'head') | |
53 | ||
54 | def downgrade(self, version): | |
55 | if isinstance(version, int) or version is None or version.isdigit(): | |
56 | version = 'base' | |
57 | return command.downgrade(self.alembic_cfg, version) | |
58 | ||
59 | def stamp(self, revision): | |
60 | return command.stamp(self.alembic_cfg, revision=revision) | |
61 | ||
65f20ca4 BP |
62 | def init_tables(self): |
63 | Base.metadata.create_all(self.session.bind) | |
64 | # load the Alembic configuration and generate the | |
65 | # version table, "stamping" it with the most recent rev: | |
66 | command.stamp(self.alembic_cfg, 'head') | |
67 | ||
de51eca5 BP |
68 | def init_or_migrate(self, version=None): |
69 | if self.get_current_revision() is None: | |
70 | log.info('Initializing tables and stamping it with ' | |
71 | 'the most recent migration...') | |
72 | self.init_tables() | |
73 | else: | |
74 | self.upgrade(version) | |
65f20ca4 BP |
75 | |
76 | ||
a050e776 E |
77 | class MigrationManager(object): |
78 | """ | |
79 | Migration handling tool. | |
80 | ||
81 | Takes information about a database, lets you update the database | |
82 | to the latest migrations, etc. | |
83 | """ | |
84 | ||
08cd10d8 | 85 | def __init__(self, name, models, foundations, migration_registry, session, |
a050e776 E |
86 | printer=simple_printer): |
87 | """ | |
88 | Args: | |
89 | - name: identifier of this section of the database | |
90 | - session: session we're going to migrate | |
91 | - migration_registry: where we should find all migrations to | |
92 | run | |
93 | """ | |
7f342c72 | 94 | self.name = name |
a050e776 | 95 | self.models = models |
08cd10d8 | 96 | self.foundations = foundations |
a050e776 E |
97 | self.session = session |
98 | self.migration_registry = migration_registry | |
99 | self._sorted_migrations = None | |
100 | self.printer = printer | |
101 | ||
102 | # For convenience | |
103 | from mediagoblin.db.models import MigrationData | |
104 | ||
105 | self.migration_model = MigrationData | |
106 | self.migration_table = MigrationData.__table__ | |
107 | ||
108 | @property | |
109 | def sorted_migrations(self): | |
110 | """ | |
111 | Sort migrations if necessary and store in self._sorted_migrations | |
112 | """ | |
113 | if not self._sorted_migrations: | |
114 | self._sorted_migrations = sorted( | |
115 | self.migration_registry.items(), | |
116 | # sort on the key... the migration number | |
117 | key=lambda migration_tuple: migration_tuple[0]) | |
118 | ||
119 | return self._sorted_migrations | |
120 | ||
121 | @property | |
122 | def migration_data(self): | |
123 | """ | |
124 | Get the migration row associated with this object, if any. | |
125 | """ | |
126 | return self.session.query( | |
127 | self.migration_model).filter_by(name=self.name).first() | |
128 | ||
129 | @property | |
130 | def latest_migration(self): | |
131 | """ | |
132 | Return a migration number for the latest migration, or 0 if | |
133 | there are no migrations. | |
134 | """ | |
135 | if self.sorted_migrations: | |
136 | return self.sorted_migrations[-1][0] | |
137 | else: | |
138 | # If no migrations have been set, we start at 0. | |
139 | return 0 | |
140 | ||
141 | @property | |
142 | def database_current_migration(self): | |
143 | """ | |
144 | Return the current migration in the database. | |
145 | """ | |
146 | # If the table doesn't even exist, return None. | |
147 | if not self.migration_table.exists(self.session.bind): | |
148 | return None | |
149 | ||
150 | # Also return None if self.migration_data is None. | |
151 | if self.migration_data is None: | |
152 | return None | |
153 | ||
154 | return self.migration_data.version | |
155 | ||
156 | def set_current_migration(self, migration_number=None): | |
157 | """ | |
158 | Set the migration in the database to migration_number | |
159 | (or, the latest available) | |
160 | """ | |
161 | self.migration_data.version = migration_number or self.latest_migration | |
162 | self.session.commit() | |
163 | ||
164 | def migrations_to_run(self): | |
165 | """ | |
166 | Get a list of migrations to run still, if any. | |
167 | ||
168 | Note that this will fail if there's no migration record for | |
169 | this class! | |
170 | """ | |
171 | assert self.database_current_migration is not None | |
172 | ||
173 | db_current_migration = self.database_current_migration | |
174 | ||
175 | return [ | |
176 | (migration_number, migration_func) | |
177 | for migration_number, migration_func in self.sorted_migrations | |
178 | if migration_number > db_current_migration] | |
179 | ||
180 | ||
181 | def init_tables(self): | |
182 | """ | |
183 | Create all tables relative to this package | |
184 | """ | |
185 | # sanity check before we proceed, none of these should be created | |
186 | for model in self.models: | |
187 | # Maybe in the future just print out a "Yikes!" or something? | |
7e4a87dc CAW |
188 | if model.__table__.exists(self.session.bind): |
189 | raise TableAlreadyExists( | |
190 | u"Intended to create table '%s' but it already exists" % | |
191 | model.__table__.name) | |
a050e776 E |
192 | |
193 | self.migration_model.metadata.create_all( | |
194 | self.session.bind, | |
195 | tables=[model.__table__ for model in self.models]) | |
196 | ||
f2b2008d | 197 | def populate_table_foundations(self): |
198 | """ | |
199 | Create the table foundations (default rows) as layed out in FOUNDATIONS | |
200 | in mediagoblin.db.models | |
201 | """ | |
08cd10d8 | 202 | for Model, rows in self.foundations.items(): |
63c3ca28 | 203 | self.printer(u' + Laying foundations for %s table\n' % |
204 | (Model.__name__)) | |
f2b2008d | 205 | for parameters in rows: |
08cd10d8 | 206 | new_row = Model(**parameters) |
63c3ca28 | 207 | self.session.add(new_row) |
f2b2008d | 208 | |
a050e776 E |
209 | def create_new_migration_record(self): |
210 | """ | |
211 | Create a new migration record for this migration set | |
212 | """ | |
213 | migration_record = self.migration_model( | |
214 | name=self.name, | |
215 | version=self.latest_migration) | |
216 | self.session.add(migration_record) | |
217 | self.session.commit() | |
218 | ||
219 | def dry_run(self): | |
220 | """ | |
221 | Print out a dry run of what we would have upgraded. | |
222 | """ | |
223 | if self.database_current_migration is None: | |
224 | self.printer( | |
225 | u'~> Woulda initialized: %s\n' % self.name_for_printing()) | |
226 | return u'inited' | |
227 | ||
228 | migrations_to_run = self.migrations_to_run() | |
229 | if migrations_to_run: | |
230 | self.printer( | |
231 | u'~> Woulda updated %s:\n' % self.name_for_printing()) | |
232 | ||
233 | for migration_number, migration_func in migrations_to_run(): | |
234 | self.printer( | |
235 | u' + Would update %s, "%s"\n' % ( | |
236 | migration_number, migration_func.func_name)) | |
237 | ||
238 | return u'migrated' | |
239 | ||
240 | def name_for_printing(self): | |
241 | if self.name == u'__main__': | |
242 | return u"main mediagoblin tables" | |
243 | else: | |
003ea474 | 244 | return u'plugin "%s"' % self.name |
a050e776 E |
245 | |
246 | def init_or_migrate(self): | |
247 | """ | |
248 | Initialize the database or migrate if appropriate. | |
249 | ||
250 | Returns information about whether or not we initialized | |
251 | ('inited'), migrated ('migrated'), or did nothing (None) | |
252 | """ | |
253 | assure_migrations_table_setup(self.session) | |
254 | ||
255 | # Find out what migration number, if any, this database data is at, | |
256 | # and what the latest is. | |
257 | migration_number = self.database_current_migration | |
258 | ||
259 | # Is this our first time? Is there even a table entry for | |
260 | # this identifier? | |
261 | # If so: | |
262 | # - create all tables | |
263 | # - create record in migrations registry | |
264 | # - print / inform the user | |
265 | # - return 'inited' | |
266 | if migration_number is None: | |
267 | self.printer(u"-> Initializing %s... " % self.name_for_printing()) | |
268 | ||
269 | self.init_tables() | |
270 | # auto-set at latest migration number | |
f2b2008d | 271 | self.create_new_migration_record() |
a050e776 | 272 | self.printer(u"done.\n") |
63c3ca28 | 273 | self.populate_table_foundations() |
a050e776 E |
274 | self.set_current_migration() |
275 | return u'inited' | |
276 | ||
277 | # Run migrations, if appropriate. | |
278 | migrations_to_run = self.migrations_to_run() | |
279 | if migrations_to_run: | |
280 | self.printer( | |
281 | u'-> Updating %s:\n' % self.name_for_printing()) | |
282 | for migration_number, migration_func in migrations_to_run: | |
283 | self.printer( | |
284 | u' + Running migration %s, "%s"... ' % ( | |
98d8b365 | 285 | migration_number, migration_func.__name__)) |
a050e776 E |
286 | migration_func(self.session) |
287 | self.set_current_migration(migration_number) | |
288 | self.printer('done.\n') | |
289 | ||
290 | return u'migrated' | |
291 | ||
292 | # Otherwise return None. Well it would do this anyway, but | |
293 | # for clarity... ;) | |
294 | return None | |
295 | ||
296 | ||
297 | class RegisterMigration(object): | |
298 | """ | |
299 | Tool for registering migrations | |
300 | ||
301 | Call like: | |
302 | ||
303 | @RegisterMigration(33) | |
304 | def update_dwarves(database): | |
305 | [...] | |
306 | ||
307 | This will register your migration with the default migration | |
308 | registry. Alternately, to specify a very specific | |
309 | migration_registry, you can pass in that as the second argument. | |
310 | ||
311 | Note, the number of your migration should NEVER be 0 or less than | |
312 | 0. 0 is the default "no migrations" state! | |
313 | """ | |
314 | def __init__(self, migration_number, migration_registry): | |
315 | assert migration_number > 0, "Migration number must be > 0!" | |
316 | assert migration_number not in migration_registry, \ | |
317 | "Duplicate migration numbers detected! That's not allowed!" | |
318 | ||
319 | self.migration_number = migration_number | |
320 | self.migration_registry = migration_registry | |
321 | ||
322 | def __call__(self, migration): | |
323 | self.migration_registry[self.migration_number] = migration | |
324 | return migration | |
325 | ||
326 | ||
327 | def assure_migrations_table_setup(db): | |
328 | """ | |
329 | Make sure the migrations table is set up in the database. | |
330 | """ | |
331 | from mediagoblin.db.models import MigrationData | |
332 | ||
333 | if not MigrationData.__table__.exists(db.bind): | |
334 | MigrationData.metadata.create_all( | |
335 | db.bind, tables=[MigrationData.__table__]) | |
c4466cb4 E |
336 | |
337 | ||
338 | def inspect_table(metadata, table_name): | |
339 | """Simple helper to get a ref to an already existing table""" | |
340 | return Table(table_name, metadata, autoload=True, | |
341 | autoload_with=metadata.bind) | |
e5196ff0 | 342 | |
0c875e1e CAW |
343 | def replace_table_hack(db, old_table, replacement_table): |
344 | """ | |
345 | A function to fully replace a current table with a new one for migrati- | |
e5196ff0 | 346 | -ons. This is necessary because some changes are made tricky in some situa- |
347 | -tion, for example, dropping a boolean column in sqlite is impossible w/o | |
348 | this method | |
349 | ||
350 | :param old_table A ref to the old table, gotten through | |
351 | inspect_table | |
352 | ||
353 | :param replacement_table A ref to the new table, gotten through | |
0c875e1e CAW |
354 | inspect_table |
355 | ||
356 | Users are encouraged to sqlalchemy-migrate replace table solutions, unless | |
357 | that is not possible... in which case, this solution works, | |
358 | at least for sqlite. | |
359 | """ | |
e5196ff0 | 360 | surviving_columns = replacement_table.columns.keys() |
361 | old_table_name = old_table.name | |
362 | for row in db.execute(select( | |
363 | [column for column in old_table.columns | |
364 | if column.name in surviving_columns])): | |
365 | ||
366 | db.execute(replacement_table.insert().values(**row)) | |
367 | db.commit() | |
368 | ||
369 | old_table.drop() | |
370 | db.commit() | |
371 | ||
454a2c16 | 372 | replacement_table.rename(old_table_name) |
e5196ff0 | 373 | db.commit() |