Bump SQLAlchemy to 2.0 (#626)

* `SQLALCHEMY_WARN_20=1` fixed all removed warnings.

* fix some mypy errors

* fix fetchone

* make format

* ignore annotations

* let's try like this?

* remove

* make format

* Update pyproject.toml

Co-authored-by: Pavol Rusnak <pavol@rusnak.io>

* extract _mapping in fetchone() and fetchall() + fix poetry lock

* fix

* make format

* fix integer indexing of row fields

* Update cashu/mint/crud.py

---------

Co-authored-by: Pavol Rusnak <pavol@rusnak.io>
Co-authored-by: callebtc <93376500+callebtc@users.noreply.github.com>
This commit is contained in:
lollerfirst
2024-10-05 13:32:32 +02:00
committed by GitHub
parent 7fdca3b1a1
commit c5ccf65e4d
7 changed files with 719 additions and 613 deletions

View File

@@ -647,7 +647,6 @@ class WalletKeyset:
int(amount): PublicKey(bytes.fromhex(hex_key), raw=True)
for amount, hex_key in dict(json.loads(serialized)).items()
}
return cls(
id=row["id"],
unit=row["unit"],

View File

@@ -10,7 +10,8 @@ from loguru import logger
from sqlalchemy import text
from sqlalchemy.ext.asyncio import AsyncSession, create_async_engine
from sqlalchemy.orm import sessionmaker
from sqlalchemy.pool import NullPool, QueuePool
from sqlalchemy.pool import AsyncAdaptedQueuePool, NullPool
from sqlalchemy.sql.expression import TextClause
from cashu.core.settings import settings
@@ -64,7 +65,7 @@ class Compat:
def table_with_schema(self, table: str):
return f"{self.references_schema if self.schema else ''}{table}"
# https://docs.sqlalchemy.org/en/14/core/connections.html#sqlalchemy.engine.CursorResult
class Connection(Compat):
def __init__(self, conn: AsyncSession, txn, typ, name, schema):
self.conn = conn
@@ -73,19 +74,20 @@ class Connection(Compat):
self.name = name
self.schema = schema
def rewrite_query(self, query) -> str:
def rewrite_query(self, query) -> TextClause:
if self.type in {POSTGRES, COCKROACH}:
query = query.replace("%", "%%")
query = query.replace("?", "%s")
return text(query)
async def fetchall(self, query: str, values: dict = {}) -> list:
async def fetchall(self, query: str, values: dict = {}):
result = await self.conn.execute(self.rewrite_query(query), values)
return result.all()
return [r._mapping for r in result.all()] # will return [] if result list is empty
async def fetchone(self, query: str, values: dict = {}):
result = await self.conn.execute(self.rewrite_query(query), values)
return result.fetchone()
r = result.fetchone()
return r._mapping if r is not None else None
async def execute(self, query: str, values: dict = {}):
return await self.conn.execute(self.rewrite_query(query), values)
@@ -132,9 +134,9 @@ class Database(Compat):
if not settings.db_connection_pool:
kwargs["poolclass"] = NullPool
elif self.type == POSTGRES:
kwargs["poolclass"] = QueuePool
kwargs["pool_size"] = 50
kwargs["max_overflow"] = 100
kwargs["poolclass"] = AsyncAdaptedQueuePool # type: ignore[assignment]
kwargs["pool_size"] = 50 # type: ignore[assignment]
kwargs["max_overflow"] = 100 # type: ignore[assignment]
self.engine = create_async_engine(database_uri, **kwargs)
self.async_session = sessionmaker(
@@ -281,12 +283,13 @@ class Database(Compat):
async def fetchall(self, query: str, values: dict = {}) -> list:
async with self.connect() as conn:
result = await conn.execute(query, values)
return result.all()
return [r._mapping for r in result.all()]
async def fetchone(self, query: str, values: dict = {}):
async with self.connect() as conn:
result = await conn.execute(query, values)
return result.fetchone()
r = result.fetchone()
return r._mapping if r is not None else None
async def execute(self, query: str, values: dict = {}):
async with self.connect() as conn:

View File

@@ -104,6 +104,6 @@ async def migrate_databases(db: Database, migrations_module):
f"SELECT * FROM {db.table_with_schema('dbversions')}"
)
rows = result.all()
current_versions = {row["db"]: row["version"] for row in rows}
current_versions = {row._mapping["db"]: row._mapping["version"] for row in rows}
matcher = re.compile(r"^m(\d\d\d)_")
await run_migration(db, migrations_module)

View File

@@ -691,7 +691,10 @@ class LedgerCrudSqlite(LedgerCrud):
"""
)
assert row, "Balance not found"
return int(row[0])
# sqlalchemy index of first element
key = next(iter(row))
return int(row[key])
async def get_keyset(
self,

View File

@@ -377,7 +377,7 @@ async def bump_secret_derivation(
)
counter = 0
else:
counter = int(rows[0])
counter = int(rows["counter"])
if not skip:
await (conn or db).execute(
@@ -437,8 +437,8 @@ async def get_seed_and_mnemonic(
)
return (
(
row[0],
row[1],
row["seed"],
row["mnemonic"],
)
if row
else None

1291
poetry.lock generated

File diff suppressed because it is too large Load Diff

View File

@@ -7,7 +7,7 @@ license = "MIT"
[tool.poetry.dependencies]
python = "^3.8.1"
SQLAlchemy = {version = "1.4.52", extras = ["asyncio"]}
SQLAlchemy = {version = "^2.0.35", extras = ["asyncio"]}
click = "^8.1.7"
pydantic = "^1.10.2"
bech32 = "^1.2.0"