Merge 'Implement Aggregations for DBSP views' from Glauber Costa

```
turso> create table t(a, b);
turso> insert into t(a,b) values (2,2), (3,3);
turso> insert into t(a,b) values (6,6), (7,7);
turso> insert into t(a,b) values (6,6), (7,7);
turso> create view tt as select b, sum(a) from t where b > 2 group by b;
turso> select * from tt;
┌───┬─────────┐
│ b │ sum (a) │
├───┼─────────┤
│ 3 │       3 │
├───┼─────────┤
│ 6 │      12 │
├───┼─────────┤
│ 7 │      14 │
└───┴─────────┘
turso> insert into t(a,b) values (1,3);
turso> select * from tt;
┌───┬─────────┐
│ b │ sum (a) │
├───┼─────────┤
│ 3 │       4 │
├───┼─────────┤
│ 6 │      12 │
├───┼─────────┤
│ 7 │      14 │
└───┴─────────┘
turso>
```

Closes #2547
This commit is contained in:
Pekka Enberg
2025-08-12 09:52:22 +03:00
committed by GitHub

View File

@@ -1,6 +1,7 @@
use super::dbsp::{RowKeyStream, RowKeyZSet};
use super::operator::{
AggregateFunction, Delta, FilterOperator, FilterPredicate, ProjectColumn, ProjectOperator,
AggregateFunction, AggregateOperator, ComputationTracker, Delta, FilterOperator,
FilterPredicate, IncrementalOperator, ProjectColumn, ProjectOperator,
};
use crate::schema::{BTreeTable, Column, Schema};
use crate::types::{IOResult, Value};
@@ -9,7 +10,7 @@ use crate::{LimboError, Result, Statement};
use fallible_iterator::FallibleIterator;
use std::collections::BTreeMap;
use std::fmt;
use std::sync::Arc;
use std::sync::{Arc, Mutex};
use turso_sqlite3_parser::{
ast::{Cmd, Stmt},
lexer::sql::Parser,
@@ -78,12 +79,18 @@ pub struct IncrementalView {
filter_operator: Option<FilterOperator>,
// Internal project operator for value transformation
project_operator: Option<ProjectOperator>,
// Internal aggregate operator for GROUP BY and aggregations
aggregate_operator: Option<AggregateOperator>,
// Tables referenced by this view (extracted from FROM clause and JOINs)
base_table: Arc<BTreeTable>,
// The view's output columns with their types
pub columns: Vec<Column>,
// State machine for population
populate_state: PopulateState,
// Computation tracker for statistics
// We will use this one day to export rows_read, but for now, will just test that we're doing the expected amount of compute
#[cfg_attr(not(test), allow(dead_code))]
pub tracker: Arc<Mutex<ComputationTracker>>,
}
impl IncrementalView {
@@ -95,11 +102,6 @@ impl IncrementalView {
) -> Result<()> {
// Check for aggregations
let (group_by_columns, aggregate_functions, _) = Self::extract_aggregation_info(select);
if !group_by_columns.is_empty() || !aggregate_functions.is_empty() {
return Err(LimboError::ParseError(
"aggregations in views are not yet supported".to_string(),
));
}
// Check for JOINs
let (join_tables, join_condition) = Self::extract_join_info(select);
@@ -127,8 +129,10 @@ impl IncrementalView {
.map(|(i, col)| col.name.clone().unwrap_or_else(|| format!("column_{i}")))
.collect();
// Validate columns are a strict subset
Self::validate_view_columns(select, &base_table_column_names)?;
// For non-aggregated views, validate columns are a strict subset
if group_by_columns.is_empty() && aggregate_functions.is_empty() {
Self::validate_view_columns(select, &base_table_column_names)?;
}
Ok(())
}
@@ -203,12 +207,6 @@ impl IncrementalView {
let (group_by_columns, aggregate_functions, _old_output_names) =
Self::extract_aggregation_info(&select);
if !group_by_columns.is_empty() || !aggregate_functions.is_empty() {
return Err(LimboError::ParseError(
"aggregations in views are not yet supported".to_string(),
));
}
let (join_tables, join_condition) = Self::extract_join_info(&select);
if join_tables.is_some() || join_condition.is_some() {
return Err(LimboError::ParseError(
@@ -238,7 +236,7 @@ impl IncrementalView {
.map(|(i, col)| col.name.clone().unwrap_or_else(|| format!("column_{i}")))
.collect();
Ok(Self::new(
Self::new(
name,
Vec::new(), // Empty initial data
where_predicate,
@@ -246,9 +244,12 @@ impl IncrementalView {
base_table,
base_table_column_names,
view_columns,
))
group_by_columns,
aggregate_functions,
)
}
#[allow(clippy::too_many_arguments)]
pub fn new(
name: String,
initial_data: Vec<(i64, Vec<Value>)>,
@@ -257,7 +258,9 @@ impl IncrementalView {
base_table: Arc<BTreeTable>,
base_table_column_names: Vec<String>,
columns: Vec<Column>,
) -> Self {
group_by_columns: Vec<String>,
aggregate_functions: Vec<AggregateFunction>,
) -> Result<Self> {
let mut records = BTreeMap::new();
for (row_key, values) in initial_data {
@@ -272,17 +275,37 @@ impl IncrementalView {
zset.insert(row, 1);
}
// Create the tracker that will be shared by all operators
let tracker = Arc::new(Mutex::new(ComputationTracker::new()));
// Create filter operator if we have a predicate
let filter_operator = if !matches!(where_predicate, FilterPredicate::None) {
Some(FilterOperator::new(
where_predicate.clone(),
base_table_column_names.clone(),
))
let mut filter_op =
FilterOperator::new(where_predicate.clone(), base_table_column_names.clone());
filter_op.set_tracker(tracker.clone());
Some(filter_op)
} else {
None
};
let project_operator = {
// Check if this is an aggregated view
let is_aggregated = !group_by_columns.is_empty() || !aggregate_functions.is_empty();
// Create aggregate operator if needed
let aggregate_operator = if is_aggregated {
let mut agg_op = AggregateOperator::new(
group_by_columns,
aggregate_functions,
base_table_column_names.clone(),
);
agg_op.set_tracker(tracker.clone());
Some(agg_op)
} else {
None
};
// Only create project operator for non-aggregated views
let project_operator = if !is_aggregated {
let columns = Self::extract_project_columns(&select_stmt, &base_table_column_names)
.unwrap_or_else(|| {
// If we can't extract columns, default to projecting all columns
@@ -291,13 +314,14 @@ impl IncrementalView {
.map(|name| ProjectColumn::Column(name.to_string()))
.collect()
});
Some(ProjectOperator::new(
columns,
base_table_column_names.clone(),
))
let mut proj_op = ProjectOperator::new(columns, base_table_column_names.clone());
proj_op.set_tracker(tracker.clone());
Some(proj_op)
} else {
None
};
Self {
Ok(Self {
stream: RowKeyStream::from_zset(zset),
name,
records,
@@ -305,10 +329,12 @@ impl IncrementalView {
select_stmt,
filter_operator,
project_operator,
aggregate_operator,
base_table,
columns,
populate_state: PopulateState::Start,
}
tracker,
})
}
pub fn name(&self) -> &str {
@@ -526,13 +552,15 @@ impl IncrementalView {
stmt,
rows_processed,
} => {
// Process rows in batches to allow for IO interruption
// Collect rows into a delta batch
let mut batch_delta = Delta::new();
let mut batch_count = 0;
loop {
if batch_count >= BATCH_SIZE {
// Process this batch through the standard pipeline
self.merge_delta(&batch_delta);
// Yield control after processing a batch
// The statement maintains its position, so we'll resume from here
return Ok(IOResult::IO);
}
@@ -560,38 +588,15 @@ impl IncrementalView {
// Get all values except the rowid
let values = all_values[..all_values.len() - 1].to_vec();
// Apply filter if we have one
// Pure DBSP would ingest the entire stream and then apply filter operators.
// However, for initial population, we adopt a hybrid approach where we filter at
// the query result level for efficiency. This avoids reading millions of rows just
// to filter them down to a few. We only do this optimization for filters, not for
// other operators like projections or aggregations.
// TODO: We should further optimize by pushing the filter into the SQL WHERE clause.
// Check filter first (we need to do this before accessing self mutably)
let passes_filter =
if let Some(ref filter_op) = self.filter_operator {
filter_op.evaluate_predicate(&values)
} else {
true
};
if passes_filter {
// Store the row with its original rowid
self.records.insert(rowid, values.clone());
// Update the ZSet stream with weight +1
let mut delta = RowKeyZSet::new();
use crate::incremental::hashable_row::HashableRow;
let row = HashableRow::new(rowid, values);
delta.insert(row, 1);
self.stream.apply_delta(&delta);
}
// Add to batch delta - let merge_delta handle filtering and aggregation
batch_delta.insert(rowid, values);
*rows_processed += 1;
batch_count += 1;
}
crate::vdbe::StepResult::Done => {
// Process any remaining rows in the batch
self.merge_delta(&batch_delta);
// All rows processed, move to Done state
self.populate_state = PopulateState::Done;
return Ok(IOResult::Done(()));
@@ -600,8 +605,9 @@ impl IncrementalView {
return Err(LimboError::Busy);
}
crate::vdbe::StepResult::IO => {
// Process current batch before yielding
self.merge_delta(&batch_delta);
// The Statement needs to wait for IO
// When we return here, the Statement maintains its position
return Ok(IOResult::IO);
}
}
@@ -903,28 +909,49 @@ impl IncrementalView {
}
}
/// Apply filter operator to a delta if present
fn apply_filter_to_delta(&mut self, delta: Delta) -> Delta {
if let Some(ref mut filter_op) = self.filter_operator {
filter_op.process_delta(delta)
} else {
delta
}
}
/// Apply aggregation operator to a delta if this is an aggregated view
fn apply_aggregation_to_delta(&mut self, delta: Delta) -> Delta {
if let Some(ref mut agg_op) = self.aggregate_operator {
agg_op.process_delta(delta)
} else {
delta
}
}
/// Merge a delta of changes into the view's current state
pub fn merge_delta(&mut self, delta: &Delta) {
// Create a Z-set of changes to apply to the stream
// Early return if delta is empty
if delta.is_empty() {
return;
}
// Apply operators in pipeline
let mut current_delta = delta.clone();
current_delta = self.apply_filter_to_delta(current_delta);
current_delta = self.apply_aggregation_to_delta(current_delta);
// Update records and stream with the processed delta
let mut zset_delta = RowKeyZSet::new();
// Apply the delta changes to the records
for (row, weight) in &delta.changes {
for (row, weight) in &current_delta.changes {
if *weight > 0 {
// Insert
if self.apply_filter(&row.values) {
self.records.insert(row.rowid, row.values.clone());
zset_delta.insert(row.clone(), 1);
}
self.records.insert(row.rowid, row.values.clone());
zset_delta.insert(row.clone(), 1);
} else if *weight < 0 {
// Delete
if self.records.remove(&row.rowid).is_some() {
zset_delta.insert(row.clone(), -1);
}
self.records.remove(&row.rowid);
zset_delta.insert(row.clone(), -1);
}
}
// Apply all changes to the stream at once
self.stream.apply_delta(&zset_delta);
}
}
@@ -1206,4 +1233,165 @@ mod tests {
]
);
}
#[test]
fn test_aggregation_count_with_group_by() {
let schema = create_test_schema();
let sql = "CREATE VIEW v AS SELECT a, COUNT(*) FROM t GROUP BY a";
let mut view = IncrementalView::from_sql(sql, &schema).unwrap();
// Verify the view has an aggregate operator
assert!(view.aggregate_operator.is_some());
// Insert some test data
let mut delta = Delta::new();
delta.insert(
1,
vec![Value::Integer(1), Value::Integer(10), Value::Integer(100)],
);
delta.insert(
2,
vec![Value::Integer(2), Value::Integer(20), Value::Integer(200)],
);
delta.insert(
3,
vec![Value::Integer(1), Value::Integer(30), Value::Integer(300)],
);
// Process the delta
view.merge_delta(&delta);
// Verify we only processed the 3 rows we inserted
assert_eq!(view.tracker.lock().unwrap().aggregation_updates, 3);
// Check the aggregated results
let results = view.current_data(None);
// Should have 2 groups: a=1 with count=2, a=2 with count=1
assert_eq!(results.len(), 2);
// Find the group with a=1
let group1 = results
.iter()
.find(|(_, vals)| vals[0] == Value::Integer(1))
.unwrap();
assert_eq!(group1.1[0], Value::Integer(1)); // a=1
assert_eq!(group1.1[1], Value::Integer(2)); // COUNT(*)=2
// Find the group with a=2
let group2 = results
.iter()
.find(|(_, vals)| vals[0] == Value::Integer(2))
.unwrap();
assert_eq!(group2.1[0], Value::Integer(2)); // a=2
assert_eq!(group2.1[1], Value::Integer(1)); // COUNT(*)=1
}
#[test]
fn test_aggregation_sum_with_filter() {
let schema = create_test_schema();
let sql = "CREATE VIEW v AS SELECT SUM(b) FROM t WHERE a > 1";
let mut view = IncrementalView::from_sql(sql, &schema).unwrap();
assert!(view.aggregate_operator.is_some());
assert!(view.filter_operator.is_some());
let mut delta = Delta::new();
delta.insert(
1,
vec![Value::Integer(1), Value::Integer(10), Value::Integer(100)],
);
delta.insert(
2,
vec![Value::Integer(2), Value::Integer(20), Value::Integer(200)],
);
delta.insert(
3,
vec![Value::Integer(3), Value::Integer(30), Value::Integer(300)],
);
view.merge_delta(&delta);
// Should filter all 3 rows
assert_eq!(view.tracker.lock().unwrap().filter_evaluations, 3);
// But only aggregate the 2 that passed the filter (a > 1)
assert_eq!(view.tracker.lock().unwrap().aggregation_updates, 2);
let results = view.current_data(None);
// Should have 1 row with sum of b where a > 1
assert_eq!(results.len(), 1);
assert_eq!(results[0].1[0], Value::Integer(50)); // SUM(b) = 20 + 30
}
#[test]
fn test_aggregation_incremental_updates() {
let schema = create_test_schema();
let sql = "CREATE VIEW v AS SELECT a, COUNT(*), SUM(b) FROM t GROUP BY a";
let mut view = IncrementalView::from_sql(sql, &schema).unwrap();
// Initial insert
let mut delta1 = Delta::new();
delta1.insert(
1,
vec![Value::Integer(1), Value::Integer(10), Value::Integer(100)],
);
delta1.insert(
2,
vec![Value::Integer(1), Value::Integer(20), Value::Integer(200)],
);
view.merge_delta(&delta1);
// Verify we processed exactly 2 rows for the first batch
assert_eq!(view.tracker.lock().unwrap().aggregation_updates, 2);
// Check initial state
let results1 = view.current_data(None);
assert_eq!(results1.len(), 1);
assert_eq!(results1[0].1[1], Value::Integer(2)); // COUNT(*)=2
assert_eq!(results1[0].1[2], Value::Integer(30)); // SUM(b)=30
// Reset counter to track second batch separately
view.tracker.lock().unwrap().aggregation_updates = 0;
// Add more data
let mut delta2 = Delta::new();
delta2.insert(
3,
vec![Value::Integer(1), Value::Integer(5), Value::Integer(300)],
);
delta2.insert(
4,
vec![Value::Integer(2), Value::Integer(15), Value::Integer(400)],
);
view.merge_delta(&delta2);
// Should only process the 2 new rows, not recompute everything
assert_eq!(view.tracker.lock().unwrap().aggregation_updates, 2);
// Check updated state
let results2 = view.current_data(None);
assert_eq!(results2.len(), 2);
// Group a=1
let group1 = results2
.iter()
.find(|(_, vals)| vals[0] == Value::Integer(1))
.unwrap();
assert_eq!(group1.1[1], Value::Integer(3)); // COUNT(*)=3
assert_eq!(group1.1[2], Value::Integer(35)); // SUM(b)=35
// Group a=2
let group2 = results2
.iter()
.find(|(_, vals)| vals[0] == Value::Integer(2))
.unwrap();
assert_eq!(group2.1[1], Value::Integer(1)); // COUNT(*)=1
assert_eq!(group2.1[2], Value::Integer(15)); // SUM(b)=15
}
}