1 # GNU MediaGoblin -- federated, autonomous media hosting
2 # Copyright (C) 2011,2012 MediaGoblin contributors. See AUTHORS.
4 # This program is free software: you can redistribute it and/or modify
5 # it under the terms of the GNU Affero General Public License as published by
6 # the Free Software Foundation, either version 3 of the License, or
7 # (at your option) any later version.
9 # This program is distributed in the hope that it will be useful,
10 # but WITHOUT ANY WARRANTY; without even the implied warranty of
11 # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
12 # GNU Affero General Public License for more details.
14 # You should have received a copy of the GNU Affero General Public License
15 # along with this program. If not, see <http://www.gnu.org/licenses/>.
18 from sqlalchemy
.orm
import scoped_session
, sessionmaker
, object_session
19 from sqlalchemy
.orm
.query
import Query
20 from sqlalchemy
.sql
.expression
import desc
21 from mediagoblin
.db
.sql
.fake
import DESCENDING
24 def _get_query_model(query
):
25 cols
= query
.column_descriptions
26 assert len(cols
) == 1, "These functions work only on simple queries"
27 return cols
[0]["type"]
30 class GMGQuery(Query
):
31 def sort(self
, key
, direction
):
32 key_col
= getattr(_get_query_model(self
), key
)
33 if direction
is DESCENDING
:
34 key_col
= desc(key_col
)
35 return self
.order_by(key_col
)
37 def skip(self
, amount
):
38 return self
.offset(amount
)
41 Session
= scoped_session(sessionmaker(query_cls
=GMGQuery
))
44 def _fix_query_dict(query_dict
):
45 if '_id' in query_dict
:
46 query_dict
['id'] = query_dict
.pop('_id')
49 class GMGTableBase(object):
50 query
= Session
.query_property()
53 def find(cls
, query_dict
={}):
54 _fix_query_dict(query_dict
)
55 return cls
.query
.filter_by(**query_dict
)
58 def find_one(cls
, query_dict
={}):
59 _fix_query_dict(query_dict
)
60 return cls
.query
.filter_by(**query_dict
).first()
63 def one(cls
, query_dict
):
64 return cls
.find(query_dict
).one()
67 return getattr(self
, key
)
69 def save(self
, validate
=True):
71 sess
= object_session(self
)