diff --git a/cashu/core/b_dhke.py b/cashu/core/b_dhke.py index 5e9d2e5..5c83ddc 100644 --- a/cashu/core/b_dhke.py +++ b/cashu/core/b_dhke.py @@ -44,9 +44,12 @@ def hash_to_curve(message: bytes): return point -def step1_alice(secret_msg: str): +def step1_alice(secret_msg: str, blinding_factor: bytes = None): Y = hash_to_curve(secret_msg.encode("utf-8")) - r = PrivateKey() + if blinding_factor: + r = PrivateKey(privkey=blinding_factor, raw=True) + else: + r = PrivateKey() B_ = Y + r.pubkey return B_, r diff --git a/tests/test_crypto.py b/tests/test_crypto.py index 947dc2a..e10f340 100644 --- a/tests/test_crypto.py +++ b/tests/test_crypto.py @@ -1,6 +1,6 @@ import pytest - -from cashu.core.b_dhke import hash_to_curve +from cashu.core.secp import PublicKey, PrivateKey +from cashu.core.b_dhke import hash_to_curve, step1_alice, step2_bob, step3_alice def test_hash_to_curve(): @@ -26,6 +26,7 @@ def test_hash_to_curve(): def test_hash_to_curve_iteration(): + """This input causes multiple rounds of the hash_to_curve algorithm.""" result = hash_to_curve( bytes.fromhex( "0000000000000000000000000000000000000000000000000000000000000002" @@ -35,3 +36,65 @@ def test_hash_to_curve_iteration(): result.serialize().hex() == "02076c988b353fcbb748178ecb286bc9d0b4acf474d4ba31ba62334e46c97c416a" ) + + +def test_step1(): + """""" + B_, blinding_factor = step1_alice( + "test_message", blinding_factor=b"00000000000000000000000000000001" # 32 bytes + ) + + assert ( + B_.serialize().hex() + == "0243379106c73dfc635cd1422f406e83fbfa25be83bb3620aefc08f2b89d02d777" + ) + assert blinding_factor.private_key == b"00000000000000000000000000000001" + + +def test_step2(): + B_, _ = step1_alice( + "test_message", + blinding_factor=bytes.fromhex( + "0000000000000000000000000000000000000000000000000000000000000001" + ), # 32 bytes + ) + a = PrivateKey( + privkey=bytes.fromhex( + "0000000000000000000000000000000000000000000000000000000000000001" + ), + raw=True, + ) + C_ = B_.mult(a) + assert ( + C_.serialize().hex() + == "02a9acc1e48c25eeeb9289b5031cc57da9fe72f3fe2861d264bdc074209b107ba2" + ) + + +def test_step3(): + # C = C_ - A.mult(r) + C_ = PublicKey( + bytes.fromhex( + "02b15f14ae9259c101cdbc437e8877b1ca5d4af3a0c0684866b38d8c8d0b6f6374" + ), + raw=True, + ) + r = PrivateKey( + privkey=bytes.fromhex( + "0000000000000000000000000000000000000000000000000000000000000001" + ) + ) + + A = PublicKey( + pubkey=b"\x02" + + bytes.fromhex( + "0000000000000000000000000000000000000000000000000000000000000001", + ), + raw=True, + ) + C = step3_alice(C_, r, A) + + assert ( + C.serialize().hex() + == "03398f7153b381ce54d57962a5e03ce0a4f3b79755e882c972b788e8488e59b0c9" + )