Added migration for license field, resolved conflict in db/sql/models.py
[mediagoblin.git] / mediagoblin / db / sql / base.py
1 # GNU MediaGoblin -- federated, autonomous media hosting
2 # Copyright (C) 2011,2012 MediaGoblin contributors. See AUTHORS.
3 #
4 # This program is free software: you can redistribute it and/or modify
5 # it under the terms of the GNU Affero General Public License as published by
6 # the Free Software Foundation, either version 3 of the License, or
7 # (at your option) any later version.
8 #
9 # This program is distributed in the hope that it will be useful,
10 # but WITHOUT ANY WARRANTY; without even the implied warranty of
11 # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
12 # GNU Affero General Public License for more details.
13 #
14 # You should have received a copy of the GNU Affero General Public License
15 # along with this program. If not, see <http://www.gnu.org/licenses/>.
16
17
18 from sqlalchemy.orm import scoped_session, sessionmaker, object_session
19 from sqlalchemy.orm.query import Query
20 from sqlalchemy.sql.expression import desc
21 from mediagoblin.db.sql.fake import DESCENDING
22
23
24 def _get_query_model(query):
25 cols = query.column_descriptions
26 assert len(cols) == 1, "These functions work only on simple queries"
27 return cols[0]["type"]
28
29
30 class GMGQuery(Query):
31 def sort(self, key, direction):
32 key_col = getattr(_get_query_model(self), key)
33 if direction is DESCENDING:
34 key_col = desc(key_col)
35 return self.order_by(key_col)
36
37 def skip(self, amount):
38 return self.offset(amount)
39
40
41 Session = scoped_session(sessionmaker(query_cls=GMGQuery))
42
43
44 def _fix_query_dict(query_dict):
45 if '_id' in query_dict:
46 query_dict['id'] = query_dict.pop('_id')
47
48
49 class GMGTableBase(object):
50 query = Session.query_property()
51
52 @classmethod
53 def find(cls, query_dict={}):
54 _fix_query_dict(query_dict)
55 return cls.query.filter_by(**query_dict)
56
57 @classmethod
58 def find_one(cls, query_dict={}):
59 _fix_query_dict(query_dict)
60 return cls.query.filter_by(**query_dict).first()
61
62 @classmethod
63 def one(cls, query_dict):
64 return cls.find(query_dict).one()
65
66 def get(self, key):
67 return getattr(self, key)
68
69 def save(self, validate=True):
70 assert validate
71 sess = object_session(self)
72 if sess is None:
73 sess = Session()
74 sess.add(self)
75 sess.commit()