Files
nutshell/cashu/core/crypto/b_dhke.py
callebtc e2c8f7f694 Add tests for domain separated h2c (#451)
* add tests for domain separated h2c

* refactor b_dhke and add domain separated test
2024-02-21 11:10:50 +01:00

244 lines
7.0 KiB
Python

# Don't trust me with cryptography.
"""
Implementation of https://gist.github.com/RubenSomsen/be7a4760dd4596d06963d67baf140406
Bob (Mint):
A = a*G
return A
Alice (Client):
Y = hash_to_curve(secret_message)
r = random blinding factor
B'= Y + r*G
return B'
Bob:
C' = a*B'
(= a*Y + a*r*G)
return C'
Alice:
C = C' - r*A
(= C' - a*r*G)
(= a*Y)
return C, secret_message
Bob:
Y = hash_to_curve(secret_message)
C == a*Y
If true, C must have originated from Bob
# DLEQ Proof
(These steps occur once Bob returns C')
Bob:
r = random nonce
R1 = r*G
R2 = r*B'
e = hash(R1,R2,A,C')
s = r + e*a
return e, s
Alice:
R1 = s*G - e*A
R2 = s*B' - e*C'
e == hash(R1,R2,A,C')
If true, a in A = a*G must be equal to a in C' = a*B'
"""
import hashlib
from typing import Optional, Tuple
from secp256k1 import PrivateKey, PublicKey
def hash_to_curve(message: bytes) -> PublicKey:
"""Generates a point from the message hash and checks if the point lies on the curve.
If it does not, iteratively tries to compute a new point from the hash."""
point = None
msg_to_hash = message
while point is None:
_hash = hashlib.sha256(msg_to_hash).digest()
try:
# will error if point does not lie on curve
point = PublicKey(b"\x02" + _hash, raw=True)
except Exception:
msg_to_hash = _hash
return point
DOMAIN_SEPARATOR = b"Secp256k1_HashToCurve_Cashu_"
def hash_to_curve_domain_separated(message: bytes) -> PublicKey:
"""Generates a secp256k1 point from a message.
The point is generated by hashing the message with a domain separator and then
iteratively trying to compute a point from the hash. An increasing uint32 counter
(byte order little endian) is appended to the hash until a point is found that lies on the curve.
The chance of finding a valid point is 50% for every iteration. The maximum number of iterations
is 2**16. If no valid point is found after 2**16 iterations, a ValueError is raised (this should
never happen in practice).
The domain separator is b"Secp256k1_HashToCurve_Cashu_" or
bytes.fromhex("536563703235366b315f48617368546f43757276655f43617368755f").
"""
msg_to_hash = hashlib.sha256(DOMAIN_SEPARATOR + message).digest()
counter = 0
while counter < 2**16:
_hash = hashlib.sha256(msg_to_hash + counter.to_bytes(4, "little")).digest()
try:
# will error if point does not lie on curve
return PublicKey(b"\x02" + _hash, raw=True)
except Exception:
counter += 1
# it should never reach this point
raise ValueError("No valid point found")
def step1_alice(
secret_msg: str, blinding_factor: Optional[PrivateKey] = None
) -> tuple[PublicKey, PrivateKey]:
Y: PublicKey = hash_to_curve(secret_msg.encode("utf-8"))
r = blinding_factor or PrivateKey()
B_: PublicKey = Y + r.pubkey # type: ignore
return B_, r
def step1_alice_domain_separated(
secret_msg: str, blinding_factor: Optional[PrivateKey] = None
) -> tuple[PublicKey, PrivateKey]:
Y: PublicKey = hash_to_curve_domain_separated(secret_msg.encode("utf-8"))
r = blinding_factor or PrivateKey()
B_: PublicKey = Y + r.pubkey # type: ignore
return B_, r
def step2_bob(B_: PublicKey, a: PrivateKey) -> Tuple[PublicKey, PrivateKey, PrivateKey]:
C_: PublicKey = B_.mult(a) # type: ignore
# produce dleq proof
e, s = step2_bob_dleq(B_, a)
return C_, e, s
def step3_alice(C_: PublicKey, r: PrivateKey, A: PublicKey) -> PublicKey:
C: PublicKey = C_ - A.mult(r) # type: ignore
return C
def verify(a: PrivateKey, C: PublicKey, secret_msg: str) -> bool:
Y: PublicKey = hash_to_curve(secret_msg.encode("utf-8"))
valid = C == Y.mult(a) # type: ignore
# BEGIN: BACKWARDS COMPATIBILITY < 0.15.1
if not valid:
return verify_domain_separated(a, C, secret_msg)
# END: BACKWARDS COMPATIBILITY < 0.15.1
return valid
def verify_domain_separated(a: PrivateKey, C: PublicKey, secret_msg: str) -> bool:
Y: PublicKey = hash_to_curve_domain_separated(secret_msg.encode("utf-8"))
valid = C == Y.mult(a) # type: ignore
return valid
def hash_e(*publickeys: PublicKey) -> bytes:
e_ = ""
for p in publickeys:
_p = p.serialize(compressed=False).hex()
e_ += str(_p)
e = hashlib.sha256(e_.encode("utf-8")).digest()
return e
def step2_bob_dleq(
B_: PublicKey, a: PrivateKey, p_bytes: bytes = b""
) -> Tuple[PrivateKey, PrivateKey]:
if p_bytes:
# deterministic p for testing
p = PrivateKey(privkey=p_bytes, raw=True)
else:
# normally, we generate a random p
p = PrivateKey()
R1 = p.pubkey # R1 = pG
assert R1
R2: PublicKey = B_.mult(p) # R2 = pB_ # type: ignore
C_: PublicKey = B_.mult(a) # C_ = aB_ # type: ignore
A = a.pubkey
assert A
e = hash_e(R1, R2, A, C_) # e = hash(R1, R2, A, C_)
s = p.tweak_add(a.tweak_mul(e)) # s = p + ek
spk = PrivateKey(s, raw=True)
epk = PrivateKey(e, raw=True)
return epk, spk
def alice_verify_dleq(
B_: PublicKey, C_: PublicKey, e: PrivateKey, s: PrivateKey, A: PublicKey
) -> bool:
R1 = s.pubkey - A.mult(e) # type: ignore
R2 = B_.mult(s) - C_.mult(e) # type: ignore
e_bytes = e.private_key
return e_bytes == hash_e(R1, R2, A, C_)
def carol_verify_dleq(
secret_msg: str,
r: PrivateKey,
C: PublicKey,
e: PrivateKey,
s: PrivateKey,
A: PublicKey,
) -> bool:
Y: PublicKey = hash_to_curve(secret_msg.encode("utf-8"))
C_: PublicKey = C + A.mult(r) # type: ignore
B_: PublicKey = Y + r.pubkey # type: ignore
valid = alice_verify_dleq(B_, C_, e, s, A)
# BEGIN: BACKWARDS COMPATIBILITY < 0.15.1
if not valid:
return carol_verify_dleq_domain_separated(secret_msg, r, C, e, s, A)
# END: BACKWARDS COMPATIBILITY < 0.15.1
return valid
def carol_verify_dleq_domain_separated(
secret_msg: str,
r: PrivateKey,
C: PublicKey,
e: PrivateKey,
s: PrivateKey,
A: PublicKey,
) -> bool:
Y: PublicKey = hash_to_curve_domain_separated(secret_msg.encode("utf-8"))
C_: PublicKey = C + A.mult(r) # type: ignore
B_: PublicKey = Y + r.pubkey # type: ignore
valid = alice_verify_dleq(B_, C_, e, s, A)
return valid
# Below is a test of a simple positive and negative case
# # Alice's keys
# a = PrivateKey()
# A = a.pubkey
# secret_msg = "test"
# B_, r = step1_alice(secret_msg)
# C_ = step2_bob(B_, a)
# C = step3_alice(C_, r, A)
# print("C:{}, secret_msg:{}".format(C, secret_msg))
# assert verify(a, C, secret_msg)
# assert verify(a, C + C, secret_msg) == False # adding C twice shouldn't pass
# assert verify(a, A, secret_msg) == False # A shouldn't pass
# # Test operations
# b = PrivateKey()
# B = b.pubkey
# assert -A -A + A == -A # neg
# assert B.mult(a) == A.mult(b) # a*B = A*b