From e80dd8e5e1b2ea7871b2cd4911c80256ca48f0d1 Mon Sep 17 00:00:00 2001 From: Glauber Costa Date: Thu, 18 Sep 2025 10:14:40 -0500 Subject: [PATCH] move the filter operator to accept indexes instead of names We already did similarly for the AggregateOperator: for joins you can have the same column name in many tables. And passing schema information to the operator is a layering violation (the operator may be operating on the result of a previous node, and at that point there is no more "schema"). Therefore we pass indexes into the column set the operator has. The FilterOperator has a complication: we are using it to generate the SQL for the populate statement, and that needs column names. However, we should *not* be using the FilterOperator for that, and that is a relic from the time where we had operator information directly inside the IncrementalView. To enable moving the FilterOperator to index-based, we rework that code. For joins, we'll need to populate many tables anyway, so we take the time to do that work here. --- core/incremental/compiler.rs | 127 ++-- core/incremental/filter_operator.rs | 238 ++---- core/incremental/operator.rs | 22 +- core/incremental/view.rs | 1065 ++++++++++++++++++++++++--- 4 files changed, 1121 insertions(+), 331 deletions(-) diff --git a/core/incremental/compiler.rs b/core/incremental/compiler.rs index c8899a02e..07fd8f83c 100644 --- a/core/incremental/compiler.rs +++ b/core/incremental/compiler.rs @@ -895,21 +895,18 @@ impl DbspCompiler { // Compile the input first let input_id = self.compile_plan(&filter.input)?; - // Get column names from input schema + // Get input schema for column resolution let input_schema = filter.input.schema(); - let column_names: Vec = input_schema.columns.iter() - .map(|col| col.name.clone()) - .collect(); // Convert predicate to DBSP expression let dbsp_predicate = Self::compile_expr(&filter.predicate)?; // Convert to FilterPredicate - let filter_predicate = Self::compile_filter_predicate(&filter.predicate)?; + let filter_predicate = Self::compile_filter_predicate(&filter.predicate, input_schema)?; // Create executable operator let executable: Box = - Box::new(FilterOperator::new(filter_predicate, column_names)); + Box::new(FilterOperator::new(filter_predicate)); // Create filter node let node_id = self.circuit.add_node( @@ -1372,42 +1369,57 @@ impl DbspCompiler { } /// Compile a logical expression to a FilterPredicate for execution - fn compile_filter_predicate(expr: &LogicalExpr) -> Result { + fn compile_filter_predicate( + expr: &LogicalExpr, + schema: &LogicalSchema, + ) -> Result { match expr { LogicalExpr::BinaryExpr { left, op, right } => { // Extract column name and value for simple predicates if let (LogicalExpr::Column(col), LogicalExpr::Literal(val)) = (left.as_ref(), right.as_ref()) { + // Resolve column name to index using the schema + let column_idx = schema + .columns + .iter() + .position(|c| c.name == col.name) + .ok_or_else(|| { + crate::LimboError::ParseError(format!( + "Column '{}' not found in schema for filter", + col.name + )) + })?; + match op { BinaryOperator::Equals => Ok(FilterPredicate::Equals { - column: col.name.clone(), + column_idx, value: val.clone(), }), BinaryOperator::NotEquals => Ok(FilterPredicate::NotEquals { - column: col.name.clone(), + column_idx, value: val.clone(), }), BinaryOperator::Greater => Ok(FilterPredicate::GreaterThan { - column: col.name.clone(), + column_idx, value: val.clone(), }), BinaryOperator::GreaterEquals => Ok(FilterPredicate::GreaterThanOrEqual { - column: col.name.clone(), + column_idx, value: val.clone(), }), BinaryOperator::Less => Ok(FilterPredicate::LessThan { - column: col.name.clone(), + column_idx, value: val.clone(), }), BinaryOperator::LessEquals => Ok(FilterPredicate::LessThanOrEqual { - column: col.name.clone(), + column_idx, value: val.clone(), }), BinaryOperator::And => { // Handle AND of two predicates - let left_pred = Self::compile_filter_predicate(left)?; - let right_pred = Self::compile_filter_predicate(right)?; + let left_pred = Self::compile_filter_predicate(left, schema)?; + let right_pred = Self::compile_filter_predicate(right, schema)?; Ok(FilterPredicate::And( Box::new(left_pred), Box::new(right_pred), @@ -1415,8 +1427,8 @@ impl DbspCompiler { } BinaryOperator::Or => { // Handle OR of two predicates - let left_pred = Self::compile_filter_predicate(left)?; - let right_pred = Self::compile_filter_predicate(right)?; + let left_pred = Self::compile_filter_predicate(left, schema)?; + let right_pred = Self::compile_filter_predicate(right, schema)?; Ok(FilterPredicate::Or( Box::new(left_pred), Box::new(right_pred), @@ -1428,8 +1440,8 @@ impl DbspCompiler { } } else if matches!(op, BinaryOperator::And | BinaryOperator::Or) { // Handle logical operators - let left_pred = Self::compile_filter_predicate(left)?; - let right_pred = Self::compile_filter_predicate(right)?; + let left_pred = Self::compile_filter_predicate(left, schema)?; + let right_pred = Self::compile_filter_predicate(right, schema)?; match op { BinaryOperator::And => Ok(FilterPredicate::And( Box::new(left_pred), @@ -3777,13 +3789,10 @@ mod tests { Box::new(InputOperator::new("test".to_string())), ); - let filter_op = FilterOperator::new( - FilterPredicate::GreaterThan { - column: "value".to_string(), - value: Value::Integer(10), - }, - vec!["id".to_string(), "value".to_string()], - ); + let filter_op = FilterOperator::new(FilterPredicate::GreaterThan { + column_idx: 1, // "value" is at index 1 + value: Value::Integer(10), + }); // Create the filter predicate using DbspExpr let predicate = DbspExpr::BinaryExpr { @@ -4587,18 +4596,18 @@ mod tests { fn test_filter_with_qualified_columns_in_join() { // Test that filters correctly handle qualified column names in joins // when multiple tables have columns with the SAME names. - // Both users and sales tables have an 'id' column which can be ambiguous. + // Both users and customers tables have 'id' and 'name' columns which can be ambiguous. let (mut circuit, pager) = compile_sql!( - "SELECT users.id, users.name, sales.id, sales.amount + "SELECT users.id, users.name, customers.id, customers.name FROM users - JOIN sales ON users.id = sales.customer_id - WHERE users.id > 1 AND sales.id < 100" + JOIN customers ON users.id = customers.id + WHERE users.id > 1 AND customers.id < 100" ); // Create test data let mut users_delta = Delta::new(); - let mut sales_delta = Delta::new(); + let mut customers_delta = Delta::new(); // Users data: (id, name, age) users_delta.insert( @@ -4626,48 +4635,60 @@ mod tests { ], ); // id = 3 - // Sales data: (id, customer_id, amount) - sales_delta.insert( - 50, - vec![Value::Integer(50), Value::Integer(1), Value::Integer(100)], - ); // sales.id = 50, customer_id = 1 - sales_delta.insert( - 99, - vec![Value::Integer(99), Value::Integer(2), Value::Integer(200)], - ); // sales.id = 99, customer_id = 2 - sales_delta.insert( - 150, - vec![Value::Integer(150), Value::Integer(3), Value::Integer(300)], - ); // sales.id = 150, customer_id = 3 + // Customers data: (id, name, email) + customers_delta.insert( + 1, + vec![ + Value::Integer(1), + Value::Text("Customer Alice".into()), + Value::Text("alice@example.com".into()), + ], + ); // id = 1 + customers_delta.insert( + 2, + vec![ + Value::Integer(2), + Value::Text("Customer Bob".into()), + Value::Text("bob@example.com".into()), + ], + ); // id = 2 + customers_delta.insert( + 3, + vec![ + Value::Integer(3), + Value::Text("Customer Charlie".into()), + Value::Text("charlie@example.com".into()), + ], + ); // id = 3 let mut inputs = HashMap::new(); inputs.insert("users".to_string(), users_delta); - inputs.insert("sales".to_string(), sales_delta); + inputs.insert("customers".to_string(), customers_delta); let result = test_execute(&mut circuit, inputs.clone(), pager.clone()).unwrap(); - // Should only get row with Bob (users.id=2, sales.id=99): - // - users.id=2 (> 1) AND sales.id=99 (< 100) ✓ + // Should get rows where users.id > 1 AND customers.id < 100 + // - users.id=2 (> 1) AND customers.id=2 (< 100) ✓ + // - users.id=3 (> 1) AND customers.id=3 (< 100) ✓ // Alice excluded: users.id=1 (NOT > 1) - // Charlie excluded: sales.id=150 (NOT < 100) - assert_eq!(result.len(), 1, "Should have 1 filtered result"); + assert_eq!(result.len(), 2, "Should have 2 filtered results"); let (row, weight) = &result.changes[0]; assert_eq!(*weight, 1); assert_eq!(row.values.len(), 4, "Should have 4 columns"); - // Verify the filter correctly used qualified columns + // Verify the filter correctly used qualified columns for Bob assert_eq!(row.values[0], Value::Integer(2), "users.id should be 2"); assert_eq!( row.values[1], Value::Text("Bob".into()), "users.name should be Bob" ); - assert_eq!(row.values[2], Value::Integer(99), "sales.id should be 99"); + assert_eq!(row.values[2], Value::Integer(2), "customers.id should be 2"); assert_eq!( row.values[3], - Value::Integer(200), - "sales.amount should be 200" + Value::Text("Customer Bob".into()), + "customers.name should be Customer Bob" ); } } diff --git a/core/incremental/filter_operator.rs b/core/incremental/filter_operator.rs index f836f4897..a0179f9d4 100644 --- a/core/incremental/filter_operator.rs +++ b/core/incremental/filter_operator.rs @@ -6,26 +6,25 @@ use crate::incremental::dbsp::{Delta, DeltaPair}; use crate::incremental::operator::{ ComputationTracker, DbspStateCursors, EvalState, IncrementalOperator, }; -use crate::types::{IOResult, Text}; +use crate::types::IOResult; use crate::{Result, Value}; use std::sync::{Arc, Mutex}; -use turso_parser::ast::{Expr, Literal, OneSelect, Operator}; /// Filter predicate for filtering rows #[derive(Debug, Clone)] pub enum FilterPredicate { - /// Column = value - Equals { column: String, value: Value }, - /// Column != value - NotEquals { column: String, value: Value }, - /// Column > value - GreaterThan { column: String, value: Value }, - /// Column >= value - GreaterThanOrEqual { column: String, value: Value }, - /// Column < value - LessThan { column: String, value: Value }, - /// Column <= value - LessThanOrEqual { column: String, value: Value }, + /// Column = value (using column index) + Equals { column_idx: usize, value: Value }, + /// Column != value (using column index) + NotEquals { column_idx: usize, value: Value }, + /// Column > value (using column index) + GreaterThan { column_idx: usize, value: Value }, + /// Column >= value (using column index) + GreaterThanOrEqual { column_idx: usize, value: Value }, + /// Column < value (using column index) + LessThan { column_idx: usize, value: Value }, + /// Column <= value (using column index) + LessThanOrEqual { column_idx: usize, value: Value }, /// Logical AND of two predicates And(Box, Box), /// Logical OR of two predicates @@ -34,122 +33,17 @@ pub enum FilterPredicate { None, } -impl FilterPredicate { - /// Parse a SQL AST expression into a FilterPredicate - /// This centralizes all SQL-to-predicate parsing logic - pub fn from_sql_expr(expr: &turso_parser::ast::Expr) -> crate::Result { - let Expr::Binary(lhs, op, rhs) = expr else { - return Err(crate::LimboError::ParseError( - "Unsupported WHERE clause for incremental views: not a binary expression" - .to_string(), - )); - }; - - // Handle AND/OR logical operators - match op { - Operator::And => { - let left = Self::from_sql_expr(lhs)?; - let right = Self::from_sql_expr(rhs)?; - return Ok(FilterPredicate::And(Box::new(left), Box::new(right))); - } - Operator::Or => { - let left = Self::from_sql_expr(lhs)?; - let right = Self::from_sql_expr(rhs)?; - return Ok(FilterPredicate::Or(Box::new(left), Box::new(right))); - } - _ => {} - } - - // Handle comparison operators - let Expr::Id(column_name) = &**lhs else { - return Err(crate::LimboError::ParseError( - "Unsupported WHERE clause for incremental views: left-hand-side is not a column reference".to_string(), - )); - }; - - let column = column_name.as_str().to_string(); - - // Parse the right-hand side value - let value = match &**rhs { - Expr::Literal(Literal::String(s)) => { - // Strip quotes from string literals - let cleaned = s.trim_matches('\'').trim_matches('"'); - Value::Text(Text::new(cleaned)) - } - Expr::Literal(Literal::Numeric(n)) => { - // Try to parse as integer first, then float - if let Ok(i) = n.parse::() { - Value::Integer(i) - } else if let Ok(f) = n.parse::() { - Value::Float(f) - } else { - return Err(crate::LimboError::ParseError( - "Unsupported WHERE clause for incremental views: right-hand-side is not a numeric literal".to_string(), - )); - } - } - Expr::Literal(Literal::Null) => Value::Null, - Expr::Literal(Literal::Blob(_)) => { - // Blob comparison not yet supported - return Err(crate::LimboError::ParseError( - "Unsupported WHERE clause for incremental views: comparison with blob literals is not supported".to_string(), - )); - } - other => { - // Complex expressions not yet supported - return Err(crate::LimboError::ParseError( - format!("Unsupported WHERE clause for incremental views: comparison with {other:?} is not supported"), - )); - } - }; - - // Create the appropriate predicate based on operator - match op { - Operator::Equals => Ok(FilterPredicate::Equals { column, value }), - Operator::NotEquals => Ok(FilterPredicate::NotEquals { column, value }), - Operator::Greater => Ok(FilterPredicate::GreaterThan { column, value }), - Operator::GreaterEquals => Ok(FilterPredicate::GreaterThanOrEqual { column, value }), - Operator::Less => Ok(FilterPredicate::LessThan { column, value }), - Operator::LessEquals => Ok(FilterPredicate::LessThanOrEqual { column, value }), - other => Err(crate::LimboError::ParseError( - format!("Unsupported WHERE clause for incremental views: comparison operator {other:?} is not supported"), - )), - } - } - - /// Parse a WHERE clause from a SELECT statement - pub fn from_select(select: &turso_parser::ast::Select) -> crate::Result { - if let OneSelect::Select { - ref where_clause, .. - } = select.body.select - { - if let Some(where_clause) = where_clause { - Self::from_sql_expr(where_clause) - } else { - Ok(FilterPredicate::None) - } - } else { - Err(crate::LimboError::ParseError( - "Unsupported WHERE clause for incremental views: not a single SELECT statement" - .to_string(), - )) - } - } -} - /// Filter operator - filters rows based on predicate #[derive(Debug)] pub struct FilterOperator { predicate: FilterPredicate, - column_names: Vec, tracker: Option>>, } impl FilterOperator { - pub fn new(predicate: FilterPredicate, column_names: Vec) -> Self { + pub fn new(predicate: FilterPredicate) -> Self { Self { predicate, - column_names, tracker: None, } } @@ -162,86 +56,72 @@ impl FilterOperator { pub fn evaluate_predicate(&self, values: &[Value]) -> bool { match &self.predicate { FilterPredicate::None => true, - FilterPredicate::Equals { column, value } => { - if let Some(idx) = self.column_names.iter().position(|c| c == column) { - if let Some(v) = values.get(idx) { - return v == value; + FilterPredicate::Equals { column_idx, value } => { + if let Some(v) = values.get(*column_idx) { + return v == value; + } + false + } + FilterPredicate::NotEquals { column_idx, value } => { + if let Some(v) = values.get(*column_idx) { + return v != value; + } + false + } + FilterPredicate::GreaterThan { column_idx, value } => { + if let Some(v) = values.get(*column_idx) { + // Compare based on value types + match (v, value) { + (Value::Integer(a), Value::Integer(b)) => return a > b, + (Value::Float(a), Value::Float(b)) => return a > b, + (Value::Text(a), Value::Text(b)) => return a.as_str() > b.as_str(), + _ => {} } } false } - FilterPredicate::NotEquals { column, value } => { - if let Some(idx) = self.column_names.iter().position(|c| c == column) { - if let Some(v) = values.get(idx) { - return v != value; + FilterPredicate::GreaterThanOrEqual { column_idx, value } => { + if let Some(v) = values.get(*column_idx) { + match (v, value) { + (Value::Integer(a), Value::Integer(b)) => return a >= b, + (Value::Float(a), Value::Float(b)) => return a >= b, + (Value::Text(a), Value::Text(b)) => return a.as_str() >= b.as_str(), + _ => {} } } false } - FilterPredicate::GreaterThan { column, value } => { - if let Some(idx) = self.column_names.iter().position(|c| c == column) { - if let Some(v) = values.get(idx) { - // Compare based on value types - match (v, value) { - (Value::Integer(a), Value::Integer(b)) => return a > b, - (Value::Float(a), Value::Float(b)) => return a > b, - (Value::Text(a), Value::Text(b)) => return a.as_str() > b.as_str(), - _ => {} - } + FilterPredicate::LessThan { column_idx, value } => { + if let Some(v) = values.get(*column_idx) { + match (v, value) { + (Value::Integer(a), Value::Integer(b)) => return a < b, + (Value::Float(a), Value::Float(b)) => return a < b, + (Value::Text(a), Value::Text(b)) => return a.as_str() < b.as_str(), + _ => {} } } false } - FilterPredicate::GreaterThanOrEqual { column, value } => { - if let Some(idx) = self.column_names.iter().position(|c| c == column) { - if let Some(v) = values.get(idx) { - match (v, value) { - (Value::Integer(a), Value::Integer(b)) => return a >= b, - (Value::Float(a), Value::Float(b)) => return a >= b, - (Value::Text(a), Value::Text(b)) => return a.as_str() >= b.as_str(), - _ => {} - } - } - } - false - } - FilterPredicate::LessThan { column, value } => { - if let Some(idx) = self.column_names.iter().position(|c| c == column) { - if let Some(v) = values.get(idx) { - match (v, value) { - (Value::Integer(a), Value::Integer(b)) => return a < b, - (Value::Float(a), Value::Float(b)) => return a < b, - (Value::Text(a), Value::Text(b)) => return a.as_str() < b.as_str(), - _ => {} - } - } - } - false - } - FilterPredicate::LessThanOrEqual { column, value } => { - if let Some(idx) = self.column_names.iter().position(|c| c == column) { - if let Some(v) = values.get(idx) { - match (v, value) { - (Value::Integer(a), Value::Integer(b)) => return a <= b, - (Value::Float(a), Value::Float(b)) => return a <= b, - (Value::Text(a), Value::Text(b)) => return a.as_str() <= b.as_str(), - _ => {} - } + FilterPredicate::LessThanOrEqual { column_idx, value } => { + if let Some(v) = values.get(*column_idx) { + match (v, value) { + (Value::Integer(a), Value::Integer(b)) => return a <= b, + (Value::Float(a), Value::Float(b)) => return a <= b, + (Value::Text(a), Value::Text(b)) => return a.as_str() <= b.as_str(), + _ => {} } } false } FilterPredicate::And(left, right) => { // Temporarily create sub-filters to evaluate - let left_filter = FilterOperator::new((**left).clone(), self.column_names.clone()); - let right_filter = - FilterOperator::new((**right).clone(), self.column_names.clone()); + let left_filter = FilterOperator::new((**left).clone()); + let right_filter = FilterOperator::new((**right).clone()); left_filter.evaluate_predicate(values) && right_filter.evaluate_predicate(values) } FilterPredicate::Or(left, right) => { - let left_filter = FilterOperator::new((**left).clone(), self.column_names.clone()); - let right_filter = - FilterOperator::new((**right).clone(), self.column_names.clone()); + let left_filter = FilterOperator::new((**left).clone()); + let right_filter = FilterOperator::new((**right).clone()); left_filter.evaluate_predicate(values) || right_filter.evaluate_predicate(values) } } diff --git a/core/incremental/operator.rs b/core/incremental/operator.rs index 72ed7bc0c..2af512504 100644 --- a/core/incremental/operator.rs +++ b/core/incremental/operator.rs @@ -1450,13 +1450,10 @@ mod tests { BTreeCursor::new_index(None, pager.clone(), index_root_page_id, &index_def, 4); let mut cursors = DbspStateCursors::new(table_cursor, index_cursor); - let mut filter = FilterOperator::new( - FilterPredicate::GreaterThan { - column: "b".to_string(), - value: Value::Integer(2), - }, - vec!["a".to_string(), "b".to_string()], - ); + let mut filter = FilterOperator::new(FilterPredicate::GreaterThan { + column_idx: 1, // "b" is at index 1 + value: Value::Integer(2), + }); // Initialize with a row (rowid=3, values=[3, 3]) let mut init_data = Delta::new(); @@ -1512,13 +1509,10 @@ mod tests { BTreeCursor::new_index(None, pager.clone(), index_root_page_id, &index_def, 4); let mut cursors = DbspStateCursors::new(table_cursor, index_cursor); - let mut filter = FilterOperator::new( - FilterPredicate::GreaterThan { - column: "age".to_string(), - value: Value::Integer(25), - }, - vec!["id".to_string(), "name".to_string(), "age".to_string()], - ); + let mut filter = FilterOperator::new(FilterPredicate::GreaterThan { + column_idx: 2, // "age" is at index 2 + value: Value::Integer(25), + }); // Initialize with some data let mut init_data = Delta::new(); diff --git a/core/incremental/view.rs b/core/incremental/view.rs index 8b32c5dcc..77f1d0217 100644 --- a/core/incremental/view.rs +++ b/core/incremental/view.rs @@ -1,6 +1,6 @@ use super::compiler::{DbspCircuit, DbspCompiler, DeltaSet}; use super::dbsp::Delta; -use super::operator::{ComputationTracker, FilterPredicate}; +use super::operator::ComputationTracker; use crate::schema::{BTreeTable, Schema}; use crate::storage::btree::BTreeCursor; use crate::translate::logical::LogicalPlanBuilder; @@ -163,8 +163,6 @@ impl AllViewsTxState { #[derive(Debug)] pub struct IncrementalView { name: String, - // WHERE clause predicate for filtering (kept for compatibility) - pub where_predicate: FilterPredicate, // The SELECT statement that defines how to transform input data pub select_stmt: ast::Select, @@ -173,6 +171,11 @@ pub struct IncrementalView { // All tables referenced by this view (from FROM clause and JOINs) referenced_tables: Vec>, + // Mapping from table aliases to actual table names (e.g., "c" -> "customers") + table_aliases: HashMap, + // Mapping from table name to fully qualified name (e.g., "customers" -> "main.customers") + // This preserves database qualification from the original query + qualified_table_names: HashMap, // The view's column schema with table relationships pub column_schema: ViewColumnSchema, // State machine for population @@ -301,8 +304,6 @@ impl IncrementalView { ) -> Result { let name = view_name.name.as_str().to_string(); - let where_predicate = FilterPredicate::from_select(&select)?; - // Extract output columns using the shared function let column_schema = extract_view_columns(&select, schema)?; @@ -313,14 +314,16 @@ impl IncrementalView { )); } - // Get all tables from FROM clause and JOINs - let referenced_tables = Self::extract_all_tables(&select, schema)?; + // Get all tables from FROM clause and JOINs, along with their aliases + let (referenced_tables, table_aliases, qualified_table_names) = + Self::extract_all_tables(&select, schema)?; Self::new( name, - where_predicate, select.clone(), referenced_tables, + table_aliases, + qualified_table_names, column_schema, schema, main_data_root, @@ -332,9 +335,10 @@ impl IncrementalView { #[allow(clippy::too_many_arguments)] pub fn new( name: String, - where_predicate: FilterPredicate, select_stmt: ast::Select, referenced_tables: Vec>, + table_aliases: HashMap, + qualified_table_names: HashMap, column_schema: ViewColumnSchema, schema: &Schema, main_data_root: usize, @@ -355,10 +359,11 @@ impl IncrementalView { Ok(Self { name, - where_predicate, select_stmt, circuit, referenced_tables, + table_aliases, + qualified_table_names, column_schema, populate_state: PopulateState::Start, tracker, @@ -402,9 +407,22 @@ impl IncrementalView { self.referenced_tables.clone() } - /// Extract all table names from a SELECT statement (including JOINs) - fn extract_all_tables(select: &ast::Select, schema: &Schema) -> Result>> { + /// Extract all tables and their aliases from the SELECT statement + /// Returns a tuple of (tables, alias_map, qualified_names) + /// where alias_map is alias -> table_name + /// and qualified_names is table_name -> fully_qualified_name + #[allow(clippy::type_complexity)] + fn extract_all_tables( + select: &ast::Select, + schema: &Schema, + ) -> Result<( + Vec>, + HashMap, + HashMap, + )> { let mut tables = Vec::new(); + let mut aliases = HashMap::new(); + let mut qualified_names = HashMap::new(); if let ast::OneSelect::Select { from: Some(ref from), @@ -412,10 +430,24 @@ impl IncrementalView { } = select.body.select { // Get the main table from FROM clause - if let ast::SelectTable::Table(name, _, _) = from.select.as_ref() { + if let ast::SelectTable::Table(name, alias, _) = from.select.as_ref() { let table_name = name.name.as_str(); + + // Build the fully qualified name + let qualified_name = if let Some(ref db) = name.db_name { + format!("{db}.{table_name}") + } else { + table_name.to_string() + }; + if let Some(table) = schema.get_btree_table(table_name) { tables.push(table.clone()); + qualified_names.insert(table_name.to_string(), qualified_name); + + // Store the alias mapping if there is an alias + if let Some(alias_name) = alias { + aliases.insert(alias_name.to_string(), table_name.to_string()); + } } else { return Err(LimboError::ParseError(format!( "Table '{table_name}' not found in schema" @@ -425,10 +457,24 @@ impl IncrementalView { // Get all tables from JOIN clauses for join in &from.joins { - if let ast::SelectTable::Table(name, _, _) = join.table.as_ref() { + if let ast::SelectTable::Table(name, alias, _) = join.table.as_ref() { let table_name = name.name.as_str(); + + // Build the fully qualified name + let qualified_name = if let Some(ref db) = name.db_name { + format!("{db}.{table_name}") + } else { + table_name.to_string() + }; + if let Some(table) = schema.get_btree_table(table_name) { tables.push(table.clone()); + qualified_names.insert(table_name.to_string(), qualified_name); + + // Store the alias mapping if there is an alias + if let Some(alias_name) = alias { + aliases.insert(alias_name.to_string(), table_name.to_string()); + } } else { return Err(LimboError::ParseError(format!( "Table '{table_name}' not found in schema" @@ -444,90 +490,346 @@ impl IncrementalView { )); } - Ok(tables) + Ok((tables, aliases, qualified_names)) } - /// Generate the SQL query for populating the view from its source table - fn sql_for_populate(&self) -> crate::Result { - // Get the first table from referenced tables + /// Generate SQL queries for populating the view from each source table + /// Returns a vector of SQL statements, one for each referenced table + /// Each query includes only the WHERE conditions relevant to that specific table + fn sql_for_populate(&self) -> crate::Result> { if self.referenced_tables.is_empty() { return Err(LimboError::ParseError( "No tables to populate from".to_string(), )); } - let table = &self.referenced_tables[0]; - // Check if the table has a rowid alias (INTEGER PRIMARY KEY column) - let has_rowid_alias = table.columns.iter().any(|col| col.is_rowid_alias); + let mut queries = Vec::new(); - // For now, select all columns since we don't have the static operators - // The circuit will handle filtering and projection - // If there's a rowid alias, we don't need to select rowid separately - let select_clause = if has_rowid_alias { - "*".to_string() - } else { - "*, rowid".to_string() - }; + for table in &self.referenced_tables { + // Check if the table has a rowid alias (INTEGER PRIMARY KEY column) + let has_rowid_alias = table.columns.iter().any(|col| col.is_rowid_alias); - // Build WHERE clause from the where_predicate - let where_clause = self.build_where_clause(&self.where_predicate)?; + // For now, select all columns since we don't have the static operators + // The circuit will handle filtering and projection + // If there's a rowid alias, we don't need to select rowid separately + let select_clause = if has_rowid_alias { + "*".to_string() + } else { + "*, rowid".to_string() + }; - // Construct the final query - let query = if where_clause.is_empty() { - format!("SELECT {} FROM {}", select_clause, table.name) - } else { - format!( - "SELECT {} FROM {} WHERE {}", - select_clause, table.name, where_clause - ) - }; - Ok(query) + // Extract WHERE conditions for this specific table + let where_clause = self.extract_where_clause_for_table(&table.name)?; + + // Use the qualified table name if available, otherwise just the table name + let table_name = self + .qualified_table_names + .get(&table.name) + .cloned() + .unwrap_or_else(|| table.name.clone()); + + // Construct the query for this table + let query = if where_clause.is_empty() { + format!("SELECT {select_clause} FROM {table_name}") + } else { + format!("SELECT {select_clause} FROM {table_name} WHERE {where_clause}") + }; + queries.push(query); + } + + Ok(queries) } - /// Build a WHERE clause from a FilterPredicate - fn build_where_clause(&self, predicate: &FilterPredicate) -> crate::Result { - match predicate { - FilterPredicate::None => Ok(String::new()), - FilterPredicate::Equals { column, value } => { - Ok(format!("{} = {}", column, self.value_to_sql(value))) + /// Extract WHERE conditions that apply to a specific table + /// This analyzes the WHERE clause in the SELECT statement and returns + /// only the conditions that reference the given table + fn extract_where_clause_for_table(&self, table_name: &str) -> crate::Result { + // For single table queries, return the entire WHERE clause (already unqualified) + if self.referenced_tables.len() == 1 { + if let ast::OneSelect::Select { + where_clause: Some(ref where_expr), + .. + } = self.select_stmt.body.select + { + // For single table, the expression should already be unqualified or qualified with the single table + // We need to unqualify it for the single-table query + let unqualified = self.unqualify_expression(where_expr, table_name); + return Ok(unqualified.to_string()); } - FilterPredicate::NotEquals { column, value } => { - Ok(format!("{} != {}", column, self.value_to_sql(value))) + return Ok(String::new()); + } + + // For multi-table queries (JOINs), extract conditions for the specific table + if let ast::OneSelect::Select { + where_clause: Some(ref where_expr), + .. + } = self.select_stmt.body.select + { + // Extract conditions that reference only the specified table + let table_conditions = self.extract_table_conditions(where_expr, table_name)?; + if let Some(conditions) = table_conditions { + // Unqualify the expression for single-table query + let unqualified = self.unqualify_expression(&conditions, table_name); + return Ok(unqualified.to_string()); } - FilterPredicate::GreaterThan { column, value } => { - Ok(format!("{} > {}", column, self.value_to_sql(value))) + } + + Ok(String::new()) + } + + /// Extract conditions from an expression that reference only the specified table + fn extract_table_conditions( + &self, + expr: &ast::Expr, + table_name: &str, + ) -> crate::Result> { + match expr { + ast::Expr::Binary(left, op, right) => { + match op { + ast::Operator::And => { + // For AND, we can extract conditions independently + let left_cond = self.extract_table_conditions(left, table_name)?; + let right_cond = self.extract_table_conditions(right, table_name)?; + + match (left_cond, right_cond) { + (Some(l), Some(r)) => { + // Both conditions apply to this table + Ok(Some(ast::Expr::Binary( + Box::new(l), + ast::Operator::And, + Box::new(r), + ))) + } + (Some(l), None) => Ok(Some(l)), + (None, Some(r)) => Ok(Some(r)), + (None, None) => Ok(None), + } + } + ast::Operator::Or => { + // For OR, both sides must reference the same table(s) + // If either side references multiple tables, we can't extract it + let left_tables = self.get_referenced_tables_in_expr(left)?; + let right_tables = self.get_referenced_tables_in_expr(right)?; + + // If both sides only reference our table, include the whole OR + if left_tables.len() == 1 + && left_tables.contains(&table_name.to_string()) + && right_tables.len() == 1 + && right_tables.contains(&table_name.to_string()) + { + Ok(Some(expr.clone())) + } else { + // OR condition involves multiple tables, can't extract + Ok(None) + } + } + _ => { + // For comparison operators, check if this condition references only our table + let referenced_tables = self.get_referenced_tables_in_expr(expr)?; + if referenced_tables.len() == 1 + && referenced_tables.contains(&table_name.to_string()) + { + Ok(Some(expr.clone())) + } else { + Ok(None) + } + } + } } - FilterPredicate::GreaterThanOrEqual { column, value } => { - Ok(format!("{} >= {}", column, self.value_to_sql(value))) + ast::Expr::Parenthesized(exprs) => { + if exprs.len() == 1 { + self.extract_table_conditions(&exprs[0], table_name) + } else { + Ok(None) + } } - FilterPredicate::LessThan { column, value } => { - Ok(format!("{} < {}", column, self.value_to_sql(value))) - } - FilterPredicate::LessThanOrEqual { column, value } => { - Ok(format!("{} <= {}", column, self.value_to_sql(value))) - } - FilterPredicate::And(left, right) => { - let left_clause = self.build_where_clause(left)?; - let right_clause = self.build_where_clause(right)?; - Ok(format!("({left_clause} AND {right_clause})")) - } - FilterPredicate::Or(left, right) => { - let left_clause = self.build_where_clause(left)?; - let right_clause = self.build_where_clause(right)?; - Ok(format!("({left_clause} OR {right_clause})")) + _ => { + // For other expressions, check if they reference only our table + let referenced_tables = self.get_referenced_tables_in_expr(expr)?; + if referenced_tables.len() == 1 + && referenced_tables.contains(&table_name.to_string()) + { + Ok(Some(expr.clone())) + } else { + Ok(None) + } } } } - /// Convert a Value to SQL literal representation - fn value_to_sql(&self, value: &Value) -> String { - match value { - Value::Null => "NULL".to_string(), - Value::Integer(i) => i.to_string(), - Value::Float(f) => f.to_string(), - Value::Text(t) => format!("'{}'", t.as_str().replace('\'', "''")), - Value::Blob(_) => "NULL".to_string(), // Blob literals not supported in WHERE clause yet + /// Get the set of table names referenced in an expression + fn get_referenced_tables_in_expr(&self, expr: &ast::Expr) -> crate::Result> { + let mut tables = Vec::new(); + self.collect_referenced_tables(expr, &mut tables)?; + // Deduplicate + tables.sort(); + tables.dedup(); + Ok(tables) + } + + /// Recursively collect table references from an expression + fn collect_referenced_tables( + &self, + expr: &ast::Expr, + tables: &mut Vec, + ) -> crate::Result<()> { + match expr { + ast::Expr::Binary(left, _, right) => { + self.collect_referenced_tables(left, tables)?; + self.collect_referenced_tables(right, tables)?; + } + ast::Expr::Qualified(table, _) => { + // This is a qualified column reference (table.column or alias.column) + // We need to resolve aliases to actual table names + let actual_table = self.resolve_table_alias(table.as_str()); + tables.push(actual_table); + } + ast::Expr::Id(column) => { + // Unqualified column reference + if self.referenced_tables.len() > 1 { + // In a JOIN context, check which tables have this column + let mut tables_with_column = Vec::new(); + for table in &self.referenced_tables { + if table + .columns + .iter() + .any(|c| c.name.as_ref() == Some(&column.to_string())) + { + tables_with_column.push(table.name.clone()); + } + } + + if tables_with_column.len() > 1 { + // Ambiguous column - this should have been caught earlier + // Return error to be safe + return Err(crate::LimboError::ParseError(format!( + "Ambiguous column name '{}' in WHERE clause - exists in tables: {}", + column, + tables_with_column.join(", ") + ))); + } else if tables_with_column.len() == 1 { + // Unambiguous - only one table has this column + // This is allowed by SQLite + tables.push(tables_with_column[0].clone()); + } else { + // Column doesn't exist in any table - this is an error + // but should be caught during compilation + return Err(crate::LimboError::ParseError(format!( + "Column '{column}' not found in any table" + ))); + } + } else { + // Single table context - unqualified columns belong to that table + if let Some(table) = self.referenced_tables.first() { + tables.push(table.name.clone()); + } + } + } + ast::Expr::DoublyQualified(_database, table, _column) => { + // For database.table.column, resolve the table name + let table_str = table.as_str(); + let actual_table = self.resolve_table_alias(table_str); + tables.push(actual_table); + } + ast::Expr::Parenthesized(exprs) => { + for e in exprs { + self.collect_referenced_tables(e, tables)?; + } + } + _ => { + // Literals and other expressions don't reference tables + } } + Ok(()) + } + + /// Convert a qualified expression to unqualified for single-table queries + /// This removes table prefixes from column references since they're not needed + /// when querying a single table + fn unqualify_expression(&self, expr: &ast::Expr, table_name: &str) -> ast::Expr { + match expr { + ast::Expr::Binary(left, op, right) => { + // Recursively unqualify both sides + ast::Expr::Binary( + Box::new(self.unqualify_expression(left, table_name)), + *op, + Box::new(self.unqualify_expression(right, table_name)), + ) + } + ast::Expr::Qualified(table, column) => { + // Convert qualified column to unqualified if it's for our table + // Handle both "table.column" and "database.table.column" cases + let table_str = table.as_str(); + + // Check if this is a database.table reference + let actual_table = if table_str.contains('.') { + // Split on '.' and take the last part as the table name + table_str + .split('.') + .next_back() + .unwrap_or(table_str) + .to_string() + } else { + // Could be an alias or direct table name + self.resolve_table_alias(table_str) + }; + + if actual_table == table_name { + // Just return the column name without qualification + ast::Expr::Id(column.clone()) + } else { + // This shouldn't happen if extract_table_conditions worked correctly + // but keep it qualified just in case + expr.clone() + } + } + ast::Expr::DoublyQualified(_database, table, column) => { + // This is database.table.column format + // Check if the table matches our target table + let table_str = table.as_str(); + let actual_table = self.resolve_table_alias(table_str); + + if actual_table == table_name { + // Just return the column name without qualification + ast::Expr::Id(column.clone()) + } else { + // Keep it qualified if it's for a different table + expr.clone() + } + } + ast::Expr::Parenthesized(exprs) => { + // Recursively unqualify expressions in parentheses + let unqualified_exprs: Vec> = exprs + .iter() + .map(|e| Box::new(self.unqualify_expression(e, table_name))) + .collect(); + ast::Expr::Parenthesized(unqualified_exprs) + } + _ => { + // Other expression types (literals, unqualified columns, etc.) stay as-is + expr.clone() + } + } + } + + /// Resolve a table alias to the actual table name + fn resolve_table_alias(&self, alias: &str) -> String { + // Check if there's an alias mapping in the FROM/JOIN clauses + // For now, we'll do a simple check - if the alias matches a table name, use it + // Otherwise, try to find it in the FROM clause + + // First check if it's an actual table name + if self.referenced_tables.iter().any(|t| t.name == alias) { + return alias.to_string(); + } + + // Check if it's an alias that maps to a table + if let Some(table_name) = self.table_aliases.get(alias) { + return table_name.clone(); + } + + // If we can't resolve it, return as-is (it might be a table name we don't know about) + alias.to_string() } /// Populate the view by scanning the source table using a state machine @@ -564,7 +866,15 @@ impl IncrementalView { // btree and not in the table btree. Using cursors would force us to be aware of this // distinction (and others), and ultimately lead to reimplementing the whole query // machinery (next step is which index is best to use, etc) - let query = self.sql_for_populate()?; + let queries = self.sql_for_populate()?; + + // For now, only use the first query (single table population) + if queries.is_empty() { + return Err(LimboError::ParseError( + "No populate queries generated".to_string(), + )); + } + let query = queries[0].clone(); // Create a new connection for reading to avoid transaction conflicts // This allows us to read from tables while the parent transaction is writing the view @@ -958,15 +1268,76 @@ mod tests { collation: None, hidden: false, }, + SchemaColumn { + name: Some("price".to_string()), + ty: Type::Real, + ty_str: "REAL".to_string(), + primary_key: false, + is_rowid_alias: false, + notnull: false, + default: None, + unique: false, + collation: None, + hidden: false, + }, ], has_rowid: true, is_strict: false, unique_sets: vec![], }; + // Create logs table - without a rowid alias (no INTEGER PRIMARY KEY) + let logs_table = BTreeTable { + name: "logs".to_string(), + root_page: 5, + primary_key_columns: vec![], // No primary key, so no rowid alias + columns: vec![ + SchemaColumn { + name: Some("message".to_string()), + ty: Type::Text, + ty_str: "TEXT".to_string(), + primary_key: false, + is_rowid_alias: false, + notnull: false, + default: None, + unique: false, + collation: None, + hidden: false, + }, + SchemaColumn { + name: Some("level".to_string()), + ty: Type::Integer, + ty_str: "INTEGER".to_string(), + primary_key: false, + is_rowid_alias: false, + notnull: false, + default: None, + unique: false, + collation: None, + hidden: false, + }, + SchemaColumn { + name: Some("timestamp".to_string()), + ty: Type::Integer, + ty_str: "INTEGER".to_string(), + primary_key: false, + is_rowid_alias: false, + notnull: false, + default: None, + unique: false, + collation: None, + hidden: false, + }, + ], + has_rowid: true, // Has implicit rowid but no alias + is_strict: false, + unique_sets: vec![], + }; + schema.add_btree_table(Arc::new(customers_table)); schema.add_btree_table(Arc::new(orders_table)); schema.add_btree_table(Arc::new(products_table)); + schema.add_btree_table(Arc::new(logs_table)); schema } @@ -985,7 +1356,7 @@ mod tests { let schema = create_test_schema(); let select = parse_select("SELECT * FROM customers"); - let tables = IncrementalView::extract_all_tables(&select, &schema).unwrap(); + let (tables, _, _) = IncrementalView::extract_all_tables(&select, &schema).unwrap(); assert_eq!(tables.len(), 1); assert_eq!(tables[0].name, "customers"); @@ -998,7 +1369,7 @@ mod tests { "SELECT * FROM customers INNER JOIN orders ON customers.id = orders.customer_id", ); - let tables = IncrementalView::extract_all_tables(&select, &schema).unwrap(); + let (tables, _, _) = IncrementalView::extract_all_tables(&select, &schema).unwrap(); assert_eq!(tables.len(), 2); assert_eq!(tables[0].name, "customers"); @@ -1014,7 +1385,7 @@ mod tests { INNER JOIN products ON orders.id = products.id", ); - let tables = IncrementalView::extract_all_tables(&select, &schema).unwrap(); + let (tables, _, _) = IncrementalView::extract_all_tables(&select, &schema).unwrap(); assert_eq!(tables.len(), 3); assert_eq!(tables[0].name, "customers"); @@ -1029,7 +1400,7 @@ mod tests { "SELECT * FROM customers LEFT JOIN orders ON customers.id = orders.customer_id", ); - let tables = IncrementalView::extract_all_tables(&select, &schema).unwrap(); + let (tables, _, _) = IncrementalView::extract_all_tables(&select, &schema).unwrap(); assert_eq!(tables.len(), 2); assert_eq!(tables[0].name, "customers"); @@ -1041,7 +1412,7 @@ mod tests { let schema = create_test_schema(); let select = parse_select("SELECT * FROM customers CROSS JOIN orders"); - let tables = IncrementalView::extract_all_tables(&select, &schema).unwrap(); + let (tables, _, _) = IncrementalView::extract_all_tables(&select, &schema).unwrap(); assert_eq!(tables.len(), 2); assert_eq!(tables[0].name, "customers"); @@ -1054,7 +1425,7 @@ mod tests { let select = parse_select("SELECT * FROM customers c INNER JOIN orders o ON c.id = o.customer_id"); - let tables = IncrementalView::extract_all_tables(&select, &schema).unwrap(); + let (tables, _, _) = IncrementalView::extract_all_tables(&select, &schema).unwrap(); // Should still extract the actual table names, not aliases assert_eq!(tables.len(), 2); @@ -1067,7 +1438,8 @@ mod tests { let schema = create_test_schema(); let select = parse_select("SELECT * FROM nonexistent"); - let result = IncrementalView::extract_all_tables(&select, &schema); + let result = + IncrementalView::extract_all_tables(&select, &schema).map(|(tables, _, _)| tables); assert!(result.is_err()); assert!(result @@ -1083,7 +1455,8 @@ mod tests { "SELECT * FROM customers INNER JOIN nonexistent ON customers.id = nonexistent.id", ); - let result = IncrementalView::extract_all_tables(&select, &schema); + let result = + IncrementalView::extract_all_tables(&select, &schema).map(|(tables, _, _)| tables); assert!(result.is_err()); assert!(result @@ -1091,4 +1464,526 @@ mod tests { .to_string() .contains("Table 'nonexistent' not found")); } + + #[test] + fn test_sql_for_populate_simple_query_no_where() { + // Test simple query with no WHERE clause + let schema = create_test_schema(); + let select = parse_select("SELECT * FROM customers"); + + let (tables, aliases, qualified_names) = + IncrementalView::extract_all_tables(&select, &schema).unwrap(); + let view = IncrementalView::new( + "test_view".to_string(), + select.clone(), + tables, + aliases, + qualified_names, + extract_view_columns(&select, &schema).unwrap(), + &schema, + 1, // main_data_root + 2, // internal_state_root + 3, // internal_state_index_root + ) + .unwrap(); + + let queries = view.sql_for_populate().unwrap(); + + assert_eq!(queries.len(), 1); + // customers has id as rowid alias, so no need for explicit rowid + assert_eq!(queries[0], "SELECT * FROM customers"); + } + + #[test] + fn test_sql_for_populate_simple_query_with_where() { + // Test simple query with WHERE clause + let schema = create_test_schema(); + let select = parse_select("SELECT * FROM customers WHERE id > 10"); + + let (tables, aliases, qualified_names) = + IncrementalView::extract_all_tables(&select, &schema).unwrap(); + let view = IncrementalView::new( + "test_view".to_string(), + select.clone(), + tables, + aliases, + qualified_names, + extract_view_columns(&select, &schema).unwrap(), + &schema, + 1, // main_data_root + 2, // internal_state_root + 3, // internal_state_index_root + ) + .unwrap(); + + let queries = view.sql_for_populate().unwrap(); + + assert_eq!(queries.len(), 1); + // For single-table queries, we should get the full WHERE clause + assert_eq!(queries[0], "SELECT * FROM customers WHERE id > 10"); + } + + #[test] + fn test_sql_for_populate_join_with_where_on_both_tables() { + // Test JOIN query with WHERE conditions on both tables + let schema = create_test_schema(); + let select = parse_select( + "SELECT * FROM customers c \ + JOIN orders o ON c.id = o.customer_id \ + WHERE c.id > 10 AND o.total > 100", + ); + + let (tables, aliases, qualified_names) = + IncrementalView::extract_all_tables(&select, &schema).unwrap(); + let view = IncrementalView::new( + "test_view".to_string(), + select.clone(), + tables, + aliases, + qualified_names, + extract_view_columns(&select, &schema).unwrap(), + &schema, + 1, // main_data_root + 2, // internal_state_root + 3, // internal_state_index_root + ) + .unwrap(); + + let queries = view.sql_for_populate().unwrap(); + + assert_eq!(queries.len(), 2); + + // With per-table WHERE extraction: + // - customers table gets: c.id > 10 + // - orders table gets: o.total > 100 + assert_eq!(queries[0], "SELECT * FROM customers WHERE id > 10"); + assert_eq!(queries[1], "SELECT * FROM orders WHERE total > 100"); + } + + #[test] + fn test_sql_for_populate_complex_join_with_mixed_conditions() { + // Test complex JOIN with WHERE conditions mixing both tables + let schema = create_test_schema(); + let select = parse_select( + "SELECT * FROM customers c \ + JOIN orders o ON c.id = o.customer_id \ + WHERE c.id > 10 AND o.total > 100 AND c.name = 'John' \ + AND o.customer_id = 5 AND (c.id = 15 OR o.total = 200)", + ); + + let (tables, aliases, qualified_names) = + IncrementalView::extract_all_tables(&select, &schema).unwrap(); + let view = IncrementalView::new( + "test_view".to_string(), + select.clone(), + tables, + aliases, + qualified_names, + extract_view_columns(&select, &schema).unwrap(), + &schema, + 1, // main_data_root + 2, // internal_state_root + 3, // internal_state_index_root + ) + .unwrap(); + + let queries = view.sql_for_populate().unwrap(); + + assert_eq!(queries.len(), 2); + + // With per-table WHERE extraction: + // - customers gets: c.id > 10 AND c.name = 'John' + // - orders gets: o.total > 100 AND o.customer_id = 5 + // Note: The OR condition (c.id = 15 OR o.total = 200) involves both tables, + // so it cannot be extracted to either table individually + assert_eq!( + queries[0], + "SELECT * FROM customers WHERE id > 10 AND name = 'John'" + ); + assert_eq!( + queries[1], + "SELECT * FROM orders WHERE total > 100 AND customer_id = 5" + ); + } + + #[test] + fn test_where_extraction_for_three_tables() { + // Test that WHERE clause extraction correctly separates conditions for 3+ tables + // This addresses the concern about conditions "piling up" as joins increase + + // Simulate a three-table scenario + let schema = create_test_schema(); + + // Parse a WHERE clause with conditions for three different tables + let select = parse_select( + "SELECT * FROM customers WHERE c.id > 10 AND o.total > 100 AND p.price > 50", + ); + + // Get the WHERE expression + if let ast::OneSelect::Select { + where_clause: Some(ref where_expr), + .. + } = select.body.select + { + // Create a view with three tables to test extraction + let tables = vec![ + schema.get_btree_table("customers").unwrap(), + schema.get_btree_table("orders").unwrap(), + schema.get_btree_table("products").unwrap(), + ]; + + let mut aliases = HashMap::new(); + aliases.insert("c".to_string(), "customers".to_string()); + aliases.insert("o".to_string(), "orders".to_string()); + aliases.insert("p".to_string(), "products".to_string()); + + // Create a minimal view just to test extraction logic + let view = IncrementalView { + name: "test".to_string(), + select_stmt: select.clone(), + circuit: DbspCircuit::new(1, 2, 3), + referenced_tables: tables, + table_aliases: aliases, + qualified_table_names: HashMap::new(), + column_schema: ViewColumnSchema { + columns: vec![], + tables: vec![], + }, + populate_state: PopulateState::Start, + tracker: Arc::new(Mutex::new(ComputationTracker::new())), + root_page: 0, + }; + + // Test extraction for each table + let customers_conds = view + .extract_table_conditions(where_expr, "customers") + .unwrap(); + let orders_conds = view.extract_table_conditions(where_expr, "orders").unwrap(); + let products_conds = view + .extract_table_conditions(where_expr, "products") + .unwrap(); + + // Verify each table only gets its conditions + if let Some(cond) = customers_conds { + let sql = cond.to_string(); + assert!(sql.contains("id > 10")); + assert!(!sql.contains("total")); + assert!(!sql.contains("price")); + } + + if let Some(cond) = orders_conds { + let sql = cond.to_string(); + assert!(sql.contains("total > 100")); + assert!(!sql.contains("id > 10")); // From customers + assert!(!sql.contains("price")); + } + + if let Some(cond) = products_conds { + let sql = cond.to_string(); + assert!(sql.contains("price > 50")); + assert!(!sql.contains("id > 10")); // From customers + assert!(!sql.contains("total")); + } + } else { + panic!("Failed to parse WHERE clause"); + } + } + + #[test] + fn test_alias_resolution_works_correctly() { + // Test that alias resolution properly maps aliases to table names + let schema = create_test_schema(); + let select = parse_select( + "SELECT * FROM customers c \ + JOIN orders o ON c.id = o.customer_id \ + WHERE c.id > 10 AND o.total > 100", + ); + + let (tables, aliases, qualified_names) = + IncrementalView::extract_all_tables(&select, &schema).unwrap(); + let view = IncrementalView::new( + "test_view".to_string(), + select.clone(), + tables, + aliases, + qualified_names, + extract_view_columns(&select, &schema).unwrap(), + &schema, + 1, // main_data_root + 2, // internal_state_root + 3, // internal_state_index_root + ) + .unwrap(); + + // Verify that alias mappings were extracted correctly + assert_eq!(view.table_aliases.get("c"), Some(&"customers".to_string())); + assert_eq!(view.table_aliases.get("o"), Some(&"orders".to_string())); + + // Verify that SQL generation uses the aliases correctly + let queries = view.sql_for_populate().unwrap(); + assert_eq!(queries.len(), 2); + + // Each query should use the actual table name, not the alias + assert!(queries[0].contains("FROM customers") || queries[1].contains("FROM customers")); + assert!(queries[0].contains("FROM orders") || queries[1].contains("FROM orders")); + } + + #[test] + fn test_sql_for_populate_table_without_rowid_alias() { + // Test that tables without a rowid alias properly include rowid in SELECT + let schema = create_test_schema(); + let select = parse_select("SELECT * FROM logs WHERE level > 2"); + + let (tables, aliases, qualified_names) = + IncrementalView::extract_all_tables(&select, &schema).unwrap(); + let view = IncrementalView::new( + "test_view".to_string(), + select.clone(), + tables, + aliases, + qualified_names, + extract_view_columns(&select, &schema).unwrap(), + &schema, + 1, // main_data_root + 2, // internal_state_root + 3, // internal_state_index_root + ) + .unwrap(); + + let queries = view.sql_for_populate().unwrap(); + + assert_eq!(queries.len(), 1); + // logs table has no rowid alias, so we need to explicitly select rowid + assert_eq!(queries[0], "SELECT *, rowid FROM logs WHERE level > 2"); + } + + #[test] + fn test_sql_for_populate_join_with_and_without_rowid_alias() { + // Test JOIN between a table with rowid alias and one without + let schema = create_test_schema(); + let select = parse_select( + "SELECT * FROM customers c \ + JOIN logs l ON c.id = l.level \ + WHERE c.id > 10 AND l.level > 2", + ); + + let (tables, aliases, qualified_names) = + IncrementalView::extract_all_tables(&select, &schema).unwrap(); + let view = IncrementalView::new( + "test_view".to_string(), + select.clone(), + tables, + aliases, + qualified_names, + extract_view_columns(&select, &schema).unwrap(), + &schema, + 1, // main_data_root + 2, // internal_state_root + 3, // internal_state_index_root + ) + .unwrap(); + + let queries = view.sql_for_populate().unwrap(); + + assert_eq!(queries.len(), 2); + // customers has rowid alias (id), logs doesn't + assert_eq!(queries[0], "SELECT * FROM customers WHERE id > 10"); + assert_eq!(queries[1], "SELECT *, rowid FROM logs WHERE level > 2"); + } + + #[test] + fn test_sql_for_populate_with_database_qualified_names() { + // Test that database.table.column references are handled correctly + // The table name in FROM should keep the database prefix, + // but column names in WHERE should be unqualified + let schema = create_test_schema(); + + // Test with single table using database qualification + let select = parse_select("SELECT * FROM main.customers WHERE main.customers.id > 10"); + + let (tables, aliases, qualified_names) = + IncrementalView::extract_all_tables(&select, &schema).unwrap(); + let view = IncrementalView::new( + "test_view".to_string(), + select.clone(), + tables, + aliases, + qualified_names, + extract_view_columns(&select, &schema).unwrap(), + &schema, + 1, // main_data_root + 2, // internal_state_root + 3, // internal_state_index_root + ) + .unwrap(); + + let queries = view.sql_for_populate().unwrap(); + + assert_eq!(queries.len(), 1); + // The FROM clause should preserve the database qualification, + // but the WHERE clause should have unqualified column names + assert_eq!(queries[0], "SELECT * FROM main.customers WHERE id > 10"); + } + + #[test] + fn test_sql_for_populate_join_with_database_qualified_names() { + // Test JOIN with database-qualified table and column references + let schema = create_test_schema(); + + let select = parse_select( + "SELECT * FROM main.customers c \ + JOIN main.orders o ON c.id = o.customer_id \ + WHERE main.customers.id > 10 AND main.orders.total > 100", + ); + + let (tables, aliases, qualified_names) = + IncrementalView::extract_all_tables(&select, &schema).unwrap(); + let view = IncrementalView::new( + "test_view".to_string(), + select.clone(), + tables, + aliases, + qualified_names, + extract_view_columns(&select, &schema).unwrap(), + &schema, + 1, // main_data_root + 2, // internal_state_root + 3, // internal_state_index_root + ) + .unwrap(); + + let queries = view.sql_for_populate().unwrap(); + + assert_eq!(queries.len(), 2); + // The FROM clauses should preserve database qualification, + // but WHERE clauses should have unqualified column names + assert_eq!(queries[0], "SELECT * FROM main.customers WHERE id > 10"); + assert_eq!(queries[1], "SELECT * FROM main.orders WHERE total > 100"); + } + + #[test] + fn test_sql_for_populate_unambiguous_unqualified_column() { + // Test that unambiguous unqualified columns ARE extracted + let schema = create_test_schema(); + let select = parse_select( + "SELECT * FROM customers c \ + JOIN orders o ON c.id = o.customer_id \ + WHERE total > 100", // 'total' only exists in orders table + ); + + let (tables, aliases, qualified_names) = + IncrementalView::extract_all_tables(&select, &schema).unwrap(); + let view = IncrementalView::new( + "test_view".to_string(), + select.clone(), + tables, + aliases, + qualified_names, + extract_view_columns(&select, &schema).unwrap(), + &schema, + 1, // main_data_root + 2, // internal_state_root + 3, // internal_state_index_root + ) + .unwrap(); + + let queries = view.sql_for_populate().unwrap(); + + assert_eq!(queries.len(), 2); + + // 'total' is unambiguous (only in orders), so it should be extracted + assert_eq!(queries[0], "SELECT * FROM customers"); + assert_eq!(queries[1], "SELECT * FROM orders WHERE total > 100"); + } + + #[test] + fn test_database_qualified_table_names() { + let schema = create_test_schema(); + + // Test with database-qualified table names + let select = parse_select( + "SELECT c.id, c.name, o.id, o.total + FROM main.customers c + JOIN main.orders o ON c.id = o.customer_id + WHERE c.id > 10", + ); + + let (tables, aliases, qualified_names) = + IncrementalView::extract_all_tables(&select, &schema).unwrap(); + + // Check that qualified names are preserved + assert!(qualified_names.contains_key("customers")); + assert_eq!(qualified_names.get("customers").unwrap(), "main.customers"); + assert!(qualified_names.contains_key("orders")); + assert_eq!(qualified_names.get("orders").unwrap(), "main.orders"); + + let view = IncrementalView::new( + "test_view".to_string(), + select.clone(), + tables, + aliases, + qualified_names.clone(), + extract_view_columns(&select, &schema).unwrap(), + &schema, + 1, // main_data_root + 2, // internal_state_root + 3, // internal_state_index_root + ) + .unwrap(); + + let queries = view.sql_for_populate().unwrap(); + + assert_eq!(queries.len(), 2); + + // The FROM clause should contain the database-qualified name + // But the WHERE clause should use unqualified column names + assert_eq!(queries[0], "SELECT * FROM main.customers WHERE id > 10"); + assert_eq!(queries[1], "SELECT * FROM main.orders"); + } + + #[test] + fn test_mixed_qualified_unqualified_tables() { + let schema = create_test_schema(); + + // Test with a mix of qualified and unqualified table names + let select = parse_select( + "SELECT c.id, c.name, o.id, o.total + FROM main.customers c + JOIN orders o ON c.id = o.customer_id + WHERE c.id > 10 AND o.total < 1000", + ); + + let (tables, aliases, qualified_names) = + IncrementalView::extract_all_tables(&select, &schema).unwrap(); + + // Check that qualified names are preserved where specified + assert_eq!(qualified_names.get("customers").unwrap(), "main.customers"); + // Unqualified tables should not have an entry (or have the bare name) + assert!( + !qualified_names.contains_key("orders") + || qualified_names.get("orders").unwrap() == "orders" + ); + + let view = IncrementalView::new( + "test_view".to_string(), + select.clone(), + tables, + aliases, + qualified_names.clone(), + extract_view_columns(&select, &schema).unwrap(), + &schema, + 1, // main_data_root + 2, // internal_state_root + 3, // internal_state_index_root + ) + .unwrap(); + + let queries = view.sql_for_populate().unwrap(); + + assert_eq!(queries.len(), 2); + + // The FROM clause should preserve qualification where specified + assert_eq!(queries[0], "SELECT * FROM main.customers WHERE id > 10"); + assert_eq!(queries[1], "SELECT * FROM orders WHERE total < 1000"); + } }