mirror of
https://github.com/aljazceru/turso.git
synced 2026-02-23 08:55:40 +01:00
Merge 'Implement Aggregations for DBSP views' from Glauber Costa
``` turso> create table t(a, b); turso> insert into t(a,b) values (2,2), (3,3); turso> insert into t(a,b) values (6,6), (7,7); turso> insert into t(a,b) values (6,6), (7,7); turso> create view tt as select b, sum(a) from t where b > 2 group by b; turso> select * from tt; ┌───┬─────────┐ │ b │ sum (a) │ ├───┼─────────┤ │ 3 │ 3 │ ├───┼─────────┤ │ 6 │ 12 │ ├───┼─────────┤ │ 7 │ 14 │ └───┴─────────┘ turso> insert into t(a,b) values (1,3); turso> select * from tt; ┌───┬─────────┐ │ b │ sum (a) │ ├───┼─────────┤ │ 3 │ 4 │ ├───┼─────────┤ │ 6 │ 12 │ ├───┼─────────┤ │ 7 │ 14 │ └───┴─────────┘ turso> ``` Closes #2547
This commit is contained in:
@@ -1,6 +1,7 @@
|
||||
use super::dbsp::{RowKeyStream, RowKeyZSet};
|
||||
use super::operator::{
|
||||
AggregateFunction, Delta, FilterOperator, FilterPredicate, ProjectColumn, ProjectOperator,
|
||||
AggregateFunction, AggregateOperator, ComputationTracker, Delta, FilterOperator,
|
||||
FilterPredicate, IncrementalOperator, ProjectColumn, ProjectOperator,
|
||||
};
|
||||
use crate::schema::{BTreeTable, Column, Schema};
|
||||
use crate::types::{IOResult, Value};
|
||||
@@ -9,7 +10,7 @@ use crate::{LimboError, Result, Statement};
|
||||
use fallible_iterator::FallibleIterator;
|
||||
use std::collections::BTreeMap;
|
||||
use std::fmt;
|
||||
use std::sync::Arc;
|
||||
use std::sync::{Arc, Mutex};
|
||||
use turso_sqlite3_parser::{
|
||||
ast::{Cmd, Stmt},
|
||||
lexer::sql::Parser,
|
||||
@@ -78,12 +79,18 @@ pub struct IncrementalView {
|
||||
filter_operator: Option<FilterOperator>,
|
||||
// Internal project operator for value transformation
|
||||
project_operator: Option<ProjectOperator>,
|
||||
// Internal aggregate operator for GROUP BY and aggregations
|
||||
aggregate_operator: Option<AggregateOperator>,
|
||||
// Tables referenced by this view (extracted from FROM clause and JOINs)
|
||||
base_table: Arc<BTreeTable>,
|
||||
// The view's output columns with their types
|
||||
pub columns: Vec<Column>,
|
||||
// State machine for population
|
||||
populate_state: PopulateState,
|
||||
// Computation tracker for statistics
|
||||
// We will use this one day to export rows_read, but for now, will just test that we're doing the expected amount of compute
|
||||
#[cfg_attr(not(test), allow(dead_code))]
|
||||
pub tracker: Arc<Mutex<ComputationTracker>>,
|
||||
}
|
||||
|
||||
impl IncrementalView {
|
||||
@@ -95,11 +102,6 @@ impl IncrementalView {
|
||||
) -> Result<()> {
|
||||
// Check for aggregations
|
||||
let (group_by_columns, aggregate_functions, _) = Self::extract_aggregation_info(select);
|
||||
if !group_by_columns.is_empty() || !aggregate_functions.is_empty() {
|
||||
return Err(LimboError::ParseError(
|
||||
"aggregations in views are not yet supported".to_string(),
|
||||
));
|
||||
}
|
||||
|
||||
// Check for JOINs
|
||||
let (join_tables, join_condition) = Self::extract_join_info(select);
|
||||
@@ -127,8 +129,10 @@ impl IncrementalView {
|
||||
.map(|(i, col)| col.name.clone().unwrap_or_else(|| format!("column_{i}")))
|
||||
.collect();
|
||||
|
||||
// Validate columns are a strict subset
|
||||
Self::validate_view_columns(select, &base_table_column_names)?;
|
||||
// For non-aggregated views, validate columns are a strict subset
|
||||
if group_by_columns.is_empty() && aggregate_functions.is_empty() {
|
||||
Self::validate_view_columns(select, &base_table_column_names)?;
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
@@ -203,12 +207,6 @@ impl IncrementalView {
|
||||
let (group_by_columns, aggregate_functions, _old_output_names) =
|
||||
Self::extract_aggregation_info(&select);
|
||||
|
||||
if !group_by_columns.is_empty() || !aggregate_functions.is_empty() {
|
||||
return Err(LimboError::ParseError(
|
||||
"aggregations in views are not yet supported".to_string(),
|
||||
));
|
||||
}
|
||||
|
||||
let (join_tables, join_condition) = Self::extract_join_info(&select);
|
||||
if join_tables.is_some() || join_condition.is_some() {
|
||||
return Err(LimboError::ParseError(
|
||||
@@ -238,7 +236,7 @@ impl IncrementalView {
|
||||
.map(|(i, col)| col.name.clone().unwrap_or_else(|| format!("column_{i}")))
|
||||
.collect();
|
||||
|
||||
Ok(Self::new(
|
||||
Self::new(
|
||||
name,
|
||||
Vec::new(), // Empty initial data
|
||||
where_predicate,
|
||||
@@ -246,9 +244,12 @@ impl IncrementalView {
|
||||
base_table,
|
||||
base_table_column_names,
|
||||
view_columns,
|
||||
))
|
||||
group_by_columns,
|
||||
aggregate_functions,
|
||||
)
|
||||
}
|
||||
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
pub fn new(
|
||||
name: String,
|
||||
initial_data: Vec<(i64, Vec<Value>)>,
|
||||
@@ -257,7 +258,9 @@ impl IncrementalView {
|
||||
base_table: Arc<BTreeTable>,
|
||||
base_table_column_names: Vec<String>,
|
||||
columns: Vec<Column>,
|
||||
) -> Self {
|
||||
group_by_columns: Vec<String>,
|
||||
aggregate_functions: Vec<AggregateFunction>,
|
||||
) -> Result<Self> {
|
||||
let mut records = BTreeMap::new();
|
||||
|
||||
for (row_key, values) in initial_data {
|
||||
@@ -272,17 +275,37 @@ impl IncrementalView {
|
||||
zset.insert(row, 1);
|
||||
}
|
||||
|
||||
// Create the tracker that will be shared by all operators
|
||||
let tracker = Arc::new(Mutex::new(ComputationTracker::new()));
|
||||
|
||||
// Create filter operator if we have a predicate
|
||||
let filter_operator = if !matches!(where_predicate, FilterPredicate::None) {
|
||||
Some(FilterOperator::new(
|
||||
where_predicate.clone(),
|
||||
base_table_column_names.clone(),
|
||||
))
|
||||
let mut filter_op =
|
||||
FilterOperator::new(where_predicate.clone(), base_table_column_names.clone());
|
||||
filter_op.set_tracker(tracker.clone());
|
||||
Some(filter_op)
|
||||
} else {
|
||||
None
|
||||
};
|
||||
|
||||
let project_operator = {
|
||||
// Check if this is an aggregated view
|
||||
let is_aggregated = !group_by_columns.is_empty() || !aggregate_functions.is_empty();
|
||||
|
||||
// Create aggregate operator if needed
|
||||
let aggregate_operator = if is_aggregated {
|
||||
let mut agg_op = AggregateOperator::new(
|
||||
group_by_columns,
|
||||
aggregate_functions,
|
||||
base_table_column_names.clone(),
|
||||
);
|
||||
agg_op.set_tracker(tracker.clone());
|
||||
Some(agg_op)
|
||||
} else {
|
||||
None
|
||||
};
|
||||
|
||||
// Only create project operator for non-aggregated views
|
||||
let project_operator = if !is_aggregated {
|
||||
let columns = Self::extract_project_columns(&select_stmt, &base_table_column_names)
|
||||
.unwrap_or_else(|| {
|
||||
// If we can't extract columns, default to projecting all columns
|
||||
@@ -291,13 +314,14 @@ impl IncrementalView {
|
||||
.map(|name| ProjectColumn::Column(name.to_string()))
|
||||
.collect()
|
||||
});
|
||||
Some(ProjectOperator::new(
|
||||
columns,
|
||||
base_table_column_names.clone(),
|
||||
))
|
||||
let mut proj_op = ProjectOperator::new(columns, base_table_column_names.clone());
|
||||
proj_op.set_tracker(tracker.clone());
|
||||
Some(proj_op)
|
||||
} else {
|
||||
None
|
||||
};
|
||||
|
||||
Self {
|
||||
Ok(Self {
|
||||
stream: RowKeyStream::from_zset(zset),
|
||||
name,
|
||||
records,
|
||||
@@ -305,10 +329,12 @@ impl IncrementalView {
|
||||
select_stmt,
|
||||
filter_operator,
|
||||
project_operator,
|
||||
aggregate_operator,
|
||||
base_table,
|
||||
columns,
|
||||
populate_state: PopulateState::Start,
|
||||
}
|
||||
tracker,
|
||||
})
|
||||
}
|
||||
|
||||
pub fn name(&self) -> &str {
|
||||
@@ -526,13 +552,15 @@ impl IncrementalView {
|
||||
stmt,
|
||||
rows_processed,
|
||||
} => {
|
||||
// Process rows in batches to allow for IO interruption
|
||||
// Collect rows into a delta batch
|
||||
let mut batch_delta = Delta::new();
|
||||
let mut batch_count = 0;
|
||||
|
||||
loop {
|
||||
if batch_count >= BATCH_SIZE {
|
||||
// Process this batch through the standard pipeline
|
||||
self.merge_delta(&batch_delta);
|
||||
// Yield control after processing a batch
|
||||
// The statement maintains its position, so we'll resume from here
|
||||
return Ok(IOResult::IO);
|
||||
}
|
||||
|
||||
@@ -560,38 +588,15 @@ impl IncrementalView {
|
||||
// Get all values except the rowid
|
||||
let values = all_values[..all_values.len() - 1].to_vec();
|
||||
|
||||
// Apply filter if we have one
|
||||
// Pure DBSP would ingest the entire stream and then apply filter operators.
|
||||
// However, for initial population, we adopt a hybrid approach where we filter at
|
||||
// the query result level for efficiency. This avoids reading millions of rows just
|
||||
// to filter them down to a few. We only do this optimization for filters, not for
|
||||
// other operators like projections or aggregations.
|
||||
// TODO: We should further optimize by pushing the filter into the SQL WHERE clause.
|
||||
|
||||
// Check filter first (we need to do this before accessing self mutably)
|
||||
let passes_filter =
|
||||
if let Some(ref filter_op) = self.filter_operator {
|
||||
filter_op.evaluate_predicate(&values)
|
||||
} else {
|
||||
true
|
||||
};
|
||||
|
||||
if passes_filter {
|
||||
// Store the row with its original rowid
|
||||
self.records.insert(rowid, values.clone());
|
||||
|
||||
// Update the ZSet stream with weight +1
|
||||
let mut delta = RowKeyZSet::new();
|
||||
use crate::incremental::hashable_row::HashableRow;
|
||||
let row = HashableRow::new(rowid, values);
|
||||
delta.insert(row, 1);
|
||||
self.stream.apply_delta(&delta);
|
||||
}
|
||||
// Add to batch delta - let merge_delta handle filtering and aggregation
|
||||
batch_delta.insert(rowid, values);
|
||||
|
||||
*rows_processed += 1;
|
||||
batch_count += 1;
|
||||
}
|
||||
crate::vdbe::StepResult::Done => {
|
||||
// Process any remaining rows in the batch
|
||||
self.merge_delta(&batch_delta);
|
||||
// All rows processed, move to Done state
|
||||
self.populate_state = PopulateState::Done;
|
||||
return Ok(IOResult::Done(()));
|
||||
@@ -600,8 +605,9 @@ impl IncrementalView {
|
||||
return Err(LimboError::Busy);
|
||||
}
|
||||
crate::vdbe::StepResult::IO => {
|
||||
// Process current batch before yielding
|
||||
self.merge_delta(&batch_delta);
|
||||
// The Statement needs to wait for IO
|
||||
// When we return here, the Statement maintains its position
|
||||
return Ok(IOResult::IO);
|
||||
}
|
||||
}
|
||||
@@ -903,28 +909,49 @@ impl IncrementalView {
|
||||
}
|
||||
}
|
||||
|
||||
/// Apply filter operator to a delta if present
|
||||
fn apply_filter_to_delta(&mut self, delta: Delta) -> Delta {
|
||||
if let Some(ref mut filter_op) = self.filter_operator {
|
||||
filter_op.process_delta(delta)
|
||||
} else {
|
||||
delta
|
||||
}
|
||||
}
|
||||
|
||||
/// Apply aggregation operator to a delta if this is an aggregated view
|
||||
fn apply_aggregation_to_delta(&mut self, delta: Delta) -> Delta {
|
||||
if let Some(ref mut agg_op) = self.aggregate_operator {
|
||||
agg_op.process_delta(delta)
|
||||
} else {
|
||||
delta
|
||||
}
|
||||
}
|
||||
|
||||
/// Merge a delta of changes into the view's current state
|
||||
pub fn merge_delta(&mut self, delta: &Delta) {
|
||||
// Create a Z-set of changes to apply to the stream
|
||||
// Early return if delta is empty
|
||||
if delta.is_empty() {
|
||||
return;
|
||||
}
|
||||
|
||||
// Apply operators in pipeline
|
||||
let mut current_delta = delta.clone();
|
||||
current_delta = self.apply_filter_to_delta(current_delta);
|
||||
current_delta = self.apply_aggregation_to_delta(current_delta);
|
||||
|
||||
// Update records and stream with the processed delta
|
||||
let mut zset_delta = RowKeyZSet::new();
|
||||
|
||||
// Apply the delta changes to the records
|
||||
for (row, weight) in &delta.changes {
|
||||
for (row, weight) in ¤t_delta.changes {
|
||||
if *weight > 0 {
|
||||
// Insert
|
||||
if self.apply_filter(&row.values) {
|
||||
self.records.insert(row.rowid, row.values.clone());
|
||||
zset_delta.insert(row.clone(), 1);
|
||||
}
|
||||
self.records.insert(row.rowid, row.values.clone());
|
||||
zset_delta.insert(row.clone(), 1);
|
||||
} else if *weight < 0 {
|
||||
// Delete
|
||||
if self.records.remove(&row.rowid).is_some() {
|
||||
zset_delta.insert(row.clone(), -1);
|
||||
}
|
||||
self.records.remove(&row.rowid);
|
||||
zset_delta.insert(row.clone(), -1);
|
||||
}
|
||||
}
|
||||
|
||||
// Apply all changes to the stream at once
|
||||
self.stream.apply_delta(&zset_delta);
|
||||
}
|
||||
}
|
||||
@@ -1206,4 +1233,165 @@ mod tests {
|
||||
]
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_aggregation_count_with_group_by() {
|
||||
let schema = create_test_schema();
|
||||
let sql = "CREATE VIEW v AS SELECT a, COUNT(*) FROM t GROUP BY a";
|
||||
|
||||
let mut view = IncrementalView::from_sql(sql, &schema).unwrap();
|
||||
|
||||
// Verify the view has an aggregate operator
|
||||
assert!(view.aggregate_operator.is_some());
|
||||
|
||||
// Insert some test data
|
||||
let mut delta = Delta::new();
|
||||
delta.insert(
|
||||
1,
|
||||
vec![Value::Integer(1), Value::Integer(10), Value::Integer(100)],
|
||||
);
|
||||
delta.insert(
|
||||
2,
|
||||
vec![Value::Integer(2), Value::Integer(20), Value::Integer(200)],
|
||||
);
|
||||
delta.insert(
|
||||
3,
|
||||
vec![Value::Integer(1), Value::Integer(30), Value::Integer(300)],
|
||||
);
|
||||
|
||||
// Process the delta
|
||||
view.merge_delta(&delta);
|
||||
|
||||
// Verify we only processed the 3 rows we inserted
|
||||
assert_eq!(view.tracker.lock().unwrap().aggregation_updates, 3);
|
||||
|
||||
// Check the aggregated results
|
||||
let results = view.current_data(None);
|
||||
|
||||
// Should have 2 groups: a=1 with count=2, a=2 with count=1
|
||||
assert_eq!(results.len(), 2);
|
||||
|
||||
// Find the group with a=1
|
||||
let group1 = results
|
||||
.iter()
|
||||
.find(|(_, vals)| vals[0] == Value::Integer(1))
|
||||
.unwrap();
|
||||
assert_eq!(group1.1[0], Value::Integer(1)); // a=1
|
||||
assert_eq!(group1.1[1], Value::Integer(2)); // COUNT(*)=2
|
||||
|
||||
// Find the group with a=2
|
||||
let group2 = results
|
||||
.iter()
|
||||
.find(|(_, vals)| vals[0] == Value::Integer(2))
|
||||
.unwrap();
|
||||
assert_eq!(group2.1[0], Value::Integer(2)); // a=2
|
||||
assert_eq!(group2.1[1], Value::Integer(1)); // COUNT(*)=1
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_aggregation_sum_with_filter() {
|
||||
let schema = create_test_schema();
|
||||
let sql = "CREATE VIEW v AS SELECT SUM(b) FROM t WHERE a > 1";
|
||||
|
||||
let mut view = IncrementalView::from_sql(sql, &schema).unwrap();
|
||||
|
||||
assert!(view.aggregate_operator.is_some());
|
||||
assert!(view.filter_operator.is_some());
|
||||
|
||||
let mut delta = Delta::new();
|
||||
delta.insert(
|
||||
1,
|
||||
vec![Value::Integer(1), Value::Integer(10), Value::Integer(100)],
|
||||
);
|
||||
delta.insert(
|
||||
2,
|
||||
vec![Value::Integer(2), Value::Integer(20), Value::Integer(200)],
|
||||
);
|
||||
delta.insert(
|
||||
3,
|
||||
vec![Value::Integer(3), Value::Integer(30), Value::Integer(300)],
|
||||
);
|
||||
|
||||
view.merge_delta(&delta);
|
||||
|
||||
// Should filter all 3 rows
|
||||
assert_eq!(view.tracker.lock().unwrap().filter_evaluations, 3);
|
||||
// But only aggregate the 2 that passed the filter (a > 1)
|
||||
assert_eq!(view.tracker.lock().unwrap().aggregation_updates, 2);
|
||||
|
||||
let results = view.current_data(None);
|
||||
|
||||
// Should have 1 row with sum of b where a > 1
|
||||
assert_eq!(results.len(), 1);
|
||||
assert_eq!(results[0].1[0], Value::Integer(50)); // SUM(b) = 20 + 30
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_aggregation_incremental_updates() {
|
||||
let schema = create_test_schema();
|
||||
let sql = "CREATE VIEW v AS SELECT a, COUNT(*), SUM(b) FROM t GROUP BY a";
|
||||
|
||||
let mut view = IncrementalView::from_sql(sql, &schema).unwrap();
|
||||
|
||||
// Initial insert
|
||||
let mut delta1 = Delta::new();
|
||||
delta1.insert(
|
||||
1,
|
||||
vec![Value::Integer(1), Value::Integer(10), Value::Integer(100)],
|
||||
);
|
||||
delta1.insert(
|
||||
2,
|
||||
vec![Value::Integer(1), Value::Integer(20), Value::Integer(200)],
|
||||
);
|
||||
|
||||
view.merge_delta(&delta1);
|
||||
|
||||
// Verify we processed exactly 2 rows for the first batch
|
||||
assert_eq!(view.tracker.lock().unwrap().aggregation_updates, 2);
|
||||
|
||||
// Check initial state
|
||||
let results1 = view.current_data(None);
|
||||
assert_eq!(results1.len(), 1);
|
||||
assert_eq!(results1[0].1[1], Value::Integer(2)); // COUNT(*)=2
|
||||
assert_eq!(results1[0].1[2], Value::Integer(30)); // SUM(b)=30
|
||||
|
||||
// Reset counter to track second batch separately
|
||||
view.tracker.lock().unwrap().aggregation_updates = 0;
|
||||
|
||||
// Add more data
|
||||
let mut delta2 = Delta::new();
|
||||
delta2.insert(
|
||||
3,
|
||||
vec![Value::Integer(1), Value::Integer(5), Value::Integer(300)],
|
||||
);
|
||||
delta2.insert(
|
||||
4,
|
||||
vec![Value::Integer(2), Value::Integer(15), Value::Integer(400)],
|
||||
);
|
||||
|
||||
view.merge_delta(&delta2);
|
||||
|
||||
// Should only process the 2 new rows, not recompute everything
|
||||
assert_eq!(view.tracker.lock().unwrap().aggregation_updates, 2);
|
||||
|
||||
// Check updated state
|
||||
let results2 = view.current_data(None);
|
||||
assert_eq!(results2.len(), 2);
|
||||
|
||||
// Group a=1
|
||||
let group1 = results2
|
||||
.iter()
|
||||
.find(|(_, vals)| vals[0] == Value::Integer(1))
|
||||
.unwrap();
|
||||
assert_eq!(group1.1[1], Value::Integer(3)); // COUNT(*)=3
|
||||
assert_eq!(group1.1[2], Value::Integer(35)); // SUM(b)=35
|
||||
|
||||
// Group a=2
|
||||
let group2 = results2
|
||||
.iter()
|
||||
.find(|(_, vals)| vals[0] == Value::Integer(2))
|
||||
.unwrap();
|
||||
assert_eq!(group2.1[1], Value::Integer(1)); // COUNT(*)=1
|
||||
assert_eq!(group2.1[2], Value::Integer(15)); // SUM(b)=15
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user