From 9f54f60d458be90678756d946872494cb95f0bc7 Mon Sep 17 00:00:00 2001 From: Glauber Costa Date: Sat, 20 Sep 2025 20:06:22 -0300 Subject: [PATCH 1/3] make sure that complex select statements are captured by MV populate The population code extracts table information from the select statement so it can populate the materialized view. But the code, as written today, is naive. It doesn't capture table information correctly if there is more than one select statement (such in the case of a union query). --- core/incremental/view.rs | 1761 ++++++++++++++++++++++++++------------ 1 file changed, 1237 insertions(+), 524 deletions(-) diff --git a/core/incremental/view.rs b/core/incremental/view.rs index fd7b3988a..65fa5e2bb 100644 --- a/core/incremental/view.rs +++ b/core/incremental/view.rs @@ -8,7 +8,7 @@ use crate::types::{IOResult, Value}; use crate::util::{extract_view_columns, ViewColumnSchema}; use crate::{return_if_io, LimboError, Pager, Result, Statement}; use std::cell::RefCell; -use std::collections::HashMap; +use std::collections::{HashMap, HashSet}; use std::fmt; use std::rc::Rc; use std::sync::{Arc, Mutex}; @@ -195,6 +195,9 @@ pub struct IncrementalView { // Mapping from table name to fully qualified name (e.g., "customers" -> "main.customers") // This preserves database qualification from the original query qualified_table_names: HashMap, + // WHERE conditions for each table (accumulated from all occurrences) + // Multiple conditions from UNION branches or duplicate references are stored as a vector + table_conditions: HashMap>>, // The view's column schema with table relationships pub column_schema: ViewColumnSchema, // State machine for population @@ -312,9 +315,18 @@ impl IncrementalView { // Extract output columns using the shared function let column_schema = extract_view_columns(&select, schema)?; - // Get all tables from FROM clause and JOINs, along with their aliases - let (referenced_tables, table_aliases, qualified_table_names) = - Self::extract_all_tables(&select, schema)?; + let mut referenced_tables = Vec::new(); + let mut table_aliases = HashMap::new(); + let mut qualified_table_names = HashMap::new(); + let mut table_conditions = HashMap::new(); + Self::extract_all_tables( + &select, + schema, + &mut referenced_tables, + &mut table_aliases, + &mut qualified_table_names, + &mut table_conditions, + )?; Self::new( name, @@ -322,6 +334,7 @@ impl IncrementalView { referenced_tables, table_aliases, qualified_table_names, + table_conditions, column_schema, schema, main_data_root, @@ -337,6 +350,7 @@ impl IncrementalView { referenced_tables: Vec>, table_aliases: HashMap, qualified_table_names: HashMap, + table_conditions: HashMap>>, column_schema: ViewColumnSchema, schema: &Schema, main_data_root: usize, @@ -362,6 +376,7 @@ impl IncrementalView { referenced_tables, table_aliases, qualified_table_names, + table_conditions, column_schema, populate_state: PopulateState::Start, tracker, @@ -405,97 +420,249 @@ impl IncrementalView { self.referenced_tables.clone() } - /// Extract all tables and their aliases from the SELECT statement - /// Returns a tuple of (tables, alias_map, qualified_names) - /// where alias_map is alias -> table_name - /// and qualified_names is table_name -> fully_qualified_name - #[allow(clippy::type_complexity)] - fn extract_all_tables( - select: &ast::Select, + /// Process a single table reference from a FROM or JOIN clause + fn process_table_reference( + name: &ast::QualifiedName, + alias: &Option, schema: &Schema, - ) -> Result<( - Vec>, - HashMap, - HashMap, - )> { - let mut tables = Vec::new(); - let mut aliases = HashMap::new(); - let mut qualified_names = HashMap::new(); + table_map: &mut HashMap>, + aliases: &mut HashMap, + qualified_names: &mut HashMap, + cte_names: &HashSet, + ) -> Result<()> { + let table_name = name.name.as_str(); + // Build the fully qualified name + let qualified_name = if let Some(ref db) = name.db_name { + format!("{db}.{table_name}") + } else { + table_name.to_string() + }; + + // Skip CTEs - they're not real tables + if !cte_names.contains(table_name) { + if let Some(table) = schema.get_btree_table(table_name) { + table_map.insert(table_name.to_string(), table.clone()); + qualified_names.insert(table_name.to_string(), qualified_name); + + // Store the alias mapping if there is an alias + if let Some(alias_enum) = alias { + let alias_name = match alias_enum { + ast::As::As(name) | ast::As::Elided(name) => match name { + ast::Name::Ident(s) | ast::Name::Quoted(s) => s, + }, + }; + aliases.insert(alias_name.to_string(), table_name.to_string()); + } + } else { + return Err(LimboError::ParseError(format!( + "Table '{table_name}' not found in schema" + ))); + } + } + Ok(()) + } + + fn extract_one_statement( + select: &ast::OneSelect, + schema: &Schema, + table_map: &mut HashMap>, + aliases: &mut HashMap, + qualified_names: &mut HashMap, + table_conditions: &mut HashMap>>, + cte_names: &HashSet, + ) -> Result<()> { if let ast::OneSelect::Select { from: Some(ref from), .. - } = select.body.select + } = select { // Get the main table from FROM clause if let ast::SelectTable::Table(name, alias, _) = from.select.as_ref() { - let table_name = name.name.as_str(); - - // Build the fully qualified name - let qualified_name = if let Some(ref db) = name.db_name { - format!("{db}.{table_name}") - } else { - table_name.to_string() - }; - - if let Some(table) = schema.get_btree_table(table_name) { - tables.push(table.clone()); - qualified_names.insert(table_name.to_string(), qualified_name); - - // Store the alias mapping if there is an alias - if let Some(alias_name) = alias { - aliases.insert(alias_name.to_string(), table_name.to_string()); - } - } else { - return Err(LimboError::ParseError(format!( - "Table '{table_name}' not found in schema" - ))); - } + Self::process_table_reference( + name, + alias, + schema, + table_map, + aliases, + qualified_names, + cte_names, + )?; } // Get all tables from JOIN clauses for join in &from.joins { if let ast::SelectTable::Table(name, alias, _) = join.table.as_ref() { - let table_name = name.name.as_str(); - - // Build the fully qualified name - let qualified_name = if let Some(ref db) = name.db_name { - format!("{db}.{table_name}") - } else { - table_name.to_string() - }; - - if let Some(table) = schema.get_btree_table(table_name) { - tables.push(table.clone()); - qualified_names.insert(table_name.to_string(), qualified_name); - - // Store the alias mapping if there is an alias - if let Some(alias_name) = alias { - aliases.insert(alias_name.to_string(), table_name.to_string()); - } - } else { - return Err(LimboError::ParseError(format!( - "Table '{table_name}' not found in schema" - ))); - } + Self::process_table_reference( + name, + alias, + schema, + table_map, + aliases, + qualified_names, + cte_names, + )?; } } } + // Extract WHERE conditions for this SELECT + let where_expr = if let ast::OneSelect::Select { + where_clause: Some(ref where_expr), + .. + } = select + { + Some(where_expr.as_ref().clone()) + } else { + None + }; - if tables.is_empty() { - return Err(LimboError::ParseError( - "No tables found in SELECT statement".to_string(), - )); + // Ensure all tables have an entry in table_conditions (even if empty) + for table_name in table_map.keys() { + table_conditions.entry(table_name.clone()).or_default(); } - Ok((tables, aliases, qualified_names)) + // Extract and store table-specific conditions from the WHERE clause + if let Some(ref where_expr) = where_expr { + for table_name in table_map.keys() { + let all_tables: Vec = table_map.keys().cloned().collect(); + let table_specific_condition = Self::extract_conditions_for_table( + where_expr, + table_name, + aliases, + &all_tables, + schema, + ); + // Only add if there's actually a condition for this table + if let Some(condition) = table_specific_condition { + let conditions = table_conditions.get_mut(table_name).unwrap(); + conditions.push(Some(condition)); + } + } + } else { + // No WHERE clause - push None for all tables in this SELECT. It is a way + // of signaling that we need all rows in the table. It is important we signal this + // explicitly, because the same table may appear in many conditions - some of which + // have filters that would otherwise be applied. + for table_name in table_map.keys() { + let conditions = table_conditions.get_mut(table_name).unwrap(); + conditions.push(None); + } + } + + Ok(()) + } + + /// Extract all tables and their aliases from the SELECT statement, handling CTEs + /// Deduplicates tables and accumulates WHERE conditions + fn extract_all_tables( + select: &ast::Select, + schema: &Schema, + tables: &mut Vec>, + aliases: &mut HashMap, + qualified_names: &mut HashMap, + table_conditions: &mut HashMap>>, + ) -> Result<()> { + let mut table_map = HashMap::new(); + Self::extract_all_tables_inner( + select, + schema, + &mut table_map, + aliases, + qualified_names, + table_conditions, + &HashSet::new(), + )?; + + // Convert deduplicated table map to vector + for (_name, table) in table_map { + tables.push(table); + } + + Ok(()) + } + + fn extract_all_tables_inner( + select: &ast::Select, + schema: &Schema, + table_map: &mut HashMap>, + aliases: &mut HashMap, + qualified_names: &mut HashMap, + table_conditions: &mut HashMap>>, + parent_cte_names: &HashSet, + ) -> Result<()> { + let mut cte_names = parent_cte_names.clone(); + + // First, collect CTE names and process any CTEs (WITH clauses) + if let Some(ref with) = select.with { + // First pass: collect all CTE names (needed for recursive CTEs) + for cte in &with.ctes { + cte_names.insert(cte.tbl_name.as_str().to_string()); + } + + // Second pass: extract tables from each CTE's SELECT statement + for cte in &with.ctes { + // Recursively extract tables from each CTE's SELECT statement + Self::extract_all_tables_inner( + &cte.select, + schema, + table_map, + aliases, + qualified_names, + table_conditions, + &cte_names, + )?; + } + } + + // Then process the main SELECT body + Self::extract_one_statement( + &select.body.select, + schema, + table_map, + aliases, + qualified_names, + table_conditions, + &cte_names, + )?; + + // Process any compound selects (UNION, etc.) + for c in &select.body.compounds { + let ast::CompoundSelect { select, .. } = c; + Self::extract_one_statement( + select, + schema, + table_map, + aliases, + qualified_names, + table_conditions, + &cte_names, + )?; + } + + Ok(()) } /// Generate SQL queries for populating the view from each source table /// Returns a vector of SQL statements, one for each referenced table - /// Each query includes only the WHERE conditions relevant to that specific table + /// Each query includes the WHERE conditions accumulated from all occurrences fn sql_for_populate(&self) -> crate::Result> { - if self.referenced_tables.is_empty() { + Self::generate_populate_queries( + &self.select_stmt, + &self.referenced_tables, + &self.table_aliases, + &self.qualified_table_names, + &self.table_conditions, + ) + } + + pub fn generate_populate_queries( + select_stmt: &ast::Select, + referenced_tables: &[Arc], + table_aliases: &HashMap, + qualified_table_names: &HashMap, + table_conditions: &HashMap>>, + ) -> crate::Result> { + if referenced_tables.is_empty() { return Err(LimboError::ParseError( "No tables to populate from".to_string(), )); @@ -503,12 +670,11 @@ impl IncrementalView { let mut queries = Vec::new(); - for table in &self.referenced_tables { + for table in referenced_tables { // Check if the table has a rowid alias (INTEGER PRIMARY KEY column) let has_rowid_alias = table.columns.iter().any(|col| col.is_rowid_alias); - // For now, select all columns since we don't have the static operators - // The circuit will handle filtering and projection + // Select all columns. The circuit will handle filtering and projection // If there's a rowid alias, we don't need to select rowid separately let select_clause = if has_rowid_alias { "*".to_string() @@ -516,12 +682,22 @@ impl IncrementalView { "*, rowid".to_string() }; - // Extract WHERE conditions for this specific table - let where_clause = self.extract_where_clause_for_table(&table.name)?; + // Get accumulated WHERE conditions for this table + let where_clause = if let Some(conditions) = table_conditions.get(&table.name) { + // Combine multiple conditions with OR if there are multiple occurrences + Self::combine_conditions( + select_stmt, + conditions, + &table.name, + referenced_tables, + table_aliases, + )? + } else { + String::new() + }; // Use the qualified table name if available, otherwise just the table name - let table_name = self - .qualified_table_names + let table_name = qualified_table_names .get(&table.name) .cloned() .unwrap_or_else(|| table.name.clone()); @@ -532,347 +708,405 @@ impl IncrementalView { } else { format!("SELECT {select_clause} FROM {table_name} WHERE {where_clause}") }; + tracing::debug!("populating materialized view with `{query}`"); queries.push(query); } Ok(queries) } - /// Extract WHERE conditions that apply to a specific table - /// This analyzes the WHERE clause in the SELECT statement and returns - /// only the conditions that reference the given table - fn extract_where_clause_for_table(&self, table_name: &str) -> crate::Result { - // For single table queries, return the entire WHERE clause (already unqualified) - if self.referenced_tables.len() == 1 { - if let ast::OneSelect::Select { - where_clause: Some(ref where_expr), - .. - } = self.select_stmt.body.select - { - // For single table, the expression should already be unqualified or qualified with the single table - // We need to unqualify it for the single-table query - let unqualified = self.unqualify_expression(where_expr, table_name); - return Ok(unqualified.to_string()); - } + fn combine_conditions( + _select_stmt: &ast::Select, + conditions: &[Option], + table_name: &str, + _referenced_tables: &[Arc], + table_aliases: &HashMap, + ) -> crate::Result { + // Check if any conditions are None (SELECTs without WHERE) + let has_none = conditions.iter().any(|c| c.is_none()); + let non_empty: Vec<_> = conditions.iter().filter_map(|c| c.as_ref()).collect(); + + // If we have both Some and None conditions, that means in some of the expressions where + // this table appear we want all rows. So we need to fetch all rows. + if has_none && !non_empty.is_empty() { return Ok(String::new()); } - // For multi-table queries (JOINs), extract conditions for the specific table - if let ast::OneSelect::Select { - where_clause: Some(ref where_expr), - .. - } = self.select_stmt.body.select - { - // Extract conditions that reference only the specified table - let table_conditions = self.extract_table_conditions(where_expr, table_name)?; - if let Some(conditions) = table_conditions { - // Unqualify the expression for single-table query - let unqualified = self.unqualify_expression(&conditions, table_name); - return Ok(unqualified.to_string()); - } + if non_empty.is_empty() { + return Ok(String::new()); } - Ok(String::new()) + if non_empty.len() == 1 { + // Unqualify the expression before converting to string + let unqualified = Self::unqualify_expression(non_empty[0], table_name, table_aliases); + return Ok(unqualified.to_string()); + } + + // Multiple conditions - combine with OR + // This happens in UNION ALL when the same table appears multiple times + let mut combined_parts = Vec::new(); + for condition in non_empty { + let unqualified = Self::unqualify_expression(condition, table_name, table_aliases); + // Wrap each condition in parentheses to preserve precedence + combined_parts.push(format!("({unqualified})")); + } + + // Join all conditions with OR + Ok(combined_parts.join(" OR ")) + } + /// Resolve a table alias to the actual table name + /// Check if an expression is a simple comparison that can be safely extracted + /// This excludes subqueries, CASE expressions, function calls, etc. + fn is_simple_comparison(expr: &ast::Expr) -> bool { + match expr { + // Simple column references and literals are OK + ast::Expr::Column { .. } | ast::Expr::Literal(_) => true, + + // Simple binary operations between simple expressions are OK + ast::Expr::Binary(left, op, right) => { + match op { + // Logical operators + ast::Operator::And | ast::Operator::Or => { + Self::is_simple_comparison(left) && Self::is_simple_comparison(right) + } + // Comparison operators + ast::Operator::Equals + | ast::Operator::NotEquals + | ast::Operator::Less + | ast::Operator::LessEquals + | ast::Operator::Greater + | ast::Operator::GreaterEquals + | ast::Operator::Is + | ast::Operator::IsNot => { + Self::is_simple_comparison(left) && Self::is_simple_comparison(right) + } + // String concatenation and other operations are NOT simple + ast::Operator::Concat => false, + // Arithmetic might be OK if operands are simple + ast::Operator::Add + | ast::Operator::Subtract + | ast::Operator::Multiply + | ast::Operator::Divide + | ast::Operator::Modulus => { + Self::is_simple_comparison(left) && Self::is_simple_comparison(right) + } + _ => false, + } + } + + // Unary operations might be OK + ast::Expr::Unary( + ast::UnaryOperator::Not + | ast::UnaryOperator::Negative + | ast::UnaryOperator::Positive, + inner, + ) => Self::is_simple_comparison(inner), + ast::Expr::Unary(_, _) => false, + + // Complex expressions are NOT simple + ast::Expr::Case { .. } => false, + ast::Expr::Cast { .. } => false, + ast::Expr::Collate { .. } => false, + ast::Expr::Exists(_) => false, + ast::Expr::FunctionCall { .. } => false, + ast::Expr::InList { .. } => false, + ast::Expr::InSelect { .. } => false, + ast::Expr::Like { .. } => false, + ast::Expr::NotNull(_) => true, // IS NOT NULL is simple enough + ast::Expr::Parenthesized(exprs) => { + // Parenthesized expression can contain multiple expressions + // Only consider it simple if it has exactly one simple expression + exprs.len() == 1 && Self::is_simple_comparison(&exprs[0]) + } + ast::Expr::Subquery(_) => false, + + // BETWEEN might be OK if all operands are simple + ast::Expr::Between { .. } => { + // BETWEEN has a different structure, for safety just exclude it + false + } + + // Qualified references are simple + ast::Expr::DoublyQualified(..) => true, + ast::Expr::Qualified(_, _) => true, + + // These are simple + ast::Expr::Id(_) => true, + ast::Expr::Name(_) => true, + + // Anything else is not simple + _ => false, + } } - /// Extract conditions from an expression that reference only the specified table - fn extract_table_conditions( - &self, + /// Extract conditions from a WHERE clause that apply to a specific table + fn extract_conditions_for_table( expr: &ast::Expr, table_name: &str, - ) -> crate::Result> { + aliases: &HashMap, + all_tables: &[String], + schema: &Schema, + ) -> Option { match expr { ast::Expr::Binary(left, op, right) => { match op { ast::Operator::And => { // For AND, we can extract conditions independently - let left_cond = self.extract_table_conditions(left, table_name)?; - let right_cond = self.extract_table_conditions(right, table_name)?; + let left_cond = Self::extract_conditions_for_table( + left, table_name, aliases, all_tables, schema, + ); + let right_cond = Self::extract_conditions_for_table( + right, table_name, aliases, all_tables, schema, + ); match (left_cond, right_cond) { - (Some(l), Some(r)) => { - // Both conditions apply to this table - Ok(Some(ast::Expr::Binary( - Box::new(l), - ast::Operator::And, - Box::new(r), - ))) - } - (Some(l), None) => Ok(Some(l)), - (None, Some(r)) => Ok(Some(r)), - (None, None) => Ok(None), + (Some(l), Some(r)) => Some(ast::Expr::Binary( + Box::new(l), + ast::Operator::And, + Box::new(r), + )), + (Some(l), None) => Some(l), + (None, Some(r)) => Some(r), + (None, None) => None, } } ast::Operator::Or => { - // For OR, both sides must reference the same table(s) - // If either side references multiple tables, we can't extract it - let left_tables = self.get_referenced_tables_in_expr(left)?; - let right_tables = self.get_referenced_tables_in_expr(right)?; + // For OR, both sides must reference only our table + let left_tables = + Self::get_tables_in_expr(left, aliases, all_tables, schema); + let right_tables = + Self::get_tables_in_expr(right, aliases, all_tables, schema); - // If both sides only reference our table, include the whole OR if left_tables.len() == 1 && left_tables.contains(&table_name.to_string()) && right_tables.len() == 1 && right_tables.contains(&table_name.to_string()) + && Self::is_simple_comparison(expr) { - Ok(Some(expr.clone())) + Some(expr.clone()) } else { - // OR condition involves multiple tables, can't extract - Ok(None) + None } } _ => { - // For comparison operators, check if this condition references only our table - // AND is simple enough to be pushed down (no complex expressions) - let referenced_tables = self.get_referenced_tables_in_expr(expr)?; + // For comparison operators, check if this condition only references our table + let referenced_tables = + Self::get_tables_in_expr(expr, aliases, all_tables, schema); if referenced_tables.len() == 1 && referenced_tables.contains(&table_name.to_string()) + && Self::is_simple_comparison(expr) { - // Check if this is a simple comparison that can be pushed down - // Complex expressions like (a * b) >= c should be handled by the circuit - if self.is_simple_comparison(expr) { - Ok(Some(expr.clone())) - } else { - // Complex expression - let the circuit handle it - Ok(None) - } + Some(expr.clone()) } else { - Ok(None) + None } } } } - ast::Expr::Parenthesized(exprs) => { - if exprs.len() == 1 { - self.extract_table_conditions(&exprs[0], table_name) - } else { - Ok(None) - } - } _ => { - // For other expressions, check if they reference only our table - // AND are simple enough to be pushed down - let referenced_tables = self.get_referenced_tables_in_expr(expr)?; + // For other expressions, check if they only reference our table + let referenced_tables = Self::get_tables_in_expr(expr, aliases, all_tables, schema); if referenced_tables.len() == 1 && referenced_tables.contains(&table_name.to_string()) - && self.is_simple_comparison(expr) + && Self::is_simple_comparison(expr) { - Ok(Some(expr.clone())) + Some(expr.clone()) } else { - Ok(None) + None } } } } - /// Check if an expression is a simple comparison that can be pushed down to table scan - /// Returns true for simple comparisons like "column = value" or "column > value" - /// Returns false for complex expressions like "(a * b) > value" - fn is_simple_comparison(&self, expr: &ast::Expr) -> bool { - match expr { - ast::Expr::Binary(left, op, right) => { - // Check if it's a comparison operator - matches!( - op, - ast::Operator::Equals - | ast::Operator::NotEquals - | ast::Operator::Greater - | ast::Operator::GreaterEquals - | ast::Operator::Less - | ast::Operator::LessEquals - ) && self.is_simple_operand(left) - && self.is_simple_operand(right) - } - _ => false, - } - } - - /// Check if an operand is simple (column reference or literal) - fn is_simple_operand(&self, expr: &ast::Expr) -> bool { - matches!( - expr, - ast::Expr::Id(_) - | ast::Expr::Qualified(_, _) - | ast::Expr::DoublyQualified(_, _, _) - | ast::Expr::Literal(_) - ) - } - - /// Get the set of table names referenced in an expression - fn get_referenced_tables_in_expr(&self, expr: &ast::Expr) -> crate::Result> { - let mut tables = Vec::new(); - self.collect_referenced_tables(expr, &mut tables)?; - // Deduplicate - tables.sort(); - tables.dedup(); - Ok(tables) - } - - /// Recursively collect table references from an expression - fn collect_referenced_tables( - &self, + /// Unqualify column references in an expression + /// Removes table/alias prefixes from qualified column names + fn unqualify_expression( expr: &ast::Expr, - tables: &mut Vec, - ) -> crate::Result<()> { + table_name: &str, + aliases: &HashMap, + ) -> ast::Expr { match expr { - ast::Expr::Binary(left, _, right) => { - self.collect_referenced_tables(left, tables)?; - self.collect_referenced_tables(right, tables)?; - } - ast::Expr::Qualified(table, _) => { - // This is a qualified column reference (table.column or alias.column) - // We need to resolve aliases to actual table names - let actual_table = self.resolve_table_alias(table.as_str()); - tables.push(actual_table); - } - ast::Expr::Id(column) => { - // Unqualified column reference - if self.referenced_tables.len() > 1 { - // In a JOIN context, check which tables have this column - let mut tables_with_column = Vec::new(); - for table in &self.referenced_tables { - if table - .columns - .iter() - .any(|c| c.name.as_ref() == Some(&column.to_string())) - { - tables_with_column.push(table.name.clone()); - } - } - - if tables_with_column.len() > 1 { - // Ambiguous column - this should have been caught earlier - // Return error to be safe - return Err(crate::LimboError::ParseError(format!( - "Ambiguous column name '{}' in WHERE clause - exists in tables: {}", - column, - tables_with_column.join(", ") - ))); - } else if tables_with_column.len() == 1 { - // Unambiguous - only one table has this column - // This is allowed by SQLite - tables.push(tables_with_column[0].clone()); - } else { - // Column doesn't exist in any table - this is an error - // but should be caught during compilation - return Err(crate::LimboError::ParseError(format!( - "Column '{column}' not found in any table" - ))); - } - } else { - // Single table context - unqualified columns belong to that table - if let Some(table) = self.referenced_tables.first() { - tables.push(table.name.clone()); - } - } - } - ast::Expr::DoublyQualified(_database, table, _column) => { - // For database.table.column, resolve the table name - let table_str = table.as_str(); - let actual_table = self.resolve_table_alias(table_str); - tables.push(actual_table); - } - ast::Expr::Parenthesized(exprs) => { - for e in exprs { - self.collect_referenced_tables(e, tables)?; - } - } - _ => { - // Literals and other expressions don't reference tables - } - } - Ok(()) - } - - /// Convert a qualified expression to unqualified for single-table queries - /// This removes table prefixes from column references since they're not needed - /// when querying a single table - fn unqualify_expression(&self, expr: &ast::Expr, table_name: &str) -> ast::Expr { - match expr { - ast::Expr::Binary(left, op, right) => { - // Recursively unqualify both sides - ast::Expr::Binary( - Box::new(self.unqualify_expression(left, table_name)), - *op, - Box::new(self.unqualify_expression(right, table_name)), - ) - } - ast::Expr::Qualified(table, column) => { - // Convert qualified column to unqualified if it's for our table - // Handle both "table.column" and "database.table.column" cases - let table_str = table.as_str(); - - // Check if this is a database.table reference - let actual_table = if table_str.contains('.') { - // Split on '.' and take the last part as the table name + ast::Expr::Binary(left, op, right) => ast::Expr::Binary( + Box::new(Self::unqualify_expression(left, table_name, aliases)), + *op, + Box::new(Self::unqualify_expression(right, table_name, aliases)), + ), + ast::Expr::Qualified(table_or_alias, column) => { + // Check if this qualification refers to our table + let table_str = table_or_alias.as_str(); + let actual_table = if let Some(actual) = aliases.get(table_str) { + actual.clone() + } else if table_str.contains('.') { + // Handle database.table format table_str .split('.') .next_back() .unwrap_or(table_str) .to_string() } else { - // Could be an alias or direct table name - self.resolve_table_alias(table_str) + table_str.to_string() }; if actual_table == table_name { - // Just return the column name without qualification + // Remove the qualification ast::Expr::Id(column.clone()) } else { - // This shouldn't happen if extract_table_conditions worked correctly - // but keep it qualified just in case + // Keep the qualification (shouldn't happen if extraction worked correctly) expr.clone() } } ast::Expr::DoublyQualified(_database, table, column) => { - // This is database.table.column format - // Check if the table matches our target table - let table_str = table.as_str(); - let actual_table = self.resolve_table_alias(table_str); - - if actual_table == table_name { - // Just return the column name without qualification + // Check if this refers to our table + if table.as_str() == table_name { + // Remove the qualification, keep just the column ast::Expr::Id(column.clone()) } else { - // Keep it qualified if it's for a different table + // Keep the qualification (shouldn't happen if extraction worked correctly) expr.clone() } } - ast::Expr::Parenthesized(exprs) => { - // Recursively unqualify expressions in parentheses - let unqualified_exprs: Vec> = exprs + ast::Expr::Unary(op, inner) => ast::Expr::Unary( + *op, + Box::new(Self::unqualify_expression(inner, table_name, aliases)), + ), + ast::Expr::FunctionCall { + name, + args, + distinctness, + filter_over, + order_by, + } => ast::Expr::FunctionCall { + name: name.clone(), + args: args .iter() - .map(|e| Box::new(self.unqualify_expression(e, table_name))) - .collect(); - ast::Expr::Parenthesized(unqualified_exprs) + .map(|arg| Box::new(Self::unqualify_expression(arg, table_name, aliases))) + .collect(), + distinctness: *distinctness, + filter_over: filter_over.clone(), + order_by: order_by.clone(), + }, + ast::Expr::InList { lhs, not, rhs } => ast::Expr::InList { + lhs: Box::new(Self::unqualify_expression(lhs, table_name, aliases)), + not: *not, + rhs: rhs + .iter() + .map(|item| Box::new(Self::unqualify_expression(item, table_name, aliases))) + .collect(), + }, + ast::Expr::Between { + lhs, + not, + start, + end, + } => ast::Expr::Between { + lhs: Box::new(Self::unqualify_expression(lhs, table_name, aliases)), + not: *not, + start: Box::new(Self::unqualify_expression(start, table_name, aliases)), + end: Box::new(Self::unqualify_expression(end, table_name, aliases)), + }, + _ => expr.clone(), + } + } + + /// Get all tables referenced in an expression + fn get_tables_in_expr( + expr: &ast::Expr, + aliases: &HashMap, + all_tables: &[String], + schema: &Schema, + ) -> Vec { + let mut tables = Vec::new(); + Self::collect_tables_in_expr(expr, aliases, all_tables, schema, &mut tables); + tables.sort(); + tables.dedup(); + tables + } + + /// Recursively collect table references from an expression + fn collect_tables_in_expr( + expr: &ast::Expr, + aliases: &HashMap, + all_tables: &[String], + schema: &Schema, + tables: &mut Vec, + ) { + match expr { + ast::Expr::Binary(left, _, right) => { + Self::collect_tables_in_expr(left, aliases, all_tables, schema, tables); + Self::collect_tables_in_expr(right, aliases, all_tables, schema, tables); + } + ast::Expr::Qualified(table_or_alias, _) => { + // Handle database.table or just table/alias + let table_str = table_or_alias.as_str(); + let table_name = if let Some(actual_table) = aliases.get(table_str) { + // It's an alias + actual_table.clone() + } else if table_str.contains('.') { + // It might be database.table format, extract just the table name + table_str + .split('.') + .next_back() + .unwrap_or(table_str) + .to_string() + } else { + // It's a direct table name + table_str.to_string() + }; + tables.push(table_name); + } + ast::Expr::DoublyQualified(_database, table, _column) => { + // For database.table.column, extract the table name + tables.push(table.to_string()); + } + ast::Expr::Id(column) => { + // Unqualified column - try to find which table has this column + if all_tables.len() == 1 { + tables.push(all_tables[0].clone()); + } else { + // Check which table has this column + for table_name in all_tables { + if let Some(table) = schema.get_btree_table(table_name) { + if table + .columns + .iter() + .any(|col| col.name.as_deref() == Some(column.as_str())) + { + tables.push(table_name.clone()); + break; // Found the table, stop looking + } + } + } + } + } + ast::Expr::FunctionCall { args, .. } => { + for arg in args { + Self::collect_tables_in_expr(arg, aliases, all_tables, schema, tables); + } + } + ast::Expr::InList { lhs, rhs, .. } => { + Self::collect_tables_in_expr(lhs, aliases, all_tables, schema, tables); + for item in rhs { + Self::collect_tables_in_expr(item, aliases, all_tables, schema, tables); + } + } + ast::Expr::InSelect { lhs, .. } => { + Self::collect_tables_in_expr(lhs, aliases, all_tables, schema, tables); + } + ast::Expr::Between { + lhs, start, end, .. + } => { + Self::collect_tables_in_expr(lhs, aliases, all_tables, schema, tables); + Self::collect_tables_in_expr(start, aliases, all_tables, schema, tables); + Self::collect_tables_in_expr(end, aliases, all_tables, schema, tables); + } + ast::Expr::Unary(_, expr) => { + Self::collect_tables_in_expr(expr, aliases, all_tables, schema, tables); } _ => { - // Other expression types (literals, unqualified columns, etc.) stay as-is - expr.clone() + // Literals, etc. don't reference tables } } } - - /// Resolve a table alias to the actual table name - fn resolve_table_alias(&self, alias: &str) -> String { - // Check if there's an alias mapping in the FROM/JOIN clauses - // For now, we'll do a simple check - if the alias matches a table name, use it - // Otherwise, try to find it in the FROM clause - - // First check if it's an actual table name - if self.referenced_tables.iter().any(|t| t.name == alias) { - return alias.to_string(); - } - - // Check if it's an alias that maps to a table - if let Some(table_name) = self.table_aliases.get(alias) { - return table_name.clone(); - } - - // If we can't resolve it, return as-is (it might be a table name we don't know about) - alias.to_string() - } - /// Populate the view by scanning the source table using a state machine /// This can be called multiple times and will resume from where it left off /// This method is only for materialized views and will persist data to the btree @@ -1342,17 +1576,58 @@ mod tests { } } + // Type alias for the complex return type of extract_all_tables + type ExtractedTableInfo = ( + Vec>, + HashMap, + HashMap, + HashMap>>, + ); + + fn extract_all_tables(select: &ast::Select, schema: &Schema) -> Result { + let mut referenced_tables = Vec::new(); + let mut table_aliases = HashMap::new(); + let mut qualified_table_names = HashMap::new(); + let mut table_conditions = HashMap::new(); + IncrementalView::extract_all_tables( + select, + schema, + &mut referenced_tables, + &mut table_aliases, + &mut qualified_table_names, + &mut table_conditions, + )?; + Ok(( + referenced_tables, + table_aliases, + qualified_table_names, + table_conditions, + )) + } + #[test] fn test_extract_single_table() { let schema = create_test_schema(); let select = parse_select("SELECT * FROM customers"); - let (tables, _, _) = IncrementalView::extract_all_tables(&select, &schema).unwrap(); + let (tables, _, _, _table_conditions) = extract_all_tables(&select, &schema).unwrap(); assert_eq!(tables.len(), 1); assert_eq!(tables[0].name, "customers"); } + #[test] + fn test_tables_from_union() { + let schema = create_test_schema(); + let select = parse_select("SELECT name FROM customers union SELECT name from products"); + + let (tables, _, _, table_conditions) = extract_all_tables(&select, &schema).unwrap(); + + assert_eq!(tables.len(), 2); + assert!(table_conditions.contains_key("customers")); + assert!(table_conditions.contains_key("products")); + } + #[test] fn test_extract_tables_from_inner_join() { let schema = create_test_schema(); @@ -1360,11 +1635,11 @@ mod tests { "SELECT * FROM customers INNER JOIN orders ON customers.id = orders.customer_id", ); - let (tables, _, _) = IncrementalView::extract_all_tables(&select, &schema).unwrap(); + let (tables, _, _, table_conditions) = extract_all_tables(&select, &schema).unwrap(); assert_eq!(tables.len(), 2); - assert_eq!(tables[0].name, "customers"); - assert_eq!(tables[1].name, "orders"); + assert!(table_conditions.contains_key("customers")); + assert!(table_conditions.contains_key("orders")); } #[test] @@ -1376,12 +1651,12 @@ mod tests { INNER JOIN products ON orders.id = products.id", ); - let (tables, _, _) = IncrementalView::extract_all_tables(&select, &schema).unwrap(); + let (tables, _, _, table_conditions) = extract_all_tables(&select, &schema).unwrap(); assert_eq!(tables.len(), 3); - assert_eq!(tables[0].name, "customers"); - assert_eq!(tables[1].name, "orders"); - assert_eq!(tables[2].name, "products"); + assert!(table_conditions.contains_key("customers")); + assert!(table_conditions.contains_key("orders")); + assert!(table_conditions.contains_key("products")); } #[test] @@ -1391,11 +1666,11 @@ mod tests { "SELECT * FROM customers LEFT JOIN orders ON customers.id = orders.customer_id", ); - let (tables, _, _) = IncrementalView::extract_all_tables(&select, &schema).unwrap(); + let (tables, _, _, table_conditions) = extract_all_tables(&select, &schema).unwrap(); assert_eq!(tables.len(), 2); - assert_eq!(tables[0].name, "customers"); - assert_eq!(tables[1].name, "orders"); + assert!(table_conditions.contains_key("customers")); + assert!(table_conditions.contains_key("orders")); } #[test] @@ -1403,11 +1678,11 @@ mod tests { let schema = create_test_schema(); let select = parse_select("SELECT * FROM customers CROSS JOIN orders"); - let (tables, _, _) = IncrementalView::extract_all_tables(&select, &schema).unwrap(); + let (tables, _, _, table_conditions) = extract_all_tables(&select, &schema).unwrap(); assert_eq!(tables.len(), 2); - assert_eq!(tables[0].name, "customers"); - assert_eq!(tables[1].name, "orders"); + assert!(table_conditions.contains_key("customers")); + assert!(table_conditions.contains_key("orders")); } #[test] @@ -1416,12 +1691,17 @@ mod tests { let select = parse_select("SELECT * FROM customers c INNER JOIN orders o ON c.id = o.customer_id"); - let (tables, _, _) = IncrementalView::extract_all_tables(&select, &schema).unwrap(); + let (tables, aliases, _, _table_conditions) = extract_all_tables(&select, &schema).unwrap(); // Should still extract the actual table names, not aliases assert_eq!(tables.len(), 2); - assert_eq!(tables[0].name, "customers"); - assert_eq!(tables[1].name, "orders"); + let table_names: Vec<&str> = tables.iter().map(|t| t.name.as_str()).collect(); + assert!(table_names.contains(&"customers")); + assert!(table_names.contains(&"orders")); + + // Check that aliases are correctly mapped + assert_eq!(aliases.get("c"), Some(&"customers".to_string())); + assert_eq!(aliases.get("o"), Some(&"orders".to_string())); } #[test] @@ -1429,8 +1709,7 @@ mod tests { let schema = create_test_schema(); let select = parse_select("SELECT * FROM nonexistent"); - let result = - IncrementalView::extract_all_tables(&select, &schema).map(|(tables, _, _)| tables); + let result = extract_all_tables(&select, &schema).map(|(tables, _, _, _)| tables); assert!(result.is_err()); assert!(result @@ -1446,8 +1725,7 @@ mod tests { "SELECT * FROM customers INNER JOIN nonexistent ON customers.id = nonexistent.id", ); - let result = - IncrementalView::extract_all_tables(&select, &schema).map(|(tables, _, _)| tables); + let result = extract_all_tables(&select, &schema).map(|(tables, _, _, _)| tables); assert!(result.is_err()); assert!(result @@ -1462,14 +1740,15 @@ mod tests { let schema = create_test_schema(); let select = parse_select("SELECT * FROM customers"); - let (tables, aliases, qualified_names) = - IncrementalView::extract_all_tables(&select, &schema).unwrap(); + let (tables, aliases, qualified_names, table_conditions) = + extract_all_tables(&select, &schema).unwrap(); let view = IncrementalView::new( "test_view".to_string(), select.clone(), tables, aliases, qualified_names, + table_conditions, extract_view_columns(&select, &schema).unwrap(), &schema, 1, // main_data_root @@ -1491,14 +1770,15 @@ mod tests { let schema = create_test_schema(); let select = parse_select("SELECT * FROM customers WHERE id > 10"); - let (tables, aliases, qualified_names) = - IncrementalView::extract_all_tables(&select, &schema).unwrap(); + let (tables, aliases, qualified_names, table_conditions) = + extract_all_tables(&select, &schema).unwrap(); let view = IncrementalView::new( "test_view".to_string(), select.clone(), tables, aliases, qualified_names, + table_conditions, extract_view_columns(&select, &schema).unwrap(), &schema, 1, // main_data_root @@ -1524,14 +1804,15 @@ mod tests { WHERE c.id > 10 AND o.total > 100", ); - let (tables, aliases, qualified_names) = - IncrementalView::extract_all_tables(&select, &schema).unwrap(); + let (tables, aliases, qualified_names, table_conditions) = + extract_all_tables(&select, &schema).unwrap(); let view = IncrementalView::new( "test_view".to_string(), select.clone(), tables, aliases, qualified_names, + table_conditions, extract_view_columns(&select, &schema).unwrap(), &schema, 1, // main_data_root @@ -1547,8 +1828,12 @@ mod tests { // With per-table WHERE extraction: // - customers table gets: c.id > 10 // - orders table gets: o.total > 100 - assert_eq!(queries[0], "SELECT * FROM customers WHERE id > 10"); - assert_eq!(queries[1], "SELECT * FROM orders WHERE total > 100"); + assert!(queries + .iter() + .any(|q| q == "SELECT * FROM customers WHERE id > 10")); + assert!(queries + .iter() + .any(|q| q == "SELECT * FROM orders WHERE total > 100")); } #[test] @@ -1562,14 +1847,15 @@ mod tests { AND o.customer_id = 5 AND (c.id = 15 OR o.total = 200)", ); - let (tables, aliases, qualified_names) = - IncrementalView::extract_all_tables(&select, &schema).unwrap(); + let (tables, aliases, qualified_names, table_conditions) = + extract_all_tables(&select, &schema).unwrap(); let view = IncrementalView::new( "test_view".to_string(), select.clone(), tables, aliases, qualified_names, + table_conditions, extract_view_columns(&select, &schema).unwrap(), &schema, 1, // main_data_root @@ -1587,152 +1873,27 @@ mod tests { // - orders gets: o.total > 100 AND o.customer_id = 5 // Note: The OR condition (c.id = 15 OR o.total = 200) involves both tables, // so it cannot be extracted to either table individually - assert_eq!( - queries[0], - "SELECT * FROM customers WHERE id > 10 AND name = 'John'" - ); - assert_eq!( - queries[1], - "SELECT * FROM orders WHERE total > 100 AND customer_id = 5" - ); - } - - #[test] - fn test_where_extraction_for_three_tables() { - // Test that WHERE clause extraction correctly separates conditions for 3+ tables - // This addresses the concern about conditions "piling up" as joins increase - - // Simulate a three-table scenario - let schema = create_test_schema(); - - // Parse a WHERE clause with conditions for three different tables - let select = parse_select( - "SELECT * FROM customers WHERE c.id > 10 AND o.total > 100 AND p.price > 50", - ); - - // Get the WHERE expression - if let ast::OneSelect::Select { - where_clause: Some(ref where_expr), - .. - } = select.body.select - { - // Create a view with three tables to test extraction - let tables = vec![ - schema.get_btree_table("customers").unwrap(), - schema.get_btree_table("orders").unwrap(), - schema.get_btree_table("products").unwrap(), - ]; - - let mut aliases = HashMap::new(); - aliases.insert("c".to_string(), "customers".to_string()); - aliases.insert("o".to_string(), "orders".to_string()); - aliases.insert("p".to_string(), "products".to_string()); - - // Create a minimal view just to test extraction logic - let view = IncrementalView { - name: "test".to_string(), - select_stmt: select.clone(), - circuit: DbspCircuit::new(1, 2, 3), - referenced_tables: tables, - table_aliases: aliases, - qualified_table_names: HashMap::new(), - column_schema: ViewColumnSchema { - columns: vec![], - tables: vec![], - }, - populate_state: PopulateState::Start, - tracker: Arc::new(Mutex::new(ComputationTracker::new())), - root_page: 0, - }; - - // Test extraction for each table - let customers_conds = view - .extract_table_conditions(where_expr, "customers") - .unwrap(); - let orders_conds = view.extract_table_conditions(where_expr, "orders").unwrap(); - let products_conds = view - .extract_table_conditions(where_expr, "products") - .unwrap(); - - // Verify each table only gets its conditions - if let Some(cond) = customers_conds { - let sql = cond.to_string(); - assert!(sql.contains("id > 10")); - assert!(!sql.contains("total")); - assert!(!sql.contains("price")); - } - - if let Some(cond) = orders_conds { - let sql = cond.to_string(); - assert!(sql.contains("total > 100")); - assert!(!sql.contains("id > 10")); // From customers - assert!(!sql.contains("price")); - } - - if let Some(cond) = products_conds { - let sql = cond.to_string(); - assert!(sql.contains("price > 50")); - assert!(!sql.contains("id > 10")); // From customers - assert!(!sql.contains("total")); - } - } else { - panic!("Failed to parse WHERE clause"); - } - } - - #[test] - fn test_alias_resolution_works_correctly() { - // Test that alias resolution properly maps aliases to table names - let schema = create_test_schema(); - let select = parse_select( - "SELECT * FROM customers c \ - JOIN orders o ON c.id = o.customer_id \ - WHERE c.id > 10 AND o.total > 100", - ); - - let (tables, aliases, qualified_names) = - IncrementalView::extract_all_tables(&select, &schema).unwrap(); - let view = IncrementalView::new( - "test_view".to_string(), - select.clone(), - tables, - aliases, - qualified_names, - extract_view_columns(&select, &schema).unwrap(), - &schema, - 1, // main_data_root - 2, // internal_state_root - 3, // internal_state_index_root - ) - .unwrap(); - - // Verify that alias mappings were extracted correctly - assert_eq!(view.table_aliases.get("c"), Some(&"customers".to_string())); - assert_eq!(view.table_aliases.get("o"), Some(&"orders".to_string())); - - // Verify that SQL generation uses the aliases correctly - let queries = view.sql_for_populate().unwrap(); - assert_eq!(queries.len(), 2); - - // Each query should use the actual table name, not the alias - assert!(queries[0].contains("FROM customers") || queries[1].contains("FROM customers")); - assert!(queries[0].contains("FROM orders") || queries[1].contains("FROM orders")); + // Check both queries exist (order doesn't matter) + assert!(queries + .contains(&"SELECT * FROM customers WHERE id > 10 AND name = 'John'".to_string())); + assert!(queries + .contains(&"SELECT * FROM orders WHERE total > 100 AND customer_id = 5".to_string())); } #[test] fn test_sql_for_populate_table_without_rowid_alias() { - // Test that tables without a rowid alias properly include rowid in SELECT let schema = create_test_schema(); let select = parse_select("SELECT * FROM logs WHERE level > 2"); - let (tables, aliases, qualified_names) = - IncrementalView::extract_all_tables(&select, &schema).unwrap(); + let (tables, aliases, qualified_names, table_conditions) = + extract_all_tables(&select, &schema).unwrap(); let view = IncrementalView::new( "test_view".to_string(), select.clone(), tables, aliases, qualified_names, + table_conditions, extract_view_columns(&select, &schema).unwrap(), &schema, 1, // main_data_root @@ -1758,14 +1919,15 @@ mod tests { WHERE c.id > 10 AND l.level > 2", ); - let (tables, aliases, qualified_names) = - IncrementalView::extract_all_tables(&select, &schema).unwrap(); + let (tables, aliases, qualified_names, table_conditions) = + extract_all_tables(&select, &schema).unwrap(); let view = IncrementalView::new( "test_view".to_string(), select.clone(), tables, aliases, qualified_names, + table_conditions, extract_view_columns(&select, &schema).unwrap(), &schema, 1, // main_data_root @@ -1778,8 +1940,8 @@ mod tests { assert_eq!(queries.len(), 2); // customers has rowid alias (id), logs doesn't - assert_eq!(queries[0], "SELECT * FROM customers WHERE id > 10"); - assert_eq!(queries[1], "SELECT *, rowid FROM logs WHERE level > 2"); + assert!(queries.contains(&"SELECT * FROM customers WHERE id > 10".to_string())); + assert!(queries.contains(&"SELECT *, rowid FROM logs WHERE level > 2".to_string())); } #[test] @@ -1792,14 +1954,15 @@ mod tests { // Test with single table using database qualification let select = parse_select("SELECT * FROM main.customers WHERE main.customers.id > 10"); - let (tables, aliases, qualified_names) = - IncrementalView::extract_all_tables(&select, &schema).unwrap(); + let (tables, aliases, qualified_names, table_conditions) = + extract_all_tables(&select, &schema).unwrap(); let view = IncrementalView::new( "test_view".to_string(), select.clone(), tables, aliases, qualified_names, + table_conditions, extract_view_columns(&select, &schema).unwrap(), &schema, 1, // main_data_root @@ -1827,14 +1990,15 @@ mod tests { WHERE main.customers.id > 10 AND main.orders.total > 100", ); - let (tables, aliases, qualified_names) = - IncrementalView::extract_all_tables(&select, &schema).unwrap(); + let (tables, aliases, qualified_names, table_conditions) = + extract_all_tables(&select, &schema).unwrap(); let view = IncrementalView::new( "test_view".to_string(), select.clone(), tables, aliases, qualified_names, + table_conditions, extract_view_columns(&select, &schema).unwrap(), &schema, 1, // main_data_root @@ -1848,8 +2012,93 @@ mod tests { assert_eq!(queries.len(), 2); // The FROM clauses should preserve database qualification, // but WHERE clauses should have unqualified column names - assert_eq!(queries[0], "SELECT * FROM main.customers WHERE id > 10"); - assert_eq!(queries[1], "SELECT * FROM main.orders WHERE total > 100"); + assert!(queries.contains(&"SELECT * FROM main.customers WHERE id > 10".to_string())); + assert!(queries.contains(&"SELECT * FROM main.orders WHERE total > 100".to_string())); + } + + #[test] + fn test_where_extraction_for_three_tables_with_aliases() { + // Test that WHERE clause extraction correctly separates conditions for 3+ tables + // This addresses the concern about conditions "piling up" as joins increase + let schema = create_test_schema(); + let select = parse_select( + "SELECT * FROM customers c + JOIN orders o ON c.id = o.customer_id + JOIN products p ON p.id = o.product_id + WHERE c.id > 10 AND o.total > 100 AND p.price > 50", + ); + + let (tables, aliases, qualified_names, table_conditions) = + extract_all_tables(&select, &schema).unwrap(); + + // Verify we extracted all three tables + assert_eq!(tables.len(), 3); + let table_names: Vec<&str> = tables.iter().map(|t| t.name.as_str()).collect(); + assert!(table_names.contains(&"customers")); + assert!(table_names.contains(&"orders")); + assert!(table_names.contains(&"products")); + + // Verify aliases are correctly mapped + assert_eq!(aliases.get("c"), Some(&"customers".to_string())); + assert_eq!(aliases.get("o"), Some(&"orders".to_string())); + assert_eq!(aliases.get("p"), Some(&"products".to_string())); + + // Generate populate queries to verify each table gets its own conditions + let queries = IncrementalView::generate_populate_queries( + &select, + &tables, + &aliases, + &qualified_names, + &table_conditions, + ) + .unwrap(); + + assert_eq!(queries.len(), 3); + + // Verify the exact queries generated for each table + // The order might vary, so check all possibilities + let expected_queries = vec![ + "SELECT * FROM customers WHERE id > 10", + "SELECT * FROM orders WHERE total > 100", + "SELECT * FROM products WHERE price > 50", + ]; + + for expected in &expected_queries { + assert!( + queries.contains(&expected.to_string()), + "Missing expected query: {expected}. Got: {queries:?}" + ); + } + } + + #[test] + fn test_sql_for_populate_complex_expressions_not_included() { + // Test that complex expressions (subqueries, CASE, string concat) are NOT included in populate queries + let schema = create_test_schema(); + let select = parse_select( + "SELECT * FROM customers + WHERE id > (SELECT MAX(customer_id) FROM orders) + AND name || ' Customer' = 'John Customer' + AND CASE WHEN id > 10 THEN 1 ELSE 0 END = 1 + AND EXISTS (SELECT 1 FROM orders WHERE customer_id = customers.id)", + ); + + let (tables, aliases, qualified_names, table_conditions) = + extract_all_tables(&select, &schema).unwrap(); + + let queries = IncrementalView::generate_populate_queries( + &select, + &tables, + &aliases, + &qualified_names, + &table_conditions, + ) + .unwrap(); + + assert_eq!(queries.len(), 1); + // Since customers table has an INTEGER PRIMARY KEY (id), we should get SELECT * + // without rowid and without WHERE clause (all conditions are complex) + assert_eq!(queries[0], "SELECT * FROM customers"); } #[test] @@ -1862,14 +2111,15 @@ mod tests { WHERE total > 100", // 'total' only exists in orders table ); - let (tables, aliases, qualified_names) = - IncrementalView::extract_all_tables(&select, &schema).unwrap(); + let (tables, aliases, qualified_names, table_conditions) = + extract_all_tables(&select, &schema).unwrap(); let view = IncrementalView::new( "test_view".to_string(), select.clone(), tables, aliases, qualified_names, + table_conditions, extract_view_columns(&select, &schema).unwrap(), &schema, 1, // main_data_root @@ -1883,8 +2133,8 @@ mod tests { assert_eq!(queries.len(), 2); // 'total' is unambiguous (only in orders), so it should be extracted - assert_eq!(queries[0], "SELECT * FROM customers"); - assert_eq!(queries[1], "SELECT * FROM orders WHERE total > 100"); + assert!(queries.contains(&"SELECT * FROM customers".to_string())); + assert!(queries.contains(&"SELECT * FROM orders WHERE total > 100".to_string())); } #[test] @@ -1899,8 +2149,8 @@ mod tests { WHERE c.id > 10", ); - let (tables, aliases, qualified_names) = - IncrementalView::extract_all_tables(&select, &schema).unwrap(); + let (tables, aliases, qualified_names, table_conditions) = + extract_all_tables(&select, &schema).unwrap(); // Check that qualified names are preserved assert!(qualified_names.contains_key("customers")); @@ -1914,6 +2164,7 @@ mod tests { tables, aliases, qualified_names.clone(), + table_conditions, extract_view_columns(&select, &schema).unwrap(), &schema, 1, // main_data_root @@ -1928,8 +2179,8 @@ mod tests { // The FROM clause should contain the database-qualified name // But the WHERE clause should use unqualified column names - assert_eq!(queries[0], "SELECT * FROM main.customers WHERE id > 10"); - assert_eq!(queries[1], "SELECT * FROM main.orders"); + assert!(queries.contains(&"SELECT * FROM main.customers WHERE id > 10".to_string())); + assert!(queries.contains(&"SELECT * FROM main.orders".to_string())); } #[test] @@ -1944,8 +2195,8 @@ mod tests { WHERE c.id > 10 AND o.total < 1000", ); - let (tables, aliases, qualified_names) = - IncrementalView::extract_all_tables(&select, &schema).unwrap(); + let (tables, aliases, qualified_names, table_conditions) = + extract_all_tables(&select, &schema).unwrap(); // Check that qualified names are preserved where specified assert_eq!(qualified_names.get("customers").unwrap(), "main.customers"); @@ -1961,6 +2212,7 @@ mod tests { tables, aliases, qualified_names.clone(), + table_conditions, extract_view_columns(&select, &schema).unwrap(), &schema, 1, // main_data_root @@ -1974,7 +2226,468 @@ mod tests { assert_eq!(queries.len(), 2); // The FROM clause should preserve qualification where specified - assert_eq!(queries[0], "SELECT * FROM main.customers WHERE id > 10"); - assert_eq!(queries[1], "SELECT * FROM orders WHERE total < 1000"); + assert!(queries.contains(&"SELECT * FROM main.customers WHERE id > 10".to_string())); + assert!(queries.contains(&"SELECT * FROM orders WHERE total < 1000".to_string())); + } + + #[test] + fn test_extract_tables_with_simple_cte() { + let schema = create_test_schema(); + let select = parse_select( + "WITH customer_totals AS ( + SELECT c.id, c.name, SUM(o.total) as total_spent + FROM customers c + JOIN orders o ON c.id = o.customer_id + GROUP BY c.id, c.name + ) + SELECT * FROM customer_totals WHERE total_spent > 1000", + ); + + let (tables, aliases, _qualified_names, _table_conditions) = + extract_all_tables(&select, &schema).unwrap(); + + // Check that we found both tables from the CTE + assert_eq!(tables.len(), 2); + let table_names: Vec<&str> = tables.iter().map(|t| t.name.as_str()).collect(); + assert!(table_names.contains(&"customers")); + assert!(table_names.contains(&"orders")); + + // Check aliases from the CTE + assert_eq!(aliases.get("c"), Some(&"customers".to_string())); + assert_eq!(aliases.get("o"), Some(&"orders".to_string())); + } + + #[test] + fn test_extract_tables_with_multiple_ctes() { + let schema = create_test_schema(); + let select = parse_select( + "WITH + high_value_customers AS ( + SELECT id, name + FROM customers + WHERE id IN (SELECT customer_id FROM orders WHERE total > 500) + ), + recent_orders AS ( + SELECT id, customer_id, total + FROM orders + WHERE id > 100 + ) + SELECT hvc.name, ro.total + FROM high_value_customers hvc + JOIN recent_orders ro ON hvc.id = ro.customer_id", + ); + + let (tables, _aliases, _qualified_names, _table_conditions) = + extract_all_tables(&select, &schema).unwrap(); + + // Check that we found both tables from both CTEs + assert_eq!(tables.len(), 2); + let table_names: Vec<&str> = tables.iter().map(|t| t.name.as_str()).collect(); + assert!(table_names.contains(&"customers")); + assert!(table_names.contains(&"orders")); + } + + #[test] + fn test_sql_for_populate_union_mixed_conditions() { + // Test UNION where same table appears with and without WHERE clause + // This should drop ALL conditions to ensure we get all rows + let schema = create_test_schema(); + + let select = parse_select( + "SELECT * FROM customers WHERE id > 10 + UNION ALL + SELECT * FROM customers", + ); + + let (tables, aliases, qualified_names, table_conditions) = + extract_all_tables(&select, &schema).unwrap(); + + let view = IncrementalView::new( + "union_view".to_string(), + select.clone(), + tables, + aliases, + qualified_names, + table_conditions, + extract_view_columns(&select, &schema).unwrap(), + &schema, + 1, // main_data_root + 2, // internal_state_root + 3, // internal_state_index_root + ) + .unwrap(); + + let queries = view.sql_for_populate().unwrap(); + + assert_eq!(queries.len(), 1); + // When the same table appears with and without WHERE conditions in a UNION, + // we must fetch ALL rows (no WHERE clause) because the conditions are incompatible + assert_eq!( + queries[0], "SELECT * FROM customers", + "UNION with mixed conditions (some with WHERE, some without) should fetch ALL rows" + ); + } + + #[test] + fn test_extract_tables_with_nested_cte() { + let schema = create_test_schema(); + let select = parse_select( + "WITH RECURSIVE customer_hierarchy AS ( + SELECT id, name, 0 as level + FROM customers + WHERE id = 1 + UNION ALL + SELECT c.id, c.name, ch.level + 1 + FROM customers c + JOIN orders o ON c.id = o.customer_id + JOIN customer_hierarchy ch ON o.customer_id = ch.id + WHERE ch.level < 3 + ) + SELECT * FROM customer_hierarchy", + ); + + let (tables, _aliases, _qualified_names, _table_conditions) = + extract_all_tables(&select, &schema).unwrap(); + + // Check that we found the tables referenced in the recursive CTE + let table_names: Vec<&str> = tables.iter().map(|t| t.name.as_str()).collect(); + + // We're finding duplicates because "customers" appears twice in the recursive CTE + // Let's deduplicate + let unique_tables: std::collections::HashSet<&str> = table_names.iter().cloned().collect(); + assert_eq!(unique_tables.len(), 2); + assert!(unique_tables.contains("customers")); + assert!(unique_tables.contains("orders")); + } + + #[test] + fn test_extract_tables_with_cte_and_main_query() { + let schema = create_test_schema(); + let select = parse_select( + "WITH customer_stats AS ( + SELECT customer_id, COUNT(*) as order_count + FROM orders + GROUP BY customer_id + ) + SELECT c.name, cs.order_count, p.name as product_name + FROM customers c + JOIN customer_stats cs ON c.id = cs.customer_id + JOIN products p ON p.id = 1", + ); + + let (tables, aliases, _qualified_names, _table_conditions) = + extract_all_tables(&select, &schema).unwrap(); + + // Check that we found tables from both the CTE and the main query + assert_eq!(tables.len(), 3); + let table_names: Vec<&str> = tables.iter().map(|t| t.name.as_str()).collect(); + assert!(table_names.contains(&"customers")); + assert!(table_names.contains(&"orders")); + assert!(table_names.contains(&"products")); + + // Check aliases from main query + assert_eq!(aliases.get("c"), Some(&"customers".to_string())); + assert_eq!(aliases.get("p"), Some(&"products".to_string())); + } + + #[test] + fn test_sql_for_populate_simple_union() { + let schema = create_test_schema(); + let select = parse_select( + "SELECT * FROM orders WHERE total > 1000 + UNION ALL + SELECT * FROM orders WHERE total < 100", + ); + + let (tables, aliases, qualified_names, table_conditions) = + extract_all_tables(&select, &schema).unwrap(); + + // Generate populate queries + let queries = IncrementalView::generate_populate_queries( + &select, + &tables, + &aliases, + &qualified_names, + &table_conditions, + ) + .unwrap(); + + // We should have deduplicated to a single table + assert_eq!(tables.len(), 1, "Should have one unique table"); + assert_eq!(tables[0].name, "orders"); // Single table, order doesn't matter + + // Should have collected two conditions + assert_eq!(table_conditions.get("orders").unwrap().len(), 2); + + // Should combine multiple conditions with OR + assert_eq!(queries.len(), 1); + // Conditions are combined with OR + assert_eq!( + queries[0], + "SELECT * FROM orders WHERE (total > 1000) OR (total < 100)" + ); + } + + #[test] + fn test_sql_for_populate_with_union_and_filters() { + let schema = create_test_schema(); + + // Test UNION with different WHERE conditions on the same table + let select = parse_select( + "SELECT * FROM orders WHERE total > 1000 + UNION ALL + SELECT * FROM orders WHERE total < 100", + ); + + let view = IncrementalView::from_stmt( + ast::QualifiedName { + db_name: None, + name: ast::Name::Ident("test_view".to_string()), + alias: None, + }, + select, + &schema, + 1, + 2, + 3, + ) + .unwrap(); + + let queries = view.sql_for_populate().unwrap(); + + // We deduplicate tables, so we get 1 query for orders + assert_eq!(queries.len(), 1); + + // Multiple conditions on the same table are combined with OR + assert_eq!( + queries[0], + "SELECT * FROM orders WHERE (total > 1000) OR (total < 100)" + ); + } + + #[test] + fn test_sql_for_populate_with_union_mixed_tables() { + let schema = create_test_schema(); + + // Test UNION with different tables + let select = parse_select( + "SELECT id, name FROM customers WHERE id > 10 + UNION ALL + SELECT customer_id as id, 'Order' as name FROM orders WHERE total > 500", + ); + + let view = IncrementalView::from_stmt( + ast::QualifiedName { + db_name: None, + name: ast::Name::Ident("test_view".to_string()), + alias: None, + }, + select, + &schema, + 1, + 2, + 3, + ) + .unwrap(); + + let queries = view.sql_for_populate().unwrap(); + + assert_eq!(queries.len(), 2, "Should have one query per table"); + + // Check that each table gets its appropriate WHERE clause + let customers_query = queries + .iter() + .find(|q| q.contains("FROM customers")) + .unwrap(); + let orders_query = queries.iter().find(|q| q.contains("FROM orders")).unwrap(); + + assert!(customers_query.contains("WHERE id > 10")); + assert!(orders_query.contains("WHERE total > 500")); + } + + #[test] + fn test_sql_for_populate_duplicate_tables_conflicting_filters() { + // This tests what happens when we have duplicate table references with different filters + // We need to manually construct a view to simulate what would happen with CTEs + let schema = create_test_schema(); + + // Get the orders table twice (simulating what would happen with CTEs) + let orders_table = schema.get_btree_table("orders").unwrap(); + + let referenced_tables = vec![orders_table.clone(), orders_table.clone()]; + + // Create a SELECT that would have conflicting WHERE conditions + let select = parse_select( + "SELECT * FROM orders WHERE total > 1000", // This is just for the AST + ); + + let view = IncrementalView::new( + "test_view".to_string(), + select.clone(), + referenced_tables, + HashMap::new(), + HashMap::new(), + HashMap::new(), + extract_view_columns(&select, &schema).unwrap(), + &schema, + 1, + 2, + 3, + ) + .unwrap(); + + let queries = view.sql_for_populate().unwrap(); + + // With duplicates, we should get 2 identical queries + assert_eq!(queries.len(), 2); + + // Both should be the same since they're from the same table reference + assert_eq!(queries[0], queries[1]); + } + + #[test] + fn test_table_extraction_with_nested_ctes_complex_conditions() { + let schema = create_test_schema(); + let select = parse_select( + "WITH + customer_orders AS ( + SELECT c.*, o.total + FROM customers c + JOIN orders o ON c.id = o.customer_id + WHERE c.name LIKE 'A%' AND o.total > 100 + ), + top_customers AS ( + SELECT * FROM customer_orders WHERE total > 500 + ) + SELECT * FROM top_customers", + ); + + // Test table extraction directly without creating a view + let mut tables = Vec::new(); + let mut aliases = HashMap::new(); + let mut qualified_names = HashMap::new(); + let mut table_conditions = HashMap::new(); + + IncrementalView::extract_all_tables( + &select, + &schema, + &mut tables, + &mut aliases, + &mut qualified_names, + &mut table_conditions, + ) + .unwrap(); + + let table_names: Vec<&str> = tables.iter().map(|t| t.name.as_str()).collect(); + + // Should have one reference to each table + assert_eq!(table_names.len(), 2, "Should have 2 table references"); + assert!(table_names.contains(&"customers")); + assert!(table_names.contains(&"orders")); + + // Check aliases + assert_eq!(aliases.get("c"), Some(&"customers".to_string())); + assert_eq!(aliases.get("o"), Some(&"orders".to_string())); + } + + #[test] + fn test_union_all_populate_queries() { + // Test that UNION ALL generates correct populate queries + let schema = create_test_schema(); + + // Create a UNION ALL query that references the same table twice with different WHERE conditions + let sql = " + SELECT id, name FROM customers WHERE id < 5 + UNION ALL + SELECT id, name FROM customers WHERE id > 10 + "; + + let mut parser = Parser::new(sql.as_bytes()); + let cmd = parser.next_cmd().unwrap(); + let select_stmt = match cmd.unwrap() { + turso_parser::ast::Cmd::Stmt(ast::Stmt::Select(select)) => select, + _ => panic!("Expected SELECT statement"), + }; + + // Extract tables and conditions + let (tables, aliases, qualified_names, conditions) = + extract_all_tables(&select_stmt, &schema).unwrap(); + + // Generate populate queries + let queries = IncrementalView::generate_populate_queries( + &select_stmt, + &tables, + &aliases, + &qualified_names, + &conditions, + ) + .unwrap(); + + // Expected query - assuming customers table has INTEGER PRIMARY KEY + // so we don't need to select rowid separately + let expected = "SELECT * FROM customers WHERE (id < 5) OR (id > 10)"; + + assert_eq!( + queries.len(), + 1, + "Should generate exactly 1 query for UNION ALL with same table" + ); + assert_eq!(queries[0], expected, "Query should match expected format"); + } + + #[test] + fn test_union_all_different_tables_populate_queries() { + // Test UNION ALL with different tables + let schema = create_test_schema(); + + let sql = " + SELECT id, name FROM customers WHERE id < 5 + UNION ALL + SELECT id, product_name FROM orders WHERE amount > 100 + "; + + let mut parser = Parser::new(sql.as_bytes()); + let cmd = parser.next_cmd().unwrap(); + let select_stmt = match cmd.unwrap() { + turso_parser::ast::Cmd::Stmt(ast::Stmt::Select(select)) => select, + _ => panic!("Expected SELECT statement"), + }; + + // Extract tables and conditions + let (tables, aliases, qualified_names, conditions) = + extract_all_tables(&select_stmt, &schema).unwrap(); + + // Generate populate queries + let queries = IncrementalView::generate_populate_queries( + &select_stmt, + &tables, + &aliases, + &qualified_names, + &conditions, + ) + .unwrap(); + + // Should generate separate queries for each table + assert_eq!( + queries.len(), + 2, + "Should generate 2 queries for different tables" + ); + + // Check we have queries for both tables + let has_customers = queries.iter().any(|q| q.contains("customers")); + let has_orders = queries.iter().any(|q| q.contains("orders")); + assert!(has_customers, "Should have a query for customers table"); + assert!(has_orders, "Should have a query for orders table"); + + // Verify the customers query has its WHERE clause + let customers_query = queries + .iter() + .find(|q| q.contains("customers")) + .expect("Should have customers query"); + assert!( + customers_query.contains("WHERE"), + "Customers query should have WHERE clause" + ); } } From b419db489a0ca32646d3704b7220a26f9ec18950 Mon Sep 17 00:00:00 2001 From: Glauber Costa Date: Fri, 19 Sep 2025 05:23:10 -0500 Subject: [PATCH 2/3] Implement the DBSP merge operator The Merge operator is a stateless operator that merges two deltas. There are two modes: Distinct, where we merge together values that are the same, and All, where we preserve all values. We use the rowid of the hashable row to guarantee that: In Distinct mode, the rowid is set to 0 in both sides. If they values are the same, they will hash to the same thing. For All, the rowids are different. The merge operator is used for the UNION statement, which is a cornerstone of Recursive CTEs. --- core/incremental/merge_operator.rs | 187 ++++++++++++++++ core/incremental/mod.rs | 1 + core/incremental/operator.rs | 336 +++++++++++++++++++++++++++++ 3 files changed, 524 insertions(+) create mode 100644 core/incremental/merge_operator.rs diff --git a/core/incremental/merge_operator.rs b/core/incremental/merge_operator.rs new file mode 100644 index 000000000..c8547028f --- /dev/null +++ b/core/incremental/merge_operator.rs @@ -0,0 +1,187 @@ +// Merge operator for DBSP - combines two delta streams +// Used in recursive CTEs and UNION operations + +use crate::incremental::dbsp::{Delta, DeltaPair, HashableRow}; +use crate::incremental::operator::{ + ComputationTracker, DbspStateCursors, EvalState, IncrementalOperator, +}; +use crate::types::IOResult; +use crate::Result; +use std::collections::{hash_map::DefaultHasher, HashMap}; +use std::fmt::{self, Display}; +use std::hash::{Hash, Hasher}; +use std::sync::{Arc, Mutex}; + +/// How the merge operator should handle rowids when combining deltas +#[derive(Debug, Clone)] +pub enum UnionMode { + /// For UNION (distinct) - hash values only to merge duplicates + Distinct, + /// For UNION ALL - include source table name in hash to keep duplicates separate + All { + left_table: String, + right_table: String, + }, +} + +/// Merge operator that combines two input deltas into one output delta +/// Handles both recursive CTEs and UNION/UNION ALL operations +#[derive(Debug)] +pub struct MergeOperator { + operator_id: usize, + union_mode: UnionMode, + /// For UNION: tracks seen value hashes with their assigned rowids + /// For UNION ALL: tracks (source_id, original_rowid) -> assigned_rowid mappings + seen_rows: HashMap, // hash -> assigned_rowid + /// Next rowid to assign for new rows + next_rowid: i64, +} + +impl MergeOperator { + /// Create a new merge operator with specified union mode + pub fn new(operator_id: usize, mode: UnionMode) -> Self { + Self { + operator_id, + union_mode: mode, + seen_rows: HashMap::new(), + next_rowid: 1, + } + } + + /// Transform a delta's rowids based on the union mode with state tracking + fn transform_delta(&mut self, delta: Delta, is_left: bool) -> Delta { + match &self.union_mode { + UnionMode::Distinct => { + // For UNION distinct, track seen values and deduplicate + let mut output = Delta::new(); + for (row, weight) in delta.changes { + // Hash only the values (not rowid) for deduplication + let temp_row = HashableRow::new(0, row.values.clone()); + let value_hash = temp_row.cached_hash(); + + // Check if we've seen this value before + let assigned_rowid = + if let Some(&existing_rowid) = self.seen_rows.get(&value_hash) { + // Value already seen - use existing rowid + existing_rowid + } else { + // New value - assign new rowid and remember it + let new_rowid = self.next_rowid; + self.next_rowid += 1; + self.seen_rows.insert(value_hash, new_rowid); + new_rowid + }; + + // Output the row with the assigned rowid + let final_row = HashableRow::new(assigned_rowid, temp_row.values); + output.changes.push((final_row, weight)); + } + output + } + UnionMode::All { + left_table, + right_table, + } => { + // For UNION ALL, maintain consistent rowid mapping per source + let table = if is_left { left_table } else { right_table }; + let mut source_hasher = DefaultHasher::new(); + table.hash(&mut source_hasher); + let source_id = source_hasher.finish(); + + let mut output = Delta::new(); + for (row, weight) in delta.changes { + // Create a unique key for this (source, rowid) pair + let mut key_hasher = DefaultHasher::new(); + source_id.hash(&mut key_hasher); + row.rowid.hash(&mut key_hasher); + let key_hash = key_hasher.finish(); + + // Check if we've seen this (source, rowid) before + let assigned_rowid = + if let Some(&existing_rowid) = self.seen_rows.get(&key_hash) { + // Use existing rowid for this (source, rowid) pair + existing_rowid + } else { + // New row - assign new rowid + let new_rowid = self.next_rowid; + self.next_rowid += 1; + self.seen_rows.insert(key_hash, new_rowid); + new_rowid + }; + + // Create output row with consistent rowid + let final_row = HashableRow::new(assigned_rowid, row.values.clone()); + output.changes.push((final_row, weight)); + } + output + } + } + } +} + +impl Display for MergeOperator { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + match &self.union_mode { + UnionMode::Distinct => write!(f, "MergeOperator({}, UNION)", self.operator_id), + UnionMode::All { .. } => write!(f, "MergeOperator({}, UNION ALL)", self.operator_id), + } + } +} + +impl IncrementalOperator for MergeOperator { + fn eval( + &mut self, + input: &mut EvalState, + _cursors: &mut DbspStateCursors, + ) -> Result> { + match input { + EvalState::Init { deltas } => { + // Extract deltas from the evaluation state + let delta_pair = std::mem::take(deltas); + + // Transform deltas based on union mode (with state tracking) + let left_transformed = self.transform_delta(delta_pair.left, true); + let right_transformed = self.transform_delta(delta_pair.right, false); + + // Merge the transformed deltas + let mut output = Delta::new(); + output.merge(&left_transformed); + output.merge(&right_transformed); + + // Move to Done state + *input = EvalState::Done; + + Ok(IOResult::Done(output)) + } + EvalState::Aggregate(_) | EvalState::Join(_) | EvalState::Uninitialized => { + // Merge operator only handles Init state + unreachable!("MergeOperator only handles Init state") + } + EvalState::Done => { + // Already evaluated + Ok(IOResult::Done(Delta::new())) + } + } + } + + fn commit( + &mut self, + deltas: DeltaPair, + _cursors: &mut DbspStateCursors, + ) -> Result> { + // Transform deltas based on union mode + let left_transformed = self.transform_delta(deltas.left, true); + let right_transformed = self.transform_delta(deltas.right, false); + + // Merge the transformed deltas + let mut output = Delta::new(); + output.merge(&left_transformed); + output.merge(&right_transformed); + + Ok(IOResult::Done(output)) + } + + fn set_tracker(&mut self, _tracker: Arc>) { + // Merge operator doesn't need tracking for now + } +} diff --git a/core/incremental/mod.rs b/core/incremental/mod.rs index 67eed60e2..5ac635cce 100644 --- a/core/incremental/mod.rs +++ b/core/incremental/mod.rs @@ -6,6 +6,7 @@ pub mod expr_compiler; pub mod filter_operator; pub mod input_operator; pub mod join_operator; +pub mod merge_operator; pub mod operator; pub mod persistence; pub mod project_operator; diff --git a/core/incremental/operator.rs b/core/incremental/operator.rs index 2af512504..53a5b1949 100644 --- a/core/incremental/operator.rs +++ b/core/incremental/operator.rs @@ -3674,4 +3674,340 @@ mod tests { assert!(was_new, "Duplicate rowid found: {}. This would cause rows to overwrite each other in btree storage!", row.rowid); } } + + // Merge operator tests + use crate::incremental::merge_operator::{MergeOperator, UnionMode}; + + #[test] + fn test_merge_operator_basic() { + let (_pager, table_root_page_id, index_root_page_id) = create_test_pager(); + let table_cursor = BTreeCursor::new_table(None, _pager.clone(), table_root_page_id, 5); + let index_def = create_dbsp_state_index(index_root_page_id); + let index_cursor = + BTreeCursor::new_index(None, _pager.clone(), index_root_page_id, &index_def, 4); + let mut cursors = DbspStateCursors::new(table_cursor, index_cursor); + + let mut merge_op = MergeOperator::new( + 1, + UnionMode::All { + left_table: "table1".to_string(), + right_table: "table2".to_string(), + }, + ); + + // Create two deltas + let mut left_delta = Delta::new(); + left_delta.insert(1, vec![Value::Integer(1)]); + left_delta.insert(2, vec![Value::Integer(2)]); + + let mut right_delta = Delta::new(); + right_delta.insert(3, vec![Value::Integer(3)]); + right_delta.insert(4, vec![Value::Integer(4)]); + + let delta_pair = DeltaPair::new(left_delta, right_delta); + + // Evaluate merge + let result = merge_op.commit(delta_pair, &mut cursors).unwrap(); + + if let IOResult::Done(merged) = result { + // Should have all 4 entries + assert_eq!(merged.len(), 4); + + // Check that all values are present + let values: Vec = merged + .changes + .iter() + .filter_map(|(row, weight)| { + if *weight > 0 && !row.values.is_empty() { + if let Value::Integer(n) = &row.values[0] { + Some(*n) + } else { + None + } + } else { + None + } + }) + .collect(); + + assert!(values.contains(&1)); + assert!(values.contains(&2)); + assert!(values.contains(&3)); + assert!(values.contains(&4)); + } else { + panic!("Expected Done result"); + } + } + + #[test] + fn test_merge_operator_stateful_distinct() { + let (_pager, table_root_page_id, index_root_page_id) = create_test_pager(); + let table_cursor = BTreeCursor::new_table(None, _pager.clone(), table_root_page_id, 5); + let index_def = create_dbsp_state_index(index_root_page_id); + let index_cursor = + BTreeCursor::new_index(None, _pager.clone(), index_root_page_id, &index_def, 4); + let mut cursors = DbspStateCursors::new(table_cursor, index_cursor); + + // Test that UNION (distinct) properly deduplicates across multiple operations + let mut merge_op = MergeOperator::new(7, UnionMode::Distinct); + + // First operation: insert values 1, 2, 3 from left and 2, 3, 4 from right + let mut left_delta1 = Delta::new(); + left_delta1.insert(1, vec![Value::Integer(1)]); + left_delta1.insert(2, vec![Value::Integer(2)]); + left_delta1.insert(3, vec![Value::Integer(3)]); + + let mut right_delta1 = Delta::new(); + right_delta1.insert(4, vec![Value::Integer(2)]); // Duplicate value 2 + right_delta1.insert(5, vec![Value::Integer(3)]); // Duplicate value 3 + right_delta1.insert(6, vec![Value::Integer(4)]); + + let result1 = merge_op + .commit(DeltaPair::new(left_delta1, right_delta1), &mut cursors) + .unwrap(); + if let IOResult::Done(merged1) = result1 { + // Should have 4 unique values (1, 2, 3, 4) + // But 6 total entries (3 from left + 3 from right) + assert_eq!(merged1.len(), 6); + + // Collect unique rowids - should be 4 + let unique_rowids: std::collections::HashSet = + merged1.changes.iter().map(|(row, _)| row.rowid).collect(); + assert_eq!( + unique_rowids.len(), + 4, + "Should have 4 unique rowids for 4 unique values" + ); + } else { + panic!("Expected Done result"); + } + + // Second operation: insert value 2 again from left, and value 5 from right + let mut left_delta2 = Delta::new(); + left_delta2.insert(7, vec![Value::Integer(2)]); // Duplicate of existing value + + let mut right_delta2 = Delta::new(); + right_delta2.insert(8, vec![Value::Integer(5)]); // New value + + let result2 = merge_op + .commit(DeltaPair::new(left_delta2, right_delta2), &mut cursors) + .unwrap(); + if let IOResult::Done(merged2) = result2 { + assert_eq!(merged2.len(), 2, "Should have 2 entries in delta"); + + // Check that value 2 got the same rowid as before + let has_existing_rowid = merged2 + .changes + .iter() + .any(|(row, _)| row.values == vec![Value::Integer(2)] && row.rowid <= 4); + assert!(has_existing_rowid, "Value 2 should reuse existing rowid"); + + // Check that value 5 got a new rowid + let has_new_rowid = merged2 + .changes + .iter() + .any(|(row, _)| row.values == vec![Value::Integer(5)] && row.rowid > 4); + assert!(has_new_rowid, "Value 5 should get a new rowid"); + } else { + panic!("Expected Done result"); + } + } + + #[test] + fn test_merge_operator_single_sided_inputs_union_all() { + let (_pager, table_root_page_id, index_root_page_id) = create_test_pager(); + let table_cursor = BTreeCursor::new_table(None, _pager.clone(), table_root_page_id, 5); + let index_def = create_dbsp_state_index(index_root_page_id); + let index_cursor = + BTreeCursor::new_index(None, _pager.clone(), index_root_page_id, &index_def, 4); + let mut cursors = DbspStateCursors::new(table_cursor, index_cursor); + + // Test UNION ALL with inputs coming from only one side at a time + let mut merge_op = MergeOperator::new( + 10, + UnionMode::All { + left_table: "orders".to_string(), + right_table: "archived_orders".to_string(), + }, + ); + + // First: only left side (orders) has data + let mut left_delta1 = Delta::new(); + left_delta1.insert(100, vec![Value::Integer(1001)]); + left_delta1.insert(101, vec![Value::Integer(1002)]); + + let right_delta1 = Delta::new(); // Empty right side + + let result1 = merge_op + .commit(DeltaPair::new(left_delta1, right_delta1), &mut cursors) + .unwrap(); + + let first_rowids = if let IOResult::Done(ref merged1) = result1 { + assert_eq!(merged1.len(), 2, "Should have 2 entries from left only"); + merged1 + .changes + .iter() + .map(|(row, _)| row.rowid) + .collect::>() + } else { + panic!("Expected Done result"); + }; + + // Second: only right side (archived_orders) has data + let left_delta2 = Delta::new(); // Empty left side + + let mut right_delta2 = Delta::new(); + right_delta2.insert(100, vec![Value::Integer(2001)]); // Same rowid as left, different table + right_delta2.insert(102, vec![Value::Integer(2002)]); + + let result2 = merge_op + .commit(DeltaPair::new(left_delta2, right_delta2), &mut cursors) + .unwrap(); + let second_result_rowid_100 = if let IOResult::Done(ref merged2) = result2 { + assert_eq!(merged2.len(), 2, "Should have 2 entries from right only"); + + // Rowids should be different from the left side even though original rowid 100 is the same + let second_rowids: Vec = + merged2.changes.iter().map(|(row, _)| row.rowid).collect(); + for rowid in &second_rowids { + assert!( + !first_rowids.contains(rowid), + "Right side rowids should be different from left side rowids" + ); + } + + // Save rowid for archived_orders.100 + merged2 + .changes + .iter() + .find(|(row, _)| row.values == vec![Value::Integer(2001)]) + .map(|(row, _)| row.rowid) + .unwrap() + } else { + panic!("Expected Done result"); + }; + + // Third: left side again with same rowids as before + let mut left_delta3 = Delta::new(); + left_delta3.insert(100, vec![Value::Integer(1003)]); // Same rowid 100 from orders + left_delta3.insert(101, vec![Value::Integer(1004)]); // Same rowid 101 from orders + + let right_delta3 = Delta::new(); // Empty right side + + let result3 = merge_op + .commit(DeltaPair::new(left_delta3, right_delta3), &mut cursors) + .unwrap(); + if let IOResult::Done(merged3) = result3 { + assert_eq!(merged3.len(), 2, "Should have 2 entries from left"); + + // Should get the same assigned rowids as the first operation + let third_rowids: Vec = merged3.changes.iter().map(|(row, _)| row.rowid).collect(); + assert_eq!( + first_rowids, third_rowids, + "Same (table, rowid) pairs should get same assigned rowids" + ); + } else { + panic!("Expected Done result"); + } + + // Fourth: right side again with rowid 100 + let left_delta4 = Delta::new(); // Empty left side + + let mut right_delta4 = Delta::new(); + right_delta4.insert(100, vec![Value::Integer(2003)]); // Same rowid 100 from archived_orders + + let result4 = merge_op + .commit(DeltaPair::new(left_delta4, right_delta4), &mut cursors) + .unwrap(); + if let IOResult::Done(merged4) = result4 { + assert_eq!(merged4.len(), 1, "Should have 1 entry from right"); + + // Should get same assigned rowid as second operation for archived_orders.100 + let fourth_rowid = merged4.changes[0].0.rowid; + assert_eq!( + fourth_rowid, second_result_rowid_100, + "archived_orders rowid 100 should consistently map to same assigned rowid" + ); + } else { + panic!("Expected Done result"); + } + } + + #[test] + fn test_merge_operator_both_sides_empty() { + let (_pager, table_root_page_id, index_root_page_id) = create_test_pager(); + let table_cursor = BTreeCursor::new_table(None, _pager.clone(), table_root_page_id, 5); + let index_def = create_dbsp_state_index(index_root_page_id); + let index_cursor = + BTreeCursor::new_index(None, _pager.clone(), index_root_page_id, &index_def, 4); + let mut cursors = DbspStateCursors::new(table_cursor, index_cursor); + + // Test that both sides being empty works correctly + let mut merge_op = MergeOperator::new( + 12, + UnionMode::All { + left_table: "t1".to_string(), + right_table: "t2".to_string(), + }, + ); + + // First: insert some data to establish state + let mut left_delta1 = Delta::new(); + left_delta1.insert(1, vec![Value::Integer(100)]); + let mut right_delta1 = Delta::new(); + right_delta1.insert(1, vec![Value::Integer(200)]); + + let result1 = merge_op + .commit(DeltaPair::new(left_delta1, right_delta1), &mut cursors) + .unwrap(); + let original_t1_rowid = if let IOResult::Done(ref merged1) = result1 { + assert_eq!(merged1.len(), 2, "Should have 2 entries initially"); + // Save the rowid for t1.rowid=1 + merged1 + .changes + .iter() + .find(|(row, _)| row.values == vec![Value::Integer(100)]) + .map(|(row, _)| row.rowid) + .unwrap() + } else { + panic!("Expected Done result"); + }; + + // Second: both sides empty - should produce empty output + let empty_left = Delta::new(); + let empty_right = Delta::new(); + + let result2 = merge_op + .commit(DeltaPair::new(empty_left, empty_right), &mut cursors) + .unwrap(); + if let IOResult::Done(merged2) = result2 { + assert_eq!( + merged2.len(), + 0, + "Both empty sides should produce empty output" + ); + } else { + panic!("Expected Done result"); + } + + // Third: add more data to verify state is still intact + let mut left_delta3 = Delta::new(); + left_delta3.insert(1, vec![Value::Integer(101)]); // Same rowid as before + let right_delta3 = Delta::new(); + + let result3 = merge_op + .commit(DeltaPair::new(left_delta3, right_delta3), &mut cursors) + .unwrap(); + if let IOResult::Done(merged3) = result3 { + assert_eq!(merged3.len(), 1, "Should have 1 entry"); + // Should reuse the same assigned rowid for t1.rowid=1 + let rowid = merged3.changes[0].0.rowid; + assert_eq!( + rowid, original_t1_rowid, + "Should maintain consistent rowid mapping after empty operation" + ); + } else { + panic!("Expected Done result"); + } + } } From 2627ad44de1cd23dc96a7e00d6bb17afbd1ab3f4 Mon Sep 17 00:00:00 2001 From: Glauber Costa Date: Fri, 19 Sep 2025 05:18:44 -0500 Subject: [PATCH 3/3] support union statements in the DBSP circuit compiler --- core/incremental/compiler.rs | 125 +++++++++++- testing/materialized_views.test | 337 ++++++++++++++++++++++++++++++++ 2 files changed, 461 insertions(+), 1 deletion(-) diff --git a/core/incremental/compiler.rs b/core/incremental/compiler.rs index 8c8189261..cec950f35 100644 --- a/core/incremental/compiler.rs +++ b/core/incremental/compiler.rs @@ -298,6 +298,8 @@ pub enum DbspOperator { }, /// Input operator - source of data Input { name: String, schema: SchemaRef }, + /// Merge operator for combining streams (used in recursive CTEs and UNION) + Merge { schema: SchemaRef }, } /// Represents an expression in DBSP @@ -807,6 +809,13 @@ impl DbspCircuit { DbspOperator::Input { name, .. } => { writeln!(f, "{indent}Input[{node_id}]: {name}")?; } + DbspOperator::Merge { schema } => { + writeln!( + f, + "{indent}Merge[{node_id}]: UNION/Recursive (schema: {} columns)", + schema.columns.len() + )?; + } } for input_id in &node.inputs { @@ -1286,8 +1295,12 @@ impl DbspCompiler { ); Ok(node_id) } + LogicalPlan::Union(union) => { + // Handle UNION and UNION ALL + self.compile_union(union) + } _ => Err(LimboError::ParseError( - format!("Unsupported operator in DBSP compiler: only Filter, Projection, Join and Aggregate are supported, got: {:?}", + format!("Unsupported operator in DBSP compiler: only Filter, Projection, Join, Aggregate, and Union are supported, got: {:?}", match plan { LogicalPlan::Sort(_) => "Sort", LogicalPlan::Limit(_) => "Limit", @@ -1304,6 +1317,116 @@ impl DbspCompiler { } } + /// Extract a representative table name from a logical plan (for UNION ALL identification) + /// Returns a string that uniquely identifies the source of the data + fn extract_source_identifier(plan: &LogicalPlan) -> String { + match plan { + LogicalPlan::TableScan(scan) => { + // Direct table scan - use the table name + scan.table_name.clone() + } + LogicalPlan::Projection(proj) => { + // Pass through to input + Self::extract_source_identifier(&proj.input) + } + LogicalPlan::Filter(filter) => { + // Pass through to input + Self::extract_source_identifier(&filter.input) + } + LogicalPlan::Aggregate(agg) => { + // Aggregate of a table + format!("agg_{}", Self::extract_source_identifier(&agg.input)) + } + LogicalPlan::Sort(sort) => { + // Pass through to input + Self::extract_source_identifier(&sort.input) + } + LogicalPlan::Limit(limit) => { + // Pass through to input + Self::extract_source_identifier(&limit.input) + } + LogicalPlan::Join(join) => { + // Join of two sources - combine their identifiers + let left_id = Self::extract_source_identifier(&join.left); + let right_id = Self::extract_source_identifier(&join.right); + format!("join_{left_id}_{right_id}") + } + LogicalPlan::Union(union) => { + // Union of multiple sources + if union.inputs.is_empty() { + "union_empty".to_string() + } else { + let identifiers: Vec = union + .inputs + .iter() + .map(|input| Self::extract_source_identifier(input)) + .collect(); + format!("union_{}", identifiers.join("_")) + } + } + LogicalPlan::Distinct(distinct) => { + // Distinct of a source + format!( + "distinct_{}", + Self::extract_source_identifier(&distinct.input) + ) + } + LogicalPlan::WithCTE(with_cte) => { + // CTE body + Self::extract_source_identifier(&with_cte.body) + } + LogicalPlan::CTERef(cte_ref) => { + // CTE reference - use the CTE name + format!("cte_{}", cte_ref.name) + } + LogicalPlan::EmptyRelation(_) => "empty".to_string(), + LogicalPlan::Values(_) => "values".to_string(), + } + } + + /// Compile a UNION operator + fn compile_union(&mut self, union: &crate::translate::logical::Union) -> Result { + if union.inputs.len() != 2 { + return Err(LimboError::ParseError(format!( + "UNION requires exactly 2 inputs, got {}", + union.inputs.len() + ))); + } + + // Extract source identifiers from each input (for UNION ALL) + let left_source = Self::extract_source_identifier(&union.inputs[0]); + let right_source = Self::extract_source_identifier(&union.inputs[1]); + + // Compile left and right inputs + let left_id = self.compile_plan(&union.inputs[0])?; + let right_id = self.compile_plan(&union.inputs[1])?; + + use crate::incremental::merge_operator::{MergeOperator, UnionMode}; + + // Create a merge operator that handles the rowid transformation + let operator_id = self.circuit.next_id; + let mode = if union.all { + // For UNION ALL, pass the source identifiers + UnionMode::All { + left_table: left_source, + right_table: right_source, + } + } else { + UnionMode::Distinct + }; + let merge_operator = Box::new(MergeOperator::new(operator_id, mode)); + + let merge_id = self.circuit.add_node( + DbspOperator::Merge { + schema: union.schema.clone(), + }, + vec![left_id, right_id], + merge_operator, + ); + + Ok(merge_id) + } + /// Convert a logical expression to a DBSP expression fn compile_expr(expr: &LogicalExpr) -> Result { match expr { diff --git a/testing/materialized_views.test b/testing/materialized_views.test index 15229a48c..354f65d39 100755 --- a/testing/materialized_views.test +++ b/testing/materialized_views.test @@ -1091,3 +1091,340 @@ do_execsql_test_on_specific_db {:memory:} matview-join-complex-where { } {Charlie|10|100|1000 Alice|5|100|500 Charlie|6|75|450} + +# Test UNION queries in materialized views +do_execsql_test_on_specific_db {:memory:} matview-union-simple { + CREATE TABLE sales_online(id INTEGER, product TEXT, amount INTEGER); + CREATE TABLE sales_store(id INTEGER, product TEXT, amount INTEGER); + + INSERT INTO sales_online VALUES + (1, 'Laptop', 1200), + (2, 'Mouse', 25), + (3, 'Monitor', 400); + + INSERT INTO sales_store VALUES + (1, 'Keyboard', 75), + (2, 'Chair', 150), + (3, 'Desk', 350); + + -- Create a view that combines both sources + CREATE MATERIALIZED VIEW all_sales AS + SELECT product, amount FROM sales_online + UNION ALL + SELECT product, amount FROM sales_store; + + SELECT * FROM all_sales ORDER BY product; +} {Chair|150 +Desk|350 +Keyboard|75 +Laptop|1200 +Monitor|400 +Mouse|25} + +do_execsql_test_on_specific_db {:memory:} matview-union-with-where { + CREATE TABLE employees(id INTEGER, name TEXT, dept TEXT, salary INTEGER); + CREATE TABLE contractors(id INTEGER, name TEXT, dept TEXT, rate INTEGER); + + INSERT INTO employees VALUES + (1, 'Alice', 'Engineering', 90000), + (2, 'Bob', 'Sales', 60000), + (3, 'Charlie', 'Engineering', 85000); + + INSERT INTO contractors VALUES + (1, 'David', 'Engineering', 150), + (2, 'Eve', 'Marketing', 120), + (3, 'Frank', 'Engineering', 180); + + -- High-earning staff from both categories + CREATE MATERIALIZED VIEW high_earners AS + SELECT name, dept, salary as compensation FROM employees WHERE salary > 80000 + UNION ALL + SELECT name, dept, rate * 2000 as compensation FROM contractors WHERE rate > 140; + + SELECT * FROM high_earners ORDER BY name; +} {Alice|Engineering|90000 +Charlie|Engineering|85000 +David|Engineering|300000 +Frank|Engineering|360000} + +do_execsql_test_on_specific_db {:memory:} matview-union-same-table-different-filters { + CREATE TABLE orders(id INTEGER, customer_id INTEGER, product TEXT, amount INTEGER, status TEXT); + + INSERT INTO orders VALUES + (1, 1, 'Laptop', 1200, 'completed'), + (2, 2, 'Mouse', 25, 'pending'), + (3, 1, 'Monitor', 400, 'completed'), + (4, 3, 'Keyboard', 75, 'cancelled'), + (5, 2, 'Desk', 350, 'completed'), + (6, 3, 'Chair', 150, 'pending'); + + -- View showing priority orders: high-value OR pending status + CREATE MATERIALIZED VIEW priority_orders AS + SELECT id, customer_id, product, amount FROM orders WHERE amount > 300 + UNION ALL + SELECT id, customer_id, product, amount FROM orders WHERE status = 'pending'; + + SELECT * FROM priority_orders ORDER BY id; +} {1|1|Laptop|1200 +2|2|Mouse|25 +3|1|Monitor|400 +5|2|Desk|350 +6|3|Chair|150} + +do_execsql_test_on_specific_db {:memory:} matview-union-with-aggregation { + CREATE TABLE q1_sales(product TEXT, quantity INTEGER, revenue INTEGER); + CREATE TABLE q2_sales(product TEXT, quantity INTEGER, revenue INTEGER); + + INSERT INTO q1_sales VALUES + ('Laptop', 10, 12000), + ('Mouse', 50, 1250), + ('Monitor', 8, 3200); + + INSERT INTO q2_sales VALUES + ('Laptop', 15, 18000), + ('Mouse', 60, 1500), + ('Keyboard', 30, 2250); + + -- Combined quarterly summary + CREATE MATERIALIZED VIEW half_year_summary AS + SELECT 'Q1' as quarter, SUM(quantity) as total_units, SUM(revenue) as total_revenue + FROM q1_sales + UNION ALL + SELECT 'Q2' as quarter, SUM(quantity) as total_units, SUM(revenue) as total_revenue + FROM q2_sales; + + SELECT * FROM half_year_summary ORDER BY quarter; +} {Q1|68|16450 +Q2|105|21750} + +do_execsql_test_on_specific_db {:memory:} matview-union-with-join { + CREATE TABLE customers(id INTEGER PRIMARY KEY, name TEXT, type TEXT); + CREATE TABLE orders(id INTEGER PRIMARY KEY, customer_id INTEGER, amount INTEGER); + CREATE TABLE quotes(id INTEGER PRIMARY KEY, customer_id INTEGER, amount INTEGER); + + INSERT INTO customers VALUES + (1, 'Alice', 'premium'), + (2, 'Bob', 'regular'), + (3, 'Charlie', 'premium'); + + INSERT INTO orders VALUES + (1, 1, 1000), + (2, 2, 500), + (3, 3, 1500); + + INSERT INTO quotes VALUES + (1, 1, 800), + (2, 2, 300), + (3, 3, 2000); + + -- All premium customer transactions (orders and quotes) + CREATE MATERIALIZED VIEW premium_transactions AS + SELECT c.name, 'order' as type, o.amount + FROM customers c + JOIN orders o ON c.id = o.customer_id + WHERE c.type = 'premium' + UNION ALL + SELECT c.name, 'quote' as type, q.amount + FROM customers c + JOIN quotes q ON c.id = q.customer_id + WHERE c.type = 'premium'; + + SELECT * FROM premium_transactions ORDER BY name, type, amount; +} {Alice|order|1000 +Alice|quote|800 +Charlie|order|1500 +Charlie|quote|2000} + +do_execsql_test_on_specific_db {:memory:} matview-union-distinct { + CREATE TABLE active_users(id INTEGER, name TEXT, email TEXT); + CREATE TABLE inactive_users(id INTEGER, name TEXT, email TEXT); + + INSERT INTO active_users VALUES + (1, 'Alice', 'alice@example.com'), + (2, 'Bob', 'bob@example.com'), + (3, 'Charlie', 'charlie@example.com'); + + INSERT INTO inactive_users VALUES + (4, 'David', 'david@example.com'), + (2, 'Bob', 'bob@example.com'), -- Bob appears in both + (5, 'Eve', 'eve@example.com'); + + -- All unique users (using UNION to deduplicate) + CREATE MATERIALIZED VIEW all_users AS + SELECT id, name, email FROM active_users + UNION + SELECT id, name, email FROM inactive_users; + + SELECT * FROM all_users ORDER BY id; +} {1|Alice|alice@example.com +2|Bob|bob@example.com +3|Charlie|charlie@example.com +4|David|david@example.com +5|Eve|eve@example.com} + +do_execsql_test_on_specific_db {:memory:} matview-union-complex-multiple-branches { + CREATE TABLE products(id INTEGER, name TEXT, category TEXT, price INTEGER); + + INSERT INTO products VALUES + (1, 'Laptop', 'Electronics', 1200), + (2, 'Mouse', 'Electronics', 25), + (3, 'Desk', 'Furniture', 350), + (4, 'Chair', 'Furniture', 150), + (5, 'Monitor', 'Electronics', 400), + (6, 'Keyboard', 'Electronics', 75), + (7, 'Bookshelf', 'Furniture', 200), + (8, 'Tablet', 'Electronics', 600); + + -- Products of interest: expensive electronics, all furniture, or very cheap items + CREATE MATERIALIZED VIEW featured_products AS + SELECT name, category, price, 'PremiumElectronic' as tag + FROM products + WHERE category = 'Electronics' AND price > 500 + UNION ALL + SELECT name, category, price, 'Furniture' as tag + FROM products + WHERE category = 'Furniture' + UNION ALL + SELECT name, category, price, 'Budget' as tag + FROM products + WHERE price < 50; + + SELECT * FROM featured_products ORDER BY tag, name; +} {Mouse|Electronics|25|Budget +Bookshelf|Furniture|200|Furniture +Chair|Furniture|150|Furniture +Desk|Furniture|350|Furniture +Laptop|Electronics|1200|PremiumElectronic +Tablet|Electronics|600|PremiumElectronic} + +do_execsql_test_on_specific_db {:memory:} matview-union-maintenance-insert { + CREATE TABLE t1(id INTEGER, value INTEGER); + CREATE TABLE t2(id INTEGER, value INTEGER); + + INSERT INTO t1 VALUES (1, 100), (2, 200); + INSERT INTO t2 VALUES (3, 300), (4, 400); + + CREATE MATERIALIZED VIEW combined AS + SELECT id, value FROM t1 WHERE value > 150 + UNION ALL + SELECT id, value FROM t2 WHERE value > 350; + + SELECT * FROM combined ORDER BY id; + + -- Insert into t1 + INSERT INTO t1 VALUES (5, 500); + SELECT * FROM combined ORDER BY id; + + -- Insert into t2 + INSERT INTO t2 VALUES (6, 600); + SELECT * FROM combined ORDER BY id; +} {2|200 +4|400 +2|200 +4|400 +5|500 +2|200 +4|400 +5|500 +6|600} + +do_execsql_test_on_specific_db {:memory:} matview-union-maintenance-delete { + CREATE TABLE source1(id INTEGER PRIMARY KEY, data TEXT); + CREATE TABLE source2(id INTEGER PRIMARY KEY, data TEXT); + + INSERT INTO source1 VALUES (1, 'A'), (2, 'B'), (3, 'C'); + INSERT INTO source2 VALUES (4, 'D'), (5, 'E'), (6, 'F'); + + CREATE MATERIALIZED VIEW merged AS + SELECT id, data FROM source1 + UNION ALL + SELECT id, data FROM source2; + + SELECT COUNT(*) FROM merged; + + DELETE FROM source1 WHERE id = 2; + SELECT COUNT(*) FROM merged; + + DELETE FROM source2 WHERE id > 4; + SELECT COUNT(*) FROM merged; +} {6 +5 +3} + +do_execsql_test_on_specific_db {:memory:} matview-union-maintenance-update { + CREATE TABLE high_priority(id INTEGER PRIMARY KEY, task TEXT, priority INTEGER); + CREATE TABLE normal_priority(id INTEGER PRIMARY KEY, task TEXT, priority INTEGER); + + INSERT INTO high_priority VALUES (1, 'Task A', 10), (2, 'Task B', 9); + INSERT INTO normal_priority VALUES (3, 'Task C', 5), (4, 'Task D', 6); + + CREATE MATERIALIZED VIEW active_tasks AS + SELECT id, task, priority FROM high_priority WHERE priority >= 9 + UNION ALL + SELECT id, task, priority FROM normal_priority WHERE priority >= 5; + + SELECT COUNT(*) FROM active_tasks; + + -- Update drops a high priority task below threshold + UPDATE high_priority SET priority = 8 WHERE id = 2; + SELECT COUNT(*) FROM active_tasks; + + -- Update brings a normal task above threshold + UPDATE normal_priority SET priority = 3 WHERE id = 3; + SELECT COUNT(*) FROM active_tasks; +} {4 +3 +2} + +# Test UNION ALL with same table and different WHERE conditions +do_execsql_test_on_specific_db {:memory:} matview-union-all-same-table { + CREATE TABLE test(id INTEGER PRIMARY KEY, value INTEGER); + INSERT INTO test VALUES (1, 10), (2, 20); + + -- This UNION ALL should return both rows + CREATE MATERIALIZED VIEW union_view AS + SELECT id, value FROM test WHERE value < 15 + UNION ALL + SELECT id, value FROM test WHERE value > 15; + + -- Should return 2 rows: (1,10) and (2,20) + SELECT * FROM union_view ORDER BY id; +} {1|10 +2|20} + +# Test UNION ALL preserves all rows in count +do_execsql_test_on_specific_db {:memory:} matview-union-all-row-count { + CREATE TABLE data(id INTEGER PRIMARY KEY, num INTEGER); + INSERT INTO data VALUES (1, 5), (2, 15), (3, 25); + + CREATE MATERIALIZED VIEW split_view AS + SELECT id, num FROM data WHERE num <= 10 + UNION ALL + SELECT id, num FROM data WHERE num > 10; + + -- Should return count of 3 + SELECT COUNT(*) FROM split_view; +} {3} + +# Test UNION ALL with text columns and filtering +do_execsql_test_on_specific_db {:memory:} matview-union-all-text-filter { + CREATE TABLE items(id INTEGER PRIMARY KEY, category TEXT, price INTEGER); + INSERT INTO items VALUES + (1, 'cheap', 10), + (2, 'expensive', 100), + (3, 'cheap', 20), + (4, 'expensive', 200); + + CREATE MATERIALIZED VIEW price_categories AS + SELECT id, category, price FROM items WHERE category = 'cheap' + UNION ALL + SELECT id, category, price FROM items WHERE category = 'expensive'; + + -- Should return all 4 items + SELECT COUNT(*) FROM price_categories; + SELECT id FROM price_categories ORDER BY id; +} {4 +1 +2 +3 +4}