Bump SQLAlchemy to 2.0 (#626)

* `SQLALCHEMY_WARN_20=1` fixed all removed warnings.

* fix some mypy errors

* fix fetchone

* make format

* ignore annotations

* let's try like this?

* remove

* make format

* Update pyproject.toml

Co-authored-by: Pavol Rusnak <pavol@rusnak.io>

* extract _mapping in fetchone() and fetchall() + fix poetry lock

* fix

* make format

* fix integer indexing of row fields

* Update cashu/mint/crud.py

---------

Co-authored-by: Pavol Rusnak <pavol@rusnak.io>
Co-authored-by: callebtc <93376500+callebtc@users.noreply.github.com>
This commit is contained in:
lollerfirst
2024-10-05 13:32:32 +02:00
committed by GitHub
parent 7fdca3b1a1
commit c5ccf65e4d
7 changed files with 719 additions and 613 deletions

View File

@@ -647,7 +647,6 @@ class WalletKeyset:
int(amount): PublicKey(bytes.fromhex(hex_key), raw=True) int(amount): PublicKey(bytes.fromhex(hex_key), raw=True)
for amount, hex_key in dict(json.loads(serialized)).items() for amount, hex_key in dict(json.loads(serialized)).items()
} }
return cls( return cls(
id=row["id"], id=row["id"],
unit=row["unit"], unit=row["unit"],

View File

@@ -10,7 +10,8 @@ from loguru import logger
from sqlalchemy import text from sqlalchemy import text
from sqlalchemy.ext.asyncio import AsyncSession, create_async_engine from sqlalchemy.ext.asyncio import AsyncSession, create_async_engine
from sqlalchemy.orm import sessionmaker from sqlalchemy.orm import sessionmaker
from sqlalchemy.pool import NullPool, QueuePool from sqlalchemy.pool import AsyncAdaptedQueuePool, NullPool
from sqlalchemy.sql.expression import TextClause
from cashu.core.settings import settings from cashu.core.settings import settings
@@ -64,7 +65,7 @@ class Compat:
def table_with_schema(self, table: str): def table_with_schema(self, table: str):
return f"{self.references_schema if self.schema else ''}{table}" return f"{self.references_schema if self.schema else ''}{table}"
# https://docs.sqlalchemy.org/en/14/core/connections.html#sqlalchemy.engine.CursorResult
class Connection(Compat): class Connection(Compat):
def __init__(self, conn: AsyncSession, txn, typ, name, schema): def __init__(self, conn: AsyncSession, txn, typ, name, schema):
self.conn = conn self.conn = conn
@@ -73,19 +74,20 @@ class Connection(Compat):
self.name = name self.name = name
self.schema = schema self.schema = schema
def rewrite_query(self, query) -> str: def rewrite_query(self, query) -> TextClause:
if self.type in {POSTGRES, COCKROACH}: if self.type in {POSTGRES, COCKROACH}:
query = query.replace("%", "%%") query = query.replace("%", "%%")
query = query.replace("?", "%s") query = query.replace("?", "%s")
return text(query) return text(query)
async def fetchall(self, query: str, values: dict = {}) -> list: async def fetchall(self, query: str, values: dict = {}):
result = await self.conn.execute(self.rewrite_query(query), values) result = await self.conn.execute(self.rewrite_query(query), values)
return result.all() return [r._mapping for r in result.all()] # will return [] if result list is empty
async def fetchone(self, query: str, values: dict = {}): async def fetchone(self, query: str, values: dict = {}):
result = await self.conn.execute(self.rewrite_query(query), values) result = await self.conn.execute(self.rewrite_query(query), values)
return result.fetchone() r = result.fetchone()
return r._mapping if r is not None else None
async def execute(self, query: str, values: dict = {}): async def execute(self, query: str, values: dict = {}):
return await self.conn.execute(self.rewrite_query(query), values) return await self.conn.execute(self.rewrite_query(query), values)
@@ -132,9 +134,9 @@ class Database(Compat):
if not settings.db_connection_pool: if not settings.db_connection_pool:
kwargs["poolclass"] = NullPool kwargs["poolclass"] = NullPool
elif self.type == POSTGRES: elif self.type == POSTGRES:
kwargs["poolclass"] = QueuePool kwargs["poolclass"] = AsyncAdaptedQueuePool # type: ignore[assignment]
kwargs["pool_size"] = 50 kwargs["pool_size"] = 50 # type: ignore[assignment]
kwargs["max_overflow"] = 100 kwargs["max_overflow"] = 100 # type: ignore[assignment]
self.engine = create_async_engine(database_uri, **kwargs) self.engine = create_async_engine(database_uri, **kwargs)
self.async_session = sessionmaker( self.async_session = sessionmaker(
@@ -281,12 +283,13 @@ class Database(Compat):
async def fetchall(self, query: str, values: dict = {}) -> list: async def fetchall(self, query: str, values: dict = {}) -> list:
async with self.connect() as conn: async with self.connect() as conn:
result = await conn.execute(query, values) result = await conn.execute(query, values)
return result.all() return [r._mapping for r in result.all()]
async def fetchone(self, query: str, values: dict = {}): async def fetchone(self, query: str, values: dict = {}):
async with self.connect() as conn: async with self.connect() as conn:
result = await conn.execute(query, values) result = await conn.execute(query, values)
return result.fetchone() r = result.fetchone()
return r._mapping if r is not None else None
async def execute(self, query: str, values: dict = {}): async def execute(self, query: str, values: dict = {}):
async with self.connect() as conn: async with self.connect() as conn:

View File

@@ -104,6 +104,6 @@ async def migrate_databases(db: Database, migrations_module):
f"SELECT * FROM {db.table_with_schema('dbversions')}" f"SELECT * FROM {db.table_with_schema('dbversions')}"
) )
rows = result.all() rows = result.all()
current_versions = {row["db"]: row["version"] for row in rows} current_versions = {row._mapping["db"]: row._mapping["version"] for row in rows}
matcher = re.compile(r"^m(\d\d\d)_") matcher = re.compile(r"^m(\d\d\d)_")
await run_migration(db, migrations_module) await run_migration(db, migrations_module)

View File

@@ -691,7 +691,10 @@ class LedgerCrudSqlite(LedgerCrud):
""" """
) )
assert row, "Balance not found" assert row, "Balance not found"
return int(row[0])
# sqlalchemy index of first element
key = next(iter(row))
return int(row[key])
async def get_keyset( async def get_keyset(
self, self,

View File

@@ -377,7 +377,7 @@ async def bump_secret_derivation(
) )
counter = 0 counter = 0
else: else:
counter = int(rows[0]) counter = int(rows["counter"])
if not skip: if not skip:
await (conn or db).execute( await (conn or db).execute(
@@ -437,8 +437,8 @@ async def get_seed_and_mnemonic(
) )
return ( return (
( (
row[0], row["seed"],
row[1], row["mnemonic"],
) )
if row if row
else None else None

1291
poetry.lock generated

File diff suppressed because it is too large Load Diff

View File

@@ -7,7 +7,7 @@ license = "MIT"
[tool.poetry.dependencies] [tool.poetry.dependencies]
python = "^3.8.1" python = "^3.8.1"
SQLAlchemy = {version = "1.4.52", extras = ["asyncio"]} SQLAlchemy = {version = "^2.0.35", extras = ["asyncio"]}
click = "^8.1.7" click = "^8.1.7"
pydantic = "^1.10.2" pydantic = "^1.10.2"
bech32 = "^1.2.0" bech32 = "^1.2.0"