From 4eb8afe516c5eda57f041d3a374d2e7565f8db9d Mon Sep 17 00:00:00 2001 From: callebtc <93376500+callebtc@users.noreply.github.com> Date: Tue, 11 Jul 2023 00:40:18 +0200 Subject: [PATCH] Wallet: refactor split (#283) * refactor split * add comments * add witnesses inside /split * more comments --- cashu/wallet/wallet.py | 105 +++++++++++++++++++++++++---------------- 1 file changed, 64 insertions(+), 41 deletions(-) diff --git a/cashu/wallet/wallet.py b/cashu/wallet/wallet.py index 9eeaf09..8568631 100644 --- a/cashu/wallet/wallet.py +++ b/cashu/wallet/wallet.py @@ -431,8 +431,14 @@ class LedgerAPI: @async_set_requests async def split( - self, proofs, amount, secret_lock: Optional[Secret] = None - ) -> Tuple[List[Proof], List[Proof]]: + self, + proofs: List[Proof], + outputs: List[BlindedMessage], + secrets: List[str], + rs: List[PrivateKey], + amount: int, + secret_lock: Optional[Secret] = None, + ) -> Tuple[List[BlindedSignature], List[BlindedSignature]]: """Consume proofs and create new promises based on amount split. If secret_lock is None, random secrets will be generated for the tokens to keep (frst_outputs) @@ -441,30 +447,6 @@ class LedgerAPI: If secret_lock is provided, the wallet will create blinded secrets with those to attach a predefined spending condition to the tokens they want to send.""" logger.debug("Calling split. POST /split") - total = sum_proofs(proofs) - frst_amt, scnd_amt = total - amount, amount - frst_outputs = amount_split(frst_amt) - scnd_outputs = amount_split(scnd_amt) - - amounts = frst_outputs + scnd_outputs - if secret_lock is None: - secrets = [self._generate_secret() for _ in range(len(amounts))] - else: - secret_locks = [secret_lock.serialize() for i in range(len(scnd_outputs))] - logger.debug(f"Creating proofs with custom secrets: {secret_locks}") - assert len(secret_locks) == len( - scnd_outputs - ), "number of secret_locks does not match number of ouptus." - # append predefined secrets (to send) to random secrets (to keep) - secrets = [ - self._generate_secret() for s in range(len(frst_outputs)) - ] + secret_locks - - assert len(secrets) == len( - amounts - ), "number of secrets does not match number of outputs" - await self._check_used_secrets(secrets) - outputs, rs = self._construct_outputs(amounts, secrets) split_payload = PostSplitRequest(proofs=proofs, amount=amount, outputs=outputs) # construct payload @@ -487,15 +469,11 @@ class LedgerAPI: promises_fst = [BlindedSignature(**p) for p in promises_dict["fst"]] promises_snd = [BlindedSignature(**p) for p in promises_dict["snd"]] - # Construct proofs from promises (i.e., unblind signatures) - frst_proofs = self._construct_proofs( - promises_fst, secrets[: len(promises_fst)], rs[: len(promises_fst)] - ) - scnd_proofs = self._construct_proofs( - promises_snd, secrets[len(promises_fst) :], rs[len(promises_fst) :] - ) - return frst_proofs, scnd_proofs + if len(promises_fst) == 0 and len(promises_snd) == 0: + raise Exception("received no splits.") + + return promises_fst, promises_snd @async_set_requests async def check_proof_state(self, proofs: List[Proof]): @@ -690,7 +668,7 @@ class Wallet(LedgerAPI): elif all( [Secret.deserialize(p.secret).kind == SecretKind.P2PK for p in proofs] ): - p2pk_signatures = await self.sign_p2pk_with_privatekey(proofs) + p2pk_signatures = await self.sign_p2pk_proofs(proofs) logger.debug(f"Unlock signature: {p2pk_signatures}") # attach unlock signatures to proofs @@ -704,7 +682,6 @@ class Wallet(LedgerAPI): self, proofs: List[Proof], ): - proofs = await self.add_witnesses_to_proofs(proofs) return await self.split(proofs, sum_proofs(proofs)) async def split( @@ -714,17 +691,63 @@ class Wallet(LedgerAPI): secret_lock: Optional[Secret] = None, ): assert len(proofs) > 0, ValueError("no proofs provided.") - frst_proofs, scnd_proofs = await super().split(proofs, amount, secret_lock) - if len(frst_proofs) == 0 and len(scnd_proofs) == 0: - raise Exception("received no splits.") + # potentially add witnesses to unlock provided proofs (if they indicate one) + proofs = await self.add_witnesses_to_proofs(proofs) + + # create a suitable amount split based on the proofs provided + total = sum_proofs(proofs) + frst_amt, scnd_amt = total - amount, amount + frst_outputs = amount_split(frst_amt) + scnd_outputs = amount_split(scnd_amt) + + amounts = frst_outputs + scnd_outputs + # generate secrets for new outputs + if secret_lock is None: + # generate all secrets randomly + secrets = [self._generate_secret() for _ in range(len(amounts))] + else: + # use provided secret_lock to generate secrets + secret_locks = [secret_lock.serialize() for i in range(len(scnd_outputs))] + logger.debug(f"Creating proofs with custom secrets: {secret_locks}") + assert len(secret_locks) == len( + scnd_outputs + ), "number of secret_locks does not match number of ouptus." + # append custom locks (to send) to randomly generated secrets (to keep) + secrets = [ + self._generate_secret() for s in range(len(frst_outputs)) + ] + secret_locks + + assert len(secrets) == len( + amounts + ), "number of secrets does not match number of outputs" + # verify that we didn't accidentally reuse a secret + await self._check_used_secrets(secrets) + + # construct outputs + outputs, rs = self._construct_outputs(amounts, secrets) + + # Call /split API + promises_fst, promises_snd = await super().split( + proofs, outputs, secrets, rs, amount, secret_lock + ) + + # Construct proofs from returned promises (i.e., unblind the signatures) + frst_proofs = self._construct_proofs( + promises_fst, secrets[: len(promises_fst)], rs[: len(promises_fst)] + ) + scnd_proofs = self._construct_proofs( + promises_snd, secrets[len(promises_fst) :], rs[len(promises_fst) :] + ) # remove used proofs from wallet and add new ones used_secrets = [p.secret for p in proofs] self.proofs = list(filter(lambda p: p.secret not in used_secrets, self.proofs)) + # add new proofs to wallet self.proofs += frst_proofs + scnd_proofs + # store new proofs in database await self._store_proofs(frst_proofs + scnd_proofs) - # invalidate used proofs + # invalidate used proofs in database for proof in proofs: await invalidate_proof(proof, db=self.db) return frst_proofs, scnd_proofs @@ -1087,7 +1110,7 @@ class Wallet(LedgerAPI): tags=tags, ) - async def sign_p2pk_with_privatekey(self, proofs: List[Proof]) -> List[str]: + async def sign_p2pk_proofs(self, proofs: List[Proof]) -> List[str]: assert ( self.private_key ), "No private key set in settings. Set NOSTR_PRIVATE_KEY in .env"