diff --git a/src/agent/Cargo.lock b/src/agent/Cargo.lock index a7de84906..f747dc827 100644 --- a/src/agent/Cargo.lock +++ b/src/agent/Cargo.lock @@ -387,6 +387,15 @@ dependencies = [ "unicode-segmentation", ] +[[package]] +name = "hermit-abi" +version = "0.1.18" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "322f4de77956e22ed0e5032c359a0f1273f1f7f0d79bfa3b8ffbc730d7fbcc5c" +dependencies = [ + "libc", +] + [[package]] name = "hex" version = "0.4.2" @@ -753,6 +762,16 @@ dependencies = [ "autocfg", ] +[[package]] +name = "num_cpus" +version = "1.13.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "05499f3756671c15885fee9034446956fff3f243d6077b91e5767df161f766b3" +dependencies = [ + "hermit-abi", + "libc", +] + [[package]] name = "object" version = "0.22.0" @@ -1483,6 +1502,7 @@ dependencies = [ "libc", "memchr", "mio", + "num_cpus", "once_cell", "pin-project-lite 0.2.4", "signal-hook-registry", diff --git a/src/agent/Cargo.toml b/src/agent/Cargo.toml index 8f86d2aba..e632a30c3 100644 --- a/src/agent/Cargo.toml +++ b/src/agent/Cargo.toml @@ -21,7 +21,7 @@ scopeguard = "1.0.0" regex = "1" async-trait = "0.1.42" -tokio = { version = "1.2.0", features = ["rt", "sync", "macros", "io-util", "time", "signal", "io-std", "process"] } +tokio = { version = "1.2.0", features = ["rt", "rt-multi-thread", "sync", "macros", "io-util", "time", "signal", "io-std", "process", "fs"] } futures = "0.3.12" netlink-sys = { version = "0.6.0", features = ["tokio_socket",]} tokio-vsock = "0.3.0" diff --git a/src/agent/src/main.rs b/src/agent/src/main.rs index ad8cc6565..595f4968a 100644 --- a/src/agent/src/main.rs +++ b/src/agent/src/main.rs @@ -55,6 +55,7 @@ mod sandbox; #[cfg(test)] mod test_utils; mod uevent; +mod util; mod version; use mount::{cgroups_mount, general_mount}; @@ -70,7 +71,11 @@ use rustjail::pipestream::PipeStream; use tokio::{ io::AsyncWrite, signal::unix::{signal, SignalKind}, - sync::{oneshot::Sender, Mutex, RwLock}, + sync::{ + oneshot::Sender, + watch::{channel, Receiver}, + Mutex, RwLock, + }, task::JoinHandle, }; use tokio_vsock::{Incoming, VsockListener, VsockStream}; @@ -126,7 +131,7 @@ async fn get_vsock_stream(fd: RawFd) -> Result { // Create a thread to handle reading from the logger pipe. The thread will // output to the vsock port specified, or stdout. -async fn create_logger_task(rfd: RawFd, vsock_port: u32) -> Result<()> { +async fn create_logger_task(rfd: RawFd, vsock_port: u32, shutdown: Receiver) -> Result<()> { let mut reader = PipeStream::from_fd(rfd); let mut writer: Box; @@ -147,7 +152,7 @@ async fn create_logger_task(rfd: RawFd, vsock_port: u32) -> Result<()> { writer = Box::new(tokio::io::stdout()); } - let _ = tokio::io::copy(&mut reader, &mut writer).await; + let _ = util::interruptable_io_copier(&mut reader, &mut writer, shutdown).await; Ok(()) } @@ -165,6 +170,8 @@ async fn real_main() -> std::result::Result<(), Box> { // support vsock log let (rfd, wfd) = unistd::pipe2(OFlag::O_CLOEXEC)?; + let (shutdown_tx, shutdown_rx) = channel(true); + let agent_config = AGENT_CONFIG.clone(); let init_mode = unistd::getpid() == Pid::from_raw(1); @@ -203,7 +210,7 @@ async fn real_main() -> std::result::Result<(), Box> { let log_vport = config.log_vport as u32; - let log_handle = tokio::spawn(create_logger_task(rfd, log_vport)); + let log_handle = tokio::spawn(create_logger_task(rfd, log_vport, shutdown_rx.clone())); tasks.push(log_handle); @@ -226,8 +233,15 @@ async fn real_main() -> std::result::Result<(), Box> { _log_guard = Ok(slog_stdlog::init().map_err(|e| e)?); } + // Start the sandbox and wait for its ttRPC server to end start_sandbox(&logger, &config, init_mode).await?; + // Trigger a controlled shutdown + shutdown_tx + .send(true) + .map_err(|e| anyhow!(e).context("failed to request shutdown"))?; + + // Wait for all threads to finish let results = join_all(tasks).await; for result in results { @@ -236,6 +250,8 @@ async fn real_main() -> std::result::Result<(), Box> { } } + eprintln!("{} shutdown complete", NAME); + Ok(()) } @@ -259,7 +275,7 @@ fn main() -> std::result::Result<(), Box> { exit(0); } - let rt = tokio::runtime::Builder::new_current_thread() + let rt = tokio::runtime::Builder::new_multi_thread() .enable_all() .build()?; diff --git a/src/agent/src/util.rs b/src/agent/src/util.rs new file mode 100644 index 000000000..314d05a25 --- /dev/null +++ b/src/agent/src/util.rs @@ -0,0 +1,342 @@ +// Copyright (c) 2021 Intel Corporation +// +// SPDX-License-Identifier: Apache-2.0 +// + +use std::io; +use std::io::ErrorKind; +use tokio::io::{AsyncReadExt, AsyncWriteExt}; +use tokio::sync::watch::Receiver; + +// Size of I/O read buffer +const BUF_SIZE: usize = 8192; + +// Interruptable I/O copy using readers and writers +// (an interruptable version of "io::copy()"). +pub async fn interruptable_io_copier( + mut reader: R, + mut writer: W, + mut shutdown: Receiver, +) -> io::Result +where + R: tokio::io::AsyncRead + Unpin, + W: tokio::io::AsyncWrite + Unpin, +{ + let mut total_bytes: u64 = 0; + + let mut buf: [u8; BUF_SIZE] = [0; BUF_SIZE]; + + loop { + tokio::select! { + _ = shutdown.changed() => { + eprintln!("INFO: interruptable_io_copier: got shutdown request"); + break; + }, + + result = reader.read(&mut buf) => { + let bytes = match result { + Ok(0) => return Ok(total_bytes), + Ok(len) => len, + Err(ref e) if e.kind() == ErrorKind::Interrupted => continue, + Err(e) => return Err(e), + }; + + total_bytes += bytes as u64; + + // Actually copy the data ;) + writer.write_all(&buf[..bytes]).await?; + }, + }; + } + + Ok(total_bytes) +} + +#[cfg(test)] +mod tests { + use super::*; + use std::io; + use std::io::Cursor; + use std::io::Write; + use std::pin::Pin; + use std::sync::{Arc, Mutex}; + use std::task::{Context, Poll, Poll::Ready}; + use tokio::pin; + use tokio::select; + use tokio::sync::watch::channel; + use tokio::task::JoinError; + use tokio::time::Duration; + + #[derive(Debug, Default, Clone)] + struct BufWriter { + data: Arc>>, + slow_write: bool, + write_delay: Duration, + } + + impl BufWriter { + fn new() -> Self { + BufWriter { + data: Arc::new(Mutex::new(Vec::::new())), + slow_write: false, + write_delay: Duration::new(0, 0), + } + } + + fn write_vec(&mut self, buf: &[u8]) -> io::Result { + let vec_ref = self.data.clone(); + + let mut vec_locked = vec_ref.lock(); + + let mut v = vec_locked.as_deref_mut().unwrap(); + + if self.write_delay.as_nanos() > 0 { + std::thread::sleep(self.write_delay); + } + + std::io::Write::write(&mut v, buf) + } + } + + impl Write for BufWriter { + fn write(&mut self, buf: &[u8]) -> io::Result { + self.write_vec(buf) + } + + fn flush(&mut self) -> io::Result<()> { + let vec_ref = self.data.clone(); + + let mut vec_locked = vec_ref.lock(); + + let v = vec_locked.as_deref_mut().unwrap(); + + std::io::Write::flush(v) + } + } + + impl tokio::io::AsyncWrite for BufWriter { + fn poll_write( + mut self: Pin<&mut Self>, + _cx: &mut Context<'_>, + buf: &[u8], + ) -> Poll> { + let result = self.write_vec(buf); + + Ready(result) + } + + fn poll_flush( + self: Pin<&mut Self>, + _cx: &mut Context<'_>, + ) -> Poll> { + // NOP + Ready(Ok(())) + } + + fn poll_shutdown( + self: Pin<&mut Self>, + _cx: &mut Context<'_>, + ) -> Poll> { + // NOP + Ready(Ok(())) + } + } + + impl ToString for BufWriter { + fn to_string(&self) -> String { + let data_ref = self.data.clone(); + let output = data_ref.lock().unwrap(); + let s = (*output).clone(); + + String::from_utf8(s).unwrap() + } + } + + #[tokio::test(flavor = "multi_thread", worker_threads = 2)] + async fn test_interruptable_io_copier_reader() { + #[derive(Debug)] + struct TestData { + reader_value: String, + result: io::Result, + } + + let tests = &[ + TestData { + reader_value: "".into(), + result: Ok(0), + }, + TestData { + reader_value: "a".into(), + result: Ok(1), + }, + TestData { + reader_value: "foo".into(), + result: Ok(3), + }, + TestData { + reader_value: "b".repeat(BUF_SIZE - 1), + result: Ok((BUF_SIZE - 1) as u64), + }, + TestData { + reader_value: "c".repeat(BUF_SIZE), + result: Ok((BUF_SIZE) as u64), + }, + TestData { + reader_value: "d".repeat(BUF_SIZE + 1), + result: Ok((BUF_SIZE + 1) as u64), + }, + TestData { + reader_value: "e".repeat((2 * BUF_SIZE) - 1), + result: Ok(((2 * BUF_SIZE) - 1) as u64), + }, + TestData { + reader_value: "f".repeat(2 * BUF_SIZE), + result: Ok((2 * BUF_SIZE) as u64), + }, + TestData { + reader_value: "g".repeat((2 * BUF_SIZE) + 1), + result: Ok(((2 * BUF_SIZE) + 1) as u64), + }, + ]; + + for (i, d) in tests.iter().enumerate() { + // Create a string containing details of the test + let msg = format!("test[{}]: {:?}", i, d); + + let (tx, rx) = channel(true); + let reader = Cursor::new(d.reader_value.clone()); + let writer = BufWriter::new(); + + // XXX: Pass a copy of the writer to the copier to allow the + // result of the write operation to be checked below. + let handle = tokio::spawn(interruptable_io_copier(reader, writer.clone(), rx)); + + // Allow time for the thread to be spawned. + tokio::time::sleep(Duration::from_secs(1)).await; + + let timeout = tokio::time::sleep(Duration::from_secs(1)); + pin!(timeout); + + // Since the readers only specify a small number of bytes, the + // copier will quickly read zero and kill the task, closing the + // Receiver. + assert!(tx.is_closed(), "{}", msg); + + let spawn_result: std::result::Result< + std::result::Result, + JoinError, + >; + + let result: std::result::Result; + + select! { + res = handle => spawn_result = res, + _ = &mut timeout => panic!("timed out"), + } + + assert!(spawn_result.is_ok()); + + result = spawn_result.unwrap(); + + assert!(result.is_ok()); + + let byte_count = result.unwrap() as usize; + assert_eq!(byte_count, d.reader_value.len(), "{}", msg); + + let value = writer.to_string(); + assert_eq!(value, d.reader_value, "{}", msg); + } + } + + #[tokio::test(flavor = "multi_thread", worker_threads = 2)] + async fn test_interruptable_io_copier_eof() { + // Create an async reader that always returns EOF + let reader = tokio::io::empty(); + + let (tx, rx) = channel(true); + let writer = BufWriter::new(); + + let handle = tokio::spawn(interruptable_io_copier(reader, writer.clone(), rx)); + + // Allow time for the thread to be spawned. + tokio::time::sleep(Duration::from_secs(1)).await; + + let timeout = tokio::time::sleep(Duration::from_secs(1)); + pin!(timeout); + + assert!(tx.is_closed()); + + let spawn_result: std::result::Result, JoinError>; + + let result: std::result::Result; + + select! { + res = handle => spawn_result = res, + _ = &mut timeout => panic!("timed out"), + } + + assert!(spawn_result.is_ok()); + + result = spawn_result.unwrap(); + + assert!(result.is_ok()); + + let byte_count = result.unwrap(); + assert_eq!(byte_count, 0); + } + + #[tokio::test(flavor = "multi_thread", worker_threads = 2)] + async fn test_interruptable_io_copier_shutdown() { + // Create an async reader that creates an infinite stream of bytes + // (which allows us to interrupt it, since we know it is always busy ;) + const REPEAT_CHAR: u8 = b'r'; + + let reader = tokio::io::repeat(REPEAT_CHAR); + + let (tx, rx) = channel(true); + let writer = BufWriter::new(); + + let handle = tokio::spawn(interruptable_io_copier(reader, writer.clone(), rx)); + + // Allow time for the thread to be spawned. + tokio::time::sleep(Duration::from_secs(1)).await; + + let timeout = tokio::time::sleep(Duration::from_secs(1)); + pin!(timeout); + + assert!(!tx.is_closed()); + + tx.send(true).expect("failed to request shutdown"); + + let spawn_result: std::result::Result, JoinError>; + + let result: std::result::Result; + + select! { + res = handle => spawn_result = res, + _ = &mut timeout => panic!("timed out"), + } + + assert!(spawn_result.is_ok()); + + result = spawn_result.unwrap(); + + assert!(result.is_ok()); + + let byte_count = result.unwrap(); + + let value = writer.to_string(); + + let writer_byte_count = value.len() as u64; + + assert_eq!(byte_count, writer_byte_count); + + // Remove the char used as a payload. If anything else remins, + // something went wrong. + let mut remainder = value; + + remainder.retain(|c| c != REPEAT_CHAR as char); + + assert_eq!(remainder.len(), 0); + } +}