Use Requests instead of httplib
authorAaron Hill <aa1ronham@gmail.com>
Thu, 11 Jul 2013 13:49:49 +0000 (09:49 -0400)
committerAaron Hill <aa1ronham@gmail.com>
Sun, 27 Apr 2014 01:22:38 +0000 (21:22 -0400)
tweepy/binder.py
tweepy/parsers.py
tweepy/streaming.py

index 37c2effd162d4ecd47d8f8b249dac37a9014e770..13ba0d9f66deb5f10b98966b2c8108a1dfc1b18f 100644 (file)
@@ -2,7 +2,7 @@
 # Copyright 2009-2010 Joshua Roesslein
 # See LICENSE for details.
 
-import httplib
+import requests
 import urllib
 import time
 import re
@@ -28,6 +28,7 @@ def bind_api(**config):
         require_auth = config.get('require_auth', False)
         search_api = config.get('search_api', False)
         use_cache = config.get('use_cache', True)
+        session = requests.Session()
 
         def __init__(self, api, args, kargs):
             # If authentication is required and no credentials
@@ -43,7 +44,7 @@ def bind_api(**config):
             self.wait_on_rate_limit = kargs.pop('wait_on_rate_limit', api.wait_on_rate_limit)
             self.wait_on_rate_limit_notify = kargs.pop('wait_on_rate_limit_notify', api.wait_on_rate_limit_notify)
             self.parser = kargs.pop('parser', api.parser)
-            self.headers = kargs.pop('headers', {})
+            self.session.headers = kargs.pop('headers', {})
             self.build_parameters(args, kargs)
 
             # Pick correct URL root to use
@@ -69,42 +70,43 @@ def bind_api(**config):
             # or older where Host is set including the 443 port.
             # This causes Twitter to issue 301 redirect.
             # See Issue https://github.com/tweepy/tweepy/issues/12
-            self.headers['Host'] = self.host
+
+            self.session.headers['Host'] = self.host
             # Monitoring rate limits
             self._remaining_calls = None
             self._reset_time = None
 
         def build_parameters(self, args, kargs):
-            self.parameters = {}
+            self.session.params = {}
             for idx, arg in enumerate(args):
                 if arg is None:
                     continue
                 try:
-                    self.parameters[self.allowed_param[idx]] = convert_to_utf8_str(arg)
+                    self.session.params[self.allowed_param[idx]] = convert_to_utf8_str(arg)
                 except IndexError:
                     raise TweepError('Too many parameters supplied!')
 
             for k, arg in kargs.items():
                 if arg is None:
                     continue
-                if k in self.parameters:
+                if k in self.session.params:
                     raise TweepError('Multiple values for parameter %s supplied!' % k)
 
-                self.parameters[k] = convert_to_utf8_str(arg)
+                self.session.params[k] = convert_to_utf8_str(arg)
 
         def build_path(self):
             for variable in re_path_template.findall(self.path):
                 name = variable.strip('{}')
 
-                if name == 'user' and 'user' not in self.parameters and self.api.auth:
+                if name == 'user' and 'user' not in self.session.params and self.api.auth:
                     # No 'user' parameter provided, fetch it from Auth instead.
                     value = self.api.auth.get_username()
                 else:
                     try:
-                        value = urllib.quote(self.parameters[name])
+                        value = urllib.quote(self.session.params[name])
                     except KeyError:
                         raise TweepError('No parameter value found for path variable: %s' % name)
-                    del self.parameters[name]
+                    del self.session.params[name]
 
                 self.path = self.path.replace(variable, value)
 
@@ -113,8 +115,7 @@ def bind_api(**config):
 
             # Build the request URL
             url = self.api_root + self.path
-            if len(self.parameters):
-                url = '%s?%s' % (url, urllib.urlencode(self.parameters))
+            full_url = self.scheme + self.host + url
 
             # Query the cache if one is available
             # and this request uses a GET method.
@@ -146,28 +147,21 @@ def bind_api(**config):
                             print "Max retries reached. Sleeping for: " + str(sleep_time)
                         time.sleep(sleep_time + 5) # sleep for few extra sec
 
-                # Open connection
-                if self.api.secure:
-                    conn = httplib.HTTPSConnection(self.host, timeout=self.api.timeout)
-                else:
-                    conn = httplib.HTTPConnection(self.host, timeout=self.api.timeout)
-
                 # Apply authentication
                 if self.api.auth:
                     self.api.auth.apply_auth(
-                            self.scheme + self.host + url,
-                            self.method, self.headers, self.parameters
+                            full_url,
+                            self.method, self.session.headers, self.session.params
                     )
 
                 # Request compression if configured
                 if self.api.compression:
-                    self.headers['Accept-encoding'] = 'gzip'
+                    self.session.headers['Accept-encoding'] = 'gzip'
 
                 # Execute request
                 try:
-                    conn.request(self.method, url, headers=self.headers, body=self.post_data)
-                    resp = conn.getresponse()
-                except Exception as e:
+                    resp = self.session.request(self.method, full_url, data=self.post_data, timeout=self.api.timeout)
+                except Exception, e:
                     raise TweepError('Failed to send request: %s' % e)
                 rem_calls = resp.getheader('x-rate-limit-remaining')
                 if rem_calls is not None:
@@ -181,12 +175,12 @@ def bind_api(**config):
                     continue
                 retry_delay = self.retry_delay
                 # Exit request loop if non-retry error code
-                if resp.status == 200:
+                if resp.status_code == 200:
                     break
-                elif (resp.status == 429 or resp.status == 420) and self.wait_on_rate_limit:
-                    if 'retry-after' in resp.msg:
-                        retry_delay = float(resp.msg['retry-after'])
-                elif self.retry_errors and resp.status not in self.retry_errors:
+                elif (resp.status_code == 429 or resp.status_code == 420) and self.wait_on_rate_limit:
+                    if 'retry-after' in resp.headers:
+                        retry_delay = float(resp.headers['retry-after'])
+                elif self.retry_errors and resp.status_code not in self.retry_errors:
                     break
 
                 # Sleep before retrying request again
@@ -197,14 +191,14 @@ def bind_api(**config):
             self.api.last_response = resp
             if resp.status and not 200 <= resp.status < 300:
                 try:
-                    error_msg = self.parser.parse_error(resp.read())
+                    error_msg = self.parser.parse_error(resp.text)
                 except Exception:
-                    error_msg = "Twitter error response: status code = %s" % resp.status
+                    error_msg = "Twitter error response: status code = %s" % resp.status_code
                 raise TweepError(error_msg, resp)
 
             # Parse the response payload
-            body = resp.read()
-            if resp.getheader('Content-Encoding', '') == 'gzip':
+            body = resp.text
+            if resp.headers.get('Content-Encoding', '') == 'gzip':
                 try:
                     zipper = gzip.GzipFile(fileobj=StringIO(body))
                     body = zipper.read()
@@ -213,8 +207,6 @@ def bind_api(**config):
             
             result = self.parser.parse(self, body)
 
-            conn.close()
-
             # Store result into cache if one is available.
             if self.use_cache and self.api.cache and self.method == 'GET' and result:
                 self.api.cache.store(url, result)
index 31e002204565e3484e3c959164ff5b16804b0a48..c8b3a771bb11b95f8697f678a2b93814d0757c13 100644 (file)
@@ -51,7 +51,7 @@ class JSONParser(Parser):
         except Exception as e:
             raise TweepError('Failed to parse JSON payload: %s' % e)
 
-        needsCursors = method.parameters.has_key('cursor')
+        needsCursors = method.session.params.has_key('cursor')
         if needsCursors and isinstance(json, dict) and 'previous_cursor' in json and 'next_cursor' in json:
             cursors = json['previous_cursor'], json['next_cursor']
             return json, cursors
index 694f96671705a63256e51aa4daeb6eb64ba968c9..8b89d8f907745d03107427808250a1b857c87c5d 100644 (file)
@@ -3,8 +3,8 @@
 # See LICENSE for details.
 
 import logging
-import httplib
-from socket import timeout
+import requests
+from requests.exceptions import Timeout
 from threading import Thread
 from time import sleep
 import ssl
@@ -130,8 +130,9 @@ class Stream(object):
             self.scheme = "http"
 
         self.api = API()
-        self.headers = options.get("headers") or {}
-        self.parameters = None
+        self.session = requests.Session()
+        self.session.headers = options.get("headers") or {}
+        self.session.params = None
         self.body = None
         self.retry_time = self.retry_time_start
         self.snooze_time = self.snooze_time_step
@@ -142,23 +143,17 @@ class Stream(object):
 
         # Connect and process the stream
         error_counter = 0
-        conn = None
+        resp = None
         exception = None
         while self.running:
             if self.retry_count is not None and error_counter > self.retry_count:
                 # quit if error count greater than retry count
                 break
             try:
-                if self.scheme == "http":
-                    conn = httplib.HTTPConnection(self.host, timeout=self.timeout)
-                else:
-                    conn = httplib.HTTPSConnection(self.host, timeout=self.timeout)
-                self.auth.apply_auth(url, 'POST', self.headers, self.parameters)
-                conn.connect()
-                conn.request('POST', self.url, self.body, headers=self.headers)
-                resp = conn.getresponse()
-                if resp.status != 200:
-                    if self.listener.on_error(resp.status) is False:
+                self.auth.apply_auth(url, 'POST', self.session.headers, self.session.params)
+                resp = self.session.request('POST', url, data=self.body, timeout=self.timeout, stream=True)
+                if resp.status_code != 200:
+                    if self.listener.on_error(resp.status_code) is False:
                         break
                     error_counter += 1
                     if resp.status == 420:
@@ -181,7 +176,6 @@ class Stream(object):
                     break
                 if self.running is False:
                     break
-                conn.close()
                 sleep(self.snooze_time)
                 self.snooze_time = min(self.snooze_time + self.snooze_time_step,
                                        self.snooze_time_cap)
@@ -191,8 +185,8 @@ class Stream(object):
 
         # cleanup
         self.running = False
-        if conn:
-            conn.close()
+        if resp:
+            resp.close()
 
         if exception:
             # call a handler first so that the exception can be logged.
@@ -205,28 +199,33 @@ class Stream(object):
 
     def _read_loop(self, resp):
 
-        while self.running and not resp.isclosed():
+        while self.running:
 
             # Note: keep-alive newlines might be inserted before each length value.
             # read until we get a digit...
             c = '\n'
-            while c == '\n' and self.running and not resp.isclosed():
-                c = resp.read(1)
+            for c in resp.iter_content():
+                if c == '\n':
+                    continue
+                break
+
             delimited_string = c
 
             # read rest of delimiter length..
             d = ''
-            while d != '\n' and self.running and not resp.isclosed():
-                d = resp.read(1)
-                delimited_string += d
+            for d in resp.iter_content():
+                if d != '\n':
+                    delimited_string += d
+                    continue
+                break
 
             # read the next twitter status object
             if delimited_string.strip().isdigit():
-                next_status_obj = resp.read( int(delimited_string) )
+                next_status_obj = resp.raw.read( int(delimited_string) )
                 if self.running:
                     self._data(next_status_obj)
 
-        if resp.isclosed():
+        if resp.raw._fp.isclosed():
             self.on_closed(resp)
 
     def _start(self, async):
@@ -264,26 +263,27 @@ class Stream(object):
 
         self.body = urlencode_noplus(self.parameters)
         self.url = self.url + '?' + self.body
+
         self._start(async)
 
     def firehose(self, count=None, async=False):
-        self.parameters = {'delimited': 'length'}
+        self.session.params = {'delimited': 'length'}
         if self.running:
             raise TweepError('Stream object already connected!')
-        self.url = '/%s/statuses/firehose.json?delimited=length' % STREAM_VERSION
+        self.url = '/%s/statuses/firehose.json' % STREAM_VERSION
         if count:
             self.url += '&count=%s' % count
         self._start(async)
 
     def retweet(self, async=False):
-        self.parameters = {'delimited': 'length'}
+        self.session.params = {'delimited': 'length'}
         if self.running:
             raise TweepError('Stream object already connected!')
-        self.url = '/%s/statuses/retweet.json?delimited=length' % STREAM_VERSION
+        self.url = '/%s/statuses/retweet.json' % STREAM_VERSION
         self._start(async)
 
     def sample(self, async=False):
-        self.parameters = {'delimited': 'length'}
+        self.session.params = {'delimited': 'length'}
         if self.running:
             raise TweepError('Stream object already connected!')
         self.url = '/%s/statuses/sample.json?delimited=length' % STREAM_VERSION
@@ -291,28 +291,28 @@ class Stream(object):
 
     def filter(self, follow=None, track=None, async=False, locations=None,
                stall_warnings=False, languages=None, encoding='utf8'):
-        self.parameters = {}
+        self.session.params = {}
         self.headers['Content-type'] = "application/x-www-form-urlencoded"
         if self.running:
             raise TweepError('Stream object already connected!')
-        self.url = '/%s/statuses/filter.json?delimited=length' % STREAM_VERSION
+        self.url = '/%s/statuses/filter.json' % STREAM_VERSION
         if follow:
             encoded_follow = [s.encode(encoding) for s in follow]
-            self.parameters['follow'] = ','.join(encoded_follow)
+            self.session.params['follow'] = ','.join(map(str, follow))
         if track:
-            encoded_track = [s.encode(encoding) for s in track]
-            self.parameters['track'] = ','.join(encoded_track)
+            self.session.params['track'] = ','.join(map(str, track))
         if locations and len(locations) > 0:
             if len(locations) % 4 != 0:
                 raise TweepError("Wrong number of locations points, "
                                  "it has to be a multiple of 4")
-            self.parameters['locations'] = ','.join(['%.4f' % l for l in locations])
+            self.session.params['locations'] = ','.join(['%.4f' % l for l in locations])
         if stall_warnings:
-            self.parameters['stall_warnings'] = stall_warnings
+            self.session.params['stall_warnings'] = stall_warnings
         if languages:
-            self.parameters['language'] = ','.join(map(str, languages))
-        self.body = urlencode_noplus(self.parameters)
-        self.parameters['delimited'] = 'length'
+            self.session.params['language'] = ','.join(map(str, languages))
+        self.body = urlencode_noplus(self.session.params)
+        self.session.params['delimited'] = 'length'
+        self.host = 'stream.twitter.com'
         self._start(async)
 
     def disconnect(self):