Merge branch 'master' into joar-skip_transcoding
[mediagoblin.git] / mediagoblin / decorators.py
index 2955c9276688095208016b3c839b3eaf94e76aad..fbf7b1887c9a15c439a0748820165ab812a2b238 100644 (file)
 from functools import wraps
 
 from urlparse import urljoin
-from urllib import urlencode
+from werkzeug.exceptions import Forbidden, NotFound
+from werkzeug.urls import url_quote
 
-from webob import exc
-
-from mediagoblin.db.util import ObjectId, InvalidId
-from mediagoblin.db.sql.models import User
+from mediagoblin import mg_globals as mgg
+from mediagoblin.db.models import MediaEntry, User
 from mediagoblin.tools.response import redirect, render_404
 
 
@@ -33,21 +32,18 @@ def require_active_login(controller):
     @wraps(controller)
     def new_controller_func(request, *args, **kwargs):
         if request.user and \
-                request.user.get('status') == u'needs_email_verification':
+                request.user.status == u'needs_email_verification':
             return redirect(
                 request, 'mediagoblin.user_pages.user_home',
                 user=request.user.username)
-        elif not request.user or request.user.get('status') != u'active':
+        elif not request.user or request.user.status != u'active':
             next_url = urljoin(
                     request.urlgen('mediagoblin.auth.login',
                         qualified=True),
                     request.url)
 
-            return exc.HTTPFound(
-                location='?'.join([
-                    request.urlgen('mediagoblin.auth.login'),
-                    urlencode({
-                        'next': next_url})]))
+            return redirect(request, 'mediagoblin.auth.login',
+                            next=url_quote(next_url))
 
         return controller(request, *args, **kwargs)
 
@@ -74,11 +70,10 @@ def user_may_delete_media(controller):
     """
     @wraps(controller)
     def wrapper(request, *args, **kwargs):
-        uploader_id = request.db.MediaEntry.find_one(
-            {'id': ObjectId(request.matchdict['media'])}).uploader
+        uploader_id = kwargs['media'].uploader
         if not (request.user.is_admin or
                 request.user.id == uploader_id):
-            return exc.HTTPForbidden()
+            raise Forbidden()
 
         return controller(request, *args, **kwargs)
 
@@ -95,7 +90,7 @@ def user_may_alter_collection(controller):
             {'username': request.matchdict['user']}).id
         if not (request.user.is_admin or
                 request.user.id == creator_id):
-            return exc.HTTPForbidden()
+            raise Forbidden()
 
         return controller(request, *args, **kwargs)
 
@@ -126,29 +121,34 @@ def get_user_media_entry(controller):
     """
     @wraps(controller)
     def wrapper(request, *args, **kwargs):
-        user = request.db.User.find_one(
-            {'username': request.matchdict['user']})
-
+        user = User.query.filter_by(username=request.matchdict['user']).first()
         if not user:
-            return render_404(request)
-        media = request.db.MediaEntry.find_one(
-            {'slug': request.matchdict['media'],
-             'state': u'processed',
-             'uploader': user.id})
+            raise NotFound()
 
-        # no media via slug?  Grab it via ObjectId
-        if not media:
+        media = None
+
+        # might not be a slug, might be an id, but whatever
+        media_slug = request.matchdict['media']
+
+        # if it starts with id: it actually isn't a slug, it's an id.
+        if media_slug.startswith(u'id:'):
             try:
-                media = request.db.MediaEntry.find_one(
-                    {'id': ObjectId(request.matchdict['media']),
-                     'state': u'processed',
-                     'uploader': user.id})
-            except InvalidId:
-                return render_404(request)
+                media = MediaEntry.query.filter_by(
+                    id=int(media_slug[3:]),
+                    state=u'processed',
+                    uploader=user.id).first()
+            except ValueError:
+                raise NotFound()
+        else:
+            # no magical id: stuff?  It's a slug!
+            media = MediaEntry.query.filter_by(
+                slug=media_slug,
+                state=u'processed',
+                uploader=user.id).first()
 
-            # Still no media?  Okay, 404.
-            if not media:
-                return render_404(request)
+        if not media:
+            # Didn't find anything?  Okay, 404.
+            raise NotFound()
 
         return controller(request, media=media, *args, **kwargs)
 
@@ -192,10 +192,6 @@ def get_user_collection_item(controller):
         if not user:
             return render_404(request)
 
-        collection = request.db.Collection.find_one(
-            {'slug': request.matchdict['collection'],
-             'creator': user.id})
-
         collection_item = request.db.CollectionItem.find_one(
             {'id': request.matchdict['collection_item'] })
 
@@ -214,17 +210,28 @@ def get_media_entry_by_id(controller):
     """
     @wraps(controller)
     def wrapper(request, *args, **kwargs):
-        try:
-            media = request.db.MediaEntry.find_one(
-                {'id': ObjectId(request.matchdict['media']),
-                 'state': u'processed'})
-        except InvalidId:
-            return render_404(request)
-
+        media = MediaEntry.query.filter_by(
+                id=request.matchdict['media_id'],
+                state=u'processed').first()
         # Still no media?  Okay, 404.
         if not media:
             return render_404(request)
 
+        given_username = request.matchdict.get('user')
+        if given_username and (given_username != media.get_uploader.username):
+            return render_404(request)
+
         return controller(request, media=media, *args, **kwargs)
 
     return wrapper
+
+
+def get_workbench(func):
+    """Decorator, passing in a workbench as kwarg which is cleaned up afterwards"""
+
+    @wraps(func)
+    def new_func(*args, **kwargs):
+        with mgg.workbench_manager.create() as workbench:
+            return func(*args, workbench=workbench, **kwargs)
+
+    return new_func