Use custom query class
[mediagoblin.git] / mediagoblin / db / sql / base.py
1 from sqlalchemy.orm import scoped_session, sessionmaker, object_session
2 from sqlalchemy.orm.query import Query
3 from sqlalchemy.sql.expression import desc
4 from mediagoblin.db.sql.fake import DESCENDING
5
6
7 def _get_query_model(query):
8 cols = query.column_descriptions
9 assert len(cols) == 1, "These functions work only on simple queries"
10 return cols[0]["type"]
11
12
13 class GMGQuery(Query):
14 def sort(self, key, direction):
15 key_col = getattr(_get_query_model(self), key)
16 if direction is DESCENDING:
17 key_col = desc(key_col)
18 return self.order_by(key_col)
19
20 def skip(self, amount):
21 return self.offset(amount)
22
23
24 Session = scoped_session(sessionmaker(query_cls=GMGQuery))
25
26
27 def _fix_query_dict(query_dict):
28 if '_id' in query_dict:
29 query_dict['id'] = query_dict.pop('_id')
30
31
32 class GMGTableBase(object):
33 query = Session.query_property()
34
35 @classmethod
36 def find(cls, query_dict={}):
37 _fix_query_dict(query_dict)
38 return cls.query.filter_by(**query_dict)
39
40 @classmethod
41 def find_one(cls, query_dict={}):
42 _fix_query_dict(query_dict)
43 return cls.query.filter_by(**query_dict).first()
44
45 @classmethod
46 def one(cls, query_dict):
47 return cls.find(query_dict).one()
48
49 def get(self, key):
50 return getattr(self, key)
51
52 def save(self, validate = True):
53 assert validate
54 sess = object_session(self)
55 if sess is None:
56 sess = Session()
57 sess.add(self)
58 sess.commit()