diff --git a/core/function.rs b/core/function.rs index c0baf807a..1c5790251 100644 --- a/core/function.rs +++ b/core/function.rs @@ -84,6 +84,7 @@ pub enum ScalarFunc { Hex, Unhex, ZeroBlob, + LastInsertRowid, } impl Display for ScalarFunc { @@ -124,6 +125,7 @@ impl Display for ScalarFunc { ScalarFunc::Hex => "hex".to_string(), ScalarFunc::Unhex => "unhex".to_string(), ScalarFunc::ZeroBlob => "zeroblob".to_string(), + ScalarFunc::LastInsertRowid => "last_insert_rowid".to_string(), }; write!(f, "{}", str) } @@ -192,6 +194,7 @@ impl Func { "date" => Ok(Func::Scalar(ScalarFunc::Date)), "time" => Ok(Func::Scalar(ScalarFunc::Time)), "typeof" => Ok(Func::Scalar(ScalarFunc::Typeof)), + "last_insert_rowid" => Ok(Func::Scalar(ScalarFunc::LastInsertRowid)), "unicode" => Ok(Func::Scalar(ScalarFunc::Unicode)), "quote" => Ok(Func::Scalar(ScalarFunc::Quote)), "sqlite_version" => Ok(Func::Scalar(ScalarFunc::SqliteVersion)), diff --git a/core/lib.rs b/core/lib.rs index bc97a7c7b..83210483a 100644 --- a/core/lib.rs +++ b/core/lib.rs @@ -20,6 +20,7 @@ use log::trace; use schema::Schema; use sqlite3_parser::ast; use sqlite3_parser::{ast::Cmd, lexer::sql::Parser}; +use std::cell::Cell; use std::rc::Weak; use std::sync::{Arc, OnceLock}; use std::{cell::RefCell, rc::Rc}; @@ -105,6 +106,7 @@ impl Database { schema: bootstrap_schema.clone(), header: db_header.clone(), db: Weak::new(), + last_insert_rowid: Cell::new(0), }); let mut schema = Schema::new(); let rows = conn.query("SELECT * FROM sqlite_schema")?; @@ -125,6 +127,7 @@ impl Database { schema: self.schema.clone(), header: self.header.clone(), db: Rc::downgrade(self), + last_insert_rowid: Cell::new(0), }) } } @@ -175,6 +178,7 @@ pub struct Connection { schema: Rc>, header: Rc>, db: Weak, // backpointer to the database holding this connection + last_insert_rowid: Cell, } impl Connection { @@ -310,6 +314,14 @@ impl Connection { }; } } + + pub fn last_insert_rowid(&self) -> u64 { + self.last_insert_rowid.get() + } + + fn update_last_rowid(&self, rowid: u64) { + self.last_insert_rowid.set(rowid); + } } pub struct Statement { diff --git a/core/translate/expr.rs b/core/translate/expr.rs index 6c0b4437d..ca87f2686 100644 --- a/core/translate/expr.rs +++ b/core/translate/expr.rs @@ -858,6 +858,16 @@ pub fn translate_expr( Ok(target_register) } + ScalarFunc::LastInsertRowid => { + let regs = program.alloc_register(); + program.emit_insn(Insn::Function { + constant_mask: 0, + start_reg: regs, + dest: target_register, + func: func_ctx, + }); + Ok(target_register) + } ScalarFunc::Concat => { let args = if let Some(args) = args { args diff --git a/core/vdbe/mod.rs b/core/vdbe/mod.rs index 906f43f73..c1fa2621a 100644 --- a/core/vdbe/mod.rs +++ b/core/vdbe/mod.rs @@ -2105,6 +2105,13 @@ impl Program { state.registers[*dest] = result; } ScalarFunc::IfNull => {} + ScalarFunc::LastInsertRowid => { + if let Some(conn) = self.connection.upgrade() { + state.registers[*dest] = OwnedValue::Integer(conn.last_insert_rowid() as i64); + } else { + state.registers[*dest] = OwnedValue::Null; + } + } ScalarFunc::Instr => { let reg_value = &state.registers[*start_reg]; let pattern_value = &state.registers[*start_reg + 1]; @@ -2314,6 +2321,12 @@ impl Program { Insn::InsertAwait { cursor_id } => { let cursor = cursors.get_mut(cursor_id).unwrap(); cursor.wait_for_completion()?; + if let Some(rowid) = cursor.rowid()? { + if let Some(conn) = self.connection.upgrade() { + println!("rowid: {}", rowid); + conn.update_last_rowid(rowid); + } + } state.pc += 1; } Insn::NewRowid {