diff --git a/persister/src/lib.rs b/persister/src/lib.rs index 3e67bed..48dd52f 100644 --- a/persister/src/lib.rs +++ b/persister/src/lib.rs @@ -22,7 +22,7 @@ pub struct FsPersister { channels: DoubleBucket, allowlist: Bucket, chaintracker: Bucket, - pubkeys: Bucket, + pubkeys: Bucket>, } impl FsPersister { @@ -44,6 +44,36 @@ fn get_channel_key(channel_id: &[u8]) -> &[u8] { channel_id.get(length - 11..length - 7).unwrap() } +const PUBKEY_KEY: &str = "PKS"; + +impl FsPersister { + fn get_pubkeys(&self) -> Vec { + match self.pubkeys.get(PUBKEY_KEY) { + Ok(ps) => ps, + Err(_) => Vec::new(), + } + } + fn add_pubkey(&self, pubkey: &PublicKey) { + let mut pks: Vec = match self.pubkeys.get(PUBKEY_KEY) { + Ok(ps) => ps, + Err(_) => Vec::new(), + }; + let pk = pubkey.clone(); + if !pks.contains(&pk) { + pks.push(pk); + let _ = self.pubkeys.put(PUBKEY_KEY, pks); + } + } + fn remove_pubkey(&self, pk: &PublicKey) { + let pks: Vec = match self.pubkeys.get(PUBKEY_KEY) { + Ok(ps) => ps, + Err(_) => Vec::new(), + }; + let newpks = pks.iter().filter(|p| *p != pk).map(|p| p.clone()).collect(); + let _ = self.pubkeys.put(PUBKEY_KEY, newpks); + } +} + impl Persist for FsPersister { fn new_node(&self, node_id: &PublicKey, config: &NodeConfig, seed: &[u8]) { let pk = hex::encode(node_id.serialize()); @@ -53,14 +83,14 @@ impl Persist for FsPersister { network: config.network.to_string(), }; let _ = self.nodes.put(&pk, entry); - let _ = self.pubkeys.put(&pk, node_id.clone()); + self.add_pubkey(node_id); } fn delete_node(&self, node_id: &PublicKey) { let pk = hex::encode(node_id.serialize()); // clear all channel entries within "pk" sub-bucket let _ = self.channels.clear(&pk); let _ = self.nodes.remove(&pk); - let _ = self.pubkeys.remove(&pk); + self.remove_pubkey(node_id); } fn new_channel(&self, node_id: &PublicKey, stub: &ChannelStub) -> Result<(), ()> { let pk = hex::encode(node_id.serialize()); @@ -171,15 +201,11 @@ impl Persist for FsPersister { } fn get_nodes(&self) -> Vec<(PublicKey, CoreNodeEntry)> { let mut res = Vec::new(); - let list = match self.nodes.list() { - Ok(ns) => ns, - Err(_) => return res, - }; - for pk in list { - if let Ok(pubkey) = self.pubkeys.get(&pk) { - if let Ok(node) = self.nodes.get(&pk) { - res.push((pubkey, node.into())); - } + let list = self.get_pubkeys(); + for pubkey in list { + let pk = hex::encode(pubkey.serialize()); + if let Ok(node) = self.nodes.get(&pk) { + res.push((pubkey, node.into())); } } res