diff --git a/broker/src/conn.rs b/broker/src/conn.rs index a58f283..ee963da 100644 --- a/broker/src/conn.rs +++ b/broker/src/conn.rs @@ -8,7 +8,7 @@ use std::collections::HashMap; pub struct Connections { pub pubkey: Option, pub clients: HashMap, - pub current: Option, + pub current: Option<(String, SignerType)>, } impl Connections { @@ -25,17 +25,19 @@ impl Connections { pub fn set_pubkey(&mut self, pk: &str) { self.pubkey = Some(pk.to_string()) } - pub fn set_current(&mut self, cid: String) { - self.current = Some(cid); + pub fn set_current(&mut self, cid: String, signer_type: SignerType) { + self.current = Some((cid, signer_type)); } pub fn add_client(&mut self, cid: &str, signer_type: SignerType) { self.clients.insert(cid.to_string(), signer_type); - self.current = Some(cid.to_string()); + self.current = Some((cid.to_string(), signer_type)); } pub fn remove_client(&mut self, cid: &str) { self.clients.remove(cid); - if self.current == Some(cid.to_string()) { - self.current = None; + if let Some((id, _)) = &self.current { + if id == cid { + self.current = None; + } } } } diff --git a/broker/src/looper.rs b/broker/src/looper.rs index 808eb3f..2072882 100644 --- a/broker/src/looper.rs +++ b/broker/src/looper.rs @@ -4,7 +4,10 @@ use bitcoin::blockdata::constants::ChainHash; use log::*; use rocket::tokio::sync::mpsc; use secp256k1::PublicKey; -use sphinx_signer::{parser, sphinx_glyph::topics}; +use sphinx_signer::{ + parser, + sphinx_glyph::{topics, types::SignerType}, +}; use std::sync::atomic::{AtomicBool, AtomicU16, Ordering}; use std::thread; use std::time::Duration; @@ -123,7 +126,7 @@ impl SignerLoop { } msg => { let mut catch_init = false; - if let Message::HsmdInit(m) = msg { + if let Message::HsmdInit(ref m) = msg { catch_init = true; if let Some(set) = settings { if ChainHash::using_genesis_block(set.network).as_bytes() @@ -135,7 +138,14 @@ impl SignerLoop { panic!("Got HsmdInit without settings - likely because HsmdInit was sent after startup"); } } - let reply = self.handle_message(raw_msg, catch_init)?; + let reply = if let Message::PreapproveInvoice(_) + | Message::PreapproveKeysend(_) = msg + { + self.handle_message(raw_msg, catch_init, Some(SignerType::ReceiveSend))? + } else { + // None for signer_type means no restrictions on which signer type to send the message to + self.handle_message(raw_msg, catch_init, None)? + }; // Write the reply to CLN self.client.write_vec(reply)?; } @@ -143,7 +153,12 @@ impl SignerLoop { } } - fn handle_message(&mut self, message: Vec, catch_init: bool) -> Result> { + fn handle_message( + &mut self, + message: Vec, + catch_init: bool, + signer_type: Option, + ) -> Result> { // wait until not busy loop { match try_to_get_busy() { @@ -166,7 +181,7 @@ impl SignerLoop { )?; // send to signer log::info!("SEND ON {}", topics::VLS); - let (res_topic, res) = self.send_request_wait(topics::VLS, md)?; + let (res_topic, res) = self.send_request_wait(topics::VLS, md, signer_type)?; log::info!("GOT ON {}", res_topic); let the_res = if res_topic == topics::LSS_RES { // send reply to LSS to store muts @@ -174,7 +189,7 @@ impl SignerLoop { log::info!("LSS REPLY LEN {}", &lss_reply.len()); // send to signer for HMAC validation, and get final reply log::info!("SEND ON {}", topics::LSS_MSG); - let (res_topic2, res2) = self.send_request_wait(topics::LSS_MSG, lss_reply)?; + let (res_topic2, res2) = self.send_request_wait(topics::LSS_MSG, lss_reply, None)?; log::info!("GOT ON {}, send to CLN", res_topic2); if res_topic2 != topics::VLS_RES { log::warn!("got a topic NOT on {}", topics::VLS_RES); @@ -213,9 +228,17 @@ impl SignerLoop { // returns (topic, payload) // might halt if signer is offline - fn send_request_wait(&mut self, topic: &str, message: Vec) -> Result<(String, Vec)> { + fn send_request_wait( + &mut self, + topic: &str, + message: Vec, + signer_type: Option, + ) -> Result<(String, Vec)> { // Send a request to the MQTT handler to send to signer - let (request, reply_rx) = ChannelRequest::new(topic, message); + let (request, reply_rx) = match signer_type { + Some(st) => ChannelRequest::new_for_type(st, topic, message), + None => ChannelRequest::new(topic, message), + }; // This can fail if MQTT shuts down self.chan .sender diff --git a/broker/src/main.rs b/broker/src/main.rs index 032bec4..210a4c2 100644 --- a/broker/src/main.rs +++ b/broker/src/main.rs @@ -182,7 +182,11 @@ pub async fn broker_setup( false }); if dance_complete { - log::info!("adding client to the list: {}, type: {:?}", &cid, signer_type); + log::info!( + "adding client to the list: {}, type: {:?}", + &cid, + signer_type + ); let mut cs = conns_.lock().unwrap(); cs.add_client(&cid, signer_type); drop(cs); diff --git a/broker/src/mqtt.rs b/broker/src/mqtt.rs index 2ec6174..fccb3f2 100644 --- a/broker/src/mqtt.rs +++ b/broker/src/mqtt.rs @@ -124,7 +124,11 @@ pub fn start_broker( // This is the ReceiveSend signer type None => SignerType::default(), }; - log::debug!("caught hello message for id: {}, type: {:?}", cid, signer_type); + log::debug!( + "caught hello message for id: {}, type: {:?}", + cid, + signer_type + ); let _ = internal_status_tx.send((true, cid, Some(signer_type))); } else if topic.ends_with(topics::BYE) { let _ = internal_status_tx.send((false, cid, None)); @@ -187,15 +191,25 @@ fn pub_and_wait( } else { let current = current.unwrap(); // Try the current connection - let mut rep = pub_timeout(¤t, &msg.topic, &msg.message, &msg_rx, link_tx); + // This returns None if 1) signer_type is set, and not equal to the current signer + // 2) If pub_timeout times out + let mut rep = if current.1 == msg.signer_type.unwrap_or(current.1) { + pub_timeout(¤t.0, &msg.topic, &msg.message, &msg_rx, link_tx) + } else { + None + }; + // If that failed, try looking for some other signer if rep.is_none() { - for cid in client_list.into_keys().filter(|k| k != ¤t) { + // If signer_type is set, we also filter for only these types + for (cid, signer_type) in client_list.into_iter().filter(|(k, v)| { + k != ¤t.0 && v == msg.signer_type.as_ref().unwrap_or(v) + }) { rep = pub_timeout(&cid, &msg.topic, &msg.message, &msg_rx, link_tx); if rep.is_some() { let mut cs = conns_.lock().unwrap(); log::debug!("got the list lock!"); - cs.set_current(cid.to_string()); + cs.set_current(cid.to_string(), signer_type); drop(cs); break; } @@ -212,6 +226,7 @@ fn pub_and_wait( break; } else { log::debug!("couldn't reach any clients..."); + std::thread::sleep(Duration::from_secs(1)); } if let Some(max) = retries { log::debug!("counter: {}, retries: {}", counter, max);