From f149b40e75685b352fb4df35b548e05ff19321bc Mon Sep 17 00:00:00 2001 From: Glauber Costa Date: Tue, 16 Sep 2025 16:00:15 -0500 Subject: [PATCH] Implement JOINs in the DBSP circuit This PR improves the DBSP circuit so that it handles the JOIN operator. The JOIN operator exposes a weakness of our current model: we usually pass a list of columns between operators, and find the right column by name when needed. But with JOINs, many tables can have the same columns. The operators will then find the wrong column (same name, different table), and produce incorrect results. To fix this, we must do two things: 1) Change the Logical Plan. It needs to track table provenance. 2) Fix the aggregators: it needs to operate on indexes, not names. For the aggregators, note that table provenance is the wrong abstraction. The aggregator is likely working with a logical table that is the result of previous nodes in the circuit. So we just need to be able to tell it which index in the column array it should use. --- core/incremental/aggregate_operator.rs | 273 +++-- core/incremental/compiler.rs | 1256 +++++++++++++++++++++++- core/incremental/operator.rs | 130 +-- core/translate/logical.rs | 276 ++++-- 4 files changed, 1625 insertions(+), 310 deletions(-) diff --git a/core/incremental/aggregate_operator.rs b/core/incremental/aggregate_operator.rs index f4c8ece0a..9f25a84f5 100644 --- a/core/incremental/aggregate_operator.rs +++ b/core/incremental/aggregate_operator.rs @@ -19,20 +19,20 @@ pub const AGG_TYPE_MINMAX: u8 = 0b01; // MIN/MAX (BTree ordering gives both) #[derive(Debug, Clone, PartialEq)] pub enum AggregateFunction { Count, - Sum(String), - Avg(String), - Min(String), - Max(String), + Sum(usize), // Column index + Avg(usize), // Column index + Min(usize), // Column index + Max(usize), // Column index } impl Display for AggregateFunction { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { match self { AggregateFunction::Count => write!(f, "COUNT(*)"), - AggregateFunction::Sum(col) => write!(f, "SUM({col})"), - AggregateFunction::Avg(col) => write!(f, "AVG({col})"), - AggregateFunction::Min(col) => write!(f, "MIN({col})"), - AggregateFunction::Max(col) => write!(f, "MAX({col})"), + AggregateFunction::Sum(idx) => write!(f, "SUM(col{idx})"), + AggregateFunction::Avg(idx) => write!(f, "AVG(col{idx})"), + AggregateFunction::Min(idx) => write!(f, "MIN(col{idx})"), + AggregateFunction::Max(idx) => write!(f, "MAX(col{idx})"), } } } @@ -48,16 +48,16 @@ impl AggregateFunction { /// Returns None if the function is not a supported aggregate pub fn from_sql_function( func: &crate::function::Func, - input_column: Option, + input_column_idx: Option, ) -> Option { match func { Func::Agg(agg_func) => { match agg_func { AggFunc::Count | AggFunc::Count0 => Some(AggregateFunction::Count), - AggFunc::Sum => input_column.map(AggregateFunction::Sum), - AggFunc::Avg => input_column.map(AggregateFunction::Avg), - AggFunc::Min => input_column.map(AggregateFunction::Min), - AggFunc::Max => input_column.map(AggregateFunction::Max), + AggFunc::Sum => input_column_idx.map(AggregateFunction::Sum), + AggFunc::Avg => input_column_idx.map(AggregateFunction::Avg), + AggFunc::Min => input_column_idx.map(AggregateFunction::Min), + AggFunc::Max => input_column_idx.map(AggregateFunction::Max), _ => None, // Other aggregate functions not yet supported in DBSP } } @@ -115,8 +115,8 @@ pub fn deserialize_value(blob: &[u8]) -> Option<(Value, usize)> { // group_key_str -> (group_key, state) type ComputedStates = HashMap, AggregateState)>; -// group_key_str -> (column_name, value_as_hashable_row) -> accumulated_weight -pub type MinMaxDeltas = HashMap>; +// group_key_str -> (column_index, value_as_hashable_row) -> accumulated_weight +pub type MinMaxDeltas = HashMap>; #[derive(Debug)] enum AggregateCommitState { @@ -178,14 +178,14 @@ pub enum AggregateEvalState { pub struct AggregateOperator { // Unique operator ID for indexing in persistent storage pub operator_id: usize, - // GROUP BY columns - group_by: Vec, + // GROUP BY column indices + group_by: Vec, // Aggregate functions to compute (including MIN/MAX) pub aggregates: Vec, // Column names from input pub input_column_names: Vec, - // Map from column name to aggregate info for quick lookup - pub column_min_max: HashMap, + // Map from column index to aggregate info for quick lookup + pub column_min_max: HashMap, tracker: Option>>, // State machine for commit operation @@ -197,14 +197,14 @@ pub struct AggregateOperator { pub struct AggregateState { // For COUNT: just the count pub count: i64, - // For SUM: column_name -> sum value - sums: HashMap, - // For AVG: column_name -> (sum, count) for computing average - avgs: HashMap, - // For MIN: column_name -> minimum value - pub mins: HashMap, - // For MAX: column_name -> maximum value - pub maxs: HashMap, + // For SUM: column_index -> sum value + sums: HashMap, + // For AVG: column_index -> (sum, count) for computing average + avgs: HashMap, + // For MIN: column_index -> minimum value + pub mins: HashMap, + // For MAX: column_index -> maximum value + pub maxs: HashMap, } impl AggregateEvalState { @@ -520,14 +520,14 @@ impl AggregateState { AggregateFunction::Sum(col_name) => { let sum = f64::from_le_bytes(blob.get(cursor..cursor + 8)?.try_into().ok()?); cursor += 8; - state.sums.insert(col_name.clone(), sum); + state.sums.insert(*col_name, sum); } AggregateFunction::Avg(col_name) => { let sum = f64::from_le_bytes(blob.get(cursor..cursor + 8)?.try_into().ok()?); cursor += 8; let count = i64::from_le_bytes(blob.get(cursor..cursor + 8)?.try_into().ok()?); cursor += 8; - state.avgs.insert(col_name.clone(), (sum, count)); + state.avgs.insert(*col_name, (sum, count)); } AggregateFunction::Count => { // Count was already read above @@ -540,7 +540,7 @@ impl AggregateState { if has_value == 1 { let (min_value, bytes_consumed) = deserialize_value(&blob[cursor..])?; cursor += bytes_consumed; - state.mins.insert(col_name.clone(), min_value); + state.mins.insert(*col_name, min_value); } } AggregateFunction::Max(col_name) => { @@ -551,7 +551,7 @@ impl AggregateState { if has_value == 1 { let (max_value, bytes_consumed) = deserialize_value(&blob[cursor..])?; cursor += bytes_consumed; - state.maxs.insert(col_name.clone(), max_value); + state.maxs.insert(*col_name, max_value); } } } @@ -566,7 +566,7 @@ impl AggregateState { values: &[Value], weight: isize, aggregates: &[AggregateFunction], - column_names: &[String], + _column_names: &[String], // No longer needed ) { // Update COUNT self.count += weight as i64; @@ -577,32 +577,26 @@ impl AggregateState { AggregateFunction::Count => { // Already handled above } - AggregateFunction::Sum(col_name) => { - if let Some(idx) = column_names.iter().position(|c| c == col_name) { - if let Some(val) = values.get(idx) { - let num_val = match val { - Value::Integer(i) => *i as f64, - Value::Float(f) => *f, - _ => 0.0, - }; - *self.sums.entry(col_name.clone()).or_insert(0.0) += - num_val * weight as f64; - } + AggregateFunction::Sum(col_idx) => { + if let Some(val) = values.get(*col_idx) { + let num_val = match val { + Value::Integer(i) => *i as f64, + Value::Float(f) => *f, + _ => 0.0, + }; + *self.sums.entry(*col_idx).or_insert(0.0) += num_val * weight as f64; } } - AggregateFunction::Avg(col_name) => { - if let Some(idx) = column_names.iter().position(|c| c == col_name) { - if let Some(val) = values.get(idx) { - let num_val = match val { - Value::Integer(i) => *i as f64, - Value::Float(f) => *f, - _ => 0.0, - }; - let (sum, count) = - self.avgs.entry(col_name.clone()).or_insert((0.0, 0)); - *sum += num_val * weight as f64; - *count += weight as i64; - } + AggregateFunction::Avg(col_idx) => { + if let Some(val) = values.get(*col_idx) { + let num_val = match val { + Value::Integer(i) => *i as f64, + Value::Float(f) => *f, + _ => 0.0, + }; + let (sum, count) = self.avgs.entry(*col_idx).or_insert((0.0, 0)); + *sum += num_val * weight as f64; + *count += weight as i64; } } AggregateFunction::Min(_col_name) | AggregateFunction::Max(_col_name) => { @@ -644,8 +638,8 @@ impl AggregateState { AggregateFunction::Count => { result.push(Value::Integer(self.count)); } - AggregateFunction::Sum(col_name) => { - let sum = self.sums.get(col_name).copied().unwrap_or(0.0); + AggregateFunction::Sum(col_idx) => { + let sum = self.sums.get(col_idx).copied().unwrap_or(0.0); // Return as integer if it's a whole number, otherwise as float if sum.fract() == 0.0 { result.push(Value::Integer(sum as i64)); @@ -653,8 +647,8 @@ impl AggregateState { result.push(Value::Float(sum)); } } - AggregateFunction::Avg(col_name) => { - if let Some((sum, count)) = self.avgs.get(col_name) { + AggregateFunction::Avg(col_idx) => { + if let Some((sum, count)) = self.avgs.get(col_idx) { if *count > 0 { result.push(Value::Float(sum / *count as f64)); } else { @@ -664,13 +658,13 @@ impl AggregateState { result.push(Value::Null); } } - AggregateFunction::Min(col_name) => { + AggregateFunction::Min(col_idx) => { // Return the MIN value from our state - result.push(self.mins.get(col_name).cloned().unwrap_or(Value::Null)); + result.push(self.mins.get(col_idx).cloned().unwrap_or(Value::Null)); } - AggregateFunction::Max(col_name) => { + AggregateFunction::Max(col_idx) => { // Return the MAX value from our state - result.push(self.maxs.get(col_name).cloned().unwrap_or(Value::Null)); + result.push(self.maxs.get(col_idx).cloned().unwrap_or(Value::Null)); } } } @@ -682,20 +676,20 @@ impl AggregateState { impl AggregateOperator { pub fn new( operator_id: usize, - group_by: Vec, + group_by: Vec, aggregates: Vec, input_column_names: Vec, ) -> Self { - // Build map of column names to their MIN/MAX info with indices + // Build map of column indices to their MIN/MAX info let mut column_min_max = HashMap::new(); - let mut column_indices = HashMap::new(); + let mut storage_indices = HashMap::new(); let mut current_index = 0; - // First pass: assign indices to unique MIN/MAX columns + // First pass: assign storage indices to unique MIN/MAX columns for agg in &aggregates { match agg { - AggregateFunction::Min(col) | AggregateFunction::Max(col) => { - column_indices.entry(col.clone()).or_insert_with(|| { + AggregateFunction::Min(col_idx) | AggregateFunction::Max(col_idx) => { + storage_indices.entry(*col_idx).or_insert_with(|| { let idx = current_index; current_index += 1; idx @@ -708,19 +702,19 @@ impl AggregateOperator { // Second pass: build the column info map for agg in &aggregates { match agg { - AggregateFunction::Min(col) => { - let index = *column_indices.get(col).unwrap(); - let entry = column_min_max.entry(col.clone()).or_insert(AggColumnInfo { - index, + AggregateFunction::Min(col_idx) => { + let storage_index = *storage_indices.get(col_idx).unwrap(); + let entry = column_min_max.entry(*col_idx).or_insert(AggColumnInfo { + index: storage_index, has_min: false, has_max: false, }); entry.has_min = true; } - AggregateFunction::Max(col) => { - let index = *column_indices.get(col).unwrap(); - let entry = column_min_max.entry(col.clone()).or_insert(AggColumnInfo { - index, + AggregateFunction::Max(col_idx) => { + let storage_index = *storage_indices.get(col_idx).unwrap(); + let entry = column_min_max.entry(*col_idx).or_insert(AggColumnInfo { + index: storage_index, has_min: false, has_max: false, }); @@ -876,28 +870,24 @@ impl AggregateOperator { for agg in &self.aggregates { match agg { - AggregateFunction::Min(col_name) | AggregateFunction::Max(col_name) => { - if let Some(idx) = - self.input_column_names.iter().position(|c| c == col_name) - { - if let Some(val) = row.values.get(idx) { - // Skip NULL values - they don't participate in MIN/MAX - if val == &Value::Null { - continue; - } - // Create a HashableRow with just this value - // Use 0 as rowid since we only care about the value for comparison - let hashable_value = HashableRow::new(0, vec![val.clone()]); - let key = (col_name.clone(), hashable_value); - - let group_entry = - min_max_deltas.entry(group_key_str.clone()).or_default(); - - let value_entry = group_entry.entry(key).or_insert(0); - - // Accumulate the weight - *value_entry += weight; + AggregateFunction::Min(col_idx) | AggregateFunction::Max(col_idx) => { + if let Some(val) = row.values.get(*col_idx) { + // Skip NULL values - they don't participate in MIN/MAX + if val == &Value::Null { + continue; } + // Create a HashableRow with just this value + // Use 0 as rowid since we only care about the value for comparison + let hashable_value = HashableRow::new(0, vec![val.clone()]); + let key = (*col_idx, hashable_value); + + let group_entry = + min_max_deltas.entry(group_key_str.clone()).or_default(); + + let value_entry = group_entry.entry(key).or_insert(0); + + // Accumulate the weight + *value_entry += weight; } } _ => {} // Ignore non-MIN/MAX aggregates @@ -929,13 +919,9 @@ impl AggregateOperator { pub fn extract_group_key(&self, values: &[Value]) -> Vec { let mut key = Vec::new(); - for group_col in &self.group_by { - if let Some(idx) = self.input_column_names.iter().position(|c| c == group_col) { - if let Some(val) = values.get(idx) { - key.push(val.clone()); - } else { - key.push(Value::Null); - } + for &idx in &self.group_by { + if let Some(val) = values.get(idx) { + key.push(val.clone()); } else { key.push(Value::Null); } @@ -1124,13 +1110,13 @@ pub enum RecomputeMinMax { /// Current column being processed current_column_idx: usize, /// Columns to process (combined MIN and MAX) - columns_to_process: Vec<(String, String, bool)>, // (group_key, column_name, is_min) + columns_to_process: Vec<(String, usize, bool)>, // (group_key, column_name, is_min) /// MIN/MAX deltas for checking values and weights min_max_deltas: MinMaxDeltas, }, Scan { /// Columns still to process - columns_to_process: Vec<(String, String, bool)>, + columns_to_process: Vec<(String, usize, bool)>, /// Current index in columns_to_process (will resume from here) current_column_idx: usize, /// MIN/MAX deltas for checking values and weights @@ -1138,7 +1124,7 @@ pub enum RecomputeMinMax { /// Current group key being processed group_key: String, /// Current column name being processed - column_name: String, + column_name: usize, /// Whether we're looking for MIN (true) or MAX (false) is_min: bool, /// The scan state machine for finding the new MIN/MAX @@ -1153,7 +1139,7 @@ impl RecomputeMinMax { existing_groups: &HashMap, operator: &AggregateOperator, ) -> Self { - let mut groups_to_check: HashSet<(String, String, bool)> = HashSet::new(); + let mut groups_to_check: HashSet<(String, usize, bool)> = HashSet::new(); // Remember the min_max_deltas are essentially just the only column that is affected by // this min/max, in delta (actually ZSet - consolidated delta) format. This makes it easier @@ -1173,21 +1159,13 @@ impl RecomputeMinMax { // Check for MIN if let Some(current_min) = state.mins.get(col_name) { if current_min == value { - groups_to_check.insert(( - group_key_str.clone(), - col_name.clone(), - true, - )); + groups_to_check.insert((group_key_str.clone(), *col_name, true)); } } // Check for MAX if let Some(current_max) = state.maxs.get(col_name) { if current_max == value { - groups_to_check.insert(( - group_key_str.clone(), - col_name.clone(), - false, - )); + groups_to_check.insert((group_key_str.clone(), *col_name, false)); } } } @@ -1196,14 +1174,10 @@ impl RecomputeMinMax { // about this if this is a new record being inserted if let Some(info) = col_info { if info.has_min { - groups_to_check.insert((group_key_str.clone(), col_name.clone(), true)); + groups_to_check.insert((group_key_str.clone(), *col_name, true)); } if info.has_max { - groups_to_check.insert(( - group_key_str.clone(), - col_name.clone(), - false, - )); + groups_to_check.insert((group_key_str.clone(), *col_name, false)); } } } @@ -1245,12 +1219,13 @@ impl RecomputeMinMax { let (group_key, column_name, is_min) = columns_to_process[*current_column_idx].clone(); - // Get column index from pre-computed info - let column_index = operator + // Column name is already the index + // Get the storage index from column_min_max map + let column_info = operator .column_min_max .get(&column_name) - .map(|info| info.index) - .unwrap(); // Should always exist since we're processing known columns + .expect("Column should exist in column_min_max map"); + let storage_index = column_info.index; // Get current value from existing state let current_value = existing_groups.get(&group_key).and_then(|state| { @@ -1263,7 +1238,7 @@ impl RecomputeMinMax { // Create storage keys for index lookup let storage_id = - generate_storage_id(operator.operator_id, column_index, AGG_TYPE_MINMAX); + generate_storage_id(operator.operator_id, storage_index, AGG_TYPE_MINMAX); let zset_id = operator.generate_group_rowid(&group_key); // Get the values for this group from min_max_deltas @@ -1276,7 +1251,7 @@ impl RecomputeMinMax { Box::new(ScanState::new_for_min( current_value, group_key.clone(), - column_name.clone(), + column_name, storage_id, zset_id, group_values, @@ -1285,7 +1260,7 @@ impl RecomputeMinMax { Box::new(ScanState::new_for_max( current_value, group_key.clone(), - column_name.clone(), + column_name, storage_id, zset_id, group_values, @@ -1319,12 +1294,12 @@ impl RecomputeMinMax { if *is_min { if let Some(min_val) = new_value { - state.mins.insert(column_name.clone(), min_val); + state.mins.insert(*column_name, min_val); } else { state.mins.remove(column_name); } } else if let Some(max_val) = new_value { - state.maxs.insert(column_name.clone(), max_val); + state.maxs.insert(*column_name, max_val); } else { state.maxs.remove(column_name); } @@ -1355,13 +1330,13 @@ pub enum ScanState { /// Group key being processed group_key: String, /// Column name being processed - column_name: String, + column_name: usize, /// Storage ID for the index seek storage_id: i64, /// ZSet ID for the group zset_id: i64, /// Group values from MinMaxDeltas: (column_name, HashableRow) -> weight - group_values: HashMap<(String, HashableRow), isize>, + group_values: HashMap<(usize, HashableRow), isize>, /// Whether we're looking for MIN (true) or MAX (false) is_min: bool, }, @@ -1371,13 +1346,13 @@ pub enum ScanState { /// Group key being processed group_key: String, /// Column name being processed - column_name: String, + column_name: usize, /// Storage ID for the index seek storage_id: i64, /// ZSet ID for the group zset_id: i64, /// Group values from MinMaxDeltas: (column_name, HashableRow) -> weight - group_values: HashMap<(String, HashableRow), isize>, + group_values: HashMap<(usize, HashableRow), isize>, /// Whether we're looking for MIN (true) or MAX (false) is_min: bool, }, @@ -1391,10 +1366,10 @@ impl ScanState { pub fn new_for_min( current_min: Option, group_key: String, - column_name: String, + column_name: usize, storage_id: i64, zset_id: i64, - group_values: HashMap<(String, HashableRow), isize>, + group_values: HashMap<(usize, HashableRow), isize>, ) -> Self { Self::CheckCandidate { candidate: current_min, @@ -1460,10 +1435,10 @@ impl ScanState { pub fn new_for_max( current_max: Option, group_key: String, - column_name: String, + column_name: usize, storage_id: i64, zset_id: i64, - group_values: HashMap<(String, HashableRow), isize>, + group_values: HashMap<(usize, HashableRow), isize>, ) -> Self { Self::CheckCandidate { candidate: current_max, @@ -1496,7 +1471,7 @@ impl ScanState { // Check if the candidate is retracted (weight <= 0) // Create a HashableRow to look up the weight let hashable_cand = HashableRow::new(0, vec![cand_val.clone()]); - let key = (column_name.clone(), hashable_cand); + let key = (*column_name, hashable_cand); let is_retracted = group_values.get(&key).is_some_and(|weight| *weight <= 0); @@ -1633,7 +1608,7 @@ pub enum MinMaxPersistState { group_idx: usize, value_idx: usize, value: Value, - column_name: String, + column_name: usize, weight: isize, write_row: WriteRow, }, @@ -1652,7 +1627,7 @@ impl MinMaxPersistState { pub fn persist_min_max( &mut self, operator_id: usize, - column_min_max: &HashMap, + column_min_max: &HashMap, cursors: &mut DbspStateCursors, generate_group_rowid: impl Fn(&str) -> i64, ) -> Result> { @@ -1699,7 +1674,7 @@ impl MinMaxPersistState { // Process current value and extract what we need before taking ownership let ((column_name, hashable_row), weight) = values_vec[*value_idx]; - let column_name = column_name.clone(); + let column_name = *column_name; let value = hashable_row.values[0].clone(); // Extract the Value from HashableRow let weight = *weight; @@ -1731,9 +1706,9 @@ impl MinMaxPersistState { let group_key_str = &group_keys[*group_idx]; - // Get the column index from the pre-computed map + // Get the column info from the pre-computed map let column_info = column_min_max - .get(&*column_name) + .get(column_name) .expect("Column should exist in column_min_max map"); let column_index = column_info.index; diff --git a/core/incremental/compiler.rs b/core/incremental/compiler.rs index 972d6797b..c8899a02e 100644 --- a/core/incremental/compiler.rs +++ b/core/incremental/compiler.rs @@ -9,12 +9,12 @@ use crate::incremental::dbsp::{Delta, DeltaPair}; use crate::incremental::expr_compiler::CompiledExpression; use crate::incremental::operator::{ create_dbsp_state_index, DbspStateCursors, EvalState, FilterOperator, FilterPredicate, - IncrementalOperator, InputOperator, ProjectOperator, + IncrementalOperator, InputOperator, JoinOperator, JoinType, ProjectOperator, }; use crate::storage::btree::{BTreeCursor, BTreeKey}; // Note: logical module must be made pub(crate) in translate/mod.rs use crate::translate::logical::{ - BinaryOperator, LogicalExpr, LogicalPlan, LogicalSchema, SchemaRef, + BinaryOperator, JoinType as LogicalJoinType, LogicalExpr, LogicalPlan, LogicalSchema, SchemaRef, }; use crate::types::{IOResult, ImmutableRecord, SeekKey, SeekOp, SeekResult, Value}; use crate::Pager; @@ -288,6 +288,12 @@ pub enum DbspOperator { aggr_exprs: Vec, schema: SchemaRef, }, + /// Join operator (⋈) - joins two relations + Join { + join_type: JoinType, + on_exprs: Vec<(DbspExpr, DbspExpr)>, + schema: SchemaRef, + }, /// Input operator - source of data Input { name: String, schema: SchemaRef }, } @@ -789,6 +795,13 @@ impl DbspCircuit { "{indent}Aggregate[{node_id}]: GROUP BY {group_exprs:?}, AGGR {aggr_exprs:?}" )?; } + DbspOperator::Join { + join_type, + on_exprs, + .. + } => { + writeln!(f, "{indent}Join[{node_id}]: {join_type:?} ON {on_exprs:?}")?; + } DbspOperator::Input { name, .. } => { writeln!(f, "{indent}Input[{node_id}]: {name}")?; } @@ -841,7 +854,7 @@ impl DbspCompiler { // Get input column names for the ProjectOperator let input_schema = proj.input.schema(); let input_column_names: Vec = input_schema.columns.iter() - .map(|(name, _)| name.clone()) + .map(|col| col.name.clone()) .collect(); // Convert logical expressions to DBSP expressions @@ -853,14 +866,14 @@ impl DbspCompiler { let mut compiled_exprs = Vec::new(); let mut aliases = Vec::new(); for expr in &proj.exprs { - let (compiled, alias) = Self::compile_expression(expr, &input_column_names)?; + let (compiled, alias) = Self::compile_expression(expr, input_schema)?; compiled_exprs.push(compiled); aliases.push(alias); } // Get output column names from the projection schema let output_column_names: Vec = proj.schema.columns.iter() - .map(|(name, _)| name.clone()) + .map(|col| col.name.clone()) .collect(); // Create the ProjectOperator @@ -885,7 +898,7 @@ impl DbspCompiler { // Get column names from input schema let input_schema = filter.input.schema(); let column_names: Vec = input_schema.columns.iter() - .map(|(name, _)| name.clone()) + .map(|col| col.name.clone()) .collect(); // Convert predicate to DBSP expression @@ -913,16 +926,21 @@ impl DbspCompiler { // Get input column names let input_schema = agg.input.schema(); let input_column_names: Vec = input_schema.columns.iter() - .map(|(name, _)| name.clone()) + .map(|col| col.name.clone()) .collect(); - // Compile group by expressions to column names - let mut group_by_columns = Vec::new(); + // Compile group by expressions to column indices + let mut group_by_indices = Vec::new(); let mut dbsp_group_exprs = Vec::new(); for expr in &agg.group_expr { // For now, only support simple column references in GROUP BY if let LogicalExpr::Column(col) = expr { - group_by_columns.push(col.name.clone()); + // Find the column index in the input schema using qualified lookup + let (col_idx, _) = input_schema.find_column(&col.name, col.table.as_deref()) + .ok_or_else(|| LimboError::ParseError( + format!("GROUP BY column '{}' not found in input", col.name) + ))?; + group_by_indices.push(col_idx); dbsp_group_exprs.push(DbspExpr::Column(col.name.clone())); } else { return Err(LimboError::ParseError( @@ -936,7 +954,7 @@ impl DbspCompiler { for expr in &agg.aggr_expr { if let LogicalExpr::AggregateFunction { fun, args, .. } = expr { use crate::function::AggFunc; - use crate::incremental::operator::AggregateFunction; + use crate::incremental::aggregate_operator::AggregateFunction; match fun { AggFunc::Count | AggFunc::Count0 => { @@ -946,9 +964,13 @@ impl DbspCompiler { if args.is_empty() { return Err(LimboError::ParseError("SUM requires an argument".to_string())); } - // Extract column name from the argument + // Extract column index from the argument if let LogicalExpr::Column(col) = &args[0] { - aggregate_functions.push(AggregateFunction::Sum(col.name.clone())); + let (col_idx, _) = input_schema.find_column(&col.name, col.table.as_deref()) + .ok_or_else(|| LimboError::ParseError( + format!("SUM column '{}' not found in input", col.name) + ))?; + aggregate_functions.push(AggregateFunction::Sum(col_idx)); } else { return Err(LimboError::ParseError( "Only column references are supported in aggregate functions for incremental views".to_string() @@ -960,7 +982,11 @@ impl DbspCompiler { return Err(LimboError::ParseError("AVG requires an argument".to_string())); } if let LogicalExpr::Column(col) = &args[0] { - aggregate_functions.push(AggregateFunction::Avg(col.name.clone())); + let (col_idx, _) = input_schema.find_column(&col.name, col.table.as_deref()) + .ok_or_else(|| LimboError::ParseError( + format!("AVG column '{}' not found in input", col.name) + ))?; + aggregate_functions.push(AggregateFunction::Avg(col_idx)); } else { return Err(LimboError::ParseError( "Only column references are supported in aggregate functions for incremental views".to_string() @@ -972,7 +998,11 @@ impl DbspCompiler { return Err(LimboError::ParseError("MIN requires an argument".to_string())); } if let LogicalExpr::Column(col) = &args[0] { - aggregate_functions.push(AggregateFunction::Min(col.name.clone())); + let (col_idx, _) = input_schema.find_column(&col.name, col.table.as_deref()) + .ok_or_else(|| LimboError::ParseError( + format!("MIN column '{}' not found in input", col.name) + ))?; + aggregate_functions.push(AggregateFunction::Min(col_idx)); } else { return Err(LimboError::ParseError( "Only column references are supported in MIN for incremental views".to_string() @@ -984,7 +1014,11 @@ impl DbspCompiler { return Err(LimboError::ParseError("MAX requires an argument".to_string())); } if let LogicalExpr::Column(col) = &args[0] { - aggregate_functions.push(AggregateFunction::Max(col.name.clone())); + let (col_idx, _) = input_schema.find_column(&col.name, col.table.as_deref()) + .ok_or_else(|| LimboError::ParseError( + format!("MAX column '{}' not found in input", col.name) + ))?; + aggregate_functions.push(AggregateFunction::Max(col_idx)); } else { return Err(LimboError::ParseError( "Only column references are supported in MAX for incremental views".to_string() @@ -1006,10 +1040,10 @@ impl DbspCompiler { let operator_id = self.circuit.next_id; - use crate::incremental::operator::AggregateOperator; + use crate::incremental::aggregate_operator::AggregateOperator; let executable: Box = Box::new(AggregateOperator::new( operator_id, - group_by_columns.clone(), + group_by_indices.clone(), aggregate_functions.clone(), input_column_names.clone(), )); @@ -1026,6 +1060,90 @@ impl DbspCompiler { Ok(result_node_id) } + LogicalPlan::Join(join) => { + // Compile left and right inputs + let left_id = self.compile_plan(&join.left)?; + let right_id = self.compile_plan(&join.right)?; + + // Get schemas from inputs + let left_schema = join.left.schema(); + let right_schema = join.right.schema(); + + // Get column names from left and right + let left_columns: Vec = left_schema.columns.iter() + .map(|col| col.name.clone()) + .collect(); + let right_columns: Vec = right_schema.columns.iter() + .map(|col| col.name.clone()) + .collect(); + + // Extract join key indices from join conditions + // For now, we only support equijoin conditions + let mut left_key_indices = Vec::new(); + let mut right_key_indices = Vec::new(); + let mut dbsp_on_exprs = Vec::new(); + + for (left_expr, right_expr) in &join.on { + // Extract column indices from join expressions + // We expect simple column references in join conditions + if let (LogicalExpr::Column(left_col), LogicalExpr::Column(right_col)) = (left_expr, right_expr) { + // Find indices in respective schemas using qualified lookup + let (left_idx, _) = left_schema.find_column(&left_col.name, left_col.table.as_deref()) + .ok_or_else(|| LimboError::ParseError( + format!("Join column '{}' not found in left input", left_col.name) + ))?; + let (right_idx, _) = right_schema.find_column(&right_col.name, right_col.table.as_deref()) + .ok_or_else(|| LimboError::ParseError( + format!("Join column '{}' not found in right input", right_col.name) + ))?; + + left_key_indices.push(left_idx); + right_key_indices.push(right_idx); + + // Convert to DBSP expressions + dbsp_on_exprs.push(( + DbspExpr::Column(left_col.name.clone()), + DbspExpr::Column(right_col.name.clone()) + )); + } else { + return Err(LimboError::ParseError( + "Only simple column references are supported in join conditions for incremental views".to_string() + )); + } + } + + // Convert logical join type to operator join type + let operator_join_type = match join.join_type { + LogicalJoinType::Inner => JoinType::Inner, + LogicalJoinType::Left => JoinType::Left, + LogicalJoinType::Right => JoinType::Right, + LogicalJoinType::Full => JoinType::Full, + LogicalJoinType::Cross => JoinType::Cross, + }; + + // Create JoinOperator + let operator_id = self.circuit.next_id; + let executable: Box = Box::new(JoinOperator::new( + operator_id, + operator_join_type.clone(), + left_key_indices, + right_key_indices, + left_columns, + right_columns, + )?); + + // Create join node + let node_id = self.circuit.add_node( + DbspOperator::Join { + join_type: operator_join_type, + on_exprs: dbsp_on_exprs, + schema: join.schema.clone(), + }, + vec![left_id, right_id], + executable, + ); + Ok(node_id) + } LogicalPlan::TableScan(scan) => { // Create input node with InputOperator for uniform handling let executable: Box = @@ -1042,7 +1160,7 @@ impl DbspCompiler { Ok(node_id) } _ => Err(LimboError::ParseError( - format!("Unsupported operator in DBSP compiler: only Filter, Projection and Aggregate are supported, got: {:?}", + format!("Unsupported operator in DBSP compiler: only Filter, Projection, Join and Aggregate are supported, got: {:?}", match plan { LogicalPlan::Sort(_) => "Sort", LogicalPlan::Limit(_) => "Limit", @@ -1095,17 +1213,24 @@ impl DbspCompiler { /// Compile a logical expression to a CompiledExpression and optional alias fn compile_expression( expr: &LogicalExpr, - input_column_names: &[String], + input_schema: &LogicalSchema, ) -> Result<(CompiledExpression, Option)> { // Check for alias first if let LogicalExpr::Alias { expr, alias } = expr { // For aliases, compile the underlying expression and return with alias - let (compiled, _) = Self::compile_expression(expr, input_column_names)?; + let (compiled, _) = Self::compile_expression(expr, input_schema)?; return Ok((compiled, Some(alias.clone()))); } - // Convert LogicalExpr to AST Expr - let ast_expr = Self::logical_to_ast_expr(expr)?; + // Convert LogicalExpr to AST Expr with proper column resolution + let ast_expr = Self::logical_to_ast_expr_with_schema(expr, input_schema)?; + + // Extract column names from schema for CompiledExpression::compile + let input_column_names: Vec = input_schema + .columns + .iter() + .map(|col| col.name.clone()) + .collect(); // For all expressions (simple or complex), use CompiledExpression::compile // This handles both trivial cases and complex VDBE compilation @@ -1129,7 +1254,7 @@ impl DbspCompiler { // Compile the expression using the existing CompiledExpression::compile let compiled = CompiledExpression::compile( &ast_expr, - input_column_names, + &input_column_names, &schema, &temp_syms, internal_conn, @@ -1138,12 +1263,27 @@ impl DbspCompiler { Ok((compiled, None)) } - /// Convert LogicalExpr to AST Expr - fn logical_to_ast_expr(expr: &LogicalExpr) -> Result { + /// Convert LogicalExpr to AST Expr with qualified column resolution + fn logical_to_ast_expr_with_schema( + expr: &LogicalExpr, + schema: &LogicalSchema, + ) -> Result { use turso_parser::ast; match expr { - LogicalExpr::Column(col) => Ok(ast::Expr::Id(ast::Name::Ident(col.name.clone()))), + LogicalExpr::Column(col) => { + // Find the column index using qualified lookup + let (idx, _) = schema + .find_column(&col.name, col.table.as_deref()) + .ok_or_else(|| { + LimboError::ParseError(format!( + "Column '{}' with table {:?} not found in schema", + col.name, col.table + )) + })?; + // Return a Register expression with the correct index + Ok(ast::Expr::Register(idx)) + } LogicalExpr::Literal(val) => { let lit = match val { Value::Integer(i) => ast::Literal::Numeric(i.to_string()), @@ -1155,8 +1295,8 @@ impl DbspCompiler { Ok(ast::Expr::Literal(lit)) } LogicalExpr::BinaryExpr { left, op, right } => { - let left_expr = Self::logical_to_ast_expr(left)?; - let right_expr = Self::logical_to_ast_expr(right)?; + let left_expr = Self::logical_to_ast_expr_with_schema(left, schema)?; + let right_expr = Self::logical_to_ast_expr_with_schema(right, schema)?; Ok(ast::Expr::Binary( Box::new(left_expr), *op, @@ -1164,7 +1304,10 @@ impl DbspCompiler { )) } LogicalExpr::ScalarFunction { fun, args } => { - let ast_args: Result> = args.iter().map(Self::logical_to_ast_expr).collect(); + let ast_args: Result> = args + .iter() + .map(|arg| Self::logical_to_ast_expr_with_schema(arg, schema)) + .collect(); let ast_args: Vec> = ast_args?.into_iter().map(Box::new).collect(); Ok(ast::Expr::FunctionCall { name: ast::Name::Ident(fun.clone()), @@ -1179,7 +1322,7 @@ impl DbspCompiler { } LogicalExpr::Alias { expr, .. } => { // For conversion to AST, ignore the alias and convert the inner expression - Self::logical_to_ast_expr(expr) + Self::logical_to_ast_expr_with_schema(expr, schema) } LogicalExpr::AggregateFunction { fun, @@ -1187,7 +1330,10 @@ impl DbspCompiler { distinct, } => { // Convert aggregate function to AST - let ast_args: Result> = args.iter().map(Self::logical_to_ast_expr).collect(); + let ast_args: Result> = args + .iter() + .map(|arg| Self::logical_to_ast_expr_with_schema(arg, schema)) + .collect(); let ast_args: Vec> = ast_args?.into_iter().map(Box::new).collect(); // Get the function name based on the aggregate type @@ -1315,8 +1461,7 @@ mod tests { use crate::incremental::operator::{FilterOperator, FilterPredicate}; use crate::schema::{BTreeTable, Column as SchemaColumn, Schema, Type}; use crate::storage::pager::CreateBTreeFlags; - use crate::translate::logical::LogicalPlanBuilder; - use crate::translate::logical::LogicalSchema; + use crate::translate::logical::{ColumnInfo, LogicalPlanBuilder, LogicalSchema}; use crate::util::IOExt; use crate::{Database, MemoryIO, Pager, IO}; use std::sync::Arc; @@ -1374,6 +1519,270 @@ mod tests { unique_sets: vec![], }; schema.add_btree_table(Arc::new(users_table)); + + // Add products table for join tests + let products_table = BTreeTable { + name: "products".to_string(), + root_page: 3, + primary_key_columns: vec![( + "product_id".to_string(), + turso_parser::ast::SortOrder::Asc, + )], + columns: vec![ + SchemaColumn { + name: Some("product_id".to_string()), + ty: Type::Integer, + ty_str: "INTEGER".to_string(), + primary_key: true, + is_rowid_alias: true, + notnull: true, + default: None, + unique: false, + collation: None, + hidden: false, + }, + SchemaColumn { + name: Some("product_name".to_string()), + ty: Type::Text, + ty_str: "TEXT".to_string(), + primary_key: false, + is_rowid_alias: false, + notnull: false, + default: None, + unique: false, + collation: None, + hidden: false, + }, + SchemaColumn { + name: Some("price".to_string()), + ty: Type::Integer, + ty_str: "INTEGER".to_string(), + primary_key: false, + is_rowid_alias: false, + notnull: false, + default: None, + unique: false, + collation: None, + hidden: false, + }, + ], + has_rowid: true, + is_strict: false, + unique_sets: vec![], + }; + schema.add_btree_table(Arc::new(products_table)); + + // Add orders table for join tests + let orders_table = BTreeTable { + name: "orders".to_string(), + root_page: 4, + primary_key_columns: vec![( + "order_id".to_string(), + turso_parser::ast::SortOrder::Asc, + )], + columns: vec![ + SchemaColumn { + name: Some("order_id".to_string()), + ty: Type::Integer, + ty_str: "INTEGER".to_string(), + primary_key: true, + is_rowid_alias: true, + notnull: true, + default: None, + unique: false, + collation: None, + hidden: false, + }, + SchemaColumn { + name: Some("user_id".to_string()), + ty: Type::Integer, + ty_str: "INTEGER".to_string(), + primary_key: false, + is_rowid_alias: false, + notnull: false, + default: None, + unique: false, + collation: None, + hidden: false, + }, + SchemaColumn { + name: Some("product_id".to_string()), + ty: Type::Integer, + ty_str: "INTEGER".to_string(), + primary_key: false, + is_rowid_alias: false, + notnull: false, + default: None, + unique: false, + collation: None, + hidden: false, + }, + SchemaColumn { + name: Some("quantity".to_string()), + ty: Type::Integer, + ty_str: "INTEGER".to_string(), + primary_key: false, + is_rowid_alias: false, + notnull: false, + default: None, + unique: false, + collation: None, + hidden: false, + }, + ], + has_rowid: true, + is_strict: false, + unique_sets: vec![], + }; + schema.add_btree_table(Arc::new(orders_table)); + + // Add customers table with id and name for testing column ambiguity + let customers_table = BTreeTable { + name: "customers".to_string(), + root_page: 6, + primary_key_columns: vec![("id".to_string(), turso_parser::ast::SortOrder::Asc)], + columns: vec![ + SchemaColumn { + name: Some("id".to_string()), + ty: Type::Integer, + ty_str: "INTEGER".to_string(), + primary_key: true, + is_rowid_alias: true, + notnull: true, + default: None, + unique: false, + collation: None, + hidden: false, + }, + SchemaColumn { + name: Some("name".to_string()), + ty: Type::Text, + ty_str: "TEXT".to_string(), + primary_key: false, + is_rowid_alias: false, + notnull: false, + default: None, + unique: false, + collation: None, + hidden: false, + }, + ], + has_rowid: true, + is_strict: false, + unique_sets: vec![], + }; + schema.add_btree_table(Arc::new(customers_table)); + + // Add purchases table (junction table for three-way join) + let purchases_table = BTreeTable { + name: "purchases".to_string(), + root_page: 7, + primary_key_columns: vec![("id".to_string(), turso_parser::ast::SortOrder::Asc)], + columns: vec![ + SchemaColumn { + name: Some("id".to_string()), + ty: Type::Integer, + ty_str: "INTEGER".to_string(), + primary_key: true, + is_rowid_alias: true, + notnull: true, + default: None, + unique: false, + collation: None, + hidden: false, + }, + SchemaColumn { + name: Some("customer_id".to_string()), + ty: Type::Integer, + ty_str: "INTEGER".to_string(), + primary_key: false, + is_rowid_alias: false, + notnull: false, + default: None, + unique: false, + collation: None, + hidden: false, + }, + SchemaColumn { + name: Some("vendor_id".to_string()), + ty: Type::Integer, + ty_str: "INTEGER".to_string(), + primary_key: false, + is_rowid_alias: false, + notnull: false, + default: None, + unique: false, + collation: None, + hidden: false, + }, + SchemaColumn { + name: Some("quantity".to_string()), + ty: Type::Integer, + ty_str: "INTEGER".to_string(), + primary_key: false, + is_rowid_alias: false, + notnull: false, + default: None, + unique: false, + collation: None, + hidden: false, + }, + ], + has_rowid: true, + is_strict: false, + unique_sets: vec![], + }; + schema.add_btree_table(Arc::new(purchases_table)); + + // Add vendors table with id, name, and price (ambiguous columns with customers) + let vendors_table = BTreeTable { + name: "vendors".to_string(), + root_page: 8, + primary_key_columns: vec![("id".to_string(), turso_parser::ast::SortOrder::Asc)], + columns: vec![ + SchemaColumn { + name: Some("id".to_string()), + ty: Type::Integer, + ty_str: "INTEGER".to_string(), + primary_key: true, + is_rowid_alias: true, + notnull: true, + default: None, + unique: false, + collation: None, + hidden: false, + }, + SchemaColumn { + name: Some("name".to_string()), + ty: Type::Text, + ty_str: "TEXT".to_string(), + primary_key: false, + is_rowid_alias: false, + notnull: false, + default: None, + unique: false, + collation: None, + hidden: false, + }, + SchemaColumn { + name: Some("price".to_string()), + ty: Type::Integer, + ty_str: "INTEGER".to_string(), + primary_key: false, + is_rowid_alias: false, + notnull: false, + default: None, + unique: false, + collation: None, + hidden: false, + }, + ], + has_rowid: true, + is_strict: false, + unique_sets: vec![], + }; + schema.add_btree_table(Arc::new(vendors_table)); + let sales_table = BTreeTable { name: "sales".to_string(), root_page: 2, @@ -3342,8 +3751,20 @@ mod tests { // Create a simple filter node let schema = Arc::new(LogicalSchema::new(vec![ - ("id".to_string(), Type::Integer), - ("value".to_string(), Type::Integer), + ColumnInfo { + name: "id".to_string(), + ty: Type::Integer, + database: None, + table: None, + table_alias: None, + }, + ColumnInfo { + name: "value".to_string(), + ty: Type::Integer, + database: None, + table: None, + table_alias: None, + }, ])); // First create an input node with InputOperator @@ -3486,4 +3907,767 @@ mod tests { "Row should still exist with multiplicity 1" ); } + + #[test] + fn test_join_with_aggregation() { + // Test join followed by aggregation - verifying actual output + let (mut circuit, pager) = compile_sql!( + "SELECT u.name, SUM(o.quantity) as total_quantity + FROM users u + JOIN orders o ON u.id = o.user_id + GROUP BY u.name" + ); + + // Create test data for users + let mut users_delta = Delta::new(); + users_delta.insert( + 1, + vec![ + Value::Integer(1), + Value::Text("Alice".into()), + Value::Integer(30), + ], + ); + users_delta.insert( + 2, + vec![ + Value::Integer(2), + Value::Text("Bob".into()), + Value::Integer(25), + ], + ); + + // Create test data for orders (order_id, user_id, product_id, quantity) + let mut orders_delta = Delta::new(); + orders_delta.insert( + 1, + vec![ + Value::Integer(1), + Value::Integer(1), + Value::Integer(101), + Value::Integer(5), + ], + ); // Alice: 5 + orders_delta.insert( + 2, + vec![ + Value::Integer(2), + Value::Integer(1), + Value::Integer(102), + Value::Integer(3), + ], + ); // Alice: 3 + orders_delta.insert( + 3, + vec![ + Value::Integer(3), + Value::Integer(2), + Value::Integer(101), + Value::Integer(7), + ], + ); // Bob: 7 + orders_delta.insert( + 4, + vec![ + Value::Integer(4), + Value::Integer(1), + Value::Integer(103), + Value::Integer(2), + ], + ); // Alice: 2 + + let inputs = HashMap::from([ + ("users".to_string(), users_delta), + ("orders".to_string(), orders_delta), + ]); + + let result = test_execute(&mut circuit, inputs.clone(), pager.clone()).unwrap(); + + // Should have 2 results: Alice with total 10, Bob with total 7 + assert_eq!( + result.len(), + 2, + "Should have aggregated results for Alice and Bob" + ); + + // Check the results + let mut results_map: HashMap = HashMap::new(); + for (row, weight) in result.changes { + assert_eq!(weight, 1); + assert_eq!(row.values.len(), 2); // name and total_quantity + + if let (Value::Text(name), Value::Integer(total)) = (&row.values[0], &row.values[1]) { + results_map.insert(name.to_string(), *total); + } else { + panic!("Unexpected value types in result"); + } + } + + assert_eq!( + results_map.get("Alice"), + Some(&10), + "Alice should have total quantity 10" + ); + assert_eq!( + results_map.get("Bob"), + Some(&7), + "Bob should have total quantity 7" + ); + } + + #[test] + fn test_join_aggregate_with_filter() { + // Test complex query with join, filter, and aggregation - verifying output + let (mut circuit, pager) = compile_sql!( + "SELECT u.name, SUM(o.quantity) as total + FROM users u + JOIN orders o ON u.id = o.user_id + WHERE u.age > 18 + GROUP BY u.name" + ); + + // Create test data for users + let mut users_delta = Delta::new(); + users_delta.insert( + 1, + vec![ + Value::Integer(1), + Value::Text("Alice".into()), + Value::Integer(30), + ], + ); // age > 18 + users_delta.insert( + 2, + vec![ + Value::Integer(2), + Value::Text("Bob".into()), + Value::Integer(17), + ], + ); // age <= 18 + users_delta.insert( + 3, + vec![ + Value::Integer(3), + Value::Text("Charlie".into()), + Value::Integer(25), + ], + ); // age > 18 + + // Create test data for orders (order_id, user_id, product_id, quantity) + let mut orders_delta = Delta::new(); + orders_delta.insert( + 1, + vec![ + Value::Integer(1), + Value::Integer(1), + Value::Integer(101), + Value::Integer(5), + ], + ); // Alice: 5 + orders_delta.insert( + 2, + vec![ + Value::Integer(2), + Value::Integer(2), + Value::Integer(102), + Value::Integer(10), + ], + ); // Bob: 10 (should be filtered) + orders_delta.insert( + 3, + vec![ + Value::Integer(3), + Value::Integer(3), + Value::Integer(101), + Value::Integer(7), + ], + ); // Charlie: 7 + orders_delta.insert( + 4, + vec![ + Value::Integer(4), + Value::Integer(1), + Value::Integer(103), + Value::Integer(3), + ], + ); // Alice: 3 + + let inputs = HashMap::from([ + ("users".to_string(), users_delta), + ("orders".to_string(), orders_delta), + ]); + + let result = test_execute(&mut circuit, inputs.clone(), pager.clone()).unwrap(); + + // Should only have results for Alice and Charlie (Bob filtered out due to age <= 18) + assert_eq!( + result.len(), + 2, + "Should only have results for users with age > 18" + ); + + // Check the results + let mut results_map: HashMap = HashMap::new(); + for (row, weight) in result.changes { + assert_eq!(weight, 1); + assert_eq!(row.values.len(), 2); // name and total + + if let (Value::Text(name), Value::Integer(total)) = (&row.values[0], &row.values[1]) { + results_map.insert(name.to_string(), *total); + } + } + + assert_eq!( + results_map.get("Alice"), + Some(&8), + "Alice should have total 8" + ); + assert_eq!( + results_map.get("Charlie"), + Some(&7), + "Charlie should have total 7" + ); + assert_eq!(results_map.get("Bob"), None, "Bob should be filtered out"); + } + + #[test] + fn test_three_way_join_execution() { + // Test executing a 3-way join with aggregation + let (mut circuit, pager) = compile_sql!( + "SELECT u.name, p.product_name, SUM(o.quantity) as total + FROM users u + JOIN orders o ON u.id = o.user_id + JOIN products p ON o.product_id = p.product_id + GROUP BY u.name, p.product_name" + ); + + // Create test data for users + let mut users_delta = Delta::new(); + users_delta.insert( + 1, + vec![ + Value::Integer(1), + Value::Text("Alice".into()), + Value::Integer(25), + ], + ); + users_delta.insert( + 2, + vec![ + Value::Integer(2), + Value::Text("Bob".into()), + Value::Integer(30), + ], + ); + + // Create test data for products + let mut products_delta = Delta::new(); + products_delta.insert( + 100, + vec![ + Value::Integer(100), + Value::Text("Widget".into()), + Value::Integer(50), + ], + ); + products_delta.insert( + 101, + vec![ + Value::Integer(101), + Value::Text("Gadget".into()), + Value::Integer(75), + ], + ); + products_delta.insert( + 102, + vec![ + Value::Integer(102), + Value::Text("Doohickey".into()), + Value::Integer(25), + ], + ); + + // Create test data for orders joining users and products + let mut orders_delta = Delta::new(); + // Alice orders 5 Widgets + orders_delta.insert( + 1, + vec![ + Value::Integer(1), + Value::Integer(1), + Value::Integer(100), + Value::Integer(5), + ], + ); + // Alice orders 3 Gadgets + orders_delta.insert( + 2, + vec![ + Value::Integer(2), + Value::Integer(1), + Value::Integer(101), + Value::Integer(3), + ], + ); + // Bob orders 7 Widgets + orders_delta.insert( + 3, + vec![ + Value::Integer(3), + Value::Integer(2), + Value::Integer(100), + Value::Integer(7), + ], + ); + // Bob orders 2 Doohickeys + orders_delta.insert( + 4, + vec![ + Value::Integer(4), + Value::Integer(2), + Value::Integer(102), + Value::Integer(2), + ], + ); + // Alice orders 4 more Widgets + orders_delta.insert( + 5, + vec![ + Value::Integer(5), + Value::Integer(1), + Value::Integer(100), + Value::Integer(4), + ], + ); + + let mut inputs = HashMap::new(); + inputs.insert("users".to_string(), users_delta); + inputs.insert("products".to_string(), products_delta); + inputs.insert("orders".to_string(), orders_delta); + + // Execute the 3-way join with aggregation + let result = test_execute(&mut circuit, inputs.clone(), pager.clone()).unwrap(); + + // We should get aggregated results for each user-product combination + // Expected results: + // - Alice, Widget: 9 (5 + 4) + // - Alice, Gadget: 3 + // - Bob, Widget: 7 + // - Bob, Doohickey: 2 + assert_eq!(result.len(), 4, "Should have 4 aggregated results"); + + // Verify aggregation results + let mut found_results = std::collections::HashSet::new(); + for (row, weight) in result.changes.iter() { + assert_eq!(*weight, 1); + // Row should have name, product_name, and sum columns + assert_eq!(row.values.len(), 3); + + if let (Value::Text(name), Value::Text(product), Value::Integer(total)) = + (&row.values[0], &row.values[1], &row.values[2]) + { + let key = format!("{}-{}", name.as_ref(), product.as_ref()); + found_results.insert(key.clone()); + + match key.as_str() { + "Alice-Widget" => { + assert_eq!(*total, 9, "Alice should have ordered 9 Widgets total") + } + "Alice-Gadget" => assert_eq!(*total, 3, "Alice should have ordered 3 Gadgets"), + "Bob-Widget" => assert_eq!(*total, 7, "Bob should have ordered 7 Widgets"), + "Bob-Doohickey" => { + assert_eq!(*total, 2, "Bob should have ordered 2 Doohickeys") + } + _ => panic!("Unexpected result: {key}"), + } + } else { + panic!("Unexpected value types in result"); + } + } + + // Ensure we found all expected combinations + assert!(found_results.contains("Alice-Widget")); + assert!(found_results.contains("Alice-Gadget")); + assert!(found_results.contains("Bob-Widget")); + assert!(found_results.contains("Bob-Doohickey")); + } + + #[test] + fn test_join_execution() { + let (mut circuit, pager) = compile_sql!( + "SELECT u.name, o.quantity FROM users u JOIN orders o ON u.id = o.user_id" + ); + + // Create test data for users + let mut users_delta = Delta::new(); + users_delta.insert( + 1, + vec![ + Value::Integer(1), + Value::Text("Alice".into()), + Value::Integer(25), + ], + ); + users_delta.insert( + 2, + vec![ + Value::Integer(2), + Value::Text("Bob".into()), + Value::Integer(30), + ], + ); + + // Create test data for orders + let mut orders_delta = Delta::new(); + orders_delta.insert( + 1, + vec![ + Value::Integer(1), + Value::Integer(1), + Value::Integer(100), + Value::Integer(5), + ], + ); + orders_delta.insert( + 2, + vec![ + Value::Integer(2), + Value::Integer(1), + Value::Integer(101), + Value::Integer(3), + ], + ); + orders_delta.insert( + 3, + vec![ + Value::Integer(3), + Value::Integer(2), + Value::Integer(102), + Value::Integer(7), + ], + ); + + let mut inputs = HashMap::new(); + inputs.insert("users".to_string(), users_delta); + inputs.insert("orders".to_string(), orders_delta); + + // Execute the join + let result = test_execute(&mut circuit, inputs.clone(), pager.clone()).unwrap(); + + // We should get 3 results (2 orders for Alice, 1 for Bob) + assert_eq!(result.len(), 3, "Should have 3 join results"); + + // Verify the join results contain the correct data + let results: Vec<_> = result.changes.iter().collect(); + + // Check that we have the expected joined rows + for (row, weight) in results { + assert_eq!(*weight, 1); // All weights should be 1 for insertions + // Row should have name and quantity columns + assert_eq!(row.values.len(), 2); + } + } + + #[test] + fn test_three_way_join_with_column_ambiguity() { + // Test three-way join with aggregation where multiple tables have columns with the same name + // Ensures that column references are correctly resolved to their respective tables + // Tables: customers(id, name), purchases(id, customer_id, vendor_id, quantity), vendors(id, name, price) + // Note: both customers and vendors have 'id' and 'name' columns which can cause ambiguity + + let sql = "SELECT c.name as customer_name, v.name as vendor_name, + SUM(p.quantity) as total_quantity, + SUM(p.quantity * v.price) as total_value + FROM customers c + JOIN purchases p ON c.id = p.customer_id + JOIN vendors v ON p.vendor_id = v.id + GROUP BY c.name, v.name"; + + let (mut circuit, pager) = compile_sql!(sql); + + // Create test data for customers (id, name) + let mut customers_delta = Delta::new(); + customers_delta.insert(1, vec![Value::Integer(1), Value::Text("Alice".into())]); + customers_delta.insert(2, vec![Value::Integer(2), Value::Text("Bob".into())]); + + // Create test data for vendors (id, name, price) + let mut vendors_delta = Delta::new(); + vendors_delta.insert( + 1, + vec![ + Value::Integer(1), + Value::Text("Widget Co".into()), + Value::Integer(10), + ], + ); + vendors_delta.insert( + 2, + vec![ + Value::Integer(2), + Value::Text("Gadget Inc".into()), + Value::Integer(20), + ], + ); + + // Create test data for purchases (id, customer_id, vendor_id, quantity) + let mut purchases_delta = Delta::new(); + // Alice purchases 5 units from Widget Co + purchases_delta.insert( + 1, + vec![ + Value::Integer(1), + Value::Integer(1), // customer_id: Alice + Value::Integer(1), // vendor_id: Widget Co + Value::Integer(5), + ], + ); + // Alice purchases 3 units from Gadget Inc + purchases_delta.insert( + 2, + vec![ + Value::Integer(2), + Value::Integer(1), // customer_id: Alice + Value::Integer(2), // vendor_id: Gadget Inc + Value::Integer(3), + ], + ); + // Bob purchases 2 units from Widget Co + purchases_delta.insert( + 3, + vec![ + Value::Integer(3), + Value::Integer(2), // customer_id: Bob + Value::Integer(1), // vendor_id: Widget Co + Value::Integer(2), + ], + ); + // Alice purchases 4 more units from Widget Co + purchases_delta.insert( + 4, + vec![ + Value::Integer(4), + Value::Integer(1), // customer_id: Alice + Value::Integer(1), // vendor_id: Widget Co + Value::Integer(4), + ], + ); + + let inputs = HashMap::from([ + ("customers".to_string(), customers_delta), + ("purchases".to_string(), purchases_delta), + ("vendors".to_string(), vendors_delta), + ]); + + let result = test_execute(&mut circuit, inputs.clone(), pager.clone()).unwrap(); + + // Expected results: + // Alice|Gadget Inc|3|60 (3 units * 20 price = 60) + // Alice|Widget Co|9|90 (9 units * 10 price = 90) + // Bob|Widget Co|2|20 (2 units * 10 price = 20) + + assert_eq!(result.len(), 3, "Should have 3 aggregated results"); + + // Sort results for consistent testing + let mut results: Vec<_> = result.changes.into_iter().collect(); + results.sort_by(|a, b| { + let a_cust = &a.0.values[0]; + let a_vend = &a.0.values[1]; + let b_cust = &b.0.values[0]; + let b_vend = &b.0.values[1]; + (a_cust, a_vend).cmp(&(b_cust, b_vend)) + }); + + // Verify Alice's Gadget Inc purchases + assert_eq!(results[0].0.values[0], Value::Text("Alice".into())); + assert_eq!(results[0].0.values[1], Value::Text("Gadget Inc".into())); + assert_eq!(results[0].0.values[2], Value::Integer(3)); // total_quantity + assert_eq!(results[0].0.values[3], Value::Integer(60)); // total_value + + // Verify Alice's Widget Co purchases + assert_eq!(results[1].0.values[0], Value::Text("Alice".into())); + assert_eq!(results[1].0.values[1], Value::Text("Widget Co".into())); + assert_eq!(results[1].0.values[2], Value::Integer(9)); // total_quantity + assert_eq!(results[1].0.values[3], Value::Integer(90)); // total_value + + // Verify Bob's Widget Co purchases + assert_eq!(results[2].0.values[0], Value::Text("Bob".into())); + assert_eq!(results[2].0.values[1], Value::Text("Widget Co".into())); + assert_eq!(results[2].0.values[2], Value::Integer(2)); // total_quantity + assert_eq!(results[2].0.values[3], Value::Integer(20)); // total_value + } + + #[test] + fn test_join_with_aggregate_execution() { + let (mut circuit, pager) = compile_sql!( + "SELECT u.name, SUM(o.quantity) as total_quantity + FROM users u + JOIN orders o ON u.id = o.user_id + GROUP BY u.name" + ); + + // Create test data for users + let mut users_delta = Delta::new(); + users_delta.insert( + 1, + vec![ + Value::Integer(1), + Value::Text("Alice".into()), + Value::Integer(25), + ], + ); + users_delta.insert( + 2, + vec![ + Value::Integer(2), + Value::Text("Bob".into()), + Value::Integer(30), + ], + ); + + // Create test data for orders + let mut orders_delta = Delta::new(); + orders_delta.insert( + 1, + vec![ + Value::Integer(1), + Value::Integer(1), + Value::Integer(100), + Value::Integer(5), + ], + ); + orders_delta.insert( + 2, + vec![ + Value::Integer(2), + Value::Integer(1), + Value::Integer(101), + Value::Integer(3), + ], + ); + orders_delta.insert( + 3, + vec![ + Value::Integer(3), + Value::Integer(2), + Value::Integer(102), + Value::Integer(7), + ], + ); + + let mut inputs = HashMap::new(); + inputs.insert("users".to_string(), users_delta); + inputs.insert("orders".to_string(), orders_delta); + + // Execute the join with aggregation + let result = test_execute(&mut circuit, inputs.clone(), pager.clone()).unwrap(); + + // We should get 2 aggregated results (one for Alice, one for Bob) + assert_eq!(result.len(), 2, "Should have 2 aggregated results"); + + // Verify aggregation results + for (row, weight) in result.changes.iter() { + assert_eq!(*weight, 1); + // Row should have name and sum columns + assert_eq!(row.values.len(), 2); + + // Check the aggregated values + if let Value::Text(name) = &row.values[0] { + if name.as_ref() == "Alice" { + // Alice should have total quantity of 8 (5 + 3) + assert_eq!(row.values[1], Value::Integer(8)); + } else if name.as_ref() == "Bob" { + // Bob should have total quantity of 7 + assert_eq!(row.values[1], Value::Integer(7)); + } + } + } + } + + #[test] + fn test_filter_with_qualified_columns_in_join() { + // Test that filters correctly handle qualified column names in joins + // when multiple tables have columns with the SAME names. + // Both users and sales tables have an 'id' column which can be ambiguous. + + let (mut circuit, pager) = compile_sql!( + "SELECT users.id, users.name, sales.id, sales.amount + FROM users + JOIN sales ON users.id = sales.customer_id + WHERE users.id > 1 AND sales.id < 100" + ); + + // Create test data + let mut users_delta = Delta::new(); + let mut sales_delta = Delta::new(); + + // Users data: (id, name, age) + users_delta.insert( + 1, + vec![ + Value::Integer(1), + Value::Text("Alice".into()), + Value::Integer(30), + ], + ); // id = 1 + users_delta.insert( + 2, + vec![ + Value::Integer(2), + Value::Text("Bob".into()), + Value::Integer(25), + ], + ); // id = 2 + users_delta.insert( + 3, + vec![ + Value::Integer(3), + Value::Text("Charlie".into()), + Value::Integer(35), + ], + ); // id = 3 + + // Sales data: (id, customer_id, amount) + sales_delta.insert( + 50, + vec![Value::Integer(50), Value::Integer(1), Value::Integer(100)], + ); // sales.id = 50, customer_id = 1 + sales_delta.insert( + 99, + vec![Value::Integer(99), Value::Integer(2), Value::Integer(200)], + ); // sales.id = 99, customer_id = 2 + sales_delta.insert( + 150, + vec![Value::Integer(150), Value::Integer(3), Value::Integer(300)], + ); // sales.id = 150, customer_id = 3 + + let mut inputs = HashMap::new(); + inputs.insert("users".to_string(), users_delta); + inputs.insert("sales".to_string(), sales_delta); + + let result = test_execute(&mut circuit, inputs.clone(), pager.clone()).unwrap(); + + // Should only get row with Bob (users.id=2, sales.id=99): + // - users.id=2 (> 1) AND sales.id=99 (< 100) ✓ + // Alice excluded: users.id=1 (NOT > 1) + // Charlie excluded: sales.id=150 (NOT < 100) + assert_eq!(result.len(), 1, "Should have 1 filtered result"); + + let (row, weight) = &result.changes[0]; + assert_eq!(*weight, 1); + assert_eq!(row.values.len(), 4, "Should have 4 columns"); + + // Verify the filter correctly used qualified columns + assert_eq!(row.values[0], Value::Integer(2), "users.id should be 2"); + assert_eq!( + row.values[1], + Value::Text("Bob".into()), + "users.name should be Bob" + ); + assert_eq!(row.values[2], Value::Integer(99), "sales.id should be 99"); + assert_eq!( + row.values[3], + Value::Integer(200), + "sales.amount should be 200" + ); + } } diff --git a/core/incremental/operator.rs b/core/incremental/operator.rs index 54cd7e0a0..72ed7bc0c 100644 --- a/core/incremental/operator.rs +++ b/core/incremental/operator.rs @@ -3,7 +3,7 @@ // Based on Feldera DBSP design but adapted for Turso's architecture pub use crate::incremental::aggregate_operator::{ - AggregateEvalState, AggregateFunction, AggregateOperator, AggregateState, + AggregateEvalState, AggregateFunction, AggregateState, }; pub use crate::incremental::filter_operator::{FilterOperator, FilterPredicate}; pub use crate::incremental::input_operator::InputOperator; @@ -251,7 +251,7 @@ pub trait IncrementalOperator: Debug { #[cfg(test)] mod tests { use super::*; - use crate::incremental::aggregate_operator::AGG_TYPE_REGULAR; + use crate::incremental::aggregate_operator::{AggregateOperator, AGG_TYPE_REGULAR}; use crate::incremental::dbsp::HashableRow; use crate::storage::pager::CreateBTreeFlags; use crate::types::Text; @@ -395,9 +395,9 @@ mod tests { // Create an aggregate operator for SUM(age) with no GROUP BY let mut agg = AggregateOperator::new( - 1, // operator_id for testing - vec![], // No GROUP BY - vec![AggregateFunction::Sum("age".to_string())], + 1, // operator_id for testing + vec![], // No GROUP BY + vec![AggregateFunction::Sum(2)], // age is at index 2 vec!["id".to_string(), "name".to_string(), "age".to_string()], ); @@ -514,9 +514,9 @@ mod tests { let mut cursors = DbspStateCursors::new(table_cursor, index_cursor); let mut agg = AggregateOperator::new( - 1, // operator_id for testing - vec!["team".to_string()], // GROUP BY team - vec![AggregateFunction::Sum("score".to_string())], + 1, // operator_id for testing + vec![1], // GROUP BY team (index 1) + vec![AggregateFunction::Sum(3)], // score is at index 3 vec![ "id".to_string(), "team".to_string(), @@ -666,8 +666,8 @@ mod tests { // Create COUNT(*) GROUP BY category let mut agg = AggregateOperator::new( - 1, // operator_id for testing - vec!["category".to_string()], + 1, // operator_id for testing + vec![1], // category is at index 1 vec![AggregateFunction::Count], vec![ "item_id".to_string(), @@ -746,9 +746,9 @@ mod tests { let mut cursors = DbspStateCursors::new(table_cursor, index_cursor); let mut agg = AggregateOperator::new( - 1, // operator_id for testing - vec!["product".to_string()], - vec![AggregateFunction::Sum("amount".to_string())], + 1, // operator_id for testing + vec![1], // product is at index 1 + vec![AggregateFunction::Sum(2)], // amount is at index 2 vec![ "sale_id".to_string(), "product".to_string(), @@ -843,11 +843,11 @@ mod tests { let mut cursors = DbspStateCursors::new(table_cursor, index_cursor); let mut agg = AggregateOperator::new( - 1, // operator_id for testing - vec!["user_id".to_string()], + 1, // operator_id for testing + vec![1], // user_id is at index 1 vec![ AggregateFunction::Count, - AggregateFunction::Sum("amount".to_string()), + AggregateFunction::Sum(2), // amount is at index 2 ], vec![ "order_id".to_string(), @@ -935,9 +935,9 @@ mod tests { let mut cursors = DbspStateCursors::new(table_cursor, index_cursor); let mut agg = AggregateOperator::new( - 1, // operator_id for testing - vec!["category".to_string()], - vec![AggregateFunction::Avg("value".to_string())], + 1, // operator_id for testing + vec![1], // category is at index 1 + vec![AggregateFunction::Avg(2)], // value is at index 2 vec![ "id".to_string(), "category".to_string(), @@ -1035,11 +1035,11 @@ mod tests { let mut cursors = DbspStateCursors::new(table_cursor, index_cursor); let mut agg = AggregateOperator::new( - 1, // operator_id for testing - vec!["category".to_string()], + 1, // operator_id for testing + vec![1], // category is at index 1 vec![ AggregateFunction::Count, - AggregateFunction::Sum("value".to_string()), + AggregateFunction::Sum(2), // value is at index 2 ], vec![ "id".to_string(), @@ -1108,7 +1108,7 @@ mod tests { #[test] fn test_count_aggregation_with_deletions() { let aggregates = vec![AggregateFunction::Count]; - let group_by = vec!["category".to_string()]; + let group_by = vec![0]; // category is at index 0 let input_columns = vec!["category".to_string(), "value".to_string()]; // Create a persistent pager for the test @@ -1197,8 +1197,8 @@ mod tests { #[test] fn test_sum_aggregation_with_deletions() { - let aggregates = vec![AggregateFunction::Sum("value".to_string())]; - let group_by = vec!["category".to_string()]; + let aggregates = vec![AggregateFunction::Sum(1)]; // value is at index 1 + let group_by = vec![0]; // category is at index 0 let input_columns = vec!["category".to_string(), "value".to_string()]; // Create a persistent pager for the test @@ -1281,8 +1281,8 @@ mod tests { #[test] fn test_avg_aggregation_with_deletions() { - let aggregates = vec![AggregateFunction::Avg("value".to_string())]; - let group_by = vec!["category".to_string()]; + let aggregates = vec![AggregateFunction::Avg(1)]; // value is at index 1 + let group_by = vec![0]; // category is at index 0 let input_columns = vec!["category".to_string(), "value".to_string()]; // Create a persistent pager for the test @@ -1348,10 +1348,10 @@ mod tests { // Test COUNT, SUM, and AVG together let aggregates = vec![ AggregateFunction::Count, - AggregateFunction::Sum("value".to_string()), - AggregateFunction::Avg("value".to_string()), + AggregateFunction::Sum(1), // value is at index 1 + AggregateFunction::Avg(1), // value is at index 1 ]; - let group_by = vec!["category".to_string()]; + let group_by = vec![0]; // category is at index 0 let input_columns = vec!["category".to_string(), "value".to_string()]; // Create a persistent pager for the test @@ -1607,11 +1607,11 @@ mod tests { let mut cursors = DbspStateCursors::new(table_cursor, index_cursor); let mut agg = AggregateOperator::new( - 1, // operator_id for testing - vec!["category".to_string()], + 1, // operator_id for testing + vec![1], // category is at index 1 vec![ AggregateFunction::Count, - AggregateFunction::Sum("amount".to_string()), + AggregateFunction::Sum(2), // amount is at index 2 ], vec![ "id".to_string(), @@ -1781,7 +1781,7 @@ mod tests { vec![], // No GROUP BY vec![ AggregateFunction::Count, - AggregateFunction::Sum("value".to_string()), + AggregateFunction::Sum(1), // value is at index 1 ], vec!["id".to_string(), "value".to_string()], ); @@ -1859,8 +1859,8 @@ mod tests { let mut cursors = DbspStateCursors::new(table_cursor, index_cursor); let mut agg = AggregateOperator::new( - 1, // operator_id for testing - vec!["type".to_string()], + 1, // operator_id for testing + vec![1], // type is at index 1 vec![AggregateFunction::Count], vec!["id".to_string(), "type".to_string()], ); @@ -1976,8 +1976,8 @@ mod tests { 1, // operator_id vec![], // No GROUP BY vec![ - AggregateFunction::Min("price".to_string()), - AggregateFunction::Max("price".to_string()), + AggregateFunction::Min(2), // price is at index 2 + AggregateFunction::Max(2), // price is at index 2 ], vec!["id".to_string(), "name".to_string(), "price".to_string()], ); @@ -2044,8 +2044,8 @@ mod tests { 1, // operator_id vec![], // No GROUP BY vec![ - AggregateFunction::Min("price".to_string()), - AggregateFunction::Max("price".to_string()), + AggregateFunction::Min(2), // price is at index 2 + AggregateFunction::Max(2), // price is at index 2 ], vec!["id".to_string(), "name".to_string(), "price".to_string()], ); @@ -2134,8 +2134,8 @@ mod tests { 1, // operator_id vec![], // No GROUP BY vec![ - AggregateFunction::Min("price".to_string()), - AggregateFunction::Max("price".to_string()), + AggregateFunction::Min(2), // price is at index 2 + AggregateFunction::Max(2), // price is at index 2 ], vec!["id".to_string(), "name".to_string(), "price".to_string()], ); @@ -2224,8 +2224,8 @@ mod tests { 1, // operator_id vec![], // No GROUP BY vec![ - AggregateFunction::Min("price".to_string()), - AggregateFunction::Max("price".to_string()), + AggregateFunction::Min(2), // price is at index 2 + AggregateFunction::Max(2), // price is at index 2 ], vec!["id".to_string(), "name".to_string(), "price".to_string()], ); @@ -2306,8 +2306,8 @@ mod tests { 1, // operator_id vec![], // No GROUP BY vec![ - AggregateFunction::Min("price".to_string()), - AggregateFunction::Max("price".to_string()), + AggregateFunction::Min(2), // price is at index 2 + AggregateFunction::Max(2), // price is at index 2 ], vec!["id".to_string(), "name".to_string(), "price".to_string()], ); @@ -2388,8 +2388,8 @@ mod tests { 1, // operator_id vec![], // No GROUP BY vec![ - AggregateFunction::Min("price".to_string()), - AggregateFunction::Max("price".to_string()), + AggregateFunction::Min(2), // price is at index 2 + AggregateFunction::Max(2), // price is at index 2 ], vec!["id".to_string(), "name".to_string(), "price".to_string()], ); @@ -2475,11 +2475,11 @@ mod tests { let mut cursors = DbspStateCursors::new(table_cursor, index_cursor); let mut agg = AggregateOperator::new( - 1, // operator_id - vec!["category".to_string()], // GROUP BY category + 1, // operator_id + vec![1], // GROUP BY category (index 1) vec![ - AggregateFunction::Min("price".to_string()), - AggregateFunction::Max("price".to_string()), + AggregateFunction::Min(3), // price is at index 3 + AggregateFunction::Max(3), // price is at index 3 ], vec![ "id".to_string(), @@ -2580,8 +2580,8 @@ mod tests { 1, // operator_id vec![], // No GROUP BY vec![ - AggregateFunction::Min("price".to_string()), - AggregateFunction::Max("price".to_string()), + AggregateFunction::Min(2), // price is at index 2 + AggregateFunction::Max(2), // price is at index 2 ], vec!["id".to_string(), "name".to_string(), "price".to_string()], ); @@ -2656,8 +2656,8 @@ mod tests { 1, // operator_id vec![], // No GROUP BY vec![ - AggregateFunction::Min("score".to_string()), - AggregateFunction::Max("score".to_string()), + AggregateFunction::Min(2), // score is at index 2 + AggregateFunction::Max(2), // score is at index 2 ], vec!["id".to_string(), "name".to_string(), "score".to_string()], ); @@ -2724,8 +2724,8 @@ mod tests { 1, // operator_id vec![], // No GROUP BY vec![ - AggregateFunction::Min("name".to_string()), - AggregateFunction::Max("name".to_string()), + AggregateFunction::Min(1), // name is at index 1 + AggregateFunction::Max(1), // name is at index 1 ], vec!["id".to_string(), "name".to_string()], ); @@ -2764,10 +2764,10 @@ mod tests { vec![], // No GROUP BY vec![ AggregateFunction::Count, - AggregateFunction::Sum("value".to_string()), - AggregateFunction::Min("value".to_string()), - AggregateFunction::Max("value".to_string()), - AggregateFunction::Avg("value".to_string()), + AggregateFunction::Sum(1), // value is at index 1 + AggregateFunction::Min(1), // value is at index 1 + AggregateFunction::Max(1), // value is at index 1 + AggregateFunction::Avg(1), // value is at index 1 ], vec!["id".to_string(), "value".to_string()], ); @@ -2855,9 +2855,9 @@ mod tests { 1, // operator_id vec![], // No GROUP BY vec![ - AggregateFunction::Min("col1".to_string()), - AggregateFunction::Max("col2".to_string()), - AggregateFunction::Min("col3".to_string()), + AggregateFunction::Min(0), // col1 is at index 0 + AggregateFunction::Max(1), // col2 is at index 1 + AggregateFunction::Min(2), // col3 is at index 2 ], vec!["col1".to_string(), "col2".to_string(), "col3".to_string()], ); diff --git a/core/translate/logical.rs b/core/translate/logical.rs index 6a8b0a6c2..b11e2df4f 100644 --- a/core/translate/logical.rs +++ b/core/translate/logical.rs @@ -19,26 +19,35 @@ use turso_parser::ast; /// Result type for preprocessing aggregate expressions type PreprocessAggregateResult = ( - bool, // needs_pre_projection - Vec, // pre_projection_exprs - Vec<(String, Type)>, // pre_projection_schema - Vec, // modified_aggr_exprs + bool, // needs_pre_projection + Vec, // pre_projection_exprs + Vec, // pre_projection_schema + Vec, // modified_aggr_exprs ); /// Result type for parsing join conditions type JoinConditionsResult = (Vec<(LogicalExpr, LogicalExpr)>, Option); +/// Information about a column in a logical schema +#[derive(Debug, Clone, PartialEq)] +pub struct ColumnInfo { + pub name: String, + pub ty: Type, + pub database: Option, + pub table: Option, + pub table_alias: Option, +} + /// Schema information for logical plan nodes #[derive(Debug, Clone, PartialEq)] pub struct LogicalSchema { - /// Column names and types - pub columns: Vec<(String, Type)>, + pub columns: Vec, } /// A reference to a schema that can be shared between nodes pub type SchemaRef = Arc; impl LogicalSchema { - pub fn new(columns: Vec<(String, Type)>) -> Self { + pub fn new(columns: Vec) -> Self { Self { columns } } @@ -52,11 +61,42 @@ impl LogicalSchema { self.columns.len() } - pub fn find_column(&self, name: &str) -> Option<(usize, &Type)> { - self.columns - .iter() - .position(|(n, _)| n == name) - .map(|idx| (idx, &self.columns[idx].1)) + pub fn find_column(&self, name: &str, table: Option<&str>) -> Option<(usize, &ColumnInfo)> { + if let Some(table_ref) = table { + // Check if it's a database.table format + if table_ref.contains('.') { + let parts: Vec<&str> = table_ref.splitn(2, '.').collect(); + if parts.len() == 2 { + let db = parts[0]; + let tbl = parts[1]; + return self + .columns + .iter() + .position(|c| { + c.name == name + && c.database.as_deref() == Some(db) + && c.table.as_deref() == Some(tbl) + }) + .map(|idx| (idx, &self.columns[idx])); + } + } + + // Try to match against table alias first, then table name + self.columns + .iter() + .position(|c| { + c.name == name + && (c.table_alias.as_deref() == Some(table_ref) + || c.table.as_deref() == Some(table_ref)) + }) + .map(|idx| (idx, &self.columns[idx])) + } else { + // Unqualified lookup - just match by name + self.columns + .iter() + .position(|c| c.name == name) + .map(|idx| (idx, &self.columns[idx])) + } } } @@ -548,14 +588,14 @@ impl<'a> LogicalPlanBuilder<'a> { } // Regular table scan - let table_schema = self.get_table_schema(&table_name)?; let table_alias = alias.as_ref().map(|a| match a { ast::As::As(name) => Self::name_to_string(name), ast::As::Elided(name) => Self::name_to_string(name), }); + let table_schema = self.get_table_schema(&table_name, table_alias.as_deref())?; Ok(LogicalPlan::TableScan(TableScan { table_name, - alias: table_alias, + alias: table_alias.clone(), schema: table_schema, projection: None, })) @@ -751,14 +791,14 @@ impl<'a> LogicalPlanBuilder<'a> { let _left_idx = left_schema .columns .iter() - .position(|(n, _)| n == &name) + .position(|col| col.name == name) .ok_or_else(|| { LimboError::ParseError(format!("Column {name} not found in left table")) })?; let _right_idx = right_schema .columns .iter() - .position(|(n, _)| n == &name) + .position(|col| col.name == name) .ok_or_else(|| { LimboError::ParseError(format!("Column {name} not found in right table")) })?; @@ -790,9 +830,13 @@ impl<'a> LogicalPlanBuilder<'a> { // Find common column names let mut common_columns = Vec::new(); - for (left_name, _) in &left_schema.columns { - if right_schema.columns.iter().any(|(n, _)| n == left_name) { - common_columns.push(ast::Name::Ident(left_name.clone())); + for left_col in &left_schema.columns { + if right_schema + .columns + .iter() + .any(|col| col.name == left_col.name) + { + common_columns.push(ast::Name::Ident(left_col.name.clone())); } } @@ -833,10 +877,18 @@ impl<'a> LogicalPlanBuilder<'a> { let left_schema = left.schema(); let right_schema = right.schema(); - // For now, simply concatenate the schemas - // In a real implementation, we'd handle column name conflicts and nullable columns - let mut columns = left_schema.columns.clone(); - columns.extend(right_schema.columns.clone()); + // Concatenate the schemas, preserving all column information + let mut columns = Vec::new(); + + // Keep all columns from left with their table info + for col in &left_schema.columns { + columns.push(col.clone()); + } + + // Keep all columns from right with their table info + for col in &right_schema.columns { + columns.push(col.clone()); + } Ok(Arc::new(LogicalSchema::new(columns))) } @@ -870,7 +922,13 @@ impl<'a> LogicalPlanBuilder<'a> { }; let col_type = Self::infer_expr_type(&logical_expr, input_schema)?; - schema_columns.push((col_name.clone(), col_type)); + schema_columns.push(ColumnInfo { + name: col_name.clone(), + ty: col_type, + database: None, + table: None, + table_alias: None, + }); if let Some(as_alias) = alias { let alias_name = match as_alias { @@ -886,21 +944,21 @@ impl<'a> LogicalPlanBuilder<'a> { } ast::ResultColumn::Star => { // Expand * to all columns - for (name, typ) in &input_schema.columns { - proj_exprs.push(LogicalExpr::Column(Column::new(name.clone()))); - schema_columns.push((name.clone(), *typ)); + for col in &input_schema.columns { + proj_exprs.push(LogicalExpr::Column(Column::new(col.name.clone()))); + schema_columns.push(col.clone()); } } ast::ResultColumn::TableStar(table) => { // Expand table.* to all columns from that table let table_name = Self::name_to_string(table); - for (name, typ) in &input_schema.columns { + for col in &input_schema.columns { // Simple check - would need proper table tracking in real implementation proj_exprs.push(LogicalExpr::Column(Column::with_table( - name.clone(), + col.name.clone(), table_name.clone(), ))); - schema_columns.push((name.clone(), *typ)); + schema_columns.push(col.clone()); } } } @@ -938,7 +996,13 @@ impl<'a> LogicalPlanBuilder<'a> { if let LogicalExpr::Column(col) = expr { pre_projection_exprs.push(expr.clone()); let col_type = Self::infer_expr_type(expr, input_schema)?; - pre_projection_schema.push((col.name.clone(), col_type)); + pre_projection_schema.push(ColumnInfo { + name: col.name.clone(), + ty: col_type, + database: None, + table: col.table.clone(), + table_alias: None, + }); } else { // Complex group by expression - project it needs_pre_projection = true; @@ -946,7 +1010,13 @@ impl<'a> LogicalPlanBuilder<'a> { projected_col_counter += 1; pre_projection_exprs.push(expr.clone()); let col_type = Self::infer_expr_type(expr, input_schema)?; - pre_projection_schema.push((proj_col_name.clone(), col_type)); + pre_projection_schema.push(ColumnInfo { + name: proj_col_name.clone(), + ty: col_type, + database: None, + table: None, + table_alias: None, + }); } } @@ -970,7 +1040,13 @@ impl<'a> LogicalPlanBuilder<'a> { pre_projection_exprs.push(arg.clone()); let col_type = Self::infer_expr_type(arg, input_schema)?; if let LogicalExpr::Column(col) = arg { - pre_projection_schema.push((col.name.clone(), col_type)); + pre_projection_schema.push(ColumnInfo { + name: col.name.clone(), + ty: col_type, + database: None, + table: col.table.clone(), + table_alias: None, + }); } } } @@ -983,7 +1059,13 @@ impl<'a> LogicalPlanBuilder<'a> { // Add the expression to the pre-projection pre_projection_exprs.push(arg.clone()); let col_type = Self::infer_expr_type(arg, input_schema)?; - pre_projection_schema.push((proj_col_name.clone(), col_type)); + pre_projection_schema.push(ColumnInfo { + name: proj_col_name.clone(), + ty: col_type, + database: None, + table: None, + table_alias: None, + }); // In the aggregate, reference the projected column modified_args.push(LogicalExpr::Column(Column::new(proj_col_name))); @@ -1057,15 +1139,39 @@ impl<'a> LogicalPlanBuilder<'a> { // First, add GROUP BY columns to the aggregate output schema // These are always part of the aggregate operator's output for group_expr in &group_exprs { - let col_name = match group_expr { - LogicalExpr::Column(col) => col.name.clone(), + match group_expr { + LogicalExpr::Column(col) => { + // For column references in GROUP BY, preserve the original column info + if let Some((_, col_info)) = + input_schema.find_column(&col.name, col.table.as_deref()) + { + // Preserve the column with all its table information + aggregate_schema_columns.push(col_info.clone()); + } else { + // Fallback if column not found (shouldn't happen) + let col_type = Self::infer_expr_type(group_expr, input_schema)?; + aggregate_schema_columns.push(ColumnInfo { + name: col.name.clone(), + ty: col_type, + database: None, + table: col.table.clone(), + table_alias: None, + }); + } + } _ => { // For complex GROUP BY expressions, generate a name - format!("__group_{}", aggregate_schema_columns.len()) + let col_name = format!("__group_{}", aggregate_schema_columns.len()); + let col_type = Self::infer_expr_type(group_expr, input_schema)?; + aggregate_schema_columns.push(ColumnInfo { + name: col_name, + ty: col_type, + database: None, + table: None, + table_alias: None, + }); } - }; - let col_type = Self::infer_expr_type(group_expr, input_schema)?; - aggregate_schema_columns.push((col_name, col_type)); + } } // Track aggregates we've already seen to avoid duplicates @@ -1098,7 +1204,13 @@ impl<'a> LogicalPlanBuilder<'a> { } else { // New aggregate - add it let col_type = Self::infer_expr_type(&logical_expr, input_schema)?; - aggregate_schema_columns.push((col_name.clone(), col_type)); + aggregate_schema_columns.push(ColumnInfo { + name: col_name.clone(), + ty: col_type, + database: None, + table: None, + table_alias: None, + }); aggr_exprs.push(logical_expr); aggregate_map.insert(agg_key, col_name.clone()); col_name.clone() @@ -1122,7 +1234,13 @@ impl<'a> LogicalPlanBuilder<'a> { // Add only new aggregates for (agg_expr, agg_name) in extracted_aggs { let agg_type = Self::infer_expr_type(&agg_expr, input_schema)?; - aggregate_schema_columns.push((agg_name, agg_type)); + aggregate_schema_columns.push(ColumnInfo { + name: agg_name, + ty: agg_type, + database: None, + table: None, + table_alias: None, + }); aggr_exprs.push(agg_expr); } @@ -1197,7 +1315,13 @@ impl<'a> LogicalPlanBuilder<'a> { // For type inference, we need the aggregate schema for column references let aggregate_schema = LogicalSchema::new(aggregate_schema_columns.clone()); let col_type = Self::infer_expr_type(expr, &Arc::new(aggregate_schema))?; - projection_schema_columns.push((col_name, col_type)); + projection_schema_columns.push(ColumnInfo { + name: col_name, + ty: col_type, + database: None, + table: None, + table_alias: None, + }); } // Create the input plan (with pre-projection if needed) @@ -1220,11 +1344,11 @@ impl<'a> LogicalPlanBuilder<'a> { // Check if we need the outer projection // We need a projection if: - // 1. Any expression is more complex than a simple column reference (e.g., abs(sum(id))) - // 2. We're selecting a different set of columns than what the aggregate outputs - // 3. Columns are renamed or reordered + // 1. We have expressions that compute new values (e.g., SUM(x) * 2) + // 2. We're selecting a different set of columns than GROUP BY + aggregates + // 3. We're reordering columns from their natural aggregate output order let needs_outer_projection = { - // Check if any expression is more complex than a simple column reference + // Check for complex expressions let has_complex_exprs = projection_exprs .iter() .any(|expr| !matches!(expr, LogicalExpr::Column(_))); @@ -1232,17 +1356,29 @@ impl<'a> LogicalPlanBuilder<'a> { if has_complex_exprs { true } else { - // All are simple columns - check if we're selecting exactly what the aggregate outputs - // The projection might be selecting a subset (e.g., only aggregates without group columns) - // or reordering columns, or using different names + // Check if we're selecting exactly what aggregate outputs in the same order + // The aggregate outputs: all GROUP BY columns, then all aggregate expressions + // The projection might select a subset or reorder these - // For now, keep it simple: if schemas don't match exactly, we need projection - // This handles all cases: subset selection, reordering, renaming - projection_schema_columns != aggregate_schema_columns + if projection_exprs.len() != aggregate_schema_columns.len() { + // Different number of columns + true + } else { + // Check if columns match in order and name + !projection_exprs.iter().zip(&aggregate_schema_columns).all( + |(expr, agg_col)| { + if let LogicalExpr::Column(col) = expr { + col.name == agg_col.name + } else { + false + } + }, + ) + } } }; - // Create the aggregate node + // Create the aggregate node with its natural schema let aggregate_plan = LogicalPlan::Aggregate(Aggregate { input: aggregate_input, group_expr: group_exprs, @@ -1257,7 +1393,7 @@ impl<'a> LogicalPlanBuilder<'a> { schema: Arc::new(LogicalSchema::new(projection_schema_columns)), })) } else { - // No projection needed - the aggregate output is exactly what we want + // No projection needed - aggregate output matches what we want Ok(aggregate_plan) } } @@ -1275,7 +1411,13 @@ impl<'a> LogicalPlanBuilder<'a> { // Infer schema from first row let mut schema_columns = Vec::new(); for (i, _) in values[0].iter().enumerate() { - schema_columns.push((format!("column{}", i + 1), Type::Text)); + schema_columns.push(ColumnInfo { + name: format!("column{}", i + 1), + ty: Type::Text, + database: None, + table: None, + table_alias: None, + }); } for row in values { @@ -2003,17 +2145,31 @@ impl<'a> LogicalPlanBuilder<'a> { } // Get table schema - fn get_table_schema(&self, table_name: &str) -> Result { + fn get_table_schema(&self, table_name: &str, alias: Option<&str>) -> Result { // Look up table in schema let table = self .schema .get_table(table_name) .ok_or_else(|| LimboError::ParseError(format!("Table '{table_name}' not found")))?; + // Parse table_name which might be "db.table" for attached databases + let (database, actual_table) = if table_name.contains('.') { + let parts: Vec<&str> = table_name.splitn(2, '.').collect(); + (Some(parts[0].to_string()), parts[1].to_string()) + } else { + (None, table_name.to_string()) + }; + let mut columns = Vec::new(); for col in table.columns() { if let Some(ref name) = col.name { - columns.push((name.clone(), col.ty)); + columns.push(ColumnInfo { + name: name.clone(), + ty: col.ty, + database: database.clone(), + table: Some(actual_table.clone()), + table_alias: alias.map(|s| s.to_string()), + }); } } @@ -2024,8 +2180,8 @@ impl<'a> LogicalPlanBuilder<'a> { fn infer_expr_type(expr: &LogicalExpr, schema: &SchemaRef) -> Result { match expr { LogicalExpr::Column(col) => { - if let Some((_, typ)) = schema.find_column(&col.name) { - Ok(*typ) + if let Some((_, col_info)) = schema.find_column(&col.name, col.table.as_deref()) { + Ok(col_info.ty) } else { Ok(Type::Text) }