diff --git a/core/incremental/aggregate_operator.rs b/core/incremental/aggregate_operator.rs index f4c8ece0a..9f25a84f5 100644 --- a/core/incremental/aggregate_operator.rs +++ b/core/incremental/aggregate_operator.rs @@ -19,20 +19,20 @@ pub const AGG_TYPE_MINMAX: u8 = 0b01; // MIN/MAX (BTree ordering gives both) #[derive(Debug, Clone, PartialEq)] pub enum AggregateFunction { Count, - Sum(String), - Avg(String), - Min(String), - Max(String), + Sum(usize), // Column index + Avg(usize), // Column index + Min(usize), // Column index + Max(usize), // Column index } impl Display for AggregateFunction { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { match self { AggregateFunction::Count => write!(f, "COUNT(*)"), - AggregateFunction::Sum(col) => write!(f, "SUM({col})"), - AggregateFunction::Avg(col) => write!(f, "AVG({col})"), - AggregateFunction::Min(col) => write!(f, "MIN({col})"), - AggregateFunction::Max(col) => write!(f, "MAX({col})"), + AggregateFunction::Sum(idx) => write!(f, "SUM(col{idx})"), + AggregateFunction::Avg(idx) => write!(f, "AVG(col{idx})"), + AggregateFunction::Min(idx) => write!(f, "MIN(col{idx})"), + AggregateFunction::Max(idx) => write!(f, "MAX(col{idx})"), } } } @@ -48,16 +48,16 @@ impl AggregateFunction { /// Returns None if the function is not a supported aggregate pub fn from_sql_function( func: &crate::function::Func, - input_column: Option, + input_column_idx: Option, ) -> Option { match func { Func::Agg(agg_func) => { match agg_func { AggFunc::Count | AggFunc::Count0 => Some(AggregateFunction::Count), - AggFunc::Sum => input_column.map(AggregateFunction::Sum), - AggFunc::Avg => input_column.map(AggregateFunction::Avg), - AggFunc::Min => input_column.map(AggregateFunction::Min), - AggFunc::Max => input_column.map(AggregateFunction::Max), + AggFunc::Sum => input_column_idx.map(AggregateFunction::Sum), + AggFunc::Avg => input_column_idx.map(AggregateFunction::Avg), + AggFunc::Min => input_column_idx.map(AggregateFunction::Min), + AggFunc::Max => input_column_idx.map(AggregateFunction::Max), _ => None, // Other aggregate functions not yet supported in DBSP } } @@ -115,8 +115,8 @@ pub fn deserialize_value(blob: &[u8]) -> Option<(Value, usize)> { // group_key_str -> (group_key, state) type ComputedStates = HashMap, AggregateState)>; -// group_key_str -> (column_name, value_as_hashable_row) -> accumulated_weight -pub type MinMaxDeltas = HashMap>; +// group_key_str -> (column_index, value_as_hashable_row) -> accumulated_weight +pub type MinMaxDeltas = HashMap>; #[derive(Debug)] enum AggregateCommitState { @@ -178,14 +178,14 @@ pub enum AggregateEvalState { pub struct AggregateOperator { // Unique operator ID for indexing in persistent storage pub operator_id: usize, - // GROUP BY columns - group_by: Vec, + // GROUP BY column indices + group_by: Vec, // Aggregate functions to compute (including MIN/MAX) pub aggregates: Vec, // Column names from input pub input_column_names: Vec, - // Map from column name to aggregate info for quick lookup - pub column_min_max: HashMap, + // Map from column index to aggregate info for quick lookup + pub column_min_max: HashMap, tracker: Option>>, // State machine for commit operation @@ -197,14 +197,14 @@ pub struct AggregateOperator { pub struct AggregateState { // For COUNT: just the count pub count: i64, - // For SUM: column_name -> sum value - sums: HashMap, - // For AVG: column_name -> (sum, count) for computing average - avgs: HashMap, - // For MIN: column_name -> minimum value - pub mins: HashMap, - // For MAX: column_name -> maximum value - pub maxs: HashMap, + // For SUM: column_index -> sum value + sums: HashMap, + // For AVG: column_index -> (sum, count) for computing average + avgs: HashMap, + // For MIN: column_index -> minimum value + pub mins: HashMap, + // For MAX: column_index -> maximum value + pub maxs: HashMap, } impl AggregateEvalState { @@ -520,14 +520,14 @@ impl AggregateState { AggregateFunction::Sum(col_name) => { let sum = f64::from_le_bytes(blob.get(cursor..cursor + 8)?.try_into().ok()?); cursor += 8; - state.sums.insert(col_name.clone(), sum); + state.sums.insert(*col_name, sum); } AggregateFunction::Avg(col_name) => { let sum = f64::from_le_bytes(blob.get(cursor..cursor + 8)?.try_into().ok()?); cursor += 8; let count = i64::from_le_bytes(blob.get(cursor..cursor + 8)?.try_into().ok()?); cursor += 8; - state.avgs.insert(col_name.clone(), (sum, count)); + state.avgs.insert(*col_name, (sum, count)); } AggregateFunction::Count => { // Count was already read above @@ -540,7 +540,7 @@ impl AggregateState { if has_value == 1 { let (min_value, bytes_consumed) = deserialize_value(&blob[cursor..])?; cursor += bytes_consumed; - state.mins.insert(col_name.clone(), min_value); + state.mins.insert(*col_name, min_value); } } AggregateFunction::Max(col_name) => { @@ -551,7 +551,7 @@ impl AggregateState { if has_value == 1 { let (max_value, bytes_consumed) = deserialize_value(&blob[cursor..])?; cursor += bytes_consumed; - state.maxs.insert(col_name.clone(), max_value); + state.maxs.insert(*col_name, max_value); } } } @@ -566,7 +566,7 @@ impl AggregateState { values: &[Value], weight: isize, aggregates: &[AggregateFunction], - column_names: &[String], + _column_names: &[String], // No longer needed ) { // Update COUNT self.count += weight as i64; @@ -577,32 +577,26 @@ impl AggregateState { AggregateFunction::Count => { // Already handled above } - AggregateFunction::Sum(col_name) => { - if let Some(idx) = column_names.iter().position(|c| c == col_name) { - if let Some(val) = values.get(idx) { - let num_val = match val { - Value::Integer(i) => *i as f64, - Value::Float(f) => *f, - _ => 0.0, - }; - *self.sums.entry(col_name.clone()).or_insert(0.0) += - num_val * weight as f64; - } + AggregateFunction::Sum(col_idx) => { + if let Some(val) = values.get(*col_idx) { + let num_val = match val { + Value::Integer(i) => *i as f64, + Value::Float(f) => *f, + _ => 0.0, + }; + *self.sums.entry(*col_idx).or_insert(0.0) += num_val * weight as f64; } } - AggregateFunction::Avg(col_name) => { - if let Some(idx) = column_names.iter().position(|c| c == col_name) { - if let Some(val) = values.get(idx) { - let num_val = match val { - Value::Integer(i) => *i as f64, - Value::Float(f) => *f, - _ => 0.0, - }; - let (sum, count) = - self.avgs.entry(col_name.clone()).or_insert((0.0, 0)); - *sum += num_val * weight as f64; - *count += weight as i64; - } + AggregateFunction::Avg(col_idx) => { + if let Some(val) = values.get(*col_idx) { + let num_val = match val { + Value::Integer(i) => *i as f64, + Value::Float(f) => *f, + _ => 0.0, + }; + let (sum, count) = self.avgs.entry(*col_idx).or_insert((0.0, 0)); + *sum += num_val * weight as f64; + *count += weight as i64; } } AggregateFunction::Min(_col_name) | AggregateFunction::Max(_col_name) => { @@ -644,8 +638,8 @@ impl AggregateState { AggregateFunction::Count => { result.push(Value::Integer(self.count)); } - AggregateFunction::Sum(col_name) => { - let sum = self.sums.get(col_name).copied().unwrap_or(0.0); + AggregateFunction::Sum(col_idx) => { + let sum = self.sums.get(col_idx).copied().unwrap_or(0.0); // Return as integer if it's a whole number, otherwise as float if sum.fract() == 0.0 { result.push(Value::Integer(sum as i64)); @@ -653,8 +647,8 @@ impl AggregateState { result.push(Value::Float(sum)); } } - AggregateFunction::Avg(col_name) => { - if let Some((sum, count)) = self.avgs.get(col_name) { + AggregateFunction::Avg(col_idx) => { + if let Some((sum, count)) = self.avgs.get(col_idx) { if *count > 0 { result.push(Value::Float(sum / *count as f64)); } else { @@ -664,13 +658,13 @@ impl AggregateState { result.push(Value::Null); } } - AggregateFunction::Min(col_name) => { + AggregateFunction::Min(col_idx) => { // Return the MIN value from our state - result.push(self.mins.get(col_name).cloned().unwrap_or(Value::Null)); + result.push(self.mins.get(col_idx).cloned().unwrap_or(Value::Null)); } - AggregateFunction::Max(col_name) => { + AggregateFunction::Max(col_idx) => { // Return the MAX value from our state - result.push(self.maxs.get(col_name).cloned().unwrap_or(Value::Null)); + result.push(self.maxs.get(col_idx).cloned().unwrap_or(Value::Null)); } } } @@ -682,20 +676,20 @@ impl AggregateState { impl AggregateOperator { pub fn new( operator_id: usize, - group_by: Vec, + group_by: Vec, aggregates: Vec, input_column_names: Vec, ) -> Self { - // Build map of column names to their MIN/MAX info with indices + // Build map of column indices to their MIN/MAX info let mut column_min_max = HashMap::new(); - let mut column_indices = HashMap::new(); + let mut storage_indices = HashMap::new(); let mut current_index = 0; - // First pass: assign indices to unique MIN/MAX columns + // First pass: assign storage indices to unique MIN/MAX columns for agg in &aggregates { match agg { - AggregateFunction::Min(col) | AggregateFunction::Max(col) => { - column_indices.entry(col.clone()).or_insert_with(|| { + AggregateFunction::Min(col_idx) | AggregateFunction::Max(col_idx) => { + storage_indices.entry(*col_idx).or_insert_with(|| { let idx = current_index; current_index += 1; idx @@ -708,19 +702,19 @@ impl AggregateOperator { // Second pass: build the column info map for agg in &aggregates { match agg { - AggregateFunction::Min(col) => { - let index = *column_indices.get(col).unwrap(); - let entry = column_min_max.entry(col.clone()).or_insert(AggColumnInfo { - index, + AggregateFunction::Min(col_idx) => { + let storage_index = *storage_indices.get(col_idx).unwrap(); + let entry = column_min_max.entry(*col_idx).or_insert(AggColumnInfo { + index: storage_index, has_min: false, has_max: false, }); entry.has_min = true; } - AggregateFunction::Max(col) => { - let index = *column_indices.get(col).unwrap(); - let entry = column_min_max.entry(col.clone()).or_insert(AggColumnInfo { - index, + AggregateFunction::Max(col_idx) => { + let storage_index = *storage_indices.get(col_idx).unwrap(); + let entry = column_min_max.entry(*col_idx).or_insert(AggColumnInfo { + index: storage_index, has_min: false, has_max: false, }); @@ -876,28 +870,24 @@ impl AggregateOperator { for agg in &self.aggregates { match agg { - AggregateFunction::Min(col_name) | AggregateFunction::Max(col_name) => { - if let Some(idx) = - self.input_column_names.iter().position(|c| c == col_name) - { - if let Some(val) = row.values.get(idx) { - // Skip NULL values - they don't participate in MIN/MAX - if val == &Value::Null { - continue; - } - // Create a HashableRow with just this value - // Use 0 as rowid since we only care about the value for comparison - let hashable_value = HashableRow::new(0, vec![val.clone()]); - let key = (col_name.clone(), hashable_value); - - let group_entry = - min_max_deltas.entry(group_key_str.clone()).or_default(); - - let value_entry = group_entry.entry(key).or_insert(0); - - // Accumulate the weight - *value_entry += weight; + AggregateFunction::Min(col_idx) | AggregateFunction::Max(col_idx) => { + if let Some(val) = row.values.get(*col_idx) { + // Skip NULL values - they don't participate in MIN/MAX + if val == &Value::Null { + continue; } + // Create a HashableRow with just this value + // Use 0 as rowid since we only care about the value for comparison + let hashable_value = HashableRow::new(0, vec![val.clone()]); + let key = (*col_idx, hashable_value); + + let group_entry = + min_max_deltas.entry(group_key_str.clone()).or_default(); + + let value_entry = group_entry.entry(key).or_insert(0); + + // Accumulate the weight + *value_entry += weight; } } _ => {} // Ignore non-MIN/MAX aggregates @@ -929,13 +919,9 @@ impl AggregateOperator { pub fn extract_group_key(&self, values: &[Value]) -> Vec { let mut key = Vec::new(); - for group_col in &self.group_by { - if let Some(idx) = self.input_column_names.iter().position(|c| c == group_col) { - if let Some(val) = values.get(idx) { - key.push(val.clone()); - } else { - key.push(Value::Null); - } + for &idx in &self.group_by { + if let Some(val) = values.get(idx) { + key.push(val.clone()); } else { key.push(Value::Null); } @@ -1124,13 +1110,13 @@ pub enum RecomputeMinMax { /// Current column being processed current_column_idx: usize, /// Columns to process (combined MIN and MAX) - columns_to_process: Vec<(String, String, bool)>, // (group_key, column_name, is_min) + columns_to_process: Vec<(String, usize, bool)>, // (group_key, column_name, is_min) /// MIN/MAX deltas for checking values and weights min_max_deltas: MinMaxDeltas, }, Scan { /// Columns still to process - columns_to_process: Vec<(String, String, bool)>, + columns_to_process: Vec<(String, usize, bool)>, /// Current index in columns_to_process (will resume from here) current_column_idx: usize, /// MIN/MAX deltas for checking values and weights @@ -1138,7 +1124,7 @@ pub enum RecomputeMinMax { /// Current group key being processed group_key: String, /// Current column name being processed - column_name: String, + column_name: usize, /// Whether we're looking for MIN (true) or MAX (false) is_min: bool, /// The scan state machine for finding the new MIN/MAX @@ -1153,7 +1139,7 @@ impl RecomputeMinMax { existing_groups: &HashMap, operator: &AggregateOperator, ) -> Self { - let mut groups_to_check: HashSet<(String, String, bool)> = HashSet::new(); + let mut groups_to_check: HashSet<(String, usize, bool)> = HashSet::new(); // Remember the min_max_deltas are essentially just the only column that is affected by // this min/max, in delta (actually ZSet - consolidated delta) format. This makes it easier @@ -1173,21 +1159,13 @@ impl RecomputeMinMax { // Check for MIN if let Some(current_min) = state.mins.get(col_name) { if current_min == value { - groups_to_check.insert(( - group_key_str.clone(), - col_name.clone(), - true, - )); + groups_to_check.insert((group_key_str.clone(), *col_name, true)); } } // Check for MAX if let Some(current_max) = state.maxs.get(col_name) { if current_max == value { - groups_to_check.insert(( - group_key_str.clone(), - col_name.clone(), - false, - )); + groups_to_check.insert((group_key_str.clone(), *col_name, false)); } } } @@ -1196,14 +1174,10 @@ impl RecomputeMinMax { // about this if this is a new record being inserted if let Some(info) = col_info { if info.has_min { - groups_to_check.insert((group_key_str.clone(), col_name.clone(), true)); + groups_to_check.insert((group_key_str.clone(), *col_name, true)); } if info.has_max { - groups_to_check.insert(( - group_key_str.clone(), - col_name.clone(), - false, - )); + groups_to_check.insert((group_key_str.clone(), *col_name, false)); } } } @@ -1245,12 +1219,13 @@ impl RecomputeMinMax { let (group_key, column_name, is_min) = columns_to_process[*current_column_idx].clone(); - // Get column index from pre-computed info - let column_index = operator + // Column name is already the index + // Get the storage index from column_min_max map + let column_info = operator .column_min_max .get(&column_name) - .map(|info| info.index) - .unwrap(); // Should always exist since we're processing known columns + .expect("Column should exist in column_min_max map"); + let storage_index = column_info.index; // Get current value from existing state let current_value = existing_groups.get(&group_key).and_then(|state| { @@ -1263,7 +1238,7 @@ impl RecomputeMinMax { // Create storage keys for index lookup let storage_id = - generate_storage_id(operator.operator_id, column_index, AGG_TYPE_MINMAX); + generate_storage_id(operator.operator_id, storage_index, AGG_TYPE_MINMAX); let zset_id = operator.generate_group_rowid(&group_key); // Get the values for this group from min_max_deltas @@ -1276,7 +1251,7 @@ impl RecomputeMinMax { Box::new(ScanState::new_for_min( current_value, group_key.clone(), - column_name.clone(), + column_name, storage_id, zset_id, group_values, @@ -1285,7 +1260,7 @@ impl RecomputeMinMax { Box::new(ScanState::new_for_max( current_value, group_key.clone(), - column_name.clone(), + column_name, storage_id, zset_id, group_values, @@ -1319,12 +1294,12 @@ impl RecomputeMinMax { if *is_min { if let Some(min_val) = new_value { - state.mins.insert(column_name.clone(), min_val); + state.mins.insert(*column_name, min_val); } else { state.mins.remove(column_name); } } else if let Some(max_val) = new_value { - state.maxs.insert(column_name.clone(), max_val); + state.maxs.insert(*column_name, max_val); } else { state.maxs.remove(column_name); } @@ -1355,13 +1330,13 @@ pub enum ScanState { /// Group key being processed group_key: String, /// Column name being processed - column_name: String, + column_name: usize, /// Storage ID for the index seek storage_id: i64, /// ZSet ID for the group zset_id: i64, /// Group values from MinMaxDeltas: (column_name, HashableRow) -> weight - group_values: HashMap<(String, HashableRow), isize>, + group_values: HashMap<(usize, HashableRow), isize>, /// Whether we're looking for MIN (true) or MAX (false) is_min: bool, }, @@ -1371,13 +1346,13 @@ pub enum ScanState { /// Group key being processed group_key: String, /// Column name being processed - column_name: String, + column_name: usize, /// Storage ID for the index seek storage_id: i64, /// ZSet ID for the group zset_id: i64, /// Group values from MinMaxDeltas: (column_name, HashableRow) -> weight - group_values: HashMap<(String, HashableRow), isize>, + group_values: HashMap<(usize, HashableRow), isize>, /// Whether we're looking for MIN (true) or MAX (false) is_min: bool, }, @@ -1391,10 +1366,10 @@ impl ScanState { pub fn new_for_min( current_min: Option, group_key: String, - column_name: String, + column_name: usize, storage_id: i64, zset_id: i64, - group_values: HashMap<(String, HashableRow), isize>, + group_values: HashMap<(usize, HashableRow), isize>, ) -> Self { Self::CheckCandidate { candidate: current_min, @@ -1460,10 +1435,10 @@ impl ScanState { pub fn new_for_max( current_max: Option, group_key: String, - column_name: String, + column_name: usize, storage_id: i64, zset_id: i64, - group_values: HashMap<(String, HashableRow), isize>, + group_values: HashMap<(usize, HashableRow), isize>, ) -> Self { Self::CheckCandidate { candidate: current_max, @@ -1496,7 +1471,7 @@ impl ScanState { // Check if the candidate is retracted (weight <= 0) // Create a HashableRow to look up the weight let hashable_cand = HashableRow::new(0, vec![cand_val.clone()]); - let key = (column_name.clone(), hashable_cand); + let key = (*column_name, hashable_cand); let is_retracted = group_values.get(&key).is_some_and(|weight| *weight <= 0); @@ -1633,7 +1608,7 @@ pub enum MinMaxPersistState { group_idx: usize, value_idx: usize, value: Value, - column_name: String, + column_name: usize, weight: isize, write_row: WriteRow, }, @@ -1652,7 +1627,7 @@ impl MinMaxPersistState { pub fn persist_min_max( &mut self, operator_id: usize, - column_min_max: &HashMap, + column_min_max: &HashMap, cursors: &mut DbspStateCursors, generate_group_rowid: impl Fn(&str) -> i64, ) -> Result> { @@ -1699,7 +1674,7 @@ impl MinMaxPersistState { // Process current value and extract what we need before taking ownership let ((column_name, hashable_row), weight) = values_vec[*value_idx]; - let column_name = column_name.clone(); + let column_name = *column_name; let value = hashable_row.values[0].clone(); // Extract the Value from HashableRow let weight = *weight; @@ -1731,9 +1706,9 @@ impl MinMaxPersistState { let group_key_str = &group_keys[*group_idx]; - // Get the column index from the pre-computed map + // Get the column info from the pre-computed map let column_info = column_min_max - .get(&*column_name) + .get(column_name) .expect("Column should exist in column_min_max map"); let column_index = column_info.index; diff --git a/core/incremental/compiler.rs b/core/incremental/compiler.rs index 972d6797b..c8899a02e 100644 --- a/core/incremental/compiler.rs +++ b/core/incremental/compiler.rs @@ -9,12 +9,12 @@ use crate::incremental::dbsp::{Delta, DeltaPair}; use crate::incremental::expr_compiler::CompiledExpression; use crate::incremental::operator::{ create_dbsp_state_index, DbspStateCursors, EvalState, FilterOperator, FilterPredicate, - IncrementalOperator, InputOperator, ProjectOperator, + IncrementalOperator, InputOperator, JoinOperator, JoinType, ProjectOperator, }; use crate::storage::btree::{BTreeCursor, BTreeKey}; // Note: logical module must be made pub(crate) in translate/mod.rs use crate::translate::logical::{ - BinaryOperator, LogicalExpr, LogicalPlan, LogicalSchema, SchemaRef, + BinaryOperator, JoinType as LogicalJoinType, LogicalExpr, LogicalPlan, LogicalSchema, SchemaRef, }; use crate::types::{IOResult, ImmutableRecord, SeekKey, SeekOp, SeekResult, Value}; use crate::Pager; @@ -288,6 +288,12 @@ pub enum DbspOperator { aggr_exprs: Vec, schema: SchemaRef, }, + /// Join operator (⋈) - joins two relations + Join { + join_type: JoinType, + on_exprs: Vec<(DbspExpr, DbspExpr)>, + schema: SchemaRef, + }, /// Input operator - source of data Input { name: String, schema: SchemaRef }, } @@ -789,6 +795,13 @@ impl DbspCircuit { "{indent}Aggregate[{node_id}]: GROUP BY {group_exprs:?}, AGGR {aggr_exprs:?}" )?; } + DbspOperator::Join { + join_type, + on_exprs, + .. + } => { + writeln!(f, "{indent}Join[{node_id}]: {join_type:?} ON {on_exprs:?}")?; + } DbspOperator::Input { name, .. } => { writeln!(f, "{indent}Input[{node_id}]: {name}")?; } @@ -841,7 +854,7 @@ impl DbspCompiler { // Get input column names for the ProjectOperator let input_schema = proj.input.schema(); let input_column_names: Vec = input_schema.columns.iter() - .map(|(name, _)| name.clone()) + .map(|col| col.name.clone()) .collect(); // Convert logical expressions to DBSP expressions @@ -853,14 +866,14 @@ impl DbspCompiler { let mut compiled_exprs = Vec::new(); let mut aliases = Vec::new(); for expr in &proj.exprs { - let (compiled, alias) = Self::compile_expression(expr, &input_column_names)?; + let (compiled, alias) = Self::compile_expression(expr, input_schema)?; compiled_exprs.push(compiled); aliases.push(alias); } // Get output column names from the projection schema let output_column_names: Vec = proj.schema.columns.iter() - .map(|(name, _)| name.clone()) + .map(|col| col.name.clone()) .collect(); // Create the ProjectOperator @@ -885,7 +898,7 @@ impl DbspCompiler { // Get column names from input schema let input_schema = filter.input.schema(); let column_names: Vec = input_schema.columns.iter() - .map(|(name, _)| name.clone()) + .map(|col| col.name.clone()) .collect(); // Convert predicate to DBSP expression @@ -913,16 +926,21 @@ impl DbspCompiler { // Get input column names let input_schema = agg.input.schema(); let input_column_names: Vec = input_schema.columns.iter() - .map(|(name, _)| name.clone()) + .map(|col| col.name.clone()) .collect(); - // Compile group by expressions to column names - let mut group_by_columns = Vec::new(); + // Compile group by expressions to column indices + let mut group_by_indices = Vec::new(); let mut dbsp_group_exprs = Vec::new(); for expr in &agg.group_expr { // For now, only support simple column references in GROUP BY if let LogicalExpr::Column(col) = expr { - group_by_columns.push(col.name.clone()); + // Find the column index in the input schema using qualified lookup + let (col_idx, _) = input_schema.find_column(&col.name, col.table.as_deref()) + .ok_or_else(|| LimboError::ParseError( + format!("GROUP BY column '{}' not found in input", col.name) + ))?; + group_by_indices.push(col_idx); dbsp_group_exprs.push(DbspExpr::Column(col.name.clone())); } else { return Err(LimboError::ParseError( @@ -936,7 +954,7 @@ impl DbspCompiler { for expr in &agg.aggr_expr { if let LogicalExpr::AggregateFunction { fun, args, .. } = expr { use crate::function::AggFunc; - use crate::incremental::operator::AggregateFunction; + use crate::incremental::aggregate_operator::AggregateFunction; match fun { AggFunc::Count | AggFunc::Count0 => { @@ -946,9 +964,13 @@ impl DbspCompiler { if args.is_empty() { return Err(LimboError::ParseError("SUM requires an argument".to_string())); } - // Extract column name from the argument + // Extract column index from the argument if let LogicalExpr::Column(col) = &args[0] { - aggregate_functions.push(AggregateFunction::Sum(col.name.clone())); + let (col_idx, _) = input_schema.find_column(&col.name, col.table.as_deref()) + .ok_or_else(|| LimboError::ParseError( + format!("SUM column '{}' not found in input", col.name) + ))?; + aggregate_functions.push(AggregateFunction::Sum(col_idx)); } else { return Err(LimboError::ParseError( "Only column references are supported in aggregate functions for incremental views".to_string() @@ -960,7 +982,11 @@ impl DbspCompiler { return Err(LimboError::ParseError("AVG requires an argument".to_string())); } if let LogicalExpr::Column(col) = &args[0] { - aggregate_functions.push(AggregateFunction::Avg(col.name.clone())); + let (col_idx, _) = input_schema.find_column(&col.name, col.table.as_deref()) + .ok_or_else(|| LimboError::ParseError( + format!("AVG column '{}' not found in input", col.name) + ))?; + aggregate_functions.push(AggregateFunction::Avg(col_idx)); } else { return Err(LimboError::ParseError( "Only column references are supported in aggregate functions for incremental views".to_string() @@ -972,7 +998,11 @@ impl DbspCompiler { return Err(LimboError::ParseError("MIN requires an argument".to_string())); } if let LogicalExpr::Column(col) = &args[0] { - aggregate_functions.push(AggregateFunction::Min(col.name.clone())); + let (col_idx, _) = input_schema.find_column(&col.name, col.table.as_deref()) + .ok_or_else(|| LimboError::ParseError( + format!("MIN column '{}' not found in input", col.name) + ))?; + aggregate_functions.push(AggregateFunction::Min(col_idx)); } else { return Err(LimboError::ParseError( "Only column references are supported in MIN for incremental views".to_string() @@ -984,7 +1014,11 @@ impl DbspCompiler { return Err(LimboError::ParseError("MAX requires an argument".to_string())); } if let LogicalExpr::Column(col) = &args[0] { - aggregate_functions.push(AggregateFunction::Max(col.name.clone())); + let (col_idx, _) = input_schema.find_column(&col.name, col.table.as_deref()) + .ok_or_else(|| LimboError::ParseError( + format!("MAX column '{}' not found in input", col.name) + ))?; + aggregate_functions.push(AggregateFunction::Max(col_idx)); } else { return Err(LimboError::ParseError( "Only column references are supported in MAX for incremental views".to_string() @@ -1006,10 +1040,10 @@ impl DbspCompiler { let operator_id = self.circuit.next_id; - use crate::incremental::operator::AggregateOperator; + use crate::incremental::aggregate_operator::AggregateOperator; let executable: Box = Box::new(AggregateOperator::new( operator_id, - group_by_columns.clone(), + group_by_indices.clone(), aggregate_functions.clone(), input_column_names.clone(), )); @@ -1026,6 +1060,90 @@ impl DbspCompiler { Ok(result_node_id) } + LogicalPlan::Join(join) => { + // Compile left and right inputs + let left_id = self.compile_plan(&join.left)?; + let right_id = self.compile_plan(&join.right)?; + + // Get schemas from inputs + let left_schema = join.left.schema(); + let right_schema = join.right.schema(); + + // Get column names from left and right + let left_columns: Vec = left_schema.columns.iter() + .map(|col| col.name.clone()) + .collect(); + let right_columns: Vec = right_schema.columns.iter() + .map(|col| col.name.clone()) + .collect(); + + // Extract join key indices from join conditions + // For now, we only support equijoin conditions + let mut left_key_indices = Vec::new(); + let mut right_key_indices = Vec::new(); + let mut dbsp_on_exprs = Vec::new(); + + for (left_expr, right_expr) in &join.on { + // Extract column indices from join expressions + // We expect simple column references in join conditions + if let (LogicalExpr::Column(left_col), LogicalExpr::Column(right_col)) = (left_expr, right_expr) { + // Find indices in respective schemas using qualified lookup + let (left_idx, _) = left_schema.find_column(&left_col.name, left_col.table.as_deref()) + .ok_or_else(|| LimboError::ParseError( + format!("Join column '{}' not found in left input", left_col.name) + ))?; + let (right_idx, _) = right_schema.find_column(&right_col.name, right_col.table.as_deref()) + .ok_or_else(|| LimboError::ParseError( + format!("Join column '{}' not found in right input", right_col.name) + ))?; + + left_key_indices.push(left_idx); + right_key_indices.push(right_idx); + + // Convert to DBSP expressions + dbsp_on_exprs.push(( + DbspExpr::Column(left_col.name.clone()), + DbspExpr::Column(right_col.name.clone()) + )); + } else { + return Err(LimboError::ParseError( + "Only simple column references are supported in join conditions for incremental views".to_string() + )); + } + } + + // Convert logical join type to operator join type + let operator_join_type = match join.join_type { + LogicalJoinType::Inner => JoinType::Inner, + LogicalJoinType::Left => JoinType::Left, + LogicalJoinType::Right => JoinType::Right, + LogicalJoinType::Full => JoinType::Full, + LogicalJoinType::Cross => JoinType::Cross, + }; + + // Create JoinOperator + let operator_id = self.circuit.next_id; + let executable: Box = Box::new(JoinOperator::new( + operator_id, + operator_join_type.clone(), + left_key_indices, + right_key_indices, + left_columns, + right_columns, + )?); + + // Create join node + let node_id = self.circuit.add_node( + DbspOperator::Join { + join_type: operator_join_type, + on_exprs: dbsp_on_exprs, + schema: join.schema.clone(), + }, + vec![left_id, right_id], + executable, + ); + Ok(node_id) + } LogicalPlan::TableScan(scan) => { // Create input node with InputOperator for uniform handling let executable: Box = @@ -1042,7 +1160,7 @@ impl DbspCompiler { Ok(node_id) } _ => Err(LimboError::ParseError( - format!("Unsupported operator in DBSP compiler: only Filter, Projection and Aggregate are supported, got: {:?}", + format!("Unsupported operator in DBSP compiler: only Filter, Projection, Join and Aggregate are supported, got: {:?}", match plan { LogicalPlan::Sort(_) => "Sort", LogicalPlan::Limit(_) => "Limit", @@ -1095,17 +1213,24 @@ impl DbspCompiler { /// Compile a logical expression to a CompiledExpression and optional alias fn compile_expression( expr: &LogicalExpr, - input_column_names: &[String], + input_schema: &LogicalSchema, ) -> Result<(CompiledExpression, Option)> { // Check for alias first if let LogicalExpr::Alias { expr, alias } = expr { // For aliases, compile the underlying expression and return with alias - let (compiled, _) = Self::compile_expression(expr, input_column_names)?; + let (compiled, _) = Self::compile_expression(expr, input_schema)?; return Ok((compiled, Some(alias.clone()))); } - // Convert LogicalExpr to AST Expr - let ast_expr = Self::logical_to_ast_expr(expr)?; + // Convert LogicalExpr to AST Expr with proper column resolution + let ast_expr = Self::logical_to_ast_expr_with_schema(expr, input_schema)?; + + // Extract column names from schema for CompiledExpression::compile + let input_column_names: Vec = input_schema + .columns + .iter() + .map(|col| col.name.clone()) + .collect(); // For all expressions (simple or complex), use CompiledExpression::compile // This handles both trivial cases and complex VDBE compilation @@ -1129,7 +1254,7 @@ impl DbspCompiler { // Compile the expression using the existing CompiledExpression::compile let compiled = CompiledExpression::compile( &ast_expr, - input_column_names, + &input_column_names, &schema, &temp_syms, internal_conn, @@ -1138,12 +1263,27 @@ impl DbspCompiler { Ok((compiled, None)) } - /// Convert LogicalExpr to AST Expr - fn logical_to_ast_expr(expr: &LogicalExpr) -> Result { + /// Convert LogicalExpr to AST Expr with qualified column resolution + fn logical_to_ast_expr_with_schema( + expr: &LogicalExpr, + schema: &LogicalSchema, + ) -> Result { use turso_parser::ast; match expr { - LogicalExpr::Column(col) => Ok(ast::Expr::Id(ast::Name::Ident(col.name.clone()))), + LogicalExpr::Column(col) => { + // Find the column index using qualified lookup + let (idx, _) = schema + .find_column(&col.name, col.table.as_deref()) + .ok_or_else(|| { + LimboError::ParseError(format!( + "Column '{}' with table {:?} not found in schema", + col.name, col.table + )) + })?; + // Return a Register expression with the correct index + Ok(ast::Expr::Register(idx)) + } LogicalExpr::Literal(val) => { let lit = match val { Value::Integer(i) => ast::Literal::Numeric(i.to_string()), @@ -1155,8 +1295,8 @@ impl DbspCompiler { Ok(ast::Expr::Literal(lit)) } LogicalExpr::BinaryExpr { left, op, right } => { - let left_expr = Self::logical_to_ast_expr(left)?; - let right_expr = Self::logical_to_ast_expr(right)?; + let left_expr = Self::logical_to_ast_expr_with_schema(left, schema)?; + let right_expr = Self::logical_to_ast_expr_with_schema(right, schema)?; Ok(ast::Expr::Binary( Box::new(left_expr), *op, @@ -1164,7 +1304,10 @@ impl DbspCompiler { )) } LogicalExpr::ScalarFunction { fun, args } => { - let ast_args: Result> = args.iter().map(Self::logical_to_ast_expr).collect(); + let ast_args: Result> = args + .iter() + .map(|arg| Self::logical_to_ast_expr_with_schema(arg, schema)) + .collect(); let ast_args: Vec> = ast_args?.into_iter().map(Box::new).collect(); Ok(ast::Expr::FunctionCall { name: ast::Name::Ident(fun.clone()), @@ -1179,7 +1322,7 @@ impl DbspCompiler { } LogicalExpr::Alias { expr, .. } => { // For conversion to AST, ignore the alias and convert the inner expression - Self::logical_to_ast_expr(expr) + Self::logical_to_ast_expr_with_schema(expr, schema) } LogicalExpr::AggregateFunction { fun, @@ -1187,7 +1330,10 @@ impl DbspCompiler { distinct, } => { // Convert aggregate function to AST - let ast_args: Result> = args.iter().map(Self::logical_to_ast_expr).collect(); + let ast_args: Result> = args + .iter() + .map(|arg| Self::logical_to_ast_expr_with_schema(arg, schema)) + .collect(); let ast_args: Vec> = ast_args?.into_iter().map(Box::new).collect(); // Get the function name based on the aggregate type @@ -1315,8 +1461,7 @@ mod tests { use crate::incremental::operator::{FilterOperator, FilterPredicate}; use crate::schema::{BTreeTable, Column as SchemaColumn, Schema, Type}; use crate::storage::pager::CreateBTreeFlags; - use crate::translate::logical::LogicalPlanBuilder; - use crate::translate::logical::LogicalSchema; + use crate::translate::logical::{ColumnInfo, LogicalPlanBuilder, LogicalSchema}; use crate::util::IOExt; use crate::{Database, MemoryIO, Pager, IO}; use std::sync::Arc; @@ -1374,6 +1519,270 @@ mod tests { unique_sets: vec![], }; schema.add_btree_table(Arc::new(users_table)); + + // Add products table for join tests + let products_table = BTreeTable { + name: "products".to_string(), + root_page: 3, + primary_key_columns: vec![( + "product_id".to_string(), + turso_parser::ast::SortOrder::Asc, + )], + columns: vec![ + SchemaColumn { + name: Some("product_id".to_string()), + ty: Type::Integer, + ty_str: "INTEGER".to_string(), + primary_key: true, + is_rowid_alias: true, + notnull: true, + default: None, + unique: false, + collation: None, + hidden: false, + }, + SchemaColumn { + name: Some("product_name".to_string()), + ty: Type::Text, + ty_str: "TEXT".to_string(), + primary_key: false, + is_rowid_alias: false, + notnull: false, + default: None, + unique: false, + collation: None, + hidden: false, + }, + SchemaColumn { + name: Some("price".to_string()), + ty: Type::Integer, + ty_str: "INTEGER".to_string(), + primary_key: false, + is_rowid_alias: false, + notnull: false, + default: None, + unique: false, + collation: None, + hidden: false, + }, + ], + has_rowid: true, + is_strict: false, + unique_sets: vec![], + }; + schema.add_btree_table(Arc::new(products_table)); + + // Add orders table for join tests + let orders_table = BTreeTable { + name: "orders".to_string(), + root_page: 4, + primary_key_columns: vec![( + "order_id".to_string(), + turso_parser::ast::SortOrder::Asc, + )], + columns: vec![ + SchemaColumn { + name: Some("order_id".to_string()), + ty: Type::Integer, + ty_str: "INTEGER".to_string(), + primary_key: true, + is_rowid_alias: true, + notnull: true, + default: None, + unique: false, + collation: None, + hidden: false, + }, + SchemaColumn { + name: Some("user_id".to_string()), + ty: Type::Integer, + ty_str: "INTEGER".to_string(), + primary_key: false, + is_rowid_alias: false, + notnull: false, + default: None, + unique: false, + collation: None, + hidden: false, + }, + SchemaColumn { + name: Some("product_id".to_string()), + ty: Type::Integer, + ty_str: "INTEGER".to_string(), + primary_key: false, + is_rowid_alias: false, + notnull: false, + default: None, + unique: false, + collation: None, + hidden: false, + }, + SchemaColumn { + name: Some("quantity".to_string()), + ty: Type::Integer, + ty_str: "INTEGER".to_string(), + primary_key: false, + is_rowid_alias: false, + notnull: false, + default: None, + unique: false, + collation: None, + hidden: false, + }, + ], + has_rowid: true, + is_strict: false, + unique_sets: vec![], + }; + schema.add_btree_table(Arc::new(orders_table)); + + // Add customers table with id and name for testing column ambiguity + let customers_table = BTreeTable { + name: "customers".to_string(), + root_page: 6, + primary_key_columns: vec![("id".to_string(), turso_parser::ast::SortOrder::Asc)], + columns: vec![ + SchemaColumn { + name: Some("id".to_string()), + ty: Type::Integer, + ty_str: "INTEGER".to_string(), + primary_key: true, + is_rowid_alias: true, + notnull: true, + default: None, + unique: false, + collation: None, + hidden: false, + }, + SchemaColumn { + name: Some("name".to_string()), + ty: Type::Text, + ty_str: "TEXT".to_string(), + primary_key: false, + is_rowid_alias: false, + notnull: false, + default: None, + unique: false, + collation: None, + hidden: false, + }, + ], + has_rowid: true, + is_strict: false, + unique_sets: vec![], + }; + schema.add_btree_table(Arc::new(customers_table)); + + // Add purchases table (junction table for three-way join) + let purchases_table = BTreeTable { + name: "purchases".to_string(), + root_page: 7, + primary_key_columns: vec![("id".to_string(), turso_parser::ast::SortOrder::Asc)], + columns: vec![ + SchemaColumn { + name: Some("id".to_string()), + ty: Type::Integer, + ty_str: "INTEGER".to_string(), + primary_key: true, + is_rowid_alias: true, + notnull: true, + default: None, + unique: false, + collation: None, + hidden: false, + }, + SchemaColumn { + name: Some("customer_id".to_string()), + ty: Type::Integer, + ty_str: "INTEGER".to_string(), + primary_key: false, + is_rowid_alias: false, + notnull: false, + default: None, + unique: false, + collation: None, + hidden: false, + }, + SchemaColumn { + name: Some("vendor_id".to_string()), + ty: Type::Integer, + ty_str: "INTEGER".to_string(), + primary_key: false, + is_rowid_alias: false, + notnull: false, + default: None, + unique: false, + collation: None, + hidden: false, + }, + SchemaColumn { + name: Some("quantity".to_string()), + ty: Type::Integer, + ty_str: "INTEGER".to_string(), + primary_key: false, + is_rowid_alias: false, + notnull: false, + default: None, + unique: false, + collation: None, + hidden: false, + }, + ], + has_rowid: true, + is_strict: false, + unique_sets: vec![], + }; + schema.add_btree_table(Arc::new(purchases_table)); + + // Add vendors table with id, name, and price (ambiguous columns with customers) + let vendors_table = BTreeTable { + name: "vendors".to_string(), + root_page: 8, + primary_key_columns: vec![("id".to_string(), turso_parser::ast::SortOrder::Asc)], + columns: vec![ + SchemaColumn { + name: Some("id".to_string()), + ty: Type::Integer, + ty_str: "INTEGER".to_string(), + primary_key: true, + is_rowid_alias: true, + notnull: true, + default: None, + unique: false, + collation: None, + hidden: false, + }, + SchemaColumn { + name: Some("name".to_string()), + ty: Type::Text, + ty_str: "TEXT".to_string(), + primary_key: false, + is_rowid_alias: false, + notnull: false, + default: None, + unique: false, + collation: None, + hidden: false, + }, + SchemaColumn { + name: Some("price".to_string()), + ty: Type::Integer, + ty_str: "INTEGER".to_string(), + primary_key: false, + is_rowid_alias: false, + notnull: false, + default: None, + unique: false, + collation: None, + hidden: false, + }, + ], + has_rowid: true, + is_strict: false, + unique_sets: vec![], + }; + schema.add_btree_table(Arc::new(vendors_table)); + let sales_table = BTreeTable { name: "sales".to_string(), root_page: 2, @@ -3342,8 +3751,20 @@ mod tests { // Create a simple filter node let schema = Arc::new(LogicalSchema::new(vec![ - ("id".to_string(), Type::Integer), - ("value".to_string(), Type::Integer), + ColumnInfo { + name: "id".to_string(), + ty: Type::Integer, + database: None, + table: None, + table_alias: None, + }, + ColumnInfo { + name: "value".to_string(), + ty: Type::Integer, + database: None, + table: None, + table_alias: None, + }, ])); // First create an input node with InputOperator @@ -3486,4 +3907,767 @@ mod tests { "Row should still exist with multiplicity 1" ); } + + #[test] + fn test_join_with_aggregation() { + // Test join followed by aggregation - verifying actual output + let (mut circuit, pager) = compile_sql!( + "SELECT u.name, SUM(o.quantity) as total_quantity + FROM users u + JOIN orders o ON u.id = o.user_id + GROUP BY u.name" + ); + + // Create test data for users + let mut users_delta = Delta::new(); + users_delta.insert( + 1, + vec![ + Value::Integer(1), + Value::Text("Alice".into()), + Value::Integer(30), + ], + ); + users_delta.insert( + 2, + vec![ + Value::Integer(2), + Value::Text("Bob".into()), + Value::Integer(25), + ], + ); + + // Create test data for orders (order_id, user_id, product_id, quantity) + let mut orders_delta = Delta::new(); + orders_delta.insert( + 1, + vec![ + Value::Integer(1), + Value::Integer(1), + Value::Integer(101), + Value::Integer(5), + ], + ); // Alice: 5 + orders_delta.insert( + 2, + vec![ + Value::Integer(2), + Value::Integer(1), + Value::Integer(102), + Value::Integer(3), + ], + ); // Alice: 3 + orders_delta.insert( + 3, + vec![ + Value::Integer(3), + Value::Integer(2), + Value::Integer(101), + Value::Integer(7), + ], + ); // Bob: 7 + orders_delta.insert( + 4, + vec![ + Value::Integer(4), + Value::Integer(1), + Value::Integer(103), + Value::Integer(2), + ], + ); // Alice: 2 + + let inputs = HashMap::from([ + ("users".to_string(), users_delta), + ("orders".to_string(), orders_delta), + ]); + + let result = test_execute(&mut circuit, inputs.clone(), pager.clone()).unwrap(); + + // Should have 2 results: Alice with total 10, Bob with total 7 + assert_eq!( + result.len(), + 2, + "Should have aggregated results for Alice and Bob" + ); + + // Check the results + let mut results_map: HashMap = HashMap::new(); + for (row, weight) in result.changes { + assert_eq!(weight, 1); + assert_eq!(row.values.len(), 2); // name and total_quantity + + if let (Value::Text(name), Value::Integer(total)) = (&row.values[0], &row.values[1]) { + results_map.insert(name.to_string(), *total); + } else { + panic!("Unexpected value types in result"); + } + } + + assert_eq!( + results_map.get("Alice"), + Some(&10), + "Alice should have total quantity 10" + ); + assert_eq!( + results_map.get("Bob"), + Some(&7), + "Bob should have total quantity 7" + ); + } + + #[test] + fn test_join_aggregate_with_filter() { + // Test complex query with join, filter, and aggregation - verifying output + let (mut circuit, pager) = compile_sql!( + "SELECT u.name, SUM(o.quantity) as total + FROM users u + JOIN orders o ON u.id = o.user_id + WHERE u.age > 18 + GROUP BY u.name" + ); + + // Create test data for users + let mut users_delta = Delta::new(); + users_delta.insert( + 1, + vec![ + Value::Integer(1), + Value::Text("Alice".into()), + Value::Integer(30), + ], + ); // age > 18 + users_delta.insert( + 2, + vec![ + Value::Integer(2), + Value::Text("Bob".into()), + Value::Integer(17), + ], + ); // age <= 18 + users_delta.insert( + 3, + vec![ + Value::Integer(3), + Value::Text("Charlie".into()), + Value::Integer(25), + ], + ); // age > 18 + + // Create test data for orders (order_id, user_id, product_id, quantity) + let mut orders_delta = Delta::new(); + orders_delta.insert( + 1, + vec![ + Value::Integer(1), + Value::Integer(1), + Value::Integer(101), + Value::Integer(5), + ], + ); // Alice: 5 + orders_delta.insert( + 2, + vec![ + Value::Integer(2), + Value::Integer(2), + Value::Integer(102), + Value::Integer(10), + ], + ); // Bob: 10 (should be filtered) + orders_delta.insert( + 3, + vec![ + Value::Integer(3), + Value::Integer(3), + Value::Integer(101), + Value::Integer(7), + ], + ); // Charlie: 7 + orders_delta.insert( + 4, + vec![ + Value::Integer(4), + Value::Integer(1), + Value::Integer(103), + Value::Integer(3), + ], + ); // Alice: 3 + + let inputs = HashMap::from([ + ("users".to_string(), users_delta), + ("orders".to_string(), orders_delta), + ]); + + let result = test_execute(&mut circuit, inputs.clone(), pager.clone()).unwrap(); + + // Should only have results for Alice and Charlie (Bob filtered out due to age <= 18) + assert_eq!( + result.len(), + 2, + "Should only have results for users with age > 18" + ); + + // Check the results + let mut results_map: HashMap = HashMap::new(); + for (row, weight) in result.changes { + assert_eq!(weight, 1); + assert_eq!(row.values.len(), 2); // name and total + + if let (Value::Text(name), Value::Integer(total)) = (&row.values[0], &row.values[1]) { + results_map.insert(name.to_string(), *total); + } + } + + assert_eq!( + results_map.get("Alice"), + Some(&8), + "Alice should have total 8" + ); + assert_eq!( + results_map.get("Charlie"), + Some(&7), + "Charlie should have total 7" + ); + assert_eq!(results_map.get("Bob"), None, "Bob should be filtered out"); + } + + #[test] + fn test_three_way_join_execution() { + // Test executing a 3-way join with aggregation + let (mut circuit, pager) = compile_sql!( + "SELECT u.name, p.product_name, SUM(o.quantity) as total + FROM users u + JOIN orders o ON u.id = o.user_id + JOIN products p ON o.product_id = p.product_id + GROUP BY u.name, p.product_name" + ); + + // Create test data for users + let mut users_delta = Delta::new(); + users_delta.insert( + 1, + vec![ + Value::Integer(1), + Value::Text("Alice".into()), + Value::Integer(25), + ], + ); + users_delta.insert( + 2, + vec![ + Value::Integer(2), + Value::Text("Bob".into()), + Value::Integer(30), + ], + ); + + // Create test data for products + let mut products_delta = Delta::new(); + products_delta.insert( + 100, + vec![ + Value::Integer(100), + Value::Text("Widget".into()), + Value::Integer(50), + ], + ); + products_delta.insert( + 101, + vec![ + Value::Integer(101), + Value::Text("Gadget".into()), + Value::Integer(75), + ], + ); + products_delta.insert( + 102, + vec![ + Value::Integer(102), + Value::Text("Doohickey".into()), + Value::Integer(25), + ], + ); + + // Create test data for orders joining users and products + let mut orders_delta = Delta::new(); + // Alice orders 5 Widgets + orders_delta.insert( + 1, + vec![ + Value::Integer(1), + Value::Integer(1), + Value::Integer(100), + Value::Integer(5), + ], + ); + // Alice orders 3 Gadgets + orders_delta.insert( + 2, + vec![ + Value::Integer(2), + Value::Integer(1), + Value::Integer(101), + Value::Integer(3), + ], + ); + // Bob orders 7 Widgets + orders_delta.insert( + 3, + vec![ + Value::Integer(3), + Value::Integer(2), + Value::Integer(100), + Value::Integer(7), + ], + ); + // Bob orders 2 Doohickeys + orders_delta.insert( + 4, + vec![ + Value::Integer(4), + Value::Integer(2), + Value::Integer(102), + Value::Integer(2), + ], + ); + // Alice orders 4 more Widgets + orders_delta.insert( + 5, + vec![ + Value::Integer(5), + Value::Integer(1), + Value::Integer(100), + Value::Integer(4), + ], + ); + + let mut inputs = HashMap::new(); + inputs.insert("users".to_string(), users_delta); + inputs.insert("products".to_string(), products_delta); + inputs.insert("orders".to_string(), orders_delta); + + // Execute the 3-way join with aggregation + let result = test_execute(&mut circuit, inputs.clone(), pager.clone()).unwrap(); + + // We should get aggregated results for each user-product combination + // Expected results: + // - Alice, Widget: 9 (5 + 4) + // - Alice, Gadget: 3 + // - Bob, Widget: 7 + // - Bob, Doohickey: 2 + assert_eq!(result.len(), 4, "Should have 4 aggregated results"); + + // Verify aggregation results + let mut found_results = std::collections::HashSet::new(); + for (row, weight) in result.changes.iter() { + assert_eq!(*weight, 1); + // Row should have name, product_name, and sum columns + assert_eq!(row.values.len(), 3); + + if let (Value::Text(name), Value::Text(product), Value::Integer(total)) = + (&row.values[0], &row.values[1], &row.values[2]) + { + let key = format!("{}-{}", name.as_ref(), product.as_ref()); + found_results.insert(key.clone()); + + match key.as_str() { + "Alice-Widget" => { + assert_eq!(*total, 9, "Alice should have ordered 9 Widgets total") + } + "Alice-Gadget" => assert_eq!(*total, 3, "Alice should have ordered 3 Gadgets"), + "Bob-Widget" => assert_eq!(*total, 7, "Bob should have ordered 7 Widgets"), + "Bob-Doohickey" => { + assert_eq!(*total, 2, "Bob should have ordered 2 Doohickeys") + } + _ => panic!("Unexpected result: {key}"), + } + } else { + panic!("Unexpected value types in result"); + } + } + + // Ensure we found all expected combinations + assert!(found_results.contains("Alice-Widget")); + assert!(found_results.contains("Alice-Gadget")); + assert!(found_results.contains("Bob-Widget")); + assert!(found_results.contains("Bob-Doohickey")); + } + + #[test] + fn test_join_execution() { + let (mut circuit, pager) = compile_sql!( + "SELECT u.name, o.quantity FROM users u JOIN orders o ON u.id = o.user_id" + ); + + // Create test data for users + let mut users_delta = Delta::new(); + users_delta.insert( + 1, + vec![ + Value::Integer(1), + Value::Text("Alice".into()), + Value::Integer(25), + ], + ); + users_delta.insert( + 2, + vec![ + Value::Integer(2), + Value::Text("Bob".into()), + Value::Integer(30), + ], + ); + + // Create test data for orders + let mut orders_delta = Delta::new(); + orders_delta.insert( + 1, + vec![ + Value::Integer(1), + Value::Integer(1), + Value::Integer(100), + Value::Integer(5), + ], + ); + orders_delta.insert( + 2, + vec![ + Value::Integer(2), + Value::Integer(1), + Value::Integer(101), + Value::Integer(3), + ], + ); + orders_delta.insert( + 3, + vec![ + Value::Integer(3), + Value::Integer(2), + Value::Integer(102), + Value::Integer(7), + ], + ); + + let mut inputs = HashMap::new(); + inputs.insert("users".to_string(), users_delta); + inputs.insert("orders".to_string(), orders_delta); + + // Execute the join + let result = test_execute(&mut circuit, inputs.clone(), pager.clone()).unwrap(); + + // We should get 3 results (2 orders for Alice, 1 for Bob) + assert_eq!(result.len(), 3, "Should have 3 join results"); + + // Verify the join results contain the correct data + let results: Vec<_> = result.changes.iter().collect(); + + // Check that we have the expected joined rows + for (row, weight) in results { + assert_eq!(*weight, 1); // All weights should be 1 for insertions + // Row should have name and quantity columns + assert_eq!(row.values.len(), 2); + } + } + + #[test] + fn test_three_way_join_with_column_ambiguity() { + // Test three-way join with aggregation where multiple tables have columns with the same name + // Ensures that column references are correctly resolved to their respective tables + // Tables: customers(id, name), purchases(id, customer_id, vendor_id, quantity), vendors(id, name, price) + // Note: both customers and vendors have 'id' and 'name' columns which can cause ambiguity + + let sql = "SELECT c.name as customer_name, v.name as vendor_name, + SUM(p.quantity) as total_quantity, + SUM(p.quantity * v.price) as total_value + FROM customers c + JOIN purchases p ON c.id = p.customer_id + JOIN vendors v ON p.vendor_id = v.id + GROUP BY c.name, v.name"; + + let (mut circuit, pager) = compile_sql!(sql); + + // Create test data for customers (id, name) + let mut customers_delta = Delta::new(); + customers_delta.insert(1, vec![Value::Integer(1), Value::Text("Alice".into())]); + customers_delta.insert(2, vec![Value::Integer(2), Value::Text("Bob".into())]); + + // Create test data for vendors (id, name, price) + let mut vendors_delta = Delta::new(); + vendors_delta.insert( + 1, + vec![ + Value::Integer(1), + Value::Text("Widget Co".into()), + Value::Integer(10), + ], + ); + vendors_delta.insert( + 2, + vec![ + Value::Integer(2), + Value::Text("Gadget Inc".into()), + Value::Integer(20), + ], + ); + + // Create test data for purchases (id, customer_id, vendor_id, quantity) + let mut purchases_delta = Delta::new(); + // Alice purchases 5 units from Widget Co + purchases_delta.insert( + 1, + vec![ + Value::Integer(1), + Value::Integer(1), // customer_id: Alice + Value::Integer(1), // vendor_id: Widget Co + Value::Integer(5), + ], + ); + // Alice purchases 3 units from Gadget Inc + purchases_delta.insert( + 2, + vec![ + Value::Integer(2), + Value::Integer(1), // customer_id: Alice + Value::Integer(2), // vendor_id: Gadget Inc + Value::Integer(3), + ], + ); + // Bob purchases 2 units from Widget Co + purchases_delta.insert( + 3, + vec![ + Value::Integer(3), + Value::Integer(2), // customer_id: Bob + Value::Integer(1), // vendor_id: Widget Co + Value::Integer(2), + ], + ); + // Alice purchases 4 more units from Widget Co + purchases_delta.insert( + 4, + vec![ + Value::Integer(4), + Value::Integer(1), // customer_id: Alice + Value::Integer(1), // vendor_id: Widget Co + Value::Integer(4), + ], + ); + + let inputs = HashMap::from([ + ("customers".to_string(), customers_delta), + ("purchases".to_string(), purchases_delta), + ("vendors".to_string(), vendors_delta), + ]); + + let result = test_execute(&mut circuit, inputs.clone(), pager.clone()).unwrap(); + + // Expected results: + // Alice|Gadget Inc|3|60 (3 units * 20 price = 60) + // Alice|Widget Co|9|90 (9 units * 10 price = 90) + // Bob|Widget Co|2|20 (2 units * 10 price = 20) + + assert_eq!(result.len(), 3, "Should have 3 aggregated results"); + + // Sort results for consistent testing + let mut results: Vec<_> = result.changes.into_iter().collect(); + results.sort_by(|a, b| { + let a_cust = &a.0.values[0]; + let a_vend = &a.0.values[1]; + let b_cust = &b.0.values[0]; + let b_vend = &b.0.values[1]; + (a_cust, a_vend).cmp(&(b_cust, b_vend)) + }); + + // Verify Alice's Gadget Inc purchases + assert_eq!(results[0].0.values[0], Value::Text("Alice".into())); + assert_eq!(results[0].0.values[1], Value::Text("Gadget Inc".into())); + assert_eq!(results[0].0.values[2], Value::Integer(3)); // total_quantity + assert_eq!(results[0].0.values[3], Value::Integer(60)); // total_value + + // Verify Alice's Widget Co purchases + assert_eq!(results[1].0.values[0], Value::Text("Alice".into())); + assert_eq!(results[1].0.values[1], Value::Text("Widget Co".into())); + assert_eq!(results[1].0.values[2], Value::Integer(9)); // total_quantity + assert_eq!(results[1].0.values[3], Value::Integer(90)); // total_value + + // Verify Bob's Widget Co purchases + assert_eq!(results[2].0.values[0], Value::Text("Bob".into())); + assert_eq!(results[2].0.values[1], Value::Text("Widget Co".into())); + assert_eq!(results[2].0.values[2], Value::Integer(2)); // total_quantity + assert_eq!(results[2].0.values[3], Value::Integer(20)); // total_value + } + + #[test] + fn test_join_with_aggregate_execution() { + let (mut circuit, pager) = compile_sql!( + "SELECT u.name, SUM(o.quantity) as total_quantity + FROM users u + JOIN orders o ON u.id = o.user_id + GROUP BY u.name" + ); + + // Create test data for users + let mut users_delta = Delta::new(); + users_delta.insert( + 1, + vec![ + Value::Integer(1), + Value::Text("Alice".into()), + Value::Integer(25), + ], + ); + users_delta.insert( + 2, + vec![ + Value::Integer(2), + Value::Text("Bob".into()), + Value::Integer(30), + ], + ); + + // Create test data for orders + let mut orders_delta = Delta::new(); + orders_delta.insert( + 1, + vec![ + Value::Integer(1), + Value::Integer(1), + Value::Integer(100), + Value::Integer(5), + ], + ); + orders_delta.insert( + 2, + vec![ + Value::Integer(2), + Value::Integer(1), + Value::Integer(101), + Value::Integer(3), + ], + ); + orders_delta.insert( + 3, + vec![ + Value::Integer(3), + Value::Integer(2), + Value::Integer(102), + Value::Integer(7), + ], + ); + + let mut inputs = HashMap::new(); + inputs.insert("users".to_string(), users_delta); + inputs.insert("orders".to_string(), orders_delta); + + // Execute the join with aggregation + let result = test_execute(&mut circuit, inputs.clone(), pager.clone()).unwrap(); + + // We should get 2 aggregated results (one for Alice, one for Bob) + assert_eq!(result.len(), 2, "Should have 2 aggregated results"); + + // Verify aggregation results + for (row, weight) in result.changes.iter() { + assert_eq!(*weight, 1); + // Row should have name and sum columns + assert_eq!(row.values.len(), 2); + + // Check the aggregated values + if let Value::Text(name) = &row.values[0] { + if name.as_ref() == "Alice" { + // Alice should have total quantity of 8 (5 + 3) + assert_eq!(row.values[1], Value::Integer(8)); + } else if name.as_ref() == "Bob" { + // Bob should have total quantity of 7 + assert_eq!(row.values[1], Value::Integer(7)); + } + } + } + } + + #[test] + fn test_filter_with_qualified_columns_in_join() { + // Test that filters correctly handle qualified column names in joins + // when multiple tables have columns with the SAME names. + // Both users and sales tables have an 'id' column which can be ambiguous. + + let (mut circuit, pager) = compile_sql!( + "SELECT users.id, users.name, sales.id, sales.amount + FROM users + JOIN sales ON users.id = sales.customer_id + WHERE users.id > 1 AND sales.id < 100" + ); + + // Create test data + let mut users_delta = Delta::new(); + let mut sales_delta = Delta::new(); + + // Users data: (id, name, age) + users_delta.insert( + 1, + vec![ + Value::Integer(1), + Value::Text("Alice".into()), + Value::Integer(30), + ], + ); // id = 1 + users_delta.insert( + 2, + vec![ + Value::Integer(2), + Value::Text("Bob".into()), + Value::Integer(25), + ], + ); // id = 2 + users_delta.insert( + 3, + vec![ + Value::Integer(3), + Value::Text("Charlie".into()), + Value::Integer(35), + ], + ); // id = 3 + + // Sales data: (id, customer_id, amount) + sales_delta.insert( + 50, + vec![Value::Integer(50), Value::Integer(1), Value::Integer(100)], + ); // sales.id = 50, customer_id = 1 + sales_delta.insert( + 99, + vec![Value::Integer(99), Value::Integer(2), Value::Integer(200)], + ); // sales.id = 99, customer_id = 2 + sales_delta.insert( + 150, + vec![Value::Integer(150), Value::Integer(3), Value::Integer(300)], + ); // sales.id = 150, customer_id = 3 + + let mut inputs = HashMap::new(); + inputs.insert("users".to_string(), users_delta); + inputs.insert("sales".to_string(), sales_delta); + + let result = test_execute(&mut circuit, inputs.clone(), pager.clone()).unwrap(); + + // Should only get row with Bob (users.id=2, sales.id=99): + // - users.id=2 (> 1) AND sales.id=99 (< 100) ✓ + // Alice excluded: users.id=1 (NOT > 1) + // Charlie excluded: sales.id=150 (NOT < 100) + assert_eq!(result.len(), 1, "Should have 1 filtered result"); + + let (row, weight) = &result.changes[0]; + assert_eq!(*weight, 1); + assert_eq!(row.values.len(), 4, "Should have 4 columns"); + + // Verify the filter correctly used qualified columns + assert_eq!(row.values[0], Value::Integer(2), "users.id should be 2"); + assert_eq!( + row.values[1], + Value::Text("Bob".into()), + "users.name should be Bob" + ); + assert_eq!(row.values[2], Value::Integer(99), "sales.id should be 99"); + assert_eq!( + row.values[3], + Value::Integer(200), + "sales.amount should be 200" + ); + } } diff --git a/core/incremental/operator.rs b/core/incremental/operator.rs index 54cd7e0a0..72ed7bc0c 100644 --- a/core/incremental/operator.rs +++ b/core/incremental/operator.rs @@ -3,7 +3,7 @@ // Based on Feldera DBSP design but adapted for Turso's architecture pub use crate::incremental::aggregate_operator::{ - AggregateEvalState, AggregateFunction, AggregateOperator, AggregateState, + AggregateEvalState, AggregateFunction, AggregateState, }; pub use crate::incremental::filter_operator::{FilterOperator, FilterPredicate}; pub use crate::incremental::input_operator::InputOperator; @@ -251,7 +251,7 @@ pub trait IncrementalOperator: Debug { #[cfg(test)] mod tests { use super::*; - use crate::incremental::aggregate_operator::AGG_TYPE_REGULAR; + use crate::incremental::aggregate_operator::{AggregateOperator, AGG_TYPE_REGULAR}; use crate::incremental::dbsp::HashableRow; use crate::storage::pager::CreateBTreeFlags; use crate::types::Text; @@ -395,9 +395,9 @@ mod tests { // Create an aggregate operator for SUM(age) with no GROUP BY let mut agg = AggregateOperator::new( - 1, // operator_id for testing - vec![], // No GROUP BY - vec![AggregateFunction::Sum("age".to_string())], + 1, // operator_id for testing + vec![], // No GROUP BY + vec![AggregateFunction::Sum(2)], // age is at index 2 vec!["id".to_string(), "name".to_string(), "age".to_string()], ); @@ -514,9 +514,9 @@ mod tests { let mut cursors = DbspStateCursors::new(table_cursor, index_cursor); let mut agg = AggregateOperator::new( - 1, // operator_id for testing - vec!["team".to_string()], // GROUP BY team - vec![AggregateFunction::Sum("score".to_string())], + 1, // operator_id for testing + vec![1], // GROUP BY team (index 1) + vec![AggregateFunction::Sum(3)], // score is at index 3 vec![ "id".to_string(), "team".to_string(), @@ -666,8 +666,8 @@ mod tests { // Create COUNT(*) GROUP BY category let mut agg = AggregateOperator::new( - 1, // operator_id for testing - vec!["category".to_string()], + 1, // operator_id for testing + vec![1], // category is at index 1 vec![AggregateFunction::Count], vec![ "item_id".to_string(), @@ -746,9 +746,9 @@ mod tests { let mut cursors = DbspStateCursors::new(table_cursor, index_cursor); let mut agg = AggregateOperator::new( - 1, // operator_id for testing - vec!["product".to_string()], - vec![AggregateFunction::Sum("amount".to_string())], + 1, // operator_id for testing + vec![1], // product is at index 1 + vec![AggregateFunction::Sum(2)], // amount is at index 2 vec![ "sale_id".to_string(), "product".to_string(), @@ -843,11 +843,11 @@ mod tests { let mut cursors = DbspStateCursors::new(table_cursor, index_cursor); let mut agg = AggregateOperator::new( - 1, // operator_id for testing - vec!["user_id".to_string()], + 1, // operator_id for testing + vec![1], // user_id is at index 1 vec![ AggregateFunction::Count, - AggregateFunction::Sum("amount".to_string()), + AggregateFunction::Sum(2), // amount is at index 2 ], vec![ "order_id".to_string(), @@ -935,9 +935,9 @@ mod tests { let mut cursors = DbspStateCursors::new(table_cursor, index_cursor); let mut agg = AggregateOperator::new( - 1, // operator_id for testing - vec!["category".to_string()], - vec![AggregateFunction::Avg("value".to_string())], + 1, // operator_id for testing + vec![1], // category is at index 1 + vec![AggregateFunction::Avg(2)], // value is at index 2 vec![ "id".to_string(), "category".to_string(), @@ -1035,11 +1035,11 @@ mod tests { let mut cursors = DbspStateCursors::new(table_cursor, index_cursor); let mut agg = AggregateOperator::new( - 1, // operator_id for testing - vec!["category".to_string()], + 1, // operator_id for testing + vec![1], // category is at index 1 vec![ AggregateFunction::Count, - AggregateFunction::Sum("value".to_string()), + AggregateFunction::Sum(2), // value is at index 2 ], vec![ "id".to_string(), @@ -1108,7 +1108,7 @@ mod tests { #[test] fn test_count_aggregation_with_deletions() { let aggregates = vec![AggregateFunction::Count]; - let group_by = vec!["category".to_string()]; + let group_by = vec![0]; // category is at index 0 let input_columns = vec!["category".to_string(), "value".to_string()]; // Create a persistent pager for the test @@ -1197,8 +1197,8 @@ mod tests { #[test] fn test_sum_aggregation_with_deletions() { - let aggregates = vec![AggregateFunction::Sum("value".to_string())]; - let group_by = vec!["category".to_string()]; + let aggregates = vec![AggregateFunction::Sum(1)]; // value is at index 1 + let group_by = vec![0]; // category is at index 0 let input_columns = vec!["category".to_string(), "value".to_string()]; // Create a persistent pager for the test @@ -1281,8 +1281,8 @@ mod tests { #[test] fn test_avg_aggregation_with_deletions() { - let aggregates = vec![AggregateFunction::Avg("value".to_string())]; - let group_by = vec!["category".to_string()]; + let aggregates = vec![AggregateFunction::Avg(1)]; // value is at index 1 + let group_by = vec![0]; // category is at index 0 let input_columns = vec!["category".to_string(), "value".to_string()]; // Create a persistent pager for the test @@ -1348,10 +1348,10 @@ mod tests { // Test COUNT, SUM, and AVG together let aggregates = vec![ AggregateFunction::Count, - AggregateFunction::Sum("value".to_string()), - AggregateFunction::Avg("value".to_string()), + AggregateFunction::Sum(1), // value is at index 1 + AggregateFunction::Avg(1), // value is at index 1 ]; - let group_by = vec!["category".to_string()]; + let group_by = vec![0]; // category is at index 0 let input_columns = vec!["category".to_string(), "value".to_string()]; // Create a persistent pager for the test @@ -1607,11 +1607,11 @@ mod tests { let mut cursors = DbspStateCursors::new(table_cursor, index_cursor); let mut agg = AggregateOperator::new( - 1, // operator_id for testing - vec!["category".to_string()], + 1, // operator_id for testing + vec![1], // category is at index 1 vec![ AggregateFunction::Count, - AggregateFunction::Sum("amount".to_string()), + AggregateFunction::Sum(2), // amount is at index 2 ], vec![ "id".to_string(), @@ -1781,7 +1781,7 @@ mod tests { vec![], // No GROUP BY vec![ AggregateFunction::Count, - AggregateFunction::Sum("value".to_string()), + AggregateFunction::Sum(1), // value is at index 1 ], vec!["id".to_string(), "value".to_string()], ); @@ -1859,8 +1859,8 @@ mod tests { let mut cursors = DbspStateCursors::new(table_cursor, index_cursor); let mut agg = AggregateOperator::new( - 1, // operator_id for testing - vec!["type".to_string()], + 1, // operator_id for testing + vec![1], // type is at index 1 vec![AggregateFunction::Count], vec!["id".to_string(), "type".to_string()], ); @@ -1976,8 +1976,8 @@ mod tests { 1, // operator_id vec![], // No GROUP BY vec![ - AggregateFunction::Min("price".to_string()), - AggregateFunction::Max("price".to_string()), + AggregateFunction::Min(2), // price is at index 2 + AggregateFunction::Max(2), // price is at index 2 ], vec!["id".to_string(), "name".to_string(), "price".to_string()], ); @@ -2044,8 +2044,8 @@ mod tests { 1, // operator_id vec![], // No GROUP BY vec![ - AggregateFunction::Min("price".to_string()), - AggregateFunction::Max("price".to_string()), + AggregateFunction::Min(2), // price is at index 2 + AggregateFunction::Max(2), // price is at index 2 ], vec!["id".to_string(), "name".to_string(), "price".to_string()], ); @@ -2134,8 +2134,8 @@ mod tests { 1, // operator_id vec![], // No GROUP BY vec![ - AggregateFunction::Min("price".to_string()), - AggregateFunction::Max("price".to_string()), + AggregateFunction::Min(2), // price is at index 2 + AggregateFunction::Max(2), // price is at index 2 ], vec!["id".to_string(), "name".to_string(), "price".to_string()], ); @@ -2224,8 +2224,8 @@ mod tests { 1, // operator_id vec![], // No GROUP BY vec![ - AggregateFunction::Min("price".to_string()), - AggregateFunction::Max("price".to_string()), + AggregateFunction::Min(2), // price is at index 2 + AggregateFunction::Max(2), // price is at index 2 ], vec!["id".to_string(), "name".to_string(), "price".to_string()], ); @@ -2306,8 +2306,8 @@ mod tests { 1, // operator_id vec![], // No GROUP BY vec![ - AggregateFunction::Min("price".to_string()), - AggregateFunction::Max("price".to_string()), + AggregateFunction::Min(2), // price is at index 2 + AggregateFunction::Max(2), // price is at index 2 ], vec!["id".to_string(), "name".to_string(), "price".to_string()], ); @@ -2388,8 +2388,8 @@ mod tests { 1, // operator_id vec![], // No GROUP BY vec![ - AggregateFunction::Min("price".to_string()), - AggregateFunction::Max("price".to_string()), + AggregateFunction::Min(2), // price is at index 2 + AggregateFunction::Max(2), // price is at index 2 ], vec!["id".to_string(), "name".to_string(), "price".to_string()], ); @@ -2475,11 +2475,11 @@ mod tests { let mut cursors = DbspStateCursors::new(table_cursor, index_cursor); let mut agg = AggregateOperator::new( - 1, // operator_id - vec!["category".to_string()], // GROUP BY category + 1, // operator_id + vec![1], // GROUP BY category (index 1) vec![ - AggregateFunction::Min("price".to_string()), - AggregateFunction::Max("price".to_string()), + AggregateFunction::Min(3), // price is at index 3 + AggregateFunction::Max(3), // price is at index 3 ], vec![ "id".to_string(), @@ -2580,8 +2580,8 @@ mod tests { 1, // operator_id vec![], // No GROUP BY vec![ - AggregateFunction::Min("price".to_string()), - AggregateFunction::Max("price".to_string()), + AggregateFunction::Min(2), // price is at index 2 + AggregateFunction::Max(2), // price is at index 2 ], vec!["id".to_string(), "name".to_string(), "price".to_string()], ); @@ -2656,8 +2656,8 @@ mod tests { 1, // operator_id vec![], // No GROUP BY vec![ - AggregateFunction::Min("score".to_string()), - AggregateFunction::Max("score".to_string()), + AggregateFunction::Min(2), // score is at index 2 + AggregateFunction::Max(2), // score is at index 2 ], vec!["id".to_string(), "name".to_string(), "score".to_string()], ); @@ -2724,8 +2724,8 @@ mod tests { 1, // operator_id vec![], // No GROUP BY vec![ - AggregateFunction::Min("name".to_string()), - AggregateFunction::Max("name".to_string()), + AggregateFunction::Min(1), // name is at index 1 + AggregateFunction::Max(1), // name is at index 1 ], vec!["id".to_string(), "name".to_string()], ); @@ -2764,10 +2764,10 @@ mod tests { vec![], // No GROUP BY vec![ AggregateFunction::Count, - AggregateFunction::Sum("value".to_string()), - AggregateFunction::Min("value".to_string()), - AggregateFunction::Max("value".to_string()), - AggregateFunction::Avg("value".to_string()), + AggregateFunction::Sum(1), // value is at index 1 + AggregateFunction::Min(1), // value is at index 1 + AggregateFunction::Max(1), // value is at index 1 + AggregateFunction::Avg(1), // value is at index 1 ], vec!["id".to_string(), "value".to_string()], ); @@ -2855,9 +2855,9 @@ mod tests { 1, // operator_id vec![], // No GROUP BY vec![ - AggregateFunction::Min("col1".to_string()), - AggregateFunction::Max("col2".to_string()), - AggregateFunction::Min("col3".to_string()), + AggregateFunction::Min(0), // col1 is at index 0 + AggregateFunction::Max(1), // col2 is at index 1 + AggregateFunction::Min(2), // col3 is at index 2 ], vec!["col1".to_string(), "col2".to_string(), "col3".to_string()], ); diff --git a/core/translate/logical.rs b/core/translate/logical.rs index 6a8b0a6c2..b11e2df4f 100644 --- a/core/translate/logical.rs +++ b/core/translate/logical.rs @@ -19,26 +19,35 @@ use turso_parser::ast; /// Result type for preprocessing aggregate expressions type PreprocessAggregateResult = ( - bool, // needs_pre_projection - Vec, // pre_projection_exprs - Vec<(String, Type)>, // pre_projection_schema - Vec, // modified_aggr_exprs + bool, // needs_pre_projection + Vec, // pre_projection_exprs + Vec, // pre_projection_schema + Vec, // modified_aggr_exprs ); /// Result type for parsing join conditions type JoinConditionsResult = (Vec<(LogicalExpr, LogicalExpr)>, Option); +/// Information about a column in a logical schema +#[derive(Debug, Clone, PartialEq)] +pub struct ColumnInfo { + pub name: String, + pub ty: Type, + pub database: Option, + pub table: Option, + pub table_alias: Option, +} + /// Schema information for logical plan nodes #[derive(Debug, Clone, PartialEq)] pub struct LogicalSchema { - /// Column names and types - pub columns: Vec<(String, Type)>, + pub columns: Vec, } /// A reference to a schema that can be shared between nodes pub type SchemaRef = Arc; impl LogicalSchema { - pub fn new(columns: Vec<(String, Type)>) -> Self { + pub fn new(columns: Vec) -> Self { Self { columns } } @@ -52,11 +61,42 @@ impl LogicalSchema { self.columns.len() } - pub fn find_column(&self, name: &str) -> Option<(usize, &Type)> { - self.columns - .iter() - .position(|(n, _)| n == name) - .map(|idx| (idx, &self.columns[idx].1)) + pub fn find_column(&self, name: &str, table: Option<&str>) -> Option<(usize, &ColumnInfo)> { + if let Some(table_ref) = table { + // Check if it's a database.table format + if table_ref.contains('.') { + let parts: Vec<&str> = table_ref.splitn(2, '.').collect(); + if parts.len() == 2 { + let db = parts[0]; + let tbl = parts[1]; + return self + .columns + .iter() + .position(|c| { + c.name == name + && c.database.as_deref() == Some(db) + && c.table.as_deref() == Some(tbl) + }) + .map(|idx| (idx, &self.columns[idx])); + } + } + + // Try to match against table alias first, then table name + self.columns + .iter() + .position(|c| { + c.name == name + && (c.table_alias.as_deref() == Some(table_ref) + || c.table.as_deref() == Some(table_ref)) + }) + .map(|idx| (idx, &self.columns[idx])) + } else { + // Unqualified lookup - just match by name + self.columns + .iter() + .position(|c| c.name == name) + .map(|idx| (idx, &self.columns[idx])) + } } } @@ -548,14 +588,14 @@ impl<'a> LogicalPlanBuilder<'a> { } // Regular table scan - let table_schema = self.get_table_schema(&table_name)?; let table_alias = alias.as_ref().map(|a| match a { ast::As::As(name) => Self::name_to_string(name), ast::As::Elided(name) => Self::name_to_string(name), }); + let table_schema = self.get_table_schema(&table_name, table_alias.as_deref())?; Ok(LogicalPlan::TableScan(TableScan { table_name, - alias: table_alias, + alias: table_alias.clone(), schema: table_schema, projection: None, })) @@ -751,14 +791,14 @@ impl<'a> LogicalPlanBuilder<'a> { let _left_idx = left_schema .columns .iter() - .position(|(n, _)| n == &name) + .position(|col| col.name == name) .ok_or_else(|| { LimboError::ParseError(format!("Column {name} not found in left table")) })?; let _right_idx = right_schema .columns .iter() - .position(|(n, _)| n == &name) + .position(|col| col.name == name) .ok_or_else(|| { LimboError::ParseError(format!("Column {name} not found in right table")) })?; @@ -790,9 +830,13 @@ impl<'a> LogicalPlanBuilder<'a> { // Find common column names let mut common_columns = Vec::new(); - for (left_name, _) in &left_schema.columns { - if right_schema.columns.iter().any(|(n, _)| n == left_name) { - common_columns.push(ast::Name::Ident(left_name.clone())); + for left_col in &left_schema.columns { + if right_schema + .columns + .iter() + .any(|col| col.name == left_col.name) + { + common_columns.push(ast::Name::Ident(left_col.name.clone())); } } @@ -833,10 +877,18 @@ impl<'a> LogicalPlanBuilder<'a> { let left_schema = left.schema(); let right_schema = right.schema(); - // For now, simply concatenate the schemas - // In a real implementation, we'd handle column name conflicts and nullable columns - let mut columns = left_schema.columns.clone(); - columns.extend(right_schema.columns.clone()); + // Concatenate the schemas, preserving all column information + let mut columns = Vec::new(); + + // Keep all columns from left with their table info + for col in &left_schema.columns { + columns.push(col.clone()); + } + + // Keep all columns from right with their table info + for col in &right_schema.columns { + columns.push(col.clone()); + } Ok(Arc::new(LogicalSchema::new(columns))) } @@ -870,7 +922,13 @@ impl<'a> LogicalPlanBuilder<'a> { }; let col_type = Self::infer_expr_type(&logical_expr, input_schema)?; - schema_columns.push((col_name.clone(), col_type)); + schema_columns.push(ColumnInfo { + name: col_name.clone(), + ty: col_type, + database: None, + table: None, + table_alias: None, + }); if let Some(as_alias) = alias { let alias_name = match as_alias { @@ -886,21 +944,21 @@ impl<'a> LogicalPlanBuilder<'a> { } ast::ResultColumn::Star => { // Expand * to all columns - for (name, typ) in &input_schema.columns { - proj_exprs.push(LogicalExpr::Column(Column::new(name.clone()))); - schema_columns.push((name.clone(), *typ)); + for col in &input_schema.columns { + proj_exprs.push(LogicalExpr::Column(Column::new(col.name.clone()))); + schema_columns.push(col.clone()); } } ast::ResultColumn::TableStar(table) => { // Expand table.* to all columns from that table let table_name = Self::name_to_string(table); - for (name, typ) in &input_schema.columns { + for col in &input_schema.columns { // Simple check - would need proper table tracking in real implementation proj_exprs.push(LogicalExpr::Column(Column::with_table( - name.clone(), + col.name.clone(), table_name.clone(), ))); - schema_columns.push((name.clone(), *typ)); + schema_columns.push(col.clone()); } } } @@ -938,7 +996,13 @@ impl<'a> LogicalPlanBuilder<'a> { if let LogicalExpr::Column(col) = expr { pre_projection_exprs.push(expr.clone()); let col_type = Self::infer_expr_type(expr, input_schema)?; - pre_projection_schema.push((col.name.clone(), col_type)); + pre_projection_schema.push(ColumnInfo { + name: col.name.clone(), + ty: col_type, + database: None, + table: col.table.clone(), + table_alias: None, + }); } else { // Complex group by expression - project it needs_pre_projection = true; @@ -946,7 +1010,13 @@ impl<'a> LogicalPlanBuilder<'a> { projected_col_counter += 1; pre_projection_exprs.push(expr.clone()); let col_type = Self::infer_expr_type(expr, input_schema)?; - pre_projection_schema.push((proj_col_name.clone(), col_type)); + pre_projection_schema.push(ColumnInfo { + name: proj_col_name.clone(), + ty: col_type, + database: None, + table: None, + table_alias: None, + }); } } @@ -970,7 +1040,13 @@ impl<'a> LogicalPlanBuilder<'a> { pre_projection_exprs.push(arg.clone()); let col_type = Self::infer_expr_type(arg, input_schema)?; if let LogicalExpr::Column(col) = arg { - pre_projection_schema.push((col.name.clone(), col_type)); + pre_projection_schema.push(ColumnInfo { + name: col.name.clone(), + ty: col_type, + database: None, + table: col.table.clone(), + table_alias: None, + }); } } } @@ -983,7 +1059,13 @@ impl<'a> LogicalPlanBuilder<'a> { // Add the expression to the pre-projection pre_projection_exprs.push(arg.clone()); let col_type = Self::infer_expr_type(arg, input_schema)?; - pre_projection_schema.push((proj_col_name.clone(), col_type)); + pre_projection_schema.push(ColumnInfo { + name: proj_col_name.clone(), + ty: col_type, + database: None, + table: None, + table_alias: None, + }); // In the aggregate, reference the projected column modified_args.push(LogicalExpr::Column(Column::new(proj_col_name))); @@ -1057,15 +1139,39 @@ impl<'a> LogicalPlanBuilder<'a> { // First, add GROUP BY columns to the aggregate output schema // These are always part of the aggregate operator's output for group_expr in &group_exprs { - let col_name = match group_expr { - LogicalExpr::Column(col) => col.name.clone(), + match group_expr { + LogicalExpr::Column(col) => { + // For column references in GROUP BY, preserve the original column info + if let Some((_, col_info)) = + input_schema.find_column(&col.name, col.table.as_deref()) + { + // Preserve the column with all its table information + aggregate_schema_columns.push(col_info.clone()); + } else { + // Fallback if column not found (shouldn't happen) + let col_type = Self::infer_expr_type(group_expr, input_schema)?; + aggregate_schema_columns.push(ColumnInfo { + name: col.name.clone(), + ty: col_type, + database: None, + table: col.table.clone(), + table_alias: None, + }); + } + } _ => { // For complex GROUP BY expressions, generate a name - format!("__group_{}", aggregate_schema_columns.len()) + let col_name = format!("__group_{}", aggregate_schema_columns.len()); + let col_type = Self::infer_expr_type(group_expr, input_schema)?; + aggregate_schema_columns.push(ColumnInfo { + name: col_name, + ty: col_type, + database: None, + table: None, + table_alias: None, + }); } - }; - let col_type = Self::infer_expr_type(group_expr, input_schema)?; - aggregate_schema_columns.push((col_name, col_type)); + } } // Track aggregates we've already seen to avoid duplicates @@ -1098,7 +1204,13 @@ impl<'a> LogicalPlanBuilder<'a> { } else { // New aggregate - add it let col_type = Self::infer_expr_type(&logical_expr, input_schema)?; - aggregate_schema_columns.push((col_name.clone(), col_type)); + aggregate_schema_columns.push(ColumnInfo { + name: col_name.clone(), + ty: col_type, + database: None, + table: None, + table_alias: None, + }); aggr_exprs.push(logical_expr); aggregate_map.insert(agg_key, col_name.clone()); col_name.clone() @@ -1122,7 +1234,13 @@ impl<'a> LogicalPlanBuilder<'a> { // Add only new aggregates for (agg_expr, agg_name) in extracted_aggs { let agg_type = Self::infer_expr_type(&agg_expr, input_schema)?; - aggregate_schema_columns.push((agg_name, agg_type)); + aggregate_schema_columns.push(ColumnInfo { + name: agg_name, + ty: agg_type, + database: None, + table: None, + table_alias: None, + }); aggr_exprs.push(agg_expr); } @@ -1197,7 +1315,13 @@ impl<'a> LogicalPlanBuilder<'a> { // For type inference, we need the aggregate schema for column references let aggregate_schema = LogicalSchema::new(aggregate_schema_columns.clone()); let col_type = Self::infer_expr_type(expr, &Arc::new(aggregate_schema))?; - projection_schema_columns.push((col_name, col_type)); + projection_schema_columns.push(ColumnInfo { + name: col_name, + ty: col_type, + database: None, + table: None, + table_alias: None, + }); } // Create the input plan (with pre-projection if needed) @@ -1220,11 +1344,11 @@ impl<'a> LogicalPlanBuilder<'a> { // Check if we need the outer projection // We need a projection if: - // 1. Any expression is more complex than a simple column reference (e.g., abs(sum(id))) - // 2. We're selecting a different set of columns than what the aggregate outputs - // 3. Columns are renamed or reordered + // 1. We have expressions that compute new values (e.g., SUM(x) * 2) + // 2. We're selecting a different set of columns than GROUP BY + aggregates + // 3. We're reordering columns from their natural aggregate output order let needs_outer_projection = { - // Check if any expression is more complex than a simple column reference + // Check for complex expressions let has_complex_exprs = projection_exprs .iter() .any(|expr| !matches!(expr, LogicalExpr::Column(_))); @@ -1232,17 +1356,29 @@ impl<'a> LogicalPlanBuilder<'a> { if has_complex_exprs { true } else { - // All are simple columns - check if we're selecting exactly what the aggregate outputs - // The projection might be selecting a subset (e.g., only aggregates without group columns) - // or reordering columns, or using different names + // Check if we're selecting exactly what aggregate outputs in the same order + // The aggregate outputs: all GROUP BY columns, then all aggregate expressions + // The projection might select a subset or reorder these - // For now, keep it simple: if schemas don't match exactly, we need projection - // This handles all cases: subset selection, reordering, renaming - projection_schema_columns != aggregate_schema_columns + if projection_exprs.len() != aggregate_schema_columns.len() { + // Different number of columns + true + } else { + // Check if columns match in order and name + !projection_exprs.iter().zip(&aggregate_schema_columns).all( + |(expr, agg_col)| { + if let LogicalExpr::Column(col) = expr { + col.name == agg_col.name + } else { + false + } + }, + ) + } } }; - // Create the aggregate node + // Create the aggregate node with its natural schema let aggregate_plan = LogicalPlan::Aggregate(Aggregate { input: aggregate_input, group_expr: group_exprs, @@ -1257,7 +1393,7 @@ impl<'a> LogicalPlanBuilder<'a> { schema: Arc::new(LogicalSchema::new(projection_schema_columns)), })) } else { - // No projection needed - the aggregate output is exactly what we want + // No projection needed - aggregate output matches what we want Ok(aggregate_plan) } } @@ -1275,7 +1411,13 @@ impl<'a> LogicalPlanBuilder<'a> { // Infer schema from first row let mut schema_columns = Vec::new(); for (i, _) in values[0].iter().enumerate() { - schema_columns.push((format!("column{}", i + 1), Type::Text)); + schema_columns.push(ColumnInfo { + name: format!("column{}", i + 1), + ty: Type::Text, + database: None, + table: None, + table_alias: None, + }); } for row in values { @@ -2003,17 +2145,31 @@ impl<'a> LogicalPlanBuilder<'a> { } // Get table schema - fn get_table_schema(&self, table_name: &str) -> Result { + fn get_table_schema(&self, table_name: &str, alias: Option<&str>) -> Result { // Look up table in schema let table = self .schema .get_table(table_name) .ok_or_else(|| LimboError::ParseError(format!("Table '{table_name}' not found")))?; + // Parse table_name which might be "db.table" for attached databases + let (database, actual_table) = if table_name.contains('.') { + let parts: Vec<&str> = table_name.splitn(2, '.').collect(); + (Some(parts[0].to_string()), parts[1].to_string()) + } else { + (None, table_name.to_string()) + }; + let mut columns = Vec::new(); for col in table.columns() { if let Some(ref name) = col.name { - columns.push((name.clone(), col.ty)); + columns.push(ColumnInfo { + name: name.clone(), + ty: col.ty, + database: database.clone(), + table: Some(actual_table.clone()), + table_alias: alias.map(|s| s.to_string()), + }); } } @@ -2024,8 +2180,8 @@ impl<'a> LogicalPlanBuilder<'a> { fn infer_expr_type(expr: &LogicalExpr, schema: &SchemaRef) -> Result { match expr { LogicalExpr::Column(col) => { - if let Some((_, typ)) = schema.find_column(&col.name) { - Ok(*typ) + if let Some((_, col_info)) = schema.find_column(&col.name, col.table.as_deref()) { + Ok(col_info.ty) } else { Ok(Type::Text) }