mirror of
https://github.com/aljazceru/turso.git
synced 2026-02-23 17:05:36 +01:00
Merge branch 'main' into clean-parser-4
This commit is contained in:
@@ -176,6 +176,11 @@ public final class JDBC4PreparedStatement extends JDBC4Statement implements Prep
|
||||
// TODO
|
||||
}
|
||||
|
||||
@Override
|
||||
public void addBatch(String sql) throws SQLException {
|
||||
throw new SQLException("addBatch(String) cannot be called on a PreparedStatement");
|
||||
}
|
||||
|
||||
@Override
|
||||
public void setCharacterStream(int parameterIndex, Reader reader, int length)
|
||||
throws SQLException {}
|
||||
|
||||
@@ -2,12 +2,16 @@ package tech.turso.jdbc4;
|
||||
|
||||
import static java.util.Objects.requireNonNull;
|
||||
|
||||
import java.sql.BatchUpdateException;
|
||||
import java.sql.Connection;
|
||||
import java.sql.ResultSet;
|
||||
import java.sql.SQLException;
|
||||
import java.sql.SQLWarning;
|
||||
import java.sql.Statement;
|
||||
import java.util.ArrayList;
|
||||
import java.util.List;
|
||||
import java.util.concurrent.locks.ReentrantLock;
|
||||
import java.util.regex.Pattern;
|
||||
import tech.turso.annotations.Nullable;
|
||||
import tech.turso.annotations.SkipNullableCheck;
|
||||
import tech.turso.core.TursoResultSet;
|
||||
@@ -15,6 +19,20 @@ import tech.turso.core.TursoStatement;
|
||||
|
||||
public class JDBC4Statement implements Statement {
|
||||
|
||||
private static final Pattern BATCH_COMPATIBLE_PATTERN =
|
||||
Pattern.compile(
|
||||
"^\\s*"
|
||||
+ // Leading whitespace
|
||||
"(?:/\\*.*?\\*/\\s*)*"
|
||||
+ // Optional C-style comments
|
||||
"(?:--[^\\n]*\\n\\s*)*"
|
||||
+ // Optional SQL line comments
|
||||
"(?:"
|
||||
+ // Start of keywords group
|
||||
"INSERT|UPDATE|DELETE"
|
||||
+ ")\\b",
|
||||
Pattern.CASE_INSENSITIVE | Pattern.DOTALL);
|
||||
|
||||
private final JDBC4Connection connection;
|
||||
|
||||
@Nullable protected TursoStatement statement = null;
|
||||
@@ -33,6 +51,12 @@ public class JDBC4Statement implements Statement {
|
||||
|
||||
private ReentrantLock connectionLock = new ReentrantLock();
|
||||
|
||||
/**
|
||||
* List of SQL statements to be executed as a batch. Used for batch processing as per JDBC
|
||||
* specification.
|
||||
*/
|
||||
private List<String> batchCommands = new ArrayList<>();
|
||||
|
||||
public JDBC4Statement(JDBC4Connection connection) {
|
||||
this(
|
||||
connection,
|
||||
@@ -232,18 +256,82 @@ public class JDBC4Statement implements Statement {
|
||||
|
||||
@Override
|
||||
public void addBatch(String sql) throws SQLException {
|
||||
// TODO
|
||||
ensureOpen();
|
||||
if (sql == null) {
|
||||
throw new SQLException("SQL command cannot be null");
|
||||
}
|
||||
batchCommands.add(sql);
|
||||
}
|
||||
|
||||
@Override
|
||||
public void clearBatch() throws SQLException {
|
||||
// TODO
|
||||
ensureOpen();
|
||||
batchCommands.clear();
|
||||
}
|
||||
|
||||
// TODO: let's make this batch operation atomic
|
||||
@Override
|
||||
public int[] executeBatch() throws SQLException {
|
||||
// TODO
|
||||
return new int[0];
|
||||
ensureOpen();
|
||||
|
||||
int[] updateCounts = new int[batchCommands.size()];
|
||||
List<String> failedCommands = new ArrayList<>();
|
||||
|
||||
// Execute each command in the batch
|
||||
for (int i = 0; i < batchCommands.size(); i++) {
|
||||
String sql = batchCommands.get(i);
|
||||
try {
|
||||
if (!isBatchCompatibleStatement(sql)) {
|
||||
failedCommands.add(sql);
|
||||
updateCounts[i] = EXECUTE_FAILED;
|
||||
BatchUpdateException bue =
|
||||
new BatchUpdateException(
|
||||
"Batch entry "
|
||||
+ i
|
||||
+ " ("
|
||||
+ sql
|
||||
+ ") was aborted. "
|
||||
+ "Batch commands cannot return result sets.",
|
||||
"HY000", // General error SQL state
|
||||
0,
|
||||
updateCounts);
|
||||
// Clear the batch after failure
|
||||
clearBatch();
|
||||
throw bue;
|
||||
}
|
||||
|
||||
execute(sql);
|
||||
// For DML statements, get the update count
|
||||
updateCounts[i] = getUpdateCount();
|
||||
} catch (SQLException e) {
|
||||
failedCommands.add(sql);
|
||||
updateCounts[i] = EXECUTE_FAILED;
|
||||
|
||||
// Create a BatchUpdateException with the partial results
|
||||
BatchUpdateException bue =
|
||||
new BatchUpdateException(
|
||||
"Batch entry " + i + " (" + sql + ") failed: " + e.getMessage(),
|
||||
e.getSQLState(),
|
||||
e.getErrorCode(),
|
||||
updateCounts,
|
||||
e.getCause());
|
||||
// Clear the batch after failure
|
||||
clearBatch();
|
||||
throw bue;
|
||||
}
|
||||
}
|
||||
|
||||
// Clear the batch after successful execution
|
||||
clearBatch();
|
||||
return updateCounts;
|
||||
}
|
||||
|
||||
boolean isBatchCompatibleStatement(String sql) {
|
||||
if (sql == null || sql.trim().isEmpty()) {
|
||||
return false;
|
||||
}
|
||||
|
||||
return BATCH_COMPATIBLE_PATTERN.matcher(sql).find();
|
||||
}
|
||||
|
||||
@Override
|
||||
|
||||
@@ -7,6 +7,7 @@ import static org.junit.jupiter.api.Assertions.assertFalse;
|
||||
import static org.junit.jupiter.api.Assertions.assertThrows;
|
||||
import static org.junit.jupiter.api.Assertions.assertTrue;
|
||||
|
||||
import java.sql.BatchUpdateException;
|
||||
import java.sql.ResultSet;
|
||||
import java.sql.SQLException;
|
||||
import java.sql.Statement;
|
||||
@@ -120,4 +121,277 @@ class JDBC4StatementTest {
|
||||
|
||||
assertThat(stmt.executeUpdate("DELETE FROM s1")).isEqualTo(3);
|
||||
}
|
||||
|
||||
/** Tests for batch processing functionality */
|
||||
@Test
|
||||
void testAddBatch_single_statement() throws SQLException {
|
||||
stmt.execute("CREATE TABLE batch_test (id INTEGER PRIMARY KEY, value TEXT);");
|
||||
|
||||
stmt.addBatch("INSERT INTO batch_test VALUES (1, 'test1');");
|
||||
|
||||
int[] updateCounts = stmt.executeBatch();
|
||||
|
||||
assertThat(updateCounts).hasSize(1);
|
||||
assertThat(updateCounts[0]).isEqualTo(1);
|
||||
|
||||
ResultSet rs = stmt.executeQuery("SELECT COUNT(*) FROM batch_test;");
|
||||
assertTrue(rs.next());
|
||||
assertThat(rs.getInt(1)).isEqualTo(1);
|
||||
}
|
||||
|
||||
@Test
|
||||
void testAddBatch_multiple_statements() throws SQLException {
|
||||
stmt.execute("CREATE TABLE batch_test (id INTEGER PRIMARY KEY, value TEXT);");
|
||||
|
||||
stmt.addBatch("INSERT INTO batch_test VALUES (1, 'test1');");
|
||||
stmt.addBatch("INSERT INTO batch_test VALUES (2, 'test2');");
|
||||
stmt.addBatch("INSERT INTO batch_test VALUES (3, 'test3');");
|
||||
|
||||
int[] updateCounts = stmt.executeBatch();
|
||||
|
||||
assertThat(updateCounts).hasSize(3);
|
||||
assertThat(updateCounts[0]).isEqualTo(1);
|
||||
assertThat(updateCounts[1]).isEqualTo(1);
|
||||
assertThat(updateCounts[2]).isEqualTo(1);
|
||||
|
||||
ResultSet rs = stmt.executeQuery("SELECT COUNT(*) FROM batch_test;");
|
||||
assertTrue(rs.next());
|
||||
assertThat(rs.getInt(1)).isEqualTo(3);
|
||||
}
|
||||
|
||||
@Test
|
||||
void testAddBatch_with_updates_and_deletes() throws SQLException {
|
||||
stmt.execute("CREATE TABLE batch_test (id INTEGER PRIMARY KEY, value TEXT);");
|
||||
|
||||
stmt.execute(
|
||||
"INSERT INTO batch_test VALUES (1, 'initial1'), (2, 'initial2'), (3, 'initial3');");
|
||||
|
||||
stmt.addBatch("UPDATE batch_test SET value = 'updated';");
|
||||
stmt.addBatch("DELETE FROM batch_test WHERE id = 2;");
|
||||
stmt.addBatch("INSERT INTO batch_test VALUES (4, 'new');");
|
||||
|
||||
int[] updateCounts = stmt.executeBatch();
|
||||
|
||||
assertThat(updateCounts).hasSize(3);
|
||||
assertThat(updateCounts[0]).isEqualTo(3); // UPDATE affected 3 row
|
||||
assertThat(updateCounts[1]).isEqualTo(1); // DELETE affected 1 row
|
||||
assertThat(updateCounts[2]).isEqualTo(1); // INSERT affected 1 row
|
||||
|
||||
// Verify final state
|
||||
ResultSet rs = stmt.executeQuery("SELECT COUNT(*) FROM batch_test;");
|
||||
assertTrue(rs.next());
|
||||
assertThat(rs.getInt(1)).isEqualTo(3); // 3 initial - 1 deleted + 1 inserted = 3
|
||||
}
|
||||
|
||||
@Test
|
||||
void testClearBatch() throws SQLException {
|
||||
stmt.execute("CREATE TABLE batch_test (id INTEGER PRIMARY KEY, value TEXT);");
|
||||
|
||||
stmt.addBatch("INSERT INTO batch_test VALUES (1, 'test1');");
|
||||
stmt.addBatch("INSERT INTO batch_test VALUES (2, 'test2');");
|
||||
|
||||
stmt.clearBatch();
|
||||
|
||||
int[] updateCounts = stmt.executeBatch();
|
||||
assertThat(updateCounts).isEmpty();
|
||||
|
||||
ResultSet rs = stmt.executeQuery("SELECT COUNT(*) FROM batch_test;");
|
||||
assertTrue(rs.next());
|
||||
assertThat(rs.getInt(1)).isEqualTo(0);
|
||||
}
|
||||
|
||||
@Test
|
||||
void testBatch_with_SELECT_should_throw_exception() throws SQLException {
|
||||
stmt.execute("CREATE TABLE batch_test (id INTEGER PRIMARY KEY, value TEXT);");
|
||||
stmt.execute("INSERT INTO batch_test VALUES (1, 'test1');");
|
||||
|
||||
stmt.addBatch("INSERT INTO batch_test VALUES (2, 'test2');");
|
||||
stmt.addBatch("SELECT * FROM batch_test;"); // This should cause an exception
|
||||
stmt.addBatch("INSERT INTO batch_test VALUES (3, 'test3');");
|
||||
|
||||
BatchUpdateException exception =
|
||||
assertThrows(BatchUpdateException.class, () -> stmt.executeBatch());
|
||||
|
||||
assertTrue(exception.getMessage().contains("Batch commands cannot return result sets"));
|
||||
|
||||
int[] updateCounts = exception.getUpdateCounts();
|
||||
assertThat(updateCounts).hasSize(3);
|
||||
assertThat(updateCounts[0]).isEqualTo(1); // First INSERT succeeded
|
||||
assertThat(updateCounts[1]).isEqualTo(Statement.EXECUTE_FAILED); // SELECT failed
|
||||
}
|
||||
|
||||
@Test
|
||||
void testBatch_with_null_command_should_throw_exception() {
|
||||
assertThrows(SQLException.class, () -> stmt.addBatch(null));
|
||||
}
|
||||
|
||||
@Test
|
||||
void testBatch_operations_on_closed_statement_should_throw_exception() throws SQLException {
|
||||
stmt.close();
|
||||
|
||||
assertThrows(SQLException.class, () -> stmt.addBatch("INSERT INTO test VALUES (1);"));
|
||||
assertThrows(SQLException.class, () -> stmt.clearBatch());
|
||||
assertThrows(SQLException.class, () -> stmt.executeBatch());
|
||||
}
|
||||
|
||||
@Test
|
||||
void testBatch_with_syntax_error_should_throw_exception() throws SQLException {
|
||||
stmt.execute("CREATE TABLE batch_test (id INTEGER PRIMARY KEY, value TEXT);");
|
||||
|
||||
stmt.addBatch("INSERT INTO batch_test VALUES (1, 'test1');");
|
||||
stmt.addBatch("INVALID SQL SYNTAX;"); // This should cause an exception
|
||||
stmt.addBatch("INSERT INTO batch_test VALUES (3, 'test3');");
|
||||
|
||||
BatchUpdateException exception =
|
||||
assertThrows(BatchUpdateException.class, () -> stmt.executeBatch());
|
||||
|
||||
int[] updateCounts = exception.getUpdateCounts();
|
||||
assertThat(updateCounts).hasSize(3);
|
||||
assertThat(updateCounts[0]).isEqualTo(1); // First INSERT succeeded
|
||||
assertThat(updateCounts[1]).isEqualTo(Statement.EXECUTE_FAILED); // Invalid SQL failed
|
||||
}
|
||||
|
||||
@Test
|
||||
void testBatch_empty_batch_returns_empty_array() throws SQLException {
|
||||
int[] updateCounts = stmt.executeBatch();
|
||||
assertThat(updateCounts).isEmpty();
|
||||
}
|
||||
|
||||
@Test
|
||||
void testBatch_clears_after_successful_execution() throws SQLException {
|
||||
stmt.execute("CREATE TABLE batch_test (id INTEGER PRIMARY KEY, value TEXT);");
|
||||
|
||||
stmt.addBatch("INSERT INTO batch_test VALUES (1, 'test1');");
|
||||
stmt.executeBatch();
|
||||
|
||||
int[] updateCounts = stmt.executeBatch();
|
||||
assertThat(updateCounts).isEmpty();
|
||||
}
|
||||
|
||||
@Test
|
||||
void testBatch_clears_after_failed_execution() throws SQLException {
|
||||
stmt.execute("CREATE TABLE batch_test (id INTEGER PRIMARY KEY, value TEXT);");
|
||||
|
||||
stmt.addBatch("SELECT * FROM batch_test;");
|
||||
|
||||
assertThrows(BatchUpdateException.class, () -> stmt.executeBatch());
|
||||
|
||||
int[] updateCounts = stmt.executeBatch();
|
||||
assertThat(updateCounts).isEmpty();
|
||||
}
|
||||
|
||||
/** Tests for isBatchCompatibleStatement method */
|
||||
@Test
|
||||
void testIsBatchCompatibleStatement_compatible_statements() {
|
||||
JDBC4Statement jdbc4Stmt = (JDBC4Statement) stmt;
|
||||
|
||||
assertTrue(jdbc4Stmt.isBatchCompatibleStatement("INSERT INTO table VALUES (1, 2);"));
|
||||
assertTrue(jdbc4Stmt.isBatchCompatibleStatement("insert into table values (1, 2);"));
|
||||
assertTrue(
|
||||
jdbc4Stmt.isBatchCompatibleStatement("INSERT INTO table (col1, col2) VALUES (1, 2);"));
|
||||
assertTrue(jdbc4Stmt.isBatchCompatibleStatement("INSERT OR REPLACE INTO table VALUES (1);"));
|
||||
assertTrue(jdbc4Stmt.isBatchCompatibleStatement("INSERT OR IGNORE INTO table VALUES (1);"));
|
||||
|
||||
assertTrue(jdbc4Stmt.isBatchCompatibleStatement(" INSERT INTO table VALUES (1);"));
|
||||
assertTrue(jdbc4Stmt.isBatchCompatibleStatement("\t\nINSERT INTO table VALUES (1);"));
|
||||
assertTrue(jdbc4Stmt.isBatchCompatibleStatement(" \n\t INSERT INTO table VALUES (1);"));
|
||||
|
||||
assertTrue(jdbc4Stmt.isBatchCompatibleStatement("/* comment */ INSERT INTO table VALUES (1);"));
|
||||
assertTrue(
|
||||
jdbc4Stmt.isBatchCompatibleStatement(
|
||||
"/* multi\nline\ncomment */ INSERT INTO table VALUES (1);"));
|
||||
assertTrue(
|
||||
jdbc4Stmt.isBatchCompatibleStatement("-- line comment\nINSERT INTO table VALUES (1);"));
|
||||
assertTrue(
|
||||
jdbc4Stmt.isBatchCompatibleStatement(
|
||||
"-- comment 1\n-- comment 2\nINSERT INTO table VALUES (1);"));
|
||||
|
||||
assertTrue(
|
||||
jdbc4Stmt.isBatchCompatibleStatement(
|
||||
" /* comment */ -- another\n INSERT INTO table VALUES (1);"));
|
||||
|
||||
assertTrue(jdbc4Stmt.isBatchCompatibleStatement("UPDATE table SET col = 1;"));
|
||||
assertTrue(jdbc4Stmt.isBatchCompatibleStatement("update table set col = 1;"));
|
||||
assertTrue(
|
||||
jdbc4Stmt.isBatchCompatibleStatement("UPDATE table SET col1 = 1, col2 = 2 WHERE id = 3;"));
|
||||
assertTrue(jdbc4Stmt.isBatchCompatibleStatement("UPDATE OR REPLACE table SET col = 1;"));
|
||||
|
||||
assertTrue(jdbc4Stmt.isBatchCompatibleStatement(" UPDATE table SET col = 1;"));
|
||||
assertTrue(jdbc4Stmt.isBatchCompatibleStatement("\t\nUPDATE table SET col = 1;"));
|
||||
|
||||
assertTrue(jdbc4Stmt.isBatchCompatibleStatement("/* comment */ UPDATE table SET col = 1;"));
|
||||
assertTrue(jdbc4Stmt.isBatchCompatibleStatement("-- comment\nUPDATE table SET col = 1;"));
|
||||
|
||||
assertTrue(jdbc4Stmt.isBatchCompatibleStatement("DELETE FROM table;"));
|
||||
assertTrue(jdbc4Stmt.isBatchCompatibleStatement("delete from table;"));
|
||||
assertTrue(jdbc4Stmt.isBatchCompatibleStatement("DELETE FROM table WHERE id = 1;"));
|
||||
|
||||
assertTrue(jdbc4Stmt.isBatchCompatibleStatement(" DELETE FROM table;"));
|
||||
assertTrue(jdbc4Stmt.isBatchCompatibleStatement("\t\nDELETE FROM table;"));
|
||||
|
||||
assertTrue(jdbc4Stmt.isBatchCompatibleStatement("/* comment */ DELETE FROM table;"));
|
||||
assertTrue(jdbc4Stmt.isBatchCompatibleStatement("-- comment\nDELETE FROM table;"));
|
||||
}
|
||||
|
||||
@Test
|
||||
void testIsBatchCompatibleStatement_non_compatible_statements() {
|
||||
JDBC4Statement jdbc4Stmt = (JDBC4Statement) stmt;
|
||||
|
||||
assertFalse(jdbc4Stmt.isBatchCompatibleStatement("SELECT * FROM table;"));
|
||||
assertFalse(jdbc4Stmt.isBatchCompatibleStatement("select * from table;"));
|
||||
assertFalse(jdbc4Stmt.isBatchCompatibleStatement(" SELECT * FROM table;"));
|
||||
assertFalse(jdbc4Stmt.isBatchCompatibleStatement("/* comment */ SELECT * FROM table;"));
|
||||
assertFalse(jdbc4Stmt.isBatchCompatibleStatement("-- comment\nSELECT * FROM table;"));
|
||||
|
||||
assertFalse(jdbc4Stmt.isBatchCompatibleStatement("EXPLAIN SELECT * FROM table;"));
|
||||
assertFalse(jdbc4Stmt.isBatchCompatibleStatement("EXPLAIN QUERY PLAN SELECT * FROM table;"));
|
||||
|
||||
assertFalse(jdbc4Stmt.isBatchCompatibleStatement("PRAGMA table_info(table);"));
|
||||
assertFalse(jdbc4Stmt.isBatchCompatibleStatement("PRAGMA foreign_keys = ON;"));
|
||||
|
||||
assertFalse(jdbc4Stmt.isBatchCompatibleStatement("ANALYZE;"));
|
||||
assertFalse(jdbc4Stmt.isBatchCompatibleStatement("ANALYZE table;"));
|
||||
|
||||
assertFalse(
|
||||
jdbc4Stmt.isBatchCompatibleStatement(
|
||||
"WITH cte AS (SELECT * FROM table) SELECT * FROM cte;"));
|
||||
|
||||
assertFalse(jdbc4Stmt.isBatchCompatibleStatement("VACUUM;"));
|
||||
|
||||
assertFalse(jdbc4Stmt.isBatchCompatibleStatement("VALUES (1, 2), (3, 4);"));
|
||||
}
|
||||
|
||||
@Test
|
||||
void testIsBatchCompatibleStatement_edge_cases() {
|
||||
JDBC4Statement jdbc4Stmt = (JDBC4Statement) stmt;
|
||||
|
||||
assertFalse(jdbc4Stmt.isBatchCompatibleStatement(null));
|
||||
assertFalse(jdbc4Stmt.isBatchCompatibleStatement(""));
|
||||
assertFalse(jdbc4Stmt.isBatchCompatibleStatement(" "));
|
||||
assertFalse(jdbc4Stmt.isBatchCompatibleStatement("\t\n"));
|
||||
|
||||
assertFalse(jdbc4Stmt.isBatchCompatibleStatement("/* comment only */"));
|
||||
assertFalse(jdbc4Stmt.isBatchCompatibleStatement("-- comment only"));
|
||||
assertFalse(jdbc4Stmt.isBatchCompatibleStatement("/* comment */ -- another comment"));
|
||||
|
||||
assertFalse(jdbc4Stmt.isBatchCompatibleStatement("SELECT * FROM table WHERE name = 'INSERT';"));
|
||||
assertFalse(
|
||||
jdbc4Stmt.isBatchCompatibleStatement("SELECT * FROM table WHERE action = 'DELETE';"));
|
||||
|
||||
assertFalse(jdbc4Stmt.isBatchCompatibleStatement("INSER INTO table VALUES (1);"));
|
||||
assertFalse(jdbc4Stmt.isBatchCompatibleStatement("UPDAT table SET col = 1;"));
|
||||
assertFalse(jdbc4Stmt.isBatchCompatibleStatement("DELET FROM table;"));
|
||||
}
|
||||
|
||||
@Test
|
||||
void testIsBatchCompatibleStatement_case_insensitive() {
|
||||
JDBC4Statement jdbc4Stmt = (JDBC4Statement) stmt;
|
||||
|
||||
assertTrue(jdbc4Stmt.isBatchCompatibleStatement("Insert INTO table VALUES (1);"));
|
||||
assertTrue(jdbc4Stmt.isBatchCompatibleStatement("InSeRt INTO table VALUES (1);"));
|
||||
assertTrue(jdbc4Stmt.isBatchCompatibleStatement("UPDATE table SET col = 1;"));
|
||||
assertTrue(jdbc4Stmt.isBatchCompatibleStatement("UpDaTe table SET col = 1;"));
|
||||
assertTrue(jdbc4Stmt.isBatchCompatibleStatement("Delete FROM table;"));
|
||||
assertTrue(jdbc4Stmt.isBatchCompatibleStatement("DeLeTe FROM table;"));
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1 +0,0 @@
|
||||
{"rustc_fingerprint":11551670960185020797,"outputs":{"14427667104029986310":{"success":true,"status":"","code":0,"stdout":"rustc 1.83.0 (90b35a623 2024-11-26)\nbinary: rustc\ncommit-hash: 90b35a6239c3d8bdabc530a6a0816f7ff89a0aaf\ncommit-date: 2024-11-26\nhost: x86_64-unknown-linux-gnu\nrelease: 1.83.0\nLLVM version: 19.1.1\n","stderr":""},"11399821309745579047":{"success":true,"status":"","code":0,"stdout":"___\nlib___.rlib\nlib___.so\nlib___.so\nlib___.a\nlib___.so\n/home/merlin/.rustup/toolchains/1.83.0-x86_64-unknown-linux-gnu\noff\npacked\nunpacked\n___\ndebug_assertions\npanic=\"unwind\"\nproc_macro\ntarget_abi=\"\"\ntarget_arch=\"x86_64\"\ntarget_endian=\"little\"\ntarget_env=\"gnu\"\ntarget_family=\"unix\"\ntarget_feature=\"fxsr\"\ntarget_feature=\"sse\"\ntarget_feature=\"sse2\"\ntarget_has_atomic=\"16\"\ntarget_has_atomic=\"32\"\ntarget_has_atomic=\"64\"\ntarget_has_atomic=\"8\"\ntarget_has_atomic=\"ptr\"\ntarget_os=\"linux\"\ntarget_pointer_width=\"64\"\ntarget_vendor=\"unknown\"\nunix\n","stderr":""}},"successes":{}}
|
||||
@@ -569,7 +569,7 @@ impl turso_core::DatabaseStorage for DatabaseFile {
|
||||
fn read_page(
|
||||
&self,
|
||||
page_idx: usize,
|
||||
_key: Option<&turso_core::EncryptionKey>,
|
||||
_encryption_ctx: Option<&turso_core::EncryptionContext>,
|
||||
c: turso_core::Completion,
|
||||
) -> turso_core::Result<turso_core::Completion> {
|
||||
let r = c.as_read();
|
||||
@@ -586,7 +586,7 @@ impl turso_core::DatabaseStorage for DatabaseFile {
|
||||
&self,
|
||||
page_idx: usize,
|
||||
buffer: Arc<turso_core::Buffer>,
|
||||
_key: Option<&turso_core::EncryptionKey>,
|
||||
_encryption_ctx: Option<&turso_core::EncryptionContext>,
|
||||
c: turso_core::Completion,
|
||||
) -> turso_core::Result<turso_core::Completion> {
|
||||
let size = buffer.len();
|
||||
@@ -599,7 +599,7 @@ impl turso_core::DatabaseStorage for DatabaseFile {
|
||||
first_page_idx: usize,
|
||||
page_size: usize,
|
||||
buffers: Vec<Arc<turso_core::Buffer>>,
|
||||
_key: Option<&turso_core::EncryptionKey>,
|
||||
_encryption_ctx: Option<&turso_core::EncryptionContext>,
|
||||
c: turso_core::Completion,
|
||||
) -> turso_core::Result<turso_core::Completion> {
|
||||
let pos = first_page_idx.saturating_sub(1) * page_size;
|
||||
|
||||
@@ -75,7 +75,7 @@ use std::{
|
||||
};
|
||||
#[cfg(feature = "fs")]
|
||||
use storage::database::DatabaseFile;
|
||||
pub use storage::encryption::EncryptionKey;
|
||||
pub use storage::encryption::{EncryptionKey, EncryptionContext};
|
||||
use storage::page_cache::DumbLruPageCache;
|
||||
use storage::pager::{AtomicDbState, DbState};
|
||||
use storage::sqlite3_ondisk::PageSize;
|
||||
@@ -1955,11 +1955,11 @@ impl Connection {
|
||||
self.syms.borrow().vtab_modules.keys().cloned().collect()
|
||||
}
|
||||
|
||||
pub fn set_encryption_key(&self, key: Option<EncryptionKey>) {
|
||||
pub fn set_encryption_key(&self, key: EncryptionKey) {
|
||||
tracing::trace!("setting encryption key for connection");
|
||||
*self.encryption_key.borrow_mut() = key.clone();
|
||||
*self.encryption_key.borrow_mut() = Some(key.clone());
|
||||
let pager = self.pager.borrow();
|
||||
pager.set_encryption_key(key);
|
||||
pager.set_encryption_context(&key);
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -255,7 +255,8 @@ impl Schema {
|
||||
}
|
||||
|
||||
pub fn table_has_indexes(&self, table_name: &str) -> bool {
|
||||
self.has_indexes.contains(table_name)
|
||||
let name = normalize_ident(table_name);
|
||||
self.has_indexes.contains(&name)
|
||||
}
|
||||
|
||||
pub fn table_set_has_index(&mut self, table_name: &str) {
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
use crate::error::LimboError;
|
||||
use crate::storage::encryption::{decrypt_page, encrypt_page, EncryptionKey};
|
||||
use crate::storage::encryption::EncryptionContext;
|
||||
use crate::{io::Completion, Buffer, CompletionError, Result};
|
||||
use std::sync::Arc;
|
||||
use tracing::{instrument, Level};
|
||||
@@ -15,14 +15,14 @@ pub trait DatabaseStorage: Send + Sync {
|
||||
fn read_page(
|
||||
&self,
|
||||
page_idx: usize,
|
||||
encryption_key: Option<&EncryptionKey>,
|
||||
encryption_ctx: Option<&EncryptionContext>,
|
||||
c: Completion,
|
||||
) -> Result<Completion>;
|
||||
fn write_page(
|
||||
&self,
|
||||
page_idx: usize,
|
||||
buffer: Arc<Buffer>,
|
||||
encryption_key: Option<&EncryptionKey>,
|
||||
encryption_ctx: Option<&EncryptionContext>,
|
||||
c: Completion,
|
||||
) -> Result<Completion>;
|
||||
fn write_pages(
|
||||
@@ -30,7 +30,7 @@ pub trait DatabaseStorage: Send + Sync {
|
||||
first_page_idx: usize,
|
||||
page_size: usize,
|
||||
buffers: Vec<Arc<Buffer>>,
|
||||
encryption_key: Option<&EncryptionKey>,
|
||||
encryption_ctx: Option<&EncryptionContext>,
|
||||
c: Completion,
|
||||
) -> Result<Completion>;
|
||||
fn sync(&self, c: Completion) -> Result<Completion>;
|
||||
@@ -59,7 +59,7 @@ impl DatabaseStorage for DatabaseFile {
|
||||
fn read_page(
|
||||
&self,
|
||||
page_idx: usize,
|
||||
encryption_key: Option<&EncryptionKey>,
|
||||
encryption_ctx: Option<&EncryptionContext>,
|
||||
c: Completion,
|
||||
) -> Result<Completion> {
|
||||
let r = c.as_read();
|
||||
@@ -70,8 +70,8 @@ impl DatabaseStorage for DatabaseFile {
|
||||
}
|
||||
let pos = (page_idx - 1) * size;
|
||||
|
||||
if let Some(key) = encryption_key {
|
||||
let key_clone = key.clone();
|
||||
if let Some(ctx) = encryption_ctx {
|
||||
let encryption_ctx = ctx.clone();
|
||||
let read_buffer = r.buf_arc();
|
||||
let original_c = c.clone();
|
||||
|
||||
@@ -81,7 +81,7 @@ impl DatabaseStorage for DatabaseFile {
|
||||
return;
|
||||
};
|
||||
if bytes_read > 0 {
|
||||
match decrypt_page(buf.as_slice(), page_idx, &key_clone) {
|
||||
match encryption_ctx.decrypt_page(buf.as_slice(), page_idx) {
|
||||
Ok(decrypted_data) => {
|
||||
let original_buf = original_c.as_read().buf();
|
||||
original_buf.as_mut_slice().copy_from_slice(&decrypted_data);
|
||||
@@ -111,7 +111,7 @@ impl DatabaseStorage for DatabaseFile {
|
||||
&self,
|
||||
page_idx: usize,
|
||||
buffer: Arc<Buffer>,
|
||||
encryption_key: Option<&EncryptionKey>,
|
||||
encryption_ctx: Option<&EncryptionContext>,
|
||||
c: Completion,
|
||||
) -> Result<Completion> {
|
||||
let buffer_size = buffer.len();
|
||||
@@ -121,8 +121,8 @@ impl DatabaseStorage for DatabaseFile {
|
||||
assert_eq!(buffer_size & (buffer_size - 1), 0);
|
||||
let pos = (page_idx - 1) * buffer_size;
|
||||
let buffer = {
|
||||
if let Some(key) = encryption_key {
|
||||
encrypt_buffer(page_idx, buffer, key)
|
||||
if let Some(ctx) = encryption_ctx {
|
||||
encrypt_buffer(page_idx, buffer, ctx)
|
||||
} else {
|
||||
buffer
|
||||
}
|
||||
@@ -135,7 +135,7 @@ impl DatabaseStorage for DatabaseFile {
|
||||
first_page_idx: usize,
|
||||
page_size: usize,
|
||||
buffers: Vec<Arc<Buffer>>,
|
||||
encryption_key: Option<&EncryptionKey>,
|
||||
encryption_key: Option<&EncryptionContext>,
|
||||
c: Completion,
|
||||
) -> Result<Completion> {
|
||||
assert!(first_page_idx > 0);
|
||||
@@ -145,11 +145,11 @@ impl DatabaseStorage for DatabaseFile {
|
||||
|
||||
let pos = (first_page_idx - 1) * page_size;
|
||||
let buffers = {
|
||||
if let Some(key) = encryption_key {
|
||||
if let Some(ctx) = encryption_key {
|
||||
buffers
|
||||
.into_iter()
|
||||
.enumerate()
|
||||
.map(|(i, buffer)| encrypt_buffer(first_page_idx + i, buffer, key))
|
||||
.map(|(i, buffer)| encrypt_buffer(first_page_idx + i, buffer, ctx))
|
||||
.collect::<Vec<_>>()
|
||||
} else {
|
||||
buffers
|
||||
@@ -184,7 +184,11 @@ impl DatabaseFile {
|
||||
}
|
||||
}
|
||||
|
||||
fn encrypt_buffer(page_idx: usize, buffer: Arc<Buffer>, key: &EncryptionKey) -> Arc<Buffer> {
|
||||
let encrypted_data = encrypt_page(buffer.as_slice(), page_idx, key).unwrap();
|
||||
fn encrypt_buffer(
|
||||
page_idx: usize,
|
||||
buffer: Arc<Buffer>,
|
||||
ctx: &EncryptionContext,
|
||||
) -> Arc<Buffer> {
|
||||
let encrypted_data = ctx.encrypt_page(buffer.as_slice(), page_idx).unwrap();
|
||||
Arc::new(Buffer::new(encrypted_data.to_vec()))
|
||||
}
|
||||
|
||||
@@ -1,7 +1,5 @@
|
||||
#![allow(unused_variables, dead_code)]
|
||||
#[cfg(not(feature = "encryption"))]
|
||||
use crate::LimboError;
|
||||
use crate::Result;
|
||||
use crate::{LimboError, Result};
|
||||
use aes_gcm::{
|
||||
aead::{Aead, AeadCore, KeyInit, OsRng},
|
||||
Aes256Gcm, Key, Nonce,
|
||||
@@ -11,6 +9,7 @@ use std::ops::Deref;
|
||||
pub const ENCRYPTION_METADATA_SIZE: usize = 28;
|
||||
pub const ENCRYPTED_PAGE_SIZE: usize = 4096;
|
||||
pub const ENCRYPTION_NONCE_SIZE: usize = 12;
|
||||
pub const ENCRYPTION_TAG_SIZE: usize = 16;
|
||||
|
||||
#[repr(transparent)]
|
||||
#[derive(Clone)]
|
||||
@@ -71,106 +70,195 @@ impl Drop for EncryptionKey {
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(not(feature = "encryption"))]
|
||||
pub fn encrypt_page(page: &[u8], page_id: usize, key: &EncryptionKey) -> Result<Vec<u8>> {
|
||||
Err(LimboError::InvalidArgument(
|
||||
"encryption is not enabled, cannot encrypt page. enable via passing `--features encryption`".into(),
|
||||
))
|
||||
#[derive(Debug, Clone, Copy, PartialEq)]
|
||||
pub enum CipherMode {
|
||||
Aes256Gcm,
|
||||
}
|
||||
|
||||
#[cfg(feature = "encryption")]
|
||||
pub fn encrypt_page(page: &[u8], page_id: usize, key: &EncryptionKey) -> Result<Vec<u8>> {
|
||||
if page_id == 1 {
|
||||
tracing::debug!("skipping encryption for page 1 (database header)");
|
||||
return Ok(page.to_vec());
|
||||
impl CipherMode {
|
||||
/// Every cipher requires a specific key size. For 256-bit algorithms, this is 32 bytes.
|
||||
/// For 128-bit algorithms, it would be 16 bytes, etc.
|
||||
pub fn required_key_size(&self) -> usize {
|
||||
match self {
|
||||
CipherMode::Aes256Gcm => 32,
|
||||
}
|
||||
}
|
||||
tracing::debug!("encrypting page {}", page_id);
|
||||
assert_eq!(
|
||||
page.len(),
|
||||
ENCRYPTED_PAGE_SIZE,
|
||||
"Page data must be exactly {ENCRYPTED_PAGE_SIZE} bytes"
|
||||
);
|
||||
let reserved_bytes = &page[ENCRYPTED_PAGE_SIZE - ENCRYPTION_METADATA_SIZE..];
|
||||
let reserved_bytes_zeroed = reserved_bytes.iter().all(|&b| b == 0);
|
||||
assert!(
|
||||
reserved_bytes_zeroed,
|
||||
"last reserved bytes must be empty/zero, but found non-zero bytes"
|
||||
);
|
||||
let payload = &page[..ENCRYPTED_PAGE_SIZE - ENCRYPTION_METADATA_SIZE];
|
||||
let (encrypted, nonce) = encrypt(payload, key)?;
|
||||
assert_eq!(
|
||||
encrypted.len(),
|
||||
ENCRYPTED_PAGE_SIZE - nonce.len(),
|
||||
"Encrypted page must be exactly {} bytes",
|
||||
ENCRYPTED_PAGE_SIZE - nonce.len()
|
||||
);
|
||||
let mut result = Vec::with_capacity(ENCRYPTED_PAGE_SIZE);
|
||||
result.extend_from_slice(&encrypted);
|
||||
result.extend_from_slice(&nonce);
|
||||
assert_eq!(
|
||||
result.len(),
|
||||
ENCRYPTED_PAGE_SIZE,
|
||||
"Encrypted page must be exactly {ENCRYPTED_PAGE_SIZE} bytes"
|
||||
);
|
||||
Ok(result)
|
||||
}
|
||||
|
||||
#[cfg(not(feature = "encryption"))]
|
||||
pub fn decrypt_page(encrypted_page: &[u8], page_id: usize, key: &EncryptionKey) -> Result<Vec<u8>> {
|
||||
Err(LimboError::InvalidArgument(
|
||||
"encryption is not enabled, cannot decrypt page. enable via passing `--features encryption`".into(),
|
||||
))
|
||||
}
|
||||
|
||||
#[cfg(feature = "encryption")]
|
||||
pub fn decrypt_page(encrypted_page: &[u8], page_id: usize, key: &EncryptionKey) -> Result<Vec<u8>> {
|
||||
if page_id == 1 {
|
||||
tracing::debug!("skipping decryption for page 1 (database header)");
|
||||
return Ok(encrypted_page.to_vec());
|
||||
/// Returns the nonce size for this cipher mode. Though most AEAD ciphers use 12-byte nonces.
|
||||
pub fn nonce_size(&self) -> usize {
|
||||
match self {
|
||||
CipherMode::Aes256Gcm => ENCRYPTION_NONCE_SIZE,
|
||||
}
|
||||
}
|
||||
tracing::debug!("decrypting page {}", page_id);
|
||||
assert_eq!(
|
||||
encrypted_page.len(),
|
||||
ENCRYPTED_PAGE_SIZE,
|
||||
"Encrypted page data must be exactly {ENCRYPTED_PAGE_SIZE} bytes"
|
||||
);
|
||||
|
||||
let nonce_start = encrypted_page.len() - ENCRYPTION_NONCE_SIZE;
|
||||
let payload = &encrypted_page[..nonce_start];
|
||||
let nonce = &encrypted_page[nonce_start..];
|
||||
|
||||
let decrypted_data = decrypt(payload, nonce, key)?;
|
||||
assert_eq!(
|
||||
decrypted_data.len(),
|
||||
ENCRYPTED_PAGE_SIZE - ENCRYPTION_METADATA_SIZE,
|
||||
"Decrypted page data must be exactly {} bytes",
|
||||
ENCRYPTED_PAGE_SIZE - ENCRYPTION_METADATA_SIZE
|
||||
);
|
||||
let mut result = Vec::with_capacity(ENCRYPTED_PAGE_SIZE);
|
||||
result.extend_from_slice(&decrypted_data);
|
||||
result.resize(ENCRYPTED_PAGE_SIZE, 0);
|
||||
assert_eq!(
|
||||
result.len(),
|
||||
ENCRYPTED_PAGE_SIZE,
|
||||
"Decrypted page data must be exactly {ENCRYPTED_PAGE_SIZE} bytes"
|
||||
);
|
||||
Ok(result)
|
||||
/// Returns the authentication tag size for this cipher mode. All common AEAD ciphers use 16-byte tags.
|
||||
pub fn tag_size(&self) -> usize {
|
||||
match self {
|
||||
CipherMode::Aes256Gcm => ENCRYPTION_TAG_SIZE,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn encrypt(plaintext: &[u8], key: &EncryptionKey) -> Result<(Vec<u8>, Vec<u8>)> {
|
||||
let key: &Key<Aes256Gcm> = key.as_ref().into();
|
||||
let cipher = Aes256Gcm::new(key);
|
||||
let nonce = Aes256Gcm::generate_nonce(&mut OsRng);
|
||||
let ciphertext = cipher.encrypt(&nonce, plaintext).unwrap();
|
||||
Ok((ciphertext, nonce.to_vec()))
|
||||
#[derive(Clone)]
|
||||
pub enum Cipher {
|
||||
Aes256Gcm(Box<Aes256Gcm>),
|
||||
}
|
||||
|
||||
fn decrypt(ciphertext: &[u8], nonce: &[u8], key: &EncryptionKey) -> Result<Vec<u8>> {
|
||||
let key: &Key<Aes256Gcm> = key.as_ref().into();
|
||||
let cipher = Aes256Gcm::new(key);
|
||||
let nonce = Nonce::from_slice(nonce);
|
||||
let plaintext = cipher.decrypt(nonce, ciphertext).unwrap();
|
||||
Ok(plaintext)
|
||||
impl std::fmt::Debug for Cipher {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
match self {
|
||||
Cipher::Aes256Gcm(_) => write!(f, "Cipher::Aes256Gcm"),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone)]
|
||||
pub struct EncryptionContext {
|
||||
cipher_mode: CipherMode,
|
||||
cipher: Cipher,
|
||||
}
|
||||
|
||||
impl EncryptionContext {
|
||||
pub fn new(key: &EncryptionKey) -> Result<Self> {
|
||||
let cipher_mode = CipherMode::Aes256Gcm;
|
||||
let required_size = cipher_mode.required_key_size();
|
||||
if key.as_slice().len() != required_size {
|
||||
return Err(crate::LimboError::InvalidArgument(format!(
|
||||
"Invalid key size for {:?}: expected {} bytes, got {}",
|
||||
cipher_mode,
|
||||
required_size,
|
||||
key.as_slice().len()
|
||||
)));
|
||||
}
|
||||
|
||||
let cipher = match cipher_mode {
|
||||
CipherMode::Aes256Gcm => {
|
||||
let cipher_key: &Key<Aes256Gcm> = key.as_ref().into();
|
||||
Cipher::Aes256Gcm(Box::new(Aes256Gcm::new(cipher_key)))
|
||||
}
|
||||
};
|
||||
Ok(Self {
|
||||
cipher_mode,
|
||||
cipher,
|
||||
})
|
||||
}
|
||||
|
||||
pub fn cipher_mode(&self) -> CipherMode {
|
||||
self.cipher_mode
|
||||
}
|
||||
|
||||
#[cfg(feature = "encryption")]
|
||||
pub fn encrypt_page(&self, page: &[u8], page_id: usize) -> Result<Vec<u8>> {
|
||||
if page_id == 1 {
|
||||
tracing::debug!("skipping encryption for page 1 (database header)");
|
||||
return Ok(page.to_vec());
|
||||
}
|
||||
tracing::debug!("encrypting page {}", page_id);
|
||||
assert_eq!(
|
||||
page.len(),
|
||||
ENCRYPTED_PAGE_SIZE,
|
||||
"Page data must be exactly {ENCRYPTED_PAGE_SIZE} bytes"
|
||||
);
|
||||
let reserved_bytes = &page[ENCRYPTED_PAGE_SIZE - ENCRYPTION_METADATA_SIZE..];
|
||||
let reserved_bytes_zeroed = reserved_bytes.iter().all(|&b| b == 0);
|
||||
assert!(
|
||||
reserved_bytes_zeroed,
|
||||
"last reserved bytes must be empty/zero, but found non-zero bytes"
|
||||
);
|
||||
let payload = &page[..ENCRYPTED_PAGE_SIZE - ENCRYPTION_METADATA_SIZE];
|
||||
let (encrypted, nonce) = self.encrypt_raw(payload)?;
|
||||
|
||||
assert_eq!(
|
||||
encrypted.len(),
|
||||
ENCRYPTED_PAGE_SIZE - nonce.len(),
|
||||
"Encrypted page must be exactly {} bytes",
|
||||
ENCRYPTED_PAGE_SIZE - nonce.len()
|
||||
);
|
||||
let mut result = Vec::with_capacity(ENCRYPTED_PAGE_SIZE);
|
||||
result.extend_from_slice(&encrypted);
|
||||
result.extend_from_slice(&nonce);
|
||||
assert_eq!(
|
||||
result.len(),
|
||||
ENCRYPTED_PAGE_SIZE,
|
||||
"Encrypted page must be exactly {ENCRYPTED_PAGE_SIZE} bytes"
|
||||
);
|
||||
Ok(result)
|
||||
}
|
||||
|
||||
#[cfg(feature = "encryption")]
|
||||
pub fn decrypt_page(&self, encrypted_page: &[u8], page_id: usize) -> Result<Vec<u8>> {
|
||||
if page_id == 1 {
|
||||
tracing::debug!("skipping decryption for page 1 (database header)");
|
||||
return Ok(encrypted_page.to_vec());
|
||||
}
|
||||
tracing::debug!("decrypting page {}", page_id);
|
||||
assert_eq!(
|
||||
encrypted_page.len(),
|
||||
ENCRYPTED_PAGE_SIZE,
|
||||
"Encrypted page data must be exactly {ENCRYPTED_PAGE_SIZE} bytes"
|
||||
);
|
||||
|
||||
let nonce_start = encrypted_page.len() - ENCRYPTION_NONCE_SIZE;
|
||||
let payload = &encrypted_page[..nonce_start];
|
||||
let nonce = &encrypted_page[nonce_start..];
|
||||
|
||||
let decrypted_data = self.decrypt_raw(payload, nonce)?;
|
||||
|
||||
assert_eq!(
|
||||
decrypted_data.len(),
|
||||
ENCRYPTED_PAGE_SIZE - ENCRYPTION_METADATA_SIZE,
|
||||
"Decrypted page data must be exactly {} bytes",
|
||||
ENCRYPTED_PAGE_SIZE - ENCRYPTION_METADATA_SIZE
|
||||
);
|
||||
let mut result = Vec::with_capacity(ENCRYPTED_PAGE_SIZE);
|
||||
result.extend_from_slice(&decrypted_data);
|
||||
result.resize(ENCRYPTED_PAGE_SIZE, 0);
|
||||
assert_eq!(
|
||||
result.len(),
|
||||
ENCRYPTED_PAGE_SIZE,
|
||||
"Decrypted page data must be exactly {ENCRYPTED_PAGE_SIZE} bytes"
|
||||
);
|
||||
Ok(result)
|
||||
}
|
||||
|
||||
/// encrypts raw data using the configured cipher, returns ciphertext and nonce
|
||||
fn encrypt_raw(&self, plaintext: &[u8]) -> Result<(Vec<u8>, Vec<u8>)> {
|
||||
match &self.cipher {
|
||||
Cipher::Aes256Gcm(cipher) => {
|
||||
let nonce = Aes256Gcm::generate_nonce(&mut OsRng);
|
||||
let ciphertext = cipher
|
||||
.encrypt(&nonce, plaintext)
|
||||
.map_err(|e| LimboError::InternalError(format!("Encryption failed: {e:?}")))?;
|
||||
Ok((ciphertext, nonce.to_vec()))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn decrypt_raw(&self, ciphertext: &[u8], nonce: &[u8]) -> Result<Vec<u8>> {
|
||||
match &self.cipher {
|
||||
Cipher::Aes256Gcm(cipher) => {
|
||||
let nonce = Nonce::from_slice(nonce);
|
||||
let plaintext = cipher.decrypt(nonce, ciphertext).map_err(|e| {
|
||||
crate::LimboError::InternalError(format!("Decryption failed: {e:?}"))
|
||||
})?;
|
||||
Ok(plaintext)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(not(feature = "encryption"))]
|
||||
pub fn encrypt_page(&self, _page: &[u8], _page_id: usize) -> Result<Vec<u8>> {
|
||||
Err(LimboError::InvalidArgument(
|
||||
"encryption is not enabled, cannot encrypt page. enable via passing `--features encryption`".into(),
|
||||
))
|
||||
}
|
||||
|
||||
#[cfg(not(feature = "encryption"))]
|
||||
pub fn decrypt_page(&self, _encrypted_page: &[u8], _page_id: usize) -> Result<Vec<u8>> {
|
||||
Err(LimboError::InvalidArgument(
|
||||
"encryption is not enabled, cannot decrypt page. enable via passing `--features encryption`".into(),
|
||||
))
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
@@ -193,14 +281,15 @@ mod tests {
|
||||
};
|
||||
|
||||
let key = EncryptionKey::from_string("alice and bob use encryption on database");
|
||||
let ctx = EncryptionContext::new(&key).unwrap();
|
||||
|
||||
let page_id = 42;
|
||||
let encrypted = encrypt_page(&page_data, page_id, &key).unwrap();
|
||||
let encrypted = ctx.encrypt_page(&page_data, page_id).unwrap();
|
||||
assert_eq!(encrypted.len(), ENCRYPTED_PAGE_SIZE);
|
||||
assert_ne!(&encrypted[..data_size], &page_data[..data_size]);
|
||||
assert_ne!(&encrypted[..], &page_data[..]);
|
||||
|
||||
let decrypted = decrypt_page(&encrypted, page_id, &key).unwrap();
|
||||
let decrypted = ctx.decrypt_page(&encrypted, page_id).unwrap();
|
||||
assert_eq!(decrypted.len(), ENCRYPTED_PAGE_SIZE);
|
||||
assert_eq!(decrypted, page_data);
|
||||
}
|
||||
|
||||
@@ -28,7 +28,9 @@ use super::btree::{btree_init_page, BTreePage};
|
||||
use super::page_cache::{CacheError, CacheResizeResult, DumbLruPageCache, PageCacheKey};
|
||||
use super::sqlite3_ondisk::begin_write_btree_page;
|
||||
use super::wal::CheckpointMode;
|
||||
use crate::storage::encryption::{EncryptionKey, ENCRYPTION_METADATA_SIZE};
|
||||
use crate::storage::encryption::{
|
||||
EncryptionKey, EncryptionContext, ENCRYPTION_METADATA_SIZE,
|
||||
};
|
||||
|
||||
/// SQLite's default maximum page count
|
||||
const DEFAULT_MAX_PAGE_COUNT: u32 = 0xfffffffe;
|
||||
@@ -347,7 +349,7 @@ pub enum BtreePageAllocMode {
|
||||
|
||||
/// This will keep track of the state of current cache commit in order to not repeat work
|
||||
struct CommitInfo {
|
||||
state: CommitState,
|
||||
state: Cell<CommitState>,
|
||||
}
|
||||
|
||||
/// Track the state of the auto-vacuum mode.
|
||||
@@ -460,10 +462,10 @@ pub struct Pager {
|
||||
pub io: Arc<dyn crate::io::IO>,
|
||||
dirty_pages: Rc<RefCell<HashSet<usize, hash::BuildHasherDefault<hash::DefaultHasher>>>>,
|
||||
|
||||
commit_info: RefCell<CommitInfo>,
|
||||
commit_info: CommitInfo,
|
||||
checkpoint_state: RefCell<CheckpointState>,
|
||||
syncing: Rc<Cell<bool>>,
|
||||
auto_vacuum_mode: RefCell<AutoVacuumMode>,
|
||||
auto_vacuum_mode: Cell<AutoVacuumMode>,
|
||||
/// 0 -> Database is empty,
|
||||
/// 1 -> Database is being initialized,
|
||||
/// 2 -> Database is initialized and ready for use.
|
||||
@@ -491,7 +493,7 @@ pub struct Pager {
|
||||
header_ref_state: RefCell<HeaderRefState>,
|
||||
#[cfg(not(feature = "omit_autovacuum"))]
|
||||
btree_create_vacuum_full_state: Cell<BtreeCreateVacuumFullState>,
|
||||
pub(crate) encryption_key: RefCell<Option<EncryptionKey>>,
|
||||
pub(crate) encryption_ctx: RefCell<Option<EncryptionContext>>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
@@ -571,13 +573,13 @@ impl Pager {
|
||||
dirty_pages: Rc::new(RefCell::new(HashSet::with_hasher(
|
||||
hash::BuildHasherDefault::new(),
|
||||
))),
|
||||
commit_info: RefCell::new(CommitInfo {
|
||||
state: CommitState::Start,
|
||||
}),
|
||||
commit_info: CommitInfo {
|
||||
state: CommitState::Start.into(),
|
||||
},
|
||||
syncing: Rc::new(Cell::new(false)),
|
||||
checkpoint_state: RefCell::new(CheckpointState::Checkpoint),
|
||||
buffer_pool,
|
||||
auto_vacuum_mode: RefCell::new(AutoVacuumMode::None),
|
||||
auto_vacuum_mode: Cell::new(AutoVacuumMode::None),
|
||||
db_state,
|
||||
init_lock,
|
||||
allocate_page1_state,
|
||||
@@ -593,7 +595,7 @@ impl Pager {
|
||||
header_ref_state: RefCell::new(HeaderRefState::Start),
|
||||
#[cfg(not(feature = "omit_autovacuum"))]
|
||||
btree_create_vacuum_full_state: Cell::new(BtreeCreateVacuumFullState::Start),
|
||||
encryption_key: RefCell::new(None),
|
||||
encryption_ctx: RefCell::new(None),
|
||||
})
|
||||
}
|
||||
|
||||
@@ -620,11 +622,11 @@ impl Pager {
|
||||
}
|
||||
|
||||
pub fn get_auto_vacuum_mode(&self) -> AutoVacuumMode {
|
||||
*self.auto_vacuum_mode.borrow()
|
||||
self.auto_vacuum_mode.get()
|
||||
}
|
||||
|
||||
pub fn set_auto_vacuum_mode(&self, mode: AutoVacuumMode) {
|
||||
*self.auto_vacuum_mode.borrow_mut() = mode;
|
||||
self.auto_vacuum_mode.set(mode);
|
||||
}
|
||||
|
||||
/// Retrieves the pointer map entry for a given database page.
|
||||
@@ -840,8 +842,8 @@ impl Pager {
|
||||
// If autovacuum is enabled, we need to allocate a new page number that is greater than the largest root page number
|
||||
#[cfg(not(feature = "omit_autovacuum"))]
|
||||
{
|
||||
let auto_vacuum_mode = self.auto_vacuum_mode.borrow();
|
||||
match *auto_vacuum_mode {
|
||||
let auto_vacuum_mode = self.auto_vacuum_mode.get();
|
||||
match auto_vacuum_mode {
|
||||
AutoVacuumMode::None => {
|
||||
let page =
|
||||
return_if_io!(self.do_allocate_page(page_type, 0, BtreePageAllocMode::Any));
|
||||
@@ -1092,7 +1094,7 @@ impl Pager {
|
||||
page_idx,
|
||||
page.clone(),
|
||||
allow_empty_read,
|
||||
self.encryption_key.borrow().as_ref(),
|
||||
self.encryption_ctx.borrow().as_ref(),
|
||||
)?;
|
||||
return Ok((page, c));
|
||||
};
|
||||
@@ -1110,7 +1112,7 @@ impl Pager {
|
||||
page_idx,
|
||||
page.clone(),
|
||||
allow_empty_read,
|
||||
self.encryption_key.borrow().as_ref(),
|
||||
self.encryption_ctx.borrow().as_ref(),
|
||||
)?;
|
||||
Ok((page, c))
|
||||
}
|
||||
@@ -1135,7 +1137,7 @@ impl Pager {
|
||||
page_idx: usize,
|
||||
page: PageRef,
|
||||
allow_empty_read: bool,
|
||||
encryption_key: Option<&EncryptionKey>,
|
||||
encryption_key: Option<&EncryptionContext>,
|
||||
) -> Result<Completion> {
|
||||
sqlite3_ondisk::begin_read_page(
|
||||
self.db_file.clone(),
|
||||
@@ -1299,7 +1301,7 @@ impl Pager {
|
||||
};
|
||||
let mut checkpoint_result = CheckpointResult::default();
|
||||
let res = loop {
|
||||
let state = self.commit_info.borrow().state;
|
||||
let state = self.commit_info.state.get();
|
||||
trace!(?state);
|
||||
match state {
|
||||
CommitState::Start => {
|
||||
@@ -1354,35 +1356,35 @@ impl Pager {
|
||||
if completions.is_empty() {
|
||||
return Ok(IOResult::Done(PagerCommitResult::WalWritten));
|
||||
} else {
|
||||
self.commit_info.borrow_mut().state = CommitState::SyncWal;
|
||||
self.commit_info.state.set(CommitState::SyncWal);
|
||||
io_yield_many!(completions);
|
||||
}
|
||||
}
|
||||
CommitState::SyncWal => {
|
||||
self.commit_info.borrow_mut().state = CommitState::AfterSyncWal;
|
||||
self.commit_info.state.set(CommitState::AfterSyncWal);
|
||||
let c = wal.borrow_mut().sync()?;
|
||||
io_yield_one!(c);
|
||||
}
|
||||
CommitState::AfterSyncWal => {
|
||||
turso_assert!(!wal.borrow().is_syncing(), "wal should have synced");
|
||||
if wal_auto_checkpoint_disabled || !wal.borrow().should_checkpoint() {
|
||||
self.commit_info.borrow_mut().state = CommitState::Start;
|
||||
self.commit_info.state.set(CommitState::Start);
|
||||
break PagerCommitResult::WalWritten;
|
||||
}
|
||||
self.commit_info.borrow_mut().state = CommitState::Checkpoint;
|
||||
self.commit_info.state.set(CommitState::Checkpoint);
|
||||
}
|
||||
CommitState::Checkpoint => {
|
||||
checkpoint_result = return_if_io!(self.checkpoint());
|
||||
self.commit_info.borrow_mut().state = CommitState::SyncDbFile;
|
||||
self.commit_info.state.set(CommitState::SyncDbFile);
|
||||
}
|
||||
CommitState::SyncDbFile => {
|
||||
let c = sqlite3_ondisk::begin_sync(self.db_file.clone(), self.syncing.clone())?;
|
||||
self.commit_info.borrow_mut().state = CommitState::AfterSyncDbFile;
|
||||
self.commit_info.state.set(CommitState::AfterSyncDbFile);
|
||||
io_yield_one!(c);
|
||||
}
|
||||
CommitState::AfterSyncDbFile => {
|
||||
turso_assert!(!self.syncing.get(), "should have finished syncing");
|
||||
self.commit_info.borrow_mut().state = CommitState::Start;
|
||||
self.commit_info.state.set(CommitState::Start);
|
||||
break PagerCommitResult::Checkpointed(checkpoint_result);
|
||||
}
|
||||
}
|
||||
@@ -1724,7 +1726,7 @@ impl Pager {
|
||||
default_header.database_size = 1.into();
|
||||
|
||||
// if a key is set, then we will reserve space for encryption metadata
|
||||
if self.encryption_key.borrow().is_some() {
|
||||
if self.encryption_ctx.borrow().is_some() {
|
||||
default_header.reserved_space = ENCRYPTION_METADATA_SIZE as u8;
|
||||
}
|
||||
|
||||
@@ -1817,7 +1819,7 @@ impl Pager {
|
||||
// If the following conditions are met, allocate a pointer map page, add to cache and increment the database size
|
||||
// - autovacuum is enabled
|
||||
// - the last page is a pointer map page
|
||||
if matches!(*self.auto_vacuum_mode.borrow(), AutoVacuumMode::Full)
|
||||
if matches!(self.auto_vacuum_mode.get(), AutoVacuumMode::Full)
|
||||
&& is_ptrmap_page(new_db_size + 1, header.page_size.get() as usize)
|
||||
{
|
||||
// we will allocate a ptrmap page, so increment size
|
||||
@@ -2083,9 +2085,7 @@ impl Pager {
|
||||
fn reset_internal_states(&self) {
|
||||
self.checkpoint_state.replace(CheckpointState::Checkpoint);
|
||||
self.syncing.replace(false);
|
||||
self.commit_info.replace(CommitInfo {
|
||||
state: CommitState::Start,
|
||||
});
|
||||
self.commit_info.state.set(CommitState::Start);
|
||||
self.allocate_page_state.replace(AllocatePageState::Start);
|
||||
self.free_page_state.replace(FreePageState::Start);
|
||||
#[cfg(not(feature = "omit_autovacuum"))]
|
||||
@@ -2111,10 +2111,11 @@ impl Pager {
|
||||
Ok(IOResult::Done(f(header)))
|
||||
}
|
||||
|
||||
pub fn set_encryption_key(&self, key: Option<EncryptionKey>) {
|
||||
self.encryption_key.replace(key.clone());
|
||||
pub fn set_encryption_context(&self, key: &EncryptionKey) {
|
||||
let encryption_ctx = EncryptionContext::new(key).unwrap();
|
||||
self.encryption_ctx.replace(Some(encryption_ctx.clone()));
|
||||
let Some(wal) = self.wal.as_ref() else { return };
|
||||
wal.borrow_mut().set_encryption_key(key)
|
||||
wal.borrow_mut().set_encryption_context(encryption_ctx)
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -59,7 +59,7 @@ use crate::storage::btree::offset::{
|
||||
use crate::storage::btree::{payload_overflow_threshold_max, payload_overflow_threshold_min};
|
||||
use crate::storage::buffer_pool::BufferPool;
|
||||
use crate::storage::database::DatabaseStorage;
|
||||
use crate::storage::encryption::EncryptionKey;
|
||||
use crate::storage::encryption::EncryptionContext;
|
||||
use crate::storage::pager::Pager;
|
||||
use crate::storage::wal::READMARK_NOT_USED;
|
||||
use crate::types::{RawSlice, RefValue, SerialType, SerialTypeKind, TextRef, TextSubtype};
|
||||
@@ -870,7 +870,7 @@ pub fn begin_read_page(
|
||||
page: PageRef,
|
||||
page_idx: usize,
|
||||
allow_empty_read: bool,
|
||||
encryption_key: Option<&EncryptionKey>,
|
||||
encryption_key: Option<&EncryptionContext>,
|
||||
) -> Result<Completion> {
|
||||
tracing::trace!("begin_read_btree_page(page_idx = {})", page_idx);
|
||||
let buf = buffer_pool.get_page();
|
||||
@@ -965,7 +965,7 @@ pub fn write_pages_vectored(
|
||||
pager: &Pager,
|
||||
batch: BTreeMap<usize, Arc<Buffer>>,
|
||||
done_flag: Arc<AtomicBool>,
|
||||
encryption_key: Option<&EncryptionKey>,
|
||||
encryption_key: Option<&EncryptionContext>,
|
||||
) -> Result<Vec<Completion>> {
|
||||
if batch.is_empty() {
|
||||
done_flag.store(true, Ordering::Relaxed);
|
||||
|
||||
@@ -17,7 +17,7 @@ use super::sqlite3_ondisk::{self, checksum_wal, WalHeader, WAL_MAGIC_BE, WAL_MAG
|
||||
use crate::fast_lock::SpinLock;
|
||||
use crate::io::{clock, File, IO};
|
||||
use crate::result::LimboResult;
|
||||
use crate::storage::encryption::{decrypt_page, encrypt_page, EncryptionKey};
|
||||
use crate::storage::encryption::EncryptionContext;
|
||||
use crate::storage::sqlite3_ondisk::{
|
||||
begin_read_wal_frame, begin_read_wal_frame_raw, finish_read_page, prepare_wal_frame,
|
||||
write_pages_vectored, PageSize, WAL_FRAME_HEADER_SIZE, WAL_HEADER_SIZE,
|
||||
@@ -297,7 +297,7 @@ pub trait Wal: Debug {
|
||||
/// Return unique set of pages changed **after** frame_watermark position and until current WAL session max_frame_no
|
||||
fn changed_pages_after(&self, frame_watermark: u64) -> Result<Vec<u32>>;
|
||||
|
||||
fn set_encryption_key(&mut self, key: Option<EncryptionKey>);
|
||||
fn set_encryption_context(&mut self, ctx: EncryptionContext);
|
||||
|
||||
#[cfg(debug_assertions)]
|
||||
fn as_any(&self) -> &dyn std::any::Any;
|
||||
@@ -568,7 +568,7 @@ pub struct WalFile {
|
||||
/// Manages locks needed for checkpointing
|
||||
checkpoint_guard: Option<CheckpointLocks>,
|
||||
|
||||
encryption_key: RefCell<Option<EncryptionKey>>,
|
||||
encryption_ctx: RefCell<Option<EncryptionContext>>,
|
||||
}
|
||||
|
||||
impl fmt::Debug for WalFile {
|
||||
@@ -1034,7 +1034,7 @@ impl Wal for WalFile {
|
||||
page.set_locked();
|
||||
let frame = page.clone();
|
||||
let page_idx = page.get().id;
|
||||
let key = self.encryption_key.borrow().clone();
|
||||
let encryption_ctx = self.encryption_ctx.borrow().clone();
|
||||
let seq = self.header.checkpoint_seq;
|
||||
let complete = Box::new(move |res: Result<(Arc<Buffer>, i32), CompletionError>| {
|
||||
let Ok((buf, bytes_read)) = res else {
|
||||
@@ -1047,8 +1047,8 @@ impl Wal for WalFile {
|
||||
"read({bytes_read}) less than expected({buf_len}): frame_id={frame_id}"
|
||||
);
|
||||
let cloned = frame.clone();
|
||||
if let Some(key) = key.clone() {
|
||||
match decrypt_page(buf.as_slice(), page_idx, &key) {
|
||||
if let Some(ctx) = encryption_ctx.clone() {
|
||||
match ctx.decrypt_page(buf.as_slice(), page_idx) {
|
||||
Ok(decrypted_data) => {
|
||||
buf.as_mut_slice().copy_from_slice(&decrypted_data);
|
||||
}
|
||||
@@ -1213,15 +1213,15 @@ impl Wal for WalFile {
|
||||
let page_content = page.get_contents();
|
||||
let page_buf = page_content.as_ptr();
|
||||
|
||||
let key = self.encryption_key.borrow();
|
||||
let encryption_ctx = self.encryption_ctx.borrow();
|
||||
let encrypted_data = {
|
||||
if let Some(key) = key.as_ref() {
|
||||
Some(encrypt_page(page_buf, page_id, key)?)
|
||||
if let Some(key) = encryption_ctx.as_ref() {
|
||||
Some(key.encrypt_page(page_buf, page_id)?)
|
||||
} else {
|
||||
None
|
||||
}
|
||||
};
|
||||
let data_to_write = if key.as_ref().is_some() {
|
||||
let data_to_write = if encryption_ctx.as_ref().is_some() {
|
||||
encrypted_data.as_ref().unwrap().as_slice()
|
||||
} else {
|
||||
page_buf
|
||||
@@ -1374,8 +1374,8 @@ impl Wal for WalFile {
|
||||
self
|
||||
}
|
||||
|
||||
fn set_encryption_key(&mut self, key: Option<EncryptionKey>) {
|
||||
self.encryption_key.replace(key);
|
||||
fn set_encryption_context(&mut self, ctx: EncryptionContext) {
|
||||
self.encryption_ctx.replace(Some(ctx));
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1413,7 +1413,7 @@ impl WalFile {
|
||||
prev_checkpoint: CheckpointResult::default(),
|
||||
checkpoint_guard: None,
|
||||
header: *header,
|
||||
encryption_key: RefCell::new(None),
|
||||
encryption_ctx: RefCell::new(None),
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1665,7 +1665,7 @@ impl WalFile {
|
||||
pager,
|
||||
batch_map,
|
||||
done_flag,
|
||||
self.encryption_key.borrow().as_ref(),
|
||||
self.encryption_ctx.borrow().as_ref(),
|
||||
)?);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -315,7 +315,7 @@ fn update_pragma(
|
||||
PragmaName::EncryptionKey => {
|
||||
let value = parse_string(&value)?;
|
||||
let key = EncryptionKey::from_string(&value);
|
||||
connection.set_encryption_key(Some(key));
|
||||
connection.set_encryption_key(key);
|
||||
Ok((program, TransactionMode::None))
|
||||
}
|
||||
}
|
||||
|
||||
@@ -132,6 +132,7 @@ pub fn prepare_update_plan(
|
||||
Some(table) => table,
|
||||
None => bail_parse_error!("Parse error: no such table: {}", table_name),
|
||||
};
|
||||
let table_name = table.get_name();
|
||||
let iter_dir = body
|
||||
.order_by
|
||||
.first()
|
||||
@@ -149,7 +150,7 @@ pub fn prepare_update_plan(
|
||||
Table::BTree(btree_table) => Table::BTree(btree_table.clone()),
|
||||
_ => unreachable!(),
|
||||
},
|
||||
identifier: table_name.as_str().to_string(),
|
||||
identifier: table_name.to_string(),
|
||||
internal_id: program.table_reference_counter.next(),
|
||||
op: build_scan_op(&table, iter_dir),
|
||||
join_info: None,
|
||||
@@ -235,7 +236,7 @@ pub fn prepare_update_plan(
|
||||
Table::BTree(btree_table) => Table::BTree(btree_table.clone()),
|
||||
_ => unreachable!(),
|
||||
},
|
||||
identifier: table_name.as_str().to_string(),
|
||||
identifier: table_name.to_string(),
|
||||
internal_id,
|
||||
op: build_scan_op(&table, iter_dir),
|
||||
join_info: None,
|
||||
@@ -334,7 +335,7 @@ pub fn prepare_update_plan(
|
||||
|
||||
// Check what indexes will need to be updated by checking set_clauses and see
|
||||
// if a column is contained in an index.
|
||||
let indexes = schema.get_indices(table_name.as_str());
|
||||
let indexes = schema.get_indices(table_name);
|
||||
let rowid_alias_used = set_clauses
|
||||
.iter()
|
||||
.any(|(idx, _)| columns[*idx].is_rowid_alias);
|
||||
|
||||
@@ -31,7 +31,7 @@ macro_rules! peek_expect {
|
||||
expected: &[
|
||||
$($x,)*
|
||||
],
|
||||
got: token.token_type.unwrap(),
|
||||
got: tt,
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -223,10 +223,11 @@ impl<'a> Parser<'a> {
|
||||
}
|
||||
Some(token) => {
|
||||
if !found_semi {
|
||||
let tt = token.token_type.unwrap();
|
||||
return Err(Error::ParseUnexpectedToken {
|
||||
parsed_offset: (self.offset(), 1).into(),
|
||||
expected: &[TK_SEMI],
|
||||
got: token.token_type.unwrap(),
|
||||
got: tt,
|
||||
});
|
||||
}
|
||||
|
||||
@@ -253,7 +254,7 @@ impl<'a> Parser<'a> {
|
||||
}
|
||||
}
|
||||
|
||||
fn next_token(&mut self) -> Result<Option<Token<'a>>> {
|
||||
fn next_token(&mut self) -> Result<Option<&Token<'a>>> {
|
||||
debug_assert!(!self.peekable);
|
||||
let mut next = self.consume_lexer_without_whitespaces_or_comments();
|
||||
|
||||
@@ -479,9 +480,9 @@ impl<'a> Parser<'a> {
|
||||
match next {
|
||||
None => Ok(None), // EOF
|
||||
Some(Ok(tok)) => {
|
||||
self.current_token = tok.clone();
|
||||
self.current_token = tok;
|
||||
self.peekable = true;
|
||||
Ok(Some(tok))
|
||||
Ok(Some(&self.current_token))
|
||||
}
|
||||
Some(Err(err)) => Err(err),
|
||||
}
|
||||
@@ -520,16 +521,21 @@ impl<'a> Parser<'a> {
|
||||
/// Get the next token from the lexer
|
||||
#[inline]
|
||||
fn eat(&mut self) -> Result<Option<Token<'a>>> {
|
||||
let result = self.peek()?;
|
||||
self.peekable = false; // Clear the peek mark after consuming
|
||||
Ok(result)
|
||||
match self.peek()? {
|
||||
None => Ok(None),
|
||||
Some(tok) => {
|
||||
let result = tok.clone();
|
||||
self.peekable = false; // Clear the peek mark after consuming
|
||||
Ok(Some(result))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Peek at the next token without consuming it
|
||||
#[inline]
|
||||
fn peek(&mut self) -> Result<Option<Token<'a>>> {
|
||||
fn peek(&mut self) -> Result<Option<&Token<'a>>> {
|
||||
if self.peekable {
|
||||
return Ok(Some(self.current_token.clone()));
|
||||
return Ok(Some(&self.current_token));
|
||||
}
|
||||
|
||||
self.next_token()
|
||||
@@ -544,7 +550,7 @@ impl<'a> Parser<'a> {
|
||||
}
|
||||
|
||||
#[inline]
|
||||
fn peek_no_eof(&mut self) -> Result<Token<'a>> {
|
||||
fn peek_no_eof(&mut self) -> Result<&Token<'a>> {
|
||||
match self.peek()? {
|
||||
None => Err(Error::ParseUnexpectedEOF),
|
||||
Some(token) => Ok(token),
|
||||
@@ -966,7 +972,7 @@ impl<'a> Parser<'a> {
|
||||
let mut type_name = if let Some(tok) = self.peek()? {
|
||||
match tok.token_type.unwrap().fallback_id_if_ok() {
|
||||
TK_ID | TK_STRING => {
|
||||
eat_assert!(self, TK_ID, TK_STRING);
|
||||
let tok = eat_assert!(self, TK_ID, TK_STRING);
|
||||
from_bytes(tok.value)
|
||||
}
|
||||
_ => return Ok(None),
|
||||
@@ -978,7 +984,7 @@ impl<'a> Parser<'a> {
|
||||
while let Some(tok) = self.peek()? {
|
||||
match tok.token_type.unwrap().fallback_id_if_ok() {
|
||||
TK_ID | TK_STRING => {
|
||||
eat_assert!(self, TK_ID, TK_STRING);
|
||||
let tok = eat_assert!(self, TK_ID, TK_STRING);
|
||||
type_name.push(' ');
|
||||
type_name.push_str(from_bytes_as_str(tok.value));
|
||||
}
|
||||
@@ -1324,25 +1330,25 @@ impl<'a> Parser<'a> {
|
||||
Ok(Box::new(Expr::Literal(Literal::Null)))
|
||||
}
|
||||
TK_BLOB => {
|
||||
eat_assert!(self, TK_BLOB);
|
||||
let tok = eat_assert!(self, TK_BLOB);
|
||||
Ok(Box::new(Expr::Literal(Literal::Blob(from_bytes(
|
||||
tok.value,
|
||||
)))))
|
||||
}
|
||||
TK_FLOAT => {
|
||||
eat_assert!(self, TK_FLOAT);
|
||||
let tok = eat_assert!(self, TK_FLOAT);
|
||||
Ok(Box::new(Expr::Literal(Literal::Numeric(from_bytes(
|
||||
tok.value,
|
||||
)))))
|
||||
}
|
||||
TK_INTEGER => {
|
||||
eat_assert!(self, TK_INTEGER);
|
||||
let tok = eat_assert!(self, TK_INTEGER);
|
||||
Ok(Box::new(Expr::Literal(Literal::Numeric(from_bytes(
|
||||
tok.value,
|
||||
)))))
|
||||
}
|
||||
TK_VARIABLE => {
|
||||
eat_assert!(self, TK_VARIABLE);
|
||||
let tok = eat_assert!(self, TK_VARIABLE);
|
||||
Ok(Box::new(Expr::Variable(from_bytes(tok.value))))
|
||||
}
|
||||
TK_CAST => {
|
||||
|
||||
Reference in New Issue
Block a user