Files
nutshell/cashu/mint/events/client.py
lollerfirst 3b4f5b56a0 fix: subscription re-init bug (#781)
* fix subscription re-initialization bug

* test for regression issue

* format

* Update cashu/mint/events/client.py

Co-authored-by: callebtc <93376500+callebtc@users.noreply.github.com>

---------

Co-authored-by: callebtc <93376500+callebtc@users.noreply.github.com>
2025-09-08 18:20:21 +02:00

233 lines
8.4 KiB
Python

import asyncio
import json
from typing import List, Union
from fastapi import WebSocket, WebSocketDisconnect
from loguru import logger
from ...core.base import MeltQuote, MintQuote, ProofState
from ...core.db import Database
from ...core.json_rpc.base import (
JSONRPCError,
JSONRPCErrorCode,
JSONRPCErrorResponse,
JSONRPCMethods,
JSONRPCNotficationParams,
JSONRPCNotification,
JSONRPCRequest,
JSONRPCResponse,
JSONRPCStatus,
JSONRPCSubscribeParams,
JSONRPCSubscriptionKinds,
JSONRPCUnsubscribeParams,
JSONRRPCSubscribeResponse,
)
from ...core.models import PostMeltQuoteResponse, PostMintQuoteResponse
from ...core.settings import settings
from ..crud import LedgerCrud
from ..db.read import DbReadHelper
from ..limit import limit_websocket
from .event_model import LedgerEvent
class LedgerEventClientManager:
websocket: WebSocket
subscriptions: dict[
JSONRPCSubscriptionKinds, dict[str, List[str]]
] = {} # [kind, [filter, List[subId]]]
max_subscriptions = 1000
db_read: DbReadHelper
def __init__(self, websocket: WebSocket, db: Database, crud: LedgerCrud):
self.websocket = websocket
self.subscriptions = {}
self.db_read = DbReadHelper(db, crud)
async def start(self):
await self.websocket.accept()
while True:
message = await asyncio.wait_for(
self.websocket.receive(),
timeout=settings.mint_websocket_read_timeout,
)
message_text = message.get("text")
# Check the rate limit
try:
limit_websocket(self.websocket)
except Exception as e:
logger.error(f"Error: {e}")
err = JSONRPCErrorResponse(
error=JSONRPCError(
code=JSONRPCErrorCode.SERVER_ERROR,
message=f"Error: {e}",
),
id=0,
)
await self._send_msg(err)
continue
# Check if message contains text
if not message_text:
continue
# Parse the JSON data
try:
data = json.loads(message_text)
except json.JSONDecodeError as e:
logger.error(f"Error decoding JSON: {e}")
err = JSONRPCErrorResponse(
error=JSONRPCError(
code=JSONRPCErrorCode.PARSE_ERROR,
message=f"Error: {e}",
),
id=0,
)
await self._send_msg(err)
continue
# Parse the JSONRPCRequest
try:
req = JSONRPCRequest.parse_obj(data)
except Exception as e:
err = JSONRPCErrorResponse(
error=JSONRPCError(
code=JSONRPCErrorCode.INVALID_REQUEST,
message=f"Error: {e}",
),
id=0,
)
await self._send_msg(err)
logger.warning(f"Error handling websocket message: {e}")
continue
# Check if the method is valid
try:
JSONRPCMethods(req.method)
except ValueError:
err = JSONRPCErrorResponse(
error=JSONRPCError(
code=JSONRPCErrorCode.METHOD_NOT_FOUND,
message=f"Method not found: {req.method}",
),
id=req.id,
)
await self._send_msg(err)
continue
# Handle the request
try:
logger.debug(f"Request: {req.json()}")
resp = await self._handle_request(req)
# Send the response
await self._send_msg(resp)
except WebSocketDisconnect as e:
logger.debug(f"Websocket disconnected: {e}")
raise e
except Exception as e:
err = JSONRPCErrorResponse(
error=JSONRPCError(
code=JSONRPCErrorCode.INTERNAL_ERROR,
message=f"Error: {e}",
),
id=req.id,
)
await self._send_msg(err)
continue
async def _handle_request(self, data: JSONRPCRequest) -> JSONRPCResponse:
logger.debug(f"Received websocket message: {data}")
if data.method == JSONRPCMethods.SUBSCRIBE.value:
subscribe_params = JSONRPCSubscribeParams.parse_obj(data.params)
self.add_subscription(
subscribe_params.kind, subscribe_params.filters, subscribe_params.subId
)
result = JSONRRPCSubscribeResponse(
status=JSONRPCStatus.OK,
subId=subscribe_params.subId,
)
return JSONRPCResponse(result=result.dict(), id=data.id)
elif data.method == JSONRPCMethods.UNSUBSCRIBE.value:
unsubscribe_params = JSONRPCUnsubscribeParams.parse_obj(data.params)
self.remove_subscription(unsubscribe_params.subId)
result = JSONRRPCSubscribeResponse(
status=JSONRPCStatus.OK,
subId=unsubscribe_params.subId,
)
return JSONRPCResponse(result=result.dict(), id=data.id)
else:
raise ValueError(f"Invalid method: {data.method}")
async def _send_obj(self, data: dict, subId: str):
resp = JSONRPCNotification(
method=JSONRPCMethods.SUBSCRIBE.value,
params=JSONRPCNotficationParams(subId=subId, payload=data).dict(),
)
await self._send_msg(resp)
async def _send_msg(
self, data: Union[JSONRPCResponse, JSONRPCNotification, JSONRPCErrorResponse]
):
logger.debug(f"Sending websocket message: {data.json()}")
await self.websocket.send_text(data.json())
def add_subscription(
self,
kind: JSONRPCSubscriptionKinds,
filters: List[str],
subId: str,
) -> None:
if kind not in self.subscriptions:
self.subscriptions[kind] = {}
if len(self.subscriptions[kind]) >= self.max_subscriptions:
raise ValueError("Max subscriptions reached")
for f in filters:
logger.debug(f"Adding subscription {subId} for filter {f}")
self.subscriptions[kind].setdefault(f, []).append(subId)
# Initialize the subscription
asyncio.create_task(self._init_subscription(subId, f, kind))
def remove_subscription(self, subId: str) -> None:
for kind, sub_filters in self.subscriptions.items():
for filter, subs in sub_filters.items():
for sub in subs:
if sub == subId:
logger.debug(
f"Removing subscription {subId} for filter {filter}"
)
self.subscriptions[kind][filter].remove(sub)
return
raise ValueError(f"Subscription not found: {subId}")
def serialize_event(self, event: LedgerEvent) -> dict:
if isinstance(event, MintQuote):
return_dict = PostMintQuoteResponse.parse_obj(event.dict()).dict()
elif isinstance(event, MeltQuote):
return_dict = PostMeltQuoteResponse.parse_obj(event.dict()).dict()
elif isinstance(event, ProofState):
return_dict = event.dict(exclude_unset=True, exclude_none=True)
return return_dict
async def _init_subscription(
self, subId: str, filter: str, kind: JSONRPCSubscriptionKinds
):
if kind == JSONRPCSubscriptionKinds.BOLT11_MINT_QUOTE:
mint_quote = await self.db_read.crud.get_mint_quote(
quote_id=filter, db=self.db_read.db
)
if mint_quote:
await self._send_obj(mint_quote.dict(), subId)
elif kind == JSONRPCSubscriptionKinds.BOLT11_MELT_QUOTE:
melt_quote = await self.db_read.crud.get_melt_quote(
quote_id=filter, db=self.db_read.db
)
if melt_quote:
await self._send_obj(melt_quote.dict(), subId)
elif kind == JSONRPCSubscriptionKinds.PROOF_STATE:
proofs = await self.db_read.get_proofs_states(Ys=[filter])
if len(proofs):
await self._send_obj(proofs[0].dict(), subId)