Merge pull request #35 from callebtc/refactor/cleanup

add tests
This commit is contained in:
calle
2022-10-11 01:15:01 +02:00
committed by GitHub
16 changed files with 197 additions and 133 deletions

View File

@@ -1,4 +1,4 @@
name: formatting
name: checks
on:
push:
@@ -7,7 +7,7 @@ on:
branches: [main]
jobs:
poetry:
formatting:
runs-on: ubuntu-latest
strategy:
matrix:
@@ -29,3 +29,25 @@ jobs:
run: poetry run black --check .
- name: Check isort
run: poetry run isort --profile black --check-only .
linting:
runs-on: ubuntu-latest
strategy:
matrix:
python-version: ["3.9"]
poetry-version: ["1.2.1"]
steps:
- uses: actions/checkout@v2
- name: Set up Python ${{ matrix.python-version }}
uses: actions/setup-python@v2
with:
python-version: ${{ matrix.python-version }}
- name: Set up Poetry ${{ matrix.poetry-version }}
uses: abatilo/actions-poetry@v2
with:
poetry-version: ${{ matrix.poetry-version }}
- name: Install packages
run: poetry install --with dev
- name: Setup mypy
run: yes | poetry run mypy cashu --install-types || true
- name: Run mypy
run: poetry run mypy cashu --ignore-missing

View File

@@ -37,3 +37,5 @@ jobs:
MINT_PORT: 3338
run: |
poetry run pytest tests
- name: Upload coverage to Codecov
uses: codecov/codecov-action@v3

View File

@@ -166,7 +166,6 @@ class CheckFeesResponse(BaseModel):
class MeltRequest(BaseModel):
proofs: List[Proof]
amount: int = None # deprecated
invoice: str
@@ -233,7 +232,7 @@ class MintKeyset:
id: str
derivation_path: str
private_keys: Dict[int, PrivateKey]
public_keys: Dict[int, PublicKey] = None
public_keys: Dict[int, PublicKey] = {}
valid_from: Union[str, None] = None
valid_to: Union[str, None] = None
first_seen: Union[str, None] = None
@@ -247,9 +246,9 @@ class MintKeyset:
valid_to=None,
first_seen=None,
active=None,
seed: Union[None, str] = None,
derivation_path: str = None,
version: str = None,
seed: str = "",
derivation_path: str = "",
version: str = "",
):
self.derivation_path = derivation_path
self.id = id

View File

@@ -27,7 +27,7 @@ def derive_pubkeys(keys: Dict[int, PrivateKey]):
return {amt: keys[amt].pubkey for amt in [2**i for i in range(MAX_ORDER)]}
def derive_keyset_id(keys: Dict[str, PublicKey]):
def derive_keyset_id(keys: Dict[int, PublicKey]):
"""Deterministic derivation keyset_id from set of public keys."""
pubkeys_concat = "".join([p.serialize().hex() for _, p in keys.items()])
return base64.b64encode(

View File

@@ -7,13 +7,13 @@ from environs import Env # type: ignore
env = Env()
ENV_FILE: Union[str, None] = os.path.join(str(Path.home()), ".cashu", ".env")
ENV_FILE = os.path.join(str(Path.home()), ".cashu", ".env")
if not os.path.isfile(ENV_FILE):
ENV_FILE = os.path.join(os.getcwd(), ".env")
if os.path.isfile(ENV_FILE):
env.read_env(ENV_FILE)
else:
ENV_FILE = None
ENV_FILE = ""
env.read_env()
DEBUG = env.bool("DEBUG", default=False)

View File

@@ -79,9 +79,9 @@ class Wallet(ABC):
) -> Coroutine[None, None, PaymentStatus]:
pass
@abstractmethod
def paid_invoices_stream(self) -> AsyncGenerator[str, None]:
pass
# @abstractmethod
# def paid_invoices_stream(self) -> AsyncGenerator[str, None]:
# pass
class Unsupported(Exception):

View File

@@ -1,8 +1,5 @@
import asyncio
import hashlib
import json
from os import getenv
from typing import AsyncGenerator, Dict, Optional
from typing import Dict, Optional
import requests
@@ -133,26 +130,26 @@ class LNbitsWallet(Wallet):
return PaymentStatus(data["paid"], data["details"]["fee"], data["preimage"])
async def paid_invoices_stream(self) -> AsyncGenerator[str, None]:
url = f"{self.endpoint}/api/v1/payments/sse"
# async def paid_invoices_stream(self) -> AsyncGenerator[str, None]:
# url = f"{self.endpoint}/api/v1/payments/sse"
while True:
try:
async with requests.stream("GET", url) as r:
async for line in r.aiter_lines():
if line.startswith("data:"):
try:
data = json.loads(line[5:])
except json.decoder.JSONDecodeError:
continue
# while True:
# try:
# async with requests.stream("GET", url) as r:
# async for line in r.aiter_lines():
# if line.startswith("data:"):
# try:
# data = json.loads(line[5:])
# except json.decoder.JSONDecodeError:
# continue
if type(data) is not dict:
continue
# if type(data) is not dict:
# continue
yield data["payment_hash"] # payment_hash
# yield data["payment_hash"] # payment_hash
except:
pass
# except:
# pass
print("lost connection to lnbits /payments/sse, retrying in 5 seconds")
await asyncio.sleep(5)
# print("lost connection to lnbits /payments/sse, retrying in 5 seconds")
# await asyncio.sleep(5)

View File

@@ -1,4 +1,4 @@
from typing import Optional
from typing import Any, List, Optional
from cashu.core.base import Invoice, MintKeyset, Proof
from cashu.core.db import Connection, Database
@@ -118,7 +118,7 @@ async def store_keyset(
conn: Optional[Connection] = None,
):
await (conn or db).execute(
await (conn or db).execute( # type: ignore
"""
INSERT INTO keysets
(id, derivation_path, valid_from, valid_to, first_seen, active, version)
@@ -138,12 +138,12 @@ async def store_keyset(
async def get_keyset(
id: str = None,
derivation_path: str = None,
derivation_path: str = "",
db: Database = None,
conn: Optional[Connection] = None,
):
clauses = []
values = []
values: List[Any] = []
clauses.append("active = ?")
values.append(True)
if id:
@@ -156,7 +156,7 @@ async def get_keyset(
if clauses:
where = f"WHERE {' AND '.join(clauses)}"
rows = await (conn or db).fetchall(
rows = await (conn or db).fetchall( # type: ignore
f"""
SELECT * from keysets
{where}

View File

@@ -116,17 +116,7 @@ class Ledger:
C = PublicKey(bytes.fromhex(proof.C), raw=True)
# backwards compatibility with old hash_to_curve
logger.debug(f"Client version {context.get('client-version')}")
if self.keysets.keysets.get(proof.id):
logger.debug(
f"Token keyset: {self.keysets.keysets.get(proof.id)}, token version: {self.keysets.keysets[proof.id].version}"
)
# if not context.get("client-version") or (
# self.keysets.keysets.get(proof.id)
# and not self.keysets.keysets[proof.id].version
# ):
# return legacy.verify_pre_0_3_3(secret_key, C, proof.secret)
# backwards compatibility with old hash_to_curve < 0.3.3
try:
ret = legacy.verify_pre_0_3_3(secret_key, C, proof.secret)
if ret:

View File

@@ -1,9 +1,10 @@
from typing import Union
from typing import Dict, List, Union
from fastapi import APIRouter
from secp256k1 import PublicKey
from cashu.core.base import (
BlindedSignature,
CheckFeesRequest,
CheckFeesResponse,
CheckRequest,
@@ -20,24 +21,20 @@ from cashu.mint import ledger
router: APIRouter = APIRouter()
from starlette.requests import Request
from starlette_context import context
@router.get("/keys")
def keys():
def keys() -> dict[int, str]:
"""Get the public keys of the mint"""
return ledger.get_keyset()
@router.get("/keysets")
def keysets():
def keysets() -> dict[str, list[str]]:
"""Get all active keysets of the mint"""
return {"keysets": ledger.keysets.get_ids()}
@router.get("/mint")
async def request_mint(amount: int = 0):
async def request_mint(amount: int = 0) -> GetMintResponse:
"""
Request minting of new tokens. The mint responds with a Lightning invoice.
This endpoint can be used for a Lightning invoice UX flow.
@@ -54,7 +51,7 @@ async def request_mint(amount: int = 0):
async def mint(
payloads: MintRequest,
payment_hash: Union[str, None] = None,
):
) -> Union[List[BlindedSignature], CashuError]:
"""
Requests the minting of tokens belonging to a paid payment request.
@@ -73,7 +70,7 @@ async def mint(
@router.post("/melt")
async def melt(request: Request, payload: MeltRequest):
async def melt(payload: MeltRequest) -> GetMeltResponse:
"""
Requests tokens to be destroyed and sent out via Lightning.
"""
@@ -83,13 +80,13 @@ async def melt(request: Request, payload: MeltRequest):
@router.post("/check")
async def check_spendable(payload: CheckRequest):
async def check_spendable(payload: CheckRequest) -> Dict[int, bool]:
"""Check whether a secret has been spent already or not."""
return await ledger.check_spendable(payload.proofs)
@router.post("/checkfees")
async def check_fees(payload: CheckFeesRequest):
async def check_fees(payload: CheckFeesRequest) -> CheckFeesResponse:
"""
Responds with the fees necessary to pay a Lightning invoice.
Used by wallets for figuring out the fees they need to supply.
@@ -100,7 +97,9 @@ async def check_fees(payload: CheckFeesRequest):
@router.post("/split")
async def split(request: Request, payload: SplitRequest):
async def split(
payload: SplitRequest,
) -> Union[CashuError, PostSplitResponse]:
"""
Requetst a set of tokens with amount "total" to be split into two
newly minted sets with amount "split" and "total-split".
@@ -108,12 +107,14 @@ async def split(request: Request, payload: SplitRequest):
proofs = payload.proofs
amount = payload.amount
outputs = payload.outputs.blinded_messages if payload.outputs else None
# backwards compatibility with clients < v0.2.2
assert outputs, Exception("no outputs provided.")
try:
split_return = await ledger.split(proofs, amount, outputs)
except Exception as exc:
return CashuError(error=str(exc))
if not split_return:
return {"error": "there was a problem with the split."}
return CashuError(error="there was an error with the split")
frst_promises, scnd_promises = split_return
resp = PostSplitResponse(fst=frst_promises, snd=scnd_promises)
return resp

View File

@@ -1,3 +1,3 @@
import sys
sys.tracebacklimit = None
sys.tracebacklimit = None # type: ignore

View File

@@ -121,21 +121,31 @@ async def mint(ctx, amount: int, hash: str):
@cli.command("pay", help="Pay Lightning invoice.")
@click.argument("invoice", type=str)
@click.option(
"--yes", "-y", default=False, is_flag=True, help="Skip confirmation.", type=bool
)
@click.pass_context
@coro
async def pay(ctx, invoice: str):
async def pay(ctx, invoice: str, yes: bool):
wallet: Wallet = ctx.obj["WALLET"]
await wallet.load_mint()
wallet.status()
decoded_invoice: Invoice = bolt11.decode(invoice)
# check if it's an internal payment
fees = (await wallet.check_fees(invoice))["fee"]
amount = math.ceil(
(decoded_invoice.amount_msat + fees * 1000) / 1000
) # 1% fee for Lightning
print(
f"Paying Lightning invoice of {decoded_invoice.amount_msat//1000} sat ({amount} sat incl. fees)"
)
if not yes:
click.confirm(
f"Pay {decoded_invoice.amount_msat//1000} sat ({amount} sat incl. fees)?",
abort=True,
default=True,
)
print(f"Paying Lightning invoice ...")
assert amount > 0, "amount is not positive"
if wallet.available_balance < amount:
print("Error: Balance too low.")

View File

@@ -93,7 +93,7 @@ async def update_proof_reserved(
clauses.append("time_reserved = ?")
values.append(int(time.time()))
await (conn or db).execute(
await (conn or db).execute( # type: ignore
f"UPDATE proofs SET {', '.join(clauses)} WHERE secret = ?",
(*values, str(proof.secret)),
)
@@ -155,7 +155,7 @@ async def get_unused_locks(
if clause:
where = f"WHERE {' AND '.join(clause)}"
rows = await (conn or db).fetchall(
rows = await (conn or db).fetchall( # type: ignore
f"""
SELECT * from p2sh
{where}
@@ -176,7 +176,7 @@ async def update_p2sh_used(
clauses.append("used = ?")
values.append(used)
await (conn or db).execute(
await (conn or db).execute( # type: ignore
f"UPDATE proofs SET {', '.join(clauses)} WHERE address = ?",
(*values, str(p2sh.address)),
)
@@ -189,7 +189,7 @@ async def store_keyset(
conn: Optional[Connection] = None,
):
await (conn or db).execute(
await (conn or db).execute( # type: ignore
"""
INSERT INTO keysets
(id, mint_url, valid_from, valid_to, first_seen, active)
@@ -213,7 +213,7 @@ async def get_keyset(
conn: Optional[Connection] = None,
):
clauses = []
values = []
values: List[Any] = []
clauses.append("active = ?")
values.append(True)
if id:
@@ -226,7 +226,7 @@ async def get_keyset(
if clauses:
where = f"WHERE {' AND '.join(clauses)}"
row = await (conn or db).fetchone(
row = await (conn or db).fetchone( # type: ignore
f"""
SELECT * from keysets
{where}

View File

@@ -50,28 +50,8 @@ class LedgerAPI:
def __init__(self, url):
self.url = url
async def _get_keys(self, url):
resp = requests.get(
url + "/keys",
headers={"Client-version": VERSION},
).json()
keys = resp
assert len(keys), Exception("did not receive any keys")
keyset_keys = {
int(amt): PublicKey(bytes.fromhex(val), raw=True)
for amt, val in keys.items()
}
keyset = WalletKeyset(pubkeys=keyset_keys, mint_url=url)
return keyset
async def _get_keysets(self, url):
keysets = requests.get(
url + "/keysets",
headers={"Client-version": VERSION},
).json()
assert len(keysets), Exception("did not receive any keysets")
return keysets
self.s = requests.Session()
self.s.headers.update({"Client-version": VERSION})
@staticmethod
def _get_output_split(amount):
@@ -100,6 +80,11 @@ class LedgerAPI:
proofs.append(proof)
return proofs
@staticmethod
def raise_on_error(resp_dict):
if "error" in resp_dict:
raise Exception("Mint Error: {}".format(resp_dict["error"]))
@staticmethod
def _generate_secret(randombits=128):
"""Returns base64 encoded random string."""
@@ -138,12 +123,6 @@ class LedgerAPI:
self.keys = keyset.public_keys
self.keyset_id = keyset.id
def request_mint(self, amount):
"""Requests a mint from the server and returns Lightning invoice."""
r = requests.get(self.url + "/mint", params={"amount": amount})
r.raise_for_status()
return r.json()
@staticmethod
def _construct_outputs(amounts: List[int], secrets: List[str]):
"""Takes a list of amounts and secrets and returns outputs.
@@ -173,25 +152,55 @@ class LedgerAPI:
return [f"{secret}:{self._generate_secret()}" for i in range(n)]
return [f"{i}:{secret}" for i in range(n)]
"""
ENDPOINTS
"""
async def _get_keys(self, url):
resp = self.s.get(
url + "/keys",
)
resp.raise_for_status()
keys = resp.json()
assert len(keys), Exception("did not receive any keys")
keyset_keys = {
int(amt): PublicKey(bytes.fromhex(val), raw=True)
for amt, val in keys.items()
}
keyset = WalletKeyset(pubkeys=keyset_keys, mint_url=url)
return keyset
async def _get_keysets(self, url):
resp = self.s.get(
url + "/keysets",
).json()
resp.raise_for_status()
keysets = resp.json()
assert len(keysets), Exception("did not receive any keysets")
return keysets
def request_mint(self, amount):
"""Requests a mint from the server and returns Lightning invoice."""
resp = self.s.get(self.url + "/mint", params={"amount": amount})
resp.raise_for_status()
return_dict = resp.json()
self.raise_on_error(return_dict)
return return_dict
async def mint(self, amounts, payment_hash=None):
"""Mints new coins and returns a proof of promise."""
secrets = [self._generate_secret() for s in range(len(amounts))]
await self._check_used_secrets(secrets)
payloads, rs = self._construct_outputs(amounts, secrets)
resp = requests.post(
resp = self.s.post(
self.url + "/mint",
json=payloads.dict(),
params={"payment_hash": payment_hash},
headers={"Client-version": VERSION},
)
resp.raise_for_status()
try:
promises_list = resp.json()
except:
raise Exception("Unkown mint error.")
if "error" in promises_list:
raise Exception("Error: {}".format(promises_list["error"]))
promises_list = resp.json()
self.raise_on_error(promises_list)
promises = [BlindedSignature.from_dict(p) for p in promises_list]
return self._construct_proofs(promises, secrets, rs)
@@ -239,18 +248,14 @@ class LedgerAPI:
"proofs": {i: proofs_include for i in range(len(proofs))},
}
resp = requests.post(
resp = self.s.post(
self.url + "/split",
json=split_payload.dict(include=_splitrequest_include_fields(proofs)),
headers={"Client-version": VERSION},
)
resp.raise_for_status()
try:
promises_dict = resp.json()
except:
raise Exception("Unkown mint error.")
if "error" in promises_dict:
raise Exception("Mint Error: {}".format(promises_dict["error"]))
promises_dict = resp.json()
self.raise_on_error(promises_dict)
promises_fst = [BlindedSignature.from_dict(p) for p in promises_dict["fst"]]
promises_snd = [BlindedSignature.from_dict(p) for p in promises_dict["snd"]]
# Construct proofs from promises (i.e., unblind signatures)
@@ -264,31 +269,35 @@ class LedgerAPI:
return frst_proofs, scnd_proofs
async def check_spendable(self, proofs: List[Proof]):
"""
Cheks whether the secrets in proofs are already spent or not and returns a list of booleans.
"""
payload = CheckRequest(proofs=proofs)
resp = requests.post(
resp = self.s.post(
self.url + "/check",
json=payload.dict(),
headers={"Client-version": VERSION},
)
resp.raise_for_status()
return_dict = resp.json()
self.raise_on_error(return_dict)
return return_dict
async def check_fees(self, payment_request: str):
"""Checks whether the Lightning payment is internal."""
payload = CheckFeesRequest(pr=payment_request)
resp = requests.post(
resp = self.s.post(
self.url + "/checkfees",
json=payload.dict(),
headers={"Client-version": VERSION},
)
resp.raise_for_status()
return_dict = resp.json()
self.raise_on_error(return_dict)
return return_dict
async def pay_lightning(self, proofs: List[Proof], invoice: str):
"""
Accepts proofs and a lightning invoice to pay in exchange.
"""
payload = MeltRequest(proofs=proofs, invoice=invoice)
def _meltequest_include_fields(proofs):
@@ -300,14 +309,13 @@ class LedgerAPI:
"proofs": {i: proofs_include for i in range(len(proofs))},
}
resp = requests.post(
resp = self.s.post(
self.url + "/melt",
json=payload.dict(include=_meltequest_include_fields(proofs)),
headers={"Client-version": VERSION},
)
resp.raise_for_status()
return_dict = resp.json()
self.raise_on_error(return_dict)
return return_dict

37
tests/test_crypto.py Normal file
View File

@@ -0,0 +1,37 @@
import pytest
from cashu.core.b_dhke import hash_to_curve
def test_hash_to_curve():
result = hash_to_curve(
bytes.fromhex(
"0000000000000000000000000000000000000000000000000000000000000000"
)
)
assert (
result.serialize().hex()
== "0266687aadf862bd776c8fc18b8e9f8e20089714856ee233b3902a591d0d5f2925"
)
result = hash_to_curve(
bytes.fromhex(
"0000000000000000000000000000000000000000000000000000000000000001"
)
)
assert (
result.serialize().hex()
== "02ec4916dd28fc4c10d78e287ca5d9cc51ee1ae73cbfde08c6b37324cbfaac8bc5"
)
def test_hash_to_curve_iteration():
result = hash_to_curve(
bytes.fromhex(
"0000000000000000000000000000000000000000000000000000000000000002"
)
)
assert (
result.serialize().hex()
== "02076c988b353fcbb748178ecb286bc9d0b4acf474d4ba31ba62334e46c97c416a"
)

View File

@@ -1,6 +1,4 @@
import time
from distutils.command.build_scripts import first_line_re
from re import S
from typing import List
import pytest
@@ -27,9 +25,9 @@ async def assert_err(f, msg):
)
def assert_amt(proofs, expected):
def assert_amt(proofs: List[Proof], expected: int):
"""Assert amounts the proofs contain."""
assert [p["amount"] for p in proofs] == expected
assert [p.amount for p in proofs] == expected
@pytest_asyncio.fixture(scope="function")