diff --git a/core/incremental/aggregate_operator.rs b/core/incremental/aggregate_operator.rs index 9f25a84f5..df4d1a18a 100644 --- a/core/incremental/aggregate_operator.rs +++ b/core/incremental/aggregate_operator.rs @@ -16,6 +16,13 @@ use std::sync::{Arc, Mutex}; pub const AGG_TYPE_REGULAR: u8 = 0b00; // COUNT/SUM/AVG pub const AGG_TYPE_MINMAX: u8 = 0b01; // MIN/MAX (BTree ordering gives both) +// Serialization type codes for aggregate functions +const AGG_FUNC_COUNT: i64 = 0; +const AGG_FUNC_SUM: i64 = 1; +const AGG_FUNC_AVG: i64 = 2; +const AGG_FUNC_MIN: i64 = 3; +const AGG_FUNC_MAX: i64 = 4; + #[derive(Debug, Clone, PartialEq)] pub enum AggregateFunction { Count, @@ -44,6 +51,104 @@ impl AggregateFunction { self.to_string() } + /// Serialize this aggregate function to a Value + /// Returns a vector of values: [type_code, optional_column_index] + pub fn to_values(&self) -> Vec { + match self { + AggregateFunction::Count => vec![Value::Integer(AGG_FUNC_COUNT)], + AggregateFunction::Sum(idx) => { + vec![Value::Integer(AGG_FUNC_SUM), Value::Integer(*idx as i64)] + } + AggregateFunction::Avg(idx) => { + vec![Value::Integer(AGG_FUNC_AVG), Value::Integer(*idx as i64)] + } + AggregateFunction::Min(idx) => { + vec![Value::Integer(AGG_FUNC_MIN), Value::Integer(*idx as i64)] + } + AggregateFunction::Max(idx) => { + vec![Value::Integer(AGG_FUNC_MAX), Value::Integer(*idx as i64)] + } + } + } + + /// Deserialize an aggregate function from values + /// Consumes values from the cursor and returns the aggregate function + pub fn from_values(values: &[Value], cursor: &mut usize) -> Result { + let type_code = values + .get(*cursor) + .ok_or_else(|| LimboError::InternalError("Missing aggregate type code".into()))?; + + let agg_fn = match type_code { + Value::Integer(AGG_FUNC_COUNT) => { + *cursor += 1; + AggregateFunction::Count + } + Value::Integer(AGG_FUNC_SUM) => { + *cursor += 1; + let idx = values + .get(*cursor) + .ok_or_else(|| LimboError::InternalError("Missing SUM column index".into()))?; + if let Value::Integer(idx) = idx { + *cursor += 1; + AggregateFunction::Sum(*idx as usize) + } else { + return Err(LimboError::InternalError(format!( + "Expected Integer for SUM column index, got {idx:?}" + ))); + } + } + Value::Integer(AGG_FUNC_AVG) => { + *cursor += 1; + let idx = values + .get(*cursor) + .ok_or_else(|| LimboError::InternalError("Missing AVG column index".into()))?; + if let Value::Integer(idx) = idx { + *cursor += 1; + AggregateFunction::Avg(*idx as usize) + } else { + return Err(LimboError::InternalError(format!( + "Expected Integer for AVG column index, got {idx:?}" + ))); + } + } + Value::Integer(AGG_FUNC_MIN) => { + *cursor += 1; + let idx = values + .get(*cursor) + .ok_or_else(|| LimboError::InternalError("Missing MIN column index".into()))?; + if let Value::Integer(idx) = idx { + *cursor += 1; + AggregateFunction::Min(*idx as usize) + } else { + return Err(LimboError::InternalError(format!( + "Expected Integer for MIN column index, got {idx:?}" + ))); + } + } + Value::Integer(AGG_FUNC_MAX) => { + *cursor += 1; + let idx = values + .get(*cursor) + .ok_or_else(|| LimboError::InternalError("Missing MAX column index".into()))?; + if let Value::Integer(idx) = idx { + *cursor += 1; + AggregateFunction::Max(*idx as usize) + } else { + return Err(LimboError::InternalError(format!( + "Expected Integer for MAX column index, got {idx:?}" + ))); + } + } + _ => { + return Err(LimboError::InternalError(format!( + "Unknown aggregate type code: {type_code:?}" + ))) + } + }; + + Ok(agg_fn) + } + /// Create an AggregateFunction from a SQL function and its arguments /// Returns None if the function is not a supported aggregate pub fn from_sql_function( @@ -77,42 +182,6 @@ pub struct AggColumnInfo { pub has_max: bool, } -/// Serialize a Value using SQLite's serial type format -/// This is used for MIN/MAX values that need to be stored in a compact, sortable format -pub fn serialize_value(value: &Value, blob: &mut Vec) { - let serial_type = crate::types::SerialType::from(value); - let serial_type_u64: u64 = serial_type.into(); - crate::storage::sqlite3_ondisk::write_varint_to_vec(serial_type_u64, blob); - value.serialize_serial(blob); -} - -/// Deserialize a Value using SQLite's serial type format -/// Returns the deserialized value and the number of bytes consumed -pub fn deserialize_value(blob: &[u8]) -> Option<(Value, usize)> { - let mut cursor = 0; - - // Read the serial type - let (serial_type, varint_size) = crate::storage::sqlite3_ondisk::read_varint(blob).ok()?; - cursor += varint_size; - - let serial_type_obj = crate::types::SerialType::try_from(serial_type).ok()?; - let expected_size = serial_type_obj.size(); - - // Read the value - let (value, actual_size) = - crate::storage::sqlite3_ondisk::read_value(&blob[cursor..], serial_type_obj).ok()?; - - // Verify that the actual size matches what we expected from the serial type - if actual_size != expected_size { - return None; // Data corruption - size mismatch - } - - cursor += actual_size; - - // Convert RefValue to Value - Some((value.to_owned(), cursor)) -} - // group_key_str -> (group_key, state) type ComputedStates = HashMap, AggregateState)>; // group_key_str -> (column_index, value_as_hashable_row) -> accumulated_weight @@ -198,9 +267,9 @@ pub struct AggregateState { // For COUNT: just the count pub count: i64, // For SUM: column_index -> sum value - sums: HashMap, + pub sums: HashMap, // For AVG: column_index -> (sum, count) for computing average - avgs: HashMap, + pub avgs: HashMap, // For MIN: column_index -> minimum value pub mins: HashMap, // For MAX: column_index -> maximum value @@ -306,11 +375,9 @@ impl AggregateEvalState { // Only try to read if we have a rowid if let Some(rowid) = rowid { let key = SeekKey::TableRowId(*rowid); - let state = return_if_io!(read_record_state.read_record( - key, - &operator.aggregates, - &mut cursors.table_cursor - )); + let state = return_if_io!( + read_record_state.read_record(key, &mut cursors.table_cursor) + ); // Process the fetched state if let Some(state) = state { let mut old_row = group_key.clone(); @@ -368,196 +435,249 @@ impl AggregateState { Self::default() } - // Serialize the aggregate state to a binary blob including group key values - // The reason we serialize it like this, instead of just writing the actual values, is that - // The same table may have different aggregators in the circuit. They will all have different - // columns. - fn to_blob(&self, aggregates: &[AggregateFunction], group_key: &[Value]) -> Vec { - let mut blob = Vec::new(); + /// Convert the aggregate state to a vector of Values for unified serialization + /// Format: [count, num_aggregates, (agg_metadata, agg_state)...] + /// Each aggregate includes its type and column index for proper deserialization + pub fn to_value_vector(&self, aggregates: &[AggregateFunction]) -> Vec { + let mut values = Vec::new(); - // Write version byte for future compatibility - blob.push(1u8); + // Include count first + values.push(Value::Integer(self.count)); - // Write number of group key values - blob.extend_from_slice(&(group_key.len() as u32).to_le_bytes()); + // Store number of aggregates + values.push(Value::Integer(aggregates.len() as i64)); - // Write each group key value - for value in group_key { - // Write value type tag - match value { - Value::Null => blob.push(0u8), - Value::Integer(i) => { - blob.push(1u8); - blob.extend_from_slice(&i.to_le_bytes()); - } - Value::Float(f) => { - blob.push(2u8); - blob.extend_from_slice(&f.to_le_bytes()); - } - Value::Text(s) => { - blob.push(3u8); - let text_str = s.as_str(); - let bytes = text_str.as_bytes(); - blob.extend_from_slice(&(bytes.len() as u32).to_le_bytes()); - blob.extend_from_slice(bytes); - } - Value::Blob(b) => { - blob.push(4u8); - blob.extend_from_slice(&(b.len() as u32).to_le_bytes()); - blob.extend_from_slice(b); - } - } - } - - // Write count as 8 bytes (little-endian) - blob.extend_from_slice(&self.count.to_le_bytes()); - - // Write each aggregate's state + // Add each aggregate's metadata and state for agg in aggregates { + // First, add the aggregate function metadata (type and column index) + values.extend(agg.to_values()); + + // Then add the state for this aggregate match agg { - AggregateFunction::Sum(col_name) => { - let sum = self.sums.get(col_name).copied().unwrap_or(0.0); - blob.extend_from_slice(&sum.to_le_bytes()); - } - AggregateFunction::Avg(col_name) => { - let (sum, count) = self.avgs.get(col_name).copied().unwrap_or((0.0, 0)); - blob.extend_from_slice(&sum.to_le_bytes()); - blob.extend_from_slice(&count.to_le_bytes()); - } AggregateFunction::Count => { - // Count is already written above + // Count state is already stored at the beginning } - AggregateFunction::Min(col_name) => { - // Write whether we have a MIN value (1 byte) - if let Some(min_val) = self.mins.get(col_name) { - blob.push(1u8); // Has value - serialize_value(min_val, &mut blob); + AggregateFunction::Sum(col_idx) => { + let sum = self.sums.get(col_idx).copied().unwrap_or(0.0); + values.push(Value::Float(sum)); + } + AggregateFunction::Avg(col_idx) => { + let (sum, count) = self.avgs.get(col_idx).copied().unwrap_or((0.0, 0)); + values.push(Value::Float(sum)); + values.push(Value::Integer(count)); + } + AggregateFunction::Min(col_idx) => { + if let Some(min_val) = self.mins.get(col_idx) { + values.push(Value::Integer(1)); // Has value + values.push(min_val.clone()); } else { - blob.push(0u8); // No value + values.push(Value::Integer(0)); // No value } } - AggregateFunction::Max(col_name) => { - // Write whether we have a MAX value (1 byte) - if let Some(max_val) = self.maxs.get(col_name) { - blob.push(1u8); // Has value - serialize_value(max_val, &mut blob); + AggregateFunction::Max(col_idx) => { + if let Some(max_val) = self.maxs.get(col_idx) { + values.push(Value::Integer(1)); // Has value + values.push(max_val.clone()); } else { - blob.push(0u8); // No value + values.push(Value::Integer(0)); // No value } } } } - blob + values } - /// Deserialize aggregate state from a binary blob - /// Returns the aggregate state and the group key values - pub fn from_blob(blob: &[u8], aggregates: &[AggregateFunction]) -> Option<(Self, Vec)> { + /// Reconstruct aggregate state from a vector of Values + pub fn from_value_vector(values: &[Value]) -> Result { let mut cursor = 0; - - // Check version byte - if blob.get(cursor) != Some(&1u8) { - return None; - } - cursor += 1; - - // Read number of group key values - let num_group_keys = - u32::from_le_bytes(blob.get(cursor..cursor + 4)?.try_into().ok()?) as usize; - cursor += 4; - - // Read group key values - let mut group_key = Vec::new(); - for _ in 0..num_group_keys { - let value_type = *blob.get(cursor)?; - cursor += 1; - - let value = match value_type { - 0 => Value::Null, - 1 => { - let i = i64::from_le_bytes(blob.get(cursor..cursor + 8)?.try_into().ok()?); - cursor += 8; - Value::Integer(i) - } - 2 => { - let f = f64::from_le_bytes(blob.get(cursor..cursor + 8)?.try_into().ok()?); - cursor += 8; - Value::Float(f) - } - 3 => { - let len = - u32::from_le_bytes(blob.get(cursor..cursor + 4)?.try_into().ok()?) as usize; - cursor += 4; - let bytes = blob.get(cursor..cursor + len)?; - cursor += len; - let text_str = std::str::from_utf8(bytes).ok()?; - Value::Text(text_str.to_string().into()) - } - 4 => { - let len = - u32::from_le_bytes(blob.get(cursor..cursor + 4)?.try_into().ok()?) as usize; - cursor += 4; - let bytes = blob.get(cursor..cursor + len)?; - cursor += len; - Value::Blob(bytes.to_vec()) - } - _ => return None, - }; - group_key.push(value); - } + let mut state = Self::new(); // Read count - let count = i64::from_le_bytes(blob.get(cursor..cursor + 8)?.try_into().ok()?); - cursor += 8; + let count = values + .get(cursor) + .ok_or_else(|| LimboError::InternalError("Aggregate state missing count".into()))?; + if let Value::Integer(count) = count { + state.count = *count; + cursor += 1; + } else { + return Err(LimboError::InternalError(format!( + "Expected Integer for count, got {count:?}" + ))); + } - let mut state = Self::new(); - state.count = count; + // Read number of aggregates + let num_aggregates = values + .get(cursor) + .ok_or_else(|| LimboError::InternalError("Missing number of aggregates".into()))?; + let num_aggregates = match num_aggregates { + Value::Integer(n) => *n as usize, + _ => { + return Err(LimboError::InternalError(format!( + "Expected Integer for aggregate count, got {num_aggregates:?}" + ))) + } + }; + cursor += 1; - // Read each aggregate's state - for agg in aggregates { - match agg { - AggregateFunction::Sum(col_name) => { - let sum = f64::from_le_bytes(blob.get(cursor..cursor + 8)?.try_into().ok()?); - cursor += 8; - state.sums.insert(*col_name, sum); - } - AggregateFunction::Avg(col_name) => { - let sum = f64::from_le_bytes(blob.get(cursor..cursor + 8)?.try_into().ok()?); - cursor += 8; - let count = i64::from_le_bytes(blob.get(cursor..cursor + 8)?.try_into().ok()?); - cursor += 8; - state.avgs.insert(*col_name, (sum, count)); - } + // Read each aggregate's state with type and column index + for _ in 0..num_aggregates { + // Deserialize the aggregate function metadata + let agg_fn = AggregateFunction::from_values(values, &mut cursor)?; + + // Read the state for this aggregate + match agg_fn { AggregateFunction::Count => { - // Count was already read above + // Count state is already stored at the beginning } - AggregateFunction::Min(col_name) => { - // Read whether we have a MIN value - let has_value = *blob.get(cursor)?; - cursor += 1; - - if has_value == 1 { - let (min_value, bytes_consumed) = deserialize_value(&blob[cursor..])?; - cursor += bytes_consumed; - state.mins.insert(*col_name, min_value); + AggregateFunction::Sum(col_idx) => { + let sum = values + .get(cursor) + .ok_or_else(|| LimboError::InternalError("Missing SUM value".into()))?; + if let Value::Float(sum) = sum { + state.sums.insert(col_idx, *sum); + cursor += 1; + } else { + return Err(LimboError::InternalError(format!( + "Expected Float for SUM value, got {sum:?}" + ))); } } - AggregateFunction::Max(col_name) => { - // Read whether we have a MAX value - let has_value = *blob.get(cursor)?; + AggregateFunction::Avg(col_idx) => { + let sum = values + .get(cursor) + .ok_or_else(|| LimboError::InternalError("Missing AVG sum value".into()))?; + let sum = match sum { + Value::Float(f) => *f, + _ => { + return Err(LimboError::InternalError(format!( + "Expected Float for AVG sum, got {sum:?}" + ))) + } + }; cursor += 1; - if has_value == 1 { - let (max_value, bytes_consumed) = deserialize_value(&blob[cursor..])?; - cursor += bytes_consumed; - state.maxs.insert(*col_name, max_value); + let count = values.get(cursor).ok_or_else(|| { + LimboError::InternalError("Missing AVG count value".into()) + })?; + let count = match count { + Value::Integer(i) => *i, + _ => { + return Err(LimboError::InternalError(format!( + "Expected Integer for AVG count, got {count:?}" + ))) + } + }; + cursor += 1; + + state.avgs.insert(col_idx, (sum, count)); + } + AggregateFunction::Min(col_idx) => { + let has_value = values.get(cursor).ok_or_else(|| { + LimboError::InternalError("Missing MIN has_value flag".into()) + })?; + if let Value::Integer(has_value) = has_value { + cursor += 1; + if *has_value == 1 { + let min_val = values + .get(cursor) + .ok_or_else(|| { + LimboError::InternalError("Missing MIN value".into()) + })? + .clone(); + cursor += 1; + state.mins.insert(col_idx, min_val); + } + } else { + return Err(LimboError::InternalError(format!( + "Expected Integer for MIN has_value flag, got {has_value:?}" + ))); + } + } + AggregateFunction::Max(col_idx) => { + let has_value = values.get(cursor).ok_or_else(|| { + LimboError::InternalError("Missing MAX has_value flag".into()) + })?; + if let Value::Integer(has_value) = has_value { + cursor += 1; + if *has_value == 1 { + let max_val = values + .get(cursor) + .ok_or_else(|| { + LimboError::InternalError("Missing MAX value".into()) + })? + .clone(); + cursor += 1; + state.maxs.insert(col_idx, max_val); + } + } else { + return Err(LimboError::InternalError(format!( + "Expected Integer for MAX has_value flag, got {has_value:?}" + ))); } } } } - Some((state, group_key)) + Ok(state) + } + + fn to_blob(&self, aggregates: &[AggregateFunction], group_key: &[Value]) -> Vec { + let mut all_values = Vec::new(); + // Store the group key size first + all_values.push(Value::Integer(group_key.len() as i64)); + all_values.extend_from_slice(group_key); + all_values.extend(self.to_value_vector(aggregates)); + + let record = ImmutableRecord::from_values(&all_values, all_values.len()); + record.as_blob().clone() + } + + pub fn from_blob(blob: &[u8]) -> Result<(Self, Vec)> { + let record = ImmutableRecord::from_bin_record(blob.to_vec()); + let ref_values = record.get_values(); + let mut all_values: Vec = ref_values.into_iter().map(|rv| rv.to_owned()).collect(); + + if all_values.is_empty() { + return Err(LimboError::InternalError( + "Aggregate state blob is empty".into(), + )); + } + + // Read the group key size + let group_key_count = match &all_values[0] { + Value::Integer(n) if *n >= 0 => *n as usize, + Value::Integer(n) => { + return Err(LimboError::InternalError(format!( + "Negative group key count: {n}" + ))) + } + other => { + return Err(LimboError::InternalError(format!( + "Expected Integer for group key count, got {other:?}" + ))) + } + }; + + // Remove the group key count from the values + all_values.remove(0); + + if all_values.len() < group_key_count { + return Err(LimboError::InternalError(format!( + "Blob too short: expected at least {} values for group key, got {}", + group_key_count, + all_values.len() + ))); + } + + // Split into group key and state values + let group_key = all_values[..group_key_count].to_vec(); + let state_values = &all_values[group_key_count..]; + + // Reconstruct the aggregate state + let state = Self::from_value_vector(state_values)?; + + Ok((state, group_key)) } /// Apply a delta to this aggregate state diff --git a/core/incremental/join_operator.rs b/core/incremental/join_operator.rs index f5ffb9b55..d0b799a8d 100644 --- a/core/incremental/join_operator.rs +++ b/core/incremental/join_operator.rs @@ -518,124 +518,44 @@ impl JoinOperator { } } -// Helper to deserialize a HashableRow from a blob fn deserialize_hashable_row(blob: &[u8]) -> Result { - // Simple deserialization - this needs to match how we serialize in commit - // Format: [rowid:8 bytes][num_values:4 bytes][values...] - if blob.len() < 12 { + use crate::types::ImmutableRecord; + + let record = ImmutableRecord::from_bin_record(blob.to_vec()); + let ref_values = record.get_values(); + let all_values: Vec = ref_values.into_iter().map(|rv| rv.to_owned()).collect(); + + if all_values.is_empty() { return Err(crate::LimboError::InternalError( - "Invalid blob size".to_string(), + "HashableRow blob must contain at least rowid".to_string(), )); } - let rowid = i64::from_le_bytes(blob[0..8].try_into().unwrap()); - let num_values = u32::from_le_bytes(blob[8..12].try_into().unwrap()) as usize; - - let mut values = Vec::new(); - let mut offset = 12; - - for _ in 0..num_values { - if offset >= blob.len() { - break; + // First value is the rowid + let rowid = match &all_values[0] { + Value::Integer(i) => *i, + _ => { + return Err(crate::LimboError::InternalError( + "First value must be rowid (integer)".to_string(), + )) } + }; - let type_tag = blob[offset]; - offset += 1; - - match type_tag { - 0 => values.push(Value::Null), - 1 => { - if offset + 8 <= blob.len() { - let i = i64::from_le_bytes(blob[offset..offset + 8].try_into().unwrap()); - values.push(Value::Integer(i)); - offset += 8; - } - } - 2 => { - if offset + 8 <= blob.len() { - let f = f64::from_le_bytes(blob[offset..offset + 8].try_into().unwrap()); - values.push(Value::Float(f)); - offset += 8; - } - } - 3 => { - if offset + 4 <= blob.len() { - let len = - u32::from_le_bytes(blob[offset..offset + 4].try_into().unwrap()) as usize; - offset += 4; - if offset + len < blob.len() { - let text_bytes = blob[offset..offset + len].to_vec(); - offset += len; - let subtype = match blob[offset] { - 0 => crate::types::TextSubtype::Text, - 1 => crate::types::TextSubtype::Json, - _ => crate::types::TextSubtype::Text, - }; - offset += 1; - values.push(Value::Text(crate::types::Text { - value: text_bytes, - subtype, - })); - } - } - } - 4 => { - if offset + 4 <= blob.len() { - let len = - u32::from_le_bytes(blob[offset..offset + 4].try_into().unwrap()) as usize; - offset += 4; - if offset + len <= blob.len() { - let blob_data = blob[offset..offset + len].to_vec(); - values.push(Value::Blob(blob_data)); - offset += len; - } - } - } - _ => break, // Unknown type tag - } - } + // Rest are the row values + let values = all_values[1..].to_vec(); Ok(HashableRow::new(rowid, values)) } -// Helper to serialize a HashableRow to a blob fn serialize_hashable_row(row: &HashableRow) -> Vec { - let mut blob = Vec::new(); + use crate::types::ImmutableRecord; - // Write rowid - blob.extend_from_slice(&row.rowid.to_le_bytes()); + let mut all_values = Vec::with_capacity(row.values.len() + 1); + all_values.push(Value::Integer(row.rowid)); + all_values.extend_from_slice(&row.values); - // Write number of values - blob.extend_from_slice(&(row.values.len() as u32).to_le_bytes()); - - // Write each value directly with type tags (like AggregateState does) - for value in &row.values { - match value { - Value::Null => blob.push(0u8), - Value::Integer(i) => { - blob.push(1u8); - blob.extend_from_slice(&i.to_le_bytes()); - } - Value::Float(f) => { - blob.push(2u8); - blob.extend_from_slice(&f.to_le_bytes()); - } - Value::Text(s) => { - blob.push(3u8); - let bytes = &s.value; - blob.extend_from_slice(&(bytes.len() as u32).to_le_bytes()); - blob.extend_from_slice(bytes); - blob.push(s.subtype as u8); - } - Value::Blob(b) => { - blob.push(4u8); - blob.extend_from_slice(&(b.len() as u32).to_le_bytes()); - blob.extend_from_slice(b); - } - } - } - - blob + let record = ImmutableRecord::from_values(&all_values, all_values.len()); + record.as_blob().clone() } impl IncrementalOperator for JoinOperator { diff --git a/core/incremental/operator.rs b/core/incremental/operator.rs index 1b3c5a487..0dec705f9 100644 --- a/core/incremental/operator.rs +++ b/core/incremental/operator.rs @@ -332,20 +332,25 @@ mod tests { // Get the blob data from column 3 (value column) if let Some(Value::Blob(blob)) = values.get(3) { // Deserialize the state - if let Some((state, group_key)) = - AggregateState::from_blob(blob, &agg.aggregates) - { - // Should not have made it this far. - assert!(state.count != 0); - // Build output row: group_by columns + aggregate values - let mut output_values = group_key.clone(); - output_values.extend(state.to_values(&agg.aggregates)); + match AggregateState::from_blob(blob) { + Ok((state, group_key)) => { + // Should not have made it this far. + assert!(state.count != 0); + // Build output row: group_by columns + aggregate values + let mut output_values = group_key.clone(); + output_values.extend(state.to_values(&agg.aggregates)); - let group_key_str = AggregateOperator::group_key_to_string(&group_key); - let rowid = agg.generate_group_rowid(&group_key_str); + let group_key_str = AggregateOperator::group_key_to_string(&group_key); + let rowid = agg.generate_group_rowid(&group_key_str); - let output_row = HashableRow::new(rowid, output_values); - result.changes.push((output_row, 1)); + let output_row = HashableRow::new(rowid, output_values); + result.changes.push((output_row, 1)); + } + Err(e) => { + // Log or handle the deserialization error + // For now, we'll skip this entry + eprintln!("Failed to deserialize aggregate state: {e}"); + } } } } @@ -4011,4 +4016,115 @@ mod tests { panic!("Expected Done result"); } } + + #[test] + fn test_aggregate_serialization_with_different_column_indices() { + // Test that aggregate state serialization correctly preserves column indices + // when multiple aggregates operate on different columns + let (pager, table_root_page_id, index_root_page_id) = create_test_pager(); + let table_cursor = BTreeCursor::new_table(None, pager.clone(), table_root_page_id, 5); + let index_def = create_dbsp_state_index(index_root_page_id); + let index_cursor = + BTreeCursor::new_index(None, pager.clone(), index_root_page_id, &index_def, 4); + let mut cursors = DbspStateCursors::new(table_cursor, index_cursor); + + // Create first operator with SUM(col1), MIN(col3) GROUP BY col0 + let mut agg1 = AggregateOperator::new( + 1, + vec![0], + vec![AggregateFunction::Sum(1), AggregateFunction::Min(3)], + vec![ + "group".to_string(), + "val1".to_string(), + "val2".to_string(), + "val3".to_string(), + ], + ); + + // Add initial data + let mut delta = Delta::new(); + delta.insert( + 1, + vec![ + Value::Text("A".into()), + Value::Integer(10), + Value::Integer(100), + Value::Integer(5), + ], + ); + delta.insert( + 2, + vec![ + Value::Text("A".into()), + Value::Integer(15), + Value::Integer(200), + Value::Integer(3), + ], + ); + + let result1 = pager + .io + .block(|| agg1.commit((&delta).into(), &mut cursors)) + .unwrap(); + + assert_eq!(result1.changes.len(), 1); + let (row1, _) = &result1.changes[0]; + assert_eq!(row1.values[0], Value::Text("A".into())); + assert_eq!(row1.values[1], Value::Integer(25)); // SUM(val1) = 10 + 15 + assert_eq!(row1.values[2], Value::Integer(3)); // MIN(val3) = min(5, 3) + + // Create operator with same ID but different column mappings: SUM(col3), MIN(col1) + let mut agg2 = AggregateOperator::new( + 1, // Same operator_id + vec![0], + vec![AggregateFunction::Sum(3), AggregateFunction::Min(1)], + vec![ + "group".to_string(), + "val1".to_string(), + "val2".to_string(), + "val3".to_string(), + ], + ); + + // Process new data + let mut delta2 = Delta::new(); + delta2.insert( + 3, + vec![ + Value::Text("A".into()), + Value::Integer(20), + Value::Integer(300), + Value::Integer(4), + ], + ); + + let result2 = pager + .io + .block(|| agg2.commit((&delta2).into(), &mut cursors)) + .unwrap(); + + // Find the positive weight row for group A (the updated aggregate) + let row2 = result2 + .changes + .iter() + .find(|(row, weight)| row.values[0] == Value::Text("A".into()) && *weight > 0) + .expect("Should have a positive weight row for group A"); + let (row2, _) = row2; + + // Verify that column indices are preserved correctly in serialization + // When agg2 processes the data with different column mappings: + // - It reads the existing state which has SUM(col1)=25 and MIN(col3)=3 + // - For SUM(col3), there's no existing state, so it starts fresh: 4 + // - For MIN(col1), there's no existing state, so it starts fresh: 20 + assert_eq!( + row2.values[1], + Value::Integer(4), + "SUM(col3) should be 4 (new data only)" + ); + assert_eq!( + row2.values[2], + Value::Integer(20), + "MIN(col1) should be 20 (new data only)" + ); + } } diff --git a/core/incremental/persistence.rs b/core/incremental/persistence.rs index 5cf41b94a..81d0837c2 100644 --- a/core/incremental/persistence.rs +++ b/core/incremental/persistence.rs @@ -1,4 +1,4 @@ -use crate::incremental::operator::{AggregateFunction, AggregateState, DbspStateCursors}; +use crate::incremental::operator::{AggregateState, DbspStateCursors}; use crate::storage::btree::{BTreeCursor, BTreeKey}; use crate::types::{IOResult, ImmutableRecord, SeekKey, SeekOp, SeekResult}; use crate::{return_if_io, LimboError, Result, Value}; @@ -20,7 +20,6 @@ impl ReadRecord { pub fn read_record( &mut self, key: SeekKey, - aggregates: &[AggregateFunction], cursor: &mut BTreeCursor, ) -> Result>> { loop { @@ -41,12 +40,7 @@ impl ReadRecord { let blob = values[3].to_owned(); let (state, _group_key) = match blob { - Value::Blob(blob) => AggregateState::from_blob(&blob, aggregates) - .ok_or_else(|| { - LimboError::InternalError(format!( - "Cannot deserialize aggregate state {blob:?}", - )) - }), + Value::Blob(blob) => AggregateState::from_blob(&blob), _ => Err(LimboError::ParseError( "Value in aggregator not blob".to_string(), )),