formatting: apply black formatter

This commit is contained in:
Cameron Yick
2021-08-14 21:03:11 -04:00
parent fc654d1e48
commit f4a9a04b2e
7 changed files with 320 additions and 263 deletions

View File

@@ -50,6 +50,10 @@ clean-test: ## remove test and coverage artifacts
lint: ## check style with flake8 lint: ## check style with flake8
flake8 tiingo tests flake8 tiingo tests
format: ## apply opinionated formatting
black tiingo/
test: ## run tests quickly with the default Python test: ## run tests quickly with the default Python
py.test py.test

View File

@@ -3,4 +3,4 @@ from tiingo.api import TiingoClient
from tiingo.wsclient import TiingoWebsocketClient from tiingo.wsclient import TiingoWebsocketClient
__author__ = """Cameron Yick""" __author__ = """Cameron Yick"""
__email__ = 'cameron.yick@enigma.com' __email__ = "cameron.yick@enigma.com"

View File

@@ -1,2 +1,2 @@
# -*- coding: utf-8 -*- # -*- coding: utf-8 -*-
__version__ = '0.14.0' __version__ = "0.14.0"

View File

@@ -16,10 +16,12 @@ from tiingo.exceptions import (
InstallPandasException, InstallPandasException,
APIColumnNameError, APIColumnNameError,
InvalidFrequencyError, InvalidFrequencyError,
MissingRequiredArgumentError) MissingRequiredArgumentError,
)
try: try:
import pandas as pd import pandas as pd
pandas_is_installed = True pandas_is_installed = True
except ImportError: except ImportError:
pandas_is_installed = False pandas_is_installed = False
@@ -40,11 +42,13 @@ def get_zipfile_from_response(response):
def get_buffer_from_zipfile(zipfile, filename): def get_buffer_from_zipfile(zipfile, filename):
if sys.version_info < (3, 0): # python 2 if sys.version_info < (3, 0): # python 2
from StringIO import StringIO from StringIO import StringIO
return StringIO(zipfile.read(filename)) return StringIO(zipfile.read(filename))
else: # python 3 else: # python 3
# Source: # Source:
# https://stackoverflow.com/questions/5627954/py3k-how-do-you-read-a-file-inside-a-zip-file-as-text-not-bytes # https://stackoverflow.com/questions/5627954/py3k-how-do-you-read-a-file-inside-a-zip-file-as-text-not-bytes
from io import (TextIOWrapper, BytesIO) from io import TextIOWrapper, BytesIO
return TextIOWrapper(BytesIO(zipfile.read(filename))) return TextIOWrapper(BytesIO(zipfile.read(filename)))
@@ -52,9 +56,9 @@ def dict_to_object(item, object_name):
"""Converts a python dict to a namedtuple, saving memory.""" """Converts a python dict to a namedtuple, saving memory."""
fields = item.keys() fields = item.keys()
values = item.values() values = item.values()
return json.loads(json.dumps(item), return json.loads(
object_hook=lambda d: json.dumps(item), object_hook=lambda d: namedtuple(object_name, fields)(*values)
namedtuple(object_name, fields)(*values)) )
class TiingoClient(RestClient): class TiingoClient(RestClient):
@@ -69,28 +73,30 @@ class TiingoClient(RestClient):
self._base_url = "https://api.tiingo.com" self._base_url = "https://api.tiingo.com"
try: try:
api_key = self._config['api_key'] api_key = self._config["api_key"]
except KeyError: except KeyError:
api_key = os.environ.get('TIINGO_API_KEY') api_key = os.environ.get("TIINGO_API_KEY")
self._api_key = api_key self._api_key = api_key
if not(api_key): if not (api_key):
raise RuntimeError("Tiingo API Key not provided. Please provide" raise RuntimeError(
" via environment variable or config argument.") "Tiingo API Key not provided. Please provide"
" via environment variable or config argument."
)
self._headers = { self._headers = {
'Authorization': "Token {}".format(api_key), "Authorization": "Token {}".format(api_key),
'Content-Type': 'application/json', "Content-Type": "application/json",
'User-Agent': 'tiingo-python-client {}'.format(VERSION) "User-Agent": "tiingo-python-client {}".format(VERSION),
} }
self._frequency_pattern = re.compile('^[0-9]+(min|hour)$', re.IGNORECASE) self._frequency_pattern = re.compile("^[0-9]+(min|hour)$", re.IGNORECASE)
def __repr__(self): def __repr__(self):
return '<TiingoClient(url="{}")>'.format(self._base_url) return '<TiingoClient(url="{}")>'.format(self._base_url)
def _is_eod_frequency(self,frequency): def _is_eod_frequency(self, frequency):
return frequency.lower() in ['daily', 'weekly', 'monthly', 'annually'] return frequency.lower() in ["daily", "weekly", "monthly", "annually"]
# TICKER PRICE ENDPOINTS # TICKER PRICE ENDPOINTS
# https://api.tiingo.com/docs/tiingo/daily # https://api.tiingo.com/docs/tiingo/daily
@@ -102,29 +108,30 @@ class TiingoClient(RestClient):
Tickers for unrelated products are omitted. Tickers for unrelated products are omitted.
https://apimedia.tiingo.com/docs/tiingo/daily/supported_tickers.zip https://apimedia.tiingo.com/docs/tiingo/daily/supported_tickers.zip
""" """
listing_file_url = "https://apimedia.tiingo.com/docs/tiingo/daily/supported_tickers.zip" listing_file_url = (
"https://apimedia.tiingo.com/docs/tiingo/daily/supported_tickers.zip"
)
response = requests.get(listing_file_url) response = requests.get(listing_file_url)
zipdata = get_zipfile_from_response(response) zipdata = get_zipfile_from_response(response)
raw_csv = get_buffer_from_zipfile(zipdata, 'supported_tickers.csv') raw_csv = get_buffer_from_zipfile(zipdata, "supported_tickers.csv")
reader = csv.DictReader(raw_csv) reader = csv.DictReader(raw_csv)
if not len(assetTypes): if not len(assetTypes):
return [row for row in reader] return [row for row in reader]
assetTypesSet = set(assetTypes) assetTypesSet = set(assetTypes)
return [row for row in reader return [row for row in reader if row.get("assetType") in assetTypesSet]
if row.get('assetType') in assetTypesSet]
def list_stock_tickers(self): def list_stock_tickers(self):
return self.list_tickers(['Stock']) return self.list_tickers(["Stock"])
def list_etf_tickers(self): def list_etf_tickers(self):
return self.list_tickers(['ETF']) return self.list_tickers(["ETF"])
def list_fund_tickers(self): def list_fund_tickers(self):
return self.list_tickers(['Mutual Fund']) return self.list_tickers(["Mutual Fund"])
def get_ticker_metadata(self, ticker, fmt='json'): def get_ticker_metadata(self, ticker, fmt="json"):
"""Return metadata for 1 ticker """Return metadata for 1 ticker
Use TiingoClient.list_tickers() to get available options Use TiingoClient.list_tickers() to get available options
@@ -132,11 +139,11 @@ class TiingoClient(RestClient):
ticker (str) : Unique identifier for stock ticker (str) : Unique identifier for stock
""" """
url = "tiingo/daily/{}".format(ticker) url = "tiingo/daily/{}".format(ticker)
response = self._request('GET', url) response = self._request("GET", url)
data = response.json() data = response.json()
if fmt == 'json': if fmt == "json":
return data return data
elif fmt == 'object': elif fmt == "object":
return dict_to_object(data, "Ticker") return dict_to_object(data, "Ticker")
def _invalid_frequency(self, frequency): def _invalid_frequency(self, frequency):
@@ -145,7 +152,9 @@ class TiingoClient(RestClient):
:param frequency (string): frequency string :param frequency (string): frequency string
:return (boolean): :return (boolean):
""" """
is_valid = self._is_eod_frequency(frequency) or re.match(self._frequency_pattern, frequency) is_valid = self._is_eod_frequency(frequency) or re.match(
self._frequency_pattern, frequency
)
return not is_valid return not is_valid
def _get_url(self, ticker, frequency): def _get_url(self, ticker, frequency):
@@ -157,8 +166,10 @@ class TiingoClient(RestClient):
:return (string): url :return (string): url
""" """
if self._invalid_frequency(frequency): if self._invalid_frequency(frequency):
etext = ("Error: {} is an invalid frequency. Check Tiingo API documentation " etext = (
"for valid EOD or intraday frequency format.") "Error: {} is an invalid frequency. Check Tiingo API documentation "
"for valid EOD or intraday frequency format."
)
raise InvalidFrequencyError(etext.format(frequency)) raise InvalidFrequencyError(etext.format(frequency))
else: else:
if self._is_eod_frequency(frequency): if self._is_eod_frequency(frequency):
@@ -179,19 +190,19 @@ class TiingoClient(RestClient):
all of the available data will be returned. In the event of a list of tickers, all of the available data will be returned. In the event of a list of tickers,
this parameter is required. this parameter is required.
""" """
url = self._get_url(ticker, params['resampleFreq']) url = self._get_url(ticker, params["resampleFreq"])
response = self._request('GET', url, params=params) response = self._request("GET", url, params=params)
if params['format'] == 'csv': if params["format"] == "csv":
if sys.version_info < (3, 0): # python 2 if sys.version_info < (3, 0): # python 2
from StringIO import StringIO from StringIO import StringIO
else: # python 3 else: # python 3
from io import StringIO from io import StringIO
df = pd.read_csv(StringIO(response.content.decode('utf-8'))) df = pd.read_csv(StringIO(response.content.decode("utf-8")))
else: else:
df = pd.DataFrame(response.json()) df = pd.DataFrame(response.json())
df.set_index('date', inplace=True) df.set_index("date", inplace=True)
if metric_name is not None: if metric_name is not None:
prices = df[metric_name] prices = df[metric_name]
@@ -203,13 +214,13 @@ class TiingoClient(RestClient):
# Localize to UTC to ensure equivalence between data returned in json format and # Localize to UTC to ensure equivalence between data returned in json format and
# csv format. Tiingo daily data requested in csv format does not include a timezone. # csv format. Tiingo daily data requested in csv format does not include a timezone.
if prices.index.tz is None: if prices.index.tz is None:
prices.index = prices.index.tz_localize('UTC') prices.index = prices.index.tz_localize("UTC")
return prices return prices
def get_ticker_price(self, ticker, def get_ticker_price(
startDate=None, endDate=None, self, ticker, startDate=None, endDate=None, fmt="json", frequency="daily"
fmt='json', frequency='daily'): ):
"""By default, return latest EOD Composite Price for a stock ticker. """By default, return latest EOD Composite Price for a stock ticker.
On average, each feed contains 3 data sources. On average, each feed contains 3 data sources.
@@ -225,18 +236,18 @@ class TiingoClient(RestClient):
""" """
url = self._get_url(ticker, frequency) url = self._get_url(ticker, frequency)
params = { params = {
'format': fmt if fmt != "object" else 'json', # conversion local "format": fmt if fmt != "object" else "json", # conversion local
'resampleFreq': frequency "resampleFreq": frequency,
} }
if startDate: if startDate:
params['startDate'] = startDate params["startDate"] = startDate
if endDate: if endDate:
params['endDate'] = endDate params["endDate"] = endDate
# TODO: evaluate whether to stream CSV to cache on disk, or # TODO: evaluate whether to stream CSV to cache on disk, or
# load as array in memory, or just pass plain text # load as array in memory, or just pass plain text
response = self._request('GET', url, params=params) response = self._request("GET", url, params=params)
if fmt == "json": if fmt == "json":
return response.json() return response.json()
elif fmt == "object": elif fmt == "object":
@@ -245,11 +256,17 @@ class TiingoClient(RestClient):
else: else:
return response.content.decode("utf-8") return response.content.decode("utf-8")
def get_dataframe(self, tickers, def get_dataframe(
startDate=None, endDate=None, metric_name=None, self,
frequency='daily', fmt='json'): tickers,
startDate=None,
endDate=None,
metric_name=None,
frequency="daily",
fmt="json",
):
""" Return a pandas.DataFrame of historical prices for one or more ticker symbols. """Return a pandas.DataFrame of historical prices for one or more ticker symbols.
By default, return latest EOD Composite Price for a list of stock tickers. By default, return latest EOD Composite Price for a list of stock tickers.
On average, each feed contains 3 data sources. On average, each feed contains 3 data sources.
@@ -270,53 +287,77 @@ class TiingoClient(RestClient):
fmt (string): 'csv' or 'json' fmt (string): 'csv' or 'json'
""" """
valid_columns = {'open', 'high', 'low', 'close', 'volume', 'adjOpen', 'adjHigh', 'adjLow', valid_columns = {
'adjClose', 'adjVolume', 'divCash', 'splitFactor'} "open",
"high",
"low",
"close",
"volume",
"adjOpen",
"adjHigh",
"adjLow",
"adjClose",
"adjVolume",
"divCash",
"splitFactor",
}
if metric_name is not None and metric_name not in valid_columns: if metric_name is not None and metric_name not in valid_columns:
raise APIColumnNameError('Valid data items are: ' + str(valid_columns)) raise APIColumnNameError("Valid data items are: " + str(valid_columns))
if metric_name is None and isinstance(tickers, list): if metric_name is None and isinstance(tickers, list):
raise MissingRequiredArgumentError("""When tickers is provided as a list, metric_name is a required argument. raise MissingRequiredArgumentError(
Please provide a metric_name, or call this method with one ticker at a time.""") """When tickers is provided as a list, metric_name is a required argument.
Please provide a metric_name, or call this method with one ticker at a time."""
)
params = { params = {"format": fmt, "resampleFreq": frequency}
'format': fmt,
'resampleFreq': frequency
}
if startDate: if startDate:
params['startDate'] = startDate params["startDate"] = startDate
if endDate: if endDate:
params['endDate'] = endDate params["endDate"] = endDate
if pandas_is_installed: if pandas_is_installed:
if type(tickers) is str: if type(tickers) is str:
prices = self._request_pandas( prices = self._request_pandas(
ticker=tickers, params=params, metric_name=metric_name) ticker=tickers, params=params, metric_name=metric_name
)
else: else:
prices = pd.DataFrame() prices = pd.DataFrame()
for stock in tickers: for stock in tickers:
ticker_series = self._request_pandas( ticker_series = self._request_pandas(
ticker=stock, params=params, metric_name=metric_name) ticker=stock, params=params, metric_name=metric_name
)
ticker_series = ticker_series.rename(stock) ticker_series = ticker_series.rename(stock)
prices = pd.concat([prices, ticker_series], axis=1, sort=True) prices = pd.concat([prices, ticker_series], axis=1, sort=True)
return prices return prices
else: else:
error_message = ("Pandas is not installed, but .get_ticker_price() was " error_message = (
"Pandas is not installed, but .get_ticker_price() was "
"called with fmt=pandas. In order to install tiingo with " "called with fmt=pandas. In order to install tiingo with "
"pandas, reinstall with pandas as an optional dependency. \n" "pandas, reinstall with pandas as an optional dependency. \n"
"Install tiingo with pandas dependency: \'pip install tiingo[pandas]\'\n" "Install tiingo with pandas dependency: 'pip install tiingo[pandas]'\n"
"Alternatively, just install pandas: pip install pandas.") "Alternatively, just install pandas: pip install pandas."
)
raise InstallPandasException(error_message) raise InstallPandasException(error_message)
# NEWS FEEDS # NEWS FEEDS
# tiingo/news # tiingo/news
def get_news(self, tickers=[], tags=[], sources=[], startDate=None, def get_news(
endDate=None, limit=100, offset=0, sortBy="publishedDate", self,
tickers=[],
tags=[],
sources=[],
startDate=None,
endDate=None,
limit=100,
offset=0,
sortBy="publishedDate",
onlyWithTickers=False, onlyWithTickers=False,
fmt='json'): fmt="json",
):
"""Return list of news articles matching given search terms """Return list of news articles matching given search terms
https://api.tiingo.com/docs/tiingo/news https://api.tiingo.com/docs/tiingo/news
@@ -334,24 +375,24 @@ class TiingoClient(RestClient):
""" """
url = "tiingo/news" url = "tiingo/news"
params = { params = {
'limit': limit, "limit": limit,
'offset': offset, "offset": offset,
'sortBy': sortBy, "sortBy": sortBy,
'tickers': tickers, "tickers": tickers,
'source': (",").join(sources) if sources else None, "source": (",").join(sources) if sources else None,
'tags': tags, "tags": tags,
'startDate': startDate, "startDate": startDate,
'endDate': endDate, "endDate": endDate,
'onlyWithTickers': onlyWithTickers "onlyWithTickers": onlyWithTickers,
} }
response = self._request('GET', url, params=params) response = self._request("GET", url, params=params)
data = response.json() data = response.json()
if fmt == 'json': if fmt == "json":
return data return data
elif fmt == 'object': elif fmt == "object":
return [dict_to_object(item, "NewsArticle") for item in data] return [dict_to_object(item, "NewsArticle") for item in data]
def get_bulk_news(self, file_id=None, fmt='json'): def get_bulk_news(self, file_id=None, fmt="json"):
"""Only available to institutional clients. """Only available to institutional clients.
If ID is NOT provided, return array of available file_ids. If ID is NOT provided, return array of available file_ids.
If ID is provided, provides URL which you can use to download your If ID is provided, provides URL which you can use to download your
@@ -362,76 +403,85 @@ class TiingoClient(RestClient):
else: else:
url = "tiingo/news/bulk_download" url = "tiingo/news/bulk_download"
response = self._request('GET', url) response = self._request("GET", url)
data = response.json() data = response.json()
if fmt == 'json': if fmt == "json":
return data return data
elif fmt == 'object': elif fmt == "object":
return dict_to_object(data, "BulkNews") return dict_to_object(data, "BulkNews")
# Crypto # Crypto
# tiingo/crypto # tiingo/crypto
def get_crypto_top_of_book(self, tickers=[], exchanges=[], def get_crypto_top_of_book(
includeRawExchangeData=False, convertCurrency=None): self,
url = 'tiingo/crypto/top' tickers=[],
params = { exchanges=[],
'tickers': ','.join(tickers) includeRawExchangeData=False,
} convertCurrency=None,
):
url = "tiingo/crypto/top"
params = {"tickers": ",".join(tickers)}
if len(exchanges): if len(exchanges):
params['exchanges'] = ','.join(exchanges) params["exchanges"] = ",".join(exchanges)
if includeRawExchangeData is True: if includeRawExchangeData is True:
params['includeRawExchangeData'] = True params["includeRawExchangeData"] = True
if convertCurrency: if convertCurrency:
params['convertCurrency'] = convertCurrency params["convertCurrency"] = convertCurrency
response = self._request('GET', url, params=params) response = self._request("GET", url, params=params)
return response.json() return response.json()
def get_crypto_price_history(self, tickers=[], baseCurrency=None, def get_crypto_price_history(
startDate=None, endDate=None, exchanges=[], self,
consolidateBaseCurrency=False, includeRawExchangeData=False, tickers=[],
resampleFreq=None, convertCurrency=None): baseCurrency=None,
url = 'tiingo/crypto/prices' startDate=None,
params = { endDate=None,
'tickers': ','.join(tickers) exchanges=[],
} consolidateBaseCurrency=False,
includeRawExchangeData=False,
resampleFreq=None,
convertCurrency=None,
):
url = "tiingo/crypto/prices"
params = {"tickers": ",".join(tickers)}
if startDate: if startDate:
params['startDate'] = startDate params["startDate"] = startDate
if endDate: if endDate:
params['endDate'] = endDate params["endDate"] = endDate
if len(exchanges): if len(exchanges):
params['exchanges'] = ','.join(exchanges) params["exchanges"] = ",".join(exchanges)
if consolidateBaseCurrency is True: if consolidateBaseCurrency is True:
params['consolidateBaseCurrency'] = ','.join(consolidateBaseCurrency) params["consolidateBaseCurrency"] = ",".join(consolidateBaseCurrency)
if includeRawExchangeData is True: if includeRawExchangeData is True:
params['includeRawExchangeData'] = includeRawExchangeData params["includeRawExchangeData"] = includeRawExchangeData
if resampleFreq: if resampleFreq:
params['resampleFreq'] = resampleFreq params["resampleFreq"] = resampleFreq
if convertCurrency: if convertCurrency:
params['convertCurrency'] = convertCurrency params["convertCurrency"] = convertCurrency
response = self._request('GET', url, params=params) response = self._request("GET", url, params=params)
return response.json() return response.json()
def get_crypto_metadata(self, tickers=[], fmt='json'): def get_crypto_metadata(self, tickers=[], fmt="json"):
url = 'tiingo/crypto' url = "tiingo/crypto"
params = { params = {
'tickers': ','.join(tickers), "tickers": ",".join(tickers),
'format': fmt, "format": fmt,
} }
response = self._request('GET', url, params=params) response = self._request("GET", url, params=params)
if fmt == 'csv': if fmt == "csv":
return response.content.decode("utf-8") return response.content.decode("utf-8")
else: else:
return response.json() return response.json()
# FUNDAMENTAL DEFINITIONS # FUNDAMENTAL DEFINITIONS
# tiingo/fundamentals/definitions # tiingo/fundamentals/definitions
def get_fundamentals_definitions(self, tickers=[], fmt='json'): def get_fundamentals_definitions(self, tickers=[], fmt="json"):
"""Return definitions for fundamentals for specified tickers """Return definitions for fundamentals for specified tickers
https://api.tiingo.com/documentation/fundamentals https://api.tiingo.com/documentation/fundamentals
@@ -440,20 +490,16 @@ class TiingoClient(RestClient):
fmt (string): 'csv' or 'json' fmt (string): 'csv' or 'json'
""" """
url = "tiingo/fundamentals/definitions" url = "tiingo/fundamentals/definitions"
params = { params = {"tickers": tickers, "format": fmt}
'tickers': tickers, response = self._request("GET", url, params=params)
'format': fmt if fmt == "json":
}
response = self._request('GET', url, params=params)
if fmt == 'json':
return response.json() return response.json()
elif fmt == 'csv': elif fmt == "csv":
return response.content.decode("utf-8") return response.content.decode("utf-8")
# FUNDAMENTAL DAILY # FUNDAMENTAL DAILY
# tiingo/fundamentals/<ticker>/daily # tiingo/fundamentals/<ticker>/daily
def get_fundamentals_daily(self, ticker, fmt='json', def get_fundamentals_daily(self, ticker, fmt="json", startDate=None, endDate=None):
startDate=None, endDate=None):
"""Returns metrics which rely on daily price-updates """Returns metrics which rely on daily price-updates
https://api.tiingo.com/documentation/fundamentals https://api.tiingo.com/documentation/fundamentals
@@ -464,22 +510,19 @@ class TiingoClient(RestClient):
startDate, endDate [date]: Boundaries of search window startDate, endDate [date]: Boundaries of search window
fmt (string): 'csv' or 'json' fmt (string): 'csv' or 'json'
""" """
url = 'tiingo/fundamentals/{}/daily'.format(ticker) url = "tiingo/fundamentals/{}/daily".format(ticker)
params = { params = {"startDate": startDate, "endDate": endDate, "format": fmt}
'startDate': startDate, response = self._request("GET", url, params=params)
'endDate': endDate, if fmt == "json":
'format': fmt
}
response = self._request('GET', url, params=params)
if fmt == 'json':
return response.json() return response.json()
elif fmt == 'csv': elif fmt == "csv":
return response.content.decode("utf-8") return response.content.decode("utf-8")
# FUNDAMENTAL STATEMENTS # FUNDAMENTAL STATEMENTS
# tiingo/fundamentals/<ticker>/statements # tiingo/fundamentals/<ticker>/statements
def get_fundamentals_statements(self, ticker, asReported=False, fmt='json', def get_fundamentals_statements(
startDate=None, endDate=None): self, ticker, asReported=False, fmt="json", startDate=None, endDate=None
):
"""Returns data that is extracted from quarterly and annual statements. """Returns data that is extracted from quarterly and annual statements.
https://api.tiingo.com/documentation/fundamentals https://api.tiingo.com/documentation/fundamentals
@@ -494,19 +537,19 @@ class TiingoClient(RestClient):
fmt (string): 'csv' or 'json' fmt (string): 'csv' or 'json'
""" """
if asReported: if asReported:
asReported = 'true' asReported = "true"
else: else:
asReported = 'false' asReported = "false"
url = 'tiingo/fundamentals/{}/statements'.format(ticker) url = "tiingo/fundamentals/{}/statements".format(ticker)
params = { params = {
'startDate': startDate, "startDate": startDate,
'endDate': endDate, "endDate": endDate,
'asReported': asReported, "asReported": asReported,
'format': fmt "format": fmt,
} }
response = self._request('GET', url, params=params) response = self._request("GET", url, params=params)
if fmt == 'json': if fmt == "json":
return response.json() return response.json()
elif fmt == 'csv': elif fmt == "csv":
return response.content.decode("utf-8") return response.content.decode("utf-8")

View File

@@ -10,5 +10,6 @@ class APIColumnNameError(Exception):
class InvalidFrequencyError(Exception): class InvalidFrequencyError(Exception):
pass pass
class MissingRequiredArgumentError(Exception): class MissingRequiredArgumentError(Exception):
pass pass

View File

@@ -12,7 +12,6 @@ class RestClientError(Exception):
class RestClient(object): class RestClient(object):
def __init__(self, config={}): def __init__(self, config={}):
"""Base class for interacting with RESTful APIs """Base class for interacting with RESTful APIs
Child class MUST have a ._base_url property! Child class MUST have a ._base_url property!
@@ -28,7 +27,7 @@ class RestClient(object):
self._headers = {} self._headers = {}
self._base_url = "" self._base_url = ""
if config.get('session'): if config.get("session"):
self._session = requests.Session() self._session = requests.Session()
else: else:
self._session = requests self._session = requests
@@ -44,10 +43,9 @@ class RestClient(object):
url (str): path appended to the base_url to create request url (str): path appended to the base_url to create request
**kwargs: passed directly to a requests.request object **kwargs: passed directly to a requests.request object
""" """
resp = self._session.request(method, resp = self._session.request(
'{}/{}'.format(self._base_url, url), method, "{}/{}".format(self._base_url, url), headers=self._headers, **kwargs
headers=self._headers, )
**kwargs)
try: try:
resp.raise_for_status() resp.raise_for_status()

View File

@@ -3,8 +3,9 @@ import websocket
import json import json
from tiingo.exceptions import MissingRequiredArgumentError from tiingo.exceptions import MissingRequiredArgumentError
class TiingoWebsocketClient: class TiingoWebsocketClient:
''' """
from tiingo import TiingoWebsocketClient from tiingo import TiingoWebsocketClient
def cb_fn(msg): def cb_fn(msg):
@@ -36,56 +37,66 @@ class TiingoWebsocketClient:
# any logic should be implemented in the callback function # any logic should be implemented in the callback function
TiingoWebsocketClient(subscribe,endpoint="iex",on_msg_cb=cb_fn) TiingoWebsocketClient(subscribe,endpoint="iex",on_msg_cb=cb_fn)
while True:pass while True:pass
''' """
def __init__(self,config=None,endpoint=None,on_msg_cb=None): def __init__(self, config=None, endpoint=None, on_msg_cb=None):
self._base_url = "wss://api.tiingo.com" self._base_url = "wss://api.tiingo.com"
self.config = {} if config is None else config self.config = {} if config is None else config
try: try:
api_key = self.config['authorization'] api_key = self.config["authorization"]
except KeyError: except KeyError:
api_key = os.environ.get('TIINGO_API_KEY') api_key = os.environ.get("TIINGO_API_KEY")
self.config.update({"authorization":api_key}) self.config.update({"authorization": api_key})
self._api_key = api_key self._api_key = api_key
if not(api_key): if not (api_key):
raise RuntimeError("Tiingo API Key not provided. Please provide" raise RuntimeError(
"Tiingo API Key not provided. Please provide"
" via environment variable or config argument." " via environment variable or config argument."
"Notice that this config dict takes the API Key as authorization ") "Notice that this config dict takes the API Key as authorization "
)
self.endpoint = endpoint self.endpoint = endpoint
if not (self.endpoint=="iex" or self.endpoint=="fx" or self.endpoint=="crypto"): if not (
self.endpoint == "iex" or self.endpoint == "fx" or self.endpoint == "crypto"
):
raise AttributeError("Endpoint must be defined as either (iex,fx,crypto) ") raise AttributeError("Endpoint must be defined as either (iex,fx,crypto) ")
self.on_msg_cb = on_msg_cb self.on_msg_cb = on_msg_cb
if not self.on_msg_cb: if not self.on_msg_cb:
raise MissingRequiredArgumentError("please define on_msg_cb It's a callback that gets called when new messages arrive " raise MissingRequiredArgumentError(
"please define on_msg_cb It's a callback that gets called when new messages arrive "
"Example:" "Example:"
"def cb_fn(msg):" "def cb_fn(msg):"
" print(msg)") " print(msg)"
)
websocket.enableTrace(False) websocket.enableTrace(False)
ws = websocket.WebSocketApp("{0}/{1}".format(self._base_url,self.endpoint), ws = websocket.WebSocketApp(
on_message = self.get_on_msg_cb(), "{0}/{1}".format(self._base_url, self.endpoint),
on_error = self.on_error, on_message=self.get_on_msg_cb(),
on_close = self.on_close, on_error=self.on_error,
on_open = self.get_on_open(self.config)) on_close=self.on_close,
on_open=self.get_on_open(self.config),
)
ws.run_forever() ws.run_forever()
def get_on_open(self,config): def get_on_open(self, config):
# the methods passed to websocketClient have to be unbounded if we want WebSocketApp to pass everything correctly # the methods passed to websocketClient have to be unbounded if we want WebSocketApp to pass everything correctly
# see websocket-client/#471 # see websocket-client/#471
def on_open(ws): def on_open(ws):
ws.send(json.dumps(config)) ws.send(json.dumps(config))
return on_open return on_open
def get_on_msg_cb(self): def get_on_msg_cb(self):
def on_msg_cb_local(ws,msg): def on_msg_cb_local(ws, msg):
self.on_msg_cb(msg) self.on_msg_cb(msg)
return return
return on_msg_cb_local return on_msg_cb_local
# since methods need to be unbound in order for websocketClient these methods don't have a self as their first parameter # since methods need to be unbound in order for websocketClient these methods don't have a self as their first parameter