diff --git a/src/agent/src/device.rs b/src/agent/src/device.rs index 60fca50df..09ead0e49 100644 --- a/src/agent/src/device.rs +++ b/src/agent/src/device.rs @@ -206,12 +206,49 @@ pub async fn get_virtio_blk_pci_device_name( Ok(format!("{}/{}", SYSTEM_DEV_PATH, &uev.devname)) } -pub async fn get_pmem_device_name( - sandbox: &Arc>, - pmem_devname: &str, -) -> Result { - let dev_sub_path = format!("/{}/{}", SCSI_BLOCK_SUFFIX, pmem_devname); - get_device_name(sandbox, &dev_sub_path).await +#[derive(Debug)] +struct PmemBlockMatcher { + suffix: String, +} + +impl PmemBlockMatcher { + fn new(devname: &str) -> Result { + let suffix = format!(r"/block/{}", devname); + + Ok(PmemBlockMatcher { suffix }) + } +} + +impl UeventMatcher for PmemBlockMatcher { + fn is_match(&self, uev: &Uevent) -> bool { + uev.subsystem == "block" + && uev.devpath.starts_with(ACPI_DEV_PATH) + && uev.devpath.ends_with(&self.suffix) + && !uev.devname.is_empty() + } +} + +pub async fn wait_for_pmem_device(sandbox: &Arc>, devpath: &str) -> Result<()> { + let devname = match devpath.strip_prefix("/dev/") { + Some(dev) => dev, + None => { + return Err(anyhow!( + "Storage source '{}' must start with /dev/", + devpath + )) + } + }; + + let matcher = PmemBlockMatcher::new(devname)?; + let uev = wait_for_uevent(sandbox, matcher).await?; + if uev.devname != devname { + return Err(anyhow!( + "Unexpected device name {} for pmem device (expected {})", + uev.devname, + devname + )); + } + Ok(()) } /// Scan SCSI bus for the given SCSI address(SCSI-Id and LUN) diff --git a/src/agent/src/mount.rs b/src/agent/src/mount.rs index d0b91a740..0f649773e 100644 --- a/src/agent/src/mount.rs +++ b/src/agent/src/mount.rs @@ -24,7 +24,7 @@ use std::fs::File; use std::io::{BufRead, BufReader}; use crate::device::{ - get_pmem_device_name, get_scsi_device_name, get_virtio_blk_pci_device_name, online_device, + get_scsi_device_name, get_virtio_blk_pci_device_name, online_device, wait_for_pmem_device, }; use crate::linux_abi::*; use crate::pci; @@ -377,22 +377,10 @@ async fn nvdimm_storage_handler( storage: &Storage, sandbox: Arc>, ) -> Result { - let mut storage = storage.clone(); - // If hot-plugged, get the device node path based on the PCI address else - // use the virt path provided in Storage Source - let pmem_devname = match storage.source.strip_prefix("/dev/") { - Some(dev) => dev, - None => { - return Err(anyhow!( - "Storage source '{}' must start with /dev/", - storage.source - )) - } - }; + let storage = storage.clone(); // Retrieve the device path from NVDIMM address. - let dev_path = get_pmem_device_name(&sandbox, pmem_devname).await?; - storage.source = dev_path; + wait_for_pmem_device(&sandbox, &storage.source).await?; common_storage_handler(logger, &storage) }