views: implement aggregations

Hook up the AggregateOperator. Also wires up the tracker, allowing us to
verify how much work was done.
This commit is contained in:
Glauber Costa
2025-08-11 14:37:49 -05:00
parent 9ae6a57980
commit 27c22a64b3

View File

@@ -1,6 +1,7 @@
use super::dbsp::{RowKeyStream, RowKeyZSet};
use super::operator::{
AggregateFunction, Delta, FilterOperator, FilterPredicate, ProjectColumn, ProjectOperator,
AggregateFunction, AggregateOperator, ComputationTracker, Delta, FilterOperator,
FilterPredicate, IncrementalOperator, ProjectColumn, ProjectOperator,
};
use crate::schema::{BTreeTable, Column, Schema};
use crate::types::{IOResult, Value};
@@ -9,7 +10,7 @@ use crate::{LimboError, Result, Statement};
use fallible_iterator::FallibleIterator;
use std::collections::BTreeMap;
use std::fmt;
use std::sync::Arc;
use std::sync::{Arc, Mutex};
use turso_sqlite3_parser::{
ast::{Cmd, Stmt},
lexer::sql::Parser,
@@ -78,12 +79,18 @@ pub struct IncrementalView {
filter_operator: Option<FilterOperator>,
// Internal project operator for value transformation
project_operator: Option<ProjectOperator>,
// Internal aggregate operator for GROUP BY and aggregations
aggregate_operator: Option<AggregateOperator>,
// Tables referenced by this view (extracted from FROM clause and JOINs)
base_table: Arc<BTreeTable>,
// The view's output columns with their types
pub columns: Vec<Column>,
// State machine for population
populate_state: PopulateState,
// Computation tracker for statistics
// We will use this one day to export rows_read, but for now, will just test that we're doing the expected amount of compute
#[cfg_attr(not(test), allow(dead_code))]
pub tracker: Arc<Mutex<ComputationTracker>>,
}
impl IncrementalView {
@@ -95,11 +102,6 @@ impl IncrementalView {
) -> Result<()> {
// Check for aggregations
let (group_by_columns, aggregate_functions, _) = Self::extract_aggregation_info(select);
if !group_by_columns.is_empty() || !aggregate_functions.is_empty() {
return Err(LimboError::ParseError(
"aggregations in views are not yet supported".to_string(),
));
}
// Check for JOINs
let (join_tables, join_condition) = Self::extract_join_info(select);
@@ -127,8 +129,10 @@ impl IncrementalView {
.map(|(i, col)| col.name.clone().unwrap_or_else(|| format!("column_{i}")))
.collect();
// Validate columns are a strict subset
Self::validate_view_columns(select, &base_table_column_names)?;
// For non-aggregated views, validate columns are a strict subset
if group_by_columns.is_empty() && aggregate_functions.is_empty() {
Self::validate_view_columns(select, &base_table_column_names)?;
}
Ok(())
}
@@ -203,12 +207,6 @@ impl IncrementalView {
let (group_by_columns, aggregate_functions, _old_output_names) =
Self::extract_aggregation_info(&select);
if !group_by_columns.is_empty() || !aggregate_functions.is_empty() {
return Err(LimboError::ParseError(
"aggregations in views are not yet supported".to_string(),
));
}
let (join_tables, join_condition) = Self::extract_join_info(&select);
if join_tables.is_some() || join_condition.is_some() {
return Err(LimboError::ParseError(
@@ -238,7 +236,7 @@ impl IncrementalView {
.map(|(i, col)| col.name.clone().unwrap_or_else(|| format!("column_{i}")))
.collect();
Ok(Self::new(
Self::new(
name,
Vec::new(), // Empty initial data
where_predicate,
@@ -246,9 +244,12 @@ impl IncrementalView {
base_table,
base_table_column_names,
view_columns,
))
group_by_columns,
aggregate_functions,
)
}
#[allow(clippy::too_many_arguments)]
pub fn new(
name: String,
initial_data: Vec<(i64, Vec<Value>)>,
@@ -257,7 +258,9 @@ impl IncrementalView {
base_table: Arc<BTreeTable>,
base_table_column_names: Vec<String>,
columns: Vec<Column>,
) -> Self {
group_by_columns: Vec<String>,
aggregate_functions: Vec<AggregateFunction>,
) -> Result<Self> {
let mut records = BTreeMap::new();
for (row_key, values) in initial_data {
@@ -272,17 +275,37 @@ impl IncrementalView {
zset.insert(row, 1);
}
// Create the tracker that will be shared by all operators
let tracker = Arc::new(Mutex::new(ComputationTracker::new()));
// Create filter operator if we have a predicate
let filter_operator = if !matches!(where_predicate, FilterPredicate::None) {
Some(FilterOperator::new(
where_predicate.clone(),
base_table_column_names.clone(),
))
let mut filter_op =
FilterOperator::new(where_predicate.clone(), base_table_column_names.clone());
filter_op.set_tracker(tracker.clone());
Some(filter_op)
} else {
None
};
let project_operator = {
// Check if this is an aggregated view
let is_aggregated = !group_by_columns.is_empty() || !aggregate_functions.is_empty();
// Create aggregate operator if needed
let aggregate_operator = if is_aggregated {
let mut agg_op = AggregateOperator::new(
group_by_columns,
aggregate_functions,
base_table_column_names.clone(),
);
agg_op.set_tracker(tracker.clone());
Some(agg_op)
} else {
None
};
// Only create project operator for non-aggregated views
let project_operator = if !is_aggregated {
let columns = Self::extract_project_columns(&select_stmt, &base_table_column_names)
.unwrap_or_else(|| {
// If we can't extract columns, default to projecting all columns
@@ -291,13 +314,14 @@ impl IncrementalView {
.map(|name| ProjectColumn::Column(name.to_string()))
.collect()
});
Some(ProjectOperator::new(
columns,
base_table_column_names.clone(),
))
let mut proj_op = ProjectOperator::new(columns, base_table_column_names.clone());
proj_op.set_tracker(tracker.clone());
Some(proj_op)
} else {
None
};
Self {
Ok(Self {
stream: RowKeyStream::from_zset(zset),
name,
records,
@@ -305,10 +329,12 @@ impl IncrementalView {
select_stmt,
filter_operator,
project_operator,
aggregate_operator,
base_table,
columns,
populate_state: PopulateState::Start,
}
tracker,
})
}
pub fn name(&self) -> &str {
@@ -903,28 +929,49 @@ impl IncrementalView {
}
}
/// Apply filter operator to a delta if present
fn apply_filter_to_delta(&mut self, delta: Delta) -> Delta {
if let Some(ref mut filter_op) = self.filter_operator {
filter_op.process_delta(delta)
} else {
delta
}
}
/// Apply aggregation operator to a delta if this is an aggregated view
fn apply_aggregation_to_delta(&mut self, delta: Delta) -> Delta {
if let Some(ref mut agg_op) = self.aggregate_operator {
agg_op.process_delta(delta)
} else {
delta
}
}
/// Merge a delta of changes into the view's current state
pub fn merge_delta(&mut self, delta: &Delta) {
// Create a Z-set of changes to apply to the stream
// Early return if delta is empty
if delta.is_empty() {
return;
}
// Apply operators in pipeline
let mut current_delta = delta.clone();
current_delta = self.apply_filter_to_delta(current_delta);
current_delta = self.apply_aggregation_to_delta(current_delta);
// Update records and stream with the processed delta
let mut zset_delta = RowKeyZSet::new();
// Apply the delta changes to the records
for (row, weight) in &delta.changes {
for (row, weight) in &current_delta.changes {
if *weight > 0 {
// Insert
if self.apply_filter(&row.values) {
self.records.insert(row.rowid, row.values.clone());
zset_delta.insert(row.clone(), 1);
}
self.records.insert(row.rowid, row.values.clone());
zset_delta.insert(row.clone(), 1);
} else if *weight < 0 {
// Delete
if self.records.remove(&row.rowid).is_some() {
zset_delta.insert(row.clone(), -1);
}
self.records.remove(&row.rowid);
zset_delta.insert(row.clone(), -1);
}
}
// Apply all changes to the stream at once
self.stream.apply_delta(&zset_delta);
}
}
@@ -1206,4 +1253,165 @@ mod tests {
]
);
}
#[test]
fn test_aggregation_count_with_group_by() {
let schema = create_test_schema();
let sql = "CREATE VIEW v AS SELECT a, COUNT(*) FROM t GROUP BY a";
let mut view = IncrementalView::from_sql(sql, &schema).unwrap();
// Verify the view has an aggregate operator
assert!(view.aggregate_operator.is_some());
// Insert some test data
let mut delta = Delta::new();
delta.insert(
1,
vec![Value::Integer(1), Value::Integer(10), Value::Integer(100)],
);
delta.insert(
2,
vec![Value::Integer(2), Value::Integer(20), Value::Integer(200)],
);
delta.insert(
3,
vec![Value::Integer(1), Value::Integer(30), Value::Integer(300)],
);
// Process the delta
view.merge_delta(&delta);
// Verify we only processed the 3 rows we inserted
assert_eq!(view.tracker.lock().unwrap().aggregation_updates, 3);
// Check the aggregated results
let results = view.current_data(None);
// Should have 2 groups: a=1 with count=2, a=2 with count=1
assert_eq!(results.len(), 2);
// Find the group with a=1
let group1 = results
.iter()
.find(|(_, vals)| vals[0] == Value::Integer(1))
.unwrap();
assert_eq!(group1.1[0], Value::Integer(1)); // a=1
assert_eq!(group1.1[1], Value::Integer(2)); // COUNT(*)=2
// Find the group with a=2
let group2 = results
.iter()
.find(|(_, vals)| vals[0] == Value::Integer(2))
.unwrap();
assert_eq!(group2.1[0], Value::Integer(2)); // a=2
assert_eq!(group2.1[1], Value::Integer(1)); // COUNT(*)=1
}
#[test]
fn test_aggregation_sum_with_filter() {
let schema = create_test_schema();
let sql = "CREATE VIEW v AS SELECT SUM(b) FROM t WHERE a > 1";
let mut view = IncrementalView::from_sql(sql, &schema).unwrap();
assert!(view.aggregate_operator.is_some());
assert!(view.filter_operator.is_some());
let mut delta = Delta::new();
delta.insert(
1,
vec![Value::Integer(1), Value::Integer(10), Value::Integer(100)],
);
delta.insert(
2,
vec![Value::Integer(2), Value::Integer(20), Value::Integer(200)],
);
delta.insert(
3,
vec![Value::Integer(3), Value::Integer(30), Value::Integer(300)],
);
view.merge_delta(&delta);
// Should filter all 3 rows
assert_eq!(view.tracker.lock().unwrap().filter_evaluations, 3);
// But only aggregate the 2 that passed the filter (a > 1)
assert_eq!(view.tracker.lock().unwrap().aggregation_updates, 2);
let results = view.current_data(None);
// Should have 1 row with sum of b where a > 1
assert_eq!(results.len(), 1);
assert_eq!(results[0].1[0], Value::Integer(50)); // SUM(b) = 20 + 30
}
#[test]
fn test_aggregation_incremental_updates() {
let schema = create_test_schema();
let sql = "CREATE VIEW v AS SELECT a, COUNT(*), SUM(b) FROM t GROUP BY a";
let mut view = IncrementalView::from_sql(sql, &schema).unwrap();
// Initial insert
let mut delta1 = Delta::new();
delta1.insert(
1,
vec![Value::Integer(1), Value::Integer(10), Value::Integer(100)],
);
delta1.insert(
2,
vec![Value::Integer(1), Value::Integer(20), Value::Integer(200)],
);
view.merge_delta(&delta1);
// Verify we processed exactly 2 rows for the first batch
assert_eq!(view.tracker.lock().unwrap().aggregation_updates, 2);
// Check initial state
let results1 = view.current_data(None);
assert_eq!(results1.len(), 1);
assert_eq!(results1[0].1[1], Value::Integer(2)); // COUNT(*)=2
assert_eq!(results1[0].1[2], Value::Integer(30)); // SUM(b)=30
// Reset counter to track second batch separately
view.tracker.lock().unwrap().aggregation_updates = 0;
// Add more data
let mut delta2 = Delta::new();
delta2.insert(
3,
vec![Value::Integer(1), Value::Integer(5), Value::Integer(300)],
);
delta2.insert(
4,
vec![Value::Integer(2), Value::Integer(15), Value::Integer(400)],
);
view.merge_delta(&delta2);
// Should only process the 2 new rows, not recompute everything
assert_eq!(view.tracker.lock().unwrap().aggregation_updates, 2);
// Check updated state
let results2 = view.current_data(None);
assert_eq!(results2.len(), 2);
// Group a=1
let group1 = results2
.iter()
.find(|(_, vals)| vals[0] == Value::Integer(1))
.unwrap();
assert_eq!(group1.1[1], Value::Integer(3)); // COUNT(*)=3
assert_eq!(group1.1[2], Value::Integer(35)); // SUM(b)=35
// Group a=2
let group2 = results2
.iter()
.find(|(_, vals)| vals[0] == Value::Integer(2))
.unwrap();
assert_eq!(group2.1[1], Value::Integer(1)); // COUNT(*)=1
assert_eq!(group2.1[2], Value::Integer(15)); // SUM(b)=15
}
}