mirror of
https://github.com/aljazceru/turso.git
synced 2026-01-08 02:34:20 +01:00
merge main
This commit is contained in:
@@ -1,5 +1,7 @@
|
||||
use thiserror::Error;
|
||||
|
||||
use crate::storage::page_cache::CacheError;
|
||||
|
||||
#[derive(Debug, Clone, Error, miette::Diagnostic)]
|
||||
pub enum LimboError {
|
||||
#[error("Corrupt database: {0}")]
|
||||
@@ -8,8 +10,8 @@ pub enum LimboError {
|
||||
NotADB,
|
||||
#[error("Internal error: {0}")]
|
||||
InternalError(String),
|
||||
#[error("Page cache is full")]
|
||||
CacheFull,
|
||||
#[error(transparent)]
|
||||
CacheError(#[from] CacheError),
|
||||
#[error("Database is full: {0}")]
|
||||
DatabaseFull(String),
|
||||
#[error("Parse error: {0}")]
|
||||
|
||||
@@ -583,6 +583,7 @@ impl Display for MathFunc {
|
||||
#[derive(Debug)]
|
||||
pub enum AlterTableFunc {
|
||||
RenameTable,
|
||||
AlterColumn,
|
||||
RenameColumn,
|
||||
}
|
||||
|
||||
@@ -591,6 +592,7 @@ impl Display for AlterTableFunc {
|
||||
match self {
|
||||
AlterTableFunc::RenameTable => write!(f, "limbo_rename_table"),
|
||||
AlterTableFunc::RenameColumn => write!(f, "limbo_rename_column"),
|
||||
AlterTableFunc::AlterColumn => write!(f, "limbo_alter_column"),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,8 +1,40 @@
|
||||
use core::f64;
|
||||
|
||||
use crate::types::Value;
|
||||
use crate::vdbe::Register;
|
||||
use crate::LimboError;
|
||||
|
||||
// TODO: Support %!.3s %i, %x, %X, %o, %e, %E, %c. flags: - + 0 ! ,
|
||||
fn get_exponential_formatted_str(number: &f64, uppercase: bool) -> crate::Result<String> {
|
||||
let pre_formatted = format!("{number:.6e}");
|
||||
let mut parts = pre_formatted.split("e");
|
||||
|
||||
let maybe_base = parts.next();
|
||||
let maybe_exponent = parts.next();
|
||||
|
||||
let mut result = String::new();
|
||||
match (maybe_base, maybe_exponent) {
|
||||
(Some(base), Some(exponent)) => {
|
||||
result.push_str(base);
|
||||
result.push_str(if uppercase { "E" } else { "e" });
|
||||
|
||||
match exponent.parse::<i32>() {
|
||||
Ok(exponent_number) => {
|
||||
let exponent_fmt = format!("{exponent_number:+03}");
|
||||
result.push_str(&exponent_fmt);
|
||||
Ok(result)
|
||||
}
|
||||
Err(_) => Err(LimboError::InternalError(
|
||||
"unable to parse exponential expression's exponent".into(),
|
||||
)),
|
||||
}
|
||||
}
|
||||
(_, _) => Err(LimboError::InternalError(
|
||||
"unable to parse exponential expression".into(),
|
||||
)),
|
||||
}
|
||||
}
|
||||
|
||||
// TODO: Support %!.3s. flags: - + 0 ! ,
|
||||
#[inline(always)]
|
||||
pub fn exec_printf(values: &[Register]) -> crate::Result<Value> {
|
||||
if values.is_empty() {
|
||||
@@ -40,6 +72,20 @@ pub fn exec_printf(values: &[Register]) -> crate::Result<Value> {
|
||||
}
|
||||
args_index += 1;
|
||||
}
|
||||
Some('u') => {
|
||||
if args_index >= values.len() {
|
||||
return Err(LimboError::InvalidArgument("not enough arguments".into()));
|
||||
}
|
||||
let value = &values[args_index].get_value();
|
||||
match value {
|
||||
Value::Integer(_) => {
|
||||
let converted_value = value.as_uint();
|
||||
result.push_str(&format!("{converted_value}"))
|
||||
}
|
||||
_ => result.push('0'),
|
||||
}
|
||||
args_index += 1;
|
||||
}
|
||||
Some('s') => {
|
||||
if args_index >= values.len() {
|
||||
return Err(LimboError::InvalidArgument("not enough arguments".into()));
|
||||
@@ -63,6 +109,119 @@ pub fn exec_printf(values: &[Register]) -> crate::Result<Value> {
|
||||
}
|
||||
args_index += 1;
|
||||
}
|
||||
Some('e') => {
|
||||
if args_index >= values.len() {
|
||||
return Err(LimboError::InvalidArgument("not enough arguments".into()));
|
||||
}
|
||||
let value = &values[args_index].get_value();
|
||||
match value {
|
||||
Value::Float(f) => match get_exponential_formatted_str(f, false) {
|
||||
Ok(str) => result.push_str(&str),
|
||||
Err(e) => return Err(e),
|
||||
},
|
||||
Value::Integer(i) => {
|
||||
let f = *i as f64;
|
||||
match get_exponential_formatted_str(&f, false) {
|
||||
Ok(str) => result.push_str(&str),
|
||||
Err(e) => return Err(e),
|
||||
}
|
||||
}
|
||||
Value::Text(s) => {
|
||||
let number: f64 = s
|
||||
.as_str()
|
||||
.trim_start()
|
||||
.trim_end_matches(|c: char| !c.is_numeric())
|
||||
.parse()
|
||||
.unwrap_or(0.0);
|
||||
match get_exponential_formatted_str(&number, false) {
|
||||
Ok(str) => result.push_str(&str),
|
||||
Err(e) => return Err(e),
|
||||
};
|
||||
}
|
||||
_ => result.push_str("0.000000e+00"),
|
||||
}
|
||||
args_index += 1;
|
||||
}
|
||||
Some('E') => {
|
||||
if args_index >= values.len() {
|
||||
return Err(LimboError::InvalidArgument("not enough arguments".into()));
|
||||
}
|
||||
let value = &values[args_index].get_value();
|
||||
match value {
|
||||
Value::Float(f) => match get_exponential_formatted_str(f, false) {
|
||||
Ok(str) => result.push_str(&str),
|
||||
Err(e) => return Err(e),
|
||||
},
|
||||
Value::Integer(i) => {
|
||||
let f = *i as f64;
|
||||
match get_exponential_formatted_str(&f, false) {
|
||||
Ok(str) => result.push_str(&str),
|
||||
Err(e) => return Err(e),
|
||||
}
|
||||
}
|
||||
Value::Text(s) => {
|
||||
let number: f64 = s
|
||||
.as_str()
|
||||
.trim_start()
|
||||
.trim_end_matches(|c: char| !c.is_numeric())
|
||||
.parse()
|
||||
.unwrap_or(0.0);
|
||||
match get_exponential_formatted_str(&number, false) {
|
||||
Ok(str) => result.push_str(&str),
|
||||
Err(e) => return Err(e),
|
||||
};
|
||||
}
|
||||
_ => result.push_str("0.000000e+00"),
|
||||
}
|
||||
args_index += 1;
|
||||
}
|
||||
Some('c') => {
|
||||
if args_index >= values.len() {
|
||||
return Err(LimboError::InvalidArgument("not enough arguments".into()));
|
||||
}
|
||||
let value = &values[args_index].get_value();
|
||||
let value_str: String = format!("{value}");
|
||||
if !value_str.is_empty() {
|
||||
result.push_str(&value_str[0..1]);
|
||||
}
|
||||
args_index += 1;
|
||||
}
|
||||
Some('x') => {
|
||||
if args_index >= values.len() {
|
||||
return Err(LimboError::InvalidArgument("not enough arguments".into()));
|
||||
}
|
||||
let value = &values[args_index].get_value();
|
||||
match value {
|
||||
Value::Float(f) => result.push_str(&format!("{:x}", *f as i64)),
|
||||
Value::Integer(i) => result.push_str(&format!("{i:x}")),
|
||||
_ => result.push('0'),
|
||||
}
|
||||
args_index += 1;
|
||||
}
|
||||
Some('X') => {
|
||||
if args_index >= values.len() {
|
||||
return Err(LimboError::InvalidArgument("not enough arguments".into()));
|
||||
}
|
||||
let value = &values[args_index].get_value();
|
||||
match value {
|
||||
Value::Float(f) => result.push_str(&format!("{:X}", *f as i64)),
|
||||
Value::Integer(i) => result.push_str(&format!("{i:X}")),
|
||||
_ => result.push('0'),
|
||||
}
|
||||
args_index += 1;
|
||||
}
|
||||
Some('o') => {
|
||||
if args_index >= values.len() {
|
||||
return Err(LimboError::InvalidArgument("not enough arguments".into()));
|
||||
}
|
||||
let value = &values[args_index].get_value();
|
||||
match value {
|
||||
Value::Float(f) => result.push_str(&format!("{:o}", *f as i64)),
|
||||
Value::Integer(i) => result.push_str(&format!("{i:o}")),
|
||||
_ => result.push('0'),
|
||||
}
|
||||
args_index += 1;
|
||||
}
|
||||
None => {
|
||||
return Err(LimboError::InvalidArgument(
|
||||
"incomplete format specifier".into(),
|
||||
@@ -159,6 +318,29 @@ mod tests {
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_printf_unsigned_integer_formatting() {
|
||||
let test_cases = vec![
|
||||
// Basic
|
||||
(vec![text("Number: %u"), integer(42)], text("Number: 42")),
|
||||
// Multiple numbers
|
||||
(
|
||||
vec![text("%u + %u = %u"), integer(2), integer(3), integer(5)],
|
||||
text("2 + 3 = 5"),
|
||||
),
|
||||
// Negative number should be represented as its uint representation
|
||||
(
|
||||
vec![text("Negative: %u"), integer(-1)],
|
||||
text("Negative: 18446744073709551615"),
|
||||
),
|
||||
// Non-numeric value defaults to 0
|
||||
(vec![text("NaN: %u"), text("not a number")], text("NaN: 0")),
|
||||
];
|
||||
for (input, output) in test_cases {
|
||||
assert_eq!(exec_printf(&input).unwrap(), *output.get_value())
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_printf_float_formatting() {
|
||||
let test_cases = vec![
|
||||
@@ -194,6 +376,178 @@ mod tests {
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_printf_character_formatting() {
|
||||
let test_cases = vec![
|
||||
// Simple character
|
||||
(vec![text("character: %c"), text("a")], text("character: a")),
|
||||
// Character with string
|
||||
(
|
||||
vec![text("character: %c"), text("this is a test")],
|
||||
text("character: t"),
|
||||
),
|
||||
// Character with empty
|
||||
(vec![text("character: %c"), text("")], text("character: ")),
|
||||
// Character with integer
|
||||
(
|
||||
vec![text("character: %c"), integer(123)],
|
||||
text("character: 1"),
|
||||
),
|
||||
// Character with float
|
||||
(
|
||||
vec![text("character: %c"), float(42.5)],
|
||||
text("character: 4"),
|
||||
),
|
||||
];
|
||||
|
||||
for (input, expected) in test_cases {
|
||||
assert_eq!(exec_printf(&input).unwrap(), *expected.get_value());
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_printf_exponential_formatting() {
|
||||
let test_cases = vec![
|
||||
// Simple number
|
||||
(
|
||||
vec![text("Exp: %e"), float(23000000.0)],
|
||||
text("Exp: 2.300000e+07"),
|
||||
),
|
||||
// Negative number
|
||||
(
|
||||
vec![text("Exp: %e"), float(-23000000.0)],
|
||||
text("Exp: -2.300000e+07"),
|
||||
),
|
||||
// Non integer float
|
||||
(
|
||||
vec![text("Exp: %e"), float(250.375)],
|
||||
text("Exp: 2.503750e+02"),
|
||||
),
|
||||
// Positive, but smaller than zero
|
||||
(
|
||||
vec![text("Exp: %e"), float(0.0003235)],
|
||||
text("Exp: 3.235000e-04"),
|
||||
),
|
||||
// Zero
|
||||
(vec![text("Exp: %e"), float(0.0)], text("Exp: 0.000000e+00")),
|
||||
// Uppercase "e"
|
||||
(
|
||||
vec![text("Exp: %e"), float(0.0003235)],
|
||||
text("Exp: 3.235000e-04"),
|
||||
),
|
||||
// String with integer number
|
||||
(
|
||||
vec![text("Exp: %e"), text("123")],
|
||||
text("Exp: 1.230000e+02"),
|
||||
),
|
||||
// String with floating point number
|
||||
(
|
||||
vec![text("Exp: %e"), text("123.45")],
|
||||
text("Exp: 1.234500e+02"),
|
||||
),
|
||||
// String with number with leftmost zeroes
|
||||
(
|
||||
vec![text("Exp: %e"), text("00123")],
|
||||
text("Exp: 1.230000e+02"),
|
||||
),
|
||||
// String with text
|
||||
(
|
||||
vec![text("Exp: %e"), text("test")],
|
||||
text("Exp: 0.000000e+00"),
|
||||
),
|
||||
// String starting with number, but with text on the end
|
||||
(
|
||||
vec![text("Exp: %e"), text("123ab")],
|
||||
text("Exp: 1.230000e+02"),
|
||||
),
|
||||
// String starting with text, but with number on the end
|
||||
(
|
||||
vec![text("Exp: %e"), text("ab123")],
|
||||
text("Exp: 0.000000e+00"),
|
||||
),
|
||||
// String with exponential representation
|
||||
(
|
||||
vec![text("Exp: %e"), text("1.230000e+02")],
|
||||
text("Exp: 1.230000e+02"),
|
||||
),
|
||||
];
|
||||
|
||||
for (input, expected) in test_cases {
|
||||
assert_eq!(exec_printf(&input).unwrap(), *expected.get_value());
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_printf_hexadecimal_formatting() {
|
||||
let test_cases = vec![
|
||||
// Simple number
|
||||
(vec![text("hex: %x"), integer(4)], text("hex: 4")),
|
||||
// Bigger Number
|
||||
(
|
||||
vec![text("hex: %x"), integer(15565303546)],
|
||||
text("hex: 39fc3aefa"),
|
||||
),
|
||||
// Uppercase letters
|
||||
(
|
||||
vec![text("hex: %X"), integer(15565303546)],
|
||||
text("hex: 39FC3AEFA"),
|
||||
),
|
||||
// Negative
|
||||
(
|
||||
vec![text("hex: %x"), integer(-15565303546)],
|
||||
text("hex: fffffffc603c5106"),
|
||||
),
|
||||
// Float
|
||||
(vec![text("hex: %x"), float(42.5)], text("hex: 2a")),
|
||||
// Negative Float
|
||||
(
|
||||
vec![text("hex: %x"), float(-42.5)],
|
||||
text("hex: ffffffffffffffd6"),
|
||||
),
|
||||
// Text
|
||||
(vec![text("hex: %x"), text("42")], text("hex: 0")),
|
||||
// Empty Text
|
||||
(vec![text("hex: %x"), text("")], text("hex: 0")),
|
||||
];
|
||||
|
||||
for (input, expected) in test_cases {
|
||||
assert_eq!(exec_printf(&input).unwrap(), *expected.get_value());
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_printf_octal_formatting() {
|
||||
let test_cases = vec![
|
||||
// Simple number
|
||||
(vec![text("octal: %o"), integer(4)], text("octal: 4")),
|
||||
// Bigger Number
|
||||
(
|
||||
vec![text("octal: %o"), integer(15565303546)],
|
||||
text("octal: 163760727372"),
|
||||
),
|
||||
// Negative
|
||||
(
|
||||
vec![text("octal: %o"), integer(-15565303546)],
|
||||
text("octal: 1777777777614017050406"),
|
||||
),
|
||||
// Float
|
||||
(vec![text("octal: %o"), float(42.5)], text("octal: 52")),
|
||||
// Negative Float
|
||||
(
|
||||
vec![text("octal: %o"), float(-42.5)],
|
||||
text("octal: 1777777777777777777726"),
|
||||
),
|
||||
// Text
|
||||
(vec![text("octal: %o"), text("42")], text("octal: 0")),
|
||||
// Empty Text
|
||||
(vec![text("octal: %o"), text("")], text("octal: 0")),
|
||||
];
|
||||
|
||||
for (input, expected) in test_cases {
|
||||
assert_eq!(exec_printf(&input).unwrap(), *expected.get_value());
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_printf_mixed_formatting() {
|
||||
let test_cases = vec![
|
||||
|
||||
@@ -68,9 +68,9 @@ impl File for GenericFile {
|
||||
}
|
||||
|
||||
#[instrument(skip(self, c), level = Level::TRACE)]
|
||||
fn pread(&self, pos: usize, c: Completion) -> Result<Completion> {
|
||||
fn pread(&self, pos: u64, c: Completion) -> Result<Completion> {
|
||||
let mut file = self.file.write();
|
||||
file.seek(std::io::SeekFrom::Start(pos as u64))?;
|
||||
file.seek(std::io::SeekFrom::Start(pos))?;
|
||||
let nr = {
|
||||
let r = c.as_read();
|
||||
let buf = r.buf();
|
||||
@@ -83,9 +83,9 @@ impl File for GenericFile {
|
||||
}
|
||||
|
||||
#[instrument(skip(self, c, buffer), level = Level::TRACE)]
|
||||
fn pwrite(&self, pos: usize, buffer: Arc<crate::Buffer>, c: Completion) -> Result<Completion> {
|
||||
fn pwrite(&self, pos: u64, buffer: Arc<crate::Buffer>, c: Completion) -> Result<Completion> {
|
||||
let mut file = self.file.write();
|
||||
file.seek(std::io::SeekFrom::Start(pos as u64))?;
|
||||
file.seek(std::io::SeekFrom::Start(pos))?;
|
||||
let buf = buffer.as_slice();
|
||||
file.write_all(buf)?;
|
||||
c.complete(buffer.len() as i32);
|
||||
@@ -101,9 +101,9 @@ impl File for GenericFile {
|
||||
}
|
||||
|
||||
#[instrument(err, skip_all, level = Level::TRACE)]
|
||||
fn truncate(&self, len: usize, c: Completion) -> Result<Completion> {
|
||||
fn truncate(&self, len: u64, c: Completion) -> Result<Completion> {
|
||||
let file = self.file.write();
|
||||
file.set_len(len as u64)?;
|
||||
file.set_len(len)?;
|
||||
c.complete(0);
|
||||
Ok(c)
|
||||
}
|
||||
|
||||
@@ -182,7 +182,7 @@ struct WritevState {
|
||||
/// File descriptor/id of the file we are writing to
|
||||
file_id: Fd,
|
||||
/// absolute file offset for next submit
|
||||
file_pos: usize,
|
||||
file_pos: u64,
|
||||
/// current buffer index in `bufs`
|
||||
current_buffer_idx: usize,
|
||||
/// intra-buffer offset
|
||||
@@ -198,7 +198,7 @@ struct WritevState {
|
||||
}
|
||||
|
||||
impl WritevState {
|
||||
fn new(file: &UringFile, pos: usize, bufs: Vec<Arc<crate::Buffer>>) -> Self {
|
||||
fn new(file: &UringFile, pos: u64, bufs: Vec<Arc<crate::Buffer>>) -> Self {
|
||||
let file_id = file
|
||||
.id()
|
||||
.map(Fd::Fixed)
|
||||
@@ -223,23 +223,23 @@ impl WritevState {
|
||||
|
||||
/// Advance (idx, off, pos) after written bytes
|
||||
#[inline(always)]
|
||||
fn advance(&mut self, written: usize) {
|
||||
fn advance(&mut self, written: u64) {
|
||||
let mut remaining = written;
|
||||
while remaining > 0 {
|
||||
let current_buf_len = self.bufs[self.current_buffer_idx].len();
|
||||
let left = current_buf_len - self.current_buffer_offset;
|
||||
if remaining < left {
|
||||
self.current_buffer_offset += remaining;
|
||||
if remaining < left as u64 {
|
||||
self.current_buffer_offset += remaining as usize;
|
||||
self.file_pos += remaining;
|
||||
remaining = 0;
|
||||
} else {
|
||||
remaining -= left;
|
||||
self.file_pos += left;
|
||||
remaining -= left as u64;
|
||||
self.file_pos += left as u64;
|
||||
self.current_buffer_idx += 1;
|
||||
self.current_buffer_offset = 0;
|
||||
}
|
||||
}
|
||||
self.total_written += written;
|
||||
self.total_written += written as usize;
|
||||
}
|
||||
|
||||
#[inline(always)]
|
||||
@@ -400,7 +400,7 @@ impl WrappedIOUring {
|
||||
iov_allocation[0].iov_len as u32,
|
||||
id as u16,
|
||||
)
|
||||
.offset(st.file_pos as u64)
|
||||
.offset(st.file_pos)
|
||||
.build()
|
||||
.user_data(key)
|
||||
} else {
|
||||
@@ -409,7 +409,7 @@ impl WrappedIOUring {
|
||||
iov_allocation[0].iov_base as *const u8,
|
||||
iov_allocation[0].iov_len as u32,
|
||||
)
|
||||
.offset(st.file_pos as u64)
|
||||
.offset(st.file_pos)
|
||||
.build()
|
||||
.user_data(key)
|
||||
}
|
||||
@@ -425,7 +425,7 @@ impl WrappedIOUring {
|
||||
|
||||
let entry = with_fd!(st.file_id, |fd| {
|
||||
io_uring::opcode::Writev::new(fd, ptr, iov_count as u32)
|
||||
.offset(st.file_pos as u64)
|
||||
.offset(st.file_pos)
|
||||
.build()
|
||||
.user_data(key)
|
||||
});
|
||||
@@ -443,8 +443,8 @@ impl WrappedIOUring {
|
||||
return;
|
||||
}
|
||||
|
||||
let written = result as usize;
|
||||
state.advance(written);
|
||||
let written = result;
|
||||
state.advance(written as u64);
|
||||
match state.remaining() {
|
||||
0 => {
|
||||
tracing::info!(
|
||||
@@ -643,7 +643,7 @@ impl File for UringFile {
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn pread(&self, pos: usize, c: Completion) -> Result<Completion> {
|
||||
fn pread(&self, pos: u64, c: Completion) -> Result<Completion> {
|
||||
let r = c.as_read();
|
||||
let mut io = self.io.borrow_mut();
|
||||
let read_e = {
|
||||
@@ -663,14 +663,14 @@ impl File for UringFile {
|
||||
io.debug_check_fixed(idx, ptr, len);
|
||||
}
|
||||
io_uring::opcode::ReadFixed::new(fd, ptr, len as u32, idx as u16)
|
||||
.offset(pos as u64)
|
||||
.offset(pos)
|
||||
.build()
|
||||
.user_data(get_key(c.clone()))
|
||||
} else {
|
||||
trace!("pread(pos = {}, length = {})", pos, len);
|
||||
// Use Read opcode if fixed buffer is not available
|
||||
io_uring::opcode::Read::new(fd, buf.as_mut_ptr(), len as u32)
|
||||
.offset(pos as u64)
|
||||
.offset(pos)
|
||||
.build()
|
||||
.user_data(get_key(c.clone()))
|
||||
}
|
||||
@@ -680,7 +680,7 @@ impl File for UringFile {
|
||||
Ok(c)
|
||||
}
|
||||
|
||||
fn pwrite(&self, pos: usize, buffer: Arc<crate::Buffer>, c: Completion) -> Result<Completion> {
|
||||
fn pwrite(&self, pos: u64, buffer: Arc<crate::Buffer>, c: Completion) -> Result<Completion> {
|
||||
let mut io = self.io.borrow_mut();
|
||||
let write = {
|
||||
let ptr = buffer.as_ptr();
|
||||
@@ -698,13 +698,13 @@ impl File for UringFile {
|
||||
io.debug_check_fixed(idx, ptr, len);
|
||||
}
|
||||
io_uring::opcode::WriteFixed::new(fd, ptr, len as u32, idx as u16)
|
||||
.offset(pos as u64)
|
||||
.offset(pos)
|
||||
.build()
|
||||
.user_data(get_key(c.clone()))
|
||||
} else {
|
||||
trace!("pwrite(pos = {}, length = {})", pos, buffer.len());
|
||||
io_uring::opcode::Write::new(fd, ptr, len as u32)
|
||||
.offset(pos as u64)
|
||||
.offset(pos)
|
||||
.build()
|
||||
.user_data(get_key(c.clone()))
|
||||
}
|
||||
@@ -728,7 +728,7 @@ impl File for UringFile {
|
||||
|
||||
fn pwritev(
|
||||
&self,
|
||||
pos: usize,
|
||||
pos: u64,
|
||||
bufs: Vec<Arc<crate::Buffer>>,
|
||||
c: Completion,
|
||||
) -> Result<Completion> {
|
||||
@@ -748,10 +748,10 @@ impl File for UringFile {
|
||||
Ok(self.file.metadata()?.len())
|
||||
}
|
||||
|
||||
fn truncate(&self, len: usize, c: Completion) -> Result<Completion> {
|
||||
fn truncate(&self, len: u64, c: Completion) -> Result<Completion> {
|
||||
let mut io = self.io.borrow_mut();
|
||||
let truncate = with_fd!(self, |fd| {
|
||||
io_uring::opcode::Ftruncate::new(fd, len as u64)
|
||||
io_uring::opcode::Ftruncate::new(fd, len)
|
||||
.build()
|
||||
.user_data(get_key(c.clone()))
|
||||
});
|
||||
|
||||
@@ -69,17 +69,12 @@ impl IO for MemoryIO {
|
||||
files.remove(path);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn run_once(&self) -> Result<()> {
|
||||
// nop
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
pub struct MemoryFile {
|
||||
path: String,
|
||||
pages: UnsafeCell<BTreeMap<usize, MemPage>>,
|
||||
size: Cell<usize>,
|
||||
size: Cell<u64>,
|
||||
}
|
||||
unsafe impl Send for MemoryFile {}
|
||||
unsafe impl Sync for MemoryFile {}
|
||||
@@ -92,10 +87,10 @@ impl File for MemoryFile {
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn pread(&self, pos: usize, c: Completion) -> Result<Completion> {
|
||||
fn pread(&self, pos: u64, c: Completion) -> Result<Completion> {
|
||||
tracing::debug!("pread(path={}): pos={}", self.path, pos);
|
||||
let r = c.as_read();
|
||||
let buf_len = r.buf().len();
|
||||
let buf_len = r.buf().len() as u64;
|
||||
if buf_len == 0 {
|
||||
c.complete(0);
|
||||
return Ok(c);
|
||||
@@ -110,8 +105,8 @@ impl File for MemoryFile {
|
||||
let read_len = buf_len.min(file_size - pos);
|
||||
{
|
||||
let read_buf = r.buf();
|
||||
let mut offset = pos;
|
||||
let mut remaining = read_len;
|
||||
let mut offset = pos as usize;
|
||||
let mut remaining = read_len as usize;
|
||||
let mut buf_offset = 0;
|
||||
|
||||
while remaining > 0 {
|
||||
@@ -134,7 +129,7 @@ impl File for MemoryFile {
|
||||
Ok(c)
|
||||
}
|
||||
|
||||
fn pwrite(&self, pos: usize, buffer: Arc<Buffer>, c: Completion) -> Result<Completion> {
|
||||
fn pwrite(&self, pos: u64, buffer: Arc<Buffer>, c: Completion) -> Result<Completion> {
|
||||
tracing::debug!(
|
||||
"pwrite(path={}): pos={}, size={}",
|
||||
self.path,
|
||||
@@ -147,7 +142,7 @@ impl File for MemoryFile {
|
||||
return Ok(c);
|
||||
}
|
||||
|
||||
let mut offset = pos;
|
||||
let mut offset = pos as usize;
|
||||
let mut remaining = buf_len;
|
||||
let mut buf_offset = 0;
|
||||
let data = &buffer.as_slice();
|
||||
@@ -158,7 +153,7 @@ impl File for MemoryFile {
|
||||
let bytes_to_write = remaining.min(PAGE_SIZE - page_offset);
|
||||
|
||||
{
|
||||
let page = self.get_or_allocate_page(page_no);
|
||||
let page = self.get_or_allocate_page(page_no as u64);
|
||||
page[page_offset..page_offset + bytes_to_write]
|
||||
.copy_from_slice(&data[buf_offset..buf_offset + bytes_to_write]);
|
||||
}
|
||||
@@ -169,7 +164,7 @@ impl File for MemoryFile {
|
||||
}
|
||||
|
||||
self.size
|
||||
.set(core::cmp::max(pos + buf_len, self.size.get()));
|
||||
.set(core::cmp::max(pos + buf_len as u64, self.size.get()));
|
||||
|
||||
c.complete(buf_len as i32);
|
||||
Ok(c)
|
||||
@@ -182,13 +177,13 @@ impl File for MemoryFile {
|
||||
Ok(c)
|
||||
}
|
||||
|
||||
fn truncate(&self, len: usize, c: Completion) -> Result<Completion> {
|
||||
fn truncate(&self, len: u64, c: Completion) -> Result<Completion> {
|
||||
tracing::debug!("truncate(path={}): len={}", self.path, len);
|
||||
if len < self.size.get() {
|
||||
// Truncate pages
|
||||
unsafe {
|
||||
let pages = &mut *self.pages.get();
|
||||
pages.retain(|&k, _| k * PAGE_SIZE < len);
|
||||
pages.retain(|&k, _| k * PAGE_SIZE < len as usize);
|
||||
}
|
||||
}
|
||||
self.size.set(len);
|
||||
@@ -196,14 +191,14 @@ impl File for MemoryFile {
|
||||
Ok(c)
|
||||
}
|
||||
|
||||
fn pwritev(&self, pos: usize, buffers: Vec<Arc<Buffer>>, c: Completion) -> Result<Completion> {
|
||||
fn pwritev(&self, pos: u64, buffers: Vec<Arc<Buffer>>, c: Completion) -> Result<Completion> {
|
||||
tracing::debug!(
|
||||
"pwritev(path={}): pos={}, buffers={:?}",
|
||||
self.path,
|
||||
pos,
|
||||
buffers.iter().map(|x| x.len()).collect::<Vec<_>>()
|
||||
);
|
||||
let mut offset = pos;
|
||||
let mut offset = pos as usize;
|
||||
let mut total_written = 0;
|
||||
|
||||
for buffer in buffers {
|
||||
@@ -222,7 +217,7 @@ impl File for MemoryFile {
|
||||
let bytes_to_write = remaining.min(PAGE_SIZE - page_offset);
|
||||
|
||||
{
|
||||
let page = self.get_or_allocate_page(page_no);
|
||||
let page = self.get_or_allocate_page(page_no as u64);
|
||||
page[page_offset..page_offset + bytes_to_write]
|
||||
.copy_from_slice(&data[buf_offset..buf_offset + bytes_to_write]);
|
||||
}
|
||||
@@ -235,23 +230,23 @@ impl File for MemoryFile {
|
||||
}
|
||||
c.complete(total_written as i32);
|
||||
self.size
|
||||
.set(core::cmp::max(pos + total_written, self.size.get()));
|
||||
.set(core::cmp::max(pos + total_written as u64, self.size.get()));
|
||||
Ok(c)
|
||||
}
|
||||
|
||||
fn size(&self) -> Result<u64> {
|
||||
tracing::debug!("size(path={}): {}", self.path, self.size.get());
|
||||
Ok(self.size.get() as u64)
|
||||
Ok(self.size.get())
|
||||
}
|
||||
}
|
||||
|
||||
impl MemoryFile {
|
||||
#[allow(clippy::mut_from_ref)]
|
||||
fn get_or_allocate_page(&self, page_no: usize) -> &mut MemPage {
|
||||
fn get_or_allocate_page(&self, page_no: u64) -> &mut MemPage {
|
||||
unsafe {
|
||||
let pages = &mut *self.pages.get();
|
||||
pages
|
||||
.entry(page_no)
|
||||
.entry(page_no as usize)
|
||||
.or_insert_with(|| Box::new([0; PAGE_SIZE]))
|
||||
}
|
||||
}
|
||||
|
||||
@@ -12,10 +12,10 @@ use std::{fmt::Debug, pin::Pin};
|
||||
pub trait File: Send + Sync {
|
||||
fn lock_file(&self, exclusive: bool) -> Result<()>;
|
||||
fn unlock_file(&self) -> Result<()>;
|
||||
fn pread(&self, pos: usize, c: Completion) -> Result<Completion>;
|
||||
fn pwrite(&self, pos: usize, buffer: Arc<Buffer>, c: Completion) -> Result<Completion>;
|
||||
fn pread(&self, pos: u64, c: Completion) -> Result<Completion>;
|
||||
fn pwrite(&self, pos: u64, buffer: Arc<Buffer>, c: Completion) -> Result<Completion>;
|
||||
fn sync(&self, c: Completion) -> Result<Completion>;
|
||||
fn pwritev(&self, pos: usize, buffers: Vec<Arc<Buffer>>, c: Completion) -> Result<Completion> {
|
||||
fn pwritev(&self, pos: u64, buffers: Vec<Arc<Buffer>>, c: Completion) -> Result<Completion> {
|
||||
use std::sync::atomic::{AtomicUsize, Ordering};
|
||||
if buffers.is_empty() {
|
||||
c.complete(0);
|
||||
@@ -56,12 +56,12 @@ pub trait File: Send + Sync {
|
||||
c.abort();
|
||||
return Err(e);
|
||||
}
|
||||
pos += len;
|
||||
pos += len as u64;
|
||||
}
|
||||
Ok(c)
|
||||
}
|
||||
fn size(&self) -> Result<u64>;
|
||||
fn truncate(&self, len: usize, c: Completion) -> Result<Completion>;
|
||||
fn truncate(&self, len: u64, c: Completion) -> Result<Completion>;
|
||||
}
|
||||
|
||||
#[derive(Debug, Copy, Clone, PartialEq)]
|
||||
@@ -87,7 +87,9 @@ pub trait IO: Clock + Send + Sync {
|
||||
// remove_file is used in the sync-engine
|
||||
fn remove_file(&self, path: &str) -> Result<()>;
|
||||
|
||||
fn run_once(&self) -> Result<()>;
|
||||
fn run_once(&self) -> Result<()> {
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn wait_for_completion(&self, c: Completion) -> Result<()> {
|
||||
while !c.finished() {
|
||||
@@ -214,6 +216,10 @@ impl Completion {
|
||||
self.inner.result.get().is_some_and(|val| val.is_some())
|
||||
}
|
||||
|
||||
pub fn get_error(&self) -> Option<CompletionError> {
|
||||
self.inner.result.get().and_then(|res| *res)
|
||||
}
|
||||
|
||||
/// Checks if the Completion completed or errored
|
||||
pub fn finished(&self) -> bool {
|
||||
self.inner.result.get().is_some()
|
||||
|
||||
@@ -15,8 +15,6 @@ use std::{io::ErrorKind, sync::Arc};
|
||||
use tracing::debug;
|
||||
use tracing::{instrument, trace, Level};
|
||||
|
||||
/// UnixIO lives longer than any of the files it creates, so it is
|
||||
/// safe to store references to it's internals in the UnixFiles
|
||||
pub struct UnixIO {}
|
||||
|
||||
unsafe impl Send for UnixIO {}
|
||||
@@ -127,24 +125,6 @@ impl IO for UnixIO {
|
||||
}
|
||||
}
|
||||
|
||||
// enum CompletionCallback {
|
||||
// Read(Arc<Mutex<std::fs::File>>, Completion, usize),
|
||||
// Write(
|
||||
// Arc<Mutex<std::fs::File>>,
|
||||
// Completion,
|
||||
// Arc<crate::Buffer>,
|
||||
// usize,
|
||||
// ),
|
||||
// Writev(
|
||||
// Arc<Mutex<std::fs::File>>,
|
||||
// Completion,
|
||||
// Vec<Arc<crate::Buffer>>,
|
||||
// usize, // absolute file offset
|
||||
// usize, // buf index
|
||||
// usize, // intra-buf offset
|
||||
// ),
|
||||
// }
|
||||
|
||||
pub struct UnixFile {
|
||||
file: Arc<Mutex<std::fs::File>>,
|
||||
}
|
||||
@@ -192,7 +172,7 @@ impl File for UnixFile {
|
||||
}
|
||||
|
||||
#[instrument(err, skip_all, level = Level::TRACE)]
|
||||
fn pread(&self, pos: usize, c: Completion) -> Result<Completion> {
|
||||
fn pread(&self, pos: u64, c: Completion) -> Result<Completion> {
|
||||
let file = self.file.lock();
|
||||
let result = unsafe {
|
||||
let r = c.as_read();
|
||||
@@ -217,7 +197,7 @@ impl File for UnixFile {
|
||||
}
|
||||
|
||||
#[instrument(err, skip_all, level = Level::TRACE)]
|
||||
fn pwrite(&self, pos: usize, buffer: Arc<crate::Buffer>, c: Completion) -> Result<Completion> {
|
||||
fn pwrite(&self, pos: u64, buffer: Arc<crate::Buffer>, c: Completion) -> Result<Completion> {
|
||||
let file = self.file.lock();
|
||||
let result = unsafe {
|
||||
libc::pwrite(
|
||||
@@ -241,7 +221,7 @@ impl File for UnixFile {
|
||||
#[instrument(err, skip_all, level = Level::TRACE)]
|
||||
fn pwritev(
|
||||
&self,
|
||||
pos: usize,
|
||||
pos: u64,
|
||||
buffers: Vec<Arc<crate::Buffer>>,
|
||||
c: Completion,
|
||||
) -> Result<Completion> {
|
||||
@@ -251,7 +231,7 @@ impl File for UnixFile {
|
||||
}
|
||||
let file = self.file.lock();
|
||||
|
||||
match try_pwritev_raw(file.as_raw_fd(), pos as u64, &buffers, 0, 0) {
|
||||
match try_pwritev_raw(file.as_raw_fd(), pos, &buffers, 0, 0) {
|
||||
Ok(written) => {
|
||||
trace!("pwritev wrote {written}");
|
||||
c.complete(written as i32);
|
||||
@@ -268,24 +248,21 @@ impl File for UnixFile {
|
||||
let file = self.file.lock();
|
||||
|
||||
let result = unsafe {
|
||||
|
||||
#[cfg(not(any(target_os = "macos", target_os = "ios")))]
|
||||
{
|
||||
libc::fsync(file.as_raw_fd())
|
||||
}
|
||||
|
||||
|
||||
#[cfg(any(target_os = "macos", target_os = "ios"))]
|
||||
{
|
||||
libc::fcntl(file.as_raw_fd(), libc::F_FULLFSYNC)
|
||||
}
|
||||
|
||||
};
|
||||
|
||||
|
||||
if result == -1 {
|
||||
let e = std::io::Error::last_os_error();
|
||||
Err(e.into())
|
||||
} else {
|
||||
|
||||
#[cfg(not(any(target_os = "macos", target_os = "ios")))]
|
||||
trace!("fsync");
|
||||
|
||||
@@ -304,9 +281,9 @@ impl File for UnixFile {
|
||||
}
|
||||
|
||||
#[instrument(err, skip_all, level = Level::INFO)]
|
||||
fn truncate(&self, len: usize, c: Completion) -> Result<Completion> {
|
||||
fn truncate(&self, len: u64, c: Completion) -> Result<Completion> {
|
||||
let file = self.file.lock();
|
||||
let result = file.set_len(len as u64);
|
||||
let result = file.set_len(len);
|
||||
match result {
|
||||
Ok(()) => {
|
||||
trace!("file truncated to len=({})", len);
|
||||
|
||||
@@ -81,8 +81,6 @@ impl VfsMod {
|
||||
}
|
||||
}
|
||||
|
||||
// #Safety:
|
||||
// the callback wrapper in the extension library is FnOnce, so we know
|
||||
/// # Safety
|
||||
/// the callback wrapper in the extension library is FnOnce, so we know
|
||||
/// that the into_raw/from_raw contract will hold
|
||||
@@ -121,7 +119,7 @@ impl File for VfsFileImpl {
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn pread(&self, pos: usize, c: Completion) -> Result<Completion> {
|
||||
fn pread(&self, pos: u64, c: Completion) -> Result<Completion> {
|
||||
if self.vfs.is_null() {
|
||||
c.complete(-1);
|
||||
return Err(LimboError::ExtensionError("VFS is null".to_string()));
|
||||
@@ -145,7 +143,7 @@ impl File for VfsFileImpl {
|
||||
Ok(c)
|
||||
}
|
||||
|
||||
fn pwrite(&self, pos: usize, buffer: Arc<Buffer>, c: Completion) -> Result<Completion> {
|
||||
fn pwrite(&self, pos: u64, buffer: Arc<Buffer>, c: Completion) -> Result<Completion> {
|
||||
if self.vfs.is_null() {
|
||||
c.complete(-1);
|
||||
return Err(LimboError::ExtensionError("VFS is null".to_string()));
|
||||
@@ -192,7 +190,7 @@ impl File for VfsFileImpl {
|
||||
}
|
||||
}
|
||||
|
||||
fn truncate(&self, len: usize, c: Completion) -> Result<Completion> {
|
||||
fn truncate(&self, len: u64, c: Completion) -> Result<Completion> {
|
||||
if self.vfs.is_null() {
|
||||
c.complete(-1);
|
||||
return Err(LimboError::ExtensionError("VFS is null".to_string()));
|
||||
|
||||
115
core/io/windows.rs
Normal file
115
core/io/windows.rs
Normal file
@@ -0,0 +1,115 @@
|
||||
use crate::{Clock, Completion, File, Instant, LimboError, OpenFlags, Result, IO};
|
||||
use parking_lot::RwLock;
|
||||
use std::io::{Read, Seek, Write};
|
||||
use std::sync::Arc;
|
||||
use tracing::{debug, instrument, trace, Level};
|
||||
pub struct WindowsIO {}
|
||||
|
||||
impl WindowsIO {
|
||||
pub fn new() -> Result<Self> {
|
||||
debug!("Using IO backend 'syscall'");
|
||||
Ok(Self {})
|
||||
}
|
||||
}
|
||||
|
||||
impl IO for WindowsIO {
|
||||
#[instrument(err, skip_all, level = Level::TRACE)]
|
||||
fn open_file(&self, path: &str, flags: OpenFlags, direct: bool) -> Result<Arc<dyn File>> {
|
||||
trace!("open_file(path = {})", path);
|
||||
let mut file = std::fs::File::options();
|
||||
file.read(true);
|
||||
|
||||
if !flags.contains(OpenFlags::ReadOnly) {
|
||||
file.write(true);
|
||||
file.create(flags.contains(OpenFlags::Create));
|
||||
}
|
||||
|
||||
let file = file.open(path)?;
|
||||
Ok(Arc::new(WindowsFile {
|
||||
file: RwLock::new(file),
|
||||
}))
|
||||
}
|
||||
|
||||
#[instrument(err, skip_all, level = Level::TRACE)]
|
||||
fn remove_file(&self, path: &str) -> Result<()> {
|
||||
trace!("remove_file(path = {})", path);
|
||||
Ok(std::fs::remove_file(path)?)
|
||||
}
|
||||
|
||||
#[instrument(err, skip_all, level = Level::TRACE)]
|
||||
fn run_once(&self) -> Result<()> {
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
impl Clock for WindowsIO {
|
||||
fn now(&self) -> Instant {
|
||||
let now = chrono::Local::now();
|
||||
Instant {
|
||||
secs: now.timestamp(),
|
||||
micros: now.timestamp_subsec_micros(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub struct WindowsFile {
|
||||
file: RwLock<std::fs::File>,
|
||||
}
|
||||
|
||||
impl File for WindowsFile {
|
||||
#[instrument(err, skip_all, level = Level::TRACE)]
|
||||
fn lock_file(&self, exclusive: bool) -> Result<()> {
|
||||
unimplemented!()
|
||||
}
|
||||
|
||||
#[instrument(err, skip_all, level = Level::TRACE)]
|
||||
fn unlock_file(&self) -> Result<()> {
|
||||
unimplemented!()
|
||||
}
|
||||
|
||||
#[instrument(skip(self, c), level = Level::TRACE)]
|
||||
fn pread(&self, pos: u64, c: Completion) -> Result<Completion> {
|
||||
let mut file = self.file.write();
|
||||
file.seek(std::io::SeekFrom::Start(pos))?;
|
||||
let nr = {
|
||||
let r = c.as_read();
|
||||
let buf = r.buf();
|
||||
let buf = buf.as_mut_slice();
|
||||
file.read_exact(buf)?;
|
||||
buf.len() as i32
|
||||
};
|
||||
c.complete(nr);
|
||||
Ok(c)
|
||||
}
|
||||
|
||||
#[instrument(skip(self, c, buffer), level = Level::TRACE)]
|
||||
fn pwrite(&self, pos: u64, buffer: Arc<crate::Buffer>, c: Completion) -> Result<Completion> {
|
||||
let mut file = self.file.write();
|
||||
file.seek(std::io::SeekFrom::Start(pos))?;
|
||||
let buf = buffer.as_slice();
|
||||
file.write_all(buf)?;
|
||||
c.complete(buffer.len() as i32);
|
||||
Ok(c)
|
||||
}
|
||||
|
||||
#[instrument(err, skip_all, level = Level::TRACE)]
|
||||
fn sync(&self, c: Completion) -> Result<Completion> {
|
||||
let file = self.file.write();
|
||||
file.sync_all()?;
|
||||
c.complete(0);
|
||||
Ok(c)
|
||||
}
|
||||
|
||||
#[instrument(err, skip_all, level = Level::TRACE)]
|
||||
fn truncate(&self, len: u64, c: Completion) -> Result<Completion> {
|
||||
let file = self.file.write();
|
||||
file.set_len(len)?;
|
||||
c.complete(0);
|
||||
Ok(c)
|
||||
}
|
||||
|
||||
fn size(&self) -> Result<u64> {
|
||||
let file = self.file.read();
|
||||
Ok(file.metadata().unwrap().len())
|
||||
}
|
||||
}
|
||||
@@ -1277,6 +1277,12 @@ impl Connection {
|
||||
std::fs::set_permissions(&opts.path, perms.permissions())?;
|
||||
}
|
||||
let conn = db.connect()?;
|
||||
if let Some(cipher) = opts.cipher {
|
||||
let _ = conn.pragma_update("cipher", format!("'{cipher}'"));
|
||||
}
|
||||
if let Some(hexkey) = opts.hexkey {
|
||||
let _ = conn.pragma_update("hexkey", format!("'{hexkey}'"));
|
||||
}
|
||||
Ok((io, conn))
|
||||
}
|
||||
|
||||
|
||||
@@ -1047,8 +1047,8 @@ impl Column {
|
||||
}
|
||||
|
||||
// TODO: This might replace some of util::columns_from_create_table_body
|
||||
impl From<ColumnDefinition> for Column {
|
||||
fn from(value: ColumnDefinition) -> Self {
|
||||
impl From<&ColumnDefinition> for Column {
|
||||
fn from(value: &ColumnDefinition) -> Self {
|
||||
let name = value.col_name.as_str();
|
||||
|
||||
let mut default = None;
|
||||
@@ -1057,13 +1057,13 @@ impl From<ColumnDefinition> for Column {
|
||||
let mut unique = false;
|
||||
let mut collation = None;
|
||||
|
||||
for ast::NamedColumnConstraint { constraint, .. } in value.constraints {
|
||||
for ast::NamedColumnConstraint { constraint, .. } in &value.constraints {
|
||||
match constraint {
|
||||
ast::ColumnConstraint::PrimaryKey { .. } => primary_key = true,
|
||||
ast::ColumnConstraint::NotNull { .. } => notnull = true,
|
||||
ast::ColumnConstraint::Unique(..) => unique = true,
|
||||
ast::ColumnConstraint::Default(expr) => {
|
||||
default.replace(expr);
|
||||
default.replace(expr.clone());
|
||||
}
|
||||
ast::ColumnConstraint::Collate { collation_name } => {
|
||||
collation.replace(
|
||||
@@ -1082,11 +1082,14 @@ impl From<ColumnDefinition> for Column {
|
||||
|
||||
let ty_str = value
|
||||
.col_type
|
||||
.as_ref()
|
||||
.map(|t| t.name.to_string())
|
||||
.unwrap_or_default();
|
||||
|
||||
let hidden = ty_str.contains("HIDDEN");
|
||||
|
||||
Column {
|
||||
name: Some(name.to_string()),
|
||||
name: Some(normalize_ident(name)),
|
||||
ty,
|
||||
default,
|
||||
notnull,
|
||||
@@ -1095,7 +1098,7 @@ impl From<ColumnDefinition> for Column {
|
||||
is_rowid_alias: primary_key && matches!(ty, Type::Integer),
|
||||
unique,
|
||||
collation,
|
||||
hidden: false,
|
||||
hidden,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -91,7 +91,9 @@ impl DatabaseStorage for DatabaseFile {
|
||||
if !(512..=65536).contains(&size) || size & (size - 1) != 0 {
|
||||
return Err(LimboError::NotADB);
|
||||
}
|
||||
let pos = (page_idx - 1) * size;
|
||||
let Some(pos) = (page_idx as u64 - 1).checked_mul(size as u64) else {
|
||||
return Err(LimboError::IntegerOverflow);
|
||||
};
|
||||
|
||||
if let Some(ctx) = io_ctx.encryption_context() {
|
||||
let encryption_ctx = ctx.clone();
|
||||
@@ -145,7 +147,9 @@ impl DatabaseStorage for DatabaseFile {
|
||||
assert!(buffer_size >= 512);
|
||||
assert!(buffer_size <= 65536);
|
||||
assert_eq!(buffer_size & (buffer_size - 1), 0);
|
||||
let pos = (page_idx - 1) * buffer_size;
|
||||
let Some(pos) = (page_idx as u64 - 1).checked_mul(buffer_size as u64) else {
|
||||
return Err(LimboError::IntegerOverflow);
|
||||
};
|
||||
let buffer = {
|
||||
if let Some(ctx) = io_ctx.encryption_context() {
|
||||
encrypt_buffer(page_idx, buffer, ctx)
|
||||
@@ -169,7 +173,9 @@ impl DatabaseStorage for DatabaseFile {
|
||||
assert!(page_size <= 65536);
|
||||
assert_eq!(page_size & (page_size - 1), 0);
|
||||
|
||||
let pos = (first_page_idx - 1) * page_size;
|
||||
let Some(pos) = (first_page_idx as u64 - 1).checked_mul(page_size as u64) else {
|
||||
return Err(LimboError::IntegerOverflow);
|
||||
};
|
||||
let buffers = {
|
||||
if let Some(ctx) = io_ctx.encryption_context() {
|
||||
buffers
|
||||
@@ -198,7 +204,7 @@ impl DatabaseStorage for DatabaseFile {
|
||||
|
||||
#[instrument(skip_all, level = Level::INFO)]
|
||||
fn truncate(&self, len: usize, c: Completion) -> Result<Completion> {
|
||||
let c = self.file.truncate(len, c)?;
|
||||
let c = self.file.truncate(len as u64, c)?;
|
||||
Ok(c)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,14 +1,13 @@
|
||||
#![allow(unused_variables, dead_code)]
|
||||
use crate::{LimboError, Result};
|
||||
use aegis::aegis256::Aegis256;
|
||||
use aes_gcm::{
|
||||
aead::{Aead, AeadCore, KeyInit, OsRng},
|
||||
Aes256Gcm, Key, Nonce,
|
||||
};
|
||||
use aes_gcm::aead::{AeadCore, OsRng};
|
||||
use std::ops::Deref;
|
||||
use turso_macros::match_ignore_ascii_case;
|
||||
|
||||
pub const ENCRYPTED_PAGE_SIZE: usize = 4096;
|
||||
// AEGIS-256 supports both 16 and 32 byte tags, we use the 16 byte variant, it is faster
|
||||
// and provides sufficient security for our use case.
|
||||
const AEGIS_TAG_SIZE: usize = 16;
|
||||
const AES256GCM_TAG_SIZE: usize = 16;
|
||||
|
||||
#[repr(transparent)]
|
||||
#[derive(Clone)]
|
||||
@@ -74,10 +73,25 @@ impl Drop for EncryptionKey {
|
||||
}
|
||||
}
|
||||
|
||||
pub trait AeadCipher {
|
||||
fn encrypt(&self, plaintext: &[u8], ad: &[u8]) -> Result<(Vec<u8>, Vec<u8>)>;
|
||||
fn decrypt(&self, ciphertext: &[u8], nonce: &[u8], ad: &[u8]) -> Result<Vec<u8>>;
|
||||
|
||||
fn encrypt_detached(&self, plaintext: &[u8], ad: &[u8]) -> Result<(Vec<u8>, Vec<u8>, Vec<u8>)>;
|
||||
|
||||
fn decrypt_detached(
|
||||
&self,
|
||||
ciphertext: &[u8],
|
||||
nonce: &[u8],
|
||||
tag: &[u8],
|
||||
ad: &[u8],
|
||||
) -> Result<Vec<u8>>;
|
||||
}
|
||||
|
||||
// wrapper struct for AEGIS-256 cipher, because the crate we use is a bit low-level and we add
|
||||
// some nice abstractions here
|
||||
// note, the AEGIS has many variants and support for hardware acceleration. Here we just use the
|
||||
// vanilla version, which is still order of maginitudes faster than AES-GCM in software. Hardware
|
||||
// vanilla version, which is still order of magnitudes faster than AES-GCM in software. Hardware
|
||||
// based compilation is left for future work.
|
||||
#[derive(Clone)]
|
||||
pub struct Aegis256Cipher {
|
||||
@@ -85,39 +99,154 @@ pub struct Aegis256Cipher {
|
||||
}
|
||||
|
||||
impl Aegis256Cipher {
|
||||
// AEGIS-256 supports both 16 and 32 byte tags, we use the 16 byte variant, it is faster
|
||||
// and provides sufficient security for our use case.
|
||||
const TAG_SIZE: usize = 16;
|
||||
fn new(key: &EncryptionKey) -> Self {
|
||||
Self { key: key.clone() }
|
||||
}
|
||||
}
|
||||
|
||||
fn encrypt(&self, plaintext: &[u8], ad: &[u8]) -> Result<(Vec<u8>, [u8; 32])> {
|
||||
impl AeadCipher for Aegis256Cipher {
|
||||
fn encrypt(&self, plaintext: &[u8], ad: &[u8]) -> Result<(Vec<u8>, Vec<u8>)> {
|
||||
let nonce = generate_secure_nonce();
|
||||
let (ciphertext, tag) =
|
||||
Aegis256::<16>::new(self.key.as_bytes(), &nonce).encrypt(plaintext, ad);
|
||||
Aegis256::<AEGIS_TAG_SIZE>::new(self.key.as_bytes(), &nonce).encrypt(plaintext, ad);
|
||||
|
||||
let mut result = ciphertext;
|
||||
result.extend_from_slice(&tag);
|
||||
Ok((result, nonce))
|
||||
Ok((result, nonce.to_vec()))
|
||||
}
|
||||
|
||||
fn decrypt(&self, ciphertext: &[u8], nonce: &[u8; 32], ad: &[u8]) -> Result<Vec<u8>> {
|
||||
if ciphertext.len() < Self::TAG_SIZE {
|
||||
return Err(LimboError::InternalError(
|
||||
"Ciphertext too short for AEGIS-256".into(),
|
||||
));
|
||||
fn decrypt(&self, ciphertext: &[u8], nonce: &[u8], ad: &[u8]) -> Result<Vec<u8>> {
|
||||
if ciphertext.len() < AEGIS_TAG_SIZE {
|
||||
return Err(LimboError::InternalError("Ciphertext too short".into()));
|
||||
}
|
||||
let (ct, tag) = ciphertext.split_at(ciphertext.len() - Self::TAG_SIZE);
|
||||
let tag_array: [u8; 16] = tag
|
||||
.try_into()
|
||||
.map_err(|_| LimboError::InternalError("Invalid tag size for AEGIS-256".into()))?;
|
||||
let (ct, tag) = ciphertext.split_at(ciphertext.len() - AEGIS_TAG_SIZE);
|
||||
let tag_array: [u8; AEGIS_TAG_SIZE] = tag.try_into().map_err(|_| {
|
||||
LimboError::InternalError(format!("Invalid tag size for AEGIS-256 {AEGIS_TAG_SIZE}"))
|
||||
})?;
|
||||
|
||||
let plaintext = Aegis256::<16>::new(self.key.as_bytes(), nonce)
|
||||
let nonce_array: [u8; 32] = nonce
|
||||
.try_into()
|
||||
.map_err(|_| LimboError::InternalError("Invalid nonce size for AEGIS-256".into()))?;
|
||||
|
||||
Aegis256::<AEGIS_TAG_SIZE>::new(self.key.as_bytes(), &nonce_array)
|
||||
.decrypt(ct, &tag_array, ad)
|
||||
.map_err(|_| {
|
||||
LimboError::InternalError("AEGIS-256 decryption failed: invalid tag".into())
|
||||
})?;
|
||||
Ok(plaintext)
|
||||
.map_err(|_| LimboError::InternalError("AEGIS-256 decryption failed".into()))
|
||||
}
|
||||
|
||||
fn encrypt_detached(&self, plaintext: &[u8], ad: &[u8]) -> Result<(Vec<u8>, Vec<u8>, Vec<u8>)> {
|
||||
let nonce = generate_secure_nonce();
|
||||
let (ciphertext, tag) =
|
||||
Aegis256::<AEGIS_TAG_SIZE>::new(self.key.as_bytes(), &nonce).encrypt(plaintext, ad);
|
||||
|
||||
Ok((ciphertext, tag.to_vec(), nonce.to_vec()))
|
||||
}
|
||||
|
||||
fn decrypt_detached(
|
||||
&self,
|
||||
ciphertext: &[u8],
|
||||
nonce: &[u8],
|
||||
tag: &[u8],
|
||||
ad: &[u8],
|
||||
) -> Result<Vec<u8>> {
|
||||
let tag_array: [u8; AEGIS_TAG_SIZE] = tag.try_into().map_err(|_| {
|
||||
LimboError::InternalError(format!("Invalid tag size for AEGIS-256 {AEGIS_TAG_SIZE}"))
|
||||
})?;
|
||||
let nonce_array: [u8; 32] = nonce
|
||||
.try_into()
|
||||
.map_err(|_| LimboError::InternalError("Invalid nonce size for AEGIS-256".into()))?;
|
||||
|
||||
Aegis256::<AEGIS_TAG_SIZE>::new(self.key.as_bytes(), &nonce_array)
|
||||
.decrypt(ciphertext, &tag_array, ad)
|
||||
.map_err(|_| LimboError::InternalError("AEGIS-256 decrypt_detached failed".into()))
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone)]
|
||||
pub struct Aes256GcmCipher {
|
||||
key: EncryptionKey,
|
||||
}
|
||||
|
||||
impl Aes256GcmCipher {
|
||||
fn new(key: &EncryptionKey) -> Self {
|
||||
Self { key: key.clone() }
|
||||
}
|
||||
}
|
||||
|
||||
impl AeadCipher for Aes256GcmCipher {
|
||||
fn encrypt(&self, plaintext: &[u8], _ad: &[u8]) -> Result<(Vec<u8>, Vec<u8>)> {
|
||||
use aes_gcm::aead::{AeadInPlace, KeyInit};
|
||||
use aes_gcm::Aes256Gcm;
|
||||
|
||||
let cipher = Aes256Gcm::new_from_slice(self.key.as_bytes())
|
||||
.map_err(|_| LimboError::InternalError("Bad AES key".into()))?;
|
||||
let nonce = Aes256Gcm::generate_nonce(&mut rand::thread_rng());
|
||||
let mut buffer = plaintext.to_vec();
|
||||
|
||||
let tag = cipher
|
||||
.encrypt_in_place_detached(&nonce, b"", &mut buffer)
|
||||
.map_err(|_| LimboError::InternalError("AES-GCM encrypt failed".into()))?;
|
||||
|
||||
buffer.extend_from_slice(&tag[..AES256GCM_TAG_SIZE]);
|
||||
Ok((buffer, nonce.to_vec()))
|
||||
}
|
||||
|
||||
fn decrypt(&self, ciphertext: &[u8], nonce: &[u8], ad: &[u8]) -> Result<Vec<u8>> {
|
||||
use aes_gcm::aead::{AeadInPlace, KeyInit};
|
||||
use aes_gcm::{Aes256Gcm, Nonce};
|
||||
|
||||
if ciphertext.len() < AES256GCM_TAG_SIZE {
|
||||
return Err(LimboError::InternalError("Ciphertext too short".into()));
|
||||
}
|
||||
let (ct, tag) = ciphertext.split_at(ciphertext.len() - AES256GCM_TAG_SIZE);
|
||||
|
||||
let cipher = Aes256Gcm::new_from_slice(self.key.as_bytes())
|
||||
.map_err(|_| LimboError::InternalError("Bad AES key".into()))?;
|
||||
let nonce = Nonce::from_slice(nonce);
|
||||
|
||||
let mut buffer = ct.to_vec();
|
||||
cipher
|
||||
.decrypt_in_place_detached(nonce, ad, &mut buffer, tag.into())
|
||||
.map_err(|_| LimboError::InternalError("AES-GCM decrypt failed".into()))?;
|
||||
|
||||
Ok(buffer)
|
||||
}
|
||||
|
||||
fn encrypt_detached(&self, plaintext: &[u8], ad: &[u8]) -> Result<(Vec<u8>, Vec<u8>, Vec<u8>)> {
|
||||
use aes_gcm::aead::{AeadInPlace, KeyInit};
|
||||
use aes_gcm::Aes256Gcm;
|
||||
|
||||
let cipher = Aes256Gcm::new_from_slice(self.key.as_bytes())
|
||||
.map_err(|_| LimboError::InternalError("Bad AES key".into()))?;
|
||||
let nonce = Aes256Gcm::generate_nonce(&mut rand::thread_rng());
|
||||
|
||||
let mut buffer = plaintext.to_vec();
|
||||
let tag = cipher
|
||||
.encrypt_in_place_detached(&nonce, ad, &mut buffer)
|
||||
.map_err(|_| LimboError::InternalError("AES-GCM encrypt_detached failed".into()))?;
|
||||
|
||||
Ok((buffer, nonce.to_vec(), tag.to_vec()))
|
||||
}
|
||||
|
||||
fn decrypt_detached(
|
||||
&self,
|
||||
ciphertext: &[u8],
|
||||
nonce: &[u8],
|
||||
tag: &[u8],
|
||||
ad: &[u8],
|
||||
) -> Result<Vec<u8>> {
|
||||
use aes_gcm::aead::{AeadInPlace, KeyInit};
|
||||
use aes_gcm::{Aes256Gcm, Nonce};
|
||||
|
||||
let cipher = Aes256Gcm::new_from_slice(self.key.as_bytes())
|
||||
.map_err(|_| LimboError::InternalError("Bad AES key".into()))?;
|
||||
let nonce = Nonce::from_slice(nonce);
|
||||
|
||||
let mut buffer = ciphertext.to_vec();
|
||||
cipher
|
||||
.decrypt_in_place_detached(nonce, ad, &mut buffer, tag.into())
|
||||
.map_err(|_| LimboError::InternalError("AES-GCM decrypt_detached failed".into()))?;
|
||||
|
||||
Ok(buffer)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -180,8 +309,8 @@ impl CipherMode {
|
||||
/// Returns the authentication tag size for this cipher mode.
|
||||
pub fn tag_size(&self) -> usize {
|
||||
match self {
|
||||
CipherMode::Aes256Gcm => 16,
|
||||
CipherMode::Aegis256 => 16,
|
||||
CipherMode::Aes256Gcm => AES256GCM_TAG_SIZE,
|
||||
CipherMode::Aegis256 => AEGIS_TAG_SIZE,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -193,8 +322,17 @@ impl CipherMode {
|
||||
|
||||
#[derive(Clone)]
|
||||
pub enum Cipher {
|
||||
Aes256Gcm(Box<Aes256Gcm>),
|
||||
Aegis256(Box<Aegis256Cipher>),
|
||||
Aes256Gcm(Aes256GcmCipher),
|
||||
Aegis256(Aegis256Cipher),
|
||||
}
|
||||
|
||||
impl Cipher {
|
||||
fn as_aead(&self) -> &dyn AeadCipher {
|
||||
match self {
|
||||
Cipher::Aes256Gcm(c) => c,
|
||||
Cipher::Aegis256(c) => c,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl std::fmt::Debug for Cipher {
|
||||
@@ -210,10 +348,11 @@ impl std::fmt::Debug for Cipher {
|
||||
pub struct EncryptionContext {
|
||||
cipher_mode: CipherMode,
|
||||
cipher: Cipher,
|
||||
page_size: usize,
|
||||
}
|
||||
|
||||
impl EncryptionContext {
|
||||
pub fn new(cipher_mode: CipherMode, key: &EncryptionKey) -> Result<Self> {
|
||||
pub fn new(cipher_mode: CipherMode, key: &EncryptionKey, page_size: usize) -> Result<Self> {
|
||||
let required_size = cipher_mode.required_key_size();
|
||||
if key.as_slice().len() != required_size {
|
||||
return Err(crate::LimboError::InvalidArgument(format!(
|
||||
@@ -225,15 +364,13 @@ impl EncryptionContext {
|
||||
}
|
||||
|
||||
let cipher = match cipher_mode {
|
||||
CipherMode::Aes256Gcm => {
|
||||
let cipher_key: &Key<Aes256Gcm> = key.as_ref().into();
|
||||
Cipher::Aes256Gcm(Box::new(Aes256Gcm::new(cipher_key)))
|
||||
}
|
||||
CipherMode::Aegis256 => Cipher::Aegis256(Box::new(Aegis256Cipher::new(key))),
|
||||
CipherMode::Aes256Gcm => Cipher::Aes256Gcm(Aes256GcmCipher::new(key)),
|
||||
CipherMode::Aegis256 => Cipher::Aegis256(Aegis256Cipher::new(key)),
|
||||
};
|
||||
Ok(Self {
|
||||
cipher_mode,
|
||||
cipher,
|
||||
page_size,
|
||||
})
|
||||
}
|
||||
|
||||
@@ -255,36 +392,38 @@ impl EncryptionContext {
|
||||
tracing::debug!("encrypting page {}", page_id);
|
||||
assert_eq!(
|
||||
page.len(),
|
||||
ENCRYPTED_PAGE_SIZE,
|
||||
"Page data must be exactly {ENCRYPTED_PAGE_SIZE} bytes"
|
||||
self.page_size,
|
||||
"Page data must be exactly {} bytes",
|
||||
self.page_size
|
||||
);
|
||||
|
||||
let metadata_size = self.cipher_mode.metadata_size();
|
||||
let reserved_bytes = &page[ENCRYPTED_PAGE_SIZE - metadata_size..];
|
||||
let reserved_bytes = &page[self.page_size - metadata_size..];
|
||||
let reserved_bytes_zeroed = reserved_bytes.iter().all(|&b| b == 0);
|
||||
assert!(
|
||||
reserved_bytes_zeroed,
|
||||
"last reserved bytes must be empty/zero, but found non-zero bytes"
|
||||
);
|
||||
|
||||
let payload = &page[..ENCRYPTED_PAGE_SIZE - metadata_size];
|
||||
let payload = &page[..self.page_size - metadata_size];
|
||||
let (encrypted, nonce) = self.encrypt_raw(payload)?;
|
||||
|
||||
let nonce_size = self.cipher_mode.nonce_size();
|
||||
assert_eq!(
|
||||
encrypted.len(),
|
||||
ENCRYPTED_PAGE_SIZE - nonce_size,
|
||||
self.page_size - nonce_size,
|
||||
"Encrypted page must be exactly {} bytes",
|
||||
ENCRYPTED_PAGE_SIZE - nonce_size
|
||||
self.page_size - nonce_size
|
||||
);
|
||||
|
||||
let mut result = Vec::with_capacity(ENCRYPTED_PAGE_SIZE);
|
||||
let mut result = Vec::with_capacity(self.page_size);
|
||||
result.extend_from_slice(&encrypted);
|
||||
result.extend_from_slice(&nonce);
|
||||
assert_eq!(
|
||||
result.len(),
|
||||
ENCRYPTED_PAGE_SIZE,
|
||||
"Encrypted page must be exactly {ENCRYPTED_PAGE_SIZE} bytes"
|
||||
self.page_size,
|
||||
"Encrypted page must be exactly {} bytes",
|
||||
self.page_size
|
||||
);
|
||||
Ok(result)
|
||||
}
|
||||
@@ -298,8 +437,9 @@ impl EncryptionContext {
|
||||
tracing::debug!("decrypting page {}", page_id);
|
||||
assert_eq!(
|
||||
encrypted_page.len(),
|
||||
ENCRYPTED_PAGE_SIZE,
|
||||
"Encrypted page data must be exactly {ENCRYPTED_PAGE_SIZE} bytes"
|
||||
self.page_size,
|
||||
"Encrypted page data must be exactly {} bytes",
|
||||
self.page_size
|
||||
);
|
||||
|
||||
let nonce_size = self.cipher_mode.nonce_size();
|
||||
@@ -312,60 +452,40 @@ impl EncryptionContext {
|
||||
let metadata_size = self.cipher_mode.metadata_size();
|
||||
assert_eq!(
|
||||
decrypted_data.len(),
|
||||
ENCRYPTED_PAGE_SIZE - metadata_size,
|
||||
self.page_size - metadata_size,
|
||||
"Decrypted page data must be exactly {} bytes",
|
||||
ENCRYPTED_PAGE_SIZE - metadata_size
|
||||
self.page_size - metadata_size
|
||||
);
|
||||
|
||||
let mut result = Vec::with_capacity(ENCRYPTED_PAGE_SIZE);
|
||||
let mut result = Vec::with_capacity(self.page_size);
|
||||
result.extend_from_slice(&decrypted_data);
|
||||
result.resize(ENCRYPTED_PAGE_SIZE, 0);
|
||||
result.resize(self.page_size, 0);
|
||||
assert_eq!(
|
||||
result.len(),
|
||||
ENCRYPTED_PAGE_SIZE,
|
||||
"Decrypted page data must be exactly {ENCRYPTED_PAGE_SIZE} bytes"
|
||||
self.page_size,
|
||||
"Decrypted page data must be exactly {} bytes",
|
||||
self.page_size
|
||||
);
|
||||
Ok(result)
|
||||
}
|
||||
|
||||
/// encrypts raw data using the configured cipher, returns ciphertext and nonce
|
||||
fn encrypt_raw(&self, plaintext: &[u8]) -> Result<(Vec<u8>, Vec<u8>)> {
|
||||
match &self.cipher {
|
||||
Cipher::Aes256Gcm(cipher) => {
|
||||
let nonce = Aes256Gcm::generate_nonce(&mut OsRng);
|
||||
let ciphertext = cipher
|
||||
.encrypt(&nonce, plaintext)
|
||||
.map_err(|e| LimboError::InternalError(format!("Encryption failed: {e:?}")))?;
|
||||
Ok((ciphertext, nonce.to_vec()))
|
||||
}
|
||||
Cipher::Aegis256(cipher) => {
|
||||
let ad = b"";
|
||||
let (ciphertext, nonce) = cipher.encrypt(plaintext, ad)?;
|
||||
Ok((ciphertext, nonce.to_vec()))
|
||||
}
|
||||
}
|
||||
self.cipher.as_aead().encrypt(plaintext, b"")
|
||||
}
|
||||
|
||||
fn decrypt_raw(&self, ciphertext: &[u8], nonce: &[u8]) -> Result<Vec<u8>> {
|
||||
match &self.cipher {
|
||||
Cipher::Aes256Gcm(cipher) => {
|
||||
let nonce = Nonce::from_slice(nonce);
|
||||
let plaintext = cipher.decrypt(nonce, ciphertext).map_err(|e| {
|
||||
crate::LimboError::InternalError(format!("Decryption failed: {e:?}"))
|
||||
})?;
|
||||
Ok(plaintext)
|
||||
}
|
||||
Cipher::Aegis256(cipher) => {
|
||||
let nonce_array: [u8; 32] = nonce.try_into().map_err(|_| {
|
||||
LimboError::InternalError(format!(
|
||||
"Invalid nonce size for AEGIS-256: expected 32, got {}",
|
||||
nonce.len()
|
||||
))
|
||||
})?;
|
||||
let ad = b"";
|
||||
cipher.decrypt(ciphertext, &nonce_array, ad)
|
||||
}
|
||||
}
|
||||
self.cipher.as_aead().decrypt(ciphertext, nonce, b"")
|
||||
}
|
||||
|
||||
fn encrypt_raw_detached(&self, plaintext: &[u8]) -> Result<(Vec<u8>, Vec<u8>, Vec<u8>)> {
|
||||
self.cipher.as_aead().encrypt_detached(plaintext, b"")
|
||||
}
|
||||
|
||||
fn decrypt_raw_detached(&self, ciphertext: &[u8], nonce: &[u8], tag: &[u8]) -> Result<Vec<u8>> {
|
||||
self.cipher
|
||||
.as_aead()
|
||||
.decrypt_detached(ciphertext, nonce, tag, b"")
|
||||
}
|
||||
|
||||
#[cfg(not(feature = "encryption"))]
|
||||
@@ -391,10 +511,12 @@ fn generate_secure_nonce() -> [u8; 32] {
|
||||
nonce
|
||||
}
|
||||
|
||||
#[cfg(feature = "encryption")]
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use rand::Rng;
|
||||
const DEFAULT_ENCRYPTED_PAGE_SIZE: usize = 4096;
|
||||
|
||||
fn generate_random_hex_key() -> String {
|
||||
let mut rng = rand::thread_rng();
|
||||
@@ -404,15 +526,14 @@ mod tests {
|
||||
}
|
||||
|
||||
#[test]
|
||||
#[cfg(feature = "encryption")]
|
||||
fn test_aes_encrypt_decrypt_round_trip() {
|
||||
let mut rng = rand::thread_rng();
|
||||
let cipher_mode = CipherMode::Aes256Gcm;
|
||||
let metadata_size = cipher_mode.metadata_size();
|
||||
let data_size = ENCRYPTED_PAGE_SIZE - metadata_size;
|
||||
let data_size = DEFAULT_ENCRYPTED_PAGE_SIZE - metadata_size;
|
||||
|
||||
let page_data = {
|
||||
let mut page = vec![0u8; ENCRYPTED_PAGE_SIZE];
|
||||
let mut page = vec![0u8; DEFAULT_ENCRYPTED_PAGE_SIZE];
|
||||
page.iter_mut()
|
||||
.take(data_size)
|
||||
.for_each(|byte| *byte = rng.gen());
|
||||
@@ -420,21 +541,21 @@ mod tests {
|
||||
};
|
||||
|
||||
let key = EncryptionKey::from_hex_string(&generate_random_hex_key()).unwrap();
|
||||
let ctx = EncryptionContext::new(CipherMode::Aes256Gcm, &key).unwrap();
|
||||
let ctx = EncryptionContext::new(CipherMode::Aes256Gcm, &key, DEFAULT_ENCRYPTED_PAGE_SIZE)
|
||||
.unwrap();
|
||||
|
||||
let page_id = 42;
|
||||
let encrypted = ctx.encrypt_page(&page_data, page_id).unwrap();
|
||||
assert_eq!(encrypted.len(), ENCRYPTED_PAGE_SIZE);
|
||||
assert_eq!(encrypted.len(), DEFAULT_ENCRYPTED_PAGE_SIZE);
|
||||
assert_ne!(&encrypted[..data_size], &page_data[..data_size]);
|
||||
assert_ne!(&encrypted[..], &page_data[..]);
|
||||
|
||||
let decrypted = ctx.decrypt_page(&encrypted, page_id).unwrap();
|
||||
assert_eq!(decrypted.len(), ENCRYPTED_PAGE_SIZE);
|
||||
assert_eq!(decrypted.len(), DEFAULT_ENCRYPTED_PAGE_SIZE);
|
||||
assert_eq!(decrypted, page_data);
|
||||
}
|
||||
|
||||
#[test]
|
||||
#[cfg(feature = "encryption")]
|
||||
fn test_aegis256_cipher_wrapper() {
|
||||
let key = EncryptionKey::from_hex_string(&generate_random_hex_key()).unwrap();
|
||||
let cipher = Aegis256Cipher::new(&key);
|
||||
@@ -451,10 +572,10 @@ mod tests {
|
||||
}
|
||||
|
||||
#[test]
|
||||
#[cfg(feature = "encryption")]
|
||||
fn test_aegis256_raw_encryption() {
|
||||
let key = EncryptionKey::from_hex_string(&generate_random_hex_key()).unwrap();
|
||||
let ctx = EncryptionContext::new(CipherMode::Aegis256, &key).unwrap();
|
||||
let ctx = EncryptionContext::new(CipherMode::Aegis256, &key, DEFAULT_ENCRYPTED_PAGE_SIZE)
|
||||
.unwrap();
|
||||
|
||||
let plaintext = b"Hello, AEGIS-256!";
|
||||
let (ciphertext, nonce) = ctx.encrypt_raw(plaintext).unwrap();
|
||||
@@ -467,15 +588,14 @@ mod tests {
|
||||
}
|
||||
|
||||
#[test]
|
||||
#[cfg(feature = "encryption")]
|
||||
fn test_aegis256_encrypt_decrypt_round_trip() {
|
||||
let mut rng = rand::thread_rng();
|
||||
let cipher_mode = CipherMode::Aegis256;
|
||||
let metadata_size = cipher_mode.metadata_size();
|
||||
let data_size = ENCRYPTED_PAGE_SIZE - metadata_size;
|
||||
let data_size = DEFAULT_ENCRYPTED_PAGE_SIZE - metadata_size;
|
||||
|
||||
let page_data = {
|
||||
let mut page = vec![0u8; ENCRYPTED_PAGE_SIZE];
|
||||
let mut page = vec![0u8; DEFAULT_ENCRYPTED_PAGE_SIZE];
|
||||
page.iter_mut()
|
||||
.take(data_size)
|
||||
.for_each(|byte| *byte = rng.gen());
|
||||
@@ -483,15 +603,16 @@ mod tests {
|
||||
};
|
||||
|
||||
let key = EncryptionKey::from_hex_string(&generate_random_hex_key()).unwrap();
|
||||
let ctx = EncryptionContext::new(CipherMode::Aegis256, &key).unwrap();
|
||||
let ctx = EncryptionContext::new(CipherMode::Aegis256, &key, DEFAULT_ENCRYPTED_PAGE_SIZE)
|
||||
.unwrap();
|
||||
|
||||
let page_id = 42;
|
||||
let encrypted = ctx.encrypt_page(&page_data, page_id).unwrap();
|
||||
assert_eq!(encrypted.len(), ENCRYPTED_PAGE_SIZE);
|
||||
assert_eq!(encrypted.len(), DEFAULT_ENCRYPTED_PAGE_SIZE);
|
||||
assert_ne!(&encrypted[..data_size], &page_data[..data_size]);
|
||||
|
||||
let decrypted = ctx.decrypt_page(&encrypted, page_id).unwrap();
|
||||
assert_eq!(decrypted.len(), ENCRYPTED_PAGE_SIZE);
|
||||
assert_eq!(decrypted.len(), DEFAULT_ENCRYPTED_PAGE_SIZE);
|
||||
assert_eq!(decrypted, page_data);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -12,7 +12,7 @@ use super::pager::PageRef;
|
||||
const DEFAULT_PAGE_CACHE_SIZE_IN_PAGES_MAKE_ME_SMALLER_ONCE_WAL_SPILL_IS_IMPLEMENTED: usize =
|
||||
100000;
|
||||
|
||||
#[derive(Debug, Eq, Hash, PartialEq, Clone)]
|
||||
#[derive(Debug, Eq, Hash, PartialEq, Clone, Copy)]
|
||||
pub struct PageCacheKey {
|
||||
pgno: usize,
|
||||
}
|
||||
@@ -47,14 +47,21 @@ struct HashMapNode {
|
||||
value: NonNull<PageCacheEntry>,
|
||||
}
|
||||
|
||||
#[derive(Debug, PartialEq)]
|
||||
#[derive(Debug, Clone, PartialEq, thiserror::Error)]
|
||||
pub enum CacheError {
|
||||
#[error("{0}")]
|
||||
InternalError(String),
|
||||
#[error("page {pgno} is locked")]
|
||||
Locked { pgno: usize },
|
||||
#[error("page {pgno} is dirty")]
|
||||
Dirty { pgno: usize },
|
||||
#[error("page {pgno} is pinned")]
|
||||
Pinned { pgno: usize },
|
||||
#[error("cache active refs")]
|
||||
ActiveRefs,
|
||||
#[error("Page cache is full")]
|
||||
Full,
|
||||
#[error("key already exists")]
|
||||
KeyExists,
|
||||
}
|
||||
|
||||
@@ -105,7 +112,7 @@ impl DumbLruPageCache {
|
||||
trace!("insert(key={:?})", key);
|
||||
// Check first if page already exists in cache
|
||||
if !ignore_exists {
|
||||
if let Some(existing_page_ref) = self.get(&key) {
|
||||
if let Some(existing_page_ref) = self.get(&key)? {
|
||||
assert!(
|
||||
Arc::ptr_eq(&value, &existing_page_ref),
|
||||
"Attempted to insert different page with same key: {key:?}"
|
||||
@@ -115,7 +122,7 @@ impl DumbLruPageCache {
|
||||
}
|
||||
self.make_room_for(1)?;
|
||||
let entry = Box::new(PageCacheEntry {
|
||||
key: key.clone(),
|
||||
key,
|
||||
next: None,
|
||||
prev: None,
|
||||
page: value,
|
||||
@@ -156,8 +163,21 @@ impl DumbLruPageCache {
|
||||
ptr.copied()
|
||||
}
|
||||
|
||||
pub fn get(&mut self, key: &PageCacheKey) -> Option<PageRef> {
|
||||
self.peek(key, true)
|
||||
pub fn get(&mut self, key: &PageCacheKey) -> Result<Option<PageRef>, CacheError> {
|
||||
if let Some(page) = self.peek(key, true) {
|
||||
// Because we can abort a read_page completion, this means a page can be in the cache but be unloaded and unlocked.
|
||||
// However, if we do not evict that page from the page cache, we will return an unloaded page later which will trigger
|
||||
// assertions later on. This is worsened by the fact that page cache is not per `Statement`, so you can abort a completion
|
||||
// in one Statement, and trigger some error in the next one if we don't evict the page here.
|
||||
if !page.is_loaded() && !page.is_locked() {
|
||||
self.delete(*key)?;
|
||||
Ok(None)
|
||||
} else {
|
||||
Ok(Some(page))
|
||||
}
|
||||
} else {
|
||||
Ok(None)
|
||||
}
|
||||
}
|
||||
|
||||
/// Get page without promoting entry
|
||||
@@ -309,7 +329,7 @@ impl DumbLruPageCache {
|
||||
let entry = unsafe { current.as_ref() };
|
||||
// Pick prev before modifying entry
|
||||
current_opt = entry.prev;
|
||||
match self.delete(entry.key.clone()) {
|
||||
match self.delete(entry.key) {
|
||||
Err(_) => {}
|
||||
Ok(_) => need_to_evict -= 1,
|
||||
}
|
||||
@@ -396,7 +416,7 @@ impl DumbLruPageCache {
|
||||
let mut current = head_ptr;
|
||||
while let Some(node) = current {
|
||||
unsafe {
|
||||
this_keys.push(node.as_ref().key.clone());
|
||||
this_keys.push(node.as_ref().key);
|
||||
let node_ref = node.as_ref();
|
||||
current = node_ref.next;
|
||||
}
|
||||
@@ -647,7 +667,7 @@ impl PageHashMap {
|
||||
pub fn rehash(&self, new_capacity: usize) -> PageHashMap {
|
||||
let mut new_hash_map = PageHashMap::new(new_capacity);
|
||||
for node in self.iter() {
|
||||
new_hash_map.insert(node.key.clone(), node.value);
|
||||
new_hash_map.insert(node.key, node.value);
|
||||
}
|
||||
new_hash_map
|
||||
}
|
||||
@@ -698,7 +718,7 @@ mod tests {
|
||||
fn insert_page(cache: &mut DumbLruPageCache, id: usize) -> PageCacheKey {
|
||||
let key = create_key(id);
|
||||
let page = page_with_content(id);
|
||||
assert!(cache.insert(key.clone(), page).is_ok());
|
||||
assert!(cache.insert(key, page).is_ok());
|
||||
key
|
||||
}
|
||||
|
||||
@@ -712,7 +732,7 @@ mod tests {
|
||||
) -> (PageCacheKey, NonNull<PageCacheEntry>) {
|
||||
let key = create_key(id);
|
||||
let page = page_with_content(id);
|
||||
assert!(cache.insert(key.clone(), page).is_ok());
|
||||
assert!(cache.insert(key, page).is_ok());
|
||||
let entry = cache.get_ptr(&key).expect("Entry should exist");
|
||||
(key, entry)
|
||||
}
|
||||
@@ -727,7 +747,7 @@ mod tests {
|
||||
assert!(cache.tail.borrow().is_some());
|
||||
assert_eq!(*cache.head.borrow(), *cache.tail.borrow());
|
||||
|
||||
assert!(cache.delete(key1.clone()).is_ok());
|
||||
assert!(cache.delete(key1).is_ok());
|
||||
|
||||
assert_eq!(
|
||||
cache.len(),
|
||||
@@ -759,7 +779,7 @@ mod tests {
|
||||
"Initial head check"
|
||||
);
|
||||
|
||||
assert!(cache.delete(key3.clone()).is_ok());
|
||||
assert!(cache.delete(key3).is_ok());
|
||||
|
||||
assert_eq!(cache.len(), 2, "Length should be 2 after deleting head");
|
||||
assert!(
|
||||
@@ -803,7 +823,7 @@ mod tests {
|
||||
"Initial tail check"
|
||||
);
|
||||
|
||||
assert!(cache.delete(key1.clone()).is_ok()); // Delete tail
|
||||
assert!(cache.delete(key1).is_ok()); // Delete tail
|
||||
|
||||
assert_eq!(cache.len(), 2, "Length should be 2 after deleting tail");
|
||||
assert!(
|
||||
@@ -854,7 +874,7 @@ mod tests {
|
||||
let head_ptr_before = cache.head.borrow().unwrap();
|
||||
let tail_ptr_before = cache.tail.borrow().unwrap();
|
||||
|
||||
assert!(cache.delete(key2.clone()).is_ok()); // Detach a middle element (key2)
|
||||
assert!(cache.delete(key2).is_ok()); // Detach a middle element (key2)
|
||||
|
||||
assert_eq!(cache.len(), 3, "Length should be 3 after deleting middle");
|
||||
assert!(
|
||||
@@ -895,11 +915,11 @@ mod tests {
|
||||
let mut cache = DumbLruPageCache::default();
|
||||
let key1 = create_key(1);
|
||||
let page1 = page_with_content(1);
|
||||
assert!(cache.insert(key1.clone(), page1.clone()).is_ok());
|
||||
assert!(cache.insert(key1, page1.clone()).is_ok());
|
||||
assert!(page_has_content(&page1));
|
||||
cache.verify_list_integrity();
|
||||
|
||||
let result = cache.delete(key1.clone());
|
||||
let result = cache.delete(key1);
|
||||
assert!(result.is_err());
|
||||
assert_eq!(result.unwrap_err(), CacheError::ActiveRefs);
|
||||
assert_eq!(cache.len(), 1);
|
||||
@@ -918,10 +938,10 @@ mod tests {
|
||||
let key1 = create_key(1);
|
||||
let page1_v1 = page_with_content(1);
|
||||
let page1_v2 = page_with_content(1);
|
||||
assert!(cache.insert(key1.clone(), page1_v1.clone()).is_ok());
|
||||
assert!(cache.insert(key1, page1_v1.clone()).is_ok());
|
||||
assert_eq!(cache.len(), 1);
|
||||
cache.verify_list_integrity();
|
||||
let _ = cache.insert(key1.clone(), page1_v2.clone()); // Panic
|
||||
let _ = cache.insert(key1, page1_v2.clone()); // Panic
|
||||
}
|
||||
|
||||
#[test]
|
||||
@@ -929,7 +949,7 @@ mod tests {
|
||||
let mut cache = DumbLruPageCache::default();
|
||||
let key_nonexist = create_key(99);
|
||||
|
||||
assert!(cache.delete(key_nonexist.clone()).is_ok()); // no-op
|
||||
assert!(cache.delete(key_nonexist).is_ok()); // no-op
|
||||
}
|
||||
|
||||
#[test]
|
||||
@@ -937,8 +957,8 @@ mod tests {
|
||||
let mut cache = DumbLruPageCache::new(1);
|
||||
let key1 = insert_page(&mut cache, 1);
|
||||
let key2 = insert_page(&mut cache, 2);
|
||||
assert_eq!(cache.get(&key2).unwrap().get().id, 2);
|
||||
assert!(cache.get(&key1).is_none());
|
||||
assert_eq!(cache.get(&key2).unwrap().unwrap().get().id, 2);
|
||||
assert!(cache.get(&key1).unwrap().is_none());
|
||||
}
|
||||
|
||||
#[test]
|
||||
@@ -1002,7 +1022,7 @@ mod tests {
|
||||
fn test_detach_with_cleaning() {
|
||||
let mut cache = DumbLruPageCache::default();
|
||||
let (key, entry) = insert_and_get_entry(&mut cache, 1);
|
||||
let page = cache.get(&key).expect("Page should exist");
|
||||
let page = cache.get(&key).unwrap().expect("Page should exist");
|
||||
assert!(page_has_content(&page));
|
||||
drop(page);
|
||||
assert!(cache.detach(entry, true).is_ok());
|
||||
@@ -1034,8 +1054,8 @@ mod tests {
|
||||
let (key1, _) = insert_and_get_entry(&mut cache, 1);
|
||||
let (key2, entry2) = insert_and_get_entry(&mut cache, 2);
|
||||
let (key3, _) = insert_and_get_entry(&mut cache, 3);
|
||||
let head_key = unsafe { cache.head.borrow().unwrap().as_ref().key.clone() };
|
||||
let tail_key = unsafe { cache.tail.borrow().unwrap().as_ref().key.clone() };
|
||||
let head_key = unsafe { cache.head.borrow().unwrap().as_ref().key };
|
||||
let tail_key = unsafe { cache.tail.borrow().unwrap().as_ref().key };
|
||||
assert_eq!(head_key, key3, "Head should be key3");
|
||||
assert_eq!(tail_key, key1, "Tail should be key1");
|
||||
assert!(cache.detach(entry2, false).is_ok());
|
||||
@@ -1044,12 +1064,12 @@ mod tests {
|
||||
assert_eq!(head_entry.key, key3, "Head should still be key3");
|
||||
assert_eq!(tail_entry.key, key1, "Tail should still be key1");
|
||||
assert_eq!(
|
||||
unsafe { head_entry.next.unwrap().as_ref().key.clone() },
|
||||
unsafe { head_entry.next.unwrap().as_ref().key },
|
||||
key1,
|
||||
"Head's next should point to tail after middle element detached"
|
||||
);
|
||||
assert_eq!(
|
||||
unsafe { tail_entry.prev.unwrap().as_ref().key.clone() },
|
||||
unsafe { tail_entry.prev.unwrap().as_ref().key },
|
||||
key3,
|
||||
"Tail's prev should point to head after middle element detached"
|
||||
);
|
||||
@@ -1085,7 +1105,7 @@ mod tests {
|
||||
continue; // skip duplicate page ids
|
||||
}
|
||||
tracing::debug!("inserting page {:?}", key);
|
||||
match cache.insert(key.clone(), page.clone()) {
|
||||
match cache.insert(key, page.clone()) {
|
||||
Err(CacheError::Full | CacheError::ActiveRefs) => {} // Ignore
|
||||
Err(err) => {
|
||||
// Any other error should fail the test
|
||||
@@ -1106,7 +1126,7 @@ mod tests {
|
||||
PageCacheKey::new(id_page as usize)
|
||||
} else {
|
||||
let i = rng.next_u64() as usize % lru.len();
|
||||
let key: PageCacheKey = lru.iter().nth(i).unwrap().0.clone();
|
||||
let key: PageCacheKey = *lru.iter().nth(i).unwrap().0;
|
||||
key
|
||||
};
|
||||
tracing::debug!("removing page {:?}", key);
|
||||
@@ -1133,7 +1153,7 @@ mod tests {
|
||||
let this_keys = cache.keys();
|
||||
let mut lru_keys = Vec::new();
|
||||
for (lru_key, _) in lru {
|
||||
lru_keys.push(lru_key.clone());
|
||||
lru_keys.push(*lru_key);
|
||||
}
|
||||
if this_keys != lru_keys {
|
||||
cache.print();
|
||||
@@ -1149,8 +1169,8 @@ mod tests {
|
||||
let mut cache = DumbLruPageCache::default();
|
||||
let key1 = insert_page(&mut cache, 1);
|
||||
let key2 = insert_page(&mut cache, 2);
|
||||
assert_eq!(cache.get(&key1).unwrap().get().id, 1);
|
||||
assert_eq!(cache.get(&key2).unwrap().get().id, 2);
|
||||
assert_eq!(cache.get(&key1).unwrap().unwrap().get().id, 1);
|
||||
assert_eq!(cache.get(&key2).unwrap().unwrap().get().id, 2);
|
||||
}
|
||||
|
||||
#[test]
|
||||
@@ -1159,17 +1179,17 @@ mod tests {
|
||||
let key1 = insert_page(&mut cache, 1);
|
||||
let key2 = insert_page(&mut cache, 2);
|
||||
let key3 = insert_page(&mut cache, 3);
|
||||
assert!(cache.get(&key1).is_none());
|
||||
assert_eq!(cache.get(&key2).unwrap().get().id, 2);
|
||||
assert_eq!(cache.get(&key3).unwrap().get().id, 3);
|
||||
assert!(cache.get(&key1).unwrap().is_none());
|
||||
assert_eq!(cache.get(&key2).unwrap().unwrap().get().id, 2);
|
||||
assert_eq!(cache.get(&key3).unwrap().unwrap().get().id, 3);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_page_cache_delete() {
|
||||
let mut cache = DumbLruPageCache::default();
|
||||
let key1 = insert_page(&mut cache, 1);
|
||||
assert!(cache.delete(key1.clone()).is_ok());
|
||||
assert!(cache.get(&key1).is_none());
|
||||
assert!(cache.delete(key1).is_ok());
|
||||
assert!(cache.get(&key1).unwrap().is_none());
|
||||
}
|
||||
|
||||
#[test]
|
||||
@@ -1178,8 +1198,8 @@ mod tests {
|
||||
let key1 = insert_page(&mut cache, 1);
|
||||
let key2 = insert_page(&mut cache, 2);
|
||||
assert!(cache.clear().is_ok());
|
||||
assert!(cache.get(&key1).is_none());
|
||||
assert!(cache.get(&key2).is_none());
|
||||
assert!(cache.get(&key1).unwrap().is_none());
|
||||
assert!(cache.get(&key2).unwrap().is_none());
|
||||
}
|
||||
|
||||
#[test]
|
||||
@@ -1216,8 +1236,8 @@ mod tests {
|
||||
assert_eq!(result, CacheResizeResult::Done);
|
||||
assert_eq!(cache.len(), 2);
|
||||
assert_eq!(cache.capacity, 5);
|
||||
assert!(cache.get(&create_key(1)).is_some());
|
||||
assert!(cache.get(&create_key(2)).is_some());
|
||||
assert!(cache.get(&create_key(1)).unwrap().is_some());
|
||||
assert!(cache.get(&create_key(2)).unwrap().is_some());
|
||||
for i in 3..=5 {
|
||||
let _ = insert_page(&mut cache, i);
|
||||
}
|
||||
|
||||
@@ -1123,7 +1123,7 @@ impl Pager {
|
||||
tracing::trace!("read_page(page_idx = {})", page_idx);
|
||||
let mut page_cache = self.page_cache.write();
|
||||
let page_key = PageCacheKey::new(page_idx);
|
||||
if let Some(page) = page_cache.get(&page_key) {
|
||||
if let Some(page) = page_cache.get(&page_key)? {
|
||||
tracing::trace!("read_page(page_idx = {}) = cached", page_idx);
|
||||
return Ok((page.clone(), None));
|
||||
}
|
||||
@@ -1158,25 +1158,20 @@ impl Pager {
|
||||
let page_key = PageCacheKey::new(page_idx);
|
||||
match page_cache.insert(page_key, page.clone()) {
|
||||
Ok(_) => {}
|
||||
Err(CacheError::Full) => return Err(LimboError::CacheFull),
|
||||
Err(CacheError::KeyExists) => {
|
||||
unreachable!("Page should not exist in cache after get() miss")
|
||||
}
|
||||
Err(e) => {
|
||||
return Err(LimboError::InternalError(format!(
|
||||
"Failed to insert page into cache: {e:?}"
|
||||
)))
|
||||
}
|
||||
Err(e) => return Err(e.into()),
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
// Get a page from the cache, if it exists.
|
||||
pub fn cache_get(&self, page_idx: usize) -> Option<PageRef> {
|
||||
pub fn cache_get(&self, page_idx: usize) -> Result<Option<PageRef>> {
|
||||
tracing::trace!("read_page(page_idx = {})", page_idx);
|
||||
let mut page_cache = self.page_cache.write();
|
||||
let page_key = PageCacheKey::new(page_idx);
|
||||
page_cache.get(&page_key)
|
||||
Ok(page_cache.get(&page_key)?)
|
||||
}
|
||||
|
||||
/// Get a page from cache only if it matches the target frame
|
||||
@@ -1185,10 +1180,10 @@ impl Pager {
|
||||
page_idx: usize,
|
||||
target_frame: u64,
|
||||
seq: u32,
|
||||
) -> Option<PageRef> {
|
||||
) -> Result<Option<PageRef>> {
|
||||
let mut page_cache = self.page_cache.write();
|
||||
let page_key = PageCacheKey::new(page_idx);
|
||||
page_cache.get(&page_key).and_then(|page| {
|
||||
let page = page_cache.get(&page_key)?.and_then(|page| {
|
||||
if page.is_valid_for_checkpoint(target_frame, seq) {
|
||||
tracing::trace!(
|
||||
"cache_get_for_checkpoint: page {} frame {} is valid",
|
||||
@@ -1207,7 +1202,8 @@ impl Pager {
|
||||
);
|
||||
None
|
||||
}
|
||||
})
|
||||
});
|
||||
Ok(page)
|
||||
}
|
||||
|
||||
/// Changes the size of the page cache.
|
||||
@@ -1261,7 +1257,7 @@ impl Pager {
|
||||
let page = {
|
||||
let mut cache = self.page_cache.write();
|
||||
let page_key = PageCacheKey::new(*page_id);
|
||||
let page = cache.get(&page_key).expect("we somehow added a page to dirty list but we didn't mark it as dirty, causing cache to drop it.");
|
||||
let page = cache.get(&page_key)?.expect("we somehow added a page to dirty list but we didn't mark it as dirty, causing cache to drop it.");
|
||||
let page_type = page.get_contents().maybe_page_type();
|
||||
trace!("cacheflush(page={}, page_type={:?}", page_id, page_type);
|
||||
page
|
||||
@@ -1344,7 +1340,7 @@ impl Pager {
|
||||
let page = {
|
||||
let mut cache = self.page_cache.write();
|
||||
let page_key = PageCacheKey::new(page_id);
|
||||
let page = cache.get(&page_key).expect(
|
||||
let page = cache.get(&page_key)?.expect(
|
||||
"dirty list contained a page that cache dropped (page={page_id})",
|
||||
);
|
||||
trace!(
|
||||
@@ -1482,7 +1478,7 @@ impl Pager {
|
||||
header.db_size as u64,
|
||||
raw_page,
|
||||
)?;
|
||||
if let Some(page) = self.cache_get(header.page_number as usize) {
|
||||
if let Some(page) = self.cache_get(header.page_number as usize)? {
|
||||
let content = page.get_contents();
|
||||
content.as_ptr().copy_from_slice(raw_page);
|
||||
turso_assert!(
|
||||
@@ -1505,7 +1501,7 @@ impl Pager {
|
||||
for page_id in self.dirty_pages.borrow().iter() {
|
||||
let page_key = PageCacheKey::new(*page_id);
|
||||
let mut cache = self.page_cache.write();
|
||||
let page = cache.get(&page_key).expect("we somehow added a page to dirty list but we didn't mark it as dirty, causing cache to drop it.");
|
||||
let page = cache.get(&page_key)?.expect("we somehow added a page to dirty list but we didn't mark it as dirty, causing cache to drop it.");
|
||||
page.clear_dirty();
|
||||
}
|
||||
self.dirty_pages.borrow_mut().clear();
|
||||
@@ -1902,15 +1898,7 @@ impl Pager {
|
||||
self.add_dirty(&page);
|
||||
let page_key = PageCacheKey::new(page.get().id);
|
||||
let mut cache = self.page_cache.write();
|
||||
match cache.insert(page_key, page.clone()) {
|
||||
Ok(_) => (),
|
||||
Err(CacheError::Full) => return Err(LimboError::CacheFull),
|
||||
Err(_) => {
|
||||
return Err(LimboError::InternalError(
|
||||
"Unknown error inserting page to cache".into(),
|
||||
))
|
||||
}
|
||||
}
|
||||
cache.insert(page_key, page.clone())?;
|
||||
}
|
||||
}
|
||||
|
||||
@@ -2081,15 +2069,7 @@ impl Pager {
|
||||
{
|
||||
// Run in separate block to avoid deadlock on page cache write lock
|
||||
let mut cache = self.page_cache.write();
|
||||
match cache.insert(page_key, page.clone()) {
|
||||
Err(CacheError::Full) => return Err(LimboError::CacheFull),
|
||||
Err(_) => {
|
||||
return Err(LimboError::InternalError(
|
||||
"Unknown error inserting page to cache".into(),
|
||||
))
|
||||
}
|
||||
Ok(_) => {}
|
||||
};
|
||||
cache.insert(page_key, page.clone())?;
|
||||
}
|
||||
header.database_size = new_db_size.into();
|
||||
*state = AllocatePageState::Start;
|
||||
@@ -2186,7 +2166,8 @@ impl Pager {
|
||||
}
|
||||
|
||||
pub fn set_encryption_context(&self, cipher_mode: CipherMode, key: &EncryptionKey) {
|
||||
let encryption_ctx = EncryptionContext::new(cipher_mode, key).unwrap();
|
||||
let page_size = self.page_size.get().unwrap().get() as usize;
|
||||
let encryption_ctx = EncryptionContext::new(cipher_mode, key, page_size).unwrap();
|
||||
{
|
||||
let mut io_ctx = self.io_ctx.borrow_mut();
|
||||
io_ctx.set_encryption(encryption_ctx);
|
||||
@@ -2428,13 +2409,16 @@ mod tests {
|
||||
std::thread::spawn(move || {
|
||||
let mut cache = cache.write();
|
||||
let page_key = PageCacheKey::new(1);
|
||||
cache.insert(page_key, Arc::new(Page::new(1))).unwrap();
|
||||
let page = Page::new(1);
|
||||
// Set loaded so that we avoid eviction, as we evict the page from cache if it is not locked and not loaded
|
||||
page.set_loaded();
|
||||
cache.insert(page_key, Arc::new(page)).unwrap();
|
||||
})
|
||||
};
|
||||
let _ = thread.join();
|
||||
let mut cache = cache.write();
|
||||
let page_key = PageCacheKey::new(1);
|
||||
let page = cache.get(&page_key);
|
||||
let page = cache.get(&page_key).unwrap();
|
||||
assert_eq!(page.unwrap().get().id, 1);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1838,7 +1838,7 @@ pub fn read_entire_wal_dumb(file: &Arc<dyn File>) -> Result<Arc<UnsafeCell<WalFi
|
||||
pub fn begin_read_wal_frame_raw(
|
||||
buffer_pool: &Arc<BufferPool>,
|
||||
io: &Arc<dyn File>,
|
||||
offset: usize,
|
||||
offset: u64,
|
||||
complete: Box<ReadComplete>,
|
||||
) -> Result<Completion> {
|
||||
tracing::trace!("begin_read_wal_frame_raw(offset={})", offset);
|
||||
@@ -1851,7 +1851,7 @@ pub fn begin_read_wal_frame_raw(
|
||||
|
||||
pub fn begin_read_wal_frame(
|
||||
io: &Arc<dyn File>,
|
||||
offset: usize,
|
||||
offset: u64,
|
||||
buffer_pool: Arc<BufferPool>,
|
||||
complete: Box<ReadComplete>,
|
||||
page_idx: usize,
|
||||
|
||||
@@ -1082,7 +1082,7 @@ impl Wal for WalFile {
|
||||
});
|
||||
begin_read_wal_frame(
|
||||
&self.get_shared().file,
|
||||
offset + WAL_FRAME_HEADER_SIZE,
|
||||
offset + WAL_FRAME_HEADER_SIZE as u64,
|
||||
buffer_pool,
|
||||
complete,
|
||||
page_idx,
|
||||
@@ -1095,6 +1095,11 @@ impl Wal for WalFile {
|
||||
tracing::debug!("read_frame({})", frame_id);
|
||||
let offset = self.frame_offset(frame_id);
|
||||
let (frame_ptr, frame_len) = (frame.as_mut_ptr(), frame.len());
|
||||
|
||||
let encryption_ctx = {
|
||||
let io_ctx = self.io_ctx.borrow();
|
||||
io_ctx.encryption_context().cloned()
|
||||
};
|
||||
let complete = Box::new(move |res: Result<(Arc<Buffer>, i32), CompletionError>| {
|
||||
let Ok((buf, bytes_read)) = res else {
|
||||
return;
|
||||
@@ -1104,10 +1109,34 @@ impl Wal for WalFile {
|
||||
bytes_read == buf_len as i32,
|
||||
"read({bytes_read}) != expected({buf_len})"
|
||||
);
|
||||
let buf_ptr = buf.as_mut_ptr();
|
||||
let buf_ptr = buf.as_ptr();
|
||||
let frame_ref: &mut [u8] =
|
||||
unsafe { std::slice::from_raw_parts_mut(frame_ptr, frame_len) };
|
||||
|
||||
// Copy the just-read WAL frame into the destination buffer
|
||||
unsafe {
|
||||
std::ptr::copy_nonoverlapping(buf_ptr, frame_ptr, frame_len);
|
||||
}
|
||||
|
||||
// Now parse the header from the freshly-copied data
|
||||
let (header, raw_page) = sqlite3_ondisk::parse_wal_frame_header(frame_ref);
|
||||
|
||||
if let Some(ctx) = encryption_ctx.clone() {
|
||||
match ctx.decrypt_page(raw_page, header.page_number as usize) {
|
||||
Ok(decrypted_data) => {
|
||||
turso_assert!(
|
||||
(frame_len - WAL_FRAME_HEADER_SIZE) == decrypted_data.len(),
|
||||
"frame_len - header_size({}) != expected({})",
|
||||
frame_len - WAL_FRAME_HEADER_SIZE,
|
||||
decrypted_data.len()
|
||||
);
|
||||
frame_ref[WAL_FRAME_HEADER_SIZE..].copy_from_slice(&decrypted_data);
|
||||
}
|
||||
Err(_) => {
|
||||
tracing::error!("Failed to decrypt page data for frame_id={frame_id}");
|
||||
}
|
||||
}
|
||||
}
|
||||
});
|
||||
let c =
|
||||
begin_read_wal_frame_raw(&self.buffer_pool, &self.get_shared().file, offset, complete)?;
|
||||
@@ -1167,7 +1196,7 @@ impl Wal for WalFile {
|
||||
});
|
||||
let c = begin_read_wal_frame(
|
||||
&self.get_shared().file,
|
||||
offset + WAL_FRAME_HEADER_SIZE,
|
||||
offset + WAL_FRAME_HEADER_SIZE as u64,
|
||||
buffer_pool,
|
||||
complete,
|
||||
page_id as usize,
|
||||
@@ -1431,14 +1460,14 @@ impl Wal for WalFile {
|
||||
let mut next_frame_id = self.max_frame + 1;
|
||||
// Build every frame in order, updating the rolling checksum
|
||||
for (idx, page) in pages.iter().enumerate() {
|
||||
let page_id = page.get().id as u64;
|
||||
let page_id = page.get().id;
|
||||
let plain = page.get_contents().as_ptr();
|
||||
|
||||
let data_to_write: std::borrow::Cow<[u8]> = {
|
||||
let io_ctx = self.io_ctx.borrow();
|
||||
let ectx = io_ctx.encryption_context();
|
||||
if let Some(ctx) = ectx.as_ref() {
|
||||
Cow::Owned(ctx.encrypt_page(plain, page_id as usize)?)
|
||||
Cow::Owned(ctx.encrypt_page(plain, page_id)?)
|
||||
} else {
|
||||
Cow::Borrowed(plain)
|
||||
}
|
||||
@@ -1552,11 +1581,10 @@ impl WalFile {
|
||||
self.get_shared().wal_header.lock().page_size
|
||||
}
|
||||
|
||||
fn frame_offset(&self, frame_id: u64) -> usize {
|
||||
fn frame_offset(&self, frame_id: u64) -> u64 {
|
||||
assert!(frame_id > 0, "Frame ID must be 1-based");
|
||||
let page_offset = (frame_id - 1) * (self.page_size() + WAL_FRAME_HEADER_SIZE as u32) as u64;
|
||||
let offset = WAL_HEADER_SIZE as u64 + page_offset;
|
||||
offset as usize
|
||||
WAL_HEADER_SIZE as u64 + page_offset
|
||||
}
|
||||
|
||||
#[allow(clippy::mut_from_ref)]
|
||||
@@ -1748,7 +1776,7 @@ impl WalFile {
|
||||
|
||||
// Try cache first, if enabled
|
||||
if let Some(cached_page) =
|
||||
pager.cache_get_for_checkpoint(page_id as usize, target_frame, seq)
|
||||
pager.cache_get_for_checkpoint(page_id as usize, target_frame, seq)?
|
||||
{
|
||||
let contents = cached_page.get_contents();
|
||||
let buffer = contents.buffer.clone();
|
||||
@@ -1805,7 +1833,7 @@ impl WalFile {
|
||||
self.ongoing_checkpoint.pages_to_checkpoint.iter()
|
||||
{
|
||||
if *cached {
|
||||
let page = pager.cache_get((*page_id) as usize);
|
||||
let page = pager.cache_get((*page_id) as usize)?;
|
||||
turso_assert!(
|
||||
page.is_some(),
|
||||
"page should still exist in the page cache"
|
||||
@@ -2102,7 +2130,7 @@ impl WalFile {
|
||||
// schedule read of the page payload
|
||||
let c = begin_read_wal_frame(
|
||||
&self.get_shared().file,
|
||||
offset + WAL_FRAME_HEADER_SIZE,
|
||||
offset + WAL_FRAME_HEADER_SIZE as u64,
|
||||
self.buffer_pool.clone(),
|
||||
complete,
|
||||
page_id,
|
||||
@@ -2288,7 +2316,7 @@ pub mod test {
|
||||
let done = Rc::new(Cell::new(false));
|
||||
let _done = done.clone();
|
||||
let _ = file.file.truncate(
|
||||
WAL_HEADER_SIZE,
|
||||
WAL_HEADER_SIZE as u64,
|
||||
Completion::new_trunc(move |_| {
|
||||
let done = _done.clone();
|
||||
done.set(true);
|
||||
|
||||
@@ -125,27 +125,161 @@ pub fn handle_distinct(program: &mut ProgramBuilder, agg: &Aggregate, agg_arg_re
|
||||
});
|
||||
}
|
||||
|
||||
/// Emits the bytecode for processing an aggregate step.
|
||||
/// E.g. in `SELECT SUM(price) FROM t`, 'price' is evaluated for every row, and the result is added to the accumulator.
|
||||
/// Enum representing the source of the aggregate function arguments
|
||||
///
|
||||
/// This is distinct from the final step, which is called after the main loop has finished processing
|
||||
/// Aggregate arguments can come from different sources, depending on how the aggregation
|
||||
/// is evaluated:
|
||||
/// * In the common grouped case, the aggregate function arguments are first inserted
|
||||
/// into a sorter in the main loop, and in the group by aggregation phase we read
|
||||
/// the data from the sorter.
|
||||
/// * In grouped cases where no sorting is required, arguments are retrieved directly
|
||||
/// from registers allocated in the main loop.
|
||||
/// * In ungrouped cases, arguments are computed directly from the `args` expressions.
|
||||
pub enum AggArgumentSource<'a> {
|
||||
/// The aggregate function arguments are retrieved from a pseudo cursor
|
||||
/// which reads from the GROUP BY sorter.
|
||||
PseudoCursor {
|
||||
cursor_id: usize,
|
||||
col_start: usize,
|
||||
dest_reg_start: usize,
|
||||
aggregate: &'a Aggregate,
|
||||
},
|
||||
/// The aggregate function arguments are retrieved from a contiguous block of registers
|
||||
/// allocated in the main loop for that given aggregate function.
|
||||
Register {
|
||||
src_reg_start: usize,
|
||||
aggregate: &'a Aggregate,
|
||||
},
|
||||
/// The aggregate function arguments are retrieved by evaluating expressions.
|
||||
Expression { aggregate: &'a Aggregate },
|
||||
}
|
||||
|
||||
impl<'a> AggArgumentSource<'a> {
|
||||
/// Create a new [AggArgumentSource] that retrieves the values from a GROUP BY sorter.
|
||||
pub fn new_from_cursor(
|
||||
program: &mut ProgramBuilder,
|
||||
cursor_id: usize,
|
||||
col_start: usize,
|
||||
aggregate: &'a Aggregate,
|
||||
) -> Self {
|
||||
let dest_reg_start = program.alloc_registers(aggregate.args.len());
|
||||
Self::PseudoCursor {
|
||||
cursor_id,
|
||||
col_start,
|
||||
dest_reg_start,
|
||||
aggregate,
|
||||
}
|
||||
}
|
||||
/// Create a new [AggArgumentSource] that retrieves the values directly from an already
|
||||
/// populated register or registers.
|
||||
pub fn new_from_registers(src_reg_start: usize, aggregate: &'a Aggregate) -> Self {
|
||||
Self::Register {
|
||||
src_reg_start,
|
||||
aggregate,
|
||||
}
|
||||
}
|
||||
|
||||
/// Create a new [AggArgumentSource] that retrieves the values by evaluating `args` expressions.
|
||||
pub fn new_from_expression(aggregate: &'a Aggregate) -> Self {
|
||||
Self::Expression { aggregate }
|
||||
}
|
||||
|
||||
pub fn aggregate(&self) -> &Aggregate {
|
||||
match self {
|
||||
AggArgumentSource::PseudoCursor { aggregate, .. } => aggregate,
|
||||
AggArgumentSource::Register { aggregate, .. } => aggregate,
|
||||
AggArgumentSource::Expression { aggregate } => aggregate,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn agg_func(&self) -> &AggFunc {
|
||||
match self {
|
||||
AggArgumentSource::PseudoCursor { aggregate, .. } => &aggregate.func,
|
||||
AggArgumentSource::Register { aggregate, .. } => &aggregate.func,
|
||||
AggArgumentSource::Expression { aggregate } => &aggregate.func,
|
||||
}
|
||||
}
|
||||
pub fn args(&self) -> &[ast::Expr] {
|
||||
match self {
|
||||
AggArgumentSource::PseudoCursor { aggregate, .. } => &aggregate.args,
|
||||
AggArgumentSource::Register { aggregate, .. } => &aggregate.args,
|
||||
AggArgumentSource::Expression { aggregate } => &aggregate.args,
|
||||
}
|
||||
}
|
||||
pub fn num_args(&self) -> usize {
|
||||
match self {
|
||||
AggArgumentSource::PseudoCursor { aggregate, .. } => aggregate.args.len(),
|
||||
AggArgumentSource::Register { aggregate, .. } => aggregate.args.len(),
|
||||
AggArgumentSource::Expression { aggregate } => aggregate.args.len(),
|
||||
}
|
||||
}
|
||||
/// Read the value of an aggregate function argument
|
||||
pub fn translate(
|
||||
&self,
|
||||
program: &mut ProgramBuilder,
|
||||
referenced_tables: &TableReferences,
|
||||
resolver: &Resolver,
|
||||
arg_idx: usize,
|
||||
) -> Result<usize> {
|
||||
match self {
|
||||
AggArgumentSource::PseudoCursor {
|
||||
cursor_id,
|
||||
col_start,
|
||||
dest_reg_start,
|
||||
..
|
||||
} => {
|
||||
program.emit_column_or_rowid(
|
||||
*cursor_id,
|
||||
*col_start + arg_idx,
|
||||
dest_reg_start + arg_idx,
|
||||
);
|
||||
Ok(dest_reg_start + arg_idx)
|
||||
}
|
||||
AggArgumentSource::Register {
|
||||
src_reg_start: start_reg,
|
||||
..
|
||||
} => Ok(*start_reg + arg_idx),
|
||||
AggArgumentSource::Expression { aggregate } => {
|
||||
let dest_reg = program.alloc_register();
|
||||
translate_expr(
|
||||
program,
|
||||
Some(referenced_tables),
|
||||
&aggregate.args[arg_idx],
|
||||
dest_reg,
|
||||
resolver,
|
||||
)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Emits the bytecode for processing an aggregate step.
|
||||
///
|
||||
/// This is distinct from the final step, which is called after a single group has been entirely accumulated,
|
||||
/// and the actual result value of the aggregation is materialized.
|
||||
///
|
||||
/// Ungrouped aggregation is a special case of grouped aggregation that involves a single group.
|
||||
///
|
||||
/// Examples:
|
||||
/// * In `SELECT SUM(price) FROM t`, `price` is evaluated for each row and added to the accumulator.
|
||||
/// * In `SELECT product_category, SUM(price) FROM t GROUP BY product_category`, `price` is evaluated for
|
||||
/// each row in the group and added to that group’s accumulator.
|
||||
pub fn translate_aggregation_step(
|
||||
program: &mut ProgramBuilder,
|
||||
referenced_tables: &TableReferences,
|
||||
agg: &Aggregate,
|
||||
agg_arg_source: AggArgumentSource,
|
||||
target_register: usize,
|
||||
resolver: &Resolver,
|
||||
) -> Result<usize> {
|
||||
let dest = match agg.func {
|
||||
let num_args = agg_arg_source.num_args();
|
||||
let func = agg_arg_source.agg_func();
|
||||
let dest = match func {
|
||||
AggFunc::Avg => {
|
||||
if agg.args.len() != 1 {
|
||||
if num_args != 1 {
|
||||
crate::bail_parse_error!("avg bad number of arguments");
|
||||
}
|
||||
let expr = &agg.args[0];
|
||||
let expr_reg = program.alloc_register();
|
||||
let _ = translate_expr(program, Some(referenced_tables), expr, expr_reg, resolver)?;
|
||||
handle_distinct(program, agg, expr_reg);
|
||||
let expr_reg = agg_arg_source.translate(program, referenced_tables, resolver, 0)?;
|
||||
handle_distinct(program, agg_arg_source.aggregate(), expr_reg);
|
||||
program.emit_insn(Insn::AggStep {
|
||||
acc_reg: target_register,
|
||||
col: expr_reg,
|
||||
@@ -155,20 +289,16 @@ pub fn translate_aggregation_step(
|
||||
target_register
|
||||
}
|
||||
AggFunc::Count | AggFunc::Count0 => {
|
||||
let expr_reg = if agg.args.is_empty() {
|
||||
program.alloc_register()
|
||||
} else {
|
||||
let expr = &agg.args[0];
|
||||
let expr_reg = program.alloc_register();
|
||||
let _ = translate_expr(program, Some(referenced_tables), expr, expr_reg, resolver)?;
|
||||
expr_reg
|
||||
};
|
||||
handle_distinct(program, agg, expr_reg);
|
||||
if num_args != 1 {
|
||||
crate::bail_parse_error!("count bad number of arguments");
|
||||
}
|
||||
let expr_reg = agg_arg_source.translate(program, referenced_tables, resolver, 0)?;
|
||||
handle_distinct(program, agg_arg_source.aggregate(), expr_reg);
|
||||
program.emit_insn(Insn::AggStep {
|
||||
acc_reg: target_register,
|
||||
col: expr_reg,
|
||||
delimiter: 0,
|
||||
func: if matches!(agg.func, AggFunc::Count0) {
|
||||
func: if matches!(func, AggFunc::Count0) {
|
||||
AggFunc::Count0
|
||||
} else {
|
||||
AggFunc::Count
|
||||
@@ -177,18 +307,16 @@ pub fn translate_aggregation_step(
|
||||
target_register
|
||||
}
|
||||
AggFunc::GroupConcat => {
|
||||
if agg.args.len() != 1 && agg.args.len() != 2 {
|
||||
if num_args != 1 && num_args != 2 {
|
||||
crate::bail_parse_error!("group_concat bad number of arguments");
|
||||
}
|
||||
|
||||
let expr_reg = program.alloc_register();
|
||||
let delimiter_reg = program.alloc_register();
|
||||
|
||||
let expr = &agg.args[0];
|
||||
let delimiter_expr: ast::Expr;
|
||||
|
||||
if agg.args.len() == 2 {
|
||||
match &agg.args[1] {
|
||||
if num_args == 2 {
|
||||
match &agg_arg_source.args()[1] {
|
||||
arg @ ast::Expr::Column { .. } => {
|
||||
delimiter_expr = arg.clone();
|
||||
}
|
||||
@@ -201,8 +329,8 @@ pub fn translate_aggregation_step(
|
||||
delimiter_expr = ast::Expr::Literal(ast::Literal::String(String::from("\",\"")));
|
||||
}
|
||||
|
||||
translate_expr(program, Some(referenced_tables), expr, expr_reg, resolver)?;
|
||||
handle_distinct(program, agg, expr_reg);
|
||||
let expr_reg = agg_arg_source.translate(program, referenced_tables, resolver, 0)?;
|
||||
handle_distinct(program, agg_arg_source.aggregate(), expr_reg);
|
||||
translate_expr(
|
||||
program,
|
||||
Some(referenced_tables),
|
||||
@@ -221,13 +349,12 @@ pub fn translate_aggregation_step(
|
||||
target_register
|
||||
}
|
||||
AggFunc::Max => {
|
||||
if agg.args.len() != 1 {
|
||||
if num_args != 1 {
|
||||
crate::bail_parse_error!("max bad number of arguments");
|
||||
}
|
||||
let expr = &agg.args[0];
|
||||
let expr_reg = program.alloc_register();
|
||||
let _ = translate_expr(program, Some(referenced_tables), expr, expr_reg, resolver)?;
|
||||
handle_distinct(program, agg, expr_reg);
|
||||
let expr_reg = agg_arg_source.translate(program, referenced_tables, resolver, 0)?;
|
||||
handle_distinct(program, agg_arg_source.aggregate(), expr_reg);
|
||||
let expr = &agg_arg_source.args()[0];
|
||||
emit_collseq_if_needed(program, referenced_tables, expr);
|
||||
program.emit_insn(Insn::AggStep {
|
||||
acc_reg: target_register,
|
||||
@@ -238,13 +365,12 @@ pub fn translate_aggregation_step(
|
||||
target_register
|
||||
}
|
||||
AggFunc::Min => {
|
||||
if agg.args.len() != 1 {
|
||||
if num_args != 1 {
|
||||
crate::bail_parse_error!("min bad number of arguments");
|
||||
}
|
||||
let expr = &agg.args[0];
|
||||
let expr_reg = program.alloc_register();
|
||||
let _ = translate_expr(program, Some(referenced_tables), expr, expr_reg, resolver)?;
|
||||
handle_distinct(program, agg, expr_reg);
|
||||
let expr_reg = agg_arg_source.translate(program, referenced_tables, resolver, 0)?;
|
||||
handle_distinct(program, agg_arg_source.aggregate(), expr_reg);
|
||||
let expr = &agg_arg_source.args()[0];
|
||||
emit_collseq_if_needed(program, referenced_tables, expr);
|
||||
program.emit_insn(Insn::AggStep {
|
||||
acc_reg: target_register,
|
||||
@@ -256,23 +382,12 @@ pub fn translate_aggregation_step(
|
||||
}
|
||||
#[cfg(feature = "json")]
|
||||
AggFunc::JsonGroupObject | AggFunc::JsonbGroupObject => {
|
||||
if agg.args.len() != 2 {
|
||||
if num_args != 2 {
|
||||
crate::bail_parse_error!("max bad number of arguments");
|
||||
}
|
||||
let expr = &agg.args[0];
|
||||
let expr_reg = program.alloc_register();
|
||||
let value_expr = &agg.args[1];
|
||||
let value_reg = program.alloc_register();
|
||||
|
||||
let _ = translate_expr(program, Some(referenced_tables), expr, expr_reg, resolver)?;
|
||||
handle_distinct(program, agg, expr_reg);
|
||||
let _ = translate_expr(
|
||||
program,
|
||||
Some(referenced_tables),
|
||||
value_expr,
|
||||
value_reg,
|
||||
resolver,
|
||||
)?;
|
||||
let expr_reg = agg_arg_source.translate(program, referenced_tables, resolver, 0)?;
|
||||
handle_distinct(program, agg_arg_source.aggregate(), expr_reg);
|
||||
let value_reg = agg_arg_source.translate(program, referenced_tables, resolver, 1)?;
|
||||
|
||||
program.emit_insn(Insn::AggStep {
|
||||
acc_reg: target_register,
|
||||
@@ -284,13 +399,11 @@ pub fn translate_aggregation_step(
|
||||
}
|
||||
#[cfg(feature = "json")]
|
||||
AggFunc::JsonGroupArray | AggFunc::JsonbGroupArray => {
|
||||
if agg.args.len() != 1 {
|
||||
if num_args != 1 {
|
||||
crate::bail_parse_error!("max bad number of arguments");
|
||||
}
|
||||
let expr = &agg.args[0];
|
||||
let expr_reg = program.alloc_register();
|
||||
let _ = translate_expr(program, Some(referenced_tables), expr, expr_reg, resolver)?;
|
||||
handle_distinct(program, agg, expr_reg);
|
||||
let expr_reg = agg_arg_source.translate(program, referenced_tables, resolver, 0)?;
|
||||
handle_distinct(program, agg_arg_source.aggregate(), expr_reg);
|
||||
program.emit_insn(Insn::AggStep {
|
||||
acc_reg: target_register,
|
||||
col: expr_reg,
|
||||
@@ -300,15 +413,13 @@ pub fn translate_aggregation_step(
|
||||
target_register
|
||||
}
|
||||
AggFunc::StringAgg => {
|
||||
if agg.args.len() != 2 {
|
||||
if num_args != 2 {
|
||||
crate::bail_parse_error!("string_agg bad number of arguments");
|
||||
}
|
||||
|
||||
let expr_reg = program.alloc_register();
|
||||
let delimiter_reg = program.alloc_register();
|
||||
|
||||
let expr = &agg.args[0];
|
||||
let delimiter_expr = match &agg.args[1] {
|
||||
let delimiter_expr = match &agg_arg_source.args()[1] {
|
||||
arg @ ast::Expr::Column { .. } => arg.clone(),
|
||||
ast::Expr::Literal(ast::Literal::String(s)) => {
|
||||
ast::Expr::Literal(ast::Literal::String(s.to_string()))
|
||||
@@ -316,7 +427,7 @@ pub fn translate_aggregation_step(
|
||||
_ => crate::bail_parse_error!("Incorrect delimiter parameter"),
|
||||
};
|
||||
|
||||
translate_expr(program, Some(referenced_tables), expr, expr_reg, resolver)?;
|
||||
let expr_reg = agg_arg_source.translate(program, referenced_tables, resolver, 0)?;
|
||||
translate_expr(
|
||||
program,
|
||||
Some(referenced_tables),
|
||||
@@ -335,13 +446,11 @@ pub fn translate_aggregation_step(
|
||||
target_register
|
||||
}
|
||||
AggFunc::Sum => {
|
||||
if agg.args.len() != 1 {
|
||||
if num_args != 1 {
|
||||
crate::bail_parse_error!("sum bad number of arguments");
|
||||
}
|
||||
let expr = &agg.args[0];
|
||||
let expr_reg = program.alloc_register();
|
||||
let _ = translate_expr(program, Some(referenced_tables), expr, expr_reg, resolver)?;
|
||||
handle_distinct(program, agg, expr_reg);
|
||||
let expr_reg = agg_arg_source.translate(program, referenced_tables, resolver, 0)?;
|
||||
handle_distinct(program, agg_arg_source.aggregate(), expr_reg);
|
||||
program.emit_insn(Insn::AggStep {
|
||||
acc_reg: target_register,
|
||||
col: expr_reg,
|
||||
@@ -351,13 +460,11 @@ pub fn translate_aggregation_step(
|
||||
target_register
|
||||
}
|
||||
AggFunc::Total => {
|
||||
if agg.args.len() != 1 {
|
||||
if num_args != 1 {
|
||||
crate::bail_parse_error!("total bad number of arguments");
|
||||
}
|
||||
let expr = &agg.args[0];
|
||||
let expr_reg = program.alloc_register();
|
||||
let _ = translate_expr(program, Some(referenced_tables), expr, expr_reg, resolver)?;
|
||||
handle_distinct(program, agg, expr_reg);
|
||||
let expr_reg = agg_arg_source.translate(program, referenced_tables, resolver, 0)?;
|
||||
handle_distinct(program, agg_arg_source.aggregate(), expr_reg);
|
||||
program.emit_insn(Insn::AggStep {
|
||||
acc_reg: target_register,
|
||||
col: expr_reg,
|
||||
@@ -367,31 +474,24 @@ pub fn translate_aggregation_step(
|
||||
target_register
|
||||
}
|
||||
AggFunc::External(ref func) => {
|
||||
let expr_reg = program.alloc_register();
|
||||
let argc = func.agg_args().map_err(|_| {
|
||||
LimboError::ExtensionError(
|
||||
"External aggregate function called with wrong number of arguments".to_string(),
|
||||
)
|
||||
})?;
|
||||
if argc != agg.args.len() {
|
||||
if argc != num_args {
|
||||
crate::bail_parse_error!(
|
||||
"External aggregate function called with wrong number of arguments"
|
||||
);
|
||||
}
|
||||
let expr_reg = agg_arg_source.translate(program, referenced_tables, resolver, 0)?;
|
||||
for i in 0..argc {
|
||||
if i != 0 {
|
||||
let _ = program.alloc_register();
|
||||
let _ = agg_arg_source.translate(program, referenced_tables, resolver, i)?;
|
||||
}
|
||||
let _ = translate_expr(
|
||||
program,
|
||||
Some(referenced_tables),
|
||||
&agg.args[i],
|
||||
expr_reg + i,
|
||||
resolver,
|
||||
)?;
|
||||
// invariant: distinct aggregates are only supported for single-argument functions
|
||||
if argc == 1 {
|
||||
handle_distinct(program, agg, expr_reg + i);
|
||||
handle_distinct(program, agg_arg_source.aggregate(), expr_reg + i);
|
||||
}
|
||||
}
|
||||
program.emit_insn(Insn::AggStep {
|
||||
|
||||
@@ -1,5 +1,8 @@
|
||||
use std::sync::Arc;
|
||||
use turso_parser::{ast, parser::Parser};
|
||||
use turso_parser::{
|
||||
ast::{self, fmt::ToTokens as _},
|
||||
parser::Parser,
|
||||
};
|
||||
|
||||
use crate::{
|
||||
function::{AlterTableFunc, Func},
|
||||
@@ -166,7 +169,7 @@ pub fn translate_alter_table(
|
||||
)?
|
||||
}
|
||||
ast::AlterTableBody::AddColumn(col_def) => {
|
||||
let column = Column::from(col_def);
|
||||
let column = Column::from(&col_def);
|
||||
|
||||
if let Some(default) = &column.default {
|
||||
if !matches!(
|
||||
@@ -233,97 +236,6 @@ pub fn translate_alter_table(
|
||||
},
|
||||
)?
|
||||
}
|
||||
ast::AlterTableBody::RenameColumn { old, new } => {
|
||||
let rename_from = old.as_str();
|
||||
let rename_to = new.as_str();
|
||||
|
||||
let Some((column_index, _)) = btree.get_column(rename_from) else {
|
||||
return Err(LimboError::ParseError(format!(
|
||||
"no such column: \"{rename_from}\""
|
||||
)));
|
||||
};
|
||||
|
||||
if btree.get_column(rename_to).is_some() {
|
||||
return Err(LimboError::ParseError(format!(
|
||||
"duplicate column name: \"{rename_from}\""
|
||||
)));
|
||||
};
|
||||
|
||||
let sqlite_schema = schema
|
||||
.get_btree_table(SQLITE_TABLEID)
|
||||
.expect("sqlite_schema should be on schema");
|
||||
|
||||
let cursor_id = program.alloc_cursor_id(CursorType::BTreeTable(sqlite_schema.clone()));
|
||||
|
||||
program.emit_insn(Insn::OpenWrite {
|
||||
cursor_id,
|
||||
root_page: RegisterOrLiteral::Literal(sqlite_schema.root_page),
|
||||
db: 0,
|
||||
});
|
||||
|
||||
program.cursor_loop(cursor_id, |program, rowid| {
|
||||
let sqlite_schema_column_len = sqlite_schema.columns.len();
|
||||
assert_eq!(sqlite_schema_column_len, 5);
|
||||
|
||||
let first_column = program.alloc_registers(sqlite_schema_column_len);
|
||||
|
||||
for i in 0..sqlite_schema_column_len {
|
||||
program.emit_column_or_rowid(cursor_id, i, first_column + i);
|
||||
}
|
||||
|
||||
program.emit_string8_new_reg(table_name.to_string());
|
||||
program.mark_last_insn_constant();
|
||||
|
||||
program.emit_string8_new_reg(rename_from.to_string());
|
||||
program.mark_last_insn_constant();
|
||||
|
||||
program.emit_string8_new_reg(rename_to.to_string());
|
||||
program.mark_last_insn_constant();
|
||||
|
||||
let out = program.alloc_registers(sqlite_schema_column_len);
|
||||
|
||||
program.emit_insn(Insn::Function {
|
||||
constant_mask: 0,
|
||||
start_reg: first_column,
|
||||
dest: out,
|
||||
func: crate::function::FuncCtx {
|
||||
func: Func::AlterTable(AlterTableFunc::RenameColumn),
|
||||
arg_count: 8,
|
||||
},
|
||||
});
|
||||
|
||||
let record = program.alloc_register();
|
||||
|
||||
program.emit_insn(Insn::MakeRecord {
|
||||
start_reg: out,
|
||||
count: sqlite_schema_column_len,
|
||||
dest_reg: record,
|
||||
index_name: None,
|
||||
});
|
||||
|
||||
program.emit_insn(Insn::Insert {
|
||||
cursor: cursor_id,
|
||||
key_reg: rowid,
|
||||
record_reg: record,
|
||||
flag: crate::vdbe::insn::InsertFlags(0),
|
||||
table_name: table_name.to_string(),
|
||||
});
|
||||
});
|
||||
|
||||
program.emit_insn(Insn::SetCookie {
|
||||
db: 0,
|
||||
cookie: Cookie::SchemaVersion,
|
||||
value: schema.schema_version as i32 + 1,
|
||||
p5: 0,
|
||||
});
|
||||
program.emit_insn(Insn::RenameColumn {
|
||||
table: table_name.to_owned(),
|
||||
column_index,
|
||||
name: rename_to.to_owned(),
|
||||
});
|
||||
|
||||
program
|
||||
}
|
||||
ast::AlterTableBody::RenameTo(new_name) => {
|
||||
let new_name = new_name.as_str();
|
||||
|
||||
@@ -409,6 +321,148 @@ pub fn translate_alter_table(
|
||||
to: new_name.to_owned(),
|
||||
});
|
||||
|
||||
program
|
||||
}
|
||||
body @ (ast::AlterTableBody::AlterColumn { .. }
|
||||
| ast::AlterTableBody::RenameColumn { .. }) => {
|
||||
let from;
|
||||
let definition;
|
||||
let col_name;
|
||||
let rename;
|
||||
|
||||
match body {
|
||||
ast::AlterTableBody::AlterColumn { old, new } => {
|
||||
from = old;
|
||||
definition = new;
|
||||
col_name = definition.col_name.clone();
|
||||
rename = false;
|
||||
}
|
||||
ast::AlterTableBody::RenameColumn { old, new } => {
|
||||
from = old;
|
||||
definition = ast::ColumnDefinition {
|
||||
col_name: new.clone(),
|
||||
col_type: None,
|
||||
constraints: vec![],
|
||||
};
|
||||
col_name = new;
|
||||
rename = true;
|
||||
}
|
||||
_ => unreachable!(),
|
||||
}
|
||||
|
||||
let from = from.as_str();
|
||||
let col_name = col_name.as_str();
|
||||
|
||||
let Some((column_index, _)) = btree.get_column(from) else {
|
||||
return Err(LimboError::ParseError(format!(
|
||||
"no such column: \"{from}\""
|
||||
)));
|
||||
};
|
||||
|
||||
if btree.get_column(col_name).is_some() {
|
||||
return Err(LimboError::ParseError(format!(
|
||||
"duplicate column name: \"{col_name}\""
|
||||
)));
|
||||
};
|
||||
|
||||
if definition
|
||||
.constraints
|
||||
.iter()
|
||||
.any(|c| matches!(c.constraint, ast::ColumnConstraint::PrimaryKey { .. }))
|
||||
{
|
||||
return Err(LimboError::ParseError(
|
||||
"PRIMARY KEY constraint cannot be altered".to_string(),
|
||||
));
|
||||
}
|
||||
|
||||
if definition
|
||||
.constraints
|
||||
.iter()
|
||||
.any(|c| matches!(c.constraint, ast::ColumnConstraint::Unique { .. }))
|
||||
{
|
||||
return Err(LimboError::ParseError(
|
||||
"UNIQUE constraint cannot be altered".to_string(),
|
||||
));
|
||||
}
|
||||
|
||||
let sqlite_schema = schema
|
||||
.get_btree_table(SQLITE_TABLEID)
|
||||
.expect("sqlite_schema should be on schema");
|
||||
|
||||
let cursor_id = program.alloc_cursor_id(CursorType::BTreeTable(sqlite_schema.clone()));
|
||||
|
||||
program.emit_insn(Insn::OpenWrite {
|
||||
cursor_id,
|
||||
root_page: RegisterOrLiteral::Literal(sqlite_schema.root_page),
|
||||
db: 0,
|
||||
});
|
||||
|
||||
program.cursor_loop(cursor_id, |program, rowid| {
|
||||
let sqlite_schema_column_len = sqlite_schema.columns.len();
|
||||
assert_eq!(sqlite_schema_column_len, 5);
|
||||
|
||||
let first_column = program.alloc_registers(sqlite_schema_column_len);
|
||||
|
||||
for i in 0..sqlite_schema_column_len {
|
||||
program.emit_column_or_rowid(cursor_id, i, first_column + i);
|
||||
}
|
||||
|
||||
program.emit_string8_new_reg(table_name.to_string());
|
||||
program.mark_last_insn_constant();
|
||||
|
||||
program.emit_string8_new_reg(from.to_string());
|
||||
program.mark_last_insn_constant();
|
||||
|
||||
program.emit_string8_new_reg(definition.format().unwrap());
|
||||
program.mark_last_insn_constant();
|
||||
|
||||
let out = program.alloc_registers(sqlite_schema_column_len);
|
||||
|
||||
program.emit_insn(Insn::Function {
|
||||
constant_mask: 0,
|
||||
start_reg: first_column,
|
||||
dest: out,
|
||||
func: crate::function::FuncCtx {
|
||||
func: Func::AlterTable(if rename {
|
||||
AlterTableFunc::RenameColumn
|
||||
} else {
|
||||
AlterTableFunc::AlterColumn
|
||||
}),
|
||||
arg_count: 8,
|
||||
},
|
||||
});
|
||||
|
||||
let record = program.alloc_register();
|
||||
|
||||
program.emit_insn(Insn::MakeRecord {
|
||||
start_reg: out,
|
||||
count: sqlite_schema_column_len,
|
||||
dest_reg: record,
|
||||
index_name: None,
|
||||
});
|
||||
|
||||
program.emit_insn(Insn::Insert {
|
||||
cursor: cursor_id,
|
||||
key_reg: rowid,
|
||||
record_reg: record,
|
||||
flag: crate::vdbe::insn::InsertFlags(0),
|
||||
table_name: table_name.to_string(),
|
||||
});
|
||||
});
|
||||
|
||||
program.emit_insn(Insn::SetCookie {
|
||||
db: 0,
|
||||
cookie: Cookie::SchemaVersion,
|
||||
value: schema.schema_version as i32 + 1,
|
||||
p5: 0,
|
||||
});
|
||||
program.emit_insn(Insn::AlterColumn {
|
||||
table: table_name.to_owned(),
|
||||
column_index,
|
||||
definition,
|
||||
rename,
|
||||
});
|
||||
|
||||
program
|
||||
}
|
||||
})
|
||||
|
||||
@@ -1,9 +1,16 @@
|
||||
use turso_parser::ast;
|
||||
|
||||
use super::{
|
||||
emitter::TranslateCtx,
|
||||
expr::{translate_condition_expr, translate_expr, ConditionMetadata},
|
||||
order_by::order_by_sorter_insert,
|
||||
plan::{Distinctness, GroupBy, SelectPlan},
|
||||
result_row::emit_select_result,
|
||||
};
|
||||
use crate::translate::aggregation::{translate_aggregation_step, AggArgumentSource};
|
||||
use crate::translate::expr::{walk_expr, WalkControl};
|
||||
use crate::translate::plan::ResultSetColumn;
|
||||
use crate::{
|
||||
function::AggFunc,
|
||||
schema::PseudoCursorType,
|
||||
translate::collate::CollationSeq,
|
||||
util::exprs_are_equivalent,
|
||||
@@ -15,15 +22,6 @@ use crate::{
|
||||
Result,
|
||||
};
|
||||
|
||||
use super::{
|
||||
aggregation::handle_distinct,
|
||||
emitter::{Resolver, TranslateCtx},
|
||||
expr::{translate_condition_expr, translate_expr, ConditionMetadata},
|
||||
order_by::order_by_sorter_insert,
|
||||
plan::{Aggregate, Distinctness, GroupBy, SelectPlan, TableReferences},
|
||||
result_row::emit_select_result,
|
||||
};
|
||||
|
||||
/// Labels needed for various jumps in GROUP BY handling.
|
||||
#[derive(Debug)]
|
||||
pub struct GroupByLabels {
|
||||
@@ -394,102 +392,6 @@ pub enum GroupByRowSource {
|
||||
},
|
||||
}
|
||||
|
||||
/// Enum representing the source of the aggregate function arguments
|
||||
/// emitted for a group by aggregation.
|
||||
/// In the common case, the aggregate function arguments are first inserted
|
||||
/// into a sorter in the main loop, and in the group by aggregation phase
|
||||
/// we read the data from the sorter.
|
||||
///
|
||||
/// In the alternative case, no sorting is required for group by,
|
||||
/// and the aggregate function arguments are retrieved directly from
|
||||
/// registers allocated in the main loop.
|
||||
pub enum GroupByAggArgumentSource<'a> {
|
||||
/// The aggregate function arguments are retrieved from a pseudo cursor
|
||||
/// which reads from the GROUP BY sorter.
|
||||
PseudoCursor {
|
||||
cursor_id: usize,
|
||||
col_start: usize,
|
||||
dest_reg_start: usize,
|
||||
aggregate: &'a Aggregate,
|
||||
},
|
||||
/// The aggregate function arguments are retrieved from a contiguous block of registers
|
||||
/// allocated in the main loop for that given aggregate function.
|
||||
Register {
|
||||
src_reg_start: usize,
|
||||
aggregate: &'a Aggregate,
|
||||
},
|
||||
}
|
||||
|
||||
impl<'a> GroupByAggArgumentSource<'a> {
|
||||
/// Create a new [GroupByAggArgumentSource] that retrieves the values from a GROUP BY sorter.
|
||||
pub fn new_from_cursor(
|
||||
program: &mut ProgramBuilder,
|
||||
cursor_id: usize,
|
||||
col_start: usize,
|
||||
aggregate: &'a Aggregate,
|
||||
) -> Self {
|
||||
let dest_reg_start = program.alloc_registers(aggregate.args.len());
|
||||
Self::PseudoCursor {
|
||||
cursor_id,
|
||||
col_start,
|
||||
dest_reg_start,
|
||||
aggregate,
|
||||
}
|
||||
}
|
||||
/// Create a new [GroupByAggArgumentSource] that retrieves the values directly from an already
|
||||
/// populated register or registers.
|
||||
pub fn new_from_registers(src_reg_start: usize, aggregate: &'a Aggregate) -> Self {
|
||||
Self::Register {
|
||||
src_reg_start,
|
||||
aggregate,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn aggregate(&self) -> &Aggregate {
|
||||
match self {
|
||||
GroupByAggArgumentSource::PseudoCursor { aggregate, .. } => aggregate,
|
||||
GroupByAggArgumentSource::Register { aggregate, .. } => aggregate,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn agg_func(&self) -> &AggFunc {
|
||||
match self {
|
||||
GroupByAggArgumentSource::PseudoCursor { aggregate, .. } => &aggregate.func,
|
||||
GroupByAggArgumentSource::Register { aggregate, .. } => &aggregate.func,
|
||||
}
|
||||
}
|
||||
pub fn args(&self) -> &[ast::Expr] {
|
||||
match self {
|
||||
GroupByAggArgumentSource::PseudoCursor { aggregate, .. } => &aggregate.args,
|
||||
GroupByAggArgumentSource::Register { aggregate, .. } => &aggregate.args,
|
||||
}
|
||||
}
|
||||
pub fn num_args(&self) -> usize {
|
||||
match self {
|
||||
GroupByAggArgumentSource::PseudoCursor { aggregate, .. } => aggregate.args.len(),
|
||||
GroupByAggArgumentSource::Register { aggregate, .. } => aggregate.args.len(),
|
||||
}
|
||||
}
|
||||
/// Read the value of an aggregate function argument either from sorter data or directly from a register.
|
||||
pub fn translate(&self, program: &mut ProgramBuilder, arg_idx: usize) -> Result<usize> {
|
||||
match self {
|
||||
GroupByAggArgumentSource::PseudoCursor {
|
||||
cursor_id,
|
||||
col_start,
|
||||
dest_reg_start,
|
||||
..
|
||||
} => {
|
||||
program.emit_column_or_rowid(*cursor_id, *col_start, dest_reg_start + arg_idx);
|
||||
Ok(dest_reg_start + arg_idx)
|
||||
}
|
||||
GroupByAggArgumentSource::Register {
|
||||
src_reg_start: start_reg,
|
||||
..
|
||||
} => Ok(*start_reg + arg_idx),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Emits bytecode for processing a single GROUP BY group.
|
||||
pub fn group_by_process_single_group(
|
||||
program: &mut ProgramBuilder,
|
||||
@@ -593,21 +495,19 @@ pub fn group_by_process_single_group(
|
||||
.expect("aggregate registers must be initialized");
|
||||
let agg_result_reg = start_reg + i;
|
||||
let agg_arg_source = match &row_source {
|
||||
GroupByRowSource::Sorter { pseudo_cursor, .. } => {
|
||||
GroupByAggArgumentSource::new_from_cursor(
|
||||
program,
|
||||
*pseudo_cursor,
|
||||
cursor_index + offset,
|
||||
agg,
|
||||
)
|
||||
}
|
||||
GroupByRowSource::Sorter { pseudo_cursor, .. } => AggArgumentSource::new_from_cursor(
|
||||
program,
|
||||
*pseudo_cursor,
|
||||
cursor_index + offset,
|
||||
agg,
|
||||
),
|
||||
GroupByRowSource::MainLoop { start_reg_src, .. } => {
|
||||
// Aggregation arguments are always placed in the registers that follow any scalars.
|
||||
let start_reg_aggs = start_reg_src + t_ctx.non_aggregate_expressions.len();
|
||||
GroupByAggArgumentSource::new_from_registers(start_reg_aggs + offset, agg)
|
||||
AggArgumentSource::new_from_registers(start_reg_aggs + offset, agg)
|
||||
}
|
||||
};
|
||||
translate_aggregation_step_groupby(
|
||||
translate_aggregation_step(
|
||||
program,
|
||||
&plan.table_references,
|
||||
agg_arg_source,
|
||||
@@ -897,220 +797,3 @@ pub fn group_by_emit_row_phase<'a>(
|
||||
program.preassign_label_to_next_insn(labels.label_group_by_end);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Emits the bytecode for processing an aggregate step within a GROUP BY clause.
|
||||
/// Eg. in `SELECT product_category, SUM(price) FROM t GROUP BY line_item`, 'price' is evaluated for every row
|
||||
/// where the 'product_category' is the same, and the result is added to the accumulator for that category.
|
||||
///
|
||||
/// This is distinct from the final step, which is called after a single group has been entirely accumulated,
|
||||
/// and the actual result value of the aggregation is materialized.
|
||||
pub fn translate_aggregation_step_groupby(
|
||||
program: &mut ProgramBuilder,
|
||||
referenced_tables: &TableReferences,
|
||||
agg_arg_source: GroupByAggArgumentSource,
|
||||
target_register: usize,
|
||||
resolver: &Resolver,
|
||||
) -> Result<usize> {
|
||||
let num_args = agg_arg_source.num_args();
|
||||
let dest = match agg_arg_source.agg_func() {
|
||||
AggFunc::Avg => {
|
||||
if num_args != 1 {
|
||||
crate::bail_parse_error!("avg bad number of arguments");
|
||||
}
|
||||
let expr_reg = agg_arg_source.translate(program, 0)?;
|
||||
handle_distinct(program, agg_arg_source.aggregate(), expr_reg);
|
||||
program.emit_insn(Insn::AggStep {
|
||||
acc_reg: target_register,
|
||||
col: expr_reg,
|
||||
delimiter: 0,
|
||||
func: AggFunc::Avg,
|
||||
});
|
||||
target_register
|
||||
}
|
||||
AggFunc::Count | AggFunc::Count0 => {
|
||||
let expr_reg = agg_arg_source.translate(program, 0)?;
|
||||
handle_distinct(program, agg_arg_source.aggregate(), expr_reg);
|
||||
program.emit_insn(Insn::AggStep {
|
||||
acc_reg: target_register,
|
||||
col: expr_reg,
|
||||
delimiter: 0,
|
||||
func: if matches!(agg_arg_source.agg_func(), AggFunc::Count0) {
|
||||
AggFunc::Count0
|
||||
} else {
|
||||
AggFunc::Count
|
||||
},
|
||||
});
|
||||
target_register
|
||||
}
|
||||
AggFunc::GroupConcat => {
|
||||
let num_args = agg_arg_source.num_args();
|
||||
if num_args != 1 && num_args != 2 {
|
||||
crate::bail_parse_error!("group_concat bad number of arguments");
|
||||
}
|
||||
|
||||
let delimiter_reg = program.alloc_register();
|
||||
|
||||
let delimiter_expr: ast::Expr;
|
||||
|
||||
if num_args == 2 {
|
||||
match &agg_arg_source.args()[1] {
|
||||
arg @ ast::Expr::Column { .. } => {
|
||||
delimiter_expr = arg.clone();
|
||||
}
|
||||
ast::Expr::Literal(ast::Literal::String(s)) => {
|
||||
delimiter_expr = ast::Expr::Literal(ast::Literal::String(s.to_string()));
|
||||
}
|
||||
_ => crate::bail_parse_error!("Incorrect delimiter parameter"),
|
||||
};
|
||||
} else {
|
||||
delimiter_expr = ast::Expr::Literal(ast::Literal::String(String::from("\",\"")));
|
||||
}
|
||||
|
||||
let expr_reg = agg_arg_source.translate(program, 0)?;
|
||||
handle_distinct(program, agg_arg_source.aggregate(), expr_reg);
|
||||
translate_expr(
|
||||
program,
|
||||
Some(referenced_tables),
|
||||
&delimiter_expr,
|
||||
delimiter_reg,
|
||||
resolver,
|
||||
)?;
|
||||
|
||||
program.emit_insn(Insn::AggStep {
|
||||
acc_reg: target_register,
|
||||
col: expr_reg,
|
||||
delimiter: delimiter_reg,
|
||||
func: AggFunc::GroupConcat,
|
||||
});
|
||||
|
||||
target_register
|
||||
}
|
||||
AggFunc::Max => {
|
||||
if num_args != 1 {
|
||||
crate::bail_parse_error!("max bad number of arguments");
|
||||
}
|
||||
let expr_reg = agg_arg_source.translate(program, 0)?;
|
||||
handle_distinct(program, agg_arg_source.aggregate(), expr_reg);
|
||||
program.emit_insn(Insn::AggStep {
|
||||
acc_reg: target_register,
|
||||
col: expr_reg,
|
||||
delimiter: 0,
|
||||
func: AggFunc::Max,
|
||||
});
|
||||
target_register
|
||||
}
|
||||
AggFunc::Min => {
|
||||
if num_args != 1 {
|
||||
crate::bail_parse_error!("min bad number of arguments");
|
||||
}
|
||||
let expr_reg = agg_arg_source.translate(program, 0)?;
|
||||
handle_distinct(program, agg_arg_source.aggregate(), expr_reg);
|
||||
program.emit_insn(Insn::AggStep {
|
||||
acc_reg: target_register,
|
||||
col: expr_reg,
|
||||
delimiter: 0,
|
||||
func: AggFunc::Min,
|
||||
});
|
||||
target_register
|
||||
}
|
||||
#[cfg(feature = "json")]
|
||||
AggFunc::JsonGroupArray | AggFunc::JsonbGroupArray => {
|
||||
if num_args != 1 {
|
||||
crate::bail_parse_error!("min bad number of arguments");
|
||||
}
|
||||
let expr_reg = agg_arg_source.translate(program, 0)?;
|
||||
handle_distinct(program, agg_arg_source.aggregate(), expr_reg);
|
||||
program.emit_insn(Insn::AggStep {
|
||||
acc_reg: target_register,
|
||||
col: expr_reg,
|
||||
delimiter: 0,
|
||||
func: AggFunc::JsonGroupArray,
|
||||
});
|
||||
target_register
|
||||
}
|
||||
#[cfg(feature = "json")]
|
||||
AggFunc::JsonGroupObject | AggFunc::JsonbGroupObject => {
|
||||
if num_args != 2 {
|
||||
crate::bail_parse_error!("max bad number of arguments");
|
||||
}
|
||||
|
||||
let expr_reg = agg_arg_source.translate(program, 0)?;
|
||||
handle_distinct(program, agg_arg_source.aggregate(), expr_reg);
|
||||
let value_reg = agg_arg_source.translate(program, 1)?;
|
||||
|
||||
program.emit_insn(Insn::AggStep {
|
||||
acc_reg: target_register,
|
||||
col: expr_reg,
|
||||
delimiter: value_reg,
|
||||
func: AggFunc::JsonGroupObject,
|
||||
});
|
||||
target_register
|
||||
}
|
||||
AggFunc::StringAgg => {
|
||||
if num_args != 2 {
|
||||
crate::bail_parse_error!("string_agg bad number of arguments");
|
||||
}
|
||||
|
||||
let delimiter_reg = program.alloc_register();
|
||||
|
||||
let delimiter_expr = match &agg_arg_source.args()[1] {
|
||||
arg @ ast::Expr::Column { .. } => arg.clone(),
|
||||
ast::Expr::Literal(ast::Literal::String(s)) => {
|
||||
ast::Expr::Literal(ast::Literal::String(s.to_string()))
|
||||
}
|
||||
_ => crate::bail_parse_error!("Incorrect delimiter parameter"),
|
||||
};
|
||||
|
||||
let expr_reg = agg_arg_source.translate(program, 0)?;
|
||||
handle_distinct(program, agg_arg_source.aggregate(), expr_reg);
|
||||
translate_expr(
|
||||
program,
|
||||
Some(referenced_tables),
|
||||
&delimiter_expr,
|
||||
delimiter_reg,
|
||||
resolver,
|
||||
)?;
|
||||
|
||||
program.emit_insn(Insn::AggStep {
|
||||
acc_reg: target_register,
|
||||
col: expr_reg,
|
||||
delimiter: delimiter_reg,
|
||||
func: AggFunc::StringAgg,
|
||||
});
|
||||
|
||||
target_register
|
||||
}
|
||||
AggFunc::Sum => {
|
||||
if num_args != 1 {
|
||||
crate::bail_parse_error!("sum bad number of arguments");
|
||||
}
|
||||
let expr_reg = agg_arg_source.translate(program, 0)?;
|
||||
handle_distinct(program, agg_arg_source.aggregate(), expr_reg);
|
||||
program.emit_insn(Insn::AggStep {
|
||||
acc_reg: target_register,
|
||||
col: expr_reg,
|
||||
delimiter: 0,
|
||||
func: AggFunc::Sum,
|
||||
});
|
||||
target_register
|
||||
}
|
||||
AggFunc::Total => {
|
||||
if num_args != 1 {
|
||||
crate::bail_parse_error!("total bad number of arguments");
|
||||
}
|
||||
let expr_reg = agg_arg_source.translate(program, 0)?;
|
||||
handle_distinct(program, agg_arg_source.aggregate(), expr_reg);
|
||||
program.emit_insn(Insn::AggStep {
|
||||
acc_reg: target_register,
|
||||
col: expr_reg,
|
||||
delimiter: 0,
|
||||
func: AggFunc::Total,
|
||||
});
|
||||
target_register
|
||||
}
|
||||
AggFunc::External(_) => {
|
||||
todo!("External aggregate functions are not yet supported in GROUP BY");
|
||||
}
|
||||
};
|
||||
Ok(dest)
|
||||
}
|
||||
|
||||
@@ -19,7 +19,7 @@ use crate::{
|
||||
};
|
||||
|
||||
use super::{
|
||||
aggregation::translate_aggregation_step,
|
||||
aggregation::{translate_aggregation_step, AggArgumentSource},
|
||||
emitter::{OperationMode, TranslateCtx},
|
||||
expr::{
|
||||
translate_condition_expr, translate_expr, translate_expr_no_constant_opt,
|
||||
@@ -868,7 +868,7 @@ fn emit_loop_source(
|
||||
translate_aggregation_step(
|
||||
program,
|
||||
&plan.table_references,
|
||||
agg,
|
||||
AggArgumentSource::new_from_expression(agg),
|
||||
reg,
|
||||
&t_ctx.resolver,
|
||||
)?;
|
||||
|
||||
@@ -1048,6 +1048,24 @@ pub struct Aggregate {
|
||||
}
|
||||
|
||||
impl Aggregate {
|
||||
pub fn new(func: AggFunc, args: &[Box<Expr>], expr: &Expr, distinctness: Distinctness) -> Self {
|
||||
let agg_args = if args.is_empty() {
|
||||
// The AggStep instruction requires at least one argument. For functions that accept
|
||||
// zero arguments (e.g. COUNT()), we insert a dummy literal so that AggStep remains valid.
|
||||
// This does not cause ambiguity: the resolver has already verified that the function
|
||||
// takes zero arguments, so the dummy value will be ignored.
|
||||
vec![Expr::Literal(ast::Literal::Numeric("1".to_string()))]
|
||||
} else {
|
||||
args.iter().map(|arg| *arg.clone()).collect()
|
||||
};
|
||||
Aggregate {
|
||||
func,
|
||||
args: agg_args,
|
||||
original_expr: expr.clone(),
|
||||
distinctness,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn is_distinct(&self) -> bool {
|
||||
self.distinctness.is_distinct()
|
||||
}
|
||||
|
||||
@@ -73,12 +73,7 @@ pub fn resolve_aggregates(
|
||||
"DISTINCT aggregate functions must have exactly one argument"
|
||||
);
|
||||
}
|
||||
aggs.push(Aggregate {
|
||||
func: f,
|
||||
args: args.iter().map(|arg| *arg.clone()).collect(),
|
||||
original_expr: expr.clone(),
|
||||
distinctness,
|
||||
});
|
||||
aggs.push(Aggregate::new(f, args, expr, distinctness));
|
||||
contains_aggregates = true;
|
||||
}
|
||||
_ => {
|
||||
@@ -95,12 +90,7 @@ pub fn resolve_aggregates(
|
||||
);
|
||||
}
|
||||
if let Ok(Func::Agg(f)) = Func::resolve_function(name.as_str(), 0) {
|
||||
aggs.push(Aggregate {
|
||||
func: f,
|
||||
args: vec![],
|
||||
original_expr: expr.clone(),
|
||||
distinctness: Distinctness::NonDistinct,
|
||||
});
|
||||
aggs.push(Aggregate::new(f, &[], expr, Distinctness::NonDistinct));
|
||||
contains_aggregates = true;
|
||||
}
|
||||
}
|
||||
|
||||
@@ -371,27 +371,7 @@ fn prepare_one_select_plan(
|
||||
}
|
||||
match Func::resolve_function(name.as_str(), args_count) {
|
||||
Ok(Func::Agg(f)) => {
|
||||
let agg_args = match (args.is_empty(), &f) {
|
||||
(true, crate::function::AggFunc::Count0) => {
|
||||
// COUNT() case
|
||||
vec![ast::Expr::Literal(ast::Literal::Numeric(
|
||||
"1".to_string(),
|
||||
))
|
||||
.into()]
|
||||
}
|
||||
(true, _) => crate::bail_parse_error!(
|
||||
"Aggregate function {} requires arguments",
|
||||
name.as_str()
|
||||
),
|
||||
(false, _) => args.clone(),
|
||||
};
|
||||
|
||||
let agg = Aggregate {
|
||||
func: f,
|
||||
args: agg_args.iter().map(|arg| *arg.clone()).collect(),
|
||||
original_expr: *expr.clone(),
|
||||
distinctness,
|
||||
};
|
||||
let agg = Aggregate::new(f, args, expr, distinctness);
|
||||
aggregate_expressions.push(agg);
|
||||
plan.result_columns.push(ResultSetColumn {
|
||||
alias: maybe_alias.as_ref().map(|alias| match alias {
|
||||
@@ -446,15 +426,12 @@ fn prepare_one_select_plan(
|
||||
contains_aggregates,
|
||||
});
|
||||
} else {
|
||||
let agg = Aggregate {
|
||||
func: AggFunc::External(f.func.clone().into()),
|
||||
args: args
|
||||
.iter()
|
||||
.map(|arg| *arg.clone())
|
||||
.collect(),
|
||||
original_expr: *expr.clone(),
|
||||
let agg = Aggregate::new(
|
||||
AggFunc::External(f.func.clone().into()),
|
||||
args,
|
||||
expr,
|
||||
distinctness,
|
||||
};
|
||||
);
|
||||
aggregate_expressions.push(agg);
|
||||
plan.result_columns.push(ResultSetColumn {
|
||||
alias: maybe_alias.as_ref().map(|alias| {
|
||||
@@ -488,14 +465,8 @@ fn prepare_one_select_plan(
|
||||
}
|
||||
match Func::resolve_function(name.as_str(), 0) {
|
||||
Ok(Func::Agg(f)) => {
|
||||
let agg = Aggregate {
|
||||
func: f,
|
||||
args: vec![ast::Expr::Literal(ast::Literal::Numeric(
|
||||
"1".to_string(),
|
||||
))],
|
||||
original_expr: *expr.clone(),
|
||||
distinctness: Distinctness::NonDistinct,
|
||||
};
|
||||
let agg =
|
||||
Aggregate::new(f, &[], expr, Distinctness::NonDistinct);
|
||||
aggregate_expressions.push(agg);
|
||||
plan.result_columns.push(ResultSetColumn {
|
||||
alias: maybe_alias.as_ref().map(|alias| match alias {
|
||||
|
||||
@@ -14,7 +14,7 @@ use crate::translate::plan::IterationDirection;
|
||||
use crate::vdbe::sorter::Sorter;
|
||||
use crate::vdbe::Register;
|
||||
use crate::vtab::VirtualTableCursor;
|
||||
use crate::{turso_assert, Completion, Result, IO};
|
||||
use crate::{turso_assert, Completion, CompletionError, Result, IO};
|
||||
use std::fmt::{Debug, Display};
|
||||
|
||||
const MAX_REAL_SIZE: u8 = 15;
|
||||
@@ -350,6 +350,13 @@ impl Value {
|
||||
}
|
||||
}
|
||||
|
||||
pub fn as_uint(&self) -> u64 {
|
||||
match self {
|
||||
Value::Integer(i) => (*i).cast_unsigned(),
|
||||
_ => 0,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn from_text(text: &str) -> Self {
|
||||
Value::Text(Text::new(text))
|
||||
}
|
||||
@@ -2502,6 +2509,13 @@ impl IOCompletions {
|
||||
IOCompletions::Many(completions) => completions.iter().for_each(|c| c.abort()),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn get_error(&self) -> Option<CompletionError> {
|
||||
match self {
|
||||
IOCompletions::Single(c) => c.get_error(),
|
||||
IOCompletions::Many(completions) => completions.iter().find_map(|c| c.get_error()),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
|
||||
59
core/util.rs
59
core/util.rs
@@ -703,58 +703,7 @@ pub fn columns_from_create_table_body(
|
||||
|
||||
use turso_parser::ast;
|
||||
|
||||
Ok(columns
|
||||
.iter()
|
||||
.map(
|
||||
|ast::ColumnDefinition {
|
||||
col_name: name,
|
||||
col_type,
|
||||
constraints,
|
||||
}| {
|
||||
Column {
|
||||
name: Some(normalize_ident(name.as_str())),
|
||||
ty: match col_type {
|
||||
Some(ref data_type) => type_from_name(data_type.name.as_str()).0,
|
||||
None => Type::Null,
|
||||
},
|
||||
default: constraints.iter().find_map(|c| match &c.constraint {
|
||||
ast::ColumnConstraint::Default(val) => Some(val.clone()),
|
||||
_ => None,
|
||||
}),
|
||||
notnull: constraints
|
||||
.iter()
|
||||
.any(|c| matches!(c.constraint, ast::ColumnConstraint::NotNull { .. })),
|
||||
ty_str: col_type
|
||||
.clone()
|
||||
.map(|t| t.name.to_string())
|
||||
.unwrap_or_default(),
|
||||
primary_key: constraints
|
||||
.iter()
|
||||
.any(|c| matches!(c.constraint, ast::ColumnConstraint::PrimaryKey { .. })),
|
||||
is_rowid_alias: false,
|
||||
unique: constraints
|
||||
.iter()
|
||||
.any(|c| matches!(c.constraint, ast::ColumnConstraint::Unique(..))),
|
||||
collation: constraints.iter().find_map(|c| match &c.constraint {
|
||||
// TODO: see if this should be the correct behavior
|
||||
// currently there cannot be any user defined collation sequences.
|
||||
// But in the future, when a user defines a collation sequence, creates a table with it,
|
||||
// then closes the db and opens it again. This may panic here if the collation seq is not registered
|
||||
// before reading the columns
|
||||
ast::ColumnConstraint::Collate { collation_name } => Some(
|
||||
CollationSeq::new(collation_name.as_str())
|
||||
.expect("collation should have been set correctly in create table"),
|
||||
),
|
||||
_ => None,
|
||||
}),
|
||||
hidden: col_type
|
||||
.as_ref()
|
||||
.map(|data_type| data_type.name.as_str().contains("HIDDEN"))
|
||||
.unwrap_or(false),
|
||||
}
|
||||
},
|
||||
)
|
||||
.collect::<Vec<_>>())
|
||||
Ok(columns.iter().map(Into::into).collect())
|
||||
}
|
||||
|
||||
/// This function checks if a given expression is a constant value that can be pushed down to the database engine.
|
||||
@@ -803,6 +752,10 @@ pub struct OpenOptions<'a> {
|
||||
pub cache: CacheMode,
|
||||
/// immutable=1|0 specifies that the database is stored on read-only media
|
||||
pub immutable: bool,
|
||||
// The encryption cipher
|
||||
pub cipher: Option<String>,
|
||||
// The encryption key in hex format
|
||||
pub hexkey: Option<String>,
|
||||
}
|
||||
|
||||
pub const MEMORY_PATH: &str = ":memory:";
|
||||
@@ -954,6 +907,8 @@ fn parse_query_params(query: &str, opts: &mut OpenOptions) -> Result<()> {
|
||||
"cache" => opts.cache = decoded_value.as_str().into(),
|
||||
"immutable" => opts.immutable = decoded_value == "1",
|
||||
"vfs" => opts.vfs = Some(decoded_value),
|
||||
"cipher" => opts.cipher = Some(decoded_value),
|
||||
"hexkey" => opts.hexkey = Some(decoded_value),
|
||||
_ => {}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -4461,10 +4461,16 @@ pub fn op_function(
|
||||
}
|
||||
}
|
||||
ScalarFunc::SqliteVersion => {
|
||||
let version_integer =
|
||||
return_if_io!(pager.with_header(|header| header.version_number)).get() as i64;
|
||||
let version = execute_sqlite_version(version_integer);
|
||||
state.registers[*dest] = Register::Value(Value::build_text(version));
|
||||
if !program.connection.is_db_initialized() {
|
||||
state.registers[*dest] =
|
||||
Register::Value(Value::build_text(info::build::PKG_VERSION));
|
||||
} else {
|
||||
let version_integer =
|
||||
return_if_io!(pager.with_header(|header| header.version_number)).get()
|
||||
as i64;
|
||||
let version = execute_sqlite_version(version_integer);
|
||||
state.registers[*dest] = Register::Value(Value::build_text(version));
|
||||
}
|
||||
}
|
||||
ScalarFunc::SqliteSourceId => {
|
||||
let src_id = format!(
|
||||
@@ -4852,10 +4858,10 @@ pub fn op_function(
|
||||
|
||||
match stmt {
|
||||
ast::Stmt::CreateIndex {
|
||||
tbl_name,
|
||||
unique,
|
||||
if_not_exists,
|
||||
idx_name,
|
||||
tbl_name,
|
||||
columns,
|
||||
where_clause,
|
||||
} => {
|
||||
@@ -4867,10 +4873,10 @@ pub fn op_function(
|
||||
|
||||
Some(
|
||||
ast::Stmt::CreateIndex {
|
||||
tbl_name: ast::Name::new(&rename_to),
|
||||
unique,
|
||||
if_not_exists,
|
||||
idx_name,
|
||||
tbl_name: ast::Name::new(&rename_to),
|
||||
columns,
|
||||
where_clause,
|
||||
}
|
||||
@@ -4879,9 +4885,9 @@ pub fn op_function(
|
||||
)
|
||||
}
|
||||
ast::Stmt::CreateTable {
|
||||
tbl_name,
|
||||
temporary,
|
||||
if_not_exists,
|
||||
tbl_name,
|
||||
body,
|
||||
} => {
|
||||
let table_name = normalize_ident(tbl_name.name.as_str());
|
||||
@@ -4892,13 +4898,13 @@ pub fn op_function(
|
||||
|
||||
Some(
|
||||
ast::Stmt::CreateTable {
|
||||
temporary,
|
||||
if_not_exists,
|
||||
tbl_name: ast::QualifiedName {
|
||||
db_name: None,
|
||||
name: ast::Name::new(&rename_to),
|
||||
alias: None,
|
||||
},
|
||||
temporary,
|
||||
if_not_exists,
|
||||
body,
|
||||
}
|
||||
.format()
|
||||
@@ -4911,7 +4917,7 @@ pub fn op_function(
|
||||
|
||||
(new_name, new_tbl_name, new_sql)
|
||||
}
|
||||
AlterTableFunc::RenameColumn => {
|
||||
AlterTableFunc::AlterColumn | AlterTableFunc::RenameColumn => {
|
||||
let table = {
|
||||
match &state.registers[*start_reg + 5].get_value() {
|
||||
Value::Text(rename_to) => normalize_ident(rename_to.as_str()),
|
||||
@@ -4926,13 +4932,17 @@ pub fn op_function(
|
||||
}
|
||||
};
|
||||
|
||||
let rename_to = {
|
||||
let column_def = {
|
||||
match &state.registers[*start_reg + 7].get_value() {
|
||||
Value::Text(rename_to) => normalize_ident(rename_to.as_str()),
|
||||
Value::Text(column_def) => column_def.as_str(),
|
||||
_ => panic!("rename_to parameter should be TEXT"),
|
||||
}
|
||||
};
|
||||
|
||||
let column_def = Parser::new(column_def.as_bytes())
|
||||
.parse_column_definition(true)
|
||||
.unwrap();
|
||||
|
||||
let new_sql = 'sql: {
|
||||
if table != tbl_name {
|
||||
break 'sql None;
|
||||
@@ -4949,11 +4959,11 @@ pub fn op_function(
|
||||
|
||||
match stmt {
|
||||
ast::Stmt::CreateIndex {
|
||||
tbl_name,
|
||||
mut columns,
|
||||
unique,
|
||||
if_not_exists,
|
||||
idx_name,
|
||||
tbl_name,
|
||||
mut columns,
|
||||
where_clause,
|
||||
} => {
|
||||
if table != normalize_ident(tbl_name.as_str()) {
|
||||
@@ -4965,7 +4975,7 @@ pub fn op_function(
|
||||
ast::Expr::Id(ast::Name::Ident(id))
|
||||
if normalize_ident(id) == rename_from =>
|
||||
{
|
||||
*id = rename_to.clone();
|
||||
*id = column_def.col_name.as_str().to_owned();
|
||||
}
|
||||
_ => {}
|
||||
}
|
||||
@@ -4973,11 +4983,11 @@ pub fn op_function(
|
||||
|
||||
Some(
|
||||
ast::Stmt::CreateIndex {
|
||||
tbl_name,
|
||||
columns,
|
||||
unique,
|
||||
if_not_exists,
|
||||
idx_name,
|
||||
tbl_name,
|
||||
columns,
|
||||
where_clause,
|
||||
}
|
||||
.format()
|
||||
@@ -4985,10 +4995,10 @@ pub fn op_function(
|
||||
)
|
||||
}
|
||||
ast::Stmt::CreateTable {
|
||||
temporary,
|
||||
if_not_exists,
|
||||
tbl_name,
|
||||
body,
|
||||
temporary,
|
||||
if_not_exists,
|
||||
} => {
|
||||
if table != normalize_ident(tbl_name.name.as_str()) {
|
||||
break 'sql None;
|
||||
@@ -5008,18 +5018,24 @@ pub fn op_function(
|
||||
.find(|column| column.col_name == ast::Name::new(&rename_from))
|
||||
.expect("column being renamed should be present");
|
||||
|
||||
column.col_name = ast::Name::new(&rename_to);
|
||||
match alter_func {
|
||||
AlterTableFunc::AlterColumn => *column = column_def,
|
||||
AlterTableFunc::RenameColumn => {
|
||||
column.col_name = column_def.col_name
|
||||
}
|
||||
_ => unreachable!(),
|
||||
}
|
||||
|
||||
Some(
|
||||
ast::Stmt::CreateTable {
|
||||
temporary,
|
||||
if_not_exists,
|
||||
tbl_name,
|
||||
body: ast::CreateTableBody::ColumnsAndConstraints {
|
||||
columns,
|
||||
constraints,
|
||||
options,
|
||||
},
|
||||
temporary,
|
||||
if_not_exists,
|
||||
}
|
||||
.format()
|
||||
.unwrap(),
|
||||
@@ -7303,7 +7319,7 @@ pub fn op_add_column(
|
||||
Ok(InsnFunctionStepResult::Step)
|
||||
}
|
||||
|
||||
pub fn op_rename_column(
|
||||
pub fn op_alter_column(
|
||||
program: &Program,
|
||||
state: &mut ProgramState,
|
||||
insn: &Insn,
|
||||
@@ -7311,16 +7327,19 @@ pub fn op_rename_column(
|
||||
mv_store: Option<&Arc<MvStore>>,
|
||||
) -> Result<InsnFunctionStepResult> {
|
||||
load_insn!(
|
||||
RenameColumn {
|
||||
AlterColumn {
|
||||
table: table_name,
|
||||
column_index,
|
||||
name
|
||||
definition,
|
||||
rename,
|
||||
},
|
||||
insn
|
||||
);
|
||||
|
||||
let conn = program.connection.clone();
|
||||
|
||||
let new_column = crate::schema::Column::from(definition);
|
||||
|
||||
conn.with_schema_mut(|schema| {
|
||||
let table = schema
|
||||
.tables
|
||||
@@ -7347,13 +7366,17 @@ pub fn op_rename_column(
|
||||
if index_column.name
|
||||
== *column.name.as_ref().expect("btree column should be named")
|
||||
{
|
||||
index_column.name = name.to_owned();
|
||||
index_column.name = definition.col_name.as_str().to_owned();
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
column.name = Some(name.to_owned());
|
||||
if *rename {
|
||||
column.name = new_column.name;
|
||||
} else {
|
||||
*column = new_column;
|
||||
}
|
||||
});
|
||||
|
||||
state.pc += 1;
|
||||
|
||||
@@ -1672,14 +1672,14 @@ pub fn insn_to_str(
|
||||
0,
|
||||
format!("add_column({table}, {column:?})"),
|
||||
),
|
||||
Insn::RenameColumn { table, column_index, name } => (
|
||||
"RenameColumn",
|
||||
Insn::AlterColumn { table, column_index, definition: column, rename } => (
|
||||
"AlterColumn",
|
||||
0,
|
||||
0,
|
||||
0,
|
||||
Value::build_text(""),
|
||||
0,
|
||||
format!("rename_column({table}, {column_index}, {name})"),
|
||||
format!("alter_column({table}, {column_index}, {column:?}, {rename:?})"),
|
||||
),
|
||||
Insn::MaxPgcnt { db, dest, new_max } => (
|
||||
"MaxPgcnt",
|
||||
|
||||
@@ -1053,10 +1053,11 @@ pub enum Insn {
|
||||
table: String,
|
||||
column: Column,
|
||||
},
|
||||
RenameColumn {
|
||||
AlterColumn {
|
||||
table: String,
|
||||
column_index: usize,
|
||||
name: String,
|
||||
definition: turso_parser::ast::ColumnDefinition,
|
||||
rename: bool,
|
||||
},
|
||||
/// Try to set the maximum page count for database P1 to the value in P3.
|
||||
/// Do not let the maximum page count fall below the current page count and
|
||||
@@ -1209,7 +1210,7 @@ impl Insn {
|
||||
Insn::RenameTable { .. } => execute::op_rename_table,
|
||||
Insn::DropColumn { .. } => execute::op_drop_column,
|
||||
Insn::AddColumn { .. } => execute::op_add_column,
|
||||
Insn::RenameColumn { .. } => execute::op_rename_column,
|
||||
Insn::AlterColumn { .. } => execute::op_alter_column,
|
||||
Insn::MaxPgcnt { .. } => execute::op_max_pgcnt,
|
||||
Insn::JournalMode { .. } => execute::op_journal_mode,
|
||||
}
|
||||
|
||||
@@ -460,6 +460,11 @@ impl Program {
|
||||
if !io.finished() {
|
||||
return Ok(StepResult::IO);
|
||||
}
|
||||
if let Some(err) = io.get_error() {
|
||||
let err = err.into();
|
||||
handle_program_error(&pager, &self.connection, &err)?;
|
||||
return Err(err);
|
||||
}
|
||||
state.io_completions = None;
|
||||
}
|
||||
// invalidate row
|
||||
|
||||
@@ -370,7 +370,7 @@ struct SortedChunk {
|
||||
/// The chunk file.
|
||||
file: Arc<dyn File>,
|
||||
/// Offset of the start of chunk in file
|
||||
start_offset: usize,
|
||||
start_offset: u64,
|
||||
/// The size of this chunk file in bytes.
|
||||
chunk_size: usize,
|
||||
/// The read buffer.
|
||||
@@ -391,7 +391,7 @@ impl SortedChunk {
|
||||
fn new(file: Arc<dyn File>, start_offset: usize, buffer_size: usize) -> Self {
|
||||
Self {
|
||||
file,
|
||||
start_offset,
|
||||
start_offset: start_offset as u64,
|
||||
chunk_size: 0,
|
||||
buffer: Rc::new(RefCell::new(vec![0; buffer_size])),
|
||||
buffer_len: Rc::new(Cell::new(0)),
|
||||
@@ -522,7 +522,7 @@ impl SortedChunk {
|
||||
let c = Completion::new_read(read_buffer_ref, read_complete);
|
||||
let c = self
|
||||
.file
|
||||
.pread(self.start_offset + self.total_bytes_read.get(), c)?;
|
||||
.pread(self.start_offset + self.total_bytes_read.get() as u64, c)?;
|
||||
Ok(c)
|
||||
}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user