mirror of
https://github.com/aljazceru/nutshell.git
synced 2025-12-21 19:14:19 +01:00
* fix subscription re-initialization bug * test for regression issue * format * Update cashu/mint/events/client.py Co-authored-by: callebtc <93376500+callebtc@users.noreply.github.com> --------- Co-authored-by: callebtc <93376500+callebtc@users.noreply.github.com>
233 lines
8.4 KiB
Python
233 lines
8.4 KiB
Python
import asyncio
|
|
import json
|
|
from typing import List, Union
|
|
|
|
from fastapi import WebSocket, WebSocketDisconnect
|
|
from loguru import logger
|
|
|
|
from ...core.base import MeltQuote, MintQuote, ProofState
|
|
from ...core.db import Database
|
|
from ...core.json_rpc.base import (
|
|
JSONRPCError,
|
|
JSONRPCErrorCode,
|
|
JSONRPCErrorResponse,
|
|
JSONRPCMethods,
|
|
JSONRPCNotficationParams,
|
|
JSONRPCNotification,
|
|
JSONRPCRequest,
|
|
JSONRPCResponse,
|
|
JSONRPCStatus,
|
|
JSONRPCSubscribeParams,
|
|
JSONRPCSubscriptionKinds,
|
|
JSONRPCUnsubscribeParams,
|
|
JSONRRPCSubscribeResponse,
|
|
)
|
|
from ...core.models import PostMeltQuoteResponse, PostMintQuoteResponse
|
|
from ...core.settings import settings
|
|
from ..crud import LedgerCrud
|
|
from ..db.read import DbReadHelper
|
|
from ..limit import limit_websocket
|
|
from .event_model import LedgerEvent
|
|
|
|
|
|
class LedgerEventClientManager:
|
|
websocket: WebSocket
|
|
subscriptions: dict[
|
|
JSONRPCSubscriptionKinds, dict[str, List[str]]
|
|
] = {} # [kind, [filter, List[subId]]]
|
|
max_subscriptions = 1000
|
|
db_read: DbReadHelper
|
|
|
|
def __init__(self, websocket: WebSocket, db: Database, crud: LedgerCrud):
|
|
self.websocket = websocket
|
|
self.subscriptions = {}
|
|
self.db_read = DbReadHelper(db, crud)
|
|
|
|
async def start(self):
|
|
await self.websocket.accept()
|
|
|
|
while True:
|
|
message = await asyncio.wait_for(
|
|
self.websocket.receive(),
|
|
timeout=settings.mint_websocket_read_timeout,
|
|
)
|
|
message_text = message.get("text")
|
|
|
|
# Check the rate limit
|
|
try:
|
|
limit_websocket(self.websocket)
|
|
except Exception as e:
|
|
logger.error(f"Error: {e}")
|
|
err = JSONRPCErrorResponse(
|
|
error=JSONRPCError(
|
|
code=JSONRPCErrorCode.SERVER_ERROR,
|
|
message=f"Error: {e}",
|
|
),
|
|
id=0,
|
|
)
|
|
await self._send_msg(err)
|
|
continue
|
|
|
|
# Check if message contains text
|
|
if not message_text:
|
|
continue
|
|
|
|
# Parse the JSON data
|
|
try:
|
|
data = json.loads(message_text)
|
|
except json.JSONDecodeError as e:
|
|
logger.error(f"Error decoding JSON: {e}")
|
|
err = JSONRPCErrorResponse(
|
|
error=JSONRPCError(
|
|
code=JSONRPCErrorCode.PARSE_ERROR,
|
|
message=f"Error: {e}",
|
|
),
|
|
id=0,
|
|
)
|
|
await self._send_msg(err)
|
|
continue
|
|
|
|
# Parse the JSONRPCRequest
|
|
try:
|
|
req = JSONRPCRequest.parse_obj(data)
|
|
except Exception as e:
|
|
err = JSONRPCErrorResponse(
|
|
error=JSONRPCError(
|
|
code=JSONRPCErrorCode.INVALID_REQUEST,
|
|
message=f"Error: {e}",
|
|
),
|
|
id=0,
|
|
)
|
|
await self._send_msg(err)
|
|
logger.warning(f"Error handling websocket message: {e}")
|
|
continue
|
|
|
|
# Check if the method is valid
|
|
try:
|
|
JSONRPCMethods(req.method)
|
|
except ValueError:
|
|
err = JSONRPCErrorResponse(
|
|
error=JSONRPCError(
|
|
code=JSONRPCErrorCode.METHOD_NOT_FOUND,
|
|
message=f"Method not found: {req.method}",
|
|
),
|
|
id=req.id,
|
|
)
|
|
await self._send_msg(err)
|
|
continue
|
|
|
|
# Handle the request
|
|
try:
|
|
logger.debug(f"Request: {req.json()}")
|
|
resp = await self._handle_request(req)
|
|
# Send the response
|
|
await self._send_msg(resp)
|
|
except WebSocketDisconnect as e:
|
|
logger.debug(f"Websocket disconnected: {e}")
|
|
raise e
|
|
except Exception as e:
|
|
err = JSONRPCErrorResponse(
|
|
error=JSONRPCError(
|
|
code=JSONRPCErrorCode.INTERNAL_ERROR,
|
|
message=f"Error: {e}",
|
|
),
|
|
id=req.id,
|
|
)
|
|
await self._send_msg(err)
|
|
continue
|
|
|
|
async def _handle_request(self, data: JSONRPCRequest) -> JSONRPCResponse:
|
|
logger.debug(f"Received websocket message: {data}")
|
|
if data.method == JSONRPCMethods.SUBSCRIBE.value:
|
|
subscribe_params = JSONRPCSubscribeParams.parse_obj(data.params)
|
|
self.add_subscription(
|
|
subscribe_params.kind, subscribe_params.filters, subscribe_params.subId
|
|
)
|
|
result = JSONRRPCSubscribeResponse(
|
|
status=JSONRPCStatus.OK,
|
|
subId=subscribe_params.subId,
|
|
)
|
|
return JSONRPCResponse(result=result.dict(), id=data.id)
|
|
elif data.method == JSONRPCMethods.UNSUBSCRIBE.value:
|
|
unsubscribe_params = JSONRPCUnsubscribeParams.parse_obj(data.params)
|
|
self.remove_subscription(unsubscribe_params.subId)
|
|
result = JSONRRPCSubscribeResponse(
|
|
status=JSONRPCStatus.OK,
|
|
subId=unsubscribe_params.subId,
|
|
)
|
|
return JSONRPCResponse(result=result.dict(), id=data.id)
|
|
else:
|
|
raise ValueError(f"Invalid method: {data.method}")
|
|
|
|
async def _send_obj(self, data: dict, subId: str):
|
|
resp = JSONRPCNotification(
|
|
method=JSONRPCMethods.SUBSCRIBE.value,
|
|
params=JSONRPCNotficationParams(subId=subId, payload=data).dict(),
|
|
)
|
|
await self._send_msg(resp)
|
|
|
|
async def _send_msg(
|
|
self, data: Union[JSONRPCResponse, JSONRPCNotification, JSONRPCErrorResponse]
|
|
):
|
|
logger.debug(f"Sending websocket message: {data.json()}")
|
|
await self.websocket.send_text(data.json())
|
|
|
|
def add_subscription(
|
|
self,
|
|
kind: JSONRPCSubscriptionKinds,
|
|
filters: List[str],
|
|
subId: str,
|
|
) -> None:
|
|
if kind not in self.subscriptions:
|
|
self.subscriptions[kind] = {}
|
|
|
|
if len(self.subscriptions[kind]) >= self.max_subscriptions:
|
|
raise ValueError("Max subscriptions reached")
|
|
|
|
for f in filters:
|
|
logger.debug(f"Adding subscription {subId} for filter {f}")
|
|
self.subscriptions[kind].setdefault(f, []).append(subId)
|
|
# Initialize the subscription
|
|
asyncio.create_task(self._init_subscription(subId, f, kind))
|
|
|
|
def remove_subscription(self, subId: str) -> None:
|
|
for kind, sub_filters in self.subscriptions.items():
|
|
for filter, subs in sub_filters.items():
|
|
for sub in subs:
|
|
if sub == subId:
|
|
logger.debug(
|
|
f"Removing subscription {subId} for filter {filter}"
|
|
)
|
|
self.subscriptions[kind][filter].remove(sub)
|
|
return
|
|
raise ValueError(f"Subscription not found: {subId}")
|
|
|
|
def serialize_event(self, event: LedgerEvent) -> dict:
|
|
if isinstance(event, MintQuote):
|
|
return_dict = PostMintQuoteResponse.parse_obj(event.dict()).dict()
|
|
elif isinstance(event, MeltQuote):
|
|
return_dict = PostMeltQuoteResponse.parse_obj(event.dict()).dict()
|
|
elif isinstance(event, ProofState):
|
|
return_dict = event.dict(exclude_unset=True, exclude_none=True)
|
|
return return_dict
|
|
|
|
async def _init_subscription(
|
|
self, subId: str, filter: str, kind: JSONRPCSubscriptionKinds
|
|
):
|
|
if kind == JSONRPCSubscriptionKinds.BOLT11_MINT_QUOTE:
|
|
mint_quote = await self.db_read.crud.get_mint_quote(
|
|
quote_id=filter, db=self.db_read.db
|
|
)
|
|
if mint_quote:
|
|
await self._send_obj(mint_quote.dict(), subId)
|
|
elif kind == JSONRPCSubscriptionKinds.BOLT11_MELT_QUOTE:
|
|
melt_quote = await self.db_read.crud.get_melt_quote(
|
|
quote_id=filter, db=self.db_read.db
|
|
)
|
|
if melt_quote:
|
|
await self._send_obj(melt_quote.dict(), subId)
|
|
elif kind == JSONRPCSubscriptionKinds.PROOF_STATE:
|
|
proofs = await self.db_read.get_proofs_states(Ys=[filter])
|
|
if len(proofs):
|
|
await self._send_obj(proofs[0].dict(), subId)
|