Keyset ID: fix deserialization edge-case, add unit tests

This commit is contained in:
ok300
2024-10-23 16:24:51 +02:00
committed by thesimplekid
parent 58e7226cff
commit 09b5a55239
2 changed files with 60 additions and 60 deletions

View File

@@ -35,7 +35,7 @@ reqwest = { version = "0.12", default-features = false, features = [
], optional = true }
serde = { version = "1", default-features = false, features = ["derive"] }
serde_json = "1"
serde_with = "3.1"
serde_with = "3"
tracing = { version = "0.1", default-features = false, features = ["attributes", "log"] }
thiserror = "1"
futures = { version = "0.3.28", default-features = false, optional = true }

View File

@@ -18,7 +18,7 @@ use bitcoin::hashes::Hash;
use bitcoin::key::Secp256k1;
#[cfg(feature = "mint")]
use bitcoin::secp256k1;
use serde::{Deserialize, Deserializer, Serialize};
use serde::{Deserialize, Serialize};
use serde_with::{serde_as, VecSkipError};
use thiserror::Error;
@@ -86,10 +86,11 @@ impl fmt::Display for KeySetVersion {
/// A keyset ID is an identifier for a specific keyset. It can be derived by
/// anyone who knows the set of public keys of a mint. The keyset ID **CAN**
/// be stored in a Cashu token such that the token can be used to identify
/// be stored in a Cashu token such that the token can be used to identify
/// which mint or keyset it was generated from.
#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash)]
#[cfg_attr(feature = "swagger", derive(utoipa::ToSchema))]
#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash, Serialize, Deserialize)]
#[serde(into = "String", try_from = "String")]
#[cfg_attr(feature = "swagger", derive(utoipa::ToSchema), schema(as = String))]
pub struct Id {
version: KeySetVersion,
id: [u8; Self::BYTELEN],
@@ -130,17 +131,16 @@ impl fmt::Display for Id {
}
}
impl FromStr for Id {
type Err = Error;
impl TryFrom<String> for Id {
type Error = Error;
fn from_str(s: &str) -> Result<Self, Self::Err> {
// Check if the string length is valid
fn try_from(s: String) -> Result<Self, Self::Error> {
if s.len() != 16 {
return Err(Error::Length);
}
Ok(Self {
version: KeySetVersion::Version00,
version: KeySetVersion::from_byte(&hex::decode(&s[..2])?[0])?,
id: hex::decode(&s[2..])?
.try_into()
.map_err(|_| Error::Length)?,
@@ -148,63 +148,29 @@ impl FromStr for Id {
}
}
impl Serialize for Id {
fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
where
S: serde::Serializer,
{
serializer.serialize_str(&self.to_string())
impl FromStr for Id {
type Err = Error;
fn from_str(s: &str) -> Result<Self, Self::Err> {
Self::try_from(s.to_string())
}
}
impl<'de> Deserialize<'de> for Id {
fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
where
D: Deserializer<'de>,
{
struct IdVisitor;
impl<'de> serde::de::Visitor<'de> for IdVisitor {
type Value = Id;
fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result {
formatter.write_str("Expecting a 14 char hex string")
}
fn visit_str<E>(self, v: &str) -> Result<Self::Value, E>
where
E: serde::de::Error,
{
Id::from_str(v).map_err(|e| match e {
Error::Length => E::custom(format!(
"Invalid Length: Expected {}, got {}:
{}",
Id::STRLEN,
v.len(),
v
)),
_ => E::custom(e),
})
}
}
deserializer.deserialize_str(IdVisitor)
impl From<Id> for String {
fn from(value: Id) -> Self {
value.to_string()
}
}
impl From<&Keys> for Id {
/// As per NUT-02:
/// 1. sort public keys by their amount in ascending order
/// 2. concatenate all public keys to one string
/// 3. HASH_SHA256 the concatenated public keys
/// 4. take the first 14 characters of the hex-encoded hash
/// 5. prefix it with a keyset ID version byte
fn from(map: &Keys) -> Self {
// REVIEW: Is it 16 or 14 bytes
/* NUT-02
1 - sort public keys by their amount in ascending order
2 - concatenate all public keys to one string
3 - HASH_SHA256 the concatenated public keys
4 - take the first 14 characters of the hex-encoded hash
5 - prefix it with a keyset ID version byte
*/
let mut keys: Vec<(&AmountStr, &super::PublicKey)> = map.iter().collect();
keys.sort_by_key(|(amt, _v)| *amt);
let pubkeys_concat: Vec<u8> = keys
@@ -400,12 +366,14 @@ impl From<&MintKeys> for Id {
#[cfg(test)]
mod test {
use std::str::FromStr;
use rand::RngCore;
use super::{KeySetInfo, Keys, KeysetResponse};
use crate::nuts::nut02::Id;
use crate::nuts::nut02::{Error, Id};
use crate::nuts::KeysResponse;
use crate::util::hex;
const SHORT_KEYSET_ID: &str = "00456a94ab4e1c46";
const SHORT_KEYSET: &str = r#"
@@ -547,4 +515,36 @@ mod test {
assert_eq!(keys_response.keysets.len(), 2);
}
fn generate_random_id() -> Id {
let mut rand_bytes = vec![0u8; 8];
rand::thread_rng().fill_bytes(&mut rand_bytes[1..]);
Id::from_bytes(&rand_bytes)
.unwrap_or_else(|e| panic!("Failed to create Id from {}: {e}", hex::encode(rand_bytes)))
}
#[test]
fn test_id_serialization() {
let id = generate_random_id();
let id_str = id.to_string();
assert!(id_str.chars().all(|c| c.is_ascii_hexdigit()));
assert_eq!(16, id_str.len());
assert_eq!(id_str.to_lowercase(), id_str);
}
#[test]
fn test_id_deserialization() {
let id_from_short_str = Id::from_str("00123");
assert!(matches!(id_from_short_str, Err(Error::Length)));
let id_from_non_hex_str = Id::from_str(&SHORT_KEYSET_ID.replace('a', "x"));
assert!(matches!(id_from_non_hex_str, Err(Error::HexError(_))));
let id_invalid_version = Id::from_str(&SHORT_KEYSET_ID.replace("00", "99"));
assert!(matches!(id_invalid_version, Err(Error::UnknownVersion)));
let id_from_uppercase = Id::from_str(&SHORT_KEYSET_ID.to_uppercase());
assert!(id_from_uppercase.is_ok());
}
}