mirror of
https://github.com/aljazceru/nutshell.git
synced 2025-12-20 10:34:20 +01:00
Bump SQLAlchemy to 2.0 (#626)
* `SQLALCHEMY_WARN_20=1` fixed all removed warnings. * fix some mypy errors * fix fetchone * make format * ignore annotations * let's try like this? * remove * make format * Update pyproject.toml Co-authored-by: Pavol Rusnak <pavol@rusnak.io> * extract _mapping in fetchone() and fetchall() + fix poetry lock * fix * make format * fix integer indexing of row fields * Update cashu/mint/crud.py --------- Co-authored-by: Pavol Rusnak <pavol@rusnak.io> Co-authored-by: callebtc <93376500+callebtc@users.noreply.github.com>
This commit is contained in:
@@ -647,7 +647,6 @@ class WalletKeyset:
|
|||||||
int(amount): PublicKey(bytes.fromhex(hex_key), raw=True)
|
int(amount): PublicKey(bytes.fromhex(hex_key), raw=True)
|
||||||
for amount, hex_key in dict(json.loads(serialized)).items()
|
for amount, hex_key in dict(json.loads(serialized)).items()
|
||||||
}
|
}
|
||||||
|
|
||||||
return cls(
|
return cls(
|
||||||
id=row["id"],
|
id=row["id"],
|
||||||
unit=row["unit"],
|
unit=row["unit"],
|
||||||
|
|||||||
@@ -10,7 +10,8 @@ from loguru import logger
|
|||||||
from sqlalchemy import text
|
from sqlalchemy import text
|
||||||
from sqlalchemy.ext.asyncio import AsyncSession, create_async_engine
|
from sqlalchemy.ext.asyncio import AsyncSession, create_async_engine
|
||||||
from sqlalchemy.orm import sessionmaker
|
from sqlalchemy.orm import sessionmaker
|
||||||
from sqlalchemy.pool import NullPool, QueuePool
|
from sqlalchemy.pool import AsyncAdaptedQueuePool, NullPool
|
||||||
|
from sqlalchemy.sql.expression import TextClause
|
||||||
|
|
||||||
from cashu.core.settings import settings
|
from cashu.core.settings import settings
|
||||||
|
|
||||||
@@ -64,7 +65,7 @@ class Compat:
|
|||||||
def table_with_schema(self, table: str):
|
def table_with_schema(self, table: str):
|
||||||
return f"{self.references_schema if self.schema else ''}{table}"
|
return f"{self.references_schema if self.schema else ''}{table}"
|
||||||
|
|
||||||
|
# https://docs.sqlalchemy.org/en/14/core/connections.html#sqlalchemy.engine.CursorResult
|
||||||
class Connection(Compat):
|
class Connection(Compat):
|
||||||
def __init__(self, conn: AsyncSession, txn, typ, name, schema):
|
def __init__(self, conn: AsyncSession, txn, typ, name, schema):
|
||||||
self.conn = conn
|
self.conn = conn
|
||||||
@@ -73,19 +74,20 @@ class Connection(Compat):
|
|||||||
self.name = name
|
self.name = name
|
||||||
self.schema = schema
|
self.schema = schema
|
||||||
|
|
||||||
def rewrite_query(self, query) -> str:
|
def rewrite_query(self, query) -> TextClause:
|
||||||
if self.type in {POSTGRES, COCKROACH}:
|
if self.type in {POSTGRES, COCKROACH}:
|
||||||
query = query.replace("%", "%%")
|
query = query.replace("%", "%%")
|
||||||
query = query.replace("?", "%s")
|
query = query.replace("?", "%s")
|
||||||
return text(query)
|
return text(query)
|
||||||
|
|
||||||
async def fetchall(self, query: str, values: dict = {}) -> list:
|
async def fetchall(self, query: str, values: dict = {}):
|
||||||
result = await self.conn.execute(self.rewrite_query(query), values)
|
result = await self.conn.execute(self.rewrite_query(query), values)
|
||||||
return result.all()
|
return [r._mapping for r in result.all()] # will return [] if result list is empty
|
||||||
|
|
||||||
async def fetchone(self, query: str, values: dict = {}):
|
async def fetchone(self, query: str, values: dict = {}):
|
||||||
result = await self.conn.execute(self.rewrite_query(query), values)
|
result = await self.conn.execute(self.rewrite_query(query), values)
|
||||||
return result.fetchone()
|
r = result.fetchone()
|
||||||
|
return r._mapping if r is not None else None
|
||||||
|
|
||||||
async def execute(self, query: str, values: dict = {}):
|
async def execute(self, query: str, values: dict = {}):
|
||||||
return await self.conn.execute(self.rewrite_query(query), values)
|
return await self.conn.execute(self.rewrite_query(query), values)
|
||||||
@@ -132,9 +134,9 @@ class Database(Compat):
|
|||||||
if not settings.db_connection_pool:
|
if not settings.db_connection_pool:
|
||||||
kwargs["poolclass"] = NullPool
|
kwargs["poolclass"] = NullPool
|
||||||
elif self.type == POSTGRES:
|
elif self.type == POSTGRES:
|
||||||
kwargs["poolclass"] = QueuePool
|
kwargs["poolclass"] = AsyncAdaptedQueuePool # type: ignore[assignment]
|
||||||
kwargs["pool_size"] = 50
|
kwargs["pool_size"] = 50 # type: ignore[assignment]
|
||||||
kwargs["max_overflow"] = 100
|
kwargs["max_overflow"] = 100 # type: ignore[assignment]
|
||||||
|
|
||||||
self.engine = create_async_engine(database_uri, **kwargs)
|
self.engine = create_async_engine(database_uri, **kwargs)
|
||||||
self.async_session = sessionmaker(
|
self.async_session = sessionmaker(
|
||||||
@@ -281,12 +283,13 @@ class Database(Compat):
|
|||||||
async def fetchall(self, query: str, values: dict = {}) -> list:
|
async def fetchall(self, query: str, values: dict = {}) -> list:
|
||||||
async with self.connect() as conn:
|
async with self.connect() as conn:
|
||||||
result = await conn.execute(query, values)
|
result = await conn.execute(query, values)
|
||||||
return result.all()
|
return [r._mapping for r in result.all()]
|
||||||
|
|
||||||
async def fetchone(self, query: str, values: dict = {}):
|
async def fetchone(self, query: str, values: dict = {}):
|
||||||
async with self.connect() as conn:
|
async with self.connect() as conn:
|
||||||
result = await conn.execute(query, values)
|
result = await conn.execute(query, values)
|
||||||
return result.fetchone()
|
r = result.fetchone()
|
||||||
|
return r._mapping if r is not None else None
|
||||||
|
|
||||||
async def execute(self, query: str, values: dict = {}):
|
async def execute(self, query: str, values: dict = {}):
|
||||||
async with self.connect() as conn:
|
async with self.connect() as conn:
|
||||||
|
|||||||
@@ -104,6 +104,6 @@ async def migrate_databases(db: Database, migrations_module):
|
|||||||
f"SELECT * FROM {db.table_with_schema('dbversions')}"
|
f"SELECT * FROM {db.table_with_schema('dbversions')}"
|
||||||
)
|
)
|
||||||
rows = result.all()
|
rows = result.all()
|
||||||
current_versions = {row["db"]: row["version"] for row in rows}
|
current_versions = {row._mapping["db"]: row._mapping["version"] for row in rows}
|
||||||
matcher = re.compile(r"^m(\d\d\d)_")
|
matcher = re.compile(r"^m(\d\d\d)_")
|
||||||
await run_migration(db, migrations_module)
|
await run_migration(db, migrations_module)
|
||||||
|
|||||||
@@ -691,7 +691,10 @@ class LedgerCrudSqlite(LedgerCrud):
|
|||||||
"""
|
"""
|
||||||
)
|
)
|
||||||
assert row, "Balance not found"
|
assert row, "Balance not found"
|
||||||
return int(row[0])
|
|
||||||
|
# sqlalchemy index of first element
|
||||||
|
key = next(iter(row))
|
||||||
|
return int(row[key])
|
||||||
|
|
||||||
async def get_keyset(
|
async def get_keyset(
|
||||||
self,
|
self,
|
||||||
|
|||||||
@@ -377,7 +377,7 @@ async def bump_secret_derivation(
|
|||||||
)
|
)
|
||||||
counter = 0
|
counter = 0
|
||||||
else:
|
else:
|
||||||
counter = int(rows[0])
|
counter = int(rows["counter"])
|
||||||
|
|
||||||
if not skip:
|
if not skip:
|
||||||
await (conn or db).execute(
|
await (conn or db).execute(
|
||||||
@@ -437,8 +437,8 @@ async def get_seed_and_mnemonic(
|
|||||||
)
|
)
|
||||||
return (
|
return (
|
||||||
(
|
(
|
||||||
row[0],
|
row["seed"],
|
||||||
row[1],
|
row["mnemonic"],
|
||||||
)
|
)
|
||||||
if row
|
if row
|
||||||
else None
|
else None
|
||||||
|
|||||||
1291
poetry.lock
generated
1291
poetry.lock
generated
File diff suppressed because it is too large
Load Diff
@@ -7,7 +7,7 @@ license = "MIT"
|
|||||||
|
|
||||||
[tool.poetry.dependencies]
|
[tool.poetry.dependencies]
|
||||||
python = "^3.8.1"
|
python = "^3.8.1"
|
||||||
SQLAlchemy = {version = "1.4.52", extras = ["asyncio"]}
|
SQLAlchemy = {version = "^2.0.35", extras = ["asyncio"]}
|
||||||
click = "^8.1.7"
|
click = "^8.1.7"
|
||||||
pydantic = "^1.10.2"
|
pydantic = "^1.10.2"
|
||||||
bech32 = "^1.2.0"
|
bech32 = "^1.2.0"
|
||||||
|
|||||||
Reference in New Issue
Block a user