Files
nutshell/cashu/mint/events/client.py
callebtc e846acf946 Mint: add websockets for quote updates (#413)
* add websockets for quote updates

* add test (not working)

* wip: emit events to everyone

* wip: emit events to everyone

* wip, lots of things broken but invoice callback works

* wip

* add wip files

* tests almost passing

* add task

* refactor nut constants

* startup fix

* works with old mints

* wip cli

* fix mypy

* remove automatic invoice test now with websockets

* remove comment

* better logging

* send back response

* add rate limiter to websocket

* add rate limiter to subscriptions

* refactor websocket ratelimit

* websocket tests

* subscription kinds

* doesnt start

* remove circular import

* update

* fix mypy

* move test file in test because it fails if it runs later... dunno why

* adjust websocket NUT-06 settings

* local import and small fix

* disable websockets in CLI if "no_check" is selected

* move subscription test to where it was

* check proof state with callback, add tests

* tests: run mint fixture per module instead of per session

* subscription command name fix

* test per session again

* update test race conditions

* fix tests

* clean up

* tmp

* fix db issues and remove cached secrets

* fix tests

* blindly try pipeline

* remove comments

* comments
2024-06-25 19:20:03 +02:00

232 lines
8.4 KiB
Python

import asyncio
import json
from typing import List, Union
from fastapi import WebSocket
from loguru import logger
from ...core.base import MeltQuote, MintQuote, ProofState
from ...core.db import Database
from ...core.json_rpc.base import (
JSONRPCError,
JSONRPCErrorCode,
JSONRPCErrorResponse,
JSONRPCMethods,
JSONRPCNotficationParams,
JSONRPCNotification,
JSONRPCRequest,
JSONRPCResponse,
JSONRPCStatus,
JSONRPCSubscribeParams,
JSONRPCSubscriptionKinds,
JSONRPCUnsubscribeParams,
JSONRRPCSubscribeResponse,
)
from ...core.models import PostMeltQuoteResponse, PostMintQuoteResponse
from ...core.settings import settings
from ..crud import LedgerCrud
from ..db.read import DbReadHelper
from ..limit import limit_websocket
from .event_model import LedgerEvent
class LedgerEventClientManager:
websocket: WebSocket
subscriptions: dict[
JSONRPCSubscriptionKinds, dict[str, List[str]]
] = {} # [kind, [filter, List[subId]]]
max_subscriptions = 1000
db_read: DbReadHelper
def __init__(self, websocket: WebSocket, db: Database, crud: LedgerCrud):
self.websocket = websocket
self.subscriptions = {}
self.db_read = DbReadHelper(db, crud)
async def start(self):
await self.websocket.accept()
while True:
message = await asyncio.wait_for(
self.websocket.receive(),
timeout=settings.mint_websocket_read_timeout,
)
message_text = message.get("text")
# Check the rate limit
try:
limit_websocket(self.websocket)
except Exception as e:
logger.error(f"Error: {e}")
err = JSONRPCErrorResponse(
error=JSONRPCError(
code=JSONRPCErrorCode.SERVER_ERROR,
message=f"Error: {e}",
),
id=0,
)
await self._send_msg(err)
continue
# Check if message contains text
if not message_text:
continue
# Parse the JSON data
try:
data = json.loads(message_text)
except json.JSONDecodeError as e:
logger.error(f"Error decoding JSON: {e}")
err = JSONRPCErrorResponse(
error=JSONRPCError(
code=JSONRPCErrorCode.PARSE_ERROR,
message=f"Error: {e}",
),
id=0,
)
await self._send_msg(err)
continue
# Parse the JSONRPCRequest
try:
req = JSONRPCRequest.parse_obj(data)
except Exception as e:
err = JSONRPCErrorResponse(
error=JSONRPCError(
code=JSONRPCErrorCode.INVALID_REQUEST,
message=f"Error: {e}",
),
id=0,
)
await self._send_msg(err)
logger.warning(f"Error handling websocket message: {e}")
continue
# Check if the method is valid
try:
JSONRPCMethods(req.method)
except ValueError:
err = JSONRPCErrorResponse(
error=JSONRPCError(
code=JSONRPCErrorCode.METHOD_NOT_FOUND,
message=f"Method not found: {req.method}",
),
id=req.id,
)
await self._send_msg(err)
continue
# Handle the request
try:
logger.debug(f"Request: {req.json()}")
resp = await self._handle_request(req)
# Send the response
await self._send_msg(resp)
except Exception as e:
err = JSONRPCErrorResponse(
error=JSONRPCError(
code=JSONRPCErrorCode.INTERNAL_ERROR,
message=f"Error: {e}",
),
id=req.id,
)
await self._send_msg(err)
continue
async def _handle_request(self, data: JSONRPCRequest) -> JSONRPCResponse:
logger.debug(f"Received websocket message: {data}")
if data.method == JSONRPCMethods.SUBSCRIBE.value:
subscribe_params = JSONRPCSubscribeParams.parse_obj(data.params)
self.add_subscription(
subscribe_params.kind, subscribe_params.filters, subscribe_params.subId
)
result = JSONRRPCSubscribeResponse(
status=JSONRPCStatus.OK,
subId=subscribe_params.subId,
)
return JSONRPCResponse(result=result.dict(), id=data.id)
elif data.method == JSONRPCMethods.UNSUBSCRIBE.value:
unsubscribe_params = JSONRPCUnsubscribeParams.parse_obj(data.params)
self.remove_subscription(unsubscribe_params.subId)
result = JSONRRPCSubscribeResponse(
status=JSONRPCStatus.OK,
subId=unsubscribe_params.subId,
)
return JSONRPCResponse(result=result.dict(), id=data.id)
else:
raise ValueError(f"Invalid method: {data.method}")
async def _send_obj(self, data: dict, subId: str):
resp = JSONRPCNotification(
method=JSONRPCMethods.SUBSCRIBE.value,
params=JSONRPCNotficationParams(subId=subId, payload=data).dict(),
)
await self._send_msg(resp)
async def _send_msg(
self, data: Union[JSONRPCResponse, JSONRPCNotification, JSONRPCErrorResponse]
):
logger.debug(f"Sending websocket message: {data.json()}")
await self.websocket.send_text(data.json())
def add_subscription(
self,
kind: JSONRPCSubscriptionKinds,
filters: List[str],
subId: str,
) -> None:
if kind not in self.subscriptions:
self.subscriptions[kind] = {}
if len(self.subscriptions[kind]) >= self.max_subscriptions:
raise ValueError("Max subscriptions reached")
for filter in filters:
if filter not in self.subscriptions:
self.subscriptions[kind][filter] = []
logger.debug(f"Adding subscription {subId} for filter {filter}")
self.subscriptions[kind][filter].append(subId)
# Initialize the subscription
asyncio.create_task(self._init_subscription(subId, filter, kind))
def remove_subscription(self, subId: str) -> None:
for kind, sub_filters in self.subscriptions.items():
for filter, subs in sub_filters.items():
for sub in subs:
if sub == subId:
logger.debug(
f"Removing subscription {subId} for filter {filter}"
)
self.subscriptions[kind][filter].remove(sub)
return
raise ValueError(f"Subscription not found: {subId}")
def serialize_event(self, event: LedgerEvent) -> dict:
if isinstance(event, MintQuote):
return_dict = PostMintQuoteResponse.parse_obj(event.dict()).dict()
elif isinstance(event, MeltQuote):
return_dict = PostMeltQuoteResponse.parse_obj(event.dict()).dict()
elif isinstance(event, ProofState):
return_dict = event.dict(exclude_unset=True, exclude_none=True)
return return_dict
async def _init_subscription(
self, subId: str, filter: str, kind: JSONRPCSubscriptionKinds
):
if kind == JSONRPCSubscriptionKinds.BOLT11_MINT_QUOTE:
mint_quote = await self.db_read.crud.get_mint_quote(
quote_id=filter, db=self.db_read.db
)
if mint_quote:
await self._send_obj(mint_quote.dict(), subId)
elif kind == JSONRPCSubscriptionKinds.BOLT11_MELT_QUOTE:
melt_quote = await self.db_read.crud.get_melt_quote(
quote_id=filter, db=self.db_read.db
)
if melt_quote:
await self._send_obj(melt_quote.dict(), subId)
elif kind == JSONRPCSubscriptionKinds.PROOF_STATE:
proofs = await self.db_read.get_proofs_states(Ys=[filter])
if len(proofs):
await self._send_obj(proofs[0].dict(), subId)