diff --git a/bfxapi/models/Subscription.py b/bfxapi/models/Subscription.py index bc32ba4..72da70c 100644 --- a/bfxapi/models/Subscription.py +++ b/bfxapi/models/Subscription.py @@ -13,8 +13,8 @@ class Subscription: such as unsibscribe and subscribe. """ - def __init__(self, ws, channel_name, symbol, timeframe=None, **kwargs): - self._ws = ws + def __init__(self, bfxapi, channel_name, symbol, timeframe=None, **kwargs): + self.bfxapi = bfxapi self.channel_name = channel_name self.symbol = symbol self.timeframe = timeframe @@ -40,13 +40,13 @@ class Subscription: if not self.is_subscribed(): raise Exception("Subscription is not subscribed to websocket") payload = {'event': 'unsubscribe', 'chanId': self.chan_id} - await self._ws.send(json.dumps(payload)) + await self.bfxapi.get_ws().send(json.dumps(payload)) async def subscribe(self): """ Send a subscription request to the bitfinex socket """ - await self._ws.send(json.dumps(self._get_send_payload())) + await self.bfxapi.get_ws().send(json.dumps(self._get_send_payload())) def confirm_unsubscribe(self): """ diff --git a/bfxapi/models/subscription.py b/bfxapi/models/subscription.py index bc32ba4..72da70c 100644 --- a/bfxapi/models/subscription.py +++ b/bfxapi/models/subscription.py @@ -13,8 +13,8 @@ class Subscription: such as unsibscribe and subscribe. """ - def __init__(self, ws, channel_name, symbol, timeframe=None, **kwargs): - self._ws = ws + def __init__(self, bfxapi, channel_name, symbol, timeframe=None, **kwargs): + self.bfxapi = bfxapi self.channel_name = channel_name self.symbol = symbol self.timeframe = timeframe @@ -40,13 +40,13 @@ class Subscription: if not self.is_subscribed(): raise Exception("Subscription is not subscribed to websocket") payload = {'event': 'unsubscribe', 'chanId': self.chan_id} - await self._ws.send(json.dumps(payload)) + await self.bfxapi.get_ws().send(json.dumps(payload)) async def subscribe(self): """ Send a subscription request to the bitfinex socket """ - await self._ws.send(json.dumps(self._get_send_payload())) + await self.bfxapi.get_ws().send(json.dumps(self._get_send_payload())) def confirm_unsubscribe(self): """ diff --git a/bfxapi/websockets/BfxWebsocket.py b/bfxapi/websockets/BfxWebsocket.py index a03097b..2babff9 100644 --- a/bfxapi/websockets/BfxWebsocket.py +++ b/bfxapi/websockets/BfxWebsocket.py @@ -369,7 +369,7 @@ class BfxWebsocket(GenericWebsocket): jdata = generate_auth_payload(self.API_KEY, self.API_SECRET) if self.dead_man_switch: jdata['dms'] = 4 - await self.ws.send(json.dumps(jdata)) + await self.get_ws().send(json.dumps(jdata)) async def on_open(self): self.logger.info("Websocket opened.") @@ -380,17 +380,19 @@ class BfxWebsocket(GenericWebsocket): # enable order book checksums if self.manageOrderBooks: await self.enable_flag(Flags.CHECKSUM) + # resubscribe to any channels + await self.subscriptionManager.resubscribe_all() async def _send_auth_command(self, channel_name, data): payload = [0, channel_name, None, data] - await self.ws.send(json.dumps(payload)) + await self.get_ws().send(json.dumps(payload)) async def enable_flag(self, flag): payload = { "event": 'conf', "flags": flag } - await self.ws.send(json.dumps(payload)) + await self.get_ws().send(json.dumps(payload)) def get_orderbook(self, symbol): return self.orderBooks.get(symbol, None) diff --git a/bfxapi/websockets/GenericWebsocket.py b/bfxapi/websockets/GenericWebsocket.py index 0801fbf..c80da3d 100644 --- a/bfxapi/websockets/GenericWebsocket.py +++ b/bfxapi/websockets/GenericWebsocket.py @@ -4,11 +4,14 @@ Module used as a interfeace to describe a generick websocket client import asyncio import websockets +import socket import json from pyee import EventEmitter from ..utils.CustomLogger import CustomLogger +# websocket exceptions +from websockets.exceptions import ConnectionClosed class AuthError(Exception): """ @@ -16,7 +19,6 @@ class AuthError(Exception): """ pass - def is_json(myjson): try: json_object = json.loads(myjson) @@ -24,20 +26,20 @@ def is_json(myjson): return False return True - class GenericWebsocket: """ Websocket object used to contain the base functionality of a websocket. Inlcudes an event emitter and a standard websocket client. """ - def __init__(self, host, logLevel='INFO', loop=None): + def __init__(self, host, logLevel='INFO', loop=None, max_retries=5): self.host = host self.logger = CustomLogger('BfxWebsocket', logLevel=logLevel) self.loop = loop or asyncio.get_event_loop() self.events = EventEmitter( scheduler=asyncio.ensure_future, loop=self.loop) self.ws = None + self.max_retries = max_retries def run(self): """ @@ -51,15 +53,33 @@ class GenericWebsocket: """ return self._main(self.host) - async def _main(self, host): + async def _connect(self, host): async with websockets.connect(host) as websocket: self.ws = websocket - self.logger.info("Wesocket connectedt to {}".format(self.host)) + self.logger.info("Wesocket connected to {}".format(self.host)) while True: await asyncio.sleep(0) message = await websocket.recv() await self.on_message(message) + def get_ws(self): + return self.ws + + async def _main(self, host): + retries = 0 + while retries < self.max_retries: + try: + await self._connect(host) + retries = 0 + except (ConnectionClosed, socket.error) as e: + self.logger.error(str(e)) + retries += 1 + # wait 5 seconds befor retrying + self.logger.info("Waiting 5 seconds befor retrying...") + await asyncio.sleep(5) + self.logger.info("Reconnect attempt {}/{}".format(retries, self.max_retries)) + self.logger.info("Unable to connect to websocket.") + def remove_all_listeners(self, event): """ Remove all listeners from event emitter diff --git a/bfxapi/websockets/SubscriptionManager.py b/bfxapi/websockets/SubscriptionManager.py index df961ee..cf2fe87 100644 --- a/bfxapi/websockets/SubscriptionManager.py +++ b/bfxapi/websockets/SubscriptionManager.py @@ -32,7 +32,7 @@ class SubscriptionManager: """ # create a new subscription subscription = Subscription( - self.bfxapi.ws, channel_name, symbol, timeframe, **kwargs) + self.bfxapi, channel_name, symbol, timeframe, **kwargs) self.logger.info("Subscribing to channel {}".format(channel_name)) key = "{}_{}".format(channel_name, subscription.key or symbol) self.pending_subscriptions[key] = subscription @@ -121,6 +121,8 @@ class SubscriptionManager: task_batch += [ asyncio.ensure_future(self.unsubscribe(chan_id)) ] + if len(task_batch) == 0: + return await asyncio.wait(*[task_batch]) async def resubscribe_all(self): @@ -132,4 +134,6 @@ class SubscriptionManager: task_batch += [ asyncio.ensure_future(self.resubscribe(chan_id)) ] + if len(task_batch) == 0: + return await asyncio.wait(*[task_batch])