diff --git a/Cargo.lock b/Cargo.lock index 983012525..dd15bab8d 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -4451,6 +4451,7 @@ dependencies = [ "rand_chacha 0.9.0", "regex", "regex-syntax", + "roaring", "rstest", "rusqlite", "rustix 1.0.7", diff --git a/core/Cargo.toml b/core/Cargo.toml index 7fa6afb78..a795f9d8c 100644 --- a/core/Cargo.toml +++ b/core/Cargo.toml @@ -83,6 +83,7 @@ turso_parser = { workspace = true } aegis = "0.9.0" twox-hash = "2.1.1" intrusive-collections = "0.9.7" +roaring = "0.11.2" [build-dependencies] chrono = { workspace = true, default-features = false } diff --git a/core/translate/optimizer/mod.rs b/core/translate/optimizer/mod.rs index f053e7e7e..81dde810c 100644 --- a/core/translate/optimizer/mod.rs +++ b/core/translate/optimizer/mod.rs @@ -187,6 +187,12 @@ fn optimize_table_access( order_by: &mut Vec<(Box, SortOrder)>, group_by: &mut Option, ) -> Result>> { + if table_references.joined_tables().len() > TableReferences::MAX_JOINED_TABLES { + crate::bail_parse_error!( + "Only up to {} tables can be joined", + TableReferences::MAX_JOINED_TABLES + ); + } let access_methods_arena = RefCell::new(Vec::new()); let maybe_order_target = compute_order_target(order_by, group_by.as_mut()); let constraints_per_table = diff --git a/core/translate/plan.rs b/core/translate/plan.rs index 21fa84b69..98cd8e2e2 100644 --- a/core/translate/plan.rs +++ b/core/translate/plan.rs @@ -583,6 +583,11 @@ pub struct TableReferences { } impl TableReferences { + /// The maximum number of tables that can be joined together in a query. + /// This limit is arbitrary, although we currently use a u128 to represent the [crate::translate::planner::TableMask], + /// which can represent up to 128 tables. + /// Even at 63 tables we currently cannot handle the optimization performantly, hence the arbitrary cap. + pub const MAX_JOINED_TABLES: usize = 63; pub fn new( joined_tables: Vec, outer_query_refs: Vec, @@ -752,33 +757,25 @@ impl TableReferences { } } -#[derive(Clone, Debug, Default, PartialEq, Eq)] +#[derive(Clone, Debug, Default, PartialEq)] #[repr(transparent)] -pub struct ColumnUsedMask(u128); +pub struct ColumnUsedMask(roaring::RoaringBitmap); impl ColumnUsedMask { pub fn set(&mut self, index: usize) { - assert!( - index < 128, - "ColumnUsedMask only supports up to 128 columns" - ); - self.0 |= 1 << index; + self.0.insert(index as u32); } pub fn get(&self, index: usize) -> bool { - assert!( - index < 128, - "ColumnUsedMask only supports up to 128 columns" - ); - self.0 & (1 << index) != 0 + self.0.contains(index as u32) } pub fn contains_all_set_bits_of(&self, other: &Self) -> bool { - self.0 & other.0 == other.0 + other.0.is_subset(&self.0) } pub fn is_empty(&self) -> bool { - self.0 == 0 + self.0.is_empty() } } @@ -1261,3 +1258,127 @@ pub struct WindowFunction { /// The expression from which the function was resolved. pub original_expr: Expr, } + +#[cfg(test)] +mod tests { + use super::*; + use rand_chacha::{ + rand_core::{RngCore, SeedableRng}, + ChaCha8Rng, + }; + + #[test] + fn test_column_used_mask_empty() { + let mask = ColumnUsedMask::default(); + assert!(mask.is_empty()); + + let mut mask2 = ColumnUsedMask::default(); + mask2.set(0); + assert!(!mask2.is_empty()); + } + + #[test] + fn test_column_used_mask_set_and_get() { + let mut mask = ColumnUsedMask::default(); + + let max_columns = 10000; + let mut set_indices = Vec::new(); + let mut rng = ChaCha8Rng::seed_from_u64( + std::time::SystemTime::now() + .duration_since(std::time::UNIX_EPOCH) + .unwrap() + .as_secs(), + ); + + for i in 0..max_columns { + if rng.next_u32() % 3 == 0 { + set_indices.push(i); + mask.set(i); + } + } + + // Verify set bits are present + for &i in &set_indices { + assert!(mask.get(i), "Expected bit {i} to be set"); + } + + // Verify unset bits are not present + for i in 0..max_columns { + if !set_indices.contains(&i) { + assert!(!mask.get(i), "Expected bit {i} to not be set"); + } + } + } + + #[test] + fn test_column_used_mask_subset_relationship() { + let mut full_mask = ColumnUsedMask::default(); + let mut subset_mask = ColumnUsedMask::default(); + + let max_columns = 5000; + let mut rng = ChaCha8Rng::seed_from_u64( + std::time::SystemTime::now() + .duration_since(std::time::UNIX_EPOCH) + .unwrap() + .as_secs(), + ); + + // Create a pattern where subset has fewer bits + for i in 0..max_columns { + if rng.next_u32() % 5 == 0 { + full_mask.set(i); + if i % 2 == 0 { + subset_mask.set(i); + } + } + } + + // full_mask contains all bits of subset_mask + assert!(full_mask.contains_all_set_bits_of(&subset_mask)); + + // subset_mask does not contain all bits of full_mask + assert!(!subset_mask.contains_all_set_bits_of(&full_mask)); + + // A mask contains itself + assert!(full_mask.contains_all_set_bits_of(&full_mask)); + assert!(subset_mask.contains_all_set_bits_of(&subset_mask)); + } + + #[test] + fn test_column_used_mask_empty_subset() { + let mut mask = ColumnUsedMask::default(); + for i in (0..1000).step_by(7) { + mask.set(i); + } + + let empty_mask = ColumnUsedMask::default(); + + // Empty mask is subset of everything + assert!(mask.contains_all_set_bits_of(&empty_mask)); + assert!(empty_mask.contains_all_set_bits_of(&empty_mask)); + } + + #[test] + fn test_column_used_mask_sparse_indices() { + let mut sparse_mask = ColumnUsedMask::default(); + + // Test with very sparse, large indices + let sparse_indices = vec![0, 137, 1042, 5389, 10000, 50000, 100000, 500000, 1000000]; + + for &idx in &sparse_indices { + sparse_mask.set(idx); + } + + for &idx in &sparse_indices { + assert!(sparse_mask.get(idx), "Expected bit {idx} to be set"); + } + + // Check some indices that shouldn't be set + let unset_indices = vec![1, 100, 1000, 5000, 25000, 75000, 250000, 750000]; + for &idx in &unset_indices { + assert!(!sparse_mask.get(idx), "Expected bit {idx} to not be set"); + } + + assert!(!sparse_mask.is_empty()); + } +} diff --git a/tests/integration/query_processing/test_read_path.rs b/tests/integration/query_processing/test_read_path.rs index 98874283b..4b08df3b1 100644 --- a/tests/integration/query_processing/test_read_path.rs +++ b/tests/integration/query_processing/test_read_path.rs @@ -1,5 +1,5 @@ use crate::common::{limbo_exec_rows, TempDatabase}; -use turso_core::{StepResult, Value}; +use turso_core::{LimboError, StepResult, Value}; #[test] fn test_statement_reset_bind() -> anyhow::Result<()> { @@ -897,3 +897,97 @@ fn test_multiple_connections_visibility() -> anyhow::Result<()> { assert_eq!(rows, vec![vec![rusqlite::types::Value::Integer(2)]]); Ok(()) } + +#[test] +/// Test that we can only join up to 63 tables, and trying to join more should fail with an error instead of panicing. +fn test_max_joined_tables_limit() { + let tmp_db = TempDatabase::new("test_max_joined_tables_limit", false); + let conn = tmp_db.connect_limbo(); + + // Create 64 tables + for i in 0..64 { + conn.execute(format!("CREATE TABLE t{i} (id INTEGER)")) + .unwrap(); + } + + // Try to join 64 tables - should fail + let mut sql = String::from("SELECT * FROM t0"); + for i in 1..64 { + sql.push_str(&format!(" JOIN t{i} ON t{i}.id = t0.id")); + } + + let Err(LimboError::ParseError(result)) = conn.prepare(&sql) else { + panic!("Expected an error but got no error"); + }; + assert!(result.contains("Only up to 63 tables can be joined")); +} + +#[test] +/// Test that we can create and select from a table with 1000 columns. +fn test_many_columns() { + let mut create_sql = String::from("CREATE TABLE test ("); + for i in 0..1000 { + if i > 0 { + create_sql.push_str(", "); + } + create_sql.push_str(&format!("col{i} INTEGER")); + } + create_sql.push(')'); + + let tmp_db = TempDatabase::new("test_many_columns", false); + let conn = tmp_db.connect_limbo(); + conn.execute(&create_sql).unwrap(); + + // Insert a row with values 0-999 + let mut insert_sql = String::from("INSERT INTO test VALUES ("); + for i in 0..1000 { + if i > 0 { + insert_sql.push_str(", "); + } + insert_sql.push_str(&i.to_string()); + } + insert_sql.push(')'); + conn.execute(&insert_sql).unwrap(); + + // Select every 100th column + let mut select_sql = String::from("SELECT "); + let mut first = true; + for i in (0..1000).step_by(100) { + if !first { + select_sql.push_str(", "); + } + select_sql.push_str(&format!("col{i}")); + first = false; + } + select_sql.push_str(" FROM test"); + + let mut rows = Vec::new(); + let mut stmt = conn.prepare(&select_sql).unwrap(); + loop { + match stmt.step().unwrap() { + StepResult::Row => { + let row = stmt.row().unwrap(); + rows.push(row.get_values().cloned().collect::>()); + } + StepResult::IO => stmt.run_once().unwrap(), + _ => break, + } + } + + // Verify we got values 0,100,200,...,900 + assert_eq!( + rows, + vec![vec![ + turso_core::Value::Integer(0), + turso_core::Value::Integer(100), + turso_core::Value::Integer(200), + turso_core::Value::Integer(300), + turso_core::Value::Integer(400), + turso_core::Value::Integer(500), + turso_core::Value::Integer(600), + turso_core::Value::Integer(700), + turso_core::Value::Integer(800), + turso_core::Value::Integer(900), + ]] + ); +}