From aa8fcdbe54546cce13897507820ab4858293dbe2 Mon Sep 17 00:00:00 2001 From: Glauber Costa Date: Fri, 19 Sep 2025 03:58:35 -0500 Subject: [PATCH] move the aggregate operator to its own file. The code is becoming impossible to reason about with everything in operator.rs --- core/incremental/aggregate_operator.rs | 1787 ++++++++++++++++++++++++ core/incremental/mod.rs | 1 + core/incremental/operator.rs | 1151 +-------------- core/incremental/persistence.rs | 678 +-------- 4 files changed, 1796 insertions(+), 1821 deletions(-) create mode 100644 core/incremental/aggregate_operator.rs diff --git a/core/incremental/aggregate_operator.rs b/core/incremental/aggregate_operator.rs new file mode 100644 index 000000000..f4c8ece0a --- /dev/null +++ b/core/incremental/aggregate_operator.rs @@ -0,0 +1,1787 @@ +// Aggregate operator for DBSP-style incremental computation + +use crate::function::{AggFunc, Func}; +use crate::incremental::dbsp::{Delta, DeltaPair, HashableRow}; +use crate::incremental::operator::{ + generate_storage_id, ComputationTracker, DbspStateCursors, EvalState, IncrementalOperator, +}; +use crate::incremental::persistence::{ReadRecord, WriteRow}; +use crate::types::{IOResult, ImmutableRecord, RefValue, SeekKey, SeekOp, SeekResult}; +use crate::{return_and_restore_if_io, return_if_io, LimboError, Result, Value}; +use std::collections::{BTreeMap, HashMap, HashSet}; +use std::fmt::{self, Display}; +use std::sync::{Arc, Mutex}; + +/// Constants for aggregate type encoding in storage IDs (2 bits) +pub const AGG_TYPE_REGULAR: u8 = 0b00; // COUNT/SUM/AVG +pub const AGG_TYPE_MINMAX: u8 = 0b01; // MIN/MAX (BTree ordering gives both) + +#[derive(Debug, Clone, PartialEq)] +pub enum AggregateFunction { + Count, + Sum(String), + Avg(String), + Min(String), + Max(String), +} + +impl Display for AggregateFunction { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + match self { + AggregateFunction::Count => write!(f, "COUNT(*)"), + AggregateFunction::Sum(col) => write!(f, "SUM({col})"), + AggregateFunction::Avg(col) => write!(f, "AVG({col})"), + AggregateFunction::Min(col) => write!(f, "MIN({col})"), + AggregateFunction::Max(col) => write!(f, "MAX({col})"), + } + } +} + +impl AggregateFunction { + /// Get the default output column name for this aggregate function + #[inline] + pub fn default_output_name(&self) -> String { + self.to_string() + } + + /// Create an AggregateFunction from a SQL function and its arguments + /// Returns None if the function is not a supported aggregate + pub fn from_sql_function( + func: &crate::function::Func, + input_column: Option, + ) -> Option { + match func { + Func::Agg(agg_func) => { + match agg_func { + AggFunc::Count | AggFunc::Count0 => Some(AggregateFunction::Count), + AggFunc::Sum => input_column.map(AggregateFunction::Sum), + AggFunc::Avg => input_column.map(AggregateFunction::Avg), + AggFunc::Min => input_column.map(AggregateFunction::Min), + AggFunc::Max => input_column.map(AggregateFunction::Max), + _ => None, // Other aggregate functions not yet supported in DBSP + } + } + _ => None, // Not an aggregate function + } + } +} + +/// Information about a column that has MIN/MAX aggregations +#[derive(Debug, Clone)] +pub struct AggColumnInfo { + /// Index used for storage key generation + pub index: usize, + /// Whether this column has a MIN aggregate + pub has_min: bool, + /// Whether this column has a MAX aggregate + pub has_max: bool, +} + +/// Serialize a Value using SQLite's serial type format +/// This is used for MIN/MAX values that need to be stored in a compact, sortable format +pub fn serialize_value(value: &Value, blob: &mut Vec) { + let serial_type = crate::types::SerialType::from(value); + let serial_type_u64: u64 = serial_type.into(); + crate::storage::sqlite3_ondisk::write_varint_to_vec(serial_type_u64, blob); + value.serialize_serial(blob); +} + +/// Deserialize a Value using SQLite's serial type format +/// Returns the deserialized value and the number of bytes consumed +pub fn deserialize_value(blob: &[u8]) -> Option<(Value, usize)> { + let mut cursor = 0; + + // Read the serial type + let (serial_type, varint_size) = crate::storage::sqlite3_ondisk::read_varint(blob).ok()?; + cursor += varint_size; + + let serial_type_obj = crate::types::SerialType::try_from(serial_type).ok()?; + let expected_size = serial_type_obj.size(); + + // Read the value + let (value, actual_size) = + crate::storage::sqlite3_ondisk::read_value(&blob[cursor..], serial_type_obj).ok()?; + + // Verify that the actual size matches what we expected from the serial type + if actual_size != expected_size { + return None; // Data corruption - size mismatch + } + + cursor += actual_size; + + // Convert RefValue to Value + Some((value.to_owned(), cursor)) +} + +// group_key_str -> (group_key, state) +type ComputedStates = HashMap, AggregateState)>; +// group_key_str -> (column_name, value_as_hashable_row) -> accumulated_weight +pub type MinMaxDeltas = HashMap>; + +#[derive(Debug)] +enum AggregateCommitState { + Idle, + Eval { + eval_state: EvalState, + }, + PersistDelta { + delta: Delta, + computed_states: ComputedStates, + current_idx: usize, + write_row: WriteRow, + min_max_deltas: MinMaxDeltas, + }, + PersistMinMax { + delta: Delta, + min_max_persist_state: MinMaxPersistState, + }, + Done { + delta: Delta, + }, + Invalid, +} + +// Aggregate-specific eval states +#[derive(Debug)] +pub enum AggregateEvalState { + FetchKey { + delta: Delta, // Keep original delta for merge operation + current_idx: usize, + groups_to_read: Vec<(String, Vec)>, // Changed to Vec for index-based access + existing_groups: HashMap, + old_values: HashMap>, + }, + FetchData { + delta: Delta, // Keep original delta for merge operation + current_idx: usize, + groups_to_read: Vec<(String, Vec)>, // Changed to Vec for index-based access + existing_groups: HashMap, + old_values: HashMap>, + rowid: Option, // Rowid found by FetchKey (None if not found) + read_record_state: Box, + }, + RecomputeMinMax { + delta: Delta, + existing_groups: HashMap, + old_values: HashMap>, + recompute_state: Box, + }, + Done { + output: (Delta, ComputedStates), + }, +} + +/// Note that the AggregateOperator essentially implements a ZSet, even +/// though the ZSet structure is never used explicitly. The on-disk btree +/// plays the role of the set! +#[derive(Debug)] +pub struct AggregateOperator { + // Unique operator ID for indexing in persistent storage + pub operator_id: usize, + // GROUP BY columns + group_by: Vec, + // Aggregate functions to compute (including MIN/MAX) + pub aggregates: Vec, + // Column names from input + pub input_column_names: Vec, + // Map from column name to aggregate info for quick lookup + pub column_min_max: HashMap, + tracker: Option>>, + + // State machine for commit operation + commit_state: AggregateCommitState, +} + +/// State for a single group's aggregates +#[derive(Debug, Clone, Default)] +pub struct AggregateState { + // For COUNT: just the count + pub count: i64, + // For SUM: column_name -> sum value + sums: HashMap, + // For AVG: column_name -> (sum, count) for computing average + avgs: HashMap, + // For MIN: column_name -> minimum value + pub mins: HashMap, + // For MAX: column_name -> maximum value + pub maxs: HashMap, +} + +impl AggregateEvalState { + fn process_delta( + &mut self, + operator: &mut AggregateOperator, + cursors: &mut DbspStateCursors, + ) -> Result> { + loop { + match self { + AggregateEvalState::FetchKey { + delta, + current_idx, + groups_to_read, + existing_groups, + old_values, + } => { + if *current_idx >= groups_to_read.len() { + // All groups have been fetched, move to RecomputeMinMax + // Extract MIN/MAX deltas from the input delta + let min_max_deltas = operator.extract_min_max_deltas(delta); + + let recompute_state = Box::new(RecomputeMinMax::new( + min_max_deltas, + existing_groups, + operator, + )); + + *self = AggregateEvalState::RecomputeMinMax { + delta: std::mem::take(delta), + existing_groups: std::mem::take(existing_groups), + old_values: std::mem::take(old_values), + recompute_state, + }; + } else { + // Get the current group to read + let (group_key_str, _group_key) = &groups_to_read[*current_idx]; + + // Build the key for the index: (operator_id, zset_id, element_id) + // For regular aggregates, use column_index=0 and AGG_TYPE_REGULAR + let operator_storage_id = + generate_storage_id(operator.operator_id, 0, AGG_TYPE_REGULAR); + let zset_id = operator.generate_group_rowid(group_key_str); + let element_id = 0i64; // Always 0 for aggregators + + // Create index key values + let index_key_values = vec![ + Value::Integer(operator_storage_id), + Value::Integer(zset_id), + Value::Integer(element_id), + ]; + + // Create an immutable record for the index key + let index_record = + ImmutableRecord::from_values(&index_key_values, index_key_values.len()); + + // Seek in the index to find if this row exists + let seek_result = return_if_io!(cursors.index_cursor.seek( + SeekKey::IndexKey(&index_record), + SeekOp::GE { eq_only: true } + )); + + let rowid = if matches!(seek_result, SeekResult::Found) { + // Found in index, get the table rowid + // The btree code handles extracting the rowid from the index record for has_rowid indexes + return_if_io!(cursors.index_cursor.rowid()) + } else { + // Not found in index, no existing state + None + }; + + // Always transition to FetchData + let taken_existing = std::mem::take(existing_groups); + let taken_old_values = std::mem::take(old_values); + let next_state = AggregateEvalState::FetchData { + delta: std::mem::take(delta), + current_idx: *current_idx, + groups_to_read: std::mem::take(groups_to_read), + existing_groups: taken_existing, + old_values: taken_old_values, + rowid, + read_record_state: Box::new(ReadRecord::new()), + }; + *self = next_state; + } + } + AggregateEvalState::FetchData { + delta, + current_idx, + groups_to_read, + existing_groups, + old_values, + rowid, + read_record_state, + } => { + // Get the current group to read + let (group_key_str, group_key) = &groups_to_read[*current_idx]; + + // Only try to read if we have a rowid + if let Some(rowid) = rowid { + let key = SeekKey::TableRowId(*rowid); + let state = return_if_io!(read_record_state.read_record( + key, + &operator.aggregates, + &mut cursors.table_cursor + )); + // Process the fetched state + if let Some(state) = state { + let mut old_row = group_key.clone(); + old_row.extend(state.to_values(&operator.aggregates)); + old_values.insert(group_key_str.clone(), old_row); + existing_groups.insert(group_key_str.clone(), state.clone()); + } + } else { + // No rowid for this group, skipping read + } + // If no rowid, there's no existing state for this group + + // Move to next group + let next_idx = *current_idx + 1; + let taken_existing = std::mem::take(existing_groups); + let taken_old_values = std::mem::take(old_values); + let next_state = AggregateEvalState::FetchKey { + delta: std::mem::take(delta), + current_idx: next_idx, + groups_to_read: std::mem::take(groups_to_read), + existing_groups: taken_existing, + old_values: taken_old_values, + }; + *self = next_state; + } + AggregateEvalState::RecomputeMinMax { + delta, + existing_groups, + old_values, + recompute_state, + } => { + if operator.has_min_max() { + // Process MIN/MAX recomputation - this will update existing_groups with correct MIN/MAX + return_if_io!(recompute_state.process(existing_groups, operator, cursors)); + } + + // Now compute final output with updated MIN/MAX values + let (output_delta, computed_states) = + operator.merge_delta_with_existing(delta, existing_groups, old_values); + + *self = AggregateEvalState::Done { + output: (output_delta, computed_states), + }; + } + AggregateEvalState::Done { output } => { + return Ok(IOResult::Done(output.clone())); + } + } + } + } +} + +impl AggregateState { + pub fn new() -> Self { + Self::default() + } + + // Serialize the aggregate state to a binary blob including group key values + // The reason we serialize it like this, instead of just writing the actual values, is that + // The same table may have different aggregators in the circuit. They will all have different + // columns. + fn to_blob(&self, aggregates: &[AggregateFunction], group_key: &[Value]) -> Vec { + let mut blob = Vec::new(); + + // Write version byte for future compatibility + blob.push(1u8); + + // Write number of group key values + blob.extend_from_slice(&(group_key.len() as u32).to_le_bytes()); + + // Write each group key value + for value in group_key { + // Write value type tag + match value { + Value::Null => blob.push(0u8), + Value::Integer(i) => { + blob.push(1u8); + blob.extend_from_slice(&i.to_le_bytes()); + } + Value::Float(f) => { + blob.push(2u8); + blob.extend_from_slice(&f.to_le_bytes()); + } + Value::Text(s) => { + blob.push(3u8); + let text_str = s.as_str(); + let bytes = text_str.as_bytes(); + blob.extend_from_slice(&(bytes.len() as u32).to_le_bytes()); + blob.extend_from_slice(bytes); + } + Value::Blob(b) => { + blob.push(4u8); + blob.extend_from_slice(&(b.len() as u32).to_le_bytes()); + blob.extend_from_slice(b); + } + } + } + + // Write count as 8 bytes (little-endian) + blob.extend_from_slice(&self.count.to_le_bytes()); + + // Write each aggregate's state + for agg in aggregates { + match agg { + AggregateFunction::Sum(col_name) => { + let sum = self.sums.get(col_name).copied().unwrap_or(0.0); + blob.extend_from_slice(&sum.to_le_bytes()); + } + AggregateFunction::Avg(col_name) => { + let (sum, count) = self.avgs.get(col_name).copied().unwrap_or((0.0, 0)); + blob.extend_from_slice(&sum.to_le_bytes()); + blob.extend_from_slice(&count.to_le_bytes()); + } + AggregateFunction::Count => { + // Count is already written above + } + AggregateFunction::Min(col_name) => { + // Write whether we have a MIN value (1 byte) + if let Some(min_val) = self.mins.get(col_name) { + blob.push(1u8); // Has value + serialize_value(min_val, &mut blob); + } else { + blob.push(0u8); // No value + } + } + AggregateFunction::Max(col_name) => { + // Write whether we have a MAX value (1 byte) + if let Some(max_val) = self.maxs.get(col_name) { + blob.push(1u8); // Has value + serialize_value(max_val, &mut blob); + } else { + blob.push(0u8); // No value + } + } + } + } + + blob + } + + /// Deserialize aggregate state from a binary blob + /// Returns the aggregate state and the group key values + pub fn from_blob(blob: &[u8], aggregates: &[AggregateFunction]) -> Option<(Self, Vec)> { + let mut cursor = 0; + + // Check version byte + if blob.get(cursor) != Some(&1u8) { + return None; + } + cursor += 1; + + // Read number of group key values + let num_group_keys = + u32::from_le_bytes(blob.get(cursor..cursor + 4)?.try_into().ok()?) as usize; + cursor += 4; + + // Read group key values + let mut group_key = Vec::new(); + for _ in 0..num_group_keys { + let value_type = *blob.get(cursor)?; + cursor += 1; + + let value = match value_type { + 0 => Value::Null, + 1 => { + let i = i64::from_le_bytes(blob.get(cursor..cursor + 8)?.try_into().ok()?); + cursor += 8; + Value::Integer(i) + } + 2 => { + let f = f64::from_le_bytes(blob.get(cursor..cursor + 8)?.try_into().ok()?); + cursor += 8; + Value::Float(f) + } + 3 => { + let len = + u32::from_le_bytes(blob.get(cursor..cursor + 4)?.try_into().ok()?) as usize; + cursor += 4; + let bytes = blob.get(cursor..cursor + len)?; + cursor += len; + let text_str = std::str::from_utf8(bytes).ok()?; + Value::Text(text_str.to_string().into()) + } + 4 => { + let len = + u32::from_le_bytes(blob.get(cursor..cursor + 4)?.try_into().ok()?) as usize; + cursor += 4; + let bytes = blob.get(cursor..cursor + len)?; + cursor += len; + Value::Blob(bytes.to_vec()) + } + _ => return None, + }; + group_key.push(value); + } + + // Read count + let count = i64::from_le_bytes(blob.get(cursor..cursor + 8)?.try_into().ok()?); + cursor += 8; + + let mut state = Self::new(); + state.count = count; + + // Read each aggregate's state + for agg in aggregates { + match agg { + AggregateFunction::Sum(col_name) => { + let sum = f64::from_le_bytes(blob.get(cursor..cursor + 8)?.try_into().ok()?); + cursor += 8; + state.sums.insert(col_name.clone(), sum); + } + AggregateFunction::Avg(col_name) => { + let sum = f64::from_le_bytes(blob.get(cursor..cursor + 8)?.try_into().ok()?); + cursor += 8; + let count = i64::from_le_bytes(blob.get(cursor..cursor + 8)?.try_into().ok()?); + cursor += 8; + state.avgs.insert(col_name.clone(), (sum, count)); + } + AggregateFunction::Count => { + // Count was already read above + } + AggregateFunction::Min(col_name) => { + // Read whether we have a MIN value + let has_value = *blob.get(cursor)?; + cursor += 1; + + if has_value == 1 { + let (min_value, bytes_consumed) = deserialize_value(&blob[cursor..])?; + cursor += bytes_consumed; + state.mins.insert(col_name.clone(), min_value); + } + } + AggregateFunction::Max(col_name) => { + // Read whether we have a MAX value + let has_value = *blob.get(cursor)?; + cursor += 1; + + if has_value == 1 { + let (max_value, bytes_consumed) = deserialize_value(&blob[cursor..])?; + cursor += bytes_consumed; + state.maxs.insert(col_name.clone(), max_value); + } + } + } + } + + Some((state, group_key)) + } + + /// Apply a delta to this aggregate state + fn apply_delta( + &mut self, + values: &[Value], + weight: isize, + aggregates: &[AggregateFunction], + column_names: &[String], + ) { + // Update COUNT + self.count += weight as i64; + + // Update other aggregates + for agg in aggregates { + match agg { + AggregateFunction::Count => { + // Already handled above + } + AggregateFunction::Sum(col_name) => { + if let Some(idx) = column_names.iter().position(|c| c == col_name) { + if let Some(val) = values.get(idx) { + let num_val = match val { + Value::Integer(i) => *i as f64, + Value::Float(f) => *f, + _ => 0.0, + }; + *self.sums.entry(col_name.clone()).or_insert(0.0) += + num_val * weight as f64; + } + } + } + AggregateFunction::Avg(col_name) => { + if let Some(idx) = column_names.iter().position(|c| c == col_name) { + if let Some(val) = values.get(idx) { + let num_val = match val { + Value::Integer(i) => *i as f64, + Value::Float(f) => *f, + _ => 0.0, + }; + let (sum, count) = + self.avgs.entry(col_name.clone()).or_insert((0.0, 0)); + *sum += num_val * weight as f64; + *count += weight as i64; + } + } + } + AggregateFunction::Min(_col_name) | AggregateFunction::Max(_col_name) => { + // MIN/MAX cannot be handled incrementally in apply_delta because: + // + // 1. For insertions: We can't just keep the minimum/maximum value. + // We need to track ALL values to handle future deletions correctly. + // + // 2. For deletions (retractions): If we delete the current MIN/MAX, + // we need to find the next best value, which requires knowing all + // other values in the group. + // + // Example: Consider MIN(price) with values [10, 20, 30] + // - Current MIN = 10 + // - Delete 10 (weight = -1) + // - New MIN should be 20, but we can't determine this without + // having tracked all values [20, 30] + // + // Therefore, MIN/MAX processing is handled separately: + // - All input values are persisted to the index via persist_min_max() + // - When aggregates have MIN/MAX, we unconditionally transition to + // the RecomputeMinMax state machine (see EvalState::RecomputeMinMax) + // - RecomputeMinMax checks if the current MIN/MAX was deleted, and if so, + // scans the index to find the new MIN/MAX from remaining values + // + // This ensures correctness for incremental computation at the cost of + // additional I/O for MIN/MAX operations. + } + } + } + } + + /// Convert aggregate state to output values + pub fn to_values(&self, aggregates: &[AggregateFunction]) -> Vec { + let mut result = Vec::new(); + + for agg in aggregates { + match agg { + AggregateFunction::Count => { + result.push(Value::Integer(self.count)); + } + AggregateFunction::Sum(col_name) => { + let sum = self.sums.get(col_name).copied().unwrap_or(0.0); + // Return as integer if it's a whole number, otherwise as float + if sum.fract() == 0.0 { + result.push(Value::Integer(sum as i64)); + } else { + result.push(Value::Float(sum)); + } + } + AggregateFunction::Avg(col_name) => { + if let Some((sum, count)) = self.avgs.get(col_name) { + if *count > 0 { + result.push(Value::Float(sum / *count as f64)); + } else { + result.push(Value::Null); + } + } else { + result.push(Value::Null); + } + } + AggregateFunction::Min(col_name) => { + // Return the MIN value from our state + result.push(self.mins.get(col_name).cloned().unwrap_or(Value::Null)); + } + AggregateFunction::Max(col_name) => { + // Return the MAX value from our state + result.push(self.maxs.get(col_name).cloned().unwrap_or(Value::Null)); + } + } + } + + result + } +} + +impl AggregateOperator { + pub fn new( + operator_id: usize, + group_by: Vec, + aggregates: Vec, + input_column_names: Vec, + ) -> Self { + // Build map of column names to their MIN/MAX info with indices + let mut column_min_max = HashMap::new(); + let mut column_indices = HashMap::new(); + let mut current_index = 0; + + // First pass: assign indices to unique MIN/MAX columns + for agg in &aggregates { + match agg { + AggregateFunction::Min(col) | AggregateFunction::Max(col) => { + column_indices.entry(col.clone()).or_insert_with(|| { + let idx = current_index; + current_index += 1; + idx + }); + } + _ => {} + } + } + + // Second pass: build the column info map + for agg in &aggregates { + match agg { + AggregateFunction::Min(col) => { + let index = *column_indices.get(col).unwrap(); + let entry = column_min_max.entry(col.clone()).or_insert(AggColumnInfo { + index, + has_min: false, + has_max: false, + }); + entry.has_min = true; + } + AggregateFunction::Max(col) => { + let index = *column_indices.get(col).unwrap(); + let entry = column_min_max.entry(col.clone()).or_insert(AggColumnInfo { + index, + has_min: false, + has_max: false, + }); + entry.has_max = true; + } + _ => {} + } + } + + Self { + operator_id, + group_by, + aggregates, + input_column_names, + column_min_max, + tracker: None, + commit_state: AggregateCommitState::Idle, + } + } + + pub fn has_min_max(&self) -> bool { + !self.column_min_max.is_empty() + } + + fn eval_internal( + &mut self, + state: &mut EvalState, + cursors: &mut DbspStateCursors, + ) -> Result> { + match state { + EvalState::Uninitialized => { + panic!("Cannot eval AggregateOperator with Uninitialized state"); + } + EvalState::Init { deltas } => { + // Aggregate operators only use left_delta, right_delta must be empty + assert!( + deltas.right.is_empty(), + "AggregateOperator expects right_delta to be empty" + ); + + if deltas.left.changes.is_empty() { + *state = EvalState::Done; + return Ok(IOResult::Done((Delta::new(), HashMap::new()))); + } + + let mut groups_to_read = BTreeMap::new(); + for (row, _weight) in &deltas.left.changes { + let group_key = self.extract_group_key(&row.values); + let group_key_str = Self::group_key_to_string(&group_key); + groups_to_read.insert(group_key_str, group_key); + } + + let delta = std::mem::take(&mut deltas.left); + *state = EvalState::Aggregate(Box::new(AggregateEvalState::FetchKey { + delta, + current_idx: 0, + groups_to_read: groups_to_read.into_iter().collect(), + existing_groups: HashMap::new(), + old_values: HashMap::new(), + })); + } + EvalState::Aggregate(_agg_state) => { + // Already in progress, continue processing below. + } + EvalState::Done => { + panic!("unreachable state! should have returned"); + } + EvalState::Join(_) => { + panic!("Join state should not appear in aggregate operator"); + } + } + + // Process the delta through the aggregate state machine + match state { + EvalState::Aggregate(agg_state) => { + let result = return_if_io!(agg_state.process_delta(self, cursors)); + Ok(IOResult::Done(result)) + } + _ => panic!("Invalid state for aggregate processing"), + } + } + + fn merge_delta_with_existing( + &mut self, + delta: &Delta, + existing_groups: &mut HashMap, + old_values: &mut HashMap>, + ) -> (Delta, HashMap, AggregateState)>) { + let mut output_delta = Delta::new(); + let mut temp_keys: HashMap> = HashMap::new(); + + // Process each change in the delta + for (row, weight) in &delta.changes { + if let Some(tracker) = &self.tracker { + tracker.lock().unwrap().record_aggregation(); + } + + // Extract group key + let group_key = self.extract_group_key(&row.values); + let group_key_str = Self::group_key_to_string(&group_key); + + let state = existing_groups.entry(group_key_str.clone()).or_default(); + + temp_keys.insert(group_key_str.clone(), group_key.clone()); + + // Apply the delta to the temporary state + state.apply_delta( + &row.values, + *weight, + &self.aggregates, + &self.input_column_names, + ); + } + + // Generate output delta from temporary states and collect final states + let mut final_states = HashMap::new(); + + for (group_key_str, state) in existing_groups { + let group_key = temp_keys.get(group_key_str).cloned().unwrap_or_default(); + + // Generate a unique rowid for this group + let result_key = self.generate_group_rowid(group_key_str); + + if let Some(old_row_values) = old_values.get(group_key_str) { + let old_row = HashableRow::new(result_key, old_row_values.clone()); + output_delta.changes.push((old_row, -1)); + } + + // Always store the state for persistence (even if count=0, we need to delete it) + final_states.insert(group_key_str.clone(), (group_key.clone(), state.clone())); + + // Only include groups with count > 0 in the output delta + if state.count > 0 { + // Build output row: group_by columns + aggregate values + let mut output_values = group_key.clone(); + let aggregate_values = state.to_values(&self.aggregates); + output_values.extend(aggregate_values); + + let output_row = HashableRow::new(result_key, output_values.clone()); + output_delta.changes.push((output_row, 1)); + } + } + (output_delta, final_states) + } + + /// Extract MIN/MAX values from delta changes for persistence to index + fn extract_min_max_deltas(&self, delta: &Delta) -> MinMaxDeltas { + let mut min_max_deltas: MinMaxDeltas = HashMap::new(); + + for (row, weight) in &delta.changes { + let group_key = self.extract_group_key(&row.values); + let group_key_str = Self::group_key_to_string(&group_key); + + for agg in &self.aggregates { + match agg { + AggregateFunction::Min(col_name) | AggregateFunction::Max(col_name) => { + if let Some(idx) = + self.input_column_names.iter().position(|c| c == col_name) + { + if let Some(val) = row.values.get(idx) { + // Skip NULL values - they don't participate in MIN/MAX + if val == &Value::Null { + continue; + } + // Create a HashableRow with just this value + // Use 0 as rowid since we only care about the value for comparison + let hashable_value = HashableRow::new(0, vec![val.clone()]); + let key = (col_name.clone(), hashable_value); + + let group_entry = + min_max_deltas.entry(group_key_str.clone()).or_default(); + + let value_entry = group_entry.entry(key).or_insert(0); + + // Accumulate the weight + *value_entry += weight; + } + } + } + _ => {} // Ignore non-MIN/MAX aggregates + } + } + } + + min_max_deltas + } + + pub fn set_tracker(&mut self, tracker: Arc>) { + self.tracker = Some(tracker); + } + + /// Generate a rowid for a group + /// For no GROUP BY: always returns 0 + /// For GROUP BY: returns a hash of the group key string + pub fn generate_group_rowid(&self, group_key_str: &str) -> i64 { + if self.group_by.is_empty() { + 0 + } else { + group_key_str + .bytes() + .fold(0i64, |acc, b| acc.wrapping_mul(31).wrapping_add(b as i64)) + } + } + + /// Extract group key values from a row + pub fn extract_group_key(&self, values: &[Value]) -> Vec { + let mut key = Vec::new(); + + for group_col in &self.group_by { + if let Some(idx) = self.input_column_names.iter().position(|c| c == group_col) { + if let Some(val) = values.get(idx) { + key.push(val.clone()); + } else { + key.push(Value::Null); + } + } else { + key.push(Value::Null); + } + } + + key + } + + /// Convert group key to string for indexing (since Value doesn't implement Hash) + pub fn group_key_to_string(key: &[Value]) -> String { + key.iter() + .map(|v| format!("{v:?}")) + .collect::>() + .join(",") + } +} + +impl IncrementalOperator for AggregateOperator { + fn eval( + &mut self, + state: &mut EvalState, + cursors: &mut DbspStateCursors, + ) -> Result> { + let (delta, _) = return_if_io!(self.eval_internal(state, cursors)); + Ok(IOResult::Done(delta)) + } + + fn commit( + &mut self, + mut deltas: DeltaPair, + cursors: &mut DbspStateCursors, + ) -> Result> { + // Aggregate operator only uses left delta, right must be empty + assert!( + deltas.right.is_empty(), + "AggregateOperator expects right delta to be empty in commit" + ); + let delta = std::mem::take(&mut deltas.left); + loop { + // Note: because we std::mem::replace here (without it, the borrow checker goes nuts, + // because we call self.eval_interval, which requires a mutable borrow), we have to + // restore the state if we return I/O. So we can't use return_if_io! + let mut state = + std::mem::replace(&mut self.commit_state, AggregateCommitState::Invalid); + match &mut state { + AggregateCommitState::Invalid => { + panic!("Reached invalid state! State was replaced, and not replaced back"); + } + AggregateCommitState::Idle => { + let eval_state = EvalState::from_delta(delta.clone()); + self.commit_state = AggregateCommitState::Eval { eval_state }; + } + AggregateCommitState::Eval { ref mut eval_state } => { + // Extract input delta before eval for MIN/MAX processing + let input_delta = eval_state.extract_delta(); + + // Extract MIN/MAX deltas before any I/O operations + let min_max_deltas = self.extract_min_max_deltas(&input_delta); + + // Create a new eval state with the same delta + *eval_state = EvalState::from_delta(input_delta.clone()); + + let (output_delta, computed_states) = return_and_restore_if_io!( + &mut self.commit_state, + state, + self.eval_internal(eval_state, cursors) + ); + + self.commit_state = AggregateCommitState::PersistDelta { + delta: output_delta, + computed_states, + current_idx: 0, + write_row: WriteRow::new(), + min_max_deltas, // Store for later use + }; + } + AggregateCommitState::PersistDelta { + delta, + computed_states, + current_idx, + write_row, + min_max_deltas, + } => { + let states_vec: Vec<_> = computed_states.iter().collect(); + + if *current_idx >= states_vec.len() { + // Use the min_max_deltas we extracted earlier from the input delta + self.commit_state = AggregateCommitState::PersistMinMax { + delta: delta.clone(), + min_max_persist_state: MinMaxPersistState::new(min_max_deltas.clone()), + }; + } else { + let (group_key_str, (group_key, agg_state)) = states_vec[*current_idx]; + + // Build the key components for the new table structure + // For regular aggregates, use column_index=0 and AGG_TYPE_REGULAR + let operator_storage_id = + generate_storage_id(self.operator_id, 0, AGG_TYPE_REGULAR); + let zset_id = self.generate_group_rowid(group_key_str); + let element_id = 0i64; + + // Determine weight: -1 to delete (cancels existing weight=1), 1 to insert/update + let weight = if agg_state.count == 0 { -1 } else { 1 }; + + // Serialize the aggregate state with group key (even for deletion, we need a row) + let state_blob = agg_state.to_blob(&self.aggregates, group_key); + let blob_value = Value::Blob(state_blob); + + // Build the aggregate storage format: [operator_id, zset_id, element_id, value, weight] + let operator_id_val = Value::Integer(operator_storage_id); + let zset_id_val = Value::Integer(zset_id); + let element_id_val = Value::Integer(element_id); + let blob_val = blob_value.clone(); + + // Create index key - the first 3 columns of our primary key + let index_key = vec![ + operator_id_val.clone(), + zset_id_val.clone(), + element_id_val.clone(), + ]; + + // Record values (without weight) + let record_values = + vec![operator_id_val, zset_id_val, element_id_val, blob_val]; + + return_and_restore_if_io!( + &mut self.commit_state, + state, + write_row.write_row(cursors, index_key, record_values, weight) + ); + + let delta = std::mem::take(delta); + let computed_states = std::mem::take(computed_states); + let min_max_deltas = std::mem::take(min_max_deltas); + + self.commit_state = AggregateCommitState::PersistDelta { + delta, + computed_states, + current_idx: *current_idx + 1, + write_row: WriteRow::new(), // Reset for next write + min_max_deltas, + }; + } + } + AggregateCommitState::PersistMinMax { + delta, + min_max_persist_state, + } => { + if !self.has_min_max() { + let delta = std::mem::take(delta); + self.commit_state = AggregateCommitState::Done { delta }; + } else { + return_and_restore_if_io!( + &mut self.commit_state, + state, + min_max_persist_state.persist_min_max( + self.operator_id, + &self.column_min_max, + cursors, + |group_key_str| self.generate_group_rowid(group_key_str) + ) + ); + + let delta = std::mem::take(delta); + self.commit_state = AggregateCommitState::Done { delta }; + } + } + AggregateCommitState::Done { delta } => { + self.commit_state = AggregateCommitState::Idle; + let delta = std::mem::take(delta); + return Ok(IOResult::Done(delta)); + } + } + } + } + + fn set_tracker(&mut self, tracker: Arc>) { + self.tracker = Some(tracker); + } +} + +/// State machine for recomputing MIN/MAX values after deletion +#[derive(Debug)] +pub enum RecomputeMinMax { + ProcessElements { + /// Current column being processed + current_column_idx: usize, + /// Columns to process (combined MIN and MAX) + columns_to_process: Vec<(String, String, bool)>, // (group_key, column_name, is_min) + /// MIN/MAX deltas for checking values and weights + min_max_deltas: MinMaxDeltas, + }, + Scan { + /// Columns still to process + columns_to_process: Vec<(String, String, bool)>, + /// Current index in columns_to_process (will resume from here) + current_column_idx: usize, + /// MIN/MAX deltas for checking values and weights + min_max_deltas: MinMaxDeltas, + /// Current group key being processed + group_key: String, + /// Current column name being processed + column_name: String, + /// Whether we're looking for MIN (true) or MAX (false) + is_min: bool, + /// The scan state machine for finding the new MIN/MAX + scan_state: Box, + }, + Done, +} + +impl RecomputeMinMax { + pub fn new( + min_max_deltas: MinMaxDeltas, + existing_groups: &HashMap, + operator: &AggregateOperator, + ) -> Self { + let mut groups_to_check: HashSet<(String, String, bool)> = HashSet::new(); + + // Remember the min_max_deltas are essentially just the only column that is affected by + // this min/max, in delta (actually ZSet - consolidated delta) format. This makes it easier + // for us to consume it in here. + // + // The most challenging case is the case where there is a retraction, since we need to go + // back to the index. + for (group_key_str, values) in &min_max_deltas { + for ((col_name, hashable_row), weight) in values { + let col_info = operator.column_min_max.get(col_name); + + let value = &hashable_row.values[0]; + + if *weight < 0 { + // Deletion detected - check if it's the current MIN/MAX + if let Some(state) = existing_groups.get(group_key_str) { + // Check for MIN + if let Some(current_min) = state.mins.get(col_name) { + if current_min == value { + groups_to_check.insert(( + group_key_str.clone(), + col_name.clone(), + true, + )); + } + } + // Check for MAX + if let Some(current_max) = state.maxs.get(col_name) { + if current_max == value { + groups_to_check.insert(( + group_key_str.clone(), + col_name.clone(), + false, + )); + } + } + } + } else if *weight > 0 { + // If it is not found in the existing groups, then we only need to care + // about this if this is a new record being inserted + if let Some(info) = col_info { + if info.has_min { + groups_to_check.insert((group_key_str.clone(), col_name.clone(), true)); + } + if info.has_max { + groups_to_check.insert(( + group_key_str.clone(), + col_name.clone(), + false, + )); + } + } + } + } + } + + if groups_to_check.is_empty() { + // No recomputation or initialization needed + Self::Done + } else { + // Convert HashSet to Vec for indexed processing + let groups_to_check_vec: Vec<_> = groups_to_check.into_iter().collect(); + Self::ProcessElements { + current_column_idx: 0, + columns_to_process: groups_to_check_vec, + min_max_deltas, + } + } + } + + pub fn process( + &mut self, + existing_groups: &mut HashMap, + operator: &AggregateOperator, + cursors: &mut DbspStateCursors, + ) -> Result> { + loop { + match self { + RecomputeMinMax::ProcessElements { + current_column_idx, + columns_to_process, + min_max_deltas, + } => { + if *current_column_idx >= columns_to_process.len() { + *self = RecomputeMinMax::Done; + return Ok(IOResult::Done(())); + } + + let (group_key, column_name, is_min) = + columns_to_process[*current_column_idx].clone(); + + // Get column index from pre-computed info + let column_index = operator + .column_min_max + .get(&column_name) + .map(|info| info.index) + .unwrap(); // Should always exist since we're processing known columns + + // Get current value from existing state + let current_value = existing_groups.get(&group_key).and_then(|state| { + if is_min { + state.mins.get(&column_name).cloned() + } else { + state.maxs.get(&column_name).cloned() + } + }); + + // Create storage keys for index lookup + let storage_id = + generate_storage_id(operator.operator_id, column_index, AGG_TYPE_MINMAX); + let zset_id = operator.generate_group_rowid(&group_key); + + // Get the values for this group from min_max_deltas + let group_values = min_max_deltas.get(&group_key).cloned().unwrap_or_default(); + + let columns_to_process = std::mem::take(columns_to_process); + let min_max_deltas = std::mem::take(min_max_deltas); + + let scan_state = if is_min { + Box::new(ScanState::new_for_min( + current_value, + group_key.clone(), + column_name.clone(), + storage_id, + zset_id, + group_values, + )) + } else { + Box::new(ScanState::new_for_max( + current_value, + group_key.clone(), + column_name.clone(), + storage_id, + zset_id, + group_values, + )) + }; + + *self = RecomputeMinMax::Scan { + columns_to_process, + current_column_idx: *current_column_idx, + min_max_deltas, + group_key, + column_name, + is_min, + scan_state, + }; + } + RecomputeMinMax::Scan { + columns_to_process, + current_column_idx, + min_max_deltas, + group_key, + column_name, + is_min, + scan_state, + } => { + // Find new value using the scan state machine + let new_value = return_if_io!(scan_state.find_new_value(cursors)); + + // Update the state with new value (create if doesn't exist) + let state = existing_groups.entry(group_key.clone()).or_default(); + + if *is_min { + if let Some(min_val) = new_value { + state.mins.insert(column_name.clone(), min_val); + } else { + state.mins.remove(column_name); + } + } else if let Some(max_val) = new_value { + state.maxs.insert(column_name.clone(), max_val); + } else { + state.maxs.remove(column_name); + } + + // Move to next column + let min_max_deltas = std::mem::take(min_max_deltas); + let columns_to_process = std::mem::take(columns_to_process); + *self = RecomputeMinMax::ProcessElements { + current_column_idx: *current_column_idx + 1, + columns_to_process, + min_max_deltas, + }; + } + RecomputeMinMax::Done => { + return Ok(IOResult::Done(())); + } + } + } + } +} + +/// State machine for scanning through the index to find new MIN/MAX values +#[derive(Debug)] +pub enum ScanState { + CheckCandidate { + /// Current candidate value for MIN/MAX + candidate: Option, + /// Group key being processed + group_key: String, + /// Column name being processed + column_name: String, + /// Storage ID for the index seek + storage_id: i64, + /// ZSet ID for the group + zset_id: i64, + /// Group values from MinMaxDeltas: (column_name, HashableRow) -> weight + group_values: HashMap<(String, HashableRow), isize>, + /// Whether we're looking for MIN (true) or MAX (false) + is_min: bool, + }, + FetchNextCandidate { + /// Current candidate to seek past + current_candidate: Value, + /// Group key being processed + group_key: String, + /// Column name being processed + column_name: String, + /// Storage ID for the index seek + storage_id: i64, + /// ZSet ID for the group + zset_id: i64, + /// Group values from MinMaxDeltas: (column_name, HashableRow) -> weight + group_values: HashMap<(String, HashableRow), isize>, + /// Whether we're looking for MIN (true) or MAX (false) + is_min: bool, + }, + Done { + /// The final MIN/MAX value found + result: Option, + }, +} + +impl ScanState { + pub fn new_for_min( + current_min: Option, + group_key: String, + column_name: String, + storage_id: i64, + zset_id: i64, + group_values: HashMap<(String, HashableRow), isize>, + ) -> Self { + Self::CheckCandidate { + candidate: current_min, + group_key, + column_name, + storage_id, + zset_id, + group_values, + is_min: true, + } + } + + // Extract a new candidate from the index. It is possible that, when searching, + // we end up going into a different operator altogether. That means we have + // exhausted this operator (or group) entirely, and no good candidate was found + fn extract_new_candidate( + cursors: &mut DbspStateCursors, + index_record: &ImmutableRecord, + seek_op: SeekOp, + storage_id: i64, + zset_id: i64, + ) -> Result>> { + let seek_result = return_if_io!(cursors + .index_cursor + .seek(SeekKey::IndexKey(index_record), seek_op)); + if !matches!(seek_result, SeekResult::Found) { + return Ok(IOResult::Done(None)); + } + + let record = return_if_io!(cursors.index_cursor.record()).ok_or_else(|| { + LimboError::InternalError( + "Record found on the cursor, but could not be read".to_string(), + ) + })?; + + let values = record.get_values(); + if values.len() < 3 { + return Ok(IOResult::Done(None)); + } + + let Some(rec_storage_id) = values.first() else { + return Ok(IOResult::Done(None)); + }; + let Some(rec_zset_id) = values.get(1) else { + return Ok(IOResult::Done(None)); + }; + + // Check if we're still in the same group + if let (RefValue::Integer(rec_sid), RefValue::Integer(rec_zid)) = + (rec_storage_id, rec_zset_id) + { + if *rec_sid != storage_id || *rec_zid != zset_id { + return Ok(IOResult::Done(None)); + } + } else { + return Ok(IOResult::Done(None)); + } + + // Get the value (3rd element) + Ok(IOResult::Done(values.get(2).map(|v| v.to_owned()))) + } + + pub fn new_for_max( + current_max: Option, + group_key: String, + column_name: String, + storage_id: i64, + zset_id: i64, + group_values: HashMap<(String, HashableRow), isize>, + ) -> Self { + Self::CheckCandidate { + candidate: current_max, + group_key, + column_name, + storage_id, + zset_id, + group_values, + is_min: false, + } + } + + pub fn find_new_value( + &mut self, + cursors: &mut DbspStateCursors, + ) -> Result>> { + loop { + match self { + ScanState::CheckCandidate { + candidate, + group_key, + column_name, + storage_id, + zset_id, + group_values, + is_min, + } => { + // First, check if we have a candidate + if let Some(cand_val) = candidate { + // Check if the candidate is retracted (weight <= 0) + // Create a HashableRow to look up the weight + let hashable_cand = HashableRow::new(0, vec![cand_val.clone()]); + let key = (column_name.clone(), hashable_cand); + let is_retracted = + group_values.get(&key).is_some_and(|weight| *weight <= 0); + + if is_retracted { + // Candidate is retracted, need to fetch next from index + *self = ScanState::FetchNextCandidate { + current_candidate: cand_val.clone(), + group_key: std::mem::take(group_key), + column_name: std::mem::take(column_name), + storage_id: *storage_id, + zset_id: *zset_id, + group_values: std::mem::take(group_values), + is_min: *is_min, + }; + continue; + } + } + + // Candidate is valid or we have no candidate + // Now find the best value from insertions in group_values + let mut best_from_zset = None; + for ((col, hashable_val), weight) in group_values.iter() { + if col == column_name && *weight > 0 { + let value = &hashable_val.values[0]; + // Skip NULL values - they don't participate in MIN/MAX + if value == &Value::Null { + continue; + } + // This is an insertion for our column + if let Some(ref current_best) = best_from_zset { + if *is_min { + if value.cmp(current_best) == std::cmp::Ordering::Less { + best_from_zset = Some(value.clone()); + } + } else if value.cmp(current_best) == std::cmp::Ordering::Greater { + best_from_zset = Some(value.clone()); + } + } else { + best_from_zset = Some(value.clone()); + } + } + } + + // Compare candidate with best from ZSet, filtering out NULLs + let result = match (&candidate, &best_from_zset) { + (Some(cand), Some(zset_val)) if cand != &Value::Null => { + if *is_min { + if zset_val.cmp(cand) == std::cmp::Ordering::Less { + Some(zset_val.clone()) + } else { + Some(cand.clone()) + } + } else if zset_val.cmp(cand) == std::cmp::Ordering::Greater { + Some(zset_val.clone()) + } else { + Some(cand.clone()) + } + } + (Some(cand), None) if cand != &Value::Null => Some(cand.clone()), + (None, Some(zset_val)) => Some(zset_val.clone()), + (Some(cand), Some(_)) if cand == &Value::Null => best_from_zset, + _ => None, + }; + + *self = ScanState::Done { result }; + } + + ScanState::FetchNextCandidate { + current_candidate, + group_key, + column_name, + storage_id, + zset_id, + group_values, + is_min, + } => { + // Seek to the next value in the index + let index_key = vec![ + Value::Integer(*storage_id), + Value::Integer(*zset_id), + current_candidate.clone(), + ]; + let index_record = ImmutableRecord::from_values(&index_key, index_key.len()); + + let seek_op = if *is_min { + SeekOp::GT // For MIN, seek greater than current + } else { + SeekOp::LT // For MAX, seek less than current + }; + + let new_candidate = return_if_io!(Self::extract_new_candidate( + cursors, + &index_record, + seek_op, + *storage_id, + *zset_id + )); + + *self = ScanState::CheckCandidate { + candidate: new_candidate, + group_key: std::mem::take(group_key), + column_name: std::mem::take(column_name), + storage_id: *storage_id, + zset_id: *zset_id, + group_values: std::mem::take(group_values), + is_min: *is_min, + }; + } + + ScanState::Done { result } => { + return Ok(IOResult::Done(result.clone())); + } + } + } + } +} + +/// State machine for persisting Min/Max values to storage +#[derive(Debug)] +pub enum MinMaxPersistState { + Init { + min_max_deltas: MinMaxDeltas, + group_keys: Vec, + }, + ProcessGroup { + min_max_deltas: MinMaxDeltas, + group_keys: Vec, + group_idx: usize, + value_idx: usize, + }, + WriteValue { + min_max_deltas: MinMaxDeltas, + group_keys: Vec, + group_idx: usize, + value_idx: usize, + value: Value, + column_name: String, + weight: isize, + write_row: WriteRow, + }, + Done, +} + +impl MinMaxPersistState { + pub fn new(min_max_deltas: MinMaxDeltas) -> Self { + let group_keys: Vec = min_max_deltas.keys().cloned().collect(); + Self::Init { + min_max_deltas, + group_keys, + } + } + + pub fn persist_min_max( + &mut self, + operator_id: usize, + column_min_max: &HashMap, + cursors: &mut DbspStateCursors, + generate_group_rowid: impl Fn(&str) -> i64, + ) -> Result> { + loop { + match self { + MinMaxPersistState::Init { + min_max_deltas, + group_keys, + } => { + let min_max_deltas = std::mem::take(min_max_deltas); + let group_keys = std::mem::take(group_keys); + *self = MinMaxPersistState::ProcessGroup { + min_max_deltas, + group_keys, + group_idx: 0, + value_idx: 0, + }; + } + MinMaxPersistState::ProcessGroup { + min_max_deltas, + group_keys, + group_idx, + value_idx, + } => { + // Check if we're past all groups + if *group_idx >= group_keys.len() { + *self = MinMaxPersistState::Done; + continue; + } + + let group_key_str = &group_keys[*group_idx]; + let values = &min_max_deltas[group_key_str]; // This should always exist + + // Convert HashMap to Vec for indexed access + let values_vec: Vec<_> = values.iter().collect(); + + // Check if we have more values in current group + if *value_idx >= values_vec.len() { + *group_idx += 1; + *value_idx = 0; + // Continue to check if we're past all groups now + continue; + } + + // Process current value and extract what we need before taking ownership + let ((column_name, hashable_row), weight) = values_vec[*value_idx]; + let column_name = column_name.clone(); + let value = hashable_row.values[0].clone(); // Extract the Value from HashableRow + let weight = *weight; + + let min_max_deltas = std::mem::take(min_max_deltas); + let group_keys = std::mem::take(group_keys); + *self = MinMaxPersistState::WriteValue { + min_max_deltas, + group_keys, + group_idx: *group_idx, + value_idx: *value_idx, + column_name, + value, + weight, + write_row: WriteRow::new(), + }; + } + MinMaxPersistState::WriteValue { + min_max_deltas, + group_keys, + group_idx, + value_idx, + value, + column_name, + weight, + write_row, + } => { + // Should have exited in the previous state + assert!(*group_idx < group_keys.len()); + + let group_key_str = &group_keys[*group_idx]; + + // Get the column index from the pre-computed map + let column_info = column_min_max + .get(&*column_name) + .expect("Column should exist in column_min_max map"); + let column_index = column_info.index; + + // Build the key components for MinMax storage using new encoding + let storage_id = + generate_storage_id(operator_id, column_index, AGG_TYPE_MINMAX); + let zset_id = generate_group_rowid(group_key_str); + + // element_id is the actual value for Min/Max + let element_id_val = value.clone(); + + // Create index key + let index_key = vec![ + Value::Integer(storage_id), + Value::Integer(zset_id), + element_id_val.clone(), + ]; + + // Record values (operator_id, zset_id, element_id, unused_placeholder) + // For MIN/MAX, the element_id IS the value, so we use NULL for the 4th column + let record_values = vec![ + Value::Integer(storage_id), + Value::Integer(zset_id), + element_id_val.clone(), + Value::Null, // Placeholder - not used for MIN/MAX + ]; + + return_if_io!(write_row.write_row( + cursors, + index_key.clone(), + record_values, + *weight + )); + + // Move to next value + let min_max_deltas = std::mem::take(min_max_deltas); + let group_keys = std::mem::take(group_keys); + *self = MinMaxPersistState::ProcessGroup { + min_max_deltas, + group_keys, + group_idx: *group_idx, + value_idx: *value_idx + 1, + }; + } + MinMaxPersistState::Done => { + return Ok(IOResult::Done(())); + } + } + } + } +} diff --git a/core/incremental/mod.rs b/core/incremental/mod.rs index 0e45b3194..a747809d9 100644 --- a/core/incremental/mod.rs +++ b/core/incremental/mod.rs @@ -1,3 +1,4 @@ +pub mod aggregate_operator; pub mod compiler; pub mod cursor; pub mod dbsp; diff --git a/core/incremental/operator.rs b/core/incremental/operator.rs index 43ad8f67c..92b35d5f1 100644 --- a/core/incremental/operator.rs +++ b/core/incremental/operator.rs @@ -2,19 +2,20 @@ // Operator DAG for DBSP-style incremental computation // Based on Feldera DBSP design but adapted for Turso's architecture +pub use crate::incremental::aggregate_operator::{ + AggregateEvalState, AggregateFunction, AggregateOperator, AggregateState, +}; pub use crate::incremental::filter_operator::{FilterOperator, FilterPredicate}; pub use crate::incremental::input_operator::InputOperator; pub use crate::incremental::project_operator::{ProjectColumn, ProjectOperator}; -use crate::function::{AggFunc, Func}; use crate::incremental::dbsp::{Delta, DeltaPair, HashableRow}; -use crate::incremental::persistence::{MinMaxPersistState, ReadRecord, RecomputeMinMax, WriteRow}; +use crate::incremental::persistence::WriteRow; use crate::schema::{Index, IndexColumn}; use crate::storage::btree::BTreeCursor; use crate::types::{IOResult, ImmutableRecord, SeekKey, SeekOp, SeekResult}; use crate::{return_and_restore_if_io, return_if_io, Result, Value}; -use std::collections::{BTreeMap, HashMap}; -use std::fmt::{self, Debug, Display}; +use std::fmt::Debug; use std::sync::{Arc, Mutex}; /// Struct to hold both table and index cursors for DBSP state operations @@ -71,12 +72,6 @@ pub fn create_dbsp_state_index(root_page: usize) -> Index { } } -/// Constants for aggregate type encoding in storage IDs (2 bits) -pub const AGG_TYPE_REGULAR: u8 = 0b00; // COUNT/SUM/AVG -pub const AGG_TYPE_MINMAX: u8 = 0b01; // MIN/MAX (BTree ordering gives both) -pub const AGG_TYPE_RESERVED1: u8 = 0b10; // Reserved for future use -pub const AGG_TYPE_RESERVED2: u8 = 0b11; // Reserved for future use - /// Generate a storage ID with column index and operation type encoding /// Storage ID = (operator_id << 16) | (column_index << 2) | operation_type /// Bit layout (64-bit integer): @@ -90,64 +85,6 @@ pub fn generate_storage_id(operator_id: usize, column_index: usize, op_type: u8) ((operator_id as i64) << 16) | ((column_index as i64) << 2) | (op_type as i64) } -// group_key_str -> (group_key, state) -type ComputedStates = HashMap, AggregateState)>; -// group_key_str -> (column_name, value_as_hashable_row) -> accumulated_weight -pub type MinMaxDeltas = HashMap>; - -#[derive(Debug)] -enum AggregateCommitState { - Idle, - Eval { - eval_state: EvalState, - }, - PersistDelta { - delta: Delta, - computed_states: ComputedStates, - current_idx: usize, - write_row: WriteRow, - min_max_deltas: MinMaxDeltas, - }, - PersistMinMax { - delta: Delta, - min_max_persist_state: MinMaxPersistState, - }, - Done { - delta: Delta, - }, - Invalid, -} - -// Aggregate-specific eval states -#[derive(Debug)] -pub enum AggregateEvalState { - FetchKey { - delta: Delta, // Keep original delta for merge operation - current_idx: usize, - groups_to_read: Vec<(String, Vec)>, // Changed to Vec for index-based access - existing_groups: HashMap, - old_values: HashMap>, - }, - FetchData { - delta: Delta, // Keep original delta for merge operation - current_idx: usize, - groups_to_read: Vec<(String, Vec)>, // Changed to Vec for index-based access - existing_groups: HashMap, - old_values: HashMap>, - rowid: Option, // Rowid found by FetchKey (None if not found) - read_record_state: Box, - }, - RecomputeMinMax { - delta: Delta, - existing_groups: HashMap, - old_values: HashMap>, - recompute_state: Box, - }, - Done { - output: (Delta, ComputedStates), - }, -} - // Helper function to read the next row from the BTree for joins fn read_next_join_row( storage_id: i64, @@ -476,180 +413,6 @@ impl EvalState { _ => panic!("extract_delta() can only be called when in Init state"), } } - - fn advance_aggregate(&mut self, groups_to_read: BTreeMap>) { - let delta = match self { - EvalState::Init { deltas } => std::mem::take(&mut deltas.left), - _ => panic!("advance_aggregate() can only be called when in Init state, current state: {self:?}"), - }; - - let _ = std::mem::replace( - self, - EvalState::Aggregate(Box::new(AggregateEvalState::FetchKey { - delta, - current_idx: 0, - groups_to_read: groups_to_read.into_iter().collect(), // Convert BTreeMap to Vec - existing_groups: HashMap::new(), - old_values: HashMap::new(), - })), - ); - } -} - -impl AggregateEvalState { - fn process_delta( - &mut self, - operator: &mut AggregateOperator, - cursors: &mut DbspStateCursors, - ) -> Result> { - loop { - match self { - AggregateEvalState::FetchKey { - delta, - current_idx, - groups_to_read, - existing_groups, - old_values, - } => { - if *current_idx >= groups_to_read.len() { - // All groups have been fetched, move to RecomputeMinMax - // Extract MIN/MAX deltas from the input delta - let min_max_deltas = operator.extract_min_max_deltas(delta); - - let recompute_state = Box::new(RecomputeMinMax::new( - min_max_deltas, - existing_groups, - operator, - )); - - *self = AggregateEvalState::RecomputeMinMax { - delta: std::mem::take(delta), - existing_groups: std::mem::take(existing_groups), - old_values: std::mem::take(old_values), - recompute_state, - }; - } else { - // Get the current group to read - let (group_key_str, _group_key) = &groups_to_read[*current_idx]; - - // Build the key for the index: (operator_id, zset_id, element_id) - // For regular aggregates, use column_index=0 and AGG_TYPE_REGULAR - let operator_storage_id = - generate_storage_id(operator.operator_id, 0, AGG_TYPE_REGULAR); - let zset_id = operator.generate_group_rowid(group_key_str); - let element_id = 0i64; // Always 0 for aggregators - - // Create index key values - let index_key_values = vec![ - Value::Integer(operator_storage_id), - Value::Integer(zset_id), - Value::Integer(element_id), - ]; - - // Create an immutable record for the index key - let index_record = - ImmutableRecord::from_values(&index_key_values, index_key_values.len()); - - // Seek in the index to find if this row exists - let seek_result = return_if_io!(cursors.index_cursor.seek( - SeekKey::IndexKey(&index_record), - SeekOp::GE { eq_only: true } - )); - - let rowid = if matches!(seek_result, SeekResult::Found) { - // Found in index, get the table rowid - // The btree code handles extracting the rowid from the index record for has_rowid indexes - return_if_io!(cursors.index_cursor.rowid()) - } else { - // Not found in index, no existing state - None - }; - - // Always transition to FetchData - let taken_existing = std::mem::take(existing_groups); - let taken_old_values = std::mem::take(old_values); - let next_state = AggregateEvalState::FetchData { - delta: std::mem::take(delta), - current_idx: *current_idx, - groups_to_read: std::mem::take(groups_to_read), - existing_groups: taken_existing, - old_values: taken_old_values, - rowid, - read_record_state: Box::new(ReadRecord::new()), - }; - *self = next_state; - } - } - AggregateEvalState::FetchData { - delta, - current_idx, - groups_to_read, - existing_groups, - old_values, - rowid, - read_record_state, - } => { - // Get the current group to read - let (group_key_str, group_key) = &groups_to_read[*current_idx]; - - // Only try to read if we have a rowid - if let Some(rowid) = rowid { - let key = SeekKey::TableRowId(*rowid); - let state = return_if_io!(read_record_state.read_record( - key, - &operator.aggregates, - &mut cursors.table_cursor - )); - // Process the fetched state - if let Some(state) = state { - let mut old_row = group_key.clone(); - old_row.extend(state.to_values(&operator.aggregates)); - old_values.insert(group_key_str.clone(), old_row); - existing_groups.insert(group_key_str.clone(), state.clone()); - } - } else { - // No rowid for this group, skipping read - } - // If no rowid, there's no existing state for this group - - // Move to next group - let next_idx = *current_idx + 1; - let taken_existing = std::mem::take(existing_groups); - let taken_old_values = std::mem::take(old_values); - let next_state = AggregateEvalState::FetchKey { - delta: std::mem::take(delta), - current_idx: next_idx, - groups_to_read: std::mem::take(groups_to_read), - existing_groups: taken_existing, - old_values: taken_old_values, - }; - *self = next_state; - } - AggregateEvalState::RecomputeMinMax { - delta, - existing_groups, - old_values, - recompute_state, - } => { - if operator.has_min_max() { - // Process MIN/MAX recomputation - this will update existing_groups with correct MIN/MAX - return_if_io!(recompute_state.process(existing_groups, operator, cursors)); - } - - // Now compute final output with updated MIN/MAX values - let (output_delta, computed_states) = - operator.merge_delta_with_existing(delta, existing_groups, old_values); - - *self = AggregateEvalState::Done { - output: (output_delta, computed_states), - }; - } - AggregateEvalState::Done { output } => { - return Ok(IOResult::Done(output.clone())); - } - } - } - } } /// Tracks computation counts to verify incremental behavior (for tests now), and in the future @@ -800,56 +563,6 @@ pub enum JoinType { Cross, } -#[derive(Debug, Clone, PartialEq)] -pub enum AggregateFunction { - Count, - Sum(String), - Avg(String), - Min(String), - Max(String), -} - -impl Display for AggregateFunction { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - match self { - AggregateFunction::Count => write!(f, "COUNT(*)"), - AggregateFunction::Sum(col) => write!(f, "SUM({col})"), - AggregateFunction::Avg(col) => write!(f, "AVG({col})"), - AggregateFunction::Min(col) => write!(f, "MIN({col})"), - AggregateFunction::Max(col) => write!(f, "MAX({col})"), - } - } -} - -impl AggregateFunction { - /// Get the default output column name for this aggregate function - #[inline] - pub fn default_output_name(&self) -> String { - self.to_string() - } - - /// Create an AggregateFunction from a SQL function and its arguments - /// Returns None if the function is not a supported aggregate - pub fn from_sql_function( - func: &crate::function::Func, - input_column: Option, - ) -> Option { - match func { - Func::Agg(agg_func) => { - match agg_func { - AggFunc::Count | AggFunc::Count0 => Some(AggregateFunction::Count), - AggFunc::Sum => input_column.map(AggregateFunction::Sum), - AggFunc::Avg => input_column.map(AggregateFunction::Avg), - AggFunc::Min => input_column.map(AggregateFunction::Min), - AggFunc::Max => input_column.map(AggregateFunction::Max), - _ => None, // Other aggregate functions not yet supported in DBSP - } - } - _ => None, // Not an aggregate function - } - } -} - /// Operator DAG (Directed Acyclic Graph) /// Base trait for incremental operators pub trait IncrementalOperator: Debug { @@ -883,859 +596,6 @@ pub trait IncrementalOperator: Debug { fn set_tracker(&mut self, tracker: Arc>); } -/// Aggregate operator - performs incremental aggregation with GROUP BY -/// Maintains running totals/counts that are updated incrementally -/// -/// Information about a column that has MIN/MAX aggregations -#[derive(Debug, Clone)] -pub struct AggColumnInfo { - /// Index used for storage key generation - pub index: usize, - /// Whether this column has a MIN aggregate - pub has_min: bool, - /// Whether this column has a MAX aggregate - pub has_max: bool, -} - -/// Note that the AggregateOperator essentially implements a ZSet, even -/// though the ZSet structure is never used explicitly. The on-disk btree -/// plays the role of the set! -#[derive(Debug)] -pub struct AggregateOperator { - // Unique operator ID for indexing in persistent storage - pub operator_id: usize, - // GROUP BY columns - group_by: Vec, - // Aggregate functions to compute (including MIN/MAX) - pub aggregates: Vec, - // Column names from input - pub input_column_names: Vec, - // Map from column name to aggregate info for quick lookup - pub column_min_max: HashMap, - tracker: Option>>, - - // State machine for commit operation - commit_state: AggregateCommitState, -} - -/// State for a single group's aggregates -#[derive(Debug, Clone, Default)] -pub struct AggregateState { - // For COUNT: just the count - count: i64, - // For SUM: column_name -> sum value - sums: HashMap, - // For AVG: column_name -> (sum, count) for computing average - avgs: HashMap, - // For MIN: column_name -> minimum value - pub mins: HashMap, - // For MAX: column_name -> maximum value - pub maxs: HashMap, -} - -/// Serialize a Value using SQLite's serial type format -/// This is used for MIN/MAX values that need to be stored in a compact, sortable format -pub fn serialize_value(value: &Value, blob: &mut Vec) { - let serial_type = crate::types::SerialType::from(value); - let serial_type_u64: u64 = serial_type.into(); - crate::storage::sqlite3_ondisk::write_varint_to_vec(serial_type_u64, blob); - value.serialize_serial(blob); -} - -/// Deserialize a Value using SQLite's serial type format -/// Returns the deserialized value and the number of bytes consumed -pub fn deserialize_value(blob: &[u8]) -> Option<(Value, usize)> { - let mut cursor = 0; - - // Read the serial type - let (serial_type, varint_size) = crate::storage::sqlite3_ondisk::read_varint(blob).ok()?; - cursor += varint_size; - - let serial_type_obj = crate::types::SerialType::try_from(serial_type).ok()?; - let expected_size = serial_type_obj.size(); - - // Read the value - let (value, actual_size) = - crate::storage::sqlite3_ondisk::read_value(&blob[cursor..], serial_type_obj).ok()?; - - // Verify that the actual size matches what we expected from the serial type - if actual_size != expected_size { - return None; // Data corruption - size mismatch - } - - cursor += actual_size; - - // Convert RefValue to Value - Some((value.to_owned(), cursor)) -} - -impl AggregateState { - pub fn new() -> Self { - Self::default() - } - - // Serialize the aggregate state to a binary blob including group key values - // The reason we serialize it like this, instead of just writing the actual values, is that - // The same table may have different aggregators in the circuit. They will all have different - // columns. - fn to_blob(&self, aggregates: &[AggregateFunction], group_key: &[Value]) -> Vec { - let mut blob = Vec::new(); - - // Write version byte for future compatibility - blob.push(1u8); - - // Write number of group key values - blob.extend_from_slice(&(group_key.len() as u32).to_le_bytes()); - - // Write each group key value - for value in group_key { - // Write value type tag - match value { - Value::Null => blob.push(0u8), - Value::Integer(i) => { - blob.push(1u8); - blob.extend_from_slice(&i.to_le_bytes()); - } - Value::Float(f) => { - blob.push(2u8); - blob.extend_from_slice(&f.to_le_bytes()); - } - Value::Text(s) => { - blob.push(3u8); - let text_str = s.as_str(); - let bytes = text_str.as_bytes(); - blob.extend_from_slice(&(bytes.len() as u32).to_le_bytes()); - blob.extend_from_slice(bytes); - } - Value::Blob(b) => { - blob.push(4u8); - blob.extend_from_slice(&(b.len() as u32).to_le_bytes()); - blob.extend_from_slice(b); - } - } - } - - // Write count as 8 bytes (little-endian) - blob.extend_from_slice(&self.count.to_le_bytes()); - - // Write each aggregate's state - for agg in aggregates { - match agg { - AggregateFunction::Sum(col_name) => { - let sum = self.sums.get(col_name).copied().unwrap_or(0.0); - blob.extend_from_slice(&sum.to_le_bytes()); - } - AggregateFunction::Avg(col_name) => { - let (sum, count) = self.avgs.get(col_name).copied().unwrap_or((0.0, 0)); - blob.extend_from_slice(&sum.to_le_bytes()); - blob.extend_from_slice(&count.to_le_bytes()); - } - AggregateFunction::Count => { - // Count is already written above - } - AggregateFunction::Min(col_name) => { - // Write whether we have a MIN value (1 byte) - if let Some(min_val) = self.mins.get(col_name) { - blob.push(1u8); // Has value - serialize_value(min_val, &mut blob); - } else { - blob.push(0u8); // No value - } - } - AggregateFunction::Max(col_name) => { - // Write whether we have a MAX value (1 byte) - if let Some(max_val) = self.maxs.get(col_name) { - blob.push(1u8); // Has value - serialize_value(max_val, &mut blob); - } else { - blob.push(0u8); // No value - } - } - } - } - - blob - } - - /// Deserialize aggregate state from a binary blob - /// Returns the aggregate state and the group key values - pub fn from_blob(blob: &[u8], aggregates: &[AggregateFunction]) -> Option<(Self, Vec)> { - let mut cursor = 0; - - // Check version byte - if blob.get(cursor) != Some(&1u8) { - return None; - } - cursor += 1; - - // Read number of group key values - let num_group_keys = - u32::from_le_bytes(blob.get(cursor..cursor + 4)?.try_into().ok()?) as usize; - cursor += 4; - - // Read group key values - let mut group_key = Vec::new(); - for _ in 0..num_group_keys { - let value_type = *blob.get(cursor)?; - cursor += 1; - - let value = match value_type { - 0 => Value::Null, - 1 => { - let i = i64::from_le_bytes(blob.get(cursor..cursor + 8)?.try_into().ok()?); - cursor += 8; - Value::Integer(i) - } - 2 => { - let f = f64::from_le_bytes(blob.get(cursor..cursor + 8)?.try_into().ok()?); - cursor += 8; - Value::Float(f) - } - 3 => { - let len = - u32::from_le_bytes(blob.get(cursor..cursor + 4)?.try_into().ok()?) as usize; - cursor += 4; - let bytes = blob.get(cursor..cursor + len)?; - cursor += len; - let text_str = std::str::from_utf8(bytes).ok()?; - Value::Text(text_str.to_string().into()) - } - 4 => { - let len = - u32::from_le_bytes(blob.get(cursor..cursor + 4)?.try_into().ok()?) as usize; - cursor += 4; - let bytes = blob.get(cursor..cursor + len)?; - cursor += len; - Value::Blob(bytes.to_vec()) - } - _ => return None, - }; - group_key.push(value); - } - - // Read count - let count = i64::from_le_bytes(blob.get(cursor..cursor + 8)?.try_into().ok()?); - cursor += 8; - - let mut state = Self::new(); - state.count = count; - - // Read each aggregate's state - for agg in aggregates { - match agg { - AggregateFunction::Sum(col_name) => { - let sum = f64::from_le_bytes(blob.get(cursor..cursor + 8)?.try_into().ok()?); - cursor += 8; - state.sums.insert(col_name.clone(), sum); - } - AggregateFunction::Avg(col_name) => { - let sum = f64::from_le_bytes(blob.get(cursor..cursor + 8)?.try_into().ok()?); - cursor += 8; - let count = i64::from_le_bytes(blob.get(cursor..cursor + 8)?.try_into().ok()?); - cursor += 8; - state.avgs.insert(col_name.clone(), (sum, count)); - } - AggregateFunction::Count => { - // Count was already read above - } - AggregateFunction::Min(col_name) => { - // Read whether we have a MIN value - let has_value = *blob.get(cursor)?; - cursor += 1; - - if has_value == 1 { - let (min_value, bytes_consumed) = deserialize_value(&blob[cursor..])?; - cursor += bytes_consumed; - state.mins.insert(col_name.clone(), min_value); - } - } - AggregateFunction::Max(col_name) => { - // Read whether we have a MAX value - let has_value = *blob.get(cursor)?; - cursor += 1; - - if has_value == 1 { - let (max_value, bytes_consumed) = deserialize_value(&blob[cursor..])?; - cursor += bytes_consumed; - state.maxs.insert(col_name.clone(), max_value); - } - } - } - } - - Some((state, group_key)) - } - - /// Apply a delta to this aggregate state - fn apply_delta( - &mut self, - values: &[Value], - weight: isize, - aggregates: &[AggregateFunction], - column_names: &[String], - ) { - // Update COUNT - self.count += weight as i64; - - // Update other aggregates - for agg in aggregates { - match agg { - AggregateFunction::Count => { - // Already handled above - } - AggregateFunction::Sum(col_name) => { - if let Some(idx) = column_names.iter().position(|c| c == col_name) { - if let Some(val) = values.get(idx) { - let num_val = match val { - Value::Integer(i) => *i as f64, - Value::Float(f) => *f, - _ => 0.0, - }; - *self.sums.entry(col_name.clone()).or_insert(0.0) += - num_val * weight as f64; - } - } - } - AggregateFunction::Avg(col_name) => { - if let Some(idx) = column_names.iter().position(|c| c == col_name) { - if let Some(val) = values.get(idx) { - let num_val = match val { - Value::Integer(i) => *i as f64, - Value::Float(f) => *f, - _ => 0.0, - }; - let (sum, count) = - self.avgs.entry(col_name.clone()).or_insert((0.0, 0)); - *sum += num_val * weight as f64; - *count += weight as i64; - } - } - } - AggregateFunction::Min(_col_name) | AggregateFunction::Max(_col_name) => { - // MIN/MAX cannot be handled incrementally in apply_delta because: - // - // 1. For insertions: We can't just keep the minimum/maximum value. - // We need to track ALL values to handle future deletions correctly. - // - // 2. For deletions (retractions): If we delete the current MIN/MAX, - // we need to find the next best value, which requires knowing all - // other values in the group. - // - // Example: Consider MIN(price) with values [10, 20, 30] - // - Current MIN = 10 - // - Delete 10 (weight = -1) - // - New MIN should be 20, but we can't determine this without - // having tracked all values [20, 30] - // - // Therefore, MIN/MAX processing is handled separately: - // - All input values are persisted to the index via persist_min_max() - // - When aggregates have MIN/MAX, we unconditionally transition to - // the RecomputeMinMax state machine (see EvalState::RecomputeMinMax) - // - RecomputeMinMax checks if the current MIN/MAX was deleted, and if so, - // scans the index to find the new MIN/MAX from remaining values - // - // This ensures correctness for incremental computation at the cost of - // additional I/O for MIN/MAX operations. - } - } - } - } - - /// Convert aggregate state to output values - pub fn to_values(&self, aggregates: &[AggregateFunction]) -> Vec { - let mut result = Vec::new(); - - for agg in aggregates { - match agg { - AggregateFunction::Count => { - result.push(Value::Integer(self.count)); - } - AggregateFunction::Sum(col_name) => { - let sum = self.sums.get(col_name).copied().unwrap_or(0.0); - // Return as integer if it's a whole number, otherwise as float - if sum.fract() == 0.0 { - result.push(Value::Integer(sum as i64)); - } else { - result.push(Value::Float(sum)); - } - } - AggregateFunction::Avg(col_name) => { - if let Some((sum, count)) = self.avgs.get(col_name) { - if *count > 0 { - result.push(Value::Float(sum / *count as f64)); - } else { - result.push(Value::Null); - } - } else { - result.push(Value::Null); - } - } - AggregateFunction::Min(col_name) => { - // Return the MIN value from our state - result.push(self.mins.get(col_name).cloned().unwrap_or(Value::Null)); - } - AggregateFunction::Max(col_name) => { - // Return the MAX value from our state - result.push(self.maxs.get(col_name).cloned().unwrap_or(Value::Null)); - } - } - } - - result - } -} - -impl AggregateOperator { - pub fn new( - operator_id: usize, - group_by: Vec, - aggregates: Vec, - input_column_names: Vec, - ) -> Self { - // Build map of column names to their MIN/MAX info with indices - let mut column_min_max = HashMap::new(); - let mut column_indices = HashMap::new(); - let mut current_index = 0; - - // First pass: assign indices to unique MIN/MAX columns - for agg in &aggregates { - match agg { - AggregateFunction::Min(col) | AggregateFunction::Max(col) => { - column_indices.entry(col.clone()).or_insert_with(|| { - let idx = current_index; - current_index += 1; - idx - }); - } - _ => {} - } - } - - // Second pass: build the column info map - for agg in &aggregates { - match agg { - AggregateFunction::Min(col) => { - let index = *column_indices.get(col).unwrap(); - let entry = column_min_max.entry(col.clone()).or_insert(AggColumnInfo { - index, - has_min: false, - has_max: false, - }); - entry.has_min = true; - } - AggregateFunction::Max(col) => { - let index = *column_indices.get(col).unwrap(); - let entry = column_min_max.entry(col.clone()).or_insert(AggColumnInfo { - index, - has_min: false, - has_max: false, - }); - entry.has_max = true; - } - _ => {} - } - } - - Self { - operator_id, - group_by, - aggregates, - input_column_names, - column_min_max, - tracker: None, - commit_state: AggregateCommitState::Idle, - } - } - - pub fn has_min_max(&self) -> bool { - !self.column_min_max.is_empty() - } - - fn eval_internal( - &mut self, - state: &mut EvalState, - cursors: &mut DbspStateCursors, - ) -> Result> { - match state { - EvalState::Uninitialized => { - panic!("Cannot eval AggregateOperator with Uninitialized state"); - } - EvalState::Init { deltas } => { - // Aggregate operators only use left_delta, right_delta must be empty - assert!( - deltas.right.is_empty(), - "AggregateOperator expects right_delta to be empty" - ); - - if deltas.left.changes.is_empty() { - *state = EvalState::Done; - return Ok(IOResult::Done((Delta::new(), HashMap::new()))); - } - - let mut groups_to_read = BTreeMap::new(); - for (row, _weight) in &deltas.left.changes { - // Extract group key using cloned fields - let group_key = self.extract_group_key(&row.values); - let group_key_str = Self::group_key_to_string(&group_key); - groups_to_read.insert(group_key_str, group_key); - } - state.advance_aggregate(groups_to_read); - } - EvalState::Aggregate(_agg_state) => { - // Already in progress, continue processing below. - } - EvalState::Done => { - panic!("unreachable state! should have returned"); - } - EvalState::Join(_) => { - panic!("Join state should not appear in aggregate operator"); - } - } - - // Process the delta through the aggregate state machine - match state { - EvalState::Aggregate(agg_state) => { - let result = return_if_io!(agg_state.process_delta(self, cursors)); - Ok(IOResult::Done(result)) - } - _ => panic!("Invalid state for aggregate processing"), - } - } - - fn merge_delta_with_existing( - &mut self, - delta: &Delta, - existing_groups: &mut HashMap, - old_values: &mut HashMap>, - ) -> (Delta, HashMap, AggregateState)>) { - let mut output_delta = Delta::new(); - let mut temp_keys: HashMap> = HashMap::new(); - - // Process each change in the delta - for (row, weight) in &delta.changes { - if let Some(tracker) = &self.tracker { - tracker.lock().unwrap().record_aggregation(); - } - - // Extract group key - let group_key = self.extract_group_key(&row.values); - let group_key_str = Self::group_key_to_string(&group_key); - - let state = existing_groups.entry(group_key_str.clone()).or_default(); - - temp_keys.insert(group_key_str.clone(), group_key.clone()); - - // Apply the delta to the temporary state - state.apply_delta( - &row.values, - *weight, - &self.aggregates, - &self.input_column_names, - ); - } - - // Generate output delta from temporary states and collect final states - let mut final_states = HashMap::new(); - - for (group_key_str, state) in existing_groups { - let group_key = temp_keys.get(group_key_str).cloned().unwrap_or_default(); - - // Generate a unique rowid for this group - let result_key = self.generate_group_rowid(group_key_str); - - if let Some(old_row_values) = old_values.get(group_key_str) { - let old_row = HashableRow::new(result_key, old_row_values.clone()); - output_delta.changes.push((old_row, -1)); - } - - // Always store the state for persistence (even if count=0, we need to delete it) - final_states.insert(group_key_str.clone(), (group_key.clone(), state.clone())); - - // Only include groups with count > 0 in the output delta - if state.count > 0 { - // Build output row: group_by columns + aggregate values - let mut output_values = group_key.clone(); - let aggregate_values = state.to_values(&self.aggregates); - output_values.extend(aggregate_values); - - let output_row = HashableRow::new(result_key, output_values.clone()); - output_delta.changes.push((output_row, 1)); - } - } - (output_delta, final_states) - } - - /// Extract MIN/MAX values from delta changes for persistence to index - fn extract_min_max_deltas(&self, delta: &Delta) -> MinMaxDeltas { - let mut min_max_deltas: MinMaxDeltas = HashMap::new(); - - for (row, weight) in &delta.changes { - let group_key = self.extract_group_key(&row.values); - let group_key_str = Self::group_key_to_string(&group_key); - - for agg in &self.aggregates { - match agg { - AggregateFunction::Min(col_name) | AggregateFunction::Max(col_name) => { - if let Some(idx) = - self.input_column_names.iter().position(|c| c == col_name) - { - if let Some(val) = row.values.get(idx) { - // Skip NULL values - they don't participate in MIN/MAX - if val == &Value::Null { - continue; - } - // Create a HashableRow with just this value - // Use 0 as rowid since we only care about the value for comparison - let hashable_value = HashableRow::new(0, vec![val.clone()]); - let key = (col_name.clone(), hashable_value); - - let group_entry = - min_max_deltas.entry(group_key_str.clone()).or_default(); - - let value_entry = group_entry.entry(key).or_insert(0); - - // Accumulate the weight - *value_entry += weight; - } - } - } - _ => {} // Ignore non-MIN/MAX aggregates - } - } - } - - min_max_deltas - } - - pub fn set_tracker(&mut self, tracker: Arc>) { - self.tracker = Some(tracker); - } - - /// Generate a rowid for a group - /// For no GROUP BY: always returns 0 - /// For GROUP BY: returns a hash of the group key string - pub fn generate_group_rowid(&self, group_key_str: &str) -> i64 { - if self.group_by.is_empty() { - 0 - } else { - group_key_str - .bytes() - .fold(0i64, |acc, b| acc.wrapping_mul(31).wrapping_add(b as i64)) - } - } - - /// Generate the composite key for BTree storage - /// Combines operator_id and group hash - fn generate_storage_key(&self, group_key_str: &str) -> i64 { - let group_hash = self.generate_group_rowid(group_key_str); - (self.operator_id as i64) << 32 | (group_hash & 0xFFFFFFFF) - } - - /// Extract group key values from a row - pub fn extract_group_key(&self, values: &[Value]) -> Vec { - let mut key = Vec::new(); - - for group_col in &self.group_by { - if let Some(idx) = self.input_column_names.iter().position(|c| c == group_col) { - if let Some(val) = values.get(idx) { - key.push(val.clone()); - } else { - key.push(Value::Null); - } - } else { - key.push(Value::Null); - } - } - - key - } - - /// Convert group key to string for indexing (since Value doesn't implement Hash) - pub fn group_key_to_string(key: &[Value]) -> String { - key.iter() - .map(|v| format!("{v:?}")) - .collect::>() - .join(",") - } - - fn seek_key_from_str(&self, group_key_str: &str) -> SeekKey<'_> { - // Calculate the composite key for seeking - let key_i64 = self.generate_storage_key(group_key_str); - SeekKey::TableRowId(key_i64) - } - - fn seek_key(&self, row: HashableRow) -> SeekKey<'_> { - // Extract group key for first row - let group_key = self.extract_group_key(&row.values); - let group_key_str = Self::group_key_to_string(&group_key); - self.seek_key_from_str(&group_key_str) - } -} - -impl IncrementalOperator for AggregateOperator { - fn eval( - &mut self, - state: &mut EvalState, - cursors: &mut DbspStateCursors, - ) -> Result> { - let (delta, _) = return_if_io!(self.eval_internal(state, cursors)); - Ok(IOResult::Done(delta)) - } - - fn commit( - &mut self, - mut deltas: DeltaPair, - cursors: &mut DbspStateCursors, - ) -> Result> { - // Aggregate operator only uses left delta, right must be empty - assert!( - deltas.right.is_empty(), - "AggregateOperator expects right delta to be empty in commit" - ); - let delta = std::mem::take(&mut deltas.left); - loop { - // Note: because we std::mem::replace here (without it, the borrow checker goes nuts, - // because we call self.eval_interval, which requires a mutable borrow), we have to - // restore the state if we return I/O. So we can't use return_if_io! - let mut state = - std::mem::replace(&mut self.commit_state, AggregateCommitState::Invalid); - match &mut state { - AggregateCommitState::Invalid => { - panic!("Reached invalid state! State was replaced, and not replaced back"); - } - AggregateCommitState::Idle => { - let eval_state = EvalState::from_delta(delta.clone()); - self.commit_state = AggregateCommitState::Eval { eval_state }; - } - AggregateCommitState::Eval { ref mut eval_state } => { - // Extract input delta before eval for MIN/MAX processing - let input_delta = eval_state.extract_delta(); - - // Extract MIN/MAX deltas before any I/O operations - let min_max_deltas = self.extract_min_max_deltas(&input_delta); - - // Create a new eval state with the same delta - *eval_state = EvalState::from_delta(input_delta.clone()); - - let (output_delta, computed_states) = return_and_restore_if_io!( - &mut self.commit_state, - state, - self.eval_internal(eval_state, cursors) - ); - - self.commit_state = AggregateCommitState::PersistDelta { - delta: output_delta, - computed_states, - current_idx: 0, - write_row: WriteRow::new(), - min_max_deltas, // Store for later use - }; - } - AggregateCommitState::PersistDelta { - delta, - computed_states, - current_idx, - write_row, - min_max_deltas, - } => { - let states_vec: Vec<_> = computed_states.iter().collect(); - - if *current_idx >= states_vec.len() { - // Use the min_max_deltas we extracted earlier from the input delta - self.commit_state = AggregateCommitState::PersistMinMax { - delta: delta.clone(), - min_max_persist_state: MinMaxPersistState::new(min_max_deltas.clone()), - }; - } else { - let (group_key_str, (group_key, agg_state)) = states_vec[*current_idx]; - - // Build the key components for the new table structure - // For regular aggregates, use column_index=0 and AGG_TYPE_REGULAR - let operator_storage_id = - generate_storage_id(self.operator_id, 0, AGG_TYPE_REGULAR); - let zset_id = self.generate_group_rowid(group_key_str); - let element_id = 0i64; - - // Determine weight: -1 to delete (cancels existing weight=1), 1 to insert/update - let weight = if agg_state.count == 0 { -1 } else { 1 }; - - // Serialize the aggregate state with group key (even for deletion, we need a row) - let state_blob = agg_state.to_blob(&self.aggregates, group_key); - let blob_value = Value::Blob(state_blob); - - // Build the aggregate storage format: [operator_id, zset_id, element_id, value, weight] - let operator_id_val = Value::Integer(operator_storage_id); - let zset_id_val = Value::Integer(zset_id); - let element_id_val = Value::Integer(element_id); - let blob_val = blob_value.clone(); - - // Create index key - the first 3 columns of our primary key - let index_key = vec![ - operator_id_val.clone(), - zset_id_val.clone(), - element_id_val.clone(), - ]; - - // Record values (without weight) - let record_values = - vec![operator_id_val, zset_id_val, element_id_val, blob_val]; - - return_and_restore_if_io!( - &mut self.commit_state, - state, - write_row.write_row(cursors, index_key, record_values, weight) - ); - - let delta = std::mem::take(delta); - let computed_states = std::mem::take(computed_states); - let min_max_deltas = std::mem::take(min_max_deltas); - - self.commit_state = AggregateCommitState::PersistDelta { - delta, - computed_states, - current_idx: *current_idx + 1, - write_row: WriteRow::new(), // Reset for next write - min_max_deltas, - }; - } - } - AggregateCommitState::PersistMinMax { - delta, - min_max_persist_state, - } => { - if !self.has_min_max() { - let delta = std::mem::take(delta); - self.commit_state = AggregateCommitState::Done { delta }; - } else { - return_and_restore_if_io!( - &mut self.commit_state, - state, - min_max_persist_state.persist_min_max( - self.operator_id, - &self.column_min_max, - cursors, - |group_key_str| self.generate_group_rowid(group_key_str) - ) - ); - - let delta = std::mem::take(delta); - self.commit_state = AggregateCommitState::Done { delta }; - } - } - AggregateCommitState::Done { delta } => { - self.commit_state = AggregateCommitState::Idle; - let delta = std::mem::take(delta); - return Ok(IOResult::Done(delta)); - } - } - } - } - - fn set_tracker(&mut self, tracker: Arc>) { - self.tracker = Some(tracker); - } -} - #[derive(Debug)] enum JoinCommitState { Idle, @@ -2226,6 +1086,7 @@ impl IncrementalOperator for JoinOperator { #[cfg(test)] mod tests { use super::*; + use crate::incremental::aggregate_operator::AGG_TYPE_REGULAR; use crate::storage::pager::CreateBTreeFlags; use crate::types::Text; use crate::util::IOExt; diff --git a/core/incremental/persistence.rs b/core/incremental/persistence.rs index eca26cd7c..5cf41b94a 100644 --- a/core/incremental/persistence.rs +++ b/core/incremental/persistence.rs @@ -1,12 +1,7 @@ -use crate::incremental::dbsp::HashableRow; -use crate::incremental::operator::{ - generate_storage_id, AggColumnInfo, AggregateFunction, AggregateOperator, AggregateState, - DbspStateCursors, MinMaxDeltas, AGG_TYPE_MINMAX, -}; +use crate::incremental::operator::{AggregateFunction, AggregateState, DbspStateCursors}; use crate::storage::btree::{BTreeCursor, BTreeKey}; -use crate::types::{IOResult, ImmutableRecord, RefValue, SeekKey, SeekOp, SeekResult}; +use crate::types::{IOResult, ImmutableRecord, SeekKey, SeekOp, SeekResult}; use crate::{return_if_io, LimboError, Result, Value}; -use std::collections::{HashMap, HashSet}; #[derive(Debug, Default)] pub enum ReadRecord { @@ -290,672 +285,3 @@ impl WriteRow { } } } - -/// State machine for recomputing MIN/MAX values after deletion -#[derive(Debug)] -pub enum RecomputeMinMax { - ProcessElements { - /// Current column being processed - current_column_idx: usize, - /// Columns to process (combined MIN and MAX) - columns_to_process: Vec<(String, String, bool)>, // (group_key, column_name, is_min) - /// MIN/MAX deltas for checking values and weights - min_max_deltas: MinMaxDeltas, - }, - Scan { - /// Columns still to process - columns_to_process: Vec<(String, String, bool)>, - /// Current index in columns_to_process (will resume from here) - current_column_idx: usize, - /// MIN/MAX deltas for checking values and weights - min_max_deltas: MinMaxDeltas, - /// Current group key being processed - group_key: String, - /// Current column name being processed - column_name: String, - /// Whether we're looking for MIN (true) or MAX (false) - is_min: bool, - /// The scan state machine for finding the new MIN/MAX - scan_state: Box, - }, - Done, -} - -impl RecomputeMinMax { - pub fn new( - min_max_deltas: MinMaxDeltas, - existing_groups: &HashMap, - operator: &AggregateOperator, - ) -> Self { - let mut groups_to_check: HashSet<(String, String, bool)> = HashSet::new(); - - // Remember the min_max_deltas are essentially just the only column that is affected by - // this min/max, in delta (actually ZSet - consolidated delta) format. This makes it easier - // for us to consume it in here. - // - // The most challenging case is the case where there is a retraction, since we need to go - // back to the index. - for (group_key_str, values) in &min_max_deltas { - for ((col_name, hashable_row), weight) in values { - let col_info = operator.column_min_max.get(col_name); - - let value = &hashable_row.values[0]; - - if *weight < 0 { - // Deletion detected - check if it's the current MIN/MAX - if let Some(state) = existing_groups.get(group_key_str) { - // Check for MIN - if let Some(current_min) = state.mins.get(col_name) { - if current_min == value { - groups_to_check.insert(( - group_key_str.clone(), - col_name.clone(), - true, - )); - } - } - // Check for MAX - if let Some(current_max) = state.maxs.get(col_name) { - if current_max == value { - groups_to_check.insert(( - group_key_str.clone(), - col_name.clone(), - false, - )); - } - } - } - } else if *weight > 0 { - // If it is not found in the existing groups, then we only need to care - // about this if this is a new record being inserted - if let Some(info) = col_info { - if info.has_min { - groups_to_check.insert((group_key_str.clone(), col_name.clone(), true)); - } - if info.has_max { - groups_to_check.insert(( - group_key_str.clone(), - col_name.clone(), - false, - )); - } - } - } - } - } - - if groups_to_check.is_empty() { - // No recomputation or initialization needed - Self::Done - } else { - // Convert HashSet to Vec for indexed processing - let groups_to_check_vec: Vec<_> = groups_to_check.into_iter().collect(); - Self::ProcessElements { - current_column_idx: 0, - columns_to_process: groups_to_check_vec, - min_max_deltas, - } - } - } - - pub fn process( - &mut self, - existing_groups: &mut HashMap, - operator: &AggregateOperator, - cursors: &mut DbspStateCursors, - ) -> Result> { - loop { - match self { - RecomputeMinMax::ProcessElements { - current_column_idx, - columns_to_process, - min_max_deltas, - } => { - if *current_column_idx >= columns_to_process.len() { - *self = RecomputeMinMax::Done; - return Ok(IOResult::Done(())); - } - - let (group_key, column_name, is_min) = - columns_to_process[*current_column_idx].clone(); - - // Get column index from pre-computed info - let column_index = operator - .column_min_max - .get(&column_name) - .map(|info| info.index) - .unwrap(); // Should always exist since we're processing known columns - - // Get current value from existing state - let current_value = existing_groups.get(&group_key).and_then(|state| { - if is_min { - state.mins.get(&column_name).cloned() - } else { - state.maxs.get(&column_name).cloned() - } - }); - - // Create storage keys for index lookup - let storage_id = - generate_storage_id(operator.operator_id, column_index, AGG_TYPE_MINMAX); - let zset_id = operator.generate_group_rowid(&group_key); - - // Get the values for this group from min_max_deltas - let group_values = min_max_deltas.get(&group_key).cloned().unwrap_or_default(); - - let columns_to_process = std::mem::take(columns_to_process); - let min_max_deltas = std::mem::take(min_max_deltas); - - let scan_state = if is_min { - Box::new(ScanState::new_for_min( - current_value, - group_key.clone(), - column_name.clone(), - storage_id, - zset_id, - group_values, - )) - } else { - Box::new(ScanState::new_for_max( - current_value, - group_key.clone(), - column_name.clone(), - storage_id, - zset_id, - group_values, - )) - }; - - *self = RecomputeMinMax::Scan { - columns_to_process, - current_column_idx: *current_column_idx, - min_max_deltas, - group_key, - column_name, - is_min, - scan_state, - }; - } - RecomputeMinMax::Scan { - columns_to_process, - current_column_idx, - min_max_deltas, - group_key, - column_name, - is_min, - scan_state, - } => { - // Find new value using the scan state machine - let new_value = return_if_io!(scan_state.find_new_value(cursors)); - - // Update the state with new value (create if doesn't exist) - let state = existing_groups.entry(group_key.clone()).or_default(); - - if *is_min { - if let Some(min_val) = new_value { - state.mins.insert(column_name.clone(), min_val); - } else { - state.mins.remove(column_name); - } - } else if let Some(max_val) = new_value { - state.maxs.insert(column_name.clone(), max_val); - } else { - state.maxs.remove(column_name); - } - - // Move to next column - let min_max_deltas = std::mem::take(min_max_deltas); - let columns_to_process = std::mem::take(columns_to_process); - *self = RecomputeMinMax::ProcessElements { - current_column_idx: *current_column_idx + 1, - columns_to_process, - min_max_deltas, - }; - } - RecomputeMinMax::Done => { - return Ok(IOResult::Done(())); - } - } - } - } -} - -/// State machine for scanning through the index to find new MIN/MAX values -#[derive(Debug)] -pub enum ScanState { - CheckCandidate { - /// Current candidate value for MIN/MAX - candidate: Option, - /// Group key being processed - group_key: String, - /// Column name being processed - column_name: String, - /// Storage ID for the index seek - storage_id: i64, - /// ZSet ID for the group - zset_id: i64, - /// Group values from MinMaxDeltas: (column_name, HashableRow) -> weight - group_values: HashMap<(String, HashableRow), isize>, - /// Whether we're looking for MIN (true) or MAX (false) - is_min: bool, - }, - FetchNextCandidate { - /// Current candidate to seek past - current_candidate: Value, - /// Group key being processed - group_key: String, - /// Column name being processed - column_name: String, - /// Storage ID for the index seek - storage_id: i64, - /// ZSet ID for the group - zset_id: i64, - /// Group values from MinMaxDeltas: (column_name, HashableRow) -> weight - group_values: HashMap<(String, HashableRow), isize>, - /// Whether we're looking for MIN (true) or MAX (false) - is_min: bool, - }, - Done { - /// The final MIN/MAX value found - result: Option, - }, -} - -impl ScanState { - pub fn new_for_min( - current_min: Option, - group_key: String, - column_name: String, - storage_id: i64, - zset_id: i64, - group_values: HashMap<(String, HashableRow), isize>, - ) -> Self { - Self::CheckCandidate { - candidate: current_min, - group_key, - column_name, - storage_id, - zset_id, - group_values, - is_min: true, - } - } - - // Extract a new candidate from the index. It is possible that, when searching, - // we end up going into a different operator altogether. That means we have - // exhausted this operator (or group) entirely, and no good candidate was found - fn extract_new_candidate( - cursors: &mut DbspStateCursors, - index_record: &ImmutableRecord, - seek_op: SeekOp, - storage_id: i64, - zset_id: i64, - ) -> Result>> { - let seek_result = return_if_io!(cursors - .index_cursor - .seek(SeekKey::IndexKey(index_record), seek_op)); - if !matches!(seek_result, SeekResult::Found) { - return Ok(IOResult::Done(None)); - } - - let record = return_if_io!(cursors.index_cursor.record()).ok_or_else(|| { - LimboError::InternalError( - "Record found on the cursor, but could not be read".to_string(), - ) - })?; - - let values = record.get_values(); - if values.len() < 3 { - return Ok(IOResult::Done(None)); - } - - let Some(rec_storage_id) = values.first() else { - return Ok(IOResult::Done(None)); - }; - let Some(rec_zset_id) = values.get(1) else { - return Ok(IOResult::Done(None)); - }; - - // Check if we're still in the same group - if let (RefValue::Integer(rec_sid), RefValue::Integer(rec_zid)) = - (rec_storage_id, rec_zset_id) - { - if *rec_sid != storage_id || *rec_zid != zset_id { - return Ok(IOResult::Done(None)); - } - } else { - return Ok(IOResult::Done(None)); - } - - // Get the value (3rd element) - Ok(IOResult::Done(values.get(2).map(|v| v.to_owned()))) - } - - pub fn new_for_max( - current_max: Option, - group_key: String, - column_name: String, - storage_id: i64, - zset_id: i64, - group_values: HashMap<(String, HashableRow), isize>, - ) -> Self { - Self::CheckCandidate { - candidate: current_max, - group_key, - column_name, - storage_id, - zset_id, - group_values, - is_min: false, - } - } - - pub fn find_new_value( - &mut self, - cursors: &mut DbspStateCursors, - ) -> Result>> { - loop { - match self { - ScanState::CheckCandidate { - candidate, - group_key, - column_name, - storage_id, - zset_id, - group_values, - is_min, - } => { - // First, check if we have a candidate - if let Some(cand_val) = candidate { - // Check if the candidate is retracted (weight <= 0) - // Create a HashableRow to look up the weight - let hashable_cand = HashableRow::new(0, vec![cand_val.clone()]); - let key = (column_name.clone(), hashable_cand); - let is_retracted = - group_values.get(&key).is_some_and(|weight| *weight <= 0); - - if is_retracted { - // Candidate is retracted, need to fetch next from index - *self = ScanState::FetchNextCandidate { - current_candidate: cand_val.clone(), - group_key: std::mem::take(group_key), - column_name: std::mem::take(column_name), - storage_id: *storage_id, - zset_id: *zset_id, - group_values: std::mem::take(group_values), - is_min: *is_min, - }; - continue; - } - } - - // Candidate is valid or we have no candidate - // Now find the best value from insertions in group_values - let mut best_from_zset = None; - for ((col, hashable_val), weight) in group_values.iter() { - if col == column_name && *weight > 0 { - let value = &hashable_val.values[0]; - // Skip NULL values - they don't participate in MIN/MAX - if value == &Value::Null { - continue; - } - // This is an insertion for our column - if let Some(ref current_best) = best_from_zset { - if *is_min { - if value.cmp(current_best) == std::cmp::Ordering::Less { - best_from_zset = Some(value.clone()); - } - } else if value.cmp(current_best) == std::cmp::Ordering::Greater { - best_from_zset = Some(value.clone()); - } - } else { - best_from_zset = Some(value.clone()); - } - } - } - - // Compare candidate with best from ZSet, filtering out NULLs - let result = match (&candidate, &best_from_zset) { - (Some(cand), Some(zset_val)) if cand != &Value::Null => { - if *is_min { - if zset_val.cmp(cand) == std::cmp::Ordering::Less { - Some(zset_val.clone()) - } else { - Some(cand.clone()) - } - } else if zset_val.cmp(cand) == std::cmp::Ordering::Greater { - Some(zset_val.clone()) - } else { - Some(cand.clone()) - } - } - (Some(cand), None) if cand != &Value::Null => Some(cand.clone()), - (None, Some(zset_val)) => Some(zset_val.clone()), - (Some(cand), Some(_)) if cand == &Value::Null => best_from_zset, - _ => None, - }; - - *self = ScanState::Done { result }; - } - - ScanState::FetchNextCandidate { - current_candidate, - group_key, - column_name, - storage_id, - zset_id, - group_values, - is_min, - } => { - // Seek to the next value in the index - let index_key = vec![ - Value::Integer(*storage_id), - Value::Integer(*zset_id), - current_candidate.clone(), - ]; - let index_record = ImmutableRecord::from_values(&index_key, index_key.len()); - - let seek_op = if *is_min { - SeekOp::GT // For MIN, seek greater than current - } else { - SeekOp::LT // For MAX, seek less than current - }; - - let new_candidate = return_if_io!(Self::extract_new_candidate( - cursors, - &index_record, - seek_op, - *storage_id, - *zset_id - )); - - *self = ScanState::CheckCandidate { - candidate: new_candidate, - group_key: std::mem::take(group_key), - column_name: std::mem::take(column_name), - storage_id: *storage_id, - zset_id: *zset_id, - group_values: std::mem::take(group_values), - is_min: *is_min, - }; - } - - ScanState::Done { result } => { - return Ok(IOResult::Done(result.clone())); - } - } - } - } -} - -/// State machine for persisting Min/Max values to storage -#[derive(Debug)] -pub enum MinMaxPersistState { - Init { - min_max_deltas: MinMaxDeltas, - group_keys: Vec, - }, - ProcessGroup { - min_max_deltas: MinMaxDeltas, - group_keys: Vec, - group_idx: usize, - value_idx: usize, - }, - WriteValue { - min_max_deltas: MinMaxDeltas, - group_keys: Vec, - group_idx: usize, - value_idx: usize, - value: Value, - column_name: String, - weight: isize, - write_row: WriteRow, - }, - Done, -} - -impl MinMaxPersistState { - pub fn new(min_max_deltas: MinMaxDeltas) -> Self { - let group_keys: Vec = min_max_deltas.keys().cloned().collect(); - Self::Init { - min_max_deltas, - group_keys, - } - } - - pub fn persist_min_max( - &mut self, - operator_id: usize, - column_min_max: &HashMap, - cursors: &mut DbspStateCursors, - generate_group_rowid: impl Fn(&str) -> i64, - ) -> Result> { - loop { - match self { - MinMaxPersistState::Init { - min_max_deltas, - group_keys, - } => { - let min_max_deltas = std::mem::take(min_max_deltas); - let group_keys = std::mem::take(group_keys); - *self = MinMaxPersistState::ProcessGroup { - min_max_deltas, - group_keys, - group_idx: 0, - value_idx: 0, - }; - } - MinMaxPersistState::ProcessGroup { - min_max_deltas, - group_keys, - group_idx, - value_idx, - } => { - // Check if we're past all groups - if *group_idx >= group_keys.len() { - *self = MinMaxPersistState::Done; - continue; - } - - let group_key_str = &group_keys[*group_idx]; - let values = &min_max_deltas[group_key_str]; // This should always exist - - // Convert HashMap to Vec for indexed access - let values_vec: Vec<_> = values.iter().collect(); - - // Check if we have more values in current group - if *value_idx >= values_vec.len() { - *group_idx += 1; - *value_idx = 0; - // Continue to check if we're past all groups now - continue; - } - - // Process current value and extract what we need before taking ownership - let ((column_name, hashable_row), weight) = values_vec[*value_idx]; - let column_name = column_name.clone(); - let value = hashable_row.values[0].clone(); // Extract the Value from HashableRow - let weight = *weight; - - let min_max_deltas = std::mem::take(min_max_deltas); - let group_keys = std::mem::take(group_keys); - *self = MinMaxPersistState::WriteValue { - min_max_deltas, - group_keys, - group_idx: *group_idx, - value_idx: *value_idx, - column_name, - value, - weight, - write_row: WriteRow::new(), - }; - } - MinMaxPersistState::WriteValue { - min_max_deltas, - group_keys, - group_idx, - value_idx, - value, - column_name, - weight, - write_row, - } => { - // Should have exited in the previous state - assert!(*group_idx < group_keys.len()); - - let group_key_str = &group_keys[*group_idx]; - - // Get the column index from the pre-computed map - let column_info = column_min_max - .get(&*column_name) - .expect("Column should exist in column_min_max map"); - let column_index = column_info.index; - - // Build the key components for MinMax storage using new encoding - let storage_id = - generate_storage_id(operator_id, column_index, AGG_TYPE_MINMAX); - let zset_id = generate_group_rowid(group_key_str); - - // element_id is the actual value for Min/Max - let element_id_val = value.clone(); - - // Create index key - let index_key = vec![ - Value::Integer(storage_id), - Value::Integer(zset_id), - element_id_val.clone(), - ]; - - // Record values (operator_id, zset_id, element_id, unused_placeholder) - // For MIN/MAX, the element_id IS the value, so we use NULL for the 4th column - let record_values = vec![ - Value::Integer(storage_id), - Value::Integer(zset_id), - element_id_val.clone(), - Value::Null, // Placeholder - not used for MIN/MAX - ]; - - return_if_io!(write_row.write_row( - cursors, - index_key.clone(), - record_values, - *weight - )); - - // Move to next value - let min_max_deltas = std::mem::take(min_max_deltas); - let group_keys = std::mem::take(group_keys); - *self = MinMaxPersistState::ProcessGroup { - min_max_deltas, - group_keys, - group_idx: *group_idx, - value_idx: *value_idx + 1, - }; - } - MinMaxPersistState::Done => { - return Ok(IOResult::Done(())); - } - } - } - } -}