diff --git a/core/incremental/cursor.rs b/core/incremental/cursor.rs index a486e9fb3..df4f78d9d 100644 --- a/core/incremental/cursor.rs +++ b/core/incremental/cursor.rs @@ -95,7 +95,11 @@ impl MaterializedViewCursor { // Process the delta through the circuit to get materialized changes let mut uncommitted = DeltaSet::new(); - uncommitted.insert(view_guard.base_table().name.clone(), tx_delta); + // Get the first table name from the view's referenced tables + let table_names = view_guard.get_referenced_table_names(); + if !table_names.is_empty() { + uncommitted.insert(table_names[0].clone(), tx_delta); + } let processed_delta = return_if_io!(view_guard.execute_with_uncommitted( uncommitted, diff --git a/core/incremental/view.rs b/core/incremental/view.rs index b15faf847..b2d0df606 100644 --- a/core/incremental/view.rs +++ b/core/incremental/view.rs @@ -166,8 +166,8 @@ pub struct IncrementalView { // DBSP circuit that encapsulates the computation circuit: DbspCircuit, - // Tables referenced by this view (extracted from FROM clause and JOINs) - base_table: Arc, + // All tables referenced by this view (from FROM clause and JOINs) + referenced_tables: Vec>, // The view's output columns with their types pub columns: Vec, // State machine for population @@ -199,7 +199,6 @@ impl IncrementalView { fn try_compile_circuit( select: &ast::Select, schema: &Schema, - _base_table: &Arc, main_data_root: usize, internal_state_root: usize, ) -> Result { @@ -311,26 +310,14 @@ impl IncrementalView { )); } - // Get the base table from FROM clause (when no joins) - let base_table = if let Some(base_table_name) = Self::extract_base_table(&select) { - if let Some(table) = schema.get_btree_table(&base_table_name) { - table.clone() - } else { - return Err(LimboError::ParseError(format!( - "Table '{base_table_name}' not found in schema" - ))); - } - } else { - return Err(LimboError::ParseError( - "views without a base table not supported yet".to_string(), - )); - }; + // Get all tables from FROM clause and JOINs + let referenced_tables = Self::extract_all_tables(&select, schema)?; Self::new( name, where_predicate, select.clone(), - base_table, + referenced_tables, view_columns, schema, main_data_root, @@ -343,7 +330,7 @@ impl IncrementalView { name: String, where_predicate: FilterPredicate, select_stmt: ast::Select, - base_table: Arc, + referenced_tables: Vec>, columns: Vec, schema: &Schema, main_data_root: usize, @@ -353,20 +340,15 @@ impl IncrementalView { let tracker = Arc::new(Mutex::new(ComputationTracker::new())); // Compile the SELECT statement into a DBSP circuit - let circuit = Self::try_compile_circuit( - &select_stmt, - schema, - &base_table, - main_data_root, - internal_state_root, - )?; + let circuit = + Self::try_compile_circuit(&select_stmt, schema, main_data_root, internal_state_root)?; Ok(Self { name, where_predicate, select_stmt, circuit, - base_table, + referenced_tables, columns, populate_state: PopulateState::Start, tracker, @@ -378,10 +360,6 @@ impl IncrementalView { &self.name } - pub fn base_table(&self) -> &Arc { - &self.base_table - } - /// Execute the circuit with uncommitted changes to get processed delta pub fn execute_with_uncommitted( &mut self, @@ -403,12 +381,60 @@ impl IncrementalView { /// Get all table names referenced by this view pub fn get_referenced_table_names(&self) -> Vec { - vec![self.base_table.name.clone()] + self.referenced_tables + .iter() + .map(|t| t.name.clone()) + .collect() } /// Get all tables referenced by this view pub fn get_referenced_tables(&self) -> Vec> { - vec![self.base_table.clone()] + self.referenced_tables.clone() + } + + /// Extract all table names from a SELECT statement (including JOINs) + fn extract_all_tables(select: &ast::Select, schema: &Schema) -> Result>> { + let mut tables = Vec::new(); + + if let ast::OneSelect::Select { + from: Some(ref from), + .. + } = select.body.select + { + // Get the main table from FROM clause + if let ast::SelectTable::Table(name, _, _) = from.select.as_ref() { + let table_name = name.name.as_str(); + if let Some(table) = schema.get_btree_table(table_name) { + tables.push(table.clone()); + } else { + return Err(LimboError::ParseError(format!( + "Table '{table_name}' not found in schema" + ))); + } + } + + // Get all tables from JOIN clauses + for join in &from.joins { + if let ast::SelectTable::Table(name, _, _) = join.table.as_ref() { + let table_name = name.name.as_str(); + if let Some(table) = schema.get_btree_table(table_name) { + tables.push(table.clone()); + } else { + return Err(LimboError::ParseError(format!( + "Table '{table_name}' not found in schema" + ))); + } + } + } + } + + if tables.is_empty() { + return Err(LimboError::ParseError( + "No tables found in SELECT statement".to_string(), + )); + } + + Ok(tables) } /// Extract the base table name from a SELECT statement (for non-join cases) @@ -427,8 +453,13 @@ impl IncrementalView { /// Generate the SQL query for populating the view from its source table fn sql_for_populate(&self) -> crate::Result { - // Get the base table from referenced tables - let table = &self.base_table; + // Get the first table from referenced tables + if self.referenced_tables.is_empty() { + return Err(LimboError::ParseError( + "No tables to populate from".to_string(), + )); + } + let table = &self.referenced_tables[0]; // Check if the table has a rowid alias (INTEGER PRIMARY KEY column) let has_rowid_alias = table.columns.iter().any(|col| col.is_rowid_alias); @@ -607,33 +638,34 @@ impl IncrementalView { // Determine how to extract the rowid // If there's a rowid alias (INTEGER PRIMARY KEY), the rowid is one of the columns // Otherwise, it's the last value we explicitly selected - let (rowid, values) = - if let Some((idx, _)) = self.base_table.get_rowid_alias_column() { - // The rowid is the value at the rowid alias column index - let rowid = match all_values.get(idx) { - Some(crate::types::Value::Integer(id)) => *id, - _ => { - // This shouldn't happen - rowid alias must be an integer - rows_processed += 1; - continue; - } - }; - // All values are table columns (no separate rowid was selected) - (rowid, all_values) - } else { - // The last value is the explicitly selected rowid - let rowid = match all_values.last() { - Some(crate::types::Value::Integer(id)) => *id, - _ => { - // This shouldn't happen - rowid must be an integer - rows_processed += 1; - continue; - } - }; - // Get all values except the rowid - let values = all_values[..all_values.len() - 1].to_vec(); - (rowid, values) + let (rowid, values) = if let Some((idx, _)) = + self.referenced_tables[0].get_rowid_alias_column() + { + // The rowid is the value at the rowid alias column index + let rowid = match all_values.get(idx) { + Some(crate::types::Value::Integer(id)) => *id, + _ => { + // This shouldn't happen - rowid alias must be an integer + rows_processed += 1; + continue; + } }; + // All values are table columns (no separate rowid was selected) + (rowid, all_values) + } else { + // The last value is the explicitly selected rowid + let rowid = match all_values.last() { + Some(crate::types::Value::Integer(id)) => *id, + _ => { + // This shouldn't happen - rowid must be an integer + rows_processed += 1; + continue; + } + }; + // Get all values except the rowid + let values = all_values[..all_values.len() - 1].to_vec(); + (rowid, values) + }; // Create a single-row delta and process it immediately let mut single_row_delta = Delta::new(); @@ -782,10 +814,275 @@ impl IncrementalView { // Use the circuit to process the delta and write to btree let mut input_data = HashMap::new(); - input_data.insert(self.base_table.name.clone(), delta.clone()); + // For now, assume the delta applies to the first table + // TODO: This needs to be improved to handle deltas for multiple tables + if !self.referenced_tables.is_empty() { + input_data.insert(self.referenced_tables[0].name.clone(), delta.clone()); + } // The circuit now handles all btree I/O internally with the provided pager let _delta = return_if_io!(self.circuit.commit(input_data, pager)); Ok(IOResult::Done(())) } } + +#[cfg(test)] +mod tests { + use super::*; + use crate::schema::{BTreeTable, Column as SchemaColumn, Schema, Type}; + use std::sync::Arc; + use turso_parser::ast; + use turso_parser::parser::Parser; + + // Helper function to create a test schema with multiple tables + fn create_test_schema() -> Schema { + let mut schema = Schema::new(false); + + // Create customers table + let customers_table = BTreeTable { + name: "customers".to_string(), + root_page: 2, + primary_key_columns: vec![("id".to_string(), ast::SortOrder::Asc)], + columns: vec![ + SchemaColumn { + name: Some("id".to_string()), + ty: Type::Integer, + ty_str: "INTEGER".to_string(), + primary_key: true, + is_rowid_alias: true, + notnull: true, + default: None, + unique: false, + collation: None, + hidden: false, + }, + SchemaColumn { + name: Some("name".to_string()), + ty: Type::Text, + ty_str: "TEXT".to_string(), + primary_key: false, + is_rowid_alias: false, + notnull: false, + default: None, + unique: false, + collation: None, + hidden: false, + }, + ], + has_rowid: true, + is_strict: false, + unique_sets: None, + }; + + // Create orders table + let orders_table = BTreeTable { + name: "orders".to_string(), + root_page: 3, + primary_key_columns: vec![("id".to_string(), ast::SortOrder::Asc)], + columns: vec![ + SchemaColumn { + name: Some("id".to_string()), + ty: Type::Integer, + ty_str: "INTEGER".to_string(), + primary_key: true, + is_rowid_alias: true, + notnull: true, + default: None, + unique: false, + collation: None, + hidden: false, + }, + SchemaColumn { + name: Some("customer_id".to_string()), + ty: Type::Integer, + ty_str: "INTEGER".to_string(), + primary_key: false, + is_rowid_alias: false, + notnull: false, + default: None, + unique: false, + collation: None, + hidden: false, + }, + SchemaColumn { + name: Some("total".to_string()), + ty: Type::Integer, + ty_str: "INTEGER".to_string(), + primary_key: false, + is_rowid_alias: false, + notnull: false, + default: None, + unique: false, + collation: None, + hidden: false, + }, + ], + has_rowid: true, + is_strict: false, + unique_sets: None, + }; + + // Create products table + let products_table = BTreeTable { + name: "products".to_string(), + root_page: 4, + primary_key_columns: vec![("id".to_string(), ast::SortOrder::Asc)], + columns: vec![ + SchemaColumn { + name: Some("id".to_string()), + ty: Type::Integer, + ty_str: "INTEGER".to_string(), + primary_key: true, + is_rowid_alias: true, + notnull: true, + default: None, + unique: false, + collation: None, + hidden: false, + }, + SchemaColumn { + name: Some("name".to_string()), + ty: Type::Text, + ty_str: "TEXT".to_string(), + primary_key: false, + is_rowid_alias: false, + notnull: false, + default: None, + unique: false, + collation: None, + hidden: false, + }, + ], + has_rowid: true, + is_strict: false, + unique_sets: None, + }; + + schema.add_btree_table(Arc::new(customers_table)); + schema.add_btree_table(Arc::new(orders_table)); + schema.add_btree_table(Arc::new(products_table)); + schema + } + + // Helper to parse SQL and extract the SELECT statement + fn parse_select(sql: &str) -> ast::Select { + let mut parser = Parser::new(sql.as_bytes()); + let cmd = parser.next().unwrap().unwrap(); + match cmd { + ast::Cmd::Stmt(ast::Stmt::Select(select)) => select, + _ => panic!("Expected SELECT statement"), + } + } + + #[test] + fn test_extract_single_table() { + let schema = create_test_schema(); + let select = parse_select("SELECT * FROM customers"); + + let tables = IncrementalView::extract_all_tables(&select, &schema).unwrap(); + + assert_eq!(tables.len(), 1); + assert_eq!(tables[0].name, "customers"); + } + + #[test] + fn test_extract_tables_from_inner_join() { + let schema = create_test_schema(); + let select = parse_select( + "SELECT * FROM customers INNER JOIN orders ON customers.id = orders.customer_id", + ); + + let tables = IncrementalView::extract_all_tables(&select, &schema).unwrap(); + + assert_eq!(tables.len(), 2); + assert_eq!(tables[0].name, "customers"); + assert_eq!(tables[1].name, "orders"); + } + + #[test] + fn test_extract_tables_from_multiple_joins() { + let schema = create_test_schema(); + let select = parse_select( + "SELECT * FROM customers + INNER JOIN orders ON customers.id = orders.customer_id + INNER JOIN products ON orders.id = products.id", + ); + + let tables = IncrementalView::extract_all_tables(&select, &schema).unwrap(); + + assert_eq!(tables.len(), 3); + assert_eq!(tables[0].name, "customers"); + assert_eq!(tables[1].name, "orders"); + assert_eq!(tables[2].name, "products"); + } + + #[test] + fn test_extract_tables_from_left_join() { + let schema = create_test_schema(); + let select = parse_select( + "SELECT * FROM customers LEFT JOIN orders ON customers.id = orders.customer_id", + ); + + let tables = IncrementalView::extract_all_tables(&select, &schema).unwrap(); + + assert_eq!(tables.len(), 2); + assert_eq!(tables[0].name, "customers"); + assert_eq!(tables[1].name, "orders"); + } + + #[test] + fn test_extract_tables_from_cross_join() { + let schema = create_test_schema(); + let select = parse_select("SELECT * FROM customers CROSS JOIN orders"); + + let tables = IncrementalView::extract_all_tables(&select, &schema).unwrap(); + + assert_eq!(tables.len(), 2); + assert_eq!(tables[0].name, "customers"); + assert_eq!(tables[1].name, "orders"); + } + + #[test] + fn test_extract_tables_with_aliases() { + let schema = create_test_schema(); + let select = + parse_select("SELECT * FROM customers c INNER JOIN orders o ON c.id = o.customer_id"); + + let tables = IncrementalView::extract_all_tables(&select, &schema).unwrap(); + + // Should still extract the actual table names, not aliases + assert_eq!(tables.len(), 2); + assert_eq!(tables[0].name, "customers"); + assert_eq!(tables[1].name, "orders"); + } + + #[test] + fn test_extract_tables_nonexistent_table_error() { + let schema = create_test_schema(); + let select = parse_select("SELECT * FROM nonexistent"); + + let result = IncrementalView::extract_all_tables(&select, &schema); + + assert!(result.is_err()); + assert!(result + .unwrap_err() + .to_string() + .contains("Table 'nonexistent' not found")); + } + + #[test] + fn test_extract_tables_nonexistent_join_table_error() { + let schema = create_test_schema(); + let select = parse_select( + "SELECT * FROM customers INNER JOIN nonexistent ON customers.id = nonexistent.id", + ); + + let result = IncrementalView::extract_all_tables(&select, &schema); + + assert!(result.is_err()); + assert!(result + .unwrap_err() + .to_string() + .contains("Table 'nonexistent' not found")); + } +}