merge with remote master branch of cweb.
[mediagoblin.git] / mediagoblin / db / migration_tools.py
1 # GNU MediaGoblin -- federated, autonomous media hosting
2 # Copyright (C) 2011, 2012 MediaGoblin contributors. See AUTHORS.
3 #
4 # This program is free software: you can redistribute it and/or modify
5 # it under the terms of the GNU Affero General Public License as published by
6 # the Free Software Foundation, either version 3 of the License, or
7 # (at your option) any later version.
8 #
9 # This program is distributed in the hope that it will be useful,
10 # but WITHOUT ANY WARRANTY; without even the implied warranty of
11 # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
12 # GNU Affero General Public License for more details.
13 #
14 # You should have received a copy of the GNU Affero General Public License
15 # along with this program. If not, see <http://www.gnu.org/licenses/>.
16
17 from mediagoblin.tools.common import simple_printer
18 from sqlalchemy import Table
19
20 class TableAlreadyExists(Exception):
21 pass
22
23
24 class MigrationManager(object):
25 """
26 Migration handling tool.
27
28 Takes information about a database, lets you update the database
29 to the latest migrations, etc.
30 """
31
32 def __init__(self, name, models, migration_registry, session,
33 printer=simple_printer):
34 """
35 Args:
36 - name: identifier of this section of the database
37 - session: session we're going to migrate
38 - migration_registry: where we should find all migrations to
39 run
40 """
41 self.name = unicode(name)
42 self.models = models
43 self.session = session
44 self.migration_registry = migration_registry
45 self._sorted_migrations = None
46 self.printer = printer
47
48 # For convenience
49 from mediagoblin.db.models import MigrationData
50
51 self.migration_model = MigrationData
52 self.migration_table = MigrationData.__table__
53
54 @property
55 def sorted_migrations(self):
56 """
57 Sort migrations if necessary and store in self._sorted_migrations
58 """
59 if not self._sorted_migrations:
60 self._sorted_migrations = sorted(
61 self.migration_registry.items(),
62 # sort on the key... the migration number
63 key=lambda migration_tuple: migration_tuple[0])
64
65 return self._sorted_migrations
66
67 @property
68 def migration_data(self):
69 """
70 Get the migration row associated with this object, if any.
71 """
72 return self.session.query(
73 self.migration_model).filter_by(name=self.name).first()
74
75 @property
76 def latest_migration(self):
77 """
78 Return a migration number for the latest migration, or 0 if
79 there are no migrations.
80 """
81 if self.sorted_migrations:
82 return self.sorted_migrations[-1][0]
83 else:
84 # If no migrations have been set, we start at 0.
85 return 0
86
87 @property
88 def database_current_migration(self):
89 """
90 Return the current migration in the database.
91 """
92 # If the table doesn't even exist, return None.
93 if not self.migration_table.exists(self.session.bind):
94 return None
95
96 # Also return None if self.migration_data is None.
97 if self.migration_data is None:
98 return None
99
100 return self.migration_data.version
101
102 def set_current_migration(self, migration_number=None):
103 """
104 Set the migration in the database to migration_number
105 (or, the latest available)
106 """
107 self.migration_data.version = migration_number or self.latest_migration
108 self.session.commit()
109
110 def migrations_to_run(self):
111 """
112 Get a list of migrations to run still, if any.
113
114 Note that this will fail if there's no migration record for
115 this class!
116 """
117 assert self.database_current_migration is not None
118
119 db_current_migration = self.database_current_migration
120
121 return [
122 (migration_number, migration_func)
123 for migration_number, migration_func in self.sorted_migrations
124 if migration_number > db_current_migration]
125
126
127 def init_tables(self):
128 """
129 Create all tables relative to this package
130 """
131 # sanity check before we proceed, none of these should be created
132 for model in self.models:
133 # Maybe in the future just print out a "Yikes!" or something?
134 if model.__table__.exists(self.session.bind):
135 raise TableAlreadyExists(
136 u"Intended to create table '%s' but it already exists" %
137 model.__table__.name)
138
139 self.migration_model.metadata.create_all(
140 self.session.bind,
141 tables=[model.__table__ for model in self.models])
142
143 def create_new_migration_record(self):
144 """
145 Create a new migration record for this migration set
146 """
147 migration_record = self.migration_model(
148 name=self.name,
149 version=self.latest_migration)
150 self.session.add(migration_record)
151 self.session.commit()
152
153 def dry_run(self):
154 """
155 Print out a dry run of what we would have upgraded.
156 """
157 if self.database_current_migration is None:
158 self.printer(
159 u'~> Woulda initialized: %s\n' % self.name_for_printing())
160 return u'inited'
161
162 migrations_to_run = self.migrations_to_run()
163 if migrations_to_run:
164 self.printer(
165 u'~> Woulda updated %s:\n' % self.name_for_printing())
166
167 for migration_number, migration_func in migrations_to_run():
168 self.printer(
169 u' + Would update %s, "%s"\n' % (
170 migration_number, migration_func.func_name))
171
172 return u'migrated'
173
174 def name_for_printing(self):
175 if self.name == u'__main__':
176 return u"main mediagoblin tables"
177 else:
178 return u'plugin "%s"' % self.name
179
180 def init_or_migrate(self):
181 """
182 Initialize the database or migrate if appropriate.
183
184 Returns information about whether or not we initialized
185 ('inited'), migrated ('migrated'), or did nothing (None)
186 """
187 assure_migrations_table_setup(self.session)
188
189 # Find out what migration number, if any, this database data is at,
190 # and what the latest is.
191 migration_number = self.database_current_migration
192
193 # Is this our first time? Is there even a table entry for
194 # this identifier?
195 # If so:
196 # - create all tables
197 # - create record in migrations registry
198 # - print / inform the user
199 # - return 'inited'
200 if migration_number is None:
201 self.printer(u"-> Initializing %s... " % self.name_for_printing())
202
203 self.init_tables()
204 # auto-set at latest migration number
205 self.create_new_migration_record()
206
207 self.printer(u"done.\n")
208 self.set_current_migration()
209 return u'inited'
210
211 # Run migrations, if appropriate.
212 migrations_to_run = self.migrations_to_run()
213 if migrations_to_run:
214 self.printer(
215 u'-> Updating %s:\n' % self.name_for_printing())
216 for migration_number, migration_func in migrations_to_run:
217 self.printer(
218 u' + Running migration %s, "%s"... ' % (
219 migration_number, migration_func.func_name))
220 migration_func(self.session)
221 self.set_current_migration(migration_number)
222 self.printer('done.\n')
223
224 return u'migrated'
225
226 # Otherwise return None. Well it would do this anyway, but
227 # for clarity... ;)
228 return None
229
230
231 class RegisterMigration(object):
232 """
233 Tool for registering migrations
234
235 Call like:
236
237 @RegisterMigration(33)
238 def update_dwarves(database):
239 [...]
240
241 This will register your migration with the default migration
242 registry. Alternately, to specify a very specific
243 migration_registry, you can pass in that as the second argument.
244
245 Note, the number of your migration should NEVER be 0 or less than
246 0. 0 is the default "no migrations" state!
247 """
248 def __init__(self, migration_number, migration_registry):
249 assert migration_number > 0, "Migration number must be > 0!"
250 assert migration_number not in migration_registry, \
251 "Duplicate migration numbers detected! That's not allowed!"
252
253 self.migration_number = migration_number
254 self.migration_registry = migration_registry
255
256 def __call__(self, migration):
257 self.migration_registry[self.migration_number] = migration
258 return migration
259
260
261 def assure_migrations_table_setup(db):
262 """
263 Make sure the migrations table is set up in the database.
264 """
265 from mediagoblin.db.models import MigrationData
266
267 if not MigrationData.__table__.exists(db.bind):
268 MigrationData.metadata.create_all(
269 db.bind, tables=[MigrationData.__table__])
270
271
272 def inspect_table(metadata, table_name):
273 """Simple helper to get a ref to an already existing table"""
274 return Table(table_name, metadata, autoload=True,
275 autoload_with=metadata.bind)