diff --git a/core/incremental/compiler.rs b/core/incremental/compiler.rs index 7c215f2f1..902dfe4f8 100644 --- a/core/incremental/compiler.rs +++ b/core/incremental/compiler.rs @@ -298,6 +298,8 @@ pub enum DbspOperator { }, /// Input operator - source of data Input { name: String, schema: SchemaRef }, + /// Merge operator for combining streams (used in recursive CTEs and UNION) + Merge { schema: SchemaRef }, } /// Represents an expression in DBSP @@ -807,6 +809,13 @@ impl DbspCircuit { DbspOperator::Input { name, .. } => { writeln!(f, "{indent}Input[{node_id}]: {name}")?; } + DbspOperator::Merge { schema } => { + writeln!( + f, + "{indent}Merge[{node_id}]: UNION/Recursive (schema: {} columns)", + schema.columns.len() + )?; + } } for input_id in &node.inputs { @@ -1300,8 +1309,12 @@ impl DbspCompiler { ); Ok(node_id) } + LogicalPlan::Union(union) => { + // Handle UNION and UNION ALL + self.compile_union(union) + } _ => Err(LimboError::ParseError( - format!("Unsupported operator in DBSP compiler: only Filter, Projection, Join and Aggregate are supported, got: {:?}", + format!("Unsupported operator in DBSP compiler: only Filter, Projection, Join, Aggregate, and Union are supported, got: {:?}", match plan { LogicalPlan::Sort(_) => "Sort", LogicalPlan::Limit(_) => "Limit", @@ -1318,6 +1331,116 @@ impl DbspCompiler { } } + /// Extract a representative table name from a logical plan (for UNION ALL identification) + /// Returns a string that uniquely identifies the source of the data + fn extract_source_identifier(plan: &LogicalPlan) -> String { + match plan { + LogicalPlan::TableScan(scan) => { + // Direct table scan - use the table name + scan.table_name.clone() + } + LogicalPlan::Projection(proj) => { + // Pass through to input + Self::extract_source_identifier(&proj.input) + } + LogicalPlan::Filter(filter) => { + // Pass through to input + Self::extract_source_identifier(&filter.input) + } + LogicalPlan::Aggregate(agg) => { + // Aggregate of a table + format!("agg_{}", Self::extract_source_identifier(&agg.input)) + } + LogicalPlan::Sort(sort) => { + // Pass through to input + Self::extract_source_identifier(&sort.input) + } + LogicalPlan::Limit(limit) => { + // Pass through to input + Self::extract_source_identifier(&limit.input) + } + LogicalPlan::Join(join) => { + // Join of two sources - combine their identifiers + let left_id = Self::extract_source_identifier(&join.left); + let right_id = Self::extract_source_identifier(&join.right); + format!("join_{left_id}_{right_id}") + } + LogicalPlan::Union(union) => { + // Union of multiple sources + if union.inputs.is_empty() { + "union_empty".to_string() + } else { + let identifiers: Vec = union + .inputs + .iter() + .map(|input| Self::extract_source_identifier(input)) + .collect(); + format!("union_{}", identifiers.join("_")) + } + } + LogicalPlan::Distinct(distinct) => { + // Distinct of a source + format!( + "distinct_{}", + Self::extract_source_identifier(&distinct.input) + ) + } + LogicalPlan::WithCTE(with_cte) => { + // CTE body + Self::extract_source_identifier(&with_cte.body) + } + LogicalPlan::CTERef(cte_ref) => { + // CTE reference - use the CTE name + format!("cte_{}", cte_ref.name) + } + LogicalPlan::EmptyRelation(_) => "empty".to_string(), + LogicalPlan::Values(_) => "values".to_string(), + } + } + + /// Compile a UNION operator + fn compile_union(&mut self, union: &crate::translate::logical::Union) -> Result { + if union.inputs.len() != 2 { + return Err(LimboError::ParseError(format!( + "UNION requires exactly 2 inputs, got {}", + union.inputs.len() + ))); + } + + // Extract source identifiers from each input (for UNION ALL) + let left_source = Self::extract_source_identifier(&union.inputs[0]); + let right_source = Self::extract_source_identifier(&union.inputs[1]); + + // Compile left and right inputs + let left_id = self.compile_plan(&union.inputs[0])?; + let right_id = self.compile_plan(&union.inputs[1])?; + + use crate::incremental::merge_operator::{MergeOperator, UnionMode}; + + // Create a merge operator that handles the rowid transformation + let operator_id = self.circuit.next_id; + let mode = if union.all { + // For UNION ALL, pass the source identifiers + UnionMode::All { + left_table: left_source, + right_table: right_source, + } + } else { + UnionMode::Distinct + }; + let merge_operator = Box::new(MergeOperator::new(operator_id, mode)); + + let merge_id = self.circuit.add_node( + DbspOperator::Merge { + schema: union.schema.clone(), + }, + vec![left_id, right_id], + merge_operator, + ); + + Ok(merge_id) + } + /// Convert a logical expression to a DBSP expression fn compile_expr(expr: &LogicalExpr) -> Result { match expr { diff --git a/core/incremental/merge_operator.rs b/core/incremental/merge_operator.rs new file mode 100644 index 000000000..c8547028f --- /dev/null +++ b/core/incremental/merge_operator.rs @@ -0,0 +1,187 @@ +// Merge operator for DBSP - combines two delta streams +// Used in recursive CTEs and UNION operations + +use crate::incremental::dbsp::{Delta, DeltaPair, HashableRow}; +use crate::incremental::operator::{ + ComputationTracker, DbspStateCursors, EvalState, IncrementalOperator, +}; +use crate::types::IOResult; +use crate::Result; +use std::collections::{hash_map::DefaultHasher, HashMap}; +use std::fmt::{self, Display}; +use std::hash::{Hash, Hasher}; +use std::sync::{Arc, Mutex}; + +/// How the merge operator should handle rowids when combining deltas +#[derive(Debug, Clone)] +pub enum UnionMode { + /// For UNION (distinct) - hash values only to merge duplicates + Distinct, + /// For UNION ALL - include source table name in hash to keep duplicates separate + All { + left_table: String, + right_table: String, + }, +} + +/// Merge operator that combines two input deltas into one output delta +/// Handles both recursive CTEs and UNION/UNION ALL operations +#[derive(Debug)] +pub struct MergeOperator { + operator_id: usize, + union_mode: UnionMode, + /// For UNION: tracks seen value hashes with their assigned rowids + /// For UNION ALL: tracks (source_id, original_rowid) -> assigned_rowid mappings + seen_rows: HashMap, // hash -> assigned_rowid + /// Next rowid to assign for new rows + next_rowid: i64, +} + +impl MergeOperator { + /// Create a new merge operator with specified union mode + pub fn new(operator_id: usize, mode: UnionMode) -> Self { + Self { + operator_id, + union_mode: mode, + seen_rows: HashMap::new(), + next_rowid: 1, + } + } + + /// Transform a delta's rowids based on the union mode with state tracking + fn transform_delta(&mut self, delta: Delta, is_left: bool) -> Delta { + match &self.union_mode { + UnionMode::Distinct => { + // For UNION distinct, track seen values and deduplicate + let mut output = Delta::new(); + for (row, weight) in delta.changes { + // Hash only the values (not rowid) for deduplication + let temp_row = HashableRow::new(0, row.values.clone()); + let value_hash = temp_row.cached_hash(); + + // Check if we've seen this value before + let assigned_rowid = + if let Some(&existing_rowid) = self.seen_rows.get(&value_hash) { + // Value already seen - use existing rowid + existing_rowid + } else { + // New value - assign new rowid and remember it + let new_rowid = self.next_rowid; + self.next_rowid += 1; + self.seen_rows.insert(value_hash, new_rowid); + new_rowid + }; + + // Output the row with the assigned rowid + let final_row = HashableRow::new(assigned_rowid, temp_row.values); + output.changes.push((final_row, weight)); + } + output + } + UnionMode::All { + left_table, + right_table, + } => { + // For UNION ALL, maintain consistent rowid mapping per source + let table = if is_left { left_table } else { right_table }; + let mut source_hasher = DefaultHasher::new(); + table.hash(&mut source_hasher); + let source_id = source_hasher.finish(); + + let mut output = Delta::new(); + for (row, weight) in delta.changes { + // Create a unique key for this (source, rowid) pair + let mut key_hasher = DefaultHasher::new(); + source_id.hash(&mut key_hasher); + row.rowid.hash(&mut key_hasher); + let key_hash = key_hasher.finish(); + + // Check if we've seen this (source, rowid) before + let assigned_rowid = + if let Some(&existing_rowid) = self.seen_rows.get(&key_hash) { + // Use existing rowid for this (source, rowid) pair + existing_rowid + } else { + // New row - assign new rowid + let new_rowid = self.next_rowid; + self.next_rowid += 1; + self.seen_rows.insert(key_hash, new_rowid); + new_rowid + }; + + // Create output row with consistent rowid + let final_row = HashableRow::new(assigned_rowid, row.values.clone()); + output.changes.push((final_row, weight)); + } + output + } + } + } +} + +impl Display for MergeOperator { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + match &self.union_mode { + UnionMode::Distinct => write!(f, "MergeOperator({}, UNION)", self.operator_id), + UnionMode::All { .. } => write!(f, "MergeOperator({}, UNION ALL)", self.operator_id), + } + } +} + +impl IncrementalOperator for MergeOperator { + fn eval( + &mut self, + input: &mut EvalState, + _cursors: &mut DbspStateCursors, + ) -> Result> { + match input { + EvalState::Init { deltas } => { + // Extract deltas from the evaluation state + let delta_pair = std::mem::take(deltas); + + // Transform deltas based on union mode (with state tracking) + let left_transformed = self.transform_delta(delta_pair.left, true); + let right_transformed = self.transform_delta(delta_pair.right, false); + + // Merge the transformed deltas + let mut output = Delta::new(); + output.merge(&left_transformed); + output.merge(&right_transformed); + + // Move to Done state + *input = EvalState::Done; + + Ok(IOResult::Done(output)) + } + EvalState::Aggregate(_) | EvalState::Join(_) | EvalState::Uninitialized => { + // Merge operator only handles Init state + unreachable!("MergeOperator only handles Init state") + } + EvalState::Done => { + // Already evaluated + Ok(IOResult::Done(Delta::new())) + } + } + } + + fn commit( + &mut self, + deltas: DeltaPair, + _cursors: &mut DbspStateCursors, + ) -> Result> { + // Transform deltas based on union mode + let left_transformed = self.transform_delta(deltas.left, true); + let right_transformed = self.transform_delta(deltas.right, false); + + // Merge the transformed deltas + let mut output = Delta::new(); + output.merge(&left_transformed); + output.merge(&right_transformed); + + Ok(IOResult::Done(output)) + } + + fn set_tracker(&mut self, _tracker: Arc>) { + // Merge operator doesn't need tracking for now + } +} diff --git a/core/incremental/mod.rs b/core/incremental/mod.rs index 67eed60e2..5ac635cce 100644 --- a/core/incremental/mod.rs +++ b/core/incremental/mod.rs @@ -6,6 +6,7 @@ pub mod expr_compiler; pub mod filter_operator; pub mod input_operator; pub mod join_operator; +pub mod merge_operator; pub mod operator; pub mod persistence; pub mod project_operator; diff --git a/core/incremental/operator.rs b/core/incremental/operator.rs index 2af512504..53a5b1949 100644 --- a/core/incremental/operator.rs +++ b/core/incremental/operator.rs @@ -3674,4 +3674,340 @@ mod tests { assert!(was_new, "Duplicate rowid found: {}. This would cause rows to overwrite each other in btree storage!", row.rowid); } } + + // Merge operator tests + use crate::incremental::merge_operator::{MergeOperator, UnionMode}; + + #[test] + fn test_merge_operator_basic() { + let (_pager, table_root_page_id, index_root_page_id) = create_test_pager(); + let table_cursor = BTreeCursor::new_table(None, _pager.clone(), table_root_page_id, 5); + let index_def = create_dbsp_state_index(index_root_page_id); + let index_cursor = + BTreeCursor::new_index(None, _pager.clone(), index_root_page_id, &index_def, 4); + let mut cursors = DbspStateCursors::new(table_cursor, index_cursor); + + let mut merge_op = MergeOperator::new( + 1, + UnionMode::All { + left_table: "table1".to_string(), + right_table: "table2".to_string(), + }, + ); + + // Create two deltas + let mut left_delta = Delta::new(); + left_delta.insert(1, vec![Value::Integer(1)]); + left_delta.insert(2, vec![Value::Integer(2)]); + + let mut right_delta = Delta::new(); + right_delta.insert(3, vec![Value::Integer(3)]); + right_delta.insert(4, vec![Value::Integer(4)]); + + let delta_pair = DeltaPair::new(left_delta, right_delta); + + // Evaluate merge + let result = merge_op.commit(delta_pair, &mut cursors).unwrap(); + + if let IOResult::Done(merged) = result { + // Should have all 4 entries + assert_eq!(merged.len(), 4); + + // Check that all values are present + let values: Vec = merged + .changes + .iter() + .filter_map(|(row, weight)| { + if *weight > 0 && !row.values.is_empty() { + if let Value::Integer(n) = &row.values[0] { + Some(*n) + } else { + None + } + } else { + None + } + }) + .collect(); + + assert!(values.contains(&1)); + assert!(values.contains(&2)); + assert!(values.contains(&3)); + assert!(values.contains(&4)); + } else { + panic!("Expected Done result"); + } + } + + #[test] + fn test_merge_operator_stateful_distinct() { + let (_pager, table_root_page_id, index_root_page_id) = create_test_pager(); + let table_cursor = BTreeCursor::new_table(None, _pager.clone(), table_root_page_id, 5); + let index_def = create_dbsp_state_index(index_root_page_id); + let index_cursor = + BTreeCursor::new_index(None, _pager.clone(), index_root_page_id, &index_def, 4); + let mut cursors = DbspStateCursors::new(table_cursor, index_cursor); + + // Test that UNION (distinct) properly deduplicates across multiple operations + let mut merge_op = MergeOperator::new(7, UnionMode::Distinct); + + // First operation: insert values 1, 2, 3 from left and 2, 3, 4 from right + let mut left_delta1 = Delta::new(); + left_delta1.insert(1, vec![Value::Integer(1)]); + left_delta1.insert(2, vec![Value::Integer(2)]); + left_delta1.insert(3, vec![Value::Integer(3)]); + + let mut right_delta1 = Delta::new(); + right_delta1.insert(4, vec![Value::Integer(2)]); // Duplicate value 2 + right_delta1.insert(5, vec![Value::Integer(3)]); // Duplicate value 3 + right_delta1.insert(6, vec![Value::Integer(4)]); + + let result1 = merge_op + .commit(DeltaPair::new(left_delta1, right_delta1), &mut cursors) + .unwrap(); + if let IOResult::Done(merged1) = result1 { + // Should have 4 unique values (1, 2, 3, 4) + // But 6 total entries (3 from left + 3 from right) + assert_eq!(merged1.len(), 6); + + // Collect unique rowids - should be 4 + let unique_rowids: std::collections::HashSet = + merged1.changes.iter().map(|(row, _)| row.rowid).collect(); + assert_eq!( + unique_rowids.len(), + 4, + "Should have 4 unique rowids for 4 unique values" + ); + } else { + panic!("Expected Done result"); + } + + // Second operation: insert value 2 again from left, and value 5 from right + let mut left_delta2 = Delta::new(); + left_delta2.insert(7, vec![Value::Integer(2)]); // Duplicate of existing value + + let mut right_delta2 = Delta::new(); + right_delta2.insert(8, vec![Value::Integer(5)]); // New value + + let result2 = merge_op + .commit(DeltaPair::new(left_delta2, right_delta2), &mut cursors) + .unwrap(); + if let IOResult::Done(merged2) = result2 { + assert_eq!(merged2.len(), 2, "Should have 2 entries in delta"); + + // Check that value 2 got the same rowid as before + let has_existing_rowid = merged2 + .changes + .iter() + .any(|(row, _)| row.values == vec![Value::Integer(2)] && row.rowid <= 4); + assert!(has_existing_rowid, "Value 2 should reuse existing rowid"); + + // Check that value 5 got a new rowid + let has_new_rowid = merged2 + .changes + .iter() + .any(|(row, _)| row.values == vec![Value::Integer(5)] && row.rowid > 4); + assert!(has_new_rowid, "Value 5 should get a new rowid"); + } else { + panic!("Expected Done result"); + } + } + + #[test] + fn test_merge_operator_single_sided_inputs_union_all() { + let (_pager, table_root_page_id, index_root_page_id) = create_test_pager(); + let table_cursor = BTreeCursor::new_table(None, _pager.clone(), table_root_page_id, 5); + let index_def = create_dbsp_state_index(index_root_page_id); + let index_cursor = + BTreeCursor::new_index(None, _pager.clone(), index_root_page_id, &index_def, 4); + let mut cursors = DbspStateCursors::new(table_cursor, index_cursor); + + // Test UNION ALL with inputs coming from only one side at a time + let mut merge_op = MergeOperator::new( + 10, + UnionMode::All { + left_table: "orders".to_string(), + right_table: "archived_orders".to_string(), + }, + ); + + // First: only left side (orders) has data + let mut left_delta1 = Delta::new(); + left_delta1.insert(100, vec![Value::Integer(1001)]); + left_delta1.insert(101, vec![Value::Integer(1002)]); + + let right_delta1 = Delta::new(); // Empty right side + + let result1 = merge_op + .commit(DeltaPair::new(left_delta1, right_delta1), &mut cursors) + .unwrap(); + + let first_rowids = if let IOResult::Done(ref merged1) = result1 { + assert_eq!(merged1.len(), 2, "Should have 2 entries from left only"); + merged1 + .changes + .iter() + .map(|(row, _)| row.rowid) + .collect::>() + } else { + panic!("Expected Done result"); + }; + + // Second: only right side (archived_orders) has data + let left_delta2 = Delta::new(); // Empty left side + + let mut right_delta2 = Delta::new(); + right_delta2.insert(100, vec![Value::Integer(2001)]); // Same rowid as left, different table + right_delta2.insert(102, vec![Value::Integer(2002)]); + + let result2 = merge_op + .commit(DeltaPair::new(left_delta2, right_delta2), &mut cursors) + .unwrap(); + let second_result_rowid_100 = if let IOResult::Done(ref merged2) = result2 { + assert_eq!(merged2.len(), 2, "Should have 2 entries from right only"); + + // Rowids should be different from the left side even though original rowid 100 is the same + let second_rowids: Vec = + merged2.changes.iter().map(|(row, _)| row.rowid).collect(); + for rowid in &second_rowids { + assert!( + !first_rowids.contains(rowid), + "Right side rowids should be different from left side rowids" + ); + } + + // Save rowid for archived_orders.100 + merged2 + .changes + .iter() + .find(|(row, _)| row.values == vec![Value::Integer(2001)]) + .map(|(row, _)| row.rowid) + .unwrap() + } else { + panic!("Expected Done result"); + }; + + // Third: left side again with same rowids as before + let mut left_delta3 = Delta::new(); + left_delta3.insert(100, vec![Value::Integer(1003)]); // Same rowid 100 from orders + left_delta3.insert(101, vec![Value::Integer(1004)]); // Same rowid 101 from orders + + let right_delta3 = Delta::new(); // Empty right side + + let result3 = merge_op + .commit(DeltaPair::new(left_delta3, right_delta3), &mut cursors) + .unwrap(); + if let IOResult::Done(merged3) = result3 { + assert_eq!(merged3.len(), 2, "Should have 2 entries from left"); + + // Should get the same assigned rowids as the first operation + let third_rowids: Vec = merged3.changes.iter().map(|(row, _)| row.rowid).collect(); + assert_eq!( + first_rowids, third_rowids, + "Same (table, rowid) pairs should get same assigned rowids" + ); + } else { + panic!("Expected Done result"); + } + + // Fourth: right side again with rowid 100 + let left_delta4 = Delta::new(); // Empty left side + + let mut right_delta4 = Delta::new(); + right_delta4.insert(100, vec![Value::Integer(2003)]); // Same rowid 100 from archived_orders + + let result4 = merge_op + .commit(DeltaPair::new(left_delta4, right_delta4), &mut cursors) + .unwrap(); + if let IOResult::Done(merged4) = result4 { + assert_eq!(merged4.len(), 1, "Should have 1 entry from right"); + + // Should get same assigned rowid as second operation for archived_orders.100 + let fourth_rowid = merged4.changes[0].0.rowid; + assert_eq!( + fourth_rowid, second_result_rowid_100, + "archived_orders rowid 100 should consistently map to same assigned rowid" + ); + } else { + panic!("Expected Done result"); + } + } + + #[test] + fn test_merge_operator_both_sides_empty() { + let (_pager, table_root_page_id, index_root_page_id) = create_test_pager(); + let table_cursor = BTreeCursor::new_table(None, _pager.clone(), table_root_page_id, 5); + let index_def = create_dbsp_state_index(index_root_page_id); + let index_cursor = + BTreeCursor::new_index(None, _pager.clone(), index_root_page_id, &index_def, 4); + let mut cursors = DbspStateCursors::new(table_cursor, index_cursor); + + // Test that both sides being empty works correctly + let mut merge_op = MergeOperator::new( + 12, + UnionMode::All { + left_table: "t1".to_string(), + right_table: "t2".to_string(), + }, + ); + + // First: insert some data to establish state + let mut left_delta1 = Delta::new(); + left_delta1.insert(1, vec![Value::Integer(100)]); + let mut right_delta1 = Delta::new(); + right_delta1.insert(1, vec![Value::Integer(200)]); + + let result1 = merge_op + .commit(DeltaPair::new(left_delta1, right_delta1), &mut cursors) + .unwrap(); + let original_t1_rowid = if let IOResult::Done(ref merged1) = result1 { + assert_eq!(merged1.len(), 2, "Should have 2 entries initially"); + // Save the rowid for t1.rowid=1 + merged1 + .changes + .iter() + .find(|(row, _)| row.values == vec![Value::Integer(100)]) + .map(|(row, _)| row.rowid) + .unwrap() + } else { + panic!("Expected Done result"); + }; + + // Second: both sides empty - should produce empty output + let empty_left = Delta::new(); + let empty_right = Delta::new(); + + let result2 = merge_op + .commit(DeltaPair::new(empty_left, empty_right), &mut cursors) + .unwrap(); + if let IOResult::Done(merged2) = result2 { + assert_eq!( + merged2.len(), + 0, + "Both empty sides should produce empty output" + ); + } else { + panic!("Expected Done result"); + } + + // Third: add more data to verify state is still intact + let mut left_delta3 = Delta::new(); + left_delta3.insert(1, vec![Value::Integer(101)]); // Same rowid as before + let right_delta3 = Delta::new(); + + let result3 = merge_op + .commit(DeltaPair::new(left_delta3, right_delta3), &mut cursors) + .unwrap(); + if let IOResult::Done(merged3) = result3 { + assert_eq!(merged3.len(), 1, "Should have 1 entry"); + // Should reuse the same assigned rowid for t1.rowid=1 + let rowid = merged3.changes[0].0.rowid; + assert_eq!( + rowid, original_t1_rowid, + "Should maintain consistent rowid mapping after empty operation" + ); + } else { + panic!("Expected Done result"); + } + } } diff --git a/core/incremental/view.rs b/core/incremental/view.rs index fd7b3988a..65fa5e2bb 100644 --- a/core/incremental/view.rs +++ b/core/incremental/view.rs @@ -8,7 +8,7 @@ use crate::types::{IOResult, Value}; use crate::util::{extract_view_columns, ViewColumnSchema}; use crate::{return_if_io, LimboError, Pager, Result, Statement}; use std::cell::RefCell; -use std::collections::HashMap; +use std::collections::{HashMap, HashSet}; use std::fmt; use std::rc::Rc; use std::sync::{Arc, Mutex}; @@ -195,6 +195,9 @@ pub struct IncrementalView { // Mapping from table name to fully qualified name (e.g., "customers" -> "main.customers") // This preserves database qualification from the original query qualified_table_names: HashMap, + // WHERE conditions for each table (accumulated from all occurrences) + // Multiple conditions from UNION branches or duplicate references are stored as a vector + table_conditions: HashMap>>, // The view's column schema with table relationships pub column_schema: ViewColumnSchema, // State machine for population @@ -312,9 +315,18 @@ impl IncrementalView { // Extract output columns using the shared function let column_schema = extract_view_columns(&select, schema)?; - // Get all tables from FROM clause and JOINs, along with their aliases - let (referenced_tables, table_aliases, qualified_table_names) = - Self::extract_all_tables(&select, schema)?; + let mut referenced_tables = Vec::new(); + let mut table_aliases = HashMap::new(); + let mut qualified_table_names = HashMap::new(); + let mut table_conditions = HashMap::new(); + Self::extract_all_tables( + &select, + schema, + &mut referenced_tables, + &mut table_aliases, + &mut qualified_table_names, + &mut table_conditions, + )?; Self::new( name, @@ -322,6 +334,7 @@ impl IncrementalView { referenced_tables, table_aliases, qualified_table_names, + table_conditions, column_schema, schema, main_data_root, @@ -337,6 +350,7 @@ impl IncrementalView { referenced_tables: Vec>, table_aliases: HashMap, qualified_table_names: HashMap, + table_conditions: HashMap>>, column_schema: ViewColumnSchema, schema: &Schema, main_data_root: usize, @@ -362,6 +376,7 @@ impl IncrementalView { referenced_tables, table_aliases, qualified_table_names, + table_conditions, column_schema, populate_state: PopulateState::Start, tracker, @@ -405,97 +420,249 @@ impl IncrementalView { self.referenced_tables.clone() } - /// Extract all tables and their aliases from the SELECT statement - /// Returns a tuple of (tables, alias_map, qualified_names) - /// where alias_map is alias -> table_name - /// and qualified_names is table_name -> fully_qualified_name - #[allow(clippy::type_complexity)] - fn extract_all_tables( - select: &ast::Select, + /// Process a single table reference from a FROM or JOIN clause + fn process_table_reference( + name: &ast::QualifiedName, + alias: &Option, schema: &Schema, - ) -> Result<( - Vec>, - HashMap, - HashMap, - )> { - let mut tables = Vec::new(); - let mut aliases = HashMap::new(); - let mut qualified_names = HashMap::new(); + table_map: &mut HashMap>, + aliases: &mut HashMap, + qualified_names: &mut HashMap, + cte_names: &HashSet, + ) -> Result<()> { + let table_name = name.name.as_str(); + // Build the fully qualified name + let qualified_name = if let Some(ref db) = name.db_name { + format!("{db}.{table_name}") + } else { + table_name.to_string() + }; + + // Skip CTEs - they're not real tables + if !cte_names.contains(table_name) { + if let Some(table) = schema.get_btree_table(table_name) { + table_map.insert(table_name.to_string(), table.clone()); + qualified_names.insert(table_name.to_string(), qualified_name); + + // Store the alias mapping if there is an alias + if let Some(alias_enum) = alias { + let alias_name = match alias_enum { + ast::As::As(name) | ast::As::Elided(name) => match name { + ast::Name::Ident(s) | ast::Name::Quoted(s) => s, + }, + }; + aliases.insert(alias_name.to_string(), table_name.to_string()); + } + } else { + return Err(LimboError::ParseError(format!( + "Table '{table_name}' not found in schema" + ))); + } + } + Ok(()) + } + + fn extract_one_statement( + select: &ast::OneSelect, + schema: &Schema, + table_map: &mut HashMap>, + aliases: &mut HashMap, + qualified_names: &mut HashMap, + table_conditions: &mut HashMap>>, + cte_names: &HashSet, + ) -> Result<()> { if let ast::OneSelect::Select { from: Some(ref from), .. - } = select.body.select + } = select { // Get the main table from FROM clause if let ast::SelectTable::Table(name, alias, _) = from.select.as_ref() { - let table_name = name.name.as_str(); - - // Build the fully qualified name - let qualified_name = if let Some(ref db) = name.db_name { - format!("{db}.{table_name}") - } else { - table_name.to_string() - }; - - if let Some(table) = schema.get_btree_table(table_name) { - tables.push(table.clone()); - qualified_names.insert(table_name.to_string(), qualified_name); - - // Store the alias mapping if there is an alias - if let Some(alias_name) = alias { - aliases.insert(alias_name.to_string(), table_name.to_string()); - } - } else { - return Err(LimboError::ParseError(format!( - "Table '{table_name}' not found in schema" - ))); - } + Self::process_table_reference( + name, + alias, + schema, + table_map, + aliases, + qualified_names, + cte_names, + )?; } // Get all tables from JOIN clauses for join in &from.joins { if let ast::SelectTable::Table(name, alias, _) = join.table.as_ref() { - let table_name = name.name.as_str(); - - // Build the fully qualified name - let qualified_name = if let Some(ref db) = name.db_name { - format!("{db}.{table_name}") - } else { - table_name.to_string() - }; - - if let Some(table) = schema.get_btree_table(table_name) { - tables.push(table.clone()); - qualified_names.insert(table_name.to_string(), qualified_name); - - // Store the alias mapping if there is an alias - if let Some(alias_name) = alias { - aliases.insert(alias_name.to_string(), table_name.to_string()); - } - } else { - return Err(LimboError::ParseError(format!( - "Table '{table_name}' not found in schema" - ))); - } + Self::process_table_reference( + name, + alias, + schema, + table_map, + aliases, + qualified_names, + cte_names, + )?; } } } + // Extract WHERE conditions for this SELECT + let where_expr = if let ast::OneSelect::Select { + where_clause: Some(ref where_expr), + .. + } = select + { + Some(where_expr.as_ref().clone()) + } else { + None + }; - if tables.is_empty() { - return Err(LimboError::ParseError( - "No tables found in SELECT statement".to_string(), - )); + // Ensure all tables have an entry in table_conditions (even if empty) + for table_name in table_map.keys() { + table_conditions.entry(table_name.clone()).or_default(); } - Ok((tables, aliases, qualified_names)) + // Extract and store table-specific conditions from the WHERE clause + if let Some(ref where_expr) = where_expr { + for table_name in table_map.keys() { + let all_tables: Vec = table_map.keys().cloned().collect(); + let table_specific_condition = Self::extract_conditions_for_table( + where_expr, + table_name, + aliases, + &all_tables, + schema, + ); + // Only add if there's actually a condition for this table + if let Some(condition) = table_specific_condition { + let conditions = table_conditions.get_mut(table_name).unwrap(); + conditions.push(Some(condition)); + } + } + } else { + // No WHERE clause - push None for all tables in this SELECT. It is a way + // of signaling that we need all rows in the table. It is important we signal this + // explicitly, because the same table may appear in many conditions - some of which + // have filters that would otherwise be applied. + for table_name in table_map.keys() { + let conditions = table_conditions.get_mut(table_name).unwrap(); + conditions.push(None); + } + } + + Ok(()) + } + + /// Extract all tables and their aliases from the SELECT statement, handling CTEs + /// Deduplicates tables and accumulates WHERE conditions + fn extract_all_tables( + select: &ast::Select, + schema: &Schema, + tables: &mut Vec>, + aliases: &mut HashMap, + qualified_names: &mut HashMap, + table_conditions: &mut HashMap>>, + ) -> Result<()> { + let mut table_map = HashMap::new(); + Self::extract_all_tables_inner( + select, + schema, + &mut table_map, + aliases, + qualified_names, + table_conditions, + &HashSet::new(), + )?; + + // Convert deduplicated table map to vector + for (_name, table) in table_map { + tables.push(table); + } + + Ok(()) + } + + fn extract_all_tables_inner( + select: &ast::Select, + schema: &Schema, + table_map: &mut HashMap>, + aliases: &mut HashMap, + qualified_names: &mut HashMap, + table_conditions: &mut HashMap>>, + parent_cte_names: &HashSet, + ) -> Result<()> { + let mut cte_names = parent_cte_names.clone(); + + // First, collect CTE names and process any CTEs (WITH clauses) + if let Some(ref with) = select.with { + // First pass: collect all CTE names (needed for recursive CTEs) + for cte in &with.ctes { + cte_names.insert(cte.tbl_name.as_str().to_string()); + } + + // Second pass: extract tables from each CTE's SELECT statement + for cte in &with.ctes { + // Recursively extract tables from each CTE's SELECT statement + Self::extract_all_tables_inner( + &cte.select, + schema, + table_map, + aliases, + qualified_names, + table_conditions, + &cte_names, + )?; + } + } + + // Then process the main SELECT body + Self::extract_one_statement( + &select.body.select, + schema, + table_map, + aliases, + qualified_names, + table_conditions, + &cte_names, + )?; + + // Process any compound selects (UNION, etc.) + for c in &select.body.compounds { + let ast::CompoundSelect { select, .. } = c; + Self::extract_one_statement( + select, + schema, + table_map, + aliases, + qualified_names, + table_conditions, + &cte_names, + )?; + } + + Ok(()) } /// Generate SQL queries for populating the view from each source table /// Returns a vector of SQL statements, one for each referenced table - /// Each query includes only the WHERE conditions relevant to that specific table + /// Each query includes the WHERE conditions accumulated from all occurrences fn sql_for_populate(&self) -> crate::Result> { - if self.referenced_tables.is_empty() { + Self::generate_populate_queries( + &self.select_stmt, + &self.referenced_tables, + &self.table_aliases, + &self.qualified_table_names, + &self.table_conditions, + ) + } + + pub fn generate_populate_queries( + select_stmt: &ast::Select, + referenced_tables: &[Arc], + table_aliases: &HashMap, + qualified_table_names: &HashMap, + table_conditions: &HashMap>>, + ) -> crate::Result> { + if referenced_tables.is_empty() { return Err(LimboError::ParseError( "No tables to populate from".to_string(), )); @@ -503,12 +670,11 @@ impl IncrementalView { let mut queries = Vec::new(); - for table in &self.referenced_tables { + for table in referenced_tables { // Check if the table has a rowid alias (INTEGER PRIMARY KEY column) let has_rowid_alias = table.columns.iter().any(|col| col.is_rowid_alias); - // For now, select all columns since we don't have the static operators - // The circuit will handle filtering and projection + // Select all columns. The circuit will handle filtering and projection // If there's a rowid alias, we don't need to select rowid separately let select_clause = if has_rowid_alias { "*".to_string() @@ -516,12 +682,22 @@ impl IncrementalView { "*, rowid".to_string() }; - // Extract WHERE conditions for this specific table - let where_clause = self.extract_where_clause_for_table(&table.name)?; + // Get accumulated WHERE conditions for this table + let where_clause = if let Some(conditions) = table_conditions.get(&table.name) { + // Combine multiple conditions with OR if there are multiple occurrences + Self::combine_conditions( + select_stmt, + conditions, + &table.name, + referenced_tables, + table_aliases, + )? + } else { + String::new() + }; // Use the qualified table name if available, otherwise just the table name - let table_name = self - .qualified_table_names + let table_name = qualified_table_names .get(&table.name) .cloned() .unwrap_or_else(|| table.name.clone()); @@ -532,347 +708,405 @@ impl IncrementalView { } else { format!("SELECT {select_clause} FROM {table_name} WHERE {where_clause}") }; + tracing::debug!("populating materialized view with `{query}`"); queries.push(query); } Ok(queries) } - /// Extract WHERE conditions that apply to a specific table - /// This analyzes the WHERE clause in the SELECT statement and returns - /// only the conditions that reference the given table - fn extract_where_clause_for_table(&self, table_name: &str) -> crate::Result { - // For single table queries, return the entire WHERE clause (already unqualified) - if self.referenced_tables.len() == 1 { - if let ast::OneSelect::Select { - where_clause: Some(ref where_expr), - .. - } = self.select_stmt.body.select - { - // For single table, the expression should already be unqualified or qualified with the single table - // We need to unqualify it for the single-table query - let unqualified = self.unqualify_expression(where_expr, table_name); - return Ok(unqualified.to_string()); - } + fn combine_conditions( + _select_stmt: &ast::Select, + conditions: &[Option], + table_name: &str, + _referenced_tables: &[Arc], + table_aliases: &HashMap, + ) -> crate::Result { + // Check if any conditions are None (SELECTs without WHERE) + let has_none = conditions.iter().any(|c| c.is_none()); + let non_empty: Vec<_> = conditions.iter().filter_map(|c| c.as_ref()).collect(); + + // If we have both Some and None conditions, that means in some of the expressions where + // this table appear we want all rows. So we need to fetch all rows. + if has_none && !non_empty.is_empty() { return Ok(String::new()); } - // For multi-table queries (JOINs), extract conditions for the specific table - if let ast::OneSelect::Select { - where_clause: Some(ref where_expr), - .. - } = self.select_stmt.body.select - { - // Extract conditions that reference only the specified table - let table_conditions = self.extract_table_conditions(where_expr, table_name)?; - if let Some(conditions) = table_conditions { - // Unqualify the expression for single-table query - let unqualified = self.unqualify_expression(&conditions, table_name); - return Ok(unqualified.to_string()); - } + if non_empty.is_empty() { + return Ok(String::new()); } - Ok(String::new()) + if non_empty.len() == 1 { + // Unqualify the expression before converting to string + let unqualified = Self::unqualify_expression(non_empty[0], table_name, table_aliases); + return Ok(unqualified.to_string()); + } + + // Multiple conditions - combine with OR + // This happens in UNION ALL when the same table appears multiple times + let mut combined_parts = Vec::new(); + for condition in non_empty { + let unqualified = Self::unqualify_expression(condition, table_name, table_aliases); + // Wrap each condition in parentheses to preserve precedence + combined_parts.push(format!("({unqualified})")); + } + + // Join all conditions with OR + Ok(combined_parts.join(" OR ")) + } + /// Resolve a table alias to the actual table name + /// Check if an expression is a simple comparison that can be safely extracted + /// This excludes subqueries, CASE expressions, function calls, etc. + fn is_simple_comparison(expr: &ast::Expr) -> bool { + match expr { + // Simple column references and literals are OK + ast::Expr::Column { .. } | ast::Expr::Literal(_) => true, + + // Simple binary operations between simple expressions are OK + ast::Expr::Binary(left, op, right) => { + match op { + // Logical operators + ast::Operator::And | ast::Operator::Or => { + Self::is_simple_comparison(left) && Self::is_simple_comparison(right) + } + // Comparison operators + ast::Operator::Equals + | ast::Operator::NotEquals + | ast::Operator::Less + | ast::Operator::LessEquals + | ast::Operator::Greater + | ast::Operator::GreaterEquals + | ast::Operator::Is + | ast::Operator::IsNot => { + Self::is_simple_comparison(left) && Self::is_simple_comparison(right) + } + // String concatenation and other operations are NOT simple + ast::Operator::Concat => false, + // Arithmetic might be OK if operands are simple + ast::Operator::Add + | ast::Operator::Subtract + | ast::Operator::Multiply + | ast::Operator::Divide + | ast::Operator::Modulus => { + Self::is_simple_comparison(left) && Self::is_simple_comparison(right) + } + _ => false, + } + } + + // Unary operations might be OK + ast::Expr::Unary( + ast::UnaryOperator::Not + | ast::UnaryOperator::Negative + | ast::UnaryOperator::Positive, + inner, + ) => Self::is_simple_comparison(inner), + ast::Expr::Unary(_, _) => false, + + // Complex expressions are NOT simple + ast::Expr::Case { .. } => false, + ast::Expr::Cast { .. } => false, + ast::Expr::Collate { .. } => false, + ast::Expr::Exists(_) => false, + ast::Expr::FunctionCall { .. } => false, + ast::Expr::InList { .. } => false, + ast::Expr::InSelect { .. } => false, + ast::Expr::Like { .. } => false, + ast::Expr::NotNull(_) => true, // IS NOT NULL is simple enough + ast::Expr::Parenthesized(exprs) => { + // Parenthesized expression can contain multiple expressions + // Only consider it simple if it has exactly one simple expression + exprs.len() == 1 && Self::is_simple_comparison(&exprs[0]) + } + ast::Expr::Subquery(_) => false, + + // BETWEEN might be OK if all operands are simple + ast::Expr::Between { .. } => { + // BETWEEN has a different structure, for safety just exclude it + false + } + + // Qualified references are simple + ast::Expr::DoublyQualified(..) => true, + ast::Expr::Qualified(_, _) => true, + + // These are simple + ast::Expr::Id(_) => true, + ast::Expr::Name(_) => true, + + // Anything else is not simple + _ => false, + } } - /// Extract conditions from an expression that reference only the specified table - fn extract_table_conditions( - &self, + /// Extract conditions from a WHERE clause that apply to a specific table + fn extract_conditions_for_table( expr: &ast::Expr, table_name: &str, - ) -> crate::Result> { + aliases: &HashMap, + all_tables: &[String], + schema: &Schema, + ) -> Option { match expr { ast::Expr::Binary(left, op, right) => { match op { ast::Operator::And => { // For AND, we can extract conditions independently - let left_cond = self.extract_table_conditions(left, table_name)?; - let right_cond = self.extract_table_conditions(right, table_name)?; + let left_cond = Self::extract_conditions_for_table( + left, table_name, aliases, all_tables, schema, + ); + let right_cond = Self::extract_conditions_for_table( + right, table_name, aliases, all_tables, schema, + ); match (left_cond, right_cond) { - (Some(l), Some(r)) => { - // Both conditions apply to this table - Ok(Some(ast::Expr::Binary( - Box::new(l), - ast::Operator::And, - Box::new(r), - ))) - } - (Some(l), None) => Ok(Some(l)), - (None, Some(r)) => Ok(Some(r)), - (None, None) => Ok(None), + (Some(l), Some(r)) => Some(ast::Expr::Binary( + Box::new(l), + ast::Operator::And, + Box::new(r), + )), + (Some(l), None) => Some(l), + (None, Some(r)) => Some(r), + (None, None) => None, } } ast::Operator::Or => { - // For OR, both sides must reference the same table(s) - // If either side references multiple tables, we can't extract it - let left_tables = self.get_referenced_tables_in_expr(left)?; - let right_tables = self.get_referenced_tables_in_expr(right)?; + // For OR, both sides must reference only our table + let left_tables = + Self::get_tables_in_expr(left, aliases, all_tables, schema); + let right_tables = + Self::get_tables_in_expr(right, aliases, all_tables, schema); - // If both sides only reference our table, include the whole OR if left_tables.len() == 1 && left_tables.contains(&table_name.to_string()) && right_tables.len() == 1 && right_tables.contains(&table_name.to_string()) + && Self::is_simple_comparison(expr) { - Ok(Some(expr.clone())) + Some(expr.clone()) } else { - // OR condition involves multiple tables, can't extract - Ok(None) + None } } _ => { - // For comparison operators, check if this condition references only our table - // AND is simple enough to be pushed down (no complex expressions) - let referenced_tables = self.get_referenced_tables_in_expr(expr)?; + // For comparison operators, check if this condition only references our table + let referenced_tables = + Self::get_tables_in_expr(expr, aliases, all_tables, schema); if referenced_tables.len() == 1 && referenced_tables.contains(&table_name.to_string()) + && Self::is_simple_comparison(expr) { - // Check if this is a simple comparison that can be pushed down - // Complex expressions like (a * b) >= c should be handled by the circuit - if self.is_simple_comparison(expr) { - Ok(Some(expr.clone())) - } else { - // Complex expression - let the circuit handle it - Ok(None) - } + Some(expr.clone()) } else { - Ok(None) + None } } } } - ast::Expr::Parenthesized(exprs) => { - if exprs.len() == 1 { - self.extract_table_conditions(&exprs[0], table_name) - } else { - Ok(None) - } - } _ => { - // For other expressions, check if they reference only our table - // AND are simple enough to be pushed down - let referenced_tables = self.get_referenced_tables_in_expr(expr)?; + // For other expressions, check if they only reference our table + let referenced_tables = Self::get_tables_in_expr(expr, aliases, all_tables, schema); if referenced_tables.len() == 1 && referenced_tables.contains(&table_name.to_string()) - && self.is_simple_comparison(expr) + && Self::is_simple_comparison(expr) { - Ok(Some(expr.clone())) + Some(expr.clone()) } else { - Ok(None) + None } } } } - /// Check if an expression is a simple comparison that can be pushed down to table scan - /// Returns true for simple comparisons like "column = value" or "column > value" - /// Returns false for complex expressions like "(a * b) > value" - fn is_simple_comparison(&self, expr: &ast::Expr) -> bool { - match expr { - ast::Expr::Binary(left, op, right) => { - // Check if it's a comparison operator - matches!( - op, - ast::Operator::Equals - | ast::Operator::NotEquals - | ast::Operator::Greater - | ast::Operator::GreaterEquals - | ast::Operator::Less - | ast::Operator::LessEquals - ) && self.is_simple_operand(left) - && self.is_simple_operand(right) - } - _ => false, - } - } - - /// Check if an operand is simple (column reference or literal) - fn is_simple_operand(&self, expr: &ast::Expr) -> bool { - matches!( - expr, - ast::Expr::Id(_) - | ast::Expr::Qualified(_, _) - | ast::Expr::DoublyQualified(_, _, _) - | ast::Expr::Literal(_) - ) - } - - /// Get the set of table names referenced in an expression - fn get_referenced_tables_in_expr(&self, expr: &ast::Expr) -> crate::Result> { - let mut tables = Vec::new(); - self.collect_referenced_tables(expr, &mut tables)?; - // Deduplicate - tables.sort(); - tables.dedup(); - Ok(tables) - } - - /// Recursively collect table references from an expression - fn collect_referenced_tables( - &self, + /// Unqualify column references in an expression + /// Removes table/alias prefixes from qualified column names + fn unqualify_expression( expr: &ast::Expr, - tables: &mut Vec, - ) -> crate::Result<()> { + table_name: &str, + aliases: &HashMap, + ) -> ast::Expr { match expr { - ast::Expr::Binary(left, _, right) => { - self.collect_referenced_tables(left, tables)?; - self.collect_referenced_tables(right, tables)?; - } - ast::Expr::Qualified(table, _) => { - // This is a qualified column reference (table.column or alias.column) - // We need to resolve aliases to actual table names - let actual_table = self.resolve_table_alias(table.as_str()); - tables.push(actual_table); - } - ast::Expr::Id(column) => { - // Unqualified column reference - if self.referenced_tables.len() > 1 { - // In a JOIN context, check which tables have this column - let mut tables_with_column = Vec::new(); - for table in &self.referenced_tables { - if table - .columns - .iter() - .any(|c| c.name.as_ref() == Some(&column.to_string())) - { - tables_with_column.push(table.name.clone()); - } - } - - if tables_with_column.len() > 1 { - // Ambiguous column - this should have been caught earlier - // Return error to be safe - return Err(crate::LimboError::ParseError(format!( - "Ambiguous column name '{}' in WHERE clause - exists in tables: {}", - column, - tables_with_column.join(", ") - ))); - } else if tables_with_column.len() == 1 { - // Unambiguous - only one table has this column - // This is allowed by SQLite - tables.push(tables_with_column[0].clone()); - } else { - // Column doesn't exist in any table - this is an error - // but should be caught during compilation - return Err(crate::LimboError::ParseError(format!( - "Column '{column}' not found in any table" - ))); - } - } else { - // Single table context - unqualified columns belong to that table - if let Some(table) = self.referenced_tables.first() { - tables.push(table.name.clone()); - } - } - } - ast::Expr::DoublyQualified(_database, table, _column) => { - // For database.table.column, resolve the table name - let table_str = table.as_str(); - let actual_table = self.resolve_table_alias(table_str); - tables.push(actual_table); - } - ast::Expr::Parenthesized(exprs) => { - for e in exprs { - self.collect_referenced_tables(e, tables)?; - } - } - _ => { - // Literals and other expressions don't reference tables - } - } - Ok(()) - } - - /// Convert a qualified expression to unqualified for single-table queries - /// This removes table prefixes from column references since they're not needed - /// when querying a single table - fn unqualify_expression(&self, expr: &ast::Expr, table_name: &str) -> ast::Expr { - match expr { - ast::Expr::Binary(left, op, right) => { - // Recursively unqualify both sides - ast::Expr::Binary( - Box::new(self.unqualify_expression(left, table_name)), - *op, - Box::new(self.unqualify_expression(right, table_name)), - ) - } - ast::Expr::Qualified(table, column) => { - // Convert qualified column to unqualified if it's for our table - // Handle both "table.column" and "database.table.column" cases - let table_str = table.as_str(); - - // Check if this is a database.table reference - let actual_table = if table_str.contains('.') { - // Split on '.' and take the last part as the table name + ast::Expr::Binary(left, op, right) => ast::Expr::Binary( + Box::new(Self::unqualify_expression(left, table_name, aliases)), + *op, + Box::new(Self::unqualify_expression(right, table_name, aliases)), + ), + ast::Expr::Qualified(table_or_alias, column) => { + // Check if this qualification refers to our table + let table_str = table_or_alias.as_str(); + let actual_table = if let Some(actual) = aliases.get(table_str) { + actual.clone() + } else if table_str.contains('.') { + // Handle database.table format table_str .split('.') .next_back() .unwrap_or(table_str) .to_string() } else { - // Could be an alias or direct table name - self.resolve_table_alias(table_str) + table_str.to_string() }; if actual_table == table_name { - // Just return the column name without qualification + // Remove the qualification ast::Expr::Id(column.clone()) } else { - // This shouldn't happen if extract_table_conditions worked correctly - // but keep it qualified just in case + // Keep the qualification (shouldn't happen if extraction worked correctly) expr.clone() } } ast::Expr::DoublyQualified(_database, table, column) => { - // This is database.table.column format - // Check if the table matches our target table - let table_str = table.as_str(); - let actual_table = self.resolve_table_alias(table_str); - - if actual_table == table_name { - // Just return the column name without qualification + // Check if this refers to our table + if table.as_str() == table_name { + // Remove the qualification, keep just the column ast::Expr::Id(column.clone()) } else { - // Keep it qualified if it's for a different table + // Keep the qualification (shouldn't happen if extraction worked correctly) expr.clone() } } - ast::Expr::Parenthesized(exprs) => { - // Recursively unqualify expressions in parentheses - let unqualified_exprs: Vec> = exprs + ast::Expr::Unary(op, inner) => ast::Expr::Unary( + *op, + Box::new(Self::unqualify_expression(inner, table_name, aliases)), + ), + ast::Expr::FunctionCall { + name, + args, + distinctness, + filter_over, + order_by, + } => ast::Expr::FunctionCall { + name: name.clone(), + args: args .iter() - .map(|e| Box::new(self.unqualify_expression(e, table_name))) - .collect(); - ast::Expr::Parenthesized(unqualified_exprs) + .map(|arg| Box::new(Self::unqualify_expression(arg, table_name, aliases))) + .collect(), + distinctness: *distinctness, + filter_over: filter_over.clone(), + order_by: order_by.clone(), + }, + ast::Expr::InList { lhs, not, rhs } => ast::Expr::InList { + lhs: Box::new(Self::unqualify_expression(lhs, table_name, aliases)), + not: *not, + rhs: rhs + .iter() + .map(|item| Box::new(Self::unqualify_expression(item, table_name, aliases))) + .collect(), + }, + ast::Expr::Between { + lhs, + not, + start, + end, + } => ast::Expr::Between { + lhs: Box::new(Self::unqualify_expression(lhs, table_name, aliases)), + not: *not, + start: Box::new(Self::unqualify_expression(start, table_name, aliases)), + end: Box::new(Self::unqualify_expression(end, table_name, aliases)), + }, + _ => expr.clone(), + } + } + + /// Get all tables referenced in an expression + fn get_tables_in_expr( + expr: &ast::Expr, + aliases: &HashMap, + all_tables: &[String], + schema: &Schema, + ) -> Vec { + let mut tables = Vec::new(); + Self::collect_tables_in_expr(expr, aliases, all_tables, schema, &mut tables); + tables.sort(); + tables.dedup(); + tables + } + + /// Recursively collect table references from an expression + fn collect_tables_in_expr( + expr: &ast::Expr, + aliases: &HashMap, + all_tables: &[String], + schema: &Schema, + tables: &mut Vec, + ) { + match expr { + ast::Expr::Binary(left, _, right) => { + Self::collect_tables_in_expr(left, aliases, all_tables, schema, tables); + Self::collect_tables_in_expr(right, aliases, all_tables, schema, tables); + } + ast::Expr::Qualified(table_or_alias, _) => { + // Handle database.table or just table/alias + let table_str = table_or_alias.as_str(); + let table_name = if let Some(actual_table) = aliases.get(table_str) { + // It's an alias + actual_table.clone() + } else if table_str.contains('.') { + // It might be database.table format, extract just the table name + table_str + .split('.') + .next_back() + .unwrap_or(table_str) + .to_string() + } else { + // It's a direct table name + table_str.to_string() + }; + tables.push(table_name); + } + ast::Expr::DoublyQualified(_database, table, _column) => { + // For database.table.column, extract the table name + tables.push(table.to_string()); + } + ast::Expr::Id(column) => { + // Unqualified column - try to find which table has this column + if all_tables.len() == 1 { + tables.push(all_tables[0].clone()); + } else { + // Check which table has this column + for table_name in all_tables { + if let Some(table) = schema.get_btree_table(table_name) { + if table + .columns + .iter() + .any(|col| col.name.as_deref() == Some(column.as_str())) + { + tables.push(table_name.clone()); + break; // Found the table, stop looking + } + } + } + } + } + ast::Expr::FunctionCall { args, .. } => { + for arg in args { + Self::collect_tables_in_expr(arg, aliases, all_tables, schema, tables); + } + } + ast::Expr::InList { lhs, rhs, .. } => { + Self::collect_tables_in_expr(lhs, aliases, all_tables, schema, tables); + for item in rhs { + Self::collect_tables_in_expr(item, aliases, all_tables, schema, tables); + } + } + ast::Expr::InSelect { lhs, .. } => { + Self::collect_tables_in_expr(lhs, aliases, all_tables, schema, tables); + } + ast::Expr::Between { + lhs, start, end, .. + } => { + Self::collect_tables_in_expr(lhs, aliases, all_tables, schema, tables); + Self::collect_tables_in_expr(start, aliases, all_tables, schema, tables); + Self::collect_tables_in_expr(end, aliases, all_tables, schema, tables); + } + ast::Expr::Unary(_, expr) => { + Self::collect_tables_in_expr(expr, aliases, all_tables, schema, tables); } _ => { - // Other expression types (literals, unqualified columns, etc.) stay as-is - expr.clone() + // Literals, etc. don't reference tables } } } - - /// Resolve a table alias to the actual table name - fn resolve_table_alias(&self, alias: &str) -> String { - // Check if there's an alias mapping in the FROM/JOIN clauses - // For now, we'll do a simple check - if the alias matches a table name, use it - // Otherwise, try to find it in the FROM clause - - // First check if it's an actual table name - if self.referenced_tables.iter().any(|t| t.name == alias) { - return alias.to_string(); - } - - // Check if it's an alias that maps to a table - if let Some(table_name) = self.table_aliases.get(alias) { - return table_name.clone(); - } - - // If we can't resolve it, return as-is (it might be a table name we don't know about) - alias.to_string() - } - /// Populate the view by scanning the source table using a state machine /// This can be called multiple times and will resume from where it left off /// This method is only for materialized views and will persist data to the btree @@ -1342,17 +1576,58 @@ mod tests { } } + // Type alias for the complex return type of extract_all_tables + type ExtractedTableInfo = ( + Vec>, + HashMap, + HashMap, + HashMap>>, + ); + + fn extract_all_tables(select: &ast::Select, schema: &Schema) -> Result { + let mut referenced_tables = Vec::new(); + let mut table_aliases = HashMap::new(); + let mut qualified_table_names = HashMap::new(); + let mut table_conditions = HashMap::new(); + IncrementalView::extract_all_tables( + select, + schema, + &mut referenced_tables, + &mut table_aliases, + &mut qualified_table_names, + &mut table_conditions, + )?; + Ok(( + referenced_tables, + table_aliases, + qualified_table_names, + table_conditions, + )) + } + #[test] fn test_extract_single_table() { let schema = create_test_schema(); let select = parse_select("SELECT * FROM customers"); - let (tables, _, _) = IncrementalView::extract_all_tables(&select, &schema).unwrap(); + let (tables, _, _, _table_conditions) = extract_all_tables(&select, &schema).unwrap(); assert_eq!(tables.len(), 1); assert_eq!(tables[0].name, "customers"); } + #[test] + fn test_tables_from_union() { + let schema = create_test_schema(); + let select = parse_select("SELECT name FROM customers union SELECT name from products"); + + let (tables, _, _, table_conditions) = extract_all_tables(&select, &schema).unwrap(); + + assert_eq!(tables.len(), 2); + assert!(table_conditions.contains_key("customers")); + assert!(table_conditions.contains_key("products")); + } + #[test] fn test_extract_tables_from_inner_join() { let schema = create_test_schema(); @@ -1360,11 +1635,11 @@ mod tests { "SELECT * FROM customers INNER JOIN orders ON customers.id = orders.customer_id", ); - let (tables, _, _) = IncrementalView::extract_all_tables(&select, &schema).unwrap(); + let (tables, _, _, table_conditions) = extract_all_tables(&select, &schema).unwrap(); assert_eq!(tables.len(), 2); - assert_eq!(tables[0].name, "customers"); - assert_eq!(tables[1].name, "orders"); + assert!(table_conditions.contains_key("customers")); + assert!(table_conditions.contains_key("orders")); } #[test] @@ -1376,12 +1651,12 @@ mod tests { INNER JOIN products ON orders.id = products.id", ); - let (tables, _, _) = IncrementalView::extract_all_tables(&select, &schema).unwrap(); + let (tables, _, _, table_conditions) = extract_all_tables(&select, &schema).unwrap(); assert_eq!(tables.len(), 3); - assert_eq!(tables[0].name, "customers"); - assert_eq!(tables[1].name, "orders"); - assert_eq!(tables[2].name, "products"); + assert!(table_conditions.contains_key("customers")); + assert!(table_conditions.contains_key("orders")); + assert!(table_conditions.contains_key("products")); } #[test] @@ -1391,11 +1666,11 @@ mod tests { "SELECT * FROM customers LEFT JOIN orders ON customers.id = orders.customer_id", ); - let (tables, _, _) = IncrementalView::extract_all_tables(&select, &schema).unwrap(); + let (tables, _, _, table_conditions) = extract_all_tables(&select, &schema).unwrap(); assert_eq!(tables.len(), 2); - assert_eq!(tables[0].name, "customers"); - assert_eq!(tables[1].name, "orders"); + assert!(table_conditions.contains_key("customers")); + assert!(table_conditions.contains_key("orders")); } #[test] @@ -1403,11 +1678,11 @@ mod tests { let schema = create_test_schema(); let select = parse_select("SELECT * FROM customers CROSS JOIN orders"); - let (tables, _, _) = IncrementalView::extract_all_tables(&select, &schema).unwrap(); + let (tables, _, _, table_conditions) = extract_all_tables(&select, &schema).unwrap(); assert_eq!(tables.len(), 2); - assert_eq!(tables[0].name, "customers"); - assert_eq!(tables[1].name, "orders"); + assert!(table_conditions.contains_key("customers")); + assert!(table_conditions.contains_key("orders")); } #[test] @@ -1416,12 +1691,17 @@ mod tests { let select = parse_select("SELECT * FROM customers c INNER JOIN orders o ON c.id = o.customer_id"); - let (tables, _, _) = IncrementalView::extract_all_tables(&select, &schema).unwrap(); + let (tables, aliases, _, _table_conditions) = extract_all_tables(&select, &schema).unwrap(); // Should still extract the actual table names, not aliases assert_eq!(tables.len(), 2); - assert_eq!(tables[0].name, "customers"); - assert_eq!(tables[1].name, "orders"); + let table_names: Vec<&str> = tables.iter().map(|t| t.name.as_str()).collect(); + assert!(table_names.contains(&"customers")); + assert!(table_names.contains(&"orders")); + + // Check that aliases are correctly mapped + assert_eq!(aliases.get("c"), Some(&"customers".to_string())); + assert_eq!(aliases.get("o"), Some(&"orders".to_string())); } #[test] @@ -1429,8 +1709,7 @@ mod tests { let schema = create_test_schema(); let select = parse_select("SELECT * FROM nonexistent"); - let result = - IncrementalView::extract_all_tables(&select, &schema).map(|(tables, _, _)| tables); + let result = extract_all_tables(&select, &schema).map(|(tables, _, _, _)| tables); assert!(result.is_err()); assert!(result @@ -1446,8 +1725,7 @@ mod tests { "SELECT * FROM customers INNER JOIN nonexistent ON customers.id = nonexistent.id", ); - let result = - IncrementalView::extract_all_tables(&select, &schema).map(|(tables, _, _)| tables); + let result = extract_all_tables(&select, &schema).map(|(tables, _, _, _)| tables); assert!(result.is_err()); assert!(result @@ -1462,14 +1740,15 @@ mod tests { let schema = create_test_schema(); let select = parse_select("SELECT * FROM customers"); - let (tables, aliases, qualified_names) = - IncrementalView::extract_all_tables(&select, &schema).unwrap(); + let (tables, aliases, qualified_names, table_conditions) = + extract_all_tables(&select, &schema).unwrap(); let view = IncrementalView::new( "test_view".to_string(), select.clone(), tables, aliases, qualified_names, + table_conditions, extract_view_columns(&select, &schema).unwrap(), &schema, 1, // main_data_root @@ -1491,14 +1770,15 @@ mod tests { let schema = create_test_schema(); let select = parse_select("SELECT * FROM customers WHERE id > 10"); - let (tables, aliases, qualified_names) = - IncrementalView::extract_all_tables(&select, &schema).unwrap(); + let (tables, aliases, qualified_names, table_conditions) = + extract_all_tables(&select, &schema).unwrap(); let view = IncrementalView::new( "test_view".to_string(), select.clone(), tables, aliases, qualified_names, + table_conditions, extract_view_columns(&select, &schema).unwrap(), &schema, 1, // main_data_root @@ -1524,14 +1804,15 @@ mod tests { WHERE c.id > 10 AND o.total > 100", ); - let (tables, aliases, qualified_names) = - IncrementalView::extract_all_tables(&select, &schema).unwrap(); + let (tables, aliases, qualified_names, table_conditions) = + extract_all_tables(&select, &schema).unwrap(); let view = IncrementalView::new( "test_view".to_string(), select.clone(), tables, aliases, qualified_names, + table_conditions, extract_view_columns(&select, &schema).unwrap(), &schema, 1, // main_data_root @@ -1547,8 +1828,12 @@ mod tests { // With per-table WHERE extraction: // - customers table gets: c.id > 10 // - orders table gets: o.total > 100 - assert_eq!(queries[0], "SELECT * FROM customers WHERE id > 10"); - assert_eq!(queries[1], "SELECT * FROM orders WHERE total > 100"); + assert!(queries + .iter() + .any(|q| q == "SELECT * FROM customers WHERE id > 10")); + assert!(queries + .iter() + .any(|q| q == "SELECT * FROM orders WHERE total > 100")); } #[test] @@ -1562,14 +1847,15 @@ mod tests { AND o.customer_id = 5 AND (c.id = 15 OR o.total = 200)", ); - let (tables, aliases, qualified_names) = - IncrementalView::extract_all_tables(&select, &schema).unwrap(); + let (tables, aliases, qualified_names, table_conditions) = + extract_all_tables(&select, &schema).unwrap(); let view = IncrementalView::new( "test_view".to_string(), select.clone(), tables, aliases, qualified_names, + table_conditions, extract_view_columns(&select, &schema).unwrap(), &schema, 1, // main_data_root @@ -1587,152 +1873,27 @@ mod tests { // - orders gets: o.total > 100 AND o.customer_id = 5 // Note: The OR condition (c.id = 15 OR o.total = 200) involves both tables, // so it cannot be extracted to either table individually - assert_eq!( - queries[0], - "SELECT * FROM customers WHERE id > 10 AND name = 'John'" - ); - assert_eq!( - queries[1], - "SELECT * FROM orders WHERE total > 100 AND customer_id = 5" - ); - } - - #[test] - fn test_where_extraction_for_three_tables() { - // Test that WHERE clause extraction correctly separates conditions for 3+ tables - // This addresses the concern about conditions "piling up" as joins increase - - // Simulate a three-table scenario - let schema = create_test_schema(); - - // Parse a WHERE clause with conditions for three different tables - let select = parse_select( - "SELECT * FROM customers WHERE c.id > 10 AND o.total > 100 AND p.price > 50", - ); - - // Get the WHERE expression - if let ast::OneSelect::Select { - where_clause: Some(ref where_expr), - .. - } = select.body.select - { - // Create a view with three tables to test extraction - let tables = vec![ - schema.get_btree_table("customers").unwrap(), - schema.get_btree_table("orders").unwrap(), - schema.get_btree_table("products").unwrap(), - ]; - - let mut aliases = HashMap::new(); - aliases.insert("c".to_string(), "customers".to_string()); - aliases.insert("o".to_string(), "orders".to_string()); - aliases.insert("p".to_string(), "products".to_string()); - - // Create a minimal view just to test extraction logic - let view = IncrementalView { - name: "test".to_string(), - select_stmt: select.clone(), - circuit: DbspCircuit::new(1, 2, 3), - referenced_tables: tables, - table_aliases: aliases, - qualified_table_names: HashMap::new(), - column_schema: ViewColumnSchema { - columns: vec![], - tables: vec![], - }, - populate_state: PopulateState::Start, - tracker: Arc::new(Mutex::new(ComputationTracker::new())), - root_page: 0, - }; - - // Test extraction for each table - let customers_conds = view - .extract_table_conditions(where_expr, "customers") - .unwrap(); - let orders_conds = view.extract_table_conditions(where_expr, "orders").unwrap(); - let products_conds = view - .extract_table_conditions(where_expr, "products") - .unwrap(); - - // Verify each table only gets its conditions - if let Some(cond) = customers_conds { - let sql = cond.to_string(); - assert!(sql.contains("id > 10")); - assert!(!sql.contains("total")); - assert!(!sql.contains("price")); - } - - if let Some(cond) = orders_conds { - let sql = cond.to_string(); - assert!(sql.contains("total > 100")); - assert!(!sql.contains("id > 10")); // From customers - assert!(!sql.contains("price")); - } - - if let Some(cond) = products_conds { - let sql = cond.to_string(); - assert!(sql.contains("price > 50")); - assert!(!sql.contains("id > 10")); // From customers - assert!(!sql.contains("total")); - } - } else { - panic!("Failed to parse WHERE clause"); - } - } - - #[test] - fn test_alias_resolution_works_correctly() { - // Test that alias resolution properly maps aliases to table names - let schema = create_test_schema(); - let select = parse_select( - "SELECT * FROM customers c \ - JOIN orders o ON c.id = o.customer_id \ - WHERE c.id > 10 AND o.total > 100", - ); - - let (tables, aliases, qualified_names) = - IncrementalView::extract_all_tables(&select, &schema).unwrap(); - let view = IncrementalView::new( - "test_view".to_string(), - select.clone(), - tables, - aliases, - qualified_names, - extract_view_columns(&select, &schema).unwrap(), - &schema, - 1, // main_data_root - 2, // internal_state_root - 3, // internal_state_index_root - ) - .unwrap(); - - // Verify that alias mappings were extracted correctly - assert_eq!(view.table_aliases.get("c"), Some(&"customers".to_string())); - assert_eq!(view.table_aliases.get("o"), Some(&"orders".to_string())); - - // Verify that SQL generation uses the aliases correctly - let queries = view.sql_for_populate().unwrap(); - assert_eq!(queries.len(), 2); - - // Each query should use the actual table name, not the alias - assert!(queries[0].contains("FROM customers") || queries[1].contains("FROM customers")); - assert!(queries[0].contains("FROM orders") || queries[1].contains("FROM orders")); + // Check both queries exist (order doesn't matter) + assert!(queries + .contains(&"SELECT * FROM customers WHERE id > 10 AND name = 'John'".to_string())); + assert!(queries + .contains(&"SELECT * FROM orders WHERE total > 100 AND customer_id = 5".to_string())); } #[test] fn test_sql_for_populate_table_without_rowid_alias() { - // Test that tables without a rowid alias properly include rowid in SELECT let schema = create_test_schema(); let select = parse_select("SELECT * FROM logs WHERE level > 2"); - let (tables, aliases, qualified_names) = - IncrementalView::extract_all_tables(&select, &schema).unwrap(); + let (tables, aliases, qualified_names, table_conditions) = + extract_all_tables(&select, &schema).unwrap(); let view = IncrementalView::new( "test_view".to_string(), select.clone(), tables, aliases, qualified_names, + table_conditions, extract_view_columns(&select, &schema).unwrap(), &schema, 1, // main_data_root @@ -1758,14 +1919,15 @@ mod tests { WHERE c.id > 10 AND l.level > 2", ); - let (tables, aliases, qualified_names) = - IncrementalView::extract_all_tables(&select, &schema).unwrap(); + let (tables, aliases, qualified_names, table_conditions) = + extract_all_tables(&select, &schema).unwrap(); let view = IncrementalView::new( "test_view".to_string(), select.clone(), tables, aliases, qualified_names, + table_conditions, extract_view_columns(&select, &schema).unwrap(), &schema, 1, // main_data_root @@ -1778,8 +1940,8 @@ mod tests { assert_eq!(queries.len(), 2); // customers has rowid alias (id), logs doesn't - assert_eq!(queries[0], "SELECT * FROM customers WHERE id > 10"); - assert_eq!(queries[1], "SELECT *, rowid FROM logs WHERE level > 2"); + assert!(queries.contains(&"SELECT * FROM customers WHERE id > 10".to_string())); + assert!(queries.contains(&"SELECT *, rowid FROM logs WHERE level > 2".to_string())); } #[test] @@ -1792,14 +1954,15 @@ mod tests { // Test with single table using database qualification let select = parse_select("SELECT * FROM main.customers WHERE main.customers.id > 10"); - let (tables, aliases, qualified_names) = - IncrementalView::extract_all_tables(&select, &schema).unwrap(); + let (tables, aliases, qualified_names, table_conditions) = + extract_all_tables(&select, &schema).unwrap(); let view = IncrementalView::new( "test_view".to_string(), select.clone(), tables, aliases, qualified_names, + table_conditions, extract_view_columns(&select, &schema).unwrap(), &schema, 1, // main_data_root @@ -1827,14 +1990,15 @@ mod tests { WHERE main.customers.id > 10 AND main.orders.total > 100", ); - let (tables, aliases, qualified_names) = - IncrementalView::extract_all_tables(&select, &schema).unwrap(); + let (tables, aliases, qualified_names, table_conditions) = + extract_all_tables(&select, &schema).unwrap(); let view = IncrementalView::new( "test_view".to_string(), select.clone(), tables, aliases, qualified_names, + table_conditions, extract_view_columns(&select, &schema).unwrap(), &schema, 1, // main_data_root @@ -1848,8 +2012,93 @@ mod tests { assert_eq!(queries.len(), 2); // The FROM clauses should preserve database qualification, // but WHERE clauses should have unqualified column names - assert_eq!(queries[0], "SELECT * FROM main.customers WHERE id > 10"); - assert_eq!(queries[1], "SELECT * FROM main.orders WHERE total > 100"); + assert!(queries.contains(&"SELECT * FROM main.customers WHERE id > 10".to_string())); + assert!(queries.contains(&"SELECT * FROM main.orders WHERE total > 100".to_string())); + } + + #[test] + fn test_where_extraction_for_three_tables_with_aliases() { + // Test that WHERE clause extraction correctly separates conditions for 3+ tables + // This addresses the concern about conditions "piling up" as joins increase + let schema = create_test_schema(); + let select = parse_select( + "SELECT * FROM customers c + JOIN orders o ON c.id = o.customer_id + JOIN products p ON p.id = o.product_id + WHERE c.id > 10 AND o.total > 100 AND p.price > 50", + ); + + let (tables, aliases, qualified_names, table_conditions) = + extract_all_tables(&select, &schema).unwrap(); + + // Verify we extracted all three tables + assert_eq!(tables.len(), 3); + let table_names: Vec<&str> = tables.iter().map(|t| t.name.as_str()).collect(); + assert!(table_names.contains(&"customers")); + assert!(table_names.contains(&"orders")); + assert!(table_names.contains(&"products")); + + // Verify aliases are correctly mapped + assert_eq!(aliases.get("c"), Some(&"customers".to_string())); + assert_eq!(aliases.get("o"), Some(&"orders".to_string())); + assert_eq!(aliases.get("p"), Some(&"products".to_string())); + + // Generate populate queries to verify each table gets its own conditions + let queries = IncrementalView::generate_populate_queries( + &select, + &tables, + &aliases, + &qualified_names, + &table_conditions, + ) + .unwrap(); + + assert_eq!(queries.len(), 3); + + // Verify the exact queries generated for each table + // The order might vary, so check all possibilities + let expected_queries = vec![ + "SELECT * FROM customers WHERE id > 10", + "SELECT * FROM orders WHERE total > 100", + "SELECT * FROM products WHERE price > 50", + ]; + + for expected in &expected_queries { + assert!( + queries.contains(&expected.to_string()), + "Missing expected query: {expected}. Got: {queries:?}" + ); + } + } + + #[test] + fn test_sql_for_populate_complex_expressions_not_included() { + // Test that complex expressions (subqueries, CASE, string concat) are NOT included in populate queries + let schema = create_test_schema(); + let select = parse_select( + "SELECT * FROM customers + WHERE id > (SELECT MAX(customer_id) FROM orders) + AND name || ' Customer' = 'John Customer' + AND CASE WHEN id > 10 THEN 1 ELSE 0 END = 1 + AND EXISTS (SELECT 1 FROM orders WHERE customer_id = customers.id)", + ); + + let (tables, aliases, qualified_names, table_conditions) = + extract_all_tables(&select, &schema).unwrap(); + + let queries = IncrementalView::generate_populate_queries( + &select, + &tables, + &aliases, + &qualified_names, + &table_conditions, + ) + .unwrap(); + + assert_eq!(queries.len(), 1); + // Since customers table has an INTEGER PRIMARY KEY (id), we should get SELECT * + // without rowid and without WHERE clause (all conditions are complex) + assert_eq!(queries[0], "SELECT * FROM customers"); } #[test] @@ -1862,14 +2111,15 @@ mod tests { WHERE total > 100", // 'total' only exists in orders table ); - let (tables, aliases, qualified_names) = - IncrementalView::extract_all_tables(&select, &schema).unwrap(); + let (tables, aliases, qualified_names, table_conditions) = + extract_all_tables(&select, &schema).unwrap(); let view = IncrementalView::new( "test_view".to_string(), select.clone(), tables, aliases, qualified_names, + table_conditions, extract_view_columns(&select, &schema).unwrap(), &schema, 1, // main_data_root @@ -1883,8 +2133,8 @@ mod tests { assert_eq!(queries.len(), 2); // 'total' is unambiguous (only in orders), so it should be extracted - assert_eq!(queries[0], "SELECT * FROM customers"); - assert_eq!(queries[1], "SELECT * FROM orders WHERE total > 100"); + assert!(queries.contains(&"SELECT * FROM customers".to_string())); + assert!(queries.contains(&"SELECT * FROM orders WHERE total > 100".to_string())); } #[test] @@ -1899,8 +2149,8 @@ mod tests { WHERE c.id > 10", ); - let (tables, aliases, qualified_names) = - IncrementalView::extract_all_tables(&select, &schema).unwrap(); + let (tables, aliases, qualified_names, table_conditions) = + extract_all_tables(&select, &schema).unwrap(); // Check that qualified names are preserved assert!(qualified_names.contains_key("customers")); @@ -1914,6 +2164,7 @@ mod tests { tables, aliases, qualified_names.clone(), + table_conditions, extract_view_columns(&select, &schema).unwrap(), &schema, 1, // main_data_root @@ -1928,8 +2179,8 @@ mod tests { // The FROM clause should contain the database-qualified name // But the WHERE clause should use unqualified column names - assert_eq!(queries[0], "SELECT * FROM main.customers WHERE id > 10"); - assert_eq!(queries[1], "SELECT * FROM main.orders"); + assert!(queries.contains(&"SELECT * FROM main.customers WHERE id > 10".to_string())); + assert!(queries.contains(&"SELECT * FROM main.orders".to_string())); } #[test] @@ -1944,8 +2195,8 @@ mod tests { WHERE c.id > 10 AND o.total < 1000", ); - let (tables, aliases, qualified_names) = - IncrementalView::extract_all_tables(&select, &schema).unwrap(); + let (tables, aliases, qualified_names, table_conditions) = + extract_all_tables(&select, &schema).unwrap(); // Check that qualified names are preserved where specified assert_eq!(qualified_names.get("customers").unwrap(), "main.customers"); @@ -1961,6 +2212,7 @@ mod tests { tables, aliases, qualified_names.clone(), + table_conditions, extract_view_columns(&select, &schema).unwrap(), &schema, 1, // main_data_root @@ -1974,7 +2226,468 @@ mod tests { assert_eq!(queries.len(), 2); // The FROM clause should preserve qualification where specified - assert_eq!(queries[0], "SELECT * FROM main.customers WHERE id > 10"); - assert_eq!(queries[1], "SELECT * FROM orders WHERE total < 1000"); + assert!(queries.contains(&"SELECT * FROM main.customers WHERE id > 10".to_string())); + assert!(queries.contains(&"SELECT * FROM orders WHERE total < 1000".to_string())); + } + + #[test] + fn test_extract_tables_with_simple_cte() { + let schema = create_test_schema(); + let select = parse_select( + "WITH customer_totals AS ( + SELECT c.id, c.name, SUM(o.total) as total_spent + FROM customers c + JOIN orders o ON c.id = o.customer_id + GROUP BY c.id, c.name + ) + SELECT * FROM customer_totals WHERE total_spent > 1000", + ); + + let (tables, aliases, _qualified_names, _table_conditions) = + extract_all_tables(&select, &schema).unwrap(); + + // Check that we found both tables from the CTE + assert_eq!(tables.len(), 2); + let table_names: Vec<&str> = tables.iter().map(|t| t.name.as_str()).collect(); + assert!(table_names.contains(&"customers")); + assert!(table_names.contains(&"orders")); + + // Check aliases from the CTE + assert_eq!(aliases.get("c"), Some(&"customers".to_string())); + assert_eq!(aliases.get("o"), Some(&"orders".to_string())); + } + + #[test] + fn test_extract_tables_with_multiple_ctes() { + let schema = create_test_schema(); + let select = parse_select( + "WITH + high_value_customers AS ( + SELECT id, name + FROM customers + WHERE id IN (SELECT customer_id FROM orders WHERE total > 500) + ), + recent_orders AS ( + SELECT id, customer_id, total + FROM orders + WHERE id > 100 + ) + SELECT hvc.name, ro.total + FROM high_value_customers hvc + JOIN recent_orders ro ON hvc.id = ro.customer_id", + ); + + let (tables, _aliases, _qualified_names, _table_conditions) = + extract_all_tables(&select, &schema).unwrap(); + + // Check that we found both tables from both CTEs + assert_eq!(tables.len(), 2); + let table_names: Vec<&str> = tables.iter().map(|t| t.name.as_str()).collect(); + assert!(table_names.contains(&"customers")); + assert!(table_names.contains(&"orders")); + } + + #[test] + fn test_sql_for_populate_union_mixed_conditions() { + // Test UNION where same table appears with and without WHERE clause + // This should drop ALL conditions to ensure we get all rows + let schema = create_test_schema(); + + let select = parse_select( + "SELECT * FROM customers WHERE id > 10 + UNION ALL + SELECT * FROM customers", + ); + + let (tables, aliases, qualified_names, table_conditions) = + extract_all_tables(&select, &schema).unwrap(); + + let view = IncrementalView::new( + "union_view".to_string(), + select.clone(), + tables, + aliases, + qualified_names, + table_conditions, + extract_view_columns(&select, &schema).unwrap(), + &schema, + 1, // main_data_root + 2, // internal_state_root + 3, // internal_state_index_root + ) + .unwrap(); + + let queries = view.sql_for_populate().unwrap(); + + assert_eq!(queries.len(), 1); + // When the same table appears with and without WHERE conditions in a UNION, + // we must fetch ALL rows (no WHERE clause) because the conditions are incompatible + assert_eq!( + queries[0], "SELECT * FROM customers", + "UNION with mixed conditions (some with WHERE, some without) should fetch ALL rows" + ); + } + + #[test] + fn test_extract_tables_with_nested_cte() { + let schema = create_test_schema(); + let select = parse_select( + "WITH RECURSIVE customer_hierarchy AS ( + SELECT id, name, 0 as level + FROM customers + WHERE id = 1 + UNION ALL + SELECT c.id, c.name, ch.level + 1 + FROM customers c + JOIN orders o ON c.id = o.customer_id + JOIN customer_hierarchy ch ON o.customer_id = ch.id + WHERE ch.level < 3 + ) + SELECT * FROM customer_hierarchy", + ); + + let (tables, _aliases, _qualified_names, _table_conditions) = + extract_all_tables(&select, &schema).unwrap(); + + // Check that we found the tables referenced in the recursive CTE + let table_names: Vec<&str> = tables.iter().map(|t| t.name.as_str()).collect(); + + // We're finding duplicates because "customers" appears twice in the recursive CTE + // Let's deduplicate + let unique_tables: std::collections::HashSet<&str> = table_names.iter().cloned().collect(); + assert_eq!(unique_tables.len(), 2); + assert!(unique_tables.contains("customers")); + assert!(unique_tables.contains("orders")); + } + + #[test] + fn test_extract_tables_with_cte_and_main_query() { + let schema = create_test_schema(); + let select = parse_select( + "WITH customer_stats AS ( + SELECT customer_id, COUNT(*) as order_count + FROM orders + GROUP BY customer_id + ) + SELECT c.name, cs.order_count, p.name as product_name + FROM customers c + JOIN customer_stats cs ON c.id = cs.customer_id + JOIN products p ON p.id = 1", + ); + + let (tables, aliases, _qualified_names, _table_conditions) = + extract_all_tables(&select, &schema).unwrap(); + + // Check that we found tables from both the CTE and the main query + assert_eq!(tables.len(), 3); + let table_names: Vec<&str> = tables.iter().map(|t| t.name.as_str()).collect(); + assert!(table_names.contains(&"customers")); + assert!(table_names.contains(&"orders")); + assert!(table_names.contains(&"products")); + + // Check aliases from main query + assert_eq!(aliases.get("c"), Some(&"customers".to_string())); + assert_eq!(aliases.get("p"), Some(&"products".to_string())); + } + + #[test] + fn test_sql_for_populate_simple_union() { + let schema = create_test_schema(); + let select = parse_select( + "SELECT * FROM orders WHERE total > 1000 + UNION ALL + SELECT * FROM orders WHERE total < 100", + ); + + let (tables, aliases, qualified_names, table_conditions) = + extract_all_tables(&select, &schema).unwrap(); + + // Generate populate queries + let queries = IncrementalView::generate_populate_queries( + &select, + &tables, + &aliases, + &qualified_names, + &table_conditions, + ) + .unwrap(); + + // We should have deduplicated to a single table + assert_eq!(tables.len(), 1, "Should have one unique table"); + assert_eq!(tables[0].name, "orders"); // Single table, order doesn't matter + + // Should have collected two conditions + assert_eq!(table_conditions.get("orders").unwrap().len(), 2); + + // Should combine multiple conditions with OR + assert_eq!(queries.len(), 1); + // Conditions are combined with OR + assert_eq!( + queries[0], + "SELECT * FROM orders WHERE (total > 1000) OR (total < 100)" + ); + } + + #[test] + fn test_sql_for_populate_with_union_and_filters() { + let schema = create_test_schema(); + + // Test UNION with different WHERE conditions on the same table + let select = parse_select( + "SELECT * FROM orders WHERE total > 1000 + UNION ALL + SELECT * FROM orders WHERE total < 100", + ); + + let view = IncrementalView::from_stmt( + ast::QualifiedName { + db_name: None, + name: ast::Name::Ident("test_view".to_string()), + alias: None, + }, + select, + &schema, + 1, + 2, + 3, + ) + .unwrap(); + + let queries = view.sql_for_populate().unwrap(); + + // We deduplicate tables, so we get 1 query for orders + assert_eq!(queries.len(), 1); + + // Multiple conditions on the same table are combined with OR + assert_eq!( + queries[0], + "SELECT * FROM orders WHERE (total > 1000) OR (total < 100)" + ); + } + + #[test] + fn test_sql_for_populate_with_union_mixed_tables() { + let schema = create_test_schema(); + + // Test UNION with different tables + let select = parse_select( + "SELECT id, name FROM customers WHERE id > 10 + UNION ALL + SELECT customer_id as id, 'Order' as name FROM orders WHERE total > 500", + ); + + let view = IncrementalView::from_stmt( + ast::QualifiedName { + db_name: None, + name: ast::Name::Ident("test_view".to_string()), + alias: None, + }, + select, + &schema, + 1, + 2, + 3, + ) + .unwrap(); + + let queries = view.sql_for_populate().unwrap(); + + assert_eq!(queries.len(), 2, "Should have one query per table"); + + // Check that each table gets its appropriate WHERE clause + let customers_query = queries + .iter() + .find(|q| q.contains("FROM customers")) + .unwrap(); + let orders_query = queries.iter().find(|q| q.contains("FROM orders")).unwrap(); + + assert!(customers_query.contains("WHERE id > 10")); + assert!(orders_query.contains("WHERE total > 500")); + } + + #[test] + fn test_sql_for_populate_duplicate_tables_conflicting_filters() { + // This tests what happens when we have duplicate table references with different filters + // We need to manually construct a view to simulate what would happen with CTEs + let schema = create_test_schema(); + + // Get the orders table twice (simulating what would happen with CTEs) + let orders_table = schema.get_btree_table("orders").unwrap(); + + let referenced_tables = vec![orders_table.clone(), orders_table.clone()]; + + // Create a SELECT that would have conflicting WHERE conditions + let select = parse_select( + "SELECT * FROM orders WHERE total > 1000", // This is just for the AST + ); + + let view = IncrementalView::new( + "test_view".to_string(), + select.clone(), + referenced_tables, + HashMap::new(), + HashMap::new(), + HashMap::new(), + extract_view_columns(&select, &schema).unwrap(), + &schema, + 1, + 2, + 3, + ) + .unwrap(); + + let queries = view.sql_for_populate().unwrap(); + + // With duplicates, we should get 2 identical queries + assert_eq!(queries.len(), 2); + + // Both should be the same since they're from the same table reference + assert_eq!(queries[0], queries[1]); + } + + #[test] + fn test_table_extraction_with_nested_ctes_complex_conditions() { + let schema = create_test_schema(); + let select = parse_select( + "WITH + customer_orders AS ( + SELECT c.*, o.total + FROM customers c + JOIN orders o ON c.id = o.customer_id + WHERE c.name LIKE 'A%' AND o.total > 100 + ), + top_customers AS ( + SELECT * FROM customer_orders WHERE total > 500 + ) + SELECT * FROM top_customers", + ); + + // Test table extraction directly without creating a view + let mut tables = Vec::new(); + let mut aliases = HashMap::new(); + let mut qualified_names = HashMap::new(); + let mut table_conditions = HashMap::new(); + + IncrementalView::extract_all_tables( + &select, + &schema, + &mut tables, + &mut aliases, + &mut qualified_names, + &mut table_conditions, + ) + .unwrap(); + + let table_names: Vec<&str> = tables.iter().map(|t| t.name.as_str()).collect(); + + // Should have one reference to each table + assert_eq!(table_names.len(), 2, "Should have 2 table references"); + assert!(table_names.contains(&"customers")); + assert!(table_names.contains(&"orders")); + + // Check aliases + assert_eq!(aliases.get("c"), Some(&"customers".to_string())); + assert_eq!(aliases.get("o"), Some(&"orders".to_string())); + } + + #[test] + fn test_union_all_populate_queries() { + // Test that UNION ALL generates correct populate queries + let schema = create_test_schema(); + + // Create a UNION ALL query that references the same table twice with different WHERE conditions + let sql = " + SELECT id, name FROM customers WHERE id < 5 + UNION ALL + SELECT id, name FROM customers WHERE id > 10 + "; + + let mut parser = Parser::new(sql.as_bytes()); + let cmd = parser.next_cmd().unwrap(); + let select_stmt = match cmd.unwrap() { + turso_parser::ast::Cmd::Stmt(ast::Stmt::Select(select)) => select, + _ => panic!("Expected SELECT statement"), + }; + + // Extract tables and conditions + let (tables, aliases, qualified_names, conditions) = + extract_all_tables(&select_stmt, &schema).unwrap(); + + // Generate populate queries + let queries = IncrementalView::generate_populate_queries( + &select_stmt, + &tables, + &aliases, + &qualified_names, + &conditions, + ) + .unwrap(); + + // Expected query - assuming customers table has INTEGER PRIMARY KEY + // so we don't need to select rowid separately + let expected = "SELECT * FROM customers WHERE (id < 5) OR (id > 10)"; + + assert_eq!( + queries.len(), + 1, + "Should generate exactly 1 query for UNION ALL with same table" + ); + assert_eq!(queries[0], expected, "Query should match expected format"); + } + + #[test] + fn test_union_all_different_tables_populate_queries() { + // Test UNION ALL with different tables + let schema = create_test_schema(); + + let sql = " + SELECT id, name FROM customers WHERE id < 5 + UNION ALL + SELECT id, product_name FROM orders WHERE amount > 100 + "; + + let mut parser = Parser::new(sql.as_bytes()); + let cmd = parser.next_cmd().unwrap(); + let select_stmt = match cmd.unwrap() { + turso_parser::ast::Cmd::Stmt(ast::Stmt::Select(select)) => select, + _ => panic!("Expected SELECT statement"), + }; + + // Extract tables and conditions + let (tables, aliases, qualified_names, conditions) = + extract_all_tables(&select_stmt, &schema).unwrap(); + + // Generate populate queries + let queries = IncrementalView::generate_populate_queries( + &select_stmt, + &tables, + &aliases, + &qualified_names, + &conditions, + ) + .unwrap(); + + // Should generate separate queries for each table + assert_eq!( + queries.len(), + 2, + "Should generate 2 queries for different tables" + ); + + // Check we have queries for both tables + let has_customers = queries.iter().any(|q| q.contains("customers")); + let has_orders = queries.iter().any(|q| q.contains("orders")); + assert!(has_customers, "Should have a query for customers table"); + assert!(has_orders, "Should have a query for orders table"); + + // Verify the customers query has its WHERE clause + let customers_query = queries + .iter() + .find(|q| q.contains("customers")) + .expect("Should have customers query"); + assert!( + customers_query.contains("WHERE"), + "Customers query should have WHERE clause" + ); } } diff --git a/testing/materialized_views.test b/testing/materialized_views.test index 15229a48c..354f65d39 100755 --- a/testing/materialized_views.test +++ b/testing/materialized_views.test @@ -1091,3 +1091,340 @@ do_execsql_test_on_specific_db {:memory:} matview-join-complex-where { } {Charlie|10|100|1000 Alice|5|100|500 Charlie|6|75|450} + +# Test UNION queries in materialized views +do_execsql_test_on_specific_db {:memory:} matview-union-simple { + CREATE TABLE sales_online(id INTEGER, product TEXT, amount INTEGER); + CREATE TABLE sales_store(id INTEGER, product TEXT, amount INTEGER); + + INSERT INTO sales_online VALUES + (1, 'Laptop', 1200), + (2, 'Mouse', 25), + (3, 'Monitor', 400); + + INSERT INTO sales_store VALUES + (1, 'Keyboard', 75), + (2, 'Chair', 150), + (3, 'Desk', 350); + + -- Create a view that combines both sources + CREATE MATERIALIZED VIEW all_sales AS + SELECT product, amount FROM sales_online + UNION ALL + SELECT product, amount FROM sales_store; + + SELECT * FROM all_sales ORDER BY product; +} {Chair|150 +Desk|350 +Keyboard|75 +Laptop|1200 +Monitor|400 +Mouse|25} + +do_execsql_test_on_specific_db {:memory:} matview-union-with-where { + CREATE TABLE employees(id INTEGER, name TEXT, dept TEXT, salary INTEGER); + CREATE TABLE contractors(id INTEGER, name TEXT, dept TEXT, rate INTEGER); + + INSERT INTO employees VALUES + (1, 'Alice', 'Engineering', 90000), + (2, 'Bob', 'Sales', 60000), + (3, 'Charlie', 'Engineering', 85000); + + INSERT INTO contractors VALUES + (1, 'David', 'Engineering', 150), + (2, 'Eve', 'Marketing', 120), + (3, 'Frank', 'Engineering', 180); + + -- High-earning staff from both categories + CREATE MATERIALIZED VIEW high_earners AS + SELECT name, dept, salary as compensation FROM employees WHERE salary > 80000 + UNION ALL + SELECT name, dept, rate * 2000 as compensation FROM contractors WHERE rate > 140; + + SELECT * FROM high_earners ORDER BY name; +} {Alice|Engineering|90000 +Charlie|Engineering|85000 +David|Engineering|300000 +Frank|Engineering|360000} + +do_execsql_test_on_specific_db {:memory:} matview-union-same-table-different-filters { + CREATE TABLE orders(id INTEGER, customer_id INTEGER, product TEXT, amount INTEGER, status TEXT); + + INSERT INTO orders VALUES + (1, 1, 'Laptop', 1200, 'completed'), + (2, 2, 'Mouse', 25, 'pending'), + (3, 1, 'Monitor', 400, 'completed'), + (4, 3, 'Keyboard', 75, 'cancelled'), + (5, 2, 'Desk', 350, 'completed'), + (6, 3, 'Chair', 150, 'pending'); + + -- View showing priority orders: high-value OR pending status + CREATE MATERIALIZED VIEW priority_orders AS + SELECT id, customer_id, product, amount FROM orders WHERE amount > 300 + UNION ALL + SELECT id, customer_id, product, amount FROM orders WHERE status = 'pending'; + + SELECT * FROM priority_orders ORDER BY id; +} {1|1|Laptop|1200 +2|2|Mouse|25 +3|1|Monitor|400 +5|2|Desk|350 +6|3|Chair|150} + +do_execsql_test_on_specific_db {:memory:} matview-union-with-aggregation { + CREATE TABLE q1_sales(product TEXT, quantity INTEGER, revenue INTEGER); + CREATE TABLE q2_sales(product TEXT, quantity INTEGER, revenue INTEGER); + + INSERT INTO q1_sales VALUES + ('Laptop', 10, 12000), + ('Mouse', 50, 1250), + ('Monitor', 8, 3200); + + INSERT INTO q2_sales VALUES + ('Laptop', 15, 18000), + ('Mouse', 60, 1500), + ('Keyboard', 30, 2250); + + -- Combined quarterly summary + CREATE MATERIALIZED VIEW half_year_summary AS + SELECT 'Q1' as quarter, SUM(quantity) as total_units, SUM(revenue) as total_revenue + FROM q1_sales + UNION ALL + SELECT 'Q2' as quarter, SUM(quantity) as total_units, SUM(revenue) as total_revenue + FROM q2_sales; + + SELECT * FROM half_year_summary ORDER BY quarter; +} {Q1|68|16450 +Q2|105|21750} + +do_execsql_test_on_specific_db {:memory:} matview-union-with-join { + CREATE TABLE customers(id INTEGER PRIMARY KEY, name TEXT, type TEXT); + CREATE TABLE orders(id INTEGER PRIMARY KEY, customer_id INTEGER, amount INTEGER); + CREATE TABLE quotes(id INTEGER PRIMARY KEY, customer_id INTEGER, amount INTEGER); + + INSERT INTO customers VALUES + (1, 'Alice', 'premium'), + (2, 'Bob', 'regular'), + (3, 'Charlie', 'premium'); + + INSERT INTO orders VALUES + (1, 1, 1000), + (2, 2, 500), + (3, 3, 1500); + + INSERT INTO quotes VALUES + (1, 1, 800), + (2, 2, 300), + (3, 3, 2000); + + -- All premium customer transactions (orders and quotes) + CREATE MATERIALIZED VIEW premium_transactions AS + SELECT c.name, 'order' as type, o.amount + FROM customers c + JOIN orders o ON c.id = o.customer_id + WHERE c.type = 'premium' + UNION ALL + SELECT c.name, 'quote' as type, q.amount + FROM customers c + JOIN quotes q ON c.id = q.customer_id + WHERE c.type = 'premium'; + + SELECT * FROM premium_transactions ORDER BY name, type, amount; +} {Alice|order|1000 +Alice|quote|800 +Charlie|order|1500 +Charlie|quote|2000} + +do_execsql_test_on_specific_db {:memory:} matview-union-distinct { + CREATE TABLE active_users(id INTEGER, name TEXT, email TEXT); + CREATE TABLE inactive_users(id INTEGER, name TEXT, email TEXT); + + INSERT INTO active_users VALUES + (1, 'Alice', 'alice@example.com'), + (2, 'Bob', 'bob@example.com'), + (3, 'Charlie', 'charlie@example.com'); + + INSERT INTO inactive_users VALUES + (4, 'David', 'david@example.com'), + (2, 'Bob', 'bob@example.com'), -- Bob appears in both + (5, 'Eve', 'eve@example.com'); + + -- All unique users (using UNION to deduplicate) + CREATE MATERIALIZED VIEW all_users AS + SELECT id, name, email FROM active_users + UNION + SELECT id, name, email FROM inactive_users; + + SELECT * FROM all_users ORDER BY id; +} {1|Alice|alice@example.com +2|Bob|bob@example.com +3|Charlie|charlie@example.com +4|David|david@example.com +5|Eve|eve@example.com} + +do_execsql_test_on_specific_db {:memory:} matview-union-complex-multiple-branches { + CREATE TABLE products(id INTEGER, name TEXT, category TEXT, price INTEGER); + + INSERT INTO products VALUES + (1, 'Laptop', 'Electronics', 1200), + (2, 'Mouse', 'Electronics', 25), + (3, 'Desk', 'Furniture', 350), + (4, 'Chair', 'Furniture', 150), + (5, 'Monitor', 'Electronics', 400), + (6, 'Keyboard', 'Electronics', 75), + (7, 'Bookshelf', 'Furniture', 200), + (8, 'Tablet', 'Electronics', 600); + + -- Products of interest: expensive electronics, all furniture, or very cheap items + CREATE MATERIALIZED VIEW featured_products AS + SELECT name, category, price, 'PremiumElectronic' as tag + FROM products + WHERE category = 'Electronics' AND price > 500 + UNION ALL + SELECT name, category, price, 'Furniture' as tag + FROM products + WHERE category = 'Furniture' + UNION ALL + SELECT name, category, price, 'Budget' as tag + FROM products + WHERE price < 50; + + SELECT * FROM featured_products ORDER BY tag, name; +} {Mouse|Electronics|25|Budget +Bookshelf|Furniture|200|Furniture +Chair|Furniture|150|Furniture +Desk|Furniture|350|Furniture +Laptop|Electronics|1200|PremiumElectronic +Tablet|Electronics|600|PremiumElectronic} + +do_execsql_test_on_specific_db {:memory:} matview-union-maintenance-insert { + CREATE TABLE t1(id INTEGER, value INTEGER); + CREATE TABLE t2(id INTEGER, value INTEGER); + + INSERT INTO t1 VALUES (1, 100), (2, 200); + INSERT INTO t2 VALUES (3, 300), (4, 400); + + CREATE MATERIALIZED VIEW combined AS + SELECT id, value FROM t1 WHERE value > 150 + UNION ALL + SELECT id, value FROM t2 WHERE value > 350; + + SELECT * FROM combined ORDER BY id; + + -- Insert into t1 + INSERT INTO t1 VALUES (5, 500); + SELECT * FROM combined ORDER BY id; + + -- Insert into t2 + INSERT INTO t2 VALUES (6, 600); + SELECT * FROM combined ORDER BY id; +} {2|200 +4|400 +2|200 +4|400 +5|500 +2|200 +4|400 +5|500 +6|600} + +do_execsql_test_on_specific_db {:memory:} matview-union-maintenance-delete { + CREATE TABLE source1(id INTEGER PRIMARY KEY, data TEXT); + CREATE TABLE source2(id INTEGER PRIMARY KEY, data TEXT); + + INSERT INTO source1 VALUES (1, 'A'), (2, 'B'), (3, 'C'); + INSERT INTO source2 VALUES (4, 'D'), (5, 'E'), (6, 'F'); + + CREATE MATERIALIZED VIEW merged AS + SELECT id, data FROM source1 + UNION ALL + SELECT id, data FROM source2; + + SELECT COUNT(*) FROM merged; + + DELETE FROM source1 WHERE id = 2; + SELECT COUNT(*) FROM merged; + + DELETE FROM source2 WHERE id > 4; + SELECT COUNT(*) FROM merged; +} {6 +5 +3} + +do_execsql_test_on_specific_db {:memory:} matview-union-maintenance-update { + CREATE TABLE high_priority(id INTEGER PRIMARY KEY, task TEXT, priority INTEGER); + CREATE TABLE normal_priority(id INTEGER PRIMARY KEY, task TEXT, priority INTEGER); + + INSERT INTO high_priority VALUES (1, 'Task A', 10), (2, 'Task B', 9); + INSERT INTO normal_priority VALUES (3, 'Task C', 5), (4, 'Task D', 6); + + CREATE MATERIALIZED VIEW active_tasks AS + SELECT id, task, priority FROM high_priority WHERE priority >= 9 + UNION ALL + SELECT id, task, priority FROM normal_priority WHERE priority >= 5; + + SELECT COUNT(*) FROM active_tasks; + + -- Update drops a high priority task below threshold + UPDATE high_priority SET priority = 8 WHERE id = 2; + SELECT COUNT(*) FROM active_tasks; + + -- Update brings a normal task above threshold + UPDATE normal_priority SET priority = 3 WHERE id = 3; + SELECT COUNT(*) FROM active_tasks; +} {4 +3 +2} + +# Test UNION ALL with same table and different WHERE conditions +do_execsql_test_on_specific_db {:memory:} matview-union-all-same-table { + CREATE TABLE test(id INTEGER PRIMARY KEY, value INTEGER); + INSERT INTO test VALUES (1, 10), (2, 20); + + -- This UNION ALL should return both rows + CREATE MATERIALIZED VIEW union_view AS + SELECT id, value FROM test WHERE value < 15 + UNION ALL + SELECT id, value FROM test WHERE value > 15; + + -- Should return 2 rows: (1,10) and (2,20) + SELECT * FROM union_view ORDER BY id; +} {1|10 +2|20} + +# Test UNION ALL preserves all rows in count +do_execsql_test_on_specific_db {:memory:} matview-union-all-row-count { + CREATE TABLE data(id INTEGER PRIMARY KEY, num INTEGER); + INSERT INTO data VALUES (1, 5), (2, 15), (3, 25); + + CREATE MATERIALIZED VIEW split_view AS + SELECT id, num FROM data WHERE num <= 10 + UNION ALL + SELECT id, num FROM data WHERE num > 10; + + -- Should return count of 3 + SELECT COUNT(*) FROM split_view; +} {3} + +# Test UNION ALL with text columns and filtering +do_execsql_test_on_specific_db {:memory:} matview-union-all-text-filter { + CREATE TABLE items(id INTEGER PRIMARY KEY, category TEXT, price INTEGER); + INSERT INTO items VALUES + (1, 'cheap', 10), + (2, 'expensive', 100), + (3, 'cheap', 20), + (4, 'expensive', 200); + + CREATE MATERIALIZED VIEW price_categories AS + SELECT id, category, price FROM items WHERE category = 'cheap' + UNION ALL + SELECT id, category, price FROM items WHERE category = 'expensive'; + + -- Should return all 4 items + SELECT COUNT(*) FROM price_categories; + SELECT id FROM price_categories ORDER BY id; +} {4 +1 +2 +3 +4}