diff --git a/Cargo.lock b/Cargo.lock index 79c477a01..f774018b0 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1100,6 +1100,19 @@ version = "1.15.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "48c757948c5ede0e46177b7add2e67155f70e33c07fea8284df6576da70b3719" +[[package]] +name = "encryption-throughput" +version = "0.1.0" +dependencies = [ + "clap", + "futures", + "hex", + "rand 0.9.2", + "tokio", + "tracing-subscriber", + "turso", +] + [[package]] name = "endian-type" version = "0.1.2" diff --git a/Cargo.toml b/Cargo.toml index 33352e546..cad11bee3 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -32,8 +32,11 @@ members = [ "whopper", "perf/throughput/turso", "perf/throughput/rusqlite", + "perf/encryption" +] +exclude = [ + "perf/latency/limbo", ] -exclude = ["perf/latency/limbo"] [workspace.package] version = "0.2.0-pre.3" diff --git a/bindings/rust/Cargo.toml b/bindings/rust/Cargo.toml index d799b5320..e50304f01 100644 --- a/bindings/rust/Cargo.toml +++ b/bindings/rust/Cargo.toml @@ -15,6 +15,7 @@ conn_raw_api = ["turso_core/conn_raw_api"] experimental_indexes = [] antithesis = ["turso_core/antithesis"] tracing_release = ["turso_core/tracing_release"] +encryption = ["turso_core/encryption"] [dependencies] turso_core = { workspace = true, features = ["io_uring"] } diff --git a/bindings/rust/src/lib.rs b/bindings/rust/src/lib.rs index 39706fdd5..15ae191f7 100644 --- a/bindings/rust/src/lib.rs +++ b/bindings/rust/src/lib.rs @@ -413,7 +413,7 @@ impl Connection { .inner .lock() .map_err(|e| Error::MutexError(e.to_string()))?; - conn.busy_timeout(duration); + conn.set_busy_timeout(duration); Ok(()) } } diff --git a/core/fast_lock.rs b/core/fast_lock.rs index 8abda6a17..a02d617ba 100644 --- a/core/fast_lock.rs +++ b/core/fast_lock.rs @@ -34,7 +34,6 @@ impl DerefMut for SpinLockGuard<'_, T> { } } -unsafe impl Send for SpinLock {} unsafe impl Sync for SpinLock {} impl SpinLock { diff --git a/core/incremental/aggregate_operator.rs b/core/incremental/aggregate_operator.rs new file mode 100644 index 000000000..9f25a84f5 --- /dev/null +++ b/core/incremental/aggregate_operator.rs @@ -0,0 +1,1762 @@ +// Aggregate operator for DBSP-style incremental computation + +use crate::function::{AggFunc, Func}; +use crate::incremental::dbsp::{Delta, DeltaPair, HashableRow}; +use crate::incremental::operator::{ + generate_storage_id, ComputationTracker, DbspStateCursors, EvalState, IncrementalOperator, +}; +use crate::incremental::persistence::{ReadRecord, WriteRow}; +use crate::types::{IOResult, ImmutableRecord, RefValue, SeekKey, SeekOp, SeekResult}; +use crate::{return_and_restore_if_io, return_if_io, LimboError, Result, Value}; +use std::collections::{BTreeMap, HashMap, HashSet}; +use std::fmt::{self, Display}; +use std::sync::{Arc, Mutex}; + +/// Constants for aggregate type encoding in storage IDs (2 bits) +pub const AGG_TYPE_REGULAR: u8 = 0b00; // COUNT/SUM/AVG +pub const AGG_TYPE_MINMAX: u8 = 0b01; // MIN/MAX (BTree ordering gives both) + +#[derive(Debug, Clone, PartialEq)] +pub enum AggregateFunction { + Count, + Sum(usize), // Column index + Avg(usize), // Column index + Min(usize), // Column index + Max(usize), // Column index +} + +impl Display for AggregateFunction { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + match self { + AggregateFunction::Count => write!(f, "COUNT(*)"), + AggregateFunction::Sum(idx) => write!(f, "SUM(col{idx})"), + AggregateFunction::Avg(idx) => write!(f, "AVG(col{idx})"), + AggregateFunction::Min(idx) => write!(f, "MIN(col{idx})"), + AggregateFunction::Max(idx) => write!(f, "MAX(col{idx})"), + } + } +} + +impl AggregateFunction { + /// Get the default output column name for this aggregate function + #[inline] + pub fn default_output_name(&self) -> String { + self.to_string() + } + + /// Create an AggregateFunction from a SQL function and its arguments + /// Returns None if the function is not a supported aggregate + pub fn from_sql_function( + func: &crate::function::Func, + input_column_idx: Option, + ) -> Option { + match func { + Func::Agg(agg_func) => { + match agg_func { + AggFunc::Count | AggFunc::Count0 => Some(AggregateFunction::Count), + AggFunc::Sum => input_column_idx.map(AggregateFunction::Sum), + AggFunc::Avg => input_column_idx.map(AggregateFunction::Avg), + AggFunc::Min => input_column_idx.map(AggregateFunction::Min), + AggFunc::Max => input_column_idx.map(AggregateFunction::Max), + _ => None, // Other aggregate functions not yet supported in DBSP + } + } + _ => None, // Not an aggregate function + } + } +} + +/// Information about a column that has MIN/MAX aggregations +#[derive(Debug, Clone)] +pub struct AggColumnInfo { + /// Index used for storage key generation + pub index: usize, + /// Whether this column has a MIN aggregate + pub has_min: bool, + /// Whether this column has a MAX aggregate + pub has_max: bool, +} + +/// Serialize a Value using SQLite's serial type format +/// This is used for MIN/MAX values that need to be stored in a compact, sortable format +pub fn serialize_value(value: &Value, blob: &mut Vec) { + let serial_type = crate::types::SerialType::from(value); + let serial_type_u64: u64 = serial_type.into(); + crate::storage::sqlite3_ondisk::write_varint_to_vec(serial_type_u64, blob); + value.serialize_serial(blob); +} + +/// Deserialize a Value using SQLite's serial type format +/// Returns the deserialized value and the number of bytes consumed +pub fn deserialize_value(blob: &[u8]) -> Option<(Value, usize)> { + let mut cursor = 0; + + // Read the serial type + let (serial_type, varint_size) = crate::storage::sqlite3_ondisk::read_varint(blob).ok()?; + cursor += varint_size; + + let serial_type_obj = crate::types::SerialType::try_from(serial_type).ok()?; + let expected_size = serial_type_obj.size(); + + // Read the value + let (value, actual_size) = + crate::storage::sqlite3_ondisk::read_value(&blob[cursor..], serial_type_obj).ok()?; + + // Verify that the actual size matches what we expected from the serial type + if actual_size != expected_size { + return None; // Data corruption - size mismatch + } + + cursor += actual_size; + + // Convert RefValue to Value + Some((value.to_owned(), cursor)) +} + +// group_key_str -> (group_key, state) +type ComputedStates = HashMap, AggregateState)>; +// group_key_str -> (column_index, value_as_hashable_row) -> accumulated_weight +pub type MinMaxDeltas = HashMap>; + +#[derive(Debug)] +enum AggregateCommitState { + Idle, + Eval { + eval_state: EvalState, + }, + PersistDelta { + delta: Delta, + computed_states: ComputedStates, + current_idx: usize, + write_row: WriteRow, + min_max_deltas: MinMaxDeltas, + }, + PersistMinMax { + delta: Delta, + min_max_persist_state: MinMaxPersistState, + }, + Done { + delta: Delta, + }, + Invalid, +} + +// Aggregate-specific eval states +#[derive(Debug)] +pub enum AggregateEvalState { + FetchKey { + delta: Delta, // Keep original delta for merge operation + current_idx: usize, + groups_to_read: Vec<(String, Vec)>, // Changed to Vec for index-based access + existing_groups: HashMap, + old_values: HashMap>, + }, + FetchData { + delta: Delta, // Keep original delta for merge operation + current_idx: usize, + groups_to_read: Vec<(String, Vec)>, // Changed to Vec for index-based access + existing_groups: HashMap, + old_values: HashMap>, + rowid: Option, // Rowid found by FetchKey (None if not found) + read_record_state: Box, + }, + RecomputeMinMax { + delta: Delta, + existing_groups: HashMap, + old_values: HashMap>, + recompute_state: Box, + }, + Done { + output: (Delta, ComputedStates), + }, +} + +/// Note that the AggregateOperator essentially implements a ZSet, even +/// though the ZSet structure is never used explicitly. The on-disk btree +/// plays the role of the set! +#[derive(Debug)] +pub struct AggregateOperator { + // Unique operator ID for indexing in persistent storage + pub operator_id: usize, + // GROUP BY column indices + group_by: Vec, + // Aggregate functions to compute (including MIN/MAX) + pub aggregates: Vec, + // Column names from input + pub input_column_names: Vec, + // Map from column index to aggregate info for quick lookup + pub column_min_max: HashMap, + tracker: Option>>, + + // State machine for commit operation + commit_state: AggregateCommitState, +} + +/// State for a single group's aggregates +#[derive(Debug, Clone, Default)] +pub struct AggregateState { + // For COUNT: just the count + pub count: i64, + // For SUM: column_index -> sum value + sums: HashMap, + // For AVG: column_index -> (sum, count) for computing average + avgs: HashMap, + // For MIN: column_index -> minimum value + pub mins: HashMap, + // For MAX: column_index -> maximum value + pub maxs: HashMap, +} + +impl AggregateEvalState { + fn process_delta( + &mut self, + operator: &mut AggregateOperator, + cursors: &mut DbspStateCursors, + ) -> Result> { + loop { + match self { + AggregateEvalState::FetchKey { + delta, + current_idx, + groups_to_read, + existing_groups, + old_values, + } => { + if *current_idx >= groups_to_read.len() { + // All groups have been fetched, move to RecomputeMinMax + // Extract MIN/MAX deltas from the input delta + let min_max_deltas = operator.extract_min_max_deltas(delta); + + let recompute_state = Box::new(RecomputeMinMax::new( + min_max_deltas, + existing_groups, + operator, + )); + + *self = AggregateEvalState::RecomputeMinMax { + delta: std::mem::take(delta), + existing_groups: std::mem::take(existing_groups), + old_values: std::mem::take(old_values), + recompute_state, + }; + } else { + // Get the current group to read + let (group_key_str, _group_key) = &groups_to_read[*current_idx]; + + // Build the key for the index: (operator_id, zset_id, element_id) + // For regular aggregates, use column_index=0 and AGG_TYPE_REGULAR + let operator_storage_id = + generate_storage_id(operator.operator_id, 0, AGG_TYPE_REGULAR); + let zset_id = operator.generate_group_rowid(group_key_str); + let element_id = 0i64; // Always 0 for aggregators + + // Create index key values + let index_key_values = vec![ + Value::Integer(operator_storage_id), + Value::Integer(zset_id), + Value::Integer(element_id), + ]; + + // Create an immutable record for the index key + let index_record = + ImmutableRecord::from_values(&index_key_values, index_key_values.len()); + + // Seek in the index to find if this row exists + let seek_result = return_if_io!(cursors.index_cursor.seek( + SeekKey::IndexKey(&index_record), + SeekOp::GE { eq_only: true } + )); + + let rowid = if matches!(seek_result, SeekResult::Found) { + // Found in index, get the table rowid + // The btree code handles extracting the rowid from the index record for has_rowid indexes + return_if_io!(cursors.index_cursor.rowid()) + } else { + // Not found in index, no existing state + None + }; + + // Always transition to FetchData + let taken_existing = std::mem::take(existing_groups); + let taken_old_values = std::mem::take(old_values); + let next_state = AggregateEvalState::FetchData { + delta: std::mem::take(delta), + current_idx: *current_idx, + groups_to_read: std::mem::take(groups_to_read), + existing_groups: taken_existing, + old_values: taken_old_values, + rowid, + read_record_state: Box::new(ReadRecord::new()), + }; + *self = next_state; + } + } + AggregateEvalState::FetchData { + delta, + current_idx, + groups_to_read, + existing_groups, + old_values, + rowid, + read_record_state, + } => { + // Get the current group to read + let (group_key_str, group_key) = &groups_to_read[*current_idx]; + + // Only try to read if we have a rowid + if let Some(rowid) = rowid { + let key = SeekKey::TableRowId(*rowid); + let state = return_if_io!(read_record_state.read_record( + key, + &operator.aggregates, + &mut cursors.table_cursor + )); + // Process the fetched state + if let Some(state) = state { + let mut old_row = group_key.clone(); + old_row.extend(state.to_values(&operator.aggregates)); + old_values.insert(group_key_str.clone(), old_row); + existing_groups.insert(group_key_str.clone(), state.clone()); + } + } else { + // No rowid for this group, skipping read + } + // If no rowid, there's no existing state for this group + + // Move to next group + let next_idx = *current_idx + 1; + let taken_existing = std::mem::take(existing_groups); + let taken_old_values = std::mem::take(old_values); + let next_state = AggregateEvalState::FetchKey { + delta: std::mem::take(delta), + current_idx: next_idx, + groups_to_read: std::mem::take(groups_to_read), + existing_groups: taken_existing, + old_values: taken_old_values, + }; + *self = next_state; + } + AggregateEvalState::RecomputeMinMax { + delta, + existing_groups, + old_values, + recompute_state, + } => { + if operator.has_min_max() { + // Process MIN/MAX recomputation - this will update existing_groups with correct MIN/MAX + return_if_io!(recompute_state.process(existing_groups, operator, cursors)); + } + + // Now compute final output with updated MIN/MAX values + let (output_delta, computed_states) = + operator.merge_delta_with_existing(delta, existing_groups, old_values); + + *self = AggregateEvalState::Done { + output: (output_delta, computed_states), + }; + } + AggregateEvalState::Done { output } => { + return Ok(IOResult::Done(output.clone())); + } + } + } + } +} + +impl AggregateState { + pub fn new() -> Self { + Self::default() + } + + // Serialize the aggregate state to a binary blob including group key values + // The reason we serialize it like this, instead of just writing the actual values, is that + // The same table may have different aggregators in the circuit. They will all have different + // columns. + fn to_blob(&self, aggregates: &[AggregateFunction], group_key: &[Value]) -> Vec { + let mut blob = Vec::new(); + + // Write version byte for future compatibility + blob.push(1u8); + + // Write number of group key values + blob.extend_from_slice(&(group_key.len() as u32).to_le_bytes()); + + // Write each group key value + for value in group_key { + // Write value type tag + match value { + Value::Null => blob.push(0u8), + Value::Integer(i) => { + blob.push(1u8); + blob.extend_from_slice(&i.to_le_bytes()); + } + Value::Float(f) => { + blob.push(2u8); + blob.extend_from_slice(&f.to_le_bytes()); + } + Value::Text(s) => { + blob.push(3u8); + let text_str = s.as_str(); + let bytes = text_str.as_bytes(); + blob.extend_from_slice(&(bytes.len() as u32).to_le_bytes()); + blob.extend_from_slice(bytes); + } + Value::Blob(b) => { + blob.push(4u8); + blob.extend_from_slice(&(b.len() as u32).to_le_bytes()); + blob.extend_from_slice(b); + } + } + } + + // Write count as 8 bytes (little-endian) + blob.extend_from_slice(&self.count.to_le_bytes()); + + // Write each aggregate's state + for agg in aggregates { + match agg { + AggregateFunction::Sum(col_name) => { + let sum = self.sums.get(col_name).copied().unwrap_or(0.0); + blob.extend_from_slice(&sum.to_le_bytes()); + } + AggregateFunction::Avg(col_name) => { + let (sum, count) = self.avgs.get(col_name).copied().unwrap_or((0.0, 0)); + blob.extend_from_slice(&sum.to_le_bytes()); + blob.extend_from_slice(&count.to_le_bytes()); + } + AggregateFunction::Count => { + // Count is already written above + } + AggregateFunction::Min(col_name) => { + // Write whether we have a MIN value (1 byte) + if let Some(min_val) = self.mins.get(col_name) { + blob.push(1u8); // Has value + serialize_value(min_val, &mut blob); + } else { + blob.push(0u8); // No value + } + } + AggregateFunction::Max(col_name) => { + // Write whether we have a MAX value (1 byte) + if let Some(max_val) = self.maxs.get(col_name) { + blob.push(1u8); // Has value + serialize_value(max_val, &mut blob); + } else { + blob.push(0u8); // No value + } + } + } + } + + blob + } + + /// Deserialize aggregate state from a binary blob + /// Returns the aggregate state and the group key values + pub fn from_blob(blob: &[u8], aggregates: &[AggregateFunction]) -> Option<(Self, Vec)> { + let mut cursor = 0; + + // Check version byte + if blob.get(cursor) != Some(&1u8) { + return None; + } + cursor += 1; + + // Read number of group key values + let num_group_keys = + u32::from_le_bytes(blob.get(cursor..cursor + 4)?.try_into().ok()?) as usize; + cursor += 4; + + // Read group key values + let mut group_key = Vec::new(); + for _ in 0..num_group_keys { + let value_type = *blob.get(cursor)?; + cursor += 1; + + let value = match value_type { + 0 => Value::Null, + 1 => { + let i = i64::from_le_bytes(blob.get(cursor..cursor + 8)?.try_into().ok()?); + cursor += 8; + Value::Integer(i) + } + 2 => { + let f = f64::from_le_bytes(blob.get(cursor..cursor + 8)?.try_into().ok()?); + cursor += 8; + Value::Float(f) + } + 3 => { + let len = + u32::from_le_bytes(blob.get(cursor..cursor + 4)?.try_into().ok()?) as usize; + cursor += 4; + let bytes = blob.get(cursor..cursor + len)?; + cursor += len; + let text_str = std::str::from_utf8(bytes).ok()?; + Value::Text(text_str.to_string().into()) + } + 4 => { + let len = + u32::from_le_bytes(blob.get(cursor..cursor + 4)?.try_into().ok()?) as usize; + cursor += 4; + let bytes = blob.get(cursor..cursor + len)?; + cursor += len; + Value::Blob(bytes.to_vec()) + } + _ => return None, + }; + group_key.push(value); + } + + // Read count + let count = i64::from_le_bytes(blob.get(cursor..cursor + 8)?.try_into().ok()?); + cursor += 8; + + let mut state = Self::new(); + state.count = count; + + // Read each aggregate's state + for agg in aggregates { + match agg { + AggregateFunction::Sum(col_name) => { + let sum = f64::from_le_bytes(blob.get(cursor..cursor + 8)?.try_into().ok()?); + cursor += 8; + state.sums.insert(*col_name, sum); + } + AggregateFunction::Avg(col_name) => { + let sum = f64::from_le_bytes(blob.get(cursor..cursor + 8)?.try_into().ok()?); + cursor += 8; + let count = i64::from_le_bytes(blob.get(cursor..cursor + 8)?.try_into().ok()?); + cursor += 8; + state.avgs.insert(*col_name, (sum, count)); + } + AggregateFunction::Count => { + // Count was already read above + } + AggregateFunction::Min(col_name) => { + // Read whether we have a MIN value + let has_value = *blob.get(cursor)?; + cursor += 1; + + if has_value == 1 { + let (min_value, bytes_consumed) = deserialize_value(&blob[cursor..])?; + cursor += bytes_consumed; + state.mins.insert(*col_name, min_value); + } + } + AggregateFunction::Max(col_name) => { + // Read whether we have a MAX value + let has_value = *blob.get(cursor)?; + cursor += 1; + + if has_value == 1 { + let (max_value, bytes_consumed) = deserialize_value(&blob[cursor..])?; + cursor += bytes_consumed; + state.maxs.insert(*col_name, max_value); + } + } + } + } + + Some((state, group_key)) + } + + /// Apply a delta to this aggregate state + fn apply_delta( + &mut self, + values: &[Value], + weight: isize, + aggregates: &[AggregateFunction], + _column_names: &[String], // No longer needed + ) { + // Update COUNT + self.count += weight as i64; + + // Update other aggregates + for agg in aggregates { + match agg { + AggregateFunction::Count => { + // Already handled above + } + AggregateFunction::Sum(col_idx) => { + if let Some(val) = values.get(*col_idx) { + let num_val = match val { + Value::Integer(i) => *i as f64, + Value::Float(f) => *f, + _ => 0.0, + }; + *self.sums.entry(*col_idx).or_insert(0.0) += num_val * weight as f64; + } + } + AggregateFunction::Avg(col_idx) => { + if let Some(val) = values.get(*col_idx) { + let num_val = match val { + Value::Integer(i) => *i as f64, + Value::Float(f) => *f, + _ => 0.0, + }; + let (sum, count) = self.avgs.entry(*col_idx).or_insert((0.0, 0)); + *sum += num_val * weight as f64; + *count += weight as i64; + } + } + AggregateFunction::Min(_col_name) | AggregateFunction::Max(_col_name) => { + // MIN/MAX cannot be handled incrementally in apply_delta because: + // + // 1. For insertions: We can't just keep the minimum/maximum value. + // We need to track ALL values to handle future deletions correctly. + // + // 2. For deletions (retractions): If we delete the current MIN/MAX, + // we need to find the next best value, which requires knowing all + // other values in the group. + // + // Example: Consider MIN(price) with values [10, 20, 30] + // - Current MIN = 10 + // - Delete 10 (weight = -1) + // - New MIN should be 20, but we can't determine this without + // having tracked all values [20, 30] + // + // Therefore, MIN/MAX processing is handled separately: + // - All input values are persisted to the index via persist_min_max() + // - When aggregates have MIN/MAX, we unconditionally transition to + // the RecomputeMinMax state machine (see EvalState::RecomputeMinMax) + // - RecomputeMinMax checks if the current MIN/MAX was deleted, and if so, + // scans the index to find the new MIN/MAX from remaining values + // + // This ensures correctness for incremental computation at the cost of + // additional I/O for MIN/MAX operations. + } + } + } + } + + /// Convert aggregate state to output values + pub fn to_values(&self, aggregates: &[AggregateFunction]) -> Vec { + let mut result = Vec::new(); + + for agg in aggregates { + match agg { + AggregateFunction::Count => { + result.push(Value::Integer(self.count)); + } + AggregateFunction::Sum(col_idx) => { + let sum = self.sums.get(col_idx).copied().unwrap_or(0.0); + // Return as integer if it's a whole number, otherwise as float + if sum.fract() == 0.0 { + result.push(Value::Integer(sum as i64)); + } else { + result.push(Value::Float(sum)); + } + } + AggregateFunction::Avg(col_idx) => { + if let Some((sum, count)) = self.avgs.get(col_idx) { + if *count > 0 { + result.push(Value::Float(sum / *count as f64)); + } else { + result.push(Value::Null); + } + } else { + result.push(Value::Null); + } + } + AggregateFunction::Min(col_idx) => { + // Return the MIN value from our state + result.push(self.mins.get(col_idx).cloned().unwrap_or(Value::Null)); + } + AggregateFunction::Max(col_idx) => { + // Return the MAX value from our state + result.push(self.maxs.get(col_idx).cloned().unwrap_or(Value::Null)); + } + } + } + + result + } +} + +impl AggregateOperator { + pub fn new( + operator_id: usize, + group_by: Vec, + aggregates: Vec, + input_column_names: Vec, + ) -> Self { + // Build map of column indices to their MIN/MAX info + let mut column_min_max = HashMap::new(); + let mut storage_indices = HashMap::new(); + let mut current_index = 0; + + // First pass: assign storage indices to unique MIN/MAX columns + for agg in &aggregates { + match agg { + AggregateFunction::Min(col_idx) | AggregateFunction::Max(col_idx) => { + storage_indices.entry(*col_idx).or_insert_with(|| { + let idx = current_index; + current_index += 1; + idx + }); + } + _ => {} + } + } + + // Second pass: build the column info map + for agg in &aggregates { + match agg { + AggregateFunction::Min(col_idx) => { + let storage_index = *storage_indices.get(col_idx).unwrap(); + let entry = column_min_max.entry(*col_idx).or_insert(AggColumnInfo { + index: storage_index, + has_min: false, + has_max: false, + }); + entry.has_min = true; + } + AggregateFunction::Max(col_idx) => { + let storage_index = *storage_indices.get(col_idx).unwrap(); + let entry = column_min_max.entry(*col_idx).or_insert(AggColumnInfo { + index: storage_index, + has_min: false, + has_max: false, + }); + entry.has_max = true; + } + _ => {} + } + } + + Self { + operator_id, + group_by, + aggregates, + input_column_names, + column_min_max, + tracker: None, + commit_state: AggregateCommitState::Idle, + } + } + + pub fn has_min_max(&self) -> bool { + !self.column_min_max.is_empty() + } + + fn eval_internal( + &mut self, + state: &mut EvalState, + cursors: &mut DbspStateCursors, + ) -> Result> { + match state { + EvalState::Uninitialized => { + panic!("Cannot eval AggregateOperator with Uninitialized state"); + } + EvalState::Init { deltas } => { + // Aggregate operators only use left_delta, right_delta must be empty + assert!( + deltas.right.is_empty(), + "AggregateOperator expects right_delta to be empty" + ); + + if deltas.left.changes.is_empty() { + *state = EvalState::Done; + return Ok(IOResult::Done((Delta::new(), HashMap::new()))); + } + + let mut groups_to_read = BTreeMap::new(); + for (row, _weight) in &deltas.left.changes { + let group_key = self.extract_group_key(&row.values); + let group_key_str = Self::group_key_to_string(&group_key); + groups_to_read.insert(group_key_str, group_key); + } + + let delta = std::mem::take(&mut deltas.left); + *state = EvalState::Aggregate(Box::new(AggregateEvalState::FetchKey { + delta, + current_idx: 0, + groups_to_read: groups_to_read.into_iter().collect(), + existing_groups: HashMap::new(), + old_values: HashMap::new(), + })); + } + EvalState::Aggregate(_agg_state) => { + // Already in progress, continue processing below. + } + EvalState::Done => { + panic!("unreachable state! should have returned"); + } + EvalState::Join(_) => { + panic!("Join state should not appear in aggregate operator"); + } + } + + // Process the delta through the aggregate state machine + match state { + EvalState::Aggregate(agg_state) => { + let result = return_if_io!(agg_state.process_delta(self, cursors)); + Ok(IOResult::Done(result)) + } + _ => panic!("Invalid state for aggregate processing"), + } + } + + fn merge_delta_with_existing( + &mut self, + delta: &Delta, + existing_groups: &mut HashMap, + old_values: &mut HashMap>, + ) -> (Delta, HashMap, AggregateState)>) { + let mut output_delta = Delta::new(); + let mut temp_keys: HashMap> = HashMap::new(); + + // Process each change in the delta + for (row, weight) in &delta.changes { + if let Some(tracker) = &self.tracker { + tracker.lock().unwrap().record_aggregation(); + } + + // Extract group key + let group_key = self.extract_group_key(&row.values); + let group_key_str = Self::group_key_to_string(&group_key); + + let state = existing_groups.entry(group_key_str.clone()).or_default(); + + temp_keys.insert(group_key_str.clone(), group_key.clone()); + + // Apply the delta to the temporary state + state.apply_delta( + &row.values, + *weight, + &self.aggregates, + &self.input_column_names, + ); + } + + // Generate output delta from temporary states and collect final states + let mut final_states = HashMap::new(); + + for (group_key_str, state) in existing_groups { + let group_key = temp_keys.get(group_key_str).cloned().unwrap_or_default(); + + // Generate a unique rowid for this group + let result_key = self.generate_group_rowid(group_key_str); + + if let Some(old_row_values) = old_values.get(group_key_str) { + let old_row = HashableRow::new(result_key, old_row_values.clone()); + output_delta.changes.push((old_row, -1)); + } + + // Always store the state for persistence (even if count=0, we need to delete it) + final_states.insert(group_key_str.clone(), (group_key.clone(), state.clone())); + + // Only include groups with count > 0 in the output delta + if state.count > 0 { + // Build output row: group_by columns + aggregate values + let mut output_values = group_key.clone(); + let aggregate_values = state.to_values(&self.aggregates); + output_values.extend(aggregate_values); + + let output_row = HashableRow::new(result_key, output_values.clone()); + output_delta.changes.push((output_row, 1)); + } + } + (output_delta, final_states) + } + + /// Extract MIN/MAX values from delta changes for persistence to index + fn extract_min_max_deltas(&self, delta: &Delta) -> MinMaxDeltas { + let mut min_max_deltas: MinMaxDeltas = HashMap::new(); + + for (row, weight) in &delta.changes { + let group_key = self.extract_group_key(&row.values); + let group_key_str = Self::group_key_to_string(&group_key); + + for agg in &self.aggregates { + match agg { + AggregateFunction::Min(col_idx) | AggregateFunction::Max(col_idx) => { + if let Some(val) = row.values.get(*col_idx) { + // Skip NULL values - they don't participate in MIN/MAX + if val == &Value::Null { + continue; + } + // Create a HashableRow with just this value + // Use 0 as rowid since we only care about the value for comparison + let hashable_value = HashableRow::new(0, vec![val.clone()]); + let key = (*col_idx, hashable_value); + + let group_entry = + min_max_deltas.entry(group_key_str.clone()).or_default(); + + let value_entry = group_entry.entry(key).or_insert(0); + + // Accumulate the weight + *value_entry += weight; + } + } + _ => {} // Ignore non-MIN/MAX aggregates + } + } + } + + min_max_deltas + } + + pub fn set_tracker(&mut self, tracker: Arc>) { + self.tracker = Some(tracker); + } + + /// Generate a rowid for a group + /// For no GROUP BY: always returns 0 + /// For GROUP BY: returns a hash of the group key string + pub fn generate_group_rowid(&self, group_key_str: &str) -> i64 { + if self.group_by.is_empty() { + 0 + } else { + group_key_str + .bytes() + .fold(0i64, |acc, b| acc.wrapping_mul(31).wrapping_add(b as i64)) + } + } + + /// Extract group key values from a row + pub fn extract_group_key(&self, values: &[Value]) -> Vec { + let mut key = Vec::new(); + + for &idx in &self.group_by { + if let Some(val) = values.get(idx) { + key.push(val.clone()); + } else { + key.push(Value::Null); + } + } + + key + } + + /// Convert group key to string for indexing (since Value doesn't implement Hash) + pub fn group_key_to_string(key: &[Value]) -> String { + key.iter() + .map(|v| format!("{v:?}")) + .collect::>() + .join(",") + } +} + +impl IncrementalOperator for AggregateOperator { + fn eval( + &mut self, + state: &mut EvalState, + cursors: &mut DbspStateCursors, + ) -> Result> { + let (delta, _) = return_if_io!(self.eval_internal(state, cursors)); + Ok(IOResult::Done(delta)) + } + + fn commit( + &mut self, + mut deltas: DeltaPair, + cursors: &mut DbspStateCursors, + ) -> Result> { + // Aggregate operator only uses left delta, right must be empty + assert!( + deltas.right.is_empty(), + "AggregateOperator expects right delta to be empty in commit" + ); + let delta = std::mem::take(&mut deltas.left); + loop { + // Note: because we std::mem::replace here (without it, the borrow checker goes nuts, + // because we call self.eval_interval, which requires a mutable borrow), we have to + // restore the state if we return I/O. So we can't use return_if_io! + let mut state = + std::mem::replace(&mut self.commit_state, AggregateCommitState::Invalid); + match &mut state { + AggregateCommitState::Invalid => { + panic!("Reached invalid state! State was replaced, and not replaced back"); + } + AggregateCommitState::Idle => { + let eval_state = EvalState::from_delta(delta.clone()); + self.commit_state = AggregateCommitState::Eval { eval_state }; + } + AggregateCommitState::Eval { ref mut eval_state } => { + // Extract input delta before eval for MIN/MAX processing + let input_delta = eval_state.extract_delta(); + + // Extract MIN/MAX deltas before any I/O operations + let min_max_deltas = self.extract_min_max_deltas(&input_delta); + + // Create a new eval state with the same delta + *eval_state = EvalState::from_delta(input_delta.clone()); + + let (output_delta, computed_states) = return_and_restore_if_io!( + &mut self.commit_state, + state, + self.eval_internal(eval_state, cursors) + ); + + self.commit_state = AggregateCommitState::PersistDelta { + delta: output_delta, + computed_states, + current_idx: 0, + write_row: WriteRow::new(), + min_max_deltas, // Store for later use + }; + } + AggregateCommitState::PersistDelta { + delta, + computed_states, + current_idx, + write_row, + min_max_deltas, + } => { + let states_vec: Vec<_> = computed_states.iter().collect(); + + if *current_idx >= states_vec.len() { + // Use the min_max_deltas we extracted earlier from the input delta + self.commit_state = AggregateCommitState::PersistMinMax { + delta: delta.clone(), + min_max_persist_state: MinMaxPersistState::new(min_max_deltas.clone()), + }; + } else { + let (group_key_str, (group_key, agg_state)) = states_vec[*current_idx]; + + // Build the key components for the new table structure + // For regular aggregates, use column_index=0 and AGG_TYPE_REGULAR + let operator_storage_id = + generate_storage_id(self.operator_id, 0, AGG_TYPE_REGULAR); + let zset_id = self.generate_group_rowid(group_key_str); + let element_id = 0i64; + + // Determine weight: -1 to delete (cancels existing weight=1), 1 to insert/update + let weight = if agg_state.count == 0 { -1 } else { 1 }; + + // Serialize the aggregate state with group key (even for deletion, we need a row) + let state_blob = agg_state.to_blob(&self.aggregates, group_key); + let blob_value = Value::Blob(state_blob); + + // Build the aggregate storage format: [operator_id, zset_id, element_id, value, weight] + let operator_id_val = Value::Integer(operator_storage_id); + let zset_id_val = Value::Integer(zset_id); + let element_id_val = Value::Integer(element_id); + let blob_val = blob_value.clone(); + + // Create index key - the first 3 columns of our primary key + let index_key = vec![ + operator_id_val.clone(), + zset_id_val.clone(), + element_id_val.clone(), + ]; + + // Record values (without weight) + let record_values = + vec![operator_id_val, zset_id_val, element_id_val, blob_val]; + + return_and_restore_if_io!( + &mut self.commit_state, + state, + write_row.write_row(cursors, index_key, record_values, weight) + ); + + let delta = std::mem::take(delta); + let computed_states = std::mem::take(computed_states); + let min_max_deltas = std::mem::take(min_max_deltas); + + self.commit_state = AggregateCommitState::PersistDelta { + delta, + computed_states, + current_idx: *current_idx + 1, + write_row: WriteRow::new(), // Reset for next write + min_max_deltas, + }; + } + } + AggregateCommitState::PersistMinMax { + delta, + min_max_persist_state, + } => { + if !self.has_min_max() { + let delta = std::mem::take(delta); + self.commit_state = AggregateCommitState::Done { delta }; + } else { + return_and_restore_if_io!( + &mut self.commit_state, + state, + min_max_persist_state.persist_min_max( + self.operator_id, + &self.column_min_max, + cursors, + |group_key_str| self.generate_group_rowid(group_key_str) + ) + ); + + let delta = std::mem::take(delta); + self.commit_state = AggregateCommitState::Done { delta }; + } + } + AggregateCommitState::Done { delta } => { + self.commit_state = AggregateCommitState::Idle; + let delta = std::mem::take(delta); + return Ok(IOResult::Done(delta)); + } + } + } + } + + fn set_tracker(&mut self, tracker: Arc>) { + self.tracker = Some(tracker); + } +} + +/// State machine for recomputing MIN/MAX values after deletion +#[derive(Debug)] +pub enum RecomputeMinMax { + ProcessElements { + /// Current column being processed + current_column_idx: usize, + /// Columns to process (combined MIN and MAX) + columns_to_process: Vec<(String, usize, bool)>, // (group_key, column_name, is_min) + /// MIN/MAX deltas for checking values and weights + min_max_deltas: MinMaxDeltas, + }, + Scan { + /// Columns still to process + columns_to_process: Vec<(String, usize, bool)>, + /// Current index in columns_to_process (will resume from here) + current_column_idx: usize, + /// MIN/MAX deltas for checking values and weights + min_max_deltas: MinMaxDeltas, + /// Current group key being processed + group_key: String, + /// Current column name being processed + column_name: usize, + /// Whether we're looking for MIN (true) or MAX (false) + is_min: bool, + /// The scan state machine for finding the new MIN/MAX + scan_state: Box, + }, + Done, +} + +impl RecomputeMinMax { + pub fn new( + min_max_deltas: MinMaxDeltas, + existing_groups: &HashMap, + operator: &AggregateOperator, + ) -> Self { + let mut groups_to_check: HashSet<(String, usize, bool)> = HashSet::new(); + + // Remember the min_max_deltas are essentially just the only column that is affected by + // this min/max, in delta (actually ZSet - consolidated delta) format. This makes it easier + // for us to consume it in here. + // + // The most challenging case is the case where there is a retraction, since we need to go + // back to the index. + for (group_key_str, values) in &min_max_deltas { + for ((col_name, hashable_row), weight) in values { + let col_info = operator.column_min_max.get(col_name); + + let value = &hashable_row.values[0]; + + if *weight < 0 { + // Deletion detected - check if it's the current MIN/MAX + if let Some(state) = existing_groups.get(group_key_str) { + // Check for MIN + if let Some(current_min) = state.mins.get(col_name) { + if current_min == value { + groups_to_check.insert((group_key_str.clone(), *col_name, true)); + } + } + // Check for MAX + if let Some(current_max) = state.maxs.get(col_name) { + if current_max == value { + groups_to_check.insert((group_key_str.clone(), *col_name, false)); + } + } + } + } else if *weight > 0 { + // If it is not found in the existing groups, then we only need to care + // about this if this is a new record being inserted + if let Some(info) = col_info { + if info.has_min { + groups_to_check.insert((group_key_str.clone(), *col_name, true)); + } + if info.has_max { + groups_to_check.insert((group_key_str.clone(), *col_name, false)); + } + } + } + } + } + + if groups_to_check.is_empty() { + // No recomputation or initialization needed + Self::Done + } else { + // Convert HashSet to Vec for indexed processing + let groups_to_check_vec: Vec<_> = groups_to_check.into_iter().collect(); + Self::ProcessElements { + current_column_idx: 0, + columns_to_process: groups_to_check_vec, + min_max_deltas, + } + } + } + + pub fn process( + &mut self, + existing_groups: &mut HashMap, + operator: &AggregateOperator, + cursors: &mut DbspStateCursors, + ) -> Result> { + loop { + match self { + RecomputeMinMax::ProcessElements { + current_column_idx, + columns_to_process, + min_max_deltas, + } => { + if *current_column_idx >= columns_to_process.len() { + *self = RecomputeMinMax::Done; + return Ok(IOResult::Done(())); + } + + let (group_key, column_name, is_min) = + columns_to_process[*current_column_idx].clone(); + + // Column name is already the index + // Get the storage index from column_min_max map + let column_info = operator + .column_min_max + .get(&column_name) + .expect("Column should exist in column_min_max map"); + let storage_index = column_info.index; + + // Get current value from existing state + let current_value = existing_groups.get(&group_key).and_then(|state| { + if is_min { + state.mins.get(&column_name).cloned() + } else { + state.maxs.get(&column_name).cloned() + } + }); + + // Create storage keys for index lookup + let storage_id = + generate_storage_id(operator.operator_id, storage_index, AGG_TYPE_MINMAX); + let zset_id = operator.generate_group_rowid(&group_key); + + // Get the values for this group from min_max_deltas + let group_values = min_max_deltas.get(&group_key).cloned().unwrap_or_default(); + + let columns_to_process = std::mem::take(columns_to_process); + let min_max_deltas = std::mem::take(min_max_deltas); + + let scan_state = if is_min { + Box::new(ScanState::new_for_min( + current_value, + group_key.clone(), + column_name, + storage_id, + zset_id, + group_values, + )) + } else { + Box::new(ScanState::new_for_max( + current_value, + group_key.clone(), + column_name, + storage_id, + zset_id, + group_values, + )) + }; + + *self = RecomputeMinMax::Scan { + columns_to_process, + current_column_idx: *current_column_idx, + min_max_deltas, + group_key, + column_name, + is_min, + scan_state, + }; + } + RecomputeMinMax::Scan { + columns_to_process, + current_column_idx, + min_max_deltas, + group_key, + column_name, + is_min, + scan_state, + } => { + // Find new value using the scan state machine + let new_value = return_if_io!(scan_state.find_new_value(cursors)); + + // Update the state with new value (create if doesn't exist) + let state = existing_groups.entry(group_key.clone()).or_default(); + + if *is_min { + if let Some(min_val) = new_value { + state.mins.insert(*column_name, min_val); + } else { + state.mins.remove(column_name); + } + } else if let Some(max_val) = new_value { + state.maxs.insert(*column_name, max_val); + } else { + state.maxs.remove(column_name); + } + + // Move to next column + let min_max_deltas = std::mem::take(min_max_deltas); + let columns_to_process = std::mem::take(columns_to_process); + *self = RecomputeMinMax::ProcessElements { + current_column_idx: *current_column_idx + 1, + columns_to_process, + min_max_deltas, + }; + } + RecomputeMinMax::Done => { + return Ok(IOResult::Done(())); + } + } + } + } +} + +/// State machine for scanning through the index to find new MIN/MAX values +#[derive(Debug)] +pub enum ScanState { + CheckCandidate { + /// Current candidate value for MIN/MAX + candidate: Option, + /// Group key being processed + group_key: String, + /// Column name being processed + column_name: usize, + /// Storage ID for the index seek + storage_id: i64, + /// ZSet ID for the group + zset_id: i64, + /// Group values from MinMaxDeltas: (column_name, HashableRow) -> weight + group_values: HashMap<(usize, HashableRow), isize>, + /// Whether we're looking for MIN (true) or MAX (false) + is_min: bool, + }, + FetchNextCandidate { + /// Current candidate to seek past + current_candidate: Value, + /// Group key being processed + group_key: String, + /// Column name being processed + column_name: usize, + /// Storage ID for the index seek + storage_id: i64, + /// ZSet ID for the group + zset_id: i64, + /// Group values from MinMaxDeltas: (column_name, HashableRow) -> weight + group_values: HashMap<(usize, HashableRow), isize>, + /// Whether we're looking for MIN (true) or MAX (false) + is_min: bool, + }, + Done { + /// The final MIN/MAX value found + result: Option, + }, +} + +impl ScanState { + pub fn new_for_min( + current_min: Option, + group_key: String, + column_name: usize, + storage_id: i64, + zset_id: i64, + group_values: HashMap<(usize, HashableRow), isize>, + ) -> Self { + Self::CheckCandidate { + candidate: current_min, + group_key, + column_name, + storage_id, + zset_id, + group_values, + is_min: true, + } + } + + // Extract a new candidate from the index. It is possible that, when searching, + // we end up going into a different operator altogether. That means we have + // exhausted this operator (or group) entirely, and no good candidate was found + fn extract_new_candidate( + cursors: &mut DbspStateCursors, + index_record: &ImmutableRecord, + seek_op: SeekOp, + storage_id: i64, + zset_id: i64, + ) -> Result>> { + let seek_result = return_if_io!(cursors + .index_cursor + .seek(SeekKey::IndexKey(index_record), seek_op)); + if !matches!(seek_result, SeekResult::Found) { + return Ok(IOResult::Done(None)); + } + + let record = return_if_io!(cursors.index_cursor.record()).ok_or_else(|| { + LimboError::InternalError( + "Record found on the cursor, but could not be read".to_string(), + ) + })?; + + let values = record.get_values(); + if values.len() < 3 { + return Ok(IOResult::Done(None)); + } + + let Some(rec_storage_id) = values.first() else { + return Ok(IOResult::Done(None)); + }; + let Some(rec_zset_id) = values.get(1) else { + return Ok(IOResult::Done(None)); + }; + + // Check if we're still in the same group + if let (RefValue::Integer(rec_sid), RefValue::Integer(rec_zid)) = + (rec_storage_id, rec_zset_id) + { + if *rec_sid != storage_id || *rec_zid != zset_id { + return Ok(IOResult::Done(None)); + } + } else { + return Ok(IOResult::Done(None)); + } + + // Get the value (3rd element) + Ok(IOResult::Done(values.get(2).map(|v| v.to_owned()))) + } + + pub fn new_for_max( + current_max: Option, + group_key: String, + column_name: usize, + storage_id: i64, + zset_id: i64, + group_values: HashMap<(usize, HashableRow), isize>, + ) -> Self { + Self::CheckCandidate { + candidate: current_max, + group_key, + column_name, + storage_id, + zset_id, + group_values, + is_min: false, + } + } + + pub fn find_new_value( + &mut self, + cursors: &mut DbspStateCursors, + ) -> Result>> { + loop { + match self { + ScanState::CheckCandidate { + candidate, + group_key, + column_name, + storage_id, + zset_id, + group_values, + is_min, + } => { + // First, check if we have a candidate + if let Some(cand_val) = candidate { + // Check if the candidate is retracted (weight <= 0) + // Create a HashableRow to look up the weight + let hashable_cand = HashableRow::new(0, vec![cand_val.clone()]); + let key = (*column_name, hashable_cand); + let is_retracted = + group_values.get(&key).is_some_and(|weight| *weight <= 0); + + if is_retracted { + // Candidate is retracted, need to fetch next from index + *self = ScanState::FetchNextCandidate { + current_candidate: cand_val.clone(), + group_key: std::mem::take(group_key), + column_name: std::mem::take(column_name), + storage_id: *storage_id, + zset_id: *zset_id, + group_values: std::mem::take(group_values), + is_min: *is_min, + }; + continue; + } + } + + // Candidate is valid or we have no candidate + // Now find the best value from insertions in group_values + let mut best_from_zset = None; + for ((col, hashable_val), weight) in group_values.iter() { + if col == column_name && *weight > 0 { + let value = &hashable_val.values[0]; + // Skip NULL values - they don't participate in MIN/MAX + if value == &Value::Null { + continue; + } + // This is an insertion for our column + if let Some(ref current_best) = best_from_zset { + if *is_min { + if value.cmp(current_best) == std::cmp::Ordering::Less { + best_from_zset = Some(value.clone()); + } + } else if value.cmp(current_best) == std::cmp::Ordering::Greater { + best_from_zset = Some(value.clone()); + } + } else { + best_from_zset = Some(value.clone()); + } + } + } + + // Compare candidate with best from ZSet, filtering out NULLs + let result = match (&candidate, &best_from_zset) { + (Some(cand), Some(zset_val)) if cand != &Value::Null => { + if *is_min { + if zset_val.cmp(cand) == std::cmp::Ordering::Less { + Some(zset_val.clone()) + } else { + Some(cand.clone()) + } + } else if zset_val.cmp(cand) == std::cmp::Ordering::Greater { + Some(zset_val.clone()) + } else { + Some(cand.clone()) + } + } + (Some(cand), None) if cand != &Value::Null => Some(cand.clone()), + (None, Some(zset_val)) => Some(zset_val.clone()), + (Some(cand), Some(_)) if cand == &Value::Null => best_from_zset, + _ => None, + }; + + *self = ScanState::Done { result }; + } + + ScanState::FetchNextCandidate { + current_candidate, + group_key, + column_name, + storage_id, + zset_id, + group_values, + is_min, + } => { + // Seek to the next value in the index + let index_key = vec![ + Value::Integer(*storage_id), + Value::Integer(*zset_id), + current_candidate.clone(), + ]; + let index_record = ImmutableRecord::from_values(&index_key, index_key.len()); + + let seek_op = if *is_min { + SeekOp::GT // For MIN, seek greater than current + } else { + SeekOp::LT // For MAX, seek less than current + }; + + let new_candidate = return_if_io!(Self::extract_new_candidate( + cursors, + &index_record, + seek_op, + *storage_id, + *zset_id + )); + + *self = ScanState::CheckCandidate { + candidate: new_candidate, + group_key: std::mem::take(group_key), + column_name: std::mem::take(column_name), + storage_id: *storage_id, + zset_id: *zset_id, + group_values: std::mem::take(group_values), + is_min: *is_min, + }; + } + + ScanState::Done { result } => { + return Ok(IOResult::Done(result.clone())); + } + } + } + } +} + +/// State machine for persisting Min/Max values to storage +#[derive(Debug)] +pub enum MinMaxPersistState { + Init { + min_max_deltas: MinMaxDeltas, + group_keys: Vec, + }, + ProcessGroup { + min_max_deltas: MinMaxDeltas, + group_keys: Vec, + group_idx: usize, + value_idx: usize, + }, + WriteValue { + min_max_deltas: MinMaxDeltas, + group_keys: Vec, + group_idx: usize, + value_idx: usize, + value: Value, + column_name: usize, + weight: isize, + write_row: WriteRow, + }, + Done, +} + +impl MinMaxPersistState { + pub fn new(min_max_deltas: MinMaxDeltas) -> Self { + let group_keys: Vec = min_max_deltas.keys().cloned().collect(); + Self::Init { + min_max_deltas, + group_keys, + } + } + + pub fn persist_min_max( + &mut self, + operator_id: usize, + column_min_max: &HashMap, + cursors: &mut DbspStateCursors, + generate_group_rowid: impl Fn(&str) -> i64, + ) -> Result> { + loop { + match self { + MinMaxPersistState::Init { + min_max_deltas, + group_keys, + } => { + let min_max_deltas = std::mem::take(min_max_deltas); + let group_keys = std::mem::take(group_keys); + *self = MinMaxPersistState::ProcessGroup { + min_max_deltas, + group_keys, + group_idx: 0, + value_idx: 0, + }; + } + MinMaxPersistState::ProcessGroup { + min_max_deltas, + group_keys, + group_idx, + value_idx, + } => { + // Check if we're past all groups + if *group_idx >= group_keys.len() { + *self = MinMaxPersistState::Done; + continue; + } + + let group_key_str = &group_keys[*group_idx]; + let values = &min_max_deltas[group_key_str]; // This should always exist + + // Convert HashMap to Vec for indexed access + let values_vec: Vec<_> = values.iter().collect(); + + // Check if we have more values in current group + if *value_idx >= values_vec.len() { + *group_idx += 1; + *value_idx = 0; + // Continue to check if we're past all groups now + continue; + } + + // Process current value and extract what we need before taking ownership + let ((column_name, hashable_row), weight) = values_vec[*value_idx]; + let column_name = *column_name; + let value = hashable_row.values[0].clone(); // Extract the Value from HashableRow + let weight = *weight; + + let min_max_deltas = std::mem::take(min_max_deltas); + let group_keys = std::mem::take(group_keys); + *self = MinMaxPersistState::WriteValue { + min_max_deltas, + group_keys, + group_idx: *group_idx, + value_idx: *value_idx, + column_name, + value, + weight, + write_row: WriteRow::new(), + }; + } + MinMaxPersistState::WriteValue { + min_max_deltas, + group_keys, + group_idx, + value_idx, + value, + column_name, + weight, + write_row, + } => { + // Should have exited in the previous state + assert!(*group_idx < group_keys.len()); + + let group_key_str = &group_keys[*group_idx]; + + // Get the column info from the pre-computed map + let column_info = column_min_max + .get(column_name) + .expect("Column should exist in column_min_max map"); + let column_index = column_info.index; + + // Build the key components for MinMax storage using new encoding + let storage_id = + generate_storage_id(operator_id, column_index, AGG_TYPE_MINMAX); + let zset_id = generate_group_rowid(group_key_str); + + // element_id is the actual value for Min/Max + let element_id_val = value.clone(); + + // Create index key + let index_key = vec![ + Value::Integer(storage_id), + Value::Integer(zset_id), + element_id_val.clone(), + ]; + + // Record values (operator_id, zset_id, element_id, unused_placeholder) + // For MIN/MAX, the element_id IS the value, so we use NULL for the 4th column + let record_values = vec![ + Value::Integer(storage_id), + Value::Integer(zset_id), + element_id_val.clone(), + Value::Null, // Placeholder - not used for MIN/MAX + ]; + + return_if_io!(write_row.write_row( + cursors, + index_key.clone(), + record_values, + *weight + )); + + // Move to next value + let min_max_deltas = std::mem::take(min_max_deltas); + let group_keys = std::mem::take(group_keys); + *self = MinMaxPersistState::ProcessGroup { + min_max_deltas, + group_keys, + group_idx: *group_idx, + value_idx: *value_idx + 1, + }; + } + MinMaxPersistState::Done => { + return Ok(IOResult::Done(())); + } + } + } + } +} diff --git a/core/incremental/compiler.rs b/core/incremental/compiler.rs index 972d6797b..8c8189261 100644 --- a/core/incremental/compiler.rs +++ b/core/incremental/compiler.rs @@ -9,12 +9,14 @@ use crate::incremental::dbsp::{Delta, DeltaPair}; use crate::incremental::expr_compiler::CompiledExpression; use crate::incremental::operator::{ create_dbsp_state_index, DbspStateCursors, EvalState, FilterOperator, FilterPredicate, - IncrementalOperator, InputOperator, ProjectOperator, + IncrementalOperator, InputOperator, JoinOperator, JoinType, ProjectOperator, }; +use crate::schema::Type; use crate::storage::btree::{BTreeCursor, BTreeKey}; // Note: logical module must be made pub(crate) in translate/mod.rs use crate::translate::logical::{ - BinaryOperator, LogicalExpr, LogicalPlan, LogicalSchema, SchemaRef, + BinaryOperator, Column, ColumnInfo, JoinType as LogicalJoinType, LogicalExpr, LogicalPlan, + LogicalSchema, SchemaRef, }; use crate::types::{IOResult, ImmutableRecord, SeekKey, SeekOp, SeekResult, Value}; use crate::Pager; @@ -288,6 +290,12 @@ pub enum DbspOperator { aggr_exprs: Vec, schema: SchemaRef, }, + /// Join operator (⋈) - joins two relations + Join { + join_type: JoinType, + on_exprs: Vec<(DbspExpr, DbspExpr)>, + schema: SchemaRef, + }, /// Input operator - source of data Input { name: String, schema: SchemaRef }, } @@ -789,6 +797,13 @@ impl DbspCircuit { "{indent}Aggregate[{node_id}]: GROUP BY {group_exprs:?}, AGGR {aggr_exprs:?}" )?; } + DbspOperator::Join { + join_type, + on_exprs, + .. + } => { + writeln!(f, "{indent}Join[{node_id}]: {join_type:?} ON {on_exprs:?}")?; + } DbspOperator::Input { name, .. } => { writeln!(f, "{indent}Input[{node_id}]: {name}")?; } @@ -841,7 +856,7 @@ impl DbspCompiler { // Get input column names for the ProjectOperator let input_schema = proj.input.schema(); let input_column_names: Vec = input_schema.columns.iter() - .map(|(name, _)| name.clone()) + .map(|col| col.name.clone()) .collect(); // Convert logical expressions to DBSP expressions @@ -853,14 +868,14 @@ impl DbspCompiler { let mut compiled_exprs = Vec::new(); let mut aliases = Vec::new(); for expr in &proj.exprs { - let (compiled, alias) = Self::compile_expression(expr, &input_column_names)?; + let (compiled, alias) = Self::compile_expression(expr, input_schema)?; compiled_exprs.push(compiled); aliases.push(alias); } // Get output column names from the projection schema let output_column_names: Vec = proj.schema.columns.iter() - .map(|(name, _)| name.clone()) + .map(|col| col.name.clone()) .collect(); // Create the ProjectOperator @@ -882,29 +897,154 @@ impl DbspCompiler { // Compile the input first let input_id = self.compile_plan(&filter.input)?; - // Get column names from input schema + // Get input schema for column resolution let input_schema = filter.input.schema(); - let column_names: Vec = input_schema.columns.iter() - .map(|(name, _)| name.clone()) - .collect(); - // Convert predicate to DBSP expression - let dbsp_predicate = Self::compile_expr(&filter.predicate)?; + // Check if the predicate contains expressions that need to be computed + if Self::predicate_needs_projection(&filter.predicate) { + // Complex expression in WHERE clause - need to add projection first + // 1. Create projection that adds the computed expression as a new column - // Convert to FilterPredicate - let filter_predicate = Self::compile_filter_predicate(&filter.predicate)?; + // First, get all existing columns + let mut projection_exprs = Vec::new(); + let mut dbsp_exprs = Vec::new(); - // Create executable operator - let executable: Box = - Box::new(FilterOperator::new(filter_predicate, column_names)); + for col in &input_schema.columns { + projection_exprs.push(LogicalExpr::Column(Column { + name: col.name.clone(), + table: None, + })); + dbsp_exprs.push(DbspExpr::Column(col.name.clone())); + } - // Create filter node - let node_id = self.circuit.add_node( - DbspOperator::Filter { predicate: dbsp_predicate }, - vec![input_id], - executable, - ); - Ok(node_id) + // Now add the expression as a computed column + let temp_column_name = "__temp_filter_expr"; + let computed_expr = Self::extract_expression_from_predicate(&filter.predicate)?; + projection_exprs.push(computed_expr.clone()); + + // Compile the projection expressions + let mut compiled_exprs = Vec::new(); + let mut aliases = Vec::new(); + let mut output_names = Vec::new(); + for (i, expr) in projection_exprs.iter().enumerate() { + let (compiled, _alias) = Self::compile_expression(expr, input_schema)?; + compiled_exprs.push(compiled); + if i < input_schema.columns.len() { + aliases.push(None); + output_names.push(input_schema.columns[i].name.clone()); + } else { + aliases.push(Some(temp_column_name.to_string())); + output_names.push(temp_column_name.to_string()); + } + } + + // Get input column names for ProjectOperator + let input_column_names: Vec = input_schema.columns.iter() + .map(|col| col.name.clone()) + .collect(); + + // Create projection operator + let proj_executable: Box = + Box::new(ProjectOperator::from_compiled( + compiled_exprs.clone(), + aliases.clone(), + input_column_names.clone(), + output_names.clone() + )?); + + // Create updated schema for the projection output + let mut proj_schema_columns = input_schema.columns.clone(); + proj_schema_columns.push(ColumnInfo { + name: temp_column_name.to_string(), + table: None, + database: None, + table_alias: None, + ty: Type::Integer, // Computed expressions default to Integer + }); + let proj_schema = SchemaRef::new(LogicalSchema { + columns: proj_schema_columns, + }); + + // Add projection node + let proj_id = self.circuit.add_node( + DbspOperator::Projection { + exprs: dbsp_exprs.clone(), + schema: proj_schema.clone(), + }, + vec![input_id], + proj_executable, + ); + + // Now create a filter that replaces the complex expression with the temp column + // but keeps all other conditions intact + let replaced_predicate = Self::replace_complex_with_temp(&filter.predicate, temp_column_name)?; + let filter_predicate = Self::compile_filter_predicate(&replaced_predicate, &proj_schema)?; + + let filter_executable: Box = + Box::new(FilterOperator::new(filter_predicate)); + + // Create filter node + let filter_id = self.circuit.add_node( + DbspOperator::Filter { predicate: Self::compile_expr(&replaced_predicate)? }, + vec![proj_id], + filter_executable, + ); + + // Finally, project again to remove the temporary column + let mut final_exprs = Vec::new(); + let mut final_aliases = Vec::new(); + let mut final_names = Vec::new(); + let mut final_dbsp_exprs = Vec::new(); + + for (i, column) in input_schema.columns.iter().enumerate() { + let col_name = &column.name; + final_exprs.push(compiled_exprs[i].clone()); + final_aliases.push(None); + final_names.push(col_name.clone()); + final_dbsp_exprs.push(DbspExpr::Column(col_name.clone())); + } + + // Input names for the final projection include the temp column + let filter_output_names = output_names.clone(); + + let final_proj_executable: Box = + Box::new(ProjectOperator::from_compiled( + final_exprs, + final_aliases, + filter_output_names, + final_names.clone() + )?); + + let final_id = self.circuit.add_node( + DbspOperator::Projection { + exprs: final_dbsp_exprs, + schema: input_schema.clone(), // Back to original schema + }, + vec![filter_id], + final_proj_executable, + ); + + Ok(final_id) + } else { + // Simple filter - use existing implementation + // Convert predicate to DBSP expression + let dbsp_predicate = Self::compile_expr(&filter.predicate)?; + + // Convert to FilterPredicate + let filter_predicate = Self::compile_filter_predicate(&filter.predicate, input_schema)?; + + // Create executable operator + let executable: Box = + Box::new(FilterOperator::new(filter_predicate)); + + // Create filter node + let node_id = self.circuit.add_node( + DbspOperator::Filter { predicate: dbsp_predicate }, + vec![input_id], + executable, + ); + Ok(node_id) + } } LogicalPlan::Aggregate(agg) => { // Compile the input first @@ -913,16 +1053,21 @@ impl DbspCompiler { // Get input column names let input_schema = agg.input.schema(); let input_column_names: Vec = input_schema.columns.iter() - .map(|(name, _)| name.clone()) + .map(|col| col.name.clone()) .collect(); - // Compile group by expressions to column names - let mut group_by_columns = Vec::new(); + // Compile group by expressions to column indices + let mut group_by_indices = Vec::new(); let mut dbsp_group_exprs = Vec::new(); for expr in &agg.group_expr { // For now, only support simple column references in GROUP BY if let LogicalExpr::Column(col) = expr { - group_by_columns.push(col.name.clone()); + // Find the column index in the input schema using qualified lookup + let (col_idx, _) = input_schema.find_column(&col.name, col.table.as_deref()) + .ok_or_else(|| LimboError::ParseError( + format!("GROUP BY column '{}' not found in input", col.name) + ))?; + group_by_indices.push(col_idx); dbsp_group_exprs.push(DbspExpr::Column(col.name.clone())); } else { return Err(LimboError::ParseError( @@ -936,7 +1081,7 @@ impl DbspCompiler { for expr in &agg.aggr_expr { if let LogicalExpr::AggregateFunction { fun, args, .. } = expr { use crate::function::AggFunc; - use crate::incremental::operator::AggregateFunction; + use crate::incremental::aggregate_operator::AggregateFunction; match fun { AggFunc::Count | AggFunc::Count0 => { @@ -946,9 +1091,13 @@ impl DbspCompiler { if args.is_empty() { return Err(LimboError::ParseError("SUM requires an argument".to_string())); } - // Extract column name from the argument + // Extract column index from the argument if let LogicalExpr::Column(col) = &args[0] { - aggregate_functions.push(AggregateFunction::Sum(col.name.clone())); + let (col_idx, _) = input_schema.find_column(&col.name, col.table.as_deref()) + .ok_or_else(|| LimboError::ParseError( + format!("SUM column '{}' not found in input", col.name) + ))?; + aggregate_functions.push(AggregateFunction::Sum(col_idx)); } else { return Err(LimboError::ParseError( "Only column references are supported in aggregate functions for incremental views".to_string() @@ -960,7 +1109,11 @@ impl DbspCompiler { return Err(LimboError::ParseError("AVG requires an argument".to_string())); } if let LogicalExpr::Column(col) = &args[0] { - aggregate_functions.push(AggregateFunction::Avg(col.name.clone())); + let (col_idx, _) = input_schema.find_column(&col.name, col.table.as_deref()) + .ok_or_else(|| LimboError::ParseError( + format!("AVG column '{}' not found in input", col.name) + ))?; + aggregate_functions.push(AggregateFunction::Avg(col_idx)); } else { return Err(LimboError::ParseError( "Only column references are supported in aggregate functions for incremental views".to_string() @@ -972,7 +1125,11 @@ impl DbspCompiler { return Err(LimboError::ParseError("MIN requires an argument".to_string())); } if let LogicalExpr::Column(col) = &args[0] { - aggregate_functions.push(AggregateFunction::Min(col.name.clone())); + let (col_idx, _) = input_schema.find_column(&col.name, col.table.as_deref()) + .ok_or_else(|| LimboError::ParseError( + format!("MIN column '{}' not found in input", col.name) + ))?; + aggregate_functions.push(AggregateFunction::Min(col_idx)); } else { return Err(LimboError::ParseError( "Only column references are supported in MIN for incremental views".to_string() @@ -984,7 +1141,11 @@ impl DbspCompiler { return Err(LimboError::ParseError("MAX requires an argument".to_string())); } if let LogicalExpr::Column(col) = &args[0] { - aggregate_functions.push(AggregateFunction::Max(col.name.clone())); + let (col_idx, _) = input_schema.find_column(&col.name, col.table.as_deref()) + .ok_or_else(|| LimboError::ParseError( + format!("MAX column '{}' not found in input", col.name) + ))?; + aggregate_functions.push(AggregateFunction::Max(col_idx)); } else { return Err(LimboError::ParseError( "Only column references are supported in MAX for incremental views".to_string() @@ -1006,10 +1167,10 @@ impl DbspCompiler { let operator_id = self.circuit.next_id; - use crate::incremental::operator::AggregateOperator; + use crate::incremental::aggregate_operator::AggregateOperator; let executable: Box = Box::new(AggregateOperator::new( operator_id, - group_by_columns.clone(), + group_by_indices.clone(), aggregate_functions.clone(), input_column_names.clone(), )); @@ -1026,6 +1187,90 @@ impl DbspCompiler { Ok(result_node_id) } + LogicalPlan::Join(join) => { + // Compile left and right inputs + let left_id = self.compile_plan(&join.left)?; + let right_id = self.compile_plan(&join.right)?; + + // Get schemas from inputs + let left_schema = join.left.schema(); + let right_schema = join.right.schema(); + + // Get column names from left and right + let left_columns: Vec = left_schema.columns.iter() + .map(|col| col.name.clone()) + .collect(); + let right_columns: Vec = right_schema.columns.iter() + .map(|col| col.name.clone()) + .collect(); + + // Extract join key indices from join conditions + // For now, we only support equijoin conditions + let mut left_key_indices = Vec::new(); + let mut right_key_indices = Vec::new(); + let mut dbsp_on_exprs = Vec::new(); + + for (left_expr, right_expr) in &join.on { + // Extract column indices from join expressions + // We expect simple column references in join conditions + if let (LogicalExpr::Column(left_col), LogicalExpr::Column(right_col)) = (left_expr, right_expr) { + // Find indices in respective schemas using qualified lookup + let (left_idx, _) = left_schema.find_column(&left_col.name, left_col.table.as_deref()) + .ok_or_else(|| LimboError::ParseError( + format!("Join column '{}' not found in left input", left_col.name) + ))?; + let (right_idx, _) = right_schema.find_column(&right_col.name, right_col.table.as_deref()) + .ok_or_else(|| LimboError::ParseError( + format!("Join column '{}' not found in right input", right_col.name) + ))?; + + left_key_indices.push(left_idx); + right_key_indices.push(right_idx); + + // Convert to DBSP expressions + dbsp_on_exprs.push(( + DbspExpr::Column(left_col.name.clone()), + DbspExpr::Column(right_col.name.clone()) + )); + } else { + return Err(LimboError::ParseError( + "Only simple column references are supported in join conditions for incremental views".to_string() + )); + } + } + + // Convert logical join type to operator join type + let operator_join_type = match join.join_type { + LogicalJoinType::Inner => JoinType::Inner, + LogicalJoinType::Left => JoinType::Left, + LogicalJoinType::Right => JoinType::Right, + LogicalJoinType::Full => JoinType::Full, + LogicalJoinType::Cross => JoinType::Cross, + }; + + // Create JoinOperator + let operator_id = self.circuit.next_id; + let executable: Box = Box::new(JoinOperator::new( + operator_id, + operator_join_type.clone(), + left_key_indices, + right_key_indices, + left_columns, + right_columns, + )?); + + // Create join node + let node_id = self.circuit.add_node( + DbspOperator::Join { + join_type: operator_join_type, + on_exprs: dbsp_on_exprs, + schema: join.schema.clone(), + }, + vec![left_id, right_id], + executable, + ); + Ok(node_id) + } LogicalPlan::TableScan(scan) => { // Create input node with InputOperator for uniform handling let executable: Box = @@ -1042,7 +1287,7 @@ impl DbspCompiler { Ok(node_id) } _ => Err(LimboError::ParseError( - format!("Unsupported operator in DBSP compiler: only Filter, Projection and Aggregate are supported, got: {:?}", + format!("Unsupported operator in DBSP compiler: only Filter, Projection, Join and Aggregate are supported, got: {:?}", match plan { LogicalPlan::Sort(_) => "Sort", LogicalPlan::Limit(_) => "Limit", @@ -1095,17 +1340,24 @@ impl DbspCompiler { /// Compile a logical expression to a CompiledExpression and optional alias fn compile_expression( expr: &LogicalExpr, - input_column_names: &[String], + input_schema: &LogicalSchema, ) -> Result<(CompiledExpression, Option)> { // Check for alias first if let LogicalExpr::Alias { expr, alias } = expr { // For aliases, compile the underlying expression and return with alias - let (compiled, _) = Self::compile_expression(expr, input_column_names)?; + let (compiled, _) = Self::compile_expression(expr, input_schema)?; return Ok((compiled, Some(alias.clone()))); } - // Convert LogicalExpr to AST Expr - let ast_expr = Self::logical_to_ast_expr(expr)?; + // Convert LogicalExpr to AST Expr with proper column resolution + let ast_expr = Self::logical_to_ast_expr_with_schema(expr, input_schema)?; + + // Extract column names from schema for CompiledExpression::compile + let input_column_names: Vec = input_schema + .columns + .iter() + .map(|col| col.name.clone()) + .collect(); // For all expressions (simple or complex), use CompiledExpression::compile // This handles both trivial cases and complex VDBE compilation @@ -1129,7 +1381,7 @@ impl DbspCompiler { // Compile the expression using the existing CompiledExpression::compile let compiled = CompiledExpression::compile( &ast_expr, - input_column_names, + &input_column_names, &schema, &temp_syms, internal_conn, @@ -1138,25 +1390,45 @@ impl DbspCompiler { Ok((compiled, None)) } - /// Convert LogicalExpr to AST Expr - fn logical_to_ast_expr(expr: &LogicalExpr) -> Result { + /// Convert LogicalExpr to AST Expr with qualified column resolution + fn logical_to_ast_expr_with_schema( + expr: &LogicalExpr, + schema: &LogicalSchema, + ) -> Result { use turso_parser::ast; match expr { - LogicalExpr::Column(col) => Ok(ast::Expr::Id(ast::Name::Ident(col.name.clone()))), + LogicalExpr::Column(col) => { + // Find the column index using qualified lookup + let (idx, _) = schema + .find_column(&col.name, col.table.as_deref()) + .ok_or_else(|| { + LimboError::ParseError(format!( + "Column '{}' with table {:?} not found in schema", + col.name, col.table + )) + })?; + // Return a Register expression with the correct index + Ok(ast::Expr::Register(idx)) + } LogicalExpr::Literal(val) => { let lit = match val { Value::Integer(i) => ast::Literal::Numeric(i.to_string()), Value::Float(f) => ast::Literal::Numeric(f.to_string()), - Value::Text(t) => ast::Literal::String(t.to_string()), + Value::Text(t) => { + // Add quotes for string literals as translate_expr expects them + // Also escape any single quotes in the string + let escaped = t.to_string().replace('\'', "''"); + ast::Literal::String(format!("'{escaped}'")) + } Value::Blob(b) => ast::Literal::Blob(format!("{b:?}")), Value::Null => ast::Literal::Null, }; Ok(ast::Expr::Literal(lit)) } LogicalExpr::BinaryExpr { left, op, right } => { - let left_expr = Self::logical_to_ast_expr(left)?; - let right_expr = Self::logical_to_ast_expr(right)?; + let left_expr = Self::logical_to_ast_expr_with_schema(left, schema)?; + let right_expr = Self::logical_to_ast_expr_with_schema(right, schema)?; Ok(ast::Expr::Binary( Box::new(left_expr), *op, @@ -1164,7 +1436,10 @@ impl DbspCompiler { )) } LogicalExpr::ScalarFunction { fun, args } => { - let ast_args: Result> = args.iter().map(Self::logical_to_ast_expr).collect(); + let ast_args: Result> = args + .iter() + .map(|arg| Self::logical_to_ast_expr_with_schema(arg, schema)) + .collect(); let ast_args: Vec> = ast_args?.into_iter().map(Box::new).collect(); Ok(ast::Expr::FunctionCall { name: ast::Name::Ident(fun.clone()), @@ -1179,7 +1454,7 @@ impl DbspCompiler { } LogicalExpr::Alias { expr, .. } => { // For conversion to AST, ignore the alias and convert the inner expression - Self::logical_to_ast_expr(expr) + Self::logical_to_ast_expr_with_schema(expr, schema) } LogicalExpr::AggregateFunction { fun, @@ -1187,7 +1462,10 @@ impl DbspCompiler { distinct, } => { // Convert aggregate function to AST - let ast_args: Result> = args.iter().map(Self::logical_to_ast_expr).collect(); + let ast_args: Result> = args + .iter() + .map(|arg| Self::logical_to_ast_expr_with_schema(arg, schema)) + .collect(); let ast_args: Vec> = ast_args?.into_iter().map(Box::new).collect(); // Get the function name based on the aggregate type @@ -1225,43 +1503,235 @@ impl DbspCompiler { } } + /// Check if a predicate contains expressions that need projection + fn predicate_needs_projection(expr: &LogicalExpr) -> bool { + match expr { + LogicalExpr::BinaryExpr { left, op, right } => { + match (left.as_ref(), right.as_ref()) { + // Simple column to literal - OK + (LogicalExpr::Column(_), LogicalExpr::Literal(_)) => false, + // Simple column to column - OK + (LogicalExpr::Column(_), LogicalExpr::Column(_)) => false, + // AND/OR of simple expressions - check recursively + _ if matches!(op, BinaryOperator::And | BinaryOperator::Or) => { + Self::predicate_needs_projection(left) + || Self::predicate_needs_projection(right) + } + // Any other pattern needs projection + _ => true, + } + } + _ => false, + } + } + + /// Extract the expression part from a predicate that needs to be computed + fn extract_expression_from_predicate(expr: &LogicalExpr) -> Result { + match expr { + LogicalExpr::BinaryExpr { left, op, right } => { + // Handle AND/OR - recursively find the complex expression + if matches!(op, BinaryOperator::And | BinaryOperator::Or) { + // Check left side first + if Self::predicate_needs_projection(left) { + return Self::extract_expression_from_predicate(left); + } + // Then check right side + if Self::predicate_needs_projection(right) { + return Self::extract_expression_from_predicate(right); + } + // Neither side needs projection (shouldn't happen if predicate_needs_projection was true) + return Ok(expr.clone()); + } + + // For expressions like (age * 2) > 30, we want to extract (age * 2) + if matches!( + op, + BinaryOperator::Greater + | BinaryOperator::GreaterEquals + | BinaryOperator::Less + | BinaryOperator::LessEquals + | BinaryOperator::Equals + | BinaryOperator::NotEquals + ) { + // Return the left side if it's not a simple column + if !matches!(left.as_ref(), LogicalExpr::Column(_)) { + Ok((**left).clone()) + } else { + // Must be the whole expression then + Ok(expr.clone()) + } + } else { + Ok(expr.clone()) + } + } + _ => Ok(expr.clone()), + } + } + + /// Replace complex expressions in the predicate with references to the temp column + fn replace_complex_with_temp( + expr: &LogicalExpr, + temp_column_name: &str, + ) -> Result { + match expr { + LogicalExpr::BinaryExpr { left, op, right } => { + // Handle AND/OR - recursively process both sides + if matches!(op, BinaryOperator::And | BinaryOperator::Or) { + let new_left = Self::replace_complex_with_temp(left, temp_column_name)?; + let new_right = Self::replace_complex_with_temp(right, temp_column_name)?; + return Ok(LogicalExpr::BinaryExpr { + left: Box::new(new_left), + op: *op, + right: Box::new(new_right), + }); + } + + // Check if this is a complex comparison that needs replacement + if Self::predicate_needs_projection(expr) { + // Replace the complex expression (left side) with the temp column + return Ok(LogicalExpr::BinaryExpr { + left: Box::new(LogicalExpr::Column(Column { + name: temp_column_name.to_string(), + table: None, + })), + op: *op, + right: right.clone(), + }); + } + + // Simple comparison - keep as is + Ok(expr.clone()) + } + _ => Ok(expr.clone()), + } + } + /// Compile a logical expression to a FilterPredicate for execution - fn compile_filter_predicate(expr: &LogicalExpr) -> Result { + fn compile_filter_predicate( + expr: &LogicalExpr, + schema: &LogicalSchema, + ) -> Result { match expr { LogicalExpr::BinaryExpr { left, op, right } => { // Extract column name and value for simple predicates - if let (LogicalExpr::Column(col), LogicalExpr::Literal(val)) = + // First check for column-to-column comparisons + if let (LogicalExpr::Column(left_col), LogicalExpr::Column(right_col)) = (left.as_ref(), right.as_ref()) { + // Resolve both column names to indices + let left_idx = schema + .columns + .iter() + .position(|c| c.name == left_col.name) + .ok_or_else(|| { + crate::LimboError::ParseError(format!( + "Column '{}' not found in schema for filter", + left_col.name + )) + })?; + + let right_idx = schema + .columns + .iter() + .position(|c| c.name == right_col.name) + .ok_or_else(|| { + crate::LimboError::ParseError(format!( + "Column '{}' not found in schema for filter", + right_col.name + )) + })?; + + match op { + BinaryOperator::Equals => Ok(FilterPredicate::ColumnEquals { + left_idx, + right_idx, + }), + BinaryOperator::NotEquals => Ok(FilterPredicate::ColumnNotEquals { + left_idx, + right_idx, + }), + BinaryOperator::Greater => Ok(FilterPredicate::ColumnGreaterThan { + left_idx, + right_idx, + }), + BinaryOperator::GreaterEquals => { + Ok(FilterPredicate::ColumnGreaterThanOrEqual { + left_idx, + right_idx, + }) + } + BinaryOperator::Less => Ok(FilterPredicate::ColumnLessThan { + left_idx, + right_idx, + }), + BinaryOperator::LessEquals => Ok(FilterPredicate::ColumnLessThanOrEqual { + left_idx, + right_idx, + }), + BinaryOperator::And | BinaryOperator::Or => { + // Handle logical operators recursively + let left_pred = Self::compile_filter_predicate(left, schema)?; + let right_pred = Self::compile_filter_predicate(right, schema)?; + match op { + BinaryOperator::And => Ok(FilterPredicate::And( + Box::new(left_pred), + Box::new(right_pred), + )), + BinaryOperator::Or => Ok(FilterPredicate::Or( + Box::new(left_pred), + Box::new(right_pred), + )), + _ => unreachable!(), + } + } + _ => Err(LimboError::ParseError(format!( + "Unsupported operator in filter: {op:?}" + ))), + } + } else if let (LogicalExpr::Column(col), LogicalExpr::Literal(val)) = + (left.as_ref(), right.as_ref()) + { + // Column-to-literal comparisons + let column_idx = schema + .columns + .iter() + .position(|c| c.name == col.name) + .ok_or_else(|| { + crate::LimboError::ParseError(format!( + "Column '{}' not found in schema for filter", + col.name + )) + })?; + match op { BinaryOperator::Equals => Ok(FilterPredicate::Equals { - column: col.name.clone(), + column_idx, value: val.clone(), }), BinaryOperator::NotEquals => Ok(FilterPredicate::NotEquals { - column: col.name.clone(), + column_idx, value: val.clone(), }), BinaryOperator::Greater => Ok(FilterPredicate::GreaterThan { - column: col.name.clone(), + column_idx, value: val.clone(), }), BinaryOperator::GreaterEquals => Ok(FilterPredicate::GreaterThanOrEqual { - column: col.name.clone(), + column_idx, value: val.clone(), }), BinaryOperator::Less => Ok(FilterPredicate::LessThan { - column: col.name.clone(), + column_idx, value: val.clone(), }), BinaryOperator::LessEquals => Ok(FilterPredicate::LessThanOrEqual { - column: col.name.clone(), + column_idx, value: val.clone(), }), BinaryOperator::And => { // Handle AND of two predicates - let left_pred = Self::compile_filter_predicate(left)?; - let right_pred = Self::compile_filter_predicate(right)?; + let left_pred = Self::compile_filter_predicate(left, schema)?; + let right_pred = Self::compile_filter_predicate(right, schema)?; Ok(FilterPredicate::And( Box::new(left_pred), Box::new(right_pred), @@ -1269,8 +1739,8 @@ impl DbspCompiler { } BinaryOperator::Or => { // Handle OR of two predicates - let left_pred = Self::compile_filter_predicate(left)?; - let right_pred = Self::compile_filter_predicate(right)?; + let left_pred = Self::compile_filter_predicate(left, schema)?; + let right_pred = Self::compile_filter_predicate(right, schema)?; Ok(FilterPredicate::Or( Box::new(left_pred), Box::new(right_pred), @@ -1282,8 +1752,8 @@ impl DbspCompiler { } } else if matches!(op, BinaryOperator::And | BinaryOperator::Or) { // Handle logical operators - let left_pred = Self::compile_filter_predicate(left)?; - let right_pred = Self::compile_filter_predicate(right)?; + let left_pred = Self::compile_filter_predicate(left, schema)?; + let right_pred = Self::compile_filter_predicate(right, schema)?; match op { BinaryOperator::And => Ok(FilterPredicate::And( Box::new(left_pred), @@ -1297,7 +1767,7 @@ impl DbspCompiler { } } else { Err(LimboError::ParseError( - "Filter predicate must be column op value".to_string(), + "Filter predicate must be column op value or column op column".to_string(), )) } } @@ -1315,8 +1785,7 @@ mod tests { use crate::incremental::operator::{FilterOperator, FilterPredicate}; use crate::schema::{BTreeTable, Column as SchemaColumn, Schema, Type}; use crate::storage::pager::CreateBTreeFlags; - use crate::translate::logical::LogicalPlanBuilder; - use crate::translate::logical::LogicalSchema; + use crate::translate::logical::{ColumnInfo, LogicalPlanBuilder, LogicalSchema}; use crate::util::IOExt; use crate::{Database, MemoryIO, Pager, IO}; use std::sync::Arc; @@ -1374,6 +1843,270 @@ mod tests { unique_sets: vec![], }; schema.add_btree_table(Arc::new(users_table)); + + // Add products table for join tests + let products_table = BTreeTable { + name: "products".to_string(), + root_page: 3, + primary_key_columns: vec![( + "product_id".to_string(), + turso_parser::ast::SortOrder::Asc, + )], + columns: vec![ + SchemaColumn { + name: Some("product_id".to_string()), + ty: Type::Integer, + ty_str: "INTEGER".to_string(), + primary_key: true, + is_rowid_alias: true, + notnull: true, + default: None, + unique: false, + collation: None, + hidden: false, + }, + SchemaColumn { + name: Some("product_name".to_string()), + ty: Type::Text, + ty_str: "TEXT".to_string(), + primary_key: false, + is_rowid_alias: false, + notnull: false, + default: None, + unique: false, + collation: None, + hidden: false, + }, + SchemaColumn { + name: Some("price".to_string()), + ty: Type::Integer, + ty_str: "INTEGER".to_string(), + primary_key: false, + is_rowid_alias: false, + notnull: false, + default: None, + unique: false, + collation: None, + hidden: false, + }, + ], + has_rowid: true, + is_strict: false, + unique_sets: vec![], + }; + schema.add_btree_table(Arc::new(products_table)); + + // Add orders table for join tests + let orders_table = BTreeTable { + name: "orders".to_string(), + root_page: 4, + primary_key_columns: vec![( + "order_id".to_string(), + turso_parser::ast::SortOrder::Asc, + )], + columns: vec![ + SchemaColumn { + name: Some("order_id".to_string()), + ty: Type::Integer, + ty_str: "INTEGER".to_string(), + primary_key: true, + is_rowid_alias: true, + notnull: true, + default: None, + unique: false, + collation: None, + hidden: false, + }, + SchemaColumn { + name: Some("user_id".to_string()), + ty: Type::Integer, + ty_str: "INTEGER".to_string(), + primary_key: false, + is_rowid_alias: false, + notnull: false, + default: None, + unique: false, + collation: None, + hidden: false, + }, + SchemaColumn { + name: Some("product_id".to_string()), + ty: Type::Integer, + ty_str: "INTEGER".to_string(), + primary_key: false, + is_rowid_alias: false, + notnull: false, + default: None, + unique: false, + collation: None, + hidden: false, + }, + SchemaColumn { + name: Some("quantity".to_string()), + ty: Type::Integer, + ty_str: "INTEGER".to_string(), + primary_key: false, + is_rowid_alias: false, + notnull: false, + default: None, + unique: false, + collation: None, + hidden: false, + }, + ], + has_rowid: true, + is_strict: false, + unique_sets: vec![], + }; + schema.add_btree_table(Arc::new(orders_table)); + + // Add customers table with id and name for testing column ambiguity + let customers_table = BTreeTable { + name: "customers".to_string(), + root_page: 6, + primary_key_columns: vec![("id".to_string(), turso_parser::ast::SortOrder::Asc)], + columns: vec![ + SchemaColumn { + name: Some("id".to_string()), + ty: Type::Integer, + ty_str: "INTEGER".to_string(), + primary_key: true, + is_rowid_alias: true, + notnull: true, + default: None, + unique: false, + collation: None, + hidden: false, + }, + SchemaColumn { + name: Some("name".to_string()), + ty: Type::Text, + ty_str: "TEXT".to_string(), + primary_key: false, + is_rowid_alias: false, + notnull: false, + default: None, + unique: false, + collation: None, + hidden: false, + }, + ], + has_rowid: true, + is_strict: false, + unique_sets: vec![], + }; + schema.add_btree_table(Arc::new(customers_table)); + + // Add purchases table (junction table for three-way join) + let purchases_table = BTreeTable { + name: "purchases".to_string(), + root_page: 7, + primary_key_columns: vec![("id".to_string(), turso_parser::ast::SortOrder::Asc)], + columns: vec![ + SchemaColumn { + name: Some("id".to_string()), + ty: Type::Integer, + ty_str: "INTEGER".to_string(), + primary_key: true, + is_rowid_alias: true, + notnull: true, + default: None, + unique: false, + collation: None, + hidden: false, + }, + SchemaColumn { + name: Some("customer_id".to_string()), + ty: Type::Integer, + ty_str: "INTEGER".to_string(), + primary_key: false, + is_rowid_alias: false, + notnull: false, + default: None, + unique: false, + collation: None, + hidden: false, + }, + SchemaColumn { + name: Some("vendor_id".to_string()), + ty: Type::Integer, + ty_str: "INTEGER".to_string(), + primary_key: false, + is_rowid_alias: false, + notnull: false, + default: None, + unique: false, + collation: None, + hidden: false, + }, + SchemaColumn { + name: Some("quantity".to_string()), + ty: Type::Integer, + ty_str: "INTEGER".to_string(), + primary_key: false, + is_rowid_alias: false, + notnull: false, + default: None, + unique: false, + collation: None, + hidden: false, + }, + ], + has_rowid: true, + is_strict: false, + unique_sets: vec![], + }; + schema.add_btree_table(Arc::new(purchases_table)); + + // Add vendors table with id, name, and price (ambiguous columns with customers) + let vendors_table = BTreeTable { + name: "vendors".to_string(), + root_page: 8, + primary_key_columns: vec![("id".to_string(), turso_parser::ast::SortOrder::Asc)], + columns: vec![ + SchemaColumn { + name: Some("id".to_string()), + ty: Type::Integer, + ty_str: "INTEGER".to_string(), + primary_key: true, + is_rowid_alias: true, + notnull: true, + default: None, + unique: false, + collation: None, + hidden: false, + }, + SchemaColumn { + name: Some("name".to_string()), + ty: Type::Text, + ty_str: "TEXT".to_string(), + primary_key: false, + is_rowid_alias: false, + notnull: false, + default: None, + unique: false, + collation: None, + hidden: false, + }, + SchemaColumn { + name: Some("price".to_string()), + ty: Type::Integer, + ty_str: "INTEGER".to_string(), + primary_key: false, + is_rowid_alias: false, + notnull: false, + default: None, + unique: false, + collation: None, + hidden: false, + }, + ], + has_rowid: true, + is_strict: false, + unique_sets: vec![], + }; + schema.add_btree_table(Arc::new(vendors_table)); + let sales_table = BTreeTable { name: "sales".to_string(), root_page: 2, @@ -3342,8 +4075,20 @@ mod tests { // Create a simple filter node let schema = Arc::new(LogicalSchema::new(vec![ - ("id".to_string(), Type::Integer), - ("value".to_string(), Type::Integer), + ColumnInfo { + name: "id".to_string(), + ty: Type::Integer, + database: None, + table: None, + table_alias: None, + }, + ColumnInfo { + name: "value".to_string(), + ty: Type::Integer, + database: None, + table: None, + table_alias: None, + }, ])); // First create an input node with InputOperator @@ -3356,13 +4101,10 @@ mod tests { Box::new(InputOperator::new("test".to_string())), ); - let filter_op = FilterOperator::new( - FilterPredicate::GreaterThan { - column: "value".to_string(), - value: Value::Integer(10), - }, - vec!["id".to_string(), "value".to_string()], - ); + let filter_op = FilterOperator::new(FilterPredicate::GreaterThan { + column_idx: 1, // "value" is at index 1 + value: Value::Integer(10), + }); // Create the filter predicate using DbspExpr let predicate = DbspExpr::BinaryExpr { @@ -3486,4 +4228,1140 @@ mod tests { "Row should still exist with multiplicity 1" ); } + + #[test] + fn test_join_with_aggregation() { + // Test join followed by aggregation - verifying actual output + let (mut circuit, pager) = compile_sql!( + "SELECT u.name, SUM(o.quantity) as total_quantity + FROM users u + JOIN orders o ON u.id = o.user_id + GROUP BY u.name" + ); + + // Create test data for users + let mut users_delta = Delta::new(); + users_delta.insert( + 1, + vec![ + Value::Integer(1), + Value::Text("Alice".into()), + Value::Integer(30), + ], + ); + users_delta.insert( + 2, + vec![ + Value::Integer(2), + Value::Text("Bob".into()), + Value::Integer(25), + ], + ); + + // Create test data for orders (order_id, user_id, product_id, quantity) + let mut orders_delta = Delta::new(); + orders_delta.insert( + 1, + vec![ + Value::Integer(1), + Value::Integer(1), + Value::Integer(101), + Value::Integer(5), + ], + ); // Alice: 5 + orders_delta.insert( + 2, + vec![ + Value::Integer(2), + Value::Integer(1), + Value::Integer(102), + Value::Integer(3), + ], + ); // Alice: 3 + orders_delta.insert( + 3, + vec![ + Value::Integer(3), + Value::Integer(2), + Value::Integer(101), + Value::Integer(7), + ], + ); // Bob: 7 + orders_delta.insert( + 4, + vec![ + Value::Integer(4), + Value::Integer(1), + Value::Integer(103), + Value::Integer(2), + ], + ); // Alice: 2 + + let inputs = HashMap::from([ + ("users".to_string(), users_delta), + ("orders".to_string(), orders_delta), + ]); + + let result = test_execute(&mut circuit, inputs.clone(), pager.clone()).unwrap(); + + // Should have 2 results: Alice with total 10, Bob with total 7 + assert_eq!( + result.len(), + 2, + "Should have aggregated results for Alice and Bob" + ); + + // Check the results + let mut results_map: HashMap = HashMap::new(); + for (row, weight) in result.changes { + assert_eq!(weight, 1); + assert_eq!(row.values.len(), 2); // name and total_quantity + + if let (Value::Text(name), Value::Integer(total)) = (&row.values[0], &row.values[1]) { + results_map.insert(name.to_string(), *total); + } else { + panic!("Unexpected value types in result"); + } + } + + assert_eq!( + results_map.get("Alice"), + Some(&10), + "Alice should have total quantity 10" + ); + assert_eq!( + results_map.get("Bob"), + Some(&7), + "Bob should have total quantity 7" + ); + } + + #[test] + fn test_join_aggregate_with_filter() { + // Test complex query with join, filter, and aggregation - verifying output + let (mut circuit, pager) = compile_sql!( + "SELECT u.name, SUM(o.quantity) as total + FROM users u + JOIN orders o ON u.id = o.user_id + WHERE u.age > 18 + GROUP BY u.name" + ); + + // Create test data for users + let mut users_delta = Delta::new(); + users_delta.insert( + 1, + vec![ + Value::Integer(1), + Value::Text("Alice".into()), + Value::Integer(30), + ], + ); // age > 18 + users_delta.insert( + 2, + vec![ + Value::Integer(2), + Value::Text("Bob".into()), + Value::Integer(17), + ], + ); // age <= 18 + users_delta.insert( + 3, + vec![ + Value::Integer(3), + Value::Text("Charlie".into()), + Value::Integer(25), + ], + ); // age > 18 + + // Create test data for orders (order_id, user_id, product_id, quantity) + let mut orders_delta = Delta::new(); + orders_delta.insert( + 1, + vec![ + Value::Integer(1), + Value::Integer(1), + Value::Integer(101), + Value::Integer(5), + ], + ); // Alice: 5 + orders_delta.insert( + 2, + vec![ + Value::Integer(2), + Value::Integer(2), + Value::Integer(102), + Value::Integer(10), + ], + ); // Bob: 10 (should be filtered) + orders_delta.insert( + 3, + vec![ + Value::Integer(3), + Value::Integer(3), + Value::Integer(101), + Value::Integer(7), + ], + ); // Charlie: 7 + orders_delta.insert( + 4, + vec![ + Value::Integer(4), + Value::Integer(1), + Value::Integer(103), + Value::Integer(3), + ], + ); // Alice: 3 + + let inputs = HashMap::from([ + ("users".to_string(), users_delta), + ("orders".to_string(), orders_delta), + ]); + + let result = test_execute(&mut circuit, inputs.clone(), pager.clone()).unwrap(); + + // Should only have results for Alice and Charlie (Bob filtered out due to age <= 18) + assert_eq!( + result.len(), + 2, + "Should only have results for users with age > 18" + ); + + // Check the results + let mut results_map: HashMap = HashMap::new(); + for (row, weight) in result.changes { + assert_eq!(weight, 1); + assert_eq!(row.values.len(), 2); // name and total + + if let (Value::Text(name), Value::Integer(total)) = (&row.values[0], &row.values[1]) { + results_map.insert(name.to_string(), *total); + } + } + + assert_eq!( + results_map.get("Alice"), + Some(&8), + "Alice should have total 8" + ); + assert_eq!( + results_map.get("Charlie"), + Some(&7), + "Charlie should have total 7" + ); + assert_eq!(results_map.get("Bob"), None, "Bob should be filtered out"); + } + + #[test] + fn test_three_way_join_execution() { + // Test executing a 3-way join with aggregation + let (mut circuit, pager) = compile_sql!( + "SELECT u.name, p.product_name, SUM(o.quantity) as total + FROM users u + JOIN orders o ON u.id = o.user_id + JOIN products p ON o.product_id = p.product_id + GROUP BY u.name, p.product_name" + ); + + // Create test data for users + let mut users_delta = Delta::new(); + users_delta.insert( + 1, + vec![ + Value::Integer(1), + Value::Text("Alice".into()), + Value::Integer(25), + ], + ); + users_delta.insert( + 2, + vec![ + Value::Integer(2), + Value::Text("Bob".into()), + Value::Integer(30), + ], + ); + + // Create test data for products + let mut products_delta = Delta::new(); + products_delta.insert( + 100, + vec![ + Value::Integer(100), + Value::Text("Widget".into()), + Value::Integer(50), + ], + ); + products_delta.insert( + 101, + vec![ + Value::Integer(101), + Value::Text("Gadget".into()), + Value::Integer(75), + ], + ); + products_delta.insert( + 102, + vec![ + Value::Integer(102), + Value::Text("Doohickey".into()), + Value::Integer(25), + ], + ); + + // Create test data for orders joining users and products + let mut orders_delta = Delta::new(); + // Alice orders 5 Widgets + orders_delta.insert( + 1, + vec![ + Value::Integer(1), + Value::Integer(1), + Value::Integer(100), + Value::Integer(5), + ], + ); + // Alice orders 3 Gadgets + orders_delta.insert( + 2, + vec![ + Value::Integer(2), + Value::Integer(1), + Value::Integer(101), + Value::Integer(3), + ], + ); + // Bob orders 7 Widgets + orders_delta.insert( + 3, + vec![ + Value::Integer(3), + Value::Integer(2), + Value::Integer(100), + Value::Integer(7), + ], + ); + // Bob orders 2 Doohickeys + orders_delta.insert( + 4, + vec![ + Value::Integer(4), + Value::Integer(2), + Value::Integer(102), + Value::Integer(2), + ], + ); + // Alice orders 4 more Widgets + orders_delta.insert( + 5, + vec![ + Value::Integer(5), + Value::Integer(1), + Value::Integer(100), + Value::Integer(4), + ], + ); + + let mut inputs = HashMap::new(); + inputs.insert("users".to_string(), users_delta); + inputs.insert("products".to_string(), products_delta); + inputs.insert("orders".to_string(), orders_delta); + + // Execute the 3-way join with aggregation + let result = test_execute(&mut circuit, inputs.clone(), pager.clone()).unwrap(); + + // We should get aggregated results for each user-product combination + // Expected results: + // - Alice, Widget: 9 (5 + 4) + // - Alice, Gadget: 3 + // - Bob, Widget: 7 + // - Bob, Doohickey: 2 + assert_eq!(result.len(), 4, "Should have 4 aggregated results"); + + // Verify aggregation results + let mut found_results = std::collections::HashSet::new(); + for (row, weight) in result.changes.iter() { + assert_eq!(*weight, 1); + // Row should have name, product_name, and sum columns + assert_eq!(row.values.len(), 3); + + if let (Value::Text(name), Value::Text(product), Value::Integer(total)) = + (&row.values[0], &row.values[1], &row.values[2]) + { + let key = format!("{}-{}", name.as_ref(), product.as_ref()); + found_results.insert(key.clone()); + + match key.as_str() { + "Alice-Widget" => { + assert_eq!(*total, 9, "Alice should have ordered 9 Widgets total") + } + "Alice-Gadget" => assert_eq!(*total, 3, "Alice should have ordered 3 Gadgets"), + "Bob-Widget" => assert_eq!(*total, 7, "Bob should have ordered 7 Widgets"), + "Bob-Doohickey" => { + assert_eq!(*total, 2, "Bob should have ordered 2 Doohickeys") + } + _ => panic!("Unexpected result: {key}"), + } + } else { + panic!("Unexpected value types in result"); + } + } + + // Ensure we found all expected combinations + assert!(found_results.contains("Alice-Widget")); + assert!(found_results.contains("Alice-Gadget")); + assert!(found_results.contains("Bob-Widget")); + assert!(found_results.contains("Bob-Doohickey")); + } + + #[test] + fn test_join_execution() { + let (mut circuit, pager) = compile_sql!( + "SELECT u.name, o.quantity FROM users u JOIN orders o ON u.id = o.user_id" + ); + + // Create test data for users + let mut users_delta = Delta::new(); + users_delta.insert( + 1, + vec![ + Value::Integer(1), + Value::Text("Alice".into()), + Value::Integer(25), + ], + ); + users_delta.insert( + 2, + vec![ + Value::Integer(2), + Value::Text("Bob".into()), + Value::Integer(30), + ], + ); + + // Create test data for orders + let mut orders_delta = Delta::new(); + orders_delta.insert( + 1, + vec![ + Value::Integer(1), + Value::Integer(1), + Value::Integer(100), + Value::Integer(5), + ], + ); + orders_delta.insert( + 2, + vec![ + Value::Integer(2), + Value::Integer(1), + Value::Integer(101), + Value::Integer(3), + ], + ); + orders_delta.insert( + 3, + vec![ + Value::Integer(3), + Value::Integer(2), + Value::Integer(102), + Value::Integer(7), + ], + ); + + let mut inputs = HashMap::new(); + inputs.insert("users".to_string(), users_delta); + inputs.insert("orders".to_string(), orders_delta); + + // Execute the join + let result = test_execute(&mut circuit, inputs.clone(), pager.clone()).unwrap(); + + // We should get 3 results (2 orders for Alice, 1 for Bob) + assert_eq!(result.len(), 3, "Should have 3 join results"); + + // Verify the join results contain the correct data + let results: Vec<_> = result.changes.iter().collect(); + + // Check that we have the expected joined rows + for (row, weight) in results { + assert_eq!(*weight, 1); // All weights should be 1 for insertions + // Row should have name and quantity columns + assert_eq!(row.values.len(), 2); + } + } + + #[test] + fn test_three_way_join_with_column_ambiguity() { + // Test three-way join with aggregation where multiple tables have columns with the same name + // Ensures that column references are correctly resolved to their respective tables + // Tables: customers(id, name), purchases(id, customer_id, vendor_id, quantity), vendors(id, name, price) + // Note: both customers and vendors have 'id' and 'name' columns which can cause ambiguity + + let sql = "SELECT c.name as customer_name, v.name as vendor_name, + SUM(p.quantity) as total_quantity, + SUM(p.quantity * v.price) as total_value + FROM customers c + JOIN purchases p ON c.id = p.customer_id + JOIN vendors v ON p.vendor_id = v.id + GROUP BY c.name, v.name"; + + let (mut circuit, pager) = compile_sql!(sql); + + // Create test data for customers (id, name) + let mut customers_delta = Delta::new(); + customers_delta.insert(1, vec![Value::Integer(1), Value::Text("Alice".into())]); + customers_delta.insert(2, vec![Value::Integer(2), Value::Text("Bob".into())]); + + // Create test data for vendors (id, name, price) + let mut vendors_delta = Delta::new(); + vendors_delta.insert( + 1, + vec![ + Value::Integer(1), + Value::Text("Widget Co".into()), + Value::Integer(10), + ], + ); + vendors_delta.insert( + 2, + vec![ + Value::Integer(2), + Value::Text("Gadget Inc".into()), + Value::Integer(20), + ], + ); + + // Create test data for purchases (id, customer_id, vendor_id, quantity) + let mut purchases_delta = Delta::new(); + // Alice purchases 5 units from Widget Co + purchases_delta.insert( + 1, + vec![ + Value::Integer(1), + Value::Integer(1), // customer_id: Alice + Value::Integer(1), // vendor_id: Widget Co + Value::Integer(5), + ], + ); + // Alice purchases 3 units from Gadget Inc + purchases_delta.insert( + 2, + vec![ + Value::Integer(2), + Value::Integer(1), // customer_id: Alice + Value::Integer(2), // vendor_id: Gadget Inc + Value::Integer(3), + ], + ); + // Bob purchases 2 units from Widget Co + purchases_delta.insert( + 3, + vec![ + Value::Integer(3), + Value::Integer(2), // customer_id: Bob + Value::Integer(1), // vendor_id: Widget Co + Value::Integer(2), + ], + ); + // Alice purchases 4 more units from Widget Co + purchases_delta.insert( + 4, + vec![ + Value::Integer(4), + Value::Integer(1), // customer_id: Alice + Value::Integer(1), // vendor_id: Widget Co + Value::Integer(4), + ], + ); + + let inputs = HashMap::from([ + ("customers".to_string(), customers_delta), + ("purchases".to_string(), purchases_delta), + ("vendors".to_string(), vendors_delta), + ]); + + let result = test_execute(&mut circuit, inputs.clone(), pager.clone()).unwrap(); + + // Expected results: + // Alice|Gadget Inc|3|60 (3 units * 20 price = 60) + // Alice|Widget Co|9|90 (9 units * 10 price = 90) + // Bob|Widget Co|2|20 (2 units * 10 price = 20) + + assert_eq!(result.len(), 3, "Should have 3 aggregated results"); + + // Sort results for consistent testing + let mut results: Vec<_> = result.changes.into_iter().collect(); + results.sort_by(|a, b| { + let a_cust = &a.0.values[0]; + let a_vend = &a.0.values[1]; + let b_cust = &b.0.values[0]; + let b_vend = &b.0.values[1]; + (a_cust, a_vend).cmp(&(b_cust, b_vend)) + }); + + // Verify Alice's Gadget Inc purchases + assert_eq!(results[0].0.values[0], Value::Text("Alice".into())); + assert_eq!(results[0].0.values[1], Value::Text("Gadget Inc".into())); + assert_eq!(results[0].0.values[2], Value::Integer(3)); // total_quantity + assert_eq!(results[0].0.values[3], Value::Integer(60)); // total_value + + // Verify Alice's Widget Co purchases + assert_eq!(results[1].0.values[0], Value::Text("Alice".into())); + assert_eq!(results[1].0.values[1], Value::Text("Widget Co".into())); + assert_eq!(results[1].0.values[2], Value::Integer(9)); // total_quantity + assert_eq!(results[1].0.values[3], Value::Integer(90)); // total_value + + // Verify Bob's Widget Co purchases + assert_eq!(results[2].0.values[0], Value::Text("Bob".into())); + assert_eq!(results[2].0.values[1], Value::Text("Widget Co".into())); + assert_eq!(results[2].0.values[2], Value::Integer(2)); // total_quantity + assert_eq!(results[2].0.values[3], Value::Integer(20)); // total_value + } + + #[test] + fn test_projection_with_function_and_ambiguous_columns() { + // Test projection with functions operating on potentially ambiguous columns + // Uses HEX() function on sum of columns from different tables with same names + // Tables: customers(id, name), vendors(id, name, price), purchases(id, customer_id, vendor_id, quantity) + // This test ensures column references are correctly resolved to their respective tables + + let sql = "SELECT HEX(c.id + v.id) as hex_sum, + UPPER(c.name) as customer_upper, + LOWER(v.name) as vendor_lower, + c.id * v.price as product_value + FROM customers c + JOIN vendors v ON c.id = v.id"; + + let (mut circuit, pager) = compile_sql!(sql); + + // Create test data for customers (id, name) + let mut customers_delta = Delta::new(); + customers_delta.insert(1, vec![Value::Integer(1), Value::Text("Alice".into())]); + customers_delta.insert(2, vec![Value::Integer(2), Value::Text("Bob".into())]); + customers_delta.insert(3, vec![Value::Integer(3), Value::Text("Charlie".into())]); + + // Create test data for vendors (id, name, price) + let mut vendors_delta = Delta::new(); + vendors_delta.insert( + 1, + vec![ + Value::Integer(1), + Value::Text("Widget Co".into()), + Value::Integer(10), + ], + ); + vendors_delta.insert( + 2, + vec![ + Value::Integer(2), + Value::Text("Gadget Inc".into()), + Value::Integer(20), + ], + ); + vendors_delta.insert( + 3, + vec![ + Value::Integer(3), + Value::Text("Tool Corp".into()), + Value::Integer(30), + ], + ); + + let inputs = HashMap::from([ + ("customers".to_string(), customers_delta), + ("vendors".to_string(), vendors_delta), + ]); + + let result = test_execute(&mut circuit, inputs.clone(), pager.clone()).unwrap(); + + // Expected results: + // For customer 1 (Alice) + vendor 1: + // - HEX(1 + 1) = HEX(2) = "32" + // - UPPER("Alice") = "ALICE" + // - LOWER("Widget Co") = "widget co" + // - 1 * 10 = 10 + assert_eq!(result.len(), 3, "Should have 3 join results"); + + let mut results = result.changes.clone(); + results.sort_by_key(|(row, _)| { + // Sort by the product_value column for predictable ordering + match &row.values[3] { + Value::Integer(n) => *n, + _ => 0, + } + }); + + // First result: Alice + Widget Co + assert_eq!(results[0].0.values[0], Value::Text("32".into())); // HEX(2) + assert_eq!(results[0].0.values[1], Value::Text("ALICE".into())); + assert_eq!(results[0].0.values[2], Value::Text("widget co".into())); + assert_eq!(results[0].0.values[3], Value::Integer(10)); // 1 * 10 + + // Second result: Bob + Gadget Inc + assert_eq!(results[1].0.values[0], Value::Text("34".into())); // HEX(4) + assert_eq!(results[1].0.values[1], Value::Text("BOB".into())); + assert_eq!(results[1].0.values[2], Value::Text("gadget inc".into())); + assert_eq!(results[1].0.values[3], Value::Integer(40)); // 2 * 20 + + // Third result: Charlie + Tool Corp + assert_eq!(results[2].0.values[0], Value::Text("36".into())); // HEX(6) + assert_eq!(results[2].0.values[1], Value::Text("CHARLIE".into())); + assert_eq!(results[2].0.values[2], Value::Text("tool corp".into())); + assert_eq!(results[2].0.values[3], Value::Integer(90)); // 3 * 30 + } + + #[test] + fn test_projection_column_selection_after_join() { + // Test selecting specific columns after a join, especially with overlapping column names + // This ensures the projection correctly picks columns by their qualified references + + let sql = "SELECT c.id as customer_id, + c.name as customer_name, + o.order_id, + o.quantity, + p.product_name + FROM users c + JOIN orders o ON c.id = o.user_id + JOIN products p ON o.product_id = p.product_id + WHERE o.quantity > 2"; + + let (mut circuit, pager) = compile_sql!(sql); + + // Create test data for users (id, name, age) + let mut users_delta = Delta::new(); + users_delta.insert( + 1, + vec![ + Value::Integer(1), + Value::Text("Alice".into()), + Value::Integer(25), + ], + ); + users_delta.insert( + 2, + vec![ + Value::Integer(2), + Value::Text("Bob".into()), + Value::Integer(30), + ], + ); + + // Create test data for orders (order_id, user_id, product_id, quantity) + let mut orders_delta = Delta::new(); + orders_delta.insert( + 1, + vec![ + Value::Integer(101), + Value::Integer(1), // Alice + Value::Integer(201), // Widget + Value::Integer(5), // quantity > 2 + ], + ); + orders_delta.insert( + 2, + vec![ + Value::Integer(102), + Value::Integer(2), // Bob + Value::Integer(202), // Gadget + Value::Integer(1), // quantity <= 2, filtered out + ], + ); + orders_delta.insert( + 3, + vec![ + Value::Integer(103), + Value::Integer(1), // Alice + Value::Integer(202), // Gadget + Value::Integer(3), // quantity > 2 + ], + ); + + // Create test data for products (product_id, product_name, price) + let mut products_delta = Delta::new(); + products_delta.insert( + 201, + vec![ + Value::Integer(201), + Value::Text("Widget".into()), + Value::Integer(10), + ], + ); + products_delta.insert( + 202, + vec![ + Value::Integer(202), + Value::Text("Gadget".into()), + Value::Integer(20), + ], + ); + + let inputs = HashMap::from([ + ("users".to_string(), users_delta), + ("orders".to_string(), orders_delta), + ("products".to_string(), products_delta), + ]); + + let result = test_execute(&mut circuit, inputs.clone(), pager.clone()).unwrap(); + + // Should have 2 results (orders with quantity > 2) + assert_eq!(result.len(), 2, "Should have 2 results after filtering"); + + let mut results = result.changes.clone(); + results.sort_by_key(|(row, _)| { + match &row.values[2] { + // Sort by order_id + Value::Integer(n) => *n, + _ => 0, + } + }); + + // First result: Alice's order 101 for Widget + assert_eq!(results[0].0.values[0], Value::Integer(1)); // customer_id + assert_eq!(results[0].0.values[1], Value::Text("Alice".into())); // customer_name + assert_eq!(results[0].0.values[2], Value::Integer(101)); // order_id + assert_eq!(results[0].0.values[3], Value::Integer(5)); // quantity + assert_eq!(results[0].0.values[4], Value::Text("Widget".into())); // product_name + + // Second result: Alice's order 103 for Gadget + assert_eq!(results[1].0.values[0], Value::Integer(1)); // customer_id + assert_eq!(results[1].0.values[1], Value::Text("Alice".into())); // customer_name + assert_eq!(results[1].0.values[2], Value::Integer(103)); // order_id + assert_eq!(results[1].0.values[3], Value::Integer(3)); // quantity + assert_eq!(results[1].0.values[4], Value::Text("Gadget".into())); // product_name + } + + #[test] + fn test_projection_column_reordering_and_duplication() { + // Test that projection can reorder columns and select the same column multiple times + // This is important for views that need specific column arrangements + + let sql = "SELECT o.quantity, + u.name, + u.id, + o.quantity * 2 as double_quantity, + u.id as user_id_again + FROM users u + JOIN orders o ON u.id = o.user_id + WHERE u.id = 1"; + + let (mut circuit, pager) = compile_sql!(sql); + + // Create test data for users + let mut users_delta = Delta::new(); + users_delta.insert( + 1, + vec![ + Value::Integer(1), + Value::Text("Alice".into()), + Value::Integer(25), + ], + ); + + // Create test data for orders + let mut orders_delta = Delta::new(); + orders_delta.insert( + 1, + vec![ + Value::Integer(101), + Value::Integer(1), // user_id + Value::Integer(201), // product_id + Value::Integer(5), // quantity + ], + ); + orders_delta.insert( + 2, + vec![ + Value::Integer(102), + Value::Integer(1), // user_id + Value::Integer(202), // product_id + Value::Integer(3), // quantity + ], + ); + + let inputs = HashMap::from([ + ("users".to_string(), users_delta), + ("orders".to_string(), orders_delta), + ]); + + let result = test_execute(&mut circuit, inputs.clone(), pager.clone()).unwrap(); + + assert_eq!(result.len(), 2, "Should have 2 results for user 1"); + + // Check that columns are in the right order and values are correct + for (row, _) in &result.changes { + // Column 0: o.quantity (5 or 3) + assert!(matches!( + row.values[0], + Value::Integer(5) | Value::Integer(3) + )); + // Column 1: u.name + assert_eq!(row.values[1], Value::Text("Alice".into())); + // Column 2: u.id + assert_eq!(row.values[2], Value::Integer(1)); + // Column 3: o.quantity * 2 (10 or 6) + assert!(matches!( + row.values[3], + Value::Integer(10) | Value::Integer(6) + )); + // Column 4: u.id again + assert_eq!(row.values[4], Value::Integer(1)); + } + } + + #[test] + fn test_join_with_aggregate_execution() { + let (mut circuit, pager) = compile_sql!( + "SELECT u.name, SUM(o.quantity) as total_quantity + FROM users u + JOIN orders o ON u.id = o.user_id + GROUP BY u.name" + ); + + // Create test data for users + let mut users_delta = Delta::new(); + users_delta.insert( + 1, + vec![ + Value::Integer(1), + Value::Text("Alice".into()), + Value::Integer(25), + ], + ); + users_delta.insert( + 2, + vec![ + Value::Integer(2), + Value::Text("Bob".into()), + Value::Integer(30), + ], + ); + + // Create test data for orders + let mut orders_delta = Delta::new(); + orders_delta.insert( + 1, + vec![ + Value::Integer(1), + Value::Integer(1), + Value::Integer(100), + Value::Integer(5), + ], + ); + orders_delta.insert( + 2, + vec![ + Value::Integer(2), + Value::Integer(1), + Value::Integer(101), + Value::Integer(3), + ], + ); + orders_delta.insert( + 3, + vec![ + Value::Integer(3), + Value::Integer(2), + Value::Integer(102), + Value::Integer(7), + ], + ); + + let mut inputs = HashMap::new(); + inputs.insert("users".to_string(), users_delta); + inputs.insert("orders".to_string(), orders_delta); + + // Execute the join with aggregation + let result = test_execute(&mut circuit, inputs.clone(), pager.clone()).unwrap(); + + // We should get 2 aggregated results (one for Alice, one for Bob) + assert_eq!(result.len(), 2, "Should have 2 aggregated results"); + + // Verify aggregation results + for (row, weight) in result.changes.iter() { + assert_eq!(*weight, 1); + // Row should have name and sum columns + assert_eq!(row.values.len(), 2); + + // Check the aggregated values + if let Value::Text(name) = &row.values[0] { + if name.as_ref() == "Alice" { + // Alice should have total quantity of 8 (5 + 3) + assert_eq!(row.values[1], Value::Integer(8)); + } else if name.as_ref() == "Bob" { + // Bob should have total quantity of 7 + assert_eq!(row.values[1], Value::Integer(7)); + } + } + } + } + + #[test] + fn test_filter_with_qualified_columns_in_join() { + // Test that filters correctly handle qualified column names in joins + // when multiple tables have columns with the SAME names. + // Both users and customers tables have 'id' and 'name' columns which can be ambiguous. + + let (mut circuit, pager) = compile_sql!( + "SELECT users.id, users.name, customers.id, customers.name + FROM users + JOIN customers ON users.id = customers.id + WHERE users.id > 1 AND customers.id < 100" + ); + + // Create test data + let mut users_delta = Delta::new(); + let mut customers_delta = Delta::new(); + + // Users data: (id, name, age) + users_delta.insert( + 1, + vec![ + Value::Integer(1), + Value::Text("Alice".into()), + Value::Integer(30), + ], + ); // id = 1 + users_delta.insert( + 2, + vec![ + Value::Integer(2), + Value::Text("Bob".into()), + Value::Integer(25), + ], + ); // id = 2 + users_delta.insert( + 3, + vec![ + Value::Integer(3), + Value::Text("Charlie".into()), + Value::Integer(35), + ], + ); // id = 3 + + // Customers data: (id, name, email) + customers_delta.insert( + 1, + vec![ + Value::Integer(1), + Value::Text("Customer Alice".into()), + Value::Text("alice@example.com".into()), + ], + ); // id = 1 + customers_delta.insert( + 2, + vec![ + Value::Integer(2), + Value::Text("Customer Bob".into()), + Value::Text("bob@example.com".into()), + ], + ); // id = 2 + customers_delta.insert( + 3, + vec![ + Value::Integer(3), + Value::Text("Customer Charlie".into()), + Value::Text("charlie@example.com".into()), + ], + ); // id = 3 + + let mut inputs = HashMap::new(); + inputs.insert("users".to_string(), users_delta); + inputs.insert("customers".to_string(), customers_delta); + + let result = test_execute(&mut circuit, inputs.clone(), pager.clone()).unwrap(); + + // Should get rows where users.id > 1 AND customers.id < 100 + // - users.id=2 (> 1) AND customers.id=2 (< 100) ✓ + // - users.id=3 (> 1) AND customers.id=3 (< 100) ✓ + // Alice excluded: users.id=1 (NOT > 1) + assert_eq!(result.len(), 2, "Should have 2 filtered results"); + + let (row, weight) = &result.changes[0]; + assert_eq!(*weight, 1); + assert_eq!(row.values.len(), 4, "Should have 4 columns"); + + // Verify the filter correctly used qualified columns for Bob + assert_eq!(row.values[0], Value::Integer(2), "users.id should be 2"); + assert_eq!( + row.values[1], + Value::Text("Bob".into()), + "users.name should be Bob" + ); + assert_eq!(row.values[2], Value::Integer(2), "customers.id should be 2"); + assert_eq!( + row.values[3], + Value::Text("Customer Bob".into()), + "customers.name should be Customer Bob" + ); + } + + #[test] + fn test_expression_in_where_clause() { + // Test expressions in WHERE clauses like (quantity * price) >= 400 + let (mut circuit, pager) = compile_sql!("SELECT * FROM users WHERE (age * 2) > 30"); + + // Create test data + let mut input_delta = Delta::new(); + input_delta.insert( + 1, + vec![ + Value::Integer(1), + Value::Text("Alice".into()), + Value::Integer(20), // age * 2 = 40 > 30, should pass + ], + ); + input_delta.insert( + 2, + vec![ + Value::Integer(2), + Value::Text("Bob".into()), + Value::Integer(10), // age * 2 = 20 <= 30, should be filtered out + ], + ); + input_delta.insert( + 3, + vec![ + Value::Integer(3), + Value::Text("Charlie".into()), + Value::Integer(16), // age * 2 = 32 > 30, should pass + ], + ); + + // Create input map + let mut inputs = HashMap::new(); + inputs.insert("users".to_string(), input_delta); + + let result = test_execute(&mut circuit, inputs.clone(), pager.clone()).unwrap(); + + // Should only have Alice and Charlie (age * 2 > 30) + assert_eq!( + result.changes.len(), + 2, + "Should have 2 rows after filtering" + ); + + // Check Alice + let alice = result + .changes + .iter() + .find(|(row, _)| row.values[0] == Value::Integer(1)) + .expect("Alice should be in result"); + assert_eq!(alice.0.values[1], Value::Text("Alice".into())); + assert_eq!(alice.0.values[2], Value::Integer(20)); + + // Check Charlie + let charlie = result + .changes + .iter() + .find(|(row, _)| row.values[0] == Value::Integer(3)) + .expect("Charlie should be in result"); + assert_eq!(charlie.0.values[1], Value::Text("Charlie".into())); + assert_eq!(charlie.0.values[2], Value::Integer(16)); + + // Bob should not be in result + let bob = result + .changes + .iter() + .find(|(row, _)| row.values[0] == Value::Integer(2)); + assert!(bob.is_none(), "Bob should be filtered out"); + } } diff --git a/core/incremental/cursor.rs b/core/incremental/cursor.rs index 9bf39f53d..bf500b450 100644 --- a/core/incremental/cursor.rs +++ b/core/incremental/cursor.rs @@ -355,7 +355,7 @@ mod tests { "View not materialized".to_string(), )); } - let num_columns = view.columns.len(); + let num_columns = view.column_schema.columns.len(); drop(view); // Create a btree cursor diff --git a/core/incremental/dbsp.rs b/core/incremental/dbsp.rs index 363ac1142..eeab315d3 100644 --- a/core/incremental/dbsp.rs +++ b/core/incremental/dbsp.rs @@ -75,6 +75,10 @@ impl HashableRow { hasher.finish() } + + pub fn cached_hash(&self) -> u64 { + self.cached_hash + } } impl Hash for HashableRow { @@ -168,7 +172,7 @@ impl Delta { } /// A pair of deltas for operators that process two inputs -#[derive(Debug, Clone)] +#[derive(Debug, Clone, Default)] pub struct DeltaPair { pub left: Delta, pub right: Delta, @@ -400,4 +404,57 @@ mod tests { let weight = zset.iter().find(|(k, _)| **k == 1).map(|(_, w)| w); assert_eq!(weight, Some(1)); } + + #[test] + fn test_hashable_row_delta_operations() { + let mut delta = Delta::new(); + + // Test INSERT + delta.insert(1, vec![Value::Integer(1), Value::Integer(100)]); + assert_eq!(delta.len(), 1); + + // Test UPDATE (DELETE + INSERT) - order matters! + delta.delete(1, vec![Value::Integer(1), Value::Integer(100)]); + delta.insert(1, vec![Value::Integer(1), Value::Integer(200)]); + assert_eq!(delta.len(), 3); // Should have 3 operations before consolidation + + // Verify order is preserved + let ops: Vec<_> = delta.changes.iter().collect(); + assert_eq!(ops[0].1, 1); // First insert + assert_eq!(ops[1].1, -1); // Delete + assert_eq!(ops[2].1, 1); // Second insert + + // Test consolidation + delta.consolidate(); + // After consolidation, the first insert and delete should cancel out + // leaving only the second insert + assert_eq!(delta.len(), 1); + + let final_row = &delta.changes[0]; + assert_eq!(final_row.0.rowid, 1); + assert_eq!( + final_row.0.values, + vec![Value::Integer(1), Value::Integer(200)] + ); + assert_eq!(final_row.1, 1); + } + + #[test] + fn test_duplicate_row_consolidation() { + let mut delta = Delta::new(); + + // Insert same row twice + delta.insert(2, vec![Value::Integer(2), Value::Integer(300)]); + delta.insert(2, vec![Value::Integer(2), Value::Integer(300)]); + + assert_eq!(delta.len(), 2); + + delta.consolidate(); + assert_eq!(delta.len(), 1); + + // Weight should be 2 (sum of both inserts) + let final_row = &delta.changes[0]; + assert_eq!(final_row.0.rowid, 2); + assert_eq!(final_row.1, 2); + } } diff --git a/core/incremental/filter_operator.rs b/core/incremental/filter_operator.rs new file mode 100644 index 000000000..84a3c53ce --- /dev/null +++ b/core/incremental/filter_operator.rs @@ -0,0 +1,295 @@ +#![allow(dead_code)] +// Filter operator for DBSP-style incremental computation +// This operator filters rows based on predicates + +use crate::incremental::dbsp::{Delta, DeltaPair}; +use crate::incremental::operator::{ + ComputationTracker, DbspStateCursors, EvalState, IncrementalOperator, +}; +use crate::types::IOResult; +use crate::{Result, Value}; +use std::sync::{Arc, Mutex}; + +/// Filter predicate for filtering rows +#[derive(Debug, Clone)] +pub enum FilterPredicate { + /// Column = value (using column index) + Equals { column_idx: usize, value: Value }, + /// Column != value (using column index) + NotEquals { column_idx: usize, value: Value }, + /// Column > value (using column index) + GreaterThan { column_idx: usize, value: Value }, + /// Column >= value (using column index) + GreaterThanOrEqual { column_idx: usize, value: Value }, + /// Column < value (using column index) + LessThan { column_idx: usize, value: Value }, + /// Column <= value (using column index) + LessThanOrEqual { column_idx: usize, value: Value }, + + /// Column = Column comparisons + ColumnEquals { left_idx: usize, right_idx: usize }, + /// Column != Column comparisons + ColumnNotEquals { left_idx: usize, right_idx: usize }, + /// Column > Column comparisons + ColumnGreaterThan { left_idx: usize, right_idx: usize }, + /// Column >= Column comparisons + ColumnGreaterThanOrEqual { left_idx: usize, right_idx: usize }, + /// Column < Column comparisons + ColumnLessThan { left_idx: usize, right_idx: usize }, + /// Column <= Column comparisons + ColumnLessThanOrEqual { left_idx: usize, right_idx: usize }, + + /// Logical AND of two predicates + And(Box, Box), + /// Logical OR of two predicates + Or(Box, Box), + /// No predicate (accept all rows) + None, +} + +/// Filter operator - filters rows based on predicate +#[derive(Debug)] +pub struct FilterOperator { + predicate: FilterPredicate, + tracker: Option>>, +} + +impl FilterOperator { + pub fn new(predicate: FilterPredicate) -> Self { + Self { + predicate, + tracker: None, + } + } + + /// Get the predicate for this filter + pub fn predicate(&self) -> &FilterPredicate { + &self.predicate + } + + pub fn evaluate_predicate(&self, values: &[Value]) -> bool { + match &self.predicate { + FilterPredicate::None => true, + FilterPredicate::Equals { column_idx, value } => { + if let Some(v) = values.get(*column_idx) { + return v == value; + } + false + } + FilterPredicate::NotEquals { column_idx, value } => { + if let Some(v) = values.get(*column_idx) { + return v != value; + } + false + } + FilterPredicate::GreaterThan { column_idx, value } => { + if let Some(v) = values.get(*column_idx) { + // Compare based on value types + match (v, value) { + (Value::Integer(a), Value::Integer(b)) => return a > b, + (Value::Float(a), Value::Float(b)) => return a > b, + (Value::Text(a), Value::Text(b)) => return a.as_str() > b.as_str(), + _ => {} + } + } + false + } + FilterPredicate::GreaterThanOrEqual { column_idx, value } => { + if let Some(v) = values.get(*column_idx) { + match (v, value) { + (Value::Integer(a), Value::Integer(b)) => return a >= b, + (Value::Float(a), Value::Float(b)) => return a >= b, + (Value::Text(a), Value::Text(b)) => return a.as_str() >= b.as_str(), + _ => {} + } + } + false + } + FilterPredicate::LessThan { column_idx, value } => { + if let Some(v) = values.get(*column_idx) { + match (v, value) { + (Value::Integer(a), Value::Integer(b)) => return a < b, + (Value::Float(a), Value::Float(b)) => return a < b, + (Value::Text(a), Value::Text(b)) => return a.as_str() < b.as_str(), + _ => {} + } + } + false + } + FilterPredicate::LessThanOrEqual { column_idx, value } => { + if let Some(v) = values.get(*column_idx) { + match (v, value) { + (Value::Integer(a), Value::Integer(b)) => return a <= b, + (Value::Float(a), Value::Float(b)) => return a <= b, + (Value::Text(a), Value::Text(b)) => return a.as_str() <= b.as_str(), + _ => {} + } + } + false + } + FilterPredicate::And(left, right) => { + // Temporarily create sub-filters to evaluate + let left_filter = FilterOperator::new((**left).clone()); + let right_filter = FilterOperator::new((**right).clone()); + left_filter.evaluate_predicate(values) && right_filter.evaluate_predicate(values) + } + FilterPredicate::Or(left, right) => { + let left_filter = FilterOperator::new((**left).clone()); + let right_filter = FilterOperator::new((**right).clone()); + left_filter.evaluate_predicate(values) || right_filter.evaluate_predicate(values) + } + + // Column-to-column comparisons + FilterPredicate::ColumnEquals { + left_idx, + right_idx, + } => { + if let (Some(left), Some(right)) = (values.get(*left_idx), values.get(*right_idx)) { + return left == right; + } + false + } + FilterPredicate::ColumnNotEquals { + left_idx, + right_idx, + } => { + if let (Some(left), Some(right)) = (values.get(*left_idx), values.get(*right_idx)) { + return left != right; + } + false + } + FilterPredicate::ColumnGreaterThan { + left_idx, + right_idx, + } => { + if let (Some(left), Some(right)) = (values.get(*left_idx), values.get(*right_idx)) { + match (left, right) { + (Value::Integer(a), Value::Integer(b)) => return a > b, + (Value::Float(a), Value::Float(b)) => return a > b, + (Value::Text(a), Value::Text(b)) => return a.as_str() > b.as_str(), + _ => {} + } + } + false + } + FilterPredicate::ColumnGreaterThanOrEqual { + left_idx, + right_idx, + } => { + if let (Some(left), Some(right)) = (values.get(*left_idx), values.get(*right_idx)) { + match (left, right) { + (Value::Integer(a), Value::Integer(b)) => return a >= b, + (Value::Float(a), Value::Float(b)) => return a >= b, + (Value::Text(a), Value::Text(b)) => return a.as_str() >= b.as_str(), + _ => {} + } + } + false + } + FilterPredicate::ColumnLessThan { + left_idx, + right_idx, + } => { + if let (Some(left), Some(right)) = (values.get(*left_idx), values.get(*right_idx)) { + match (left, right) { + (Value::Integer(a), Value::Integer(b)) => return a < b, + (Value::Float(a), Value::Float(b)) => return a < b, + (Value::Text(a), Value::Text(b)) => return a.as_str() < b.as_str(), + _ => {} + } + } + false + } + FilterPredicate::ColumnLessThanOrEqual { + left_idx, + right_idx, + } => { + if let (Some(left), Some(right)) = (values.get(*left_idx), values.get(*right_idx)) { + match (left, right) { + (Value::Integer(a), Value::Integer(b)) => return a <= b, + (Value::Float(a), Value::Float(b)) => return a <= b, + (Value::Text(a), Value::Text(b)) => return a.as_str() <= b.as_str(), + _ => {} + } + } + false + } + } + } +} + +impl IncrementalOperator for FilterOperator { + fn eval( + &mut self, + state: &mut EvalState, + _cursors: &mut DbspStateCursors, + ) -> Result> { + let delta = match state { + EvalState::Init { deltas } => { + // Filter operators only use left_delta, right_delta must be empty + assert!( + deltas.right.is_empty(), + "FilterOperator expects right_delta to be empty" + ); + std::mem::take(&mut deltas.left) + } + _ => unreachable!( + "FilterOperator doesn't execute the state machine. Should be in Init state" + ), + }; + + let mut output_delta = Delta::new(); + + // Process the delta through the filter + for (row, weight) in delta.changes { + if let Some(tracker) = &self.tracker { + tracker.lock().unwrap().record_filter(); + } + + // Only pass through rows that satisfy the filter predicate + // For deletes (weight < 0), we only pass them if the row values + // would have passed the filter (meaning it was in the view) + if self.evaluate_predicate(&row.values) { + output_delta.changes.push((row, weight)); + } + } + + *state = EvalState::Done; + Ok(IOResult::Done(output_delta)) + } + + fn commit( + &mut self, + deltas: DeltaPair, + _cursors: &mut DbspStateCursors, + ) -> Result> { + // Filter operator only uses left delta, right must be empty + assert!( + deltas.right.is_empty(), + "FilterOperator expects right delta to be empty in commit" + ); + + let mut output_delta = Delta::new(); + + // Commit the delta to our internal state + // Only pass through and track rows that satisfy the filter predicate + for (row, weight) in deltas.left.changes { + if let Some(tracker) = &self.tracker { + tracker.lock().unwrap().record_filter(); + } + + // Only track and output rows that pass the filter + // For deletes, this means the row was in the view (its values pass the filter) + // For inserts, this means the row should be in the view + if self.evaluate_predicate(&row.values) { + output_delta.changes.push((row, weight)); + } + } + + Ok(IOResult::Done(output_delta)) + } + + fn set_tracker(&mut self, tracker: Arc>) { + self.tracker = Some(tracker); + } +} diff --git a/core/incremental/input_operator.rs b/core/incremental/input_operator.rs new file mode 100644 index 000000000..b9a6eeb01 --- /dev/null +++ b/core/incremental/input_operator.rs @@ -0,0 +1,66 @@ +// Input operator for DBSP-style incremental computation +// This operator serves as the entry point for data into the incremental computation pipeline + +use crate::incremental::dbsp::{Delta, DeltaPair}; +use crate::incremental::operator::{ + ComputationTracker, DbspStateCursors, EvalState, IncrementalOperator, +}; +use crate::types::IOResult; +use crate::Result; +use std::sync::{Arc, Mutex}; + +/// Input operator - source of data for the circuit +/// Represents base relations/tables that receive external updates +#[derive(Debug)] +pub struct InputOperator { + #[allow(dead_code)] + name: String, +} + +impl InputOperator { + pub fn new(name: String) -> Self { + Self { name } + } +} + +impl IncrementalOperator for InputOperator { + fn eval( + &mut self, + state: &mut EvalState, + _cursors: &mut DbspStateCursors, + ) -> Result> { + match state { + EvalState::Init { deltas } => { + // Input operators only use left_delta, right_delta must be empty + assert!( + deltas.right.is_empty(), + "InputOperator expects right_delta to be empty" + ); + let output = std::mem::take(&mut deltas.left); + *state = EvalState::Done; + Ok(IOResult::Done(output)) + } + _ => unreachable!( + "InputOperator doesn't execute the state machine. Should be in Init state" + ), + } + } + + fn commit( + &mut self, + deltas: DeltaPair, + _cursors: &mut DbspStateCursors, + ) -> Result> { + // Input operator only uses left delta, right must be empty + assert!( + deltas.right.is_empty(), + "InputOperator expects right delta to be empty in commit" + ); + // Input operator passes through the delta unchanged during commit + Ok(IOResult::Done(deltas.left)) + } + + fn set_tracker(&mut self, _tracker: Arc>) { + // Input operator doesn't need tracking + } +} diff --git a/core/incremental/join_operator.rs b/core/incremental/join_operator.rs new file mode 100644 index 000000000..f5ffb9b55 --- /dev/null +++ b/core/incremental/join_operator.rs @@ -0,0 +1,787 @@ +#![allow(dead_code)] + +use crate::incremental::dbsp::{Delta, DeltaPair, HashableRow}; +use crate::incremental::operator::{ + generate_storage_id, ComputationTracker, DbspStateCursors, EvalState, IncrementalOperator, +}; +use crate::incremental::persistence::WriteRow; +use crate::types::{IOResult, ImmutableRecord, SeekKey, SeekOp, SeekResult}; +use crate::{return_and_restore_if_io, return_if_io, Result, Value}; +use std::sync::{Arc, Mutex}; + +#[derive(Debug, Clone, PartialEq)] +pub enum JoinType { + Inner, + Left, + Right, + Full, + Cross, +} + +// Helper function to read the next row from the BTree for joins +fn read_next_join_row( + storage_id: i64, + join_key: &HashableRow, + last_element_id: i64, + cursors: &mut DbspStateCursors, +) -> Result>> { + // Build the index key: (storage_id, zset_id, element_id) + // zset_id is the hash of the join key + let zset_id = join_key.cached_hash() as i64; + + let index_key_values = vec![ + Value::Integer(storage_id), + Value::Integer(zset_id), + Value::Integer(last_element_id), + ]; + + let index_record = ImmutableRecord::from_values(&index_key_values, index_key_values.len()); + let seek_result = return_if_io!(cursors + .index_cursor + .seek(SeekKey::IndexKey(&index_record), SeekOp::GT)); + + if !matches!(seek_result, SeekResult::Found) { + return Ok(IOResult::Done(None)); + } + + // Check if we're still in the same (storage_id, zset_id) range + let current_record = return_if_io!(cursors.index_cursor.record()); + + // Extract all needed values from the record before dropping it + let (found_storage_id, found_zset_id, element_id) = if let Some(rec) = current_record { + let values = rec.get_values(); + + // Index has 4 values: storage_id, zset_id, element_id, rowid (appended by WriteRow) + if values.len() >= 3 { + let found_storage_id = match &values[0].to_owned() { + Value::Integer(id) => *id, + _ => return Ok(IOResult::Done(None)), + }; + let found_zset_id = match &values[1].to_owned() { + Value::Integer(id) => *id, + _ => return Ok(IOResult::Done(None)), + }; + let element_id = match &values[2].to_owned() { + Value::Integer(id) => *id, + _ => { + return Ok(IOResult::Done(None)); + } + }; + (found_storage_id, found_zset_id, element_id) + } else { + return Ok(IOResult::Done(None)); + } + } else { + return Ok(IOResult::Done(None)); + }; + + // Now we can safely check if we're in the right range + // If we've moved to a different storage_id or zset_id, we're done + if found_storage_id != storage_id || found_zset_id != zset_id { + return Ok(IOResult::Done(None)); + } + + // Now get the actual row from the table using the rowid from the index + let rowid = return_if_io!(cursors.index_cursor.rowid()); + if let Some(rowid) = rowid { + return_if_io!(cursors + .table_cursor + .seek(SeekKey::TableRowId(rowid), SeekOp::GE { eq_only: true })); + + let table_record = return_if_io!(cursors.table_cursor.record()); + if let Some(rec) = table_record { + let table_values = rec.get_values(); + // Table format: [storage_id, zset_id, element_id, value_blob, weight] + if table_values.len() >= 5 { + // Deserialize the row from the blob + let value_at_3 = table_values[3].to_owned(); + let blob = match value_at_3 { + Value::Blob(ref b) => b, + _ => return Ok(IOResult::Done(None)), + }; + + // The blob contains the serialized HashableRow + // For now, let's deserialize it simply + let row = deserialize_hashable_row(blob)?; + + let weight = match &table_values[4].to_owned() { + Value::Integer(w) => *w as isize, + _ => return Ok(IOResult::Done(None)), + }; + + return Ok(IOResult::Done(Some((element_id, row, weight)))); + } + } + } + Ok(IOResult::Done(None)) +} + +// Join-specific eval states +#[derive(Debug)] +pub enum JoinEvalState { + ProcessDeltaJoin { + deltas: DeltaPair, + output: Delta, + }, + ProcessLeftJoin { + deltas: DeltaPair, + output: Delta, + current_idx: usize, + last_row_scanned: i64, + }, + ProcessRightJoin { + deltas: DeltaPair, + output: Delta, + current_idx: usize, + last_row_scanned: i64, + }, + Done { + output: Delta, + }, +} + +impl JoinEvalState { + fn combine_rows( + left_row: &HashableRow, + left_weight: i64, + right_row: &HashableRow, + right_weight: i64, + output: &mut Delta, + ) { + // Combine the rows + let mut combined_values = left_row.values.clone(); + combined_values.extend(right_row.values.clone()); + // Use hash of the combined values as rowid to ensure uniqueness + let temp_row = HashableRow::new(0, combined_values.clone()); + let joined_rowid = temp_row.cached_hash() as i64; + let joined_row = HashableRow::new(joined_rowid, combined_values); + + // Add to output with combined weight + let combined_weight = left_weight * right_weight; + output.changes.push((joined_row, combined_weight as isize)); + } + + fn process_join_state( + &mut self, + cursors: &mut DbspStateCursors, + left_key_indices: &[usize], + right_key_indices: &[usize], + left_storage_id: i64, + right_storage_id: i64, + ) -> Result> { + loop { + match self { + JoinEvalState::ProcessDeltaJoin { deltas, output } => { + // Move to ProcessLeftJoin + *self = JoinEvalState::ProcessLeftJoin { + deltas: std::mem::take(deltas), + output: std::mem::take(output), + current_idx: 0, + last_row_scanned: i64::MIN, + }; + } + JoinEvalState::ProcessLeftJoin { + deltas, + output, + current_idx, + last_row_scanned, + } => { + if *current_idx >= deltas.left.changes.len() { + *self = JoinEvalState::ProcessRightJoin { + deltas: std::mem::take(deltas), + output: std::mem::take(output), + current_idx: 0, + last_row_scanned: i64::MIN, + }; + } else { + let (left_row, left_weight) = &deltas.left.changes[*current_idx]; + // Extract join key using provided indices + let key_values: Vec = left_key_indices + .iter() + .map(|&idx| left_row.values.get(idx).cloned().unwrap_or(Value::Null)) + .collect(); + let left_key = HashableRow::new(0, key_values); + + let next_row = return_if_io!(read_next_join_row( + right_storage_id, + &left_key, + *last_row_scanned, + cursors + )); + match next_row { + Some((element_id, right_row, right_weight)) => { + Self::combine_rows( + left_row, + (*left_weight) as i64, + &right_row, + right_weight as i64, + output, + ); + // Continue scanning with this left row + *self = JoinEvalState::ProcessLeftJoin { + deltas: std::mem::take(deltas), + output: std::mem::take(output), + current_idx: *current_idx, + last_row_scanned: element_id, + }; + } + None => { + // No more matches for this left row, move to next + *self = JoinEvalState::ProcessLeftJoin { + deltas: std::mem::take(deltas), + output: std::mem::take(output), + current_idx: *current_idx + 1, + last_row_scanned: i64::MIN, + }; + } + } + } + } + JoinEvalState::ProcessRightJoin { + deltas, + output, + current_idx, + last_row_scanned, + } => { + if *current_idx >= deltas.right.changes.len() { + *self = JoinEvalState::Done { + output: std::mem::take(output), + }; + } else { + let (right_row, right_weight) = &deltas.right.changes[*current_idx]; + // Extract join key using provided indices + let key_values: Vec = right_key_indices + .iter() + .map(|&idx| right_row.values.get(idx).cloned().unwrap_or(Value::Null)) + .collect(); + let right_key = HashableRow::new(0, key_values); + + let next_row = return_if_io!(read_next_join_row( + left_storage_id, + &right_key, + *last_row_scanned, + cursors + )); + match next_row { + Some((element_id, left_row, left_weight)) => { + Self::combine_rows( + &left_row, + left_weight as i64, + right_row, + (*right_weight) as i64, + output, + ); + // Continue scanning with this right row + *self = JoinEvalState::ProcessRightJoin { + deltas: std::mem::take(deltas), + output: std::mem::take(output), + current_idx: *current_idx, + last_row_scanned: element_id, + }; + } + None => { + // No more matches for this right row, move to next + *self = JoinEvalState::ProcessRightJoin { + deltas: std::mem::take(deltas), + output: std::mem::take(output), + current_idx: *current_idx + 1, + last_row_scanned: i64::MIN, + }; + } + } + } + } + JoinEvalState::Done { output } => { + return Ok(IOResult::Done(std::mem::take(output))); + } + } + } + } +} + +#[derive(Debug)] +enum JoinCommitState { + Idle, + Eval { + eval_state: EvalState, + }, + CommitLeftDelta { + deltas: DeltaPair, + output: Delta, + current_idx: usize, + write_row: WriteRow, + }, + CommitRightDelta { + deltas: DeltaPair, + output: Delta, + current_idx: usize, + write_row: WriteRow, + }, + Invalid, +} + +/// Join operator - performs incremental join between two relations +/// Implements the DBSP formula: δ(R ⋈ S) = (δR ⋈ S) ∪ (R ⋈ δS) ∪ (δR ⋈ δS) +#[derive(Debug)] +pub struct JoinOperator { + /// Unique operator ID for indexing in persistent storage + operator_id: usize, + /// Type of join to perform + join_type: JoinType, + /// Column indices for extracting join keys from left input + left_key_indices: Vec, + /// Column indices for extracting join keys from right input + right_key_indices: Vec, + /// Column names from left input + left_columns: Vec, + /// Column names from right input + right_columns: Vec, + /// Tracker for computation statistics + tracker: Option>>, + + commit_state: JoinCommitState, +} + +impl JoinOperator { + pub fn new( + operator_id: usize, + join_type: JoinType, + left_key_indices: Vec, + right_key_indices: Vec, + left_columns: Vec, + right_columns: Vec, + ) -> Result { + // Check for unsupported join types + match join_type { + JoinType::Left => { + return Err(crate::LimboError::ParseError( + "LEFT OUTER JOIN is not yet supported in incremental views".to_string(), + )) + } + JoinType::Right => { + return Err(crate::LimboError::ParseError( + "RIGHT OUTER JOIN is not yet supported in incremental views".to_string(), + )) + } + JoinType::Full => { + return Err(crate::LimboError::ParseError( + "FULL OUTER JOIN is not yet supported in incremental views".to_string(), + )) + } + JoinType::Cross => { + return Err(crate::LimboError::ParseError( + "CROSS JOIN is not yet supported in incremental views".to_string(), + )) + } + JoinType::Inner => {} // Inner join is supported + } + + Ok(Self { + operator_id, + join_type, + left_key_indices, + right_key_indices, + left_columns, + right_columns, + tracker: None, + commit_state: JoinCommitState::Idle, + }) + } + + /// Extract join key from row values using the specified indices + fn extract_join_key(&self, values: &[Value], indices: &[usize]) -> HashableRow { + let key_values: Vec = indices + .iter() + .map(|&idx| values.get(idx).cloned().unwrap_or(Value::Null)) + .collect(); + // Use 0 as a dummy rowid for join keys. They don't come from a table, + // so they don't need a rowid. Their key will be the hash of the row values. + HashableRow::new(0, key_values) + } + + /// Generate storage ID for left table + fn left_storage_id(&self) -> i64 { + // Use column_index=0 for left side + generate_storage_id(self.operator_id, 0, 0) + } + + /// Generate storage ID for right table + fn right_storage_id(&self) -> i64 { + // Use column_index=1 for right side + generate_storage_id(self.operator_id, 1, 0) + } + + /// SQL-compliant comparison for join keys + /// Returns true if keys match according to SQL semantics (NULL != NULL) + fn sql_keys_equal(left_key: &HashableRow, right_key: &HashableRow) -> bool { + if left_key.values.len() != right_key.values.len() { + return false; + } + + for (left_val, right_val) in left_key.values.iter().zip(right_key.values.iter()) { + // In SQL, NULL never equals NULL + if matches!(left_val, Value::Null) || matches!(right_val, Value::Null) { + return false; + } + + // For non-NULL values, use regular comparison + if left_val != right_val { + return false; + } + } + + true + } + + fn process_join_state( + &mut self, + state: &mut EvalState, + cursors: &mut DbspStateCursors, + ) -> Result> { + // Get the join state out of the enum + match state { + EvalState::Join(js) => js.process_join_state( + cursors, + &self.left_key_indices, + &self.right_key_indices, + self.left_storage_id(), + self.right_storage_id(), + ), + _ => panic!("process_join_state called with non-join state"), + } + } + + fn eval_internal( + &mut self, + state: &mut EvalState, + cursors: &mut DbspStateCursors, + ) -> Result> { + loop { + let loop_state = std::mem::replace(state, EvalState::Uninitialized); + match loop_state { + EvalState::Uninitialized => { + panic!("Cannot eval JoinOperator with Uninitialized state"); + } + EvalState::Init { deltas } => { + let mut output = Delta::new(); + + // Component 3: δR ⋈ δS (left delta join right delta) + for (left_row, left_weight) in &deltas.left.changes { + let left_key = + self.extract_join_key(&left_row.values, &self.left_key_indices); + + for (right_row, right_weight) in &deltas.right.changes { + let right_key = + self.extract_join_key(&right_row.values, &self.right_key_indices); + + if Self::sql_keys_equal(&left_key, &right_key) { + if let Some(tracker) = &self.tracker { + tracker.lock().unwrap().record_join_lookup(); + } + + // Combine the rows + let mut combined_values = left_row.values.clone(); + combined_values.extend(right_row.values.clone()); + + // Create the joined row with a unique rowid + // Use hash of the combined values to ensure uniqueness + let temp_row = HashableRow::new(0, combined_values.clone()); + let joined_rowid = temp_row.cached_hash() as i64; + let joined_row = + HashableRow::new(joined_rowid, combined_values.clone()); + + // Add to output with combined weight + let combined_weight = left_weight * right_weight; + output.changes.push((joined_row, combined_weight)); + } + } + } + + *state = EvalState::Join(Box::new(JoinEvalState::ProcessDeltaJoin { + deltas, + output, + })); + } + EvalState::Join(join_state) => { + *state = EvalState::Join(join_state); + let output = return_if_io!(self.process_join_state(state, cursors)); + return Ok(IOResult::Done(output)); + } + EvalState::Done => { + return Ok(IOResult::Done(Delta::new())); + } + EvalState::Aggregate(_) => { + panic!("Aggregate state should not appear in join operator"); + } + } + } + } +} + +// Helper to deserialize a HashableRow from a blob +fn deserialize_hashable_row(blob: &[u8]) -> Result { + // Simple deserialization - this needs to match how we serialize in commit + // Format: [rowid:8 bytes][num_values:4 bytes][values...] + if blob.len() < 12 { + return Err(crate::LimboError::InternalError( + "Invalid blob size".to_string(), + )); + } + + let rowid = i64::from_le_bytes(blob[0..8].try_into().unwrap()); + let num_values = u32::from_le_bytes(blob[8..12].try_into().unwrap()) as usize; + + let mut values = Vec::new(); + let mut offset = 12; + + for _ in 0..num_values { + if offset >= blob.len() { + break; + } + + let type_tag = blob[offset]; + offset += 1; + + match type_tag { + 0 => values.push(Value::Null), + 1 => { + if offset + 8 <= blob.len() { + let i = i64::from_le_bytes(blob[offset..offset + 8].try_into().unwrap()); + values.push(Value::Integer(i)); + offset += 8; + } + } + 2 => { + if offset + 8 <= blob.len() { + let f = f64::from_le_bytes(blob[offset..offset + 8].try_into().unwrap()); + values.push(Value::Float(f)); + offset += 8; + } + } + 3 => { + if offset + 4 <= blob.len() { + let len = + u32::from_le_bytes(blob[offset..offset + 4].try_into().unwrap()) as usize; + offset += 4; + if offset + len < blob.len() { + let text_bytes = blob[offset..offset + len].to_vec(); + offset += len; + let subtype = match blob[offset] { + 0 => crate::types::TextSubtype::Text, + 1 => crate::types::TextSubtype::Json, + _ => crate::types::TextSubtype::Text, + }; + offset += 1; + values.push(Value::Text(crate::types::Text { + value: text_bytes, + subtype, + })); + } + } + } + 4 => { + if offset + 4 <= blob.len() { + let len = + u32::from_le_bytes(blob[offset..offset + 4].try_into().unwrap()) as usize; + offset += 4; + if offset + len <= blob.len() { + let blob_data = blob[offset..offset + len].to_vec(); + values.push(Value::Blob(blob_data)); + offset += len; + } + } + } + _ => break, // Unknown type tag + } + } + + Ok(HashableRow::new(rowid, values)) +} + +// Helper to serialize a HashableRow to a blob +fn serialize_hashable_row(row: &HashableRow) -> Vec { + let mut blob = Vec::new(); + + // Write rowid + blob.extend_from_slice(&row.rowid.to_le_bytes()); + + // Write number of values + blob.extend_from_slice(&(row.values.len() as u32).to_le_bytes()); + + // Write each value directly with type tags (like AggregateState does) + for value in &row.values { + match value { + Value::Null => blob.push(0u8), + Value::Integer(i) => { + blob.push(1u8); + blob.extend_from_slice(&i.to_le_bytes()); + } + Value::Float(f) => { + blob.push(2u8); + blob.extend_from_slice(&f.to_le_bytes()); + } + Value::Text(s) => { + blob.push(3u8); + let bytes = &s.value; + blob.extend_from_slice(&(bytes.len() as u32).to_le_bytes()); + blob.extend_from_slice(bytes); + blob.push(s.subtype as u8); + } + Value::Blob(b) => { + blob.push(4u8); + blob.extend_from_slice(&(b.len() as u32).to_le_bytes()); + blob.extend_from_slice(b); + } + } + } + + blob +} + +impl IncrementalOperator for JoinOperator { + fn eval( + &mut self, + state: &mut EvalState, + cursors: &mut DbspStateCursors, + ) -> Result> { + let delta = return_if_io!(self.eval_internal(state, cursors)); + Ok(IOResult::Done(delta)) + } + + fn commit( + &mut self, + deltas: DeltaPair, + cursors: &mut DbspStateCursors, + ) -> Result> { + loop { + let mut state = std::mem::replace(&mut self.commit_state, JoinCommitState::Invalid); + match &mut state { + JoinCommitState::Idle => { + self.commit_state = JoinCommitState::Eval { + eval_state: deltas.clone().into(), + } + } + JoinCommitState::Eval { ref mut eval_state } => { + let output = return_and_restore_if_io!( + &mut self.commit_state, + state, + self.eval(eval_state, cursors) + ); + self.commit_state = JoinCommitState::CommitLeftDelta { + deltas: deltas.clone(), + output, + current_idx: 0, + write_row: WriteRow::new(), + }; + } + JoinCommitState::CommitLeftDelta { + deltas, + output, + current_idx, + ref mut write_row, + } => { + if *current_idx >= deltas.left.changes.len() { + self.commit_state = JoinCommitState::CommitRightDelta { + deltas: std::mem::take(deltas), + output: std::mem::take(output), + current_idx: 0, + write_row: WriteRow::new(), + }; + continue; + } + + let (row, weight) = &deltas.left.changes[*current_idx]; + // Extract join key from the left row + let join_key = self.extract_join_key(&row.values, &self.left_key_indices); + + // The index key: (storage_id, zset_id, element_id) + // zset_id is the hash of the join key, element_id is hash of the row + let storage_id = self.left_storage_id(); + let zset_id = join_key.cached_hash() as i64; + let element_id = row.cached_hash() as i64; + let index_key = vec![ + Value::Integer(storage_id), + Value::Integer(zset_id), + Value::Integer(element_id), + ]; + + // The record values: we'll store the serialized row as a blob + let row_blob = serialize_hashable_row(row); + let record_values = vec![ + Value::Integer(self.left_storage_id()), + Value::Integer(join_key.cached_hash() as i64), + Value::Integer(row.cached_hash() as i64), + Value::Blob(row_blob), + ]; + + // Use return_and_restore_if_io to handle I/O properly + return_and_restore_if_io!( + &mut self.commit_state, + state, + write_row.write_row(cursors, index_key, record_values, *weight) + ); + + self.commit_state = JoinCommitState::CommitLeftDelta { + deltas: deltas.clone(), + output: output.clone(), + current_idx: *current_idx + 1, + write_row: WriteRow::new(), + }; + } + JoinCommitState::CommitRightDelta { + deltas, + output, + current_idx, + ref mut write_row, + } => { + if *current_idx >= deltas.right.changes.len() { + // Reset to Idle state for next commit + self.commit_state = JoinCommitState::Idle; + return Ok(IOResult::Done(output.clone())); + } + + let (row, weight) = &deltas.right.changes[*current_idx]; + // Extract join key from the right row + let join_key = self.extract_join_key(&row.values, &self.right_key_indices); + + // The index key: (storage_id, zset_id, element_id) + let index_key = vec![ + Value::Integer(self.right_storage_id()), + Value::Integer(join_key.cached_hash() as i64), + Value::Integer(row.cached_hash() as i64), + ]; + + // The record values: we'll store the serialized row as a blob + let row_blob = serialize_hashable_row(row); + let record_values = vec![ + Value::Integer(self.right_storage_id()), + Value::Integer(join_key.cached_hash() as i64), + Value::Integer(row.cached_hash() as i64), + Value::Blob(row_blob), + ]; + + // Use return_and_restore_if_io to handle I/O properly + return_and_restore_if_io!( + &mut self.commit_state, + state, + write_row.write_row(cursors, index_key, record_values, *weight) + ); + + self.commit_state = JoinCommitState::CommitRightDelta { + deltas: std::mem::take(deltas), + output: std::mem::take(output), + current_idx: *current_idx + 1, + write_row: WriteRow::new(), + }; + } + JoinCommitState::Invalid => { + panic!("Invalid join commit state"); + } + } + } + } + + fn set_tracker(&mut self, tracker: Arc>) { + self.tracker = Some(tracker); + } +} diff --git a/core/incremental/mod.rs b/core/incremental/mod.rs index 2c69e050b..67eed60e2 100644 --- a/core/incremental/mod.rs +++ b/core/incremental/mod.rs @@ -1,7 +1,12 @@ +pub mod aggregate_operator; pub mod compiler; pub mod cursor; pub mod dbsp; pub mod expr_compiler; +pub mod filter_operator; +pub mod input_operator; +pub mod join_operator; pub mod operator; pub mod persistence; +pub mod project_operator; pub mod view; diff --git a/core/incremental/operator.rs b/core/incremental/operator.rs index 7c402db93..2af512504 100644 --- a/core/incremental/operator.rs +++ b/core/incremental/operator.rs @@ -2,21 +2,21 @@ // Operator DAG for DBSP-style incremental computation // Based on Feldera DBSP design but adapted for Turso's architecture -use crate::function::{AggFunc, Func}; -use crate::incremental::dbsp::{Delta, DeltaPair, HashableRow}; -use crate::incremental::expr_compiler::CompiledExpression; -use crate::incremental::persistence::{MinMaxPersistState, ReadRecord, RecomputeMinMax, WriteRow}; +pub use crate::incremental::aggregate_operator::{ + AggregateEvalState, AggregateFunction, AggregateState, +}; +pub use crate::incremental::filter_operator::{FilterOperator, FilterPredicate}; +pub use crate::incremental::input_operator::InputOperator; +pub use crate::incremental::join_operator::{JoinEvalState, JoinOperator, JoinType}; +pub use crate::incremental::project_operator::{ProjectColumn, ProjectOperator}; + +use crate::incremental::dbsp::{Delta, DeltaPair}; use crate::schema::{Index, IndexColumn}; use crate::storage::btree::BTreeCursor; -use crate::types::{IOResult, ImmutableRecord, SeekKey, SeekOp, SeekResult, Text}; -use crate::{ - return_and_restore_if_io, return_if_io, Connection, Database, Result, SymbolTable, Value, -}; -use std::collections::{BTreeMap, HashMap}; -use std::fmt::{self, Debug, Display}; +use crate::types::IOResult; +use crate::Result; +use std::fmt::Debug; use std::sync::{Arc, Mutex}; -use turso_macros::match_ignore_ascii_case; -use turso_parser::ast::{As, Expr, Literal, Name, OneSelect, Operator, ResultColumn}; /// Struct to hold both table and index cursors for DBSP state operations pub struct DbspStateCursors { @@ -72,12 +72,6 @@ pub fn create_dbsp_state_index(root_page: usize) -> Index { } } -/// Constants for aggregate type encoding in storage IDs (2 bits) -pub const AGG_TYPE_REGULAR: u8 = 0b00; // COUNT/SUM/AVG -pub const AGG_TYPE_MINMAX: u8 = 0b01; // MIN/MAX (BTree ordering gives both) -pub const AGG_TYPE_RESERVED1: u8 = 0b10; // Reserved for future use -pub const AGG_TYPE_RESERVED2: u8 = 0b11; // Reserved for future use - /// Generate a storage ID with column index and operation type encoding /// Storage ID = (operator_id << 16) | (column_index << 2) | operation_type /// Bit layout (64-bit integer): @@ -91,64 +85,13 @@ pub fn generate_storage_id(operator_id: usize, column_index: usize, op_type: u8) ((operator_id as i64) << 16) | ((column_index as i64) << 2) | (op_type as i64) } -// group_key_str -> (group_key, state) -type ComputedStates = HashMap, AggregateState)>; -// group_key_str -> (column_name, value_as_hashable_row) -> accumulated_weight -pub type MinMaxDeltas = HashMap>; - -#[derive(Debug)] -enum AggregateCommitState { - Idle, - Eval { - eval_state: EvalState, - }, - PersistDelta { - delta: Delta, - computed_states: ComputedStates, - current_idx: usize, - write_row: WriteRow, - min_max_deltas: MinMaxDeltas, - }, - PersistMinMax { - delta: Delta, - min_max_persist_state: MinMaxPersistState, - }, - Done { - delta: Delta, - }, - Invalid, -} - -// eval() has uncommitted data, so it can't be a member attribute of the Operator. -// The state has to be kept by the caller +// Generic eval state that delegates to operator-specific states #[derive(Debug)] pub enum EvalState { Uninitialized, - Init { - deltas: DeltaPair, - }, - FetchKey { - delta: Delta, // Keep original delta for merge operation - current_idx: usize, - groups_to_read: Vec<(String, Vec)>, // Changed to Vec for index-based access - existing_groups: HashMap, - old_values: HashMap>, - }, - FetchData { - delta: Delta, // Keep original delta for merge operation - current_idx: usize, - groups_to_read: Vec<(String, Vec)>, // Changed to Vec for index-based access - existing_groups: HashMap, - old_values: HashMap>, - rowid: Option, // Rowid found by FetchKey (None if not found) - read_record_state: Box, - }, - RecomputeMinMax { - delta: Delta, - existing_groups: HashMap, - old_values: HashMap>, - recompute_state: Box, - }, + Init { deltas: DeltaPair }, + Aggregate(Box), + Join(Box), Done, } @@ -189,182 +132,6 @@ impl EvalState { _ => panic!("extract_delta() can only be called when in Init state"), } } - - fn advance(&mut self, groups_to_read: BTreeMap>) { - let delta = match self { - EvalState::Init { deltas } => std::mem::take(&mut deltas.left), - _ => panic!("advance() can only be called when in Init state, current state: {self:?}"), - }; - - let _ = std::mem::replace( - self, - EvalState::FetchKey { - delta, - current_idx: 0, - groups_to_read: groups_to_read.into_iter().collect(), // Convert BTreeMap to Vec - existing_groups: HashMap::new(), - old_values: HashMap::new(), - }, - ); - } - fn process_delta( - &mut self, - operator: &mut AggregateOperator, - cursors: &mut DbspStateCursors, - ) -> Result> { - loop { - match self { - EvalState::Uninitialized => { - panic!("Cannot process_delta with Uninitialized state"); - } - EvalState::Init { .. } => { - panic!("State machine not supposed to reach the init state! advance() should have been called"); - } - EvalState::FetchKey { - delta, - current_idx, - groups_to_read, - existing_groups, - old_values, - } => { - if *current_idx >= groups_to_read.len() { - // All groups have been fetched, move to RecomputeMinMax - // Extract MIN/MAX deltas from the input delta - let min_max_deltas = operator.extract_min_max_deltas(delta); - - let recompute_state = Box::new(RecomputeMinMax::new( - min_max_deltas, - existing_groups, - operator, - )); - - *self = EvalState::RecomputeMinMax { - delta: std::mem::take(delta), - existing_groups: std::mem::take(existing_groups), - old_values: std::mem::take(old_values), - recompute_state, - }; - } else { - // Get the current group to read - let (group_key_str, _group_key) = &groups_to_read[*current_idx]; - - // Build the key for the index: (operator_id, zset_id, element_id) - // For regular aggregates, use column_index=0 and AGG_TYPE_REGULAR - let operator_storage_id = - generate_storage_id(operator.operator_id, 0, AGG_TYPE_REGULAR); - let zset_id = operator.generate_group_rowid(group_key_str); - let element_id = 0i64; // Always 0 for aggregators - - // Create index key values - let index_key_values = vec![ - Value::Integer(operator_storage_id), - Value::Integer(zset_id), - Value::Integer(element_id), - ]; - - // Create an immutable record for the index key - let index_record = - ImmutableRecord::from_values(&index_key_values, index_key_values.len()); - - // Seek in the index to find if this row exists - let seek_result = return_if_io!(cursors.index_cursor.seek( - SeekKey::IndexKey(&index_record), - SeekOp::GE { eq_only: true } - )); - - let rowid = if matches!(seek_result, SeekResult::Found) { - // Found in index, get the table rowid - // The btree code handles extracting the rowid from the index record for has_rowid indexes - return_if_io!(cursors.index_cursor.rowid()) - } else { - // Not found in index, no existing state - None - }; - - // Always transition to FetchData - let taken_existing = std::mem::take(existing_groups); - let taken_old_values = std::mem::take(old_values); - let next_state = EvalState::FetchData { - delta: std::mem::take(delta), - current_idx: *current_idx, - groups_to_read: std::mem::take(groups_to_read), - existing_groups: taken_existing, - old_values: taken_old_values, - rowid, - read_record_state: Box::new(ReadRecord::new()), - }; - *self = next_state; - } - } - EvalState::FetchData { - delta, - current_idx, - groups_to_read, - existing_groups, - old_values, - rowid, - read_record_state, - } => { - // Get the current group to read - let (group_key_str, group_key) = &groups_to_read[*current_idx]; - - // Only try to read if we have a rowid - if let Some(rowid) = rowid { - let key = SeekKey::TableRowId(*rowid); - let state = return_if_io!(read_record_state.read_record( - key, - &operator.aggregates, - &mut cursors.table_cursor - )); - // Process the fetched state - if let Some(state) = state { - let mut old_row = group_key.clone(); - old_row.extend(state.to_values(&operator.aggregates)); - old_values.insert(group_key_str.clone(), old_row); - existing_groups.insert(group_key_str.clone(), state.clone()); - } - } else { - // No rowid for this group, skipping read - } - // If no rowid, there's no existing state for this group - - // Move to next group - let next_idx = *current_idx + 1; - let taken_existing = std::mem::take(existing_groups); - let taken_old_values = std::mem::take(old_values); - let next_state = EvalState::FetchKey { - delta: std::mem::take(delta), - current_idx: next_idx, - groups_to_read: std::mem::take(groups_to_read), - existing_groups: taken_existing, - old_values: taken_old_values, - }; - *self = next_state; - } - EvalState::RecomputeMinMax { - delta, - existing_groups, - old_values, - recompute_state, - } => { - if operator.has_min_max() { - // Process MIN/MAX recomputation - this will update existing_groups with correct MIN/MAX - return_if_io!(recompute_state.process(existing_groups, operator, cursors)); - } - - // Now compute final output with updated MIN/MAX values - let (output_delta, computed_states) = - operator.merge_delta_with_existing(delta, existing_groups, old_values); - - *self = EvalState::Done; - return Ok(IOResult::Done((output_delta, computed_states))); - } - EvalState::Done => { - return Ok(IOResult::Done((Delta::new(), HashMap::new()))); - } - } - } - } } /// Tracks computation counts to verify incremental behavior (for tests now), and in the future @@ -411,64 +178,6 @@ impl ComputationTracker { } } -#[cfg(test)] -mod dbsp_types_tests { - use super::*; - - #[test] - fn test_hashable_row_delta_operations() { - let mut delta = Delta::new(); - - // Test INSERT - delta.insert(1, vec![Value::Integer(1), Value::Integer(100)]); - assert_eq!(delta.len(), 1); - - // Test UPDATE (DELETE + INSERT) - order matters! - delta.delete(1, vec![Value::Integer(1), Value::Integer(100)]); - delta.insert(1, vec![Value::Integer(1), Value::Integer(200)]); - assert_eq!(delta.len(), 3); // Should have 3 operations before consolidation - - // Verify order is preserved - let ops: Vec<_> = delta.changes.iter().collect(); - assert_eq!(ops[0].1, 1); // First insert - assert_eq!(ops[1].1, -1); // Delete - assert_eq!(ops[2].1, 1); // Second insert - - // Test consolidation - delta.consolidate(); - // After consolidation, the first insert and delete should cancel out - // leaving only the second insert - assert_eq!(delta.len(), 1); - - let final_row = &delta.changes[0]; - assert_eq!(final_row.0.rowid, 1); - assert_eq!( - final_row.0.values, - vec![Value::Integer(1), Value::Integer(200)] - ); - assert_eq!(final_row.1, 1); - } - - #[test] - fn test_duplicate_row_consolidation() { - let mut delta = Delta::new(); - - // Insert same row twice - delta.insert(2, vec![Value::Integer(2), Value::Integer(300)]); - delta.insert(2, vec![Value::Integer(2), Value::Integer(300)]); - - assert_eq!(delta.len(), 2); - - delta.consolidate(); - assert_eq!(delta.len(), 1); - - // Weight should be 2 (sum of both inserts) - let final_row = &delta.changes[0]; - assert_eq!(final_row.0.rowid, 2); - assert_eq!(final_row.1, 2); - } -} - /// Represents an operator in the dataflow graph #[derive(Debug, Clone)] pub enum QueryOperator { @@ -506,198 +215,6 @@ pub enum QueryOperator { }, } -#[derive(Debug, Clone)] -pub enum FilterPredicate { - /// Column = value - Equals { column: String, value: Value }, - /// Column != value - NotEquals { column: String, value: Value }, - /// Column > value - GreaterThan { column: String, value: Value }, - /// Column >= value - GreaterThanOrEqual { column: String, value: Value }, - /// Column < value - LessThan { column: String, value: Value }, - /// Column <= value - LessThanOrEqual { column: String, value: Value }, - /// Logical AND of two predicates - And(Box, Box), - /// Logical OR of two predicates - Or(Box, Box), - /// No predicate (accept all rows) - None, -} - -impl FilterPredicate { - /// Parse a SQL AST expression into a FilterPredicate - /// This centralizes all SQL-to-predicate parsing logic - pub fn from_sql_expr(expr: &turso_parser::ast::Expr) -> crate::Result { - let Expr::Binary(lhs, op, rhs) = expr else { - return Err(crate::LimboError::ParseError( - "Unsupported WHERE clause for incremental views: not a binary expression" - .to_string(), - )); - }; - - // Handle AND/OR logical operators - match op { - Operator::And => { - let left = Self::from_sql_expr(lhs)?; - let right = Self::from_sql_expr(rhs)?; - return Ok(FilterPredicate::And(Box::new(left), Box::new(right))); - } - Operator::Or => { - let left = Self::from_sql_expr(lhs)?; - let right = Self::from_sql_expr(rhs)?; - return Ok(FilterPredicate::Or(Box::new(left), Box::new(right))); - } - _ => {} - } - - // Handle comparison operators - let Expr::Id(column_name) = &**lhs else { - return Err(crate::LimboError::ParseError( - "Unsupported WHERE clause for incremental views: left-hand-side is not a column reference".to_string(), - )); - }; - - let column = column_name.as_str().to_string(); - - // Parse the right-hand side value - let value = match &**rhs { - Expr::Literal(Literal::String(s)) => { - // Strip quotes from string literals - let cleaned = s.trim_matches('\'').trim_matches('"'); - Value::Text(Text::new(cleaned)) - } - Expr::Literal(Literal::Numeric(n)) => { - // Try to parse as integer first, then float - if let Ok(i) = n.parse::() { - Value::Integer(i) - } else if let Ok(f) = n.parse::() { - Value::Float(f) - } else { - return Err(crate::LimboError::ParseError( - "Unsupported WHERE clause for incremental views: right-hand-side is not a numeric literal".to_string(), - )); - } - } - Expr::Literal(Literal::Null) => Value::Null, - Expr::Literal(Literal::Blob(_)) => { - // Blob comparison not yet supported - return Err(crate::LimboError::ParseError( - "Unsupported WHERE clause for incremental views: comparison with blob literals is not supported".to_string(), - )); - } - other => { - // Complex expressions not yet supported - return Err(crate::LimboError::ParseError( - format!("Unsupported WHERE clause for incremental views: comparison with {other:?} is not supported"), - )); - } - }; - - // Create the appropriate predicate based on operator - match op { - Operator::Equals => Ok(FilterPredicate::Equals { column, value }), - Operator::NotEquals => Ok(FilterPredicate::NotEquals { column, value }), - Operator::Greater => Ok(FilterPredicate::GreaterThan { column, value }), - Operator::GreaterEquals => Ok(FilterPredicate::GreaterThanOrEqual { column, value }), - Operator::Less => Ok(FilterPredicate::LessThan { column, value }), - Operator::LessEquals => Ok(FilterPredicate::LessThanOrEqual { column, value }), - other => Err(crate::LimboError::ParseError( - format!("Unsupported WHERE clause for incremental views: comparison operator {other:?} is not supported"), - )), - } - } - - /// Parse a WHERE clause from a SELECT statement - pub fn from_select(select: &turso_parser::ast::Select) -> crate::Result { - if let OneSelect::Select { - ref where_clause, .. - } = select.body.select - { - if let Some(where_clause) = where_clause { - Self::from_sql_expr(where_clause) - } else { - Ok(FilterPredicate::None) - } - } else { - Err(crate::LimboError::ParseError( - "Unsupported WHERE clause for incremental views: not a single SELECT statement" - .to_string(), - )) - } - } -} - -#[derive(Debug, Clone)] -pub struct ProjectColumn { - /// The original SQL expression (for debugging/fallback) - pub expr: turso_parser::ast::Expr, - /// Optional alias for the column - pub alias: Option, - /// Compiled expression (handles both trivial columns and complex expressions) - pub compiled: CompiledExpression, -} - -#[derive(Debug, Clone)] -pub enum JoinType { - Inner, - Left, - Right, -} - -#[derive(Debug, Clone, PartialEq)] -pub enum AggregateFunction { - Count, - Sum(String), - Avg(String), - Min(String), - Max(String), -} - -impl Display for AggregateFunction { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - match self { - AggregateFunction::Count => write!(f, "COUNT(*)"), - AggregateFunction::Sum(col) => write!(f, "SUM({col})"), - AggregateFunction::Avg(col) => write!(f, "AVG({col})"), - AggregateFunction::Min(col) => write!(f, "MIN({col})"), - AggregateFunction::Max(col) => write!(f, "MAX({col})"), - } - } -} - -impl AggregateFunction { - /// Get the default output column name for this aggregate function - #[inline] - pub fn default_output_name(&self) -> String { - self.to_string() - } - - /// Create an AggregateFunction from a SQL function and its arguments - /// Returns None if the function is not a supported aggregate - pub fn from_sql_function( - func: &crate::function::Func, - input_column: Option, - ) -> Option { - match func { - Func::Agg(agg_func) => { - match agg_func { - AggFunc::Count | AggFunc::Count0 => Some(AggregateFunction::Count), - AggFunc::Sum => input_column.map(AggregateFunction::Sum), - AggFunc::Avg => input_column.map(AggregateFunction::Avg), - AggFunc::Min => input_column.map(AggregateFunction::Min), - AggFunc::Max => input_column.map(AggregateFunction::Max), - _ => None, // Other aggregate functions not yet supported in DBSP - } - } - _ => None, // Not an aggregate function - } - } -} - /// Operator DAG (Directed Acyclic Graph) /// Base trait for incremental operators pub trait IncrementalOperator: Debug { @@ -731,1506 +248,11 @@ pub trait IncrementalOperator: Debug { fn set_tracker(&mut self, tracker: Arc>); } -/// Input operator - passes through input data unchanged -/// This operator is used for input nodes in the circuit to provide a uniform interface -#[derive(Debug)] -pub struct InputOperator { - name: String, -} - -impl InputOperator { - pub fn new(name: String) -> Self { - Self { name } - } - - pub fn name(&self) -> &str { - &self.name - } -} - -impl IncrementalOperator for InputOperator { - fn eval( - &mut self, - state: &mut EvalState, - _cursors: &mut DbspStateCursors, - ) -> Result> { - match state { - EvalState::Init { deltas } => { - // Input operators only use left_delta, right_delta must be empty - assert!( - deltas.right.is_empty(), - "InputOperator expects right_delta to be empty" - ); - let output = std::mem::take(&mut deltas.left); - *state = EvalState::Done; - Ok(IOResult::Done(output)) - } - _ => unreachable!( - "InputOperator doesn't execute the state machine. Should be in Init state" - ), - } - } - - fn commit( - &mut self, - deltas: DeltaPair, - _cursors: &mut DbspStateCursors, - ) -> Result> { - // Input operator only uses left delta, right must be empty - assert!( - deltas.right.is_empty(), - "InputOperator expects right delta to be empty in commit" - ); - // Input operator passes through the delta unchanged during commit - Ok(IOResult::Done(deltas.left)) - } - - fn set_tracker(&mut self, _tracker: Arc>) { - // Input operator doesn't need tracking - } -} - -/// Filter operator - filters rows based on predicate -#[derive(Debug)] -pub struct FilterOperator { - predicate: FilterPredicate, - column_names: Vec, - tracker: Option>>, -} - -impl FilterOperator { - pub fn new(predicate: FilterPredicate, column_names: Vec) -> Self { - Self { - predicate, - column_names, - tracker: None, - } - } - - /// Get the predicate for this filter - pub fn predicate(&self) -> &FilterPredicate { - &self.predicate - } - - pub fn evaluate_predicate(&self, values: &[Value]) -> bool { - match &self.predicate { - FilterPredicate::None => true, - FilterPredicate::Equals { column, value } => { - if let Some(idx) = self.column_names.iter().position(|c| c == column) { - if let Some(v) = values.get(idx) { - return v == value; - } - } - false - } - FilterPredicate::NotEquals { column, value } => { - if let Some(idx) = self.column_names.iter().position(|c| c == column) { - if let Some(v) = values.get(idx) { - return v != value; - } - } - false - } - FilterPredicate::GreaterThan { column, value } => { - if let Some(idx) = self.column_names.iter().position(|c| c == column) { - if let Some(v) = values.get(idx) { - // Compare based on value types - match (v, value) { - (Value::Integer(a), Value::Integer(b)) => return a > b, - (Value::Float(a), Value::Float(b)) => return a > b, - (Value::Text(a), Value::Text(b)) => return a.as_str() > b.as_str(), - _ => {} - } - } - } - false - } - FilterPredicate::GreaterThanOrEqual { column, value } => { - if let Some(idx) = self.column_names.iter().position(|c| c == column) { - if let Some(v) = values.get(idx) { - match (v, value) { - (Value::Integer(a), Value::Integer(b)) => return a >= b, - (Value::Float(a), Value::Float(b)) => return a >= b, - (Value::Text(a), Value::Text(b)) => return a.as_str() >= b.as_str(), - _ => {} - } - } - } - false - } - FilterPredicate::LessThan { column, value } => { - if let Some(idx) = self.column_names.iter().position(|c| c == column) { - if let Some(v) = values.get(idx) { - match (v, value) { - (Value::Integer(a), Value::Integer(b)) => return a < b, - (Value::Float(a), Value::Float(b)) => return a < b, - (Value::Text(a), Value::Text(b)) => return a.as_str() < b.as_str(), - _ => {} - } - } - } - false - } - FilterPredicate::LessThanOrEqual { column, value } => { - if let Some(idx) = self.column_names.iter().position(|c| c == column) { - if let Some(v) = values.get(idx) { - match (v, value) { - (Value::Integer(a), Value::Integer(b)) => return a <= b, - (Value::Float(a), Value::Float(b)) => return a <= b, - (Value::Text(a), Value::Text(b)) => return a.as_str() <= b.as_str(), - _ => {} - } - } - } - false - } - FilterPredicate::And(left, right) => { - // Temporarily create sub-filters to evaluate - let left_filter = FilterOperator::new((**left).clone(), self.column_names.clone()); - let right_filter = - FilterOperator::new((**right).clone(), self.column_names.clone()); - left_filter.evaluate_predicate(values) && right_filter.evaluate_predicate(values) - } - FilterPredicate::Or(left, right) => { - let left_filter = FilterOperator::new((**left).clone(), self.column_names.clone()); - let right_filter = - FilterOperator::new((**right).clone(), self.column_names.clone()); - left_filter.evaluate_predicate(values) || right_filter.evaluate_predicate(values) - } - } - } -} - -impl IncrementalOperator for FilterOperator { - fn eval( - &mut self, - state: &mut EvalState, - _cursors: &mut DbspStateCursors, - ) -> Result> { - let delta = match state { - EvalState::Init { deltas } => { - // Filter operators only use left_delta, right_delta must be empty - assert!( - deltas.right.is_empty(), - "FilterOperator expects right_delta to be empty" - ); - std::mem::take(&mut deltas.left) - } - _ => unreachable!( - "FilterOperator doesn't execute the state machine. Should be in Init state" - ), - }; - - let mut output_delta = Delta::new(); - - // Process the delta through the filter - for (row, weight) in delta.changes { - if let Some(tracker) = &self.tracker { - tracker.lock().unwrap().record_filter(); - } - - // Only pass through rows that satisfy the filter predicate - // For deletes (weight < 0), we only pass them if the row values - // would have passed the filter (meaning it was in the view) - if self.evaluate_predicate(&row.values) { - output_delta.changes.push((row, weight)); - } - } - - *state = EvalState::Done; - Ok(IOResult::Done(output_delta)) - } - - fn commit( - &mut self, - deltas: DeltaPair, - _cursors: &mut DbspStateCursors, - ) -> Result> { - // Filter operator only uses left delta, right must be empty - assert!( - deltas.right.is_empty(), - "FilterOperator expects right delta to be empty in commit" - ); - - let mut output_delta = Delta::new(); - - // Commit the delta to our internal state - // Only pass through and track rows that satisfy the filter predicate - for (row, weight) in deltas.left.changes { - if let Some(tracker) = &self.tracker { - tracker.lock().unwrap().record_filter(); - } - - // Only track and output rows that pass the filter - // For deletes, this means the row was in the view (its values pass the filter) - // For inserts, this means the row should be in the view - if self.evaluate_predicate(&row.values) { - output_delta.changes.push((row, weight)); - } - } - - Ok(IOResult::Done(output_delta)) - } - - fn set_tracker(&mut self, tracker: Arc>) { - self.tracker = Some(tracker); - } -} - -/// Project operator - selects/transforms columns -#[derive(Clone)] -pub struct ProjectOperator { - columns: Vec, - input_column_names: Vec, - output_column_names: Vec, - tracker: Option>>, - // Internal in-memory connection for expression evaluation - // Programs are very dependent on having a connection, so give it one. - // - // We could in theory pass the current connection, but there are a host of problems with that. - // For example: during a write transaction, where views are usually updated, we have autocommit - // on. When the program we are executing calls Halt, it will try to commit the current - // transaction, which is absolutely incorrect. - // - // There are other ways to solve this, but a read-only connection to an empty in-memory - // database gives us the closest environment we need to execute expressions. - internal_conn: Arc, -} - -impl std::fmt::Debug for ProjectOperator { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - f.debug_struct("ProjectOperator") - .field("columns", &self.columns) - .field("input_column_names", &self.input_column_names) - .field("output_column_names", &self.output_column_names) - .field("tracker", &self.tracker) - .finish_non_exhaustive() - } -} - -impl ProjectOperator { - /// Create a new ProjectOperator from a SELECT statement, extracting projection columns - pub fn from_select( - select: &turso_parser::ast::Select, - input_column_names: Vec, - schema: &crate::schema::Schema, - ) -> crate::Result { - // Set up internal connection for expression evaluation - let io = Arc::new(crate::MemoryIO::new()); - let db = Database::open_file( - io, ":memory:", false, // no MVCC needed for expression evaluation - false, // no indexes needed - )?; - let internal_conn = db.connect()?; - // Set to read-only mode and disable auto-commit since we're only evaluating expressions - internal_conn.query_only.set(true); - internal_conn.auto_commit.set(false); - - let temp_syms = SymbolTable::new(); - - // Extract columns from SELECT statement - let columns = if let OneSelect::Select { - columns: ref select_columns, - .. - } = &select.body.select - { - let mut columns = Vec::new(); - for result_col in select_columns { - match result_col { - ResultColumn::Expr(expr, alias) => { - let alias_str = if let Some(As::As(alias_name)) = alias { - Some(alias_name.as_str().to_string()) - } else { - None - }; - // Try to compile the expression (handles both columns and complex expressions) - let compiled = CompiledExpression::compile( - expr, - &input_column_names, - schema, - &temp_syms, - internal_conn.clone(), - )?; - columns.push(ProjectColumn { - expr: (**expr).clone(), - alias: alias_str, - compiled, - }); - } - ResultColumn::Star => { - // Select all columns - create trivial column references - for name in &input_column_names { - // Create an Id expression for the column - let expr = Expr::Id(Name::Ident(name.clone())); - let compiled = CompiledExpression::compile( - &expr, - &input_column_names, - schema, - &temp_syms, - internal_conn.clone(), - )?; - columns.push(ProjectColumn { - expr, - alias: None, - compiled, - }); - } - } - x => { - return Err(crate::LimboError::ParseError(format!( - "Unsupported {x:?} clause when compiling project operator", - ))); - } - } - } - - if columns.is_empty() { - return Err(crate::LimboError::ParseError( - "No columns found when compiling project operator".to_string(), - )); - } - columns - } else { - return Err(crate::LimboError::ParseError( - "Expression is not a valid SELECT expression".to_string(), - )); - }; - - // Generate output column names based on aliases or expressions - let output_column_names = columns - .iter() - .map(|c| { - c.alias.clone().unwrap_or_else(|| match &c.expr { - Expr::Id(name) => name.as_str().to_string(), - Expr::Qualified(table, column) => { - format!("{}.{}", table.as_str(), column.as_str()) - } - Expr::DoublyQualified(db, table, column) => { - format!("{}.{}.{}", db.as_str(), table.as_str(), column.as_str()) - } - _ => c.expr.to_string(), - }) - }) - .collect(); - - Ok(Self { - columns, - input_column_names, - output_column_names, - tracker: None, - internal_conn, - }) - } - - /// Create a ProjectOperator from pre-compiled expressions - pub fn from_compiled( - compiled_exprs: Vec, - aliases: Vec>, - input_column_names: Vec, - output_column_names: Vec, - ) -> crate::Result { - // Set up internal connection for expression evaluation - let io = Arc::new(crate::MemoryIO::new()); - let db = Database::open_file( - io, ":memory:", false, // no MVCC needed for expression evaluation - false, // no indexes needed - )?; - let internal_conn = db.connect()?; - // Set to read-only mode and disable auto-commit since we're only evaluating expressions - internal_conn.query_only.set(true); - internal_conn.auto_commit.set(false); - - // Create ProjectColumn structs from compiled expressions - let columns: Vec = compiled_exprs - .into_iter() - .zip(aliases) - .map(|(compiled, alias)| ProjectColumn { - // Create a placeholder AST expression since we already have the compiled version - expr: turso_parser::ast::Expr::Literal(turso_parser::ast::Literal::Null), - alias, - compiled, - }) - .collect(); - - Ok(Self { - columns, - input_column_names, - output_column_names, - tracker: None, - internal_conn, - }) - } - - /// Get the columns for this projection - pub fn columns(&self) -> &[ProjectColumn] { - &self.columns - } - - fn project_values(&self, values: &[Value]) -> Vec { - let mut output = Vec::new(); - - for col in &self.columns { - // Use the internal connection's pager for expression evaluation - let internal_pager = self.internal_conn.pager.borrow().clone(); - - // Execute the compiled expression (handles both columns and complex expressions) - let result = col - .compiled - .execute(values, internal_pager) - .expect("Failed to execute compiled expression for the Project operator"); - output.push(result); - } - - output - } - - fn evaluate_expression(&self, expr: &turso_parser::ast::Expr, values: &[Value]) -> Value { - match expr { - Expr::Id(name) => { - if let Some(idx) = self - .input_column_names - .iter() - .position(|c| c == name.as_str()) - { - if let Some(v) = values.get(idx) { - return v.clone(); - } - } - Value::Null - } - Expr::Literal(lit) => { - match lit { - Literal::Numeric(n) => { - if let Ok(i) = n.parse::() { - Value::Integer(i) - } else if let Ok(f) = n.parse::() { - Value::Float(f) - } else { - Value::Null - } - } - Literal::String(s) => { - let cleaned = s.trim_matches('\'').trim_matches('"'); - Value::Text(Text::new(cleaned)) - } - Literal::Null => Value::Null, - Literal::Blob(_) - | Literal::Keyword(_) - | Literal::CurrentDate - | Literal::CurrentTime - | Literal::CurrentTimestamp => Value::Null, // Not supported yet - } - } - Expr::Binary(left, op, right) => { - let left_val = self.evaluate_expression(left, values); - let right_val = self.evaluate_expression(right, values); - - match op { - Operator::Add => match (&left_val, &right_val) { - (Value::Integer(a), Value::Integer(b)) => Value::Integer(a + b), - (Value::Float(a), Value::Float(b)) => Value::Float(a + b), - (Value::Integer(a), Value::Float(b)) => Value::Float(*a as f64 + b), - (Value::Float(a), Value::Integer(b)) => Value::Float(a + *b as f64), - _ => Value::Null, - }, - Operator::Subtract => match (&left_val, &right_val) { - (Value::Integer(a), Value::Integer(b)) => Value::Integer(a - b), - (Value::Float(a), Value::Float(b)) => Value::Float(a - b), - (Value::Integer(a), Value::Float(b)) => Value::Float(*a as f64 - b), - (Value::Float(a), Value::Integer(b)) => Value::Float(a - *b as f64), - _ => Value::Null, - }, - Operator::Multiply => match (&left_val, &right_val) { - (Value::Integer(a), Value::Integer(b)) => Value::Integer(a * b), - (Value::Float(a), Value::Float(b)) => Value::Float(a * b), - (Value::Integer(a), Value::Float(b)) => Value::Float(*a as f64 * b), - (Value::Float(a), Value::Integer(b)) => Value::Float(a * *b as f64), - _ => Value::Null, - }, - Operator::Divide => match (&left_val, &right_val) { - (Value::Integer(a), Value::Integer(b)) => { - if *b != 0 { - Value::Integer(a / b) - } else { - Value::Null - } - } - (Value::Float(a), Value::Float(b)) => { - if *b != 0.0 { - Value::Float(a / b) - } else { - Value::Null - } - } - (Value::Integer(a), Value::Float(b)) => { - if *b != 0.0 { - Value::Float(*a as f64 / b) - } else { - Value::Null - } - } - (Value::Float(a), Value::Integer(b)) => { - if *b != 0 { - Value::Float(a / *b as f64) - } else { - Value::Null - } - } - _ => Value::Null, - }, - _ => Value::Null, // Other operators not supported yet - } - } - Expr::FunctionCall { name, args, .. } => { - let name_bytes = name.as_str().as_bytes(); - match_ignore_ascii_case!(match name_bytes { - b"hex" => { - if args.len() == 1 { - let arg_val = self.evaluate_expression(&args[0], values); - match arg_val { - Value::Integer(i) => Value::Text(Text::new(&format!("{i:X}"))), - _ => Value::Null, - } - } else { - Value::Null - } - } - _ => Value::Null, // Other functions not supported yet - }) - } - Expr::Parenthesized(inner) => { - assert!( - inner.len() <= 1, - "Parenthesized expressions with multiple elements are not supported" - ); - if !inner.is_empty() { - self.evaluate_expression(&inner[0], values) - } else { - Value::Null - } - } - _ => Value::Null, // Other expression types not supported yet - } - } -} - -impl IncrementalOperator for ProjectOperator { - fn eval( - &mut self, - state: &mut EvalState, - _cursors: &mut DbspStateCursors, - ) -> Result> { - let delta = match state { - EvalState::Init { deltas } => { - // Project operators only use left_delta, right_delta must be empty - assert!( - deltas.right.is_empty(), - "ProjectOperator expects right_delta to be empty" - ); - std::mem::take(&mut deltas.left) - } - _ => unreachable!( - "ProjectOperator doesn't execute the state machine. Should be in Init state" - ), - }; - - let mut output_delta = Delta::new(); - - for (row, weight) in delta.changes { - if let Some(tracker) = &self.tracker { - tracker.lock().unwrap().record_project(); - } - - let projected = self.project_values(&row.values); - let projected_row = HashableRow::new(row.rowid, projected); - output_delta.changes.push((projected_row, weight)); - } - - *state = EvalState::Done; - Ok(IOResult::Done(output_delta)) - } - - fn commit( - &mut self, - deltas: DeltaPair, - _cursors: &mut DbspStateCursors, - ) -> Result> { - // Project operator only uses left delta, right must be empty - assert!( - deltas.right.is_empty(), - "ProjectOperator expects right delta to be empty in commit" - ); - - let mut output_delta = Delta::new(); - - // Commit the delta to our internal state and build output - for (row, weight) in &deltas.left.changes { - if let Some(tracker) = &self.tracker { - tracker.lock().unwrap().record_project(); - } - let projected = self.project_values(&row.values); - let projected_row = HashableRow::new(row.rowid, projected); - output_delta.changes.push((projected_row, *weight)); - } - - Ok(crate::types::IOResult::Done(output_delta)) - } - - fn set_tracker(&mut self, tracker: Arc>) { - self.tracker = Some(tracker); - } -} - -/// Aggregate operator - performs incremental aggregation with GROUP BY -/// Maintains running totals/counts that are updated incrementally -/// -/// Information about a column that has MIN/MAX aggregations -#[derive(Debug, Clone)] -pub struct AggColumnInfo { - /// Index used for storage key generation - pub index: usize, - /// Whether this column has a MIN aggregate - pub has_min: bool, - /// Whether this column has a MAX aggregate - pub has_max: bool, -} - -/// Note that the AggregateOperator essentially implements a ZSet, even -/// though the ZSet structure is never used explicitly. The on-disk btree -/// plays the role of the set! -#[derive(Debug)] -pub struct AggregateOperator { - // Unique operator ID for indexing in persistent storage - pub operator_id: usize, - // GROUP BY columns - group_by: Vec, - // Aggregate functions to compute (including MIN/MAX) - pub aggregates: Vec, - // Column names from input - pub input_column_names: Vec, - // Map from column name to aggregate info for quick lookup - pub column_min_max: HashMap, - tracker: Option>>, - - // State machine for commit operation - commit_state: AggregateCommitState, -} - -/// State for a single group's aggregates -#[derive(Debug, Clone, Default)] -pub struct AggregateState { - // For COUNT: just the count - count: i64, - // For SUM: column_name -> sum value - sums: HashMap, - // For AVG: column_name -> (sum, count) for computing average - avgs: HashMap, - // For MIN: column_name -> minimum value - pub mins: HashMap, - // For MAX: column_name -> maximum value - pub maxs: HashMap, -} - -/// Serialize a Value using SQLite's serial type format -/// This is used for MIN/MAX values that need to be stored in a compact, sortable format -pub fn serialize_value(value: &Value, blob: &mut Vec) { - let serial_type = crate::types::SerialType::from(value); - let serial_type_u64: u64 = serial_type.into(); - crate::storage::sqlite3_ondisk::write_varint_to_vec(serial_type_u64, blob); - value.serialize_serial(blob); -} - -/// Deserialize a Value using SQLite's serial type format -/// Returns the deserialized value and the number of bytes consumed -pub fn deserialize_value(blob: &[u8]) -> Option<(Value, usize)> { - let mut cursor = 0; - - // Read the serial type - let (serial_type, varint_size) = crate::storage::sqlite3_ondisk::read_varint(blob).ok()?; - cursor += varint_size; - - let serial_type_obj = crate::types::SerialType::try_from(serial_type).ok()?; - let expected_size = serial_type_obj.size(); - - // Read the value - let (value, actual_size) = - crate::storage::sqlite3_ondisk::read_value(&blob[cursor..], serial_type_obj).ok()?; - - // Verify that the actual size matches what we expected from the serial type - if actual_size != expected_size { - return None; // Data corruption - size mismatch - } - - cursor += actual_size; - - // Convert RefValue to Value - Some((value.to_owned(), cursor)) -} - -impl AggregateState { - pub fn new() -> Self { - Self::default() - } - - // Serialize the aggregate state to a binary blob including group key values - // The reason we serialize it like this, instead of just writing the actual values, is that - // The same table may have different aggregators in the circuit. They will all have different - // columns. - fn to_blob(&self, aggregates: &[AggregateFunction], group_key: &[Value]) -> Vec { - let mut blob = Vec::new(); - - // Write version byte for future compatibility - blob.push(1u8); - - // Write number of group key values - blob.extend_from_slice(&(group_key.len() as u32).to_le_bytes()); - - // Write each group key value - for value in group_key { - // Write value type tag - match value { - Value::Null => blob.push(0u8), - Value::Integer(i) => { - blob.push(1u8); - blob.extend_from_slice(&i.to_le_bytes()); - } - Value::Float(f) => { - blob.push(2u8); - blob.extend_from_slice(&f.to_le_bytes()); - } - Value::Text(s) => { - blob.push(3u8); - let text_str = s.as_str(); - let bytes = text_str.as_bytes(); - blob.extend_from_slice(&(bytes.len() as u32).to_le_bytes()); - blob.extend_from_slice(bytes); - } - Value::Blob(b) => { - blob.push(4u8); - blob.extend_from_slice(&(b.len() as u32).to_le_bytes()); - blob.extend_from_slice(b); - } - } - } - - // Write count as 8 bytes (little-endian) - blob.extend_from_slice(&self.count.to_le_bytes()); - - // Write each aggregate's state - for agg in aggregates { - match agg { - AggregateFunction::Sum(col_name) => { - let sum = self.sums.get(col_name).copied().unwrap_or(0.0); - blob.extend_from_slice(&sum.to_le_bytes()); - } - AggregateFunction::Avg(col_name) => { - let (sum, count) = self.avgs.get(col_name).copied().unwrap_or((0.0, 0)); - blob.extend_from_slice(&sum.to_le_bytes()); - blob.extend_from_slice(&count.to_le_bytes()); - } - AggregateFunction::Count => { - // Count is already written above - } - AggregateFunction::Min(col_name) => { - // Write whether we have a MIN value (1 byte) - if let Some(min_val) = self.mins.get(col_name) { - blob.push(1u8); // Has value - serialize_value(min_val, &mut blob); - } else { - blob.push(0u8); // No value - } - } - AggregateFunction::Max(col_name) => { - // Write whether we have a MAX value (1 byte) - if let Some(max_val) = self.maxs.get(col_name) { - blob.push(1u8); // Has value - serialize_value(max_val, &mut blob); - } else { - blob.push(0u8); // No value - } - } - } - } - - blob - } - - /// Deserialize aggregate state from a binary blob - /// Returns the aggregate state and the group key values - pub fn from_blob(blob: &[u8], aggregates: &[AggregateFunction]) -> Option<(Self, Vec)> { - let mut cursor = 0; - - // Check version byte - if blob.get(cursor) != Some(&1u8) { - return None; - } - cursor += 1; - - // Read number of group key values - let num_group_keys = - u32::from_le_bytes(blob.get(cursor..cursor + 4)?.try_into().ok()?) as usize; - cursor += 4; - - // Read group key values - let mut group_key = Vec::new(); - for _ in 0..num_group_keys { - let value_type = *blob.get(cursor)?; - cursor += 1; - - let value = match value_type { - 0 => Value::Null, - 1 => { - let i = i64::from_le_bytes(blob.get(cursor..cursor + 8)?.try_into().ok()?); - cursor += 8; - Value::Integer(i) - } - 2 => { - let f = f64::from_le_bytes(blob.get(cursor..cursor + 8)?.try_into().ok()?); - cursor += 8; - Value::Float(f) - } - 3 => { - let len = - u32::from_le_bytes(blob.get(cursor..cursor + 4)?.try_into().ok()?) as usize; - cursor += 4; - let bytes = blob.get(cursor..cursor + len)?; - cursor += len; - let text_str = std::str::from_utf8(bytes).ok()?; - Value::Text(text_str.to_string().into()) - } - 4 => { - let len = - u32::from_le_bytes(blob.get(cursor..cursor + 4)?.try_into().ok()?) as usize; - cursor += 4; - let bytes = blob.get(cursor..cursor + len)?; - cursor += len; - Value::Blob(bytes.to_vec()) - } - _ => return None, - }; - group_key.push(value); - } - - // Read count - let count = i64::from_le_bytes(blob.get(cursor..cursor + 8)?.try_into().ok()?); - cursor += 8; - - let mut state = Self::new(); - state.count = count; - - // Read each aggregate's state - for agg in aggregates { - match agg { - AggregateFunction::Sum(col_name) => { - let sum = f64::from_le_bytes(blob.get(cursor..cursor + 8)?.try_into().ok()?); - cursor += 8; - state.sums.insert(col_name.clone(), sum); - } - AggregateFunction::Avg(col_name) => { - let sum = f64::from_le_bytes(blob.get(cursor..cursor + 8)?.try_into().ok()?); - cursor += 8; - let count = i64::from_le_bytes(blob.get(cursor..cursor + 8)?.try_into().ok()?); - cursor += 8; - state.avgs.insert(col_name.clone(), (sum, count)); - } - AggregateFunction::Count => { - // Count was already read above - } - AggregateFunction::Min(col_name) => { - // Read whether we have a MIN value - let has_value = *blob.get(cursor)?; - cursor += 1; - - if has_value == 1 { - let (min_value, bytes_consumed) = deserialize_value(&blob[cursor..])?; - cursor += bytes_consumed; - state.mins.insert(col_name.clone(), min_value); - } - } - AggregateFunction::Max(col_name) => { - // Read whether we have a MAX value - let has_value = *blob.get(cursor)?; - cursor += 1; - - if has_value == 1 { - let (max_value, bytes_consumed) = deserialize_value(&blob[cursor..])?; - cursor += bytes_consumed; - state.maxs.insert(col_name.clone(), max_value); - } - } - } - } - - Some((state, group_key)) - } - - /// Apply a delta to this aggregate state - fn apply_delta( - &mut self, - values: &[Value], - weight: isize, - aggregates: &[AggregateFunction], - column_names: &[String], - ) { - // Update COUNT - self.count += weight as i64; - - // Update other aggregates - for agg in aggregates { - match agg { - AggregateFunction::Count => { - // Already handled above - } - AggregateFunction::Sum(col_name) => { - if let Some(idx) = column_names.iter().position(|c| c == col_name) { - if let Some(val) = values.get(idx) { - let num_val = match val { - Value::Integer(i) => *i as f64, - Value::Float(f) => *f, - _ => 0.0, - }; - *self.sums.entry(col_name.clone()).or_insert(0.0) += - num_val * weight as f64; - } - } - } - AggregateFunction::Avg(col_name) => { - if let Some(idx) = column_names.iter().position(|c| c == col_name) { - if let Some(val) = values.get(idx) { - let num_val = match val { - Value::Integer(i) => *i as f64, - Value::Float(f) => *f, - _ => 0.0, - }; - let (sum, count) = - self.avgs.entry(col_name.clone()).or_insert((0.0, 0)); - *sum += num_val * weight as f64; - *count += weight as i64; - } - } - } - AggregateFunction::Min(_col_name) | AggregateFunction::Max(_col_name) => { - // MIN/MAX cannot be handled incrementally in apply_delta because: - // - // 1. For insertions: We can't just keep the minimum/maximum value. - // We need to track ALL values to handle future deletions correctly. - // - // 2. For deletions (retractions): If we delete the current MIN/MAX, - // we need to find the next best value, which requires knowing all - // other values in the group. - // - // Example: Consider MIN(price) with values [10, 20, 30] - // - Current MIN = 10 - // - Delete 10 (weight = -1) - // - New MIN should be 20, but we can't determine this without - // having tracked all values [20, 30] - // - // Therefore, MIN/MAX processing is handled separately: - // - All input values are persisted to the index via persist_min_max() - // - When aggregates have MIN/MAX, we unconditionally transition to - // the RecomputeMinMax state machine (see EvalState::RecomputeMinMax) - // - RecomputeMinMax checks if the current MIN/MAX was deleted, and if so, - // scans the index to find the new MIN/MAX from remaining values - // - // This ensures correctness for incremental computation at the cost of - // additional I/O for MIN/MAX operations. - } - } - } - } - - /// Convert aggregate state to output values - pub fn to_values(&self, aggregates: &[AggregateFunction]) -> Vec { - let mut result = Vec::new(); - - for agg in aggregates { - match agg { - AggregateFunction::Count => { - result.push(Value::Integer(self.count)); - } - AggregateFunction::Sum(col_name) => { - let sum = self.sums.get(col_name).copied().unwrap_or(0.0); - // Return as integer if it's a whole number, otherwise as float - if sum.fract() == 0.0 { - result.push(Value::Integer(sum as i64)); - } else { - result.push(Value::Float(sum)); - } - } - AggregateFunction::Avg(col_name) => { - if let Some((sum, count)) = self.avgs.get(col_name) { - if *count > 0 { - result.push(Value::Float(sum / *count as f64)); - } else { - result.push(Value::Null); - } - } else { - result.push(Value::Null); - } - } - AggregateFunction::Min(col_name) => { - // Return the MIN value from our state - result.push(self.mins.get(col_name).cloned().unwrap_or(Value::Null)); - } - AggregateFunction::Max(col_name) => { - // Return the MAX value from our state - result.push(self.maxs.get(col_name).cloned().unwrap_or(Value::Null)); - } - } - } - - result - } -} - -impl AggregateOperator { - pub fn new( - operator_id: usize, - group_by: Vec, - aggregates: Vec, - input_column_names: Vec, - ) -> Self { - // Build map of column names to their MIN/MAX info with indices - let mut column_min_max = HashMap::new(); - let mut column_indices = HashMap::new(); - let mut current_index = 0; - - // First pass: assign indices to unique MIN/MAX columns - for agg in &aggregates { - match agg { - AggregateFunction::Min(col) | AggregateFunction::Max(col) => { - column_indices.entry(col.clone()).or_insert_with(|| { - let idx = current_index; - current_index += 1; - idx - }); - } - _ => {} - } - } - - // Second pass: build the column info map - for agg in &aggregates { - match agg { - AggregateFunction::Min(col) => { - let index = *column_indices.get(col).unwrap(); - let entry = column_min_max.entry(col.clone()).or_insert(AggColumnInfo { - index, - has_min: false, - has_max: false, - }); - entry.has_min = true; - } - AggregateFunction::Max(col) => { - let index = *column_indices.get(col).unwrap(); - let entry = column_min_max.entry(col.clone()).or_insert(AggColumnInfo { - index, - has_min: false, - has_max: false, - }); - entry.has_max = true; - } - _ => {} - } - } - - Self { - operator_id, - group_by, - aggregates, - input_column_names, - column_min_max, - tracker: None, - commit_state: AggregateCommitState::Idle, - } - } - - pub fn has_min_max(&self) -> bool { - !self.column_min_max.is_empty() - } - - fn eval_internal( - &mut self, - state: &mut EvalState, - cursors: &mut DbspStateCursors, - ) -> Result> { - match state { - EvalState::Uninitialized => { - panic!("Cannot eval AggregateOperator with Uninitialized state"); - } - EvalState::Init { deltas } => { - // Aggregate operators only use left_delta, right_delta must be empty - assert!( - deltas.right.is_empty(), - "AggregateOperator expects right_delta to be empty" - ); - - if deltas.left.changes.is_empty() { - *state = EvalState::Done; - return Ok(IOResult::Done((Delta::new(), HashMap::new()))); - } - - let mut groups_to_read = BTreeMap::new(); - for (row, _weight) in &deltas.left.changes { - // Extract group key using cloned fields - let group_key = self.extract_group_key(&row.values); - let group_key_str = Self::group_key_to_string(&group_key); - groups_to_read.insert(group_key_str, group_key); - } - state.advance(groups_to_read); - } - EvalState::FetchKey { .. } - | EvalState::FetchData { .. } - | EvalState::RecomputeMinMax { .. } => { - // Already in progress, continue processing on process_delta below. - } - EvalState::Done => { - panic!("unreachable state! should have returned"); - } - } - - // Process the delta through the state machine - let result = return_if_io!(state.process_delta(self, cursors)); - Ok(IOResult::Done(result)) - } - - fn merge_delta_with_existing( - &mut self, - delta: &Delta, - existing_groups: &mut HashMap, - old_values: &mut HashMap>, - ) -> (Delta, HashMap, AggregateState)>) { - let mut output_delta = Delta::new(); - let mut temp_keys: HashMap> = HashMap::new(); - - // Process each change in the delta - for (row, weight) in &delta.changes { - if let Some(tracker) = &self.tracker { - tracker.lock().unwrap().record_aggregation(); - } - - // Extract group key - let group_key = self.extract_group_key(&row.values); - let group_key_str = Self::group_key_to_string(&group_key); - - let state = existing_groups.entry(group_key_str.clone()).or_default(); - - temp_keys.insert(group_key_str.clone(), group_key.clone()); - - // Apply the delta to the temporary state - state.apply_delta( - &row.values, - *weight, - &self.aggregates, - &self.input_column_names, - ); - } - - // Generate output delta from temporary states and collect final states - let mut final_states = HashMap::new(); - - for (group_key_str, state) in existing_groups { - let group_key = temp_keys.get(group_key_str).cloned().unwrap_or_default(); - - // Generate a unique rowid for this group - let result_key = self.generate_group_rowid(group_key_str); - - if let Some(old_row_values) = old_values.get(group_key_str) { - let old_row = HashableRow::new(result_key, old_row_values.clone()); - output_delta.changes.push((old_row, -1)); - } - - // Always store the state for persistence (even if count=0, we need to delete it) - final_states.insert(group_key_str.clone(), (group_key.clone(), state.clone())); - - // Only include groups with count > 0 in the output delta - if state.count > 0 { - // Build output row: group_by columns + aggregate values - let mut output_values = group_key.clone(); - let aggregate_values = state.to_values(&self.aggregates); - output_values.extend(aggregate_values); - - let output_row = HashableRow::new(result_key, output_values.clone()); - output_delta.changes.push((output_row, 1)); - } - } - (output_delta, final_states) - } - - /// Extract MIN/MAX values from delta changes for persistence to index - fn extract_min_max_deltas(&self, delta: &Delta) -> MinMaxDeltas { - let mut min_max_deltas: MinMaxDeltas = HashMap::new(); - - for (row, weight) in &delta.changes { - let group_key = self.extract_group_key(&row.values); - let group_key_str = Self::group_key_to_string(&group_key); - - for agg in &self.aggregates { - match agg { - AggregateFunction::Min(col_name) | AggregateFunction::Max(col_name) => { - if let Some(idx) = - self.input_column_names.iter().position(|c| c == col_name) - { - if let Some(val) = row.values.get(idx) { - // Skip NULL values - they don't participate in MIN/MAX - if val == &Value::Null { - continue; - } - // Create a HashableRow with just this value - // Use 0 as rowid since we only care about the value for comparison - let hashable_value = HashableRow::new(0, vec![val.clone()]); - let key = (col_name.clone(), hashable_value); - - let group_entry = - min_max_deltas.entry(group_key_str.clone()).or_default(); - - let value_entry = group_entry.entry(key).or_insert(0); - - // Accumulate the weight - *value_entry += weight; - } - } - } - _ => {} // Ignore non-MIN/MAX aggregates - } - } - } - - min_max_deltas - } - - pub fn set_tracker(&mut self, tracker: Arc>) { - self.tracker = Some(tracker); - } - - /// Generate a rowid for a group - /// For no GROUP BY: always returns 0 - /// For GROUP BY: returns a hash of the group key string - pub fn generate_group_rowid(&self, group_key_str: &str) -> i64 { - if self.group_by.is_empty() { - 0 - } else { - group_key_str - .bytes() - .fold(0i64, |acc, b| acc.wrapping_mul(31).wrapping_add(b as i64)) - } - } - - /// Generate the composite key for BTree storage - /// Combines operator_id and group hash - fn generate_storage_key(&self, group_key_str: &str) -> i64 { - let group_hash = self.generate_group_rowid(group_key_str); - (self.operator_id as i64) << 32 | (group_hash & 0xFFFFFFFF) - } - - /// Extract group key values from a row - pub fn extract_group_key(&self, values: &[Value]) -> Vec { - let mut key = Vec::new(); - - for group_col in &self.group_by { - if let Some(idx) = self.input_column_names.iter().position(|c| c == group_col) { - if let Some(val) = values.get(idx) { - key.push(val.clone()); - } else { - key.push(Value::Null); - } - } else { - key.push(Value::Null); - } - } - - key - } - - /// Convert group key to string for indexing (since Value doesn't implement Hash) - pub fn group_key_to_string(key: &[Value]) -> String { - key.iter() - .map(|v| format!("{v:?}")) - .collect::>() - .join(",") - } - - fn seek_key_from_str(&self, group_key_str: &str) -> SeekKey<'_> { - // Calculate the composite key for seeking - let key_i64 = self.generate_storage_key(group_key_str); - SeekKey::TableRowId(key_i64) - } - - fn seek_key(&self, row: HashableRow) -> SeekKey<'_> { - // Extract group key for first row - let group_key = self.extract_group_key(&row.values); - let group_key_str = Self::group_key_to_string(&group_key); - self.seek_key_from_str(&group_key_str) - } -} - -impl IncrementalOperator for AggregateOperator { - fn eval( - &mut self, - state: &mut EvalState, - cursors: &mut DbspStateCursors, - ) -> Result> { - let (delta, _) = return_if_io!(self.eval_internal(state, cursors)); - Ok(IOResult::Done(delta)) - } - - fn commit( - &mut self, - mut deltas: DeltaPair, - cursors: &mut DbspStateCursors, - ) -> Result> { - // Aggregate operator only uses left delta, right must be empty - assert!( - deltas.right.is_empty(), - "AggregateOperator expects right delta to be empty in commit" - ); - let delta = std::mem::take(&mut deltas.left); - loop { - // Note: because we std::mem::replace here (without it, the borrow checker goes nuts, - // because we call self.eval_interval, which requires a mutable borrow), we have to - // restore the state if we return I/O. So we can't use return_if_io! - let mut state = - std::mem::replace(&mut self.commit_state, AggregateCommitState::Invalid); - match &mut state { - AggregateCommitState::Invalid => { - panic!("Reached invalid state! State was replaced, and not replaced back"); - } - AggregateCommitState::Idle => { - let eval_state = EvalState::from_delta(delta.clone()); - self.commit_state = AggregateCommitState::Eval { eval_state }; - } - AggregateCommitState::Eval { ref mut eval_state } => { - // Extract input delta before eval for MIN/MAX processing - let input_delta = eval_state.extract_delta(); - - // Extract MIN/MAX deltas before any I/O operations - let min_max_deltas = self.extract_min_max_deltas(&input_delta); - - // Create a new eval state with the same delta - *eval_state = EvalState::from_delta(input_delta.clone()); - - let (output_delta, computed_states) = return_and_restore_if_io!( - &mut self.commit_state, - state, - self.eval_internal(eval_state, cursors) - ); - - self.commit_state = AggregateCommitState::PersistDelta { - delta: output_delta, - computed_states, - current_idx: 0, - write_row: WriteRow::new(), - min_max_deltas, // Store for later use - }; - } - AggregateCommitState::PersistDelta { - delta, - computed_states, - current_idx, - write_row, - min_max_deltas, - } => { - let states_vec: Vec<_> = computed_states.iter().collect(); - - if *current_idx >= states_vec.len() { - // Use the min_max_deltas we extracted earlier from the input delta - self.commit_state = AggregateCommitState::PersistMinMax { - delta: delta.clone(), - min_max_persist_state: MinMaxPersistState::new(min_max_deltas.clone()), - }; - } else { - let (group_key_str, (group_key, agg_state)) = states_vec[*current_idx]; - - // Build the key components for the new table structure - // For regular aggregates, use column_index=0 and AGG_TYPE_REGULAR - let operator_storage_id = - generate_storage_id(self.operator_id, 0, AGG_TYPE_REGULAR); - let zset_id = self.generate_group_rowid(group_key_str); - let element_id = 0i64; - - // Determine weight: -1 to delete (cancels existing weight=1), 1 to insert/update - let weight = if agg_state.count == 0 { -1 } else { 1 }; - - // Serialize the aggregate state with group key (even for deletion, we need a row) - let state_blob = agg_state.to_blob(&self.aggregates, group_key); - let blob_value = Value::Blob(state_blob); - - // Build the aggregate storage format: [operator_id, zset_id, element_id, value, weight] - let operator_id_val = Value::Integer(operator_storage_id); - let zset_id_val = Value::Integer(zset_id); - let element_id_val = Value::Integer(element_id); - let blob_val = blob_value.clone(); - - // Create index key - the first 3 columns of our primary key - let index_key = vec![ - operator_id_val.clone(), - zset_id_val.clone(), - element_id_val.clone(), - ]; - - // Record values (without weight) - let record_values = - vec![operator_id_val, zset_id_val, element_id_val, blob_val]; - - return_and_restore_if_io!( - &mut self.commit_state, - state, - write_row.write_row(cursors, index_key, record_values, weight) - ); - - let delta = std::mem::take(delta); - let computed_states = std::mem::take(computed_states); - let min_max_deltas = std::mem::take(min_max_deltas); - - self.commit_state = AggregateCommitState::PersistDelta { - delta, - computed_states, - current_idx: *current_idx + 1, - write_row: WriteRow::new(), // Reset for next write - min_max_deltas, - }; - } - } - AggregateCommitState::PersistMinMax { - delta, - min_max_persist_state, - } => { - if !self.has_min_max() { - let delta = std::mem::take(delta); - self.commit_state = AggregateCommitState::Done { delta }; - } else { - return_and_restore_if_io!( - &mut self.commit_state, - state, - min_max_persist_state.persist_min_max( - self.operator_id, - &self.column_min_max, - cursors, - |group_key_str| self.generate_group_rowid(group_key_str) - ) - ); - - let delta = std::mem::take(delta); - self.commit_state = AggregateCommitState::Done { delta }; - } - } - AggregateCommitState::Done { delta } => { - self.commit_state = AggregateCommitState::Idle; - let delta = std::mem::take(delta); - return Ok(IOResult::Done(delta)); - } - } - } - } - - fn set_tracker(&mut self, tracker: Arc>) { - self.tracker = Some(tracker); - } -} - #[cfg(test)] mod tests { use super::*; + use crate::incremental::aggregate_operator::{AggregateOperator, AGG_TYPE_REGULAR}; + use crate::incremental::dbsp::HashableRow; use crate::storage::pager::CreateBTreeFlags; use crate::types::Text; use crate::util::IOExt; @@ -2373,9 +395,9 @@ mod tests { // Create an aggregate operator for SUM(age) with no GROUP BY let mut agg = AggregateOperator::new( - 1, // operator_id for testing - vec![], // No GROUP BY - vec![AggregateFunction::Sum("age".to_string())], + 1, // operator_id for testing + vec![], // No GROUP BY + vec![AggregateFunction::Sum(2)], // age is at index 2 vec!["id".to_string(), "name".to_string(), "age".to_string()], ); @@ -2492,9 +514,9 @@ mod tests { let mut cursors = DbspStateCursors::new(table_cursor, index_cursor); let mut agg = AggregateOperator::new( - 1, // operator_id for testing - vec!["team".to_string()], // GROUP BY team - vec![AggregateFunction::Sum("score".to_string())], + 1, // operator_id for testing + vec![1], // GROUP BY team (index 1) + vec![AggregateFunction::Sum(3)], // score is at index 3 vec![ "id".to_string(), "team".to_string(), @@ -2644,8 +666,8 @@ mod tests { // Create COUNT(*) GROUP BY category let mut agg = AggregateOperator::new( - 1, // operator_id for testing - vec!["category".to_string()], + 1, // operator_id for testing + vec![1], // category is at index 1 vec![AggregateFunction::Count], vec![ "item_id".to_string(), @@ -2724,9 +746,9 @@ mod tests { let mut cursors = DbspStateCursors::new(table_cursor, index_cursor); let mut agg = AggregateOperator::new( - 1, // operator_id for testing - vec!["product".to_string()], - vec![AggregateFunction::Sum("amount".to_string())], + 1, // operator_id for testing + vec![1], // product is at index 1 + vec![AggregateFunction::Sum(2)], // amount is at index 2 vec![ "sale_id".to_string(), "product".to_string(), @@ -2821,11 +843,11 @@ mod tests { let mut cursors = DbspStateCursors::new(table_cursor, index_cursor); let mut agg = AggregateOperator::new( - 1, // operator_id for testing - vec!["user_id".to_string()], + 1, // operator_id for testing + vec![1], // user_id is at index 1 vec![ AggregateFunction::Count, - AggregateFunction::Sum("amount".to_string()), + AggregateFunction::Sum(2), // amount is at index 2 ], vec![ "order_id".to_string(), @@ -2913,9 +935,9 @@ mod tests { let mut cursors = DbspStateCursors::new(table_cursor, index_cursor); let mut agg = AggregateOperator::new( - 1, // operator_id for testing - vec!["category".to_string()], - vec![AggregateFunction::Avg("value".to_string())], + 1, // operator_id for testing + vec![1], // category is at index 1 + vec![AggregateFunction::Avg(2)], // value is at index 2 vec![ "id".to_string(), "category".to_string(), @@ -3013,11 +1035,11 @@ mod tests { let mut cursors = DbspStateCursors::new(table_cursor, index_cursor); let mut agg = AggregateOperator::new( - 1, // operator_id for testing - vec!["category".to_string()], + 1, // operator_id for testing + vec![1], // category is at index 1 vec![ AggregateFunction::Count, - AggregateFunction::Sum("value".to_string()), + AggregateFunction::Sum(2), // value is at index 2 ], vec![ "id".to_string(), @@ -3086,7 +1108,7 @@ mod tests { #[test] fn test_count_aggregation_with_deletions() { let aggregates = vec![AggregateFunction::Count]; - let group_by = vec!["category".to_string()]; + let group_by = vec![0]; // category is at index 0 let input_columns = vec!["category".to_string(), "value".to_string()]; // Create a persistent pager for the test @@ -3175,8 +1197,8 @@ mod tests { #[test] fn test_sum_aggregation_with_deletions() { - let aggregates = vec![AggregateFunction::Sum("value".to_string())]; - let group_by = vec!["category".to_string()]; + let aggregates = vec![AggregateFunction::Sum(1)]; // value is at index 1 + let group_by = vec![0]; // category is at index 0 let input_columns = vec!["category".to_string(), "value".to_string()]; // Create a persistent pager for the test @@ -3259,8 +1281,8 @@ mod tests { #[test] fn test_avg_aggregation_with_deletions() { - let aggregates = vec![AggregateFunction::Avg("value".to_string())]; - let group_by = vec!["category".to_string()]; + let aggregates = vec![AggregateFunction::Avg(1)]; // value is at index 1 + let group_by = vec![0]; // category is at index 0 let input_columns = vec!["category".to_string(), "value".to_string()]; // Create a persistent pager for the test @@ -3326,10 +1348,10 @@ mod tests { // Test COUNT, SUM, and AVG together let aggregates = vec![ AggregateFunction::Count, - AggregateFunction::Sum("value".to_string()), - AggregateFunction::Avg("value".to_string()), + AggregateFunction::Sum(1), // value is at index 1 + AggregateFunction::Avg(1), // value is at index 1 ]; - let group_by = vec!["category".to_string()]; + let group_by = vec![0]; // category is at index 0 let input_columns = vec!["category".to_string(), "value".to_string()]; // Create a persistent pager for the test @@ -3428,13 +1450,10 @@ mod tests { BTreeCursor::new_index(None, pager.clone(), index_root_page_id, &index_def, 4); let mut cursors = DbspStateCursors::new(table_cursor, index_cursor); - let mut filter = FilterOperator::new( - FilterPredicate::GreaterThan { - column: "b".to_string(), - value: Value::Integer(2), - }, - vec!["a".to_string(), "b".to_string()], - ); + let mut filter = FilterOperator::new(FilterPredicate::GreaterThan { + column_idx: 1, // "b" is at index 1 + value: Value::Integer(2), + }); // Initialize with a row (rowid=3, values=[3, 3]) let mut init_data = Delta::new(); @@ -3490,13 +1509,10 @@ mod tests { BTreeCursor::new_index(None, pager.clone(), index_root_page_id, &index_def, 4); let mut cursors = DbspStateCursors::new(table_cursor, index_cursor); - let mut filter = FilterOperator::new( - FilterPredicate::GreaterThan { - column: "age".to_string(), - value: Value::Integer(25), - }, - vec!["id".to_string(), "name".to_string(), "age".to_string()], - ); + let mut filter = FilterOperator::new(FilterPredicate::GreaterThan { + column_idx: 2, // "age" is at index 2 + value: Value::Integer(25), + }); // Initialize with some data let mut init_data = Delta::new(); @@ -3585,11 +1601,11 @@ mod tests { let mut cursors = DbspStateCursors::new(table_cursor, index_cursor); let mut agg = AggregateOperator::new( - 1, // operator_id for testing - vec!["category".to_string()], + 1, // operator_id for testing + vec![1], // category is at index 1 vec![ AggregateFunction::Count, - AggregateFunction::Sum("amount".to_string()), + AggregateFunction::Sum(2), // amount is at index 2 ], vec![ "id".to_string(), @@ -3759,7 +1775,7 @@ mod tests { vec![], // No GROUP BY vec![ AggregateFunction::Count, - AggregateFunction::Sum("value".to_string()), + AggregateFunction::Sum(1), // value is at index 1 ], vec!["id".to_string(), "value".to_string()], ); @@ -3837,8 +1853,8 @@ mod tests { let mut cursors = DbspStateCursors::new(table_cursor, index_cursor); let mut agg = AggregateOperator::new( - 1, // operator_id for testing - vec!["type".to_string()], + 1, // operator_id for testing + vec![1], // type is at index 1 vec![AggregateFunction::Count], vec!["id".to_string(), "type".to_string()], ); @@ -3954,8 +1970,8 @@ mod tests { 1, // operator_id vec![], // No GROUP BY vec![ - AggregateFunction::Min("price".to_string()), - AggregateFunction::Max("price".to_string()), + AggregateFunction::Min(2), // price is at index 2 + AggregateFunction::Max(2), // price is at index 2 ], vec!["id".to_string(), "name".to_string(), "price".to_string()], ); @@ -4022,8 +2038,8 @@ mod tests { 1, // operator_id vec![], // No GROUP BY vec![ - AggregateFunction::Min("price".to_string()), - AggregateFunction::Max("price".to_string()), + AggregateFunction::Min(2), // price is at index 2 + AggregateFunction::Max(2), // price is at index 2 ], vec!["id".to_string(), "name".to_string(), "price".to_string()], ); @@ -4112,8 +2128,8 @@ mod tests { 1, // operator_id vec![], // No GROUP BY vec![ - AggregateFunction::Min("price".to_string()), - AggregateFunction::Max("price".to_string()), + AggregateFunction::Min(2), // price is at index 2 + AggregateFunction::Max(2), // price is at index 2 ], vec!["id".to_string(), "name".to_string(), "price".to_string()], ); @@ -4202,8 +2218,8 @@ mod tests { 1, // operator_id vec![], // No GROUP BY vec![ - AggregateFunction::Min("price".to_string()), - AggregateFunction::Max("price".to_string()), + AggregateFunction::Min(2), // price is at index 2 + AggregateFunction::Max(2), // price is at index 2 ], vec!["id".to_string(), "name".to_string(), "price".to_string()], ); @@ -4284,8 +2300,8 @@ mod tests { 1, // operator_id vec![], // No GROUP BY vec![ - AggregateFunction::Min("price".to_string()), - AggregateFunction::Max("price".to_string()), + AggregateFunction::Min(2), // price is at index 2 + AggregateFunction::Max(2), // price is at index 2 ], vec!["id".to_string(), "name".to_string(), "price".to_string()], ); @@ -4366,8 +2382,8 @@ mod tests { 1, // operator_id vec![], // No GROUP BY vec![ - AggregateFunction::Min("price".to_string()), - AggregateFunction::Max("price".to_string()), + AggregateFunction::Min(2), // price is at index 2 + AggregateFunction::Max(2), // price is at index 2 ], vec!["id".to_string(), "name".to_string(), "price".to_string()], ); @@ -4453,11 +2469,11 @@ mod tests { let mut cursors = DbspStateCursors::new(table_cursor, index_cursor); let mut agg = AggregateOperator::new( - 1, // operator_id - vec!["category".to_string()], // GROUP BY category + 1, // operator_id + vec![1], // GROUP BY category (index 1) vec![ - AggregateFunction::Min("price".to_string()), - AggregateFunction::Max("price".to_string()), + AggregateFunction::Min(3), // price is at index 3 + AggregateFunction::Max(3), // price is at index 3 ], vec![ "id".to_string(), @@ -4558,8 +2574,8 @@ mod tests { 1, // operator_id vec![], // No GROUP BY vec![ - AggregateFunction::Min("price".to_string()), - AggregateFunction::Max("price".to_string()), + AggregateFunction::Min(2), // price is at index 2 + AggregateFunction::Max(2), // price is at index 2 ], vec!["id".to_string(), "name".to_string(), "price".to_string()], ); @@ -4634,8 +2650,8 @@ mod tests { 1, // operator_id vec![], // No GROUP BY vec![ - AggregateFunction::Min("score".to_string()), - AggregateFunction::Max("score".to_string()), + AggregateFunction::Min(2), // score is at index 2 + AggregateFunction::Max(2), // score is at index 2 ], vec!["id".to_string(), "name".to_string(), "score".to_string()], ); @@ -4702,8 +2718,8 @@ mod tests { 1, // operator_id vec![], // No GROUP BY vec![ - AggregateFunction::Min("name".to_string()), - AggregateFunction::Max("name".to_string()), + AggregateFunction::Min(1), // name is at index 1 + AggregateFunction::Max(1), // name is at index 1 ], vec!["id".to_string(), "name".to_string()], ); @@ -4742,10 +2758,10 @@ mod tests { vec![], // No GROUP BY vec![ AggregateFunction::Count, - AggregateFunction::Sum("value".to_string()), - AggregateFunction::Min("value".to_string()), - AggregateFunction::Max("value".to_string()), - AggregateFunction::Avg("value".to_string()), + AggregateFunction::Sum(1), // value is at index 1 + AggregateFunction::Min(1), // value is at index 1 + AggregateFunction::Max(1), // value is at index 1 + AggregateFunction::Avg(1), // value is at index 1 ], vec!["id".to_string(), "value".to_string()], ); @@ -4833,9 +2849,9 @@ mod tests { 1, // operator_id vec![], // No GROUP BY vec![ - AggregateFunction::Min("col1".to_string()), - AggregateFunction::Max("col2".to_string()), - AggregateFunction::Min("col3".to_string()), + AggregateFunction::Min(0), // col1 is at index 0 + AggregateFunction::Max(1), // col2 is at index 1 + AggregateFunction::Min(2), // col3 is at index 2 ], vec!["col1".to_string(), "col2".to_string(), "col3".to_string()], ); @@ -4897,4 +2913,765 @@ mod tests { assert_eq!(row_ins.values[1], Value::Integer(150)); // New MAX(col2) assert_eq!(row_ins.values[2], Value::Integer(500)); // MIN(col3) unchanged } + + #[test] + fn test_join_operator_inner() { + // Test INNER JOIN with incremental updates + let (pager, table_page_id, index_page_id) = create_test_pager(); + let table_cursor = BTreeCursor::new_table(None, pager.clone(), table_page_id, 10); + let index_def = create_dbsp_state_index(index_page_id); + let index_cursor = + BTreeCursor::new_index(None, pager.clone(), index_page_id, &index_def, 10); + let mut cursors = DbspStateCursors::new(table_cursor, index_cursor); + + let mut join = JoinOperator::new( + 1, // operator_id + JoinType::Inner, + vec![0], // Join on first column + vec![0], + vec!["customer_id".to_string(), "amount".to_string()], + vec!["id".to_string(), "name".to_string()], + ) + .unwrap(); + + // FIRST COMMIT: Initialize with data + let mut left_delta = Delta::new(); + left_delta.insert(1, vec![Value::Integer(1), Value::Float(100.0)]); + left_delta.insert(2, vec![Value::Integer(2), Value::Float(200.0)]); + left_delta.insert(3, vec![Value::Integer(3), Value::Float(300.0)]); // No match initially + + let mut right_delta = Delta::new(); + right_delta.insert(1, vec![Value::Integer(1), Value::Text("Alice".into())]); + right_delta.insert(2, vec![Value::Integer(2), Value::Text("Bob".into())]); + right_delta.insert(4, vec![Value::Integer(4), Value::Text("David".into())]); // No match initially + + let delta_pair = DeltaPair::new(left_delta, right_delta); + let result = pager + .io + .block(|| join.commit(delta_pair.clone(), &mut cursors)) + .unwrap(); + + // Should have 2 matches (customer 1 and 2) + assert_eq!( + result.changes.len(), + 2, + "First commit should produce 2 matches" + ); + + let mut results: Vec<_> = result.changes.clone(); + results.sort_by_key(|r| r.0.values[0].clone()); + + assert_eq!(results[0].0.values[0], Value::Integer(1)); + assert_eq!(results[0].0.values[3], Value::Text("Alice".into())); + assert_eq!(results[1].0.values[0], Value::Integer(2)); + assert_eq!(results[1].0.values[3], Value::Text("Bob".into())); + + // SECOND COMMIT: Add incremental data that should join with persisted state + // Add a new left row that should match existing right row (customer 4) + let mut left_delta2 = Delta::new(); + left_delta2.insert(5, vec![Value::Integer(4), Value::Float(400.0)]); // Should match David from persisted state + + // Add a new right row that should match existing left row (customer 3) + let mut right_delta2 = Delta::new(); + right_delta2.insert(6, vec![Value::Integer(3), Value::Text("Charlie".into())]); // Should match customer 3 from persisted state + + let delta_pair2 = DeltaPair::new(left_delta2, right_delta2); + let result2 = pager + .io + .block(|| join.commit(delta_pair2.clone(), &mut cursors)) + .unwrap(); + + // The second commit should produce: + // 1. New left (customer_id=4) joins with persisted right (id=4, David) + // 2. Persisted left (customer_id=3) joins with new right (id=3, Charlie) + + assert_eq!( + result2.changes.len(), + 2, + "Second commit should produce 2 new matches from incremental join. Got: {:?}", + result2.changes + ); + + // Verify the incremental results + let mut results2: Vec<_> = result2.changes.clone(); + results2.sort_by_key(|r| r.0.values[0].clone()); + + // Check for customer 3 joined with Charlie (existing left + new right) + let charlie_match = results2 + .iter() + .find(|(row, _)| row.values[0] == Value::Integer(3)) + .expect("Should find customer 3 joined with new Charlie"); + assert_eq!(charlie_match.0.values[2], Value::Integer(3)); + assert_eq!(charlie_match.0.values[3], Value::Text("Charlie".into())); + + // Check for customer 4 joined with David (new left + existing right) + let david_match = results2 + .iter() + .find(|(row, _)| row.values[0] == Value::Integer(4)) + .expect("Should find new customer 4 joined with existing David"); + assert_eq!(david_match.0.values[0], Value::Integer(4)); + assert_eq!(david_match.0.values[3], Value::Text("David".into())); + } + + #[test] + fn test_join_operator_with_deletions() { + // Test INNER JOIN with deletions (negative weights) + let (pager, table_page_id, index_page_id) = create_test_pager(); + let table_cursor = BTreeCursor::new_table(None, pager.clone(), table_page_id, 10); + let index_def = create_dbsp_state_index(index_page_id); + let index_cursor = + BTreeCursor::new_index(None, pager.clone(), index_page_id, &index_def, 10); + let mut cursors = DbspStateCursors::new(table_cursor, index_cursor); + + let mut join = JoinOperator::new( + 1, // operator_id + JoinType::Inner, + vec![0], // Join on first column + vec![0], + vec!["customer_id".to_string(), "amount".to_string()], + vec!["id".to_string(), "name".to_string()], + ) + .unwrap(); + + // FIRST COMMIT: Add initial data + let mut left_delta = Delta::new(); + left_delta.insert(1, vec![Value::Integer(1), Value::Float(100.0)]); + left_delta.insert(2, vec![Value::Integer(2), Value::Float(200.0)]); + left_delta.insert(3, vec![Value::Integer(3), Value::Float(300.0)]); + + let mut right_delta = Delta::new(); + right_delta.insert(1, vec![Value::Integer(1), Value::Text("Alice".into())]); + right_delta.insert(2, vec![Value::Integer(2), Value::Text("Bob".into())]); + right_delta.insert(3, vec![Value::Integer(3), Value::Text("Charlie".into())]); + + let delta_pair = DeltaPair::new(left_delta, right_delta); + + let result = pager + .io + .block(|| join.commit(delta_pair.clone(), &mut cursors)) + .unwrap(); + + assert_eq!(result.changes.len(), 3, "Should have 3 initial joins"); + + // SECOND COMMIT: Delete customer 2 from left side + let mut left_delta2 = Delta::new(); + left_delta2.delete(2, vec![Value::Integer(2), Value::Float(200.0)]); + + let empty_right = Delta::new(); + let delta_pair2 = DeltaPair::new(left_delta2, empty_right); + + let result2 = pager + .io + .block(|| join.commit(delta_pair2.clone(), &mut cursors)) + .unwrap(); + + // Should produce 1 deletion (retraction) of the join for customer 2 + assert_eq!( + result2.changes.len(), + 1, + "Should produce 1 retraction for deleted customer 2" + ); + assert_eq!( + result2.changes[0].1, -1, + "Should have weight -1 for deletion" + ); + assert_eq!(result2.changes[0].0.values[0], Value::Integer(2)); + assert_eq!(result2.changes[0].0.values[3], Value::Text("Bob".into())); + + // THIRD COMMIT: Delete customer 3 from right side + let empty_left = Delta::new(); + let mut right_delta3 = Delta::new(); + right_delta3.delete(3, vec![Value::Integer(3), Value::Text("Charlie".into())]); + + let delta_pair3 = DeltaPair::new(empty_left, right_delta3); + + let result3 = pager + .io + .block(|| join.commit(delta_pair3.clone(), &mut cursors)) + .unwrap(); + + // Should produce 1 deletion (retraction) of the join for customer 3 + assert_eq!( + result3.changes.len(), + 1, + "Should produce 1 retraction for deleted customer 3" + ); + assert_eq!( + result3.changes[0].1, -1, + "Should have weight -1 for deletion" + ); + assert_eq!(result3.changes[0].0.values[0], Value::Integer(3)); + assert_eq!(result3.changes[0].0.values[2], Value::Integer(3)); + } + + #[test] + fn test_join_operator_one_to_many() { + // Test one-to-many relationship: one customer with multiple orders + let (pager, table_page_id, index_page_id) = create_test_pager(); + let table_cursor = BTreeCursor::new_table(None, pager.clone(), table_page_id, 10); + let index_def = create_dbsp_state_index(index_page_id); + let index_cursor = + BTreeCursor::new_index(None, pager.clone(), index_page_id, &index_def, 10); + let mut cursors = DbspStateCursors::new(table_cursor, index_cursor); + + let mut join = JoinOperator::new( + 1, // operator_id + JoinType::Inner, + vec![0], // Join on first column (customer_id for orders) + vec![0], // Join on first column (id for customers) + vec![ + "customer_id".to_string(), + "order_id".to_string(), + "amount".to_string(), + ], + vec!["id".to_string(), "name".to_string()], + ) + .unwrap(); + + // FIRST COMMIT: Add one customer + let left_delta = Delta::new(); // Empty orders initially + let mut right_delta = Delta::new(); + right_delta.insert(1, vec![Value::Integer(100), Value::Text("Alice".into())]); + + let delta_pair = DeltaPair::new(left_delta, right_delta); + let result = pager + .io + .block(|| join.commit(delta_pair.clone(), &mut cursors)) + .unwrap(); + + // No joins yet (customer exists but no orders) + assert_eq!( + result.changes.len(), + 0, + "Should have no joins with customer but no orders" + ); + + // SECOND COMMIT: Add multiple orders for the same customer + let mut left_delta2 = Delta::new(); + left_delta2.insert( + 1, + vec![ + Value::Integer(100), + Value::Integer(1001), + Value::Float(50.0), + ], + ); // order 1001 + left_delta2.insert( + 2, + vec![ + Value::Integer(100), + Value::Integer(1002), + Value::Float(75.0), + ], + ); // order 1002 + left_delta2.insert( + 3, + vec![ + Value::Integer(100), + Value::Integer(1003), + Value::Float(100.0), + ], + ); // order 1003 + + let right_delta2 = Delta::new(); // No new customers + + let delta_pair2 = DeltaPair::new(left_delta2, right_delta2); + let result2 = pager + .io + .block(|| join.commit(delta_pair2.clone(), &mut cursors)) + .unwrap(); + + // Should produce 3 joins (3 orders × 1 customer) + assert_eq!( + result2.changes.len(), + 3, + "Should produce 3 joins for 3 orders with same customer. Got: {:?}", + result2.changes + ); + + // Verify all three joins have the same customer but different orders + for (row, weight) in &result2.changes { + assert_eq!(*weight, 1, "Weight should be 1 for insertion"); + assert_eq!( + row.values[0], + Value::Integer(100), + "Customer ID should be 100" + ); + assert_eq!( + row.values[4], + Value::Text("Alice".into()), + "Customer name should be Alice" + ); + + // Check order IDs are different + let order_id = match &row.values[1] { + Value::Integer(id) => *id, + _ => panic!("Expected integer order ID"), + }; + assert!( + (1001..=1003).contains(&order_id), + "Order ID {order_id} should be between 1001 and 1003" + ); + } + + // THIRD COMMIT: Delete one order + let mut left_delta3 = Delta::new(); + left_delta3.delete( + 2, + vec![ + Value::Integer(100), + Value::Integer(1002), + Value::Float(75.0), + ], + ); + + let delta_pair3 = DeltaPair::new(left_delta3, Delta::new()); + let result3 = pager + .io + .block(|| join.commit(delta_pair3.clone(), &mut cursors)) + .unwrap(); + + // Should produce 1 retraction for the deleted order + assert_eq!(result3.changes.len(), 1, "Should produce 1 retraction"); + assert_eq!(result3.changes[0].1, -1, "Should be a deletion"); + assert_eq!( + result3.changes[0].0.values[1], + Value::Integer(1002), + "Should delete order 1002" + ); + } + + #[test] + fn test_join_operator_many_to_many() { + // Test many-to-many: multiple rows with same key on both sides + let (pager, table_page_id, index_page_id) = create_test_pager(); + let table_cursor = BTreeCursor::new_table(None, pager.clone(), table_page_id, 10); + let index_def = create_dbsp_state_index(index_page_id); + let index_cursor = + BTreeCursor::new_index(None, pager.clone(), index_page_id, &index_def, 10); + let mut cursors = DbspStateCursors::new(table_cursor, index_cursor); + + let mut join = JoinOperator::new( + 1, // operator_id + JoinType::Inner, + vec![0], // Join on category_id + vec![0], // Join on id + vec![ + "category_id".to_string(), + "product_name".to_string(), + "price".to_string(), + ], + vec!["id".to_string(), "category_name".to_string()], + ) + .unwrap(); + + // FIRST COMMIT: Add multiple products in same category + let mut left_delta = Delta::new(); + left_delta.insert( + 1, + vec![ + Value::Integer(10), + Value::Text("Laptop".into()), + Value::Float(1000.0), + ], + ); + left_delta.insert( + 2, + vec![ + Value::Integer(10), + Value::Text("Mouse".into()), + Value::Float(50.0), + ], + ); + left_delta.insert( + 3, + vec![ + Value::Integer(10), + Value::Text("Keyboard".into()), + Value::Float(100.0), + ], + ); + + // Add multiple categories with same ID (simulating denormalized data or versioning) + let mut right_delta = Delta::new(); + right_delta.insert( + 1, + vec![Value::Integer(10), Value::Text("Electronics".into())], + ); + right_delta.insert(2, vec![Value::Integer(10), Value::Text("Computers".into())]); // Same category ID, different name + + let delta_pair = DeltaPair::new(left_delta, right_delta); + let result = pager + .io + .block(|| join.commit(delta_pair.clone(), &mut cursors)) + .unwrap(); + + // Should produce 3 products × 2 categories = 6 joins + assert_eq!( + result.changes.len(), + 6, + "Should produce 6 joins (3 products × 2 category records). Got: {:?}", + result.changes + ); + + // Verify we have all combinations + let mut found_combinations = std::collections::HashSet::new(); + for (row, weight) in &result.changes { + assert_eq!(*weight, 1); + let product = row.values[1].to_string(); + let category = row.values[4].to_string(); + found_combinations.insert((product, category)); + } + + assert_eq!( + found_combinations.len(), + 6, + "Should have 6 unique combinations" + ); + + // SECOND COMMIT: Add one more product in the same category + let mut left_delta2 = Delta::new(); + left_delta2.insert( + 4, + vec![ + Value::Integer(10), + Value::Text("Monitor".into()), + Value::Float(500.0), + ], + ); + + let delta_pair2 = DeltaPair::new(left_delta2, Delta::new()); + let result2 = pager + .io + .block(|| join.commit(delta_pair2.clone(), &mut cursors)) + .unwrap(); + + // New product should join with both existing category records + assert_eq!( + result2.changes.len(), + 2, + "New product should join with 2 existing category records" + ); + + for (row, _) in &result2.changes { + assert_eq!(row.values[1], Value::Text("Monitor".into())); + } + } + + #[test] + fn test_join_operator_update_in_one_to_many() { + // Test updates in one-to-many scenarios + let (pager, table_page_id, index_page_id) = create_test_pager(); + let table_cursor = BTreeCursor::new_table(None, pager.clone(), table_page_id, 10); + let index_def = create_dbsp_state_index(index_page_id); + let index_cursor = + BTreeCursor::new_index(None, pager.clone(), index_page_id, &index_def, 10); + let mut cursors = DbspStateCursors::new(table_cursor, index_cursor); + + let mut join = JoinOperator::new( + 1, // operator_id + JoinType::Inner, + vec![0], // Join on customer_id + vec![0], // Join on id + vec![ + "customer_id".to_string(), + "order_id".to_string(), + "amount".to_string(), + ], + vec!["id".to_string(), "name".to_string()], + ) + .unwrap(); + + // FIRST COMMIT: Setup one customer with multiple orders + let mut left_delta = Delta::new(); + left_delta.insert( + 1, + vec![ + Value::Integer(100), + Value::Integer(1001), + Value::Float(50.0), + ], + ); + left_delta.insert( + 2, + vec![ + Value::Integer(100), + Value::Integer(1002), + Value::Float(75.0), + ], + ); + left_delta.insert( + 3, + vec![ + Value::Integer(100), + Value::Integer(1003), + Value::Float(100.0), + ], + ); + + let mut right_delta = Delta::new(); + right_delta.insert(1, vec![Value::Integer(100), Value::Text("Alice".into())]); + + let delta_pair = DeltaPair::new(left_delta, right_delta); + let result = pager + .io + .block(|| join.commit(delta_pair.clone(), &mut cursors)) + .unwrap(); + + assert_eq!(result.changes.len(), 3, "Should have 3 initial joins"); + + // SECOND COMMIT: Update the customer name (affects all 3 joins) + let mut right_delta2 = Delta::new(); + // Delete old customer record + right_delta2.delete(1, vec![Value::Integer(100), Value::Text("Alice".into())]); + // Insert updated customer record + right_delta2.insert( + 1, + vec![Value::Integer(100), Value::Text("Alice Smith".into())], + ); + + let delta_pair2 = DeltaPair::new(Delta::new(), right_delta2); + let result2 = pager + .io + .block(|| join.commit(delta_pair2.clone(), &mut cursors)) + .unwrap(); + + // Should produce 3 deletions and 3 insertions (one for each order) + assert_eq!(result2.changes.len(), 6, + "Should produce 6 changes (3 deletions + 3 insertions) when updating customer with 3 orders"); + + let deletions: Vec<_> = result2.changes.iter().filter(|(_, w)| *w == -1).collect(); + let insertions: Vec<_> = result2.changes.iter().filter(|(_, w)| *w == 1).collect(); + + assert_eq!(deletions.len(), 3, "Should have 3 deletions"); + assert_eq!(insertions.len(), 3, "Should have 3 insertions"); + + // Check all deletions have old name + for (row, _) in &deletions { + assert_eq!( + row.values[4], + Value::Text("Alice".into()), + "Deletions should have old name" + ); + } + + // Check all insertions have new name + for (row, _) in &insertions { + assert_eq!( + row.values[4], + Value::Text("Alice Smith".into()), + "Insertions should have new name" + ); + } + + // Verify we still have all three order IDs in the insertions + let mut order_ids = std::collections::HashSet::new(); + for (row, _) in &insertions { + if let Value::Integer(order_id) = &row.values[1] { + order_ids.insert(*order_id); + } + } + assert_eq!( + order_ids.len(), + 3, + "Should still have all 3 order IDs after update" + ); + assert!(order_ids.contains(&1001)); + assert!(order_ids.contains(&1002)); + assert!(order_ids.contains(&1003)); + } + + #[test] + fn test_join_operator_weight_accumulation_complex() { + // Test complex weight accumulation with multiple identical rows + let (pager, table_page_id, index_page_id) = create_test_pager(); + let table_cursor = BTreeCursor::new_table(None, pager.clone(), table_page_id, 10); + let index_def = create_dbsp_state_index(index_page_id); + let index_cursor = + BTreeCursor::new_index(None, pager.clone(), index_page_id, &index_def, 10); + let mut cursors = DbspStateCursors::new(table_cursor, index_cursor); + + let mut join = JoinOperator::new( + 1, // operator_id + JoinType::Inner, + vec![0], // Join on first column + vec![0], + vec!["key".to_string(), "val_left".to_string()], + vec!["key".to_string(), "val_right".to_string()], + ) + .unwrap(); + + // FIRST COMMIT: Add identical rows multiple times (simulating duplicates) + let mut left_delta = Delta::new(); + // Same key-value pair inserted 3 times with different rowids + left_delta.insert(1, vec![Value::Integer(10), Value::Text("A".into())]); + left_delta.insert(2, vec![Value::Integer(10), Value::Text("A".into())]); + left_delta.insert(3, vec![Value::Integer(10), Value::Text("A".into())]); + + let mut right_delta = Delta::new(); + // Same key-value pair inserted 2 times + right_delta.insert(4, vec![Value::Integer(10), Value::Text("B".into())]); + right_delta.insert(5, vec![Value::Integer(10), Value::Text("B".into())]); + + let delta_pair = DeltaPair::new(left_delta, right_delta); + let result = pager + .io + .block(|| join.commit(delta_pair.clone(), &mut cursors)) + .unwrap(); + + // Should produce 3 × 2 = 6 join results (cartesian product) + assert_eq!( + result.changes.len(), + 6, + "Should produce 6 joins (3 left rows × 2 right rows)" + ); + + // All should have weight 1 + for (_, weight) in &result.changes { + assert_eq!(*weight, 1); + } + + // SECOND COMMIT: Delete one instance from left + let mut left_delta2 = Delta::new(); + left_delta2.delete(2, vec![Value::Integer(10), Value::Text("A".into())]); + + let delta_pair2 = DeltaPair::new(left_delta2, Delta::new()); + let result2 = pager + .io + .block(|| join.commit(delta_pair2.clone(), &mut cursors)) + .unwrap(); + + // Should produce 2 retractions (1 deleted left row × 2 right rows) + assert_eq!( + result2.changes.len(), + 2, + "Should produce 2 retractions when deleting 1 of 3 identical left rows" + ); + + for (_, weight) in &result2.changes { + assert_eq!(*weight, -1, "Should be retractions"); + } + } + + #[test] + fn test_join_produces_all_expected_results() { + // Test that a join produces ALL expected output rows + // This reproduces the issue where only 1 of 3 expected rows appears in the final result + + // Create a join operator similar to: SELECT u.name, o.quantity FROM users u JOIN orders o ON u.id = o.user_id + let mut join = JoinOperator::new( + 0, + JoinType::Inner, + vec![0], // Join on first column (id) + vec![0], // Join on first column (user_id) + vec!["id".to_string(), "name".to_string()], + vec![ + "user_id".to_string(), + "product_id".to_string(), + "quantity".to_string(), + ], + ) + .unwrap(); + + // Create test data matching the example that fails: + // users: (1, 'Alice'), (2, 'Bob') + // orders: (1, 5), (1, 3), (2, 7) -- user_id, quantity + let left_delta = Delta { + changes: vec![ + ( + HashableRow::new(1, vec![Value::Integer(1), Value::Text(Text::from("Alice"))]), + 1, + ), + ( + HashableRow::new(2, vec![Value::Integer(2), Value::Text(Text::from("Bob"))]), + 1, + ), + ], + }; + + // Orders: Alice has 2 orders, Bob has 1 + let right_delta = Delta { + changes: vec![ + ( + HashableRow::new( + 1, + vec![Value::Integer(1), Value::Integer(100), Value::Integer(5)], + ), + 1, + ), + ( + HashableRow::new( + 2, + vec![Value::Integer(1), Value::Integer(101), Value::Integer(3)], + ), + 1, + ), + ( + HashableRow::new( + 3, + vec![Value::Integer(2), Value::Integer(100), Value::Integer(7)], + ), + 1, + ), + ], + }; + + // Evaluate the join + let delta_pair = DeltaPair::new(left_delta, right_delta); + let mut state = EvalState::Init { deltas: delta_pair }; + + let (pager, table_root, index_root) = create_test_pager(); + let table_cursor = BTreeCursor::new_table(None, pager.clone(), table_root, 5); + let index_def = create_dbsp_state_index(index_root); + let index_cursor = BTreeCursor::new_index(None, pager.clone(), index_root, &index_def, 4); + let mut cursors = DbspStateCursors::new(table_cursor, index_cursor); + + let result = pager + .io + .block(|| join.eval(&mut state, &mut cursors)) + .unwrap(); + + // Should produce 3 results: Alice with 2 orders, Bob with 1 order + assert_eq!( + result.changes.len(), + 3, + "Should produce 3 joined rows (Alice×2 + Bob×1)" + ); + + // Verify the actual content of the results + let mut expected_results = std::collections::HashSet::new(); + // Expected: (Alice, 5), (Alice, 3), (Bob, 7) + expected_results.insert(("Alice".to_string(), 5)); + expected_results.insert(("Alice".to_string(), 3)); + expected_results.insert(("Bob".to_string(), 7)); + + let mut actual_results = std::collections::HashSet::new(); + for (row, weight) in &result.changes { + assert_eq!(*weight, 1, "All results should have weight 1"); + + // Extract name (column 1 from left) and quantity (column 3 from right) + let name = match &row.values[1] { + Value::Text(t) => t.as_str().to_string(), + _ => panic!("Expected text value for name"), + }; + let quantity = match &row.values[4] { + Value::Integer(q) => *q, + _ => panic!("Expected integer value for quantity"), + }; + + actual_results.insert((name, quantity)); + } + + assert_eq!( + expected_results, actual_results, + "Join should produce all expected results. Expected: {expected_results:?}, Got: {actual_results:?}", + ); + + // Also verify that rowids are unique (this is important for btree storage) + let mut seen_rowids = std::collections::HashSet::new(); + for (row, _) in &result.changes { + let was_new = seen_rowids.insert(row.rowid); + assert!(was_new, "Duplicate rowid found: {}. This would cause rows to overwrite each other in btree storage!", row.rowid); + } + } } diff --git a/core/incremental/persistence.rs b/core/incremental/persistence.rs index eca26cd7c..5cf41b94a 100644 --- a/core/incremental/persistence.rs +++ b/core/incremental/persistence.rs @@ -1,12 +1,7 @@ -use crate::incremental::dbsp::HashableRow; -use crate::incremental::operator::{ - generate_storage_id, AggColumnInfo, AggregateFunction, AggregateOperator, AggregateState, - DbspStateCursors, MinMaxDeltas, AGG_TYPE_MINMAX, -}; +use crate::incremental::operator::{AggregateFunction, AggregateState, DbspStateCursors}; use crate::storage::btree::{BTreeCursor, BTreeKey}; -use crate::types::{IOResult, ImmutableRecord, RefValue, SeekKey, SeekOp, SeekResult}; +use crate::types::{IOResult, ImmutableRecord, SeekKey, SeekOp, SeekResult}; use crate::{return_if_io, LimboError, Result, Value}; -use std::collections::{HashMap, HashSet}; #[derive(Debug, Default)] pub enum ReadRecord { @@ -290,672 +285,3 @@ impl WriteRow { } } } - -/// State machine for recomputing MIN/MAX values after deletion -#[derive(Debug)] -pub enum RecomputeMinMax { - ProcessElements { - /// Current column being processed - current_column_idx: usize, - /// Columns to process (combined MIN and MAX) - columns_to_process: Vec<(String, String, bool)>, // (group_key, column_name, is_min) - /// MIN/MAX deltas for checking values and weights - min_max_deltas: MinMaxDeltas, - }, - Scan { - /// Columns still to process - columns_to_process: Vec<(String, String, bool)>, - /// Current index in columns_to_process (will resume from here) - current_column_idx: usize, - /// MIN/MAX deltas for checking values and weights - min_max_deltas: MinMaxDeltas, - /// Current group key being processed - group_key: String, - /// Current column name being processed - column_name: String, - /// Whether we're looking for MIN (true) or MAX (false) - is_min: bool, - /// The scan state machine for finding the new MIN/MAX - scan_state: Box, - }, - Done, -} - -impl RecomputeMinMax { - pub fn new( - min_max_deltas: MinMaxDeltas, - existing_groups: &HashMap, - operator: &AggregateOperator, - ) -> Self { - let mut groups_to_check: HashSet<(String, String, bool)> = HashSet::new(); - - // Remember the min_max_deltas are essentially just the only column that is affected by - // this min/max, in delta (actually ZSet - consolidated delta) format. This makes it easier - // for us to consume it in here. - // - // The most challenging case is the case where there is a retraction, since we need to go - // back to the index. - for (group_key_str, values) in &min_max_deltas { - for ((col_name, hashable_row), weight) in values { - let col_info = operator.column_min_max.get(col_name); - - let value = &hashable_row.values[0]; - - if *weight < 0 { - // Deletion detected - check if it's the current MIN/MAX - if let Some(state) = existing_groups.get(group_key_str) { - // Check for MIN - if let Some(current_min) = state.mins.get(col_name) { - if current_min == value { - groups_to_check.insert(( - group_key_str.clone(), - col_name.clone(), - true, - )); - } - } - // Check for MAX - if let Some(current_max) = state.maxs.get(col_name) { - if current_max == value { - groups_to_check.insert(( - group_key_str.clone(), - col_name.clone(), - false, - )); - } - } - } - } else if *weight > 0 { - // If it is not found in the existing groups, then we only need to care - // about this if this is a new record being inserted - if let Some(info) = col_info { - if info.has_min { - groups_to_check.insert((group_key_str.clone(), col_name.clone(), true)); - } - if info.has_max { - groups_to_check.insert(( - group_key_str.clone(), - col_name.clone(), - false, - )); - } - } - } - } - } - - if groups_to_check.is_empty() { - // No recomputation or initialization needed - Self::Done - } else { - // Convert HashSet to Vec for indexed processing - let groups_to_check_vec: Vec<_> = groups_to_check.into_iter().collect(); - Self::ProcessElements { - current_column_idx: 0, - columns_to_process: groups_to_check_vec, - min_max_deltas, - } - } - } - - pub fn process( - &mut self, - existing_groups: &mut HashMap, - operator: &AggregateOperator, - cursors: &mut DbspStateCursors, - ) -> Result> { - loop { - match self { - RecomputeMinMax::ProcessElements { - current_column_idx, - columns_to_process, - min_max_deltas, - } => { - if *current_column_idx >= columns_to_process.len() { - *self = RecomputeMinMax::Done; - return Ok(IOResult::Done(())); - } - - let (group_key, column_name, is_min) = - columns_to_process[*current_column_idx].clone(); - - // Get column index from pre-computed info - let column_index = operator - .column_min_max - .get(&column_name) - .map(|info| info.index) - .unwrap(); // Should always exist since we're processing known columns - - // Get current value from existing state - let current_value = existing_groups.get(&group_key).and_then(|state| { - if is_min { - state.mins.get(&column_name).cloned() - } else { - state.maxs.get(&column_name).cloned() - } - }); - - // Create storage keys for index lookup - let storage_id = - generate_storage_id(operator.operator_id, column_index, AGG_TYPE_MINMAX); - let zset_id = operator.generate_group_rowid(&group_key); - - // Get the values for this group from min_max_deltas - let group_values = min_max_deltas.get(&group_key).cloned().unwrap_or_default(); - - let columns_to_process = std::mem::take(columns_to_process); - let min_max_deltas = std::mem::take(min_max_deltas); - - let scan_state = if is_min { - Box::new(ScanState::new_for_min( - current_value, - group_key.clone(), - column_name.clone(), - storage_id, - zset_id, - group_values, - )) - } else { - Box::new(ScanState::new_for_max( - current_value, - group_key.clone(), - column_name.clone(), - storage_id, - zset_id, - group_values, - )) - }; - - *self = RecomputeMinMax::Scan { - columns_to_process, - current_column_idx: *current_column_idx, - min_max_deltas, - group_key, - column_name, - is_min, - scan_state, - }; - } - RecomputeMinMax::Scan { - columns_to_process, - current_column_idx, - min_max_deltas, - group_key, - column_name, - is_min, - scan_state, - } => { - // Find new value using the scan state machine - let new_value = return_if_io!(scan_state.find_new_value(cursors)); - - // Update the state with new value (create if doesn't exist) - let state = existing_groups.entry(group_key.clone()).or_default(); - - if *is_min { - if let Some(min_val) = new_value { - state.mins.insert(column_name.clone(), min_val); - } else { - state.mins.remove(column_name); - } - } else if let Some(max_val) = new_value { - state.maxs.insert(column_name.clone(), max_val); - } else { - state.maxs.remove(column_name); - } - - // Move to next column - let min_max_deltas = std::mem::take(min_max_deltas); - let columns_to_process = std::mem::take(columns_to_process); - *self = RecomputeMinMax::ProcessElements { - current_column_idx: *current_column_idx + 1, - columns_to_process, - min_max_deltas, - }; - } - RecomputeMinMax::Done => { - return Ok(IOResult::Done(())); - } - } - } - } -} - -/// State machine for scanning through the index to find new MIN/MAX values -#[derive(Debug)] -pub enum ScanState { - CheckCandidate { - /// Current candidate value for MIN/MAX - candidate: Option, - /// Group key being processed - group_key: String, - /// Column name being processed - column_name: String, - /// Storage ID for the index seek - storage_id: i64, - /// ZSet ID for the group - zset_id: i64, - /// Group values from MinMaxDeltas: (column_name, HashableRow) -> weight - group_values: HashMap<(String, HashableRow), isize>, - /// Whether we're looking for MIN (true) or MAX (false) - is_min: bool, - }, - FetchNextCandidate { - /// Current candidate to seek past - current_candidate: Value, - /// Group key being processed - group_key: String, - /// Column name being processed - column_name: String, - /// Storage ID for the index seek - storage_id: i64, - /// ZSet ID for the group - zset_id: i64, - /// Group values from MinMaxDeltas: (column_name, HashableRow) -> weight - group_values: HashMap<(String, HashableRow), isize>, - /// Whether we're looking for MIN (true) or MAX (false) - is_min: bool, - }, - Done { - /// The final MIN/MAX value found - result: Option, - }, -} - -impl ScanState { - pub fn new_for_min( - current_min: Option, - group_key: String, - column_name: String, - storage_id: i64, - zset_id: i64, - group_values: HashMap<(String, HashableRow), isize>, - ) -> Self { - Self::CheckCandidate { - candidate: current_min, - group_key, - column_name, - storage_id, - zset_id, - group_values, - is_min: true, - } - } - - // Extract a new candidate from the index. It is possible that, when searching, - // we end up going into a different operator altogether. That means we have - // exhausted this operator (or group) entirely, and no good candidate was found - fn extract_new_candidate( - cursors: &mut DbspStateCursors, - index_record: &ImmutableRecord, - seek_op: SeekOp, - storage_id: i64, - zset_id: i64, - ) -> Result>> { - let seek_result = return_if_io!(cursors - .index_cursor - .seek(SeekKey::IndexKey(index_record), seek_op)); - if !matches!(seek_result, SeekResult::Found) { - return Ok(IOResult::Done(None)); - } - - let record = return_if_io!(cursors.index_cursor.record()).ok_or_else(|| { - LimboError::InternalError( - "Record found on the cursor, but could not be read".to_string(), - ) - })?; - - let values = record.get_values(); - if values.len() < 3 { - return Ok(IOResult::Done(None)); - } - - let Some(rec_storage_id) = values.first() else { - return Ok(IOResult::Done(None)); - }; - let Some(rec_zset_id) = values.get(1) else { - return Ok(IOResult::Done(None)); - }; - - // Check if we're still in the same group - if let (RefValue::Integer(rec_sid), RefValue::Integer(rec_zid)) = - (rec_storage_id, rec_zset_id) - { - if *rec_sid != storage_id || *rec_zid != zset_id { - return Ok(IOResult::Done(None)); - } - } else { - return Ok(IOResult::Done(None)); - } - - // Get the value (3rd element) - Ok(IOResult::Done(values.get(2).map(|v| v.to_owned()))) - } - - pub fn new_for_max( - current_max: Option, - group_key: String, - column_name: String, - storage_id: i64, - zset_id: i64, - group_values: HashMap<(String, HashableRow), isize>, - ) -> Self { - Self::CheckCandidate { - candidate: current_max, - group_key, - column_name, - storage_id, - zset_id, - group_values, - is_min: false, - } - } - - pub fn find_new_value( - &mut self, - cursors: &mut DbspStateCursors, - ) -> Result>> { - loop { - match self { - ScanState::CheckCandidate { - candidate, - group_key, - column_name, - storage_id, - zset_id, - group_values, - is_min, - } => { - // First, check if we have a candidate - if let Some(cand_val) = candidate { - // Check if the candidate is retracted (weight <= 0) - // Create a HashableRow to look up the weight - let hashable_cand = HashableRow::new(0, vec![cand_val.clone()]); - let key = (column_name.clone(), hashable_cand); - let is_retracted = - group_values.get(&key).is_some_and(|weight| *weight <= 0); - - if is_retracted { - // Candidate is retracted, need to fetch next from index - *self = ScanState::FetchNextCandidate { - current_candidate: cand_val.clone(), - group_key: std::mem::take(group_key), - column_name: std::mem::take(column_name), - storage_id: *storage_id, - zset_id: *zset_id, - group_values: std::mem::take(group_values), - is_min: *is_min, - }; - continue; - } - } - - // Candidate is valid or we have no candidate - // Now find the best value from insertions in group_values - let mut best_from_zset = None; - for ((col, hashable_val), weight) in group_values.iter() { - if col == column_name && *weight > 0 { - let value = &hashable_val.values[0]; - // Skip NULL values - they don't participate in MIN/MAX - if value == &Value::Null { - continue; - } - // This is an insertion for our column - if let Some(ref current_best) = best_from_zset { - if *is_min { - if value.cmp(current_best) == std::cmp::Ordering::Less { - best_from_zset = Some(value.clone()); - } - } else if value.cmp(current_best) == std::cmp::Ordering::Greater { - best_from_zset = Some(value.clone()); - } - } else { - best_from_zset = Some(value.clone()); - } - } - } - - // Compare candidate with best from ZSet, filtering out NULLs - let result = match (&candidate, &best_from_zset) { - (Some(cand), Some(zset_val)) if cand != &Value::Null => { - if *is_min { - if zset_val.cmp(cand) == std::cmp::Ordering::Less { - Some(zset_val.clone()) - } else { - Some(cand.clone()) - } - } else if zset_val.cmp(cand) == std::cmp::Ordering::Greater { - Some(zset_val.clone()) - } else { - Some(cand.clone()) - } - } - (Some(cand), None) if cand != &Value::Null => Some(cand.clone()), - (None, Some(zset_val)) => Some(zset_val.clone()), - (Some(cand), Some(_)) if cand == &Value::Null => best_from_zset, - _ => None, - }; - - *self = ScanState::Done { result }; - } - - ScanState::FetchNextCandidate { - current_candidate, - group_key, - column_name, - storage_id, - zset_id, - group_values, - is_min, - } => { - // Seek to the next value in the index - let index_key = vec![ - Value::Integer(*storage_id), - Value::Integer(*zset_id), - current_candidate.clone(), - ]; - let index_record = ImmutableRecord::from_values(&index_key, index_key.len()); - - let seek_op = if *is_min { - SeekOp::GT // For MIN, seek greater than current - } else { - SeekOp::LT // For MAX, seek less than current - }; - - let new_candidate = return_if_io!(Self::extract_new_candidate( - cursors, - &index_record, - seek_op, - *storage_id, - *zset_id - )); - - *self = ScanState::CheckCandidate { - candidate: new_candidate, - group_key: std::mem::take(group_key), - column_name: std::mem::take(column_name), - storage_id: *storage_id, - zset_id: *zset_id, - group_values: std::mem::take(group_values), - is_min: *is_min, - }; - } - - ScanState::Done { result } => { - return Ok(IOResult::Done(result.clone())); - } - } - } - } -} - -/// State machine for persisting Min/Max values to storage -#[derive(Debug)] -pub enum MinMaxPersistState { - Init { - min_max_deltas: MinMaxDeltas, - group_keys: Vec, - }, - ProcessGroup { - min_max_deltas: MinMaxDeltas, - group_keys: Vec, - group_idx: usize, - value_idx: usize, - }, - WriteValue { - min_max_deltas: MinMaxDeltas, - group_keys: Vec, - group_idx: usize, - value_idx: usize, - value: Value, - column_name: String, - weight: isize, - write_row: WriteRow, - }, - Done, -} - -impl MinMaxPersistState { - pub fn new(min_max_deltas: MinMaxDeltas) -> Self { - let group_keys: Vec = min_max_deltas.keys().cloned().collect(); - Self::Init { - min_max_deltas, - group_keys, - } - } - - pub fn persist_min_max( - &mut self, - operator_id: usize, - column_min_max: &HashMap, - cursors: &mut DbspStateCursors, - generate_group_rowid: impl Fn(&str) -> i64, - ) -> Result> { - loop { - match self { - MinMaxPersistState::Init { - min_max_deltas, - group_keys, - } => { - let min_max_deltas = std::mem::take(min_max_deltas); - let group_keys = std::mem::take(group_keys); - *self = MinMaxPersistState::ProcessGroup { - min_max_deltas, - group_keys, - group_idx: 0, - value_idx: 0, - }; - } - MinMaxPersistState::ProcessGroup { - min_max_deltas, - group_keys, - group_idx, - value_idx, - } => { - // Check if we're past all groups - if *group_idx >= group_keys.len() { - *self = MinMaxPersistState::Done; - continue; - } - - let group_key_str = &group_keys[*group_idx]; - let values = &min_max_deltas[group_key_str]; // This should always exist - - // Convert HashMap to Vec for indexed access - let values_vec: Vec<_> = values.iter().collect(); - - // Check if we have more values in current group - if *value_idx >= values_vec.len() { - *group_idx += 1; - *value_idx = 0; - // Continue to check if we're past all groups now - continue; - } - - // Process current value and extract what we need before taking ownership - let ((column_name, hashable_row), weight) = values_vec[*value_idx]; - let column_name = column_name.clone(); - let value = hashable_row.values[0].clone(); // Extract the Value from HashableRow - let weight = *weight; - - let min_max_deltas = std::mem::take(min_max_deltas); - let group_keys = std::mem::take(group_keys); - *self = MinMaxPersistState::WriteValue { - min_max_deltas, - group_keys, - group_idx: *group_idx, - value_idx: *value_idx, - column_name, - value, - weight, - write_row: WriteRow::new(), - }; - } - MinMaxPersistState::WriteValue { - min_max_deltas, - group_keys, - group_idx, - value_idx, - value, - column_name, - weight, - write_row, - } => { - // Should have exited in the previous state - assert!(*group_idx < group_keys.len()); - - let group_key_str = &group_keys[*group_idx]; - - // Get the column index from the pre-computed map - let column_info = column_min_max - .get(&*column_name) - .expect("Column should exist in column_min_max map"); - let column_index = column_info.index; - - // Build the key components for MinMax storage using new encoding - let storage_id = - generate_storage_id(operator_id, column_index, AGG_TYPE_MINMAX); - let zset_id = generate_group_rowid(group_key_str); - - // element_id is the actual value for Min/Max - let element_id_val = value.clone(); - - // Create index key - let index_key = vec![ - Value::Integer(storage_id), - Value::Integer(zset_id), - element_id_val.clone(), - ]; - - // Record values (operator_id, zset_id, element_id, unused_placeholder) - // For MIN/MAX, the element_id IS the value, so we use NULL for the 4th column - let record_values = vec![ - Value::Integer(storage_id), - Value::Integer(zset_id), - element_id_val.clone(), - Value::Null, // Placeholder - not used for MIN/MAX - ]; - - return_if_io!(write_row.write_row( - cursors, - index_key.clone(), - record_values, - *weight - )); - - // Move to next value - let min_max_deltas = std::mem::take(min_max_deltas); - let group_keys = std::mem::take(group_keys); - *self = MinMaxPersistState::ProcessGroup { - min_max_deltas, - group_keys, - group_idx: *group_idx, - value_idx: *value_idx + 1, - }; - } - MinMaxPersistState::Done => { - return Ok(IOResult::Done(())); - } - } - } - } -} diff --git a/core/incremental/project_operator.rs b/core/incremental/project_operator.rs new file mode 100644 index 000000000..b1d9fc9ed --- /dev/null +++ b/core/incremental/project_operator.rs @@ -0,0 +1,168 @@ +// Project operator for DBSP-style incremental computation +// This operator projects/transforms columns in a relational stream + +use crate::incremental::dbsp::{Delta, DeltaPair, HashableRow}; +use crate::incremental::expr_compiler::CompiledExpression; +use crate::incremental::operator::{ + ComputationTracker, DbspStateCursors, EvalState, IncrementalOperator, +}; +use crate::types::IOResult; +use crate::{Connection, Database, Result, Value}; +use std::sync::{Arc, Mutex}; + +#[derive(Debug, Clone)] +pub struct ProjectColumn { + /// Compiled expression (handles both trivial columns and complex expressions) + pub compiled: CompiledExpression, +} + +/// Project operator - selects/transforms columns +#[derive(Clone)] +pub struct ProjectOperator { + columns: Vec, + input_column_names: Vec, + output_column_names: Vec, + tracker: Option>>, + // Internal in-memory connection for expression evaluation + // Programs are very dependent on having a connection, so give it one. + // + // We could in theory pass the current connection, but there are a host of problems with that. + // For example: during a write transaction, where views are usually updated, we have autocommit + // on. When the program we are executing calls Halt, it will try to commit the current + // transaction, which is absolutely incorrect. + // + // There are other ways to solve this, but a read-only connection to an empty in-memory + // database gives us the closest environment we need to execute expressions. + internal_conn: Arc, +} + +impl std::fmt::Debug for ProjectOperator { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.debug_struct("ProjectOperator") + .field("columns", &self.columns) + .field("input_column_names", &self.input_column_names) + .field("output_column_names", &self.output_column_names) + .finish() + } +} + +impl ProjectOperator { + /// Create a ProjectOperator from pre-compiled expressions + pub fn from_compiled( + compiled_exprs: Vec, + aliases: Vec>, + input_column_names: Vec, + output_column_names: Vec, + ) -> crate::Result { + // Set up internal connection for expression evaluation + let io = Arc::new(crate::MemoryIO::new()); + let db = Database::open_file( + io, ":memory:", false, // no MVCC needed for expression evaluation + false, // no indexes needed + )?; + let internal_conn = db.connect()?; + // Set to read-only mode and disable auto-commit since we're only evaluating expressions + internal_conn.query_only.set(true); + internal_conn.auto_commit.set(false); + + // Create ProjectColumn structs from compiled expressions + let columns: Vec = compiled_exprs + .into_iter() + .zip(aliases) + .map(|(compiled, _alias)| ProjectColumn { compiled }) + .collect(); + + Ok(Self { + columns, + input_column_names, + output_column_names, + tracker: None, + internal_conn, + }) + } + + fn project_values(&self, values: &[Value]) -> Vec { + let mut output = Vec::new(); + + for col in &self.columns { + // Use the internal connection's pager for expression evaluation + let internal_pager = self.internal_conn.pager.borrow().clone(); + + // Execute the compiled expression (handles both columns and complex expressions) + let result = col + .compiled + .execute(values, internal_pager) + .expect("Failed to execute compiled expression for the Project operator"); + output.push(result); + } + + output + } +} + +impl IncrementalOperator for ProjectOperator { + fn eval( + &mut self, + state: &mut EvalState, + _cursors: &mut DbspStateCursors, + ) -> Result> { + let delta = match state { + EvalState::Init { deltas } => { + // Project operators only use left_delta, right_delta must be empty + assert!( + deltas.right.is_empty(), + "ProjectOperator expects right_delta to be empty" + ); + std::mem::take(&mut deltas.left) + } + _ => unreachable!( + "ProjectOperator doesn't execute the state machine. Should be in Init state" + ), + }; + + let mut output_delta = Delta::new(); + + for (row, weight) in delta.changes { + if let Some(tracker) = &self.tracker { + tracker.lock().unwrap().record_project(); + } + + let projected = self.project_values(&row.values); + let projected_row = HashableRow::new(row.rowid, projected); + output_delta.changes.push((projected_row, weight)); + } + + *state = EvalState::Done; + Ok(IOResult::Done(output_delta)) + } + + fn commit( + &mut self, + deltas: DeltaPair, + _cursors: &mut DbspStateCursors, + ) -> Result> { + // Project operator only uses left delta, right must be empty + assert!( + deltas.right.is_empty(), + "ProjectOperator expects right delta to be empty in commit" + ); + + let mut output_delta = Delta::new(); + + // Commit the delta to our internal state and build output + for (row, weight) in &deltas.left.changes { + if let Some(tracker) = &self.tracker { + tracker.lock().unwrap().record_project(); + } + let projected = self.project_values(&row.values); + let projected_row = HashableRow::new(row.rowid, projected); + output_delta.changes.push((projected_row, *weight)); + } + + Ok(crate::types::IOResult::Done(output_delta)) + } + + fn set_tracker(&mut self, tracker: Arc>) { + self.tracker = Some(tracker); + } +} diff --git a/core/incremental/view.rs b/core/incremental/view.rs index 9a200c830..fd7b3988a 100644 --- a/core/incremental/view.rs +++ b/core/incremental/view.rs @@ -1,11 +1,11 @@ use super::compiler::{DbspCircuit, DbspCompiler, DeltaSet}; use super::dbsp::Delta; -use super::operator::{ComputationTracker, FilterPredicate}; -use crate::schema::{BTreeTable, Column, Schema}; +use super::operator::ComputationTracker; +use crate::schema::{BTreeTable, Schema}; use crate::storage::btree::BTreeCursor; use crate::translate::logical::LogicalPlanBuilder; use crate::types::{IOResult, Value}; -use crate::util::extract_view_columns; +use crate::util::{extract_view_columns, ViewColumnSchema}; use crate::{return_if_io, LimboError, Pager, Result, Statement}; use std::cell::RefCell; use std::collections::HashMap; @@ -22,8 +22,15 @@ use turso_parser::{ pub enum PopulateState { /// Initial state - need to prepare the query Start, + /// All tables that need to be populated + ProcessingAllTables { + queries: Vec, + current_idx: usize, + }, /// Actively processing rows from the query - Processing { + ProcessingOneTable { + queries: Vec, + current_idx: usize, stmt: Box, rows_processed: usize, /// If we're in the middle of processing a row (merge_delta returned I/O) @@ -38,14 +45,26 @@ impl fmt::Debug for PopulateState { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { match self { PopulateState::Start => write!(f, "Start"), - PopulateState::Processing { + PopulateState::ProcessingAllTables { + current_idx, + queries, + } => f + .debug_struct("ProcessingAllTables") + .field("current_idx", current_idx) + .field("num_queries", &queries.len()) + .finish(), + PopulateState::ProcessingOneTable { + current_idx, rows_processed, pending_row, + queries, .. } => f - .debug_struct("Processing") + .debug_struct("ProcessingOneTable") + .field("current_idx", current_idx) .field("rows_processed", rows_processed) .field("has_pending", &pending_row.is_some()) + .field("total_queries", &queries.len()) .finish(), PopulateState::Done => write!(f, "Done"), } @@ -163,8 +182,6 @@ impl AllViewsTxState { #[derive(Debug)] pub struct IncrementalView { name: String, - // WHERE clause predicate for filtering (kept for compatibility) - pub where_predicate: FilterPredicate, // The SELECT statement that defines how to transform input data pub select_stmt: ast::Select, @@ -173,8 +190,13 @@ pub struct IncrementalView { // All tables referenced by this view (from FROM clause and JOINs) referenced_tables: Vec>, - // The view's output columns with their types - pub columns: Vec, + // Mapping from table aliases to actual table names (e.g., "c" -> "customers") + table_aliases: HashMap, + // Mapping from table name to fully qualified name (e.g., "customers" -> "main.customers") + // This preserves database qualification from the original query + qualified_table_names: HashMap, + // The view's column schema with table relationships + pub column_schema: ViewColumnSchema, // State machine for population populate_state: PopulateState, // Computation tracker for statistics @@ -186,20 +208,6 @@ pub struct IncrementalView { } impl IncrementalView { - /// Validate that a CREATE MATERIALIZED VIEW statement can be handled by IncrementalView - /// This should be called early, before updating sqlite_master - pub fn can_create_view(select: &ast::Select) -> Result<()> { - // Check for JOINs - let (join_tables, join_condition) = Self::extract_join_info(select); - if join_tables.is_some() || join_condition.is_some() { - return Err(LimboError::ParseError( - "JOINs in views are not yet supported".to_string(), - )); - } - - Ok(()) - } - /// Try to compile the SELECT statement into a DBSP circuit fn try_compile_circuit( select: &ast::Select, @@ -227,11 +235,16 @@ impl IncrementalView { /// Get an iterator over column names, using enumerated naming for unnamed columns pub fn column_names(&self) -> impl Iterator + '_ { - self.columns.iter().enumerate().map(|(i, col)| { - col.name - .clone() - .unwrap_or_else(|| format!("column{}", i + 1)) - }) + self.column_schema + .columns + .iter() + .enumerate() + .map(|(i, vc)| { + vc.column + .name + .clone() + .unwrap_or_else(|| format!("column{}", i + 1)) + }) } /// Check if this view has the same SQL definition as the provided SQL string @@ -251,24 +264,9 @@ impl IncrementalView { pub fn validate_and_extract_columns( select: &ast::Select, schema: &Schema, - ) -> Result> { - // For now, just extract columns from a simple select - // This will need to be expanded to handle joins, aggregates, etc. - - // Get the base table name - let base_table_name = Self::extract_base_table(select).ok_or_else(|| { - LimboError::ParseError("Cannot extract base table from SELECT".to_string()) - })?; - - // Get the table from schema - let table = schema - .get_table(&base_table_name) - .and_then(|t| t.btree()) - .ok_or_else(|| LimboError::ParseError(format!("Table {base_table_name} not found")))?; - - // For now, return all columns from the base table - // In the future, this should parse the select list and handle projections - Ok(table.columns.clone()) + ) -> Result { + // Use the shared function to extract columns with full table context + extract_view_columns(select, schema) } pub fn from_sql( @@ -311,27 +309,20 @@ impl IncrementalView { ) -> Result { let name = view_name.name.as_str().to_string(); - let where_predicate = FilterPredicate::from_select(&select)?; - // Extract output columns using the shared function - let view_columns = extract_view_columns(&select, schema); + let column_schema = extract_view_columns(&select, schema)?; - let (join_tables, join_condition) = Self::extract_join_info(&select); - if join_tables.is_some() || join_condition.is_some() { - return Err(LimboError::ParseError( - "JOINs in views are not yet supported".to_string(), - )); - } - - // Get all tables from FROM clause and JOINs - let referenced_tables = Self::extract_all_tables(&select, schema)?; + // Get all tables from FROM clause and JOINs, along with their aliases + let (referenced_tables, table_aliases, qualified_table_names) = + Self::extract_all_tables(&select, schema)?; Self::new( name, - where_predicate, select.clone(), referenced_tables, - view_columns, + table_aliases, + qualified_table_names, + column_schema, schema, main_data_root, internal_state_root, @@ -342,10 +333,11 @@ impl IncrementalView { #[allow(clippy::too_many_arguments)] pub fn new( name: String, - where_predicate: FilterPredicate, select_stmt: ast::Select, referenced_tables: Vec>, - columns: Vec, + table_aliases: HashMap, + qualified_table_names: HashMap, + column_schema: ViewColumnSchema, schema: &Schema, main_data_root: usize, internal_state_root: usize, @@ -365,11 +357,12 @@ impl IncrementalView { Ok(Self { name, - where_predicate, select_stmt, circuit, referenced_tables, - columns, + table_aliases, + qualified_table_names, + column_schema, populate_state: PopulateState::Start, tracker, root_page: main_data_root, @@ -412,9 +405,22 @@ impl IncrementalView { self.referenced_tables.clone() } - /// Extract all table names from a SELECT statement (including JOINs) - fn extract_all_tables(select: &ast::Select, schema: &Schema) -> Result>> { + /// Extract all tables and their aliases from the SELECT statement + /// Returns a tuple of (tables, alias_map, qualified_names) + /// where alias_map is alias -> table_name + /// and qualified_names is table_name -> fully_qualified_name + #[allow(clippy::type_complexity)] + fn extract_all_tables( + select: &ast::Select, + schema: &Schema, + ) -> Result<( + Vec>, + HashMap, + HashMap, + )> { let mut tables = Vec::new(); + let mut aliases = HashMap::new(); + let mut qualified_names = HashMap::new(); if let ast::OneSelect::Select { from: Some(ref from), @@ -422,10 +428,24 @@ impl IncrementalView { } = select.body.select { // Get the main table from FROM clause - if let ast::SelectTable::Table(name, _, _) = from.select.as_ref() { + if let ast::SelectTable::Table(name, alias, _) = from.select.as_ref() { let table_name = name.name.as_str(); + + // Build the fully qualified name + let qualified_name = if let Some(ref db) = name.db_name { + format!("{db}.{table_name}") + } else { + table_name.to_string() + }; + if let Some(table) = schema.get_btree_table(table_name) { tables.push(table.clone()); + qualified_names.insert(table_name.to_string(), qualified_name); + + // Store the alias mapping if there is an alias + if let Some(alias_name) = alias { + aliases.insert(alias_name.to_string(), table_name.to_string()); + } } else { return Err(LimboError::ParseError(format!( "Table '{table_name}' not found in schema" @@ -435,10 +455,24 @@ impl IncrementalView { // Get all tables from JOIN clauses for join in &from.joins { - if let ast::SelectTable::Table(name, _, _) = join.table.as_ref() { + if let ast::SelectTable::Table(name, alias, _) = join.table.as_ref() { let table_name = name.name.as_str(); + + // Build the fully qualified name + let qualified_name = if let Some(ref db) = name.db_name { + format!("{db}.{table_name}") + } else { + table_name.to_string() + }; + if let Some(table) = schema.get_btree_table(table_name) { tables.push(table.clone()); + qualified_names.insert(table_name.to_string(), qualified_name); + + // Store the alias mapping if there is an alias + if let Some(alias_name) = alias { + aliases.insert(alias_name.to_string(), table_name.to_string()); + } } else { return Err(LimboError::ParseError(format!( "Table '{table_name}' not found in schema" @@ -454,106 +488,391 @@ impl IncrementalView { )); } - Ok(tables) + Ok((tables, aliases, qualified_names)) } - /// Extract the base table name from a SELECT statement (for non-join cases) - fn extract_base_table(select: &ast::Select) -> Option { - if let ast::OneSelect::Select { - from: Some(ref from), - .. - } = select.body.select - { - if let ast::SelectTable::Table(name, _, _) = from.select.as_ref() { - return Some(name.name.as_str().to_string()); - } - } - None - } - - /// Generate the SQL query for populating the view from its source table - fn sql_for_populate(&self) -> crate::Result { - // Get the first table from referenced tables + /// Generate SQL queries for populating the view from each source table + /// Returns a vector of SQL statements, one for each referenced table + /// Each query includes only the WHERE conditions relevant to that specific table + fn sql_for_populate(&self) -> crate::Result> { if self.referenced_tables.is_empty() { return Err(LimboError::ParseError( "No tables to populate from".to_string(), )); } - let table = &self.referenced_tables[0]; - // Check if the table has a rowid alias (INTEGER PRIMARY KEY column) - let has_rowid_alias = table.columns.iter().any(|col| col.is_rowid_alias); + let mut queries = Vec::new(); - // For now, select all columns since we don't have the static operators - // The circuit will handle filtering and projection - // If there's a rowid alias, we don't need to select rowid separately - let select_clause = if has_rowid_alias { - "*".to_string() - } else { - "*, rowid".to_string() - }; + for table in &self.referenced_tables { + // Check if the table has a rowid alias (INTEGER PRIMARY KEY column) + let has_rowid_alias = table.columns.iter().any(|col| col.is_rowid_alias); - // Build WHERE clause from the where_predicate - let where_clause = self.build_where_clause(&self.where_predicate)?; + // For now, select all columns since we don't have the static operators + // The circuit will handle filtering and projection + // If there's a rowid alias, we don't need to select rowid separately + let select_clause = if has_rowid_alias { + "*".to_string() + } else { + "*, rowid".to_string() + }; - // Construct the final query - let query = if where_clause.is_empty() { - format!("SELECT {} FROM {}", select_clause, table.name) - } else { - format!( - "SELECT {} FROM {} WHERE {}", - select_clause, table.name, where_clause - ) - }; - Ok(query) + // Extract WHERE conditions for this specific table + let where_clause = self.extract_where_clause_for_table(&table.name)?; + + // Use the qualified table name if available, otherwise just the table name + let table_name = self + .qualified_table_names + .get(&table.name) + .cloned() + .unwrap_or_else(|| table.name.clone()); + + // Construct the query for this table + let query = if where_clause.is_empty() { + format!("SELECT {select_clause} FROM {table_name}") + } else { + format!("SELECT {select_clause} FROM {table_name} WHERE {where_clause}") + }; + queries.push(query); + } + + Ok(queries) } - /// Build a WHERE clause from a FilterPredicate - fn build_where_clause(&self, predicate: &FilterPredicate) -> crate::Result { - match predicate { - FilterPredicate::None => Ok(String::new()), - FilterPredicate::Equals { column, value } => { - Ok(format!("{} = {}", column, self.value_to_sql(value))) + /// Extract WHERE conditions that apply to a specific table + /// This analyzes the WHERE clause in the SELECT statement and returns + /// only the conditions that reference the given table + fn extract_where_clause_for_table(&self, table_name: &str) -> crate::Result { + // For single table queries, return the entire WHERE clause (already unqualified) + if self.referenced_tables.len() == 1 { + if let ast::OneSelect::Select { + where_clause: Some(ref where_expr), + .. + } = self.select_stmt.body.select + { + // For single table, the expression should already be unqualified or qualified with the single table + // We need to unqualify it for the single-table query + let unqualified = self.unqualify_expression(where_expr, table_name); + return Ok(unqualified.to_string()); } - FilterPredicate::NotEquals { column, value } => { - Ok(format!("{} != {}", column, self.value_to_sql(value))) + return Ok(String::new()); + } + + // For multi-table queries (JOINs), extract conditions for the specific table + if let ast::OneSelect::Select { + where_clause: Some(ref where_expr), + .. + } = self.select_stmt.body.select + { + // Extract conditions that reference only the specified table + let table_conditions = self.extract_table_conditions(where_expr, table_name)?; + if let Some(conditions) = table_conditions { + // Unqualify the expression for single-table query + let unqualified = self.unqualify_expression(&conditions, table_name); + return Ok(unqualified.to_string()); } - FilterPredicate::GreaterThan { column, value } => { - Ok(format!("{} > {}", column, self.value_to_sql(value))) + } + + Ok(String::new()) + } + + /// Extract conditions from an expression that reference only the specified table + fn extract_table_conditions( + &self, + expr: &ast::Expr, + table_name: &str, + ) -> crate::Result> { + match expr { + ast::Expr::Binary(left, op, right) => { + match op { + ast::Operator::And => { + // For AND, we can extract conditions independently + let left_cond = self.extract_table_conditions(left, table_name)?; + let right_cond = self.extract_table_conditions(right, table_name)?; + + match (left_cond, right_cond) { + (Some(l), Some(r)) => { + // Both conditions apply to this table + Ok(Some(ast::Expr::Binary( + Box::new(l), + ast::Operator::And, + Box::new(r), + ))) + } + (Some(l), None) => Ok(Some(l)), + (None, Some(r)) => Ok(Some(r)), + (None, None) => Ok(None), + } + } + ast::Operator::Or => { + // For OR, both sides must reference the same table(s) + // If either side references multiple tables, we can't extract it + let left_tables = self.get_referenced_tables_in_expr(left)?; + let right_tables = self.get_referenced_tables_in_expr(right)?; + + // If both sides only reference our table, include the whole OR + if left_tables.len() == 1 + && left_tables.contains(&table_name.to_string()) + && right_tables.len() == 1 + && right_tables.contains(&table_name.to_string()) + { + Ok(Some(expr.clone())) + } else { + // OR condition involves multiple tables, can't extract + Ok(None) + } + } + _ => { + // For comparison operators, check if this condition references only our table + // AND is simple enough to be pushed down (no complex expressions) + let referenced_tables = self.get_referenced_tables_in_expr(expr)?; + if referenced_tables.len() == 1 + && referenced_tables.contains(&table_name.to_string()) + { + // Check if this is a simple comparison that can be pushed down + // Complex expressions like (a * b) >= c should be handled by the circuit + if self.is_simple_comparison(expr) { + Ok(Some(expr.clone())) + } else { + // Complex expression - let the circuit handle it + Ok(None) + } + } else { + Ok(None) + } + } + } } - FilterPredicate::GreaterThanOrEqual { column, value } => { - Ok(format!("{} >= {}", column, self.value_to_sql(value))) + ast::Expr::Parenthesized(exprs) => { + if exprs.len() == 1 { + self.extract_table_conditions(&exprs[0], table_name) + } else { + Ok(None) + } } - FilterPredicate::LessThan { column, value } => { - Ok(format!("{} < {}", column, self.value_to_sql(value))) - } - FilterPredicate::LessThanOrEqual { column, value } => { - Ok(format!("{} <= {}", column, self.value_to_sql(value))) - } - FilterPredicate::And(left, right) => { - let left_clause = self.build_where_clause(left)?; - let right_clause = self.build_where_clause(right)?; - Ok(format!("({left_clause} AND {right_clause})")) - } - FilterPredicate::Or(left, right) => { - let left_clause = self.build_where_clause(left)?; - let right_clause = self.build_where_clause(right)?; - Ok(format!("({left_clause} OR {right_clause})")) + _ => { + // For other expressions, check if they reference only our table + // AND are simple enough to be pushed down + let referenced_tables = self.get_referenced_tables_in_expr(expr)?; + if referenced_tables.len() == 1 + && referenced_tables.contains(&table_name.to_string()) + && self.is_simple_comparison(expr) + { + Ok(Some(expr.clone())) + } else { + Ok(None) + } } } } - /// Convert a Value to SQL literal representation - fn value_to_sql(&self, value: &Value) -> String { - match value { - Value::Null => "NULL".to_string(), - Value::Integer(i) => i.to_string(), - Value::Float(f) => f.to_string(), - Value::Text(t) => format!("'{}'", t.as_str().replace('\'', "''")), - Value::Blob(_) => "NULL".to_string(), // Blob literals not supported in WHERE clause yet + /// Check if an expression is a simple comparison that can be pushed down to table scan + /// Returns true for simple comparisons like "column = value" or "column > value" + /// Returns false for complex expressions like "(a * b) > value" + fn is_simple_comparison(&self, expr: &ast::Expr) -> bool { + match expr { + ast::Expr::Binary(left, op, right) => { + // Check if it's a comparison operator + matches!( + op, + ast::Operator::Equals + | ast::Operator::NotEquals + | ast::Operator::Greater + | ast::Operator::GreaterEquals + | ast::Operator::Less + | ast::Operator::LessEquals + ) && self.is_simple_operand(left) + && self.is_simple_operand(right) + } + _ => false, } } + /// Check if an operand is simple (column reference or literal) + fn is_simple_operand(&self, expr: &ast::Expr) -> bool { + matches!( + expr, + ast::Expr::Id(_) + | ast::Expr::Qualified(_, _) + | ast::Expr::DoublyQualified(_, _, _) + | ast::Expr::Literal(_) + ) + } + + /// Get the set of table names referenced in an expression + fn get_referenced_tables_in_expr(&self, expr: &ast::Expr) -> crate::Result> { + let mut tables = Vec::new(); + self.collect_referenced_tables(expr, &mut tables)?; + // Deduplicate + tables.sort(); + tables.dedup(); + Ok(tables) + } + + /// Recursively collect table references from an expression + fn collect_referenced_tables( + &self, + expr: &ast::Expr, + tables: &mut Vec, + ) -> crate::Result<()> { + match expr { + ast::Expr::Binary(left, _, right) => { + self.collect_referenced_tables(left, tables)?; + self.collect_referenced_tables(right, tables)?; + } + ast::Expr::Qualified(table, _) => { + // This is a qualified column reference (table.column or alias.column) + // We need to resolve aliases to actual table names + let actual_table = self.resolve_table_alias(table.as_str()); + tables.push(actual_table); + } + ast::Expr::Id(column) => { + // Unqualified column reference + if self.referenced_tables.len() > 1 { + // In a JOIN context, check which tables have this column + let mut tables_with_column = Vec::new(); + for table in &self.referenced_tables { + if table + .columns + .iter() + .any(|c| c.name.as_ref() == Some(&column.to_string())) + { + tables_with_column.push(table.name.clone()); + } + } + + if tables_with_column.len() > 1 { + // Ambiguous column - this should have been caught earlier + // Return error to be safe + return Err(crate::LimboError::ParseError(format!( + "Ambiguous column name '{}' in WHERE clause - exists in tables: {}", + column, + tables_with_column.join(", ") + ))); + } else if tables_with_column.len() == 1 { + // Unambiguous - only one table has this column + // This is allowed by SQLite + tables.push(tables_with_column[0].clone()); + } else { + // Column doesn't exist in any table - this is an error + // but should be caught during compilation + return Err(crate::LimboError::ParseError(format!( + "Column '{column}' not found in any table" + ))); + } + } else { + // Single table context - unqualified columns belong to that table + if let Some(table) = self.referenced_tables.first() { + tables.push(table.name.clone()); + } + } + } + ast::Expr::DoublyQualified(_database, table, _column) => { + // For database.table.column, resolve the table name + let table_str = table.as_str(); + let actual_table = self.resolve_table_alias(table_str); + tables.push(actual_table); + } + ast::Expr::Parenthesized(exprs) => { + for e in exprs { + self.collect_referenced_tables(e, tables)?; + } + } + _ => { + // Literals and other expressions don't reference tables + } + } + Ok(()) + } + + /// Convert a qualified expression to unqualified for single-table queries + /// This removes table prefixes from column references since they're not needed + /// when querying a single table + fn unqualify_expression(&self, expr: &ast::Expr, table_name: &str) -> ast::Expr { + match expr { + ast::Expr::Binary(left, op, right) => { + // Recursively unqualify both sides + ast::Expr::Binary( + Box::new(self.unqualify_expression(left, table_name)), + *op, + Box::new(self.unqualify_expression(right, table_name)), + ) + } + ast::Expr::Qualified(table, column) => { + // Convert qualified column to unqualified if it's for our table + // Handle both "table.column" and "database.table.column" cases + let table_str = table.as_str(); + + // Check if this is a database.table reference + let actual_table = if table_str.contains('.') { + // Split on '.' and take the last part as the table name + table_str + .split('.') + .next_back() + .unwrap_or(table_str) + .to_string() + } else { + // Could be an alias or direct table name + self.resolve_table_alias(table_str) + }; + + if actual_table == table_name { + // Just return the column name without qualification + ast::Expr::Id(column.clone()) + } else { + // This shouldn't happen if extract_table_conditions worked correctly + // but keep it qualified just in case + expr.clone() + } + } + ast::Expr::DoublyQualified(_database, table, column) => { + // This is database.table.column format + // Check if the table matches our target table + let table_str = table.as_str(); + let actual_table = self.resolve_table_alias(table_str); + + if actual_table == table_name { + // Just return the column name without qualification + ast::Expr::Id(column.clone()) + } else { + // Keep it qualified if it's for a different table + expr.clone() + } + } + ast::Expr::Parenthesized(exprs) => { + // Recursively unqualify expressions in parentheses + let unqualified_exprs: Vec> = exprs + .iter() + .map(|e| Box::new(self.unqualify_expression(e, table_name))) + .collect(); + ast::Expr::Parenthesized(unqualified_exprs) + } + _ => { + // Other expression types (literals, unqualified columns, etc.) stay as-is + expr.clone() + } + } + } + + /// Resolve a table alias to the actual table name + fn resolve_table_alias(&self, alias: &str) -> String { + // Check if there's an alias mapping in the FROM/JOIN clauses + // For now, we'll do a simple check - if the alias matches a table name, use it + // Otherwise, try to find it in the FROM clause + + // First check if it's an actual table name + if self.referenced_tables.iter().any(|t| t.name == alias) { + return alias.to_string(); + } + + // Check if it's an alias that maps to a table + if let Some(table_name) = self.table_aliases.get(alias) { + return table_name.clone(); + } + + // If we can't resolve it, return as-is (it might be a table name we don't know about) + alias.to_string() + } + /// Populate the view by scanning the source table using a state machine /// This can be called multiple times and will resume from where it left off /// This method is only for materialized views and will persist data to the btree @@ -563,279 +882,242 @@ impl IncrementalView { pager: &std::sync::Arc, _btree_cursor: &mut BTreeCursor, ) -> crate::Result> { - // If already populated, return immediately - if matches!(self.populate_state, PopulateState::Done) { - return Ok(IOResult::Done(())); - } - // Assert that this is a materialized view with a root page assert!( self.root_page != 0, "populate_from_table should only be called for materialized views with root_page" ); - loop { - // To avoid borrow checker issues, we need to handle state transitions carefully - let needs_start = matches!(self.populate_state, PopulateState::Start); + 'outer: loop { + match std::mem::replace(&mut self.populate_state, PopulateState::Done) { + PopulateState::Start => { + // Generate the SQL query for populating the view + // It is best to use a standard query than a cursor for two reasons: + // 1) Using a sql query will allow us to be much more efficient in cases where we only want + // some rows, in particular for indexed filters + // 2) There are two types of cursors: index and table. In some situations (like for example + // if the table has an integer primary key), the key will be exclusively in the index + // btree and not in the table btree. Using cursors would force us to be aware of this + // distinction (and others), and ultimately lead to reimplementing the whole query + // machinery (next step is which index is best to use, etc) + let queries = self.sql_for_populate()?; - if needs_start { - // Generate the SQL query for populating the view - // It is best to use a standard query than a cursor for two reasons: - // 1) Using a sql query will allow us to be much more efficient in cases where we only want - // some rows, in particular for indexed filters - // 2) There are two types of cursors: index and table. In some situations (like for example - // if the table has an integer primary key), the key will be exclusively in the index - // btree and not in the table btree. Using cursors would force us to be aware of this - // distinction (and others), and ultimately lead to reimplementing the whole query - // machinery (next step is which index is best to use, etc) - let query = self.sql_for_populate()?; - - // Prepare the statement - let stmt = conn.prepare(&query)?; - - self.populate_state = PopulateState::Processing { - stmt: Box::new(stmt), - rows_processed: 0, - pending_row: None, - }; - // Continue to next state - continue; - } - - // Handle Done state - if matches!(self.populate_state, PopulateState::Done) { - return Ok(IOResult::Done(())); - } - - // Handle Processing state - extract state to avoid borrow issues - let (mut stmt, mut rows_processed, pending_row) = - match std::mem::replace(&mut self.populate_state, PopulateState::Done) { - PopulateState::Processing { - stmt, - rows_processed, - pending_row, - } => (stmt, rows_processed, pending_row), - _ => unreachable!("We already handled Start and Done states"), - }; - - // If we have a pending row from a previous I/O interruption, process it first - if let Some((rowid, values)) = pending_row { - // Create a single-row delta for the pending row - let mut single_row_delta = Delta::new(); - single_row_delta.insert(rowid, values.clone()); - - // Create a DeltaSet with this delta for the first table (for now) - let mut delta_set = DeltaSet::new(); - // TODO: When we support JOINs, determine which table this row came from - delta_set.insert(self.referenced_tables[0].name.clone(), single_row_delta); - - // Process the pending row with the pager - match self.merge_delta(delta_set, pager.clone())? { - IOResult::Done(_) => { - // Row processed successfully, continue to next row - rows_processed += 1; - // Continue to fetch next row from statement - } - IOResult::IO(io) => { - // Still not done, save state with pending row - self.populate_state = PopulateState::Processing { - stmt, - rows_processed, - pending_row: Some((rowid, values)), // Keep the pending row - }; - return Ok(IOResult::IO(io)); - } + self.populate_state = PopulateState::ProcessingAllTables { + queries, + current_idx: 0, + }; } - } - // Process rows one at a time - no batching - loop { - // This step() call resumes from where the statement left off - match stmt.step()? { - crate::vdbe::StepResult::Row => { - // Get the row - let row = stmt.row().unwrap(); + PopulateState::ProcessingAllTables { + queries, + current_idx, + } => { + if current_idx >= queries.len() { + self.populate_state = PopulateState::Done; + return Ok(IOResult::Done(())); + } - // Extract values from the row - let all_values: Vec = - row.get_values().cloned().collect(); + let query = queries[current_idx].clone(); + // Create a new connection for reading to avoid transaction conflicts + // This allows us to read from tables while the parent transaction is writing the view + // The statement holds a reference to this connection, keeping it alive + let read_conn = conn.db.connect()?; - // Determine how to extract the rowid - // If there's a rowid alias (INTEGER PRIMARY KEY), the rowid is one of the columns - // Otherwise, it's the last value we explicitly selected - let (rowid, values) = if let Some((idx, _)) = - self.referenced_tables[0].get_rowid_alias_column() - { - // The rowid is the value at the rowid alias column index - let rowid = match all_values.get(idx) { - Some(crate::types::Value::Integer(id)) => *id, - _ => { - // This shouldn't happen - rowid alias must be an integer - rows_processed += 1; - continue; - } - }; - // All values are table columns (no separate rowid was selected) - (rowid, all_values) - } else { - // The last value is the explicitly selected rowid - let rowid = match all_values.last() { - Some(crate::types::Value::Integer(id)) => *id, - _ => { - // This shouldn't happen - rowid must be an integer - rows_processed += 1; - continue; - } - }; - // Get all values except the rowid - let values = all_values[..all_values.len() - 1].to_vec(); - (rowid, values) - }; + // Prepare the statement using the read connection + let stmt = read_conn.prepare(&query)?; - // Create a single-row delta and process it immediately - let mut single_row_delta = Delta::new(); - single_row_delta.insert(rowid, values.clone()); + self.populate_state = PopulateState::ProcessingOneTable { + queries, + current_idx, + stmt: Box::new(stmt), + rows_processed: 0, + pending_row: None, + }; + } - // Create a DeltaSet with this delta for the first table (for now) - let mut delta_set = DeltaSet::new(); - // TODO: When we support JOINs, determine which table this row came from - delta_set.insert(self.referenced_tables[0].name.clone(), single_row_delta); - - // Process this single row through merge_delta with the pager - match self.merge_delta(delta_set, pager.clone())? { + PopulateState::ProcessingOneTable { + queries, + current_idx, + mut stmt, + mut rows_processed, + pending_row, + } => { + // If we have a pending row from a previous I/O interruption, process it first + if let Some((rowid, values)) = pending_row { + match self.process_one_row( + rowid, + values.clone(), + current_idx, + pager.clone(), + )? { IOResult::Done(_) => { // Row processed successfully, continue to next row rows_processed += 1; } IOResult::IO(io) => { - // Save state and return I/O - // We'll resume at the SAME row when called again (don't increment rows_processed) - // The circuit still has unfinished work for this row - self.populate_state = PopulateState::Processing { + // Still not done, restore state with pending row and return + self.populate_state = PopulateState::ProcessingOneTable { + queries, + current_idx, stmt, - rows_processed, // Don't increment - row not done yet! - pending_row: Some((rowid, values)), // Save the row for resumption + rows_processed, + pending_row: Some((rowid, values)), }; return Ok(IOResult::IO(io)); } } } - crate::vdbe::StepResult::Done => { - // All rows processed, we're done - self.populate_state = PopulateState::Done; - return Ok(IOResult::Done(())); - } + // Process rows one at a time - no batching + loop { + // This step() call resumes from where the statement left off + match stmt.step()? { + crate::vdbe::StepResult::Row => { + // Get the row + let row = stmt.row().unwrap(); - crate::vdbe::StepResult::Interrupt | crate::vdbe::StepResult::Busy => { - // Save state before returning error - self.populate_state = PopulateState::Processing { - stmt, - rows_processed, - pending_row: None, // No pending row when interrupted between rows - }; - return Err(LimboError::Busy); - } + // Extract values from the row + let all_values: Vec = + row.get_values().cloned().collect(); - crate::vdbe::StepResult::IO => { - // Statement needs I/O - save state and return - self.populate_state = PopulateState::Processing { - stmt, - rows_processed, - pending_row: None, // No pending row when interrupted between rows - }; - // TODO: Get the actual I/O completion from the statement - let completion = crate::io::Completion::new_dummy(); - return Ok(IOResult::IO(crate::types::IOCompletions::Single( - completion, - ))); - } - } - } - } - } + // Extract rowid and values using helper + let (rowid, values) = + match self.extract_rowid_and_values(all_values, current_idx) { + Some(result) => result, + None => { + // Invalid rowid, skip this row + rows_processed += 1; + continue; + } + }; - /// Extract JOIN information from SELECT statement - #[allow(clippy::type_complexity)] - pub fn extract_join_info( - select: &ast::Select, - ) -> (Option<(String, String)>, Option<(String, String)>) { - use turso_parser::ast::*; + // Process this row + match self.process_one_row( + rowid, + values.clone(), + current_idx, + pager.clone(), + )? { + IOResult::Done(_) => { + // Row processed successfully, continue to next row + rows_processed += 1; + } + IOResult::IO(io) => { + // Save state and return I/O + // We'll resume at the SAME row when called again (don't increment rows_processed) + // The circuit still has unfinished work for this row + self.populate_state = PopulateState::ProcessingOneTable { + queries, + current_idx, + stmt, + rows_processed, // Don't increment - row not done yet! + pending_row: Some((rowid, values)), // Save the row for resumption + }; + return Ok(IOResult::IO(io)); + } + } + } - if let OneSelect::Select { - from: Some(ref from), - .. - } = select.body.select - { - // Check if there are any joins - if !from.joins.is_empty() { - // Get the first (left) table name - let left_table = match from.select.as_ref() { - SelectTable::Table(name, _, _) => Some(name.name.as_str().to_string()), - _ => None, - }; + crate::vdbe::StepResult::Done => { + // All rows processed from this table + // Move to next table + self.populate_state = PopulateState::ProcessingAllTables { + queries, + current_idx: current_idx + 1, + }; + continue 'outer; + } - // Get the first join (right) table and condition - if let Some(first_join) = from.joins.first() { - let right_table = match &first_join.table.as_ref() { - SelectTable::Table(name, _, _) => Some(name.name.as_str().to_string()), - _ => None, - }; + crate::vdbe::StepResult::Interrupt | crate::vdbe::StepResult::Busy => { + // Save state before returning error + self.populate_state = PopulateState::ProcessingOneTable { + queries, + current_idx, + stmt, + rows_processed, + pending_row: None, // No pending row when interrupted between rows + }; + return Err(LimboError::Busy); + } - // Extract join condition (simplified - assumes single equality) - let join_condition = if let Some(ref constraint) = &first_join.constraint { - match constraint { - JoinConstraint::On(expr) => Self::extract_join_columns_from_expr(expr), - _ => None, + crate::vdbe::StepResult::IO => { + // Statement needs I/O - save state and return + self.populate_state = PopulateState::ProcessingOneTable { + queries, + current_idx, + stmt, + rows_processed, + pending_row: None, // No pending row when interrupted between rows + }; + // TODO: Get the actual I/O completion from the statement + let completion = crate::io::Completion::new_dummy(); + return Ok(IOResult::IO(crate::types::IOCompletions::Single( + completion, + ))); + } } - } else { - None - }; - - if let (Some(left), Some(right)) = (left_table, right_table) { - return (Some((left, right)), join_condition); } } - } - } - (None, None) - } - - /// Extract join column names from a join condition expression - fn extract_join_columns_from_expr(expr: &ast::Expr) -> Option<(String, String)> { - use turso_parser::ast::*; - - // Look for expressions like: t1.col = t2.col - if let Expr::Binary(left, op, right) = expr { - if matches!(op, Operator::Equals) { - // Extract column names from both sides - let left_col = match &**left { - Expr::Qualified(name, _) => Some(name.as_str().to_string()), - Expr::Id(name) => Some(name.as_str().to_string()), - _ => None, - }; - - let right_col = match &**right { - Expr::Qualified(name, _) => Some(name.as_str().to_string()), - Expr::Id(name) => Some(name.as_str().to_string()), - _ => None, - }; - - if let (Some(l), Some(r)) = (left_col, right_col) { - return Some((l, r)); + PopulateState::Done => { + return Ok(IOResult::Done(())); } } } + } - None + /// Process a single row through the circuit + fn process_one_row( + &mut self, + rowid: i64, + values: Vec, + table_idx: usize, + pager: Arc, + ) -> crate::Result> { + // Create a single-row delta + let mut single_row_delta = Delta::new(); + single_row_delta.insert(rowid, values); + + // Create a DeltaSet with this delta for the current table + let mut delta_set = DeltaSet::new(); + let table_name = self.referenced_tables[table_idx].name.clone(); + delta_set.insert(table_name, single_row_delta); + + // Process through merge_delta + self.merge_delta(delta_set, pager) + } + + /// Extract rowid and values from a row + fn extract_rowid_and_values( + &self, + all_values: Vec, + table_idx: usize, + ) -> Option<(i64, Vec)> { + if let Some((idx, _)) = self.referenced_tables[table_idx].get_rowid_alias_column() { + // The rowid is the value at the rowid alias column index + let rowid = match all_values.get(idx) { + Some(Value::Integer(id)) => *id, + _ => return None, // Invalid rowid + }; + // All values are table columns (no separate rowid was selected) + Some((rowid, all_values)) + } else { + // The last value is the explicitly selected rowid + let rowid = match all_values.last() { + Some(Value::Integer(id)) => *id, + _ => return None, // Invalid rowid + }; + // Get all values except the rowid + let values = all_values[..all_values.len() - 1].to_vec(); + Some((rowid, values)) + } } /// Merge a delta set of changes into the view's current state pub fn merge_delta( &mut self, delta_set: DeltaSet, - pager: std::sync::Arc, + pager: Arc, ) -> crate::Result> { // Early return if all deltas are empty if delta_set.is_empty() { @@ -977,15 +1259,76 @@ mod tests { collation: None, hidden: false, }, + SchemaColumn { + name: Some("price".to_string()), + ty: Type::Real, + ty_str: "REAL".to_string(), + primary_key: false, + is_rowid_alias: false, + notnull: false, + default: None, + unique: false, + collation: None, + hidden: false, + }, ], has_rowid: true, is_strict: false, unique_sets: vec![], }; + // Create logs table - without a rowid alias (no INTEGER PRIMARY KEY) + let logs_table = BTreeTable { + name: "logs".to_string(), + root_page: 5, + primary_key_columns: vec![], // No primary key, so no rowid alias + columns: vec![ + SchemaColumn { + name: Some("message".to_string()), + ty: Type::Text, + ty_str: "TEXT".to_string(), + primary_key: false, + is_rowid_alias: false, + notnull: false, + default: None, + unique: false, + collation: None, + hidden: false, + }, + SchemaColumn { + name: Some("level".to_string()), + ty: Type::Integer, + ty_str: "INTEGER".to_string(), + primary_key: false, + is_rowid_alias: false, + notnull: false, + default: None, + unique: false, + collation: None, + hidden: false, + }, + SchemaColumn { + name: Some("timestamp".to_string()), + ty: Type::Integer, + ty_str: "INTEGER".to_string(), + primary_key: false, + is_rowid_alias: false, + notnull: false, + default: None, + unique: false, + collation: None, + hidden: false, + }, + ], + has_rowid: true, // Has implicit rowid but no alias + is_strict: false, + unique_sets: vec![], + }; + schema.add_btree_table(Arc::new(customers_table)); schema.add_btree_table(Arc::new(orders_table)); schema.add_btree_table(Arc::new(products_table)); + schema.add_btree_table(Arc::new(logs_table)); schema } @@ -1004,7 +1347,7 @@ mod tests { let schema = create_test_schema(); let select = parse_select("SELECT * FROM customers"); - let tables = IncrementalView::extract_all_tables(&select, &schema).unwrap(); + let (tables, _, _) = IncrementalView::extract_all_tables(&select, &schema).unwrap(); assert_eq!(tables.len(), 1); assert_eq!(tables[0].name, "customers"); @@ -1017,7 +1360,7 @@ mod tests { "SELECT * FROM customers INNER JOIN orders ON customers.id = orders.customer_id", ); - let tables = IncrementalView::extract_all_tables(&select, &schema).unwrap(); + let (tables, _, _) = IncrementalView::extract_all_tables(&select, &schema).unwrap(); assert_eq!(tables.len(), 2); assert_eq!(tables[0].name, "customers"); @@ -1033,7 +1376,7 @@ mod tests { INNER JOIN products ON orders.id = products.id", ); - let tables = IncrementalView::extract_all_tables(&select, &schema).unwrap(); + let (tables, _, _) = IncrementalView::extract_all_tables(&select, &schema).unwrap(); assert_eq!(tables.len(), 3); assert_eq!(tables[0].name, "customers"); @@ -1048,7 +1391,7 @@ mod tests { "SELECT * FROM customers LEFT JOIN orders ON customers.id = orders.customer_id", ); - let tables = IncrementalView::extract_all_tables(&select, &schema).unwrap(); + let (tables, _, _) = IncrementalView::extract_all_tables(&select, &schema).unwrap(); assert_eq!(tables.len(), 2); assert_eq!(tables[0].name, "customers"); @@ -1060,7 +1403,7 @@ mod tests { let schema = create_test_schema(); let select = parse_select("SELECT * FROM customers CROSS JOIN orders"); - let tables = IncrementalView::extract_all_tables(&select, &schema).unwrap(); + let (tables, _, _) = IncrementalView::extract_all_tables(&select, &schema).unwrap(); assert_eq!(tables.len(), 2); assert_eq!(tables[0].name, "customers"); @@ -1073,7 +1416,7 @@ mod tests { let select = parse_select("SELECT * FROM customers c INNER JOIN orders o ON c.id = o.customer_id"); - let tables = IncrementalView::extract_all_tables(&select, &schema).unwrap(); + let (tables, _, _) = IncrementalView::extract_all_tables(&select, &schema).unwrap(); // Should still extract the actual table names, not aliases assert_eq!(tables.len(), 2); @@ -1086,7 +1429,8 @@ mod tests { let schema = create_test_schema(); let select = parse_select("SELECT * FROM nonexistent"); - let result = IncrementalView::extract_all_tables(&select, &schema); + let result = + IncrementalView::extract_all_tables(&select, &schema).map(|(tables, _, _)| tables); assert!(result.is_err()); assert!(result @@ -1102,7 +1446,8 @@ mod tests { "SELECT * FROM customers INNER JOIN nonexistent ON customers.id = nonexistent.id", ); - let result = IncrementalView::extract_all_tables(&select, &schema); + let result = + IncrementalView::extract_all_tables(&select, &schema).map(|(tables, _, _)| tables); assert!(result.is_err()); assert!(result @@ -1110,4 +1455,526 @@ mod tests { .to_string() .contains("Table 'nonexistent' not found")); } + + #[test] + fn test_sql_for_populate_simple_query_no_where() { + // Test simple query with no WHERE clause + let schema = create_test_schema(); + let select = parse_select("SELECT * FROM customers"); + + let (tables, aliases, qualified_names) = + IncrementalView::extract_all_tables(&select, &schema).unwrap(); + let view = IncrementalView::new( + "test_view".to_string(), + select.clone(), + tables, + aliases, + qualified_names, + extract_view_columns(&select, &schema).unwrap(), + &schema, + 1, // main_data_root + 2, // internal_state_root + 3, // internal_state_index_root + ) + .unwrap(); + + let queries = view.sql_for_populate().unwrap(); + + assert_eq!(queries.len(), 1); + // customers has id as rowid alias, so no need for explicit rowid + assert_eq!(queries[0], "SELECT * FROM customers"); + } + + #[test] + fn test_sql_for_populate_simple_query_with_where() { + // Test simple query with WHERE clause + let schema = create_test_schema(); + let select = parse_select("SELECT * FROM customers WHERE id > 10"); + + let (tables, aliases, qualified_names) = + IncrementalView::extract_all_tables(&select, &schema).unwrap(); + let view = IncrementalView::new( + "test_view".to_string(), + select.clone(), + tables, + aliases, + qualified_names, + extract_view_columns(&select, &schema).unwrap(), + &schema, + 1, // main_data_root + 2, // internal_state_root + 3, // internal_state_index_root + ) + .unwrap(); + + let queries = view.sql_for_populate().unwrap(); + + assert_eq!(queries.len(), 1); + // For single-table queries, we should get the full WHERE clause + assert_eq!(queries[0], "SELECT * FROM customers WHERE id > 10"); + } + + #[test] + fn test_sql_for_populate_join_with_where_on_both_tables() { + // Test JOIN query with WHERE conditions on both tables + let schema = create_test_schema(); + let select = parse_select( + "SELECT * FROM customers c \ + JOIN orders o ON c.id = o.customer_id \ + WHERE c.id > 10 AND o.total > 100", + ); + + let (tables, aliases, qualified_names) = + IncrementalView::extract_all_tables(&select, &schema).unwrap(); + let view = IncrementalView::new( + "test_view".to_string(), + select.clone(), + tables, + aliases, + qualified_names, + extract_view_columns(&select, &schema).unwrap(), + &schema, + 1, // main_data_root + 2, // internal_state_root + 3, // internal_state_index_root + ) + .unwrap(); + + let queries = view.sql_for_populate().unwrap(); + + assert_eq!(queries.len(), 2); + + // With per-table WHERE extraction: + // - customers table gets: c.id > 10 + // - orders table gets: o.total > 100 + assert_eq!(queries[0], "SELECT * FROM customers WHERE id > 10"); + assert_eq!(queries[1], "SELECT * FROM orders WHERE total > 100"); + } + + #[test] + fn test_sql_for_populate_complex_join_with_mixed_conditions() { + // Test complex JOIN with WHERE conditions mixing both tables + let schema = create_test_schema(); + let select = parse_select( + "SELECT * FROM customers c \ + JOIN orders o ON c.id = o.customer_id \ + WHERE c.id > 10 AND o.total > 100 AND c.name = 'John' \ + AND o.customer_id = 5 AND (c.id = 15 OR o.total = 200)", + ); + + let (tables, aliases, qualified_names) = + IncrementalView::extract_all_tables(&select, &schema).unwrap(); + let view = IncrementalView::new( + "test_view".to_string(), + select.clone(), + tables, + aliases, + qualified_names, + extract_view_columns(&select, &schema).unwrap(), + &schema, + 1, // main_data_root + 2, // internal_state_root + 3, // internal_state_index_root + ) + .unwrap(); + + let queries = view.sql_for_populate().unwrap(); + + assert_eq!(queries.len(), 2); + + // With per-table WHERE extraction: + // - customers gets: c.id > 10 AND c.name = 'John' + // - orders gets: o.total > 100 AND o.customer_id = 5 + // Note: The OR condition (c.id = 15 OR o.total = 200) involves both tables, + // so it cannot be extracted to either table individually + assert_eq!( + queries[0], + "SELECT * FROM customers WHERE id > 10 AND name = 'John'" + ); + assert_eq!( + queries[1], + "SELECT * FROM orders WHERE total > 100 AND customer_id = 5" + ); + } + + #[test] + fn test_where_extraction_for_three_tables() { + // Test that WHERE clause extraction correctly separates conditions for 3+ tables + // This addresses the concern about conditions "piling up" as joins increase + + // Simulate a three-table scenario + let schema = create_test_schema(); + + // Parse a WHERE clause with conditions for three different tables + let select = parse_select( + "SELECT * FROM customers WHERE c.id > 10 AND o.total > 100 AND p.price > 50", + ); + + // Get the WHERE expression + if let ast::OneSelect::Select { + where_clause: Some(ref where_expr), + .. + } = select.body.select + { + // Create a view with three tables to test extraction + let tables = vec![ + schema.get_btree_table("customers").unwrap(), + schema.get_btree_table("orders").unwrap(), + schema.get_btree_table("products").unwrap(), + ]; + + let mut aliases = HashMap::new(); + aliases.insert("c".to_string(), "customers".to_string()); + aliases.insert("o".to_string(), "orders".to_string()); + aliases.insert("p".to_string(), "products".to_string()); + + // Create a minimal view just to test extraction logic + let view = IncrementalView { + name: "test".to_string(), + select_stmt: select.clone(), + circuit: DbspCircuit::new(1, 2, 3), + referenced_tables: tables, + table_aliases: aliases, + qualified_table_names: HashMap::new(), + column_schema: ViewColumnSchema { + columns: vec![], + tables: vec![], + }, + populate_state: PopulateState::Start, + tracker: Arc::new(Mutex::new(ComputationTracker::new())), + root_page: 0, + }; + + // Test extraction for each table + let customers_conds = view + .extract_table_conditions(where_expr, "customers") + .unwrap(); + let orders_conds = view.extract_table_conditions(where_expr, "orders").unwrap(); + let products_conds = view + .extract_table_conditions(where_expr, "products") + .unwrap(); + + // Verify each table only gets its conditions + if let Some(cond) = customers_conds { + let sql = cond.to_string(); + assert!(sql.contains("id > 10")); + assert!(!sql.contains("total")); + assert!(!sql.contains("price")); + } + + if let Some(cond) = orders_conds { + let sql = cond.to_string(); + assert!(sql.contains("total > 100")); + assert!(!sql.contains("id > 10")); // From customers + assert!(!sql.contains("price")); + } + + if let Some(cond) = products_conds { + let sql = cond.to_string(); + assert!(sql.contains("price > 50")); + assert!(!sql.contains("id > 10")); // From customers + assert!(!sql.contains("total")); + } + } else { + panic!("Failed to parse WHERE clause"); + } + } + + #[test] + fn test_alias_resolution_works_correctly() { + // Test that alias resolution properly maps aliases to table names + let schema = create_test_schema(); + let select = parse_select( + "SELECT * FROM customers c \ + JOIN orders o ON c.id = o.customer_id \ + WHERE c.id > 10 AND o.total > 100", + ); + + let (tables, aliases, qualified_names) = + IncrementalView::extract_all_tables(&select, &schema).unwrap(); + let view = IncrementalView::new( + "test_view".to_string(), + select.clone(), + tables, + aliases, + qualified_names, + extract_view_columns(&select, &schema).unwrap(), + &schema, + 1, // main_data_root + 2, // internal_state_root + 3, // internal_state_index_root + ) + .unwrap(); + + // Verify that alias mappings were extracted correctly + assert_eq!(view.table_aliases.get("c"), Some(&"customers".to_string())); + assert_eq!(view.table_aliases.get("o"), Some(&"orders".to_string())); + + // Verify that SQL generation uses the aliases correctly + let queries = view.sql_for_populate().unwrap(); + assert_eq!(queries.len(), 2); + + // Each query should use the actual table name, not the alias + assert!(queries[0].contains("FROM customers") || queries[1].contains("FROM customers")); + assert!(queries[0].contains("FROM orders") || queries[1].contains("FROM orders")); + } + + #[test] + fn test_sql_for_populate_table_without_rowid_alias() { + // Test that tables without a rowid alias properly include rowid in SELECT + let schema = create_test_schema(); + let select = parse_select("SELECT * FROM logs WHERE level > 2"); + + let (tables, aliases, qualified_names) = + IncrementalView::extract_all_tables(&select, &schema).unwrap(); + let view = IncrementalView::new( + "test_view".to_string(), + select.clone(), + tables, + aliases, + qualified_names, + extract_view_columns(&select, &schema).unwrap(), + &schema, + 1, // main_data_root + 2, // internal_state_root + 3, // internal_state_index_root + ) + .unwrap(); + + let queries = view.sql_for_populate().unwrap(); + + assert_eq!(queries.len(), 1); + // logs table has no rowid alias, so we need to explicitly select rowid + assert_eq!(queries[0], "SELECT *, rowid FROM logs WHERE level > 2"); + } + + #[test] + fn test_sql_for_populate_join_with_and_without_rowid_alias() { + // Test JOIN between a table with rowid alias and one without + let schema = create_test_schema(); + let select = parse_select( + "SELECT * FROM customers c \ + JOIN logs l ON c.id = l.level \ + WHERE c.id > 10 AND l.level > 2", + ); + + let (tables, aliases, qualified_names) = + IncrementalView::extract_all_tables(&select, &schema).unwrap(); + let view = IncrementalView::new( + "test_view".to_string(), + select.clone(), + tables, + aliases, + qualified_names, + extract_view_columns(&select, &schema).unwrap(), + &schema, + 1, // main_data_root + 2, // internal_state_root + 3, // internal_state_index_root + ) + .unwrap(); + + let queries = view.sql_for_populate().unwrap(); + + assert_eq!(queries.len(), 2); + // customers has rowid alias (id), logs doesn't + assert_eq!(queries[0], "SELECT * FROM customers WHERE id > 10"); + assert_eq!(queries[1], "SELECT *, rowid FROM logs WHERE level > 2"); + } + + #[test] + fn test_sql_for_populate_with_database_qualified_names() { + // Test that database.table.column references are handled correctly + // The table name in FROM should keep the database prefix, + // but column names in WHERE should be unqualified + let schema = create_test_schema(); + + // Test with single table using database qualification + let select = parse_select("SELECT * FROM main.customers WHERE main.customers.id > 10"); + + let (tables, aliases, qualified_names) = + IncrementalView::extract_all_tables(&select, &schema).unwrap(); + let view = IncrementalView::new( + "test_view".to_string(), + select.clone(), + tables, + aliases, + qualified_names, + extract_view_columns(&select, &schema).unwrap(), + &schema, + 1, // main_data_root + 2, // internal_state_root + 3, // internal_state_index_root + ) + .unwrap(); + + let queries = view.sql_for_populate().unwrap(); + + assert_eq!(queries.len(), 1); + // The FROM clause should preserve the database qualification, + // but the WHERE clause should have unqualified column names + assert_eq!(queries[0], "SELECT * FROM main.customers WHERE id > 10"); + } + + #[test] + fn test_sql_for_populate_join_with_database_qualified_names() { + // Test JOIN with database-qualified table and column references + let schema = create_test_schema(); + + let select = parse_select( + "SELECT * FROM main.customers c \ + JOIN main.orders o ON c.id = o.customer_id \ + WHERE main.customers.id > 10 AND main.orders.total > 100", + ); + + let (tables, aliases, qualified_names) = + IncrementalView::extract_all_tables(&select, &schema).unwrap(); + let view = IncrementalView::new( + "test_view".to_string(), + select.clone(), + tables, + aliases, + qualified_names, + extract_view_columns(&select, &schema).unwrap(), + &schema, + 1, // main_data_root + 2, // internal_state_root + 3, // internal_state_index_root + ) + .unwrap(); + + let queries = view.sql_for_populate().unwrap(); + + assert_eq!(queries.len(), 2); + // The FROM clauses should preserve database qualification, + // but WHERE clauses should have unqualified column names + assert_eq!(queries[0], "SELECT * FROM main.customers WHERE id > 10"); + assert_eq!(queries[1], "SELECT * FROM main.orders WHERE total > 100"); + } + + #[test] + fn test_sql_for_populate_unambiguous_unqualified_column() { + // Test that unambiguous unqualified columns ARE extracted + let schema = create_test_schema(); + let select = parse_select( + "SELECT * FROM customers c \ + JOIN orders o ON c.id = o.customer_id \ + WHERE total > 100", // 'total' only exists in orders table + ); + + let (tables, aliases, qualified_names) = + IncrementalView::extract_all_tables(&select, &schema).unwrap(); + let view = IncrementalView::new( + "test_view".to_string(), + select.clone(), + tables, + aliases, + qualified_names, + extract_view_columns(&select, &schema).unwrap(), + &schema, + 1, // main_data_root + 2, // internal_state_root + 3, // internal_state_index_root + ) + .unwrap(); + + let queries = view.sql_for_populate().unwrap(); + + assert_eq!(queries.len(), 2); + + // 'total' is unambiguous (only in orders), so it should be extracted + assert_eq!(queries[0], "SELECT * FROM customers"); + assert_eq!(queries[1], "SELECT * FROM orders WHERE total > 100"); + } + + #[test] + fn test_database_qualified_table_names() { + let schema = create_test_schema(); + + // Test with database-qualified table names + let select = parse_select( + "SELECT c.id, c.name, o.id, o.total + FROM main.customers c + JOIN main.orders o ON c.id = o.customer_id + WHERE c.id > 10", + ); + + let (tables, aliases, qualified_names) = + IncrementalView::extract_all_tables(&select, &schema).unwrap(); + + // Check that qualified names are preserved + assert!(qualified_names.contains_key("customers")); + assert_eq!(qualified_names.get("customers").unwrap(), "main.customers"); + assert!(qualified_names.contains_key("orders")); + assert_eq!(qualified_names.get("orders").unwrap(), "main.orders"); + + let view = IncrementalView::new( + "test_view".to_string(), + select.clone(), + tables, + aliases, + qualified_names.clone(), + extract_view_columns(&select, &schema).unwrap(), + &schema, + 1, // main_data_root + 2, // internal_state_root + 3, // internal_state_index_root + ) + .unwrap(); + + let queries = view.sql_for_populate().unwrap(); + + assert_eq!(queries.len(), 2); + + // The FROM clause should contain the database-qualified name + // But the WHERE clause should use unqualified column names + assert_eq!(queries[0], "SELECT * FROM main.customers WHERE id > 10"); + assert_eq!(queries[1], "SELECT * FROM main.orders"); + } + + #[test] + fn test_mixed_qualified_unqualified_tables() { + let schema = create_test_schema(); + + // Test with a mix of qualified and unqualified table names + let select = parse_select( + "SELECT c.id, c.name, o.id, o.total + FROM main.customers c + JOIN orders o ON c.id = o.customer_id + WHERE c.id > 10 AND o.total < 1000", + ); + + let (tables, aliases, qualified_names) = + IncrementalView::extract_all_tables(&select, &schema).unwrap(); + + // Check that qualified names are preserved where specified + assert_eq!(qualified_names.get("customers").unwrap(), "main.customers"); + // Unqualified tables should not have an entry (or have the bare name) + assert!( + !qualified_names.contains_key("orders") + || qualified_names.get("orders").unwrap() == "orders" + ); + + let view = IncrementalView::new( + "test_view".to_string(), + select.clone(), + tables, + aliases, + qualified_names.clone(), + extract_view_columns(&select, &schema).unwrap(), + &schema, + 1, // main_data_root + 2, // internal_state_root + 3, // internal_state_index_root + ) + .unwrap(); + + let queries = view.sql_for_populate().unwrap(); + + assert_eq!(queries.len(), 2); + + // The FROM clause should preserve qualification where specified + assert_eq!(queries[0], "SELECT * FROM main.customers WHERE id > 10"); + assert_eq!(queries[1], "SELECT * FROM orders WHERE total < 1000"); + } } diff --git a/core/io/memory.rs b/core/io/memory.rs index c69d87dcf..fc0549ca7 100644 --- a/core/io/memory.rs +++ b/core/io/memory.rs @@ -12,7 +12,6 @@ use tracing::debug; pub struct MemoryIO { files: Arc>>>, } -unsafe impl Send for MemoryIO {} // TODO: page size flag const PAGE_SIZE: usize = 4096; @@ -76,7 +75,7 @@ pub struct MemoryFile { pages: UnsafeCell>, size: Cell, } -unsafe impl Send for MemoryFile {} + unsafe impl Sync for MemoryFile {} impl File for MemoryFile { diff --git a/core/io/unix.rs b/core/io/unix.rs index a3cfd6f2f..b0d47f30f 100644 --- a/core/io/unix.rs +++ b/core/io/unix.rs @@ -17,9 +17,6 @@ use tracing::{instrument, trace, Level}; pub struct UnixIO {} -unsafe impl Send for UnixIO {} -unsafe impl Sync for UnixIO {} - impl UnixIO { #[cfg(feature = "fs")] pub fn new() -> Result { @@ -128,8 +125,6 @@ impl IO for UnixIO { pub struct UnixFile { file: Arc>, } -unsafe impl Send for UnixFile {} -unsafe impl Sync for UnixFile {} impl File for UnixFile { fn lock_file(&self, exclusive: bool) -> Result<()> { diff --git a/core/lib.rs b/core/lib.rs index ae8fb31ab..79999a7bb 100644 --- a/core/lib.rs +++ b/core/lib.rs @@ -2174,10 +2174,14 @@ impl Connection { /// 5. Step through query -> returns Busy -> return Busy to user /// /// This slight api change demonstrated a better throughtput in `perf/throughput/turso` benchmark - pub fn busy_timeout(&self, mut duration: Option) { + pub fn set_busy_timeout(&self, mut duration: Option) { duration = duration.filter(|duration| !duration.is_zero()); self.busy_timeout.set(duration); } + + pub fn get_busy_timeout(&self) -> Option { + self.busy_timeout.get() + } } #[derive(Debug, Default)] diff --git a/core/mvcc/database/mod.rs b/core/mvcc/database/mod.rs index 54cd1d0cf..547cd9e95 100644 --- a/core/mvcc/database/mod.rs +++ b/core/mvcc/database/mod.rs @@ -1,6 +1,5 @@ use crate::mvcc::clock::LogicalClock; use crate::mvcc::persistent_storage::Storage; -use crate::return_if_io; use crate::state_machine::StateMachine; use crate::state_machine::StateTransition; use crate::state_machine::TransitionResult; @@ -542,24 +541,11 @@ impl StateTransition for CommitStateMachine { if mvcc_store.is_exclusive_tx(&self.tx_id) { mvcc_store.release_exclusive_tx(&self.tx_id); self.commit_coordinator.pager_commit_lock.unlock(); - if !mvcc_store.storage.is_logical_log() { - // FIXME: this function isnt re-entrant - self.pager - .io - .block(|| self.pager.end_tx(false, &self.connection))?; - } - } else if !mvcc_store.storage.is_logical_log() { - self.pager.end_read_tx()?; } self.finalize(mvcc_store)?; return Ok(TransitionResult::Done(())); } - if mvcc_store.storage.is_logical_log() { - self.state = CommitState::Commit { end_ts }; - return Ok(TransitionResult::Continue); - } else { - self.state = CommitState::BeginPagerTxn { end_ts }; - } + self.state = CommitState::Commit { end_ts }; Ok(TransitionResult::Continue) } CommitState::BeginPagerTxn { end_ts } => { @@ -851,7 +837,6 @@ impl StateTransition for CommitStateMachine { return Ok(TransitionResult::Continue); } CommitState::BeginCommitLogicalLog { end_ts, log_record } => { - assert!(mvcc_store.storage.is_logical_log()); if !mvcc_store.is_exclusive_tx(&self.tx_id) { // logical log needs to be serialized let locked = self.commit_coordinator.pager_commit_lock.write(); @@ -866,10 +851,6 @@ impl StateTransition for CommitStateMachine { match result { IOResult::Done(_) => {} IOResult::IO(io) => { - assert!( - mvcc_store.storage.is_logical_log(), - "for now logical log is the only storage that can return IO" - ); if !io.finished() { return Ok(TransitionResult::Io(io)); } @@ -897,13 +878,11 @@ impl StateTransition for CommitStateMachine { let schema = connection.schema.borrow().clone(); connection.db.update_schema_if_newer(schema)?; } - if mvcc_store.storage.is_logical_log() { - let tx = mvcc_store.txs.get(&self.tx_id).unwrap(); - let tx_unlocked = tx.value(); - self.header.write().replace(*tx_unlocked.header.borrow()); - tracing::trace!("end_commit_logical_log(tx_id={})", self.tx_id); - self.commit_coordinator.pager_commit_lock.unlock(); - } + let tx = mvcc_store.txs.get(&self.tx_id).unwrap(); + let tx_unlocked = tx.value(); + self.header.write().replace(*tx_unlocked.header.borrow()); + tracing::trace!("end_commit_logical_log(tx_id={})", self.tx_id); + self.commit_coordinator.pager_commit_lock.unlock(); self.state = CommitState::CommitEnd { end_ts: *end_ts }; return Ok(TransitionResult::Continue); } @@ -1422,38 +1401,12 @@ impl MvStore { /// /// This is used for IMMEDIATE and EXCLUSIVE transaction types where we need /// to ensure exclusive write access as per SQLite semantics. + #[instrument(skip_all, level = Level::DEBUG)] pub fn begin_exclusive_tx( &self, pager: Arc, maybe_existing_tx_id: Option, ) -> Result> { - self._begin_exclusive_tx(pager, false, maybe_existing_tx_id) - } - - /// Upgrades a read transaction to an exclusive write transaction. - /// - /// This is used for IMMEDIATE and EXCLUSIVE transaction types where we need - /// to ensure exclusive write access as per SQLite semantics. - pub fn upgrade_to_exclusive_tx( - &self, - pager: Arc, - maybe_existing_tx_id: Option, - ) -> Result> { - self._begin_exclusive_tx(pager, true, maybe_existing_tx_id) - } - - /// Begins an exclusive write transaction that prevents concurrent writes. - /// - /// This is used for IMMEDIATE and EXCLUSIVE transaction types where we need - /// to ensure exclusive write access as per SQLite semantics. - #[instrument(skip_all, level = Level::DEBUG)] - fn _begin_exclusive_tx( - &self, - pager: Arc, - is_upgrade_from_read: bool, - maybe_existing_tx_id: Option, - ) -> Result> { - let is_logical_log = self.storage.is_logical_log(); let tx_id = maybe_existing_tx_id.unwrap_or_else(|| self.get_tx_id()); let begin_ts = if let Some(tx_id) = maybe_existing_tx_id { self.txs.get(&tx_id).unwrap().value().begin_ts @@ -1463,16 +1416,6 @@ impl MvStore { self.acquire_exclusive_tx(&tx_id)?; - // Try to acquire the pager read lock - if !is_upgrade_from_read && !is_logical_log { - pager.begin_read_tx().inspect_err(|_| { - tracing::debug!( - "begin_exclusive_tx: tx_id={} failed with Busy on pager_read_lock", - tx_id - ); - self.release_exclusive_tx(&tx_id); - })?; - } let locked = self.commit_coordinator.pager_commit_lock.write(); if !locked { tracing::debug!( @@ -1480,46 +1423,18 @@ impl MvStore { tx_id ); self.release_exclusive_tx(&tx_id); - pager.end_read_tx()?; return Err(LimboError::Busy); } let header = self.get_new_transaction_database_header(&pager); - if is_logical_log { - let tx = Transaction::new(tx_id, begin_ts, header); - tracing::trace!( - "begin_exclusive_tx(tx_id={}) - exclusive write logical log transaction", - tx_id - ); - tracing::debug!("begin_exclusive_tx: tx_id={} succeeded", tx_id); - self.txs.insert(tx_id, tx); - return Ok(IOResult::Done(tx_id)); - } - // Try to acquire the pager write lock - let begin_w_tx_res = pager.begin_write_tx(); - if let Err(LimboError::Busy) = begin_w_tx_res { - tracing::debug!("begin_exclusive_tx: tx_id={} failed with Busy", tx_id); - // Failed to get pager lock - release our exclusive lock - self.commit_coordinator.pager_commit_lock.unlock(); - self.release_exclusive_tx(&tx_id); - if maybe_existing_tx_id.is_none() { - // If we were upgrading an existing non-CONCURRENT mvcc transaction to write, we don't end the read tx on Busy. - // But if we were beginning a completely new non-CONCURRENT mvcc transaction, we do end it because the next time the connection - // attempts to do something, it will open a new read tx, which will fail if we don't end this one here. - pager.end_read_tx()?; - } - return Err(LimboError::Busy); - } - return_if_io!(begin_w_tx_res); let tx = Transaction::new(tx_id, begin_ts, header); tracing::trace!( - "begin_exclusive_tx(tx_id={}) - exclusive write transaction", + "begin_exclusive_tx(tx_id={}) - exclusive write logical log transaction", tx_id ); tracing::debug!("begin_exclusive_tx: tx_id={} succeeded", tx_id); self.txs.insert(tx_id, tx); - Ok(IOResult::Done(tx_id)) } @@ -1532,12 +1447,6 @@ impl MvStore { let tx_id = self.get_tx_id(); let begin_ts = self.get_timestamp(); - // TODO: we need to tie a pager's read transaction to a transaction ID, so that future refactors to read - // pages from WAL/DB read from a consistent state to maintiain snapshot isolation. - if !self.storage.is_logical_log() { - pager.begin_read_tx()?; - } - // Set txn's header to the global header let header = self.get_new_transaction_database_header(&pager); let tx = Transaction::new(tx_id, begin_ts, header); diff --git a/core/mvcc/persistent_storage/mod.rs b/core/mvcc/persistent_storage/mod.rs index b92bf081e..cfe977a5f 100644 --- a/core/mvcc/persistent_storage/mod.rs +++ b/core/mvcc/persistent_storage/mod.rs @@ -1,6 +1,5 @@ -use std::cell::RefCell; use std::fmt::Debug; -use std::sync::Arc; +use std::sync::{Arc, RwLock}; mod logical_log; use crate::mvcc::database::LogRecord; @@ -9,32 +8,28 @@ use crate::types::IOResult; use crate::{File, Result}; pub struct Storage { - logical_log: RefCell, + logical_log: RwLock, } impl Storage { pub fn new(file: Arc) -> Self { Self { - logical_log: RefCell::new(LogicalLog::new(file)), + logical_log: RwLock::new(LogicalLog::new(file)), } } } impl Storage { pub fn log_tx(&self, m: &LogRecord) -> Result> { - self.logical_log.borrow_mut().log_tx(m) + self.logical_log.write().unwrap().log_tx(m) } pub fn read_tx_log(&self) -> Result> { todo!() } - pub fn is_logical_log(&self) -> bool { - true - } - pub fn sync(&self) -> Result> { - self.logical_log.borrow_mut().sync() + self.logical_log.write().unwrap().sync() } } diff --git a/core/pragma.rs b/core/pragma.rs index edcfd21b9..c83509a69 100644 --- a/core/pragma.rs +++ b/core/pragma.rs @@ -102,6 +102,10 @@ pub fn pragma_for(pragma: &PragmaName) -> Pragma { PragmaFlags::NoColumns1 | PragmaFlags::Result0, &["auto_vacuum"], ), + BusyTimeout => Pragma::new( + PragmaFlags::NoColumns1 | PragmaFlags::Result0, + &["busy_timeout"], + ), IntegrityCheck => Pragma::new( PragmaFlags::NeedSchema | PragmaFlags::ReadOnly | PragmaFlags::Result0, &["message"], diff --git a/core/schema.rs b/core/schema.rs index 6d510b3a3..71cbb4932 100644 --- a/core/schema.rs +++ b/core/schema.rs @@ -527,7 +527,7 @@ impl Schema { let table = Arc::new(Table::BTree(Arc::new(BTreeTable { name: view_name.clone(), root_page: main_root, - columns: incremental_view.columns.clone(), + columns: incremental_view.column_schema.flat_columns(), primary_key_columns: Vec::new(), has_rowid: true, is_strict: false, @@ -673,11 +673,12 @@ impl Schema { .. } => { // Extract actual columns from the SELECT statement - let view_columns = crate::util::extract_view_columns(&select, self); + let view_column_schema = + crate::util::extract_view_columns(&select, self)?; // If column names were provided in CREATE VIEW (col1, col2, ...), // use them to rename the columns - let mut final_columns = view_columns; + let mut final_columns = view_column_schema.flat_columns(); for (i, indexed_col) in column_names.iter().enumerate() { if let Some(col) = final_columns.get_mut(i) { col.name = Some(indexed_col.col_name.to_string()); diff --git a/core/storage/btree.rs b/core/storage/btree.rs index e9c888e80..21806cb27 100644 --- a/core/storage/btree.rs +++ b/core/storage/btree.rs @@ -3631,6 +3631,7 @@ impl BTreeCursor { ); let divider_cell_insert_idx_in_parent = balance_info.first_divider_cell + sibling_page_idx; + #[cfg(debug_assertions)] let overflow_cell_count_before = parent_contents.overflow_cells.len(); insert_into_cell( parent_contents, @@ -3638,9 +3639,9 @@ impl BTreeCursor { divider_cell_insert_idx_in_parent, usable_space, )?; - let overflow_cell_count_after = parent_contents.overflow_cells.len(); #[cfg(debug_assertions)] { + let overflow_cell_count_after = parent_contents.overflow_cells.len(); let divider_cell_is_overflow_cell = overflow_cell_count_after > overflow_cell_count_before; @@ -6664,6 +6665,20 @@ pub fn btree_init_page(page: &PageRef, page_type: PageType, offset: usize, usabl contents.write_fragmented_bytes_count(0); contents.write_rightmost_ptr(0); + + #[cfg(debug_assertions)] + { + // we might get already used page from the pool. generally this is not a problem because + // b tree access is very controlled. However, for encrypted pages (and also checksums) we want + // to ensure that there are no reserved bytes that contain old data. + let buffer_len = contents.buffer.len(); + turso_assert!( + usable_space <= buffer_len, + "usable_space must be <= buffer_len" + ); + // this is no op if usable_space == buffer_len + contents.as_ptr()[usable_space..buffer_len].fill(0); + } } fn to_static_buf(buf: &mut [u8]) -> &'static mut [u8] { diff --git a/core/storage/database.rs b/core/storage/database.rs index e7aceebbf..3cbc42b9f 100644 --- a/core/storage/database.rs +++ b/core/storage/database.rs @@ -88,11 +88,6 @@ pub struct DatabaseFile { file: Arc, } -#[cfg(feature = "fs")] -unsafe impl Send for DatabaseFile {} -#[cfg(feature = "fs")] -unsafe impl Sync for DatabaseFile {} - #[cfg(feature = "fs")] impl DatabaseStorage for DatabaseFile { #[instrument(skip_all, level = Level::DEBUG)] diff --git a/core/storage/encryption.rs b/core/storage/encryption.rs index fb5406b85..c43a6d660 100644 --- a/core/storage/encryption.rs +++ b/core/storage/encryption.rs @@ -440,11 +440,19 @@ impl EncryptionContext { }; let metadata_size = self.cipher_mode.metadata_size(); let reserved_bytes = &page[self.page_size - metadata_size..]; - let reserved_bytes_zeroed = reserved_bytes.iter().all(|&b| b == 0); - assert!( - reserved_bytes_zeroed, - "last reserved bytes must be empty/zero, but found non-zero bytes" - ); + + #[cfg(debug_assertions)] + { + use crate::turso_assert; + // In debug builds, ensure that the reserved bytes are zeroed out. So even when we are + // reusing a page from buffer pool, we zero out in debug build so that we can be + // sure that b tree layer is not writing any data into the reserved space. + let reserved_bytes_zeroed = reserved_bytes.iter().all(|&b| b == 0); + turso_assert!( + reserved_bytes_zeroed, + "last reserved bytes must be empty/zero, but found non-zero bytes" + ); + } let payload = &page[encryption_start_offset..self.page_size - metadata_size]; let (encrypted, nonce) = self.encrypt_raw(payload)?; diff --git a/core/translate/alter.rs b/core/translate/alter.rs index e3ecf43ea..d3170fdbe 100644 --- a/core/translate/alter.rs +++ b/core/translate/alter.rs @@ -48,6 +48,15 @@ pub fn translate_alter_table( ))); }; + // Check if this table has dependent materialized views + let dependent_views = schema.get_dependent_materialized_views(table_name); + if !dependent_views.is_empty() { + return Err(LimboError::ParseError(format!( + "cannot alter table \"{table_name}\": it has dependent materialized view(s): {}", + dependent_views.join(", ") + ))); + } + let mut btree = (*original_btree).clone(); Ok(match alter_table { diff --git a/core/translate/delete.rs b/core/translate/delete.rs index dee30b2af..c2a76f9ec 100644 --- a/core/translate/delete.rs +++ b/core/translate/delete.rs @@ -1,5 +1,6 @@ use crate::schema::Table; use crate::translate::emitter::emit_program; +use crate::translate::expr::ParamState; use crate::translate::optimizer::optimize_plan; use crate::translate::plan::{DeletePlan, Operation, Plan}; use crate::translate::planner::{parse_limit, parse_where}; @@ -108,6 +109,7 @@ pub fn prepare_delete_plan( let mut table_references = TableReferences::new(joined_tables, vec![]); let mut where_predicates = vec![]; + let mut param_ctx = ParamState::default(); // Parse the WHERE clause parse_where( @@ -116,11 +118,13 @@ pub fn prepare_delete_plan( None, &mut where_predicates, connection, + &mut param_ctx, )?; // Parse the LIMIT/OFFSET clause - let (resolved_limit, resolved_offset) = - limit.map_or(Ok((None, None)), |mut l| parse_limit(&mut l, connection))?; + let (resolved_limit, resolved_offset) = limit.map_or(Ok((None, None)), |mut l| { + parse_limit(&mut l, connection, &mut param_ctx) + })?; let plan = DeletePlan { table_references, diff --git a/core/translate/expr.rs b/core/translate/expr.rs index 95ac93d95..38eca1ae4 100644 --- a/core/translate/expr.rs +++ b/core/translate/expr.rs @@ -1,3 +1,5 @@ +use std::sync::Arc; + use tracing::{instrument, Level}; use turso_parser::ast::{self, As, Expr, UnaryOperator}; @@ -8,8 +10,12 @@ use super::plan::TableReferences; use crate::function::JsonFunc; use crate::function::{Func, FuncCtx, MathFuncArity, ScalarFunc, VectorFunc}; use crate::functions::datetime; +use crate::parameters::PARAM_PREFIX; use crate::schema::{affinity, Affinity, Table, Type}; -use crate::util::{exprs_are_equivalent, parse_numeric_literal}; +use crate::translate::optimizer::TakeOwnership; +use crate::translate::plan::ResultSetColumn; +use crate::translate::planner::parse_row_id; +use crate::util::{exprs_are_equivalent, normalize_ident, parse_numeric_literal}; use crate::vdbe::builder::CursorKey; use crate::vdbe::{ builder::ProgramBuilder, @@ -3244,6 +3250,296 @@ where Ok(WalkControl::Continue) } +/// Context needed to walk all expressions in a INSERT|UPDATE|SELECT|DELETE body, +/// in the order they are encountered, to ensure that the parameters are rewritten from +/// anonymous ("?") to our internal named scheme so when the columns are re-ordered we are able +/// to bind the proper parameter values. +pub struct ParamState { + /// ALWAYS starts at 1 + pub next_param_idx: usize, +} + +impl Default for ParamState { + fn default() -> Self { + Self { next_param_idx: 1 } + } +} + +/// Rewrite ast::Expr in place, binding Column references/rewriting Expr::Id -> Expr::Column +/// using the provided TableReferences, and replacing anonymous parameters with internal named +/// ones, as well as normalizing any DoublyQualified/Qualified quoted identifiers. +pub fn bind_and_rewrite_expr<'a>( + top_level_expr: &mut ast::Expr, + mut referenced_tables: Option<&'a mut TableReferences>, + result_columns: Option<&'a [ResultSetColumn]>, + connection: &'a Arc, + param_state: &mut ParamState, +) -> Result { + walk_expr_mut( + top_level_expr, + &mut |expr: &mut ast::Expr| -> Result { + match expr { + ast::Expr::Id(ast::Name::Ident(n)) if n.eq_ignore_ascii_case("true") => { + *expr = ast::Expr::Literal(ast::Literal::Numeric("1".to_string())); + } + ast::Expr::Id(ast::Name::Ident(n)) if n.eq_ignore_ascii_case("false") => { + *expr = ast::Expr::Literal(ast::Literal::Numeric("0".to_string())); + } + // Rewrite anonymous variables in encounter order. + ast::Expr::Variable(var) if var.is_empty() => { + *expr = ast::Expr::Variable(format!( + "{}{}", + PARAM_PREFIX, param_state.next_param_idx + )); + param_state.next_param_idx += 1; + } + ast::Expr::Qualified(ast::Name::Quoted(ns), ast::Name::Quoted(c)) + | ast::Expr::DoublyQualified(_, ast::Name::Quoted(ns), ast::Name::Quoted(c)) => { + *expr = ast::Expr::Qualified( + ast::Name::Ident(normalize_ident(ns.as_str())), + ast::Name::Ident(normalize_ident(c.as_str())), + ); + } + ast::Expr::Between { + lhs, + not, + start, + end, + } => { + let (lower_op, upper_op) = if *not { + (ast::Operator::Greater, ast::Operator::Greater) + } else { + (ast::Operator::LessEquals, ast::Operator::LessEquals) + }; + let start = start.take_ownership(); + let lhs_v = lhs.take_ownership(); + let end = end.take_ownership(); + + let lower = + ast::Expr::Binary(Box::new(start), lower_op, Box::new(lhs_v.clone())); + let upper = ast::Expr::Binary(Box::new(lhs_v), upper_op, Box::new(end)); + + *expr = if *not { + ast::Expr::Binary(Box::new(lower), ast::Operator::Or, Box::new(upper)) + } else { + ast::Expr::Binary(Box::new(lower), ast::Operator::And, Box::new(upper)) + }; + } + _ => {} + } + if let Some(referenced_tables) = &mut referenced_tables { + match expr { + // Unqualified identifier binding (including rowid aliases, outer refs, result-column fallback). + Expr::Id(id) => { + let normalized_id = normalize_ident(id.as_str()); + if !referenced_tables.joined_tables().is_empty() { + if let Some(row_id_expr) = parse_row_id( + &normalized_id, + referenced_tables.joined_tables()[0].internal_id, + || referenced_tables.joined_tables().len() != 1, + )? { + *expr = row_id_expr; + + return Ok(WalkControl::Continue); + } + } + let mut match_result = None; + + // First check joined tables + for joined_table in referenced_tables.joined_tables().iter() { + let col_idx = joined_table.table.columns().iter().position(|c| { + c.name + .as_ref() + .is_some_and(|name| name.eq_ignore_ascii_case(&normalized_id)) + }); + if col_idx.is_some() { + if match_result.is_some() { + crate::bail_parse_error!("Column {} is ambiguous", id.as_str()); + } + let col = + joined_table.table.columns().get(col_idx.unwrap()).unwrap(); + match_result = Some(( + joined_table.internal_id, + col_idx.unwrap(), + col.is_rowid_alias, + )); + } + } + + // Then check outer query references, if we still didn't find something. + // Normally finding multiple matches for a non-qualified column is an error (column x is ambiguous) + // but in the case of subqueries, the inner query takes precedence. + // For example: + // SELECT * FROM t WHERE x = (SELECT x FROM t2) + // In this case, there is no ambiguity: + // - x in the outer query refers to t.x, + // - x in the inner query refers to t2.x. + if match_result.is_none() { + for outer_ref in referenced_tables.outer_query_refs().iter() { + let col_idx = outer_ref.table.columns().iter().position(|c| { + c.name.as_ref().is_some_and(|name| { + name.eq_ignore_ascii_case(&normalized_id) + }) + }); + if col_idx.is_some() { + if match_result.is_some() { + crate::bail_parse_error!( + "Column {} is ambiguous", + id.as_str() + ); + } + let col = + outer_ref.table.columns().get(col_idx.unwrap()).unwrap(); + match_result = Some(( + outer_ref.internal_id, + col_idx.unwrap(), + col.is_rowid_alias, + )); + } + } + } + + if let Some((table_id, col_idx, is_rowid_alias)) = match_result { + *expr = Expr::Column { + database: None, // TODO: support different databases + table: table_id, + column: col_idx, + is_rowid_alias, + }; + referenced_tables.mark_column_used(table_id, col_idx); + return Ok(WalkControl::Continue); + } + + if let Some(result_columns) = result_columns { + for result_column in result_columns.iter() { + if result_column + .name(referenced_tables) + .is_some_and(|name| name.eq_ignore_ascii_case(&normalized_id)) + { + *expr = result_column.expr.clone(); + return Ok(WalkControl::Continue); + } + } + } + // SQLite behavior: Only double-quoted identifiers get fallback to string literals + // Single quotes are handled as literals earlier, unquoted identifiers must resolve to columns + if id.is_double_quoted() { + // Convert failed double-quoted identifier to string literal + *expr = Expr::Literal(ast::Literal::String(id.as_str().to_string())); + return Ok(WalkControl::Continue); + } else { + // Unquoted identifiers must resolve to columns - no fallback + crate::bail_parse_error!("no such column: {}", id.as_str()) + } + } + Expr::Qualified(tbl, id) => { + let normalized_table_name = normalize_ident(tbl.as_str()); + let matching_tbl = referenced_tables + .find_table_and_internal_id_by_identifier(&normalized_table_name); + if matching_tbl.is_none() { + crate::bail_parse_error!("no such table: {}", normalized_table_name); + } + let (tbl_id, tbl) = matching_tbl.unwrap(); + let normalized_id = normalize_ident(id.as_str()); + + if let Some(row_id_expr) = parse_row_id(&normalized_id, tbl_id, || false)? { + *expr = row_id_expr; + + return Ok(WalkControl::Continue); + } + let col_idx = tbl.columns().iter().position(|c| { + c.name + .as_ref() + .is_some_and(|name| name.eq_ignore_ascii_case(&normalized_id)) + }); + let Some(col_idx) = col_idx else { + crate::bail_parse_error!("no such column: {}", normalized_id); + }; + let col = tbl.columns().get(col_idx).unwrap(); + *expr = Expr::Column { + database: None, // TODO: support different databases + table: tbl_id, + column: col_idx, + is_rowid_alias: col.is_rowid_alias, + }; + referenced_tables.mark_column_used(tbl_id, col_idx); + return Ok(WalkControl::Continue); + } + Expr::DoublyQualified(db_name, tbl_name, col_name) => { + let normalized_col_name = normalize_ident(col_name.as_str()); + + // Create a QualifiedName and use existing resolve_database_id method + let qualified_name = ast::QualifiedName { + db_name: Some(db_name.clone()), + name: tbl_name.clone(), + alias: None, + }; + let database_id = connection.resolve_database_id(&qualified_name)?; + + // Get the table from the specified database + let table = connection + .with_schema(database_id, |schema| schema.get_table(tbl_name.as_str())) + .ok_or_else(|| { + crate::LimboError::ParseError(format!( + "no such table: {}.{}", + db_name.as_str(), + tbl_name.as_str() + )) + })?; + + // Find the column in the table + let col_idx = table + .columns() + .iter() + .position(|c| { + c.name.as_ref().is_some_and(|name| { + name.eq_ignore_ascii_case(&normalized_col_name) + }) + }) + .ok_or_else(|| { + crate::LimboError::ParseError(format!( + "Column: {}.{}.{} not found", + db_name.as_str(), + tbl_name.as_str(), + col_name.as_str() + )) + })?; + + let col = table.columns().get(col_idx).unwrap(); + + // Check if this is a rowid alias + let is_rowid_alias = col.is_rowid_alias; + + // Convert to Column expression - since this is a cross-database reference, + // we need to create a synthetic table reference for it + // For now, we'll error if the table isn't already in the referenced tables + let normalized_tbl_name = normalize_ident(tbl_name.as_str()); + let matching_tbl = referenced_tables + .find_table_and_internal_id_by_identifier(&normalized_tbl_name); + + if let Some((tbl_id, _)) = matching_tbl { + // Table is already in referenced tables, use existing internal ID + *expr = Expr::Column { + database: Some(database_id), + table: tbl_id, + column: col_idx, + is_rowid_alias, + }; + referenced_tables.mark_column_used(tbl_id, col_idx); + } else { + return Err(crate::LimboError::ParseError(format!( + "table {normalized_tbl_name} is not in FROM clause - cross-database column references require the table to be explicitly joined" + ))); + } + } + _ => {} + } + } + Ok(WalkControl::Continue) + }, + ) +} + /// Recursively walks a mutable expression, applying a function to each sub-expression. pub fn walk_expr_mut(expr: &mut ast::Expr, func: &mut F) -> Result where @@ -3709,12 +4005,12 @@ pub fn process_returning_clause( table_name: &str, program: &mut ProgramBuilder, connection: &std::sync::Arc, + param_ctx: &mut ParamState, ) -> Result<( Vec, super::plan::TableReferences, )> { use super::plan::{ColumnUsedMask, JoinedTable, Operation, ResultSetColumn, TableReferences}; - use super::planner::bind_column_references; let mut result_columns = vec![]; @@ -3741,7 +4037,13 @@ pub fn process_returning_clause( ast::ResultColumn::Expr(expr, alias) => { let column_alias = determine_column_alias(expr, alias, table); - bind_column_references(expr, &mut table_references, None, connection)?; + bind_and_rewrite_expr( + expr, + Some(&mut table_references), + None, + connection, + param_ctx, + )?; result_columns.push(ResultSetColumn { expr: *expr.clone(), diff --git a/core/translate/insert.rs b/core/translate/insert.rs index 83c176a77..ddcf00755 100644 --- a/core/translate/insert.rs +++ b/core/translate/insert.rs @@ -10,7 +10,8 @@ use crate::translate::emitter::{ emit_cdc_insns, emit_cdc_patch_record, prepare_cdc_if_necessary, OperationMode, }; use crate::translate::expr::{ - emit_returning_results, process_returning_clause, ReturningValueRegisters, + bind_and_rewrite_expr, emit_returning_results, process_returning_clause, ParamState, + ReturningValueRegisters, }; use crate::translate::planner::ROWID; use crate::translate::upsert::{ @@ -31,7 +32,6 @@ use crate::{Result, SymbolTable, VirtualTable}; use super::emitter::Resolver; use super::expr::{translate_expr, translate_expr_no_constant_opt, NoConstantOptReason}; -use super::optimizer::rewrite_expr; use super::plan::QueryDestination; use super::select::translate_select; @@ -118,7 +118,7 @@ pub fn translate_insert( let mut values: Option>> = None; let mut upsert_opt: Option = None; - let mut param_idx = 1; + let mut param_ctx = ParamState::default(); let mut inserting_multiple_rows = false; if let InsertBody::Select(select, upsert) = &mut body { match &mut select.body.select { @@ -144,7 +144,7 @@ pub fn translate_insert( } _ => {} } - rewrite_expr(expr, &mut param_idx)?; + bind_and_rewrite_expr(expr, None, None, connection, &mut param_ctx)?; } values = values_expr.pop(); } @@ -157,10 +157,10 @@ pub fn translate_insert( } = &mut upsert.do_clause { for set in sets.iter_mut() { - rewrite_expr(set.expr.as_mut(), &mut param_idx)?; + bind_and_rewrite_expr(&mut set.expr, None, None, connection, &mut param_ctx)?; } if let Some(ref mut where_expr) = where_clause { - rewrite_expr(where_expr.as_mut(), &mut param_idx)?; + bind_and_rewrite_expr(where_expr, None, None, connection, &mut param_ctx)?; } } } @@ -180,6 +180,7 @@ pub fn translate_insert( table_name.as_str(), &mut program, connection, + &mut param_ctx, )?; let mut yield_reg_opt = None; diff --git a/core/translate/logical.rs b/core/translate/logical.rs index aa71f047b..b11e2df4f 100644 --- a/core/translate/logical.rs +++ b/core/translate/logical.rs @@ -19,23 +19,35 @@ use turso_parser::ast; /// Result type for preprocessing aggregate expressions type PreprocessAggregateResult = ( - bool, // needs_pre_projection - Vec, // pre_projection_exprs - Vec<(String, Type)>, // pre_projection_schema - Vec, // modified_aggr_exprs + bool, // needs_pre_projection + Vec, // pre_projection_exprs + Vec, // pre_projection_schema + Vec, // modified_aggr_exprs ); +/// Result type for parsing join conditions +type JoinConditionsResult = (Vec<(LogicalExpr, LogicalExpr)>, Option); + +/// Information about a column in a logical schema +#[derive(Debug, Clone, PartialEq)] +pub struct ColumnInfo { + pub name: String, + pub ty: Type, + pub database: Option, + pub table: Option, + pub table_alias: Option, +} + /// Schema information for logical plan nodes #[derive(Debug, Clone, PartialEq)] pub struct LogicalSchema { - /// Column names and types - pub columns: Vec<(String, Type)>, + pub columns: Vec, } /// A reference to a schema that can be shared between nodes pub type SchemaRef = Arc; impl LogicalSchema { - pub fn new(columns: Vec<(String, Type)>) -> Self { + pub fn new(columns: Vec) -> Self { Self { columns } } @@ -49,11 +61,42 @@ impl LogicalSchema { self.columns.len() } - pub fn find_column(&self, name: &str) -> Option<(usize, &Type)> { - self.columns - .iter() - .position(|(n, _)| n == name) - .map(|idx| (idx, &self.columns[idx].1)) + pub fn find_column(&self, name: &str, table: Option<&str>) -> Option<(usize, &ColumnInfo)> { + if let Some(table_ref) = table { + // Check if it's a database.table format + if table_ref.contains('.') { + let parts: Vec<&str> = table_ref.splitn(2, '.').collect(); + if parts.len() == 2 { + let db = parts[0]; + let tbl = parts[1]; + return self + .columns + .iter() + .position(|c| { + c.name == name + && c.database.as_deref() == Some(db) + && c.table.as_deref() == Some(tbl) + }) + .map(|idx| (idx, &self.columns[idx])); + } + } + + // Try to match against table alias first, then table name + self.columns + .iter() + .position(|c| { + c.name == name + && (c.table_alias.as_deref() == Some(table_ref) + || c.table.as_deref() == Some(table_ref)) + }) + .map(|idx| (idx, &self.columns[idx])) + } else { + // Unqualified lookup - just match by name + self.columns + .iter() + .position(|c| c.name == name) + .map(|idx| (idx, &self.columns[idx])) + } } } @@ -66,8 +109,8 @@ pub enum LogicalPlan { Filter(Filter), /// Aggregate - GROUP BY with aggregate functions Aggregate(Aggregate), - // TODO: Join - combining two relations (not yet implemented) - // Join(Join), + /// Join - combining two relations + Join(Join), /// Sort - ORDER BY clause Sort(Sort), /// Limit - LIMIT/OFFSET clause @@ -95,7 +138,7 @@ impl LogicalPlan { LogicalPlan::Projection(p) => &p.schema, LogicalPlan::Filter(f) => f.input.schema(), LogicalPlan::Aggregate(a) => &a.schema, - // LogicalPlan::Join(j) => &j.schema, + LogicalPlan::Join(j) => &j.schema, LogicalPlan::Sort(s) => s.input.schema(), LogicalPlan::Limit(l) => l.input.schema(), LogicalPlan::TableScan(t) => &t.schema, @@ -133,26 +176,26 @@ pub struct Aggregate { pub schema: SchemaRef, } -// TODO: Join operator (not yet implemented) -// #[derive(Debug, Clone, PartialEq)] -// pub struct Join { -// pub left: Arc, -// pub right: Arc, -// pub join_type: JoinType, -// pub on: Vec<(LogicalExpr, LogicalExpr)>, // Equijoin conditions -// pub filter: Option, // Additional filter conditions -// pub schema: SchemaRef, -// } +/// Types of joins +#[derive(Debug, Clone, Copy, PartialEq)] +pub enum JoinType { + Inner, + Left, + Right, + Full, + Cross, +} -// TODO: Types of joins (not yet implemented) -// #[derive(Debug, Clone, Copy, PartialEq)] -// pub enum JoinType { -// Inner, -// Left, -// Right, -// Full, -// Cross, -// } +/// Join operator - combines two relations +#[derive(Debug, Clone, PartialEq)] +pub struct Join { + pub left: Arc, + pub right: Arc, + pub join_type: JoinType, + pub on: Vec<(LogicalExpr, LogicalExpr)>, // Equijoin conditions (left_expr, right_expr) + pub filter: Option, // Additional filter conditions + pub schema: SchemaRef, +} /// Sort operator - ORDER BY #[derive(Debug, Clone, PartialEq)] @@ -545,14 +588,14 @@ impl<'a> LogicalPlanBuilder<'a> { } // Regular table scan - let table_schema = self.get_table_schema(&table_name)?; let table_alias = alias.as_ref().map(|a| match a { ast::As::As(name) => Self::name_to_string(name), ast::As::Elided(name) => Self::name_to_string(name), }); + let table_schema = self.get_table_schema(&table_name, table_alias.as_deref())?; Ok(LogicalPlan::TableScan(TableScan { table_name, - alias: table_alias, + alias: table_alias.clone(), schema: table_schema, projection: None, })) @@ -570,14 +613,291 @@ impl<'a> LogicalPlanBuilder<'a> { // Build JOIN fn build_join( &mut self, - _left: LogicalPlan, - _right: LogicalPlan, - _op: &ast::JoinOperator, - _constraint: &Option, + left: LogicalPlan, + right: LogicalPlan, + op: &ast::JoinOperator, + constraint: &Option, ) -> Result { - Err(LimboError::ParseError( - "JOINs are not yet supported in logical plans".to_string(), - )) + // Determine join type + let join_type = match op { + ast::JoinOperator::Comma => JoinType::Cross, // Comma is essentially a cross join + ast::JoinOperator::TypedJoin(Some(jt)) => { + // Check the join type flags + // Note: JoinType can have multiple flags set + if jt.contains(ast::JoinType::NATURAL) { + // Natural joins need special handling - find common columns + return self.build_natural_join(left, right, JoinType::Inner); + } else if jt.contains(ast::JoinType::LEFT) + && jt.contains(ast::JoinType::RIGHT) + && jt.contains(ast::JoinType::OUTER) + { + // FULL OUTER JOIN (has LEFT, RIGHT, and OUTER) + JoinType::Full + } else if jt.contains(ast::JoinType::LEFT) && jt.contains(ast::JoinType::OUTER) { + JoinType::Left + } else if jt.contains(ast::JoinType::RIGHT) && jt.contains(ast::JoinType::OUTER) { + JoinType::Right + } else if jt.contains(ast::JoinType::OUTER) + && !jt.contains(ast::JoinType::LEFT) + && !jt.contains(ast::JoinType::RIGHT) + { + // Plain OUTER JOIN should also be FULL + JoinType::Full + } else if jt.contains(ast::JoinType::LEFT) { + JoinType::Left + } else if jt.contains(ast::JoinType::RIGHT) { + JoinType::Right + } else if jt.contains(ast::JoinType::CROSS) + || (jt.contains(ast::JoinType::INNER) && jt.contains(ast::JoinType::CROSS)) + { + JoinType::Cross + } else { + JoinType::Inner // Default to inner + } + } + ast::JoinOperator::TypedJoin(None) => JoinType::Inner, // Default JOIN is INNER JOIN + }; + + // Build join conditions + let (on_conditions, filter) = match constraint { + Some(ast::JoinConstraint::On(expr)) => { + // Parse ON clause into equijoin conditions and filters + self.parse_join_conditions(expr, left.schema(), right.schema())? + } + Some(ast::JoinConstraint::Using(columns)) => { + // Build equijoin conditions from USING clause + let on = self.build_using_conditions(columns, left.schema(), right.schema())?; + (on, None) + } + None => { + // Cross join or natural join + (Vec::new(), None) + } + }; + + // Build combined schema + let schema = self.build_join_schema(&left, &right, &join_type)?; + + Ok(LogicalPlan::Join(Join { + left: Arc::new(left), + right: Arc::new(right), + join_type, + on: on_conditions, + filter, + schema, + })) + } + + // Helper: Parse join conditions into equijoins and filters + fn parse_join_conditions( + &mut self, + expr: &ast::Expr, + left_schema: &SchemaRef, + right_schema: &SchemaRef, + ) -> Result { + // For now, we'll handle simple equality conditions + // More complex conditions will go into the filter + let mut equijoins = Vec::new(); + let mut filters = Vec::new(); + + // Try to extract equijoin conditions from the expression + self.extract_equijoin_conditions( + expr, + left_schema, + right_schema, + &mut equijoins, + &mut filters, + )?; + + let filter = if filters.is_empty() { + None + } else { + // Combine multiple filters with AND + Some( + filters + .into_iter() + .reduce(|acc, e| LogicalExpr::BinaryExpr { + left: Box::new(acc), + op: BinaryOperator::And, + right: Box::new(e), + }) + .unwrap(), + ) + }; + + Ok((equijoins, filter)) + } + + // Helper: Extract equijoin conditions from expression + fn extract_equijoin_conditions( + &mut self, + expr: &ast::Expr, + left_schema: &SchemaRef, + right_schema: &SchemaRef, + equijoins: &mut Vec<(LogicalExpr, LogicalExpr)>, + filters: &mut Vec, + ) -> Result<()> { + match expr { + ast::Expr::Binary(lhs, ast::Operator::Equals, rhs) => { + // Check if this is an equijoin condition (left.col = right.col) + let left_expr = self.build_expr(lhs, left_schema)?; + let right_expr = self.build_expr(rhs, right_schema)?; + + // For simplicity, we'll check if one references left and one references right + // In a real implementation, we'd need more sophisticated column resolution + equijoins.push((left_expr, right_expr)); + } + ast::Expr::Binary(lhs, ast::Operator::And, rhs) => { + // Recursively extract from AND conditions + self.extract_equijoin_conditions( + lhs, + left_schema, + right_schema, + equijoins, + filters, + )?; + self.extract_equijoin_conditions( + rhs, + left_schema, + right_schema, + equijoins, + filters, + )?; + } + _ => { + // Other conditions go into the filter + // We need a combined schema to build the expression + let combined_schema = self.combine_schemas(left_schema, right_schema)?; + let filter_expr = self.build_expr(expr, &combined_schema)?; + filters.push(filter_expr); + } + } + Ok(()) + } + + // Helper: Build equijoin conditions from USING clause + fn build_using_conditions( + &mut self, + columns: &[ast::Name], + left_schema: &SchemaRef, + right_schema: &SchemaRef, + ) -> Result> { + let mut conditions = Vec::new(); + + for col_name in columns { + let name = Self::name_to_string(col_name); + + // Find the column in both schemas + let _left_idx = left_schema + .columns + .iter() + .position(|col| col.name == name) + .ok_or_else(|| { + LimboError::ParseError(format!("Column {name} not found in left table")) + })?; + let _right_idx = right_schema + .columns + .iter() + .position(|col| col.name == name) + .ok_or_else(|| { + LimboError::ParseError(format!("Column {name} not found in right table")) + })?; + + conditions.push(( + LogicalExpr::Column(Column { + name: name.clone(), + table: None, // Will be resolved later + }), + LogicalExpr::Column(Column { + name, + table: None, // Will be resolved later + }), + )); + } + + Ok(conditions) + } + + // Helper: Build natural join by finding common columns + fn build_natural_join( + &mut self, + left: LogicalPlan, + right: LogicalPlan, + join_type: JoinType, + ) -> Result { + let left_schema = left.schema(); + let right_schema = right.schema(); + + // Find common column names + let mut common_columns = Vec::new(); + for left_col in &left_schema.columns { + if right_schema + .columns + .iter() + .any(|col| col.name == left_col.name) + { + common_columns.push(ast::Name::Ident(left_col.name.clone())); + } + } + + if common_columns.is_empty() { + // Natural join with no common columns becomes a cross join + let schema = self.build_join_schema(&left, &right, &JoinType::Cross)?; + return Ok(LogicalPlan::Join(Join { + left: Arc::new(left), + right: Arc::new(right), + join_type: JoinType::Cross, + on: Vec::new(), + filter: None, + schema, + })); + } + + // Build equijoin conditions for common columns + let on = self.build_using_conditions(&common_columns, left_schema, right_schema)?; + let schema = self.build_join_schema(&left, &right, &join_type)?; + + Ok(LogicalPlan::Join(Join { + left: Arc::new(left), + right: Arc::new(right), + join_type, + on, + filter: None, + schema, + })) + } + + // Helper: Build schema for join result + fn build_join_schema( + &self, + left: &LogicalPlan, + right: &LogicalPlan, + _join_type: &JoinType, + ) -> Result { + let left_schema = left.schema(); + let right_schema = right.schema(); + + // Concatenate the schemas, preserving all column information + let mut columns = Vec::new(); + + // Keep all columns from left with their table info + for col in &left_schema.columns { + columns.push(col.clone()); + } + + // Keep all columns from right with their table info + for col in &right_schema.columns { + columns.push(col.clone()); + } + + Ok(Arc::new(LogicalSchema::new(columns))) + } + + // Helper: Combine two schemas for expression building + fn combine_schemas(&self, left: &SchemaRef, right: &SchemaRef) -> Result { + let mut columns = left.columns.clone(); + columns.extend(right.columns.clone()); + Ok(Arc::new(LogicalSchema::new(columns))) } // Build projection @@ -602,7 +922,13 @@ impl<'a> LogicalPlanBuilder<'a> { }; let col_type = Self::infer_expr_type(&logical_expr, input_schema)?; - schema_columns.push((col_name.clone(), col_type)); + schema_columns.push(ColumnInfo { + name: col_name.clone(), + ty: col_type, + database: None, + table: None, + table_alias: None, + }); if let Some(as_alias) = alias { let alias_name = match as_alias { @@ -618,21 +944,21 @@ impl<'a> LogicalPlanBuilder<'a> { } ast::ResultColumn::Star => { // Expand * to all columns - for (name, typ) in &input_schema.columns { - proj_exprs.push(LogicalExpr::Column(Column::new(name.clone()))); - schema_columns.push((name.clone(), *typ)); + for col in &input_schema.columns { + proj_exprs.push(LogicalExpr::Column(Column::new(col.name.clone()))); + schema_columns.push(col.clone()); } } ast::ResultColumn::TableStar(table) => { // Expand table.* to all columns from that table let table_name = Self::name_to_string(table); - for (name, typ) in &input_schema.columns { + for col in &input_schema.columns { // Simple check - would need proper table tracking in real implementation proj_exprs.push(LogicalExpr::Column(Column::with_table( - name.clone(), + col.name.clone(), table_name.clone(), ))); - schema_columns.push((name.clone(), *typ)); + schema_columns.push(col.clone()); } } } @@ -670,7 +996,13 @@ impl<'a> LogicalPlanBuilder<'a> { if let LogicalExpr::Column(col) = expr { pre_projection_exprs.push(expr.clone()); let col_type = Self::infer_expr_type(expr, input_schema)?; - pre_projection_schema.push((col.name.clone(), col_type)); + pre_projection_schema.push(ColumnInfo { + name: col.name.clone(), + ty: col_type, + database: None, + table: col.table.clone(), + table_alias: None, + }); } else { // Complex group by expression - project it needs_pre_projection = true; @@ -678,7 +1010,13 @@ impl<'a> LogicalPlanBuilder<'a> { projected_col_counter += 1; pre_projection_exprs.push(expr.clone()); let col_type = Self::infer_expr_type(expr, input_schema)?; - pre_projection_schema.push((proj_col_name.clone(), col_type)); + pre_projection_schema.push(ColumnInfo { + name: proj_col_name.clone(), + ty: col_type, + database: None, + table: None, + table_alias: None, + }); } } @@ -702,7 +1040,13 @@ impl<'a> LogicalPlanBuilder<'a> { pre_projection_exprs.push(arg.clone()); let col_type = Self::infer_expr_type(arg, input_schema)?; if let LogicalExpr::Column(col) = arg { - pre_projection_schema.push((col.name.clone(), col_type)); + pre_projection_schema.push(ColumnInfo { + name: col.name.clone(), + ty: col_type, + database: None, + table: col.table.clone(), + table_alias: None, + }); } } } @@ -715,7 +1059,13 @@ impl<'a> LogicalPlanBuilder<'a> { // Add the expression to the pre-projection pre_projection_exprs.push(arg.clone()); let col_type = Self::infer_expr_type(arg, input_schema)?; - pre_projection_schema.push((proj_col_name.clone(), col_type)); + pre_projection_schema.push(ColumnInfo { + name: proj_col_name.clone(), + ty: col_type, + database: None, + table: None, + table_alias: None, + }); // In the aggregate, reference the projected column modified_args.push(LogicalExpr::Column(Column::new(proj_col_name))); @@ -789,15 +1139,39 @@ impl<'a> LogicalPlanBuilder<'a> { // First, add GROUP BY columns to the aggregate output schema // These are always part of the aggregate operator's output for group_expr in &group_exprs { - let col_name = match group_expr { - LogicalExpr::Column(col) => col.name.clone(), + match group_expr { + LogicalExpr::Column(col) => { + // For column references in GROUP BY, preserve the original column info + if let Some((_, col_info)) = + input_schema.find_column(&col.name, col.table.as_deref()) + { + // Preserve the column with all its table information + aggregate_schema_columns.push(col_info.clone()); + } else { + // Fallback if column not found (shouldn't happen) + let col_type = Self::infer_expr_type(group_expr, input_schema)?; + aggregate_schema_columns.push(ColumnInfo { + name: col.name.clone(), + ty: col_type, + database: None, + table: col.table.clone(), + table_alias: None, + }); + } + } _ => { // For complex GROUP BY expressions, generate a name - format!("__group_{}", aggregate_schema_columns.len()) + let col_name = format!("__group_{}", aggregate_schema_columns.len()); + let col_type = Self::infer_expr_type(group_expr, input_schema)?; + aggregate_schema_columns.push(ColumnInfo { + name: col_name, + ty: col_type, + database: None, + table: None, + table_alias: None, + }); } - }; - let col_type = Self::infer_expr_type(group_expr, input_schema)?; - aggregate_schema_columns.push((col_name, col_type)); + } } // Track aggregates we've already seen to avoid duplicates @@ -830,7 +1204,13 @@ impl<'a> LogicalPlanBuilder<'a> { } else { // New aggregate - add it let col_type = Self::infer_expr_type(&logical_expr, input_schema)?; - aggregate_schema_columns.push((col_name.clone(), col_type)); + aggregate_schema_columns.push(ColumnInfo { + name: col_name.clone(), + ty: col_type, + database: None, + table: None, + table_alias: None, + }); aggr_exprs.push(logical_expr); aggregate_map.insert(agg_key, col_name.clone()); col_name.clone() @@ -854,7 +1234,13 @@ impl<'a> LogicalPlanBuilder<'a> { // Add only new aggregates for (agg_expr, agg_name) in extracted_aggs { let agg_type = Self::infer_expr_type(&agg_expr, input_schema)?; - aggregate_schema_columns.push((agg_name, agg_type)); + aggregate_schema_columns.push(ColumnInfo { + name: agg_name, + ty: agg_type, + database: None, + table: None, + table_alias: None, + }); aggr_exprs.push(agg_expr); } @@ -929,7 +1315,13 @@ impl<'a> LogicalPlanBuilder<'a> { // For type inference, we need the aggregate schema for column references let aggregate_schema = LogicalSchema::new(aggregate_schema_columns.clone()); let col_type = Self::infer_expr_type(expr, &Arc::new(aggregate_schema))?; - projection_schema_columns.push((col_name, col_type)); + projection_schema_columns.push(ColumnInfo { + name: col_name, + ty: col_type, + database: None, + table: None, + table_alias: None, + }); } // Create the input plan (with pre-projection if needed) @@ -952,11 +1344,11 @@ impl<'a> LogicalPlanBuilder<'a> { // Check if we need the outer projection // We need a projection if: - // 1. Any expression is more complex than a simple column reference (e.g., abs(sum(id))) - // 2. We're selecting a different set of columns than what the aggregate outputs - // 3. Columns are renamed or reordered + // 1. We have expressions that compute new values (e.g., SUM(x) * 2) + // 2. We're selecting a different set of columns than GROUP BY + aggregates + // 3. We're reordering columns from their natural aggregate output order let needs_outer_projection = { - // Check if any expression is more complex than a simple column reference + // Check for complex expressions let has_complex_exprs = projection_exprs .iter() .any(|expr| !matches!(expr, LogicalExpr::Column(_))); @@ -964,17 +1356,29 @@ impl<'a> LogicalPlanBuilder<'a> { if has_complex_exprs { true } else { - // All are simple columns - check if we're selecting exactly what the aggregate outputs - // The projection might be selecting a subset (e.g., only aggregates without group columns) - // or reordering columns, or using different names + // Check if we're selecting exactly what aggregate outputs in the same order + // The aggregate outputs: all GROUP BY columns, then all aggregate expressions + // The projection might select a subset or reorder these - // For now, keep it simple: if schemas don't match exactly, we need projection - // This handles all cases: subset selection, reordering, renaming - projection_schema_columns != aggregate_schema_columns + if projection_exprs.len() != aggregate_schema_columns.len() { + // Different number of columns + true + } else { + // Check if columns match in order and name + !projection_exprs.iter().zip(&aggregate_schema_columns).all( + |(expr, agg_col)| { + if let LogicalExpr::Column(col) = expr { + col.name == agg_col.name + } else { + false + } + }, + ) + } } }; - // Create the aggregate node + // Create the aggregate node with its natural schema let aggregate_plan = LogicalPlan::Aggregate(Aggregate { input: aggregate_input, group_expr: group_exprs, @@ -989,7 +1393,7 @@ impl<'a> LogicalPlanBuilder<'a> { schema: Arc::new(LogicalSchema::new(projection_schema_columns)), })) } else { - // No projection needed - the aggregate output is exactly what we want + // No projection needed - aggregate output matches what we want Ok(aggregate_plan) } } @@ -1007,7 +1411,13 @@ impl<'a> LogicalPlanBuilder<'a> { // Infer schema from first row let mut schema_columns = Vec::new(); for (i, _) in values[0].iter().enumerate() { - schema_columns.push((format!("column{}", i + 1), Type::Text)); + schema_columns.push(ColumnInfo { + name: format!("column{}", i + 1), + ty: Type::Text, + database: None, + table: None, + table_alias: None, + }); } for row in values { @@ -1735,17 +2145,31 @@ impl<'a> LogicalPlanBuilder<'a> { } // Get table schema - fn get_table_schema(&self, table_name: &str) -> Result { + fn get_table_schema(&self, table_name: &str, alias: Option<&str>) -> Result { // Look up table in schema let table = self .schema .get_table(table_name) .ok_or_else(|| LimboError::ParseError(format!("Table '{table_name}' not found")))?; + // Parse table_name which might be "db.table" for attached databases + let (database, actual_table) = if table_name.contains('.') { + let parts: Vec<&str> = table_name.splitn(2, '.').collect(); + (Some(parts[0].to_string()), parts[1].to_string()) + } else { + (None, table_name.to_string()) + }; + let mut columns = Vec::new(); for col in table.columns() { if let Some(ref name) = col.name { - columns.push((name.clone(), col.ty)); + columns.push(ColumnInfo { + name: name.clone(), + ty: col.ty, + database: database.clone(), + table: Some(actual_table.clone()), + table_alias: alias.map(|s| s.to_string()), + }); } } @@ -1756,8 +2180,8 @@ impl<'a> LogicalPlanBuilder<'a> { fn infer_expr_type(expr: &LogicalExpr, schema: &SchemaRef) -> Result { match expr { LogicalExpr::Column(col) => { - if let Some((_, typ)) = schema.find_column(&col.name) { - Ok(*typ) + if let Some((_, col_info)) = schema.find_column(&col.name, col.table.as_deref()) { + Ok(col_info.ty) } else { Ok(Type::Text) } @@ -1974,6 +2398,67 @@ mod tests { }; schema.add_btree_table(Arc::new(orders_table)); + // Create products table + let products_table = BTreeTable { + name: "products".to_string(), + root_page: 4, + primary_key_columns: vec![("id".to_string(), turso_parser::ast::SortOrder::Asc)], + columns: vec![ + SchemaColumn { + name: Some("id".to_string()), + ty: Type::Integer, + ty_str: "INTEGER".to_string(), + primary_key: true, + is_rowid_alias: true, + notnull: true, + default: None, + unique: false, + collation: None, + hidden: false, + }, + SchemaColumn { + name: Some("name".to_string()), + ty: Type::Text, + ty_str: "TEXT".to_string(), + primary_key: false, + is_rowid_alias: false, + notnull: false, + default: None, + unique: false, + collation: None, + hidden: false, + }, + SchemaColumn { + name: Some("price".to_string()), + ty: Type::Real, + ty_str: "REAL".to_string(), + primary_key: false, + is_rowid_alias: false, + notnull: false, + default: None, + unique: false, + collation: None, + hidden: false, + }, + SchemaColumn { + name: Some("product_id".to_string()), + ty: Type::Integer, + ty_str: "INTEGER".to_string(), + primary_key: false, + is_rowid_alias: false, + notnull: false, + default: None, + unique: false, + collation: None, + hidden: false, + }, + ], + has_rowid: true, + is_strict: false, + unique_sets: vec![], + }; + schema.add_btree_table(Arc::new(products_table)); + schema } @@ -3086,4 +3571,381 @@ mod tests { _ => panic!("Expected Projection as top-level operator, got: {plan:?}"), } } + + // ===== JOIN TESTS ===== + + #[test] + fn test_inner_join() { + let schema = create_test_schema(); + let sql = "SELECT * FROM users u INNER JOIN orders o ON u.id = o.user_id"; + let plan = parse_and_build(sql, &schema).unwrap(); + + match plan { + LogicalPlan::Projection(proj) => { + match &*proj.input { + LogicalPlan::Join(join) => { + assert_eq!(join.join_type, JoinType::Inner); + assert!(!join.on.is_empty(), "Should have join conditions"); + + // Check left input is users + match &*join.left { + LogicalPlan::TableScan(scan) => { + assert_eq!(scan.table_name, "users"); + } + _ => panic!("Expected TableScan for left input"), + } + + // Check right input is orders + match &*join.right { + LogicalPlan::TableScan(scan) => { + assert_eq!(scan.table_name, "orders"); + } + _ => panic!("Expected TableScan for right input"), + } + } + _ => panic!("Expected Join under Projection"), + } + } + _ => panic!("Expected Projection at top level"), + } + } + + #[test] + fn test_left_join() { + let schema = create_test_schema(); + let sql = "SELECT u.name, o.amount FROM users u LEFT JOIN orders o ON u.id = o.user_id"; + let plan = parse_and_build(sql, &schema).unwrap(); + + match plan { + LogicalPlan::Projection(proj) => { + assert_eq!(proj.exprs.len(), 2); // name and amount + match &*proj.input { + LogicalPlan::Join(join) => { + assert_eq!(join.join_type, JoinType::Left); + assert!(!join.on.is_empty(), "Should have join conditions"); + } + _ => panic!("Expected Join under Projection"), + } + } + _ => panic!("Expected Projection at top level"), + } + } + + #[test] + fn test_right_join() { + let schema = create_test_schema(); + let sql = "SELECT * FROM orders o RIGHT JOIN users u ON o.user_id = u.id"; + let plan = parse_and_build(sql, &schema).unwrap(); + + match plan { + LogicalPlan::Projection(proj) => match &*proj.input { + LogicalPlan::Join(join) => { + assert_eq!(join.join_type, JoinType::Right); + assert!(!join.on.is_empty(), "Should have join conditions"); + } + _ => panic!("Expected Join under Projection"), + }, + _ => panic!("Expected Projection at top level"), + } + } + + #[test] + fn test_full_outer_join() { + let schema = create_test_schema(); + let sql = "SELECT * FROM users u FULL OUTER JOIN orders o ON u.id = o.user_id"; + let plan = parse_and_build(sql, &schema).unwrap(); + + match plan { + LogicalPlan::Projection(proj) => match &*proj.input { + LogicalPlan::Join(join) => { + assert_eq!(join.join_type, JoinType::Full); + assert!(!join.on.is_empty(), "Should have join conditions"); + } + _ => panic!("Expected Join under Projection"), + }, + _ => panic!("Expected Projection at top level"), + } + } + + #[test] + fn test_cross_join() { + let schema = create_test_schema(); + let sql = "SELECT * FROM users CROSS JOIN orders"; + let plan = parse_and_build(sql, &schema).unwrap(); + + match plan { + LogicalPlan::Projection(proj) => match &*proj.input { + LogicalPlan::Join(join) => { + assert_eq!(join.join_type, JoinType::Cross); + assert!(join.on.is_empty(), "Cross join should have no conditions"); + assert!(join.filter.is_none(), "Cross join should have no filter"); + } + _ => panic!("Expected Join under Projection"), + }, + _ => panic!("Expected Projection at top level"), + } + } + + #[test] + fn test_join_with_multiple_conditions() { + let schema = create_test_schema(); + let sql = "SELECT * FROM users u JOIN orders o ON u.id = o.user_id AND u.age > 18"; + let plan = parse_and_build(sql, &schema).unwrap(); + + match plan { + LogicalPlan::Projection(proj) => { + match &*proj.input { + LogicalPlan::Join(join) => { + assert_eq!(join.join_type, JoinType::Inner); + // Should have at least one equijoin condition + assert!(!join.on.is_empty(), "Should have join conditions"); + // Additional conditions may be in filter + // The exact distribution depends on our implementation + } + _ => panic!("Expected Join under Projection"), + } + } + _ => panic!("Expected Projection at top level"), + } + } + + #[test] + fn test_join_using_clause() { + let schema = create_test_schema(); + // Note: Both tables should have an 'id' column for this to work + let sql = "SELECT * FROM users JOIN orders USING (id)"; + let plan = parse_and_build(sql, &schema).unwrap(); + + match plan { + LogicalPlan::Projection(proj) => match &*proj.input { + LogicalPlan::Join(join) => { + assert_eq!(join.join_type, JoinType::Inner); + assert!( + !join.on.is_empty(), + "USING clause should create join conditions" + ); + } + _ => panic!("Expected Join under Projection"), + }, + _ => panic!("Expected Projection at top level"), + } + } + + #[test] + fn test_natural_join() { + let schema = create_test_schema(); + let sql = "SELECT * FROM users NATURAL JOIN orders"; + let plan = parse_and_build(sql, &schema).unwrap(); + + match plan { + LogicalPlan::Projection(proj) => { + match &*proj.input { + LogicalPlan::Join(join) => { + // Natural join finds common columns (id in this case) + // If no common columns, it becomes a cross join + assert!( + !join.on.is_empty() || join.join_type == JoinType::Cross, + "Natural join should either find common columns or become cross join" + ); + } + _ => panic!("Expected Join under Projection"), + } + } + _ => panic!("Expected Projection at top level"), + } + } + + #[test] + fn test_three_way_join() { + let schema = create_test_schema(); + let sql = "SELECT * FROM users u + JOIN orders o ON u.id = o.user_id + JOIN products p ON o.product_id = p.id"; + let plan = parse_and_build(sql, &schema).unwrap(); + + match plan { + LogicalPlan::Projection(proj) => { + match &*proj.input { + LogicalPlan::Join(join2) => { + // Second join (with products) + assert_eq!(join2.join_type, JoinType::Inner); + match &*join2.left { + LogicalPlan::Join(join1) => { + // First join (users with orders) + assert_eq!(join1.join_type, JoinType::Inner); + } + _ => panic!("Expected nested Join for three-way join"), + } + } + _ => panic!("Expected Join under Projection"), + } + } + _ => panic!("Expected Projection at top level"), + } + } + + #[test] + fn test_mixed_join_types() { + let schema = create_test_schema(); + let sql = "SELECT * FROM users u + LEFT JOIN orders o ON u.id = o.user_id + INNER JOIN products p ON o.product_id = p.id"; + let plan = parse_and_build(sql, &schema).unwrap(); + + match plan { + LogicalPlan::Projection(proj) => { + match &*proj.input { + LogicalPlan::Join(join2) => { + // Second join should be INNER + assert_eq!(join2.join_type, JoinType::Inner); + match &*join2.left { + LogicalPlan::Join(join1) => { + // First join should be LEFT + assert_eq!(join1.join_type, JoinType::Left); + } + _ => panic!("Expected nested Join"), + } + } + _ => panic!("Expected Join under Projection"), + } + } + _ => panic!("Expected Projection at top level"), + } + } + + #[test] + fn test_join_with_filter() { + let schema = create_test_schema(); + let sql = "SELECT * FROM users u JOIN orders o ON u.id = o.user_id WHERE o.amount > 100"; + let plan = parse_and_build(sql, &schema).unwrap(); + + match plan { + LogicalPlan::Projection(proj) => { + match &*proj.input { + LogicalPlan::Filter(filter) => { + // WHERE clause creates a Filter above the Join + match &*filter.input { + LogicalPlan::Join(join) => { + assert_eq!(join.join_type, JoinType::Inner); + } + _ => panic!("Expected Join under Filter"), + } + } + _ => panic!("Expected Filter under Projection"), + } + } + _ => panic!("Expected Projection at top level"), + } + } + + #[test] + fn test_join_with_projection() { + let schema = create_test_schema(); + let sql = "SELECT u.name, o.amount FROM users u JOIN orders o ON u.id = o.user_id"; + let plan = parse_and_build(sql, &schema).unwrap(); + + match plan { + LogicalPlan::Projection(proj) => { + assert_eq!(proj.exprs.len(), 2); // u.name and o.amount + match &*proj.input { + LogicalPlan::Join(join) => { + assert_eq!(join.join_type, JoinType::Inner); + } + _ => panic!("Expected Join under Projection"), + } + } + _ => panic!("Expected Projection at top level"), + } + } + + #[test] + fn test_join_with_aggregation() { + let schema = create_test_schema(); + let sql = "SELECT u.name, SUM(o.amount) + FROM users u JOIN orders o ON u.id = o.user_id + GROUP BY u.name"; + let plan = parse_and_build(sql, &schema).unwrap(); + + match plan { + LogicalPlan::Aggregate(agg) => { + assert_eq!(agg.group_expr.len(), 1); // GROUP BY u.name + assert_eq!(agg.aggr_expr.len(), 1); // SUM(o.amount) + match &*agg.input { + LogicalPlan::Join(join) => { + assert_eq!(join.join_type, JoinType::Inner); + } + _ => panic!("Expected Join under Aggregate"), + } + } + _ => panic!("Expected Aggregate"), + } + } + + #[test] + fn test_join_with_order_by() { + let schema = create_test_schema(); + let sql = "SELECT * FROM users u JOIN orders o ON u.id = o.user_id ORDER BY o.amount DESC"; + let plan = parse_and_build(sql, &schema).unwrap(); + + match plan { + LogicalPlan::Sort(sort) => { + assert_eq!(sort.exprs.len(), 1); + assert!(!sort.exprs[0].asc); // DESC + match &*sort.input { + LogicalPlan::Projection(proj) => match &*proj.input { + LogicalPlan::Join(join) => { + assert_eq!(join.join_type, JoinType::Inner); + } + _ => panic!("Expected Join under Projection"), + }, + _ => panic!("Expected Projection under Sort"), + } + } + _ => panic!("Expected Sort at top level"), + } + } + + #[test] + fn test_join_in_subquery() { + let schema = create_test_schema(); + let sql = "SELECT * FROM ( + SELECT u.id, u.name, o.amount + FROM users u JOIN orders o ON u.id = o.user_id + ) WHERE amount > 100"; + let plan = parse_and_build(sql, &schema).unwrap(); + + match plan { + LogicalPlan::Projection(outer_proj) => match &*outer_proj.input { + LogicalPlan::Filter(filter) => match &*filter.input { + LogicalPlan::Projection(inner_proj) => match &*inner_proj.input { + LogicalPlan::Join(join) => { + assert_eq!(join.join_type, JoinType::Inner); + } + _ => panic!("Expected Join in subquery"), + }, + _ => panic!("Expected Projection for subquery"), + }, + _ => panic!("Expected Filter"), + }, + _ => panic!("Expected Projection at top level"), + } + } + + #[test] + fn test_join_ambiguous_column() { + let schema = create_test_schema(); + // Both users and orders have an 'id' column + let sql = "SELECT id FROM users JOIN orders ON users.id = orders.user_id"; + let result = parse_and_build(sql, &schema); + // This might error or succeed depending on how we handle ambiguous columns + // For now, just check that parsing completes + match result { + Ok(_) => { + // If successful, the implementation handles ambiguous columns somehow + } + Err(_) => { + // If error, the implementation rejects ambiguous columns + } + } + } } diff --git a/core/translate/optimizer/mod.rs b/core/translate/optimizer/mod.rs index 6fc2dbe6f..b9df7c698 100644 --- a/core/translate/optimizer/mod.rs +++ b/core/translate/optimizer/mod.rs @@ -8,15 +8,13 @@ use join::{compute_best_join_order, BestJoinOrderResult}; use lift_common_subexpressions::lift_common_subexpressions_from_binary_or_terms; use order::{compute_order_target, plan_satisfies_order_target, EliminatesSortBy}; use turso_ext::{ConstraintInfo, ConstraintUsage}; -use turso_macros::match_ignore_ascii_case; use turso_parser::ast::{self, Expr, SortOrder}; use crate::{ - parameters::PARAM_PREFIX, schema::{Index, IndexColumn, Schema, Table}, translate::{ - expr::walk_expr_mut, expr::WalkControl, optimizer::access_method::AccessMethodParams, - optimizer::constraints::TableConstraints, plan::Scan, plan::TerminationKey, + optimizer::access_method::AccessMethodParams, optimizer::constraints::TableConstraints, + plan::Scan, plan::TerminationKey, }, types::SeekOp, LimboError, Result, @@ -64,7 +62,7 @@ pub fn optimize_plan(plan: &mut Plan, schema: &Schema) -> Result<()> { */ pub fn optimize_select_plan(plan: &mut SelectPlan, schema: &Schema) -> Result<()> { optimize_subqueries(plan, schema)?; - rewrite_exprs_select(plan)?; + lift_common_subexpressions_from_binary_or_terms(&mut plan.where_clause)?; if let ConstantConditionEliminationResult::ImpossibleCondition = eliminate_constant_conditions(&mut plan.where_clause)? { @@ -89,7 +87,7 @@ pub fn optimize_select_plan(plan: &mut SelectPlan, schema: &Schema) -> Result<() } fn optimize_delete_plan(plan: &mut DeletePlan, schema: &Schema) -> Result<()> { - rewrite_exprs_delete(plan)?; + lift_common_subexpressions_from_binary_or_terms(&mut plan.where_clause)?; if let ConstantConditionEliminationResult::ImpossibleCondition = eliminate_constant_conditions(&mut plan.where_clause)? { @@ -110,7 +108,7 @@ fn optimize_delete_plan(plan: &mut DeletePlan, schema: &Schema) -> Result<()> { } fn optimize_update_plan(plan: &mut UpdatePlan, schema: &Schema) -> Result<()> { - rewrite_exprs_update(plan)?; + lift_common_subexpressions_from_binary_or_terms(&mut plan.where_clause)?; if let ConstantConditionEliminationResult::ImpossibleCondition = eliminate_constant_conditions(&mut plan.where_clause)? { @@ -558,62 +556,6 @@ fn eliminate_constant_conditions( Ok(ConstantConditionEliminationResult::Continue) } -fn rewrite_exprs_select(plan: &mut SelectPlan) -> Result<()> { - let mut param_count = 1; - for rc in plan.result_columns.iter_mut() { - rewrite_expr(&mut rc.expr, &mut param_count)?; - } - for agg in plan.aggregates.iter_mut() { - rewrite_expr(&mut agg.original_expr, &mut param_count)?; - } - lift_common_subexpressions_from_binary_or_terms(&mut plan.where_clause)?; - for cond in plan.where_clause.iter_mut() { - rewrite_expr(&mut cond.expr, &mut param_count)?; - } - if let Some(group_by) = &mut plan.group_by { - for expr in group_by.exprs.iter_mut() { - rewrite_expr(expr, &mut param_count)?; - } - } - for (expr, _) in plan.order_by.iter_mut() { - rewrite_expr(expr, &mut param_count)?; - } - if let Some(window) = &mut plan.window { - for func in window.functions.iter_mut() { - rewrite_expr(&mut func.original_expr, &mut param_count)?; - } - } - - Ok(()) -} - -fn rewrite_exprs_delete(plan: &mut DeletePlan) -> Result<()> { - let mut param_idx = 1; - for cond in plan.where_clause.iter_mut() { - rewrite_expr(&mut cond.expr, &mut param_idx)?; - } - Ok(()) -} - -fn rewrite_exprs_update(plan: &mut UpdatePlan) -> Result<()> { - let mut param_idx = 1; - for (_, expr) in plan.set_clauses.iter_mut() { - rewrite_expr(expr, &mut param_idx)?; - } - for cond in plan.where_clause.iter_mut() { - rewrite_expr(&mut cond.expr, &mut param_idx)?; - } - for (expr, _) in plan.order_by.iter_mut() { - rewrite_expr(expr, &mut param_idx)?; - } - if let Some(rc) = plan.returning.as_mut() { - for rc in rc.iter_mut() { - rewrite_expr(&mut rc.expr, &mut param_idx)?; - } - } - Ok(()) -} - #[derive(Debug, Clone, Copy, PartialEq, Eq)] pub enum AlwaysTrueOrFalse { AlwaysTrue, @@ -1449,77 +1391,7 @@ fn build_seek_def( }) } -pub fn rewrite_expr(top_level_expr: &mut ast::Expr, param_idx: &mut usize) -> Result { - walk_expr_mut( - top_level_expr, - &mut |expr: &mut ast::Expr| -> Result { - match expr { - ast::Expr::Id(id) => { - // Convert "true" and "false" to 1 and 0 - let id_bytes = id.as_str().as_bytes(); - match_ignore_ascii_case!(match id_bytes { - b"true" => { - *expr = ast::Expr::Literal(ast::Literal::Numeric("1".to_owned())); - } - b"false" => { - *expr = ast::Expr::Literal(ast::Literal::Numeric("0".to_owned())); - } - _ => {} - }) - } - ast::Expr::Variable(var) => { - if var.is_empty() { - // rewrite anonymous variables only, ensure that the `param_idx` starts at 1 and - // all the expressions are rewritten in the order they come in the statement - *expr = ast::Expr::Variable(format!("{PARAM_PREFIX}{param_idx}")); - *param_idx += 1; - } - } - ast::Expr::Between { - lhs, - not, - start, - end, - } => { - // Convert `y NOT BETWEEN x AND z` to `x > y OR y > z` - let (lower_op, upper_op) = if *not { - (ast::Operator::Greater, ast::Operator::Greater) - } else { - // Convert `y BETWEEN x AND z` to `x <= y AND y <= z` - (ast::Operator::LessEquals, ast::Operator::LessEquals) - }; - - let start = start.take_ownership(); - let lhs = lhs.take_ownership(); - let end = end.take_ownership(); - - let lower_bound = - ast::Expr::Binary(Box::new(start), lower_op, Box::new(lhs.clone())); - let upper_bound = ast::Expr::Binary(Box::new(lhs), upper_op, Box::new(end)); - - if *not { - *expr = ast::Expr::Binary( - Box::new(lower_bound), - ast::Operator::Or, - Box::new(upper_bound), - ); - } else { - *expr = ast::Expr::Binary( - Box::new(lower_bound), - ast::Operator::And, - Box::new(upper_bound), - ); - } - } - _ => {} - } - - Ok(WalkControl::Continue) - }, - ) -} - -trait TakeOwnership { +pub trait TakeOwnership { fn take_ownership(&mut self) -> Self; } diff --git a/core/translate/planner.rs b/core/translate/planner.rs index 14f422860..e9f62e9e8 100644 --- a/core/translate/planner.rs +++ b/core/translate/planner.rs @@ -11,23 +11,23 @@ use super::{ select::prepare_select_plan, SymbolTable, }; -use crate::function::{AggFunc, ExtFunc}; use crate::translate::expr::WalkControl; use crate::translate::plan::{Window, WindowFunction}; use crate::{ ast::Limit, function::Func, schema::{Schema, Table}, - translate::expr::walk_expr_mut, util::{exprs_are_equivalent, normalize_ident}, vdbe::builder::TableRefIdCounter, Result, }; -use turso_macros::match_ignore_ascii_case; +use crate::{ + function::{AggFunc, ExtFunc}, + translate::expr::{bind_and_rewrite_expr, ParamState}, +}; use turso_parser::ast::Literal::Null; use turso_parser::ast::{ - self, As, Expr, FromClause, JoinType, Literal, Materialized, Over, QualifiedName, - TableInternalId, With, + self, As, Expr, FromClause, JoinType, Materialized, Over, QualifiedName, TableInternalId, With, }; pub const ROWID: &str = "rowid"; @@ -262,231 +262,6 @@ fn add_aggregate_if_not_exists( Ok(()) } -pub fn bind_column_references( - top_level_expr: &mut Expr, - referenced_tables: &mut TableReferences, - result_columns: Option<&[ResultSetColumn]>, - connection: &Arc, -) -> Result { - walk_expr_mut( - top_level_expr, - &mut |expr: &mut Expr| -> Result { - match expr { - Expr::Id(id) => { - // true and false are special constants that are effectively aliases for 1 and 0 - // and not identifiers of columns - let id_bytes = id.as_str().as_bytes(); - match_ignore_ascii_case!(match id_bytes { - b"true" | b"false" => { - return Ok(WalkControl::Continue); - } - _ => {} - }); - let normalized_id = normalize_ident(id.as_str()); - - if !referenced_tables.joined_tables().is_empty() { - if let Some(row_id_expr) = parse_row_id( - &normalized_id, - referenced_tables.joined_tables()[0].internal_id, - || referenced_tables.joined_tables().len() != 1, - )? { - *expr = row_id_expr; - - return Ok(WalkControl::Continue); - } - } - let mut match_result = None; - - // First check joined tables - for joined_table in referenced_tables.joined_tables().iter() { - let col_idx = joined_table.table.columns().iter().position(|c| { - c.name - .as_ref() - .is_some_and(|name| name.eq_ignore_ascii_case(&normalized_id)) - }); - if col_idx.is_some() { - if match_result.is_some() { - crate::bail_parse_error!("Column {} is ambiguous", id.as_str()); - } - let col = joined_table.table.columns().get(col_idx.unwrap()).unwrap(); - match_result = Some(( - joined_table.internal_id, - col_idx.unwrap(), - col.is_rowid_alias, - )); - } - } - - // Then check outer query references, if we still didn't find something. - // Normally finding multiple matches for a non-qualified column is an error (column x is ambiguous) - // but in the case of subqueries, the inner query takes precedence. - // For example: - // SELECT * FROM t WHERE x = (SELECT x FROM t2) - // In this case, there is no ambiguity: - // - x in the outer query refers to t.x, - // - x in the inner query refers to t2.x. - if match_result.is_none() { - for outer_ref in referenced_tables.outer_query_refs().iter() { - let col_idx = outer_ref.table.columns().iter().position(|c| { - c.name - .as_ref() - .is_some_and(|name| name.eq_ignore_ascii_case(&normalized_id)) - }); - if col_idx.is_some() { - if match_result.is_some() { - crate::bail_parse_error!("Column {} is ambiguous", id.as_str()); - } - let col = outer_ref.table.columns().get(col_idx.unwrap()).unwrap(); - match_result = Some(( - outer_ref.internal_id, - col_idx.unwrap(), - col.is_rowid_alias, - )); - } - } - } - - if let Some((table_id, col_idx, is_rowid_alias)) = match_result { - *expr = Expr::Column { - database: None, // TODO: support different databases - table: table_id, - column: col_idx, - is_rowid_alias, - }; - referenced_tables.mark_column_used(table_id, col_idx); - return Ok(WalkControl::Continue); - } - - if let Some(result_columns) = result_columns { - for result_column in result_columns.iter() { - if result_column - .name(referenced_tables) - .is_some_and(|name| name.eq_ignore_ascii_case(&normalized_id)) - { - *expr = result_column.expr.clone(); - return Ok(WalkControl::Continue); - } - } - } - // SQLite behavior: Only double-quoted identifiers get fallback to string literals - // Single quotes are handled as literals earlier, unquoted identifiers must resolve to columns - if id.is_double_quoted() { - // Convert failed double-quoted identifier to string literal - *expr = Expr::Literal(Literal::String(id.as_str().to_string())); - Ok(WalkControl::Continue) - } else { - // Unquoted identifiers must resolve to columns - no fallback - crate::bail_parse_error!("no such column: {}", id.as_str()) - } - } - Expr::Qualified(tbl, id) => { - let normalized_table_name = normalize_ident(tbl.as_str()); - let matching_tbl = referenced_tables - .find_table_and_internal_id_by_identifier(&normalized_table_name); - if matching_tbl.is_none() { - crate::bail_parse_error!("no such table: {}", normalized_table_name); - } - let (tbl_id, tbl) = matching_tbl.unwrap(); - let normalized_id = normalize_ident(id.as_str()); - - if let Some(row_id_expr) = parse_row_id(&normalized_id, tbl_id, || false)? { - *expr = row_id_expr; - - return Ok(WalkControl::Continue); - } - let col_idx = tbl.columns().iter().position(|c| { - c.name - .as_ref() - .is_some_and(|name| name.eq_ignore_ascii_case(&normalized_id)) - }); - let Some(col_idx) = col_idx else { - crate::bail_parse_error!("no such column: {}", normalized_id); - }; - let col = tbl.columns().get(col_idx).unwrap(); - *expr = Expr::Column { - database: None, // TODO: support different databases - table: tbl_id, - column: col_idx, - is_rowid_alias: col.is_rowid_alias, - }; - referenced_tables.mark_column_used(tbl_id, col_idx); - Ok(WalkControl::Continue) - } - Expr::DoublyQualified(db_name, tbl_name, col_name) => { - let normalized_col_name = normalize_ident(col_name.as_str()); - - // Create a QualifiedName and use existing resolve_database_id method - let qualified_name = ast::QualifiedName { - db_name: Some(db_name.clone()), - name: tbl_name.clone(), - alias: None, - }; - let database_id = connection.resolve_database_id(&qualified_name)?; - - // Get the table from the specified database - let table = connection - .with_schema(database_id, |schema| schema.get_table(tbl_name.as_str())) - .ok_or_else(|| { - crate::LimboError::ParseError(format!( - "no such table: {}.{}", - db_name.as_str(), - tbl_name.as_str() - )) - })?; - - // Find the column in the table - let col_idx = table - .columns() - .iter() - .position(|c| { - c.name - .as_ref() - .is_some_and(|name| name.eq_ignore_ascii_case(&normalized_col_name)) - }) - .ok_or_else(|| { - crate::LimboError::ParseError(format!( - "Column: {}.{}.{} not found", - db_name.as_str(), - tbl_name.as_str(), - col_name.as_str() - )) - })?; - - let col = table.columns().get(col_idx).unwrap(); - - // Check if this is a rowid alias - let is_rowid_alias = col.is_rowid_alias; - - // Convert to Column expression - since this is a cross-database reference, - // we need to create a synthetic table reference for it - // For now, we'll error if the table isn't already in the referenced tables - let normalized_tbl_name = normalize_ident(tbl_name.as_str()); - let matching_tbl = referenced_tables - .find_table_and_internal_id_by_identifier(&normalized_tbl_name); - - if let Some((tbl_id, _)) = matching_tbl { - // Table is already in referenced tables, use existing internal ID - *expr = Expr::Column { - database: Some(database_id), - table: tbl_id, - column: col_idx, - is_rowid_alias, - }; - referenced_tables.mark_column_used(tbl_id, col_idx); - } else { - return Err(crate::LimboError::ParseError(format!( - "table {normalized_tbl_name} is not in FROM clause - cross-database column references require the table to be explicitly joined" - ))); - } - - Ok(WalkControl::Continue) - } - _ => Ok(WalkControl::Continue), - } - }, - ) -} - #[allow(clippy::too_many_arguments)] fn parse_from_clause_table( schema: &Schema, @@ -663,7 +438,7 @@ fn parse_table( let btree_table = Arc::new(crate::schema::BTreeTable { name: view_guard.name().to_string(), root_page, - columns: view_guard.columns.clone(), + columns: view_guard.column_schema.flat_columns(), primary_key_columns: Vec::new(), has_rowid: true, is_strict: false, @@ -776,6 +551,7 @@ pub fn parse_from( table_references: &mut TableReferences, table_ref_counter: &mut TableRefIdCounter, connection: &Arc, + param_ctx: &mut ParamState, ) -> Result<()> { if from.is_none() { return Ok(()); @@ -874,6 +650,7 @@ pub fn parse_from( table_references, table_ref_counter, connection, + param_ctx, )?; } @@ -886,12 +663,19 @@ pub fn parse_where( result_columns: Option<&[ResultSetColumn]>, out_where_clause: &mut Vec, connection: &Arc, + param_ctx: &mut ParamState, ) -> Result<()> { if let Some(where_expr) = where_clause { let start_idx = out_where_clause.len(); break_predicate_at_and_boundaries(where_expr, out_where_clause); for expr in out_where_clause[start_idx..].iter_mut() { - bind_column_references(&mut expr.expr, table_references, result_columns, connection)?; + bind_and_rewrite_expr( + &mut expr.expr, + Some(table_references), + result_columns, + connection, + param_ctx, + )?; } Ok(()) } else { @@ -1084,6 +868,7 @@ fn parse_join( table_references: &mut TableReferences, table_ref_counter: &mut TableRefIdCounter, connection: &Arc, + param_ctx: &mut ParamState, ) -> Result<()> { let ast::JoinedSelectTable { operator: join_operator, @@ -1171,11 +956,12 @@ fn parse_join( } else { None }; - bind_column_references( + bind_and_rewrite_expr( &mut predicate.expr, - table_references, + Some(table_references), None, connection, + param_ctx, )?; } } @@ -1290,7 +1076,7 @@ pub fn break_predicate_at_and_boundaries>( } } -fn parse_row_id( +pub fn parse_row_id( column_name: &str, table_id: TableInternalId, fn_check: F, @@ -1315,11 +1101,11 @@ where pub fn parse_limit( limit: &mut Limit, connection: &std::sync::Arc, + param_ctx: &mut ParamState, ) -> Result<(Option>, Option>)> { - let mut empty_refs = TableReferences::new(Vec::new(), Vec::new()); - bind_column_references(&mut limit.expr, &mut empty_refs, None, connection)?; + bind_and_rewrite_expr(&mut limit.expr, None, None, connection, param_ctx)?; if let Some(ref mut off_expr) = limit.offset { - bind_column_references(off_expr, &mut empty_refs, None, connection)?; + bind_and_rewrite_expr(off_expr, None, None, connection, param_ctx)?; } Ok((Some(limit.expr.clone()), limit.offset.clone())) } diff --git a/core/translate/pragma.rs b/core/translate/pragma.rs index 6336faf25..fa4274ed3 100644 --- a/core/translate/pragma.rs +++ b/core/translate/pragma.rs @@ -99,7 +99,7 @@ fn update_pragma( let app_id_value = match data { Value::Integer(i) => i as i32, Value::Float(f) => f as i32, - _ => unreachable!(), + _ => bail_parse_error!("expected integer, got {:?}", data), }; program.emit_insn(Insn::SetCookie { @@ -110,6 +110,19 @@ fn update_pragma( }); Ok((program, TransactionMode::Write)) } + PragmaName::BusyTimeout => { + let data = parse_signed_number(&value)?; + let busy_timeout_ms = match data { + Value::Integer(i) => i as i32, + Value::Float(f) => f as i32, + _ => bail_parse_error!("expected integer, got {:?}", data), + }; + let busy_timeout_ms = busy_timeout_ms.max(0); + connection.set_busy_timeout(Some(std::time::Duration::from_millis( + busy_timeout_ms as u64, + ))); + Ok((program, TransactionMode::Write)) + } PragmaName::CacheSize => { let cache_size = match parse_signed_number(&value)? { Value::Integer(size) => size, @@ -388,6 +401,18 @@ fn query_pragma( program.emit_result_row(register, 1); Ok((program, TransactionMode::Read)) } + PragmaName::BusyTimeout => { + program.emit_int( + connection + .get_busy_timeout() + .map(|t| t.as_millis() as i64) + .unwrap_or_default(), + register, + ); + program.emit_result_row(register, 1); + program.add_pragma_result_column(pragma.to_string()); + Ok((program, TransactionMode::None)) + } PragmaName::CacheSize => { program.emit_int(connection.get_cache_size() as i64, register); program.emit_result_row(register, 1); @@ -508,7 +533,8 @@ fn query_pragma( emit_columns_for_table_info(&mut program, table.columns(), base_reg); } else if let Some(view_mutex) = schema.get_materialized_view(&name) { let view = view_mutex.lock().unwrap(); - emit_columns_for_table_info(&mut program, &view.columns, base_reg); + let flat_columns = view.column_schema.flat_columns(); + emit_columns_for_table_info(&mut program, &flat_columns, base_reg); } else if let Some(view) = schema.get_view(&name) { emit_columns_for_table_info(&mut program, &view.columns, base_reg); } diff --git a/core/translate/select.rs b/core/translate/select.rs index e13eed952..a1ec15abe 100644 --- a/core/translate/select.rs +++ b/core/translate/select.rs @@ -4,11 +4,12 @@ use super::plan::{ Search, TableReferences, WhereTerm, Window, }; use crate::schema::Table; +use crate::translate::expr::{bind_and_rewrite_expr, ParamState}; use crate::translate::optimizer::optimize_plan; use crate::translate::plan::{GroupBy, Plan, ResultSetColumn, SelectPlan}; use crate::translate::planner::{ - bind_column_references, break_predicate_at_and_boundaries, parse_from, parse_limit, - parse_where, resolve_window_and_aggregate_functions, + break_predicate_at_and_boundaries, parse_from, parse_limit, parse_where, + resolve_window_and_aggregate_functions, }; use crate::translate::window::plan_windows; use crate::util::normalize_ident; @@ -98,6 +99,7 @@ pub fn prepare_select_plan( connection: &Arc, ) -> Result { let compounds = select.body.compounds; + let mut param_ctx = ParamState::default(); match compounds.is_empty() { true => Ok(Plan::Select(prepare_one_select_plan( schema, @@ -110,6 +112,7 @@ pub fn prepare_select_plan( table_ref_counter, query_destination, connection, + &mut param_ctx, )?)), false => { let mut last = prepare_one_select_plan( @@ -123,6 +126,7 @@ pub fn prepare_select_plan( table_ref_counter, query_destination.clone(), connection, + &mut param_ctx, )?; let mut left = Vec::with_capacity(compounds.len()); @@ -139,6 +143,7 @@ pub fn prepare_select_plan( table_ref_counter, query_destination.clone(), connection, + &mut param_ctx, )?; } @@ -149,9 +154,9 @@ pub fn prepare_select_plan( crate::bail_parse_error!("SELECTs to the left and right of {} do not have the same number of result columns", operator); } } - let (limit, offset) = select - .limit - .map_or(Ok((None, None)), |mut l| parse_limit(&mut l, connection))?; + let (limit, offset) = select.limit.map_or(Ok((None, None)), |mut l| { + parse_limit(&mut l, connection, &mut param_ctx) + })?; // FIXME: handle ORDER BY for compound selects if !select.order_by.is_empty() { @@ -184,6 +189,7 @@ fn prepare_one_select_plan( table_ref_counter: &mut TableRefIdCounter, query_destination: QueryDestination, connection: &Arc, + param_ctx: &mut ParamState, ) -> Result { match select { ast::OneSelect::Select { @@ -230,6 +236,7 @@ fn prepare_one_select_plan( &mut table_references, table_ref_counter, connection, + param_ctx, )?; // Preallocate space for the result columns @@ -255,7 +262,6 @@ fn prepare_one_select_plan( }) .sum(), ); - let mut plan = SelectPlan { join_order: table_references .joined_tables() @@ -288,19 +294,21 @@ fn prepare_one_select_plan( let mut window = Window::new(Some(name), &window_def.window)?; for expr in window.partition_by.iter_mut() { - bind_column_references( + bind_and_rewrite_expr( expr, - &mut plan.table_references, - Some(&plan.result_columns), + Some(&mut plan.table_references), + None, connection, + param_ctx, )?; } for (expr, _) in window.order_by.iter_mut() { - bind_column_references( + bind_and_rewrite_expr( expr, - &mut plan.table_references, - Some(&plan.result_columns), + Some(&mut plan.table_references), + None, connection, + param_ctx, )?; } @@ -357,11 +365,12 @@ fn prepare_one_select_plan( } } ResultColumn::Expr(ref mut expr, maybe_alias) => { - bind_column_references( + bind_and_rewrite_expr( expr, - &mut plan.table_references, - Some(&plan.result_columns), + Some(&mut plan.table_references), + None, connection, + param_ctx, )?; let contains_aggregates = resolve_window_and_aggregate_functions( schema, @@ -385,7 +394,12 @@ fn prepare_one_select_plan( // This step can only be performed at this point, because all table references are now available. // Virtual table predicates may depend on column bindings from tables to the right in the join order, // so we must wait until the full set of references has been collected. - add_vtab_predicates_to_where_clause(&mut vtab_predicates, &mut plan, connection)?; + add_vtab_predicates_to_where_clause( + &mut vtab_predicates, + &mut plan, + connection, + param_ctx, + )?; // Parse the actual WHERE clause and add its conditions to the plan WHERE clause that already contains the join conditions. parse_where( @@ -394,16 +408,18 @@ fn prepare_one_select_plan( Some(&plan.result_columns), &mut plan.where_clause, connection, + param_ctx, )?; if let Some(mut group_by) = group_by { for expr in group_by.exprs.iter_mut() { replace_column_number_with_copy_of_column_expr(expr, &plan.result_columns)?; - bind_column_references( + bind_and_rewrite_expr( expr, - &mut plan.table_references, + Some(&mut plan.table_references), Some(&plan.result_columns), connection, + param_ctx, )?; } @@ -414,11 +430,12 @@ fn prepare_one_select_plan( let mut predicates = vec![]; break_predicate_at_and_boundaries(&having, &mut predicates); for expr in predicates.iter_mut() { - bind_column_references( + bind_and_rewrite_expr( expr, - &mut plan.table_references, + Some(&mut plan.table_references), Some(&plan.result_columns), connection, + param_ctx, )?; let contains_aggregates = resolve_window_and_aggregate_functions( schema, @@ -452,11 +469,12 @@ fn prepare_one_select_plan( for mut o in order_by { replace_column_number_with_copy_of_column_expr(&mut o.expr, &plan.result_columns)?; - bind_column_references( + bind_and_rewrite_expr( &mut o.expr, - &mut plan.table_references, + Some(&mut plan.table_references), Some(&plan.result_columns), connection, + param_ctx, )?; resolve_window_and_aggregate_functions( schema, @@ -471,8 +489,9 @@ fn prepare_one_select_plan( plan.order_by = key; // Parse the LIMIT/OFFSET clause - (plan.limit, plan.offset) = - limit.map_or(Ok((None, None)), |mut l| parse_limit(&mut l, connection))?; + (plan.limit, plan.offset) = limit.map_or(Ok((None, None)), |mut l| { + parse_limit(&mut l, connection, param_ctx) + })?; if !windows.is_empty() { plan_windows(schema, syms, &mut plan, table_ref_counter, &mut windows)?; @@ -521,13 +540,15 @@ fn add_vtab_predicates_to_where_clause( vtab_predicates: &mut Vec, plan: &mut SelectPlan, connection: &Arc, + param_ctx: &mut ParamState, ) -> Result<()> { for expr in vtab_predicates.iter_mut() { - bind_column_references( + bind_and_rewrite_expr( expr, - &mut plan.table_references, + Some(&mut plan.table_references), Some(&plan.result_columns), connection, + param_ctx, )?; } for expr in vtab_predicates.drain(..) { diff --git a/core/translate/update.rs b/core/translate/update.rs index 6ca366049..feb3d926d 100644 --- a/core/translate/update.rs +++ b/core/translate/update.rs @@ -2,6 +2,7 @@ use std::collections::HashMap; use std::sync::Arc; use crate::schema::{BTreeTable, Column, Type}; +use crate::translate::expr::{bind_and_rewrite_expr, ParamState}; use crate::translate::optimizer::optimize_select_plan; use crate::translate::plan::{Operation, QueryDestination, Scan, Search, SelectPlan}; use crate::translate::planner::parse_limit; @@ -22,7 +23,7 @@ use super::plan::{ ColumnUsedMask, IterationDirection, JoinedTable, Plan, ResultSetColumn, TableReferences, UpdatePlan, }; -use super::planner::{bind_column_references, parse_where}; +use super::planner::parse_where; /* * Update is simple. By default we scan the table, and for each row, we check the WHERE * clause. If it evaluates to true, we build the new record with the updated value and insert. @@ -90,7 +91,6 @@ pub fn translate_update_for_schema_change( } optimize_plan(&mut plan, schema)?; - // TODO: freestyling these numbers let opts = ProgramBuilderOpts { num_cursors: 1, approx_num_insns: 20, @@ -181,11 +181,18 @@ pub fn prepare_update_plan( .collect(); let mut set_clauses = Vec::with_capacity(body.sets.len()); + let mut param_idx = ParamState::default(); // Process each SET assignment and map column names to expressions // e.g the statement `SET x = 1, y = 2, z = 3` has 3 set assigments for set in &mut body.sets { - bind_column_references(&mut set.expr, &mut table_references, None, connection)?; + bind_and_rewrite_expr( + &mut set.expr, + Some(&mut table_references), + None, + connection, + &mut param_idx, + )?; let values = match set.expr.as_ref() { Expr::Parenthesized(vals) => vals.clone(), @@ -222,12 +229,22 @@ pub fn prepare_update_plan( body.tbl_name.name.as_str(), program, connection, + &mut param_idx, )?; let order_by = body .order_by - .iter() - .map(|o| (o.expr.clone(), o.order.unwrap_or(SortOrder::Asc))) + .iter_mut() + .map(|o| { + let _ = bind_and_rewrite_expr( + &mut o.expr, + Some(&mut table_references), + Some(&result_columns), + connection, + &mut param_idx, + ); + (o.expr.clone(), o.order.unwrap_or(SortOrder::Asc)) + }) .collect(); // Sqlite determines we should create an ephemeral table if we do not have a FROM clause @@ -266,6 +283,7 @@ pub fn prepare_update_plan( Some(&result_columns), &mut where_clause, connection, + &mut param_idx, )?; let table = Arc::new(BTreeTable { @@ -342,14 +360,14 @@ pub fn prepare_update_plan( Some(&result_columns), &mut where_clause, connection, + &mut param_idx, )?; }; // Parse the LIMIT/OFFSET clause - let (limit, offset) = body - .limit - .as_mut() - .map_or(Ok((None, None)), |l| parse_limit(l, connection))?; + let (limit, offset) = body.limit.as_mut().map_or(Ok((None, None)), |l| { + parse_limit(l, connection, &mut param_idx) + })?; // Check what indexes will need to be updated by checking set_clauses and see // if a column is contained in an index. diff --git a/core/translate/view.rs b/core/translate/view.rs index f89f29817..9ff8e6c89 100644 --- a/core/translate/view.rs +++ b/core/translate/view.rs @@ -42,7 +42,8 @@ pub fn translate_create_materialized_view( // storing invalid view definitions use crate::incremental::view::IncrementalView; use crate::schema::BTreeTable; - let view_columns = IncrementalView::validate_and_extract_columns(select_stmt, schema)?; + let view_column_schema = IncrementalView::validate_and_extract_columns(select_stmt, schema)?; + let view_columns = view_column_schema.flat_columns(); // Reconstruct the SQL string for storage let sql = create_materialized_view_to_str(view_name, select_stmt); diff --git a/core/util.rs b/core/util.rs index 2d945ec11..faffc72cf 100644 --- a/core/util.rs +++ b/core/util.rs @@ -1066,9 +1066,59 @@ pub fn extract_column_name_from_expr(expr: impl AsRef) -> Option, +} + +/// Information about a column in the view's output +#[derive(Debug, Clone)] +pub struct ViewColumn { + /// Index into ViewColumnSchema.tables indicating which table this column comes from + /// For computed columns or constants, this will be usize::MAX + pub table_index: usize, + /// The actual column definition + pub column: Column, +} + +/// Schema information for a view, tracking which columns come from which tables +#[derive(Debug, Clone)] +pub struct ViewColumnSchema { + /// All tables referenced by the view (in order of appearance) + pub tables: Vec, + /// The view's output columns with their table associations + pub columns: Vec, +} + +impl ViewColumnSchema { + /// Get all columns as a flat vector (without table association info) + pub fn flat_columns(&self) -> Vec { + self.columns.iter().map(|vc| vc.column.clone()).collect() + } + + /// Get columns that belong to a specific table + pub fn table_columns(&self, table_index: usize) -> Vec { + self.columns + .iter() + .filter(|vc| vc.table_index == table_index) + .map(|vc| vc.column.clone()) + .collect() + } +} + /// Extract column information from a SELECT statement for view creation -pub fn extract_view_columns(select_stmt: &ast::Select, schema: &Schema) -> Vec { +pub fn extract_view_columns( + select_stmt: &ast::Select, + schema: &Schema, +) -> Result { + let mut tables = Vec::new(); let mut columns = Vec::new(); + let mut column_name_counts: HashMap = HashMap::new(); + // Navigate to the first SELECT in the statement if let ast::OneSelect::Select { ref from, @@ -1076,23 +1126,85 @@ pub fn extract_view_columns(select_stmt: &ast::Select, schema: &Schema) -> Vec { + let table_name = if qualified_name.db_name.is_some() { + // Include database qualifier if present + qualified_name.to_string() + } else { + normalize_ident(qualified_name.name.as_str()) + }; + tables.push(ViewTable { + name: table_name.clone(), + alias: alias.as_ref().map(|a| match a { + ast::As::As(name) => normalize_ident(name.as_str()), + ast::As::Elided(name) => normalize_ident(name.as_str()), + }), + }); + } + _ => { + // Handle other types like subqueries if needed + } } - } else { - None + + // Add tables from JOINs + for join in &from.joins { + match join.table.as_ref() { + ast::SelectTable::Table(qualified_name, alias, _) => { + let table_name = if qualified_name.db_name.is_some() { + // Include database qualifier if present + qualified_name.to_string() + } else { + normalize_ident(qualified_name.name.as_str()) + }; + tables.push(ViewTable { + name: table_name.clone(), + alias: alias.as_ref().map(|a| match a { + ast::As::As(name) => normalize_ident(name.as_str()), + ast::As::Elided(name) => normalize_ident(name.as_str()), + }), + }); + } + _ => { + // Handle other types like subqueries if needed + } + } + } + } + + // Helper function to find table index by name or alias + let find_table_index = |name: &str| -> Option { + tables + .iter() + .position(|t| t.name == name || t.alias.as_ref().is_some_and(|a| a == name)) }; - // Get the table for column resolution - let _table = table_name.as_ref().and_then(|name| schema.get_table(name)); + // Process each column in the SELECT list - for (i, result_col) in select_columns.iter().enumerate() { + for result_col in select_columns.iter() { match result_col { ast::ResultColumn::Expr(expr, alias) => { - let name = alias + // Figure out which table this expression comes from + let table_index = match expr.as_ref() { + ast::Expr::Qualified(table_ref, _col_name) => { + // Column qualified with table name + find_table_index(table_ref.as_str()) + } + ast::Expr::Id(_col_name) => { + // Unqualified column - would need to resolve based on schema + // For now, assume it's from the first table if there is one + if !tables.is_empty() { + Some(0) + } else { + None + } + } + _ => None, // Expression, literal, etc. + }; + + let col_name = alias .as_ref() .map(|a| match a { ast::As::Elided(name) => name.as_str().to_string(), @@ -1103,41 +1215,65 @@ pub fn extract_view_columns(select_stmt: &ast::Select, schema: &Schema) -> Vec { - // For SELECT *, expand to all columns from the table - if let Some(ref table_name) = table_name { - if let Some(table) = schema.get_table(table_name) { - // Copy all columns from the table, but adjust for view constraints - for table_column in table.columns() { - columns.push(Column { - name: table_column.name.clone(), - ty: table_column.ty, - ty_str: table_column.ty_str.clone(), - primary_key: false, // Views don't have primary keys - is_rowid_alias: false, - notnull: false, // Views typically don't enforce NOT NULL - default: None, // Views don't have default values - unique: false, - collation: table_column.collation, - hidden: false, + // For SELECT *, expand to all columns from all tables + for (table_idx, table) in tables.iter().enumerate() { + if let Some(table_obj) = schema.get_table(&table.name) { + for table_column in table_obj.columns() { + let col_name = + table_column.name.clone().unwrap_or_else(|| "?".to_string()); + + // Handle duplicate column names by adding suffix + let final_name = + if let Some(count) = column_name_counts.get_mut(&col_name) { + *count += 1; + format!("{}:{}", col_name, *count - 1) + } else { + column_name_counts.insert(col_name.clone(), 1); + col_name.clone() + }; + + columns.push(ViewColumn { + table_index: table_idx, + column: Column { + name: Some(final_name), + ty: table_column.ty, + ty_str: table_column.ty_str.clone(), + primary_key: false, + is_rowid_alias: false, + notnull: false, + default: None, + unique: false, + collation: table_column.collation, + hidden: false, + }, }); } - } else { - // Table not found, create placeholder - columns.push(Column { + } + } + + // If no tables, create a placeholder + if tables.is_empty() { + columns.push(ViewColumn { + table_index: usize::MAX, + column: Column { name: Some("*".to_string()), ty: Type::Text, ty_str: "TEXT".to_string(), @@ -1148,63 +1284,70 @@ pub fn extract_view_columns(select_stmt: &ast::Select, schema: &Schema) -> Vec { + ast::ResultColumn::TableStar(table_ref) => { // For table.*, expand to all columns from the specified table - let table_name_str = normalize_ident(table_name.as_str()); - if let Some(table) = schema.get_table(&table_name_str) { - // Copy all columns from the table, but adjust for view constraints - for table_column in table.columns() { - columns.push(Column { - name: table_column.name.clone(), - ty: table_column.ty, - ty_str: table_column.ty_str.clone(), - primary_key: false, - is_rowid_alias: false, - notnull: false, - default: None, - unique: false, - collation: table_column.collation, - hidden: false, + let table_name_str = normalize_ident(table_ref.as_str()); + if let Some(table_idx) = find_table_index(&table_name_str) { + if let Some(table) = schema.get_table(&tables[table_idx].name) { + for table_column in table.columns() { + let col_name = + table_column.name.clone().unwrap_or_else(|| "?".to_string()); + + // Handle duplicate column names by adding suffix + let final_name = + if let Some(count) = column_name_counts.get_mut(&col_name) { + *count += 1; + format!("{}:{}", col_name, *count - 1) + } else { + column_name_counts.insert(col_name.clone(), 1); + col_name.clone() + }; + + columns.push(ViewColumn { + table_index: table_idx, + column: Column { + name: Some(final_name), + ty: table_column.ty, + ty_str: table_column.ty_str.clone(), + primary_key: false, + is_rowid_alias: false, + notnull: false, + default: None, + unique: false, + collation: table_column.collation, + hidden: false, + }, + }); + } + } else { + // Table not found, create placeholder + columns.push(ViewColumn { + table_index: usize::MAX, + column: Column { + name: Some(format!("{table_name_str}.*")), + ty: Type::Text, + ty_str: "TEXT".to_string(), + primary_key: false, + is_rowid_alias: false, + notnull: false, + default: None, + unique: false, + collation: None, + hidden: false, + }, }); } - } else { - // Table not found, create placeholder - columns.push(Column { - name: Some(format!("{table_name_str}.*")), - ty: Type::Text, - ty_str: "TEXT".to_string(), - primary_key: false, - is_rowid_alias: false, - notnull: false, - default: None, - unique: false, - collation: None, - hidden: false, - }); } } } } } - columns + + Ok(ViewColumnSchema { tables, columns }) } #[cfg(test)] diff --git a/core/vdbe/execute.rs b/core/vdbe/execute.rs index 16f32ba79..5099f7831 100644 --- a/core/vdbe/execute.rs +++ b/core/vdbe/execute.rs @@ -2276,16 +2276,8 @@ pub fn op_transaction_inner( if matches!(new_transaction_state, TransactionState::Write { .. }) && matches!(actual_tx_mode, TransactionMode::Write) { - let (tx_id, mv_tx_mode) = program.connection.mv_tx.get().unwrap(); - if mv_tx_mode == TransactionMode::Read { - return_if_io!( - mv_store.upgrade_to_exclusive_tx(pager.clone(), Some(tx_id)) - ); - } else { - return_if_io!( - mv_store.begin_exclusive_tx(pager.clone(), Some(tx_id)) - ); - } + let (tx_id, _) = program.connection.mv_tx.get().unwrap(); + return_if_io!(mv_store.begin_exclusive_tx(pager.clone(), Some(tx_id))); } } } else { diff --git a/parser/src/ast.rs b/parser/src/ast.rs index 3c331107a..ed58d0bba 100644 --- a/parser/src/ast.rs +++ b/parser/src/ast.rs @@ -1312,6 +1312,8 @@ pub enum PragmaName { ApplicationId, /// set the autovacuum mode AutoVacuum, + /// set the busy_timeout (see https://www.sqlite.org/pragma.html#pragma_busy_timeout) + BusyTimeout, /// `cache_size` pragma CacheSize, /// encryption cipher algorithm name for encrypted databases diff --git a/parser/src/parser.rs b/parser/src/parser.rs index fa2230373..2fde878e4 100644 --- a/parser/src/parser.rs +++ b/parser/src/parser.rs @@ -749,7 +749,7 @@ impl<'a> Parser<'a> { fn parse_create_materialized_view(&mut self) -> Result { eat_assert!(self, TK_MATERIALIZED); - eat_assert!(self, TK_VIEW); + eat_expect!(self, TK_VIEW); let if_not_exists = self.parse_if_not_exists()?; let view_name = self.parse_fullname(false)?; let columns = self.parse_eid_list(false)?; diff --git a/perf/encryption/Cargo.toml b/perf/encryption/Cargo.toml new file mode 100644 index 000000000..e769c5b0b --- /dev/null +++ b/perf/encryption/Cargo.toml @@ -0,0 +1,17 @@ +[package] +name = "encryption-throughput" +version = "0.1.0" +edition = "2021" + +[[bin]] +name = "encryption-throughput" +path = "src/main.rs" + +[dependencies] +turso = { workspace = true, features = ["encryption"] } +clap = { workspace = true, features = ["derive"] } +tokio = { workspace = true, default-features = true, features = ["full"] } +futures = { workspace = true } +tracing-subscriber = { workspace = true } +rand = { workspace = true, features = ["small_rng"] } +hex = { workspace = true } \ No newline at end of file diff --git a/perf/encryption/README.md b/perf/encryption/README.md new file mode 100644 index 000000000..0ec611258 --- /dev/null +++ b/perf/encryption/README.md @@ -0,0 +1,28 @@ +# Encryption Throughput Benchmarking + +```shell +$ cargo run --release -- --help + +Usage: encryption-throughput [OPTIONS] + +Options: + -t, --threads [default: 1] + -b, --batch-size [default: 100] + -i, --iterations [default: 10] + -r, --read-ratio Percentage of operations that should be reads (0-100) + -w, --write-ratio Percentage of operations that should be writes (0-100) + --encryption Enable database encryption + --cipher Encryption cipher to use (only relevant if --encryption is set) [default: aegis-256] + --think Per transaction think time (ms) [default: 0] + --timeout Busy timeout in milliseconds [default: 30000] + --seed Random seed for reproducible workloads [default: 2167532792061351037] + -h, --help Print help +``` + +```shell +# try these: + +cargo run --release -- -b 100 -i 25000 --read-ratio 75 + +cargo run --release -- -b 100 -i 25000 --read-ratio 75 --encryption +``` \ No newline at end of file diff --git a/perf/encryption/src/main.rs b/perf/encryption/src/main.rs new file mode 100644 index 000000000..7736055c5 --- /dev/null +++ b/perf/encryption/src/main.rs @@ -0,0 +1,457 @@ +use clap::Parser; +use rand::rngs::SmallRng; +use rand::{Rng, RngCore, SeedableRng}; +use std::sync::atomic::{AtomicU64, Ordering}; +use std::sync::{Arc, Barrier}; +use std::time::{Duration, Instant}; +use turso::{Builder, Database, Result}; + +#[derive(Debug, Clone)] +struct EncryptionOpts { + cipher: String, + hexkey: String, +} + +#[derive(Parser)] +#[command(name = "encryption-throughput")] +#[command(about = "Encryption throughput benchmark on Turso DB")] +struct Args { + /// More than one thread does not work yet + #[arg(short = 't', long = "threads", default_value = "1")] + threads: usize, + + /// the number operations per transaction + #[arg(short = 'b', long = "batch-size", default_value = "100")] + batch_size: usize, + + /// number of transactions per thread + #[arg(short = 'i', long = "iterations", default_value = "10")] + iterations: usize, + + #[arg( + short = 'r', + long = "read-ratio", + help = "Percentage of operations that should be reads (0-100)" + )] + read_ratio: Option, + + #[arg( + short = 'w', + long = "write-ratio", + help = "Percentage of operations that should be writes (0-100)" + )] + write_ratio: Option, + + #[arg( + long = "encryption", + action = clap::ArgAction::SetTrue, + help = "Enable database encryption" + )] + encryption: bool, + + #[arg( + long = "cipher", + default_value = "aegis-256", + help = "Encryption cipher to use (only relevant if --encryption is set)" + )] + cipher: String, + + #[arg( + long = "think", + default_value = "0", + help = "Per transaction think time (ms)" + )] + think: u64, + + #[arg( + long = "timeout", + default_value = "30000", + help = "Busy timeout in milliseconds" + )] + timeout: u64, + + #[arg( + long = "seed", + default_value = "2167532792061351037", + help = "Random seed for reproducible workloads" + )] + seed: u64, +} + +#[derive(Debug)] +struct WorkerStats { + transactions_completed: u64, + reads_completed: u64, + writes_completed: u64, + reads_found: u64, + reads_not_found: u64, + total_transaction_time: Duration, +} + +#[derive(Debug, Clone)] +struct SharedState { + max_inserted_id: Arc, +} + +#[tokio::main] +async fn main() -> Result<()> { + let _ = tracing_subscriber::fmt::try_init(); + let args = Args::parse(); + + let read_ratio = match (args.read_ratio, args.write_ratio) { + (Some(_), Some(_)) => { + eprintln!("Error: Cannot specify both --read-ratio and --write-ratio"); + std::process::exit(1); + } + (Some(r), None) => { + if r > 100 { + eprintln!("Error: read-ratio must be between 0 and 100"); + std::process::exit(1); + } + r + } + (None, Some(w)) => { + if w > 100 { + eprintln!("Error: write-ratio must be between 0 and 100"); + std::process::exit(1); + } + 100 - w + } + // lets default to 0% reads (100% writes) + (None, None) => 0, + }; + + println!( + "Running encryption throughput benchmark with {} threads, {} batch size, {} iterations", + args.threads, args.batch_size, args.iterations + ); + println!( + "Read/Write ratio: {}% reads, {}% writes", + read_ratio, + 100 - read_ratio + ); + println!("Encryption enabled: {}", args.encryption); + println!("Random seed: {}", args.seed); + + let encryption_opts = if args.encryption { + let mut key_rng = SmallRng::seed_from_u64(args.seed); + let key_size = get_key_size_for_cipher(&args.cipher); + let mut key = vec![0u8; key_size]; + key_rng.fill_bytes(&mut key); + + let config = EncryptionOpts { + cipher: args.cipher.clone(), + hexkey: hex::encode(&key), + }; + + println!("Cipher: {}", config.cipher); + println!("Hexkey: {}", config.hexkey); + Some(config) + } else { + None + }; + + let db_path = "encryption_throughput_test.db"; + if std::path::Path::new(db_path).exists() { + std::fs::remove_file(db_path).expect("Failed to remove existing database"); + } + let wal_path = "encryption_throughput_test.db-wal"; + if std::path::Path::new(wal_path).exists() { + std::fs::remove_file(wal_path).expect("Failed to remove existing WAL file"); + } + + let db = setup_database(db_path, &encryption_opts).await?; + + // for create a var which is shared between all the threads, this we use to track the + // max inserted id so that we only read these + let shared_state = SharedState { + max_inserted_id: Arc::new(AtomicU64::new(0)), + }; + + let start_barrier = Arc::new(Barrier::new(args.threads)); + let mut handles = Vec::new(); + + let timeout = Duration::from_millis(args.timeout); + let overall_start = Instant::now(); + + for thread_id in 0..args.threads { + let db_clone = db.clone(); + let barrier = Arc::clone(&start_barrier); + let encryption_opts_clone = encryption_opts.clone(); + let shared_state_clone = shared_state.clone(); + + let handle = tokio::task::spawn(worker_thread( + thread_id, + db_clone, + args.batch_size, + args.iterations, + barrier, + read_ratio, + encryption_opts_clone, + args.think, + timeout, + shared_state_clone, + args.seed, + )); + + handles.push(handle); + } + + let mut total_transactions = 0; + let mut total_reads = 0; + let mut total_writes = 0; + let mut total_reads_found = 0; + let mut total_reads_not_found = 0; + + for (idx, handle) in handles.into_iter().enumerate() { + match handle.await { + Ok(Ok(stats)) => { + total_transactions += stats.transactions_completed; + total_reads += stats.reads_completed; + total_writes += stats.writes_completed; + total_reads_found += stats.reads_found; + total_reads_not_found += stats.reads_not_found; + } + Ok(Err(e)) => { + eprintln!("Thread error {idx}: {e}"); + return Err(e); + } + Err(_) => { + eprintln!("Thread panicked"); + std::process::exit(1); + } + } + } + + let overall_elapsed = overall_start.elapsed(); + let total_operations = total_reads + total_writes; + + let transaction_throughput = (total_transactions as f64) / overall_elapsed.as_secs_f64(); + let operation_throughput = (total_operations as f64) / overall_elapsed.as_secs_f64(); + let read_throughput = if total_reads > 0 { + (total_reads as f64) / overall_elapsed.as_secs_f64() + } else { + 0.0 + }; + let write_throughput = if total_writes > 0 { + (total_writes as f64) / overall_elapsed.as_secs_f64() + } else { + 0.0 + }; + let avg_ops_per_txn = (total_operations as f64) / (total_transactions as f64); + + println!("\n=== BENCHMARK RESULTS ==="); + println!("Total transactions: {total_transactions}"); + println!("Total operations: {total_operations}"); + println!("Operations per transaction: {avg_ops_per_txn:.1}"); + println!("Total time: {:.2}s", overall_elapsed.as_secs_f64()); + println!(); + println!("Transaction throughput: {transaction_throughput:.2} txns/sec"); + println!("Operation throughput: {operation_throughput:.2} ops/sec"); + + // not found should be zero since track the max inserted id + // todo(v): probably handle the not found error and remove max id + if total_reads > 0 { + println!( + " - Read operations: {total_reads} ({total_reads_found} found, {total_reads_not_found} not found)" + ); + println!(" - Read throughput: {read_throughput:.2} reads/sec"); + } + if total_writes > 0 { + println!(" - Write operations: {total_writes}"); + println!(" - Write throughput: {write_throughput:.2} writes/sec"); + } + + println!("\nConfiguration:"); + println!("Threads: {}", args.threads); + println!("Batch size: {}", args.batch_size); + println!("Iterations per thread: {}", args.iterations); + println!("Encryption: {}", args.encryption); + println!("Seed: {}", args.seed); + + if let Ok(metadata) = std::fs::metadata(db_path) { + println!("Database file size: {} bytes", metadata.len()); + } + + Ok(()) +} + +fn get_key_size_for_cipher(cipher: &str) -> usize { + match cipher.to_lowercase().as_str() { + "aes-128-gcm" | "aegis-128l" | "aegis-128x2" | "aegis-128x4" => 16, + "aes-256-gcm" | "aegis-256" | "aegis-256x2" | "aegis-256x4" => 32, + _ => 32, // default to 256-bit key + } +} + +async fn setup_database( + db_path: &str, + encryption_opts: &Option, +) -> Result { + let builder = Builder::new_local(db_path); + let db = builder.build().await?; + let conn = db.connect()?; + + if let Some(config) = encryption_opts { + conn.execute(&format!("PRAGMA cipher='{}'", config.cipher), ()) + .await?; + conn.execute(&format!("PRAGMA hexkey='{}'", config.hexkey), ()) + .await?; + } + + // todo(v): probably store blobs and then have option of randomblob size + conn.execute( + "CREATE TABLE IF NOT EXISTS test_table ( + id INTEGER PRIMARY KEY, + data TEXT NOT NULL + )", + (), + ) + .await?; + + println!("Database created at: {db_path}"); + Ok(db) +} + +#[allow(clippy::too_many_arguments)] +async fn worker_thread( + thread_id: usize, + db: Database, + batch_size: usize, + iterations: usize, + start_barrier: Arc, + read_ratio: u8, + encryption_opts: Option, + think_ms: u64, + timeout: Duration, + shared_state: SharedState, + base_seed: u64, +) -> Result { + start_barrier.wait(); + + let start_time = Instant::now(); + let mut stats = WorkerStats { + transactions_completed: 0, + reads_completed: 0, + writes_completed: 0, + reads_found: 0, + reads_not_found: 0, + total_transaction_time: Duration::ZERO, + }; + + let thread_seed = base_seed.wrapping_add(thread_id as u64); + let mut rng = SmallRng::seed_from_u64(thread_seed); + + for iteration in 0..iterations { + let conn = db.connect()?; + + if let Some(config) = &encryption_opts { + conn.execute(&format!("PRAGMA cipher='{}'", config.cipher), ()) + .await?; + conn.execute(&format!("PRAGMA hexkey='{}'", config.hexkey), ()) + .await?; + } + + conn.busy_timeout(Some(timeout))?; + + let mut insert_stmt = conn + .prepare("INSERT INTO test_table (id, data) VALUES (?, ?)") + .await?; + + let transaction_start = Instant::now(); + conn.execute("BEGIN", ()).await?; + + for i in 0..batch_size { + let should_read = rng.random_range(0..100) < read_ratio; + + if should_read { + // only attempt reads if we have inserted some data + let max_id = shared_state.max_inserted_id.load(Ordering::Relaxed); + if max_id > 0 { + let read_id = rng.random_range(1..=max_id); + let row = conn + .query( + "SELECT data FROM test_table WHERE id = ?", + turso::params::Params::Positional(vec![turso::Value::Integer( + read_id as i64, + )]), + ) + .await; + + match row { + Ok(_) => stats.reads_found += 1, + Err(turso::Error::QueryReturnedNoRows) => stats.reads_not_found += 1, + Err(e) => return Err(e), + }; + stats.reads_completed += 1; + } else { + // if no data inserted yet, convert to a write + let id = thread_id * iterations * batch_size + iteration * batch_size + i + 1; + insert_stmt + .execute(turso::params::Params::Positional(vec![ + turso::Value::Integer(id as i64), + turso::Value::Text(format!("data_{id}")), + ])) + .await?; + + shared_state + .max_inserted_id + .fetch_max(id as u64, Ordering::Relaxed); + stats.writes_completed += 1; + } + } else { + let id = thread_id * iterations * batch_size + iteration * batch_size + i + 1; + insert_stmt + .execute(turso::params::Params::Positional(vec![ + turso::Value::Integer(id as i64), + turso::Value::Text(format!("data_{id}")), + ])) + .await?; + + shared_state + .max_inserted_id + .fetch_max(id as u64, Ordering::Relaxed); + stats.writes_completed += 1; + } + } + + if think_ms > 0 { + tokio::time::sleep(Duration::from_millis(think_ms)).await; + } + + conn.execute("COMMIT", ()).await?; + + let transaction_elapsed = transaction_start.elapsed(); + stats.transactions_completed += 1; + stats.total_transaction_time += transaction_elapsed; + } + + let elapsed = start_time.elapsed(); + let total_ops = stats.reads_completed + stats.writes_completed; + let transaction_throughput = (stats.transactions_completed as f64) / elapsed.as_secs_f64(); + let operation_throughput = (total_ops as f64) / elapsed.as_secs_f64(); + let avg_txn_latency = + stats.total_transaction_time.as_secs_f64() * 1000.0 / stats.transactions_completed as f64; + + println!( + "Thread {}: {} txns ({} ops: {} reads, {} writes) in {:.2}s ({:.2} txns/sec, {:.2} ops/sec, {:.2}ms avg latency)", + thread_id, + stats.transactions_completed, + total_ops, + stats.reads_completed, + stats.writes_completed, + elapsed.as_secs_f64(), + transaction_throughput, + operation_throughput, + avg_txn_latency + ); + + if stats.reads_completed > 0 { + println!( + " Thread {} reads: {} found, {} not found", + thread_id, stats.reads_found, stats.reads_not_found + ); + } + + Ok(stats) +} diff --git a/testing/materialized_views.test b/testing/materialized_views.test index 5d226b016..15229a48c 100755 --- a/testing/materialized_views.test +++ b/testing/materialized_views.test @@ -44,13 +44,13 @@ do_execsql_test_on_specific_db {:memory:} matview-aggregation-population { do_execsql_test_on_specific_db {:memory:} matview-filter-with-groupby { CREATE TABLE t(a INTEGER, b INTEGER); INSERT INTO t(a,b) VALUES (2,2), (3,3), (6,6), (7,7); - + CREATE MATERIALIZED VIEW v AS SELECT b as yourb, SUM(a) as mysum, COUNT(a) as mycount FROM t WHERE b > 2 GROUP BY b; - + SELECT * FROM v ORDER BY yourb; } {3|3|1 6|6|1 @@ -63,13 +63,13 @@ do_execsql_test_on_specific_db {:memory:} matview-insert-maintenance { FROM t WHERE b > 2 GROUP BY b; - + INSERT INTO t VALUES (3,3), (6,6); SELECT * FROM v ORDER BY b; - + INSERT INTO t VALUES (4,3), (5,6); SELECT * FROM v ORDER BY b; - + INSERT INTO t VALUES (1,1), (2,2); SELECT * FROM v ORDER BY b; } {3|3|1 @@ -87,17 +87,17 @@ do_execsql_test_on_specific_db {:memory:} matview-delete-maintenance { (3, 'A', 30), (4, 'B', 40), (5, 'A', 50); - + CREATE MATERIALIZED VIEW category_sums AS SELECT category, SUM(amount) as total, COUNT(*) as cnt FROM items GROUP BY category; - + SELECT * FROM category_sums ORDER BY category; - + DELETE FROM items WHERE id = 3; SELECT * FROM category_sums ORDER BY category; - + DELETE FROM items WHERE category = 'B'; SELECT * FROM category_sums ORDER BY category; } {A|90|3 @@ -113,17 +113,17 @@ do_execsql_test_on_specific_db {:memory:} matview-update-maintenance { (2, 200, 2), (3, 300, 1), (4, 400, 2); - + CREATE MATERIALIZED VIEW status_totals AS SELECT status, SUM(value) as total, COUNT(*) as cnt FROM records GROUP BY status; - + SELECT * FROM status_totals ORDER BY status; - + UPDATE records SET value = 150 WHERE id = 1; SELECT * FROM status_totals ORDER BY status; - + UPDATE records SET status = 2 WHERE id = 3; SELECT * FROM status_totals ORDER BY status; } {1|400|2 @@ -136,10 +136,10 @@ do_execsql_test_on_specific_db {:memory:} matview-update-maintenance { do_execsql_test_on_specific_db {:memory:} matview-integer-primary-key-basic { CREATE TABLE t(a INTEGER PRIMARY KEY, b INTEGER); INSERT INTO t(a,b) VALUES (2,2), (3,3), (6,6), (7,7); - + CREATE MATERIALIZED VIEW v AS SELECT * FROM t WHERE b > 2; - + SELECT * FROM v ORDER BY a; } {3|3 6|6 @@ -148,15 +148,15 @@ do_execsql_test_on_specific_db {:memory:} matview-integer-primary-key-basic { do_execsql_test_on_specific_db {:memory:} matview-integer-primary-key-update-rowid { CREATE TABLE t(a INTEGER PRIMARY KEY, b INTEGER); INSERT INTO t(a,b) VALUES (2,2), (3,3), (6,6), (7,7); - + CREATE MATERIALIZED VIEW v AS SELECT * FROM t WHERE b > 2; - + SELECT * FROM v ORDER BY a; - + UPDATE t SET a = 1 WHERE b = 3; SELECT * FROM v ORDER BY a; - + UPDATE t SET a = 10 WHERE a = 6; SELECT * FROM v ORDER BY a; } {3|3 @@ -172,15 +172,15 @@ do_execsql_test_on_specific_db {:memory:} matview-integer-primary-key-update-row do_execsql_test_on_specific_db {:memory:} matview-integer-primary-key-update-value { CREATE TABLE t(a INTEGER PRIMARY KEY, b INTEGER); INSERT INTO t(a,b) VALUES (2,2), (3,3), (6,6), (7,7); - + CREATE MATERIALIZED VIEW v AS SELECT * FROM t WHERE b > 2; - + SELECT * FROM v ORDER BY a; - + UPDATE t SET b = 1 WHERE a = 6; SELECT * FROM v ORDER BY a; - + UPDATE t SET b = 5 WHERE a = 2; SELECT * FROM v ORDER BY a; } {3|3 @@ -200,18 +200,18 @@ do_execsql_test_on_specific_db {:memory:} matview-integer-primary-key-with-aggre (3, 20, 300), (4, 20, 400), (5, 10, 500); - + CREATE MATERIALIZED VIEW v AS SELECT b, SUM(c) as total, COUNT(*) as cnt FROM t WHERE a > 2 GROUP BY b; - + SELECT * FROM v ORDER BY b; - + UPDATE t SET a = 6 WHERE a = 1; SELECT * FROM v ORDER BY b; - + DELETE FROM t WHERE a = 3; SELECT * FROM v ORDER BY b; } {10|500|1 @@ -228,7 +228,7 @@ do_execsql_test_on_specific_db {:memory:} matview-complex-filter-aggregation { amount INTEGER, type INTEGER ); - + INSERT INTO transactions VALUES (1, 100, 50, 1), (2, 100, 30, 2), @@ -236,21 +236,21 @@ do_execsql_test_on_specific_db {:memory:} matview-complex-filter-aggregation { (4, 100, 20, 1), (5, 200, 40, 2), (6, 300, 60, 1); - + CREATE MATERIALIZED VIEW account_deposits AS SELECT account, SUM(amount) as total_deposits, COUNT(*) as deposit_count FROM transactions WHERE type = 1 GROUP BY account; - + SELECT * FROM account_deposits ORDER BY account; - + INSERT INTO transactions VALUES (7, 100, 25, 1); SELECT * FROM account_deposits ORDER BY account; - + UPDATE transactions SET amount = 80 WHERE id = 1; SELECT * FROM account_deposits ORDER BY account; - + DELETE FROM transactions WHERE id = 3; SELECT * FROM account_deposits ORDER BY account; } {100|70|2 @@ -273,19 +273,19 @@ do_execsql_test_on_specific_db {:memory:} matview-sum-count-only { (3, 30, 2), (4, 40, 2), (5, 50, 1); - + CREATE MATERIALIZED VIEW category_stats AS SELECT category, SUM(value) as sum_val, COUNT(*) as cnt FROM data GROUP BY category; - + SELECT * FROM category_stats ORDER BY category; - + INSERT INTO data VALUES (6, 5, 1); SELECT * FROM category_stats ORDER BY category; - + UPDATE data SET value = 35 WHERE id = 3; SELECT * FROM category_stats ORDER BY category; } {1|80|3 @@ -302,9 +302,9 @@ do_execsql_test_on_specific_db {:memory:} matview-empty-table-population { FROM t WHERE b > 5 GROUP BY b; - + SELECT COUNT(*) FROM v; - + INSERT INTO t VALUES (1, 3), (2, 7), (3, 9); SELECT * FROM v ORDER BY b; } {0 @@ -314,15 +314,15 @@ do_execsql_test_on_specific_db {:memory:} matview-empty-table-population { do_execsql_test_on_specific_db {:memory:} matview-all-rows-filtered { CREATE TABLE t(a INTEGER, b INTEGER); INSERT INTO t VALUES (1, 1), (2, 2), (3, 3); - + CREATE MATERIALIZED VIEW v AS SELECT * FROM t WHERE b > 10; - + SELECT COUNT(*) FROM v; - + INSERT INTO t VALUES (11, 11); SELECT * FROM v; - + UPDATE t SET b = 1 WHERE a = 11; SELECT COUNT(*) FROM v; } {0 @@ -335,26 +335,26 @@ do_execsql_test_on_specific_db {:memory:} matview-mixed-operations-sequence { customer_id INTEGER, amount INTEGER ); - + INSERT INTO orders VALUES (1, 100, 50); INSERT INTO orders VALUES (2, 200, 75); - + CREATE MATERIALIZED VIEW customer_totals AS SELECT customer_id, SUM(amount) as total, COUNT(*) as order_count FROM orders GROUP BY customer_id; - + SELECT * FROM customer_totals ORDER BY customer_id; - + INSERT INTO orders VALUES (3, 100, 25); SELECT * FROM customer_totals ORDER BY customer_id; - + UPDATE orders SET amount = 100 WHERE order_id = 2; SELECT * FROM customer_totals ORDER BY customer_id; - + DELETE FROM orders WHERE order_id = 1; SELECT * FROM customer_totals ORDER BY customer_id; - + INSERT INTO orders VALUES (4, 300, 150); SELECT * FROM customer_totals ORDER BY customer_id; } {100|50|1 @@ -389,17 +389,17 @@ do_execsql_test_on_specific_db {:memory:} matview-projections { do_execsql_test_on_specific_db {:memory:} matview-rollback-insert { CREATE TABLE t(a INTEGER, b INTEGER); INSERT INTO t VALUES (1, 10), (2, 20), (3, 30); - + CREATE MATERIALIZED VIEW v AS SELECT * FROM t WHERE b > 15; - + SELECT * FROM v ORDER BY a; - + BEGIN; INSERT INTO t VALUES (4, 40), (5, 50); SELECT * FROM v ORDER BY a; ROLLBACK; - + SELECT * FROM v ORDER BY a; } {2|20 3|30 @@ -413,17 +413,17 @@ do_execsql_test_on_specific_db {:memory:} matview-rollback-insert { do_execsql_test_on_specific_db {:memory:} matview-rollback-delete { CREATE TABLE t(a INTEGER, b INTEGER); INSERT INTO t VALUES (1, 10), (2, 20), (3, 30), (4, 40); - + CREATE MATERIALIZED VIEW v AS SELECT * FROM t WHERE b > 15; - + SELECT * FROM v ORDER BY a; - + BEGIN; DELETE FROM t WHERE a IN (2, 3); SELECT * FROM v ORDER BY a; ROLLBACK; - + SELECT * FROM v ORDER BY a; } {2|20 3|30 @@ -436,18 +436,18 @@ do_execsql_test_on_specific_db {:memory:} matview-rollback-delete { do_execsql_test_on_specific_db {:memory:} matview-rollback-update { CREATE TABLE t(a INTEGER, b INTEGER); INSERT INTO t VALUES (1, 10), (2, 20), (3, 30); - + CREATE MATERIALIZED VIEW v AS SELECT * FROM t WHERE b > 15; - + SELECT * FROM v ORDER BY a; - + BEGIN; UPDATE t SET b = 5 WHERE a = 2; UPDATE t SET b = 35 WHERE a = 1; SELECT * FROM v ORDER BY a; ROLLBACK; - + SELECT * FROM v ORDER BY a; } {2|20 3|30 @@ -459,19 +459,19 @@ do_execsql_test_on_specific_db {:memory:} matview-rollback-update { do_execsql_test_on_specific_db {:memory:} matview-rollback-aggregation { CREATE TABLE sales(product_id INTEGER, amount INTEGER); INSERT INTO sales VALUES (1, 100), (1, 200), (2, 150), (2, 250); - + CREATE MATERIALIZED VIEW product_totals AS SELECT product_id, SUM(amount) as total, COUNT(*) as cnt FROM sales GROUP BY product_id; - + SELECT * FROM product_totals ORDER BY product_id; - + BEGIN; INSERT INTO sales VALUES (1, 50), (3, 300); SELECT * FROM product_totals ORDER BY product_id; ROLLBACK; - + SELECT * FROM product_totals ORDER BY product_id; } {1|300|2 2|400|2 @@ -484,21 +484,21 @@ do_execsql_test_on_specific_db {:memory:} matview-rollback-aggregation { do_execsql_test_on_specific_db {:memory:} matview-rollback-mixed-operations { CREATE TABLE orders(id INTEGER PRIMARY KEY, customer INTEGER, amount INTEGER); INSERT INTO orders VALUES (1, 100, 50), (2, 200, 75), (3, 100, 25); - + CREATE MATERIALIZED VIEW customer_totals AS SELECT customer, SUM(amount) as total, COUNT(*) as cnt FROM orders GROUP BY customer; - + SELECT * FROM customer_totals ORDER BY customer; - + BEGIN; INSERT INTO orders VALUES (4, 100, 100); UPDATE orders SET amount = 150 WHERE id = 2; DELETE FROM orders WHERE id = 3; SELECT * FROM customer_totals ORDER BY customer; ROLLBACK; - + SELECT * FROM customer_totals ORDER BY customer; } {100|75|2 200|75|1 @@ -514,22 +514,22 @@ do_execsql_test_on_specific_db {:memory:} matview-rollback-filtered-aggregation (2, 100, 30, 'withdraw'), (3, 200, 100, 'deposit'), (4, 200, 40, 'withdraw'); - + CREATE MATERIALIZED VIEW deposits AS SELECT account, SUM(amount) as total_deposits, COUNT(*) as cnt FROM transactions WHERE type = 'deposit' GROUP BY account; - + SELECT * FROM deposits ORDER BY account; - + BEGIN; INSERT INTO transactions VALUES (5, 100, 75, 'deposit'); UPDATE transactions SET amount = 60 WHERE id = 1; DELETE FROM transactions WHERE id = 3; SELECT * FROM deposits ORDER BY account; ROLLBACK; - + SELECT * FROM deposits ORDER BY account; } {100|50|1 200|100|1 @@ -540,12 +540,12 @@ do_execsql_test_on_specific_db {:memory:} matview-rollback-filtered-aggregation do_execsql_test_on_specific_db {:memory:} matview-rollback-empty-view { CREATE TABLE t(a INTEGER, b INTEGER); INSERT INTO t VALUES (1, 5), (2, 8); - + CREATE MATERIALIZED VIEW v AS SELECT * FROM t WHERE b > 10; - + SELECT COUNT(*) FROM v; - + BEGIN; INSERT INTO t VALUES (3, 15), (4, 20); SELECT * FROM v ORDER BY a; @@ -556,3 +556,538 @@ do_execsql_test_on_specific_db {:memory:} matview-rollback-empty-view { 3|15 4|20 0} + +# Join tests for materialized views + +do_execsql_test_on_specific_db {:memory:} matview-simple-join { + CREATE TABLE users(id INTEGER PRIMARY KEY, name TEXT, age INTEGER); + CREATE TABLE orders(order_id INTEGER PRIMARY KEY, user_id INTEGER, product_id INTEGER, quantity INTEGER); + + INSERT INTO users VALUES (1, 'Alice', 25), (2, 'Bob', 30), (3, 'Charlie', 35); + INSERT INTO orders VALUES (1, 1, 100, 5), (2, 1, 101, 3), (3, 2, 100, 7); + + CREATE MATERIALIZED VIEW user_orders AS + SELECT u.name, o.quantity + FROM users u + JOIN orders o ON u.id = o.user_id; + + SELECT * FROM user_orders ORDER BY name, quantity; +} {Alice|3 +Alice|5 +Bob|7} + +do_execsql_test_on_specific_db {:memory:} matview-join-with-aggregation { + CREATE TABLE users(id INTEGER PRIMARY KEY, name TEXT); + CREATE TABLE orders(order_id INTEGER PRIMARY KEY, user_id INTEGER, amount INTEGER); + + INSERT INTO users VALUES (1, 'Alice'), (2, 'Bob'); + INSERT INTO orders VALUES (1, 1, 100), (2, 1, 150), (3, 2, 200), (4, 2, 50); + + CREATE MATERIALIZED VIEW user_totals AS + SELECT u.name, SUM(o.amount) as total_amount + FROM users u + JOIN orders o ON u.id = o.user_id + GROUP BY u.name; + + SELECT * FROM user_totals ORDER BY name; +} {Alice|250 +Bob|250} + +do_execsql_test_on_specific_db {:memory:} matview-three-way-join { + CREATE TABLE customers(id INTEGER PRIMARY KEY, name TEXT, city TEXT); + CREATE TABLE orders(id INTEGER PRIMARY KEY, customer_id INTEGER, product_id INTEGER, quantity INTEGER); + CREATE TABLE products(id INTEGER PRIMARY KEY, name TEXT, price INTEGER); + + INSERT INTO customers VALUES (1, 'Alice', 'NYC'), (2, 'Bob', 'LA'); + INSERT INTO products VALUES (1, 'Widget', 10), (2, 'Gadget', 20); + INSERT INTO orders VALUES (1, 1, 1, 5), (2, 1, 2, 3), (3, 2, 1, 2); + + CREATE MATERIALIZED VIEW sales_summary AS + SELECT c.name as customer_name, p.name as product_name, o.quantity + FROM customers c + JOIN orders o ON c.id = o.customer_id + JOIN products p ON o.product_id = p.id; + + SELECT * FROM sales_summary ORDER BY customer_name, product_name; +} {Alice|Gadget|3 +Alice|Widget|5 +Bob|Widget|2} + +do_execsql_test_on_specific_db {:memory:} matview-three-way-join-with-aggregation { + CREATE TABLE customers(id INTEGER PRIMARY KEY, name TEXT); + CREATE TABLE orders(id INTEGER PRIMARY KEY, customer_id INTEGER, product_id INTEGER, quantity INTEGER); + CREATE TABLE products(id INTEGER PRIMARY KEY, name TEXT, price INTEGER); + + INSERT INTO customers VALUES (1, 'Alice'), (2, 'Bob'); + INSERT INTO products VALUES (1, 'Widget', 10), (2, 'Gadget', 20); + INSERT INTO orders VALUES (1, 1, 1, 5), (2, 1, 2, 3), (3, 2, 1, 2), (4, 1, 1, 4); + + CREATE MATERIALIZED VIEW sales_totals AS + SELECT c.name as customer_name, p.name as product_name, + SUM(o.quantity) as total_quantity, + SUM(o.quantity * p.price) as total_value + FROM customers c + JOIN orders o ON c.id = o.customer_id + JOIN products p ON o.product_id = p.id + GROUP BY c.name, p.name; + + SELECT * FROM sales_totals ORDER BY customer_name, product_name; +} {Alice|Gadget|3|60 +Alice|Widget|9|90 +Bob|Widget|2|20} + +do_execsql_test_on_specific_db {:memory:} matview-join-incremental-insert { + CREATE TABLE users(id INTEGER PRIMARY KEY, name TEXT); + CREATE TABLE orders(order_id INTEGER PRIMARY KEY, user_id INTEGER, amount INTEGER); + + INSERT INTO users VALUES (1, 'Alice'); + INSERT INTO orders VALUES (1, 1, 100); + + CREATE MATERIALIZED VIEW user_orders AS + SELECT u.name, o.amount + FROM users u + JOIN orders o ON u.id = o.user_id; + + SELECT COUNT(*) FROM user_orders; + + INSERT INTO orders VALUES (2, 1, 150); + SELECT COUNT(*) FROM user_orders; + + INSERT INTO users VALUES (2, 'Bob'); + INSERT INTO orders VALUES (3, 2, 200); + SELECT COUNT(*) FROM user_orders; +} {1 +2 +3} + +do_execsql_test_on_specific_db {:memory:} matview-join-incremental-delete { + CREATE TABLE users(id INTEGER PRIMARY KEY, name TEXT); + CREATE TABLE orders(order_id INTEGER PRIMARY KEY, user_id INTEGER, amount INTEGER); + + INSERT INTO users VALUES (1, 'Alice'), (2, 'Bob'); + INSERT INTO orders VALUES (1, 1, 100), (2, 1, 150), (3, 2, 200); + + CREATE MATERIALIZED VIEW user_orders AS + SELECT u.name, o.amount + FROM users u + JOIN orders o ON u.id = o.user_id; + + SELECT COUNT(*) FROM user_orders; + + DELETE FROM orders WHERE order_id = 2; + SELECT COUNT(*) FROM user_orders; + + DELETE FROM users WHERE id = 2; + SELECT COUNT(*) FROM user_orders; +} {3 +2 +1} + +do_execsql_test_on_specific_db {:memory:} matview-join-incremental-update { + CREATE TABLE users(id INTEGER PRIMARY KEY, name TEXT); + CREATE TABLE orders(order_id INTEGER PRIMARY KEY, user_id INTEGER, amount INTEGER); + + INSERT INTO users VALUES (1, 'Alice'), (2, 'Bob'); + INSERT INTO orders VALUES (1, 1, 100), (2, 2, 200); + + CREATE MATERIALIZED VIEW user_orders AS + SELECT u.name, o.amount + FROM users u + JOIN orders o ON u.id = o.user_id; + + SELECT * FROM user_orders ORDER BY name; + + UPDATE orders SET amount = 150 WHERE order_id = 1; + SELECT * FROM user_orders ORDER BY name; + + UPDATE users SET name = 'Robert' WHERE id = 2; + SELECT * FROM user_orders ORDER BY name; +} {Alice|100 +Bob|200 +Alice|150 +Bob|200 +Alice|150 +Robert|200} + +do_execsql_test_on_specific_db {:memory:} matview-join-with-filter { + CREATE TABLE users(id INTEGER PRIMARY KEY, name TEXT, age INTEGER); + CREATE TABLE orders(order_id INTEGER PRIMARY KEY, user_id INTEGER, amount INTEGER); + + INSERT INTO users VALUES (1, 'Alice', 25), (2, 'Bob', 35), (3, 'Charlie', 20); + INSERT INTO orders VALUES (1, 1, 100), (2, 2, 200), (3, 3, 150); + + CREATE MATERIALIZED VIEW adult_orders AS + SELECT u.name, o.amount + FROM users u + JOIN orders o ON u.id = o.user_id + WHERE u.age > 21; + + SELECT * FROM adult_orders ORDER BY name; +} {Alice|100 +Bob|200} + +do_execsql_test_on_specific_db {:memory:} matview-join-rollback { + CREATE TABLE users(id INTEGER PRIMARY KEY, name TEXT); + CREATE TABLE orders(order_id INTEGER PRIMARY KEY, user_id INTEGER, amount INTEGER); + + INSERT INTO users VALUES (1, 'Alice'), (2, 'Bob'); + INSERT INTO orders VALUES (1, 1, 100), (2, 2, 200); + + CREATE MATERIALIZED VIEW user_orders AS + SELECT u.name, o.amount + FROM users u + JOIN orders o ON u.id = o.user_id; + + SELECT COUNT(*) FROM user_orders; + + BEGIN; + INSERT INTO users VALUES (3, 'Charlie'); + INSERT INTO orders VALUES (3, 3, 300); + SELECT COUNT(*) FROM user_orders; + ROLLBACK; + + SELECT COUNT(*) FROM user_orders; +} {2 +3 +2} + +# ===== COMPREHENSIVE JOIN TESTS ===== + +# Test 1: Join with filter BEFORE the join (on base tables) +do_execsql_test_on_specific_db {:memory:} matview-join-with-pre-filter { + CREATE TABLE employees(id INTEGER PRIMARY KEY, name TEXT, department TEXT, salary INTEGER); + CREATE TABLE departments(id INTEGER PRIMARY KEY, dept_name TEXT, budget INTEGER); + + INSERT INTO employees VALUES + (1, 'Alice', 'Engineering', 80000), + (2, 'Bob', 'Engineering', 90000), + (3, 'Charlie', 'Sales', 60000), + (4, 'David', 'Sales', 65000), + (5, 'Eve', 'HR', 70000); + + INSERT INTO departments VALUES + (1, 'Engineering', 500000), + (2, 'Sales', 300000), + (3, 'HR', 200000); + + -- View: Join only high-salary employees with their departments + CREATE MATERIALIZED VIEW high_earners_by_dept AS + SELECT e.name, e.salary, d.dept_name, d.budget + FROM employees e + JOIN departments d ON e.department = d.dept_name + WHERE e.salary > 70000; + + SELECT * FROM high_earners_by_dept ORDER BY salary DESC; +} {Bob|90000|Engineering|500000 +Alice|80000|Engineering|500000} + +# Test 2: Join with filter AFTER the join +do_execsql_test_on_specific_db {:memory:} matview-join-with-post-filter { + CREATE TABLE products(id INTEGER PRIMARY KEY, name TEXT, category_id INTEGER, price INTEGER); + CREATE TABLE categories(id INTEGER PRIMARY KEY, name TEXT, min_price INTEGER); + + INSERT INTO products VALUES + (1, 'Laptop', 1, 1200), + (2, 'Mouse', 1, 25), + (3, 'Shirt', 2, 50), + (4, 'Shoes', 2, 120); + + INSERT INTO categories VALUES + (1, 'Electronics', 100), + (2, 'Clothing', 30); + + -- View: Products that meet or exceed their category's minimum price + CREATE MATERIALIZED VIEW premium_products AS + SELECT p.name as product, c.name as category, p.price, c.min_price + FROM products p + JOIN categories c ON p.category_id = c.id + WHERE p.price >= c.min_price; + + SELECT * FROM premium_products ORDER BY price DESC; +} {Laptop|Electronics|1200|100 +Shoes|Clothing|120|30 +Shirt|Clothing|50|30} + +# Test 3: Join with aggregation BEFORE the join +do_execsql_test_on_specific_db {:memory:} matview-aggregation-before-join { + CREATE TABLE orders(id INTEGER PRIMARY KEY, customer_id INTEGER, product_id INTEGER, quantity INTEGER, order_date INTEGER); + CREATE TABLE customers(id INTEGER PRIMARY KEY, name TEXT, tier TEXT); + + INSERT INTO orders VALUES + (1, 1, 101, 2, 1), + (2, 1, 102, 1, 1), + (3, 2, 101, 5, 1), + (4, 1, 101, 3, 2), + (5, 2, 103, 2, 2), + (6, 3, 102, 1, 2); + + INSERT INTO customers VALUES + (1, 'Alice', 'Gold'), + (2, 'Bob', 'Silver'), + (3, 'Charlie', 'Bronze'); + + -- View: Customer order counts joined with customer details + -- Note: Simplified to avoid subquery issues with DBSP compiler + CREATE MATERIALIZED VIEW customer_order_summary AS + SELECT c.name, c.tier, COUNT(o.id) as order_count, SUM(o.quantity) as total_quantity + FROM customers c + JOIN orders o ON c.id = o.customer_id + GROUP BY c.id, c.name, c.tier; + + SELECT * FROM customer_order_summary ORDER BY total_quantity DESC; +} {Bob|Silver|2|7 +Alice|Gold|3|6 +Charlie|Bronze|1|1} + +# Test 4: Join with aggregation AFTER the join +do_execsql_test_on_specific_db {:memory:} matview-aggregation-after-join { + CREATE TABLE sales(id INTEGER PRIMARY KEY, product_id INTEGER, store_id INTEGER, units_sold INTEGER, revenue INTEGER); + CREATE TABLE stores(id INTEGER PRIMARY KEY, name TEXT, region TEXT); + + INSERT INTO sales VALUES + (1, 1, 1, 10, 1000), + (2, 1, 2, 15, 1500), + (3, 2, 1, 5, 250), + (4, 2, 2, 8, 400), + (5, 1, 3, 12, 1200), + (6, 2, 3, 6, 300); + + INSERT INTO stores VALUES + (1, 'StoreA', 'North'), + (2, 'StoreB', 'North'), + (3, 'StoreC', 'South'); + + -- View: Regional sales summary (aggregate after joining) + CREATE MATERIALIZED VIEW regional_sales AS + SELECT st.region, SUM(s.units_sold) as total_units, SUM(s.revenue) as total_revenue + FROM sales s + JOIN stores st ON s.store_id = st.id + GROUP BY st.region; + + SELECT * FROM regional_sales ORDER BY total_revenue DESC; +} {North|38|3150 +South|18|1500} + +# Test 5: Modifying both tables in same transaction +do_execsql_test_on_specific_db {:memory:} matview-join-both-tables-modified { + CREATE TABLE authors(id INTEGER PRIMARY KEY, name TEXT); + CREATE TABLE books(id INTEGER PRIMARY KEY, title TEXT, author_id INTEGER, year INTEGER); + + INSERT INTO authors VALUES (1, 'Orwell'), (2, 'Asimov'); + INSERT INTO books VALUES (1, '1984', 1, 1949), (2, 'Foundation', 2, 1951); + + CREATE MATERIALIZED VIEW author_books AS + SELECT a.name, b.title, b.year + FROM authors a + JOIN books b ON a.id = b.author_id; + + SELECT COUNT(*) FROM author_books; + + BEGIN; + INSERT INTO authors VALUES (3, 'Herbert'); + INSERT INTO books VALUES (3, 'Dune', 3, 1965); + SELECT COUNT(*) FROM author_books; + COMMIT; + + SELECT * FROM author_books ORDER BY year; +} {2 +3 +Orwell|1984|1949 +Asimov|Foundation|1951 +Herbert|Dune|1965} + +# Test 6: Modifying only one table in transaction +do_execsql_test_on_specific_db {:memory:} matview-join-single-table-modified { + CREATE TABLE users(id INTEGER PRIMARY KEY, name TEXT, active INTEGER); + CREATE TABLE posts(id INTEGER PRIMARY KEY, user_id INTEGER, content TEXT); + + INSERT INTO users VALUES (1, 'Alice', 1), (2, 'Bob', 1), (3, 'Charlie', 0); + INSERT INTO posts VALUES (1, 1, 'Hello'), (2, 1, 'World'), (3, 2, 'Test'); + + CREATE MATERIALIZED VIEW active_user_posts AS + SELECT u.name, p.content + FROM users u + JOIN posts p ON u.id = p.user_id + WHERE u.active = 1; + + SELECT COUNT(*) FROM active_user_posts; + + -- Add posts for existing user (modify only posts table) + BEGIN; + INSERT INTO posts VALUES (4, 1, 'NewPost'), (5, 2, 'Another'); + SELECT COUNT(*) FROM active_user_posts; + COMMIT; + + SELECT * FROM active_user_posts ORDER BY name, content; +} {3 +5 +Alice|Hello +Alice|NewPost +Alice|World +Bob|Another +Bob|Test} + + +do_execsql_test_on_specific_db {:memory:} matview-three-way-incremental { + CREATE TABLE students(id INTEGER PRIMARY KEY, name TEXT, major TEXT); + CREATE TABLE courses(id INTEGER PRIMARY KEY, name TEXT, department TEXT, credits INTEGER); + CREATE TABLE enrollments(student_id INTEGER, course_id INTEGER, grade TEXT, PRIMARY KEY(student_id, course_id)); + + INSERT INTO students VALUES (1, 'Alice', 'CS'), (2, 'Bob', 'Math'); + INSERT INTO courses VALUES (1, 'DatabaseSystems', 'CS', 3), (2, 'Calculus', 'Math', 4); + INSERT INTO enrollments VALUES (1, 1, 'A'), (2, 2, 'B'); + + CREATE MATERIALIZED VIEW student_transcripts AS + SELECT s.name as student, c.name as course, c.credits, e.grade + FROM students s + JOIN enrollments e ON s.id = e.student_id + JOIN courses c ON e.course_id = c.id; + + SELECT COUNT(*) FROM student_transcripts; + + -- Add new student + INSERT INTO students VALUES (3, 'Charlie', 'CS'); + SELECT COUNT(*) FROM student_transcripts; + + -- Enroll new student + INSERT INTO enrollments VALUES (3, 1, 'A'), (3, 2, 'A'); + SELECT COUNT(*) FROM student_transcripts; + + -- Add new course + INSERT INTO courses VALUES (3, 'Algorithms', 'CS', 3); + SELECT COUNT(*) FROM student_transcripts; + + -- Enroll existing students in new course + INSERT INTO enrollments VALUES (1, 3, 'B'), (3, 3, 'A'); + SELECT COUNT(*) FROM student_transcripts; + + SELECT * FROM student_transcripts ORDER BY student, course; +} {2 +2 +4 +4 +6 +Alice|Algorithms|3|B +Alice|DatabaseSystems|3|A +Bob|Calculus|4|B +Charlie|Algorithms|3|A +Charlie|Calculus|4|A +Charlie|DatabaseSystems|3|A} + +do_execsql_test_on_specific_db {:memory:} matview-self-join { + CREATE TABLE employees(id INTEGER PRIMARY KEY, name TEXT, manager_id INTEGER, salary INTEGER); + + INSERT INTO employees VALUES + (1, 'CEO', NULL, 150000), + (2, 'VPSales', 1, 120000), + (3, 'VPEngineering', 1, 130000), + (4, 'Engineer1', 3, 90000), + (5, 'Engineer2', 3, 85000), + (6, 'SalesRep', 2, 70000); + + CREATE MATERIALIZED VIEW org_chart AS + SELECT e.name as employee, m.name as manager, e.salary + FROM employees e + JOIN employees m ON e.manager_id = m.id; + + SELECT * FROM org_chart ORDER BY salary DESC; +} {VPEngineering|CEO|130000 +VPSales|CEO|120000 +Engineer1|VPEngineering|90000 +Engineer2|VPEngineering|85000 +SalesRep|VPSales|70000} + +do_execsql_test_on_specific_db {:memory:} matview-join-cascade-update { + CREATE TABLE categories(id INTEGER PRIMARY KEY, name TEXT, discount_rate INTEGER); + CREATE TABLE products(id INTEGER PRIMARY KEY, name TEXT, category_id INTEGER, base_price INTEGER); + + INSERT INTO categories VALUES (1, 'Electronics', 10), (2, 'Books', 5); + INSERT INTO products VALUES + (1, 'Laptop', 1, 1000), + (2, 'Phone', 1, 500), + (3, 'Novel', 2, 20), + (4, 'Textbook', 2, 80); + + CREATE MATERIALIZED VIEW discounted_prices AS + SELECT p.name as product, c.name as category, + p.base_price, c.discount_rate, + (p.base_price * (100 - c.discount_rate) / 100) as final_price + FROM products p + JOIN categories c ON p.category_id = c.id; + + SELECT * FROM discounted_prices ORDER BY final_price DESC; + + -- Update discount rate for Electronics + UPDATE categories SET discount_rate = 20 WHERE id = 1; + + SELECT * FROM discounted_prices ORDER BY final_price DESC; +} {Laptop|Electronics|1000|10|900 +Phone|Electronics|500|10|450 +Textbook|Books|80|5|76 +Novel|Books|20|5|19 +Laptop|Electronics|1000|20|800 +Phone|Electronics|500|20|400 +Textbook|Books|80|5|76 +Novel|Books|20|5|19} + +do_execsql_test_on_specific_db {:memory:} matview-join-delete-cascade { + CREATE TABLE users(id INTEGER PRIMARY KEY, name TEXT, active INTEGER); + CREATE TABLE sessions(id INTEGER PRIMARY KEY, user_id INTEGER, duration INTEGER); + + INSERT INTO users VALUES (1, 'Alice', 1), (2, 'Bob', 1), (3, 'Charlie', 0); + INSERT INTO sessions VALUES + (1, 1, 30), + (2, 1, 45), + (3, 2, 60), + (4, 3, 15), + (5, 2, 90); + + CREATE MATERIALIZED VIEW active_sessions AS + SELECT u.name, s.duration + FROM users u + JOIN sessions s ON u.id = s.user_id + WHERE u.active = 1; + + SELECT COUNT(*) FROM active_sessions; + + -- Delete Bob's sessions + DELETE FROM sessions WHERE user_id = 2; + + SELECT COUNT(*) FROM active_sessions; + SELECT * FROM active_sessions ORDER BY name, duration; +} {4 +2 +Alice|30 +Alice|45} + +do_execsql_test_on_specific_db {:memory:} matview-join-complex-where { + CREATE TABLE orders(id INTEGER PRIMARY KEY, customer_id INTEGER, product_id INTEGER, quantity INTEGER, price INTEGER, order_date INTEGER); + CREATE TABLE customers(id INTEGER PRIMARY KEY, name TEXT, tier TEXT, country TEXT); + + INSERT INTO customers VALUES + (1, 'Alice', 'Gold', 'USA'), + (2, 'Bob', 'Silver', 'Canada'), + (3, 'Charlie', 'Gold', 'USA'), + (4, 'David', 'Bronze', 'UK'); + + INSERT INTO orders VALUES + (1, 1, 1, 5, 100, 20240101), + (2, 2, 2, 3, 50, 20240102), + (3, 3, 1, 10, 100, 20240103), + (4, 4, 3, 2, 75, 20240104), + (5, 1, 2, 4, 50, 20240105), + (6, 3, 3, 6, 75, 20240106); + + -- View: Gold tier USA customers with high-value orders + CREATE MATERIALIZED VIEW premium_usa_orders AS + SELECT c.name, o.quantity, o.price, (o.quantity * o.price) as total + FROM customers c + JOIN orders o ON c.id = o.customer_id + WHERE c.tier = 'Gold' + AND c.country = 'USA' + AND (o.quantity * o.price) >= 400; + + SELECT * FROM premium_usa_orders ORDER by total DESC; +} {Charlie|10|100|1000 +Alice|5|100|500 +Charlie|6|75|450}