mirror of
https://github.com/aljazceru/turso.git
synced 2026-02-15 21:14:21 +01:00
encrypt/decrypt when writing/reading from DB
This commit is contained in:
@@ -280,6 +280,10 @@ impl ReadCompletion {
|
||||
pub fn callback(&self, bytes_read: Result<i32, CompletionError>) {
|
||||
(self.complete)(bytes_read.map(|b| (self.buf.clone(), b)));
|
||||
}
|
||||
|
||||
pub fn buf_arc(&self) -> Arc<Buffer> {
|
||||
self.buf.clone()
|
||||
}
|
||||
}
|
||||
|
||||
pub struct WriteCompletion {
|
||||
|
||||
@@ -8539,6 +8539,11 @@ mod tests {
|
||||
let c = Completion::new_write(move |_| {
|
||||
let _ = _buf.clone();
|
||||
});
|
||||
#[cfg(feature = "encryption")]
|
||||
let _c = pager
|
||||
.db_file
|
||||
.write_page(current_page as usize, buf.clone(), None, c)?;
|
||||
#[cfg(not(feature = "encryption"))]
|
||||
let _c = pager
|
||||
.db_file
|
||||
.write_page(current_page as usize, buf.clone(), c)?;
|
||||
|
||||
@@ -58,7 +58,12 @@ impl DatabaseStorage for DatabaseFile {
|
||||
self.file.pread(0, c)
|
||||
}
|
||||
#[instrument(skip_all, level = Level::DEBUG)]
|
||||
fn read_page(&self, page_idx: usize, c: Completion) -> Result<Completion> {
|
||||
fn read_page(
|
||||
&self,
|
||||
page_idx: usize,
|
||||
#[cfg(feature = "encryption")] encryption_key: Option<&EncryptionKey>,
|
||||
c: Completion,
|
||||
) -> Result<Completion> {
|
||||
let r = c.as_read();
|
||||
let size = r.buf().len();
|
||||
assert!(page_idx > 0);
|
||||
@@ -66,7 +71,42 @@ impl DatabaseStorage for DatabaseFile {
|
||||
return Err(LimboError::NotADB);
|
||||
}
|
||||
let pos = (page_idx - 1) * size;
|
||||
self.file.pread(pos, c)
|
||||
|
||||
#[cfg(feature = "encryption")]
|
||||
{
|
||||
if let Some(key) = encryption_key {
|
||||
let key_clone = key.clone();
|
||||
let read_buffer = r.buf_arc();
|
||||
let original_c = c.clone();
|
||||
|
||||
let decrypt_complete = Box::new(move |buf: Arc<Buffer>, bytes_read: i32| {
|
||||
if bytes_read > 0 {
|
||||
match decrypt_page(buf.as_slice(), page_idx, &key_clone) {
|
||||
Ok(decrypted_data) => {
|
||||
let original_buf = original_c.as_read().buf();
|
||||
original_buf.as_mut_slice().copy_from_slice(&decrypted_data);
|
||||
original_c.complete(bytes_read);
|
||||
}
|
||||
Err(_) => {
|
||||
original_c.complete(-1);
|
||||
}
|
||||
}
|
||||
} else {
|
||||
original_c.complete(bytes_read);
|
||||
}
|
||||
});
|
||||
|
||||
let new_completion = Completion::new_read(read_buffer, decrypt_complete);
|
||||
self.file.pread(pos, new_completion)
|
||||
} else {
|
||||
self.file.pread(pos, c)
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(not(feature = "encryption"))]
|
||||
{
|
||||
self.file.pread(pos, c)
|
||||
}
|
||||
}
|
||||
|
||||
#[instrument(skip_all, level = Level::DEBUG)]
|
||||
@@ -74,6 +114,7 @@ impl DatabaseStorage for DatabaseFile {
|
||||
&self,
|
||||
page_idx: usize,
|
||||
buffer: Arc<Buffer>,
|
||||
#[cfg(feature = "encryption")] encryption_key: Option<&EncryptionKey>,
|
||||
c: Completion,
|
||||
) -> Result<Completion> {
|
||||
let buffer_size = buffer.len();
|
||||
@@ -82,21 +123,57 @@ impl DatabaseStorage for DatabaseFile {
|
||||
assert!(buffer_size <= 65536);
|
||||
assert_eq!(buffer_size & (buffer_size - 1), 0);
|
||||
let pos = (page_idx - 1) * buffer_size;
|
||||
let buffer = {
|
||||
#[cfg(feature = "encryption")]
|
||||
{
|
||||
if let Some(key) = encryption_key {
|
||||
encrypt_buffer(page_idx, buffer, key)
|
||||
} else {
|
||||
buffer
|
||||
}
|
||||
}
|
||||
#[cfg(not(feature = "encryption"))]
|
||||
{
|
||||
buffer
|
||||
}
|
||||
};
|
||||
self.file.pwrite(pos, buffer, c)
|
||||
}
|
||||
|
||||
fn write_pages(
|
||||
&self,
|
||||
page_idx: usize,
|
||||
first_page_idx: usize,
|
||||
page_size: usize,
|
||||
buffers: Vec<Arc<Buffer>>,
|
||||
#[cfg(feature = "encryption")] encryption_key: Option<&EncryptionKey>,
|
||||
c: Completion,
|
||||
) -> Result<Completion> {
|
||||
assert!(page_idx > 0);
|
||||
assert!(first_page_idx > 0);
|
||||
assert!(page_size >= 512);
|
||||
assert!(page_size <= 65536);
|
||||
assert_eq!(page_size & (page_size - 1), 0);
|
||||
let pos = (page_idx - 1) * page_size;
|
||||
|
||||
let pos = (first_page_idx - 1) * page_size;
|
||||
|
||||
let buffers = {
|
||||
#[cfg(feature = "encryption")]
|
||||
{
|
||||
if let Some(key) = encryption_key {
|
||||
buffers
|
||||
.into_iter()
|
||||
.enumerate()
|
||||
.map(|(i, buffer)| encrypt_buffer(first_page_idx + i, buffer, key))
|
||||
.collect::<Vec<_>>()
|
||||
} else {
|
||||
buffers
|
||||
}
|
||||
}
|
||||
#[cfg(not(feature = "encryption"))]
|
||||
{
|
||||
buffers
|
||||
}
|
||||
};
|
||||
|
||||
let c = self.file.pwritev(pos, buffers, c)?;
|
||||
Ok(c)
|
||||
}
|
||||
@@ -124,3 +201,9 @@ impl DatabaseFile {
|
||||
Self { file }
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(feature = "encryption")]
|
||||
fn encrypt_buffer(page_idx: usize, buffer: Arc<Buffer>, key: &EncryptionKey) -> Arc<Buffer> {
|
||||
let encrypted_data = encrypt_page(buffer.as_slice(), page_idx, key).unwrap();
|
||||
Arc::new(Buffer::new(encrypted_data.to_vec()))
|
||||
}
|
||||
|
||||
@@ -633,8 +633,8 @@ impl Pager {
|
||||
// Check if the calculated offset for the entry is within the bounds of the actual page data length.
|
||||
if offset_in_ptrmap_page + PTRMAP_ENTRY_SIZE > actual_data_length {
|
||||
return Err(LimboError::InternalError(format!(
|
||||
"Ptrmap offset {offset_in_ptrmap_page} + entry size {PTRMAP_ENTRY_SIZE} out of bounds for page {ptrmap_pg_no} (actual data len {actual_data_length})"
|
||||
)));
|
||||
"Ptrmap offset {offset_in_ptrmap_page} + entry size {PTRMAP_ENTRY_SIZE} out of bounds for page {ptrmap_pg_no} (actual data len {actual_data_length})"
|
||||
)));
|
||||
}
|
||||
|
||||
let entry_slice = &ptrmap_page_data_slice
|
||||
@@ -676,8 +676,8 @@ impl Pager {
|
||||
|| is_ptrmap_page(db_page_no_to_update, page_size)
|
||||
{
|
||||
return Err(LimboError::InternalError(format!(
|
||||
"Cannot set ptrmap entry for page {db_page_no_to_update}: it's a header/ptrmap page or invalid."
|
||||
)));
|
||||
"Cannot set ptrmap entry for page {db_page_no_to_update}: it's a header/ptrmap page or invalid."
|
||||
)));
|
||||
}
|
||||
|
||||
let ptrmap_pg_no = get_ptrmap_page_no_for_db_page(db_page_no_to_update, page_size);
|
||||
@@ -1012,8 +1012,15 @@ impl Pager {
|
||||
matches!(frame_watermark, Some(0) | None),
|
||||
"frame_watermark must be either None or Some(0) because DB has no WAL and read with other watermark is invalid"
|
||||
);
|
||||
|
||||
page.set_locked();
|
||||
let c = self.begin_read_disk_page(page_idx, page.clone(), allow_empty_read)?;
|
||||
let c = self.begin_read_disk_page(
|
||||
page_idx,
|
||||
page.clone(),
|
||||
allow_empty_read,
|
||||
#[cfg(feature = "encryption")]
|
||||
self.encryption_key.borrow().as_ref(),
|
||||
)?;
|
||||
return Ok((page, c));
|
||||
};
|
||||
|
||||
@@ -1026,7 +1033,13 @@ impl Pager {
|
||||
return Ok((page, c));
|
||||
}
|
||||
|
||||
let c = self.begin_read_disk_page(page_idx, page.clone(), allow_empty_read)?;
|
||||
let c = self.begin_read_disk_page(
|
||||
page_idx,
|
||||
page.clone(),
|
||||
allow_empty_read,
|
||||
#[cfg(feature = "encryption")]
|
||||
self.encryption_key.borrow().as_ref(),
|
||||
)?;
|
||||
Ok((page, c))
|
||||
}
|
||||
|
||||
@@ -2087,10 +2100,10 @@ mod ptrmap {
|
||||
pub fn serialize(&self, buffer: &mut [u8]) -> Result<()> {
|
||||
if buffer.len() < PTRMAP_ENTRY_SIZE {
|
||||
return Err(LimboError::InternalError(format!(
|
||||
"Buffer too small to serialize ptrmap entry. Expected at least {} bytes, got {}",
|
||||
PTRMAP_ENTRY_SIZE,
|
||||
buffer.len()
|
||||
)));
|
||||
"Buffer too small to serialize ptrmap entry. Expected at least {} bytes, got {}",
|
||||
PTRMAP_ENTRY_SIZE,
|
||||
buffer.len()
|
||||
)));
|
||||
}
|
||||
buffer[0] = self.entry_type as u8;
|
||||
buffer[1..5].copy_from_slice(&self.parent_page_no.to_be_bytes());
|
||||
|
||||
@@ -874,6 +874,7 @@ pub fn begin_read_page(
|
||||
page: PageRef,
|
||||
page_idx: usize,
|
||||
allow_empty_read: bool,
|
||||
encryption_key: Option<&EncryptionKey>,
|
||||
) -> Result<Completion> {
|
||||
tracing::trace!("begin_read_btree_page(page_idx = {})", page_idx);
|
||||
let buf = buffer_pool.get_page();
|
||||
@@ -895,7 +896,7 @@ pub fn begin_read_page(
|
||||
finish_read_page(page_idx, buf, page.clone());
|
||||
});
|
||||
let c = Completion::new_read(buf, complete);
|
||||
db_file.read_page(page_idx, c)
|
||||
db_file.read_page(page_idx, encryption_key, c)
|
||||
}
|
||||
|
||||
#[instrument(skip_all, level = Level::INFO)]
|
||||
@@ -946,7 +947,7 @@ pub fn begin_write_btree_page(pager: &Pager, page: &PageRef) -> Result<Completio
|
||||
})
|
||||
};
|
||||
let c = Completion::new_write(write_complete);
|
||||
page_source.write_page(page_id, buffer.clone(), c)
|
||||
page_source.write_page(page_id, buffer.clone(), None, c)
|
||||
}
|
||||
|
||||
#[instrument(skip_all, level = Level::DEBUG)]
|
||||
@@ -964,6 +965,7 @@ pub fn write_pages_vectored(
|
||||
pager: &Pager,
|
||||
batch: BTreeMap<usize, Arc<Buffer>>,
|
||||
flush: &PendingFlush,
|
||||
encryption_key: Option<&EncryptionKey>,
|
||||
) -> Result<Vec<Completion>> {
|
||||
if batch.is_empty() {
|
||||
return Ok(Vec::new());
|
||||
@@ -1043,6 +1045,7 @@ pub fn write_pages_vectored(
|
||||
start_id,
|
||||
page_sz,
|
||||
std::mem::replace(&mut run_bufs, Vec::with_capacity(EST_BUFF_CAPACITY)),
|
||||
encryption_key,
|
||||
c,
|
||||
) {
|
||||
Ok(c) => {
|
||||
|
||||
Reference in New Issue
Block a user