Updates so that dbupdate command works
[mediagoblin.git] / mediagoblin / db / sql / util.py
1 # GNU MediaGoblin -- federated, autonomous media hosting
2 # Copyright (C) 2011 MediaGoblin contributors. See AUTHORS.
3 #
4 # This program is free software: you can redistribute it and/or modify
5 # it under the terms of the GNU Affero General Public License as published by
6 # the Free Software Foundation, either version 3 of the License, or
7 # (at your option) any later version.
8 #
9 # This program is distributed in the hope that it will be useful,
10 # but WITHOUT ANY WARRANTY; without even the implied warranty of
11 # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
12 # GNU Affero General Public License for more details.
13 #
14 # You should have received a copy of the GNU Affero General Public License
15 # along with this program. If not, see <http://www.gnu.org/licenses/>.
16
17
18 import sys
19
20 def _simple_printer(string):
21 """
22 Prints a string, but without an auto \n at the end.
23 """
24 sys.stdout.write(string)
25 sys.stdout.flush()
26
27
28 class MigrationManager(object):
29 """
30 Migration handling tool.
31
32 Takes information about a database, lets you update the database
33 to the latest migrations, etc.
34 """
35
36 def __init__(self, name, models, migration_registry, session,
37 printer=_simple_printer):
38 """
39 Args:
40 - name: identifier of this section of the database
41 - session: session we're going to migrate
42 - migration_registry: where we should find all migrations to
43 run
44 """
45 self.name = name
46 self.models = models
47 self.session = session
48 self.migration_registry = migration_registry
49 self._sorted_migrations = None
50 self.printer = printer
51
52 # For convenience
53 from mediagoblin.db.sql.models import MigrationData
54
55 self.migration_model = MigrationData
56 self.migration_table = MigrationData.__table__
57
58 @property
59 def sorted_migrations(self):
60 """
61 Sort migrations if necessary and store in self._sorted_migrations
62 """
63 if not self._sorted_migrations:
64 self._sorted_migrations = sorted(
65 self.migration_registry.items(),
66 # sort on the key... the migration number
67 key=lambda migration_tuple: migration_tuple[0])
68
69 return self._sorted_migrations
70
71 @property
72 def migration_data(self):
73 """
74 Get the migration row associated with this object, if any.
75 """
76 return self.session.query(
77 self.migration_model).filter_by(name=self.name).first()
78
79 @property
80 def latest_migration(self):
81 """
82 Return a migration number for the latest migration, or 0 if
83 there are no migrations.
84 """
85 if self.sorted_migrations:
86 return self.sorted_migrations[-1][0]
87 else:
88 # If no migrations have been set, we start at 0.
89 return 0
90
91 @property
92 def database_current_migration(self):
93 """
94 Return the current migration in the database.
95 """
96 # If the table doesn't even exist, return None.
97 if not self.migration_table.exists(self.session.bind):
98 return None
99
100 # Also return None if self.migration_data is None.
101 if self.migration_data is None:
102 return None
103
104 return self.migration_data.version
105
106 def set_current_migration(self, migration_number=None):
107 """
108 Set the migration in the database to migration_number
109 (or, the latest available)
110 """
111 self.migration_data.version = migration_number or self.latest_migration
112 self.session.commit()
113
114 def migrations_to_run(self):
115 """
116 Get a list of migrations to run still, if any.
117
118 Note that this will fail if there's no migration record for
119 this class!
120 """
121 assert self.database_current_migration is not None
122
123 db_current_migration = self.database_current_migration
124
125 return [
126 (migration_number, migration_func)
127 for migration_number, migration_func in self.sorted_migrations
128 if migration_number > db_current_migration]
129
130
131 def init_tables(self):
132 """
133 Create all tables relative to this package
134 """
135 # sanity check before we proceed, none of these should be created
136 for model in self.models:
137 # Maybe in the future just print out a "Yikes!" or something?
138 assert not model.__table__.exists(self.session.bind)
139
140 self.migration_model.metadata.create_all(
141 self.session.bind,
142 tables=[model.__table__ for model in self.models])
143
144 def create_new_migration_record(self):
145 """
146 Create a new migration record for this migration set
147 """
148 migration_record = self.migration_model(
149 name=self.name,
150 version=self.latest_migration)
151 self.session.add(migration_record)
152 self.session.commit()
153
154 def dry_run(self):
155 """
156 Print out a dry run of what we would have upgraded.
157 """
158 if self.database_current_migration is None:
159 self.printer(
160 u'~> Woulda initialized: %s\n' % self.name_for_printing())
161 return u'inited'
162
163 migrations_to_run = self.migrations_to_run()
164 if migrations_to_run:
165 self.printer(
166 u'~> Woulda updated %s:\n' % self.name_for_printing())
167
168 for migration_number, migration_func in migrations_to_run():
169 self.printer(
170 u' + Would update %s, "%s"\n' % (
171 migration_number, migration_func.func_name))
172
173 return u'migrated'
174
175 def name_for_printing(self):
176 if self.name == u'__main__':
177 return u"main mediagoblin tables"
178 else:
179 # TODO: Use the friendlier media manager "human readable" name
180 return u'media type "%s"' % self.name
181
182 def init_or_migrate(self):
183 """
184 Initialize the database or migrate if appropriate.
185
186 Returns information about whether or not we initialized
187 ('inited'), migrated ('migrated'), or did nothing (None)
188 """
189 assure_migrations_table_setup(self.session)
190
191 # Find out what migration number, if any, this database data is at,
192 # and what the latest is.
193 migration_number = self.database_current_migration
194
195 # Is this our first time? Is there even a table entry for
196 # this identifier?
197 # If so:
198 # - create all tables
199 # - create record in migrations registry
200 # - print / inform the user
201 # - return 'inited'
202 if migration_number is None:
203 self.printer(u"-> Initializing %s... " % self.name_for_printing())
204
205 self.init_tables()
206 # auto-set at latest migration number
207 self.create_new_migration_record()
208
209 self.printer(u"done.\n")
210 self.set_current_migration()
211 return u'inited'
212
213 # Run migrations, if appropriate.
214 migrations_to_run = self.migrations_to_run()
215 if migrations_to_run:
216 self.printer(
217 u'-> Updating %s:\n' % self.name_for_printing())
218 for migration_number, migration_func in migrations_to_run:
219 self.printer(
220 u' + Running migration %s, "%s"... ' % (
221 migration_number, migration_func.func_name))
222 migration_func(self.session)
223 self.printer('done.\n')
224
225 self.set_current_migration()
226 return u'migrated'
227
228 # Otherwise return None. Well it would do this anyway, but
229 # for clarity... ;)
230 return None
231
232
233 class RegisterMigration(object):
234 """
235 Tool for registering migrations
236
237 Call like:
238
239 @RegisterMigration(33)
240 def update_dwarves(database):
241 [...]
242
243 This will register your migration with the default migration
244 registry. Alternately, to specify a very specific
245 migration_registry, you can pass in that as the second argument.
246
247 Note, the number of your migration should NEVER be 0 or less than
248 0. 0 is the default "no migrations" state!
249 """
250 def __init__(self, migration_number, migration_registry):
251 assert migration_number > 0, "Migration number must be > 0!"
252 assert migration_number not in migration_registry, \
253 "Duplicate migration numbers detected! That's not allowed!"
254
255 self.migration_number = migration_number
256 self.migration_registry = migration_registry
257
258 def __call__(self, migration):
259 self.migration_registry[self.migration_number] = migration
260 return migration
261
262
263 def assure_migrations_table_setup(db):
264 """
265 Make sure the migrations table is set up in the database.
266 """
267 from mediagoblin.db.sql.models import MigrationData
268
269 if not MigrationData.__table__.exists(db.bind):
270 MigrationData.metadata.create_all(
271 db.bind, tables=[MigrationData.__table__])