Wallet: deprecate old h2c (#459)

* wallet: deprecate old hash to curve

* fix order

* added migration: untested

* recompute Y always
This commit is contained in:
callebtc
2024-02-26 23:07:13 +01:00
committed by GitHub
parent 53cd8ff016
commit 29be002495
6 changed files with 233 additions and 150 deletions

View File

@@ -101,16 +101,15 @@ class Proof(BaseModel):
time_created: Union[None, str] = "" time_created: Union[None, str] = ""
time_reserved: Union[None, str] = "" time_reserved: Union[None, str] = ""
derivation_path: Union[None, str] = "" # derivation path of the proof derivation_path: Union[None, str] = "" # derivation path of the proof
mint_id: Union[None, str] = ( mint_id: Union[
None # holds the id of the mint operation that created this proof None, str
) ] = None # holds the id of the mint operation that created this proof
melt_id: Union[None, str] = ( melt_id: Union[
None # holds the id of the melt operation that destroyed this proof None, str
) ] = None # holds the id of the melt operation that destroyed this proof
def __init__(self, **data): def __init__(self, **data):
super().__init__(**data) super().__init__(**data)
if not self.Y:
self.Y = hash_to_curve(self.secret.encode("utf-8")).serialize().hex() self.Y = hash_to_curve(self.secret.encode("utf-8")).serialize().hex()
@classmethod @classmethod
@@ -274,7 +273,6 @@ class MintQuote(BaseModel):
@classmethod @classmethod
def from_row(cls, row: Row): def from_row(cls, row: Row):
try: try:
# SQLITE: row is timestamp (string) # SQLITE: row is timestamp (string)
created_time = int(row["created_time"]) if row["created_time"] else None created_time = int(row["created_time"]) if row["created_time"] else None
@@ -664,9 +662,9 @@ class WalletKeyset:
self.id = id self.id = id
def serialize(self): def serialize(self):
return json.dumps({ return json.dumps(
amount: key.serialize().hex() for amount, key in self.public_keys.items() {amount: key.serialize().hex() for amount, key in self.public_keys.items()}
}) )
@classmethod @classmethod
def from_row(cls, row: Row): def from_row(cls, row: Row):

View File

@@ -55,26 +55,10 @@ from typing import Optional, Tuple
from secp256k1 import PrivateKey, PublicKey from secp256k1 import PrivateKey, PublicKey
def hash_to_curve(message: bytes) -> PublicKey:
"""Generates a point from the message hash and checks if the point lies on the curve.
If it does not, iteratively tries to compute a new point from the hash."""
point = None
msg_to_hash = message
while point is None:
_hash = hashlib.sha256(msg_to_hash).digest()
try:
# will error if point does not lie on curve
point = PublicKey(b"\x02" + _hash, raw=True)
except Exception:
msg_to_hash = _hash
return point
DOMAIN_SEPARATOR = b"Secp256k1_HashToCurve_Cashu_" DOMAIN_SEPARATOR = b"Secp256k1_HashToCurve_Cashu_"
def hash_to_curve_domain_separated(message: bytes) -> PublicKey: def hash_to_curve(message: bytes) -> PublicKey:
"""Generates a secp256k1 point from a message. """Generates a secp256k1 point from a message.
The point is generated by hashing the message with a domain separator and then The point is generated by hashing the message with a domain separator and then
@@ -110,15 +94,6 @@ def step1_alice(
return B_, r return B_, r
def step1_alice_domain_separated(
secret_msg: str, blinding_factor: Optional[PrivateKey] = None
) -> tuple[PublicKey, PrivateKey]:
Y: PublicKey = hash_to_curve_domain_separated(secret_msg.encode("utf-8"))
r = blinding_factor or PrivateKey()
B_: PublicKey = Y + r.pubkey # type: ignore
return B_, r
def step2_bob(B_: PublicKey, a: PrivateKey) -> Tuple[PublicKey, PrivateKey, PrivateKey]: def step2_bob(B_: PublicKey, a: PrivateKey) -> Tuple[PublicKey, PrivateKey, PrivateKey]:
C_: PublicKey = B_.mult(a) # type: ignore C_: PublicKey = B_.mult(a) # type: ignore
# produce dleq proof # produce dleq proof
@@ -136,17 +111,11 @@ def verify(a: PrivateKey, C: PublicKey, secret_msg: str) -> bool:
valid = C == Y.mult(a) # type: ignore valid = C == Y.mult(a) # type: ignore
# BEGIN: BACKWARDS COMPATIBILITY < 0.15.1 # BEGIN: BACKWARDS COMPATIBILITY < 0.15.1
if not valid: if not valid:
valid = verify_domain_separated(a, C, secret_msg) valid = verify_deprecated(a, C, secret_msg)
# END: BACKWARDS COMPATIBILITY < 0.15.1 # END: BACKWARDS COMPATIBILITY < 0.15.1
return valid return valid
def verify_domain_separated(a: PrivateKey, C: PublicKey, secret_msg: str) -> bool:
Y: PublicKey = hash_to_curve_domain_separated(secret_msg.encode("utf-8"))
valid = C == Y.mult(a) # type: ignore
return valid
def hash_e(*publickeys: PublicKey) -> bytes: def hash_e(*publickeys: PublicKey) -> bytes:
e_ = "" e_ = ""
for p in publickeys: for p in publickeys:
@@ -202,12 +171,45 @@ def carol_verify_dleq(
valid = alice_verify_dleq(B_, C_, e, s, A) valid = alice_verify_dleq(B_, C_, e, s, A)
# BEGIN: BACKWARDS COMPATIBILITY < 0.15.1 # BEGIN: BACKWARDS COMPATIBILITY < 0.15.1
if not valid: if not valid:
return carol_verify_dleq_domain_separated(secret_msg, r, C, e, s, A) return carol_verify_dleq_deprecated(secret_msg, r, C, e, s, A)
# END: BACKWARDS COMPATIBILITY < 0.15.1 # END: BACKWARDS COMPATIBILITY < 0.15.1
return valid return valid
def carol_verify_dleq_domain_separated( # -------- Deprecated hash_to_curve before 0.15.0 --------
def hash_to_curve_deprecated(message: bytes) -> PublicKey:
"""Generates a point from the message hash and checks if the point lies on the curve.
If it does not, iteratively tries to compute a new point from the hash."""
point = None
msg_to_hash = message
while point is None:
_hash = hashlib.sha256(msg_to_hash).digest()
try:
# will error if point does not lie on curve
point = PublicKey(b"\x02" + _hash, raw=True)
except Exception:
msg_to_hash = _hash
return point
def step1_alice_deprecated(
secret_msg: str, blinding_factor: Optional[PrivateKey] = None
) -> tuple[PublicKey, PrivateKey]:
Y: PublicKey = hash_to_curve_deprecated(secret_msg.encode("utf-8"))
r = blinding_factor or PrivateKey()
B_: PublicKey = Y + r.pubkey # type: ignore
return B_, r
def verify_deprecated(a: PrivateKey, C: PublicKey, secret_msg: str) -> bool:
Y: PublicKey = hash_to_curve_deprecated(secret_msg.encode("utf-8"))
valid = C == Y.mult(a) # type: ignore
return valid
def carol_verify_dleq_deprecated(
secret_msg: str, secret_msg: str,
r: PrivateKey, r: PrivateKey,
C: PublicKey, C: PublicKey,
@@ -215,7 +217,7 @@ def carol_verify_dleq_domain_separated(
s: PrivateKey, s: PrivateKey,
A: PublicKey, A: PublicKey,
) -> bool: ) -> bool:
Y: PublicKey = hash_to_curve_domain_separated(secret_msg.encode("utf-8")) Y: PublicKey = hash_to_curve_deprecated(secret_msg.encode("utf-8"))
C_: PublicKey = C + A.mult(r) # type: ignore C_: PublicKey = C + A.mult(r) # type: ignore
B_: PublicKey = Y + r.pubkey # type: ignore B_: PublicKey = Y + r.pubkey # type: ignore
valid = alice_verify_dleq(B_, C_, e, s, A) valid = alice_verify_dleq(B_, C_, e, s, A)

View File

@@ -123,7 +123,7 @@ class WalletSettings(CashuSettings):
mint_port: int = Field(default=3338) mint_port: int = Field(default=3338)
wallet_name: str = Field(default="wallet") wallet_name: str = Field(default="wallet")
wallet_unit: str = Field(default="sat") wallet_unit: str = Field(default="sat")
wallet_domain_separation: bool = Field(default=False) wallet_use_deprecated_h2c: bool = Field(default=False)
api_port: int = Field(default=4448) api_port: int = Field(default=4448)
api_host: str = Field(default="127.0.0.1") api_host: str = Field(default="127.0.0.1")

View File

@@ -4,17 +4,20 @@ from ..core.settings import settings
async def m000_create_migrations_table(conn: Connection): async def m000_create_migrations_table(conn: Connection):
await conn.execute(f""" await conn.execute(
f"""
CREATE TABLE IF NOT EXISTS {table_with_schema(conn, 'dbversions')} ( CREATE TABLE IF NOT EXISTS {table_with_schema(conn, 'dbversions')} (
db TEXT PRIMARY KEY, db TEXT PRIMARY KEY,
version INT NOT NULL version INT NOT NULL
) )
""") """
)
async def m001_initial(db: Database): async def m001_initial(db: Database):
async with db.connect() as conn: async with db.connect() as conn:
await conn.execute(f""" await conn.execute(
f"""
CREATE TABLE IF NOT EXISTS {table_with_schema(db, 'promises')} ( CREATE TABLE IF NOT EXISTS {table_with_schema(db, 'promises')} (
amount {db.big_int} NOT NULL, amount {db.big_int} NOT NULL,
B_b TEXT NOT NULL, B_b TEXT NOT NULL,
@@ -23,9 +26,11 @@ async def m001_initial(db: Database):
UNIQUE (B_b) UNIQUE (B_b)
); );
""") """
)
await conn.execute(f""" await conn.execute(
f"""
CREATE TABLE IF NOT EXISTS {table_with_schema(db, 'proofs_used')} ( CREATE TABLE IF NOT EXISTS {table_with_schema(db, 'proofs_used')} (
amount {db.big_int} NOT NULL, amount {db.big_int} NOT NULL,
C TEXT NOT NULL, C TEXT NOT NULL,
@@ -34,9 +39,11 @@ async def m001_initial(db: Database):
UNIQUE (secret) UNIQUE (secret)
); );
""") """
)
await conn.execute(f""" await conn.execute(
f"""
CREATE TABLE IF NOT EXISTS {table_with_schema(db, 'invoices')} ( CREATE TABLE IF NOT EXISTS {table_with_schema(db, 'invoices')} (
amount {db.big_int} NOT NULL, amount {db.big_int} NOT NULL,
pr TEXT NOT NULL, pr TEXT NOT NULL,
@@ -46,7 +53,8 @@ async def m001_initial(db: Database):
UNIQUE (hash) UNIQUE (hash)
); );
""") """
)
async def drop_balance_views(db: Database, conn: Connection): async def drop_balance_views(db: Database, conn: Connection):
@@ -58,32 +66,38 @@ async def drop_balance_views(db: Database, conn: Connection):
async def create_balance_views(db: Database, conn: Connection): async def create_balance_views(db: Database, conn: Connection):
await conn.execute(f""" await conn.execute(
f"""
CREATE VIEW {table_with_schema(db, 'balance_issued')} AS CREATE VIEW {table_with_schema(db, 'balance_issued')} AS
SELECT COALESCE(SUM(s), 0) AS balance FROM ( SELECT COALESCE(SUM(s), 0) AS balance FROM (
SELECT SUM(amount) AS s SELECT SUM(amount) AS s
FROM {table_with_schema(db, 'promises')} FROM {table_with_schema(db, 'promises')}
WHERE amount > 0 WHERE amount > 0
) AS balance_issued; ) AS balance_issued;
""") """
)
await conn.execute(f""" await conn.execute(
f"""
CREATE VIEW {table_with_schema(db, 'balance_redeemed')} AS CREATE VIEW {table_with_schema(db, 'balance_redeemed')} AS
SELECT COALESCE(SUM(s), 0) AS balance FROM ( SELECT COALESCE(SUM(s), 0) AS balance FROM (
SELECT SUM(amount) AS s SELECT SUM(amount) AS s
FROM {table_with_schema(db, 'proofs_used')} FROM {table_with_schema(db, 'proofs_used')}
WHERE amount > 0 WHERE amount > 0
) AS balance_redeemed; ) AS balance_redeemed;
""") """
)
await conn.execute(f""" await conn.execute(
f"""
CREATE VIEW {table_with_schema(db, 'balance')} AS CREATE VIEW {table_with_schema(db, 'balance')} AS
SELECT s_issued - s_used FROM ( SELECT s_issued - s_used FROM (
SELECT bi.balance AS s_issued, bu.balance AS s_used SELECT bi.balance AS s_issued, bu.balance AS s_used
FROM {table_with_schema(db, 'balance_issued')} bi FROM {table_with_schema(db, 'balance_issued')} bi
CROSS JOIN {table_with_schema(db, 'balance_redeemed')} bu CROSS JOIN {table_with_schema(db, 'balance_redeemed')} bu
) AS balance; ) AS balance;
""") """
)
async def m002_add_balance_views(db: Database): async def m002_add_balance_views(db: Database):
@@ -96,7 +110,8 @@ async def m003_mint_keysets(db: Database):
Stores mint keysets from different mints and epochs. Stores mint keysets from different mints and epochs.
""" """
async with db.connect() as conn: async with db.connect() as conn:
await conn.execute(f""" await conn.execute(
f"""
CREATE TABLE IF NOT EXISTS {table_with_schema(db, 'keysets')} ( CREATE TABLE IF NOT EXISTS {table_with_schema(db, 'keysets')} (
id TEXT NOT NULL, id TEXT NOT NULL,
derivation_path TEXT, derivation_path TEXT,
@@ -108,8 +123,10 @@ async def m003_mint_keysets(db: Database):
UNIQUE (derivation_path) UNIQUE (derivation_path)
); );
""") """
await conn.execute(f""" )
await conn.execute(
f"""
CREATE TABLE IF NOT EXISTS {table_with_schema(db, 'mint_pubkeys')} ( CREATE TABLE IF NOT EXISTS {table_with_schema(db, 'mint_pubkeys')} (
id TEXT NOT NULL, id TEXT NOT NULL,
amount INTEGER NOT NULL, amount INTEGER NOT NULL,
@@ -118,7 +135,8 @@ async def m003_mint_keysets(db: Database):
UNIQUE (id, pubkey) UNIQUE (id, pubkey)
); );
""") """
)
async def m004_keysets_add_version(db: Database): async def m004_keysets_add_version(db: Database):
@@ -136,7 +154,8 @@ async def m005_pending_proofs_table(db: Database) -> None:
Store pending proofs. Store pending proofs.
""" """
async with db.connect() as conn: async with db.connect() as conn:
await conn.execute(f""" await conn.execute(
f"""
CREATE TABLE IF NOT EXISTS {table_with_schema(db, 'proofs_pending')} ( CREATE TABLE IF NOT EXISTS {table_with_schema(db, 'proofs_pending')} (
amount INTEGER NOT NULL, amount INTEGER NOT NULL,
C TEXT NOT NULL, C TEXT NOT NULL,
@@ -145,7 +164,8 @@ async def m005_pending_proofs_table(db: Database) -> None:
UNIQUE (secret) UNIQUE (secret)
); );
""") """
)
async def m006_invoices_add_payment_hash(db: Database): async def m006_invoices_add_payment_hash(db: Database):
@@ -255,7 +275,8 @@ async def m011_add_quote_tables(db: Database):
f" '{settings.mint_private_key}', unit = 'sat'" f" '{settings.mint_private_key}', unit = 'sat'"
) )
await conn.execute(f""" await conn.execute(
f"""
CREATE TABLE IF NOT EXISTS {table_with_schema(db, 'mint_quotes')} ( CREATE TABLE IF NOT EXISTS {table_with_schema(db, 'mint_quotes')} (
quote TEXT NOT NULL, quote TEXT NOT NULL,
method TEXT NOT NULL, method TEXT NOT NULL,
@@ -271,9 +292,11 @@ async def m011_add_quote_tables(db: Database):
UNIQUE (quote) UNIQUE (quote)
); );
""") """
)
await conn.execute(f""" await conn.execute(
f"""
CREATE TABLE IF NOT EXISTS {table_with_schema(db, 'melt_quotes')} ( CREATE TABLE IF NOT EXISTS {table_with_schema(db, 'melt_quotes')} (
quote TEXT NOT NULL, quote TEXT NOT NULL,
method TEXT NOT NULL, method TEXT NOT NULL,
@@ -291,7 +314,8 @@ async def m011_add_quote_tables(db: Database):
UNIQUE (quote) UNIQUE (quote)
); );
""") """
)
await conn.execute( await conn.execute(
f"INSERT INTO {table_with_schema(db, 'mint_quotes')} (quote, method," f"INSERT INTO {table_with_schema(db, 'mint_quotes')} (quote, method,"
@@ -318,7 +342,8 @@ async def m012_keysets_uniqueness_with_seed(db: Database):
f" SELECT * FROM {table_with_schema(db, 'keysets')}" f" SELECT * FROM {table_with_schema(db, 'keysets')}"
) )
await conn.execute(f"DROP TABLE {table_with_schema(db, 'keysets')}") await conn.execute(f"DROP TABLE {table_with_schema(db, 'keysets')}")
await conn.execute(f""" await conn.execute(
f"""
CREATE TABLE IF NOT EXISTS {table_with_schema(db, 'keysets')} ( CREATE TABLE IF NOT EXISTS {table_with_schema(db, 'keysets')} (
id TEXT NOT NULL, id TEXT NOT NULL,
derivation_path TEXT, derivation_path TEXT,
@@ -333,7 +358,8 @@ async def m012_keysets_uniqueness_with_seed(db: Database):
UNIQUE (seed, derivation_path) UNIQUE (seed, derivation_path)
); );
""") """
)
await conn.execute( await conn.execute(
f"INSERT INTO {table_with_schema(db, 'keysets')} (id," f"INSERT INTO {table_with_schema(db, 'keysets')} (id,"
" derivation_path, valid_from, valid_to, first_seen," " derivation_path, valid_from, valid_to, first_seen,"
@@ -358,7 +384,8 @@ async def m013_keysets_add_encrypted_seed(db: Database):
f" SELECT * FROM {table_with_schema(db, 'keysets')}" f" SELECT * FROM {table_with_schema(db, 'keysets')}"
) )
await conn.execute(f"DROP TABLE {table_with_schema(db, 'keysets')}") await conn.execute(f"DROP TABLE {table_with_schema(db, 'keysets')}")
await conn.execute(f""" await conn.execute(
f"""
CREATE TABLE IF NOT EXISTS {table_with_schema(db, 'keysets')} ( CREATE TABLE IF NOT EXISTS {table_with_schema(db, 'keysets')} (
id TEXT NOT NULL, id TEXT NOT NULL,
derivation_path TEXT, derivation_path TEXT,
@@ -373,7 +400,8 @@ async def m013_keysets_add_encrypted_seed(db: Database):
UNIQUE (id) UNIQUE (id)
); );
""") """
)
await conn.execute( await conn.execute(
f"INSERT INTO {table_with_schema(db, 'keysets')} (id," f"INSERT INTO {table_with_schema(db, 'keysets')} (id,"
" derivation_path, valid_from, valid_to, first_seen," " derivation_path, valid_from, valid_to, first_seen,"
@@ -430,7 +458,8 @@ async def m014_proofs_add_Y_column(db: Database):
f" SELECT * FROM {table_with_schema(db, 'proofs_used')}" f" SELECT * FROM {table_with_schema(db, 'proofs_used')}"
) )
await conn.execute(f"DROP TABLE {table_with_schema(db, 'proofs_used')}") await conn.execute(f"DROP TABLE {table_with_schema(db, 'proofs_used')}")
await conn.execute(f""" await conn.execute(
f"""
CREATE TABLE IF NOT EXISTS {table_with_schema(db, 'proofs_used')} ( CREATE TABLE IF NOT EXISTS {table_with_schema(db, 'proofs_used')} (
amount INTEGER NOT NULL, amount INTEGER NOT NULL,
C TEXT NOT NULL, C TEXT NOT NULL,
@@ -443,7 +472,8 @@ async def m014_proofs_add_Y_column(db: Database):
UNIQUE (Y) UNIQUE (Y)
); );
""") """
)
await conn.execute( await conn.execute(
f"INSERT INTO {table_with_schema(db, 'proofs_used')} (amount, C, " f"INSERT INTO {table_with_schema(db, 'proofs_used')} (amount, C, "
"secret, id, Y, created, witness) SELECT amount, C, secret, id, Y," "secret, id, Y, created, witness) SELECT amount, C, secret, id, Y,"
@@ -474,7 +504,8 @@ async def m014_proofs_add_Y_column(db: Database):
) )
await conn.execute(f"DROP TABLE {table_with_schema(db, 'proofs_pending')}") await conn.execute(f"DROP TABLE {table_with_schema(db, 'proofs_pending')}")
await conn.execute(f""" await conn.execute(
f"""
CREATE TABLE IF NOT EXISTS {table_with_schema(db, 'proofs_pending')} ( CREATE TABLE IF NOT EXISTS {table_with_schema(db, 'proofs_pending')} (
amount INTEGER NOT NULL, amount INTEGER NOT NULL,
C TEXT NOT NULL, C TEXT NOT NULL,
@@ -486,7 +517,8 @@ async def m014_proofs_add_Y_column(db: Database):
UNIQUE (Y) UNIQUE (Y)
); );
""") """
)
await conn.execute( await conn.execute(
f"INSERT INTO {table_with_schema(db, 'proofs_pending')} (amount, C, " f"INSERT INTO {table_with_schema(db, 'proofs_pending')} (amount, C, "
"secret, Y, id, created) SELECT amount, C, secret, Y, id, created" "secret, Y, id, created) SELECT amount, C, secret, Y, id, created"
@@ -507,3 +539,31 @@ async def m015_add_index_Y_to_proofs_used(db: Database):
" proofs_used_Y_idx ON" " proofs_used_Y_idx ON"
f" {table_with_schema(db, 'proofs_used')} (Y)" f" {table_with_schema(db, 'proofs_used')} (Y)"
) )
async def m016_recompute_Y_with_new_h2c(db: Database):
# get all proofs_used and proofs_pending from the database and compute Y for each of them
async with db.connect() as conn:
rows = await conn.fetchall(
f"SELECT * FROM {table_with_schema(db, 'proofs_used')}"
)
# Proof() will compute Y from secret upon initialization
proofs_used = [Proof(**r) for r in rows]
rows = await conn.fetchall(
f"SELECT * FROM {table_with_schema(db, 'proofs_pending')}"
)
proofs_pending = [Proof(**r) for r in rows]
# overwrite the old Y columns with the new Y
async with db.connect() as conn:
for proof in proofs_used:
await conn.execute(
f"UPDATE {table_with_schema(db, 'proofs_used')} SET Y = '{proof.Y}'"
f" WHERE secret = '{proof.secret}'"
)
for proof in proofs_pending:
await conn.execute(
f"UPDATE {table_with_schema(db, 'proofs_pending')} SET Y = '{proof.Y}'"
f" WHERE secret = '{proof.secret}'"
)

View File

@@ -998,9 +998,11 @@ class Wallet(LedgerAPI, WalletP2PK, WalletHTLC, WalletSecrets):
# NUT-08, the mint will imprint these outputs with a value depending on the # NUT-08, the mint will imprint these outputs with a value depending on the
# amount of fees we overpaid. # amount of fees we overpaid.
n_change_outputs = calculate_number_of_blank_outputs(fee_reserve_sat) n_change_outputs = calculate_number_of_blank_outputs(fee_reserve_sat)
change_secrets, change_rs, change_derivation_paths = ( (
await self.generate_n_secrets(n_change_outputs) change_secrets,
) change_rs,
change_derivation_paths,
) = await self.generate_n_secrets(n_change_outputs)
change_outputs, change_rs = self._construct_outputs( change_outputs, change_rs = self._construct_outputs(
n_change_outputs * [1], change_secrets, change_rs n_change_outputs * [1], change_secrets, change_rs
) )
@@ -1126,14 +1128,15 @@ class Wallet(LedgerAPI, WalletP2PK, WalletHTLC, WalletSecrets):
C = b_dhke.step3_alice( C = b_dhke.step3_alice(
C_, r, self.keysets[promise.id].public_keys[promise.amount] C_, r, self.keysets[promise.id].public_keys[promise.amount]
) )
# BEGIN: BACKWARDS COMPATIBILITY < 0.15.1
if not settings.wallet_domain_separation: if not settings.wallet_use_deprecated_h2c:
B_, r = b_dhke.step1_alice(secret, r) # recompute B_ for dleq proofs B_, r = b_dhke.step1_alice(secret, r) # recompute B_ for dleq proofs
# END: BACKWARDS COMPATIBILITY < 0.15.1 # BEGIN: BACKWARDS COMPATIBILITY < 0.15.1
else: else:
B_, r = b_dhke.step1_alice_domain_separated( B_, r = b_dhke.step1_alice_deprecated(
secret, r secret, r
) # recompute B_ for dleq proofs ) # recompute B_ for dleq proofs
# END: BACKWARDS COMPATIBILITY < 0.15.1
proof = Proof( proof = Proof(
id=promise.id, id=promise.id,
@@ -1196,12 +1199,13 @@ class Wallet(LedgerAPI, WalletP2PK, WalletHTLC, WalletSecrets):
rs_ = [None] * len(amounts) if not rs else rs rs_ = [None] * len(amounts) if not rs else rs
rs_return: List[PrivateKey] = [] rs_return: List[PrivateKey] = []
for secret, amount, r in zip(secrets, amounts, rs_): for secret, amount, r in zip(secrets, amounts, rs_):
# BEGIN: BACKWARDS COMPATIBILITY < 0.15.1 if not settings.wallet_use_deprecated_h2c:
if not settings.wallet_domain_separation:
B_, r = b_dhke.step1_alice(secret, r or None) B_, r = b_dhke.step1_alice(secret, r or None)
# END: BACKWARDS COMPATIBILITY < 0.15.1 # BEGIN: BACKWARDS COMPATIBILITY < 0.15.1
else: else:
B_, r = b_dhke.step1_alice_domain_separated(secret, r or None) B_, r = b_dhke.step1_alice_deprecated(secret, r or None)
# END: BACKWARDS COMPATIBILITY < 0.15.1
rs_return.append(r) rs_return.append(r)
output = BlindedMessage( output = BlindedMessage(
amount=amount, B_=B_.serialize().hex(), id=self.keyset_id amount=amount, B_=B_.serialize().hex(), id=self.keyset_id

View File

@@ -1,12 +1,11 @@
from cashu.core.crypto.b_dhke import ( from cashu.core.crypto.b_dhke import (
alice_verify_dleq, alice_verify_dleq,
carol_verify_dleq, carol_verify_dleq,
carol_verify_dleq_domain_separated,
hash_e, hash_e,
hash_to_curve, hash_to_curve,
hash_to_curve_domain_separated, hash_to_curve_deprecated,
step1_alice, step1_alice,
step1_alice_domain_separated, step1_alice_deprecated,
step2_bob, step2_bob,
step2_bob_dleq, step2_bob_dleq,
step3_alice, step3_alice,
@@ -22,9 +21,11 @@ def test_hash_to_curve():
) )
assert ( assert (
result.serialize().hex() result.serialize().hex()
== "0266687aadf862bd776c8fc18b8e9f8e20089714856ee233b3902a591d0d5f2925" == "024cce997d3b518f739663b757deaec95bcd9473c30a14ac2fd04023a739d1a725"
) )
def test_hash_to_curve_iteration():
result = hash_to_curve( result = hash_to_curve(
bytes.fromhex( bytes.fromhex(
"0000000000000000000000000000000000000000000000000000000000000001" "0000000000000000000000000000000000000000000000000000000000000001"
@@ -32,20 +33,7 @@ def test_hash_to_curve():
) )
assert ( assert (
result.serialize().hex() result.serialize().hex()
== "02ec4916dd28fc4c10d78e287ca5d9cc51ee1ae73cbfde08c6b37324cbfaac8bc5" == "022e7158e11c9506f1aa4248bf531298daa7febd6194f003edcd9b93ade6253acf"
)
def test_hash_to_curve_iteration():
"""This input causes multiple rounds of the hash_to_curve algorithm."""
result = hash_to_curve(
bytes.fromhex(
"0000000000000000000000000000000000000000000000000000000000000002"
)
)
assert (
result.serialize().hex()
== "02076c988b353fcbb748178ecb286bc9d0b4acf474d4ba31ba62334e46c97c416a"
) )
@@ -62,7 +50,7 @@ def test_step1():
assert ( assert (
B_.serialize().hex() B_.serialize().hex()
== "02a9acc1e48c25eeeb9289b5031cc57da9fe72f3fe2861d264bdc074209b107ba2" == "025cc16fe33b953e2ace39653efb3e7a7049711ae1d8a2f7a9108753f1cdea742b"
) )
assert blinding_factor.private_key == bytes.fromhex( assert blinding_factor.private_key == bytes.fromhex(
"0000000000000000000000000000000000000000000000000000000000000001" "0000000000000000000000000000000000000000000000000000000000000001"
@@ -88,7 +76,7 @@ def test_step2():
C_, e, s = step2_bob(B_, a) C_, e, s = step2_bob(B_, a)
assert ( assert (
C_.serialize().hex() C_.serialize().hex()
== "02a9acc1e48c25eeeb9289b5031cc57da9fe72f3fe2861d264bdc074209b107ba2" == "025cc16fe33b953e2ace39653efb3e7a7049711ae1d8a2f7a9108753f1cdea742b"
) )
@@ -97,7 +85,7 @@ def test_step3():
# C_ from test_step2 # C_ from test_step2
C_ = PublicKey( C_ = PublicKey(
bytes.fromhex( bytes.fromhex(
"02a9acc1e48c25eeeb9289b5031cc57da9fe72f3fe2861d264bdc074209b107ba2" "025cc16fe33b953e2ace39653efb3e7a7049711ae1d8a2f7a9108753f1cdea742b"
), ),
raw=True, raw=True,
) )
@@ -118,7 +106,7 @@ def test_step3():
assert ( assert (
C.serialize().hex() C.serialize().hex()
== "03c724d7e6a5443b39ac8acf11f40420adc4f99a02e7cc1b57703d9391f6d129cd" == "0271bf0d702dbad86cbe0af3ab2bfba70a0338f22728e412d88a830ed0580b9de4"
) )
@@ -176,11 +164,11 @@ def test_dleq_step2_bob_dleq():
e, s = step2_bob_dleq(B_, a, p_bytes) e, s = step2_bob_dleq(B_, a, p_bytes)
assert ( assert (
e.serialize() e.serialize()
== "9818e061ee51d5c8edc3342369a554998ff7b4381c8652d724cdf46429be73d9" == "a608ae30a54c6d878c706240ee35d4289b68cfe99454bbfa6578b503bce2dbe1"
) )
assert ( assert (
s.serialize() s.serialize()
== "9818e061ee51d5c8edc3342369a554998ff7b4381c8652d724cdf46429be73da" == "a608ae30a54c6d878c706240ee35d4289b68cfe99454bbfa6578b503bce2dbe2"
) # differs from e only in least significant byte because `a = 0x1` ) # differs from e only in least significant byte because `a = 0x1`
# change `a` # change `a`
@@ -193,11 +181,11 @@ def test_dleq_step2_bob_dleq():
e, s = step2_bob_dleq(B_, a, p_bytes) e, s = step2_bob_dleq(B_, a, p_bytes)
assert ( assert (
e.serialize() e.serialize()
== "df1984d5c22f7e17afe33b8669f02f530f286ae3b00a1978edaf900f4721f65e" == "076cbdda4f368053c33056c438df014d1875eb3c8b28120bece74b6d0e6381bb"
) )
assert ( assert (
s.serialize() s.serialize()
== "828404170c86f240c50ae0f5fc17bb6b82612d46b355e046d7cd84b0a3c934a0" == "b6d41ac1e12415862bf8cace95e5355e9262eab8a11d201dadd3b6e41584ea6e"
) )
@@ -306,36 +294,47 @@ def test_dleq_carol_verify_from_bob():
assert carol_verify_dleq(secret_msg=secret_msg, C=C, r=r, e=e, s=s, A=A) assert carol_verify_dleq(secret_msg=secret_msg, C=C, r=r, e=e, s=s, A=A)
# TESTS FOR DOMAIN SEPARATED HASH TO CURVE # TESTS FOR DEPRECATED HASH TO CURVE
def test_hash_to_curve_domain_separated(): def test_hash_to_curve_deprecated():
result = hash_to_curve_domain_separated( result = hash_to_curve_deprecated(
bytes.fromhex( bytes.fromhex(
"0000000000000000000000000000000000000000000000000000000000000000" "0000000000000000000000000000000000000000000000000000000000000000"
) )
) )
assert ( assert (
result.serialize().hex() result.serialize().hex()
== "024cce997d3b518f739663b757deaec95bcd9473c30a14ac2fd04023a739d1a725" == "0266687aadf862bd776c8fc18b8e9f8e20089714856ee233b3902a591d0d5f2925"
) )
result = hash_to_curve_deprecated(
def test_hash_to_curve_domain_separated_iterative():
result = hash_to_curve_domain_separated(
bytes.fromhex( bytes.fromhex(
"0000000000000000000000000000000000000000000000000000000000000001" "0000000000000000000000000000000000000000000000000000000000000001"
) )
) )
assert ( assert (
result.serialize().hex() result.serialize().hex()
== "022e7158e11c9506f1aa4248bf531298daa7febd6194f003edcd9b93ade6253acf" == "02ec4916dd28fc4c10d78e287ca5d9cc51ee1ae73cbfde08c6b37324cbfaac8bc5"
) )
def test_step1_domain_separated(): def test_hash_to_curve_iteration_deprecated():
"""This input causes multiple rounds of the hash_to_curve algorithm."""
result = hash_to_curve_deprecated(
bytes.fromhex(
"0000000000000000000000000000000000000000000000000000000000000002"
)
)
assert (
result.serialize().hex()
== "02076c988b353fcbb748178ecb286bc9d0b4acf474d4ba31ba62334e46c97c416a"
)
def test_step1_deprecated():
secret_msg = "test_message" secret_msg = "test_message"
B_, blinding_factor = step1_alice_domain_separated( B_, blinding_factor = step1_alice_deprecated(
secret_msg, secret_msg,
blinding_factor=PrivateKey( blinding_factor=PrivateKey(
privkey=bytes.fromhex( privkey=bytes.fromhex(
@@ -346,15 +345,15 @@ def test_step1_domain_separated():
assert ( assert (
B_.serialize().hex() B_.serialize().hex()
== "025cc16fe33b953e2ace39653efb3e7a7049711ae1d8a2f7a9108753f1cdea742b" == "02a9acc1e48c25eeeb9289b5031cc57da9fe72f3fe2861d264bdc074209b107ba2"
) )
assert blinding_factor.private_key == bytes.fromhex( assert blinding_factor.private_key == bytes.fromhex(
"0000000000000000000000000000000000000000000000000000000000000001" "0000000000000000000000000000000000000000000000000000000000000001"
) )
def test_step2_domain_separated(): def test_step2_deprecated():
B_, _ = step1_alice_domain_separated( B_, _ = step1_alice_deprecated(
"test_message", "test_message",
blinding_factor=PrivateKey( blinding_factor=PrivateKey(
privkey=bytes.fromhex( privkey=bytes.fromhex(
@@ -372,16 +371,16 @@ def test_step2_domain_separated():
C_, e, s = step2_bob(B_, a) C_, e, s = step2_bob(B_, a)
assert ( assert (
C_.serialize().hex() C_.serialize().hex()
== "025cc16fe33b953e2ace39653efb3e7a7049711ae1d8a2f7a9108753f1cdea742b" == "02a9acc1e48c25eeeb9289b5031cc57da9fe72f3fe2861d264bdc074209b107ba2"
) )
def test_step3_domain_separated(): def test_step3_deprecated():
# C = C_ - A.mult(r) # C = C_ - A.mult(r)
# C_ from test_step2 # C_ from test_step2_deprecated
C_ = PublicKey( C_ = PublicKey(
bytes.fromhex( bytes.fromhex(
"025cc16fe33b953e2ace39653efb3e7a7049711ae1d8a2f7a9108753f1cdea742b" "02a9acc1e48c25eeeb9289b5031cc57da9fe72f3fe2861d264bdc074209b107ba2"
), ),
raw=True, raw=True,
) )
@@ -402,32 +401,52 @@ def test_step3_domain_separated():
assert ( assert (
C.serialize().hex() C.serialize().hex()
== "0271bf0d702dbad86cbe0af3ab2bfba70a0338f22728e412d88a830ed0580b9de4" == "03c724d7e6a5443b39ac8acf11f40420adc4f99a02e7cc1b57703d9391f6d129cd"
) )
def test_dleq_carol_verify_from_bob_domain_separated(): def test_dleq_step2_bob_dleq_deprecated():
B_, _ = step1_alice_deprecated(
"test_message",
blinding_factor=PrivateKey(
privkey=bytes.fromhex(
"0000000000000000000000000000000000000000000000000000000000000001"
),
raw=True,
),
)
a = PrivateKey( a = PrivateKey(
privkey=bytes.fromhex( privkey=bytes.fromhex(
"0000000000000000000000000000000000000000000000000000000000000001" "0000000000000000000000000000000000000000000000000000000000000001"
), ),
raw=True, raw=True,
) )
A = a.pubkey p_bytes = bytes.fromhex(
assert A
secret_msg = "test_message"
r = PrivateKey(
privkey=bytes.fromhex(
"0000000000000000000000000000000000000000000000000000000000000001" "0000000000000000000000000000000000000000000000000000000000000001"
) # 32 bytes
e, s = step2_bob_dleq(B_, a, p_bytes)
assert (
e.serialize()
== "9818e061ee51d5c8edc3342369a554998ff7b4381c8652d724cdf46429be73d9"
)
assert (
s.serialize()
== "9818e061ee51d5c8edc3342369a554998ff7b4381c8652d724cdf46429be73da"
) # differs from e only in least significant byte because `a = 0x1`
# change `a`
a = PrivateKey(
privkey=bytes.fromhex(
"0000000000000000000000000000000000000000000000000000000000001111"
), ),
raw=True, raw=True,
) )
B_, _ = step1_alice_domain_separated(secret_msg, r) e, s = step2_bob_dleq(B_, a, p_bytes)
C_, e, s = step2_bob(B_, a) assert (
assert alice_verify_dleq(B_, C_, e, s, A) e.serialize()
C = step3_alice(C_, r, A) == "df1984d5c22f7e17afe33b8669f02f530f286ae3b00a1978edaf900f4721f65e"
)
# carol does not know B_ and C_, but she receives C and r from Alice assert (
assert carol_verify_dleq_domain_separated( s.serialize()
secret_msg=secret_msg, C=C, r=r, e=e, s=s, A=A == "828404170c86f240c50ae0f5fc17bb6b82612d46b355e046d7cd84b0a3c934a0"
) )