Merge branch 'master' into follow_int
authorHarmon <Harmon758@gmail.com>
Sun, 21 Feb 2021 21:45:42 +0000 (15:45 -0600)
committerGitHub <noreply@github.com>
Sun, 21 Feb 2021 21:45:42 +0000 (15:45 -0600)
1  2 
tweepy/streaming.py

index eed633987a019c3520079440f9141d23bacec881,00cb745085bc52052654add6569168ddfa26bbeb..0b20fd9d727943445f7876c40d7b9bef4f69030b
@@@ -14,461 -12,286 +12,286 @@@ import ss
  from threading import Thread
  from time import sleep
  
- import six
- import ssl
+ import requests
+ from requests_oauthlib import OAuth1
+ import urllib3
  
- from tweepy.models import Status
- from tweepy.api import API
+ import tweepy
  from tweepy.error import TweepError
+ from tweepy.models import Status
  
- from tweepy.utils import import_simplejson
- json = import_simplejson()
- STREAM_VERSION = '1.1'
- class StreamListener(object):
-     def __init__(self, api=None):
-         self.api = api or API()
+ log = logging.getLogger(__name__)
  
-     def on_connect(self):
-         """Called once connected to streaming server.
  
-         This will be invoked once a successful response
-         is received from the server. Allows the listener
-         to perform some work prior to entering the read loop.
-         """
-         pass
+ class Stream:
  
-     def on_data(self, raw_data):
-         """Called when raw data is received from connection.
+     def __init__(self, consumer_key, consumer_secret, access_token,
+                  access_token_secret, *, chunk_size=512, daemon=False,
+                  max_retries=inf, proxy=None, verify=True):
+         self.consumer_key = consumer_key
+         self.consumer_secret = consumer_secret
+         self.access_token = access_token
+         self.access_token_secret = access_token_secret
+         # The default socket.read size. Default to less than half the size of
+         # a tweet so that it reads tweets with the minimal latency of 2 reads
+         # per tweet. Values higher than ~1kb will increase latency by waiting
+         # for more data to arrive but may also increase throughput by doing
+         # fewer socket read calls.
+         self.chunk_size = chunk_size
+         self.daemon = daemon
+         self.max_retries = max_retries
+         self.proxies = {"https": proxy} if proxy else {}
+         self.verify = verify
  
-         Override this method if you wish to manually handle
-         the stream data. Return False to stop stream and close connection.
-         """
-         data = json.loads(raw_data)
+         self.running = False
+         self.session = None
+         self.thread = None
+         self.user_agent = (
+             f"Python/{python_version()} "
+             f"Requests/{requests.__version__} "
+             f"Tweepy/{tweepy.__version__}"
+         )
+     def _connect(self, method, endpoint, params=None, headers=None, body=None):
+         self.running = True
  
-         if 'in_reply_to_status_id' in data:
-             status = Status.parse(self.api, data)
-             if self.on_status(status) is False:
-                 return False
-         elif 'delete' in data:
-             delete = data['delete']['status']
-             if self.on_delete(delete['id'], delete['user_id']) is False:
-                 return False
-         elif 'event' in data:
-             status = Status.parse(self.api, data)
-             if self.on_event(status) is False:
-                 return False
-         elif 'direct_message' in data:
-             status = Status.parse(self.api, data)
-             if self.on_direct_message(status) is False:
-                 return False
-         elif 'friends' in data:
-             if self.on_friends(data['friends']) is False:
-                 return False
-         elif 'limit' in data:
-             if self.on_limit(data['limit']['track']) is False:
-                 return False
-         elif 'disconnect' in data:
-             if self.on_disconnect(data['disconnect']) is False:
-                 return False
-         elif 'warning' in data:
-             if self.on_warning(data['warning']) is False:
-                 return False
-         else:
-             logging.error("Unknown message type: " + str(raw_data))
+         error_count = 0
+         # https://developer.twitter.com/en/docs/twitter-api/v1/tweets/filter-realtime/guides/connecting
+         stall_timeout = 90
+         network_error_wait = network_error_wait_step = 0.25
+         network_error_wait_max = 16
+         http_error_wait = http_error_wait_start = 5
+         http_error_wait_max = 320
+         http_420_error_wait_start = 60
+         auth = OAuth1(self.consumer_key, self.consumer_secret,
+                       self.access_token, self.access_token_secret)
+         if self.session is None:
+             self.session = requests.Session()
+             self.session.headers["User-Agent"] = self.user_agent
+         url = f"https://stream.twitter.com/1.1/{endpoint}.json"
+         try:
+             while self.running and error_count <= self.max_retries:
+                 try:
+                     with self.session.request(
+                         method, url, params=params, headers=headers, data=body,
+                         timeout=stall_timeout, stream=True, auth=auth,
+                         verify=self.verify, proxies=self.proxies
+                     ) as resp:
+                         if resp.status_code == 200:
+                             error_count = 0
+                             http_error_wait = http_error_wait_start
+                             network_error_wait = network_error_wait_step
+                             self.on_connect()
+                             if not self.running:
+                                 break
+                             for line in resp.iter_lines(
+                                 chunk_size=self.chunk_size
+                             ):
+                                 if line:
+                                     self.on_data(line)
+                                 else:
+                                     self.on_keep_alive()
+                                 if not self.running:
+                                     break
+                             if resp.raw.closed:
+                                 self.on_closed(resp)
+                         else:
+                             self.on_request_error(resp.status_code)
+                             if not self.running:
+                                 break
+                             error_count += 1
+                             if resp.status_code == 420:
+                                 if http_error_wait < http_420_error_wait_start:
+                                     http_error_wait = http_420_error_wait_start
+                             sleep(http_error_wait)
+                             http_error_wait *= 2
+                             if http_error_wait > http_error_wait_max:
+                                 http_error_wait = http_error_wait_max
+                 except (requests.ConnectionError, requests.Timeout,
+                         ssl.SSLError, urllib3.exceptions.ReadTimeoutError,
+                         urllib3.exceptions.ProtocolError) as exc:
+                     # This is still necessary, as a SSLError can actually be
+                     # thrown when using Requests
+                     # If it's not time out treat it like any other exception
+                     if isinstance(exc, ssl.SSLError):
+                         if not (exc.args and "timed out" in str(exc.args[0])):
+                             raise
+                     self.on_connection_error()
+                     if not self.running:
+                         break
  
-     def keep_alive(self):
-         """Called when a keep-alive arrived"""
-         return
+                     sleep(network_error_wait)
  
-     def on_status(self, status):
-         """Called when a new status arrives"""
-         return
+                     network_error_wait += network_error_wait_step
+                     if network_error_wait > network_error_wait_max:
+                         network_error_wait = network_error_wait_max
+         except Exception as exc:
+             self.on_exception(exc)
+         finally:
+             self.session.close()
+             self.running = False
+             self.on_disconnect()
  
-     def on_exception(self, exception):
-         """Called when an unhandled exception occurs."""
-         return
+     def _threaded_connect(self, *args, **kwargs):
+         self.thread = Thread(target=self._connect, name="Tweepy Stream",
+                              args=args, kwargs=kwargs, daemon=self.daemon)
+         self.thread.start()
+         return self.thread
  
-     def on_delete(self, status_id, user_id):
-         """Called when a delete notice arrives for a status"""
-         return
+     def filter(self, *, follow=None, track=None, locations=None,
+                filter_level=None, languages=None, stall_warnings=False,
+                threaded=False):
+         if self.running:
+             raise TweepError("Stream is already connected")
  
-     def on_event(self, status):
-         """Called when a new event arrives"""
-         return
+         method = "POST"
+         endpoint = "statuses/filter"
+         headers = {"Content-Type": "application/x-www-form-urlencoded"}
  
-     def on_direct_message(self, status):
-         """Called when a new direct message arrives"""
-         return
+         body = {}
+         if follow:
 -            body["follow"] = ','.join(follow)
++            body["follow"] = ','.join(map(str, follow))
+         if track:
 -            body["track"] = ','.join(track)
++            body["track"] = ','.join(map(str, track))
+         if locations and len(locations) > 0:
+             if len(locations) % 4:
+                 raise TweepError(
+                     "Number of location coordinates should be a multiple of 4"
+                 )
+             body["locations"] = ','.join(f"{l:.4f}" for l in locations)
+         if filter_level:
+             body["filter_level"] = filter_level
+         if languages:
+             body["language"] = ','.join(map(str, languages))
+         if stall_warnings:
+             body["stall_warnings"] = stall_warnings
  
-     def on_friends(self, friends):
-         """Called when a friends list arrives.
+         if threaded:
+             return self._threaded_connect(method, endpoint, headers=headers,
+                                           body=body)
+         else:
+             self._connect(method, endpoint, headers=headers, body=body)
  
-         friends is a list that contains user_id
-         """
-         return
+     def sample(self, *, languages=None, stall_warnings=False, threaded=False):
+         if self.running:
+             raise TweepError("Stream is already connected")
  
-     def on_limit(self, track):
-         """Called when a limitation notice arrives"""
-         return
+         method = "GET"
+         endpoint = "statuses/sample"
  
-     def on_error(self, status_code):
-         """Called when a non-200 status code is returned"""
-         return False
+         params = {}
+         if languages:
+             params["language"] = ','.join(map(str, languages))
+         if stall_warnings:
+             params["stall_warnings"] = "true"
  
-     def on_timeout(self):
-         """Called when stream connection times out"""
-         return
+         if threaded:
+             return self._threaded_connect(method, endpoint, params=params)
+         else:
+             self._connect(method, endpoint, params=params)
  
-     def on_disconnect(self, notice):
-         """Called when twitter sends a disconnect notice
+     def disconnect(self):
+         self.running = False
  
-         Disconnect codes are listed here:
-         https://dev.twitter.com/docs/streaming-apis/messages#Disconnect_messages_disconnect
-         """
-         return
+     def on_closed(self, resp):
+         """This is called when the stream has been closed by Twitter."""
+         log.error("Stream connection closed by Twitter")
  
-     def on_warning(self, notice):
-         """Called when a disconnection warning message arrives"""
-         return
- class ReadBuffer(object):
-     """Buffer data from the response in a smarter way than httplib/requests can.
-     Tweets are roughly in the 2-12kb range, averaging around 3kb.
-     Requests/urllib3/httplib/socket all use socket.read, which blocks
-     until enough data is returned. On some systems (eg google appengine), socket
-     reads are quite slow. To combat this latency we can read big chunks,
-     but the blocking part means we won't get results until enough tweets
-     have arrived. That may not be a big deal for high throughput systems.
-     For low throughput systems we don't want to sacrafice latency, so we
-     use small chunks so it can read the length and the tweet in 2 read calls.
-     """
-     def __init__(self, stream, chunk_size, encoding='utf-8'):
-         self._stream = stream
-         self._buffer = six.b('')
-         self._chunk_size = chunk_size
-         self._encoding = encoding
-     def read_len(self, length):
-         while not self._stream.closed:
-             if len(self._buffer) >= length:
-                 return self._pop(length)
-             read_len = max(self._chunk_size, length - len(self._buffer))
-             self._buffer += self._stream.read(read_len)
-         return six.b('')
-     def read_line(self, sep=six.b('\n')):
-         """Read the data stream until a given separator is found (default \n)
-         :param sep: Separator to read until. Must by of the bytes type (str in python 2,
-             bytes in python 3)
-         :return: The str of the data read until sep
+     def on_connect(self):
+         """This is called after successfully connecting to the streaming API.
          """
-         start = 0
-         while not self._stream.closed:
-             loc = self._buffer.find(sep, start)
-             if loc >= 0:
-                 return self._pop(loc + len(sep))
-             else:
-                 start = len(self._buffer)
-             self._buffer += self._stream.read(self._chunk_size)
-         return six.b('')
-     def _pop(self, length):
-         r = self._buffer[:length]
-         self._buffer = self._buffer[length:]
-         return r.decode(self._encoding)
+         log.info("Stream connected")
  
+     def on_connection_error(self):
+         """This is called when the stream connection errors or times out."""
+         log.error("Stream connection has errored or timed out")
  
- class Stream(object):
+     def on_disconnect(self):
+         """This is called when the stream has disconnected."""
+         log.info("Stream disconnected")
  
-     host = 'stream.twitter.com'
-     def __init__(self, auth, listener, **options):
-         self.auth = auth
-         self.listener = listener
-         self.running = False
-         self.timeout = options.get("timeout", 300.0)
-         self.retry_count = options.get("retry_count")
-         # values according to
-         # https://dev.twitter.com/docs/streaming-apis/connecting#Reconnecting
-         self.retry_time_start = options.get("retry_time", 5.0)
-         self.retry_420_start = options.get("retry_420", 60.0)
-         self.retry_time_cap = options.get("retry_time_cap", 320.0)
-         self.snooze_time_step = options.get("snooze_time", 0.25)
-         self.snooze_time_cap = options.get("snooze_time_cap", 16)
+     def on_exception(self, exception):
+         """This is called when an unhandled exception occurs."""
+         log.exception("Stream encountered an exception")
  
-         # The default socket.read size. Default to less than half the size of
-         # a tweet so that it reads tweets with the minimal latency of 2 reads
-         # per tweet. Values higher than ~1kb will increase latency by waiting
-         # for more data to arrive but may also increase throughput by doing
-         # fewer socket read calls.
-         self.chunk_size = options.get("chunk_size",  512)
-         self.verify = options.get("verify", True)
-         self.api = API()
-         self.headers = options.get("headers") or {}
-         self.new_session()
-         self.body = None
-         self.retry_time = self.retry_time_start
-         self.snooze_time = self.snooze_time_step
-     def new_session(self):
-         self.session = requests.Session()
-         self.session.headers = self.headers
-         self.session.params = None
-     def _run(self):
-         # Authenticate
-         url = "https://%s%s" % (self.host, self.url)
-         # Connect and process the stream
-         error_counter = 0
-         resp = None
-         exc_info = None
-         while self.running:
-             if self.retry_count is not None:
-                 if error_counter > self.retry_count:
-                     # quit if error count greater than retry count
-                     break
-             try:
-                 auth = self.auth.apply_auth()
-                 resp = self.session.request('POST',
-                                             url,
-                                             data=self.body,
-                                             timeout=self.timeout,
-                                             stream=True,
-                                             auth=auth,
-                                             verify=self.verify)
-                 if resp.status_code != 200:
-                     if self.listener.on_error(resp.status_code) is False:
-                         break
-                     error_counter += 1
-                     if resp.status_code == 420:
-                         self.retry_time = max(self.retry_420_start,
-                                               self.retry_time)
-                     sleep(self.retry_time)
-                     self.retry_time = min(self.retry_time * 2,
-                                           self.retry_time_cap)
-                 else:
-                     error_counter = 0
-                     self.retry_time = self.retry_time_start
-                     self.snooze_time = self.snooze_time_step
-                     self.listener.on_connect()
-                     self._read_loop(resp)
-             except (Timeout, ssl.SSLError) as exc:
-                 # This is still necessary, as a SSLError can actually be
-                 # thrown when using Requests
-                 # If it's not time out treat it like any other exception
-                 if isinstance(exc, ssl.SSLError):
-                     if not (exc.args and 'timed out' in str(exc.args[0])):
-                         exc_info = sys.exc_info()
-                         break
-                 if self.listener.on_timeout() is False:
-                     break
-                 if self.running is False:
-                     break
-                 sleep(self.snooze_time)
-                 self.snooze_time = min(self.snooze_time + self.snooze_time_step,
-                                        self.snooze_time_cap)
-             except Exception as exc:
-                 exc_info = sys.exc_info()
-                 # any other exception is fatal, so kill loop
-                 break
-         # cleanup
-         self.running = False
-         if resp:
-             resp.close()
+     def on_keep_alive(self):
+         """This is called when a keep-alive signal is received."""
+         log.debug("Received keep-alive signal")
  
-         self.new_session()
+     def on_request_error(self, status_code):
+         """This is called when a non-200 HTTP status code is encountered."""
+         log.error("Stream encountered HTTP error: %d", status_code)
  
-         if exc_info:
-             # call a handler first so that the exception can be logged.
-             self.listener.on_exception(exc_info[1])
-             six.reraise(*exc_info)
+     def on_data(self, raw_data):
+         """This is called when raw data is received from the stream.
+         This method handles sending the data to other methods, depending on the
+         message type.
  
-     def _data(self, data):
-         if self.listener.on_data(data) is False:
-             self.running = False
+         https://developer.twitter.com/en/docs/twitter-api/v1/tweets/filter-realtime/guides/streaming-message-types
+         """
+         data = json.loads(raw_data)
  
-     def _read_loop(self, resp):
-         charset = resp.headers.get('content-type', default='')
-         enc_search = re.search('charset=(?P<enc>\S*)', charset)
-         if enc_search is not None:
-             encoding = enc_search.group('enc')
-         else:
-             encoding = 'utf-8'
-         buf = ReadBuffer(resp.raw, self.chunk_size, encoding=encoding)
-         while self.running and not resp.raw.closed:
-             length = 0
-             while not resp.raw.closed:
-                 line = buf.read_line().strip()
-                 if not line:
-                     self.listener.keep_alive()  # keep-alive new lines are expected
-                 elif line.isdigit():
-                     length = int(line)
-                     break
-                 else:
-                     raise TweepError('Expecting length, unexpected value found')
-             next_status_obj = buf.read_len(length)
-             if self.running and next_status_obj:
-                 self._data(next_status_obj)
-             # # Note: keep-alive newlines might be inserted before each length value.
-             # # read until we get a digit...
-             # c = b'\n'
-             # for c in resp.iter_content(decode_unicode=True):
-             #     if c == b'\n':
-             #         continue
-             #     break
-             #
-             # delimited_string = c
-             #
-             # # read rest of delimiter length..
-             # d = b''
-             # for d in resp.iter_content(decode_unicode=True):
-             #     if d != b'\n':
-             #         delimited_string += d
-             #         continue
-             #     break
-             #
-             # # read the next twitter status object
-             # if delimited_string.decode('utf-8').strip().isdigit():
-             #     status_id = int(delimited_string)
-             #     next_status_obj = resp.raw.read(status_id)
-             #     if self.running:
-             #         self._data(next_status_obj.decode('utf-8'))
-         if resp.raw.closed:
-             self.on_closed(resp)
-     def _start(self, async):
-         self.running = True
-         if async:
-             self._thread = Thread(target=self._run)
-             self._thread.start()
-         else:
-             self._run()
+         if "in_reply_to_status_id" in data:
+             status = Status.parse(None, data)
+             return self.on_status(status)
+         if "delete" in data:
+             delete = data["delete"]["status"]
+             return self.on_delete(delete["id"], delete["user_id"])
+         if "disconnect" in data:
+             return self.on_disconnect_message(data["disconnect"])
+         if "limit" in data:
+             return self.on_limit(data["limit"]["track"])
+         if "scrub_geo" in data:
+             return self.on_scrub_geo(data["scrub_geo"])
+         if "status_withheld" in data:
+             return self.on_status_withheld(data["status_withheld"])
+         if "user_withheld" in data:
+             return self.on_user_withheld(data["user_withheld"])
+         if "warning" in data:
+             return self.on_warning(data["warning"])
+         log.error("Received unknown message type: %s", raw_data)
  
-     def on_closed(self, resp):
-         """ Called when the response has been closed by Twitter """
-         pass
-     def userstream(self,
-                    stall_warnings=False,
-                    _with=None,
-                    replies=None,
-                    track=None,
-                    locations=None,
-                    async=False,
-                    encoding='utf8'):
-         self.session.params = {'delimited': 'length'}
-         if self.running:
-             raise TweepError('Stream object already connected!')
-         self.url = '/%s/user.json' % STREAM_VERSION
-         self.host = 'userstream.twitter.com'
-         if stall_warnings:
-             self.session.params['stall_warnings'] = stall_warnings
-         if _with:
-             self.session.params['with'] = _with
-         if replies:
-             self.session.params['replies'] = replies
-         if locations and len(locations) > 0:
-             if len(locations) % 4 != 0:
-                 raise TweepError("Wrong number of locations points, "
-                                  "it has to be a multiple of 4")
-             self.session.params['locations'] = ','.join(['%.2f' % l for l in locations])
-         if track:
-             self.session.params['track'] = u','.join(track).encode(encoding)
+     def on_status(self, status):
+         """This is called when a status is received."""
+         log.debug("Received status: %d", status.id)
  
-         self._start(async)
+     def on_delete(self, status_id, user_id):
+         """This is called when a status deletion notice is received."""
+         log.debug("Received status deletion notice: %d", status_id)
  
-     def firehose(self, count=None, async=False):
-         self.session.params = {'delimited': 'length'}
-         if self.running:
-             raise TweepError('Stream object already connected!')
-         self.url = '/%s/statuses/firehose.json' % STREAM_VERSION
-         if count:
-             self.url += '&count=%s' % count
-         self._start(async)
-     def retweet(self, async=False):
-         self.session.params = {'delimited': 'length'}
-         if self.running:
-             raise TweepError('Stream object already connected!')
-         self.url = '/%s/statuses/retweet.json' % STREAM_VERSION
-         self._start(async)
+     def on_disconnect_message(self, notice):
+         """This is called when a disconnect message is received."""
+         log.warning("Received disconnect message: %s", notice)
  
-     def sample(self, async=False, languages=None, stall_warnings=False):
-         self.session.params = {'delimited': 'length'}
-         if self.running:
-             raise TweepError('Stream object already connected!')
-         self.url = '/%s/statuses/sample.json' % STREAM_VERSION
-         if languages:
-             self.session.params['language'] = ','.join(map(str, languages))
-         if stall_warnings:
-             self.session.params['stall_warnings'] = 'true'
-         self._start(async)
+     def on_limit(self, track):
+         """This is called when a limit notice is received."""
+         log.debug("Received limit notice: %d", track)
  
-     def filter(self, follow=None, track=None, async=False, locations=None,
-                stall_warnings=False, languages=None, encoding='utf8', filter_level=None):
-         self.body = {}
-         self.session.headers['Content-type'] = "application/x-www-form-urlencoded"
-         if self.running:
-             raise TweepError('Stream object already connected!')
-         self.url = '/%s/statuses/filter.json' % STREAM_VERSION
+     def on_scrub_geo(self, notice):
+         """This is called when a location deletion notice is received."""
+         log.debug("Received location deletion notice: %s", notice)
  
-         if follow:
-             self.body['follow'] = u','.join(map(str, follow)).encode(encoding)
+     def on_status_withheld(self, notice):
+         """This is called when a status withheld content notice is received."""
+         log.debug("Received status withheld content notice: %s", notice)
  
-         if track:
-             self.body['track'] = u','.join(map(str, track)).encode(encoding)
+     def on_user_withheld(self, notice):
+         """This is called when a user withheld content notice is received."""
+         log.debug("Received user withheld content notice: %s", notice)
  
-         if locations and len(locations) > 0:
-             if len(locations) % 4 != 0:
-                 raise TweepError("Wrong number of locations points, "
-                                  "it has to be a multiple of 4")
-             self.body['locations'] = u','.join(['%.4f' % l for l in locations])
-         if stall_warnings:
-             self.body['stall_warnings'] = stall_warnings
-         if languages:
-             self.body['language'] = u','.join(map(str, languages))
-         if filter_level:
-             self.body['filter_level'] = filter_level.encode(encoding)
-         self.session.params = {'delimited': 'length'}
-         self.host = 'stream.twitter.com'
-         self._start(async)
-     def sitestream(self, follow, stall_warnings=False,
-                    with_='user', replies=False, async=False):
-         self.body = {}
-         if self.running:
-             raise TweepError('Stream object already connected!')
-         self.url = '/%s/site.json' % STREAM_VERSION
-         self.body['follow'] = u','.join(map(six.text_type, follow))
-         self.body['delimited'] = 'length'
-         if stall_warnings:
-             self.body['stall_warnings'] = stall_warnings
-         if with_:
-             self.body['with'] = with_
-         if replies:
-             self.body['replies'] = replies
-         self._start(async)
-     def disconnect(self):
-         if self.running is False:
-             return
-         self.running = False
+     def on_warning(self, notice):
+         """This is called when a stall warning message is received."""
+         log.warning("Received stall warning: %s", notice)