merge main

This commit is contained in:
TcMits
2025-09-02 18:25:20 +07:00
94 changed files with 3786 additions and 1684 deletions

View File

@@ -1,5 +1,7 @@
use thiserror::Error;
use crate::storage::page_cache::CacheError;
#[derive(Debug, Clone, Error, miette::Diagnostic)]
pub enum LimboError {
#[error("Corrupt database: {0}")]
@@ -8,8 +10,8 @@ pub enum LimboError {
NotADB,
#[error("Internal error: {0}")]
InternalError(String),
#[error("Page cache is full")]
CacheFull,
#[error(transparent)]
CacheError(#[from] CacheError),
#[error("Database is full: {0}")]
DatabaseFull(String),
#[error("Parse error: {0}")]

View File

@@ -583,6 +583,7 @@ impl Display for MathFunc {
#[derive(Debug)]
pub enum AlterTableFunc {
RenameTable,
AlterColumn,
RenameColumn,
}
@@ -591,6 +592,7 @@ impl Display for AlterTableFunc {
match self {
AlterTableFunc::RenameTable => write!(f, "limbo_rename_table"),
AlterTableFunc::RenameColumn => write!(f, "limbo_rename_column"),
AlterTableFunc::AlterColumn => write!(f, "limbo_alter_column"),
}
}
}

View File

@@ -1,8 +1,40 @@
use core::f64;
use crate::types::Value;
use crate::vdbe::Register;
use crate::LimboError;
// TODO: Support %!.3s %i, %x, %X, %o, %e, %E, %c. flags: - + 0 ! ,
fn get_exponential_formatted_str(number: &f64, uppercase: bool) -> crate::Result<String> {
let pre_formatted = format!("{number:.6e}");
let mut parts = pre_formatted.split("e");
let maybe_base = parts.next();
let maybe_exponent = parts.next();
let mut result = String::new();
match (maybe_base, maybe_exponent) {
(Some(base), Some(exponent)) => {
result.push_str(base);
result.push_str(if uppercase { "E" } else { "e" });
match exponent.parse::<i32>() {
Ok(exponent_number) => {
let exponent_fmt = format!("{exponent_number:+03}");
result.push_str(&exponent_fmt);
Ok(result)
}
Err(_) => Err(LimboError::InternalError(
"unable to parse exponential expression's exponent".into(),
)),
}
}
(_, _) => Err(LimboError::InternalError(
"unable to parse exponential expression".into(),
)),
}
}
// TODO: Support %!.3s. flags: - + 0 ! ,
#[inline(always)]
pub fn exec_printf(values: &[Register]) -> crate::Result<Value> {
if values.is_empty() {
@@ -40,6 +72,20 @@ pub fn exec_printf(values: &[Register]) -> crate::Result<Value> {
}
args_index += 1;
}
Some('u') => {
if args_index >= values.len() {
return Err(LimboError::InvalidArgument("not enough arguments".into()));
}
let value = &values[args_index].get_value();
match value {
Value::Integer(_) => {
let converted_value = value.as_uint();
result.push_str(&format!("{converted_value}"))
}
_ => result.push('0'),
}
args_index += 1;
}
Some('s') => {
if args_index >= values.len() {
return Err(LimboError::InvalidArgument("not enough arguments".into()));
@@ -63,6 +109,119 @@ pub fn exec_printf(values: &[Register]) -> crate::Result<Value> {
}
args_index += 1;
}
Some('e') => {
if args_index >= values.len() {
return Err(LimboError::InvalidArgument("not enough arguments".into()));
}
let value = &values[args_index].get_value();
match value {
Value::Float(f) => match get_exponential_formatted_str(f, false) {
Ok(str) => result.push_str(&str),
Err(e) => return Err(e),
},
Value::Integer(i) => {
let f = *i as f64;
match get_exponential_formatted_str(&f, false) {
Ok(str) => result.push_str(&str),
Err(e) => return Err(e),
}
}
Value::Text(s) => {
let number: f64 = s
.as_str()
.trim_start()
.trim_end_matches(|c: char| !c.is_numeric())
.parse()
.unwrap_or(0.0);
match get_exponential_formatted_str(&number, false) {
Ok(str) => result.push_str(&str),
Err(e) => return Err(e),
};
}
_ => result.push_str("0.000000e+00"),
}
args_index += 1;
}
Some('E') => {
if args_index >= values.len() {
return Err(LimboError::InvalidArgument("not enough arguments".into()));
}
let value = &values[args_index].get_value();
match value {
Value::Float(f) => match get_exponential_formatted_str(f, false) {
Ok(str) => result.push_str(&str),
Err(e) => return Err(e),
},
Value::Integer(i) => {
let f = *i as f64;
match get_exponential_formatted_str(&f, false) {
Ok(str) => result.push_str(&str),
Err(e) => return Err(e),
}
}
Value::Text(s) => {
let number: f64 = s
.as_str()
.trim_start()
.trim_end_matches(|c: char| !c.is_numeric())
.parse()
.unwrap_or(0.0);
match get_exponential_formatted_str(&number, false) {
Ok(str) => result.push_str(&str),
Err(e) => return Err(e),
};
}
_ => result.push_str("0.000000e+00"),
}
args_index += 1;
}
Some('c') => {
if args_index >= values.len() {
return Err(LimboError::InvalidArgument("not enough arguments".into()));
}
let value = &values[args_index].get_value();
let value_str: String = format!("{value}");
if !value_str.is_empty() {
result.push_str(&value_str[0..1]);
}
args_index += 1;
}
Some('x') => {
if args_index >= values.len() {
return Err(LimboError::InvalidArgument("not enough arguments".into()));
}
let value = &values[args_index].get_value();
match value {
Value::Float(f) => result.push_str(&format!("{:x}", *f as i64)),
Value::Integer(i) => result.push_str(&format!("{i:x}")),
_ => result.push('0'),
}
args_index += 1;
}
Some('X') => {
if args_index >= values.len() {
return Err(LimboError::InvalidArgument("not enough arguments".into()));
}
let value = &values[args_index].get_value();
match value {
Value::Float(f) => result.push_str(&format!("{:X}", *f as i64)),
Value::Integer(i) => result.push_str(&format!("{i:X}")),
_ => result.push('0'),
}
args_index += 1;
}
Some('o') => {
if args_index >= values.len() {
return Err(LimboError::InvalidArgument("not enough arguments".into()));
}
let value = &values[args_index].get_value();
match value {
Value::Float(f) => result.push_str(&format!("{:o}", *f as i64)),
Value::Integer(i) => result.push_str(&format!("{i:o}")),
_ => result.push('0'),
}
args_index += 1;
}
None => {
return Err(LimboError::InvalidArgument(
"incomplete format specifier".into(),
@@ -159,6 +318,29 @@ mod tests {
}
}
#[test]
fn test_printf_unsigned_integer_formatting() {
let test_cases = vec![
// Basic
(vec![text("Number: %u"), integer(42)], text("Number: 42")),
// Multiple numbers
(
vec![text("%u + %u = %u"), integer(2), integer(3), integer(5)],
text("2 + 3 = 5"),
),
// Negative number should be represented as its uint representation
(
vec![text("Negative: %u"), integer(-1)],
text("Negative: 18446744073709551615"),
),
// Non-numeric value defaults to 0
(vec![text("NaN: %u"), text("not a number")], text("NaN: 0")),
];
for (input, output) in test_cases {
assert_eq!(exec_printf(&input).unwrap(), *output.get_value())
}
}
#[test]
fn test_printf_float_formatting() {
let test_cases = vec![
@@ -194,6 +376,178 @@ mod tests {
}
}
#[test]
fn test_printf_character_formatting() {
let test_cases = vec![
// Simple character
(vec![text("character: %c"), text("a")], text("character: a")),
// Character with string
(
vec![text("character: %c"), text("this is a test")],
text("character: t"),
),
// Character with empty
(vec![text("character: %c"), text("")], text("character: ")),
// Character with integer
(
vec![text("character: %c"), integer(123)],
text("character: 1"),
),
// Character with float
(
vec![text("character: %c"), float(42.5)],
text("character: 4"),
),
];
for (input, expected) in test_cases {
assert_eq!(exec_printf(&input).unwrap(), *expected.get_value());
}
}
#[test]
fn test_printf_exponential_formatting() {
let test_cases = vec![
// Simple number
(
vec![text("Exp: %e"), float(23000000.0)],
text("Exp: 2.300000e+07"),
),
// Negative number
(
vec![text("Exp: %e"), float(-23000000.0)],
text("Exp: -2.300000e+07"),
),
// Non integer float
(
vec![text("Exp: %e"), float(250.375)],
text("Exp: 2.503750e+02"),
),
// Positive, but smaller than zero
(
vec![text("Exp: %e"), float(0.0003235)],
text("Exp: 3.235000e-04"),
),
// Zero
(vec![text("Exp: %e"), float(0.0)], text("Exp: 0.000000e+00")),
// Uppercase "e"
(
vec![text("Exp: %e"), float(0.0003235)],
text("Exp: 3.235000e-04"),
),
// String with integer number
(
vec![text("Exp: %e"), text("123")],
text("Exp: 1.230000e+02"),
),
// String with floating point number
(
vec![text("Exp: %e"), text("123.45")],
text("Exp: 1.234500e+02"),
),
// String with number with leftmost zeroes
(
vec![text("Exp: %e"), text("00123")],
text("Exp: 1.230000e+02"),
),
// String with text
(
vec![text("Exp: %e"), text("test")],
text("Exp: 0.000000e+00"),
),
// String starting with number, but with text on the end
(
vec![text("Exp: %e"), text("123ab")],
text("Exp: 1.230000e+02"),
),
// String starting with text, but with number on the end
(
vec![text("Exp: %e"), text("ab123")],
text("Exp: 0.000000e+00"),
),
// String with exponential representation
(
vec![text("Exp: %e"), text("1.230000e+02")],
text("Exp: 1.230000e+02"),
),
];
for (input, expected) in test_cases {
assert_eq!(exec_printf(&input).unwrap(), *expected.get_value());
}
}
#[test]
fn test_printf_hexadecimal_formatting() {
let test_cases = vec![
// Simple number
(vec![text("hex: %x"), integer(4)], text("hex: 4")),
// Bigger Number
(
vec![text("hex: %x"), integer(15565303546)],
text("hex: 39fc3aefa"),
),
// Uppercase letters
(
vec![text("hex: %X"), integer(15565303546)],
text("hex: 39FC3AEFA"),
),
// Negative
(
vec![text("hex: %x"), integer(-15565303546)],
text("hex: fffffffc603c5106"),
),
// Float
(vec![text("hex: %x"), float(42.5)], text("hex: 2a")),
// Negative Float
(
vec![text("hex: %x"), float(-42.5)],
text("hex: ffffffffffffffd6"),
),
// Text
(vec![text("hex: %x"), text("42")], text("hex: 0")),
// Empty Text
(vec![text("hex: %x"), text("")], text("hex: 0")),
];
for (input, expected) in test_cases {
assert_eq!(exec_printf(&input).unwrap(), *expected.get_value());
}
}
#[test]
fn test_printf_octal_formatting() {
let test_cases = vec![
// Simple number
(vec![text("octal: %o"), integer(4)], text("octal: 4")),
// Bigger Number
(
vec![text("octal: %o"), integer(15565303546)],
text("octal: 163760727372"),
),
// Negative
(
vec![text("octal: %o"), integer(-15565303546)],
text("octal: 1777777777614017050406"),
),
// Float
(vec![text("octal: %o"), float(42.5)], text("octal: 52")),
// Negative Float
(
vec![text("octal: %o"), float(-42.5)],
text("octal: 1777777777777777777726"),
),
// Text
(vec![text("octal: %o"), text("42")], text("octal: 0")),
// Empty Text
(vec![text("octal: %o"), text("")], text("octal: 0")),
];
for (input, expected) in test_cases {
assert_eq!(exec_printf(&input).unwrap(), *expected.get_value());
}
}
#[test]
fn test_printf_mixed_formatting() {
let test_cases = vec![

View File

@@ -68,9 +68,9 @@ impl File for GenericFile {
}
#[instrument(skip(self, c), level = Level::TRACE)]
fn pread(&self, pos: usize, c: Completion) -> Result<Completion> {
fn pread(&self, pos: u64, c: Completion) -> Result<Completion> {
let mut file = self.file.write();
file.seek(std::io::SeekFrom::Start(pos as u64))?;
file.seek(std::io::SeekFrom::Start(pos))?;
let nr = {
let r = c.as_read();
let buf = r.buf();
@@ -83,9 +83,9 @@ impl File for GenericFile {
}
#[instrument(skip(self, c, buffer), level = Level::TRACE)]
fn pwrite(&self, pos: usize, buffer: Arc<crate::Buffer>, c: Completion) -> Result<Completion> {
fn pwrite(&self, pos: u64, buffer: Arc<crate::Buffer>, c: Completion) -> Result<Completion> {
let mut file = self.file.write();
file.seek(std::io::SeekFrom::Start(pos as u64))?;
file.seek(std::io::SeekFrom::Start(pos))?;
let buf = buffer.as_slice();
file.write_all(buf)?;
c.complete(buffer.len() as i32);
@@ -101,9 +101,9 @@ impl File for GenericFile {
}
#[instrument(err, skip_all, level = Level::TRACE)]
fn truncate(&self, len: usize, c: Completion) -> Result<Completion> {
fn truncate(&self, len: u64, c: Completion) -> Result<Completion> {
let file = self.file.write();
file.set_len(len as u64)?;
file.set_len(len)?;
c.complete(0);
Ok(c)
}

View File

@@ -182,7 +182,7 @@ struct WritevState {
/// File descriptor/id of the file we are writing to
file_id: Fd,
/// absolute file offset for next submit
file_pos: usize,
file_pos: u64,
/// current buffer index in `bufs`
current_buffer_idx: usize,
/// intra-buffer offset
@@ -198,7 +198,7 @@ struct WritevState {
}
impl WritevState {
fn new(file: &UringFile, pos: usize, bufs: Vec<Arc<crate::Buffer>>) -> Self {
fn new(file: &UringFile, pos: u64, bufs: Vec<Arc<crate::Buffer>>) -> Self {
let file_id = file
.id()
.map(Fd::Fixed)
@@ -223,23 +223,23 @@ impl WritevState {
/// Advance (idx, off, pos) after written bytes
#[inline(always)]
fn advance(&mut self, written: usize) {
fn advance(&mut self, written: u64) {
let mut remaining = written;
while remaining > 0 {
let current_buf_len = self.bufs[self.current_buffer_idx].len();
let left = current_buf_len - self.current_buffer_offset;
if remaining < left {
self.current_buffer_offset += remaining;
if remaining < left as u64 {
self.current_buffer_offset += remaining as usize;
self.file_pos += remaining;
remaining = 0;
} else {
remaining -= left;
self.file_pos += left;
remaining -= left as u64;
self.file_pos += left as u64;
self.current_buffer_idx += 1;
self.current_buffer_offset = 0;
}
}
self.total_written += written;
self.total_written += written as usize;
}
#[inline(always)]
@@ -400,7 +400,7 @@ impl WrappedIOUring {
iov_allocation[0].iov_len as u32,
id as u16,
)
.offset(st.file_pos as u64)
.offset(st.file_pos)
.build()
.user_data(key)
} else {
@@ -409,7 +409,7 @@ impl WrappedIOUring {
iov_allocation[0].iov_base as *const u8,
iov_allocation[0].iov_len as u32,
)
.offset(st.file_pos as u64)
.offset(st.file_pos)
.build()
.user_data(key)
}
@@ -425,7 +425,7 @@ impl WrappedIOUring {
let entry = with_fd!(st.file_id, |fd| {
io_uring::opcode::Writev::new(fd, ptr, iov_count as u32)
.offset(st.file_pos as u64)
.offset(st.file_pos)
.build()
.user_data(key)
});
@@ -443,8 +443,8 @@ impl WrappedIOUring {
return;
}
let written = result as usize;
state.advance(written);
let written = result;
state.advance(written as u64);
match state.remaining() {
0 => {
tracing::info!(
@@ -643,7 +643,7 @@ impl File for UringFile {
Ok(())
}
fn pread(&self, pos: usize, c: Completion) -> Result<Completion> {
fn pread(&self, pos: u64, c: Completion) -> Result<Completion> {
let r = c.as_read();
let mut io = self.io.borrow_mut();
let read_e = {
@@ -663,14 +663,14 @@ impl File for UringFile {
io.debug_check_fixed(idx, ptr, len);
}
io_uring::opcode::ReadFixed::new(fd, ptr, len as u32, idx as u16)
.offset(pos as u64)
.offset(pos)
.build()
.user_data(get_key(c.clone()))
} else {
trace!("pread(pos = {}, length = {})", pos, len);
// Use Read opcode if fixed buffer is not available
io_uring::opcode::Read::new(fd, buf.as_mut_ptr(), len as u32)
.offset(pos as u64)
.offset(pos)
.build()
.user_data(get_key(c.clone()))
}
@@ -680,7 +680,7 @@ impl File for UringFile {
Ok(c)
}
fn pwrite(&self, pos: usize, buffer: Arc<crate::Buffer>, c: Completion) -> Result<Completion> {
fn pwrite(&self, pos: u64, buffer: Arc<crate::Buffer>, c: Completion) -> Result<Completion> {
let mut io = self.io.borrow_mut();
let write = {
let ptr = buffer.as_ptr();
@@ -698,13 +698,13 @@ impl File for UringFile {
io.debug_check_fixed(idx, ptr, len);
}
io_uring::opcode::WriteFixed::new(fd, ptr, len as u32, idx as u16)
.offset(pos as u64)
.offset(pos)
.build()
.user_data(get_key(c.clone()))
} else {
trace!("pwrite(pos = {}, length = {})", pos, buffer.len());
io_uring::opcode::Write::new(fd, ptr, len as u32)
.offset(pos as u64)
.offset(pos)
.build()
.user_data(get_key(c.clone()))
}
@@ -728,7 +728,7 @@ impl File for UringFile {
fn pwritev(
&self,
pos: usize,
pos: u64,
bufs: Vec<Arc<crate::Buffer>>,
c: Completion,
) -> Result<Completion> {
@@ -748,10 +748,10 @@ impl File for UringFile {
Ok(self.file.metadata()?.len())
}
fn truncate(&self, len: usize, c: Completion) -> Result<Completion> {
fn truncate(&self, len: u64, c: Completion) -> Result<Completion> {
let mut io = self.io.borrow_mut();
let truncate = with_fd!(self, |fd| {
io_uring::opcode::Ftruncate::new(fd, len as u64)
io_uring::opcode::Ftruncate::new(fd, len)
.build()
.user_data(get_key(c.clone()))
});

View File

@@ -69,17 +69,12 @@ impl IO for MemoryIO {
files.remove(path);
Ok(())
}
fn run_once(&self) -> Result<()> {
// nop
Ok(())
}
}
pub struct MemoryFile {
path: String,
pages: UnsafeCell<BTreeMap<usize, MemPage>>,
size: Cell<usize>,
size: Cell<u64>,
}
unsafe impl Send for MemoryFile {}
unsafe impl Sync for MemoryFile {}
@@ -92,10 +87,10 @@ impl File for MemoryFile {
Ok(())
}
fn pread(&self, pos: usize, c: Completion) -> Result<Completion> {
fn pread(&self, pos: u64, c: Completion) -> Result<Completion> {
tracing::debug!("pread(path={}): pos={}", self.path, pos);
let r = c.as_read();
let buf_len = r.buf().len();
let buf_len = r.buf().len() as u64;
if buf_len == 0 {
c.complete(0);
return Ok(c);
@@ -110,8 +105,8 @@ impl File for MemoryFile {
let read_len = buf_len.min(file_size - pos);
{
let read_buf = r.buf();
let mut offset = pos;
let mut remaining = read_len;
let mut offset = pos as usize;
let mut remaining = read_len as usize;
let mut buf_offset = 0;
while remaining > 0 {
@@ -134,7 +129,7 @@ impl File for MemoryFile {
Ok(c)
}
fn pwrite(&self, pos: usize, buffer: Arc<Buffer>, c: Completion) -> Result<Completion> {
fn pwrite(&self, pos: u64, buffer: Arc<Buffer>, c: Completion) -> Result<Completion> {
tracing::debug!(
"pwrite(path={}): pos={}, size={}",
self.path,
@@ -147,7 +142,7 @@ impl File for MemoryFile {
return Ok(c);
}
let mut offset = pos;
let mut offset = pos as usize;
let mut remaining = buf_len;
let mut buf_offset = 0;
let data = &buffer.as_slice();
@@ -158,7 +153,7 @@ impl File for MemoryFile {
let bytes_to_write = remaining.min(PAGE_SIZE - page_offset);
{
let page = self.get_or_allocate_page(page_no);
let page = self.get_or_allocate_page(page_no as u64);
page[page_offset..page_offset + bytes_to_write]
.copy_from_slice(&data[buf_offset..buf_offset + bytes_to_write]);
}
@@ -169,7 +164,7 @@ impl File for MemoryFile {
}
self.size
.set(core::cmp::max(pos + buf_len, self.size.get()));
.set(core::cmp::max(pos + buf_len as u64, self.size.get()));
c.complete(buf_len as i32);
Ok(c)
@@ -182,13 +177,13 @@ impl File for MemoryFile {
Ok(c)
}
fn truncate(&self, len: usize, c: Completion) -> Result<Completion> {
fn truncate(&self, len: u64, c: Completion) -> Result<Completion> {
tracing::debug!("truncate(path={}): len={}", self.path, len);
if len < self.size.get() {
// Truncate pages
unsafe {
let pages = &mut *self.pages.get();
pages.retain(|&k, _| k * PAGE_SIZE < len);
pages.retain(|&k, _| k * PAGE_SIZE < len as usize);
}
}
self.size.set(len);
@@ -196,14 +191,14 @@ impl File for MemoryFile {
Ok(c)
}
fn pwritev(&self, pos: usize, buffers: Vec<Arc<Buffer>>, c: Completion) -> Result<Completion> {
fn pwritev(&self, pos: u64, buffers: Vec<Arc<Buffer>>, c: Completion) -> Result<Completion> {
tracing::debug!(
"pwritev(path={}): pos={}, buffers={:?}",
self.path,
pos,
buffers.iter().map(|x| x.len()).collect::<Vec<_>>()
);
let mut offset = pos;
let mut offset = pos as usize;
let mut total_written = 0;
for buffer in buffers {
@@ -222,7 +217,7 @@ impl File for MemoryFile {
let bytes_to_write = remaining.min(PAGE_SIZE - page_offset);
{
let page = self.get_or_allocate_page(page_no);
let page = self.get_or_allocate_page(page_no as u64);
page[page_offset..page_offset + bytes_to_write]
.copy_from_slice(&data[buf_offset..buf_offset + bytes_to_write]);
}
@@ -235,23 +230,23 @@ impl File for MemoryFile {
}
c.complete(total_written as i32);
self.size
.set(core::cmp::max(pos + total_written, self.size.get()));
.set(core::cmp::max(pos + total_written as u64, self.size.get()));
Ok(c)
}
fn size(&self) -> Result<u64> {
tracing::debug!("size(path={}): {}", self.path, self.size.get());
Ok(self.size.get() as u64)
Ok(self.size.get())
}
}
impl MemoryFile {
#[allow(clippy::mut_from_ref)]
fn get_or_allocate_page(&self, page_no: usize) -> &mut MemPage {
fn get_or_allocate_page(&self, page_no: u64) -> &mut MemPage {
unsafe {
let pages = &mut *self.pages.get();
pages
.entry(page_no)
.entry(page_no as usize)
.or_insert_with(|| Box::new([0; PAGE_SIZE]))
}
}

View File

@@ -12,10 +12,10 @@ use std::{fmt::Debug, pin::Pin};
pub trait File: Send + Sync {
fn lock_file(&self, exclusive: bool) -> Result<()>;
fn unlock_file(&self) -> Result<()>;
fn pread(&self, pos: usize, c: Completion) -> Result<Completion>;
fn pwrite(&self, pos: usize, buffer: Arc<Buffer>, c: Completion) -> Result<Completion>;
fn pread(&self, pos: u64, c: Completion) -> Result<Completion>;
fn pwrite(&self, pos: u64, buffer: Arc<Buffer>, c: Completion) -> Result<Completion>;
fn sync(&self, c: Completion) -> Result<Completion>;
fn pwritev(&self, pos: usize, buffers: Vec<Arc<Buffer>>, c: Completion) -> Result<Completion> {
fn pwritev(&self, pos: u64, buffers: Vec<Arc<Buffer>>, c: Completion) -> Result<Completion> {
use std::sync::atomic::{AtomicUsize, Ordering};
if buffers.is_empty() {
c.complete(0);
@@ -56,12 +56,12 @@ pub trait File: Send + Sync {
c.abort();
return Err(e);
}
pos += len;
pos += len as u64;
}
Ok(c)
}
fn size(&self) -> Result<u64>;
fn truncate(&self, len: usize, c: Completion) -> Result<Completion>;
fn truncate(&self, len: u64, c: Completion) -> Result<Completion>;
}
#[derive(Debug, Copy, Clone, PartialEq)]
@@ -87,7 +87,9 @@ pub trait IO: Clock + Send + Sync {
// remove_file is used in the sync-engine
fn remove_file(&self, path: &str) -> Result<()>;
fn run_once(&self) -> Result<()>;
fn run_once(&self) -> Result<()> {
Ok(())
}
fn wait_for_completion(&self, c: Completion) -> Result<()> {
while !c.finished() {
@@ -214,6 +216,10 @@ impl Completion {
self.inner.result.get().is_some_and(|val| val.is_some())
}
pub fn get_error(&self) -> Option<CompletionError> {
self.inner.result.get().and_then(|res| *res)
}
/// Checks if the Completion completed or errored
pub fn finished(&self) -> bool {
self.inner.result.get().is_some()

View File

@@ -15,8 +15,6 @@ use std::{io::ErrorKind, sync::Arc};
use tracing::debug;
use tracing::{instrument, trace, Level};
/// UnixIO lives longer than any of the files it creates, so it is
/// safe to store references to it's internals in the UnixFiles
pub struct UnixIO {}
unsafe impl Send for UnixIO {}
@@ -127,24 +125,6 @@ impl IO for UnixIO {
}
}
// enum CompletionCallback {
// Read(Arc<Mutex<std::fs::File>>, Completion, usize),
// Write(
// Arc<Mutex<std::fs::File>>,
// Completion,
// Arc<crate::Buffer>,
// usize,
// ),
// Writev(
// Arc<Mutex<std::fs::File>>,
// Completion,
// Vec<Arc<crate::Buffer>>,
// usize, // absolute file offset
// usize, // buf index
// usize, // intra-buf offset
// ),
// }
pub struct UnixFile {
file: Arc<Mutex<std::fs::File>>,
}
@@ -192,7 +172,7 @@ impl File for UnixFile {
}
#[instrument(err, skip_all, level = Level::TRACE)]
fn pread(&self, pos: usize, c: Completion) -> Result<Completion> {
fn pread(&self, pos: u64, c: Completion) -> Result<Completion> {
let file = self.file.lock();
let result = unsafe {
let r = c.as_read();
@@ -217,7 +197,7 @@ impl File for UnixFile {
}
#[instrument(err, skip_all, level = Level::TRACE)]
fn pwrite(&self, pos: usize, buffer: Arc<crate::Buffer>, c: Completion) -> Result<Completion> {
fn pwrite(&self, pos: u64, buffer: Arc<crate::Buffer>, c: Completion) -> Result<Completion> {
let file = self.file.lock();
let result = unsafe {
libc::pwrite(
@@ -241,7 +221,7 @@ impl File for UnixFile {
#[instrument(err, skip_all, level = Level::TRACE)]
fn pwritev(
&self,
pos: usize,
pos: u64,
buffers: Vec<Arc<crate::Buffer>>,
c: Completion,
) -> Result<Completion> {
@@ -251,7 +231,7 @@ impl File for UnixFile {
}
let file = self.file.lock();
match try_pwritev_raw(file.as_raw_fd(), pos as u64, &buffers, 0, 0) {
match try_pwritev_raw(file.as_raw_fd(), pos, &buffers, 0, 0) {
Ok(written) => {
trace!("pwritev wrote {written}");
c.complete(written as i32);
@@ -268,24 +248,21 @@ impl File for UnixFile {
let file = self.file.lock();
let result = unsafe {
#[cfg(not(any(target_os = "macos", target_os = "ios")))]
{
libc::fsync(file.as_raw_fd())
}
#[cfg(any(target_os = "macos", target_os = "ios"))]
{
libc::fcntl(file.as_raw_fd(), libc::F_FULLFSYNC)
}
};
if result == -1 {
let e = std::io::Error::last_os_error();
Err(e.into())
} else {
#[cfg(not(any(target_os = "macos", target_os = "ios")))]
trace!("fsync");
@@ -304,9 +281,9 @@ impl File for UnixFile {
}
#[instrument(err, skip_all, level = Level::INFO)]
fn truncate(&self, len: usize, c: Completion) -> Result<Completion> {
fn truncate(&self, len: u64, c: Completion) -> Result<Completion> {
let file = self.file.lock();
let result = file.set_len(len as u64);
let result = file.set_len(len);
match result {
Ok(()) => {
trace!("file truncated to len=({})", len);

View File

@@ -81,8 +81,6 @@ impl VfsMod {
}
}
// #Safety:
// the callback wrapper in the extension library is FnOnce, so we know
/// # Safety
/// the callback wrapper in the extension library is FnOnce, so we know
/// that the into_raw/from_raw contract will hold
@@ -121,7 +119,7 @@ impl File for VfsFileImpl {
Ok(())
}
fn pread(&self, pos: usize, c: Completion) -> Result<Completion> {
fn pread(&self, pos: u64, c: Completion) -> Result<Completion> {
if self.vfs.is_null() {
c.complete(-1);
return Err(LimboError::ExtensionError("VFS is null".to_string()));
@@ -145,7 +143,7 @@ impl File for VfsFileImpl {
Ok(c)
}
fn pwrite(&self, pos: usize, buffer: Arc<Buffer>, c: Completion) -> Result<Completion> {
fn pwrite(&self, pos: u64, buffer: Arc<Buffer>, c: Completion) -> Result<Completion> {
if self.vfs.is_null() {
c.complete(-1);
return Err(LimboError::ExtensionError("VFS is null".to_string()));
@@ -192,7 +190,7 @@ impl File for VfsFileImpl {
}
}
fn truncate(&self, len: usize, c: Completion) -> Result<Completion> {
fn truncate(&self, len: u64, c: Completion) -> Result<Completion> {
if self.vfs.is_null() {
c.complete(-1);
return Err(LimboError::ExtensionError("VFS is null".to_string()));

115
core/io/windows.rs Normal file
View File

@@ -0,0 +1,115 @@
use crate::{Clock, Completion, File, Instant, LimboError, OpenFlags, Result, IO};
use parking_lot::RwLock;
use std::io::{Read, Seek, Write};
use std::sync::Arc;
use tracing::{debug, instrument, trace, Level};
pub struct WindowsIO {}
impl WindowsIO {
pub fn new() -> Result<Self> {
debug!("Using IO backend 'syscall'");
Ok(Self {})
}
}
impl IO for WindowsIO {
#[instrument(err, skip_all, level = Level::TRACE)]
fn open_file(&self, path: &str, flags: OpenFlags, direct: bool) -> Result<Arc<dyn File>> {
trace!("open_file(path = {})", path);
let mut file = std::fs::File::options();
file.read(true);
if !flags.contains(OpenFlags::ReadOnly) {
file.write(true);
file.create(flags.contains(OpenFlags::Create));
}
let file = file.open(path)?;
Ok(Arc::new(WindowsFile {
file: RwLock::new(file),
}))
}
#[instrument(err, skip_all, level = Level::TRACE)]
fn remove_file(&self, path: &str) -> Result<()> {
trace!("remove_file(path = {})", path);
Ok(std::fs::remove_file(path)?)
}
#[instrument(err, skip_all, level = Level::TRACE)]
fn run_once(&self) -> Result<()> {
Ok(())
}
}
impl Clock for WindowsIO {
fn now(&self) -> Instant {
let now = chrono::Local::now();
Instant {
secs: now.timestamp(),
micros: now.timestamp_subsec_micros(),
}
}
}
pub struct WindowsFile {
file: RwLock<std::fs::File>,
}
impl File for WindowsFile {
#[instrument(err, skip_all, level = Level::TRACE)]
fn lock_file(&self, exclusive: bool) -> Result<()> {
unimplemented!()
}
#[instrument(err, skip_all, level = Level::TRACE)]
fn unlock_file(&self) -> Result<()> {
unimplemented!()
}
#[instrument(skip(self, c), level = Level::TRACE)]
fn pread(&self, pos: u64, c: Completion) -> Result<Completion> {
let mut file = self.file.write();
file.seek(std::io::SeekFrom::Start(pos))?;
let nr = {
let r = c.as_read();
let buf = r.buf();
let buf = buf.as_mut_slice();
file.read_exact(buf)?;
buf.len() as i32
};
c.complete(nr);
Ok(c)
}
#[instrument(skip(self, c, buffer), level = Level::TRACE)]
fn pwrite(&self, pos: u64, buffer: Arc<crate::Buffer>, c: Completion) -> Result<Completion> {
let mut file = self.file.write();
file.seek(std::io::SeekFrom::Start(pos))?;
let buf = buffer.as_slice();
file.write_all(buf)?;
c.complete(buffer.len() as i32);
Ok(c)
}
#[instrument(err, skip_all, level = Level::TRACE)]
fn sync(&self, c: Completion) -> Result<Completion> {
let file = self.file.write();
file.sync_all()?;
c.complete(0);
Ok(c)
}
#[instrument(err, skip_all, level = Level::TRACE)]
fn truncate(&self, len: u64, c: Completion) -> Result<Completion> {
let file = self.file.write();
file.set_len(len)?;
c.complete(0);
Ok(c)
}
fn size(&self) -> Result<u64> {
let file = self.file.read();
Ok(file.metadata().unwrap().len())
}
}

View File

@@ -1277,6 +1277,12 @@ impl Connection {
std::fs::set_permissions(&opts.path, perms.permissions())?;
}
let conn = db.connect()?;
if let Some(cipher) = opts.cipher {
let _ = conn.pragma_update("cipher", format!("'{cipher}'"));
}
if let Some(hexkey) = opts.hexkey {
let _ = conn.pragma_update("hexkey", format!("'{hexkey}'"));
}
Ok((io, conn))
}

View File

@@ -1047,8 +1047,8 @@ impl Column {
}
// TODO: This might replace some of util::columns_from_create_table_body
impl From<ColumnDefinition> for Column {
fn from(value: ColumnDefinition) -> Self {
impl From<&ColumnDefinition> for Column {
fn from(value: &ColumnDefinition) -> Self {
let name = value.col_name.as_str();
let mut default = None;
@@ -1057,13 +1057,13 @@ impl From<ColumnDefinition> for Column {
let mut unique = false;
let mut collation = None;
for ast::NamedColumnConstraint { constraint, .. } in value.constraints {
for ast::NamedColumnConstraint { constraint, .. } in &value.constraints {
match constraint {
ast::ColumnConstraint::PrimaryKey { .. } => primary_key = true,
ast::ColumnConstraint::NotNull { .. } => notnull = true,
ast::ColumnConstraint::Unique(..) => unique = true,
ast::ColumnConstraint::Default(expr) => {
default.replace(expr);
default.replace(expr.clone());
}
ast::ColumnConstraint::Collate { collation_name } => {
collation.replace(
@@ -1082,11 +1082,14 @@ impl From<ColumnDefinition> for Column {
let ty_str = value
.col_type
.as_ref()
.map(|t| t.name.to_string())
.unwrap_or_default();
let hidden = ty_str.contains("HIDDEN");
Column {
name: Some(name.to_string()),
name: Some(normalize_ident(name)),
ty,
default,
notnull,
@@ -1095,7 +1098,7 @@ impl From<ColumnDefinition> for Column {
is_rowid_alias: primary_key && matches!(ty, Type::Integer),
unique,
collation,
hidden: false,
hidden,
}
}
}

View File

@@ -91,7 +91,9 @@ impl DatabaseStorage for DatabaseFile {
if !(512..=65536).contains(&size) || size & (size - 1) != 0 {
return Err(LimboError::NotADB);
}
let pos = (page_idx - 1) * size;
let Some(pos) = (page_idx as u64 - 1).checked_mul(size as u64) else {
return Err(LimboError::IntegerOverflow);
};
if let Some(ctx) = io_ctx.encryption_context() {
let encryption_ctx = ctx.clone();
@@ -145,7 +147,9 @@ impl DatabaseStorage for DatabaseFile {
assert!(buffer_size >= 512);
assert!(buffer_size <= 65536);
assert_eq!(buffer_size & (buffer_size - 1), 0);
let pos = (page_idx - 1) * buffer_size;
let Some(pos) = (page_idx as u64 - 1).checked_mul(buffer_size as u64) else {
return Err(LimboError::IntegerOverflow);
};
let buffer = {
if let Some(ctx) = io_ctx.encryption_context() {
encrypt_buffer(page_idx, buffer, ctx)
@@ -169,7 +173,9 @@ impl DatabaseStorage for DatabaseFile {
assert!(page_size <= 65536);
assert_eq!(page_size & (page_size - 1), 0);
let pos = (first_page_idx - 1) * page_size;
let Some(pos) = (first_page_idx as u64 - 1).checked_mul(page_size as u64) else {
return Err(LimboError::IntegerOverflow);
};
let buffers = {
if let Some(ctx) = io_ctx.encryption_context() {
buffers
@@ -198,7 +204,7 @@ impl DatabaseStorage for DatabaseFile {
#[instrument(skip_all, level = Level::INFO)]
fn truncate(&self, len: usize, c: Completion) -> Result<Completion> {
let c = self.file.truncate(len, c)?;
let c = self.file.truncate(len as u64, c)?;
Ok(c)
}
}

View File

@@ -1,14 +1,13 @@
#![allow(unused_variables, dead_code)]
use crate::{LimboError, Result};
use aegis::aegis256::Aegis256;
use aes_gcm::{
aead::{Aead, AeadCore, KeyInit, OsRng},
Aes256Gcm, Key, Nonce,
};
use aes_gcm::aead::{AeadCore, OsRng};
use std::ops::Deref;
use turso_macros::match_ignore_ascii_case;
pub const ENCRYPTED_PAGE_SIZE: usize = 4096;
// AEGIS-256 supports both 16 and 32 byte tags, we use the 16 byte variant, it is faster
// and provides sufficient security for our use case.
const AEGIS_TAG_SIZE: usize = 16;
const AES256GCM_TAG_SIZE: usize = 16;
#[repr(transparent)]
#[derive(Clone)]
@@ -74,10 +73,25 @@ impl Drop for EncryptionKey {
}
}
pub trait AeadCipher {
fn encrypt(&self, plaintext: &[u8], ad: &[u8]) -> Result<(Vec<u8>, Vec<u8>)>;
fn decrypt(&self, ciphertext: &[u8], nonce: &[u8], ad: &[u8]) -> Result<Vec<u8>>;
fn encrypt_detached(&self, plaintext: &[u8], ad: &[u8]) -> Result<(Vec<u8>, Vec<u8>, Vec<u8>)>;
fn decrypt_detached(
&self,
ciphertext: &[u8],
nonce: &[u8],
tag: &[u8],
ad: &[u8],
) -> Result<Vec<u8>>;
}
// wrapper struct for AEGIS-256 cipher, because the crate we use is a bit low-level and we add
// some nice abstractions here
// note, the AEGIS has many variants and support for hardware acceleration. Here we just use the
// vanilla version, which is still order of maginitudes faster than AES-GCM in software. Hardware
// vanilla version, which is still order of magnitudes faster than AES-GCM in software. Hardware
// based compilation is left for future work.
#[derive(Clone)]
pub struct Aegis256Cipher {
@@ -85,39 +99,154 @@ pub struct Aegis256Cipher {
}
impl Aegis256Cipher {
// AEGIS-256 supports both 16 and 32 byte tags, we use the 16 byte variant, it is faster
// and provides sufficient security for our use case.
const TAG_SIZE: usize = 16;
fn new(key: &EncryptionKey) -> Self {
Self { key: key.clone() }
}
}
fn encrypt(&self, plaintext: &[u8], ad: &[u8]) -> Result<(Vec<u8>, [u8; 32])> {
impl AeadCipher for Aegis256Cipher {
fn encrypt(&self, plaintext: &[u8], ad: &[u8]) -> Result<(Vec<u8>, Vec<u8>)> {
let nonce = generate_secure_nonce();
let (ciphertext, tag) =
Aegis256::<16>::new(self.key.as_bytes(), &nonce).encrypt(plaintext, ad);
Aegis256::<AEGIS_TAG_SIZE>::new(self.key.as_bytes(), &nonce).encrypt(plaintext, ad);
let mut result = ciphertext;
result.extend_from_slice(&tag);
Ok((result, nonce))
Ok((result, nonce.to_vec()))
}
fn decrypt(&self, ciphertext: &[u8], nonce: &[u8; 32], ad: &[u8]) -> Result<Vec<u8>> {
if ciphertext.len() < Self::TAG_SIZE {
return Err(LimboError::InternalError(
"Ciphertext too short for AEGIS-256".into(),
));
fn decrypt(&self, ciphertext: &[u8], nonce: &[u8], ad: &[u8]) -> Result<Vec<u8>> {
if ciphertext.len() < AEGIS_TAG_SIZE {
return Err(LimboError::InternalError("Ciphertext too short".into()));
}
let (ct, tag) = ciphertext.split_at(ciphertext.len() - Self::TAG_SIZE);
let tag_array: [u8; 16] = tag
.try_into()
.map_err(|_| LimboError::InternalError("Invalid tag size for AEGIS-256".into()))?;
let (ct, tag) = ciphertext.split_at(ciphertext.len() - AEGIS_TAG_SIZE);
let tag_array: [u8; AEGIS_TAG_SIZE] = tag.try_into().map_err(|_| {
LimboError::InternalError(format!("Invalid tag size for AEGIS-256 {AEGIS_TAG_SIZE}"))
})?;
let plaintext = Aegis256::<16>::new(self.key.as_bytes(), nonce)
let nonce_array: [u8; 32] = nonce
.try_into()
.map_err(|_| LimboError::InternalError("Invalid nonce size for AEGIS-256".into()))?;
Aegis256::<AEGIS_TAG_SIZE>::new(self.key.as_bytes(), &nonce_array)
.decrypt(ct, &tag_array, ad)
.map_err(|_| {
LimboError::InternalError("AEGIS-256 decryption failed: invalid tag".into())
})?;
Ok(plaintext)
.map_err(|_| LimboError::InternalError("AEGIS-256 decryption failed".into()))
}
fn encrypt_detached(&self, plaintext: &[u8], ad: &[u8]) -> Result<(Vec<u8>, Vec<u8>, Vec<u8>)> {
let nonce = generate_secure_nonce();
let (ciphertext, tag) =
Aegis256::<AEGIS_TAG_SIZE>::new(self.key.as_bytes(), &nonce).encrypt(plaintext, ad);
Ok((ciphertext, tag.to_vec(), nonce.to_vec()))
}
fn decrypt_detached(
&self,
ciphertext: &[u8],
nonce: &[u8],
tag: &[u8],
ad: &[u8],
) -> Result<Vec<u8>> {
let tag_array: [u8; AEGIS_TAG_SIZE] = tag.try_into().map_err(|_| {
LimboError::InternalError(format!("Invalid tag size for AEGIS-256 {AEGIS_TAG_SIZE}"))
})?;
let nonce_array: [u8; 32] = nonce
.try_into()
.map_err(|_| LimboError::InternalError("Invalid nonce size for AEGIS-256".into()))?;
Aegis256::<AEGIS_TAG_SIZE>::new(self.key.as_bytes(), &nonce_array)
.decrypt(ciphertext, &tag_array, ad)
.map_err(|_| LimboError::InternalError("AEGIS-256 decrypt_detached failed".into()))
}
}
#[derive(Clone)]
pub struct Aes256GcmCipher {
key: EncryptionKey,
}
impl Aes256GcmCipher {
fn new(key: &EncryptionKey) -> Self {
Self { key: key.clone() }
}
}
impl AeadCipher for Aes256GcmCipher {
fn encrypt(&self, plaintext: &[u8], _ad: &[u8]) -> Result<(Vec<u8>, Vec<u8>)> {
use aes_gcm::aead::{AeadInPlace, KeyInit};
use aes_gcm::Aes256Gcm;
let cipher = Aes256Gcm::new_from_slice(self.key.as_bytes())
.map_err(|_| LimboError::InternalError("Bad AES key".into()))?;
let nonce = Aes256Gcm::generate_nonce(&mut rand::thread_rng());
let mut buffer = plaintext.to_vec();
let tag = cipher
.encrypt_in_place_detached(&nonce, b"", &mut buffer)
.map_err(|_| LimboError::InternalError("AES-GCM encrypt failed".into()))?;
buffer.extend_from_slice(&tag[..AES256GCM_TAG_SIZE]);
Ok((buffer, nonce.to_vec()))
}
fn decrypt(&self, ciphertext: &[u8], nonce: &[u8], ad: &[u8]) -> Result<Vec<u8>> {
use aes_gcm::aead::{AeadInPlace, KeyInit};
use aes_gcm::{Aes256Gcm, Nonce};
if ciphertext.len() < AES256GCM_TAG_SIZE {
return Err(LimboError::InternalError("Ciphertext too short".into()));
}
let (ct, tag) = ciphertext.split_at(ciphertext.len() - AES256GCM_TAG_SIZE);
let cipher = Aes256Gcm::new_from_slice(self.key.as_bytes())
.map_err(|_| LimboError::InternalError("Bad AES key".into()))?;
let nonce = Nonce::from_slice(nonce);
let mut buffer = ct.to_vec();
cipher
.decrypt_in_place_detached(nonce, ad, &mut buffer, tag.into())
.map_err(|_| LimboError::InternalError("AES-GCM decrypt failed".into()))?;
Ok(buffer)
}
fn encrypt_detached(&self, plaintext: &[u8], ad: &[u8]) -> Result<(Vec<u8>, Vec<u8>, Vec<u8>)> {
use aes_gcm::aead::{AeadInPlace, KeyInit};
use aes_gcm::Aes256Gcm;
let cipher = Aes256Gcm::new_from_slice(self.key.as_bytes())
.map_err(|_| LimboError::InternalError("Bad AES key".into()))?;
let nonce = Aes256Gcm::generate_nonce(&mut rand::thread_rng());
let mut buffer = plaintext.to_vec();
let tag = cipher
.encrypt_in_place_detached(&nonce, ad, &mut buffer)
.map_err(|_| LimboError::InternalError("AES-GCM encrypt_detached failed".into()))?;
Ok((buffer, nonce.to_vec(), tag.to_vec()))
}
fn decrypt_detached(
&self,
ciphertext: &[u8],
nonce: &[u8],
tag: &[u8],
ad: &[u8],
) -> Result<Vec<u8>> {
use aes_gcm::aead::{AeadInPlace, KeyInit};
use aes_gcm::{Aes256Gcm, Nonce};
let cipher = Aes256Gcm::new_from_slice(self.key.as_bytes())
.map_err(|_| LimboError::InternalError("Bad AES key".into()))?;
let nonce = Nonce::from_slice(nonce);
let mut buffer = ciphertext.to_vec();
cipher
.decrypt_in_place_detached(nonce, ad, &mut buffer, tag.into())
.map_err(|_| LimboError::InternalError("AES-GCM decrypt_detached failed".into()))?;
Ok(buffer)
}
}
@@ -180,8 +309,8 @@ impl CipherMode {
/// Returns the authentication tag size for this cipher mode.
pub fn tag_size(&self) -> usize {
match self {
CipherMode::Aes256Gcm => 16,
CipherMode::Aegis256 => 16,
CipherMode::Aes256Gcm => AES256GCM_TAG_SIZE,
CipherMode::Aegis256 => AEGIS_TAG_SIZE,
}
}
@@ -193,8 +322,17 @@ impl CipherMode {
#[derive(Clone)]
pub enum Cipher {
Aes256Gcm(Box<Aes256Gcm>),
Aegis256(Box<Aegis256Cipher>),
Aes256Gcm(Aes256GcmCipher),
Aegis256(Aegis256Cipher),
}
impl Cipher {
fn as_aead(&self) -> &dyn AeadCipher {
match self {
Cipher::Aes256Gcm(c) => c,
Cipher::Aegis256(c) => c,
}
}
}
impl std::fmt::Debug for Cipher {
@@ -210,10 +348,11 @@ impl std::fmt::Debug for Cipher {
pub struct EncryptionContext {
cipher_mode: CipherMode,
cipher: Cipher,
page_size: usize,
}
impl EncryptionContext {
pub fn new(cipher_mode: CipherMode, key: &EncryptionKey) -> Result<Self> {
pub fn new(cipher_mode: CipherMode, key: &EncryptionKey, page_size: usize) -> Result<Self> {
let required_size = cipher_mode.required_key_size();
if key.as_slice().len() != required_size {
return Err(crate::LimboError::InvalidArgument(format!(
@@ -225,15 +364,13 @@ impl EncryptionContext {
}
let cipher = match cipher_mode {
CipherMode::Aes256Gcm => {
let cipher_key: &Key<Aes256Gcm> = key.as_ref().into();
Cipher::Aes256Gcm(Box::new(Aes256Gcm::new(cipher_key)))
}
CipherMode::Aegis256 => Cipher::Aegis256(Box::new(Aegis256Cipher::new(key))),
CipherMode::Aes256Gcm => Cipher::Aes256Gcm(Aes256GcmCipher::new(key)),
CipherMode::Aegis256 => Cipher::Aegis256(Aegis256Cipher::new(key)),
};
Ok(Self {
cipher_mode,
cipher,
page_size,
})
}
@@ -255,36 +392,38 @@ impl EncryptionContext {
tracing::debug!("encrypting page {}", page_id);
assert_eq!(
page.len(),
ENCRYPTED_PAGE_SIZE,
"Page data must be exactly {ENCRYPTED_PAGE_SIZE} bytes"
self.page_size,
"Page data must be exactly {} bytes",
self.page_size
);
let metadata_size = self.cipher_mode.metadata_size();
let reserved_bytes = &page[ENCRYPTED_PAGE_SIZE - metadata_size..];
let reserved_bytes = &page[self.page_size - metadata_size..];
let reserved_bytes_zeroed = reserved_bytes.iter().all(|&b| b == 0);
assert!(
reserved_bytes_zeroed,
"last reserved bytes must be empty/zero, but found non-zero bytes"
);
let payload = &page[..ENCRYPTED_PAGE_SIZE - metadata_size];
let payload = &page[..self.page_size - metadata_size];
let (encrypted, nonce) = self.encrypt_raw(payload)?;
let nonce_size = self.cipher_mode.nonce_size();
assert_eq!(
encrypted.len(),
ENCRYPTED_PAGE_SIZE - nonce_size,
self.page_size - nonce_size,
"Encrypted page must be exactly {} bytes",
ENCRYPTED_PAGE_SIZE - nonce_size
self.page_size - nonce_size
);
let mut result = Vec::with_capacity(ENCRYPTED_PAGE_SIZE);
let mut result = Vec::with_capacity(self.page_size);
result.extend_from_slice(&encrypted);
result.extend_from_slice(&nonce);
assert_eq!(
result.len(),
ENCRYPTED_PAGE_SIZE,
"Encrypted page must be exactly {ENCRYPTED_PAGE_SIZE} bytes"
self.page_size,
"Encrypted page must be exactly {} bytes",
self.page_size
);
Ok(result)
}
@@ -298,8 +437,9 @@ impl EncryptionContext {
tracing::debug!("decrypting page {}", page_id);
assert_eq!(
encrypted_page.len(),
ENCRYPTED_PAGE_SIZE,
"Encrypted page data must be exactly {ENCRYPTED_PAGE_SIZE} bytes"
self.page_size,
"Encrypted page data must be exactly {} bytes",
self.page_size
);
let nonce_size = self.cipher_mode.nonce_size();
@@ -312,60 +452,40 @@ impl EncryptionContext {
let metadata_size = self.cipher_mode.metadata_size();
assert_eq!(
decrypted_data.len(),
ENCRYPTED_PAGE_SIZE - metadata_size,
self.page_size - metadata_size,
"Decrypted page data must be exactly {} bytes",
ENCRYPTED_PAGE_SIZE - metadata_size
self.page_size - metadata_size
);
let mut result = Vec::with_capacity(ENCRYPTED_PAGE_SIZE);
let mut result = Vec::with_capacity(self.page_size);
result.extend_from_slice(&decrypted_data);
result.resize(ENCRYPTED_PAGE_SIZE, 0);
result.resize(self.page_size, 0);
assert_eq!(
result.len(),
ENCRYPTED_PAGE_SIZE,
"Decrypted page data must be exactly {ENCRYPTED_PAGE_SIZE} bytes"
self.page_size,
"Decrypted page data must be exactly {} bytes",
self.page_size
);
Ok(result)
}
/// encrypts raw data using the configured cipher, returns ciphertext and nonce
fn encrypt_raw(&self, plaintext: &[u8]) -> Result<(Vec<u8>, Vec<u8>)> {
match &self.cipher {
Cipher::Aes256Gcm(cipher) => {
let nonce = Aes256Gcm::generate_nonce(&mut OsRng);
let ciphertext = cipher
.encrypt(&nonce, plaintext)
.map_err(|e| LimboError::InternalError(format!("Encryption failed: {e:?}")))?;
Ok((ciphertext, nonce.to_vec()))
}
Cipher::Aegis256(cipher) => {
let ad = b"";
let (ciphertext, nonce) = cipher.encrypt(plaintext, ad)?;
Ok((ciphertext, nonce.to_vec()))
}
}
self.cipher.as_aead().encrypt(plaintext, b"")
}
fn decrypt_raw(&self, ciphertext: &[u8], nonce: &[u8]) -> Result<Vec<u8>> {
match &self.cipher {
Cipher::Aes256Gcm(cipher) => {
let nonce = Nonce::from_slice(nonce);
let plaintext = cipher.decrypt(nonce, ciphertext).map_err(|e| {
crate::LimboError::InternalError(format!("Decryption failed: {e:?}"))
})?;
Ok(plaintext)
}
Cipher::Aegis256(cipher) => {
let nonce_array: [u8; 32] = nonce.try_into().map_err(|_| {
LimboError::InternalError(format!(
"Invalid nonce size for AEGIS-256: expected 32, got {}",
nonce.len()
))
})?;
let ad = b"";
cipher.decrypt(ciphertext, &nonce_array, ad)
}
}
self.cipher.as_aead().decrypt(ciphertext, nonce, b"")
}
fn encrypt_raw_detached(&self, plaintext: &[u8]) -> Result<(Vec<u8>, Vec<u8>, Vec<u8>)> {
self.cipher.as_aead().encrypt_detached(plaintext, b"")
}
fn decrypt_raw_detached(&self, ciphertext: &[u8], nonce: &[u8], tag: &[u8]) -> Result<Vec<u8>> {
self.cipher
.as_aead()
.decrypt_detached(ciphertext, nonce, tag, b"")
}
#[cfg(not(feature = "encryption"))]
@@ -391,10 +511,12 @@ fn generate_secure_nonce() -> [u8; 32] {
nonce
}
#[cfg(feature = "encryption")]
#[cfg(test)]
mod tests {
use super::*;
use rand::Rng;
const DEFAULT_ENCRYPTED_PAGE_SIZE: usize = 4096;
fn generate_random_hex_key() -> String {
let mut rng = rand::thread_rng();
@@ -404,15 +526,14 @@ mod tests {
}
#[test]
#[cfg(feature = "encryption")]
fn test_aes_encrypt_decrypt_round_trip() {
let mut rng = rand::thread_rng();
let cipher_mode = CipherMode::Aes256Gcm;
let metadata_size = cipher_mode.metadata_size();
let data_size = ENCRYPTED_PAGE_SIZE - metadata_size;
let data_size = DEFAULT_ENCRYPTED_PAGE_SIZE - metadata_size;
let page_data = {
let mut page = vec![0u8; ENCRYPTED_PAGE_SIZE];
let mut page = vec![0u8; DEFAULT_ENCRYPTED_PAGE_SIZE];
page.iter_mut()
.take(data_size)
.for_each(|byte| *byte = rng.gen());
@@ -420,21 +541,21 @@ mod tests {
};
let key = EncryptionKey::from_hex_string(&generate_random_hex_key()).unwrap();
let ctx = EncryptionContext::new(CipherMode::Aes256Gcm, &key).unwrap();
let ctx = EncryptionContext::new(CipherMode::Aes256Gcm, &key, DEFAULT_ENCRYPTED_PAGE_SIZE)
.unwrap();
let page_id = 42;
let encrypted = ctx.encrypt_page(&page_data, page_id).unwrap();
assert_eq!(encrypted.len(), ENCRYPTED_PAGE_SIZE);
assert_eq!(encrypted.len(), DEFAULT_ENCRYPTED_PAGE_SIZE);
assert_ne!(&encrypted[..data_size], &page_data[..data_size]);
assert_ne!(&encrypted[..], &page_data[..]);
let decrypted = ctx.decrypt_page(&encrypted, page_id).unwrap();
assert_eq!(decrypted.len(), ENCRYPTED_PAGE_SIZE);
assert_eq!(decrypted.len(), DEFAULT_ENCRYPTED_PAGE_SIZE);
assert_eq!(decrypted, page_data);
}
#[test]
#[cfg(feature = "encryption")]
fn test_aegis256_cipher_wrapper() {
let key = EncryptionKey::from_hex_string(&generate_random_hex_key()).unwrap();
let cipher = Aegis256Cipher::new(&key);
@@ -451,10 +572,10 @@ mod tests {
}
#[test]
#[cfg(feature = "encryption")]
fn test_aegis256_raw_encryption() {
let key = EncryptionKey::from_hex_string(&generate_random_hex_key()).unwrap();
let ctx = EncryptionContext::new(CipherMode::Aegis256, &key).unwrap();
let ctx = EncryptionContext::new(CipherMode::Aegis256, &key, DEFAULT_ENCRYPTED_PAGE_SIZE)
.unwrap();
let plaintext = b"Hello, AEGIS-256!";
let (ciphertext, nonce) = ctx.encrypt_raw(plaintext).unwrap();
@@ -467,15 +588,14 @@ mod tests {
}
#[test]
#[cfg(feature = "encryption")]
fn test_aegis256_encrypt_decrypt_round_trip() {
let mut rng = rand::thread_rng();
let cipher_mode = CipherMode::Aegis256;
let metadata_size = cipher_mode.metadata_size();
let data_size = ENCRYPTED_PAGE_SIZE - metadata_size;
let data_size = DEFAULT_ENCRYPTED_PAGE_SIZE - metadata_size;
let page_data = {
let mut page = vec![0u8; ENCRYPTED_PAGE_SIZE];
let mut page = vec![0u8; DEFAULT_ENCRYPTED_PAGE_SIZE];
page.iter_mut()
.take(data_size)
.for_each(|byte| *byte = rng.gen());
@@ -483,15 +603,16 @@ mod tests {
};
let key = EncryptionKey::from_hex_string(&generate_random_hex_key()).unwrap();
let ctx = EncryptionContext::new(CipherMode::Aegis256, &key).unwrap();
let ctx = EncryptionContext::new(CipherMode::Aegis256, &key, DEFAULT_ENCRYPTED_PAGE_SIZE)
.unwrap();
let page_id = 42;
let encrypted = ctx.encrypt_page(&page_data, page_id).unwrap();
assert_eq!(encrypted.len(), ENCRYPTED_PAGE_SIZE);
assert_eq!(encrypted.len(), DEFAULT_ENCRYPTED_PAGE_SIZE);
assert_ne!(&encrypted[..data_size], &page_data[..data_size]);
let decrypted = ctx.decrypt_page(&encrypted, page_id).unwrap();
assert_eq!(decrypted.len(), ENCRYPTED_PAGE_SIZE);
assert_eq!(decrypted.len(), DEFAULT_ENCRYPTED_PAGE_SIZE);
assert_eq!(decrypted, page_data);
}
}

View File

@@ -12,7 +12,7 @@ use super::pager::PageRef;
const DEFAULT_PAGE_CACHE_SIZE_IN_PAGES_MAKE_ME_SMALLER_ONCE_WAL_SPILL_IS_IMPLEMENTED: usize =
100000;
#[derive(Debug, Eq, Hash, PartialEq, Clone)]
#[derive(Debug, Eq, Hash, PartialEq, Clone, Copy)]
pub struct PageCacheKey {
pgno: usize,
}
@@ -47,14 +47,21 @@ struct HashMapNode {
value: NonNull<PageCacheEntry>,
}
#[derive(Debug, PartialEq)]
#[derive(Debug, Clone, PartialEq, thiserror::Error)]
pub enum CacheError {
#[error("{0}")]
InternalError(String),
#[error("page {pgno} is locked")]
Locked { pgno: usize },
#[error("page {pgno} is dirty")]
Dirty { pgno: usize },
#[error("page {pgno} is pinned")]
Pinned { pgno: usize },
#[error("cache active refs")]
ActiveRefs,
#[error("Page cache is full")]
Full,
#[error("key already exists")]
KeyExists,
}
@@ -105,7 +112,7 @@ impl DumbLruPageCache {
trace!("insert(key={:?})", key);
// Check first if page already exists in cache
if !ignore_exists {
if let Some(existing_page_ref) = self.get(&key) {
if let Some(existing_page_ref) = self.get(&key)? {
assert!(
Arc::ptr_eq(&value, &existing_page_ref),
"Attempted to insert different page with same key: {key:?}"
@@ -115,7 +122,7 @@ impl DumbLruPageCache {
}
self.make_room_for(1)?;
let entry = Box::new(PageCacheEntry {
key: key.clone(),
key,
next: None,
prev: None,
page: value,
@@ -156,8 +163,21 @@ impl DumbLruPageCache {
ptr.copied()
}
pub fn get(&mut self, key: &PageCacheKey) -> Option<PageRef> {
self.peek(key, true)
pub fn get(&mut self, key: &PageCacheKey) -> Result<Option<PageRef>, CacheError> {
if let Some(page) = self.peek(key, true) {
// Because we can abort a read_page completion, this means a page can be in the cache but be unloaded and unlocked.
// However, if we do not evict that page from the page cache, we will return an unloaded page later which will trigger
// assertions later on. This is worsened by the fact that page cache is not per `Statement`, so you can abort a completion
// in one Statement, and trigger some error in the next one if we don't evict the page here.
if !page.is_loaded() && !page.is_locked() {
self.delete(*key)?;
Ok(None)
} else {
Ok(Some(page))
}
} else {
Ok(None)
}
}
/// Get page without promoting entry
@@ -309,7 +329,7 @@ impl DumbLruPageCache {
let entry = unsafe { current.as_ref() };
// Pick prev before modifying entry
current_opt = entry.prev;
match self.delete(entry.key.clone()) {
match self.delete(entry.key) {
Err(_) => {}
Ok(_) => need_to_evict -= 1,
}
@@ -396,7 +416,7 @@ impl DumbLruPageCache {
let mut current = head_ptr;
while let Some(node) = current {
unsafe {
this_keys.push(node.as_ref().key.clone());
this_keys.push(node.as_ref().key);
let node_ref = node.as_ref();
current = node_ref.next;
}
@@ -647,7 +667,7 @@ impl PageHashMap {
pub fn rehash(&self, new_capacity: usize) -> PageHashMap {
let mut new_hash_map = PageHashMap::new(new_capacity);
for node in self.iter() {
new_hash_map.insert(node.key.clone(), node.value);
new_hash_map.insert(node.key, node.value);
}
new_hash_map
}
@@ -698,7 +718,7 @@ mod tests {
fn insert_page(cache: &mut DumbLruPageCache, id: usize) -> PageCacheKey {
let key = create_key(id);
let page = page_with_content(id);
assert!(cache.insert(key.clone(), page).is_ok());
assert!(cache.insert(key, page).is_ok());
key
}
@@ -712,7 +732,7 @@ mod tests {
) -> (PageCacheKey, NonNull<PageCacheEntry>) {
let key = create_key(id);
let page = page_with_content(id);
assert!(cache.insert(key.clone(), page).is_ok());
assert!(cache.insert(key, page).is_ok());
let entry = cache.get_ptr(&key).expect("Entry should exist");
(key, entry)
}
@@ -727,7 +747,7 @@ mod tests {
assert!(cache.tail.borrow().is_some());
assert_eq!(*cache.head.borrow(), *cache.tail.borrow());
assert!(cache.delete(key1.clone()).is_ok());
assert!(cache.delete(key1).is_ok());
assert_eq!(
cache.len(),
@@ -759,7 +779,7 @@ mod tests {
"Initial head check"
);
assert!(cache.delete(key3.clone()).is_ok());
assert!(cache.delete(key3).is_ok());
assert_eq!(cache.len(), 2, "Length should be 2 after deleting head");
assert!(
@@ -803,7 +823,7 @@ mod tests {
"Initial tail check"
);
assert!(cache.delete(key1.clone()).is_ok()); // Delete tail
assert!(cache.delete(key1).is_ok()); // Delete tail
assert_eq!(cache.len(), 2, "Length should be 2 after deleting tail");
assert!(
@@ -854,7 +874,7 @@ mod tests {
let head_ptr_before = cache.head.borrow().unwrap();
let tail_ptr_before = cache.tail.borrow().unwrap();
assert!(cache.delete(key2.clone()).is_ok()); // Detach a middle element (key2)
assert!(cache.delete(key2).is_ok()); // Detach a middle element (key2)
assert_eq!(cache.len(), 3, "Length should be 3 after deleting middle");
assert!(
@@ -895,11 +915,11 @@ mod tests {
let mut cache = DumbLruPageCache::default();
let key1 = create_key(1);
let page1 = page_with_content(1);
assert!(cache.insert(key1.clone(), page1.clone()).is_ok());
assert!(cache.insert(key1, page1.clone()).is_ok());
assert!(page_has_content(&page1));
cache.verify_list_integrity();
let result = cache.delete(key1.clone());
let result = cache.delete(key1);
assert!(result.is_err());
assert_eq!(result.unwrap_err(), CacheError::ActiveRefs);
assert_eq!(cache.len(), 1);
@@ -918,10 +938,10 @@ mod tests {
let key1 = create_key(1);
let page1_v1 = page_with_content(1);
let page1_v2 = page_with_content(1);
assert!(cache.insert(key1.clone(), page1_v1.clone()).is_ok());
assert!(cache.insert(key1, page1_v1.clone()).is_ok());
assert_eq!(cache.len(), 1);
cache.verify_list_integrity();
let _ = cache.insert(key1.clone(), page1_v2.clone()); // Panic
let _ = cache.insert(key1, page1_v2.clone()); // Panic
}
#[test]
@@ -929,7 +949,7 @@ mod tests {
let mut cache = DumbLruPageCache::default();
let key_nonexist = create_key(99);
assert!(cache.delete(key_nonexist.clone()).is_ok()); // no-op
assert!(cache.delete(key_nonexist).is_ok()); // no-op
}
#[test]
@@ -937,8 +957,8 @@ mod tests {
let mut cache = DumbLruPageCache::new(1);
let key1 = insert_page(&mut cache, 1);
let key2 = insert_page(&mut cache, 2);
assert_eq!(cache.get(&key2).unwrap().get().id, 2);
assert!(cache.get(&key1).is_none());
assert_eq!(cache.get(&key2).unwrap().unwrap().get().id, 2);
assert!(cache.get(&key1).unwrap().is_none());
}
#[test]
@@ -1002,7 +1022,7 @@ mod tests {
fn test_detach_with_cleaning() {
let mut cache = DumbLruPageCache::default();
let (key, entry) = insert_and_get_entry(&mut cache, 1);
let page = cache.get(&key).expect("Page should exist");
let page = cache.get(&key).unwrap().expect("Page should exist");
assert!(page_has_content(&page));
drop(page);
assert!(cache.detach(entry, true).is_ok());
@@ -1034,8 +1054,8 @@ mod tests {
let (key1, _) = insert_and_get_entry(&mut cache, 1);
let (key2, entry2) = insert_and_get_entry(&mut cache, 2);
let (key3, _) = insert_and_get_entry(&mut cache, 3);
let head_key = unsafe { cache.head.borrow().unwrap().as_ref().key.clone() };
let tail_key = unsafe { cache.tail.borrow().unwrap().as_ref().key.clone() };
let head_key = unsafe { cache.head.borrow().unwrap().as_ref().key };
let tail_key = unsafe { cache.tail.borrow().unwrap().as_ref().key };
assert_eq!(head_key, key3, "Head should be key3");
assert_eq!(tail_key, key1, "Tail should be key1");
assert!(cache.detach(entry2, false).is_ok());
@@ -1044,12 +1064,12 @@ mod tests {
assert_eq!(head_entry.key, key3, "Head should still be key3");
assert_eq!(tail_entry.key, key1, "Tail should still be key1");
assert_eq!(
unsafe { head_entry.next.unwrap().as_ref().key.clone() },
unsafe { head_entry.next.unwrap().as_ref().key },
key1,
"Head's next should point to tail after middle element detached"
);
assert_eq!(
unsafe { tail_entry.prev.unwrap().as_ref().key.clone() },
unsafe { tail_entry.prev.unwrap().as_ref().key },
key3,
"Tail's prev should point to head after middle element detached"
);
@@ -1085,7 +1105,7 @@ mod tests {
continue; // skip duplicate page ids
}
tracing::debug!("inserting page {:?}", key);
match cache.insert(key.clone(), page.clone()) {
match cache.insert(key, page.clone()) {
Err(CacheError::Full | CacheError::ActiveRefs) => {} // Ignore
Err(err) => {
// Any other error should fail the test
@@ -1106,7 +1126,7 @@ mod tests {
PageCacheKey::new(id_page as usize)
} else {
let i = rng.next_u64() as usize % lru.len();
let key: PageCacheKey = lru.iter().nth(i).unwrap().0.clone();
let key: PageCacheKey = *lru.iter().nth(i).unwrap().0;
key
};
tracing::debug!("removing page {:?}", key);
@@ -1133,7 +1153,7 @@ mod tests {
let this_keys = cache.keys();
let mut lru_keys = Vec::new();
for (lru_key, _) in lru {
lru_keys.push(lru_key.clone());
lru_keys.push(*lru_key);
}
if this_keys != lru_keys {
cache.print();
@@ -1149,8 +1169,8 @@ mod tests {
let mut cache = DumbLruPageCache::default();
let key1 = insert_page(&mut cache, 1);
let key2 = insert_page(&mut cache, 2);
assert_eq!(cache.get(&key1).unwrap().get().id, 1);
assert_eq!(cache.get(&key2).unwrap().get().id, 2);
assert_eq!(cache.get(&key1).unwrap().unwrap().get().id, 1);
assert_eq!(cache.get(&key2).unwrap().unwrap().get().id, 2);
}
#[test]
@@ -1159,17 +1179,17 @@ mod tests {
let key1 = insert_page(&mut cache, 1);
let key2 = insert_page(&mut cache, 2);
let key3 = insert_page(&mut cache, 3);
assert!(cache.get(&key1).is_none());
assert_eq!(cache.get(&key2).unwrap().get().id, 2);
assert_eq!(cache.get(&key3).unwrap().get().id, 3);
assert!(cache.get(&key1).unwrap().is_none());
assert_eq!(cache.get(&key2).unwrap().unwrap().get().id, 2);
assert_eq!(cache.get(&key3).unwrap().unwrap().get().id, 3);
}
#[test]
fn test_page_cache_delete() {
let mut cache = DumbLruPageCache::default();
let key1 = insert_page(&mut cache, 1);
assert!(cache.delete(key1.clone()).is_ok());
assert!(cache.get(&key1).is_none());
assert!(cache.delete(key1).is_ok());
assert!(cache.get(&key1).unwrap().is_none());
}
#[test]
@@ -1178,8 +1198,8 @@ mod tests {
let key1 = insert_page(&mut cache, 1);
let key2 = insert_page(&mut cache, 2);
assert!(cache.clear().is_ok());
assert!(cache.get(&key1).is_none());
assert!(cache.get(&key2).is_none());
assert!(cache.get(&key1).unwrap().is_none());
assert!(cache.get(&key2).unwrap().is_none());
}
#[test]
@@ -1216,8 +1236,8 @@ mod tests {
assert_eq!(result, CacheResizeResult::Done);
assert_eq!(cache.len(), 2);
assert_eq!(cache.capacity, 5);
assert!(cache.get(&create_key(1)).is_some());
assert!(cache.get(&create_key(2)).is_some());
assert!(cache.get(&create_key(1)).unwrap().is_some());
assert!(cache.get(&create_key(2)).unwrap().is_some());
for i in 3..=5 {
let _ = insert_page(&mut cache, i);
}

View File

@@ -1123,7 +1123,7 @@ impl Pager {
tracing::trace!("read_page(page_idx = {})", page_idx);
let mut page_cache = self.page_cache.write();
let page_key = PageCacheKey::new(page_idx);
if let Some(page) = page_cache.get(&page_key) {
if let Some(page) = page_cache.get(&page_key)? {
tracing::trace!("read_page(page_idx = {}) = cached", page_idx);
return Ok((page.clone(), None));
}
@@ -1158,25 +1158,20 @@ impl Pager {
let page_key = PageCacheKey::new(page_idx);
match page_cache.insert(page_key, page.clone()) {
Ok(_) => {}
Err(CacheError::Full) => return Err(LimboError::CacheFull),
Err(CacheError::KeyExists) => {
unreachable!("Page should not exist in cache after get() miss")
}
Err(e) => {
return Err(LimboError::InternalError(format!(
"Failed to insert page into cache: {e:?}"
)))
}
Err(e) => return Err(e.into()),
}
Ok(())
}
// Get a page from the cache, if it exists.
pub fn cache_get(&self, page_idx: usize) -> Option<PageRef> {
pub fn cache_get(&self, page_idx: usize) -> Result<Option<PageRef>> {
tracing::trace!("read_page(page_idx = {})", page_idx);
let mut page_cache = self.page_cache.write();
let page_key = PageCacheKey::new(page_idx);
page_cache.get(&page_key)
Ok(page_cache.get(&page_key)?)
}
/// Get a page from cache only if it matches the target frame
@@ -1185,10 +1180,10 @@ impl Pager {
page_idx: usize,
target_frame: u64,
seq: u32,
) -> Option<PageRef> {
) -> Result<Option<PageRef>> {
let mut page_cache = self.page_cache.write();
let page_key = PageCacheKey::new(page_idx);
page_cache.get(&page_key).and_then(|page| {
let page = page_cache.get(&page_key)?.and_then(|page| {
if page.is_valid_for_checkpoint(target_frame, seq) {
tracing::trace!(
"cache_get_for_checkpoint: page {} frame {} is valid",
@@ -1207,7 +1202,8 @@ impl Pager {
);
None
}
})
});
Ok(page)
}
/// Changes the size of the page cache.
@@ -1261,7 +1257,7 @@ impl Pager {
let page = {
let mut cache = self.page_cache.write();
let page_key = PageCacheKey::new(*page_id);
let page = cache.get(&page_key).expect("we somehow added a page to dirty list but we didn't mark it as dirty, causing cache to drop it.");
let page = cache.get(&page_key)?.expect("we somehow added a page to dirty list but we didn't mark it as dirty, causing cache to drop it.");
let page_type = page.get_contents().maybe_page_type();
trace!("cacheflush(page={}, page_type={:?}", page_id, page_type);
page
@@ -1344,7 +1340,7 @@ impl Pager {
let page = {
let mut cache = self.page_cache.write();
let page_key = PageCacheKey::new(page_id);
let page = cache.get(&page_key).expect(
let page = cache.get(&page_key)?.expect(
"dirty list contained a page that cache dropped (page={page_id})",
);
trace!(
@@ -1482,7 +1478,7 @@ impl Pager {
header.db_size as u64,
raw_page,
)?;
if let Some(page) = self.cache_get(header.page_number as usize) {
if let Some(page) = self.cache_get(header.page_number as usize)? {
let content = page.get_contents();
content.as_ptr().copy_from_slice(raw_page);
turso_assert!(
@@ -1505,7 +1501,7 @@ impl Pager {
for page_id in self.dirty_pages.borrow().iter() {
let page_key = PageCacheKey::new(*page_id);
let mut cache = self.page_cache.write();
let page = cache.get(&page_key).expect("we somehow added a page to dirty list but we didn't mark it as dirty, causing cache to drop it.");
let page = cache.get(&page_key)?.expect("we somehow added a page to dirty list but we didn't mark it as dirty, causing cache to drop it.");
page.clear_dirty();
}
self.dirty_pages.borrow_mut().clear();
@@ -1902,15 +1898,7 @@ impl Pager {
self.add_dirty(&page);
let page_key = PageCacheKey::new(page.get().id);
let mut cache = self.page_cache.write();
match cache.insert(page_key, page.clone()) {
Ok(_) => (),
Err(CacheError::Full) => return Err(LimboError::CacheFull),
Err(_) => {
return Err(LimboError::InternalError(
"Unknown error inserting page to cache".into(),
))
}
}
cache.insert(page_key, page.clone())?;
}
}
@@ -2081,15 +2069,7 @@ impl Pager {
{
// Run in separate block to avoid deadlock on page cache write lock
let mut cache = self.page_cache.write();
match cache.insert(page_key, page.clone()) {
Err(CacheError::Full) => return Err(LimboError::CacheFull),
Err(_) => {
return Err(LimboError::InternalError(
"Unknown error inserting page to cache".into(),
))
}
Ok(_) => {}
};
cache.insert(page_key, page.clone())?;
}
header.database_size = new_db_size.into();
*state = AllocatePageState::Start;
@@ -2186,7 +2166,8 @@ impl Pager {
}
pub fn set_encryption_context(&self, cipher_mode: CipherMode, key: &EncryptionKey) {
let encryption_ctx = EncryptionContext::new(cipher_mode, key).unwrap();
let page_size = self.page_size.get().unwrap().get() as usize;
let encryption_ctx = EncryptionContext::new(cipher_mode, key, page_size).unwrap();
{
let mut io_ctx = self.io_ctx.borrow_mut();
io_ctx.set_encryption(encryption_ctx);
@@ -2428,13 +2409,16 @@ mod tests {
std::thread::spawn(move || {
let mut cache = cache.write();
let page_key = PageCacheKey::new(1);
cache.insert(page_key, Arc::new(Page::new(1))).unwrap();
let page = Page::new(1);
// Set loaded so that we avoid eviction, as we evict the page from cache if it is not locked and not loaded
page.set_loaded();
cache.insert(page_key, Arc::new(page)).unwrap();
})
};
let _ = thread.join();
let mut cache = cache.write();
let page_key = PageCacheKey::new(1);
let page = cache.get(&page_key);
let page = cache.get(&page_key).unwrap();
assert_eq!(page.unwrap().get().id, 1);
}
}

View File

@@ -1838,7 +1838,7 @@ pub fn read_entire_wal_dumb(file: &Arc<dyn File>) -> Result<Arc<UnsafeCell<WalFi
pub fn begin_read_wal_frame_raw(
buffer_pool: &Arc<BufferPool>,
io: &Arc<dyn File>,
offset: usize,
offset: u64,
complete: Box<ReadComplete>,
) -> Result<Completion> {
tracing::trace!("begin_read_wal_frame_raw(offset={})", offset);
@@ -1851,7 +1851,7 @@ pub fn begin_read_wal_frame_raw(
pub fn begin_read_wal_frame(
io: &Arc<dyn File>,
offset: usize,
offset: u64,
buffer_pool: Arc<BufferPool>,
complete: Box<ReadComplete>,
page_idx: usize,

View File

@@ -1082,7 +1082,7 @@ impl Wal for WalFile {
});
begin_read_wal_frame(
&self.get_shared().file,
offset + WAL_FRAME_HEADER_SIZE,
offset + WAL_FRAME_HEADER_SIZE as u64,
buffer_pool,
complete,
page_idx,
@@ -1095,6 +1095,11 @@ impl Wal for WalFile {
tracing::debug!("read_frame({})", frame_id);
let offset = self.frame_offset(frame_id);
let (frame_ptr, frame_len) = (frame.as_mut_ptr(), frame.len());
let encryption_ctx = {
let io_ctx = self.io_ctx.borrow();
io_ctx.encryption_context().cloned()
};
let complete = Box::new(move |res: Result<(Arc<Buffer>, i32), CompletionError>| {
let Ok((buf, bytes_read)) = res else {
return;
@@ -1104,10 +1109,34 @@ impl Wal for WalFile {
bytes_read == buf_len as i32,
"read({bytes_read}) != expected({buf_len})"
);
let buf_ptr = buf.as_mut_ptr();
let buf_ptr = buf.as_ptr();
let frame_ref: &mut [u8] =
unsafe { std::slice::from_raw_parts_mut(frame_ptr, frame_len) };
// Copy the just-read WAL frame into the destination buffer
unsafe {
std::ptr::copy_nonoverlapping(buf_ptr, frame_ptr, frame_len);
}
// Now parse the header from the freshly-copied data
let (header, raw_page) = sqlite3_ondisk::parse_wal_frame_header(frame_ref);
if let Some(ctx) = encryption_ctx.clone() {
match ctx.decrypt_page(raw_page, header.page_number as usize) {
Ok(decrypted_data) => {
turso_assert!(
(frame_len - WAL_FRAME_HEADER_SIZE) == decrypted_data.len(),
"frame_len - header_size({}) != expected({})",
frame_len - WAL_FRAME_HEADER_SIZE,
decrypted_data.len()
);
frame_ref[WAL_FRAME_HEADER_SIZE..].copy_from_slice(&decrypted_data);
}
Err(_) => {
tracing::error!("Failed to decrypt page data for frame_id={frame_id}");
}
}
}
});
let c =
begin_read_wal_frame_raw(&self.buffer_pool, &self.get_shared().file, offset, complete)?;
@@ -1167,7 +1196,7 @@ impl Wal for WalFile {
});
let c = begin_read_wal_frame(
&self.get_shared().file,
offset + WAL_FRAME_HEADER_SIZE,
offset + WAL_FRAME_HEADER_SIZE as u64,
buffer_pool,
complete,
page_id as usize,
@@ -1431,14 +1460,14 @@ impl Wal for WalFile {
let mut next_frame_id = self.max_frame + 1;
// Build every frame in order, updating the rolling checksum
for (idx, page) in pages.iter().enumerate() {
let page_id = page.get().id as u64;
let page_id = page.get().id;
let plain = page.get_contents().as_ptr();
let data_to_write: std::borrow::Cow<[u8]> = {
let io_ctx = self.io_ctx.borrow();
let ectx = io_ctx.encryption_context();
if let Some(ctx) = ectx.as_ref() {
Cow::Owned(ctx.encrypt_page(plain, page_id as usize)?)
Cow::Owned(ctx.encrypt_page(plain, page_id)?)
} else {
Cow::Borrowed(plain)
}
@@ -1552,11 +1581,10 @@ impl WalFile {
self.get_shared().wal_header.lock().page_size
}
fn frame_offset(&self, frame_id: u64) -> usize {
fn frame_offset(&self, frame_id: u64) -> u64 {
assert!(frame_id > 0, "Frame ID must be 1-based");
let page_offset = (frame_id - 1) * (self.page_size() + WAL_FRAME_HEADER_SIZE as u32) as u64;
let offset = WAL_HEADER_SIZE as u64 + page_offset;
offset as usize
WAL_HEADER_SIZE as u64 + page_offset
}
#[allow(clippy::mut_from_ref)]
@@ -1748,7 +1776,7 @@ impl WalFile {
// Try cache first, if enabled
if let Some(cached_page) =
pager.cache_get_for_checkpoint(page_id as usize, target_frame, seq)
pager.cache_get_for_checkpoint(page_id as usize, target_frame, seq)?
{
let contents = cached_page.get_contents();
let buffer = contents.buffer.clone();
@@ -1805,7 +1833,7 @@ impl WalFile {
self.ongoing_checkpoint.pages_to_checkpoint.iter()
{
if *cached {
let page = pager.cache_get((*page_id) as usize);
let page = pager.cache_get((*page_id) as usize)?;
turso_assert!(
page.is_some(),
"page should still exist in the page cache"
@@ -2102,7 +2130,7 @@ impl WalFile {
// schedule read of the page payload
let c = begin_read_wal_frame(
&self.get_shared().file,
offset + WAL_FRAME_HEADER_SIZE,
offset + WAL_FRAME_HEADER_SIZE as u64,
self.buffer_pool.clone(),
complete,
page_id,
@@ -2288,7 +2316,7 @@ pub mod test {
let done = Rc::new(Cell::new(false));
let _done = done.clone();
let _ = file.file.truncate(
WAL_HEADER_SIZE,
WAL_HEADER_SIZE as u64,
Completion::new_trunc(move |_| {
let done = _done.clone();
done.set(true);

View File

@@ -125,27 +125,161 @@ pub fn handle_distinct(program: &mut ProgramBuilder, agg: &Aggregate, agg_arg_re
});
}
/// Emits the bytecode for processing an aggregate step.
/// E.g. in `SELECT SUM(price) FROM t`, 'price' is evaluated for every row, and the result is added to the accumulator.
/// Enum representing the source of the aggregate function arguments
///
/// This is distinct from the final step, which is called after the main loop has finished processing
/// Aggregate arguments can come from different sources, depending on how the aggregation
/// is evaluated:
/// * In the common grouped case, the aggregate function arguments are first inserted
/// into a sorter in the main loop, and in the group by aggregation phase we read
/// the data from the sorter.
/// * In grouped cases where no sorting is required, arguments are retrieved directly
/// from registers allocated in the main loop.
/// * In ungrouped cases, arguments are computed directly from the `args` expressions.
pub enum AggArgumentSource<'a> {
/// The aggregate function arguments are retrieved from a pseudo cursor
/// which reads from the GROUP BY sorter.
PseudoCursor {
cursor_id: usize,
col_start: usize,
dest_reg_start: usize,
aggregate: &'a Aggregate,
},
/// The aggregate function arguments are retrieved from a contiguous block of registers
/// allocated in the main loop for that given aggregate function.
Register {
src_reg_start: usize,
aggregate: &'a Aggregate,
},
/// The aggregate function arguments are retrieved by evaluating expressions.
Expression { aggregate: &'a Aggregate },
}
impl<'a> AggArgumentSource<'a> {
/// Create a new [AggArgumentSource] that retrieves the values from a GROUP BY sorter.
pub fn new_from_cursor(
program: &mut ProgramBuilder,
cursor_id: usize,
col_start: usize,
aggregate: &'a Aggregate,
) -> Self {
let dest_reg_start = program.alloc_registers(aggregate.args.len());
Self::PseudoCursor {
cursor_id,
col_start,
dest_reg_start,
aggregate,
}
}
/// Create a new [AggArgumentSource] that retrieves the values directly from an already
/// populated register or registers.
pub fn new_from_registers(src_reg_start: usize, aggregate: &'a Aggregate) -> Self {
Self::Register {
src_reg_start,
aggregate,
}
}
/// Create a new [AggArgumentSource] that retrieves the values by evaluating `args` expressions.
pub fn new_from_expression(aggregate: &'a Aggregate) -> Self {
Self::Expression { aggregate }
}
pub fn aggregate(&self) -> &Aggregate {
match self {
AggArgumentSource::PseudoCursor { aggregate, .. } => aggregate,
AggArgumentSource::Register { aggregate, .. } => aggregate,
AggArgumentSource::Expression { aggregate } => aggregate,
}
}
pub fn agg_func(&self) -> &AggFunc {
match self {
AggArgumentSource::PseudoCursor { aggregate, .. } => &aggregate.func,
AggArgumentSource::Register { aggregate, .. } => &aggregate.func,
AggArgumentSource::Expression { aggregate } => &aggregate.func,
}
}
pub fn args(&self) -> &[ast::Expr] {
match self {
AggArgumentSource::PseudoCursor { aggregate, .. } => &aggregate.args,
AggArgumentSource::Register { aggregate, .. } => &aggregate.args,
AggArgumentSource::Expression { aggregate } => &aggregate.args,
}
}
pub fn num_args(&self) -> usize {
match self {
AggArgumentSource::PseudoCursor { aggregate, .. } => aggregate.args.len(),
AggArgumentSource::Register { aggregate, .. } => aggregate.args.len(),
AggArgumentSource::Expression { aggregate } => aggregate.args.len(),
}
}
/// Read the value of an aggregate function argument
pub fn translate(
&self,
program: &mut ProgramBuilder,
referenced_tables: &TableReferences,
resolver: &Resolver,
arg_idx: usize,
) -> Result<usize> {
match self {
AggArgumentSource::PseudoCursor {
cursor_id,
col_start,
dest_reg_start,
..
} => {
program.emit_column_or_rowid(
*cursor_id,
*col_start + arg_idx,
dest_reg_start + arg_idx,
);
Ok(dest_reg_start + arg_idx)
}
AggArgumentSource::Register {
src_reg_start: start_reg,
..
} => Ok(*start_reg + arg_idx),
AggArgumentSource::Expression { aggregate } => {
let dest_reg = program.alloc_register();
translate_expr(
program,
Some(referenced_tables),
&aggregate.args[arg_idx],
dest_reg,
resolver,
)
}
}
}
}
/// Emits the bytecode for processing an aggregate step.
///
/// This is distinct from the final step, which is called after a single group has been entirely accumulated,
/// and the actual result value of the aggregation is materialized.
///
/// Ungrouped aggregation is a special case of grouped aggregation that involves a single group.
///
/// Examples:
/// * In `SELECT SUM(price) FROM t`, `price` is evaluated for each row and added to the accumulator.
/// * In `SELECT product_category, SUM(price) FROM t GROUP BY product_category`, `price` is evaluated for
/// each row in the group and added to that groups accumulator.
pub fn translate_aggregation_step(
program: &mut ProgramBuilder,
referenced_tables: &TableReferences,
agg: &Aggregate,
agg_arg_source: AggArgumentSource,
target_register: usize,
resolver: &Resolver,
) -> Result<usize> {
let dest = match agg.func {
let num_args = agg_arg_source.num_args();
let func = agg_arg_source.agg_func();
let dest = match func {
AggFunc::Avg => {
if agg.args.len() != 1 {
if num_args != 1 {
crate::bail_parse_error!("avg bad number of arguments");
}
let expr = &agg.args[0];
let expr_reg = program.alloc_register();
let _ = translate_expr(program, Some(referenced_tables), expr, expr_reg, resolver)?;
handle_distinct(program, agg, expr_reg);
let expr_reg = agg_arg_source.translate(program, referenced_tables, resolver, 0)?;
handle_distinct(program, agg_arg_source.aggregate(), expr_reg);
program.emit_insn(Insn::AggStep {
acc_reg: target_register,
col: expr_reg,
@@ -155,20 +289,16 @@ pub fn translate_aggregation_step(
target_register
}
AggFunc::Count | AggFunc::Count0 => {
let expr_reg = if agg.args.is_empty() {
program.alloc_register()
} else {
let expr = &agg.args[0];
let expr_reg = program.alloc_register();
let _ = translate_expr(program, Some(referenced_tables), expr, expr_reg, resolver)?;
expr_reg
};
handle_distinct(program, agg, expr_reg);
if num_args != 1 {
crate::bail_parse_error!("count bad number of arguments");
}
let expr_reg = agg_arg_source.translate(program, referenced_tables, resolver, 0)?;
handle_distinct(program, agg_arg_source.aggregate(), expr_reg);
program.emit_insn(Insn::AggStep {
acc_reg: target_register,
col: expr_reg,
delimiter: 0,
func: if matches!(agg.func, AggFunc::Count0) {
func: if matches!(func, AggFunc::Count0) {
AggFunc::Count0
} else {
AggFunc::Count
@@ -177,18 +307,16 @@ pub fn translate_aggregation_step(
target_register
}
AggFunc::GroupConcat => {
if agg.args.len() != 1 && agg.args.len() != 2 {
if num_args != 1 && num_args != 2 {
crate::bail_parse_error!("group_concat bad number of arguments");
}
let expr_reg = program.alloc_register();
let delimiter_reg = program.alloc_register();
let expr = &agg.args[0];
let delimiter_expr: ast::Expr;
if agg.args.len() == 2 {
match &agg.args[1] {
if num_args == 2 {
match &agg_arg_source.args()[1] {
arg @ ast::Expr::Column { .. } => {
delimiter_expr = arg.clone();
}
@@ -201,8 +329,8 @@ pub fn translate_aggregation_step(
delimiter_expr = ast::Expr::Literal(ast::Literal::String(String::from("\",\"")));
}
translate_expr(program, Some(referenced_tables), expr, expr_reg, resolver)?;
handle_distinct(program, agg, expr_reg);
let expr_reg = agg_arg_source.translate(program, referenced_tables, resolver, 0)?;
handle_distinct(program, agg_arg_source.aggregate(), expr_reg);
translate_expr(
program,
Some(referenced_tables),
@@ -221,13 +349,12 @@ pub fn translate_aggregation_step(
target_register
}
AggFunc::Max => {
if agg.args.len() != 1 {
if num_args != 1 {
crate::bail_parse_error!("max bad number of arguments");
}
let expr = &agg.args[0];
let expr_reg = program.alloc_register();
let _ = translate_expr(program, Some(referenced_tables), expr, expr_reg, resolver)?;
handle_distinct(program, agg, expr_reg);
let expr_reg = agg_arg_source.translate(program, referenced_tables, resolver, 0)?;
handle_distinct(program, agg_arg_source.aggregate(), expr_reg);
let expr = &agg_arg_source.args()[0];
emit_collseq_if_needed(program, referenced_tables, expr);
program.emit_insn(Insn::AggStep {
acc_reg: target_register,
@@ -238,13 +365,12 @@ pub fn translate_aggregation_step(
target_register
}
AggFunc::Min => {
if agg.args.len() != 1 {
if num_args != 1 {
crate::bail_parse_error!("min bad number of arguments");
}
let expr = &agg.args[0];
let expr_reg = program.alloc_register();
let _ = translate_expr(program, Some(referenced_tables), expr, expr_reg, resolver)?;
handle_distinct(program, agg, expr_reg);
let expr_reg = agg_arg_source.translate(program, referenced_tables, resolver, 0)?;
handle_distinct(program, agg_arg_source.aggregate(), expr_reg);
let expr = &agg_arg_source.args()[0];
emit_collseq_if_needed(program, referenced_tables, expr);
program.emit_insn(Insn::AggStep {
acc_reg: target_register,
@@ -256,23 +382,12 @@ pub fn translate_aggregation_step(
}
#[cfg(feature = "json")]
AggFunc::JsonGroupObject | AggFunc::JsonbGroupObject => {
if agg.args.len() != 2 {
if num_args != 2 {
crate::bail_parse_error!("max bad number of arguments");
}
let expr = &agg.args[0];
let expr_reg = program.alloc_register();
let value_expr = &agg.args[1];
let value_reg = program.alloc_register();
let _ = translate_expr(program, Some(referenced_tables), expr, expr_reg, resolver)?;
handle_distinct(program, agg, expr_reg);
let _ = translate_expr(
program,
Some(referenced_tables),
value_expr,
value_reg,
resolver,
)?;
let expr_reg = agg_arg_source.translate(program, referenced_tables, resolver, 0)?;
handle_distinct(program, agg_arg_source.aggregate(), expr_reg);
let value_reg = agg_arg_source.translate(program, referenced_tables, resolver, 1)?;
program.emit_insn(Insn::AggStep {
acc_reg: target_register,
@@ -284,13 +399,11 @@ pub fn translate_aggregation_step(
}
#[cfg(feature = "json")]
AggFunc::JsonGroupArray | AggFunc::JsonbGroupArray => {
if agg.args.len() != 1 {
if num_args != 1 {
crate::bail_parse_error!("max bad number of arguments");
}
let expr = &agg.args[0];
let expr_reg = program.alloc_register();
let _ = translate_expr(program, Some(referenced_tables), expr, expr_reg, resolver)?;
handle_distinct(program, agg, expr_reg);
let expr_reg = agg_arg_source.translate(program, referenced_tables, resolver, 0)?;
handle_distinct(program, agg_arg_source.aggregate(), expr_reg);
program.emit_insn(Insn::AggStep {
acc_reg: target_register,
col: expr_reg,
@@ -300,15 +413,13 @@ pub fn translate_aggregation_step(
target_register
}
AggFunc::StringAgg => {
if agg.args.len() != 2 {
if num_args != 2 {
crate::bail_parse_error!("string_agg bad number of arguments");
}
let expr_reg = program.alloc_register();
let delimiter_reg = program.alloc_register();
let expr = &agg.args[0];
let delimiter_expr = match &agg.args[1] {
let delimiter_expr = match &agg_arg_source.args()[1] {
arg @ ast::Expr::Column { .. } => arg.clone(),
ast::Expr::Literal(ast::Literal::String(s)) => {
ast::Expr::Literal(ast::Literal::String(s.to_string()))
@@ -316,7 +427,7 @@ pub fn translate_aggregation_step(
_ => crate::bail_parse_error!("Incorrect delimiter parameter"),
};
translate_expr(program, Some(referenced_tables), expr, expr_reg, resolver)?;
let expr_reg = agg_arg_source.translate(program, referenced_tables, resolver, 0)?;
translate_expr(
program,
Some(referenced_tables),
@@ -335,13 +446,11 @@ pub fn translate_aggregation_step(
target_register
}
AggFunc::Sum => {
if agg.args.len() != 1 {
if num_args != 1 {
crate::bail_parse_error!("sum bad number of arguments");
}
let expr = &agg.args[0];
let expr_reg = program.alloc_register();
let _ = translate_expr(program, Some(referenced_tables), expr, expr_reg, resolver)?;
handle_distinct(program, agg, expr_reg);
let expr_reg = agg_arg_source.translate(program, referenced_tables, resolver, 0)?;
handle_distinct(program, agg_arg_source.aggregate(), expr_reg);
program.emit_insn(Insn::AggStep {
acc_reg: target_register,
col: expr_reg,
@@ -351,13 +460,11 @@ pub fn translate_aggregation_step(
target_register
}
AggFunc::Total => {
if agg.args.len() != 1 {
if num_args != 1 {
crate::bail_parse_error!("total bad number of arguments");
}
let expr = &agg.args[0];
let expr_reg = program.alloc_register();
let _ = translate_expr(program, Some(referenced_tables), expr, expr_reg, resolver)?;
handle_distinct(program, agg, expr_reg);
let expr_reg = agg_arg_source.translate(program, referenced_tables, resolver, 0)?;
handle_distinct(program, agg_arg_source.aggregate(), expr_reg);
program.emit_insn(Insn::AggStep {
acc_reg: target_register,
col: expr_reg,
@@ -367,31 +474,24 @@ pub fn translate_aggregation_step(
target_register
}
AggFunc::External(ref func) => {
let expr_reg = program.alloc_register();
let argc = func.agg_args().map_err(|_| {
LimboError::ExtensionError(
"External aggregate function called with wrong number of arguments".to_string(),
)
})?;
if argc != agg.args.len() {
if argc != num_args {
crate::bail_parse_error!(
"External aggregate function called with wrong number of arguments"
);
}
let expr_reg = agg_arg_source.translate(program, referenced_tables, resolver, 0)?;
for i in 0..argc {
if i != 0 {
let _ = program.alloc_register();
let _ = agg_arg_source.translate(program, referenced_tables, resolver, i)?;
}
let _ = translate_expr(
program,
Some(referenced_tables),
&agg.args[i],
expr_reg + i,
resolver,
)?;
// invariant: distinct aggregates are only supported for single-argument functions
if argc == 1 {
handle_distinct(program, agg, expr_reg + i);
handle_distinct(program, agg_arg_source.aggregate(), expr_reg + i);
}
}
program.emit_insn(Insn::AggStep {

View File

@@ -1,5 +1,8 @@
use std::sync::Arc;
use turso_parser::{ast, parser::Parser};
use turso_parser::{
ast::{self, fmt::ToTokens as _},
parser::Parser,
};
use crate::{
function::{AlterTableFunc, Func},
@@ -166,7 +169,7 @@ pub fn translate_alter_table(
)?
}
ast::AlterTableBody::AddColumn(col_def) => {
let column = Column::from(col_def);
let column = Column::from(&col_def);
if let Some(default) = &column.default {
if !matches!(
@@ -233,97 +236,6 @@ pub fn translate_alter_table(
},
)?
}
ast::AlterTableBody::RenameColumn { old, new } => {
let rename_from = old.as_str();
let rename_to = new.as_str();
let Some((column_index, _)) = btree.get_column(rename_from) else {
return Err(LimboError::ParseError(format!(
"no such column: \"{rename_from}\""
)));
};
if btree.get_column(rename_to).is_some() {
return Err(LimboError::ParseError(format!(
"duplicate column name: \"{rename_from}\""
)));
};
let sqlite_schema = schema
.get_btree_table(SQLITE_TABLEID)
.expect("sqlite_schema should be on schema");
let cursor_id = program.alloc_cursor_id(CursorType::BTreeTable(sqlite_schema.clone()));
program.emit_insn(Insn::OpenWrite {
cursor_id,
root_page: RegisterOrLiteral::Literal(sqlite_schema.root_page),
db: 0,
});
program.cursor_loop(cursor_id, |program, rowid| {
let sqlite_schema_column_len = sqlite_schema.columns.len();
assert_eq!(sqlite_schema_column_len, 5);
let first_column = program.alloc_registers(sqlite_schema_column_len);
for i in 0..sqlite_schema_column_len {
program.emit_column_or_rowid(cursor_id, i, first_column + i);
}
program.emit_string8_new_reg(table_name.to_string());
program.mark_last_insn_constant();
program.emit_string8_new_reg(rename_from.to_string());
program.mark_last_insn_constant();
program.emit_string8_new_reg(rename_to.to_string());
program.mark_last_insn_constant();
let out = program.alloc_registers(sqlite_schema_column_len);
program.emit_insn(Insn::Function {
constant_mask: 0,
start_reg: first_column,
dest: out,
func: crate::function::FuncCtx {
func: Func::AlterTable(AlterTableFunc::RenameColumn),
arg_count: 8,
},
});
let record = program.alloc_register();
program.emit_insn(Insn::MakeRecord {
start_reg: out,
count: sqlite_schema_column_len,
dest_reg: record,
index_name: None,
});
program.emit_insn(Insn::Insert {
cursor: cursor_id,
key_reg: rowid,
record_reg: record,
flag: crate::vdbe::insn::InsertFlags(0),
table_name: table_name.to_string(),
});
});
program.emit_insn(Insn::SetCookie {
db: 0,
cookie: Cookie::SchemaVersion,
value: schema.schema_version as i32 + 1,
p5: 0,
});
program.emit_insn(Insn::RenameColumn {
table: table_name.to_owned(),
column_index,
name: rename_to.to_owned(),
});
program
}
ast::AlterTableBody::RenameTo(new_name) => {
let new_name = new_name.as_str();
@@ -409,6 +321,148 @@ pub fn translate_alter_table(
to: new_name.to_owned(),
});
program
}
body @ (ast::AlterTableBody::AlterColumn { .. }
| ast::AlterTableBody::RenameColumn { .. }) => {
let from;
let definition;
let col_name;
let rename;
match body {
ast::AlterTableBody::AlterColumn { old, new } => {
from = old;
definition = new;
col_name = definition.col_name.clone();
rename = false;
}
ast::AlterTableBody::RenameColumn { old, new } => {
from = old;
definition = ast::ColumnDefinition {
col_name: new.clone(),
col_type: None,
constraints: vec![],
};
col_name = new;
rename = true;
}
_ => unreachable!(),
}
let from = from.as_str();
let col_name = col_name.as_str();
let Some((column_index, _)) = btree.get_column(from) else {
return Err(LimboError::ParseError(format!(
"no such column: \"{from}\""
)));
};
if btree.get_column(col_name).is_some() {
return Err(LimboError::ParseError(format!(
"duplicate column name: \"{col_name}\""
)));
};
if definition
.constraints
.iter()
.any(|c| matches!(c.constraint, ast::ColumnConstraint::PrimaryKey { .. }))
{
return Err(LimboError::ParseError(
"PRIMARY KEY constraint cannot be altered".to_string(),
));
}
if definition
.constraints
.iter()
.any(|c| matches!(c.constraint, ast::ColumnConstraint::Unique { .. }))
{
return Err(LimboError::ParseError(
"UNIQUE constraint cannot be altered".to_string(),
));
}
let sqlite_schema = schema
.get_btree_table(SQLITE_TABLEID)
.expect("sqlite_schema should be on schema");
let cursor_id = program.alloc_cursor_id(CursorType::BTreeTable(sqlite_schema.clone()));
program.emit_insn(Insn::OpenWrite {
cursor_id,
root_page: RegisterOrLiteral::Literal(sqlite_schema.root_page),
db: 0,
});
program.cursor_loop(cursor_id, |program, rowid| {
let sqlite_schema_column_len = sqlite_schema.columns.len();
assert_eq!(sqlite_schema_column_len, 5);
let first_column = program.alloc_registers(sqlite_schema_column_len);
for i in 0..sqlite_schema_column_len {
program.emit_column_or_rowid(cursor_id, i, first_column + i);
}
program.emit_string8_new_reg(table_name.to_string());
program.mark_last_insn_constant();
program.emit_string8_new_reg(from.to_string());
program.mark_last_insn_constant();
program.emit_string8_new_reg(definition.format().unwrap());
program.mark_last_insn_constant();
let out = program.alloc_registers(sqlite_schema_column_len);
program.emit_insn(Insn::Function {
constant_mask: 0,
start_reg: first_column,
dest: out,
func: crate::function::FuncCtx {
func: Func::AlterTable(if rename {
AlterTableFunc::RenameColumn
} else {
AlterTableFunc::AlterColumn
}),
arg_count: 8,
},
});
let record = program.alloc_register();
program.emit_insn(Insn::MakeRecord {
start_reg: out,
count: sqlite_schema_column_len,
dest_reg: record,
index_name: None,
});
program.emit_insn(Insn::Insert {
cursor: cursor_id,
key_reg: rowid,
record_reg: record,
flag: crate::vdbe::insn::InsertFlags(0),
table_name: table_name.to_string(),
});
});
program.emit_insn(Insn::SetCookie {
db: 0,
cookie: Cookie::SchemaVersion,
value: schema.schema_version as i32 + 1,
p5: 0,
});
program.emit_insn(Insn::AlterColumn {
table: table_name.to_owned(),
column_index,
definition,
rename,
});
program
}
})

View File

@@ -1,9 +1,16 @@
use turso_parser::ast;
use super::{
emitter::TranslateCtx,
expr::{translate_condition_expr, translate_expr, ConditionMetadata},
order_by::order_by_sorter_insert,
plan::{Distinctness, GroupBy, SelectPlan},
result_row::emit_select_result,
};
use crate::translate::aggregation::{translate_aggregation_step, AggArgumentSource};
use crate::translate::expr::{walk_expr, WalkControl};
use crate::translate::plan::ResultSetColumn;
use crate::{
function::AggFunc,
schema::PseudoCursorType,
translate::collate::CollationSeq,
util::exprs_are_equivalent,
@@ -15,15 +22,6 @@ use crate::{
Result,
};
use super::{
aggregation::handle_distinct,
emitter::{Resolver, TranslateCtx},
expr::{translate_condition_expr, translate_expr, ConditionMetadata},
order_by::order_by_sorter_insert,
plan::{Aggregate, Distinctness, GroupBy, SelectPlan, TableReferences},
result_row::emit_select_result,
};
/// Labels needed for various jumps in GROUP BY handling.
#[derive(Debug)]
pub struct GroupByLabels {
@@ -394,102 +392,6 @@ pub enum GroupByRowSource {
},
}
/// Enum representing the source of the aggregate function arguments
/// emitted for a group by aggregation.
/// In the common case, the aggregate function arguments are first inserted
/// into a sorter in the main loop, and in the group by aggregation phase
/// we read the data from the sorter.
///
/// In the alternative case, no sorting is required for group by,
/// and the aggregate function arguments are retrieved directly from
/// registers allocated in the main loop.
pub enum GroupByAggArgumentSource<'a> {
/// The aggregate function arguments are retrieved from a pseudo cursor
/// which reads from the GROUP BY sorter.
PseudoCursor {
cursor_id: usize,
col_start: usize,
dest_reg_start: usize,
aggregate: &'a Aggregate,
},
/// The aggregate function arguments are retrieved from a contiguous block of registers
/// allocated in the main loop for that given aggregate function.
Register {
src_reg_start: usize,
aggregate: &'a Aggregate,
},
}
impl<'a> GroupByAggArgumentSource<'a> {
/// Create a new [GroupByAggArgumentSource] that retrieves the values from a GROUP BY sorter.
pub fn new_from_cursor(
program: &mut ProgramBuilder,
cursor_id: usize,
col_start: usize,
aggregate: &'a Aggregate,
) -> Self {
let dest_reg_start = program.alloc_registers(aggregate.args.len());
Self::PseudoCursor {
cursor_id,
col_start,
dest_reg_start,
aggregate,
}
}
/// Create a new [GroupByAggArgumentSource] that retrieves the values directly from an already
/// populated register or registers.
pub fn new_from_registers(src_reg_start: usize, aggregate: &'a Aggregate) -> Self {
Self::Register {
src_reg_start,
aggregate,
}
}
pub fn aggregate(&self) -> &Aggregate {
match self {
GroupByAggArgumentSource::PseudoCursor { aggregate, .. } => aggregate,
GroupByAggArgumentSource::Register { aggregate, .. } => aggregate,
}
}
pub fn agg_func(&self) -> &AggFunc {
match self {
GroupByAggArgumentSource::PseudoCursor { aggregate, .. } => &aggregate.func,
GroupByAggArgumentSource::Register { aggregate, .. } => &aggregate.func,
}
}
pub fn args(&self) -> &[ast::Expr] {
match self {
GroupByAggArgumentSource::PseudoCursor { aggregate, .. } => &aggregate.args,
GroupByAggArgumentSource::Register { aggregate, .. } => &aggregate.args,
}
}
pub fn num_args(&self) -> usize {
match self {
GroupByAggArgumentSource::PseudoCursor { aggregate, .. } => aggregate.args.len(),
GroupByAggArgumentSource::Register { aggregate, .. } => aggregate.args.len(),
}
}
/// Read the value of an aggregate function argument either from sorter data or directly from a register.
pub fn translate(&self, program: &mut ProgramBuilder, arg_idx: usize) -> Result<usize> {
match self {
GroupByAggArgumentSource::PseudoCursor {
cursor_id,
col_start,
dest_reg_start,
..
} => {
program.emit_column_or_rowid(*cursor_id, *col_start, dest_reg_start + arg_idx);
Ok(dest_reg_start + arg_idx)
}
GroupByAggArgumentSource::Register {
src_reg_start: start_reg,
..
} => Ok(*start_reg + arg_idx),
}
}
}
/// Emits bytecode for processing a single GROUP BY group.
pub fn group_by_process_single_group(
program: &mut ProgramBuilder,
@@ -593,21 +495,19 @@ pub fn group_by_process_single_group(
.expect("aggregate registers must be initialized");
let agg_result_reg = start_reg + i;
let agg_arg_source = match &row_source {
GroupByRowSource::Sorter { pseudo_cursor, .. } => {
GroupByAggArgumentSource::new_from_cursor(
program,
*pseudo_cursor,
cursor_index + offset,
agg,
)
}
GroupByRowSource::Sorter { pseudo_cursor, .. } => AggArgumentSource::new_from_cursor(
program,
*pseudo_cursor,
cursor_index + offset,
agg,
),
GroupByRowSource::MainLoop { start_reg_src, .. } => {
// Aggregation arguments are always placed in the registers that follow any scalars.
let start_reg_aggs = start_reg_src + t_ctx.non_aggregate_expressions.len();
GroupByAggArgumentSource::new_from_registers(start_reg_aggs + offset, agg)
AggArgumentSource::new_from_registers(start_reg_aggs + offset, agg)
}
};
translate_aggregation_step_groupby(
translate_aggregation_step(
program,
&plan.table_references,
agg_arg_source,
@@ -897,220 +797,3 @@ pub fn group_by_emit_row_phase<'a>(
program.preassign_label_to_next_insn(labels.label_group_by_end);
Ok(())
}
/// Emits the bytecode for processing an aggregate step within a GROUP BY clause.
/// Eg. in `SELECT product_category, SUM(price) FROM t GROUP BY line_item`, 'price' is evaluated for every row
/// where the 'product_category' is the same, and the result is added to the accumulator for that category.
///
/// This is distinct from the final step, which is called after a single group has been entirely accumulated,
/// and the actual result value of the aggregation is materialized.
pub fn translate_aggregation_step_groupby(
program: &mut ProgramBuilder,
referenced_tables: &TableReferences,
agg_arg_source: GroupByAggArgumentSource,
target_register: usize,
resolver: &Resolver,
) -> Result<usize> {
let num_args = agg_arg_source.num_args();
let dest = match agg_arg_source.agg_func() {
AggFunc::Avg => {
if num_args != 1 {
crate::bail_parse_error!("avg bad number of arguments");
}
let expr_reg = agg_arg_source.translate(program, 0)?;
handle_distinct(program, agg_arg_source.aggregate(), expr_reg);
program.emit_insn(Insn::AggStep {
acc_reg: target_register,
col: expr_reg,
delimiter: 0,
func: AggFunc::Avg,
});
target_register
}
AggFunc::Count | AggFunc::Count0 => {
let expr_reg = agg_arg_source.translate(program, 0)?;
handle_distinct(program, agg_arg_source.aggregate(), expr_reg);
program.emit_insn(Insn::AggStep {
acc_reg: target_register,
col: expr_reg,
delimiter: 0,
func: if matches!(agg_arg_source.agg_func(), AggFunc::Count0) {
AggFunc::Count0
} else {
AggFunc::Count
},
});
target_register
}
AggFunc::GroupConcat => {
let num_args = agg_arg_source.num_args();
if num_args != 1 && num_args != 2 {
crate::bail_parse_error!("group_concat bad number of arguments");
}
let delimiter_reg = program.alloc_register();
let delimiter_expr: ast::Expr;
if num_args == 2 {
match &agg_arg_source.args()[1] {
arg @ ast::Expr::Column { .. } => {
delimiter_expr = arg.clone();
}
ast::Expr::Literal(ast::Literal::String(s)) => {
delimiter_expr = ast::Expr::Literal(ast::Literal::String(s.to_string()));
}
_ => crate::bail_parse_error!("Incorrect delimiter parameter"),
};
} else {
delimiter_expr = ast::Expr::Literal(ast::Literal::String(String::from("\",\"")));
}
let expr_reg = agg_arg_source.translate(program, 0)?;
handle_distinct(program, agg_arg_source.aggregate(), expr_reg);
translate_expr(
program,
Some(referenced_tables),
&delimiter_expr,
delimiter_reg,
resolver,
)?;
program.emit_insn(Insn::AggStep {
acc_reg: target_register,
col: expr_reg,
delimiter: delimiter_reg,
func: AggFunc::GroupConcat,
});
target_register
}
AggFunc::Max => {
if num_args != 1 {
crate::bail_parse_error!("max bad number of arguments");
}
let expr_reg = agg_arg_source.translate(program, 0)?;
handle_distinct(program, agg_arg_source.aggregate(), expr_reg);
program.emit_insn(Insn::AggStep {
acc_reg: target_register,
col: expr_reg,
delimiter: 0,
func: AggFunc::Max,
});
target_register
}
AggFunc::Min => {
if num_args != 1 {
crate::bail_parse_error!("min bad number of arguments");
}
let expr_reg = agg_arg_source.translate(program, 0)?;
handle_distinct(program, agg_arg_source.aggregate(), expr_reg);
program.emit_insn(Insn::AggStep {
acc_reg: target_register,
col: expr_reg,
delimiter: 0,
func: AggFunc::Min,
});
target_register
}
#[cfg(feature = "json")]
AggFunc::JsonGroupArray | AggFunc::JsonbGroupArray => {
if num_args != 1 {
crate::bail_parse_error!("min bad number of arguments");
}
let expr_reg = agg_arg_source.translate(program, 0)?;
handle_distinct(program, agg_arg_source.aggregate(), expr_reg);
program.emit_insn(Insn::AggStep {
acc_reg: target_register,
col: expr_reg,
delimiter: 0,
func: AggFunc::JsonGroupArray,
});
target_register
}
#[cfg(feature = "json")]
AggFunc::JsonGroupObject | AggFunc::JsonbGroupObject => {
if num_args != 2 {
crate::bail_parse_error!("max bad number of arguments");
}
let expr_reg = agg_arg_source.translate(program, 0)?;
handle_distinct(program, agg_arg_source.aggregate(), expr_reg);
let value_reg = agg_arg_source.translate(program, 1)?;
program.emit_insn(Insn::AggStep {
acc_reg: target_register,
col: expr_reg,
delimiter: value_reg,
func: AggFunc::JsonGroupObject,
});
target_register
}
AggFunc::StringAgg => {
if num_args != 2 {
crate::bail_parse_error!("string_agg bad number of arguments");
}
let delimiter_reg = program.alloc_register();
let delimiter_expr = match &agg_arg_source.args()[1] {
arg @ ast::Expr::Column { .. } => arg.clone(),
ast::Expr::Literal(ast::Literal::String(s)) => {
ast::Expr::Literal(ast::Literal::String(s.to_string()))
}
_ => crate::bail_parse_error!("Incorrect delimiter parameter"),
};
let expr_reg = agg_arg_source.translate(program, 0)?;
handle_distinct(program, agg_arg_source.aggregate(), expr_reg);
translate_expr(
program,
Some(referenced_tables),
&delimiter_expr,
delimiter_reg,
resolver,
)?;
program.emit_insn(Insn::AggStep {
acc_reg: target_register,
col: expr_reg,
delimiter: delimiter_reg,
func: AggFunc::StringAgg,
});
target_register
}
AggFunc::Sum => {
if num_args != 1 {
crate::bail_parse_error!("sum bad number of arguments");
}
let expr_reg = agg_arg_source.translate(program, 0)?;
handle_distinct(program, agg_arg_source.aggregate(), expr_reg);
program.emit_insn(Insn::AggStep {
acc_reg: target_register,
col: expr_reg,
delimiter: 0,
func: AggFunc::Sum,
});
target_register
}
AggFunc::Total => {
if num_args != 1 {
crate::bail_parse_error!("total bad number of arguments");
}
let expr_reg = agg_arg_source.translate(program, 0)?;
handle_distinct(program, agg_arg_source.aggregate(), expr_reg);
program.emit_insn(Insn::AggStep {
acc_reg: target_register,
col: expr_reg,
delimiter: 0,
func: AggFunc::Total,
});
target_register
}
AggFunc::External(_) => {
todo!("External aggregate functions are not yet supported in GROUP BY");
}
};
Ok(dest)
}

View File

@@ -19,7 +19,7 @@ use crate::{
};
use super::{
aggregation::translate_aggregation_step,
aggregation::{translate_aggregation_step, AggArgumentSource},
emitter::{OperationMode, TranslateCtx},
expr::{
translate_condition_expr, translate_expr, translate_expr_no_constant_opt,
@@ -868,7 +868,7 @@ fn emit_loop_source(
translate_aggregation_step(
program,
&plan.table_references,
agg,
AggArgumentSource::new_from_expression(agg),
reg,
&t_ctx.resolver,
)?;

View File

@@ -1048,6 +1048,24 @@ pub struct Aggregate {
}
impl Aggregate {
pub fn new(func: AggFunc, args: &[Box<Expr>], expr: &Expr, distinctness: Distinctness) -> Self {
let agg_args = if args.is_empty() {
// The AggStep instruction requires at least one argument. For functions that accept
// zero arguments (e.g. COUNT()), we insert a dummy literal so that AggStep remains valid.
// This does not cause ambiguity: the resolver has already verified that the function
// takes zero arguments, so the dummy value will be ignored.
vec![Expr::Literal(ast::Literal::Numeric("1".to_string()))]
} else {
args.iter().map(|arg| *arg.clone()).collect()
};
Aggregate {
func,
args: agg_args,
original_expr: expr.clone(),
distinctness,
}
}
pub fn is_distinct(&self) -> bool {
self.distinctness.is_distinct()
}

View File

@@ -73,12 +73,7 @@ pub fn resolve_aggregates(
"DISTINCT aggregate functions must have exactly one argument"
);
}
aggs.push(Aggregate {
func: f,
args: args.iter().map(|arg| *arg.clone()).collect(),
original_expr: expr.clone(),
distinctness,
});
aggs.push(Aggregate::new(f, args, expr, distinctness));
contains_aggregates = true;
}
_ => {
@@ -95,12 +90,7 @@ pub fn resolve_aggregates(
);
}
if let Ok(Func::Agg(f)) = Func::resolve_function(name.as_str(), 0) {
aggs.push(Aggregate {
func: f,
args: vec![],
original_expr: expr.clone(),
distinctness: Distinctness::NonDistinct,
});
aggs.push(Aggregate::new(f, &[], expr, Distinctness::NonDistinct));
contains_aggregates = true;
}
}

View File

@@ -371,27 +371,7 @@ fn prepare_one_select_plan(
}
match Func::resolve_function(name.as_str(), args_count) {
Ok(Func::Agg(f)) => {
let agg_args = match (args.is_empty(), &f) {
(true, crate::function::AggFunc::Count0) => {
// COUNT() case
vec![ast::Expr::Literal(ast::Literal::Numeric(
"1".to_string(),
))
.into()]
}
(true, _) => crate::bail_parse_error!(
"Aggregate function {} requires arguments",
name.as_str()
),
(false, _) => args.clone(),
};
let agg = Aggregate {
func: f,
args: agg_args.iter().map(|arg| *arg.clone()).collect(),
original_expr: *expr.clone(),
distinctness,
};
let agg = Aggregate::new(f, args, expr, distinctness);
aggregate_expressions.push(agg);
plan.result_columns.push(ResultSetColumn {
alias: maybe_alias.as_ref().map(|alias| match alias {
@@ -446,15 +426,12 @@ fn prepare_one_select_plan(
contains_aggregates,
});
} else {
let agg = Aggregate {
func: AggFunc::External(f.func.clone().into()),
args: args
.iter()
.map(|arg| *arg.clone())
.collect(),
original_expr: *expr.clone(),
let agg = Aggregate::new(
AggFunc::External(f.func.clone().into()),
args,
expr,
distinctness,
};
);
aggregate_expressions.push(agg);
plan.result_columns.push(ResultSetColumn {
alias: maybe_alias.as_ref().map(|alias| {
@@ -488,14 +465,8 @@ fn prepare_one_select_plan(
}
match Func::resolve_function(name.as_str(), 0) {
Ok(Func::Agg(f)) => {
let agg = Aggregate {
func: f,
args: vec![ast::Expr::Literal(ast::Literal::Numeric(
"1".to_string(),
))],
original_expr: *expr.clone(),
distinctness: Distinctness::NonDistinct,
};
let agg =
Aggregate::new(f, &[], expr, Distinctness::NonDistinct);
aggregate_expressions.push(agg);
plan.result_columns.push(ResultSetColumn {
alias: maybe_alias.as_ref().map(|alias| match alias {

View File

@@ -14,7 +14,7 @@ use crate::translate::plan::IterationDirection;
use crate::vdbe::sorter::Sorter;
use crate::vdbe::Register;
use crate::vtab::VirtualTableCursor;
use crate::{turso_assert, Completion, Result, IO};
use crate::{turso_assert, Completion, CompletionError, Result, IO};
use std::fmt::{Debug, Display};
const MAX_REAL_SIZE: u8 = 15;
@@ -350,6 +350,13 @@ impl Value {
}
}
pub fn as_uint(&self) -> u64 {
match self {
Value::Integer(i) => (*i).cast_unsigned(),
_ => 0,
}
}
pub fn from_text(text: &str) -> Self {
Value::Text(Text::new(text))
}
@@ -2502,6 +2509,13 @@ impl IOCompletions {
IOCompletions::Many(completions) => completions.iter().for_each(|c| c.abort()),
}
}
pub fn get_error(&self) -> Option<CompletionError> {
match self {
IOCompletions::Single(c) => c.get_error(),
IOCompletions::Many(completions) => completions.iter().find_map(|c| c.get_error()),
}
}
}
#[derive(Debug)]

View File

@@ -703,58 +703,7 @@ pub fn columns_from_create_table_body(
use turso_parser::ast;
Ok(columns
.iter()
.map(
|ast::ColumnDefinition {
col_name: name,
col_type,
constraints,
}| {
Column {
name: Some(normalize_ident(name.as_str())),
ty: match col_type {
Some(ref data_type) => type_from_name(data_type.name.as_str()).0,
None => Type::Null,
},
default: constraints.iter().find_map(|c| match &c.constraint {
ast::ColumnConstraint::Default(val) => Some(val.clone()),
_ => None,
}),
notnull: constraints
.iter()
.any(|c| matches!(c.constraint, ast::ColumnConstraint::NotNull { .. })),
ty_str: col_type
.clone()
.map(|t| t.name.to_string())
.unwrap_or_default(),
primary_key: constraints
.iter()
.any(|c| matches!(c.constraint, ast::ColumnConstraint::PrimaryKey { .. })),
is_rowid_alias: false,
unique: constraints
.iter()
.any(|c| matches!(c.constraint, ast::ColumnConstraint::Unique(..))),
collation: constraints.iter().find_map(|c| match &c.constraint {
// TODO: see if this should be the correct behavior
// currently there cannot be any user defined collation sequences.
// But in the future, when a user defines a collation sequence, creates a table with it,
// then closes the db and opens it again. This may panic here if the collation seq is not registered
// before reading the columns
ast::ColumnConstraint::Collate { collation_name } => Some(
CollationSeq::new(collation_name.as_str())
.expect("collation should have been set correctly in create table"),
),
_ => None,
}),
hidden: col_type
.as_ref()
.map(|data_type| data_type.name.as_str().contains("HIDDEN"))
.unwrap_or(false),
}
},
)
.collect::<Vec<_>>())
Ok(columns.iter().map(Into::into).collect())
}
/// This function checks if a given expression is a constant value that can be pushed down to the database engine.
@@ -803,6 +752,10 @@ pub struct OpenOptions<'a> {
pub cache: CacheMode,
/// immutable=1|0 specifies that the database is stored on read-only media
pub immutable: bool,
// The encryption cipher
pub cipher: Option<String>,
// The encryption key in hex format
pub hexkey: Option<String>,
}
pub const MEMORY_PATH: &str = ":memory:";
@@ -954,6 +907,8 @@ fn parse_query_params(query: &str, opts: &mut OpenOptions) -> Result<()> {
"cache" => opts.cache = decoded_value.as_str().into(),
"immutable" => opts.immutable = decoded_value == "1",
"vfs" => opts.vfs = Some(decoded_value),
"cipher" => opts.cipher = Some(decoded_value),
"hexkey" => opts.hexkey = Some(decoded_value),
_ => {}
}
}

View File

@@ -4461,10 +4461,16 @@ pub fn op_function(
}
}
ScalarFunc::SqliteVersion => {
let version_integer =
return_if_io!(pager.with_header(|header| header.version_number)).get() as i64;
let version = execute_sqlite_version(version_integer);
state.registers[*dest] = Register::Value(Value::build_text(version));
if !program.connection.is_db_initialized() {
state.registers[*dest] =
Register::Value(Value::build_text(info::build::PKG_VERSION));
} else {
let version_integer =
return_if_io!(pager.with_header(|header| header.version_number)).get()
as i64;
let version = execute_sqlite_version(version_integer);
state.registers[*dest] = Register::Value(Value::build_text(version));
}
}
ScalarFunc::SqliteSourceId => {
let src_id = format!(
@@ -4852,10 +4858,10 @@ pub fn op_function(
match stmt {
ast::Stmt::CreateIndex {
tbl_name,
unique,
if_not_exists,
idx_name,
tbl_name,
columns,
where_clause,
} => {
@@ -4867,10 +4873,10 @@ pub fn op_function(
Some(
ast::Stmt::CreateIndex {
tbl_name: ast::Name::new(&rename_to),
unique,
if_not_exists,
idx_name,
tbl_name: ast::Name::new(&rename_to),
columns,
where_clause,
}
@@ -4879,9 +4885,9 @@ pub fn op_function(
)
}
ast::Stmt::CreateTable {
tbl_name,
temporary,
if_not_exists,
tbl_name,
body,
} => {
let table_name = normalize_ident(tbl_name.name.as_str());
@@ -4892,13 +4898,13 @@ pub fn op_function(
Some(
ast::Stmt::CreateTable {
temporary,
if_not_exists,
tbl_name: ast::QualifiedName {
db_name: None,
name: ast::Name::new(&rename_to),
alias: None,
},
temporary,
if_not_exists,
body,
}
.format()
@@ -4911,7 +4917,7 @@ pub fn op_function(
(new_name, new_tbl_name, new_sql)
}
AlterTableFunc::RenameColumn => {
AlterTableFunc::AlterColumn | AlterTableFunc::RenameColumn => {
let table = {
match &state.registers[*start_reg + 5].get_value() {
Value::Text(rename_to) => normalize_ident(rename_to.as_str()),
@@ -4926,13 +4932,17 @@ pub fn op_function(
}
};
let rename_to = {
let column_def = {
match &state.registers[*start_reg + 7].get_value() {
Value::Text(rename_to) => normalize_ident(rename_to.as_str()),
Value::Text(column_def) => column_def.as_str(),
_ => panic!("rename_to parameter should be TEXT"),
}
};
let column_def = Parser::new(column_def.as_bytes())
.parse_column_definition(true)
.unwrap();
let new_sql = 'sql: {
if table != tbl_name {
break 'sql None;
@@ -4949,11 +4959,11 @@ pub fn op_function(
match stmt {
ast::Stmt::CreateIndex {
tbl_name,
mut columns,
unique,
if_not_exists,
idx_name,
tbl_name,
mut columns,
where_clause,
} => {
if table != normalize_ident(tbl_name.as_str()) {
@@ -4965,7 +4975,7 @@ pub fn op_function(
ast::Expr::Id(ast::Name::Ident(id))
if normalize_ident(id) == rename_from =>
{
*id = rename_to.clone();
*id = column_def.col_name.as_str().to_owned();
}
_ => {}
}
@@ -4973,11 +4983,11 @@ pub fn op_function(
Some(
ast::Stmt::CreateIndex {
tbl_name,
columns,
unique,
if_not_exists,
idx_name,
tbl_name,
columns,
where_clause,
}
.format()
@@ -4985,10 +4995,10 @@ pub fn op_function(
)
}
ast::Stmt::CreateTable {
temporary,
if_not_exists,
tbl_name,
body,
temporary,
if_not_exists,
} => {
if table != normalize_ident(tbl_name.name.as_str()) {
break 'sql None;
@@ -5008,18 +5018,24 @@ pub fn op_function(
.find(|column| column.col_name == ast::Name::new(&rename_from))
.expect("column being renamed should be present");
column.col_name = ast::Name::new(&rename_to);
match alter_func {
AlterTableFunc::AlterColumn => *column = column_def,
AlterTableFunc::RenameColumn => {
column.col_name = column_def.col_name
}
_ => unreachable!(),
}
Some(
ast::Stmt::CreateTable {
temporary,
if_not_exists,
tbl_name,
body: ast::CreateTableBody::ColumnsAndConstraints {
columns,
constraints,
options,
},
temporary,
if_not_exists,
}
.format()
.unwrap(),
@@ -7303,7 +7319,7 @@ pub fn op_add_column(
Ok(InsnFunctionStepResult::Step)
}
pub fn op_rename_column(
pub fn op_alter_column(
program: &Program,
state: &mut ProgramState,
insn: &Insn,
@@ -7311,16 +7327,19 @@ pub fn op_rename_column(
mv_store: Option<&Arc<MvStore>>,
) -> Result<InsnFunctionStepResult> {
load_insn!(
RenameColumn {
AlterColumn {
table: table_name,
column_index,
name
definition,
rename,
},
insn
);
let conn = program.connection.clone();
let new_column = crate::schema::Column::from(definition);
conn.with_schema_mut(|schema| {
let table = schema
.tables
@@ -7347,13 +7366,17 @@ pub fn op_rename_column(
if index_column.name
== *column.name.as_ref().expect("btree column should be named")
{
index_column.name = name.to_owned();
index_column.name = definition.col_name.as_str().to_owned();
}
}
}
}
column.name = Some(name.to_owned());
if *rename {
column.name = new_column.name;
} else {
*column = new_column;
}
});
state.pc += 1;

View File

@@ -1672,14 +1672,14 @@ pub fn insn_to_str(
0,
format!("add_column({table}, {column:?})"),
),
Insn::RenameColumn { table, column_index, name } => (
"RenameColumn",
Insn::AlterColumn { table, column_index, definition: column, rename } => (
"AlterColumn",
0,
0,
0,
Value::build_text(""),
0,
format!("rename_column({table}, {column_index}, {name})"),
format!("alter_column({table}, {column_index}, {column:?}, {rename:?})"),
),
Insn::MaxPgcnt { db, dest, new_max } => (
"MaxPgcnt",

View File

@@ -1053,10 +1053,11 @@ pub enum Insn {
table: String,
column: Column,
},
RenameColumn {
AlterColumn {
table: String,
column_index: usize,
name: String,
definition: turso_parser::ast::ColumnDefinition,
rename: bool,
},
/// Try to set the maximum page count for database P1 to the value in P3.
/// Do not let the maximum page count fall below the current page count and
@@ -1209,7 +1210,7 @@ impl Insn {
Insn::RenameTable { .. } => execute::op_rename_table,
Insn::DropColumn { .. } => execute::op_drop_column,
Insn::AddColumn { .. } => execute::op_add_column,
Insn::RenameColumn { .. } => execute::op_rename_column,
Insn::AlterColumn { .. } => execute::op_alter_column,
Insn::MaxPgcnt { .. } => execute::op_max_pgcnt,
Insn::JournalMode { .. } => execute::op_journal_mode,
}

View File

@@ -460,6 +460,11 @@ impl Program {
if !io.finished() {
return Ok(StepResult::IO);
}
if let Some(err) = io.get_error() {
let err = err.into();
handle_program_error(&pager, &self.connection, &err)?;
return Err(err);
}
state.io_completions = None;
}
// invalidate row

View File

@@ -370,7 +370,7 @@ struct SortedChunk {
/// The chunk file.
file: Arc<dyn File>,
/// Offset of the start of chunk in file
start_offset: usize,
start_offset: u64,
/// The size of this chunk file in bytes.
chunk_size: usize,
/// The read buffer.
@@ -391,7 +391,7 @@ impl SortedChunk {
fn new(file: Arc<dyn File>, start_offset: usize, buffer_size: usize) -> Self {
Self {
file,
start_offset,
start_offset: start_offset as u64,
chunk_size: 0,
buffer: Rc::new(RefCell::new(vec![0; buffer_size])),
buffer_len: Rc::new(Cell::new(0)),
@@ -522,7 +522,7 @@ impl SortedChunk {
let c = Completion::new_read(read_buffer_ref, read_complete);
let c = self
.file
.pread(self.start_offset + self.total_bytes_read.get(), c)?;
.pread(self.start_offset + self.total_bytes_read.get() as u64, c)?;
Ok(c)
}