Give a more useful error if a table already exists and so we can't create it during...
[mediagoblin.git] / mediagoblin / db / migration_tools.py
1 # GNU MediaGoblin -- federated, autonomous media hosting
2 # Copyright (C) 2011, 2012 MediaGoblin contributors. See AUTHORS.
3 #
4 # This program is free software: you can redistribute it and/or modify
5 # it under the terms of the GNU Affero General Public License as published by
6 # the Free Software Foundation, either version 3 of the License, or
7 # (at your option) any later version.
8 #
9 # This program is distributed in the hope that it will be useful,
10 # but WITHOUT ANY WARRANTY; without even the implied warranty of
11 # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
12 # GNU Affero General Public License for more details.
13 #
14 # You should have received a copy of the GNU Affero General Public License
15 # along with this program. If not, see <http://www.gnu.org/licenses/>.
16
17 from mediagoblin.tools.common import simple_printer
18 from sqlalchemy import Table
19
20 class TableAlreadyExists(Exception):
21 pass
22
23
24 class MigrationManager(object):
25 """
26 Migration handling tool.
27
28 Takes information about a database, lets you update the database
29 to the latest migrations, etc.
30 """
31
32 def __init__(self, name, models, migration_registry, session,
33 printer=simple_printer):
34 """
35 Args:
36 - name: identifier of this section of the database
37 - session: session we're going to migrate
38 - migration_registry: where we should find all migrations to
39 run
40 """
41 self.name = unicode(name)
42 self.models = models
43 self.session = session
44 self.migration_registry = migration_registry
45 self._sorted_migrations = None
46 self.printer = printer
47
48 # For convenience
49 from mediagoblin.db.models import MigrationData
50
51 self.migration_model = MigrationData
52 self.migration_table = MigrationData.__table__
53
54 @property
55 def sorted_migrations(self):
56 """
57 Sort migrations if necessary and store in self._sorted_migrations
58 """
59 if not self._sorted_migrations:
60 self._sorted_migrations = sorted(
61 self.migration_registry.items(),
62 # sort on the key... the migration number
63 key=lambda migration_tuple: migration_tuple[0])
64
65 return self._sorted_migrations
66
67 @property
68 def migration_data(self):
69 """
70 Get the migration row associated with this object, if any.
71 """
72 return self.session.query(
73 self.migration_model).filter_by(name=self.name).first()
74
75 @property
76 def latest_migration(self):
77 """
78 Return a migration number for the latest migration, or 0 if
79 there are no migrations.
80 """
81 if self.sorted_migrations:
82 return self.sorted_migrations[-1][0]
83 else:
84 # If no migrations have been set, we start at 0.
85 return 0
86
87 @property
88 def database_current_migration(self):
89 """
90 Return the current migration in the database.
91 """
92 # If the table doesn't even exist, return None.
93 if not self.migration_table.exists(self.session.bind):
94 return None
95
96 # Also return None if self.migration_data is None.
97 if self.migration_data is None:
98 return None
99
100 return self.migration_data.version
101
102 def set_current_migration(self, migration_number=None):
103 """
104 Set the migration in the database to migration_number
105 (or, the latest available)
106 """
107 self.migration_data.version = migration_number or self.latest_migration
108 self.session.commit()
109
110 def migrations_to_run(self):
111 """
112 Get a list of migrations to run still, if any.
113
114 Note that this will fail if there's no migration record for
115 this class!
116 """
117 assert self.database_current_migration is not None
118
119 db_current_migration = self.database_current_migration
120
121 return [
122 (migration_number, migration_func)
123 for migration_number, migration_func in self.sorted_migrations
124 if migration_number > db_current_migration]
125
126
127 def init_tables(self):
128 """
129 Create all tables relative to this package
130 """
131 # sanity check before we proceed, none of these should be created
132 for model in self.models:
133 # Maybe in the future just print out a "Yikes!" or something?
134 if model.__table__.exists(self.session.bind):
135 raise TableAlreadyExists(
136 u"Intended to create table '%s' but it already exists" %
137 model.__table__.name)
138
139 self.migration_model.metadata.create_all(
140 self.session.bind,
141 tables=[model.__table__ for model in self.models])
142
143 def create_new_migration_record(self):
144 """
145 Create a new migration record for this migration set
146 """
147 migration_record = self.migration_model(
148 name=self.name,
149 version=self.latest_migration)
150 self.session.add(migration_record)
151 self.session.commit()
152
153 def dry_run(self):
154 """
155 Print out a dry run of what we would have upgraded.
156 """
157 if self.database_current_migration is None:
158 self.printer(
159 u'~> Woulda initialized: %s\n' % self.name_for_printing())
160 return u'inited'
161
162 migrations_to_run = self.migrations_to_run()
163 if migrations_to_run:
164 self.printer(
165 u'~> Woulda updated %s:\n' % self.name_for_printing())
166
167 for migration_number, migration_func in migrations_to_run():
168 self.printer(
169 u' + Would update %s, "%s"\n' % (
170 migration_number, migration_func.func_name))
171
172 return u'migrated'
173
174 def name_for_printing(self):
175 if self.name == u'__main__':
176 return u"main mediagoblin tables"
177 else:
178 # TODO: Use the friendlier media manager "human readable" name
179 return u'media type "%s"' % self.name
180
181 def init_or_migrate(self):
182 """
183 Initialize the database or migrate if appropriate.
184
185 Returns information about whether or not we initialized
186 ('inited'), migrated ('migrated'), or did nothing (None)
187 """
188 assure_migrations_table_setup(self.session)
189
190 # Find out what migration number, if any, this database data is at,
191 # and what the latest is.
192 migration_number = self.database_current_migration
193
194 # Is this our first time? Is there even a table entry for
195 # this identifier?
196 # If so:
197 # - create all tables
198 # - create record in migrations registry
199 # - print / inform the user
200 # - return 'inited'
201 if migration_number is None:
202 self.printer(u"-> Initializing %s... " % self.name_for_printing())
203
204 self.init_tables()
205 # auto-set at latest migration number
206 self.create_new_migration_record()
207
208 self.printer(u"done.\n")
209 self.set_current_migration()
210 return u'inited'
211
212 # Run migrations, if appropriate.
213 migrations_to_run = self.migrations_to_run()
214 if migrations_to_run:
215 self.printer(
216 u'-> Updating %s:\n' % self.name_for_printing())
217 for migration_number, migration_func in migrations_to_run:
218 self.printer(
219 u' + Running migration %s, "%s"... ' % (
220 migration_number, migration_func.func_name))
221 migration_func(self.session)
222 self.set_current_migration(migration_number)
223 self.printer('done.\n')
224
225 return u'migrated'
226
227 # Otherwise return None. Well it would do this anyway, but
228 # for clarity... ;)
229 return None
230
231
232 class RegisterMigration(object):
233 """
234 Tool for registering migrations
235
236 Call like:
237
238 @RegisterMigration(33)
239 def update_dwarves(database):
240 [...]
241
242 This will register your migration with the default migration
243 registry. Alternately, to specify a very specific
244 migration_registry, you can pass in that as the second argument.
245
246 Note, the number of your migration should NEVER be 0 or less than
247 0. 0 is the default "no migrations" state!
248 """
249 def __init__(self, migration_number, migration_registry):
250 assert migration_number > 0, "Migration number must be > 0!"
251 assert migration_number not in migration_registry, \
252 "Duplicate migration numbers detected! That's not allowed!"
253
254 self.migration_number = migration_number
255 self.migration_registry = migration_registry
256
257 def __call__(self, migration):
258 self.migration_registry[self.migration_number] = migration
259 return migration
260
261
262 def assure_migrations_table_setup(db):
263 """
264 Make sure the migrations table is set up in the database.
265 """
266 from mediagoblin.db.models import MigrationData
267
268 if not MigrationData.__table__.exists(db.bind):
269 MigrationData.metadata.create_all(
270 db.bind, tables=[MigrationData.__table__])
271
272
273 def inspect_table(metadata, table_name):
274 """Simple helper to get a ref to an already existing table"""
275 return Table(table_name, metadata, autoload=True,
276 autoload_with=metadata.bind)