Small changes in connection
[diaspy.git] / diaspy / connection.py
index af978d674123027feeef9fb4e488fd7575699d64..8654b068ede0c4c2acc4e7884fbe8ec3c404592e 100644 (file)
@@ -3,17 +3,28 @@
 import re
 import requests
 import json
+import warnings
+
+
+"""This module abstracts connection to pod.
+"""
 
 
 class LoginError(Exception):
     pass
 
 
+class TokenError(Exception):
+    pass
+
+
 class Connection():
-    """Object representing connection with the server.
-    It is pushed around internally and is considered private.
+    """Object representing connection with the pod.
     """
-    def __init__(self, pod, username='', password=''):
+    _token_regex = re.compile(r'content="(.*?)"\s+name="csrf-token')
+    _userinfo_regex = re.compile(r'window.current_user_attributes = ({.*})')
+
+    def __init__(self, pod, username='', password='', schema='https'):
         """
         :param pod: The complete url of the diaspora pod to use.
         :type pod: str
@@ -24,12 +35,25 @@ class Connection():
         """
         self.pod = pod
         self.session = requests.Session()
-        self._token_regex = re.compile(r'content="(.*?)"\s+name="csrf-token')
-        self._userinfo_regex = re.compile(r'window.current_user_attributes = ({.*})')
         self.login_data = {}
-        self._setlogin(username, password)
+        self.token = ''
+        try: self._setlogin(username, password)
+        except requests.exceptions.MissingSchema:
+            self.pod = '{0}://{1}'.format(schema, self.pod)
+            warnings.warn('schema was missing')
+        finally: pass
+        try: self._setlogin(username, password)
+        except Exception as e: raise LoginError('cannot create login data (caused by: {0})'.format(e))
+
+    def __repr__(self):
+        """Returns token string.
+        It will be easier to change backend if programs will just use:
+            repr(connection)
+        instead of calling a specified method.
+        """
+        return self.get_token()
 
-    def get(self, string):
+    def get(self, string, headers={}, params={}):
         """This method gets data from session.
         Performs additional checks if needed.
 
@@ -39,7 +63,7 @@ class Connection():
         :param string: URL to get without the pod's URL and slash eg. 'stream'.
         :type string: str
         """
-        return self.session.get('{0}/{1}'.format(self.pod, string))
+        return self.session.get('{0}/{1}'.format(self.pod, string), params=params, headers=headers)
 
     def post(self, string, data, headers={}, params={}):
         """This method posts data to session.
@@ -57,14 +81,15 @@ class Connection():
         :type params: dict
         """
         string = '{0}/{1}'.format(self.pod, string)
-        if headers and params:
-            request = self.session.post(string, data=data, headers=headers, params=params)
-        elif headers and not params:
-            request = self.session.post(string, data=data, headers=headers)
-        elif not headers and params:
-            request = self.session.post(string, data=data, params=params)
-        else:
-            request = self.session.post(string, data=data)
+        request = self.session.post(string, data, headers=headers, params=params)
+        return request
+
+    def put(self, string, data=None, headers={}, params={}):
+        """This method PUTs to session.
+        """
+        string = '{0}/{1}'.format(self.pod, string)
+        if data is not None: request = self.session.put(string, data, headers=headers, params=params)
+        else: request = self.session.put(string, headers=headers, params=params)
         return request
 
     def delete(self, string, data, headers={}):
@@ -78,21 +103,19 @@ class Connection():
         :type headers: dict
         """
         string = '{0}/{1}'.format(self.pod, string)
-        if headers:
-            request = self.session.delete(string, data=data, headers=headers)
-        else:
-            request = self.session.delete(string, data=data)
+        request = self.session.delete(string, data=data, headers=headers)
         return request
 
     def _setlogin(self, username, password):
         """This function is used to set data for login.
+
         .. note::
             It should be called before _login() function.
         """
         self.username, self.password = username, password
         self.login_data = {'user[username]': self.username,
                            'user[password]': self.password,
-                           'authenticity_token': self.getToken()}
+                           'authenticity_token': self._fetchtoken()}
 
     def _login(self):
         """Handles actual login request.
@@ -102,7 +125,7 @@ class Connection():
                             data=self.login_data,
                             headers={'accept': 'application/json'})
         if request.status_code != 201:
-            raise LoginError('{0}: Login failed.'.format(request.status_code))
+            raise LoginError('{0}: login failed'.format(request.status_code))
 
     def login(self, username='', password=''):
         """This function is used to log in to a pod.
@@ -112,6 +135,15 @@ class Connection():
         if not self.username or not self.password: raise LoginError('password or username not specified')
         self._login()
 
+    def logout(self):
+        """Logs out from a pod.
+        When logged out you can't do anything.
+        """
+        self.get('users/sign_out')
+        self.username = ''
+        self.token = ''
+        self.password = ''
+
     def podswitch(self, pod):
         """Switches pod from current to another one.
         """
@@ -124,14 +156,37 @@ class Connection():
         :returns: dict -- json formatted user info.
         """
         request = self.get('bookmarklet')
-        userdata = json.loads(self._userinfo_regex.search(request.text).group(1))
+        try:
+            userdata = json.loads(self._userinfo_regex.search(request.text).group(1))
+        except AttributeError:
+            raise errors.DiaspyError('cannot find user data')
         return userdata
 
-    def getToken(self):
+    def _fetchtoken(self):
+        """This method tries to get token string needed for authentication on D*.
+
+        :returns: token string
+        """
+        request = self.get('stream')
+        token = self._token_regex.search(request.text).group(1)
+        self.token = token
+        return token
+
+    def get_token(self, fetch=False):
         """This function returns a token needed for authentication in most cases.
+        Each time it is run a _fetchtoken() is called and refreshed token is stored.
+
+        It is more safe to use than _fetchtoken().
+        By setting new you can request new token or decide to get stored one.
+        If no token is stored new one will be fatched anyway.
 
         :returns: string -- token used to authenticate
         """
-        r = self.get('stream')
-        token = self._token_regex.search(r.text).group(1)
-        return token
+        try:
+            if fetch: self._fetchtoken()
+            if not self.token: self._fetchtoken()
+        except requests.exceptions.ConnectionError as e:
+            warnings.warn('{0} was cought: reusing old token'.format(e))
+        finally:
+            if not self.token: raise TokenError('cannot obtain token and no previous token found for reuse')
+        return self.token