Merge commit '9408938' from 565_workbench_cleanup (spaetz)
[mediagoblin.git] / mediagoblin / db / migration_tools.py
1 # GNU MediaGoblin -- federated, autonomous media hosting
2 # Copyright (C) 2011, 2012 MediaGoblin contributors. See AUTHORS.
3 #
4 # This program is free software: you can redistribute it and/or modify
5 # it under the terms of the GNU Affero General Public License as published by
6 # the Free Software Foundation, either version 3 of the License, or
7 # (at your option) any later version.
8 #
9 # This program is distributed in the hope that it will be useful,
10 # but WITHOUT ANY WARRANTY; without even the implied warranty of
11 # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
12 # GNU Affero General Public License for more details.
13 #
14 # You should have received a copy of the GNU Affero General Public License
15 # along with this program. If not, see <http://www.gnu.org/licenses/>.
16
17 from mediagoblin.tools.common import simple_printer
18 from sqlalchemy import Table
19
20
21 class MigrationManager(object):
22 """
23 Migration handling tool.
24
25 Takes information about a database, lets you update the database
26 to the latest migrations, etc.
27 """
28
29 def __init__(self, name, models, migration_registry, session,
30 printer=simple_printer):
31 """
32 Args:
33 - name: identifier of this section of the database
34 - session: session we're going to migrate
35 - migration_registry: where we should find all migrations to
36 run
37 """
38 self.name = unicode(name)
39 self.models = models
40 self.session = session
41 self.migration_registry = migration_registry
42 self._sorted_migrations = None
43 self.printer = printer
44
45 # For convenience
46 from mediagoblin.db.models import MigrationData
47
48 self.migration_model = MigrationData
49 self.migration_table = MigrationData.__table__
50
51 @property
52 def sorted_migrations(self):
53 """
54 Sort migrations if necessary and store in self._sorted_migrations
55 """
56 if not self._sorted_migrations:
57 self._sorted_migrations = sorted(
58 self.migration_registry.items(),
59 # sort on the key... the migration number
60 key=lambda migration_tuple: migration_tuple[0])
61
62 return self._sorted_migrations
63
64 @property
65 def migration_data(self):
66 """
67 Get the migration row associated with this object, if any.
68 """
69 return self.session.query(
70 self.migration_model).filter_by(name=self.name).first()
71
72 @property
73 def latest_migration(self):
74 """
75 Return a migration number for the latest migration, or 0 if
76 there are no migrations.
77 """
78 if self.sorted_migrations:
79 return self.sorted_migrations[-1][0]
80 else:
81 # If no migrations have been set, we start at 0.
82 return 0
83
84 @property
85 def database_current_migration(self):
86 """
87 Return the current migration in the database.
88 """
89 # If the table doesn't even exist, return None.
90 if not self.migration_table.exists(self.session.bind):
91 return None
92
93 # Also return None if self.migration_data is None.
94 if self.migration_data is None:
95 return None
96
97 return self.migration_data.version
98
99 def set_current_migration(self, migration_number=None):
100 """
101 Set the migration in the database to migration_number
102 (or, the latest available)
103 """
104 self.migration_data.version = migration_number or self.latest_migration
105 self.session.commit()
106
107 def migrations_to_run(self):
108 """
109 Get a list of migrations to run still, if any.
110
111 Note that this will fail if there's no migration record for
112 this class!
113 """
114 assert self.database_current_migration is not None
115
116 db_current_migration = self.database_current_migration
117
118 return [
119 (migration_number, migration_func)
120 for migration_number, migration_func in self.sorted_migrations
121 if migration_number > db_current_migration]
122
123
124 def init_tables(self):
125 """
126 Create all tables relative to this package
127 """
128 # sanity check before we proceed, none of these should be created
129 for model in self.models:
130 # Maybe in the future just print out a "Yikes!" or something?
131 assert not model.__table__.exists(self.session.bind)
132
133 self.migration_model.metadata.create_all(
134 self.session.bind,
135 tables=[model.__table__ for model in self.models])
136
137 def create_new_migration_record(self):
138 """
139 Create a new migration record for this migration set
140 """
141 migration_record = self.migration_model(
142 name=self.name,
143 version=self.latest_migration)
144 self.session.add(migration_record)
145 self.session.commit()
146
147 def dry_run(self):
148 """
149 Print out a dry run of what we would have upgraded.
150 """
151 if self.database_current_migration is None:
152 self.printer(
153 u'~> Woulda initialized: %s\n' % self.name_for_printing())
154 return u'inited'
155
156 migrations_to_run = self.migrations_to_run()
157 if migrations_to_run:
158 self.printer(
159 u'~> Woulda updated %s:\n' % self.name_for_printing())
160
161 for migration_number, migration_func in migrations_to_run():
162 self.printer(
163 u' + Would update %s, "%s"\n' % (
164 migration_number, migration_func.func_name))
165
166 return u'migrated'
167
168 def name_for_printing(self):
169 if self.name == u'__main__':
170 return u"main mediagoblin tables"
171 else:
172 # TODO: Use the friendlier media manager "human readable" name
173 return u'media type "%s"' % self.name
174
175 def init_or_migrate(self):
176 """
177 Initialize the database or migrate if appropriate.
178
179 Returns information about whether or not we initialized
180 ('inited'), migrated ('migrated'), or did nothing (None)
181 """
182 assure_migrations_table_setup(self.session)
183
184 # Find out what migration number, if any, this database data is at,
185 # and what the latest is.
186 migration_number = self.database_current_migration
187
188 # Is this our first time? Is there even a table entry for
189 # this identifier?
190 # If so:
191 # - create all tables
192 # - create record in migrations registry
193 # - print / inform the user
194 # - return 'inited'
195 if migration_number is None:
196 self.printer(u"-> Initializing %s... " % self.name_for_printing())
197
198 self.init_tables()
199 # auto-set at latest migration number
200 self.create_new_migration_record()
201
202 self.printer(u"done.\n")
203 self.set_current_migration()
204 return u'inited'
205
206 # Run migrations, if appropriate.
207 migrations_to_run = self.migrations_to_run()
208 if migrations_to_run:
209 self.printer(
210 u'-> Updating %s:\n' % self.name_for_printing())
211 for migration_number, migration_func in migrations_to_run:
212 self.printer(
213 u' + Running migration %s, "%s"... ' % (
214 migration_number, migration_func.func_name))
215 migration_func(self.session)
216 self.set_current_migration(migration_number)
217 self.printer('done.\n')
218
219 return u'migrated'
220
221 # Otherwise return None. Well it would do this anyway, but
222 # for clarity... ;)
223 return None
224
225
226 class RegisterMigration(object):
227 """
228 Tool for registering migrations
229
230 Call like:
231
232 @RegisterMigration(33)
233 def update_dwarves(database):
234 [...]
235
236 This will register your migration with the default migration
237 registry. Alternately, to specify a very specific
238 migration_registry, you can pass in that as the second argument.
239
240 Note, the number of your migration should NEVER be 0 or less than
241 0. 0 is the default "no migrations" state!
242 """
243 def __init__(self, migration_number, migration_registry):
244 assert migration_number > 0, "Migration number must be > 0!"
245 assert migration_number not in migration_registry, \
246 "Duplicate migration numbers detected! That's not allowed!"
247
248 self.migration_number = migration_number
249 self.migration_registry = migration_registry
250
251 def __call__(self, migration):
252 self.migration_registry[self.migration_number] = migration
253 return migration
254
255
256 def assure_migrations_table_setup(db):
257 """
258 Make sure the migrations table is set up in the database.
259 """
260 from mediagoblin.db.models import MigrationData
261
262 if not MigrationData.__table__.exists(db.bind):
263 MigrationData.metadata.create_all(
264 db.bind, tables=[MigrationData.__table__])
265
266
267 def inspect_table(metadata, table_name):
268 """Simple helper to get a ref to an already existing table"""
269 return Table(table_name, metadata, autoload=True,
270 autoload_with=metadata.bind)