From 91e2a679b94febbf01c165ac23a008579d92e1cf Mon Sep 17 00:00:00 2001 From: Avinash Sajjanshetty Date: Thu, 18 Sep 2025 17:51:30 +0530 Subject: [PATCH 01/34] bugfix: clear reserved space for a reused page --- core/storage/btree.rs | 14 ++++++++++++++ 1 file changed, 14 insertions(+) diff --git a/core/storage/btree.rs b/core/storage/btree.rs index 78b92fa80..9fa6488a6 100644 --- a/core/storage/btree.rs +++ b/core/storage/btree.rs @@ -6565,6 +6565,20 @@ pub fn btree_init_page(page: &PageRef, page_type: PageType, offset: usize, usabl contents.write_fragmented_bytes_count(0); contents.write_rightmost_ptr(0); + + #[cfg(debug_assertions)] + { + // we might get already used page from the pool. generally this is not a problem because + // b tree access is very controlled. However, for encrypted pages (and also checksums) we want + // to ensure that there are no reserved bytes that contain old data. + let buffer_len = contents.buffer.len(); + turso_assert!( + usable_space <= buffer_len, + "usable_space must be <= buffer_len" + ); + // this is no op if usable_space == buffer_len + contents.as_ptr()[usable_space..buffer_len].fill(0); + } } fn to_static_buf(buf: &mut [u8]) -> &'static mut [u8] { From f0d705946ca3c7f51b8ccf5f4ff6793d09569a17 Mon Sep 17 00:00:00 2001 From: Avinash Sajjanshetty Date: Thu, 18 Sep 2025 19:04:27 +0530 Subject: [PATCH 02/34] keep the reserved bytes check in debug_assertion flag --- core/storage/encryption.rs | 18 +++++++++++++----- 1 file changed, 13 insertions(+), 5 deletions(-) diff --git a/core/storage/encryption.rs b/core/storage/encryption.rs index fb5406b85..c43a6d660 100644 --- a/core/storage/encryption.rs +++ b/core/storage/encryption.rs @@ -440,11 +440,19 @@ impl EncryptionContext { }; let metadata_size = self.cipher_mode.metadata_size(); let reserved_bytes = &page[self.page_size - metadata_size..]; - let reserved_bytes_zeroed = reserved_bytes.iter().all(|&b| b == 0); - assert!( - reserved_bytes_zeroed, - "last reserved bytes must be empty/zero, but found non-zero bytes" - ); + + #[cfg(debug_assertions)] + { + use crate::turso_assert; + // In debug builds, ensure that the reserved bytes are zeroed out. So even when we are + // reusing a page from buffer pool, we zero out in debug build so that we can be + // sure that b tree layer is not writing any data into the reserved space. + let reserved_bytes_zeroed = reserved_bytes.iter().all(|&b| b == 0); + turso_assert!( + reserved_bytes_zeroed, + "last reserved bytes must be empty/zero, but found non-zero bytes" + ); + } let payload = &page[encryption_start_offset..self.page_size - metadata_size]; let (encrypted, nonce) = self.encrypt_raw(payload)?; From ffd1f87682a7704adcfa7e4c372859652842c14c Mon Sep 17 00:00:00 2001 From: PThorpe92 Date: Thu, 18 Sep 2025 18:37:43 -0400 Subject: [PATCH 03/34] Centralize most of the AST traversal by binding columns and rewriting exprs together --- core/translate/delete.rs | 8 +- core/translate/expr.rs | 274 +++++++++++++++++++++++++++++++- core/translate/insert.rs | 13 +- core/translate/optimizer/mod.rs | 140 +--------------- core/translate/planner.rs | 22 ++- core/translate/select.rs | 78 +++++---- core/translate/update.rs | 36 +++-- 7 files changed, 381 insertions(+), 190 deletions(-) diff --git a/core/translate/delete.rs b/core/translate/delete.rs index dee30b2af..c2a76f9ec 100644 --- a/core/translate/delete.rs +++ b/core/translate/delete.rs @@ -1,5 +1,6 @@ use crate::schema::Table; use crate::translate::emitter::emit_program; +use crate::translate::expr::ParamState; use crate::translate::optimizer::optimize_plan; use crate::translate::plan::{DeletePlan, Operation, Plan}; use crate::translate::planner::{parse_limit, parse_where}; @@ -108,6 +109,7 @@ pub fn prepare_delete_plan( let mut table_references = TableReferences::new(joined_tables, vec![]); let mut where_predicates = vec![]; + let mut param_ctx = ParamState::default(); // Parse the WHERE clause parse_where( @@ -116,11 +118,13 @@ pub fn prepare_delete_plan( None, &mut where_predicates, connection, + &mut param_ctx, )?; // Parse the LIMIT/OFFSET clause - let (resolved_limit, resolved_offset) = - limit.map_or(Ok((None, None)), |mut l| parse_limit(&mut l, connection))?; + let (resolved_limit, resolved_offset) = limit.map_or(Ok((None, None)), |mut l| { + parse_limit(&mut l, connection, &mut param_ctx) + })?; let plan = DeletePlan { table_references, diff --git a/core/translate/expr.rs b/core/translate/expr.rs index 95ac93d95..61cc65faf 100644 --- a/core/translate/expr.rs +++ b/core/translate/expr.rs @@ -1,5 +1,7 @@ +use std::sync::Arc; + use tracing::{instrument, Level}; -use turso_parser::ast::{self, As, Expr, UnaryOperator}; +use turso_parser::ast::{self, As, Expr, TableInternalId, UnaryOperator}; use super::emitter::Resolver; use super::optimizer::Optimizable; @@ -8,8 +10,12 @@ use super::plan::TableReferences; use crate::function::JsonFunc; use crate::function::{Func, FuncCtx, MathFuncArity, ScalarFunc, VectorFunc}; use crate::functions::datetime; +use crate::parameters::PARAM_PREFIX; use crate::schema::{affinity, Affinity, Table, Type}; -use crate::util::{exprs_are_equivalent, parse_numeric_literal}; +use crate::translate::optimizer::TakeOwnership; +use crate::translate::plan::ResultSetColumn; +use crate::translate::planner::parse_row_id; +use crate::util::{exprs_are_equivalent, normalize_ident, parse_numeric_literal}; use crate::vdbe::builder::CursorKey; use crate::vdbe::{ builder::ProgramBuilder, @@ -3244,6 +3250,260 @@ where Ok(WalkControl::Continue) } +pub struct ParamState { + /// ALWAYS starts at 1 + pub next_param_idx: usize, +} + +impl Default for ParamState { + fn default() -> Self { + Self { next_param_idx: 1 } + } +} + +pub fn bind_and_rewrite_expr<'a>( + top_level_expr: &mut ast::Expr, + mut referenced_tables: Option<&'a mut TableReferences>, + result_columns: Option<&'a [ResultSetColumn]>, + connection: &'a Arc, + param_state: &mut ParamState, +) -> Result { + walk_expr_mut( + top_level_expr, + &mut |expr: &mut ast::Expr| -> Result { + match expr { + // Rewrite anonymous variables in encounter order. + ast::Expr::Variable(var) if var.is_empty() => { + *expr = ast::Expr::Variable(format!( + "{}{}", + PARAM_PREFIX, param_state.next_param_idx + )); + param_state.next_param_idx += 1; + } + ast::Expr::Qualified(ast::Name::Quoted(ns), ast::Name::Quoted(c)) + | ast::Expr::DoublyQualified(_, ast::Name::Quoted(ns), ast::Name::Quoted(c)) => { + *expr = ast::Expr::Qualified( + ast::Name::Ident(normalize_ident(ns.as_str())), + ast::Name::Ident(normalize_ident(c.as_str())), + ); + } + // Expand BETWEEN to binary ops (kept identical to your logic). + ast::Expr::Between { + lhs, + not, + start, + end, + } => { + let (lower_op, upper_op) = if *not { + (ast::Operator::Greater, ast::Operator::Greater) + } else { + (ast::Operator::LessEquals, ast::Operator::LessEquals) + }; + let start = start.take_ownership(); + let lhs_v = lhs.take_ownership(); + let end = end.take_ownership(); + + let lower = + ast::Expr::Binary(Box::new(start), lower_op, Box::new(lhs_v.clone())); + let upper = ast::Expr::Binary(Box::new(lhs_v), upper_op, Box::new(end)); + + *expr = if *not { + ast::Expr::Binary(Box::new(lower), ast::Operator::Or, Box::new(upper)) + } else { + ast::Expr::Binary(Box::new(lower), ast::Operator::And, Box::new(upper)) + }; + } + _ => {} + } + + if let Some(referenced_tables) = &mut referenced_tables { + match expr { + // Unqualified identifier binding (including rowid aliases, outer refs, result-column fallback). + ast::Expr::Id(id) => { + let ident = normalize_ident(id.as_str()); + // Optional fast-path for rowid on simple FROM t + if !referenced_tables.joined_tables().is_empty() { + if let Some(row_id_expr) = parse_row_id( + &ident, + referenced_tables.joined_tables()[0].internal_id, + || referenced_tables.joined_tables().len() != 1, + )? { + *expr = row_id_expr; + return Ok(WalkControl::Continue); + } + } + + // Search joined tables + let mut match_result: Option<(TableInternalId, usize, bool)> = None; + for jt in referenced_tables.joined_tables().iter() { + if let Some(col_idx) = jt.table.columns().iter().position(|c| { + c.name + .as_ref() + .is_some_and(|n| n.eq_ignore_ascii_case(&ident)) + }) { + if match_result.is_some() { + crate::bail_parse_error!("Column {} is ambiguous", id.as_str()); + } + let col = jt.table.columns()[col_idx].clone(); + match_result = Some((jt.internal_id, col_idx, col.is_rowid_alias)); + } + } + + // If not found, search outer query references (outer scope wins for inner queries) + if match_result.is_none() { + for outer in referenced_tables.outer_query_refs().iter() { + if let Some(col_idx) = outer.table.columns().iter().position(|c| { + c.name + .as_ref() + .is_some_and(|n| n.eq_ignore_ascii_case(&ident)) + }) { + if match_result.is_some() { + crate::bail_parse_error!( + "Column {} is ambiguous", + id.as_str() + ); + } + let col = outer.table.columns()[col_idx].clone(); + match_result = + Some((outer.internal_id, col_idx, col.is_rowid_alias)); + } + } + } + + if let Some((tbl_id, col_idx, is_rowid_alias)) = match_result { + *expr = ast::Expr::Column { + database: None, + table: tbl_id, + column: col_idx, + is_rowid_alias, + }; + referenced_tables.mark_column_used(tbl_id, col_idx); + return Ok(WalkControl::Continue); + } + + // Result-column fallback (e.g. SELECT ... WHERE name; name is a result alias) + if let Some(rcs) = result_columns { + for rc in rcs { + if rc + .name(referenced_tables) + .is_some_and(|n| n.eq_ignore_ascii_case(&ident)) + { + *expr = rc.expr.clone(); + return Ok(WalkControl::Continue); + } + } + } + + // Double-quoted unresolved, string literal, others must resolve + if id.is_double_quoted() { + *expr = + ast::Expr::Literal(ast::Literal::String(id.as_str().to_string())); + return Ok(WalkControl::Continue); + } else { + crate::bail_parse_error!("no such column: {}", id.as_str()) + } + } + + ast::Expr::Qualified(tbl, id) => { + let tbl_name = normalize_ident(tbl.as_str()); + let Some((tbl_id, tbl_ref)) = + referenced_tables.find_table_and_internal_id_by_identifier(&tbl_name) + else { + crate::bail_parse_error!("no such table: {}", tbl_name); + }; + + let ident = normalize_ident(id.as_str()); + + if let Some(row_id_expr) = parse_row_id(&ident, tbl_id, || false)? { + *expr = row_id_expr; + return Ok(WalkControl::Continue); + } + + let Some(col_idx) = tbl_ref.columns().iter().position(|c| { + c.name + .as_ref() + .is_some_and(|n| n.eq_ignore_ascii_case(&ident)) + }) else { + crate::bail_parse_error!("no such column: {}", ident); + }; + + let col = &tbl_ref.columns()[col_idx]; + *expr = ast::Expr::Column { + database: None, + table: tbl_id, + column: col_idx, + is_rowid_alias: col.is_rowid_alias, + }; + referenced_tables.mark_column_used(tbl_id, col_idx); + return Ok(WalkControl::Continue); + } + + // db.t.x (requires table to already be present in FROM) + ast::Expr::DoublyQualified(db_name, tbl_name, col_name) => { + let qn = ast::QualifiedName { + db_name: Some(db_name.clone()), + name: ast::Name::Ident(normalize_ident(tbl_name.as_str())), + alias: None, + }; + let db_id = connection.resolve_database_id(&qn)?; + + let table = connection + .with_schema(db_id, |schema| schema.get_table(tbl_name.as_str())) + .ok_or_else(|| { + crate::LimboError::ParseError(format!( + "no such table: {}.{}", + db_name.as_str(), + tbl_name.as_str() + )) + })?; + + let ident = normalize_ident(col_name.as_str()); + let col_idx = table + .columns() + .iter() + .position(|c| { + c.name + .as_ref() + .is_some_and(|n| n.eq_ignore_ascii_case(&ident)) + }) + .ok_or_else(|| { + crate::LimboError::ParseError(format!( + "Column: {}.{}.{} not found", + db_name.as_str(), + tbl_name.as_str(), + col_name.as_str() + )) + })?; + + let is_rowid_alias = table.columns()[col_idx].is_rowid_alias; + + // Only allow if the table is already in the FROM clause + let normalized_tbl = normalize_ident(tbl_name.as_str()); + let Some((tbl_id, _)) = referenced_tables + .find_table_and_internal_id_by_identifier(&normalized_tbl) + else { + return Err(crate::LimboError::ParseError(format!( + "table {normalized_tbl} is not in FROM clause - cross-database column references require the table to be explicitly joined" + ))); + }; + + *expr = ast::Expr::Column { + database: Some(db_id), + table: tbl_id, + column: col_idx, + is_rowid_alias, + }; + referenced_tables.mark_column_used(tbl_id, col_idx); + return Ok(WalkControl::Continue); + } + _ => {} + } + } + Ok(WalkControl::Continue) + }, + ) +} + /// Recursively walks a mutable expression, applying a function to each sub-expression. pub fn walk_expr_mut(expr: &mut ast::Expr, func: &mut F) -> Result where @@ -3709,12 +3969,12 @@ pub fn process_returning_clause( table_name: &str, program: &mut ProgramBuilder, connection: &std::sync::Arc, + param_ctx: &mut ParamState, ) -> Result<( Vec, super::plan::TableReferences, )> { use super::plan::{ColumnUsedMask, JoinedTable, Operation, ResultSetColumn, TableReferences}; - use super::planner::bind_column_references; let mut result_columns = vec![]; @@ -3741,7 +4001,13 @@ pub fn process_returning_clause( ast::ResultColumn::Expr(expr, alias) => { let column_alias = determine_column_alias(expr, alias, table); - bind_column_references(expr, &mut table_references, None, connection)?; + bind_and_rewrite_expr( + expr, + Some(&mut table_references), + None, + connection, + param_ctx, + )?; result_columns.push(ResultSetColumn { expr: *expr.clone(), diff --git a/core/translate/insert.rs b/core/translate/insert.rs index 83c176a77..ddcf00755 100644 --- a/core/translate/insert.rs +++ b/core/translate/insert.rs @@ -10,7 +10,8 @@ use crate::translate::emitter::{ emit_cdc_insns, emit_cdc_patch_record, prepare_cdc_if_necessary, OperationMode, }; use crate::translate::expr::{ - emit_returning_results, process_returning_clause, ReturningValueRegisters, + bind_and_rewrite_expr, emit_returning_results, process_returning_clause, ParamState, + ReturningValueRegisters, }; use crate::translate::planner::ROWID; use crate::translate::upsert::{ @@ -31,7 +32,6 @@ use crate::{Result, SymbolTable, VirtualTable}; use super::emitter::Resolver; use super::expr::{translate_expr, translate_expr_no_constant_opt, NoConstantOptReason}; -use super::optimizer::rewrite_expr; use super::plan::QueryDestination; use super::select::translate_select; @@ -118,7 +118,7 @@ pub fn translate_insert( let mut values: Option>> = None; let mut upsert_opt: Option = None; - let mut param_idx = 1; + let mut param_ctx = ParamState::default(); let mut inserting_multiple_rows = false; if let InsertBody::Select(select, upsert) = &mut body { match &mut select.body.select { @@ -144,7 +144,7 @@ pub fn translate_insert( } _ => {} } - rewrite_expr(expr, &mut param_idx)?; + bind_and_rewrite_expr(expr, None, None, connection, &mut param_ctx)?; } values = values_expr.pop(); } @@ -157,10 +157,10 @@ pub fn translate_insert( } = &mut upsert.do_clause { for set in sets.iter_mut() { - rewrite_expr(set.expr.as_mut(), &mut param_idx)?; + bind_and_rewrite_expr(&mut set.expr, None, None, connection, &mut param_ctx)?; } if let Some(ref mut where_expr) = where_clause { - rewrite_expr(where_expr.as_mut(), &mut param_idx)?; + bind_and_rewrite_expr(where_expr, None, None, connection, &mut param_ctx)?; } } } @@ -180,6 +180,7 @@ pub fn translate_insert( table_name.as_str(), &mut program, connection, + &mut param_ctx, )?; let mut yield_reg_opt = None; diff --git a/core/translate/optimizer/mod.rs b/core/translate/optimizer/mod.rs index 6fc2dbe6f..b9df7c698 100644 --- a/core/translate/optimizer/mod.rs +++ b/core/translate/optimizer/mod.rs @@ -8,15 +8,13 @@ use join::{compute_best_join_order, BestJoinOrderResult}; use lift_common_subexpressions::lift_common_subexpressions_from_binary_or_terms; use order::{compute_order_target, plan_satisfies_order_target, EliminatesSortBy}; use turso_ext::{ConstraintInfo, ConstraintUsage}; -use turso_macros::match_ignore_ascii_case; use turso_parser::ast::{self, Expr, SortOrder}; use crate::{ - parameters::PARAM_PREFIX, schema::{Index, IndexColumn, Schema, Table}, translate::{ - expr::walk_expr_mut, expr::WalkControl, optimizer::access_method::AccessMethodParams, - optimizer::constraints::TableConstraints, plan::Scan, plan::TerminationKey, + optimizer::access_method::AccessMethodParams, optimizer::constraints::TableConstraints, + plan::Scan, plan::TerminationKey, }, types::SeekOp, LimboError, Result, @@ -64,7 +62,7 @@ pub fn optimize_plan(plan: &mut Plan, schema: &Schema) -> Result<()> { */ pub fn optimize_select_plan(plan: &mut SelectPlan, schema: &Schema) -> Result<()> { optimize_subqueries(plan, schema)?; - rewrite_exprs_select(plan)?; + lift_common_subexpressions_from_binary_or_terms(&mut plan.where_clause)?; if let ConstantConditionEliminationResult::ImpossibleCondition = eliminate_constant_conditions(&mut plan.where_clause)? { @@ -89,7 +87,7 @@ pub fn optimize_select_plan(plan: &mut SelectPlan, schema: &Schema) -> Result<() } fn optimize_delete_plan(plan: &mut DeletePlan, schema: &Schema) -> Result<()> { - rewrite_exprs_delete(plan)?; + lift_common_subexpressions_from_binary_or_terms(&mut plan.where_clause)?; if let ConstantConditionEliminationResult::ImpossibleCondition = eliminate_constant_conditions(&mut plan.where_clause)? { @@ -110,7 +108,7 @@ fn optimize_delete_plan(plan: &mut DeletePlan, schema: &Schema) -> Result<()> { } fn optimize_update_plan(plan: &mut UpdatePlan, schema: &Schema) -> Result<()> { - rewrite_exprs_update(plan)?; + lift_common_subexpressions_from_binary_or_terms(&mut plan.where_clause)?; if let ConstantConditionEliminationResult::ImpossibleCondition = eliminate_constant_conditions(&mut plan.where_clause)? { @@ -558,62 +556,6 @@ fn eliminate_constant_conditions( Ok(ConstantConditionEliminationResult::Continue) } -fn rewrite_exprs_select(plan: &mut SelectPlan) -> Result<()> { - let mut param_count = 1; - for rc in plan.result_columns.iter_mut() { - rewrite_expr(&mut rc.expr, &mut param_count)?; - } - for agg in plan.aggregates.iter_mut() { - rewrite_expr(&mut agg.original_expr, &mut param_count)?; - } - lift_common_subexpressions_from_binary_or_terms(&mut plan.where_clause)?; - for cond in plan.where_clause.iter_mut() { - rewrite_expr(&mut cond.expr, &mut param_count)?; - } - if let Some(group_by) = &mut plan.group_by { - for expr in group_by.exprs.iter_mut() { - rewrite_expr(expr, &mut param_count)?; - } - } - for (expr, _) in plan.order_by.iter_mut() { - rewrite_expr(expr, &mut param_count)?; - } - if let Some(window) = &mut plan.window { - for func in window.functions.iter_mut() { - rewrite_expr(&mut func.original_expr, &mut param_count)?; - } - } - - Ok(()) -} - -fn rewrite_exprs_delete(plan: &mut DeletePlan) -> Result<()> { - let mut param_idx = 1; - for cond in plan.where_clause.iter_mut() { - rewrite_expr(&mut cond.expr, &mut param_idx)?; - } - Ok(()) -} - -fn rewrite_exprs_update(plan: &mut UpdatePlan) -> Result<()> { - let mut param_idx = 1; - for (_, expr) in plan.set_clauses.iter_mut() { - rewrite_expr(expr, &mut param_idx)?; - } - for cond in plan.where_clause.iter_mut() { - rewrite_expr(&mut cond.expr, &mut param_idx)?; - } - for (expr, _) in plan.order_by.iter_mut() { - rewrite_expr(expr, &mut param_idx)?; - } - if let Some(rc) = plan.returning.as_mut() { - for rc in rc.iter_mut() { - rewrite_expr(&mut rc.expr, &mut param_idx)?; - } - } - Ok(()) -} - #[derive(Debug, Clone, Copy, PartialEq, Eq)] pub enum AlwaysTrueOrFalse { AlwaysTrue, @@ -1449,77 +1391,7 @@ fn build_seek_def( }) } -pub fn rewrite_expr(top_level_expr: &mut ast::Expr, param_idx: &mut usize) -> Result { - walk_expr_mut( - top_level_expr, - &mut |expr: &mut ast::Expr| -> Result { - match expr { - ast::Expr::Id(id) => { - // Convert "true" and "false" to 1 and 0 - let id_bytes = id.as_str().as_bytes(); - match_ignore_ascii_case!(match id_bytes { - b"true" => { - *expr = ast::Expr::Literal(ast::Literal::Numeric("1".to_owned())); - } - b"false" => { - *expr = ast::Expr::Literal(ast::Literal::Numeric("0".to_owned())); - } - _ => {} - }) - } - ast::Expr::Variable(var) => { - if var.is_empty() { - // rewrite anonymous variables only, ensure that the `param_idx` starts at 1 and - // all the expressions are rewritten in the order they come in the statement - *expr = ast::Expr::Variable(format!("{PARAM_PREFIX}{param_idx}")); - *param_idx += 1; - } - } - ast::Expr::Between { - lhs, - not, - start, - end, - } => { - // Convert `y NOT BETWEEN x AND z` to `x > y OR y > z` - let (lower_op, upper_op) = if *not { - (ast::Operator::Greater, ast::Operator::Greater) - } else { - // Convert `y BETWEEN x AND z` to `x <= y AND y <= z` - (ast::Operator::LessEquals, ast::Operator::LessEquals) - }; - - let start = start.take_ownership(); - let lhs = lhs.take_ownership(); - let end = end.take_ownership(); - - let lower_bound = - ast::Expr::Binary(Box::new(start), lower_op, Box::new(lhs.clone())); - let upper_bound = ast::Expr::Binary(Box::new(lhs), upper_op, Box::new(end)); - - if *not { - *expr = ast::Expr::Binary( - Box::new(lower_bound), - ast::Operator::Or, - Box::new(upper_bound), - ); - } else { - *expr = ast::Expr::Binary( - Box::new(lower_bound), - ast::Operator::And, - Box::new(upper_bound), - ); - } - } - _ => {} - } - - Ok(WalkControl::Continue) - }, - ) -} - -trait TakeOwnership { +pub trait TakeOwnership { fn take_ownership(&mut self) -> Self; } diff --git a/core/translate/planner.rs b/core/translate/planner.rs index 14f422860..3a94af2b8 100644 --- a/core/translate/planner.rs +++ b/core/translate/planner.rs @@ -11,7 +11,6 @@ use super::{ select::prepare_select_plan, SymbolTable, }; -use crate::function::{AggFunc, ExtFunc}; use crate::translate::expr::WalkControl; use crate::translate::plan::{Window, WindowFunction}; use crate::{ @@ -23,6 +22,10 @@ use crate::{ vdbe::builder::TableRefIdCounter, Result, }; +use crate::{ + function::{AggFunc, ExtFunc}, + translate::expr::{bind_and_rewrite_expr, ParamState}, +}; use turso_macros::match_ignore_ascii_case; use turso_parser::ast::Literal::Null; use turso_parser::ast::{ @@ -886,12 +889,19 @@ pub fn parse_where( result_columns: Option<&[ResultSetColumn]>, out_where_clause: &mut Vec, connection: &Arc, + param_ctx: &mut ParamState, ) -> Result<()> { if let Some(where_expr) = where_clause { let start_idx = out_where_clause.len(); break_predicate_at_and_boundaries(where_expr, out_where_clause); for expr in out_where_clause[start_idx..].iter_mut() { - bind_column_references(&mut expr.expr, table_references, result_columns, connection)?; + bind_and_rewrite_expr( + &mut expr.expr, + Some(table_references), + result_columns, + connection, + param_ctx, + )?; } Ok(()) } else { @@ -1290,7 +1300,7 @@ pub fn break_predicate_at_and_boundaries>( } } -fn parse_row_id( +pub fn parse_row_id( column_name: &str, table_id: TableInternalId, fn_check: F, @@ -1315,11 +1325,11 @@ where pub fn parse_limit( limit: &mut Limit, connection: &std::sync::Arc, + param_ctx: &mut ParamState, ) -> Result<(Option>, Option>)> { - let mut empty_refs = TableReferences::new(Vec::new(), Vec::new()); - bind_column_references(&mut limit.expr, &mut empty_refs, None, connection)?; + bind_and_rewrite_expr(&mut limit.expr, None, None, connection, param_ctx)?; if let Some(ref mut off_expr) = limit.offset { - bind_column_references(off_expr, &mut empty_refs, None, connection)?; + bind_and_rewrite_expr(off_expr, None, None, connection, param_ctx)?; } Ok((Some(limit.expr.clone()), limit.offset.clone())) } diff --git a/core/translate/select.rs b/core/translate/select.rs index e13eed952..20ee62659 100644 --- a/core/translate/select.rs +++ b/core/translate/select.rs @@ -4,11 +4,12 @@ use super::plan::{ Search, TableReferences, WhereTerm, Window, }; use crate::schema::Table; +use crate::translate::expr::{bind_and_rewrite_expr, ParamState}; use crate::translate::optimizer::optimize_plan; use crate::translate::plan::{GroupBy, Plan, ResultSetColumn, SelectPlan}; use crate::translate::planner::{ - bind_column_references, break_predicate_at_and_boundaries, parse_from, parse_limit, - parse_where, resolve_window_and_aggregate_functions, + break_predicate_at_and_boundaries, parse_from, parse_limit, parse_where, + resolve_window_and_aggregate_functions, }; use crate::translate::window::plan_windows; use crate::util::normalize_ident; @@ -98,6 +99,7 @@ pub fn prepare_select_plan( connection: &Arc, ) -> Result { let compounds = select.body.compounds; + let mut param_ctx = ParamState::default(); match compounds.is_empty() { true => Ok(Plan::Select(prepare_one_select_plan( schema, @@ -110,6 +112,7 @@ pub fn prepare_select_plan( table_ref_counter, query_destination, connection, + &mut param_ctx, )?)), false => { let mut last = prepare_one_select_plan( @@ -123,6 +126,7 @@ pub fn prepare_select_plan( table_ref_counter, query_destination.clone(), connection, + &mut param_ctx, )?; let mut left = Vec::with_capacity(compounds.len()); @@ -139,6 +143,7 @@ pub fn prepare_select_plan( table_ref_counter, query_destination.clone(), connection, + &mut param_ctx, )?; } @@ -149,9 +154,9 @@ pub fn prepare_select_plan( crate::bail_parse_error!("SELECTs to the left and right of {} do not have the same number of result columns", operator); } } - let (limit, offset) = select - .limit - .map_or(Ok((None, None)), |mut l| parse_limit(&mut l, connection))?; + let (limit, offset) = select.limit.map_or(Ok((None, None)), |mut l| { + parse_limit(&mut l, connection, &mut param_ctx) + })?; // FIXME: handle ORDER BY for compound selects if !select.order_by.is_empty() { @@ -184,6 +189,7 @@ fn prepare_one_select_plan( table_ref_counter: &mut TableRefIdCounter, query_destination: QueryDestination, connection: &Arc, + param_ctx: &mut ParamState, ) -> Result { match select { ast::OneSelect::Select { @@ -255,7 +261,6 @@ fn prepare_one_select_plan( }) .sum(), ); - let mut plan = SelectPlan { join_order: table_references .joined_tables() @@ -288,19 +293,21 @@ fn prepare_one_select_plan( let mut window = Window::new(Some(name), &window_def.window)?; for expr in window.partition_by.iter_mut() { - bind_column_references( + bind_and_rewrite_expr( expr, - &mut plan.table_references, - Some(&plan.result_columns), + Some(&mut plan.table_references), + None, connection, + param_ctx, )?; } for (expr, _) in window.order_by.iter_mut() { - bind_column_references( + bind_and_rewrite_expr( expr, - &mut plan.table_references, - Some(&plan.result_columns), + Some(&mut plan.table_references), + None, connection, + param_ctx, )?; } @@ -357,11 +364,12 @@ fn prepare_one_select_plan( } } ResultColumn::Expr(ref mut expr, maybe_alias) => { - bind_column_references( + bind_and_rewrite_expr( expr, - &mut plan.table_references, - Some(&plan.result_columns), + Some(&mut plan.table_references), + None, connection, + param_ctx, )?; let contains_aggregates = resolve_window_and_aggregate_functions( schema, @@ -385,7 +393,12 @@ fn prepare_one_select_plan( // This step can only be performed at this point, because all table references are now available. // Virtual table predicates may depend on column bindings from tables to the right in the join order, // so we must wait until the full set of references has been collected. - add_vtab_predicates_to_where_clause(&mut vtab_predicates, &mut plan, connection)?; + add_vtab_predicates_to_where_clause( + &mut vtab_predicates, + &mut plan, + connection, + param_ctx, + )?; // Parse the actual WHERE clause and add its conditions to the plan WHERE clause that already contains the join conditions. parse_where( @@ -394,16 +407,18 @@ fn prepare_one_select_plan( Some(&plan.result_columns), &mut plan.where_clause, connection, + param_ctx, )?; if let Some(mut group_by) = group_by { for expr in group_by.exprs.iter_mut() { replace_column_number_with_copy_of_column_expr(expr, &plan.result_columns)?; - bind_column_references( + bind_and_rewrite_expr( expr, - &mut plan.table_references, - Some(&plan.result_columns), + Some(&mut plan.table_references), + Some(&mut plan.result_columns), connection, + param_ctx, )?; } @@ -414,11 +429,12 @@ fn prepare_one_select_plan( let mut predicates = vec![]; break_predicate_at_and_boundaries(&having, &mut predicates); for expr in predicates.iter_mut() { - bind_column_references( + bind_and_rewrite_expr( expr, - &mut plan.table_references, - Some(&plan.result_columns), + Some(&mut plan.table_references), + Some(&mut plan.result_columns), connection, + param_ctx, )?; let contains_aggregates = resolve_window_and_aggregate_functions( schema, @@ -452,11 +468,12 @@ fn prepare_one_select_plan( for mut o in order_by { replace_column_number_with_copy_of_column_expr(&mut o.expr, &plan.result_columns)?; - bind_column_references( + bind_and_rewrite_expr( &mut o.expr, - &mut plan.table_references, - Some(&plan.result_columns), + Some(&mut plan.table_references), + Some(&mut plan.result_columns), connection, + param_ctx, )?; resolve_window_and_aggregate_functions( schema, @@ -471,8 +488,9 @@ fn prepare_one_select_plan( plan.order_by = key; // Parse the LIMIT/OFFSET clause - (plan.limit, plan.offset) = - limit.map_or(Ok((None, None)), |mut l| parse_limit(&mut l, connection))?; + (plan.limit, plan.offset) = limit.map_or(Ok((None, None)), |mut l| { + parse_limit(&mut l, connection, param_ctx) + })?; if !windows.is_empty() { plan_windows(schema, syms, &mut plan, table_ref_counter, &mut windows)?; @@ -521,13 +539,15 @@ fn add_vtab_predicates_to_where_clause( vtab_predicates: &mut Vec, plan: &mut SelectPlan, connection: &Arc, + param_ctx: &mut ParamState, ) -> Result<()> { for expr in vtab_predicates.iter_mut() { - bind_column_references( + bind_and_rewrite_expr( expr, - &mut plan.table_references, + Some(&mut plan.table_references), Some(&plan.result_columns), connection, + param_ctx, )?; } for expr in vtab_predicates.drain(..) { diff --git a/core/translate/update.rs b/core/translate/update.rs index 6ca366049..feb3d926d 100644 --- a/core/translate/update.rs +++ b/core/translate/update.rs @@ -2,6 +2,7 @@ use std::collections::HashMap; use std::sync::Arc; use crate::schema::{BTreeTable, Column, Type}; +use crate::translate::expr::{bind_and_rewrite_expr, ParamState}; use crate::translate::optimizer::optimize_select_plan; use crate::translate::plan::{Operation, QueryDestination, Scan, Search, SelectPlan}; use crate::translate::planner::parse_limit; @@ -22,7 +23,7 @@ use super::plan::{ ColumnUsedMask, IterationDirection, JoinedTable, Plan, ResultSetColumn, TableReferences, UpdatePlan, }; -use super::planner::{bind_column_references, parse_where}; +use super::planner::parse_where; /* * Update is simple. By default we scan the table, and for each row, we check the WHERE * clause. If it evaluates to true, we build the new record with the updated value and insert. @@ -90,7 +91,6 @@ pub fn translate_update_for_schema_change( } optimize_plan(&mut plan, schema)?; - // TODO: freestyling these numbers let opts = ProgramBuilderOpts { num_cursors: 1, approx_num_insns: 20, @@ -181,11 +181,18 @@ pub fn prepare_update_plan( .collect(); let mut set_clauses = Vec::with_capacity(body.sets.len()); + let mut param_idx = ParamState::default(); // Process each SET assignment and map column names to expressions // e.g the statement `SET x = 1, y = 2, z = 3` has 3 set assigments for set in &mut body.sets { - bind_column_references(&mut set.expr, &mut table_references, None, connection)?; + bind_and_rewrite_expr( + &mut set.expr, + Some(&mut table_references), + None, + connection, + &mut param_idx, + )?; let values = match set.expr.as_ref() { Expr::Parenthesized(vals) => vals.clone(), @@ -222,12 +229,22 @@ pub fn prepare_update_plan( body.tbl_name.name.as_str(), program, connection, + &mut param_idx, )?; let order_by = body .order_by - .iter() - .map(|o| (o.expr.clone(), o.order.unwrap_or(SortOrder::Asc))) + .iter_mut() + .map(|o| { + let _ = bind_and_rewrite_expr( + &mut o.expr, + Some(&mut table_references), + Some(&result_columns), + connection, + &mut param_idx, + ); + (o.expr.clone(), o.order.unwrap_or(SortOrder::Asc)) + }) .collect(); // Sqlite determines we should create an ephemeral table if we do not have a FROM clause @@ -266,6 +283,7 @@ pub fn prepare_update_plan( Some(&result_columns), &mut where_clause, connection, + &mut param_idx, )?; let table = Arc::new(BTreeTable { @@ -342,14 +360,14 @@ pub fn prepare_update_plan( Some(&result_columns), &mut where_clause, connection, + &mut param_idx, )?; }; // Parse the LIMIT/OFFSET clause - let (limit, offset) = body - .limit - .as_mut() - .map_or(Ok((None, None)), |l| parse_limit(l, connection))?; + let (limit, offset) = body.limit.as_mut().map_or(Ok((None, None)), |l| { + parse_limit(l, connection, &mut param_idx) + })?; // Check what indexes will need to be updated by checking set_clauses and see // if a column is contained in an index. From 38096ffc9ed417e79aef170b1ab0e6f5df515ab4 Mon Sep 17 00:00:00 2001 From: PThorpe92 Date: Thu, 18 Sep 2025 18:44:35 -0400 Subject: [PATCH 04/34] Rewrite true/false to 0/1 even tho its also done in the parser now --- core/translate/expr.rs | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/core/translate/expr.rs b/core/translate/expr.rs index 61cc65faf..36770b46a 100644 --- a/core/translate/expr.rs +++ b/core/translate/expr.rs @@ -3272,6 +3272,12 @@ pub fn bind_and_rewrite_expr<'a>( top_level_expr, &mut |expr: &mut ast::Expr| -> Result { match expr { + ast::Expr::Id(ast::Name::Ident(n)) if n.eq_ignore_ascii_case("true") => { + *expr = ast::Expr::Literal(ast::Literal::Numeric("1".to_string())); + } + ast::Expr::Id(ast::Name::Ident(n)) if n.eq_ignore_ascii_case("false") => { + *expr = ast::Expr::Literal(ast::Literal::Numeric("0".to_string())); + } // Rewrite anonymous variables in encounter order. ast::Expr::Variable(var) if var.is_empty() => { *expr = ast::Expr::Variable(format!( From 6f446aaf4824ca5baea25bceadc997a481f2f262 Mon Sep 17 00:00:00 2001 From: PThorpe92 Date: Thu, 18 Sep 2025 18:59:28 -0400 Subject: [PATCH 05/34] remove bind_column_references method and its last usages --- core/translate/expr.rs | 196 +++++++++++++++++-------------- core/translate/planner.rs | 238 ++------------------------------------ core/translate/select.rs | 1 + 3 files changed, 118 insertions(+), 317 deletions(-) diff --git a/core/translate/expr.rs b/core/translate/expr.rs index 36770b46a..515d0a263 100644 --- a/core/translate/expr.rs +++ b/core/translate/expr.rs @@ -3321,121 +3321,137 @@ pub fn bind_and_rewrite_expr<'a>( } _ => {} } - if let Some(referenced_tables) = &mut referenced_tables { match expr { // Unqualified identifier binding (including rowid aliases, outer refs, result-column fallback). - ast::Expr::Id(id) => { - let ident = normalize_ident(id.as_str()); - // Optional fast-path for rowid on simple FROM t + Expr::Id(id) => { + let normalized_id = normalize_ident(id.as_str()); if !referenced_tables.joined_tables().is_empty() { if let Some(row_id_expr) = parse_row_id( - &ident, + &normalized_id, referenced_tables.joined_tables()[0].internal_id, || referenced_tables.joined_tables().len() != 1, )? { *expr = row_id_expr; + return Ok(WalkControl::Continue); } } + let mut match_result = None; - // Search joined tables - let mut match_result: Option<(TableInternalId, usize, bool)> = None; - for jt in referenced_tables.joined_tables().iter() { - if let Some(col_idx) = jt.table.columns().iter().position(|c| { + // First check joined tables + for joined_table in referenced_tables.joined_tables().iter() { + let col_idx = joined_table.table.columns().iter().position(|c| { c.name .as_ref() - .is_some_and(|n| n.eq_ignore_ascii_case(&ident)) - }) { + .is_some_and(|name| name.eq_ignore_ascii_case(&normalized_id)) + }); + if col_idx.is_some() { if match_result.is_some() { crate::bail_parse_error!("Column {} is ambiguous", id.as_str()); } - let col = jt.table.columns()[col_idx].clone(); - match_result = Some((jt.internal_id, col_idx, col.is_rowid_alias)); + let col = + joined_table.table.columns().get(col_idx.unwrap()).unwrap(); + match_result = Some(( + joined_table.internal_id, + col_idx.unwrap(), + col.is_rowid_alias, + )); } } - // If not found, search outer query references (outer scope wins for inner queries) + // Then check outer query references, if we still didn't find something. + // Normally finding multiple matches for a non-qualified column is an error (column x is ambiguous) + // but in the case of subqueries, the inner query takes precedence. + // For example: + // SELECT * FROM t WHERE x = (SELECT x FROM t2) + // In this case, there is no ambiguity: + // - x in the outer query refers to t.x, + // - x in the inner query refers to t2.x. if match_result.is_none() { - for outer in referenced_tables.outer_query_refs().iter() { - if let Some(col_idx) = outer.table.columns().iter().position(|c| { - c.name - .as_ref() - .is_some_and(|n| n.eq_ignore_ascii_case(&ident)) - }) { + for outer_ref in referenced_tables.outer_query_refs().iter() { + let col_idx = outer_ref.table.columns().iter().position(|c| { + c.name.as_ref().is_some_and(|name| { + name.eq_ignore_ascii_case(&normalized_id) + }) + }); + if col_idx.is_some() { if match_result.is_some() { crate::bail_parse_error!( "Column {} is ambiguous", id.as_str() ); } - let col = outer.table.columns()[col_idx].clone(); - match_result = - Some((outer.internal_id, col_idx, col.is_rowid_alias)); + let col = + outer_ref.table.columns().get(col_idx.unwrap()).unwrap(); + match_result = Some(( + outer_ref.internal_id, + col_idx.unwrap(), + col.is_rowid_alias, + )); } } } - if let Some((tbl_id, col_idx, is_rowid_alias)) = match_result { - *expr = ast::Expr::Column { - database: None, - table: tbl_id, + if let Some((table_id, col_idx, is_rowid_alias)) = match_result { + *expr = Expr::Column { + database: None, // TODO: support different databases + table: table_id, column: col_idx, is_rowid_alias, }; - referenced_tables.mark_column_used(tbl_id, col_idx); + referenced_tables.mark_column_used(table_id, col_idx); return Ok(WalkControl::Continue); } - // Result-column fallback (e.g. SELECT ... WHERE name; name is a result alias) - if let Some(rcs) = result_columns { - for rc in rcs { - if rc + if let Some(result_columns) = result_columns { + for result_column in result_columns.iter() { + if result_column .name(referenced_tables) - .is_some_and(|n| n.eq_ignore_ascii_case(&ident)) + .is_some_and(|name| name.eq_ignore_ascii_case(&normalized_id)) { - *expr = rc.expr.clone(); + *expr = result_column.expr.clone(); return Ok(WalkControl::Continue); } } } - - // Double-quoted unresolved, string literal, others must resolve + // SQLite behavior: Only double-quoted identifiers get fallback to string literals + // Single quotes are handled as literals earlier, unquoted identifiers must resolve to columns if id.is_double_quoted() { - *expr = - ast::Expr::Literal(ast::Literal::String(id.as_str().to_string())); + // Convert failed double-quoted identifier to string literal + *expr = Expr::Literal(ast::Literal::String(id.as_str().to_string())); return Ok(WalkControl::Continue); } else { + // Unquoted identifiers must resolve to columns - no fallback crate::bail_parse_error!("no such column: {}", id.as_str()) } } + Expr::Qualified(tbl, id) => { + let normalized_table_name = normalize_ident(tbl.as_str()); + let matching_tbl = referenced_tables + .find_table_and_internal_id_by_identifier(&normalized_table_name); + if matching_tbl.is_none() { + crate::bail_parse_error!("no such table: {}", normalized_table_name); + } + let (tbl_id, tbl) = matching_tbl.unwrap(); + let normalized_id = normalize_ident(id.as_str()); - ast::Expr::Qualified(tbl, id) => { - let tbl_name = normalize_ident(tbl.as_str()); - let Some((tbl_id, tbl_ref)) = - referenced_tables.find_table_and_internal_id_by_identifier(&tbl_name) - else { - crate::bail_parse_error!("no such table: {}", tbl_name); - }; - - let ident = normalize_ident(id.as_str()); - - if let Some(row_id_expr) = parse_row_id(&ident, tbl_id, || false)? { + if let Some(row_id_expr) = parse_row_id(&normalized_id, tbl_id, || false)? { *expr = row_id_expr; + return Ok(WalkControl::Continue); } - - let Some(col_idx) = tbl_ref.columns().iter().position(|c| { + let col_idx = tbl.columns().iter().position(|c| { c.name .as_ref() - .is_some_and(|n| n.eq_ignore_ascii_case(&ident)) - }) else { - crate::bail_parse_error!("no such column: {}", ident); + .is_some_and(|name| name.eq_ignore_ascii_case(&normalized_id)) + }); + let Some(col_idx) = col_idx else { + crate::bail_parse_error!("no such column: {}", normalized_id); }; - - let col = &tbl_ref.columns()[col_idx]; - *expr = ast::Expr::Column { - database: None, + let col = tbl.columns().get(col_idx).unwrap(); + *expr = Expr::Column { + database: None, // TODO: support different databases table: tbl_id, column: col_idx, is_rowid_alias: col.is_rowid_alias, @@ -3443,18 +3459,20 @@ pub fn bind_and_rewrite_expr<'a>( referenced_tables.mark_column_used(tbl_id, col_idx); return Ok(WalkControl::Continue); } + Expr::DoublyQualified(db_name, tbl_name, col_name) => { + let normalized_col_name = normalize_ident(col_name.as_str()); - // db.t.x (requires table to already be present in FROM) - ast::Expr::DoublyQualified(db_name, tbl_name, col_name) => { - let qn = ast::QualifiedName { + // Create a QualifiedName and use existing resolve_database_id method + let qualified_name = ast::QualifiedName { db_name: Some(db_name.clone()), - name: ast::Name::Ident(normalize_ident(tbl_name.as_str())), + name: tbl_name.clone(), alias: None, }; - let db_id = connection.resolve_database_id(&qn)?; + let database_id = connection.resolve_database_id(&qualified_name)?; + // Get the table from the specified database let table = connection - .with_schema(db_id, |schema| schema.get_table(tbl_name.as_str())) + .with_schema(database_id, |schema| schema.get_table(tbl_name.as_str())) .ok_or_else(|| { crate::LimboError::ParseError(format!( "no such table: {}.{}", @@ -3463,14 +3481,14 @@ pub fn bind_and_rewrite_expr<'a>( )) })?; - let ident = normalize_ident(col_name.as_str()); + // Find the column in the table let col_idx = table .columns() .iter() .position(|c| { - c.name - .as_ref() - .is_some_and(|n| n.eq_ignore_ascii_case(&ident)) + c.name.as_ref().is_some_and(|name| { + name.eq_ignore_ascii_case(&normalized_col_name) + }) }) .ok_or_else(|| { crate::LimboError::ParseError(format!( @@ -3481,26 +3499,32 @@ pub fn bind_and_rewrite_expr<'a>( )) })?; - let is_rowid_alias = table.columns()[col_idx].is_rowid_alias; + let col = table.columns().get(col_idx).unwrap(); - // Only allow if the table is already in the FROM clause - let normalized_tbl = normalize_ident(tbl_name.as_str()); - let Some((tbl_id, _)) = referenced_tables - .find_table_and_internal_id_by_identifier(&normalized_tbl) - else { + // Check if this is a rowid alias + let is_rowid_alias = col.is_rowid_alias; + + // Convert to Column expression - since this is a cross-database reference, + // we need to create a synthetic table reference for it + // For now, we'll error if the table isn't already in the referenced tables + let normalized_tbl_name = normalize_ident(tbl_name.as_str()); + let matching_tbl = referenced_tables + .find_table_and_internal_id_by_identifier(&normalized_tbl_name); + + if let Some((tbl_id, _)) = matching_tbl { + // Table is already in referenced tables, use existing internal ID + *expr = Expr::Column { + database: Some(database_id), + table: tbl_id, + column: col_idx, + is_rowid_alias, + }; + referenced_tables.mark_column_used(tbl_id, col_idx); + } else { return Err(crate::LimboError::ParseError(format!( - "table {normalized_tbl} is not in FROM clause - cross-database column references require the table to be explicitly joined" - ))); - }; - - *expr = ast::Expr::Column { - database: Some(db_id), - table: tbl_id, - column: col_idx, - is_rowid_alias, - }; - referenced_tables.mark_column_used(tbl_id, col_idx); - return Ok(WalkControl::Continue); + "table {normalized_tbl_name} is not in FROM clause - cross-database column references require the table to be explicitly joined" + ))); + } } _ => {} } diff --git a/core/translate/planner.rs b/core/translate/planner.rs index 3a94af2b8..4b5270dc6 100644 --- a/core/translate/planner.rs +++ b/core/translate/planner.rs @@ -17,7 +17,6 @@ use crate::{ ast::Limit, function::Func, schema::{Schema, Table}, - translate::expr::walk_expr_mut, util::{exprs_are_equivalent, normalize_ident}, vdbe::builder::TableRefIdCounter, Result, @@ -26,11 +25,9 @@ use crate::{ function::{AggFunc, ExtFunc}, translate::expr::{bind_and_rewrite_expr, ParamState}, }; -use turso_macros::match_ignore_ascii_case; use turso_parser::ast::Literal::Null; use turso_parser::ast::{ - self, As, Expr, FromClause, JoinType, Literal, Materialized, Over, QualifiedName, - TableInternalId, With, + self, As, Expr, FromClause, JoinType, Materialized, Over, QualifiedName, TableInternalId, With, }; pub const ROWID: &str = "rowid"; @@ -265,231 +262,6 @@ fn add_aggregate_if_not_exists( Ok(()) } -pub fn bind_column_references( - top_level_expr: &mut Expr, - referenced_tables: &mut TableReferences, - result_columns: Option<&[ResultSetColumn]>, - connection: &Arc, -) -> Result { - walk_expr_mut( - top_level_expr, - &mut |expr: &mut Expr| -> Result { - match expr { - Expr::Id(id) => { - // true and false are special constants that are effectively aliases for 1 and 0 - // and not identifiers of columns - let id_bytes = id.as_str().as_bytes(); - match_ignore_ascii_case!(match id_bytes { - b"true" | b"false" => { - return Ok(WalkControl::Continue); - } - _ => {} - }); - let normalized_id = normalize_ident(id.as_str()); - - if !referenced_tables.joined_tables().is_empty() { - if let Some(row_id_expr) = parse_row_id( - &normalized_id, - referenced_tables.joined_tables()[0].internal_id, - || referenced_tables.joined_tables().len() != 1, - )? { - *expr = row_id_expr; - - return Ok(WalkControl::Continue); - } - } - let mut match_result = None; - - // First check joined tables - for joined_table in referenced_tables.joined_tables().iter() { - let col_idx = joined_table.table.columns().iter().position(|c| { - c.name - .as_ref() - .is_some_and(|name| name.eq_ignore_ascii_case(&normalized_id)) - }); - if col_idx.is_some() { - if match_result.is_some() { - crate::bail_parse_error!("Column {} is ambiguous", id.as_str()); - } - let col = joined_table.table.columns().get(col_idx.unwrap()).unwrap(); - match_result = Some(( - joined_table.internal_id, - col_idx.unwrap(), - col.is_rowid_alias, - )); - } - } - - // Then check outer query references, if we still didn't find something. - // Normally finding multiple matches for a non-qualified column is an error (column x is ambiguous) - // but in the case of subqueries, the inner query takes precedence. - // For example: - // SELECT * FROM t WHERE x = (SELECT x FROM t2) - // In this case, there is no ambiguity: - // - x in the outer query refers to t.x, - // - x in the inner query refers to t2.x. - if match_result.is_none() { - for outer_ref in referenced_tables.outer_query_refs().iter() { - let col_idx = outer_ref.table.columns().iter().position(|c| { - c.name - .as_ref() - .is_some_and(|name| name.eq_ignore_ascii_case(&normalized_id)) - }); - if col_idx.is_some() { - if match_result.is_some() { - crate::bail_parse_error!("Column {} is ambiguous", id.as_str()); - } - let col = outer_ref.table.columns().get(col_idx.unwrap()).unwrap(); - match_result = Some(( - outer_ref.internal_id, - col_idx.unwrap(), - col.is_rowid_alias, - )); - } - } - } - - if let Some((table_id, col_idx, is_rowid_alias)) = match_result { - *expr = Expr::Column { - database: None, // TODO: support different databases - table: table_id, - column: col_idx, - is_rowid_alias, - }; - referenced_tables.mark_column_used(table_id, col_idx); - return Ok(WalkControl::Continue); - } - - if let Some(result_columns) = result_columns { - for result_column in result_columns.iter() { - if result_column - .name(referenced_tables) - .is_some_and(|name| name.eq_ignore_ascii_case(&normalized_id)) - { - *expr = result_column.expr.clone(); - return Ok(WalkControl::Continue); - } - } - } - // SQLite behavior: Only double-quoted identifiers get fallback to string literals - // Single quotes are handled as literals earlier, unquoted identifiers must resolve to columns - if id.is_double_quoted() { - // Convert failed double-quoted identifier to string literal - *expr = Expr::Literal(Literal::String(id.as_str().to_string())); - Ok(WalkControl::Continue) - } else { - // Unquoted identifiers must resolve to columns - no fallback - crate::bail_parse_error!("no such column: {}", id.as_str()) - } - } - Expr::Qualified(tbl, id) => { - let normalized_table_name = normalize_ident(tbl.as_str()); - let matching_tbl = referenced_tables - .find_table_and_internal_id_by_identifier(&normalized_table_name); - if matching_tbl.is_none() { - crate::bail_parse_error!("no such table: {}", normalized_table_name); - } - let (tbl_id, tbl) = matching_tbl.unwrap(); - let normalized_id = normalize_ident(id.as_str()); - - if let Some(row_id_expr) = parse_row_id(&normalized_id, tbl_id, || false)? { - *expr = row_id_expr; - - return Ok(WalkControl::Continue); - } - let col_idx = tbl.columns().iter().position(|c| { - c.name - .as_ref() - .is_some_and(|name| name.eq_ignore_ascii_case(&normalized_id)) - }); - let Some(col_idx) = col_idx else { - crate::bail_parse_error!("no such column: {}", normalized_id); - }; - let col = tbl.columns().get(col_idx).unwrap(); - *expr = Expr::Column { - database: None, // TODO: support different databases - table: tbl_id, - column: col_idx, - is_rowid_alias: col.is_rowid_alias, - }; - referenced_tables.mark_column_used(tbl_id, col_idx); - Ok(WalkControl::Continue) - } - Expr::DoublyQualified(db_name, tbl_name, col_name) => { - let normalized_col_name = normalize_ident(col_name.as_str()); - - // Create a QualifiedName and use existing resolve_database_id method - let qualified_name = ast::QualifiedName { - db_name: Some(db_name.clone()), - name: tbl_name.clone(), - alias: None, - }; - let database_id = connection.resolve_database_id(&qualified_name)?; - - // Get the table from the specified database - let table = connection - .with_schema(database_id, |schema| schema.get_table(tbl_name.as_str())) - .ok_or_else(|| { - crate::LimboError::ParseError(format!( - "no such table: {}.{}", - db_name.as_str(), - tbl_name.as_str() - )) - })?; - - // Find the column in the table - let col_idx = table - .columns() - .iter() - .position(|c| { - c.name - .as_ref() - .is_some_and(|name| name.eq_ignore_ascii_case(&normalized_col_name)) - }) - .ok_or_else(|| { - crate::LimboError::ParseError(format!( - "Column: {}.{}.{} not found", - db_name.as_str(), - tbl_name.as_str(), - col_name.as_str() - )) - })?; - - let col = table.columns().get(col_idx).unwrap(); - - // Check if this is a rowid alias - let is_rowid_alias = col.is_rowid_alias; - - // Convert to Column expression - since this is a cross-database reference, - // we need to create a synthetic table reference for it - // For now, we'll error if the table isn't already in the referenced tables - let normalized_tbl_name = normalize_ident(tbl_name.as_str()); - let matching_tbl = referenced_tables - .find_table_and_internal_id_by_identifier(&normalized_tbl_name); - - if let Some((tbl_id, _)) = matching_tbl { - // Table is already in referenced tables, use existing internal ID - *expr = Expr::Column { - database: Some(database_id), - table: tbl_id, - column: col_idx, - is_rowid_alias, - }; - referenced_tables.mark_column_used(tbl_id, col_idx); - } else { - return Err(crate::LimboError::ParseError(format!( - "table {normalized_tbl_name} is not in FROM clause - cross-database column references require the table to be explicitly joined" - ))); - } - - Ok(WalkControl::Continue) - } - _ => Ok(WalkControl::Continue), - } - }, - ) -} - #[allow(clippy::too_many_arguments)] fn parse_from_clause_table( schema: &Schema, @@ -779,6 +551,7 @@ pub fn parse_from( table_references: &mut TableReferences, table_ref_counter: &mut TableRefIdCounter, connection: &Arc, + param_ctx: &mut ParamState, ) -> Result<()> { if from.is_none() { return Ok(()); @@ -877,6 +650,7 @@ pub fn parse_from( table_references, table_ref_counter, connection, + param_ctx, )?; } @@ -1094,6 +868,7 @@ fn parse_join( table_references: &mut TableReferences, table_ref_counter: &mut TableRefIdCounter, connection: &Arc, + param_ctx: &mut ParamState, ) -> Result<()> { let ast::JoinedSelectTable { operator: join_operator, @@ -1181,11 +956,12 @@ fn parse_join( } else { None }; - bind_column_references( + bind_and_rewrite_expr( &mut predicate.expr, - table_references, + Some(table_references), None, connection, + param_ctx, )?; } } diff --git a/core/translate/select.rs b/core/translate/select.rs index 20ee62659..7f95fd1d7 100644 --- a/core/translate/select.rs +++ b/core/translate/select.rs @@ -236,6 +236,7 @@ fn prepare_one_select_plan( &mut table_references, table_ref_counter, connection, + param_ctx, )?; // Preallocate space for the result columns From 1a3a41997c2e970957265ddc8da463bb0e197bbf Mon Sep 17 00:00:00 2001 From: PThorpe92 Date: Thu, 18 Sep 2025 19:04:13 -0400 Subject: [PATCH 06/34] Clippy warning, fix needless mut refs and remove import --- core/translate/expr.rs | 2 +- core/translate/select.rs | 6 +++--- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/core/translate/expr.rs b/core/translate/expr.rs index 515d0a263..68ab0bafc 100644 --- a/core/translate/expr.rs +++ b/core/translate/expr.rs @@ -1,7 +1,7 @@ use std::sync::Arc; use tracing::{instrument, Level}; -use turso_parser::ast::{self, As, Expr, TableInternalId, UnaryOperator}; +use turso_parser::ast::{self, As, Expr, UnaryOperator}; use super::emitter::Resolver; use super::optimizer::Optimizable; diff --git a/core/translate/select.rs b/core/translate/select.rs index 7f95fd1d7..a1ec15abe 100644 --- a/core/translate/select.rs +++ b/core/translate/select.rs @@ -417,7 +417,7 @@ fn prepare_one_select_plan( bind_and_rewrite_expr( expr, Some(&mut plan.table_references), - Some(&mut plan.result_columns), + Some(&plan.result_columns), connection, param_ctx, )?; @@ -433,7 +433,7 @@ fn prepare_one_select_plan( bind_and_rewrite_expr( expr, Some(&mut plan.table_references), - Some(&mut plan.result_columns), + Some(&plan.result_columns), connection, param_ctx, )?; @@ -472,7 +472,7 @@ fn prepare_one_select_plan( bind_and_rewrite_expr( &mut o.expr, Some(&mut plan.table_references), - Some(&mut plan.result_columns), + Some(&plan.result_columns), connection, param_ctx, )?; From b86f321ecae3a54d1bcdc8ff8069fe6030350894 Mon Sep 17 00:00:00 2001 From: PThorpe92 Date: Thu, 18 Sep 2025 19:15:14 -0400 Subject: [PATCH 07/34] Add comments to bind_and_rewrite_expr --- core/translate/expr.rs | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/core/translate/expr.rs b/core/translate/expr.rs index 68ab0bafc..56649f715 100644 --- a/core/translate/expr.rs +++ b/core/translate/expr.rs @@ -3250,6 +3250,10 @@ where Ok(WalkControl::Continue) } +/// Context needed to walk all expressions in a INSERT|UPDATE|SELECT|DELETE body, +/// in the order they are encountered, to ensure that the parameters are rewritten from +/// anonymous ("?") to our internal named scheme so when the columns are re-ordered we are able +/// to bind the proper parameter values. pub struct ParamState { /// ALWAYS starts at 1 pub next_param_idx: usize, @@ -3261,6 +3265,9 @@ impl Default for ParamState { } } +/// Rewrite ast::Expr in place, binding Column references/rewriting Expr::Id -> Expr::Column +/// using the provided TableReferences, and replacing anonymous parameters with internal named +/// ones, as well as normalizing any DoublyQualified/Qualified quoted identifiers. pub fn bind_and_rewrite_expr<'a>( top_level_expr: &mut ast::Expr, mut referenced_tables: Option<&'a mut TableReferences>, From c77f523bfebf5d7d50f13928e98feed4bc2286c4 Mon Sep 17 00:00:00 2001 From: Pekka Enberg Date: Fri, 19 Sep 2025 09:02:58 +0300 Subject: [PATCH 08/34] core/mvcc: Wrap LogicalLog in RwLock --- core/mvcc/persistent_storage/mod.rs | 11 +++++------ 1 file changed, 5 insertions(+), 6 deletions(-) diff --git a/core/mvcc/persistent_storage/mod.rs b/core/mvcc/persistent_storage/mod.rs index b92bf081e..c8af16273 100644 --- a/core/mvcc/persistent_storage/mod.rs +++ b/core/mvcc/persistent_storage/mod.rs @@ -1,6 +1,5 @@ -use std::cell::RefCell; use std::fmt::Debug; -use std::sync::Arc; +use std::sync::{Arc, RwLock}; mod logical_log; use crate::mvcc::database::LogRecord; @@ -9,20 +8,20 @@ use crate::types::IOResult; use crate::{File, Result}; pub struct Storage { - logical_log: RefCell, + logical_log: RwLock, } impl Storage { pub fn new(file: Arc) -> Self { Self { - logical_log: RefCell::new(LogicalLog::new(file)), + logical_log: RwLock::new(LogicalLog::new(file)), } } } impl Storage { pub fn log_tx(&self, m: &LogRecord) -> Result> { - self.logical_log.borrow_mut().log_tx(m) + self.logical_log.write().unwrap().log_tx(m) } pub fn read_tx_log(&self) -> Result> { @@ -34,7 +33,7 @@ impl Storage { } pub fn sync(&self) -> Result> { - self.logical_log.borrow_mut().sync() + self.logical_log.write().unwrap().sync() } } From 0b3317d449fab4fea8de64ad6cf0c28410f82199 Mon Sep 17 00:00:00 2001 From: Glauber Costa Date: Mon, 8 Sep 2025 19:59:33 -0700 Subject: [PATCH 09/34] extract columns from all tables in case of joins. Our code for view needs to extract the list of columns used in the view. We currently extract only from "the base table", but once we have joins, we need a more complex structure, that keeps the mapping of (tables, columns). This actually affects both views and materialized views: for views, the queries with joins work just fine, because views are just aliases for a query. But the list of columns returned by pragma table_info on the view is incorrect. We add a test to make sure it is fixed. For materialized views, we add extensive tests to make sure that the columns are extracted correctly. --- core/incremental/cursor.rs | 2 +- core/incremental/view.rs | 66 +++----- core/schema.rs | 7 +- core/translate/planner.rs | 2 +- core/translate/pragma.rs | 3 +- core/translate/view.rs | 3 +- core/util.rs | 321 +++++++++++++++++++++++++++---------- 7 files changed, 263 insertions(+), 141 deletions(-) diff --git a/core/incremental/cursor.rs b/core/incremental/cursor.rs index 9bf39f53d..bf500b450 100644 --- a/core/incremental/cursor.rs +++ b/core/incremental/cursor.rs @@ -355,7 +355,7 @@ mod tests { "View not materialized".to_string(), )); } - let num_columns = view.columns.len(); + let num_columns = view.column_schema.columns.len(); drop(view); // Create a btree cursor diff --git a/core/incremental/view.rs b/core/incremental/view.rs index 9a200c830..f2acabcdc 100644 --- a/core/incremental/view.rs +++ b/core/incremental/view.rs @@ -1,11 +1,11 @@ use super::compiler::{DbspCircuit, DbspCompiler, DeltaSet}; use super::dbsp::Delta; use super::operator::{ComputationTracker, FilterPredicate}; -use crate::schema::{BTreeTable, Column, Schema}; +use crate::schema::{BTreeTable, Schema}; use crate::storage::btree::BTreeCursor; use crate::translate::logical::LogicalPlanBuilder; use crate::types::{IOResult, Value}; -use crate::util::extract_view_columns; +use crate::util::{extract_view_columns, ViewColumnSchema}; use crate::{return_if_io, LimboError, Pager, Result, Statement}; use std::cell::RefCell; use std::collections::HashMap; @@ -173,8 +173,8 @@ pub struct IncrementalView { // All tables referenced by this view (from FROM clause and JOINs) referenced_tables: Vec>, - // The view's output columns with their types - pub columns: Vec, + // The view's column schema with table relationships + pub column_schema: ViewColumnSchema, // State machine for population populate_state: PopulateState, // Computation tracker for statistics @@ -227,11 +227,16 @@ impl IncrementalView { /// Get an iterator over column names, using enumerated naming for unnamed columns pub fn column_names(&self) -> impl Iterator + '_ { - self.columns.iter().enumerate().map(|(i, col)| { - col.name - .clone() - .unwrap_or_else(|| format!("column{}", i + 1)) - }) + self.column_schema + .columns + .iter() + .enumerate() + .map(|(i, vc)| { + vc.column + .name + .clone() + .unwrap_or_else(|| format!("column{}", i + 1)) + }) } /// Check if this view has the same SQL definition as the provided SQL string @@ -251,24 +256,9 @@ impl IncrementalView { pub fn validate_and_extract_columns( select: &ast::Select, schema: &Schema, - ) -> Result> { - // For now, just extract columns from a simple select - // This will need to be expanded to handle joins, aggregates, etc. - - // Get the base table name - let base_table_name = Self::extract_base_table(select).ok_or_else(|| { - LimboError::ParseError("Cannot extract base table from SELECT".to_string()) - })?; - - // Get the table from schema - let table = schema - .get_table(&base_table_name) - .and_then(|t| t.btree()) - .ok_or_else(|| LimboError::ParseError(format!("Table {base_table_name} not found")))?; - - // For now, return all columns from the base table - // In the future, this should parse the select list and handle projections - Ok(table.columns.clone()) + ) -> Result { + // Use the shared function to extract columns with full table context + extract_view_columns(select, schema) } pub fn from_sql( @@ -314,7 +304,7 @@ impl IncrementalView { let where_predicate = FilterPredicate::from_select(&select)?; // Extract output columns using the shared function - let view_columns = extract_view_columns(&select, schema); + let column_schema = extract_view_columns(&select, schema)?; let (join_tables, join_condition) = Self::extract_join_info(&select); if join_tables.is_some() || join_condition.is_some() { @@ -331,7 +321,7 @@ impl IncrementalView { where_predicate, select.clone(), referenced_tables, - view_columns, + column_schema, schema, main_data_root, internal_state_root, @@ -345,7 +335,7 @@ impl IncrementalView { where_predicate: FilterPredicate, select_stmt: ast::Select, referenced_tables: Vec>, - columns: Vec, + column_schema: ViewColumnSchema, schema: &Schema, main_data_root: usize, internal_state_root: usize, @@ -369,7 +359,7 @@ impl IncrementalView { select_stmt, circuit, referenced_tables, - columns, + column_schema, populate_state: PopulateState::Start, tracker, root_page: main_data_root, @@ -457,20 +447,6 @@ impl IncrementalView { Ok(tables) } - /// Extract the base table name from a SELECT statement (for non-join cases) - fn extract_base_table(select: &ast::Select) -> Option { - if let ast::OneSelect::Select { - from: Some(ref from), - .. - } = select.body.select - { - if let ast::SelectTable::Table(name, _, _) = from.select.as_ref() { - return Some(name.name.as_str().to_string()); - } - } - None - } - /// Generate the SQL query for populating the view from its source table fn sql_for_populate(&self) -> crate::Result { // Get the first table from referenced tables diff --git a/core/schema.rs b/core/schema.rs index 6d510b3a3..71cbb4932 100644 --- a/core/schema.rs +++ b/core/schema.rs @@ -527,7 +527,7 @@ impl Schema { let table = Arc::new(Table::BTree(Arc::new(BTreeTable { name: view_name.clone(), root_page: main_root, - columns: incremental_view.columns.clone(), + columns: incremental_view.column_schema.flat_columns(), primary_key_columns: Vec::new(), has_rowid: true, is_strict: false, @@ -673,11 +673,12 @@ impl Schema { .. } => { // Extract actual columns from the SELECT statement - let view_columns = crate::util::extract_view_columns(&select, self); + let view_column_schema = + crate::util::extract_view_columns(&select, self)?; // If column names were provided in CREATE VIEW (col1, col2, ...), // use them to rename the columns - let mut final_columns = view_columns; + let mut final_columns = view_column_schema.flat_columns(); for (i, indexed_col) in column_names.iter().enumerate() { if let Some(col) = final_columns.get_mut(i) { col.name = Some(indexed_col.col_name.to_string()); diff --git a/core/translate/planner.rs b/core/translate/planner.rs index 14f422860..36a99aaa8 100644 --- a/core/translate/planner.rs +++ b/core/translate/planner.rs @@ -663,7 +663,7 @@ fn parse_table( let btree_table = Arc::new(crate::schema::BTreeTable { name: view_guard.name().to_string(), root_page, - columns: view_guard.columns.clone(), + columns: view_guard.column_schema.flat_columns(), primary_key_columns: Vec::new(), has_rowid: true, is_strict: false, diff --git a/core/translate/pragma.rs b/core/translate/pragma.rs index 6336faf25..7fa74e9ca 100644 --- a/core/translate/pragma.rs +++ b/core/translate/pragma.rs @@ -508,7 +508,8 @@ fn query_pragma( emit_columns_for_table_info(&mut program, table.columns(), base_reg); } else if let Some(view_mutex) = schema.get_materialized_view(&name) { let view = view_mutex.lock().unwrap(); - emit_columns_for_table_info(&mut program, &view.columns, base_reg); + let flat_columns = view.column_schema.flat_columns(); + emit_columns_for_table_info(&mut program, &flat_columns, base_reg); } else if let Some(view) = schema.get_view(&name) { emit_columns_for_table_info(&mut program, &view.columns, base_reg); } diff --git a/core/translate/view.rs b/core/translate/view.rs index f89f29817..9ff8e6c89 100644 --- a/core/translate/view.rs +++ b/core/translate/view.rs @@ -42,7 +42,8 @@ pub fn translate_create_materialized_view( // storing invalid view definitions use crate::incremental::view::IncrementalView; use crate::schema::BTreeTable; - let view_columns = IncrementalView::validate_and_extract_columns(select_stmt, schema)?; + let view_column_schema = IncrementalView::validate_and_extract_columns(select_stmt, schema)?; + let view_columns = view_column_schema.flat_columns(); // Reconstruct the SQL string for storage let sql = create_materialized_view_to_str(view_name, select_stmt); diff --git a/core/util.rs b/core/util.rs index 2d945ec11..faffc72cf 100644 --- a/core/util.rs +++ b/core/util.rs @@ -1066,9 +1066,59 @@ pub fn extract_column_name_from_expr(expr: impl AsRef) -> Option, +} + +/// Information about a column in the view's output +#[derive(Debug, Clone)] +pub struct ViewColumn { + /// Index into ViewColumnSchema.tables indicating which table this column comes from + /// For computed columns or constants, this will be usize::MAX + pub table_index: usize, + /// The actual column definition + pub column: Column, +} + +/// Schema information for a view, tracking which columns come from which tables +#[derive(Debug, Clone)] +pub struct ViewColumnSchema { + /// All tables referenced by the view (in order of appearance) + pub tables: Vec, + /// The view's output columns with their table associations + pub columns: Vec, +} + +impl ViewColumnSchema { + /// Get all columns as a flat vector (without table association info) + pub fn flat_columns(&self) -> Vec { + self.columns.iter().map(|vc| vc.column.clone()).collect() + } + + /// Get columns that belong to a specific table + pub fn table_columns(&self, table_index: usize) -> Vec { + self.columns + .iter() + .filter(|vc| vc.table_index == table_index) + .map(|vc| vc.column.clone()) + .collect() + } +} + /// Extract column information from a SELECT statement for view creation -pub fn extract_view_columns(select_stmt: &ast::Select, schema: &Schema) -> Vec { +pub fn extract_view_columns( + select_stmt: &ast::Select, + schema: &Schema, +) -> Result { + let mut tables = Vec::new(); let mut columns = Vec::new(); + let mut column_name_counts: HashMap = HashMap::new(); + // Navigate to the first SELECT in the statement if let ast::OneSelect::Select { ref from, @@ -1076,23 +1126,85 @@ pub fn extract_view_columns(select_stmt: &ast::Select, schema: &Schema) -> Vec { + let table_name = if qualified_name.db_name.is_some() { + // Include database qualifier if present + qualified_name.to_string() + } else { + normalize_ident(qualified_name.name.as_str()) + }; + tables.push(ViewTable { + name: table_name.clone(), + alias: alias.as_ref().map(|a| match a { + ast::As::As(name) => normalize_ident(name.as_str()), + ast::As::Elided(name) => normalize_ident(name.as_str()), + }), + }); + } + _ => { + // Handle other types like subqueries if needed + } } - } else { - None + + // Add tables from JOINs + for join in &from.joins { + match join.table.as_ref() { + ast::SelectTable::Table(qualified_name, alias, _) => { + let table_name = if qualified_name.db_name.is_some() { + // Include database qualifier if present + qualified_name.to_string() + } else { + normalize_ident(qualified_name.name.as_str()) + }; + tables.push(ViewTable { + name: table_name.clone(), + alias: alias.as_ref().map(|a| match a { + ast::As::As(name) => normalize_ident(name.as_str()), + ast::As::Elided(name) => normalize_ident(name.as_str()), + }), + }); + } + _ => { + // Handle other types like subqueries if needed + } + } + } + } + + // Helper function to find table index by name or alias + let find_table_index = |name: &str| -> Option { + tables + .iter() + .position(|t| t.name == name || t.alias.as_ref().is_some_and(|a| a == name)) }; - // Get the table for column resolution - let _table = table_name.as_ref().and_then(|name| schema.get_table(name)); + // Process each column in the SELECT list - for (i, result_col) in select_columns.iter().enumerate() { + for result_col in select_columns.iter() { match result_col { ast::ResultColumn::Expr(expr, alias) => { - let name = alias + // Figure out which table this expression comes from + let table_index = match expr.as_ref() { + ast::Expr::Qualified(table_ref, _col_name) => { + // Column qualified with table name + find_table_index(table_ref.as_str()) + } + ast::Expr::Id(_col_name) => { + // Unqualified column - would need to resolve based on schema + // For now, assume it's from the first table if there is one + if !tables.is_empty() { + Some(0) + } else { + None + } + } + _ => None, // Expression, literal, etc. + }; + + let col_name = alias .as_ref() .map(|a| match a { ast::As::Elided(name) => name.as_str().to_string(), @@ -1103,41 +1215,65 @@ pub fn extract_view_columns(select_stmt: &ast::Select, schema: &Schema) -> Vec { - // For SELECT *, expand to all columns from the table - if let Some(ref table_name) = table_name { - if let Some(table) = schema.get_table(table_name) { - // Copy all columns from the table, but adjust for view constraints - for table_column in table.columns() { - columns.push(Column { - name: table_column.name.clone(), - ty: table_column.ty, - ty_str: table_column.ty_str.clone(), - primary_key: false, // Views don't have primary keys - is_rowid_alias: false, - notnull: false, // Views typically don't enforce NOT NULL - default: None, // Views don't have default values - unique: false, - collation: table_column.collation, - hidden: false, + // For SELECT *, expand to all columns from all tables + for (table_idx, table) in tables.iter().enumerate() { + if let Some(table_obj) = schema.get_table(&table.name) { + for table_column in table_obj.columns() { + let col_name = + table_column.name.clone().unwrap_or_else(|| "?".to_string()); + + // Handle duplicate column names by adding suffix + let final_name = + if let Some(count) = column_name_counts.get_mut(&col_name) { + *count += 1; + format!("{}:{}", col_name, *count - 1) + } else { + column_name_counts.insert(col_name.clone(), 1); + col_name.clone() + }; + + columns.push(ViewColumn { + table_index: table_idx, + column: Column { + name: Some(final_name), + ty: table_column.ty, + ty_str: table_column.ty_str.clone(), + primary_key: false, + is_rowid_alias: false, + notnull: false, + default: None, + unique: false, + collation: table_column.collation, + hidden: false, + }, }); } - } else { - // Table not found, create placeholder - columns.push(Column { + } + } + + // If no tables, create a placeholder + if tables.is_empty() { + columns.push(ViewColumn { + table_index: usize::MAX, + column: Column { name: Some("*".to_string()), ty: Type::Text, ty_str: "TEXT".to_string(), @@ -1148,63 +1284,70 @@ pub fn extract_view_columns(select_stmt: &ast::Select, schema: &Schema) -> Vec { + ast::ResultColumn::TableStar(table_ref) => { // For table.*, expand to all columns from the specified table - let table_name_str = normalize_ident(table_name.as_str()); - if let Some(table) = schema.get_table(&table_name_str) { - // Copy all columns from the table, but adjust for view constraints - for table_column in table.columns() { - columns.push(Column { - name: table_column.name.clone(), - ty: table_column.ty, - ty_str: table_column.ty_str.clone(), - primary_key: false, - is_rowid_alias: false, - notnull: false, - default: None, - unique: false, - collation: table_column.collation, - hidden: false, + let table_name_str = normalize_ident(table_ref.as_str()); + if let Some(table_idx) = find_table_index(&table_name_str) { + if let Some(table) = schema.get_table(&tables[table_idx].name) { + for table_column in table.columns() { + let col_name = + table_column.name.clone().unwrap_or_else(|| "?".to_string()); + + // Handle duplicate column names by adding suffix + let final_name = + if let Some(count) = column_name_counts.get_mut(&col_name) { + *count += 1; + format!("{}:{}", col_name, *count - 1) + } else { + column_name_counts.insert(col_name.clone(), 1); + col_name.clone() + }; + + columns.push(ViewColumn { + table_index: table_idx, + column: Column { + name: Some(final_name), + ty: table_column.ty, + ty_str: table_column.ty_str.clone(), + primary_key: false, + is_rowid_alias: false, + notnull: false, + default: None, + unique: false, + collation: table_column.collation, + hidden: false, + }, + }); + } + } else { + // Table not found, create placeholder + columns.push(ViewColumn { + table_index: usize::MAX, + column: Column { + name: Some(format!("{table_name_str}.*")), + ty: Type::Text, + ty_str: "TEXT".to_string(), + primary_key: false, + is_rowid_alias: false, + notnull: false, + default: None, + unique: false, + collation: None, + hidden: false, + }, }); } - } else { - // Table not found, create placeholder - columns.push(Column { - name: Some(format!("{table_name_str}.*")), - ty: Type::Text, - ty_str: "TEXT".to_string(), - primary_key: false, - is_rowid_alias: false, - notnull: false, - default: None, - unique: false, - collation: None, - hidden: false, - }); } } } } } - columns + + Ok(ViewColumnSchema { tables, columns }) } #[cfg(test)] From 5b4a6e5c2d33c17745aeb5041a62c4880fdaa27c Mon Sep 17 00:00:00 2001 From: Glauber Costa Date: Mon, 8 Sep 2025 17:28:15 -0700 Subject: [PATCH 10/34] view: catch all tables mentioned, instead of just one. Ahead of the implementation of JOINs, we need to evolve the IncrementalView, which currently only accepts a single base table, to keep a list of tables mentioned in the statement. --- core/incremental/view.rs | 9 +++++++-- 1 file changed, 7 insertions(+), 2 deletions(-) diff --git a/core/incremental/view.rs b/core/incremental/view.rs index f2acabcdc..8b32c5dcc 100644 --- a/core/incremental/view.rs +++ b/core/incremental/view.rs @@ -566,8 +566,13 @@ impl IncrementalView { // machinery (next step is which index is best to use, etc) let query = self.sql_for_populate()?; - // Prepare the statement - let stmt = conn.prepare(&query)?; + // Create a new connection for reading to avoid transaction conflicts + // This allows us to read from tables while the parent transaction is writing the view + // The statement holds a reference to this connection, keeping it alive + let read_conn = conn.db.connect()?; + + // Prepare the statement using the read connection + let stmt = read_conn.prepare(&query)?; self.populate_state = PopulateState::Processing { stmt: Box::new(stmt), From 2e7a45559b346c65967325628d06e2152f932553 Mon Sep 17 00:00:00 2001 From: Glauber Costa Date: Thu, 4 Sep 2025 19:35:49 -0500 Subject: [PATCH 11/34] add joins to the logical plan --- core/translate/logical.rs | 764 ++++++++++++++++++++++++++++++++++++-- 1 file changed, 735 insertions(+), 29 deletions(-) diff --git a/core/translate/logical.rs b/core/translate/logical.rs index aa71f047b..6a8b0a6c2 100644 --- a/core/translate/logical.rs +++ b/core/translate/logical.rs @@ -25,6 +25,9 @@ type PreprocessAggregateResult = ( Vec, // modified_aggr_exprs ); +/// Result type for parsing join conditions +type JoinConditionsResult = (Vec<(LogicalExpr, LogicalExpr)>, Option); + /// Schema information for logical plan nodes #[derive(Debug, Clone, PartialEq)] pub struct LogicalSchema { @@ -66,8 +69,8 @@ pub enum LogicalPlan { Filter(Filter), /// Aggregate - GROUP BY with aggregate functions Aggregate(Aggregate), - // TODO: Join - combining two relations (not yet implemented) - // Join(Join), + /// Join - combining two relations + Join(Join), /// Sort - ORDER BY clause Sort(Sort), /// Limit - LIMIT/OFFSET clause @@ -95,7 +98,7 @@ impl LogicalPlan { LogicalPlan::Projection(p) => &p.schema, LogicalPlan::Filter(f) => f.input.schema(), LogicalPlan::Aggregate(a) => &a.schema, - // LogicalPlan::Join(j) => &j.schema, + LogicalPlan::Join(j) => &j.schema, LogicalPlan::Sort(s) => s.input.schema(), LogicalPlan::Limit(l) => l.input.schema(), LogicalPlan::TableScan(t) => &t.schema, @@ -133,26 +136,26 @@ pub struct Aggregate { pub schema: SchemaRef, } -// TODO: Join operator (not yet implemented) -// #[derive(Debug, Clone, PartialEq)] -// pub struct Join { -// pub left: Arc, -// pub right: Arc, -// pub join_type: JoinType, -// pub on: Vec<(LogicalExpr, LogicalExpr)>, // Equijoin conditions -// pub filter: Option, // Additional filter conditions -// pub schema: SchemaRef, -// } +/// Types of joins +#[derive(Debug, Clone, Copy, PartialEq)] +pub enum JoinType { + Inner, + Left, + Right, + Full, + Cross, +} -// TODO: Types of joins (not yet implemented) -// #[derive(Debug, Clone, Copy, PartialEq)] -// pub enum JoinType { -// Inner, -// Left, -// Right, -// Full, -// Cross, -// } +/// Join operator - combines two relations +#[derive(Debug, Clone, PartialEq)] +pub struct Join { + pub left: Arc, + pub right: Arc, + pub join_type: JoinType, + pub on: Vec<(LogicalExpr, LogicalExpr)>, // Equijoin conditions (left_expr, right_expr) + pub filter: Option, // Additional filter conditions + pub schema: SchemaRef, +} /// Sort operator - ORDER BY #[derive(Debug, Clone, PartialEq)] @@ -570,14 +573,279 @@ impl<'a> LogicalPlanBuilder<'a> { // Build JOIN fn build_join( &mut self, - _left: LogicalPlan, - _right: LogicalPlan, - _op: &ast::JoinOperator, - _constraint: &Option, + left: LogicalPlan, + right: LogicalPlan, + op: &ast::JoinOperator, + constraint: &Option, ) -> Result { - Err(LimboError::ParseError( - "JOINs are not yet supported in logical plans".to_string(), - )) + // Determine join type + let join_type = match op { + ast::JoinOperator::Comma => JoinType::Cross, // Comma is essentially a cross join + ast::JoinOperator::TypedJoin(Some(jt)) => { + // Check the join type flags + // Note: JoinType can have multiple flags set + if jt.contains(ast::JoinType::NATURAL) { + // Natural joins need special handling - find common columns + return self.build_natural_join(left, right, JoinType::Inner); + } else if jt.contains(ast::JoinType::LEFT) + && jt.contains(ast::JoinType::RIGHT) + && jt.contains(ast::JoinType::OUTER) + { + // FULL OUTER JOIN (has LEFT, RIGHT, and OUTER) + JoinType::Full + } else if jt.contains(ast::JoinType::LEFT) && jt.contains(ast::JoinType::OUTER) { + JoinType::Left + } else if jt.contains(ast::JoinType::RIGHT) && jt.contains(ast::JoinType::OUTER) { + JoinType::Right + } else if jt.contains(ast::JoinType::OUTER) + && !jt.contains(ast::JoinType::LEFT) + && !jt.contains(ast::JoinType::RIGHT) + { + // Plain OUTER JOIN should also be FULL + JoinType::Full + } else if jt.contains(ast::JoinType::LEFT) { + JoinType::Left + } else if jt.contains(ast::JoinType::RIGHT) { + JoinType::Right + } else if jt.contains(ast::JoinType::CROSS) + || (jt.contains(ast::JoinType::INNER) && jt.contains(ast::JoinType::CROSS)) + { + JoinType::Cross + } else { + JoinType::Inner // Default to inner + } + } + ast::JoinOperator::TypedJoin(None) => JoinType::Inner, // Default JOIN is INNER JOIN + }; + + // Build join conditions + let (on_conditions, filter) = match constraint { + Some(ast::JoinConstraint::On(expr)) => { + // Parse ON clause into equijoin conditions and filters + self.parse_join_conditions(expr, left.schema(), right.schema())? + } + Some(ast::JoinConstraint::Using(columns)) => { + // Build equijoin conditions from USING clause + let on = self.build_using_conditions(columns, left.schema(), right.schema())?; + (on, None) + } + None => { + // Cross join or natural join + (Vec::new(), None) + } + }; + + // Build combined schema + let schema = self.build_join_schema(&left, &right, &join_type)?; + + Ok(LogicalPlan::Join(Join { + left: Arc::new(left), + right: Arc::new(right), + join_type, + on: on_conditions, + filter, + schema, + })) + } + + // Helper: Parse join conditions into equijoins and filters + fn parse_join_conditions( + &mut self, + expr: &ast::Expr, + left_schema: &SchemaRef, + right_schema: &SchemaRef, + ) -> Result { + // For now, we'll handle simple equality conditions + // More complex conditions will go into the filter + let mut equijoins = Vec::new(); + let mut filters = Vec::new(); + + // Try to extract equijoin conditions from the expression + self.extract_equijoin_conditions( + expr, + left_schema, + right_schema, + &mut equijoins, + &mut filters, + )?; + + let filter = if filters.is_empty() { + None + } else { + // Combine multiple filters with AND + Some( + filters + .into_iter() + .reduce(|acc, e| LogicalExpr::BinaryExpr { + left: Box::new(acc), + op: BinaryOperator::And, + right: Box::new(e), + }) + .unwrap(), + ) + }; + + Ok((equijoins, filter)) + } + + // Helper: Extract equijoin conditions from expression + fn extract_equijoin_conditions( + &mut self, + expr: &ast::Expr, + left_schema: &SchemaRef, + right_schema: &SchemaRef, + equijoins: &mut Vec<(LogicalExpr, LogicalExpr)>, + filters: &mut Vec, + ) -> Result<()> { + match expr { + ast::Expr::Binary(lhs, ast::Operator::Equals, rhs) => { + // Check if this is an equijoin condition (left.col = right.col) + let left_expr = self.build_expr(lhs, left_schema)?; + let right_expr = self.build_expr(rhs, right_schema)?; + + // For simplicity, we'll check if one references left and one references right + // In a real implementation, we'd need more sophisticated column resolution + equijoins.push((left_expr, right_expr)); + } + ast::Expr::Binary(lhs, ast::Operator::And, rhs) => { + // Recursively extract from AND conditions + self.extract_equijoin_conditions( + lhs, + left_schema, + right_schema, + equijoins, + filters, + )?; + self.extract_equijoin_conditions( + rhs, + left_schema, + right_schema, + equijoins, + filters, + )?; + } + _ => { + // Other conditions go into the filter + // We need a combined schema to build the expression + let combined_schema = self.combine_schemas(left_schema, right_schema)?; + let filter_expr = self.build_expr(expr, &combined_schema)?; + filters.push(filter_expr); + } + } + Ok(()) + } + + // Helper: Build equijoin conditions from USING clause + fn build_using_conditions( + &mut self, + columns: &[ast::Name], + left_schema: &SchemaRef, + right_schema: &SchemaRef, + ) -> Result> { + let mut conditions = Vec::new(); + + for col_name in columns { + let name = Self::name_to_string(col_name); + + // Find the column in both schemas + let _left_idx = left_schema + .columns + .iter() + .position(|(n, _)| n == &name) + .ok_or_else(|| { + LimboError::ParseError(format!("Column {name} not found in left table")) + })?; + let _right_idx = right_schema + .columns + .iter() + .position(|(n, _)| n == &name) + .ok_or_else(|| { + LimboError::ParseError(format!("Column {name} not found in right table")) + })?; + + conditions.push(( + LogicalExpr::Column(Column { + name: name.clone(), + table: None, // Will be resolved later + }), + LogicalExpr::Column(Column { + name, + table: None, // Will be resolved later + }), + )); + } + + Ok(conditions) + } + + // Helper: Build natural join by finding common columns + fn build_natural_join( + &mut self, + left: LogicalPlan, + right: LogicalPlan, + join_type: JoinType, + ) -> Result { + let left_schema = left.schema(); + let right_schema = right.schema(); + + // Find common column names + let mut common_columns = Vec::new(); + for (left_name, _) in &left_schema.columns { + if right_schema.columns.iter().any(|(n, _)| n == left_name) { + common_columns.push(ast::Name::Ident(left_name.clone())); + } + } + + if common_columns.is_empty() { + // Natural join with no common columns becomes a cross join + let schema = self.build_join_schema(&left, &right, &JoinType::Cross)?; + return Ok(LogicalPlan::Join(Join { + left: Arc::new(left), + right: Arc::new(right), + join_type: JoinType::Cross, + on: Vec::new(), + filter: None, + schema, + })); + } + + // Build equijoin conditions for common columns + let on = self.build_using_conditions(&common_columns, left_schema, right_schema)?; + let schema = self.build_join_schema(&left, &right, &join_type)?; + + Ok(LogicalPlan::Join(Join { + left: Arc::new(left), + right: Arc::new(right), + join_type, + on, + filter: None, + schema, + })) + } + + // Helper: Build schema for join result + fn build_join_schema( + &self, + left: &LogicalPlan, + right: &LogicalPlan, + _join_type: &JoinType, + ) -> Result { + let left_schema = left.schema(); + let right_schema = right.schema(); + + // For now, simply concatenate the schemas + // In a real implementation, we'd handle column name conflicts and nullable columns + let mut columns = left_schema.columns.clone(); + columns.extend(right_schema.columns.clone()); + + Ok(Arc::new(LogicalSchema::new(columns))) + } + + // Helper: Combine two schemas for expression building + fn combine_schemas(&self, left: &SchemaRef, right: &SchemaRef) -> Result { + let mut columns = left.columns.clone(); + columns.extend(right.columns.clone()); + Ok(Arc::new(LogicalSchema::new(columns))) } // Build projection @@ -1974,6 +2242,67 @@ mod tests { }; schema.add_btree_table(Arc::new(orders_table)); + // Create products table + let products_table = BTreeTable { + name: "products".to_string(), + root_page: 4, + primary_key_columns: vec![("id".to_string(), turso_parser::ast::SortOrder::Asc)], + columns: vec![ + SchemaColumn { + name: Some("id".to_string()), + ty: Type::Integer, + ty_str: "INTEGER".to_string(), + primary_key: true, + is_rowid_alias: true, + notnull: true, + default: None, + unique: false, + collation: None, + hidden: false, + }, + SchemaColumn { + name: Some("name".to_string()), + ty: Type::Text, + ty_str: "TEXT".to_string(), + primary_key: false, + is_rowid_alias: false, + notnull: false, + default: None, + unique: false, + collation: None, + hidden: false, + }, + SchemaColumn { + name: Some("price".to_string()), + ty: Type::Real, + ty_str: "REAL".to_string(), + primary_key: false, + is_rowid_alias: false, + notnull: false, + default: None, + unique: false, + collation: None, + hidden: false, + }, + SchemaColumn { + name: Some("product_id".to_string()), + ty: Type::Integer, + ty_str: "INTEGER".to_string(), + primary_key: false, + is_rowid_alias: false, + notnull: false, + default: None, + unique: false, + collation: None, + hidden: false, + }, + ], + has_rowid: true, + is_strict: false, + unique_sets: vec![], + }; + schema.add_btree_table(Arc::new(products_table)); + schema } @@ -3086,4 +3415,381 @@ mod tests { _ => panic!("Expected Projection as top-level operator, got: {plan:?}"), } } + + // ===== JOIN TESTS ===== + + #[test] + fn test_inner_join() { + let schema = create_test_schema(); + let sql = "SELECT * FROM users u INNER JOIN orders o ON u.id = o.user_id"; + let plan = parse_and_build(sql, &schema).unwrap(); + + match plan { + LogicalPlan::Projection(proj) => { + match &*proj.input { + LogicalPlan::Join(join) => { + assert_eq!(join.join_type, JoinType::Inner); + assert!(!join.on.is_empty(), "Should have join conditions"); + + // Check left input is users + match &*join.left { + LogicalPlan::TableScan(scan) => { + assert_eq!(scan.table_name, "users"); + } + _ => panic!("Expected TableScan for left input"), + } + + // Check right input is orders + match &*join.right { + LogicalPlan::TableScan(scan) => { + assert_eq!(scan.table_name, "orders"); + } + _ => panic!("Expected TableScan for right input"), + } + } + _ => panic!("Expected Join under Projection"), + } + } + _ => panic!("Expected Projection at top level"), + } + } + + #[test] + fn test_left_join() { + let schema = create_test_schema(); + let sql = "SELECT u.name, o.amount FROM users u LEFT JOIN orders o ON u.id = o.user_id"; + let plan = parse_and_build(sql, &schema).unwrap(); + + match plan { + LogicalPlan::Projection(proj) => { + assert_eq!(proj.exprs.len(), 2); // name and amount + match &*proj.input { + LogicalPlan::Join(join) => { + assert_eq!(join.join_type, JoinType::Left); + assert!(!join.on.is_empty(), "Should have join conditions"); + } + _ => panic!("Expected Join under Projection"), + } + } + _ => panic!("Expected Projection at top level"), + } + } + + #[test] + fn test_right_join() { + let schema = create_test_schema(); + let sql = "SELECT * FROM orders o RIGHT JOIN users u ON o.user_id = u.id"; + let plan = parse_and_build(sql, &schema).unwrap(); + + match plan { + LogicalPlan::Projection(proj) => match &*proj.input { + LogicalPlan::Join(join) => { + assert_eq!(join.join_type, JoinType::Right); + assert!(!join.on.is_empty(), "Should have join conditions"); + } + _ => panic!("Expected Join under Projection"), + }, + _ => panic!("Expected Projection at top level"), + } + } + + #[test] + fn test_full_outer_join() { + let schema = create_test_schema(); + let sql = "SELECT * FROM users u FULL OUTER JOIN orders o ON u.id = o.user_id"; + let plan = parse_and_build(sql, &schema).unwrap(); + + match plan { + LogicalPlan::Projection(proj) => match &*proj.input { + LogicalPlan::Join(join) => { + assert_eq!(join.join_type, JoinType::Full); + assert!(!join.on.is_empty(), "Should have join conditions"); + } + _ => panic!("Expected Join under Projection"), + }, + _ => panic!("Expected Projection at top level"), + } + } + + #[test] + fn test_cross_join() { + let schema = create_test_schema(); + let sql = "SELECT * FROM users CROSS JOIN orders"; + let plan = parse_and_build(sql, &schema).unwrap(); + + match plan { + LogicalPlan::Projection(proj) => match &*proj.input { + LogicalPlan::Join(join) => { + assert_eq!(join.join_type, JoinType::Cross); + assert!(join.on.is_empty(), "Cross join should have no conditions"); + assert!(join.filter.is_none(), "Cross join should have no filter"); + } + _ => panic!("Expected Join under Projection"), + }, + _ => panic!("Expected Projection at top level"), + } + } + + #[test] + fn test_join_with_multiple_conditions() { + let schema = create_test_schema(); + let sql = "SELECT * FROM users u JOIN orders o ON u.id = o.user_id AND u.age > 18"; + let plan = parse_and_build(sql, &schema).unwrap(); + + match plan { + LogicalPlan::Projection(proj) => { + match &*proj.input { + LogicalPlan::Join(join) => { + assert_eq!(join.join_type, JoinType::Inner); + // Should have at least one equijoin condition + assert!(!join.on.is_empty(), "Should have join conditions"); + // Additional conditions may be in filter + // The exact distribution depends on our implementation + } + _ => panic!("Expected Join under Projection"), + } + } + _ => panic!("Expected Projection at top level"), + } + } + + #[test] + fn test_join_using_clause() { + let schema = create_test_schema(); + // Note: Both tables should have an 'id' column for this to work + let sql = "SELECT * FROM users JOIN orders USING (id)"; + let plan = parse_and_build(sql, &schema).unwrap(); + + match plan { + LogicalPlan::Projection(proj) => match &*proj.input { + LogicalPlan::Join(join) => { + assert_eq!(join.join_type, JoinType::Inner); + assert!( + !join.on.is_empty(), + "USING clause should create join conditions" + ); + } + _ => panic!("Expected Join under Projection"), + }, + _ => panic!("Expected Projection at top level"), + } + } + + #[test] + fn test_natural_join() { + let schema = create_test_schema(); + let sql = "SELECT * FROM users NATURAL JOIN orders"; + let plan = parse_and_build(sql, &schema).unwrap(); + + match plan { + LogicalPlan::Projection(proj) => { + match &*proj.input { + LogicalPlan::Join(join) => { + // Natural join finds common columns (id in this case) + // If no common columns, it becomes a cross join + assert!( + !join.on.is_empty() || join.join_type == JoinType::Cross, + "Natural join should either find common columns or become cross join" + ); + } + _ => panic!("Expected Join under Projection"), + } + } + _ => panic!("Expected Projection at top level"), + } + } + + #[test] + fn test_three_way_join() { + let schema = create_test_schema(); + let sql = "SELECT * FROM users u + JOIN orders o ON u.id = o.user_id + JOIN products p ON o.product_id = p.id"; + let plan = parse_and_build(sql, &schema).unwrap(); + + match plan { + LogicalPlan::Projection(proj) => { + match &*proj.input { + LogicalPlan::Join(join2) => { + // Second join (with products) + assert_eq!(join2.join_type, JoinType::Inner); + match &*join2.left { + LogicalPlan::Join(join1) => { + // First join (users with orders) + assert_eq!(join1.join_type, JoinType::Inner); + } + _ => panic!("Expected nested Join for three-way join"), + } + } + _ => panic!("Expected Join under Projection"), + } + } + _ => panic!("Expected Projection at top level"), + } + } + + #[test] + fn test_mixed_join_types() { + let schema = create_test_schema(); + let sql = "SELECT * FROM users u + LEFT JOIN orders o ON u.id = o.user_id + INNER JOIN products p ON o.product_id = p.id"; + let plan = parse_and_build(sql, &schema).unwrap(); + + match plan { + LogicalPlan::Projection(proj) => { + match &*proj.input { + LogicalPlan::Join(join2) => { + // Second join should be INNER + assert_eq!(join2.join_type, JoinType::Inner); + match &*join2.left { + LogicalPlan::Join(join1) => { + // First join should be LEFT + assert_eq!(join1.join_type, JoinType::Left); + } + _ => panic!("Expected nested Join"), + } + } + _ => panic!("Expected Join under Projection"), + } + } + _ => panic!("Expected Projection at top level"), + } + } + + #[test] + fn test_join_with_filter() { + let schema = create_test_schema(); + let sql = "SELECT * FROM users u JOIN orders o ON u.id = o.user_id WHERE o.amount > 100"; + let plan = parse_and_build(sql, &schema).unwrap(); + + match plan { + LogicalPlan::Projection(proj) => { + match &*proj.input { + LogicalPlan::Filter(filter) => { + // WHERE clause creates a Filter above the Join + match &*filter.input { + LogicalPlan::Join(join) => { + assert_eq!(join.join_type, JoinType::Inner); + } + _ => panic!("Expected Join under Filter"), + } + } + _ => panic!("Expected Filter under Projection"), + } + } + _ => panic!("Expected Projection at top level"), + } + } + + #[test] + fn test_join_with_projection() { + let schema = create_test_schema(); + let sql = "SELECT u.name, o.amount FROM users u JOIN orders o ON u.id = o.user_id"; + let plan = parse_and_build(sql, &schema).unwrap(); + + match plan { + LogicalPlan::Projection(proj) => { + assert_eq!(proj.exprs.len(), 2); // u.name and o.amount + match &*proj.input { + LogicalPlan::Join(join) => { + assert_eq!(join.join_type, JoinType::Inner); + } + _ => panic!("Expected Join under Projection"), + } + } + _ => panic!("Expected Projection at top level"), + } + } + + #[test] + fn test_join_with_aggregation() { + let schema = create_test_schema(); + let sql = "SELECT u.name, SUM(o.amount) + FROM users u JOIN orders o ON u.id = o.user_id + GROUP BY u.name"; + let plan = parse_and_build(sql, &schema).unwrap(); + + match plan { + LogicalPlan::Aggregate(agg) => { + assert_eq!(agg.group_expr.len(), 1); // GROUP BY u.name + assert_eq!(agg.aggr_expr.len(), 1); // SUM(o.amount) + match &*agg.input { + LogicalPlan::Join(join) => { + assert_eq!(join.join_type, JoinType::Inner); + } + _ => panic!("Expected Join under Aggregate"), + } + } + _ => panic!("Expected Aggregate"), + } + } + + #[test] + fn test_join_with_order_by() { + let schema = create_test_schema(); + let sql = "SELECT * FROM users u JOIN orders o ON u.id = o.user_id ORDER BY o.amount DESC"; + let plan = parse_and_build(sql, &schema).unwrap(); + + match plan { + LogicalPlan::Sort(sort) => { + assert_eq!(sort.exprs.len(), 1); + assert!(!sort.exprs[0].asc); // DESC + match &*sort.input { + LogicalPlan::Projection(proj) => match &*proj.input { + LogicalPlan::Join(join) => { + assert_eq!(join.join_type, JoinType::Inner); + } + _ => panic!("Expected Join under Projection"), + }, + _ => panic!("Expected Projection under Sort"), + } + } + _ => panic!("Expected Sort at top level"), + } + } + + #[test] + fn test_join_in_subquery() { + let schema = create_test_schema(); + let sql = "SELECT * FROM ( + SELECT u.id, u.name, o.amount + FROM users u JOIN orders o ON u.id = o.user_id + ) WHERE amount > 100"; + let plan = parse_and_build(sql, &schema).unwrap(); + + match plan { + LogicalPlan::Projection(outer_proj) => match &*outer_proj.input { + LogicalPlan::Filter(filter) => match &*filter.input { + LogicalPlan::Projection(inner_proj) => match &*inner_proj.input { + LogicalPlan::Join(join) => { + assert_eq!(join.join_type, JoinType::Inner); + } + _ => panic!("Expected Join in subquery"), + }, + _ => panic!("Expected Projection for subquery"), + }, + _ => panic!("Expected Filter"), + }, + _ => panic!("Expected Projection at top level"), + } + } + + #[test] + fn test_join_ambiguous_column() { + let schema = create_test_schema(); + // Both users and orders have an 'id' column + let sql = "SELECT id FROM users JOIN orders ON users.id = orders.user_id"; + let result = parse_and_build(sql, &schema); + // This might error or succeed depending on how we handle ambiguous columns + // For now, just check that parsing completes + match result { + Ok(_) => { + // If successful, the implementation handles ambiguous columns somehow + } + Err(_) => { + // If error, the implementation rejects ambiguous columns + } + } + } } From 6be5eb74d9daa96b5fedcb5d56af75ddbcc1b891 Mon Sep 17 00:00:00 2001 From: Glauber Costa Date: Tue, 16 Sep 2025 12:10:00 -0500 Subject: [PATCH 12/34] Implement the Join Operator The join operator is also a stateful operator. It keeps the input deltas stored in the state, for both the left and right branches of the join. JOINs extract a join key, which is the values that were used in the join's equality statement. That key is now our zset_id, and it points to a collection of rows. --- core/incremental/dbsp.rs | 6 +- core/incremental/operator.rs | 1612 +++++++++++++++++++++++++++++++++- 2 files changed, 1582 insertions(+), 36 deletions(-) diff --git a/core/incremental/dbsp.rs b/core/incremental/dbsp.rs index 363ac1142..d4862b70a 100644 --- a/core/incremental/dbsp.rs +++ b/core/incremental/dbsp.rs @@ -75,6 +75,10 @@ impl HashableRow { hasher.finish() } + + pub fn cached_hash(&self) -> u64 { + self.cached_hash + } } impl Hash for HashableRow { @@ -168,7 +172,7 @@ impl Delta { } /// A pair of deltas for operators that process two inputs -#[derive(Debug, Clone)] +#[derive(Debug, Clone, Default)] pub struct DeltaPair { pub left: Delta, pub right: Delta, diff --git a/core/incremental/operator.rs b/core/incremental/operator.rs index 7c402db93..78d2ca60d 100644 --- a/core/incremental/operator.rs +++ b/core/incremental/operator.rs @@ -119,14 +119,9 @@ enum AggregateCommitState { Invalid, } -// eval() has uncommitted data, so it can't be a member attribute of the Operator. -// The state has to be kept by the caller +// Aggregate-specific eval states #[derive(Debug)] -pub enum EvalState { - Uninitialized, - Init { - deltas: DeltaPair, - }, +pub enum AggregateEvalState { FetchKey { delta: Delta, // Keep original delta for merge operation current_idx: usize, @@ -149,6 +144,299 @@ pub enum EvalState { old_values: HashMap>, recompute_state: Box, }, + Done { + output: (Delta, ComputedStates), + }, +} + +// Helper function to read the next row from the BTree for joins +fn read_next_join_row( + storage_id: i64, + join_key: &HashableRow, + last_element_id: i64, + cursors: &mut DbspStateCursors, +) -> Result>> { + // Build the index key: (storage_id, zset_id, element_id) + // zset_id is the hash of the join key + let zset_id = join_key.cached_hash() as i64; + + let index_key_values = vec![ + Value::Integer(storage_id), + Value::Integer(zset_id), + Value::Integer(last_element_id), + ]; + + let index_record = ImmutableRecord::from_values(&index_key_values, index_key_values.len()); + let seek_result = return_if_io!(cursors + .index_cursor + .seek(SeekKey::IndexKey(&index_record), SeekOp::GT)); + + if !matches!(seek_result, SeekResult::Found) { + return Ok(IOResult::Done(None)); + } + + // Check if we're still in the same (storage_id, zset_id) range + let current_record = return_if_io!(cursors.index_cursor.record()); + + // Extract all needed values from the record before dropping it + let (found_storage_id, found_zset_id, element_id) = if let Some(rec) = current_record { + let values = rec.get_values(); + + // Index has 4 values: storage_id, zset_id, element_id, rowid (appended by WriteRow) + if values.len() >= 3 { + let found_storage_id = match &values[0].to_owned() { + Value::Integer(id) => *id, + _ => return Ok(IOResult::Done(None)), + }; + let found_zset_id = match &values[1].to_owned() { + Value::Integer(id) => *id, + _ => return Ok(IOResult::Done(None)), + }; + let element_id = match &values[2].to_owned() { + Value::Integer(id) => *id, + _ => { + return Ok(IOResult::Done(None)); + } + }; + (found_storage_id, found_zset_id, element_id) + } else { + return Ok(IOResult::Done(None)); + } + } else { + return Ok(IOResult::Done(None)); + }; + + // Now we can safely check if we're in the right range + // If we've moved to a different storage_id or zset_id, we're done + if found_storage_id != storage_id || found_zset_id != zset_id { + return Ok(IOResult::Done(None)); + } + + // Now get the actual row from the table using the rowid from the index + let rowid = return_if_io!(cursors.index_cursor.rowid()); + if let Some(rowid) = rowid { + return_if_io!(cursors + .table_cursor + .seek(SeekKey::TableRowId(rowid), SeekOp::GE { eq_only: true })); + + let table_record = return_if_io!(cursors.table_cursor.record()); + if let Some(rec) = table_record { + let table_values = rec.get_values(); + // Table format: [storage_id, zset_id, element_id, value_blob, weight] + if table_values.len() >= 5 { + // Deserialize the row from the blob + let value_at_3 = table_values[3].to_owned(); + let blob = match value_at_3 { + Value::Blob(ref b) => b, + _ => return Ok(IOResult::Done(None)), + }; + + // The blob contains the serialized HashableRow + // For now, let's deserialize it simply + let row = deserialize_hashable_row(blob)?; + + let weight = match &table_values[4].to_owned() { + Value::Integer(w) => *w as isize, + _ => return Ok(IOResult::Done(None)), + }; + + return Ok(IOResult::Done(Some((element_id, row, weight)))); + } + } + } + Ok(IOResult::Done(None)) +} + +// Join-specific eval states +#[derive(Debug)] +pub enum JoinEvalState { + ProcessDeltaJoin { + deltas: DeltaPair, + output: Delta, + }, + ProcessLeftJoin { + deltas: DeltaPair, + output: Delta, + current_idx: usize, + last_row_scanned: i64, + }, + ProcessRightJoin { + deltas: DeltaPair, + output: Delta, + current_idx: usize, + last_row_scanned: i64, + }, + Done { + output: Delta, + }, +} + +impl JoinEvalState { + fn combine_rows( + left_row: &HashableRow, + left_weight: i64, + right_row: &HashableRow, + right_weight: i64, + output: &mut Delta, + ) { + // Combine the rows + let mut combined_values = left_row.values.clone(); + combined_values.extend(right_row.values.clone()); + // Use hash of the combined values as rowid to ensure uniqueness + let temp_row = HashableRow::new(0, combined_values.clone()); + let joined_rowid = temp_row.cached_hash() as i64; + let joined_row = HashableRow::new(joined_rowid, combined_values); + + // Add to output with combined weight + let combined_weight = left_weight * right_weight; + output.changes.push((joined_row, combined_weight as isize)); + } + + fn process_join_state( + &mut self, + cursors: &mut DbspStateCursors, + left_key_indices: &[usize], + right_key_indices: &[usize], + left_storage_id: i64, + right_storage_id: i64, + ) -> Result> { + loop { + match self { + JoinEvalState::ProcessDeltaJoin { deltas, output } => { + // Move to ProcessLeftJoin + *self = JoinEvalState::ProcessLeftJoin { + deltas: std::mem::take(deltas), + output: std::mem::take(output), + current_idx: 0, + last_row_scanned: i64::MIN, + }; + } + JoinEvalState::ProcessLeftJoin { + deltas, + output, + current_idx, + last_row_scanned, + } => { + if *current_idx >= deltas.left.changes.len() { + *self = JoinEvalState::ProcessRightJoin { + deltas: std::mem::take(deltas), + output: std::mem::take(output), + current_idx: 0, + last_row_scanned: i64::MIN, + }; + } else { + let (left_row, left_weight) = &deltas.left.changes[*current_idx]; + // Extract join key using provided indices + let key_values: Vec = left_key_indices + .iter() + .map(|&idx| left_row.values.get(idx).cloned().unwrap_or(Value::Null)) + .collect(); + let left_key = HashableRow::new(0, key_values); + + let next_row = return_if_io!(read_next_join_row( + right_storage_id, + &left_key, + *last_row_scanned, + cursors + )); + match next_row { + Some((element_id, right_row, right_weight)) => { + Self::combine_rows( + left_row, + (*left_weight) as i64, + &right_row, + right_weight as i64, + output, + ); + // Continue scanning with this left row + *self = JoinEvalState::ProcessLeftJoin { + deltas: std::mem::take(deltas), + output: std::mem::take(output), + current_idx: *current_idx, + last_row_scanned: element_id, + }; + } + None => { + // No more matches for this left row, move to next + *self = JoinEvalState::ProcessLeftJoin { + deltas: std::mem::take(deltas), + output: std::mem::take(output), + current_idx: *current_idx + 1, + last_row_scanned: i64::MIN, + }; + } + } + } + } + JoinEvalState::ProcessRightJoin { + deltas, + output, + current_idx, + last_row_scanned, + } => { + if *current_idx >= deltas.right.changes.len() { + *self = JoinEvalState::Done { + output: std::mem::take(output), + }; + } else { + let (right_row, right_weight) = &deltas.right.changes[*current_idx]; + // Extract join key using provided indices + let key_values: Vec = right_key_indices + .iter() + .map(|&idx| right_row.values.get(idx).cloned().unwrap_or(Value::Null)) + .collect(); + let right_key = HashableRow::new(0, key_values); + + let next_row = return_if_io!(read_next_join_row( + left_storage_id, + &right_key, + *last_row_scanned, + cursors + )); + match next_row { + Some((element_id, left_row, left_weight)) => { + Self::combine_rows( + &left_row, + left_weight as i64, + right_row, + (*right_weight) as i64, + output, + ); + // Continue scanning with this right row + *self = JoinEvalState::ProcessRightJoin { + deltas: std::mem::take(deltas), + output: std::mem::take(output), + current_idx: *current_idx, + last_row_scanned: element_id, + }; + } + None => { + // No more matches for this right row, move to next + *self = JoinEvalState::ProcessRightJoin { + deltas: std::mem::take(deltas), + output: std::mem::take(output), + current_idx: *current_idx + 1, + last_row_scanned: i64::MIN, + }; + } + } + } + } + JoinEvalState::Done { output } => { + return Ok(IOResult::Done(std::mem::take(output))); + } + } + } + } +} + +// Generic eval state that delegates to operator-specific states +#[derive(Debug)] +pub enum EvalState { + Uninitialized, + Init { deltas: DeltaPair }, + Aggregate(Box), + Join(Box), Done, } @@ -190,23 +478,26 @@ impl EvalState { } } - fn advance(&mut self, groups_to_read: BTreeMap>) { + fn advance_aggregate(&mut self, groups_to_read: BTreeMap>) { let delta = match self { EvalState::Init { deltas } => std::mem::take(&mut deltas.left), - _ => panic!("advance() can only be called when in Init state, current state: {self:?}"), + _ => panic!("advance_aggregate() can only be called when in Init state, current state: {self:?}"), }; let _ = std::mem::replace( self, - EvalState::FetchKey { + EvalState::Aggregate(Box::new(AggregateEvalState::FetchKey { delta, current_idx: 0, groups_to_read: groups_to_read.into_iter().collect(), // Convert BTreeMap to Vec existing_groups: HashMap::new(), old_values: HashMap::new(), - }, + })), ); } +} + +impl AggregateEvalState { fn process_delta( &mut self, operator: &mut AggregateOperator, @@ -214,13 +505,7 @@ impl EvalState { ) -> Result> { loop { match self { - EvalState::Uninitialized => { - panic!("Cannot process_delta with Uninitialized state"); - } - EvalState::Init { .. } => { - panic!("State machine not supposed to reach the init state! advance() should have been called"); - } - EvalState::FetchKey { + AggregateEvalState::FetchKey { delta, current_idx, groups_to_read, @@ -238,7 +523,7 @@ impl EvalState { operator, )); - *self = EvalState::RecomputeMinMax { + *self = AggregateEvalState::RecomputeMinMax { delta: std::mem::take(delta), existing_groups: std::mem::take(existing_groups), old_values: std::mem::take(old_values), @@ -284,7 +569,7 @@ impl EvalState { // Always transition to FetchData let taken_existing = std::mem::take(existing_groups); let taken_old_values = std::mem::take(old_values); - let next_state = EvalState::FetchData { + let next_state = AggregateEvalState::FetchData { delta: std::mem::take(delta), current_idx: *current_idx, groups_to_read: std::mem::take(groups_to_read), @@ -296,7 +581,7 @@ impl EvalState { *self = next_state; } } - EvalState::FetchData { + AggregateEvalState::FetchData { delta, current_idx, groups_to_read, @@ -332,7 +617,7 @@ impl EvalState { let next_idx = *current_idx + 1; let taken_existing = std::mem::take(existing_groups); let taken_old_values = std::mem::take(old_values); - let next_state = EvalState::FetchKey { + let next_state = AggregateEvalState::FetchKey { delta: std::mem::take(delta), current_idx: next_idx, groups_to_read: std::mem::take(groups_to_read), @@ -341,7 +626,7 @@ impl EvalState { }; *self = next_state; } - EvalState::RecomputeMinMax { + AggregateEvalState::RecomputeMinMax { delta, existing_groups, old_values, @@ -356,11 +641,12 @@ impl EvalState { let (output_delta, computed_states) = operator.merge_delta_with_existing(delta, existing_groups, old_values); - *self = EvalState::Done; - return Ok(IOResult::Done((output_delta, computed_states))); + *self = AggregateEvalState::Done { + output: (output_delta, computed_states), + }; } - EvalState::Done => { - return Ok(IOResult::Done((Delta::new(), HashMap::new()))); + AggregateEvalState::Done { output } => { + return Ok(IOResult::Done(output.clone())); } } } @@ -646,6 +932,8 @@ pub enum JoinType { Inner, Left, Right, + Full, + Cross, } #[derive(Debug, Clone, PartialEq)] @@ -1877,21 +2165,27 @@ impl AggregateOperator { let group_key_str = Self::group_key_to_string(&group_key); groups_to_read.insert(group_key_str, group_key); } - state.advance(groups_to_read); + state.advance_aggregate(groups_to_read); } - EvalState::FetchKey { .. } - | EvalState::FetchData { .. } - | EvalState::RecomputeMinMax { .. } => { - // Already in progress, continue processing on process_delta below. + EvalState::Aggregate(_agg_state) => { + // Already in progress, continue processing below. } EvalState::Done => { panic!("unreachable state! should have returned"); } + EvalState::Join(_) => { + panic!("Join state should not appear in aggregate operator"); + } } - // Process the delta through the state machine - let result = return_if_io!(state.process_delta(self, cursors)); - Ok(IOResult::Done(result)) + // Process the delta through the aggregate state machine + match state { + EvalState::Aggregate(agg_state) => { + let result = return_if_io!(agg_state.process_delta(self, cursors)); + Ok(IOResult::Done(result)) + } + _ => panic!("Invalid state for aggregate processing"), + } } fn merge_delta_with_existing( @@ -2228,6 +2522,493 @@ impl IncrementalOperator for AggregateOperator { } } +#[derive(Debug)] +enum JoinCommitState { + Idle, + Eval { + eval_state: EvalState, + }, + CommitLeftDelta { + deltas: DeltaPair, + output: Delta, + current_idx: usize, + write_row: WriteRow, + }, + CommitRightDelta { + deltas: DeltaPair, + output: Delta, + current_idx: usize, + write_row: WriteRow, + }, + Invalid, +} + +/// Join operator - performs incremental join between two relations +/// Implements the DBSP formula: δ(R ⋈ S) = (δR ⋈ S) ∪ (R ⋈ δS) ∪ (δR ⋈ δS) +#[derive(Debug)] +pub struct JoinOperator { + /// Unique operator ID for indexing in persistent storage + operator_id: usize, + /// Type of join to perform + join_type: JoinType, + /// Column indices for extracting join keys from left input + left_key_indices: Vec, + /// Column indices for extracting join keys from right input + right_key_indices: Vec, + /// Column names from left input + left_columns: Vec, + /// Column names from right input + right_columns: Vec, + /// Tracker for computation statistics + tracker: Option>>, + + commit_state: JoinCommitState, +} + +impl JoinOperator { + pub fn new( + operator_id: usize, + join_type: JoinType, + left_key_indices: Vec, + right_key_indices: Vec, + left_columns: Vec, + right_columns: Vec, + ) -> Result { + // Check for unsupported join types + match join_type { + JoinType::Left => { + return Err(crate::LimboError::ParseError( + "LEFT OUTER JOIN is not yet supported in incremental views".to_string(), + )) + } + JoinType::Right => { + return Err(crate::LimboError::ParseError( + "RIGHT OUTER JOIN is not yet supported in incremental views".to_string(), + )) + } + JoinType::Full => { + return Err(crate::LimboError::ParseError( + "FULL OUTER JOIN is not yet supported in incremental views".to_string(), + )) + } + JoinType::Cross => { + return Err(crate::LimboError::ParseError( + "CROSS JOIN is not yet supported in incremental views".to_string(), + )) + } + JoinType::Inner => {} // Inner join is supported + } + + Ok(Self { + operator_id, + join_type, + left_key_indices, + right_key_indices, + left_columns, + right_columns, + tracker: None, + commit_state: JoinCommitState::Idle, + }) + } + + /// Extract join key from row values using the specified indices + fn extract_join_key(&self, values: &[Value], indices: &[usize]) -> HashableRow { + let key_values: Vec = indices + .iter() + .map(|&idx| values.get(idx).cloned().unwrap_or(Value::Null)) + .collect(); + // Use 0 as a dummy rowid for join keys. They don't come from a table, + // so they don't need a rowid. Their key will be the hash of the row values. + HashableRow::new(0, key_values) + } + + /// Generate storage ID for left table + fn left_storage_id(&self) -> i64 { + // Use column_index=0 for left side + generate_storage_id(self.operator_id, 0, 0) + } + + /// Generate storage ID for right table + fn right_storage_id(&self) -> i64 { + // Use column_index=1 for right side + generate_storage_id(self.operator_id, 1, 0) + } + + /// SQL-compliant comparison for join keys + /// Returns true if keys match according to SQL semantics (NULL != NULL) + fn sql_keys_equal(left_key: &HashableRow, right_key: &HashableRow) -> bool { + if left_key.values.len() != right_key.values.len() { + return false; + } + + for (left_val, right_val) in left_key.values.iter().zip(right_key.values.iter()) { + // In SQL, NULL never equals NULL + if matches!(left_val, Value::Null) || matches!(right_val, Value::Null) { + return false; + } + + // For non-NULL values, use regular comparison + if left_val != right_val { + return false; + } + } + + true + } + + fn process_join_state( + &mut self, + state: &mut EvalState, + cursors: &mut DbspStateCursors, + ) -> Result> { + // Get the join state out of the enum + match state { + EvalState::Join(js) => js.process_join_state( + cursors, + &self.left_key_indices, + &self.right_key_indices, + self.left_storage_id(), + self.right_storage_id(), + ), + _ => panic!("process_join_state called with non-join state"), + } + } + + fn eval_internal( + &mut self, + state: &mut EvalState, + cursors: &mut DbspStateCursors, + ) -> Result> { + loop { + let loop_state = std::mem::replace(state, EvalState::Uninitialized); + match loop_state { + EvalState::Uninitialized => { + panic!("Cannot eval JoinOperator with Uninitialized state"); + } + EvalState::Init { deltas } => { + let mut output = Delta::new(); + + // Component 3: δR ⋈ δS (left delta join right delta) + for (left_row, left_weight) in &deltas.left.changes { + let left_key = + self.extract_join_key(&left_row.values, &self.left_key_indices); + + for (right_row, right_weight) in &deltas.right.changes { + let right_key = + self.extract_join_key(&right_row.values, &self.right_key_indices); + + if Self::sql_keys_equal(&left_key, &right_key) { + if let Some(tracker) = &self.tracker { + tracker.lock().unwrap().record_join_lookup(); + } + + // Combine the rows + let mut combined_values = left_row.values.clone(); + combined_values.extend(right_row.values.clone()); + + // Create the joined row with a unique rowid + // Use hash of the combined values to ensure uniqueness + let temp_row = HashableRow::new(0, combined_values.clone()); + let joined_rowid = temp_row.cached_hash() as i64; + let joined_row = + HashableRow::new(joined_rowid, combined_values.clone()); + + // Add to output with combined weight + let combined_weight = left_weight * right_weight; + output.changes.push((joined_row, combined_weight)); + } + } + } + + *state = EvalState::Join(Box::new(JoinEvalState::ProcessDeltaJoin { + deltas, + output, + })); + } + EvalState::Join(join_state) => { + *state = EvalState::Join(join_state); + let output = return_if_io!(self.process_join_state(state, cursors)); + return Ok(IOResult::Done(output)); + } + EvalState::Done => { + return Ok(IOResult::Done(Delta::new())); + } + EvalState::Aggregate(_) => { + panic!("Aggregate state should not appear in join operator"); + } + } + } + } +} + +// Helper to deserialize a HashableRow from a blob +fn deserialize_hashable_row(blob: &[u8]) -> Result { + // Simple deserialization - this needs to match how we serialize in commit + // Format: [rowid:8 bytes][num_values:4 bytes][values...] + if blob.len() < 12 { + return Err(crate::LimboError::InternalError( + "Invalid blob size".to_string(), + )); + } + + let rowid = i64::from_le_bytes(blob[0..8].try_into().unwrap()); + let num_values = u32::from_le_bytes(blob[8..12].try_into().unwrap()) as usize; + + let mut values = Vec::new(); + let mut offset = 12; + + for _ in 0..num_values { + if offset >= blob.len() { + break; + } + + let type_tag = blob[offset]; + offset += 1; + + match type_tag { + 0 => values.push(Value::Null), + 1 => { + if offset + 8 <= blob.len() { + let i = i64::from_le_bytes(blob[offset..offset + 8].try_into().unwrap()); + values.push(Value::Integer(i)); + offset += 8; + } + } + 2 => { + if offset + 8 <= blob.len() { + let f = f64::from_le_bytes(blob[offset..offset + 8].try_into().unwrap()); + values.push(Value::Float(f)); + offset += 8; + } + } + 3 => { + if offset + 4 <= blob.len() { + let len = + u32::from_le_bytes(blob[offset..offset + 4].try_into().unwrap()) as usize; + offset += 4; + if offset + len < blob.len() { + let text_bytes = blob[offset..offset + len].to_vec(); + offset += len; + let subtype = match blob[offset] { + 0 => crate::types::TextSubtype::Text, + 1 => crate::types::TextSubtype::Json, + _ => crate::types::TextSubtype::Text, + }; + offset += 1; + values.push(Value::Text(crate::types::Text { + value: text_bytes, + subtype, + })); + } + } + } + 4 => { + if offset + 4 <= blob.len() { + let len = + u32::from_le_bytes(blob[offset..offset + 4].try_into().unwrap()) as usize; + offset += 4; + if offset + len <= blob.len() { + let blob_data = blob[offset..offset + len].to_vec(); + values.push(Value::Blob(blob_data)); + offset += len; + } + } + } + _ => break, // Unknown type tag + } + } + + Ok(HashableRow::new(rowid, values)) +} + +// Helper to serialize a HashableRow to a blob +fn serialize_hashable_row(row: &HashableRow) -> Vec { + let mut blob = Vec::new(); + + // Write rowid + blob.extend_from_slice(&row.rowid.to_le_bytes()); + + // Write number of values + blob.extend_from_slice(&(row.values.len() as u32).to_le_bytes()); + + // Write each value directly with type tags (like AggregateState does) + for value in &row.values { + match value { + Value::Null => blob.push(0u8), + Value::Integer(i) => { + blob.push(1u8); + blob.extend_from_slice(&i.to_le_bytes()); + } + Value::Float(f) => { + blob.push(2u8); + blob.extend_from_slice(&f.to_le_bytes()); + } + Value::Text(s) => { + blob.push(3u8); + let bytes = &s.value; + blob.extend_from_slice(&(bytes.len() as u32).to_le_bytes()); + blob.extend_from_slice(bytes); + blob.push(s.subtype as u8); + } + Value::Blob(b) => { + blob.push(4u8); + blob.extend_from_slice(&(b.len() as u32).to_le_bytes()); + blob.extend_from_slice(b); + } + } + } + + blob +} + +impl IncrementalOperator for JoinOperator { + fn eval( + &mut self, + state: &mut EvalState, + cursors: &mut DbspStateCursors, + ) -> Result> { + let delta = return_if_io!(self.eval_internal(state, cursors)); + Ok(IOResult::Done(delta)) + } + + fn commit( + &mut self, + deltas: DeltaPair, + cursors: &mut DbspStateCursors, + ) -> Result> { + loop { + let mut state = std::mem::replace(&mut self.commit_state, JoinCommitState::Invalid); + match &mut state { + JoinCommitState::Idle => { + self.commit_state = JoinCommitState::Eval { + eval_state: deltas.clone().into(), + } + } + JoinCommitState::Eval { ref mut eval_state } => { + let output = return_and_restore_if_io!( + &mut self.commit_state, + state, + self.eval(eval_state, cursors) + ); + self.commit_state = JoinCommitState::CommitLeftDelta { + deltas: deltas.clone(), + output, + current_idx: 0, + write_row: WriteRow::new(), + }; + } + JoinCommitState::CommitLeftDelta { + deltas, + output, + current_idx, + ref mut write_row, + } => { + if *current_idx >= deltas.left.changes.len() { + self.commit_state = JoinCommitState::CommitRightDelta { + deltas: std::mem::take(deltas), + output: std::mem::take(output), + current_idx: 0, + write_row: WriteRow::new(), + }; + continue; + } + + let (row, weight) = &deltas.left.changes[*current_idx]; + // Extract join key from the left row + let join_key = self.extract_join_key(&row.values, &self.left_key_indices); + + // The index key: (storage_id, zset_id, element_id) + // zset_id is the hash of the join key, element_id is hash of the row + let storage_id = self.left_storage_id(); + let zset_id = join_key.cached_hash() as i64; + let element_id = row.cached_hash() as i64; + let index_key = vec![ + Value::Integer(storage_id), + Value::Integer(zset_id), + Value::Integer(element_id), + ]; + + // The record values: we'll store the serialized row as a blob + let row_blob = serialize_hashable_row(row); + let record_values = vec![ + Value::Integer(self.left_storage_id()), + Value::Integer(join_key.cached_hash() as i64), + Value::Integer(row.cached_hash() as i64), + Value::Blob(row_blob), + ]; + + // Use return_and_restore_if_io to handle I/O properly + return_and_restore_if_io!( + &mut self.commit_state, + state, + write_row.write_row(cursors, index_key, record_values, *weight) + ); + + self.commit_state = JoinCommitState::CommitLeftDelta { + deltas: deltas.clone(), + output: output.clone(), + current_idx: *current_idx + 1, + write_row: WriteRow::new(), + }; + } + JoinCommitState::CommitRightDelta { + deltas, + output, + current_idx, + ref mut write_row, + } => { + if *current_idx >= deltas.right.changes.len() { + // Reset to Idle state for next commit + self.commit_state = JoinCommitState::Idle; + return Ok(IOResult::Done(output.clone())); + } + + let (row, weight) = &deltas.right.changes[*current_idx]; + // Extract join key from the right row + let join_key = self.extract_join_key(&row.values, &self.right_key_indices); + + // The index key: (storage_id, zset_id, element_id) + let index_key = vec![ + Value::Integer(self.right_storage_id()), + Value::Integer(join_key.cached_hash() as i64), + Value::Integer(row.cached_hash() as i64), + ]; + + // The record values: we'll store the serialized row as a blob + let row_blob = serialize_hashable_row(row); + let record_values = vec![ + Value::Integer(self.right_storage_id()), + Value::Integer(join_key.cached_hash() as i64), + Value::Integer(row.cached_hash() as i64), + Value::Blob(row_blob), + ]; + + // Use return_and_restore_if_io to handle I/O properly + return_and_restore_if_io!( + &mut self.commit_state, + state, + write_row.write_row(cursors, index_key, record_values, *weight) + ); + + self.commit_state = JoinCommitState::CommitRightDelta { + deltas: std::mem::take(deltas), + output: std::mem::take(output), + current_idx: *current_idx + 1, + write_row: WriteRow::new(), + }; + } + JoinCommitState::Invalid => { + panic!("Invalid join commit state"); + } + } + } + } + + fn set_tracker(&mut self, tracker: Arc>) { + self.tracker = Some(tracker); + } +} + #[cfg(test)] mod tests { use super::*; @@ -4897,4 +5678,765 @@ mod tests { assert_eq!(row_ins.values[1], Value::Integer(150)); // New MAX(col2) assert_eq!(row_ins.values[2], Value::Integer(500)); // MIN(col3) unchanged } + + #[test] + fn test_join_operator_inner() { + // Test INNER JOIN with incremental updates + let (pager, table_page_id, index_page_id) = create_test_pager(); + let table_cursor = BTreeCursor::new_table(None, pager.clone(), table_page_id, 10); + let index_def = create_dbsp_state_index(index_page_id); + let index_cursor = + BTreeCursor::new_index(None, pager.clone(), index_page_id, &index_def, 10); + let mut cursors = DbspStateCursors::new(table_cursor, index_cursor); + + let mut join = JoinOperator::new( + 1, // operator_id + JoinType::Inner, + vec![0], // Join on first column + vec![0], + vec!["customer_id".to_string(), "amount".to_string()], + vec!["id".to_string(), "name".to_string()], + ) + .unwrap(); + + // FIRST COMMIT: Initialize with data + let mut left_delta = Delta::new(); + left_delta.insert(1, vec![Value::Integer(1), Value::Float(100.0)]); + left_delta.insert(2, vec![Value::Integer(2), Value::Float(200.0)]); + left_delta.insert(3, vec![Value::Integer(3), Value::Float(300.0)]); // No match initially + + let mut right_delta = Delta::new(); + right_delta.insert(1, vec![Value::Integer(1), Value::Text("Alice".into())]); + right_delta.insert(2, vec![Value::Integer(2), Value::Text("Bob".into())]); + right_delta.insert(4, vec![Value::Integer(4), Value::Text("David".into())]); // No match initially + + let delta_pair = DeltaPair::new(left_delta, right_delta); + let result = pager + .io + .block(|| join.commit(delta_pair.clone(), &mut cursors)) + .unwrap(); + + // Should have 2 matches (customer 1 and 2) + assert_eq!( + result.changes.len(), + 2, + "First commit should produce 2 matches" + ); + + let mut results: Vec<_> = result.changes.clone(); + results.sort_by_key(|r| r.0.values[0].clone()); + + assert_eq!(results[0].0.values[0], Value::Integer(1)); + assert_eq!(results[0].0.values[3], Value::Text("Alice".into())); + assert_eq!(results[1].0.values[0], Value::Integer(2)); + assert_eq!(results[1].0.values[3], Value::Text("Bob".into())); + + // SECOND COMMIT: Add incremental data that should join with persisted state + // Add a new left row that should match existing right row (customer 4) + let mut left_delta2 = Delta::new(); + left_delta2.insert(5, vec![Value::Integer(4), Value::Float(400.0)]); // Should match David from persisted state + + // Add a new right row that should match existing left row (customer 3) + let mut right_delta2 = Delta::new(); + right_delta2.insert(6, vec![Value::Integer(3), Value::Text("Charlie".into())]); // Should match customer 3 from persisted state + + let delta_pair2 = DeltaPair::new(left_delta2, right_delta2); + let result2 = pager + .io + .block(|| join.commit(delta_pair2.clone(), &mut cursors)) + .unwrap(); + + // The second commit should produce: + // 1. New left (customer_id=4) joins with persisted right (id=4, David) + // 2. Persisted left (customer_id=3) joins with new right (id=3, Charlie) + + assert_eq!( + result2.changes.len(), + 2, + "Second commit should produce 2 new matches from incremental join. Got: {:?}", + result2.changes + ); + + // Verify the incremental results + let mut results2: Vec<_> = result2.changes.clone(); + results2.sort_by_key(|r| r.0.values[0].clone()); + + // Check for customer 3 joined with Charlie (existing left + new right) + let charlie_match = results2 + .iter() + .find(|(row, _)| row.values[0] == Value::Integer(3)) + .expect("Should find customer 3 joined with new Charlie"); + assert_eq!(charlie_match.0.values[2], Value::Integer(3)); + assert_eq!(charlie_match.0.values[3], Value::Text("Charlie".into())); + + // Check for customer 4 joined with David (new left + existing right) + let david_match = results2 + .iter() + .find(|(row, _)| row.values[0] == Value::Integer(4)) + .expect("Should find new customer 4 joined with existing David"); + assert_eq!(david_match.0.values[0], Value::Integer(4)); + assert_eq!(david_match.0.values[3], Value::Text("David".into())); + } + + #[test] + fn test_join_operator_with_deletions() { + // Test INNER JOIN with deletions (negative weights) + let (pager, table_page_id, index_page_id) = create_test_pager(); + let table_cursor = BTreeCursor::new_table(None, pager.clone(), table_page_id, 10); + let index_def = create_dbsp_state_index(index_page_id); + let index_cursor = + BTreeCursor::new_index(None, pager.clone(), index_page_id, &index_def, 10); + let mut cursors = DbspStateCursors::new(table_cursor, index_cursor); + + let mut join = JoinOperator::new( + 1, // operator_id + JoinType::Inner, + vec![0], // Join on first column + vec![0], + vec!["customer_id".to_string(), "amount".to_string()], + vec!["id".to_string(), "name".to_string()], + ) + .unwrap(); + + // FIRST COMMIT: Add initial data + let mut left_delta = Delta::new(); + left_delta.insert(1, vec![Value::Integer(1), Value::Float(100.0)]); + left_delta.insert(2, vec![Value::Integer(2), Value::Float(200.0)]); + left_delta.insert(3, vec![Value::Integer(3), Value::Float(300.0)]); + + let mut right_delta = Delta::new(); + right_delta.insert(1, vec![Value::Integer(1), Value::Text("Alice".into())]); + right_delta.insert(2, vec![Value::Integer(2), Value::Text("Bob".into())]); + right_delta.insert(3, vec![Value::Integer(3), Value::Text("Charlie".into())]); + + let delta_pair = DeltaPair::new(left_delta, right_delta); + + let result = pager + .io + .block(|| join.commit(delta_pair.clone(), &mut cursors)) + .unwrap(); + + assert_eq!(result.changes.len(), 3, "Should have 3 initial joins"); + + // SECOND COMMIT: Delete customer 2 from left side + let mut left_delta2 = Delta::new(); + left_delta2.delete(2, vec![Value::Integer(2), Value::Float(200.0)]); + + let empty_right = Delta::new(); + let delta_pair2 = DeltaPair::new(left_delta2, empty_right); + + let result2 = pager + .io + .block(|| join.commit(delta_pair2.clone(), &mut cursors)) + .unwrap(); + + // Should produce 1 deletion (retraction) of the join for customer 2 + assert_eq!( + result2.changes.len(), + 1, + "Should produce 1 retraction for deleted customer 2" + ); + assert_eq!( + result2.changes[0].1, -1, + "Should have weight -1 for deletion" + ); + assert_eq!(result2.changes[0].0.values[0], Value::Integer(2)); + assert_eq!(result2.changes[0].0.values[3], Value::Text("Bob".into())); + + // THIRD COMMIT: Delete customer 3 from right side + let empty_left = Delta::new(); + let mut right_delta3 = Delta::new(); + right_delta3.delete(3, vec![Value::Integer(3), Value::Text("Charlie".into())]); + + let delta_pair3 = DeltaPair::new(empty_left, right_delta3); + + let result3 = pager + .io + .block(|| join.commit(delta_pair3.clone(), &mut cursors)) + .unwrap(); + + // Should produce 1 deletion (retraction) of the join for customer 3 + assert_eq!( + result3.changes.len(), + 1, + "Should produce 1 retraction for deleted customer 3" + ); + assert_eq!( + result3.changes[0].1, -1, + "Should have weight -1 for deletion" + ); + assert_eq!(result3.changes[0].0.values[0], Value::Integer(3)); + assert_eq!(result3.changes[0].0.values[2], Value::Integer(3)); + } + + #[test] + fn test_join_operator_one_to_many() { + // Test one-to-many relationship: one customer with multiple orders + let (pager, table_page_id, index_page_id) = create_test_pager(); + let table_cursor = BTreeCursor::new_table(None, pager.clone(), table_page_id, 10); + let index_def = create_dbsp_state_index(index_page_id); + let index_cursor = + BTreeCursor::new_index(None, pager.clone(), index_page_id, &index_def, 10); + let mut cursors = DbspStateCursors::new(table_cursor, index_cursor); + + let mut join = JoinOperator::new( + 1, // operator_id + JoinType::Inner, + vec![0], // Join on first column (customer_id for orders) + vec![0], // Join on first column (id for customers) + vec![ + "customer_id".to_string(), + "order_id".to_string(), + "amount".to_string(), + ], + vec!["id".to_string(), "name".to_string()], + ) + .unwrap(); + + // FIRST COMMIT: Add one customer + let left_delta = Delta::new(); // Empty orders initially + let mut right_delta = Delta::new(); + right_delta.insert(1, vec![Value::Integer(100), Value::Text("Alice".into())]); + + let delta_pair = DeltaPair::new(left_delta, right_delta); + let result = pager + .io + .block(|| join.commit(delta_pair.clone(), &mut cursors)) + .unwrap(); + + // No joins yet (customer exists but no orders) + assert_eq!( + result.changes.len(), + 0, + "Should have no joins with customer but no orders" + ); + + // SECOND COMMIT: Add multiple orders for the same customer + let mut left_delta2 = Delta::new(); + left_delta2.insert( + 1, + vec![ + Value::Integer(100), + Value::Integer(1001), + Value::Float(50.0), + ], + ); // order 1001 + left_delta2.insert( + 2, + vec![ + Value::Integer(100), + Value::Integer(1002), + Value::Float(75.0), + ], + ); // order 1002 + left_delta2.insert( + 3, + vec![ + Value::Integer(100), + Value::Integer(1003), + Value::Float(100.0), + ], + ); // order 1003 + + let right_delta2 = Delta::new(); // No new customers + + let delta_pair2 = DeltaPair::new(left_delta2, right_delta2); + let result2 = pager + .io + .block(|| join.commit(delta_pair2.clone(), &mut cursors)) + .unwrap(); + + // Should produce 3 joins (3 orders × 1 customer) + assert_eq!( + result2.changes.len(), + 3, + "Should produce 3 joins for 3 orders with same customer. Got: {:?}", + result2.changes + ); + + // Verify all three joins have the same customer but different orders + for (row, weight) in &result2.changes { + assert_eq!(*weight, 1, "Weight should be 1 for insertion"); + assert_eq!( + row.values[0], + Value::Integer(100), + "Customer ID should be 100" + ); + assert_eq!( + row.values[4], + Value::Text("Alice".into()), + "Customer name should be Alice" + ); + + // Check order IDs are different + let order_id = match &row.values[1] { + Value::Integer(id) => *id, + _ => panic!("Expected integer order ID"), + }; + assert!( + (1001..=1003).contains(&order_id), + "Order ID {order_id} should be between 1001 and 1003" + ); + } + + // THIRD COMMIT: Delete one order + let mut left_delta3 = Delta::new(); + left_delta3.delete( + 2, + vec![ + Value::Integer(100), + Value::Integer(1002), + Value::Float(75.0), + ], + ); + + let delta_pair3 = DeltaPair::new(left_delta3, Delta::new()); + let result3 = pager + .io + .block(|| join.commit(delta_pair3.clone(), &mut cursors)) + .unwrap(); + + // Should produce 1 retraction for the deleted order + assert_eq!(result3.changes.len(), 1, "Should produce 1 retraction"); + assert_eq!(result3.changes[0].1, -1, "Should be a deletion"); + assert_eq!( + result3.changes[0].0.values[1], + Value::Integer(1002), + "Should delete order 1002" + ); + } + + #[test] + fn test_join_operator_many_to_many() { + // Test many-to-many: multiple rows with same key on both sides + let (pager, table_page_id, index_page_id) = create_test_pager(); + let table_cursor = BTreeCursor::new_table(None, pager.clone(), table_page_id, 10); + let index_def = create_dbsp_state_index(index_page_id); + let index_cursor = + BTreeCursor::new_index(None, pager.clone(), index_page_id, &index_def, 10); + let mut cursors = DbspStateCursors::new(table_cursor, index_cursor); + + let mut join = JoinOperator::new( + 1, // operator_id + JoinType::Inner, + vec![0], // Join on category_id + vec![0], // Join on id + vec![ + "category_id".to_string(), + "product_name".to_string(), + "price".to_string(), + ], + vec!["id".to_string(), "category_name".to_string()], + ) + .unwrap(); + + // FIRST COMMIT: Add multiple products in same category + let mut left_delta = Delta::new(); + left_delta.insert( + 1, + vec![ + Value::Integer(10), + Value::Text("Laptop".into()), + Value::Float(1000.0), + ], + ); + left_delta.insert( + 2, + vec![ + Value::Integer(10), + Value::Text("Mouse".into()), + Value::Float(50.0), + ], + ); + left_delta.insert( + 3, + vec![ + Value::Integer(10), + Value::Text("Keyboard".into()), + Value::Float(100.0), + ], + ); + + // Add multiple categories with same ID (simulating denormalized data or versioning) + let mut right_delta = Delta::new(); + right_delta.insert( + 1, + vec![Value::Integer(10), Value::Text("Electronics".into())], + ); + right_delta.insert(2, vec![Value::Integer(10), Value::Text("Computers".into())]); // Same category ID, different name + + let delta_pair = DeltaPair::new(left_delta, right_delta); + let result = pager + .io + .block(|| join.commit(delta_pair.clone(), &mut cursors)) + .unwrap(); + + // Should produce 3 products × 2 categories = 6 joins + assert_eq!( + result.changes.len(), + 6, + "Should produce 6 joins (3 products × 2 category records). Got: {:?}", + result.changes + ); + + // Verify we have all combinations + let mut found_combinations = std::collections::HashSet::new(); + for (row, weight) in &result.changes { + assert_eq!(*weight, 1); + let product = row.values[1].to_string(); + let category = row.values[4].to_string(); + found_combinations.insert((product, category)); + } + + assert_eq!( + found_combinations.len(), + 6, + "Should have 6 unique combinations" + ); + + // SECOND COMMIT: Add one more product in the same category + let mut left_delta2 = Delta::new(); + left_delta2.insert( + 4, + vec![ + Value::Integer(10), + Value::Text("Monitor".into()), + Value::Float(500.0), + ], + ); + + let delta_pair2 = DeltaPair::new(left_delta2, Delta::new()); + let result2 = pager + .io + .block(|| join.commit(delta_pair2.clone(), &mut cursors)) + .unwrap(); + + // New product should join with both existing category records + assert_eq!( + result2.changes.len(), + 2, + "New product should join with 2 existing category records" + ); + + for (row, _) in &result2.changes { + assert_eq!(row.values[1], Value::Text("Monitor".into())); + } + } + + #[test] + fn test_join_operator_update_in_one_to_many() { + // Test updates in one-to-many scenarios + let (pager, table_page_id, index_page_id) = create_test_pager(); + let table_cursor = BTreeCursor::new_table(None, pager.clone(), table_page_id, 10); + let index_def = create_dbsp_state_index(index_page_id); + let index_cursor = + BTreeCursor::new_index(None, pager.clone(), index_page_id, &index_def, 10); + let mut cursors = DbspStateCursors::new(table_cursor, index_cursor); + + let mut join = JoinOperator::new( + 1, // operator_id + JoinType::Inner, + vec![0], // Join on customer_id + vec![0], // Join on id + vec![ + "customer_id".to_string(), + "order_id".to_string(), + "amount".to_string(), + ], + vec!["id".to_string(), "name".to_string()], + ) + .unwrap(); + + // FIRST COMMIT: Setup one customer with multiple orders + let mut left_delta = Delta::new(); + left_delta.insert( + 1, + vec![ + Value::Integer(100), + Value::Integer(1001), + Value::Float(50.0), + ], + ); + left_delta.insert( + 2, + vec![ + Value::Integer(100), + Value::Integer(1002), + Value::Float(75.0), + ], + ); + left_delta.insert( + 3, + vec![ + Value::Integer(100), + Value::Integer(1003), + Value::Float(100.0), + ], + ); + + let mut right_delta = Delta::new(); + right_delta.insert(1, vec![Value::Integer(100), Value::Text("Alice".into())]); + + let delta_pair = DeltaPair::new(left_delta, right_delta); + let result = pager + .io + .block(|| join.commit(delta_pair.clone(), &mut cursors)) + .unwrap(); + + assert_eq!(result.changes.len(), 3, "Should have 3 initial joins"); + + // SECOND COMMIT: Update the customer name (affects all 3 joins) + let mut right_delta2 = Delta::new(); + // Delete old customer record + right_delta2.delete(1, vec![Value::Integer(100), Value::Text("Alice".into())]); + // Insert updated customer record + right_delta2.insert( + 1, + vec![Value::Integer(100), Value::Text("Alice Smith".into())], + ); + + let delta_pair2 = DeltaPair::new(Delta::new(), right_delta2); + let result2 = pager + .io + .block(|| join.commit(delta_pair2.clone(), &mut cursors)) + .unwrap(); + + // Should produce 3 deletions and 3 insertions (one for each order) + assert_eq!(result2.changes.len(), 6, + "Should produce 6 changes (3 deletions + 3 insertions) when updating customer with 3 orders"); + + let deletions: Vec<_> = result2.changes.iter().filter(|(_, w)| *w == -1).collect(); + let insertions: Vec<_> = result2.changes.iter().filter(|(_, w)| *w == 1).collect(); + + assert_eq!(deletions.len(), 3, "Should have 3 deletions"); + assert_eq!(insertions.len(), 3, "Should have 3 insertions"); + + // Check all deletions have old name + for (row, _) in &deletions { + assert_eq!( + row.values[4], + Value::Text("Alice".into()), + "Deletions should have old name" + ); + } + + // Check all insertions have new name + for (row, _) in &insertions { + assert_eq!( + row.values[4], + Value::Text("Alice Smith".into()), + "Insertions should have new name" + ); + } + + // Verify we still have all three order IDs in the insertions + let mut order_ids = std::collections::HashSet::new(); + for (row, _) in &insertions { + if let Value::Integer(order_id) = &row.values[1] { + order_ids.insert(*order_id); + } + } + assert_eq!( + order_ids.len(), + 3, + "Should still have all 3 order IDs after update" + ); + assert!(order_ids.contains(&1001)); + assert!(order_ids.contains(&1002)); + assert!(order_ids.contains(&1003)); + } + + #[test] + fn test_join_operator_weight_accumulation_complex() { + // Test complex weight accumulation with multiple identical rows + let (pager, table_page_id, index_page_id) = create_test_pager(); + let table_cursor = BTreeCursor::new_table(None, pager.clone(), table_page_id, 10); + let index_def = create_dbsp_state_index(index_page_id); + let index_cursor = + BTreeCursor::new_index(None, pager.clone(), index_page_id, &index_def, 10); + let mut cursors = DbspStateCursors::new(table_cursor, index_cursor); + + let mut join = JoinOperator::new( + 1, // operator_id + JoinType::Inner, + vec![0], // Join on first column + vec![0], + vec!["key".to_string(), "val_left".to_string()], + vec!["key".to_string(), "val_right".to_string()], + ) + .unwrap(); + + // FIRST COMMIT: Add identical rows multiple times (simulating duplicates) + let mut left_delta = Delta::new(); + // Same key-value pair inserted 3 times with different rowids + left_delta.insert(1, vec![Value::Integer(10), Value::Text("A".into())]); + left_delta.insert(2, vec![Value::Integer(10), Value::Text("A".into())]); + left_delta.insert(3, vec![Value::Integer(10), Value::Text("A".into())]); + + let mut right_delta = Delta::new(); + // Same key-value pair inserted 2 times + right_delta.insert(4, vec![Value::Integer(10), Value::Text("B".into())]); + right_delta.insert(5, vec![Value::Integer(10), Value::Text("B".into())]); + + let delta_pair = DeltaPair::new(left_delta, right_delta); + let result = pager + .io + .block(|| join.commit(delta_pair.clone(), &mut cursors)) + .unwrap(); + + // Should produce 3 × 2 = 6 join results (cartesian product) + assert_eq!( + result.changes.len(), + 6, + "Should produce 6 joins (3 left rows × 2 right rows)" + ); + + // All should have weight 1 + for (_, weight) in &result.changes { + assert_eq!(*weight, 1); + } + + // SECOND COMMIT: Delete one instance from left + let mut left_delta2 = Delta::new(); + left_delta2.delete(2, vec![Value::Integer(10), Value::Text("A".into())]); + + let delta_pair2 = DeltaPair::new(left_delta2, Delta::new()); + let result2 = pager + .io + .block(|| join.commit(delta_pair2.clone(), &mut cursors)) + .unwrap(); + + // Should produce 2 retractions (1 deleted left row × 2 right rows) + assert_eq!( + result2.changes.len(), + 2, + "Should produce 2 retractions when deleting 1 of 3 identical left rows" + ); + + for (_, weight) in &result2.changes { + assert_eq!(*weight, -1, "Should be retractions"); + } + } + + #[test] + fn test_join_produces_all_expected_results() { + // Test that a join produces ALL expected output rows + // This reproduces the issue where only 1 of 3 expected rows appears in the final result + + // Create a join operator similar to: SELECT u.name, o.quantity FROM users u JOIN orders o ON u.id = o.user_id + let mut join = JoinOperator::new( + 0, + JoinType::Inner, + vec![0], // Join on first column (id) + vec![0], // Join on first column (user_id) + vec!["id".to_string(), "name".to_string()], + vec![ + "user_id".to_string(), + "product_id".to_string(), + "quantity".to_string(), + ], + ) + .unwrap(); + + // Create test data matching the example that fails: + // users: (1, 'Alice'), (2, 'Bob') + // orders: (1, 5), (1, 3), (2, 7) -- user_id, quantity + let left_delta = Delta { + changes: vec![ + ( + HashableRow::new(1, vec![Value::Integer(1), Value::Text(Text::from("Alice"))]), + 1, + ), + ( + HashableRow::new(2, vec![Value::Integer(2), Value::Text(Text::from("Bob"))]), + 1, + ), + ], + }; + + // Orders: Alice has 2 orders, Bob has 1 + let right_delta = Delta { + changes: vec![ + ( + HashableRow::new( + 1, + vec![Value::Integer(1), Value::Integer(100), Value::Integer(5)], + ), + 1, + ), + ( + HashableRow::new( + 2, + vec![Value::Integer(1), Value::Integer(101), Value::Integer(3)], + ), + 1, + ), + ( + HashableRow::new( + 3, + vec![Value::Integer(2), Value::Integer(100), Value::Integer(7)], + ), + 1, + ), + ], + }; + + // Evaluate the join + let delta_pair = DeltaPair::new(left_delta, right_delta); + let mut state = EvalState::Init { deltas: delta_pair }; + + let (pager, table_root, index_root) = create_test_pager(); + let table_cursor = BTreeCursor::new_table(None, pager.clone(), table_root, 5); + let index_def = create_dbsp_state_index(index_root); + let index_cursor = BTreeCursor::new_index(None, pager.clone(), index_root, &index_def, 4); + let mut cursors = DbspStateCursors::new(table_cursor, index_cursor); + + let result = pager + .io + .block(|| join.eval(&mut state, &mut cursors)) + .unwrap(); + + // Should produce 3 results: Alice with 2 orders, Bob with 1 order + assert_eq!( + result.changes.len(), + 3, + "Should produce 3 joined rows (Alice×2 + Bob×1)" + ); + + // Verify the actual content of the results + let mut expected_results = std::collections::HashSet::new(); + // Expected: (Alice, 5), (Alice, 3), (Bob, 7) + expected_results.insert(("Alice".to_string(), 5)); + expected_results.insert(("Alice".to_string(), 3)); + expected_results.insert(("Bob".to_string(), 7)); + + let mut actual_results = std::collections::HashSet::new(); + for (row, weight) in &result.changes { + assert_eq!(*weight, 1, "All results should have weight 1"); + + // Extract name (column 1 from left) and quantity (column 3 from right) + let name = match &row.values[1] { + Value::Text(t) => t.as_str().to_string(), + _ => panic!("Expected text value for name"), + }; + let quantity = match &row.values[4] { + Value::Integer(q) => *q, + _ => panic!("Expected integer value for quantity"), + }; + + actual_results.insert((name, quantity)); + } + + assert_eq!( + expected_results, actual_results, + "Join should produce all expected results. Expected: {expected_results:?}, Got: {actual_results:?}", + ); + + // Also verify that rowids are unique (this is important for btree storage) + let mut seen_rowids = std::collections::HashSet::new(); + for (row, _) in &result.changes { + let was_new = seen_rowids.insert(row.rowid); + assert!(was_new, "Duplicate rowid found: {}. This would cause rows to overwrite each other in btree storage!", row.rowid); + } + } } From 9747d6c6b61d91211929d24cb82659ebf3ec0a73 Mon Sep 17 00:00:00 2001 From: Glauber Costa Date: Wed, 17 Sep 2025 10:45:12 -0500 Subject: [PATCH 13/34] move the input operator to its own file. The code is becoming impossible to reason about with everything in operator.rs --- core/incremental/input_operator.rs | 66 ++++++++++++++++++++++++++++++ core/incremental/mod.rs | 1 + core/incremental/operator.rs | 61 +-------------------------- 3 files changed, 69 insertions(+), 59 deletions(-) create mode 100644 core/incremental/input_operator.rs diff --git a/core/incremental/input_operator.rs b/core/incremental/input_operator.rs new file mode 100644 index 000000000..b9a6eeb01 --- /dev/null +++ b/core/incremental/input_operator.rs @@ -0,0 +1,66 @@ +// Input operator for DBSP-style incremental computation +// This operator serves as the entry point for data into the incremental computation pipeline + +use crate::incremental::dbsp::{Delta, DeltaPair}; +use crate::incremental::operator::{ + ComputationTracker, DbspStateCursors, EvalState, IncrementalOperator, +}; +use crate::types::IOResult; +use crate::Result; +use std::sync::{Arc, Mutex}; + +/// Input operator - source of data for the circuit +/// Represents base relations/tables that receive external updates +#[derive(Debug)] +pub struct InputOperator { + #[allow(dead_code)] + name: String, +} + +impl InputOperator { + pub fn new(name: String) -> Self { + Self { name } + } +} + +impl IncrementalOperator for InputOperator { + fn eval( + &mut self, + state: &mut EvalState, + _cursors: &mut DbspStateCursors, + ) -> Result> { + match state { + EvalState::Init { deltas } => { + // Input operators only use left_delta, right_delta must be empty + assert!( + deltas.right.is_empty(), + "InputOperator expects right_delta to be empty" + ); + let output = std::mem::take(&mut deltas.left); + *state = EvalState::Done; + Ok(IOResult::Done(output)) + } + _ => unreachable!( + "InputOperator doesn't execute the state machine. Should be in Init state" + ), + } + } + + fn commit( + &mut self, + deltas: DeltaPair, + _cursors: &mut DbspStateCursors, + ) -> Result> { + // Input operator only uses left delta, right must be empty + assert!( + deltas.right.is_empty(), + "InputOperator expects right delta to be empty in commit" + ); + // Input operator passes through the delta unchanged during commit + Ok(IOResult::Done(deltas.left)) + } + + fn set_tracker(&mut self, _tracker: Arc>) { + // Input operator doesn't need tracking + } +} diff --git a/core/incremental/mod.rs b/core/incremental/mod.rs index 2c69e050b..ef8cbfb29 100644 --- a/core/incremental/mod.rs +++ b/core/incremental/mod.rs @@ -2,6 +2,7 @@ pub mod compiler; pub mod cursor; pub mod dbsp; pub mod expr_compiler; +pub mod input_operator; pub mod operator; pub mod persistence; pub mod view; diff --git a/core/incremental/operator.rs b/core/incremental/operator.rs index 78d2ca60d..03aae440d 100644 --- a/core/incremental/operator.rs +++ b/core/incremental/operator.rs @@ -2,6 +2,8 @@ // Operator DAG for DBSP-style incremental computation // Based on Feldera DBSP design but adapted for Turso's architecture +pub use crate::incremental::input_operator::InputOperator; + use crate::function::{AggFunc, Func}; use crate::incremental::dbsp::{Delta, DeltaPair, HashableRow}; use crate::incremental::expr_compiler::CompiledExpression; @@ -1019,65 +1021,6 @@ pub trait IncrementalOperator: Debug { fn set_tracker(&mut self, tracker: Arc>); } -/// Input operator - passes through input data unchanged -/// This operator is used for input nodes in the circuit to provide a uniform interface -#[derive(Debug)] -pub struct InputOperator { - name: String, -} - -impl InputOperator { - pub fn new(name: String) -> Self { - Self { name } - } - - pub fn name(&self) -> &str { - &self.name - } -} - -impl IncrementalOperator for InputOperator { - fn eval( - &mut self, - state: &mut EvalState, - _cursors: &mut DbspStateCursors, - ) -> Result> { - match state { - EvalState::Init { deltas } => { - // Input operators only use left_delta, right_delta must be empty - assert!( - deltas.right.is_empty(), - "InputOperator expects right_delta to be empty" - ); - let output = std::mem::take(&mut deltas.left); - *state = EvalState::Done; - Ok(IOResult::Done(output)) - } - _ => unreachable!( - "InputOperator doesn't execute the state machine. Should be in Init state" - ), - } - } - - fn commit( - &mut self, - deltas: DeltaPair, - _cursors: &mut DbspStateCursors, - ) -> Result> { - // Input operator only uses left delta, right must be empty - assert!( - deltas.right.is_empty(), - "InputOperator expects right delta to be empty in commit" - ); - // Input operator passes through the delta unchanged during commit - Ok(IOResult::Done(deltas.left)) - } - - fn set_tracker(&mut self, _tracker: Arc>) { - // Input operator doesn't need tracking - } -} - /// Filter operator - filters rows based on predicate #[derive(Debug)] pub struct FilterOperator { From ee914fc54382f49e287e25ab3f956d24d9346ccf Mon Sep 17 00:00:00 2001 From: Glauber Costa Date: Wed, 17 Sep 2025 10:45:12 -0500 Subject: [PATCH 14/34] move the filter operator to its own file. The code is becoming impossible to reason about with everything in operator.rs --- core/incremental/filter_operator.rs | 325 ++++++++++++++++++++++++++++ core/incremental/mod.rs | 1 + core/incremental/operator.rs | 313 +-------------------------- 3 files changed, 327 insertions(+), 312 deletions(-) create mode 100644 core/incremental/filter_operator.rs diff --git a/core/incremental/filter_operator.rs b/core/incremental/filter_operator.rs new file mode 100644 index 000000000..f836f4897 --- /dev/null +++ b/core/incremental/filter_operator.rs @@ -0,0 +1,325 @@ +#![allow(dead_code)] +// Filter operator for DBSP-style incremental computation +// This operator filters rows based on predicates + +use crate::incremental::dbsp::{Delta, DeltaPair}; +use crate::incremental::operator::{ + ComputationTracker, DbspStateCursors, EvalState, IncrementalOperator, +}; +use crate::types::{IOResult, Text}; +use crate::{Result, Value}; +use std::sync::{Arc, Mutex}; +use turso_parser::ast::{Expr, Literal, OneSelect, Operator}; + +/// Filter predicate for filtering rows +#[derive(Debug, Clone)] +pub enum FilterPredicate { + /// Column = value + Equals { column: String, value: Value }, + /// Column != value + NotEquals { column: String, value: Value }, + /// Column > value + GreaterThan { column: String, value: Value }, + /// Column >= value + GreaterThanOrEqual { column: String, value: Value }, + /// Column < value + LessThan { column: String, value: Value }, + /// Column <= value + LessThanOrEqual { column: String, value: Value }, + /// Logical AND of two predicates + And(Box, Box), + /// Logical OR of two predicates + Or(Box, Box), + /// No predicate (accept all rows) + None, +} + +impl FilterPredicate { + /// Parse a SQL AST expression into a FilterPredicate + /// This centralizes all SQL-to-predicate parsing logic + pub fn from_sql_expr(expr: &turso_parser::ast::Expr) -> crate::Result { + let Expr::Binary(lhs, op, rhs) = expr else { + return Err(crate::LimboError::ParseError( + "Unsupported WHERE clause for incremental views: not a binary expression" + .to_string(), + )); + }; + + // Handle AND/OR logical operators + match op { + Operator::And => { + let left = Self::from_sql_expr(lhs)?; + let right = Self::from_sql_expr(rhs)?; + return Ok(FilterPredicate::And(Box::new(left), Box::new(right))); + } + Operator::Or => { + let left = Self::from_sql_expr(lhs)?; + let right = Self::from_sql_expr(rhs)?; + return Ok(FilterPredicate::Or(Box::new(left), Box::new(right))); + } + _ => {} + } + + // Handle comparison operators + let Expr::Id(column_name) = &**lhs else { + return Err(crate::LimboError::ParseError( + "Unsupported WHERE clause for incremental views: left-hand-side is not a column reference".to_string(), + )); + }; + + let column = column_name.as_str().to_string(); + + // Parse the right-hand side value + let value = match &**rhs { + Expr::Literal(Literal::String(s)) => { + // Strip quotes from string literals + let cleaned = s.trim_matches('\'').trim_matches('"'); + Value::Text(Text::new(cleaned)) + } + Expr::Literal(Literal::Numeric(n)) => { + // Try to parse as integer first, then float + if let Ok(i) = n.parse::() { + Value::Integer(i) + } else if let Ok(f) = n.parse::() { + Value::Float(f) + } else { + return Err(crate::LimboError::ParseError( + "Unsupported WHERE clause for incremental views: right-hand-side is not a numeric literal".to_string(), + )); + } + } + Expr::Literal(Literal::Null) => Value::Null, + Expr::Literal(Literal::Blob(_)) => { + // Blob comparison not yet supported + return Err(crate::LimboError::ParseError( + "Unsupported WHERE clause for incremental views: comparison with blob literals is not supported".to_string(), + )); + } + other => { + // Complex expressions not yet supported + return Err(crate::LimboError::ParseError( + format!("Unsupported WHERE clause for incremental views: comparison with {other:?} is not supported"), + )); + } + }; + + // Create the appropriate predicate based on operator + match op { + Operator::Equals => Ok(FilterPredicate::Equals { column, value }), + Operator::NotEquals => Ok(FilterPredicate::NotEquals { column, value }), + Operator::Greater => Ok(FilterPredicate::GreaterThan { column, value }), + Operator::GreaterEquals => Ok(FilterPredicate::GreaterThanOrEqual { column, value }), + Operator::Less => Ok(FilterPredicate::LessThan { column, value }), + Operator::LessEquals => Ok(FilterPredicate::LessThanOrEqual { column, value }), + other => Err(crate::LimboError::ParseError( + format!("Unsupported WHERE clause for incremental views: comparison operator {other:?} is not supported"), + )), + } + } + + /// Parse a WHERE clause from a SELECT statement + pub fn from_select(select: &turso_parser::ast::Select) -> crate::Result { + if let OneSelect::Select { + ref where_clause, .. + } = select.body.select + { + if let Some(where_clause) = where_clause { + Self::from_sql_expr(where_clause) + } else { + Ok(FilterPredicate::None) + } + } else { + Err(crate::LimboError::ParseError( + "Unsupported WHERE clause for incremental views: not a single SELECT statement" + .to_string(), + )) + } + } +} + +/// Filter operator - filters rows based on predicate +#[derive(Debug)] +pub struct FilterOperator { + predicate: FilterPredicate, + column_names: Vec, + tracker: Option>>, +} + +impl FilterOperator { + pub fn new(predicate: FilterPredicate, column_names: Vec) -> Self { + Self { + predicate, + column_names, + tracker: None, + } + } + + /// Get the predicate for this filter + pub fn predicate(&self) -> &FilterPredicate { + &self.predicate + } + + pub fn evaluate_predicate(&self, values: &[Value]) -> bool { + match &self.predicate { + FilterPredicate::None => true, + FilterPredicate::Equals { column, value } => { + if let Some(idx) = self.column_names.iter().position(|c| c == column) { + if let Some(v) = values.get(idx) { + return v == value; + } + } + false + } + FilterPredicate::NotEquals { column, value } => { + if let Some(idx) = self.column_names.iter().position(|c| c == column) { + if let Some(v) = values.get(idx) { + return v != value; + } + } + false + } + FilterPredicate::GreaterThan { column, value } => { + if let Some(idx) = self.column_names.iter().position(|c| c == column) { + if let Some(v) = values.get(idx) { + // Compare based on value types + match (v, value) { + (Value::Integer(a), Value::Integer(b)) => return a > b, + (Value::Float(a), Value::Float(b)) => return a > b, + (Value::Text(a), Value::Text(b)) => return a.as_str() > b.as_str(), + _ => {} + } + } + } + false + } + FilterPredicate::GreaterThanOrEqual { column, value } => { + if let Some(idx) = self.column_names.iter().position(|c| c == column) { + if let Some(v) = values.get(idx) { + match (v, value) { + (Value::Integer(a), Value::Integer(b)) => return a >= b, + (Value::Float(a), Value::Float(b)) => return a >= b, + (Value::Text(a), Value::Text(b)) => return a.as_str() >= b.as_str(), + _ => {} + } + } + } + false + } + FilterPredicate::LessThan { column, value } => { + if let Some(idx) = self.column_names.iter().position(|c| c == column) { + if let Some(v) = values.get(idx) { + match (v, value) { + (Value::Integer(a), Value::Integer(b)) => return a < b, + (Value::Float(a), Value::Float(b)) => return a < b, + (Value::Text(a), Value::Text(b)) => return a.as_str() < b.as_str(), + _ => {} + } + } + } + false + } + FilterPredicate::LessThanOrEqual { column, value } => { + if let Some(idx) = self.column_names.iter().position(|c| c == column) { + if let Some(v) = values.get(idx) { + match (v, value) { + (Value::Integer(a), Value::Integer(b)) => return a <= b, + (Value::Float(a), Value::Float(b)) => return a <= b, + (Value::Text(a), Value::Text(b)) => return a.as_str() <= b.as_str(), + _ => {} + } + } + } + false + } + FilterPredicate::And(left, right) => { + // Temporarily create sub-filters to evaluate + let left_filter = FilterOperator::new((**left).clone(), self.column_names.clone()); + let right_filter = + FilterOperator::new((**right).clone(), self.column_names.clone()); + left_filter.evaluate_predicate(values) && right_filter.evaluate_predicate(values) + } + FilterPredicate::Or(left, right) => { + let left_filter = FilterOperator::new((**left).clone(), self.column_names.clone()); + let right_filter = + FilterOperator::new((**right).clone(), self.column_names.clone()); + left_filter.evaluate_predicate(values) || right_filter.evaluate_predicate(values) + } + } + } +} + +impl IncrementalOperator for FilterOperator { + fn eval( + &mut self, + state: &mut EvalState, + _cursors: &mut DbspStateCursors, + ) -> Result> { + let delta = match state { + EvalState::Init { deltas } => { + // Filter operators only use left_delta, right_delta must be empty + assert!( + deltas.right.is_empty(), + "FilterOperator expects right_delta to be empty" + ); + std::mem::take(&mut deltas.left) + } + _ => unreachable!( + "FilterOperator doesn't execute the state machine. Should be in Init state" + ), + }; + + let mut output_delta = Delta::new(); + + // Process the delta through the filter + for (row, weight) in delta.changes { + if let Some(tracker) = &self.tracker { + tracker.lock().unwrap().record_filter(); + } + + // Only pass through rows that satisfy the filter predicate + // For deletes (weight < 0), we only pass them if the row values + // would have passed the filter (meaning it was in the view) + if self.evaluate_predicate(&row.values) { + output_delta.changes.push((row, weight)); + } + } + + *state = EvalState::Done; + Ok(IOResult::Done(output_delta)) + } + + fn commit( + &mut self, + deltas: DeltaPair, + _cursors: &mut DbspStateCursors, + ) -> Result> { + // Filter operator only uses left delta, right must be empty + assert!( + deltas.right.is_empty(), + "FilterOperator expects right delta to be empty in commit" + ); + + let mut output_delta = Delta::new(); + + // Commit the delta to our internal state + // Only pass through and track rows that satisfy the filter predicate + for (row, weight) in deltas.left.changes { + if let Some(tracker) = &self.tracker { + tracker.lock().unwrap().record_filter(); + } + + // Only track and output rows that pass the filter + // For deletes, this means the row was in the view (its values pass the filter) + // For inserts, this means the row should be in the view + if self.evaluate_predicate(&row.values) { + output_delta.changes.push((row, weight)); + } + } + + Ok(IOResult::Done(output_delta)) + } + + fn set_tracker(&mut self, tracker: Arc>) { + self.tracker = Some(tracker); + } +} diff --git a/core/incremental/mod.rs b/core/incremental/mod.rs index ef8cbfb29..8a6722370 100644 --- a/core/incremental/mod.rs +++ b/core/incremental/mod.rs @@ -2,6 +2,7 @@ pub mod compiler; pub mod cursor; pub mod dbsp; pub mod expr_compiler; +pub mod filter_operator; pub mod input_operator; pub mod operator; pub mod persistence; diff --git a/core/incremental/operator.rs b/core/incremental/operator.rs index 03aae440d..4374e27e5 100644 --- a/core/incremental/operator.rs +++ b/core/incremental/operator.rs @@ -2,6 +2,7 @@ // Operator DAG for DBSP-style incremental computation // Based on Feldera DBSP design but adapted for Turso's architecture +pub use crate::incremental::filter_operator::{FilterOperator, FilterPredicate}; pub use crate::incremental::input_operator::InputOperator; use crate::function::{AggFunc, Func}; @@ -794,131 +795,6 @@ pub enum QueryOperator { }, } -#[derive(Debug, Clone)] -pub enum FilterPredicate { - /// Column = value - Equals { column: String, value: Value }, - /// Column != value - NotEquals { column: String, value: Value }, - /// Column > value - GreaterThan { column: String, value: Value }, - /// Column >= value - GreaterThanOrEqual { column: String, value: Value }, - /// Column < value - LessThan { column: String, value: Value }, - /// Column <= value - LessThanOrEqual { column: String, value: Value }, - /// Logical AND of two predicates - And(Box, Box), - /// Logical OR of two predicates - Or(Box, Box), - /// No predicate (accept all rows) - None, -} - -impl FilterPredicate { - /// Parse a SQL AST expression into a FilterPredicate - /// This centralizes all SQL-to-predicate parsing logic - pub fn from_sql_expr(expr: &turso_parser::ast::Expr) -> crate::Result { - let Expr::Binary(lhs, op, rhs) = expr else { - return Err(crate::LimboError::ParseError( - "Unsupported WHERE clause for incremental views: not a binary expression" - .to_string(), - )); - }; - - // Handle AND/OR logical operators - match op { - Operator::And => { - let left = Self::from_sql_expr(lhs)?; - let right = Self::from_sql_expr(rhs)?; - return Ok(FilterPredicate::And(Box::new(left), Box::new(right))); - } - Operator::Or => { - let left = Self::from_sql_expr(lhs)?; - let right = Self::from_sql_expr(rhs)?; - return Ok(FilterPredicate::Or(Box::new(left), Box::new(right))); - } - _ => {} - } - - // Handle comparison operators - let Expr::Id(column_name) = &**lhs else { - return Err(crate::LimboError::ParseError( - "Unsupported WHERE clause for incremental views: left-hand-side is not a column reference".to_string(), - )); - }; - - let column = column_name.as_str().to_string(); - - // Parse the right-hand side value - let value = match &**rhs { - Expr::Literal(Literal::String(s)) => { - // Strip quotes from string literals - let cleaned = s.trim_matches('\'').trim_matches('"'); - Value::Text(Text::new(cleaned)) - } - Expr::Literal(Literal::Numeric(n)) => { - // Try to parse as integer first, then float - if let Ok(i) = n.parse::() { - Value::Integer(i) - } else if let Ok(f) = n.parse::() { - Value::Float(f) - } else { - return Err(crate::LimboError::ParseError( - "Unsupported WHERE clause for incremental views: right-hand-side is not a numeric literal".to_string(), - )); - } - } - Expr::Literal(Literal::Null) => Value::Null, - Expr::Literal(Literal::Blob(_)) => { - // Blob comparison not yet supported - return Err(crate::LimboError::ParseError( - "Unsupported WHERE clause for incremental views: comparison with blob literals is not supported".to_string(), - )); - } - other => { - // Complex expressions not yet supported - return Err(crate::LimboError::ParseError( - format!("Unsupported WHERE clause for incremental views: comparison with {other:?} is not supported"), - )); - } - }; - - // Create the appropriate predicate based on operator - match op { - Operator::Equals => Ok(FilterPredicate::Equals { column, value }), - Operator::NotEquals => Ok(FilterPredicate::NotEquals { column, value }), - Operator::Greater => Ok(FilterPredicate::GreaterThan { column, value }), - Operator::GreaterEquals => Ok(FilterPredicate::GreaterThanOrEqual { column, value }), - Operator::Less => Ok(FilterPredicate::LessThan { column, value }), - Operator::LessEquals => Ok(FilterPredicate::LessThanOrEqual { column, value }), - other => Err(crate::LimboError::ParseError( - format!("Unsupported WHERE clause for incremental views: comparison operator {other:?} is not supported"), - )), - } - } - - /// Parse a WHERE clause from a SELECT statement - pub fn from_select(select: &turso_parser::ast::Select) -> crate::Result { - if let OneSelect::Select { - ref where_clause, .. - } = select.body.select - { - if let Some(where_clause) = where_clause { - Self::from_sql_expr(where_clause) - } else { - Ok(FilterPredicate::None) - } - } else { - Err(crate::LimboError::ParseError( - "Unsupported WHERE clause for incremental views: not a single SELECT statement" - .to_string(), - )) - } - } -} - #[derive(Debug, Clone)] pub struct ProjectColumn { /// The original SQL expression (for debugging/fallback) @@ -1021,193 +897,6 @@ pub trait IncrementalOperator: Debug { fn set_tracker(&mut self, tracker: Arc>); } -/// Filter operator - filters rows based on predicate -#[derive(Debug)] -pub struct FilterOperator { - predicate: FilterPredicate, - column_names: Vec, - tracker: Option>>, -} - -impl FilterOperator { - pub fn new(predicate: FilterPredicate, column_names: Vec) -> Self { - Self { - predicate, - column_names, - tracker: None, - } - } - - /// Get the predicate for this filter - pub fn predicate(&self) -> &FilterPredicate { - &self.predicate - } - - pub fn evaluate_predicate(&self, values: &[Value]) -> bool { - match &self.predicate { - FilterPredicate::None => true, - FilterPredicate::Equals { column, value } => { - if let Some(idx) = self.column_names.iter().position(|c| c == column) { - if let Some(v) = values.get(idx) { - return v == value; - } - } - false - } - FilterPredicate::NotEquals { column, value } => { - if let Some(idx) = self.column_names.iter().position(|c| c == column) { - if let Some(v) = values.get(idx) { - return v != value; - } - } - false - } - FilterPredicate::GreaterThan { column, value } => { - if let Some(idx) = self.column_names.iter().position(|c| c == column) { - if let Some(v) = values.get(idx) { - // Compare based on value types - match (v, value) { - (Value::Integer(a), Value::Integer(b)) => return a > b, - (Value::Float(a), Value::Float(b)) => return a > b, - (Value::Text(a), Value::Text(b)) => return a.as_str() > b.as_str(), - _ => {} - } - } - } - false - } - FilterPredicate::GreaterThanOrEqual { column, value } => { - if let Some(idx) = self.column_names.iter().position(|c| c == column) { - if let Some(v) = values.get(idx) { - match (v, value) { - (Value::Integer(a), Value::Integer(b)) => return a >= b, - (Value::Float(a), Value::Float(b)) => return a >= b, - (Value::Text(a), Value::Text(b)) => return a.as_str() >= b.as_str(), - _ => {} - } - } - } - false - } - FilterPredicate::LessThan { column, value } => { - if let Some(idx) = self.column_names.iter().position(|c| c == column) { - if let Some(v) = values.get(idx) { - match (v, value) { - (Value::Integer(a), Value::Integer(b)) => return a < b, - (Value::Float(a), Value::Float(b)) => return a < b, - (Value::Text(a), Value::Text(b)) => return a.as_str() < b.as_str(), - _ => {} - } - } - } - false - } - FilterPredicate::LessThanOrEqual { column, value } => { - if let Some(idx) = self.column_names.iter().position(|c| c == column) { - if let Some(v) = values.get(idx) { - match (v, value) { - (Value::Integer(a), Value::Integer(b)) => return a <= b, - (Value::Float(a), Value::Float(b)) => return a <= b, - (Value::Text(a), Value::Text(b)) => return a.as_str() <= b.as_str(), - _ => {} - } - } - } - false - } - FilterPredicate::And(left, right) => { - // Temporarily create sub-filters to evaluate - let left_filter = FilterOperator::new((**left).clone(), self.column_names.clone()); - let right_filter = - FilterOperator::new((**right).clone(), self.column_names.clone()); - left_filter.evaluate_predicate(values) && right_filter.evaluate_predicate(values) - } - FilterPredicate::Or(left, right) => { - let left_filter = FilterOperator::new((**left).clone(), self.column_names.clone()); - let right_filter = - FilterOperator::new((**right).clone(), self.column_names.clone()); - left_filter.evaluate_predicate(values) || right_filter.evaluate_predicate(values) - } - } - } -} - -impl IncrementalOperator for FilterOperator { - fn eval( - &mut self, - state: &mut EvalState, - _cursors: &mut DbspStateCursors, - ) -> Result> { - let delta = match state { - EvalState::Init { deltas } => { - // Filter operators only use left_delta, right_delta must be empty - assert!( - deltas.right.is_empty(), - "FilterOperator expects right_delta to be empty" - ); - std::mem::take(&mut deltas.left) - } - _ => unreachable!( - "FilterOperator doesn't execute the state machine. Should be in Init state" - ), - }; - - let mut output_delta = Delta::new(); - - // Process the delta through the filter - for (row, weight) in delta.changes { - if let Some(tracker) = &self.tracker { - tracker.lock().unwrap().record_filter(); - } - - // Only pass through rows that satisfy the filter predicate - // For deletes (weight < 0), we only pass them if the row values - // would have passed the filter (meaning it was in the view) - if self.evaluate_predicate(&row.values) { - output_delta.changes.push((row, weight)); - } - } - - *state = EvalState::Done; - Ok(IOResult::Done(output_delta)) - } - - fn commit( - &mut self, - deltas: DeltaPair, - _cursors: &mut DbspStateCursors, - ) -> Result> { - // Filter operator only uses left delta, right must be empty - assert!( - deltas.right.is_empty(), - "FilterOperator expects right delta to be empty in commit" - ); - - let mut output_delta = Delta::new(); - - // Commit the delta to our internal state - // Only pass through and track rows that satisfy the filter predicate - for (row, weight) in deltas.left.changes { - if let Some(tracker) = &self.tracker { - tracker.lock().unwrap().record_filter(); - } - - // Only track and output rows that pass the filter - // For deletes, this means the row was in the view (its values pass the filter) - // For inserts, this means the row should be in the view - if self.evaluate_predicate(&row.values) { - output_delta.changes.push((row, weight)); - } - } - - Ok(IOResult::Done(output_delta)) - } - - fn set_tracker(&mut self, tracker: Arc>) { - self.tracker = Some(tracker); - } -} - /// Project operator - selects/transforms columns #[derive(Clone)] pub struct ProjectOperator { From 7178d8d31c91a82310a746052bbdbe71b355fdd4 Mon Sep 17 00:00:00 2001 From: Glauber Costa Date: Wed, 17 Sep 2025 10:45:12 -0500 Subject: [PATCH 15/34] move the project operator to its own file. The code is becoming impossible to reason about with everything in operator.rs --- core/incremental/mod.rs | 1 + core/incremental/operator.rs | 424 +-------------------------- core/incremental/project_operator.rs | 168 +++++++++++ 3 files changed, 172 insertions(+), 421 deletions(-) create mode 100644 core/incremental/project_operator.rs diff --git a/core/incremental/mod.rs b/core/incremental/mod.rs index 8a6722370..0e45b3194 100644 --- a/core/incremental/mod.rs +++ b/core/incremental/mod.rs @@ -6,4 +6,5 @@ pub mod filter_operator; pub mod input_operator; pub mod operator; pub mod persistence; +pub mod project_operator; pub mod view; diff --git a/core/incremental/operator.rs b/core/incremental/operator.rs index 4374e27e5..43ad8f67c 100644 --- a/core/incremental/operator.rs +++ b/core/incremental/operator.rs @@ -4,22 +4,18 @@ pub use crate::incremental::filter_operator::{FilterOperator, FilterPredicate}; pub use crate::incremental::input_operator::InputOperator; +pub use crate::incremental::project_operator::{ProjectColumn, ProjectOperator}; use crate::function::{AggFunc, Func}; use crate::incremental::dbsp::{Delta, DeltaPair, HashableRow}; -use crate::incremental::expr_compiler::CompiledExpression; use crate::incremental::persistence::{MinMaxPersistState, ReadRecord, RecomputeMinMax, WriteRow}; use crate::schema::{Index, IndexColumn}; use crate::storage::btree::BTreeCursor; -use crate::types::{IOResult, ImmutableRecord, SeekKey, SeekOp, SeekResult, Text}; -use crate::{ - return_and_restore_if_io, return_if_io, Connection, Database, Result, SymbolTable, Value, -}; +use crate::types::{IOResult, ImmutableRecord, SeekKey, SeekOp, SeekResult}; +use crate::{return_and_restore_if_io, return_if_io, Result, Value}; use std::collections::{BTreeMap, HashMap}; use std::fmt::{self, Debug, Display}; use std::sync::{Arc, Mutex}; -use turso_macros::match_ignore_ascii_case; -use turso_parser::ast::{As, Expr, Literal, Name, OneSelect, Operator, ResultColumn}; /// Struct to hold both table and index cursors for DBSP state operations pub struct DbspStateCursors { @@ -795,16 +791,6 @@ pub enum QueryOperator { }, } -#[derive(Debug, Clone)] -pub struct ProjectColumn { - /// The original SQL expression (for debugging/fallback) - pub expr: turso_parser::ast::Expr, - /// Optional alias for the column - pub alias: Option, - /// Compiled expression (handles both trivial columns and complex expressions) - pub compiled: CompiledExpression, -} - #[derive(Debug, Clone)] pub enum JoinType { Inner, @@ -897,410 +883,6 @@ pub trait IncrementalOperator: Debug { fn set_tracker(&mut self, tracker: Arc>); } -/// Project operator - selects/transforms columns -#[derive(Clone)] -pub struct ProjectOperator { - columns: Vec, - input_column_names: Vec, - output_column_names: Vec, - tracker: Option>>, - // Internal in-memory connection for expression evaluation - // Programs are very dependent on having a connection, so give it one. - // - // We could in theory pass the current connection, but there are a host of problems with that. - // For example: during a write transaction, where views are usually updated, we have autocommit - // on. When the program we are executing calls Halt, it will try to commit the current - // transaction, which is absolutely incorrect. - // - // There are other ways to solve this, but a read-only connection to an empty in-memory - // database gives us the closest environment we need to execute expressions. - internal_conn: Arc, -} - -impl std::fmt::Debug for ProjectOperator { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - f.debug_struct("ProjectOperator") - .field("columns", &self.columns) - .field("input_column_names", &self.input_column_names) - .field("output_column_names", &self.output_column_names) - .field("tracker", &self.tracker) - .finish_non_exhaustive() - } -} - -impl ProjectOperator { - /// Create a new ProjectOperator from a SELECT statement, extracting projection columns - pub fn from_select( - select: &turso_parser::ast::Select, - input_column_names: Vec, - schema: &crate::schema::Schema, - ) -> crate::Result { - // Set up internal connection for expression evaluation - let io = Arc::new(crate::MemoryIO::new()); - let db = Database::open_file( - io, ":memory:", false, // no MVCC needed for expression evaluation - false, // no indexes needed - )?; - let internal_conn = db.connect()?; - // Set to read-only mode and disable auto-commit since we're only evaluating expressions - internal_conn.query_only.set(true); - internal_conn.auto_commit.set(false); - - let temp_syms = SymbolTable::new(); - - // Extract columns from SELECT statement - let columns = if let OneSelect::Select { - columns: ref select_columns, - .. - } = &select.body.select - { - let mut columns = Vec::new(); - for result_col in select_columns { - match result_col { - ResultColumn::Expr(expr, alias) => { - let alias_str = if let Some(As::As(alias_name)) = alias { - Some(alias_name.as_str().to_string()) - } else { - None - }; - // Try to compile the expression (handles both columns and complex expressions) - let compiled = CompiledExpression::compile( - expr, - &input_column_names, - schema, - &temp_syms, - internal_conn.clone(), - )?; - columns.push(ProjectColumn { - expr: (**expr).clone(), - alias: alias_str, - compiled, - }); - } - ResultColumn::Star => { - // Select all columns - create trivial column references - for name in &input_column_names { - // Create an Id expression for the column - let expr = Expr::Id(Name::Ident(name.clone())); - let compiled = CompiledExpression::compile( - &expr, - &input_column_names, - schema, - &temp_syms, - internal_conn.clone(), - )?; - columns.push(ProjectColumn { - expr, - alias: None, - compiled, - }); - } - } - x => { - return Err(crate::LimboError::ParseError(format!( - "Unsupported {x:?} clause when compiling project operator", - ))); - } - } - } - - if columns.is_empty() { - return Err(crate::LimboError::ParseError( - "No columns found when compiling project operator".to_string(), - )); - } - columns - } else { - return Err(crate::LimboError::ParseError( - "Expression is not a valid SELECT expression".to_string(), - )); - }; - - // Generate output column names based on aliases or expressions - let output_column_names = columns - .iter() - .map(|c| { - c.alias.clone().unwrap_or_else(|| match &c.expr { - Expr::Id(name) => name.as_str().to_string(), - Expr::Qualified(table, column) => { - format!("{}.{}", table.as_str(), column.as_str()) - } - Expr::DoublyQualified(db, table, column) => { - format!("{}.{}.{}", db.as_str(), table.as_str(), column.as_str()) - } - _ => c.expr.to_string(), - }) - }) - .collect(); - - Ok(Self { - columns, - input_column_names, - output_column_names, - tracker: None, - internal_conn, - }) - } - - /// Create a ProjectOperator from pre-compiled expressions - pub fn from_compiled( - compiled_exprs: Vec, - aliases: Vec>, - input_column_names: Vec, - output_column_names: Vec, - ) -> crate::Result { - // Set up internal connection for expression evaluation - let io = Arc::new(crate::MemoryIO::new()); - let db = Database::open_file( - io, ":memory:", false, // no MVCC needed for expression evaluation - false, // no indexes needed - )?; - let internal_conn = db.connect()?; - // Set to read-only mode and disable auto-commit since we're only evaluating expressions - internal_conn.query_only.set(true); - internal_conn.auto_commit.set(false); - - // Create ProjectColumn structs from compiled expressions - let columns: Vec = compiled_exprs - .into_iter() - .zip(aliases) - .map(|(compiled, alias)| ProjectColumn { - // Create a placeholder AST expression since we already have the compiled version - expr: turso_parser::ast::Expr::Literal(turso_parser::ast::Literal::Null), - alias, - compiled, - }) - .collect(); - - Ok(Self { - columns, - input_column_names, - output_column_names, - tracker: None, - internal_conn, - }) - } - - /// Get the columns for this projection - pub fn columns(&self) -> &[ProjectColumn] { - &self.columns - } - - fn project_values(&self, values: &[Value]) -> Vec { - let mut output = Vec::new(); - - for col in &self.columns { - // Use the internal connection's pager for expression evaluation - let internal_pager = self.internal_conn.pager.borrow().clone(); - - // Execute the compiled expression (handles both columns and complex expressions) - let result = col - .compiled - .execute(values, internal_pager) - .expect("Failed to execute compiled expression for the Project operator"); - output.push(result); - } - - output - } - - fn evaluate_expression(&self, expr: &turso_parser::ast::Expr, values: &[Value]) -> Value { - match expr { - Expr::Id(name) => { - if let Some(idx) = self - .input_column_names - .iter() - .position(|c| c == name.as_str()) - { - if let Some(v) = values.get(idx) { - return v.clone(); - } - } - Value::Null - } - Expr::Literal(lit) => { - match lit { - Literal::Numeric(n) => { - if let Ok(i) = n.parse::() { - Value::Integer(i) - } else if let Ok(f) = n.parse::() { - Value::Float(f) - } else { - Value::Null - } - } - Literal::String(s) => { - let cleaned = s.trim_matches('\'').trim_matches('"'); - Value::Text(Text::new(cleaned)) - } - Literal::Null => Value::Null, - Literal::Blob(_) - | Literal::Keyword(_) - | Literal::CurrentDate - | Literal::CurrentTime - | Literal::CurrentTimestamp => Value::Null, // Not supported yet - } - } - Expr::Binary(left, op, right) => { - let left_val = self.evaluate_expression(left, values); - let right_val = self.evaluate_expression(right, values); - - match op { - Operator::Add => match (&left_val, &right_val) { - (Value::Integer(a), Value::Integer(b)) => Value::Integer(a + b), - (Value::Float(a), Value::Float(b)) => Value::Float(a + b), - (Value::Integer(a), Value::Float(b)) => Value::Float(*a as f64 + b), - (Value::Float(a), Value::Integer(b)) => Value::Float(a + *b as f64), - _ => Value::Null, - }, - Operator::Subtract => match (&left_val, &right_val) { - (Value::Integer(a), Value::Integer(b)) => Value::Integer(a - b), - (Value::Float(a), Value::Float(b)) => Value::Float(a - b), - (Value::Integer(a), Value::Float(b)) => Value::Float(*a as f64 - b), - (Value::Float(a), Value::Integer(b)) => Value::Float(a - *b as f64), - _ => Value::Null, - }, - Operator::Multiply => match (&left_val, &right_val) { - (Value::Integer(a), Value::Integer(b)) => Value::Integer(a * b), - (Value::Float(a), Value::Float(b)) => Value::Float(a * b), - (Value::Integer(a), Value::Float(b)) => Value::Float(*a as f64 * b), - (Value::Float(a), Value::Integer(b)) => Value::Float(a * *b as f64), - _ => Value::Null, - }, - Operator::Divide => match (&left_val, &right_val) { - (Value::Integer(a), Value::Integer(b)) => { - if *b != 0 { - Value::Integer(a / b) - } else { - Value::Null - } - } - (Value::Float(a), Value::Float(b)) => { - if *b != 0.0 { - Value::Float(a / b) - } else { - Value::Null - } - } - (Value::Integer(a), Value::Float(b)) => { - if *b != 0.0 { - Value::Float(*a as f64 / b) - } else { - Value::Null - } - } - (Value::Float(a), Value::Integer(b)) => { - if *b != 0 { - Value::Float(a / *b as f64) - } else { - Value::Null - } - } - _ => Value::Null, - }, - _ => Value::Null, // Other operators not supported yet - } - } - Expr::FunctionCall { name, args, .. } => { - let name_bytes = name.as_str().as_bytes(); - match_ignore_ascii_case!(match name_bytes { - b"hex" => { - if args.len() == 1 { - let arg_val = self.evaluate_expression(&args[0], values); - match arg_val { - Value::Integer(i) => Value::Text(Text::new(&format!("{i:X}"))), - _ => Value::Null, - } - } else { - Value::Null - } - } - _ => Value::Null, // Other functions not supported yet - }) - } - Expr::Parenthesized(inner) => { - assert!( - inner.len() <= 1, - "Parenthesized expressions with multiple elements are not supported" - ); - if !inner.is_empty() { - self.evaluate_expression(&inner[0], values) - } else { - Value::Null - } - } - _ => Value::Null, // Other expression types not supported yet - } - } -} - -impl IncrementalOperator for ProjectOperator { - fn eval( - &mut self, - state: &mut EvalState, - _cursors: &mut DbspStateCursors, - ) -> Result> { - let delta = match state { - EvalState::Init { deltas } => { - // Project operators only use left_delta, right_delta must be empty - assert!( - deltas.right.is_empty(), - "ProjectOperator expects right_delta to be empty" - ); - std::mem::take(&mut deltas.left) - } - _ => unreachable!( - "ProjectOperator doesn't execute the state machine. Should be in Init state" - ), - }; - - let mut output_delta = Delta::new(); - - for (row, weight) in delta.changes { - if let Some(tracker) = &self.tracker { - tracker.lock().unwrap().record_project(); - } - - let projected = self.project_values(&row.values); - let projected_row = HashableRow::new(row.rowid, projected); - output_delta.changes.push((projected_row, weight)); - } - - *state = EvalState::Done; - Ok(IOResult::Done(output_delta)) - } - - fn commit( - &mut self, - deltas: DeltaPair, - _cursors: &mut DbspStateCursors, - ) -> Result> { - // Project operator only uses left delta, right must be empty - assert!( - deltas.right.is_empty(), - "ProjectOperator expects right delta to be empty in commit" - ); - - let mut output_delta = Delta::new(); - - // Commit the delta to our internal state and build output - for (row, weight) in &deltas.left.changes { - if let Some(tracker) = &self.tracker { - tracker.lock().unwrap().record_project(); - } - let projected = self.project_values(&row.values); - let projected_row = HashableRow::new(row.rowid, projected); - output_delta.changes.push((projected_row, *weight)); - } - - Ok(crate::types::IOResult::Done(output_delta)) - } - - fn set_tracker(&mut self, tracker: Arc>) { - self.tracker = Some(tracker); - } -} - /// Aggregate operator - performs incremental aggregation with GROUP BY /// Maintains running totals/counts that are updated incrementally /// diff --git a/core/incremental/project_operator.rs b/core/incremental/project_operator.rs new file mode 100644 index 000000000..b1d9fc9ed --- /dev/null +++ b/core/incremental/project_operator.rs @@ -0,0 +1,168 @@ +// Project operator for DBSP-style incremental computation +// This operator projects/transforms columns in a relational stream + +use crate::incremental::dbsp::{Delta, DeltaPair, HashableRow}; +use crate::incremental::expr_compiler::CompiledExpression; +use crate::incremental::operator::{ + ComputationTracker, DbspStateCursors, EvalState, IncrementalOperator, +}; +use crate::types::IOResult; +use crate::{Connection, Database, Result, Value}; +use std::sync::{Arc, Mutex}; + +#[derive(Debug, Clone)] +pub struct ProjectColumn { + /// Compiled expression (handles both trivial columns and complex expressions) + pub compiled: CompiledExpression, +} + +/// Project operator - selects/transforms columns +#[derive(Clone)] +pub struct ProjectOperator { + columns: Vec, + input_column_names: Vec, + output_column_names: Vec, + tracker: Option>>, + // Internal in-memory connection for expression evaluation + // Programs are very dependent on having a connection, so give it one. + // + // We could in theory pass the current connection, but there are a host of problems with that. + // For example: during a write transaction, where views are usually updated, we have autocommit + // on. When the program we are executing calls Halt, it will try to commit the current + // transaction, which is absolutely incorrect. + // + // There are other ways to solve this, but a read-only connection to an empty in-memory + // database gives us the closest environment we need to execute expressions. + internal_conn: Arc, +} + +impl std::fmt::Debug for ProjectOperator { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.debug_struct("ProjectOperator") + .field("columns", &self.columns) + .field("input_column_names", &self.input_column_names) + .field("output_column_names", &self.output_column_names) + .finish() + } +} + +impl ProjectOperator { + /// Create a ProjectOperator from pre-compiled expressions + pub fn from_compiled( + compiled_exprs: Vec, + aliases: Vec>, + input_column_names: Vec, + output_column_names: Vec, + ) -> crate::Result { + // Set up internal connection for expression evaluation + let io = Arc::new(crate::MemoryIO::new()); + let db = Database::open_file( + io, ":memory:", false, // no MVCC needed for expression evaluation + false, // no indexes needed + )?; + let internal_conn = db.connect()?; + // Set to read-only mode and disable auto-commit since we're only evaluating expressions + internal_conn.query_only.set(true); + internal_conn.auto_commit.set(false); + + // Create ProjectColumn structs from compiled expressions + let columns: Vec = compiled_exprs + .into_iter() + .zip(aliases) + .map(|(compiled, _alias)| ProjectColumn { compiled }) + .collect(); + + Ok(Self { + columns, + input_column_names, + output_column_names, + tracker: None, + internal_conn, + }) + } + + fn project_values(&self, values: &[Value]) -> Vec { + let mut output = Vec::new(); + + for col in &self.columns { + // Use the internal connection's pager for expression evaluation + let internal_pager = self.internal_conn.pager.borrow().clone(); + + // Execute the compiled expression (handles both columns and complex expressions) + let result = col + .compiled + .execute(values, internal_pager) + .expect("Failed to execute compiled expression for the Project operator"); + output.push(result); + } + + output + } +} + +impl IncrementalOperator for ProjectOperator { + fn eval( + &mut self, + state: &mut EvalState, + _cursors: &mut DbspStateCursors, + ) -> Result> { + let delta = match state { + EvalState::Init { deltas } => { + // Project operators only use left_delta, right_delta must be empty + assert!( + deltas.right.is_empty(), + "ProjectOperator expects right_delta to be empty" + ); + std::mem::take(&mut deltas.left) + } + _ => unreachable!( + "ProjectOperator doesn't execute the state machine. Should be in Init state" + ), + }; + + let mut output_delta = Delta::new(); + + for (row, weight) in delta.changes { + if let Some(tracker) = &self.tracker { + tracker.lock().unwrap().record_project(); + } + + let projected = self.project_values(&row.values); + let projected_row = HashableRow::new(row.rowid, projected); + output_delta.changes.push((projected_row, weight)); + } + + *state = EvalState::Done; + Ok(IOResult::Done(output_delta)) + } + + fn commit( + &mut self, + deltas: DeltaPair, + _cursors: &mut DbspStateCursors, + ) -> Result> { + // Project operator only uses left delta, right must be empty + assert!( + deltas.right.is_empty(), + "ProjectOperator expects right delta to be empty in commit" + ); + + let mut output_delta = Delta::new(); + + // Commit the delta to our internal state and build output + for (row, weight) in &deltas.left.changes { + if let Some(tracker) = &self.tracker { + tracker.lock().unwrap().record_project(); + } + let projected = self.project_values(&row.values); + let projected_row = HashableRow::new(row.rowid, projected); + output_delta.changes.push((projected_row, *weight)); + } + + Ok(crate::types::IOResult::Done(output_delta)) + } + + fn set_tracker(&mut self, tracker: Arc>) { + self.tracker = Some(tracker); + } +} From aa8fcdbe54546cce13897507820ab4858293dbe2 Mon Sep 17 00:00:00 2001 From: Glauber Costa Date: Fri, 19 Sep 2025 03:58:35 -0500 Subject: [PATCH 16/34] move the aggregate operator to its own file. The code is becoming impossible to reason about with everything in operator.rs --- core/incremental/aggregate_operator.rs | 1787 ++++++++++++++++++++++++ core/incremental/mod.rs | 1 + core/incremental/operator.rs | 1151 +-------------- core/incremental/persistence.rs | 678 +-------- 4 files changed, 1796 insertions(+), 1821 deletions(-) create mode 100644 core/incremental/aggregate_operator.rs diff --git a/core/incremental/aggregate_operator.rs b/core/incremental/aggregate_operator.rs new file mode 100644 index 000000000..f4c8ece0a --- /dev/null +++ b/core/incremental/aggregate_operator.rs @@ -0,0 +1,1787 @@ +// Aggregate operator for DBSP-style incremental computation + +use crate::function::{AggFunc, Func}; +use crate::incremental::dbsp::{Delta, DeltaPair, HashableRow}; +use crate::incremental::operator::{ + generate_storage_id, ComputationTracker, DbspStateCursors, EvalState, IncrementalOperator, +}; +use crate::incremental::persistence::{ReadRecord, WriteRow}; +use crate::types::{IOResult, ImmutableRecord, RefValue, SeekKey, SeekOp, SeekResult}; +use crate::{return_and_restore_if_io, return_if_io, LimboError, Result, Value}; +use std::collections::{BTreeMap, HashMap, HashSet}; +use std::fmt::{self, Display}; +use std::sync::{Arc, Mutex}; + +/// Constants for aggregate type encoding in storage IDs (2 bits) +pub const AGG_TYPE_REGULAR: u8 = 0b00; // COUNT/SUM/AVG +pub const AGG_TYPE_MINMAX: u8 = 0b01; // MIN/MAX (BTree ordering gives both) + +#[derive(Debug, Clone, PartialEq)] +pub enum AggregateFunction { + Count, + Sum(String), + Avg(String), + Min(String), + Max(String), +} + +impl Display for AggregateFunction { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + match self { + AggregateFunction::Count => write!(f, "COUNT(*)"), + AggregateFunction::Sum(col) => write!(f, "SUM({col})"), + AggregateFunction::Avg(col) => write!(f, "AVG({col})"), + AggregateFunction::Min(col) => write!(f, "MIN({col})"), + AggregateFunction::Max(col) => write!(f, "MAX({col})"), + } + } +} + +impl AggregateFunction { + /// Get the default output column name for this aggregate function + #[inline] + pub fn default_output_name(&self) -> String { + self.to_string() + } + + /// Create an AggregateFunction from a SQL function and its arguments + /// Returns None if the function is not a supported aggregate + pub fn from_sql_function( + func: &crate::function::Func, + input_column: Option, + ) -> Option { + match func { + Func::Agg(agg_func) => { + match agg_func { + AggFunc::Count | AggFunc::Count0 => Some(AggregateFunction::Count), + AggFunc::Sum => input_column.map(AggregateFunction::Sum), + AggFunc::Avg => input_column.map(AggregateFunction::Avg), + AggFunc::Min => input_column.map(AggregateFunction::Min), + AggFunc::Max => input_column.map(AggregateFunction::Max), + _ => None, // Other aggregate functions not yet supported in DBSP + } + } + _ => None, // Not an aggregate function + } + } +} + +/// Information about a column that has MIN/MAX aggregations +#[derive(Debug, Clone)] +pub struct AggColumnInfo { + /// Index used for storage key generation + pub index: usize, + /// Whether this column has a MIN aggregate + pub has_min: bool, + /// Whether this column has a MAX aggregate + pub has_max: bool, +} + +/// Serialize a Value using SQLite's serial type format +/// This is used for MIN/MAX values that need to be stored in a compact, sortable format +pub fn serialize_value(value: &Value, blob: &mut Vec) { + let serial_type = crate::types::SerialType::from(value); + let serial_type_u64: u64 = serial_type.into(); + crate::storage::sqlite3_ondisk::write_varint_to_vec(serial_type_u64, blob); + value.serialize_serial(blob); +} + +/// Deserialize a Value using SQLite's serial type format +/// Returns the deserialized value and the number of bytes consumed +pub fn deserialize_value(blob: &[u8]) -> Option<(Value, usize)> { + let mut cursor = 0; + + // Read the serial type + let (serial_type, varint_size) = crate::storage::sqlite3_ondisk::read_varint(blob).ok()?; + cursor += varint_size; + + let serial_type_obj = crate::types::SerialType::try_from(serial_type).ok()?; + let expected_size = serial_type_obj.size(); + + // Read the value + let (value, actual_size) = + crate::storage::sqlite3_ondisk::read_value(&blob[cursor..], serial_type_obj).ok()?; + + // Verify that the actual size matches what we expected from the serial type + if actual_size != expected_size { + return None; // Data corruption - size mismatch + } + + cursor += actual_size; + + // Convert RefValue to Value + Some((value.to_owned(), cursor)) +} + +// group_key_str -> (group_key, state) +type ComputedStates = HashMap, AggregateState)>; +// group_key_str -> (column_name, value_as_hashable_row) -> accumulated_weight +pub type MinMaxDeltas = HashMap>; + +#[derive(Debug)] +enum AggregateCommitState { + Idle, + Eval { + eval_state: EvalState, + }, + PersistDelta { + delta: Delta, + computed_states: ComputedStates, + current_idx: usize, + write_row: WriteRow, + min_max_deltas: MinMaxDeltas, + }, + PersistMinMax { + delta: Delta, + min_max_persist_state: MinMaxPersistState, + }, + Done { + delta: Delta, + }, + Invalid, +} + +// Aggregate-specific eval states +#[derive(Debug)] +pub enum AggregateEvalState { + FetchKey { + delta: Delta, // Keep original delta for merge operation + current_idx: usize, + groups_to_read: Vec<(String, Vec)>, // Changed to Vec for index-based access + existing_groups: HashMap, + old_values: HashMap>, + }, + FetchData { + delta: Delta, // Keep original delta for merge operation + current_idx: usize, + groups_to_read: Vec<(String, Vec)>, // Changed to Vec for index-based access + existing_groups: HashMap, + old_values: HashMap>, + rowid: Option, // Rowid found by FetchKey (None if not found) + read_record_state: Box, + }, + RecomputeMinMax { + delta: Delta, + existing_groups: HashMap, + old_values: HashMap>, + recompute_state: Box, + }, + Done { + output: (Delta, ComputedStates), + }, +} + +/// Note that the AggregateOperator essentially implements a ZSet, even +/// though the ZSet structure is never used explicitly. The on-disk btree +/// plays the role of the set! +#[derive(Debug)] +pub struct AggregateOperator { + // Unique operator ID for indexing in persistent storage + pub operator_id: usize, + // GROUP BY columns + group_by: Vec, + // Aggregate functions to compute (including MIN/MAX) + pub aggregates: Vec, + // Column names from input + pub input_column_names: Vec, + // Map from column name to aggregate info for quick lookup + pub column_min_max: HashMap, + tracker: Option>>, + + // State machine for commit operation + commit_state: AggregateCommitState, +} + +/// State for a single group's aggregates +#[derive(Debug, Clone, Default)] +pub struct AggregateState { + // For COUNT: just the count + pub count: i64, + // For SUM: column_name -> sum value + sums: HashMap, + // For AVG: column_name -> (sum, count) for computing average + avgs: HashMap, + // For MIN: column_name -> minimum value + pub mins: HashMap, + // For MAX: column_name -> maximum value + pub maxs: HashMap, +} + +impl AggregateEvalState { + fn process_delta( + &mut self, + operator: &mut AggregateOperator, + cursors: &mut DbspStateCursors, + ) -> Result> { + loop { + match self { + AggregateEvalState::FetchKey { + delta, + current_idx, + groups_to_read, + existing_groups, + old_values, + } => { + if *current_idx >= groups_to_read.len() { + // All groups have been fetched, move to RecomputeMinMax + // Extract MIN/MAX deltas from the input delta + let min_max_deltas = operator.extract_min_max_deltas(delta); + + let recompute_state = Box::new(RecomputeMinMax::new( + min_max_deltas, + existing_groups, + operator, + )); + + *self = AggregateEvalState::RecomputeMinMax { + delta: std::mem::take(delta), + existing_groups: std::mem::take(existing_groups), + old_values: std::mem::take(old_values), + recompute_state, + }; + } else { + // Get the current group to read + let (group_key_str, _group_key) = &groups_to_read[*current_idx]; + + // Build the key for the index: (operator_id, zset_id, element_id) + // For regular aggregates, use column_index=0 and AGG_TYPE_REGULAR + let operator_storage_id = + generate_storage_id(operator.operator_id, 0, AGG_TYPE_REGULAR); + let zset_id = operator.generate_group_rowid(group_key_str); + let element_id = 0i64; // Always 0 for aggregators + + // Create index key values + let index_key_values = vec![ + Value::Integer(operator_storage_id), + Value::Integer(zset_id), + Value::Integer(element_id), + ]; + + // Create an immutable record for the index key + let index_record = + ImmutableRecord::from_values(&index_key_values, index_key_values.len()); + + // Seek in the index to find if this row exists + let seek_result = return_if_io!(cursors.index_cursor.seek( + SeekKey::IndexKey(&index_record), + SeekOp::GE { eq_only: true } + )); + + let rowid = if matches!(seek_result, SeekResult::Found) { + // Found in index, get the table rowid + // The btree code handles extracting the rowid from the index record for has_rowid indexes + return_if_io!(cursors.index_cursor.rowid()) + } else { + // Not found in index, no existing state + None + }; + + // Always transition to FetchData + let taken_existing = std::mem::take(existing_groups); + let taken_old_values = std::mem::take(old_values); + let next_state = AggregateEvalState::FetchData { + delta: std::mem::take(delta), + current_idx: *current_idx, + groups_to_read: std::mem::take(groups_to_read), + existing_groups: taken_existing, + old_values: taken_old_values, + rowid, + read_record_state: Box::new(ReadRecord::new()), + }; + *self = next_state; + } + } + AggregateEvalState::FetchData { + delta, + current_idx, + groups_to_read, + existing_groups, + old_values, + rowid, + read_record_state, + } => { + // Get the current group to read + let (group_key_str, group_key) = &groups_to_read[*current_idx]; + + // Only try to read if we have a rowid + if let Some(rowid) = rowid { + let key = SeekKey::TableRowId(*rowid); + let state = return_if_io!(read_record_state.read_record( + key, + &operator.aggregates, + &mut cursors.table_cursor + )); + // Process the fetched state + if let Some(state) = state { + let mut old_row = group_key.clone(); + old_row.extend(state.to_values(&operator.aggregates)); + old_values.insert(group_key_str.clone(), old_row); + existing_groups.insert(group_key_str.clone(), state.clone()); + } + } else { + // No rowid for this group, skipping read + } + // If no rowid, there's no existing state for this group + + // Move to next group + let next_idx = *current_idx + 1; + let taken_existing = std::mem::take(existing_groups); + let taken_old_values = std::mem::take(old_values); + let next_state = AggregateEvalState::FetchKey { + delta: std::mem::take(delta), + current_idx: next_idx, + groups_to_read: std::mem::take(groups_to_read), + existing_groups: taken_existing, + old_values: taken_old_values, + }; + *self = next_state; + } + AggregateEvalState::RecomputeMinMax { + delta, + existing_groups, + old_values, + recompute_state, + } => { + if operator.has_min_max() { + // Process MIN/MAX recomputation - this will update existing_groups with correct MIN/MAX + return_if_io!(recompute_state.process(existing_groups, operator, cursors)); + } + + // Now compute final output with updated MIN/MAX values + let (output_delta, computed_states) = + operator.merge_delta_with_existing(delta, existing_groups, old_values); + + *self = AggregateEvalState::Done { + output: (output_delta, computed_states), + }; + } + AggregateEvalState::Done { output } => { + return Ok(IOResult::Done(output.clone())); + } + } + } + } +} + +impl AggregateState { + pub fn new() -> Self { + Self::default() + } + + // Serialize the aggregate state to a binary blob including group key values + // The reason we serialize it like this, instead of just writing the actual values, is that + // The same table may have different aggregators in the circuit. They will all have different + // columns. + fn to_blob(&self, aggregates: &[AggregateFunction], group_key: &[Value]) -> Vec { + let mut blob = Vec::new(); + + // Write version byte for future compatibility + blob.push(1u8); + + // Write number of group key values + blob.extend_from_slice(&(group_key.len() as u32).to_le_bytes()); + + // Write each group key value + for value in group_key { + // Write value type tag + match value { + Value::Null => blob.push(0u8), + Value::Integer(i) => { + blob.push(1u8); + blob.extend_from_slice(&i.to_le_bytes()); + } + Value::Float(f) => { + blob.push(2u8); + blob.extend_from_slice(&f.to_le_bytes()); + } + Value::Text(s) => { + blob.push(3u8); + let text_str = s.as_str(); + let bytes = text_str.as_bytes(); + blob.extend_from_slice(&(bytes.len() as u32).to_le_bytes()); + blob.extend_from_slice(bytes); + } + Value::Blob(b) => { + blob.push(4u8); + blob.extend_from_slice(&(b.len() as u32).to_le_bytes()); + blob.extend_from_slice(b); + } + } + } + + // Write count as 8 bytes (little-endian) + blob.extend_from_slice(&self.count.to_le_bytes()); + + // Write each aggregate's state + for agg in aggregates { + match agg { + AggregateFunction::Sum(col_name) => { + let sum = self.sums.get(col_name).copied().unwrap_or(0.0); + blob.extend_from_slice(&sum.to_le_bytes()); + } + AggregateFunction::Avg(col_name) => { + let (sum, count) = self.avgs.get(col_name).copied().unwrap_or((0.0, 0)); + blob.extend_from_slice(&sum.to_le_bytes()); + blob.extend_from_slice(&count.to_le_bytes()); + } + AggregateFunction::Count => { + // Count is already written above + } + AggregateFunction::Min(col_name) => { + // Write whether we have a MIN value (1 byte) + if let Some(min_val) = self.mins.get(col_name) { + blob.push(1u8); // Has value + serialize_value(min_val, &mut blob); + } else { + blob.push(0u8); // No value + } + } + AggregateFunction::Max(col_name) => { + // Write whether we have a MAX value (1 byte) + if let Some(max_val) = self.maxs.get(col_name) { + blob.push(1u8); // Has value + serialize_value(max_val, &mut blob); + } else { + blob.push(0u8); // No value + } + } + } + } + + blob + } + + /// Deserialize aggregate state from a binary blob + /// Returns the aggregate state and the group key values + pub fn from_blob(blob: &[u8], aggregates: &[AggregateFunction]) -> Option<(Self, Vec)> { + let mut cursor = 0; + + // Check version byte + if blob.get(cursor) != Some(&1u8) { + return None; + } + cursor += 1; + + // Read number of group key values + let num_group_keys = + u32::from_le_bytes(blob.get(cursor..cursor + 4)?.try_into().ok()?) as usize; + cursor += 4; + + // Read group key values + let mut group_key = Vec::new(); + for _ in 0..num_group_keys { + let value_type = *blob.get(cursor)?; + cursor += 1; + + let value = match value_type { + 0 => Value::Null, + 1 => { + let i = i64::from_le_bytes(blob.get(cursor..cursor + 8)?.try_into().ok()?); + cursor += 8; + Value::Integer(i) + } + 2 => { + let f = f64::from_le_bytes(blob.get(cursor..cursor + 8)?.try_into().ok()?); + cursor += 8; + Value::Float(f) + } + 3 => { + let len = + u32::from_le_bytes(blob.get(cursor..cursor + 4)?.try_into().ok()?) as usize; + cursor += 4; + let bytes = blob.get(cursor..cursor + len)?; + cursor += len; + let text_str = std::str::from_utf8(bytes).ok()?; + Value::Text(text_str.to_string().into()) + } + 4 => { + let len = + u32::from_le_bytes(blob.get(cursor..cursor + 4)?.try_into().ok()?) as usize; + cursor += 4; + let bytes = blob.get(cursor..cursor + len)?; + cursor += len; + Value::Blob(bytes.to_vec()) + } + _ => return None, + }; + group_key.push(value); + } + + // Read count + let count = i64::from_le_bytes(blob.get(cursor..cursor + 8)?.try_into().ok()?); + cursor += 8; + + let mut state = Self::new(); + state.count = count; + + // Read each aggregate's state + for agg in aggregates { + match agg { + AggregateFunction::Sum(col_name) => { + let sum = f64::from_le_bytes(blob.get(cursor..cursor + 8)?.try_into().ok()?); + cursor += 8; + state.sums.insert(col_name.clone(), sum); + } + AggregateFunction::Avg(col_name) => { + let sum = f64::from_le_bytes(blob.get(cursor..cursor + 8)?.try_into().ok()?); + cursor += 8; + let count = i64::from_le_bytes(blob.get(cursor..cursor + 8)?.try_into().ok()?); + cursor += 8; + state.avgs.insert(col_name.clone(), (sum, count)); + } + AggregateFunction::Count => { + // Count was already read above + } + AggregateFunction::Min(col_name) => { + // Read whether we have a MIN value + let has_value = *blob.get(cursor)?; + cursor += 1; + + if has_value == 1 { + let (min_value, bytes_consumed) = deserialize_value(&blob[cursor..])?; + cursor += bytes_consumed; + state.mins.insert(col_name.clone(), min_value); + } + } + AggregateFunction::Max(col_name) => { + // Read whether we have a MAX value + let has_value = *blob.get(cursor)?; + cursor += 1; + + if has_value == 1 { + let (max_value, bytes_consumed) = deserialize_value(&blob[cursor..])?; + cursor += bytes_consumed; + state.maxs.insert(col_name.clone(), max_value); + } + } + } + } + + Some((state, group_key)) + } + + /// Apply a delta to this aggregate state + fn apply_delta( + &mut self, + values: &[Value], + weight: isize, + aggregates: &[AggregateFunction], + column_names: &[String], + ) { + // Update COUNT + self.count += weight as i64; + + // Update other aggregates + for agg in aggregates { + match agg { + AggregateFunction::Count => { + // Already handled above + } + AggregateFunction::Sum(col_name) => { + if let Some(idx) = column_names.iter().position(|c| c == col_name) { + if let Some(val) = values.get(idx) { + let num_val = match val { + Value::Integer(i) => *i as f64, + Value::Float(f) => *f, + _ => 0.0, + }; + *self.sums.entry(col_name.clone()).or_insert(0.0) += + num_val * weight as f64; + } + } + } + AggregateFunction::Avg(col_name) => { + if let Some(idx) = column_names.iter().position(|c| c == col_name) { + if let Some(val) = values.get(idx) { + let num_val = match val { + Value::Integer(i) => *i as f64, + Value::Float(f) => *f, + _ => 0.0, + }; + let (sum, count) = + self.avgs.entry(col_name.clone()).or_insert((0.0, 0)); + *sum += num_val * weight as f64; + *count += weight as i64; + } + } + } + AggregateFunction::Min(_col_name) | AggregateFunction::Max(_col_name) => { + // MIN/MAX cannot be handled incrementally in apply_delta because: + // + // 1. For insertions: We can't just keep the minimum/maximum value. + // We need to track ALL values to handle future deletions correctly. + // + // 2. For deletions (retractions): If we delete the current MIN/MAX, + // we need to find the next best value, which requires knowing all + // other values in the group. + // + // Example: Consider MIN(price) with values [10, 20, 30] + // - Current MIN = 10 + // - Delete 10 (weight = -1) + // - New MIN should be 20, but we can't determine this without + // having tracked all values [20, 30] + // + // Therefore, MIN/MAX processing is handled separately: + // - All input values are persisted to the index via persist_min_max() + // - When aggregates have MIN/MAX, we unconditionally transition to + // the RecomputeMinMax state machine (see EvalState::RecomputeMinMax) + // - RecomputeMinMax checks if the current MIN/MAX was deleted, and if so, + // scans the index to find the new MIN/MAX from remaining values + // + // This ensures correctness for incremental computation at the cost of + // additional I/O for MIN/MAX operations. + } + } + } + } + + /// Convert aggregate state to output values + pub fn to_values(&self, aggregates: &[AggregateFunction]) -> Vec { + let mut result = Vec::new(); + + for agg in aggregates { + match agg { + AggregateFunction::Count => { + result.push(Value::Integer(self.count)); + } + AggregateFunction::Sum(col_name) => { + let sum = self.sums.get(col_name).copied().unwrap_or(0.0); + // Return as integer if it's a whole number, otherwise as float + if sum.fract() == 0.0 { + result.push(Value::Integer(sum as i64)); + } else { + result.push(Value::Float(sum)); + } + } + AggregateFunction::Avg(col_name) => { + if let Some((sum, count)) = self.avgs.get(col_name) { + if *count > 0 { + result.push(Value::Float(sum / *count as f64)); + } else { + result.push(Value::Null); + } + } else { + result.push(Value::Null); + } + } + AggregateFunction::Min(col_name) => { + // Return the MIN value from our state + result.push(self.mins.get(col_name).cloned().unwrap_or(Value::Null)); + } + AggregateFunction::Max(col_name) => { + // Return the MAX value from our state + result.push(self.maxs.get(col_name).cloned().unwrap_or(Value::Null)); + } + } + } + + result + } +} + +impl AggregateOperator { + pub fn new( + operator_id: usize, + group_by: Vec, + aggregates: Vec, + input_column_names: Vec, + ) -> Self { + // Build map of column names to their MIN/MAX info with indices + let mut column_min_max = HashMap::new(); + let mut column_indices = HashMap::new(); + let mut current_index = 0; + + // First pass: assign indices to unique MIN/MAX columns + for agg in &aggregates { + match agg { + AggregateFunction::Min(col) | AggregateFunction::Max(col) => { + column_indices.entry(col.clone()).or_insert_with(|| { + let idx = current_index; + current_index += 1; + idx + }); + } + _ => {} + } + } + + // Second pass: build the column info map + for agg in &aggregates { + match agg { + AggregateFunction::Min(col) => { + let index = *column_indices.get(col).unwrap(); + let entry = column_min_max.entry(col.clone()).or_insert(AggColumnInfo { + index, + has_min: false, + has_max: false, + }); + entry.has_min = true; + } + AggregateFunction::Max(col) => { + let index = *column_indices.get(col).unwrap(); + let entry = column_min_max.entry(col.clone()).or_insert(AggColumnInfo { + index, + has_min: false, + has_max: false, + }); + entry.has_max = true; + } + _ => {} + } + } + + Self { + operator_id, + group_by, + aggregates, + input_column_names, + column_min_max, + tracker: None, + commit_state: AggregateCommitState::Idle, + } + } + + pub fn has_min_max(&self) -> bool { + !self.column_min_max.is_empty() + } + + fn eval_internal( + &mut self, + state: &mut EvalState, + cursors: &mut DbspStateCursors, + ) -> Result> { + match state { + EvalState::Uninitialized => { + panic!("Cannot eval AggregateOperator with Uninitialized state"); + } + EvalState::Init { deltas } => { + // Aggregate operators only use left_delta, right_delta must be empty + assert!( + deltas.right.is_empty(), + "AggregateOperator expects right_delta to be empty" + ); + + if deltas.left.changes.is_empty() { + *state = EvalState::Done; + return Ok(IOResult::Done((Delta::new(), HashMap::new()))); + } + + let mut groups_to_read = BTreeMap::new(); + for (row, _weight) in &deltas.left.changes { + let group_key = self.extract_group_key(&row.values); + let group_key_str = Self::group_key_to_string(&group_key); + groups_to_read.insert(group_key_str, group_key); + } + + let delta = std::mem::take(&mut deltas.left); + *state = EvalState::Aggregate(Box::new(AggregateEvalState::FetchKey { + delta, + current_idx: 0, + groups_to_read: groups_to_read.into_iter().collect(), + existing_groups: HashMap::new(), + old_values: HashMap::new(), + })); + } + EvalState::Aggregate(_agg_state) => { + // Already in progress, continue processing below. + } + EvalState::Done => { + panic!("unreachable state! should have returned"); + } + EvalState::Join(_) => { + panic!("Join state should not appear in aggregate operator"); + } + } + + // Process the delta through the aggregate state machine + match state { + EvalState::Aggregate(agg_state) => { + let result = return_if_io!(agg_state.process_delta(self, cursors)); + Ok(IOResult::Done(result)) + } + _ => panic!("Invalid state for aggregate processing"), + } + } + + fn merge_delta_with_existing( + &mut self, + delta: &Delta, + existing_groups: &mut HashMap, + old_values: &mut HashMap>, + ) -> (Delta, HashMap, AggregateState)>) { + let mut output_delta = Delta::new(); + let mut temp_keys: HashMap> = HashMap::new(); + + // Process each change in the delta + for (row, weight) in &delta.changes { + if let Some(tracker) = &self.tracker { + tracker.lock().unwrap().record_aggregation(); + } + + // Extract group key + let group_key = self.extract_group_key(&row.values); + let group_key_str = Self::group_key_to_string(&group_key); + + let state = existing_groups.entry(group_key_str.clone()).or_default(); + + temp_keys.insert(group_key_str.clone(), group_key.clone()); + + // Apply the delta to the temporary state + state.apply_delta( + &row.values, + *weight, + &self.aggregates, + &self.input_column_names, + ); + } + + // Generate output delta from temporary states and collect final states + let mut final_states = HashMap::new(); + + for (group_key_str, state) in existing_groups { + let group_key = temp_keys.get(group_key_str).cloned().unwrap_or_default(); + + // Generate a unique rowid for this group + let result_key = self.generate_group_rowid(group_key_str); + + if let Some(old_row_values) = old_values.get(group_key_str) { + let old_row = HashableRow::new(result_key, old_row_values.clone()); + output_delta.changes.push((old_row, -1)); + } + + // Always store the state for persistence (even if count=0, we need to delete it) + final_states.insert(group_key_str.clone(), (group_key.clone(), state.clone())); + + // Only include groups with count > 0 in the output delta + if state.count > 0 { + // Build output row: group_by columns + aggregate values + let mut output_values = group_key.clone(); + let aggregate_values = state.to_values(&self.aggregates); + output_values.extend(aggregate_values); + + let output_row = HashableRow::new(result_key, output_values.clone()); + output_delta.changes.push((output_row, 1)); + } + } + (output_delta, final_states) + } + + /// Extract MIN/MAX values from delta changes for persistence to index + fn extract_min_max_deltas(&self, delta: &Delta) -> MinMaxDeltas { + let mut min_max_deltas: MinMaxDeltas = HashMap::new(); + + for (row, weight) in &delta.changes { + let group_key = self.extract_group_key(&row.values); + let group_key_str = Self::group_key_to_string(&group_key); + + for agg in &self.aggregates { + match agg { + AggregateFunction::Min(col_name) | AggregateFunction::Max(col_name) => { + if let Some(idx) = + self.input_column_names.iter().position(|c| c == col_name) + { + if let Some(val) = row.values.get(idx) { + // Skip NULL values - they don't participate in MIN/MAX + if val == &Value::Null { + continue; + } + // Create a HashableRow with just this value + // Use 0 as rowid since we only care about the value for comparison + let hashable_value = HashableRow::new(0, vec![val.clone()]); + let key = (col_name.clone(), hashable_value); + + let group_entry = + min_max_deltas.entry(group_key_str.clone()).or_default(); + + let value_entry = group_entry.entry(key).or_insert(0); + + // Accumulate the weight + *value_entry += weight; + } + } + } + _ => {} // Ignore non-MIN/MAX aggregates + } + } + } + + min_max_deltas + } + + pub fn set_tracker(&mut self, tracker: Arc>) { + self.tracker = Some(tracker); + } + + /// Generate a rowid for a group + /// For no GROUP BY: always returns 0 + /// For GROUP BY: returns a hash of the group key string + pub fn generate_group_rowid(&self, group_key_str: &str) -> i64 { + if self.group_by.is_empty() { + 0 + } else { + group_key_str + .bytes() + .fold(0i64, |acc, b| acc.wrapping_mul(31).wrapping_add(b as i64)) + } + } + + /// Extract group key values from a row + pub fn extract_group_key(&self, values: &[Value]) -> Vec { + let mut key = Vec::new(); + + for group_col in &self.group_by { + if let Some(idx) = self.input_column_names.iter().position(|c| c == group_col) { + if let Some(val) = values.get(idx) { + key.push(val.clone()); + } else { + key.push(Value::Null); + } + } else { + key.push(Value::Null); + } + } + + key + } + + /// Convert group key to string for indexing (since Value doesn't implement Hash) + pub fn group_key_to_string(key: &[Value]) -> String { + key.iter() + .map(|v| format!("{v:?}")) + .collect::>() + .join(",") + } +} + +impl IncrementalOperator for AggregateOperator { + fn eval( + &mut self, + state: &mut EvalState, + cursors: &mut DbspStateCursors, + ) -> Result> { + let (delta, _) = return_if_io!(self.eval_internal(state, cursors)); + Ok(IOResult::Done(delta)) + } + + fn commit( + &mut self, + mut deltas: DeltaPair, + cursors: &mut DbspStateCursors, + ) -> Result> { + // Aggregate operator only uses left delta, right must be empty + assert!( + deltas.right.is_empty(), + "AggregateOperator expects right delta to be empty in commit" + ); + let delta = std::mem::take(&mut deltas.left); + loop { + // Note: because we std::mem::replace here (without it, the borrow checker goes nuts, + // because we call self.eval_interval, which requires a mutable borrow), we have to + // restore the state if we return I/O. So we can't use return_if_io! + let mut state = + std::mem::replace(&mut self.commit_state, AggregateCommitState::Invalid); + match &mut state { + AggregateCommitState::Invalid => { + panic!("Reached invalid state! State was replaced, and not replaced back"); + } + AggregateCommitState::Idle => { + let eval_state = EvalState::from_delta(delta.clone()); + self.commit_state = AggregateCommitState::Eval { eval_state }; + } + AggregateCommitState::Eval { ref mut eval_state } => { + // Extract input delta before eval for MIN/MAX processing + let input_delta = eval_state.extract_delta(); + + // Extract MIN/MAX deltas before any I/O operations + let min_max_deltas = self.extract_min_max_deltas(&input_delta); + + // Create a new eval state with the same delta + *eval_state = EvalState::from_delta(input_delta.clone()); + + let (output_delta, computed_states) = return_and_restore_if_io!( + &mut self.commit_state, + state, + self.eval_internal(eval_state, cursors) + ); + + self.commit_state = AggregateCommitState::PersistDelta { + delta: output_delta, + computed_states, + current_idx: 0, + write_row: WriteRow::new(), + min_max_deltas, // Store for later use + }; + } + AggregateCommitState::PersistDelta { + delta, + computed_states, + current_idx, + write_row, + min_max_deltas, + } => { + let states_vec: Vec<_> = computed_states.iter().collect(); + + if *current_idx >= states_vec.len() { + // Use the min_max_deltas we extracted earlier from the input delta + self.commit_state = AggregateCommitState::PersistMinMax { + delta: delta.clone(), + min_max_persist_state: MinMaxPersistState::new(min_max_deltas.clone()), + }; + } else { + let (group_key_str, (group_key, agg_state)) = states_vec[*current_idx]; + + // Build the key components for the new table structure + // For regular aggregates, use column_index=0 and AGG_TYPE_REGULAR + let operator_storage_id = + generate_storage_id(self.operator_id, 0, AGG_TYPE_REGULAR); + let zset_id = self.generate_group_rowid(group_key_str); + let element_id = 0i64; + + // Determine weight: -1 to delete (cancels existing weight=1), 1 to insert/update + let weight = if agg_state.count == 0 { -1 } else { 1 }; + + // Serialize the aggregate state with group key (even for deletion, we need a row) + let state_blob = agg_state.to_blob(&self.aggregates, group_key); + let blob_value = Value::Blob(state_blob); + + // Build the aggregate storage format: [operator_id, zset_id, element_id, value, weight] + let operator_id_val = Value::Integer(operator_storage_id); + let zset_id_val = Value::Integer(zset_id); + let element_id_val = Value::Integer(element_id); + let blob_val = blob_value.clone(); + + // Create index key - the first 3 columns of our primary key + let index_key = vec![ + operator_id_val.clone(), + zset_id_val.clone(), + element_id_val.clone(), + ]; + + // Record values (without weight) + let record_values = + vec![operator_id_val, zset_id_val, element_id_val, blob_val]; + + return_and_restore_if_io!( + &mut self.commit_state, + state, + write_row.write_row(cursors, index_key, record_values, weight) + ); + + let delta = std::mem::take(delta); + let computed_states = std::mem::take(computed_states); + let min_max_deltas = std::mem::take(min_max_deltas); + + self.commit_state = AggregateCommitState::PersistDelta { + delta, + computed_states, + current_idx: *current_idx + 1, + write_row: WriteRow::new(), // Reset for next write + min_max_deltas, + }; + } + } + AggregateCommitState::PersistMinMax { + delta, + min_max_persist_state, + } => { + if !self.has_min_max() { + let delta = std::mem::take(delta); + self.commit_state = AggregateCommitState::Done { delta }; + } else { + return_and_restore_if_io!( + &mut self.commit_state, + state, + min_max_persist_state.persist_min_max( + self.operator_id, + &self.column_min_max, + cursors, + |group_key_str| self.generate_group_rowid(group_key_str) + ) + ); + + let delta = std::mem::take(delta); + self.commit_state = AggregateCommitState::Done { delta }; + } + } + AggregateCommitState::Done { delta } => { + self.commit_state = AggregateCommitState::Idle; + let delta = std::mem::take(delta); + return Ok(IOResult::Done(delta)); + } + } + } + } + + fn set_tracker(&mut self, tracker: Arc>) { + self.tracker = Some(tracker); + } +} + +/// State machine for recomputing MIN/MAX values after deletion +#[derive(Debug)] +pub enum RecomputeMinMax { + ProcessElements { + /// Current column being processed + current_column_idx: usize, + /// Columns to process (combined MIN and MAX) + columns_to_process: Vec<(String, String, bool)>, // (group_key, column_name, is_min) + /// MIN/MAX deltas for checking values and weights + min_max_deltas: MinMaxDeltas, + }, + Scan { + /// Columns still to process + columns_to_process: Vec<(String, String, bool)>, + /// Current index in columns_to_process (will resume from here) + current_column_idx: usize, + /// MIN/MAX deltas for checking values and weights + min_max_deltas: MinMaxDeltas, + /// Current group key being processed + group_key: String, + /// Current column name being processed + column_name: String, + /// Whether we're looking for MIN (true) or MAX (false) + is_min: bool, + /// The scan state machine for finding the new MIN/MAX + scan_state: Box, + }, + Done, +} + +impl RecomputeMinMax { + pub fn new( + min_max_deltas: MinMaxDeltas, + existing_groups: &HashMap, + operator: &AggregateOperator, + ) -> Self { + let mut groups_to_check: HashSet<(String, String, bool)> = HashSet::new(); + + // Remember the min_max_deltas are essentially just the only column that is affected by + // this min/max, in delta (actually ZSet - consolidated delta) format. This makes it easier + // for us to consume it in here. + // + // The most challenging case is the case where there is a retraction, since we need to go + // back to the index. + for (group_key_str, values) in &min_max_deltas { + for ((col_name, hashable_row), weight) in values { + let col_info = operator.column_min_max.get(col_name); + + let value = &hashable_row.values[0]; + + if *weight < 0 { + // Deletion detected - check if it's the current MIN/MAX + if let Some(state) = existing_groups.get(group_key_str) { + // Check for MIN + if let Some(current_min) = state.mins.get(col_name) { + if current_min == value { + groups_to_check.insert(( + group_key_str.clone(), + col_name.clone(), + true, + )); + } + } + // Check for MAX + if let Some(current_max) = state.maxs.get(col_name) { + if current_max == value { + groups_to_check.insert(( + group_key_str.clone(), + col_name.clone(), + false, + )); + } + } + } + } else if *weight > 0 { + // If it is not found in the existing groups, then we only need to care + // about this if this is a new record being inserted + if let Some(info) = col_info { + if info.has_min { + groups_to_check.insert((group_key_str.clone(), col_name.clone(), true)); + } + if info.has_max { + groups_to_check.insert(( + group_key_str.clone(), + col_name.clone(), + false, + )); + } + } + } + } + } + + if groups_to_check.is_empty() { + // No recomputation or initialization needed + Self::Done + } else { + // Convert HashSet to Vec for indexed processing + let groups_to_check_vec: Vec<_> = groups_to_check.into_iter().collect(); + Self::ProcessElements { + current_column_idx: 0, + columns_to_process: groups_to_check_vec, + min_max_deltas, + } + } + } + + pub fn process( + &mut self, + existing_groups: &mut HashMap, + operator: &AggregateOperator, + cursors: &mut DbspStateCursors, + ) -> Result> { + loop { + match self { + RecomputeMinMax::ProcessElements { + current_column_idx, + columns_to_process, + min_max_deltas, + } => { + if *current_column_idx >= columns_to_process.len() { + *self = RecomputeMinMax::Done; + return Ok(IOResult::Done(())); + } + + let (group_key, column_name, is_min) = + columns_to_process[*current_column_idx].clone(); + + // Get column index from pre-computed info + let column_index = operator + .column_min_max + .get(&column_name) + .map(|info| info.index) + .unwrap(); // Should always exist since we're processing known columns + + // Get current value from existing state + let current_value = existing_groups.get(&group_key).and_then(|state| { + if is_min { + state.mins.get(&column_name).cloned() + } else { + state.maxs.get(&column_name).cloned() + } + }); + + // Create storage keys for index lookup + let storage_id = + generate_storage_id(operator.operator_id, column_index, AGG_TYPE_MINMAX); + let zset_id = operator.generate_group_rowid(&group_key); + + // Get the values for this group from min_max_deltas + let group_values = min_max_deltas.get(&group_key).cloned().unwrap_or_default(); + + let columns_to_process = std::mem::take(columns_to_process); + let min_max_deltas = std::mem::take(min_max_deltas); + + let scan_state = if is_min { + Box::new(ScanState::new_for_min( + current_value, + group_key.clone(), + column_name.clone(), + storage_id, + zset_id, + group_values, + )) + } else { + Box::new(ScanState::new_for_max( + current_value, + group_key.clone(), + column_name.clone(), + storage_id, + zset_id, + group_values, + )) + }; + + *self = RecomputeMinMax::Scan { + columns_to_process, + current_column_idx: *current_column_idx, + min_max_deltas, + group_key, + column_name, + is_min, + scan_state, + }; + } + RecomputeMinMax::Scan { + columns_to_process, + current_column_idx, + min_max_deltas, + group_key, + column_name, + is_min, + scan_state, + } => { + // Find new value using the scan state machine + let new_value = return_if_io!(scan_state.find_new_value(cursors)); + + // Update the state with new value (create if doesn't exist) + let state = existing_groups.entry(group_key.clone()).or_default(); + + if *is_min { + if let Some(min_val) = new_value { + state.mins.insert(column_name.clone(), min_val); + } else { + state.mins.remove(column_name); + } + } else if let Some(max_val) = new_value { + state.maxs.insert(column_name.clone(), max_val); + } else { + state.maxs.remove(column_name); + } + + // Move to next column + let min_max_deltas = std::mem::take(min_max_deltas); + let columns_to_process = std::mem::take(columns_to_process); + *self = RecomputeMinMax::ProcessElements { + current_column_idx: *current_column_idx + 1, + columns_to_process, + min_max_deltas, + }; + } + RecomputeMinMax::Done => { + return Ok(IOResult::Done(())); + } + } + } + } +} + +/// State machine for scanning through the index to find new MIN/MAX values +#[derive(Debug)] +pub enum ScanState { + CheckCandidate { + /// Current candidate value for MIN/MAX + candidate: Option, + /// Group key being processed + group_key: String, + /// Column name being processed + column_name: String, + /// Storage ID for the index seek + storage_id: i64, + /// ZSet ID for the group + zset_id: i64, + /// Group values from MinMaxDeltas: (column_name, HashableRow) -> weight + group_values: HashMap<(String, HashableRow), isize>, + /// Whether we're looking for MIN (true) or MAX (false) + is_min: bool, + }, + FetchNextCandidate { + /// Current candidate to seek past + current_candidate: Value, + /// Group key being processed + group_key: String, + /// Column name being processed + column_name: String, + /// Storage ID for the index seek + storage_id: i64, + /// ZSet ID for the group + zset_id: i64, + /// Group values from MinMaxDeltas: (column_name, HashableRow) -> weight + group_values: HashMap<(String, HashableRow), isize>, + /// Whether we're looking for MIN (true) or MAX (false) + is_min: bool, + }, + Done { + /// The final MIN/MAX value found + result: Option, + }, +} + +impl ScanState { + pub fn new_for_min( + current_min: Option, + group_key: String, + column_name: String, + storage_id: i64, + zset_id: i64, + group_values: HashMap<(String, HashableRow), isize>, + ) -> Self { + Self::CheckCandidate { + candidate: current_min, + group_key, + column_name, + storage_id, + zset_id, + group_values, + is_min: true, + } + } + + // Extract a new candidate from the index. It is possible that, when searching, + // we end up going into a different operator altogether. That means we have + // exhausted this operator (or group) entirely, and no good candidate was found + fn extract_new_candidate( + cursors: &mut DbspStateCursors, + index_record: &ImmutableRecord, + seek_op: SeekOp, + storage_id: i64, + zset_id: i64, + ) -> Result>> { + let seek_result = return_if_io!(cursors + .index_cursor + .seek(SeekKey::IndexKey(index_record), seek_op)); + if !matches!(seek_result, SeekResult::Found) { + return Ok(IOResult::Done(None)); + } + + let record = return_if_io!(cursors.index_cursor.record()).ok_or_else(|| { + LimboError::InternalError( + "Record found on the cursor, but could not be read".to_string(), + ) + })?; + + let values = record.get_values(); + if values.len() < 3 { + return Ok(IOResult::Done(None)); + } + + let Some(rec_storage_id) = values.first() else { + return Ok(IOResult::Done(None)); + }; + let Some(rec_zset_id) = values.get(1) else { + return Ok(IOResult::Done(None)); + }; + + // Check if we're still in the same group + if let (RefValue::Integer(rec_sid), RefValue::Integer(rec_zid)) = + (rec_storage_id, rec_zset_id) + { + if *rec_sid != storage_id || *rec_zid != zset_id { + return Ok(IOResult::Done(None)); + } + } else { + return Ok(IOResult::Done(None)); + } + + // Get the value (3rd element) + Ok(IOResult::Done(values.get(2).map(|v| v.to_owned()))) + } + + pub fn new_for_max( + current_max: Option, + group_key: String, + column_name: String, + storage_id: i64, + zset_id: i64, + group_values: HashMap<(String, HashableRow), isize>, + ) -> Self { + Self::CheckCandidate { + candidate: current_max, + group_key, + column_name, + storage_id, + zset_id, + group_values, + is_min: false, + } + } + + pub fn find_new_value( + &mut self, + cursors: &mut DbspStateCursors, + ) -> Result>> { + loop { + match self { + ScanState::CheckCandidate { + candidate, + group_key, + column_name, + storage_id, + zset_id, + group_values, + is_min, + } => { + // First, check if we have a candidate + if let Some(cand_val) = candidate { + // Check if the candidate is retracted (weight <= 0) + // Create a HashableRow to look up the weight + let hashable_cand = HashableRow::new(0, vec![cand_val.clone()]); + let key = (column_name.clone(), hashable_cand); + let is_retracted = + group_values.get(&key).is_some_and(|weight| *weight <= 0); + + if is_retracted { + // Candidate is retracted, need to fetch next from index + *self = ScanState::FetchNextCandidate { + current_candidate: cand_val.clone(), + group_key: std::mem::take(group_key), + column_name: std::mem::take(column_name), + storage_id: *storage_id, + zset_id: *zset_id, + group_values: std::mem::take(group_values), + is_min: *is_min, + }; + continue; + } + } + + // Candidate is valid or we have no candidate + // Now find the best value from insertions in group_values + let mut best_from_zset = None; + for ((col, hashable_val), weight) in group_values.iter() { + if col == column_name && *weight > 0 { + let value = &hashable_val.values[0]; + // Skip NULL values - they don't participate in MIN/MAX + if value == &Value::Null { + continue; + } + // This is an insertion for our column + if let Some(ref current_best) = best_from_zset { + if *is_min { + if value.cmp(current_best) == std::cmp::Ordering::Less { + best_from_zset = Some(value.clone()); + } + } else if value.cmp(current_best) == std::cmp::Ordering::Greater { + best_from_zset = Some(value.clone()); + } + } else { + best_from_zset = Some(value.clone()); + } + } + } + + // Compare candidate with best from ZSet, filtering out NULLs + let result = match (&candidate, &best_from_zset) { + (Some(cand), Some(zset_val)) if cand != &Value::Null => { + if *is_min { + if zset_val.cmp(cand) == std::cmp::Ordering::Less { + Some(zset_val.clone()) + } else { + Some(cand.clone()) + } + } else if zset_val.cmp(cand) == std::cmp::Ordering::Greater { + Some(zset_val.clone()) + } else { + Some(cand.clone()) + } + } + (Some(cand), None) if cand != &Value::Null => Some(cand.clone()), + (None, Some(zset_val)) => Some(zset_val.clone()), + (Some(cand), Some(_)) if cand == &Value::Null => best_from_zset, + _ => None, + }; + + *self = ScanState::Done { result }; + } + + ScanState::FetchNextCandidate { + current_candidate, + group_key, + column_name, + storage_id, + zset_id, + group_values, + is_min, + } => { + // Seek to the next value in the index + let index_key = vec![ + Value::Integer(*storage_id), + Value::Integer(*zset_id), + current_candidate.clone(), + ]; + let index_record = ImmutableRecord::from_values(&index_key, index_key.len()); + + let seek_op = if *is_min { + SeekOp::GT // For MIN, seek greater than current + } else { + SeekOp::LT // For MAX, seek less than current + }; + + let new_candidate = return_if_io!(Self::extract_new_candidate( + cursors, + &index_record, + seek_op, + *storage_id, + *zset_id + )); + + *self = ScanState::CheckCandidate { + candidate: new_candidate, + group_key: std::mem::take(group_key), + column_name: std::mem::take(column_name), + storage_id: *storage_id, + zset_id: *zset_id, + group_values: std::mem::take(group_values), + is_min: *is_min, + }; + } + + ScanState::Done { result } => { + return Ok(IOResult::Done(result.clone())); + } + } + } + } +} + +/// State machine for persisting Min/Max values to storage +#[derive(Debug)] +pub enum MinMaxPersistState { + Init { + min_max_deltas: MinMaxDeltas, + group_keys: Vec, + }, + ProcessGroup { + min_max_deltas: MinMaxDeltas, + group_keys: Vec, + group_idx: usize, + value_idx: usize, + }, + WriteValue { + min_max_deltas: MinMaxDeltas, + group_keys: Vec, + group_idx: usize, + value_idx: usize, + value: Value, + column_name: String, + weight: isize, + write_row: WriteRow, + }, + Done, +} + +impl MinMaxPersistState { + pub fn new(min_max_deltas: MinMaxDeltas) -> Self { + let group_keys: Vec = min_max_deltas.keys().cloned().collect(); + Self::Init { + min_max_deltas, + group_keys, + } + } + + pub fn persist_min_max( + &mut self, + operator_id: usize, + column_min_max: &HashMap, + cursors: &mut DbspStateCursors, + generate_group_rowid: impl Fn(&str) -> i64, + ) -> Result> { + loop { + match self { + MinMaxPersistState::Init { + min_max_deltas, + group_keys, + } => { + let min_max_deltas = std::mem::take(min_max_deltas); + let group_keys = std::mem::take(group_keys); + *self = MinMaxPersistState::ProcessGroup { + min_max_deltas, + group_keys, + group_idx: 0, + value_idx: 0, + }; + } + MinMaxPersistState::ProcessGroup { + min_max_deltas, + group_keys, + group_idx, + value_idx, + } => { + // Check if we're past all groups + if *group_idx >= group_keys.len() { + *self = MinMaxPersistState::Done; + continue; + } + + let group_key_str = &group_keys[*group_idx]; + let values = &min_max_deltas[group_key_str]; // This should always exist + + // Convert HashMap to Vec for indexed access + let values_vec: Vec<_> = values.iter().collect(); + + // Check if we have more values in current group + if *value_idx >= values_vec.len() { + *group_idx += 1; + *value_idx = 0; + // Continue to check if we're past all groups now + continue; + } + + // Process current value and extract what we need before taking ownership + let ((column_name, hashable_row), weight) = values_vec[*value_idx]; + let column_name = column_name.clone(); + let value = hashable_row.values[0].clone(); // Extract the Value from HashableRow + let weight = *weight; + + let min_max_deltas = std::mem::take(min_max_deltas); + let group_keys = std::mem::take(group_keys); + *self = MinMaxPersistState::WriteValue { + min_max_deltas, + group_keys, + group_idx: *group_idx, + value_idx: *value_idx, + column_name, + value, + weight, + write_row: WriteRow::new(), + }; + } + MinMaxPersistState::WriteValue { + min_max_deltas, + group_keys, + group_idx, + value_idx, + value, + column_name, + weight, + write_row, + } => { + // Should have exited in the previous state + assert!(*group_idx < group_keys.len()); + + let group_key_str = &group_keys[*group_idx]; + + // Get the column index from the pre-computed map + let column_info = column_min_max + .get(&*column_name) + .expect("Column should exist in column_min_max map"); + let column_index = column_info.index; + + // Build the key components for MinMax storage using new encoding + let storage_id = + generate_storage_id(operator_id, column_index, AGG_TYPE_MINMAX); + let zset_id = generate_group_rowid(group_key_str); + + // element_id is the actual value for Min/Max + let element_id_val = value.clone(); + + // Create index key + let index_key = vec![ + Value::Integer(storage_id), + Value::Integer(zset_id), + element_id_val.clone(), + ]; + + // Record values (operator_id, zset_id, element_id, unused_placeholder) + // For MIN/MAX, the element_id IS the value, so we use NULL for the 4th column + let record_values = vec![ + Value::Integer(storage_id), + Value::Integer(zset_id), + element_id_val.clone(), + Value::Null, // Placeholder - not used for MIN/MAX + ]; + + return_if_io!(write_row.write_row( + cursors, + index_key.clone(), + record_values, + *weight + )); + + // Move to next value + let min_max_deltas = std::mem::take(min_max_deltas); + let group_keys = std::mem::take(group_keys); + *self = MinMaxPersistState::ProcessGroup { + min_max_deltas, + group_keys, + group_idx: *group_idx, + value_idx: *value_idx + 1, + }; + } + MinMaxPersistState::Done => { + return Ok(IOResult::Done(())); + } + } + } + } +} diff --git a/core/incremental/mod.rs b/core/incremental/mod.rs index 0e45b3194..a747809d9 100644 --- a/core/incremental/mod.rs +++ b/core/incremental/mod.rs @@ -1,3 +1,4 @@ +pub mod aggregate_operator; pub mod compiler; pub mod cursor; pub mod dbsp; diff --git a/core/incremental/operator.rs b/core/incremental/operator.rs index 43ad8f67c..92b35d5f1 100644 --- a/core/incremental/operator.rs +++ b/core/incremental/operator.rs @@ -2,19 +2,20 @@ // Operator DAG for DBSP-style incremental computation // Based on Feldera DBSP design but adapted for Turso's architecture +pub use crate::incremental::aggregate_operator::{ + AggregateEvalState, AggregateFunction, AggregateOperator, AggregateState, +}; pub use crate::incremental::filter_operator::{FilterOperator, FilterPredicate}; pub use crate::incremental::input_operator::InputOperator; pub use crate::incremental::project_operator::{ProjectColumn, ProjectOperator}; -use crate::function::{AggFunc, Func}; use crate::incremental::dbsp::{Delta, DeltaPair, HashableRow}; -use crate::incremental::persistence::{MinMaxPersistState, ReadRecord, RecomputeMinMax, WriteRow}; +use crate::incremental::persistence::WriteRow; use crate::schema::{Index, IndexColumn}; use crate::storage::btree::BTreeCursor; use crate::types::{IOResult, ImmutableRecord, SeekKey, SeekOp, SeekResult}; use crate::{return_and_restore_if_io, return_if_io, Result, Value}; -use std::collections::{BTreeMap, HashMap}; -use std::fmt::{self, Debug, Display}; +use std::fmt::Debug; use std::sync::{Arc, Mutex}; /// Struct to hold both table and index cursors for DBSP state operations @@ -71,12 +72,6 @@ pub fn create_dbsp_state_index(root_page: usize) -> Index { } } -/// Constants for aggregate type encoding in storage IDs (2 bits) -pub const AGG_TYPE_REGULAR: u8 = 0b00; // COUNT/SUM/AVG -pub const AGG_TYPE_MINMAX: u8 = 0b01; // MIN/MAX (BTree ordering gives both) -pub const AGG_TYPE_RESERVED1: u8 = 0b10; // Reserved for future use -pub const AGG_TYPE_RESERVED2: u8 = 0b11; // Reserved for future use - /// Generate a storage ID with column index and operation type encoding /// Storage ID = (operator_id << 16) | (column_index << 2) | operation_type /// Bit layout (64-bit integer): @@ -90,64 +85,6 @@ pub fn generate_storage_id(operator_id: usize, column_index: usize, op_type: u8) ((operator_id as i64) << 16) | ((column_index as i64) << 2) | (op_type as i64) } -// group_key_str -> (group_key, state) -type ComputedStates = HashMap, AggregateState)>; -// group_key_str -> (column_name, value_as_hashable_row) -> accumulated_weight -pub type MinMaxDeltas = HashMap>; - -#[derive(Debug)] -enum AggregateCommitState { - Idle, - Eval { - eval_state: EvalState, - }, - PersistDelta { - delta: Delta, - computed_states: ComputedStates, - current_idx: usize, - write_row: WriteRow, - min_max_deltas: MinMaxDeltas, - }, - PersistMinMax { - delta: Delta, - min_max_persist_state: MinMaxPersistState, - }, - Done { - delta: Delta, - }, - Invalid, -} - -// Aggregate-specific eval states -#[derive(Debug)] -pub enum AggregateEvalState { - FetchKey { - delta: Delta, // Keep original delta for merge operation - current_idx: usize, - groups_to_read: Vec<(String, Vec)>, // Changed to Vec for index-based access - existing_groups: HashMap, - old_values: HashMap>, - }, - FetchData { - delta: Delta, // Keep original delta for merge operation - current_idx: usize, - groups_to_read: Vec<(String, Vec)>, // Changed to Vec for index-based access - existing_groups: HashMap, - old_values: HashMap>, - rowid: Option, // Rowid found by FetchKey (None if not found) - read_record_state: Box, - }, - RecomputeMinMax { - delta: Delta, - existing_groups: HashMap, - old_values: HashMap>, - recompute_state: Box, - }, - Done { - output: (Delta, ComputedStates), - }, -} - // Helper function to read the next row from the BTree for joins fn read_next_join_row( storage_id: i64, @@ -476,180 +413,6 @@ impl EvalState { _ => panic!("extract_delta() can only be called when in Init state"), } } - - fn advance_aggregate(&mut self, groups_to_read: BTreeMap>) { - let delta = match self { - EvalState::Init { deltas } => std::mem::take(&mut deltas.left), - _ => panic!("advance_aggregate() can only be called when in Init state, current state: {self:?}"), - }; - - let _ = std::mem::replace( - self, - EvalState::Aggregate(Box::new(AggregateEvalState::FetchKey { - delta, - current_idx: 0, - groups_to_read: groups_to_read.into_iter().collect(), // Convert BTreeMap to Vec - existing_groups: HashMap::new(), - old_values: HashMap::new(), - })), - ); - } -} - -impl AggregateEvalState { - fn process_delta( - &mut self, - operator: &mut AggregateOperator, - cursors: &mut DbspStateCursors, - ) -> Result> { - loop { - match self { - AggregateEvalState::FetchKey { - delta, - current_idx, - groups_to_read, - existing_groups, - old_values, - } => { - if *current_idx >= groups_to_read.len() { - // All groups have been fetched, move to RecomputeMinMax - // Extract MIN/MAX deltas from the input delta - let min_max_deltas = operator.extract_min_max_deltas(delta); - - let recompute_state = Box::new(RecomputeMinMax::new( - min_max_deltas, - existing_groups, - operator, - )); - - *self = AggregateEvalState::RecomputeMinMax { - delta: std::mem::take(delta), - existing_groups: std::mem::take(existing_groups), - old_values: std::mem::take(old_values), - recompute_state, - }; - } else { - // Get the current group to read - let (group_key_str, _group_key) = &groups_to_read[*current_idx]; - - // Build the key for the index: (operator_id, zset_id, element_id) - // For regular aggregates, use column_index=0 and AGG_TYPE_REGULAR - let operator_storage_id = - generate_storage_id(operator.operator_id, 0, AGG_TYPE_REGULAR); - let zset_id = operator.generate_group_rowid(group_key_str); - let element_id = 0i64; // Always 0 for aggregators - - // Create index key values - let index_key_values = vec![ - Value::Integer(operator_storage_id), - Value::Integer(zset_id), - Value::Integer(element_id), - ]; - - // Create an immutable record for the index key - let index_record = - ImmutableRecord::from_values(&index_key_values, index_key_values.len()); - - // Seek in the index to find if this row exists - let seek_result = return_if_io!(cursors.index_cursor.seek( - SeekKey::IndexKey(&index_record), - SeekOp::GE { eq_only: true } - )); - - let rowid = if matches!(seek_result, SeekResult::Found) { - // Found in index, get the table rowid - // The btree code handles extracting the rowid from the index record for has_rowid indexes - return_if_io!(cursors.index_cursor.rowid()) - } else { - // Not found in index, no existing state - None - }; - - // Always transition to FetchData - let taken_existing = std::mem::take(existing_groups); - let taken_old_values = std::mem::take(old_values); - let next_state = AggregateEvalState::FetchData { - delta: std::mem::take(delta), - current_idx: *current_idx, - groups_to_read: std::mem::take(groups_to_read), - existing_groups: taken_existing, - old_values: taken_old_values, - rowid, - read_record_state: Box::new(ReadRecord::new()), - }; - *self = next_state; - } - } - AggregateEvalState::FetchData { - delta, - current_idx, - groups_to_read, - existing_groups, - old_values, - rowid, - read_record_state, - } => { - // Get the current group to read - let (group_key_str, group_key) = &groups_to_read[*current_idx]; - - // Only try to read if we have a rowid - if let Some(rowid) = rowid { - let key = SeekKey::TableRowId(*rowid); - let state = return_if_io!(read_record_state.read_record( - key, - &operator.aggregates, - &mut cursors.table_cursor - )); - // Process the fetched state - if let Some(state) = state { - let mut old_row = group_key.clone(); - old_row.extend(state.to_values(&operator.aggregates)); - old_values.insert(group_key_str.clone(), old_row); - existing_groups.insert(group_key_str.clone(), state.clone()); - } - } else { - // No rowid for this group, skipping read - } - // If no rowid, there's no existing state for this group - - // Move to next group - let next_idx = *current_idx + 1; - let taken_existing = std::mem::take(existing_groups); - let taken_old_values = std::mem::take(old_values); - let next_state = AggregateEvalState::FetchKey { - delta: std::mem::take(delta), - current_idx: next_idx, - groups_to_read: std::mem::take(groups_to_read), - existing_groups: taken_existing, - old_values: taken_old_values, - }; - *self = next_state; - } - AggregateEvalState::RecomputeMinMax { - delta, - existing_groups, - old_values, - recompute_state, - } => { - if operator.has_min_max() { - // Process MIN/MAX recomputation - this will update existing_groups with correct MIN/MAX - return_if_io!(recompute_state.process(existing_groups, operator, cursors)); - } - - // Now compute final output with updated MIN/MAX values - let (output_delta, computed_states) = - operator.merge_delta_with_existing(delta, existing_groups, old_values); - - *self = AggregateEvalState::Done { - output: (output_delta, computed_states), - }; - } - AggregateEvalState::Done { output } => { - return Ok(IOResult::Done(output.clone())); - } - } - } - } } /// Tracks computation counts to verify incremental behavior (for tests now), and in the future @@ -800,56 +563,6 @@ pub enum JoinType { Cross, } -#[derive(Debug, Clone, PartialEq)] -pub enum AggregateFunction { - Count, - Sum(String), - Avg(String), - Min(String), - Max(String), -} - -impl Display for AggregateFunction { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - match self { - AggregateFunction::Count => write!(f, "COUNT(*)"), - AggregateFunction::Sum(col) => write!(f, "SUM({col})"), - AggregateFunction::Avg(col) => write!(f, "AVG({col})"), - AggregateFunction::Min(col) => write!(f, "MIN({col})"), - AggregateFunction::Max(col) => write!(f, "MAX({col})"), - } - } -} - -impl AggregateFunction { - /// Get the default output column name for this aggregate function - #[inline] - pub fn default_output_name(&self) -> String { - self.to_string() - } - - /// Create an AggregateFunction from a SQL function and its arguments - /// Returns None if the function is not a supported aggregate - pub fn from_sql_function( - func: &crate::function::Func, - input_column: Option, - ) -> Option { - match func { - Func::Agg(agg_func) => { - match agg_func { - AggFunc::Count | AggFunc::Count0 => Some(AggregateFunction::Count), - AggFunc::Sum => input_column.map(AggregateFunction::Sum), - AggFunc::Avg => input_column.map(AggregateFunction::Avg), - AggFunc::Min => input_column.map(AggregateFunction::Min), - AggFunc::Max => input_column.map(AggregateFunction::Max), - _ => None, // Other aggregate functions not yet supported in DBSP - } - } - _ => None, // Not an aggregate function - } - } -} - /// Operator DAG (Directed Acyclic Graph) /// Base trait for incremental operators pub trait IncrementalOperator: Debug { @@ -883,859 +596,6 @@ pub trait IncrementalOperator: Debug { fn set_tracker(&mut self, tracker: Arc>); } -/// Aggregate operator - performs incremental aggregation with GROUP BY -/// Maintains running totals/counts that are updated incrementally -/// -/// Information about a column that has MIN/MAX aggregations -#[derive(Debug, Clone)] -pub struct AggColumnInfo { - /// Index used for storage key generation - pub index: usize, - /// Whether this column has a MIN aggregate - pub has_min: bool, - /// Whether this column has a MAX aggregate - pub has_max: bool, -} - -/// Note that the AggregateOperator essentially implements a ZSet, even -/// though the ZSet structure is never used explicitly. The on-disk btree -/// plays the role of the set! -#[derive(Debug)] -pub struct AggregateOperator { - // Unique operator ID for indexing in persistent storage - pub operator_id: usize, - // GROUP BY columns - group_by: Vec, - // Aggregate functions to compute (including MIN/MAX) - pub aggregates: Vec, - // Column names from input - pub input_column_names: Vec, - // Map from column name to aggregate info for quick lookup - pub column_min_max: HashMap, - tracker: Option>>, - - // State machine for commit operation - commit_state: AggregateCommitState, -} - -/// State for a single group's aggregates -#[derive(Debug, Clone, Default)] -pub struct AggregateState { - // For COUNT: just the count - count: i64, - // For SUM: column_name -> sum value - sums: HashMap, - // For AVG: column_name -> (sum, count) for computing average - avgs: HashMap, - // For MIN: column_name -> minimum value - pub mins: HashMap, - // For MAX: column_name -> maximum value - pub maxs: HashMap, -} - -/// Serialize a Value using SQLite's serial type format -/// This is used for MIN/MAX values that need to be stored in a compact, sortable format -pub fn serialize_value(value: &Value, blob: &mut Vec) { - let serial_type = crate::types::SerialType::from(value); - let serial_type_u64: u64 = serial_type.into(); - crate::storage::sqlite3_ondisk::write_varint_to_vec(serial_type_u64, blob); - value.serialize_serial(blob); -} - -/// Deserialize a Value using SQLite's serial type format -/// Returns the deserialized value and the number of bytes consumed -pub fn deserialize_value(blob: &[u8]) -> Option<(Value, usize)> { - let mut cursor = 0; - - // Read the serial type - let (serial_type, varint_size) = crate::storage::sqlite3_ondisk::read_varint(blob).ok()?; - cursor += varint_size; - - let serial_type_obj = crate::types::SerialType::try_from(serial_type).ok()?; - let expected_size = serial_type_obj.size(); - - // Read the value - let (value, actual_size) = - crate::storage::sqlite3_ondisk::read_value(&blob[cursor..], serial_type_obj).ok()?; - - // Verify that the actual size matches what we expected from the serial type - if actual_size != expected_size { - return None; // Data corruption - size mismatch - } - - cursor += actual_size; - - // Convert RefValue to Value - Some((value.to_owned(), cursor)) -} - -impl AggregateState { - pub fn new() -> Self { - Self::default() - } - - // Serialize the aggregate state to a binary blob including group key values - // The reason we serialize it like this, instead of just writing the actual values, is that - // The same table may have different aggregators in the circuit. They will all have different - // columns. - fn to_blob(&self, aggregates: &[AggregateFunction], group_key: &[Value]) -> Vec { - let mut blob = Vec::new(); - - // Write version byte for future compatibility - blob.push(1u8); - - // Write number of group key values - blob.extend_from_slice(&(group_key.len() as u32).to_le_bytes()); - - // Write each group key value - for value in group_key { - // Write value type tag - match value { - Value::Null => blob.push(0u8), - Value::Integer(i) => { - blob.push(1u8); - blob.extend_from_slice(&i.to_le_bytes()); - } - Value::Float(f) => { - blob.push(2u8); - blob.extend_from_slice(&f.to_le_bytes()); - } - Value::Text(s) => { - blob.push(3u8); - let text_str = s.as_str(); - let bytes = text_str.as_bytes(); - blob.extend_from_slice(&(bytes.len() as u32).to_le_bytes()); - blob.extend_from_slice(bytes); - } - Value::Blob(b) => { - blob.push(4u8); - blob.extend_from_slice(&(b.len() as u32).to_le_bytes()); - blob.extend_from_slice(b); - } - } - } - - // Write count as 8 bytes (little-endian) - blob.extend_from_slice(&self.count.to_le_bytes()); - - // Write each aggregate's state - for agg in aggregates { - match agg { - AggregateFunction::Sum(col_name) => { - let sum = self.sums.get(col_name).copied().unwrap_or(0.0); - blob.extend_from_slice(&sum.to_le_bytes()); - } - AggregateFunction::Avg(col_name) => { - let (sum, count) = self.avgs.get(col_name).copied().unwrap_or((0.0, 0)); - blob.extend_from_slice(&sum.to_le_bytes()); - blob.extend_from_slice(&count.to_le_bytes()); - } - AggregateFunction::Count => { - // Count is already written above - } - AggregateFunction::Min(col_name) => { - // Write whether we have a MIN value (1 byte) - if let Some(min_val) = self.mins.get(col_name) { - blob.push(1u8); // Has value - serialize_value(min_val, &mut blob); - } else { - blob.push(0u8); // No value - } - } - AggregateFunction::Max(col_name) => { - // Write whether we have a MAX value (1 byte) - if let Some(max_val) = self.maxs.get(col_name) { - blob.push(1u8); // Has value - serialize_value(max_val, &mut blob); - } else { - blob.push(0u8); // No value - } - } - } - } - - blob - } - - /// Deserialize aggregate state from a binary blob - /// Returns the aggregate state and the group key values - pub fn from_blob(blob: &[u8], aggregates: &[AggregateFunction]) -> Option<(Self, Vec)> { - let mut cursor = 0; - - // Check version byte - if blob.get(cursor) != Some(&1u8) { - return None; - } - cursor += 1; - - // Read number of group key values - let num_group_keys = - u32::from_le_bytes(blob.get(cursor..cursor + 4)?.try_into().ok()?) as usize; - cursor += 4; - - // Read group key values - let mut group_key = Vec::new(); - for _ in 0..num_group_keys { - let value_type = *blob.get(cursor)?; - cursor += 1; - - let value = match value_type { - 0 => Value::Null, - 1 => { - let i = i64::from_le_bytes(blob.get(cursor..cursor + 8)?.try_into().ok()?); - cursor += 8; - Value::Integer(i) - } - 2 => { - let f = f64::from_le_bytes(blob.get(cursor..cursor + 8)?.try_into().ok()?); - cursor += 8; - Value::Float(f) - } - 3 => { - let len = - u32::from_le_bytes(blob.get(cursor..cursor + 4)?.try_into().ok()?) as usize; - cursor += 4; - let bytes = blob.get(cursor..cursor + len)?; - cursor += len; - let text_str = std::str::from_utf8(bytes).ok()?; - Value::Text(text_str.to_string().into()) - } - 4 => { - let len = - u32::from_le_bytes(blob.get(cursor..cursor + 4)?.try_into().ok()?) as usize; - cursor += 4; - let bytes = blob.get(cursor..cursor + len)?; - cursor += len; - Value::Blob(bytes.to_vec()) - } - _ => return None, - }; - group_key.push(value); - } - - // Read count - let count = i64::from_le_bytes(blob.get(cursor..cursor + 8)?.try_into().ok()?); - cursor += 8; - - let mut state = Self::new(); - state.count = count; - - // Read each aggregate's state - for agg in aggregates { - match agg { - AggregateFunction::Sum(col_name) => { - let sum = f64::from_le_bytes(blob.get(cursor..cursor + 8)?.try_into().ok()?); - cursor += 8; - state.sums.insert(col_name.clone(), sum); - } - AggregateFunction::Avg(col_name) => { - let sum = f64::from_le_bytes(blob.get(cursor..cursor + 8)?.try_into().ok()?); - cursor += 8; - let count = i64::from_le_bytes(blob.get(cursor..cursor + 8)?.try_into().ok()?); - cursor += 8; - state.avgs.insert(col_name.clone(), (sum, count)); - } - AggregateFunction::Count => { - // Count was already read above - } - AggregateFunction::Min(col_name) => { - // Read whether we have a MIN value - let has_value = *blob.get(cursor)?; - cursor += 1; - - if has_value == 1 { - let (min_value, bytes_consumed) = deserialize_value(&blob[cursor..])?; - cursor += bytes_consumed; - state.mins.insert(col_name.clone(), min_value); - } - } - AggregateFunction::Max(col_name) => { - // Read whether we have a MAX value - let has_value = *blob.get(cursor)?; - cursor += 1; - - if has_value == 1 { - let (max_value, bytes_consumed) = deserialize_value(&blob[cursor..])?; - cursor += bytes_consumed; - state.maxs.insert(col_name.clone(), max_value); - } - } - } - } - - Some((state, group_key)) - } - - /// Apply a delta to this aggregate state - fn apply_delta( - &mut self, - values: &[Value], - weight: isize, - aggregates: &[AggregateFunction], - column_names: &[String], - ) { - // Update COUNT - self.count += weight as i64; - - // Update other aggregates - for agg in aggregates { - match agg { - AggregateFunction::Count => { - // Already handled above - } - AggregateFunction::Sum(col_name) => { - if let Some(idx) = column_names.iter().position(|c| c == col_name) { - if let Some(val) = values.get(idx) { - let num_val = match val { - Value::Integer(i) => *i as f64, - Value::Float(f) => *f, - _ => 0.0, - }; - *self.sums.entry(col_name.clone()).or_insert(0.0) += - num_val * weight as f64; - } - } - } - AggregateFunction::Avg(col_name) => { - if let Some(idx) = column_names.iter().position(|c| c == col_name) { - if let Some(val) = values.get(idx) { - let num_val = match val { - Value::Integer(i) => *i as f64, - Value::Float(f) => *f, - _ => 0.0, - }; - let (sum, count) = - self.avgs.entry(col_name.clone()).or_insert((0.0, 0)); - *sum += num_val * weight as f64; - *count += weight as i64; - } - } - } - AggregateFunction::Min(_col_name) | AggregateFunction::Max(_col_name) => { - // MIN/MAX cannot be handled incrementally in apply_delta because: - // - // 1. For insertions: We can't just keep the minimum/maximum value. - // We need to track ALL values to handle future deletions correctly. - // - // 2. For deletions (retractions): If we delete the current MIN/MAX, - // we need to find the next best value, which requires knowing all - // other values in the group. - // - // Example: Consider MIN(price) with values [10, 20, 30] - // - Current MIN = 10 - // - Delete 10 (weight = -1) - // - New MIN should be 20, but we can't determine this without - // having tracked all values [20, 30] - // - // Therefore, MIN/MAX processing is handled separately: - // - All input values are persisted to the index via persist_min_max() - // - When aggregates have MIN/MAX, we unconditionally transition to - // the RecomputeMinMax state machine (see EvalState::RecomputeMinMax) - // - RecomputeMinMax checks if the current MIN/MAX was deleted, and if so, - // scans the index to find the new MIN/MAX from remaining values - // - // This ensures correctness for incremental computation at the cost of - // additional I/O for MIN/MAX operations. - } - } - } - } - - /// Convert aggregate state to output values - pub fn to_values(&self, aggregates: &[AggregateFunction]) -> Vec { - let mut result = Vec::new(); - - for agg in aggregates { - match agg { - AggregateFunction::Count => { - result.push(Value::Integer(self.count)); - } - AggregateFunction::Sum(col_name) => { - let sum = self.sums.get(col_name).copied().unwrap_or(0.0); - // Return as integer if it's a whole number, otherwise as float - if sum.fract() == 0.0 { - result.push(Value::Integer(sum as i64)); - } else { - result.push(Value::Float(sum)); - } - } - AggregateFunction::Avg(col_name) => { - if let Some((sum, count)) = self.avgs.get(col_name) { - if *count > 0 { - result.push(Value::Float(sum / *count as f64)); - } else { - result.push(Value::Null); - } - } else { - result.push(Value::Null); - } - } - AggregateFunction::Min(col_name) => { - // Return the MIN value from our state - result.push(self.mins.get(col_name).cloned().unwrap_or(Value::Null)); - } - AggregateFunction::Max(col_name) => { - // Return the MAX value from our state - result.push(self.maxs.get(col_name).cloned().unwrap_or(Value::Null)); - } - } - } - - result - } -} - -impl AggregateOperator { - pub fn new( - operator_id: usize, - group_by: Vec, - aggregates: Vec, - input_column_names: Vec, - ) -> Self { - // Build map of column names to their MIN/MAX info with indices - let mut column_min_max = HashMap::new(); - let mut column_indices = HashMap::new(); - let mut current_index = 0; - - // First pass: assign indices to unique MIN/MAX columns - for agg in &aggregates { - match agg { - AggregateFunction::Min(col) | AggregateFunction::Max(col) => { - column_indices.entry(col.clone()).or_insert_with(|| { - let idx = current_index; - current_index += 1; - idx - }); - } - _ => {} - } - } - - // Second pass: build the column info map - for agg in &aggregates { - match agg { - AggregateFunction::Min(col) => { - let index = *column_indices.get(col).unwrap(); - let entry = column_min_max.entry(col.clone()).or_insert(AggColumnInfo { - index, - has_min: false, - has_max: false, - }); - entry.has_min = true; - } - AggregateFunction::Max(col) => { - let index = *column_indices.get(col).unwrap(); - let entry = column_min_max.entry(col.clone()).or_insert(AggColumnInfo { - index, - has_min: false, - has_max: false, - }); - entry.has_max = true; - } - _ => {} - } - } - - Self { - operator_id, - group_by, - aggregates, - input_column_names, - column_min_max, - tracker: None, - commit_state: AggregateCommitState::Idle, - } - } - - pub fn has_min_max(&self) -> bool { - !self.column_min_max.is_empty() - } - - fn eval_internal( - &mut self, - state: &mut EvalState, - cursors: &mut DbspStateCursors, - ) -> Result> { - match state { - EvalState::Uninitialized => { - panic!("Cannot eval AggregateOperator with Uninitialized state"); - } - EvalState::Init { deltas } => { - // Aggregate operators only use left_delta, right_delta must be empty - assert!( - deltas.right.is_empty(), - "AggregateOperator expects right_delta to be empty" - ); - - if deltas.left.changes.is_empty() { - *state = EvalState::Done; - return Ok(IOResult::Done((Delta::new(), HashMap::new()))); - } - - let mut groups_to_read = BTreeMap::new(); - for (row, _weight) in &deltas.left.changes { - // Extract group key using cloned fields - let group_key = self.extract_group_key(&row.values); - let group_key_str = Self::group_key_to_string(&group_key); - groups_to_read.insert(group_key_str, group_key); - } - state.advance_aggregate(groups_to_read); - } - EvalState::Aggregate(_agg_state) => { - // Already in progress, continue processing below. - } - EvalState::Done => { - panic!("unreachable state! should have returned"); - } - EvalState::Join(_) => { - panic!("Join state should not appear in aggregate operator"); - } - } - - // Process the delta through the aggregate state machine - match state { - EvalState::Aggregate(agg_state) => { - let result = return_if_io!(agg_state.process_delta(self, cursors)); - Ok(IOResult::Done(result)) - } - _ => panic!("Invalid state for aggregate processing"), - } - } - - fn merge_delta_with_existing( - &mut self, - delta: &Delta, - existing_groups: &mut HashMap, - old_values: &mut HashMap>, - ) -> (Delta, HashMap, AggregateState)>) { - let mut output_delta = Delta::new(); - let mut temp_keys: HashMap> = HashMap::new(); - - // Process each change in the delta - for (row, weight) in &delta.changes { - if let Some(tracker) = &self.tracker { - tracker.lock().unwrap().record_aggregation(); - } - - // Extract group key - let group_key = self.extract_group_key(&row.values); - let group_key_str = Self::group_key_to_string(&group_key); - - let state = existing_groups.entry(group_key_str.clone()).or_default(); - - temp_keys.insert(group_key_str.clone(), group_key.clone()); - - // Apply the delta to the temporary state - state.apply_delta( - &row.values, - *weight, - &self.aggregates, - &self.input_column_names, - ); - } - - // Generate output delta from temporary states and collect final states - let mut final_states = HashMap::new(); - - for (group_key_str, state) in existing_groups { - let group_key = temp_keys.get(group_key_str).cloned().unwrap_or_default(); - - // Generate a unique rowid for this group - let result_key = self.generate_group_rowid(group_key_str); - - if let Some(old_row_values) = old_values.get(group_key_str) { - let old_row = HashableRow::new(result_key, old_row_values.clone()); - output_delta.changes.push((old_row, -1)); - } - - // Always store the state for persistence (even if count=0, we need to delete it) - final_states.insert(group_key_str.clone(), (group_key.clone(), state.clone())); - - // Only include groups with count > 0 in the output delta - if state.count > 0 { - // Build output row: group_by columns + aggregate values - let mut output_values = group_key.clone(); - let aggregate_values = state.to_values(&self.aggregates); - output_values.extend(aggregate_values); - - let output_row = HashableRow::new(result_key, output_values.clone()); - output_delta.changes.push((output_row, 1)); - } - } - (output_delta, final_states) - } - - /// Extract MIN/MAX values from delta changes for persistence to index - fn extract_min_max_deltas(&self, delta: &Delta) -> MinMaxDeltas { - let mut min_max_deltas: MinMaxDeltas = HashMap::new(); - - for (row, weight) in &delta.changes { - let group_key = self.extract_group_key(&row.values); - let group_key_str = Self::group_key_to_string(&group_key); - - for agg in &self.aggregates { - match agg { - AggregateFunction::Min(col_name) | AggregateFunction::Max(col_name) => { - if let Some(idx) = - self.input_column_names.iter().position(|c| c == col_name) - { - if let Some(val) = row.values.get(idx) { - // Skip NULL values - they don't participate in MIN/MAX - if val == &Value::Null { - continue; - } - // Create a HashableRow with just this value - // Use 0 as rowid since we only care about the value for comparison - let hashable_value = HashableRow::new(0, vec![val.clone()]); - let key = (col_name.clone(), hashable_value); - - let group_entry = - min_max_deltas.entry(group_key_str.clone()).or_default(); - - let value_entry = group_entry.entry(key).or_insert(0); - - // Accumulate the weight - *value_entry += weight; - } - } - } - _ => {} // Ignore non-MIN/MAX aggregates - } - } - } - - min_max_deltas - } - - pub fn set_tracker(&mut self, tracker: Arc>) { - self.tracker = Some(tracker); - } - - /// Generate a rowid for a group - /// For no GROUP BY: always returns 0 - /// For GROUP BY: returns a hash of the group key string - pub fn generate_group_rowid(&self, group_key_str: &str) -> i64 { - if self.group_by.is_empty() { - 0 - } else { - group_key_str - .bytes() - .fold(0i64, |acc, b| acc.wrapping_mul(31).wrapping_add(b as i64)) - } - } - - /// Generate the composite key for BTree storage - /// Combines operator_id and group hash - fn generate_storage_key(&self, group_key_str: &str) -> i64 { - let group_hash = self.generate_group_rowid(group_key_str); - (self.operator_id as i64) << 32 | (group_hash & 0xFFFFFFFF) - } - - /// Extract group key values from a row - pub fn extract_group_key(&self, values: &[Value]) -> Vec { - let mut key = Vec::new(); - - for group_col in &self.group_by { - if let Some(idx) = self.input_column_names.iter().position(|c| c == group_col) { - if let Some(val) = values.get(idx) { - key.push(val.clone()); - } else { - key.push(Value::Null); - } - } else { - key.push(Value::Null); - } - } - - key - } - - /// Convert group key to string for indexing (since Value doesn't implement Hash) - pub fn group_key_to_string(key: &[Value]) -> String { - key.iter() - .map(|v| format!("{v:?}")) - .collect::>() - .join(",") - } - - fn seek_key_from_str(&self, group_key_str: &str) -> SeekKey<'_> { - // Calculate the composite key for seeking - let key_i64 = self.generate_storage_key(group_key_str); - SeekKey::TableRowId(key_i64) - } - - fn seek_key(&self, row: HashableRow) -> SeekKey<'_> { - // Extract group key for first row - let group_key = self.extract_group_key(&row.values); - let group_key_str = Self::group_key_to_string(&group_key); - self.seek_key_from_str(&group_key_str) - } -} - -impl IncrementalOperator for AggregateOperator { - fn eval( - &mut self, - state: &mut EvalState, - cursors: &mut DbspStateCursors, - ) -> Result> { - let (delta, _) = return_if_io!(self.eval_internal(state, cursors)); - Ok(IOResult::Done(delta)) - } - - fn commit( - &mut self, - mut deltas: DeltaPair, - cursors: &mut DbspStateCursors, - ) -> Result> { - // Aggregate operator only uses left delta, right must be empty - assert!( - deltas.right.is_empty(), - "AggregateOperator expects right delta to be empty in commit" - ); - let delta = std::mem::take(&mut deltas.left); - loop { - // Note: because we std::mem::replace here (without it, the borrow checker goes nuts, - // because we call self.eval_interval, which requires a mutable borrow), we have to - // restore the state if we return I/O. So we can't use return_if_io! - let mut state = - std::mem::replace(&mut self.commit_state, AggregateCommitState::Invalid); - match &mut state { - AggregateCommitState::Invalid => { - panic!("Reached invalid state! State was replaced, and not replaced back"); - } - AggregateCommitState::Idle => { - let eval_state = EvalState::from_delta(delta.clone()); - self.commit_state = AggregateCommitState::Eval { eval_state }; - } - AggregateCommitState::Eval { ref mut eval_state } => { - // Extract input delta before eval for MIN/MAX processing - let input_delta = eval_state.extract_delta(); - - // Extract MIN/MAX deltas before any I/O operations - let min_max_deltas = self.extract_min_max_deltas(&input_delta); - - // Create a new eval state with the same delta - *eval_state = EvalState::from_delta(input_delta.clone()); - - let (output_delta, computed_states) = return_and_restore_if_io!( - &mut self.commit_state, - state, - self.eval_internal(eval_state, cursors) - ); - - self.commit_state = AggregateCommitState::PersistDelta { - delta: output_delta, - computed_states, - current_idx: 0, - write_row: WriteRow::new(), - min_max_deltas, // Store for later use - }; - } - AggregateCommitState::PersistDelta { - delta, - computed_states, - current_idx, - write_row, - min_max_deltas, - } => { - let states_vec: Vec<_> = computed_states.iter().collect(); - - if *current_idx >= states_vec.len() { - // Use the min_max_deltas we extracted earlier from the input delta - self.commit_state = AggregateCommitState::PersistMinMax { - delta: delta.clone(), - min_max_persist_state: MinMaxPersistState::new(min_max_deltas.clone()), - }; - } else { - let (group_key_str, (group_key, agg_state)) = states_vec[*current_idx]; - - // Build the key components for the new table structure - // For regular aggregates, use column_index=0 and AGG_TYPE_REGULAR - let operator_storage_id = - generate_storage_id(self.operator_id, 0, AGG_TYPE_REGULAR); - let zset_id = self.generate_group_rowid(group_key_str); - let element_id = 0i64; - - // Determine weight: -1 to delete (cancels existing weight=1), 1 to insert/update - let weight = if agg_state.count == 0 { -1 } else { 1 }; - - // Serialize the aggregate state with group key (even for deletion, we need a row) - let state_blob = agg_state.to_blob(&self.aggregates, group_key); - let blob_value = Value::Blob(state_blob); - - // Build the aggregate storage format: [operator_id, zset_id, element_id, value, weight] - let operator_id_val = Value::Integer(operator_storage_id); - let zset_id_val = Value::Integer(zset_id); - let element_id_val = Value::Integer(element_id); - let blob_val = blob_value.clone(); - - // Create index key - the first 3 columns of our primary key - let index_key = vec![ - operator_id_val.clone(), - zset_id_val.clone(), - element_id_val.clone(), - ]; - - // Record values (without weight) - let record_values = - vec![operator_id_val, zset_id_val, element_id_val, blob_val]; - - return_and_restore_if_io!( - &mut self.commit_state, - state, - write_row.write_row(cursors, index_key, record_values, weight) - ); - - let delta = std::mem::take(delta); - let computed_states = std::mem::take(computed_states); - let min_max_deltas = std::mem::take(min_max_deltas); - - self.commit_state = AggregateCommitState::PersistDelta { - delta, - computed_states, - current_idx: *current_idx + 1, - write_row: WriteRow::new(), // Reset for next write - min_max_deltas, - }; - } - } - AggregateCommitState::PersistMinMax { - delta, - min_max_persist_state, - } => { - if !self.has_min_max() { - let delta = std::mem::take(delta); - self.commit_state = AggregateCommitState::Done { delta }; - } else { - return_and_restore_if_io!( - &mut self.commit_state, - state, - min_max_persist_state.persist_min_max( - self.operator_id, - &self.column_min_max, - cursors, - |group_key_str| self.generate_group_rowid(group_key_str) - ) - ); - - let delta = std::mem::take(delta); - self.commit_state = AggregateCommitState::Done { delta }; - } - } - AggregateCommitState::Done { delta } => { - self.commit_state = AggregateCommitState::Idle; - let delta = std::mem::take(delta); - return Ok(IOResult::Done(delta)); - } - } - } - } - - fn set_tracker(&mut self, tracker: Arc>) { - self.tracker = Some(tracker); - } -} - #[derive(Debug)] enum JoinCommitState { Idle, @@ -2226,6 +1086,7 @@ impl IncrementalOperator for JoinOperator { #[cfg(test)] mod tests { use super::*; + use crate::incremental::aggregate_operator::AGG_TYPE_REGULAR; use crate::storage::pager::CreateBTreeFlags; use crate::types::Text; use crate::util::IOExt; diff --git a/core/incremental/persistence.rs b/core/incremental/persistence.rs index eca26cd7c..5cf41b94a 100644 --- a/core/incremental/persistence.rs +++ b/core/incremental/persistence.rs @@ -1,12 +1,7 @@ -use crate::incremental::dbsp::HashableRow; -use crate::incremental::operator::{ - generate_storage_id, AggColumnInfo, AggregateFunction, AggregateOperator, AggregateState, - DbspStateCursors, MinMaxDeltas, AGG_TYPE_MINMAX, -}; +use crate::incremental::operator::{AggregateFunction, AggregateState, DbspStateCursors}; use crate::storage::btree::{BTreeCursor, BTreeKey}; -use crate::types::{IOResult, ImmutableRecord, RefValue, SeekKey, SeekOp, SeekResult}; +use crate::types::{IOResult, ImmutableRecord, SeekKey, SeekOp, SeekResult}; use crate::{return_if_io, LimboError, Result, Value}; -use std::collections::{HashMap, HashSet}; #[derive(Debug, Default)] pub enum ReadRecord { @@ -290,672 +285,3 @@ impl WriteRow { } } } - -/// State machine for recomputing MIN/MAX values after deletion -#[derive(Debug)] -pub enum RecomputeMinMax { - ProcessElements { - /// Current column being processed - current_column_idx: usize, - /// Columns to process (combined MIN and MAX) - columns_to_process: Vec<(String, String, bool)>, // (group_key, column_name, is_min) - /// MIN/MAX deltas for checking values and weights - min_max_deltas: MinMaxDeltas, - }, - Scan { - /// Columns still to process - columns_to_process: Vec<(String, String, bool)>, - /// Current index in columns_to_process (will resume from here) - current_column_idx: usize, - /// MIN/MAX deltas for checking values and weights - min_max_deltas: MinMaxDeltas, - /// Current group key being processed - group_key: String, - /// Current column name being processed - column_name: String, - /// Whether we're looking for MIN (true) or MAX (false) - is_min: bool, - /// The scan state machine for finding the new MIN/MAX - scan_state: Box, - }, - Done, -} - -impl RecomputeMinMax { - pub fn new( - min_max_deltas: MinMaxDeltas, - existing_groups: &HashMap, - operator: &AggregateOperator, - ) -> Self { - let mut groups_to_check: HashSet<(String, String, bool)> = HashSet::new(); - - // Remember the min_max_deltas are essentially just the only column that is affected by - // this min/max, in delta (actually ZSet - consolidated delta) format. This makes it easier - // for us to consume it in here. - // - // The most challenging case is the case where there is a retraction, since we need to go - // back to the index. - for (group_key_str, values) in &min_max_deltas { - for ((col_name, hashable_row), weight) in values { - let col_info = operator.column_min_max.get(col_name); - - let value = &hashable_row.values[0]; - - if *weight < 0 { - // Deletion detected - check if it's the current MIN/MAX - if let Some(state) = existing_groups.get(group_key_str) { - // Check for MIN - if let Some(current_min) = state.mins.get(col_name) { - if current_min == value { - groups_to_check.insert(( - group_key_str.clone(), - col_name.clone(), - true, - )); - } - } - // Check for MAX - if let Some(current_max) = state.maxs.get(col_name) { - if current_max == value { - groups_to_check.insert(( - group_key_str.clone(), - col_name.clone(), - false, - )); - } - } - } - } else if *weight > 0 { - // If it is not found in the existing groups, then we only need to care - // about this if this is a new record being inserted - if let Some(info) = col_info { - if info.has_min { - groups_to_check.insert((group_key_str.clone(), col_name.clone(), true)); - } - if info.has_max { - groups_to_check.insert(( - group_key_str.clone(), - col_name.clone(), - false, - )); - } - } - } - } - } - - if groups_to_check.is_empty() { - // No recomputation or initialization needed - Self::Done - } else { - // Convert HashSet to Vec for indexed processing - let groups_to_check_vec: Vec<_> = groups_to_check.into_iter().collect(); - Self::ProcessElements { - current_column_idx: 0, - columns_to_process: groups_to_check_vec, - min_max_deltas, - } - } - } - - pub fn process( - &mut self, - existing_groups: &mut HashMap, - operator: &AggregateOperator, - cursors: &mut DbspStateCursors, - ) -> Result> { - loop { - match self { - RecomputeMinMax::ProcessElements { - current_column_idx, - columns_to_process, - min_max_deltas, - } => { - if *current_column_idx >= columns_to_process.len() { - *self = RecomputeMinMax::Done; - return Ok(IOResult::Done(())); - } - - let (group_key, column_name, is_min) = - columns_to_process[*current_column_idx].clone(); - - // Get column index from pre-computed info - let column_index = operator - .column_min_max - .get(&column_name) - .map(|info| info.index) - .unwrap(); // Should always exist since we're processing known columns - - // Get current value from existing state - let current_value = existing_groups.get(&group_key).and_then(|state| { - if is_min { - state.mins.get(&column_name).cloned() - } else { - state.maxs.get(&column_name).cloned() - } - }); - - // Create storage keys for index lookup - let storage_id = - generate_storage_id(operator.operator_id, column_index, AGG_TYPE_MINMAX); - let zset_id = operator.generate_group_rowid(&group_key); - - // Get the values for this group from min_max_deltas - let group_values = min_max_deltas.get(&group_key).cloned().unwrap_or_default(); - - let columns_to_process = std::mem::take(columns_to_process); - let min_max_deltas = std::mem::take(min_max_deltas); - - let scan_state = if is_min { - Box::new(ScanState::new_for_min( - current_value, - group_key.clone(), - column_name.clone(), - storage_id, - zset_id, - group_values, - )) - } else { - Box::new(ScanState::new_for_max( - current_value, - group_key.clone(), - column_name.clone(), - storage_id, - zset_id, - group_values, - )) - }; - - *self = RecomputeMinMax::Scan { - columns_to_process, - current_column_idx: *current_column_idx, - min_max_deltas, - group_key, - column_name, - is_min, - scan_state, - }; - } - RecomputeMinMax::Scan { - columns_to_process, - current_column_idx, - min_max_deltas, - group_key, - column_name, - is_min, - scan_state, - } => { - // Find new value using the scan state machine - let new_value = return_if_io!(scan_state.find_new_value(cursors)); - - // Update the state with new value (create if doesn't exist) - let state = existing_groups.entry(group_key.clone()).or_default(); - - if *is_min { - if let Some(min_val) = new_value { - state.mins.insert(column_name.clone(), min_val); - } else { - state.mins.remove(column_name); - } - } else if let Some(max_val) = new_value { - state.maxs.insert(column_name.clone(), max_val); - } else { - state.maxs.remove(column_name); - } - - // Move to next column - let min_max_deltas = std::mem::take(min_max_deltas); - let columns_to_process = std::mem::take(columns_to_process); - *self = RecomputeMinMax::ProcessElements { - current_column_idx: *current_column_idx + 1, - columns_to_process, - min_max_deltas, - }; - } - RecomputeMinMax::Done => { - return Ok(IOResult::Done(())); - } - } - } - } -} - -/// State machine for scanning through the index to find new MIN/MAX values -#[derive(Debug)] -pub enum ScanState { - CheckCandidate { - /// Current candidate value for MIN/MAX - candidate: Option, - /// Group key being processed - group_key: String, - /// Column name being processed - column_name: String, - /// Storage ID for the index seek - storage_id: i64, - /// ZSet ID for the group - zset_id: i64, - /// Group values from MinMaxDeltas: (column_name, HashableRow) -> weight - group_values: HashMap<(String, HashableRow), isize>, - /// Whether we're looking for MIN (true) or MAX (false) - is_min: bool, - }, - FetchNextCandidate { - /// Current candidate to seek past - current_candidate: Value, - /// Group key being processed - group_key: String, - /// Column name being processed - column_name: String, - /// Storage ID for the index seek - storage_id: i64, - /// ZSet ID for the group - zset_id: i64, - /// Group values from MinMaxDeltas: (column_name, HashableRow) -> weight - group_values: HashMap<(String, HashableRow), isize>, - /// Whether we're looking for MIN (true) or MAX (false) - is_min: bool, - }, - Done { - /// The final MIN/MAX value found - result: Option, - }, -} - -impl ScanState { - pub fn new_for_min( - current_min: Option, - group_key: String, - column_name: String, - storage_id: i64, - zset_id: i64, - group_values: HashMap<(String, HashableRow), isize>, - ) -> Self { - Self::CheckCandidate { - candidate: current_min, - group_key, - column_name, - storage_id, - zset_id, - group_values, - is_min: true, - } - } - - // Extract a new candidate from the index. It is possible that, when searching, - // we end up going into a different operator altogether. That means we have - // exhausted this operator (or group) entirely, and no good candidate was found - fn extract_new_candidate( - cursors: &mut DbspStateCursors, - index_record: &ImmutableRecord, - seek_op: SeekOp, - storage_id: i64, - zset_id: i64, - ) -> Result>> { - let seek_result = return_if_io!(cursors - .index_cursor - .seek(SeekKey::IndexKey(index_record), seek_op)); - if !matches!(seek_result, SeekResult::Found) { - return Ok(IOResult::Done(None)); - } - - let record = return_if_io!(cursors.index_cursor.record()).ok_or_else(|| { - LimboError::InternalError( - "Record found on the cursor, but could not be read".to_string(), - ) - })?; - - let values = record.get_values(); - if values.len() < 3 { - return Ok(IOResult::Done(None)); - } - - let Some(rec_storage_id) = values.first() else { - return Ok(IOResult::Done(None)); - }; - let Some(rec_zset_id) = values.get(1) else { - return Ok(IOResult::Done(None)); - }; - - // Check if we're still in the same group - if let (RefValue::Integer(rec_sid), RefValue::Integer(rec_zid)) = - (rec_storage_id, rec_zset_id) - { - if *rec_sid != storage_id || *rec_zid != zset_id { - return Ok(IOResult::Done(None)); - } - } else { - return Ok(IOResult::Done(None)); - } - - // Get the value (3rd element) - Ok(IOResult::Done(values.get(2).map(|v| v.to_owned()))) - } - - pub fn new_for_max( - current_max: Option, - group_key: String, - column_name: String, - storage_id: i64, - zset_id: i64, - group_values: HashMap<(String, HashableRow), isize>, - ) -> Self { - Self::CheckCandidate { - candidate: current_max, - group_key, - column_name, - storage_id, - zset_id, - group_values, - is_min: false, - } - } - - pub fn find_new_value( - &mut self, - cursors: &mut DbspStateCursors, - ) -> Result>> { - loop { - match self { - ScanState::CheckCandidate { - candidate, - group_key, - column_name, - storage_id, - zset_id, - group_values, - is_min, - } => { - // First, check if we have a candidate - if let Some(cand_val) = candidate { - // Check if the candidate is retracted (weight <= 0) - // Create a HashableRow to look up the weight - let hashable_cand = HashableRow::new(0, vec![cand_val.clone()]); - let key = (column_name.clone(), hashable_cand); - let is_retracted = - group_values.get(&key).is_some_and(|weight| *weight <= 0); - - if is_retracted { - // Candidate is retracted, need to fetch next from index - *self = ScanState::FetchNextCandidate { - current_candidate: cand_val.clone(), - group_key: std::mem::take(group_key), - column_name: std::mem::take(column_name), - storage_id: *storage_id, - zset_id: *zset_id, - group_values: std::mem::take(group_values), - is_min: *is_min, - }; - continue; - } - } - - // Candidate is valid or we have no candidate - // Now find the best value from insertions in group_values - let mut best_from_zset = None; - for ((col, hashable_val), weight) in group_values.iter() { - if col == column_name && *weight > 0 { - let value = &hashable_val.values[0]; - // Skip NULL values - they don't participate in MIN/MAX - if value == &Value::Null { - continue; - } - // This is an insertion for our column - if let Some(ref current_best) = best_from_zset { - if *is_min { - if value.cmp(current_best) == std::cmp::Ordering::Less { - best_from_zset = Some(value.clone()); - } - } else if value.cmp(current_best) == std::cmp::Ordering::Greater { - best_from_zset = Some(value.clone()); - } - } else { - best_from_zset = Some(value.clone()); - } - } - } - - // Compare candidate with best from ZSet, filtering out NULLs - let result = match (&candidate, &best_from_zset) { - (Some(cand), Some(zset_val)) if cand != &Value::Null => { - if *is_min { - if zset_val.cmp(cand) == std::cmp::Ordering::Less { - Some(zset_val.clone()) - } else { - Some(cand.clone()) - } - } else if zset_val.cmp(cand) == std::cmp::Ordering::Greater { - Some(zset_val.clone()) - } else { - Some(cand.clone()) - } - } - (Some(cand), None) if cand != &Value::Null => Some(cand.clone()), - (None, Some(zset_val)) => Some(zset_val.clone()), - (Some(cand), Some(_)) if cand == &Value::Null => best_from_zset, - _ => None, - }; - - *self = ScanState::Done { result }; - } - - ScanState::FetchNextCandidate { - current_candidate, - group_key, - column_name, - storage_id, - zset_id, - group_values, - is_min, - } => { - // Seek to the next value in the index - let index_key = vec![ - Value::Integer(*storage_id), - Value::Integer(*zset_id), - current_candidate.clone(), - ]; - let index_record = ImmutableRecord::from_values(&index_key, index_key.len()); - - let seek_op = if *is_min { - SeekOp::GT // For MIN, seek greater than current - } else { - SeekOp::LT // For MAX, seek less than current - }; - - let new_candidate = return_if_io!(Self::extract_new_candidate( - cursors, - &index_record, - seek_op, - *storage_id, - *zset_id - )); - - *self = ScanState::CheckCandidate { - candidate: new_candidate, - group_key: std::mem::take(group_key), - column_name: std::mem::take(column_name), - storage_id: *storage_id, - zset_id: *zset_id, - group_values: std::mem::take(group_values), - is_min: *is_min, - }; - } - - ScanState::Done { result } => { - return Ok(IOResult::Done(result.clone())); - } - } - } - } -} - -/// State machine for persisting Min/Max values to storage -#[derive(Debug)] -pub enum MinMaxPersistState { - Init { - min_max_deltas: MinMaxDeltas, - group_keys: Vec, - }, - ProcessGroup { - min_max_deltas: MinMaxDeltas, - group_keys: Vec, - group_idx: usize, - value_idx: usize, - }, - WriteValue { - min_max_deltas: MinMaxDeltas, - group_keys: Vec, - group_idx: usize, - value_idx: usize, - value: Value, - column_name: String, - weight: isize, - write_row: WriteRow, - }, - Done, -} - -impl MinMaxPersistState { - pub fn new(min_max_deltas: MinMaxDeltas) -> Self { - let group_keys: Vec = min_max_deltas.keys().cloned().collect(); - Self::Init { - min_max_deltas, - group_keys, - } - } - - pub fn persist_min_max( - &mut self, - operator_id: usize, - column_min_max: &HashMap, - cursors: &mut DbspStateCursors, - generate_group_rowid: impl Fn(&str) -> i64, - ) -> Result> { - loop { - match self { - MinMaxPersistState::Init { - min_max_deltas, - group_keys, - } => { - let min_max_deltas = std::mem::take(min_max_deltas); - let group_keys = std::mem::take(group_keys); - *self = MinMaxPersistState::ProcessGroup { - min_max_deltas, - group_keys, - group_idx: 0, - value_idx: 0, - }; - } - MinMaxPersistState::ProcessGroup { - min_max_deltas, - group_keys, - group_idx, - value_idx, - } => { - // Check if we're past all groups - if *group_idx >= group_keys.len() { - *self = MinMaxPersistState::Done; - continue; - } - - let group_key_str = &group_keys[*group_idx]; - let values = &min_max_deltas[group_key_str]; // This should always exist - - // Convert HashMap to Vec for indexed access - let values_vec: Vec<_> = values.iter().collect(); - - // Check if we have more values in current group - if *value_idx >= values_vec.len() { - *group_idx += 1; - *value_idx = 0; - // Continue to check if we're past all groups now - continue; - } - - // Process current value and extract what we need before taking ownership - let ((column_name, hashable_row), weight) = values_vec[*value_idx]; - let column_name = column_name.clone(); - let value = hashable_row.values[0].clone(); // Extract the Value from HashableRow - let weight = *weight; - - let min_max_deltas = std::mem::take(min_max_deltas); - let group_keys = std::mem::take(group_keys); - *self = MinMaxPersistState::WriteValue { - min_max_deltas, - group_keys, - group_idx: *group_idx, - value_idx: *value_idx, - column_name, - value, - weight, - write_row: WriteRow::new(), - }; - } - MinMaxPersistState::WriteValue { - min_max_deltas, - group_keys, - group_idx, - value_idx, - value, - column_name, - weight, - write_row, - } => { - // Should have exited in the previous state - assert!(*group_idx < group_keys.len()); - - let group_key_str = &group_keys[*group_idx]; - - // Get the column index from the pre-computed map - let column_info = column_min_max - .get(&*column_name) - .expect("Column should exist in column_min_max map"); - let column_index = column_info.index; - - // Build the key components for MinMax storage using new encoding - let storage_id = - generate_storage_id(operator_id, column_index, AGG_TYPE_MINMAX); - let zset_id = generate_group_rowid(group_key_str); - - // element_id is the actual value for Min/Max - let element_id_val = value.clone(); - - // Create index key - let index_key = vec![ - Value::Integer(storage_id), - Value::Integer(zset_id), - element_id_val.clone(), - ]; - - // Record values (operator_id, zset_id, element_id, unused_placeholder) - // For MIN/MAX, the element_id IS the value, so we use NULL for the 4th column - let record_values = vec![ - Value::Integer(storage_id), - Value::Integer(zset_id), - element_id_val.clone(), - Value::Null, // Placeholder - not used for MIN/MAX - ]; - - return_if_io!(write_row.write_row( - cursors, - index_key.clone(), - record_values, - *weight - )); - - // Move to next value - let min_max_deltas = std::mem::take(min_max_deltas); - let group_keys = std::mem::take(group_keys); - *self = MinMaxPersistState::ProcessGroup { - min_max_deltas, - group_keys, - group_idx: *group_idx, - value_idx: *value_idx + 1, - }; - } - MinMaxPersistState::Done => { - return Ok(IOResult::Done(())); - } - } - } - } -} From e2f0e372a1a1453916a3978e9d84cb661327a07c Mon Sep 17 00:00:00 2001 From: Glauber Costa Date: Wed, 17 Sep 2025 10:45:12 -0500 Subject: [PATCH 17/34] move the join operator to its own file. The code is becoming impossible to reason about with everything in operator.rs --- core/incremental/join_operator.rs | 787 ++++++++++++++++++++++++++++++ core/incremental/mod.rs | 1 + core/incremental/operator.rs | 787 +----------------------------- 3 files changed, 794 insertions(+), 781 deletions(-) create mode 100644 core/incremental/join_operator.rs diff --git a/core/incremental/join_operator.rs b/core/incremental/join_operator.rs new file mode 100644 index 000000000..f5ffb9b55 --- /dev/null +++ b/core/incremental/join_operator.rs @@ -0,0 +1,787 @@ +#![allow(dead_code)] + +use crate::incremental::dbsp::{Delta, DeltaPair, HashableRow}; +use crate::incremental::operator::{ + generate_storage_id, ComputationTracker, DbspStateCursors, EvalState, IncrementalOperator, +}; +use crate::incremental::persistence::WriteRow; +use crate::types::{IOResult, ImmutableRecord, SeekKey, SeekOp, SeekResult}; +use crate::{return_and_restore_if_io, return_if_io, Result, Value}; +use std::sync::{Arc, Mutex}; + +#[derive(Debug, Clone, PartialEq)] +pub enum JoinType { + Inner, + Left, + Right, + Full, + Cross, +} + +// Helper function to read the next row from the BTree for joins +fn read_next_join_row( + storage_id: i64, + join_key: &HashableRow, + last_element_id: i64, + cursors: &mut DbspStateCursors, +) -> Result>> { + // Build the index key: (storage_id, zset_id, element_id) + // zset_id is the hash of the join key + let zset_id = join_key.cached_hash() as i64; + + let index_key_values = vec![ + Value::Integer(storage_id), + Value::Integer(zset_id), + Value::Integer(last_element_id), + ]; + + let index_record = ImmutableRecord::from_values(&index_key_values, index_key_values.len()); + let seek_result = return_if_io!(cursors + .index_cursor + .seek(SeekKey::IndexKey(&index_record), SeekOp::GT)); + + if !matches!(seek_result, SeekResult::Found) { + return Ok(IOResult::Done(None)); + } + + // Check if we're still in the same (storage_id, zset_id) range + let current_record = return_if_io!(cursors.index_cursor.record()); + + // Extract all needed values from the record before dropping it + let (found_storage_id, found_zset_id, element_id) = if let Some(rec) = current_record { + let values = rec.get_values(); + + // Index has 4 values: storage_id, zset_id, element_id, rowid (appended by WriteRow) + if values.len() >= 3 { + let found_storage_id = match &values[0].to_owned() { + Value::Integer(id) => *id, + _ => return Ok(IOResult::Done(None)), + }; + let found_zset_id = match &values[1].to_owned() { + Value::Integer(id) => *id, + _ => return Ok(IOResult::Done(None)), + }; + let element_id = match &values[2].to_owned() { + Value::Integer(id) => *id, + _ => { + return Ok(IOResult::Done(None)); + } + }; + (found_storage_id, found_zset_id, element_id) + } else { + return Ok(IOResult::Done(None)); + } + } else { + return Ok(IOResult::Done(None)); + }; + + // Now we can safely check if we're in the right range + // If we've moved to a different storage_id or zset_id, we're done + if found_storage_id != storage_id || found_zset_id != zset_id { + return Ok(IOResult::Done(None)); + } + + // Now get the actual row from the table using the rowid from the index + let rowid = return_if_io!(cursors.index_cursor.rowid()); + if let Some(rowid) = rowid { + return_if_io!(cursors + .table_cursor + .seek(SeekKey::TableRowId(rowid), SeekOp::GE { eq_only: true })); + + let table_record = return_if_io!(cursors.table_cursor.record()); + if let Some(rec) = table_record { + let table_values = rec.get_values(); + // Table format: [storage_id, zset_id, element_id, value_blob, weight] + if table_values.len() >= 5 { + // Deserialize the row from the blob + let value_at_3 = table_values[3].to_owned(); + let blob = match value_at_3 { + Value::Blob(ref b) => b, + _ => return Ok(IOResult::Done(None)), + }; + + // The blob contains the serialized HashableRow + // For now, let's deserialize it simply + let row = deserialize_hashable_row(blob)?; + + let weight = match &table_values[4].to_owned() { + Value::Integer(w) => *w as isize, + _ => return Ok(IOResult::Done(None)), + }; + + return Ok(IOResult::Done(Some((element_id, row, weight)))); + } + } + } + Ok(IOResult::Done(None)) +} + +// Join-specific eval states +#[derive(Debug)] +pub enum JoinEvalState { + ProcessDeltaJoin { + deltas: DeltaPair, + output: Delta, + }, + ProcessLeftJoin { + deltas: DeltaPair, + output: Delta, + current_idx: usize, + last_row_scanned: i64, + }, + ProcessRightJoin { + deltas: DeltaPair, + output: Delta, + current_idx: usize, + last_row_scanned: i64, + }, + Done { + output: Delta, + }, +} + +impl JoinEvalState { + fn combine_rows( + left_row: &HashableRow, + left_weight: i64, + right_row: &HashableRow, + right_weight: i64, + output: &mut Delta, + ) { + // Combine the rows + let mut combined_values = left_row.values.clone(); + combined_values.extend(right_row.values.clone()); + // Use hash of the combined values as rowid to ensure uniqueness + let temp_row = HashableRow::new(0, combined_values.clone()); + let joined_rowid = temp_row.cached_hash() as i64; + let joined_row = HashableRow::new(joined_rowid, combined_values); + + // Add to output with combined weight + let combined_weight = left_weight * right_weight; + output.changes.push((joined_row, combined_weight as isize)); + } + + fn process_join_state( + &mut self, + cursors: &mut DbspStateCursors, + left_key_indices: &[usize], + right_key_indices: &[usize], + left_storage_id: i64, + right_storage_id: i64, + ) -> Result> { + loop { + match self { + JoinEvalState::ProcessDeltaJoin { deltas, output } => { + // Move to ProcessLeftJoin + *self = JoinEvalState::ProcessLeftJoin { + deltas: std::mem::take(deltas), + output: std::mem::take(output), + current_idx: 0, + last_row_scanned: i64::MIN, + }; + } + JoinEvalState::ProcessLeftJoin { + deltas, + output, + current_idx, + last_row_scanned, + } => { + if *current_idx >= deltas.left.changes.len() { + *self = JoinEvalState::ProcessRightJoin { + deltas: std::mem::take(deltas), + output: std::mem::take(output), + current_idx: 0, + last_row_scanned: i64::MIN, + }; + } else { + let (left_row, left_weight) = &deltas.left.changes[*current_idx]; + // Extract join key using provided indices + let key_values: Vec = left_key_indices + .iter() + .map(|&idx| left_row.values.get(idx).cloned().unwrap_or(Value::Null)) + .collect(); + let left_key = HashableRow::new(0, key_values); + + let next_row = return_if_io!(read_next_join_row( + right_storage_id, + &left_key, + *last_row_scanned, + cursors + )); + match next_row { + Some((element_id, right_row, right_weight)) => { + Self::combine_rows( + left_row, + (*left_weight) as i64, + &right_row, + right_weight as i64, + output, + ); + // Continue scanning with this left row + *self = JoinEvalState::ProcessLeftJoin { + deltas: std::mem::take(deltas), + output: std::mem::take(output), + current_idx: *current_idx, + last_row_scanned: element_id, + }; + } + None => { + // No more matches for this left row, move to next + *self = JoinEvalState::ProcessLeftJoin { + deltas: std::mem::take(deltas), + output: std::mem::take(output), + current_idx: *current_idx + 1, + last_row_scanned: i64::MIN, + }; + } + } + } + } + JoinEvalState::ProcessRightJoin { + deltas, + output, + current_idx, + last_row_scanned, + } => { + if *current_idx >= deltas.right.changes.len() { + *self = JoinEvalState::Done { + output: std::mem::take(output), + }; + } else { + let (right_row, right_weight) = &deltas.right.changes[*current_idx]; + // Extract join key using provided indices + let key_values: Vec = right_key_indices + .iter() + .map(|&idx| right_row.values.get(idx).cloned().unwrap_or(Value::Null)) + .collect(); + let right_key = HashableRow::new(0, key_values); + + let next_row = return_if_io!(read_next_join_row( + left_storage_id, + &right_key, + *last_row_scanned, + cursors + )); + match next_row { + Some((element_id, left_row, left_weight)) => { + Self::combine_rows( + &left_row, + left_weight as i64, + right_row, + (*right_weight) as i64, + output, + ); + // Continue scanning with this right row + *self = JoinEvalState::ProcessRightJoin { + deltas: std::mem::take(deltas), + output: std::mem::take(output), + current_idx: *current_idx, + last_row_scanned: element_id, + }; + } + None => { + // No more matches for this right row, move to next + *self = JoinEvalState::ProcessRightJoin { + deltas: std::mem::take(deltas), + output: std::mem::take(output), + current_idx: *current_idx + 1, + last_row_scanned: i64::MIN, + }; + } + } + } + } + JoinEvalState::Done { output } => { + return Ok(IOResult::Done(std::mem::take(output))); + } + } + } + } +} + +#[derive(Debug)] +enum JoinCommitState { + Idle, + Eval { + eval_state: EvalState, + }, + CommitLeftDelta { + deltas: DeltaPair, + output: Delta, + current_idx: usize, + write_row: WriteRow, + }, + CommitRightDelta { + deltas: DeltaPair, + output: Delta, + current_idx: usize, + write_row: WriteRow, + }, + Invalid, +} + +/// Join operator - performs incremental join between two relations +/// Implements the DBSP formula: δ(R ⋈ S) = (δR ⋈ S) ∪ (R ⋈ δS) ∪ (δR ⋈ δS) +#[derive(Debug)] +pub struct JoinOperator { + /// Unique operator ID for indexing in persistent storage + operator_id: usize, + /// Type of join to perform + join_type: JoinType, + /// Column indices for extracting join keys from left input + left_key_indices: Vec, + /// Column indices for extracting join keys from right input + right_key_indices: Vec, + /// Column names from left input + left_columns: Vec, + /// Column names from right input + right_columns: Vec, + /// Tracker for computation statistics + tracker: Option>>, + + commit_state: JoinCommitState, +} + +impl JoinOperator { + pub fn new( + operator_id: usize, + join_type: JoinType, + left_key_indices: Vec, + right_key_indices: Vec, + left_columns: Vec, + right_columns: Vec, + ) -> Result { + // Check for unsupported join types + match join_type { + JoinType::Left => { + return Err(crate::LimboError::ParseError( + "LEFT OUTER JOIN is not yet supported in incremental views".to_string(), + )) + } + JoinType::Right => { + return Err(crate::LimboError::ParseError( + "RIGHT OUTER JOIN is not yet supported in incremental views".to_string(), + )) + } + JoinType::Full => { + return Err(crate::LimboError::ParseError( + "FULL OUTER JOIN is not yet supported in incremental views".to_string(), + )) + } + JoinType::Cross => { + return Err(crate::LimboError::ParseError( + "CROSS JOIN is not yet supported in incremental views".to_string(), + )) + } + JoinType::Inner => {} // Inner join is supported + } + + Ok(Self { + operator_id, + join_type, + left_key_indices, + right_key_indices, + left_columns, + right_columns, + tracker: None, + commit_state: JoinCommitState::Idle, + }) + } + + /// Extract join key from row values using the specified indices + fn extract_join_key(&self, values: &[Value], indices: &[usize]) -> HashableRow { + let key_values: Vec = indices + .iter() + .map(|&idx| values.get(idx).cloned().unwrap_or(Value::Null)) + .collect(); + // Use 0 as a dummy rowid for join keys. They don't come from a table, + // so they don't need a rowid. Their key will be the hash of the row values. + HashableRow::new(0, key_values) + } + + /// Generate storage ID for left table + fn left_storage_id(&self) -> i64 { + // Use column_index=0 for left side + generate_storage_id(self.operator_id, 0, 0) + } + + /// Generate storage ID for right table + fn right_storage_id(&self) -> i64 { + // Use column_index=1 for right side + generate_storage_id(self.operator_id, 1, 0) + } + + /// SQL-compliant comparison for join keys + /// Returns true if keys match according to SQL semantics (NULL != NULL) + fn sql_keys_equal(left_key: &HashableRow, right_key: &HashableRow) -> bool { + if left_key.values.len() != right_key.values.len() { + return false; + } + + for (left_val, right_val) in left_key.values.iter().zip(right_key.values.iter()) { + // In SQL, NULL never equals NULL + if matches!(left_val, Value::Null) || matches!(right_val, Value::Null) { + return false; + } + + // For non-NULL values, use regular comparison + if left_val != right_val { + return false; + } + } + + true + } + + fn process_join_state( + &mut self, + state: &mut EvalState, + cursors: &mut DbspStateCursors, + ) -> Result> { + // Get the join state out of the enum + match state { + EvalState::Join(js) => js.process_join_state( + cursors, + &self.left_key_indices, + &self.right_key_indices, + self.left_storage_id(), + self.right_storage_id(), + ), + _ => panic!("process_join_state called with non-join state"), + } + } + + fn eval_internal( + &mut self, + state: &mut EvalState, + cursors: &mut DbspStateCursors, + ) -> Result> { + loop { + let loop_state = std::mem::replace(state, EvalState::Uninitialized); + match loop_state { + EvalState::Uninitialized => { + panic!("Cannot eval JoinOperator with Uninitialized state"); + } + EvalState::Init { deltas } => { + let mut output = Delta::new(); + + // Component 3: δR ⋈ δS (left delta join right delta) + for (left_row, left_weight) in &deltas.left.changes { + let left_key = + self.extract_join_key(&left_row.values, &self.left_key_indices); + + for (right_row, right_weight) in &deltas.right.changes { + let right_key = + self.extract_join_key(&right_row.values, &self.right_key_indices); + + if Self::sql_keys_equal(&left_key, &right_key) { + if let Some(tracker) = &self.tracker { + tracker.lock().unwrap().record_join_lookup(); + } + + // Combine the rows + let mut combined_values = left_row.values.clone(); + combined_values.extend(right_row.values.clone()); + + // Create the joined row with a unique rowid + // Use hash of the combined values to ensure uniqueness + let temp_row = HashableRow::new(0, combined_values.clone()); + let joined_rowid = temp_row.cached_hash() as i64; + let joined_row = + HashableRow::new(joined_rowid, combined_values.clone()); + + // Add to output with combined weight + let combined_weight = left_weight * right_weight; + output.changes.push((joined_row, combined_weight)); + } + } + } + + *state = EvalState::Join(Box::new(JoinEvalState::ProcessDeltaJoin { + deltas, + output, + })); + } + EvalState::Join(join_state) => { + *state = EvalState::Join(join_state); + let output = return_if_io!(self.process_join_state(state, cursors)); + return Ok(IOResult::Done(output)); + } + EvalState::Done => { + return Ok(IOResult::Done(Delta::new())); + } + EvalState::Aggregate(_) => { + panic!("Aggregate state should not appear in join operator"); + } + } + } + } +} + +// Helper to deserialize a HashableRow from a blob +fn deserialize_hashable_row(blob: &[u8]) -> Result { + // Simple deserialization - this needs to match how we serialize in commit + // Format: [rowid:8 bytes][num_values:4 bytes][values...] + if blob.len() < 12 { + return Err(crate::LimboError::InternalError( + "Invalid blob size".to_string(), + )); + } + + let rowid = i64::from_le_bytes(blob[0..8].try_into().unwrap()); + let num_values = u32::from_le_bytes(blob[8..12].try_into().unwrap()) as usize; + + let mut values = Vec::new(); + let mut offset = 12; + + for _ in 0..num_values { + if offset >= blob.len() { + break; + } + + let type_tag = blob[offset]; + offset += 1; + + match type_tag { + 0 => values.push(Value::Null), + 1 => { + if offset + 8 <= blob.len() { + let i = i64::from_le_bytes(blob[offset..offset + 8].try_into().unwrap()); + values.push(Value::Integer(i)); + offset += 8; + } + } + 2 => { + if offset + 8 <= blob.len() { + let f = f64::from_le_bytes(blob[offset..offset + 8].try_into().unwrap()); + values.push(Value::Float(f)); + offset += 8; + } + } + 3 => { + if offset + 4 <= blob.len() { + let len = + u32::from_le_bytes(blob[offset..offset + 4].try_into().unwrap()) as usize; + offset += 4; + if offset + len < blob.len() { + let text_bytes = blob[offset..offset + len].to_vec(); + offset += len; + let subtype = match blob[offset] { + 0 => crate::types::TextSubtype::Text, + 1 => crate::types::TextSubtype::Json, + _ => crate::types::TextSubtype::Text, + }; + offset += 1; + values.push(Value::Text(crate::types::Text { + value: text_bytes, + subtype, + })); + } + } + } + 4 => { + if offset + 4 <= blob.len() { + let len = + u32::from_le_bytes(blob[offset..offset + 4].try_into().unwrap()) as usize; + offset += 4; + if offset + len <= blob.len() { + let blob_data = blob[offset..offset + len].to_vec(); + values.push(Value::Blob(blob_data)); + offset += len; + } + } + } + _ => break, // Unknown type tag + } + } + + Ok(HashableRow::new(rowid, values)) +} + +// Helper to serialize a HashableRow to a blob +fn serialize_hashable_row(row: &HashableRow) -> Vec { + let mut blob = Vec::new(); + + // Write rowid + blob.extend_from_slice(&row.rowid.to_le_bytes()); + + // Write number of values + blob.extend_from_slice(&(row.values.len() as u32).to_le_bytes()); + + // Write each value directly with type tags (like AggregateState does) + for value in &row.values { + match value { + Value::Null => blob.push(0u8), + Value::Integer(i) => { + blob.push(1u8); + blob.extend_from_slice(&i.to_le_bytes()); + } + Value::Float(f) => { + blob.push(2u8); + blob.extend_from_slice(&f.to_le_bytes()); + } + Value::Text(s) => { + blob.push(3u8); + let bytes = &s.value; + blob.extend_from_slice(&(bytes.len() as u32).to_le_bytes()); + blob.extend_from_slice(bytes); + blob.push(s.subtype as u8); + } + Value::Blob(b) => { + blob.push(4u8); + blob.extend_from_slice(&(b.len() as u32).to_le_bytes()); + blob.extend_from_slice(b); + } + } + } + + blob +} + +impl IncrementalOperator for JoinOperator { + fn eval( + &mut self, + state: &mut EvalState, + cursors: &mut DbspStateCursors, + ) -> Result> { + let delta = return_if_io!(self.eval_internal(state, cursors)); + Ok(IOResult::Done(delta)) + } + + fn commit( + &mut self, + deltas: DeltaPair, + cursors: &mut DbspStateCursors, + ) -> Result> { + loop { + let mut state = std::mem::replace(&mut self.commit_state, JoinCommitState::Invalid); + match &mut state { + JoinCommitState::Idle => { + self.commit_state = JoinCommitState::Eval { + eval_state: deltas.clone().into(), + } + } + JoinCommitState::Eval { ref mut eval_state } => { + let output = return_and_restore_if_io!( + &mut self.commit_state, + state, + self.eval(eval_state, cursors) + ); + self.commit_state = JoinCommitState::CommitLeftDelta { + deltas: deltas.clone(), + output, + current_idx: 0, + write_row: WriteRow::new(), + }; + } + JoinCommitState::CommitLeftDelta { + deltas, + output, + current_idx, + ref mut write_row, + } => { + if *current_idx >= deltas.left.changes.len() { + self.commit_state = JoinCommitState::CommitRightDelta { + deltas: std::mem::take(deltas), + output: std::mem::take(output), + current_idx: 0, + write_row: WriteRow::new(), + }; + continue; + } + + let (row, weight) = &deltas.left.changes[*current_idx]; + // Extract join key from the left row + let join_key = self.extract_join_key(&row.values, &self.left_key_indices); + + // The index key: (storage_id, zset_id, element_id) + // zset_id is the hash of the join key, element_id is hash of the row + let storage_id = self.left_storage_id(); + let zset_id = join_key.cached_hash() as i64; + let element_id = row.cached_hash() as i64; + let index_key = vec![ + Value::Integer(storage_id), + Value::Integer(zset_id), + Value::Integer(element_id), + ]; + + // The record values: we'll store the serialized row as a blob + let row_blob = serialize_hashable_row(row); + let record_values = vec![ + Value::Integer(self.left_storage_id()), + Value::Integer(join_key.cached_hash() as i64), + Value::Integer(row.cached_hash() as i64), + Value::Blob(row_blob), + ]; + + // Use return_and_restore_if_io to handle I/O properly + return_and_restore_if_io!( + &mut self.commit_state, + state, + write_row.write_row(cursors, index_key, record_values, *weight) + ); + + self.commit_state = JoinCommitState::CommitLeftDelta { + deltas: deltas.clone(), + output: output.clone(), + current_idx: *current_idx + 1, + write_row: WriteRow::new(), + }; + } + JoinCommitState::CommitRightDelta { + deltas, + output, + current_idx, + ref mut write_row, + } => { + if *current_idx >= deltas.right.changes.len() { + // Reset to Idle state for next commit + self.commit_state = JoinCommitState::Idle; + return Ok(IOResult::Done(output.clone())); + } + + let (row, weight) = &deltas.right.changes[*current_idx]; + // Extract join key from the right row + let join_key = self.extract_join_key(&row.values, &self.right_key_indices); + + // The index key: (storage_id, zset_id, element_id) + let index_key = vec![ + Value::Integer(self.right_storage_id()), + Value::Integer(join_key.cached_hash() as i64), + Value::Integer(row.cached_hash() as i64), + ]; + + // The record values: we'll store the serialized row as a blob + let row_blob = serialize_hashable_row(row); + let record_values = vec![ + Value::Integer(self.right_storage_id()), + Value::Integer(join_key.cached_hash() as i64), + Value::Integer(row.cached_hash() as i64), + Value::Blob(row_blob), + ]; + + // Use return_and_restore_if_io to handle I/O properly + return_and_restore_if_io!( + &mut self.commit_state, + state, + write_row.write_row(cursors, index_key, record_values, *weight) + ); + + self.commit_state = JoinCommitState::CommitRightDelta { + deltas: std::mem::take(deltas), + output: std::mem::take(output), + current_idx: *current_idx + 1, + write_row: WriteRow::new(), + }; + } + JoinCommitState::Invalid => { + panic!("Invalid join commit state"); + } + } + } + } + + fn set_tracker(&mut self, tracker: Arc>) { + self.tracker = Some(tracker); + } +} diff --git a/core/incremental/mod.rs b/core/incremental/mod.rs index a747809d9..67eed60e2 100644 --- a/core/incremental/mod.rs +++ b/core/incremental/mod.rs @@ -5,6 +5,7 @@ pub mod dbsp; pub mod expr_compiler; pub mod filter_operator; pub mod input_operator; +pub mod join_operator; pub mod operator; pub mod persistence; pub mod project_operator; diff --git a/core/incremental/operator.rs b/core/incremental/operator.rs index 92b35d5f1..278ce4ef9 100644 --- a/core/incremental/operator.rs +++ b/core/incremental/operator.rs @@ -7,14 +7,14 @@ pub use crate::incremental::aggregate_operator::{ }; pub use crate::incremental::filter_operator::{FilterOperator, FilterPredicate}; pub use crate::incremental::input_operator::InputOperator; +pub use crate::incremental::join_operator::{JoinEvalState, JoinOperator, JoinType}; pub use crate::incremental::project_operator::{ProjectColumn, ProjectOperator}; -use crate::incremental::dbsp::{Delta, DeltaPair, HashableRow}; -use crate::incremental::persistence::WriteRow; +use crate::incremental::dbsp::{Delta, DeltaPair}; use crate::schema::{Index, IndexColumn}; use crate::storage::btree::BTreeCursor; -use crate::types::{IOResult, ImmutableRecord, SeekKey, SeekOp, SeekResult}; -use crate::{return_and_restore_if_io, return_if_io, Result, Value}; +use crate::types::IOResult; +use crate::Result; use std::fmt::Debug; use std::sync::{Arc, Mutex}; @@ -85,287 +85,6 @@ pub fn generate_storage_id(operator_id: usize, column_index: usize, op_type: u8) ((operator_id as i64) << 16) | ((column_index as i64) << 2) | (op_type as i64) } -// Helper function to read the next row from the BTree for joins -fn read_next_join_row( - storage_id: i64, - join_key: &HashableRow, - last_element_id: i64, - cursors: &mut DbspStateCursors, -) -> Result>> { - // Build the index key: (storage_id, zset_id, element_id) - // zset_id is the hash of the join key - let zset_id = join_key.cached_hash() as i64; - - let index_key_values = vec![ - Value::Integer(storage_id), - Value::Integer(zset_id), - Value::Integer(last_element_id), - ]; - - let index_record = ImmutableRecord::from_values(&index_key_values, index_key_values.len()); - let seek_result = return_if_io!(cursors - .index_cursor - .seek(SeekKey::IndexKey(&index_record), SeekOp::GT)); - - if !matches!(seek_result, SeekResult::Found) { - return Ok(IOResult::Done(None)); - } - - // Check if we're still in the same (storage_id, zset_id) range - let current_record = return_if_io!(cursors.index_cursor.record()); - - // Extract all needed values from the record before dropping it - let (found_storage_id, found_zset_id, element_id) = if let Some(rec) = current_record { - let values = rec.get_values(); - - // Index has 4 values: storage_id, zset_id, element_id, rowid (appended by WriteRow) - if values.len() >= 3 { - let found_storage_id = match &values[0].to_owned() { - Value::Integer(id) => *id, - _ => return Ok(IOResult::Done(None)), - }; - let found_zset_id = match &values[1].to_owned() { - Value::Integer(id) => *id, - _ => return Ok(IOResult::Done(None)), - }; - let element_id = match &values[2].to_owned() { - Value::Integer(id) => *id, - _ => { - return Ok(IOResult::Done(None)); - } - }; - (found_storage_id, found_zset_id, element_id) - } else { - return Ok(IOResult::Done(None)); - } - } else { - return Ok(IOResult::Done(None)); - }; - - // Now we can safely check if we're in the right range - // If we've moved to a different storage_id or zset_id, we're done - if found_storage_id != storage_id || found_zset_id != zset_id { - return Ok(IOResult::Done(None)); - } - - // Now get the actual row from the table using the rowid from the index - let rowid = return_if_io!(cursors.index_cursor.rowid()); - if let Some(rowid) = rowid { - return_if_io!(cursors - .table_cursor - .seek(SeekKey::TableRowId(rowid), SeekOp::GE { eq_only: true })); - - let table_record = return_if_io!(cursors.table_cursor.record()); - if let Some(rec) = table_record { - let table_values = rec.get_values(); - // Table format: [storage_id, zset_id, element_id, value_blob, weight] - if table_values.len() >= 5 { - // Deserialize the row from the blob - let value_at_3 = table_values[3].to_owned(); - let blob = match value_at_3 { - Value::Blob(ref b) => b, - _ => return Ok(IOResult::Done(None)), - }; - - // The blob contains the serialized HashableRow - // For now, let's deserialize it simply - let row = deserialize_hashable_row(blob)?; - - let weight = match &table_values[4].to_owned() { - Value::Integer(w) => *w as isize, - _ => return Ok(IOResult::Done(None)), - }; - - return Ok(IOResult::Done(Some((element_id, row, weight)))); - } - } - } - Ok(IOResult::Done(None)) -} - -// Join-specific eval states -#[derive(Debug)] -pub enum JoinEvalState { - ProcessDeltaJoin { - deltas: DeltaPair, - output: Delta, - }, - ProcessLeftJoin { - deltas: DeltaPair, - output: Delta, - current_idx: usize, - last_row_scanned: i64, - }, - ProcessRightJoin { - deltas: DeltaPair, - output: Delta, - current_idx: usize, - last_row_scanned: i64, - }, - Done { - output: Delta, - }, -} - -impl JoinEvalState { - fn combine_rows( - left_row: &HashableRow, - left_weight: i64, - right_row: &HashableRow, - right_weight: i64, - output: &mut Delta, - ) { - // Combine the rows - let mut combined_values = left_row.values.clone(); - combined_values.extend(right_row.values.clone()); - // Use hash of the combined values as rowid to ensure uniqueness - let temp_row = HashableRow::new(0, combined_values.clone()); - let joined_rowid = temp_row.cached_hash() as i64; - let joined_row = HashableRow::new(joined_rowid, combined_values); - - // Add to output with combined weight - let combined_weight = left_weight * right_weight; - output.changes.push((joined_row, combined_weight as isize)); - } - - fn process_join_state( - &mut self, - cursors: &mut DbspStateCursors, - left_key_indices: &[usize], - right_key_indices: &[usize], - left_storage_id: i64, - right_storage_id: i64, - ) -> Result> { - loop { - match self { - JoinEvalState::ProcessDeltaJoin { deltas, output } => { - // Move to ProcessLeftJoin - *self = JoinEvalState::ProcessLeftJoin { - deltas: std::mem::take(deltas), - output: std::mem::take(output), - current_idx: 0, - last_row_scanned: i64::MIN, - }; - } - JoinEvalState::ProcessLeftJoin { - deltas, - output, - current_idx, - last_row_scanned, - } => { - if *current_idx >= deltas.left.changes.len() { - *self = JoinEvalState::ProcessRightJoin { - deltas: std::mem::take(deltas), - output: std::mem::take(output), - current_idx: 0, - last_row_scanned: i64::MIN, - }; - } else { - let (left_row, left_weight) = &deltas.left.changes[*current_idx]; - // Extract join key using provided indices - let key_values: Vec = left_key_indices - .iter() - .map(|&idx| left_row.values.get(idx).cloned().unwrap_or(Value::Null)) - .collect(); - let left_key = HashableRow::new(0, key_values); - - let next_row = return_if_io!(read_next_join_row( - right_storage_id, - &left_key, - *last_row_scanned, - cursors - )); - match next_row { - Some((element_id, right_row, right_weight)) => { - Self::combine_rows( - left_row, - (*left_weight) as i64, - &right_row, - right_weight as i64, - output, - ); - // Continue scanning with this left row - *self = JoinEvalState::ProcessLeftJoin { - deltas: std::mem::take(deltas), - output: std::mem::take(output), - current_idx: *current_idx, - last_row_scanned: element_id, - }; - } - None => { - // No more matches for this left row, move to next - *self = JoinEvalState::ProcessLeftJoin { - deltas: std::mem::take(deltas), - output: std::mem::take(output), - current_idx: *current_idx + 1, - last_row_scanned: i64::MIN, - }; - } - } - } - } - JoinEvalState::ProcessRightJoin { - deltas, - output, - current_idx, - last_row_scanned, - } => { - if *current_idx >= deltas.right.changes.len() { - *self = JoinEvalState::Done { - output: std::mem::take(output), - }; - } else { - let (right_row, right_weight) = &deltas.right.changes[*current_idx]; - // Extract join key using provided indices - let key_values: Vec = right_key_indices - .iter() - .map(|&idx| right_row.values.get(idx).cloned().unwrap_or(Value::Null)) - .collect(); - let right_key = HashableRow::new(0, key_values); - - let next_row = return_if_io!(read_next_join_row( - left_storage_id, - &right_key, - *last_row_scanned, - cursors - )); - match next_row { - Some((element_id, left_row, left_weight)) => { - Self::combine_rows( - &left_row, - left_weight as i64, - right_row, - (*right_weight) as i64, - output, - ); - // Continue scanning with this right row - *self = JoinEvalState::ProcessRightJoin { - deltas: std::mem::take(deltas), - output: std::mem::take(output), - current_idx: *current_idx, - last_row_scanned: element_id, - }; - } - None => { - // No more matches for this right row, move to next - *self = JoinEvalState::ProcessRightJoin { - deltas: std::mem::take(deltas), - output: std::mem::take(output), - current_idx: *current_idx + 1, - last_row_scanned: i64::MIN, - }; - } - } - } - } - JoinEvalState::Done { output } => { - return Ok(IOResult::Done(std::mem::take(output))); - } - } - } - } -} - // Generic eval state that delegates to operator-specific states #[derive(Debug)] pub enum EvalState { @@ -462,6 +181,7 @@ impl ComputationTracker { #[cfg(test)] mod dbsp_types_tests { use super::*; + use crate::Value; #[test] fn test_hashable_row_delta_operations() { @@ -554,15 +274,6 @@ pub enum QueryOperator { }, } -#[derive(Debug, Clone)] -pub enum JoinType { - Inner, - Left, - Right, - Full, - Cross, -} - /// Operator DAG (Directed Acyclic Graph) /// Base trait for incremental operators pub trait IncrementalOperator: Debug { @@ -596,497 +307,11 @@ pub trait IncrementalOperator: Debug { fn set_tracker(&mut self, tracker: Arc>); } -#[derive(Debug)] -enum JoinCommitState { - Idle, - Eval { - eval_state: EvalState, - }, - CommitLeftDelta { - deltas: DeltaPair, - output: Delta, - current_idx: usize, - write_row: WriteRow, - }, - CommitRightDelta { - deltas: DeltaPair, - output: Delta, - current_idx: usize, - write_row: WriteRow, - }, - Invalid, -} - -/// Join operator - performs incremental join between two relations -/// Implements the DBSP formula: δ(R ⋈ S) = (δR ⋈ S) ∪ (R ⋈ δS) ∪ (δR ⋈ δS) -#[derive(Debug)] -pub struct JoinOperator { - /// Unique operator ID for indexing in persistent storage - operator_id: usize, - /// Type of join to perform - join_type: JoinType, - /// Column indices for extracting join keys from left input - left_key_indices: Vec, - /// Column indices for extracting join keys from right input - right_key_indices: Vec, - /// Column names from left input - left_columns: Vec, - /// Column names from right input - right_columns: Vec, - /// Tracker for computation statistics - tracker: Option>>, - - commit_state: JoinCommitState, -} - -impl JoinOperator { - pub fn new( - operator_id: usize, - join_type: JoinType, - left_key_indices: Vec, - right_key_indices: Vec, - left_columns: Vec, - right_columns: Vec, - ) -> Result { - // Check for unsupported join types - match join_type { - JoinType::Left => { - return Err(crate::LimboError::ParseError( - "LEFT OUTER JOIN is not yet supported in incremental views".to_string(), - )) - } - JoinType::Right => { - return Err(crate::LimboError::ParseError( - "RIGHT OUTER JOIN is not yet supported in incremental views".to_string(), - )) - } - JoinType::Full => { - return Err(crate::LimboError::ParseError( - "FULL OUTER JOIN is not yet supported in incremental views".to_string(), - )) - } - JoinType::Cross => { - return Err(crate::LimboError::ParseError( - "CROSS JOIN is not yet supported in incremental views".to_string(), - )) - } - JoinType::Inner => {} // Inner join is supported - } - - Ok(Self { - operator_id, - join_type, - left_key_indices, - right_key_indices, - left_columns, - right_columns, - tracker: None, - commit_state: JoinCommitState::Idle, - }) - } - - /// Extract join key from row values using the specified indices - fn extract_join_key(&self, values: &[Value], indices: &[usize]) -> HashableRow { - let key_values: Vec = indices - .iter() - .map(|&idx| values.get(idx).cloned().unwrap_or(Value::Null)) - .collect(); - // Use 0 as a dummy rowid for join keys. They don't come from a table, - // so they don't need a rowid. Their key will be the hash of the row values. - HashableRow::new(0, key_values) - } - - /// Generate storage ID for left table - fn left_storage_id(&self) -> i64 { - // Use column_index=0 for left side - generate_storage_id(self.operator_id, 0, 0) - } - - /// Generate storage ID for right table - fn right_storage_id(&self) -> i64 { - // Use column_index=1 for right side - generate_storage_id(self.operator_id, 1, 0) - } - - /// SQL-compliant comparison for join keys - /// Returns true if keys match according to SQL semantics (NULL != NULL) - fn sql_keys_equal(left_key: &HashableRow, right_key: &HashableRow) -> bool { - if left_key.values.len() != right_key.values.len() { - return false; - } - - for (left_val, right_val) in left_key.values.iter().zip(right_key.values.iter()) { - // In SQL, NULL never equals NULL - if matches!(left_val, Value::Null) || matches!(right_val, Value::Null) { - return false; - } - - // For non-NULL values, use regular comparison - if left_val != right_val { - return false; - } - } - - true - } - - fn process_join_state( - &mut self, - state: &mut EvalState, - cursors: &mut DbspStateCursors, - ) -> Result> { - // Get the join state out of the enum - match state { - EvalState::Join(js) => js.process_join_state( - cursors, - &self.left_key_indices, - &self.right_key_indices, - self.left_storage_id(), - self.right_storage_id(), - ), - _ => panic!("process_join_state called with non-join state"), - } - } - - fn eval_internal( - &mut self, - state: &mut EvalState, - cursors: &mut DbspStateCursors, - ) -> Result> { - loop { - let loop_state = std::mem::replace(state, EvalState::Uninitialized); - match loop_state { - EvalState::Uninitialized => { - panic!("Cannot eval JoinOperator with Uninitialized state"); - } - EvalState::Init { deltas } => { - let mut output = Delta::new(); - - // Component 3: δR ⋈ δS (left delta join right delta) - for (left_row, left_weight) in &deltas.left.changes { - let left_key = - self.extract_join_key(&left_row.values, &self.left_key_indices); - - for (right_row, right_weight) in &deltas.right.changes { - let right_key = - self.extract_join_key(&right_row.values, &self.right_key_indices); - - if Self::sql_keys_equal(&left_key, &right_key) { - if let Some(tracker) = &self.tracker { - tracker.lock().unwrap().record_join_lookup(); - } - - // Combine the rows - let mut combined_values = left_row.values.clone(); - combined_values.extend(right_row.values.clone()); - - // Create the joined row with a unique rowid - // Use hash of the combined values to ensure uniqueness - let temp_row = HashableRow::new(0, combined_values.clone()); - let joined_rowid = temp_row.cached_hash() as i64; - let joined_row = - HashableRow::new(joined_rowid, combined_values.clone()); - - // Add to output with combined weight - let combined_weight = left_weight * right_weight; - output.changes.push((joined_row, combined_weight)); - } - } - } - - *state = EvalState::Join(Box::new(JoinEvalState::ProcessDeltaJoin { - deltas, - output, - })); - } - EvalState::Join(join_state) => { - *state = EvalState::Join(join_state); - let output = return_if_io!(self.process_join_state(state, cursors)); - return Ok(IOResult::Done(output)); - } - EvalState::Done => { - return Ok(IOResult::Done(Delta::new())); - } - EvalState::Aggregate(_) => { - panic!("Aggregate state should not appear in join operator"); - } - } - } - } -} - -// Helper to deserialize a HashableRow from a blob -fn deserialize_hashable_row(blob: &[u8]) -> Result { - // Simple deserialization - this needs to match how we serialize in commit - // Format: [rowid:8 bytes][num_values:4 bytes][values...] - if blob.len() < 12 { - return Err(crate::LimboError::InternalError( - "Invalid blob size".to_string(), - )); - } - - let rowid = i64::from_le_bytes(blob[0..8].try_into().unwrap()); - let num_values = u32::from_le_bytes(blob[8..12].try_into().unwrap()) as usize; - - let mut values = Vec::new(); - let mut offset = 12; - - for _ in 0..num_values { - if offset >= blob.len() { - break; - } - - let type_tag = blob[offset]; - offset += 1; - - match type_tag { - 0 => values.push(Value::Null), - 1 => { - if offset + 8 <= blob.len() { - let i = i64::from_le_bytes(blob[offset..offset + 8].try_into().unwrap()); - values.push(Value::Integer(i)); - offset += 8; - } - } - 2 => { - if offset + 8 <= blob.len() { - let f = f64::from_le_bytes(blob[offset..offset + 8].try_into().unwrap()); - values.push(Value::Float(f)); - offset += 8; - } - } - 3 => { - if offset + 4 <= blob.len() { - let len = - u32::from_le_bytes(blob[offset..offset + 4].try_into().unwrap()) as usize; - offset += 4; - if offset + len < blob.len() { - let text_bytes = blob[offset..offset + len].to_vec(); - offset += len; - let subtype = match blob[offset] { - 0 => crate::types::TextSubtype::Text, - 1 => crate::types::TextSubtype::Json, - _ => crate::types::TextSubtype::Text, - }; - offset += 1; - values.push(Value::Text(crate::types::Text { - value: text_bytes, - subtype, - })); - } - } - } - 4 => { - if offset + 4 <= blob.len() { - let len = - u32::from_le_bytes(blob[offset..offset + 4].try_into().unwrap()) as usize; - offset += 4; - if offset + len <= blob.len() { - let blob_data = blob[offset..offset + len].to_vec(); - values.push(Value::Blob(blob_data)); - offset += len; - } - } - } - _ => break, // Unknown type tag - } - } - - Ok(HashableRow::new(rowid, values)) -} - -// Helper to serialize a HashableRow to a blob -fn serialize_hashable_row(row: &HashableRow) -> Vec { - let mut blob = Vec::new(); - - // Write rowid - blob.extend_from_slice(&row.rowid.to_le_bytes()); - - // Write number of values - blob.extend_from_slice(&(row.values.len() as u32).to_le_bytes()); - - // Write each value directly with type tags (like AggregateState does) - for value in &row.values { - match value { - Value::Null => blob.push(0u8), - Value::Integer(i) => { - blob.push(1u8); - blob.extend_from_slice(&i.to_le_bytes()); - } - Value::Float(f) => { - blob.push(2u8); - blob.extend_from_slice(&f.to_le_bytes()); - } - Value::Text(s) => { - blob.push(3u8); - let bytes = &s.value; - blob.extend_from_slice(&(bytes.len() as u32).to_le_bytes()); - blob.extend_from_slice(bytes); - blob.push(s.subtype as u8); - } - Value::Blob(b) => { - blob.push(4u8); - blob.extend_from_slice(&(b.len() as u32).to_le_bytes()); - blob.extend_from_slice(b); - } - } - } - - blob -} - -impl IncrementalOperator for JoinOperator { - fn eval( - &mut self, - state: &mut EvalState, - cursors: &mut DbspStateCursors, - ) -> Result> { - let delta = return_if_io!(self.eval_internal(state, cursors)); - Ok(IOResult::Done(delta)) - } - - fn commit( - &mut self, - deltas: DeltaPair, - cursors: &mut DbspStateCursors, - ) -> Result> { - loop { - let mut state = std::mem::replace(&mut self.commit_state, JoinCommitState::Invalid); - match &mut state { - JoinCommitState::Idle => { - self.commit_state = JoinCommitState::Eval { - eval_state: deltas.clone().into(), - } - } - JoinCommitState::Eval { ref mut eval_state } => { - let output = return_and_restore_if_io!( - &mut self.commit_state, - state, - self.eval(eval_state, cursors) - ); - self.commit_state = JoinCommitState::CommitLeftDelta { - deltas: deltas.clone(), - output, - current_idx: 0, - write_row: WriteRow::new(), - }; - } - JoinCommitState::CommitLeftDelta { - deltas, - output, - current_idx, - ref mut write_row, - } => { - if *current_idx >= deltas.left.changes.len() { - self.commit_state = JoinCommitState::CommitRightDelta { - deltas: std::mem::take(deltas), - output: std::mem::take(output), - current_idx: 0, - write_row: WriteRow::new(), - }; - continue; - } - - let (row, weight) = &deltas.left.changes[*current_idx]; - // Extract join key from the left row - let join_key = self.extract_join_key(&row.values, &self.left_key_indices); - - // The index key: (storage_id, zset_id, element_id) - // zset_id is the hash of the join key, element_id is hash of the row - let storage_id = self.left_storage_id(); - let zset_id = join_key.cached_hash() as i64; - let element_id = row.cached_hash() as i64; - let index_key = vec![ - Value::Integer(storage_id), - Value::Integer(zset_id), - Value::Integer(element_id), - ]; - - // The record values: we'll store the serialized row as a blob - let row_blob = serialize_hashable_row(row); - let record_values = vec![ - Value::Integer(self.left_storage_id()), - Value::Integer(join_key.cached_hash() as i64), - Value::Integer(row.cached_hash() as i64), - Value::Blob(row_blob), - ]; - - // Use return_and_restore_if_io to handle I/O properly - return_and_restore_if_io!( - &mut self.commit_state, - state, - write_row.write_row(cursors, index_key, record_values, *weight) - ); - - self.commit_state = JoinCommitState::CommitLeftDelta { - deltas: deltas.clone(), - output: output.clone(), - current_idx: *current_idx + 1, - write_row: WriteRow::new(), - }; - } - JoinCommitState::CommitRightDelta { - deltas, - output, - current_idx, - ref mut write_row, - } => { - if *current_idx >= deltas.right.changes.len() { - // Reset to Idle state for next commit - self.commit_state = JoinCommitState::Idle; - return Ok(IOResult::Done(output.clone())); - } - - let (row, weight) = &deltas.right.changes[*current_idx]; - // Extract join key from the right row - let join_key = self.extract_join_key(&row.values, &self.right_key_indices); - - // The index key: (storage_id, zset_id, element_id) - let index_key = vec![ - Value::Integer(self.right_storage_id()), - Value::Integer(join_key.cached_hash() as i64), - Value::Integer(row.cached_hash() as i64), - ]; - - // The record values: we'll store the serialized row as a blob - let row_blob = serialize_hashable_row(row); - let record_values = vec![ - Value::Integer(self.right_storage_id()), - Value::Integer(join_key.cached_hash() as i64), - Value::Integer(row.cached_hash() as i64), - Value::Blob(row_blob), - ]; - - // Use return_and_restore_if_io to handle I/O properly - return_and_restore_if_io!( - &mut self.commit_state, - state, - write_row.write_row(cursors, index_key, record_values, *weight) - ); - - self.commit_state = JoinCommitState::CommitRightDelta { - deltas: std::mem::take(deltas), - output: std::mem::take(output), - current_idx: *current_idx + 1, - write_row: WriteRow::new(), - }; - } - JoinCommitState::Invalid => { - panic!("Invalid join commit state"); - } - } - } - } - - fn set_tracker(&mut self, tracker: Arc>) { - self.tracker = Some(tracker); - } -} - #[cfg(test)] mod tests { use super::*; use crate::incremental::aggregate_operator::AGG_TYPE_REGULAR; + use crate::incremental::dbsp::HashableRow; use crate::storage::pager::CreateBTreeFlags; use crate::types::Text; use crate::util::IOExt; From 9f3d119a5ac87cf101bbcbd5528627a9b9f32445 Mon Sep 17 00:00:00 2001 From: Glauber Costa Date: Wed, 17 Sep 2025 15:02:09 -0500 Subject: [PATCH 18/34] move hashable row tests to dbsp.rs The operator.rs file was so huge, that we didn't even notice there was a test block in the middle of the file that was testing things that were long moved to dbsp.rs (the HashableRow). Move the tests there now. --- core/incremental/dbsp.rs | 53 ++++++++++++++++++++++++++++++++ core/incremental/operator.rs | 59 ------------------------------------ 2 files changed, 53 insertions(+), 59 deletions(-) diff --git a/core/incremental/dbsp.rs b/core/incremental/dbsp.rs index d4862b70a..eeab315d3 100644 --- a/core/incremental/dbsp.rs +++ b/core/incremental/dbsp.rs @@ -404,4 +404,57 @@ mod tests { let weight = zset.iter().find(|(k, _)| **k == 1).map(|(_, w)| w); assert_eq!(weight, Some(1)); } + + #[test] + fn test_hashable_row_delta_operations() { + let mut delta = Delta::new(); + + // Test INSERT + delta.insert(1, vec![Value::Integer(1), Value::Integer(100)]); + assert_eq!(delta.len(), 1); + + // Test UPDATE (DELETE + INSERT) - order matters! + delta.delete(1, vec![Value::Integer(1), Value::Integer(100)]); + delta.insert(1, vec![Value::Integer(1), Value::Integer(200)]); + assert_eq!(delta.len(), 3); // Should have 3 operations before consolidation + + // Verify order is preserved + let ops: Vec<_> = delta.changes.iter().collect(); + assert_eq!(ops[0].1, 1); // First insert + assert_eq!(ops[1].1, -1); // Delete + assert_eq!(ops[2].1, 1); // Second insert + + // Test consolidation + delta.consolidate(); + // After consolidation, the first insert and delete should cancel out + // leaving only the second insert + assert_eq!(delta.len(), 1); + + let final_row = &delta.changes[0]; + assert_eq!(final_row.0.rowid, 1); + assert_eq!( + final_row.0.values, + vec![Value::Integer(1), Value::Integer(200)] + ); + assert_eq!(final_row.1, 1); + } + + #[test] + fn test_duplicate_row_consolidation() { + let mut delta = Delta::new(); + + // Insert same row twice + delta.insert(2, vec![Value::Integer(2), Value::Integer(300)]); + delta.insert(2, vec![Value::Integer(2), Value::Integer(300)]); + + assert_eq!(delta.len(), 2); + + delta.consolidate(); + assert_eq!(delta.len(), 1); + + // Weight should be 2 (sum of both inserts) + let final_row = &delta.changes[0]; + assert_eq!(final_row.0.rowid, 2); + assert_eq!(final_row.1, 2); + } } diff --git a/core/incremental/operator.rs b/core/incremental/operator.rs index 278ce4ef9..54cd7e0a0 100644 --- a/core/incremental/operator.rs +++ b/core/incremental/operator.rs @@ -178,65 +178,6 @@ impl ComputationTracker { } } -#[cfg(test)] -mod dbsp_types_tests { - use super::*; - use crate::Value; - - #[test] - fn test_hashable_row_delta_operations() { - let mut delta = Delta::new(); - - // Test INSERT - delta.insert(1, vec![Value::Integer(1), Value::Integer(100)]); - assert_eq!(delta.len(), 1); - - // Test UPDATE (DELETE + INSERT) - order matters! - delta.delete(1, vec![Value::Integer(1), Value::Integer(100)]); - delta.insert(1, vec![Value::Integer(1), Value::Integer(200)]); - assert_eq!(delta.len(), 3); // Should have 3 operations before consolidation - - // Verify order is preserved - let ops: Vec<_> = delta.changes.iter().collect(); - assert_eq!(ops[0].1, 1); // First insert - assert_eq!(ops[1].1, -1); // Delete - assert_eq!(ops[2].1, 1); // Second insert - - // Test consolidation - delta.consolidate(); - // After consolidation, the first insert and delete should cancel out - // leaving only the second insert - assert_eq!(delta.len(), 1); - - let final_row = &delta.changes[0]; - assert_eq!(final_row.0.rowid, 1); - assert_eq!( - final_row.0.values, - vec![Value::Integer(1), Value::Integer(200)] - ); - assert_eq!(final_row.1, 1); - } - - #[test] - fn test_duplicate_row_consolidation() { - let mut delta = Delta::new(); - - // Insert same row twice - delta.insert(2, vec![Value::Integer(2), Value::Integer(300)]); - delta.insert(2, vec![Value::Integer(2), Value::Integer(300)]); - - assert_eq!(delta.len(), 2); - - delta.consolidate(); - assert_eq!(delta.len(), 1); - - // Weight should be 2 (sum of both inserts) - let final_row = &delta.changes[0]; - assert_eq!(final_row.0.rowid, 2); - assert_eq!(final_row.1, 2); - } -} - /// Represents an operator in the dataflow graph #[derive(Debug, Clone)] pub enum QueryOperator { From f149b40e75685b352fb4df35b548e05ff19321bc Mon Sep 17 00:00:00 2001 From: Glauber Costa Date: Tue, 16 Sep 2025 16:00:15 -0500 Subject: [PATCH 19/34] Implement JOINs in the DBSP circuit This PR improves the DBSP circuit so that it handles the JOIN operator. The JOIN operator exposes a weakness of our current model: we usually pass a list of columns between operators, and find the right column by name when needed. But with JOINs, many tables can have the same columns. The operators will then find the wrong column (same name, different table), and produce incorrect results. To fix this, we must do two things: 1) Change the Logical Plan. It needs to track table provenance. 2) Fix the aggregators: it needs to operate on indexes, not names. For the aggregators, note that table provenance is the wrong abstraction. The aggregator is likely working with a logical table that is the result of previous nodes in the circuit. So we just need to be able to tell it which index in the column array it should use. --- core/incremental/aggregate_operator.rs | 273 +++-- core/incremental/compiler.rs | 1256 +++++++++++++++++++++++- core/incremental/operator.rs | 130 +-- core/translate/logical.rs | 276 ++++-- 4 files changed, 1625 insertions(+), 310 deletions(-) diff --git a/core/incremental/aggregate_operator.rs b/core/incremental/aggregate_operator.rs index f4c8ece0a..9f25a84f5 100644 --- a/core/incremental/aggregate_operator.rs +++ b/core/incremental/aggregate_operator.rs @@ -19,20 +19,20 @@ pub const AGG_TYPE_MINMAX: u8 = 0b01; // MIN/MAX (BTree ordering gives both) #[derive(Debug, Clone, PartialEq)] pub enum AggregateFunction { Count, - Sum(String), - Avg(String), - Min(String), - Max(String), + Sum(usize), // Column index + Avg(usize), // Column index + Min(usize), // Column index + Max(usize), // Column index } impl Display for AggregateFunction { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { match self { AggregateFunction::Count => write!(f, "COUNT(*)"), - AggregateFunction::Sum(col) => write!(f, "SUM({col})"), - AggregateFunction::Avg(col) => write!(f, "AVG({col})"), - AggregateFunction::Min(col) => write!(f, "MIN({col})"), - AggregateFunction::Max(col) => write!(f, "MAX({col})"), + AggregateFunction::Sum(idx) => write!(f, "SUM(col{idx})"), + AggregateFunction::Avg(idx) => write!(f, "AVG(col{idx})"), + AggregateFunction::Min(idx) => write!(f, "MIN(col{idx})"), + AggregateFunction::Max(idx) => write!(f, "MAX(col{idx})"), } } } @@ -48,16 +48,16 @@ impl AggregateFunction { /// Returns None if the function is not a supported aggregate pub fn from_sql_function( func: &crate::function::Func, - input_column: Option, + input_column_idx: Option, ) -> Option { match func { Func::Agg(agg_func) => { match agg_func { AggFunc::Count | AggFunc::Count0 => Some(AggregateFunction::Count), - AggFunc::Sum => input_column.map(AggregateFunction::Sum), - AggFunc::Avg => input_column.map(AggregateFunction::Avg), - AggFunc::Min => input_column.map(AggregateFunction::Min), - AggFunc::Max => input_column.map(AggregateFunction::Max), + AggFunc::Sum => input_column_idx.map(AggregateFunction::Sum), + AggFunc::Avg => input_column_idx.map(AggregateFunction::Avg), + AggFunc::Min => input_column_idx.map(AggregateFunction::Min), + AggFunc::Max => input_column_idx.map(AggregateFunction::Max), _ => None, // Other aggregate functions not yet supported in DBSP } } @@ -115,8 +115,8 @@ pub fn deserialize_value(blob: &[u8]) -> Option<(Value, usize)> { // group_key_str -> (group_key, state) type ComputedStates = HashMap, AggregateState)>; -// group_key_str -> (column_name, value_as_hashable_row) -> accumulated_weight -pub type MinMaxDeltas = HashMap>; +// group_key_str -> (column_index, value_as_hashable_row) -> accumulated_weight +pub type MinMaxDeltas = HashMap>; #[derive(Debug)] enum AggregateCommitState { @@ -178,14 +178,14 @@ pub enum AggregateEvalState { pub struct AggregateOperator { // Unique operator ID for indexing in persistent storage pub operator_id: usize, - // GROUP BY columns - group_by: Vec, + // GROUP BY column indices + group_by: Vec, // Aggregate functions to compute (including MIN/MAX) pub aggregates: Vec, // Column names from input pub input_column_names: Vec, - // Map from column name to aggregate info for quick lookup - pub column_min_max: HashMap, + // Map from column index to aggregate info for quick lookup + pub column_min_max: HashMap, tracker: Option>>, // State machine for commit operation @@ -197,14 +197,14 @@ pub struct AggregateOperator { pub struct AggregateState { // For COUNT: just the count pub count: i64, - // For SUM: column_name -> sum value - sums: HashMap, - // For AVG: column_name -> (sum, count) for computing average - avgs: HashMap, - // For MIN: column_name -> minimum value - pub mins: HashMap, - // For MAX: column_name -> maximum value - pub maxs: HashMap, + // For SUM: column_index -> sum value + sums: HashMap, + // For AVG: column_index -> (sum, count) for computing average + avgs: HashMap, + // For MIN: column_index -> minimum value + pub mins: HashMap, + // For MAX: column_index -> maximum value + pub maxs: HashMap, } impl AggregateEvalState { @@ -520,14 +520,14 @@ impl AggregateState { AggregateFunction::Sum(col_name) => { let sum = f64::from_le_bytes(blob.get(cursor..cursor + 8)?.try_into().ok()?); cursor += 8; - state.sums.insert(col_name.clone(), sum); + state.sums.insert(*col_name, sum); } AggregateFunction::Avg(col_name) => { let sum = f64::from_le_bytes(blob.get(cursor..cursor + 8)?.try_into().ok()?); cursor += 8; let count = i64::from_le_bytes(blob.get(cursor..cursor + 8)?.try_into().ok()?); cursor += 8; - state.avgs.insert(col_name.clone(), (sum, count)); + state.avgs.insert(*col_name, (sum, count)); } AggregateFunction::Count => { // Count was already read above @@ -540,7 +540,7 @@ impl AggregateState { if has_value == 1 { let (min_value, bytes_consumed) = deserialize_value(&blob[cursor..])?; cursor += bytes_consumed; - state.mins.insert(col_name.clone(), min_value); + state.mins.insert(*col_name, min_value); } } AggregateFunction::Max(col_name) => { @@ -551,7 +551,7 @@ impl AggregateState { if has_value == 1 { let (max_value, bytes_consumed) = deserialize_value(&blob[cursor..])?; cursor += bytes_consumed; - state.maxs.insert(col_name.clone(), max_value); + state.maxs.insert(*col_name, max_value); } } } @@ -566,7 +566,7 @@ impl AggregateState { values: &[Value], weight: isize, aggregates: &[AggregateFunction], - column_names: &[String], + _column_names: &[String], // No longer needed ) { // Update COUNT self.count += weight as i64; @@ -577,32 +577,26 @@ impl AggregateState { AggregateFunction::Count => { // Already handled above } - AggregateFunction::Sum(col_name) => { - if let Some(idx) = column_names.iter().position(|c| c == col_name) { - if let Some(val) = values.get(idx) { - let num_val = match val { - Value::Integer(i) => *i as f64, - Value::Float(f) => *f, - _ => 0.0, - }; - *self.sums.entry(col_name.clone()).or_insert(0.0) += - num_val * weight as f64; - } + AggregateFunction::Sum(col_idx) => { + if let Some(val) = values.get(*col_idx) { + let num_val = match val { + Value::Integer(i) => *i as f64, + Value::Float(f) => *f, + _ => 0.0, + }; + *self.sums.entry(*col_idx).or_insert(0.0) += num_val * weight as f64; } } - AggregateFunction::Avg(col_name) => { - if let Some(idx) = column_names.iter().position(|c| c == col_name) { - if let Some(val) = values.get(idx) { - let num_val = match val { - Value::Integer(i) => *i as f64, - Value::Float(f) => *f, - _ => 0.0, - }; - let (sum, count) = - self.avgs.entry(col_name.clone()).or_insert((0.0, 0)); - *sum += num_val * weight as f64; - *count += weight as i64; - } + AggregateFunction::Avg(col_idx) => { + if let Some(val) = values.get(*col_idx) { + let num_val = match val { + Value::Integer(i) => *i as f64, + Value::Float(f) => *f, + _ => 0.0, + }; + let (sum, count) = self.avgs.entry(*col_idx).or_insert((0.0, 0)); + *sum += num_val * weight as f64; + *count += weight as i64; } } AggregateFunction::Min(_col_name) | AggregateFunction::Max(_col_name) => { @@ -644,8 +638,8 @@ impl AggregateState { AggregateFunction::Count => { result.push(Value::Integer(self.count)); } - AggregateFunction::Sum(col_name) => { - let sum = self.sums.get(col_name).copied().unwrap_or(0.0); + AggregateFunction::Sum(col_idx) => { + let sum = self.sums.get(col_idx).copied().unwrap_or(0.0); // Return as integer if it's a whole number, otherwise as float if sum.fract() == 0.0 { result.push(Value::Integer(sum as i64)); @@ -653,8 +647,8 @@ impl AggregateState { result.push(Value::Float(sum)); } } - AggregateFunction::Avg(col_name) => { - if let Some((sum, count)) = self.avgs.get(col_name) { + AggregateFunction::Avg(col_idx) => { + if let Some((sum, count)) = self.avgs.get(col_idx) { if *count > 0 { result.push(Value::Float(sum / *count as f64)); } else { @@ -664,13 +658,13 @@ impl AggregateState { result.push(Value::Null); } } - AggregateFunction::Min(col_name) => { + AggregateFunction::Min(col_idx) => { // Return the MIN value from our state - result.push(self.mins.get(col_name).cloned().unwrap_or(Value::Null)); + result.push(self.mins.get(col_idx).cloned().unwrap_or(Value::Null)); } - AggregateFunction::Max(col_name) => { + AggregateFunction::Max(col_idx) => { // Return the MAX value from our state - result.push(self.maxs.get(col_name).cloned().unwrap_or(Value::Null)); + result.push(self.maxs.get(col_idx).cloned().unwrap_or(Value::Null)); } } } @@ -682,20 +676,20 @@ impl AggregateState { impl AggregateOperator { pub fn new( operator_id: usize, - group_by: Vec, + group_by: Vec, aggregates: Vec, input_column_names: Vec, ) -> Self { - // Build map of column names to their MIN/MAX info with indices + // Build map of column indices to their MIN/MAX info let mut column_min_max = HashMap::new(); - let mut column_indices = HashMap::new(); + let mut storage_indices = HashMap::new(); let mut current_index = 0; - // First pass: assign indices to unique MIN/MAX columns + // First pass: assign storage indices to unique MIN/MAX columns for agg in &aggregates { match agg { - AggregateFunction::Min(col) | AggregateFunction::Max(col) => { - column_indices.entry(col.clone()).or_insert_with(|| { + AggregateFunction::Min(col_idx) | AggregateFunction::Max(col_idx) => { + storage_indices.entry(*col_idx).or_insert_with(|| { let idx = current_index; current_index += 1; idx @@ -708,19 +702,19 @@ impl AggregateOperator { // Second pass: build the column info map for agg in &aggregates { match agg { - AggregateFunction::Min(col) => { - let index = *column_indices.get(col).unwrap(); - let entry = column_min_max.entry(col.clone()).or_insert(AggColumnInfo { - index, + AggregateFunction::Min(col_idx) => { + let storage_index = *storage_indices.get(col_idx).unwrap(); + let entry = column_min_max.entry(*col_idx).or_insert(AggColumnInfo { + index: storage_index, has_min: false, has_max: false, }); entry.has_min = true; } - AggregateFunction::Max(col) => { - let index = *column_indices.get(col).unwrap(); - let entry = column_min_max.entry(col.clone()).or_insert(AggColumnInfo { - index, + AggregateFunction::Max(col_idx) => { + let storage_index = *storage_indices.get(col_idx).unwrap(); + let entry = column_min_max.entry(*col_idx).or_insert(AggColumnInfo { + index: storage_index, has_min: false, has_max: false, }); @@ -876,28 +870,24 @@ impl AggregateOperator { for agg in &self.aggregates { match agg { - AggregateFunction::Min(col_name) | AggregateFunction::Max(col_name) => { - if let Some(idx) = - self.input_column_names.iter().position(|c| c == col_name) - { - if let Some(val) = row.values.get(idx) { - // Skip NULL values - they don't participate in MIN/MAX - if val == &Value::Null { - continue; - } - // Create a HashableRow with just this value - // Use 0 as rowid since we only care about the value for comparison - let hashable_value = HashableRow::new(0, vec![val.clone()]); - let key = (col_name.clone(), hashable_value); - - let group_entry = - min_max_deltas.entry(group_key_str.clone()).or_default(); - - let value_entry = group_entry.entry(key).or_insert(0); - - // Accumulate the weight - *value_entry += weight; + AggregateFunction::Min(col_idx) | AggregateFunction::Max(col_idx) => { + if let Some(val) = row.values.get(*col_idx) { + // Skip NULL values - they don't participate in MIN/MAX + if val == &Value::Null { + continue; } + // Create a HashableRow with just this value + // Use 0 as rowid since we only care about the value for comparison + let hashable_value = HashableRow::new(0, vec![val.clone()]); + let key = (*col_idx, hashable_value); + + let group_entry = + min_max_deltas.entry(group_key_str.clone()).or_default(); + + let value_entry = group_entry.entry(key).or_insert(0); + + // Accumulate the weight + *value_entry += weight; } } _ => {} // Ignore non-MIN/MAX aggregates @@ -929,13 +919,9 @@ impl AggregateOperator { pub fn extract_group_key(&self, values: &[Value]) -> Vec { let mut key = Vec::new(); - for group_col in &self.group_by { - if let Some(idx) = self.input_column_names.iter().position(|c| c == group_col) { - if let Some(val) = values.get(idx) { - key.push(val.clone()); - } else { - key.push(Value::Null); - } + for &idx in &self.group_by { + if let Some(val) = values.get(idx) { + key.push(val.clone()); } else { key.push(Value::Null); } @@ -1124,13 +1110,13 @@ pub enum RecomputeMinMax { /// Current column being processed current_column_idx: usize, /// Columns to process (combined MIN and MAX) - columns_to_process: Vec<(String, String, bool)>, // (group_key, column_name, is_min) + columns_to_process: Vec<(String, usize, bool)>, // (group_key, column_name, is_min) /// MIN/MAX deltas for checking values and weights min_max_deltas: MinMaxDeltas, }, Scan { /// Columns still to process - columns_to_process: Vec<(String, String, bool)>, + columns_to_process: Vec<(String, usize, bool)>, /// Current index in columns_to_process (will resume from here) current_column_idx: usize, /// MIN/MAX deltas for checking values and weights @@ -1138,7 +1124,7 @@ pub enum RecomputeMinMax { /// Current group key being processed group_key: String, /// Current column name being processed - column_name: String, + column_name: usize, /// Whether we're looking for MIN (true) or MAX (false) is_min: bool, /// The scan state machine for finding the new MIN/MAX @@ -1153,7 +1139,7 @@ impl RecomputeMinMax { existing_groups: &HashMap, operator: &AggregateOperator, ) -> Self { - let mut groups_to_check: HashSet<(String, String, bool)> = HashSet::new(); + let mut groups_to_check: HashSet<(String, usize, bool)> = HashSet::new(); // Remember the min_max_deltas are essentially just the only column that is affected by // this min/max, in delta (actually ZSet - consolidated delta) format. This makes it easier @@ -1173,21 +1159,13 @@ impl RecomputeMinMax { // Check for MIN if let Some(current_min) = state.mins.get(col_name) { if current_min == value { - groups_to_check.insert(( - group_key_str.clone(), - col_name.clone(), - true, - )); + groups_to_check.insert((group_key_str.clone(), *col_name, true)); } } // Check for MAX if let Some(current_max) = state.maxs.get(col_name) { if current_max == value { - groups_to_check.insert(( - group_key_str.clone(), - col_name.clone(), - false, - )); + groups_to_check.insert((group_key_str.clone(), *col_name, false)); } } } @@ -1196,14 +1174,10 @@ impl RecomputeMinMax { // about this if this is a new record being inserted if let Some(info) = col_info { if info.has_min { - groups_to_check.insert((group_key_str.clone(), col_name.clone(), true)); + groups_to_check.insert((group_key_str.clone(), *col_name, true)); } if info.has_max { - groups_to_check.insert(( - group_key_str.clone(), - col_name.clone(), - false, - )); + groups_to_check.insert((group_key_str.clone(), *col_name, false)); } } } @@ -1245,12 +1219,13 @@ impl RecomputeMinMax { let (group_key, column_name, is_min) = columns_to_process[*current_column_idx].clone(); - // Get column index from pre-computed info - let column_index = operator + // Column name is already the index + // Get the storage index from column_min_max map + let column_info = operator .column_min_max .get(&column_name) - .map(|info| info.index) - .unwrap(); // Should always exist since we're processing known columns + .expect("Column should exist in column_min_max map"); + let storage_index = column_info.index; // Get current value from existing state let current_value = existing_groups.get(&group_key).and_then(|state| { @@ -1263,7 +1238,7 @@ impl RecomputeMinMax { // Create storage keys for index lookup let storage_id = - generate_storage_id(operator.operator_id, column_index, AGG_TYPE_MINMAX); + generate_storage_id(operator.operator_id, storage_index, AGG_TYPE_MINMAX); let zset_id = operator.generate_group_rowid(&group_key); // Get the values for this group from min_max_deltas @@ -1276,7 +1251,7 @@ impl RecomputeMinMax { Box::new(ScanState::new_for_min( current_value, group_key.clone(), - column_name.clone(), + column_name, storage_id, zset_id, group_values, @@ -1285,7 +1260,7 @@ impl RecomputeMinMax { Box::new(ScanState::new_for_max( current_value, group_key.clone(), - column_name.clone(), + column_name, storage_id, zset_id, group_values, @@ -1319,12 +1294,12 @@ impl RecomputeMinMax { if *is_min { if let Some(min_val) = new_value { - state.mins.insert(column_name.clone(), min_val); + state.mins.insert(*column_name, min_val); } else { state.mins.remove(column_name); } } else if let Some(max_val) = new_value { - state.maxs.insert(column_name.clone(), max_val); + state.maxs.insert(*column_name, max_val); } else { state.maxs.remove(column_name); } @@ -1355,13 +1330,13 @@ pub enum ScanState { /// Group key being processed group_key: String, /// Column name being processed - column_name: String, + column_name: usize, /// Storage ID for the index seek storage_id: i64, /// ZSet ID for the group zset_id: i64, /// Group values from MinMaxDeltas: (column_name, HashableRow) -> weight - group_values: HashMap<(String, HashableRow), isize>, + group_values: HashMap<(usize, HashableRow), isize>, /// Whether we're looking for MIN (true) or MAX (false) is_min: bool, }, @@ -1371,13 +1346,13 @@ pub enum ScanState { /// Group key being processed group_key: String, /// Column name being processed - column_name: String, + column_name: usize, /// Storage ID for the index seek storage_id: i64, /// ZSet ID for the group zset_id: i64, /// Group values from MinMaxDeltas: (column_name, HashableRow) -> weight - group_values: HashMap<(String, HashableRow), isize>, + group_values: HashMap<(usize, HashableRow), isize>, /// Whether we're looking for MIN (true) or MAX (false) is_min: bool, }, @@ -1391,10 +1366,10 @@ impl ScanState { pub fn new_for_min( current_min: Option, group_key: String, - column_name: String, + column_name: usize, storage_id: i64, zset_id: i64, - group_values: HashMap<(String, HashableRow), isize>, + group_values: HashMap<(usize, HashableRow), isize>, ) -> Self { Self::CheckCandidate { candidate: current_min, @@ -1460,10 +1435,10 @@ impl ScanState { pub fn new_for_max( current_max: Option, group_key: String, - column_name: String, + column_name: usize, storage_id: i64, zset_id: i64, - group_values: HashMap<(String, HashableRow), isize>, + group_values: HashMap<(usize, HashableRow), isize>, ) -> Self { Self::CheckCandidate { candidate: current_max, @@ -1496,7 +1471,7 @@ impl ScanState { // Check if the candidate is retracted (weight <= 0) // Create a HashableRow to look up the weight let hashable_cand = HashableRow::new(0, vec![cand_val.clone()]); - let key = (column_name.clone(), hashable_cand); + let key = (*column_name, hashable_cand); let is_retracted = group_values.get(&key).is_some_and(|weight| *weight <= 0); @@ -1633,7 +1608,7 @@ pub enum MinMaxPersistState { group_idx: usize, value_idx: usize, value: Value, - column_name: String, + column_name: usize, weight: isize, write_row: WriteRow, }, @@ -1652,7 +1627,7 @@ impl MinMaxPersistState { pub fn persist_min_max( &mut self, operator_id: usize, - column_min_max: &HashMap, + column_min_max: &HashMap, cursors: &mut DbspStateCursors, generate_group_rowid: impl Fn(&str) -> i64, ) -> Result> { @@ -1699,7 +1674,7 @@ impl MinMaxPersistState { // Process current value and extract what we need before taking ownership let ((column_name, hashable_row), weight) = values_vec[*value_idx]; - let column_name = column_name.clone(); + let column_name = *column_name; let value = hashable_row.values[0].clone(); // Extract the Value from HashableRow let weight = *weight; @@ -1731,9 +1706,9 @@ impl MinMaxPersistState { let group_key_str = &group_keys[*group_idx]; - // Get the column index from the pre-computed map + // Get the column info from the pre-computed map let column_info = column_min_max - .get(&*column_name) + .get(column_name) .expect("Column should exist in column_min_max map"); let column_index = column_info.index; diff --git a/core/incremental/compiler.rs b/core/incremental/compiler.rs index 972d6797b..c8899a02e 100644 --- a/core/incremental/compiler.rs +++ b/core/incremental/compiler.rs @@ -9,12 +9,12 @@ use crate::incremental::dbsp::{Delta, DeltaPair}; use crate::incremental::expr_compiler::CompiledExpression; use crate::incremental::operator::{ create_dbsp_state_index, DbspStateCursors, EvalState, FilterOperator, FilterPredicate, - IncrementalOperator, InputOperator, ProjectOperator, + IncrementalOperator, InputOperator, JoinOperator, JoinType, ProjectOperator, }; use crate::storage::btree::{BTreeCursor, BTreeKey}; // Note: logical module must be made pub(crate) in translate/mod.rs use crate::translate::logical::{ - BinaryOperator, LogicalExpr, LogicalPlan, LogicalSchema, SchemaRef, + BinaryOperator, JoinType as LogicalJoinType, LogicalExpr, LogicalPlan, LogicalSchema, SchemaRef, }; use crate::types::{IOResult, ImmutableRecord, SeekKey, SeekOp, SeekResult, Value}; use crate::Pager; @@ -288,6 +288,12 @@ pub enum DbspOperator { aggr_exprs: Vec, schema: SchemaRef, }, + /// Join operator (⋈) - joins two relations + Join { + join_type: JoinType, + on_exprs: Vec<(DbspExpr, DbspExpr)>, + schema: SchemaRef, + }, /// Input operator - source of data Input { name: String, schema: SchemaRef }, } @@ -789,6 +795,13 @@ impl DbspCircuit { "{indent}Aggregate[{node_id}]: GROUP BY {group_exprs:?}, AGGR {aggr_exprs:?}" )?; } + DbspOperator::Join { + join_type, + on_exprs, + .. + } => { + writeln!(f, "{indent}Join[{node_id}]: {join_type:?} ON {on_exprs:?}")?; + } DbspOperator::Input { name, .. } => { writeln!(f, "{indent}Input[{node_id}]: {name}")?; } @@ -841,7 +854,7 @@ impl DbspCompiler { // Get input column names for the ProjectOperator let input_schema = proj.input.schema(); let input_column_names: Vec = input_schema.columns.iter() - .map(|(name, _)| name.clone()) + .map(|col| col.name.clone()) .collect(); // Convert logical expressions to DBSP expressions @@ -853,14 +866,14 @@ impl DbspCompiler { let mut compiled_exprs = Vec::new(); let mut aliases = Vec::new(); for expr in &proj.exprs { - let (compiled, alias) = Self::compile_expression(expr, &input_column_names)?; + let (compiled, alias) = Self::compile_expression(expr, input_schema)?; compiled_exprs.push(compiled); aliases.push(alias); } // Get output column names from the projection schema let output_column_names: Vec = proj.schema.columns.iter() - .map(|(name, _)| name.clone()) + .map(|col| col.name.clone()) .collect(); // Create the ProjectOperator @@ -885,7 +898,7 @@ impl DbspCompiler { // Get column names from input schema let input_schema = filter.input.schema(); let column_names: Vec = input_schema.columns.iter() - .map(|(name, _)| name.clone()) + .map(|col| col.name.clone()) .collect(); // Convert predicate to DBSP expression @@ -913,16 +926,21 @@ impl DbspCompiler { // Get input column names let input_schema = agg.input.schema(); let input_column_names: Vec = input_schema.columns.iter() - .map(|(name, _)| name.clone()) + .map(|col| col.name.clone()) .collect(); - // Compile group by expressions to column names - let mut group_by_columns = Vec::new(); + // Compile group by expressions to column indices + let mut group_by_indices = Vec::new(); let mut dbsp_group_exprs = Vec::new(); for expr in &agg.group_expr { // For now, only support simple column references in GROUP BY if let LogicalExpr::Column(col) = expr { - group_by_columns.push(col.name.clone()); + // Find the column index in the input schema using qualified lookup + let (col_idx, _) = input_schema.find_column(&col.name, col.table.as_deref()) + .ok_or_else(|| LimboError::ParseError( + format!("GROUP BY column '{}' not found in input", col.name) + ))?; + group_by_indices.push(col_idx); dbsp_group_exprs.push(DbspExpr::Column(col.name.clone())); } else { return Err(LimboError::ParseError( @@ -936,7 +954,7 @@ impl DbspCompiler { for expr in &agg.aggr_expr { if let LogicalExpr::AggregateFunction { fun, args, .. } = expr { use crate::function::AggFunc; - use crate::incremental::operator::AggregateFunction; + use crate::incremental::aggregate_operator::AggregateFunction; match fun { AggFunc::Count | AggFunc::Count0 => { @@ -946,9 +964,13 @@ impl DbspCompiler { if args.is_empty() { return Err(LimboError::ParseError("SUM requires an argument".to_string())); } - // Extract column name from the argument + // Extract column index from the argument if let LogicalExpr::Column(col) = &args[0] { - aggregate_functions.push(AggregateFunction::Sum(col.name.clone())); + let (col_idx, _) = input_schema.find_column(&col.name, col.table.as_deref()) + .ok_or_else(|| LimboError::ParseError( + format!("SUM column '{}' not found in input", col.name) + ))?; + aggregate_functions.push(AggregateFunction::Sum(col_idx)); } else { return Err(LimboError::ParseError( "Only column references are supported in aggregate functions for incremental views".to_string() @@ -960,7 +982,11 @@ impl DbspCompiler { return Err(LimboError::ParseError("AVG requires an argument".to_string())); } if let LogicalExpr::Column(col) = &args[0] { - aggregate_functions.push(AggregateFunction::Avg(col.name.clone())); + let (col_idx, _) = input_schema.find_column(&col.name, col.table.as_deref()) + .ok_or_else(|| LimboError::ParseError( + format!("AVG column '{}' not found in input", col.name) + ))?; + aggregate_functions.push(AggregateFunction::Avg(col_idx)); } else { return Err(LimboError::ParseError( "Only column references are supported in aggregate functions for incremental views".to_string() @@ -972,7 +998,11 @@ impl DbspCompiler { return Err(LimboError::ParseError("MIN requires an argument".to_string())); } if let LogicalExpr::Column(col) = &args[0] { - aggregate_functions.push(AggregateFunction::Min(col.name.clone())); + let (col_idx, _) = input_schema.find_column(&col.name, col.table.as_deref()) + .ok_or_else(|| LimboError::ParseError( + format!("MIN column '{}' not found in input", col.name) + ))?; + aggregate_functions.push(AggregateFunction::Min(col_idx)); } else { return Err(LimboError::ParseError( "Only column references are supported in MIN for incremental views".to_string() @@ -984,7 +1014,11 @@ impl DbspCompiler { return Err(LimboError::ParseError("MAX requires an argument".to_string())); } if let LogicalExpr::Column(col) = &args[0] { - aggregate_functions.push(AggregateFunction::Max(col.name.clone())); + let (col_idx, _) = input_schema.find_column(&col.name, col.table.as_deref()) + .ok_or_else(|| LimboError::ParseError( + format!("MAX column '{}' not found in input", col.name) + ))?; + aggregate_functions.push(AggregateFunction::Max(col_idx)); } else { return Err(LimboError::ParseError( "Only column references are supported in MAX for incremental views".to_string() @@ -1006,10 +1040,10 @@ impl DbspCompiler { let operator_id = self.circuit.next_id; - use crate::incremental::operator::AggregateOperator; + use crate::incremental::aggregate_operator::AggregateOperator; let executable: Box = Box::new(AggregateOperator::new( operator_id, - group_by_columns.clone(), + group_by_indices.clone(), aggregate_functions.clone(), input_column_names.clone(), )); @@ -1026,6 +1060,90 @@ impl DbspCompiler { Ok(result_node_id) } + LogicalPlan::Join(join) => { + // Compile left and right inputs + let left_id = self.compile_plan(&join.left)?; + let right_id = self.compile_plan(&join.right)?; + + // Get schemas from inputs + let left_schema = join.left.schema(); + let right_schema = join.right.schema(); + + // Get column names from left and right + let left_columns: Vec = left_schema.columns.iter() + .map(|col| col.name.clone()) + .collect(); + let right_columns: Vec = right_schema.columns.iter() + .map(|col| col.name.clone()) + .collect(); + + // Extract join key indices from join conditions + // For now, we only support equijoin conditions + let mut left_key_indices = Vec::new(); + let mut right_key_indices = Vec::new(); + let mut dbsp_on_exprs = Vec::new(); + + for (left_expr, right_expr) in &join.on { + // Extract column indices from join expressions + // We expect simple column references in join conditions + if let (LogicalExpr::Column(left_col), LogicalExpr::Column(right_col)) = (left_expr, right_expr) { + // Find indices in respective schemas using qualified lookup + let (left_idx, _) = left_schema.find_column(&left_col.name, left_col.table.as_deref()) + .ok_or_else(|| LimboError::ParseError( + format!("Join column '{}' not found in left input", left_col.name) + ))?; + let (right_idx, _) = right_schema.find_column(&right_col.name, right_col.table.as_deref()) + .ok_or_else(|| LimboError::ParseError( + format!("Join column '{}' not found in right input", right_col.name) + ))?; + + left_key_indices.push(left_idx); + right_key_indices.push(right_idx); + + // Convert to DBSP expressions + dbsp_on_exprs.push(( + DbspExpr::Column(left_col.name.clone()), + DbspExpr::Column(right_col.name.clone()) + )); + } else { + return Err(LimboError::ParseError( + "Only simple column references are supported in join conditions for incremental views".to_string() + )); + } + } + + // Convert logical join type to operator join type + let operator_join_type = match join.join_type { + LogicalJoinType::Inner => JoinType::Inner, + LogicalJoinType::Left => JoinType::Left, + LogicalJoinType::Right => JoinType::Right, + LogicalJoinType::Full => JoinType::Full, + LogicalJoinType::Cross => JoinType::Cross, + }; + + // Create JoinOperator + let operator_id = self.circuit.next_id; + let executable: Box = Box::new(JoinOperator::new( + operator_id, + operator_join_type.clone(), + left_key_indices, + right_key_indices, + left_columns, + right_columns, + )?); + + // Create join node + let node_id = self.circuit.add_node( + DbspOperator::Join { + join_type: operator_join_type, + on_exprs: dbsp_on_exprs, + schema: join.schema.clone(), + }, + vec![left_id, right_id], + executable, + ); + Ok(node_id) + } LogicalPlan::TableScan(scan) => { // Create input node with InputOperator for uniform handling let executable: Box = @@ -1042,7 +1160,7 @@ impl DbspCompiler { Ok(node_id) } _ => Err(LimboError::ParseError( - format!("Unsupported operator in DBSP compiler: only Filter, Projection and Aggregate are supported, got: {:?}", + format!("Unsupported operator in DBSP compiler: only Filter, Projection, Join and Aggregate are supported, got: {:?}", match plan { LogicalPlan::Sort(_) => "Sort", LogicalPlan::Limit(_) => "Limit", @@ -1095,17 +1213,24 @@ impl DbspCompiler { /// Compile a logical expression to a CompiledExpression and optional alias fn compile_expression( expr: &LogicalExpr, - input_column_names: &[String], + input_schema: &LogicalSchema, ) -> Result<(CompiledExpression, Option)> { // Check for alias first if let LogicalExpr::Alias { expr, alias } = expr { // For aliases, compile the underlying expression and return with alias - let (compiled, _) = Self::compile_expression(expr, input_column_names)?; + let (compiled, _) = Self::compile_expression(expr, input_schema)?; return Ok((compiled, Some(alias.clone()))); } - // Convert LogicalExpr to AST Expr - let ast_expr = Self::logical_to_ast_expr(expr)?; + // Convert LogicalExpr to AST Expr with proper column resolution + let ast_expr = Self::logical_to_ast_expr_with_schema(expr, input_schema)?; + + // Extract column names from schema for CompiledExpression::compile + let input_column_names: Vec = input_schema + .columns + .iter() + .map(|col| col.name.clone()) + .collect(); // For all expressions (simple or complex), use CompiledExpression::compile // This handles both trivial cases and complex VDBE compilation @@ -1129,7 +1254,7 @@ impl DbspCompiler { // Compile the expression using the existing CompiledExpression::compile let compiled = CompiledExpression::compile( &ast_expr, - input_column_names, + &input_column_names, &schema, &temp_syms, internal_conn, @@ -1138,12 +1263,27 @@ impl DbspCompiler { Ok((compiled, None)) } - /// Convert LogicalExpr to AST Expr - fn logical_to_ast_expr(expr: &LogicalExpr) -> Result { + /// Convert LogicalExpr to AST Expr with qualified column resolution + fn logical_to_ast_expr_with_schema( + expr: &LogicalExpr, + schema: &LogicalSchema, + ) -> Result { use turso_parser::ast; match expr { - LogicalExpr::Column(col) => Ok(ast::Expr::Id(ast::Name::Ident(col.name.clone()))), + LogicalExpr::Column(col) => { + // Find the column index using qualified lookup + let (idx, _) = schema + .find_column(&col.name, col.table.as_deref()) + .ok_or_else(|| { + LimboError::ParseError(format!( + "Column '{}' with table {:?} not found in schema", + col.name, col.table + )) + })?; + // Return a Register expression with the correct index + Ok(ast::Expr::Register(idx)) + } LogicalExpr::Literal(val) => { let lit = match val { Value::Integer(i) => ast::Literal::Numeric(i.to_string()), @@ -1155,8 +1295,8 @@ impl DbspCompiler { Ok(ast::Expr::Literal(lit)) } LogicalExpr::BinaryExpr { left, op, right } => { - let left_expr = Self::logical_to_ast_expr(left)?; - let right_expr = Self::logical_to_ast_expr(right)?; + let left_expr = Self::logical_to_ast_expr_with_schema(left, schema)?; + let right_expr = Self::logical_to_ast_expr_with_schema(right, schema)?; Ok(ast::Expr::Binary( Box::new(left_expr), *op, @@ -1164,7 +1304,10 @@ impl DbspCompiler { )) } LogicalExpr::ScalarFunction { fun, args } => { - let ast_args: Result> = args.iter().map(Self::logical_to_ast_expr).collect(); + let ast_args: Result> = args + .iter() + .map(|arg| Self::logical_to_ast_expr_with_schema(arg, schema)) + .collect(); let ast_args: Vec> = ast_args?.into_iter().map(Box::new).collect(); Ok(ast::Expr::FunctionCall { name: ast::Name::Ident(fun.clone()), @@ -1179,7 +1322,7 @@ impl DbspCompiler { } LogicalExpr::Alias { expr, .. } => { // For conversion to AST, ignore the alias and convert the inner expression - Self::logical_to_ast_expr(expr) + Self::logical_to_ast_expr_with_schema(expr, schema) } LogicalExpr::AggregateFunction { fun, @@ -1187,7 +1330,10 @@ impl DbspCompiler { distinct, } => { // Convert aggregate function to AST - let ast_args: Result> = args.iter().map(Self::logical_to_ast_expr).collect(); + let ast_args: Result> = args + .iter() + .map(|arg| Self::logical_to_ast_expr_with_schema(arg, schema)) + .collect(); let ast_args: Vec> = ast_args?.into_iter().map(Box::new).collect(); // Get the function name based on the aggregate type @@ -1315,8 +1461,7 @@ mod tests { use crate::incremental::operator::{FilterOperator, FilterPredicate}; use crate::schema::{BTreeTable, Column as SchemaColumn, Schema, Type}; use crate::storage::pager::CreateBTreeFlags; - use crate::translate::logical::LogicalPlanBuilder; - use crate::translate::logical::LogicalSchema; + use crate::translate::logical::{ColumnInfo, LogicalPlanBuilder, LogicalSchema}; use crate::util::IOExt; use crate::{Database, MemoryIO, Pager, IO}; use std::sync::Arc; @@ -1374,6 +1519,270 @@ mod tests { unique_sets: vec![], }; schema.add_btree_table(Arc::new(users_table)); + + // Add products table for join tests + let products_table = BTreeTable { + name: "products".to_string(), + root_page: 3, + primary_key_columns: vec![( + "product_id".to_string(), + turso_parser::ast::SortOrder::Asc, + )], + columns: vec![ + SchemaColumn { + name: Some("product_id".to_string()), + ty: Type::Integer, + ty_str: "INTEGER".to_string(), + primary_key: true, + is_rowid_alias: true, + notnull: true, + default: None, + unique: false, + collation: None, + hidden: false, + }, + SchemaColumn { + name: Some("product_name".to_string()), + ty: Type::Text, + ty_str: "TEXT".to_string(), + primary_key: false, + is_rowid_alias: false, + notnull: false, + default: None, + unique: false, + collation: None, + hidden: false, + }, + SchemaColumn { + name: Some("price".to_string()), + ty: Type::Integer, + ty_str: "INTEGER".to_string(), + primary_key: false, + is_rowid_alias: false, + notnull: false, + default: None, + unique: false, + collation: None, + hidden: false, + }, + ], + has_rowid: true, + is_strict: false, + unique_sets: vec![], + }; + schema.add_btree_table(Arc::new(products_table)); + + // Add orders table for join tests + let orders_table = BTreeTable { + name: "orders".to_string(), + root_page: 4, + primary_key_columns: vec![( + "order_id".to_string(), + turso_parser::ast::SortOrder::Asc, + )], + columns: vec![ + SchemaColumn { + name: Some("order_id".to_string()), + ty: Type::Integer, + ty_str: "INTEGER".to_string(), + primary_key: true, + is_rowid_alias: true, + notnull: true, + default: None, + unique: false, + collation: None, + hidden: false, + }, + SchemaColumn { + name: Some("user_id".to_string()), + ty: Type::Integer, + ty_str: "INTEGER".to_string(), + primary_key: false, + is_rowid_alias: false, + notnull: false, + default: None, + unique: false, + collation: None, + hidden: false, + }, + SchemaColumn { + name: Some("product_id".to_string()), + ty: Type::Integer, + ty_str: "INTEGER".to_string(), + primary_key: false, + is_rowid_alias: false, + notnull: false, + default: None, + unique: false, + collation: None, + hidden: false, + }, + SchemaColumn { + name: Some("quantity".to_string()), + ty: Type::Integer, + ty_str: "INTEGER".to_string(), + primary_key: false, + is_rowid_alias: false, + notnull: false, + default: None, + unique: false, + collation: None, + hidden: false, + }, + ], + has_rowid: true, + is_strict: false, + unique_sets: vec![], + }; + schema.add_btree_table(Arc::new(orders_table)); + + // Add customers table with id and name for testing column ambiguity + let customers_table = BTreeTable { + name: "customers".to_string(), + root_page: 6, + primary_key_columns: vec![("id".to_string(), turso_parser::ast::SortOrder::Asc)], + columns: vec![ + SchemaColumn { + name: Some("id".to_string()), + ty: Type::Integer, + ty_str: "INTEGER".to_string(), + primary_key: true, + is_rowid_alias: true, + notnull: true, + default: None, + unique: false, + collation: None, + hidden: false, + }, + SchemaColumn { + name: Some("name".to_string()), + ty: Type::Text, + ty_str: "TEXT".to_string(), + primary_key: false, + is_rowid_alias: false, + notnull: false, + default: None, + unique: false, + collation: None, + hidden: false, + }, + ], + has_rowid: true, + is_strict: false, + unique_sets: vec![], + }; + schema.add_btree_table(Arc::new(customers_table)); + + // Add purchases table (junction table for three-way join) + let purchases_table = BTreeTable { + name: "purchases".to_string(), + root_page: 7, + primary_key_columns: vec![("id".to_string(), turso_parser::ast::SortOrder::Asc)], + columns: vec![ + SchemaColumn { + name: Some("id".to_string()), + ty: Type::Integer, + ty_str: "INTEGER".to_string(), + primary_key: true, + is_rowid_alias: true, + notnull: true, + default: None, + unique: false, + collation: None, + hidden: false, + }, + SchemaColumn { + name: Some("customer_id".to_string()), + ty: Type::Integer, + ty_str: "INTEGER".to_string(), + primary_key: false, + is_rowid_alias: false, + notnull: false, + default: None, + unique: false, + collation: None, + hidden: false, + }, + SchemaColumn { + name: Some("vendor_id".to_string()), + ty: Type::Integer, + ty_str: "INTEGER".to_string(), + primary_key: false, + is_rowid_alias: false, + notnull: false, + default: None, + unique: false, + collation: None, + hidden: false, + }, + SchemaColumn { + name: Some("quantity".to_string()), + ty: Type::Integer, + ty_str: "INTEGER".to_string(), + primary_key: false, + is_rowid_alias: false, + notnull: false, + default: None, + unique: false, + collation: None, + hidden: false, + }, + ], + has_rowid: true, + is_strict: false, + unique_sets: vec![], + }; + schema.add_btree_table(Arc::new(purchases_table)); + + // Add vendors table with id, name, and price (ambiguous columns with customers) + let vendors_table = BTreeTable { + name: "vendors".to_string(), + root_page: 8, + primary_key_columns: vec![("id".to_string(), turso_parser::ast::SortOrder::Asc)], + columns: vec![ + SchemaColumn { + name: Some("id".to_string()), + ty: Type::Integer, + ty_str: "INTEGER".to_string(), + primary_key: true, + is_rowid_alias: true, + notnull: true, + default: None, + unique: false, + collation: None, + hidden: false, + }, + SchemaColumn { + name: Some("name".to_string()), + ty: Type::Text, + ty_str: "TEXT".to_string(), + primary_key: false, + is_rowid_alias: false, + notnull: false, + default: None, + unique: false, + collation: None, + hidden: false, + }, + SchemaColumn { + name: Some("price".to_string()), + ty: Type::Integer, + ty_str: "INTEGER".to_string(), + primary_key: false, + is_rowid_alias: false, + notnull: false, + default: None, + unique: false, + collation: None, + hidden: false, + }, + ], + has_rowid: true, + is_strict: false, + unique_sets: vec![], + }; + schema.add_btree_table(Arc::new(vendors_table)); + let sales_table = BTreeTable { name: "sales".to_string(), root_page: 2, @@ -3342,8 +3751,20 @@ mod tests { // Create a simple filter node let schema = Arc::new(LogicalSchema::new(vec![ - ("id".to_string(), Type::Integer), - ("value".to_string(), Type::Integer), + ColumnInfo { + name: "id".to_string(), + ty: Type::Integer, + database: None, + table: None, + table_alias: None, + }, + ColumnInfo { + name: "value".to_string(), + ty: Type::Integer, + database: None, + table: None, + table_alias: None, + }, ])); // First create an input node with InputOperator @@ -3486,4 +3907,767 @@ mod tests { "Row should still exist with multiplicity 1" ); } + + #[test] + fn test_join_with_aggregation() { + // Test join followed by aggregation - verifying actual output + let (mut circuit, pager) = compile_sql!( + "SELECT u.name, SUM(o.quantity) as total_quantity + FROM users u + JOIN orders o ON u.id = o.user_id + GROUP BY u.name" + ); + + // Create test data for users + let mut users_delta = Delta::new(); + users_delta.insert( + 1, + vec![ + Value::Integer(1), + Value::Text("Alice".into()), + Value::Integer(30), + ], + ); + users_delta.insert( + 2, + vec![ + Value::Integer(2), + Value::Text("Bob".into()), + Value::Integer(25), + ], + ); + + // Create test data for orders (order_id, user_id, product_id, quantity) + let mut orders_delta = Delta::new(); + orders_delta.insert( + 1, + vec![ + Value::Integer(1), + Value::Integer(1), + Value::Integer(101), + Value::Integer(5), + ], + ); // Alice: 5 + orders_delta.insert( + 2, + vec![ + Value::Integer(2), + Value::Integer(1), + Value::Integer(102), + Value::Integer(3), + ], + ); // Alice: 3 + orders_delta.insert( + 3, + vec![ + Value::Integer(3), + Value::Integer(2), + Value::Integer(101), + Value::Integer(7), + ], + ); // Bob: 7 + orders_delta.insert( + 4, + vec![ + Value::Integer(4), + Value::Integer(1), + Value::Integer(103), + Value::Integer(2), + ], + ); // Alice: 2 + + let inputs = HashMap::from([ + ("users".to_string(), users_delta), + ("orders".to_string(), orders_delta), + ]); + + let result = test_execute(&mut circuit, inputs.clone(), pager.clone()).unwrap(); + + // Should have 2 results: Alice with total 10, Bob with total 7 + assert_eq!( + result.len(), + 2, + "Should have aggregated results for Alice and Bob" + ); + + // Check the results + let mut results_map: HashMap = HashMap::new(); + for (row, weight) in result.changes { + assert_eq!(weight, 1); + assert_eq!(row.values.len(), 2); // name and total_quantity + + if let (Value::Text(name), Value::Integer(total)) = (&row.values[0], &row.values[1]) { + results_map.insert(name.to_string(), *total); + } else { + panic!("Unexpected value types in result"); + } + } + + assert_eq!( + results_map.get("Alice"), + Some(&10), + "Alice should have total quantity 10" + ); + assert_eq!( + results_map.get("Bob"), + Some(&7), + "Bob should have total quantity 7" + ); + } + + #[test] + fn test_join_aggregate_with_filter() { + // Test complex query with join, filter, and aggregation - verifying output + let (mut circuit, pager) = compile_sql!( + "SELECT u.name, SUM(o.quantity) as total + FROM users u + JOIN orders o ON u.id = o.user_id + WHERE u.age > 18 + GROUP BY u.name" + ); + + // Create test data for users + let mut users_delta = Delta::new(); + users_delta.insert( + 1, + vec![ + Value::Integer(1), + Value::Text("Alice".into()), + Value::Integer(30), + ], + ); // age > 18 + users_delta.insert( + 2, + vec![ + Value::Integer(2), + Value::Text("Bob".into()), + Value::Integer(17), + ], + ); // age <= 18 + users_delta.insert( + 3, + vec![ + Value::Integer(3), + Value::Text("Charlie".into()), + Value::Integer(25), + ], + ); // age > 18 + + // Create test data for orders (order_id, user_id, product_id, quantity) + let mut orders_delta = Delta::new(); + orders_delta.insert( + 1, + vec![ + Value::Integer(1), + Value::Integer(1), + Value::Integer(101), + Value::Integer(5), + ], + ); // Alice: 5 + orders_delta.insert( + 2, + vec![ + Value::Integer(2), + Value::Integer(2), + Value::Integer(102), + Value::Integer(10), + ], + ); // Bob: 10 (should be filtered) + orders_delta.insert( + 3, + vec![ + Value::Integer(3), + Value::Integer(3), + Value::Integer(101), + Value::Integer(7), + ], + ); // Charlie: 7 + orders_delta.insert( + 4, + vec![ + Value::Integer(4), + Value::Integer(1), + Value::Integer(103), + Value::Integer(3), + ], + ); // Alice: 3 + + let inputs = HashMap::from([ + ("users".to_string(), users_delta), + ("orders".to_string(), orders_delta), + ]); + + let result = test_execute(&mut circuit, inputs.clone(), pager.clone()).unwrap(); + + // Should only have results for Alice and Charlie (Bob filtered out due to age <= 18) + assert_eq!( + result.len(), + 2, + "Should only have results for users with age > 18" + ); + + // Check the results + let mut results_map: HashMap = HashMap::new(); + for (row, weight) in result.changes { + assert_eq!(weight, 1); + assert_eq!(row.values.len(), 2); // name and total + + if let (Value::Text(name), Value::Integer(total)) = (&row.values[0], &row.values[1]) { + results_map.insert(name.to_string(), *total); + } + } + + assert_eq!( + results_map.get("Alice"), + Some(&8), + "Alice should have total 8" + ); + assert_eq!( + results_map.get("Charlie"), + Some(&7), + "Charlie should have total 7" + ); + assert_eq!(results_map.get("Bob"), None, "Bob should be filtered out"); + } + + #[test] + fn test_three_way_join_execution() { + // Test executing a 3-way join with aggregation + let (mut circuit, pager) = compile_sql!( + "SELECT u.name, p.product_name, SUM(o.quantity) as total + FROM users u + JOIN orders o ON u.id = o.user_id + JOIN products p ON o.product_id = p.product_id + GROUP BY u.name, p.product_name" + ); + + // Create test data for users + let mut users_delta = Delta::new(); + users_delta.insert( + 1, + vec![ + Value::Integer(1), + Value::Text("Alice".into()), + Value::Integer(25), + ], + ); + users_delta.insert( + 2, + vec![ + Value::Integer(2), + Value::Text("Bob".into()), + Value::Integer(30), + ], + ); + + // Create test data for products + let mut products_delta = Delta::new(); + products_delta.insert( + 100, + vec![ + Value::Integer(100), + Value::Text("Widget".into()), + Value::Integer(50), + ], + ); + products_delta.insert( + 101, + vec![ + Value::Integer(101), + Value::Text("Gadget".into()), + Value::Integer(75), + ], + ); + products_delta.insert( + 102, + vec![ + Value::Integer(102), + Value::Text("Doohickey".into()), + Value::Integer(25), + ], + ); + + // Create test data for orders joining users and products + let mut orders_delta = Delta::new(); + // Alice orders 5 Widgets + orders_delta.insert( + 1, + vec![ + Value::Integer(1), + Value::Integer(1), + Value::Integer(100), + Value::Integer(5), + ], + ); + // Alice orders 3 Gadgets + orders_delta.insert( + 2, + vec![ + Value::Integer(2), + Value::Integer(1), + Value::Integer(101), + Value::Integer(3), + ], + ); + // Bob orders 7 Widgets + orders_delta.insert( + 3, + vec![ + Value::Integer(3), + Value::Integer(2), + Value::Integer(100), + Value::Integer(7), + ], + ); + // Bob orders 2 Doohickeys + orders_delta.insert( + 4, + vec![ + Value::Integer(4), + Value::Integer(2), + Value::Integer(102), + Value::Integer(2), + ], + ); + // Alice orders 4 more Widgets + orders_delta.insert( + 5, + vec![ + Value::Integer(5), + Value::Integer(1), + Value::Integer(100), + Value::Integer(4), + ], + ); + + let mut inputs = HashMap::new(); + inputs.insert("users".to_string(), users_delta); + inputs.insert("products".to_string(), products_delta); + inputs.insert("orders".to_string(), orders_delta); + + // Execute the 3-way join with aggregation + let result = test_execute(&mut circuit, inputs.clone(), pager.clone()).unwrap(); + + // We should get aggregated results for each user-product combination + // Expected results: + // - Alice, Widget: 9 (5 + 4) + // - Alice, Gadget: 3 + // - Bob, Widget: 7 + // - Bob, Doohickey: 2 + assert_eq!(result.len(), 4, "Should have 4 aggregated results"); + + // Verify aggregation results + let mut found_results = std::collections::HashSet::new(); + for (row, weight) in result.changes.iter() { + assert_eq!(*weight, 1); + // Row should have name, product_name, and sum columns + assert_eq!(row.values.len(), 3); + + if let (Value::Text(name), Value::Text(product), Value::Integer(total)) = + (&row.values[0], &row.values[1], &row.values[2]) + { + let key = format!("{}-{}", name.as_ref(), product.as_ref()); + found_results.insert(key.clone()); + + match key.as_str() { + "Alice-Widget" => { + assert_eq!(*total, 9, "Alice should have ordered 9 Widgets total") + } + "Alice-Gadget" => assert_eq!(*total, 3, "Alice should have ordered 3 Gadgets"), + "Bob-Widget" => assert_eq!(*total, 7, "Bob should have ordered 7 Widgets"), + "Bob-Doohickey" => { + assert_eq!(*total, 2, "Bob should have ordered 2 Doohickeys") + } + _ => panic!("Unexpected result: {key}"), + } + } else { + panic!("Unexpected value types in result"); + } + } + + // Ensure we found all expected combinations + assert!(found_results.contains("Alice-Widget")); + assert!(found_results.contains("Alice-Gadget")); + assert!(found_results.contains("Bob-Widget")); + assert!(found_results.contains("Bob-Doohickey")); + } + + #[test] + fn test_join_execution() { + let (mut circuit, pager) = compile_sql!( + "SELECT u.name, o.quantity FROM users u JOIN orders o ON u.id = o.user_id" + ); + + // Create test data for users + let mut users_delta = Delta::new(); + users_delta.insert( + 1, + vec![ + Value::Integer(1), + Value::Text("Alice".into()), + Value::Integer(25), + ], + ); + users_delta.insert( + 2, + vec![ + Value::Integer(2), + Value::Text("Bob".into()), + Value::Integer(30), + ], + ); + + // Create test data for orders + let mut orders_delta = Delta::new(); + orders_delta.insert( + 1, + vec![ + Value::Integer(1), + Value::Integer(1), + Value::Integer(100), + Value::Integer(5), + ], + ); + orders_delta.insert( + 2, + vec![ + Value::Integer(2), + Value::Integer(1), + Value::Integer(101), + Value::Integer(3), + ], + ); + orders_delta.insert( + 3, + vec![ + Value::Integer(3), + Value::Integer(2), + Value::Integer(102), + Value::Integer(7), + ], + ); + + let mut inputs = HashMap::new(); + inputs.insert("users".to_string(), users_delta); + inputs.insert("orders".to_string(), orders_delta); + + // Execute the join + let result = test_execute(&mut circuit, inputs.clone(), pager.clone()).unwrap(); + + // We should get 3 results (2 orders for Alice, 1 for Bob) + assert_eq!(result.len(), 3, "Should have 3 join results"); + + // Verify the join results contain the correct data + let results: Vec<_> = result.changes.iter().collect(); + + // Check that we have the expected joined rows + for (row, weight) in results { + assert_eq!(*weight, 1); // All weights should be 1 for insertions + // Row should have name and quantity columns + assert_eq!(row.values.len(), 2); + } + } + + #[test] + fn test_three_way_join_with_column_ambiguity() { + // Test three-way join with aggregation where multiple tables have columns with the same name + // Ensures that column references are correctly resolved to their respective tables + // Tables: customers(id, name), purchases(id, customer_id, vendor_id, quantity), vendors(id, name, price) + // Note: both customers and vendors have 'id' and 'name' columns which can cause ambiguity + + let sql = "SELECT c.name as customer_name, v.name as vendor_name, + SUM(p.quantity) as total_quantity, + SUM(p.quantity * v.price) as total_value + FROM customers c + JOIN purchases p ON c.id = p.customer_id + JOIN vendors v ON p.vendor_id = v.id + GROUP BY c.name, v.name"; + + let (mut circuit, pager) = compile_sql!(sql); + + // Create test data for customers (id, name) + let mut customers_delta = Delta::new(); + customers_delta.insert(1, vec![Value::Integer(1), Value::Text("Alice".into())]); + customers_delta.insert(2, vec![Value::Integer(2), Value::Text("Bob".into())]); + + // Create test data for vendors (id, name, price) + let mut vendors_delta = Delta::new(); + vendors_delta.insert( + 1, + vec![ + Value::Integer(1), + Value::Text("Widget Co".into()), + Value::Integer(10), + ], + ); + vendors_delta.insert( + 2, + vec![ + Value::Integer(2), + Value::Text("Gadget Inc".into()), + Value::Integer(20), + ], + ); + + // Create test data for purchases (id, customer_id, vendor_id, quantity) + let mut purchases_delta = Delta::new(); + // Alice purchases 5 units from Widget Co + purchases_delta.insert( + 1, + vec![ + Value::Integer(1), + Value::Integer(1), // customer_id: Alice + Value::Integer(1), // vendor_id: Widget Co + Value::Integer(5), + ], + ); + // Alice purchases 3 units from Gadget Inc + purchases_delta.insert( + 2, + vec![ + Value::Integer(2), + Value::Integer(1), // customer_id: Alice + Value::Integer(2), // vendor_id: Gadget Inc + Value::Integer(3), + ], + ); + // Bob purchases 2 units from Widget Co + purchases_delta.insert( + 3, + vec![ + Value::Integer(3), + Value::Integer(2), // customer_id: Bob + Value::Integer(1), // vendor_id: Widget Co + Value::Integer(2), + ], + ); + // Alice purchases 4 more units from Widget Co + purchases_delta.insert( + 4, + vec![ + Value::Integer(4), + Value::Integer(1), // customer_id: Alice + Value::Integer(1), // vendor_id: Widget Co + Value::Integer(4), + ], + ); + + let inputs = HashMap::from([ + ("customers".to_string(), customers_delta), + ("purchases".to_string(), purchases_delta), + ("vendors".to_string(), vendors_delta), + ]); + + let result = test_execute(&mut circuit, inputs.clone(), pager.clone()).unwrap(); + + // Expected results: + // Alice|Gadget Inc|3|60 (3 units * 20 price = 60) + // Alice|Widget Co|9|90 (9 units * 10 price = 90) + // Bob|Widget Co|2|20 (2 units * 10 price = 20) + + assert_eq!(result.len(), 3, "Should have 3 aggregated results"); + + // Sort results for consistent testing + let mut results: Vec<_> = result.changes.into_iter().collect(); + results.sort_by(|a, b| { + let a_cust = &a.0.values[0]; + let a_vend = &a.0.values[1]; + let b_cust = &b.0.values[0]; + let b_vend = &b.0.values[1]; + (a_cust, a_vend).cmp(&(b_cust, b_vend)) + }); + + // Verify Alice's Gadget Inc purchases + assert_eq!(results[0].0.values[0], Value::Text("Alice".into())); + assert_eq!(results[0].0.values[1], Value::Text("Gadget Inc".into())); + assert_eq!(results[0].0.values[2], Value::Integer(3)); // total_quantity + assert_eq!(results[0].0.values[3], Value::Integer(60)); // total_value + + // Verify Alice's Widget Co purchases + assert_eq!(results[1].0.values[0], Value::Text("Alice".into())); + assert_eq!(results[1].0.values[1], Value::Text("Widget Co".into())); + assert_eq!(results[1].0.values[2], Value::Integer(9)); // total_quantity + assert_eq!(results[1].0.values[3], Value::Integer(90)); // total_value + + // Verify Bob's Widget Co purchases + assert_eq!(results[2].0.values[0], Value::Text("Bob".into())); + assert_eq!(results[2].0.values[1], Value::Text("Widget Co".into())); + assert_eq!(results[2].0.values[2], Value::Integer(2)); // total_quantity + assert_eq!(results[2].0.values[3], Value::Integer(20)); // total_value + } + + #[test] + fn test_join_with_aggregate_execution() { + let (mut circuit, pager) = compile_sql!( + "SELECT u.name, SUM(o.quantity) as total_quantity + FROM users u + JOIN orders o ON u.id = o.user_id + GROUP BY u.name" + ); + + // Create test data for users + let mut users_delta = Delta::new(); + users_delta.insert( + 1, + vec![ + Value::Integer(1), + Value::Text("Alice".into()), + Value::Integer(25), + ], + ); + users_delta.insert( + 2, + vec![ + Value::Integer(2), + Value::Text("Bob".into()), + Value::Integer(30), + ], + ); + + // Create test data for orders + let mut orders_delta = Delta::new(); + orders_delta.insert( + 1, + vec![ + Value::Integer(1), + Value::Integer(1), + Value::Integer(100), + Value::Integer(5), + ], + ); + orders_delta.insert( + 2, + vec![ + Value::Integer(2), + Value::Integer(1), + Value::Integer(101), + Value::Integer(3), + ], + ); + orders_delta.insert( + 3, + vec![ + Value::Integer(3), + Value::Integer(2), + Value::Integer(102), + Value::Integer(7), + ], + ); + + let mut inputs = HashMap::new(); + inputs.insert("users".to_string(), users_delta); + inputs.insert("orders".to_string(), orders_delta); + + // Execute the join with aggregation + let result = test_execute(&mut circuit, inputs.clone(), pager.clone()).unwrap(); + + // We should get 2 aggregated results (one for Alice, one for Bob) + assert_eq!(result.len(), 2, "Should have 2 aggregated results"); + + // Verify aggregation results + for (row, weight) in result.changes.iter() { + assert_eq!(*weight, 1); + // Row should have name and sum columns + assert_eq!(row.values.len(), 2); + + // Check the aggregated values + if let Value::Text(name) = &row.values[0] { + if name.as_ref() == "Alice" { + // Alice should have total quantity of 8 (5 + 3) + assert_eq!(row.values[1], Value::Integer(8)); + } else if name.as_ref() == "Bob" { + // Bob should have total quantity of 7 + assert_eq!(row.values[1], Value::Integer(7)); + } + } + } + } + + #[test] + fn test_filter_with_qualified_columns_in_join() { + // Test that filters correctly handle qualified column names in joins + // when multiple tables have columns with the SAME names. + // Both users and sales tables have an 'id' column which can be ambiguous. + + let (mut circuit, pager) = compile_sql!( + "SELECT users.id, users.name, sales.id, sales.amount + FROM users + JOIN sales ON users.id = sales.customer_id + WHERE users.id > 1 AND sales.id < 100" + ); + + // Create test data + let mut users_delta = Delta::new(); + let mut sales_delta = Delta::new(); + + // Users data: (id, name, age) + users_delta.insert( + 1, + vec![ + Value::Integer(1), + Value::Text("Alice".into()), + Value::Integer(30), + ], + ); // id = 1 + users_delta.insert( + 2, + vec![ + Value::Integer(2), + Value::Text("Bob".into()), + Value::Integer(25), + ], + ); // id = 2 + users_delta.insert( + 3, + vec![ + Value::Integer(3), + Value::Text("Charlie".into()), + Value::Integer(35), + ], + ); // id = 3 + + // Sales data: (id, customer_id, amount) + sales_delta.insert( + 50, + vec![Value::Integer(50), Value::Integer(1), Value::Integer(100)], + ); // sales.id = 50, customer_id = 1 + sales_delta.insert( + 99, + vec![Value::Integer(99), Value::Integer(2), Value::Integer(200)], + ); // sales.id = 99, customer_id = 2 + sales_delta.insert( + 150, + vec![Value::Integer(150), Value::Integer(3), Value::Integer(300)], + ); // sales.id = 150, customer_id = 3 + + let mut inputs = HashMap::new(); + inputs.insert("users".to_string(), users_delta); + inputs.insert("sales".to_string(), sales_delta); + + let result = test_execute(&mut circuit, inputs.clone(), pager.clone()).unwrap(); + + // Should only get row with Bob (users.id=2, sales.id=99): + // - users.id=2 (> 1) AND sales.id=99 (< 100) ✓ + // Alice excluded: users.id=1 (NOT > 1) + // Charlie excluded: sales.id=150 (NOT < 100) + assert_eq!(result.len(), 1, "Should have 1 filtered result"); + + let (row, weight) = &result.changes[0]; + assert_eq!(*weight, 1); + assert_eq!(row.values.len(), 4, "Should have 4 columns"); + + // Verify the filter correctly used qualified columns + assert_eq!(row.values[0], Value::Integer(2), "users.id should be 2"); + assert_eq!( + row.values[1], + Value::Text("Bob".into()), + "users.name should be Bob" + ); + assert_eq!(row.values[2], Value::Integer(99), "sales.id should be 99"); + assert_eq!( + row.values[3], + Value::Integer(200), + "sales.amount should be 200" + ); + } } diff --git a/core/incremental/operator.rs b/core/incremental/operator.rs index 54cd7e0a0..72ed7bc0c 100644 --- a/core/incremental/operator.rs +++ b/core/incremental/operator.rs @@ -3,7 +3,7 @@ // Based on Feldera DBSP design but adapted for Turso's architecture pub use crate::incremental::aggregate_operator::{ - AggregateEvalState, AggregateFunction, AggregateOperator, AggregateState, + AggregateEvalState, AggregateFunction, AggregateState, }; pub use crate::incremental::filter_operator::{FilterOperator, FilterPredicate}; pub use crate::incremental::input_operator::InputOperator; @@ -251,7 +251,7 @@ pub trait IncrementalOperator: Debug { #[cfg(test)] mod tests { use super::*; - use crate::incremental::aggregate_operator::AGG_TYPE_REGULAR; + use crate::incremental::aggregate_operator::{AggregateOperator, AGG_TYPE_REGULAR}; use crate::incremental::dbsp::HashableRow; use crate::storage::pager::CreateBTreeFlags; use crate::types::Text; @@ -395,9 +395,9 @@ mod tests { // Create an aggregate operator for SUM(age) with no GROUP BY let mut agg = AggregateOperator::new( - 1, // operator_id for testing - vec![], // No GROUP BY - vec![AggregateFunction::Sum("age".to_string())], + 1, // operator_id for testing + vec![], // No GROUP BY + vec![AggregateFunction::Sum(2)], // age is at index 2 vec!["id".to_string(), "name".to_string(), "age".to_string()], ); @@ -514,9 +514,9 @@ mod tests { let mut cursors = DbspStateCursors::new(table_cursor, index_cursor); let mut agg = AggregateOperator::new( - 1, // operator_id for testing - vec!["team".to_string()], // GROUP BY team - vec![AggregateFunction::Sum("score".to_string())], + 1, // operator_id for testing + vec![1], // GROUP BY team (index 1) + vec![AggregateFunction::Sum(3)], // score is at index 3 vec![ "id".to_string(), "team".to_string(), @@ -666,8 +666,8 @@ mod tests { // Create COUNT(*) GROUP BY category let mut agg = AggregateOperator::new( - 1, // operator_id for testing - vec!["category".to_string()], + 1, // operator_id for testing + vec![1], // category is at index 1 vec![AggregateFunction::Count], vec![ "item_id".to_string(), @@ -746,9 +746,9 @@ mod tests { let mut cursors = DbspStateCursors::new(table_cursor, index_cursor); let mut agg = AggregateOperator::new( - 1, // operator_id for testing - vec!["product".to_string()], - vec![AggregateFunction::Sum("amount".to_string())], + 1, // operator_id for testing + vec![1], // product is at index 1 + vec![AggregateFunction::Sum(2)], // amount is at index 2 vec![ "sale_id".to_string(), "product".to_string(), @@ -843,11 +843,11 @@ mod tests { let mut cursors = DbspStateCursors::new(table_cursor, index_cursor); let mut agg = AggregateOperator::new( - 1, // operator_id for testing - vec!["user_id".to_string()], + 1, // operator_id for testing + vec![1], // user_id is at index 1 vec![ AggregateFunction::Count, - AggregateFunction::Sum("amount".to_string()), + AggregateFunction::Sum(2), // amount is at index 2 ], vec![ "order_id".to_string(), @@ -935,9 +935,9 @@ mod tests { let mut cursors = DbspStateCursors::new(table_cursor, index_cursor); let mut agg = AggregateOperator::new( - 1, // operator_id for testing - vec!["category".to_string()], - vec![AggregateFunction::Avg("value".to_string())], + 1, // operator_id for testing + vec![1], // category is at index 1 + vec![AggregateFunction::Avg(2)], // value is at index 2 vec![ "id".to_string(), "category".to_string(), @@ -1035,11 +1035,11 @@ mod tests { let mut cursors = DbspStateCursors::new(table_cursor, index_cursor); let mut agg = AggregateOperator::new( - 1, // operator_id for testing - vec!["category".to_string()], + 1, // operator_id for testing + vec![1], // category is at index 1 vec![ AggregateFunction::Count, - AggregateFunction::Sum("value".to_string()), + AggregateFunction::Sum(2), // value is at index 2 ], vec![ "id".to_string(), @@ -1108,7 +1108,7 @@ mod tests { #[test] fn test_count_aggregation_with_deletions() { let aggregates = vec![AggregateFunction::Count]; - let group_by = vec!["category".to_string()]; + let group_by = vec![0]; // category is at index 0 let input_columns = vec!["category".to_string(), "value".to_string()]; // Create a persistent pager for the test @@ -1197,8 +1197,8 @@ mod tests { #[test] fn test_sum_aggregation_with_deletions() { - let aggregates = vec![AggregateFunction::Sum("value".to_string())]; - let group_by = vec!["category".to_string()]; + let aggregates = vec![AggregateFunction::Sum(1)]; // value is at index 1 + let group_by = vec![0]; // category is at index 0 let input_columns = vec!["category".to_string(), "value".to_string()]; // Create a persistent pager for the test @@ -1281,8 +1281,8 @@ mod tests { #[test] fn test_avg_aggregation_with_deletions() { - let aggregates = vec![AggregateFunction::Avg("value".to_string())]; - let group_by = vec!["category".to_string()]; + let aggregates = vec![AggregateFunction::Avg(1)]; // value is at index 1 + let group_by = vec![0]; // category is at index 0 let input_columns = vec!["category".to_string(), "value".to_string()]; // Create a persistent pager for the test @@ -1348,10 +1348,10 @@ mod tests { // Test COUNT, SUM, and AVG together let aggregates = vec![ AggregateFunction::Count, - AggregateFunction::Sum("value".to_string()), - AggregateFunction::Avg("value".to_string()), + AggregateFunction::Sum(1), // value is at index 1 + AggregateFunction::Avg(1), // value is at index 1 ]; - let group_by = vec!["category".to_string()]; + let group_by = vec![0]; // category is at index 0 let input_columns = vec!["category".to_string(), "value".to_string()]; // Create a persistent pager for the test @@ -1607,11 +1607,11 @@ mod tests { let mut cursors = DbspStateCursors::new(table_cursor, index_cursor); let mut agg = AggregateOperator::new( - 1, // operator_id for testing - vec!["category".to_string()], + 1, // operator_id for testing + vec![1], // category is at index 1 vec![ AggregateFunction::Count, - AggregateFunction::Sum("amount".to_string()), + AggregateFunction::Sum(2), // amount is at index 2 ], vec![ "id".to_string(), @@ -1781,7 +1781,7 @@ mod tests { vec![], // No GROUP BY vec![ AggregateFunction::Count, - AggregateFunction::Sum("value".to_string()), + AggregateFunction::Sum(1), // value is at index 1 ], vec!["id".to_string(), "value".to_string()], ); @@ -1859,8 +1859,8 @@ mod tests { let mut cursors = DbspStateCursors::new(table_cursor, index_cursor); let mut agg = AggregateOperator::new( - 1, // operator_id for testing - vec!["type".to_string()], + 1, // operator_id for testing + vec![1], // type is at index 1 vec![AggregateFunction::Count], vec!["id".to_string(), "type".to_string()], ); @@ -1976,8 +1976,8 @@ mod tests { 1, // operator_id vec![], // No GROUP BY vec![ - AggregateFunction::Min("price".to_string()), - AggregateFunction::Max("price".to_string()), + AggregateFunction::Min(2), // price is at index 2 + AggregateFunction::Max(2), // price is at index 2 ], vec!["id".to_string(), "name".to_string(), "price".to_string()], ); @@ -2044,8 +2044,8 @@ mod tests { 1, // operator_id vec![], // No GROUP BY vec![ - AggregateFunction::Min("price".to_string()), - AggregateFunction::Max("price".to_string()), + AggregateFunction::Min(2), // price is at index 2 + AggregateFunction::Max(2), // price is at index 2 ], vec!["id".to_string(), "name".to_string(), "price".to_string()], ); @@ -2134,8 +2134,8 @@ mod tests { 1, // operator_id vec![], // No GROUP BY vec![ - AggregateFunction::Min("price".to_string()), - AggregateFunction::Max("price".to_string()), + AggregateFunction::Min(2), // price is at index 2 + AggregateFunction::Max(2), // price is at index 2 ], vec!["id".to_string(), "name".to_string(), "price".to_string()], ); @@ -2224,8 +2224,8 @@ mod tests { 1, // operator_id vec![], // No GROUP BY vec![ - AggregateFunction::Min("price".to_string()), - AggregateFunction::Max("price".to_string()), + AggregateFunction::Min(2), // price is at index 2 + AggregateFunction::Max(2), // price is at index 2 ], vec!["id".to_string(), "name".to_string(), "price".to_string()], ); @@ -2306,8 +2306,8 @@ mod tests { 1, // operator_id vec![], // No GROUP BY vec![ - AggregateFunction::Min("price".to_string()), - AggregateFunction::Max("price".to_string()), + AggregateFunction::Min(2), // price is at index 2 + AggregateFunction::Max(2), // price is at index 2 ], vec!["id".to_string(), "name".to_string(), "price".to_string()], ); @@ -2388,8 +2388,8 @@ mod tests { 1, // operator_id vec![], // No GROUP BY vec![ - AggregateFunction::Min("price".to_string()), - AggregateFunction::Max("price".to_string()), + AggregateFunction::Min(2), // price is at index 2 + AggregateFunction::Max(2), // price is at index 2 ], vec!["id".to_string(), "name".to_string(), "price".to_string()], ); @@ -2475,11 +2475,11 @@ mod tests { let mut cursors = DbspStateCursors::new(table_cursor, index_cursor); let mut agg = AggregateOperator::new( - 1, // operator_id - vec!["category".to_string()], // GROUP BY category + 1, // operator_id + vec![1], // GROUP BY category (index 1) vec![ - AggregateFunction::Min("price".to_string()), - AggregateFunction::Max("price".to_string()), + AggregateFunction::Min(3), // price is at index 3 + AggregateFunction::Max(3), // price is at index 3 ], vec![ "id".to_string(), @@ -2580,8 +2580,8 @@ mod tests { 1, // operator_id vec![], // No GROUP BY vec![ - AggregateFunction::Min("price".to_string()), - AggregateFunction::Max("price".to_string()), + AggregateFunction::Min(2), // price is at index 2 + AggregateFunction::Max(2), // price is at index 2 ], vec!["id".to_string(), "name".to_string(), "price".to_string()], ); @@ -2656,8 +2656,8 @@ mod tests { 1, // operator_id vec![], // No GROUP BY vec![ - AggregateFunction::Min("score".to_string()), - AggregateFunction::Max("score".to_string()), + AggregateFunction::Min(2), // score is at index 2 + AggregateFunction::Max(2), // score is at index 2 ], vec!["id".to_string(), "name".to_string(), "score".to_string()], ); @@ -2724,8 +2724,8 @@ mod tests { 1, // operator_id vec![], // No GROUP BY vec![ - AggregateFunction::Min("name".to_string()), - AggregateFunction::Max("name".to_string()), + AggregateFunction::Min(1), // name is at index 1 + AggregateFunction::Max(1), // name is at index 1 ], vec!["id".to_string(), "name".to_string()], ); @@ -2764,10 +2764,10 @@ mod tests { vec![], // No GROUP BY vec![ AggregateFunction::Count, - AggregateFunction::Sum("value".to_string()), - AggregateFunction::Min("value".to_string()), - AggregateFunction::Max("value".to_string()), - AggregateFunction::Avg("value".to_string()), + AggregateFunction::Sum(1), // value is at index 1 + AggregateFunction::Min(1), // value is at index 1 + AggregateFunction::Max(1), // value is at index 1 + AggregateFunction::Avg(1), // value is at index 1 ], vec!["id".to_string(), "value".to_string()], ); @@ -2855,9 +2855,9 @@ mod tests { 1, // operator_id vec![], // No GROUP BY vec![ - AggregateFunction::Min("col1".to_string()), - AggregateFunction::Max("col2".to_string()), - AggregateFunction::Min("col3".to_string()), + AggregateFunction::Min(0), // col1 is at index 0 + AggregateFunction::Max(1), // col2 is at index 1 + AggregateFunction::Min(2), // col3 is at index 2 ], vec!["col1".to_string(), "col2".to_string(), "col3".to_string()], ); diff --git a/core/translate/logical.rs b/core/translate/logical.rs index 6a8b0a6c2..b11e2df4f 100644 --- a/core/translate/logical.rs +++ b/core/translate/logical.rs @@ -19,26 +19,35 @@ use turso_parser::ast; /// Result type for preprocessing aggregate expressions type PreprocessAggregateResult = ( - bool, // needs_pre_projection - Vec, // pre_projection_exprs - Vec<(String, Type)>, // pre_projection_schema - Vec, // modified_aggr_exprs + bool, // needs_pre_projection + Vec, // pre_projection_exprs + Vec, // pre_projection_schema + Vec, // modified_aggr_exprs ); /// Result type for parsing join conditions type JoinConditionsResult = (Vec<(LogicalExpr, LogicalExpr)>, Option); +/// Information about a column in a logical schema +#[derive(Debug, Clone, PartialEq)] +pub struct ColumnInfo { + pub name: String, + pub ty: Type, + pub database: Option, + pub table: Option, + pub table_alias: Option, +} + /// Schema information for logical plan nodes #[derive(Debug, Clone, PartialEq)] pub struct LogicalSchema { - /// Column names and types - pub columns: Vec<(String, Type)>, + pub columns: Vec, } /// A reference to a schema that can be shared between nodes pub type SchemaRef = Arc; impl LogicalSchema { - pub fn new(columns: Vec<(String, Type)>) -> Self { + pub fn new(columns: Vec) -> Self { Self { columns } } @@ -52,11 +61,42 @@ impl LogicalSchema { self.columns.len() } - pub fn find_column(&self, name: &str) -> Option<(usize, &Type)> { - self.columns - .iter() - .position(|(n, _)| n == name) - .map(|idx| (idx, &self.columns[idx].1)) + pub fn find_column(&self, name: &str, table: Option<&str>) -> Option<(usize, &ColumnInfo)> { + if let Some(table_ref) = table { + // Check if it's a database.table format + if table_ref.contains('.') { + let parts: Vec<&str> = table_ref.splitn(2, '.').collect(); + if parts.len() == 2 { + let db = parts[0]; + let tbl = parts[1]; + return self + .columns + .iter() + .position(|c| { + c.name == name + && c.database.as_deref() == Some(db) + && c.table.as_deref() == Some(tbl) + }) + .map(|idx| (idx, &self.columns[idx])); + } + } + + // Try to match against table alias first, then table name + self.columns + .iter() + .position(|c| { + c.name == name + && (c.table_alias.as_deref() == Some(table_ref) + || c.table.as_deref() == Some(table_ref)) + }) + .map(|idx| (idx, &self.columns[idx])) + } else { + // Unqualified lookup - just match by name + self.columns + .iter() + .position(|c| c.name == name) + .map(|idx| (idx, &self.columns[idx])) + } } } @@ -548,14 +588,14 @@ impl<'a> LogicalPlanBuilder<'a> { } // Regular table scan - let table_schema = self.get_table_schema(&table_name)?; let table_alias = alias.as_ref().map(|a| match a { ast::As::As(name) => Self::name_to_string(name), ast::As::Elided(name) => Self::name_to_string(name), }); + let table_schema = self.get_table_schema(&table_name, table_alias.as_deref())?; Ok(LogicalPlan::TableScan(TableScan { table_name, - alias: table_alias, + alias: table_alias.clone(), schema: table_schema, projection: None, })) @@ -751,14 +791,14 @@ impl<'a> LogicalPlanBuilder<'a> { let _left_idx = left_schema .columns .iter() - .position(|(n, _)| n == &name) + .position(|col| col.name == name) .ok_or_else(|| { LimboError::ParseError(format!("Column {name} not found in left table")) })?; let _right_idx = right_schema .columns .iter() - .position(|(n, _)| n == &name) + .position(|col| col.name == name) .ok_or_else(|| { LimboError::ParseError(format!("Column {name} not found in right table")) })?; @@ -790,9 +830,13 @@ impl<'a> LogicalPlanBuilder<'a> { // Find common column names let mut common_columns = Vec::new(); - for (left_name, _) in &left_schema.columns { - if right_schema.columns.iter().any(|(n, _)| n == left_name) { - common_columns.push(ast::Name::Ident(left_name.clone())); + for left_col in &left_schema.columns { + if right_schema + .columns + .iter() + .any(|col| col.name == left_col.name) + { + common_columns.push(ast::Name::Ident(left_col.name.clone())); } } @@ -833,10 +877,18 @@ impl<'a> LogicalPlanBuilder<'a> { let left_schema = left.schema(); let right_schema = right.schema(); - // For now, simply concatenate the schemas - // In a real implementation, we'd handle column name conflicts and nullable columns - let mut columns = left_schema.columns.clone(); - columns.extend(right_schema.columns.clone()); + // Concatenate the schemas, preserving all column information + let mut columns = Vec::new(); + + // Keep all columns from left with their table info + for col in &left_schema.columns { + columns.push(col.clone()); + } + + // Keep all columns from right with their table info + for col in &right_schema.columns { + columns.push(col.clone()); + } Ok(Arc::new(LogicalSchema::new(columns))) } @@ -870,7 +922,13 @@ impl<'a> LogicalPlanBuilder<'a> { }; let col_type = Self::infer_expr_type(&logical_expr, input_schema)?; - schema_columns.push((col_name.clone(), col_type)); + schema_columns.push(ColumnInfo { + name: col_name.clone(), + ty: col_type, + database: None, + table: None, + table_alias: None, + }); if let Some(as_alias) = alias { let alias_name = match as_alias { @@ -886,21 +944,21 @@ impl<'a> LogicalPlanBuilder<'a> { } ast::ResultColumn::Star => { // Expand * to all columns - for (name, typ) in &input_schema.columns { - proj_exprs.push(LogicalExpr::Column(Column::new(name.clone()))); - schema_columns.push((name.clone(), *typ)); + for col in &input_schema.columns { + proj_exprs.push(LogicalExpr::Column(Column::new(col.name.clone()))); + schema_columns.push(col.clone()); } } ast::ResultColumn::TableStar(table) => { // Expand table.* to all columns from that table let table_name = Self::name_to_string(table); - for (name, typ) in &input_schema.columns { + for col in &input_schema.columns { // Simple check - would need proper table tracking in real implementation proj_exprs.push(LogicalExpr::Column(Column::with_table( - name.clone(), + col.name.clone(), table_name.clone(), ))); - schema_columns.push((name.clone(), *typ)); + schema_columns.push(col.clone()); } } } @@ -938,7 +996,13 @@ impl<'a> LogicalPlanBuilder<'a> { if let LogicalExpr::Column(col) = expr { pre_projection_exprs.push(expr.clone()); let col_type = Self::infer_expr_type(expr, input_schema)?; - pre_projection_schema.push((col.name.clone(), col_type)); + pre_projection_schema.push(ColumnInfo { + name: col.name.clone(), + ty: col_type, + database: None, + table: col.table.clone(), + table_alias: None, + }); } else { // Complex group by expression - project it needs_pre_projection = true; @@ -946,7 +1010,13 @@ impl<'a> LogicalPlanBuilder<'a> { projected_col_counter += 1; pre_projection_exprs.push(expr.clone()); let col_type = Self::infer_expr_type(expr, input_schema)?; - pre_projection_schema.push((proj_col_name.clone(), col_type)); + pre_projection_schema.push(ColumnInfo { + name: proj_col_name.clone(), + ty: col_type, + database: None, + table: None, + table_alias: None, + }); } } @@ -970,7 +1040,13 @@ impl<'a> LogicalPlanBuilder<'a> { pre_projection_exprs.push(arg.clone()); let col_type = Self::infer_expr_type(arg, input_schema)?; if let LogicalExpr::Column(col) = arg { - pre_projection_schema.push((col.name.clone(), col_type)); + pre_projection_schema.push(ColumnInfo { + name: col.name.clone(), + ty: col_type, + database: None, + table: col.table.clone(), + table_alias: None, + }); } } } @@ -983,7 +1059,13 @@ impl<'a> LogicalPlanBuilder<'a> { // Add the expression to the pre-projection pre_projection_exprs.push(arg.clone()); let col_type = Self::infer_expr_type(arg, input_schema)?; - pre_projection_schema.push((proj_col_name.clone(), col_type)); + pre_projection_schema.push(ColumnInfo { + name: proj_col_name.clone(), + ty: col_type, + database: None, + table: None, + table_alias: None, + }); // In the aggregate, reference the projected column modified_args.push(LogicalExpr::Column(Column::new(proj_col_name))); @@ -1057,15 +1139,39 @@ impl<'a> LogicalPlanBuilder<'a> { // First, add GROUP BY columns to the aggregate output schema // These are always part of the aggregate operator's output for group_expr in &group_exprs { - let col_name = match group_expr { - LogicalExpr::Column(col) => col.name.clone(), + match group_expr { + LogicalExpr::Column(col) => { + // For column references in GROUP BY, preserve the original column info + if let Some((_, col_info)) = + input_schema.find_column(&col.name, col.table.as_deref()) + { + // Preserve the column with all its table information + aggregate_schema_columns.push(col_info.clone()); + } else { + // Fallback if column not found (shouldn't happen) + let col_type = Self::infer_expr_type(group_expr, input_schema)?; + aggregate_schema_columns.push(ColumnInfo { + name: col.name.clone(), + ty: col_type, + database: None, + table: col.table.clone(), + table_alias: None, + }); + } + } _ => { // For complex GROUP BY expressions, generate a name - format!("__group_{}", aggregate_schema_columns.len()) + let col_name = format!("__group_{}", aggregate_schema_columns.len()); + let col_type = Self::infer_expr_type(group_expr, input_schema)?; + aggregate_schema_columns.push(ColumnInfo { + name: col_name, + ty: col_type, + database: None, + table: None, + table_alias: None, + }); } - }; - let col_type = Self::infer_expr_type(group_expr, input_schema)?; - aggregate_schema_columns.push((col_name, col_type)); + } } // Track aggregates we've already seen to avoid duplicates @@ -1098,7 +1204,13 @@ impl<'a> LogicalPlanBuilder<'a> { } else { // New aggregate - add it let col_type = Self::infer_expr_type(&logical_expr, input_schema)?; - aggregate_schema_columns.push((col_name.clone(), col_type)); + aggregate_schema_columns.push(ColumnInfo { + name: col_name.clone(), + ty: col_type, + database: None, + table: None, + table_alias: None, + }); aggr_exprs.push(logical_expr); aggregate_map.insert(agg_key, col_name.clone()); col_name.clone() @@ -1122,7 +1234,13 @@ impl<'a> LogicalPlanBuilder<'a> { // Add only new aggregates for (agg_expr, agg_name) in extracted_aggs { let agg_type = Self::infer_expr_type(&agg_expr, input_schema)?; - aggregate_schema_columns.push((agg_name, agg_type)); + aggregate_schema_columns.push(ColumnInfo { + name: agg_name, + ty: agg_type, + database: None, + table: None, + table_alias: None, + }); aggr_exprs.push(agg_expr); } @@ -1197,7 +1315,13 @@ impl<'a> LogicalPlanBuilder<'a> { // For type inference, we need the aggregate schema for column references let aggregate_schema = LogicalSchema::new(aggregate_schema_columns.clone()); let col_type = Self::infer_expr_type(expr, &Arc::new(aggregate_schema))?; - projection_schema_columns.push((col_name, col_type)); + projection_schema_columns.push(ColumnInfo { + name: col_name, + ty: col_type, + database: None, + table: None, + table_alias: None, + }); } // Create the input plan (with pre-projection if needed) @@ -1220,11 +1344,11 @@ impl<'a> LogicalPlanBuilder<'a> { // Check if we need the outer projection // We need a projection if: - // 1. Any expression is more complex than a simple column reference (e.g., abs(sum(id))) - // 2. We're selecting a different set of columns than what the aggregate outputs - // 3. Columns are renamed or reordered + // 1. We have expressions that compute new values (e.g., SUM(x) * 2) + // 2. We're selecting a different set of columns than GROUP BY + aggregates + // 3. We're reordering columns from their natural aggregate output order let needs_outer_projection = { - // Check if any expression is more complex than a simple column reference + // Check for complex expressions let has_complex_exprs = projection_exprs .iter() .any(|expr| !matches!(expr, LogicalExpr::Column(_))); @@ -1232,17 +1356,29 @@ impl<'a> LogicalPlanBuilder<'a> { if has_complex_exprs { true } else { - // All are simple columns - check if we're selecting exactly what the aggregate outputs - // The projection might be selecting a subset (e.g., only aggregates without group columns) - // or reordering columns, or using different names + // Check if we're selecting exactly what aggregate outputs in the same order + // The aggregate outputs: all GROUP BY columns, then all aggregate expressions + // The projection might select a subset or reorder these - // For now, keep it simple: if schemas don't match exactly, we need projection - // This handles all cases: subset selection, reordering, renaming - projection_schema_columns != aggregate_schema_columns + if projection_exprs.len() != aggregate_schema_columns.len() { + // Different number of columns + true + } else { + // Check if columns match in order and name + !projection_exprs.iter().zip(&aggregate_schema_columns).all( + |(expr, agg_col)| { + if let LogicalExpr::Column(col) = expr { + col.name == agg_col.name + } else { + false + } + }, + ) + } } }; - // Create the aggregate node + // Create the aggregate node with its natural schema let aggregate_plan = LogicalPlan::Aggregate(Aggregate { input: aggregate_input, group_expr: group_exprs, @@ -1257,7 +1393,7 @@ impl<'a> LogicalPlanBuilder<'a> { schema: Arc::new(LogicalSchema::new(projection_schema_columns)), })) } else { - // No projection needed - the aggregate output is exactly what we want + // No projection needed - aggregate output matches what we want Ok(aggregate_plan) } } @@ -1275,7 +1411,13 @@ impl<'a> LogicalPlanBuilder<'a> { // Infer schema from first row let mut schema_columns = Vec::new(); for (i, _) in values[0].iter().enumerate() { - schema_columns.push((format!("column{}", i + 1), Type::Text)); + schema_columns.push(ColumnInfo { + name: format!("column{}", i + 1), + ty: Type::Text, + database: None, + table: None, + table_alias: None, + }); } for row in values { @@ -2003,17 +2145,31 @@ impl<'a> LogicalPlanBuilder<'a> { } // Get table schema - fn get_table_schema(&self, table_name: &str) -> Result { + fn get_table_schema(&self, table_name: &str, alias: Option<&str>) -> Result { // Look up table in schema let table = self .schema .get_table(table_name) .ok_or_else(|| LimboError::ParseError(format!("Table '{table_name}' not found")))?; + // Parse table_name which might be "db.table" for attached databases + let (database, actual_table) = if table_name.contains('.') { + let parts: Vec<&str> = table_name.splitn(2, '.').collect(); + (Some(parts[0].to_string()), parts[1].to_string()) + } else { + (None, table_name.to_string()) + }; + let mut columns = Vec::new(); for col in table.columns() { if let Some(ref name) = col.name { - columns.push((name.clone(), col.ty)); + columns.push(ColumnInfo { + name: name.clone(), + ty: col.ty, + database: database.clone(), + table: Some(actual_table.clone()), + table_alias: alias.map(|s| s.to_string()), + }); } } @@ -2024,8 +2180,8 @@ impl<'a> LogicalPlanBuilder<'a> { fn infer_expr_type(expr: &LogicalExpr, schema: &SchemaRef) -> Result { match expr { LogicalExpr::Column(col) => { - if let Some((_, typ)) = schema.find_column(&col.name) { - Ok(*typ) + if let Some((_, col_info)) = schema.find_column(&col.name, col.table.as_deref()) { + Ok(col_info.ty) } else { Ok(Type::Text) } From cb7c04ffad35bfc555ec4e0599ec2b098923e203 Mon Sep 17 00:00:00 2001 From: Glauber Costa Date: Wed, 17 Sep 2025 21:19:10 -0500 Subject: [PATCH 20/34] return error instead of panic for invalid syntax on views I have accidentally typed "create materialized views", and noticed that this panics, instead of returning an error. Fix it. --- parser/src/parser.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/parser/src/parser.rs b/parser/src/parser.rs index fa2230373..2fde878e4 100644 --- a/parser/src/parser.rs +++ b/parser/src/parser.rs @@ -749,7 +749,7 @@ impl<'a> Parser<'a> { fn parse_create_materialized_view(&mut self) -> Result { eat_assert!(self, TK_MATERIALIZED); - eat_assert!(self, TK_VIEW); + eat_expect!(self, TK_VIEW); let if_not_exists = self.parse_if_not_exists()?; let view_name = self.parse_fullname(false)?; let columns = self.parse_eid_list(false)?; From e80dd8e5e1b2ea7871b2cd4911c80256ca48f0d1 Mon Sep 17 00:00:00 2001 From: Glauber Costa Date: Thu, 18 Sep 2025 10:14:40 -0500 Subject: [PATCH 21/34] move the filter operator to accept indexes instead of names We already did similarly for the AggregateOperator: for joins you can have the same column name in many tables. And passing schema information to the operator is a layering violation (the operator may be operating on the result of a previous node, and at that point there is no more "schema"). Therefore we pass indexes into the column set the operator has. The FilterOperator has a complication: we are using it to generate the SQL for the populate statement, and that needs column names. However, we should *not* be using the FilterOperator for that, and that is a relic from the time where we had operator information directly inside the IncrementalView. To enable moving the FilterOperator to index-based, we rework that code. For joins, we'll need to populate many tables anyway, so we take the time to do that work here. --- core/incremental/compiler.rs | 127 ++-- core/incremental/filter_operator.rs | 238 ++---- core/incremental/operator.rs | 22 +- core/incremental/view.rs | 1065 ++++++++++++++++++++++++--- 4 files changed, 1121 insertions(+), 331 deletions(-) diff --git a/core/incremental/compiler.rs b/core/incremental/compiler.rs index c8899a02e..07fd8f83c 100644 --- a/core/incremental/compiler.rs +++ b/core/incremental/compiler.rs @@ -895,21 +895,18 @@ impl DbspCompiler { // Compile the input first let input_id = self.compile_plan(&filter.input)?; - // Get column names from input schema + // Get input schema for column resolution let input_schema = filter.input.schema(); - let column_names: Vec = input_schema.columns.iter() - .map(|col| col.name.clone()) - .collect(); // Convert predicate to DBSP expression let dbsp_predicate = Self::compile_expr(&filter.predicate)?; // Convert to FilterPredicate - let filter_predicate = Self::compile_filter_predicate(&filter.predicate)?; + let filter_predicate = Self::compile_filter_predicate(&filter.predicate, input_schema)?; // Create executable operator let executable: Box = - Box::new(FilterOperator::new(filter_predicate, column_names)); + Box::new(FilterOperator::new(filter_predicate)); // Create filter node let node_id = self.circuit.add_node( @@ -1372,42 +1369,57 @@ impl DbspCompiler { } /// Compile a logical expression to a FilterPredicate for execution - fn compile_filter_predicate(expr: &LogicalExpr) -> Result { + fn compile_filter_predicate( + expr: &LogicalExpr, + schema: &LogicalSchema, + ) -> Result { match expr { LogicalExpr::BinaryExpr { left, op, right } => { // Extract column name and value for simple predicates if let (LogicalExpr::Column(col), LogicalExpr::Literal(val)) = (left.as_ref(), right.as_ref()) { + // Resolve column name to index using the schema + let column_idx = schema + .columns + .iter() + .position(|c| c.name == col.name) + .ok_or_else(|| { + crate::LimboError::ParseError(format!( + "Column '{}' not found in schema for filter", + col.name + )) + })?; + match op { BinaryOperator::Equals => Ok(FilterPredicate::Equals { - column: col.name.clone(), + column_idx, value: val.clone(), }), BinaryOperator::NotEquals => Ok(FilterPredicate::NotEquals { - column: col.name.clone(), + column_idx, value: val.clone(), }), BinaryOperator::Greater => Ok(FilterPredicate::GreaterThan { - column: col.name.clone(), + column_idx, value: val.clone(), }), BinaryOperator::GreaterEquals => Ok(FilterPredicate::GreaterThanOrEqual { - column: col.name.clone(), + column_idx, value: val.clone(), }), BinaryOperator::Less => Ok(FilterPredicate::LessThan { - column: col.name.clone(), + column_idx, value: val.clone(), }), BinaryOperator::LessEquals => Ok(FilterPredicate::LessThanOrEqual { - column: col.name.clone(), + column_idx, value: val.clone(), }), BinaryOperator::And => { // Handle AND of two predicates - let left_pred = Self::compile_filter_predicate(left)?; - let right_pred = Self::compile_filter_predicate(right)?; + let left_pred = Self::compile_filter_predicate(left, schema)?; + let right_pred = Self::compile_filter_predicate(right, schema)?; Ok(FilterPredicate::And( Box::new(left_pred), Box::new(right_pred), @@ -1415,8 +1427,8 @@ impl DbspCompiler { } BinaryOperator::Or => { // Handle OR of two predicates - let left_pred = Self::compile_filter_predicate(left)?; - let right_pred = Self::compile_filter_predicate(right)?; + let left_pred = Self::compile_filter_predicate(left, schema)?; + let right_pred = Self::compile_filter_predicate(right, schema)?; Ok(FilterPredicate::Or( Box::new(left_pred), Box::new(right_pred), @@ -1428,8 +1440,8 @@ impl DbspCompiler { } } else if matches!(op, BinaryOperator::And | BinaryOperator::Or) { // Handle logical operators - let left_pred = Self::compile_filter_predicate(left)?; - let right_pred = Self::compile_filter_predicate(right)?; + let left_pred = Self::compile_filter_predicate(left, schema)?; + let right_pred = Self::compile_filter_predicate(right, schema)?; match op { BinaryOperator::And => Ok(FilterPredicate::And( Box::new(left_pred), @@ -3777,13 +3789,10 @@ mod tests { Box::new(InputOperator::new("test".to_string())), ); - let filter_op = FilterOperator::new( - FilterPredicate::GreaterThan { - column: "value".to_string(), - value: Value::Integer(10), - }, - vec!["id".to_string(), "value".to_string()], - ); + let filter_op = FilterOperator::new(FilterPredicate::GreaterThan { + column_idx: 1, // "value" is at index 1 + value: Value::Integer(10), + }); // Create the filter predicate using DbspExpr let predicate = DbspExpr::BinaryExpr { @@ -4587,18 +4596,18 @@ mod tests { fn test_filter_with_qualified_columns_in_join() { // Test that filters correctly handle qualified column names in joins // when multiple tables have columns with the SAME names. - // Both users and sales tables have an 'id' column which can be ambiguous. + // Both users and customers tables have 'id' and 'name' columns which can be ambiguous. let (mut circuit, pager) = compile_sql!( - "SELECT users.id, users.name, sales.id, sales.amount + "SELECT users.id, users.name, customers.id, customers.name FROM users - JOIN sales ON users.id = sales.customer_id - WHERE users.id > 1 AND sales.id < 100" + JOIN customers ON users.id = customers.id + WHERE users.id > 1 AND customers.id < 100" ); // Create test data let mut users_delta = Delta::new(); - let mut sales_delta = Delta::new(); + let mut customers_delta = Delta::new(); // Users data: (id, name, age) users_delta.insert( @@ -4626,48 +4635,60 @@ mod tests { ], ); // id = 3 - // Sales data: (id, customer_id, amount) - sales_delta.insert( - 50, - vec![Value::Integer(50), Value::Integer(1), Value::Integer(100)], - ); // sales.id = 50, customer_id = 1 - sales_delta.insert( - 99, - vec![Value::Integer(99), Value::Integer(2), Value::Integer(200)], - ); // sales.id = 99, customer_id = 2 - sales_delta.insert( - 150, - vec![Value::Integer(150), Value::Integer(3), Value::Integer(300)], - ); // sales.id = 150, customer_id = 3 + // Customers data: (id, name, email) + customers_delta.insert( + 1, + vec![ + Value::Integer(1), + Value::Text("Customer Alice".into()), + Value::Text("alice@example.com".into()), + ], + ); // id = 1 + customers_delta.insert( + 2, + vec![ + Value::Integer(2), + Value::Text("Customer Bob".into()), + Value::Text("bob@example.com".into()), + ], + ); // id = 2 + customers_delta.insert( + 3, + vec![ + Value::Integer(3), + Value::Text("Customer Charlie".into()), + Value::Text("charlie@example.com".into()), + ], + ); // id = 3 let mut inputs = HashMap::new(); inputs.insert("users".to_string(), users_delta); - inputs.insert("sales".to_string(), sales_delta); + inputs.insert("customers".to_string(), customers_delta); let result = test_execute(&mut circuit, inputs.clone(), pager.clone()).unwrap(); - // Should only get row with Bob (users.id=2, sales.id=99): - // - users.id=2 (> 1) AND sales.id=99 (< 100) ✓ + // Should get rows where users.id > 1 AND customers.id < 100 + // - users.id=2 (> 1) AND customers.id=2 (< 100) ✓ + // - users.id=3 (> 1) AND customers.id=3 (< 100) ✓ // Alice excluded: users.id=1 (NOT > 1) - // Charlie excluded: sales.id=150 (NOT < 100) - assert_eq!(result.len(), 1, "Should have 1 filtered result"); + assert_eq!(result.len(), 2, "Should have 2 filtered results"); let (row, weight) = &result.changes[0]; assert_eq!(*weight, 1); assert_eq!(row.values.len(), 4, "Should have 4 columns"); - // Verify the filter correctly used qualified columns + // Verify the filter correctly used qualified columns for Bob assert_eq!(row.values[0], Value::Integer(2), "users.id should be 2"); assert_eq!( row.values[1], Value::Text("Bob".into()), "users.name should be Bob" ); - assert_eq!(row.values[2], Value::Integer(99), "sales.id should be 99"); + assert_eq!(row.values[2], Value::Integer(2), "customers.id should be 2"); assert_eq!( row.values[3], - Value::Integer(200), - "sales.amount should be 200" + Value::Text("Customer Bob".into()), + "customers.name should be Customer Bob" ); } } diff --git a/core/incremental/filter_operator.rs b/core/incremental/filter_operator.rs index f836f4897..a0179f9d4 100644 --- a/core/incremental/filter_operator.rs +++ b/core/incremental/filter_operator.rs @@ -6,26 +6,25 @@ use crate::incremental::dbsp::{Delta, DeltaPair}; use crate::incremental::operator::{ ComputationTracker, DbspStateCursors, EvalState, IncrementalOperator, }; -use crate::types::{IOResult, Text}; +use crate::types::IOResult; use crate::{Result, Value}; use std::sync::{Arc, Mutex}; -use turso_parser::ast::{Expr, Literal, OneSelect, Operator}; /// Filter predicate for filtering rows #[derive(Debug, Clone)] pub enum FilterPredicate { - /// Column = value - Equals { column: String, value: Value }, - /// Column != value - NotEquals { column: String, value: Value }, - /// Column > value - GreaterThan { column: String, value: Value }, - /// Column >= value - GreaterThanOrEqual { column: String, value: Value }, - /// Column < value - LessThan { column: String, value: Value }, - /// Column <= value - LessThanOrEqual { column: String, value: Value }, + /// Column = value (using column index) + Equals { column_idx: usize, value: Value }, + /// Column != value (using column index) + NotEquals { column_idx: usize, value: Value }, + /// Column > value (using column index) + GreaterThan { column_idx: usize, value: Value }, + /// Column >= value (using column index) + GreaterThanOrEqual { column_idx: usize, value: Value }, + /// Column < value (using column index) + LessThan { column_idx: usize, value: Value }, + /// Column <= value (using column index) + LessThanOrEqual { column_idx: usize, value: Value }, /// Logical AND of two predicates And(Box, Box), /// Logical OR of two predicates @@ -34,122 +33,17 @@ pub enum FilterPredicate { None, } -impl FilterPredicate { - /// Parse a SQL AST expression into a FilterPredicate - /// This centralizes all SQL-to-predicate parsing logic - pub fn from_sql_expr(expr: &turso_parser::ast::Expr) -> crate::Result { - let Expr::Binary(lhs, op, rhs) = expr else { - return Err(crate::LimboError::ParseError( - "Unsupported WHERE clause for incremental views: not a binary expression" - .to_string(), - )); - }; - - // Handle AND/OR logical operators - match op { - Operator::And => { - let left = Self::from_sql_expr(lhs)?; - let right = Self::from_sql_expr(rhs)?; - return Ok(FilterPredicate::And(Box::new(left), Box::new(right))); - } - Operator::Or => { - let left = Self::from_sql_expr(lhs)?; - let right = Self::from_sql_expr(rhs)?; - return Ok(FilterPredicate::Or(Box::new(left), Box::new(right))); - } - _ => {} - } - - // Handle comparison operators - let Expr::Id(column_name) = &**lhs else { - return Err(crate::LimboError::ParseError( - "Unsupported WHERE clause for incremental views: left-hand-side is not a column reference".to_string(), - )); - }; - - let column = column_name.as_str().to_string(); - - // Parse the right-hand side value - let value = match &**rhs { - Expr::Literal(Literal::String(s)) => { - // Strip quotes from string literals - let cleaned = s.trim_matches('\'').trim_matches('"'); - Value::Text(Text::new(cleaned)) - } - Expr::Literal(Literal::Numeric(n)) => { - // Try to parse as integer first, then float - if let Ok(i) = n.parse::() { - Value::Integer(i) - } else if let Ok(f) = n.parse::() { - Value::Float(f) - } else { - return Err(crate::LimboError::ParseError( - "Unsupported WHERE clause for incremental views: right-hand-side is not a numeric literal".to_string(), - )); - } - } - Expr::Literal(Literal::Null) => Value::Null, - Expr::Literal(Literal::Blob(_)) => { - // Blob comparison not yet supported - return Err(crate::LimboError::ParseError( - "Unsupported WHERE clause for incremental views: comparison with blob literals is not supported".to_string(), - )); - } - other => { - // Complex expressions not yet supported - return Err(crate::LimboError::ParseError( - format!("Unsupported WHERE clause for incremental views: comparison with {other:?} is not supported"), - )); - } - }; - - // Create the appropriate predicate based on operator - match op { - Operator::Equals => Ok(FilterPredicate::Equals { column, value }), - Operator::NotEquals => Ok(FilterPredicate::NotEquals { column, value }), - Operator::Greater => Ok(FilterPredicate::GreaterThan { column, value }), - Operator::GreaterEquals => Ok(FilterPredicate::GreaterThanOrEqual { column, value }), - Operator::Less => Ok(FilterPredicate::LessThan { column, value }), - Operator::LessEquals => Ok(FilterPredicate::LessThanOrEqual { column, value }), - other => Err(crate::LimboError::ParseError( - format!("Unsupported WHERE clause for incremental views: comparison operator {other:?} is not supported"), - )), - } - } - - /// Parse a WHERE clause from a SELECT statement - pub fn from_select(select: &turso_parser::ast::Select) -> crate::Result { - if let OneSelect::Select { - ref where_clause, .. - } = select.body.select - { - if let Some(where_clause) = where_clause { - Self::from_sql_expr(where_clause) - } else { - Ok(FilterPredicate::None) - } - } else { - Err(crate::LimboError::ParseError( - "Unsupported WHERE clause for incremental views: not a single SELECT statement" - .to_string(), - )) - } - } -} - /// Filter operator - filters rows based on predicate #[derive(Debug)] pub struct FilterOperator { predicate: FilterPredicate, - column_names: Vec, tracker: Option>>, } impl FilterOperator { - pub fn new(predicate: FilterPredicate, column_names: Vec) -> Self { + pub fn new(predicate: FilterPredicate) -> Self { Self { predicate, - column_names, tracker: None, } } @@ -162,86 +56,72 @@ impl FilterOperator { pub fn evaluate_predicate(&self, values: &[Value]) -> bool { match &self.predicate { FilterPredicate::None => true, - FilterPredicate::Equals { column, value } => { - if let Some(idx) = self.column_names.iter().position(|c| c == column) { - if let Some(v) = values.get(idx) { - return v == value; + FilterPredicate::Equals { column_idx, value } => { + if let Some(v) = values.get(*column_idx) { + return v == value; + } + false + } + FilterPredicate::NotEquals { column_idx, value } => { + if let Some(v) = values.get(*column_idx) { + return v != value; + } + false + } + FilterPredicate::GreaterThan { column_idx, value } => { + if let Some(v) = values.get(*column_idx) { + // Compare based on value types + match (v, value) { + (Value::Integer(a), Value::Integer(b)) => return a > b, + (Value::Float(a), Value::Float(b)) => return a > b, + (Value::Text(a), Value::Text(b)) => return a.as_str() > b.as_str(), + _ => {} } } false } - FilterPredicate::NotEquals { column, value } => { - if let Some(idx) = self.column_names.iter().position(|c| c == column) { - if let Some(v) = values.get(idx) { - return v != value; + FilterPredicate::GreaterThanOrEqual { column_idx, value } => { + if let Some(v) = values.get(*column_idx) { + match (v, value) { + (Value::Integer(a), Value::Integer(b)) => return a >= b, + (Value::Float(a), Value::Float(b)) => return a >= b, + (Value::Text(a), Value::Text(b)) => return a.as_str() >= b.as_str(), + _ => {} } } false } - FilterPredicate::GreaterThan { column, value } => { - if let Some(idx) = self.column_names.iter().position(|c| c == column) { - if let Some(v) = values.get(idx) { - // Compare based on value types - match (v, value) { - (Value::Integer(a), Value::Integer(b)) => return a > b, - (Value::Float(a), Value::Float(b)) => return a > b, - (Value::Text(a), Value::Text(b)) => return a.as_str() > b.as_str(), - _ => {} - } + FilterPredicate::LessThan { column_idx, value } => { + if let Some(v) = values.get(*column_idx) { + match (v, value) { + (Value::Integer(a), Value::Integer(b)) => return a < b, + (Value::Float(a), Value::Float(b)) => return a < b, + (Value::Text(a), Value::Text(b)) => return a.as_str() < b.as_str(), + _ => {} } } false } - FilterPredicate::GreaterThanOrEqual { column, value } => { - if let Some(idx) = self.column_names.iter().position(|c| c == column) { - if let Some(v) = values.get(idx) { - match (v, value) { - (Value::Integer(a), Value::Integer(b)) => return a >= b, - (Value::Float(a), Value::Float(b)) => return a >= b, - (Value::Text(a), Value::Text(b)) => return a.as_str() >= b.as_str(), - _ => {} - } - } - } - false - } - FilterPredicate::LessThan { column, value } => { - if let Some(idx) = self.column_names.iter().position(|c| c == column) { - if let Some(v) = values.get(idx) { - match (v, value) { - (Value::Integer(a), Value::Integer(b)) => return a < b, - (Value::Float(a), Value::Float(b)) => return a < b, - (Value::Text(a), Value::Text(b)) => return a.as_str() < b.as_str(), - _ => {} - } - } - } - false - } - FilterPredicate::LessThanOrEqual { column, value } => { - if let Some(idx) = self.column_names.iter().position(|c| c == column) { - if let Some(v) = values.get(idx) { - match (v, value) { - (Value::Integer(a), Value::Integer(b)) => return a <= b, - (Value::Float(a), Value::Float(b)) => return a <= b, - (Value::Text(a), Value::Text(b)) => return a.as_str() <= b.as_str(), - _ => {} - } + FilterPredicate::LessThanOrEqual { column_idx, value } => { + if let Some(v) = values.get(*column_idx) { + match (v, value) { + (Value::Integer(a), Value::Integer(b)) => return a <= b, + (Value::Float(a), Value::Float(b)) => return a <= b, + (Value::Text(a), Value::Text(b)) => return a.as_str() <= b.as_str(), + _ => {} } } false } FilterPredicate::And(left, right) => { // Temporarily create sub-filters to evaluate - let left_filter = FilterOperator::new((**left).clone(), self.column_names.clone()); - let right_filter = - FilterOperator::new((**right).clone(), self.column_names.clone()); + let left_filter = FilterOperator::new((**left).clone()); + let right_filter = FilterOperator::new((**right).clone()); left_filter.evaluate_predicate(values) && right_filter.evaluate_predicate(values) } FilterPredicate::Or(left, right) => { - let left_filter = FilterOperator::new((**left).clone(), self.column_names.clone()); - let right_filter = - FilterOperator::new((**right).clone(), self.column_names.clone()); + let left_filter = FilterOperator::new((**left).clone()); + let right_filter = FilterOperator::new((**right).clone()); left_filter.evaluate_predicate(values) || right_filter.evaluate_predicate(values) } } diff --git a/core/incremental/operator.rs b/core/incremental/operator.rs index 72ed7bc0c..2af512504 100644 --- a/core/incremental/operator.rs +++ b/core/incremental/operator.rs @@ -1450,13 +1450,10 @@ mod tests { BTreeCursor::new_index(None, pager.clone(), index_root_page_id, &index_def, 4); let mut cursors = DbspStateCursors::new(table_cursor, index_cursor); - let mut filter = FilterOperator::new( - FilterPredicate::GreaterThan { - column: "b".to_string(), - value: Value::Integer(2), - }, - vec!["a".to_string(), "b".to_string()], - ); + let mut filter = FilterOperator::new(FilterPredicate::GreaterThan { + column_idx: 1, // "b" is at index 1 + value: Value::Integer(2), + }); // Initialize with a row (rowid=3, values=[3, 3]) let mut init_data = Delta::new(); @@ -1512,13 +1509,10 @@ mod tests { BTreeCursor::new_index(None, pager.clone(), index_root_page_id, &index_def, 4); let mut cursors = DbspStateCursors::new(table_cursor, index_cursor); - let mut filter = FilterOperator::new( - FilterPredicate::GreaterThan { - column: "age".to_string(), - value: Value::Integer(25), - }, - vec!["id".to_string(), "name".to_string(), "age".to_string()], - ); + let mut filter = FilterOperator::new(FilterPredicate::GreaterThan { + column_idx: 2, // "age" is at index 2 + value: Value::Integer(25), + }); // Initialize with some data let mut init_data = Delta::new(); diff --git a/core/incremental/view.rs b/core/incremental/view.rs index 8b32c5dcc..77f1d0217 100644 --- a/core/incremental/view.rs +++ b/core/incremental/view.rs @@ -1,6 +1,6 @@ use super::compiler::{DbspCircuit, DbspCompiler, DeltaSet}; use super::dbsp::Delta; -use super::operator::{ComputationTracker, FilterPredicate}; +use super::operator::ComputationTracker; use crate::schema::{BTreeTable, Schema}; use crate::storage::btree::BTreeCursor; use crate::translate::logical::LogicalPlanBuilder; @@ -163,8 +163,6 @@ impl AllViewsTxState { #[derive(Debug)] pub struct IncrementalView { name: String, - // WHERE clause predicate for filtering (kept for compatibility) - pub where_predicate: FilterPredicate, // The SELECT statement that defines how to transform input data pub select_stmt: ast::Select, @@ -173,6 +171,11 @@ pub struct IncrementalView { // All tables referenced by this view (from FROM clause and JOINs) referenced_tables: Vec>, + // Mapping from table aliases to actual table names (e.g., "c" -> "customers") + table_aliases: HashMap, + // Mapping from table name to fully qualified name (e.g., "customers" -> "main.customers") + // This preserves database qualification from the original query + qualified_table_names: HashMap, // The view's column schema with table relationships pub column_schema: ViewColumnSchema, // State machine for population @@ -301,8 +304,6 @@ impl IncrementalView { ) -> Result { let name = view_name.name.as_str().to_string(); - let where_predicate = FilterPredicate::from_select(&select)?; - // Extract output columns using the shared function let column_schema = extract_view_columns(&select, schema)?; @@ -313,14 +314,16 @@ impl IncrementalView { )); } - // Get all tables from FROM clause and JOINs - let referenced_tables = Self::extract_all_tables(&select, schema)?; + // Get all tables from FROM clause and JOINs, along with their aliases + let (referenced_tables, table_aliases, qualified_table_names) = + Self::extract_all_tables(&select, schema)?; Self::new( name, - where_predicate, select.clone(), referenced_tables, + table_aliases, + qualified_table_names, column_schema, schema, main_data_root, @@ -332,9 +335,10 @@ impl IncrementalView { #[allow(clippy::too_many_arguments)] pub fn new( name: String, - where_predicate: FilterPredicate, select_stmt: ast::Select, referenced_tables: Vec>, + table_aliases: HashMap, + qualified_table_names: HashMap, column_schema: ViewColumnSchema, schema: &Schema, main_data_root: usize, @@ -355,10 +359,11 @@ impl IncrementalView { Ok(Self { name, - where_predicate, select_stmt, circuit, referenced_tables, + table_aliases, + qualified_table_names, column_schema, populate_state: PopulateState::Start, tracker, @@ -402,9 +407,22 @@ impl IncrementalView { self.referenced_tables.clone() } - /// Extract all table names from a SELECT statement (including JOINs) - fn extract_all_tables(select: &ast::Select, schema: &Schema) -> Result>> { + /// Extract all tables and their aliases from the SELECT statement + /// Returns a tuple of (tables, alias_map, qualified_names) + /// where alias_map is alias -> table_name + /// and qualified_names is table_name -> fully_qualified_name + #[allow(clippy::type_complexity)] + fn extract_all_tables( + select: &ast::Select, + schema: &Schema, + ) -> Result<( + Vec>, + HashMap, + HashMap, + )> { let mut tables = Vec::new(); + let mut aliases = HashMap::new(); + let mut qualified_names = HashMap::new(); if let ast::OneSelect::Select { from: Some(ref from), @@ -412,10 +430,24 @@ impl IncrementalView { } = select.body.select { // Get the main table from FROM clause - if let ast::SelectTable::Table(name, _, _) = from.select.as_ref() { + if let ast::SelectTable::Table(name, alias, _) = from.select.as_ref() { let table_name = name.name.as_str(); + + // Build the fully qualified name + let qualified_name = if let Some(ref db) = name.db_name { + format!("{db}.{table_name}") + } else { + table_name.to_string() + }; + if let Some(table) = schema.get_btree_table(table_name) { tables.push(table.clone()); + qualified_names.insert(table_name.to_string(), qualified_name); + + // Store the alias mapping if there is an alias + if let Some(alias_name) = alias { + aliases.insert(alias_name.to_string(), table_name.to_string()); + } } else { return Err(LimboError::ParseError(format!( "Table '{table_name}' not found in schema" @@ -425,10 +457,24 @@ impl IncrementalView { // Get all tables from JOIN clauses for join in &from.joins { - if let ast::SelectTable::Table(name, _, _) = join.table.as_ref() { + if let ast::SelectTable::Table(name, alias, _) = join.table.as_ref() { let table_name = name.name.as_str(); + + // Build the fully qualified name + let qualified_name = if let Some(ref db) = name.db_name { + format!("{db}.{table_name}") + } else { + table_name.to_string() + }; + if let Some(table) = schema.get_btree_table(table_name) { tables.push(table.clone()); + qualified_names.insert(table_name.to_string(), qualified_name); + + // Store the alias mapping if there is an alias + if let Some(alias_name) = alias { + aliases.insert(alias_name.to_string(), table_name.to_string()); + } } else { return Err(LimboError::ParseError(format!( "Table '{table_name}' not found in schema" @@ -444,90 +490,346 @@ impl IncrementalView { )); } - Ok(tables) + Ok((tables, aliases, qualified_names)) } - /// Generate the SQL query for populating the view from its source table - fn sql_for_populate(&self) -> crate::Result { - // Get the first table from referenced tables + /// Generate SQL queries for populating the view from each source table + /// Returns a vector of SQL statements, one for each referenced table + /// Each query includes only the WHERE conditions relevant to that specific table + fn sql_for_populate(&self) -> crate::Result> { if self.referenced_tables.is_empty() { return Err(LimboError::ParseError( "No tables to populate from".to_string(), )); } - let table = &self.referenced_tables[0]; - // Check if the table has a rowid alias (INTEGER PRIMARY KEY column) - let has_rowid_alias = table.columns.iter().any(|col| col.is_rowid_alias); + let mut queries = Vec::new(); - // For now, select all columns since we don't have the static operators - // The circuit will handle filtering and projection - // If there's a rowid alias, we don't need to select rowid separately - let select_clause = if has_rowid_alias { - "*".to_string() - } else { - "*, rowid".to_string() - }; + for table in &self.referenced_tables { + // Check if the table has a rowid alias (INTEGER PRIMARY KEY column) + let has_rowid_alias = table.columns.iter().any(|col| col.is_rowid_alias); - // Build WHERE clause from the where_predicate - let where_clause = self.build_where_clause(&self.where_predicate)?; + // For now, select all columns since we don't have the static operators + // The circuit will handle filtering and projection + // If there's a rowid alias, we don't need to select rowid separately + let select_clause = if has_rowid_alias { + "*".to_string() + } else { + "*, rowid".to_string() + }; - // Construct the final query - let query = if where_clause.is_empty() { - format!("SELECT {} FROM {}", select_clause, table.name) - } else { - format!( - "SELECT {} FROM {} WHERE {}", - select_clause, table.name, where_clause - ) - }; - Ok(query) + // Extract WHERE conditions for this specific table + let where_clause = self.extract_where_clause_for_table(&table.name)?; + + // Use the qualified table name if available, otherwise just the table name + let table_name = self + .qualified_table_names + .get(&table.name) + .cloned() + .unwrap_or_else(|| table.name.clone()); + + // Construct the query for this table + let query = if where_clause.is_empty() { + format!("SELECT {select_clause} FROM {table_name}") + } else { + format!("SELECT {select_clause} FROM {table_name} WHERE {where_clause}") + }; + queries.push(query); + } + + Ok(queries) } - /// Build a WHERE clause from a FilterPredicate - fn build_where_clause(&self, predicate: &FilterPredicate) -> crate::Result { - match predicate { - FilterPredicate::None => Ok(String::new()), - FilterPredicate::Equals { column, value } => { - Ok(format!("{} = {}", column, self.value_to_sql(value))) + /// Extract WHERE conditions that apply to a specific table + /// This analyzes the WHERE clause in the SELECT statement and returns + /// only the conditions that reference the given table + fn extract_where_clause_for_table(&self, table_name: &str) -> crate::Result { + // For single table queries, return the entire WHERE clause (already unqualified) + if self.referenced_tables.len() == 1 { + if let ast::OneSelect::Select { + where_clause: Some(ref where_expr), + .. + } = self.select_stmt.body.select + { + // For single table, the expression should already be unqualified or qualified with the single table + // We need to unqualify it for the single-table query + let unqualified = self.unqualify_expression(where_expr, table_name); + return Ok(unqualified.to_string()); } - FilterPredicate::NotEquals { column, value } => { - Ok(format!("{} != {}", column, self.value_to_sql(value))) + return Ok(String::new()); + } + + // For multi-table queries (JOINs), extract conditions for the specific table + if let ast::OneSelect::Select { + where_clause: Some(ref where_expr), + .. + } = self.select_stmt.body.select + { + // Extract conditions that reference only the specified table + let table_conditions = self.extract_table_conditions(where_expr, table_name)?; + if let Some(conditions) = table_conditions { + // Unqualify the expression for single-table query + let unqualified = self.unqualify_expression(&conditions, table_name); + return Ok(unqualified.to_string()); } - FilterPredicate::GreaterThan { column, value } => { - Ok(format!("{} > {}", column, self.value_to_sql(value))) + } + + Ok(String::new()) + } + + /// Extract conditions from an expression that reference only the specified table + fn extract_table_conditions( + &self, + expr: &ast::Expr, + table_name: &str, + ) -> crate::Result> { + match expr { + ast::Expr::Binary(left, op, right) => { + match op { + ast::Operator::And => { + // For AND, we can extract conditions independently + let left_cond = self.extract_table_conditions(left, table_name)?; + let right_cond = self.extract_table_conditions(right, table_name)?; + + match (left_cond, right_cond) { + (Some(l), Some(r)) => { + // Both conditions apply to this table + Ok(Some(ast::Expr::Binary( + Box::new(l), + ast::Operator::And, + Box::new(r), + ))) + } + (Some(l), None) => Ok(Some(l)), + (None, Some(r)) => Ok(Some(r)), + (None, None) => Ok(None), + } + } + ast::Operator::Or => { + // For OR, both sides must reference the same table(s) + // If either side references multiple tables, we can't extract it + let left_tables = self.get_referenced_tables_in_expr(left)?; + let right_tables = self.get_referenced_tables_in_expr(right)?; + + // If both sides only reference our table, include the whole OR + if left_tables.len() == 1 + && left_tables.contains(&table_name.to_string()) + && right_tables.len() == 1 + && right_tables.contains(&table_name.to_string()) + { + Ok(Some(expr.clone())) + } else { + // OR condition involves multiple tables, can't extract + Ok(None) + } + } + _ => { + // For comparison operators, check if this condition references only our table + let referenced_tables = self.get_referenced_tables_in_expr(expr)?; + if referenced_tables.len() == 1 + && referenced_tables.contains(&table_name.to_string()) + { + Ok(Some(expr.clone())) + } else { + Ok(None) + } + } + } } - FilterPredicate::GreaterThanOrEqual { column, value } => { - Ok(format!("{} >= {}", column, self.value_to_sql(value))) + ast::Expr::Parenthesized(exprs) => { + if exprs.len() == 1 { + self.extract_table_conditions(&exprs[0], table_name) + } else { + Ok(None) + } } - FilterPredicate::LessThan { column, value } => { - Ok(format!("{} < {}", column, self.value_to_sql(value))) - } - FilterPredicate::LessThanOrEqual { column, value } => { - Ok(format!("{} <= {}", column, self.value_to_sql(value))) - } - FilterPredicate::And(left, right) => { - let left_clause = self.build_where_clause(left)?; - let right_clause = self.build_where_clause(right)?; - Ok(format!("({left_clause} AND {right_clause})")) - } - FilterPredicate::Or(left, right) => { - let left_clause = self.build_where_clause(left)?; - let right_clause = self.build_where_clause(right)?; - Ok(format!("({left_clause} OR {right_clause})")) + _ => { + // For other expressions, check if they reference only our table + let referenced_tables = self.get_referenced_tables_in_expr(expr)?; + if referenced_tables.len() == 1 + && referenced_tables.contains(&table_name.to_string()) + { + Ok(Some(expr.clone())) + } else { + Ok(None) + } } } } - /// Convert a Value to SQL literal representation - fn value_to_sql(&self, value: &Value) -> String { - match value { - Value::Null => "NULL".to_string(), - Value::Integer(i) => i.to_string(), - Value::Float(f) => f.to_string(), - Value::Text(t) => format!("'{}'", t.as_str().replace('\'', "''")), - Value::Blob(_) => "NULL".to_string(), // Blob literals not supported in WHERE clause yet + /// Get the set of table names referenced in an expression + fn get_referenced_tables_in_expr(&self, expr: &ast::Expr) -> crate::Result> { + let mut tables = Vec::new(); + self.collect_referenced_tables(expr, &mut tables)?; + // Deduplicate + tables.sort(); + tables.dedup(); + Ok(tables) + } + + /// Recursively collect table references from an expression + fn collect_referenced_tables( + &self, + expr: &ast::Expr, + tables: &mut Vec, + ) -> crate::Result<()> { + match expr { + ast::Expr::Binary(left, _, right) => { + self.collect_referenced_tables(left, tables)?; + self.collect_referenced_tables(right, tables)?; + } + ast::Expr::Qualified(table, _) => { + // This is a qualified column reference (table.column or alias.column) + // We need to resolve aliases to actual table names + let actual_table = self.resolve_table_alias(table.as_str()); + tables.push(actual_table); + } + ast::Expr::Id(column) => { + // Unqualified column reference + if self.referenced_tables.len() > 1 { + // In a JOIN context, check which tables have this column + let mut tables_with_column = Vec::new(); + for table in &self.referenced_tables { + if table + .columns + .iter() + .any(|c| c.name.as_ref() == Some(&column.to_string())) + { + tables_with_column.push(table.name.clone()); + } + } + + if tables_with_column.len() > 1 { + // Ambiguous column - this should have been caught earlier + // Return error to be safe + return Err(crate::LimboError::ParseError(format!( + "Ambiguous column name '{}' in WHERE clause - exists in tables: {}", + column, + tables_with_column.join(", ") + ))); + } else if tables_with_column.len() == 1 { + // Unambiguous - only one table has this column + // This is allowed by SQLite + tables.push(tables_with_column[0].clone()); + } else { + // Column doesn't exist in any table - this is an error + // but should be caught during compilation + return Err(crate::LimboError::ParseError(format!( + "Column '{column}' not found in any table" + ))); + } + } else { + // Single table context - unqualified columns belong to that table + if let Some(table) = self.referenced_tables.first() { + tables.push(table.name.clone()); + } + } + } + ast::Expr::DoublyQualified(_database, table, _column) => { + // For database.table.column, resolve the table name + let table_str = table.as_str(); + let actual_table = self.resolve_table_alias(table_str); + tables.push(actual_table); + } + ast::Expr::Parenthesized(exprs) => { + for e in exprs { + self.collect_referenced_tables(e, tables)?; + } + } + _ => { + // Literals and other expressions don't reference tables + } } + Ok(()) + } + + /// Convert a qualified expression to unqualified for single-table queries + /// This removes table prefixes from column references since they're not needed + /// when querying a single table + fn unqualify_expression(&self, expr: &ast::Expr, table_name: &str) -> ast::Expr { + match expr { + ast::Expr::Binary(left, op, right) => { + // Recursively unqualify both sides + ast::Expr::Binary( + Box::new(self.unqualify_expression(left, table_name)), + *op, + Box::new(self.unqualify_expression(right, table_name)), + ) + } + ast::Expr::Qualified(table, column) => { + // Convert qualified column to unqualified if it's for our table + // Handle both "table.column" and "database.table.column" cases + let table_str = table.as_str(); + + // Check if this is a database.table reference + let actual_table = if table_str.contains('.') { + // Split on '.' and take the last part as the table name + table_str + .split('.') + .next_back() + .unwrap_or(table_str) + .to_string() + } else { + // Could be an alias or direct table name + self.resolve_table_alias(table_str) + }; + + if actual_table == table_name { + // Just return the column name without qualification + ast::Expr::Id(column.clone()) + } else { + // This shouldn't happen if extract_table_conditions worked correctly + // but keep it qualified just in case + expr.clone() + } + } + ast::Expr::DoublyQualified(_database, table, column) => { + // This is database.table.column format + // Check if the table matches our target table + let table_str = table.as_str(); + let actual_table = self.resolve_table_alias(table_str); + + if actual_table == table_name { + // Just return the column name without qualification + ast::Expr::Id(column.clone()) + } else { + // Keep it qualified if it's for a different table + expr.clone() + } + } + ast::Expr::Parenthesized(exprs) => { + // Recursively unqualify expressions in parentheses + let unqualified_exprs: Vec> = exprs + .iter() + .map(|e| Box::new(self.unqualify_expression(e, table_name))) + .collect(); + ast::Expr::Parenthesized(unqualified_exprs) + } + _ => { + // Other expression types (literals, unqualified columns, etc.) stay as-is + expr.clone() + } + } + } + + /// Resolve a table alias to the actual table name + fn resolve_table_alias(&self, alias: &str) -> String { + // Check if there's an alias mapping in the FROM/JOIN clauses + // For now, we'll do a simple check - if the alias matches a table name, use it + // Otherwise, try to find it in the FROM clause + + // First check if it's an actual table name + if self.referenced_tables.iter().any(|t| t.name == alias) { + return alias.to_string(); + } + + // Check if it's an alias that maps to a table + if let Some(table_name) = self.table_aliases.get(alias) { + return table_name.clone(); + } + + // If we can't resolve it, return as-is (it might be a table name we don't know about) + alias.to_string() } /// Populate the view by scanning the source table using a state machine @@ -564,7 +866,15 @@ impl IncrementalView { // btree and not in the table btree. Using cursors would force us to be aware of this // distinction (and others), and ultimately lead to reimplementing the whole query // machinery (next step is which index is best to use, etc) - let query = self.sql_for_populate()?; + let queries = self.sql_for_populate()?; + + // For now, only use the first query (single table population) + if queries.is_empty() { + return Err(LimboError::ParseError( + "No populate queries generated".to_string(), + )); + } + let query = queries[0].clone(); // Create a new connection for reading to avoid transaction conflicts // This allows us to read from tables while the parent transaction is writing the view @@ -958,15 +1268,76 @@ mod tests { collation: None, hidden: false, }, + SchemaColumn { + name: Some("price".to_string()), + ty: Type::Real, + ty_str: "REAL".to_string(), + primary_key: false, + is_rowid_alias: false, + notnull: false, + default: None, + unique: false, + collation: None, + hidden: false, + }, ], has_rowid: true, is_strict: false, unique_sets: vec![], }; + // Create logs table - without a rowid alias (no INTEGER PRIMARY KEY) + let logs_table = BTreeTable { + name: "logs".to_string(), + root_page: 5, + primary_key_columns: vec![], // No primary key, so no rowid alias + columns: vec![ + SchemaColumn { + name: Some("message".to_string()), + ty: Type::Text, + ty_str: "TEXT".to_string(), + primary_key: false, + is_rowid_alias: false, + notnull: false, + default: None, + unique: false, + collation: None, + hidden: false, + }, + SchemaColumn { + name: Some("level".to_string()), + ty: Type::Integer, + ty_str: "INTEGER".to_string(), + primary_key: false, + is_rowid_alias: false, + notnull: false, + default: None, + unique: false, + collation: None, + hidden: false, + }, + SchemaColumn { + name: Some("timestamp".to_string()), + ty: Type::Integer, + ty_str: "INTEGER".to_string(), + primary_key: false, + is_rowid_alias: false, + notnull: false, + default: None, + unique: false, + collation: None, + hidden: false, + }, + ], + has_rowid: true, // Has implicit rowid but no alias + is_strict: false, + unique_sets: vec![], + }; + schema.add_btree_table(Arc::new(customers_table)); schema.add_btree_table(Arc::new(orders_table)); schema.add_btree_table(Arc::new(products_table)); + schema.add_btree_table(Arc::new(logs_table)); schema } @@ -985,7 +1356,7 @@ mod tests { let schema = create_test_schema(); let select = parse_select("SELECT * FROM customers"); - let tables = IncrementalView::extract_all_tables(&select, &schema).unwrap(); + let (tables, _, _) = IncrementalView::extract_all_tables(&select, &schema).unwrap(); assert_eq!(tables.len(), 1); assert_eq!(tables[0].name, "customers"); @@ -998,7 +1369,7 @@ mod tests { "SELECT * FROM customers INNER JOIN orders ON customers.id = orders.customer_id", ); - let tables = IncrementalView::extract_all_tables(&select, &schema).unwrap(); + let (tables, _, _) = IncrementalView::extract_all_tables(&select, &schema).unwrap(); assert_eq!(tables.len(), 2); assert_eq!(tables[0].name, "customers"); @@ -1014,7 +1385,7 @@ mod tests { INNER JOIN products ON orders.id = products.id", ); - let tables = IncrementalView::extract_all_tables(&select, &schema).unwrap(); + let (tables, _, _) = IncrementalView::extract_all_tables(&select, &schema).unwrap(); assert_eq!(tables.len(), 3); assert_eq!(tables[0].name, "customers"); @@ -1029,7 +1400,7 @@ mod tests { "SELECT * FROM customers LEFT JOIN orders ON customers.id = orders.customer_id", ); - let tables = IncrementalView::extract_all_tables(&select, &schema).unwrap(); + let (tables, _, _) = IncrementalView::extract_all_tables(&select, &schema).unwrap(); assert_eq!(tables.len(), 2); assert_eq!(tables[0].name, "customers"); @@ -1041,7 +1412,7 @@ mod tests { let schema = create_test_schema(); let select = parse_select("SELECT * FROM customers CROSS JOIN orders"); - let tables = IncrementalView::extract_all_tables(&select, &schema).unwrap(); + let (tables, _, _) = IncrementalView::extract_all_tables(&select, &schema).unwrap(); assert_eq!(tables.len(), 2); assert_eq!(tables[0].name, "customers"); @@ -1054,7 +1425,7 @@ mod tests { let select = parse_select("SELECT * FROM customers c INNER JOIN orders o ON c.id = o.customer_id"); - let tables = IncrementalView::extract_all_tables(&select, &schema).unwrap(); + let (tables, _, _) = IncrementalView::extract_all_tables(&select, &schema).unwrap(); // Should still extract the actual table names, not aliases assert_eq!(tables.len(), 2); @@ -1067,7 +1438,8 @@ mod tests { let schema = create_test_schema(); let select = parse_select("SELECT * FROM nonexistent"); - let result = IncrementalView::extract_all_tables(&select, &schema); + let result = + IncrementalView::extract_all_tables(&select, &schema).map(|(tables, _, _)| tables); assert!(result.is_err()); assert!(result @@ -1083,7 +1455,8 @@ mod tests { "SELECT * FROM customers INNER JOIN nonexistent ON customers.id = nonexistent.id", ); - let result = IncrementalView::extract_all_tables(&select, &schema); + let result = + IncrementalView::extract_all_tables(&select, &schema).map(|(tables, _, _)| tables); assert!(result.is_err()); assert!(result @@ -1091,4 +1464,526 @@ mod tests { .to_string() .contains("Table 'nonexistent' not found")); } + + #[test] + fn test_sql_for_populate_simple_query_no_where() { + // Test simple query with no WHERE clause + let schema = create_test_schema(); + let select = parse_select("SELECT * FROM customers"); + + let (tables, aliases, qualified_names) = + IncrementalView::extract_all_tables(&select, &schema).unwrap(); + let view = IncrementalView::new( + "test_view".to_string(), + select.clone(), + tables, + aliases, + qualified_names, + extract_view_columns(&select, &schema).unwrap(), + &schema, + 1, // main_data_root + 2, // internal_state_root + 3, // internal_state_index_root + ) + .unwrap(); + + let queries = view.sql_for_populate().unwrap(); + + assert_eq!(queries.len(), 1); + // customers has id as rowid alias, so no need for explicit rowid + assert_eq!(queries[0], "SELECT * FROM customers"); + } + + #[test] + fn test_sql_for_populate_simple_query_with_where() { + // Test simple query with WHERE clause + let schema = create_test_schema(); + let select = parse_select("SELECT * FROM customers WHERE id > 10"); + + let (tables, aliases, qualified_names) = + IncrementalView::extract_all_tables(&select, &schema).unwrap(); + let view = IncrementalView::new( + "test_view".to_string(), + select.clone(), + tables, + aliases, + qualified_names, + extract_view_columns(&select, &schema).unwrap(), + &schema, + 1, // main_data_root + 2, // internal_state_root + 3, // internal_state_index_root + ) + .unwrap(); + + let queries = view.sql_for_populate().unwrap(); + + assert_eq!(queries.len(), 1); + // For single-table queries, we should get the full WHERE clause + assert_eq!(queries[0], "SELECT * FROM customers WHERE id > 10"); + } + + #[test] + fn test_sql_for_populate_join_with_where_on_both_tables() { + // Test JOIN query with WHERE conditions on both tables + let schema = create_test_schema(); + let select = parse_select( + "SELECT * FROM customers c \ + JOIN orders o ON c.id = o.customer_id \ + WHERE c.id > 10 AND o.total > 100", + ); + + let (tables, aliases, qualified_names) = + IncrementalView::extract_all_tables(&select, &schema).unwrap(); + let view = IncrementalView::new( + "test_view".to_string(), + select.clone(), + tables, + aliases, + qualified_names, + extract_view_columns(&select, &schema).unwrap(), + &schema, + 1, // main_data_root + 2, // internal_state_root + 3, // internal_state_index_root + ) + .unwrap(); + + let queries = view.sql_for_populate().unwrap(); + + assert_eq!(queries.len(), 2); + + // With per-table WHERE extraction: + // - customers table gets: c.id > 10 + // - orders table gets: o.total > 100 + assert_eq!(queries[0], "SELECT * FROM customers WHERE id > 10"); + assert_eq!(queries[1], "SELECT * FROM orders WHERE total > 100"); + } + + #[test] + fn test_sql_for_populate_complex_join_with_mixed_conditions() { + // Test complex JOIN with WHERE conditions mixing both tables + let schema = create_test_schema(); + let select = parse_select( + "SELECT * FROM customers c \ + JOIN orders o ON c.id = o.customer_id \ + WHERE c.id > 10 AND o.total > 100 AND c.name = 'John' \ + AND o.customer_id = 5 AND (c.id = 15 OR o.total = 200)", + ); + + let (tables, aliases, qualified_names) = + IncrementalView::extract_all_tables(&select, &schema).unwrap(); + let view = IncrementalView::new( + "test_view".to_string(), + select.clone(), + tables, + aliases, + qualified_names, + extract_view_columns(&select, &schema).unwrap(), + &schema, + 1, // main_data_root + 2, // internal_state_root + 3, // internal_state_index_root + ) + .unwrap(); + + let queries = view.sql_for_populate().unwrap(); + + assert_eq!(queries.len(), 2); + + // With per-table WHERE extraction: + // - customers gets: c.id > 10 AND c.name = 'John' + // - orders gets: o.total > 100 AND o.customer_id = 5 + // Note: The OR condition (c.id = 15 OR o.total = 200) involves both tables, + // so it cannot be extracted to either table individually + assert_eq!( + queries[0], + "SELECT * FROM customers WHERE id > 10 AND name = 'John'" + ); + assert_eq!( + queries[1], + "SELECT * FROM orders WHERE total > 100 AND customer_id = 5" + ); + } + + #[test] + fn test_where_extraction_for_three_tables() { + // Test that WHERE clause extraction correctly separates conditions for 3+ tables + // This addresses the concern about conditions "piling up" as joins increase + + // Simulate a three-table scenario + let schema = create_test_schema(); + + // Parse a WHERE clause with conditions for three different tables + let select = parse_select( + "SELECT * FROM customers WHERE c.id > 10 AND o.total > 100 AND p.price > 50", + ); + + // Get the WHERE expression + if let ast::OneSelect::Select { + where_clause: Some(ref where_expr), + .. + } = select.body.select + { + // Create a view with three tables to test extraction + let tables = vec![ + schema.get_btree_table("customers").unwrap(), + schema.get_btree_table("orders").unwrap(), + schema.get_btree_table("products").unwrap(), + ]; + + let mut aliases = HashMap::new(); + aliases.insert("c".to_string(), "customers".to_string()); + aliases.insert("o".to_string(), "orders".to_string()); + aliases.insert("p".to_string(), "products".to_string()); + + // Create a minimal view just to test extraction logic + let view = IncrementalView { + name: "test".to_string(), + select_stmt: select.clone(), + circuit: DbspCircuit::new(1, 2, 3), + referenced_tables: tables, + table_aliases: aliases, + qualified_table_names: HashMap::new(), + column_schema: ViewColumnSchema { + columns: vec![], + tables: vec![], + }, + populate_state: PopulateState::Start, + tracker: Arc::new(Mutex::new(ComputationTracker::new())), + root_page: 0, + }; + + // Test extraction for each table + let customers_conds = view + .extract_table_conditions(where_expr, "customers") + .unwrap(); + let orders_conds = view.extract_table_conditions(where_expr, "orders").unwrap(); + let products_conds = view + .extract_table_conditions(where_expr, "products") + .unwrap(); + + // Verify each table only gets its conditions + if let Some(cond) = customers_conds { + let sql = cond.to_string(); + assert!(sql.contains("id > 10")); + assert!(!sql.contains("total")); + assert!(!sql.contains("price")); + } + + if let Some(cond) = orders_conds { + let sql = cond.to_string(); + assert!(sql.contains("total > 100")); + assert!(!sql.contains("id > 10")); // From customers + assert!(!sql.contains("price")); + } + + if let Some(cond) = products_conds { + let sql = cond.to_string(); + assert!(sql.contains("price > 50")); + assert!(!sql.contains("id > 10")); // From customers + assert!(!sql.contains("total")); + } + } else { + panic!("Failed to parse WHERE clause"); + } + } + + #[test] + fn test_alias_resolution_works_correctly() { + // Test that alias resolution properly maps aliases to table names + let schema = create_test_schema(); + let select = parse_select( + "SELECT * FROM customers c \ + JOIN orders o ON c.id = o.customer_id \ + WHERE c.id > 10 AND o.total > 100", + ); + + let (tables, aliases, qualified_names) = + IncrementalView::extract_all_tables(&select, &schema).unwrap(); + let view = IncrementalView::new( + "test_view".to_string(), + select.clone(), + tables, + aliases, + qualified_names, + extract_view_columns(&select, &schema).unwrap(), + &schema, + 1, // main_data_root + 2, // internal_state_root + 3, // internal_state_index_root + ) + .unwrap(); + + // Verify that alias mappings were extracted correctly + assert_eq!(view.table_aliases.get("c"), Some(&"customers".to_string())); + assert_eq!(view.table_aliases.get("o"), Some(&"orders".to_string())); + + // Verify that SQL generation uses the aliases correctly + let queries = view.sql_for_populate().unwrap(); + assert_eq!(queries.len(), 2); + + // Each query should use the actual table name, not the alias + assert!(queries[0].contains("FROM customers") || queries[1].contains("FROM customers")); + assert!(queries[0].contains("FROM orders") || queries[1].contains("FROM orders")); + } + + #[test] + fn test_sql_for_populate_table_without_rowid_alias() { + // Test that tables without a rowid alias properly include rowid in SELECT + let schema = create_test_schema(); + let select = parse_select("SELECT * FROM logs WHERE level > 2"); + + let (tables, aliases, qualified_names) = + IncrementalView::extract_all_tables(&select, &schema).unwrap(); + let view = IncrementalView::new( + "test_view".to_string(), + select.clone(), + tables, + aliases, + qualified_names, + extract_view_columns(&select, &schema).unwrap(), + &schema, + 1, // main_data_root + 2, // internal_state_root + 3, // internal_state_index_root + ) + .unwrap(); + + let queries = view.sql_for_populate().unwrap(); + + assert_eq!(queries.len(), 1); + // logs table has no rowid alias, so we need to explicitly select rowid + assert_eq!(queries[0], "SELECT *, rowid FROM logs WHERE level > 2"); + } + + #[test] + fn test_sql_for_populate_join_with_and_without_rowid_alias() { + // Test JOIN between a table with rowid alias and one without + let schema = create_test_schema(); + let select = parse_select( + "SELECT * FROM customers c \ + JOIN logs l ON c.id = l.level \ + WHERE c.id > 10 AND l.level > 2", + ); + + let (tables, aliases, qualified_names) = + IncrementalView::extract_all_tables(&select, &schema).unwrap(); + let view = IncrementalView::new( + "test_view".to_string(), + select.clone(), + tables, + aliases, + qualified_names, + extract_view_columns(&select, &schema).unwrap(), + &schema, + 1, // main_data_root + 2, // internal_state_root + 3, // internal_state_index_root + ) + .unwrap(); + + let queries = view.sql_for_populate().unwrap(); + + assert_eq!(queries.len(), 2); + // customers has rowid alias (id), logs doesn't + assert_eq!(queries[0], "SELECT * FROM customers WHERE id > 10"); + assert_eq!(queries[1], "SELECT *, rowid FROM logs WHERE level > 2"); + } + + #[test] + fn test_sql_for_populate_with_database_qualified_names() { + // Test that database.table.column references are handled correctly + // The table name in FROM should keep the database prefix, + // but column names in WHERE should be unqualified + let schema = create_test_schema(); + + // Test with single table using database qualification + let select = parse_select("SELECT * FROM main.customers WHERE main.customers.id > 10"); + + let (tables, aliases, qualified_names) = + IncrementalView::extract_all_tables(&select, &schema).unwrap(); + let view = IncrementalView::new( + "test_view".to_string(), + select.clone(), + tables, + aliases, + qualified_names, + extract_view_columns(&select, &schema).unwrap(), + &schema, + 1, // main_data_root + 2, // internal_state_root + 3, // internal_state_index_root + ) + .unwrap(); + + let queries = view.sql_for_populate().unwrap(); + + assert_eq!(queries.len(), 1); + // The FROM clause should preserve the database qualification, + // but the WHERE clause should have unqualified column names + assert_eq!(queries[0], "SELECT * FROM main.customers WHERE id > 10"); + } + + #[test] + fn test_sql_for_populate_join_with_database_qualified_names() { + // Test JOIN with database-qualified table and column references + let schema = create_test_schema(); + + let select = parse_select( + "SELECT * FROM main.customers c \ + JOIN main.orders o ON c.id = o.customer_id \ + WHERE main.customers.id > 10 AND main.orders.total > 100", + ); + + let (tables, aliases, qualified_names) = + IncrementalView::extract_all_tables(&select, &schema).unwrap(); + let view = IncrementalView::new( + "test_view".to_string(), + select.clone(), + tables, + aliases, + qualified_names, + extract_view_columns(&select, &schema).unwrap(), + &schema, + 1, // main_data_root + 2, // internal_state_root + 3, // internal_state_index_root + ) + .unwrap(); + + let queries = view.sql_for_populate().unwrap(); + + assert_eq!(queries.len(), 2); + // The FROM clauses should preserve database qualification, + // but WHERE clauses should have unqualified column names + assert_eq!(queries[0], "SELECT * FROM main.customers WHERE id > 10"); + assert_eq!(queries[1], "SELECT * FROM main.orders WHERE total > 100"); + } + + #[test] + fn test_sql_for_populate_unambiguous_unqualified_column() { + // Test that unambiguous unqualified columns ARE extracted + let schema = create_test_schema(); + let select = parse_select( + "SELECT * FROM customers c \ + JOIN orders o ON c.id = o.customer_id \ + WHERE total > 100", // 'total' only exists in orders table + ); + + let (tables, aliases, qualified_names) = + IncrementalView::extract_all_tables(&select, &schema).unwrap(); + let view = IncrementalView::new( + "test_view".to_string(), + select.clone(), + tables, + aliases, + qualified_names, + extract_view_columns(&select, &schema).unwrap(), + &schema, + 1, // main_data_root + 2, // internal_state_root + 3, // internal_state_index_root + ) + .unwrap(); + + let queries = view.sql_for_populate().unwrap(); + + assert_eq!(queries.len(), 2); + + // 'total' is unambiguous (only in orders), so it should be extracted + assert_eq!(queries[0], "SELECT * FROM customers"); + assert_eq!(queries[1], "SELECT * FROM orders WHERE total > 100"); + } + + #[test] + fn test_database_qualified_table_names() { + let schema = create_test_schema(); + + // Test with database-qualified table names + let select = parse_select( + "SELECT c.id, c.name, o.id, o.total + FROM main.customers c + JOIN main.orders o ON c.id = o.customer_id + WHERE c.id > 10", + ); + + let (tables, aliases, qualified_names) = + IncrementalView::extract_all_tables(&select, &schema).unwrap(); + + // Check that qualified names are preserved + assert!(qualified_names.contains_key("customers")); + assert_eq!(qualified_names.get("customers").unwrap(), "main.customers"); + assert!(qualified_names.contains_key("orders")); + assert_eq!(qualified_names.get("orders").unwrap(), "main.orders"); + + let view = IncrementalView::new( + "test_view".to_string(), + select.clone(), + tables, + aliases, + qualified_names.clone(), + extract_view_columns(&select, &schema).unwrap(), + &schema, + 1, // main_data_root + 2, // internal_state_root + 3, // internal_state_index_root + ) + .unwrap(); + + let queries = view.sql_for_populate().unwrap(); + + assert_eq!(queries.len(), 2); + + // The FROM clause should contain the database-qualified name + // But the WHERE clause should use unqualified column names + assert_eq!(queries[0], "SELECT * FROM main.customers WHERE id > 10"); + assert_eq!(queries[1], "SELECT * FROM main.orders"); + } + + #[test] + fn test_mixed_qualified_unqualified_tables() { + let schema = create_test_schema(); + + // Test with a mix of qualified and unqualified table names + let select = parse_select( + "SELECT c.id, c.name, o.id, o.total + FROM main.customers c + JOIN orders o ON c.id = o.customer_id + WHERE c.id > 10 AND o.total < 1000", + ); + + let (tables, aliases, qualified_names) = + IncrementalView::extract_all_tables(&select, &schema).unwrap(); + + // Check that qualified names are preserved where specified + assert_eq!(qualified_names.get("customers").unwrap(), "main.customers"); + // Unqualified tables should not have an entry (or have the bare name) + assert!( + !qualified_names.contains_key("orders") + || qualified_names.get("orders").unwrap() == "orders" + ); + + let view = IncrementalView::new( + "test_view".to_string(), + select.clone(), + tables, + aliases, + qualified_names.clone(), + extract_view_columns(&select, &schema).unwrap(), + &schema, + 1, // main_data_root + 2, // internal_state_root + 3, // internal_state_index_root + ) + .unwrap(); + + let queries = view.sql_for_populate().unwrap(); + + assert_eq!(queries.len(), 2); + + // The FROM clause should preserve qualification where specified + assert_eq!(queries[0], "SELECT * FROM main.customers WHERE id > 10"); + assert_eq!(queries[1], "SELECT * FROM orders WHERE total < 1000"); + } } From 47097fbec69c8167eda88557020928da6a849e11 Mon Sep 17 00:00:00 2001 From: Glauber Costa Date: Thu, 18 Sep 2025 10:35:21 -0500 Subject: [PATCH 22/34] Add tests for project operator working with ambiguous columns Unlike the other operators, project works just fine with ambiguous columsn, because it works with compiled expressions. We don't need to patch it, but let's make sure it keeps working by writing a test. --- core/incremental/compiler.rs | 290 +++++++++++++++++++++++++++++++++++ 1 file changed, 290 insertions(+) diff --git a/core/incremental/compiler.rs b/core/incremental/compiler.rs index 07fd8f83c..52f383617 100644 --- a/core/incremental/compiler.rs +++ b/core/incremental/compiler.rs @@ -4505,6 +4505,296 @@ mod tests { assert_eq!(results[2].0.values[3], Value::Integer(20)); // total_value } + #[test] + fn test_projection_with_function_and_ambiguous_columns() { + // Test projection with functions operating on potentially ambiguous columns + // Uses HEX() function on sum of columns from different tables with same names + // Tables: customers(id, name), vendors(id, name, price), purchases(id, customer_id, vendor_id, quantity) + // This test ensures column references are correctly resolved to their respective tables + + let sql = "SELECT HEX(c.id + v.id) as hex_sum, + UPPER(c.name) as customer_upper, + LOWER(v.name) as vendor_lower, + c.id * v.price as product_value + FROM customers c + JOIN vendors v ON c.id = v.id"; + + let (mut circuit, pager) = compile_sql!(sql); + + // Create test data for customers (id, name) + let mut customers_delta = Delta::new(); + customers_delta.insert(1, vec![Value::Integer(1), Value::Text("Alice".into())]); + customers_delta.insert(2, vec![Value::Integer(2), Value::Text("Bob".into())]); + customers_delta.insert(3, vec![Value::Integer(3), Value::Text("Charlie".into())]); + + // Create test data for vendors (id, name, price) + let mut vendors_delta = Delta::new(); + vendors_delta.insert( + 1, + vec![ + Value::Integer(1), + Value::Text("Widget Co".into()), + Value::Integer(10), + ], + ); + vendors_delta.insert( + 2, + vec![ + Value::Integer(2), + Value::Text("Gadget Inc".into()), + Value::Integer(20), + ], + ); + vendors_delta.insert( + 3, + vec![ + Value::Integer(3), + Value::Text("Tool Corp".into()), + Value::Integer(30), + ], + ); + + let inputs = HashMap::from([ + ("customers".to_string(), customers_delta), + ("vendors".to_string(), vendors_delta), + ]); + + let result = test_execute(&mut circuit, inputs.clone(), pager.clone()).unwrap(); + + // Expected results: + // For customer 1 (Alice) + vendor 1: + // - HEX(1 + 1) = HEX(2) = "32" + // - UPPER("Alice") = "ALICE" + // - LOWER("Widget Co") = "widget co" + // - 1 * 10 = 10 + assert_eq!(result.len(), 3, "Should have 3 join results"); + + let mut results = result.changes.clone(); + results.sort_by_key(|(row, _)| { + // Sort by the product_value column for predictable ordering + match &row.values[3] { + Value::Integer(n) => *n, + _ => 0, + } + }); + + // First result: Alice + Widget Co + assert_eq!(results[0].0.values[0], Value::Text("32".into())); // HEX(2) + assert_eq!(results[0].0.values[1], Value::Text("ALICE".into())); + assert_eq!(results[0].0.values[2], Value::Text("widget co".into())); + assert_eq!(results[0].0.values[3], Value::Integer(10)); // 1 * 10 + + // Second result: Bob + Gadget Inc + assert_eq!(results[1].0.values[0], Value::Text("34".into())); // HEX(4) + assert_eq!(results[1].0.values[1], Value::Text("BOB".into())); + assert_eq!(results[1].0.values[2], Value::Text("gadget inc".into())); + assert_eq!(results[1].0.values[3], Value::Integer(40)); // 2 * 20 + + // Third result: Charlie + Tool Corp + assert_eq!(results[2].0.values[0], Value::Text("36".into())); // HEX(6) + assert_eq!(results[2].0.values[1], Value::Text("CHARLIE".into())); + assert_eq!(results[2].0.values[2], Value::Text("tool corp".into())); + assert_eq!(results[2].0.values[3], Value::Integer(90)); // 3 * 30 + } + + #[test] + fn test_projection_column_selection_after_join() { + // Test selecting specific columns after a join, especially with overlapping column names + // This ensures the projection correctly picks columns by their qualified references + + let sql = "SELECT c.id as customer_id, + c.name as customer_name, + o.order_id, + o.quantity, + p.product_name + FROM users c + JOIN orders o ON c.id = o.user_id + JOIN products p ON o.product_id = p.product_id + WHERE o.quantity > 2"; + + let (mut circuit, pager) = compile_sql!(sql); + + // Create test data for users (id, name, age) + let mut users_delta = Delta::new(); + users_delta.insert( + 1, + vec![ + Value::Integer(1), + Value::Text("Alice".into()), + Value::Integer(25), + ], + ); + users_delta.insert( + 2, + vec![ + Value::Integer(2), + Value::Text("Bob".into()), + Value::Integer(30), + ], + ); + + // Create test data for orders (order_id, user_id, product_id, quantity) + let mut orders_delta = Delta::new(); + orders_delta.insert( + 1, + vec![ + Value::Integer(101), + Value::Integer(1), // Alice + Value::Integer(201), // Widget + Value::Integer(5), // quantity > 2 + ], + ); + orders_delta.insert( + 2, + vec![ + Value::Integer(102), + Value::Integer(2), // Bob + Value::Integer(202), // Gadget + Value::Integer(1), // quantity <= 2, filtered out + ], + ); + orders_delta.insert( + 3, + vec![ + Value::Integer(103), + Value::Integer(1), // Alice + Value::Integer(202), // Gadget + Value::Integer(3), // quantity > 2 + ], + ); + + // Create test data for products (product_id, product_name, price) + let mut products_delta = Delta::new(); + products_delta.insert( + 201, + vec![ + Value::Integer(201), + Value::Text("Widget".into()), + Value::Integer(10), + ], + ); + products_delta.insert( + 202, + vec![ + Value::Integer(202), + Value::Text("Gadget".into()), + Value::Integer(20), + ], + ); + + let inputs = HashMap::from([ + ("users".to_string(), users_delta), + ("orders".to_string(), orders_delta), + ("products".to_string(), products_delta), + ]); + + let result = test_execute(&mut circuit, inputs.clone(), pager.clone()).unwrap(); + + // Should have 2 results (orders with quantity > 2) + assert_eq!(result.len(), 2, "Should have 2 results after filtering"); + + let mut results = result.changes.clone(); + results.sort_by_key(|(row, _)| { + match &row.values[2] { + // Sort by order_id + Value::Integer(n) => *n, + _ => 0, + } + }); + + // First result: Alice's order 101 for Widget + assert_eq!(results[0].0.values[0], Value::Integer(1)); // customer_id + assert_eq!(results[0].0.values[1], Value::Text("Alice".into())); // customer_name + assert_eq!(results[0].0.values[2], Value::Integer(101)); // order_id + assert_eq!(results[0].0.values[3], Value::Integer(5)); // quantity + assert_eq!(results[0].0.values[4], Value::Text("Widget".into())); // product_name + + // Second result: Alice's order 103 for Gadget + assert_eq!(results[1].0.values[0], Value::Integer(1)); // customer_id + assert_eq!(results[1].0.values[1], Value::Text("Alice".into())); // customer_name + assert_eq!(results[1].0.values[2], Value::Integer(103)); // order_id + assert_eq!(results[1].0.values[3], Value::Integer(3)); // quantity + assert_eq!(results[1].0.values[4], Value::Text("Gadget".into())); // product_name + } + + #[test] + fn test_projection_column_reordering_and_duplication() { + // Test that projection can reorder columns and select the same column multiple times + // This is important for views that need specific column arrangements + + let sql = "SELECT o.quantity, + u.name, + u.id, + o.quantity * 2 as double_quantity, + u.id as user_id_again + FROM users u + JOIN orders o ON u.id = o.user_id + WHERE u.id = 1"; + + let (mut circuit, pager) = compile_sql!(sql); + + // Create test data for users + let mut users_delta = Delta::new(); + users_delta.insert( + 1, + vec![ + Value::Integer(1), + Value::Text("Alice".into()), + Value::Integer(25), + ], + ); + + // Create test data for orders + let mut orders_delta = Delta::new(); + orders_delta.insert( + 1, + vec![ + Value::Integer(101), + Value::Integer(1), // user_id + Value::Integer(201), // product_id + Value::Integer(5), // quantity + ], + ); + orders_delta.insert( + 2, + vec![ + Value::Integer(102), + Value::Integer(1), // user_id + Value::Integer(202), // product_id + Value::Integer(3), // quantity + ], + ); + + let inputs = HashMap::from([ + ("users".to_string(), users_delta), + ("orders".to_string(), orders_delta), + ]); + + let result = test_execute(&mut circuit, inputs.clone(), pager.clone()).unwrap(); + + assert_eq!(result.len(), 2, "Should have 2 results for user 1"); + + // Check that columns are in the right order and values are correct + for (row, _) in &result.changes { + // Column 0: o.quantity (5 or 3) + assert!(matches!( + row.values[0], + Value::Integer(5) | Value::Integer(3) + )); + // Column 1: u.name + assert_eq!(row.values[1], Value::Text("Alice".into())); + // Column 2: u.id + assert_eq!(row.values[2], Value::Integer(1)); + // Column 3: o.quantity * 2 (10 or 6) + assert!(matches!( + row.values[3], + Value::Integer(10) | Value::Integer(6) + )); + // Column 4: u.id again + assert_eq!(row.values[4], Value::Integer(1)); + } + } + #[test] fn test_join_with_aggregate_execution() { let (mut circuit, pager) = compile_sql!( From 627f61aa81f36ffcd8834dfb43f5417a7b7fec01 Mon Sep 17 00:00:00 2001 From: Glauber Costa Date: Thu, 18 Sep 2025 11:23:47 -0500 Subject: [PATCH 23/34] support column comparisons in the filter operator We currently only support column / literal comparisons in the filter operator. But with JOINs, comparisons are usually against two columns. Do the work to support it. --- core/incremental/compiler.rs | 80 ++++++++++++++++++++++++- core/incremental/filter_operator.rs | 90 +++++++++++++++++++++++++++++ 2 files changed, 167 insertions(+), 3 deletions(-) diff --git a/core/incremental/compiler.rs b/core/incremental/compiler.rs index 52f383617..e1b04f607 100644 --- a/core/incremental/compiler.rs +++ b/core/incremental/compiler.rs @@ -1376,10 +1376,84 @@ impl DbspCompiler { match expr { LogicalExpr::BinaryExpr { left, op, right } => { // Extract column name and value for simple predicates - if let (LogicalExpr::Column(col), LogicalExpr::Literal(val)) = + // First check for column-to-column comparisons + if let (LogicalExpr::Column(left_col), LogicalExpr::Column(right_col)) = (left.as_ref(), right.as_ref()) { - // Resolve column name to index using the schema + // Resolve both column names to indices + let left_idx = schema + .columns + .iter() + .position(|c| c.name == left_col.name) + .ok_or_else(|| { + crate::LimboError::ParseError(format!( + "Column '{}' not found in schema for filter", + left_col.name + )) + })?; + + let right_idx = schema + .columns + .iter() + .position(|c| c.name == right_col.name) + .ok_or_else(|| { + crate::LimboError::ParseError(format!( + "Column '{}' not found in schema for filter", + right_col.name + )) + })?; + + match op { + BinaryOperator::Equals => Ok(FilterPredicate::ColumnEquals { + left_idx, + right_idx, + }), + BinaryOperator::NotEquals => Ok(FilterPredicate::ColumnNotEquals { + left_idx, + right_idx, + }), + BinaryOperator::Greater => Ok(FilterPredicate::ColumnGreaterThan { + left_idx, + right_idx, + }), + BinaryOperator::GreaterEquals => { + Ok(FilterPredicate::ColumnGreaterThanOrEqual { + left_idx, + right_idx, + }) + } + BinaryOperator::Less => Ok(FilterPredicate::ColumnLessThan { + left_idx, + right_idx, + }), + BinaryOperator::LessEquals => Ok(FilterPredicate::ColumnLessThanOrEqual { + left_idx, + right_idx, + }), + BinaryOperator::And | BinaryOperator::Or => { + // Handle logical operators recursively + let left_pred = Self::compile_filter_predicate(left, schema)?; + let right_pred = Self::compile_filter_predicate(right, schema)?; + match op { + BinaryOperator::And => Ok(FilterPredicate::And( + Box::new(left_pred), + Box::new(right_pred), + )), + BinaryOperator::Or => Ok(FilterPredicate::Or( + Box::new(left_pred), + Box::new(right_pred), + )), + _ => unreachable!(), + } + } + _ => Err(LimboError::ParseError(format!( + "Unsupported operator in filter: {op:?}" + ))), + } + } else if let (LogicalExpr::Column(col), LogicalExpr::Literal(val)) = + (left.as_ref(), right.as_ref()) + { + // Column-to-literal comparisons let column_idx = schema .columns .iter() @@ -1455,7 +1529,7 @@ impl DbspCompiler { } } else { Err(LimboError::ParseError( - "Filter predicate must be column op value".to_string(), + "Filter predicate must be column op value or column op column".to_string(), )) } } diff --git a/core/incremental/filter_operator.rs b/core/incremental/filter_operator.rs index a0179f9d4..84a3c53ce 100644 --- a/core/incremental/filter_operator.rs +++ b/core/incremental/filter_operator.rs @@ -25,6 +25,20 @@ pub enum FilterPredicate { LessThan { column_idx: usize, value: Value }, /// Column <= value (using column index) LessThanOrEqual { column_idx: usize, value: Value }, + + /// Column = Column comparisons + ColumnEquals { left_idx: usize, right_idx: usize }, + /// Column != Column comparisons + ColumnNotEquals { left_idx: usize, right_idx: usize }, + /// Column > Column comparisons + ColumnGreaterThan { left_idx: usize, right_idx: usize }, + /// Column >= Column comparisons + ColumnGreaterThanOrEqual { left_idx: usize, right_idx: usize }, + /// Column < Column comparisons + ColumnLessThan { left_idx: usize, right_idx: usize }, + /// Column <= Column comparisons + ColumnLessThanOrEqual { left_idx: usize, right_idx: usize }, + /// Logical AND of two predicates And(Box, Box), /// Logical OR of two predicates @@ -124,6 +138,82 @@ impl FilterOperator { let right_filter = FilterOperator::new((**right).clone()); left_filter.evaluate_predicate(values) || right_filter.evaluate_predicate(values) } + + // Column-to-column comparisons + FilterPredicate::ColumnEquals { + left_idx, + right_idx, + } => { + if let (Some(left), Some(right)) = (values.get(*left_idx), values.get(*right_idx)) { + return left == right; + } + false + } + FilterPredicate::ColumnNotEquals { + left_idx, + right_idx, + } => { + if let (Some(left), Some(right)) = (values.get(*left_idx), values.get(*right_idx)) { + return left != right; + } + false + } + FilterPredicate::ColumnGreaterThan { + left_idx, + right_idx, + } => { + if let (Some(left), Some(right)) = (values.get(*left_idx), values.get(*right_idx)) { + match (left, right) { + (Value::Integer(a), Value::Integer(b)) => return a > b, + (Value::Float(a), Value::Float(b)) => return a > b, + (Value::Text(a), Value::Text(b)) => return a.as_str() > b.as_str(), + _ => {} + } + } + false + } + FilterPredicate::ColumnGreaterThanOrEqual { + left_idx, + right_idx, + } => { + if let (Some(left), Some(right)) = (values.get(*left_idx), values.get(*right_idx)) { + match (left, right) { + (Value::Integer(a), Value::Integer(b)) => return a >= b, + (Value::Float(a), Value::Float(b)) => return a >= b, + (Value::Text(a), Value::Text(b)) => return a.as_str() >= b.as_str(), + _ => {} + } + } + false + } + FilterPredicate::ColumnLessThan { + left_idx, + right_idx, + } => { + if let (Some(left), Some(right)) = (values.get(*left_idx), values.get(*right_idx)) { + match (left, right) { + (Value::Integer(a), Value::Integer(b)) => return a < b, + (Value::Float(a), Value::Float(b)) => return a < b, + (Value::Text(a), Value::Text(b)) => return a.as_str() < b.as_str(), + _ => {} + } + } + false + } + FilterPredicate::ColumnLessThanOrEqual { + left_idx, + right_idx, + } => { + if let (Some(left), Some(right)) = (values.get(*left_idx), values.get(*right_idx)) { + match (left, right) { + (Value::Integer(a), Value::Integer(b)) => return a <= b, + (Value::Float(a), Value::Float(b)) => return a <= b, + (Value::Text(a), Value::Text(b)) => return a.as_str() <= b.as_str(), + _ => {} + } + } + false + } } } } From 832a4d703448157d5a2bdd299a169ef3d8d5640f Mon Sep 17 00:00:00 2001 From: Glauber Costa Date: Thu, 18 Sep 2025 12:18:01 -0500 Subject: [PATCH 24/34] generate projection nodes inside filter clauses We are currently not able to properly compute things like WHERE a+b=2. Let's generate a projection node inside a filter when needed. --- core/incremental/compiler.rs | 341 +++++++++++++++++++++++++++++++++-- 1 file changed, 325 insertions(+), 16 deletions(-) diff --git a/core/incremental/compiler.rs b/core/incremental/compiler.rs index e1b04f607..8c8189261 100644 --- a/core/incremental/compiler.rs +++ b/core/incremental/compiler.rs @@ -11,10 +11,12 @@ use crate::incremental::operator::{ create_dbsp_state_index, DbspStateCursors, EvalState, FilterOperator, FilterPredicate, IncrementalOperator, InputOperator, JoinOperator, JoinType, ProjectOperator, }; +use crate::schema::Type; use crate::storage::btree::{BTreeCursor, BTreeKey}; // Note: logical module must be made pub(crate) in translate/mod.rs use crate::translate::logical::{ - BinaryOperator, JoinType as LogicalJoinType, LogicalExpr, LogicalPlan, LogicalSchema, SchemaRef, + BinaryOperator, Column, ColumnInfo, JoinType as LogicalJoinType, LogicalExpr, LogicalPlan, + LogicalSchema, SchemaRef, }; use crate::types::{IOResult, ImmutableRecord, SeekKey, SeekOp, SeekResult, Value}; use crate::Pager; @@ -898,23 +900,151 @@ impl DbspCompiler { // Get input schema for column resolution let input_schema = filter.input.schema(); - // Convert predicate to DBSP expression - let dbsp_predicate = Self::compile_expr(&filter.predicate)?; + // Check if the predicate contains expressions that need to be computed + if Self::predicate_needs_projection(&filter.predicate) { + // Complex expression in WHERE clause - need to add projection first + // 1. Create projection that adds the computed expression as a new column - // Convert to FilterPredicate - let filter_predicate = Self::compile_filter_predicate(&filter.predicate, input_schema)?; + // First, get all existing columns + let mut projection_exprs = Vec::new(); + let mut dbsp_exprs = Vec::new(); - // Create executable operator - let executable: Box = - Box::new(FilterOperator::new(filter_predicate)); + for col in &input_schema.columns { + projection_exprs.push(LogicalExpr::Column(Column { + name: col.name.clone(), + table: None, + })); + dbsp_exprs.push(DbspExpr::Column(col.name.clone())); + } - // Create filter node - let node_id = self.circuit.add_node( - DbspOperator::Filter { predicate: dbsp_predicate }, - vec![input_id], - executable, - ); - Ok(node_id) + // Now add the expression as a computed column + let temp_column_name = "__temp_filter_expr"; + let computed_expr = Self::extract_expression_from_predicate(&filter.predicate)?; + projection_exprs.push(computed_expr.clone()); + + // Compile the projection expressions + let mut compiled_exprs = Vec::new(); + let mut aliases = Vec::new(); + let mut output_names = Vec::new(); + for (i, expr) in projection_exprs.iter().enumerate() { + let (compiled, _alias) = Self::compile_expression(expr, input_schema)?; + compiled_exprs.push(compiled); + if i < input_schema.columns.len() { + aliases.push(None); + output_names.push(input_schema.columns[i].name.clone()); + } else { + aliases.push(Some(temp_column_name.to_string())); + output_names.push(temp_column_name.to_string()); + } + } + + // Get input column names for ProjectOperator + let input_column_names: Vec = input_schema.columns.iter() + .map(|col| col.name.clone()) + .collect(); + + // Create projection operator + let proj_executable: Box = + Box::new(ProjectOperator::from_compiled( + compiled_exprs.clone(), + aliases.clone(), + input_column_names.clone(), + output_names.clone() + )?); + + // Create updated schema for the projection output + let mut proj_schema_columns = input_schema.columns.clone(); + proj_schema_columns.push(ColumnInfo { + name: temp_column_name.to_string(), + table: None, + database: None, + table_alias: None, + ty: Type::Integer, // Computed expressions default to Integer + }); + let proj_schema = SchemaRef::new(LogicalSchema { + columns: proj_schema_columns, + }); + + // Add projection node + let proj_id = self.circuit.add_node( + DbspOperator::Projection { + exprs: dbsp_exprs.clone(), + schema: proj_schema.clone(), + }, + vec![input_id], + proj_executable, + ); + + // Now create a filter that replaces the complex expression with the temp column + // but keeps all other conditions intact + let replaced_predicate = Self::replace_complex_with_temp(&filter.predicate, temp_column_name)?; + let filter_predicate = Self::compile_filter_predicate(&replaced_predicate, &proj_schema)?; + + let filter_executable: Box = + Box::new(FilterOperator::new(filter_predicate)); + + // Create filter node + let filter_id = self.circuit.add_node( + DbspOperator::Filter { predicate: Self::compile_expr(&replaced_predicate)? }, + vec![proj_id], + filter_executable, + ); + + // Finally, project again to remove the temporary column + let mut final_exprs = Vec::new(); + let mut final_aliases = Vec::new(); + let mut final_names = Vec::new(); + let mut final_dbsp_exprs = Vec::new(); + + for (i, column) in input_schema.columns.iter().enumerate() { + let col_name = &column.name; + final_exprs.push(compiled_exprs[i].clone()); + final_aliases.push(None); + final_names.push(col_name.clone()); + final_dbsp_exprs.push(DbspExpr::Column(col_name.clone())); + } + + // Input names for the final projection include the temp column + let filter_output_names = output_names.clone(); + + let final_proj_executable: Box = + Box::new(ProjectOperator::from_compiled( + final_exprs, + final_aliases, + filter_output_names, + final_names.clone() + )?); + + let final_id = self.circuit.add_node( + DbspOperator::Projection { + exprs: final_dbsp_exprs, + schema: input_schema.clone(), // Back to original schema + }, + vec![filter_id], + final_proj_executable, + ); + + Ok(final_id) + } else { + // Simple filter - use existing implementation + // Convert predicate to DBSP expression + let dbsp_predicate = Self::compile_expr(&filter.predicate)?; + + // Convert to FilterPredicate + let filter_predicate = Self::compile_filter_predicate(&filter.predicate, input_schema)?; + + // Create executable operator + let executable: Box = + Box::new(FilterOperator::new(filter_predicate)); + + // Create filter node + let node_id = self.circuit.add_node( + DbspOperator::Filter { predicate: dbsp_predicate }, + vec![input_id], + executable, + ); + Ok(node_id) + } } LogicalPlan::Aggregate(agg) => { // Compile the input first @@ -1285,7 +1415,12 @@ impl DbspCompiler { let lit = match val { Value::Integer(i) => ast::Literal::Numeric(i.to_string()), Value::Float(f) => ast::Literal::Numeric(f.to_string()), - Value::Text(t) => ast::Literal::String(t.to_string()), + Value::Text(t) => { + // Add quotes for string literals as translate_expr expects them + // Also escape any single quotes in the string + let escaped = t.to_string().replace('\'', "''"); + ast::Literal::String(format!("'{escaped}'")) + } Value::Blob(b) => ast::Literal::Blob(format!("{b:?}")), Value::Null => ast::Literal::Null, }; @@ -1368,6 +1503,109 @@ impl DbspCompiler { } } + /// Check if a predicate contains expressions that need projection + fn predicate_needs_projection(expr: &LogicalExpr) -> bool { + match expr { + LogicalExpr::BinaryExpr { left, op, right } => { + match (left.as_ref(), right.as_ref()) { + // Simple column to literal - OK + (LogicalExpr::Column(_), LogicalExpr::Literal(_)) => false, + // Simple column to column - OK + (LogicalExpr::Column(_), LogicalExpr::Column(_)) => false, + // AND/OR of simple expressions - check recursively + _ if matches!(op, BinaryOperator::And | BinaryOperator::Or) => { + Self::predicate_needs_projection(left) + || Self::predicate_needs_projection(right) + } + // Any other pattern needs projection + _ => true, + } + } + _ => false, + } + } + + /// Extract the expression part from a predicate that needs to be computed + fn extract_expression_from_predicate(expr: &LogicalExpr) -> Result { + match expr { + LogicalExpr::BinaryExpr { left, op, right } => { + // Handle AND/OR - recursively find the complex expression + if matches!(op, BinaryOperator::And | BinaryOperator::Or) { + // Check left side first + if Self::predicate_needs_projection(left) { + return Self::extract_expression_from_predicate(left); + } + // Then check right side + if Self::predicate_needs_projection(right) { + return Self::extract_expression_from_predicate(right); + } + // Neither side needs projection (shouldn't happen if predicate_needs_projection was true) + return Ok(expr.clone()); + } + + // For expressions like (age * 2) > 30, we want to extract (age * 2) + if matches!( + op, + BinaryOperator::Greater + | BinaryOperator::GreaterEquals + | BinaryOperator::Less + | BinaryOperator::LessEquals + | BinaryOperator::Equals + | BinaryOperator::NotEquals + ) { + // Return the left side if it's not a simple column + if !matches!(left.as_ref(), LogicalExpr::Column(_)) { + Ok((**left).clone()) + } else { + // Must be the whole expression then + Ok(expr.clone()) + } + } else { + Ok(expr.clone()) + } + } + _ => Ok(expr.clone()), + } + } + + /// Replace complex expressions in the predicate with references to the temp column + fn replace_complex_with_temp( + expr: &LogicalExpr, + temp_column_name: &str, + ) -> Result { + match expr { + LogicalExpr::BinaryExpr { left, op, right } => { + // Handle AND/OR - recursively process both sides + if matches!(op, BinaryOperator::And | BinaryOperator::Or) { + let new_left = Self::replace_complex_with_temp(left, temp_column_name)?; + let new_right = Self::replace_complex_with_temp(right, temp_column_name)?; + return Ok(LogicalExpr::BinaryExpr { + left: Box::new(new_left), + op: *op, + right: Box::new(new_right), + }); + } + + // Check if this is a complex comparison that needs replacement + if Self::predicate_needs_projection(expr) { + // Replace the complex expression (left side) with the temp column + return Ok(LogicalExpr::BinaryExpr { + left: Box::new(LogicalExpr::Column(Column { + name: temp_column_name.to_string(), + table: None, + })), + op: *op, + right: right.clone(), + }); + } + + // Simple comparison - keep as is + Ok(expr.clone()) + } + _ => Ok(expr.clone()), + } + } + /// Compile a logical expression to a FilterPredicate for execution fn compile_filter_predicate( expr: &LogicalExpr, @@ -5055,4 +5293,75 @@ mod tests { "customers.name should be Customer Bob" ); } + + #[test] + fn test_expression_in_where_clause() { + // Test expressions in WHERE clauses like (quantity * price) >= 400 + let (mut circuit, pager) = compile_sql!("SELECT * FROM users WHERE (age * 2) > 30"); + + // Create test data + let mut input_delta = Delta::new(); + input_delta.insert( + 1, + vec![ + Value::Integer(1), + Value::Text("Alice".into()), + Value::Integer(20), // age * 2 = 40 > 30, should pass + ], + ); + input_delta.insert( + 2, + vec![ + Value::Integer(2), + Value::Text("Bob".into()), + Value::Integer(10), // age * 2 = 20 <= 30, should be filtered out + ], + ); + input_delta.insert( + 3, + vec![ + Value::Integer(3), + Value::Text("Charlie".into()), + Value::Integer(16), // age * 2 = 32 > 30, should pass + ], + ); + + // Create input map + let mut inputs = HashMap::new(); + inputs.insert("users".to_string(), input_delta); + + let result = test_execute(&mut circuit, inputs.clone(), pager.clone()).unwrap(); + + // Should only have Alice and Charlie (age * 2 > 30) + assert_eq!( + result.changes.len(), + 2, + "Should have 2 rows after filtering" + ); + + // Check Alice + let alice = result + .changes + .iter() + .find(|(row, _)| row.values[0] == Value::Integer(1)) + .expect("Alice should be in result"); + assert_eq!(alice.0.values[1], Value::Text("Alice".into())); + assert_eq!(alice.0.values[2], Value::Integer(20)); + + // Check Charlie + let charlie = result + .changes + .iter() + .find(|(row, _)| row.values[0] == Value::Integer(3)) + .expect("Charlie should be in result"); + assert_eq!(charlie.0.values[1], Value::Text("Charlie".into())); + assert_eq!(charlie.0.values[2], Value::Integer(16)); + + // Bob should not be in result + let bob = result + .changes + .iter() + .find(|(row, _)| row.values[0] == Value::Integer(2)); + assert!(bob.is_none(), "Bob should be filtered out"); + } } From e5a106d8d6f8e1f0e29f4223f948583e87ff9860 Mon Sep 17 00:00:00 2001 From: Glauber Costa Date: Thu, 18 Sep 2025 10:40:35 -0500 Subject: [PATCH 25/34] enable joins in IncrementalView --- core/incremental/view.rs | 97 ----- testing/materialized_views.test | 689 ++++++++++++++++++++++++++++---- 2 files changed, 612 insertions(+), 174 deletions(-) diff --git a/core/incremental/view.rs b/core/incremental/view.rs index 77f1d0217..9f5fe52cb 100644 --- a/core/incremental/view.rs +++ b/core/incremental/view.rs @@ -189,20 +189,6 @@ pub struct IncrementalView { } impl IncrementalView { - /// Validate that a CREATE MATERIALIZED VIEW statement can be handled by IncrementalView - /// This should be called early, before updating sqlite_master - pub fn can_create_view(select: &ast::Select) -> Result<()> { - // Check for JOINs - let (join_tables, join_condition) = Self::extract_join_info(select); - if join_tables.is_some() || join_condition.is_some() { - return Err(LimboError::ParseError( - "JOINs in views are not yet supported".to_string(), - )); - } - - Ok(()) - } - /// Try to compile the SELECT statement into a DBSP circuit fn try_compile_circuit( select: &ast::Select, @@ -307,13 +293,6 @@ impl IncrementalView { // Extract output columns using the shared function let column_schema = extract_view_columns(&select, schema)?; - let (join_tables, join_condition) = Self::extract_join_info(&select); - if join_tables.is_some() || join_condition.is_some() { - return Err(LimboError::ParseError( - "JOINs in views are not yet supported".to_string(), - )); - } - // Get all tables from FROM clause and JOINs, along with their aliases let (referenced_tables, table_aliases, qualified_table_names) = Self::extract_all_tables(&select, schema)?; @@ -1046,82 +1025,6 @@ impl IncrementalView { } } - /// Extract JOIN information from SELECT statement - #[allow(clippy::type_complexity)] - pub fn extract_join_info( - select: &ast::Select, - ) -> (Option<(String, String)>, Option<(String, String)>) { - use turso_parser::ast::*; - - if let OneSelect::Select { - from: Some(ref from), - .. - } = select.body.select - { - // Check if there are any joins - if !from.joins.is_empty() { - // Get the first (left) table name - let left_table = match from.select.as_ref() { - SelectTable::Table(name, _, _) => Some(name.name.as_str().to_string()), - _ => None, - }; - - // Get the first join (right) table and condition - if let Some(first_join) = from.joins.first() { - let right_table = match &first_join.table.as_ref() { - SelectTable::Table(name, _, _) => Some(name.name.as_str().to_string()), - _ => None, - }; - - // Extract join condition (simplified - assumes single equality) - let join_condition = if let Some(ref constraint) = &first_join.constraint { - match constraint { - JoinConstraint::On(expr) => Self::extract_join_columns_from_expr(expr), - _ => None, - } - } else { - None - }; - - if let (Some(left), Some(right)) = (left_table, right_table) { - return (Some((left, right)), join_condition); - } - } - } - } - - (None, None) - } - - /// Extract join column names from a join condition expression - fn extract_join_columns_from_expr(expr: &ast::Expr) -> Option<(String, String)> { - use turso_parser::ast::*; - - // Look for expressions like: t1.col = t2.col - if let Expr::Binary(left, op, right) = expr { - if matches!(op, Operator::Equals) { - // Extract column names from both sides - let left_col = match &**left { - Expr::Qualified(name, _) => Some(name.as_str().to_string()), - Expr::Id(name) => Some(name.as_str().to_string()), - _ => None, - }; - - let right_col = match &**right { - Expr::Qualified(name, _) => Some(name.as_str().to_string()), - Expr::Id(name) => Some(name.as_str().to_string()), - _ => None, - }; - - if let (Some(l), Some(r)) = (left_col, right_col) { - return Some((l, r)); - } - } - } - - None - } - /// Merge a delta set of changes into the view's current state pub fn merge_delta( &mut self, diff --git a/testing/materialized_views.test b/testing/materialized_views.test index 5d226b016..15229a48c 100755 --- a/testing/materialized_views.test +++ b/testing/materialized_views.test @@ -44,13 +44,13 @@ do_execsql_test_on_specific_db {:memory:} matview-aggregation-population { do_execsql_test_on_specific_db {:memory:} matview-filter-with-groupby { CREATE TABLE t(a INTEGER, b INTEGER); INSERT INTO t(a,b) VALUES (2,2), (3,3), (6,6), (7,7); - + CREATE MATERIALIZED VIEW v AS SELECT b as yourb, SUM(a) as mysum, COUNT(a) as mycount FROM t WHERE b > 2 GROUP BY b; - + SELECT * FROM v ORDER BY yourb; } {3|3|1 6|6|1 @@ -63,13 +63,13 @@ do_execsql_test_on_specific_db {:memory:} matview-insert-maintenance { FROM t WHERE b > 2 GROUP BY b; - + INSERT INTO t VALUES (3,3), (6,6); SELECT * FROM v ORDER BY b; - + INSERT INTO t VALUES (4,3), (5,6); SELECT * FROM v ORDER BY b; - + INSERT INTO t VALUES (1,1), (2,2); SELECT * FROM v ORDER BY b; } {3|3|1 @@ -87,17 +87,17 @@ do_execsql_test_on_specific_db {:memory:} matview-delete-maintenance { (3, 'A', 30), (4, 'B', 40), (5, 'A', 50); - + CREATE MATERIALIZED VIEW category_sums AS SELECT category, SUM(amount) as total, COUNT(*) as cnt FROM items GROUP BY category; - + SELECT * FROM category_sums ORDER BY category; - + DELETE FROM items WHERE id = 3; SELECT * FROM category_sums ORDER BY category; - + DELETE FROM items WHERE category = 'B'; SELECT * FROM category_sums ORDER BY category; } {A|90|3 @@ -113,17 +113,17 @@ do_execsql_test_on_specific_db {:memory:} matview-update-maintenance { (2, 200, 2), (3, 300, 1), (4, 400, 2); - + CREATE MATERIALIZED VIEW status_totals AS SELECT status, SUM(value) as total, COUNT(*) as cnt FROM records GROUP BY status; - + SELECT * FROM status_totals ORDER BY status; - + UPDATE records SET value = 150 WHERE id = 1; SELECT * FROM status_totals ORDER BY status; - + UPDATE records SET status = 2 WHERE id = 3; SELECT * FROM status_totals ORDER BY status; } {1|400|2 @@ -136,10 +136,10 @@ do_execsql_test_on_specific_db {:memory:} matview-update-maintenance { do_execsql_test_on_specific_db {:memory:} matview-integer-primary-key-basic { CREATE TABLE t(a INTEGER PRIMARY KEY, b INTEGER); INSERT INTO t(a,b) VALUES (2,2), (3,3), (6,6), (7,7); - + CREATE MATERIALIZED VIEW v AS SELECT * FROM t WHERE b > 2; - + SELECT * FROM v ORDER BY a; } {3|3 6|6 @@ -148,15 +148,15 @@ do_execsql_test_on_specific_db {:memory:} matview-integer-primary-key-basic { do_execsql_test_on_specific_db {:memory:} matview-integer-primary-key-update-rowid { CREATE TABLE t(a INTEGER PRIMARY KEY, b INTEGER); INSERT INTO t(a,b) VALUES (2,2), (3,3), (6,6), (7,7); - + CREATE MATERIALIZED VIEW v AS SELECT * FROM t WHERE b > 2; - + SELECT * FROM v ORDER BY a; - + UPDATE t SET a = 1 WHERE b = 3; SELECT * FROM v ORDER BY a; - + UPDATE t SET a = 10 WHERE a = 6; SELECT * FROM v ORDER BY a; } {3|3 @@ -172,15 +172,15 @@ do_execsql_test_on_specific_db {:memory:} matview-integer-primary-key-update-row do_execsql_test_on_specific_db {:memory:} matview-integer-primary-key-update-value { CREATE TABLE t(a INTEGER PRIMARY KEY, b INTEGER); INSERT INTO t(a,b) VALUES (2,2), (3,3), (6,6), (7,7); - + CREATE MATERIALIZED VIEW v AS SELECT * FROM t WHERE b > 2; - + SELECT * FROM v ORDER BY a; - + UPDATE t SET b = 1 WHERE a = 6; SELECT * FROM v ORDER BY a; - + UPDATE t SET b = 5 WHERE a = 2; SELECT * FROM v ORDER BY a; } {3|3 @@ -200,18 +200,18 @@ do_execsql_test_on_specific_db {:memory:} matview-integer-primary-key-with-aggre (3, 20, 300), (4, 20, 400), (5, 10, 500); - + CREATE MATERIALIZED VIEW v AS SELECT b, SUM(c) as total, COUNT(*) as cnt FROM t WHERE a > 2 GROUP BY b; - + SELECT * FROM v ORDER BY b; - + UPDATE t SET a = 6 WHERE a = 1; SELECT * FROM v ORDER BY b; - + DELETE FROM t WHERE a = 3; SELECT * FROM v ORDER BY b; } {10|500|1 @@ -228,7 +228,7 @@ do_execsql_test_on_specific_db {:memory:} matview-complex-filter-aggregation { amount INTEGER, type INTEGER ); - + INSERT INTO transactions VALUES (1, 100, 50, 1), (2, 100, 30, 2), @@ -236,21 +236,21 @@ do_execsql_test_on_specific_db {:memory:} matview-complex-filter-aggregation { (4, 100, 20, 1), (5, 200, 40, 2), (6, 300, 60, 1); - + CREATE MATERIALIZED VIEW account_deposits AS SELECT account, SUM(amount) as total_deposits, COUNT(*) as deposit_count FROM transactions WHERE type = 1 GROUP BY account; - + SELECT * FROM account_deposits ORDER BY account; - + INSERT INTO transactions VALUES (7, 100, 25, 1); SELECT * FROM account_deposits ORDER BY account; - + UPDATE transactions SET amount = 80 WHERE id = 1; SELECT * FROM account_deposits ORDER BY account; - + DELETE FROM transactions WHERE id = 3; SELECT * FROM account_deposits ORDER BY account; } {100|70|2 @@ -273,19 +273,19 @@ do_execsql_test_on_specific_db {:memory:} matview-sum-count-only { (3, 30, 2), (4, 40, 2), (5, 50, 1); - + CREATE MATERIALIZED VIEW category_stats AS SELECT category, SUM(value) as sum_val, COUNT(*) as cnt FROM data GROUP BY category; - + SELECT * FROM category_stats ORDER BY category; - + INSERT INTO data VALUES (6, 5, 1); SELECT * FROM category_stats ORDER BY category; - + UPDATE data SET value = 35 WHERE id = 3; SELECT * FROM category_stats ORDER BY category; } {1|80|3 @@ -302,9 +302,9 @@ do_execsql_test_on_specific_db {:memory:} matview-empty-table-population { FROM t WHERE b > 5 GROUP BY b; - + SELECT COUNT(*) FROM v; - + INSERT INTO t VALUES (1, 3), (2, 7), (3, 9); SELECT * FROM v ORDER BY b; } {0 @@ -314,15 +314,15 @@ do_execsql_test_on_specific_db {:memory:} matview-empty-table-population { do_execsql_test_on_specific_db {:memory:} matview-all-rows-filtered { CREATE TABLE t(a INTEGER, b INTEGER); INSERT INTO t VALUES (1, 1), (2, 2), (3, 3); - + CREATE MATERIALIZED VIEW v AS SELECT * FROM t WHERE b > 10; - + SELECT COUNT(*) FROM v; - + INSERT INTO t VALUES (11, 11); SELECT * FROM v; - + UPDATE t SET b = 1 WHERE a = 11; SELECT COUNT(*) FROM v; } {0 @@ -335,26 +335,26 @@ do_execsql_test_on_specific_db {:memory:} matview-mixed-operations-sequence { customer_id INTEGER, amount INTEGER ); - + INSERT INTO orders VALUES (1, 100, 50); INSERT INTO orders VALUES (2, 200, 75); - + CREATE MATERIALIZED VIEW customer_totals AS SELECT customer_id, SUM(amount) as total, COUNT(*) as order_count FROM orders GROUP BY customer_id; - + SELECT * FROM customer_totals ORDER BY customer_id; - + INSERT INTO orders VALUES (3, 100, 25); SELECT * FROM customer_totals ORDER BY customer_id; - + UPDATE orders SET amount = 100 WHERE order_id = 2; SELECT * FROM customer_totals ORDER BY customer_id; - + DELETE FROM orders WHERE order_id = 1; SELECT * FROM customer_totals ORDER BY customer_id; - + INSERT INTO orders VALUES (4, 300, 150); SELECT * FROM customer_totals ORDER BY customer_id; } {100|50|1 @@ -389,17 +389,17 @@ do_execsql_test_on_specific_db {:memory:} matview-projections { do_execsql_test_on_specific_db {:memory:} matview-rollback-insert { CREATE TABLE t(a INTEGER, b INTEGER); INSERT INTO t VALUES (1, 10), (2, 20), (3, 30); - + CREATE MATERIALIZED VIEW v AS SELECT * FROM t WHERE b > 15; - + SELECT * FROM v ORDER BY a; - + BEGIN; INSERT INTO t VALUES (4, 40), (5, 50); SELECT * FROM v ORDER BY a; ROLLBACK; - + SELECT * FROM v ORDER BY a; } {2|20 3|30 @@ -413,17 +413,17 @@ do_execsql_test_on_specific_db {:memory:} matview-rollback-insert { do_execsql_test_on_specific_db {:memory:} matview-rollback-delete { CREATE TABLE t(a INTEGER, b INTEGER); INSERT INTO t VALUES (1, 10), (2, 20), (3, 30), (4, 40); - + CREATE MATERIALIZED VIEW v AS SELECT * FROM t WHERE b > 15; - + SELECT * FROM v ORDER BY a; - + BEGIN; DELETE FROM t WHERE a IN (2, 3); SELECT * FROM v ORDER BY a; ROLLBACK; - + SELECT * FROM v ORDER BY a; } {2|20 3|30 @@ -436,18 +436,18 @@ do_execsql_test_on_specific_db {:memory:} matview-rollback-delete { do_execsql_test_on_specific_db {:memory:} matview-rollback-update { CREATE TABLE t(a INTEGER, b INTEGER); INSERT INTO t VALUES (1, 10), (2, 20), (3, 30); - + CREATE MATERIALIZED VIEW v AS SELECT * FROM t WHERE b > 15; - + SELECT * FROM v ORDER BY a; - + BEGIN; UPDATE t SET b = 5 WHERE a = 2; UPDATE t SET b = 35 WHERE a = 1; SELECT * FROM v ORDER BY a; ROLLBACK; - + SELECT * FROM v ORDER BY a; } {2|20 3|30 @@ -459,19 +459,19 @@ do_execsql_test_on_specific_db {:memory:} matview-rollback-update { do_execsql_test_on_specific_db {:memory:} matview-rollback-aggregation { CREATE TABLE sales(product_id INTEGER, amount INTEGER); INSERT INTO sales VALUES (1, 100), (1, 200), (2, 150), (2, 250); - + CREATE MATERIALIZED VIEW product_totals AS SELECT product_id, SUM(amount) as total, COUNT(*) as cnt FROM sales GROUP BY product_id; - + SELECT * FROM product_totals ORDER BY product_id; - + BEGIN; INSERT INTO sales VALUES (1, 50), (3, 300); SELECT * FROM product_totals ORDER BY product_id; ROLLBACK; - + SELECT * FROM product_totals ORDER BY product_id; } {1|300|2 2|400|2 @@ -484,21 +484,21 @@ do_execsql_test_on_specific_db {:memory:} matview-rollback-aggregation { do_execsql_test_on_specific_db {:memory:} matview-rollback-mixed-operations { CREATE TABLE orders(id INTEGER PRIMARY KEY, customer INTEGER, amount INTEGER); INSERT INTO orders VALUES (1, 100, 50), (2, 200, 75), (3, 100, 25); - + CREATE MATERIALIZED VIEW customer_totals AS SELECT customer, SUM(amount) as total, COUNT(*) as cnt FROM orders GROUP BY customer; - + SELECT * FROM customer_totals ORDER BY customer; - + BEGIN; INSERT INTO orders VALUES (4, 100, 100); UPDATE orders SET amount = 150 WHERE id = 2; DELETE FROM orders WHERE id = 3; SELECT * FROM customer_totals ORDER BY customer; ROLLBACK; - + SELECT * FROM customer_totals ORDER BY customer; } {100|75|2 200|75|1 @@ -514,22 +514,22 @@ do_execsql_test_on_specific_db {:memory:} matview-rollback-filtered-aggregation (2, 100, 30, 'withdraw'), (3, 200, 100, 'deposit'), (4, 200, 40, 'withdraw'); - + CREATE MATERIALIZED VIEW deposits AS SELECT account, SUM(amount) as total_deposits, COUNT(*) as cnt FROM transactions WHERE type = 'deposit' GROUP BY account; - + SELECT * FROM deposits ORDER BY account; - + BEGIN; INSERT INTO transactions VALUES (5, 100, 75, 'deposit'); UPDATE transactions SET amount = 60 WHERE id = 1; DELETE FROM transactions WHERE id = 3; SELECT * FROM deposits ORDER BY account; ROLLBACK; - + SELECT * FROM deposits ORDER BY account; } {100|50|1 200|100|1 @@ -540,12 +540,12 @@ do_execsql_test_on_specific_db {:memory:} matview-rollback-filtered-aggregation do_execsql_test_on_specific_db {:memory:} matview-rollback-empty-view { CREATE TABLE t(a INTEGER, b INTEGER); INSERT INTO t VALUES (1, 5), (2, 8); - + CREATE MATERIALIZED VIEW v AS SELECT * FROM t WHERE b > 10; - + SELECT COUNT(*) FROM v; - + BEGIN; INSERT INTO t VALUES (3, 15), (4, 20); SELECT * FROM v ORDER BY a; @@ -556,3 +556,538 @@ do_execsql_test_on_specific_db {:memory:} matview-rollback-empty-view { 3|15 4|20 0} + +# Join tests for materialized views + +do_execsql_test_on_specific_db {:memory:} matview-simple-join { + CREATE TABLE users(id INTEGER PRIMARY KEY, name TEXT, age INTEGER); + CREATE TABLE orders(order_id INTEGER PRIMARY KEY, user_id INTEGER, product_id INTEGER, quantity INTEGER); + + INSERT INTO users VALUES (1, 'Alice', 25), (2, 'Bob', 30), (3, 'Charlie', 35); + INSERT INTO orders VALUES (1, 1, 100, 5), (2, 1, 101, 3), (3, 2, 100, 7); + + CREATE MATERIALIZED VIEW user_orders AS + SELECT u.name, o.quantity + FROM users u + JOIN orders o ON u.id = o.user_id; + + SELECT * FROM user_orders ORDER BY name, quantity; +} {Alice|3 +Alice|5 +Bob|7} + +do_execsql_test_on_specific_db {:memory:} matview-join-with-aggregation { + CREATE TABLE users(id INTEGER PRIMARY KEY, name TEXT); + CREATE TABLE orders(order_id INTEGER PRIMARY KEY, user_id INTEGER, amount INTEGER); + + INSERT INTO users VALUES (1, 'Alice'), (2, 'Bob'); + INSERT INTO orders VALUES (1, 1, 100), (2, 1, 150), (3, 2, 200), (4, 2, 50); + + CREATE MATERIALIZED VIEW user_totals AS + SELECT u.name, SUM(o.amount) as total_amount + FROM users u + JOIN orders o ON u.id = o.user_id + GROUP BY u.name; + + SELECT * FROM user_totals ORDER BY name; +} {Alice|250 +Bob|250} + +do_execsql_test_on_specific_db {:memory:} matview-three-way-join { + CREATE TABLE customers(id INTEGER PRIMARY KEY, name TEXT, city TEXT); + CREATE TABLE orders(id INTEGER PRIMARY KEY, customer_id INTEGER, product_id INTEGER, quantity INTEGER); + CREATE TABLE products(id INTEGER PRIMARY KEY, name TEXT, price INTEGER); + + INSERT INTO customers VALUES (1, 'Alice', 'NYC'), (2, 'Bob', 'LA'); + INSERT INTO products VALUES (1, 'Widget', 10), (2, 'Gadget', 20); + INSERT INTO orders VALUES (1, 1, 1, 5), (2, 1, 2, 3), (3, 2, 1, 2); + + CREATE MATERIALIZED VIEW sales_summary AS + SELECT c.name as customer_name, p.name as product_name, o.quantity + FROM customers c + JOIN orders o ON c.id = o.customer_id + JOIN products p ON o.product_id = p.id; + + SELECT * FROM sales_summary ORDER BY customer_name, product_name; +} {Alice|Gadget|3 +Alice|Widget|5 +Bob|Widget|2} + +do_execsql_test_on_specific_db {:memory:} matview-three-way-join-with-aggregation { + CREATE TABLE customers(id INTEGER PRIMARY KEY, name TEXT); + CREATE TABLE orders(id INTEGER PRIMARY KEY, customer_id INTEGER, product_id INTEGER, quantity INTEGER); + CREATE TABLE products(id INTEGER PRIMARY KEY, name TEXT, price INTEGER); + + INSERT INTO customers VALUES (1, 'Alice'), (2, 'Bob'); + INSERT INTO products VALUES (1, 'Widget', 10), (2, 'Gadget', 20); + INSERT INTO orders VALUES (1, 1, 1, 5), (2, 1, 2, 3), (3, 2, 1, 2), (4, 1, 1, 4); + + CREATE MATERIALIZED VIEW sales_totals AS + SELECT c.name as customer_name, p.name as product_name, + SUM(o.quantity) as total_quantity, + SUM(o.quantity * p.price) as total_value + FROM customers c + JOIN orders o ON c.id = o.customer_id + JOIN products p ON o.product_id = p.id + GROUP BY c.name, p.name; + + SELECT * FROM sales_totals ORDER BY customer_name, product_name; +} {Alice|Gadget|3|60 +Alice|Widget|9|90 +Bob|Widget|2|20} + +do_execsql_test_on_specific_db {:memory:} matview-join-incremental-insert { + CREATE TABLE users(id INTEGER PRIMARY KEY, name TEXT); + CREATE TABLE orders(order_id INTEGER PRIMARY KEY, user_id INTEGER, amount INTEGER); + + INSERT INTO users VALUES (1, 'Alice'); + INSERT INTO orders VALUES (1, 1, 100); + + CREATE MATERIALIZED VIEW user_orders AS + SELECT u.name, o.amount + FROM users u + JOIN orders o ON u.id = o.user_id; + + SELECT COUNT(*) FROM user_orders; + + INSERT INTO orders VALUES (2, 1, 150); + SELECT COUNT(*) FROM user_orders; + + INSERT INTO users VALUES (2, 'Bob'); + INSERT INTO orders VALUES (3, 2, 200); + SELECT COUNT(*) FROM user_orders; +} {1 +2 +3} + +do_execsql_test_on_specific_db {:memory:} matview-join-incremental-delete { + CREATE TABLE users(id INTEGER PRIMARY KEY, name TEXT); + CREATE TABLE orders(order_id INTEGER PRIMARY KEY, user_id INTEGER, amount INTEGER); + + INSERT INTO users VALUES (1, 'Alice'), (2, 'Bob'); + INSERT INTO orders VALUES (1, 1, 100), (2, 1, 150), (3, 2, 200); + + CREATE MATERIALIZED VIEW user_orders AS + SELECT u.name, o.amount + FROM users u + JOIN orders o ON u.id = o.user_id; + + SELECT COUNT(*) FROM user_orders; + + DELETE FROM orders WHERE order_id = 2; + SELECT COUNT(*) FROM user_orders; + + DELETE FROM users WHERE id = 2; + SELECT COUNT(*) FROM user_orders; +} {3 +2 +1} + +do_execsql_test_on_specific_db {:memory:} matview-join-incremental-update { + CREATE TABLE users(id INTEGER PRIMARY KEY, name TEXT); + CREATE TABLE orders(order_id INTEGER PRIMARY KEY, user_id INTEGER, amount INTEGER); + + INSERT INTO users VALUES (1, 'Alice'), (2, 'Bob'); + INSERT INTO orders VALUES (1, 1, 100), (2, 2, 200); + + CREATE MATERIALIZED VIEW user_orders AS + SELECT u.name, o.amount + FROM users u + JOIN orders o ON u.id = o.user_id; + + SELECT * FROM user_orders ORDER BY name; + + UPDATE orders SET amount = 150 WHERE order_id = 1; + SELECT * FROM user_orders ORDER BY name; + + UPDATE users SET name = 'Robert' WHERE id = 2; + SELECT * FROM user_orders ORDER BY name; +} {Alice|100 +Bob|200 +Alice|150 +Bob|200 +Alice|150 +Robert|200} + +do_execsql_test_on_specific_db {:memory:} matview-join-with-filter { + CREATE TABLE users(id INTEGER PRIMARY KEY, name TEXT, age INTEGER); + CREATE TABLE orders(order_id INTEGER PRIMARY KEY, user_id INTEGER, amount INTEGER); + + INSERT INTO users VALUES (1, 'Alice', 25), (2, 'Bob', 35), (3, 'Charlie', 20); + INSERT INTO orders VALUES (1, 1, 100), (2, 2, 200), (3, 3, 150); + + CREATE MATERIALIZED VIEW adult_orders AS + SELECT u.name, o.amount + FROM users u + JOIN orders o ON u.id = o.user_id + WHERE u.age > 21; + + SELECT * FROM adult_orders ORDER BY name; +} {Alice|100 +Bob|200} + +do_execsql_test_on_specific_db {:memory:} matview-join-rollback { + CREATE TABLE users(id INTEGER PRIMARY KEY, name TEXT); + CREATE TABLE orders(order_id INTEGER PRIMARY KEY, user_id INTEGER, amount INTEGER); + + INSERT INTO users VALUES (1, 'Alice'), (2, 'Bob'); + INSERT INTO orders VALUES (1, 1, 100), (2, 2, 200); + + CREATE MATERIALIZED VIEW user_orders AS + SELECT u.name, o.amount + FROM users u + JOIN orders o ON u.id = o.user_id; + + SELECT COUNT(*) FROM user_orders; + + BEGIN; + INSERT INTO users VALUES (3, 'Charlie'); + INSERT INTO orders VALUES (3, 3, 300); + SELECT COUNT(*) FROM user_orders; + ROLLBACK; + + SELECT COUNT(*) FROM user_orders; +} {2 +3 +2} + +# ===== COMPREHENSIVE JOIN TESTS ===== + +# Test 1: Join with filter BEFORE the join (on base tables) +do_execsql_test_on_specific_db {:memory:} matview-join-with-pre-filter { + CREATE TABLE employees(id INTEGER PRIMARY KEY, name TEXT, department TEXT, salary INTEGER); + CREATE TABLE departments(id INTEGER PRIMARY KEY, dept_name TEXT, budget INTEGER); + + INSERT INTO employees VALUES + (1, 'Alice', 'Engineering', 80000), + (2, 'Bob', 'Engineering', 90000), + (3, 'Charlie', 'Sales', 60000), + (4, 'David', 'Sales', 65000), + (5, 'Eve', 'HR', 70000); + + INSERT INTO departments VALUES + (1, 'Engineering', 500000), + (2, 'Sales', 300000), + (3, 'HR', 200000); + + -- View: Join only high-salary employees with their departments + CREATE MATERIALIZED VIEW high_earners_by_dept AS + SELECT e.name, e.salary, d.dept_name, d.budget + FROM employees e + JOIN departments d ON e.department = d.dept_name + WHERE e.salary > 70000; + + SELECT * FROM high_earners_by_dept ORDER BY salary DESC; +} {Bob|90000|Engineering|500000 +Alice|80000|Engineering|500000} + +# Test 2: Join with filter AFTER the join +do_execsql_test_on_specific_db {:memory:} matview-join-with-post-filter { + CREATE TABLE products(id INTEGER PRIMARY KEY, name TEXT, category_id INTEGER, price INTEGER); + CREATE TABLE categories(id INTEGER PRIMARY KEY, name TEXT, min_price INTEGER); + + INSERT INTO products VALUES + (1, 'Laptop', 1, 1200), + (2, 'Mouse', 1, 25), + (3, 'Shirt', 2, 50), + (4, 'Shoes', 2, 120); + + INSERT INTO categories VALUES + (1, 'Electronics', 100), + (2, 'Clothing', 30); + + -- View: Products that meet or exceed their category's minimum price + CREATE MATERIALIZED VIEW premium_products AS + SELECT p.name as product, c.name as category, p.price, c.min_price + FROM products p + JOIN categories c ON p.category_id = c.id + WHERE p.price >= c.min_price; + + SELECT * FROM premium_products ORDER BY price DESC; +} {Laptop|Electronics|1200|100 +Shoes|Clothing|120|30 +Shirt|Clothing|50|30} + +# Test 3: Join with aggregation BEFORE the join +do_execsql_test_on_specific_db {:memory:} matview-aggregation-before-join { + CREATE TABLE orders(id INTEGER PRIMARY KEY, customer_id INTEGER, product_id INTEGER, quantity INTEGER, order_date INTEGER); + CREATE TABLE customers(id INTEGER PRIMARY KEY, name TEXT, tier TEXT); + + INSERT INTO orders VALUES + (1, 1, 101, 2, 1), + (2, 1, 102, 1, 1), + (3, 2, 101, 5, 1), + (4, 1, 101, 3, 2), + (5, 2, 103, 2, 2), + (6, 3, 102, 1, 2); + + INSERT INTO customers VALUES + (1, 'Alice', 'Gold'), + (2, 'Bob', 'Silver'), + (3, 'Charlie', 'Bronze'); + + -- View: Customer order counts joined with customer details + -- Note: Simplified to avoid subquery issues with DBSP compiler + CREATE MATERIALIZED VIEW customer_order_summary AS + SELECT c.name, c.tier, COUNT(o.id) as order_count, SUM(o.quantity) as total_quantity + FROM customers c + JOIN orders o ON c.id = o.customer_id + GROUP BY c.id, c.name, c.tier; + + SELECT * FROM customer_order_summary ORDER BY total_quantity DESC; +} {Bob|Silver|2|7 +Alice|Gold|3|6 +Charlie|Bronze|1|1} + +# Test 4: Join with aggregation AFTER the join +do_execsql_test_on_specific_db {:memory:} matview-aggregation-after-join { + CREATE TABLE sales(id INTEGER PRIMARY KEY, product_id INTEGER, store_id INTEGER, units_sold INTEGER, revenue INTEGER); + CREATE TABLE stores(id INTEGER PRIMARY KEY, name TEXT, region TEXT); + + INSERT INTO sales VALUES + (1, 1, 1, 10, 1000), + (2, 1, 2, 15, 1500), + (3, 2, 1, 5, 250), + (4, 2, 2, 8, 400), + (5, 1, 3, 12, 1200), + (6, 2, 3, 6, 300); + + INSERT INTO stores VALUES + (1, 'StoreA', 'North'), + (2, 'StoreB', 'North'), + (3, 'StoreC', 'South'); + + -- View: Regional sales summary (aggregate after joining) + CREATE MATERIALIZED VIEW regional_sales AS + SELECT st.region, SUM(s.units_sold) as total_units, SUM(s.revenue) as total_revenue + FROM sales s + JOIN stores st ON s.store_id = st.id + GROUP BY st.region; + + SELECT * FROM regional_sales ORDER BY total_revenue DESC; +} {North|38|3150 +South|18|1500} + +# Test 5: Modifying both tables in same transaction +do_execsql_test_on_specific_db {:memory:} matview-join-both-tables-modified { + CREATE TABLE authors(id INTEGER PRIMARY KEY, name TEXT); + CREATE TABLE books(id INTEGER PRIMARY KEY, title TEXT, author_id INTEGER, year INTEGER); + + INSERT INTO authors VALUES (1, 'Orwell'), (2, 'Asimov'); + INSERT INTO books VALUES (1, '1984', 1, 1949), (2, 'Foundation', 2, 1951); + + CREATE MATERIALIZED VIEW author_books AS + SELECT a.name, b.title, b.year + FROM authors a + JOIN books b ON a.id = b.author_id; + + SELECT COUNT(*) FROM author_books; + + BEGIN; + INSERT INTO authors VALUES (3, 'Herbert'); + INSERT INTO books VALUES (3, 'Dune', 3, 1965); + SELECT COUNT(*) FROM author_books; + COMMIT; + + SELECT * FROM author_books ORDER BY year; +} {2 +3 +Orwell|1984|1949 +Asimov|Foundation|1951 +Herbert|Dune|1965} + +# Test 6: Modifying only one table in transaction +do_execsql_test_on_specific_db {:memory:} matview-join-single-table-modified { + CREATE TABLE users(id INTEGER PRIMARY KEY, name TEXT, active INTEGER); + CREATE TABLE posts(id INTEGER PRIMARY KEY, user_id INTEGER, content TEXT); + + INSERT INTO users VALUES (1, 'Alice', 1), (2, 'Bob', 1), (3, 'Charlie', 0); + INSERT INTO posts VALUES (1, 1, 'Hello'), (2, 1, 'World'), (3, 2, 'Test'); + + CREATE MATERIALIZED VIEW active_user_posts AS + SELECT u.name, p.content + FROM users u + JOIN posts p ON u.id = p.user_id + WHERE u.active = 1; + + SELECT COUNT(*) FROM active_user_posts; + + -- Add posts for existing user (modify only posts table) + BEGIN; + INSERT INTO posts VALUES (4, 1, 'NewPost'), (5, 2, 'Another'); + SELECT COUNT(*) FROM active_user_posts; + COMMIT; + + SELECT * FROM active_user_posts ORDER BY name, content; +} {3 +5 +Alice|Hello +Alice|NewPost +Alice|World +Bob|Another +Bob|Test} + + +do_execsql_test_on_specific_db {:memory:} matview-three-way-incremental { + CREATE TABLE students(id INTEGER PRIMARY KEY, name TEXT, major TEXT); + CREATE TABLE courses(id INTEGER PRIMARY KEY, name TEXT, department TEXT, credits INTEGER); + CREATE TABLE enrollments(student_id INTEGER, course_id INTEGER, grade TEXT, PRIMARY KEY(student_id, course_id)); + + INSERT INTO students VALUES (1, 'Alice', 'CS'), (2, 'Bob', 'Math'); + INSERT INTO courses VALUES (1, 'DatabaseSystems', 'CS', 3), (2, 'Calculus', 'Math', 4); + INSERT INTO enrollments VALUES (1, 1, 'A'), (2, 2, 'B'); + + CREATE MATERIALIZED VIEW student_transcripts AS + SELECT s.name as student, c.name as course, c.credits, e.grade + FROM students s + JOIN enrollments e ON s.id = e.student_id + JOIN courses c ON e.course_id = c.id; + + SELECT COUNT(*) FROM student_transcripts; + + -- Add new student + INSERT INTO students VALUES (3, 'Charlie', 'CS'); + SELECT COUNT(*) FROM student_transcripts; + + -- Enroll new student + INSERT INTO enrollments VALUES (3, 1, 'A'), (3, 2, 'A'); + SELECT COUNT(*) FROM student_transcripts; + + -- Add new course + INSERT INTO courses VALUES (3, 'Algorithms', 'CS', 3); + SELECT COUNT(*) FROM student_transcripts; + + -- Enroll existing students in new course + INSERT INTO enrollments VALUES (1, 3, 'B'), (3, 3, 'A'); + SELECT COUNT(*) FROM student_transcripts; + + SELECT * FROM student_transcripts ORDER BY student, course; +} {2 +2 +4 +4 +6 +Alice|Algorithms|3|B +Alice|DatabaseSystems|3|A +Bob|Calculus|4|B +Charlie|Algorithms|3|A +Charlie|Calculus|4|A +Charlie|DatabaseSystems|3|A} + +do_execsql_test_on_specific_db {:memory:} matview-self-join { + CREATE TABLE employees(id INTEGER PRIMARY KEY, name TEXT, manager_id INTEGER, salary INTEGER); + + INSERT INTO employees VALUES + (1, 'CEO', NULL, 150000), + (2, 'VPSales', 1, 120000), + (3, 'VPEngineering', 1, 130000), + (4, 'Engineer1', 3, 90000), + (5, 'Engineer2', 3, 85000), + (6, 'SalesRep', 2, 70000); + + CREATE MATERIALIZED VIEW org_chart AS + SELECT e.name as employee, m.name as manager, e.salary + FROM employees e + JOIN employees m ON e.manager_id = m.id; + + SELECT * FROM org_chart ORDER BY salary DESC; +} {VPEngineering|CEO|130000 +VPSales|CEO|120000 +Engineer1|VPEngineering|90000 +Engineer2|VPEngineering|85000 +SalesRep|VPSales|70000} + +do_execsql_test_on_specific_db {:memory:} matview-join-cascade-update { + CREATE TABLE categories(id INTEGER PRIMARY KEY, name TEXT, discount_rate INTEGER); + CREATE TABLE products(id INTEGER PRIMARY KEY, name TEXT, category_id INTEGER, base_price INTEGER); + + INSERT INTO categories VALUES (1, 'Electronics', 10), (2, 'Books', 5); + INSERT INTO products VALUES + (1, 'Laptop', 1, 1000), + (2, 'Phone', 1, 500), + (3, 'Novel', 2, 20), + (4, 'Textbook', 2, 80); + + CREATE MATERIALIZED VIEW discounted_prices AS + SELECT p.name as product, c.name as category, + p.base_price, c.discount_rate, + (p.base_price * (100 - c.discount_rate) / 100) as final_price + FROM products p + JOIN categories c ON p.category_id = c.id; + + SELECT * FROM discounted_prices ORDER BY final_price DESC; + + -- Update discount rate for Electronics + UPDATE categories SET discount_rate = 20 WHERE id = 1; + + SELECT * FROM discounted_prices ORDER BY final_price DESC; +} {Laptop|Electronics|1000|10|900 +Phone|Electronics|500|10|450 +Textbook|Books|80|5|76 +Novel|Books|20|5|19 +Laptop|Electronics|1000|20|800 +Phone|Electronics|500|20|400 +Textbook|Books|80|5|76 +Novel|Books|20|5|19} + +do_execsql_test_on_specific_db {:memory:} matview-join-delete-cascade { + CREATE TABLE users(id INTEGER PRIMARY KEY, name TEXT, active INTEGER); + CREATE TABLE sessions(id INTEGER PRIMARY KEY, user_id INTEGER, duration INTEGER); + + INSERT INTO users VALUES (1, 'Alice', 1), (2, 'Bob', 1), (3, 'Charlie', 0); + INSERT INTO sessions VALUES + (1, 1, 30), + (2, 1, 45), + (3, 2, 60), + (4, 3, 15), + (5, 2, 90); + + CREATE MATERIALIZED VIEW active_sessions AS + SELECT u.name, s.duration + FROM users u + JOIN sessions s ON u.id = s.user_id + WHERE u.active = 1; + + SELECT COUNT(*) FROM active_sessions; + + -- Delete Bob's sessions + DELETE FROM sessions WHERE user_id = 2; + + SELECT COUNT(*) FROM active_sessions; + SELECT * FROM active_sessions ORDER BY name, duration; +} {4 +2 +Alice|30 +Alice|45} + +do_execsql_test_on_specific_db {:memory:} matview-join-complex-where { + CREATE TABLE orders(id INTEGER PRIMARY KEY, customer_id INTEGER, product_id INTEGER, quantity INTEGER, price INTEGER, order_date INTEGER); + CREATE TABLE customers(id INTEGER PRIMARY KEY, name TEXT, tier TEXT, country TEXT); + + INSERT INTO customers VALUES + (1, 'Alice', 'Gold', 'USA'), + (2, 'Bob', 'Silver', 'Canada'), + (3, 'Charlie', 'Gold', 'USA'), + (4, 'David', 'Bronze', 'UK'); + + INSERT INTO orders VALUES + (1, 1, 1, 5, 100, 20240101), + (2, 2, 2, 3, 50, 20240102), + (3, 3, 1, 10, 100, 20240103), + (4, 4, 3, 2, 75, 20240104), + (5, 1, 2, 4, 50, 20240105), + (6, 3, 3, 6, 75, 20240106); + + -- View: Gold tier USA customers with high-value orders + CREATE MATERIALIZED VIEW premium_usa_orders AS + SELECT c.name, o.quantity, o.price, (o.quantity * o.price) as total + FROM customers c + JOIN orders o ON c.id = o.customer_id + WHERE c.tier = 'Gold' + AND c.country = 'USA' + AND (o.quantity * o.price) >= 400; + + SELECT * FROM premium_usa_orders ORDER by total DESC; +} {Charlie|10|100|1000 +Alice|5|100|500 +Charlie|6|75|450} From f2f7f817e4998e2a755290d10b63e83739925f6c Mon Sep 17 00:00:00 2001 From: Glauber Costa Date: Thu, 18 Sep 2025 16:03:37 -0500 Subject: [PATCH 26/34] populate all tables in IncrementalView For joins to work, we have to populate all referenced tables when we create the view. --- core/incremental/view.rs | 444 +++++++++++++++++++++++---------------- 1 file changed, 266 insertions(+), 178 deletions(-) diff --git a/core/incremental/view.rs b/core/incremental/view.rs index 9f5fe52cb..fd7b3988a 100644 --- a/core/incremental/view.rs +++ b/core/incremental/view.rs @@ -22,8 +22,15 @@ use turso_parser::{ pub enum PopulateState { /// Initial state - need to prepare the query Start, + /// All tables that need to be populated + ProcessingAllTables { + queries: Vec, + current_idx: usize, + }, /// Actively processing rows from the query - Processing { + ProcessingOneTable { + queries: Vec, + current_idx: usize, stmt: Box, rows_processed: usize, /// If we're in the middle of processing a row (merge_delta returned I/O) @@ -38,14 +45,26 @@ impl fmt::Debug for PopulateState { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { match self { PopulateState::Start => write!(f, "Start"), - PopulateState::Processing { + PopulateState::ProcessingAllTables { + current_idx, + queries, + } => f + .debug_struct("ProcessingAllTables") + .field("current_idx", current_idx) + .field("num_queries", &queries.len()) + .finish(), + PopulateState::ProcessingOneTable { + current_idx, rows_processed, pending_row, + queries, .. } => f - .debug_struct("Processing") + .debug_struct("ProcessingOneTable") + .field("current_idx", current_idx) .field("rows_processed", rows_processed) .field("has_pending", &pending_row.is_some()) + .field("total_queries", &queries.len()) .finish(), PopulateState::Done => write!(f, "Done"), } @@ -604,11 +623,19 @@ impl IncrementalView { } _ => { // For comparison operators, check if this condition references only our table + // AND is simple enough to be pushed down (no complex expressions) let referenced_tables = self.get_referenced_tables_in_expr(expr)?; if referenced_tables.len() == 1 && referenced_tables.contains(&table_name.to_string()) { - Ok(Some(expr.clone())) + // Check if this is a simple comparison that can be pushed down + // Complex expressions like (a * b) >= c should be handled by the circuit + if self.is_simple_comparison(expr) { + Ok(Some(expr.clone())) + } else { + // Complex expression - let the circuit handle it + Ok(None) + } } else { Ok(None) } @@ -624,9 +651,11 @@ impl IncrementalView { } _ => { // For other expressions, check if they reference only our table + // AND are simple enough to be pushed down let referenced_tables = self.get_referenced_tables_in_expr(expr)?; if referenced_tables.len() == 1 && referenced_tables.contains(&table_name.to_string()) + && self.is_simple_comparison(expr) { Ok(Some(expr.clone())) } else { @@ -636,6 +665,39 @@ impl IncrementalView { } } + /// Check if an expression is a simple comparison that can be pushed down to table scan + /// Returns true for simple comparisons like "column = value" or "column > value" + /// Returns false for complex expressions like "(a * b) > value" + fn is_simple_comparison(&self, expr: &ast::Expr) -> bool { + match expr { + ast::Expr::Binary(left, op, right) => { + // Check if it's a comparison operator + matches!( + op, + ast::Operator::Equals + | ast::Operator::NotEquals + | ast::Operator::Greater + | ast::Operator::GreaterEquals + | ast::Operator::Less + | ast::Operator::LessEquals + ) && self.is_simple_operand(left) + && self.is_simple_operand(right) + } + _ => false, + } + } + + /// Check if an operand is simple (column reference or literal) + fn is_simple_operand(&self, expr: &ast::Expr) -> bool { + matches!( + expr, + ast::Expr::Id(_) + | ast::Expr::Qualified(_, _) + | ast::Expr::DoublyQualified(_, _, _) + | ast::Expr::Literal(_) + ) + } + /// Get the set of table names referenced in an expression fn get_referenced_tables_in_expr(&self, expr: &ast::Expr) -> crate::Result> { let mut tables = Vec::new(); @@ -820,216 +882,242 @@ impl IncrementalView { pager: &std::sync::Arc, _btree_cursor: &mut BTreeCursor, ) -> crate::Result> { - // If already populated, return immediately - if matches!(self.populate_state, PopulateState::Done) { - return Ok(IOResult::Done(())); - } - // Assert that this is a materialized view with a root page assert!( self.root_page != 0, "populate_from_table should only be called for materialized views with root_page" ); - loop { - // To avoid borrow checker issues, we need to handle state transitions carefully - let needs_start = matches!(self.populate_state, PopulateState::Start); + 'outer: loop { + match std::mem::replace(&mut self.populate_state, PopulateState::Done) { + PopulateState::Start => { + // Generate the SQL query for populating the view + // It is best to use a standard query than a cursor for two reasons: + // 1) Using a sql query will allow us to be much more efficient in cases where we only want + // some rows, in particular for indexed filters + // 2) There are two types of cursors: index and table. In some situations (like for example + // if the table has an integer primary key), the key will be exclusively in the index + // btree and not in the table btree. Using cursors would force us to be aware of this + // distinction (and others), and ultimately lead to reimplementing the whole query + // machinery (next step is which index is best to use, etc) + let queries = self.sql_for_populate()?; - if needs_start { - // Generate the SQL query for populating the view - // It is best to use a standard query than a cursor for two reasons: - // 1) Using a sql query will allow us to be much more efficient in cases where we only want - // some rows, in particular for indexed filters - // 2) There are two types of cursors: index and table. In some situations (like for example - // if the table has an integer primary key), the key will be exclusively in the index - // btree and not in the table btree. Using cursors would force us to be aware of this - // distinction (and others), and ultimately lead to reimplementing the whole query - // machinery (next step is which index is best to use, etc) - let queries = self.sql_for_populate()?; - - // For now, only use the first query (single table population) - if queries.is_empty() { - return Err(LimboError::ParseError( - "No populate queries generated".to_string(), - )); + self.populate_state = PopulateState::ProcessingAllTables { + queries, + current_idx: 0, + }; } - let query = queries[0].clone(); - // Create a new connection for reading to avoid transaction conflicts - // This allows us to read from tables while the parent transaction is writing the view - // The statement holds a reference to this connection, keeping it alive - let read_conn = conn.db.connect()?; - - // Prepare the statement using the read connection - let stmt = read_conn.prepare(&query)?; - - self.populate_state = PopulateState::Processing { - stmt: Box::new(stmt), - rows_processed: 0, - pending_row: None, - }; - // Continue to next state - continue; - } - - // Handle Done state - if matches!(self.populate_state, PopulateState::Done) { - return Ok(IOResult::Done(())); - } - - // Handle Processing state - extract state to avoid borrow issues - let (mut stmt, mut rows_processed, pending_row) = - match std::mem::replace(&mut self.populate_state, PopulateState::Done) { - PopulateState::Processing { - stmt, - rows_processed, - pending_row, - } => (stmt, rows_processed, pending_row), - _ => unreachable!("We already handled Start and Done states"), - }; - - // If we have a pending row from a previous I/O interruption, process it first - if let Some((rowid, values)) = pending_row { - // Create a single-row delta for the pending row - let mut single_row_delta = Delta::new(); - single_row_delta.insert(rowid, values.clone()); - - // Create a DeltaSet with this delta for the first table (for now) - let mut delta_set = DeltaSet::new(); - // TODO: When we support JOINs, determine which table this row came from - delta_set.insert(self.referenced_tables[0].name.clone(), single_row_delta); - - // Process the pending row with the pager - match self.merge_delta(delta_set, pager.clone())? { - IOResult::Done(_) => { - // Row processed successfully, continue to next row - rows_processed += 1; - // Continue to fetch next row from statement - } - IOResult::IO(io) => { - // Still not done, save state with pending row - self.populate_state = PopulateState::Processing { - stmt, - rows_processed, - pending_row: Some((rowid, values)), // Keep the pending row - }; - return Ok(IOResult::IO(io)); + PopulateState::ProcessingAllTables { + queries, + current_idx, + } => { + if current_idx >= queries.len() { + self.populate_state = PopulateState::Done; + return Ok(IOResult::Done(())); } + + let query = queries[current_idx].clone(); + // Create a new connection for reading to avoid transaction conflicts + // This allows us to read from tables while the parent transaction is writing the view + // The statement holds a reference to this connection, keeping it alive + let read_conn = conn.db.connect()?; + + // Prepare the statement using the read connection + let stmt = read_conn.prepare(&query)?; + + self.populate_state = PopulateState::ProcessingOneTable { + queries, + current_idx, + stmt: Box::new(stmt), + rows_processed: 0, + pending_row: None, + }; } - } - // Process rows one at a time - no batching - loop { - // This step() call resumes from where the statement left off - match stmt.step()? { - crate::vdbe::StepResult::Row => { - // Get the row - let row = stmt.row().unwrap(); - - // Extract values from the row - let all_values: Vec = - row.get_values().cloned().collect(); - - // Determine how to extract the rowid - // If there's a rowid alias (INTEGER PRIMARY KEY), the rowid is one of the columns - // Otherwise, it's the last value we explicitly selected - let (rowid, values) = if let Some((idx, _)) = - self.referenced_tables[0].get_rowid_alias_column() - { - // The rowid is the value at the rowid alias column index - let rowid = match all_values.get(idx) { - Some(crate::types::Value::Integer(id)) => *id, - _ => { - // This shouldn't happen - rowid alias must be an integer - rows_processed += 1; - continue; - } - }; - // All values are table columns (no separate rowid was selected) - (rowid, all_values) - } else { - // The last value is the explicitly selected rowid - let rowid = match all_values.last() { - Some(crate::types::Value::Integer(id)) => *id, - _ => { - // This shouldn't happen - rowid must be an integer - rows_processed += 1; - continue; - } - }; - // Get all values except the rowid - let values = all_values[..all_values.len() - 1].to_vec(); - (rowid, values) - }; - - // Create a single-row delta and process it immediately - let mut single_row_delta = Delta::new(); - single_row_delta.insert(rowid, values.clone()); - - // Create a DeltaSet with this delta for the first table (for now) - let mut delta_set = DeltaSet::new(); - // TODO: When we support JOINs, determine which table this row came from - delta_set.insert(self.referenced_tables[0].name.clone(), single_row_delta); - - // Process this single row through merge_delta with the pager - match self.merge_delta(delta_set, pager.clone())? { + PopulateState::ProcessingOneTable { + queries, + current_idx, + mut stmt, + mut rows_processed, + pending_row, + } => { + // If we have a pending row from a previous I/O interruption, process it first + if let Some((rowid, values)) = pending_row { + match self.process_one_row( + rowid, + values.clone(), + current_idx, + pager.clone(), + )? { IOResult::Done(_) => { // Row processed successfully, continue to next row rows_processed += 1; } IOResult::IO(io) => { - // Save state and return I/O - // We'll resume at the SAME row when called again (don't increment rows_processed) - // The circuit still has unfinished work for this row - self.populate_state = PopulateState::Processing { + // Still not done, restore state with pending row and return + self.populate_state = PopulateState::ProcessingOneTable { + queries, + current_idx, stmt, - rows_processed, // Don't increment - row not done yet! - pending_row: Some((rowid, values)), // Save the row for resumption + rows_processed, + pending_row: Some((rowid, values)), }; return Ok(IOResult::IO(io)); } } } - crate::vdbe::StepResult::Done => { - // All rows processed, we're done - self.populate_state = PopulateState::Done; - return Ok(IOResult::Done(())); - } + // Process rows one at a time - no batching + loop { + // This step() call resumes from where the statement left off + match stmt.step()? { + crate::vdbe::StepResult::Row => { + // Get the row + let row = stmt.row().unwrap(); - crate::vdbe::StepResult::Interrupt | crate::vdbe::StepResult::Busy => { - // Save state before returning error - self.populate_state = PopulateState::Processing { - stmt, - rows_processed, - pending_row: None, // No pending row when interrupted between rows - }; - return Err(LimboError::Busy); - } + // Extract values from the row + let all_values: Vec = + row.get_values().cloned().collect(); - crate::vdbe::StepResult::IO => { - // Statement needs I/O - save state and return - self.populate_state = PopulateState::Processing { - stmt, - rows_processed, - pending_row: None, // No pending row when interrupted between rows - }; - // TODO: Get the actual I/O completion from the statement - let completion = crate::io::Completion::new_dummy(); - return Ok(IOResult::IO(crate::types::IOCompletions::Single( - completion, - ))); + // Extract rowid and values using helper + let (rowid, values) = + match self.extract_rowid_and_values(all_values, current_idx) { + Some(result) => result, + None => { + // Invalid rowid, skip this row + rows_processed += 1; + continue; + } + }; + + // Process this row + match self.process_one_row( + rowid, + values.clone(), + current_idx, + pager.clone(), + )? { + IOResult::Done(_) => { + // Row processed successfully, continue to next row + rows_processed += 1; + } + IOResult::IO(io) => { + // Save state and return I/O + // We'll resume at the SAME row when called again (don't increment rows_processed) + // The circuit still has unfinished work for this row + self.populate_state = PopulateState::ProcessingOneTable { + queries, + current_idx, + stmt, + rows_processed, // Don't increment - row not done yet! + pending_row: Some((rowid, values)), // Save the row for resumption + }; + return Ok(IOResult::IO(io)); + } + } + } + + crate::vdbe::StepResult::Done => { + // All rows processed from this table + // Move to next table + self.populate_state = PopulateState::ProcessingAllTables { + queries, + current_idx: current_idx + 1, + }; + continue 'outer; + } + + crate::vdbe::StepResult::Interrupt | crate::vdbe::StepResult::Busy => { + // Save state before returning error + self.populate_state = PopulateState::ProcessingOneTable { + queries, + current_idx, + stmt, + rows_processed, + pending_row: None, // No pending row when interrupted between rows + }; + return Err(LimboError::Busy); + } + + crate::vdbe::StepResult::IO => { + // Statement needs I/O - save state and return + self.populate_state = PopulateState::ProcessingOneTable { + queries, + current_idx, + stmt, + rows_processed, + pending_row: None, // No pending row when interrupted between rows + }; + // TODO: Get the actual I/O completion from the statement + let completion = crate::io::Completion::new_dummy(); + return Ok(IOResult::IO(crate::types::IOCompletions::Single( + completion, + ))); + } + } } } + + PopulateState::Done => { + return Ok(IOResult::Done(())); + } } } } + /// Process a single row through the circuit + fn process_one_row( + &mut self, + rowid: i64, + values: Vec, + table_idx: usize, + pager: Arc, + ) -> crate::Result> { + // Create a single-row delta + let mut single_row_delta = Delta::new(); + single_row_delta.insert(rowid, values); + + // Create a DeltaSet with this delta for the current table + let mut delta_set = DeltaSet::new(); + let table_name = self.referenced_tables[table_idx].name.clone(); + delta_set.insert(table_name, single_row_delta); + + // Process through merge_delta + self.merge_delta(delta_set, pager) + } + + /// Extract rowid and values from a row + fn extract_rowid_and_values( + &self, + all_values: Vec, + table_idx: usize, + ) -> Option<(i64, Vec)> { + if let Some((idx, _)) = self.referenced_tables[table_idx].get_rowid_alias_column() { + // The rowid is the value at the rowid alias column index + let rowid = match all_values.get(idx) { + Some(Value::Integer(id)) => *id, + _ => return None, // Invalid rowid + }; + // All values are table columns (no separate rowid was selected) + Some((rowid, all_values)) + } else { + // The last value is the explicitly selected rowid + let rowid = match all_values.last() { + Some(Value::Integer(id)) => *id, + _ => return None, // Invalid rowid + }; + // Get all values except the rowid + let values = all_values[..all_values.len() - 1].to_vec(); + Some((rowid, values)) + } + } + /// Merge a delta set of changes into the view's current state pub fn merge_delta( &mut self, delta_set: DeltaSet, - pager: std::sync::Arc, + pager: Arc, ) -> crate::Result> { // Early return if all deltas are empty if delta_set.is_empty() { From e1ed12b2841144ed4d1f7f0548948bad405bdb5c Mon Sep 17 00:00:00 2001 From: PThorpe92 Date: Fri, 19 Sep 2025 05:20:20 -0400 Subject: [PATCH 27/34] rm claude comment --- core/translate/expr.rs | 1 - 1 file changed, 1 deletion(-) diff --git a/core/translate/expr.rs b/core/translate/expr.rs index 56649f715..38eca1ae4 100644 --- a/core/translate/expr.rs +++ b/core/translate/expr.rs @@ -3300,7 +3300,6 @@ pub fn bind_and_rewrite_expr<'a>( ast::Name::Ident(normalize_ident(c.as_str())), ); } - // Expand BETWEEN to binary ops (kept identical to your logic). ast::Expr::Between { lhs, not, From d5295fb45cb11ec5d8866db39ba5b38b3cae22b9 Mon Sep 17 00:00:00 2001 From: Avinash Sajjanshetty Date: Fri, 19 Sep 2025 14:55:02 +0530 Subject: [PATCH 28/34] Put the unused variable behind a flag as intended --- core/storage/btree.rs | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/core/storage/btree.rs b/core/storage/btree.rs index e9c888e80..9673eef67 100644 --- a/core/storage/btree.rs +++ b/core/storage/btree.rs @@ -3631,6 +3631,7 @@ impl BTreeCursor { ); let divider_cell_insert_idx_in_parent = balance_info.first_divider_cell + sibling_page_idx; + #[cfg(debug_assertions)] let overflow_cell_count_before = parent_contents.overflow_cells.len(); insert_into_cell( parent_contents, @@ -3638,9 +3639,9 @@ impl BTreeCursor { divider_cell_insert_idx_in_parent, usable_space, )?; - let overflow_cell_count_after = parent_contents.overflow_cells.len(); #[cfg(debug_assertions)] { + let overflow_cell_count_after = parent_contents.overflow_cells.len(); let divider_cell_is_overflow_cell = overflow_cell_count_after > overflow_cell_count_before; From ba7ae50eff40bad4e0494ef23cc48ded40fe0028 Mon Sep 17 00:00:00 2001 From: Jussi Saurio Date: Fri, 19 Sep 2025 12:55:27 +0300 Subject: [PATCH 29/34] mvcc: remove unused code related to is_logical_log() is always logical log --- core/mvcc/database/mod.rs | 107 +++------------------------- core/mvcc/persistent_storage/mod.rs | 4 -- core/vdbe/execute.rs | 12 +--- 3 files changed, 10 insertions(+), 113 deletions(-) diff --git a/core/mvcc/database/mod.rs b/core/mvcc/database/mod.rs index 54cd1d0cf..547cd9e95 100644 --- a/core/mvcc/database/mod.rs +++ b/core/mvcc/database/mod.rs @@ -1,6 +1,5 @@ use crate::mvcc::clock::LogicalClock; use crate::mvcc::persistent_storage::Storage; -use crate::return_if_io; use crate::state_machine::StateMachine; use crate::state_machine::StateTransition; use crate::state_machine::TransitionResult; @@ -542,24 +541,11 @@ impl StateTransition for CommitStateMachine { if mvcc_store.is_exclusive_tx(&self.tx_id) { mvcc_store.release_exclusive_tx(&self.tx_id); self.commit_coordinator.pager_commit_lock.unlock(); - if !mvcc_store.storage.is_logical_log() { - // FIXME: this function isnt re-entrant - self.pager - .io - .block(|| self.pager.end_tx(false, &self.connection))?; - } - } else if !mvcc_store.storage.is_logical_log() { - self.pager.end_read_tx()?; } self.finalize(mvcc_store)?; return Ok(TransitionResult::Done(())); } - if mvcc_store.storage.is_logical_log() { - self.state = CommitState::Commit { end_ts }; - return Ok(TransitionResult::Continue); - } else { - self.state = CommitState::BeginPagerTxn { end_ts }; - } + self.state = CommitState::Commit { end_ts }; Ok(TransitionResult::Continue) } CommitState::BeginPagerTxn { end_ts } => { @@ -851,7 +837,6 @@ impl StateTransition for CommitStateMachine { return Ok(TransitionResult::Continue); } CommitState::BeginCommitLogicalLog { end_ts, log_record } => { - assert!(mvcc_store.storage.is_logical_log()); if !mvcc_store.is_exclusive_tx(&self.tx_id) { // logical log needs to be serialized let locked = self.commit_coordinator.pager_commit_lock.write(); @@ -866,10 +851,6 @@ impl StateTransition for CommitStateMachine { match result { IOResult::Done(_) => {} IOResult::IO(io) => { - assert!( - mvcc_store.storage.is_logical_log(), - "for now logical log is the only storage that can return IO" - ); if !io.finished() { return Ok(TransitionResult::Io(io)); } @@ -897,13 +878,11 @@ impl StateTransition for CommitStateMachine { let schema = connection.schema.borrow().clone(); connection.db.update_schema_if_newer(schema)?; } - if mvcc_store.storage.is_logical_log() { - let tx = mvcc_store.txs.get(&self.tx_id).unwrap(); - let tx_unlocked = tx.value(); - self.header.write().replace(*tx_unlocked.header.borrow()); - tracing::trace!("end_commit_logical_log(tx_id={})", self.tx_id); - self.commit_coordinator.pager_commit_lock.unlock(); - } + let tx = mvcc_store.txs.get(&self.tx_id).unwrap(); + let tx_unlocked = tx.value(); + self.header.write().replace(*tx_unlocked.header.borrow()); + tracing::trace!("end_commit_logical_log(tx_id={})", self.tx_id); + self.commit_coordinator.pager_commit_lock.unlock(); self.state = CommitState::CommitEnd { end_ts: *end_ts }; return Ok(TransitionResult::Continue); } @@ -1422,38 +1401,12 @@ impl MvStore { /// /// This is used for IMMEDIATE and EXCLUSIVE transaction types where we need /// to ensure exclusive write access as per SQLite semantics. + #[instrument(skip_all, level = Level::DEBUG)] pub fn begin_exclusive_tx( &self, pager: Arc, maybe_existing_tx_id: Option, ) -> Result> { - self._begin_exclusive_tx(pager, false, maybe_existing_tx_id) - } - - /// Upgrades a read transaction to an exclusive write transaction. - /// - /// This is used for IMMEDIATE and EXCLUSIVE transaction types where we need - /// to ensure exclusive write access as per SQLite semantics. - pub fn upgrade_to_exclusive_tx( - &self, - pager: Arc, - maybe_existing_tx_id: Option, - ) -> Result> { - self._begin_exclusive_tx(pager, true, maybe_existing_tx_id) - } - - /// Begins an exclusive write transaction that prevents concurrent writes. - /// - /// This is used for IMMEDIATE and EXCLUSIVE transaction types where we need - /// to ensure exclusive write access as per SQLite semantics. - #[instrument(skip_all, level = Level::DEBUG)] - fn _begin_exclusive_tx( - &self, - pager: Arc, - is_upgrade_from_read: bool, - maybe_existing_tx_id: Option, - ) -> Result> { - let is_logical_log = self.storage.is_logical_log(); let tx_id = maybe_existing_tx_id.unwrap_or_else(|| self.get_tx_id()); let begin_ts = if let Some(tx_id) = maybe_existing_tx_id { self.txs.get(&tx_id).unwrap().value().begin_ts @@ -1463,16 +1416,6 @@ impl MvStore { self.acquire_exclusive_tx(&tx_id)?; - // Try to acquire the pager read lock - if !is_upgrade_from_read && !is_logical_log { - pager.begin_read_tx().inspect_err(|_| { - tracing::debug!( - "begin_exclusive_tx: tx_id={} failed with Busy on pager_read_lock", - tx_id - ); - self.release_exclusive_tx(&tx_id); - })?; - } let locked = self.commit_coordinator.pager_commit_lock.write(); if !locked { tracing::debug!( @@ -1480,46 +1423,18 @@ impl MvStore { tx_id ); self.release_exclusive_tx(&tx_id); - pager.end_read_tx()?; return Err(LimboError::Busy); } let header = self.get_new_transaction_database_header(&pager); - if is_logical_log { - let tx = Transaction::new(tx_id, begin_ts, header); - tracing::trace!( - "begin_exclusive_tx(tx_id={}) - exclusive write logical log transaction", - tx_id - ); - tracing::debug!("begin_exclusive_tx: tx_id={} succeeded", tx_id); - self.txs.insert(tx_id, tx); - return Ok(IOResult::Done(tx_id)); - } - // Try to acquire the pager write lock - let begin_w_tx_res = pager.begin_write_tx(); - if let Err(LimboError::Busy) = begin_w_tx_res { - tracing::debug!("begin_exclusive_tx: tx_id={} failed with Busy", tx_id); - // Failed to get pager lock - release our exclusive lock - self.commit_coordinator.pager_commit_lock.unlock(); - self.release_exclusive_tx(&tx_id); - if maybe_existing_tx_id.is_none() { - // If we were upgrading an existing non-CONCURRENT mvcc transaction to write, we don't end the read tx on Busy. - // But if we were beginning a completely new non-CONCURRENT mvcc transaction, we do end it because the next time the connection - // attempts to do something, it will open a new read tx, which will fail if we don't end this one here. - pager.end_read_tx()?; - } - return Err(LimboError::Busy); - } - return_if_io!(begin_w_tx_res); let tx = Transaction::new(tx_id, begin_ts, header); tracing::trace!( - "begin_exclusive_tx(tx_id={}) - exclusive write transaction", + "begin_exclusive_tx(tx_id={}) - exclusive write logical log transaction", tx_id ); tracing::debug!("begin_exclusive_tx: tx_id={} succeeded", tx_id); self.txs.insert(tx_id, tx); - Ok(IOResult::Done(tx_id)) } @@ -1532,12 +1447,6 @@ impl MvStore { let tx_id = self.get_tx_id(); let begin_ts = self.get_timestamp(); - // TODO: we need to tie a pager's read transaction to a transaction ID, so that future refactors to read - // pages from WAL/DB read from a consistent state to maintiain snapshot isolation. - if !self.storage.is_logical_log() { - pager.begin_read_tx()?; - } - // Set txn's header to the global header let header = self.get_new_transaction_database_header(&pager); let tx = Transaction::new(tx_id, begin_ts, header); diff --git a/core/mvcc/persistent_storage/mod.rs b/core/mvcc/persistent_storage/mod.rs index b92bf081e..ac12e77c4 100644 --- a/core/mvcc/persistent_storage/mod.rs +++ b/core/mvcc/persistent_storage/mod.rs @@ -29,10 +29,6 @@ impl Storage { todo!() } - pub fn is_logical_log(&self) -> bool { - true - } - pub fn sync(&self) -> Result> { self.logical_log.borrow_mut().sync() } diff --git a/core/vdbe/execute.rs b/core/vdbe/execute.rs index 16f32ba79..5099f7831 100644 --- a/core/vdbe/execute.rs +++ b/core/vdbe/execute.rs @@ -2276,16 +2276,8 @@ pub fn op_transaction_inner( if matches!(new_transaction_state, TransactionState::Write { .. }) && matches!(actual_tx_mode, TransactionMode::Write) { - let (tx_id, mv_tx_mode) = program.connection.mv_tx.get().unwrap(); - if mv_tx_mode == TransactionMode::Read { - return_if_io!( - mv_store.upgrade_to_exclusive_tx(pager.clone(), Some(tx_id)) - ); - } else { - return_if_io!( - mv_store.begin_exclusive_tx(pager.clone(), Some(tx_id)) - ); - } + let (tx_id, _) = program.connection.mv_tx.get().unwrap(); + return_if_io!(mv_store.begin_exclusive_tx(pager.clone(), Some(tx_id))); } } } else { From 8300d0390e0b92cbf518687025f1b990ec0d3e26 Mon Sep 17 00:00:00 2001 From: Glauber Costa Date: Fri, 19 Sep 2025 05:59:46 -0500 Subject: [PATCH 30/34] prevent alter table with materialized views I don't want to even think about the complexity involved in making sure that materialized views are still sane after the base table(s) are altered. --- core/translate/alter.rs | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/core/translate/alter.rs b/core/translate/alter.rs index e3ecf43ea..d3170fdbe 100644 --- a/core/translate/alter.rs +++ b/core/translate/alter.rs @@ -48,6 +48,15 @@ pub fn translate_alter_table( ))); }; + // Check if this table has dependent materialized views + let dependent_views = schema.get_dependent_materialized_views(table_name); + if !dependent_views.is_empty() { + return Err(LimboError::ParseError(format!( + "cannot alter table \"{table_name}\": it has dependent materialized view(s): {}", + dependent_views.join(", ") + ))); + } + let mut btree = (*original_btree).clone(); Ok(match alter_table { From c63c820bb7731c446be314a5cd1d66f6f92cb7e6 Mon Sep 17 00:00:00 2001 From: Nikita Sivukhin Date: Fri, 19 Sep 2025 16:48:12 +0400 Subject: [PATCH 31/34] add busy_timeout pragma --- bindings/rust/src/lib.rs | 2 +- core/lib.rs | 6 +++++- core/pragma.rs | 4 ++++ core/translate/pragma.rs | 27 ++++++++++++++++++++++++++- parser/src/ast.rs | 2 ++ 5 files changed, 38 insertions(+), 3 deletions(-) diff --git a/bindings/rust/src/lib.rs b/bindings/rust/src/lib.rs index 39706fdd5..15ae191f7 100644 --- a/bindings/rust/src/lib.rs +++ b/bindings/rust/src/lib.rs @@ -413,7 +413,7 @@ impl Connection { .inner .lock() .map_err(|e| Error::MutexError(e.to_string()))?; - conn.busy_timeout(duration); + conn.set_busy_timeout(duration); Ok(()) } } diff --git a/core/lib.rs b/core/lib.rs index ae8fb31ab..79999a7bb 100644 --- a/core/lib.rs +++ b/core/lib.rs @@ -2174,10 +2174,14 @@ impl Connection { /// 5. Step through query -> returns Busy -> return Busy to user /// /// This slight api change demonstrated a better throughtput in `perf/throughput/turso` benchmark - pub fn busy_timeout(&self, mut duration: Option) { + pub fn set_busy_timeout(&self, mut duration: Option) { duration = duration.filter(|duration| !duration.is_zero()); self.busy_timeout.set(duration); } + + pub fn get_busy_timeout(&self) -> Option { + self.busy_timeout.get() + } } #[derive(Debug, Default)] diff --git a/core/pragma.rs b/core/pragma.rs index edcfd21b9..c83509a69 100644 --- a/core/pragma.rs +++ b/core/pragma.rs @@ -102,6 +102,10 @@ pub fn pragma_for(pragma: &PragmaName) -> Pragma { PragmaFlags::NoColumns1 | PragmaFlags::Result0, &["auto_vacuum"], ), + BusyTimeout => Pragma::new( + PragmaFlags::NoColumns1 | PragmaFlags::Result0, + &["busy_timeout"], + ), IntegrityCheck => Pragma::new( PragmaFlags::NeedSchema | PragmaFlags::ReadOnly | PragmaFlags::Result0, &["message"], diff --git a/core/translate/pragma.rs b/core/translate/pragma.rs index 7fa74e9ca..fa4274ed3 100644 --- a/core/translate/pragma.rs +++ b/core/translate/pragma.rs @@ -99,7 +99,7 @@ fn update_pragma( let app_id_value = match data { Value::Integer(i) => i as i32, Value::Float(f) => f as i32, - _ => unreachable!(), + _ => bail_parse_error!("expected integer, got {:?}", data), }; program.emit_insn(Insn::SetCookie { @@ -110,6 +110,19 @@ fn update_pragma( }); Ok((program, TransactionMode::Write)) } + PragmaName::BusyTimeout => { + let data = parse_signed_number(&value)?; + let busy_timeout_ms = match data { + Value::Integer(i) => i as i32, + Value::Float(f) => f as i32, + _ => bail_parse_error!("expected integer, got {:?}", data), + }; + let busy_timeout_ms = busy_timeout_ms.max(0); + connection.set_busy_timeout(Some(std::time::Duration::from_millis( + busy_timeout_ms as u64, + ))); + Ok((program, TransactionMode::Write)) + } PragmaName::CacheSize => { let cache_size = match parse_signed_number(&value)? { Value::Integer(size) => size, @@ -388,6 +401,18 @@ fn query_pragma( program.emit_result_row(register, 1); Ok((program, TransactionMode::Read)) } + PragmaName::BusyTimeout => { + program.emit_int( + connection + .get_busy_timeout() + .map(|t| t.as_millis() as i64) + .unwrap_or_default(), + register, + ); + program.emit_result_row(register, 1); + program.add_pragma_result_column(pragma.to_string()); + Ok((program, TransactionMode::None)) + } PragmaName::CacheSize => { program.emit_int(connection.get_cache_size() as i64, register); program.emit_result_row(register, 1); diff --git a/parser/src/ast.rs b/parser/src/ast.rs index 3c331107a..988bbab95 100644 --- a/parser/src/ast.rs +++ b/parser/src/ast.rs @@ -1312,6 +1312,8 @@ pub enum PragmaName { ApplicationId, /// set the autovacuum mode AutoVacuum, + /// set the busy_timeout + BusyTimeout, /// `cache_size` pragma CacheSize, /// encryption cipher algorithm name for encrypted databases From 57e52077be6efc35e3107ff9de76aad2aee82373 Mon Sep 17 00:00:00 2001 From: Nikita Sivukhin Date: Fri, 19 Sep 2025 16:48:43 +0400 Subject: [PATCH 32/34] add link to the docs --- parser/src/ast.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/parser/src/ast.rs b/parser/src/ast.rs index 988bbab95..ed58d0bba 100644 --- a/parser/src/ast.rs +++ b/parser/src/ast.rs @@ -1312,7 +1312,7 @@ pub enum PragmaName { ApplicationId, /// set the autovacuum mode AutoVacuum, - /// set the busy_timeout + /// set the busy_timeout (see https://www.sqlite.org/pragma.html#pragma_busy_timeout) BusyTimeout, /// `cache_size` pragma CacheSize, From 0597ea722ae0da9914e67f752fabe1510e4eff09 Mon Sep 17 00:00:00 2001 From: Avinash Sajjanshetty Date: Sat, 20 Sep 2025 21:56:58 +0530 Subject: [PATCH 33/34] Add encryption throughput test --- Cargo.lock | 13 + Cargo.toml | 2 + bindings/rust/Cargo.toml | 1 + perf/encryption/Cargo.toml | 17 ++ perf/encryption/README.md | 28 +++ perf/encryption/src/main.rs | 457 ++++++++++++++++++++++++++++++++++++ 6 files changed, 518 insertions(+) create mode 100644 perf/encryption/Cargo.toml create mode 100644 perf/encryption/README.md create mode 100644 perf/encryption/src/main.rs diff --git a/Cargo.lock b/Cargo.lock index 9e2bb330d..17a1d956d 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1100,6 +1100,19 @@ version = "1.15.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "48c757948c5ede0e46177b7add2e67155f70e33c07fea8284df6576da70b3719" +[[package]] +name = "encryption-throughput" +version = "0.1.0" +dependencies = [ + "clap", + "futures", + "hex", + "rand 0.9.2", + "tokio", + "tracing-subscriber", + "turso", +] + [[package]] name = "endian-type" version = "0.1.2" diff --git a/Cargo.toml b/Cargo.toml index 2771c2a31..aff912890 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -32,7 +32,9 @@ members = [ "whopper", "perf/throughput/turso", "perf/throughput/rusqlite", + "perf/encryption" ] + exclude = [ "perf/latency/limbo", ] diff --git a/bindings/rust/Cargo.toml b/bindings/rust/Cargo.toml index d799b5320..e50304f01 100644 --- a/bindings/rust/Cargo.toml +++ b/bindings/rust/Cargo.toml @@ -15,6 +15,7 @@ conn_raw_api = ["turso_core/conn_raw_api"] experimental_indexes = [] antithesis = ["turso_core/antithesis"] tracing_release = ["turso_core/tracing_release"] +encryption = ["turso_core/encryption"] [dependencies] turso_core = { workspace = true, features = ["io_uring"] } diff --git a/perf/encryption/Cargo.toml b/perf/encryption/Cargo.toml new file mode 100644 index 000000000..e769c5b0b --- /dev/null +++ b/perf/encryption/Cargo.toml @@ -0,0 +1,17 @@ +[package] +name = "encryption-throughput" +version = "0.1.0" +edition = "2021" + +[[bin]] +name = "encryption-throughput" +path = "src/main.rs" + +[dependencies] +turso = { workspace = true, features = ["encryption"] } +clap = { workspace = true, features = ["derive"] } +tokio = { workspace = true, default-features = true, features = ["full"] } +futures = { workspace = true } +tracing-subscriber = { workspace = true } +rand = { workspace = true, features = ["small_rng"] } +hex = { workspace = true } \ No newline at end of file diff --git a/perf/encryption/README.md b/perf/encryption/README.md new file mode 100644 index 000000000..0ec611258 --- /dev/null +++ b/perf/encryption/README.md @@ -0,0 +1,28 @@ +# Encryption Throughput Benchmarking + +```shell +$ cargo run --release -- --help + +Usage: encryption-throughput [OPTIONS] + +Options: + -t, --threads [default: 1] + -b, --batch-size [default: 100] + -i, --iterations [default: 10] + -r, --read-ratio Percentage of operations that should be reads (0-100) + -w, --write-ratio Percentage of operations that should be writes (0-100) + --encryption Enable database encryption + --cipher Encryption cipher to use (only relevant if --encryption is set) [default: aegis-256] + --think Per transaction think time (ms) [default: 0] + --timeout Busy timeout in milliseconds [default: 30000] + --seed Random seed for reproducible workloads [default: 2167532792061351037] + -h, --help Print help +``` + +```shell +# try these: + +cargo run --release -- -b 100 -i 25000 --read-ratio 75 + +cargo run --release -- -b 100 -i 25000 --read-ratio 75 --encryption +``` \ No newline at end of file diff --git a/perf/encryption/src/main.rs b/perf/encryption/src/main.rs new file mode 100644 index 000000000..7736055c5 --- /dev/null +++ b/perf/encryption/src/main.rs @@ -0,0 +1,457 @@ +use clap::Parser; +use rand::rngs::SmallRng; +use rand::{Rng, RngCore, SeedableRng}; +use std::sync::atomic::{AtomicU64, Ordering}; +use std::sync::{Arc, Barrier}; +use std::time::{Duration, Instant}; +use turso::{Builder, Database, Result}; + +#[derive(Debug, Clone)] +struct EncryptionOpts { + cipher: String, + hexkey: String, +} + +#[derive(Parser)] +#[command(name = "encryption-throughput")] +#[command(about = "Encryption throughput benchmark on Turso DB")] +struct Args { + /// More than one thread does not work yet + #[arg(short = 't', long = "threads", default_value = "1")] + threads: usize, + + /// the number operations per transaction + #[arg(short = 'b', long = "batch-size", default_value = "100")] + batch_size: usize, + + /// number of transactions per thread + #[arg(short = 'i', long = "iterations", default_value = "10")] + iterations: usize, + + #[arg( + short = 'r', + long = "read-ratio", + help = "Percentage of operations that should be reads (0-100)" + )] + read_ratio: Option, + + #[arg( + short = 'w', + long = "write-ratio", + help = "Percentage of operations that should be writes (0-100)" + )] + write_ratio: Option, + + #[arg( + long = "encryption", + action = clap::ArgAction::SetTrue, + help = "Enable database encryption" + )] + encryption: bool, + + #[arg( + long = "cipher", + default_value = "aegis-256", + help = "Encryption cipher to use (only relevant if --encryption is set)" + )] + cipher: String, + + #[arg( + long = "think", + default_value = "0", + help = "Per transaction think time (ms)" + )] + think: u64, + + #[arg( + long = "timeout", + default_value = "30000", + help = "Busy timeout in milliseconds" + )] + timeout: u64, + + #[arg( + long = "seed", + default_value = "2167532792061351037", + help = "Random seed for reproducible workloads" + )] + seed: u64, +} + +#[derive(Debug)] +struct WorkerStats { + transactions_completed: u64, + reads_completed: u64, + writes_completed: u64, + reads_found: u64, + reads_not_found: u64, + total_transaction_time: Duration, +} + +#[derive(Debug, Clone)] +struct SharedState { + max_inserted_id: Arc, +} + +#[tokio::main] +async fn main() -> Result<()> { + let _ = tracing_subscriber::fmt::try_init(); + let args = Args::parse(); + + let read_ratio = match (args.read_ratio, args.write_ratio) { + (Some(_), Some(_)) => { + eprintln!("Error: Cannot specify both --read-ratio and --write-ratio"); + std::process::exit(1); + } + (Some(r), None) => { + if r > 100 { + eprintln!("Error: read-ratio must be between 0 and 100"); + std::process::exit(1); + } + r + } + (None, Some(w)) => { + if w > 100 { + eprintln!("Error: write-ratio must be between 0 and 100"); + std::process::exit(1); + } + 100 - w + } + // lets default to 0% reads (100% writes) + (None, None) => 0, + }; + + println!( + "Running encryption throughput benchmark with {} threads, {} batch size, {} iterations", + args.threads, args.batch_size, args.iterations + ); + println!( + "Read/Write ratio: {}% reads, {}% writes", + read_ratio, + 100 - read_ratio + ); + println!("Encryption enabled: {}", args.encryption); + println!("Random seed: {}", args.seed); + + let encryption_opts = if args.encryption { + let mut key_rng = SmallRng::seed_from_u64(args.seed); + let key_size = get_key_size_for_cipher(&args.cipher); + let mut key = vec![0u8; key_size]; + key_rng.fill_bytes(&mut key); + + let config = EncryptionOpts { + cipher: args.cipher.clone(), + hexkey: hex::encode(&key), + }; + + println!("Cipher: {}", config.cipher); + println!("Hexkey: {}", config.hexkey); + Some(config) + } else { + None + }; + + let db_path = "encryption_throughput_test.db"; + if std::path::Path::new(db_path).exists() { + std::fs::remove_file(db_path).expect("Failed to remove existing database"); + } + let wal_path = "encryption_throughput_test.db-wal"; + if std::path::Path::new(wal_path).exists() { + std::fs::remove_file(wal_path).expect("Failed to remove existing WAL file"); + } + + let db = setup_database(db_path, &encryption_opts).await?; + + // for create a var which is shared between all the threads, this we use to track the + // max inserted id so that we only read these + let shared_state = SharedState { + max_inserted_id: Arc::new(AtomicU64::new(0)), + }; + + let start_barrier = Arc::new(Barrier::new(args.threads)); + let mut handles = Vec::new(); + + let timeout = Duration::from_millis(args.timeout); + let overall_start = Instant::now(); + + for thread_id in 0..args.threads { + let db_clone = db.clone(); + let barrier = Arc::clone(&start_barrier); + let encryption_opts_clone = encryption_opts.clone(); + let shared_state_clone = shared_state.clone(); + + let handle = tokio::task::spawn(worker_thread( + thread_id, + db_clone, + args.batch_size, + args.iterations, + barrier, + read_ratio, + encryption_opts_clone, + args.think, + timeout, + shared_state_clone, + args.seed, + )); + + handles.push(handle); + } + + let mut total_transactions = 0; + let mut total_reads = 0; + let mut total_writes = 0; + let mut total_reads_found = 0; + let mut total_reads_not_found = 0; + + for (idx, handle) in handles.into_iter().enumerate() { + match handle.await { + Ok(Ok(stats)) => { + total_transactions += stats.transactions_completed; + total_reads += stats.reads_completed; + total_writes += stats.writes_completed; + total_reads_found += stats.reads_found; + total_reads_not_found += stats.reads_not_found; + } + Ok(Err(e)) => { + eprintln!("Thread error {idx}: {e}"); + return Err(e); + } + Err(_) => { + eprintln!("Thread panicked"); + std::process::exit(1); + } + } + } + + let overall_elapsed = overall_start.elapsed(); + let total_operations = total_reads + total_writes; + + let transaction_throughput = (total_transactions as f64) / overall_elapsed.as_secs_f64(); + let operation_throughput = (total_operations as f64) / overall_elapsed.as_secs_f64(); + let read_throughput = if total_reads > 0 { + (total_reads as f64) / overall_elapsed.as_secs_f64() + } else { + 0.0 + }; + let write_throughput = if total_writes > 0 { + (total_writes as f64) / overall_elapsed.as_secs_f64() + } else { + 0.0 + }; + let avg_ops_per_txn = (total_operations as f64) / (total_transactions as f64); + + println!("\n=== BENCHMARK RESULTS ==="); + println!("Total transactions: {total_transactions}"); + println!("Total operations: {total_operations}"); + println!("Operations per transaction: {avg_ops_per_txn:.1}"); + println!("Total time: {:.2}s", overall_elapsed.as_secs_f64()); + println!(); + println!("Transaction throughput: {transaction_throughput:.2} txns/sec"); + println!("Operation throughput: {operation_throughput:.2} ops/sec"); + + // not found should be zero since track the max inserted id + // todo(v): probably handle the not found error and remove max id + if total_reads > 0 { + println!( + " - Read operations: {total_reads} ({total_reads_found} found, {total_reads_not_found} not found)" + ); + println!(" - Read throughput: {read_throughput:.2} reads/sec"); + } + if total_writes > 0 { + println!(" - Write operations: {total_writes}"); + println!(" - Write throughput: {write_throughput:.2} writes/sec"); + } + + println!("\nConfiguration:"); + println!("Threads: {}", args.threads); + println!("Batch size: {}", args.batch_size); + println!("Iterations per thread: {}", args.iterations); + println!("Encryption: {}", args.encryption); + println!("Seed: {}", args.seed); + + if let Ok(metadata) = std::fs::metadata(db_path) { + println!("Database file size: {} bytes", metadata.len()); + } + + Ok(()) +} + +fn get_key_size_for_cipher(cipher: &str) -> usize { + match cipher.to_lowercase().as_str() { + "aes-128-gcm" | "aegis-128l" | "aegis-128x2" | "aegis-128x4" => 16, + "aes-256-gcm" | "aegis-256" | "aegis-256x2" | "aegis-256x4" => 32, + _ => 32, // default to 256-bit key + } +} + +async fn setup_database( + db_path: &str, + encryption_opts: &Option, +) -> Result { + let builder = Builder::new_local(db_path); + let db = builder.build().await?; + let conn = db.connect()?; + + if let Some(config) = encryption_opts { + conn.execute(&format!("PRAGMA cipher='{}'", config.cipher), ()) + .await?; + conn.execute(&format!("PRAGMA hexkey='{}'", config.hexkey), ()) + .await?; + } + + // todo(v): probably store blobs and then have option of randomblob size + conn.execute( + "CREATE TABLE IF NOT EXISTS test_table ( + id INTEGER PRIMARY KEY, + data TEXT NOT NULL + )", + (), + ) + .await?; + + println!("Database created at: {db_path}"); + Ok(db) +} + +#[allow(clippy::too_many_arguments)] +async fn worker_thread( + thread_id: usize, + db: Database, + batch_size: usize, + iterations: usize, + start_barrier: Arc, + read_ratio: u8, + encryption_opts: Option, + think_ms: u64, + timeout: Duration, + shared_state: SharedState, + base_seed: u64, +) -> Result { + start_barrier.wait(); + + let start_time = Instant::now(); + let mut stats = WorkerStats { + transactions_completed: 0, + reads_completed: 0, + writes_completed: 0, + reads_found: 0, + reads_not_found: 0, + total_transaction_time: Duration::ZERO, + }; + + let thread_seed = base_seed.wrapping_add(thread_id as u64); + let mut rng = SmallRng::seed_from_u64(thread_seed); + + for iteration in 0..iterations { + let conn = db.connect()?; + + if let Some(config) = &encryption_opts { + conn.execute(&format!("PRAGMA cipher='{}'", config.cipher), ()) + .await?; + conn.execute(&format!("PRAGMA hexkey='{}'", config.hexkey), ()) + .await?; + } + + conn.busy_timeout(Some(timeout))?; + + let mut insert_stmt = conn + .prepare("INSERT INTO test_table (id, data) VALUES (?, ?)") + .await?; + + let transaction_start = Instant::now(); + conn.execute("BEGIN", ()).await?; + + for i in 0..batch_size { + let should_read = rng.random_range(0..100) < read_ratio; + + if should_read { + // only attempt reads if we have inserted some data + let max_id = shared_state.max_inserted_id.load(Ordering::Relaxed); + if max_id > 0 { + let read_id = rng.random_range(1..=max_id); + let row = conn + .query( + "SELECT data FROM test_table WHERE id = ?", + turso::params::Params::Positional(vec![turso::Value::Integer( + read_id as i64, + )]), + ) + .await; + + match row { + Ok(_) => stats.reads_found += 1, + Err(turso::Error::QueryReturnedNoRows) => stats.reads_not_found += 1, + Err(e) => return Err(e), + }; + stats.reads_completed += 1; + } else { + // if no data inserted yet, convert to a write + let id = thread_id * iterations * batch_size + iteration * batch_size + i + 1; + insert_stmt + .execute(turso::params::Params::Positional(vec![ + turso::Value::Integer(id as i64), + turso::Value::Text(format!("data_{id}")), + ])) + .await?; + + shared_state + .max_inserted_id + .fetch_max(id as u64, Ordering::Relaxed); + stats.writes_completed += 1; + } + } else { + let id = thread_id * iterations * batch_size + iteration * batch_size + i + 1; + insert_stmt + .execute(turso::params::Params::Positional(vec![ + turso::Value::Integer(id as i64), + turso::Value::Text(format!("data_{id}")), + ])) + .await?; + + shared_state + .max_inserted_id + .fetch_max(id as u64, Ordering::Relaxed); + stats.writes_completed += 1; + } + } + + if think_ms > 0 { + tokio::time::sleep(Duration::from_millis(think_ms)).await; + } + + conn.execute("COMMIT", ()).await?; + + let transaction_elapsed = transaction_start.elapsed(); + stats.transactions_completed += 1; + stats.total_transaction_time += transaction_elapsed; + } + + let elapsed = start_time.elapsed(); + let total_ops = stats.reads_completed + stats.writes_completed; + let transaction_throughput = (stats.transactions_completed as f64) / elapsed.as_secs_f64(); + let operation_throughput = (total_ops as f64) / elapsed.as_secs_f64(); + let avg_txn_latency = + stats.total_transaction_time.as_secs_f64() * 1000.0 / stats.transactions_completed as f64; + + println!( + "Thread {}: {} txns ({} ops: {} reads, {} writes) in {:.2}s ({:.2} txns/sec, {:.2} ops/sec, {:.2}ms avg latency)", + thread_id, + stats.transactions_completed, + total_ops, + stats.reads_completed, + stats.writes_completed, + elapsed.as_secs_f64(), + transaction_throughput, + operation_throughput, + avg_txn_latency + ); + + if stats.reads_completed > 0 { + println!( + " Thread {} reads: {} found, {} not found", + thread_id, stats.reads_found, stats.reads_not_found + ); + } + + Ok(stats) +} From e5dfc942b12f1c943f97061af3c274797fd287f5 Mon Sep 17 00:00:00 2001 From: pedrocarlo Date: Sun, 21 Sep 2025 13:05:46 -0300 Subject: [PATCH 34/34] remove some unnecessary unsafe impls --- core/fast_lock.rs | 1 - core/io/memory.rs | 3 +-- core/io/unix.rs | 5 ----- core/storage/database.rs | 5 ----- 4 files changed, 1 insertion(+), 13 deletions(-) diff --git a/core/fast_lock.rs b/core/fast_lock.rs index 8abda6a17..a02d617ba 100644 --- a/core/fast_lock.rs +++ b/core/fast_lock.rs @@ -34,7 +34,6 @@ impl DerefMut for SpinLockGuard<'_, T> { } } -unsafe impl Send for SpinLock {} unsafe impl Sync for SpinLock {} impl SpinLock { diff --git a/core/io/memory.rs b/core/io/memory.rs index c69d87dcf..fc0549ca7 100644 --- a/core/io/memory.rs +++ b/core/io/memory.rs @@ -12,7 +12,6 @@ use tracing::debug; pub struct MemoryIO { files: Arc>>>, } -unsafe impl Send for MemoryIO {} // TODO: page size flag const PAGE_SIZE: usize = 4096; @@ -76,7 +75,7 @@ pub struct MemoryFile { pages: UnsafeCell>, size: Cell, } -unsafe impl Send for MemoryFile {} + unsafe impl Sync for MemoryFile {} impl File for MemoryFile { diff --git a/core/io/unix.rs b/core/io/unix.rs index a3cfd6f2f..b0d47f30f 100644 --- a/core/io/unix.rs +++ b/core/io/unix.rs @@ -17,9 +17,6 @@ use tracing::{instrument, trace, Level}; pub struct UnixIO {} -unsafe impl Send for UnixIO {} -unsafe impl Sync for UnixIO {} - impl UnixIO { #[cfg(feature = "fs")] pub fn new() -> Result { @@ -128,8 +125,6 @@ impl IO for UnixIO { pub struct UnixFile { file: Arc>, } -unsafe impl Send for UnixFile {} -unsafe impl Sync for UnixFile {} impl File for UnixFile { fn lock_file(&self, exclusive: bool) -> Result<()> { diff --git a/core/storage/database.rs b/core/storage/database.rs index e7aceebbf..3cbc42b9f 100644 --- a/core/storage/database.rs +++ b/core/storage/database.rs @@ -88,11 +88,6 @@ pub struct DatabaseFile { file: Arc, } -#[cfg(feature = "fs")] -unsafe impl Send for DatabaseFile {} -#[cfg(feature = "fs")] -unsafe impl Sync for DatabaseFile {} - #[cfg(feature = "fs")] impl DatabaseStorage for DatabaseFile { #[instrument(skip_all, level = Level::DEBUG)]