diff --git a/core/incremental/view.rs b/core/incremental/view.rs index b590e6d8b..e163f55de 100644 --- a/core/incremental/view.rs +++ b/core/incremental/view.rs @@ -1,6 +1,7 @@ use super::dbsp::{RowKeyStream, RowKeyZSet}; use super::operator::{ - AggregateFunction, Delta, FilterOperator, FilterPredicate, ProjectColumn, ProjectOperator, + AggregateFunction, AggregateOperator, ComputationTracker, Delta, FilterOperator, + FilterPredicate, IncrementalOperator, ProjectColumn, ProjectOperator, }; use crate::schema::{BTreeTable, Column, Schema}; use crate::types::{IOResult, Value}; @@ -9,7 +10,7 @@ use crate::{LimboError, Result, Statement}; use fallible_iterator::FallibleIterator; use std::collections::BTreeMap; use std::fmt; -use std::sync::Arc; +use std::sync::{Arc, Mutex}; use turso_sqlite3_parser::{ ast::{Cmd, Stmt}, lexer::sql::Parser, @@ -78,12 +79,18 @@ pub struct IncrementalView { filter_operator: Option, // Internal project operator for value transformation project_operator: Option, + // Internal aggregate operator for GROUP BY and aggregations + aggregate_operator: Option, // Tables referenced by this view (extracted from FROM clause and JOINs) base_table: Arc, // The view's output columns with their types pub columns: Vec, // State machine for population populate_state: PopulateState, + // Computation tracker for statistics + // We will use this one day to export rows_read, but for now, will just test that we're doing the expected amount of compute + #[cfg_attr(not(test), allow(dead_code))] + pub tracker: Arc>, } impl IncrementalView { @@ -95,11 +102,6 @@ impl IncrementalView { ) -> Result<()> { // Check for aggregations let (group_by_columns, aggregate_functions, _) = Self::extract_aggregation_info(select); - if !group_by_columns.is_empty() || !aggregate_functions.is_empty() { - return Err(LimboError::ParseError( - "aggregations in views are not yet supported".to_string(), - )); - } // Check for JOINs let (join_tables, join_condition) = Self::extract_join_info(select); @@ -127,8 +129,10 @@ impl IncrementalView { .map(|(i, col)| col.name.clone().unwrap_or_else(|| format!("column_{i}"))) .collect(); - // Validate columns are a strict subset - Self::validate_view_columns(select, &base_table_column_names)?; + // For non-aggregated views, validate columns are a strict subset + if group_by_columns.is_empty() && aggregate_functions.is_empty() { + Self::validate_view_columns(select, &base_table_column_names)?; + } Ok(()) } @@ -203,12 +207,6 @@ impl IncrementalView { let (group_by_columns, aggregate_functions, _old_output_names) = Self::extract_aggregation_info(&select); - if !group_by_columns.is_empty() || !aggregate_functions.is_empty() { - return Err(LimboError::ParseError( - "aggregations in views are not yet supported".to_string(), - )); - } - let (join_tables, join_condition) = Self::extract_join_info(&select); if join_tables.is_some() || join_condition.is_some() { return Err(LimboError::ParseError( @@ -238,7 +236,7 @@ impl IncrementalView { .map(|(i, col)| col.name.clone().unwrap_or_else(|| format!("column_{i}"))) .collect(); - Ok(Self::new( + Self::new( name, Vec::new(), // Empty initial data where_predicate, @@ -246,9 +244,12 @@ impl IncrementalView { base_table, base_table_column_names, view_columns, - )) + group_by_columns, + aggregate_functions, + ) } + #[allow(clippy::too_many_arguments)] pub fn new( name: String, initial_data: Vec<(i64, Vec)>, @@ -257,7 +258,9 @@ impl IncrementalView { base_table: Arc, base_table_column_names: Vec, columns: Vec, - ) -> Self { + group_by_columns: Vec, + aggregate_functions: Vec, + ) -> Result { let mut records = BTreeMap::new(); for (row_key, values) in initial_data { @@ -272,17 +275,37 @@ impl IncrementalView { zset.insert(row, 1); } + // Create the tracker that will be shared by all operators + let tracker = Arc::new(Mutex::new(ComputationTracker::new())); + // Create filter operator if we have a predicate let filter_operator = if !matches!(where_predicate, FilterPredicate::None) { - Some(FilterOperator::new( - where_predicate.clone(), - base_table_column_names.clone(), - )) + let mut filter_op = + FilterOperator::new(where_predicate.clone(), base_table_column_names.clone()); + filter_op.set_tracker(tracker.clone()); + Some(filter_op) } else { None }; - let project_operator = { + // Check if this is an aggregated view + let is_aggregated = !group_by_columns.is_empty() || !aggregate_functions.is_empty(); + + // Create aggregate operator if needed + let aggregate_operator = if is_aggregated { + let mut agg_op = AggregateOperator::new( + group_by_columns, + aggregate_functions, + base_table_column_names.clone(), + ); + agg_op.set_tracker(tracker.clone()); + Some(agg_op) + } else { + None + }; + + // Only create project operator for non-aggregated views + let project_operator = if !is_aggregated { let columns = Self::extract_project_columns(&select_stmt, &base_table_column_names) .unwrap_or_else(|| { // If we can't extract columns, default to projecting all columns @@ -291,13 +314,14 @@ impl IncrementalView { .map(|name| ProjectColumn::Column(name.to_string())) .collect() }); - Some(ProjectOperator::new( - columns, - base_table_column_names.clone(), - )) + let mut proj_op = ProjectOperator::new(columns, base_table_column_names.clone()); + proj_op.set_tracker(tracker.clone()); + Some(proj_op) + } else { + None }; - Self { + Ok(Self { stream: RowKeyStream::from_zset(zset), name, records, @@ -305,10 +329,12 @@ impl IncrementalView { select_stmt, filter_operator, project_operator, + aggregate_operator, base_table, columns, populate_state: PopulateState::Start, - } + tracker, + }) } pub fn name(&self) -> &str { @@ -526,13 +552,15 @@ impl IncrementalView { stmt, rows_processed, } => { - // Process rows in batches to allow for IO interruption + // Collect rows into a delta batch + let mut batch_delta = Delta::new(); let mut batch_count = 0; loop { if batch_count >= BATCH_SIZE { + // Process this batch through the standard pipeline + self.merge_delta(&batch_delta); // Yield control after processing a batch - // The statement maintains its position, so we'll resume from here return Ok(IOResult::IO); } @@ -560,38 +588,15 @@ impl IncrementalView { // Get all values except the rowid let values = all_values[..all_values.len() - 1].to_vec(); - // Apply filter if we have one - // Pure DBSP would ingest the entire stream and then apply filter operators. - // However, for initial population, we adopt a hybrid approach where we filter at - // the query result level for efficiency. This avoids reading millions of rows just - // to filter them down to a few. We only do this optimization for filters, not for - // other operators like projections or aggregations. - // TODO: We should further optimize by pushing the filter into the SQL WHERE clause. - - // Check filter first (we need to do this before accessing self mutably) - let passes_filter = - if let Some(ref filter_op) = self.filter_operator { - filter_op.evaluate_predicate(&values) - } else { - true - }; - - if passes_filter { - // Store the row with its original rowid - self.records.insert(rowid, values.clone()); - - // Update the ZSet stream with weight +1 - let mut delta = RowKeyZSet::new(); - use crate::incremental::hashable_row::HashableRow; - let row = HashableRow::new(rowid, values); - delta.insert(row, 1); - self.stream.apply_delta(&delta); - } + // Add to batch delta - let merge_delta handle filtering and aggregation + batch_delta.insert(rowid, values); *rows_processed += 1; batch_count += 1; } crate::vdbe::StepResult::Done => { + // Process any remaining rows in the batch + self.merge_delta(&batch_delta); // All rows processed, move to Done state self.populate_state = PopulateState::Done; return Ok(IOResult::Done(())); @@ -600,8 +605,9 @@ impl IncrementalView { return Err(LimboError::Busy); } crate::vdbe::StepResult::IO => { + // Process current batch before yielding + self.merge_delta(&batch_delta); // The Statement needs to wait for IO - // When we return here, the Statement maintains its position return Ok(IOResult::IO); } } @@ -903,28 +909,49 @@ impl IncrementalView { } } + /// Apply filter operator to a delta if present + fn apply_filter_to_delta(&mut self, delta: Delta) -> Delta { + if let Some(ref mut filter_op) = self.filter_operator { + filter_op.process_delta(delta) + } else { + delta + } + } + + /// Apply aggregation operator to a delta if this is an aggregated view + fn apply_aggregation_to_delta(&mut self, delta: Delta) -> Delta { + if let Some(ref mut agg_op) = self.aggregate_operator { + agg_op.process_delta(delta) + } else { + delta + } + } + /// Merge a delta of changes into the view's current state pub fn merge_delta(&mut self, delta: &Delta) { - // Create a Z-set of changes to apply to the stream + // Early return if delta is empty + if delta.is_empty() { + return; + } + + // Apply operators in pipeline + let mut current_delta = delta.clone(); + current_delta = self.apply_filter_to_delta(current_delta); + current_delta = self.apply_aggregation_to_delta(current_delta); + + // Update records and stream with the processed delta let mut zset_delta = RowKeyZSet::new(); - // Apply the delta changes to the records - for (row, weight) in &delta.changes { + for (row, weight) in ¤t_delta.changes { if *weight > 0 { - // Insert - if self.apply_filter(&row.values) { - self.records.insert(row.rowid, row.values.clone()); - zset_delta.insert(row.clone(), 1); - } + self.records.insert(row.rowid, row.values.clone()); + zset_delta.insert(row.clone(), 1); } else if *weight < 0 { - // Delete - if self.records.remove(&row.rowid).is_some() { - zset_delta.insert(row.clone(), -1); - } + self.records.remove(&row.rowid); + zset_delta.insert(row.clone(), -1); } } - // Apply all changes to the stream at once self.stream.apply_delta(&zset_delta); } } @@ -1206,4 +1233,165 @@ mod tests { ] ); } + + #[test] + fn test_aggregation_count_with_group_by() { + let schema = create_test_schema(); + let sql = "CREATE VIEW v AS SELECT a, COUNT(*) FROM t GROUP BY a"; + + let mut view = IncrementalView::from_sql(sql, &schema).unwrap(); + + // Verify the view has an aggregate operator + assert!(view.aggregate_operator.is_some()); + + // Insert some test data + let mut delta = Delta::new(); + delta.insert( + 1, + vec![Value::Integer(1), Value::Integer(10), Value::Integer(100)], + ); + delta.insert( + 2, + vec![Value::Integer(2), Value::Integer(20), Value::Integer(200)], + ); + delta.insert( + 3, + vec![Value::Integer(1), Value::Integer(30), Value::Integer(300)], + ); + + // Process the delta + view.merge_delta(&delta); + + // Verify we only processed the 3 rows we inserted + assert_eq!(view.tracker.lock().unwrap().aggregation_updates, 3); + + // Check the aggregated results + let results = view.current_data(None); + + // Should have 2 groups: a=1 with count=2, a=2 with count=1 + assert_eq!(results.len(), 2); + + // Find the group with a=1 + let group1 = results + .iter() + .find(|(_, vals)| vals[0] == Value::Integer(1)) + .unwrap(); + assert_eq!(group1.1[0], Value::Integer(1)); // a=1 + assert_eq!(group1.1[1], Value::Integer(2)); // COUNT(*)=2 + + // Find the group with a=2 + let group2 = results + .iter() + .find(|(_, vals)| vals[0] == Value::Integer(2)) + .unwrap(); + assert_eq!(group2.1[0], Value::Integer(2)); // a=2 + assert_eq!(group2.1[1], Value::Integer(1)); // COUNT(*)=1 + } + + #[test] + fn test_aggregation_sum_with_filter() { + let schema = create_test_schema(); + let sql = "CREATE VIEW v AS SELECT SUM(b) FROM t WHERE a > 1"; + + let mut view = IncrementalView::from_sql(sql, &schema).unwrap(); + + assert!(view.aggregate_operator.is_some()); + assert!(view.filter_operator.is_some()); + + let mut delta = Delta::new(); + delta.insert( + 1, + vec![Value::Integer(1), Value::Integer(10), Value::Integer(100)], + ); + delta.insert( + 2, + vec![Value::Integer(2), Value::Integer(20), Value::Integer(200)], + ); + delta.insert( + 3, + vec![Value::Integer(3), Value::Integer(30), Value::Integer(300)], + ); + + view.merge_delta(&delta); + + // Should filter all 3 rows + assert_eq!(view.tracker.lock().unwrap().filter_evaluations, 3); + // But only aggregate the 2 that passed the filter (a > 1) + assert_eq!(view.tracker.lock().unwrap().aggregation_updates, 2); + + let results = view.current_data(None); + + // Should have 1 row with sum of b where a > 1 + assert_eq!(results.len(), 1); + assert_eq!(results[0].1[0], Value::Integer(50)); // SUM(b) = 20 + 30 + } + + #[test] + fn test_aggregation_incremental_updates() { + let schema = create_test_schema(); + let sql = "CREATE VIEW v AS SELECT a, COUNT(*), SUM(b) FROM t GROUP BY a"; + + let mut view = IncrementalView::from_sql(sql, &schema).unwrap(); + + // Initial insert + let mut delta1 = Delta::new(); + delta1.insert( + 1, + vec![Value::Integer(1), Value::Integer(10), Value::Integer(100)], + ); + delta1.insert( + 2, + vec![Value::Integer(1), Value::Integer(20), Value::Integer(200)], + ); + + view.merge_delta(&delta1); + + // Verify we processed exactly 2 rows for the first batch + assert_eq!(view.tracker.lock().unwrap().aggregation_updates, 2); + + // Check initial state + let results1 = view.current_data(None); + assert_eq!(results1.len(), 1); + assert_eq!(results1[0].1[1], Value::Integer(2)); // COUNT(*)=2 + assert_eq!(results1[0].1[2], Value::Integer(30)); // SUM(b)=30 + + // Reset counter to track second batch separately + view.tracker.lock().unwrap().aggregation_updates = 0; + + // Add more data + let mut delta2 = Delta::new(); + delta2.insert( + 3, + vec![Value::Integer(1), Value::Integer(5), Value::Integer(300)], + ); + delta2.insert( + 4, + vec![Value::Integer(2), Value::Integer(15), Value::Integer(400)], + ); + + view.merge_delta(&delta2); + + // Should only process the 2 new rows, not recompute everything + assert_eq!(view.tracker.lock().unwrap().aggregation_updates, 2); + + // Check updated state + let results2 = view.current_data(None); + assert_eq!(results2.len(), 2); + + // Group a=1 + let group1 = results2 + .iter() + .find(|(_, vals)| vals[0] == Value::Integer(1)) + .unwrap(); + assert_eq!(group1.1[1], Value::Integer(3)); // COUNT(*)=3 + assert_eq!(group1.1[2], Value::Integer(35)); // SUM(b)=35 + + // Group a=2 + let group2 = results2 + .iter() + .find(|(_, vals)| vals[0] == Value::Integer(2)) + .unwrap(); + assert_eq!(group2.1[1], Value::Integer(1)); // COUNT(*)=1 + assert_eq!(group2.1[2], Value::Integer(15)); // SUM(b)=15 + } }