add websocket multiplexer

This commit is contained in:
Jacob Plaster
2019-06-18 14:42:19 +08:00
committed by Jacob Plaster
parent d027de9964
commit 13e7d505f8
9 changed files with 240 additions and 76 deletions

View File

@@ -20,10 +20,10 @@ class Client:
def __init__(self, API_KEY=None, API_SECRET=None, rest_host=REST_HOST,
ws_host=WS_HOST, loop=None, logLevel='INFO', dead_man_switch=False,
*args, **kwargs):
ws_capacity=25, *args, **kwargs):
self.loop = loop or asyncio.get_event_loop()
self.ws = BfxWebsocket(API_KEY=API_KEY, API_SECRET=API_SECRET, host=ws_host,
loop=self.loop, logLevel=logLevel, dead_man_switch=dead_man_switch,
*args, **kwargs)
ws_capacity=ws_capacity, *args, **kwargs)
self.rest = BfxRest(API_KEY=API_KEY, API_SECRET=API_SECRET, host=rest_host,
loop=self.loop, logLevel=logLevel, *args, **kwargs)

View File

@@ -5,7 +5,7 @@ This module is used to interact with the bitfinex api
from .client import Client
from .models import (Order, Trade, OrderBook, Subscription, Wallet,
Position, FundingLoan, FundingOffer, FundingCredit)
from .websockets.GenericWebsocket import GenericWebsocket
from .websockets.GenericWebsocket import GenericWebsocket, Socket
from .websockets.BfxWebsocket import BfxWebsocket
from .utils.Decimal import Decimal

View File

@@ -20,10 +20,10 @@ class Client:
def __init__(self, API_KEY=None, API_SECRET=None, rest_host=REST_HOST,
ws_host=WS_HOST, loop=None, logLevel='INFO', dead_man_switch=False,
*args, **kwargs):
ws_capacity=25, *args, **kwargs):
self.loop = loop or asyncio.get_event_loop()
self.ws = BfxWebsocket(API_KEY=API_KEY, API_SECRET=API_SECRET, host=ws_host,
loop=self.loop, logLevel=logLevel, dead_man_switch=dead_man_switch,
*args, **kwargs)
ws_capacity=ws_capacity, *args, **kwargs)
self.rest = BfxRest(API_KEY=API_KEY, API_SECRET=API_SECRET, host=rest_host,
loop=self.loop, logLevel=logLevel, *args, **kwargs)

View File

@@ -21,8 +21,8 @@ class Subscription:
such as unsibscribe and subscribe.
"""
def __init__(self, bfxapi, channel_name, symbol, timeframe=None, **kwargs):
self.bfxapi = bfxapi
def __init__(self, socket, channel_name, symbol, timeframe=None, **kwargs):
self.socket = socket
self.channel_name = channel_name
self.symbol = symbol
self.timeframe = timeframe
@@ -34,6 +34,12 @@ class Subscription:
self.sub_id = generate_sub_id()
self.send_payload = self._generate_payload(**kwargs)
def get_key(self):
"""
Generates a unique key string for the subscription
"""
return "{}_{}".format(self.channel_name, self.key or self.symbol)
def confirm_subscription(self, chan_id):
"""
Update the subscription to confirmed state
@@ -48,13 +54,13 @@ class Subscription:
if not self.is_subscribed():
raise Exception("Subscription is not subscribed to websocket")
payload = {'event': 'unsubscribe', 'chanId': self.chan_id}
await self.bfxapi.get_ws().send(json.dumps(payload))
await self.socket.ws.send(json.dumps(payload))
async def subscribe(self):
"""
Send a subscription request to the bitfinex socket
"""
await self.bfxapi.get_ws().send(json.dumps(self._get_send_payload()))
await self.socket.ws.send(json.dumps(self._get_send_payload()))
def confirm_unsubscribe(self):
"""

View File

@@ -21,8 +21,8 @@ class Subscription:
such as unsibscribe and subscribe.
"""
def __init__(self, bfxapi, channel_name, symbol, timeframe=None, **kwargs):
self.bfxapi = bfxapi
def __init__(self, socket, channel_name, symbol, timeframe=None, **kwargs):
self.socket = socket
self.channel_name = channel_name
self.symbol = symbol
self.timeframe = timeframe
@@ -34,6 +34,12 @@ class Subscription:
self.sub_id = generate_sub_id()
self.send_payload = self._generate_payload(**kwargs)
def get_key(self):
"""
Generates a unique key string for the subscription
"""
return "{}_{}".format(self.channel_name, self.key or self.symbol)
def confirm_subscription(self, chan_id):
"""
Update the subscription to confirmed state
@@ -48,13 +54,13 @@ class Subscription:
if not self.is_subscribed():
raise Exception("Subscription is not subscribed to websocket")
payload = {'event': 'unsubscribe', 'chanId': self.chan_id}
await self.bfxapi.get_ws().send(json.dumps(payload))
await self.socket.ws.send(json.dumps(payload))
async def subscribe(self):
"""
Send a subscription request to the bitfinex socket
"""
await self.bfxapi.get_ws().send(json.dumps(self._get_send_payload()))
await self.socket.ws.send(json.dumps(self._get_send_payload()))
def confirm_unsubscribe(self):
"""

View File

@@ -98,7 +98,7 @@ class BfxWebsocket(GenericWebsocket):
}
def __init__(self, API_KEY=None, API_SECRET=None, host='wss://api-pub.bitfinex.com/ws/2',
manageOrderBooks=False, dead_man_switch=False, logLevel='INFO', parse_float=float,
manageOrderBooks=False, dead_man_switch=False, ws_capacity=25, logLevel='INFO', parse_float=float,
*args, **kwargs):
self.API_KEY = API_KEY
self.API_SECRET = API_SECRET
@@ -106,12 +106,11 @@ class BfxWebsocket(GenericWebsocket):
self.dead_man_switch = dead_man_switch
self.pendingOrders = {}
self.orderBooks = {}
self.ws_capacity = ws_capacity
# How should we store float values? could also be bfxapi.Decimal
# which is slower but has higher precision.
self.parse_float = parse_float
super(BfxWebsocket, self).__init__(
host, logLevel=logLevel, *args, **kwargs)
super(BfxWebsocket, self).__init__(host, logLevel=logLevel, *args, **kwargs)
self.subscriptionManager = SubscriptionManager(self, logLevel=logLevel)
self.orderManager = OrderManager(self, logLevel=logLevel)
self.wallets = WalletManager()
@@ -145,15 +144,15 @@ class BfxWebsocket(GenericWebsocket):
'conf': self._system_conf_handler
}
async def _ws_system_handler(self, msg):
async def _ws_system_handler(self, socketId, msg):
eType = msg.get('event')
if eType in self._WS_SYSTEM_HANDLERS:
await self._WS_SYSTEM_HANDLERS[eType](msg)
await self._WS_SYSTEM_HANDLERS[eType](socketId, msg)
else:
self.logger.warn(
"Unknown websocket event: '{}' {}".format(eType, msg))
"Unknown websocket event (socketId={}): '{}' {}".format(socketId, eType, msg))
async def _ws_data_handler(self, data, raw_message_str):
async def _ws_data_handler(self, socketId, data, raw_message_str):
dataEvent = data[1]
chan_id = data[0]
@@ -172,13 +171,13 @@ class BfxWebsocket(GenericWebsocket):
self.logger.warn(
"Unknown data event: '{}' {}".format(dataEvent, data))
async def _system_info_handler(self, data):
async def _system_info_handler(self, socketId, data):
self.logger.info(data)
if data.get('serverId', None):
# connection has been established
await self.on_open()
await self.on_open(socketId)
async def _system_conf_handler(self, data):
async def _system_conf_handler(self, socketId, data):
flag = data.get('flags')
status = data.get('status')
if flag not in Flags.strings:
@@ -191,19 +190,21 @@ class BfxWebsocket(GenericWebsocket):
self.logger.error(
"Unable to enable config flag {}".format(flagString))
async def _system_subscribed_handler(self, data):
await self.subscriptionManager.confirm_subscription(data)
async def _system_subscribed_handler(self, socket_id, data):
await self.subscriptionManager.confirm_subscription(socket_id, data)
async def _system_unsubscribe_handler(self, data):
await self.subscriptionManager.confirm_unsubscribe(data)
async def _system_unsubscribe_handler(self, socket_id, data):
await self.subscriptionManager.confirm_unsubscribe(socket_id, data)
async def _system_error_handler(self, data):
async def _system_error_handler(self, socketId, data):
err_string = self.ERRORS[data.get('code', 10000)]
err_string = "{} - {}".format(self.ERRORS[data.get('code', 10000)],
data.get("msg", ""))
err_string = "(socketId={}) {} - {}".format(
socketId,
self.ERRORS[data.get('code', 10000)],
data.get("msg", ""))
self._emit('error', err_string)
async def _system_auth_handler(self, data):
async def _system_auth_handler(self, socketId, data):
if data.get('status') == 'FAILED':
raise AuthError(self.ERRORS[data.get('code')])
else:
@@ -359,53 +360,84 @@ class BfxWebsocket(GenericWebsocket):
self.orderBooks[symbol].update_with(obInfo, orig_raw_message)
self._emit('order_book_update', {'symbol': symbol, 'data': obInfo})
async def on_message(self, message):
async def on_message(self, socketId, message):
self.logger.debug(message)
# convert float values to decimal
msg = json.loads(message, parse_float=self.parse_float)
self._emit('all', msg)
if type(msg) is dict:
# System messages are received as json
await self._ws_system_handler(msg)
await self._ws_system_handler(socketId, msg)
elif type(msg) is list:
# All data messages are received as a list
await self._ws_data_handler(msg, message)
await self._ws_data_handler(socketId, msg, message)
else:
self.logger.warn('Unknown websocket response: {}'.format(msg))
self.logger.warn('Unknown (socketId={}) websocket response: {}'.format(socketId, msg))
async def _ws_authenticate_socket(self):
async def _ws_authenticate_socket(self, socketId):
socket = self.sockets[socketId]
socket.set_authenticated()
jdata = generate_auth_payload(self.API_KEY, self.API_SECRET)
if self.dead_man_switch:
jdata['dms'] = 4
await self.get_ws().send(json.dumps(jdata))
await socket.ws.send(json.dumps(jdata))
async def on_open(self):
async def on_open(self, socket_id):
self.logger.info("Websocket opened.")
self._emit('connected')
if len(self.sockets) == 1:
## only call on first connection
self._emit('connected')
# Orders are simulated in backtest mode
if self.API_KEY and self.API_SECRET:
await self._ws_authenticate_socket()
if self.API_KEY and self.API_SECRET and self.get_authenticated_socket() == None:
await self._ws_authenticate_socket(socket_id)
# enable order book checksums
if self.manageOrderBooks:
await self.enable_flag(Flags.CHECKSUM)
# set any existing subscriptions to not subscribed
self.subscriptionManager.set_all_unsubscribed()
self.subscriptionManager.set_unsubscribed_by_socket(socket_id)
# re-subscribe to existing channels
await self.subscriptionManager.resubscribe_all()
await self.subscriptionManager.resubscribe_by_socket(socket_id)
async def _send_auth_command(self, channel_name, data):
payload = [0, channel_name, None, data]
await self.get_ws().send(json.dumps(payload))
socket = self.get_authenticated_socket()
if socket == None:
raise ValueError("authenticated socket connection not found")
if not socket.isConnected:
raise ValueError("authenticated socket not connected")
await socket.ws.send(json.dumps(payload))
def get_orderbook(self, symbol):
return self.orderBooks.get(symbol, None)
def get_socket_capacity(self, socket_id):
return self.ws_capacity - self.subscriptionManager.get_sub_count_by_socket(socket_id)
def get_most_available_socket(self):
bestId = None
bestCount = 0
for socketId in self.sockets:
cap = self.get_socket_capacity(socketId)
if bestId == None or cap > bestCount:
bestId = socketId
bestCount = cap
return self.sockets[socketId]
def get_total_available_capcity(self):
total = 0
for socketId in self.sockets:
total += self.get_socket_capacity(socketId)
return total
async def enable_flag(self, flag):
payload = {
"event": 'conf',
"flags": flag
}
await self.get_ws().send(json.dumps(payload))
def get_orderbook(self, symbol):
return self.orderBooks.get(symbol, None)
# enable on all sockets
for socket in self.sockets.values():
if socket.isConnected:
await socket.ws.send(json.dumps(payload))
async def subscribe(self, *args, **kwargs):
return await self.subscriptionManager.subscribe(*args, **kwargs)

View File

@@ -6,6 +6,8 @@ import asyncio
import websockets
import socket
import json
import time
from threading import Thread
from pyee import EventEmitter
from ..utils.CustomLogger import CustomLogger
@@ -26,55 +28,126 @@ def is_json(myjson):
return False
return True
class Socket():
def __init__(self, sId):
self.ws = None
self.isConnected = False
self.isAuthenticated = False
self.id = sId
def set_connected(self):
self.isConnected = True
def set_disconnected(self):
self.isConnected = False
def set_authenticated(self):
self.isAuthenticated = True
def set_websocket(self, ws):
self.ws = ws
def _start_event_worker():
async def event_sleep_process():
"""
sleeping process for event emitter to schedule on
"""
while True:
await asyncio.sleep(0)
def start_loop(loop):
asyncio.set_event_loop(loop)
loop.run_until_complete(event_sleep_process())
event_loop = asyncio.new_event_loop()
worker = Thread(target=start_loop, args=(event_loop,))
worker.start()
return event_loop
class GenericWebsocket:
"""
Websocket object used to contain the base functionality of a websocket.
Inlcudes an event emitter and a standard websocket client.
"""
def __init__(self, host, logLevel='INFO', loop=None, max_retries=5):
def __init__(self, host, logLevel='INFO', loop=None, max_retries=5,
create_event_emitter=_start_event_worker):
self.host = host
self.logger = CustomLogger('BfxWebsocket', logLevel=logLevel)
self.loop = loop or asyncio.get_event_loop()
self.events = EventEmitter(
scheduler=asyncio.ensure_future, loop=self.loop)
# overide 'error' event to stop it raising an exception
# self.events.on('error', self.on_error)
self.ws = None
self.max_retries = max_retries
self.attempt_retry = True
self.sockets = {}
# start seperate process for the even emitter
eventLoop = create_event_emitter()
self.events = EventEmitter(scheduler=asyncio.ensure_future, loop=eventLoop)
def run(self):
"""
Run the websocket connection indefinitely
Starte the websocket connection. This functions spawns the initial socket
thread and connection.
"""
self.loop.run_until_complete(self._main(self.host))
self._start_new_socket()
def get_task_executable(self):
"""
Get the run indefinitely asyncio task
"""
return self._main(self.host)
return self._run_socket()
async def _connect(self, host):
async with websockets.connect(host) as websocket:
self.ws = websocket
self.logger.info("Wesocket connected to {}".format(host))
def _start_new_socket(self, socketId=None):
if not socketId:
socketId = len(self.sockets)
def start_loop(loop):
asyncio.set_event_loop(loop)
loop.run_until_complete(self._run_socket())
worker_loop = asyncio.new_event_loop()
worker = Thread(target=start_loop, args=(worker_loop,))
worker.start()
return socketId
def _wait_for_socket(self, socket_id):
"""
Block until the given socket connection is open
"""
while True:
socket = self.sockets.get(socket_id, False)
if socket:
if socket.isConnected and socket.ws:
return
time.sleep(0.01)
async def _connect(self, socket):
async with websockets.connect(self.host) as websocket:
self.sockets[socket.id].set_websocket(websocket)
self.sockets[socket.id].set_connected()
self.logger.info("Wesocket connected to {}".format(self.host))
while True:
await asyncio.sleep(0)
message = await websocket.recv()
await self.on_message(message)
await self.on_message(socket.id, message)
def get_ws(self):
return self.ws
def get_socket(self, socketId):
return self.sockets[socketId]
async def _main(self, host):
def get_authenticated_socket(self):
for socketId in self.sockets:
if self.sockets[socketId].isAuthenticated:
return self.sockets[socketId]
return None
async def _run_socket(self):
retries = 0
sId = len(self.sockets)
s = Socket(sId)
self.sockets[sId] = s
while retries < self.max_retries and self.attempt_retry:
try:
await self._connect(host)
await self._connect(s)
retries = 0
except (ConnectionClosed, socket.error) as e:
self.sockets[sId].set_disconnected()
self._emit('disconnected')
if (not self.attempt_retry):
return

View File

@@ -251,7 +251,6 @@ class OrderManager:
key = None
for k in callback_storage.keys():
if k in idents:
print (callback_storage[k])
key = k
# call all callbacks associated with identifier
for callback in callback_storage[k]:

View File

@@ -22,6 +22,16 @@ class SubscriptionManager:
self.bfxapi = bfxapi
self.logger = CustomLogger('BfxSubscriptionManager', logLevel=logLevel)
def get_sub_count_by_socket(self, socket_id):
count = 0
for sub in self.subscriptions_chanid.values():
if sub.socket.id == socket_id and sub.is_subscribed():
count += 1
for sub in self.pending_subscriptions.values():
if sub.socket.id == socket_id:
count += 1
return count
async def subscribe(self, channel_name, symbol, timeframe=None, **kwargs):
"""
Subscribe to a new channel
@@ -31,41 +41,57 @@ class SubscriptionManager:
@param timeframe: sepecifies the data timeframe between each candle (only required
for the candles channel)
"""
# make sure we dont over subscribe the connection
if self.channel_count() >= MAX_CHANNEL_COUNT:
raise Exception("Subscribe error - max channel count ({0}) reached".format(self.channel_count()))
return False
if self.bfxapi.get_total_available_capcity() < 2:
sId = self.bfxapi._start_new_socket()
self.bfxapi._wait_for_socket(sId)
soc = self.bfxapi.sockets[sId]
socket = self.bfxapi.sockets[sId]
else:
# get the socket with the least amount of subscriptions
socket = self.bfxapi.get_most_available_socket()
# create a new subscription
subscription = Subscription(
self.bfxapi, channel_name, symbol, timeframe, **kwargs)
socket, channel_name, symbol, timeframe, **kwargs)
self.logger.info("Subscribing to channel {}".format(channel_name))
key = "{}_{}".format(channel_name, subscription.key or symbol)
self.pending_subscriptions[key] = subscription
self.pending_subscriptions[subscription.get_key()] = subscription
await subscription.subscribe()
async def confirm_subscription(self, raw_ws_data):
async def confirm_subscription(self, socket_id, raw_ws_data):
symbol = raw_ws_data.get("symbol", None)
channel = raw_ws_data.get("channel")
chan_id = raw_ws_data.get("chanId")
key = raw_ws_data.get("key", None)
get_key = "{}_{}".format(channel, key or symbol)
if chan_id in self.subscriptions_chanid:
# subscription has already existed in the past
p_sub = self.subscriptions_chanid[chan_id]
else:
elif get_key in self.pending_subscriptions:
# has just been created and is pending
p_sub = self.pending_subscriptions[get_key]
# remove from pending list
del self.pending_subscriptions[get_key]
else:
# might have been disconnected, so we need to check if exists
# as subscribed but with a new channel ID
for sub in self.subscriptions_chanid.values():
if sub.get_key() == get_key and not sub.is_subscribed():
# delete old channelId
del self.subscriptions_chanid[sub.chan_id]
p_sub = sub
break
if p_sub is None:
# no sub matches confirmation
self.logger.warn("unknown subscription confirmed {}".format(get_key))
return
p_sub.confirm_subscription(chan_id)
# add to confirmed list
self.subscriptions_chanid[chan_id] = p_sub
self.subscriptions_subid[p_sub.sub_id] = p_sub
self.bfxapi._emit('subscribed', p_sub)
async def confirm_unsubscribe(self, raw_ws_data):
async def confirm_unsubscribe(self, socket_id, raw_ws_data):
chan_id = raw_ws_data.get("chanId")
sub = self.subscriptions_chanid[chan_id]
sub.confirm_unsubscribe()
@@ -78,6 +104,14 @@ class SubscriptionManager:
def get(self, chan_id):
return self.subscriptions_chanid[chan_id]
def set_unsubscribed_by_socket(self, socket_id):
"""
Sets all f the subscriptions ot state 'unsubscribed'
"""
for sub in self.subscriptions_chanid.values():
if sub.socket.id == socket_id:
sub.confirm_unsubscribe()
def set_all_unsubscribed(self):
"""
Sets all f the subscriptions ot state 'unsubscribed'
@@ -144,6 +178,20 @@ class SubscriptionManager:
return
await asyncio.wait(*[task_batch])
async def resubscribe_by_socket(self, socket_id):
"""
Unsubscribe channels on socket and then subscribe to all channels
"""
task_batch = []
for sub in self.subscriptions_chanid.values():
if sub.socket.id == socket_id:
task_batch += [
asyncio.ensure_future(self.resubscribe(sub.chan_id))
]
if len(task_batch) == 0:
return
await asyncio.wait(*[task_batch])
async def resubscribe_all(self):
"""
Unsubscribe and then subscribe to all channels