diff --git a/broker/src/error_log.rs b/broker/src/error_log.rs index 868a7e2..ddad888 100644 --- a/broker/src/error_log.rs +++ b/broker/src/error_log.rs @@ -4,9 +4,12 @@ use std::{env, fs}; const DEFAULT_ERROR_LOG_PATH: &str = "/root/.lightning/broker_errors.log"; -pub fn log_errors(mut error_rx: tokio::sync::broadcast::Receiver>) { +pub fn log_errors( + mut error_rx: tokio::sync::broadcast::Receiver>, + task_set: &mut tokio::task::JoinSet<()>, +) { // collect errors - tokio::spawn(async move { + task_set.spawn(async move { let err_log_path = env::var("BROKER_ERROR_LOG_PATH").unwrap_or(DEFAULT_ERROR_LOG_PATH.to_string()); if let Ok(mut file) = fs::OpenOptions::new() diff --git a/broker/src/handle.rs b/broker/src/handle.rs index 5951795..c0d0bd5 100644 --- a/broker/src/handle.rs +++ b/broker/src/handle.rs @@ -5,7 +5,6 @@ use crate::looper::ClientId; use rocket::tokio::sync::mpsc; use sphinx_signer::{parser, sphinx_glyph::topics}; use std::sync::atomic::{AtomicU16, Ordering}; -use std::thread; use std::time::Duration; use vls_protocol::{Error, Result}; @@ -37,7 +36,7 @@ pub fn handle_message( if is_my_turn(ticket) { break; } else { - thread::sleep(Duration::from_millis(96)); + std::thread::sleep(Duration::from_millis(96)); } } @@ -45,12 +44,12 @@ pub fn handle_message( let (cid, is_synced) = current_client_and_synced(); if cid.is_none() { log::debug!("no client yet... retry"); - thread::sleep(Duration::from_millis(96)); + std::thread::sleep(Duration::from_millis(96)); continue; } if !is_synced { log::debug!("current client still syncing..."); - thread::sleep(Duration::from_millis(96)); + std::thread::sleep(Duration::from_millis(96)); continue; } let cid = cid.unwrap(); @@ -62,7 +61,7 @@ pub fn handle_message( Err(e) => { log::warn!("error handle_message_inner, trying again... {:?}", e); cycle_clients(&cid); - thread::sleep(Duration::from_millis(96)); + std::thread::sleep(Duration::from_millis(96)); } } }; diff --git a/broker/src/looper.rs b/broker/src/looper.rs index 4780333..9cd97b6 100644 --- a/broker/src/looper.rs +++ b/broker/src/looper.rs @@ -3,10 +3,9 @@ use crate::handle::handle_message; use crate::secp256k1::PublicKey; use log::*; use lru::LruCache; -use rocket::tokio::sync::mpsc; +use rocket::tokio::{self, sync::mpsc}; use sphinx_signer::lightning_signer::bitcoin::hashes::{sha256::Hash as Sha256Hash, Hash}; use std::num::NonZeroUsize; -use std::thread; use std::time::Duration; use std::time::SystemTime; use vls_protocol::{msgs, msgs::Message, Error, Result}; @@ -106,7 +105,7 @@ impl SignerLoop { self.vls_tx.clone(), client_id, ); - thread::spawn(move || new_loop.start()); + tokio::task::spawn_blocking(move || new_loop.start()); } Message::Memleak(_) => { // info!("Memleak"); diff --git a/broker/src/lss.rs b/broker/src/lss.rs index 9ed9a04..4838e92 100644 --- a/broker/src/lss.rs +++ b/broker/src/lss.rs @@ -7,6 +7,7 @@ use rumqttd::oneshot as std_oneshot; use sphinx_signer::parser; use sphinx_signer::sphinx_glyph::topics; use tokio::sync::mpsc; +use tokio::task::JoinSet; use vls_protocol::msgs::{self, Message, SerBolt}; use vls_proxy::client::{Client, UnixClient}; @@ -17,8 +18,9 @@ pub fn lss_tasks( init_tx: mpsc::Sender, mut cln_client: UnixClient, mut hsmd_raw: Vec, + task_set: &mut JoinSet<()>, ) { - tokio::task::spawn(async move { + task_set.spawn(async move { // first connection - initializes lssbroker let (lss_conn, hsmd_init_reply) = loop { let (cid, dance_complete_tx) = conn_rx.recv().await.unwrap(); diff --git a/broker/src/main.rs b/broker/src/main.rs index cac8f35..c32e11f 100644 --- a/broker/src/main.rs +++ b/broker/src/main.rs @@ -19,9 +19,11 @@ use crate::mqtt::{check_auth, start_broker}; use crate::util::{read_broker_config, Settings}; use clap::{arg, App}; use rocket::tokio::{ - self, + self, select, sync::{broadcast, mpsc}, + task::JoinSet, }; +use rocket::{Build, Rocket}; use rumqttd::{oneshot as std_oneshot, AuthMsg, AuthType}; use std::env; use std::sync::Arc; @@ -33,8 +35,23 @@ use vls_proxy::connection::{open_parent_fd, UnixConnection}; use vls_proxy::portfront::SignerPortFront; use vls_proxy::util::{add_hsmd_args, handle_hsmd_version}; -#[rocket::launch] -async fn rocket() -> _ { +#[rocket::main] +async fn main() { + let mut task_set: JoinSet<()> = JoinSet::new(); + let web_server = rocket(&mut task_set).await; + select! { + _ = task_set.join_next() => { + log::warn!("a spawned task shut down"); + } + _ = web_server.launch() => { + log::warn!("the rocket web server shut down"); + } + + }; + log::info!("shutting down"); +} + +async fn rocket(task_set: &mut JoinSet<()>) -> Rocket { let parent_fd = open_parent_fd(); util::setup_logging("hsmd ", "info"); @@ -57,7 +74,7 @@ async fn rocket() -> _ { } else if matches.is_present("test") { run_test::run_test() } else { - run_main(parent_fd) + run_main(parent_fd, task_set) } } @@ -68,17 +85,24 @@ fn make_clap_app() -> App<'static> { add_hsmd_args(app) } -fn run_main(parent_fd: i32) -> rocket::Rocket { +fn run_main(parent_fd: i32, task_set: &mut JoinSet<()>) -> rocket::Rocket { let settings = read_broker_config(); let (mqtt_tx, mqtt_rx) = mpsc::channel(10000); let (init_tx, init_rx) = mpsc::channel(10000); let (error_tx, error_rx) = broadcast::channel(10000); - error_log::log_errors(error_rx); + error_log::log_errors(error_rx, task_set); let (conn_tx, conn_rx) = mpsc::channel::<(String, std_oneshot::Sender)>(10000); - broker_setup(settings, mqtt_rx, init_rx, conn_tx, error_tx.clone()); + broker_setup( + settings, + mqtt_rx, + init_rx, + conn_tx, + error_tx.clone(), + task_set, + ); let mut cln_client_a = UnixClient::new(UnixConnection::new(parent_fd)); let hsmd_raw = cln_client_a.read_raw().unwrap(); @@ -93,7 +117,15 @@ fn run_main(parent_fd: i32) -> rocket::Rocket { // TODO: add a validation here of the uri setting to make sure LSS is running if let Ok(lss_uri) = env::var("VLS_LSS") { log::info!("Spawning lss tasks..."); - lss::lss_tasks(lss_uri, lss_rx, conn_rx, init_tx, cln_client_a, hsmd_raw); + lss::lss_tasks( + lss_uri, + lss_rx, + conn_rx, + init_tx, + cln_client_a, + hsmd_raw, + task_set, + ); } else { log::warn!("running without LSS"); } @@ -125,7 +157,7 @@ fn run_main(parent_fd: i32) -> rocket::Rocket { // TODO pass status_rx into SignerLoop? let mut signer_loop = SignerLoop::new(cln_client, mqtt_tx.clone(), lss_tx); // spawn CLN listener - std::thread::spawn(move || { + task_set.spawn_blocking(move || { signer_loop.start(); }); @@ -139,12 +171,13 @@ pub fn broker_setup( init_rx: mpsc::Receiver, conn_tx: mpsc::Sender<(String, std_oneshot::Sender)>, error_tx: broadcast::Sender>, + task_set: &mut JoinSet<()>, ) { let (auth_tx, auth_rx) = std::sync::mpsc::channel::(); let (status_tx, status_rx) = std::sync::mpsc::channel(); // authenticator - std::thread::spawn(move || { + task_set.spawn_blocking(move || { while let Ok(am) = auth_rx.recv() { let pubkey = current_pubkey(); let (ok, new_pubkey) = match am.msg { @@ -160,11 +193,13 @@ pub fn broker_setup( // broker log::info!("=> start broker on network: {}", settings.network); - start_broker(settings, mqtt_rx, init_rx, status_tx, error_tx, auth_tx) - .expect("BROKER FAILED TO START"); + start_broker( + settings, mqtt_rx, init_rx, status_tx, error_tx, auth_tx, task_set, + ) + .expect("BROKER FAILED TO START"); // client connections state - std::thread::spawn(move || { + task_set.spawn_blocking(move || { log::info!("=> waiting first connection..."); while let Ok((cid, connected)) = status_rx.recv() { log::info!("=> connection status: {}: {}", cid, connected); diff --git a/broker/src/mqtt.rs b/broker/src/mqtt.rs index dea6ba9..aefdfef 100644 --- a/broker/src/mqtt.rs +++ b/broker/src/mqtt.rs @@ -1,6 +1,6 @@ use crate::conn::{ChannelReply, ChannelRequest}; use crate::util::Settings; -use rocket::tokio::{sync::broadcast, sync::mpsc}; +use rocket::tokio::{sync::broadcast, sync::mpsc, task::JoinSet}; use rumqttd::{local::LinkTx, AuthMsg, Broker, Config, Notification}; use sphinx_signer::sphinx_glyph::sphinx_auther::token::Token; use sphinx_signer::sphinx_glyph::topics; @@ -16,6 +16,7 @@ pub fn start_broker( status_sender: std::sync::mpsc::Sender<(String, bool)>, error_sender: broadcast::Sender>, auth_sender: std::sync::mpsc::Sender, + task_set: &mut JoinSet<()>, ) -> anyhow::Result<()> { let conf = config(settings); // println!("CONF {:?}", conf); @@ -27,7 +28,7 @@ pub fn start_broker( let _ = link_tx.subscribe(format!("+/{}", topics::HELLO)); let _ = link_tx.subscribe(format!("+/{}", topics::BYE)); - std::thread::spawn(move || { + task_set.spawn_blocking(move || { broker.start().expect("could not start broker"); }); @@ -36,7 +37,7 @@ pub fn start_broker( // track connections let link_tx_ = link_tx.clone(); - let _conns_task = std::thread::spawn(move || { + let _conns_task = task_set.spawn_blocking(move || { while let Ok((is, cid)) = internal_status_rx.recv() { if is { subs(&cid, link_tx_.clone()); @@ -52,7 +53,7 @@ pub fn start_broker( let mut link_tx_ = link_tx.clone(); // receive replies from LSS initialization - let _init_task = std::thread::spawn(move || { + let _init_task = task_set.spawn_blocking(move || { while let Some(msg) = init_receiver.blocking_recv() { // Retry three times pub_and_wait(msg, &init_rx, &mut link_tx_, Some(3)); @@ -63,7 +64,7 @@ pub fn start_broker( let (msg_tx, msg_rx) = std::sync::mpsc::channel::<(String, String, Vec)>(); // receive from CLN, Frontend, Controller, or LSS - let _relay_task = std::thread::spawn(move || { + let _relay_task = task_set.spawn_blocking(move || { while let Some(msg) = receiver.blocking_recv() { log::debug!("Received message here: {:?}", msg); let retries = if msg.topic == topics::CONTROL { @@ -78,7 +79,7 @@ pub fn start_broker( }); // receive replies back from glyph - let _sub_task = std::thread::spawn(move || { + let _sub_task = task_set.spawn_blocking(move || { while let Ok(message) = link_rx.recv() { if message.is_none() { continue; diff --git a/broker/src/run_test.rs b/broker/src/run_test.rs index 97bf825..1d421ba 100644 --- a/broker/src/run_test.rs +++ b/broker/src/run_test.rs @@ -1,7 +1,7 @@ use crate::conn::ChannelRequest; use crate::routes::launch_rocket; use crate::util::Settings; -use rocket::tokio::{self, sync::broadcast, sync::mpsc}; +use rocket::tokio::{self, sync::broadcast, sync::mpsc, task::JoinSet}; use sphinx_signer::vls_protocol::{msgs, msgs::Message}; use sphinx_signer::{parser, sphinx_glyph::topics}; use vls_protocol::serde_bolt::WireString; @@ -9,6 +9,7 @@ use vls_protocol::serde_bolt::WireString; // const CLIENT_ID: &str = "test-1"; pub fn run_test() -> rocket::Rocket { + let mut task_set = JoinSet::<()>::new(); log::info!("TEST..."); // let mut id = 0u16; @@ -20,14 +21,21 @@ pub fn run_test() -> rocket::Rocket { let (error_tx, error_rx) = broadcast::channel(10000); let (conn_tx, _conn_rx) = mpsc::channel(10000); - crate::error_log::log_errors(error_rx); + crate::error_log::log_errors(error_rx, &mut task_set); // block until connection - crate::broker_setup(settings, mqtt_rx, init_rx, conn_tx, error_tx.clone()); + crate::broker_setup( + settings, + mqtt_rx, + init_rx, + conn_tx, + error_tx.clone(), + &mut task_set, + ); log::info!("=> off to the races!"); let tx_ = mqtt_tx.clone(); - tokio::spawn(async move { + task_set.spawn(async move { let mut id = 0; let mut sequence = 0; loop {