try to speed up count(*) where 1 = 1

This commit is contained in:
Nikita Sivukhin
2025-09-02 03:07:18 +04:00
parent c374cf0c93
commit db7c6b3370
4 changed files with 120 additions and 87 deletions

View File

@@ -721,6 +721,7 @@ impl BTreeCursor {
continue;
}
}
if cell_idx >= cell_count as i32 {
self.stack.set_cell_index(cell_count as i32 - 1);
} else if !self.stack.current_cell_index_less_than_min() {
@@ -756,6 +757,7 @@ impl BTreeCursor {
continue;
}
if contents.is_leaf() {
self.going_upwards = false;
return Ok(IOResult::Done(true));
}
@@ -1204,6 +1206,14 @@ impl BTreeCursor {
}
}
loop {
let cell_idx = self.stack.current_cell_index();
let cell_count = self.stack.leaf_cell_count();
if cell_idx != -1 && cell_count.is_some() && cell_idx + 1 < cell_count.unwrap() {
self.stack.advance();
return Ok(IOResult::Done(true));
}
self.stack.set_leaf_cell_count(None);
let mem_page = self.stack.top();
let contents = mem_page.get_contents();
let cell_count = contents.cell_count();
@@ -1273,6 +1283,7 @@ impl BTreeCursor {
);
if contents.is_leaf() {
self.stack.set_leaf_cell_count(Some(cell_count as i32));
return Ok(IOResult::Done(true));
}
if is_index && self.going_upwards {
@@ -4203,7 +4214,6 @@ impl BTreeCursor {
let cursor_has_record = return_if_io!(self.get_next_record());
self.has_record.replace(cursor_has_record);
self.invalidate_record();
self.advance_state = AdvanceState::Start;
return Ok(IOResult::Done(cursor_has_record));
}
}
@@ -4231,7 +4241,6 @@ impl BTreeCursor {
let cursor_has_record = return_if_io!(self.get_prev_record());
self.has_record.replace(cursor_has_record);
self.invalidate_record();
self.advance_state = AdvanceState::Start;
return Ok(IOResult::Done(cursor_has_record));
}
}
@@ -5148,7 +5157,8 @@ impl BTreeCursor {
}
fn get_immutable_record_or_create(&self) -> std::cell::RefMut<'_, Option<ImmutableRecord>> {
if self.reusable_immutable_record.borrow().is_none() {
let mut reusable_immutable_record = self.reusable_immutable_record.borrow_mut();
if reusable_immutable_record.is_none() {
let page_size = self
.pager
.page_size
@@ -5156,9 +5166,9 @@ impl BTreeCursor {
.expect("page size is not set")
.get();
let record = ImmutableRecord::new(page_size as usize);
self.reusable_immutable_record.replace(Some(record));
reusable_immutable_record.replace(record);
}
self.reusable_immutable_record.borrow_mut()
reusable_immutable_record
}
fn get_immutable_record(&self) -> std::cell::RefMut<'_, Option<ImmutableRecord>> {
@@ -5304,7 +5314,7 @@ impl BTreeCursor {
self.context = None;
self.valid_state = CursorValidState::Valid;
return Ok(IOResult::Done(()));
}
};
let ctx = self.context.take().unwrap();
let seek_key = match ctx.key {
CursorContextKey::TableRowId(rowid) => SeekKey::TableRowId(rowid),
@@ -5985,9 +5995,12 @@ impl PageStack {
page.unpin();
}
assert!(current > 0);
self.node_states[current] = BTreeNodeState::default();
self.stack[current] = None;
assert!(current > 0);
// cell_count must be unset for last stack page by default
// (otherwise caller can think that he is at the leaf and enable hot-path optimization)
self.node_states[current - 1].cell_count = None;
self.current_page -= 1;
}
@@ -6003,6 +6016,7 @@ impl PageStack {
}
/// Current page pointer being used
#[inline(always)]
fn current(&self) -> usize {
assert!(self.current_page >= 0);
let current = self.current_page as usize;
@@ -6015,6 +6029,20 @@ impl PageStack {
self.node_states[current].cell_idx
}
/// Cell count of the current leaf page
/// Caller must ensure that this method will be called for the leag page only
fn leaf_cell_count(&self) -> Option<i32> {
let current = self.current();
self.node_states[current].cell_count
}
// Set cell count for current leaf page
// Caller must ensure that this method will be called for the leag page only
fn set_leaf_cell_count(&mut self, cell_count: Option<i32>) {
let current = self.current();
self.node_states[current].cell_count = cell_count;
}
/// Check if the current cell index is less than 0.
/// This means we have been iterating backwards and have reached the start of the page.
fn current_cell_index_less_than_min(&self) -> bool {
@@ -6024,13 +6052,14 @@ impl PageStack {
/// Advance the current cell index of the current page to the next cell.
/// We usually advance after going traversing a new page
#[instrument(skip(self), level = Level::DEBUG, name = "pagestack::advance",)]
// #[instrument(skip(self), level = Level::DEBUG, name = "pagestack::advance",)]
#[inline(always)]
fn advance(&mut self) {
let current = self.current();
tracing::trace!(
curr_cell_index = self.node_states[current].cell_idx,
node_states = ?self.node_states.iter().map(|state| state.cell_idx).collect::<Vec<_>>(),
);
// tracing::trace!(
// curr_cell_index = self.node_states[current].cell_idx,
// node_states = ?self.node_states.iter().map(|state| state.cell_idx).collect::<Vec<_>>(),
// );
self.node_states[current].cell_idx += 1;
}

View File

@@ -759,83 +759,89 @@ impl Ord for Value {
impl std::ops::Add<Value> for Value {
type Output = Value;
fn add(self, rhs: Self) -> Self::Output {
match (self, rhs) {
(Self::Integer(int_left), Self::Integer(int_right)) => {
Self::Integer(int_left + int_right)
}
(Self::Integer(int_left), Self::Float(float_right)) => {
Self::Float(int_left as f64 + float_right)
}
(Self::Float(float_left), Self::Integer(int_right)) => {
Self::Float(float_left + int_right as f64)
}
(Self::Float(float_left), Self::Float(float_right)) => {
Self::Float(float_left + float_right)
}
(Self::Text(string_left), Self::Text(string_right)) => {
Self::build_text(&(string_left.as_str().to_string() + string_right.as_str()))
}
(Self::Text(string_left), Self::Integer(int_right)) => {
Self::build_text(&(string_left.as_str().to_string() + &int_right.to_string()))
}
(Self::Integer(int_left), Self::Text(string_right)) => {
Self::build_text(&(int_left.to_string() + string_right.as_str()))
}
(Self::Text(string_left), Self::Float(float_right)) => {
let string_right = Self::Float(float_right).to_string();
Self::build_text(&(string_left.as_str().to_string() + &string_right))
}
(Self::Float(float_left), Self::Text(string_right)) => {
let string_left = Self::Float(float_left).to_string();
Self::build_text(&(string_left + string_right.as_str()))
}
(lhs, Self::Null) => lhs,
(Self::Null, rhs) => rhs,
_ => Self::Float(0.0),
}
fn add(mut self, rhs: Self) -> Self::Output {
self += rhs;
self
}
}
impl std::ops::Add<f64> for Value {
type Output = Value;
fn add(self, rhs: f64) -> Self::Output {
match self {
Self::Integer(int_left) => Self::Float(int_left as f64 + rhs),
Self::Float(float_left) => Self::Float(float_left + rhs),
_ => unreachable!(),
}
fn add(mut self, rhs: f64) -> Self::Output {
self += rhs;
self
}
}
impl std::ops::Add<i64> for Value {
type Output = Value;
fn add(self, rhs: i64) -> Self::Output {
match self {
Self::Integer(int_left) => Self::Integer(int_left + rhs),
Self::Float(float_left) => Self::Float(float_left + rhs as f64),
_ => unreachable!(),
}
fn add(mut self, rhs: i64) -> Self::Output {
self += rhs;
self
}
}
impl std::ops::AddAssign for Value {
fn add_assign(&mut self, rhs: Self) {
*self = self.clone() + rhs;
fn add_assign(mut self: &mut Self, rhs: Self) {
match (&mut self, rhs) {
(Self::Integer(int_left), Self::Integer(int_right)) => *int_left += int_right,
(Self::Integer(int_left), Self::Float(float_right)) => {
*self = Self::Float(*int_left as f64 + float_right)
}
(Self::Float(float_left), Self::Integer(int_right)) => {
*self = Self::Float(*float_left + int_right as f64)
}
(Self::Float(float_left), Self::Float(float_right)) => {
*float_left += float_right;
}
(Self::Text(string_left), Self::Text(string_right)) => {
string_left.value.extend_from_slice(&string_right.value);
string_left.subtype = TextSubtype::Text;
}
(Self::Text(string_left), Self::Integer(int_right)) => {
let string_right = int_right.to_string();
string_left.value.extend_from_slice(string_right.as_bytes());
string_left.subtype = TextSubtype::Text;
}
(Self::Integer(int_left), Self::Text(string_right)) => {
let string_left = int_left.to_string();
*self = Self::build_text(&(string_left + string_right.as_str()));
}
(Self::Text(string_left), Self::Float(float_right)) => {
let string_right = Self::Float(float_right).to_string();
string_left.value.extend_from_slice(string_right.as_bytes());
string_left.subtype = TextSubtype::Text;
}
(Self::Float(float_left), Self::Text(string_right)) => {
let string_left = Self::Float(*float_left).to_string();
*self = Self::build_text(&(string_left + string_right.as_str()));
}
(_, Self::Null) => {}
(Self::Null, rhs) => *self = rhs,
_ => *self = Self::Float(0.0),
}
}
}
impl std::ops::AddAssign<i64> for Value {
fn add_assign(&mut self, rhs: i64) {
*self = self.clone() + rhs;
match self {
Self::Integer(int_left) => *int_left += rhs,
Self::Float(float_left) => *float_left += rhs as f64,
_ => unreachable!(),
}
}
}
impl std::ops::AddAssign<f64> for Value {
fn add_assign(&mut self, rhs: f64) {
*self = self.clone() + rhs;
match self {
Self::Integer(int_left) => *self = Self::Float(*int_left as f64 + rhs),
Self::Float(float_left) => *float_left += rhs,
_ => unreachable!(),
}
}
}
@@ -2475,9 +2481,10 @@ impl<T> IOResult<T> {
#[macro_export]
macro_rules! return_if_io {
($expr:expr) => {
match $expr? {
IOResult::Done(v) => v,
IOResult::IO(io) => return Ok(IOResult::IO(io)),
match $expr {
Ok(IOResult::Done(v)) => v,
Ok(IOResult::IO(io)) => return Ok(IOResult::IO(io)),
Err(err) => return Err(err),
}
};
}

View File

@@ -118,9 +118,10 @@ macro_rules! load_insn {
macro_rules! return_if_io {
($expr:expr) => {
match $expr? {
IOResult::Done(v) => v,
IOResult::IO(io) => return Ok(InsnFunctionStepResult::IO(io)),
match $expr {
Ok(IOResult::Done(v)) => v,
Ok(IOResult::IO(io)) => return Ok(InsnFunctionStepResult::IO(io)),
Err(err) => return Err(err),
}
};
}
@@ -3494,22 +3495,22 @@ pub fn op_agg_step(
}
}
AggFunc::Count | AggFunc::Count0 => {
let col = state.registers[*col].get_value().clone();
let skip = (matches!(func, AggFunc::Count)
&& matches!(state.registers[*col].get_value(), Value::Null));
if matches!(&state.registers[*acc_reg], Register::Value(Value::Null)) {
state.registers[*acc_reg] =
Register::Aggregate(AggContext::Count(Value::Integer(0)));
}
let Register::Aggregate(agg) = state.registers[*acc_reg].borrow_mut() else {
let Register::Aggregate(agg) = &mut state.registers[*acc_reg] else {
panic!(
"Unexpected value {:?} in AggStep at register {}",
state.registers[*acc_reg], *acc_reg
);
};
let AggContext::Count(count) = agg.borrow_mut() else {
let AggContext::Count(count) = agg else {
unreachable!();
};
if !(matches!(func, AggFunc::Count) && matches!(col, Value::Null)) {
if !skip {
*count += 1;
};
}

View File

@@ -415,15 +415,11 @@ impl Register {
macro_rules! must_be_btree_cursor {
($cursor_id:expr, $cursor_ref:expr, $state:expr, $insn_name:expr) => {{
let (_, cursor_type) = $cursor_ref.get($cursor_id).unwrap();
let cursor = match cursor_type {
CursorType::BTreeTable(_) => $crate::get_cursor!($state, $cursor_id),
CursorType::BTreeIndex(_) => $crate::get_cursor!($state, $cursor_id),
CursorType::MaterializedView(_, _) => $crate::get_cursor!($state, $cursor_id),
CursorType::Pseudo(_) => panic!("{} on pseudo cursor", $insn_name),
CursorType::Sorter => panic!("{} on sorter cursor", $insn_name),
CursorType::VirtualTable(_) => panic!("{} on virtual table cursor", $insn_name),
};
cursor
if matches!(cursor_type, CursorType::BTreeTable(_) | CursorType::BTreeIndex(_)) | CursorType::Materialized(_, _) {
$state.get_cursor($cursor_id)
} else {
panic!("{} on unexpected cursor", $insn_name)
}
}};
}
@@ -471,6 +467,7 @@ impl Program {
mv_store: Option<Arc<MvStore>>,
pager: Rc<Pager>,
) -> Result<StepResult> {
let enable_tracing = tracing::enabled!(tracing::Level::TRACE);
loop {
if self.connection.closed.get() {
// Connection is closed for whatever reason, rollback the transaction.
@@ -497,7 +494,9 @@ impl Program {
// invalidate row
let _ = state.result_row.take();
let (insn, insn_function) = &self.insns[state.pc as usize];
trace_insn(self, state.pc as InsnReference, insn);
if enable_tracing {
trace_insn(self, state.pc as InsnReference, insn);
}
// Always increment VM steps for every loop iteration
state.metrics.vm_steps = state.metrics.vm_steps.saturating_add(1);
@@ -832,9 +831,6 @@ pub fn registers_to_ref_values(registers: &[Register]) -> Vec<RefValue> {
#[instrument(skip(program), level = Level::DEBUG)]
fn trace_insn(program: &Program, addr: InsnReference, insn: &Insn) {
if !tracing::enabled!(tracing::Level::TRACE) {
return;
}
tracing::trace!(
"\n{}",
explain::insn_to_str(