mirror of
https://github.com/aljazceru/nutshell.git
synced 2025-12-20 10:34:20 +01:00
Bump SQLAlchemy to 2.0 (#626)
* `SQLALCHEMY_WARN_20=1` fixed all removed warnings. * fix some mypy errors * fix fetchone * make format * ignore annotations * let's try like this? * remove * make format * Update pyproject.toml Co-authored-by: Pavol Rusnak <pavol@rusnak.io> * extract _mapping in fetchone() and fetchall() + fix poetry lock * fix * make format * fix integer indexing of row fields * Update cashu/mint/crud.py --------- Co-authored-by: Pavol Rusnak <pavol@rusnak.io> Co-authored-by: callebtc <93376500+callebtc@users.noreply.github.com>
This commit is contained in:
@@ -647,7 +647,6 @@ class WalletKeyset:
|
||||
int(amount): PublicKey(bytes.fromhex(hex_key), raw=True)
|
||||
for amount, hex_key in dict(json.loads(serialized)).items()
|
||||
}
|
||||
|
||||
return cls(
|
||||
id=row["id"],
|
||||
unit=row["unit"],
|
||||
|
||||
@@ -10,7 +10,8 @@ from loguru import logger
|
||||
from sqlalchemy import text
|
||||
from sqlalchemy.ext.asyncio import AsyncSession, create_async_engine
|
||||
from sqlalchemy.orm import sessionmaker
|
||||
from sqlalchemy.pool import NullPool, QueuePool
|
||||
from sqlalchemy.pool import AsyncAdaptedQueuePool, NullPool
|
||||
from sqlalchemy.sql.expression import TextClause
|
||||
|
||||
from cashu.core.settings import settings
|
||||
|
||||
@@ -64,7 +65,7 @@ class Compat:
|
||||
def table_with_schema(self, table: str):
|
||||
return f"{self.references_schema if self.schema else ''}{table}"
|
||||
|
||||
|
||||
# https://docs.sqlalchemy.org/en/14/core/connections.html#sqlalchemy.engine.CursorResult
|
||||
class Connection(Compat):
|
||||
def __init__(self, conn: AsyncSession, txn, typ, name, schema):
|
||||
self.conn = conn
|
||||
@@ -73,19 +74,20 @@ class Connection(Compat):
|
||||
self.name = name
|
||||
self.schema = schema
|
||||
|
||||
def rewrite_query(self, query) -> str:
|
||||
def rewrite_query(self, query) -> TextClause:
|
||||
if self.type in {POSTGRES, COCKROACH}:
|
||||
query = query.replace("%", "%%")
|
||||
query = query.replace("?", "%s")
|
||||
return text(query)
|
||||
|
||||
async def fetchall(self, query: str, values: dict = {}) -> list:
|
||||
async def fetchall(self, query: str, values: dict = {}):
|
||||
result = await self.conn.execute(self.rewrite_query(query), values)
|
||||
return result.all()
|
||||
return [r._mapping for r in result.all()] # will return [] if result list is empty
|
||||
|
||||
async def fetchone(self, query: str, values: dict = {}):
|
||||
result = await self.conn.execute(self.rewrite_query(query), values)
|
||||
return result.fetchone()
|
||||
r = result.fetchone()
|
||||
return r._mapping if r is not None else None
|
||||
|
||||
async def execute(self, query: str, values: dict = {}):
|
||||
return await self.conn.execute(self.rewrite_query(query), values)
|
||||
@@ -132,9 +134,9 @@ class Database(Compat):
|
||||
if not settings.db_connection_pool:
|
||||
kwargs["poolclass"] = NullPool
|
||||
elif self.type == POSTGRES:
|
||||
kwargs["poolclass"] = QueuePool
|
||||
kwargs["pool_size"] = 50
|
||||
kwargs["max_overflow"] = 100
|
||||
kwargs["poolclass"] = AsyncAdaptedQueuePool # type: ignore[assignment]
|
||||
kwargs["pool_size"] = 50 # type: ignore[assignment]
|
||||
kwargs["max_overflow"] = 100 # type: ignore[assignment]
|
||||
|
||||
self.engine = create_async_engine(database_uri, **kwargs)
|
||||
self.async_session = sessionmaker(
|
||||
@@ -281,12 +283,13 @@ class Database(Compat):
|
||||
async def fetchall(self, query: str, values: dict = {}) -> list:
|
||||
async with self.connect() as conn:
|
||||
result = await conn.execute(query, values)
|
||||
return result.all()
|
||||
return [r._mapping for r in result.all()]
|
||||
|
||||
async def fetchone(self, query: str, values: dict = {}):
|
||||
async with self.connect() as conn:
|
||||
result = await conn.execute(query, values)
|
||||
return result.fetchone()
|
||||
r = result.fetchone()
|
||||
return r._mapping if r is not None else None
|
||||
|
||||
async def execute(self, query: str, values: dict = {}):
|
||||
async with self.connect() as conn:
|
||||
|
||||
@@ -104,6 +104,6 @@ async def migrate_databases(db: Database, migrations_module):
|
||||
f"SELECT * FROM {db.table_with_schema('dbversions')}"
|
||||
)
|
||||
rows = result.all()
|
||||
current_versions = {row["db"]: row["version"] for row in rows}
|
||||
current_versions = {row._mapping["db"]: row._mapping["version"] for row in rows}
|
||||
matcher = re.compile(r"^m(\d\d\d)_")
|
||||
await run_migration(db, migrations_module)
|
||||
|
||||
@@ -691,7 +691,10 @@ class LedgerCrudSqlite(LedgerCrud):
|
||||
"""
|
||||
)
|
||||
assert row, "Balance not found"
|
||||
return int(row[0])
|
||||
|
||||
# sqlalchemy index of first element
|
||||
key = next(iter(row))
|
||||
return int(row[key])
|
||||
|
||||
async def get_keyset(
|
||||
self,
|
||||
|
||||
@@ -377,7 +377,7 @@ async def bump_secret_derivation(
|
||||
)
|
||||
counter = 0
|
||||
else:
|
||||
counter = int(rows[0])
|
||||
counter = int(rows["counter"])
|
||||
|
||||
if not skip:
|
||||
await (conn or db).execute(
|
||||
@@ -437,8 +437,8 @@ async def get_seed_and_mnemonic(
|
||||
)
|
||||
return (
|
||||
(
|
||||
row[0],
|
||||
row[1],
|
||||
row["seed"],
|
||||
row["mnemonic"],
|
||||
)
|
||||
if row
|
||||
else None
|
||||
|
||||
1291
poetry.lock
generated
1291
poetry.lock
generated
File diff suppressed because it is too large
Load Diff
@@ -7,7 +7,7 @@ license = "MIT"
|
||||
|
||||
[tool.poetry.dependencies]
|
||||
python = "^3.8.1"
|
||||
SQLAlchemy = {version = "1.4.52", extras = ["asyncio"]}
|
||||
SQLAlchemy = {version = "^2.0.35", extras = ["asyncio"]}
|
||||
click = "^8.1.7"
|
||||
pydantic = "^1.10.2"
|
||||
bech32 = "^1.2.0"
|
||||
|
||||
Reference in New Issue
Block a user