OAuth: Support refresh tokens, etc
authorJoar Wandborg <joar@wandborg.se>
Sun, 10 Mar 2013 21:52:07 +0000 (22:52 +0100)
committerJoar Wandborg <joar@wandborg.se>
Sat, 6 Apr 2013 20:17:27 +0000 (22:17 +0200)
Initially I was going to write a failing test for refresh tokens. Thus
this fix includes an orphaned 'expect_failure' method in test utils.

I ended up writing support for OAuth refresh tokens, as well as a lot of
cleanup (hopefully) in the OAuth plugin code.

**Rebase**: While waiting for this stuff to be merged, the testing
framework changed, it comes with batteries included regarding fails.
Removed legacy nosetest helper.

Also added a lot of backref=backref([...], cascade='all, delete-orphan')

mediagoblin/plugins/oauth/__init__.py
mediagoblin/plugins/oauth/migrations.py
mediagoblin/plugins/oauth/models.py
mediagoblin/plugins/oauth/tools.py
mediagoblin/plugins/oauth/views.py
mediagoblin/tests/test_oauth.py
mediagoblin/tests/tools.py

index 4714d95d69beab4e9f10bc68f4f67cd6c477aca0..5762379dea39c68f9fdddca1dd645424b23bb7b4 100644 (file)
@@ -34,7 +34,7 @@ def setup_plugin():
     _log.debug('OAuth config: {0}'.format(config))
 
     routes = [
-       ('mediagoblin.plugins.oauth.authorize',
+        ('mediagoblin.plugins.oauth.authorize',
             '/oauth/authorize',
             'mediagoblin.plugins.oauth.views:authorize'),
         ('mediagoblin.plugins.oauth.authorize_client',
index 6aa0d7cb51124f67a9ea8cf6cd892b279166518e..d7b89da34acbbc2622533f11c04fe86966f4e7ea 100644 (file)
@@ -102,6 +102,21 @@ class OAuthCode_v0(declarative_base()):
     client_id = Column(Integer, ForeignKey(OAuthClient_v0.id), nullable=False)
 
 
+class OAuthRefreshToken_v0(declarative_base()):
+    __tablename__ = 'oauth__refresh_tokens'
+
+    id = Column(Integer, primary_key=True)
+    created = Column(DateTime, nullable=False,
+                     default=datetime.now)
+
+    token = Column(Unicode, index=True)
+
+    user_id = Column(Integer, ForeignKey(User.id), nullable=False)
+
+    # XXX: Is it OK to use OAuthClient_v0.id in this way?
+    client_id = Column(Integer, ForeignKey(OAuthClient_v0.id), nullable=False)
+
+
 @RegisterMigration(1, MIGRATIONS)
 def remove_and_replace_token_and_code(db):
     metadata = MetaData(bind=db.bind)
@@ -122,3 +137,22 @@ def remove_and_replace_token_and_code(db):
     OAuthCode_v0.__table__.create(db.bind)
 
     db.commit()
+
+
+@RegisterMigration(2, MIGRATIONS)
+def remove_refresh_token_field(db):
+    metadata = MetaData(bind=db.bind)
+
+    token_table = Table('oauth__tokens', metadata, autoload=True,
+                        autoload_with=db.bind)
+
+    refresh_token = token_table.columns['refresh_token']
+
+    refresh_token.drop()
+    db.commit()
+
+@RegisterMigration(3, MIGRATIONS)
+def create_refresh_token_table(db):
+    OAuthRefreshToken_v0.__table__.create(db.bind)
+
+    db.commit()
index 695dad3192fb1b8ec254cfb5199de6e7fe6a2679..439424d369a3f1dbf2fd6d39ad20bb1415da171e 100644 (file)
 # You should have received a copy of the GNU Affero General Public License
 # along with this program.  If not, see <http://www.gnu.org/licenses/>.
 
-import uuid
-import bcrypt
 
 from datetime import datetime, timedelta
 
-from mediagoblin.db.base import Base
-from mediagoblin.db.models import User
 
 from sqlalchemy import (
         Column, Unicode, Integer, DateTime, ForeignKey, Enum)
-from sqlalchemy.orm import relationship
+from sqlalchemy.orm import relationship, backref
+from mediagoblin.db.base import Base
+from mediagoblin.db.models import User
+from mediagoblin.plugins.oauth.tools import generate_identifier, \
+    generate_secret, generate_token, generate_code, generate_refresh_token
 
 # Don't remove this, I *think* it applies sqlalchemy-migrate functionality onto
 # the models.
@@ -41,11 +41,14 @@ class OAuthClient(Base):
     name = Column(Unicode)
     description = Column(Unicode)
 
-    identifier = Column(Unicode, unique=True, index=True)
-    secret = Column(Unicode, index=True)
+    identifier = Column(Unicode, unique=True, index=True,
+                        default=generate_identifier)
+    secret = Column(Unicode, index=True, default=generate_secret)
 
     owner_id = Column(Integer, ForeignKey(User.id))
-    owner = relationship(User, backref='registered_clients')
+    owner = relationship(
+        User,
+        backref=backref('registered_clients', cascade='all, delete-orphan'))
 
     redirect_uri = Column(Unicode)
 
@@ -54,14 +57,8 @@ class OAuthClient(Base):
         u'public',
         name=u'oauth__client_type'))
 
-    def generate_identifier(self):
-        self.identifier = unicode(uuid.uuid4())
-
-    def generate_secret(self):
-        self.secret = unicode(
-                bcrypt.hashpw(
-                    unicode(uuid.uuid4()),
-                    bcrypt.gensalt()))
+    def update_secret(self):
+        self.secret = generate_secret()
 
     def __repr__(self):
         return '<{0} {1}:{2} ({3})>'.format(
@@ -76,10 +73,15 @@ class OAuthUserClient(Base):
     id = Column(Integer, primary_key=True)
 
     user_id = Column(Integer, ForeignKey(User.id))
-    user = relationship(User, backref='oauth_clients')
+    user = relationship(
+        User,
+        backref=backref('oauth_client_relations',
+                        cascade='all, delete-orphan'))
 
     client_id = Column(Integer, ForeignKey(OAuthClient.id))
-    client = relationship(OAuthClient, backref='users')
+    client = relationship(
+        OAuthClient,
+        backref=backref('oauth_user_relations', cascade='all, delete-orphan'))
 
     state = Column(Enum(
         u'approved',
@@ -103,15 +105,18 @@ class OAuthToken(Base):
             default=datetime.now)
     expires = Column(DateTime, nullable=False,
             default=lambda: datetime.now() + timedelta(days=30))
-    token = Column(Unicode, index=True)
-    refresh_token = Column(Unicode, index=True)
+    token = Column(Unicode, index=True, default=generate_token)
 
     user_id = Column(Integer, ForeignKey(User.id), nullable=False,
             index=True)
-    user = relationship(User)
+    user = relationship(
+        User,
+        backref=backref('oauth_tokens', cascade='all, delete-orphan'))
 
     client_id = Column(Integer, ForeignKey(OAuthClient.id), nullable=False)
-    client = relationship(OAuthClient)
+    client = relationship(
+        OAuthClient,
+        backref=backref('oauth_tokens', cascade='all, delete-orphan'))
 
     def __repr__(self):
         return '<{0} #{1} expires {2} [{3}, {4}]>'.format(
@@ -121,6 +126,34 @@ class OAuthToken(Base):
                 self.user,
                 self.client)
 
+class OAuthRefreshToken(Base):
+    __tablename__ = 'oauth__refresh_tokens'
+
+    id = Column(Integer, primary_key=True)
+    created = Column(DateTime, nullable=False,
+                     default=datetime.now)
+
+    token = Column(Unicode, index=True,
+                   default=generate_refresh_token)
+
+    user_id = Column(Integer, ForeignKey(User.id), nullable=False)
+
+    user = relationship(User, backref=backref('oauth_refresh_tokens',
+                                              cascade='all, delete-orphan'))
+
+    client_id = Column(Integer, ForeignKey(OAuthClient.id), nullable=False)
+    client = relationship(OAuthClient,
+                          backref=backref(
+                              'oauth_refresh_tokens',
+                              cascade='all, delete-orphan'))
+
+    def __repr__(self):
+        return '<{0} #{1} [{3}, {4}]>'.format(
+                self.__class__.__name__,
+                self.id,
+                self.user,
+                self.client)
+
 
 class OAuthCode(Base):
     __tablename__ = 'oauth__codes'
@@ -130,14 +163,17 @@ class OAuthCode(Base):
             default=datetime.now)
     expires = Column(DateTime, nullable=False,
             default=lambda: datetime.now() + timedelta(minutes=5))
-    code = Column(Unicode, index=True)
+    code = Column(Unicode, index=True, default=generate_code)
 
     user_id = Column(Integer, ForeignKey(User.id), nullable=False,
             index=True)
-    user = relationship(User)
+    user = relationship(User, backref=backref('oauth_codes',
+                                              cascade='all, delete-orphan'))
 
     client_id = Column(Integer, ForeignKey(OAuthClient.id), nullable=False)
-    client = relationship(OAuthClient)
+    client = relationship(OAuthClient, backref=backref(
+        'oauth_codes',
+        cascade='all, delete-orphan'))
 
     def __repr__(self):
         return '<{0} #{1} expires {2} [{3}, {4}]>'.format(
@@ -150,6 +186,7 @@ class OAuthCode(Base):
 
 MODELS = [
         OAuthToken,
+        OAuthRefreshToken,
         OAuthCode,
         OAuthClient,
         OAuthUserClient]
index d21c8a5bf9344c92698aea576f3a05c13e2e9245..27ff32b488a20411fc796381ac3072cdc4ea4a88 100644 (file)
@@ -1,3 +1,4 @@
+# -*- coding: utf-8 -*-
 # GNU MediaGoblin -- federated, autonomous media hosting
 # Copyright (C) 2011, 2012 MediaGoblin contributors.  See AUTHORS.
 #
 # You should have received a copy of the GNU Affero General Public License
 # along with this program.  If not, see <http://www.gnu.org/licenses/>.
 
+import uuid
+
+from random import getrandbits
+
+from datetime import datetime
+
 from functools import wraps
 
-from mediagoblin.plugins.oauth.models import OAuthClient
 from mediagoblin.plugins.api.tools import json_response
 
 
 def require_client_auth(controller):
+    '''
+    View decorator
+
+    - Requires the presence of ``?client_id``
+    '''
+    # Avoid circular import
+    from mediagoblin.plugins.oauth.models import OAuthClient
+
     @wraps(controller)
     def wrapper(request, *args, **kw):
         if not request.GET.get('client_id'):
@@ -41,3 +55,60 @@ def require_client_auth(controller):
         return controller(request, client)
 
     return wrapper
+
+
+def create_token(client, user):
+    '''
+    Create an OAuthToken and an OAuthRefreshToken entry in the database
+
+    Returns the data structure expected by the OAuth clients.
+    '''
+    from mediagoblin.plugins.oauth.models import OAuthToken, OAuthRefreshToken
+
+    token = OAuthToken()
+    token.user = user
+    token.client = client
+    token.save()
+
+    refresh_token = OAuthRefreshToken()
+    refresh_token.user = user
+    refresh_token.client = client
+    refresh_token.save()
+
+    # expire time of token in full seconds
+    # timedelta.total_seconds is python >= 2.7 or we would use that
+    td = token.expires - datetime.now()
+    exp_in = 86400*td.days + td.seconds # just ignore µsec
+
+    return {'access_token': token.token, 'token_type': 'bearer',
+            'refresh_token': refresh_token.token, 'expires_in': exp_in}
+
+
+def generate_identifier():
+    ''' Generates a ``uuid.uuid4()`` '''
+    return unicode(uuid.uuid4())
+
+
+def generate_token():
+    ''' Uses generate_identifier '''
+    return generate_identifier()
+
+
+def generate_refresh_token():
+    ''' Uses generate_identifier '''
+    return generate_identifier()
+
+
+def generate_code():
+    ''' Uses generate_identifier '''
+    return generate_identifier()
+
+
+def generate_secret():
+    '''
+    Generate a long string of pseudo-random characters
+    '''
+    # XXX: We might not want it to use bcrypt, since bcrypt takes its time to
+    # generate the result.
+    return unicode(getrandbits(192))
+
index ea45c209ef1d089585732ef5819d3e914dc0f5ee..d6fd314f78810ce6816cefb0fdc38a7c40b19e3c 100644 (file)
 # along with this program.  If not, see <http://www.gnu.org/licenses/>.
 
 import logging
-import json
 
 from urllib import urlencode
-from uuid import uuid4
-from datetime import datetime
+
+from werkzeug.exceptions import BadRequest
 
 from mediagoblin.tools.response import render_to_response, redirect
 from mediagoblin.decorators import require_active_login
-from mediagoblin.messages import add_message, SUCCESS, ERROR
+from mediagoblin.messages import add_message, SUCCESS
 from mediagoblin.tools.translate import pass_to_ugettext as _
-from mediagoblin.plugins.oauth.models import OAuthCode, OAuthToken, \
-        OAuthClient, OAuthUserClient
+from mediagoblin.plugins.oauth.models import OAuthCode, OAuthClient, \
+        OAuthUserClient, OAuthRefreshToken
 from mediagoblin.plugins.oauth.forms import ClientRegistrationForm, \
         AuthorizationForm
-from mediagoblin.plugins.oauth.tools import require_client_auth
+from mediagoblin.plugins.oauth.tools import require_client_auth, \
+        create_token
 from mediagoblin.plugins.api.tools import json_response
 
 _log = logging.getLogger(__name__)
@@ -51,9 +51,6 @@ def register_client(request):
         client.owner_id = request.user.id
         client.redirect_uri = unicode(form.redirect_uri.data)
 
-        client.generate_identifier()
-        client.generate_secret()
-
         client.save()
 
         add_message(request, SUCCESS, _('The client {0} has been registered!')\
@@ -92,9 +89,9 @@ def authorize_client(request):
         form.client_id.data).first()
 
     if not client:
-        _log.error('''No such client id as received from client authorization
-                form.''')
-        return BadRequest()
+        _log.error('No such client id as received from client authorization \
+form.')
+        raise BadRequest()
 
     if form.validate():
         relation = OAuthUserClient()
@@ -105,7 +102,7 @@ def authorize_client(request):
         elif form.deny.data:
             relation.state = u'rejected'
         else:
-            return BadRequest
+            raise BadRequest()
 
         relation.save()
 
@@ -136,7 +133,7 @@ def authorize(request, client):
                 return json_response({
                     'status': 400,
                     'errors':
-                        [u'Public clients MUST have a redirect_uri pre-set']},
+                        [u'Public clients should have a redirect_uri pre-set.']},
                         _disable_cors=True)
 
             redirect_uri = client.redirect_uri
@@ -146,11 +143,10 @@ def authorize(request, client):
             if not redirect_uri:
                 return json_response({
                     'status': 400,
-                    'errors': [u'Can not find a redirect_uri for client: {0}'\
-                            .format(client.name)]}, _disable_cors=True)
+                    'errors': [u'No redirect_uri supplied!']},
+                    _disable_cors=True)
 
         code = OAuthCode()
-        code.code = unicode(uuid4())
         code.user = request.user
         code.client = client
         code.save()
@@ -180,59 +176,79 @@ def authorize(request, client):
 
 
 def access_token(request):
+    '''
+    Access token endpoint provides access tokens to any clients that have the
+    right grants/credentials
+    '''
+
+    client = None
+    user = None
+
     if request.GET.get('code'):
+        # Validate the code arg, then get the client object from the db.
         code = OAuthCode.query.filter(OAuthCode.code ==
                 request.GET.get('code')).first()
 
-        if code:
-            if code.client.type == u'confidential':
-                client_identifier = request.GET.get('client_id')
-
-                if not client_identifier:
-                    return json_response({
-                        'error': 'invalid_request',
-                        'error_description':
-                            'Missing client_id in request'})
-
-                client_secret = request.GET.get('client_secret')
-
-                if not client_secret:
-                    return json_response({
-                        'error': 'invalid_request',
-                        'error_description':
-                            'Missing client_secret in request'})
-
-                if not client_secret == code.client.secret or \
-                        not client_identifier == code.client.identifier:
-                    return json_response({
-                        'error': 'invalid_client',
-                        'error_description':
-                            'The client_id or client_secret does not match the'
-                            ' code'})
-
-            token = OAuthToken()
-            token.token = unicode(uuid4())
-            token.user = code.user
-            token.client = code.client
-            token.save()
-
-            # expire time of token in full seconds
-            # timedelta.total_seconds is python >= 2.7 or we would use that
-            td = token.expires - datetime.now()
-            exp_in = 86400*td.days + td.seconds # just ignore µsec
-
-            access_token_data = {
-                'access_token': token.token,
-                'token_type': 'bearer',
-                'expires_in': exp_in}
-            return json_response(access_token_data, _disable_cors=True)
-        else:
+        if not code:
             return json_response({
                 'error': 'invalid_request',
                 'error_description':
-                    'Invalid code'})
-    else:
-        return json_response({
-            'error': 'invalid_request',
-            'error_descriptin':
-                'Missing `code` parameter in request'})
+                    'Invalid code.'})
+
+        client = code.client
+        user = code.user
+
+    elif request.args.get('refresh_token'):
+        # Validate a refresh token, then get the client object from the db.
+        refresh_token = OAuthRefreshToken.query.filter(
+            OAuthRefreshToken.token ==
+            request.args.get('refresh_token')).first()
+
+        if not refresh_token:
+            return json_response({
+                'error': 'invalid_request',
+                'error_description':
+                    'Invalid refresh token.'})
+
+        client = refresh_token.client
+        user = refresh_token.user
+
+    if client:
+        client_identifier = request.GET.get('client_id')
+
+        if not client_identifier:
+            return json_response({
+                'error': 'invalid_request',
+                'error_description':
+                    'Missing client_id in request.'})
+
+        if not client_identifier == client.identifier:
+            return json_response({
+                'error': 'invalid_client',
+                'error_description':
+                    'Mismatching client credentials.'})
+
+        if client.type == u'confidential':
+            client_secret = request.GET.get('client_secret')
+
+            if not client_secret:
+                return json_response({
+                    'error': 'invalid_request',
+                    'error_description':
+                        'Missing client_secret in request.'})
+
+            if not client_secret == client.secret:
+                return json_response({
+                    'error': 'invalid_client',
+                    'error_description':
+                        'Mismatching client credentials.'})
+
+
+        access_token_data = create_token(client, user)
+
+        return json_response(access_token_data, _disable_cors=True)
+
+    return json_response({
+        'error': 'invalid_request',
+        'error_description':
+            'Missing `code` or `refresh_token` parameter in request.'})
index 901556fe229b945dcf32108ae04876b437d6f4a3..7ad984598fe4a1170d46c97f0508467a8b0d6c0f 100644 (file)
@@ -71,7 +71,7 @@ class TestOAuth(object):
         assert response.status_int == 200
 
         # Should display an error
-        assert ctx['form'].redirect_uri.errors
+        assert len(ctx['form'].redirect_uri.errors)
 
         # Should not pass through
         assert not client
@@ -79,12 +79,16 @@ class TestOAuth(object):
     def test_2_successful_public_client_registration(self, test_app):
         ''' Successfully register a public client '''
         self._setup(test_app)
+        uri = 'http://foo.example'
         self.register_client(test_app, u'OMGOMG', 'public', 'OMG!',
-                'http://foo.example')
+                uri)
 
         client = self.db.OAuthClient.query.filter(
                 self.db.OAuthClient.name == u'OMGOMG').first()
 
+        # redirect_uri should be set
+        assert client.redirect_uri == uri
+
         # Client should have been registered
         assert client
 
@@ -116,7 +120,7 @@ class TestOAuth(object):
         redirect_uri = 'https://foo.example'
         response = test_app.get('/oauth/authorize', {
                 'client_id': client.identifier,
-                'scope': 'admin',
+                'scope': 'all',
                 'redirect_uri': redirect_uri})
 
         # User-agent should NOT be redirected
@@ -142,6 +146,7 @@ class TestOAuth(object):
         return authorization_response, client_identifier
 
     def get_code_from_redirect_uri(self, uri):
+        ''' Get the value of ?code= from an URI '''
         return parse_qs(urlparse(uri).query)['code'][0]
 
     def test_token_endpoint_successful_confidential_request(self, test_app):
@@ -170,6 +175,11 @@ code={1}&client_secret={2}'.format(client_id, code, client.secret))
         assert type(token_data['expires_in']) == int
         assert token_data['expires_in'] > 0
 
+        # There should be a refresh token provided in the token data
+        assert len(token_data['refresh_token'])
+
+        return client_id, token_data
+
     def test_token_endpont_missing_id_confidential_request(self, test_app):
         ''' Unsuccessful request against token endpoint, missing client_id '''
         self._setup(test_app)
@@ -192,4 +202,30 @@ code={0}&client_secret={1}'.format(code, client.secret))
         assert 'error' in token_data
         assert not 'access_token' in token_data
         assert token_data['error'] == 'invalid_request'
-        assert token_data['error_description'] == 'Missing client_id in request'
+        assert len(token_data['error_description'])
+
+    def test_refresh_token(self, test_app):
+        ''' Try to get a new access token using the refresh token '''
+        # Get an access token and a refresh token
+        client_id, token_data =\
+            self.test_token_endpoint_successful_confidential_request(test_app)
+
+        client = self.db.OAuthClient.query.filter(
+            self.db.OAuthClient.identifier == client_id).first()
+
+        token_res = test_app.get('/oauth/access_token',
+                     {'refresh_token': token_data['refresh_token'],
+                      'client_id': client_id,
+                      'client_secret': client.secret
+                      })
+
+        assert token_res.status_int == 200
+
+        new_token_data = json.loads(token_res.body)
+
+        assert not 'error' in new_token_data
+        assert 'access_token' in new_token_data
+        assert 'token_type' in new_token_data
+        assert 'expires_in' in new_token_data
+        assert type(new_token_data['expires_in']) == int
+        assert new_token_data['expires_in'] > 0
index 2e47cb5c234d69db109519a183cc27ce8223d29a..b68d55e81a44af2e8f491da1f737fe415ebf75c2 100644 (file)
@@ -15,6 +15,7 @@
 # along with this program.  If not, see <http://www.gnu.org/licenses/>.
 
 
+import sys
 import os
 import pkg_resources
 import shutil
@@ -28,7 +29,6 @@ from mediagoblin import mg_globals
 from mediagoblin.db.models import User, MediaEntry, Collection
 from mediagoblin.tools import testing
 from mediagoblin.init.config import read_mediagoblin_config
-from mediagoblin.db.open import setup_connection_and_db_from_config
 from mediagoblin.db.base import Session
 from mediagoblin.meddleware import BaseMeddleware
 from mediagoblin.auth.lib import bcrypt_gen_password_hash
@@ -50,9 +50,10 @@ USER_DEV_DIRECTORIES_TO_SETUP = [
     'beaker/sessions/data', 'beaker/sessions/lock']
 
 BAD_CELERY_MESSAGE = """\
-Sorry, you *absolutely* must run nosetests with the
+Sorry, you *absolutely* must run tests with the
 mediagoblin.init.celery.from_tests module.  Like so:
-$ CELERY_CONFIG_MODULE=mediagoblin.init.celery.from_tests ./bin/nosetests"""
+$ CELERY_CONFIG_MODULE=mediagoblin.init.celery.from_tests {0}\
+""".format(sys.argv[0])
 
 
 class BadCeleryEnviron(Exception): pass
@@ -232,7 +233,7 @@ def fixture_media_entry(title=u"Some title", slug=None,
     entry.slug = slug
     entry.uploader = uploader or fixture_add_user().id
     entry.media_type = u'image'
-    
+
     if gen_slug:
         entry.generate_slug()
     if save: