diff --git a/CHANGELOG.md b/CHANGELOG.md index e8206190..22dda38c 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -29,9 +29,11 @@ ### Changed ### Added +- cdk(NUT-11): Add `Copy` on `SigFlag` ([thesimplekid]). ### Fixed - cdk(mint): `SIG_ALL` is not allowed in `melt` ([thesimplekid]). +- cdk(mint): On `swap` verify correct number of sigs on outputs when `SigAll` ([thesimplekid]). ### Removed diff --git a/crates/cdk/src/mint/mod.rs b/crates/cdk/src/mint/mod.rs index 1f1346d1..b50a2e24 100644 --- a/crates/cdk/src/mint/mod.rs +++ b/crates/cdk/src/mint/mod.rs @@ -11,6 +11,7 @@ use tokio::sync::RwLock; use tracing::instrument; use self::nut05::QuoteState; +use self::nut11::EnforceSigFlag; use crate::cdk_database::{self, MintDatabase}; use crate::dhke::{hash_to_curve, sign_message, verify_message}; use crate::nuts::nut11::enforce_sig_flag; @@ -656,12 +657,16 @@ impl Mint { return Err(Error::MultipleUnits); } - let (sig_flag, pubkeys) = enforce_sig_flag(swap_request.inputs.clone()); + let EnforceSigFlag { + sig_flag, + pubkeys, + sigs_required, + } = enforce_sig_flag(swap_request.inputs.clone()); if sig_flag.eq(&SigFlag::SigAll) { let pubkeys = pubkeys.into_iter().collect(); - for blinded_messaage in &swap_request.outputs { - blinded_messaage.verify_p2pk(&pubkeys, 1)?; + for blinded_message in &swap_request.outputs { + blinded_message.verify_p2pk(&pubkeys, sigs_required)?; } } @@ -819,7 +824,7 @@ impl Mint { } if let Some(outputs) = &melt_request.outputs { - let (sig_flag, _pubkeys) = enforce_sig_flag(melt_request.inputs.clone()); + let EnforceSigFlag { sig_flag, .. } = enforce_sig_flag(melt_request.inputs.clone()); if sig_flag.eq(&SigFlag::SigAll) { return Err(Error::SigAllUsedInMelt); diff --git a/crates/cdk/src/nuts/nut11/mod.rs b/crates/cdk/src/nuts/nut11/mod.rs index 541c9b4e..267d4456 100644 --- a/crates/cdk/src/nuts/nut11/mod.rs +++ b/crates/cdk/src/nuts/nut11/mod.rs @@ -578,7 +578,9 @@ where /// Signature flag /// /// Defined in [NUT11](https://github.com/cashubtc/nuts/blob/main/11.md) -#[derive(Debug, Default, Clone, PartialEq, Eq, Serialize, Deserialize, PartialOrd, Ord, Hash)] +#[derive( + Debug, Default, Clone, Copy, PartialEq, Eq, Serialize, Deserialize, PartialOrd, Ord, Hash, +)] pub enum SigFlag { #[default] /// Requires valid signatures on all inputs. @@ -609,9 +611,10 @@ impl FromStr for SigFlag { } /// Get the signature flag that should be enforced for a set of proofs and the public keys that signatures are valid for -pub fn enforce_sig_flag(proofs: Proofs) -> (SigFlag, HashSet) { +pub(crate) fn enforce_sig_flag(proofs: Proofs) -> EnforceSigFlag { let mut sig_flag = SigFlag::SigInputs; let mut pubkeys = HashSet::new(); + let mut sigs_required = 1; for proof in proofs { if let Ok(secret) = Nut10Secret::try_from(proof.secret) { if secret.kind.eq(&Kind::P2PK) { @@ -626,6 +629,12 @@ pub fn enforce_sig_flag(proofs: Proofs) -> (SigFlag, HashSet) { sig_flag = SigFlag::SigAll; } + if let Some(sigs) = conditions.num_sigs { + if sigs > sigs_required { + sigs_required = sigs; + } + } + if let Some(pubs) = conditions.pubkeys { pubkeys.extend(pubs); } @@ -634,7 +643,22 @@ pub fn enforce_sig_flag(proofs: Proofs) -> (SigFlag, HashSet) { } } - (sig_flag, pubkeys) + EnforceSigFlag { + sig_flag, + pubkeys, + sigs_required, + } +} + +/// Enforce Sigflag info +#[derive(Debug, Clone, PartialEq, Eq)] +pub(crate) struct EnforceSigFlag { + /// Sigflag required for proofs + pub sig_flag: SigFlag, + /// Pubkeys that can sign for proofs + pub pubkeys: HashSet, + /// Number of sigs required for proofs + pub sigs_required: u64, } /// Tag