mirror of
https://github.com/aljazceru/turso.git
synced 2025-12-19 01:24:20 +01:00
Replace custom serialization with a saner version
The Materialized View code had custom serialization written so we could move this code forward. Now that we have many operators and the views work, replace it with something saner. The main insight is that if we transform the AggregateState into Values before the serialization, we are able to just use standard SQLite serialization for the values. We then just have to add sizes, codes for the functions, etc (which are also represented as Values).
This commit is contained in:
@@ -16,6 +16,13 @@ use std::sync::{Arc, Mutex};
|
||||
pub const AGG_TYPE_REGULAR: u8 = 0b00; // COUNT/SUM/AVG
|
||||
pub const AGG_TYPE_MINMAX: u8 = 0b01; // MIN/MAX (BTree ordering gives both)
|
||||
|
||||
// Serialization type codes for aggregate functions
|
||||
const AGG_FUNC_COUNT: i64 = 0;
|
||||
const AGG_FUNC_SUM: i64 = 1;
|
||||
const AGG_FUNC_AVG: i64 = 2;
|
||||
const AGG_FUNC_MIN: i64 = 3;
|
||||
const AGG_FUNC_MAX: i64 = 4;
|
||||
|
||||
#[derive(Debug, Clone, PartialEq)]
|
||||
pub enum AggregateFunction {
|
||||
Count,
|
||||
@@ -44,6 +51,104 @@ impl AggregateFunction {
|
||||
self.to_string()
|
||||
}
|
||||
|
||||
/// Serialize this aggregate function to a Value
|
||||
/// Returns a vector of values: [type_code, optional_column_index]
|
||||
pub fn to_values(&self) -> Vec<Value> {
|
||||
match self {
|
||||
AggregateFunction::Count => vec![Value::Integer(AGG_FUNC_COUNT)],
|
||||
AggregateFunction::Sum(idx) => {
|
||||
vec![Value::Integer(AGG_FUNC_SUM), Value::Integer(*idx as i64)]
|
||||
}
|
||||
AggregateFunction::Avg(idx) => {
|
||||
vec![Value::Integer(AGG_FUNC_AVG), Value::Integer(*idx as i64)]
|
||||
}
|
||||
AggregateFunction::Min(idx) => {
|
||||
vec![Value::Integer(AGG_FUNC_MIN), Value::Integer(*idx as i64)]
|
||||
}
|
||||
AggregateFunction::Max(idx) => {
|
||||
vec![Value::Integer(AGG_FUNC_MAX), Value::Integer(*idx as i64)]
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Deserialize an aggregate function from values
|
||||
/// Consumes values from the cursor and returns the aggregate function
|
||||
pub fn from_values(values: &[Value], cursor: &mut usize) -> Result<Self> {
|
||||
let type_code = values
|
||||
.get(*cursor)
|
||||
.ok_or_else(|| LimboError::InternalError("Missing aggregate type code".into()))?;
|
||||
|
||||
let agg_fn = match type_code {
|
||||
Value::Integer(AGG_FUNC_COUNT) => {
|
||||
*cursor += 1;
|
||||
AggregateFunction::Count
|
||||
}
|
||||
Value::Integer(AGG_FUNC_SUM) => {
|
||||
*cursor += 1;
|
||||
let idx = values
|
||||
.get(*cursor)
|
||||
.ok_or_else(|| LimboError::InternalError("Missing SUM column index".into()))?;
|
||||
if let Value::Integer(idx) = idx {
|
||||
*cursor += 1;
|
||||
AggregateFunction::Sum(*idx as usize)
|
||||
} else {
|
||||
return Err(LimboError::InternalError(format!(
|
||||
"Expected Integer for SUM column index, got {idx:?}"
|
||||
)));
|
||||
}
|
||||
}
|
||||
Value::Integer(AGG_FUNC_AVG) => {
|
||||
*cursor += 1;
|
||||
let idx = values
|
||||
.get(*cursor)
|
||||
.ok_or_else(|| LimboError::InternalError("Missing AVG column index".into()))?;
|
||||
if let Value::Integer(idx) = idx {
|
||||
*cursor += 1;
|
||||
AggregateFunction::Avg(*idx as usize)
|
||||
} else {
|
||||
return Err(LimboError::InternalError(format!(
|
||||
"Expected Integer for AVG column index, got {idx:?}"
|
||||
)));
|
||||
}
|
||||
}
|
||||
Value::Integer(AGG_FUNC_MIN) => {
|
||||
*cursor += 1;
|
||||
let idx = values
|
||||
.get(*cursor)
|
||||
.ok_or_else(|| LimboError::InternalError("Missing MIN column index".into()))?;
|
||||
if let Value::Integer(idx) = idx {
|
||||
*cursor += 1;
|
||||
AggregateFunction::Min(*idx as usize)
|
||||
} else {
|
||||
return Err(LimboError::InternalError(format!(
|
||||
"Expected Integer for MIN column index, got {idx:?}"
|
||||
)));
|
||||
}
|
||||
}
|
||||
Value::Integer(AGG_FUNC_MAX) => {
|
||||
*cursor += 1;
|
||||
let idx = values
|
||||
.get(*cursor)
|
||||
.ok_or_else(|| LimboError::InternalError("Missing MAX column index".into()))?;
|
||||
if let Value::Integer(idx) = idx {
|
||||
*cursor += 1;
|
||||
AggregateFunction::Max(*idx as usize)
|
||||
} else {
|
||||
return Err(LimboError::InternalError(format!(
|
||||
"Expected Integer for MAX column index, got {idx:?}"
|
||||
)));
|
||||
}
|
||||
}
|
||||
_ => {
|
||||
return Err(LimboError::InternalError(format!(
|
||||
"Unknown aggregate type code: {type_code:?}"
|
||||
)))
|
||||
}
|
||||
};
|
||||
|
||||
Ok(agg_fn)
|
||||
}
|
||||
|
||||
/// Create an AggregateFunction from a SQL function and its arguments
|
||||
/// Returns None if the function is not a supported aggregate
|
||||
pub fn from_sql_function(
|
||||
@@ -77,42 +182,6 @@ pub struct AggColumnInfo {
|
||||
pub has_max: bool,
|
||||
}
|
||||
|
||||
/// Serialize a Value using SQLite's serial type format
|
||||
/// This is used for MIN/MAX values that need to be stored in a compact, sortable format
|
||||
pub fn serialize_value(value: &Value, blob: &mut Vec<u8>) {
|
||||
let serial_type = crate::types::SerialType::from(value);
|
||||
let serial_type_u64: u64 = serial_type.into();
|
||||
crate::storage::sqlite3_ondisk::write_varint_to_vec(serial_type_u64, blob);
|
||||
value.serialize_serial(blob);
|
||||
}
|
||||
|
||||
/// Deserialize a Value using SQLite's serial type format
|
||||
/// Returns the deserialized value and the number of bytes consumed
|
||||
pub fn deserialize_value(blob: &[u8]) -> Option<(Value, usize)> {
|
||||
let mut cursor = 0;
|
||||
|
||||
// Read the serial type
|
||||
let (serial_type, varint_size) = crate::storage::sqlite3_ondisk::read_varint(blob).ok()?;
|
||||
cursor += varint_size;
|
||||
|
||||
let serial_type_obj = crate::types::SerialType::try_from(serial_type).ok()?;
|
||||
let expected_size = serial_type_obj.size();
|
||||
|
||||
// Read the value
|
||||
let (value, actual_size) =
|
||||
crate::storage::sqlite3_ondisk::read_value(&blob[cursor..], serial_type_obj).ok()?;
|
||||
|
||||
// Verify that the actual size matches what we expected from the serial type
|
||||
if actual_size != expected_size {
|
||||
return None; // Data corruption - size mismatch
|
||||
}
|
||||
|
||||
cursor += actual_size;
|
||||
|
||||
// Convert RefValue to Value
|
||||
Some((value.to_owned(), cursor))
|
||||
}
|
||||
|
||||
// group_key_str -> (group_key, state)
|
||||
type ComputedStates = HashMap<String, (Vec<Value>, AggregateState)>;
|
||||
// group_key_str -> (column_index, value_as_hashable_row) -> accumulated_weight
|
||||
@@ -198,9 +267,9 @@ pub struct AggregateState {
|
||||
// For COUNT: just the count
|
||||
pub count: i64,
|
||||
// For SUM: column_index -> sum value
|
||||
sums: HashMap<usize, f64>,
|
||||
pub sums: HashMap<usize, f64>,
|
||||
// For AVG: column_index -> (sum, count) for computing average
|
||||
avgs: HashMap<usize, (f64, i64)>,
|
||||
pub avgs: HashMap<usize, (f64, i64)>,
|
||||
// For MIN: column_index -> minimum value
|
||||
pub mins: HashMap<usize, Value>,
|
||||
// For MAX: column_index -> maximum value
|
||||
@@ -306,11 +375,9 @@ impl AggregateEvalState {
|
||||
// Only try to read if we have a rowid
|
||||
if let Some(rowid) = rowid {
|
||||
let key = SeekKey::TableRowId(*rowid);
|
||||
let state = return_if_io!(read_record_state.read_record(
|
||||
key,
|
||||
&operator.aggregates,
|
||||
&mut cursors.table_cursor
|
||||
));
|
||||
let state = return_if_io!(
|
||||
read_record_state.read_record(key, &mut cursors.table_cursor)
|
||||
);
|
||||
// Process the fetched state
|
||||
if let Some(state) = state {
|
||||
let mut old_row = group_key.clone();
|
||||
@@ -368,196 +435,249 @@ impl AggregateState {
|
||||
Self::default()
|
||||
}
|
||||
|
||||
// Serialize the aggregate state to a binary blob including group key values
|
||||
// The reason we serialize it like this, instead of just writing the actual values, is that
|
||||
// The same table may have different aggregators in the circuit. They will all have different
|
||||
// columns.
|
||||
fn to_blob(&self, aggregates: &[AggregateFunction], group_key: &[Value]) -> Vec<u8> {
|
||||
let mut blob = Vec::new();
|
||||
/// Convert the aggregate state to a vector of Values for unified serialization
|
||||
/// Format: [count, num_aggregates, (agg_metadata, agg_state)...]
|
||||
/// Each aggregate includes its type and column index for proper deserialization
|
||||
pub fn to_value_vector(&self, aggregates: &[AggregateFunction]) -> Vec<Value> {
|
||||
let mut values = Vec::new();
|
||||
|
||||
// Write version byte for future compatibility
|
||||
blob.push(1u8);
|
||||
// Include count first
|
||||
values.push(Value::Integer(self.count));
|
||||
|
||||
// Write number of group key values
|
||||
blob.extend_from_slice(&(group_key.len() as u32).to_le_bytes());
|
||||
// Store number of aggregates
|
||||
values.push(Value::Integer(aggregates.len() as i64));
|
||||
|
||||
// Write each group key value
|
||||
for value in group_key {
|
||||
// Write value type tag
|
||||
match value {
|
||||
Value::Null => blob.push(0u8),
|
||||
Value::Integer(i) => {
|
||||
blob.push(1u8);
|
||||
blob.extend_from_slice(&i.to_le_bytes());
|
||||
}
|
||||
Value::Float(f) => {
|
||||
blob.push(2u8);
|
||||
blob.extend_from_slice(&f.to_le_bytes());
|
||||
}
|
||||
Value::Text(s) => {
|
||||
blob.push(3u8);
|
||||
let text_str = s.as_str();
|
||||
let bytes = text_str.as_bytes();
|
||||
blob.extend_from_slice(&(bytes.len() as u32).to_le_bytes());
|
||||
blob.extend_from_slice(bytes);
|
||||
}
|
||||
Value::Blob(b) => {
|
||||
blob.push(4u8);
|
||||
blob.extend_from_slice(&(b.len() as u32).to_le_bytes());
|
||||
blob.extend_from_slice(b);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Write count as 8 bytes (little-endian)
|
||||
blob.extend_from_slice(&self.count.to_le_bytes());
|
||||
|
||||
// Write each aggregate's state
|
||||
// Add each aggregate's metadata and state
|
||||
for agg in aggregates {
|
||||
// First, add the aggregate function metadata (type and column index)
|
||||
values.extend(agg.to_values());
|
||||
|
||||
// Then add the state for this aggregate
|
||||
match agg {
|
||||
AggregateFunction::Sum(col_name) => {
|
||||
let sum = self.sums.get(col_name).copied().unwrap_or(0.0);
|
||||
blob.extend_from_slice(&sum.to_le_bytes());
|
||||
}
|
||||
AggregateFunction::Avg(col_name) => {
|
||||
let (sum, count) = self.avgs.get(col_name).copied().unwrap_or((0.0, 0));
|
||||
blob.extend_from_slice(&sum.to_le_bytes());
|
||||
blob.extend_from_slice(&count.to_le_bytes());
|
||||
}
|
||||
AggregateFunction::Count => {
|
||||
// Count is already written above
|
||||
// Count state is already stored at the beginning
|
||||
}
|
||||
AggregateFunction::Min(col_name) => {
|
||||
// Write whether we have a MIN value (1 byte)
|
||||
if let Some(min_val) = self.mins.get(col_name) {
|
||||
blob.push(1u8); // Has value
|
||||
serialize_value(min_val, &mut blob);
|
||||
AggregateFunction::Sum(col_idx) => {
|
||||
let sum = self.sums.get(col_idx).copied().unwrap_or(0.0);
|
||||
values.push(Value::Float(sum));
|
||||
}
|
||||
AggregateFunction::Avg(col_idx) => {
|
||||
let (sum, count) = self.avgs.get(col_idx).copied().unwrap_or((0.0, 0));
|
||||
values.push(Value::Float(sum));
|
||||
values.push(Value::Integer(count));
|
||||
}
|
||||
AggregateFunction::Min(col_idx) => {
|
||||
if let Some(min_val) = self.mins.get(col_idx) {
|
||||
values.push(Value::Integer(1)); // Has value
|
||||
values.push(min_val.clone());
|
||||
} else {
|
||||
blob.push(0u8); // No value
|
||||
values.push(Value::Integer(0)); // No value
|
||||
}
|
||||
}
|
||||
AggregateFunction::Max(col_name) => {
|
||||
// Write whether we have a MAX value (1 byte)
|
||||
if let Some(max_val) = self.maxs.get(col_name) {
|
||||
blob.push(1u8); // Has value
|
||||
serialize_value(max_val, &mut blob);
|
||||
AggregateFunction::Max(col_idx) => {
|
||||
if let Some(max_val) = self.maxs.get(col_idx) {
|
||||
values.push(Value::Integer(1)); // Has value
|
||||
values.push(max_val.clone());
|
||||
} else {
|
||||
blob.push(0u8); // No value
|
||||
values.push(Value::Integer(0)); // No value
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
blob
|
||||
values
|
||||
}
|
||||
|
||||
/// Deserialize aggregate state from a binary blob
|
||||
/// Returns the aggregate state and the group key values
|
||||
pub fn from_blob(blob: &[u8], aggregates: &[AggregateFunction]) -> Option<(Self, Vec<Value>)> {
|
||||
/// Reconstruct aggregate state from a vector of Values
|
||||
pub fn from_value_vector(values: &[Value]) -> Result<Self> {
|
||||
let mut cursor = 0;
|
||||
|
||||
// Check version byte
|
||||
if blob.get(cursor) != Some(&1u8) {
|
||||
return None;
|
||||
}
|
||||
cursor += 1;
|
||||
|
||||
// Read number of group key values
|
||||
let num_group_keys =
|
||||
u32::from_le_bytes(blob.get(cursor..cursor + 4)?.try_into().ok()?) as usize;
|
||||
cursor += 4;
|
||||
|
||||
// Read group key values
|
||||
let mut group_key = Vec::new();
|
||||
for _ in 0..num_group_keys {
|
||||
let value_type = *blob.get(cursor)?;
|
||||
cursor += 1;
|
||||
|
||||
let value = match value_type {
|
||||
0 => Value::Null,
|
||||
1 => {
|
||||
let i = i64::from_le_bytes(blob.get(cursor..cursor + 8)?.try_into().ok()?);
|
||||
cursor += 8;
|
||||
Value::Integer(i)
|
||||
}
|
||||
2 => {
|
||||
let f = f64::from_le_bytes(blob.get(cursor..cursor + 8)?.try_into().ok()?);
|
||||
cursor += 8;
|
||||
Value::Float(f)
|
||||
}
|
||||
3 => {
|
||||
let len =
|
||||
u32::from_le_bytes(blob.get(cursor..cursor + 4)?.try_into().ok()?) as usize;
|
||||
cursor += 4;
|
||||
let bytes = blob.get(cursor..cursor + len)?;
|
||||
cursor += len;
|
||||
let text_str = std::str::from_utf8(bytes).ok()?;
|
||||
Value::Text(text_str.to_string().into())
|
||||
}
|
||||
4 => {
|
||||
let len =
|
||||
u32::from_le_bytes(blob.get(cursor..cursor + 4)?.try_into().ok()?) as usize;
|
||||
cursor += 4;
|
||||
let bytes = blob.get(cursor..cursor + len)?;
|
||||
cursor += len;
|
||||
Value::Blob(bytes.to_vec())
|
||||
}
|
||||
_ => return None,
|
||||
};
|
||||
group_key.push(value);
|
||||
}
|
||||
let mut state = Self::new();
|
||||
|
||||
// Read count
|
||||
let count = i64::from_le_bytes(blob.get(cursor..cursor + 8)?.try_into().ok()?);
|
||||
cursor += 8;
|
||||
let count = values
|
||||
.get(cursor)
|
||||
.ok_or_else(|| LimboError::InternalError("Aggregate state missing count".into()))?;
|
||||
if let Value::Integer(count) = count {
|
||||
state.count = *count;
|
||||
cursor += 1;
|
||||
} else {
|
||||
return Err(LimboError::InternalError(format!(
|
||||
"Expected Integer for count, got {count:?}"
|
||||
)));
|
||||
}
|
||||
|
||||
let mut state = Self::new();
|
||||
state.count = count;
|
||||
// Read number of aggregates
|
||||
let num_aggregates = values
|
||||
.get(cursor)
|
||||
.ok_or_else(|| LimboError::InternalError("Missing number of aggregates".into()))?;
|
||||
let num_aggregates = match num_aggregates {
|
||||
Value::Integer(n) => *n as usize,
|
||||
_ => {
|
||||
return Err(LimboError::InternalError(format!(
|
||||
"Expected Integer for aggregate count, got {num_aggregates:?}"
|
||||
)))
|
||||
}
|
||||
};
|
||||
cursor += 1;
|
||||
|
||||
// Read each aggregate's state
|
||||
for agg in aggregates {
|
||||
match agg {
|
||||
AggregateFunction::Sum(col_name) => {
|
||||
let sum = f64::from_le_bytes(blob.get(cursor..cursor + 8)?.try_into().ok()?);
|
||||
cursor += 8;
|
||||
state.sums.insert(*col_name, sum);
|
||||
}
|
||||
AggregateFunction::Avg(col_name) => {
|
||||
let sum = f64::from_le_bytes(blob.get(cursor..cursor + 8)?.try_into().ok()?);
|
||||
cursor += 8;
|
||||
let count = i64::from_le_bytes(blob.get(cursor..cursor + 8)?.try_into().ok()?);
|
||||
cursor += 8;
|
||||
state.avgs.insert(*col_name, (sum, count));
|
||||
}
|
||||
// Read each aggregate's state with type and column index
|
||||
for _ in 0..num_aggregates {
|
||||
// Deserialize the aggregate function metadata
|
||||
let agg_fn = AggregateFunction::from_values(values, &mut cursor)?;
|
||||
|
||||
// Read the state for this aggregate
|
||||
match agg_fn {
|
||||
AggregateFunction::Count => {
|
||||
// Count was already read above
|
||||
// Count state is already stored at the beginning
|
||||
}
|
||||
AggregateFunction::Min(col_name) => {
|
||||
// Read whether we have a MIN value
|
||||
let has_value = *blob.get(cursor)?;
|
||||
cursor += 1;
|
||||
|
||||
if has_value == 1 {
|
||||
let (min_value, bytes_consumed) = deserialize_value(&blob[cursor..])?;
|
||||
cursor += bytes_consumed;
|
||||
state.mins.insert(*col_name, min_value);
|
||||
AggregateFunction::Sum(col_idx) => {
|
||||
let sum = values
|
||||
.get(cursor)
|
||||
.ok_or_else(|| LimboError::InternalError("Missing SUM value".into()))?;
|
||||
if let Value::Float(sum) = sum {
|
||||
state.sums.insert(col_idx, *sum);
|
||||
cursor += 1;
|
||||
} else {
|
||||
return Err(LimboError::InternalError(format!(
|
||||
"Expected Float for SUM value, got {sum:?}"
|
||||
)));
|
||||
}
|
||||
}
|
||||
AggregateFunction::Max(col_name) => {
|
||||
// Read whether we have a MAX value
|
||||
let has_value = *blob.get(cursor)?;
|
||||
AggregateFunction::Avg(col_idx) => {
|
||||
let sum = values
|
||||
.get(cursor)
|
||||
.ok_or_else(|| LimboError::InternalError("Missing AVG sum value".into()))?;
|
||||
let sum = match sum {
|
||||
Value::Float(f) => *f,
|
||||
_ => {
|
||||
return Err(LimboError::InternalError(format!(
|
||||
"Expected Float for AVG sum, got {sum:?}"
|
||||
)))
|
||||
}
|
||||
};
|
||||
cursor += 1;
|
||||
|
||||
if has_value == 1 {
|
||||
let (max_value, bytes_consumed) = deserialize_value(&blob[cursor..])?;
|
||||
cursor += bytes_consumed;
|
||||
state.maxs.insert(*col_name, max_value);
|
||||
let count = values.get(cursor).ok_or_else(|| {
|
||||
LimboError::InternalError("Missing AVG count value".into())
|
||||
})?;
|
||||
let count = match count {
|
||||
Value::Integer(i) => *i,
|
||||
_ => {
|
||||
return Err(LimboError::InternalError(format!(
|
||||
"Expected Integer for AVG count, got {count:?}"
|
||||
)))
|
||||
}
|
||||
};
|
||||
cursor += 1;
|
||||
|
||||
state.avgs.insert(col_idx, (sum, count));
|
||||
}
|
||||
AggregateFunction::Min(col_idx) => {
|
||||
let has_value = values.get(cursor).ok_or_else(|| {
|
||||
LimboError::InternalError("Missing MIN has_value flag".into())
|
||||
})?;
|
||||
if let Value::Integer(has_value) = has_value {
|
||||
cursor += 1;
|
||||
if *has_value == 1 {
|
||||
let min_val = values
|
||||
.get(cursor)
|
||||
.ok_or_else(|| {
|
||||
LimboError::InternalError("Missing MIN value".into())
|
||||
})?
|
||||
.clone();
|
||||
cursor += 1;
|
||||
state.mins.insert(col_idx, min_val);
|
||||
}
|
||||
} else {
|
||||
return Err(LimboError::InternalError(format!(
|
||||
"Expected Integer for MIN has_value flag, got {has_value:?}"
|
||||
)));
|
||||
}
|
||||
}
|
||||
AggregateFunction::Max(col_idx) => {
|
||||
let has_value = values.get(cursor).ok_or_else(|| {
|
||||
LimboError::InternalError("Missing MAX has_value flag".into())
|
||||
})?;
|
||||
if let Value::Integer(has_value) = has_value {
|
||||
cursor += 1;
|
||||
if *has_value == 1 {
|
||||
let max_val = values
|
||||
.get(cursor)
|
||||
.ok_or_else(|| {
|
||||
LimboError::InternalError("Missing MAX value".into())
|
||||
})?
|
||||
.clone();
|
||||
cursor += 1;
|
||||
state.maxs.insert(col_idx, max_val);
|
||||
}
|
||||
} else {
|
||||
return Err(LimboError::InternalError(format!(
|
||||
"Expected Integer for MAX has_value flag, got {has_value:?}"
|
||||
)));
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Some((state, group_key))
|
||||
Ok(state)
|
||||
}
|
||||
|
||||
fn to_blob(&self, aggregates: &[AggregateFunction], group_key: &[Value]) -> Vec<u8> {
|
||||
let mut all_values = Vec::new();
|
||||
// Store the group key size first
|
||||
all_values.push(Value::Integer(group_key.len() as i64));
|
||||
all_values.extend_from_slice(group_key);
|
||||
all_values.extend(self.to_value_vector(aggregates));
|
||||
|
||||
let record = ImmutableRecord::from_values(&all_values, all_values.len());
|
||||
record.as_blob().clone()
|
||||
}
|
||||
|
||||
pub fn from_blob(blob: &[u8]) -> Result<(Self, Vec<Value>)> {
|
||||
let record = ImmutableRecord::from_bin_record(blob.to_vec());
|
||||
let ref_values = record.get_values();
|
||||
let mut all_values: Vec<Value> = ref_values.into_iter().map(|rv| rv.to_owned()).collect();
|
||||
|
||||
if all_values.is_empty() {
|
||||
return Err(LimboError::InternalError(
|
||||
"Aggregate state blob is empty".into(),
|
||||
));
|
||||
}
|
||||
|
||||
// Read the group key size
|
||||
let group_key_count = match &all_values[0] {
|
||||
Value::Integer(n) if *n >= 0 => *n as usize,
|
||||
Value::Integer(n) => {
|
||||
return Err(LimboError::InternalError(format!(
|
||||
"Negative group key count: {n}"
|
||||
)))
|
||||
}
|
||||
other => {
|
||||
return Err(LimboError::InternalError(format!(
|
||||
"Expected Integer for group key count, got {other:?}"
|
||||
)))
|
||||
}
|
||||
};
|
||||
|
||||
// Remove the group key count from the values
|
||||
all_values.remove(0);
|
||||
|
||||
if all_values.len() < group_key_count {
|
||||
return Err(LimboError::InternalError(format!(
|
||||
"Blob too short: expected at least {} values for group key, got {}",
|
||||
group_key_count,
|
||||
all_values.len()
|
||||
)));
|
||||
}
|
||||
|
||||
// Split into group key and state values
|
||||
let group_key = all_values[..group_key_count].to_vec();
|
||||
let state_values = &all_values[group_key_count..];
|
||||
|
||||
// Reconstruct the aggregate state
|
||||
let state = Self::from_value_vector(state_values)?;
|
||||
|
||||
Ok((state, group_key))
|
||||
}
|
||||
|
||||
/// Apply a delta to this aggregate state
|
||||
|
||||
@@ -518,124 +518,44 @@ impl JoinOperator {
|
||||
}
|
||||
}
|
||||
|
||||
// Helper to deserialize a HashableRow from a blob
|
||||
fn deserialize_hashable_row(blob: &[u8]) -> Result<HashableRow> {
|
||||
// Simple deserialization - this needs to match how we serialize in commit
|
||||
// Format: [rowid:8 bytes][num_values:4 bytes][values...]
|
||||
if blob.len() < 12 {
|
||||
use crate::types::ImmutableRecord;
|
||||
|
||||
let record = ImmutableRecord::from_bin_record(blob.to_vec());
|
||||
let ref_values = record.get_values();
|
||||
let all_values: Vec<Value> = ref_values.into_iter().map(|rv| rv.to_owned()).collect();
|
||||
|
||||
if all_values.is_empty() {
|
||||
return Err(crate::LimboError::InternalError(
|
||||
"Invalid blob size".to_string(),
|
||||
"HashableRow blob must contain at least rowid".to_string(),
|
||||
));
|
||||
}
|
||||
|
||||
let rowid = i64::from_le_bytes(blob[0..8].try_into().unwrap());
|
||||
let num_values = u32::from_le_bytes(blob[8..12].try_into().unwrap()) as usize;
|
||||
|
||||
let mut values = Vec::new();
|
||||
let mut offset = 12;
|
||||
|
||||
for _ in 0..num_values {
|
||||
if offset >= blob.len() {
|
||||
break;
|
||||
// First value is the rowid
|
||||
let rowid = match &all_values[0] {
|
||||
Value::Integer(i) => *i,
|
||||
_ => {
|
||||
return Err(crate::LimboError::InternalError(
|
||||
"First value must be rowid (integer)".to_string(),
|
||||
))
|
||||
}
|
||||
};
|
||||
|
||||
let type_tag = blob[offset];
|
||||
offset += 1;
|
||||
|
||||
match type_tag {
|
||||
0 => values.push(Value::Null),
|
||||
1 => {
|
||||
if offset + 8 <= blob.len() {
|
||||
let i = i64::from_le_bytes(blob[offset..offset + 8].try_into().unwrap());
|
||||
values.push(Value::Integer(i));
|
||||
offset += 8;
|
||||
}
|
||||
}
|
||||
2 => {
|
||||
if offset + 8 <= blob.len() {
|
||||
let f = f64::from_le_bytes(blob[offset..offset + 8].try_into().unwrap());
|
||||
values.push(Value::Float(f));
|
||||
offset += 8;
|
||||
}
|
||||
}
|
||||
3 => {
|
||||
if offset + 4 <= blob.len() {
|
||||
let len =
|
||||
u32::from_le_bytes(blob[offset..offset + 4].try_into().unwrap()) as usize;
|
||||
offset += 4;
|
||||
if offset + len < blob.len() {
|
||||
let text_bytes = blob[offset..offset + len].to_vec();
|
||||
offset += len;
|
||||
let subtype = match blob[offset] {
|
||||
0 => crate::types::TextSubtype::Text,
|
||||
1 => crate::types::TextSubtype::Json,
|
||||
_ => crate::types::TextSubtype::Text,
|
||||
};
|
||||
offset += 1;
|
||||
values.push(Value::Text(crate::types::Text {
|
||||
value: text_bytes,
|
||||
subtype,
|
||||
}));
|
||||
}
|
||||
}
|
||||
}
|
||||
4 => {
|
||||
if offset + 4 <= blob.len() {
|
||||
let len =
|
||||
u32::from_le_bytes(blob[offset..offset + 4].try_into().unwrap()) as usize;
|
||||
offset += 4;
|
||||
if offset + len <= blob.len() {
|
||||
let blob_data = blob[offset..offset + len].to_vec();
|
||||
values.push(Value::Blob(blob_data));
|
||||
offset += len;
|
||||
}
|
||||
}
|
||||
}
|
||||
_ => break, // Unknown type tag
|
||||
}
|
||||
}
|
||||
// Rest are the row values
|
||||
let values = all_values[1..].to_vec();
|
||||
|
||||
Ok(HashableRow::new(rowid, values))
|
||||
}
|
||||
|
||||
// Helper to serialize a HashableRow to a blob
|
||||
fn serialize_hashable_row(row: &HashableRow) -> Vec<u8> {
|
||||
let mut blob = Vec::new();
|
||||
use crate::types::ImmutableRecord;
|
||||
|
||||
// Write rowid
|
||||
blob.extend_from_slice(&row.rowid.to_le_bytes());
|
||||
let mut all_values = Vec::with_capacity(row.values.len() + 1);
|
||||
all_values.push(Value::Integer(row.rowid));
|
||||
all_values.extend_from_slice(&row.values);
|
||||
|
||||
// Write number of values
|
||||
blob.extend_from_slice(&(row.values.len() as u32).to_le_bytes());
|
||||
|
||||
// Write each value directly with type tags (like AggregateState does)
|
||||
for value in &row.values {
|
||||
match value {
|
||||
Value::Null => blob.push(0u8),
|
||||
Value::Integer(i) => {
|
||||
blob.push(1u8);
|
||||
blob.extend_from_slice(&i.to_le_bytes());
|
||||
}
|
||||
Value::Float(f) => {
|
||||
blob.push(2u8);
|
||||
blob.extend_from_slice(&f.to_le_bytes());
|
||||
}
|
||||
Value::Text(s) => {
|
||||
blob.push(3u8);
|
||||
let bytes = &s.value;
|
||||
blob.extend_from_slice(&(bytes.len() as u32).to_le_bytes());
|
||||
blob.extend_from_slice(bytes);
|
||||
blob.push(s.subtype as u8);
|
||||
}
|
||||
Value::Blob(b) => {
|
||||
blob.push(4u8);
|
||||
blob.extend_from_slice(&(b.len() as u32).to_le_bytes());
|
||||
blob.extend_from_slice(b);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
blob
|
||||
let record = ImmutableRecord::from_values(&all_values, all_values.len());
|
||||
record.as_blob().clone()
|
||||
}
|
||||
|
||||
impl IncrementalOperator for JoinOperator {
|
||||
|
||||
@@ -332,20 +332,25 @@ mod tests {
|
||||
// Get the blob data from column 3 (value column)
|
||||
if let Some(Value::Blob(blob)) = values.get(3) {
|
||||
// Deserialize the state
|
||||
if let Some((state, group_key)) =
|
||||
AggregateState::from_blob(blob, &agg.aggregates)
|
||||
{
|
||||
// Should not have made it this far.
|
||||
assert!(state.count != 0);
|
||||
// Build output row: group_by columns + aggregate values
|
||||
let mut output_values = group_key.clone();
|
||||
output_values.extend(state.to_values(&agg.aggregates));
|
||||
match AggregateState::from_blob(blob) {
|
||||
Ok((state, group_key)) => {
|
||||
// Should not have made it this far.
|
||||
assert!(state.count != 0);
|
||||
// Build output row: group_by columns + aggregate values
|
||||
let mut output_values = group_key.clone();
|
||||
output_values.extend(state.to_values(&agg.aggregates));
|
||||
|
||||
let group_key_str = AggregateOperator::group_key_to_string(&group_key);
|
||||
let rowid = agg.generate_group_rowid(&group_key_str);
|
||||
let group_key_str = AggregateOperator::group_key_to_string(&group_key);
|
||||
let rowid = agg.generate_group_rowid(&group_key_str);
|
||||
|
||||
let output_row = HashableRow::new(rowid, output_values);
|
||||
result.changes.push((output_row, 1));
|
||||
let output_row = HashableRow::new(rowid, output_values);
|
||||
result.changes.push((output_row, 1));
|
||||
}
|
||||
Err(e) => {
|
||||
// Log or handle the deserialization error
|
||||
// For now, we'll skip this entry
|
||||
eprintln!("Failed to deserialize aggregate state: {e}");
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -4011,4 +4016,115 @@ mod tests {
|
||||
panic!("Expected Done result");
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_aggregate_serialization_with_different_column_indices() {
|
||||
// Test that aggregate state serialization correctly preserves column indices
|
||||
// when multiple aggregates operate on different columns
|
||||
let (pager, table_root_page_id, index_root_page_id) = create_test_pager();
|
||||
let table_cursor = BTreeCursor::new_table(None, pager.clone(), table_root_page_id, 5);
|
||||
let index_def = create_dbsp_state_index(index_root_page_id);
|
||||
let index_cursor =
|
||||
BTreeCursor::new_index(None, pager.clone(), index_root_page_id, &index_def, 4);
|
||||
let mut cursors = DbspStateCursors::new(table_cursor, index_cursor);
|
||||
|
||||
// Create first operator with SUM(col1), MIN(col3) GROUP BY col0
|
||||
let mut agg1 = AggregateOperator::new(
|
||||
1,
|
||||
vec![0],
|
||||
vec![AggregateFunction::Sum(1), AggregateFunction::Min(3)],
|
||||
vec![
|
||||
"group".to_string(),
|
||||
"val1".to_string(),
|
||||
"val2".to_string(),
|
||||
"val3".to_string(),
|
||||
],
|
||||
);
|
||||
|
||||
// Add initial data
|
||||
let mut delta = Delta::new();
|
||||
delta.insert(
|
||||
1,
|
||||
vec![
|
||||
Value::Text("A".into()),
|
||||
Value::Integer(10),
|
||||
Value::Integer(100),
|
||||
Value::Integer(5),
|
||||
],
|
||||
);
|
||||
delta.insert(
|
||||
2,
|
||||
vec![
|
||||
Value::Text("A".into()),
|
||||
Value::Integer(15),
|
||||
Value::Integer(200),
|
||||
Value::Integer(3),
|
||||
],
|
||||
);
|
||||
|
||||
let result1 = pager
|
||||
.io
|
||||
.block(|| agg1.commit((&delta).into(), &mut cursors))
|
||||
.unwrap();
|
||||
|
||||
assert_eq!(result1.changes.len(), 1);
|
||||
let (row1, _) = &result1.changes[0];
|
||||
assert_eq!(row1.values[0], Value::Text("A".into()));
|
||||
assert_eq!(row1.values[1], Value::Integer(25)); // SUM(val1) = 10 + 15
|
||||
assert_eq!(row1.values[2], Value::Integer(3)); // MIN(val3) = min(5, 3)
|
||||
|
||||
// Create operator with same ID but different column mappings: SUM(col3), MIN(col1)
|
||||
let mut agg2 = AggregateOperator::new(
|
||||
1, // Same operator_id
|
||||
vec![0],
|
||||
vec![AggregateFunction::Sum(3), AggregateFunction::Min(1)],
|
||||
vec![
|
||||
"group".to_string(),
|
||||
"val1".to_string(),
|
||||
"val2".to_string(),
|
||||
"val3".to_string(),
|
||||
],
|
||||
);
|
||||
|
||||
// Process new data
|
||||
let mut delta2 = Delta::new();
|
||||
delta2.insert(
|
||||
3,
|
||||
vec![
|
||||
Value::Text("A".into()),
|
||||
Value::Integer(20),
|
||||
Value::Integer(300),
|
||||
Value::Integer(4),
|
||||
],
|
||||
);
|
||||
|
||||
let result2 = pager
|
||||
.io
|
||||
.block(|| agg2.commit((&delta2).into(), &mut cursors))
|
||||
.unwrap();
|
||||
|
||||
// Find the positive weight row for group A (the updated aggregate)
|
||||
let row2 = result2
|
||||
.changes
|
||||
.iter()
|
||||
.find(|(row, weight)| row.values[0] == Value::Text("A".into()) && *weight > 0)
|
||||
.expect("Should have a positive weight row for group A");
|
||||
let (row2, _) = row2;
|
||||
|
||||
// Verify that column indices are preserved correctly in serialization
|
||||
// When agg2 processes the data with different column mappings:
|
||||
// - It reads the existing state which has SUM(col1)=25 and MIN(col3)=3
|
||||
// - For SUM(col3), there's no existing state, so it starts fresh: 4
|
||||
// - For MIN(col1), there's no existing state, so it starts fresh: 20
|
||||
assert_eq!(
|
||||
row2.values[1],
|
||||
Value::Integer(4),
|
||||
"SUM(col3) should be 4 (new data only)"
|
||||
);
|
||||
assert_eq!(
|
||||
row2.values[2],
|
||||
Value::Integer(20),
|
||||
"MIN(col1) should be 20 (new data only)"
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
use crate::incremental::operator::{AggregateFunction, AggregateState, DbspStateCursors};
|
||||
use crate::incremental::operator::{AggregateState, DbspStateCursors};
|
||||
use crate::storage::btree::{BTreeCursor, BTreeKey};
|
||||
use crate::types::{IOResult, ImmutableRecord, SeekKey, SeekOp, SeekResult};
|
||||
use crate::{return_if_io, LimboError, Result, Value};
|
||||
@@ -20,7 +20,6 @@ impl ReadRecord {
|
||||
pub fn read_record(
|
||||
&mut self,
|
||||
key: SeekKey,
|
||||
aggregates: &[AggregateFunction],
|
||||
cursor: &mut BTreeCursor,
|
||||
) -> Result<IOResult<Option<AggregateState>>> {
|
||||
loop {
|
||||
@@ -41,12 +40,7 @@ impl ReadRecord {
|
||||
let blob = values[3].to_owned();
|
||||
|
||||
let (state, _group_key) = match blob {
|
||||
Value::Blob(blob) => AggregateState::from_blob(&blob, aggregates)
|
||||
.ok_or_else(|| {
|
||||
LimboError::InternalError(format!(
|
||||
"Cannot deserialize aggregate state {blob:?}",
|
||||
))
|
||||
}),
|
||||
Value::Blob(blob) => AggregateState::from_blob(&blob),
|
||||
_ => Err(LimboError::ParseError(
|
||||
"Value in aggregator not blob".to_string(),
|
||||
)),
|
||||
|
||||
Reference in New Issue
Block a user