diff --git a/core/lib.rs b/core/lib.rs index 840b4cfc8..c72ee9a3d 100644 --- a/core/lib.rs +++ b/core/lib.rs @@ -422,6 +422,12 @@ impl Connection { fn update_last_rowid(&self, rowid: u64) { self.last_insert_rowid.set(rowid); } + + pub fn set_changes(&self, nchange: i64) { + self.last_change.set(nchange); + let prev_total_changes = self.total_changes.get(); + self.total_changes.set(prev_total_changes + nchange); + } } pub struct Statement { diff --git a/core/translate/mod.rs b/core/translate/mod.rs index 4a13ec1fd..28324e943 100644 --- a/core/translate/mod.rs +++ b/core/translate/mod.rs @@ -48,6 +48,7 @@ pub fn translate( syms: &SymbolTable, ) -> Result { let mut program = ProgramBuilder::new(); + let mut change_cnt_on = false; match stmt { ast::Stmt::AlterTable(_, _) => bail_parse_error!("ALTER TABLE not supported yet"), @@ -79,6 +80,7 @@ pub fn translate( limit, .. } => { + change_cnt_on = true; translate_delete(&mut program, schema, &tbl_name, where_clause, limit, syms)?; } ast::Stmt::Detach(_) => bail_parse_error!("DETACH not supported yet"), @@ -106,6 +108,7 @@ pub fn translate( body, returning, } => { + change_cnt_on = true; translate_insert( &mut program, schema, @@ -120,7 +123,7 @@ pub fn translate( } } - Ok(program.build(database_header, connection)) + Ok(program.build(database_header, connection, change_cnt_on)) } /* Example: diff --git a/core/vdbe/builder.rs b/core/vdbe/builder.rs index 82be12f1c..07ef77307 100644 --- a/core/vdbe/builder.rs +++ b/core/vdbe/builder.rs @@ -1,5 +1,5 @@ use std::{ - cell::RefCell, + cell::{Cell, RefCell}, collections::HashMap, rc::{Rc, Weak}, }; @@ -327,6 +327,7 @@ impl ProgramBuilder { mut self, database_header: Rc>, connection: Weak, + change_cnt_on: bool, ) -> Program { self.resolve_labels(); assert!( @@ -343,6 +344,8 @@ impl ProgramBuilder { connection, auto_commit: true, parameters: self.parameters, + n_change: Cell::new(0), + change_cnt_on, } } } diff --git a/core/vdbe/mod.rs b/core/vdbe/mod.rs index 6e134a6bc..cc64c68a1 100644 --- a/core/vdbe/mod.rs +++ b/core/vdbe/mod.rs @@ -55,7 +55,7 @@ use rand::{thread_rng, Rng}; use regex::{Regex, RegexBuilder}; use sorter::Sorter; use std::borrow::BorrowMut; -use std::cell::RefCell; +use std::cell::{Cell, RefCell}; use std::collections::{BTreeMap, HashMap}; use std::num::NonZero; use std::rc::{Rc, Weak}; @@ -282,6 +282,8 @@ pub struct Program { pub parameters: crate::parameters::Parameters, pub connection: Weak, pub auto_commit: bool, + pub n_change: Cell, + pub change_cnt_on: bool, } impl Program { @@ -892,11 +894,24 @@ impl Program { return if self.auto_commit { match pager.end_tx() { Ok(crate::storage::wal::CheckpointStatus::IO) => Ok(StepResult::IO), - Ok(crate::storage::wal::CheckpointStatus::Done) => Ok(StepResult::Done), + Ok(crate::storage::wal::CheckpointStatus::Done) => { + if self.change_cnt_on { + self.connection + .upgrade() + .unwrap() + .set_changes(self.n_change.get()); + } + Ok(StepResult::Done) + } Err(e) => Err(e), } } else { - Ok(StepResult::Done) + if self.change_cnt_on { + if let Some(conn) = self.connection.upgrade() { + conn.set_changes(self.n_change.get()); + } + } + return Ok(StepResult::Done); }; } Insn::Transaction { write } => { @@ -2076,10 +2091,9 @@ impl Program { if let Some(rowid) = cursor.rowid()? { if let Some(conn) = self.connection.upgrade() { conn.update_last_rowid(rowid); - let prev_total_changes = conn.total_changes.get(); - conn.last_change.set(1); - conn.total_changes.set(prev_total_changes + 1); } + let prev_changes = self.n_change.get(); + self.n_change.set(prev_changes + 1); } } state.pc += 1;