Merge branch 'main' into clean-parser-4

This commit is contained in:
TcMits
2025-08-24 13:15:05 +07:00
15 changed files with 663 additions and 195 deletions

View File

@@ -176,6 +176,11 @@ public final class JDBC4PreparedStatement extends JDBC4Statement implements Prep
// TODO
}
@Override
public void addBatch(String sql) throws SQLException {
throw new SQLException("addBatch(String) cannot be called on a PreparedStatement");
}
@Override
public void setCharacterStream(int parameterIndex, Reader reader, int length)
throws SQLException {}

View File

@@ -2,12 +2,16 @@ package tech.turso.jdbc4;
import static java.util.Objects.requireNonNull;
import java.sql.BatchUpdateException;
import java.sql.Connection;
import java.sql.ResultSet;
import java.sql.SQLException;
import java.sql.SQLWarning;
import java.sql.Statement;
import java.util.ArrayList;
import java.util.List;
import java.util.concurrent.locks.ReentrantLock;
import java.util.regex.Pattern;
import tech.turso.annotations.Nullable;
import tech.turso.annotations.SkipNullableCheck;
import tech.turso.core.TursoResultSet;
@@ -15,6 +19,20 @@ import tech.turso.core.TursoStatement;
public class JDBC4Statement implements Statement {
private static final Pattern BATCH_COMPATIBLE_PATTERN =
Pattern.compile(
"^\\s*"
+ // Leading whitespace
"(?:/\\*.*?\\*/\\s*)*"
+ // Optional C-style comments
"(?:--[^\\n]*\\n\\s*)*"
+ // Optional SQL line comments
"(?:"
+ // Start of keywords group
"INSERT|UPDATE|DELETE"
+ ")\\b",
Pattern.CASE_INSENSITIVE | Pattern.DOTALL);
private final JDBC4Connection connection;
@Nullable protected TursoStatement statement = null;
@@ -33,6 +51,12 @@ public class JDBC4Statement implements Statement {
private ReentrantLock connectionLock = new ReentrantLock();
/**
* List of SQL statements to be executed as a batch. Used for batch processing as per JDBC
* specification.
*/
private List<String> batchCommands = new ArrayList<>();
public JDBC4Statement(JDBC4Connection connection) {
this(
connection,
@@ -232,18 +256,82 @@ public class JDBC4Statement implements Statement {
@Override
public void addBatch(String sql) throws SQLException {
// TODO
ensureOpen();
if (sql == null) {
throw new SQLException("SQL command cannot be null");
}
batchCommands.add(sql);
}
@Override
public void clearBatch() throws SQLException {
// TODO
ensureOpen();
batchCommands.clear();
}
// TODO: let's make this batch operation atomic
@Override
public int[] executeBatch() throws SQLException {
// TODO
return new int[0];
ensureOpen();
int[] updateCounts = new int[batchCommands.size()];
List<String> failedCommands = new ArrayList<>();
// Execute each command in the batch
for (int i = 0; i < batchCommands.size(); i++) {
String sql = batchCommands.get(i);
try {
if (!isBatchCompatibleStatement(sql)) {
failedCommands.add(sql);
updateCounts[i] = EXECUTE_FAILED;
BatchUpdateException bue =
new BatchUpdateException(
"Batch entry "
+ i
+ " ("
+ sql
+ ") was aborted. "
+ "Batch commands cannot return result sets.",
"HY000", // General error SQL state
0,
updateCounts);
// Clear the batch after failure
clearBatch();
throw bue;
}
execute(sql);
// For DML statements, get the update count
updateCounts[i] = getUpdateCount();
} catch (SQLException e) {
failedCommands.add(sql);
updateCounts[i] = EXECUTE_FAILED;
// Create a BatchUpdateException with the partial results
BatchUpdateException bue =
new BatchUpdateException(
"Batch entry " + i + " (" + sql + ") failed: " + e.getMessage(),
e.getSQLState(),
e.getErrorCode(),
updateCounts,
e.getCause());
// Clear the batch after failure
clearBatch();
throw bue;
}
}
// Clear the batch after successful execution
clearBatch();
return updateCounts;
}
boolean isBatchCompatibleStatement(String sql) {
if (sql == null || sql.trim().isEmpty()) {
return false;
}
return BATCH_COMPATIBLE_PATTERN.matcher(sql).find();
}
@Override

View File

@@ -7,6 +7,7 @@ import static org.junit.jupiter.api.Assertions.assertFalse;
import static org.junit.jupiter.api.Assertions.assertThrows;
import static org.junit.jupiter.api.Assertions.assertTrue;
import java.sql.BatchUpdateException;
import java.sql.ResultSet;
import java.sql.SQLException;
import java.sql.Statement;
@@ -120,4 +121,277 @@ class JDBC4StatementTest {
assertThat(stmt.executeUpdate("DELETE FROM s1")).isEqualTo(3);
}
/** Tests for batch processing functionality */
@Test
void testAddBatch_single_statement() throws SQLException {
stmt.execute("CREATE TABLE batch_test (id INTEGER PRIMARY KEY, value TEXT);");
stmt.addBatch("INSERT INTO batch_test VALUES (1, 'test1');");
int[] updateCounts = stmt.executeBatch();
assertThat(updateCounts).hasSize(1);
assertThat(updateCounts[0]).isEqualTo(1);
ResultSet rs = stmt.executeQuery("SELECT COUNT(*) FROM batch_test;");
assertTrue(rs.next());
assertThat(rs.getInt(1)).isEqualTo(1);
}
@Test
void testAddBatch_multiple_statements() throws SQLException {
stmt.execute("CREATE TABLE batch_test (id INTEGER PRIMARY KEY, value TEXT);");
stmt.addBatch("INSERT INTO batch_test VALUES (1, 'test1');");
stmt.addBatch("INSERT INTO batch_test VALUES (2, 'test2');");
stmt.addBatch("INSERT INTO batch_test VALUES (3, 'test3');");
int[] updateCounts = stmt.executeBatch();
assertThat(updateCounts).hasSize(3);
assertThat(updateCounts[0]).isEqualTo(1);
assertThat(updateCounts[1]).isEqualTo(1);
assertThat(updateCounts[2]).isEqualTo(1);
ResultSet rs = stmt.executeQuery("SELECT COUNT(*) FROM batch_test;");
assertTrue(rs.next());
assertThat(rs.getInt(1)).isEqualTo(3);
}
@Test
void testAddBatch_with_updates_and_deletes() throws SQLException {
stmt.execute("CREATE TABLE batch_test (id INTEGER PRIMARY KEY, value TEXT);");
stmt.execute(
"INSERT INTO batch_test VALUES (1, 'initial1'), (2, 'initial2'), (3, 'initial3');");
stmt.addBatch("UPDATE batch_test SET value = 'updated';");
stmt.addBatch("DELETE FROM batch_test WHERE id = 2;");
stmt.addBatch("INSERT INTO batch_test VALUES (4, 'new');");
int[] updateCounts = stmt.executeBatch();
assertThat(updateCounts).hasSize(3);
assertThat(updateCounts[0]).isEqualTo(3); // UPDATE affected 3 row
assertThat(updateCounts[1]).isEqualTo(1); // DELETE affected 1 row
assertThat(updateCounts[2]).isEqualTo(1); // INSERT affected 1 row
// Verify final state
ResultSet rs = stmt.executeQuery("SELECT COUNT(*) FROM batch_test;");
assertTrue(rs.next());
assertThat(rs.getInt(1)).isEqualTo(3); // 3 initial - 1 deleted + 1 inserted = 3
}
@Test
void testClearBatch() throws SQLException {
stmt.execute("CREATE TABLE batch_test (id INTEGER PRIMARY KEY, value TEXT);");
stmt.addBatch("INSERT INTO batch_test VALUES (1, 'test1');");
stmt.addBatch("INSERT INTO batch_test VALUES (2, 'test2');");
stmt.clearBatch();
int[] updateCounts = stmt.executeBatch();
assertThat(updateCounts).isEmpty();
ResultSet rs = stmt.executeQuery("SELECT COUNT(*) FROM batch_test;");
assertTrue(rs.next());
assertThat(rs.getInt(1)).isEqualTo(0);
}
@Test
void testBatch_with_SELECT_should_throw_exception() throws SQLException {
stmt.execute("CREATE TABLE batch_test (id INTEGER PRIMARY KEY, value TEXT);");
stmt.execute("INSERT INTO batch_test VALUES (1, 'test1');");
stmt.addBatch("INSERT INTO batch_test VALUES (2, 'test2');");
stmt.addBatch("SELECT * FROM batch_test;"); // This should cause an exception
stmt.addBatch("INSERT INTO batch_test VALUES (3, 'test3');");
BatchUpdateException exception =
assertThrows(BatchUpdateException.class, () -> stmt.executeBatch());
assertTrue(exception.getMessage().contains("Batch commands cannot return result sets"));
int[] updateCounts = exception.getUpdateCounts();
assertThat(updateCounts).hasSize(3);
assertThat(updateCounts[0]).isEqualTo(1); // First INSERT succeeded
assertThat(updateCounts[1]).isEqualTo(Statement.EXECUTE_FAILED); // SELECT failed
}
@Test
void testBatch_with_null_command_should_throw_exception() {
assertThrows(SQLException.class, () -> stmt.addBatch(null));
}
@Test
void testBatch_operations_on_closed_statement_should_throw_exception() throws SQLException {
stmt.close();
assertThrows(SQLException.class, () -> stmt.addBatch("INSERT INTO test VALUES (1);"));
assertThrows(SQLException.class, () -> stmt.clearBatch());
assertThrows(SQLException.class, () -> stmt.executeBatch());
}
@Test
void testBatch_with_syntax_error_should_throw_exception() throws SQLException {
stmt.execute("CREATE TABLE batch_test (id INTEGER PRIMARY KEY, value TEXT);");
stmt.addBatch("INSERT INTO batch_test VALUES (1, 'test1');");
stmt.addBatch("INVALID SQL SYNTAX;"); // This should cause an exception
stmt.addBatch("INSERT INTO batch_test VALUES (3, 'test3');");
BatchUpdateException exception =
assertThrows(BatchUpdateException.class, () -> stmt.executeBatch());
int[] updateCounts = exception.getUpdateCounts();
assertThat(updateCounts).hasSize(3);
assertThat(updateCounts[0]).isEqualTo(1); // First INSERT succeeded
assertThat(updateCounts[1]).isEqualTo(Statement.EXECUTE_FAILED); // Invalid SQL failed
}
@Test
void testBatch_empty_batch_returns_empty_array() throws SQLException {
int[] updateCounts = stmt.executeBatch();
assertThat(updateCounts).isEmpty();
}
@Test
void testBatch_clears_after_successful_execution() throws SQLException {
stmt.execute("CREATE TABLE batch_test (id INTEGER PRIMARY KEY, value TEXT);");
stmt.addBatch("INSERT INTO batch_test VALUES (1, 'test1');");
stmt.executeBatch();
int[] updateCounts = stmt.executeBatch();
assertThat(updateCounts).isEmpty();
}
@Test
void testBatch_clears_after_failed_execution() throws SQLException {
stmt.execute("CREATE TABLE batch_test (id INTEGER PRIMARY KEY, value TEXT);");
stmt.addBatch("SELECT * FROM batch_test;");
assertThrows(BatchUpdateException.class, () -> stmt.executeBatch());
int[] updateCounts = stmt.executeBatch();
assertThat(updateCounts).isEmpty();
}
/** Tests for isBatchCompatibleStatement method */
@Test
void testIsBatchCompatibleStatement_compatible_statements() {
JDBC4Statement jdbc4Stmt = (JDBC4Statement) stmt;
assertTrue(jdbc4Stmt.isBatchCompatibleStatement("INSERT INTO table VALUES (1, 2);"));
assertTrue(jdbc4Stmt.isBatchCompatibleStatement("insert into table values (1, 2);"));
assertTrue(
jdbc4Stmt.isBatchCompatibleStatement("INSERT INTO table (col1, col2) VALUES (1, 2);"));
assertTrue(jdbc4Stmt.isBatchCompatibleStatement("INSERT OR REPLACE INTO table VALUES (1);"));
assertTrue(jdbc4Stmt.isBatchCompatibleStatement("INSERT OR IGNORE INTO table VALUES (1);"));
assertTrue(jdbc4Stmt.isBatchCompatibleStatement(" INSERT INTO table VALUES (1);"));
assertTrue(jdbc4Stmt.isBatchCompatibleStatement("\t\nINSERT INTO table VALUES (1);"));
assertTrue(jdbc4Stmt.isBatchCompatibleStatement(" \n\t INSERT INTO table VALUES (1);"));
assertTrue(jdbc4Stmt.isBatchCompatibleStatement("/* comment */ INSERT INTO table VALUES (1);"));
assertTrue(
jdbc4Stmt.isBatchCompatibleStatement(
"/* multi\nline\ncomment */ INSERT INTO table VALUES (1);"));
assertTrue(
jdbc4Stmt.isBatchCompatibleStatement("-- line comment\nINSERT INTO table VALUES (1);"));
assertTrue(
jdbc4Stmt.isBatchCompatibleStatement(
"-- comment 1\n-- comment 2\nINSERT INTO table VALUES (1);"));
assertTrue(
jdbc4Stmt.isBatchCompatibleStatement(
" /* comment */ -- another\n INSERT INTO table VALUES (1);"));
assertTrue(jdbc4Stmt.isBatchCompatibleStatement("UPDATE table SET col = 1;"));
assertTrue(jdbc4Stmt.isBatchCompatibleStatement("update table set col = 1;"));
assertTrue(
jdbc4Stmt.isBatchCompatibleStatement("UPDATE table SET col1 = 1, col2 = 2 WHERE id = 3;"));
assertTrue(jdbc4Stmt.isBatchCompatibleStatement("UPDATE OR REPLACE table SET col = 1;"));
assertTrue(jdbc4Stmt.isBatchCompatibleStatement(" UPDATE table SET col = 1;"));
assertTrue(jdbc4Stmt.isBatchCompatibleStatement("\t\nUPDATE table SET col = 1;"));
assertTrue(jdbc4Stmt.isBatchCompatibleStatement("/* comment */ UPDATE table SET col = 1;"));
assertTrue(jdbc4Stmt.isBatchCompatibleStatement("-- comment\nUPDATE table SET col = 1;"));
assertTrue(jdbc4Stmt.isBatchCompatibleStatement("DELETE FROM table;"));
assertTrue(jdbc4Stmt.isBatchCompatibleStatement("delete from table;"));
assertTrue(jdbc4Stmt.isBatchCompatibleStatement("DELETE FROM table WHERE id = 1;"));
assertTrue(jdbc4Stmt.isBatchCompatibleStatement(" DELETE FROM table;"));
assertTrue(jdbc4Stmt.isBatchCompatibleStatement("\t\nDELETE FROM table;"));
assertTrue(jdbc4Stmt.isBatchCompatibleStatement("/* comment */ DELETE FROM table;"));
assertTrue(jdbc4Stmt.isBatchCompatibleStatement("-- comment\nDELETE FROM table;"));
}
@Test
void testIsBatchCompatibleStatement_non_compatible_statements() {
JDBC4Statement jdbc4Stmt = (JDBC4Statement) stmt;
assertFalse(jdbc4Stmt.isBatchCompatibleStatement("SELECT * FROM table;"));
assertFalse(jdbc4Stmt.isBatchCompatibleStatement("select * from table;"));
assertFalse(jdbc4Stmt.isBatchCompatibleStatement(" SELECT * FROM table;"));
assertFalse(jdbc4Stmt.isBatchCompatibleStatement("/* comment */ SELECT * FROM table;"));
assertFalse(jdbc4Stmt.isBatchCompatibleStatement("-- comment\nSELECT * FROM table;"));
assertFalse(jdbc4Stmt.isBatchCompatibleStatement("EXPLAIN SELECT * FROM table;"));
assertFalse(jdbc4Stmt.isBatchCompatibleStatement("EXPLAIN QUERY PLAN SELECT * FROM table;"));
assertFalse(jdbc4Stmt.isBatchCompatibleStatement("PRAGMA table_info(table);"));
assertFalse(jdbc4Stmt.isBatchCompatibleStatement("PRAGMA foreign_keys = ON;"));
assertFalse(jdbc4Stmt.isBatchCompatibleStatement("ANALYZE;"));
assertFalse(jdbc4Stmt.isBatchCompatibleStatement("ANALYZE table;"));
assertFalse(
jdbc4Stmt.isBatchCompatibleStatement(
"WITH cte AS (SELECT * FROM table) SELECT * FROM cte;"));
assertFalse(jdbc4Stmt.isBatchCompatibleStatement("VACUUM;"));
assertFalse(jdbc4Stmt.isBatchCompatibleStatement("VALUES (1, 2), (3, 4);"));
}
@Test
void testIsBatchCompatibleStatement_edge_cases() {
JDBC4Statement jdbc4Stmt = (JDBC4Statement) stmt;
assertFalse(jdbc4Stmt.isBatchCompatibleStatement(null));
assertFalse(jdbc4Stmt.isBatchCompatibleStatement(""));
assertFalse(jdbc4Stmt.isBatchCompatibleStatement(" "));
assertFalse(jdbc4Stmt.isBatchCompatibleStatement("\t\n"));
assertFalse(jdbc4Stmt.isBatchCompatibleStatement("/* comment only */"));
assertFalse(jdbc4Stmt.isBatchCompatibleStatement("-- comment only"));
assertFalse(jdbc4Stmt.isBatchCompatibleStatement("/* comment */ -- another comment"));
assertFalse(jdbc4Stmt.isBatchCompatibleStatement("SELECT * FROM table WHERE name = 'INSERT';"));
assertFalse(
jdbc4Stmt.isBatchCompatibleStatement("SELECT * FROM table WHERE action = 'DELETE';"));
assertFalse(jdbc4Stmt.isBatchCompatibleStatement("INSER INTO table VALUES (1);"));
assertFalse(jdbc4Stmt.isBatchCompatibleStatement("UPDAT table SET col = 1;"));
assertFalse(jdbc4Stmt.isBatchCompatibleStatement("DELET FROM table;"));
}
@Test
void testIsBatchCompatibleStatement_case_insensitive() {
JDBC4Statement jdbc4Stmt = (JDBC4Statement) stmt;
assertTrue(jdbc4Stmt.isBatchCompatibleStatement("Insert INTO table VALUES (1);"));
assertTrue(jdbc4Stmt.isBatchCompatibleStatement("InSeRt INTO table VALUES (1);"));
assertTrue(jdbc4Stmt.isBatchCompatibleStatement("UPDATE table SET col = 1;"));
assertTrue(jdbc4Stmt.isBatchCompatibleStatement("UpDaTe table SET col = 1;"));
assertTrue(jdbc4Stmt.isBatchCompatibleStatement("Delete FROM table;"));
assertTrue(jdbc4Stmt.isBatchCompatibleStatement("DeLeTe FROM table;"));
}
}

View File

@@ -1 +0,0 @@
{"rustc_fingerprint":11551670960185020797,"outputs":{"14427667104029986310":{"success":true,"status":"","code":0,"stdout":"rustc 1.83.0 (90b35a623 2024-11-26)\nbinary: rustc\ncommit-hash: 90b35a6239c3d8bdabc530a6a0816f7ff89a0aaf\ncommit-date: 2024-11-26\nhost: x86_64-unknown-linux-gnu\nrelease: 1.83.0\nLLVM version: 19.1.1\n","stderr":""},"11399821309745579047":{"success":true,"status":"","code":0,"stdout":"___\nlib___.rlib\nlib___.so\nlib___.so\nlib___.a\nlib___.so\n/home/merlin/.rustup/toolchains/1.83.0-x86_64-unknown-linux-gnu\noff\npacked\nunpacked\n___\ndebug_assertions\npanic=\"unwind\"\nproc_macro\ntarget_abi=\"\"\ntarget_arch=\"x86_64\"\ntarget_endian=\"little\"\ntarget_env=\"gnu\"\ntarget_family=\"unix\"\ntarget_feature=\"fxsr\"\ntarget_feature=\"sse\"\ntarget_feature=\"sse2\"\ntarget_has_atomic=\"16\"\ntarget_has_atomic=\"32\"\ntarget_has_atomic=\"64\"\ntarget_has_atomic=\"8\"\ntarget_has_atomic=\"ptr\"\ntarget_os=\"linux\"\ntarget_pointer_width=\"64\"\ntarget_vendor=\"unknown\"\nunix\n","stderr":""}},"successes":{}}

View File

@@ -569,7 +569,7 @@ impl turso_core::DatabaseStorage for DatabaseFile {
fn read_page(
&self,
page_idx: usize,
_key: Option<&turso_core::EncryptionKey>,
_encryption_ctx: Option<&turso_core::EncryptionContext>,
c: turso_core::Completion,
) -> turso_core::Result<turso_core::Completion> {
let r = c.as_read();
@@ -586,7 +586,7 @@ impl turso_core::DatabaseStorage for DatabaseFile {
&self,
page_idx: usize,
buffer: Arc<turso_core::Buffer>,
_key: Option<&turso_core::EncryptionKey>,
_encryption_ctx: Option<&turso_core::EncryptionContext>,
c: turso_core::Completion,
) -> turso_core::Result<turso_core::Completion> {
let size = buffer.len();
@@ -599,7 +599,7 @@ impl turso_core::DatabaseStorage for DatabaseFile {
first_page_idx: usize,
page_size: usize,
buffers: Vec<Arc<turso_core::Buffer>>,
_key: Option<&turso_core::EncryptionKey>,
_encryption_ctx: Option<&turso_core::EncryptionContext>,
c: turso_core::Completion,
) -> turso_core::Result<turso_core::Completion> {
let pos = first_page_idx.saturating_sub(1) * page_size;

View File

@@ -75,7 +75,7 @@ use std::{
};
#[cfg(feature = "fs")]
use storage::database::DatabaseFile;
pub use storage::encryption::EncryptionKey;
pub use storage::encryption::{EncryptionKey, EncryptionContext};
use storage::page_cache::DumbLruPageCache;
use storage::pager::{AtomicDbState, DbState};
use storage::sqlite3_ondisk::PageSize;
@@ -1955,11 +1955,11 @@ impl Connection {
self.syms.borrow().vtab_modules.keys().cloned().collect()
}
pub fn set_encryption_key(&self, key: Option<EncryptionKey>) {
pub fn set_encryption_key(&self, key: EncryptionKey) {
tracing::trace!("setting encryption key for connection");
*self.encryption_key.borrow_mut() = key.clone();
*self.encryption_key.borrow_mut() = Some(key.clone());
let pager = self.pager.borrow();
pager.set_encryption_key(key);
pager.set_encryption_context(&key);
}
}

View File

@@ -255,7 +255,8 @@ impl Schema {
}
pub fn table_has_indexes(&self, table_name: &str) -> bool {
self.has_indexes.contains(table_name)
let name = normalize_ident(table_name);
self.has_indexes.contains(&name)
}
pub fn table_set_has_index(&mut self, table_name: &str) {

View File

@@ -1,5 +1,5 @@
use crate::error::LimboError;
use crate::storage::encryption::{decrypt_page, encrypt_page, EncryptionKey};
use crate::storage::encryption::EncryptionContext;
use crate::{io::Completion, Buffer, CompletionError, Result};
use std::sync::Arc;
use tracing::{instrument, Level};
@@ -15,14 +15,14 @@ pub trait DatabaseStorage: Send + Sync {
fn read_page(
&self,
page_idx: usize,
encryption_key: Option<&EncryptionKey>,
encryption_ctx: Option<&EncryptionContext>,
c: Completion,
) -> Result<Completion>;
fn write_page(
&self,
page_idx: usize,
buffer: Arc<Buffer>,
encryption_key: Option<&EncryptionKey>,
encryption_ctx: Option<&EncryptionContext>,
c: Completion,
) -> Result<Completion>;
fn write_pages(
@@ -30,7 +30,7 @@ pub trait DatabaseStorage: Send + Sync {
first_page_idx: usize,
page_size: usize,
buffers: Vec<Arc<Buffer>>,
encryption_key: Option<&EncryptionKey>,
encryption_ctx: Option<&EncryptionContext>,
c: Completion,
) -> Result<Completion>;
fn sync(&self, c: Completion) -> Result<Completion>;
@@ -59,7 +59,7 @@ impl DatabaseStorage for DatabaseFile {
fn read_page(
&self,
page_idx: usize,
encryption_key: Option<&EncryptionKey>,
encryption_ctx: Option<&EncryptionContext>,
c: Completion,
) -> Result<Completion> {
let r = c.as_read();
@@ -70,8 +70,8 @@ impl DatabaseStorage for DatabaseFile {
}
let pos = (page_idx - 1) * size;
if let Some(key) = encryption_key {
let key_clone = key.clone();
if let Some(ctx) = encryption_ctx {
let encryption_ctx = ctx.clone();
let read_buffer = r.buf_arc();
let original_c = c.clone();
@@ -81,7 +81,7 @@ impl DatabaseStorage for DatabaseFile {
return;
};
if bytes_read > 0 {
match decrypt_page(buf.as_slice(), page_idx, &key_clone) {
match encryption_ctx.decrypt_page(buf.as_slice(), page_idx) {
Ok(decrypted_data) => {
let original_buf = original_c.as_read().buf();
original_buf.as_mut_slice().copy_from_slice(&decrypted_data);
@@ -111,7 +111,7 @@ impl DatabaseStorage for DatabaseFile {
&self,
page_idx: usize,
buffer: Arc<Buffer>,
encryption_key: Option<&EncryptionKey>,
encryption_ctx: Option<&EncryptionContext>,
c: Completion,
) -> Result<Completion> {
let buffer_size = buffer.len();
@@ -121,8 +121,8 @@ impl DatabaseStorage for DatabaseFile {
assert_eq!(buffer_size & (buffer_size - 1), 0);
let pos = (page_idx - 1) * buffer_size;
let buffer = {
if let Some(key) = encryption_key {
encrypt_buffer(page_idx, buffer, key)
if let Some(ctx) = encryption_ctx {
encrypt_buffer(page_idx, buffer, ctx)
} else {
buffer
}
@@ -135,7 +135,7 @@ impl DatabaseStorage for DatabaseFile {
first_page_idx: usize,
page_size: usize,
buffers: Vec<Arc<Buffer>>,
encryption_key: Option<&EncryptionKey>,
encryption_key: Option<&EncryptionContext>,
c: Completion,
) -> Result<Completion> {
assert!(first_page_idx > 0);
@@ -145,11 +145,11 @@ impl DatabaseStorage for DatabaseFile {
let pos = (first_page_idx - 1) * page_size;
let buffers = {
if let Some(key) = encryption_key {
if let Some(ctx) = encryption_key {
buffers
.into_iter()
.enumerate()
.map(|(i, buffer)| encrypt_buffer(first_page_idx + i, buffer, key))
.map(|(i, buffer)| encrypt_buffer(first_page_idx + i, buffer, ctx))
.collect::<Vec<_>>()
} else {
buffers
@@ -184,7 +184,11 @@ impl DatabaseFile {
}
}
fn encrypt_buffer(page_idx: usize, buffer: Arc<Buffer>, key: &EncryptionKey) -> Arc<Buffer> {
let encrypted_data = encrypt_page(buffer.as_slice(), page_idx, key).unwrap();
fn encrypt_buffer(
page_idx: usize,
buffer: Arc<Buffer>,
ctx: &EncryptionContext,
) -> Arc<Buffer> {
let encrypted_data = ctx.encrypt_page(buffer.as_slice(), page_idx).unwrap();
Arc::new(Buffer::new(encrypted_data.to_vec()))
}

View File

@@ -1,7 +1,5 @@
#![allow(unused_variables, dead_code)]
#[cfg(not(feature = "encryption"))]
use crate::LimboError;
use crate::Result;
use crate::{LimboError, Result};
use aes_gcm::{
aead::{Aead, AeadCore, KeyInit, OsRng},
Aes256Gcm, Key, Nonce,
@@ -11,6 +9,7 @@ use std::ops::Deref;
pub const ENCRYPTION_METADATA_SIZE: usize = 28;
pub const ENCRYPTED_PAGE_SIZE: usize = 4096;
pub const ENCRYPTION_NONCE_SIZE: usize = 12;
pub const ENCRYPTION_TAG_SIZE: usize = 16;
#[repr(transparent)]
#[derive(Clone)]
@@ -71,106 +70,195 @@ impl Drop for EncryptionKey {
}
}
#[cfg(not(feature = "encryption"))]
pub fn encrypt_page(page: &[u8], page_id: usize, key: &EncryptionKey) -> Result<Vec<u8>> {
Err(LimboError::InvalidArgument(
"encryption is not enabled, cannot encrypt page. enable via passing `--features encryption`".into(),
))
#[derive(Debug, Clone, Copy, PartialEq)]
pub enum CipherMode {
Aes256Gcm,
}
#[cfg(feature = "encryption")]
pub fn encrypt_page(page: &[u8], page_id: usize, key: &EncryptionKey) -> Result<Vec<u8>> {
if page_id == 1 {
tracing::debug!("skipping encryption for page 1 (database header)");
return Ok(page.to_vec());
impl CipherMode {
/// Every cipher requires a specific key size. For 256-bit algorithms, this is 32 bytes.
/// For 128-bit algorithms, it would be 16 bytes, etc.
pub fn required_key_size(&self) -> usize {
match self {
CipherMode::Aes256Gcm => 32,
}
}
tracing::debug!("encrypting page {}", page_id);
assert_eq!(
page.len(),
ENCRYPTED_PAGE_SIZE,
"Page data must be exactly {ENCRYPTED_PAGE_SIZE} bytes"
);
let reserved_bytes = &page[ENCRYPTED_PAGE_SIZE - ENCRYPTION_METADATA_SIZE..];
let reserved_bytes_zeroed = reserved_bytes.iter().all(|&b| b == 0);
assert!(
reserved_bytes_zeroed,
"last reserved bytes must be empty/zero, but found non-zero bytes"
);
let payload = &page[..ENCRYPTED_PAGE_SIZE - ENCRYPTION_METADATA_SIZE];
let (encrypted, nonce) = encrypt(payload, key)?;
assert_eq!(
encrypted.len(),
ENCRYPTED_PAGE_SIZE - nonce.len(),
"Encrypted page must be exactly {} bytes",
ENCRYPTED_PAGE_SIZE - nonce.len()
);
let mut result = Vec::with_capacity(ENCRYPTED_PAGE_SIZE);
result.extend_from_slice(&encrypted);
result.extend_from_slice(&nonce);
assert_eq!(
result.len(),
ENCRYPTED_PAGE_SIZE,
"Encrypted page must be exactly {ENCRYPTED_PAGE_SIZE} bytes"
);
Ok(result)
}
#[cfg(not(feature = "encryption"))]
pub fn decrypt_page(encrypted_page: &[u8], page_id: usize, key: &EncryptionKey) -> Result<Vec<u8>> {
Err(LimboError::InvalidArgument(
"encryption is not enabled, cannot decrypt page. enable via passing `--features encryption`".into(),
))
}
#[cfg(feature = "encryption")]
pub fn decrypt_page(encrypted_page: &[u8], page_id: usize, key: &EncryptionKey) -> Result<Vec<u8>> {
if page_id == 1 {
tracing::debug!("skipping decryption for page 1 (database header)");
return Ok(encrypted_page.to_vec());
/// Returns the nonce size for this cipher mode. Though most AEAD ciphers use 12-byte nonces.
pub fn nonce_size(&self) -> usize {
match self {
CipherMode::Aes256Gcm => ENCRYPTION_NONCE_SIZE,
}
}
tracing::debug!("decrypting page {}", page_id);
assert_eq!(
encrypted_page.len(),
ENCRYPTED_PAGE_SIZE,
"Encrypted page data must be exactly {ENCRYPTED_PAGE_SIZE} bytes"
);
let nonce_start = encrypted_page.len() - ENCRYPTION_NONCE_SIZE;
let payload = &encrypted_page[..nonce_start];
let nonce = &encrypted_page[nonce_start..];
let decrypted_data = decrypt(payload, nonce, key)?;
assert_eq!(
decrypted_data.len(),
ENCRYPTED_PAGE_SIZE - ENCRYPTION_METADATA_SIZE,
"Decrypted page data must be exactly {} bytes",
ENCRYPTED_PAGE_SIZE - ENCRYPTION_METADATA_SIZE
);
let mut result = Vec::with_capacity(ENCRYPTED_PAGE_SIZE);
result.extend_from_slice(&decrypted_data);
result.resize(ENCRYPTED_PAGE_SIZE, 0);
assert_eq!(
result.len(),
ENCRYPTED_PAGE_SIZE,
"Decrypted page data must be exactly {ENCRYPTED_PAGE_SIZE} bytes"
);
Ok(result)
/// Returns the authentication tag size for this cipher mode. All common AEAD ciphers use 16-byte tags.
pub fn tag_size(&self) -> usize {
match self {
CipherMode::Aes256Gcm => ENCRYPTION_TAG_SIZE,
}
}
}
fn encrypt(plaintext: &[u8], key: &EncryptionKey) -> Result<(Vec<u8>, Vec<u8>)> {
let key: &Key<Aes256Gcm> = key.as_ref().into();
let cipher = Aes256Gcm::new(key);
let nonce = Aes256Gcm::generate_nonce(&mut OsRng);
let ciphertext = cipher.encrypt(&nonce, plaintext).unwrap();
Ok((ciphertext, nonce.to_vec()))
#[derive(Clone)]
pub enum Cipher {
Aes256Gcm(Box<Aes256Gcm>),
}
fn decrypt(ciphertext: &[u8], nonce: &[u8], key: &EncryptionKey) -> Result<Vec<u8>> {
let key: &Key<Aes256Gcm> = key.as_ref().into();
let cipher = Aes256Gcm::new(key);
let nonce = Nonce::from_slice(nonce);
let plaintext = cipher.decrypt(nonce, ciphertext).unwrap();
Ok(plaintext)
impl std::fmt::Debug for Cipher {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
Cipher::Aes256Gcm(_) => write!(f, "Cipher::Aes256Gcm"),
}
}
}
#[derive(Clone)]
pub struct EncryptionContext {
cipher_mode: CipherMode,
cipher: Cipher,
}
impl EncryptionContext {
pub fn new(key: &EncryptionKey) -> Result<Self> {
let cipher_mode = CipherMode::Aes256Gcm;
let required_size = cipher_mode.required_key_size();
if key.as_slice().len() != required_size {
return Err(crate::LimboError::InvalidArgument(format!(
"Invalid key size for {:?}: expected {} bytes, got {}",
cipher_mode,
required_size,
key.as_slice().len()
)));
}
let cipher = match cipher_mode {
CipherMode::Aes256Gcm => {
let cipher_key: &Key<Aes256Gcm> = key.as_ref().into();
Cipher::Aes256Gcm(Box::new(Aes256Gcm::new(cipher_key)))
}
};
Ok(Self {
cipher_mode,
cipher,
})
}
pub fn cipher_mode(&self) -> CipherMode {
self.cipher_mode
}
#[cfg(feature = "encryption")]
pub fn encrypt_page(&self, page: &[u8], page_id: usize) -> Result<Vec<u8>> {
if page_id == 1 {
tracing::debug!("skipping encryption for page 1 (database header)");
return Ok(page.to_vec());
}
tracing::debug!("encrypting page {}", page_id);
assert_eq!(
page.len(),
ENCRYPTED_PAGE_SIZE,
"Page data must be exactly {ENCRYPTED_PAGE_SIZE} bytes"
);
let reserved_bytes = &page[ENCRYPTED_PAGE_SIZE - ENCRYPTION_METADATA_SIZE..];
let reserved_bytes_zeroed = reserved_bytes.iter().all(|&b| b == 0);
assert!(
reserved_bytes_zeroed,
"last reserved bytes must be empty/zero, but found non-zero bytes"
);
let payload = &page[..ENCRYPTED_PAGE_SIZE - ENCRYPTION_METADATA_SIZE];
let (encrypted, nonce) = self.encrypt_raw(payload)?;
assert_eq!(
encrypted.len(),
ENCRYPTED_PAGE_SIZE - nonce.len(),
"Encrypted page must be exactly {} bytes",
ENCRYPTED_PAGE_SIZE - nonce.len()
);
let mut result = Vec::with_capacity(ENCRYPTED_PAGE_SIZE);
result.extend_from_slice(&encrypted);
result.extend_from_slice(&nonce);
assert_eq!(
result.len(),
ENCRYPTED_PAGE_SIZE,
"Encrypted page must be exactly {ENCRYPTED_PAGE_SIZE} bytes"
);
Ok(result)
}
#[cfg(feature = "encryption")]
pub fn decrypt_page(&self, encrypted_page: &[u8], page_id: usize) -> Result<Vec<u8>> {
if page_id == 1 {
tracing::debug!("skipping decryption for page 1 (database header)");
return Ok(encrypted_page.to_vec());
}
tracing::debug!("decrypting page {}", page_id);
assert_eq!(
encrypted_page.len(),
ENCRYPTED_PAGE_SIZE,
"Encrypted page data must be exactly {ENCRYPTED_PAGE_SIZE} bytes"
);
let nonce_start = encrypted_page.len() - ENCRYPTION_NONCE_SIZE;
let payload = &encrypted_page[..nonce_start];
let nonce = &encrypted_page[nonce_start..];
let decrypted_data = self.decrypt_raw(payload, nonce)?;
assert_eq!(
decrypted_data.len(),
ENCRYPTED_PAGE_SIZE - ENCRYPTION_METADATA_SIZE,
"Decrypted page data must be exactly {} bytes",
ENCRYPTED_PAGE_SIZE - ENCRYPTION_METADATA_SIZE
);
let mut result = Vec::with_capacity(ENCRYPTED_PAGE_SIZE);
result.extend_from_slice(&decrypted_data);
result.resize(ENCRYPTED_PAGE_SIZE, 0);
assert_eq!(
result.len(),
ENCRYPTED_PAGE_SIZE,
"Decrypted page data must be exactly {ENCRYPTED_PAGE_SIZE} bytes"
);
Ok(result)
}
/// encrypts raw data using the configured cipher, returns ciphertext and nonce
fn encrypt_raw(&self, plaintext: &[u8]) -> Result<(Vec<u8>, Vec<u8>)> {
match &self.cipher {
Cipher::Aes256Gcm(cipher) => {
let nonce = Aes256Gcm::generate_nonce(&mut OsRng);
let ciphertext = cipher
.encrypt(&nonce, plaintext)
.map_err(|e| LimboError::InternalError(format!("Encryption failed: {e:?}")))?;
Ok((ciphertext, nonce.to_vec()))
}
}
}
fn decrypt_raw(&self, ciphertext: &[u8], nonce: &[u8]) -> Result<Vec<u8>> {
match &self.cipher {
Cipher::Aes256Gcm(cipher) => {
let nonce = Nonce::from_slice(nonce);
let plaintext = cipher.decrypt(nonce, ciphertext).map_err(|e| {
crate::LimboError::InternalError(format!("Decryption failed: {e:?}"))
})?;
Ok(plaintext)
}
}
}
#[cfg(not(feature = "encryption"))]
pub fn encrypt_page(&self, _page: &[u8], _page_id: usize) -> Result<Vec<u8>> {
Err(LimboError::InvalidArgument(
"encryption is not enabled, cannot encrypt page. enable via passing `--features encryption`".into(),
))
}
#[cfg(not(feature = "encryption"))]
pub fn decrypt_page(&self, _encrypted_page: &[u8], _page_id: usize) -> Result<Vec<u8>> {
Err(LimboError::InvalidArgument(
"encryption is not enabled, cannot decrypt page. enable via passing `--features encryption`".into(),
))
}
}
#[cfg(test)]
@@ -193,14 +281,15 @@ mod tests {
};
let key = EncryptionKey::from_string("alice and bob use encryption on database");
let ctx = EncryptionContext::new(&key).unwrap();
let page_id = 42;
let encrypted = encrypt_page(&page_data, page_id, &key).unwrap();
let encrypted = ctx.encrypt_page(&page_data, page_id).unwrap();
assert_eq!(encrypted.len(), ENCRYPTED_PAGE_SIZE);
assert_ne!(&encrypted[..data_size], &page_data[..data_size]);
assert_ne!(&encrypted[..], &page_data[..]);
let decrypted = decrypt_page(&encrypted, page_id, &key).unwrap();
let decrypted = ctx.decrypt_page(&encrypted, page_id).unwrap();
assert_eq!(decrypted.len(), ENCRYPTED_PAGE_SIZE);
assert_eq!(decrypted, page_data);
}

View File

@@ -28,7 +28,9 @@ use super::btree::{btree_init_page, BTreePage};
use super::page_cache::{CacheError, CacheResizeResult, DumbLruPageCache, PageCacheKey};
use super::sqlite3_ondisk::begin_write_btree_page;
use super::wal::CheckpointMode;
use crate::storage::encryption::{EncryptionKey, ENCRYPTION_METADATA_SIZE};
use crate::storage::encryption::{
EncryptionKey, EncryptionContext, ENCRYPTION_METADATA_SIZE,
};
/// SQLite's default maximum page count
const DEFAULT_MAX_PAGE_COUNT: u32 = 0xfffffffe;
@@ -347,7 +349,7 @@ pub enum BtreePageAllocMode {
/// This will keep track of the state of current cache commit in order to not repeat work
struct CommitInfo {
state: CommitState,
state: Cell<CommitState>,
}
/// Track the state of the auto-vacuum mode.
@@ -460,10 +462,10 @@ pub struct Pager {
pub io: Arc<dyn crate::io::IO>,
dirty_pages: Rc<RefCell<HashSet<usize, hash::BuildHasherDefault<hash::DefaultHasher>>>>,
commit_info: RefCell<CommitInfo>,
commit_info: CommitInfo,
checkpoint_state: RefCell<CheckpointState>,
syncing: Rc<Cell<bool>>,
auto_vacuum_mode: RefCell<AutoVacuumMode>,
auto_vacuum_mode: Cell<AutoVacuumMode>,
/// 0 -> Database is empty,
/// 1 -> Database is being initialized,
/// 2 -> Database is initialized and ready for use.
@@ -491,7 +493,7 @@ pub struct Pager {
header_ref_state: RefCell<HeaderRefState>,
#[cfg(not(feature = "omit_autovacuum"))]
btree_create_vacuum_full_state: Cell<BtreeCreateVacuumFullState>,
pub(crate) encryption_key: RefCell<Option<EncryptionKey>>,
pub(crate) encryption_ctx: RefCell<Option<EncryptionContext>>,
}
#[derive(Debug, Clone)]
@@ -571,13 +573,13 @@ impl Pager {
dirty_pages: Rc::new(RefCell::new(HashSet::with_hasher(
hash::BuildHasherDefault::new(),
))),
commit_info: RefCell::new(CommitInfo {
state: CommitState::Start,
}),
commit_info: CommitInfo {
state: CommitState::Start.into(),
},
syncing: Rc::new(Cell::new(false)),
checkpoint_state: RefCell::new(CheckpointState::Checkpoint),
buffer_pool,
auto_vacuum_mode: RefCell::new(AutoVacuumMode::None),
auto_vacuum_mode: Cell::new(AutoVacuumMode::None),
db_state,
init_lock,
allocate_page1_state,
@@ -593,7 +595,7 @@ impl Pager {
header_ref_state: RefCell::new(HeaderRefState::Start),
#[cfg(not(feature = "omit_autovacuum"))]
btree_create_vacuum_full_state: Cell::new(BtreeCreateVacuumFullState::Start),
encryption_key: RefCell::new(None),
encryption_ctx: RefCell::new(None),
})
}
@@ -620,11 +622,11 @@ impl Pager {
}
pub fn get_auto_vacuum_mode(&self) -> AutoVacuumMode {
*self.auto_vacuum_mode.borrow()
self.auto_vacuum_mode.get()
}
pub fn set_auto_vacuum_mode(&self, mode: AutoVacuumMode) {
*self.auto_vacuum_mode.borrow_mut() = mode;
self.auto_vacuum_mode.set(mode);
}
/// Retrieves the pointer map entry for a given database page.
@@ -840,8 +842,8 @@ impl Pager {
// If autovacuum is enabled, we need to allocate a new page number that is greater than the largest root page number
#[cfg(not(feature = "omit_autovacuum"))]
{
let auto_vacuum_mode = self.auto_vacuum_mode.borrow();
match *auto_vacuum_mode {
let auto_vacuum_mode = self.auto_vacuum_mode.get();
match auto_vacuum_mode {
AutoVacuumMode::None => {
let page =
return_if_io!(self.do_allocate_page(page_type, 0, BtreePageAllocMode::Any));
@@ -1092,7 +1094,7 @@ impl Pager {
page_idx,
page.clone(),
allow_empty_read,
self.encryption_key.borrow().as_ref(),
self.encryption_ctx.borrow().as_ref(),
)?;
return Ok((page, c));
};
@@ -1110,7 +1112,7 @@ impl Pager {
page_idx,
page.clone(),
allow_empty_read,
self.encryption_key.borrow().as_ref(),
self.encryption_ctx.borrow().as_ref(),
)?;
Ok((page, c))
}
@@ -1135,7 +1137,7 @@ impl Pager {
page_idx: usize,
page: PageRef,
allow_empty_read: bool,
encryption_key: Option<&EncryptionKey>,
encryption_key: Option<&EncryptionContext>,
) -> Result<Completion> {
sqlite3_ondisk::begin_read_page(
self.db_file.clone(),
@@ -1299,7 +1301,7 @@ impl Pager {
};
let mut checkpoint_result = CheckpointResult::default();
let res = loop {
let state = self.commit_info.borrow().state;
let state = self.commit_info.state.get();
trace!(?state);
match state {
CommitState::Start => {
@@ -1354,35 +1356,35 @@ impl Pager {
if completions.is_empty() {
return Ok(IOResult::Done(PagerCommitResult::WalWritten));
} else {
self.commit_info.borrow_mut().state = CommitState::SyncWal;
self.commit_info.state.set(CommitState::SyncWal);
io_yield_many!(completions);
}
}
CommitState::SyncWal => {
self.commit_info.borrow_mut().state = CommitState::AfterSyncWal;
self.commit_info.state.set(CommitState::AfterSyncWal);
let c = wal.borrow_mut().sync()?;
io_yield_one!(c);
}
CommitState::AfterSyncWal => {
turso_assert!(!wal.borrow().is_syncing(), "wal should have synced");
if wal_auto_checkpoint_disabled || !wal.borrow().should_checkpoint() {
self.commit_info.borrow_mut().state = CommitState::Start;
self.commit_info.state.set(CommitState::Start);
break PagerCommitResult::WalWritten;
}
self.commit_info.borrow_mut().state = CommitState::Checkpoint;
self.commit_info.state.set(CommitState::Checkpoint);
}
CommitState::Checkpoint => {
checkpoint_result = return_if_io!(self.checkpoint());
self.commit_info.borrow_mut().state = CommitState::SyncDbFile;
self.commit_info.state.set(CommitState::SyncDbFile);
}
CommitState::SyncDbFile => {
let c = sqlite3_ondisk::begin_sync(self.db_file.clone(), self.syncing.clone())?;
self.commit_info.borrow_mut().state = CommitState::AfterSyncDbFile;
self.commit_info.state.set(CommitState::AfterSyncDbFile);
io_yield_one!(c);
}
CommitState::AfterSyncDbFile => {
turso_assert!(!self.syncing.get(), "should have finished syncing");
self.commit_info.borrow_mut().state = CommitState::Start;
self.commit_info.state.set(CommitState::Start);
break PagerCommitResult::Checkpointed(checkpoint_result);
}
}
@@ -1724,7 +1726,7 @@ impl Pager {
default_header.database_size = 1.into();
// if a key is set, then we will reserve space for encryption metadata
if self.encryption_key.borrow().is_some() {
if self.encryption_ctx.borrow().is_some() {
default_header.reserved_space = ENCRYPTION_METADATA_SIZE as u8;
}
@@ -1817,7 +1819,7 @@ impl Pager {
// If the following conditions are met, allocate a pointer map page, add to cache and increment the database size
// - autovacuum is enabled
// - the last page is a pointer map page
if matches!(*self.auto_vacuum_mode.borrow(), AutoVacuumMode::Full)
if matches!(self.auto_vacuum_mode.get(), AutoVacuumMode::Full)
&& is_ptrmap_page(new_db_size + 1, header.page_size.get() as usize)
{
// we will allocate a ptrmap page, so increment size
@@ -2083,9 +2085,7 @@ impl Pager {
fn reset_internal_states(&self) {
self.checkpoint_state.replace(CheckpointState::Checkpoint);
self.syncing.replace(false);
self.commit_info.replace(CommitInfo {
state: CommitState::Start,
});
self.commit_info.state.set(CommitState::Start);
self.allocate_page_state.replace(AllocatePageState::Start);
self.free_page_state.replace(FreePageState::Start);
#[cfg(not(feature = "omit_autovacuum"))]
@@ -2111,10 +2111,11 @@ impl Pager {
Ok(IOResult::Done(f(header)))
}
pub fn set_encryption_key(&self, key: Option<EncryptionKey>) {
self.encryption_key.replace(key.clone());
pub fn set_encryption_context(&self, key: &EncryptionKey) {
let encryption_ctx = EncryptionContext::new(key).unwrap();
self.encryption_ctx.replace(Some(encryption_ctx.clone()));
let Some(wal) = self.wal.as_ref() else { return };
wal.borrow_mut().set_encryption_key(key)
wal.borrow_mut().set_encryption_context(encryption_ctx)
}
}

View File

@@ -59,7 +59,7 @@ use crate::storage::btree::offset::{
use crate::storage::btree::{payload_overflow_threshold_max, payload_overflow_threshold_min};
use crate::storage::buffer_pool::BufferPool;
use crate::storage::database::DatabaseStorage;
use crate::storage::encryption::EncryptionKey;
use crate::storage::encryption::EncryptionContext;
use crate::storage::pager::Pager;
use crate::storage::wal::READMARK_NOT_USED;
use crate::types::{RawSlice, RefValue, SerialType, SerialTypeKind, TextRef, TextSubtype};
@@ -870,7 +870,7 @@ pub fn begin_read_page(
page: PageRef,
page_idx: usize,
allow_empty_read: bool,
encryption_key: Option<&EncryptionKey>,
encryption_key: Option<&EncryptionContext>,
) -> Result<Completion> {
tracing::trace!("begin_read_btree_page(page_idx = {})", page_idx);
let buf = buffer_pool.get_page();
@@ -965,7 +965,7 @@ pub fn write_pages_vectored(
pager: &Pager,
batch: BTreeMap<usize, Arc<Buffer>>,
done_flag: Arc<AtomicBool>,
encryption_key: Option<&EncryptionKey>,
encryption_key: Option<&EncryptionContext>,
) -> Result<Vec<Completion>> {
if batch.is_empty() {
done_flag.store(true, Ordering::Relaxed);

View File

@@ -17,7 +17,7 @@ use super::sqlite3_ondisk::{self, checksum_wal, WalHeader, WAL_MAGIC_BE, WAL_MAG
use crate::fast_lock::SpinLock;
use crate::io::{clock, File, IO};
use crate::result::LimboResult;
use crate::storage::encryption::{decrypt_page, encrypt_page, EncryptionKey};
use crate::storage::encryption::EncryptionContext;
use crate::storage::sqlite3_ondisk::{
begin_read_wal_frame, begin_read_wal_frame_raw, finish_read_page, prepare_wal_frame,
write_pages_vectored, PageSize, WAL_FRAME_HEADER_SIZE, WAL_HEADER_SIZE,
@@ -297,7 +297,7 @@ pub trait Wal: Debug {
/// Return unique set of pages changed **after** frame_watermark position and until current WAL session max_frame_no
fn changed_pages_after(&self, frame_watermark: u64) -> Result<Vec<u32>>;
fn set_encryption_key(&mut self, key: Option<EncryptionKey>);
fn set_encryption_context(&mut self, ctx: EncryptionContext);
#[cfg(debug_assertions)]
fn as_any(&self) -> &dyn std::any::Any;
@@ -568,7 +568,7 @@ pub struct WalFile {
/// Manages locks needed for checkpointing
checkpoint_guard: Option<CheckpointLocks>,
encryption_key: RefCell<Option<EncryptionKey>>,
encryption_ctx: RefCell<Option<EncryptionContext>>,
}
impl fmt::Debug for WalFile {
@@ -1034,7 +1034,7 @@ impl Wal for WalFile {
page.set_locked();
let frame = page.clone();
let page_idx = page.get().id;
let key = self.encryption_key.borrow().clone();
let encryption_ctx = self.encryption_ctx.borrow().clone();
let seq = self.header.checkpoint_seq;
let complete = Box::new(move |res: Result<(Arc<Buffer>, i32), CompletionError>| {
let Ok((buf, bytes_read)) = res else {
@@ -1047,8 +1047,8 @@ impl Wal for WalFile {
"read({bytes_read}) less than expected({buf_len}): frame_id={frame_id}"
);
let cloned = frame.clone();
if let Some(key) = key.clone() {
match decrypt_page(buf.as_slice(), page_idx, &key) {
if let Some(ctx) = encryption_ctx.clone() {
match ctx.decrypt_page(buf.as_slice(), page_idx) {
Ok(decrypted_data) => {
buf.as_mut_slice().copy_from_slice(&decrypted_data);
}
@@ -1213,15 +1213,15 @@ impl Wal for WalFile {
let page_content = page.get_contents();
let page_buf = page_content.as_ptr();
let key = self.encryption_key.borrow();
let encryption_ctx = self.encryption_ctx.borrow();
let encrypted_data = {
if let Some(key) = key.as_ref() {
Some(encrypt_page(page_buf, page_id, key)?)
if let Some(key) = encryption_ctx.as_ref() {
Some(key.encrypt_page(page_buf, page_id)?)
} else {
None
}
};
let data_to_write = if key.as_ref().is_some() {
let data_to_write = if encryption_ctx.as_ref().is_some() {
encrypted_data.as_ref().unwrap().as_slice()
} else {
page_buf
@@ -1374,8 +1374,8 @@ impl Wal for WalFile {
self
}
fn set_encryption_key(&mut self, key: Option<EncryptionKey>) {
self.encryption_key.replace(key);
fn set_encryption_context(&mut self, ctx: EncryptionContext) {
self.encryption_ctx.replace(Some(ctx));
}
}
@@ -1413,7 +1413,7 @@ impl WalFile {
prev_checkpoint: CheckpointResult::default(),
checkpoint_guard: None,
header: *header,
encryption_key: RefCell::new(None),
encryption_ctx: RefCell::new(None),
}
}
@@ -1665,7 +1665,7 @@ impl WalFile {
pager,
batch_map,
done_flag,
self.encryption_key.borrow().as_ref(),
self.encryption_ctx.borrow().as_ref(),
)?);
}
}

View File

@@ -315,7 +315,7 @@ fn update_pragma(
PragmaName::EncryptionKey => {
let value = parse_string(&value)?;
let key = EncryptionKey::from_string(&value);
connection.set_encryption_key(Some(key));
connection.set_encryption_key(key);
Ok((program, TransactionMode::None))
}
}

View File

@@ -132,6 +132,7 @@ pub fn prepare_update_plan(
Some(table) => table,
None => bail_parse_error!("Parse error: no such table: {}", table_name),
};
let table_name = table.get_name();
let iter_dir = body
.order_by
.first()
@@ -149,7 +150,7 @@ pub fn prepare_update_plan(
Table::BTree(btree_table) => Table::BTree(btree_table.clone()),
_ => unreachable!(),
},
identifier: table_name.as_str().to_string(),
identifier: table_name.to_string(),
internal_id: program.table_reference_counter.next(),
op: build_scan_op(&table, iter_dir),
join_info: None,
@@ -235,7 +236,7 @@ pub fn prepare_update_plan(
Table::BTree(btree_table) => Table::BTree(btree_table.clone()),
_ => unreachable!(),
},
identifier: table_name.as_str().to_string(),
identifier: table_name.to_string(),
internal_id,
op: build_scan_op(&table, iter_dir),
join_info: None,
@@ -334,7 +335,7 @@ pub fn prepare_update_plan(
// Check what indexes will need to be updated by checking set_clauses and see
// if a column is contained in an index.
let indexes = schema.get_indices(table_name.as_str());
let indexes = schema.get_indices(table_name);
let rowid_alias_used = set_clauses
.iter()
.any(|(idx, _)| columns[*idx].is_rowid_alias);

View File

@@ -31,7 +31,7 @@ macro_rules! peek_expect {
expected: &[
$($x,)*
],
got: token.token_type.unwrap(),
got: tt,
})
}
}
@@ -223,10 +223,11 @@ impl<'a> Parser<'a> {
}
Some(token) => {
if !found_semi {
let tt = token.token_type.unwrap();
return Err(Error::ParseUnexpectedToken {
parsed_offset: (self.offset(), 1).into(),
expected: &[TK_SEMI],
got: token.token_type.unwrap(),
got: tt,
});
}
@@ -253,7 +254,7 @@ impl<'a> Parser<'a> {
}
}
fn next_token(&mut self) -> Result<Option<Token<'a>>> {
fn next_token(&mut self) -> Result<Option<&Token<'a>>> {
debug_assert!(!self.peekable);
let mut next = self.consume_lexer_without_whitespaces_or_comments();
@@ -479,9 +480,9 @@ impl<'a> Parser<'a> {
match next {
None => Ok(None), // EOF
Some(Ok(tok)) => {
self.current_token = tok.clone();
self.current_token = tok;
self.peekable = true;
Ok(Some(tok))
Ok(Some(&self.current_token))
}
Some(Err(err)) => Err(err),
}
@@ -520,16 +521,21 @@ impl<'a> Parser<'a> {
/// Get the next token from the lexer
#[inline]
fn eat(&mut self) -> Result<Option<Token<'a>>> {
let result = self.peek()?;
self.peekable = false; // Clear the peek mark after consuming
Ok(result)
match self.peek()? {
None => Ok(None),
Some(tok) => {
let result = tok.clone();
self.peekable = false; // Clear the peek mark after consuming
Ok(Some(result))
}
}
}
/// Peek at the next token without consuming it
#[inline]
fn peek(&mut self) -> Result<Option<Token<'a>>> {
fn peek(&mut self) -> Result<Option<&Token<'a>>> {
if self.peekable {
return Ok(Some(self.current_token.clone()));
return Ok(Some(&self.current_token));
}
self.next_token()
@@ -544,7 +550,7 @@ impl<'a> Parser<'a> {
}
#[inline]
fn peek_no_eof(&mut self) -> Result<Token<'a>> {
fn peek_no_eof(&mut self) -> Result<&Token<'a>> {
match self.peek()? {
None => Err(Error::ParseUnexpectedEOF),
Some(token) => Ok(token),
@@ -966,7 +972,7 @@ impl<'a> Parser<'a> {
let mut type_name = if let Some(tok) = self.peek()? {
match tok.token_type.unwrap().fallback_id_if_ok() {
TK_ID | TK_STRING => {
eat_assert!(self, TK_ID, TK_STRING);
let tok = eat_assert!(self, TK_ID, TK_STRING);
from_bytes(tok.value)
}
_ => return Ok(None),
@@ -978,7 +984,7 @@ impl<'a> Parser<'a> {
while let Some(tok) = self.peek()? {
match tok.token_type.unwrap().fallback_id_if_ok() {
TK_ID | TK_STRING => {
eat_assert!(self, TK_ID, TK_STRING);
let tok = eat_assert!(self, TK_ID, TK_STRING);
type_name.push(' ');
type_name.push_str(from_bytes_as_str(tok.value));
}
@@ -1324,25 +1330,25 @@ impl<'a> Parser<'a> {
Ok(Box::new(Expr::Literal(Literal::Null)))
}
TK_BLOB => {
eat_assert!(self, TK_BLOB);
let tok = eat_assert!(self, TK_BLOB);
Ok(Box::new(Expr::Literal(Literal::Blob(from_bytes(
tok.value,
)))))
}
TK_FLOAT => {
eat_assert!(self, TK_FLOAT);
let tok = eat_assert!(self, TK_FLOAT);
Ok(Box::new(Expr::Literal(Literal::Numeric(from_bytes(
tok.value,
)))))
}
TK_INTEGER => {
eat_assert!(self, TK_INTEGER);
let tok = eat_assert!(self, TK_INTEGER);
Ok(Box::new(Expr::Literal(Literal::Numeric(from_bytes(
tok.value,
)))))
}
TK_VARIABLE => {
eat_assert!(self, TK_VARIABLE);
let tok = eat_assert!(self, TK_VARIABLE);
Ok(Box::new(Expr::Variable(from_bytes(tok.value))))
}
TK_CAST => {