Fixes for small bugs
[mediagoblin.git] / mediagoblin / oauth / views.py
index 14c8ab140c9bdcac3b54673ed9c606e53a54341c..ef91eb911f867f228635b01d4da7bc0a3149d5af 100644 (file)
 # along with this program.  If not, see <http://www.gnu.org/licenses/>.
 
 import datetime
+import urllib
 
+import six
+
+from oauthlib.oauth1.rfc5849.utils import UNICODE_ASCII_CHARACTER_SET
 from oauthlib.oauth1 import (RequestTokenEndpoint, AuthorizationEndpoint,
                              AccessTokenEndpoint)
 
@@ -35,7 +39,7 @@ from mediagoblin.oauth.tools.forms import WTFormData
 from mediagoblin.db.models import NonceTimestamp, Client, RequestToken
 
 # possible client types
-client_types = ["web", "native"] # currently what pump supports
+CLIENT_TYPES = ["web", "native"] # currently what pump supports
 
 @csrf_exempt
 def client_register(request):
@@ -53,7 +57,7 @@ def client_register(request):
     if "type" not in data:
         error = "No registration type provided."
         return json_response({"error": error}, status=400)
-    if data.get("application_type", None) not in client_types:
+    if data.get("application_type", None) not in CLIENT_TYPES:
         error = "Unknown application_type."
         return json_response({"error": error}, status=400)
 
@@ -88,7 +92,7 @@ def client_register(request):
                 )
 
         app_name = ("application_type", client.application_name)
-        if app_name in client_types:
+        if app_name in CLIENT_TYPES:
             client.application_name = app_name
 
     elif client_type == "client_associate":
@@ -104,8 +108,8 @@ def client_register(request):
             return json_response({"error": error}, status=400)
 
         # generate the client_id and client_secret
-        client_id = random_string(22) # seems to be what pump uses
-        client_secret = random_string(43) # again, seems to be what pump uses
+        client_id = random_string(22, UNICODE_ASCII_CHARACTER_SET)
+        client_secret = random_string(43, UNICODE_ASCII_CHARACTER_SET)
         expirey = 0 # for now, lets not have it expire
         expirey_db = None if expirey == 0 else expirey
         application_type = data["application_type"]
@@ -122,21 +126,21 @@ def client_register(request):
         error = "Invalid registration type"
         return json_response({"error": error}, status=400)
 
-    logo_url = data.get("logo_url", client.logo_url)
-    if logo_url is not None and not validate_url(logo_url):
-        error = "Logo URL {0} is not a valid URL.".format(logo_url)
+    logo_uri = data.get("logo_uri", client.logo_url)
+    if logo_uri is not None and not validate_url(logo_uri):
+        error = "Logo URI {0} is not a valid URI.".format(logo_uri)
         return json_response(
                 {"error": error},
                 status=400
                 )
     else:
-        client.logo_url = logo_url
+        client.logo_url = logo_uri
 
     client.application_name = data.get("application_name", None)
 
     contacts = data.get("contacts", None)
     if contacts is not None:
-        if type(contacts) is not unicode:
+        if not isinstance(contacts, six.text_type):
             error = "Contacts must be a string of space-seporated email addresses."
             return json_response({"error": error}, status=400)
 
@@ -152,7 +156,7 @@ def client_register(request):
 
     redirect_uris = data.get("redirect_uris", None)
     if redirect_uris is not None:
-        if type(redirect_uris) is not unicode:
+        if not isinstance(redirect_uris, six.text_type):
             error = "redirect_uris must be space-seporated URLs."
             return json_response({"error": error}, status=400)
 
@@ -187,10 +191,6 @@ def request_token(request):
         error = "Could not decode data."
         return json_response({"error": error}, status=400)
 
-    if data == "":
-        error = "Unknown Content-Type"
-        return json_response({"error": error}, status=400)
-
     if not data and request.headers:
         data = request.headers
 
@@ -211,7 +211,7 @@ def request_token(request):
         error = "Invalid client_id"
         return json_response({"error": error}, status=400)
 
-   # make request token and return to client
+    # make request token and return to client
     request_validator = GMGRequestValidator(authorization)
     rv = RequestTokenEndpoint(request_validator)
     tokens = rv.create_request_token(request, authorization)
@@ -249,12 +249,13 @@ def authorize(request):
 
     if oauth_request.verifier is None:
         orequest = GMGRequest(request)
+        orequest.resource_owner_key = token
         request_validator = GMGRequestValidator()
         auth_endpoint = AuthorizationEndpoint(request_validator)
         verifier = auth_endpoint.create_verifier(orequest, {})
         oauth_request.verifier = verifier["oauth_verifier"]
 
-    oauth_request.user = request.user.id
+    oauth_request.actor = request.user.id
     oauth_request.save()
 
     # find client & build context
@@ -313,10 +314,13 @@ def authorize_finish(request):
             oauth_request.verifier
             )
 
+    # It's come from the OAuth headers so it'll be encoded.
+    redirect_url = urllib.unquote(oauth_request.callback)
+
     return redirect(
             request,
             querystring=querystring,
-            location=oauth_request.callback
+            location=redirect_url
             )
 
 @csrf_exempt
@@ -330,10 +334,19 @@ def access_token(request):
         error = "Missing required parameter."
         return json_response({"error": error}, status=400)
 
-
+    request.resource_owner_key = parsed_tokens["oauth_consumer_key"]
     request.oauth_token = parsed_tokens["oauth_token"]
     request_validator = GMGRequestValidator(data)
+
+    # Check that the verifier is valid
+    verifier_valid = request_validator.validate_verifier(
+        token=request.oauth_token,
+        verifier=parsed_tokens["oauth_verifier"]
+    )
+    if not verifier_valid:
+        error = "Verifier code or token incorrect"
+        return json_response({"error": error}, status=401)
+
     av = AccessTokenEndpoint(request_validator)
     tokens = av.create_access_token(request, {})
     return form_response(tokens)
-