Correcting a couple of spelling errors. Thanks elesa, for finding them!
[mediagoblin.git] / mediagoblin / decorators.py
index 229664d79cd2bc6963326c0d9c076dc3ca0a6a5f..9be9d4cc1246b6d8348abab0bfabb6ba3b9f672a 100644 (file)
@@ -1,5 +1,5 @@
 # GNU MediaGoblin -- federated, autonomous media hosting
-# Copyright (C) 2011 MediaGoblin contributors.  See AUTHORS.
+# Copyright (C) 2011, 2012 MediaGoblin contributors.  See AUTHORS.
 #
 # This program is free software: you can redistribute it and/or modify
 # it under the terms of the GNU Affero General Public License as published by
 # You should have received a copy of the GNU Affero General Public License
 # along with this program.  If not, see <http://www.gnu.org/licenses/>.
 
+from functools import wraps
+
+from urlparse import urljoin
+from urllib import urlencode
 
 from webob import exc
 
-from mediagoblin.tools.response import redirect, render_404
 from mediagoblin.db.util import ObjectId, InvalidId
-
-
-def _make_safe(decorator, original):
-    """
-    Copy the function data from the old function to the decorator.
-    """
-    decorator.__name__ = original.__name__
-    decorator.__dict__ = original.__dict__
-    decorator.__doc__ = original.__doc__
-    return decorator
+from mediagoblin.tools.response import redirect, render_404
 
 
 def require_active_login(controller):
     """
     Require an active login from the user.
     """
+    @wraps(controller)
     def new_controller_func(request, *args, **kwargs):
         if request.user and \
                 request.user.get('status') == u'needs_email_verification':
@@ -42,36 +37,61 @@ def require_active_login(controller):
                 request, 'mediagoblin.user_pages.user_home',
                 user=request.user.username)
         elif not request.user or request.user.get('status') != u'active':
+            next_url = urljoin(
+                    request.urlgen('mediagoblin.auth.login',
+                        qualified=True),
+                    request.url)
+
             return exc.HTTPFound(
-                location="%s?next=%s" % (
-                    request.urlgen("mediagoblin.auth.login"),
-                    request.full_path))
+                location='?'.join([
+                    request.urlgen('mediagoblin.auth.login'),
+                    urlencode({
+                        'next': next_url})]))
 
         return controller(request, *args, **kwargs)
 
-    return _make_safe(new_controller_func, controller)
+    return new_controller_func
 
 
 def user_may_delete_media(controller):
     """
     Require user ownership of the MediaEntry to delete.
     """
+    @wraps(controller)
     def wrapper(request, *args, **kwargs):
-        uploader = request.db.MediaEntry.find_one(
-            {'_id': ObjectId(request.matchdict['media'])}).get_uploader()
+        uploader_id = request.db.MediaEntry.find_one(
+            {'_id': ObjectId(request.matchdict['media'])}).uploader
         if not (request.user.is_admin or
-                request.user._id == uploader._id):
+                request.user._id == uploader_id):
             return exc.HTTPForbidden()
 
         return controller(request, *args, **kwargs)
 
-    return _make_safe(wrapper, controller)
+    return wrapper
+
+
+def user_may_alter_collection(controller):
+    """
+    Require user ownership of the Collection to modify.
+    """
+    @wraps(controller)
+    def wrapper(request, *args, **kwargs):
+        creator_id = request.db.User.find_one(
+            {'username': request.matchdict['user']}).id
+        if not (request.user.is_admin or
+                request.user._id == creator_id):
+            return exc.HTTPForbidden()
+
+        return controller(request, *args, **kwargs)
+
+    return wrapper
 
 
 def uses_pagination(controller):
     """
     Check request GET 'page' key for wrong values
     """
+    @wraps(controller)
     def wrapper(request, *args, **kwargs):
         try:
             page = int(request.GET.get('page', 1))
@@ -82,13 +102,14 @@ def uses_pagination(controller):
 
         return controller(request, page=page, *args, **kwargs)
 
-    return _make_safe(wrapper, controller)
+    return wrapper
 
 
 def get_user_media_entry(controller):
     """
     Pass in a MediaEntry based off of a url component
     """
+    @wraps(controller)
     def wrapper(request, *args, **kwargs):
         user = request.db.User.find_one(
             {'username': request.matchdict['user']})
@@ -97,7 +118,7 @@ def get_user_media_entry(controller):
             return render_404(request)
         media = request.db.MediaEntry.find_one(
             {'slug': request.matchdict['media'],
-             'state': 'processed',
+             'state': u'processed',
              'uploader': user._id})
 
         # no media via slug?  Grab it via ObjectId
@@ -105,7 +126,7 @@ def get_user_media_entry(controller):
             try:
                 media = request.db.MediaEntry.find_one(
                     {'_id': ObjectId(request.matchdict['media']),
-                     'state': 'processed',
+                     'state': u'processed',
                      'uploader': user._id})
             except InvalidId:
                 return render_404(request)
@@ -116,18 +137,72 @@ def get_user_media_entry(controller):
 
         return controller(request, media=media, *args, **kwargs)
 
-    return _make_safe(wrapper, controller)
+    return wrapper
+
+
+def get_user_collection(controller):
+    """
+    Pass in a Collection based off of a url component
+    """
+    @wraps(controller)
+    def wrapper(request, *args, **kwargs):
+        user = request.db.User.find_one(
+            {'username': request.matchdict['user']})
+
+        if not user:
+            return render_404(request)
+
+        collection = request.db.Collection.find_one(
+            {'slug': request.matchdict['collection'],
+             'creator': user._id})
+
+        # Still no collection?  Okay, 404.
+        if not collection:
+            return render_404(request)
+
+        return controller(request, collection=collection, *args, **kwargs)
+
+    return wrapper
+
+
+def get_user_collection_item(controller):
+    """
+    Pass in a CollectionItem based off of a url component
+    """
+    @wraps(controller)
+    def wrapper(request, *args, **kwargs):
+        user = request.db.User.find_one(
+            {'username': request.matchdict['user']})
+
+        if not user:
+            return render_404(request)
+
+        collection = request.db.Collection.find_one(
+            {'slug': request.matchdict['collection'],
+             'creator': user._id})
+
+        collection_item = request.db.CollectionItem.find_one(
+            {'_id': request.matchdict['collection_item'] })
+
+        # Still no collection item?  Okay, 404.
+        if not collection_item:
+            return render_404(request)
+
+        return controller(request, collection_item=collection_item, *args, **kwargs)
+
+    return wrapper
 
 
 def get_media_entry_by_id(controller):
     """
     Pass in a MediaEntry based off of a url component
     """
+    @wraps(controller)
     def wrapper(request, *args, **kwargs):
         try:
             media = request.db.MediaEntry.find_one(
                 {'_id': ObjectId(request.matchdict['media']),
-                 'state': 'processed'})
+                 'state': u'processed'})
         except InvalidId:
             return render_404(request)
 
@@ -137,4 +212,4 @@ def get_media_entry_by_id(controller):
 
         return controller(request, media=media, *args, **kwargs)
 
-    return _make_safe(wrapper, controller)
+    return wrapper