diff --git a/src/agent/src/main.rs b/src/agent/src/main.rs index a4603eebf..c86a03254 100644 --- a/src/agent/src/main.rs +++ b/src/agent/src/main.rs @@ -340,7 +340,9 @@ async fn start_sandbox( tasks.push(signal_handler_task); - watch_uevents(sandbox.clone()).await; + let uevents_handler_task = tokio::spawn(watch_uevents(sandbox.clone(), shutdown.clone())); + + tasks.push(uevents_handler_task); let (tx, rx) = tokio::sync::oneshot::channel(); sandbox.lock().await.sender = Some(tx); diff --git a/src/agent/src/uevent.rs b/src/agent/src/uevent.rs index 94c4253da..c4067ed5a 100644 --- a/src/agent/src/uevent.rs +++ b/src/agent/src/uevent.rs @@ -9,10 +9,13 @@ use crate::sandbox::Sandbox; use crate::GLOBAL_DEVICE_WATCHER; use slog::Logger; +use anyhow::Result; use netlink_sys::{protocols, SocketAddr, TokioSocket}; use nix::errno::Errno; use std::os::unix::io::FromRawFd; use std::sync::Arc; +use tokio::select; +use tokio::sync::watch::Receiver; use tokio::sync::Mutex; #[derive(Debug, Default)] @@ -132,49 +135,67 @@ impl Uevent { } } -pub async fn watch_uevents(sandbox: Arc>) { +pub async fn watch_uevents( + sandbox: Arc>, + mut shutdown: Receiver, +) -> Result<()> { let sref = sandbox.clone(); let s = sref.lock().await; let logger = s.logger.new(o!("subsystem" => "uevent")); - tokio::spawn(async move { - let mut socket; - unsafe { - let fd = libc::socket( - libc::AF_NETLINK, - libc::SOCK_DGRAM | libc::SOCK_CLOEXEC, - protocols::NETLINK_KOBJECT_UEVENT as libc::c_int, - ); - socket = TokioSocket::from_raw_fd(fd); - } - socket.bind(&SocketAddr::new(0, 1)).unwrap(); + // Unlock the sandbox to allow a successful shutdown + drop(s); - loop { - match socket.recv_from_full().await { - Err(e) => { - error!(logger, "receive uevent message failed"; "error" => format!("{}", e)) - } - Ok((buf, addr)) => { - if addr.port_number() != 0 { - // not our netlink message - let err_msg = format!("{:?}", nix::Error::Sys(Errno::EBADMSG)); - error!(logger, "receive uevent message failed"; "error" => err_msg); - return; + info!(logger, "starting uevents handler"); + + let mut socket; + + unsafe { + let fd = libc::socket( + libc::AF_NETLINK, + libc::SOCK_DGRAM | libc::SOCK_CLOEXEC, + protocols::NETLINK_KOBJECT_UEVENT as libc::c_int, + ); + socket = TokioSocket::from_raw_fd(fd); + } + + socket.bind(&SocketAddr::new(0, 1))?; + + loop { + select! { + _ = shutdown.changed() => { + info!(logger, "got shutdown request"); + break; + } + result = socket.recv_from_full() => { + match result { + Err(e) => { + error!(logger, "failed to receive uevent"; "error" => format!("{}", e)) } - - let text = String::from_utf8(buf); - match text { - Err(e) => { - error!(logger, "failed to convert bytes to text"; "error" => format!("{}", e)) + Ok((buf, addr)) => { + if addr.port_number() != 0 { + // not our netlink message + let err_msg = format!("{:?}", nix::Error::Sys(Errno::EBADMSG)); + error!(logger, "receive uevent message failed"; "error" => err_msg); + continue; } - Ok(text) => { - let event = Uevent::new(&text); - info!(logger, "got uevent message"; "event" => format!("{:?}", event)); - event.process(&logger, &sandbox).await; + + let text = String::from_utf8(buf); + match text { + Err(e) => { + error!(logger, "failed to convert bytes to text"; "error" => format!("{}", e)) + } + Ok(text) => { + let event = Uevent::new(&text); + info!(logger, "got uevent message"; "event" => format!("{:?}", event)); + event.process(&logger, &sandbox).await; + } } } } } } - }); + } + + Ok(()) }