diff --git a/core/incremental/compiler.rs b/core/incremental/compiler.rs new file mode 100644 index 000000000..9cd2e3702 --- /dev/null +++ b/core/incremental/compiler.rs @@ -0,0 +1,2922 @@ +//! DBSP Compiler: Converts Logical Plans to DBSP Circuits +//! +//! This module implements compilation from SQL logical plans to DBSP circuits. +//! The initial version supports only filter and projection operators. +//! +//! Based on the DBSP paper: "DBSP: Automatic Incremental View Maintenance for Rich Query Languages" + +use crate::incremental::expr_compiler::CompiledExpression; +use crate::incremental::operator::{ + Delta, FilterOperator, FilterPredicate, IncrementalOperator, ProjectOperator, +}; +// Note: logical module must be made pub(crate) in translate/mod.rs +use crate::translate::logical::{BinaryOperator, LogicalExpr, LogicalPlan, SchemaRef}; +use crate::types::Value; +use crate::{LimboError, Result}; +use std::collections::HashMap; +use std::fmt::{self, Display, Formatter}; + +/// A set of deltas for multiple tables/operators +/// This provides a cleaner API for passing deltas through circuit execution +#[derive(Debug, Clone, Default)] +pub struct DeltaSet { + /// Deltas keyed by table/operator name + deltas: HashMap, +} + +impl DeltaSet { + /// Create a new empty delta set + pub fn new() -> Self { + Self { + deltas: HashMap::new(), + } + } + + /// Create an empty delta set (more semantic for "no changes") + pub fn empty() -> Self { + Self { + deltas: HashMap::new(), + } + } + + /// Add a delta for a table + pub fn insert(&mut self, table_name: String, delta: Delta) { + self.deltas.insert(table_name, delta); + } + + /// Get delta for a table, returns empty delta if not found + pub fn get(&self, table_name: &str) -> Delta { + self.deltas + .get(table_name) + .cloned() + .unwrap_or_else(Delta::new) + } +} + +/// Represents a DBSP operator in the compiled circuit +#[derive(Debug, Clone, PartialEq)] +pub enum DbspOperator { + /// Filter operator (σ) - filters records based on a predicate + Filter { predicate: DbspExpr }, + /// Projection operator (π) - projects specific columns + Projection { + exprs: Vec, + schema: SchemaRef, + }, + /// Aggregate operator (γ) - performs grouping and aggregation + Aggregate { + group_exprs: Vec, + aggr_exprs: Vec, + schema: SchemaRef, + }, + /// Input operator - source of data + Input { name: String, schema: SchemaRef }, +} + +/// Represents an expression in DBSP +#[derive(Debug, Clone, PartialEq)] +pub enum DbspExpr { + /// Column reference + Column(String), + /// Literal value + Literal(Value), + /// Binary expression + BinaryExpr { + left: Box, + op: BinaryOperator, + right: Box, + }, +} + +/// A node in the DBSP circuit DAG +pub struct DbspNode { + /// Unique identifier for this node + pub id: usize, + /// The operator metadata + pub operator: DbspOperator, + /// Input nodes (edges in the DAG) + pub inputs: Vec, + /// The actual executable operator (if applicable) + pub executable: Option>, +} + +impl std::fmt::Debug for DbspNode { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.debug_struct("DbspNode") + .field("id", &self.id) + .field("operator", &self.operator) + .field("inputs", &self.inputs) + .field("has_executable", &self.executable.is_some()) + .finish() + } +} + +/// Represents a complete DBSP circuit (DAG of operators) +#[derive(Debug)] +pub struct DbspCircuit { + /// All nodes in the circuit, indexed by their ID + pub(super) nodes: HashMap, + /// Counter for generating unique node IDs + next_id: usize, + /// Root node ID (the final output) + pub(super) root: Option, +} + +impl DbspCircuit { + /// Create a new empty circuit + pub fn new() -> Self { + Self { + nodes: HashMap::new(), + next_id: 0, + root: None, + } + } + + /// Add a node to the circuit + fn add_node( + &mut self, + operator: DbspOperator, + inputs: Vec, + executable: Option>, + ) -> usize { + let id = self.next_id; + self.next_id += 1; + + let node = DbspNode { + id, + operator, + inputs, + executable, + }; + + self.nodes.insert(id, node); + id + } + + /// Initialize the circuit with base data. Should be called once before processing deltas. + /// If the database is restarting with materialized views, this can be skipped. + pub fn initialize(&mut self, input_data: HashMap) -> Result { + if let Some(root_id) = self.root { + self.initialize_node(root_id, &input_data) + } else { + Err(LimboError::ParseError( + "Circuit has no root node".to_string(), + )) + } + } + + /// Initialize a specific node and its dependencies + fn initialize_node( + &mut self, + node_id: usize, + input_data: &HashMap, + ) -> Result { + // Clone to avoid borrow checker issues + let inputs = self + .nodes + .get(&node_id) + .ok_or_else(|| LimboError::ParseError("Node not found".to_string()))? + .inputs + .clone(); + + // Initialize inputs first + let mut input_deltas = Vec::new(); + for input_id in inputs { + let delta = self.initialize_node(input_id, input_data)?; + input_deltas.push(delta); + } + + // Get mutable reference to node + let node = self + .nodes + .get_mut(&node_id) + .ok_or_else(|| LimboError::ParseError("Node not found".to_string()))?; + + // Initialize based on operator type + let result = match &node.operator { + DbspOperator::Input { name, .. } => { + // Get data from input map + input_data.get(name).cloned().unwrap_or_else(Delta::new) + } + DbspOperator::Filter { .. } + | DbspOperator::Projection { .. } + | DbspOperator::Aggregate { .. } => { + // Initialize the executable operator + if let Some(ref mut op) = node.executable { + if !input_deltas.is_empty() { + let input_delta = input_deltas[0].clone(); + op.initialize(input_delta); + op.get_current_state() + } else { + Delta::new() + } + } else { + // If no executable, pass through the input + if !input_deltas.is_empty() { + input_deltas[0].clone() + } else { + Delta::new() + } + } + } + }; + + Ok(result) + } + + /// Execute the circuit with incremental input data (deltas). + /// Call initialize() first for initial data, then use execute() for updates. + /// + /// # Arguments + /// * `input_data` - The committed deltas to process + /// * `uncommitted_data` - Uncommitted transaction deltas that should be visible + /// during this execution but not stored in operators. + /// Use DeltaSet::empty() for no uncommitted changes. + pub fn execute( + &self, + input_data: HashMap, + uncommitted_data: DeltaSet, + ) -> Result { + if let Some(root_id) = self.root { + self.execute_node(root_id, &input_data, &uncommitted_data) + } else { + Err(LimboError::ParseError( + "Circuit has no root node".to_string(), + )) + } + } + + /// Commit deltas to the circuit, updating internal operator state. + /// This should be called after execute() when you want to make changes permanent. + /// + /// # Arguments + /// * `input_data` - The deltas to commit (same as what was passed to execute) + pub fn commit(&mut self, input_data: HashMap) -> Result<()> { + if let Some(root_id) = self.root { + self.commit_node(root_id, &input_data)?; + } + Ok(()) + } + + /// Commit a specific node in the circuit + fn commit_node( + &mut self, + node_id: usize, + input_data: &HashMap, + ) -> Result { + // Clone to avoid borrow checker issues + let inputs = self + .nodes + .get(&node_id) + .ok_or_else(|| LimboError::ParseError("Node not found".to_string()))? + .inputs + .clone(); + + // Process inputs first + let mut input_deltas = Vec::new(); + for input_id in inputs { + let delta = self.commit_node(input_id, input_data)?; + input_deltas.push(delta); + } + + // Get mutable reference to node + let node = self + .nodes + .get_mut(&node_id) + .ok_or_else(|| LimboError::ParseError("Node not found".to_string()))?; + + // Commit based on operator type + let result = match &node.operator { + DbspOperator::Input { name, .. } => { + // For input nodes, just return the committed delta + input_data.get(name).cloned().unwrap_or_else(Delta::new) + } + DbspOperator::Filter { .. } + | DbspOperator::Projection { .. } + | DbspOperator::Aggregate { .. } => { + // Commit the delta to the executable operator + if let Some(ref mut op) = node.executable { + if !input_deltas.is_empty() { + let input_delta = input_deltas[0].clone(); + // Commit updates state and returns the output delta + op.commit(input_delta) + } else { + Delta::new() + } + } else { + // If no executable, pass through the input + if !input_deltas.is_empty() { + input_deltas[0].clone() + } else { + Delta::new() + } + } + } + }; + Ok(result) + } + + /// Execute a specific node in the circuit + fn execute_node( + &self, + node_id: usize, + input_data: &HashMap, + uncommitted_data: &DeltaSet, + ) -> Result { + // Clone to avoid borrow checker issues + let inputs = self + .nodes + .get(&node_id) + .ok_or_else(|| LimboError::ParseError("Node not found".to_string()))? + .inputs + .clone(); + + // Process inputs first + let mut input_deltas = Vec::new(); + for input_id in inputs { + let delta = self.execute_node(input_id, input_data, uncommitted_data)?; + input_deltas.push(delta); + } + + // Get reference to node (read-only since we're using eval, not commit) + let node = self + .nodes + .get(&node_id) + .ok_or_else(|| LimboError::ParseError("Node not found".to_string()))?; + + // Execute based on operator type + let result = match &node.operator { + DbspOperator::Input { name, .. } => { + // Get committed data from input map and merge with uncommitted if present + let committed = input_data.get(name).cloned().unwrap_or_else(Delta::new); + let uncommitted = uncommitted_data.get(name); + + // If there's uncommitted data for this table, merge it with committed + if !uncommitted.is_empty() { + let mut combined = committed; + combined.merge(&uncommitted); + combined + } else { + committed + } + } + DbspOperator::Filter { .. } + | DbspOperator::Projection { .. } + | DbspOperator::Aggregate { .. } => { + // Process delta using the executable operator + if let Some(ref op) = node.executable { + if !input_deltas.is_empty() { + // Process the delta through the operator + let input_delta = input_deltas[0].clone(); + + // Use eval to compute result without modifying state + // The uncommitted data has already been merged into input_delta if needed + op.eval(input_delta, None) + } else { + Delta::new() + } + } else { + // If no executable, pass through the input + if !input_deltas.is_empty() { + input_deltas[0].clone() + } else { + Delta::new() + } + } + } + }; + Ok(result) + } +} + +impl Display for DbspCircuit { + fn fmt(&self, f: &mut Formatter) -> fmt::Result { + writeln!(f, "DBSP Circuit:")?; + if let Some(root_id) = self.root { + self.fmt_node(f, root_id, 0)?; + } + Ok(()) + } +} + +impl DbspCircuit { + fn fmt_node(&self, f: &mut Formatter, node_id: usize, depth: usize) -> fmt::Result { + let indent = " ".repeat(depth); + if let Some(node) = self.nodes.get(&node_id) { + match &node.operator { + DbspOperator::Filter { predicate } => { + writeln!(f, "{indent}Filter[{node_id}]: {predicate:?}")?; + } + DbspOperator::Projection { exprs, .. } => { + writeln!(f, "{indent}Projection[{node_id}]: {exprs:?}")?; + } + DbspOperator::Aggregate { + group_exprs, + aggr_exprs, + .. + } => { + writeln!( + f, + "{indent}Aggregate[{node_id}]: GROUP BY {group_exprs:?}, AGGR {aggr_exprs:?}" + )?; + } + DbspOperator::Input { name, .. } => { + writeln!(f, "{indent}Input[{node_id}]: {name}")?; + } + } + + for input_id in &node.inputs { + self.fmt_node(f, *input_id, depth + 1)?; + } + } + Ok(()) + } +} + +/// Compiler from LogicalPlan to DBSP Circuit +pub struct DbspCompiler { + circuit: DbspCircuit, +} + +impl DbspCompiler { + /// Create a new DBSP compiler + pub fn new() -> Self { + Self { + circuit: DbspCircuit::new(), + } + } + + /// Compile a logical plan to a DBSP circuit + pub fn compile(mut self, plan: &LogicalPlan) -> Result { + let root_id = self.compile_plan(plan)?; + self.circuit.root = Some(root_id); + Ok(self.circuit) + } + + /// Recursively compile a logical plan node + fn compile_plan(&mut self, plan: &LogicalPlan) -> Result { + match plan { + LogicalPlan::Projection(proj) => { + // Compile the input first + let input_id = self.compile_plan(&proj.input)?; + + // Get input column names for the ProjectOperator + let input_schema = proj.input.schema(); + let input_column_names: Vec = input_schema.columns.iter() + .map(|(name, _)| name.clone()) + .collect(); + + // Convert logical expressions to DBSP expressions + let dbsp_exprs = proj.exprs.iter() + .map(Self::compile_expr) + .collect::>>()?; + + // Compile logical expressions to CompiledExpressions + let mut compiled_exprs = Vec::new(); + let mut aliases = Vec::new(); + for expr in &proj.exprs { + let (compiled, alias) = Self::compile_expression(expr, &input_column_names)?; + compiled_exprs.push(compiled); + aliases.push(alias); + } + + // Get output column names from the projection schema + let output_column_names: Vec = proj.schema.columns.iter() + .map(|(name, _)| name.clone()) + .collect(); + + // Create the ProjectOperator + let executable: Option> = + ProjectOperator::from_compiled(compiled_exprs, aliases, input_column_names, output_column_names) + .ok() + .map(|op| Box::new(op) as Box); + + // Create projection node + let node_id = self.circuit.add_node( + DbspOperator::Projection { + exprs: dbsp_exprs, + schema: proj.schema.clone(), + }, + vec![input_id], + executable, + ); + Ok(node_id) + } + LogicalPlan::Filter(filter) => { + // Compile the input first + let input_id = self.compile_plan(&filter.input)?; + + // Get column names from input schema + let input_schema = filter.input.schema(); + let column_names: Vec = input_schema.columns.iter() + .map(|(name, _)| name.clone()) + .collect(); + + // Convert predicate to DBSP expression + let dbsp_predicate = Self::compile_expr(&filter.predicate)?; + + // Convert to FilterPredicate + let filter_predicate = Self::compile_filter_predicate(&filter.predicate)?; + + // Create executable operator + let executable: Box = + Box::new(FilterOperator::new(filter_predicate, column_names)); + + // Create filter node + let node_id = self.circuit.add_node( + DbspOperator::Filter { predicate: dbsp_predicate }, + vec![input_id], + Some(executable), + ); + Ok(node_id) + } + LogicalPlan::Aggregate(agg) => { + // Compile the input first + let input_id = self.compile_plan(&agg.input)?; + + // Get input column names + let input_schema = agg.input.schema(); + let input_column_names: Vec = input_schema.columns.iter() + .map(|(name, _)| name.clone()) + .collect(); + + // Compile group by expressions to column names + let mut group_by_columns = Vec::new(); + let mut dbsp_group_exprs = Vec::new(); + for expr in &agg.group_expr { + // For now, only support simple column references in GROUP BY + if let LogicalExpr::Column(col) = expr { + group_by_columns.push(col.name.clone()); + dbsp_group_exprs.push(DbspExpr::Column(col.name.clone())); + } else { + return Err(LimboError::ParseError( + "Only column references are supported in GROUP BY for incremental views".to_string() + )); + } + } + + // Compile aggregate expressions + let mut aggregate_functions = Vec::new(); + for expr in &agg.aggr_expr { + if let LogicalExpr::AggregateFunction { fun, args, .. } = expr { + use crate::function::AggFunc; + use crate::incremental::operator::AggregateFunction; + + let agg_fn = match fun { + AggFunc::Count | AggFunc::Count0 => { + AggregateFunction::Count + } + AggFunc::Sum => { + if args.is_empty() { + return Err(LimboError::ParseError("SUM requires an argument".to_string())); + } + // Extract column name from the argument + if let LogicalExpr::Column(col) = &args[0] { + AggregateFunction::Sum(col.name.clone()) + } else { + return Err(LimboError::ParseError( + "Only column references are supported in aggregate functions for incremental views".to_string() + )); + } + } + AggFunc::Avg => { + if args.is_empty() { + return Err(LimboError::ParseError("AVG requires an argument".to_string())); + } + if let LogicalExpr::Column(col) = &args[0] { + AggregateFunction::Avg(col.name.clone()) + } else { + return Err(LimboError::ParseError( + "Only column references are supported in aggregate functions for incremental views".to_string() + )); + } + } + // MIN and MAX are not supported in incremental views due to storage overhead. + // To correctly handle deletions, these operators would need to track all values + // in each group, resulting in O(n) storage overhead. This is prohibitive for + // large datasets. Alternative approaches like maintaining sorted indexes still + // require O(n) storage. Until a more efficient solution is found, MIN/MAX + // aggregations are not supported in materialized views. + AggFunc::Min => { + return Err(LimboError::ParseError( + "MIN aggregation is not supported in incremental materialized views due to O(n) storage overhead required for handling deletions".to_string() + )); + } + AggFunc::Max => { + return Err(LimboError::ParseError( + "MAX aggregation is not supported in incremental materialized views due to O(n) storage overhead required for handling deletions".to_string() + )); + } + _ => { + return Err(LimboError::ParseError( + format!("Unsupported aggregate function in DBSP compiler: {fun:?}") + )); + } + }; + aggregate_functions.push(agg_fn); + } else { + return Err(LimboError::ParseError( + "Expected aggregate function in aggregate expressions".to_string() + )); + } + } + + // Create the AggregateOperator + use crate::incremental::operator::AggregateOperator; + let executable: Option> = Some( + Box::new(AggregateOperator::new( + group_by_columns, + aggregate_functions.clone(), + input_column_names, + )) + ); + + // Create aggregate node + let node_id = self.circuit.add_node( + DbspOperator::Aggregate { + group_exprs: dbsp_group_exprs, + aggr_exprs: aggregate_functions, + schema: agg.schema.clone(), + }, + vec![input_id], + executable, + ); + Ok(node_id) + } + LogicalPlan::TableScan(scan) => { + // Create input node (no executable needed for input) + let node_id = self.circuit.add_node( + DbspOperator::Input { + name: scan.table_name.clone(), + schema: scan.schema.clone(), + }, + vec![], + None, + ); + Ok(node_id) + } + _ => Err(LimboError::ParseError( + format!("Unsupported operator in DBSP compiler: only Filter, Projection and Aggregate are supported, got: {:?}", + match plan { + LogicalPlan::Sort(_) => "Sort", + LogicalPlan::Limit(_) => "Limit", + LogicalPlan::Union(_) => "Union", + LogicalPlan::Distinct(_) => "Distinct", + LogicalPlan::EmptyRelation(_) => "EmptyRelation", + LogicalPlan::Values(_) => "Values", + LogicalPlan::WithCTE(_) => "WithCTE", + LogicalPlan::CTERef(_) => "CTERef", + _ => "Unknown", + } + ) + )), + } + } + + /// Convert a logical expression to a DBSP expression + fn compile_expr(expr: &LogicalExpr) -> Result { + match expr { + LogicalExpr::Column(col) => Ok(DbspExpr::Column(col.name.clone())), + + LogicalExpr::Literal(val) => Ok(DbspExpr::Literal(val.clone())), + + LogicalExpr::BinaryExpr { left, op, right } => { + let left_expr = Self::compile_expr(left)?; + let right_expr = Self::compile_expr(right)?; + + Ok(DbspExpr::BinaryExpr { + left: Box::new(left_expr), + op: *op, + right: Box::new(right_expr), + }) + } + + LogicalExpr::Alias { expr, .. } => { + // For aliases, compile the underlying expression + Self::compile_expr(expr) + } + + // For complex expressions (functions, etc), we can't represent them as DbspExpr + // but that's OK - they'll be handled by the ProjectOperator's VDBE compilation + // For now, just use a placeholder + _ => { + // Use a literal null as placeholder - the actual execution will use the compiled VDBE + Ok(DbspExpr::Literal(Value::Null)) + } + } + } + + /// Compile a logical expression to a CompiledExpression and optional alias + fn compile_expression( + expr: &LogicalExpr, + input_column_names: &[String], + ) -> Result<(CompiledExpression, Option)> { + // Check for alias first + if let LogicalExpr::Alias { expr, alias } = expr { + // For aliases, compile the underlying expression and return with alias + let (compiled, _) = Self::compile_expression(expr, input_column_names)?; + return Ok((compiled, Some(alias.clone()))); + } + + // Convert LogicalExpr to AST Expr + let ast_expr = Self::logical_to_ast_expr(expr)?; + + // For all expressions (simple or complex), use CompiledExpression::compile + // This handles both trivial cases and complex VDBE compilation + // We need to set up the necessary context + use crate::{Database, MemoryIO, SymbolTable}; + use std::sync::Arc; + + // Create an internal connection for expression compilation + let io = Arc::new(MemoryIO::new()); + let db = Database::open_file(io, ":memory:", false, false)?; + let internal_conn = db.connect()?; + internal_conn.query_only.set(true); + internal_conn.auto_commit.set(false); + + // Create temporary symbol table + let temp_syms = SymbolTable::new(); + + // Get a minimal schema for compilation (we don't need the full schema for expressions) + let schema = crate::schema::Schema::new(false); + + // Compile the expression using the existing CompiledExpression::compile + let compiled = CompiledExpression::compile( + &ast_expr, + input_column_names, + &schema, + &temp_syms, + internal_conn, + )?; + + Ok((compiled, None)) + } + + /// Convert LogicalExpr to AST Expr + fn logical_to_ast_expr(expr: &LogicalExpr) -> Result { + use turso_parser::ast; + + match expr { + LogicalExpr::Column(col) => Ok(ast::Expr::Id(ast::Name::Ident(col.name.clone()))), + LogicalExpr::Literal(val) => { + let lit = match val { + Value::Integer(i) => ast::Literal::Numeric(i.to_string()), + Value::Float(f) => ast::Literal::Numeric(f.to_string()), + Value::Text(t) => ast::Literal::String(t.to_string()), + Value::Blob(b) => ast::Literal::Blob(format!("{b:?}")), + Value::Null => ast::Literal::Null, + }; + Ok(ast::Expr::Literal(lit)) + } + LogicalExpr::BinaryExpr { left, op, right } => { + let left_expr = Self::logical_to_ast_expr(left)?; + let right_expr = Self::logical_to_ast_expr(right)?; + Ok(ast::Expr::Binary( + Box::new(left_expr), + *op, + Box::new(right_expr), + )) + } + LogicalExpr::ScalarFunction { fun, args } => { + let ast_args: Result> = args.iter().map(Self::logical_to_ast_expr).collect(); + let ast_args: Vec> = ast_args?.into_iter().map(Box::new).collect(); + Ok(ast::Expr::FunctionCall { + name: ast::Name::Ident(fun.clone()), + distinctness: None, + args: ast_args, + order_by: Vec::new(), + filter_over: ast::FunctionTail { + filter_clause: None, + over_clause: None, + }, + }) + } + LogicalExpr::Alias { expr, .. } => { + // For conversion to AST, ignore the alias and convert the inner expression + Self::logical_to_ast_expr(expr) + } + LogicalExpr::AggregateFunction { + fun, + args, + distinct, + } => { + // Convert aggregate function to AST + let ast_args: Result> = args.iter().map(Self::logical_to_ast_expr).collect(); + let ast_args: Vec> = ast_args?.into_iter().map(Box::new).collect(); + + // Get the function name based on the aggregate type + let func_name = match fun { + crate::function::AggFunc::Count => "COUNT", + crate::function::AggFunc::Sum => "SUM", + crate::function::AggFunc::Avg => "AVG", + crate::function::AggFunc::Min => "MIN", + crate::function::AggFunc::Max => "MAX", + _ => { + return Err(LimboError::ParseError(format!( + "Unsupported aggregate function: {fun:?}" + ))) + } + }; + + Ok(ast::Expr::FunctionCall { + name: ast::Name::Ident(func_name.to_string()), + distinctness: if *distinct { + Some(ast::Distinctness::Distinct) + } else { + None + }, + args: ast_args, + order_by: Vec::new(), + filter_over: ast::FunctionTail { + filter_clause: None, + over_clause: None, + }, + }) + } + _ => Err(LimboError::ParseError(format!( + "Cannot convert LogicalExpr to AST Expr: {expr:?}" + ))), + } + } + + /// Compile a logical expression to a FilterPredicate for execution + fn compile_filter_predicate(expr: &LogicalExpr) -> Result { + match expr { + LogicalExpr::BinaryExpr { left, op, right } => { + // Extract column name and value for simple predicates + if let (LogicalExpr::Column(col), LogicalExpr::Literal(val)) = + (left.as_ref(), right.as_ref()) + { + match op { + BinaryOperator::Equals => Ok(FilterPredicate::Equals { + column: col.name.clone(), + value: val.clone(), + }), + BinaryOperator::NotEquals => Ok(FilterPredicate::NotEquals { + column: col.name.clone(), + value: val.clone(), + }), + BinaryOperator::Greater => Ok(FilterPredicate::GreaterThan { + column: col.name.clone(), + value: val.clone(), + }), + BinaryOperator::GreaterEquals => Ok(FilterPredicate::GreaterThanOrEqual { + column: col.name.clone(), + value: val.clone(), + }), + BinaryOperator::Less => Ok(FilterPredicate::LessThan { + column: col.name.clone(), + value: val.clone(), + }), + BinaryOperator::LessEquals => Ok(FilterPredicate::LessThanOrEqual { + column: col.name.clone(), + value: val.clone(), + }), + BinaryOperator::And => { + // Handle AND of two predicates + let left_pred = Self::compile_filter_predicate(left)?; + let right_pred = Self::compile_filter_predicate(right)?; + Ok(FilterPredicate::And( + Box::new(left_pred), + Box::new(right_pred), + )) + } + BinaryOperator::Or => { + // Handle OR of two predicates + let left_pred = Self::compile_filter_predicate(left)?; + let right_pred = Self::compile_filter_predicate(right)?; + Ok(FilterPredicate::Or( + Box::new(left_pred), + Box::new(right_pred), + )) + } + _ => Err(LimboError::ParseError(format!( + "Unsupported operator in filter: {op:?}" + ))), + } + } else if matches!(op, BinaryOperator::And | BinaryOperator::Or) { + // Handle logical operators + let left_pred = Self::compile_filter_predicate(left)?; + let right_pred = Self::compile_filter_predicate(right)?; + match op { + BinaryOperator::And => Ok(FilterPredicate::And( + Box::new(left_pred), + Box::new(right_pred), + )), + BinaryOperator::Or => Ok(FilterPredicate::Or( + Box::new(left_pred), + Box::new(right_pred), + )), + _ => unreachable!(), + } + } else { + Err(LimboError::ParseError( + "Filter predicate must be column op value".to_string(), + )) + } + } + _ => Err(LimboError::ParseError(format!( + "Unsupported filter expression: {expr:?}" + ))), + } + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::incremental::operator::{Delta, FilterOperator, FilterPredicate}; + use crate::schema::{BTreeTable, Column as SchemaColumn, Schema, Type}; + use crate::translate::logical::LogicalPlanBuilder; + use crate::translate::logical::LogicalSchema; + use std::sync::Arc; + use turso_parser::ast; + use turso_parser::parser::Parser; + + // Macro to create a test schema with a users table + macro_rules! test_schema { + () => {{ + let mut schema = Schema::new(false); + let users_table = BTreeTable { + name: "users".to_string(), + root_page: 2, + primary_key_columns: vec![("id".to_string(), turso_parser::ast::SortOrder::Asc)], + columns: vec![ + SchemaColumn { + name: Some("id".to_string()), + ty: Type::Integer, + ty_str: "INTEGER".to_string(), + primary_key: true, + is_rowid_alias: true, + notnull: true, + default: None, + unique: false, + collation: None, + hidden: false, + }, + SchemaColumn { + name: Some("name".to_string()), + ty: Type::Text, + ty_str: "TEXT".to_string(), + primary_key: false, + is_rowid_alias: false, + notnull: false, + default: None, + unique: false, + collation: None, + hidden: false, + }, + SchemaColumn { + name: Some("age".to_string()), + ty: Type::Integer, + ty_str: "INTEGER".to_string(), + primary_key: false, + is_rowid_alias: false, + notnull: false, + default: None, + unique: false, + collation: None, + hidden: false, + }, + ], + has_rowid: true, + is_strict: false, + unique_sets: None, + }; + schema.add_btree_table(Arc::new(users_table)); + schema + }}; + } + + // Macro to compile SQL to DBSP circuit + macro_rules! compile_sql { + ($sql:expr) => {{ + let schema = test_schema!(); + let mut parser = Parser::new($sql.as_bytes()); + let cmd = parser + .next() + .unwrap() // This returns Option> + .unwrap(); // This unwraps the Result + + match cmd { + ast::Cmd::Stmt(stmt) => { + let mut builder = LogicalPlanBuilder::new(&schema); + let logical_plan = builder.build_statement(&stmt).unwrap(); + DbspCompiler::new().compile(&logical_plan).unwrap() + } + _ => panic!("Only SQL statements are supported"), + } + }}; + } + + // Macro to assert circuit structure + macro_rules! assert_circuit { + ($circuit:expr, depth: $depth:expr, root: $root_type:ident) => { + assert_eq!($circuit.nodes.len(), $depth); + let node = get_node_at_level(&$circuit, 0); + assert!(matches!(node.operator, DbspOperator::$root_type { .. })); + }; + } + + // Macro to assert operator properties + macro_rules! assert_operator { + ($circuit:expr, $level:expr, Input { name: $name:expr }) => {{ + let node = get_node_at_level(&$circuit, $level); + match &node.operator { + DbspOperator::Input { name, .. } => assert_eq!(name, $name), + _ => panic!("Expected Input operator at level {}", $level), + } + }}; + ($circuit:expr, $level:expr, Filter) => {{ + let node = get_node_at_level(&$circuit, $level); + assert!(matches!(node.operator, DbspOperator::Filter { .. })); + }}; + ($circuit:expr, $level:expr, Projection { columns: [$($col:expr),*] }) => {{ + let node = get_node_at_level(&$circuit, $level); + match &node.operator { + DbspOperator::Projection { exprs, .. } => { + let expected_cols = vec![$($col),*]; + let actual_cols: Vec = exprs.iter().map(|e| { + match e { + DbspExpr::Column(name) => name.clone(), + _ => "expr".to_string(), + } + }).collect(); + assert_eq!(actual_cols, expected_cols); + } + _ => panic!("Expected Projection operator at level {}", $level), + } + }}; + } + + // Macro to assert filter predicate + macro_rules! assert_filter_predicate { + ($circuit:expr, $level:expr, $col:literal > $val:literal) => {{ + let node = get_node_at_level(&$circuit, $level); + match &node.operator { + DbspOperator::Filter { predicate } => match predicate { + DbspExpr::BinaryExpr { left, op, right } => { + assert!(matches!(op, ast::Operator::Greater)); + assert!(matches!(&**left, DbspExpr::Column(name) if name == $col)); + assert!(matches!(&**right, DbspExpr::Literal(Value::Integer($val)))); + } + _ => panic!("Expected binary expression in filter"), + }, + _ => panic!("Expected Filter operator at level {}", $level), + } + }}; + ($circuit:expr, $level:expr, $col:literal < $val:literal) => {{ + let node = get_node_at_level(&$circuit, $level); + match &node.operator { + DbspOperator::Filter { predicate } => match predicate { + DbspExpr::BinaryExpr { left, op, right } => { + assert!(matches!(op, ast::Operator::Less)); + assert!(matches!(&**left, DbspExpr::Column(name) if name == $col)); + assert!(matches!(&**right, DbspExpr::Literal(Value::Integer($val)))); + } + _ => panic!("Expected binary expression in filter"), + }, + _ => panic!("Expected Filter operator at level {}", $level), + } + }}; + ($circuit:expr, $level:expr, $col:literal = $val:literal) => {{ + let node = get_node_at_level(&$circuit, $level); + match &node.operator { + DbspOperator::Filter { predicate } => match predicate { + DbspExpr::BinaryExpr { left, op, right } => { + assert!(matches!(op, ast::Operator::Equals)); + assert!(matches!(&**left, DbspExpr::Column(name) if name == $col)); + assert!(matches!(&**right, DbspExpr::Literal(Value::Integer($val)))); + } + _ => panic!("Expected binary expression in filter"), + }, + _ => panic!("Expected Filter operator at level {}", $level), + } + }}; + } + + // Helper to get node at specific level from root + fn get_node_at_level(circuit: &DbspCircuit, level: usize) -> &DbspNode { + let mut current_id = circuit.root.expect("Circuit has no root"); + for _ in 0..level { + let node = circuit.nodes.get(¤t_id).expect("Node not found"); + if node.inputs.is_empty() { + panic!("No more levels available, requested level {level}"); + } + current_id = node.inputs[0]; + } + circuit.nodes.get(¤t_id).expect("Node not found") + } + + // Helper to get the current accumulated state of the circuit (from the root operator) + // This returns the internal state including bookkeeping entries + fn get_current_state(circuit: &DbspCircuit) -> Result { + if let Some(root_id) = circuit.root { + let node = circuit + .nodes + .get(&root_id) + .ok_or_else(|| LimboError::ParseError("Root node not found".to_string()))?; + + if let Some(ref executable) = node.executable { + Ok(executable.get_current_state()) + } else { + // Input nodes don't have executables but also don't have state + Ok(Delta::new()) + } + } else { + Err(LimboError::ParseError( + "Circuit has no root node".to_string(), + )) + } + } + + // Helper to create a DeltaSet from a HashMap (for tests) + fn delta_set_from_map(map: HashMap) -> DeltaSet { + let mut delta_set = DeltaSet::new(); + for (key, value) in map { + delta_set.insert(key, value); + } + delta_set + } + + #[test] + fn test_simple_projection() { + let circuit = compile_sql!("SELECT name FROM users"); + + // Circuit has 2 nodes with Projection at root + assert_circuit!(circuit, depth: 2, root: Projection); + + // Verify operators at each level + assert_operator!(circuit, 0, Projection { columns: ["name"] }); + assert_operator!(circuit, 1, Input { name: "users" }); + } + + #[test] + fn test_filter_with_projection() { + let circuit = compile_sql!("SELECT name FROM users WHERE age > 18"); + + // Circuit has 3 nodes with Projection at root + assert_circuit!(circuit, depth: 3, root: Projection); + + // Verify operators at each level + assert_operator!(circuit, 0, Projection { columns: ["name"] }); + assert_operator!(circuit, 1, Filter); + assert_filter_predicate!(circuit, 1, "age" > 18); + assert_operator!(circuit, 2, Input { name: "users" }); + } + + #[test] + fn test_select_star() { + let mut circuit = compile_sql!("SELECT * FROM users"); + + // Create test data + let mut input_delta = Delta::new(); + input_delta.insert( + 1, + vec![ + Value::Integer(1), + Value::Text("Alice".into()), + Value::Integer(25), + ], + ); + input_delta.insert( + 2, + vec![ + Value::Integer(2), + Value::Text("Bob".into()), + Value::Integer(17), + ], + ); + + // Create input map + let mut inputs = HashMap::new(); + inputs.insert("users".to_string(), input_delta); + + // Initialize circuit with initial data + let result = circuit.initialize(inputs).unwrap(); + + // Should have all rows with all columns + assert_eq!(result.changes.len(), 2); + + // Verify both rows are present with all columns + for (row, weight) in &result.changes { + assert_eq!(*weight, 1); + assert_eq!(row.values.len(), 3); // id, name, age + } + } + + #[test] + fn test_execute_filter() { + let mut circuit = compile_sql!("SELECT * FROM users WHERE age > 18"); + + // Create test data + let mut input_delta = Delta::new(); + input_delta.insert( + 1, + vec![ + Value::Integer(1), + Value::Text("Alice".into()), + Value::Integer(25), + ], + ); + input_delta.insert( + 2, + vec![ + Value::Integer(2), + Value::Text("Bob".into()), + Value::Integer(17), + ], + ); + input_delta.insert( + 3, + vec![ + Value::Integer(3), + Value::Text("Charlie".into()), + Value::Integer(30), + ], + ); + + // Create input map + let mut inputs = HashMap::new(); + inputs.insert("users".to_string(), input_delta); + + // Initialize circuit with initial data + let result = circuit.initialize(inputs).unwrap(); + + // Should only have Alice and Charlie (age > 18) + assert_eq!( + result.changes.len(), + 2, + "Expected 2 rows after filtering, got {}", + result.changes.len() + ); + + // Check that the filtered rows are correct + let names: Vec = result + .changes + .iter() + .filter_map(|(row, weight)| { + if *weight > 0 && row.values.len() > 1 { + if let Value::Text(name) = &row.values[1] { + Some(name.to_string()) + } else { + None + } + } else { + None + } + }) + .collect(); + + assert!( + names.contains(&"Alice".to_string()), + "Alice should be in results" + ); + assert!( + names.contains(&"Charlie".to_string()), + "Charlie should be in results" + ); + assert!( + !names.contains(&"Bob".to_string()), + "Bob should not be in results" + ); + } + + #[test] + fn test_simple_column_projection() { + let mut circuit = compile_sql!("SELECT name, age FROM users"); + + // Create test data + let mut input_delta = Delta::new(); + input_delta.insert( + 1, + vec![ + Value::Integer(1), + Value::Text("Alice".into()), + Value::Integer(25), + ], + ); + input_delta.insert( + 2, + vec![ + Value::Integer(2), + Value::Text("Bob".into()), + Value::Integer(17), + ], + ); + + // Create input map + let mut inputs = HashMap::new(); + inputs.insert("users".to_string(), input_delta); + + // Initialize circuit with initial data + let result = circuit.initialize(inputs).unwrap(); + + // Should have all rows but only 2 columns (name, age) + assert_eq!(result.changes.len(), 2); + + for (row, _) in &result.changes { + assert_eq!(row.values.len(), 2); // Only name and age + // First value should be name (Text) + assert!(matches!(&row.values[0], Value::Text(_))); + // Second value should be age (Integer) + assert!(matches!(&row.values[1], Value::Integer(_))); + } + } + + #[test] + fn test_simple_aggregation() { + // Test COUNT(*) with GROUP BY + let mut circuit = compile_sql!("SELECT age, COUNT(*) FROM users GROUP BY age"); + + // Create test data + let mut input_delta = Delta::new(); + input_delta.insert( + 1, + vec![ + Value::Integer(1), + Value::Text("Alice".into()), + Value::Integer(25), + ], + ); + input_delta.insert( + 2, + vec![ + Value::Integer(2), + Value::Text("Bob".into()), + Value::Integer(25), + ], + ); + input_delta.insert( + 3, + vec![ + Value::Integer(3), + Value::Text("Charlie".into()), + Value::Integer(30), + ], + ); + + // Create input map + let mut inputs = HashMap::new(); + inputs.insert("users".to_string(), input_delta); + + // Initialize circuit with initial data + let result = circuit.initialize(inputs).unwrap(); + + // Should have 2 groups: age 25 with count 2, age 30 with count 1 + assert_eq!(result.changes.len(), 2); + + // Check the results + let mut found_25 = false; + let mut found_30 = false; + + for (row, weight) in &result.changes { + assert_eq!(*weight, 1); + assert_eq!(row.values.len(), 2); // age, count + + if let (Value::Integer(age), Value::Integer(count)) = (&row.values[0], &row.values[1]) { + if *age == 25 { + assert_eq!(*count, 2, "Age 25 should have count 2"); + found_25 = true; + } else if *age == 30 { + assert_eq!(*count, 1, "Age 30 should have count 1"); + found_30 = true; + } + } + } + + assert!(found_25, "Should have group for age 25"); + assert!(found_30, "Should have group for age 30"); + } + + #[test] + fn test_sum_aggregation() { + // Test SUM with GROUP BY + let mut circuit = compile_sql!("SELECT name, SUM(age) FROM users GROUP BY name"); + + // Create test data - some names appear multiple times + let mut input_delta = Delta::new(); + input_delta.insert( + 1, + vec![ + Value::Integer(1), + Value::Text("Alice".into()), + Value::Integer(25), + ], + ); + input_delta.insert( + 2, + vec![ + Value::Integer(2), + Value::Text("Alice".into()), + Value::Integer(30), + ], + ); + input_delta.insert( + 3, + vec![ + Value::Integer(3), + Value::Text("Bob".into()), + Value::Integer(20), + ], + ); + + // Create input map + let mut inputs = HashMap::new(); + inputs.insert("users".to_string(), input_delta); + + // Initialize circuit with initial data + let result = circuit.initialize(inputs).unwrap(); + + // Should have 2 groups: Alice with sum 55, Bob with sum 20 + assert_eq!(result.changes.len(), 2); + + for (row, weight) in &result.changes { + assert_eq!(*weight, 1); + assert_eq!(row.values.len(), 2); // name, sum + + if let (Value::Text(name), Value::Float(sum)) = (&row.values[0], &row.values[1]) { + if name.as_str() == "Alice" { + assert_eq!(*sum, 55.0, "Alice should have sum 55"); + } else if name.as_str() == "Bob" { + assert_eq!(*sum, 20.0, "Bob should have sum 20"); + } + } + } + } + + #[test] + fn test_aggregation_without_group_by() { + // Test aggregation without GROUP BY - should produce a single row + let mut circuit = compile_sql!("SELECT COUNT(*), SUM(age), AVG(age) FROM users"); + + // Create test data + let mut input_delta = Delta::new(); + input_delta.insert( + 1, + vec![ + Value::Integer(1), + Value::Text("Alice".into()), + Value::Integer(25), + ], + ); + input_delta.insert( + 2, + vec![ + Value::Integer(2), + Value::Text("Bob".into()), + Value::Integer(30), + ], + ); + input_delta.insert( + 3, + vec![ + Value::Integer(3), + Value::Text("Charlie".into()), + Value::Integer(20), + ], + ); + + // Create input map + let mut inputs = HashMap::new(); + inputs.insert("users".to_string(), input_delta); + + // Initialize circuit with initial data + let result = circuit.initialize(inputs).unwrap(); + + // Should have exactly 1 row with all aggregates + assert_eq!( + result.changes.len(), + 1, + "Should have exactly one result row" + ); + + let (row, weight) = result.changes.first().unwrap(); + assert_eq!(*weight, 1); + assert_eq!(row.values.len(), 3); // count, sum, avg + + // Check aggregate results + // COUNT should be Integer + if let Value::Integer(count) = &row.values[0] { + assert_eq!(*count, 3, "COUNT(*) should be 3"); + } else { + panic!("COUNT should be Integer, got {:?}", row.values[0]); + } + + // SUM can be Integer (if whole number) or Float + match &row.values[1] { + Value::Integer(sum) => assert_eq!(*sum, 75, "SUM(age) should be 75"), + Value::Float(sum) => assert_eq!(*sum, 75.0, "SUM(age) should be 75.0"), + other => panic!("SUM should be Integer or Float, got {other:?}"), + } + + // AVG should be Float + if let Value::Float(avg) = &row.values[2] { + assert_eq!(*avg, 25.0, "AVG(age) should be 25.0"); + } else { + panic!("AVG should be Float, got {:?}", row.values[2]); + } + } + + #[test] + fn test_expression_projection_execution() { + // Test that complex expressions work through VDBE compilation + let mut circuit = compile_sql!("SELECT hex(id) FROM users"); + + // Create test data + let mut input_delta = Delta::new(); + input_delta.insert( + 1, + vec![ + Value::Integer(1), + Value::Text("Alice".into()), + Value::Integer(25), + ], + ); + input_delta.insert( + 2, + vec![ + Value::Integer(255), + Value::Text("Bob".into()), + Value::Integer(17), + ], + ); + + // Create input map + let mut inputs = HashMap::new(); + inputs.insert("users".to_string(), input_delta); + + // Initialize circuit with initial data + let result = circuit.initialize(inputs).unwrap(); + + assert_eq!(result.changes.len(), 2); + + let hex_values: HashMap = result + .changes + .iter() + .map(|(row, _)| { + let rowid = row.rowid; + if let Value::Text(text) = &row.values[0] { + (rowid, text.to_string()) + } else { + panic!("Expected Text value for hex() result"); + } + }) + .collect(); + + assert_eq!( + hex_values.get(&1).unwrap(), + "31", + "hex(1) should return '31' (hex of ASCII '1')" + ); + + assert_eq!( + hex_values.get(&2).unwrap(), + "323535", + "hex(255) should return '323535' (hex of ASCII '2', '5', '5')" + ); + } + + // TODO: This test currently fails on incremental updates. + // The initial execution works correctly, but incremental updates produce + // incorrect results (3 changes instead of 2, with wrong values). + // This tests that the aggregate operator correctly handles incremental + // updates when it's sandwiched between projection operators. + #[test] + fn test_projection_aggregation_projection_pattern() { + // Test pattern: projection -> aggregation -> projection + // Query: SELECT HEX(SUM(age + 2)) FROM users + let mut circuit = compile_sql!("SELECT HEX(SUM(age + 2)) FROM users"); + + // Initial input data + let mut input_delta = Delta::new(); + input_delta.insert( + 1, + vec![ + Value::Integer(1), + Value::Text("Alice".to_string().into()), + Value::Integer(25), + ], + ); + input_delta.insert( + 2, + vec![ + Value::Integer(2), + Value::Text("Bob".to_string().into()), + Value::Integer(30), + ], + ); + input_delta.insert( + 3, + vec![ + Value::Integer(3), + Value::Text("Charlie".to_string().into()), + Value::Integer(35), + ], + ); + + let mut input_data = HashMap::new(); + input_data.insert("users".to_string(), input_delta); + + // Initialize the circuit with the initial data + let result = circuit.initialize(input_data).unwrap(); + + // Expected: SUM(age + 2) = (25+2) + (30+2) + (35+2) = 27 + 32 + 37 = 96 + // HEX(96) should be the hex representation of the string "96" = "3936" + assert_eq!(result.changes.len(), 1); + let (row, _weight) = &result.changes[0]; + assert_eq!(row.values.len(), 1); + + // The hex function converts the number to string first, then to hex + // 96 as string is "96", which in hex is "3936" (hex of ASCII '9' and '6') + assert_eq!( + row.values[0], + Value::Text("3936".to_string().into()), + "HEX(SUM(age + 2)) should return '3936' for sum of 96" + ); + + // Test incremental update: add a new user + let mut input_delta = Delta::new(); + input_delta.insert( + 4, + vec![ + Value::Integer(4), + Value::Text("David".to_string().into()), + Value::Integer(40), + ], + ); + + let mut input_data = HashMap::new(); + input_data.insert("users".to_string(), input_delta); + + let result = circuit.execute(input_data, DeltaSet::empty()).unwrap(); + + // Expected: new SUM(age + 2) = 96 + (40+2) = 138 + // HEX(138) = hex of "138" = "313338" + assert_eq!(result.changes.len(), 2); + + // First change: remove old aggregate (96) + let (row, weight) = &result.changes[0]; + assert_eq!(*weight, -1); + assert_eq!(row.values[0], Value::Text("3936".to_string().into())); + + // Second change: add new aggregate (138) + let (row, weight) = &result.changes[1]; + assert_eq!(*weight, 1); + assert_eq!( + row.values[0], + Value::Text("313338".to_string().into()), + "HEX(SUM(age + 2)) should return '313338' for sum of 138" + ); + } + + #[test] + fn test_nested_projection_with_groupby() { + // Test pattern: projection -> aggregation with GROUP BY -> projection + // Query: SELECT name, HEX(SUM(age * 2)) FROM users GROUP BY name + let mut circuit = compile_sql!("SELECT name, HEX(SUM(age * 2)) FROM users GROUP BY name"); + + // Initial input data + let mut input_delta = Delta::new(); + input_delta.insert( + 1, + vec![ + Value::Integer(1), + Value::Text("Alice".to_string().into()), + Value::Integer(25), + ], + ); + input_delta.insert( + 2, + vec![ + Value::Integer(2), + Value::Text("Bob".to_string().into()), + Value::Integer(30), + ], + ); + input_delta.insert( + 3, + vec![ + Value::Integer(3), + Value::Text("Alice".to_string().into()), + Value::Integer(35), + ], + ); + + let mut input_data = HashMap::new(); + input_data.insert("users".to_string(), input_delta); + + // Initialize circuit with initial data + let result = circuit.initialize(input_data).unwrap(); + + // Expected results: + // Alice: SUM(25*2 + 35*2) = 50 + 70 = 120, HEX("120") = "313230" + // Bob: SUM(30*2) = 60, HEX("60") = "3630" + assert_eq!(result.changes.len(), 2); + + let results: HashMap = result + .changes + .iter() + .map(|(row, _weight)| { + let name = match &row.values[0] { + Value::Text(t) => t.to_string(), + _ => panic!("Expected text for name"), + }; + let hex_sum = match &row.values[1] { + Value::Text(t) => t.to_string(), + _ => panic!("Expected text for hex value"), + }; + (name, hex_sum) + }) + .collect(); + + assert_eq!( + results.get("Alice").unwrap(), + "313230", + "Alice's HEX(SUM(age * 2)) should be '313230' (120)" + ); + assert_eq!( + results.get("Bob").unwrap(), + "3630", + "Bob's HEX(SUM(age * 2)) should be '3630' (60)" + ); + } + + #[test] + fn test_transaction_context() { + // Test that uncommitted changes are visible within a transaction + // but don't affect the operator's internal state + let mut circuit = compile_sql!("SELECT * FROM users WHERE age > 18"); + + // Initialize with some data + let mut init_data = HashMap::new(); + let mut delta = Delta::new(); + delta.insert( + 1, + vec![ + Value::Integer(1), + Value::Text("Alice".into()), + Value::Integer(25), + ], + ); + delta.insert( + 2, + vec![ + Value::Integer(2), + Value::Text("Bob".into()), + Value::Integer(17), + ], + ); + init_data.insert("users".to_string(), delta); + + circuit.initialize(init_data).unwrap(); + + // Verify initial state: only Alice (age > 18) + let state = get_current_state(&circuit).unwrap(); + assert_eq!(state.changes.len(), 1); + assert_eq!(state.changes[0].0.values[1], Value::Text("Alice".into())); + + // Create uncommitted changes that would be visible in a transaction + let mut uncommitted = HashMap::new(); + let mut uncommitted_delta = Delta::new(); + // Add Charlie (age 30) - should be visible in transaction + uncommitted_delta.insert( + 3, + vec![ + Value::Integer(3), + Value::Text("Charlie".into()), + Value::Integer(30), + ], + ); + // Add David (age 15) - should NOT be visible (filtered out) + uncommitted_delta.insert( + 4, + vec![ + Value::Integer(4), + Value::Text("David".into()), + Value::Integer(15), + ], + ); + uncommitted.insert("users".to_string(), uncommitted_delta); + + // Execute with uncommitted data - this simulates processing the uncommitted changes + // through the circuit to see what would be visible + let tx_result = circuit + .execute(HashMap::new(), delta_set_from_map(uncommitted.clone())) + .unwrap(); + + // The result should show Charlie being added (passes filter, age > 18) + // David is filtered out (age 15 < 18) + assert_eq!(tx_result.changes.len(), 1, "Should see Charlie added"); + assert_eq!( + tx_result.changes[0].0.values[1], + Value::Text("Charlie".into()) + ); + + // Now actually commit Charlie (without uncommitted context) + let mut commit_data = HashMap::new(); + let mut commit_delta = Delta::new(); + commit_delta.insert( + 3, + vec![ + Value::Integer(3), + Value::Text("Charlie".into()), + Value::Integer(30), + ], + ); + commit_data.insert("users".to_string(), commit_delta); + + let commit_result = circuit + .execute(commit_data.clone(), DeltaSet::empty()) + .unwrap(); + + // The commit result should show Charlie being added + assert_eq!(commit_result.changes.len(), 1, "Should see Charlie added"); + assert_eq!( + commit_result.changes[0].0.values[1], + Value::Text("Charlie".into()) + ); + + // Commit the change to make it permanent + circuit.commit(commit_data).unwrap(); + + // Now if we execute again with no changes, we should see no delta + let empty_result = circuit.execute(HashMap::new(), DeltaSet::empty()).unwrap(); + assert_eq!(empty_result.changes.len(), 0, "No changes when no new data"); + } + + #[test] + fn test_uncommitted_delete() { + // Test that uncommitted deletes are handled correctly without affecting operator state + let mut circuit = compile_sql!("SELECT * FROM users WHERE age > 18"); + + // Initialize with some data + let mut init_data = HashMap::new(); + let mut delta = Delta::new(); + delta.insert( + 1, + vec![ + Value::Integer(1), + Value::Text("Alice".into()), + Value::Integer(25), + ], + ); + delta.insert( + 2, + vec![ + Value::Integer(2), + Value::Text("Bob".into()), + Value::Integer(30), + ], + ); + delta.insert( + 3, + vec![ + Value::Integer(3), + Value::Text("Charlie".into()), + Value::Integer(20), + ], + ); + init_data.insert("users".to_string(), delta); + + circuit.initialize(init_data).unwrap(); + + // Verify initial state: Alice, Bob, Charlie (all age > 18) + let state = get_current_state(&circuit).unwrap(); + assert_eq!(state.changes.len(), 3); + + // Create uncommitted delete for Bob + let mut uncommitted = HashMap::new(); + let mut uncommitted_delta = Delta::new(); + uncommitted_delta.delete( + 2, + vec![ + Value::Integer(2), + Value::Text("Bob".into()), + Value::Integer(30), + ], + ); + uncommitted.insert("users".to_string(), uncommitted_delta); + + // Execute with uncommitted delete + let tx_result = circuit + .execute(HashMap::new(), delta_set_from_map(uncommitted.clone())) + .unwrap(); + + // Result should show the deleted row that passed the filter + assert_eq!( + tx_result.changes.len(), + 1, + "Should see the uncommitted delete" + ); + + // Verify operator's internal state is unchanged (still has all 3 users) + let state_after = get_current_state(&circuit).unwrap(); + assert_eq!( + state_after.changes.len(), + 3, + "Internal state should still have all 3 users" + ); + + // Now actually commit the delete + let mut commit_data = HashMap::new(); + let mut commit_delta = Delta::new(); + commit_delta.delete( + 2, + vec![ + Value::Integer(2), + Value::Text("Bob".into()), + Value::Integer(30), + ], + ); + commit_data.insert("users".to_string(), commit_delta); + + let commit_result = circuit + .execute(commit_data.clone(), DeltaSet::empty()) + .unwrap(); + + // Actually commit the delete to update operator state + circuit.commit(commit_data).unwrap(); + + // The commit result should show Bob being deleted + assert_eq!(commit_result.changes.len(), 1, "Should see Bob deleted"); + assert_eq!( + commit_result.changes[0].1, -1, + "Delete should have weight -1" + ); + assert_eq!( + commit_result.changes[0].0.values[1], + Value::Text("Bob".into()) + ); + + // After commit, internal state should have only Alice and Charlie + let final_state = get_current_state(&circuit).unwrap(); + assert_eq!( + final_state.changes.len(), + 2, + "After commit, should have Alice and Charlie" + ); + + let names: Vec = final_state + .changes + .iter() + .map(|(row, _)| { + if let Value::Text(name) = &row.values[1] { + name.to_string() + } else { + panic!("Expected text value"); + } + }) + .collect(); + assert!(names.contains(&"Alice".to_string())); + assert!(names.contains(&"Charlie".to_string())); + assert!(!names.contains(&"Bob".to_string())); + } + + #[test] + fn test_uncommitted_update() { + // Test that uncommitted updates (delete + insert) are handled correctly + let mut circuit = compile_sql!("SELECT * FROM users WHERE age > 18"); + + // Initialize with some data + let mut init_data = HashMap::new(); + let mut delta = Delta::new(); + delta.insert( + 1, + vec![ + Value::Integer(1), + Value::Text("Alice".into()), + Value::Integer(25), + ], + ); + delta.insert( + 2, + vec![ + Value::Integer(2), + Value::Text("Bob".into()), + Value::Integer(17), + ], + ); // Bob is 17, filtered out + init_data.insert("users".to_string(), delta); + + circuit.initialize(init_data).unwrap(); + + // Create uncommitted update: Bob turns 19 (update from 17 to 19) + // This is modeled as delete + insert + let mut uncommitted = HashMap::new(); + let mut uncommitted_delta = Delta::new(); + uncommitted_delta.delete( + 2, + vec![ + Value::Integer(2), + Value::Text("Bob".into()), + Value::Integer(17), + ], + ); + uncommitted_delta.insert( + 2, + vec![ + Value::Integer(2), + Value::Text("Bob".into()), + Value::Integer(19), + ], + ); + uncommitted.insert("users".to_string(), uncommitted_delta); + + // Execute with uncommitted update + let tx_result = circuit + .execute(HashMap::new(), delta_set_from_map(uncommitted.clone())) + .unwrap(); + + // Bob should now appear in the result (age 19 > 18) + // Consolidate to see the final state + let mut final_result = tx_result; + final_result.consolidate(); + + assert_eq!(final_result.changes.len(), 1, "Bob should now be in view"); + assert_eq!( + final_result.changes[0].0.values[1], + Value::Text("Bob".into()) + ); + assert_eq!(final_result.changes[0].0.values[2], Value::Integer(19)); + + // Now actually commit the update + let mut commit_data = HashMap::new(); + let mut commit_delta = Delta::new(); + commit_delta.delete( + 2, + vec![ + Value::Integer(2), + Value::Text("Bob".into()), + Value::Integer(17), + ], + ); + commit_delta.insert( + 2, + vec![ + Value::Integer(2), + Value::Text("Bob".into()), + Value::Integer(19), + ], + ); + commit_data.insert("users".to_string(), commit_delta); + + // Commit the update + circuit.commit(commit_data).unwrap(); + + // After committing, Bob should be in the view's state + let state = get_current_state(&circuit).unwrap(); + let mut consolidated_state = state; + consolidated_state.consolidate(); + + // Should have both Alice and Bob now + assert_eq!( + consolidated_state.changes.len(), + 2, + "Should have Alice and Bob" + ); + + let names: Vec = consolidated_state + .changes + .iter() + .map(|(row, _)| { + if let Value::Text(name) = &row.values[1] { + name.as_str().to_string() + } else { + panic!("Expected text value"); + } + }) + .collect(); + assert!(names.contains(&"Alice".to_string())); + assert!(names.contains(&"Bob".to_string())); + } + + #[test] + fn test_uncommitted_filtered_delete() { + // Test deleting a row that doesn't pass the filter + let mut circuit = compile_sql!("SELECT * FROM users WHERE age > 18"); + + // Initialize with mixed data + let mut init_data = HashMap::new(); + let mut delta = Delta::new(); + delta.insert( + 1, + vec![ + Value::Integer(1), + Value::Text("Alice".into()), + Value::Integer(25), + ], + ); + delta.insert( + 2, + vec![ + Value::Integer(2), + Value::Text("Bob".into()), + Value::Integer(15), + ], + ); // Bob doesn't pass filter + init_data.insert("users".to_string(), delta); + + circuit.initialize(init_data).unwrap(); + + // Create uncommitted delete for Bob (who isn't in the view because age=15) + let mut uncommitted = HashMap::new(); + let mut uncommitted_delta = Delta::new(); + uncommitted_delta.delete( + 2, + vec![ + Value::Integer(2), + Value::Text("Bob".into()), + Value::Integer(15), + ], + ); + uncommitted.insert("users".to_string(), uncommitted_delta); + + // Execute with uncommitted delete - should produce no output changes + let tx_result = circuit + .execute(HashMap::new(), delta_set_from_map(uncommitted)) + .unwrap(); + + // Bob wasn't in the view, so deleting him produces no output + assert_eq!( + tx_result.changes.len(), + 0, + "Deleting filtered row produces no changes" + ); + + // The view state should still only have Alice + let state = get_current_state(&circuit).unwrap(); + assert_eq!(state.changes.len(), 1, "View still has only Alice"); + assert_eq!(state.changes[0].0.values[1], Value::Text("Alice".into())); + } + + #[test] + fn test_uncommitted_mixed_operations() { + // Test multiple uncommitted operations together + let mut circuit = compile_sql!("SELECT * FROM users WHERE age > 18"); + + // Initialize with some data + let mut init_data = HashMap::new(); + let mut delta = Delta::new(); + delta.insert( + 1, + vec![ + Value::Integer(1), + Value::Text("Alice".into()), + Value::Integer(25), + ], + ); + delta.insert( + 2, + vec![ + Value::Integer(2), + Value::Text("Bob".into()), + Value::Integer(30), + ], + ); + init_data.insert("users".to_string(), delta); + + circuit.initialize(init_data).unwrap(); + + // Verify initial state + let state = get_current_state(&circuit).unwrap(); + assert_eq!(state.changes.len(), 2); + + // Create uncommitted changes: + // - Delete Alice + // - Update Bob's age to 35 + // - Insert Charlie (age 40) + // - Insert David (age 16, filtered out) + let mut uncommitted = HashMap::new(); + let mut uncommitted_delta = Delta::new(); + // Delete Alice + uncommitted_delta.delete( + 1, + vec![ + Value::Integer(1), + Value::Text("Alice".into()), + Value::Integer(25), + ], + ); + // Update Bob (delete + insert) + uncommitted_delta.delete( + 2, + vec![ + Value::Integer(2), + Value::Text("Bob".into()), + Value::Integer(30), + ], + ); + uncommitted_delta.insert( + 2, + vec![ + Value::Integer(2), + Value::Text("Bob".into()), + Value::Integer(35), + ], + ); + // Insert Charlie + uncommitted_delta.insert( + 3, + vec![ + Value::Integer(3), + Value::Text("Charlie".into()), + Value::Integer(40), + ], + ); + // Insert David (will be filtered) + uncommitted_delta.insert( + 4, + vec![ + Value::Integer(4), + Value::Text("David".into()), + Value::Integer(16), + ], + ); + uncommitted.insert("users".to_string(), uncommitted_delta); + + // Execute with uncommitted changes + let tx_result = circuit + .execute(HashMap::new(), delta_set_from_map(uncommitted.clone())) + .unwrap(); + + // Result should show all changes: delete Alice, update Bob, insert Charlie and David + assert_eq!( + tx_result.changes.len(), + 4, + "Should see all uncommitted mixed operations" + ); + + // Verify operator's internal state is unchanged + let state_after = get_current_state(&circuit).unwrap(); + assert_eq!(state_after.changes.len(), 2, "Still has Alice and Bob"); + + // Commit all changes + let mut commit_data = HashMap::new(); + let mut commit_delta = Delta::new(); + commit_delta.delete( + 1, + vec![ + Value::Integer(1), + Value::Text("Alice".into()), + Value::Integer(25), + ], + ); + commit_delta.delete( + 2, + vec![ + Value::Integer(2), + Value::Text("Bob".into()), + Value::Integer(30), + ], + ); + commit_delta.insert( + 2, + vec![ + Value::Integer(2), + Value::Text("Bob".into()), + Value::Integer(35), + ], + ); + commit_delta.insert( + 3, + vec![ + Value::Integer(3), + Value::Text("Charlie".into()), + Value::Integer(40), + ], + ); + commit_delta.insert( + 4, + vec![ + Value::Integer(4), + Value::Text("David".into()), + Value::Integer(16), + ], + ); + commit_data.insert("users".to_string(), commit_delta); + + let commit_result = circuit + .execute(commit_data.clone(), DeltaSet::empty()) + .unwrap(); + + // Should see: Alice deleted, Bob deleted, Bob inserted, Charlie inserted + // (David filtered out) + assert_eq!(commit_result.changes.len(), 4, "Should see 4 changes"); + + // Actually commit the changes to update operator state + circuit.commit(commit_data).unwrap(); + + // After all commits, execute with no changes should return empty delta + let empty_result = circuit.execute(HashMap::new(), DeltaSet::empty()).unwrap(); + assert_eq!(empty_result.changes.len(), 0, "No changes when no new data"); + } + + #[test] + fn test_uncommitted_aggregation() { + // Test that aggregations work correctly with uncommitted changes + // This tests the specific scenario where a transaction adds new data + // and we need to see correct aggregation results within the transaction + + // Create a sales table schema for testing + let mut schema = Schema::new(false); + let sales_table = BTreeTable { + name: "sales".to_string(), + root_page: 2, + primary_key_columns: vec![], + columns: vec![ + SchemaColumn { + name: Some("product_id".to_string()), + ty: Type::Integer, + ty_str: "INTEGER".to_string(), + primary_key: false, + is_rowid_alias: false, + notnull: false, + default: None, + unique: false, + collation: None, + hidden: false, + }, + SchemaColumn { + name: Some("amount".to_string()), + ty: Type::Integer, + ty_str: "INTEGER".to_string(), + primary_key: false, + is_rowid_alias: false, + notnull: false, + default: None, + unique: false, + collation: None, + hidden: false, + }, + ], + has_rowid: true, + is_strict: false, + unique_sets: None, + }; + schema.add_btree_table(Arc::new(sales_table)); + + // Parse and compile the aggregation query + let sql = "SELECT product_id, SUM(amount) as total, COUNT(*) as cnt FROM sales GROUP BY product_id"; + let mut parser = Parser::new(sql.as_bytes()); + let cmd = parser.next().unwrap().unwrap(); + + let mut circuit = match cmd { + ast::Cmd::Stmt(stmt) => { + let mut builder = LogicalPlanBuilder::new(&schema); + let logical_plan = builder.build_statement(&stmt).unwrap(); + DbspCompiler::new().compile(&logical_plan).unwrap() + } + _ => panic!("Expected SQL statement"), + }; + + // Initialize with base data: (1, 100), (1, 200), (2, 150), (2, 250) + let mut init_data = HashMap::new(); + let mut delta = Delta::new(); + delta.insert(1, vec![Value::Integer(1), Value::Integer(100)]); + delta.insert(2, vec![Value::Integer(1), Value::Integer(200)]); + delta.insert(3, vec![Value::Integer(2), Value::Integer(150)]); + delta.insert(4, vec![Value::Integer(2), Value::Integer(250)]); + init_data.insert("sales".to_string(), delta); + + circuit.initialize(init_data).unwrap(); + + // Verify initial state: product 1 total=300, product 2 total=400 + let state = get_current_state(&circuit).unwrap(); + assert_eq!(state.changes.len(), 2, "Should have 2 product groups"); + + // Build a map of product_id -> (total, count) + let initial_results: HashMap = state + .changes + .iter() + .map(|(row, _)| { + // SUM might return Integer or Float, COUNT returns Integer + let product_id = match &row.values[0] { + Value::Integer(id) => *id, + _ => panic!("Product ID should be Integer, got {:?}", row.values[0]), + }; + + let total = match &row.values[1] { + Value::Integer(t) => *t, + Value::Float(t) => *t as i64, + _ => panic!("Total should be numeric, got {:?}", row.values[1]), + }; + + let count = match &row.values[2] { + Value::Integer(c) => *c, + _ => panic!("Count should be Integer, got {:?}", row.values[2]), + }; + + (product_id, (total, count)) + }) + .collect(); + + assert_eq!( + initial_results.get(&1).unwrap(), + &(300, 2), + "Product 1 should have total=300, count=2" + ); + assert_eq!( + initial_results.get(&2).unwrap(), + &(400, 2), + "Product 2 should have total=400, count=2" + ); + + // Create uncommitted changes: INSERT (1, 50), (3, 300) + let mut uncommitted = HashMap::new(); + let mut uncommitted_delta = Delta::new(); + uncommitted_delta.insert(5, vec![Value::Integer(1), Value::Integer(50)]); // Add to product 1 + uncommitted_delta.insert(6, vec![Value::Integer(3), Value::Integer(300)]); // New product 3 + uncommitted.insert("sales".to_string(), uncommitted_delta); + + // Execute with uncommitted data - simulating a read within transaction + let tx_result = circuit + .execute(HashMap::new(), delta_set_from_map(uncommitted.clone())) + .unwrap(); + + // Result should show the aggregate changes from uncommitted data + // Product 1: retraction of (300, 2) and insertion of (350, 3) + // Product 3: insertion of (300, 1) - new product + assert_eq!( + tx_result.changes.len(), + 3, + "Should see aggregate changes from uncommitted data" + ); + + // IMPORTANT: Verify operator's internal state is unchanged + let state_after = get_current_state(&circuit).unwrap(); + assert_eq!( + state_after.changes.len(), + 2, + "Internal state should still have 2 groups" + ); + + // Verify the internal state still has original values + let state_results: HashMap = state_after + .changes + .iter() + .map(|(row, _)| { + let product_id = match &row.values[0] { + Value::Integer(id) => *id, + _ => panic!("Product ID should be Integer"), + }; + + let total = match &row.values[1] { + Value::Integer(t) => *t, + Value::Float(t) => *t as i64, + _ => panic!("Total should be numeric"), + }; + + let count = match &row.values[2] { + Value::Integer(c) => *c, + _ => panic!("Count should be Integer"), + }; + + (product_id, (total, count)) + }) + .collect(); + + assert_eq!( + state_results.get(&1).unwrap(), + &(300, 2), + "Product 1 unchanged" + ); + assert_eq!( + state_results.get(&2).unwrap(), + &(400, 2), + "Product 2 unchanged" + ); + assert!( + !state_results.contains_key(&3), + "Product 3 should not be in committed state" + ); + + // Now actually commit the changes + let mut commit_data = HashMap::new(); + let mut commit_delta = Delta::new(); + commit_delta.insert(5, vec![Value::Integer(1), Value::Integer(50)]); + commit_delta.insert(6, vec![Value::Integer(3), Value::Integer(300)]); + commit_data.insert("sales".to_string(), commit_delta); + + let commit_result = circuit + .execute(commit_data.clone(), DeltaSet::empty()) + .unwrap(); + + // Should see changes for product 1 (updated) and product 3 (new) + assert_eq!( + commit_result.changes.len(), + 3, + "Should see 3 changes (delete old product 1, insert new product 1, insert product 3)" + ); + + // Actually commit the changes to update operator state + circuit.commit(commit_data).unwrap(); + + // After commit, verify final state + let final_state = get_current_state(&circuit).unwrap(); + assert_eq!( + final_state.changes.len(), + 3, + "Should have 3 product groups after commit" + ); + + let final_results: HashMap = final_state + .changes + .iter() + .map(|(row, _)| { + let product_id = match &row.values[0] { + Value::Integer(id) => *id, + _ => panic!("Product ID should be Integer"), + }; + + let total = match &row.values[1] { + Value::Integer(t) => *t, + Value::Float(t) => *t as i64, + _ => panic!("Total should be numeric"), + }; + + let count = match &row.values[2] { + Value::Integer(c) => *c, + _ => panic!("Count should be Integer"), + }; + + (product_id, (total, count)) + }) + .collect(); + + assert_eq!( + final_results.get(&1).unwrap(), + &(350, 3), + "Product 1 should have total=350, count=3" + ); + assert_eq!( + final_results.get(&2).unwrap(), + &(400, 2), + "Product 2 should have total=400, count=2" + ); + assert_eq!( + final_results.get(&3).unwrap(), + &(300, 1), + "Product 3 should have total=300, count=1" + ); + } + + #[test] + fn test_uncommitted_data_visible_in_transaction() { + // Test that uncommitted INSERTs are visible within the same transaction + // This simulates: BEGIN; INSERT ...; SELECT * FROM view; COMMIT; + + let mut circuit = compile_sql!("SELECT * FROM users WHERE age > 18"); + + // Initialize with some data - need to match the schema (id, name, age) + let mut init_data = HashMap::new(); + let mut delta = Delta::new(); + delta.insert( + 1, + vec![ + Value::Integer(1), + Value::Text("Alice".into()), + Value::Integer(25), + ], + ); + delta.insert( + 2, + vec![ + Value::Integer(2), + Value::Text("Bob".into()), + Value::Integer(30), + ], + ); + init_data.insert("users".to_string(), delta); + + circuit.initialize(init_data.clone()).unwrap(); + + // Verify initial state + let state = get_current_state(&circuit).unwrap(); + assert_eq!( + state.len(), + 2, + "Should have 2 users initially (both pass age > 18 filter)" + ); + + // Simulate a transaction: INSERT new users that pass the filter - match schema (id, name, age) + let mut uncommitted = HashMap::new(); + let mut tx_delta = Delta::new(); + tx_delta.insert( + 3, + vec![ + Value::Integer(3), + Value::Text("Charlie".into()), + Value::Integer(35), + ], + ); + tx_delta.insert( + 4, + vec![ + Value::Integer(4), + Value::Text("David".into()), + Value::Integer(20), + ], + ); + uncommitted.insert("users".to_string(), tx_delta); + + // Execute with uncommitted data - this should return the uncommitted changes + // that passed through the filter (age > 18) + let tx_result = circuit + .execute(HashMap::new(), delta_set_from_map(uncommitted.clone())) + .unwrap(); + + // IMPORTANT: tx_result should contain the filtered uncommitted changes! + // Both Charlie (35) and David (20) should pass the age > 18 filter + assert_eq!( + tx_result.len(), + 2, + "Should see 2 uncommitted rows that pass filter" + ); + + // Verify the uncommitted results contain the expected rows + let has_charlie = tx_result.changes.iter().any(|(row, _)| row.rowid == 3); + assert!( + has_charlie, + "Should find Charlie (rowid=3) in uncommitted results" + ); + + let has_david = tx_result.changes.iter().any(|(row, _)| row.rowid == 4); + assert!( + has_david, + "Should find David (rowid=4) in uncommitted results" + ); + + // CRITICAL: Verify the operator state wasn't modified by uncommitted execution + let state_after_uncommitted = get_current_state(&circuit).unwrap(); + assert_eq!( + state_after_uncommitted.len(), + 2, + "State should STILL be 2 after uncommitted execution - only Alice and Bob" + ); + + // The state should not contain Charlie or David + let has_charlie_in_state = state_after_uncommitted + .changes + .iter() + .any(|(row, _)| row.rowid == 3); + let has_david_in_state = state_after_uncommitted + .changes + .iter() + .any(|(row, _)| row.rowid == 4); + assert!( + !has_charlie_in_state, + "Charlie should NOT be in operator state (uncommitted)" + ); + assert!( + !has_david_in_state, + "David should NOT be in operator state (uncommitted)" + ); + } + + #[test] + fn test_uncommitted_aggregation_with_rollback() { + // Test that rollback properly discards uncommitted aggregation changes + // Similar to test_uncommitted_aggregation but explicitly tests rollback semantics + + // Create a simple aggregation circuit + let mut circuit = compile_sql!("SELECT age, COUNT(*) as cnt FROM users GROUP BY age"); + + // Initialize with some data + let mut init_data = HashMap::new(); + let mut delta = Delta::new(); + delta.insert( + 1, + vec![ + Value::Integer(1), + Value::Text("Alice".into()), + Value::Integer(25), + ], + ); + delta.insert( + 2, + vec![ + Value::Integer(2), + Value::Text("Bob".into()), + Value::Integer(30), + ], + ); + delta.insert( + 3, + vec![ + Value::Integer(3), + Value::Text("Charlie".into()), + Value::Integer(25), + ], + ); + delta.insert( + 4, + vec![ + Value::Integer(4), + Value::Text("David".into()), + Value::Integer(30), + ], + ); + init_data.insert("users".to_string(), delta); + + circuit.initialize(init_data).unwrap(); + + // Verify initial state: age 25 count=2, age 30 count=2 + let state = get_current_state(&circuit).unwrap(); + assert_eq!(state.changes.len(), 2); + + let initial_counts: HashMap = state + .changes + .iter() + .map(|(row, _)| { + if let (Value::Integer(age), Value::Integer(count)) = + (&row.values[0], &row.values[1]) + { + (*age, *count) + } else { + panic!("Unexpected value types"); + } + }) + .collect(); + + assert_eq!(initial_counts.get(&25).unwrap(), &2); + assert_eq!(initial_counts.get(&30).unwrap(), &2); + + // Create uncommitted changes that would affect aggregations + let mut uncommitted = HashMap::new(); + let mut uncommitted_delta = Delta::new(); + // Add more people aged 25 + uncommitted_delta.insert( + 5, + vec![ + Value::Integer(5), + Value::Text("Eve".into()), + Value::Integer(25), + ], + ); + uncommitted_delta.insert( + 6, + vec![ + Value::Integer(6), + Value::Text("Frank".into()), + Value::Integer(25), + ], + ); + // Add person aged 35 (new group) + uncommitted_delta.insert( + 7, + vec![ + Value::Integer(7), + Value::Text("Grace".into()), + Value::Integer(35), + ], + ); + // Delete Bob (age 30) + uncommitted_delta.delete( + 2, + vec![ + Value::Integer(2), + Value::Text("Bob".into()), + Value::Integer(30), + ], + ); + uncommitted.insert("users".to_string(), uncommitted_delta); + + // Execute with uncommitted changes + let tx_result = circuit + .execute(HashMap::new(), delta_set_from_map(uncommitted.clone())) + .unwrap(); + + // Should see the aggregate changes from uncommitted data + // Age 25: retraction of count 1 and insertion of count 2 + // Age 30: insertion of count 1 (Bob is new for age 30) + assert!( + !tx_result.changes.is_empty(), + "Should see aggregate changes from uncommitted data" + ); + + // Verify internal state is unchanged (simulating rollback by not committing) + let state_after_rollback = get_current_state(&circuit).unwrap(); + assert_eq!( + state_after_rollback.changes.len(), + 2, + "Should still have 2 age groups" + ); + + let rollback_counts: HashMap = state_after_rollback + .changes + .iter() + .map(|(row, _)| { + if let (Value::Integer(age), Value::Integer(count)) = + (&row.values[0], &row.values[1]) + { + (*age, *count) + } else { + panic!("Unexpected value types"); + } + }) + .collect(); + + // Verify counts are unchanged after rollback + assert_eq!( + rollback_counts.get(&25).unwrap(), + &2, + "Age 25 count unchanged" + ); + assert_eq!( + rollback_counts.get(&30).unwrap(), + &2, + "Age 30 count unchanged" + ); + assert!( + !rollback_counts.contains_key(&35), + "Age 35 should not exist" + ); + } + + #[test] + fn test_circuit_rowid_update_consolidation() { + // Test that circuit properly consolidates state when rowid changes + let mut circuit = DbspCircuit::new(); + + // Create a simple filter node + let schema = Arc::new(LogicalSchema::new(vec![ + ("id".to_string(), Type::Integer), + ("value".to_string(), Type::Integer), + ])); + + // First create an input node + let input_id = circuit.add_node( + DbspOperator::Input { + name: "test".to_string(), + schema: schema.clone(), + }, + vec![], + None, // Input nodes don't have executables + ); + + let filter_op = FilterOperator::new( + FilterPredicate::GreaterThan { + column: "value".to_string(), + value: Value::Integer(10), + }, + vec!["id".to_string(), "value".to_string()], + ); + + // Create the filter predicate using DbspExpr + let predicate = DbspExpr::BinaryExpr { + left: Box::new(DbspExpr::Column("value".to_string())), + op: ast::Operator::Greater, + right: Box::new(DbspExpr::Literal(Value::Integer(10))), + }; + + let filter_id = circuit.add_node( + DbspOperator::Filter { predicate }, + vec![input_id], // Filter takes input from the input node + Some(Box::new(filter_op)), + ); + + circuit.root = Some(filter_id); + + // Initialize with a row + let mut init_data = HashMap::new(); + let mut delta = Delta::new(); + delta.insert(5, vec![Value::Integer(5), Value::Integer(20)]); + init_data.insert("test".to_string(), delta); + + circuit.initialize(init_data).unwrap(); + + // Verify initial state + let state = get_current_state(&circuit).unwrap(); + assert_eq!(state.changes.len(), 1); + assert_eq!(state.changes[0].0.rowid, 5); + + // Now update the rowid from 5 to 3 + let mut update_data = HashMap::new(); + let mut update_delta = Delta::new(); + update_delta.delete(5, vec![Value::Integer(5), Value::Integer(20)]); + update_delta.insert(3, vec![Value::Integer(3), Value::Integer(20)]); + update_data.insert("test".to_string(), update_delta); + + circuit + .execute(update_data.clone(), DeltaSet::empty()) + .unwrap(); + + // Commit the changes to update operator state + circuit.commit(update_data).unwrap(); + + // The circuit should consolidate the state properly + let final_state = get_current_state(&circuit).unwrap(); + assert_eq!( + final_state.changes.len(), + 1, + "Circuit should consolidate to single row" + ); + assert_eq!(final_state.changes[0].0.rowid, 3); + assert_eq!( + final_state.changes[0].0.values, + vec![Value::Integer(3), Value::Integer(20)] + ); + assert_eq!(final_state.changes[0].1, 1); + } +} diff --git a/core/incremental/mod.rs b/core/incremental/mod.rs index 328f1a510..4c26b91ba 100644 --- a/core/incremental/mod.rs +++ b/core/incremental/mod.rs @@ -1,3 +1,4 @@ +pub mod compiler; pub mod dbsp; pub mod expr_compiler; pub mod hashable_row; diff --git a/core/incremental/operator.rs b/core/incremental/operator.rs index 0391e3c0a..85f7e640c 100644 --- a/core/incremental/operator.rs +++ b/core/incremental/operator.rs @@ -359,13 +359,12 @@ pub enum JoinType { Right, } -#[derive(Debug, Clone)] +#[derive(Debug, Clone, PartialEq)] pub enum AggregateFunction { Count, Sum(String), Avg(String), - Min(String), - Max(String), + // MIN and MAX are not supported - see comment in compiler.rs for explanation } impl Display for AggregateFunction { @@ -374,8 +373,6 @@ impl Display for AggregateFunction { AggregateFunction::Count => write!(f, "COUNT(*)"), AggregateFunction::Sum(col) => write!(f, "SUM({col})"), AggregateFunction::Avg(col) => write!(f, "AVG({col})"), - AggregateFunction::Min(col) => write!(f, "MIN({col})"), - AggregateFunction::Max(col) => write!(f, "MAX({col})"), } } } @@ -401,8 +398,8 @@ impl AggregateFunction { AggFunc::Count | AggFunc::Count0 => Some(AggregateFunction::Count), AggFunc::Sum => input_column.map(AggregateFunction::Sum), AggFunc::Avg => input_column.map(AggregateFunction::Avg), - AggFunc::Min => input_column.map(AggregateFunction::Min), - AggFunc::Max => input_column.map(AggregateFunction::Max), + // MIN and MAX are not supported in incremental views - see compiler.rs + AggFunc::Min | AggFunc::Max => None, _ => None, // Other aggregate functions not yet supported in DBSP } } @@ -417,8 +414,18 @@ pub trait IncrementalOperator: Debug { /// Initialize with base data fn initialize(&mut self, data: Delta); - /// Process a delta (incremental update) - fn process_delta(&mut self, delta: Delta) -> Delta; + /// Evaluate the operator with a delta, without modifying internal state + /// This is used during query execution to compute results including uncommitted changes + /// + /// # Arguments + /// * `delta` - The committed delta to process + /// * `uncommitted` - Optional uncommitted changes from the current transaction + fn eval(&self, delta: Delta, uncommitted: Option) -> Delta; + + /// Commit a delta to the operator's internal state and return the output + /// This is called when a transaction commits, making changes permanent + /// Returns the output delta (what downstream operators should see) + fn commit(&mut self, delta: Delta) -> Delta; /// Get current accumulated state fn get_current_state(&self) -> Delta; @@ -554,20 +561,51 @@ impl IncrementalOperator for FilterOperator { } } - fn process_delta(&mut self, delta: Delta) -> Delta { + fn eval(&self, delta: Delta, uncommitted: Option) -> Delta { let mut output_delta = Delta::new(); - // Process only the delta, not the entire state + // Merge delta with uncommitted if present + let combined_delta = if let Some(uncommitted) = uncommitted { + let mut combined = delta; + combined.merge(&uncommitted); + combined + } else { + delta + }; + + // Process the combined delta through the filter + for (row, weight) in combined_delta.changes { + if let Some(tracker) = &self.tracker { + tracker.lock().unwrap().record_filter(); + } + + // Only pass through rows that satisfy the filter predicate + // For deletes (weight < 0), we only pass them if the row values + // would have passed the filter (meaning it was in the view) + if self.evaluate_predicate(&row.values) { + output_delta.changes.push((row, weight)); + } + } + + output_delta + } + + fn commit(&mut self, delta: Delta) -> Delta { + let mut output_delta = Delta::new(); + + // Commit the delta to our internal state + // Only pass through and track rows that satisfy the filter predicate for (row, weight) in delta.changes { if let Some(tracker) = &self.tracker { tracker.lock().unwrap().record_filter(); } + // Only track and output rows that pass the filter + // For deletes, this means the row was in the view (its values pass the filter) + // For inserts, this means the row should be in the view if self.evaluate_predicate(&row.values) { - output_delta.changes.push((row.clone(), weight)); - - // Update our state - self.current_state.changes.push((row, weight)); + self.current_state.changes.push((row.clone(), weight)); + output_delta.changes.push((row, weight)); } } @@ -575,7 +613,10 @@ impl IncrementalOperator for FilterOperator { } fn get_current_state(&self) -> Delta { - self.current_state.clone() + // Return a consolidated view of the current state + let mut consolidated = self.current_state.clone(); + consolidated.consolidate(); + consolidated } fn set_tracker(&mut self, tracker: Arc>) { @@ -733,6 +774,46 @@ impl ProjectOperator { }) } + /// Create a ProjectOperator from pre-compiled expressions + pub fn from_compiled( + compiled_exprs: Vec, + aliases: Vec>, + input_column_names: Vec, + output_column_names: Vec, + ) -> crate::Result { + // Set up internal connection for expression evaluation + let io = Arc::new(crate::MemoryIO::new()); + let db = Database::open_file( + io, ":memory:", false, // no MVCC needed for expression evaluation + false, // no indexes needed + )?; + let internal_conn = db.connect()?; + // Set to read-only mode and disable auto-commit since we're only evaluating expressions + internal_conn.query_only.set(true); + internal_conn.auto_commit.set(false); + + // Create ProjectColumn structs from compiled expressions + let columns: Vec = compiled_exprs + .into_iter() + .zip(aliases) + .map(|(compiled, alias)| ProjectColumn { + // Create a placeholder AST expression since we already have the compiled version + expr: turso_parser::ast::Expr::Literal(turso_parser::ast::Literal::Null), + alias, + compiled, + }) + .collect(); + + Ok(Self { + columns, + input_column_names, + output_column_names, + current_state: Delta::new(), + tracker: None, + internal_conn, + }) + } + /// Get the columns for this projection pub fn columns(&self) -> &[ProjectColumn] { &self.columns @@ -899,26 +980,55 @@ impl IncrementalOperator for ProjectOperator { } } - fn process_delta(&mut self, delta: Delta) -> Delta { + fn eval(&self, delta: Delta, uncommitted: Option) -> Delta { let mut output_delta = Delta::new(); - for (row, weight) in &delta.changes { + // Merge delta with uncommitted if present + let combined_delta = if let Some(uncommitted) = uncommitted { + let mut combined = delta; + combined.merge(&uncommitted); + combined + } else { + delta + }; + + for (row, weight) in &combined_delta.changes { if let Some(tracker) = &self.tracker { tracker.lock().unwrap().record_project(); } let projected = self.project_values(&row.values); let projected_row = HashableRow::new(row.rowid, projected); + output_delta.changes.push((projected_row, *weight)); + } - output_delta.changes.push((projected_row.clone(), *weight)); - self.current_state.changes.push((projected_row, *weight)); + output_delta + } + + fn commit(&mut self, delta: Delta) -> Delta { + let mut output_delta = Delta::new(); + + // Commit the delta to our internal state and build output + for (row, weight) in &delta.changes { + if let Some(tracker) = &self.tracker { + tracker.lock().unwrap().record_project(); + } + let projected = self.project_values(&row.values); + let projected_row = HashableRow::new(row.rowid, projected); + self.current_state + .changes + .push((projected_row.clone(), *weight)); + output_delta.changes.push((projected_row, *weight)); } output_delta } fn get_current_state(&self) -> Delta { - self.current_state.clone() + // Return a consolidated view of the current state + let mut consolidated = self.current_state.clone(); + consolidated.consolidate(); + consolidated } fn set_tracker(&mut self, tracker: Arc>) { @@ -926,325 +1036,6 @@ impl IncrementalOperator for ProjectOperator { } } -/// Join operator - performs incremental joins using DBSP formula -/// ∂(A ⋈ B) = A ⋈ ∂B + ∂A ⋈ B + ∂A ⋈ ∂B -#[derive(Debug)] -pub struct JoinOperator { - join_type: JoinType, - pub left_on_column: String, - pub right_on_column: String, - left_column_names: Vec, - right_column_names: Vec, - // Current accumulated state for both sides - left_state: Delta, - right_state: Delta, - // Index for efficient lookups: column_value_as_string -> vec of row_keys - // We use String representation of values since Value doesn't implement Hash - left_index: HashMap>, - right_index: HashMap>, - // Result state - current_state: Delta, - tracker: Option>>, - // For generating unique keys for join results - next_result_key: i64, -} - -impl JoinOperator { - pub fn new( - join_type: JoinType, - left_on_column: String, - right_on_column: String, - left_column_names: Vec, - right_column_names: Vec, - ) -> Self { - Self { - join_type, - left_on_column, - right_on_column, - left_column_names, - right_column_names, - left_state: Delta::new(), - right_state: Delta::new(), - left_index: HashMap::new(), - right_index: HashMap::new(), - current_state: Delta::new(), - tracker: None, - next_result_key: 0, - } - } - - pub fn set_tracker(&mut self, tracker: Arc>) { - self.tracker = Some(tracker); - } - - /// Build index for a side of the join - fn build_index( - state: &Delta, - column_names: &[String], - on_column: &str, - ) -> HashMap> { - let mut index = HashMap::new(); - - // Find the column index - let col_idx = column_names.iter().position(|c| c == on_column); - if col_idx.is_none() { - return index; - } - let col_idx = col_idx.unwrap(); - - // Build the index - for (row, weight) in &state.changes { - // Include rows with positive weight in the index - if *weight > 0 { - if let Some(key_value) = row.values.get(col_idx) { - // Convert value to string for indexing - let key_str = format!("{key_value:?}"); - index - .entry(key_str) - .or_insert_with(Vec::new) - .push(row.rowid); - } - } - } - - index - } - - /// Join two deltas - fn join_deltas(&self, left_delta: &Delta, right_delta: &Delta, next_key: &mut i64) -> Delta { - let mut result = Delta::new(); - - // Find column indices - let left_col_idx = self - .left_column_names - .iter() - .position(|c| c == &self.left_on_column) - .unwrap_or(0); - let right_col_idx = self - .right_column_names - .iter() - .position(|c| c == &self.right_on_column) - .unwrap_or(0); - - // For each row in left_delta - for (left_row, left_weight) in &left_delta.changes { - // Process both inserts and deletes - - let left_join_value = left_row.values.get(left_col_idx); - if left_join_value.is_none() { - continue; - } - let left_join_value = left_join_value.unwrap(); - - // Look up matching rows in right_delta - for (right_row, right_weight) in &right_delta.changes { - // Process both inserts and deletes - - let right_join_value = right_row.values.get(right_col_idx); - if right_join_value.is_none() { - continue; - } - let right_join_value = right_join_value.unwrap(); - - // Check if values match - if left_join_value == right_join_value { - // Record the join lookup - if let Some(tracker) = &self.tracker { - tracker.lock().unwrap().record_join_lookup(); - } - - // Create joined row - let mut joined_values = left_row.values.clone(); - joined_values.extend(right_row.values.clone()); - - // Generate a unique key for the result - let result_key = *next_key; - *next_key += 1; - - let joined_row = HashableRow::new(result_key, joined_values); - result - .changes - .push((joined_row, left_weight * right_weight)); - } - } - } - - result - } - - /// Join a delta with the full state using the index - fn join_delta_with_state( - &self, - delta: &Delta, - state: &Delta, - delta_on_left: bool, - next_key: &mut i64, - ) -> Delta { - let mut result = Delta::new(); - - let (delta_col_idx, state_col_names) = if delta_on_left { - ( - self.left_column_names - .iter() - .position(|c| c == &self.left_on_column) - .unwrap_or(0), - &self.right_column_names, - ) - } else { - ( - self.right_column_names - .iter() - .position(|c| c == &self.right_on_column) - .unwrap_or(0), - &self.left_column_names, - ) - }; - - // Use index for efficient lookup - let state_index = Self::build_index( - state, - state_col_names, - if delta_on_left { - &self.right_on_column - } else { - &self.left_on_column - }, - ); - - for (delta_row, delta_weight) in &delta.changes { - // Process both inserts and deletes - - let delta_join_value = delta_row.values.get(delta_col_idx); - if delta_join_value.is_none() { - continue; - } - let delta_join_value = delta_join_value.unwrap(); - - // Use index to find matching rows - let delta_key_str = format!("{delta_join_value:?}"); - if let Some(matching_keys) = state_index.get(&delta_key_str) { - for state_key in matching_keys { - // Look up in the state - find the row with this rowid - let state_row_opt = state - .changes - .iter() - .find(|(row, weight)| row.rowid == *state_key && *weight > 0); - - if let Some((state_row, state_weight)) = state_row_opt { - // Record the join lookup - if let Some(tracker) = &self.tracker { - tracker.lock().unwrap().record_join_lookup(); - } - - // Create joined row - let joined_values = if delta_on_left { - let mut v = delta_row.values.clone(); - v.extend(state_row.values.clone()); - v - } else { - let mut v = state_row.values.clone(); - v.extend(delta_row.values.clone()); - v - }; - - let result_key = *next_key; - *next_key += 1; - - let joined_row = HashableRow::new(result_key, joined_values); - result - .changes - .push((joined_row, delta_weight * state_weight)); - } - } - } - } - - result - } - - /// Initialize both sides of the join - pub fn initialize_both(&mut self, left_data: Delta, right_data: Delta) { - self.left_state = left_data.clone(); - self.right_state = right_data.clone(); - - // Build indices - self.left_index = Self::build_index( - &self.left_state, - &self.left_column_names, - &self.left_on_column, - ); - self.right_index = Self::build_index( - &self.right_state, - &self.right_column_names, - &self.right_on_column, - ); - - // Perform initial join - let mut next_key = self.next_result_key; - self.current_state = self.join_deltas(&self.left_state, &self.right_state, &mut next_key); - self.next_result_key = next_key; - } - - /// Process deltas for both sides using DBSP formula - /// ∂(A ⋈ B) = A ⋈ ∂B + ∂A ⋈ B + ∂A ⋈ ∂B - pub fn process_both_deltas(&mut self, left_delta: Delta, right_delta: Delta) -> Delta { - let mut result = Delta::new(); - let mut next_key = self.next_result_key; - - // A ⋈ ∂B (existing left with new right) - let a_join_db = - self.join_delta_with_state(&right_delta, &self.left_state, false, &mut next_key); - result.merge(&a_join_db); - - // ∂A ⋈ B (new left with existing right) - let da_join_b = - self.join_delta_with_state(&left_delta, &self.right_state, true, &mut next_key); - result.merge(&da_join_b); - - // ∂A ⋈ ∂B (new left with new right) - let da_join_db = self.join_deltas(&left_delta, &right_delta, &mut next_key); - result.merge(&da_join_db); - - // Update the next key counter - self.next_result_key = next_key; - - // Update state - self.left_state.merge(&left_delta); - self.right_state.merge(&right_delta); - self.current_state.merge(&result); - - // Rebuild indices if needed - self.left_index = Self::build_index( - &self.left_state, - &self.left_column_names, - &self.left_on_column, - ); - self.right_index = Self::build_index( - &self.right_state, - &self.right_column_names, - &self.right_on_column, - ); - - result - } - - pub fn get_current_state(&self) -> &Delta { - &self.current_state - } - - /// Process a delta from the left table only - pub fn process_left_delta(&mut self, left_delta: Delta) -> Delta { - let empty_delta = Delta::new(); - self.process_both_deltas(left_delta, empty_delta) - } - - /// Process a delta from the right table only - pub fn process_right_delta(&mut self, right_delta: Delta) -> Delta { - let empty_delta = Delta::new(); - self.process_both_deltas(empty_delta, right_delta) - } -} - /// Aggregate operator - performs incremental aggregation with GROUP BY /// Maintains running totals/counts that are updated incrementally #[derive(Debug, Clone)] @@ -1275,10 +1066,8 @@ struct AggregateState { sums: HashMap, // For AVG: column_name -> (sum, count) for computing average avgs: HashMap, - // For MIN: column_name -> min value - mins: HashMap, - // For MAX: column_name -> max value - maxs: HashMap, + // MIN/MAX are not supported - they require O(n) storage overhead for handling deletions + // correctly. See comment in apply_delta() for details. } impl AggregateState { @@ -1287,8 +1076,6 @@ impl AggregateState { count: 0, sums: HashMap::new(), avgs: HashMap::new(), - mins: HashMap::new(), - maxs: HashMap::new(), } } @@ -1337,43 +1124,6 @@ impl AggregateState { } } } - AggregateFunction::Min(col_name) => { - // MIN/MAX are more complex for incremental updates - // For now, we'll need to recompute from the full state - // This is a limitation we can improve later - if weight > 0 { - // Only update on insert - if let Some(idx) = column_names.iter().position(|c| c == col_name) { - if let Some(val) = values.get(idx) { - self.mins - .entry(col_name.clone()) - .and_modify(|existing| { - if val < existing { - *existing = val.clone(); - } - }) - .or_insert_with(|| val.clone()); - } - } - } - } - AggregateFunction::Max(col_name) => { - if weight > 0 { - // Only update on insert - if let Some(idx) = column_names.iter().position(|c| c == col_name) { - if let Some(val) = values.get(idx) { - self.maxs - .entry(col_name.clone()) - .and_modify(|existing| { - if val > existing { - *existing = val.clone(); - } - }) - .or_insert_with(|| val.clone()); - } - } - } - } } } } @@ -1407,12 +1157,6 @@ impl AggregateState { result.push(Value::Null); } } - AggregateFunction::Min(col_name) => { - result.push(self.mins.get(col_name).cloned().unwrap_or(Value::Null)); - } - AggregateFunction::Max(col_name) => { - result.push(self.maxs.get(col_name).cloned().unwrap_or(Value::Null)); - } } } @@ -1472,8 +1216,9 @@ impl AggregateOperator { pub fn process_delta(&mut self, delta: Delta) -> Delta { let mut output_delta = Delta::new(); - // Track which groups were modified + // Track which groups were modified and their old values let mut modified_groups = HashSet::new(); + let mut old_values: HashMap> = HashMap::new(); // Process each change in the delta for (row, weight) in &delta.changes { @@ -1484,6 +1229,17 @@ impl AggregateOperator { // Extract group key let group_key = self.extract_group_key(&row.values); let group_key_str = Self::group_key_to_string(&group_key); + + // Store old aggregate values BEFORE applying the delta + // (only for the first time we see this group in this batch) + if !modified_groups.contains(&group_key_str) { + if let Some(state) = self.group_states.get(&group_key_str) { + let mut old_row = group_key.clone(); + old_row.extend(state.to_values(&self.aggregates)); + old_values.insert(group_key_str.clone(), old_row); + } + } + modified_groups.insert(group_key_str.clone()); // Store the actual group key values @@ -1514,17 +1270,25 @@ impl AggregateOperator { .cloned() .unwrap_or_default(); + // Generate a unique key for this group + // We use a hash of the group key to ensure consistency + let result_key = group_key_str + .bytes() + .fold(0i64, |acc, b| acc.wrapping_mul(31).wrapping_add(b as i64)); + + // Emit retraction for old value if it existed + if let Some(old_row_values) = old_values.get(&group_key_str) { + let old_row = HashableRow::new(result_key, old_row_values.clone()); + output_delta.changes.push((old_row.clone(), -1)); + // Also remove from current state + self.current_state.changes.push((old_row, -1)); + } + if let Some(state) = self.group_states.get(&group_key_str) { // Build output row: group_by columns + aggregate values let mut output_values = group_key.clone(); output_values.extend(state.to_values(&self.aggregates)); - // Generate a unique key for this group - // We use a hash of the group key to ensure consistency - let result_key = group_key_str - .bytes() - .fold(0i64, |acc, b| acc.wrapping_mul(31).wrapping_add(b as i64)); - // Check if group should be removed (count is 0) if state.count > 0 { // Add to output delta with positive weight @@ -1534,12 +1298,8 @@ impl AggregateOperator { // Update current state self.current_state.changes.push((output_row, 1)); } else { - // Add to output delta with negative weight (deletion) - let output_row = HashableRow::new(result_key, output_values); - output_delta.changes.push((output_row.clone(), -1)); - - // Mark for removal in current state - self.current_state.changes.push((output_row, -1)); + // Group has count=0, remove from state + // (we already emitted the retraction above if needed) self.group_states.remove(&group_key_str); self.group_key_values.remove(&group_key_str); } @@ -1559,16 +1319,108 @@ impl AggregateOperator { impl IncrementalOperator for AggregateOperator { fn initialize(&mut self, data: Delta) { - // Process all initial data - self.process_delta(data); + // Process all initial data - this modifies state during initialization + let _ = self.process_delta(data); } - fn process_delta(&mut self, delta: Delta) -> Delta { + fn eval(&self, delta: Delta, uncommitted: Option) -> Delta { + // Clone the current state to work with temporarily + let mut temp_group_states = self.group_states.clone(); + let mut temp_group_key_values = self.group_key_values.clone(); + + // Merge delta with uncommitted if present + let combined_delta = if let Some(uncommitted) = uncommitted { + let mut combined = delta; + combined.merge(&uncommitted); + combined + } else { + delta + }; + + let mut output_delta = Delta::new(); + let mut modified_groups = HashSet::new(); + let mut old_values: HashMap> = HashMap::new(); + + // Process each change in the combined delta using temporary state + for (row, weight) in &combined_delta.changes { + if let Some(tracker) = &self.tracker { + tracker.lock().unwrap().record_aggregation(); + } + + // Extract group key + let group_key = self.extract_group_key(&row.values); + let group_key_str = Self::group_key_to_string(&group_key); + + // Store old aggregate values BEFORE applying the delta + if !modified_groups.contains(&group_key_str) { + if let Some(state) = temp_group_states.get(&group_key_str) { + let mut old_row = group_key.clone(); + old_row.extend(state.to_values(&self.aggregates)); + old_values.insert(group_key_str.clone(), old_row); + } + } + + modified_groups.insert(group_key_str.clone()); + temp_group_key_values.insert(group_key_str.clone(), group_key.clone()); + + // Get or create aggregate state for this group in temporary state + let state = temp_group_states + .entry(group_key_str.clone()) + .or_insert_with(AggregateState::new); + + // Apply the delta to the temporary aggregate state + state.apply_delta( + &row.values, + *weight, + &self.aggregates, + &self.input_column_names, + ); + } + + // Generate output delta for modified groups using temporary state + for group_key_str in modified_groups { + let group_key = temp_group_key_values + .get(&group_key_str) + .cloned() + .unwrap_or_default(); + + // Generate a unique key for this group + let result_key = group_key_str + .bytes() + .fold(0i64, |acc, b| acc.wrapping_mul(31).wrapping_add(b as i64)); + + // Emit retraction for old value if it existed + if let Some(old_row_values) = old_values.get(&group_key_str) { + let old_row = HashableRow::new(result_key, old_row_values.clone()); + output_delta.changes.push((old_row, -1)); + } + + if let Some(state) = temp_group_states.get(&group_key_str) { + // Build output row: group_by columns + aggregate values + let mut output_values = group_key.clone(); + output_values.extend(state.to_values(&self.aggregates)); + + // Check if group should be included (count > 0) + if state.count > 0 { + let output_row = HashableRow::new(result_key, output_values); + output_delta.changes.push((output_row, 1)); + } + } + } + + output_delta + } + + fn commit(&mut self, delta: Delta) -> Delta { + // Actually update the internal state when committing and return the output self.process_delta(delta) } fn get_current_state(&self) -> Delta { - self.current_state.clone() + // Return a consolidated view of the current state + let mut consolidated = self.current_state.clone(); + consolidated.consolidate(); + consolidated } fn set_tracker(&mut self, tracker: Arc>) { @@ -1603,176 +1455,245 @@ mod tests { ); } - // Join tests + // Aggregate tests #[test] - fn test_join_uses_delta_formula() { - let tracker = Arc::new(Mutex::new(ComputationTracker::new())); + fn test_aggregate_incremental_update_emits_retraction() { + // This test verifies that when an aggregate value changes, + // the operator emits both a retraction (-1) of the old value + // and an insertion (+1) of the new value. - // Create join operator - let mut join = JoinOperator::new( - JoinType::Inner, - "user_id".to_string(), - "user_id".to_string(), - vec!["user_id".to_string(), "email".to_string()], - vec![ - "login_id".to_string(), - "user_id".to_string(), - "timestamp".to_string(), - ], + // Create an aggregate operator for SUM(age) with no GROUP BY + let mut agg = AggregateOperator::new( + vec![], // No GROUP BY + vec![AggregateFunction::Sum("age".to_string())], + vec!["id".to_string(), "name".to_string(), "age".to_string()], ); - join.set_tracker(tracker.clone()); - // Initial data: emails table - let mut emails = Delta::new(); - emails.insert( + // Initial data: 3 users + let mut initial_delta = Delta::new(); + initial_delta.insert( 1, vec![ Value::Integer(1), - Value::Text(Text::new("alice@example.com")), + Value::Text("Alice".to_string().into()), + Value::Integer(25), ], ); - emails.insert( + initial_delta.insert( 2, - vec![Value::Integer(2), Value::Text(Text::new("bob@example.com"))], + vec![ + Value::Integer(2), + Value::Text("Bob".to_string().into()), + Value::Integer(30), + ], ); - - // Initial data: logins table - let mut logins = Delta::new(); - logins.insert( - 1, - vec![Value::Integer(1), Value::Integer(1), Value::Integer(1000)], - ); - logins.insert( - 2, - vec![Value::Integer(2), Value::Integer(1), Value::Integer(2000)], - ); - - // Initialize join - join.initialize_both(emails.clone(), logins.clone()); - - // Reset tracker for delta processing - tracker.lock().unwrap().join_lookups = 0; - - // Add one login for bob (user_id=2) - let mut delta_logins = Delta::new(); - delta_logins.insert( + initial_delta.insert( 3, - vec![Value::Integer(3), Value::Integer(2), Value::Integer(3000)], + vec![ + Value::Integer(3), + Value::Text("Charlie".to_string().into()), + Value::Integer(35), + ], ); - // Process delta - should use incremental formula - let empty_delta = Delta::new(); - let output = join.process_both_deltas(empty_delta, delta_logins); + // Initialize with initial data + agg.initialize(initial_delta); - // Should have one join result (bob's new login) - assert_eq!(output.len(), 1); + // Verify initial state: SUM(age) = 25 + 30 + 35 = 90 + let state = agg.get_current_state(); + assert_eq!(state.changes.len(), 1, "Should have one aggregate row"); + let (row, weight) = &state.changes[0]; + assert_eq!(*weight, 1, "Aggregate row should have weight 1"); + assert_eq!(row.values[0], Value::Float(90.0), "SUM should be 90"); - // Verify we used index lookups, not nested loops - // Should have done 1 lookup (finding bob's email for the new login) - let lookups = tracker.lock().unwrap().join_lookups; - assert_eq!(lookups, 1, "Should use index lookup, not scan all emails"); + // Now add a new user (incremental update) + let mut update_delta = Delta::new(); + update_delta.insert( + 4, + vec![ + Value::Integer(4), + Value::Text("David".to_string().into()), + Value::Integer(40), + ], + ); - // Verify incremental behavior - we processed only the delta - let t = tracker.lock().unwrap(); - assert_incremental(&t, 1, 3); // 1 operation for 3 total rows + // Process the incremental update + let output_delta = agg.eval(update_delta.clone(), None); + agg.commit(update_delta); + + // CRITICAL: The output delta should contain TWO changes: + // 1. Retraction of old aggregate value (90) with weight -1 + // 2. Insertion of new aggregate value (130) with weight +1 + assert_eq!( + output_delta.changes.len(), + 2, + "Expected 2 changes (retraction + insertion), got {}: {:?}", + output_delta.changes.len(), + output_delta.changes + ); + + // Verify the retraction comes first + let (retraction_row, retraction_weight) = &output_delta.changes[0]; + assert_eq!( + *retraction_weight, -1, + "First change should be a retraction" + ); + assert_eq!( + retraction_row.values[0], + Value::Float(90.0), + "Retracted value should be the old sum (90)" + ); + + // Verify the insertion comes second + let (insertion_row, insertion_weight) = &output_delta.changes[1]; + assert_eq!(*insertion_weight, 1, "Second change should be an insertion"); + assert_eq!( + insertion_row.values[0], + Value::Float(130.0), + "Inserted value should be the new sum (130)" + ); + + // Both changes should have the same row ID (since it's the same aggregate group) + assert_eq!( + retraction_row.rowid, insertion_row.rowid, + "Retraction and insertion should have the same row ID" + ); } #[test] - fn test_join_maintains_index() { - // Create join operator - let mut join = JoinOperator::new( - JoinType::Inner, - "id".to_string(), - "ref_id".to_string(), - vec!["id".to_string(), "name".to_string()], - vec!["ref_id".to_string(), "value".to_string()], + fn test_aggregate_with_group_by_emits_retractions() { + // This test verifies that when aggregate values change for grouped data, + // the operator emits both retractions and insertions correctly for each group. + + // Create an aggregate operator for SUM(score) GROUP BY team + let mut agg = AggregateOperator::new( + vec!["team".to_string()], // GROUP BY team + vec![AggregateFunction::Sum("score".to_string())], + vec![ + "id".to_string(), + "team".to_string(), + "player".to_string(), + "score".to_string(), + ], ); - // Initial data - let mut left = Delta::new(); - left.insert(1, vec![Value::Integer(1), Value::Text(Text::new("A"))]); - left.insert(2, vec![Value::Integer(2), Value::Text(Text::new("B"))]); - - let mut right = Delta::new(); - right.insert(1, vec![Value::Integer(1), Value::Integer(100)]); - - // Initialize - should build index - join.initialize_both(left.clone(), right.clone()); - - // Verify initial join worked - let state = join.get_current_state(); - assert_eq!(state.changes.len(), 1); // One match: id=1 - - // Add new item to left - let mut delta_left = Delta::new(); - delta_left.insert(3, vec![Value::Integer(3), Value::Text(Text::new("C"))]); - - // Add matching item to right - let mut delta_right = Delta::new(); - delta_right.insert(2, vec![Value::Integer(3), Value::Integer(300)]); - - // Process deltas - let output = join.process_both_deltas(delta_left, delta_right); - - // Should have new join result - assert_eq!(output.len(), 1); - - // Verify the join result has the expected values - assert!(!output.changes.is_empty()); - let (result, _weight) = &output.changes[0]; - assert_eq!(result.values.len(), 4); // id, name, ref_id, value - } - - #[test] - fn test_join_formula_correctness() { - // Test the DBSP formula: ∂(A ⋈ B) = A ⋈ ∂B + ∂A ⋈ B + ∂A ⋈ ∂B - let tracker = Arc::new(Mutex::new(ComputationTracker::new())); - - let mut join = JoinOperator::new( - JoinType::Inner, - "x".to_string(), - "x".to_string(), - vec!["x".to_string(), "a".to_string()], - vec!["x".to_string(), "b".to_string()], + // Initial data: players on different teams + let mut initial_delta = Delta::new(); + initial_delta.insert( + 1, + vec![ + Value::Integer(1), + Value::Text("red".to_string().into()), + Value::Text("Alice".to_string().into()), + Value::Integer(10), + ], + ); + initial_delta.insert( + 2, + vec![ + Value::Integer(2), + Value::Text("blue".to_string().into()), + Value::Text("Bob".to_string().into()), + Value::Integer(15), + ], + ); + initial_delta.insert( + 3, + vec![ + Value::Integer(3), + Value::Text("red".to_string().into()), + Value::Text("Charlie".to_string().into()), + Value::Integer(20), + ], ); - join.set_tracker(tracker.clone()); - // Initial state A - let mut a = Delta::new(); - a.insert(1, vec![Value::Integer(1), Value::Text(Text::new("a1"))]); - a.insert(2, vec![Value::Integer(2), Value::Text(Text::new("a2"))]); + // Initialize with initial data + agg.initialize(initial_delta); - // Initial state B - let mut b = Delta::new(); - b.insert(1, vec![Value::Integer(1), Value::Text(Text::new("b1"))]); - b.insert(2, vec![Value::Integer(2), Value::Text(Text::new("b2"))]); + // Verify initial state: red team = 30, blue team = 15 + let state = agg.get_current_state(); + assert_eq!(state.changes.len(), 2, "Should have two groups"); - join.initialize_both(a.clone(), b.clone()); + // Find the red and blue team aggregates + let mut red_sum = None; + let mut blue_sum = None; + for (row, weight) in &state.changes { + assert_eq!(*weight, 1); + if let Value::Text(team) = &row.values[0] { + if team.as_str() == "red" { + red_sum = Some(&row.values[1]); + } else if team.as_str() == "blue" { + blue_sum = Some(&row.values[1]); + } + } + } + assert_eq!( + red_sum, + Some(&Value::Float(30.0)), + "Red team sum should be 30" + ); + assert_eq!( + blue_sum, + Some(&Value::Float(15.0)), + "Blue team sum should be 15" + ); - // Reset tracker - tracker.lock().unwrap().join_lookups = 0; + // Now add a new player to the red team (incremental update) + let mut update_delta = Delta::new(); + update_delta.insert( + 4, + vec![ + Value::Integer(4), + Value::Text("red".to_string().into()), + Value::Text("David".to_string().into()), + Value::Integer(25), + ], + ); - // Delta for A (add x=3) - let mut delta_a = Delta::new(); - delta_a.insert(3, vec![Value::Integer(3), Value::Text(Text::new("a3"))]); + // Process the incremental update + let output_delta = agg.eval(update_delta.clone(), None); + agg.commit(update_delta); - // Delta for B (add x=3 and x=1) - let mut delta_b = Delta::new(); - delta_b.insert(3, vec![Value::Integer(3), Value::Text(Text::new("b3"))]); - delta_b.insert(4, vec![Value::Integer(1), Value::Text(Text::new("b1_new"))]); + // Should have 2 changes: retraction of old red team sum, insertion of new red team sum + // Blue team should NOT be affected + assert_eq!( + output_delta.changes.len(), + 2, + "Expected 2 changes for red team only, got {}: {:?}", + output_delta.changes.len(), + output_delta.changes + ); - let output = join.process_both_deltas(delta_a, delta_b); + // Both changes should be for the red team + let mut found_retraction = false; + let mut found_insertion = false; - // Expected results: - // A ⋈ ∂B: (1,a1) ⋈ (1,b1_new) = 1 result - // ∂A ⋈ B: (3,a3) ⋈ nothing = 0 results - // ∂A ⋈ ∂B: (3,a3) ⋈ (3,b3) = 1 result - // Total: 2 results - assert_eq!(output.len(), 2); + for (row, weight) in &output_delta.changes { + if let Value::Text(team) = &row.values[0] { + assert_eq!(team.as_str(), "red", "Only red team should have changes"); - // Verify we're doing incremental work - let lookups = tracker.lock().unwrap().join_lookups; - assert!(lookups <= 4, "Should use efficient index lookups"); + if *weight == -1 { + // Retraction of old value + assert_eq!( + row.values[1], + Value::Float(30.0), + "Should retract old sum of 30" + ); + found_retraction = true; + } else if *weight == 1 { + // Insertion of new value + assert_eq!( + row.values[1], + Value::Float(55.0), + "Should insert new sum of 55" + ); + found_insertion = true; + } + } + } + + assert!(found_retraction, "Should have found retraction"); + assert!(found_insertion, "Should have found insertion"); } // Aggregation tests @@ -1821,21 +1742,25 @@ mod tests { ], ); - let output = agg.process_delta(delta); + let _output = agg.eval(delta.clone(), None); + agg.commit(delta); - // Should only update one group (cat_0), not recount all groups - assert_eq!(tracker.lock().unwrap().aggregation_updates, 1); + // Should update one group (cat_0) twice - once in eval, once in commit + // This is still incremental - we're not recounting all groups + assert_eq!(tracker.lock().unwrap().aggregation_updates, 2); - // Output should show cat_0 now has count 11 - assert_eq!(output.len(), 1); - assert!(!output.changes.is_empty()); - let (change_row, _weight) = &output.changes[0]; - assert_eq!(change_row.values[0], Value::Text(Text::new("cat_0"))); - assert_eq!(change_row.values[1], Value::Integer(11)); + // Check the final state - cat_0 should now have count 11 + let final_state = agg.get_current_state(); + let cat_0 = final_state + .changes + .iter() + .find(|(row, _)| row.values[0] == Value::Text(Text::new("cat_0"))) + .unwrap(); + assert_eq!(cat_0.0.values[1], Value::Integer(11)); - // Verify incremental behavior + // Verify incremental behavior - we process the delta twice (eval + commit) let t = tracker.lock().unwrap(); - assert_incremental(&t, 1, 101); + assert_incremental(&t, 2, 101); } #[test] @@ -1906,17 +1831,20 @@ mod tests { ], ); - let output = agg.process_delta(delta); + let _output = agg.eval(delta.clone(), None); + agg.commit(delta); - // Should only update Widget group - assert_eq!(tracker.lock().unwrap().aggregation_updates, 1); - assert_eq!(output.len(), 1); + // Should update Widget group twice (once in eval, once in commit) + assert_eq!(tracker.lock().unwrap().aggregation_updates, 2); - // Widget should now be 300 (250 + 50) - assert!(!output.changes.is_empty()); - let (change, _weight) = &output.changes[0]; - assert_eq!(change.values[0], Value::Text(Text::new("Widget"))); - assert_eq!(change.values[1], Value::Integer(300)); + // Check final state - Widget should now be 300 (250 + 50) + let final_state = agg.get_current_state(); + let widget = final_state + .changes + .iter() + .find(|(row, _)| row.values[0] == Value::Text(Text::new("Widget"))) + .unwrap(); + assert_eq!(widget.0.values[1], Value::Integer(300)); } #[test] @@ -1981,15 +1909,18 @@ mod tests { 4, vec![Value::Integer(4), Value::Integer(1), Value::Integer(50)], ); - let output = agg.process_delta(delta); + let _output = agg.eval(delta.clone(), None); + agg.commit(delta); - // Should only update user 1 - assert_eq!(output.len(), 1); - assert!(!output.changes.is_empty()); - let (change, _weight) = &output.changes[0]; - assert_eq!(change.values[0], Value::Integer(1)); // user_id - assert_eq!(change.values[1], Value::Integer(3)); // count: 2 + 1 - assert_eq!(change.values[2], Value::Integer(350)); // sum: 300 + 50 + // Check final state - user 1 should have updated count and sum + let final_state = agg.get_current_state(); + let user1 = final_state + .changes + .iter() + .find(|(row, _)| row.values[0] == Value::Integer(1)) + .unwrap(); + assert_eq!(user1.0.values[1], Value::Integer(3)); // count: 2 + 1 + assert_eq!(user1.0.values[2], Value::Integer(350)); // sum: 300 + 50 } #[test] @@ -2063,13 +1994,17 @@ mod tests { Value::Integer(30), ], ); - let output = agg.process_delta(delta); + let _output = agg.eval(delta.clone(), None); + agg.commit(delta); - // Category A avg should now be (10 + 20 + 30) / 3 = 20 - assert!(!output.changes.is_empty()); - let (change, _weight) = &output.changes[0]; - assert_eq!(change.values[0], Value::Text(Text::new("A"))); - assert_eq!(change.values[1], Value::Float(20.0)); + // Check final state - Category A avg should now be (10 + 20 + 30) / 3 = 20 + let final_state = agg.get_current_state(); + let cat_a = final_state + .changes + .iter() + .find(|(row, _)| row.values[0] == Value::Text(Text::new("A"))) + .unwrap(); + assert_eq!(cat_a.0.values[1], Value::Float(20.0)); } #[test] @@ -2126,13 +2061,662 @@ mod tests { ], ); - let output = agg.process_delta(delta); + let _output = agg.eval(delta.clone(), None); + agg.commit(delta); - // Should update to count=1, sum=200 - assert!(!output.changes.is_empty()); - let (change_row, _weight) = &output.changes[0]; - assert_eq!(change_row.values[0], Value::Text(Text::new("A"))); - assert_eq!(change_row.values[1], Value::Integer(1)); // count: 2 - 1 - assert_eq!(change_row.values[2], Value::Integer(200)); // sum: 300 - 100 + // Check final state - should update to count=1, sum=200 + let final_state = agg.get_current_state(); + let cat_a = final_state + .changes + .iter() + .find(|(row, _)| row.values[0] == Value::Text(Text::new("A"))) + .unwrap(); + assert_eq!(cat_a.0.values[1], Value::Integer(1)); // count: 2 - 1 + assert_eq!(cat_a.0.values[2], Value::Integer(200)); // sum: 300 - 100 + } + + #[test] + fn test_count_aggregation_with_deletions() { + let aggregates = vec![AggregateFunction::Count]; + let group_by = vec!["category".to_string()]; + let input_columns = vec!["category".to_string(), "value".to_string()]; + + let mut agg = AggregateOperator::new(group_by, aggregates.clone(), input_columns); + + // Initialize with data + let mut init_data = Delta::new(); + init_data.insert(1, vec![Value::Text("A".into()), Value::Integer(10)]); + init_data.insert(2, vec![Value::Text("A".into()), Value::Integer(20)]); + init_data.insert(3, vec![Value::Text("B".into()), Value::Integer(30)]); + agg.initialize(init_data); + + // Check initial counts + let state = agg.get_current_state(); + assert_eq!(state.changes.len(), 2); + + // Find group A and B + let group_a = state + .changes + .iter() + .find(|(row, _)| row.values[0] == Value::Text("A".into())) + .unwrap(); + let group_b = state + .changes + .iter() + .find(|(row, _)| row.values[0] == Value::Text("B".into())) + .unwrap(); + + assert_eq!(group_a.0.values[1], Value::Integer(2)); // COUNT = 2 for A + assert_eq!(group_b.0.values[1], Value::Integer(1)); // COUNT = 1 for B + + // Delete one row from group A + let mut delete_delta = Delta::new(); + delete_delta.delete(1, vec![Value::Text("A".into()), Value::Integer(10)]); + + let output = agg.eval(delete_delta.clone(), None); + agg.commit(delete_delta); + + // Should emit retraction for old count and insertion for new count + assert_eq!(output.changes.len(), 2); + + // Check final state + let final_state = agg.get_current_state(); + let group_a_final = final_state + .changes + .iter() + .find(|(row, _)| row.values[0] == Value::Text("A".into())) + .unwrap(); + assert_eq!(group_a_final.0.values[1], Value::Integer(1)); // COUNT = 1 for A after deletion + + // Delete all rows from group B + let mut delete_all_b = Delta::new(); + delete_all_b.delete(3, vec![Value::Text("B".into()), Value::Integer(30)]); + + let output_b = agg.eval(delete_all_b.clone(), None); + agg.commit(delete_all_b); + assert_eq!(output_b.changes.len(), 1); // Only retraction, no new row + assert_eq!(output_b.changes[0].1, -1); // Retraction + + // Final state should not have group B + let final_state2 = agg.get_current_state(); + assert_eq!(final_state2.changes.len(), 1); // Only group A remains + assert_eq!(final_state2.changes[0].0.values[0], Value::Text("A".into())); + } + + #[test] + fn test_sum_aggregation_with_deletions() { + let aggregates = vec![AggregateFunction::Sum("value".to_string())]; + let group_by = vec!["category".to_string()]; + let input_columns = vec!["category".to_string(), "value".to_string()]; + + let mut agg = AggregateOperator::new(group_by, aggregates.clone(), input_columns); + + // Initialize with data + let mut init_data = Delta::new(); + init_data.insert(1, vec![Value::Text("A".into()), Value::Integer(10)]); + init_data.insert(2, vec![Value::Text("A".into()), Value::Integer(20)]); + init_data.insert(3, vec![Value::Text("B".into()), Value::Integer(30)]); + init_data.insert(4, vec![Value::Text("B".into()), Value::Integer(15)]); + agg.initialize(init_data); + + // Check initial sums + let state = agg.get_current_state(); + let group_a = state + .changes + .iter() + .find(|(row, _)| row.values[0] == Value::Text("A".into())) + .unwrap(); + let group_b = state + .changes + .iter() + .find(|(row, _)| row.values[0] == Value::Text("B".into())) + .unwrap(); + + assert_eq!(group_a.0.values[1], Value::Integer(30)); // SUM = 30 for A (10+20) + assert_eq!(group_b.0.values[1], Value::Integer(45)); // SUM = 45 for B (30+15) + + // Delete one row from group A + let mut delete_delta = Delta::new(); + delete_delta.delete(2, vec![Value::Text("A".into()), Value::Integer(20)]); + + let _ = agg.eval(delete_delta.clone(), None); + agg.commit(delete_delta); + + // Check updated sum + let state = agg.get_current_state(); + let group_a = state + .changes + .iter() + .find(|(row, _)| row.values[0] == Value::Text("A".into())) + .unwrap(); + assert_eq!(group_a.0.values[1], Value::Integer(10)); // SUM = 10 for A after deletion + + // Delete all from group B + let mut delete_all_b = Delta::new(); + delete_all_b.delete(3, vec![Value::Text("B".into()), Value::Integer(30)]); + delete_all_b.delete(4, vec![Value::Text("B".into()), Value::Integer(15)]); + + let _ = agg.eval(delete_all_b.clone(), None); + agg.commit(delete_all_b); + + // Group B should be gone + let final_state = agg.get_current_state(); + assert_eq!(final_state.changes.len(), 1); // Only group A remains + assert_eq!(final_state.changes[0].0.values[0], Value::Text("A".into())); + } + + #[test] + fn test_avg_aggregation_with_deletions() { + let aggregates = vec![AggregateFunction::Avg("value".to_string())]; + let group_by = vec!["category".to_string()]; + let input_columns = vec!["category".to_string(), "value".to_string()]; + + let mut agg = AggregateOperator::new(group_by, aggregates.clone(), input_columns); + + // Initialize with data + let mut init_data = Delta::new(); + init_data.insert(1, vec![Value::Text("A".into()), Value::Integer(10)]); + init_data.insert(2, vec![Value::Text("A".into()), Value::Integer(20)]); + init_data.insert(3, vec![Value::Text("A".into()), Value::Integer(30)]); + agg.initialize(init_data); + + // Check initial average + let state = agg.get_current_state(); + assert_eq!(state.changes.len(), 1); + assert_eq!(state.changes[0].0.values[1], Value::Float(20.0)); // AVG = (10+20+30)/3 = 20 + + // Delete the middle value + let mut delete_delta = Delta::new(); + delete_delta.delete(2, vec![Value::Text("A".into()), Value::Integer(20)]); + + let _ = agg.eval(delete_delta.clone(), None); + agg.commit(delete_delta); + + // Check updated average + let state = agg.get_current_state(); + assert_eq!(state.changes[0].0.values[1], Value::Float(20.0)); // AVG = (10+30)/2 = 20 (same!) + + // Delete another to change the average + let mut delete_another = Delta::new(); + delete_another.delete(3, vec![Value::Text("A".into()), Value::Integer(30)]); + + let _ = agg.eval(delete_another.clone(), None); + agg.commit(delete_another); + + let state = agg.get_current_state(); + assert_eq!(state.changes[0].0.values[1], Value::Float(10.0)); // AVG = 10/1 = 10 + } + + #[test] + fn test_multiple_aggregations_with_deletions() { + // Test COUNT, SUM, and AVG together + let aggregates = vec![ + AggregateFunction::Count, + AggregateFunction::Sum("value".to_string()), + AggregateFunction::Avg("value".to_string()), + ]; + let group_by = vec!["category".to_string()]; + let input_columns = vec!["category".to_string(), "value".to_string()]; + + let mut agg = AggregateOperator::new(group_by, aggregates.clone(), input_columns); + + // Initialize with data + let mut init_data = Delta::new(); + init_data.insert(1, vec![Value::Text("A".into()), Value::Integer(100)]); + init_data.insert(2, vec![Value::Text("A".into()), Value::Integer(200)]); + init_data.insert(3, vec![Value::Text("B".into()), Value::Integer(50)]); + agg.initialize(init_data); + + // Check initial state + let state = agg.get_current_state(); + let group_a = state + .changes + .iter() + .find(|(row, _)| row.values[0] == Value::Text("A".into())) + .unwrap(); + + assert_eq!(group_a.0.values[1], Value::Integer(2)); // COUNT = 2 + assert_eq!(group_a.0.values[2], Value::Integer(300)); // SUM = 300 + assert_eq!(group_a.0.values[3], Value::Float(150.0)); // AVG = 150 + + // Delete one row from group A + let mut delete_delta = Delta::new(); + delete_delta.delete(1, vec![Value::Text("A".into()), Value::Integer(100)]); + + let _ = agg.eval(delete_delta.clone(), None); + agg.commit(delete_delta); + + // Check all aggregates updated correctly + let state = agg.get_current_state(); + let group_a = state + .changes + .iter() + .find(|(row, _)| row.values[0] == Value::Text("A".into())) + .unwrap(); + + assert_eq!(group_a.0.values[1], Value::Integer(1)); // COUNT = 1 + assert_eq!(group_a.0.values[2], Value::Integer(200)); // SUM = 200 + assert_eq!(group_a.0.values[3], Value::Float(200.0)); // AVG = 200 + + // Insert a new row with floating point value + let mut insert_delta = Delta::new(); + insert_delta.insert(4, vec![Value::Text("A".into()), Value::Float(50.5)]); + + let _ = agg.eval(insert_delta.clone(), None); + agg.commit(insert_delta); + + let state = agg.get_current_state(); + let group_a = state + .changes + .iter() + .find(|(row, _)| row.values[0] == Value::Text("A".into())) + .unwrap(); + + assert_eq!(group_a.0.values[1], Value::Integer(2)); // COUNT = 2 + assert_eq!(group_a.0.values[2], Value::Float(250.5)); // SUM = 250.5 + assert_eq!(group_a.0.values[3], Value::Float(125.25)); // AVG = 125.25 + } + + #[test] + fn test_filter_operator_rowid_update() { + // When a row's rowid changes (e.g., UPDATE t SET a=1 WHERE a=3 on INTEGER PRIMARY KEY), + // the operator should properly consolidate the state + + let mut filter = FilterOperator::new( + FilterPredicate::GreaterThan { + column: "b".to_string(), + value: Value::Integer(2), + }, + vec!["a".to_string(), "b".to_string()], + ); + + // Initialize with a row (rowid=3, values=[3, 3]) + let mut init_data = Delta::new(); + init_data.insert(3, vec![Value::Integer(3), Value::Integer(3)]); + filter.initialize(init_data); + + // Check initial state + let state = filter.get_current_state(); + assert_eq!(state.changes.len(), 1); + assert_eq!(state.changes[0].0.rowid, 3); + assert_eq!( + state.changes[0].0.values, + vec![Value::Integer(3), Value::Integer(3)] + ); + + // Simulate an UPDATE that changes rowid from 3 to 1 + // This is sent as: delete(3) + insert(1) + let mut update_delta = Delta::new(); + update_delta.delete(3, vec![Value::Integer(3), Value::Integer(3)]); + update_delta.insert(1, vec![Value::Integer(1), Value::Integer(3)]); + + let output = filter.eval(update_delta.clone(), None); + filter.commit(update_delta); + + // The output delta should have both changes (both pass the filter b > 2) + assert_eq!(output.changes.len(), 2); + assert_eq!(output.changes[0].1, -1); // delete weight + assert_eq!(output.changes[1].1, 1); // insert weight + + // The current state should be consolidated to only have rows with positive weight + let final_state = filter.get_current_state(); + + // After consolidation, we should have only one row with rowid=1 + assert_eq!( + final_state.changes.len(), + 1, + "State should be consolidated to have only one row" + ); + assert_eq!(final_state.changes[0].0.rowid, 1); + assert_eq!( + final_state.changes[0].0.values, + vec![Value::Integer(1), Value::Integer(3)] + ); + assert_eq!(final_state.changes[0].1, 1); // positive weight + } + + // ============================================================================ + // EVAL/COMMIT PATTERN TESTS + // These tests verify that the eval/commit pattern works correctly: + // - eval() computes results without modifying state + // - eval() with uncommitted data returns correct results + // - commit() updates internal state + // - State remains unchanged when eval() is called with uncommitted data + // ============================================================================ + + #[test] + fn test_filter_eval_with_uncommitted() { + let mut filter = FilterOperator::new( + FilterPredicate::GreaterThan { + column: "age".to_string(), + value: Value::Integer(25), + }, + vec!["id".to_string(), "name".to_string(), "age".to_string()], + ); + + // Initialize with some data + let mut init_data = Delta::new(); + init_data.insert( + 1, + vec![ + Value::Integer(1), + Value::Text("Alice".into()), + Value::Integer(30), + ], + ); + init_data.insert( + 2, + vec![ + Value::Integer(2), + Value::Text("Bob".into()), + Value::Integer(20), + ], + ); + filter.initialize(init_data); + + // Verify initial state (only Alice passes filter) + let state = filter.get_current_state(); + assert_eq!(state.changes.len(), 1); + assert_eq!(state.changes[0].0.rowid, 1); + + // Create uncommitted changes + let mut uncommitted = Delta::new(); + uncommitted.insert( + 3, + vec![ + Value::Integer(3), + Value::Text("Charlie".into()), + Value::Integer(35), + ], + ); + uncommitted.insert( + 4, + vec![ + Value::Integer(4), + Value::Text("David".into()), + Value::Integer(15), + ], + ); + + // Eval with uncommitted - should return filtered uncommitted rows + let result = filter.eval(Delta::new(), Some(uncommitted.clone())); + assert_eq!( + result.changes.len(), + 1, + "Only Charlie (35) should pass filter" + ); + assert_eq!(result.changes[0].0.rowid, 3); + + // Verify state hasn't changed + let state_after_eval = filter.get_current_state(); + assert_eq!( + state_after_eval.changes.len(), + 1, + "State should still only have Alice" + ); + assert_eq!(state_after_eval.changes[0].0.rowid, 1); + + // Now commit the changes + filter.commit(uncommitted); + + // State should now include Charlie (who passes filter) + let final_state = filter.get_current_state(); + assert_eq!( + final_state.changes.len(), + 2, + "State should now have Alice and Charlie" + ); + } + + #[test] + fn test_aggregate_eval_with_uncommitted_preserves_state() { + // This is the critical test - aggregations must not modify internal state during eval + let mut agg = AggregateOperator::new( + vec!["category".to_string()], + vec![ + AggregateFunction::Count, + AggregateFunction::Sum("amount".to_string()), + ], + vec![ + "id".to_string(), + "category".to_string(), + "amount".to_string(), + ], + ); + + // Initialize with data + let mut init_data = Delta::new(); + init_data.insert( + 1, + vec![ + Value::Integer(1), + Value::Text("A".into()), + Value::Integer(100), + ], + ); + init_data.insert( + 2, + vec![ + Value::Integer(2), + Value::Text("A".into()), + Value::Integer(200), + ], + ); + init_data.insert( + 3, + vec![ + Value::Integer(3), + Value::Text("B".into()), + Value::Integer(150), + ], + ); + agg.initialize(init_data); + + // Check initial state: A -> (count=2, sum=300), B -> (count=1, sum=150) + let initial_state = agg.get_current_state(); + assert_eq!(initial_state.changes.len(), 2); + + // Store initial state for comparison + let initial_a = initial_state + .changes + .iter() + .find(|(row, _)| row.values[0] == Value::Text("A".into())) + .unwrap(); + assert_eq!(initial_a.0.values[1], Value::Integer(2)); // count + assert_eq!(initial_a.0.values[2], Value::Float(300.0)); // sum + + // Create uncommitted changes + let mut uncommitted = Delta::new(); + uncommitted.insert( + 4, + vec![ + Value::Integer(4), + Value::Text("A".into()), + Value::Integer(50), + ], + ); + uncommitted.insert( + 5, + vec![ + Value::Integer(5), + Value::Text("C".into()), + Value::Integer(75), + ], + ); + + // Eval with uncommitted should return the delta (changes to aggregates) + let result = agg.eval(Delta::new(), Some(uncommitted.clone())); + + // Result should contain updates for A and new group C + // For A: retraction of old (2, 300) and insertion of new (3, 350) + // For C: insertion of (1, 75) + assert!(!result.changes.is_empty(), "Should have aggregate changes"); + + // CRITICAL: Verify internal state hasn't changed + let state_after_eval = agg.get_current_state(); + assert_eq!( + state_after_eval.changes.len(), + 2, + "State should still have only A and B" + ); + + let a_after_eval = state_after_eval + .changes + .iter() + .find(|(row, _)| row.values[0] == Value::Text("A".into())) + .unwrap(); + assert_eq!( + a_after_eval.0.values[1], + Value::Integer(2), + "A count should still be 2" + ); + assert_eq!( + a_after_eval.0.values[2], + Value::Float(300.0), + "A sum should still be 300" + ); + + // Now commit the changes + agg.commit(uncommitted); + + // State should now be updated + let final_state = agg.get_current_state(); + assert_eq!(final_state.changes.len(), 3, "Should now have A, B, and C"); + + let a_final = final_state + .changes + .iter() + .find(|(row, _)| row.values[0] == Value::Text("A".into())) + .unwrap(); + assert_eq!( + a_final.0.values[1], + Value::Integer(3), + "A count should now be 3" + ); + assert_eq!( + a_final.0.values[2], + Value::Float(350.0), + "A sum should now be 350" + ); + + let c_final = final_state + .changes + .iter() + .find(|(row, _)| row.values[0] == Value::Text("C".into())) + .unwrap(); + assert_eq!( + c_final.0.values[1], + Value::Integer(1), + "C count should be 1" + ); + assert_eq!( + c_final.0.values[2], + Value::Float(75.0), + "C sum should be 75" + ); + } + + #[test] + fn test_aggregate_eval_multiple_times_without_commit() { + // Test that calling eval multiple times with different uncommitted data + // doesn't pollute the internal state + let mut agg = AggregateOperator::new( + vec![], // No GROUP BY + vec![ + AggregateFunction::Count, + AggregateFunction::Sum("value".to_string()), + ], + vec!["id".to_string(), "value".to_string()], + ); + + // Initialize + let mut init_data = Delta::new(); + init_data.insert(1, vec![Value::Integer(1), Value::Integer(100)]); + init_data.insert(2, vec![Value::Integer(2), Value::Integer(200)]); + agg.initialize(init_data); + + // Initial state: count=2, sum=300 + let initial_state = agg.get_current_state(); + assert_eq!(initial_state.changes.len(), 1); + assert_eq!(initial_state.changes[0].0.values[0], Value::Integer(2)); + assert_eq!(initial_state.changes[0].0.values[1], Value::Float(300.0)); + + // First eval with uncommitted + let mut uncommitted1 = Delta::new(); + uncommitted1.insert(3, vec![Value::Integer(3), Value::Integer(50)]); + let _ = agg.eval(Delta::new(), Some(uncommitted1)); + + // State should be unchanged + let state1 = agg.get_current_state(); + assert_eq!(state1.changes[0].0.values[0], Value::Integer(2)); + assert_eq!(state1.changes[0].0.values[1], Value::Float(300.0)); + + // Second eval with different uncommitted + let mut uncommitted2 = Delta::new(); + uncommitted2.insert(4, vec![Value::Integer(4), Value::Integer(75)]); + uncommitted2.insert(5, vec![Value::Integer(5), Value::Integer(25)]); + let _ = agg.eval(Delta::new(), Some(uncommitted2)); + + // State should STILL be unchanged + let state2 = agg.get_current_state(); + assert_eq!(state2.changes[0].0.values[0], Value::Integer(2)); + assert_eq!(state2.changes[0].0.values[1], Value::Float(300.0)); + + // Third eval with deletion as uncommitted + let mut uncommitted3 = Delta::new(); + uncommitted3.delete(1, vec![Value::Integer(1), Value::Integer(100)]); + let _ = agg.eval(Delta::new(), Some(uncommitted3)); + + // State should STILL be unchanged + let state3 = agg.get_current_state(); + assert_eq!(state3.changes[0].0.values[0], Value::Integer(2)); + assert_eq!(state3.changes[0].0.values[1], Value::Float(300.0)); + } + + #[test] + fn test_aggregate_eval_with_mixed_committed_and_uncommitted() { + // Test eval with both committed delta and uncommitted changes + let mut agg = AggregateOperator::new( + vec!["type".to_string()], + vec![AggregateFunction::Count], + vec!["id".to_string(), "type".to_string()], + ); + + // Initialize + let mut init_data = Delta::new(); + init_data.insert(1, vec![Value::Integer(1), Value::Text("X".into())]); + init_data.insert(2, vec![Value::Integer(2), Value::Text("Y".into())]); + agg.initialize(init_data); + + // Create a committed delta (to be processed) + let mut committed_delta = Delta::new(); + committed_delta.insert(3, vec![Value::Integer(3), Value::Text("X".into())]); + + // Create uncommitted changes + let mut uncommitted = Delta::new(); + uncommitted.insert(4, vec![Value::Integer(4), Value::Text("Y".into())]); + uncommitted.insert(5, vec![Value::Integer(5), Value::Text("Z".into())]); + + // Eval with both - should process both but not commit + let result = agg.eval(committed_delta.clone(), Some(uncommitted)); + + // Result should reflect changes from both + assert!(!result.changes.is_empty()); + + // But internal state should be unchanged + let state = agg.get_current_state(); + assert_eq!(state.changes.len(), 2, "Should still have only X and Y"); + + // Now commit only the committed_delta + agg.commit(committed_delta); + + // State should now have X count=2, Y count=1 + let final_state = agg.get_current_state(); + let x = final_state + .changes + .iter() + .find(|(row, _)| row.values[0] == Value::Text("X".into())) + .unwrap(); + assert_eq!(x.0.values[1], Value::Integer(2)); } } diff --git a/core/incremental/view.rs b/core/incremental/view.rs index 4f4d4c6e6..bcae0afae 100644 --- a/core/incremental/view.rs +++ b/core/incremental/view.rs @@ -1,13 +1,12 @@ +use super::compiler::{DbspCircuit, DbspCompiler, DeltaSet}; use super::dbsp::{RowKeyStream, RowKeyZSet}; -use super::operator::{ - AggregateFunction, AggregateOperator, ComputationTracker, Delta, FilterOperator, - FilterPredicate, IncrementalOperator, ProjectOperator, -}; +use super::operator::{ComputationTracker, Delta, FilterPredicate}; use crate::schema::{BTreeTable, Column, Schema}; +use crate::translate::logical::LogicalPlanBuilder; use crate::types::{IOCompletions, IOResult, Value}; -use crate::util::{extract_column_name_from_expr, extract_view_columns}; +use crate::util::extract_view_columns; use crate::{io_yield_one, Completion, LimboError, Result, Statement}; -use std::collections::BTreeMap; +use std::collections::{BTreeMap, HashMap}; use std::fmt; use std::sync::{Arc, Mutex}; use turso_parser::ast; @@ -60,8 +59,7 @@ pub struct ViewTransactionState { /// for large aggregations, because then we don't have to re-compute when opening the database /// again. /// -/// Right now we are supporting the simplest views by keeping the operators in the view and -/// applying them in a sane order. But the general solution would turn this into a DBSP circuit. +/// Uses DBSP circuits for incremental computation. #[derive(Debug)] pub struct IncrementalView { // Stream of row keys for this view @@ -75,12 +73,11 @@ pub struct IncrementalView { // The SELECT statement that defines how to transform input data pub select_stmt: ast::Select, - // Internal filter operator for predicate evaluation - filter_operator: Option, - // Internal project operator for value transformation - project_operator: Option, - // Internal aggregate operator for GROUP BY and aggregations - aggregate_operator: Option, + // DBSP circuit that encapsulates the computation + circuit: DbspCircuit, + // Track whether circuit has been initialized with data + circuit_initialized: bool, + // Tables referenced by this view (extracted from FROM clause and JOINs) base_table: Arc, // The view's output columns with their types @@ -108,6 +105,25 @@ impl IncrementalView { Ok(()) } + /// Try to compile the SELECT statement into a DBSP circuit + fn try_compile_circuit( + select: &ast::Select, + schema: &Schema, + _base_table: &Arc, + ) -> Result { + // Build the logical plan from the SELECT statement + let mut builder = LogicalPlanBuilder::new(schema); + // Convert Select to a Stmt for the builder + let stmt = ast::Stmt::Select(select.clone()); + let logical_plan = builder.build_statement(&stmt)?; + + // Compile the logical plan to a DBSP circuit + let compiler = DbspCompiler::new(); + let circuit = compiler.compile(&logical_plan)?; + + Ok(circuit) + } + /// Get an iterator over column names, using enumerated naming for unnamed columns pub fn column_names(&self) -> impl Iterator + '_ { self.columns.iter().enumerate().map(|(i, col)| { @@ -136,14 +152,6 @@ impl IncrementalView { false } - /// Apply filter operator to check if values pass the view's WHERE clause - fn apply_filter(&self, values: &[Value]) -> bool { - if let Some(ref filter_op) = self.filter_operator { - filter_op.evaluate_predicate(values) - } else { - true - } - } pub fn from_sql(sql: &str, schema: &Schema) -> Result { let mut parser = Parser::new(sql.as_bytes()); let cmd = parser.next_cmd()?; @@ -173,10 +181,6 @@ impl IncrementalView { // Extract output columns using the shared function let view_columns = extract_view_columns(&select, schema); - // Extract GROUP BY columns and aggregate functions - let (group_by_columns, aggregate_functions, _old_output_names) = - Self::extract_aggregation_info(&select); - let (join_tables, join_condition) = Self::extract_join_info(&select); if join_tables.is_some() || join_condition.is_some() { return Err(LimboError::ParseError( @@ -199,105 +203,43 @@ impl IncrementalView { )); }; - let base_table_column_names = base_table - .columns - .iter() - .enumerate() - .map(|(i, col)| col.name.clone().unwrap_or_else(|| format!("column_{i}"))) - .collect(); - Self::new( name, - Vec::new(), // Empty initial data where_predicate, select.clone(), base_table, - base_table_column_names, view_columns, - group_by_columns, - aggregate_functions, schema, ) } - #[allow(clippy::too_many_arguments)] pub fn new( name: String, - initial_data: Vec<(i64, Vec)>, where_predicate: FilterPredicate, select_stmt: ast::Select, base_table: Arc, - base_table_column_names: Vec, columns: Vec, - group_by_columns: Vec, - aggregate_functions: Vec, schema: &Schema, ) -> Result { - let mut records = BTreeMap::new(); - - for (row_key, values) in initial_data { - records.insert(row_key, values); - } - - // Create initial stream with row keys - let mut zset = RowKeyZSet::new(); - for (row_key, values) in &records { - use crate::incremental::hashable_row::HashableRow; - let row = HashableRow::new(*row_key, values.clone()); - zset.insert(row, 1); - } + let records = BTreeMap::new(); // Create the tracker that will be shared by all operators let tracker = Arc::new(Mutex::new(ComputationTracker::new())); - // Create filter operator if we have a predicate - let filter_operator = if !matches!(where_predicate, FilterPredicate::None) { - let mut filter_op = - FilterOperator::new(where_predicate.clone(), base_table_column_names.clone()); - filter_op.set_tracker(tracker.clone()); - Some(filter_op) - } else { - None - }; + // Compile the SELECT statement into a DBSP circuit + let circuit = Self::try_compile_circuit(&select_stmt, schema, &base_table)?; - // Check if this is an aggregated view - let is_aggregated = !group_by_columns.is_empty() || !aggregate_functions.is_empty(); - - // Create aggregate operator if needed - let aggregate_operator = if is_aggregated { - let mut agg_op = AggregateOperator::new( - group_by_columns, - aggregate_functions, - base_table_column_names.clone(), - ); - agg_op.set_tracker(tracker.clone()); - Some(agg_op) - } else { - None - }; - - // Only create project operator for non-aggregated views - let project_operator = if !is_aggregated { - let mut proj_op = ProjectOperator::from_select( - &select_stmt, - base_table_column_names.clone(), - schema, - )?; - proj_op.set_tracker(tracker.clone()); - Some(proj_op) - } else { - None - }; + // Circuit will be initialized when we first call merge_delta + let circuit_initialized = false; Ok(Self { - stream: RowKeyStream::from_zset(zset), + stream: RowKeyStream::from_zset(RowKeyZSet::new()), name, records, where_predicate, select_stmt, - filter_operator, - project_operator, - aggregate_operator, + circuit, + circuit_initialized, base_table, columns, populate_state: PopulateState::Start, @@ -338,46 +280,28 @@ impl IncrementalView { // Get the base table from referenced tables let table = &self.base_table; - // Build column list for SELECT clause - let select_columns = if let Some(ref project_op) = self.project_operator { - // Get the columns used by the projection operator - let mut columns = Vec::new(); - for col in project_op.columns() { - // Check if it's a simple column reference - if let turso_parser::ast::Expr::Id(name) = &col.expr { - columns.push(name.as_str().to_string()); - } else { - // For expressions, we need all columns (for now) - columns.clear(); - columns.push("*".to_string()); - break; - } - } - if columns.is_empty() || columns.contains(&"*".to_string()) { - "*".to_string() - } else { - // Add the columns and always include rowid - columns.join(", ").to_string() - } - } else { - // No projection, use all columns + // Check if the table has a rowid alias (INTEGER PRIMARY KEY column) + let has_rowid_alias = table.columns.iter().any(|col| col.is_rowid_alias); + + // For now, select all columns since we don't have the static operators + // The circuit will handle filtering and projection + // If there's a rowid alias, we don't need to select rowid separately + let select_clause = if has_rowid_alias { "*".to_string() + } else { + "*, rowid".to_string() }; - // Build WHERE clause from filter operator - let where_clause = if let Some(ref filter_op) = self.filter_operator { - self.build_where_clause(filter_op.predicate())? - } else { - String::new() - }; + // Build WHERE clause from the where_predicate + let where_clause = self.build_where_clause(&self.where_predicate)?; // Construct the final query let query = if where_clause.is_empty() { - format!("SELECT {}, rowid FROM {}", select_columns, table.name) + format!("SELECT {} FROM {}", select_clause, table.name) } else { format!( - "SELECT {}, rowid FROM {} WHERE {}", - select_columns, table.name, where_clause + "SELECT {} FROM {} WHERE {}", + select_clause, table.name, where_clause ) }; Ok(query) @@ -494,20 +418,40 @@ impl IncrementalView { let all_values: Vec = row.get_values().cloned().collect(); - // The last value should be the rowid - let rowid = match all_values.last() { - Some(crate::types::Value::Integer(id)) => *id, - _ => { - // This shouldn't happen - rowid must be an integer - *rows_processed += 1; - batch_count += 1; - continue; - } + // Determine how to extract the rowid + // If there's a rowid alias (INTEGER PRIMARY KEY), the rowid is one of the columns + // Otherwise, it's the last value we explicitly selected + let (rowid, values) = if let Some((idx, _)) = + self.base_table.get_rowid_alias_column() + { + // The rowid is the value at the rowid alias column index + let rowid = match all_values.get(idx) { + Some(crate::types::Value::Integer(id)) => *id, + _ => { + // This shouldn't happen - rowid alias must be an integer + *rows_processed += 1; + batch_count += 1; + continue; + } + }; + // All values are table columns (no separate rowid was selected) + (rowid, all_values) + } else { + // The last value is the explicitly selected rowid + let rowid = match all_values.last() { + Some(crate::types::Value::Integer(id)) => *id, + _ => { + // This shouldn't happen - rowid must be an integer + *rows_processed += 1; + batch_count += 1; + continue; + } + }; + // Get all values except the rowid + let values = all_values[..all_values.len() - 1].to_vec(); + (rowid, values) }; - // Get all values except the rowid - let values = all_values[..all_values.len() - 1].to_vec(); - // Add to batch delta - let merge_delta handle filtering and aggregation batch_delta.insert(rowid, values); @@ -542,120 +486,6 @@ impl IncrementalView { } } - /// Extract GROUP BY columns and aggregate functions from SELECT statement - fn extract_aggregation_info( - select: &ast::Select, - ) -> (Vec, Vec, Vec) { - use turso_parser::ast::*; - - let mut group_by_columns = Vec::new(); - let mut aggregate_functions = Vec::new(); - let mut output_column_names = Vec::new(); - - if let OneSelect::Select { - ref group_by, - ref columns, - .. - } = select.body.select - { - // Extract GROUP BY columns - if let Some(group_by) = group_by { - for expr in &group_by.exprs { - if let Some(col_name) = extract_column_name_from_expr(expr) { - group_by_columns.push(col_name); - } - } - } - - // Extract aggregate functions and column names/aliases from SELECT list - for result_col in columns { - match result_col { - ResultColumn::Expr(expr, alias) => { - // Extract aggregate functions - let mut found_aggregates = Vec::new(); - Self::extract_aggregates_from_expr(expr, &mut found_aggregates); - - // Determine the output column name - let col_name = if let Some(As::As(alias_name)) = alias { - // Use the provided alias - alias_name.as_str().to_string() - } else if !found_aggregates.is_empty() { - // Use the default name from the aggregate function - found_aggregates[0].default_output_name() - } else if let Some(name) = extract_column_name_from_expr(expr) { - // Use the column name - name - } else { - // Fallback to a generic name - format!("column{}", output_column_names.len() + 1) - }; - - output_column_names.push(col_name); - aggregate_functions.extend(found_aggregates); - } - ResultColumn::Star => { - // For SELECT *, we'd need to know the base table columns - // This is handled elsewhere - } - ResultColumn::TableStar(_) => { - // Similar to Star, but for a specific table - } - } - } - } - - (group_by_columns, aggregate_functions, output_column_names) - } - - /// Recursively extract aggregate functions from an expression - fn extract_aggregates_from_expr( - expr: &ast::Expr, - aggregate_functions: &mut Vec, - ) { - use crate::function::Func; - use turso_parser::ast::*; - - match expr { - // Handle COUNT(*) and similar aggregate functions with * - Expr::FunctionCallStar { name, .. } => { - // FunctionCallStar is typically COUNT(*), which has 0 args - if let Ok(func) = Func::resolve_function(name.as_str(), 0) { - // Use the centralized mapping from operator.rs - // For COUNT(*), we pass None as the input column - if let Some(agg_func) = AggregateFunction::from_sql_function(&func, None) { - aggregate_functions.push(agg_func); - } - } - } - Expr::FunctionCall { name, args, .. } => { - // Regular function calls with arguments - let arg_count = args.len(); - - if let Ok(func) = Func::resolve_function(name.as_str(), arg_count) { - // Extract the input column if there's an argument - let input_column = if arg_count > 0 { - args.first().and_then(extract_column_name_from_expr) - } else { - None - }; - - // Use the centralized mapping from operator.rs - if let Some(agg_func) = - AggregateFunction::from_sql_function(&func, input_column) - { - aggregate_functions.push(agg_func); - } - } - } - // Recursively check binary expressions, etc. - Expr::Binary(left, _, right) => { - Self::extract_aggregates_from_expr(left, aggregate_functions); - Self::extract_aggregates_from_expr(right, aggregate_functions); - } - _ => {} - } - } - /// Extract JOIN information from SELECT statement #[allow(clippy::type_complexity)] pub fn extract_join_info( @@ -743,48 +573,36 @@ impl IncrementalView { /// Get current data merged with transaction state pub fn current_data(&self, tx_state: Option<&ViewTransactionState>) -> Vec<(i64, Vec)> { - // Start with committed records - if let Some(tx_state) = tx_state { - // processed_delta = input delta for now. Need to apply operations - let processed_delta = &tx_state.delta; + // Use circuit to process uncommitted changes + let mut uncommitted = DeltaSet::new(); + uncommitted.insert(self.base_table.name.clone(), tx_state.delta.clone()); - // For non-aggregation views, merge the processed delta with committed records - let mut result_map: BTreeMap> = self.records.clone(); - - for (row, weight) in &processed_delta.changes { - if *weight > 0 && self.apply_filter(&row.values) { - result_map.insert(row.rowid, row.values.clone()); - } else if *weight < 0 { - result_map.remove(&row.rowid); + // Execute with uncommitted changes (won't affect circuit state) + match self.circuit.execute(HashMap::new(), uncommitted) { + Ok(processed_delta) => { + // Merge processed delta with committed records + let mut result_map: BTreeMap> = self.records.clone(); + for (row, weight) in &processed_delta.changes { + if *weight > 0 { + result_map.insert(row.rowid, row.values.clone()); + } else if *weight < 0 { + result_map.remove(&row.rowid); + } + } + result_map.into_iter().collect() + } + Err(e) => { + // Return error or panic - no fallback + panic!("Failed to execute circuit with uncommitted data: {e:?}"); } } - - result_map.into_iter().collect() } else { // No transaction state: return committed records self.records.clone().into_iter().collect() } } - /// Apply filter operator to a delta if present - fn apply_filter_to_delta(&mut self, delta: Delta) -> Delta { - if let Some(ref mut filter_op) = self.filter_operator { - filter_op.process_delta(delta) - } else { - delta - } - } - - /// Apply aggregation operator to a delta if this is an aggregated view - fn apply_aggregation_to_delta(&mut self, delta: Delta) -> Delta { - if let Some(ref mut agg_op) = self.aggregate_operator { - agg_op.process_delta(delta) - } else { - delta - } - } - /// Merge a delta of changes into the view's current state pub fn merge_delta(&mut self, delta: &Delta) { // Early return if delta is empty @@ -792,16 +610,33 @@ impl IncrementalView { return; } - // Apply operators in pipeline - let mut current_delta = delta.clone(); - current_delta = self.apply_filter_to_delta(current_delta); + // Use the circuit to process the delta + let mut input_data = HashMap::new(); + input_data.insert(self.base_table.name.clone(), delta.clone()); - // Apply projection operator if present (for non-aggregated views) - if let Some(ref mut project_op) = self.project_operator { - current_delta = project_op.process_delta(current_delta); + // If circuit hasn't been initialized yet, initialize it first + // This happens during populate_from_table + if !self.circuit_initialized { + // Initialize the circuit with empty state + self.circuit + .initialize(HashMap::new()) + .expect("Failed to initialize circuit"); + self.circuit_initialized = true; } - current_delta = self.apply_aggregation_to_delta(current_delta); + // Execute the circuit to process the delta + let current_delta = match self.circuit.execute(input_data.clone(), DeltaSet::empty()) { + Ok(output) => { + // Commit the changes to the circuit's internal state + self.circuit + .commit(input_data) + .expect("Failed to commit to circuit"); + output + } + Err(e) => { + panic!("Failed to execute circuit: {e:?}"); + } + }; // Update records and stream with the processed delta let mut zset_delta = RowKeyZSet::new(); @@ -819,441 +654,3 @@ impl IncrementalView { self.stream.apply_delta(&zset_delta); } } - -#[cfg(test)] -mod tests { - use super::*; - use crate::incremental::operator::{Delta, IncrementalOperator}; - use crate::schema::{BTreeTable, Column, Schema, Type}; - use crate::types::Value; - use std::sync::Arc; - fn create_test_schema() -> Schema { - let mut schema = Schema::new(false); - let table = BTreeTable { - root_page: 1, - name: "t".to_string(), - columns: vec![ - Column { - name: Some("a".to_string()), - ty: Type::Integer, - ty_str: "INTEGER".to_string(), - primary_key: false, - is_rowid_alias: false, - notnull: false, - default: None, - unique: false, - collation: None, - hidden: false, - }, - Column { - name: Some("b".to_string()), - ty: Type::Integer, - ty_str: "INTEGER".to_string(), - primary_key: false, - is_rowid_alias: false, - notnull: false, - default: None, - unique: false, - collation: None, - hidden: false, - }, - Column { - name: Some("c".to_string()), - ty: Type::Integer, - ty_str: "INTEGER".to_string(), - primary_key: false, - is_rowid_alias: false, - notnull: false, - default: None, - unique: false, - collation: None, - hidden: false, - }, - ], - primary_key_columns: vec![], - has_rowid: true, - is_strict: false, - unique_sets: None, - }; - schema.add_btree_table(Arc::new(table)); - schema - } - - #[test] - fn test_projection_simple_columns() { - let schema = create_test_schema(); - let sql = "CREATE MATERIALIZED VIEW v AS SELECT a, b FROM t"; - - let view = IncrementalView::from_sql(sql, &schema).unwrap(); - - assert!(view.project_operator.is_some()); - let project_op = view.project_operator.as_ref().unwrap(); - - let mut delta = Delta::new(); - delta.insert( - 1, - vec![Value::Integer(10), Value::Integer(20), Value::Integer(30)], - ); - - let mut temp_project = project_op.clone(); - temp_project.initialize(delta); - let result = temp_project.get_current_state(); - - let (output, _weight) = result.changes.first().unwrap(); - assert_eq!(output.values, vec![Value::Integer(10), Value::Integer(20)]); - } - - #[test] - fn test_projection_arithmetic_expression() { - let schema = create_test_schema(); - let sql = "CREATE MATERIALIZED VIEW v AS SELECT a * 2 as doubled FROM t"; - - let view = IncrementalView::from_sql(sql, &schema).unwrap(); - - assert!(view.project_operator.is_some()); - let project_op = view.project_operator.as_ref().unwrap(); - - let mut delta = Delta::new(); - delta.insert( - 1, - vec![Value::Integer(4), Value::Integer(2), Value::Integer(0)], - ); - - let mut temp_project = project_op.clone(); - temp_project.initialize(delta); - let result = temp_project.get_current_state(); - - let (output, _weight) = result.changes.first().unwrap(); - assert_eq!(output.values, vec![Value::Integer(8)]); - } - - #[test] - fn test_projection_multiple_expressions() { - let schema = create_test_schema(); - let sql = "CREATE MATERIALIZED VIEW v AS SELECT a + b as sum, a - b as diff, c FROM t"; - - let view = IncrementalView::from_sql(sql, &schema).unwrap(); - - assert!(view.project_operator.is_some()); - let project_op = view.project_operator.as_ref().unwrap(); - - let mut delta = Delta::new(); - delta.insert( - 1, - vec![Value::Integer(10), Value::Integer(3), Value::Integer(7)], - ); - - let mut temp_project = project_op.clone(); - temp_project.initialize(delta); - let result = temp_project.get_current_state(); - - let (output, _weight) = result.changes.first().unwrap(); - assert_eq!( - output.values, - vec![Value::Integer(13), Value::Integer(7), Value::Integer(7),] - ); - } - - #[test] - fn test_projection_function_call() { - let schema = create_test_schema(); - let sql = "CREATE MATERIALIZED VIEW v AS SELECT abs(a - 300) as abs_diff, b FROM t"; - - let view = IncrementalView::from_sql(sql, &schema).unwrap(); - - assert!(view.project_operator.is_some()); - let project_op = view.project_operator.as_ref().unwrap(); - - let mut delta = Delta::new(); - delta.insert( - 1, - vec![Value::Integer(255), Value::Integer(20), Value::Integer(30)], - ); - - let mut temp_project = project_op.clone(); - temp_project.initialize(delta); - let result = temp_project.get_current_state(); - - let (output, _weight) = result.changes.first().unwrap(); - // abs(255 - 300) = abs(-45) = 45 - assert_eq!(output.values, vec![Value::Integer(45), Value::Integer(20),]); - } - - #[test] - fn test_projection_mixed_columns_and_expressions() { - let schema = create_test_schema(); - let sql = "CREATE MATERIALIZED VIEW v AS SELECT a, b * 2 as doubled, c, a + b + c as total FROM t"; - - let view = IncrementalView::from_sql(sql, &schema).unwrap(); - - assert!(view.project_operator.is_some()); - let project_op = view.project_operator.as_ref().unwrap(); - - let mut delta = Delta::new(); - delta.insert( - 1, - vec![Value::Integer(1), Value::Integer(5), Value::Integer(3)], - ); - - let mut temp_project = project_op.clone(); - temp_project.initialize(delta); - let result = temp_project.get_current_state(); - - let (output, _weight) = result.changes.first().unwrap(); - assert_eq!( - output.values, - vec![ - Value::Integer(1), - Value::Integer(10), - Value::Integer(3), - Value::Integer(9), - ] - ); - } - - #[test] - fn test_projection_complex_expression() { - let schema = create_test_schema(); - let sql = "CREATE MATERIALIZED VIEW v AS SELECT (a * 2) + (b * 3) as weighted, c / 2 as half FROM t"; - - let view = IncrementalView::from_sql(sql, &schema).unwrap(); - - assert!(view.project_operator.is_some()); - let project_op = view.project_operator.as_ref().unwrap(); - - let mut delta = Delta::new(); - delta.insert( - 1, - vec![Value::Integer(5), Value::Integer(2), Value::Integer(10)], - ); - - let mut temp_project = project_op.clone(); - temp_project.initialize(delta); - let result = temp_project.get_current_state(); - - let (output, _weight) = result.changes.first().unwrap(); - assert_eq!(output.values, vec![Value::Integer(16), Value::Integer(5),]); - } - - #[test] - fn test_projection_with_where_clause() { - let schema = create_test_schema(); - let sql = "CREATE MATERIALIZED VIEW v AS SELECT a, a * 2 as doubled FROM t WHERE b > 2"; - - let view = IncrementalView::from_sql(sql, &schema).unwrap(); - - assert!(view.project_operator.is_some()); - assert!(view.filter_operator.is_some()); - - let project_op = view.project_operator.as_ref().unwrap(); - - let mut delta = Delta::new(); - delta.insert( - 1, - vec![Value::Integer(4), Value::Integer(3), Value::Integer(0)], - ); - - let mut temp_project = project_op.clone(); - temp_project.initialize(delta); - let result = temp_project.get_current_state(); - - let (output, _weight) = result.changes.first().unwrap(); - assert_eq!(output.values, vec![Value::Integer(4), Value::Integer(8),]); - } - - #[test] - fn test_projection_more_output_columns_than_input() { - let schema = create_test_schema(); - let sql = "CREATE MATERIALIZED VIEW v AS SELECT a, b, a * 2 as doubled_a, b * 3 as tripled_b, a + b as sum, hex(c) as hex_c FROM t"; - - let view = IncrementalView::from_sql(sql, &schema).unwrap(); - - assert!(view.project_operator.is_some()); - let project_op = view.project_operator.as_ref().unwrap(); - - let mut delta = Delta::new(); - delta.insert( - 1, - vec![Value::Integer(5), Value::Integer(2), Value::Integer(15)], - ); - - let mut temp_project = project_op.clone(); - temp_project.initialize(delta); - let result = temp_project.get_current_state(); - - let (output, _weight) = result.changes.first().unwrap(); - // 3 input columns -> 6 output columns - assert_eq!( - output.values, - vec![ - Value::Integer(5), // a - Value::Integer(2), // b - Value::Integer(10), // a * 2 - Value::Integer(6), // b * 3 - Value::Integer(7), // a + b - Value::Text("3135".into()), // hex(15) - SQLite converts to string "15" then hex encodes - ] - ); - } - - #[test] - fn test_aggregation_count_with_group_by() { - let schema = create_test_schema(); - let sql = "CREATE MATERIALIZED VIEW v AS SELECT a, COUNT(*) FROM t GROUP BY a"; - - let mut view = IncrementalView::from_sql(sql, &schema).unwrap(); - - // Verify the view has an aggregate operator - assert!(view.aggregate_operator.is_some()); - - // Insert some test data - let mut delta = Delta::new(); - delta.insert( - 1, - vec![Value::Integer(1), Value::Integer(10), Value::Integer(100)], - ); - delta.insert( - 2, - vec![Value::Integer(2), Value::Integer(20), Value::Integer(200)], - ); - delta.insert( - 3, - vec![Value::Integer(1), Value::Integer(30), Value::Integer(300)], - ); - - // Process the delta - view.merge_delta(&delta); - - // Verify we only processed the 3 rows we inserted - assert_eq!(view.tracker.lock().unwrap().aggregation_updates, 3); - - // Check the aggregated results - let results = view.current_data(None); - - // Should have 2 groups: a=1 with count=2, a=2 with count=1 - assert_eq!(results.len(), 2); - - // Find the group with a=1 - let group1 = results - .iter() - .find(|(_, vals)| vals[0] == Value::Integer(1)) - .unwrap(); - assert_eq!(group1.1[0], Value::Integer(1)); // a=1 - assert_eq!(group1.1[1], Value::Integer(2)); // COUNT(*)=2 - - // Find the group with a=2 - let group2 = results - .iter() - .find(|(_, vals)| vals[0] == Value::Integer(2)) - .unwrap(); - assert_eq!(group2.1[0], Value::Integer(2)); // a=2 - assert_eq!(group2.1[1], Value::Integer(1)); // COUNT(*)=1 - } - - #[test] - fn test_aggregation_sum_with_filter() { - let schema = create_test_schema(); - let sql = "CREATE MATERIALIZED VIEW v AS SELECT SUM(b) FROM t WHERE a > 1"; - - let mut view = IncrementalView::from_sql(sql, &schema).unwrap(); - - assert!(view.aggregate_operator.is_some()); - assert!(view.filter_operator.is_some()); - - let mut delta = Delta::new(); - delta.insert( - 1, - vec![Value::Integer(1), Value::Integer(10), Value::Integer(100)], - ); - delta.insert( - 2, - vec![Value::Integer(2), Value::Integer(20), Value::Integer(200)], - ); - delta.insert( - 3, - vec![Value::Integer(3), Value::Integer(30), Value::Integer(300)], - ); - - view.merge_delta(&delta); - - // Should filter all 3 rows - assert_eq!(view.tracker.lock().unwrap().filter_evaluations, 3); - // But only aggregate the 2 that passed the filter (a > 1) - assert_eq!(view.tracker.lock().unwrap().aggregation_updates, 2); - - let results = view.current_data(None); - - // Should have 1 row with sum of b where a > 1 - assert_eq!(results.len(), 1); - assert_eq!(results[0].1[0], Value::Integer(50)); // SUM(b) = 20 + 30 - } - - #[test] - fn test_aggregation_incremental_updates() { - let schema = create_test_schema(); - let sql = "CREATE MATERIALIZED VIEW v AS SELECT a, COUNT(*), SUM(b) FROM t GROUP BY a"; - - let mut view = IncrementalView::from_sql(sql, &schema).unwrap(); - - // Initial insert - let mut delta1 = Delta::new(); - delta1.insert( - 1, - vec![Value::Integer(1), Value::Integer(10), Value::Integer(100)], - ); - delta1.insert( - 2, - vec![Value::Integer(1), Value::Integer(20), Value::Integer(200)], - ); - - view.merge_delta(&delta1); - - // Verify we processed exactly 2 rows for the first batch - assert_eq!(view.tracker.lock().unwrap().aggregation_updates, 2); - - // Check initial state - let results1 = view.current_data(None); - assert_eq!(results1.len(), 1); - assert_eq!(results1[0].1[1], Value::Integer(2)); // COUNT(*)=2 - assert_eq!(results1[0].1[2], Value::Integer(30)); // SUM(b)=30 - - // Reset counter to track second batch separately - view.tracker.lock().unwrap().aggregation_updates = 0; - - // Add more data - let mut delta2 = Delta::new(); - delta2.insert( - 3, - vec![Value::Integer(1), Value::Integer(5), Value::Integer(300)], - ); - delta2.insert( - 4, - vec![Value::Integer(2), Value::Integer(15), Value::Integer(400)], - ); - - view.merge_delta(&delta2); - - // Should only process the 2 new rows, not recompute everything - assert_eq!(view.tracker.lock().unwrap().aggregation_updates, 2); - - // Check updated state - let results2 = view.current_data(None); - assert_eq!(results2.len(), 2); - - // Group a=1 - let group1 = results2 - .iter() - .find(|(_, vals)| vals[0] == Value::Integer(1)) - .unwrap(); - assert_eq!(group1.1[1], Value::Integer(3)); // COUNT(*)=3 - assert_eq!(group1.1[2], Value::Integer(35)); // SUM(b)=35 - - // Group a=2 - let group2 = results2 - .iter() - .find(|(_, vals)| vals[0] == Value::Integer(2)) - .unwrap(); - assert_eq!(group2.1[1], Value::Integer(1)); // COUNT(*)=1 - assert_eq!(group2.1[2], Value::Integer(15)); // SUM(b)=15 - } -} diff --git a/core/translate/logical.rs b/core/translate/logical.rs new file mode 100644 index 000000000..df8bdd13a --- /dev/null +++ b/core/translate/logical.rs @@ -0,0 +1,3076 @@ +//! Logical plan representation for SQL queries +//! +//! This module provides a platform-independent intermediate representation +//! for SQL queries. The logical plan is a DAG (Directed Acyclic Graph) that +//! supports CTEs and can be used for query optimization before being compiled +//! to an execution plan (e.g., DBSP circuits). +//! +//! The main entry point is `LogicalPlanBuilder` which constructs logical plans +//! from SQL AST nodes. +use crate::function::AggFunc; +use crate::schema::{Schema, Type}; +use crate::types::Value; +use crate::{LimboError, Result}; +use std::collections::HashMap; +use std::fmt::{self, Display, Formatter}; +use std::sync::Arc; +use turso_parser::ast; + +/// Result type for preprocessing aggregate expressions +type PreprocessAggregateResult = ( + bool, // needs_pre_projection + Vec, // pre_projection_exprs + Vec<(String, Type)>, // pre_projection_schema + Vec, // modified_aggr_exprs +); + +/// Schema information for logical plan nodes +#[derive(Debug, Clone, PartialEq)] +pub struct LogicalSchema { + /// Column names and types + pub columns: Vec<(String, Type)>, +} +/// A reference to a schema that can be shared between nodes +pub type SchemaRef = Arc; + +impl LogicalSchema { + pub fn new(columns: Vec<(String, Type)>) -> Self { + Self { columns } + } + + pub fn empty() -> Self { + Self { + columns: Vec::new(), + } + } + + pub fn column_count(&self) -> usize { + self.columns.len() + } + + pub fn find_column(&self, name: &str) -> Option<(usize, &Type)> { + self.columns + .iter() + .position(|(n, _)| n == name) + .map(|idx| (idx, &self.columns[idx].1)) + } +} + +/// Logical representation of a SQL query plan +#[derive(Debug, Clone, PartialEq)] +pub enum LogicalPlan { + /// Projection - SELECT expressions + Projection(Projection), + /// Filter - WHERE/HAVING clause + Filter(Filter), + /// Aggregate - GROUP BY with aggregate functions + Aggregate(Aggregate), + // TODO: Join - combining two relations (not yet implemented) + // Join(Join), + /// Sort - ORDER BY clause + Sort(Sort), + /// Limit - LIMIT/OFFSET clause + Limit(Limit), + /// Table scan - reading from a base table + TableScan(TableScan), + /// Union - UNION/UNION ALL/INTERSECT/EXCEPT + Union(Union), + /// Distinct - remove duplicates + Distinct(Distinct), + /// Empty relation - no rows + EmptyRelation(EmptyRelation), + /// Values - literal rows (VALUES clause) + Values(Values), + /// CTE support - WITH clause + WithCTE(WithCTE), + /// Reference to a CTE + CTERef(CTERef), +} + +impl LogicalPlan { + /// Get the schema of this plan node + pub fn schema(&self) -> &SchemaRef { + match self { + LogicalPlan::Projection(p) => &p.schema, + LogicalPlan::Filter(f) => f.input.schema(), + LogicalPlan::Aggregate(a) => &a.schema, + // LogicalPlan::Join(j) => &j.schema, + LogicalPlan::Sort(s) => s.input.schema(), + LogicalPlan::Limit(l) => l.input.schema(), + LogicalPlan::TableScan(t) => &t.schema, + LogicalPlan::Union(u) => &u.schema, + LogicalPlan::Distinct(d) => d.input.schema(), + LogicalPlan::EmptyRelation(e) => &e.schema, + LogicalPlan::Values(v) => &v.schema, + LogicalPlan::WithCTE(w) => w.body.schema(), + LogicalPlan::CTERef(c) => &c.schema, + } + } +} + +/// Projection operator - SELECT expressions +#[derive(Debug, Clone, PartialEq)] +pub struct Projection { + pub input: Arc, + pub exprs: Vec, + pub schema: SchemaRef, +} + +/// Filter operator - WHERE/HAVING predicates +#[derive(Debug, Clone, PartialEq)] +pub struct Filter { + pub input: Arc, + pub predicate: LogicalExpr, +} + +/// Aggregate operator - GROUP BY with aggregations +#[derive(Debug, Clone, PartialEq)] +pub struct Aggregate { + pub input: Arc, + pub group_expr: Vec, + pub aggr_expr: Vec, + pub schema: SchemaRef, +} + +// TODO: Join operator (not yet implemented) +// #[derive(Debug, Clone, PartialEq)] +// pub struct Join { +// pub left: Arc, +// pub right: Arc, +// pub join_type: JoinType, +// pub on: Vec<(LogicalExpr, LogicalExpr)>, // Equijoin conditions +// pub filter: Option, // Additional filter conditions +// pub schema: SchemaRef, +// } + +// TODO: Types of joins (not yet implemented) +// #[derive(Debug, Clone, Copy, PartialEq)] +// pub enum JoinType { +// Inner, +// Left, +// Right, +// Full, +// Cross, +// } + +/// Sort operator - ORDER BY +#[derive(Debug, Clone, PartialEq)] +pub struct Sort { + pub input: Arc, + pub exprs: Vec, +} + +/// Sort expression with direction +#[derive(Debug, Clone, PartialEq)] +pub struct SortExpr { + pub expr: LogicalExpr, + pub asc: bool, + pub nulls_first: bool, +} + +/// Limit operator - LIMIT/OFFSET +#[derive(Debug, Clone, PartialEq)] +pub struct Limit { + pub input: Arc, + pub skip: Option, + pub fetch: Option, +} + +/// Table scan operator +#[derive(Debug, Clone, PartialEq)] +pub struct TableScan { + pub table_name: String, + pub schema: SchemaRef, + pub projection: Option>, // Column indices to project +} + +/// Union operator +#[derive(Debug, Clone, PartialEq)] +pub struct Union { + pub inputs: Vec>, + pub all: bool, // true for UNION ALL, false for UNION + pub schema: SchemaRef, +} + +/// Distinct operator +#[derive(Debug, Clone, PartialEq)] +pub struct Distinct { + pub input: Arc, +} + +/// Empty relation - produces no rows +#[derive(Debug, Clone, PartialEq)] +pub struct EmptyRelation { + pub produce_one_row: bool, + pub schema: SchemaRef, +} + +/// Values operator - literal rows +#[derive(Debug, Clone, PartialEq)] +pub struct Values { + pub rows: Vec>, + pub schema: SchemaRef, +} + +/// WITH clause - CTEs +#[derive(Debug, Clone, PartialEq)] +pub struct WithCTE { + pub ctes: HashMap>, + pub body: Arc, +} + +/// Reference to a CTE +#[derive(Debug, Clone, PartialEq)] +pub struct CTERef { + pub name: String, + pub schema: SchemaRef, +} + +/// Logical expression representation +#[derive(Debug, Clone, PartialEq)] +pub enum LogicalExpr { + /// Column reference + Column(Column), + /// Literal value + Literal(Value), + /// Binary expression + BinaryExpr { + left: Box, + op: BinaryOperator, + right: Box, + }, + /// Unary expression + UnaryExpr { + op: UnaryOperator, + expr: Box, + }, + /// Aggregate function + AggregateFunction { + fun: AggregateFunction, + args: Vec, + distinct: bool, + }, + /// Scalar function call + ScalarFunction { fun: String, args: Vec }, + /// CASE expression + Case { + expr: Option>, + when_then: Vec<(LogicalExpr, LogicalExpr)>, + else_expr: Option>, + }, + /// IN list + InList { + expr: Box, + list: Vec, + negated: bool, + }, + /// IN subquery + InSubquery { + expr: Box, + subquery: Arc, + negated: bool, + }, + /// EXISTS subquery + Exists { + subquery: Arc, + negated: bool, + }, + /// Scalar subquery + ScalarSubquery(Arc), + /// Alias for an expression + Alias { + expr: Box, + alias: String, + }, + /// IS NULL / IS NOT NULL + IsNull { + expr: Box, + negated: bool, + }, + /// BETWEEN + Between { + expr: Box, + low: Box, + high: Box, + negated: bool, + }, + /// LIKE pattern matching + Like { + expr: Box, + pattern: Box, + escape: Option, + negated: bool, + }, +} + +/// Column reference +#[derive(Debug, Clone, PartialEq)] +pub struct Column { + pub name: String, + pub table: Option, +} + +impl Column { + pub fn new(name: impl Into) -> Self { + Self { + name: name.into(), + table: None, + } + } + + pub fn with_table(name: impl Into, table: impl Into) -> Self { + Self { + name: name.into(), + table: Some(table.into()), + } + } +} + +impl Display for Column { + fn fmt(&self, f: &mut Formatter) -> fmt::Result { + match &self.table { + Some(t) => write!(f, "{}.{}", t, self.name), + None => write!(f, "{}", self.name), + } + } +} + +/// Type alias for binary operators +pub type BinaryOperator = ast::Operator; + +/// Type alias for unary operators +pub type UnaryOperator = ast::UnaryOperator; + +/// Type alias for aggregate functions +pub type AggregateFunction = AggFunc; + +/// Compiler from AST to LogicalPlan +pub struct LogicalPlanBuilder<'a> { + schema: &'a Schema, + ctes: HashMap>, +} + +impl<'a> LogicalPlanBuilder<'a> { + pub fn new(schema: &'a Schema) -> Self { + Self { + schema, + ctes: HashMap::new(), + } + } + + /// Main entry point: compile a statement to a logical plan + pub fn build_statement(&mut self, stmt: &ast::Stmt) -> Result { + match stmt { + ast::Stmt::Select(select) => self.build_select(select), + _ => Err(LimboError::ParseError( + "Only SELECT statements are currently supported in logical plans".to_string(), + )), + } + } + + // Convert Name to String + fn name_to_string(name: &ast::Name) -> String { + match name { + ast::Name::Ident(s) | ast::Name::Quoted(s) => s.clone(), + } + } + + // Build a SELECT statement + // Build a logical plan from a SELECT statement + fn build_select(&mut self, select: &ast::Select) -> Result { + // Handle WITH clause if present + if let Some(with) = &select.with { + return self.build_with_cte(with, select); + } + + // Build the main query body + let order_by = &select.order_by; + let limit = &select.limit; + self.build_select_body(&select.body, order_by, limit) + } + + // Build WITH CTE + fn build_with_cte(&mut self, with: &ast::With, select: &ast::Select) -> Result { + let mut cte_plans = HashMap::new(); + + // Build each CTE + for cte in &with.ctes { + let cte_plan = self.build_select(&cte.select)?; + let cte_name = Self::name_to_string(&cte.tbl_name); + cte_plans.insert(cte_name.clone(), Arc::new(cte_plan)); + self.ctes + .insert(cte_name.clone(), cte_plans[&cte_name].clone()); + } + + // Build the main body with CTEs available + let order_by = &select.order_by; + let limit = &select.limit; + let body = self.build_select_body(&select.body, order_by, limit)?; + + // Clear CTEs from builder context + for cte in &with.ctes { + self.ctes.remove(&Self::name_to_string(&cte.tbl_name)); + } + + Ok(LogicalPlan::WithCTE(WithCTE { + ctes: cte_plans, + body: Arc::new(body), + })) + } + + // Build SELECT body + fn build_select_body( + &mut self, + body: &ast::SelectBody, + order_by: &[ast::SortedColumn], + limit: &Option, + ) -> Result { + let mut plan = self.build_one_select(&body.select)?; + + // Handle compound operators (UNION, INTERSECT, EXCEPT) + if !body.compounds.is_empty() { + for compound in &body.compounds { + let right = self.build_one_select(&compound.select)?; + plan = Self::build_compound(plan, right, &compound.operator)?; + } + } + + // Apply ORDER BY + if !order_by.is_empty() { + plan = self.build_sort(plan, order_by)?; + } + + // Apply LIMIT + if let Some(limit) = limit { + plan = Self::build_limit(plan, limit)?; + } + + Ok(plan) + } + + // Build a single SELECT (without compounds) + fn build_one_select(&mut self, select: &ast::OneSelect) -> Result { + match select { + ast::OneSelect::Select { + distinctness, + columns, + from, + where_clause, + group_by, + window_clause: _, + } => { + // Start with FROM clause + let mut plan = if let Some(from) = from { + self.build_from(from)? + } else { + // No FROM clause - single row + LogicalPlan::EmptyRelation(EmptyRelation { + produce_one_row: true, + schema: Arc::new(LogicalSchema::empty()), + }) + }; + + // Apply WHERE + if let Some(where_expr) = where_clause { + let predicate = self.build_expr(where_expr, plan.schema())?; + plan = LogicalPlan::Filter(Filter { + input: Arc::new(plan), + predicate, + }); + } + + // Apply GROUP BY and aggregations + if let Some(group_by) = group_by { + plan = self.build_aggregate(plan, group_by, columns)?; + } else if Self::has_aggregates(columns) { + // Aggregation without GROUP BY + plan = self.build_aggregate_no_group(plan, columns)?; + } else { + // Regular projection + plan = self.build_projection(plan, columns)?; + } + + // Apply HAVING (part of GROUP BY) + if let Some(ref group_by) = group_by { + if let Some(ref having_expr) = group_by.having { + let predicate = self.build_expr(having_expr, plan.schema())?; + plan = LogicalPlan::Filter(Filter { + input: Arc::new(plan), + predicate, + }); + } + } + + // Apply DISTINCT + if distinctness.is_some() { + plan = LogicalPlan::Distinct(Distinct { + input: Arc::new(plan), + }); + } + + Ok(plan) + } + ast::OneSelect::Values(values) => self.build_values(values), + } + } + + // Build FROM clause + fn build_from(&mut self, from: &ast::FromClause) -> Result { + let mut plan = { self.build_select_table(&from.select)? }; + + // Handle JOINs + if !from.joins.is_empty() { + for join in &from.joins { + let right = self.build_select_table(&join.table)?; + plan = self.build_join(plan, right, &join.operator, &join.constraint)?; + } + } + + Ok(plan) + } + + // Build a table reference + fn build_select_table(&mut self, table: &ast::SelectTable) -> Result { + match table { + ast::SelectTable::Table(name, _alias, _indexed) => { + let table_name = Self::name_to_string(&name.name); + // Check if it's a CTE reference + if let Some(cte_plan) = self.ctes.get(&table_name) { + return Ok(LogicalPlan::CTERef(CTERef { + name: table_name.clone(), + schema: cte_plan.schema().clone(), + })); + } + + // Regular table scan + let table_schema = self.get_table_schema(&table_name)?; + Ok(LogicalPlan::TableScan(TableScan { + table_name, + schema: table_schema, + projection: None, + })) + } + ast::SelectTable::Select(subquery, _alias) => self.build_select(subquery), + ast::SelectTable::TableCall(_, _, _) => Err(LimboError::ParseError( + "Table-valued functions are not supported in logical plans".to_string(), + )), + ast::SelectTable::Sub(_, _) => Err(LimboError::ParseError( + "Subquery in FROM clause not yet supported".to_string(), + )), + } + } + + // Build JOIN + fn build_join( + &mut self, + _left: LogicalPlan, + _right: LogicalPlan, + _op: &ast::JoinOperator, + _constraint: &Option, + ) -> Result { + Err(LimboError::ParseError( + "JOINs are not yet supported in logical plans".to_string(), + )) + } + + // Build projection + fn build_projection( + &mut self, + input: LogicalPlan, + columns: &[ast::ResultColumn], + ) -> Result { + let input_schema = input.schema(); + let mut proj_exprs = Vec::new(); + let mut schema_columns = Vec::new(); + + for col in columns { + match col { + ast::ResultColumn::Expr(expr, alias) => { + let logical_expr = self.build_expr(expr, input_schema)?; + let col_name = match alias { + Some(as_alias) => match as_alias { + ast::As::As(name) | ast::As::Elided(name) => Self::name_to_string(name), + }, + None => Self::expr_to_column_name(expr), + }; + let col_type = Self::infer_expr_type(&logical_expr, input_schema)?; + + schema_columns.push((col_name.clone(), col_type)); + + if let Some(as_alias) = alias { + let alias_name = match as_alias { + ast::As::As(name) | ast::As::Elided(name) => Self::name_to_string(name), + }; + proj_exprs.push(LogicalExpr::Alias { + expr: Box::new(logical_expr), + alias: alias_name, + }); + } else { + proj_exprs.push(logical_expr); + } + } + ast::ResultColumn::Star => { + // Expand * to all columns + for (name, typ) in &input_schema.columns { + proj_exprs.push(LogicalExpr::Column(Column::new(name.clone()))); + schema_columns.push((name.clone(), *typ)); + } + } + ast::ResultColumn::TableStar(table) => { + // Expand table.* to all columns from that table + let table_name = Self::name_to_string(table); + for (name, typ) in &input_schema.columns { + // Simple check - would need proper table tracking in real implementation + proj_exprs.push(LogicalExpr::Column(Column::with_table( + name.clone(), + table_name.clone(), + ))); + schema_columns.push((name.clone(), *typ)); + } + } + } + } + + Ok(LogicalPlan::Projection(Projection { + input: Arc::new(input), + exprs: proj_exprs, + schema: Arc::new(LogicalSchema::new(schema_columns)), + })) + } + + // Helper function to preprocess aggregate expressions that contain complex arguments + // Returns: (needs_pre_projection, pre_projection_exprs, pre_projection_schema, modified_aggr_exprs) + // + // This will be used in expressions like select sum(hex(a + 2)) from tbl => hex(a + 2) is a + // pre-projection. + // + // Another alternative is to always generate a projection together with an aggregation, and + // just have "a" be the identity projection if we don't have a complex case. But that's quite + // wasteful. + fn preprocess_aggregate_expressions( + aggr_exprs: &[LogicalExpr], + group_exprs: &[LogicalExpr], + input_schema: &SchemaRef, + ) -> Result { + let mut needs_pre_projection = false; + let mut pre_projection_exprs = Vec::new(); + let mut pre_projection_schema = Vec::new(); + let mut modified_aggr_exprs = Vec::new(); + let mut projected_col_counter = 0; + + // First, add all group by expressions to the pre-projection + for expr in group_exprs { + if let LogicalExpr::Column(col) = expr { + pre_projection_exprs.push(expr.clone()); + let col_type = Self::infer_expr_type(expr, input_schema)?; + pre_projection_schema.push((col.name.clone(), col_type)); + } else { + // Complex group by expression - project it + needs_pre_projection = true; + let proj_col_name = format!("__group_proj_{projected_col_counter}"); + projected_col_counter += 1; + pre_projection_exprs.push(expr.clone()); + let col_type = Self::infer_expr_type(expr, input_schema)?; + pre_projection_schema.push((proj_col_name.clone(), col_type)); + } + } + + // Check each aggregate expression + for agg_expr in aggr_exprs { + if let LogicalExpr::AggregateFunction { + fun, + args, + distinct, + } = agg_expr + { + let mut modified_args = Vec::new(); + for arg in args { + // Check if the argument is a simple column reference or a complex expression + match arg { + LogicalExpr::Column(_) => { + // Simple column - just use it + modified_args.push(arg.clone()); + // Make sure the column is in the pre-projection + if !pre_projection_exprs.iter().any(|e| e == arg) { + pre_projection_exprs.push(arg.clone()); + let col_type = Self::infer_expr_type(arg, input_schema)?; + if let LogicalExpr::Column(col) = arg { + pre_projection_schema.push((col.name.clone(), col_type)); + } + } + } + _ => { + // Complex expression - we need to project it first + needs_pre_projection = true; + let proj_col_name = format!("__agg_arg_proj_{projected_col_counter}"); + projected_col_counter += 1; + + // Add the expression to the pre-projection + pre_projection_exprs.push(arg.clone()); + let col_type = Self::infer_expr_type(arg, input_schema)?; + pre_projection_schema.push((proj_col_name.clone(), col_type)); + + // In the aggregate, reference the projected column + modified_args.push(LogicalExpr::Column(Column::new(proj_col_name))); + } + } + } + + // Create the modified aggregate expression + modified_aggr_exprs.push(LogicalExpr::AggregateFunction { + fun: fun.clone(), + args: modified_args, + distinct: *distinct, + }); + } else { + modified_aggr_exprs.push(agg_expr.clone()); + } + } + + Ok(( + needs_pre_projection, + pre_projection_exprs, + pre_projection_schema, + modified_aggr_exprs, + )) + } + + // Build aggregate with GROUP BY + fn build_aggregate( + &mut self, + input: LogicalPlan, + group_by: &ast::GroupBy, + columns: &[ast::ResultColumn], + ) -> Result { + let input_schema = input.schema(); + + // Build grouping expressions + let mut group_exprs = Vec::new(); + for expr in &group_by.exprs { + group_exprs.push(self.build_expr(expr, input_schema)?); + } + + // Use the unified aggregate builder + self.build_aggregate_internal(input, group_exprs, columns) + } + + // Build aggregate without GROUP BY + fn build_aggregate_no_group( + &mut self, + input: LogicalPlan, + columns: &[ast::ResultColumn], + ) -> Result { + // Use the unified aggregate builder with empty group expressions + self.build_aggregate_internal(input, vec![], columns) + } + + // Unified internal aggregate builder that handles both GROUP BY and non-GROUP BY cases + fn build_aggregate_internal( + &mut self, + input: LogicalPlan, + group_exprs: Vec, + columns: &[ast::ResultColumn], + ) -> Result { + let input_schema = input.schema(); + let has_group_by = !group_exprs.is_empty(); + + // Build aggregate expressions and projection expressions + let mut aggr_exprs = Vec::new(); + let mut projection_exprs = Vec::new(); + let mut aggregate_schema_columns = Vec::new(); + + // First, add GROUP BY columns to the aggregate output schema + // These are always part of the aggregate operator's output + for group_expr in &group_exprs { + let col_name = match group_expr { + LogicalExpr::Column(col) => col.name.clone(), + _ => { + // For complex GROUP BY expressions, generate a name + format!("__group_{}", aggregate_schema_columns.len()) + } + }; + let col_type = Self::infer_expr_type(group_expr, input_schema)?; + aggregate_schema_columns.push((col_name, col_type)); + } + + // Track aggregates we've already seen to avoid duplicates + let mut aggregate_map: std::collections::HashMap = + std::collections::HashMap::new(); + + for col in columns { + match col { + ast::ResultColumn::Expr(expr, alias) => { + let logical_expr = self.build_expr(expr, input_schema)?; + + // Determine the column name for this expression + let col_name = match alias { + Some(as_alias) => match as_alias { + ast::As::As(name) | ast::As::Elided(name) => Self::name_to_string(name), + }, + None => Self::expr_to_column_name(expr), + }; + + // Check if the TOP-LEVEL expression is an aggregate + // We only care about immediate aggregates, not nested ones + if Self::is_aggregate_expr(&logical_expr) { + // Pure aggregate function - check if we've seen it before + let agg_key = format!("{logical_expr:?}"); + + let agg_col_name = if let Some(existing_name) = aggregate_map.get(&agg_key) + { + // Reuse existing aggregate + existing_name.clone() + } else { + // New aggregate - add it + let col_type = Self::infer_expr_type(&logical_expr, input_schema)?; + aggregate_schema_columns.push((col_name.clone(), col_type)); + aggr_exprs.push(logical_expr); + aggregate_map.insert(agg_key, col_name.clone()); + col_name.clone() + }; + + // In the projection, reference this aggregate by name + projection_exprs.push(LogicalExpr::Column(Column { + name: agg_col_name, + table: None, + })); + } else if Self::contains_aggregate(&logical_expr) { + // This is an expression that contains an aggregate somewhere + // (e.g., sum(a + 2) * 2) + // We need to extract aggregates and replace them with column references + let (processed_expr, extracted_aggs) = + Self::extract_and_replace_aggregates_with_dedup( + logical_expr, + &mut aggregate_map, + )?; + + // Add only new aggregates + for (agg_expr, agg_name) in extracted_aggs { + let agg_type = Self::infer_expr_type(&agg_expr, input_schema)?; + aggregate_schema_columns.push((agg_name, agg_type)); + aggr_exprs.push(agg_expr); + } + + // Add the processed expression (with column refs) to projection + projection_exprs.push(processed_expr); + } else { + // Non-aggregate expression - validation depends on GROUP BY presence + if has_group_by { + // With GROUP BY: only allow constants and grouped columns + // TODO: SQLite actually allows any column here and returns the value from + // the first row encountered in each group. We should support this in the + // future for full SQLite compatibility, but for now we're stricter to + // simplify the DBSP compilation. + if !Self::is_constant_expr(&logical_expr) + && !Self::is_valid_in_group_by(&logical_expr, &group_exprs) + { + return Err(LimboError::ParseError(format!( + "Column '{col_name}' must appear in the GROUP BY clause or be used in an aggregate function" + ))); + } + } else { + // Without GROUP BY: only allow constant expressions + // TODO: SQLite allows any column here and returns a value from an + // arbitrary row. We should support this for full compatibility, + // but for now we're stricter to simplify DBSP compilation. + if !Self::is_constant_expr(&logical_expr) { + return Err(LimboError::ParseError(format!( + "Column '{col_name}' must be used in an aggregate function when using aggregates without GROUP BY" + ))); + } + } + projection_exprs.push(logical_expr); + } + } + _ => { + let error_msg = if has_group_by { + "* not supported with GROUP BY".to_string() + } else { + "* not supported with aggregate functions".to_string() + }; + return Err(LimboError::ParseError(error_msg)); + } + } + } + + // Check if any aggregate functions have complex expressions as arguments + // If so, we need to insert a projection before the aggregate + let ( + needs_pre_projection, + pre_projection_exprs, + pre_projection_schema, + modified_aggr_exprs, + ) = Self::preprocess_aggregate_expressions(&aggr_exprs, &group_exprs, input_schema)?; + + // Build the final schema for the projection + let mut projection_schema_columns = Vec::new(); + for (i, expr) in projection_exprs.iter().enumerate() { + let col_name = if i < columns.len() { + match &columns[i] { + ast::ResultColumn::Expr(e, alias) => match alias { + Some(as_alias) => match as_alias { + ast::As::As(name) | ast::As::Elided(name) => Self::name_to_string(name), + }, + None => Self::expr_to_column_name(e), + }, + _ => format!("col_{i}"), + } + } else { + format!("col_{i}") + }; + + // For type inference, we need the aggregate schema for column references + let aggregate_schema = LogicalSchema::new(aggregate_schema_columns.clone()); + let col_type = Self::infer_expr_type(expr, &Arc::new(aggregate_schema))?; + projection_schema_columns.push((col_name, col_type)); + } + + // Create the input plan (with pre-projection if needed) + let aggregate_input = if needs_pre_projection { + Arc::new(LogicalPlan::Projection(Projection { + input: Arc::new(input), + exprs: pre_projection_exprs, + schema: Arc::new(LogicalSchema::new(pre_projection_schema)), + })) + } else { + Arc::new(input) + }; + + // Use modified aggregate expressions if we inserted a pre-projection + let final_aggr_exprs = if needs_pre_projection { + modified_aggr_exprs + } else { + aggr_exprs + }; + + // Check if we need the outer projection + // We need a projection if: + // 1. Any expression is more complex than a simple column reference (e.g., abs(sum(id))) + // 2. We're selecting a different set of columns than what the aggregate outputs + // 3. Columns are renamed or reordered + let needs_outer_projection = { + // Check if any expression is more complex than a simple column reference + let has_complex_exprs = projection_exprs + .iter() + .any(|expr| !matches!(expr, LogicalExpr::Column(_))); + + if has_complex_exprs { + true + } else { + // All are simple columns - check if we're selecting exactly what the aggregate outputs + // The projection might be selecting a subset (e.g., only aggregates without group columns) + // or reordering columns, or using different names + + // For now, keep it simple: if schemas don't match exactly, we need projection + // This handles all cases: subset selection, reordering, renaming + projection_schema_columns != aggregate_schema_columns + } + }; + + // Create the aggregate node + let aggregate_plan = LogicalPlan::Aggregate(Aggregate { + input: aggregate_input, + group_expr: group_exprs, + aggr_expr: final_aggr_exprs, + schema: Arc::new(LogicalSchema::new(aggregate_schema_columns)), + }); + + if needs_outer_projection { + Ok(LogicalPlan::Projection(Projection { + input: Arc::new(aggregate_plan), + exprs: projection_exprs, + schema: Arc::new(LogicalSchema::new(projection_schema_columns)), + })) + } else { + // No projection needed - the aggregate output is exactly what we want + Ok(aggregate_plan) + } + } + + /// Build VALUES clause + #[allow(clippy::vec_box)] + fn build_values(&mut self, values: &[Vec>]) -> Result { + if values.is_empty() { + return Err(LimboError::ParseError("Empty VALUES clause".to_string())); + } + + let mut rows = Vec::new(); + let first_row_len = values[0].len(); + + // Infer schema from first row + let mut schema_columns = Vec::new(); + for (i, _) in values[0].iter().enumerate() { + schema_columns.push((format!("column{}", i + 1), Type::Text)); + } + + for row in values { + if row.len() != first_row_len { + return Err(LimboError::ParseError( + "All rows in VALUES must have the same number of columns".to_string(), + )); + } + + let mut logical_row = Vec::new(); + for expr in row { + // VALUES doesn't have input schema + let empty_schema = Arc::new(LogicalSchema::empty()); + logical_row.push(self.build_expr(expr, &empty_schema)?); + } + rows.push(logical_row); + } + + Ok(LogicalPlan::Values(Values { + rows, + schema: Arc::new(LogicalSchema::new(schema_columns)), + })) + } + + // Build SORT + fn build_sort( + &mut self, + input: LogicalPlan, + exprs: &[ast::SortedColumn], + ) -> Result { + let input_schema = input.schema(); + let mut sort_exprs = Vec::new(); + + for sorted_col in exprs { + let expr = self.build_expr(&sorted_col.expr, input_schema)?; + sort_exprs.push(SortExpr { + expr, + asc: sorted_col.order != Some(ast::SortOrder::Desc), + nulls_first: sorted_col.nulls == Some(ast::NullsOrder::First), + }); + } + + Ok(LogicalPlan::Sort(Sort { + input: Arc::new(input), + exprs: sort_exprs, + })) + } + + // Build LIMIT + fn build_limit(input: LogicalPlan, limit: &ast::Limit) -> Result { + let fetch = match limit.expr.as_ref() { + ast::Expr::Literal(ast::Literal::Numeric(s)) => s.parse::().ok(), + _ => { + return Err(LimboError::ParseError( + "LIMIT must be a literal integer".to_string(), + )) + } + }; + + let skip = if let Some(offset) = &limit.offset { + match offset.as_ref() { + ast::Expr::Literal(ast::Literal::Numeric(s)) => s.parse::().ok(), + _ => { + return Err(LimboError::ParseError( + "OFFSET must be a literal integer".to_string(), + )) + } + } + } else { + None + }; + + Ok(LogicalPlan::Limit(Limit { + input: Arc::new(input), + skip, + fetch, + })) + } + + // Build compound operator (UNION, INTERSECT, EXCEPT) + fn build_compound( + left: LogicalPlan, + right: LogicalPlan, + op: &ast::CompoundOperator, + ) -> Result { + // Check schema compatibility + if left.schema().column_count() != right.schema().column_count() { + return Err(LimboError::ParseError( + "UNION/INTERSECT/EXCEPT requires same number of columns".to_string(), + )); + } + + let all = matches!(op, ast::CompoundOperator::UnionAll); + + match op { + ast::CompoundOperator::Union | ast::CompoundOperator::UnionAll => { + let schema = left.schema().clone(); + Ok(LogicalPlan::Union(Union { + inputs: vec![Arc::new(left), Arc::new(right)], + all, + schema, + })) + } + _ => Err(LimboError::ParseError( + "INTERSECT and EXCEPT not yet supported in logical plans".to_string(), + )), + } + } + + // Build expression from AST + fn build_expr(&mut self, expr: &ast::Expr, _schema: &SchemaRef) -> Result { + match expr { + ast::Expr::Id(name) => Ok(LogicalExpr::Column(Column::new(Self::name_to_string(name)))), + + ast::Expr::DoublyQualified(db, table, col) => { + Ok(LogicalExpr::Column(Column::with_table( + Self::name_to_string(col), + format!( + "{}.{}", + Self::name_to_string(db), + Self::name_to_string(table) + ), + ))) + } + + ast::Expr::Qualified(table, col) => Ok(LogicalExpr::Column(Column::with_table( + Self::name_to_string(col), + Self::name_to_string(table), + ))), + + ast::Expr::Literal(lit) => Ok(LogicalExpr::Literal(Self::build_literal(lit)?)), + + ast::Expr::Binary(lhs, op, rhs) => { + // Special case: IS NULL and IS NOT NULL + if matches!(op, ast::Operator::Is | ast::Operator::IsNot) { + if let ast::Expr::Literal(ast::Literal::Null) = rhs.as_ref() { + let expr = Box::new(self.build_expr(lhs, _schema)?); + return Ok(LogicalExpr::IsNull { + expr, + negated: matches!(op, ast::Operator::IsNot), + }); + } + } + + let left = Box::new(self.build_expr(lhs, _schema)?); + let right = Box::new(self.build_expr(rhs, _schema)?); + Ok(LogicalExpr::BinaryExpr { + left, + op: *op, + right, + }) + } + + ast::Expr::Unary(op, expr) => { + let inner = Box::new(self.build_expr(expr, _schema)?); + Ok(LogicalExpr::UnaryExpr { + op: *op, + expr: inner, + }) + } + + ast::Expr::FunctionCall { + name, + distinctness, + args, + filter_over, + .. + } => { + // Check for window functions (OVER clause) + if filter_over.over_clause.is_some() { + return Err(LimboError::ParseError( + "Unsupported expression type: window functions are not yet supported" + .to_string(), + )); + } + + let func_name = Self::name_to_string(name); + let arg_count = args.len(); + // Check if it's an aggregate function (considering argument count for min/max) + if let Some(agg_fun) = Self::parse_aggregate_function(&func_name, arg_count) { + let distinct = distinctness.is_some(); + let arg_exprs = args + .iter() + .map(|e| self.build_expr(e, _schema)) + .collect::>>()?; + Ok(LogicalExpr::AggregateFunction { + fun: agg_fun, + args: arg_exprs, + distinct, + }) + } else { + // Regular scalar function + let arg_exprs = args + .iter() + .map(|e| self.build_expr(e, _schema)) + .collect::>>()?; + Ok(LogicalExpr::ScalarFunction { + fun: func_name, + args: arg_exprs, + }) + } + } + + ast::Expr::FunctionCallStar { name, .. } => { + // Handle COUNT(*) and similar + let func_name = Self::name_to_string(name); + // FunctionCallStar always has 0 args (it's the * form) + if let Some(agg_fun) = Self::parse_aggregate_function(&func_name, 0) { + Ok(LogicalExpr::AggregateFunction { + fun: agg_fun, + args: vec![], + distinct: false, + }) + } else { + Err(LimboError::ParseError(format!( + "Function {func_name}(*) is not supported" + ))) + } + } + + ast::Expr::Case { + base, + when_then_pairs, + else_expr, + } => { + let case_expr = if let Some(e) = base { + Some(Box::new(self.build_expr(e, _schema)?)) + } else { + None + }; + + let when_then_exprs = when_then_pairs + .iter() + .map(|(when, then)| { + Ok(( + self.build_expr(when, _schema)?, + self.build_expr(then, _schema)?, + )) + }) + .collect::>>()?; + + let else_result = if let Some(e) = else_expr { + Some(Box::new(self.build_expr(e, _schema)?)) + } else { + None + }; + + Ok(LogicalExpr::Case { + expr: case_expr, + when_then: when_then_exprs, + else_expr: else_result, + }) + } + + ast::Expr::InList { lhs, not, rhs } => { + let expr = Box::new(self.build_expr(lhs, _schema)?); + let list = rhs + .iter() + .map(|e| self.build_expr(e, _schema)) + .collect::>>()?; + Ok(LogicalExpr::InList { + expr, + list, + negated: *not, + }) + } + + ast::Expr::InSelect { lhs, not, rhs } => { + let expr = Box::new(self.build_expr(lhs, _schema)?); + let subquery = Arc::new(self.build_select(rhs)?); + Ok(LogicalExpr::InSubquery { + expr, + subquery, + negated: *not, + }) + } + + ast::Expr::Exists(select) => { + let subquery = Arc::new(self.build_select(select)?); + Ok(LogicalExpr::Exists { + subquery, + negated: false, + }) + } + + ast::Expr::Subquery(select) => { + let subquery = Arc::new(self.build_select(select)?); + Ok(LogicalExpr::ScalarSubquery(subquery)) + } + + ast::Expr::IsNull(lhs) => { + let expr = Box::new(self.build_expr(lhs, _schema)?); + Ok(LogicalExpr::IsNull { + expr, + negated: false, + }) + } + + ast::Expr::NotNull(lhs) => { + let expr = Box::new(self.build_expr(lhs, _schema)?); + Ok(LogicalExpr::IsNull { + expr, + negated: true, + }) + } + + ast::Expr::Between { + lhs, + not, + start, + end, + } => { + let expr = Box::new(self.build_expr(lhs, _schema)?); + let low = Box::new(self.build_expr(start, _schema)?); + let high = Box::new(self.build_expr(end, _schema)?); + Ok(LogicalExpr::Between { + expr, + low, + high, + negated: *not, + }) + } + + ast::Expr::Like { + lhs, + not, + op: _, + rhs, + escape, + } => { + let expr = Box::new(self.build_expr(lhs, _schema)?); + let pattern = Box::new(self.build_expr(rhs, _schema)?); + let escape_char = escape.as_ref().and_then(|e| { + if let ast::Expr::Literal(ast::Literal::String(s)) = e.as_ref() { + s.chars().next() + } else { + None + } + }); + Ok(LogicalExpr::Like { + expr, + pattern, + escape: escape_char, + negated: *not, + }) + } + + ast::Expr::Parenthesized(exprs) => { + // the assumption is that there is at least one parenthesis here. + // If this is not true, then I don't understand this code and can't be trusted. + assert!(!exprs.is_empty()); + // Multiple expressions in parentheses is unusual but handle it + // by building the first one (SQLite behavior) + self.build_expr(&exprs[0], _schema) + } + + _ => Err(LimboError::ParseError(format!( + "Unsupported expression type in logical plan: {expr:?}" + ))), + } + } + + /// Build literal value + fn build_literal(lit: &ast::Literal) -> Result { + match lit { + ast::Literal::Null => Ok(Value::Null), + ast::Literal::Keyword(k) if k.eq_ignore_ascii_case("true") => Ok(Value::Integer(1)), // SQLite uses int for bool + ast::Literal::Keyword(k) if k.eq_ignore_ascii_case("false") => Ok(Value::Integer(0)), // SQLite uses int for bool + ast::Literal::Keyword(k) => Ok(Value::Text(k.clone().into())), + ast::Literal::Numeric(s) => { + if let Ok(i) = s.parse::() { + Ok(Value::Integer(i)) + } else if let Ok(f) = s.parse::() { + Ok(Value::Float(f)) + } else { + Ok(Value::Text(s.clone().into())) + } + } + ast::Literal::String(s) => { + // Strip surrounding quotes from the SQL literal + // The parser includes quotes in the string value + let unquoted = if s.starts_with('\'') && s.ends_with('\'') && s.len() > 1 { + &s[1..s.len() - 1] + } else { + s.as_str() + }; + Ok(Value::Text(unquoted.to_string().into())) + } + ast::Literal::Blob(b) => Ok(Value::Blob(b.clone().into())), + ast::Literal::CurrentDate + | ast::Literal::CurrentTime + | ast::Literal::CurrentTimestamp => Err(LimboError::ParseError( + "Temporal literals not yet supported".to_string(), + )), + } + } + + /// Parse aggregate function name (considering argument count for min/max) + fn parse_aggregate_function(name: &str, arg_count: usize) -> Option { + match name.to_uppercase().as_str() { + "COUNT" => Some(AggFunc::Count), + "SUM" => Some(AggFunc::Sum), + "AVG" => Some(AggFunc::Avg), + // MIN and MAX are only aggregates with 1 argument + // With 2+ arguments, they're scalar functions + "MIN" if arg_count == 1 => Some(AggFunc::Min), + "MAX" if arg_count == 1 => Some(AggFunc::Max), + "GROUP_CONCAT" => Some(AggFunc::GroupConcat), + "STRING_AGG" => Some(AggFunc::StringAgg), + "TOTAL" => Some(AggFunc::Total), + _ => None, + } + } + + // Check if expression contains aggregates + fn has_aggregates(columns: &[ast::ResultColumn]) -> bool { + for col in columns { + if let ast::ResultColumn::Expr(expr, _) = col { + if Self::expr_has_aggregate(expr) { + return true; + } + } + } + false + } + + // Check if AST expression contains aggregates + fn expr_has_aggregate(expr: &ast::Expr) -> bool { + match expr { + ast::Expr::FunctionCall { name, args, .. } => { + // Check if the function itself is an aggregate (considering arg count for min/max) + let arg_count = args.len(); + if Self::parse_aggregate_function(&Self::name_to_string(name), arg_count).is_some() + { + return true; + } + // Also check if any arguments contain aggregates (for nested functions like HEX(SUM(...))) + args.iter().any(|arg| Self::expr_has_aggregate(arg)) + } + ast::Expr::FunctionCallStar { name, .. } => { + // FunctionCallStar always has 0 args + Self::parse_aggregate_function(&Self::name_to_string(name), 0).is_some() + } + ast::Expr::Binary(lhs, _, rhs) => { + Self::expr_has_aggregate(lhs) || Self::expr_has_aggregate(rhs) + } + ast::Expr::Unary(_, e) => Self::expr_has_aggregate(e), + ast::Expr::Case { + when_then_pairs, + else_expr, + .. + } => { + when_then_pairs + .iter() + .any(|(w, t)| Self::expr_has_aggregate(w) || Self::expr_has_aggregate(t)) + || else_expr + .as_ref() + .is_some_and(|e| Self::expr_has_aggregate(e)) + } + ast::Expr::Parenthesized(exprs) => { + // Check if any parenthesized expression contains an aggregate + exprs.iter().any(|e| Self::expr_has_aggregate(e)) + } + _ => false, + } + } + + // Check if logical expression is an aggregate + fn is_aggregate_expr(expr: &LogicalExpr) -> bool { + match expr { + LogicalExpr::AggregateFunction { .. } => true, + LogicalExpr::Alias { expr, .. } => Self::is_aggregate_expr(expr), + _ => false, + } + } + + // Check if logical expression contains an aggregate anywhere + fn contains_aggregate(expr: &LogicalExpr) -> bool { + match expr { + LogicalExpr::AggregateFunction { .. } => true, + LogicalExpr::Alias { expr, .. } => Self::contains_aggregate(expr), + LogicalExpr::BinaryExpr { left, right, .. } => { + Self::contains_aggregate(left) || Self::contains_aggregate(right) + } + LogicalExpr::UnaryExpr { expr, .. } => Self::contains_aggregate(expr), + LogicalExpr::ScalarFunction { args, .. } => args.iter().any(Self::contains_aggregate), + LogicalExpr::Case { + when_then, + else_expr, + .. + } => { + when_then + .iter() + .any(|(w, t)| Self::contains_aggregate(w) || Self::contains_aggregate(t)) + || else_expr + .as_ref() + .is_some_and(|e| Self::contains_aggregate(e)) + } + _ => false, + } + } + + // Check if an expression is a constant (contains only literals) + fn is_constant_expr(expr: &LogicalExpr) -> bool { + match expr { + LogicalExpr::Literal(_) => true, + LogicalExpr::BinaryExpr { left, right, .. } => { + Self::is_constant_expr(left) && Self::is_constant_expr(right) + } + LogicalExpr::UnaryExpr { expr, .. } => Self::is_constant_expr(expr), + LogicalExpr::ScalarFunction { args, .. } => args.iter().all(Self::is_constant_expr), + LogicalExpr::Alias { expr, .. } => Self::is_constant_expr(expr), + _ => false, + } + } + + // Check if an expression is valid in GROUP BY context + // An expression is valid if it's: + // 1. A constant literal + // 2. An aggregate function + // 3. A grouping column (or expression involving only grouping columns) + fn is_valid_in_group_by(expr: &LogicalExpr, group_exprs: &[LogicalExpr]) -> bool { + match expr { + LogicalExpr::Literal(_) => true, // Constants are always valid + LogicalExpr::AggregateFunction { .. } => true, // Aggregates are valid + LogicalExpr::Column(col) => { + // Check if this column is in the GROUP BY + group_exprs.iter().any(|g| match g { + LogicalExpr::Column(gcol) => gcol.name == col.name, + _ => false, + }) + } + LogicalExpr::BinaryExpr { left, right, .. } => { + // Both sides must be valid + Self::is_valid_in_group_by(left, group_exprs) + && Self::is_valid_in_group_by(right, group_exprs) + } + LogicalExpr::UnaryExpr { expr, .. } => Self::is_valid_in_group_by(expr, group_exprs), + LogicalExpr::ScalarFunction { args, .. } => { + // All arguments must be valid + args.iter() + .all(|arg| Self::is_valid_in_group_by(arg, group_exprs)) + } + LogicalExpr::Alias { expr, .. } => Self::is_valid_in_group_by(expr, group_exprs), + _ => false, // Other expressions are not valid + } + } + + // Extract aggregates from an expression and replace them with column references, with deduplication + // Returns the modified expression and a list of NEW (aggregate_expr, column_name) pairs + fn extract_and_replace_aggregates_with_dedup( + expr: LogicalExpr, + aggregate_map: &mut std::collections::HashMap, + ) -> Result<(LogicalExpr, Vec<(LogicalExpr, String)>)> { + let mut new_aggregates = Vec::new(); + let mut counter = aggregate_map.len(); + let new_expr = Self::replace_aggregates_with_columns_dedup( + expr, + &mut new_aggregates, + aggregate_map, + &mut counter, + )?; + Ok((new_expr, new_aggregates)) + } + + // Recursively replace aggregate functions with column references, with deduplication + fn replace_aggregates_with_columns_dedup( + expr: LogicalExpr, + new_aggregates: &mut Vec<(LogicalExpr, String)>, + aggregate_map: &mut std::collections::HashMap, + counter: &mut usize, + ) -> Result { + match expr { + LogicalExpr::AggregateFunction { .. } => { + // Found an aggregate - check if we've seen it before + let agg_key = format!("{expr:?}"); + + let col_name = if let Some(existing_name) = aggregate_map.get(&agg_key) { + // Reuse existing aggregate + existing_name.clone() + } else { + // New aggregate + let col_name = format!("__agg_{}", *counter); + *counter += 1; + aggregate_map.insert(agg_key, col_name.clone()); + new_aggregates.push((expr, col_name.clone())); + col_name + }; + + Ok(LogicalExpr::Column(Column { + name: col_name, + table: None, + })) + } + LogicalExpr::BinaryExpr { left, op, right } => { + let new_left = Self::replace_aggregates_with_columns_dedup( + *left, + new_aggregates, + aggregate_map, + counter, + )?; + let new_right = Self::replace_aggregates_with_columns_dedup( + *right, + new_aggregates, + aggregate_map, + counter, + )?; + Ok(LogicalExpr::BinaryExpr { + left: Box::new(new_left), + op, + right: Box::new(new_right), + }) + } + LogicalExpr::UnaryExpr { op, expr } => { + let new_expr = Self::replace_aggregates_with_columns_dedup( + *expr, + new_aggregates, + aggregate_map, + counter, + )?; + Ok(LogicalExpr::UnaryExpr { + op, + expr: Box::new(new_expr), + }) + } + LogicalExpr::ScalarFunction { fun, args } => { + let mut new_args = Vec::new(); + for arg in args { + new_args.push(Self::replace_aggregates_with_columns_dedup( + arg, + new_aggregates, + aggregate_map, + counter, + )?); + } + Ok(LogicalExpr::ScalarFunction { + fun, + args: new_args, + }) + } + LogicalExpr::Case { + expr: case_expr, + when_then, + else_expr, + } => { + let new_case_expr = if let Some(e) = case_expr { + Some(Box::new(Self::replace_aggregates_with_columns_dedup( + *e, + new_aggregates, + aggregate_map, + counter, + )?)) + } else { + None + }; + + let mut new_when_then = Vec::new(); + for (when, then) in when_then { + let new_when = Self::replace_aggregates_with_columns_dedup( + when, + new_aggregates, + aggregate_map, + counter, + )?; + let new_then = Self::replace_aggregates_with_columns_dedup( + then, + new_aggregates, + aggregate_map, + counter, + )?; + new_when_then.push((new_when, new_then)); + } + + let new_else = if let Some(e) = else_expr { + Some(Box::new(Self::replace_aggregates_with_columns_dedup( + *e, + new_aggregates, + aggregate_map, + counter, + )?)) + } else { + None + }; + + Ok(LogicalExpr::Case { + expr: new_case_expr, + when_then: new_when_then, + else_expr: new_else, + }) + } + LogicalExpr::Alias { expr, alias } => { + let new_expr = Self::replace_aggregates_with_columns_dedup( + *expr, + new_aggregates, + aggregate_map, + counter, + )?; + Ok(LogicalExpr::Alias { + expr: Box::new(new_expr), + alias, + }) + } + // Other expressions - keep as is + _ => Ok(expr), + } + } + + // Get column name from expression + fn expr_to_column_name(expr: &ast::Expr) -> String { + match expr { + ast::Expr::Id(name) => Self::name_to_string(name), + ast::Expr::Qualified(_, col) => Self::name_to_string(col), + ast::Expr::FunctionCall { name, .. } => Self::name_to_string(name), + ast::Expr::FunctionCallStar { name, .. } => { + format!("{}(*)", Self::name_to_string(name)) + } + _ => "expr".to_string(), + } + } + + // Get table schema + fn get_table_schema(&self, table_name: &str) -> Result { + // Look up table in schema + let table = self + .schema + .get_table(table_name) + .ok_or_else(|| LimboError::ParseError(format!("Table '{table_name}' not found")))?; + + let mut columns = Vec::new(); + for col in table.columns() { + if let Some(ref name) = col.name { + columns.push((name.clone(), col.ty)); + } + } + + Ok(Arc::new(LogicalSchema::new(columns))) + } + + // Infer expression type + fn infer_expr_type(expr: &LogicalExpr, schema: &SchemaRef) -> Result { + match expr { + LogicalExpr::Column(col) => { + if let Some((_, typ)) = schema.find_column(&col.name) { + Ok(*typ) + } else { + Ok(Type::Text) + } + } + LogicalExpr::Literal(Value::Integer(_)) => Ok(Type::Integer), + LogicalExpr::Literal(Value::Float(_)) => Ok(Type::Real), + LogicalExpr::Literal(Value::Text(_)) => Ok(Type::Text), + LogicalExpr::Literal(Value::Null) => Ok(Type::Null), + LogicalExpr::Literal(Value::Blob(_)) => Ok(Type::Blob), + LogicalExpr::BinaryExpr { op, left, right } => { + match op { + ast::Operator::Add | ast::Operator::Subtract | ast::Operator::Multiply => { + // Infer types of operands to match SQLite/Numeric behavior + let left_type = Self::infer_expr_type(left, schema)?; + let right_type = Self::infer_expr_type(right, schema)?; + + // Integer op Integer = Integer (matching core/numeric/mod.rs behavior) + // Any operation with Real = Real + match (left_type, right_type) { + (Type::Integer, Type::Integer) => Ok(Type::Integer), + (Type::Integer, Type::Real) + | (Type::Real, Type::Integer) + | (Type::Real, Type::Real) => Ok(Type::Real), + (Type::Null, _) | (_, Type::Null) => Ok(Type::Null), + // For Text/Blob, SQLite coerces to numeric, defaulting to Real + _ => Ok(Type::Real), + } + } + ast::Operator::Divide => { + // Division always produces Real in SQLite + Ok(Type::Real) + } + ast::Operator::Modulus => { + // Modulus follows same rules as other arithmetic ops + let left_type = Self::infer_expr_type(left, schema)?; + let right_type = Self::infer_expr_type(right, schema)?; + match (left_type, right_type) { + (Type::Integer, Type::Integer) => Ok(Type::Integer), + _ => Ok(Type::Real), + } + } + ast::Operator::Equals + | ast::Operator::NotEquals + | ast::Operator::Less + | ast::Operator::LessEquals + | ast::Operator::Greater + | ast::Operator::GreaterEquals + | ast::Operator::And + | ast::Operator::Or + | ast::Operator::Is + | ast::Operator::IsNot => Ok(Type::Integer), + ast::Operator::Concat => Ok(Type::Text), + _ => Ok(Type::Text), // Default for other operators + } + } + LogicalExpr::UnaryExpr { op, expr } => match op { + ast::UnaryOperator::Not => Ok(Type::Integer), + ast::UnaryOperator::Negative | ast::UnaryOperator::Positive => { + Self::infer_expr_type(expr, schema) + } + ast::UnaryOperator::BitwiseNot => Ok(Type::Integer), + }, + LogicalExpr::AggregateFunction { fun, .. } => match fun { + AggFunc::Count | AggFunc::Count0 => Ok(Type::Integer), + AggFunc::Sum | AggFunc::Avg | AggFunc::Total => Ok(Type::Real), + AggFunc::Min | AggFunc::Max => Ok(Type::Text), + AggFunc::GroupConcat | AggFunc::StringAgg => Ok(Type::Text), + #[cfg(feature = "json")] + AggFunc::JsonbGroupArray + | AggFunc::JsonGroupArray + | AggFunc::JsonbGroupObject + | AggFunc::JsonGroupObject => Ok(Type::Text), + AggFunc::External(_) => Ok(Type::Text), // Default for external + }, + LogicalExpr::Alias { expr, .. } => Self::infer_expr_type(expr, schema), + LogicalExpr::IsNull { .. } => Ok(Type::Integer), + LogicalExpr::InList { .. } | LogicalExpr::InSubquery { .. } => Ok(Type::Integer), + LogicalExpr::Exists { .. } => Ok(Type::Integer), + LogicalExpr::Between { .. } => Ok(Type::Integer), + LogicalExpr::Like { .. } => Ok(Type::Integer), + _ => Ok(Type::Text), + } + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::schema::{BTreeTable, Column as SchemaColumn, Schema, Type}; + use turso_parser::parser::Parser; + + fn create_test_schema() -> Schema { + let mut schema = Schema::new(false); + + // Create users table + let users_table = BTreeTable { + name: "users".to_string(), + root_page: 2, + primary_key_columns: vec![("id".to_string(), turso_parser::ast::SortOrder::Asc)], + columns: vec![ + SchemaColumn { + name: Some("id".to_string()), + ty: Type::Integer, + ty_str: "INTEGER".to_string(), + primary_key: true, + is_rowid_alias: true, + notnull: true, + default: None, + unique: false, + collation: None, + hidden: false, + }, + SchemaColumn { + name: Some("name".to_string()), + ty: Type::Text, + ty_str: "TEXT".to_string(), + primary_key: false, + is_rowid_alias: false, + notnull: false, + default: None, + unique: false, + collation: None, + hidden: false, + }, + SchemaColumn { + name: Some("age".to_string()), + ty: Type::Integer, + ty_str: "INTEGER".to_string(), + primary_key: false, + is_rowid_alias: false, + notnull: false, + default: None, + unique: false, + collation: None, + hidden: false, + }, + SchemaColumn { + name: Some("email".to_string()), + ty: Type::Text, + ty_str: "TEXT".to_string(), + primary_key: false, + is_rowid_alias: false, + notnull: false, + default: None, + unique: false, + collation: None, + hidden: false, + }, + ], + has_rowid: true, + is_strict: false, + unique_sets: None, + }; + schema.add_btree_table(Arc::new(users_table)); + + // Create orders table + let orders_table = BTreeTable { + name: "orders".to_string(), + root_page: 3, + primary_key_columns: vec![("id".to_string(), turso_parser::ast::SortOrder::Asc)], + columns: vec![ + SchemaColumn { + name: Some("id".to_string()), + ty: Type::Integer, + ty_str: "INTEGER".to_string(), + primary_key: true, + is_rowid_alias: true, + notnull: true, + default: None, + unique: false, + collation: None, + hidden: false, + }, + SchemaColumn { + name: Some("user_id".to_string()), + ty: Type::Integer, + ty_str: "INTEGER".to_string(), + primary_key: false, + is_rowid_alias: false, + notnull: false, + default: None, + unique: false, + collation: None, + hidden: false, + }, + SchemaColumn { + name: Some("product".to_string()), + ty: Type::Text, + ty_str: "TEXT".to_string(), + primary_key: false, + is_rowid_alias: false, + notnull: false, + default: None, + unique: false, + collation: None, + hidden: false, + }, + SchemaColumn { + name: Some("amount".to_string()), + ty: Type::Real, + ty_str: "REAL".to_string(), + primary_key: false, + is_rowid_alias: false, + notnull: false, + default: None, + unique: false, + collation: None, + hidden: false, + }, + ], + has_rowid: true, + is_strict: false, + unique_sets: None, + }; + schema.add_btree_table(Arc::new(orders_table)); + + schema + } + + fn parse_and_build(sql: &str, schema: &Schema) -> Result { + let mut parser = Parser::new(sql.as_bytes()); + let cmd = parser + .next() + .ok_or_else(|| LimboError::ParseError("Empty statement".to_string()))? + .map_err(|e| LimboError::ParseError(e.to_string()))?; + match cmd { + ast::Cmd::Stmt(stmt) => { + let mut builder = LogicalPlanBuilder::new(schema); + builder.build_statement(&stmt) + } + _ => Err(LimboError::ParseError( + "Only SQL statements are supported".to_string(), + )), + } + } + + #[test] + fn test_simple_select() { + let schema = create_test_schema(); + let sql = "SELECT id, name FROM users"; + let plan = parse_and_build(sql, &schema).unwrap(); + + match plan { + LogicalPlan::Projection(proj) => { + assert_eq!(proj.exprs.len(), 2); + assert!(matches!(proj.exprs[0], LogicalExpr::Column(_))); + assert!(matches!(proj.exprs[1], LogicalExpr::Column(_))); + + match &*proj.input { + LogicalPlan::TableScan(scan) => { + assert_eq!(scan.table_name, "users"); + } + _ => panic!("Expected TableScan"), + } + } + _ => panic!("Expected Projection"), + } + } + + #[test] + fn test_select_with_filter() { + let schema = create_test_schema(); + let sql = "SELECT name FROM users WHERE age > 18"; + let plan = parse_and_build(sql, &schema).unwrap(); + + match plan { + LogicalPlan::Projection(proj) => { + assert_eq!(proj.exprs.len(), 1); + + match &*proj.input { + LogicalPlan::Filter(filter) => { + assert!(matches!( + filter.predicate, + LogicalExpr::BinaryExpr { + op: ast::Operator::Greater, + .. + } + )); + + match &*filter.input { + LogicalPlan::TableScan(scan) => { + assert_eq!(scan.table_name, "users"); + } + _ => panic!("Expected TableScan"), + } + } + _ => panic!("Expected Filter"), + } + } + _ => panic!("Expected Projection"), + } + } + + #[test] + fn test_aggregate_with_group_by() { + let schema = create_test_schema(); + let sql = "SELECT user_id, SUM(amount) FROM orders GROUP BY user_id"; + let plan = parse_and_build(sql, &schema).unwrap(); + + match plan { + LogicalPlan::Aggregate(agg) => { + assert_eq!(agg.group_expr.len(), 1); + assert_eq!(agg.aggr_expr.len(), 1); + assert_eq!(agg.schema.column_count(), 2); + + assert!(matches!( + agg.aggr_expr[0], + LogicalExpr::AggregateFunction { + fun: AggFunc::Sum, + .. + } + )); + + match &*agg.input { + LogicalPlan::TableScan(scan) => { + assert_eq!(scan.table_name, "orders"); + } + _ => panic!("Expected TableScan"), + } + } + _ => panic!("Expected Aggregate (no projection)"), + } + } + + #[test] + fn test_aggregate_without_group_by() { + let schema = create_test_schema(); + let sql = "SELECT COUNT(*), MAX(age) FROM users"; + let plan = parse_and_build(sql, &schema).unwrap(); + + match plan { + LogicalPlan::Aggregate(agg) => { + assert_eq!(agg.group_expr.len(), 0); + assert_eq!(agg.aggr_expr.len(), 2); + assert_eq!(agg.schema.column_count(), 2); + + assert!(matches!( + agg.aggr_expr[0], + LogicalExpr::AggregateFunction { + fun: AggFunc::Count, + .. + } + )); + + assert!(matches!( + agg.aggr_expr[1], + LogicalExpr::AggregateFunction { + fun: AggFunc::Max, + .. + } + )); + } + _ => panic!("Expected Aggregate (no projection)"), + } + } + + #[test] + fn test_order_by() { + let schema = create_test_schema(); + let sql = "SELECT name FROM users ORDER BY age DESC, name ASC"; + let plan = parse_and_build(sql, &schema).unwrap(); + + match plan { + LogicalPlan::Sort(sort) => { + assert_eq!(sort.exprs.len(), 2); + assert!(!sort.exprs[0].asc); // DESC + assert!(sort.exprs[1].asc); // ASC + + match &*sort.input { + LogicalPlan::Projection(_) => {} + _ => panic!("Expected Projection"), + } + } + _ => panic!("Expected Sort"), + } + } + + #[test] + fn test_limit_offset() { + let schema = create_test_schema(); + let sql = "SELECT * FROM users LIMIT 10 OFFSET 5"; + let plan = parse_and_build(sql, &schema).unwrap(); + + match plan { + LogicalPlan::Limit(limit) => { + assert_eq!(limit.fetch, Some(10)); + assert_eq!(limit.skip, Some(5)); + } + _ => panic!("Expected Limit"), + } + } + + #[test] + fn test_order_by_with_limit() { + let schema = create_test_schema(); + let sql = "SELECT name FROM users ORDER BY age DESC LIMIT 5"; + let plan = parse_and_build(sql, &schema).unwrap(); + + // Should produce: Limit -> Sort -> Projection -> TableScan + match plan { + LogicalPlan::Limit(limit) => { + assert_eq!(limit.fetch, Some(5)); + assert_eq!(limit.skip, None); + + match &*limit.input { + LogicalPlan::Sort(sort) => { + assert_eq!(sort.exprs.len(), 1); + assert!(!sort.exprs[0].asc); // DESC + + match &*sort.input { + LogicalPlan::Projection(_) => {} + _ => panic!("Expected Projection under Sort"), + } + } + _ => panic!("Expected Sort under Limit"), + } + } + _ => panic!("Expected Limit at top level"), + } + } + + #[test] + fn test_distinct() { + let schema = create_test_schema(); + let sql = "SELECT DISTINCT name FROM users"; + let plan = parse_and_build(sql, &schema).unwrap(); + + match plan { + LogicalPlan::Distinct(distinct) => match &*distinct.input { + LogicalPlan::Projection(_) => {} + _ => panic!("Expected Projection"), + }, + _ => panic!("Expected Distinct"), + } + } + + #[test] + fn test_union() { + let schema = create_test_schema(); + let sql = "SELECT id FROM users UNION SELECT user_id FROM orders"; + let plan = parse_and_build(sql, &schema).unwrap(); + + match plan { + LogicalPlan::Union(union) => { + assert!(!union.all); + assert_eq!(union.inputs.len(), 2); + } + _ => panic!("Expected Union"), + } + } + + #[test] + fn test_union_all() { + let schema = create_test_schema(); + let sql = "SELECT id FROM users UNION ALL SELECT user_id FROM orders"; + let plan = parse_and_build(sql, &schema).unwrap(); + + match plan { + LogicalPlan::Union(union) => { + assert!(union.all); + assert_eq!(union.inputs.len(), 2); + } + _ => panic!("Expected Union"), + } + } + + #[test] + fn test_union_with_order_by() { + let schema = create_test_schema(); + let sql = "SELECT id, name FROM users UNION SELECT user_id, name FROM orders ORDER BY id"; + let plan = parse_and_build(sql, &schema).unwrap(); + + match plan { + LogicalPlan::Sort(sort) => { + assert_eq!(sort.exprs.len(), 1); + assert!(sort.exprs[0].asc); // Default ASC + + match &*sort.input { + LogicalPlan::Union(union) => { + assert!(!union.all); // UNION (not UNION ALL) + assert_eq!(union.inputs.len(), 2); + } + _ => panic!("Expected Union under Sort"), + } + } + _ => panic!("Expected Sort at top level"), + } + } + + #[test] + fn test_with_cte() { + let schema = create_test_schema(); + let sql = "WITH active_users AS (SELECT * FROM users WHERE age > 18) SELECT name FROM active_users"; + let plan = parse_and_build(sql, &schema).unwrap(); + + match plan { + LogicalPlan::WithCTE(with) => { + assert_eq!(with.ctes.len(), 1); + assert!(with.ctes.contains_key("active_users")); + + let cte = &with.ctes["active_users"]; + match &**cte { + LogicalPlan::Projection(proj) => match &*proj.input { + LogicalPlan::Filter(_) => {} + _ => panic!("Expected Filter in CTE"), + }, + _ => panic!("Expected Projection in CTE"), + } + + match &*with.body { + LogicalPlan::Projection(proj) => match &*proj.input { + LogicalPlan::CTERef(cte_ref) => { + assert_eq!(cte_ref.name, "active_users"); + } + _ => panic!("Expected CTERef"), + }, + _ => panic!("Expected Projection in body"), + } + } + _ => panic!("Expected WithCTE"), + } + } + + #[test] + fn test_case_expression() { + let schema = create_test_schema(); + let sql = "SELECT CASE WHEN age < 18 THEN 'minor' WHEN age < 65 THEN 'adult' ELSE 'senior' END FROM users"; + let plan = parse_and_build(sql, &schema).unwrap(); + + match plan { + LogicalPlan::Projection(proj) => { + assert_eq!(proj.exprs.len(), 1); + assert!(matches!(proj.exprs[0], LogicalExpr::Case { .. })); + } + _ => panic!("Expected Projection"), + } + } + + #[test] + fn test_in_list() { + let schema = create_test_schema(); + let sql = "SELECT * FROM users WHERE id IN (1, 2, 3)"; + let plan = parse_and_build(sql, &schema).unwrap(); + + match plan { + LogicalPlan::Projection(proj) => match &*proj.input { + LogicalPlan::Filter(filter) => match &filter.predicate { + LogicalExpr::InList { list, negated, .. } => { + assert!(!negated); + assert_eq!(list.len(), 3); + } + _ => panic!("Expected InList"), + }, + _ => panic!("Expected Filter"), + }, + _ => panic!("Expected Projection"), + } + } + + #[test] + fn test_in_subquery() { + let schema = create_test_schema(); + let sql = "SELECT * FROM users WHERE id IN (SELECT user_id FROM orders)"; + let plan = parse_and_build(sql, &schema).unwrap(); + + match plan { + LogicalPlan::Projection(proj) => match &*proj.input { + LogicalPlan::Filter(filter) => { + assert!(matches!(filter.predicate, LogicalExpr::InSubquery { .. })); + } + _ => panic!("Expected Filter"), + }, + _ => panic!("Expected Projection"), + } + } + + #[test] + fn test_exists_subquery() { + let schema = create_test_schema(); + let sql = "SELECT * FROM users WHERE EXISTS (SELECT 1 FROM orders WHERE orders.user_id = users.id)"; + let plan = parse_and_build(sql, &schema).unwrap(); + + match plan { + LogicalPlan::Projection(proj) => match &*proj.input { + LogicalPlan::Filter(filter) => { + assert!(matches!(filter.predicate, LogicalExpr::Exists { .. })); + } + _ => panic!("Expected Filter"), + }, + _ => panic!("Expected Projection"), + } + } + + #[test] + fn test_between() { + let schema = create_test_schema(); + let sql = "SELECT * FROM users WHERE age BETWEEN 18 AND 65"; + let plan = parse_and_build(sql, &schema).unwrap(); + + match plan { + LogicalPlan::Projection(proj) => match &*proj.input { + LogicalPlan::Filter(filter) => match &filter.predicate { + LogicalExpr::Between { negated, .. } => { + assert!(!negated); + } + _ => panic!("Expected Between"), + }, + _ => panic!("Expected Filter"), + }, + _ => panic!("Expected Projection"), + } + } + + #[test] + fn test_like() { + let schema = create_test_schema(); + let sql = "SELECT * FROM users WHERE name LIKE 'John%'"; + let plan = parse_and_build(sql, &schema).unwrap(); + + match plan { + LogicalPlan::Projection(proj) => match &*proj.input { + LogicalPlan::Filter(filter) => match &filter.predicate { + LogicalExpr::Like { + negated, escape, .. + } => { + assert!(!negated); + assert!(escape.is_none()); + } + _ => panic!("Expected Like"), + }, + _ => panic!("Expected Filter"), + }, + _ => panic!("Expected Projection"), + } + } + + #[test] + fn test_is_null() { + let schema = create_test_schema(); + let sql = "SELECT * FROM users WHERE email IS NULL"; + let plan = parse_and_build(sql, &schema).unwrap(); + + match plan { + LogicalPlan::Projection(proj) => match &*proj.input { + LogicalPlan::Filter(filter) => match &filter.predicate { + LogicalExpr::IsNull { negated, .. } => { + assert!(!negated); + } + _ => panic!("Expected IsNull, got: {:?}", filter.predicate), + }, + _ => panic!("Expected Filter"), + }, + _ => panic!("Expected Projection"), + } + } + + #[test] + fn test_is_not_null() { + let schema = create_test_schema(); + let sql = "SELECT * FROM users WHERE email IS NOT NULL"; + let plan = parse_and_build(sql, &schema).unwrap(); + + match plan { + LogicalPlan::Projection(proj) => match &*proj.input { + LogicalPlan::Filter(filter) => match &filter.predicate { + LogicalExpr::IsNull { negated, .. } => { + assert!(negated); + } + _ => panic!("Expected IsNull"), + }, + _ => panic!("Expected Filter"), + }, + _ => panic!("Expected Projection"), + } + } + + #[test] + fn test_values_clause() { + let schema = create_test_schema(); + let sql = "SELECT * FROM (VALUES (1, 'a'), (2, 'b'))"; + let plan = parse_and_build(sql, &schema).unwrap(); + + match plan { + LogicalPlan::Projection(proj) => match &*proj.input { + LogicalPlan::Values(values) => { + assert_eq!(values.rows.len(), 2); + assert_eq!(values.rows[0].len(), 2); + } + _ => panic!("Expected Values"), + }, + _ => panic!("Expected Projection"), + } + } + + #[test] + fn test_complex_expression_with_aggregation() { + // Test: SELECT sum(id + 2) * 2 FROM orders GROUP BY user_id + let schema = create_test_schema(); + + // Test the complex case: sum((id + 2)) * 2 with parentheses + let sql = "SELECT sum((id + 2)) * 2 FROM orders GROUP BY user_id"; + let plan = parse_and_build(sql, &schema).unwrap(); + + match plan { + LogicalPlan::Projection(proj) => { + assert_eq!(proj.exprs.len(), 1); + match &proj.exprs[0] { + LogicalExpr::BinaryExpr { left, op, right } => { + assert_eq!(*op, BinaryOperator::Multiply); + assert!(matches!(**left, LogicalExpr::Column(_))); + assert!(matches!(**right, LogicalExpr::Literal(_))); + } + _ => panic!("Expected BinaryExpr in projection"), + } + + match &*proj.input { + LogicalPlan::Aggregate(agg) => { + assert_eq!(agg.group_expr.len(), 1); + + assert_eq!(agg.aggr_expr.len(), 1); + match &agg.aggr_expr[0] { + LogicalExpr::AggregateFunction { fun, args, .. } => { + assert_eq!(*fun, AggregateFunction::Sum); + assert_eq!(args.len(), 1); + match &args[0] { + LogicalExpr::Column(col) => { + assert!(col.name.starts_with("__agg_arg_proj_")); + } + _ => panic!("Expected Column reference to projected expression in aggregate args, got {:?}", args[0]), + } + } + _ => panic!("Expected AggregateFunction"), + } + + match &*agg.input { + LogicalPlan::Projection(inner_proj) => { + assert!(inner_proj.exprs.len() >= 2); + let has_binary_add = inner_proj.exprs.iter().any(|e| { + matches!( + e, + LogicalExpr::BinaryExpr { + op: BinaryOperator::Add, + .. + } + ) + }); + assert!( + has_binary_add, + "Should have id + 2 expression in inner projection" + ); + } + _ => panic!("Expected Projection as input to Aggregate"), + } + } + _ => panic!("Expected Aggregate under Projection"), + } + } + _ => panic!("Expected Projection at top level"), + } + } + + #[test] + fn test_function_on_aggregate_result() { + let schema = create_test_schema(); + + let sql = "SELECT abs(sum(id)) FROM orders GROUP BY user_id"; + let plan = parse_and_build(sql, &schema).unwrap(); + match plan { + LogicalPlan::Projection(proj) => { + assert_eq!(proj.exprs.len(), 1); + match &proj.exprs[0] { + LogicalExpr::ScalarFunction { fun, args } => { + assert_eq!(fun, "abs"); + assert_eq!(args.len(), 1); + assert!(matches!(args[0], LogicalExpr::Column(_))); + } + _ => panic!("Expected ScalarFunction in projection"), + } + } + _ => panic!("Expected Projection"), + } + } + + #[test] + fn test_multiple_aggregates_with_arithmetic() { + let schema = create_test_schema(); + + let sql = "SELECT sum(id) * 2 + count(*) FROM orders GROUP BY user_id"; + let plan = parse_and_build(sql, &schema).unwrap(); + match plan { + LogicalPlan::Projection(proj) => { + assert_eq!(proj.exprs.len(), 1); + match &proj.exprs[0] { + LogicalExpr::BinaryExpr { op, .. } => { + assert_eq!(*op, BinaryOperator::Add); + } + _ => panic!("Expected BinaryExpr"), + } + + match &*proj.input { + LogicalPlan::Aggregate(agg) => { + assert_eq!(agg.aggr_expr.len(), 2); + } + _ => panic!("Expected Aggregate"), + } + } + _ => panic!("Expected Projection"), + } + } + + #[test] + fn test_projection_aggregation_projection() { + let schema = create_test_schema(); + + // This tests: projection -> aggregation -> projection + // The inner projection computes (id + 2), then we aggregate sum(), then apply abs() + let sql = "SELECT abs(sum(id + 2)) FROM orders GROUP BY user_id"; + let plan = parse_and_build(sql, &schema).unwrap(); + + // Should produce: Projection(abs) -> Aggregate(sum) -> Projection(id + 2) -> TableScan + match plan { + LogicalPlan::Projection(outer_proj) => { + assert_eq!(outer_proj.exprs.len(), 1); + + // Outer projection should apply abs() function + match &outer_proj.exprs[0] { + LogicalExpr::ScalarFunction { fun, args } => { + assert_eq!(fun, "abs"); + assert_eq!(args.len(), 1); + assert!(matches!(args[0], LogicalExpr::Column(_))); + } + _ => panic!("Expected abs() function in outer projection"), + } + + // Next should be the Aggregate + match &*outer_proj.input { + LogicalPlan::Aggregate(agg) => { + assert_eq!(agg.group_expr.len(), 1); + assert_eq!(agg.aggr_expr.len(), 1); + + // The aggregate should be summing a column reference + match &agg.aggr_expr[0] { + LogicalExpr::AggregateFunction { fun, args, .. } => { + assert_eq!(*fun, AggregateFunction::Sum); + assert_eq!(args.len(), 1); + + // Should reference the projected column + match &args[0] { + LogicalExpr::Column(col) => { + assert!(col.name.starts_with("__agg_arg_proj_")); + } + _ => panic!("Expected column reference in aggregate"), + } + } + _ => panic!("Expected AggregateFunction"), + } + + // Input to aggregate should be a projection computing id + 2 + match &*agg.input { + LogicalPlan::Projection(inner_proj) => { + // Should have at least the group column and the computed expression + assert!(inner_proj.exprs.len() >= 2); + + // Check for the id + 2 expression + let has_add_expr = inner_proj.exprs.iter().any(|e| { + matches!( + e, + LogicalExpr::BinaryExpr { + op: BinaryOperator::Add, + .. + } + ) + }); + assert!( + has_add_expr, + "Should have id + 2 expression in inner projection" + ); + } + _ => panic!("Expected inner Projection under Aggregate"), + } + } + _ => panic!("Expected Aggregate under outer Projection"), + } + } + _ => panic!("Expected Projection at top level"), + } + } + + #[test] + fn test_group_by_validation_allow_grouped_column() { + let schema = create_test_schema(); + + // Test that grouped columns are allowed + let sql = "SELECT user_id, COUNT(*) FROM orders GROUP BY user_id"; + let result = parse_and_build(sql, &schema); + + assert!(result.is_ok(), "Should allow grouped column in SELECT"); + } + + #[test] + fn test_group_by_validation_allow_constants() { + let schema = create_test_schema(); + + // Test that simple constants are allowed even when not grouped + let sql = "SELECT user_id, 42, COUNT(*) FROM orders GROUP BY user_id"; + let result = parse_and_build(sql, &schema); + + assert!( + result.is_ok(), + "Should allow simple constants in SELECT with GROUP BY" + ); + + let sql_complex = "SELECT user_id, (100 + 50) * 2, COUNT(*) FROM orders GROUP BY user_id"; + let result_complex = parse_and_build(sql_complex, &schema); + + assert!( + result_complex.is_ok(), + "Should allow complex constant expressions in SELECT with GROUP BY" + ); + } + + #[test] + fn test_parenthesized_aggregate_expressions() { + let schema = create_test_schema(); + + let sql = "SELECT 25, (MAX(id) / 3), 39 FROM orders"; + let result = parse_and_build(sql, &schema); + + assert!( + result.is_ok(), + "Should handle parenthesized aggregate expressions" + ); + + let plan = result.unwrap(); + + match plan { + LogicalPlan::Projection(proj) => { + assert_eq!(proj.exprs.len(), 3); + + assert!(matches!( + proj.exprs[0], + LogicalExpr::Literal(Value::Integer(25)) + )); + + match &proj.exprs[1] { + LogicalExpr::BinaryExpr { left, op, right } => { + assert_eq!(*op, BinaryOperator::Divide); + assert!(matches!(&**left, LogicalExpr::Column(_))); + assert!(matches!(&**right, LogicalExpr::Literal(Value::Integer(3)))); + } + _ => panic!("Expected BinaryExpr for (MAX(id) / 3)"), + } + + assert!(matches!( + proj.exprs[2], + LogicalExpr::Literal(Value::Integer(39)) + )); + + match &*proj.input { + LogicalPlan::Aggregate(agg) => { + assert_eq!(agg.aggr_expr.len(), 1); + assert!(matches!( + agg.aggr_expr[0], + LogicalExpr::AggregateFunction { + fun: AggFunc::Max, + .. + } + )); + } + _ => panic!("Expected Aggregate node under Projection"), + } + } + _ => panic!("Expected Projection at top level"), + } + } + + #[test] + fn test_duplicate_aggregate_reuse() { + let schema = create_test_schema(); + + let sql = "SELECT (COUNT(*) - 225), 30, COUNT(*) FROM orders"; + let result = parse_and_build(sql, &schema); + + assert!(result.is_ok(), "Should handle duplicate aggregates"); + + let plan = result.unwrap(); + + match plan { + LogicalPlan::Projection(proj) => { + assert_eq!(proj.exprs.len(), 3); + + match &proj.exprs[0] { + LogicalExpr::BinaryExpr { left, op, right } => { + assert_eq!(*op, BinaryOperator::Subtract); + match &**left { + LogicalExpr::Column(col) => { + assert!(col.name.starts_with("__agg_") || col.name == "COUNT(*)"); + } + _ => panic!("Expected Column reference for COUNT(*)"), + } + assert!(matches!( + &**right, + LogicalExpr::Literal(Value::Integer(225)) + )); + } + _ => panic!("Expected BinaryExpr for (COUNT(*) - 225)"), + } + + assert!(matches!( + proj.exprs[1], + LogicalExpr::Literal(Value::Integer(30)) + )); + + match &proj.exprs[2] { + LogicalExpr::Column(col) => { + assert!(col.name.starts_with("__agg_") || col.name == "COUNT(*)"); + } + _ => panic!("Expected Column reference for COUNT(*)"), + } + + match &*proj.input { + LogicalPlan::Aggregate(agg) => { + assert_eq!( + agg.aggr_expr.len(), + 1, + "Should have only one COUNT(*) aggregate" + ); + assert!(matches!( + agg.aggr_expr[0], + LogicalExpr::AggregateFunction { + fun: AggFunc::Count, + .. + } + )); + } + _ => panic!("Expected Aggregate node under Projection"), + } + } + _ => panic!("Expected Projection at top level"), + } + } + + #[test] + fn test_aggregate_without_group_by_allow_constants() { + let schema = create_test_schema(); + + // Test that constants are allowed with aggregates even without GROUP BY + let sql = "SELECT 42, COUNT(*), MAX(amount) FROM orders"; + let result = parse_and_build(sql, &schema); + + assert!( + result.is_ok(), + "Should allow simple constants with aggregates without GROUP BY" + ); + + // Test complex constant expressions + let sql_complex = "SELECT (9 / 6) % 5, COUNT(*), MAX(amount) FROM orders"; + let result_complex = parse_and_build(sql_complex, &schema); + + assert!( + result_complex.is_ok(), + "Should allow complex constant expressions with aggregates without GROUP BY" + ); + } + + #[test] + fn test_aggregate_without_group_by_creates_aggregate_node() { + let schema = create_test_schema(); + + // Test that aggregate without GROUP BY creates proper Aggregate node + let sql = "SELECT MAX(amount) FROM orders"; + let plan = parse_and_build(sql, &schema).unwrap(); + + // Should be: Aggregate -> TableScan (no projection needed for simple aggregate) + match plan { + LogicalPlan::Aggregate(agg) => { + assert_eq!(agg.group_expr.len(), 0, "Should have no group expressions"); + assert_eq!( + agg.aggr_expr.len(), + 1, + "Should have one aggregate expression" + ); + assert_eq!( + agg.schema.column_count(), + 1, + "Schema should have one column" + ); + } + _ => panic!("Expected Aggregate at top level (no projection)"), + } + } + + #[test] + fn test_scalar_vs_aggregate_function_classification() { + let schema = create_test_schema(); + + // Test MIN/MAX with 1 argument - should be aggregate + let sql = "SELECT MIN(amount) FROM orders"; + let plan = parse_and_build(sql, &schema).unwrap(); + match plan { + LogicalPlan::Aggregate(agg) => { + assert_eq!(agg.aggr_expr.len(), 1, "MIN(x) should be an aggregate"); + match &agg.aggr_expr[0] { + LogicalExpr::AggregateFunction { fun, args, .. } => { + assert!(matches!(fun, AggFunc::Min)); + assert_eq!(args.len(), 1); + } + _ => panic!("Expected AggregateFunction"), + } + } + _ => panic!("Expected Aggregate node for MIN(x)"), + } + + // Test MIN/MAX with 2 arguments - should be scalar in projection + let sql = "SELECT MIN(amount, user_id) FROM orders"; + let plan = parse_and_build(sql, &schema).unwrap(); + match plan { + LogicalPlan::Projection(proj) => { + assert_eq!(proj.exprs.len(), 1, "Should have one projection expression"); + match &proj.exprs[0] { + LogicalExpr::ScalarFunction { fun, args } => { + assert_eq!( + fun.to_lowercase(), + "min", + "MIN(x,y) should be a scalar function" + ); + assert_eq!(args.len(), 2); + } + _ => panic!("Expected ScalarFunction for MIN(x,y)"), + } + } + _ => panic!("Expected Projection node for scalar MIN(x,y)"), + } + + // Test MAX with 3 arguments - should be scalar + let sql = "SELECT MAX(amount, user_id, id) FROM orders"; + let plan = parse_and_build(sql, &schema).unwrap(); + match plan { + LogicalPlan::Projection(proj) => { + assert_eq!(proj.exprs.len(), 1); + match &proj.exprs[0] { + LogicalExpr::ScalarFunction { fun, args } => { + assert_eq!( + fun.to_lowercase(), + "max", + "MAX(x,y,z) should be a scalar function" + ); + assert_eq!(args.len(), 3); + } + _ => panic!("Expected ScalarFunction for MAX(x,y,z)"), + } + } + _ => panic!("Expected Projection node for scalar MAX(x,y,z)"), + } + + // Test that MIN with 0 args is treated as scalar (will fail later in execution) + let sql = "SELECT MIN() FROM orders"; + let plan = parse_and_build(sql, &schema).unwrap(); + match plan { + LogicalPlan::Projection(proj) => match &proj.exprs[0] { + LogicalExpr::ScalarFunction { fun, args } => { + assert_eq!(fun.to_lowercase(), "min"); + assert_eq!(args.len(), 0, "MIN() should be scalar with 0 args"); + } + _ => panic!("Expected ScalarFunction for MIN()"), + }, + _ => panic!("Expected Projection for MIN()"), + } + + // Test other functions that are always aggregate (COUNT, SUM, AVG) + let sql = "SELECT COUNT(*), SUM(amount), AVG(amount) FROM orders"; + let plan = parse_and_build(sql, &schema).unwrap(); + match plan { + LogicalPlan::Aggregate(agg) => { + assert_eq!(agg.aggr_expr.len(), 3, "Should have 3 aggregate functions"); + for expr in &agg.aggr_expr { + assert!(matches!(expr, LogicalExpr::AggregateFunction { .. })); + } + } + _ => panic!("Expected Aggregate node"), + } + + // Test scalar functions that are never aggregates (ABS, ROUND, etc.) + let sql = "SELECT ABS(amount), ROUND(amount), LENGTH(product) FROM orders"; + let plan = parse_and_build(sql, &schema).unwrap(); + match plan { + LogicalPlan::Projection(proj) => { + assert_eq!(proj.exprs.len(), 3, "Should have 3 scalar functions"); + for expr in &proj.exprs { + match expr { + LogicalExpr::ScalarFunction { .. } => {} + _ => panic!("Expected all ScalarFunctions"), + } + } + } + _ => panic!("Expected Projection node for scalar functions"), + } + } + + #[test] + fn test_mixed_aggregate_and_group_columns() { + let schema = create_test_schema(); + + // When selecting both aggregate and grouping columns + let sql = "SELECT user_id, sum(id) FROM orders GROUP BY user_id"; + let plan = parse_and_build(sql, &schema).unwrap(); + + // No projection needed - aggregate outputs exactly what we select + match plan { + LogicalPlan::Aggregate(agg) => { + assert_eq!(agg.group_expr.len(), 1); + assert_eq!(agg.aggr_expr.len(), 1); + assert_eq!(agg.schema.column_count(), 2); + } + _ => panic!("Expected Aggregate (no projection)"), + } + } + + #[test] + fn test_scalar_function_wrapping_aggregate_no_group_by() { + // Test: SELECT HEX(SUM(age + 2)) FROM users + // Expected structure: + // Projection { exprs: [ScalarFunction(HEX, [Column])] } + // -> Aggregate { aggr_expr: [Sum(BinaryExpr(age + 2))], group_expr: [] } + // -> Projection { exprs: [BinaryExpr(age + 2)] } + // -> TableScan("users") + + let schema = create_test_schema(); + let sql = "SELECT HEX(SUM(age + 2)) FROM users"; + let mut parser = Parser::new(sql.as_bytes()); + let stmt = parser.next().unwrap().unwrap(); + + let plan = match stmt { + ast::Cmd::Stmt(stmt) => { + let mut builder = LogicalPlanBuilder::new(&schema); + builder.build_statement(&stmt).unwrap() + } + _ => panic!("Expected SQL statement"), + }; + + match &plan { + LogicalPlan::Projection(proj) => { + assert_eq!(proj.exprs.len(), 1, "Should have one expression"); + + match &proj.exprs[0] { + LogicalExpr::ScalarFunction { fun, args } => { + assert_eq!(fun, "HEX", "Outer function should be HEX"); + assert_eq!(args.len(), 1, "HEX should have one argument"); + + match &args[0] { + LogicalExpr::Column(_) => {} + LogicalExpr::AggregateFunction { .. } => { + panic!("Aggregate function should not be embedded in projection! It should be in a separate Aggregate operator"); + } + _ => panic!( + "Expected column reference as argument to HEX, got: {:?}", + args[0] + ), + } + } + _ => panic!("Expected ScalarFunction (HEX), got: {:?}", proj.exprs[0]), + } + + match &*proj.input { + LogicalPlan::Aggregate(agg) => { + assert_eq!(agg.group_expr.len(), 0, "Should have no GROUP BY"); + assert_eq!( + agg.aggr_expr.len(), + 1, + "Should have one aggregate expression" + ); + + match &agg.aggr_expr[0] { + LogicalExpr::AggregateFunction { + fun, + args, + distinct, + } => { + assert_eq!(*fun, crate::function::AggFunc::Sum, "Should be SUM"); + assert!(!distinct, "Should not be DISTINCT"); + assert_eq!(args.len(), 1, "SUM should have one argument"); + + match &args[0] { + LogicalExpr::Column(col) => { + // When aggregate arguments are complex, they get pre-projected + assert!(col.name.starts_with("__agg_arg_proj_"), + "Should reference pre-projected column, got: {}", col.name); + } + LogicalExpr::BinaryExpr { left, op, right } => { + // Simple case without pre-projection (shouldn't happen with current implementation) + assert_eq!(*op, ast::Operator::Add, "Should be addition"); + + match (&**left, &**right) { + (LogicalExpr::Column(col), LogicalExpr::Literal(val)) => { + assert_eq!(col.name, "age", "Should reference age column"); + assert_eq!(*val, Value::Integer(2), "Should add 2"); + } + _ => panic!("Expected age + 2"), + } + } + _ => panic!("Expected Column reference or BinaryExpr for aggregate argument, got: {:?}", args[0]), + } + } + _ => panic!("Expected AggregateFunction"), + } + + match &*agg.input { + LogicalPlan::TableScan(scan) => { + assert_eq!(scan.table_name, "users"); + } + LogicalPlan::Projection(proj) => match &*proj.input { + LogicalPlan::TableScan(scan) => { + assert_eq!(scan.table_name, "users"); + } + _ => panic!("Expected TableScan under projection"), + }, + _ => panic!("Expected TableScan or Projection under Aggregate"), + } + } + _ => panic!( + "Expected Aggregate operator under Projection, got: {:?}", + proj.input + ), + } + } + _ => panic!("Expected Projection as top-level operator, got: {plan:?}"), + } + } +} diff --git a/core/translate/mod.rs b/core/translate/mod.rs index 1becad3d4..6d5b727c6 100644 --- a/core/translate/mod.rs +++ b/core/translate/mod.rs @@ -21,6 +21,7 @@ pub(crate) mod group_by; pub(crate) mod index; pub(crate) mod insert; pub(crate) mod integrity_check; +pub(crate) mod logical; pub(crate) mod main_loop; pub(crate) mod optimizer; pub(crate) mod order_by; diff --git a/testing/materialized_views.test b/testing/materialized_views.test index 2b6a56be3..5d226b016 100755 --- a/testing/materialized_views.test +++ b/testing/materialized_views.test @@ -385,3 +385,174 @@ do_execsql_test_on_specific_db {:memory:} matview-projections { SELECT * from v; } {4|3|7|22|3 3|4|7|22|3} + +do_execsql_test_on_specific_db {:memory:} matview-rollback-insert { + CREATE TABLE t(a INTEGER, b INTEGER); + INSERT INTO t VALUES (1, 10), (2, 20), (3, 30); + + CREATE MATERIALIZED VIEW v AS + SELECT * FROM t WHERE b > 15; + + SELECT * FROM v ORDER BY a; + + BEGIN; + INSERT INTO t VALUES (4, 40), (5, 50); + SELECT * FROM v ORDER BY a; + ROLLBACK; + + SELECT * FROM v ORDER BY a; +} {2|20 +3|30 +2|20 +3|30 +4|40 +5|50 +2|20 +3|30} + +do_execsql_test_on_specific_db {:memory:} matview-rollback-delete { + CREATE TABLE t(a INTEGER, b INTEGER); + INSERT INTO t VALUES (1, 10), (2, 20), (3, 30), (4, 40); + + CREATE MATERIALIZED VIEW v AS + SELECT * FROM t WHERE b > 15; + + SELECT * FROM v ORDER BY a; + + BEGIN; + DELETE FROM t WHERE a IN (2, 3); + SELECT * FROM v ORDER BY a; + ROLLBACK; + + SELECT * FROM v ORDER BY a; +} {2|20 +3|30 +4|40 +4|40 +2|20 +3|30 +4|40} + +do_execsql_test_on_specific_db {:memory:} matview-rollback-update { + CREATE TABLE t(a INTEGER, b INTEGER); + INSERT INTO t VALUES (1, 10), (2, 20), (3, 30); + + CREATE MATERIALIZED VIEW v AS + SELECT * FROM t WHERE b > 15; + + SELECT * FROM v ORDER BY a; + + BEGIN; + UPDATE t SET b = 5 WHERE a = 2; + UPDATE t SET b = 35 WHERE a = 1; + SELECT * FROM v ORDER BY a; + ROLLBACK; + + SELECT * FROM v ORDER BY a; +} {2|20 +3|30 +1|35 +3|30 +2|20 +3|30} + +do_execsql_test_on_specific_db {:memory:} matview-rollback-aggregation { + CREATE TABLE sales(product_id INTEGER, amount INTEGER); + INSERT INTO sales VALUES (1, 100), (1, 200), (2, 150), (2, 250); + + CREATE MATERIALIZED VIEW product_totals AS + SELECT product_id, SUM(amount) as total, COUNT(*) as cnt + FROM sales + GROUP BY product_id; + + SELECT * FROM product_totals ORDER BY product_id; + + BEGIN; + INSERT INTO sales VALUES (1, 50), (3, 300); + SELECT * FROM product_totals ORDER BY product_id; + ROLLBACK; + + SELECT * FROM product_totals ORDER BY product_id; +} {1|300|2 +2|400|2 +1|350|3 +2|400|2 +3|300|1 +1|300|2 +2|400|2} + +do_execsql_test_on_specific_db {:memory:} matview-rollback-mixed-operations { + CREATE TABLE orders(id INTEGER PRIMARY KEY, customer INTEGER, amount INTEGER); + INSERT INTO orders VALUES (1, 100, 50), (2, 200, 75), (3, 100, 25); + + CREATE MATERIALIZED VIEW customer_totals AS + SELECT customer, SUM(amount) as total, COUNT(*) as cnt + FROM orders + GROUP BY customer; + + SELECT * FROM customer_totals ORDER BY customer; + + BEGIN; + INSERT INTO orders VALUES (4, 100, 100); + UPDATE orders SET amount = 150 WHERE id = 2; + DELETE FROM orders WHERE id = 3; + SELECT * FROM customer_totals ORDER BY customer; + ROLLBACK; + + SELECT * FROM customer_totals ORDER BY customer; +} {100|75|2 +200|75|1 +100|150|2 +200|150|1 +100|75|2 +200|75|1} + +do_execsql_test_on_specific_db {:memory:} matview-rollback-filtered-aggregation { + CREATE TABLE transactions(id INTEGER, account INTEGER, amount INTEGER, type TEXT); + INSERT INTO transactions VALUES + (1, 100, 50, 'deposit'), + (2, 100, 30, 'withdraw'), + (3, 200, 100, 'deposit'), + (4, 200, 40, 'withdraw'); + + CREATE MATERIALIZED VIEW deposits AS + SELECT account, SUM(amount) as total_deposits, COUNT(*) as cnt + FROM transactions + WHERE type = 'deposit' + GROUP BY account; + + SELECT * FROM deposits ORDER BY account; + + BEGIN; + INSERT INTO transactions VALUES (5, 100, 75, 'deposit'); + UPDATE transactions SET amount = 60 WHERE id = 1; + DELETE FROM transactions WHERE id = 3; + SELECT * FROM deposits ORDER BY account; + ROLLBACK; + + SELECT * FROM deposits ORDER BY account; +} {100|50|1 +200|100|1 +100|135|2 +100|50|1 +200|100|1} + +do_execsql_test_on_specific_db {:memory:} matview-rollback-empty-view { + CREATE TABLE t(a INTEGER, b INTEGER); + INSERT INTO t VALUES (1, 5), (2, 8); + + CREATE MATERIALIZED VIEW v AS + SELECT * FROM t WHERE b > 10; + + SELECT COUNT(*) FROM v; + + BEGIN; + INSERT INTO t VALUES (3, 15), (4, 20); + SELECT * FROM v ORDER BY a; + ROLLBACK; + + SELECT COUNT(*) FROM v; +} {0 +3|15 +4|20 +0}